summaryrefslogtreecommitdiff
path: root/internal/server/handlers/serverhandler.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/server/handlers/serverhandler.go')
-rw-r--r--internal/server/handlers/serverhandler.go75
1 files changed, 41 insertions, 34 deletions
diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go
index 7017f3e..164a280 100644
--- a/internal/server/handlers/serverhandler.go
+++ b/internal/server/handlers/serverhandler.go
@@ -13,6 +13,7 @@ import (
"sync/atomic"
"time"
+ "github.com/mimecast/dtail/internal"
"github.com/mimecast/dtail/internal/config"
"github.com/mimecast/dtail/internal/io/line"
"github.com/mimecast/dtail/internal/io/logger"
@@ -31,33 +32,28 @@ const (
// the Bi-directional communication between SSH client and server.
// This handler implements the handler of the SSH server.
type ServerHandler struct {
- lines chan line.Line
- regex string
- aggregate *server.Aggregate
- aggregatedMessages chan string
- serverMessages chan string
- payload []byte
- hostname string
- user *user.User
- // TODO: Move all these channels into a separate struct for readability!
+ done *internal.Done
+ lines chan line.Line
+ regex string
+ aggregate *server.Aggregate
+ aggregatedMessages chan string
+ serverMessages chan string
+ payload []byte
+ hostname string
+ user *user.User
catLimiter chan struct{}
tailLimiter chan struct{}
globalServerWaitFor chan struct{}
ackCloseReceived chan struct{}
- serverCtx context.Context
- handlerCtx context.Context
- done chan struct{}
activeCommands int32
activeReaders int32
background background.Background
}
// NewServerHandler returns the server handler.
-func NewServerHandler(handlerCtx, serverCtx context.Context, user *user.User, catLimiter, tailLimiter, globalServerWaitFor chan struct{}, background background.Background) (*ServerHandler, <-chan struct{}) {
+func NewServerHandler(user *user.User, catLimiter, tailLimiter, globalServerWaitFor chan struct{}, background background.Background) *ServerHandler {
h := ServerHandler{
- serverCtx: serverCtx,
- handlerCtx: handlerCtx,
- done: make(chan struct{}),
+ done: internal.NewDone(),
lines: make(chan line.Line, 100),
serverMessages: make(chan string, 10),
aggregatedMessages: make(chan string, 10),
@@ -78,7 +74,15 @@ func NewServerHandler(handlerCtx, serverCtx context.Context, user *user.User, ca
s := strings.Split(fqdn, ".")
h.hostname = s[0]
- return &h, h.done
+ return &h
+}
+
+func (h *ServerHandler) Shutdown() {
+ h.done.Shutdown()
+}
+
+func (h *ServerHandler) Done() <-chan struct{} {
+ return h.done.Done()
}
// Read is to send data to the dtail client via Reader interface.
@@ -120,7 +124,7 @@ func (h *ServerHandler) Read(p []byte) (n int, err error) {
case <-time.After(time.Second):
// Once in a while check whether we are done.
select {
- case <-h.handlerCtx.Done():
+ case <-h.done.Done():
return 0, io.EOF
default:
}
@@ -134,7 +138,7 @@ func (h *ServerHandler) Write(p []byte) (n int, err error) {
switch c {
case ';':
commandStr := strings.TrimSpace(string(h.payload))
- h.handleCommand(h.handlerCtx, commandStr)
+ h.handleCommand(commandStr)
h.payload = nil
default:
h.payload = append(h.payload, c)
@@ -145,9 +149,10 @@ func (h *ServerHandler) Write(p []byte) (n int, err error) {
return
}
-func (h *ServerHandler) handleCommand(ctx context.Context, commandStr string) {
+func (h *ServerHandler) handleCommand(commandStr string) {
logger.Debug(h.user, commandStr)
var timeout time.Duration
+ ctx := context.Background()
args, argc, err := h.handleProtocolVersion(strings.Split(commandStr, " "))
if err != nil {
@@ -172,15 +177,21 @@ func (h *ServerHandler) handleCommand(ctx context.Context, commandStr string) {
return
}
+ ctx, cancel := context.WithCancel(ctx)
+ go func() {
+ <-h.done.Done()
+ cancel()
+ }()
+
if timeout > 0 {
logger.Info(h.user, "Command with timeout context", argc, args, timeout)
- commandCtx, cancel := context.WithTimeout(ctx, timeout)
+ ctx, cancel := context.WithTimeout(ctx, timeout)
go func() {
- <-commandCtx.Done()
+ <-ctx.Done()
logger.Info(h.user, "Command timed out, canceling it", args, args, timeout)
cancel()
}()
- h.handleUserCommand(commandCtx, argc, args, timeout)
+ h.handleUserCommand(ctx, argc, args, timeout)
return
}
@@ -255,7 +266,7 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []
if h.aggregate == nil {
return
}
- h.aggregate.Cancel()
+ h.aggregate.Shutdown()
}
}
@@ -348,9 +359,8 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []
// Set default background timeout.
timeout = time.Hour * 1
}
- // Use a new context based on the server context, so that background job does not get
- // terminated when handler/SSH connection terminates.
- commandCtx, cancel := context.WithTimeout(h.serverCtx, timeout)
+
+ commandCtx, cancel := context.WithTimeout(ctx, timeout)
if err := h.background.Add(h.user.Name, jobName, cancel, &wg); err != nil {
h.sendServerMessage(logger.Error(h.user, err, jobName, args))
@@ -406,7 +416,7 @@ func (h *ServerHandler) handleAckCommand(argc int, args []string) {
func (h *ServerHandler) send(ch chan<- string, message string) {
select {
case ch <- message:
- case <-h.handlerCtx.Done():
+ case <-h.done.Done():
}
}
@@ -447,7 +457,7 @@ func (h *ServerHandler) shutdown() {
go func() {
select {
case h.serverMessageC() <- ".syn close connection":
- case <-h.handlerCtx.Done():
+ case <-h.done.Done():
}
}()
@@ -455,13 +465,10 @@ func (h *ServerHandler) shutdown() {
case <-h.ackCloseReceived:
case <-time.After(time.Second * 5):
logger.Debug(h.user, "Shutdown timeout reached, enforcing shutdown")
- case <-h.handlerCtx.Done():
+ case <-h.done.Done():
}
- select {
- case h.done <- struct{}{}:
- default:
- }
+ h.done.Shutdown()
}
func (h *ServerHandler) incrementActiveCommands() {