summaryrefslogtreecommitdiff
path: root/internal/ssh
diff options
context:
space:
mode:
authorPaul Bütow <pbuetow@mimecast.com>2020-02-11 13:10:26 +0000
committerPaul Bütow <pbuetow@mimecast.com>2020-02-11 13:10:26 +0000
commitecf7c86bc2f64068e6256ac1b9738c712a9858e9 (patch)
tree1ed536ad81a3c87c21eebd1438792ee91ce969e0 /internal/ssh
parent410ca88465f065f244f88c1d4089cb0fa4a45799 (diff)
more on scheduled jobs and ssh callbacks
Diffstat (limited to 'internal/ssh')
-rw-r--r--internal/ssh/client/authmethods.go18
-rw-r--r--internal/ssh/client/customkeycallback.go24
-rw-r--r--internal/ssh/client/hostkeycallback.go280
-rw-r--r--internal/ssh/client/knownhostscallback.go288
-rw-r--r--internal/ssh/client/simplecallback.go36
5 files changed, 363 insertions, 283 deletions
diff --git a/internal/ssh/client/authmethods.go b/internal/ssh/client/authmethods.go
index 072fad0..44c5601 100644
--- a/internal/ssh/client/authmethods.go
+++ b/internal/ssh/client/authmethods.go
@@ -10,16 +10,20 @@ import (
gossh "golang.org/x/crypto/ssh"
)
-// InitSSHAuthMethods initialises all known SSH auth methods on othe client side.
-func InitSSHAuthMethods(args clients.Args, trustAllHosts bool, throttleCh chan struct{}) ([]gossh.AuthMethod, *HostKeyCallback) {
- if len(args.SSHAuthMethods) > 0 {
- hostKeyCallback, err := NewSimpleCallback(trustAllHosts)
+// InitSSHAuthMethods initialises all known SSH auth methods on the client side.
+func InitSSHAuthMethods(sshAuthMethods []gossh.AuthMethod, hostKeyCallback gossh.HostKeyCallback, trustAllHosts bool, throttleCh chan struct{}) ([]gossh.AuthMethod, HostKeyCallback) {
+ if len(sshAuthMethods) > 0 {
+ simpleCallback, err := NewSimpleCallback()
if err != nil {
logger.FatalExit(err)
}
- return args.SSHAuthMethods, hostKeyCallback
+ return sshAuthMethods, simpleCallback
}
+ return initKnownHostsAuthMethods(trustAllHosts, throttleCh)
+}
+
+func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{}) ([]gossh.AuthMethod, HostKeyCallback) {
var sshAuthMethods []gossh.AuthMethod
if config.Common.ExperimentalFeaturesEnable {
sshAuthMethods = append(sshAuthMethods, gossh.Password("experimental feature test"))
@@ -44,10 +48,10 @@ func InitSSHAuthMethods(args clients.Args, trustAllHosts bool, throttleCh chan s
}
knownHostsPath := os.Getenv("HOME") + "/.ssh/known_hosts"
- hostKeyCallback, err := NewHostKeyCallback(knownHostsPath, trustAllHosts, throttleCh)
+ knownHostsCallback, err := NewKnownHostsCallback(knownHostsPath, trustAllHosts, throttleCh)
if err != nil {
logger.FatalExit(knownHostsPath, err)
}
- return sshAuthMethods, hostKeyCallback
+ return sshAuthMethods, knownHostsCallback
}
diff --git a/internal/ssh/client/customkeycallback.go b/internal/ssh/client/customkeycallback.go
new file mode 100644
index 0000000..73e5289
--- /dev/null
+++ b/internal/ssh/client/customkeycallback.go
@@ -0,0 +1,24 @@
+package client
+
+import (
+ "net"
+
+ "golang.org/x/crypto/ssh"
+)
+
+// CustomCallback is a custom host key callback wrapper.
+type CustomCallback struct {
+}
+
+// NewCustomCallback returns a new wrapper.
+func NewCustomCallback() (*CustomCallback, error) {
+ h := CustomCallback{}
+ return &h, nil
+}
+
+// Wrap the host key callback.
+func (h *CustomCallback) Wrap() ssh.HostKeyCallback {
+ return func(server string, remote net.Addr, key ssh.PublicKey) error {
+ return nil
+ }
+}
diff --git a/internal/ssh/client/hostkeycallback.go b/internal/ssh/client/hostkeycallback.go
index d090d7f..95543f2 100644
--- a/internal/ssh/client/hostkeycallback.go
+++ b/internal/ssh/client/hostkeycallback.go
@@ -1,287 +1,15 @@
package client
import (
- "bufio"
"context"
- "fmt"
- "net"
- "os"
- "strings"
- "sync"
- "time"
-
- "github.com/mimecast/dtail/internal/io/logger"
- "github.com/mimecast/dtail/internal/io/prompt"
"golang.org/x/crypto/ssh"
- "golang.org/x/crypto/ssh/knownhosts"
-)
-
-type response int
-
-const (
- trustHost response = iota
- dontTrustHost response = iota
)
-// Represents an unknown host.
-type unknownHost struct {
- server string
- remote net.Addr
- key ssh.PublicKey
- hostLine string
- ipLine string
- responseCh chan response
-}
-
// HostKeyCallback is a wrapper around ssh.KnownHosts so that we can add all
// unknown hosts in a single batch to the known_hosts file.
-type HostKeyCallback struct {
- knownHostsPath string
- unknownCh chan unknownHost
- throttleCh chan struct{}
- trustAllHostsCh chan struct{}
- untrustedHosts map[string]bool
- mutex sync.Mutex
-}
-
-// NewHostKeyCallback returns a new wrapper.
-func NewHostKeyCallback(knownHostsPath string, trustAllHosts bool, throttleCh chan struct{}) (*HostKeyCallback, error) {
- // Ensure file exists
- os.OpenFile(knownHostsPath, os.O_RDONLY|os.O_CREATE, 0666)
-
- h := HostKeyCallback{
- knownHostsPath: knownHostsPath,
- unknownCh: make(chan unknownHost),
- trustAllHostsCh: make(chan struct{}),
- throttleCh: throttleCh,
- untrustedHosts: make(map[string]bool),
- }
-
- if trustAllHosts {
- close(h.trustAllHostsCh)
- }
-
- return &h, nil
-}
-
-// Wrap the host key callback.
-func (h *HostKeyCallback) Wrap() ssh.HostKeyCallback {
- return func(server string, remote net.Addr, key ssh.PublicKey) error {
- // Parse known_hosts file
- knownHostsCb, err := knownhosts.New(h.knownHostsPath)
- if err != nil {
- // Problem parsing it
- return err
- }
-
- // Check for valid entry in known_hosts file
- err = knownHostsCb(server, remote, key)
- if err == nil {
- // OK
- return nil
- }
-
- // Make sure that interactive user callback does not interfere with
- // SSH connection throttler.
- <-h.throttleCh
- defer func() { h.throttleCh <- struct{}{} }()
-
- unknown := unknownHost{
- server: server,
- remote: remote,
- key: key,
- hostLine: knownhosts.Line([]string{server}, key),
- ipLine: knownhosts.Line([]string{remote.String()}, key),
- responseCh: make(chan response),
- }
-
- logger.Warn("Encountered unknown host", unknown)
- // Notify user that there is an unknown host
- h.unknownCh <- unknown
-
- // Wait for user input.
- switch <-unknown.responseCh {
- case trustHost:
- // End user acknowledged host key
- return nil
- case dontTrustHost:
- }
-
- h.mutex.Lock()
- defer h.mutex.Unlock()
- h.untrustedHosts[server] = true
-
- return err
- }
-}
-
-// PromptAddHosts prompts a question to the user whether unknown hosts should
-// be added to the known hosts or not.
-func (h *HostKeyCallback) PromptAddHosts(ctx context.Context) {
- var hosts []unknownHost
-
- for {
- // Check whether there is a unknown host
- select {
- case unknown := <-h.unknownCh:
- hosts = append(hosts, unknown)
- // Ask every 50 unknown hosts
- if len(hosts) >= 50 {
- h.promptAddHosts(hosts)
- hosts = []unknownHost{}
- }
- case <-time.After(2 * time.Second):
- // Or ask when after 2 seconds no new unknown hosts were added.
- if len(hosts) > 0 {
- h.promptAddHosts(hosts)
- hosts = []unknownHost{}
- }
- case <-ctx.Done():
- logger.Debug("Stopping goroutine prompting new hosts...")
- return
- }
- }
-}
-
-func (h *HostKeyCallback) promptAddHosts(hosts []unknownHost) {
- var servers []string
-
- for _, host := range hosts {
- servers = append(servers, host.server)
- }
-
- select {
- case <-h.trustAllHostsCh:
- logger.Warn("Trusting host keys of servers", servers)
- h.trustHosts(hosts)
- return
- default:
- }
-
- question := fmt.Sprintf("Encountered %d unknown hosts: '%s'\n%s",
- len(servers),
- strings.Join(servers, ","),
- "Do you want to trust these hosts?",
- )
-
- p := prompt.New(question)
-
- a := prompt.Answer{
- Long: "yes",
- Short: "y",
- Callback: func() {
- h.trustHosts(hosts)
- },
- EndCallback: func() {
- logger.Info("Added hosts to known hosts file", h.knownHostsPath)
- },
- }
- p.Add(a)
-
- a = prompt.Answer{
- Long: "all",
- Short: "a",
- Callback: func() {
- close(h.trustAllHostsCh)
- h.trustHosts(hosts)
- },
- EndCallback: func() {
- logger.Info("Added hosts to known hosts file", h.knownHostsPath)
- },
- }
- p.Add(a)
-
- a = prompt.Answer{
- Long: "no",
- Short: "n",
- Callback: func() {
- h.dontTrustHosts(hosts)
- },
- EndCallback: func() {
- logger.Info("Didn't add hosts to known hosts file", h.knownHostsPath)
- },
- }
- p.Add(a)
-
- a = prompt.Answer{
- Long: "details",
- Short: "d",
- AskAgain: true,
- Callback: func() {
- for _, unknown := range hosts {
- fmt.Println(unknown.hostLine)
- fmt.Println(unknown.ipLine)
- }
- },
- }
- p.Add(a)
-
- p.Ask()
-}
-
-func (h *HostKeyCallback) trustHosts(hosts []unknownHost) {
- tmpKnownHostsPath := fmt.Sprintf("%s.tmp", h.knownHostsPath)
-
- newFd, err := os.OpenFile(tmpKnownHostsPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600)
- if err != nil {
- panic(fmt.Sprintf("%s: %s", tmpKnownHostsPath, err.Error()))
- }
- defer newFd.Close()
-
- // Newly trusted hosts in normalized form
- addresses := make(map[string]struct{})
-
- // First write to new known hosts file, and keep track of addresses
- for _, unknown := range hosts {
- unknown.responseCh <- trustHost
-
- // Add once as [HOSTNAME]:PORT
- addresses[knownhosts.Normalize(unknown.server)] = struct{}{}
- // And once as [IP]:PORT
- addresses[knownhosts.Normalize(unknown.remote.String())] = struct{}{}
-
- newFd.WriteString(fmt.Sprintf("%s\n", unknown.hostLine))
- newFd.WriteString(fmt.Sprintf("%s\n", unknown.ipLine))
- }
-
- // Read old known hosts file, to see which are old and new entries
- os.OpenFile(h.knownHostsPath, os.O_RDONLY|os.O_CREATE, 0666)
- oldFd, err := os.Open(h.knownHostsPath)
- if err != nil {
- panic(err)
- }
- defer oldFd.Close()
-
- scanner := bufio.NewScanner(oldFd)
-
- // Now, append all still valid old entries to the new host file
- for scanner.Scan() {
- line := scanner.Text()
- address := strings.SplitN(line, " ", 2)[0]
-
- if _, ok := addresses[address]; !ok {
- newFd.WriteString(fmt.Sprintf("%s\n", line))
- }
- }
-
- // Now, replace old known hosts file
- if err := os.Rename(tmpKnownHostsPath, h.knownHostsPath); err != nil {
- panic(err)
- }
-}
-
-func (h *HostKeyCallback) dontTrustHosts(hosts []unknownHost) {
- for _, unknown := range hosts {
- unknown.responseCh <- dontTrustHost
- }
-}
-
-// Untrusted returns true if the host is not trusted. False otherwise.
-func (h *HostKeyCallback) Untrusted(server string) bool {
- h.mutex.Lock()
- defer h.mutex.Unlock()
- _, ok := h.untrustedHosts[server]
-
- return ok
+type HostKeyCallback interface {
+ Wrap() ssh.HostKeyCallback
+ Untrusted(server string) bool
+ PromptAddHosts(ctx context.Context)
}
diff --git a/internal/ssh/client/knownhostscallback.go b/internal/ssh/client/knownhostscallback.go
new file mode 100644
index 0000000..1ccf6c6
--- /dev/null
+++ b/internal/ssh/client/knownhostscallback.go
@@ -0,0 +1,288 @@
+package client
+
+import (
+ "bufio"
+ "context"
+ "fmt"
+ "net"
+ "os"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/io/prompt"
+
+ "golang.org/x/crypto/ssh"
+ "golang.org/x/crypto/ssh/knownhosts"
+)
+
+type response int
+
+const (
+ trustHost response = iota
+ dontTrustHost response = iota
+)
+
+// Represents an unknown host.
+type unknownHost struct {
+ server string
+ remote net.Addr
+ key ssh.PublicKey
+ hostLine string
+ ipLine string
+ responseCh chan response
+}
+
+// KnownHostsCallback is a wrapper around ssh.KnownHosts so that we can add all
+// unknown hosts in a single batch to the known_hosts file.
+type KnownHostsCallback struct {
+ knownHostsPath string
+ unknownCh chan unknownHost
+ throttleCh chan struct{}
+ trustAllHostsCh chan struct{}
+ untrustedHosts map[string]bool
+ mutex *sync.Mutex
+}
+
+// NewKnownHostsCallback returns a new wrapper.
+func NewKnownHostsCallback(knownHostsPath string, trustAllHosts bool, throttleCh chan struct{}) (HostKeyCallback, error) {
+ // Ensure file exists
+ os.OpenFile(knownHostsPath, os.O_RDONLY|os.O_CREATE, 0666)
+ untrustedHosts := make(map[string]bool)
+
+ c := KnownHostsCallback{
+ knownHostsPath: knownHostsPath,
+ unknownCh: make(chan unknownHost),
+ trustAllHostsCh: make(chan struct{}),
+ throttleCh: throttleCh,
+ untrustedHosts: untrustedHosts,
+ mutex: &sync.Mutex{},
+ }
+
+ if trustAllHosts {
+ close(c.trustAllHostsCh)
+ }
+
+ return c, nil
+}
+
+// Wrap the host key callback.
+func (c KnownHostsCallback) Wrap() ssh.HostKeyCallback {
+ return func(server string, remote net.Addr, key ssh.PublicKey) error {
+ // Parse known_hosts file
+ knownHostsCb, err := knownhosts.New(c.knownHostsPath)
+ if err != nil {
+ return err
+ }
+
+ // Check for valid entry in known_hosts file
+ err = knownHostsCb(server, remote, key)
+ if err == nil {
+ // OK
+ return nil
+ }
+
+ // Make sure that interactive user callback does not interfere with
+ // SSH connection throttler.
+ <-c.throttleCh
+ defer func() { c.throttleCh <- struct{}{} }()
+
+ unknown := unknownHost{
+ server: server,
+ remote: remote,
+ key: key,
+ hostLine: knownhosts.Line([]string{server}, key),
+ ipLine: knownhosts.Line([]string{remote.String()}, key),
+ responseCh: make(chan response),
+ }
+
+ logger.Warn("Encountered unknown host", unknown)
+ // Notify user that there is an unknown host
+ c.unknownCh <- unknown
+
+ // Wait for user input.
+ switch <-unknown.responseCh {
+ case trustHost:
+ // End user acknowledged host key
+ return nil
+ case dontTrustHost:
+ }
+
+ c.mutex.Lock()
+ defer c.mutex.Unlock()
+ c.untrustedHosts[server] = true
+
+ return err
+ }
+}
+
+// PromptAddHosts prompts a question to the user whether unknown hosts should
+// be added to the known hosts or not.
+func (c KnownHostsCallback) PromptAddHosts(ctx context.Context) {
+ var hosts []unknownHost
+
+ for {
+ // Check whether there is a unknown host
+ select {
+ case unknown := <-c.unknownCh:
+ hosts = append(hosts, unknown)
+ // Ask every 50 unknown hosts
+ if len(hosts) >= 50 {
+ c.promptAddHosts(hosts)
+ hosts = []unknownHost{}
+ }
+ case <-time.After(2 * time.Second):
+ // Or ask when after 2 seconds no new unknown hosts were added.
+ if len(hosts) > 0 {
+ c.promptAddHosts(hosts)
+ hosts = []unknownHost{}
+ }
+ case <-ctx.Done():
+ logger.Debug("Stopping goroutine prompting new hosts...")
+ return
+ }
+ }
+}
+
+func (c KnownHostsCallback) promptAddHosts(hosts []unknownHost) {
+ var servers []string
+
+ for _, host := range hosts {
+ servers = append(servers, host.server)
+ }
+
+ select {
+ case <-c.trustAllHostsCh:
+ logger.Warn("Trusting host keys of servers", servers)
+ c.trustHosts(hosts)
+ return
+ default:
+ }
+
+ question := fmt.Sprintf("Encountered %d unknown hosts: '%s'\n%s",
+ len(servers),
+ strings.Join(servers, ","),
+ "Do you want to trust these hosts?",
+ )
+
+ p := prompt.New(question)
+
+ a := prompt.Answer{
+ Long: "yes",
+ Short: "y",
+ Callback: func() {
+ c.trustHosts(hosts)
+ },
+ EndCallback: func() {
+ logger.Info("Added hosts to known hosts file", c.knownHostsPath)
+ },
+ }
+ p.Add(a)
+
+ a = prompt.Answer{
+ Long: "all",
+ Short: "a",
+ Callback: func() {
+ close(c.trustAllHostsCh)
+ c.trustHosts(hosts)
+ },
+ EndCallback: func() {
+ logger.Info("Added hosts to known hosts file", c.knownHostsPath)
+ },
+ }
+ p.Add(a)
+
+ a = prompt.Answer{
+ Long: "no",
+ Short: "n",
+ Callback: func() {
+ c.dontTrustHosts(hosts)
+ },
+ EndCallback: func() {
+ logger.Info("Didn't add hosts to known hosts file", c.knownHostsPath)
+ },
+ }
+ p.Add(a)
+
+ a = prompt.Answer{
+ Long: "details",
+ Short: "d",
+ AskAgain: true,
+ Callback: func() {
+ for _, unknown := range hosts {
+ fmt.Println(unknown.hostLine)
+ fmt.Println(unknown.ipLine)
+ }
+ },
+ }
+ p.Add(a)
+
+ p.Ask()
+}
+
+func (c KnownHostsCallback) trustHosts(hosts []unknownHost) {
+ tmpKnownHostsPath := fmt.Sprintf("%s.tmp", c.knownHostsPath)
+
+ newFd, err := os.OpenFile(tmpKnownHostsPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600)
+ if err != nil {
+ panic(fmt.Sprintf("%s: %s", tmpKnownHostsPath, err.Error()))
+ }
+ defer newFd.Close()
+
+ // Newly trusted hosts in normalized form
+ addresses := make(map[string]struct{})
+
+ // First write to new known hosts file, and keep track of addresses
+ for _, unknown := range hosts {
+ unknown.responseCh <- trustHost
+
+ // Add once as [HOSTNAME]:PORT
+ addresses[knownhosts.Normalize(unknown.server)] = struct{}{}
+ // And once as [IP]:PORT
+ addresses[knownhosts.Normalize(unknown.remote.String())] = struct{}{}
+
+ newFd.WriteString(fmt.Sprintf("%s\n", unknown.hostLine))
+ newFd.WriteString(fmt.Sprintf("%s\n", unknown.ipLine))
+ }
+
+ // Read old known hosts file, to see which are old and new entries
+ os.OpenFile(c.knownHostsPath, os.O_RDONLY|os.O_CREATE, 0666)
+ oldFd, err := os.Open(c.knownHostsPath)
+ if err != nil {
+ panic(err)
+ }
+ defer oldFd.Close()
+
+ scanner := bufio.NewScanner(oldFd)
+
+ // Now, append all still valid old entries to the new host file
+ for scanner.Scan() {
+ line := scanner.Text()
+ address := strings.SplitN(line, " ", 2)[0]
+
+ if _, ok := addresses[address]; !ok {
+ newFd.WriteString(fmt.Sprintf("%s\n", line))
+ }
+ }
+
+ // Now, replace old known hosts file
+ if err := os.Rename(tmpKnownHostsPath, c.knownHostsPath); err != nil {
+ panic(err)
+ }
+}
+
+func (c KnownHostsCallback) dontTrustHosts(hosts []unknownHost) {
+ for _, unknown := range hosts {
+ unknown.responseCh <- dontTrustHost
+ }
+}
+
+// Untrusted returns true if the host is not trusted. False otherwise.
+func (c KnownHostsCallback) Untrusted(server string) bool {
+ c.mutex.Lock()
+ defer c.mutex.Unlock()
+ _, ok := c.untrustedHosts[server]
+
+ return ok
+}
diff --git a/internal/ssh/client/simplecallback.go b/internal/ssh/client/simplecallback.go
new file mode 100644
index 0000000..580fa36
--- /dev/null
+++ b/internal/ssh/client/simplecallback.go
@@ -0,0 +1,36 @@
+package client
+
+import (
+ "context"
+ "net"
+
+ "golang.org/x/crypto/ssh"
+)
+
+// SimpleCallback is a wrapper around ssh.KnownHosts so that we can add all
+// unknown hosts in a single batch to the known_hosts file.
+type SimpleCallback struct {
+}
+
+// NewSimpleCallback returns a new wrapper.
+func NewSimpleCallback() (SimpleCallback, error) {
+ return SimpleCallback{}, nil
+}
+
+// Wrap the host key callback.
+func (SimpleCallback) Wrap() ssh.HostKeyCallback {
+ return func(server string, remote net.Addr, key ssh.PublicKey) error {
+ return nil
+ }
+}
+
+// Untrusted returns whether host is not trusted or not.
+func (SimpleCallback) Untrusted(server string) bool {
+ return false
+}
+
+// PromptAddHosts prompts a question to the user whether unknown hosts should
+// be added to the known hosts or not.
+func (SimpleCallback) PromptAddHosts(ctx context.Context) {
+ // Not used here.
+}