diff options
| -rw-r--r-- | internal/clients/baseclient.go | 2 | ||||
| -rw-r--r-- | internal/clients/handlers/basehandler.go | 23 | ||||
| -rw-r--r-- | internal/clients/handlers/clienthandler.go | 5 | ||||
| -rw-r--r-- | internal/clients/handlers/handler.go | 3 | ||||
| -rw-r--r-- | internal/clients/handlers/healthhandler.go | 17 | ||||
| -rw-r--r-- | internal/clients/handlers/maprhandler.go | 5 | ||||
| -rw-r--r-- | internal/clients/handlers/withcancel.go | 24 | ||||
| -rw-r--r-- | internal/clients/healthclient.go | 2 | ||||
| -rw-r--r-- | internal/clients/remote/connection.go | 7 | ||||
| -rw-r--r-- | internal/done.go | 32 | ||||
| -rw-r--r-- | internal/mapr/server/aggregate.go | 80 | ||||
| -rw-r--r-- | internal/server/handlers/controlhandler.go | 26 | ||||
| -rw-r--r-- | internal/server/handlers/handler.go | 2 | ||||
| -rw-r--r-- | internal/server/handlers/serverhandler.go | 75 | ||||
| -rw-r--r-- | internal/server/server.go | 39 |
15 files changed, 191 insertions, 151 deletions
diff --git a/internal/clients/baseclient.go b/internal/clients/baseclient.go index 008a01e..d8d4fde 100644 --- a/internal/clients/baseclient.go +++ b/internal/clients/baseclient.go @@ -99,7 +99,7 @@ func (c *baseClient) start(ctx context.Context, active chan struct{}, i int, con defer func() { <-active }() for { - connCtx, cancel := conn.Handler.WithCancel(ctx) + connCtx, cancel := context.WithCancel(ctx) defer cancel() conn.Start(connCtx, cancel, c.throttleCh, c.stats.connectionsEstCh) diff --git a/internal/clients/handlers/basehandler.go b/internal/clients/handlers/basehandler.go index 65bbfd7..b5045e2 100644 --- a/internal/clients/handlers/basehandler.go +++ b/internal/clients/handlers/basehandler.go @@ -8,12 +8,13 @@ import ( "strings" "time" + "github.com/mimecast/dtail/internal" "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/version" ) type baseHandler struct { - withCancel + done *internal.Done server string shellStarted bool commands chan string @@ -29,6 +30,14 @@ func (h *baseHandler) Status() int { return h.status } +func (h *baseHandler) Done() <-chan struct{} { + return h.done.Done() +} + +func (h *baseHandler) Shutdown() { + h.done.Shutdown() +} + // SendMessage to the server. func (h *baseHandler) SendMessage(command string) error { encoded := base64.StdEncoding.EncodeToString([]byte(command)) @@ -38,7 +47,8 @@ func (h *baseHandler) SendMessage(command string) error { case h.commands <- fmt.Sprintf("protocol %s base64 %v;", version.ProtocolCompat, encoded): case <-time.After(time.Second * 5): return fmt.Errorf("Timed out sending command '%s' (base64: '%s')", command, encoded) - case <-h.ctx.Done(): + case <-h.Done(): + return nil } return nil @@ -65,7 +75,7 @@ func (h *baseHandler) Read(p []byte) (n int, err error) { select { case command := <-h.commands: n = copy(p, []byte(command)) - case <-h.ctx.Done(): + case <-h.Done(): return 0, io.EOF } return @@ -95,10 +105,11 @@ func (h *baseHandler) handleHiddenMessage(message string) { case strings.HasPrefix(message, ".syn close connection"): h.SendMessage(".ack close connection") select { - case <-time.After(time.Second * 1): + case <-time.After(time.Second * 5): logger.Debug("Shutting down client after timeout and sending ack to server") - h.withCancel.shutdown() - case <-h.ctx.Done(): + h.Shutdown() + case <-h.Done(): + return } case strings.HasPrefix(message, ".run exitstatus"): diff --git a/internal/clients/handlers/clienthandler.go b/internal/clients/handlers/clienthandler.go index fcd8052..2bcb038 100644 --- a/internal/clients/handlers/clienthandler.go +++ b/internal/clients/handlers/clienthandler.go @@ -1,6 +1,7 @@ package handlers import ( + "github.com/mimecast/dtail/internal" "github.com/mimecast/dtail/internal/io/logger" ) @@ -19,9 +20,7 @@ func NewClientHandler(server string) *ClientHandler { shellStarted: false, commands: make(chan string), status: -1, - withCancel: withCancel{ - done: make(chan struct{}), - }, + done: internal.NewDone(), }, } } diff --git a/internal/clients/handlers/handler.go b/internal/clients/handlers/handler.go index c53ca34..afa87e2 100644 --- a/internal/clients/handlers/handler.go +++ b/internal/clients/handlers/handler.go @@ -1,7 +1,6 @@ package handlers import ( - "context" "io" ) @@ -11,6 +10,6 @@ type Handler interface { SendMessage(command string) error Server() string Status() int - WithCancel(ctx context.Context) (context.Context, context.CancelFunc) + Shutdown() Done() <-chan struct{} } diff --git a/internal/clients/handlers/healthhandler.go b/internal/clients/handlers/healthhandler.go index 9051015..95693ab 100644 --- a/internal/clients/handlers/healthhandler.go +++ b/internal/clients/handlers/healthhandler.go @@ -4,11 +4,13 @@ import ( "errors" "fmt" "time" + + "github.com/mimecast/dtail/internal" ) // HealthHandler implements the handler required for health checks. type HealthHandler struct { - withCancel + done *internal.Done // Buffer of incoming data from server. receiveBuf []byte // To send commands to the server. @@ -27,9 +29,7 @@ func NewHealthHandler(server string, receive chan<- string) *HealthHandler { receive: receive, commands: make(chan string), status: -1, - withCancel: withCancel{ - done: make(chan struct{}), - }, + done: internal.NewDone(), } return &h @@ -45,12 +45,21 @@ func (h *HealthHandler) Status() int { return h.status } +func (h *HealthHandler) Done() <-chan struct{} { + return h.done.Done() +} + +func (h *HealthHandler) Shutdown() { + h.done.Shutdown() +} + // SendMessage sends a DTail command to the server. func (h *HealthHandler) SendMessage(command string) error { select { case h.commands <- fmt.Sprintf("%s;", command): case <-time.NewTimer(time.Second * 10).C: return errors.New("Timed out sending command " + command) + case <-h.Done(): } return nil diff --git a/internal/clients/handlers/maprhandler.go b/internal/clients/handlers/maprhandler.go index b908f3b..fb71c8f 100644 --- a/internal/clients/handlers/maprhandler.go +++ b/internal/clients/handlers/maprhandler.go @@ -3,6 +3,7 @@ package handlers import ( "strings" + "github.com/mimecast/dtail/internal" "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/mapr" "github.com/mimecast/dtail/internal/mapr/client" @@ -24,9 +25,7 @@ func NewMaprHandler(server string, query *mapr.Query, globalGroup *mapr.GlobalGr shellStarted: false, commands: make(chan string), status: -1, - withCancel: withCancel{ - done: make(chan struct{}), - }, + done: internal.NewDone(), }, query: query, aggregate: client.NewAggregate(server, query, globalGroup), diff --git a/internal/clients/handlers/withcancel.go b/internal/clients/handlers/withcancel.go deleted file mode 100644 index 7c9cf4e..0000000 --- a/internal/clients/handlers/withcancel.go +++ /dev/null @@ -1,24 +0,0 @@ -package handlers - -import "context" - -type withCancel struct { - ctx context.Context - done chan struct{} -} - -// WithCancel sets and returns the context used. -func (w *withCancel) WithCancel(ctx context.Context) (context.Context, context.CancelFunc) { - cancelCtx, cancel := context.WithCancel(ctx) - w.ctx = cancelCtx - - return cancelCtx, cancel -} - -func (w *withCancel) Done() <-chan struct{} { - return w.done -} - -func (w *withCancel) shutdown() { - close(w.done) -} diff --git a/internal/clients/healthclient.go b/internal/clients/healthclient.go index 7313583..e93f6be 100644 --- a/internal/clients/healthclient.go +++ b/internal/clients/healthclient.go @@ -50,7 +50,7 @@ func (c *HealthClient) Start(ctx context.Context) (status int) { conn.Handler = handlers.NewHealthHandler(c.server, receive) conn.Commands = []string{c.mode.String()} - connCtx, cancel := conn.Handler.WithCancel(ctx) + connCtx, cancel := context.WithCancel(ctx) go conn.Start(connCtx, cancel, throttleCh, statsCh) for { diff --git a/internal/clients/remote/connection.go b/internal/clients/remote/connection.go index 2d97d14..b29ffed 100644 --- a/internal/clients/remote/connection.go +++ b/internal/clients/remote/connection.go @@ -177,21 +177,21 @@ func (c *Connection) handle(ctx context.Context, cancel context.CancelFunc, sess } go func() { - defer cancel() io.Copy(stdinPipe, c.Handler) + cancel() }() go func() { - defer cancel() io.Copy(c.Handler, stdoutPipe) + cancel() }() go func() { - defer cancel() select { case <-c.Handler.Done(): case <-ctx.Done(): } + cancel() }() // Send all commands to client. @@ -207,5 +207,6 @@ func (c *Connection) handle(ctx context.Context, cancel context.CancelFunc, sess } <-ctx.Done() + c.Handler.Shutdown() return nil } diff --git a/internal/done.go b/internal/done.go new file mode 100644 index 0000000..2326eee --- /dev/null +++ b/internal/done.go @@ -0,0 +1,32 @@ +package internal + +import ( + "sync" +) + +type Done struct { + ch chan struct{} + mutex sync.Mutex +} + +func NewDone() *Done { + return &Done{ + ch: make(chan struct{}), + } +} + +func (d *Done) Done() <-chan struct{} { + return d.ch +} + +func (d *Done) Shutdown() { + d.mutex.Lock() + defer d.mutex.Unlock() + + select { + case <-d.ch: + return + default: + close(d.ch) + } +} diff --git a/internal/mapr/server/aggregate.go b/internal/mapr/server/aggregate.go index 1028943..cd59b63 100644 --- a/internal/mapr/server/aggregate.go +++ b/internal/mapr/server/aggregate.go @@ -6,6 +6,7 @@ import ( "strings" "time" + "github.com/mimecast/dtail/internal" "github.com/mimecast/dtail/internal/config" "github.com/mimecast/dtail/internal/io/line" "github.com/mimecast/dtail/internal/io/logger" @@ -15,6 +16,7 @@ import ( // Aggregate is for aggregating mapreduce data on the DTail server side. type Aggregate struct { + done *internal.Done // Log lines to process (parsing MAPREDUCE lines). Lines chan line.Line // Hostname of the current server (used to populate $hostname field). @@ -23,12 +25,12 @@ type Aggregate struct { serialize chan struct{} // Signals to flush data. flush chan struct{} + // Signals that data has been flushed + flushed chan struct{} // The mapr query query *mapr.Query // The mapr log format parser parser *logformat.Parser - cancel context.CancelFunc - ctx context.Context } // NewAggregate return a new server side aggregator. @@ -64,56 +66,63 @@ func NewAggregate(queryStr string) (*Aggregate, error) { } } - ctx, cancel := context.WithCancel(context.Background()) - a := Aggregate{ + done: internal.NewDone(), Lines: make(chan line.Line, 100), serialize: make(chan struct{}), flush: make(chan struct{}), + flushed: make(chan struct{}), hostname: s[0], query: query, parser: logParser, - ctx: ctx, - cancel: cancel, } return &a, nil } +func (a *Aggregate) Shutdown() { + a.Flush() + a.done.Shutdown() +} + // Start an aggregation. func (a *Aggregate) Start(ctx context.Context, maprLines chan<- string) { - defer a.cancel() - fieldsCh := a.linesToFields(ctx) + myCtx, cancel := context.WithCancel(ctx) + defer cancel() + + go func() { + select { + case <-myCtx.Done(): + a.done.Shutdown() + case <-a.done.Done(): + cancel() + } + }() + + fieldsCh := a.makeFields(myCtx) // Add fields (e.g. via 'set' clause) if len(a.query.Set) > 0 { - fieldsCh = a.addMoreFields(ctx, fieldsCh) + fieldsCh = a.addFields(myCtx, fieldsCh) } - go a.fieldsToMaprLines(ctx, fieldsCh, maprLines) - a.periodicAggregateTimer(ctx) + go a.aggregateTimer(myCtx) + a.makeMaprLines(myCtx, fieldsCh, maprLines) } -// Cancel the aggregation. -func (a *Aggregate) Cancel() { - a.cancel() -} - -func (a *Aggregate) periodicAggregateTimer(ctx context.Context) { +func (a *Aggregate) aggregateTimer(ctx context.Context) { for { select { case <-time.After(a.query.Interval): a.Serialize(ctx) case <-ctx.Done(): return - case <-a.ctx.Done(): - return } } } -func (a *Aggregate) linesToFields(ctx context.Context) <-chan map[string]string { +func (a *Aggregate) makeFields(ctx context.Context) <-chan map[string]string { ch := make(chan map[string]string) go func() { @@ -144,8 +153,6 @@ func (a *Aggregate) linesToFields(ctx context.Context) <-chan map[string]string } case <-ctx.Done(): return - case <-a.ctx.Done(): - return } } }() @@ -153,14 +160,14 @@ func (a *Aggregate) linesToFields(ctx context.Context) <-chan map[string]string return ch } -func (a *Aggregate) addMoreFields(ctx context.Context, fieldsCh <-chan map[string]string) <-chan map[string]string { +func (a *Aggregate) addFields(ctx context.Context, fieldsCh <-chan map[string]string) <-chan map[string]string { ch := make(chan map[string]string) go func() { defer close(ch) for { - // fieldsCh will be closed via 'linesToFields' if ctx is done + // fieldsCh will be closed via 'makeFields' if ctx is done fields, ok := <-fieldsCh if !ok { return @@ -179,7 +186,7 @@ func (a *Aggregate) addMoreFields(ctx context.Context, fieldsCh <-chan map[strin return ch } -func (a *Aggregate) fieldsToMaprLines(ctx context.Context, fieldsCh <-chan map[string]string, maprLines chan<- string) { +func (a *Aggregate) makeMaprLines(ctx context.Context, fieldsCh <-chan map[string]string, maprLines chan<- string) { group := mapr.NewGroupSet() serialize := func() { @@ -200,18 +207,10 @@ func (a *Aggregate) fieldsToMaprLines(ctx context.Context, fieldsCh <-chan map[s case <-a.serialize: serialize() case <-a.flush: - logger.Info("Flushing mapreduce result") serialize() - a.flush <- struct{}{} - logger.Info("Done flushing mapreduce result") + a.flushed <- struct{}{} case <-ctx.Done(): return - case <-a.ctx.Done(): - logger.Info("Flushing mapreduce result") - serialize() - a.flush <- struct{}{} - logger.Info("Done flushing mapreduce result") - return } } } @@ -254,6 +253,8 @@ func (a *Aggregate) aggregate(group *mapr.GroupSet, fields map[string]string) { func (a *Aggregate) Serialize(ctx context.Context) { select { case a.serialize <- struct{}{}: + case <-time.After(time.Minute): + logger.Warn("Starting to serialize mapredice data takes over a minute") case <-ctx.Done(): } } @@ -261,15 +262,20 @@ func (a *Aggregate) Serialize(ctx context.Context) { // Flush all data. func (a *Aggregate) Flush() { select { - case <-a.ctx.Done(): - return case a.flush <- struct{}{}: + logger.Info("Flushing mapreduce data") case <-time.After(time.Minute): + logger.Warn("Starting to flush mapreduce data takes over a minute") + return + case <-a.done.Done(): return } select { - case <-a.flush: + case <-a.flushed: + logger.Info("Done flushing") case <-time.After(time.Minute): + logger.Warn("Waiting for data to be flushed takes over a minute") + case <-a.done.Done(): } } diff --git a/internal/server/handlers/controlhandler.go b/internal/server/handlers/controlhandler.go index daa9835..9a8eb75 100644 --- a/internal/server/handlers/controlhandler.go +++ b/internal/server/handlers/controlhandler.go @@ -1,20 +1,19 @@ package handlers import ( - "context" "fmt" "io" "os" "strings" + "github.com/mimecast/dtail/internal" "github.com/mimecast/dtail/internal/io/logger" user "github.com/mimecast/dtail/internal/user/server" ) // ControlHandler is used for control functions and health monitoring. type ControlHandler struct { - ctx context.Context - done chan struct{} + done *internal.Done hostname string payload []byte serverMessages chan string @@ -22,12 +21,11 @@ type ControlHandler struct { } // NewControlHandler returns a new control handler. -func NewControlHandler(ctx context.Context, user *user.User) (*ControlHandler, <-chan struct{}) { +func NewControlHandler(user *user.User) *ControlHandler { logger.Debug(user, "Creating control handler") h := ControlHandler{ - ctx: ctx, - done: make(chan struct{}), + done: internal.NewDone(), serverMessages: make(chan string, 10), user: user, } @@ -40,7 +38,15 @@ func NewControlHandler(ctx context.Context, user *user.User) (*ControlHandler, < s := strings.Split(fqdn, ".") h.hostname = s[0] - return &h, h.done + return &h +} + +func (h *ControlHandler) Shutdown() { + h.done.Shutdown() +} + +func (h *ControlHandler) Done() <-chan struct{} { + return h.done.Done() } // Read is to send data to the client via the Reader interface. @@ -51,7 +57,7 @@ func (h *ControlHandler) Read(p []byte) (n int, err error) { wholePayload := []byte(fmt.Sprintf("SERVER|%s|%s\n", h.hostname, message)) n = copy(p, wholePayload) return - case <-h.ctx.Done(): + case <-h.done.Done(): return 0, io.EOF } } @@ -63,7 +69,7 @@ func (h *ControlHandler) Write(p []byte) (n int, err error) { switch c { case ';': wholePayload := strings.TrimSpace(string(h.payload)) - h.handleCommand(h.ctx, wholePayload) + h.handleCommand(wholePayload) h.payload = nil default: @@ -75,7 +81,7 @@ func (h *ControlHandler) Write(p []byte) (n int, err error) { return } -func (h *ControlHandler) handleCommand(ctx context.Context, command string) { +func (h *ControlHandler) handleCommand(command string) { logger.Info(h.user, command) s := strings.Split(command, " ") logger.Debug(h.user, "Receiving command", command, s) diff --git a/internal/server/handlers/handler.go b/internal/server/handlers/handler.go index c42ceb9..b04e854 100644 --- a/internal/server/handlers/handler.go +++ b/internal/server/handlers/handler.go @@ -5,4 +5,6 @@ import "io" // Handler interface for server side functionality. type Handler interface { io.ReadWriter + Shutdown() + Done() <-chan struct{} } diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go index 7017f3e..164a280 100644 --- a/internal/server/handlers/serverhandler.go +++ b/internal/server/handlers/serverhandler.go @@ -13,6 +13,7 @@ import ( "sync/atomic" "time" + "github.com/mimecast/dtail/internal" "github.com/mimecast/dtail/internal/config" "github.com/mimecast/dtail/internal/io/line" "github.com/mimecast/dtail/internal/io/logger" @@ -31,33 +32,28 @@ const ( // the Bi-directional communication between SSH client and server. // This handler implements the handler of the SSH server. type ServerHandler struct { - lines chan line.Line - regex string - aggregate *server.Aggregate - aggregatedMessages chan string - serverMessages chan string - payload []byte - hostname string - user *user.User - // TODO: Move all these channels into a separate struct for readability! + done *internal.Done + lines chan line.Line + regex string + aggregate *server.Aggregate + aggregatedMessages chan string + serverMessages chan string + payload []byte + hostname string + user *user.User catLimiter chan struct{} tailLimiter chan struct{} globalServerWaitFor chan struct{} ackCloseReceived chan struct{} - serverCtx context.Context - handlerCtx context.Context - done chan struct{} activeCommands int32 activeReaders int32 background background.Background } // NewServerHandler returns the server handler. -func NewServerHandler(handlerCtx, serverCtx context.Context, user *user.User, catLimiter, tailLimiter, globalServerWaitFor chan struct{}, background background.Background) (*ServerHandler, <-chan struct{}) { +func NewServerHandler(user *user.User, catLimiter, tailLimiter, globalServerWaitFor chan struct{}, background background.Background) *ServerHandler { h := ServerHandler{ - serverCtx: serverCtx, - handlerCtx: handlerCtx, - done: make(chan struct{}), + done: internal.NewDone(), lines: make(chan line.Line, 100), serverMessages: make(chan string, 10), aggregatedMessages: make(chan string, 10), @@ -78,7 +74,15 @@ func NewServerHandler(handlerCtx, serverCtx context.Context, user *user.User, ca s := strings.Split(fqdn, ".") h.hostname = s[0] - return &h, h.done + return &h +} + +func (h *ServerHandler) Shutdown() { + h.done.Shutdown() +} + +func (h *ServerHandler) Done() <-chan struct{} { + return h.done.Done() } // Read is to send data to the dtail client via Reader interface. @@ -120,7 +124,7 @@ func (h *ServerHandler) Read(p []byte) (n int, err error) { case <-time.After(time.Second): // Once in a while check whether we are done. select { - case <-h.handlerCtx.Done(): + case <-h.done.Done(): return 0, io.EOF default: } @@ -134,7 +138,7 @@ func (h *ServerHandler) Write(p []byte) (n int, err error) { switch c { case ';': commandStr := strings.TrimSpace(string(h.payload)) - h.handleCommand(h.handlerCtx, commandStr) + h.handleCommand(commandStr) h.payload = nil default: h.payload = append(h.payload, c) @@ -145,9 +149,10 @@ func (h *ServerHandler) Write(p []byte) (n int, err error) { return } -func (h *ServerHandler) handleCommand(ctx context.Context, commandStr string) { +func (h *ServerHandler) handleCommand(commandStr string) { logger.Debug(h.user, commandStr) var timeout time.Duration + ctx := context.Background() args, argc, err := h.handleProtocolVersion(strings.Split(commandStr, " ")) if err != nil { @@ -172,15 +177,21 @@ func (h *ServerHandler) handleCommand(ctx context.Context, commandStr string) { return } + ctx, cancel := context.WithCancel(ctx) + go func() { + <-h.done.Done() + cancel() + }() + if timeout > 0 { logger.Info(h.user, "Command with timeout context", argc, args, timeout) - commandCtx, cancel := context.WithTimeout(ctx, timeout) + ctx, cancel := context.WithTimeout(ctx, timeout) go func() { - <-commandCtx.Done() + <-ctx.Done() logger.Info(h.user, "Command timed out, canceling it", args, args, timeout) cancel() }() - h.handleUserCommand(commandCtx, argc, args, timeout) + h.handleUserCommand(ctx, argc, args, timeout) return } @@ -255,7 +266,7 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args [] if h.aggregate == nil { return } - h.aggregate.Cancel() + h.aggregate.Shutdown() } } @@ -348,9 +359,8 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args [] // Set default background timeout. timeout = time.Hour * 1 } - // Use a new context based on the server context, so that background job does not get - // terminated when handler/SSH connection terminates. - commandCtx, cancel := context.WithTimeout(h.serverCtx, timeout) + + commandCtx, cancel := context.WithTimeout(ctx, timeout) if err := h.background.Add(h.user.Name, jobName, cancel, &wg); err != nil { h.sendServerMessage(logger.Error(h.user, err, jobName, args)) @@ -406,7 +416,7 @@ func (h *ServerHandler) handleAckCommand(argc int, args []string) { func (h *ServerHandler) send(ch chan<- string, message string) { select { case ch <- message: - case <-h.handlerCtx.Done(): + case <-h.done.Done(): } } @@ -447,7 +457,7 @@ func (h *ServerHandler) shutdown() { go func() { select { case h.serverMessageC() <- ".syn close connection": - case <-h.handlerCtx.Done(): + case <-h.done.Done(): } }() @@ -455,13 +465,10 @@ func (h *ServerHandler) shutdown() { case <-h.ackCloseReceived: case <-time.After(time.Second * 5): logger.Debug(h.user, "Shutdown timeout reached, enforcing shutdown") - case <-h.handlerCtx.Done(): + case <-h.done.Done(): } - select { - case h.done <- struct{}{}: - default: - } + h.done.Shutdown() } func (h *ServerHandler) incrementActiveCommands() { diff --git a/internal/server/server.go b/internal/server/server.go index a446738..5e2a521 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -178,53 +178,46 @@ func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, in <-ch switch req.Type { case "shell": - handlerCtx, cancel := context.WithCancel(ctx) - var handler handlers.Handler - var done <-chan struct{} - switch user.Name { case config.ControlUser: - handler, done = handlers.NewControlHandler(handlerCtx, user) + handler = handlers.NewControlHandler(user) default: - handler, done = handlers.NewServerHandler(handlerCtx, ctx, user, s.catLimiter, s.tailLimiter, s.shutdownWaitFor, s.background) + handler = handlers.NewServerHandler(user, s.catLimiter, s.tailLimiter, s.shutdownWaitFor, s.background) } - go func() { - // Handler finished work, cancel all remaining routines - defer cancel() - - <-done - }() + terminate := func() { + handler.Shutdown() + sshConn.Close() + } go func() { // Broken pipe, cancel - defer cancel() - io.Copy(channel, handler) + terminate() }() go func() { // Broken pipe, cancel - defer cancel() - io.Copy(handler, channel) + terminate() }() go func() { - defer cancel() + select { + case <-ctx.Done(): + case <-handler.Done(): + } + terminate() + }() + go func() { if err := sshConn.Wait(); err != nil && err != io.EOF { logger.Error(user, err) } s.stats.decrementConnections() logger.Info(user, "Good bye Mister!") - }() - - go func() { - <-handlerCtx.Done() - sshConn.Close() - logger.Info(user, "Closed SSH connection") + terminate() }() // Only serving shell type |
