diff options
| author | Paul Buetow <paul@buetow.org> | 2021-10-05 10:00:38 +0300 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2021-10-05 10:00:38 +0300 |
| commit | f70622f307629a2542ea5eb128dea8c1043d3a40 (patch) | |
| tree | 82455dac0c870b28aea8c96a426050dc215a8818 /internal/server | |
| parent | 599075bc6580ba77dc22ba1c1ec8aa908ef2462d (diff) | |
more on this
Diffstat (limited to 'internal/server')
| -rw-r--r-- | internal/server/handlers/basehandler.go | 283 | ||||
| -rw-r--r-- | internal/server/handlers/readcommand.go | 4 | ||||
| -rw-r--r-- | internal/server/handlers/serverhandler.go | 332 |
3 files changed, 319 insertions, 300 deletions
diff --git a/internal/server/handlers/basehandler.go b/internal/server/handlers/basehandler.go new file mode 100644 index 0000000..12fb2b3 --- /dev/null +++ b/internal/server/handlers/basehandler.go @@ -0,0 +1,283 @@ +package handlers + +import ( + "bytes" + "context" + "encoding/base64" + "errors" + "fmt" + "io" + "strconv" + "strings" + "sync/atomic" + "time" + + "github.com/mimecast/dtail/internal" + "github.com/mimecast/dtail/internal/io/dlog" + "github.com/mimecast/dtail/internal/io/line" + "github.com/mimecast/dtail/internal/io/pool" + "github.com/mimecast/dtail/internal/mapr/server" + "github.com/mimecast/dtail/internal/protocol" + user "github.com/mimecast/dtail/internal/user/server" +) + +type handleCommandCb func(context.Context, int, []string) + +type baseHandler struct { + done *internal.Done + handleCommandCb handleCommandCb + lines chan line.Line + aggregate *server.Aggregate + maprMessages chan string + serverMessages chan string + hostname string + user *user.User + ackCloseReceived chan struct{} + activeCommands int32 + quiet bool + spartan bool + serverless bool + readBuf bytes.Buffer + writeBuf bytes.Buffer +} + +// Shutdown the handler. +func (h *baseHandler) Shutdown() { + h.done.Shutdown() +} + +// Done channel of the handler. +func (h *baseHandler) Done() <-chan struct{} { + return h.done.Done() +} + +// Read is to send data to the dtail client via Reader interface. +func (h *baseHandler) Read(p []byte) (n int, err error) { + defer h.readBuf.Reset() + + select { + case message := <-h.serverMessages: + if message[0] == '.' { + // Handle hidden message (don't display to the user, interpreted by dtail client) + h.readBuf.WriteString(message) + h.readBuf.WriteByte(protocol.MessageDelimiter) + n = copy(p, h.readBuf.Bytes()) + return + } + + if h.serverless { + // In serverless mode we have logged the server message already via the + // dlog logger, no need to send the message again to the client part. + return + } + + // Handle normal server message (display to the user) + h.readBuf.WriteString("SERVER") + h.readBuf.WriteString(protocol.FieldDelimiter) + h.readBuf.WriteString(h.hostname) + h.readBuf.WriteString(protocol.FieldDelimiter) + h.readBuf.WriteString(message) + h.readBuf.WriteByte(protocol.MessageDelimiter) + n = copy(p, h.readBuf.Bytes()) + + case message := <-h.maprMessages: + // Send mapreduce-aggregated data as a message. + h.readBuf.WriteString("AGGREGATE") + h.readBuf.WriteString(protocol.FieldDelimiter) + h.readBuf.WriteString(h.hostname) + h.readBuf.WriteString(protocol.FieldDelimiter) + h.readBuf.WriteString(message) + h.readBuf.WriteByte(protocol.MessageDelimiter) + n = copy(p, h.readBuf.Bytes()) + + case line := <-h.lines: + if !h.spartan { + h.readBuf.WriteString("REMOTE") + h.readBuf.WriteString(protocol.FieldDelimiter) + h.readBuf.WriteString(h.hostname) + h.readBuf.WriteString(protocol.FieldDelimiter) + h.readBuf.WriteString(fmt.Sprintf("%3d", line.TransmittedPerc)) + h.readBuf.WriteString(protocol.FieldDelimiter) + h.readBuf.WriteString(fmt.Sprintf("%v", line.Count)) + h.readBuf.WriteString(protocol.FieldDelimiter) + h.readBuf.WriteString(line.SourceID) + h.readBuf.WriteString(protocol.FieldDelimiter) + } + h.readBuf.WriteString(line.Content.String()) + h.readBuf.WriteByte(protocol.MessageDelimiter) + n = copy(p, h.readBuf.Bytes()) + pool.RecycleBytesBuffer(line.Content) + + case <-time.After(time.Second): + // Once in a while check whether we are done. + select { + case <-h.done.Done(): + err = io.EOF + return + default: + } + } + return +} + +// Write is to receive data from the dtail client via Writer interface. +func (h *baseHandler) Write(p []byte) (n int, err error) { + for _, b := range p { + switch b { + case ';': + h.handleCommand(string(h.writeBuf.Bytes())) + h.writeBuf.Reset() + default: + h.writeBuf.WriteByte(b) + } + } + + n = len(p) + return +} + +func (h *baseHandler) handleCommand(commandStr string) { + dlog.Server.Debug(h.user, commandStr) + + args, argc, add, err := h.handleProtocolVersion(strings.Split(commandStr, " ")) + if err != nil { + h.send(h.serverMessages, dlog.Server.Error(h.user, err)+add) + return + } + + args, argc, err = h.handleBase64(args, argc) + if err != nil { + h.send(h.serverMessages, dlog.Server.Error(h.user, err)) + return + } + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + <-h.done.Done() + cancel() + }() + + h.handleCommandCb(ctx, argc, args) +} + +func (h *baseHandler) handleProtocolVersion(args []string) ([]string, int, string, error) { + argc := len(args) + var add string + + if argc <= 2 || args[0] != "protocol" { + return args, argc, add, errors.New("unable to determine protocol version") + } + + if args[1] != protocol.ProtocolCompat { + clientCompat, _ := strconv.Atoi(args[1]) + serverCompat, _ := strconv.Atoi(protocol.ProtocolCompat) + if clientCompat <= 3 { + // Protocol version 3 or lower expect a newline as message separator + // One day (after 2 major versions) this exception may be removed! + add = "\n" + } + + toUpdate := "client" + if clientCompat > serverCompat { + toUpdate = "server" + } + + err := fmt.Errorf("DTail server protocol version '%s' does not match client protocol version '%s', please update DTail %s!", + protocol.ProtocolCompat, args[1], toUpdate) + return args, argc, add, err + } + + return args[2:], argc - 2, add, nil +} + +func (h *baseHandler) handleBase64(args []string, argc int) ([]string, int, error) { + err := errors.New("Unable to decode client message, DTail server and client versions may not be compatible") + + if argc != 2 || args[0] != "base64" { + return args, argc, err + } + + decoded, err := base64.StdEncoding.DecodeString(args[1]) + if err != nil { + return args, argc, err + } + decodedStr := string(decoded) + + args = strings.Split(decodedStr, " ") + argc = len(decodedStr) + dlog.Server.Trace(h.user, "Base64 decoded received command", decodedStr, argc, args) + + return args, argc, nil +} + +func (h *baseHandler) handleAckCommand(argc int, args []string) { + if argc < 3 { + if !h.quiet { + h.send(h.serverMessages, dlog.Server.Warn(h.user, "Unable to parse command", args, argc)) + } + return + } + if args[1] == "close" && args[2] == "connection" { + select { + case <-h.ackCloseReceived: + default: + close(h.ackCloseReceived) + } + } +} + +func (h *baseHandler) send(ch chan<- string, message string) { + select { + case ch <- message: + case <-h.done.Done(): + } +} + +func (h *baseHandler) flush() { + dlog.Server.Debug(h.user, "flush()") + + numUnsentMessages := func() int { + return len(h.lines) + len(h.serverMessages) + len(h.maprMessages) + } + + for i := 0; i < 3; i++ { + if numUnsentMessages() == 0 { + dlog.Server.Debug(h.user, "All lines sent") + return + } + dlog.Server.Debug(h.user, "Still lines to be sent") + time.Sleep(time.Second) + } + + dlog.Server.Warn(h.user, "Some lines remain unsent", numUnsentMessages()) +} + +func (h *baseHandler) shutdown() { + dlog.Server.Debug(h.user, "shutdown()") + h.flush() + + go func() { + select { + case h.serverMessages <- ".syn close connection": + case <-h.done.Done(): + } + }() + + select { + case <-h.ackCloseReceived: + case <-time.After(time.Second * 5): + dlog.Server.Debug(h.user, "Shutdown timeout reached, enforcing shutdown") + case <-h.done.Done(): + } + + h.done.Shutdown() +} + +func (h *baseHandler) incrementActiveCommands() { + atomic.AddInt32(&h.activeCommands, 1) +} + +func (h *baseHandler) decrementActiveCommands() int32 { + atomic.AddInt32(&h.activeCommands, -1) + return atomic.LoadInt32(&h.activeCommands) +} diff --git a/internal/server/handlers/readcommand.go b/internal/server/handlers/readcommand.go index 6579018..abc44c7 100644 --- a/internal/server/handlers/readcommand.go +++ b/internal/server/handlers/readcommand.go @@ -32,13 +32,13 @@ func (r *readCommand) Start(ctx context.Context, argc int, args []string, retrie if argc >= 4 { deserializedRegex, err := regex.Deserialize(strings.Join(args[2:], " ")) if err != nil { - r.server.send(r.server.serverMessages, dlog.Server.Error(r.server.user, commandParseWarning, err)) + r.server.send(r.server.serverMessages, dlog.Server.Error(r.server.user, "Unable to parse command", err)) return } re = deserializedRegex } if argc < 3 { - r.server.send(r.server.serverMessages, dlog.Server.Warn(r.server.user, commandParseWarning, args, argc)) + r.server.send(r.server.serverMessages, dlog.Server.Warn(r.server.user, "Unable to parse command", args, argc)) return } r.readGlob(ctx, args[1], re, retries) diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go index ace2626..2ec4fbf 100644 --- a/internal/server/handlers/serverhandler.go +++ b/internal/server/handlers/serverhandler.go @@ -1,69 +1,60 @@ package handlers import ( - "bytes" "context" - "encoding/base64" - "errors" - "fmt" - "io" "os" - "strconv" "strings" - "sync/atomic" - "time" "github.com/mimecast/dtail/internal" "github.com/mimecast/dtail/internal/config" "github.com/mimecast/dtail/internal/io/dlog" "github.com/mimecast/dtail/internal/io/line" - "github.com/mimecast/dtail/internal/io/pool" - "github.com/mimecast/dtail/internal/mapr/server" "github.com/mimecast/dtail/internal/omode" - "github.com/mimecast/dtail/internal/protocol" user "github.com/mimecast/dtail/internal/user/server" ) -const ( - commandParseWarning string = "Unable to parse command" -) - // ServerHandler implements the Reader and Writer interfaces to handle // the Bi-directional communication between SSH client and server. // This handler implements the handler of the SSH server. type ServerHandler struct { - done *internal.Done - lines chan line.Line - regex string - aggregate *server.Aggregate - maprMessages chan string - serverMessages chan string - hostname string - user *user.User - catLimiter chan struct{} - tailLimiter chan struct{} - ackCloseReceived chan struct{} - activeCommands int32 - quiet bool - spartan bool - serverless bool - readBuf bytes.Buffer - writeBuf bytes.Buffer + baseHandler + catLimiter chan struct{} + tailLimiter chan struct{} + regex string + /* + done *internal.Done + lines chan line.Line + aggregate *server.Aggregate + maprMessages chan string + serverMessages chan string + hostname string + user *user.User + ackCloseReceived chan struct{} + activeCommands int32 + quiet bool + spartan bool + serverless bool + readBuf bytes.Buffer + writeBuf bytes.Buffer + */ } // NewServerHandler returns the server handler. func NewServerHandler(user *user.User, catLimiter, tailLimiter chan struct{}) *ServerHandler { h := ServerHandler{ - done: internal.NewDone(), - lines: make(chan line.Line, 100), - serverMessages: make(chan string, 10), - maprMessages: make(chan string, 10), - ackCloseReceived: make(chan struct{}), - catLimiter: catLimiter, - tailLimiter: tailLimiter, - regex: ".", - user: user, - } + baseHandler: baseHandler{ + done: internal.NewDone(), + lines: make(chan line.Line, 100), + serverMessages: make(chan string, 10), + maprMessages: make(chan string, 10), + ackCloseReceived: make(chan struct{}), + user: user, + }, + catLimiter: catLimiter, + tailLimiter: tailLimiter, + regex: ".", + } + h.handleCommandCb = h.handleUserCommand fqdn, err := os.Hostname() if err != nil { @@ -76,192 +67,8 @@ func NewServerHandler(user *user.User, catLimiter, tailLimiter chan struct{}) *S return &h } -// Shutdown the handler. -func (h *ServerHandler) Shutdown() { - h.done.Shutdown() -} - -// Done channel of the handler. -func (h *ServerHandler) Done() <-chan struct{} { - return h.done.Done() -} - -// Read is to send data to the dtail client via Reader interface. -func (h *ServerHandler) Read(p []byte) (n int, err error) { - defer h.readBuf.Reset() - - select { - case message := <-h.serverMessages: - if message[0] == '.' { - // Handle hidden message (don't display to the user, interpreted by dtail client) - h.readBuf.WriteString(message) - h.readBuf.WriteByte(protocol.MessageDelimiter) - n = copy(p, h.readBuf.Bytes()) - return - } - - if h.serverless { - // In serverless mode we have logged the server message already via the - // dlog logger, no need to send the message again to the client part. - return - } - - // Handle normal server message (display to the user) - h.readBuf.WriteString("SERVER") - h.readBuf.WriteString(protocol.FieldDelimiter) - h.readBuf.WriteString(h.hostname) - h.readBuf.WriteString(protocol.FieldDelimiter) - h.readBuf.WriteString(message) - h.readBuf.WriteByte(protocol.MessageDelimiter) - n = copy(p, h.readBuf.Bytes()) - - case message := <-h.maprMessages: - // Send mapreduce-aggregated data as a message. - h.readBuf.WriteString("AGGREGATE") - h.readBuf.WriteString(protocol.FieldDelimiter) - h.readBuf.WriteString(h.hostname) - h.readBuf.WriteString(protocol.FieldDelimiter) - h.readBuf.WriteString(message) - h.readBuf.WriteByte(protocol.MessageDelimiter) - n = copy(p, h.readBuf.Bytes()) - - case line := <-h.lines: - if !h.spartan { - h.readBuf.WriteString("REMOTE") - h.readBuf.WriteString(protocol.FieldDelimiter) - h.readBuf.WriteString(h.hostname) - h.readBuf.WriteString(protocol.FieldDelimiter) - h.readBuf.WriteString(fmt.Sprintf("%3d", line.TransmittedPerc)) - h.readBuf.WriteString(protocol.FieldDelimiter) - h.readBuf.WriteString(fmt.Sprintf("%v", line.Count)) - h.readBuf.WriteString(protocol.FieldDelimiter) - h.readBuf.WriteString(line.SourceID) - h.readBuf.WriteString(protocol.FieldDelimiter) - } - h.readBuf.WriteString(line.Content.String()) - h.readBuf.WriteByte(protocol.MessageDelimiter) - n = copy(p, h.readBuf.Bytes()) - pool.RecycleBytesBuffer(line.Content) - - case <-time.After(time.Second): - // Once in a while check whether we are done. - select { - case <-h.done.Done(): - err = io.EOF - return - default: - } - } - return -} - -// Write is to receive data from the dtail client via Writer interface. -func (h *ServerHandler) Write(p []byte) (n int, err error) { - for _, b := range p { - switch b { - case ';': - h.handleCommand(string(h.writeBuf.Bytes())) - h.writeBuf.Reset() - default: - h.writeBuf.WriteByte(b) - } - } - - n = len(p) - return -} - -func (h *ServerHandler) handleCommand(commandStr string) { - dlog.Server.Debug(h.user, commandStr) - ctx := context.Background() - - args, argc, add, err := h.handleProtocolVersion(strings.Split(commandStr, " ")) - if err != nil { - h.send(h.serverMessages, dlog.Server.Error(h.user, err)+add) - return - } - - args, argc, err = h.handleBase64(args, argc) - if err != nil { - h.send(h.serverMessages, dlog.Server.Error(h.user, err)) - return - } - - if h.user.Name == config.ControlUser { - h.handleControlCommand(argc, args) - return - } - - ctx, cancel := context.WithCancel(ctx) - go func() { - <-h.done.Done() - cancel() - }() - - h.handleUserCommand(ctx, argc, args) -} - -func (h *ServerHandler) handleProtocolVersion(args []string) ([]string, int, string, error) { - argc := len(args) - var add string - - if argc <= 2 || args[0] != "protocol" { - return args, argc, add, errors.New("unable to determine protocol version") - } - - if args[1] != protocol.ProtocolCompat { - clientCompat, _ := strconv.Atoi(args[1]) - serverCompat, _ := strconv.Atoi(protocol.ProtocolCompat) - if clientCompat <= 3 { - // Protocol version 3 or lower expect a newline as message separator - // One day (after 2 major versions) this exception may be removed! - add = "\n" - } - - toUpdate := "client" - if clientCompat > serverCompat { - toUpdate = "server" - } - - err := fmt.Errorf("DTail server protocol version '%s' does not match client protocol version '%s', please update DTail %s!", - protocol.ProtocolCompat, args[1], toUpdate) - return args, argc, add, err - } - - return args[2:], argc - 2, add, nil -} - -func (h *ServerHandler) handleBase64(args []string, argc int) ([]string, int, error) { - err := errors.New("Unable to decode client message, DTail server and client versions may not be compatible") - - if argc != 2 || args[0] != "base64" { - return args, argc, err - } - - decoded, err := base64.StdEncoding.DecodeString(args[1]) - if err != nil { - return args, argc, err - } - decodedStr := string(decoded) - - args = strings.Split(decodedStr, " ") - argc = len(decodedStr) - dlog.Server.Trace(h.user, "Base64 decoded received command", decodedStr, argc, args) - - return args, argc, nil -} - -func (h *ServerHandler) handleControlCommand(argc int, args []string) { - switch args[0] { - case "debug": - h.send(h.serverMessages, dlog.Server.Debug(h.user, "Receiving debug command", argc, args)) - default: - dlog.Server.Warn(h.user, "Received unknown control command", argc, args) - } -} - func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []string) { - dlog.Server.Debug(h.user, "handleUserCommand", argc, args) + dlog.Server.Debug(h.user, "Handling user command", argc, args) h.incrementActiveCommands() commandFinished := func() { @@ -332,74 +139,3 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args [] commandFinished() } } - -func (h *ServerHandler) handleAckCommand(argc int, args []string) { - if argc < 3 { - if !h.quiet { - h.send(h.serverMessages, dlog.Server.Warn(h.user, commandParseWarning, args, argc)) - } - return - } - if args[1] == "close" && args[2] == "connection" { - select { - case <-h.ackCloseReceived: - default: - close(h.ackCloseReceived) - } - } -} - -func (h *ServerHandler) send(ch chan<- string, message string) { - select { - case ch <- message: - case <-h.done.Done(): - } -} - -func (h *ServerHandler) flush() { - dlog.Server.Debug(h.user, "flush()") - - unsentMessages := func() int { - return len(h.lines) + len(h.serverMessages) + len(h.maprMessages) - } - for i := 0; i < 3; i++ { - if unsentMessages() == 0 { - dlog.Server.Debug(h.user, "All lines sent") - return - } - dlog.Server.Debug(h.user, "Still lines to be sent") - time.Sleep(time.Second) - } - - dlog.Server.Warn(h.user, "Some lines remain unsent", unsentMessages()) -} - -func (h *ServerHandler) shutdown() { - dlog.Server.Debug(h.user, "shutdown()") - h.flush() - - go func() { - select { - case h.serverMessages <- ".syn close connection": - case <-h.done.Done(): - } - }() - - select { - case <-h.ackCloseReceived: - case <-time.After(time.Second * 5): - dlog.Server.Debug(h.user, "Shutdown timeout reached, enforcing shutdown") - case <-h.done.Done(): - } - - h.done.Shutdown() -} - -func (h *ServerHandler) incrementActiveCommands() { - atomic.AddInt32(&h.activeCommands, 1) -} - -func (h *ServerHandler) decrementActiveCommands() int32 { - atomic.AddInt32(&h.activeCommands, -1) - return atomic.LoadInt32(&h.activeCommands) -} |
