summaryrefslogtreecommitdiff
path: root/internal/server/handlers
diff options
context:
space:
mode:
authorPaul Buetow <pbuetow@mimecast.com>2021-10-21 21:28:49 +0300
committerPaul Buetow <pbuetow@mimecast.com>2021-10-21 21:28:49 +0300
commitf4207a55f71bfbcfdc532d5cdd3befaa3474a157 (patch)
treeea5e4a2d2a67035f645bdee496ae55a52034178a /internal/server/handlers
parentd80d6070557e3a800e3a54967af9eced518f116b (diff)
parent739205206d63bf42f4e843b39d04d4c8cd8207c3 (diff)
merge develop
Diffstat (limited to 'internal/server/handlers')
-rw-r--r--internal/server/handlers/basehandler.go320
-rw-r--r--internal/server/handlers/controlhandler.go98
-rw-r--r--internal/server/handlers/healthhandler.go58
-rw-r--r--internal/server/handlers/mapcommand.go7
-rw-r--r--internal/server/handlers/readcommand.go83
-rw-r--r--internal/server/handlers/serverhandler.go412
6 files changed, 463 insertions, 515 deletions
diff --git a/internal/server/handlers/basehandler.go b/internal/server/handlers/basehandler.go
new file mode 100644
index 0000000..6d10d17
--- /dev/null
+++ b/internal/server/handlers/basehandler.go
@@ -0,0 +1,320 @@
+package handlers
+
+import (
+ "bytes"
+ "context"
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "io"
+ "strconv"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/mimecast/dtail/internal"
+ "github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/io/dlog"
+ "github.com/mimecast/dtail/internal/io/line"
+ "github.com/mimecast/dtail/internal/io/pool"
+ "github.com/mimecast/dtail/internal/lcontext"
+ "github.com/mimecast/dtail/internal/mapr/server"
+ "github.com/mimecast/dtail/internal/protocol"
+ user "github.com/mimecast/dtail/internal/user/server"
+)
+
+type handleCommandCb func(context.Context, lcontext.LContext, int, []string, string)
+
+type baseHandler struct {
+ done *internal.Done
+ handleCommandCb handleCommandCb
+ lines chan line.Line
+ aggregate *server.Aggregate
+ maprMessages chan string
+ serverMessages chan string
+ hostname string
+ user *user.User
+ ackCloseReceived chan struct{}
+ activeCommands int32
+ readBuf bytes.Buffer
+ writeBuf bytes.Buffer
+
+ // Some global options + sync primitives required.
+ once sync.Once
+ mutex sync.Mutex
+ quiet bool
+ spartan bool
+ serverless bool
+}
+
+// Shutdown the handler.
+func (h *baseHandler) Shutdown() {
+ h.done.Shutdown()
+}
+
+// Done channel of the handler.
+func (h *baseHandler) Done() <-chan struct{} {
+ return h.done.Done()
+}
+
+// Read is to send data to the dtail client via Reader interface.
+func (h *baseHandler) Read(p []byte) (n int, err error) {
+ defer h.readBuf.Reset()
+
+ select {
+ case message := <-h.serverMessages:
+ if len(message) > 0 && message[0] == '.' {
+ // Handle hidden message (don't display to the user)
+ h.readBuf.WriteString(message)
+ h.readBuf.WriteByte(protocol.MessageDelimiter)
+ n = copy(p, h.readBuf.Bytes())
+ return
+ }
+
+ if h.serverless {
+ return
+ }
+
+ // Handle normal server message (display to the user)
+ h.readBuf.WriteString("SERVER")
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(h.hostname)
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(message)
+ h.readBuf.WriteByte(protocol.MessageDelimiter)
+ n = copy(p, h.readBuf.Bytes())
+
+ case message := <-h.maprMessages:
+ // Send mapreduce-aggregated data as a message.
+ h.readBuf.WriteString("AGGREGATE")
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(h.hostname)
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(message)
+ h.readBuf.WriteByte(protocol.MessageDelimiter)
+ n = copy(p, h.readBuf.Bytes())
+
+ case line := <-h.lines:
+ if !h.spartan {
+ h.readBuf.WriteString("REMOTE")
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(h.hostname)
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(fmt.Sprintf("%3d", line.TransmittedPerc))
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(fmt.Sprintf("%v", line.Count))
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(line.SourceID)
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ }
+ h.readBuf.WriteString(line.Content.String())
+ h.readBuf.WriteByte(protocol.MessageDelimiter)
+ n = copy(p, h.readBuf.Bytes())
+ pool.RecycleBytesBuffer(line.Content)
+
+ case <-time.After(time.Second):
+ // Once in a while check whether we are done.
+ select {
+ case <-h.done.Done():
+ err = io.EOF
+ return
+ default:
+ }
+ }
+ return
+}
+
+// Write is to receive data from the dtail client via Writer interface.
+func (h *baseHandler) Write(p []byte) (n int, err error) {
+ for _, b := range p {
+ switch b {
+ case ';':
+ h.handleCommand(string(h.writeBuf.Bytes()))
+ h.writeBuf.Reset()
+ default:
+ h.writeBuf.WriteByte(b)
+ }
+ }
+ n = len(p)
+ return
+}
+
+func (h *baseHandler) handleCommand(commandStr string) {
+ dlog.Server.Debug(h.user, commandStr)
+
+ args, argc, add, err := h.handleProtocolVersion(strings.Split(commandStr, " "))
+ if err != nil {
+ h.send(h.serverMessages, dlog.Server.Error(h.user, err)+add)
+ return
+ }
+ args, argc, err = h.handleBase64(args, argc)
+ if err != nil {
+ h.send(h.serverMessages, dlog.Server.Error(h.user, err))
+ return
+ }
+ ctx, cancel := context.WithCancel(context.Background())
+ go func() {
+ <-h.done.Done()
+ cancel()
+ }()
+
+ parts := strings.Split(args[0], ":")
+ commandName := parts[0]
+
+ // Either no options or empty options provided.
+ if len(parts) == 1 || len(parts[1]) == 0 {
+ h.handleCommandCb(ctx, lcontext.LContext{}, argc, args, commandName)
+ return
+ }
+
+ options, ltx, err := config.DeserializeOptions(parts[1:])
+ if err != nil {
+ h.send(h.serverMessages, dlog.Server.Error(h.user, err))
+ return
+ }
+ h.handleOptions(options)
+ h.handleCommandCb(ctx, ltx, argc, args, commandName)
+}
+
+func (h *baseHandler) handleProtocolVersion(args []string) ([]string, int, string, error) {
+ argc := len(args)
+ var add string
+
+ if argc <= 2 || args[0] != "protocol" {
+ return args, argc, add, errors.New("unable to determine protocol version")
+ }
+
+ if args[1] != protocol.ProtocolCompat {
+ clientCompat, _ := strconv.Atoi(args[1])
+ serverCompat, _ := strconv.Atoi(protocol.ProtocolCompat)
+ if clientCompat <= 3 {
+ // Protocol version 3 or lower expect a newline as message separator
+ // One day (after 2 major versions) this exception may be removed!
+ add = "\n"
+ }
+
+ toUpdate := "client"
+ if clientCompat > serverCompat {
+ toUpdate = "server"
+ }
+ err := fmt.Errorf("the DTail server protocol version '%s' does not match "+
+ "client protocol version '%s', please update DTail %s",
+ protocol.ProtocolCompat, args[1], toUpdate)
+ return args, argc, add, err
+ }
+
+ return args[2:], argc - 2, add, nil
+}
+
+func (h *baseHandler) handleBase64(args []string, argc int) ([]string, int, error) {
+ err := errors.New("unable to decode client message, DTail server and client " +
+ "versions may not be compatible")
+ if argc != 2 || args[0] != "base64" {
+ return args, argc, err
+ }
+
+ decoded, err := base64.StdEncoding.DecodeString(args[1])
+ if err != nil {
+ return args, argc, err
+ }
+ decodedStr := string(decoded)
+
+ args = strings.Split(decodedStr, " ")
+ argc = len(decodedStr)
+ dlog.Server.Trace(h.user, "Base64 decoded received command",
+ decodedStr, argc, args)
+
+ return args, argc, nil
+}
+
+func (h *baseHandler) handleAckCommand(argc int, args []string) {
+ if argc < 3 {
+ if !h.quiet {
+ h.send(h.serverMessages, dlog.Server.Warn(h.user,
+ "Unable to parse command", args, argc))
+ }
+ return
+ }
+ if args[1] == "close" && args[2] == "connection" {
+ select {
+ case <-h.ackCloseReceived:
+ default:
+ close(h.ackCloseReceived)
+ }
+ }
+}
+
+func (h *baseHandler) handleOptions(options map[string]string) {
+ // We have to make sure that this block is executed only once.
+ h.mutex.Lock()
+ defer h.mutex.Unlock()
+ // We can read the options only once, will cause a data race otherwise if
+ // changed multiple times for multiple incoming commands.
+ h.once.Do(func() {
+ if quiet, _ := options["quiet"]; quiet == "true" {
+ dlog.Server.Debug(h.user, "Enabling quiet mode")
+ h.quiet = true
+ }
+ if spartan, _ := options["spartan"]; spartan == "true" {
+ dlog.Server.Debug(h.user, "Enabling spartan mode")
+ h.spartan = true
+ }
+ if serverless, _ := options["serverless"]; serverless == "true" {
+ dlog.Server.Debug(h.user, "Enabling serverless mode")
+ h.serverless = true
+ }
+ })
+}
+
+func (h *baseHandler) send(ch chan<- string, message string) {
+ select {
+ case ch <- message:
+ case <-h.done.Done():
+ }
+}
+
+func (h *baseHandler) flush() {
+ dlog.Server.Trace(h.user, "flush()")
+ numUnsentMessages := func() int {
+ return len(h.lines) + len(h.serverMessages) + len(h.maprMessages)
+ }
+ for i := 0; i < 10; i++ {
+ if numUnsentMessages() == 0 {
+ dlog.Server.Debug(h.user, "ALL lines sent", fmt.Sprintf("%p", h))
+ return
+ }
+ dlog.Server.Debug(h.user, "Still lines to be sent")
+ time.Sleep(time.Millisecond * 10)
+ }
+ dlog.Server.Warn(h.user, "Some lines remain unsent", numUnsentMessages())
+}
+
+func (h *baseHandler) shutdown() {
+ dlog.Server.Debug(h.user, "shutdown()")
+ h.flush()
+
+ go func() {
+ select {
+ case h.serverMessages <- ".syn close connection":
+ case <-h.done.Done():
+ }
+ }()
+
+ select {
+ case <-h.ackCloseReceived:
+ case <-time.After(time.Second * 5):
+ dlog.Server.Debug(h.user, "Shutdown timeout reached, enforcing shutdown")
+ case <-h.done.Done():
+ }
+ h.done.Shutdown()
+}
+
+func (h *baseHandler) incrementActiveCommands() {
+ atomic.AddInt32(&h.activeCommands, 1)
+}
+
+func (h *baseHandler) decrementActiveCommands() int32 {
+ atomic.AddInt32(&h.activeCommands, -1)
+ return atomic.LoadInt32(&h.activeCommands)
+}
diff --git a/internal/server/handlers/controlhandler.go b/internal/server/handlers/controlhandler.go
deleted file mode 100644
index 1e17c78..0000000
--- a/internal/server/handlers/controlhandler.go
+++ /dev/null
@@ -1,98 +0,0 @@
-package handlers
-
-import (
- "fmt"
- "io"
- "os"
- "strings"
-
- "github.com/mimecast/dtail/internal"
- "github.com/mimecast/dtail/internal/io/logger"
- user "github.com/mimecast/dtail/internal/user/server"
-)
-
-// ControlHandler is used for control functions and health monitoring.
-type ControlHandler struct {
- done *internal.Done
- hostname string
- payload []byte
- serverMessages chan string
- user *user.User
-}
-
-// NewControlHandler returns a new control handler.
-func NewControlHandler(user *user.User) *ControlHandler {
- logger.Debug(user, "Creating control handler")
-
- h := ControlHandler{
- done: internal.NewDone(),
- serverMessages: make(chan string, 10),
- user: user,
- }
-
- fqdn, err := os.Hostname()
- if err != nil {
- logger.FatalExit(err)
- }
-
- s := strings.Split(fqdn, ".")
- h.hostname = s[0]
-
- return &h
-}
-
-// Shutdown the handler.
-func (h *ControlHandler) Shutdown() {
- h.done.Shutdown()
-}
-
-// Done channel of the handler.
-func (h *ControlHandler) Done() <-chan struct{} {
- return h.done.Done()
-}
-
-// Read is to send data to the client via the Reader interface.
-func (h *ControlHandler) Read(p []byte) (n int, err error) {
- for {
- select {
- case message := <-h.serverMessages:
- wholePayload := []byte(fmt.Sprintf("SERVER|%s|%s\n", h.hostname, message))
- n = copy(p, wholePayload)
- return
- case <-h.done.Done():
- return 0, io.EOF
- }
- }
-}
-
-// Write is to read data to the client via the Writer interface.
-func (h *ControlHandler) Write(p []byte) (n int, err error) {
- for _, c := range p {
- switch c {
- case ';':
- wholePayload := strings.TrimSpace(string(h.payload))
- h.handleCommand(wholePayload)
- h.payload = nil
-
- default:
- h.payload = append(h.payload, c)
- }
- }
-
- n = len(p)
- return
-}
-
-func (h *ControlHandler) handleCommand(command string) {
- logger.Info(h.user, command)
- s := strings.Split(command, " ")
- logger.Debug(h.user, "Receiving command", command, s)
-
- switch s[0] {
- case "health":
- h.serverMessages <- "OK: DTail SSH Server seems fine"
- h.serverMessages <- "done;"
- default:
- h.serverMessages <- logger.Error(h.user, "Received unknown control command", command, s)
- }
-}
diff --git a/internal/server/handlers/healthhandler.go b/internal/server/handlers/healthhandler.go
new file mode 100644
index 0000000..6dd9872
--- /dev/null
+++ b/internal/server/handlers/healthhandler.go
@@ -0,0 +1,58 @@
+package handlers
+
+import (
+ "context"
+ "os"
+ "strings"
+
+ "github.com/mimecast/dtail/internal"
+ "github.com/mimecast/dtail/internal/io/dlog"
+ "github.com/mimecast/dtail/internal/io/line"
+ "github.com/mimecast/dtail/internal/lcontext"
+ user "github.com/mimecast/dtail/internal/user/server"
+)
+
+// HealthHandler is for the remote health check.
+type HealthHandler struct {
+ baseHandler
+}
+
+// NewHealthHandler returns the server handler.
+func NewHealthHandler(user *user.User) *HealthHandler {
+ dlog.Server.Debug(user, "Creating new server health handler")
+ h := HealthHandler{
+ baseHandler: baseHandler{
+ done: internal.NewDone(),
+ lines: make(chan line.Line, 100),
+ serverMessages: make(chan string, 10),
+ maprMessages: make(chan string, 10),
+ ackCloseReceived: make(chan struct{}),
+ user: user,
+ },
+ }
+ h.handleCommandCb = h.handleHealthCommand
+
+ fqdn, err := os.Hostname()
+ if err != nil {
+ dlog.Server.FatalPanic(err)
+ }
+ s := strings.Split(fqdn, ".")
+ h.hostname = s[0]
+ return &h
+}
+
+func (h *HealthHandler) handleHealthCommand(ctx context.Context,
+ ltx lcontext.LContext, argc int, args []string, commandName string) {
+
+ dlog.Server.Debug(h.user, "Handling health command", argc, args)
+ switch commandName {
+ case "health":
+ h.send(h.serverMessages, "OK")
+ case ".ack":
+ h.handleAckCommand(argc, args)
+ default:
+ h.send(h.serverMessages, dlog.Server.Error(h.user,
+ "Received unknown health command", commandName, argc, args))
+ }
+ h.shutdown()
+}
diff --git a/internal/server/handlers/mapcommand.go b/internal/server/handlers/mapcommand.go
index c3e600e..65e0ed8 100644
--- a/internal/server/handlers/mapcommand.go
+++ b/internal/server/handlers/mapcommand.go
@@ -14,18 +14,17 @@ type mapCommand struct {
}
// NewMapCommand returns a new server side mapreduce command.
-func newMapCommand(serverHandler *ServerHandler, argc int, args []string) (mapCommand, *server.Aggregate, error) {
- m := mapCommand{server: serverHandler}
+func newMapCommand(serverHandler *ServerHandler, argc int,
+ args []string) (mapCommand, *server.Aggregate, error) {
+ m := mapCommand{server: serverHandler}
queryStr := strings.Join(args[1:], " ")
aggregate, err := server.NewAggregate(queryStr)
if err != nil {
return m, nil, err
}
-
m.aggregate = aggregate
return m, aggregate, nil
-
}
func (m mapCommand) Start(ctx context.Context, aggregatedMessages chan<- string) {
diff --git a/internal/server/handlers/readcommand.go b/internal/server/handlers/readcommand.go
index b659c06..4728a55 100644
--- a/internal/server/handlers/readcommand.go
+++ b/internal/server/handlers/readcommand.go
@@ -7,8 +7,9 @@ import (
"sync"
"time"
+ "github.com/mimecast/dtail/internal/io/dlog"
"github.com/mimecast/dtail/internal/io/fs"
- "github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/io/line"
"github.com/mimecast/dtail/internal/lcontext"
"github.com/mimecast/dtail/internal/omode"
"github.com/mimecast/dtail/internal/regex"
@@ -26,39 +27,45 @@ func newReadCommand(server *ServerHandler, mode omode.Mode) *readCommand {
}
}
-func (r *readCommand) Start(ctx context.Context, lContext lcontext.LContext, argc int, args []string, retries int) {
- re := regex.NewNoop()
+func (r *readCommand) Start(ctx context.Context, ltx lcontext.LContext,
+ argc int, args []string, retries int) {
+ re := regex.NewNoop()
if argc >= 4 {
deserializedRegex, err := regex.Deserialize(strings.Join(args[2:], " "))
if err != nil {
- r.server.sendServerMessage(logger.Error(r.server.user, commandParseWarning, err))
+ r.server.send(r.server.serverMessages, dlog.Server.Error(r.server.user,
+ "Unable to parse command", err))
return
}
re = deserializedRegex
}
if argc < 3 {
- r.server.sendServerWarnMessage(logger.Warn(r.server.user, commandParseWarning, args, argc))
+ r.server.send(r.server.serverMessages, dlog.Server.Warn(r.server.user,
+ "Unable to parse command", args, argc))
return
}
- r.readGlob(ctx, lContext, args[1], re, retries)
+ r.readGlob(ctx, ltx, args[1], re, retries)
}
-func (r *readCommand) readGlob(ctx context.Context, lContext lcontext.LContext, glob string, re regex.Regex, retries int) {
+func (r *readCommand) readGlob(ctx context.Context, ltx lcontext.LContext,
+ glob string, re regex.Regex, retries int) {
+
retryInterval := time.Second * 5
glob = filepath.Clean(glob)
for retryCount := 0; retryCount < retries; retryCount++ {
paths, err := filepath.Glob(glob)
if err != nil {
- logger.Warn(r.server.user, glob, err)
+ dlog.Server.Warn(r.server.user, glob, err)
time.Sleep(retryInterval)
continue
}
if numPaths := len(paths); numPaths == 0 {
- logger.Error(r.server.user, "No such file(s) to read", glob)
- r.server.sendServerWarnMessage(logger.Warn(r.server.user, "Unable to read file(s), check server logs"))
+ dlog.Server.Error(r.server.user, "No such file(s) to read", glob)
+ r.server.send(r.server.serverMessages, dlog.Server.Warn(r.server.user,
+ "Unable to read file(s), check server logs"))
select {
case <-ctx.Done():
return
@@ -68,41 +75,44 @@ func (r *readCommand) readGlob(ctx context.Context, lContext lcontext.LContext,
continue
}
- r.readFiles(ctx, lContext, paths, glob, re, retryInterval)
+ r.readFiles(ctx, ltx, paths, glob, re, retryInterval)
return
}
- r.server.sendServerWarnMessage(logger.Warn(r.server.user, "Giving up to read file(s)"))
+ r.server.send(r.server.serverMessages, dlog.Server.Warn(r.server.user,
+ "Giving up to read file(s)"))
return
}
-func (r *readCommand) readFiles(ctx context.Context, lContext lcontext.LContext, paths []string, glob string, re regex.Regex, retryInterval time.Duration) {
+func (r *readCommand) readFiles(ctx context.Context, ltx lcontext.LContext,
+ paths []string, glob string, re regex.Regex, retryInterval time.Duration) {
+
var wg sync.WaitGroup
wg.Add(len(paths))
-
for _, path := range paths {
- go r.readFileIfPermissions(ctx, lContext, &wg, path, glob, re)
+ go r.readFileIfPermissions(ctx, ltx, &wg, path, glob, re)
}
-
wg.Wait()
}
-func (r *readCommand) readFileIfPermissions(ctx context.Context, lContext lcontext.LContext, wg *sync.WaitGroup, path, glob string, re regex.Regex) {
+func (r *readCommand) readFileIfPermissions(ctx context.Context, ltx lcontext.LContext,
+ wg *sync.WaitGroup, path, glob string, re regex.Regex) {
+
defer wg.Done()
globID := r.makeGlobID(path, glob)
-
if !r.server.user.HasFilePermission(path, "readfiles") {
- logger.Error(r.server.user, "No permission to read file", path, globID)
- r.server.sendServerWarnMessage(logger.Warn(r.server.user, "Unable to read file(s), check server logs"))
+ dlog.Server.Error(r.server.user, "No permission to read file", path, globID)
+ r.server.send(r.server.serverMessages, dlog.Server.Warn(r.server.user,
+ "Unable to read file(s), check server logs"))
return
}
-
- r.readFile(ctx, lContext, path, globID, re)
+ r.readFile(ctx, ltx, path, globID, re)
}
-func (r *readCommand) readFile(ctx context.Context, lContext lcontext.LContext, path, globID string, re regex.Regex) {
- logger.Info(r.server.user, "Start reading file", path, globID)
+func (r *readCommand) readFile(ctx context.Context, ltx lcontext.LContext,
+ path, globID string, re regex.Regex) {
+ dlog.Server.Info(r.server.user, "Start reading file", path, globID)
var reader fs.FileReader
switch r.mode {
case omode.TailClient:
@@ -114,15 +124,19 @@ func (r *readCommand) readFile(ctx context.Context, lContext lcontext.LContext,
}
lines := r.server.lines
-
- // Plug in mappreduce engine
- if r.server.aggregate != nil {
- lines = r.server.aggregate.Lines
- }
+ aggregate := r.server.aggregate
for {
- if err := reader.Start(ctx, lContext, lines, re); err != nil {
- logger.Error(r.server.user, path, globID, err)
+ if aggregate != nil {
+ lines = make(chan line.Line, 100)
+ aggregate.NextLinesCh <- lines
+ }
+ if err := reader.Start(ctx, ltx, lines, re); err != nil {
+ dlog.Server.Error(r.server.user, path, globID, err)
+ }
+ if aggregate != nil {
+ // Also makes aggregate to Flush
+ close(lines)
}
select {
@@ -133,9 +147,8 @@ func (r *readCommand) readFile(ctx context.Context, lContext lcontext.LContext,
return
}
}
-
time.Sleep(time.Second * 2)
- logger.Info(path, globID, "Reading file again")
+ dlog.Server.Info(path, globID, "Reading file again")
}
}
@@ -152,11 +165,11 @@ func (r *readCommand) makeGlobID(path, glob string) string {
if len(idParts) > 0 {
return strings.Join(idParts, "/")
}
-
if len(pathParts) > 0 {
return pathParts[len(pathParts)-1]
}
- r.server.sendServerWarnMessage(logger.Warn("Empty file path given?", path, glob))
+ r.server.send(r.server.serverMessages,
+ dlog.Server.Warn("Empty file path given?", path, glob))
return ""
}
diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go
index 39d5d5f..36574a9 100644
--- a/internal/server/handlers/serverhandler.go
+++ b/internal/server/handlers/serverhandler.go
@@ -2,69 +2,50 @@ package handlers
import (
"context"
- "encoding/base64"
- "errors"
- "fmt"
- "io"
"os"
- "strconv"
"strings"
- "sync/atomic"
- "time"
"github.com/mimecast/dtail/internal"
- "github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/io/dlog"
"github.com/mimecast/dtail/internal/io/line"
- "github.com/mimecast/dtail/internal/io/logger"
"github.com/mimecast/dtail/internal/lcontext"
- "github.com/mimecast/dtail/internal/mapr/server"
"github.com/mimecast/dtail/internal/omode"
user "github.com/mimecast/dtail/internal/user/server"
- "github.com/mimecast/dtail/internal/version"
-)
-
-const (
- commandParseWarning string = "Unable to parse command"
)
// ServerHandler implements the Reader and Writer interfaces to handle
// the Bi-directional communication between SSH client and server.
// This handler implements the handler of the SSH server.
type ServerHandler struct {
- done *internal.Done
- lines chan line.Line
- regex string
- aggregate *server.Aggregate
- aggregatedMessages chan string
- serverMessages chan string
- payload []byte
- hostname string
- user *user.User
- catLimiter chan struct{}
- tailLimiter chan struct{}
- ackCloseReceived chan struct{}
- activeCommands int32
- activeReaders int32
- quiet bool
+ baseHandler
+ catLimiter chan struct{}
+ tailLimiter chan struct{}
+ regex string
}
// NewServerHandler returns the server handler.
-func NewServerHandler(user *user.User, catLimiter, tailLimiter chan struct{}) *ServerHandler {
+func NewServerHandler(user *user.User, catLimiter,
+ tailLimiter chan struct{}) *ServerHandler {
+
+ dlog.Server.Debug(user, "Creating new server handler")
h := ServerHandler{
- done: internal.NewDone(),
- lines: make(chan line.Line, 100),
- serverMessages: make(chan string, 10),
- aggregatedMessages: make(chan string, 10),
- ackCloseReceived: make(chan struct{}),
- catLimiter: catLimiter,
- tailLimiter: tailLimiter,
- regex: ".",
- user: user,
- }
+ baseHandler: baseHandler{
+ done: internal.NewDone(),
+ lines: make(chan line.Line, 100),
+ serverMessages: make(chan string, 10),
+ maprMessages: make(chan string, 10),
+ ackCloseReceived: make(chan struct{}),
+ user: user,
+ },
+ catLimiter: catLimiter,
+ tailLimiter: tailLimiter,
+ regex: ".",
+ }
+ h.handleCommandCb = h.handleUserCommand
fqdn, err := os.Hostname()
if err != nil {
- logger.FatalExit(err)
+ dlog.Server.FatalPanic(err)
}
s := strings.Split(fqdn, ".")
@@ -73,374 +54,49 @@ func NewServerHandler(user *user.User, catLimiter, tailLimiter chan struct{}) *S
return &h
}
-// Shutdown the handler.
-func (h *ServerHandler) Shutdown() {
- h.done.Shutdown()
-}
-
-// Done channel of the handler.
-func (h *ServerHandler) Done() <-chan struct{} {
- return h.done.Done()
-}
-
-// Read is to send data to the dtail client via Reader interface.
-func (h *ServerHandler) Read(p []byte) (n int, err error) {
- for {
- select {
- case message := <-h.serverMessages:
- if len(message) == 0 {
- logger.Warn(h.user, "Empty message received")
- return
- }
- if message[0] == '.' {
- // Handle hidden message (don't display to the user, interpreted by dtail client)
- wholePayload := []byte(fmt.Sprintf("%s\n", message))
- n = copy(p, wholePayload)
- return
- }
-
- // Handle normal server message (display to the user)
- wholePayload := []byte(fmt.Sprintf("SERVER|%s|%s\n", h.hostname, message))
- n = copy(p, wholePayload)
- return
-
- case message := <-h.aggregatedMessages:
- // Send mapreduce-aggregated data as a message.
- data := fmt.Sprintf("AGGREGATE➔%s➔%s\n", h.hostname, message)
- wholePayload := []byte(data)
- n = copy(p, wholePayload)
- return
-
- case line := <-h.lines:
- // Send normal file content data as a message.
- serverInfo := []byte(fmt.Sprintf("REMOTE|%s|%3d|%v|%s|",
- h.hostname, line.TransmittedPerc, line.Count, line.SourceID))
- wholePayload := append(serverInfo, line.Content[:]...)
- n = copy(p, wholePayload)
- return
-
- case <-time.After(time.Second):
- // Once in a while check whether we are done.
- select {
- case <-h.done.Done():
- return 0, io.EOF
- default:
- }
- }
- }
-}
-
-// Write is to receive data from the dtail client via Writer interface.
-func (h *ServerHandler) Write(p []byte) (n int, err error) {
- for _, c := range p {
- switch c {
- case ';':
- commandStr := strings.TrimSpace(string(h.payload))
- h.handleCommand(commandStr)
- h.payload = nil
- default:
- h.payload = append(h.payload, c)
- }
- }
-
- n = len(p)
- return
-}
-
-func (h *ServerHandler) handleCommand(commandStr string) {
- logger.Debug(h.user, commandStr)
- ctx := context.Background()
-
- args, argc, err := h.handleProtocolVersion(strings.Split(commandStr, " "))
- if err != nil {
- h.send(h.serverMessages, logger.Error(h.user, err))
- return
- }
-
- args, argc, err = h.handleBase64(args, argc)
- if err != nil {
- h.send(h.serverMessages, logger.Error(h.user, err))
- return
- }
-
- if h.user.Name == config.ControlUser {
- h.handleControlCommand(argc, args)
- return
- }
-
- ctx, cancel := context.WithCancel(ctx)
- go func() {
- <-h.done.Done()
- cancel()
- }()
-
- h.handleUserCommand(ctx, argc, args)
-}
-
-func (h *ServerHandler) handleProtocolVersion(args []string) ([]string, int, error) {
- argc := len(args)
-
- if argc <= 2 || args[0] != "protocol" {
- return args, argc, errors.New("unable to determine protocol version")
- }
-
- if args[1] != version.ProtocolCompat {
- err := fmt.Errorf("server with protocol version '%s' but client with '%s', please update DTail", version.ProtocolCompat, args[1])
- return args, argc, err
- }
-
- return args[2:], argc - 2, nil
-}
-
-func (h *ServerHandler) handleBase64(args []string, argc int) ([]string, int, error) {
- err := errors.New("Unable to decode client message, DTail server and client versions may not be compatible")
-
- if argc != 2 || args[0] != "base64" {
- return args, argc, err
- }
-
- decoded, err := base64.StdEncoding.DecodeString(args[1])
- if err != nil {
- return args, argc, err
- }
- decodedStr := string(decoded)
-
- args = strings.Split(decodedStr, " ")
- argc = len(decodedStr)
- logger.Trace(h.user, "Base64 decoded received command", decodedStr, argc, args)
-
- return args, argc, nil
-}
-
-func (h *ServerHandler) handleControlCommand(argc int, args []string) {
- switch args[0] {
- case "debug":
- h.send(h.serverMessages, logger.Debug(h.user, "Receiving debug command", argc, args))
- default:
- logger.Warn(h.user, "Received unknown control command", argc, args)
- }
-}
-
-func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []string) {
- logger.Debug(h.user, "handleUserCommand", argc, args)
+func (h *ServerHandler) handleUserCommand(ctx context.Context, ltx lcontext.LContext,
+ argc int, args []string, commandName string) {
+ dlog.Server.Debug(h.user, "Handling user command", argc, args)
h.incrementActiveCommands()
commandFinished := func() {
if h.decrementActiveCommands() == 0 {
h.shutdown()
}
}
- readerFinished := func() {
- if h.decrementActiveReaders() == 0 {
- if h.aggregate == nil {
- return
- }
- h.aggregate.Shutdown()
- }
- }
-
- splitted := strings.Split(args[0], ":")
- commandName := splitted[0]
-
- options, lContext, err := readOptions(splitted[1:])
- if err != nil {
- h.sendServerMessage(logger.Error(h.user, err))
- commandFinished()
- return
- }
- if quiet, ok := options["quiet"]; ok {
- if quiet == "true" {
- logger.Debug(h.user, "Enabling quiet mode")
- h.quiet = true
- }
- }
switch commandName {
case "grep", "cat":
command := newReadCommand(h, omode.CatClient)
go func() {
- h.incrementActiveReaders()
- command.Start(ctx, lContext, argc, args, 1)
- readerFinished()
+ command.Start(ctx, ltx, argc, args, 1)
commandFinished()
}()
-
case "tail":
command := newReadCommand(h, omode.TailClient)
go func() {
- h.incrementActiveReaders()
- command.Start(ctx, lContext, argc, args, 10)
- readerFinished()
+ command.Start(ctx, ltx, argc, args, 10)
commandFinished()
}()
-
case "map":
command, aggregate, err := newMapCommand(h, argc, args)
if err != nil {
- h.sendServerMessage(err.Error())
- logger.Error(h.user, err)
+ h.send(h.serverMessages, err.Error())
+ dlog.Server.Error(h.user, err)
commandFinished()
return
}
-
h.aggregate = aggregate
go func() {
- command.Start(ctx, h.aggregatedMessages)
+ command.Start(ctx, h.maprMessages)
commandFinished()
}()
-
- case "ack", ".ack":
+ case ".ack":
h.handleAckCommand(argc, args)
commandFinished()
-
default:
- h.sendServerMessage(logger.Error(h.user, "Received unknown user command", commandName, argc, args, options))
+ h.send(h.serverMessages, dlog.Server.Error(h.user,
+ "Received unknown user command", commandName, argc, args))
commandFinished()
}
}
-
-func (h *ServerHandler) handleAckCommand(argc int, args []string) {
- if argc < 3 {
- h.sendServerWarnMessage(logger.Warn(h.user, commandParseWarning, args, argc))
- return
- }
- if args[1] == "close" && args[2] == "connection" {
- close(h.ackCloseReceived)
- }
-}
-
-func (h *ServerHandler) send(ch chan<- string, message string) {
- select {
- case ch <- message:
- case <-h.done.Done():
- }
-}
-
-func (h *ServerHandler) sendServerMessage(message string) {
- h.send(h.serverMessageC(), message)
-}
-
-func (h *ServerHandler) sendServerWarnMessage(message string) {
- if h.quiet {
- return
- }
- h.send(h.serverMessageC(), message)
-}
-
-func (h *ServerHandler) serverMessageC() chan<- string {
- return h.serverMessages
-}
-
-func (h *ServerHandler) flush() {
- logger.Debug(h.user, "flush()")
-
- if h.aggregate != nil {
- h.aggregate.Flush()
- }
-
- unsentMessages := func() int {
- return len(h.lines) + len(h.serverMessages) + len(h.aggregatedMessages)
- }
- for i := 0; i < 3; i++ {
- if unsentMessages() == 0 {
- logger.Debug(h.user, "All lines sent")
- return
- }
- logger.Debug(h.user, "Still lines to be sent")
- time.Sleep(time.Second)
- }
-
- logger.Warn(h.user, "Some lines remain unsent", unsentMessages())
-}
-
-func (h *ServerHandler) shutdown() {
- logger.Debug(h.user, "shutdown()")
- h.flush()
-
- go func() {
- select {
- case h.serverMessageC() <- ".syn close connection":
- case <-h.done.Done():
- }
- }()
-
- select {
- case <-h.ackCloseReceived:
- case <-time.After(time.Second * 5):
- logger.Debug(h.user, "Shutdown timeout reached, enforcing shutdown")
- case <-h.done.Done():
- }
-
- h.done.Shutdown()
-}
-
-func (h *ServerHandler) incrementActiveCommands() {
- atomic.AddInt32(&h.activeCommands, 1)
-}
-
-func (h *ServerHandler) decrementActiveCommands() int32 {
- atomic.AddInt32(&h.activeCommands, -1)
- return atomic.LoadInt32(&h.activeCommands)
-}
-
-func (h *ServerHandler) incrementActiveReaders() {
- atomic.AddInt32(&h.activeReaders, 1)
-}
-
-func (h *ServerHandler) decrementActiveReaders() int32 {
- atomic.AddInt32(&h.activeReaders, -1)
- return atomic.LoadInt32(&h.activeReaders)
-}
-
-// TODO: All options related code should be in its own package (client + server)
-func readOptions(opts []string) (map[string]string, lcontext.LContext, error) {
- options := make(map[string]string, len(opts))
- // Local search context
- var lContext lcontext.LContext
-
- for _, o := range opts {
- kv := strings.SplitN(o, "=", 2)
- if len(kv) != 2 {
- continue
- }
- key := kv[0]
- val := kv[1]
-
- if strings.HasPrefix(val, "base64%") {
- s := strings.SplitN(val, "%", 2)
- decoded, err := base64.StdEncoding.DecodeString(s[1])
- if err != nil {
- return options, lContext, err
- }
- val = string(decoded)
- }
-
- switch key {
- case "before":
- iVal, err := strconv.Atoi(val)
- if err != nil {
- logger.Error(err)
- continue
- }
- lContext.BeforeContext = iVal
- case "after":
- iVal, err := strconv.Atoi(val)
- if err != nil {
- logger.Error(err)
- continue
- }
- lContext.AfterContext = iVal
- case "max":
- iVal, err := strconv.Atoi(val)
- if err != nil {
- logger.Error(err)
- continue
- }
- lContext.MaxCount = iVal
- default:
- options[key] = val
- }
- }
-
- return options, lContext, nil
-}