summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
Diffstat (limited to 'internal')
-rw-r--r--internal/server/background/background.go9
-rw-r--r--internal/server/handlers/serverhandler.go97
-rw-r--r--internal/server/server.go2
3 files changed, 68 insertions, 40 deletions
diff --git a/internal/server/background/background.go b/internal/server/background/background.go
index d31c1f2..537ccbb 100644
--- a/internal/server/background/background.go
+++ b/internal/server/background/background.go
@@ -2,6 +2,7 @@ package background
import (
"context"
+ "fmt"
"sync"
"github.com/mimecast/dtail/internal/io/logger"
@@ -23,12 +24,18 @@ func NewBackground() *Background {
}
}
-func (b Background) Add(name string, cancel context.CancelFunc, done <-chan struct{}) {
+func (b Background) Add(name string, cancel context.CancelFunc, done <-chan struct{}) error {
b.mutex.Lock()
defer b.mutex.Unlock()
+ if _, ok := b.jobs[name]; ok {
+ return fmt.Errorf("job '%s' already exists", name)
+ }
+
logger.Debug("background", name, "adding job")
b.jobs[name] = job{cancel, done}
+
+ return nil
}
func (b Background) get(name string) (job, bool) {
diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go
index cc15c63..bcd3f85 100644
--- a/internal/server/handlers/serverhandler.go
+++ b/internal/server/handlers/serverhandler.go
@@ -2,6 +2,7 @@ package handlers
import (
"context"
+ "crypto/sha256"
"encoding/base64"
"errors"
"fmt"
@@ -44,16 +45,18 @@ type ServerHandler struct {
tailLimiter chan struct{}
globalServerWaitFor chan struct{}
ackCloseReceived chan struct{}
- ctx context.Context
+ serverCtx context.Context
+ handlerCtx context.Context
done chan struct{}
- activeReaders int
- background *background.Commands
+ activeCommands int
+ background *background.Background
}
// NewServerHandler returns the server handler.
-func NewServerHandler(ctx context.Context, user *user.User, catLimiter, tailLimiter, globalServerWaitFor chan struct{}) (*ServerHandler, <-chan struct{}) {
+func NewServerHandler(handlerCtx, serverCtx context.Context, user *user.User, catLimiter, tailLimiter, globalServerWaitFor chan struct{}) (*ServerHandler, <-chan struct{}) {
h := ServerHandler{
- ctx: ctx,
+ serverCtx: serverCtx,
+ handlerCtx: handlerCtx,
done: make(chan struct{}),
mutex: &sync.Mutex{},
lines: make(chan line.Line, 100),
@@ -65,7 +68,7 @@ func NewServerHandler(ctx context.Context, user *user.User, catLimiter, tailLimi
globalServerWaitFor: globalServerWaitFor,
regex: ".",
user: user,
- background: background.NewCommands(),
+ background: background.NewBackground(),
}
fqdn, err := os.Hostname()
@@ -115,7 +118,7 @@ func (h *ServerHandler) Read(p []byte) (n int, err error) {
case <-time.After(time.Second):
// Once in a while check whether we are done.
select {
- case <-h.ctx.Done():
+ case <-h.handlerCtx.Done():
return 0, io.EOF
default:
}
@@ -129,7 +132,7 @@ func (h *ServerHandler) Write(p []byte) (n int, err error) {
switch c {
case ';':
commandStr := strings.TrimSpace(string(h.payload))
- h.handleCommand(h.ctx, commandStr)
+ h.handleCommand(h.handlerCtx, commandStr)
h.payload = nil
default:
h.payload = append(h.payload, c)
@@ -238,29 +241,31 @@ func (h *ServerHandler) handleControlCommand(argc int, args []string) {
func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []string) {
logger.Debug(h.user, "handleUserCommand", argc, args)
+ h.incrementActiveCommands()
+ finished := func() {
+ if h.decrementActiveCommands() == 0 {
+ h.shutdown()
+ }
+ }
+
splitted := strings.Split(args[0], ":")
command := splitted[0]
- commandFlags := splitted[1:]
+ flags := splitted[1:]
switch command {
case "grep", "cat":
command := newReadCommand(h, omode.CatClient)
- h.incrementActiveReaders()
+ h.incrementActiveCommands()
go func() {
command.Start(ctx, argc, args)
- if h.decrementActiveReaders() == 0 {
- h.shutdown()
- }
+ finished()
}()
case "tail":
command := newReadCommand(h, omode.TailClient)
- h.incrementActiveReaders()
go func() {
command.Start(ctx, argc, args)
- if h.decrementActiveReaders() == 0 {
- h.shutdown()
- }
+ finished()
}()
case "map":
@@ -268,46 +273,62 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []
if err != nil {
h.sendServerMessage(err.Error())
logger.Error(h.user, err)
+ finished()
return
}
h.aggregate = aggregate
go func() {
command.Start(ctx, h.aggregatedMessages)
- h.shutdown()
+ finished()
}()
case "run":
+ // TODO: Refactor this "run" case, move code to runcommand.go
command := newRunCommand(h)
- if contains(commandFlags, "stop_background") {
- h.background.Stop(argc, args)
+ checksum := sha256.Sum256([]byte(strings.Join(args, " ")))
+ name := fmt.Sprintf("%s.%s", h.user.Name, checksum)
+
+ if contains(flags, "background.stop") {
+ h.background.Stop(name)
+ finished()
return
}
done := make(chan struct{})
- if contains(commandFlags, "start_background") {
- commandCtx, cancel := context.WithCancel(ctx)
- h.background.Add(argc, args, cancel, done)
- ctx = commandCtx
+
+ if contains(flags, "background.start") {
+ commandCtx, cancel := context.WithTimeout(h.serverCtx, time.Hour)
+ if err := h.background.Add(name, cancel, done); err != nil {
+ h.sendServerMessage(logger.Error(h.user, err, args))
+ finished()
+ return
+ }
+
+ go func() {
+ command.Start(commandCtx, argc, args)
+ close(done)
+ }()
+ finished()
+ return
}
- h.incrementActiveReaders()
go func() { h.globalServerWaitFor <- struct{}{} }()
go func() {
command.Start(ctx, argc, args)
close(done)
<-h.globalServerWaitFor
- if h.decrementActiveReaders() == 0 {
- h.shutdown()
- }
+ finished()
}()
case "ack", ".ack":
h.handleAckCommand(argc, args)
+ finished()
default:
h.sendServerMessage(logger.Error(h.user, "Received unknown command", argc, args))
+ finished()
}
}
@@ -324,7 +345,7 @@ func (h *ServerHandler) handleAckCommand(argc int, args []string) {
func (h *ServerHandler) send(ch chan<- string, message string) {
select {
case ch <- message:
- case <-h.ctx.Done():
+ case <-h.handlerCtx.Done():
}
}
@@ -346,7 +367,6 @@ func (h *ServerHandler) flush() {
unsentMessages := func() int {
return len(h.lines) + len(h.serverMessages) + len(h.aggregatedMessages)
}
-
for i := 0; i < 3; i++ {
if unsentMessages() == 0 {
logger.Debug(h.user, "All lines sent")
@@ -366,7 +386,7 @@ func (h *ServerHandler) shutdown() {
go func() {
select {
case h.serverMessageC() <- ".syn close connection":
- case <-h.ctx.Done():
+ case <-h.handlerCtx.Done():
}
}()
@@ -374,7 +394,7 @@ func (h *ServerHandler) shutdown() {
case <-h.ackCloseReceived:
case <-time.After(time.Second * 5):
logger.Debug(h.user, "Shutdown timeout reached, enforcing shutdown")
- case <-h.ctx.Done():
+ case <-h.handlerCtx.Done():
}
select {
@@ -383,19 +403,20 @@ func (h *ServerHandler) shutdown() {
}
}
-func (h *ServerHandler) incrementActiveReaders() {
- // REFACTOR: Use atomic counter variable instead, so we can get rid of the mutex
+func (h *ServerHandler) incrementActiveCommands() {
+ // TODO: Use atomic counter variable instead, so we can get rid of the mutex
h.mutex.Lock()
defer h.mutex.Unlock()
- h.activeReaders++
+ h.activeCommands++
}
-func (h *ServerHandler) decrementActiveReaders() int {
+
+func (h *ServerHandler) decrementActiveCommands() int {
h.mutex.Lock()
defer h.mutex.Unlock()
- h.activeReaders--
- return h.activeReaders
+ h.activeCommands--
+ return h.activeCommands
}
func contains(haystack []string, needle string) bool {
diff --git a/internal/server/server.go b/internal/server/server.go
index 9314468..1421540 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -178,7 +178,7 @@ func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, in <-ch
case config.ControlUser:
handler, done = handlers.NewControlHandler(handlerCtx, user)
default:
- handler, done = handlers.NewServerHandler(handlerCtx, user, s.catLimiter, s.tailLimiter, s.shutdownWaitFor)
+ handler, done = handlers.NewServerHandler(handlerCtx, ctx, user, s.catLimiter, s.tailLimiter, s.shutdownWaitFor)
}
go func() {