summaryrefslogtreecommitdiff
path: root/internal/server/handlers/basehandler.go
diff options
context:
space:
mode:
authorPaul Buetow <paul@buetow.org>2021-10-05 10:00:38 +0300
committerPaul Buetow <paul@buetow.org>2021-10-05 10:00:38 +0300
commitf70622f307629a2542ea5eb128dea8c1043d3a40 (patch)
tree82455dac0c870b28aea8c96a426050dc215a8818 /internal/server/handlers/basehandler.go
parent599075bc6580ba77dc22ba1c1ec8aa908ef2462d (diff)
more on this
Diffstat (limited to 'internal/server/handlers/basehandler.go')
-rw-r--r--internal/server/handlers/basehandler.go283
1 files changed, 283 insertions, 0 deletions
diff --git a/internal/server/handlers/basehandler.go b/internal/server/handlers/basehandler.go
new file mode 100644
index 0000000..12fb2b3
--- /dev/null
+++ b/internal/server/handlers/basehandler.go
@@ -0,0 +1,283 @@
+package handlers
+
+import (
+ "bytes"
+ "context"
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "io"
+ "strconv"
+ "strings"
+ "sync/atomic"
+ "time"
+
+ "github.com/mimecast/dtail/internal"
+ "github.com/mimecast/dtail/internal/io/dlog"
+ "github.com/mimecast/dtail/internal/io/line"
+ "github.com/mimecast/dtail/internal/io/pool"
+ "github.com/mimecast/dtail/internal/mapr/server"
+ "github.com/mimecast/dtail/internal/protocol"
+ user "github.com/mimecast/dtail/internal/user/server"
+)
+
+type handleCommandCb func(context.Context, int, []string)
+
+type baseHandler struct {
+ done *internal.Done
+ handleCommandCb handleCommandCb
+ lines chan line.Line
+ aggregate *server.Aggregate
+ maprMessages chan string
+ serverMessages chan string
+ hostname string
+ user *user.User
+ ackCloseReceived chan struct{}
+ activeCommands int32
+ quiet bool
+ spartan bool
+ serverless bool
+ readBuf bytes.Buffer
+ writeBuf bytes.Buffer
+}
+
+// Shutdown the handler.
+func (h *baseHandler) Shutdown() {
+ h.done.Shutdown()
+}
+
+// Done channel of the handler.
+func (h *baseHandler) Done() <-chan struct{} {
+ return h.done.Done()
+}
+
+// Read is to send data to the dtail client via Reader interface.
+func (h *baseHandler) Read(p []byte) (n int, err error) {
+ defer h.readBuf.Reset()
+
+ select {
+ case message := <-h.serverMessages:
+ if message[0] == '.' {
+ // Handle hidden message (don't display to the user, interpreted by dtail client)
+ h.readBuf.WriteString(message)
+ h.readBuf.WriteByte(protocol.MessageDelimiter)
+ n = copy(p, h.readBuf.Bytes())
+ return
+ }
+
+ if h.serverless {
+ // In serverless mode we have logged the server message already via the
+ // dlog logger, no need to send the message again to the client part.
+ return
+ }
+
+ // Handle normal server message (display to the user)
+ h.readBuf.WriteString("SERVER")
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(h.hostname)
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(message)
+ h.readBuf.WriteByte(protocol.MessageDelimiter)
+ n = copy(p, h.readBuf.Bytes())
+
+ case message := <-h.maprMessages:
+ // Send mapreduce-aggregated data as a message.
+ h.readBuf.WriteString("AGGREGATE")
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(h.hostname)
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(message)
+ h.readBuf.WriteByte(protocol.MessageDelimiter)
+ n = copy(p, h.readBuf.Bytes())
+
+ case line := <-h.lines:
+ if !h.spartan {
+ h.readBuf.WriteString("REMOTE")
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(h.hostname)
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(fmt.Sprintf("%3d", line.TransmittedPerc))
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(fmt.Sprintf("%v", line.Count))
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(line.SourceID)
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ }
+ h.readBuf.WriteString(line.Content.String())
+ h.readBuf.WriteByte(protocol.MessageDelimiter)
+ n = copy(p, h.readBuf.Bytes())
+ pool.RecycleBytesBuffer(line.Content)
+
+ case <-time.After(time.Second):
+ // Once in a while check whether we are done.
+ select {
+ case <-h.done.Done():
+ err = io.EOF
+ return
+ default:
+ }
+ }
+ return
+}
+
+// Write is to receive data from the dtail client via Writer interface.
+func (h *baseHandler) Write(p []byte) (n int, err error) {
+ for _, b := range p {
+ switch b {
+ case ';':
+ h.handleCommand(string(h.writeBuf.Bytes()))
+ h.writeBuf.Reset()
+ default:
+ h.writeBuf.WriteByte(b)
+ }
+ }
+
+ n = len(p)
+ return
+}
+
+func (h *baseHandler) handleCommand(commandStr string) {
+ dlog.Server.Debug(h.user, commandStr)
+
+ args, argc, add, err := h.handleProtocolVersion(strings.Split(commandStr, " "))
+ if err != nil {
+ h.send(h.serverMessages, dlog.Server.Error(h.user, err)+add)
+ return
+ }
+
+ args, argc, err = h.handleBase64(args, argc)
+ if err != nil {
+ h.send(h.serverMessages, dlog.Server.Error(h.user, err))
+ return
+ }
+
+ ctx, cancel := context.WithCancel(context.Background())
+ go func() {
+ <-h.done.Done()
+ cancel()
+ }()
+
+ h.handleCommandCb(ctx, argc, args)
+}
+
+func (h *baseHandler) handleProtocolVersion(args []string) ([]string, int, string, error) {
+ argc := len(args)
+ var add string
+
+ if argc <= 2 || args[0] != "protocol" {
+ return args, argc, add, errors.New("unable to determine protocol version")
+ }
+
+ if args[1] != protocol.ProtocolCompat {
+ clientCompat, _ := strconv.Atoi(args[1])
+ serverCompat, _ := strconv.Atoi(protocol.ProtocolCompat)
+ if clientCompat <= 3 {
+ // Protocol version 3 or lower expect a newline as message separator
+ // One day (after 2 major versions) this exception may be removed!
+ add = "\n"
+ }
+
+ toUpdate := "client"
+ if clientCompat > serverCompat {
+ toUpdate = "server"
+ }
+
+ err := fmt.Errorf("DTail server protocol version '%s' does not match client protocol version '%s', please update DTail %s!",
+ protocol.ProtocolCompat, args[1], toUpdate)
+ return args, argc, add, err
+ }
+
+ return args[2:], argc - 2, add, nil
+}
+
+func (h *baseHandler) handleBase64(args []string, argc int) ([]string, int, error) {
+ err := errors.New("Unable to decode client message, DTail server and client versions may not be compatible")
+
+ if argc != 2 || args[0] != "base64" {
+ return args, argc, err
+ }
+
+ decoded, err := base64.StdEncoding.DecodeString(args[1])
+ if err != nil {
+ return args, argc, err
+ }
+ decodedStr := string(decoded)
+
+ args = strings.Split(decodedStr, " ")
+ argc = len(decodedStr)
+ dlog.Server.Trace(h.user, "Base64 decoded received command", decodedStr, argc, args)
+
+ return args, argc, nil
+}
+
+func (h *baseHandler) handleAckCommand(argc int, args []string) {
+ if argc < 3 {
+ if !h.quiet {
+ h.send(h.serverMessages, dlog.Server.Warn(h.user, "Unable to parse command", args, argc))
+ }
+ return
+ }
+ if args[1] == "close" && args[2] == "connection" {
+ select {
+ case <-h.ackCloseReceived:
+ default:
+ close(h.ackCloseReceived)
+ }
+ }
+}
+
+func (h *baseHandler) send(ch chan<- string, message string) {
+ select {
+ case ch <- message:
+ case <-h.done.Done():
+ }
+}
+
+func (h *baseHandler) flush() {
+ dlog.Server.Debug(h.user, "flush()")
+
+ numUnsentMessages := func() int {
+ return len(h.lines) + len(h.serverMessages) + len(h.maprMessages)
+ }
+
+ for i := 0; i < 3; i++ {
+ if numUnsentMessages() == 0 {
+ dlog.Server.Debug(h.user, "All lines sent")
+ return
+ }
+ dlog.Server.Debug(h.user, "Still lines to be sent")
+ time.Sleep(time.Second)
+ }
+
+ dlog.Server.Warn(h.user, "Some lines remain unsent", numUnsentMessages())
+}
+
+func (h *baseHandler) shutdown() {
+ dlog.Server.Debug(h.user, "shutdown()")
+ h.flush()
+
+ go func() {
+ select {
+ case h.serverMessages <- ".syn close connection":
+ case <-h.done.Done():
+ }
+ }()
+
+ select {
+ case <-h.ackCloseReceived:
+ case <-time.After(time.Second * 5):
+ dlog.Server.Debug(h.user, "Shutdown timeout reached, enforcing shutdown")
+ case <-h.done.Done():
+ }
+
+ h.done.Shutdown()
+}
+
+func (h *baseHandler) incrementActiveCommands() {
+ atomic.AddInt32(&h.activeCommands, 1)
+}
+
+func (h *baseHandler) decrementActiveCommands() int32 {
+ atomic.AddInt32(&h.activeCommands, -1)
+ return atomic.LoadInt32(&h.activeCommands)
+}