diff options
Diffstat (limited to 'internal/server/server.go')
| -rw-r--r-- | internal/server/server.go | 46 |
1 files changed, 41 insertions, 5 deletions
diff --git a/internal/server/server.go b/internal/server/server.go index 9b9d8a0..9314468 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net" + "time" "github.com/mimecast/dtail/internal/config" "github.com/mimecast/dtail/internal/io/logger" @@ -24,11 +25,13 @@ type Server struct { // SSH server configuration. sshServerConfig *gossh.ServerConfig // To control the max amount of concurrent cats (which can cause a lot of I/O on the server) - catLimiterCh chan struct{} + catLimiter chan struct{} // To control the max amount of concurrent tails - tailLimiterCh chan struct{} + tailLimiter chan struct{} // To run scheduled tasks (if configured) sched *scheduler + // Wait counter, e.g. there might be still subprocesses (forked by drun) to be killed. + shutdownWaitFor chan struct{} } // New returns a new server. @@ -37,8 +40,9 @@ func New() *Server { s := Server{ sshServerConfig: &gossh.ServerConfig{}, - catLimiterCh: make(chan struct{}, config.Server.MaxConcurrentCats), - tailLimiterCh: make(chan struct{}, config.Server.MaxConcurrentTails), + catLimiter: make(chan struct{}, config.Server.MaxConcurrentCats), + tailLimiter: make(chan struct{}, config.Server.MaxConcurrentTails), + shutdownWaitFor: make(chan struct{}, 1000), sched: newScheduler(), } @@ -60,6 +64,7 @@ func (s *Server) Start(ctx context.Context) int { bindAt := fmt.Sprintf("%s:%d", config.Server.SSHBindAddress, config.Common.SSHPort) logger.Info("Binding server", bindAt) + listener, err := net.Listen("tcp", bindAt) if err != nil { logger.FatalExit("Failed to open listening TCP socket", err) @@ -67,10 +72,40 @@ func (s *Server) Start(ctx context.Context) int { go s.stats.start(ctx) go s.sched.start(ctx) + go s.listenerLoop(ctx, listener) + + select { + case <-ctx.Done(): + // Wait until all commands/jobs/children are no more! + s.wait() + } + + // For future use. + return 0 +} + +func (s *Server) wait() { + for { + num := len(s.shutdownWaitFor) + logger.Debug("Waiting for stuff to finish", num) + if num <= 0 { + return + } + time.Sleep(time.Second) + } +} + +func (s *Server) listenerLoop(ctx context.Context, listener net.Listener) { + logger.Debug("Starting listener loop") for { conn, err := listener.Accept() // Blocking if err != nil { + select { + case <-ctx.Done(): + return + default: + } logger.Error("Failed to accept incoming connection", err) continue } @@ -143,12 +178,13 @@ func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, in <-ch case config.ControlUser: handler, done = handlers.NewControlHandler(handlerCtx, user) default: - handler, done = handlers.NewServerHandler(handlerCtx, user, s.catLimiterCh, s.tailLimiterCh) + handler, done = handlers.NewServerHandler(handlerCtx, user, s.catLimiter, s.tailLimiter, s.shutdownWaitFor) } go func() { // Handler finished work, cancel all remaining routines defer cancel() + <-done }() |
