summaryrefslogtreecommitdiff
path: root/internal/server
diff options
context:
space:
mode:
authorPaul Buetow <35781042+pbuetow@users.noreply.github.com>2020-09-19 19:52:11 +0100
committerGitHub <noreply@github.com>2020-09-19 19:52:11 +0100
commit3c889d2eed4e12af505ea84d46d8e52d21057a1f (patch)
tree8e6d9f697fe9a5c70f200d54745bb5daecac6bde /internal/server
parentec67d9833095dfbe620dd3c99ea0caba391c4b87 (diff)
parentdf2ff83897cde61d04b12958c6f6d458c69502f4 (diff)
Merge pull request #14 from snonux/develop
Refactor context handling
Diffstat (limited to 'internal/server')
-rw-r--r--internal/server/handlers/controlhandler.go26
-rw-r--r--internal/server/handlers/handler.go2
-rw-r--r--internal/server/handlers/serverhandler.go75
-rw-r--r--internal/server/server.go39
4 files changed, 75 insertions, 67 deletions
diff --git a/internal/server/handlers/controlhandler.go b/internal/server/handlers/controlhandler.go
index daa9835..9a8eb75 100644
--- a/internal/server/handlers/controlhandler.go
+++ b/internal/server/handlers/controlhandler.go
@@ -1,20 +1,19 @@
package handlers
import (
- "context"
"fmt"
"io"
"os"
"strings"
+ "github.com/mimecast/dtail/internal"
"github.com/mimecast/dtail/internal/io/logger"
user "github.com/mimecast/dtail/internal/user/server"
)
// ControlHandler is used for control functions and health monitoring.
type ControlHandler struct {
- ctx context.Context
- done chan struct{}
+ done *internal.Done
hostname string
payload []byte
serverMessages chan string
@@ -22,12 +21,11 @@ type ControlHandler struct {
}
// NewControlHandler returns a new control handler.
-func NewControlHandler(ctx context.Context, user *user.User) (*ControlHandler, <-chan struct{}) {
+func NewControlHandler(user *user.User) *ControlHandler {
logger.Debug(user, "Creating control handler")
h := ControlHandler{
- ctx: ctx,
- done: make(chan struct{}),
+ done: internal.NewDone(),
serverMessages: make(chan string, 10),
user: user,
}
@@ -40,7 +38,15 @@ func NewControlHandler(ctx context.Context, user *user.User) (*ControlHandler, <
s := strings.Split(fqdn, ".")
h.hostname = s[0]
- return &h, h.done
+ return &h
+}
+
+func (h *ControlHandler) Shutdown() {
+ h.done.Shutdown()
+}
+
+func (h *ControlHandler) Done() <-chan struct{} {
+ return h.done.Done()
}
// Read is to send data to the client via the Reader interface.
@@ -51,7 +57,7 @@ func (h *ControlHandler) Read(p []byte) (n int, err error) {
wholePayload := []byte(fmt.Sprintf("SERVER|%s|%s\n", h.hostname, message))
n = copy(p, wholePayload)
return
- case <-h.ctx.Done():
+ case <-h.done.Done():
return 0, io.EOF
}
}
@@ -63,7 +69,7 @@ func (h *ControlHandler) Write(p []byte) (n int, err error) {
switch c {
case ';':
wholePayload := strings.TrimSpace(string(h.payload))
- h.handleCommand(h.ctx, wholePayload)
+ h.handleCommand(wholePayload)
h.payload = nil
default:
@@ -75,7 +81,7 @@ func (h *ControlHandler) Write(p []byte) (n int, err error) {
return
}
-func (h *ControlHandler) handleCommand(ctx context.Context, command string) {
+func (h *ControlHandler) handleCommand(command string) {
logger.Info(h.user, command)
s := strings.Split(command, " ")
logger.Debug(h.user, "Receiving command", command, s)
diff --git a/internal/server/handlers/handler.go b/internal/server/handlers/handler.go
index c42ceb9..b04e854 100644
--- a/internal/server/handlers/handler.go
+++ b/internal/server/handlers/handler.go
@@ -5,4 +5,6 @@ import "io"
// Handler interface for server side functionality.
type Handler interface {
io.ReadWriter
+ Shutdown()
+ Done() <-chan struct{}
}
diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go
index 7017f3e..164a280 100644
--- a/internal/server/handlers/serverhandler.go
+++ b/internal/server/handlers/serverhandler.go
@@ -13,6 +13,7 @@ import (
"sync/atomic"
"time"
+ "github.com/mimecast/dtail/internal"
"github.com/mimecast/dtail/internal/config"
"github.com/mimecast/dtail/internal/io/line"
"github.com/mimecast/dtail/internal/io/logger"
@@ -31,33 +32,28 @@ const (
// the Bi-directional communication between SSH client and server.
// This handler implements the handler of the SSH server.
type ServerHandler struct {
- lines chan line.Line
- regex string
- aggregate *server.Aggregate
- aggregatedMessages chan string
- serverMessages chan string
- payload []byte
- hostname string
- user *user.User
- // TODO: Move all these channels into a separate struct for readability!
+ done *internal.Done
+ lines chan line.Line
+ regex string
+ aggregate *server.Aggregate
+ aggregatedMessages chan string
+ serverMessages chan string
+ payload []byte
+ hostname string
+ user *user.User
catLimiter chan struct{}
tailLimiter chan struct{}
globalServerWaitFor chan struct{}
ackCloseReceived chan struct{}
- serverCtx context.Context
- handlerCtx context.Context
- done chan struct{}
activeCommands int32
activeReaders int32
background background.Background
}
// NewServerHandler returns the server handler.
-func NewServerHandler(handlerCtx, serverCtx context.Context, user *user.User, catLimiter, tailLimiter, globalServerWaitFor chan struct{}, background background.Background) (*ServerHandler, <-chan struct{}) {
+func NewServerHandler(user *user.User, catLimiter, tailLimiter, globalServerWaitFor chan struct{}, background background.Background) *ServerHandler {
h := ServerHandler{
- serverCtx: serverCtx,
- handlerCtx: handlerCtx,
- done: make(chan struct{}),
+ done: internal.NewDone(),
lines: make(chan line.Line, 100),
serverMessages: make(chan string, 10),
aggregatedMessages: make(chan string, 10),
@@ -78,7 +74,15 @@ func NewServerHandler(handlerCtx, serverCtx context.Context, user *user.User, ca
s := strings.Split(fqdn, ".")
h.hostname = s[0]
- return &h, h.done
+ return &h
+}
+
+func (h *ServerHandler) Shutdown() {
+ h.done.Shutdown()
+}
+
+func (h *ServerHandler) Done() <-chan struct{} {
+ return h.done.Done()
}
// Read is to send data to the dtail client via Reader interface.
@@ -120,7 +124,7 @@ func (h *ServerHandler) Read(p []byte) (n int, err error) {
case <-time.After(time.Second):
// Once in a while check whether we are done.
select {
- case <-h.handlerCtx.Done():
+ case <-h.done.Done():
return 0, io.EOF
default:
}
@@ -134,7 +138,7 @@ func (h *ServerHandler) Write(p []byte) (n int, err error) {
switch c {
case ';':
commandStr := strings.TrimSpace(string(h.payload))
- h.handleCommand(h.handlerCtx, commandStr)
+ h.handleCommand(commandStr)
h.payload = nil
default:
h.payload = append(h.payload, c)
@@ -145,9 +149,10 @@ func (h *ServerHandler) Write(p []byte) (n int, err error) {
return
}
-func (h *ServerHandler) handleCommand(ctx context.Context, commandStr string) {
+func (h *ServerHandler) handleCommand(commandStr string) {
logger.Debug(h.user, commandStr)
var timeout time.Duration
+ ctx := context.Background()
args, argc, err := h.handleProtocolVersion(strings.Split(commandStr, " "))
if err != nil {
@@ -172,15 +177,21 @@ func (h *ServerHandler) handleCommand(ctx context.Context, commandStr string) {
return
}
+ ctx, cancel := context.WithCancel(ctx)
+ go func() {
+ <-h.done.Done()
+ cancel()
+ }()
+
if timeout > 0 {
logger.Info(h.user, "Command with timeout context", argc, args, timeout)
- commandCtx, cancel := context.WithTimeout(ctx, timeout)
+ ctx, cancel := context.WithTimeout(ctx, timeout)
go func() {
- <-commandCtx.Done()
+ <-ctx.Done()
logger.Info(h.user, "Command timed out, canceling it", args, args, timeout)
cancel()
}()
- h.handleUserCommand(commandCtx, argc, args, timeout)
+ h.handleUserCommand(ctx, argc, args, timeout)
return
}
@@ -255,7 +266,7 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []
if h.aggregate == nil {
return
}
- h.aggregate.Cancel()
+ h.aggregate.Shutdown()
}
}
@@ -348,9 +359,8 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []
// Set default background timeout.
timeout = time.Hour * 1
}
- // Use a new context based on the server context, so that background job does not get
- // terminated when handler/SSH connection terminates.
- commandCtx, cancel := context.WithTimeout(h.serverCtx, timeout)
+
+ commandCtx, cancel := context.WithTimeout(ctx, timeout)
if err := h.background.Add(h.user.Name, jobName, cancel, &wg); err != nil {
h.sendServerMessage(logger.Error(h.user, err, jobName, args))
@@ -406,7 +416,7 @@ func (h *ServerHandler) handleAckCommand(argc int, args []string) {
func (h *ServerHandler) send(ch chan<- string, message string) {
select {
case ch <- message:
- case <-h.handlerCtx.Done():
+ case <-h.done.Done():
}
}
@@ -447,7 +457,7 @@ func (h *ServerHandler) shutdown() {
go func() {
select {
case h.serverMessageC() <- ".syn close connection":
- case <-h.handlerCtx.Done():
+ case <-h.done.Done():
}
}()
@@ -455,13 +465,10 @@ func (h *ServerHandler) shutdown() {
case <-h.ackCloseReceived:
case <-time.After(time.Second * 5):
logger.Debug(h.user, "Shutdown timeout reached, enforcing shutdown")
- case <-h.handlerCtx.Done():
+ case <-h.done.Done():
}
- select {
- case h.done <- struct{}{}:
- default:
- }
+ h.done.Shutdown()
}
func (h *ServerHandler) incrementActiveCommands() {
diff --git a/internal/server/server.go b/internal/server/server.go
index a446738..5e2a521 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -178,53 +178,46 @@ func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, in <-ch
switch req.Type {
case "shell":
- handlerCtx, cancel := context.WithCancel(ctx)
-
var handler handlers.Handler
- var done <-chan struct{}
-
switch user.Name {
case config.ControlUser:
- handler, done = handlers.NewControlHandler(handlerCtx, user)
+ handler = handlers.NewControlHandler(user)
default:
- handler, done = handlers.NewServerHandler(handlerCtx, ctx, user, s.catLimiter, s.tailLimiter, s.shutdownWaitFor, s.background)
+ handler = handlers.NewServerHandler(user, s.catLimiter, s.tailLimiter, s.shutdownWaitFor, s.background)
}
- go func() {
- // Handler finished work, cancel all remaining routines
- defer cancel()
-
- <-done
- }()
+ terminate := func() {
+ handler.Shutdown()
+ sshConn.Close()
+ }
go func() {
// Broken pipe, cancel
- defer cancel()
-
io.Copy(channel, handler)
+ terminate()
}()
go func() {
// Broken pipe, cancel
- defer cancel()
-
io.Copy(handler, channel)
+ terminate()
}()
go func() {
- defer cancel()
+ select {
+ case <-ctx.Done():
+ case <-handler.Done():
+ }
+ terminate()
+ }()
+ go func() {
if err := sshConn.Wait(); err != nil && err != io.EOF {
logger.Error(user, err)
}
s.stats.decrementConnections()
logger.Info(user, "Good bye Mister!")
- }()
-
- go func() {
- <-handlerCtx.Done()
- sshConn.Close()
- logger.Info(user, "Closed SSH connection")
+ terminate()
}()
// Only serving shell type