summaryrefslogtreecommitdiff
path: root/internal/ssh
diff options
context:
space:
mode:
Diffstat (limited to 'internal/ssh')
-rw-r--r--internal/ssh/client/knownhostscallback.go16
-rw-r--r--internal/ssh/ssh.go11
2 files changed, 14 insertions, 13 deletions
diff --git a/internal/ssh/client/knownhostscallback.go b/internal/ssh/client/knownhostscallback.go
index fe3543c..9c73864 100644
--- a/internal/ssh/client/knownhostscallback.go
+++ b/internal/ssh/client/knownhostscallback.go
@@ -45,6 +45,8 @@ type KnownHostsCallback struct {
mutex *sync.Mutex
}
+var _ HostKeyCallback = (*KnownHostsCallback)(nil)
+
// NewKnownHostsCallback returns a new wrapper.
func NewKnownHostsCallback(knownHostsPath string, trustAllHosts bool,
throttleCh chan struct{}) (HostKeyCallback, error) {
@@ -63,11 +65,11 @@ func NewKnownHostsCallback(knownHostsPath string, trustAllHosts bool,
if trustAllHosts {
close(c.trustAllHostsCh)
}
- return c, nil
+ return &c, nil
}
// Wrap the host key callback.
-func (c KnownHostsCallback) Wrap() ssh.HostKeyCallback {
+func (c *KnownHostsCallback) Wrap() ssh.HostKeyCallback {
return func(server string, remote net.Addr, key ssh.PublicKey) error {
// Parse known_hosts file
knownHostsCb, err := knownhosts.New(c.knownHostsPath)
@@ -113,7 +115,7 @@ func (c KnownHostsCallback) Wrap() ssh.HostKeyCallback {
// PromptAddHosts prompts a question to the user whether unknown hosts should
// be added to the known hosts or not.
-func (c KnownHostsCallback) PromptAddHosts(ctx context.Context) {
+func (c *KnownHostsCallback) PromptAddHosts(ctx context.Context) {
var hosts []unknownHost
for {
// Check whether there is a unknown host
@@ -138,7 +140,7 @@ func (c KnownHostsCallback) PromptAddHosts(ctx context.Context) {
}
}
-func (c KnownHostsCallback) promptAddHosts(hosts []unknownHost) {
+func (c *KnownHostsCallback) promptAddHosts(hosts []unknownHost) {
var servers []string
for _, host := range hosts {
servers = append(servers, host.server)
@@ -212,7 +214,7 @@ func (c KnownHostsCallback) promptAddHosts(hosts []unknownHost) {
p.Ask()
}
-func (c KnownHostsCallback) trustHosts(hosts []unknownHost) {
+func (c *KnownHostsCallback) trustHosts(hosts []unknownHost) {
tmpKnownHostsPath := fmt.Sprintf("%s.tmp", c.knownHostsPath)
newFd, err := os.OpenFile(tmpKnownHostsPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600)
@@ -265,14 +267,14 @@ func (c KnownHostsCallback) trustHosts(hosts []unknownHost) {
}
}
-func (c KnownHostsCallback) dontTrustHosts(hosts []unknownHost) {
+func (c *KnownHostsCallback) dontTrustHosts(hosts []unknownHost) {
for _, unknown := range hosts {
unknown.responseCh <- dontTrustHost
}
}
// Untrusted returns true if the host is not trusted. False otherwise.
-func (c KnownHostsCallback) Untrusted(server string) bool {
+func (c *KnownHostsCallback) Untrusted(server string) bool {
c.mutex.Lock()
defer c.mutex.Unlock()
_, ok := c.untrustedHosts[server]
diff --git a/internal/ssh/ssh.go b/internal/ssh/ssh.go
index 9c2dcb8..32e01b3 100644
--- a/internal/ssh/ssh.go
+++ b/internal/ssh/ssh.go
@@ -21,11 +21,10 @@ import (
func GeneratePrivateRSAKey(size int) (*rsa.PrivateKey, error) {
privateKey, err := rsa.GenerateKey(rand.Reader, size)
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("failed to generate RSA key: %w", err)
}
- err = privateKey.Validate()
- if err != nil {
- return nil, err
+ if err = privateKey.Validate(); err != nil {
+ return nil, fmt.Errorf("failed to validate generated RSA key: %w", err)
}
return privateKey, nil
}
@@ -46,12 +45,12 @@ func EncodePrivateKeyToPEM(privateKey *rsa.PrivateKey) []byte {
func Agent() (gossh.AuthMethod, error) {
sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK"))
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("failed to connect to SSH agent: %w", err)
}
agentClient := agent.NewClient(sshAgent)
keys, err := agentClient.List()
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("failed to list SSH agent keys: %w", err)
}
for i, key := range keys {
dlog.Common.Debug("Public key", i, key)