summaryrefslogtreecommitdiff
path: root/internal/clients/connectors/serverconnection_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/clients/connectors/serverconnection_test.go')
-rw-r--r--internal/clients/connectors/serverconnection_test.go174
1 files changed, 174 insertions, 0 deletions
diff --git a/internal/clients/connectors/serverconnection_test.go b/internal/clients/connectors/serverconnection_test.go
index 9307b24..76c4eb6 100644
--- a/internal/clients/connectors/serverconnection_test.go
+++ b/internal/clients/connectors/serverconnection_test.go
@@ -2,6 +2,7 @@ package connectors
import (
"context"
+ "errors"
"os"
"path/filepath"
"testing"
@@ -9,7 +10,9 @@ import (
"github.com/mimecast/dtail/internal/clients/handlers"
"github.com/mimecast/dtail/internal/io/dlog"
+ "github.com/mimecast/dtail/internal/omode"
"github.com/mimecast/dtail/internal/protocol"
+ sessionspec "github.com/mimecast/dtail/internal/session"
"golang.org/x/crypto/ssh"
)
@@ -91,6 +94,8 @@ func TestNewServerConnectionUsesInjectedSettings(t *testing.T) {
testHostKeyCallback{},
&mockHandler{},
nil,
+ sessionspec.Spec{},
+ false,
"",
false,
testSSHSettings{port: 3022, timeout: 5 * time.Second},
@@ -117,6 +122,8 @@ func TestNewServerConnectionFallsBackToDefaults(t *testing.T) {
testHostKeyCallback{},
&mockHandler{},
nil,
+ sessionspec.Spec{},
+ false,
"",
false,
testSSHSettings{},
@@ -173,6 +180,159 @@ func TestServerConnectionSupportsQueryUpdatesRequiresCapabilityFlag(t *testing.T
}
}
+func TestServerConnectionApplySessionSpecStart(t *testing.T) {
+ resetClientLogger(t)
+
+ conn := &ServerConnection{
+ server: "srv1",
+ handler: &mockHandler{
+ waitForCapabilities: true,
+ capabilities: map[string]bool{
+ protocol.CapabilityQueryUpdateV1: true,
+ },
+ sessionAcks: []handlers.SessionAck{{
+ Action: "start",
+ Generation: 1,
+ }},
+ },
+ }
+
+ spec := sessionspec.Spec{
+ Mode: omode.TailClient,
+ Files: []string{"/var/log/app.log"},
+ Regex: "ERROR",
+ }
+ if err := conn.ApplySessionSpec(spec, 10*time.Millisecond); err != nil {
+ t.Fatalf("ApplySessionSpec() error = %v", err)
+ }
+
+ mock := conn.handler.(*mockHandler)
+ if len(mock.commands) != 1 {
+ t.Fatalf("expected one session command, got %d", len(mock.commands))
+ }
+ if committedSpec, generation, ok := conn.CommittedSession(); !ok || generation != 1 || committedSpec.Regex != "ERROR" {
+ t.Fatalf("unexpected committed session: spec=%#v generation=%d ok=%v", committedSpec, generation, ok)
+ }
+}
+
+func TestServerConnectionApplySessionSpecUpdateUsesNextGeneration(t *testing.T) {
+ resetClientLogger(t)
+
+ mock := &mockHandler{
+ waitForCapabilities: true,
+ capabilities: map[string]bool{
+ protocol.CapabilityQueryUpdateV1: true,
+ },
+ sessionAcks: []handlers.SessionAck{
+ {Action: "start", Generation: 4},
+ {Action: "update", Generation: 5},
+ },
+ }
+ conn := &ServerConnection{
+ server: "srv1",
+ handler: mock,
+ }
+
+ startSpec := sessionspec.Spec{
+ Mode: omode.TailClient,
+ Files: []string{"/var/log/app.log"},
+ Regex: "ERROR",
+ }
+ updateSpec := sessionspec.Spec{
+ Mode: omode.TailClient,
+ Files: []string{"/var/log/app.log"},
+ Regex: "WARN",
+ }
+
+ if err := conn.ApplySessionSpec(startSpec, 10*time.Millisecond); err != nil {
+ t.Fatalf("start ApplySessionSpec() error = %v", err)
+ }
+ if err := conn.ApplySessionSpec(updateSpec, 10*time.Millisecond); err != nil {
+ t.Fatalf("update ApplySessionSpec() error = %v", err)
+ }
+ if len(mock.commands) != 2 {
+ t.Fatalf("expected two session commands, got %d", len(mock.commands))
+ }
+ if committedSpec, generation, ok := conn.CommittedSession(); !ok || generation != 5 || committedSpec.Regex != "WARN" {
+ t.Fatalf("unexpected committed session after update: spec=%#v generation=%d ok=%v", committedSpec, generation, ok)
+ }
+}
+
+func TestServerConnectionApplySessionSpecFallsBackForUnsupportedServer(t *testing.T) {
+ resetClientLogger(t)
+
+ conn := &ServerConnection{
+ handler: &mockHandler{},
+ }
+
+ err := conn.ApplySessionSpec(sessionspec.Spec{Mode: omode.TailClient, Regex: "ERROR"}, 5*time.Millisecond)
+ if !errors.Is(err, ErrSessionUnsupported) {
+ t.Fatalf("expected ErrSessionUnsupported, got %v", err)
+ }
+}
+
+func TestServerConnectionApplySessionSpecPreservesCommittedStateOnRejectedUpdate(t *testing.T) {
+ resetClientLogger(t)
+
+ mock := &mockHandler{
+ waitForCapabilities: true,
+ capabilities: map[string]bool{
+ protocol.CapabilityQueryUpdateV1: true,
+ },
+ sessionAcks: []handlers.SessionAck{
+ {Action: "start", Generation: 2},
+ {Action: "error", Error: "bad reload"},
+ },
+ }
+ conn := &ServerConnection{
+ server: "srv1",
+ handler: mock,
+ }
+
+ startSpec := sessionspec.Spec{Mode: omode.TailClient, Regex: "ERROR"}
+ if err := conn.ApplySessionSpec(startSpec, 10*time.Millisecond); err != nil {
+ t.Fatalf("start ApplySessionSpec() error = %v", err)
+ }
+
+ err := conn.ApplySessionSpec(sessionspec.Spec{Mode: omode.TailClient, Regex: "WARN"}, 10*time.Millisecond)
+ if !errors.Is(err, ErrSessionRejected) {
+ t.Fatalf("expected ErrSessionRejected, got %v", err)
+ }
+ if committedSpec, generation, ok := conn.CommittedSession(); !ok || generation != 2 || committedSpec.Regex != "ERROR" {
+ t.Fatalf("unexpected committed session after rejected update: spec=%#v generation=%d ok=%v", committedSpec, generation, ok)
+ }
+}
+
+func TestServerConnectionApplySessionSpecRejectsUnexpectedAck(t *testing.T) {
+ resetClientLogger(t)
+
+ mock := &mockHandler{
+ waitForCapabilities: true,
+ capabilities: map[string]bool{
+ protocol.CapabilityQueryUpdateV1: true,
+ },
+ sessionAcks: []handlers.SessionAck{
+ {Action: "update", Generation: 1},
+ },
+ }
+ conn := &ServerConnection{
+ server: "srv1",
+ handler: mock,
+ }
+
+ err := conn.ApplySessionSpec(sessionspec.Spec{
+ Mode: omode.TailClient,
+ Files: []string{"/var/log/app.log"},
+ Regex: "ERROR",
+ }, 10*time.Millisecond)
+ if !errors.Is(err, ErrUnexpectedSessionAck) {
+ t.Fatalf("expected ErrUnexpectedSessionAck, got %v", err)
+ }
+ if _, _, ok := conn.CommittedSession(); ok {
+ t.Fatalf("unexpected committed session after mismatched ack")
+ }
+}
+
type testSSHSettings struct {
port int
timeout time.Duration
@@ -212,6 +372,7 @@ type mockHandler struct {
commands []string
capabilities map[string]bool
waitForCapabilities bool
+ sessionAcks []handlers.SessionAck
}
var _ handlers.Handler = (*mockHandler)(nil)
@@ -253,6 +414,19 @@ func (m *mockHandler) WaitForCapabilities(timeout time.Duration) bool {
return m.waitForCapabilities
}
+func (m *mockHandler) WaitForSessionAck(timeout time.Duration) (handlers.SessionAck, bool) {
+ if timeout <= 0 {
+ return handlers.SessionAck{}, false
+ }
+ if len(m.sessionAcks) == 0 {
+ return handlers.SessionAck{}, false
+ }
+
+ ack := m.sessionAcks[0]
+ m.sessionAcks = m.sessionAcks[1:]
+ return ack, true
+}
+
func (m *mockHandler) Read(_ []byte) (int, error) {
return 0, nil
}