summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPaul Buetow <paul@buetow.org>2026-03-08 09:28:24 +0200
committerPaul Buetow <paul@buetow.org>2026-03-08 09:28:24 +0200
commit4f4c6422d0d5f8038bf918fd3da28b24428e0078 (patch)
treee7fdc5edf317239b005acc772d203d2ee81c2e47
parentb338ad35897117e38ad9a72dfe5cce5d0d05d6ba (diff)
task: replace panic path in known hosts trust flow (task 373)
-rw-r--r--internal/ssh/client/knownhostscallback.go85
1 files changed, 65 insertions, 20 deletions
diff --git a/internal/ssh/client/knownhostscallback.go b/internal/ssh/client/knownhostscallback.go
index 45451ea..ac2ec92 100644
--- a/internal/ssh/client/knownhostscallback.go
+++ b/internal/ssh/client/knownhostscallback.go
@@ -152,7 +152,10 @@ func (c *KnownHostsCallback) promptAddHosts(hosts []unknownHost) {
case <-c.trustAllHostsCh:
// Trust-all mode is non-interactive; avoid warning-level noise on stdout.
dlog.Client.Debug("Trusting host keys of servers", servers)
- c.trustHosts(hosts)
+ if err := c.trustHosts(hosts); err != nil {
+ dlog.Client.Error("Unable to update known hosts file", c.knownHostsPath, err)
+ c.dontTrustHosts(hosts)
+ }
return
default:
}
@@ -168,9 +171,11 @@ func (c *KnownHostsCallback) promptAddHosts(hosts []unknownHost) {
Long: "yes",
Short: "y",
Callback: func() {
- c.trustHosts(hosts)
- },
- EndCallback: func() {
+ if err := c.trustHosts(hosts); err != nil {
+ dlog.Client.Error("Unable to update known hosts file", c.knownHostsPath, err)
+ c.dontTrustHosts(hosts)
+ return
+ }
dlog.Client.Info("Added hosts to known hosts file", c.knownHostsPath)
},
}
@@ -180,10 +185,16 @@ func (c *KnownHostsCallback) promptAddHosts(hosts []unknownHost) {
Long: "all",
Short: "a",
Callback: func() {
- close(c.trustAllHostsCh)
- c.trustHosts(hosts)
- },
- EndCallback: func() {
+ if err := c.trustHosts(hosts); err != nil {
+ dlog.Client.Error("Unable to update known hosts file", c.knownHostsPath, err)
+ c.dontTrustHosts(hosts)
+ return
+ }
+ select {
+ case <-c.trustAllHostsCh:
+ default:
+ close(c.trustAllHostsCh)
+ }
dlog.Client.Info("Added hosts to known hosts file", c.knownHostsPath)
},
}
@@ -217,40 +228,52 @@ func (c *KnownHostsCallback) promptAddHosts(hosts []unknownHost) {
p.Ask()
}
-func (c *KnownHostsCallback) trustHosts(hosts []unknownHost) {
+func (c *KnownHostsCallback) trustHosts(hosts []unknownHost) error {
tmpKnownHostsPath := fmt.Sprintf("%s.tmp", c.knownHostsPath)
+ cleanupTmp := func() {
+ if err := os.Remove(tmpKnownHostsPath); err != nil && !os.IsNotExist(err) {
+ dlog.Client.Debug("Unable to remove temporary known hosts file", tmpKnownHostsPath, err)
+ }
+ }
newFd, err := os.OpenFile(tmpKnownHostsPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600)
if err != nil {
- panic(fmt.Sprintf("%s: %s", tmpKnownHostsPath, err.Error()))
+ return fmt.Errorf("open temp known hosts file %s: %w", tmpKnownHostsPath, err)
+ }
+ if err := newFd.Chmod(0600); err != nil {
+ newFd.Close()
+ cleanupTmp()
+ return fmt.Errorf("chmod temp known hosts file %s: %w", tmpKnownHostsPath, err)
}
- defer newFd.Close()
// Newly trusted hosts in normalized form
addresses := make(map[string]struct{})
// First write to new known hosts file, and keep track of addresses
for _, unknown := range hosts {
- unknown.responseCh <- trustHost
-
// Add once as [HOSTNAME]:PORT
addresses[knownhosts.Normalize(unknown.server)] = struct{}{}
// And once as [IP]:PORT
addresses[knownhosts.Normalize(unknown.remote.String())] = struct{}{}
if _, err := newFd.WriteString(fmt.Sprintf("%s\n", unknown.hostLine)); err != nil {
- panic(err)
+ newFd.Close()
+ cleanupTmp()
+ return fmt.Errorf("write host known_hosts entry: %w", err)
}
if _, err := newFd.WriteString(fmt.Sprintf("%s\n", unknown.ipLine)); err != nil {
- panic(err)
+ newFd.Close()
+ cleanupTmp()
+ return fmt.Errorf("write ip known_hosts entry: %w", err)
}
}
// Read old known hosts file, to see which are old and new entries
oldFd, err := os.OpenFile(c.knownHostsPath, os.O_RDONLY|os.O_CREATE, 0600)
if err != nil {
- panic(err)
+ newFd.Close()
+ cleanupTmp()
+ return fmt.Errorf("open known hosts file %s: %w", c.knownHostsPath, err)
}
- defer oldFd.Close()
scanner := bufio.NewScanner(oldFd)
// Now, append all still valid old entries to the new host file
@@ -260,18 +283,40 @@ func (c *KnownHostsCallback) trustHosts(hosts []unknownHost) {
if _, ok := addresses[address]; !ok {
if _, err := newFd.WriteString(fmt.Sprintf("%s\n", line)); err != nil {
- panic(err)
+ oldFd.Close()
+ newFd.Close()
+ cleanupTmp()
+ return fmt.Errorf("append existing known_hosts entry: %w", err)
}
}
}
if err := scanner.Err(); err != nil {
- panic(err)
+ oldFd.Close()
+ newFd.Close()
+ cleanupTmp()
+ return fmt.Errorf("scan existing known_hosts entries: %w", err)
+ }
+
+ if err := oldFd.Close(); err != nil {
+ newFd.Close()
+ cleanupTmp()
+ return fmt.Errorf("close known hosts file %s: %w", c.knownHostsPath, err)
+ }
+ if err := newFd.Close(); err != nil {
+ cleanupTmp()
+ return fmt.Errorf("close temp known hosts file %s: %w", tmpKnownHostsPath, err)
}
// Now, replace old known hosts file
if err := os.Rename(tmpKnownHostsPath, c.knownHostsPath); err != nil {
- panic(err)
+ cleanupTmp()
+ return fmt.Errorf("replace known_hosts file %s: %w", c.knownHostsPath, err)
+ }
+
+ for _, unknown := range hosts {
+ unknown.responseCh <- trustHost
}
+ return nil
}
func (c *KnownHostsCallback) dontTrustHosts(hosts []unknownHost) {