summaryrefslogtreecommitdiff
path: root/internal/server
diff options
context:
space:
mode:
authorPaul Buetow <pbuetow@mimecast.com>2020-03-04 16:32:27 +0000
committerPaul Buetow <pbuetow@mimecast.com>2020-03-04 16:32:27 +0000
commit238dd3930e9c58397a86f649c77912ee32e4d7f0 (patch)
treeb4cda0b8c677188b600478522471628b5d4efea4 /internal/server
parent89d3ebfc4e0c947977e5f455ee76f3ce29cc92c7 (diff)
can tail probe with a given timeout and then write a mapreduce result
Diffstat (limited to 'internal/server')
-rw-r--r--internal/server/handlers/serverhandler.go70
1 files changed, 42 insertions, 28 deletions
diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go
index 739696c..939388c 100644
--- a/internal/server/handlers/serverhandler.go
+++ b/internal/server/handlers/serverhandler.go
@@ -10,6 +10,7 @@ import (
"strconv"
"strings"
"sync"
+ "sync/atomic"
"time"
"github.com/mimecast/dtail/internal/config"
@@ -30,7 +31,6 @@ const (
// the Bi-directional communication between SSH client and server.
// This handler implements the handler of the SSH server.
type ServerHandler struct {
- mutex *sync.Mutex
lines chan line.Line
regex string
aggregate *server.Aggregate
@@ -47,7 +47,8 @@ type ServerHandler struct {
serverCtx context.Context
handlerCtx context.Context
done chan struct{}
- activeCommands int
+ activeCommands int32
+ activeReaders int32
background background.Background
}
@@ -57,7 +58,6 @@ func NewServerHandler(handlerCtx, serverCtx context.Context, user *user.User, ca
serverCtx: serverCtx,
handlerCtx: handlerCtx,
done: make(chan struct{}),
- mutex: &sync.Mutex{},
lines: make(chan line.Line, 100),
serverMessages: make(chan string, 10),
aggregatedMessages: make(chan string, 10),
@@ -170,10 +170,11 @@ func (h *ServerHandler) handleCommand(ctx context.Context, commandStr string) {
}
if timeout > 0 {
- logger.Debug("Command with timeout context", argc, args, timeout)
+ logger.Info(h.user, "Command with timeout context", argc, args, timeout)
commandCtx, cancel := context.WithTimeout(ctx, timeout)
go func() {
<-commandCtx.Done()
+ logger.Info(h.user, "Command timed out, canceling it", args, args, timeout)
cancel()
}()
h.handleUserCommand(commandCtx, argc, args, timeout)
@@ -241,11 +242,19 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []
logger.Debug(h.user, "handleUserCommand", argc, args)
h.incrementActiveCommands()
- finished := func() {
+ commandFinished := func() {
if h.decrementActiveCommands() == 0 {
h.shutdown()
}
}
+ readerFinished := func() {
+ if h.decrementActiveReaders() == 0 {
+ if h.aggregate == nil {
+ return
+ }
+ h.aggregate.Cancel()
+ }
+ }
splitted := strings.Split(args[0], ":")
commandName := splitted[0]
@@ -253,24 +262,27 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []
options, err := readOptions(splitted[1:])
if err != nil {
h.sendServerMessage(logger.Error(h.user, err))
- finished()
+ commandFinished()
return
}
switch commandName {
case "grep", "cat":
command := newReadCommand(h, omode.CatClient)
- h.incrementActiveCommands()
go func() {
+ h.incrementActiveReaders()
command.Start(ctx, argc, args)
- finished()
+ readerFinished()
+ commandFinished()
}()
case "tail":
command := newReadCommand(h, omode.TailClient)
go func() {
+ h.incrementActiveReaders()
command.Start(ctx, argc, args)
- finished()
+ readerFinished()
+ commandFinished()
}()
case "map":
@@ -278,14 +290,14 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []
if err != nil {
h.sendServerMessage(err.Error())
logger.Error(h.user, err)
- finished()
+ commandFinished()
return
}
h.aggregate = aggregate
go func() {
command.Start(ctx, h.aggregatedMessages)
- finished()
+ commandFinished()
}()
case "run":
@@ -301,7 +313,7 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []
} else {
h.sendServerMessage(logger.Info(h.user, "job cancelled", jobName))
}
- finished()
+ commandFinished()
return
}
@@ -313,7 +325,7 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []
count++
}
h.sendServerMessage(fmt.Sprintf("Found %d jobs", count))
- finished()
+ commandFinished()
return
}
@@ -339,7 +351,7 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []
if err := h.background.Add(h.user.Name, jobName, cancel, &wg); err != nil {
h.sendServerMessage(logger.Error(h.user, err, jobName, args))
- finished()
+ commandFinished()
return
}
ctx = commandCtx
@@ -347,7 +359,7 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []
if err := command.StartBackground(ctx, &wg, argc, args, outerArgs); err != nil {
h.sendServerMessage(logger.Error(h.user, "Unable to execute command", argc, args, err))
- finished()
+ commandFinished()
return
}
@@ -360,21 +372,21 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []
if background {
h.sendServerMessage(logger.Info(h.user, jobName, "job started in background"))
- finished()
+ commandFinished()
return
}
// Command run in foreground, wait for it to complete before finishing the connection.
wg.Wait()
- finished()
+ commandFinished()
case "ack", ".ack":
h.handleAckCommand(argc, args)
- finished()
+ commandFinished()
default:
h.sendServerMessage(logger.Error(h.user, "Received unknown user command", commandName, argc, args, options))
- finished()
+ commandFinished()
}
}
@@ -450,19 +462,21 @@ func (h *ServerHandler) shutdown() {
}
func (h *ServerHandler) incrementActiveCommands() {
- // TODO: Use atomic counter variable instead, so we can get rid of the mutex
- h.mutex.Lock()
- defer h.mutex.Unlock()
+ atomic.AddInt32(&h.activeCommands, 1)
+}
- h.activeCommands++
+func (h *ServerHandler) decrementActiveCommands() int32 {
+ atomic.AddInt32(&h.activeCommands, -1)
+ return atomic.LoadInt32(&h.activeCommands)
}
-func (h *ServerHandler) decrementActiveCommands() int {
- h.mutex.Lock()
- defer h.mutex.Unlock()
+func (h *ServerHandler) incrementActiveReaders() {
+ atomic.AddInt32(&h.activeReaders, 1)
+}
- h.activeCommands--
- return h.activeCommands
+func (h *ServerHandler) decrementActiveReaders() int32 {
+ atomic.AddInt32(&h.activeReaders, -1)
+ return atomic.LoadInt32(&h.activeReaders)
}
func readOptions(opts []string) (map[string]string, error) {