diff options
Diffstat (limited to 'internal/ssh')
| -rw-r--r-- | internal/ssh/client/authmethods.go | 8 | ||||
| -rw-r--r-- | internal/ssh/ssh.go | 30 |
2 files changed, 33 insertions, 5 deletions
diff --git a/internal/ssh/client/authmethods.go b/internal/ssh/client/authmethods.go index 6128018..1a4cb3f 100644 --- a/internal/ssh/client/authmethods.go +++ b/internal/ssh/client/authmethods.go @@ -16,7 +16,7 @@ const addedPathStr string = "Added path to list of auth methods, not adding furt // InitSSHAuthMethods initialises all known SSH auth methods on the client side. func InitSSHAuthMethods(sshAuthMethods []gossh.AuthMethod, hostKeyCallback gossh.HostKeyCallback, trustAllHosts bool, throttleCh chan struct{}, - privateKeyPath string) ([]gossh.AuthMethod, HostKeyCallback) { + privateKeyPath string, agentKeyIndex int) ([]gossh.AuthMethod, HostKeyCallback) { if len(sshAuthMethods) > 0 { simpleCallback, err := NewSimpleCallback() @@ -25,7 +25,7 @@ func InitSSHAuthMethods(sshAuthMethods []gossh.AuthMethod, } return sshAuthMethods, simpleCallback } - return initKnownHostsAuthMethods(trustAllHosts, throttleCh, privateKeyPath) + return initKnownHostsAuthMethods(trustAllHosts, throttleCh, privateKeyPath, agentKeyIndex) } func initIntegrationTestKnownHostsAuthMethods() []gossh.AuthMethod { @@ -44,7 +44,7 @@ func initIntegrationTestKnownHostsAuthMethods() []gossh.AuthMethod { } func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{}, - privateKeyPath string) ([]gossh.AuthMethod, HostKeyCallback) { + privateKeyPath string, agentKeyIndex int) ([]gossh.AuthMethod, HostKeyCallback) { var sshAuthMethods []gossh.AuthMethod knownHostsFile := fmt.Sprintf("%s/.ssh/known_hosts", os.Getenv("HOME")) @@ -75,7 +75,7 @@ func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{}, } // Second, try SSH Agent - authMethod, err := ssh.Agent() + authMethod, err := ssh.AgentWithKeyIndex(agentKeyIndex) if err == nil { sshAuthMethods = append(sshAuthMethods, authMethod) dlog.Client.Debug("initKnownHostsAuthMethods", "Added SSH Agent (SSH_AUTH_SOCK)"+ diff --git a/internal/ssh/ssh.go b/internal/ssh/ssh.go index 32e01b3..41cce05 100644 --- a/internal/ssh/ssh.go +++ b/internal/ssh/ssh.go @@ -43,6 +43,12 @@ func EncodePrivateKeyToPEM(privateKey *rsa.PrivateKey) []byte { // Agent used for SSH auth. func Agent() (gossh.AuthMethod, error) { + return AgentWithKeyIndex(-1) +} + +// AgentWithKeyIndex used for SSH auth with a specific key index from the agent. +// If keyIndex is -1, all keys are used. Otherwise, only the specified key is used. +func AgentWithKeyIndex(keyIndex int) (gossh.AuthMethod, error) { sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")) if err != nil { return nil, fmt.Errorf("failed to connect to SSH agent: %w", err) @@ -55,7 +61,29 @@ func Agent() (gossh.AuthMethod, error) { for i, key := range keys { dlog.Common.Debug("Public key", i, key) } - return gossh.PublicKeysCallback(agentClient.Signers), nil + + // If no specific key index requested, use all keys (backwards compatible default) + if keyIndex < 0 { + return gossh.PublicKeysCallback(agentClient.Signers), nil + } + + // Use only the specified key index (0-based) + if keyIndex >= len(keys) { + return nil, fmt.Errorf("key index %d out of range (agent has %d keys)", keyIndex, len(keys)) + } + + dlog.Common.Debug("Using SSH agent key at index", keyIndex) + return gossh.PublicKeysCallback(func() ([]gossh.Signer, error) { + signers, err := agentClient.Signers() + if err != nil { + return nil, err + } + if keyIndex >= len(signers) { + return nil, fmt.Errorf("key index %d out of range (agent has %d signers)", keyIndex, len(signers)) + } + // Return only the specified signer + return []gossh.Signer{signers[keyIndex]}, nil + }), nil } // EnterKeyPhrase is required to read phrase protected private keys. |
