summaryrefslogtreecommitdiff
path: root/internal/server
diff options
context:
space:
mode:
authorPaul Bütow <pbuetow@mimecast.com>2020-02-16 18:07:36 +0000
committerPaul Bütow <pbuetow@mimecast.com>2020-02-16 18:07:36 +0000
commite0f4ccc46c8601f322640b72e100f973a837ef02 (patch)
tree61a1fcf66daea222da19500b0b6ae60d1e89a5d9 /internal/server
parent6bca637513e065a33cadaccad97ada25eb7a6b00 (diff)
server kills subprocesses correctly on cancel
Diffstat (limited to 'internal/server')
-rw-r--r--internal/server/handlers/runcommand.go4
-rw-r--r--internal/server/handlers/serverhandler.go73
-rw-r--r--internal/server/server.go46
3 files changed, 99 insertions, 24 deletions
diff --git a/internal/server/handlers/runcommand.go b/internal/server/handlers/runcommand.go
index 120c402..95db52f 100644
--- a/internal/server/handlers/runcommand.go
+++ b/internal/server/handlers/runcommand.go
@@ -92,4 +92,8 @@ func (r runCommand) start(ctx context.Context, command string) {
r.server.sendServerMessage(fmt.Sprintf(".run exitstatus %d", ec))
r.server.sendServerMessage(logger.Info(fmt.Sprintf("Process %d exited with status %d", pid, ec)))
+
+ logger.Debug(r.server.user, "Waiting for Pgroup to be killed")
+ <-r.run.PgroupKilled()
+ logger.Debug(r.server.user, "Pgroup killed")
}
diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go
index b840c77..946ae83 100644
--- a/internal/server/handlers/serverhandler.go
+++ b/internal/server/handlers/serverhandler.go
@@ -7,6 +7,7 @@ import (
"fmt"
"io"
"os"
+ "strconv"
"strings"
"sync"
"time"
@@ -37,28 +38,31 @@ type ServerHandler struct {
payload []byte
hostname string
user *user.User
- catLimiter chan struct{}
- tailLimiter chan struct{}
- ackCloseReceived chan struct{}
- ctx context.Context
- done chan struct{}
- activeReaders int
+ // TODO: Move all these channels into a separate struct for readability!
+ catLimiter chan struct{}
+ tailLimiter chan struct{}
+ globalServerWaitFor chan struct{}
+ ackCloseReceived chan struct{}
+ ctx context.Context
+ done chan struct{}
+ activeReaders int
}
// NewServerHandler returns the server handler.
-func NewServerHandler(ctx context.Context, user *user.User, catLimiter chan struct{}, tailLimiter chan struct{}) (*ServerHandler, <-chan struct{}) {
+func NewServerHandler(ctx context.Context, user *user.User, catLimiter, tailLimiter, globalServerWaitFor chan struct{}) (*ServerHandler, <-chan struct{}) {
h := ServerHandler{
- ctx: ctx,
- done: make(chan struct{}),
- mutex: &sync.Mutex{},
- lines: make(chan line.Line, 100),
- serverMessages: make(chan string, 10),
- aggregatedMessages: make(chan string, 10),
- ackCloseReceived: make(chan struct{}),
- catLimiter: catLimiter,
- tailLimiter: tailLimiter,
- regex: ".",
- user: user,
+ ctx: ctx,
+ done: make(chan struct{}),
+ mutex: &sync.Mutex{},
+ lines: make(chan line.Line, 100),
+ serverMessages: make(chan string, 10),
+ aggregatedMessages: make(chan string, 10),
+ ackCloseReceived: make(chan struct{}),
+ catLimiter: catLimiter,
+ tailLimiter: tailLimiter,
+ globalServerWaitFor: globalServerWaitFor,
+ regex: ".",
+ user: user,
}
fqdn, err := os.Hostname()
@@ -135,6 +139,7 @@ func (h *ServerHandler) Write(p []byte) (n int, err error) {
func (h *ServerHandler) handleCommand(ctx context.Context, commandStr string) {
logger.Debug(h.user, commandStr)
+ var timeout time.Duration
args, argc, err := h.handleProtocolVersion(strings.Split(commandStr, " "))
if err != nil {
@@ -148,11 +153,28 @@ func (h *ServerHandler) handleCommand(ctx context.Context, commandStr string) {
return
}
+ args, argc, timeout, err = h.handleTimeout(args, argc)
+ if err != nil {
+ h.send(h.serverMessages, logger.Error(h.user, err))
+ return
+ }
+
if h.user.Name == config.ControlUser {
h.handleControlCommand(argc, args)
return
}
+ if timeout > 0 {
+ logger.Debug("Command with timeout context", argc, args, timeout)
+ commandCtx, cancel := context.WithTimeout(ctx, timeout)
+ go func() {
+ <-commandCtx.Done()
+ cancel()
+ }()
+ h.handleUserCommand(commandCtx, argc, args)
+ return
+ }
+
h.handleUserCommand(ctx, argc, args)
}
@@ -191,6 +213,16 @@ func (h *ServerHandler) handleBase64(args []string, argc int) ([]string, int, er
return args, argc, nil
}
+func (h *ServerHandler) handleTimeout(args []string, argc int) ([]string, int, time.Duration, error) {
+ if argc <= 2 || args[0] != "timeout" {
+ // No timeout specified
+ return args, argc, time.Duration(0) * time.Second, nil
+ }
+
+ timeout, err := strconv.Atoi(args[1])
+ return args[2:], argc - 2, time.Duration(timeout) * time.Second, err
+}
+
func (h *ServerHandler) handleControlCommand(argc int, args []string) {
switch args[0] {
case "debug":
@@ -241,8 +273,10 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []
case "run":
command := newRunCommand(h)
h.incrementActiveReaders()
+ go func() { h.globalServerWaitFor <- struct{}{} }()
go func() {
command.Start(ctx, argc, args)
+ <-h.globalServerWaitFor
if h.decrementActiveReaders() == 0 {
h.shutdown()
}
@@ -332,12 +366,13 @@ func (h *ServerHandler) incrementActiveReaders() {
// TODO: Use atomic counter variable instead, so we can get rid of the mutex
h.mutex.Lock()
defer h.mutex.Unlock()
+
h.activeReaders++
}
func (h *ServerHandler) decrementActiveReaders() int {
h.mutex.Lock()
defer h.mutex.Unlock()
- h.activeReaders--
+ h.activeReaders--
return h.activeReaders
}
diff --git a/internal/server/server.go b/internal/server/server.go
index 9b9d8a0..9314468 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -6,6 +6,7 @@ import (
"fmt"
"io"
"net"
+ "time"
"github.com/mimecast/dtail/internal/config"
"github.com/mimecast/dtail/internal/io/logger"
@@ -24,11 +25,13 @@ type Server struct {
// SSH server configuration.
sshServerConfig *gossh.ServerConfig
// To control the max amount of concurrent cats (which can cause a lot of I/O on the server)
- catLimiterCh chan struct{}
+ catLimiter chan struct{}
// To control the max amount of concurrent tails
- tailLimiterCh chan struct{}
+ tailLimiter chan struct{}
// To run scheduled tasks (if configured)
sched *scheduler
+ // Wait counter, e.g. there might be still subprocesses (forked by drun) to be killed.
+ shutdownWaitFor chan struct{}
}
// New returns a new server.
@@ -37,8 +40,9 @@ func New() *Server {
s := Server{
sshServerConfig: &gossh.ServerConfig{},
- catLimiterCh: make(chan struct{}, config.Server.MaxConcurrentCats),
- tailLimiterCh: make(chan struct{}, config.Server.MaxConcurrentTails),
+ catLimiter: make(chan struct{}, config.Server.MaxConcurrentCats),
+ tailLimiter: make(chan struct{}, config.Server.MaxConcurrentTails),
+ shutdownWaitFor: make(chan struct{}, 1000),
sched: newScheduler(),
}
@@ -60,6 +64,7 @@ func (s *Server) Start(ctx context.Context) int {
bindAt := fmt.Sprintf("%s:%d", config.Server.SSHBindAddress, config.Common.SSHPort)
logger.Info("Binding server", bindAt)
+
listener, err := net.Listen("tcp", bindAt)
if err != nil {
logger.FatalExit("Failed to open listening TCP socket", err)
@@ -67,10 +72,40 @@ func (s *Server) Start(ctx context.Context) int {
go s.stats.start(ctx)
go s.sched.start(ctx)
+ go s.listenerLoop(ctx, listener)
+
+ select {
+ case <-ctx.Done():
+ // Wait until all commands/jobs/children are no more!
+ s.wait()
+ }
+
+ // For future use.
+ return 0
+}
+
+func (s *Server) wait() {
+ for {
+ num := len(s.shutdownWaitFor)
+ logger.Debug("Waiting for stuff to finish", num)
+ if num <= 0 {
+ return
+ }
+ time.Sleep(time.Second)
+ }
+}
+
+func (s *Server) listenerLoop(ctx context.Context, listener net.Listener) {
+ logger.Debug("Starting listener loop")
for {
conn, err := listener.Accept() // Blocking
if err != nil {
+ select {
+ case <-ctx.Done():
+ return
+ default:
+ }
logger.Error("Failed to accept incoming connection", err)
continue
}
@@ -143,12 +178,13 @@ func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, in <-ch
case config.ControlUser:
handler, done = handlers.NewControlHandler(handlerCtx, user)
default:
- handler, done = handlers.NewServerHandler(handlerCtx, user, s.catLimiterCh, s.tailLimiterCh)
+ handler, done = handlers.NewServerHandler(handlerCtx, user, s.catLimiter, s.tailLimiter, s.shutdownWaitFor)
}
go func() {
// Handler finished work, cancel all remaining routines
defer cancel()
+
<-done
}()