diff options
| author | Paul Buetow <paul@buetow.org> | 2026-03-08 09:28:24 +0200 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2026-03-08 09:28:24 +0200 |
| commit | 4f4c6422d0d5f8038bf918fd3da28b24428e0078 (patch) | |
| tree | e7fdc5edf317239b005acc772d203d2ee81c2e47 | |
| parent | b338ad35897117e38ad9a72dfe5cce5d0d05d6ba (diff) | |
task: replace panic path in known hosts trust flow (task 373)
| -rw-r--r-- | internal/ssh/client/knownhostscallback.go | 85 |
1 files changed, 65 insertions, 20 deletions
diff --git a/internal/ssh/client/knownhostscallback.go b/internal/ssh/client/knownhostscallback.go index 45451ea..ac2ec92 100644 --- a/internal/ssh/client/knownhostscallback.go +++ b/internal/ssh/client/knownhostscallback.go @@ -152,7 +152,10 @@ func (c *KnownHostsCallback) promptAddHosts(hosts []unknownHost) { case <-c.trustAllHostsCh: // Trust-all mode is non-interactive; avoid warning-level noise on stdout. dlog.Client.Debug("Trusting host keys of servers", servers) - c.trustHosts(hosts) + if err := c.trustHosts(hosts); err != nil { + dlog.Client.Error("Unable to update known hosts file", c.knownHostsPath, err) + c.dontTrustHosts(hosts) + } return default: } @@ -168,9 +171,11 @@ func (c *KnownHostsCallback) promptAddHosts(hosts []unknownHost) { Long: "yes", Short: "y", Callback: func() { - c.trustHosts(hosts) - }, - EndCallback: func() { + if err := c.trustHosts(hosts); err != nil { + dlog.Client.Error("Unable to update known hosts file", c.knownHostsPath, err) + c.dontTrustHosts(hosts) + return + } dlog.Client.Info("Added hosts to known hosts file", c.knownHostsPath) }, } @@ -180,10 +185,16 @@ func (c *KnownHostsCallback) promptAddHosts(hosts []unknownHost) { Long: "all", Short: "a", Callback: func() { - close(c.trustAllHostsCh) - c.trustHosts(hosts) - }, - EndCallback: func() { + if err := c.trustHosts(hosts); err != nil { + dlog.Client.Error("Unable to update known hosts file", c.knownHostsPath, err) + c.dontTrustHosts(hosts) + return + } + select { + case <-c.trustAllHostsCh: + default: + close(c.trustAllHostsCh) + } dlog.Client.Info("Added hosts to known hosts file", c.knownHostsPath) }, } @@ -217,40 +228,52 @@ func (c *KnownHostsCallback) promptAddHosts(hosts []unknownHost) { p.Ask() } -func (c *KnownHostsCallback) trustHosts(hosts []unknownHost) { +func (c *KnownHostsCallback) trustHosts(hosts []unknownHost) error { tmpKnownHostsPath := fmt.Sprintf("%s.tmp", c.knownHostsPath) + cleanupTmp := func() { + if err := os.Remove(tmpKnownHostsPath); err != nil && !os.IsNotExist(err) { + dlog.Client.Debug("Unable to remove temporary known hosts file", tmpKnownHostsPath, err) + } + } newFd, err := os.OpenFile(tmpKnownHostsPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600) if err != nil { - panic(fmt.Sprintf("%s: %s", tmpKnownHostsPath, err.Error())) + return fmt.Errorf("open temp known hosts file %s: %w", tmpKnownHostsPath, err) + } + if err := newFd.Chmod(0600); err != nil { + newFd.Close() + cleanupTmp() + return fmt.Errorf("chmod temp known hosts file %s: %w", tmpKnownHostsPath, err) } - defer newFd.Close() // Newly trusted hosts in normalized form addresses := make(map[string]struct{}) // First write to new known hosts file, and keep track of addresses for _, unknown := range hosts { - unknown.responseCh <- trustHost - // Add once as [HOSTNAME]:PORT addresses[knownhosts.Normalize(unknown.server)] = struct{}{} // And once as [IP]:PORT addresses[knownhosts.Normalize(unknown.remote.String())] = struct{}{} if _, err := newFd.WriteString(fmt.Sprintf("%s\n", unknown.hostLine)); err != nil { - panic(err) + newFd.Close() + cleanupTmp() + return fmt.Errorf("write host known_hosts entry: %w", err) } if _, err := newFd.WriteString(fmt.Sprintf("%s\n", unknown.ipLine)); err != nil { - panic(err) + newFd.Close() + cleanupTmp() + return fmt.Errorf("write ip known_hosts entry: %w", err) } } // Read old known hosts file, to see which are old and new entries oldFd, err := os.OpenFile(c.knownHostsPath, os.O_RDONLY|os.O_CREATE, 0600) if err != nil { - panic(err) + newFd.Close() + cleanupTmp() + return fmt.Errorf("open known hosts file %s: %w", c.knownHostsPath, err) } - defer oldFd.Close() scanner := bufio.NewScanner(oldFd) // Now, append all still valid old entries to the new host file @@ -260,18 +283,40 @@ func (c *KnownHostsCallback) trustHosts(hosts []unknownHost) { if _, ok := addresses[address]; !ok { if _, err := newFd.WriteString(fmt.Sprintf("%s\n", line)); err != nil { - panic(err) + oldFd.Close() + newFd.Close() + cleanupTmp() + return fmt.Errorf("append existing known_hosts entry: %w", err) } } } if err := scanner.Err(); err != nil { - panic(err) + oldFd.Close() + newFd.Close() + cleanupTmp() + return fmt.Errorf("scan existing known_hosts entries: %w", err) + } + + if err := oldFd.Close(); err != nil { + newFd.Close() + cleanupTmp() + return fmt.Errorf("close known hosts file %s: %w", c.knownHostsPath, err) + } + if err := newFd.Close(); err != nil { + cleanupTmp() + return fmt.Errorf("close temp known hosts file %s: %w", tmpKnownHostsPath, err) } // Now, replace old known hosts file if err := os.Rename(tmpKnownHostsPath, c.knownHostsPath); err != nil { - panic(err) + cleanupTmp() + return fmt.Errorf("replace known_hosts file %s: %w", c.knownHostsPath, err) + } + + for _, unknown := range hosts { + unknown.responseCh <- trustHost } + return nil } func (c *KnownHostsCallback) dontTrustHosts(hosts []unknownHost) { |
