summaryrefslogtreecommitdiff
path: root/internal/server/server.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/server/server.go')
-rw-r--r--internal/server/server.go107
1 files changed, 61 insertions, 46 deletions
diff --git a/internal/server/server.go b/internal/server/server.go
index 8b581b1..38b042f 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -170,52 +170,7 @@ func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn,
switch req.Type {
case "shell":
- var handler handlers.Handler
- switch user.Name {
- case config.HealthUser:
- handler = handlers.NewHealthHandler(user)
- default:
- handler = handlers.NewServerHandler(user, s.catLimiter, s.tailLimiter)
- }
- terminate := func() {
- handler.Shutdown()
- sshConn.Close()
- }
-
- go func() {
- defer terminate()
- // Broken pipe, cancel
- if _, err := io.Copy(channel, handler); err != nil {
- dlog.Server.Trace(user, fmt.Errorf("channel->handler: %w", err))
- }
- }()
- go func() {
- defer terminate()
- // Broken pipe, cancel
- if _, err := io.Copy(handler, channel); err != nil {
- dlog.Server.Trace(user, fmt.Errorf("handler->channel: %w", err))
- }
- }()
- go func() {
- select {
- case <-ctx.Done():
- case <-handler.Done():
- }
- terminate()
- }()
- go func() {
- if err := sshConn.Wait(); err != nil && err != io.EOF {
- dlog.Server.Error(user, err)
- }
- s.stats.decrementConnections()
- dlog.Server.Info(user, "Good bye Mister!")
- terminate()
- }()
-
- // Only serving shell type
- if err := req.Reply(true, nil); err != nil {
- dlog.Server.Trace(user, fmt.Errorf("reply(true): %w", err))
- }
+ s.handleShellRequest(ctx, sshConn, channel, user, req)
default:
if err := req.Reply(false, nil); err != nil {
dlog.Server.Trace(user, fmt.Errorf("reply(false): %w", err))
@@ -227,6 +182,66 @@ func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn,
return nil
}
+// handleShellRequest sets up the shell session with handler goroutines for I/O,
+// context cancellation, and connection lifecycle management.
+func (s *Server) handleShellRequest(ctx context.Context, sshConn gossh.Conn,
+ channel gossh.Channel, user *user.User, req *gossh.Request) {
+
+ // Create the appropriate handler based on user type
+ var handler handlers.Handler
+ switch user.Name {
+ case config.HealthUser:
+ handler = handlers.NewHealthHandler(user)
+ default:
+ handler = handlers.NewServerHandler(user, s.catLimiter, s.tailLimiter)
+ }
+
+ terminate := func() {
+ handler.Shutdown()
+ sshConn.Close()
+ }
+
+ // Start goroutine to copy data from channel to handler
+ go func() {
+ defer terminate()
+ if _, err := io.Copy(channel, handler); err != nil {
+ dlog.Server.Trace(user, fmt.Errorf("channel->handler: %w", err))
+ }
+ }()
+
+ // Start goroutine to copy data from handler to channel
+ go func() {
+ defer terminate()
+ if _, err := io.Copy(handler, channel); err != nil {
+ dlog.Server.Trace(user, fmt.Errorf("handler->channel: %w", err))
+ }
+ }()
+
+ // Start goroutine to handle context or handler completion
+ go func() {
+ select {
+ case <-ctx.Done():
+ case <-handler.Done():
+ }
+ terminate()
+ }()
+
+ // Start goroutine to handle connection lifecycle and cleanup
+ go func() {
+ if err := sshConn.Wait(); err != nil && err != io.EOF {
+ dlog.Server.Error(user, err)
+ }
+ s.stats.decrementConnections()
+ dlog.Server.Info(user, "Good bye Mister!")
+ terminate()
+ }()
+
+ // Reply to indicate shell request was accepted
+ if err := req.Reply(true, nil); err != nil {
+ dlog.Server.Trace(user, fmt.Errorf("reply(true): %w", err))
+ }
+}
+
// Callback for SSH authentication.
func (s *Server) Callback(c gossh.ConnMetadata,
authPayload []byte) (*gossh.Permissions, error) {