summaryrefslogtreecommitdiff
path: root/internal/server/server.go
diff options
context:
space:
mode:
authorPaul Bütow <pbuetow@mimecast.com>2020-01-26 11:26:53 +0000
committerPaul Bütow <pbuetow@mimecast.com>2020-02-07 13:31:15 +0000
commit0945da8dfefcbb723eecea0e5f4eafff63398253 (patch)
treef06dab4d2bf21d25d176b23d5baeca588d27f5d7 /internal/server/server.go
parent2a8e5de265a0e0a31a5834909d6879f5c9941467 (diff)
Introduce drun command, refactor code to use context package
Diffstat (limited to 'internal/server/server.go')
-rw-r--r--internal/server/server.go70
1 files changed, 31 insertions, 39 deletions
diff --git a/internal/server/server.go b/internal/server/server.go
index 27a98f5..42eb74c 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -1,13 +1,14 @@
package server
import (
+ "context"
"errors"
"fmt"
"io"
"net"
"github.com/mimecast/dtail/internal/config"
- "github.com/mimecast/dtail/internal/logger"
+ "github.com/mimecast/dtail/internal/io/logger"
"github.com/mimecast/dtail/internal/server/handlers"
"github.com/mimecast/dtail/internal/ssh/server"
user "github.com/mimecast/dtail/internal/user/server"
@@ -26,8 +27,6 @@ type Server struct {
catLimiterCh chan struct{}
// To control the max amount of concurrent tails
tailLimiterCh chan struct{}
- // Ask to shutdown the server
- stop chan struct{}
}
// New returns a new server.
@@ -38,7 +37,6 @@ func New() *Server {
sshServerConfig: &gossh.ServerConfig{},
catLimiterCh: make(chan struct{}, config.Server.MaxConcurrentCats),
tailLimiterCh: make(chan struct{}, config.Server.MaxConcurrentTails),
- stop: make(chan struct{}),
}
s.sshServerConfig.PasswordCallback = s.controlUserCallback
@@ -54,7 +52,7 @@ func New() *Server {
}
// Start the server.
-func (s *Server) Start() int {
+func (s *Server) Start(ctx context.Context) int {
logger.Info("Starting server")
bindAt := fmt.Sprintf("%s:%d", config.Server.SSHBindAddress, config.Common.SSHPort)
@@ -64,7 +62,7 @@ func (s *Server) Start() int {
logger.FatalExit("Failed to open listening TCP socket", err)
}
- go s.stats.periodicLogServerStats(s.stop)
+ go s.stats.periodicLogServerStats(ctx)
for {
conn, err := listener.Accept() // Blocking
@@ -79,11 +77,11 @@ func (s *Server) Start() int {
continue
}
- go s.handleConnection(conn)
+ go s.handleConnection(ctx, conn)
}
}
-func (s *Server) handleConnection(conn net.Conn) {
+func (s *Server) handleConnection(ctx context.Context, conn net.Conn) {
logger.Info("Handling connection")
sshConn, chans, reqs, err := gossh.NewServerConn(conn, s.sshServerConfig)
@@ -96,11 +94,11 @@ func (s *Server) handleConnection(conn net.Conn) {
go gossh.DiscardRequests(reqs)
for newChannel := range chans {
- go s.handleChannel(sshConn, newChannel)
+ go s.handleChannel(ctx, sshConn, newChannel)
}
}
-func (s *Server) handleChannel(sshConn gossh.Conn, newChannel gossh.NewChannel) {
+func (s *Server) handleChannel(ctx context.Context, sshConn gossh.Conn, newChannel gossh.NewChannel) {
user := user.New(sshConn.User(), sshConn.RemoteAddr().String())
logger.Info(user, "Invoking channel handler")
@@ -117,13 +115,13 @@ func (s *Server) handleChannel(sshConn gossh.Conn, newChannel gossh.NewChannel)
return
}
- if err := s.handleRequests(sshConn, requests, channel, user); err != nil {
+ if err := s.handleRequests(ctx, sshConn, requests, channel, user); err != nil {
logger.Error(user, err)
sshConn.Close()
}
}
-func (s *Server) handleRequests(sshConn gossh.Conn, in <-chan *gossh.Request, channel gossh.Channel, user *user.User) error {
+func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, in <-chan *gossh.Request, channel gossh.Channel, user *user.User) error {
logger.Info(user, "Invoking request handler")
for req := range in {
@@ -132,50 +130,50 @@ func (s *Server) handleRequests(sshConn gossh.Conn, in <-chan *gossh.Request, ch
switch req.Type {
case "shell":
+ handlerCtx, cancel := context.WithCancel(ctx)
+
var handler handlers.Handler
+ var done <-chan struct{}
+
switch user.Name {
case config.ControlUser:
- handler = handlers.NewControlHandler(user)
+ handler, done = handlers.NewControlHandler(handlerCtx, user)
default:
- handler = handlers.NewServerHandler(user, s.catLimiterCh, s.tailLimiterCh)
+ handler, done = handlers.NewServerHandler(handlerCtx, user, s.catLimiterCh, s.tailLimiterCh)
}
- // Bi-directionally connect SSH stream to SSH handler
- brokenPipe1 := make(chan struct{})
go func() {
- defer close(brokenPipe1)
+ // Handler finished work, cancel all remaining routines
+ defer cancel()
+ <-done
+ }()
+
+ go func() {
+ // Broken pipe, cancel
+ defer cancel()
+
io.Copy(channel, handler)
}()
- brokenPipe2 := make(chan struct{})
go func() {
- defer close(brokenPipe2)
+ // Broken pipe, cancel
+ defer cancel()
+
io.Copy(handler, channel)
}()
- // Ensure to close all fd's and stop all goroutines once ssh connection terminated
go func() {
- defer s.stats.decrementConnections()
- defer handler.Close()
+ defer cancel()
if err := sshConn.Wait(); err != nil && err != io.EOF {
logger.Error(user, err)
}
+ s.stats.decrementConnections()
logger.Info(user, "Good bye Mister!")
}()
- // Close the underlying ssh socket when server shuts down
go func() {
- select {
- case <-s.stop:
- logger.Debug(user, "Server initiating shutdown on handler")
- case <-handler.Wait():
- logger.Debug(user, "Handler initiating shutdown by its own")
- case <-brokenPipe1:
- logger.Debug(user, "Broken pipe1")
- case <-brokenPipe2:
- logger.Debug(user, "Broken pipe2")
- }
+ <-handlerCtx.Done()
sshConn.Close()
logger.Info(user, "Closed SSH connection")
}()
@@ -204,9 +202,3 @@ func (*Server) controlUserCallback(c gossh.ConnMetadata, authPayload []byte) (*g
return nil, fmt.Errorf("Not authorized")
}
-
-// Stop the server.
-func (s *Server) Stop() {
- close(s.stop)
- s.stats.waitForConnections()
-}