diff options
Diffstat (limited to 'internal/server')
| -rw-r--r-- | internal/server/handlers/controlhandler.go | 42 | ||||
| -rw-r--r-- | internal/server/handlers/handler.go | 2 | ||||
| -rw-r--r-- | internal/server/handlers/mapcommand.go | 35 | ||||
| -rw-r--r-- | internal/server/handlers/readcommand.go | 158 | ||||
| -rw-r--r-- | internal/server/handlers/runcommand.go | 73 | ||||
| -rw-r--r-- | internal/server/handlers/serverhandler.go | 521 | ||||
| -rw-r--r-- | internal/server/server.go | 70 | ||||
| -rw-r--r-- | internal/server/stats.go | 10 |
8 files changed, 503 insertions, 408 deletions
diff --git a/internal/server/handlers/controlhandler.go b/internal/server/handlers/controlhandler.go index 482f759..a33a78b 100644 --- a/internal/server/handlers/controlhandler.go +++ b/internal/server/handlers/controlhandler.go @@ -1,33 +1,34 @@ package handlers import ( + "context" "fmt" "io" "os" "strings" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" user "github.com/mimecast/dtail/internal/user/server" ) // ControlHandler is used for control functions and health monitoring. type ControlHandler struct { - serverMessages chan string - pong chan struct{} - stop chan struct{} - payload []byte + ctx context.Context + done chan struct{} hostname string + payload []byte + serverMessages chan string user *user.User } // NewControlHandler returns a new control handler. -func NewControlHandler(user *user.User) *ControlHandler { +func NewControlHandler(ctx context.Context, user *user.User) (*ControlHandler, <-chan struct{}) { logger.Debug(user, "Creating control handler") h := ControlHandler{ + ctx: ctx, + done: make(chan struct{}), serverMessages: make(chan string, 10), - pong: make(chan struct{}, 10), - stop: make(chan struct{}), user: user, } @@ -38,7 +39,8 @@ func NewControlHandler(user *user.User) *ControlHandler { s := strings.Split(fqdn, ".") h.hostname = s[0] - return &h + + return &h, h.done } // Read is to send data to the client via the Reader interface. @@ -49,11 +51,7 @@ func (h *ControlHandler) Read(p []byte) (n int, err error) { wholePayload := []byte(fmt.Sprintf("SERVER|%s|%s\n", h.hostname, message)) n = copy(p, wholePayload) return - case <-h.pong: - logger.Info(h.user, "Sending pong") - n = copy(p, []byte(".pong\n")) - return - case <-h.stop: + case <-h.ctx.Done(): return 0, io.EOF } } @@ -65,7 +63,7 @@ func (h *ControlHandler) Write(p []byte) (n int, err error) { switch c { case ';': wholePayload := strings.TrimSpace(string(h.payload)) - h.handleCommand(wholePayload) + h.handleCommand(h.ctx, wholePayload) h.payload = nil default: @@ -77,17 +75,7 @@ func (h *ControlHandler) Write(p []byte) (n int, err error) { return } -// Close the control handler. -func (h *ControlHandler) Close() { - close(h.stop) -} - -// Wait returns the handler stop channel. -func (h *ControlHandler) Wait() <-chan struct{} { - return h.stop -} - -func (h *ControlHandler) handleCommand(command string) { +func (h *ControlHandler) handleCommand(ctx context.Context, command string) { logger.Info(h.user, command) s := strings.Split(command, " ") logger.Debug(h.user, "Receiving command", command, s) @@ -96,8 +84,6 @@ func (h *ControlHandler) handleCommand(command string) { case "health": h.serverMessages <- "OK: DTail SSH Server seems fine" h.serverMessages <- "done;" - case "ping": - h.pong <- struct{}{} case "debug": h.serverMessages <- logger.Debug(h.user, "Receiving debug command", command, s) default: diff --git a/internal/server/handlers/handler.go b/internal/server/handlers/handler.go index 8b1f73e..c42ceb9 100644 --- a/internal/server/handlers/handler.go +++ b/internal/server/handlers/handler.go @@ -5,6 +5,4 @@ import "io" // Handler interface for server side functionality. type Handler interface { io.ReadWriter - Close() - Wait() <-chan struct{} } diff --git a/internal/server/handlers/mapcommand.go b/internal/server/handlers/mapcommand.go new file mode 100644 index 0000000..10372da --- /dev/null +++ b/internal/server/handlers/mapcommand.go @@ -0,0 +1,35 @@ +package handlers + +import ( + "context" + "strings" + + "github.com/mimecast/dtail/internal/mapr/server" +) + +// Map command implements the mapreduce command server side. +type mapCommand struct { + aggregate *server.Aggregate + server *ServerHandler +} + +// NewMapCommand returns a new server side mapreduce command. +func newMapCommand(serverHandler *ServerHandler, argc int, args []string) (mapCommand, *server.Aggregate, error) { + mapCommand := mapCommand{ + server: serverHandler, + } + + queryStr := strings.Join(args[1:], " ") + aggregate, err := server.NewAggregate(queryStr) + if err != nil { + return mapCommand, nil, err + } + + mapCommand.aggregate = aggregate + return mapCommand, aggregate, nil + +} + +func (m mapCommand) Start(ctx context.Context, aggregatedMessages chan<- string) { + m.aggregate.Start(ctx, aggregatedMessages) +} diff --git a/internal/server/handlers/readcommand.go b/internal/server/handlers/readcommand.go new file mode 100644 index 0000000..e4079e8 --- /dev/null +++ b/internal/server/handlers/readcommand.go @@ -0,0 +1,158 @@ +package handlers + +import ( + "context" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/mimecast/dtail/internal/io/fs" + "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/omode" +) + +type readCommand struct { + server *ServerHandler + mode omode.Mode +} + +func newReadCommand(server *ServerHandler, mode omode.Mode) *readCommand { + return &readCommand{ + server: server, + mode: mode, + } +} + +func (r *readCommand) Start(ctx context.Context, argc int, args []string) { + regex := "." + if argc >= 4 { + regex = args[3] + } + if argc < 3 { + r.server.sendServerMessage(logger.Warn(r.server.user, commandParseWarning, args, argc)) + return + } + r.readGlob(ctx, args[1], regex) +} + +func (r *readCommand) readGlob(ctx context.Context, glob string, regex string) { + retryInterval := time.Second * 5 + glob = filepath.Clean(glob) + + maxRetries := 10 + for { + maxRetries-- + if maxRetries < 0 { + r.server.sendServerMessage(logger.Warn(r.server.user, "Giving up to read file(s)")) + return + } + + paths, err := filepath.Glob(glob) + if err != nil { + logger.Warn(r.server.user, glob, err) + time.Sleep(retryInterval) + continue + } + + if numPaths := len(paths); numPaths == 0 { + logger.Error(r.server.user, "No such file(s) to read", glob) + r.server.sendServerMessage(logger.Warn(r.server.user, "Unable to read file(s), check server logs")) + select { + case <-ctx.Done(): + return + default: + } + time.Sleep(retryInterval) + continue + } + + r.readFiles(ctx, paths, glob, regex, retryInterval) + break + } +} + +func (r *readCommand) readFiles(ctx context.Context, paths []string, glob string, regex string, retryInterval time.Duration) { + var wg sync.WaitGroup + wg.Add(len(paths)) + + for _, path := range paths { + go r.readFileIfPermissions(ctx, &wg, path, glob, regex) + } + + wg.Wait() +} + +func (r *readCommand) readFileIfPermissions(ctx context.Context, wg *sync.WaitGroup, path, glob, regex string) { + defer wg.Done() + globID := r.makeGlobID(path, glob) + + if !r.server.user.HasFilePermission(path, "readfiles") { + logger.Error(r.server.user, "No permission to read file", path, globID) + r.server.sendServerMessage(logger.Warn(r.server.user, "Unable to read file(s), check server logs")) + return + } + + r.readFile(ctx, path, globID, regex) +} + +func (r *readCommand) readFile(ctx context.Context, path, globID, regex string) { + logger.Info(r.server.user, "Start reading file", path, globID) + + var reader fs.FileReader + switch r.mode { + case omode.TailClient: + reader = fs.NewTailFile(path, globID, r.server.serverMessages, r.server.tailLimiter) + case omode.GrepClient, omode.CatClient: + reader = fs.NewCatFile(path, globID, r.server.serverMessages, r.server.catLimiter) + default: + reader = fs.NewTailFile(path, globID, r.server.serverMessages, r.server.tailLimiter) + } + + lines := r.server.lines + + // Plug in mappreduce engine + if r.server.aggregate != nil { + lines = r.server.aggregate.Lines + } + + for { + if err := reader.Start(ctx, lines, regex); err != nil { + logger.Error(r.server.user, path, globID, err) + } + + select { + case <-ctx.Done(): + return + default: + if !reader.Retry() { + return + } + } + + time.Sleep(time.Second * 2) + logger.Info(path, globID, "Reading file again") + } +} + +func (r *readCommand) makeGlobID(path, glob string) string { + var idParts []string + pathParts := strings.Split(path, "/") + + for i, globPart := range strings.Split(glob, "/") { + if strings.Contains(globPart, "*") { + idParts = append(idParts, pathParts[i]) + } + } + + if len(idParts) > 0 { + return strings.Join(idParts, "/") + } + + if len(pathParts) > 0 { + return pathParts[len(pathParts)-1] + } + + r.server.sendServerMessage(logger.Error("Empty file path given?", path, glob)) + return "" +} diff --git a/internal/server/handlers/runcommand.go b/internal/server/handlers/runcommand.go new file mode 100644 index 0000000..e260060 --- /dev/null +++ b/internal/server/handlers/runcommand.go @@ -0,0 +1,73 @@ +package handlers + +import ( + "context" + "fmt" + "os/exec" + "strings" + + "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/io/run" +) + +type runCommand struct { + server *ServerHandler + run run.Run +} + +func newRunCommand(server *ServerHandler) runCommand { + return runCommand{ + server: server, + } +} + +func (r runCommand) Start(ctx context.Context, argc int, args []string) { + if argc < 2 { + r.server.sendServerMessage(logger.Warn(r.server.user, commandParseWarning, args, argc)) + return + } + commands := strings.Split(strings.Join(args[1:], " "), ";") + r.start(ctx, commands) +} + +func (r runCommand) start(ctx context.Context, commands []string) { + for _, command := range commands { + command = strings.TrimSpace(command) + if len(command) == 0 { + continue + } + splitted := strings.Split(command, " ") + path := splitted[0] + args := splitted[1:] + + qualifiedPath, err := exec.LookPath(path) + if err != nil { + logger.Error(r.server.user, err) + r.server.sendServerMessage(logger.Warn(r.server.user, "Unable to execute command(s), check server logs")) + r.server.sendServerMessage(fmt.Sprintf(".run exitstatus -%d", -1)) + return + } + + if !r.server.user.HasFilePermission(qualifiedPath, "runcommands") { + logger.Error(r.server.user, "No permission to execute path", qualifiedPath) + r.server.sendServerMessage(logger.Warn(r.server.user, "Unable to execute command(s), check server logs")) + r.server.sendServerMessage(fmt.Sprintf(".run exitstatus -%d", -1)) + return + } + + r.run = run.New(qualifiedPath, args) + pid, ec, err := r.run.Start(ctx, r.server.lines) + + if err != nil { + message := fmt.Sprintf("Unable to execute remote command '%s'", command) + logger.Error(r.server.user, message, ec, pid, err) + r.server.sendServerMessage(logger.Error(message, ec, pid, err)) + r.server.sendServerMessage(fmt.Sprintf(".run exitstatus -%d", ec)) + return + } + + message := fmt.Sprintf("Remote process '%d' exited with status '%d'", pid, ec) + r.server.sendServerMessage(fmt.Sprintf(".run exitstatus %d", ec)) + r.server.sendServerMessage(logger.Info("run", pid, ec, message)) + } +} diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go index bed8609..3f0d6ce 100644 --- a/internal/server/handlers/serverhandler.go +++ b/internal/server/handlers/serverhandler.go @@ -1,17 +1,19 @@ package handlers import ( + "context" + "encoding/base64" + "errors" "fmt" "io" "os" - "path/filepath" "strings" "sync" "time" "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/fs" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/line" + "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/mapr/server" "github.com/mimecast/dtail/internal/omode" user "github.com/mimecast/dtail/internal/user/server" @@ -26,51 +28,33 @@ const ( // the Bi-directional communication between SSH client and server. // This handler implements the handler of the SSH server. type ServerHandler struct { - // Local log file readers - fileReaders []fs.FileReader - fileReadersMtx *sync.Mutex - // Channel for read lines. - lines chan fs.LineRead - // Only process log lines matching this regex. - regex string - // Server side mapr log aggregation. - aggregate *server.Aggregate - // Channel of aggregated log lines. + mutex *sync.Mutex + lines chan line.Line + regex string + aggregate *server.Aggregate aggregatedMessages chan string - // Channel for server messages to be sent to the client. - serverMessages chan string - // Channel for hidden messages to be sent to the client. - hiddenMessages chan string - // The current payload sent to the client. - payload []byte - // The current server hostname. - hostname string - // The user connecting to dtail. - user *user.User - // To limit the server wide max amount of concurrent cats - catLimiter chan struct{} - // To limit the server wide max amount of concurrent tails - tailLimiter chan struct{} - // Server can tell handler to stop the handler. - stop chan struct{} - // Indicate that client responded to server with "ack stop connection" - ackStopReceived chan struct{} - // Stop timeout. - stopTimeout chan struct{} + serverMessages chan string + payload []byte + hostname string + user *user.User + catLimiter chan struct{} + tailLimiter chan struct{} + ackCloseReceived chan struct{} + ctx context.Context + done chan struct{} + activeReaders int } // NewServerHandler returns the server handler. -func NewServerHandler(user *user.User, catLimiter chan struct{}, tailLimiter chan struct{}) *ServerHandler { - logger.Debug(user, "Creating tail handler") +func NewServerHandler(ctx context.Context, user *user.User, catLimiter chan struct{}, tailLimiter chan struct{}) (*ServerHandler, <-chan struct{}) { h := ServerHandler{ - fileReadersMtx: &sync.Mutex{}, - lines: make(chan fs.LineRead, 100), + ctx: ctx, + done: make(chan struct{}), + mutex: &sync.Mutex{}, + lines: make(chan line.Line, 100), serverMessages: make(chan string, 10), aggregatedMessages: make(chan string, 10), - hiddenMessages: make(chan string, 10), - ackStopReceived: make(chan struct{}), - stopTimeout: make(chan struct{}), - stop: make(chan struct{}), + ackCloseReceived: make(chan struct{}), catLimiter: catLimiter, tailLimiter: tailLimiter, regex: ".", @@ -85,37 +69,46 @@ func NewServerHandler(user *user.User, catLimiter chan struct{}, tailLimiter cha s := strings.Split(fqdn, ".") h.hostname = s[0] - return &h + return &h, h.done } // Read is to send data to the dtail client via Reader interface. func (h *ServerHandler) Read(p []byte) (n int, err error) { for { select { + case message := <-h.serverMessages: + if message[0] == '.' { + // Handle hidden message (don't display to the user, interpreted by dtail client) + wholePayload := []byte(fmt.Sprintf("%s\n", message)) + n = copy(p, wholePayload) + return + } + + // Handle normal server message (display to the user) wholePayload := []byte(fmt.Sprintf("SERVER|%s|%s\n", h.hostname, message)) n = copy(p, wholePayload) return + case message := <-h.aggregatedMessages: + // Send mapreduce-aggregated data as a message. data := fmt.Sprintf("AGGREGATE|%s|%s\n", h.hostname, message) - //logger.Debug("Sending aggregation data", data) wholePayload := []byte(data) n = copy(p, wholePayload) return - case message := <-h.hiddenMessages: - //logger.Debug(h.user, "Sending hidden message", message) - wholePayload := []byte(fmt.Sprintf(".%s\n", message)) - n = copy(p, wholePayload) - return + case line := <-h.lines: + // Send normal file content data as a message. serverInfo := []byte(fmt.Sprintf("REMOTE|%s|%3d|%v|%s|", - h.hostname, line.TransmittedPerc, line.Count, *line.GlobID)) + h.hostname, line.TransmittedPerc, line.Count, line.SourceID)) wholePayload := append(serverInfo, line.Content[:]...) n = copy(p, wholePayload) return + case <-time.After(time.Second): + // Once in a while check whether we are done. select { - case <-h.stop: + case <-h.ctx.Done(): return 0, io.EOF default: } @@ -129,7 +122,7 @@ func (h *ServerHandler) Write(p []byte) (n int, err error) { switch c { case ';': commandStr := strings.TrimSpace(string(h.payload)) - h.handleCommand(commandStr) + h.handleCommand(h.ctx, commandStr) h.payload = nil default: h.payload = append(h.payload, c) @@ -140,210 +133,167 @@ func (h *ServerHandler) Write(p []byte) (n int, err error) { return } -// Close the server handler. -func (h *ServerHandler) Close() { - h.fileReadersMtx.Lock() - defer h.fileReadersMtx.Unlock() +func (h *ServerHandler) handleCommand(ctx context.Context, commandStr string) { + logger.Debug(h.user, commandStr) - for _, reader := range h.fileReaders { - reader.Stop() + args, argc, err := h.handleProtocolVersion(strings.Split(commandStr, " ")) + if err != nil { + h.send(h.serverMessages, logger.Error(h.user, err)) + return } - if h.aggregate != nil { - h.aggregate.Close() + + args, argc, err = h.handleBase64(args, argc) + if err != nil { + h.send(h.serverMessages, logger.Error(h.user, err)) + return } - close(h.stop) -} + if h.user.Name == config.ControlUser { + h.handleControlCommand(argc, args) + return + } -func (h *ServerHandler) makeGlobID(path, glob string) string { - var idParts []string - pathParts := strings.Split(path, "/") + h.handleUserCommand(ctx, argc, args) +} - for i, globPart := range strings.Split(glob, "/") { - if strings.Contains(globPart, "*") { - idParts = append(idParts, pathParts[i]) - } - } +func (h *ServerHandler) handleProtocolVersion(args []string) ([]string, int, error) { + argc := len(args) - if len(idParts) > 0 { - return strings.Join(idParts, "/") + if argc <= 2 || args[0] != "protocol" { + return args, argc, errors.New("unable to determine protocol version") } - if len(pathParts) > 0 { - return pathParts[len(pathParts)-1] + if args[1] != version.ProtocolCompat { + err := fmt.Errorf("server with protool version '%s' but client with '%s', please update DTail", version.ProtocolCompat, args[1]) + return args, argc, err } - h.send(h.serverMessages, logger.Error("Empty file path given?", path, glob)) - return "" + return args[2:], argc - 2, nil } -func (h *ServerHandler) processFileGlob(mode omode.Mode, glob string, regex string) { - retryInterval := time.Second * 5 - glob = filepath.Clean(glob) - - errors := make(chan struct{}) - stop := make(chan struct{}) - defer close(stop) +func (h *ServerHandler) handleBase64(args []string, argc int) ([]string, int, error) { + err := errors.New("Unable to decode client message") - go func() { - for { - select { - case <-errors: - h.send(h.serverMessages, logger.Warn(h.user, "Unable to read file(s), check server logs")) - case <-stop: - return - case <-h.stop: - return - } - } - }() + if argc != 2 || args[0] != "base64" { + return args, argc, err + } - maxRetries := 10 - for { - maxRetries-- - if maxRetries < 0 { - h.send(h.serverMessages, logger.Warn(h.user, "Giving up to read file(s)")) - h.internalClose() - return - } + decoded, err := base64.StdEncoding.DecodeString(args[1]) + if err != nil { + return args, argc, err + } + decodedStr := string(decoded) - paths, err := filepath.Glob(glob) - if err != nil { - logger.Warn(h.user, glob, err) - time.Sleep(retryInterval) - continue - } + args = strings.Split(decodedStr, " ") + argc = len(decodedStr) + logger.Trace(h.user, "Base64 decoded received command", decodedStr, argc, args) - if numPaths := len(paths); numPaths == 0 { - logger.Error(h.user, "No such file(s) to read", glob) - select { - case errors <- struct{}{}: - case <-h.stop: - return - default: - } - time.Sleep(retryInterval) - continue - } + return args, argc, nil +} - h.startReadingFiles(mode, paths, glob, regex, retryInterval, errors) - break +func (h *ServerHandler) handleControlCommand(argc int, args []string) { + switch args[0] { + case "debug": + h.send(h.serverMessages, logger.Debug(h.user, "Receiving debug command", argc, args)) + default: + logger.Warn(h.user, "Received unknown command", argc, args) } } -func (h *ServerHandler) startReadingFiles(mode omode.Mode, paths []string, glob string, regex string, retryInterval time.Duration, errors chan<- struct{}) { - var wg sync.WaitGroup - wg.Add(len(paths)) +func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []string) { + logger.Debug(h.user, "handleUserCommand", argc, args) - read := func(path string, wg *sync.WaitGroup) { - defer wg.Done() - globID := h.makeGlobID(path, glob) + switch args[0] { + case "grep", "cat": + command := newReadCommand(h, omode.CatClient) + h.incrementActiveReaders() + go func() { + command.Start(ctx, argc, args) + if h.decrementActiveReaders() == 0 { + h.shutdown() + } + }() - if !h.user.HasFilePermission(path) { - logger.Error(h.user, "No permission to read file", path, globID) - select { - case errors <- struct{}{}: - default: + case "tail": + command := newReadCommand(h, omode.TailClient) + h.incrementActiveReaders() + go func() { + command.Start(ctx, argc, args) + if h.decrementActiveReaders() == 0 { + h.shutdown() } + }() + + case "map": + command, aggregate, err := newMapCommand(h, argc, args) + if err != nil { + h.sendServerMessage(err.Error()) + logger.Error(h.user, err) return } - h.startReadingFile(mode, path, globID, regex) - } - - for _, path := range paths { - go read(path, &wg) - } + h.aggregate = aggregate + go func() { + command.Start(ctx, h.aggregatedMessages) + h.shutdown() + }() + + case "run": + command := newRunCommand(h) + h.incrementActiveReaders() + go func() { + command.Start(ctx, argc, args) + if h.decrementActiveReaders() == 0 { + h.shutdown() + } + }() - wg.Wait() -} + case "ack", ".ack": + h.handleAckCommand(argc, args) -func (h *ServerHandler) startReadingFile(mode omode.Mode, path, globID, regex string) { - defer h.stopReadingFile(path) - logger.Info(h.user, "Start reading file", path, globID) - - var reader fs.FileReader - switch mode { - case omode.TailClient: - reader = fs.NewTailFile(path, globID, h.serverMessages, h.tailLimiter) - case omode.GrepClient: - fallthrough - case omode.CatClient: - reader = fs.NewCatFile(path, globID, h.serverMessages, h.catLimiter) default: - reader = fs.NewTailFile(path, globID, h.serverMessages, h.tailLimiter) + h.sendServerMessage(logger.Error(h.user, "Received unknown command", argc, args)) } +} - h.fileReadersMtx.Lock() - h.fileReaders = append(h.fileReaders, reader) - h.fileReadersMtx.Unlock() - - lines := h.lines - // Plugin mappreduce engine - if h.aggregate != nil { - lines = h.aggregate.Lines +func (h *ServerHandler) handleAckCommand(argc int, args []string) { + if argc < 3 { + h.sendServerMessage(logger.Warn(h.user, commandParseWarning, args, argc)) + return } - - for { - if err := reader.Start(lines, regex); err != nil { - logger.Error(h.user, path, globID, err) - } - - select { - case <-h.stop: - return - default: - if !reader.Retry() { - return - } - } - - time.Sleep(time.Second * 2) - logger.Info(path, globID, "Reading file again") + if args[1] == "close" && args[2] == "connection" { + close(h.ackCloseReceived) } } -func (h *ServerHandler) stopReadingFile(path string) { - logger.Info(h.user, "Stop reading file", path) +func (h *ServerHandler) send(ch chan<- string, message string) { + select { + case ch <- message: + case <-h.ctx.Done(): + } +} - h.fileReadersMtx.Lock() - defer h.fileReadersMtx.Unlock() +func (h *ServerHandler) sendServerMessage(message string) { + h.send(h.serverMessageC(), message) +} - path = filepath.Clean(path) - var fileReaders []fs.FileReader +func (h *ServerHandler) serverMessageC() chan<- string { + return h.serverMessages +} - for _, reader := range h.fileReaders { - if reader.FilePath() == path { - reader.Stop() - continue - } - fileReaders = append(fileReaders, reader) - } +func (h *ServerHandler) flush() { + logger.Debug(h.user, "flush()") - if len(fileReaders) == len(h.fileReaders) { - logger.Warn(h.user, "Didn't read file path", path) - return + if h.aggregate != nil { + h.aggregate.Flush() } - h.fileReaders = fileReaders - - if len(fileReaders) == 0 { - if h.aggregate != nil { - h.aggregate.Serialize() - } - h.allLinesSent() + unsentMessages := func() int { + return len(h.lines) + len(h.serverMessages) + len(h.aggregatedMessages) } -} - -func (h *ServerHandler) numUnsentMessages() int { - return len(h.lines) + len(h.serverMessages) + len(h.hiddenMessages) + len(h.aggregatedMessages) -} - -func (h *ServerHandler) allLinesSent() { - defer h.internalClose() for i := 0; i < 3; i++ { - if h.numUnsentMessages() == 0 { + if unsentMessages() == 0 { logger.Debug(h.user, "All lines sent") return } @@ -351,142 +301,43 @@ func (h *ServerHandler) allLinesSent() { time.Sleep(time.Second) } - logger.Warn(h.user, "Some lines remain unsent", h.numUnsentMessages()) + logger.Warn(h.user, "Some lines remain unsent", unsentMessages()) } -// Handler decides to shutdown the connection, not the server itself. -func (h *ServerHandler) internalClose() { - select { - case h.hiddenMessages <- "syn close connection": - case <-time.After(time.Second * 5): - logger.Debug(h.user, "Not waiting for ack close connection") - close(h.stopTimeout) - return - } +func (h *ServerHandler) shutdown() { + logger.Debug(h.user, "shutdown()") + h.flush() + + go func() { + select { + case h.serverMessageC() <- ".syn close connection": + case <-h.ctx.Done(): + } + }() select { - case <-h.Wait(): + case <-h.ackCloseReceived: case <-time.After(time.Second * 5): - logger.Debug(h.user, "Not waiting for ack close connection") - close(h.stopTimeout) - } -} - -func (h *ServerHandler) handleCommand(commandStr string) { - logger.Info(h.user, commandStr) - - args := strings.Split(commandStr, " ") - argc := len(args) - - logger.Debug(h.user, "Received command", commandStr, argc, args) - - if h.user.Name == config.ControlUser { - h.handleControlCommand(argc, args) - return + logger.Debug(h.user, "Shutdown timeout reached, enforcing shutdown") + case <-h.ctx.Done(): } - h.handleUserCommand(argc, args) -} - -// Special (restricted) set of commands for anonymous ControlUser access. -func (h *ServerHandler) handleControlCommand(argc int, args []string) { - switch args[0] { - case "ping": - h.send(h.hiddenMessages, "pong") - case "debug": - h.send(h.serverMessages, logger.Debug(h.user, "Receiving debug command", argc, args)) - default: - logger.Warn(h.user, "Received unknown command", argc, args) - } -} - -// Commands for authed users. -func (h *ServerHandler) handleUserCommand(argc int, args []string) { - switch args[0] { - case "grep": - fallthrough - case "cat": - h.handleReadCommand(argc, args, omode.CatClient) - case "tail": - h.handleReadCommand(argc, args, omode.TailClient) - case "map": - h.handleMapCommand(argc, args) - case "ack": - h.handleAckCommand(argc, args) - case "ping": - h.send(h.hiddenMessages, "pong") - case "version": - h.send(h.serverMessages, fmt.Sprintf("Server version is "+version.String())) - case "debug": - h.send(h.serverMessages, logger.Debug(h.user, "Received debug command", argc, args)) + select { + case h.done <- struct{}{}: default: - h.send(h.serverMessages, logger.Warn(h.user, "Received unknown command", argc, args)) } } -func (h *ServerHandler) handleReadCommand(argc int, args []string, mode omode.Mode) { - regex := "." - if argc >= 4 { - regex = args[3] - } - if argc < 3 { - h.send(h.serverMessages, logger.Warn(h.user, commandParseWarning, args, argc)) - return - } - go h.processFileGlob(mode, args[1], regex) +func (h *ServerHandler) incrementActiveReaders() { + // TODO: Use atomic counter variable instead, so we can get rid of the mutex + h.mutex.Lock() + defer h.mutex.Unlock() + h.activeReaders++ } +func (h *ServerHandler) decrementActiveReaders() int { + h.mutex.Lock() + defer h.mutex.Unlock() + h.activeReaders-- -func (h *ServerHandler) handleMapCommand(argc int, args []string) { - if argc < 2 { - h.send(h.serverMessages, logger.Warn(h.user, commandParseWarning, args, argc)) - return - } - - queryStr := strings.Join(args[1:], " ") - logger.Info(h.user, "Creating new mapr aggregator", queryStr) - aggregate, err := server.NewAggregate(h.aggregatedMessages, queryStr) - - if err != nil { - h.send(h.serverMessages, logger.Error(h.user, err)) - return - } - - h.aggregate = aggregate -} - -func (h *ServerHandler) handleAckCommand(argc int, args []string) { - if argc < 3 { - h.send(h.serverMessages, logger.Warn(h.user, commandParseWarning, args, argc)) - return - } - if args[1] == "close" && args[2] == "connection" { - close(h.ackStopReceived) - } -} - -func (h *ServerHandler) send(ch chan<- string, message string) { - select { - case ch <- message: - case <-h.stop: - } -} - -// Wait (block) until server handler is closed or a timeout has exceeded. -func (h *ServerHandler) Wait() <-chan struct{} { - wait := make(chan struct{}) - - go func() { - select { - case <-h.ackStopReceived: - logger.Debug(h.user, "Closing wait channel due to ACK stop received") - close(wait) - case <-h.stopTimeout: - logger.Debug(h.user, "Closing wait channel due to wait timeout") - close(wait) - case <-h.stop: - logger.Debug(h.user, "Closing wait channel due to stop") - } - }() - - return wait + return h.activeReaders } diff --git a/internal/server/server.go b/internal/server/server.go index 27a98f5..42eb74c 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -1,13 +1,14 @@ package server import ( + "context" "errors" "fmt" "io" "net" "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/server/handlers" "github.com/mimecast/dtail/internal/ssh/server" user "github.com/mimecast/dtail/internal/user/server" @@ -26,8 +27,6 @@ type Server struct { catLimiterCh chan struct{} // To control the max amount of concurrent tails tailLimiterCh chan struct{} - // Ask to shutdown the server - stop chan struct{} } // New returns a new server. @@ -38,7 +37,6 @@ func New() *Server { sshServerConfig: &gossh.ServerConfig{}, catLimiterCh: make(chan struct{}, config.Server.MaxConcurrentCats), tailLimiterCh: make(chan struct{}, config.Server.MaxConcurrentTails), - stop: make(chan struct{}), } s.sshServerConfig.PasswordCallback = s.controlUserCallback @@ -54,7 +52,7 @@ func New() *Server { } // Start the server. -func (s *Server) Start() int { +func (s *Server) Start(ctx context.Context) int { logger.Info("Starting server") bindAt := fmt.Sprintf("%s:%d", config.Server.SSHBindAddress, config.Common.SSHPort) @@ -64,7 +62,7 @@ func (s *Server) Start() int { logger.FatalExit("Failed to open listening TCP socket", err) } - go s.stats.periodicLogServerStats(s.stop) + go s.stats.periodicLogServerStats(ctx) for { conn, err := listener.Accept() // Blocking @@ -79,11 +77,11 @@ func (s *Server) Start() int { continue } - go s.handleConnection(conn) + go s.handleConnection(ctx, conn) } } -func (s *Server) handleConnection(conn net.Conn) { +func (s *Server) handleConnection(ctx context.Context, conn net.Conn) { logger.Info("Handling connection") sshConn, chans, reqs, err := gossh.NewServerConn(conn, s.sshServerConfig) @@ -96,11 +94,11 @@ func (s *Server) handleConnection(conn net.Conn) { go gossh.DiscardRequests(reqs) for newChannel := range chans { - go s.handleChannel(sshConn, newChannel) + go s.handleChannel(ctx, sshConn, newChannel) } } -func (s *Server) handleChannel(sshConn gossh.Conn, newChannel gossh.NewChannel) { +func (s *Server) handleChannel(ctx context.Context, sshConn gossh.Conn, newChannel gossh.NewChannel) { user := user.New(sshConn.User(), sshConn.RemoteAddr().String()) logger.Info(user, "Invoking channel handler") @@ -117,13 +115,13 @@ func (s *Server) handleChannel(sshConn gossh.Conn, newChannel gossh.NewChannel) return } - if err := s.handleRequests(sshConn, requests, channel, user); err != nil { + if err := s.handleRequests(ctx, sshConn, requests, channel, user); err != nil { logger.Error(user, err) sshConn.Close() } } -func (s *Server) handleRequests(sshConn gossh.Conn, in <-chan *gossh.Request, channel gossh.Channel, user *user.User) error { +func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, in <-chan *gossh.Request, channel gossh.Channel, user *user.User) error { logger.Info(user, "Invoking request handler") for req := range in { @@ -132,50 +130,50 @@ func (s *Server) handleRequests(sshConn gossh.Conn, in <-chan *gossh.Request, ch switch req.Type { case "shell": + handlerCtx, cancel := context.WithCancel(ctx) + var handler handlers.Handler + var done <-chan struct{} + switch user.Name { case config.ControlUser: - handler = handlers.NewControlHandler(user) + handler, done = handlers.NewControlHandler(handlerCtx, user) default: - handler = handlers.NewServerHandler(user, s.catLimiterCh, s.tailLimiterCh) + handler, done = handlers.NewServerHandler(handlerCtx, user, s.catLimiterCh, s.tailLimiterCh) } - // Bi-directionally connect SSH stream to SSH handler - brokenPipe1 := make(chan struct{}) go func() { - defer close(brokenPipe1) + // Handler finished work, cancel all remaining routines + defer cancel() + <-done + }() + + go func() { + // Broken pipe, cancel + defer cancel() + io.Copy(channel, handler) }() - brokenPipe2 := make(chan struct{}) go func() { - defer close(brokenPipe2) + // Broken pipe, cancel + defer cancel() + io.Copy(handler, channel) }() - // Ensure to close all fd's and stop all goroutines once ssh connection terminated go func() { - defer s.stats.decrementConnections() - defer handler.Close() + defer cancel() if err := sshConn.Wait(); err != nil && err != io.EOF { logger.Error(user, err) } + s.stats.decrementConnections() logger.Info(user, "Good bye Mister!") }() - // Close the underlying ssh socket when server shuts down go func() { - select { - case <-s.stop: - logger.Debug(user, "Server initiating shutdown on handler") - case <-handler.Wait(): - logger.Debug(user, "Handler initiating shutdown by its own") - case <-brokenPipe1: - logger.Debug(user, "Broken pipe1") - case <-brokenPipe2: - logger.Debug(user, "Broken pipe2") - } + <-handlerCtx.Done() sshConn.Close() logger.Info(user, "Closed SSH connection") }() @@ -204,9 +202,3 @@ func (*Server) controlUserCallback(c gossh.ConnMetadata, authPayload []byte) (*g return nil, fmt.Errorf("Not authorized") } - -// Stop the server. -func (s *Server) Stop() { - close(s.stop) - s.stats.waitForConnections() -} diff --git a/internal/server/stats.go b/internal/server/stats.go index beb1885..4d661f7 100644 --- a/internal/server/stats.go +++ b/internal/server/stats.go @@ -1,12 +1,14 @@ package server import ( - "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/logger" + "context" "fmt" "runtime" "sync" "time" + + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/io/logger" ) // Used to collect and display various server stats. @@ -65,12 +67,12 @@ func (s *stats) serverLimitExceeded() error { return nil } -func (s *stats) periodicLogServerStats(stop <-chan struct{}) { +func (s *stats) periodicLogServerStats(ctx context.Context) { for { select { case <-time.NewTimer(time.Second * 10).C: s.logServerStats() - case <-stop: + case <-ctx.Done(): return } } |
