diff options
Diffstat (limited to 'internal/server/server.go')
| -rw-r--r-- | internal/server/server.go | 107 |
1 files changed, 61 insertions, 46 deletions
diff --git a/internal/server/server.go b/internal/server/server.go index 8b581b1..38b042f 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -170,52 +170,7 @@ func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, switch req.Type { case "shell": - var handler handlers.Handler - switch user.Name { - case config.HealthUser: - handler = handlers.NewHealthHandler(user) - default: - handler = handlers.NewServerHandler(user, s.catLimiter, s.tailLimiter) - } - terminate := func() { - handler.Shutdown() - sshConn.Close() - } - - go func() { - defer terminate() - // Broken pipe, cancel - if _, err := io.Copy(channel, handler); err != nil { - dlog.Server.Trace(user, fmt.Errorf("channel->handler: %w", err)) - } - }() - go func() { - defer terminate() - // Broken pipe, cancel - if _, err := io.Copy(handler, channel); err != nil { - dlog.Server.Trace(user, fmt.Errorf("handler->channel: %w", err)) - } - }() - go func() { - select { - case <-ctx.Done(): - case <-handler.Done(): - } - terminate() - }() - go func() { - if err := sshConn.Wait(); err != nil && err != io.EOF { - dlog.Server.Error(user, err) - } - s.stats.decrementConnections() - dlog.Server.Info(user, "Good bye Mister!") - terminate() - }() - - // Only serving shell type - if err := req.Reply(true, nil); err != nil { - dlog.Server.Trace(user, fmt.Errorf("reply(true): %w", err)) - } + s.handleShellRequest(ctx, sshConn, channel, user, req) default: if err := req.Reply(false, nil); err != nil { dlog.Server.Trace(user, fmt.Errorf("reply(false): %w", err)) @@ -227,6 +182,66 @@ func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, return nil } +// handleShellRequest sets up the shell session with handler goroutines for I/O, +// context cancellation, and connection lifecycle management. +func (s *Server) handleShellRequest(ctx context.Context, sshConn gossh.Conn, + channel gossh.Channel, user *user.User, req *gossh.Request) { + + // Create the appropriate handler based on user type + var handler handlers.Handler + switch user.Name { + case config.HealthUser: + handler = handlers.NewHealthHandler(user) + default: + handler = handlers.NewServerHandler(user, s.catLimiter, s.tailLimiter) + } + + terminate := func() { + handler.Shutdown() + sshConn.Close() + } + + // Start goroutine to copy data from channel to handler + go func() { + defer terminate() + if _, err := io.Copy(channel, handler); err != nil { + dlog.Server.Trace(user, fmt.Errorf("channel->handler: %w", err)) + } + }() + + // Start goroutine to copy data from handler to channel + go func() { + defer terminate() + if _, err := io.Copy(handler, channel); err != nil { + dlog.Server.Trace(user, fmt.Errorf("handler->channel: %w", err)) + } + }() + + // Start goroutine to handle context or handler completion + go func() { + select { + case <-ctx.Done(): + case <-handler.Done(): + } + terminate() + }() + + // Start goroutine to handle connection lifecycle and cleanup + go func() { + if err := sshConn.Wait(); err != nil && err != io.EOF { + dlog.Server.Error(user, err) + } + s.stats.decrementConnections() + dlog.Server.Info(user, "Good bye Mister!") + terminate() + }() + + // Reply to indicate shell request was accepted + if err := req.Reply(true, nil); err != nil { + dlog.Server.Trace(user, fmt.Errorf("reply(true): %w", err)) + } +} + // Callback for SSH authentication. func (s *Server) Callback(c gossh.ConnMetadata, authPayload []byte) (*gossh.Permissions, error) { |
