summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPaul Buetow <paul@buetow.org>2026-03-13 07:48:40 +0200
committerPaul Buetow <paul@buetow.org>2026-03-13 07:48:40 +0200
commitc88dddee1953c938b47830ec13696f23770eb22d (patch)
tree35cca5c6bab8c62bf2bc18895764ff9a0bc84741
parent2a665812a0c224ef32d37b2cca681512c5b7d6c1 (diff)
task 400: add server session command scaffolding
-rw-r--r--internal/clients/session_spec.go116
-rw-r--r--internal/server/handlers/basehandler.go42
-rw-r--r--internal/server/handlers/serverhandler.go14
-rw-r--r--internal/server/handlers/sessioncommand.go136
-rw-r--r--internal/server/handlers/sessioncommand_test.go156
-rw-r--r--internal/session/spec.go124
6 files changed, 466 insertions, 122 deletions
diff --git a/internal/clients/session_spec.go b/internal/clients/session_spec.go
index 37b1803..a5218f6 100644
--- a/internal/clients/session_spec.go
+++ b/internal/clients/session_spec.go
@@ -1,124 +1,14 @@
package clients
import (
- "fmt"
- "strings"
-
"github.com/mimecast/dtail/internal/config"
- "github.com/mimecast/dtail/internal/omode"
- "github.com/mimecast/dtail/internal/regex"
+ sessionspec "github.com/mimecast/dtail/internal/session"
)
// SessionSpec captures the mutable, per-connection workload a DTail client wants to run.
-type SessionSpec struct {
- Mode omode.Mode
- Files []string
- Options string
- Query string
- Regex string
- RegexInvert bool
- Timeout int
-}
+type SessionSpec = sessionspec.Spec
// NewSessionSpec returns a session specification from client args.
func NewSessionSpec(args config.Args) SessionSpec {
- return SessionSpec{
- Mode: args.Mode,
- Files: splitSessionFiles(args.What),
- Options: args.SerializeOptions(),
- Query: strings.TrimSpace(args.QueryStr),
- Regex: args.RegexStr,
- RegexInvert: args.RegexInvert,
- Timeout: args.Timeout,
- }
-}
-
-// Commands returns the legacy command stream for this session specification.
-func (s SessionSpec) Commands() ([]string, error) {
- switch {
- case s.Mode == omode.HealthClient:
- return []string{"health"}, nil
- case s.Query != "":
- return s.queryCommands()
- default:
- return s.readCommands(s.Mode.String())
- }
-}
-
-func (s SessionSpec) queryCommands() ([]string, error) {
- if s.Mode != omode.MapClient && s.Mode != omode.TailClient {
- return nil, fmt.Errorf("session spec query mode requires map or tail mode, got %s", s.Mode)
- }
-
- regexValue, err := s.serializedRegex()
- if err != nil {
- return nil, err
- }
-
- commands := []string{fmt.Sprintf("map:%s %s", s.Options, s.Query)}
- readMode := "cat"
- if s.Mode == omode.TailClient {
- readMode = "tail"
- }
-
- for _, file := range s.Files {
- if s.Timeout > 0 {
- commands = append(commands, fmt.Sprintf("timeout %d %s %s %s", s.Timeout, readMode, file, regexValue))
- continue
- }
- commands = append(commands, fmt.Sprintf("%s:%s %s %s", readMode, s.Options, file, regexValue))
- }
-
- return commands, nil
-}
-
-func (s SessionSpec) readCommands(mode string) ([]string, error) {
- switch s.Mode {
- case omode.TailClient, omode.CatClient, omode.GrepClient:
- default:
- return nil, fmt.Errorf("unsupported session mode %s", s.Mode)
- }
-
- regexValue, err := s.serializedRegex()
- if err != nil {
- return nil, err
- }
-
- var commands []string
- for _, file := range s.Files {
- commands = append(commands, fmt.Sprintf("%s:%s %s %s", mode, s.Options, file, regexValue))
- }
-
- return commands, nil
-}
-
-func (s SessionSpec) serializedRegex() (string, error) {
- flag := regex.Default
- if s.RegexInvert {
- flag = regex.Invert
- }
-
- re, err := regex.New(s.Regex, flag)
- if err != nil {
- return "", err
- }
-
- return re.Serialize()
-}
-
-func splitSessionFiles(what string) []string {
- if strings.TrimSpace(what) == "" {
- return nil
- }
-
- rawFiles := strings.Split(what, ",")
- files := make([]string, 0, len(rawFiles))
- for _, file := range rawFiles {
- file = strings.TrimSpace(file)
- if file == "" {
- continue
- }
- files = append(files, file)
- }
- return files
+ return sessionspec.NewSpec(args)
}
diff --git a/internal/server/handlers/basehandler.go b/internal/server/handlers/basehandler.go
index 030baf9..42cc4cc 100644
--- a/internal/server/handlers/basehandler.go
+++ b/internal/server/handlers/basehandler.go
@@ -180,28 +180,30 @@ func (h *baseHandler) handleCommand(commandStr string) {
h.sendln(h.serverMessages, dlog.Server.Error(h.user, err))
return
}
- ctx, cancel := context.WithCancel(context.Background())
- go func() {
- <-h.done.Done()
- cancel()
- }()
+ ctx, _ := h.newCommandContext(context.Background())
+
+ if err := h.dispatchCommand(ctx, args, argc); err != nil {
+ h.sendln(h.serverMessages, dlog.Server.Error(h.user, err))
+ }
+}
+func (h *baseHandler) dispatchCommand(ctx context.Context, args []string, argc int) error {
parts := strings.Split(args[0], ":")
commandName := parts[0]
// Either no options or empty options provided.
if len(parts) == 1 || len(parts[1]) == 0 {
h.handleCommandCb(ctx, lcontext.LContext{}, argc, args, commandName)
- return
+ return nil
}
options, ltx, err := config.DeserializeOptions(parts[1:])
if err != nil {
- h.sendln(h.serverMessages, dlog.Server.Error(h.user, err))
- return
+ return err
}
h.handleOptions(options)
h.handleCommandCb(ctx, ltx, argc, args, commandName)
+ return nil
}
func (h *baseHandler) handleProtocolVersion(args []string) ([]string, int, string, error) {
@@ -212,6 +214,30 @@ func (h *baseHandler) handleBase64(args []string, argc int) ([]string, int, erro
return h.codec.handleBase64(args, argc)
}
+func (h *baseHandler) handleRawCommand(ctx context.Context, command string) error {
+ args := strings.Fields(command)
+ if len(args) == 0 {
+ return fmt.Errorf("empty command")
+ }
+ return h.dispatchCommand(ctx, args, len(args))
+}
+
+func (h *baseHandler) newCommandContext(parent context.Context) (context.Context, context.CancelFunc) {
+ if parent == nil {
+ parent = context.Background()
+ }
+
+ ctx, cancel := context.WithCancel(parent)
+ go func() {
+ select {
+ case <-h.done.Done():
+ cancel()
+ case <-ctx.Done():
+ }
+ }()
+ return ctx, cancel
+}
+
func (h *baseHandler) handleAckCommand(argc int, args []string) {
if argc < 3 {
if !h.quiet {
diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go
index 732cc06..e8c234b 100644
--- a/internal/server/handlers/serverhandler.go
+++ b/internal/server/handlers/serverhandler.go
@@ -12,6 +12,7 @@ import (
"github.com/mimecast/dtail/internal/io/line"
"github.com/mimecast/dtail/internal/lcontext"
"github.com/mimecast/dtail/internal/omode"
+ "github.com/mimecast/dtail/internal/protocol"
sshserver "github.com/mimecast/dtail/internal/ssh/server"
user "github.com/mimecast/dtail/internal/user/server"
@@ -29,6 +30,7 @@ type ServerHandler struct {
authKeyStore *sshserver.AuthKeyStore
regex string
commands map[string]commandHandler
+ sessionState sessionCommandState
// Track pending files waiting for limiter slots
pendingFiles int32
}
@@ -77,6 +79,7 @@ func NewServerHandler(user *user.User, catLimiter,
s := strings.Split(fqdn, ".")
h.hostname = s[0]
+ h.send(h.serverMessages, protocol.HiddenCapabilitiesPrefix+protocol.CapabilityQueryUpdateV1)
return &h
}
@@ -112,7 +115,14 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, ltx lcontext.LCon
}
func shouldShutdownOnCommandCompletion(commandName string) bool {
- return !strings.EqualFold(commandName, "AUTHKEY")
+ switch {
+ case strings.EqualFold(commandName, "AUTHKEY"):
+ return false
+ case strings.EqualFold(commandName, "SESSION"):
+ return false
+ default:
+ return true
+ }
}
func (h *ServerHandler) newCommandRegistry() map[string]commandHandler {
@@ -123,7 +133,9 @@ func (h *ServerHandler) newCommandRegistry() map[string]commandHandler {
"map": h.handleMapCommand,
".ack": h.handleAckUserCommand,
"AUTHKEY": h.handleAuthKeyCommand,
+ "SESSION": h.handleSessionCommand,
"authkey": h.handleAuthKeyCommand,
+ "session": h.handleSessionCommand,
}
}
diff --git a/internal/server/handlers/sessioncommand.go b/internal/server/handlers/sessioncommand.go
new file mode 100644
index 0000000..bc5f83e
--- /dev/null
+++ b/internal/server/handlers/sessioncommand.go
@@ -0,0 +1,136 @@
+package handlers
+
+import (
+ "context"
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "strconv"
+ "strings"
+ "sync"
+
+ "github.com/mimecast/dtail/internal/lcontext"
+ "github.com/mimecast/dtail/internal/omode"
+ "github.com/mimecast/dtail/internal/protocol"
+ "github.com/mimecast/dtail/internal/session"
+)
+
+const (
+ sessionAckStartOKPrefix = ".syn session start ok"
+ sessionAckUpdateOKPrefix = ".syn session update ok"
+ sessionAckErrorPrefix = ".syn session err "
+)
+
+type sessionCommandState struct {
+ mu sync.Mutex
+ active bool
+ generation uint64
+ spec session.Spec
+}
+
+func (h *ServerHandler) handleSessionCommand(_ context.Context, _ lcontext.LContext, argc int, args []string, commandFinished func()) {
+ defer commandFinished()
+
+ action, generation, spec, err := parseSessionCommand(args, argc)
+ if err != nil {
+ h.send(h.serverMessages, sessionAckErrorPrefix+err.Error())
+ return
+ }
+
+ switch action {
+ case "START":
+ h.sessionState.storeStart(spec)
+ h.send(h.serverMessages, sessionAckStartOKPrefix)
+ case "UPDATE":
+ if !h.sessionState.activeSession() {
+ h.send(h.serverMessages, sessionAckErrorPrefix+"session not started")
+ return
+ }
+ h.sessionState.storeUpdate(spec, generation)
+ h.send(h.serverMessages, sessionAckUpdateOKPrefix)
+ default:
+ h.send(h.serverMessages, sessionAckErrorPrefix+"unknown action")
+ }
+}
+
+func parseSessionCommand(args []string, argc int) (action string, generation uint64, spec session.Spec, err error) {
+ if argc < 3 {
+ return "", 0, spec, fmt.Errorf("invalid SESSION command")
+ }
+
+ action = strings.ToUpper(strings.TrimSpace(args[1]))
+ payloadIndex := 2
+ if action == "UPDATE" && argc >= 4 {
+ generation, err = strconv.ParseUint(args[2], 10, 64)
+ if err != nil {
+ return "", 0, spec, fmt.Errorf("invalid session generation")
+ }
+ payloadIndex = 3
+ }
+
+ payload, err := base64.StdEncoding.DecodeString(args[payloadIndex])
+ if err != nil {
+ return "", 0, spec, fmt.Errorf("invalid session payload")
+ }
+ if err := json.Unmarshal(payload, &spec); err != nil {
+ return "", 0, spec, fmt.Errorf("invalid session spec")
+ }
+ if err := validateSessionSpec(spec); err != nil {
+ return "", 0, spec, err
+ }
+
+ return action, generation, spec, nil
+}
+
+func validateSessionSpec(spec session.Spec) error {
+ switch spec.Mode {
+ case omode.TailClient, omode.CatClient, omode.GrepClient, omode.MapClient, omode.HealthClient:
+ default:
+ return fmt.Errorf("unsupported session mode")
+ }
+
+ if spec.Query != "" && spec.Mode != omode.MapClient && spec.Mode != omode.TailClient {
+ return fmt.Errorf("query sessions require map or tail mode")
+ }
+
+ if spec.Query == "" && spec.Mode == omode.MapClient {
+ return fmt.Errorf("missing session query")
+ }
+
+ if _, err := spec.Commands(); err != nil {
+ return fmt.Errorf("invalid session spec")
+ }
+
+ return nil
+}
+
+func (s *sessionCommandState) storeStart(spec session.Spec) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ s.active = true
+ s.generation = 1
+ s.spec = spec
+}
+
+func (s *sessionCommandState) storeUpdate(spec session.Spec, generation uint64) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ s.active = true
+ if generation == 0 {
+ generation = s.generation + 1
+ }
+ s.generation = generation
+ s.spec = spec
+}
+
+func (s *sessionCommandState) activeSession() bool {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ return s.active
+}
+
+func (s *sessionCommandState) advertisedCapabilities() string {
+ return protocol.HiddenCapabilitiesPrefix + protocol.CapabilityQueryUpdateV1
+}
diff --git a/internal/server/handlers/sessioncommand_test.go b/internal/server/handlers/sessioncommand_test.go
new file mode 100644
index 0000000..6af8c5b
--- /dev/null
+++ b/internal/server/handlers/sessioncommand_test.go
@@ -0,0 +1,156 @@
+package handlers
+
+import (
+ "context"
+ "encoding/base64"
+ "encoding/json"
+ "testing"
+ "time"
+
+ "github.com/mimecast/dtail/internal"
+ "github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/io/line"
+ "github.com/mimecast/dtail/internal/lcontext"
+ "github.com/mimecast/dtail/internal/omode"
+ "github.com/mimecast/dtail/internal/protocol"
+ "github.com/mimecast/dtail/internal/session"
+ userserver "github.com/mimecast/dtail/internal/user/server"
+)
+
+func TestNewServerHandlerAdvertisesQueryUpdateCapability(t *testing.T) {
+ handler := newSessionTestHandler("session-capability-user")
+
+ if message := readServerMessage(t, handler.serverMessages); message != protocol.HiddenCapabilitiesPrefix+protocol.CapabilityQueryUpdateV1 {
+ t.Fatalf("unexpected capability advertisement: %q", message)
+ }
+}
+
+func TestHandleSessionCommandStartStoresSpec(t *testing.T) {
+ handler := newSessionTestHandler("session-start-user")
+ readServerMessage(t, handler.serverMessages)
+
+ spec := session.Spec{
+ Mode: omode.TailClient,
+ Files: []string{"/var/log/app.log"},
+ Regex: "ERROR",
+ }
+ payload := mustSessionPayload(t, spec)
+
+ commandFinished := false
+ handler.handleSessionCommand(context.Background(), lcontext.LContext{}, 3, []string{"SESSION", "START", payload}, func() {
+ commandFinished = true
+ })
+
+ if !commandFinished {
+ t.Fatalf("expected commandFinished callback")
+ }
+ if !handler.sessionState.activeSession() {
+ t.Fatalf("expected session state to become active")
+ }
+ if message := readServerMessage(t, handler.serverMessages); message != sessionAckStartOKPrefix {
+ t.Fatalf("unexpected session start message: %q", message)
+ }
+}
+
+func TestHandleSessionCommandUpdateRequiresActiveSession(t *testing.T) {
+ handler := newSessionTestHandler("session-update-user")
+ readServerMessage(t, handler.serverMessages)
+
+ spec := session.Spec{
+ Mode: omode.TailClient,
+ Files: []string{"/var/log/app.log"},
+ Regex: "ERROR",
+ }
+ payload := mustSessionPayload(t, spec)
+
+ handler.handleSessionCommand(context.Background(), lcontext.LContext{}, 3, []string{"SESSION", "UPDATE", payload}, func() {})
+
+ if message := readServerMessage(t, handler.serverMessages); message != sessionAckErrorPrefix+"session not started" {
+ t.Fatalf("unexpected session update error: %q", message)
+ }
+}
+
+func TestHandleSessionCommandRejectsInvalidPayload(t *testing.T) {
+ handler := newSessionTestHandler("session-invalid-user")
+ readServerMessage(t, handler.serverMessages)
+
+ handler.handleSessionCommand(context.Background(), lcontext.LContext{}, 3, []string{"SESSION", "START", "not-base64"}, func() {})
+
+ if message := readServerMessage(t, handler.serverMessages); message != sessionAckErrorPrefix+"invalid session payload" {
+ t.Fatalf("unexpected invalid payload message: %q", message)
+ }
+}
+
+func newSessionTestHandler(userName string) *ServerHandler {
+ handler := &ServerHandler{
+ baseHandler: baseHandler{
+ done: internal.NewDone(),
+ lines: make(chan *line.Line, 4),
+ serverMessages: make(chan string, 8),
+ maprMessages: make(chan string, 4),
+ ackCloseReceived: make(chan struct{}),
+ user: &userserver.User{Name: userName},
+ codec: newProtocolCodec(&userserver.User{Name: userName}),
+ },
+ serverCfg: &config.ServerConfig{
+ AuthKeyEnabled: true,
+ },
+ }
+ handler.send(handler.serverMessages, protocol.HiddenCapabilitiesPrefix+protocol.CapabilityQueryUpdateV1)
+ return handler
+}
+
+func mustSessionPayload(t *testing.T, spec session.Spec) string {
+ t.Helper()
+
+ payload, err := json.Marshal(spec)
+ if err != nil {
+ t.Fatalf("marshal session spec: %v", err)
+ }
+ return base64.StdEncoding.EncodeToString(payload)
+}
+
+func TestParseSessionCommandWithGeneration(t *testing.T) {
+ spec := session.Spec{
+ Mode: omode.TailClient,
+ Files: []string{"/var/log/app.log"},
+ Regex: "ERROR",
+ }
+
+ action, generation, parsedSpec, err := parseSessionCommand([]string{"SESSION", "UPDATE", "7", mustSessionPayload(t, spec)}, 4)
+ if err != nil {
+ t.Fatalf("parseSessionCommand error: %v", err)
+ }
+ if action != "UPDATE" {
+ t.Fatalf("unexpected action: %s", action)
+ }
+ if generation != 7 {
+ t.Fatalf("unexpected generation: %d", generation)
+ }
+ if parsedSpec.Mode != spec.Mode {
+ t.Fatalf("unexpected parsed mode: %v", parsedSpec.Mode)
+ }
+}
+
+func TestSessionStateStoreUpdateAutoIncrementsGeneration(t *testing.T) {
+ var state sessionCommandState
+
+ state.storeStart(session.Spec{Mode: omode.TailClient, Files: []string{"/tmp/a"}, Regex: "ERROR"})
+ state.storeUpdate(session.Spec{Mode: omode.TailClient, Files: []string{"/tmp/b"}, Regex: "WARN"}, 0)
+
+ state.mu.Lock()
+ defer state.mu.Unlock()
+ if state.generation != 2 {
+ t.Fatalf("unexpected generation: %d", state.generation)
+ }
+}
+
+func TestSessionCommandReadServerMessageTimeoutProtection(t *testing.T) {
+ messages := make(chan string)
+
+ select {
+ case <-messages:
+ t.Fatalf("unexpected message")
+ case <-time.After(5 * time.Millisecond):
+ }
+}
diff --git a/internal/session/spec.go b/internal/session/spec.go
new file mode 100644
index 0000000..2d1b77d
--- /dev/null
+++ b/internal/session/spec.go
@@ -0,0 +1,124 @@
+package session
+
+import (
+ "fmt"
+ "strings"
+
+ "github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/omode"
+ "github.com/mimecast/dtail/internal/regex"
+)
+
+// Spec captures the mutable, per-connection workload a DTail client wants to run.
+type Spec struct {
+ Mode omode.Mode `json:"mode"`
+ Files []string `json:"files"`
+ Options string `json:"options,omitempty"`
+ Query string `json:"query,omitempty"`
+ Regex string `json:"regex,omitempty"`
+ RegexInvert bool `json:"regex_invert,omitempty"`
+ Timeout int `json:"timeout,omitempty"`
+}
+
+// NewSpec returns a session specification from client args.
+func NewSpec(args config.Args) Spec {
+ return Spec{
+ Mode: args.Mode,
+ Files: splitFiles(args.What),
+ Options: args.SerializeOptions(),
+ Query: strings.TrimSpace(args.QueryStr),
+ Regex: args.RegexStr,
+ RegexInvert: args.RegexInvert,
+ Timeout: args.Timeout,
+ }
+}
+
+// Commands returns the legacy command stream for this session specification.
+func (s Spec) Commands() ([]string, error) {
+ switch {
+ case s.Mode == omode.HealthClient:
+ return []string{"health"}, nil
+ case s.Query != "":
+ return s.queryCommands()
+ default:
+ return s.readCommands(s.Mode.String())
+ }
+}
+
+func (s Spec) queryCommands() ([]string, error) {
+ if s.Mode != omode.MapClient && s.Mode != omode.TailClient {
+ return nil, fmt.Errorf("session spec query mode requires map or tail mode, got %s", s.Mode)
+ }
+
+ regexValue, err := s.serializedRegex()
+ if err != nil {
+ return nil, err
+ }
+
+ commands := []string{fmt.Sprintf("map:%s %s", s.Options, s.Query)}
+ readMode := "cat"
+ if s.Mode == omode.TailClient {
+ readMode = "tail"
+ }
+
+ for _, file := range s.Files {
+ if s.Timeout > 0 {
+ commands = append(commands, fmt.Sprintf("timeout %d %s %s %s", s.Timeout, readMode, file, regexValue))
+ continue
+ }
+ commands = append(commands, fmt.Sprintf("%s:%s %s %s", readMode, s.Options, file, regexValue))
+ }
+
+ return commands, nil
+}
+
+func (s Spec) readCommands(mode string) ([]string, error) {
+ switch s.Mode {
+ case omode.TailClient, omode.CatClient, omode.GrepClient:
+ default:
+ return nil, fmt.Errorf("unsupported session mode %s", s.Mode)
+ }
+
+ regexValue, err := s.serializedRegex()
+ if err != nil {
+ return nil, err
+ }
+
+ var commands []string
+ for _, file := range s.Files {
+ commands = append(commands, fmt.Sprintf("%s:%s %s %s", mode, s.Options, file, regexValue))
+ }
+
+ return commands, nil
+}
+
+func (s Spec) serializedRegex() (string, error) {
+ flag := regex.Default
+ if s.RegexInvert {
+ flag = regex.Invert
+ }
+
+ re, err := regex.New(s.Regex, flag)
+ if err != nil {
+ return "", err
+ }
+
+ return re.Serialize()
+}
+
+func splitFiles(what string) []string {
+ if strings.TrimSpace(what) == "" {
+ return nil
+ }
+
+ rawFiles := strings.Split(what, ",")
+ files := make([]string, 0, len(rawFiles))
+ for _, file := range rawFiles {
+ file = strings.TrimSpace(file)
+ if file == "" {
+ continue
+ }
+ files = append(files, file)
+ }
+ return files
+}