diff options
| author | Paul Buetow <paul@buetow.org> | 2026-03-13 07:48:40 +0200 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2026-03-13 07:48:40 +0200 |
| commit | c88dddee1953c938b47830ec13696f23770eb22d (patch) | |
| tree | 35cca5c6bab8c62bf2bc18895764ff9a0bc84741 | |
| parent | 2a665812a0c224ef32d37b2cca681512c5b7d6c1 (diff) | |
task 400: add server session command scaffolding
| -rw-r--r-- | internal/clients/session_spec.go | 116 | ||||
| -rw-r--r-- | internal/server/handlers/basehandler.go | 42 | ||||
| -rw-r--r-- | internal/server/handlers/serverhandler.go | 14 | ||||
| -rw-r--r-- | internal/server/handlers/sessioncommand.go | 136 | ||||
| -rw-r--r-- | internal/server/handlers/sessioncommand_test.go | 156 | ||||
| -rw-r--r-- | internal/session/spec.go | 124 |
6 files changed, 466 insertions, 122 deletions
diff --git a/internal/clients/session_spec.go b/internal/clients/session_spec.go index 37b1803..a5218f6 100644 --- a/internal/clients/session_spec.go +++ b/internal/clients/session_spec.go @@ -1,124 +1,14 @@ package clients import ( - "fmt" - "strings" - "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/omode" - "github.com/mimecast/dtail/internal/regex" + sessionspec "github.com/mimecast/dtail/internal/session" ) // SessionSpec captures the mutable, per-connection workload a DTail client wants to run. -type SessionSpec struct { - Mode omode.Mode - Files []string - Options string - Query string - Regex string - RegexInvert bool - Timeout int -} +type SessionSpec = sessionspec.Spec // NewSessionSpec returns a session specification from client args. func NewSessionSpec(args config.Args) SessionSpec { - return SessionSpec{ - Mode: args.Mode, - Files: splitSessionFiles(args.What), - Options: args.SerializeOptions(), - Query: strings.TrimSpace(args.QueryStr), - Regex: args.RegexStr, - RegexInvert: args.RegexInvert, - Timeout: args.Timeout, - } -} - -// Commands returns the legacy command stream for this session specification. -func (s SessionSpec) Commands() ([]string, error) { - switch { - case s.Mode == omode.HealthClient: - return []string{"health"}, nil - case s.Query != "": - return s.queryCommands() - default: - return s.readCommands(s.Mode.String()) - } -} - -func (s SessionSpec) queryCommands() ([]string, error) { - if s.Mode != omode.MapClient && s.Mode != omode.TailClient { - return nil, fmt.Errorf("session spec query mode requires map or tail mode, got %s", s.Mode) - } - - regexValue, err := s.serializedRegex() - if err != nil { - return nil, err - } - - commands := []string{fmt.Sprintf("map:%s %s", s.Options, s.Query)} - readMode := "cat" - if s.Mode == omode.TailClient { - readMode = "tail" - } - - for _, file := range s.Files { - if s.Timeout > 0 { - commands = append(commands, fmt.Sprintf("timeout %d %s %s %s", s.Timeout, readMode, file, regexValue)) - continue - } - commands = append(commands, fmt.Sprintf("%s:%s %s %s", readMode, s.Options, file, regexValue)) - } - - return commands, nil -} - -func (s SessionSpec) readCommands(mode string) ([]string, error) { - switch s.Mode { - case omode.TailClient, omode.CatClient, omode.GrepClient: - default: - return nil, fmt.Errorf("unsupported session mode %s", s.Mode) - } - - regexValue, err := s.serializedRegex() - if err != nil { - return nil, err - } - - var commands []string - for _, file := range s.Files { - commands = append(commands, fmt.Sprintf("%s:%s %s %s", mode, s.Options, file, regexValue)) - } - - return commands, nil -} - -func (s SessionSpec) serializedRegex() (string, error) { - flag := regex.Default - if s.RegexInvert { - flag = regex.Invert - } - - re, err := regex.New(s.Regex, flag) - if err != nil { - return "", err - } - - return re.Serialize() -} - -func splitSessionFiles(what string) []string { - if strings.TrimSpace(what) == "" { - return nil - } - - rawFiles := strings.Split(what, ",") - files := make([]string, 0, len(rawFiles)) - for _, file := range rawFiles { - file = strings.TrimSpace(file) - if file == "" { - continue - } - files = append(files, file) - } - return files + return sessionspec.NewSpec(args) } diff --git a/internal/server/handlers/basehandler.go b/internal/server/handlers/basehandler.go index 030baf9..42cc4cc 100644 --- a/internal/server/handlers/basehandler.go +++ b/internal/server/handlers/basehandler.go @@ -180,28 +180,30 @@ func (h *baseHandler) handleCommand(commandStr string) { h.sendln(h.serverMessages, dlog.Server.Error(h.user, err)) return } - ctx, cancel := context.WithCancel(context.Background()) - go func() { - <-h.done.Done() - cancel() - }() + ctx, _ := h.newCommandContext(context.Background()) + + if err := h.dispatchCommand(ctx, args, argc); err != nil { + h.sendln(h.serverMessages, dlog.Server.Error(h.user, err)) + } +} +func (h *baseHandler) dispatchCommand(ctx context.Context, args []string, argc int) error { parts := strings.Split(args[0], ":") commandName := parts[0] // Either no options or empty options provided. if len(parts) == 1 || len(parts[1]) == 0 { h.handleCommandCb(ctx, lcontext.LContext{}, argc, args, commandName) - return + return nil } options, ltx, err := config.DeserializeOptions(parts[1:]) if err != nil { - h.sendln(h.serverMessages, dlog.Server.Error(h.user, err)) - return + return err } h.handleOptions(options) h.handleCommandCb(ctx, ltx, argc, args, commandName) + return nil } func (h *baseHandler) handleProtocolVersion(args []string) ([]string, int, string, error) { @@ -212,6 +214,30 @@ func (h *baseHandler) handleBase64(args []string, argc int) ([]string, int, erro return h.codec.handleBase64(args, argc) } +func (h *baseHandler) handleRawCommand(ctx context.Context, command string) error { + args := strings.Fields(command) + if len(args) == 0 { + return fmt.Errorf("empty command") + } + return h.dispatchCommand(ctx, args, len(args)) +} + +func (h *baseHandler) newCommandContext(parent context.Context) (context.Context, context.CancelFunc) { + if parent == nil { + parent = context.Background() + } + + ctx, cancel := context.WithCancel(parent) + go func() { + select { + case <-h.done.Done(): + cancel() + case <-ctx.Done(): + } + }() + return ctx, cancel +} + func (h *baseHandler) handleAckCommand(argc int, args []string) { if argc < 3 { if !h.quiet { diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go index 732cc06..e8c234b 100644 --- a/internal/server/handlers/serverhandler.go +++ b/internal/server/handlers/serverhandler.go @@ -12,6 +12,7 @@ import ( "github.com/mimecast/dtail/internal/io/line" "github.com/mimecast/dtail/internal/lcontext" "github.com/mimecast/dtail/internal/omode" + "github.com/mimecast/dtail/internal/protocol" sshserver "github.com/mimecast/dtail/internal/ssh/server" user "github.com/mimecast/dtail/internal/user/server" @@ -29,6 +30,7 @@ type ServerHandler struct { authKeyStore *sshserver.AuthKeyStore regex string commands map[string]commandHandler + sessionState sessionCommandState // Track pending files waiting for limiter slots pendingFiles int32 } @@ -77,6 +79,7 @@ func NewServerHandler(user *user.User, catLimiter, s := strings.Split(fqdn, ".") h.hostname = s[0] + h.send(h.serverMessages, protocol.HiddenCapabilitiesPrefix+protocol.CapabilityQueryUpdateV1) return &h } @@ -112,7 +115,14 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, ltx lcontext.LCon } func shouldShutdownOnCommandCompletion(commandName string) bool { - return !strings.EqualFold(commandName, "AUTHKEY") + switch { + case strings.EqualFold(commandName, "AUTHKEY"): + return false + case strings.EqualFold(commandName, "SESSION"): + return false + default: + return true + } } func (h *ServerHandler) newCommandRegistry() map[string]commandHandler { @@ -123,7 +133,9 @@ func (h *ServerHandler) newCommandRegistry() map[string]commandHandler { "map": h.handleMapCommand, ".ack": h.handleAckUserCommand, "AUTHKEY": h.handleAuthKeyCommand, + "SESSION": h.handleSessionCommand, "authkey": h.handleAuthKeyCommand, + "session": h.handleSessionCommand, } } diff --git a/internal/server/handlers/sessioncommand.go b/internal/server/handlers/sessioncommand.go new file mode 100644 index 0000000..bc5f83e --- /dev/null +++ b/internal/server/handlers/sessioncommand.go @@ -0,0 +1,136 @@ +package handlers + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "strconv" + "strings" + "sync" + + "github.com/mimecast/dtail/internal/lcontext" + "github.com/mimecast/dtail/internal/omode" + "github.com/mimecast/dtail/internal/protocol" + "github.com/mimecast/dtail/internal/session" +) + +const ( + sessionAckStartOKPrefix = ".syn session start ok" + sessionAckUpdateOKPrefix = ".syn session update ok" + sessionAckErrorPrefix = ".syn session err " +) + +type sessionCommandState struct { + mu sync.Mutex + active bool + generation uint64 + spec session.Spec +} + +func (h *ServerHandler) handleSessionCommand(_ context.Context, _ lcontext.LContext, argc int, args []string, commandFinished func()) { + defer commandFinished() + + action, generation, spec, err := parseSessionCommand(args, argc) + if err != nil { + h.send(h.serverMessages, sessionAckErrorPrefix+err.Error()) + return + } + + switch action { + case "START": + h.sessionState.storeStart(spec) + h.send(h.serverMessages, sessionAckStartOKPrefix) + case "UPDATE": + if !h.sessionState.activeSession() { + h.send(h.serverMessages, sessionAckErrorPrefix+"session not started") + return + } + h.sessionState.storeUpdate(spec, generation) + h.send(h.serverMessages, sessionAckUpdateOKPrefix) + default: + h.send(h.serverMessages, sessionAckErrorPrefix+"unknown action") + } +} + +func parseSessionCommand(args []string, argc int) (action string, generation uint64, spec session.Spec, err error) { + if argc < 3 { + return "", 0, spec, fmt.Errorf("invalid SESSION command") + } + + action = strings.ToUpper(strings.TrimSpace(args[1])) + payloadIndex := 2 + if action == "UPDATE" && argc >= 4 { + generation, err = strconv.ParseUint(args[2], 10, 64) + if err != nil { + return "", 0, spec, fmt.Errorf("invalid session generation") + } + payloadIndex = 3 + } + + payload, err := base64.StdEncoding.DecodeString(args[payloadIndex]) + if err != nil { + return "", 0, spec, fmt.Errorf("invalid session payload") + } + if err := json.Unmarshal(payload, &spec); err != nil { + return "", 0, spec, fmt.Errorf("invalid session spec") + } + if err := validateSessionSpec(spec); err != nil { + return "", 0, spec, err + } + + return action, generation, spec, nil +} + +func validateSessionSpec(spec session.Spec) error { + switch spec.Mode { + case omode.TailClient, omode.CatClient, omode.GrepClient, omode.MapClient, omode.HealthClient: + default: + return fmt.Errorf("unsupported session mode") + } + + if spec.Query != "" && spec.Mode != omode.MapClient && spec.Mode != omode.TailClient { + return fmt.Errorf("query sessions require map or tail mode") + } + + if spec.Query == "" && spec.Mode == omode.MapClient { + return fmt.Errorf("missing session query") + } + + if _, err := spec.Commands(); err != nil { + return fmt.Errorf("invalid session spec") + } + + return nil +} + +func (s *sessionCommandState) storeStart(spec session.Spec) { + s.mu.Lock() + defer s.mu.Unlock() + + s.active = true + s.generation = 1 + s.spec = spec +} + +func (s *sessionCommandState) storeUpdate(spec session.Spec, generation uint64) { + s.mu.Lock() + defer s.mu.Unlock() + + s.active = true + if generation == 0 { + generation = s.generation + 1 + } + s.generation = generation + s.spec = spec +} + +func (s *sessionCommandState) activeSession() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.active +} + +func (s *sessionCommandState) advertisedCapabilities() string { + return protocol.HiddenCapabilitiesPrefix + protocol.CapabilityQueryUpdateV1 +} diff --git a/internal/server/handlers/sessioncommand_test.go b/internal/server/handlers/sessioncommand_test.go new file mode 100644 index 0000000..6af8c5b --- /dev/null +++ b/internal/server/handlers/sessioncommand_test.go @@ -0,0 +1,156 @@ +package handlers + +import ( + "context" + "encoding/base64" + "encoding/json" + "testing" + "time" + + "github.com/mimecast/dtail/internal" + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/io/line" + "github.com/mimecast/dtail/internal/lcontext" + "github.com/mimecast/dtail/internal/omode" + "github.com/mimecast/dtail/internal/protocol" + "github.com/mimecast/dtail/internal/session" + userserver "github.com/mimecast/dtail/internal/user/server" +) + +func TestNewServerHandlerAdvertisesQueryUpdateCapability(t *testing.T) { + handler := newSessionTestHandler("session-capability-user") + + if message := readServerMessage(t, handler.serverMessages); message != protocol.HiddenCapabilitiesPrefix+protocol.CapabilityQueryUpdateV1 { + t.Fatalf("unexpected capability advertisement: %q", message) + } +} + +func TestHandleSessionCommandStartStoresSpec(t *testing.T) { + handler := newSessionTestHandler("session-start-user") + readServerMessage(t, handler.serverMessages) + + spec := session.Spec{ + Mode: omode.TailClient, + Files: []string{"/var/log/app.log"}, + Regex: "ERROR", + } + payload := mustSessionPayload(t, spec) + + commandFinished := false + handler.handleSessionCommand(context.Background(), lcontext.LContext{}, 3, []string{"SESSION", "START", payload}, func() { + commandFinished = true + }) + + if !commandFinished { + t.Fatalf("expected commandFinished callback") + } + if !handler.sessionState.activeSession() { + t.Fatalf("expected session state to become active") + } + if message := readServerMessage(t, handler.serverMessages); message != sessionAckStartOKPrefix { + t.Fatalf("unexpected session start message: %q", message) + } +} + +func TestHandleSessionCommandUpdateRequiresActiveSession(t *testing.T) { + handler := newSessionTestHandler("session-update-user") + readServerMessage(t, handler.serverMessages) + + spec := session.Spec{ + Mode: omode.TailClient, + Files: []string{"/var/log/app.log"}, + Regex: "ERROR", + } + payload := mustSessionPayload(t, spec) + + handler.handleSessionCommand(context.Background(), lcontext.LContext{}, 3, []string{"SESSION", "UPDATE", payload}, func() {}) + + if message := readServerMessage(t, handler.serverMessages); message != sessionAckErrorPrefix+"session not started" { + t.Fatalf("unexpected session update error: %q", message) + } +} + +func TestHandleSessionCommandRejectsInvalidPayload(t *testing.T) { + handler := newSessionTestHandler("session-invalid-user") + readServerMessage(t, handler.serverMessages) + + handler.handleSessionCommand(context.Background(), lcontext.LContext{}, 3, []string{"SESSION", "START", "not-base64"}, func() {}) + + if message := readServerMessage(t, handler.serverMessages); message != sessionAckErrorPrefix+"invalid session payload" { + t.Fatalf("unexpected invalid payload message: %q", message) + } +} + +func newSessionTestHandler(userName string) *ServerHandler { + handler := &ServerHandler{ + baseHandler: baseHandler{ + done: internal.NewDone(), + lines: make(chan *line.Line, 4), + serverMessages: make(chan string, 8), + maprMessages: make(chan string, 4), + ackCloseReceived: make(chan struct{}), + user: &userserver.User{Name: userName}, + codec: newProtocolCodec(&userserver.User{Name: userName}), + }, + serverCfg: &config.ServerConfig{ + AuthKeyEnabled: true, + }, + } + handler.send(handler.serverMessages, protocol.HiddenCapabilitiesPrefix+protocol.CapabilityQueryUpdateV1) + return handler +} + +func mustSessionPayload(t *testing.T, spec session.Spec) string { + t.Helper() + + payload, err := json.Marshal(spec) + if err != nil { + t.Fatalf("marshal session spec: %v", err) + } + return base64.StdEncoding.EncodeToString(payload) +} + +func TestParseSessionCommandWithGeneration(t *testing.T) { + spec := session.Spec{ + Mode: omode.TailClient, + Files: []string{"/var/log/app.log"}, + Regex: "ERROR", + } + + action, generation, parsedSpec, err := parseSessionCommand([]string{"SESSION", "UPDATE", "7", mustSessionPayload(t, spec)}, 4) + if err != nil { + t.Fatalf("parseSessionCommand error: %v", err) + } + if action != "UPDATE" { + t.Fatalf("unexpected action: %s", action) + } + if generation != 7 { + t.Fatalf("unexpected generation: %d", generation) + } + if parsedSpec.Mode != spec.Mode { + t.Fatalf("unexpected parsed mode: %v", parsedSpec.Mode) + } +} + +func TestSessionStateStoreUpdateAutoIncrementsGeneration(t *testing.T) { + var state sessionCommandState + + state.storeStart(session.Spec{Mode: omode.TailClient, Files: []string{"/tmp/a"}, Regex: "ERROR"}) + state.storeUpdate(session.Spec{Mode: omode.TailClient, Files: []string{"/tmp/b"}, Regex: "WARN"}, 0) + + state.mu.Lock() + defer state.mu.Unlock() + if state.generation != 2 { + t.Fatalf("unexpected generation: %d", state.generation) + } +} + +func TestSessionCommandReadServerMessageTimeoutProtection(t *testing.T) { + messages := make(chan string) + + select { + case <-messages: + t.Fatalf("unexpected message") + case <-time.After(5 * time.Millisecond): + } +} diff --git a/internal/session/spec.go b/internal/session/spec.go new file mode 100644 index 0000000..2d1b77d --- /dev/null +++ b/internal/session/spec.go @@ -0,0 +1,124 @@ +package session + +import ( + "fmt" + "strings" + + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/omode" + "github.com/mimecast/dtail/internal/regex" +) + +// Spec captures the mutable, per-connection workload a DTail client wants to run. +type Spec struct { + Mode omode.Mode `json:"mode"` + Files []string `json:"files"` + Options string `json:"options,omitempty"` + Query string `json:"query,omitempty"` + Regex string `json:"regex,omitempty"` + RegexInvert bool `json:"regex_invert,omitempty"` + Timeout int `json:"timeout,omitempty"` +} + +// NewSpec returns a session specification from client args. +func NewSpec(args config.Args) Spec { + return Spec{ + Mode: args.Mode, + Files: splitFiles(args.What), + Options: args.SerializeOptions(), + Query: strings.TrimSpace(args.QueryStr), + Regex: args.RegexStr, + RegexInvert: args.RegexInvert, + Timeout: args.Timeout, + } +} + +// Commands returns the legacy command stream for this session specification. +func (s Spec) Commands() ([]string, error) { + switch { + case s.Mode == omode.HealthClient: + return []string{"health"}, nil + case s.Query != "": + return s.queryCommands() + default: + return s.readCommands(s.Mode.String()) + } +} + +func (s Spec) queryCommands() ([]string, error) { + if s.Mode != omode.MapClient && s.Mode != omode.TailClient { + return nil, fmt.Errorf("session spec query mode requires map or tail mode, got %s", s.Mode) + } + + regexValue, err := s.serializedRegex() + if err != nil { + return nil, err + } + + commands := []string{fmt.Sprintf("map:%s %s", s.Options, s.Query)} + readMode := "cat" + if s.Mode == omode.TailClient { + readMode = "tail" + } + + for _, file := range s.Files { + if s.Timeout > 0 { + commands = append(commands, fmt.Sprintf("timeout %d %s %s %s", s.Timeout, readMode, file, regexValue)) + continue + } + commands = append(commands, fmt.Sprintf("%s:%s %s %s", readMode, s.Options, file, regexValue)) + } + + return commands, nil +} + +func (s Spec) readCommands(mode string) ([]string, error) { + switch s.Mode { + case omode.TailClient, omode.CatClient, omode.GrepClient: + default: + return nil, fmt.Errorf("unsupported session mode %s", s.Mode) + } + + regexValue, err := s.serializedRegex() + if err != nil { + return nil, err + } + + var commands []string + for _, file := range s.Files { + commands = append(commands, fmt.Sprintf("%s:%s %s %s", mode, s.Options, file, regexValue)) + } + + return commands, nil +} + +func (s Spec) serializedRegex() (string, error) { + flag := regex.Default + if s.RegexInvert { + flag = regex.Invert + } + + re, err := regex.New(s.Regex, flag) + if err != nil { + return "", err + } + + return re.Serialize() +} + +func splitFiles(what string) []string { + if strings.TrimSpace(what) == "" { + return nil + } + + rawFiles := strings.Split(what, ",") + files := make([]string, 0, len(rawFiles)) + for _, file := range rawFiles { + file = strings.TrimSpace(file) + if file == "" { + continue + } + files = append(files, file) + } + return files +} |
