summaryrefslogtreecommitdiff
path: root/internal/server
diff options
context:
space:
mode:
authorPaul Buetow <paul@buetow.org>2021-10-05 10:00:38 +0300
committerPaul Buetow <paul@buetow.org>2021-10-05 10:00:38 +0300
commitf70622f307629a2542ea5eb128dea8c1043d3a40 (patch)
tree82455dac0c870b28aea8c96a426050dc215a8818 /internal/server
parent599075bc6580ba77dc22ba1c1ec8aa908ef2462d (diff)
more on this
Diffstat (limited to 'internal/server')
-rw-r--r--internal/server/handlers/basehandler.go283
-rw-r--r--internal/server/handlers/readcommand.go4
-rw-r--r--internal/server/handlers/serverhandler.go332
3 files changed, 319 insertions, 300 deletions
diff --git a/internal/server/handlers/basehandler.go b/internal/server/handlers/basehandler.go
new file mode 100644
index 0000000..12fb2b3
--- /dev/null
+++ b/internal/server/handlers/basehandler.go
@@ -0,0 +1,283 @@
+package handlers
+
+import (
+ "bytes"
+ "context"
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "io"
+ "strconv"
+ "strings"
+ "sync/atomic"
+ "time"
+
+ "github.com/mimecast/dtail/internal"
+ "github.com/mimecast/dtail/internal/io/dlog"
+ "github.com/mimecast/dtail/internal/io/line"
+ "github.com/mimecast/dtail/internal/io/pool"
+ "github.com/mimecast/dtail/internal/mapr/server"
+ "github.com/mimecast/dtail/internal/protocol"
+ user "github.com/mimecast/dtail/internal/user/server"
+)
+
+type handleCommandCb func(context.Context, int, []string)
+
+type baseHandler struct {
+ done *internal.Done
+ handleCommandCb handleCommandCb
+ lines chan line.Line
+ aggregate *server.Aggregate
+ maprMessages chan string
+ serverMessages chan string
+ hostname string
+ user *user.User
+ ackCloseReceived chan struct{}
+ activeCommands int32
+ quiet bool
+ spartan bool
+ serverless bool
+ readBuf bytes.Buffer
+ writeBuf bytes.Buffer
+}
+
+// Shutdown the handler.
+func (h *baseHandler) Shutdown() {
+ h.done.Shutdown()
+}
+
+// Done channel of the handler.
+func (h *baseHandler) Done() <-chan struct{} {
+ return h.done.Done()
+}
+
+// Read is to send data to the dtail client via Reader interface.
+func (h *baseHandler) Read(p []byte) (n int, err error) {
+ defer h.readBuf.Reset()
+
+ select {
+ case message := <-h.serverMessages:
+ if message[0] == '.' {
+ // Handle hidden message (don't display to the user, interpreted by dtail client)
+ h.readBuf.WriteString(message)
+ h.readBuf.WriteByte(protocol.MessageDelimiter)
+ n = copy(p, h.readBuf.Bytes())
+ return
+ }
+
+ if h.serverless {
+ // In serverless mode we have logged the server message already via the
+ // dlog logger, no need to send the message again to the client part.
+ return
+ }
+
+ // Handle normal server message (display to the user)
+ h.readBuf.WriteString("SERVER")
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(h.hostname)
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(message)
+ h.readBuf.WriteByte(protocol.MessageDelimiter)
+ n = copy(p, h.readBuf.Bytes())
+
+ case message := <-h.maprMessages:
+ // Send mapreduce-aggregated data as a message.
+ h.readBuf.WriteString("AGGREGATE")
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(h.hostname)
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(message)
+ h.readBuf.WriteByte(protocol.MessageDelimiter)
+ n = copy(p, h.readBuf.Bytes())
+
+ case line := <-h.lines:
+ if !h.spartan {
+ h.readBuf.WriteString("REMOTE")
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(h.hostname)
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(fmt.Sprintf("%3d", line.TransmittedPerc))
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(fmt.Sprintf("%v", line.Count))
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(line.SourceID)
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ }
+ h.readBuf.WriteString(line.Content.String())
+ h.readBuf.WriteByte(protocol.MessageDelimiter)
+ n = copy(p, h.readBuf.Bytes())
+ pool.RecycleBytesBuffer(line.Content)
+
+ case <-time.After(time.Second):
+ // Once in a while check whether we are done.
+ select {
+ case <-h.done.Done():
+ err = io.EOF
+ return
+ default:
+ }
+ }
+ return
+}
+
+// Write is to receive data from the dtail client via Writer interface.
+func (h *baseHandler) Write(p []byte) (n int, err error) {
+ for _, b := range p {
+ switch b {
+ case ';':
+ h.handleCommand(string(h.writeBuf.Bytes()))
+ h.writeBuf.Reset()
+ default:
+ h.writeBuf.WriteByte(b)
+ }
+ }
+
+ n = len(p)
+ return
+}
+
+func (h *baseHandler) handleCommand(commandStr string) {
+ dlog.Server.Debug(h.user, commandStr)
+
+ args, argc, add, err := h.handleProtocolVersion(strings.Split(commandStr, " "))
+ if err != nil {
+ h.send(h.serverMessages, dlog.Server.Error(h.user, err)+add)
+ return
+ }
+
+ args, argc, err = h.handleBase64(args, argc)
+ if err != nil {
+ h.send(h.serverMessages, dlog.Server.Error(h.user, err))
+ return
+ }
+
+ ctx, cancel := context.WithCancel(context.Background())
+ go func() {
+ <-h.done.Done()
+ cancel()
+ }()
+
+ h.handleCommandCb(ctx, argc, args)
+}
+
+func (h *baseHandler) handleProtocolVersion(args []string) ([]string, int, string, error) {
+ argc := len(args)
+ var add string
+
+ if argc <= 2 || args[0] != "protocol" {
+ return args, argc, add, errors.New("unable to determine protocol version")
+ }
+
+ if args[1] != protocol.ProtocolCompat {
+ clientCompat, _ := strconv.Atoi(args[1])
+ serverCompat, _ := strconv.Atoi(protocol.ProtocolCompat)
+ if clientCompat <= 3 {
+ // Protocol version 3 or lower expect a newline as message separator
+ // One day (after 2 major versions) this exception may be removed!
+ add = "\n"
+ }
+
+ toUpdate := "client"
+ if clientCompat > serverCompat {
+ toUpdate = "server"
+ }
+
+ err := fmt.Errorf("DTail server protocol version '%s' does not match client protocol version '%s', please update DTail %s!",
+ protocol.ProtocolCompat, args[1], toUpdate)
+ return args, argc, add, err
+ }
+
+ return args[2:], argc - 2, add, nil
+}
+
+func (h *baseHandler) handleBase64(args []string, argc int) ([]string, int, error) {
+ err := errors.New("Unable to decode client message, DTail server and client versions may not be compatible")
+
+ if argc != 2 || args[0] != "base64" {
+ return args, argc, err
+ }
+
+ decoded, err := base64.StdEncoding.DecodeString(args[1])
+ if err != nil {
+ return args, argc, err
+ }
+ decodedStr := string(decoded)
+
+ args = strings.Split(decodedStr, " ")
+ argc = len(decodedStr)
+ dlog.Server.Trace(h.user, "Base64 decoded received command", decodedStr, argc, args)
+
+ return args, argc, nil
+}
+
+func (h *baseHandler) handleAckCommand(argc int, args []string) {
+ if argc < 3 {
+ if !h.quiet {
+ h.send(h.serverMessages, dlog.Server.Warn(h.user, "Unable to parse command", args, argc))
+ }
+ return
+ }
+ if args[1] == "close" && args[2] == "connection" {
+ select {
+ case <-h.ackCloseReceived:
+ default:
+ close(h.ackCloseReceived)
+ }
+ }
+}
+
+func (h *baseHandler) send(ch chan<- string, message string) {
+ select {
+ case ch <- message:
+ case <-h.done.Done():
+ }
+}
+
+func (h *baseHandler) flush() {
+ dlog.Server.Debug(h.user, "flush()")
+
+ numUnsentMessages := func() int {
+ return len(h.lines) + len(h.serverMessages) + len(h.maprMessages)
+ }
+
+ for i := 0; i < 3; i++ {
+ if numUnsentMessages() == 0 {
+ dlog.Server.Debug(h.user, "All lines sent")
+ return
+ }
+ dlog.Server.Debug(h.user, "Still lines to be sent")
+ time.Sleep(time.Second)
+ }
+
+ dlog.Server.Warn(h.user, "Some lines remain unsent", numUnsentMessages())
+}
+
+func (h *baseHandler) shutdown() {
+ dlog.Server.Debug(h.user, "shutdown()")
+ h.flush()
+
+ go func() {
+ select {
+ case h.serverMessages <- ".syn close connection":
+ case <-h.done.Done():
+ }
+ }()
+
+ select {
+ case <-h.ackCloseReceived:
+ case <-time.After(time.Second * 5):
+ dlog.Server.Debug(h.user, "Shutdown timeout reached, enforcing shutdown")
+ case <-h.done.Done():
+ }
+
+ h.done.Shutdown()
+}
+
+func (h *baseHandler) incrementActiveCommands() {
+ atomic.AddInt32(&h.activeCommands, 1)
+}
+
+func (h *baseHandler) decrementActiveCommands() int32 {
+ atomic.AddInt32(&h.activeCommands, -1)
+ return atomic.LoadInt32(&h.activeCommands)
+}
diff --git a/internal/server/handlers/readcommand.go b/internal/server/handlers/readcommand.go
index 6579018..abc44c7 100644
--- a/internal/server/handlers/readcommand.go
+++ b/internal/server/handlers/readcommand.go
@@ -32,13 +32,13 @@ func (r *readCommand) Start(ctx context.Context, argc int, args []string, retrie
if argc >= 4 {
deserializedRegex, err := regex.Deserialize(strings.Join(args[2:], " "))
if err != nil {
- r.server.send(r.server.serverMessages, dlog.Server.Error(r.server.user, commandParseWarning, err))
+ r.server.send(r.server.serverMessages, dlog.Server.Error(r.server.user, "Unable to parse command", err))
return
}
re = deserializedRegex
}
if argc < 3 {
- r.server.send(r.server.serverMessages, dlog.Server.Warn(r.server.user, commandParseWarning, args, argc))
+ r.server.send(r.server.serverMessages, dlog.Server.Warn(r.server.user, "Unable to parse command", args, argc))
return
}
r.readGlob(ctx, args[1], re, retries)
diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go
index ace2626..2ec4fbf 100644
--- a/internal/server/handlers/serverhandler.go
+++ b/internal/server/handlers/serverhandler.go
@@ -1,69 +1,60 @@
package handlers
import (
- "bytes"
"context"
- "encoding/base64"
- "errors"
- "fmt"
- "io"
"os"
- "strconv"
"strings"
- "sync/atomic"
- "time"
"github.com/mimecast/dtail/internal"
"github.com/mimecast/dtail/internal/config"
"github.com/mimecast/dtail/internal/io/dlog"
"github.com/mimecast/dtail/internal/io/line"
- "github.com/mimecast/dtail/internal/io/pool"
- "github.com/mimecast/dtail/internal/mapr/server"
"github.com/mimecast/dtail/internal/omode"
- "github.com/mimecast/dtail/internal/protocol"
user "github.com/mimecast/dtail/internal/user/server"
)
-const (
- commandParseWarning string = "Unable to parse command"
-)
-
// ServerHandler implements the Reader and Writer interfaces to handle
// the Bi-directional communication between SSH client and server.
// This handler implements the handler of the SSH server.
type ServerHandler struct {
- done *internal.Done
- lines chan line.Line
- regex string
- aggregate *server.Aggregate
- maprMessages chan string
- serverMessages chan string
- hostname string
- user *user.User
- catLimiter chan struct{}
- tailLimiter chan struct{}
- ackCloseReceived chan struct{}
- activeCommands int32
- quiet bool
- spartan bool
- serverless bool
- readBuf bytes.Buffer
- writeBuf bytes.Buffer
+ baseHandler
+ catLimiter chan struct{}
+ tailLimiter chan struct{}
+ regex string
+ /*
+ done *internal.Done
+ lines chan line.Line
+ aggregate *server.Aggregate
+ maprMessages chan string
+ serverMessages chan string
+ hostname string
+ user *user.User
+ ackCloseReceived chan struct{}
+ activeCommands int32
+ quiet bool
+ spartan bool
+ serverless bool
+ readBuf bytes.Buffer
+ writeBuf bytes.Buffer
+ */
}
// NewServerHandler returns the server handler.
func NewServerHandler(user *user.User, catLimiter, tailLimiter chan struct{}) *ServerHandler {
h := ServerHandler{
- done: internal.NewDone(),
- lines: make(chan line.Line, 100),
- serverMessages: make(chan string, 10),
- maprMessages: make(chan string, 10),
- ackCloseReceived: make(chan struct{}),
- catLimiter: catLimiter,
- tailLimiter: tailLimiter,
- regex: ".",
- user: user,
- }
+ baseHandler: baseHandler{
+ done: internal.NewDone(),
+ lines: make(chan line.Line, 100),
+ serverMessages: make(chan string, 10),
+ maprMessages: make(chan string, 10),
+ ackCloseReceived: make(chan struct{}),
+ user: user,
+ },
+ catLimiter: catLimiter,
+ tailLimiter: tailLimiter,
+ regex: ".",
+ }
+ h.handleCommandCb = h.handleUserCommand
fqdn, err := os.Hostname()
if err != nil {
@@ -76,192 +67,8 @@ func NewServerHandler(user *user.User, catLimiter, tailLimiter chan struct{}) *S
return &h
}
-// Shutdown the handler.
-func (h *ServerHandler) Shutdown() {
- h.done.Shutdown()
-}
-
-// Done channel of the handler.
-func (h *ServerHandler) Done() <-chan struct{} {
- return h.done.Done()
-}
-
-// Read is to send data to the dtail client via Reader interface.
-func (h *ServerHandler) Read(p []byte) (n int, err error) {
- defer h.readBuf.Reset()
-
- select {
- case message := <-h.serverMessages:
- if message[0] == '.' {
- // Handle hidden message (don't display to the user, interpreted by dtail client)
- h.readBuf.WriteString(message)
- h.readBuf.WriteByte(protocol.MessageDelimiter)
- n = copy(p, h.readBuf.Bytes())
- return
- }
-
- if h.serverless {
- // In serverless mode we have logged the server message already via the
- // dlog logger, no need to send the message again to the client part.
- return
- }
-
- // Handle normal server message (display to the user)
- h.readBuf.WriteString("SERVER")
- h.readBuf.WriteString(protocol.FieldDelimiter)
- h.readBuf.WriteString(h.hostname)
- h.readBuf.WriteString(protocol.FieldDelimiter)
- h.readBuf.WriteString(message)
- h.readBuf.WriteByte(protocol.MessageDelimiter)
- n = copy(p, h.readBuf.Bytes())
-
- case message := <-h.maprMessages:
- // Send mapreduce-aggregated data as a message.
- h.readBuf.WriteString("AGGREGATE")
- h.readBuf.WriteString(protocol.FieldDelimiter)
- h.readBuf.WriteString(h.hostname)
- h.readBuf.WriteString(protocol.FieldDelimiter)
- h.readBuf.WriteString(message)
- h.readBuf.WriteByte(protocol.MessageDelimiter)
- n = copy(p, h.readBuf.Bytes())
-
- case line := <-h.lines:
- if !h.spartan {
- h.readBuf.WriteString("REMOTE")
- h.readBuf.WriteString(protocol.FieldDelimiter)
- h.readBuf.WriteString(h.hostname)
- h.readBuf.WriteString(protocol.FieldDelimiter)
- h.readBuf.WriteString(fmt.Sprintf("%3d", line.TransmittedPerc))
- h.readBuf.WriteString(protocol.FieldDelimiter)
- h.readBuf.WriteString(fmt.Sprintf("%v", line.Count))
- h.readBuf.WriteString(protocol.FieldDelimiter)
- h.readBuf.WriteString(line.SourceID)
- h.readBuf.WriteString(protocol.FieldDelimiter)
- }
- h.readBuf.WriteString(line.Content.String())
- h.readBuf.WriteByte(protocol.MessageDelimiter)
- n = copy(p, h.readBuf.Bytes())
- pool.RecycleBytesBuffer(line.Content)
-
- case <-time.After(time.Second):
- // Once in a while check whether we are done.
- select {
- case <-h.done.Done():
- err = io.EOF
- return
- default:
- }
- }
- return
-}
-
-// Write is to receive data from the dtail client via Writer interface.
-func (h *ServerHandler) Write(p []byte) (n int, err error) {
- for _, b := range p {
- switch b {
- case ';':
- h.handleCommand(string(h.writeBuf.Bytes()))
- h.writeBuf.Reset()
- default:
- h.writeBuf.WriteByte(b)
- }
- }
-
- n = len(p)
- return
-}
-
-func (h *ServerHandler) handleCommand(commandStr string) {
- dlog.Server.Debug(h.user, commandStr)
- ctx := context.Background()
-
- args, argc, add, err := h.handleProtocolVersion(strings.Split(commandStr, " "))
- if err != nil {
- h.send(h.serverMessages, dlog.Server.Error(h.user, err)+add)
- return
- }
-
- args, argc, err = h.handleBase64(args, argc)
- if err != nil {
- h.send(h.serverMessages, dlog.Server.Error(h.user, err))
- return
- }
-
- if h.user.Name == config.ControlUser {
- h.handleControlCommand(argc, args)
- return
- }
-
- ctx, cancel := context.WithCancel(ctx)
- go func() {
- <-h.done.Done()
- cancel()
- }()
-
- h.handleUserCommand(ctx, argc, args)
-}
-
-func (h *ServerHandler) handleProtocolVersion(args []string) ([]string, int, string, error) {
- argc := len(args)
- var add string
-
- if argc <= 2 || args[0] != "protocol" {
- return args, argc, add, errors.New("unable to determine protocol version")
- }
-
- if args[1] != protocol.ProtocolCompat {
- clientCompat, _ := strconv.Atoi(args[1])
- serverCompat, _ := strconv.Atoi(protocol.ProtocolCompat)
- if clientCompat <= 3 {
- // Protocol version 3 or lower expect a newline as message separator
- // One day (after 2 major versions) this exception may be removed!
- add = "\n"
- }
-
- toUpdate := "client"
- if clientCompat > serverCompat {
- toUpdate = "server"
- }
-
- err := fmt.Errorf("DTail server protocol version '%s' does not match client protocol version '%s', please update DTail %s!",
- protocol.ProtocolCompat, args[1], toUpdate)
- return args, argc, add, err
- }
-
- return args[2:], argc - 2, add, nil
-}
-
-func (h *ServerHandler) handleBase64(args []string, argc int) ([]string, int, error) {
- err := errors.New("Unable to decode client message, DTail server and client versions may not be compatible")
-
- if argc != 2 || args[0] != "base64" {
- return args, argc, err
- }
-
- decoded, err := base64.StdEncoding.DecodeString(args[1])
- if err != nil {
- return args, argc, err
- }
- decodedStr := string(decoded)
-
- args = strings.Split(decodedStr, " ")
- argc = len(decodedStr)
- dlog.Server.Trace(h.user, "Base64 decoded received command", decodedStr, argc, args)
-
- return args, argc, nil
-}
-
-func (h *ServerHandler) handleControlCommand(argc int, args []string) {
- switch args[0] {
- case "debug":
- h.send(h.serverMessages, dlog.Server.Debug(h.user, "Receiving debug command", argc, args))
- default:
- dlog.Server.Warn(h.user, "Received unknown control command", argc, args)
- }
-}
-
func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []string) {
- dlog.Server.Debug(h.user, "handleUserCommand", argc, args)
+ dlog.Server.Debug(h.user, "Handling user command", argc, args)
h.incrementActiveCommands()
commandFinished := func() {
@@ -332,74 +139,3 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []
commandFinished()
}
}
-
-func (h *ServerHandler) handleAckCommand(argc int, args []string) {
- if argc < 3 {
- if !h.quiet {
- h.send(h.serverMessages, dlog.Server.Warn(h.user, commandParseWarning, args, argc))
- }
- return
- }
- if args[1] == "close" && args[2] == "connection" {
- select {
- case <-h.ackCloseReceived:
- default:
- close(h.ackCloseReceived)
- }
- }
-}
-
-func (h *ServerHandler) send(ch chan<- string, message string) {
- select {
- case ch <- message:
- case <-h.done.Done():
- }
-}
-
-func (h *ServerHandler) flush() {
- dlog.Server.Debug(h.user, "flush()")
-
- unsentMessages := func() int {
- return len(h.lines) + len(h.serverMessages) + len(h.maprMessages)
- }
- for i := 0; i < 3; i++ {
- if unsentMessages() == 0 {
- dlog.Server.Debug(h.user, "All lines sent")
- return
- }
- dlog.Server.Debug(h.user, "Still lines to be sent")
- time.Sleep(time.Second)
- }
-
- dlog.Server.Warn(h.user, "Some lines remain unsent", unsentMessages())
-}
-
-func (h *ServerHandler) shutdown() {
- dlog.Server.Debug(h.user, "shutdown()")
- h.flush()
-
- go func() {
- select {
- case h.serverMessages <- ".syn close connection":
- case <-h.done.Done():
- }
- }()
-
- select {
- case <-h.ackCloseReceived:
- case <-time.After(time.Second * 5):
- dlog.Server.Debug(h.user, "Shutdown timeout reached, enforcing shutdown")
- case <-h.done.Done():
- }
-
- h.done.Shutdown()
-}
-
-func (h *ServerHandler) incrementActiveCommands() {
- atomic.AddInt32(&h.activeCommands, 1)
-}
-
-func (h *ServerHandler) decrementActiveCommands() int32 {
- atomic.AddInt32(&h.activeCommands, -1)
- return atomic.LoadInt32(&h.activeCommands)
-}