summaryrefslogtreecommitdiff
path: root/internal/server/handlers/serverhandler.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/server/handlers/serverhandler.go')
-rw-r--r--internal/server/handlers/serverhandler.go412
1 files changed, 34 insertions, 378 deletions
diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go
index 39d5d5f..36574a9 100644
--- a/internal/server/handlers/serverhandler.go
+++ b/internal/server/handlers/serverhandler.go
@@ -2,69 +2,50 @@ package handlers
import (
"context"
- "encoding/base64"
- "errors"
- "fmt"
- "io"
"os"
- "strconv"
"strings"
- "sync/atomic"
- "time"
"github.com/mimecast/dtail/internal"
- "github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/io/dlog"
"github.com/mimecast/dtail/internal/io/line"
- "github.com/mimecast/dtail/internal/io/logger"
"github.com/mimecast/dtail/internal/lcontext"
- "github.com/mimecast/dtail/internal/mapr/server"
"github.com/mimecast/dtail/internal/omode"
user "github.com/mimecast/dtail/internal/user/server"
- "github.com/mimecast/dtail/internal/version"
-)
-
-const (
- commandParseWarning string = "Unable to parse command"
)
// ServerHandler implements the Reader and Writer interfaces to handle
// the Bi-directional communication between SSH client and server.
// This handler implements the handler of the SSH server.
type ServerHandler struct {
- done *internal.Done
- lines chan line.Line
- regex string
- aggregate *server.Aggregate
- aggregatedMessages chan string
- serverMessages chan string
- payload []byte
- hostname string
- user *user.User
- catLimiter chan struct{}
- tailLimiter chan struct{}
- ackCloseReceived chan struct{}
- activeCommands int32
- activeReaders int32
- quiet bool
+ baseHandler
+ catLimiter chan struct{}
+ tailLimiter chan struct{}
+ regex string
}
// NewServerHandler returns the server handler.
-func NewServerHandler(user *user.User, catLimiter, tailLimiter chan struct{}) *ServerHandler {
+func NewServerHandler(user *user.User, catLimiter,
+ tailLimiter chan struct{}) *ServerHandler {
+
+ dlog.Server.Debug(user, "Creating new server handler")
h := ServerHandler{
- done: internal.NewDone(),
- lines: make(chan line.Line, 100),
- serverMessages: make(chan string, 10),
- aggregatedMessages: make(chan string, 10),
- ackCloseReceived: make(chan struct{}),
- catLimiter: catLimiter,
- tailLimiter: tailLimiter,
- regex: ".",
- user: user,
- }
+ baseHandler: baseHandler{
+ done: internal.NewDone(),
+ lines: make(chan line.Line, 100),
+ serverMessages: make(chan string, 10),
+ maprMessages: make(chan string, 10),
+ ackCloseReceived: make(chan struct{}),
+ user: user,
+ },
+ catLimiter: catLimiter,
+ tailLimiter: tailLimiter,
+ regex: ".",
+ }
+ h.handleCommandCb = h.handleUserCommand
fqdn, err := os.Hostname()
if err != nil {
- logger.FatalExit(err)
+ dlog.Server.FatalPanic(err)
}
s := strings.Split(fqdn, ".")
@@ -73,374 +54,49 @@ func NewServerHandler(user *user.User, catLimiter, tailLimiter chan struct{}) *S
return &h
}
-// Shutdown the handler.
-func (h *ServerHandler) Shutdown() {
- h.done.Shutdown()
-}
-
-// Done channel of the handler.
-func (h *ServerHandler) Done() <-chan struct{} {
- return h.done.Done()
-}
-
-// Read is to send data to the dtail client via Reader interface.
-func (h *ServerHandler) Read(p []byte) (n int, err error) {
- for {
- select {
- case message := <-h.serverMessages:
- if len(message) == 0 {
- logger.Warn(h.user, "Empty message received")
- return
- }
- if message[0] == '.' {
- // Handle hidden message (don't display to the user, interpreted by dtail client)
- wholePayload := []byte(fmt.Sprintf("%s\n", message))
- n = copy(p, wholePayload)
- return
- }
-
- // Handle normal server message (display to the user)
- wholePayload := []byte(fmt.Sprintf("SERVER|%s|%s\n", h.hostname, message))
- n = copy(p, wholePayload)
- return
-
- case message := <-h.aggregatedMessages:
- // Send mapreduce-aggregated data as a message.
- data := fmt.Sprintf("AGGREGATE➔%s➔%s\n", h.hostname, message)
- wholePayload := []byte(data)
- n = copy(p, wholePayload)
- return
-
- case line := <-h.lines:
- // Send normal file content data as a message.
- serverInfo := []byte(fmt.Sprintf("REMOTE|%s|%3d|%v|%s|",
- h.hostname, line.TransmittedPerc, line.Count, line.SourceID))
- wholePayload := append(serverInfo, line.Content[:]...)
- n = copy(p, wholePayload)
- return
-
- case <-time.After(time.Second):
- // Once in a while check whether we are done.
- select {
- case <-h.done.Done():
- return 0, io.EOF
- default:
- }
- }
- }
-}
-
-// Write is to receive data from the dtail client via Writer interface.
-func (h *ServerHandler) Write(p []byte) (n int, err error) {
- for _, c := range p {
- switch c {
- case ';':
- commandStr := strings.TrimSpace(string(h.payload))
- h.handleCommand(commandStr)
- h.payload = nil
- default:
- h.payload = append(h.payload, c)
- }
- }
-
- n = len(p)
- return
-}
-
-func (h *ServerHandler) handleCommand(commandStr string) {
- logger.Debug(h.user, commandStr)
- ctx := context.Background()
-
- args, argc, err := h.handleProtocolVersion(strings.Split(commandStr, " "))
- if err != nil {
- h.send(h.serverMessages, logger.Error(h.user, err))
- return
- }
-
- args, argc, err = h.handleBase64(args, argc)
- if err != nil {
- h.send(h.serverMessages, logger.Error(h.user, err))
- return
- }
-
- if h.user.Name == config.ControlUser {
- h.handleControlCommand(argc, args)
- return
- }
-
- ctx, cancel := context.WithCancel(ctx)
- go func() {
- <-h.done.Done()
- cancel()
- }()
-
- h.handleUserCommand(ctx, argc, args)
-}
-
-func (h *ServerHandler) handleProtocolVersion(args []string) ([]string, int, error) {
- argc := len(args)
-
- if argc <= 2 || args[0] != "protocol" {
- return args, argc, errors.New("unable to determine protocol version")
- }
-
- if args[1] != version.ProtocolCompat {
- err := fmt.Errorf("server with protocol version '%s' but client with '%s', please update DTail", version.ProtocolCompat, args[1])
- return args, argc, err
- }
-
- return args[2:], argc - 2, nil
-}
-
-func (h *ServerHandler) handleBase64(args []string, argc int) ([]string, int, error) {
- err := errors.New("Unable to decode client message, DTail server and client versions may not be compatible")
-
- if argc != 2 || args[0] != "base64" {
- return args, argc, err
- }
-
- decoded, err := base64.StdEncoding.DecodeString(args[1])
- if err != nil {
- return args, argc, err
- }
- decodedStr := string(decoded)
-
- args = strings.Split(decodedStr, " ")
- argc = len(decodedStr)
- logger.Trace(h.user, "Base64 decoded received command", decodedStr, argc, args)
-
- return args, argc, nil
-}
-
-func (h *ServerHandler) handleControlCommand(argc int, args []string) {
- switch args[0] {
- case "debug":
- h.send(h.serverMessages, logger.Debug(h.user, "Receiving debug command", argc, args))
- default:
- logger.Warn(h.user, "Received unknown control command", argc, args)
- }
-}
-
-func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []string) {
- logger.Debug(h.user, "handleUserCommand", argc, args)
+func (h *ServerHandler) handleUserCommand(ctx context.Context, ltx lcontext.LContext,
+ argc int, args []string, commandName string) {
+ dlog.Server.Debug(h.user, "Handling user command", argc, args)
h.incrementActiveCommands()
commandFinished := func() {
if h.decrementActiveCommands() == 0 {
h.shutdown()
}
}
- readerFinished := func() {
- if h.decrementActiveReaders() == 0 {
- if h.aggregate == nil {
- return
- }
- h.aggregate.Shutdown()
- }
- }
-
- splitted := strings.Split(args[0], ":")
- commandName := splitted[0]
-
- options, lContext, err := readOptions(splitted[1:])
- if err != nil {
- h.sendServerMessage(logger.Error(h.user, err))
- commandFinished()
- return
- }
- if quiet, ok := options["quiet"]; ok {
- if quiet == "true" {
- logger.Debug(h.user, "Enabling quiet mode")
- h.quiet = true
- }
- }
switch commandName {
case "grep", "cat":
command := newReadCommand(h, omode.CatClient)
go func() {
- h.incrementActiveReaders()
- command.Start(ctx, lContext, argc, args, 1)
- readerFinished()
+ command.Start(ctx, ltx, argc, args, 1)
commandFinished()
}()
-
case "tail":
command := newReadCommand(h, omode.TailClient)
go func() {
- h.incrementActiveReaders()
- command.Start(ctx, lContext, argc, args, 10)
- readerFinished()
+ command.Start(ctx, ltx, argc, args, 10)
commandFinished()
}()
-
case "map":
command, aggregate, err := newMapCommand(h, argc, args)
if err != nil {
- h.sendServerMessage(err.Error())
- logger.Error(h.user, err)
+ h.send(h.serverMessages, err.Error())
+ dlog.Server.Error(h.user, err)
commandFinished()
return
}
-
h.aggregate = aggregate
go func() {
- command.Start(ctx, h.aggregatedMessages)
+ command.Start(ctx, h.maprMessages)
commandFinished()
}()
-
- case "ack", ".ack":
+ case ".ack":
h.handleAckCommand(argc, args)
commandFinished()
-
default:
- h.sendServerMessage(logger.Error(h.user, "Received unknown user command", commandName, argc, args, options))
+ h.send(h.serverMessages, dlog.Server.Error(h.user,
+ "Received unknown user command", commandName, argc, args))
commandFinished()
}
}
-
-func (h *ServerHandler) handleAckCommand(argc int, args []string) {
- if argc < 3 {
- h.sendServerWarnMessage(logger.Warn(h.user, commandParseWarning, args, argc))
- return
- }
- if args[1] == "close" && args[2] == "connection" {
- close(h.ackCloseReceived)
- }
-}
-
-func (h *ServerHandler) send(ch chan<- string, message string) {
- select {
- case ch <- message:
- case <-h.done.Done():
- }
-}
-
-func (h *ServerHandler) sendServerMessage(message string) {
- h.send(h.serverMessageC(), message)
-}
-
-func (h *ServerHandler) sendServerWarnMessage(message string) {
- if h.quiet {
- return
- }
- h.send(h.serverMessageC(), message)
-}
-
-func (h *ServerHandler) serverMessageC() chan<- string {
- return h.serverMessages
-}
-
-func (h *ServerHandler) flush() {
- logger.Debug(h.user, "flush()")
-
- if h.aggregate != nil {
- h.aggregate.Flush()
- }
-
- unsentMessages := func() int {
- return len(h.lines) + len(h.serverMessages) + len(h.aggregatedMessages)
- }
- for i := 0; i < 3; i++ {
- if unsentMessages() == 0 {
- logger.Debug(h.user, "All lines sent")
- return
- }
- logger.Debug(h.user, "Still lines to be sent")
- time.Sleep(time.Second)
- }
-
- logger.Warn(h.user, "Some lines remain unsent", unsentMessages())
-}
-
-func (h *ServerHandler) shutdown() {
- logger.Debug(h.user, "shutdown()")
- h.flush()
-
- go func() {
- select {
- case h.serverMessageC() <- ".syn close connection":
- case <-h.done.Done():
- }
- }()
-
- select {
- case <-h.ackCloseReceived:
- case <-time.After(time.Second * 5):
- logger.Debug(h.user, "Shutdown timeout reached, enforcing shutdown")
- case <-h.done.Done():
- }
-
- h.done.Shutdown()
-}
-
-func (h *ServerHandler) incrementActiveCommands() {
- atomic.AddInt32(&h.activeCommands, 1)
-}
-
-func (h *ServerHandler) decrementActiveCommands() int32 {
- atomic.AddInt32(&h.activeCommands, -1)
- return atomic.LoadInt32(&h.activeCommands)
-}
-
-func (h *ServerHandler) incrementActiveReaders() {
- atomic.AddInt32(&h.activeReaders, 1)
-}
-
-func (h *ServerHandler) decrementActiveReaders() int32 {
- atomic.AddInt32(&h.activeReaders, -1)
- return atomic.LoadInt32(&h.activeReaders)
-}
-
-// TODO: All options related code should be in its own package (client + server)
-func readOptions(opts []string) (map[string]string, lcontext.LContext, error) {
- options := make(map[string]string, len(opts))
- // Local search context
- var lContext lcontext.LContext
-
- for _, o := range opts {
- kv := strings.SplitN(o, "=", 2)
- if len(kv) != 2 {
- continue
- }
- key := kv[0]
- val := kv[1]
-
- if strings.HasPrefix(val, "base64%") {
- s := strings.SplitN(val, "%", 2)
- decoded, err := base64.StdEncoding.DecodeString(s[1])
- if err != nil {
- return options, lContext, err
- }
- val = string(decoded)
- }
-
- switch key {
- case "before":
- iVal, err := strconv.Atoi(val)
- if err != nil {
- logger.Error(err)
- continue
- }
- lContext.BeforeContext = iVal
- case "after":
- iVal, err := strconv.Atoi(val)
- if err != nil {
- logger.Error(err)
- continue
- }
- lContext.AfterContext = iVal
- case "max":
- iVal, err := strconv.Atoi(val)
- if err != nil {
- logger.Error(err)
- continue
- }
- lContext.MaxCount = iVal
- default:
- options[key] = val
- }
- }
-
- return options, lContext, nil
-}