diff options
| author | Paul Bütow <pbuetow@mimecast.com> | 2020-02-11 13:10:26 +0000 |
|---|---|---|
| committer | Paul Bütow <pbuetow@mimecast.com> | 2020-02-11 13:10:26 +0000 |
| commit | ecf7c86bc2f64068e6256ac1b9738c712a9858e9 (patch) | |
| tree | 1ed536ad81a3c87c21eebd1438792ee91ce969e0 /internal/ssh | |
| parent | 410ca88465f065f244f88c1d4089cb0fa4a45799 (diff) | |
more on scheduled jobs and ssh callbacks
Diffstat (limited to 'internal/ssh')
| -rw-r--r-- | internal/ssh/client/authmethods.go | 18 | ||||
| -rw-r--r-- | internal/ssh/client/customkeycallback.go | 24 | ||||
| -rw-r--r-- | internal/ssh/client/hostkeycallback.go | 280 | ||||
| -rw-r--r-- | internal/ssh/client/knownhostscallback.go | 288 | ||||
| -rw-r--r-- | internal/ssh/client/simplecallback.go | 36 |
5 files changed, 363 insertions, 283 deletions
diff --git a/internal/ssh/client/authmethods.go b/internal/ssh/client/authmethods.go index 072fad0..44c5601 100644 --- a/internal/ssh/client/authmethods.go +++ b/internal/ssh/client/authmethods.go @@ -10,16 +10,20 @@ import ( gossh "golang.org/x/crypto/ssh" ) -// InitSSHAuthMethods initialises all known SSH auth methods on othe client side. -func InitSSHAuthMethods(args clients.Args, trustAllHosts bool, throttleCh chan struct{}) ([]gossh.AuthMethod, *HostKeyCallback) { - if len(args.SSHAuthMethods) > 0 { - hostKeyCallback, err := NewSimpleCallback(trustAllHosts) +// InitSSHAuthMethods initialises all known SSH auth methods on the client side. +func InitSSHAuthMethods(sshAuthMethods []gossh.AuthMethod, hostKeyCallback gossh.HostKeyCallback, trustAllHosts bool, throttleCh chan struct{}) ([]gossh.AuthMethod, HostKeyCallback) { + if len(sshAuthMethods) > 0 { + simpleCallback, err := NewSimpleCallback() if err != nil { logger.FatalExit(err) } - return args.SSHAuthMethods, hostKeyCallback + return sshAuthMethods, simpleCallback } + return initKnownHostsAuthMethods(trustAllHosts, throttleCh) +} + +func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{}) ([]gossh.AuthMethod, HostKeyCallback) { var sshAuthMethods []gossh.AuthMethod if config.Common.ExperimentalFeaturesEnable { sshAuthMethods = append(sshAuthMethods, gossh.Password("experimental feature test")) @@ -44,10 +48,10 @@ func InitSSHAuthMethods(args clients.Args, trustAllHosts bool, throttleCh chan s } knownHostsPath := os.Getenv("HOME") + "/.ssh/known_hosts" - hostKeyCallback, err := NewHostKeyCallback(knownHostsPath, trustAllHosts, throttleCh) + knownHostsCallback, err := NewKnownHostsCallback(knownHostsPath, trustAllHosts, throttleCh) if err != nil { logger.FatalExit(knownHostsPath, err) } - return sshAuthMethods, hostKeyCallback + return sshAuthMethods, knownHostsCallback } diff --git a/internal/ssh/client/customkeycallback.go b/internal/ssh/client/customkeycallback.go new file mode 100644 index 0000000..73e5289 --- /dev/null +++ b/internal/ssh/client/customkeycallback.go @@ -0,0 +1,24 @@ +package client + +import ( + "net" + + "golang.org/x/crypto/ssh" +) + +// CustomCallback is a custom host key callback wrapper. +type CustomCallback struct { +} + +// NewCustomCallback returns a new wrapper. +func NewCustomCallback() (*CustomCallback, error) { + h := CustomCallback{} + return &h, nil +} + +// Wrap the host key callback. +func (h *CustomCallback) Wrap() ssh.HostKeyCallback { + return func(server string, remote net.Addr, key ssh.PublicKey) error { + return nil + } +} diff --git a/internal/ssh/client/hostkeycallback.go b/internal/ssh/client/hostkeycallback.go index d090d7f..95543f2 100644 --- a/internal/ssh/client/hostkeycallback.go +++ b/internal/ssh/client/hostkeycallback.go @@ -1,287 +1,15 @@ package client import ( - "bufio" "context" - "fmt" - "net" - "os" - "strings" - "sync" - "time" - - "github.com/mimecast/dtail/internal/io/logger" - "github.com/mimecast/dtail/internal/io/prompt" "golang.org/x/crypto/ssh" - "golang.org/x/crypto/ssh/knownhosts" -) - -type response int - -const ( - trustHost response = iota - dontTrustHost response = iota ) -// Represents an unknown host. -type unknownHost struct { - server string - remote net.Addr - key ssh.PublicKey - hostLine string - ipLine string - responseCh chan response -} - // HostKeyCallback is a wrapper around ssh.KnownHosts so that we can add all // unknown hosts in a single batch to the known_hosts file. -type HostKeyCallback struct { - knownHostsPath string - unknownCh chan unknownHost - throttleCh chan struct{} - trustAllHostsCh chan struct{} - untrustedHosts map[string]bool - mutex sync.Mutex -} - -// NewHostKeyCallback returns a new wrapper. -func NewHostKeyCallback(knownHostsPath string, trustAllHosts bool, throttleCh chan struct{}) (*HostKeyCallback, error) { - // Ensure file exists - os.OpenFile(knownHostsPath, os.O_RDONLY|os.O_CREATE, 0666) - - h := HostKeyCallback{ - knownHostsPath: knownHostsPath, - unknownCh: make(chan unknownHost), - trustAllHostsCh: make(chan struct{}), - throttleCh: throttleCh, - untrustedHosts: make(map[string]bool), - } - - if trustAllHosts { - close(h.trustAllHostsCh) - } - - return &h, nil -} - -// Wrap the host key callback. -func (h *HostKeyCallback) Wrap() ssh.HostKeyCallback { - return func(server string, remote net.Addr, key ssh.PublicKey) error { - // Parse known_hosts file - knownHostsCb, err := knownhosts.New(h.knownHostsPath) - if err != nil { - // Problem parsing it - return err - } - - // Check for valid entry in known_hosts file - err = knownHostsCb(server, remote, key) - if err == nil { - // OK - return nil - } - - // Make sure that interactive user callback does not interfere with - // SSH connection throttler. - <-h.throttleCh - defer func() { h.throttleCh <- struct{}{} }() - - unknown := unknownHost{ - server: server, - remote: remote, - key: key, - hostLine: knownhosts.Line([]string{server}, key), - ipLine: knownhosts.Line([]string{remote.String()}, key), - responseCh: make(chan response), - } - - logger.Warn("Encountered unknown host", unknown) - // Notify user that there is an unknown host - h.unknownCh <- unknown - - // Wait for user input. - switch <-unknown.responseCh { - case trustHost: - // End user acknowledged host key - return nil - case dontTrustHost: - } - - h.mutex.Lock() - defer h.mutex.Unlock() - h.untrustedHosts[server] = true - - return err - } -} - -// PromptAddHosts prompts a question to the user whether unknown hosts should -// be added to the known hosts or not. -func (h *HostKeyCallback) PromptAddHosts(ctx context.Context) { - var hosts []unknownHost - - for { - // Check whether there is a unknown host - select { - case unknown := <-h.unknownCh: - hosts = append(hosts, unknown) - // Ask every 50 unknown hosts - if len(hosts) >= 50 { - h.promptAddHosts(hosts) - hosts = []unknownHost{} - } - case <-time.After(2 * time.Second): - // Or ask when after 2 seconds no new unknown hosts were added. - if len(hosts) > 0 { - h.promptAddHosts(hosts) - hosts = []unknownHost{} - } - case <-ctx.Done(): - logger.Debug("Stopping goroutine prompting new hosts...") - return - } - } -} - -func (h *HostKeyCallback) promptAddHosts(hosts []unknownHost) { - var servers []string - - for _, host := range hosts { - servers = append(servers, host.server) - } - - select { - case <-h.trustAllHostsCh: - logger.Warn("Trusting host keys of servers", servers) - h.trustHosts(hosts) - return - default: - } - - question := fmt.Sprintf("Encountered %d unknown hosts: '%s'\n%s", - len(servers), - strings.Join(servers, ","), - "Do you want to trust these hosts?", - ) - - p := prompt.New(question) - - a := prompt.Answer{ - Long: "yes", - Short: "y", - Callback: func() { - h.trustHosts(hosts) - }, - EndCallback: func() { - logger.Info("Added hosts to known hosts file", h.knownHostsPath) - }, - } - p.Add(a) - - a = prompt.Answer{ - Long: "all", - Short: "a", - Callback: func() { - close(h.trustAllHostsCh) - h.trustHosts(hosts) - }, - EndCallback: func() { - logger.Info("Added hosts to known hosts file", h.knownHostsPath) - }, - } - p.Add(a) - - a = prompt.Answer{ - Long: "no", - Short: "n", - Callback: func() { - h.dontTrustHosts(hosts) - }, - EndCallback: func() { - logger.Info("Didn't add hosts to known hosts file", h.knownHostsPath) - }, - } - p.Add(a) - - a = prompt.Answer{ - Long: "details", - Short: "d", - AskAgain: true, - Callback: func() { - for _, unknown := range hosts { - fmt.Println(unknown.hostLine) - fmt.Println(unknown.ipLine) - } - }, - } - p.Add(a) - - p.Ask() -} - -func (h *HostKeyCallback) trustHosts(hosts []unknownHost) { - tmpKnownHostsPath := fmt.Sprintf("%s.tmp", h.knownHostsPath) - - newFd, err := os.OpenFile(tmpKnownHostsPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600) - if err != nil { - panic(fmt.Sprintf("%s: %s", tmpKnownHostsPath, err.Error())) - } - defer newFd.Close() - - // Newly trusted hosts in normalized form - addresses := make(map[string]struct{}) - - // First write to new known hosts file, and keep track of addresses - for _, unknown := range hosts { - unknown.responseCh <- trustHost - - // Add once as [HOSTNAME]:PORT - addresses[knownhosts.Normalize(unknown.server)] = struct{}{} - // And once as [IP]:PORT - addresses[knownhosts.Normalize(unknown.remote.String())] = struct{}{} - - newFd.WriteString(fmt.Sprintf("%s\n", unknown.hostLine)) - newFd.WriteString(fmt.Sprintf("%s\n", unknown.ipLine)) - } - - // Read old known hosts file, to see which are old and new entries - os.OpenFile(h.knownHostsPath, os.O_RDONLY|os.O_CREATE, 0666) - oldFd, err := os.Open(h.knownHostsPath) - if err != nil { - panic(err) - } - defer oldFd.Close() - - scanner := bufio.NewScanner(oldFd) - - // Now, append all still valid old entries to the new host file - for scanner.Scan() { - line := scanner.Text() - address := strings.SplitN(line, " ", 2)[0] - - if _, ok := addresses[address]; !ok { - newFd.WriteString(fmt.Sprintf("%s\n", line)) - } - } - - // Now, replace old known hosts file - if err := os.Rename(tmpKnownHostsPath, h.knownHostsPath); err != nil { - panic(err) - } -} - -func (h *HostKeyCallback) dontTrustHosts(hosts []unknownHost) { - for _, unknown := range hosts { - unknown.responseCh <- dontTrustHost - } -} - -// Untrusted returns true if the host is not trusted. False otherwise. -func (h *HostKeyCallback) Untrusted(server string) bool { - h.mutex.Lock() - defer h.mutex.Unlock() - _, ok := h.untrustedHosts[server] - - return ok +type HostKeyCallback interface { + Wrap() ssh.HostKeyCallback + Untrusted(server string) bool + PromptAddHosts(ctx context.Context) } diff --git a/internal/ssh/client/knownhostscallback.go b/internal/ssh/client/knownhostscallback.go new file mode 100644 index 0000000..1ccf6c6 --- /dev/null +++ b/internal/ssh/client/knownhostscallback.go @@ -0,0 +1,288 @@ +package client + +import ( + "bufio" + "context" + "fmt" + "net" + "os" + "strings" + "sync" + "time" + + "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/io/prompt" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/knownhosts" +) + +type response int + +const ( + trustHost response = iota + dontTrustHost response = iota +) + +// Represents an unknown host. +type unknownHost struct { + server string + remote net.Addr + key ssh.PublicKey + hostLine string + ipLine string + responseCh chan response +} + +// KnownHostsCallback is a wrapper around ssh.KnownHosts so that we can add all +// unknown hosts in a single batch to the known_hosts file. +type KnownHostsCallback struct { + knownHostsPath string + unknownCh chan unknownHost + throttleCh chan struct{} + trustAllHostsCh chan struct{} + untrustedHosts map[string]bool + mutex *sync.Mutex +} + +// NewKnownHostsCallback returns a new wrapper. +func NewKnownHostsCallback(knownHostsPath string, trustAllHosts bool, throttleCh chan struct{}) (HostKeyCallback, error) { + // Ensure file exists + os.OpenFile(knownHostsPath, os.O_RDONLY|os.O_CREATE, 0666) + untrustedHosts := make(map[string]bool) + + c := KnownHostsCallback{ + knownHostsPath: knownHostsPath, + unknownCh: make(chan unknownHost), + trustAllHostsCh: make(chan struct{}), + throttleCh: throttleCh, + untrustedHosts: untrustedHosts, + mutex: &sync.Mutex{}, + } + + if trustAllHosts { + close(c.trustAllHostsCh) + } + + return c, nil +} + +// Wrap the host key callback. +func (c KnownHostsCallback) Wrap() ssh.HostKeyCallback { + return func(server string, remote net.Addr, key ssh.PublicKey) error { + // Parse known_hosts file + knownHostsCb, err := knownhosts.New(c.knownHostsPath) + if err != nil { + return err + } + + // Check for valid entry in known_hosts file + err = knownHostsCb(server, remote, key) + if err == nil { + // OK + return nil + } + + // Make sure that interactive user callback does not interfere with + // SSH connection throttler. + <-c.throttleCh + defer func() { c.throttleCh <- struct{}{} }() + + unknown := unknownHost{ + server: server, + remote: remote, + key: key, + hostLine: knownhosts.Line([]string{server}, key), + ipLine: knownhosts.Line([]string{remote.String()}, key), + responseCh: make(chan response), + } + + logger.Warn("Encountered unknown host", unknown) + // Notify user that there is an unknown host + c.unknownCh <- unknown + + // Wait for user input. + switch <-unknown.responseCh { + case trustHost: + // End user acknowledged host key + return nil + case dontTrustHost: + } + + c.mutex.Lock() + defer c.mutex.Unlock() + c.untrustedHosts[server] = true + + return err + } +} + +// PromptAddHosts prompts a question to the user whether unknown hosts should +// be added to the known hosts or not. +func (c KnownHostsCallback) PromptAddHosts(ctx context.Context) { + var hosts []unknownHost + + for { + // Check whether there is a unknown host + select { + case unknown := <-c.unknownCh: + hosts = append(hosts, unknown) + // Ask every 50 unknown hosts + if len(hosts) >= 50 { + c.promptAddHosts(hosts) + hosts = []unknownHost{} + } + case <-time.After(2 * time.Second): + // Or ask when after 2 seconds no new unknown hosts were added. + if len(hosts) > 0 { + c.promptAddHosts(hosts) + hosts = []unknownHost{} + } + case <-ctx.Done(): + logger.Debug("Stopping goroutine prompting new hosts...") + return + } + } +} + +func (c KnownHostsCallback) promptAddHosts(hosts []unknownHost) { + var servers []string + + for _, host := range hosts { + servers = append(servers, host.server) + } + + select { + case <-c.trustAllHostsCh: + logger.Warn("Trusting host keys of servers", servers) + c.trustHosts(hosts) + return + default: + } + + question := fmt.Sprintf("Encountered %d unknown hosts: '%s'\n%s", + len(servers), + strings.Join(servers, ","), + "Do you want to trust these hosts?", + ) + + p := prompt.New(question) + + a := prompt.Answer{ + Long: "yes", + Short: "y", + Callback: func() { + c.trustHosts(hosts) + }, + EndCallback: func() { + logger.Info("Added hosts to known hosts file", c.knownHostsPath) + }, + } + p.Add(a) + + a = prompt.Answer{ + Long: "all", + Short: "a", + Callback: func() { + close(c.trustAllHostsCh) + c.trustHosts(hosts) + }, + EndCallback: func() { + logger.Info("Added hosts to known hosts file", c.knownHostsPath) + }, + } + p.Add(a) + + a = prompt.Answer{ + Long: "no", + Short: "n", + Callback: func() { + c.dontTrustHosts(hosts) + }, + EndCallback: func() { + logger.Info("Didn't add hosts to known hosts file", c.knownHostsPath) + }, + } + p.Add(a) + + a = prompt.Answer{ + Long: "details", + Short: "d", + AskAgain: true, + Callback: func() { + for _, unknown := range hosts { + fmt.Println(unknown.hostLine) + fmt.Println(unknown.ipLine) + } + }, + } + p.Add(a) + + p.Ask() +} + +func (c KnownHostsCallback) trustHosts(hosts []unknownHost) { + tmpKnownHostsPath := fmt.Sprintf("%s.tmp", c.knownHostsPath) + + newFd, err := os.OpenFile(tmpKnownHostsPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600) + if err != nil { + panic(fmt.Sprintf("%s: %s", tmpKnownHostsPath, err.Error())) + } + defer newFd.Close() + + // Newly trusted hosts in normalized form + addresses := make(map[string]struct{}) + + // First write to new known hosts file, and keep track of addresses + for _, unknown := range hosts { + unknown.responseCh <- trustHost + + // Add once as [HOSTNAME]:PORT + addresses[knownhosts.Normalize(unknown.server)] = struct{}{} + // And once as [IP]:PORT + addresses[knownhosts.Normalize(unknown.remote.String())] = struct{}{} + + newFd.WriteString(fmt.Sprintf("%s\n", unknown.hostLine)) + newFd.WriteString(fmt.Sprintf("%s\n", unknown.ipLine)) + } + + // Read old known hosts file, to see which are old and new entries + os.OpenFile(c.knownHostsPath, os.O_RDONLY|os.O_CREATE, 0666) + oldFd, err := os.Open(c.knownHostsPath) + if err != nil { + panic(err) + } + defer oldFd.Close() + + scanner := bufio.NewScanner(oldFd) + + // Now, append all still valid old entries to the new host file + for scanner.Scan() { + line := scanner.Text() + address := strings.SplitN(line, " ", 2)[0] + + if _, ok := addresses[address]; !ok { + newFd.WriteString(fmt.Sprintf("%s\n", line)) + } + } + + // Now, replace old known hosts file + if err := os.Rename(tmpKnownHostsPath, c.knownHostsPath); err != nil { + panic(err) + } +} + +func (c KnownHostsCallback) dontTrustHosts(hosts []unknownHost) { + for _, unknown := range hosts { + unknown.responseCh <- dontTrustHost + } +} + +// Untrusted returns true if the host is not trusted. False otherwise. +func (c KnownHostsCallback) Untrusted(server string) bool { + c.mutex.Lock() + defer c.mutex.Unlock() + _, ok := c.untrustedHosts[server] + + return ok +} diff --git a/internal/ssh/client/simplecallback.go b/internal/ssh/client/simplecallback.go new file mode 100644 index 0000000..580fa36 --- /dev/null +++ b/internal/ssh/client/simplecallback.go @@ -0,0 +1,36 @@ +package client + +import ( + "context" + "net" + + "golang.org/x/crypto/ssh" +) + +// SimpleCallback is a wrapper around ssh.KnownHosts so that we can add all +// unknown hosts in a single batch to the known_hosts file. +type SimpleCallback struct { +} + +// NewSimpleCallback returns a new wrapper. +func NewSimpleCallback() (SimpleCallback, error) { + return SimpleCallback{}, nil +} + +// Wrap the host key callback. +func (SimpleCallback) Wrap() ssh.HostKeyCallback { + return func(server string, remote net.Addr, key ssh.PublicKey) error { + return nil + } +} + +// Untrusted returns whether host is not trusted or not. +func (SimpleCallback) Untrusted(server string) bool { + return false +} + +// PromptAddHosts prompts a question to the user whether unknown hosts should +// be added to the known hosts or not. +func (SimpleCallback) PromptAddHosts(ctx context.Context) { + // Not used here. +} |
