diff options
Diffstat (limited to 'internal/clients/connectors/serverconnection_test.go')
| -rw-r--r-- | internal/clients/connectors/serverconnection_test.go | 88 |
1 files changed, 88 insertions, 0 deletions
diff --git a/internal/clients/connectors/serverconnection_test.go b/internal/clients/connectors/serverconnection_test.go index 8ab126b..227a1e9 100644 --- a/internal/clients/connectors/serverconnection_test.go +++ b/internal/clients/connectors/serverconnection_test.go @@ -1,12 +1,16 @@ package connectors import ( + "context" "os" "path/filepath" "testing" + "time" "github.com/mimecast/dtail/internal/clients/handlers" "github.com/mimecast/dtail/internal/io/dlog" + + "golang.org/x/crypto/ssh" ) func TestExtractAuthKeyBase64(t *testing.T) { @@ -76,6 +80,90 @@ func TestSendAuthKeyRegistrationCommand(t *testing.T) { } } +func TestNewServerConnectionUsesInjectedSettings(t *testing.T) { + resetClientLogger(t) + + conn := NewServerConnection( + "srv1", + "user", + nil, + testHostKeyCallback{}, + &mockHandler{}, + nil, + "", + false, + testSSHSettings{port: 3022, timeout: 5 * time.Second}, + ) + + if conn.hostname != "srv1" { + t.Fatalf("Expected hostname srv1, got %q", conn.hostname) + } + if conn.port != 3022 { + t.Fatalf("Expected injected port 3022, got %d", conn.port) + } + if conn.config.Timeout != 5*time.Second { + t.Fatalf("Expected injected timeout 5s, got %v", conn.config.Timeout) + } +} + +func TestNewServerConnectionFallsBackToDefaults(t *testing.T) { + resetClientLogger(t) + + conn := NewServerConnection( + "srv1", + "user", + nil, + testHostKeyCallback{}, + &mockHandler{}, + nil, + "", + false, + testSSHSettings{}, + ) + + if conn.port != defaultSSHPort { + t.Fatalf("Expected default port %d, got %d", defaultSSHPort, conn.port) + } + if conn.config.Timeout != defaultSSHConnectTimeout { + t.Fatalf("Expected default timeout %v, got %v", defaultSSHConnectTimeout, conn.config.Timeout) + } +} + +type testSSHSettings struct { + port int + timeout time.Duration +} + +func (s testSSHSettings) SSHPort() int { + return s.port +} + +func (s testSSHSettings) SSHConnectTimeout() time.Duration { + return s.timeout +} + +type testHostKeyCallback struct{} + +func (testHostKeyCallback) Wrap() ssh.HostKeyCallback { + return ssh.InsecureIgnoreHostKey() +} + +func (testHostKeyCallback) Untrusted(string) bool { + return false +} + +func (testHostKeyCallback) PromptAddHosts(context.Context) {} + +func resetClientLogger(t *testing.T) { + t.Helper() + + originalLogger := dlog.Client + dlog.Client = &dlog.DLog{} + t.Cleanup(func() { + dlog.Client = originalLogger + }) +} + type mockHandler struct { commands []string } |
