From 0945da8dfefcbb723eecea0e5f4eafff63398253 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul=20B=C3=BCtow?= Date: Sun, 26 Jan 2020 11:26:53 +0000 Subject: Introduce drun command, refactor code to use context package --- internal/server/server.go | 70 +++++++++++++++++++++-------------------------- 1 file changed, 31 insertions(+), 39 deletions(-) (limited to 'internal/server/server.go') diff --git a/internal/server/server.go b/internal/server/server.go index 27a98f5..42eb74c 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -1,13 +1,14 @@ package server import ( + "context" "errors" "fmt" "io" "net" "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/server/handlers" "github.com/mimecast/dtail/internal/ssh/server" user "github.com/mimecast/dtail/internal/user/server" @@ -26,8 +27,6 @@ type Server struct { catLimiterCh chan struct{} // To control the max amount of concurrent tails tailLimiterCh chan struct{} - // Ask to shutdown the server - stop chan struct{} } // New returns a new server. @@ -38,7 +37,6 @@ func New() *Server { sshServerConfig: &gossh.ServerConfig{}, catLimiterCh: make(chan struct{}, config.Server.MaxConcurrentCats), tailLimiterCh: make(chan struct{}, config.Server.MaxConcurrentTails), - stop: make(chan struct{}), } s.sshServerConfig.PasswordCallback = s.controlUserCallback @@ -54,7 +52,7 @@ func New() *Server { } // Start the server. -func (s *Server) Start() int { +func (s *Server) Start(ctx context.Context) int { logger.Info("Starting server") bindAt := fmt.Sprintf("%s:%d", config.Server.SSHBindAddress, config.Common.SSHPort) @@ -64,7 +62,7 @@ func (s *Server) Start() int { logger.FatalExit("Failed to open listening TCP socket", err) } - go s.stats.periodicLogServerStats(s.stop) + go s.stats.periodicLogServerStats(ctx) for { conn, err := listener.Accept() // Blocking @@ -79,11 +77,11 @@ func (s *Server) Start() int { continue } - go s.handleConnection(conn) + go s.handleConnection(ctx, conn) } } -func (s *Server) handleConnection(conn net.Conn) { +func (s *Server) handleConnection(ctx context.Context, conn net.Conn) { logger.Info("Handling connection") sshConn, chans, reqs, err := gossh.NewServerConn(conn, s.sshServerConfig) @@ -96,11 +94,11 @@ func (s *Server) handleConnection(conn net.Conn) { go gossh.DiscardRequests(reqs) for newChannel := range chans { - go s.handleChannel(sshConn, newChannel) + go s.handleChannel(ctx, sshConn, newChannel) } } -func (s *Server) handleChannel(sshConn gossh.Conn, newChannel gossh.NewChannel) { +func (s *Server) handleChannel(ctx context.Context, sshConn gossh.Conn, newChannel gossh.NewChannel) { user := user.New(sshConn.User(), sshConn.RemoteAddr().String()) logger.Info(user, "Invoking channel handler") @@ -117,13 +115,13 @@ func (s *Server) handleChannel(sshConn gossh.Conn, newChannel gossh.NewChannel) return } - if err := s.handleRequests(sshConn, requests, channel, user); err != nil { + if err := s.handleRequests(ctx, sshConn, requests, channel, user); err != nil { logger.Error(user, err) sshConn.Close() } } -func (s *Server) handleRequests(sshConn gossh.Conn, in <-chan *gossh.Request, channel gossh.Channel, user *user.User) error { +func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, in <-chan *gossh.Request, channel gossh.Channel, user *user.User) error { logger.Info(user, "Invoking request handler") for req := range in { @@ -132,50 +130,50 @@ func (s *Server) handleRequests(sshConn gossh.Conn, in <-chan *gossh.Request, ch switch req.Type { case "shell": + handlerCtx, cancel := context.WithCancel(ctx) + var handler handlers.Handler + var done <-chan struct{} + switch user.Name { case config.ControlUser: - handler = handlers.NewControlHandler(user) + handler, done = handlers.NewControlHandler(handlerCtx, user) default: - handler = handlers.NewServerHandler(user, s.catLimiterCh, s.tailLimiterCh) + handler, done = handlers.NewServerHandler(handlerCtx, user, s.catLimiterCh, s.tailLimiterCh) } - // Bi-directionally connect SSH stream to SSH handler - brokenPipe1 := make(chan struct{}) go func() { - defer close(brokenPipe1) + // Handler finished work, cancel all remaining routines + defer cancel() + <-done + }() + + go func() { + // Broken pipe, cancel + defer cancel() + io.Copy(channel, handler) }() - brokenPipe2 := make(chan struct{}) go func() { - defer close(brokenPipe2) + // Broken pipe, cancel + defer cancel() + io.Copy(handler, channel) }() - // Ensure to close all fd's and stop all goroutines once ssh connection terminated go func() { - defer s.stats.decrementConnections() - defer handler.Close() + defer cancel() if err := sshConn.Wait(); err != nil && err != io.EOF { logger.Error(user, err) } + s.stats.decrementConnections() logger.Info(user, "Good bye Mister!") }() - // Close the underlying ssh socket when server shuts down go func() { - select { - case <-s.stop: - logger.Debug(user, "Server initiating shutdown on handler") - case <-handler.Wait(): - logger.Debug(user, "Handler initiating shutdown by its own") - case <-brokenPipe1: - logger.Debug(user, "Broken pipe1") - case <-brokenPipe2: - logger.Debug(user, "Broken pipe2") - } + <-handlerCtx.Done() sshConn.Close() logger.Info(user, "Closed SSH connection") }() @@ -204,9 +202,3 @@ func (*Server) controlUserCallback(c gossh.ConnMetadata, authPayload []byte) (*g return nil, fmt.Errorf("Not authorized") } - -// Stop the server. -func (s *Server) Stop() { - close(s.stop) - s.stats.waitForConnections() -} -- cgit v1.2.3