summaryrefslogtreecommitdiff
path: root/internal/server/server.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/server/server.go')
-rw-r--r--internal/server/server.go46
1 files changed, 41 insertions, 5 deletions
diff --git a/internal/server/server.go b/internal/server/server.go
index 9b9d8a0..9314468 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -6,6 +6,7 @@ import (
"fmt"
"io"
"net"
+ "time"
"github.com/mimecast/dtail/internal/config"
"github.com/mimecast/dtail/internal/io/logger"
@@ -24,11 +25,13 @@ type Server struct {
// SSH server configuration.
sshServerConfig *gossh.ServerConfig
// To control the max amount of concurrent cats (which can cause a lot of I/O on the server)
- catLimiterCh chan struct{}
+ catLimiter chan struct{}
// To control the max amount of concurrent tails
- tailLimiterCh chan struct{}
+ tailLimiter chan struct{}
// To run scheduled tasks (if configured)
sched *scheduler
+ // Wait counter, e.g. there might be still subprocesses (forked by drun) to be killed.
+ shutdownWaitFor chan struct{}
}
// New returns a new server.
@@ -37,8 +40,9 @@ func New() *Server {
s := Server{
sshServerConfig: &gossh.ServerConfig{},
- catLimiterCh: make(chan struct{}, config.Server.MaxConcurrentCats),
- tailLimiterCh: make(chan struct{}, config.Server.MaxConcurrentTails),
+ catLimiter: make(chan struct{}, config.Server.MaxConcurrentCats),
+ tailLimiter: make(chan struct{}, config.Server.MaxConcurrentTails),
+ shutdownWaitFor: make(chan struct{}, 1000),
sched: newScheduler(),
}
@@ -60,6 +64,7 @@ func (s *Server) Start(ctx context.Context) int {
bindAt := fmt.Sprintf("%s:%d", config.Server.SSHBindAddress, config.Common.SSHPort)
logger.Info("Binding server", bindAt)
+
listener, err := net.Listen("tcp", bindAt)
if err != nil {
logger.FatalExit("Failed to open listening TCP socket", err)
@@ -67,10 +72,40 @@ func (s *Server) Start(ctx context.Context) int {
go s.stats.start(ctx)
go s.sched.start(ctx)
+ go s.listenerLoop(ctx, listener)
+
+ select {
+ case <-ctx.Done():
+ // Wait until all commands/jobs/children are no more!
+ s.wait()
+ }
+
+ // For future use.
+ return 0
+}
+
+func (s *Server) wait() {
+ for {
+ num := len(s.shutdownWaitFor)
+ logger.Debug("Waiting for stuff to finish", num)
+ if num <= 0 {
+ return
+ }
+ time.Sleep(time.Second)
+ }
+}
+
+func (s *Server) listenerLoop(ctx context.Context, listener net.Listener) {
+ logger.Debug("Starting listener loop")
for {
conn, err := listener.Accept() // Blocking
if err != nil {
+ select {
+ case <-ctx.Done():
+ return
+ default:
+ }
logger.Error("Failed to accept incoming connection", err)
continue
}
@@ -143,12 +178,13 @@ func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, in <-ch
case config.ControlUser:
handler, done = handlers.NewControlHandler(handlerCtx, user)
default:
- handler, done = handlers.NewServerHandler(handlerCtx, user, s.catLimiterCh, s.tailLimiterCh)
+ handler, done = handlers.NewServerHandler(handlerCtx, user, s.catLimiter, s.tailLimiter, s.shutdownWaitFor)
}
go func() {
// Handler finished work, cancel all remaining routines
defer cancel()
+
<-done
}()