diff options
Diffstat (limited to 'internal/ssh')
| -rw-r--r-- | internal/ssh/client/authmethods.go | 67 | ||||
| -rw-r--r-- | internal/ssh/client/customkeycallback.go | 3 | ||||
| -rw-r--r-- | internal/ssh/client/knownhostscallback.go | 38 | ||||
| -rw-r--r-- | internal/ssh/server/hostkey.go | 18 | ||||
| -rw-r--r-- | internal/ssh/server/publickeycallback.go | 38 | ||||
| -rw-r--r-- | internal/ssh/ssh.go | 11 |
6 files changed, 96 insertions, 79 deletions
diff --git a/internal/ssh/client/authmethods.go b/internal/ssh/client/authmethods.go index bbfb7be..37f8382 100644 --- a/internal/ssh/client/authmethods.go +++ b/internal/ssh/client/authmethods.go @@ -4,89 +4,106 @@ import ( "os" "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/io/dlog" "github.com/mimecast/dtail/internal/ssh" gossh "golang.org/x/crypto/ssh" ) // InitSSHAuthMethods initialises all known SSH auth methods on the client side. -func InitSSHAuthMethods(sshAuthMethods []gossh.AuthMethod, hostKeyCallback gossh.HostKeyCallback, trustAllHosts bool, throttleCh chan struct{}, privateKeyPath string) ([]gossh.AuthMethod, HostKeyCallback) { +func InitSSHAuthMethods(sshAuthMethods []gossh.AuthMethod, + hostKeyCallback gossh.HostKeyCallback, trustAllHosts bool, throttleCh chan struct{}, + privateKeyPath string) ([]gossh.AuthMethod, HostKeyCallback) { + if len(sshAuthMethods) > 0 { simpleCallback, err := NewSimpleCallback() if err != nil { - logger.FatalExit(err) + dlog.Client.FatalPanic(err) } return sshAuthMethods, simpleCallback } - return initKnownHostsAuthMethods(trustAllHosts, throttleCh, privateKeyPath) } -func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{}, privateKeyPath string) ([]gossh.AuthMethod, HostKeyCallback) { - var sshAuthMethods []gossh.AuthMethod +func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{}, + privateKeyPath string) ([]gossh.AuthMethod, HostKeyCallback) { + var sshAuthMethods []gossh.AuthMethod knownHostsPath := os.Getenv("HOME") + "/.ssh/known_hosts" - knownHostsCallback, err := NewKnownHostsCallback(knownHostsPath, trustAllHosts, throttleCh) + knownHostsCallback, err := NewKnownHostsCallback(knownHostsPath, trustAllHosts, + throttleCh) if err != nil { - logger.FatalExit(knownHostsPath, err) - } - logger.Debug("initKnownHostsAuthMethods", "Added known hosts file path", knownHostsPath) - - if config.Common.ExperimentalFeaturesEnable { - sshAuthMethods = append(sshAuthMethods, gossh.Password("experimental feature test")) - logger.Debug("initKnownHostsAuthMethods", "Added experimental method to list of auth methods") + dlog.Client.FatalPanic(knownHostsPath, err) } + dlog.Client.Debug("initKnownHostsAuthMethods", "Added known hosts file path", knownHostsPath) + /* + if config.Client.ExperimentalFeaturesEnable { + sshAuthMethods = append(sshAuthMethods, gossh.Password("experimental feature test")) + dlog.Client.Debug("initKnownHostsAuthMethods", "Added experimental method to list of auth methods") + } + */ // First try to read custom private key path. if privateKeyPath != "" { authMethod, err := ssh.PrivateKey(privateKeyPath) if err == nil { sshAuthMethods = append(sshAuthMethods, authMethod) - logger.Debug("initKnownHostsAuthMethods", "Added path to list of auth methods, not adding further methods", privateKeyPath) + dlog.Client.Debug("initKnownHostsAuthMethods", + "Added path to list of auth methods, not adding further methods", + privateKeyPath) return sshAuthMethods, knownHostsCallback } - logger.FatalExit("Unable to use private SSH key", privateKeyPath, err) + dlog.Client.FatalPanic("Unable to use private SSH key", privateKeyPath, err) } // Second, try SSH Agent authMethod, err := ssh.Agent() if err == nil { sshAuthMethods = append(sshAuthMethods, authMethod) - logger.Debug("initKnownHostsAuthMethods", "Added SSH Agent (SSH_AUTH_SOCK) to list of auth methods, not adding further methods") + dlog.Client.Debug("initKnownHostsAuthMethods", "Added SSH Agent (SSH_AUTH_SOCK)"+ + "to list of auth methods, not adding further methods") return sshAuthMethods, knownHostsCallback } - logger.Debug("initKnownHostsAuthMethods", "Unable to init SSH Agent auth method", err) + dlog.Client.Debug("initKnownHostsAuthMethods", + "Unable to init SSH Agent auth method", err) // Third, try Linux/UNIX default key paths privateKeyPath = os.Getenv("HOME") + "/.ssh/id_rsa" authMethod, err = ssh.PrivateKey(privateKeyPath) if err == nil { sshAuthMethods = append(sshAuthMethods, authMethod) - logger.Debug("initKnownHostsAuthmethods", "Added path to list of auth methods, not adding further methods", privateKeyPath) + dlog.Client.Debug("initKnownHostsAuthmethods", + "Added path to list of auth methods, not adding further methods", privateKeyPath) return sshAuthMethods, knownHostsCallback } - logger.Debug("initKnownHostsAuthMethods", "Unable to use private key", privateKeyPath, err) + dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to use private key", + privateKeyPath, err) privateKeyPath = os.Getenv("HOME") + "/.ssh/id_dsa" authMethod, err = ssh.PrivateKey(privateKeyPath) if err == nil { sshAuthMethods = append(sshAuthMethods, authMethod) - logger.Debug("initKnownHostsAuthmethods", "Added path to list of auth methods, not adding further methods", privateKeyPath) + dlog.Client.Debug("initKnownHostsAuthmethods", + "Added path to list of auth methods, not adding further methods", privateKeyPath) return sshAuthMethods, knownHostsCallback } - logger.Debug("initKnownHostsAuthMethods", "Unable to use private key", privateKeyPath, err) privateKeyPath = os.Getenv("HOME") + "/.ssh/id_ecdsa" authMethod, err = ssh.PrivateKey(privateKeyPath) if err == nil { sshAuthMethods = append(sshAuthMethods, authMethod) - logger.Debug("initKnownHostsAuthmethods", "Added path to list of auth methods, not adding further methods", privateKeyPath) + dlog.Client.Debug("initKnownHostsAuthmethods", + "Added path to list of auth methods, not adding further methods", privateKeyPath) return sshAuthMethods, knownHostsCallback } - logger.Debug("initKnownHostsAuthMethods", "Unable to use private key", privateKeyPath, err) - logger.FatalExit("Unable to find private SSH key information") + dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to use private key", + privateKeyPath, err) + + // This is only a panic when we expect to do something about it. + if !config.Client.SSHDontAddHostsToKnownHostsFile { + dlog.Client.FatalPanic("Unable to find private SSH key information") + } // Never reach this point. return sshAuthMethods, knownHostsCallback diff --git a/internal/ssh/client/customkeycallback.go b/internal/ssh/client/customkeycallback.go index 73e5289..53b8e3c 100644 --- a/internal/ssh/client/customkeycallback.go +++ b/internal/ssh/client/customkeycallback.go @@ -7,8 +7,7 @@ import ( ) // CustomCallback is a custom host key callback wrapper. -type CustomCallback struct { -} +type CustomCallback struct{} // NewCustomCallback returns a new wrapper. func NewCustomCallback() (*CustomCallback, error) { diff --git a/internal/ssh/client/knownhostscallback.go b/internal/ssh/client/knownhostscallback.go index 1ccf6c6..2aa0168 100644 --- a/internal/ssh/client/knownhostscallback.go +++ b/internal/ssh/client/knownhostscallback.go @@ -10,7 +10,8 @@ import ( "sync" "time" - "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/io/dlog" "github.com/mimecast/dtail/internal/io/prompt" "golang.org/x/crypto/ssh" @@ -46,8 +47,9 @@ type KnownHostsCallback struct { } // NewKnownHostsCallback returns a new wrapper. -func NewKnownHostsCallback(knownHostsPath string, trustAllHosts bool, throttleCh chan struct{}) (HostKeyCallback, error) { - // Ensure file exists +func NewKnownHostsCallback(knownHostsPath string, trustAllHosts bool, + throttleCh chan struct{}) (HostKeyCallback, error) { + os.OpenFile(knownHostsPath, os.O_RDONLY|os.O_CREATE, 0666) untrustedHosts := make(map[string]bool) @@ -59,11 +61,9 @@ func NewKnownHostsCallback(knownHostsPath string, trustAllHosts bool, throttleCh untrustedHosts: untrustedHosts, mutex: &sync.Mutex{}, } - if trustAllHosts { close(c.trustAllHostsCh) } - return c, nil } @@ -75,14 +75,12 @@ func (c KnownHostsCallback) Wrap() ssh.HostKeyCallback { if err != nil { return err } - // Check for valid entry in known_hosts file err = knownHostsCb(server, remote, key) if err == nil { // OK return nil } - // Make sure that interactive user callback does not interfere with // SSH connection throttler. <-c.throttleCh @@ -96,11 +94,9 @@ func (c KnownHostsCallback) Wrap() ssh.HostKeyCallback { ipLine: knownhosts.Line([]string{remote.String()}, key), responseCh: make(chan response), } - - logger.Warn("Encountered unknown host", unknown) + dlog.Common.Warn("Encountered unknown host", unknown) // Notify user that there is an unknown host c.unknownCh <- unknown - // Wait for user input. switch <-unknown.responseCh { case trustHost: @@ -112,7 +108,6 @@ func (c KnownHostsCallback) Wrap() ssh.HostKeyCallback { c.mutex.Lock() defer c.mutex.Unlock() c.untrustedHosts[server] = true - return err } } @@ -121,7 +116,6 @@ func (c KnownHostsCallback) Wrap() ssh.HostKeyCallback { // be added to the known hosts or not. func (c KnownHostsCallback) PromptAddHosts(ctx context.Context) { var hosts []unknownHost - for { // Check whether there is a unknown host select { @@ -139,7 +133,7 @@ func (c KnownHostsCallback) PromptAddHosts(ctx context.Context) { hosts = []unknownHost{} } case <-ctx.Done(): - logger.Debug("Stopping goroutine prompting new hosts...") + dlog.Common.Debug("Stopping goroutine prompting new hosts...") return } } @@ -147,14 +141,13 @@ func (c KnownHostsCallback) PromptAddHosts(ctx context.Context) { func (c KnownHostsCallback) promptAddHosts(hosts []unknownHost) { var servers []string - for _, host := range hosts { servers = append(servers, host.server) } select { case <-c.trustAllHostsCh: - logger.Warn("Trusting host keys of servers", servers) + dlog.Common.Warn("Trusting host keys of servers", servers) c.trustHosts(hosts) return default: @@ -165,7 +158,6 @@ func (c KnownHostsCallback) promptAddHosts(hosts []unknownHost) { strings.Join(servers, ","), "Do you want to trust these hosts?", ) - p := prompt.New(question) a := prompt.Answer{ @@ -175,7 +167,7 @@ func (c KnownHostsCallback) promptAddHosts(hosts []unknownHost) { c.trustHosts(hosts) }, EndCallback: func() { - logger.Info("Added hosts to known hosts file", c.knownHostsPath) + dlog.Common.Info("Added hosts to known hosts file", c.knownHostsPath) }, } p.Add(a) @@ -188,7 +180,7 @@ func (c KnownHostsCallback) promptAddHosts(hosts []unknownHost) { c.trustHosts(hosts) }, EndCallback: func() { - logger.Info("Added hosts to known hosts file", c.knownHostsPath) + dlog.Common.Info("Added hosts to known hosts file", c.knownHostsPath) }, } p.Add(a) @@ -200,7 +192,7 @@ func (c KnownHostsCallback) promptAddHosts(hosts []unknownHost) { c.dontTrustHosts(hosts) }, EndCallback: func() { - logger.Info("Didn't add hosts to known hosts file", c.knownHostsPath) + dlog.Common.Info("Didn't add hosts to known hosts file", c.knownHostsPath) }, } p.Add(a) @@ -224,6 +216,11 @@ func (c KnownHostsCallback) promptAddHosts(hosts []unknownHost) { func (c KnownHostsCallback) trustHosts(hosts []unknownHost) { tmpKnownHostsPath := fmt.Sprintf("%s.tmp", c.knownHostsPath) + if config.Client.SSHDontAddHostsToKnownHostsFile { + dlog.Common.Verbose("Not adding hosts to known hosts file, as disabled by config") + return + } + newFd, err := os.OpenFile(tmpKnownHostsPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600) if err != nil { panic(fmt.Sprintf("%s: %s", tmpKnownHostsPath, err.Error())) @@ -232,7 +229,6 @@ func (c KnownHostsCallback) trustHosts(hosts []unknownHost) { // Newly trusted hosts in normalized form addresses := make(map[string]struct{}) - // First write to new known hosts file, and keep track of addresses for _, unknown := range hosts { unknown.responseCh <- trustHost @@ -255,7 +251,6 @@ func (c KnownHostsCallback) trustHosts(hosts []unknownHost) { defer oldFd.Close() scanner := bufio.NewScanner(oldFd) - // Now, append all still valid old entries to the new host file for scanner.Scan() { line := scanner.Text() @@ -283,6 +278,5 @@ func (c KnownHostsCallback) Untrusted(server string) bool { c.mutex.Lock() defer c.mutex.Unlock() _, ok := c.untrustedHosts[server] - return ok } diff --git a/internal/ssh/server/hostkey.go b/internal/ssh/server/hostkey.go index 07790ad..33bd4e8 100644 --- a/internal/ssh/server/hostkey.go +++ b/internal/ssh/server/hostkey.go @@ -1,11 +1,12 @@ package server import ( - "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/io/logger" - "github.com/mimecast/dtail/internal/ssh" "io/ioutil" "os" + + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/io/dlog" + "github.com/mimecast/dtail/internal/ssh" ) // PrivateHostKey retrieves the private server RSA host key. @@ -14,24 +15,25 @@ func PrivateHostKey() []byte { _, err := os.Stat(hostKeyFile) if os.IsNotExist(err) { - logger.Info("Generating private server RSA host key") + dlog.Common.Info("Generating private server RSA host key") privateKey, err := ssh.GeneratePrivateRSAKey(config.Server.HostKeyBits) if err != nil { - logger.FatalExit("Failed to generate private server RSA host key", err) + dlog.Common.FatalPanic("Failed to generate private server RSA host key", err) } pem := ssh.EncodePrivateKeyToPEM(privateKey) if err := ioutil.WriteFile(hostKeyFile, pem, 0600); err != nil { - logger.Error("Unable to write private server RSA host key to file", hostKeyFile, err) + dlog.Common.Error("Unable to write private server RSA host key to file", + hostKeyFile, err) } return pem } - logger.Info("Reading private server RSA host key from file", hostKeyFile) + dlog.Common.Info("Reading private server RSA host key from file", hostKeyFile) pem, err := ioutil.ReadFile(hostKeyFile) if err != nil { - logger.FatalExit("Failed to load private server RSA host key", err) + dlog.Common.FatalPanic("Failed to load private server RSA host key", err) } return pem } diff --git a/internal/ssh/server/publickeycallback.go b/internal/ssh/server/publickeycallback.go index e81f019..ebc428a 100644 --- a/internal/ssh/server/publickeycallback.go +++ b/internal/ssh/server/publickeycallback.go @@ -7,28 +7,34 @@ import ( osUser "os/user" "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/io/dlog" user "github.com/mimecast/dtail/internal/user/server" gossh "golang.org/x/crypto/ssh" ) -// PublicKeyCallback is for the server to check whether a public SSH key is authorized ot not. -func PublicKeyCallback(c gossh.ConnMetadata, offeredPubKey gossh.PublicKey) (*gossh.Permissions, error) { - user := user.New(c.User(), c.RemoteAddr().String()) - logger.Info(user, "Incoming authorization") +// PublicKeyCallback is for the server to check whether a public SSH key is +// authorized ot not. +func PublicKeyCallback(c gossh.ConnMetadata, + offeredPubKey gossh.PublicKey) (*gossh.Permissions, error) { + user, err := user.New(c.User(), c.RemoteAddr().String()) + if err != nil { + return nil, err + } + + dlog.Common.Info(user, "Incoming authorization") cwd, err := os.Getwd() if err != nil { return nil, fmt.Errorf("Unable to get current working directory|%s|", err.Error()) } - if config.ServerRelaxedAuthEnable { - logger.Fatal(user, "Granting permissions via relaxed-auth") + dlog.Common.Fatal(user, "Granting permissions via relaxed-auth") return nil, nil } - authorizedKeysFile := fmt.Sprintf("%s/%s/%s.authorized_keys", cwd, config.Common.CacheDir, user.Name) + authorizedKeysFile := fmt.Sprintf("%s/%s/%s.authorized_keys", cwd, + config.Common.CacheDir, user.Name) if _, err := os.Stat(authorizedKeysFile); os.IsNotExist(err) { user, err := osUser.Lookup(user.Name) if err != nil { @@ -38,26 +44,28 @@ func PublicKeyCallback(c gossh.ConnMetadata, offeredPubKey gossh.PublicKey) (*go authorizedKeysFile = user.HomeDir + "/.ssh/authorized_keys" } - logger.Info(user, "Reading", authorizedKeysFile) + dlog.Common.Info(user, "Reading", authorizedKeysFile) authorizedKeysBytes, err := ioutil.ReadFile(authorizedKeysFile) if err != nil { - return nil, fmt.Errorf("Unable to read authorized keys file|%s|%s|%s", authorizedKeysFile, user, err.Error()) + return nil, fmt.Errorf("Unable to read authorized keys file|%s|%s|%s", + authorizedKeysFile, user, err.Error()) } authorizedKeysMap := map[string]bool{} for len(authorizedKeysBytes) > 0 { authorizedPubKey, _, _, restBytes, err := gossh.ParseAuthorizedKey(authorizedKeysBytes) if err != nil { - return nil, fmt.Errorf("Unable to parse authorized keys bytes|%s|%s", user, err.Error()) + return nil, fmt.Errorf("Unable to parse authorized keys bytes|%s|%s", + user, err.Error()) } authorizedKeysMap[string(authorizedPubKey.Marshal())] = true authorizedKeysBytes = restBytes - - logger.Debug(user, "Authorized public key fingerprint", gossh.FingerprintSHA256(authorizedPubKey)) + dlog.Common.Debug(user, "Authorized public key fingerprint", + gossh.FingerprintSHA256(authorizedPubKey)) } - logger.Debug(user, "Offered public key fingerprint", gossh.FingerprintSHA256(offeredPubKey)) - + dlog.Common.Debug(user, "Offered public key fingerprint", + gossh.FingerprintSHA256(offeredPubKey)) if authorizedKeysMap[string(offeredPubKey.Marshal())] { return &gossh.Permissions{ Extensions: map[string]string{ diff --git a/internal/ssh/ssh.go b/internal/ssh/ssh.go index 3a2e416..db5aaf1 100644 --- a/internal/ssh/ssh.go +++ b/internal/ssh/ssh.go @@ -6,12 +6,13 @@ import ( "crypto/x509" "encoding/pem" "fmt" - "github.com/mimecast/dtail/internal/io/logger" "io/ioutil" "net" "os" "syscall" + "github.com/mimecast/dtail/internal/io/dlog" + gossh "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" "golang.org/x/crypto/ssh/terminal" @@ -23,12 +24,10 @@ func GeneratePrivateRSAKey(size int) (*rsa.PrivateKey, error) { if err != nil { return nil, err } - err = privateKey.Validate() if err != nil { return nil, err } - return privateKey, nil } @@ -41,7 +40,6 @@ func EncodePrivateKeyToPEM(privateKey *rsa.PrivateKey) []byte { Headers: nil, Bytes: derFormat, } - return pem.EncodeToMemory(&block) } @@ -57,7 +55,7 @@ func Agent() (gossh.AuthMethod, error) { return nil, err } for i, key := range keys { - logger.Debug("Public key", i, key) + dlog.Common.Debug("Public key", i, key) } return gossh.PublicKeysCallback(agentClient.Signers), nil } @@ -79,7 +77,6 @@ func KeyFile(keyFile string) (gossh.AuthMethod, error) { if err != nil { return nil, err } - key, err := gossh.ParsePrivateKey(buffer) if err != nil { return nil, err @@ -105,7 +102,7 @@ func KeyFile(keyFile string) (gossh.AuthMethod, error) { func PrivateKey(keyFile string) (gossh.AuthMethod, error) { signer, err := KeyFile(keyFile) if err != nil { - logger.Debug(keyFile, err) + dlog.Common.Debug(keyFile, err) return nil, err } return gossh.AuthMethod(signer), nil |
