summaryrefslogtreecommitdiff
path: root/internal/ssh/client/authmethods_test.go
diff options
context:
space:
mode:
authorPaul Buetow <paul@buetow.org>2026-03-03 10:15:02 +0200
committerPaul Buetow <paul@buetow.org>2026-03-03 10:15:02 +0200
commitf17ffe1bae2f176e4dda90ff4dd2cb267332a7b4 (patch)
tree2feecb488d4f90efb25b0dd7cbb2c5718926f3e2 /internal/ssh/client/authmethods_test.go
parent7d3685a5ed4bfac85673793f8ae6d9c5a6cff962 (diff)
feat(ssh-client): collect auth methods in fallback order
Diffstat (limited to 'internal/ssh/client/authmethods_test.go')
-rw-r--r--internal/ssh/client/authmethods_test.go106
1 files changed, 106 insertions, 0 deletions
diff --git a/internal/ssh/client/authmethods_test.go b/internal/ssh/client/authmethods_test.go
new file mode 100644
index 0000000..04751f5
--- /dev/null
+++ b/internal/ssh/client/authmethods_test.go
@@ -0,0 +1,106 @@
+package client
+
+import (
+ "fmt"
+ "reflect"
+ "testing"
+
+ "github.com/mimecast/dtail/internal/io/dlog"
+
+ gossh "golang.org/x/crypto/ssh"
+)
+
+func TestCollectKnownHostsAuthMethodsOrder(t *testing.T) {
+ homeDir := "/tmp/dtail-auth-order"
+ t.Setenv("HOME", homeDir)
+
+ originalPrivateKeyAuthMethod := privateKeyAuthMethod
+ originalAgentAuthMethod := agentAuthMethod
+ originalLogger := dlog.Client
+ dlog.Client = &dlog.DLog{}
+ t.Cleanup(func() {
+ privateKeyAuthMethod = originalPrivateKeyAuthMethod
+ agentAuthMethod = originalAgentAuthMethod
+ dlog.Client = originalLogger
+ })
+
+ var callOrder []string
+ successfulPrivateKeys := map[string]bool{
+ "/custom/id_fast": true,
+ homeDir + "/.ssh/id_rsa": true,
+ homeDir + "/.ssh/id_dsa": true,
+ }
+
+ privateKeyAuthMethod = func(path string) (gossh.AuthMethod, error) {
+ callOrder = append(callOrder, "private:"+path)
+ if !successfulPrivateKeys[path] {
+ return nil, fmt.Errorf("missing private key: %s", path)
+ }
+ return gossh.Password(path), nil
+ }
+ agentAuthMethod = func(keyIndex int) (gossh.AuthMethod, error) {
+ callOrder = append(callOrder, fmt.Sprintf("agent:%d", keyIndex))
+ return gossh.Password("agent"), nil
+ }
+
+ methods := collectKnownHostsAuthMethods("/custom/id_fast", 7)
+ if len(methods) != 4 {
+ t.Fatalf("Expected 4 auth methods, got %d", len(methods))
+ }
+
+ expectedOrder := []string{
+ "private:/custom/id_fast",
+ "agent:7",
+ "private:/tmp/dtail-auth-order/.ssh/id_rsa",
+ "private:/tmp/dtail-auth-order/.ssh/id_dsa",
+ "private:/tmp/dtail-auth-order/.ssh/id_ecdsa",
+ "private:/tmp/dtail-auth-order/.ssh/id_ed25519",
+ }
+ if !reflect.DeepEqual(callOrder, expectedOrder) {
+ t.Fatalf("Unexpected auth method call order.\nexpected: %v\ngot: %v", expectedOrder, callOrder)
+ }
+}
+
+func TestCollectKnownHostsAuthMethodsSkipsDuplicateDefaultPath(t *testing.T) {
+ homeDir := "/tmp/dtail-auth-dedupe"
+ t.Setenv("HOME", homeDir)
+
+ originalPrivateKeyAuthMethod := privateKeyAuthMethod
+ originalAgentAuthMethod := agentAuthMethod
+ originalLogger := dlog.Client
+ dlog.Client = &dlog.DLog{}
+ t.Cleanup(func() {
+ privateKeyAuthMethod = originalPrivateKeyAuthMethod
+ agentAuthMethod = originalAgentAuthMethod
+ dlog.Client = originalLogger
+ })
+
+ var callOrder []string
+ privateKeyAuthMethod = func(path string) (gossh.AuthMethod, error) {
+ callOrder = append(callOrder, "private:"+path)
+ if path == homeDir+"/.ssh/id_rsa" {
+ return gossh.Password(path), nil
+ }
+ return nil, fmt.Errorf("missing private key: %s", path)
+ }
+ agentAuthMethod = func(keyIndex int) (gossh.AuthMethod, error) {
+ callOrder = append(callOrder, fmt.Sprintf("agent:%d", keyIndex))
+ return gossh.Password("agent"), nil
+ }
+
+ methods := collectKnownHostsAuthMethods(homeDir+"/.ssh/id_rsa", 2)
+ if len(methods) != 2 {
+ t.Fatalf("Expected 2 auth methods, got %d", len(methods))
+ }
+
+ expectedOrder := []string{
+ "private:/tmp/dtail-auth-dedupe/.ssh/id_rsa",
+ "agent:2",
+ "private:/tmp/dtail-auth-dedupe/.ssh/id_dsa",
+ "private:/tmp/dtail-auth-dedupe/.ssh/id_ecdsa",
+ "private:/tmp/dtail-auth-dedupe/.ssh/id_ed25519",
+ }
+ if !reflect.DeepEqual(callOrder, expectedOrder) {
+ t.Fatalf("Unexpected auth method call order.\nexpected: %v\ngot: %v", expectedOrder, callOrder)
+ }
+}