From e0f4ccc46c8601f322640b72e100f973a837ef02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul=20B=C3=BCtow?= Date: Sun, 16 Feb 2020 18:07:36 +0000 Subject: server kills subprocesses correctly on cancel --- internal/server/handlers/runcommand.go | 4 ++ internal/server/handlers/serverhandler.go | 73 +++++++++++++++++++++++-------- internal/server/server.go | 46 ++++++++++++++++--- 3 files changed, 99 insertions(+), 24 deletions(-) (limited to 'internal/server') diff --git a/internal/server/handlers/runcommand.go b/internal/server/handlers/runcommand.go index 120c402..95db52f 100644 --- a/internal/server/handlers/runcommand.go +++ b/internal/server/handlers/runcommand.go @@ -92,4 +92,8 @@ func (r runCommand) start(ctx context.Context, command string) { r.server.sendServerMessage(fmt.Sprintf(".run exitstatus %d", ec)) r.server.sendServerMessage(logger.Info(fmt.Sprintf("Process %d exited with status %d", pid, ec))) + + logger.Debug(r.server.user, "Waiting for Pgroup to be killed") + <-r.run.PgroupKilled() + logger.Debug(r.server.user, "Pgroup killed") } diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go index b840c77..946ae83 100644 --- a/internal/server/handlers/serverhandler.go +++ b/internal/server/handlers/serverhandler.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "os" + "strconv" "strings" "sync" "time" @@ -37,28 +38,31 @@ type ServerHandler struct { payload []byte hostname string user *user.User - catLimiter chan struct{} - tailLimiter chan struct{} - ackCloseReceived chan struct{} - ctx context.Context - done chan struct{} - activeReaders int + // TODO: Move all these channels into a separate struct for readability! + catLimiter chan struct{} + tailLimiter chan struct{} + globalServerWaitFor chan struct{} + ackCloseReceived chan struct{} + ctx context.Context + done chan struct{} + activeReaders int } // NewServerHandler returns the server handler. -func NewServerHandler(ctx context.Context, user *user.User, catLimiter chan struct{}, tailLimiter chan struct{}) (*ServerHandler, <-chan struct{}) { +func NewServerHandler(ctx context.Context, user *user.User, catLimiter, tailLimiter, globalServerWaitFor chan struct{}) (*ServerHandler, <-chan struct{}) { h := ServerHandler{ - ctx: ctx, - done: make(chan struct{}), - mutex: &sync.Mutex{}, - lines: make(chan line.Line, 100), - serverMessages: make(chan string, 10), - aggregatedMessages: make(chan string, 10), - ackCloseReceived: make(chan struct{}), - catLimiter: catLimiter, - tailLimiter: tailLimiter, - regex: ".", - user: user, + ctx: ctx, + done: make(chan struct{}), + mutex: &sync.Mutex{}, + lines: make(chan line.Line, 100), + serverMessages: make(chan string, 10), + aggregatedMessages: make(chan string, 10), + ackCloseReceived: make(chan struct{}), + catLimiter: catLimiter, + tailLimiter: tailLimiter, + globalServerWaitFor: globalServerWaitFor, + regex: ".", + user: user, } fqdn, err := os.Hostname() @@ -135,6 +139,7 @@ func (h *ServerHandler) Write(p []byte) (n int, err error) { func (h *ServerHandler) handleCommand(ctx context.Context, commandStr string) { logger.Debug(h.user, commandStr) + var timeout time.Duration args, argc, err := h.handleProtocolVersion(strings.Split(commandStr, " ")) if err != nil { @@ -148,11 +153,28 @@ func (h *ServerHandler) handleCommand(ctx context.Context, commandStr string) { return } + args, argc, timeout, err = h.handleTimeout(args, argc) + if err != nil { + h.send(h.serverMessages, logger.Error(h.user, err)) + return + } + if h.user.Name == config.ControlUser { h.handleControlCommand(argc, args) return } + if timeout > 0 { + logger.Debug("Command with timeout context", argc, args, timeout) + commandCtx, cancel := context.WithTimeout(ctx, timeout) + go func() { + <-commandCtx.Done() + cancel() + }() + h.handleUserCommand(commandCtx, argc, args) + return + } + h.handleUserCommand(ctx, argc, args) } @@ -191,6 +213,16 @@ func (h *ServerHandler) handleBase64(args []string, argc int) ([]string, int, er return args, argc, nil } +func (h *ServerHandler) handleTimeout(args []string, argc int) ([]string, int, time.Duration, error) { + if argc <= 2 || args[0] != "timeout" { + // No timeout specified + return args, argc, time.Duration(0) * time.Second, nil + } + + timeout, err := strconv.Atoi(args[1]) + return args[2:], argc - 2, time.Duration(timeout) * time.Second, err +} + func (h *ServerHandler) handleControlCommand(argc int, args []string) { switch args[0] { case "debug": @@ -241,8 +273,10 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args [] case "run": command := newRunCommand(h) h.incrementActiveReaders() + go func() { h.globalServerWaitFor <- struct{}{} }() go func() { command.Start(ctx, argc, args) + <-h.globalServerWaitFor if h.decrementActiveReaders() == 0 { h.shutdown() } @@ -332,12 +366,13 @@ func (h *ServerHandler) incrementActiveReaders() { // TODO: Use atomic counter variable instead, so we can get rid of the mutex h.mutex.Lock() defer h.mutex.Unlock() + h.activeReaders++ } func (h *ServerHandler) decrementActiveReaders() int { h.mutex.Lock() defer h.mutex.Unlock() - h.activeReaders-- + h.activeReaders-- return h.activeReaders } diff --git a/internal/server/server.go b/internal/server/server.go index 9b9d8a0..9314468 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net" + "time" "github.com/mimecast/dtail/internal/config" "github.com/mimecast/dtail/internal/io/logger" @@ -24,11 +25,13 @@ type Server struct { // SSH server configuration. sshServerConfig *gossh.ServerConfig // To control the max amount of concurrent cats (which can cause a lot of I/O on the server) - catLimiterCh chan struct{} + catLimiter chan struct{} // To control the max amount of concurrent tails - tailLimiterCh chan struct{} + tailLimiter chan struct{} // To run scheduled tasks (if configured) sched *scheduler + // Wait counter, e.g. there might be still subprocesses (forked by drun) to be killed. + shutdownWaitFor chan struct{} } // New returns a new server. @@ -37,8 +40,9 @@ func New() *Server { s := Server{ sshServerConfig: &gossh.ServerConfig{}, - catLimiterCh: make(chan struct{}, config.Server.MaxConcurrentCats), - tailLimiterCh: make(chan struct{}, config.Server.MaxConcurrentTails), + catLimiter: make(chan struct{}, config.Server.MaxConcurrentCats), + tailLimiter: make(chan struct{}, config.Server.MaxConcurrentTails), + shutdownWaitFor: make(chan struct{}, 1000), sched: newScheduler(), } @@ -60,6 +64,7 @@ func (s *Server) Start(ctx context.Context) int { bindAt := fmt.Sprintf("%s:%d", config.Server.SSHBindAddress, config.Common.SSHPort) logger.Info("Binding server", bindAt) + listener, err := net.Listen("tcp", bindAt) if err != nil { logger.FatalExit("Failed to open listening TCP socket", err) @@ -67,10 +72,40 @@ func (s *Server) Start(ctx context.Context) int { go s.stats.start(ctx) go s.sched.start(ctx) + go s.listenerLoop(ctx, listener) + + select { + case <-ctx.Done(): + // Wait until all commands/jobs/children are no more! + s.wait() + } + + // For future use. + return 0 +} + +func (s *Server) wait() { + for { + num := len(s.shutdownWaitFor) + logger.Debug("Waiting for stuff to finish", num) + if num <= 0 { + return + } + time.Sleep(time.Second) + } +} + +func (s *Server) listenerLoop(ctx context.Context, listener net.Listener) { + logger.Debug("Starting listener loop") for { conn, err := listener.Accept() // Blocking if err != nil { + select { + case <-ctx.Done(): + return + default: + } logger.Error("Failed to accept incoming connection", err) continue } @@ -143,12 +178,13 @@ func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, in <-ch case config.ControlUser: handler, done = handlers.NewControlHandler(handlerCtx, user) default: - handler, done = handlers.NewServerHandler(handlerCtx, user, s.catLimiterCh, s.tailLimiterCh) + handler, done = handlers.NewServerHandler(handlerCtx, user, s.catLimiter, s.tailLimiter, s.shutdownWaitFor) } go func() { // Handler finished work, cancel all remaining routines defer cancel() + <-done }() -- cgit v1.2.3