diff options
| author | Paul Bütow <pbuetow@mimecast.com> | 2020-02-16 18:07:36 +0000 |
|---|---|---|
| committer | Paul Bütow <pbuetow@mimecast.com> | 2020-02-16 18:07:36 +0000 |
| commit | e0f4ccc46c8601f322640b72e100f973a837ef02 (patch) | |
| tree | 61a1fcf66daea222da19500b0b6ae60d1e89a5d9 /internal/server/handlers/serverhandler.go | |
| parent | 6bca637513e065a33cadaccad97ada25eb7a6b00 (diff) | |
server kills subprocesses correctly on cancel
Diffstat (limited to 'internal/server/handlers/serverhandler.go')
| -rw-r--r-- | internal/server/handlers/serverhandler.go | 73 |
1 files changed, 54 insertions, 19 deletions
diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go index b840c77..946ae83 100644 --- a/internal/server/handlers/serverhandler.go +++ b/internal/server/handlers/serverhandler.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "os" + "strconv" "strings" "sync" "time" @@ -37,28 +38,31 @@ type ServerHandler struct { payload []byte hostname string user *user.User - catLimiter chan struct{} - tailLimiter chan struct{} - ackCloseReceived chan struct{} - ctx context.Context - done chan struct{} - activeReaders int + // TODO: Move all these channels into a separate struct for readability! + catLimiter chan struct{} + tailLimiter chan struct{} + globalServerWaitFor chan struct{} + ackCloseReceived chan struct{} + ctx context.Context + done chan struct{} + activeReaders int } // NewServerHandler returns the server handler. -func NewServerHandler(ctx context.Context, user *user.User, catLimiter chan struct{}, tailLimiter chan struct{}) (*ServerHandler, <-chan struct{}) { +func NewServerHandler(ctx context.Context, user *user.User, catLimiter, tailLimiter, globalServerWaitFor chan struct{}) (*ServerHandler, <-chan struct{}) { h := ServerHandler{ - ctx: ctx, - done: make(chan struct{}), - mutex: &sync.Mutex{}, - lines: make(chan line.Line, 100), - serverMessages: make(chan string, 10), - aggregatedMessages: make(chan string, 10), - ackCloseReceived: make(chan struct{}), - catLimiter: catLimiter, - tailLimiter: tailLimiter, - regex: ".", - user: user, + ctx: ctx, + done: make(chan struct{}), + mutex: &sync.Mutex{}, + lines: make(chan line.Line, 100), + serverMessages: make(chan string, 10), + aggregatedMessages: make(chan string, 10), + ackCloseReceived: make(chan struct{}), + catLimiter: catLimiter, + tailLimiter: tailLimiter, + globalServerWaitFor: globalServerWaitFor, + regex: ".", + user: user, } fqdn, err := os.Hostname() @@ -135,6 +139,7 @@ func (h *ServerHandler) Write(p []byte) (n int, err error) { func (h *ServerHandler) handleCommand(ctx context.Context, commandStr string) { logger.Debug(h.user, commandStr) + var timeout time.Duration args, argc, err := h.handleProtocolVersion(strings.Split(commandStr, " ")) if err != nil { @@ -148,11 +153,28 @@ func (h *ServerHandler) handleCommand(ctx context.Context, commandStr string) { return } + args, argc, timeout, err = h.handleTimeout(args, argc) + if err != nil { + h.send(h.serverMessages, logger.Error(h.user, err)) + return + } + if h.user.Name == config.ControlUser { h.handleControlCommand(argc, args) return } + if timeout > 0 { + logger.Debug("Command with timeout context", argc, args, timeout) + commandCtx, cancel := context.WithTimeout(ctx, timeout) + go func() { + <-commandCtx.Done() + cancel() + }() + h.handleUserCommand(commandCtx, argc, args) + return + } + h.handleUserCommand(ctx, argc, args) } @@ -191,6 +213,16 @@ func (h *ServerHandler) handleBase64(args []string, argc int) ([]string, int, er return args, argc, nil } +func (h *ServerHandler) handleTimeout(args []string, argc int) ([]string, int, time.Duration, error) { + if argc <= 2 || args[0] != "timeout" { + // No timeout specified + return args, argc, time.Duration(0) * time.Second, nil + } + + timeout, err := strconv.Atoi(args[1]) + return args[2:], argc - 2, time.Duration(timeout) * time.Second, err +} + func (h *ServerHandler) handleControlCommand(argc int, args []string) { switch args[0] { case "debug": @@ -241,8 +273,10 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args [] case "run": command := newRunCommand(h) h.incrementActiveReaders() + go func() { h.globalServerWaitFor <- struct{}{} }() go func() { command.Start(ctx, argc, args) + <-h.globalServerWaitFor if h.decrementActiveReaders() == 0 { h.shutdown() } @@ -332,12 +366,13 @@ func (h *ServerHandler) incrementActiveReaders() { // TODO: Use atomic counter variable instead, so we can get rid of the mutex h.mutex.Lock() defer h.mutex.Unlock() + h.activeReaders++ } func (h *ServerHandler) decrementActiveReaders() int { h.mutex.Lock() defer h.mutex.Unlock() - h.activeReaders-- + h.activeReaders-- return h.activeReaders } |
