summaryrefslogtreecommitdiff
path: root/internal/server/handlers/serverhandler.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/server/handlers/serverhandler.go')
-rw-r--r--internal/server/handlers/serverhandler.go73
1 files changed, 54 insertions, 19 deletions
diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go
index b840c77..946ae83 100644
--- a/internal/server/handlers/serverhandler.go
+++ b/internal/server/handlers/serverhandler.go
@@ -7,6 +7,7 @@ import (
"fmt"
"io"
"os"
+ "strconv"
"strings"
"sync"
"time"
@@ -37,28 +38,31 @@ type ServerHandler struct {
payload []byte
hostname string
user *user.User
- catLimiter chan struct{}
- tailLimiter chan struct{}
- ackCloseReceived chan struct{}
- ctx context.Context
- done chan struct{}
- activeReaders int
+ // TODO: Move all these channels into a separate struct for readability!
+ catLimiter chan struct{}
+ tailLimiter chan struct{}
+ globalServerWaitFor chan struct{}
+ ackCloseReceived chan struct{}
+ ctx context.Context
+ done chan struct{}
+ activeReaders int
}
// NewServerHandler returns the server handler.
-func NewServerHandler(ctx context.Context, user *user.User, catLimiter chan struct{}, tailLimiter chan struct{}) (*ServerHandler, <-chan struct{}) {
+func NewServerHandler(ctx context.Context, user *user.User, catLimiter, tailLimiter, globalServerWaitFor chan struct{}) (*ServerHandler, <-chan struct{}) {
h := ServerHandler{
- ctx: ctx,
- done: make(chan struct{}),
- mutex: &sync.Mutex{},
- lines: make(chan line.Line, 100),
- serverMessages: make(chan string, 10),
- aggregatedMessages: make(chan string, 10),
- ackCloseReceived: make(chan struct{}),
- catLimiter: catLimiter,
- tailLimiter: tailLimiter,
- regex: ".",
- user: user,
+ ctx: ctx,
+ done: make(chan struct{}),
+ mutex: &sync.Mutex{},
+ lines: make(chan line.Line, 100),
+ serverMessages: make(chan string, 10),
+ aggregatedMessages: make(chan string, 10),
+ ackCloseReceived: make(chan struct{}),
+ catLimiter: catLimiter,
+ tailLimiter: tailLimiter,
+ globalServerWaitFor: globalServerWaitFor,
+ regex: ".",
+ user: user,
}
fqdn, err := os.Hostname()
@@ -135,6 +139,7 @@ func (h *ServerHandler) Write(p []byte) (n int, err error) {
func (h *ServerHandler) handleCommand(ctx context.Context, commandStr string) {
logger.Debug(h.user, commandStr)
+ var timeout time.Duration
args, argc, err := h.handleProtocolVersion(strings.Split(commandStr, " "))
if err != nil {
@@ -148,11 +153,28 @@ func (h *ServerHandler) handleCommand(ctx context.Context, commandStr string) {
return
}
+ args, argc, timeout, err = h.handleTimeout(args, argc)
+ if err != nil {
+ h.send(h.serverMessages, logger.Error(h.user, err))
+ return
+ }
+
if h.user.Name == config.ControlUser {
h.handleControlCommand(argc, args)
return
}
+ if timeout > 0 {
+ logger.Debug("Command with timeout context", argc, args, timeout)
+ commandCtx, cancel := context.WithTimeout(ctx, timeout)
+ go func() {
+ <-commandCtx.Done()
+ cancel()
+ }()
+ h.handleUserCommand(commandCtx, argc, args)
+ return
+ }
+
h.handleUserCommand(ctx, argc, args)
}
@@ -191,6 +213,16 @@ func (h *ServerHandler) handleBase64(args []string, argc int) ([]string, int, er
return args, argc, nil
}
+func (h *ServerHandler) handleTimeout(args []string, argc int) ([]string, int, time.Duration, error) {
+ if argc <= 2 || args[0] != "timeout" {
+ // No timeout specified
+ return args, argc, time.Duration(0) * time.Second, nil
+ }
+
+ timeout, err := strconv.Atoi(args[1])
+ return args[2:], argc - 2, time.Duration(timeout) * time.Second, err
+}
+
func (h *ServerHandler) handleControlCommand(argc int, args []string) {
switch args[0] {
case "debug":
@@ -241,8 +273,10 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []
case "run":
command := newRunCommand(h)
h.incrementActiveReaders()
+ go func() { h.globalServerWaitFor <- struct{}{} }()
go func() {
command.Start(ctx, argc, args)
+ <-h.globalServerWaitFor
if h.decrementActiveReaders() == 0 {
h.shutdown()
}
@@ -332,12 +366,13 @@ func (h *ServerHandler) incrementActiveReaders() {
// TODO: Use atomic counter variable instead, so we can get rid of the mutex
h.mutex.Lock()
defer h.mutex.Unlock()
+
h.activeReaders++
}
func (h *ServerHandler) decrementActiveReaders() int {
h.mutex.Lock()
defer h.mutex.Unlock()
- h.activeReaders--
+ h.activeReaders--
return h.activeReaders
}