diff options
| author | Paul Buetow <paul@buetow.org> | 2026-03-13 07:48:40 +0200 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2026-03-13 07:48:40 +0200 |
| commit | c88dddee1953c938b47830ec13696f23770eb22d (patch) | |
| tree | 35cca5c6bab8c62bf2bc18895764ff9a0bc84741 /internal/server | |
| parent | 2a665812a0c224ef32d37b2cca681512c5b7d6c1 (diff) | |
task 400: add server session command scaffolding
Diffstat (limited to 'internal/server')
| -rw-r--r-- | internal/server/handlers/basehandler.go | 42 | ||||
| -rw-r--r-- | internal/server/handlers/serverhandler.go | 14 | ||||
| -rw-r--r-- | internal/server/handlers/sessioncommand.go | 136 | ||||
| -rw-r--r-- | internal/server/handlers/sessioncommand_test.go | 156 |
4 files changed, 339 insertions, 9 deletions
diff --git a/internal/server/handlers/basehandler.go b/internal/server/handlers/basehandler.go index 030baf9..42cc4cc 100644 --- a/internal/server/handlers/basehandler.go +++ b/internal/server/handlers/basehandler.go @@ -180,28 +180,30 @@ func (h *baseHandler) handleCommand(commandStr string) { h.sendln(h.serverMessages, dlog.Server.Error(h.user, err)) return } - ctx, cancel := context.WithCancel(context.Background()) - go func() { - <-h.done.Done() - cancel() - }() + ctx, _ := h.newCommandContext(context.Background()) + + if err := h.dispatchCommand(ctx, args, argc); err != nil { + h.sendln(h.serverMessages, dlog.Server.Error(h.user, err)) + } +} +func (h *baseHandler) dispatchCommand(ctx context.Context, args []string, argc int) error { parts := strings.Split(args[0], ":") commandName := parts[0] // Either no options or empty options provided. if len(parts) == 1 || len(parts[1]) == 0 { h.handleCommandCb(ctx, lcontext.LContext{}, argc, args, commandName) - return + return nil } options, ltx, err := config.DeserializeOptions(parts[1:]) if err != nil { - h.sendln(h.serverMessages, dlog.Server.Error(h.user, err)) - return + return err } h.handleOptions(options) h.handleCommandCb(ctx, ltx, argc, args, commandName) + return nil } func (h *baseHandler) handleProtocolVersion(args []string) ([]string, int, string, error) { @@ -212,6 +214,30 @@ func (h *baseHandler) handleBase64(args []string, argc int) ([]string, int, erro return h.codec.handleBase64(args, argc) } +func (h *baseHandler) handleRawCommand(ctx context.Context, command string) error { + args := strings.Fields(command) + if len(args) == 0 { + return fmt.Errorf("empty command") + } + return h.dispatchCommand(ctx, args, len(args)) +} + +func (h *baseHandler) newCommandContext(parent context.Context) (context.Context, context.CancelFunc) { + if parent == nil { + parent = context.Background() + } + + ctx, cancel := context.WithCancel(parent) + go func() { + select { + case <-h.done.Done(): + cancel() + case <-ctx.Done(): + } + }() + return ctx, cancel +} + func (h *baseHandler) handleAckCommand(argc int, args []string) { if argc < 3 { if !h.quiet { diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go index 732cc06..e8c234b 100644 --- a/internal/server/handlers/serverhandler.go +++ b/internal/server/handlers/serverhandler.go @@ -12,6 +12,7 @@ import ( "github.com/mimecast/dtail/internal/io/line" "github.com/mimecast/dtail/internal/lcontext" "github.com/mimecast/dtail/internal/omode" + "github.com/mimecast/dtail/internal/protocol" sshserver "github.com/mimecast/dtail/internal/ssh/server" user "github.com/mimecast/dtail/internal/user/server" @@ -29,6 +30,7 @@ type ServerHandler struct { authKeyStore *sshserver.AuthKeyStore regex string commands map[string]commandHandler + sessionState sessionCommandState // Track pending files waiting for limiter slots pendingFiles int32 } @@ -77,6 +79,7 @@ func NewServerHandler(user *user.User, catLimiter, s := strings.Split(fqdn, ".") h.hostname = s[0] + h.send(h.serverMessages, protocol.HiddenCapabilitiesPrefix+protocol.CapabilityQueryUpdateV1) return &h } @@ -112,7 +115,14 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, ltx lcontext.LCon } func shouldShutdownOnCommandCompletion(commandName string) bool { - return !strings.EqualFold(commandName, "AUTHKEY") + switch { + case strings.EqualFold(commandName, "AUTHKEY"): + return false + case strings.EqualFold(commandName, "SESSION"): + return false + default: + return true + } } func (h *ServerHandler) newCommandRegistry() map[string]commandHandler { @@ -123,7 +133,9 @@ func (h *ServerHandler) newCommandRegistry() map[string]commandHandler { "map": h.handleMapCommand, ".ack": h.handleAckUserCommand, "AUTHKEY": h.handleAuthKeyCommand, + "SESSION": h.handleSessionCommand, "authkey": h.handleAuthKeyCommand, + "session": h.handleSessionCommand, } } diff --git a/internal/server/handlers/sessioncommand.go b/internal/server/handlers/sessioncommand.go new file mode 100644 index 0000000..bc5f83e --- /dev/null +++ b/internal/server/handlers/sessioncommand.go @@ -0,0 +1,136 @@ +package handlers + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "strconv" + "strings" + "sync" + + "github.com/mimecast/dtail/internal/lcontext" + "github.com/mimecast/dtail/internal/omode" + "github.com/mimecast/dtail/internal/protocol" + "github.com/mimecast/dtail/internal/session" +) + +const ( + sessionAckStartOKPrefix = ".syn session start ok" + sessionAckUpdateOKPrefix = ".syn session update ok" + sessionAckErrorPrefix = ".syn session err " +) + +type sessionCommandState struct { + mu sync.Mutex + active bool + generation uint64 + spec session.Spec +} + +func (h *ServerHandler) handleSessionCommand(_ context.Context, _ lcontext.LContext, argc int, args []string, commandFinished func()) { + defer commandFinished() + + action, generation, spec, err := parseSessionCommand(args, argc) + if err != nil { + h.send(h.serverMessages, sessionAckErrorPrefix+err.Error()) + return + } + + switch action { + case "START": + h.sessionState.storeStart(spec) + h.send(h.serverMessages, sessionAckStartOKPrefix) + case "UPDATE": + if !h.sessionState.activeSession() { + h.send(h.serverMessages, sessionAckErrorPrefix+"session not started") + return + } + h.sessionState.storeUpdate(spec, generation) + h.send(h.serverMessages, sessionAckUpdateOKPrefix) + default: + h.send(h.serverMessages, sessionAckErrorPrefix+"unknown action") + } +} + +func parseSessionCommand(args []string, argc int) (action string, generation uint64, spec session.Spec, err error) { + if argc < 3 { + return "", 0, spec, fmt.Errorf("invalid SESSION command") + } + + action = strings.ToUpper(strings.TrimSpace(args[1])) + payloadIndex := 2 + if action == "UPDATE" && argc >= 4 { + generation, err = strconv.ParseUint(args[2], 10, 64) + if err != nil { + return "", 0, spec, fmt.Errorf("invalid session generation") + } + payloadIndex = 3 + } + + payload, err := base64.StdEncoding.DecodeString(args[payloadIndex]) + if err != nil { + return "", 0, spec, fmt.Errorf("invalid session payload") + } + if err := json.Unmarshal(payload, &spec); err != nil { + return "", 0, spec, fmt.Errorf("invalid session spec") + } + if err := validateSessionSpec(spec); err != nil { + return "", 0, spec, err + } + + return action, generation, spec, nil +} + +func validateSessionSpec(spec session.Spec) error { + switch spec.Mode { + case omode.TailClient, omode.CatClient, omode.GrepClient, omode.MapClient, omode.HealthClient: + default: + return fmt.Errorf("unsupported session mode") + } + + if spec.Query != "" && spec.Mode != omode.MapClient && spec.Mode != omode.TailClient { + return fmt.Errorf("query sessions require map or tail mode") + } + + if spec.Query == "" && spec.Mode == omode.MapClient { + return fmt.Errorf("missing session query") + } + + if _, err := spec.Commands(); err != nil { + return fmt.Errorf("invalid session spec") + } + + return nil +} + +func (s *sessionCommandState) storeStart(spec session.Spec) { + s.mu.Lock() + defer s.mu.Unlock() + + s.active = true + s.generation = 1 + s.spec = spec +} + +func (s *sessionCommandState) storeUpdate(spec session.Spec, generation uint64) { + s.mu.Lock() + defer s.mu.Unlock() + + s.active = true + if generation == 0 { + generation = s.generation + 1 + } + s.generation = generation + s.spec = spec +} + +func (s *sessionCommandState) activeSession() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.active +} + +func (s *sessionCommandState) advertisedCapabilities() string { + return protocol.HiddenCapabilitiesPrefix + protocol.CapabilityQueryUpdateV1 +} diff --git a/internal/server/handlers/sessioncommand_test.go b/internal/server/handlers/sessioncommand_test.go new file mode 100644 index 0000000..6af8c5b --- /dev/null +++ b/internal/server/handlers/sessioncommand_test.go @@ -0,0 +1,156 @@ +package handlers + +import ( + "context" + "encoding/base64" + "encoding/json" + "testing" + "time" + + "github.com/mimecast/dtail/internal" + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/io/line" + "github.com/mimecast/dtail/internal/lcontext" + "github.com/mimecast/dtail/internal/omode" + "github.com/mimecast/dtail/internal/protocol" + "github.com/mimecast/dtail/internal/session" + userserver "github.com/mimecast/dtail/internal/user/server" +) + +func TestNewServerHandlerAdvertisesQueryUpdateCapability(t *testing.T) { + handler := newSessionTestHandler("session-capability-user") + + if message := readServerMessage(t, handler.serverMessages); message != protocol.HiddenCapabilitiesPrefix+protocol.CapabilityQueryUpdateV1 { + t.Fatalf("unexpected capability advertisement: %q", message) + } +} + +func TestHandleSessionCommandStartStoresSpec(t *testing.T) { + handler := newSessionTestHandler("session-start-user") + readServerMessage(t, handler.serverMessages) + + spec := session.Spec{ + Mode: omode.TailClient, + Files: []string{"/var/log/app.log"}, + Regex: "ERROR", + } + payload := mustSessionPayload(t, spec) + + commandFinished := false + handler.handleSessionCommand(context.Background(), lcontext.LContext{}, 3, []string{"SESSION", "START", payload}, func() { + commandFinished = true + }) + + if !commandFinished { + t.Fatalf("expected commandFinished callback") + } + if !handler.sessionState.activeSession() { + t.Fatalf("expected session state to become active") + } + if message := readServerMessage(t, handler.serverMessages); message != sessionAckStartOKPrefix { + t.Fatalf("unexpected session start message: %q", message) + } +} + +func TestHandleSessionCommandUpdateRequiresActiveSession(t *testing.T) { + handler := newSessionTestHandler("session-update-user") + readServerMessage(t, handler.serverMessages) + + spec := session.Spec{ + Mode: omode.TailClient, + Files: []string{"/var/log/app.log"}, + Regex: "ERROR", + } + payload := mustSessionPayload(t, spec) + + handler.handleSessionCommand(context.Background(), lcontext.LContext{}, 3, []string{"SESSION", "UPDATE", payload}, func() {}) + + if message := readServerMessage(t, handler.serverMessages); message != sessionAckErrorPrefix+"session not started" { + t.Fatalf("unexpected session update error: %q", message) + } +} + +func TestHandleSessionCommandRejectsInvalidPayload(t *testing.T) { + handler := newSessionTestHandler("session-invalid-user") + readServerMessage(t, handler.serverMessages) + + handler.handleSessionCommand(context.Background(), lcontext.LContext{}, 3, []string{"SESSION", "START", "not-base64"}, func() {}) + + if message := readServerMessage(t, handler.serverMessages); message != sessionAckErrorPrefix+"invalid session payload" { + t.Fatalf("unexpected invalid payload message: %q", message) + } +} + +func newSessionTestHandler(userName string) *ServerHandler { + handler := &ServerHandler{ + baseHandler: baseHandler{ + done: internal.NewDone(), + lines: make(chan *line.Line, 4), + serverMessages: make(chan string, 8), + maprMessages: make(chan string, 4), + ackCloseReceived: make(chan struct{}), + user: &userserver.User{Name: userName}, + codec: newProtocolCodec(&userserver.User{Name: userName}), + }, + serverCfg: &config.ServerConfig{ + AuthKeyEnabled: true, + }, + } + handler.send(handler.serverMessages, protocol.HiddenCapabilitiesPrefix+protocol.CapabilityQueryUpdateV1) + return handler +} + +func mustSessionPayload(t *testing.T, spec session.Spec) string { + t.Helper() + + payload, err := json.Marshal(spec) + if err != nil { + t.Fatalf("marshal session spec: %v", err) + } + return base64.StdEncoding.EncodeToString(payload) +} + +func TestParseSessionCommandWithGeneration(t *testing.T) { + spec := session.Spec{ + Mode: omode.TailClient, + Files: []string{"/var/log/app.log"}, + Regex: "ERROR", + } + + action, generation, parsedSpec, err := parseSessionCommand([]string{"SESSION", "UPDATE", "7", mustSessionPayload(t, spec)}, 4) + if err != nil { + t.Fatalf("parseSessionCommand error: %v", err) + } + if action != "UPDATE" { + t.Fatalf("unexpected action: %s", action) + } + if generation != 7 { + t.Fatalf("unexpected generation: %d", generation) + } + if parsedSpec.Mode != spec.Mode { + t.Fatalf("unexpected parsed mode: %v", parsedSpec.Mode) + } +} + +func TestSessionStateStoreUpdateAutoIncrementsGeneration(t *testing.T) { + var state sessionCommandState + + state.storeStart(session.Spec{Mode: omode.TailClient, Files: []string{"/tmp/a"}, Regex: "ERROR"}) + state.storeUpdate(session.Spec{Mode: omode.TailClient, Files: []string{"/tmp/b"}, Regex: "WARN"}, 0) + + state.mu.Lock() + defer state.mu.Unlock() + if state.generation != 2 { + t.Fatalf("unexpected generation: %d", state.generation) + } +} + +func TestSessionCommandReadServerMessageTimeoutProtection(t *testing.T) { + messages := make(chan string) + + select { + case <-messages: + t.Fatalf("unexpected message") + case <-time.After(5 * time.Millisecond): + } +} |
