summaryrefslogtreecommitdiff
path: root/internal/ssh/client/authmethods_test.go
blob: 04751f5d70cc10a02ffac074005405de638135bc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
package client

import (
	"fmt"
	"reflect"
	"testing"

	"github.com/mimecast/dtail/internal/io/dlog"

	gossh "golang.org/x/crypto/ssh"
)

func TestCollectKnownHostsAuthMethodsOrder(t *testing.T) {
	homeDir := "/tmp/dtail-auth-order"
	t.Setenv("HOME", homeDir)

	originalPrivateKeyAuthMethod := privateKeyAuthMethod
	originalAgentAuthMethod := agentAuthMethod
	originalLogger := dlog.Client
	dlog.Client = &dlog.DLog{}
	t.Cleanup(func() {
		privateKeyAuthMethod = originalPrivateKeyAuthMethod
		agentAuthMethod = originalAgentAuthMethod
		dlog.Client = originalLogger
	})

	var callOrder []string
	successfulPrivateKeys := map[string]bool{
		"/custom/id_fast":        true,
		homeDir + "/.ssh/id_rsa": true,
		homeDir + "/.ssh/id_dsa": true,
	}

	privateKeyAuthMethod = func(path string) (gossh.AuthMethod, error) {
		callOrder = append(callOrder, "private:"+path)
		if !successfulPrivateKeys[path] {
			return nil, fmt.Errorf("missing private key: %s", path)
		}
		return gossh.Password(path), nil
	}
	agentAuthMethod = func(keyIndex int) (gossh.AuthMethod, error) {
		callOrder = append(callOrder, fmt.Sprintf("agent:%d", keyIndex))
		return gossh.Password("agent"), nil
	}

	methods := collectKnownHostsAuthMethods("/custom/id_fast", 7)
	if len(methods) != 4 {
		t.Fatalf("Expected 4 auth methods, got %d", len(methods))
	}

	expectedOrder := []string{
		"private:/custom/id_fast",
		"agent:7",
		"private:/tmp/dtail-auth-order/.ssh/id_rsa",
		"private:/tmp/dtail-auth-order/.ssh/id_dsa",
		"private:/tmp/dtail-auth-order/.ssh/id_ecdsa",
		"private:/tmp/dtail-auth-order/.ssh/id_ed25519",
	}
	if !reflect.DeepEqual(callOrder, expectedOrder) {
		t.Fatalf("Unexpected auth method call order.\nexpected: %v\ngot:      %v", expectedOrder, callOrder)
	}
}

func TestCollectKnownHostsAuthMethodsSkipsDuplicateDefaultPath(t *testing.T) {
	homeDir := "/tmp/dtail-auth-dedupe"
	t.Setenv("HOME", homeDir)

	originalPrivateKeyAuthMethod := privateKeyAuthMethod
	originalAgentAuthMethod := agentAuthMethod
	originalLogger := dlog.Client
	dlog.Client = &dlog.DLog{}
	t.Cleanup(func() {
		privateKeyAuthMethod = originalPrivateKeyAuthMethod
		agentAuthMethod = originalAgentAuthMethod
		dlog.Client = originalLogger
	})

	var callOrder []string
	privateKeyAuthMethod = func(path string) (gossh.AuthMethod, error) {
		callOrder = append(callOrder, "private:"+path)
		if path == homeDir+"/.ssh/id_rsa" {
			return gossh.Password(path), nil
		}
		return nil, fmt.Errorf("missing private key: %s", path)
	}
	agentAuthMethod = func(keyIndex int) (gossh.AuthMethod, error) {
		callOrder = append(callOrder, fmt.Sprintf("agent:%d", keyIndex))
		return gossh.Password("agent"), nil
	}

	methods := collectKnownHostsAuthMethods(homeDir+"/.ssh/id_rsa", 2)
	if len(methods) != 2 {
		t.Fatalf("Expected 2 auth methods, got %d", len(methods))
	}

	expectedOrder := []string{
		"private:/tmp/dtail-auth-dedupe/.ssh/id_rsa",
		"agent:2",
		"private:/tmp/dtail-auth-dedupe/.ssh/id_dsa",
		"private:/tmp/dtail-auth-dedupe/.ssh/id_ecdsa",
		"private:/tmp/dtail-auth-dedupe/.ssh/id_ed25519",
	}
	if !reflect.DeepEqual(callOrder, expectedOrder) {
		t.Fatalf("Unexpected auth method call order.\nexpected: %v\ngot:      %v", expectedOrder, callOrder)
	}
}