diff options
Diffstat (limited to 'internal/ssh/client/authmethods.go')
| -rw-r--r-- | internal/ssh/client/authmethods.go | 93 |
1 files changed, 53 insertions, 40 deletions
diff --git a/internal/ssh/client/authmethods.go b/internal/ssh/client/authmethods.go index 1a4cb3f..a414ade 100644 --- a/internal/ssh/client/authmethods.go +++ b/internal/ssh/client/authmethods.go @@ -11,7 +11,10 @@ import ( gossh "golang.org/x/crypto/ssh" ) -const addedPathStr string = "Added path to list of auth methods, not adding further methods" +var ( + privateKeyAuthMethod = ssh.PrivateKey + agentAuthMethod = ssh.AgentWithKeyIndex +) // InitSSHAuthMethods initialises all known SSH auth methods on the client side. func InitSSHAuthMethods(sshAuthMethods []gossh.AuthMethod, @@ -39,14 +42,13 @@ func initIntegrationTestKnownHostsAuthMethods() []gossh.AuthMethod { } sshAuthMethods = append(sshAuthMethods, authMethod) - dlog.Client.Debug("initKnownHostsAuthMethods", addedPathStr, privateKeyPath) + dlog.Client.Debug("initKnownHostsAuthMethods", "Added private key auth method", privateKeyPath) return sshAuthMethods } func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{}, privateKeyPath string, agentKeyIndex int) ([]gossh.AuthMethod, HostKeyCallback) { - var sshAuthMethods []gossh.AuthMethod knownHostsFile := fmt.Sprintf("%s/.ssh/known_hosts", os.Getenv("HOME")) if config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") { // In case of integration test, override known hosts file path. @@ -63,54 +65,65 @@ func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{}, return initIntegrationTestKnownHostsAuthMethods(), knownHostsCallback } - // Try to read custom private key path. - if privateKeyPath != "" { - authMethod, err := ssh.PrivateKey(privateKeyPath) - if err == nil { - sshAuthMethods = append(sshAuthMethods, authMethod) - dlog.Client.Debug("initKnownHostsAuthMethods", addedPathStr, privateKeyPath) - return sshAuthMethods, knownHostsCallback - } - dlog.Client.FatalPanic("Unable to use private SSH key", privateKeyPath, err) + sshAuthMethods := collectKnownHostsAuthMethods(privateKeyPath, agentKeyIndex) + if len(sshAuthMethods) == 0 { + dlog.Client.FatalPanic("Unable to find private SSH key information") } - // Second, try SSH Agent - authMethod, err := ssh.AgentWithKeyIndex(agentKeyIndex) - if err == nil { - sshAuthMethods = append(sshAuthMethods, authMethod) - dlog.Client.Debug("initKnownHostsAuthMethods", "Added SSH Agent (SSH_AUTH_SOCK)"+ - "to list of auth methods, not adding further methods") - return sshAuthMethods, knownHostsCallback + return sshAuthMethods, knownHostsCallback +} + +func collectKnownHostsAuthMethods(privateKeyPath string, agentKeyIndex int) []gossh.AuthMethod { + var sshAuthMethods []gossh.AuthMethod + + home := os.Getenv("HOME") + defaultPrivateKeyPaths := []string{ + home + "/.ssh/id_rsa", + home + "/.ssh/id_dsa", + home + "/.ssh/id_ecdsa", + home + "/.ssh/id_ed25519", } - dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to init SSH Agent auth method", err) - // Third, try Linux/UNIX default key paths - privateKeyPath = os.Getenv("HOME") + "/.ssh/id_rsa" - authMethod, err = ssh.PrivateKey(privateKeyPath) - if err == nil { - sshAuthMethods = append(sshAuthMethods, authMethod) - dlog.Client.Debug("initKnownHostsAuthmethods", addedPathStr, privateKeyPath) - return sshAuthMethods, knownHostsCallback + if privateKeyPath == "" { + privateKeyPath = defaultPrivateKeyPaths[0] } - dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to use private key", privateKeyPath, err) - privateKeyPath = os.Getenv("HOME") + "/.ssh/id_dsa" - authMethod, err = ssh.PrivateKey(privateKeyPath) - if err == nil { + addedPrivateKeyPaths := make(map[string]bool, len(defaultPrivateKeyPaths)+1) + addPrivateKeyAuthMethod := func(path string) { + if path == "" { + return + } + if addedPrivateKeyPaths[path] { + return + } + + authMethod, err := privateKeyAuthMethod(path) + if err != nil { + dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to use private key", path, err) + return + } + sshAuthMethods = append(sshAuthMethods, authMethod) - dlog.Client.Debug("initKnownHostsAuthmethods", addedPathStr, privateKeyPath) - return sshAuthMethods, knownHostsCallback + addedPrivateKeyPaths[path] = true + dlog.Client.Debug("initKnownHostsAuthMethods", "Added private key auth method", path) } - privateKeyPath = os.Getenv("HOME") + "/.ssh/id_ecdsa" - authMethod, err = ssh.PrivateKey(privateKeyPath) + // First, the explicit auth key path (or default ~/.ssh/id_rsa). + addPrivateKeyAuthMethod(privateKeyPath) + + // Second, SSH agent (YubiKey-backed keys are typically exposed here). + authMethod, err := agentAuthMethod(agentKeyIndex) if err == nil { sshAuthMethods = append(sshAuthMethods, authMethod) - dlog.Client.Debug("initKnownHostsAuthmethods", addedPathStr, privateKeyPath) - return sshAuthMethods, knownHostsCallback + dlog.Client.Debug("initKnownHostsAuthMethods", "Added SSH agent auth method") + } else { + dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to init SSH Agent auth method", err) } - dlog.Client.FatalPanic("Unable to find private SSH key information", privateKeyPath, err) - // Never reach this point. - return sshAuthMethods, knownHostsCallback + // Third, additional default private key paths. + for _, path := range defaultPrivateKeyPaths { + addPrivateKeyAuthMethod(path) + } + + return sshAuthMethods } |
