diff options
| author | Paul Bütow <pbuetow@mimecast.com> | 2020-02-22 10:34:19 +0000 |
|---|---|---|
| committer | Paul Bütow <pbuetow@mimecast.com> | 2020-02-22 10:34:19 +0000 |
| commit | 4d2ab8e6dd645d345fa26d8a067ad6dc14fc1fce (patch) | |
| tree | 1d5568910d40de66fb1f4a796ab1af6774cf6dc9 /internal | |
| parent | 7b400fdc922599bdd6f6c6d6c1dc4a664f104365 (diff) | |
serverhandler understands background jobs better
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/server/background/background.go | 9 | ||||
| -rw-r--r-- | internal/server/handlers/serverhandler.go | 97 | ||||
| -rw-r--r-- | internal/server/server.go | 2 |
3 files changed, 68 insertions, 40 deletions
diff --git a/internal/server/background/background.go b/internal/server/background/background.go index d31c1f2..537ccbb 100644 --- a/internal/server/background/background.go +++ b/internal/server/background/background.go @@ -2,6 +2,7 @@ package background import ( "context" + "fmt" "sync" "github.com/mimecast/dtail/internal/io/logger" @@ -23,12 +24,18 @@ func NewBackground() *Background { } } -func (b Background) Add(name string, cancel context.CancelFunc, done <-chan struct{}) { +func (b Background) Add(name string, cancel context.CancelFunc, done <-chan struct{}) error { b.mutex.Lock() defer b.mutex.Unlock() + if _, ok := b.jobs[name]; ok { + return fmt.Errorf("job '%s' already exists", name) + } + logger.Debug("background", name, "adding job") b.jobs[name] = job{cancel, done} + + return nil } func (b Background) get(name string) (job, bool) { diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go index cc15c63..bcd3f85 100644 --- a/internal/server/handlers/serverhandler.go +++ b/internal/server/handlers/serverhandler.go @@ -2,6 +2,7 @@ package handlers import ( "context" + "crypto/sha256" "encoding/base64" "errors" "fmt" @@ -44,16 +45,18 @@ type ServerHandler struct { tailLimiter chan struct{} globalServerWaitFor chan struct{} ackCloseReceived chan struct{} - ctx context.Context + serverCtx context.Context + handlerCtx context.Context done chan struct{} - activeReaders int - background *background.Commands + activeCommands int + background *background.Background } // NewServerHandler returns the server handler. -func NewServerHandler(ctx context.Context, user *user.User, catLimiter, tailLimiter, globalServerWaitFor chan struct{}) (*ServerHandler, <-chan struct{}) { +func NewServerHandler(handlerCtx, serverCtx context.Context, user *user.User, catLimiter, tailLimiter, globalServerWaitFor chan struct{}) (*ServerHandler, <-chan struct{}) { h := ServerHandler{ - ctx: ctx, + serverCtx: serverCtx, + handlerCtx: handlerCtx, done: make(chan struct{}), mutex: &sync.Mutex{}, lines: make(chan line.Line, 100), @@ -65,7 +68,7 @@ func NewServerHandler(ctx context.Context, user *user.User, catLimiter, tailLimi globalServerWaitFor: globalServerWaitFor, regex: ".", user: user, - background: background.NewCommands(), + background: background.NewBackground(), } fqdn, err := os.Hostname() @@ -115,7 +118,7 @@ func (h *ServerHandler) Read(p []byte) (n int, err error) { case <-time.After(time.Second): // Once in a while check whether we are done. select { - case <-h.ctx.Done(): + case <-h.handlerCtx.Done(): return 0, io.EOF default: } @@ -129,7 +132,7 @@ func (h *ServerHandler) Write(p []byte) (n int, err error) { switch c { case ';': commandStr := strings.TrimSpace(string(h.payload)) - h.handleCommand(h.ctx, commandStr) + h.handleCommand(h.handlerCtx, commandStr) h.payload = nil default: h.payload = append(h.payload, c) @@ -238,29 +241,31 @@ func (h *ServerHandler) handleControlCommand(argc int, args []string) { func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []string) { logger.Debug(h.user, "handleUserCommand", argc, args) + h.incrementActiveCommands() + finished := func() { + if h.decrementActiveCommands() == 0 { + h.shutdown() + } + } + splitted := strings.Split(args[0], ":") command := splitted[0] - commandFlags := splitted[1:] + flags := splitted[1:] switch command { case "grep", "cat": command := newReadCommand(h, omode.CatClient) - h.incrementActiveReaders() + h.incrementActiveCommands() go func() { command.Start(ctx, argc, args) - if h.decrementActiveReaders() == 0 { - h.shutdown() - } + finished() }() case "tail": command := newReadCommand(h, omode.TailClient) - h.incrementActiveReaders() go func() { command.Start(ctx, argc, args) - if h.decrementActiveReaders() == 0 { - h.shutdown() - } + finished() }() case "map": @@ -268,46 +273,62 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args [] if err != nil { h.sendServerMessage(err.Error()) logger.Error(h.user, err) + finished() return } h.aggregate = aggregate go func() { command.Start(ctx, h.aggregatedMessages) - h.shutdown() + finished() }() case "run": + // TODO: Refactor this "run" case, move code to runcommand.go command := newRunCommand(h) - if contains(commandFlags, "stop_background") { - h.background.Stop(argc, args) + checksum := sha256.Sum256([]byte(strings.Join(args, " "))) + name := fmt.Sprintf("%s.%s", h.user.Name, checksum) + + if contains(flags, "background.stop") { + h.background.Stop(name) + finished() return } done := make(chan struct{}) - if contains(commandFlags, "start_background") { - commandCtx, cancel := context.WithCancel(ctx) - h.background.Add(argc, args, cancel, done) - ctx = commandCtx + + if contains(flags, "background.start") { + commandCtx, cancel := context.WithTimeout(h.serverCtx, time.Hour) + if err := h.background.Add(name, cancel, done); err != nil { + h.sendServerMessage(logger.Error(h.user, err, args)) + finished() + return + } + + go func() { + command.Start(commandCtx, argc, args) + close(done) + }() + finished() + return } - h.incrementActiveReaders() go func() { h.globalServerWaitFor <- struct{}{} }() go func() { command.Start(ctx, argc, args) close(done) <-h.globalServerWaitFor - if h.decrementActiveReaders() == 0 { - h.shutdown() - } + finished() }() case "ack", ".ack": h.handleAckCommand(argc, args) + finished() default: h.sendServerMessage(logger.Error(h.user, "Received unknown command", argc, args)) + finished() } } @@ -324,7 +345,7 @@ func (h *ServerHandler) handleAckCommand(argc int, args []string) { func (h *ServerHandler) send(ch chan<- string, message string) { select { case ch <- message: - case <-h.ctx.Done(): + case <-h.handlerCtx.Done(): } } @@ -346,7 +367,6 @@ func (h *ServerHandler) flush() { unsentMessages := func() int { return len(h.lines) + len(h.serverMessages) + len(h.aggregatedMessages) } - for i := 0; i < 3; i++ { if unsentMessages() == 0 { logger.Debug(h.user, "All lines sent") @@ -366,7 +386,7 @@ func (h *ServerHandler) shutdown() { go func() { select { case h.serverMessageC() <- ".syn close connection": - case <-h.ctx.Done(): + case <-h.handlerCtx.Done(): } }() @@ -374,7 +394,7 @@ func (h *ServerHandler) shutdown() { case <-h.ackCloseReceived: case <-time.After(time.Second * 5): logger.Debug(h.user, "Shutdown timeout reached, enforcing shutdown") - case <-h.ctx.Done(): + case <-h.handlerCtx.Done(): } select { @@ -383,19 +403,20 @@ func (h *ServerHandler) shutdown() { } } -func (h *ServerHandler) incrementActiveReaders() { - // REFACTOR: Use atomic counter variable instead, so we can get rid of the mutex +func (h *ServerHandler) incrementActiveCommands() { + // TODO: Use atomic counter variable instead, so we can get rid of the mutex h.mutex.Lock() defer h.mutex.Unlock() - h.activeReaders++ + h.activeCommands++ } -func (h *ServerHandler) decrementActiveReaders() int { + +func (h *ServerHandler) decrementActiveCommands() int { h.mutex.Lock() defer h.mutex.Unlock() - h.activeReaders-- - return h.activeReaders + h.activeCommands-- + return h.activeCommands } func contains(haystack []string, needle string) bool { diff --git a/internal/server/server.go b/internal/server/server.go index 9314468..1421540 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -178,7 +178,7 @@ func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, in <-ch case config.ControlUser: handler, done = handlers.NewControlHandler(handlerCtx, user) default: - handler, done = handlers.NewServerHandler(handlerCtx, user, s.catLimiter, s.tailLimiter, s.shutdownWaitFor) + handler, done = handlers.NewServerHandler(handlerCtx, ctx, user, s.catLimiter, s.tailLimiter, s.shutdownWaitFor) } go func() { |
