summaryrefslogtreecommitdiff
path: root/internal/clients/connectors/serverconnection_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/clients/connectors/serverconnection_test.go')
-rw-r--r--internal/clients/connectors/serverconnection_test.go88
1 files changed, 88 insertions, 0 deletions
diff --git a/internal/clients/connectors/serverconnection_test.go b/internal/clients/connectors/serverconnection_test.go
index 8ab126b..227a1e9 100644
--- a/internal/clients/connectors/serverconnection_test.go
+++ b/internal/clients/connectors/serverconnection_test.go
@@ -1,12 +1,16 @@
package connectors
import (
+ "context"
"os"
"path/filepath"
"testing"
+ "time"
"github.com/mimecast/dtail/internal/clients/handlers"
"github.com/mimecast/dtail/internal/io/dlog"
+
+ "golang.org/x/crypto/ssh"
)
func TestExtractAuthKeyBase64(t *testing.T) {
@@ -76,6 +80,90 @@ func TestSendAuthKeyRegistrationCommand(t *testing.T) {
}
}
+func TestNewServerConnectionUsesInjectedSettings(t *testing.T) {
+ resetClientLogger(t)
+
+ conn := NewServerConnection(
+ "srv1",
+ "user",
+ nil,
+ testHostKeyCallback{},
+ &mockHandler{},
+ nil,
+ "",
+ false,
+ testSSHSettings{port: 3022, timeout: 5 * time.Second},
+ )
+
+ if conn.hostname != "srv1" {
+ t.Fatalf("Expected hostname srv1, got %q", conn.hostname)
+ }
+ if conn.port != 3022 {
+ t.Fatalf("Expected injected port 3022, got %d", conn.port)
+ }
+ if conn.config.Timeout != 5*time.Second {
+ t.Fatalf("Expected injected timeout 5s, got %v", conn.config.Timeout)
+ }
+}
+
+func TestNewServerConnectionFallsBackToDefaults(t *testing.T) {
+ resetClientLogger(t)
+
+ conn := NewServerConnection(
+ "srv1",
+ "user",
+ nil,
+ testHostKeyCallback{},
+ &mockHandler{},
+ nil,
+ "",
+ false,
+ testSSHSettings{},
+ )
+
+ if conn.port != defaultSSHPort {
+ t.Fatalf("Expected default port %d, got %d", defaultSSHPort, conn.port)
+ }
+ if conn.config.Timeout != defaultSSHConnectTimeout {
+ t.Fatalf("Expected default timeout %v, got %v", defaultSSHConnectTimeout, conn.config.Timeout)
+ }
+}
+
+type testSSHSettings struct {
+ port int
+ timeout time.Duration
+}
+
+func (s testSSHSettings) SSHPort() int {
+ return s.port
+}
+
+func (s testSSHSettings) SSHConnectTimeout() time.Duration {
+ return s.timeout
+}
+
+type testHostKeyCallback struct{}
+
+func (testHostKeyCallback) Wrap() ssh.HostKeyCallback {
+ return ssh.InsecureIgnoreHostKey()
+}
+
+func (testHostKeyCallback) Untrusted(string) bool {
+ return false
+}
+
+func (testHostKeyCallback) PromptAddHosts(context.Context) {}
+
+func resetClientLogger(t *testing.T) {
+ t.Helper()
+
+ originalLogger := dlog.Client
+ dlog.Client = &dlog.DLog{}
+ t.Cleanup(func() {
+ dlog.Client = originalLogger
+ })
+}
+
type mockHandler struct {
commands []string
}