summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorPaul Buetow <35781042+pbuetow@users.noreply.github.com>2020-09-19 19:52:11 +0100
committerGitHub <noreply@github.com>2020-09-19 19:52:11 +0100
commit3c889d2eed4e12af505ea84d46d8e52d21057a1f (patch)
tree8e6d9f697fe9a5c70f200d54745bb5daecac6bde /internal
parentec67d9833095dfbe620dd3c99ea0caba391c4b87 (diff)
parentdf2ff83897cde61d04b12958c6f6d458c69502f4 (diff)
Merge pull request #14 from snonux/develop
Refactor context handling
Diffstat (limited to 'internal')
-rw-r--r--internal/clients/baseclient.go2
-rw-r--r--internal/clients/handlers/basehandler.go23
-rw-r--r--internal/clients/handlers/clienthandler.go5
-rw-r--r--internal/clients/handlers/handler.go3
-rw-r--r--internal/clients/handlers/healthhandler.go17
-rw-r--r--internal/clients/handlers/maprhandler.go5
-rw-r--r--internal/clients/handlers/withcancel.go24
-rw-r--r--internal/clients/healthclient.go2
-rw-r--r--internal/clients/remote/connection.go7
-rw-r--r--internal/done.go32
-rw-r--r--internal/mapr/server/aggregate.go80
-rw-r--r--internal/server/handlers/controlhandler.go26
-rw-r--r--internal/server/handlers/handler.go2
-rw-r--r--internal/server/handlers/serverhandler.go75
-rw-r--r--internal/server/server.go39
15 files changed, 191 insertions, 151 deletions
diff --git a/internal/clients/baseclient.go b/internal/clients/baseclient.go
index 008a01e..d8d4fde 100644
--- a/internal/clients/baseclient.go
+++ b/internal/clients/baseclient.go
@@ -99,7 +99,7 @@ func (c *baseClient) start(ctx context.Context, active chan struct{}, i int, con
defer func() { <-active }()
for {
- connCtx, cancel := conn.Handler.WithCancel(ctx)
+ connCtx, cancel := context.WithCancel(ctx)
defer cancel()
conn.Start(connCtx, cancel, c.throttleCh, c.stats.connectionsEstCh)
diff --git a/internal/clients/handlers/basehandler.go b/internal/clients/handlers/basehandler.go
index 65bbfd7..b5045e2 100644
--- a/internal/clients/handlers/basehandler.go
+++ b/internal/clients/handlers/basehandler.go
@@ -8,12 +8,13 @@ import (
"strings"
"time"
+ "github.com/mimecast/dtail/internal"
"github.com/mimecast/dtail/internal/io/logger"
"github.com/mimecast/dtail/internal/version"
)
type baseHandler struct {
- withCancel
+ done *internal.Done
server string
shellStarted bool
commands chan string
@@ -29,6 +30,14 @@ func (h *baseHandler) Status() int {
return h.status
}
+func (h *baseHandler) Done() <-chan struct{} {
+ return h.done.Done()
+}
+
+func (h *baseHandler) Shutdown() {
+ h.done.Shutdown()
+}
+
// SendMessage to the server.
func (h *baseHandler) SendMessage(command string) error {
encoded := base64.StdEncoding.EncodeToString([]byte(command))
@@ -38,7 +47,8 @@ func (h *baseHandler) SendMessage(command string) error {
case h.commands <- fmt.Sprintf("protocol %s base64 %v;", version.ProtocolCompat, encoded):
case <-time.After(time.Second * 5):
return fmt.Errorf("Timed out sending command '%s' (base64: '%s')", command, encoded)
- case <-h.ctx.Done():
+ case <-h.Done():
+ return nil
}
return nil
@@ -65,7 +75,7 @@ func (h *baseHandler) Read(p []byte) (n int, err error) {
select {
case command := <-h.commands:
n = copy(p, []byte(command))
- case <-h.ctx.Done():
+ case <-h.Done():
return 0, io.EOF
}
return
@@ -95,10 +105,11 @@ func (h *baseHandler) handleHiddenMessage(message string) {
case strings.HasPrefix(message, ".syn close connection"):
h.SendMessage(".ack close connection")
select {
- case <-time.After(time.Second * 1):
+ case <-time.After(time.Second * 5):
logger.Debug("Shutting down client after timeout and sending ack to server")
- h.withCancel.shutdown()
- case <-h.ctx.Done():
+ h.Shutdown()
+ case <-h.Done():
+ return
}
case strings.HasPrefix(message, ".run exitstatus"):
diff --git a/internal/clients/handlers/clienthandler.go b/internal/clients/handlers/clienthandler.go
index fcd8052..2bcb038 100644
--- a/internal/clients/handlers/clienthandler.go
+++ b/internal/clients/handlers/clienthandler.go
@@ -1,6 +1,7 @@
package handlers
import (
+ "github.com/mimecast/dtail/internal"
"github.com/mimecast/dtail/internal/io/logger"
)
@@ -19,9 +20,7 @@ func NewClientHandler(server string) *ClientHandler {
shellStarted: false,
commands: make(chan string),
status: -1,
- withCancel: withCancel{
- done: make(chan struct{}),
- },
+ done: internal.NewDone(),
},
}
}
diff --git a/internal/clients/handlers/handler.go b/internal/clients/handlers/handler.go
index c53ca34..afa87e2 100644
--- a/internal/clients/handlers/handler.go
+++ b/internal/clients/handlers/handler.go
@@ -1,7 +1,6 @@
package handlers
import (
- "context"
"io"
)
@@ -11,6 +10,6 @@ type Handler interface {
SendMessage(command string) error
Server() string
Status() int
- WithCancel(ctx context.Context) (context.Context, context.CancelFunc)
+ Shutdown()
Done() <-chan struct{}
}
diff --git a/internal/clients/handlers/healthhandler.go b/internal/clients/handlers/healthhandler.go
index 9051015..95693ab 100644
--- a/internal/clients/handlers/healthhandler.go
+++ b/internal/clients/handlers/healthhandler.go
@@ -4,11 +4,13 @@ import (
"errors"
"fmt"
"time"
+
+ "github.com/mimecast/dtail/internal"
)
// HealthHandler implements the handler required for health checks.
type HealthHandler struct {
- withCancel
+ done *internal.Done
// Buffer of incoming data from server.
receiveBuf []byte
// To send commands to the server.
@@ -27,9 +29,7 @@ func NewHealthHandler(server string, receive chan<- string) *HealthHandler {
receive: receive,
commands: make(chan string),
status: -1,
- withCancel: withCancel{
- done: make(chan struct{}),
- },
+ done: internal.NewDone(),
}
return &h
@@ -45,12 +45,21 @@ func (h *HealthHandler) Status() int {
return h.status
}
+func (h *HealthHandler) Done() <-chan struct{} {
+ return h.done.Done()
+}
+
+func (h *HealthHandler) Shutdown() {
+ h.done.Shutdown()
+}
+
// SendMessage sends a DTail command to the server.
func (h *HealthHandler) SendMessage(command string) error {
select {
case h.commands <- fmt.Sprintf("%s;", command):
case <-time.NewTimer(time.Second * 10).C:
return errors.New("Timed out sending command " + command)
+ case <-h.Done():
}
return nil
diff --git a/internal/clients/handlers/maprhandler.go b/internal/clients/handlers/maprhandler.go
index b908f3b..fb71c8f 100644
--- a/internal/clients/handlers/maprhandler.go
+++ b/internal/clients/handlers/maprhandler.go
@@ -3,6 +3,7 @@ package handlers
import (
"strings"
+ "github.com/mimecast/dtail/internal"
"github.com/mimecast/dtail/internal/io/logger"
"github.com/mimecast/dtail/internal/mapr"
"github.com/mimecast/dtail/internal/mapr/client"
@@ -24,9 +25,7 @@ func NewMaprHandler(server string, query *mapr.Query, globalGroup *mapr.GlobalGr
shellStarted: false,
commands: make(chan string),
status: -1,
- withCancel: withCancel{
- done: make(chan struct{}),
- },
+ done: internal.NewDone(),
},
query: query,
aggregate: client.NewAggregate(server, query, globalGroup),
diff --git a/internal/clients/handlers/withcancel.go b/internal/clients/handlers/withcancel.go
deleted file mode 100644
index 7c9cf4e..0000000
--- a/internal/clients/handlers/withcancel.go
+++ /dev/null
@@ -1,24 +0,0 @@
-package handlers
-
-import "context"
-
-type withCancel struct {
- ctx context.Context
- done chan struct{}
-}
-
-// WithCancel sets and returns the context used.
-func (w *withCancel) WithCancel(ctx context.Context) (context.Context, context.CancelFunc) {
- cancelCtx, cancel := context.WithCancel(ctx)
- w.ctx = cancelCtx
-
- return cancelCtx, cancel
-}
-
-func (w *withCancel) Done() <-chan struct{} {
- return w.done
-}
-
-func (w *withCancel) shutdown() {
- close(w.done)
-}
diff --git a/internal/clients/healthclient.go b/internal/clients/healthclient.go
index 7313583..e93f6be 100644
--- a/internal/clients/healthclient.go
+++ b/internal/clients/healthclient.go
@@ -50,7 +50,7 @@ func (c *HealthClient) Start(ctx context.Context) (status int) {
conn.Handler = handlers.NewHealthHandler(c.server, receive)
conn.Commands = []string{c.mode.String()}
- connCtx, cancel := conn.Handler.WithCancel(ctx)
+ connCtx, cancel := context.WithCancel(ctx)
go conn.Start(connCtx, cancel, throttleCh, statsCh)
for {
diff --git a/internal/clients/remote/connection.go b/internal/clients/remote/connection.go
index 2d97d14..b29ffed 100644
--- a/internal/clients/remote/connection.go
+++ b/internal/clients/remote/connection.go
@@ -177,21 +177,21 @@ func (c *Connection) handle(ctx context.Context, cancel context.CancelFunc, sess
}
go func() {
- defer cancel()
io.Copy(stdinPipe, c.Handler)
+ cancel()
}()
go func() {
- defer cancel()
io.Copy(c.Handler, stdoutPipe)
+ cancel()
}()
go func() {
- defer cancel()
select {
case <-c.Handler.Done():
case <-ctx.Done():
}
+ cancel()
}()
// Send all commands to client.
@@ -207,5 +207,6 @@ func (c *Connection) handle(ctx context.Context, cancel context.CancelFunc, sess
}
<-ctx.Done()
+ c.Handler.Shutdown()
return nil
}
diff --git a/internal/done.go b/internal/done.go
new file mode 100644
index 0000000..2326eee
--- /dev/null
+++ b/internal/done.go
@@ -0,0 +1,32 @@
+package internal
+
+import (
+ "sync"
+)
+
+type Done struct {
+ ch chan struct{}
+ mutex sync.Mutex
+}
+
+func NewDone() *Done {
+ return &Done{
+ ch: make(chan struct{}),
+ }
+}
+
+func (d *Done) Done() <-chan struct{} {
+ return d.ch
+}
+
+func (d *Done) Shutdown() {
+ d.mutex.Lock()
+ defer d.mutex.Unlock()
+
+ select {
+ case <-d.ch:
+ return
+ default:
+ close(d.ch)
+ }
+}
diff --git a/internal/mapr/server/aggregate.go b/internal/mapr/server/aggregate.go
index 1028943..cd59b63 100644
--- a/internal/mapr/server/aggregate.go
+++ b/internal/mapr/server/aggregate.go
@@ -6,6 +6,7 @@ import (
"strings"
"time"
+ "github.com/mimecast/dtail/internal"
"github.com/mimecast/dtail/internal/config"
"github.com/mimecast/dtail/internal/io/line"
"github.com/mimecast/dtail/internal/io/logger"
@@ -15,6 +16,7 @@ import (
// Aggregate is for aggregating mapreduce data on the DTail server side.
type Aggregate struct {
+ done *internal.Done
// Log lines to process (parsing MAPREDUCE lines).
Lines chan line.Line
// Hostname of the current server (used to populate $hostname field).
@@ -23,12 +25,12 @@ type Aggregate struct {
serialize chan struct{}
// Signals to flush data.
flush chan struct{}
+ // Signals that data has been flushed
+ flushed chan struct{}
// The mapr query
query *mapr.Query
// The mapr log format parser
parser *logformat.Parser
- cancel context.CancelFunc
- ctx context.Context
}
// NewAggregate return a new server side aggregator.
@@ -64,56 +66,63 @@ func NewAggregate(queryStr string) (*Aggregate, error) {
}
}
- ctx, cancel := context.WithCancel(context.Background())
-
a := Aggregate{
+ done: internal.NewDone(),
Lines: make(chan line.Line, 100),
serialize: make(chan struct{}),
flush: make(chan struct{}),
+ flushed: make(chan struct{}),
hostname: s[0],
query: query,
parser: logParser,
- ctx: ctx,
- cancel: cancel,
}
return &a, nil
}
+func (a *Aggregate) Shutdown() {
+ a.Flush()
+ a.done.Shutdown()
+}
+
// Start an aggregation.
func (a *Aggregate) Start(ctx context.Context, maprLines chan<- string) {
- defer a.cancel()
- fieldsCh := a.linesToFields(ctx)
+ myCtx, cancel := context.WithCancel(ctx)
+ defer cancel()
+
+ go func() {
+ select {
+ case <-myCtx.Done():
+ a.done.Shutdown()
+ case <-a.done.Done():
+ cancel()
+ }
+ }()
+
+ fieldsCh := a.makeFields(myCtx)
// Add fields (e.g. via 'set' clause)
if len(a.query.Set) > 0 {
- fieldsCh = a.addMoreFields(ctx, fieldsCh)
+ fieldsCh = a.addFields(myCtx, fieldsCh)
}
- go a.fieldsToMaprLines(ctx, fieldsCh, maprLines)
- a.periodicAggregateTimer(ctx)
+ go a.aggregateTimer(myCtx)
+ a.makeMaprLines(myCtx, fieldsCh, maprLines)
}
-// Cancel the aggregation.
-func (a *Aggregate) Cancel() {
- a.cancel()
-}
-
-func (a *Aggregate) periodicAggregateTimer(ctx context.Context) {
+func (a *Aggregate) aggregateTimer(ctx context.Context) {
for {
select {
case <-time.After(a.query.Interval):
a.Serialize(ctx)
case <-ctx.Done():
return
- case <-a.ctx.Done():
- return
}
}
}
-func (a *Aggregate) linesToFields(ctx context.Context) <-chan map[string]string {
+func (a *Aggregate) makeFields(ctx context.Context) <-chan map[string]string {
ch := make(chan map[string]string)
go func() {
@@ -144,8 +153,6 @@ func (a *Aggregate) linesToFields(ctx context.Context) <-chan map[string]string
}
case <-ctx.Done():
return
- case <-a.ctx.Done():
- return
}
}
}()
@@ -153,14 +160,14 @@ func (a *Aggregate) linesToFields(ctx context.Context) <-chan map[string]string
return ch
}
-func (a *Aggregate) addMoreFields(ctx context.Context, fieldsCh <-chan map[string]string) <-chan map[string]string {
+func (a *Aggregate) addFields(ctx context.Context, fieldsCh <-chan map[string]string) <-chan map[string]string {
ch := make(chan map[string]string)
go func() {
defer close(ch)
for {
- // fieldsCh will be closed via 'linesToFields' if ctx is done
+ // fieldsCh will be closed via 'makeFields' if ctx is done
fields, ok := <-fieldsCh
if !ok {
return
@@ -179,7 +186,7 @@ func (a *Aggregate) addMoreFields(ctx context.Context, fieldsCh <-chan map[strin
return ch
}
-func (a *Aggregate) fieldsToMaprLines(ctx context.Context, fieldsCh <-chan map[string]string, maprLines chan<- string) {
+func (a *Aggregate) makeMaprLines(ctx context.Context, fieldsCh <-chan map[string]string, maprLines chan<- string) {
group := mapr.NewGroupSet()
serialize := func() {
@@ -200,18 +207,10 @@ func (a *Aggregate) fieldsToMaprLines(ctx context.Context, fieldsCh <-chan map[s
case <-a.serialize:
serialize()
case <-a.flush:
- logger.Info("Flushing mapreduce result")
serialize()
- a.flush <- struct{}{}
- logger.Info("Done flushing mapreduce result")
+ a.flushed <- struct{}{}
case <-ctx.Done():
return
- case <-a.ctx.Done():
- logger.Info("Flushing mapreduce result")
- serialize()
- a.flush <- struct{}{}
- logger.Info("Done flushing mapreduce result")
- return
}
}
}
@@ -254,6 +253,8 @@ func (a *Aggregate) aggregate(group *mapr.GroupSet, fields map[string]string) {
func (a *Aggregate) Serialize(ctx context.Context) {
select {
case a.serialize <- struct{}{}:
+ case <-time.After(time.Minute):
+ logger.Warn("Starting to serialize mapredice data takes over a minute")
case <-ctx.Done():
}
}
@@ -261,15 +262,20 @@ func (a *Aggregate) Serialize(ctx context.Context) {
// Flush all data.
func (a *Aggregate) Flush() {
select {
- case <-a.ctx.Done():
- return
case a.flush <- struct{}{}:
+ logger.Info("Flushing mapreduce data")
case <-time.After(time.Minute):
+ logger.Warn("Starting to flush mapreduce data takes over a minute")
+ return
+ case <-a.done.Done():
return
}
select {
- case <-a.flush:
+ case <-a.flushed:
+ logger.Info("Done flushing")
case <-time.After(time.Minute):
+ logger.Warn("Waiting for data to be flushed takes over a minute")
+ case <-a.done.Done():
}
}
diff --git a/internal/server/handlers/controlhandler.go b/internal/server/handlers/controlhandler.go
index daa9835..9a8eb75 100644
--- a/internal/server/handlers/controlhandler.go
+++ b/internal/server/handlers/controlhandler.go
@@ -1,20 +1,19 @@
package handlers
import (
- "context"
"fmt"
"io"
"os"
"strings"
+ "github.com/mimecast/dtail/internal"
"github.com/mimecast/dtail/internal/io/logger"
user "github.com/mimecast/dtail/internal/user/server"
)
// ControlHandler is used for control functions and health monitoring.
type ControlHandler struct {
- ctx context.Context
- done chan struct{}
+ done *internal.Done
hostname string
payload []byte
serverMessages chan string
@@ -22,12 +21,11 @@ type ControlHandler struct {
}
// NewControlHandler returns a new control handler.
-func NewControlHandler(ctx context.Context, user *user.User) (*ControlHandler, <-chan struct{}) {
+func NewControlHandler(user *user.User) *ControlHandler {
logger.Debug(user, "Creating control handler")
h := ControlHandler{
- ctx: ctx,
- done: make(chan struct{}),
+ done: internal.NewDone(),
serverMessages: make(chan string, 10),
user: user,
}
@@ -40,7 +38,15 @@ func NewControlHandler(ctx context.Context, user *user.User) (*ControlHandler, <
s := strings.Split(fqdn, ".")
h.hostname = s[0]
- return &h, h.done
+ return &h
+}
+
+func (h *ControlHandler) Shutdown() {
+ h.done.Shutdown()
+}
+
+func (h *ControlHandler) Done() <-chan struct{} {
+ return h.done.Done()
}
// Read is to send data to the client via the Reader interface.
@@ -51,7 +57,7 @@ func (h *ControlHandler) Read(p []byte) (n int, err error) {
wholePayload := []byte(fmt.Sprintf("SERVER|%s|%s\n", h.hostname, message))
n = copy(p, wholePayload)
return
- case <-h.ctx.Done():
+ case <-h.done.Done():
return 0, io.EOF
}
}
@@ -63,7 +69,7 @@ func (h *ControlHandler) Write(p []byte) (n int, err error) {
switch c {
case ';':
wholePayload := strings.TrimSpace(string(h.payload))
- h.handleCommand(h.ctx, wholePayload)
+ h.handleCommand(wholePayload)
h.payload = nil
default:
@@ -75,7 +81,7 @@ func (h *ControlHandler) Write(p []byte) (n int, err error) {
return
}
-func (h *ControlHandler) handleCommand(ctx context.Context, command string) {
+func (h *ControlHandler) handleCommand(command string) {
logger.Info(h.user, command)
s := strings.Split(command, " ")
logger.Debug(h.user, "Receiving command", command, s)
diff --git a/internal/server/handlers/handler.go b/internal/server/handlers/handler.go
index c42ceb9..b04e854 100644
--- a/internal/server/handlers/handler.go
+++ b/internal/server/handlers/handler.go
@@ -5,4 +5,6 @@ import "io"
// Handler interface for server side functionality.
type Handler interface {
io.ReadWriter
+ Shutdown()
+ Done() <-chan struct{}
}
diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go
index 7017f3e..164a280 100644
--- a/internal/server/handlers/serverhandler.go
+++ b/internal/server/handlers/serverhandler.go
@@ -13,6 +13,7 @@ import (
"sync/atomic"
"time"
+ "github.com/mimecast/dtail/internal"
"github.com/mimecast/dtail/internal/config"
"github.com/mimecast/dtail/internal/io/line"
"github.com/mimecast/dtail/internal/io/logger"
@@ -31,33 +32,28 @@ const (
// the Bi-directional communication between SSH client and server.
// This handler implements the handler of the SSH server.
type ServerHandler struct {
- lines chan line.Line
- regex string
- aggregate *server.Aggregate
- aggregatedMessages chan string
- serverMessages chan string
- payload []byte
- hostname string
- user *user.User
- // TODO: Move all these channels into a separate struct for readability!
+ done *internal.Done
+ lines chan line.Line
+ regex string
+ aggregate *server.Aggregate
+ aggregatedMessages chan string
+ serverMessages chan string
+ payload []byte
+ hostname string
+ user *user.User
catLimiter chan struct{}
tailLimiter chan struct{}
globalServerWaitFor chan struct{}
ackCloseReceived chan struct{}
- serverCtx context.Context
- handlerCtx context.Context
- done chan struct{}
activeCommands int32
activeReaders int32
background background.Background
}
// NewServerHandler returns the server handler.
-func NewServerHandler(handlerCtx, serverCtx context.Context, user *user.User, catLimiter, tailLimiter, globalServerWaitFor chan struct{}, background background.Background) (*ServerHandler, <-chan struct{}) {
+func NewServerHandler(user *user.User, catLimiter, tailLimiter, globalServerWaitFor chan struct{}, background background.Background) *ServerHandler {
h := ServerHandler{
- serverCtx: serverCtx,
- handlerCtx: handlerCtx,
- done: make(chan struct{}),
+ done: internal.NewDone(),
lines: make(chan line.Line, 100),
serverMessages: make(chan string, 10),
aggregatedMessages: make(chan string, 10),
@@ -78,7 +74,15 @@ func NewServerHandler(handlerCtx, serverCtx context.Context, user *user.User, ca
s := strings.Split(fqdn, ".")
h.hostname = s[0]
- return &h, h.done
+ return &h
+}
+
+func (h *ServerHandler) Shutdown() {
+ h.done.Shutdown()
+}
+
+func (h *ServerHandler) Done() <-chan struct{} {
+ return h.done.Done()
}
// Read is to send data to the dtail client via Reader interface.
@@ -120,7 +124,7 @@ func (h *ServerHandler) Read(p []byte) (n int, err error) {
case <-time.After(time.Second):
// Once in a while check whether we are done.
select {
- case <-h.handlerCtx.Done():
+ case <-h.done.Done():
return 0, io.EOF
default:
}
@@ -134,7 +138,7 @@ func (h *ServerHandler) Write(p []byte) (n int, err error) {
switch c {
case ';':
commandStr := strings.TrimSpace(string(h.payload))
- h.handleCommand(h.handlerCtx, commandStr)
+ h.handleCommand(commandStr)
h.payload = nil
default:
h.payload = append(h.payload, c)
@@ -145,9 +149,10 @@ func (h *ServerHandler) Write(p []byte) (n int, err error) {
return
}
-func (h *ServerHandler) handleCommand(ctx context.Context, commandStr string) {
+func (h *ServerHandler) handleCommand(commandStr string) {
logger.Debug(h.user, commandStr)
var timeout time.Duration
+ ctx := context.Background()
args, argc, err := h.handleProtocolVersion(strings.Split(commandStr, " "))
if err != nil {
@@ -172,15 +177,21 @@ func (h *ServerHandler) handleCommand(ctx context.Context, commandStr string) {
return
}
+ ctx, cancel := context.WithCancel(ctx)
+ go func() {
+ <-h.done.Done()
+ cancel()
+ }()
+
if timeout > 0 {
logger.Info(h.user, "Command with timeout context", argc, args, timeout)
- commandCtx, cancel := context.WithTimeout(ctx, timeout)
+ ctx, cancel := context.WithTimeout(ctx, timeout)
go func() {
- <-commandCtx.Done()
+ <-ctx.Done()
logger.Info(h.user, "Command timed out, canceling it", args, args, timeout)
cancel()
}()
- h.handleUserCommand(commandCtx, argc, args, timeout)
+ h.handleUserCommand(ctx, argc, args, timeout)
return
}
@@ -255,7 +266,7 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []
if h.aggregate == nil {
return
}
- h.aggregate.Cancel()
+ h.aggregate.Shutdown()
}
}
@@ -348,9 +359,8 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []
// Set default background timeout.
timeout = time.Hour * 1
}
- // Use a new context based on the server context, so that background job does not get
- // terminated when handler/SSH connection terminates.
- commandCtx, cancel := context.WithTimeout(h.serverCtx, timeout)
+
+ commandCtx, cancel := context.WithTimeout(ctx, timeout)
if err := h.background.Add(h.user.Name, jobName, cancel, &wg); err != nil {
h.sendServerMessage(logger.Error(h.user, err, jobName, args))
@@ -406,7 +416,7 @@ func (h *ServerHandler) handleAckCommand(argc int, args []string) {
func (h *ServerHandler) send(ch chan<- string, message string) {
select {
case ch <- message:
- case <-h.handlerCtx.Done():
+ case <-h.done.Done():
}
}
@@ -447,7 +457,7 @@ func (h *ServerHandler) shutdown() {
go func() {
select {
case h.serverMessageC() <- ".syn close connection":
- case <-h.handlerCtx.Done():
+ case <-h.done.Done():
}
}()
@@ -455,13 +465,10 @@ func (h *ServerHandler) shutdown() {
case <-h.ackCloseReceived:
case <-time.After(time.Second * 5):
logger.Debug(h.user, "Shutdown timeout reached, enforcing shutdown")
- case <-h.handlerCtx.Done():
+ case <-h.done.Done():
}
- select {
- case h.done <- struct{}{}:
- default:
- }
+ h.done.Shutdown()
}
func (h *ServerHandler) incrementActiveCommands() {
diff --git a/internal/server/server.go b/internal/server/server.go
index a446738..5e2a521 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -178,53 +178,46 @@ func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, in <-ch
switch req.Type {
case "shell":
- handlerCtx, cancel := context.WithCancel(ctx)
-
var handler handlers.Handler
- var done <-chan struct{}
-
switch user.Name {
case config.ControlUser:
- handler, done = handlers.NewControlHandler(handlerCtx, user)
+ handler = handlers.NewControlHandler(user)
default:
- handler, done = handlers.NewServerHandler(handlerCtx, ctx, user, s.catLimiter, s.tailLimiter, s.shutdownWaitFor, s.background)
+ handler = handlers.NewServerHandler(user, s.catLimiter, s.tailLimiter, s.shutdownWaitFor, s.background)
}
- go func() {
- // Handler finished work, cancel all remaining routines
- defer cancel()
-
- <-done
- }()
+ terminate := func() {
+ handler.Shutdown()
+ sshConn.Close()
+ }
go func() {
// Broken pipe, cancel
- defer cancel()
-
io.Copy(channel, handler)
+ terminate()
}()
go func() {
// Broken pipe, cancel
- defer cancel()
-
io.Copy(handler, channel)
+ terminate()
}()
go func() {
- defer cancel()
+ select {
+ case <-ctx.Done():
+ case <-handler.Done():
+ }
+ terminate()
+ }()
+ go func() {
if err := sshConn.Wait(); err != nil && err != io.EOF {
logger.Error(user, err)
}
s.stats.decrementConnections()
logger.Info(user, "Good bye Mister!")
- }()
-
- go func() {
- <-handlerCtx.Done()
- sshConn.Close()
- logger.Info(user, "Closed SSH connection")
+ terminate()
}()
// Only serving shell type