diff options
Diffstat (limited to 'internal/server')
| -rw-r--r-- | internal/server/handlers/serverhandler.go | 70 |
1 files changed, 42 insertions, 28 deletions
diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go index 739696c..939388c 100644 --- a/internal/server/handlers/serverhandler.go +++ b/internal/server/handlers/serverhandler.go @@ -10,6 +10,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" "github.com/mimecast/dtail/internal/config" @@ -30,7 +31,6 @@ const ( // the Bi-directional communication between SSH client and server. // This handler implements the handler of the SSH server. type ServerHandler struct { - mutex *sync.Mutex lines chan line.Line regex string aggregate *server.Aggregate @@ -47,7 +47,8 @@ type ServerHandler struct { serverCtx context.Context handlerCtx context.Context done chan struct{} - activeCommands int + activeCommands int32 + activeReaders int32 background background.Background } @@ -57,7 +58,6 @@ func NewServerHandler(handlerCtx, serverCtx context.Context, user *user.User, ca serverCtx: serverCtx, handlerCtx: handlerCtx, done: make(chan struct{}), - mutex: &sync.Mutex{}, lines: make(chan line.Line, 100), serverMessages: make(chan string, 10), aggregatedMessages: make(chan string, 10), @@ -170,10 +170,11 @@ func (h *ServerHandler) handleCommand(ctx context.Context, commandStr string) { } if timeout > 0 { - logger.Debug("Command with timeout context", argc, args, timeout) + logger.Info(h.user, "Command with timeout context", argc, args, timeout) commandCtx, cancel := context.WithTimeout(ctx, timeout) go func() { <-commandCtx.Done() + logger.Info(h.user, "Command timed out, canceling it", args, args, timeout) cancel() }() h.handleUserCommand(commandCtx, argc, args, timeout) @@ -241,11 +242,19 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args [] logger.Debug(h.user, "handleUserCommand", argc, args) h.incrementActiveCommands() - finished := func() { + commandFinished := func() { if h.decrementActiveCommands() == 0 { h.shutdown() } } + readerFinished := func() { + if h.decrementActiveReaders() == 0 { + if h.aggregate == nil { + return + } + h.aggregate.Cancel() + } + } splitted := strings.Split(args[0], ":") commandName := splitted[0] @@ -253,24 +262,27 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args [] options, err := readOptions(splitted[1:]) if err != nil { h.sendServerMessage(logger.Error(h.user, err)) - finished() + commandFinished() return } switch commandName { case "grep", "cat": command := newReadCommand(h, omode.CatClient) - h.incrementActiveCommands() go func() { + h.incrementActiveReaders() command.Start(ctx, argc, args) - finished() + readerFinished() + commandFinished() }() case "tail": command := newReadCommand(h, omode.TailClient) go func() { + h.incrementActiveReaders() command.Start(ctx, argc, args) - finished() + readerFinished() + commandFinished() }() case "map": @@ -278,14 +290,14 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args [] if err != nil { h.sendServerMessage(err.Error()) logger.Error(h.user, err) - finished() + commandFinished() return } h.aggregate = aggregate go func() { command.Start(ctx, h.aggregatedMessages) - finished() + commandFinished() }() case "run": @@ -301,7 +313,7 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args [] } else { h.sendServerMessage(logger.Info(h.user, "job cancelled", jobName)) } - finished() + commandFinished() return } @@ -313,7 +325,7 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args [] count++ } h.sendServerMessage(fmt.Sprintf("Found %d jobs", count)) - finished() + commandFinished() return } @@ -339,7 +351,7 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args [] if err := h.background.Add(h.user.Name, jobName, cancel, &wg); err != nil { h.sendServerMessage(logger.Error(h.user, err, jobName, args)) - finished() + commandFinished() return } ctx = commandCtx @@ -347,7 +359,7 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args [] if err := command.StartBackground(ctx, &wg, argc, args, outerArgs); err != nil { h.sendServerMessage(logger.Error(h.user, "Unable to execute command", argc, args, err)) - finished() + commandFinished() return } @@ -360,21 +372,21 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args [] if background { h.sendServerMessage(logger.Info(h.user, jobName, "job started in background")) - finished() + commandFinished() return } // Command run in foreground, wait for it to complete before finishing the connection. wg.Wait() - finished() + commandFinished() case "ack", ".ack": h.handleAckCommand(argc, args) - finished() + commandFinished() default: h.sendServerMessage(logger.Error(h.user, "Received unknown user command", commandName, argc, args, options)) - finished() + commandFinished() } } @@ -450,19 +462,21 @@ func (h *ServerHandler) shutdown() { } func (h *ServerHandler) incrementActiveCommands() { - // TODO: Use atomic counter variable instead, so we can get rid of the mutex - h.mutex.Lock() - defer h.mutex.Unlock() + atomic.AddInt32(&h.activeCommands, 1) +} - h.activeCommands++ +func (h *ServerHandler) decrementActiveCommands() int32 { + atomic.AddInt32(&h.activeCommands, -1) + return atomic.LoadInt32(&h.activeCommands) } -func (h *ServerHandler) decrementActiveCommands() int { - h.mutex.Lock() - defer h.mutex.Unlock() +func (h *ServerHandler) incrementActiveReaders() { + atomic.AddInt32(&h.activeReaders, 1) +} - h.activeCommands-- - return h.activeCommands +func (h *ServerHandler) decrementActiveReaders() int32 { + atomic.AddInt32(&h.activeReaders, -1) + return atomic.LoadInt32(&h.activeReaders) } func readOptions(opts []string) (map[string]string, error) { |
