diff options
| author | Paul Bütow <pbuetow@mimecast.com> | 2020-02-11 13:10:26 +0000 |
|---|---|---|
| committer | Paul Bütow <pbuetow@mimecast.com> | 2020-02-11 13:10:26 +0000 |
| commit | ab01ba7e88308d9ad9381fce4d2b27c1a3ad5f00 (patch) | |
| tree | 1ed536ad81a3c87c21eebd1438792ee91ce969e0 /internal/ssh/client/knownhostscallback.go | |
| parent | 581f729cdbcee98cbaa4f19a7a96a17f27add624 (diff) | |
more on scheduled jobs and ssh callbacks
Diffstat (limited to 'internal/ssh/client/knownhostscallback.go')
| -rw-r--r-- | internal/ssh/client/knownhostscallback.go | 288 |
1 files changed, 288 insertions, 0 deletions
diff --git a/internal/ssh/client/knownhostscallback.go b/internal/ssh/client/knownhostscallback.go new file mode 100644 index 0000000..1ccf6c6 --- /dev/null +++ b/internal/ssh/client/knownhostscallback.go @@ -0,0 +1,288 @@ +package client + +import ( + "bufio" + "context" + "fmt" + "net" + "os" + "strings" + "sync" + "time" + + "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/io/prompt" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/knownhosts" +) + +type response int + +const ( + trustHost response = iota + dontTrustHost response = iota +) + +// Represents an unknown host. +type unknownHost struct { + server string + remote net.Addr + key ssh.PublicKey + hostLine string + ipLine string + responseCh chan response +} + +// KnownHostsCallback is a wrapper around ssh.KnownHosts so that we can add all +// unknown hosts in a single batch to the known_hosts file. +type KnownHostsCallback struct { + knownHostsPath string + unknownCh chan unknownHost + throttleCh chan struct{} + trustAllHostsCh chan struct{} + untrustedHosts map[string]bool + mutex *sync.Mutex +} + +// NewKnownHostsCallback returns a new wrapper. +func NewKnownHostsCallback(knownHostsPath string, trustAllHosts bool, throttleCh chan struct{}) (HostKeyCallback, error) { + // Ensure file exists + os.OpenFile(knownHostsPath, os.O_RDONLY|os.O_CREATE, 0666) + untrustedHosts := make(map[string]bool) + + c := KnownHostsCallback{ + knownHostsPath: knownHostsPath, + unknownCh: make(chan unknownHost), + trustAllHostsCh: make(chan struct{}), + throttleCh: throttleCh, + untrustedHosts: untrustedHosts, + mutex: &sync.Mutex{}, + } + + if trustAllHosts { + close(c.trustAllHostsCh) + } + + return c, nil +} + +// Wrap the host key callback. +func (c KnownHostsCallback) Wrap() ssh.HostKeyCallback { + return func(server string, remote net.Addr, key ssh.PublicKey) error { + // Parse known_hosts file + knownHostsCb, err := knownhosts.New(c.knownHostsPath) + if err != nil { + return err + } + + // Check for valid entry in known_hosts file + err = knownHostsCb(server, remote, key) + if err == nil { + // OK + return nil + } + + // Make sure that interactive user callback does not interfere with + // SSH connection throttler. + <-c.throttleCh + defer func() { c.throttleCh <- struct{}{} }() + + unknown := unknownHost{ + server: server, + remote: remote, + key: key, + hostLine: knownhosts.Line([]string{server}, key), + ipLine: knownhosts.Line([]string{remote.String()}, key), + responseCh: make(chan response), + } + + logger.Warn("Encountered unknown host", unknown) + // Notify user that there is an unknown host + c.unknownCh <- unknown + + // Wait for user input. + switch <-unknown.responseCh { + case trustHost: + // End user acknowledged host key + return nil + case dontTrustHost: + } + + c.mutex.Lock() + defer c.mutex.Unlock() + c.untrustedHosts[server] = true + + return err + } +} + +// PromptAddHosts prompts a question to the user whether unknown hosts should +// be added to the known hosts or not. +func (c KnownHostsCallback) PromptAddHosts(ctx context.Context) { + var hosts []unknownHost + + for { + // Check whether there is a unknown host + select { + case unknown := <-c.unknownCh: + hosts = append(hosts, unknown) + // Ask every 50 unknown hosts + if len(hosts) >= 50 { + c.promptAddHosts(hosts) + hosts = []unknownHost{} + } + case <-time.After(2 * time.Second): + // Or ask when after 2 seconds no new unknown hosts were added. + if len(hosts) > 0 { + c.promptAddHosts(hosts) + hosts = []unknownHost{} + } + case <-ctx.Done(): + logger.Debug("Stopping goroutine prompting new hosts...") + return + } + } +} + +func (c KnownHostsCallback) promptAddHosts(hosts []unknownHost) { + var servers []string + + for _, host := range hosts { + servers = append(servers, host.server) + } + + select { + case <-c.trustAllHostsCh: + logger.Warn("Trusting host keys of servers", servers) + c.trustHosts(hosts) + return + default: + } + + question := fmt.Sprintf("Encountered %d unknown hosts: '%s'\n%s", + len(servers), + strings.Join(servers, ","), + "Do you want to trust these hosts?", + ) + + p := prompt.New(question) + + a := prompt.Answer{ + Long: "yes", + Short: "y", + Callback: func() { + c.trustHosts(hosts) + }, + EndCallback: func() { + logger.Info("Added hosts to known hosts file", c.knownHostsPath) + }, + } + p.Add(a) + + a = prompt.Answer{ + Long: "all", + Short: "a", + Callback: func() { + close(c.trustAllHostsCh) + c.trustHosts(hosts) + }, + EndCallback: func() { + logger.Info("Added hosts to known hosts file", c.knownHostsPath) + }, + } + p.Add(a) + + a = prompt.Answer{ + Long: "no", + Short: "n", + Callback: func() { + c.dontTrustHosts(hosts) + }, + EndCallback: func() { + logger.Info("Didn't add hosts to known hosts file", c.knownHostsPath) + }, + } + p.Add(a) + + a = prompt.Answer{ + Long: "details", + Short: "d", + AskAgain: true, + Callback: func() { + for _, unknown := range hosts { + fmt.Println(unknown.hostLine) + fmt.Println(unknown.ipLine) + } + }, + } + p.Add(a) + + p.Ask() +} + +func (c KnownHostsCallback) trustHosts(hosts []unknownHost) { + tmpKnownHostsPath := fmt.Sprintf("%s.tmp", c.knownHostsPath) + + newFd, err := os.OpenFile(tmpKnownHostsPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600) + if err != nil { + panic(fmt.Sprintf("%s: %s", tmpKnownHostsPath, err.Error())) + } + defer newFd.Close() + + // Newly trusted hosts in normalized form + addresses := make(map[string]struct{}) + + // First write to new known hosts file, and keep track of addresses + for _, unknown := range hosts { + unknown.responseCh <- trustHost + + // Add once as [HOSTNAME]:PORT + addresses[knownhosts.Normalize(unknown.server)] = struct{}{} + // And once as [IP]:PORT + addresses[knownhosts.Normalize(unknown.remote.String())] = struct{}{} + + newFd.WriteString(fmt.Sprintf("%s\n", unknown.hostLine)) + newFd.WriteString(fmt.Sprintf("%s\n", unknown.ipLine)) + } + + // Read old known hosts file, to see which are old and new entries + os.OpenFile(c.knownHostsPath, os.O_RDONLY|os.O_CREATE, 0666) + oldFd, err := os.Open(c.knownHostsPath) + if err != nil { + panic(err) + } + defer oldFd.Close() + + scanner := bufio.NewScanner(oldFd) + + // Now, append all still valid old entries to the new host file + for scanner.Scan() { + line := scanner.Text() + address := strings.SplitN(line, " ", 2)[0] + + if _, ok := addresses[address]; !ok { + newFd.WriteString(fmt.Sprintf("%s\n", line)) + } + } + + // Now, replace old known hosts file + if err := os.Rename(tmpKnownHostsPath, c.knownHostsPath); err != nil { + panic(err) + } +} + +func (c KnownHostsCallback) dontTrustHosts(hosts []unknownHost) { + for _, unknown := range hosts { + unknown.responseCh <- dontTrustHost + } +} + +// Untrusted returns true if the host is not trusted. False otherwise. +func (c KnownHostsCallback) Untrusted(server string) bool { + c.mutex.Lock() + defer c.mutex.Unlock() + _, ok := c.untrustedHosts[server] + + return ok +} |
