summaryrefslogtreecommitdiff
path: root/internal/ssh
diff options
context:
space:
mode:
authorPaul Buetow <pbuetow@mimecast.com>2021-10-21 21:28:49 +0300
committerPaul Buetow <pbuetow@mimecast.com>2021-10-21 21:28:49 +0300
commitf4207a55f71bfbcfdc532d5cdd3befaa3474a157 (patch)
treeea5e4a2d2a67035f645bdee496ae55a52034178a /internal/ssh
parentd80d6070557e3a800e3a54967af9eced518f116b (diff)
parent739205206d63bf42f4e843b39d04d4c8cd8207c3 (diff)
merge develop
Diffstat (limited to 'internal/ssh')
-rw-r--r--internal/ssh/client/authmethods.go67
-rw-r--r--internal/ssh/client/customkeycallback.go3
-rw-r--r--internal/ssh/client/knownhostscallback.go38
-rw-r--r--internal/ssh/server/hostkey.go18
-rw-r--r--internal/ssh/server/publickeycallback.go38
-rw-r--r--internal/ssh/ssh.go11
6 files changed, 96 insertions, 79 deletions
diff --git a/internal/ssh/client/authmethods.go b/internal/ssh/client/authmethods.go
index bbfb7be..37f8382 100644
--- a/internal/ssh/client/authmethods.go
+++ b/internal/ssh/client/authmethods.go
@@ -4,89 +4,106 @@ import (
"os"
"github.com/mimecast/dtail/internal/config"
- "github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/io/dlog"
"github.com/mimecast/dtail/internal/ssh"
gossh "golang.org/x/crypto/ssh"
)
// InitSSHAuthMethods initialises all known SSH auth methods on the client side.
-func InitSSHAuthMethods(sshAuthMethods []gossh.AuthMethod, hostKeyCallback gossh.HostKeyCallback, trustAllHosts bool, throttleCh chan struct{}, privateKeyPath string) ([]gossh.AuthMethod, HostKeyCallback) {
+func InitSSHAuthMethods(sshAuthMethods []gossh.AuthMethod,
+ hostKeyCallback gossh.HostKeyCallback, trustAllHosts bool, throttleCh chan struct{},
+ privateKeyPath string) ([]gossh.AuthMethod, HostKeyCallback) {
+
if len(sshAuthMethods) > 0 {
simpleCallback, err := NewSimpleCallback()
if err != nil {
- logger.FatalExit(err)
+ dlog.Client.FatalPanic(err)
}
return sshAuthMethods, simpleCallback
}
-
return initKnownHostsAuthMethods(trustAllHosts, throttleCh, privateKeyPath)
}
-func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{}, privateKeyPath string) ([]gossh.AuthMethod, HostKeyCallback) {
- var sshAuthMethods []gossh.AuthMethod
+func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{},
+ privateKeyPath string) ([]gossh.AuthMethod, HostKeyCallback) {
+ var sshAuthMethods []gossh.AuthMethod
knownHostsPath := os.Getenv("HOME") + "/.ssh/known_hosts"
- knownHostsCallback, err := NewKnownHostsCallback(knownHostsPath, trustAllHosts, throttleCh)
+ knownHostsCallback, err := NewKnownHostsCallback(knownHostsPath, trustAllHosts,
+ throttleCh)
if err != nil {
- logger.FatalExit(knownHostsPath, err)
- }
- logger.Debug("initKnownHostsAuthMethods", "Added known hosts file path", knownHostsPath)
-
- if config.Common.ExperimentalFeaturesEnable {
- sshAuthMethods = append(sshAuthMethods, gossh.Password("experimental feature test"))
- logger.Debug("initKnownHostsAuthMethods", "Added experimental method to list of auth methods")
+ dlog.Client.FatalPanic(knownHostsPath, err)
}
+ dlog.Client.Debug("initKnownHostsAuthMethods", "Added known hosts file path", knownHostsPath)
+ /*
+ if config.Client.ExperimentalFeaturesEnable {
+ sshAuthMethods = append(sshAuthMethods, gossh.Password("experimental feature test"))
+ dlog.Client.Debug("initKnownHostsAuthMethods", "Added experimental method to list of auth methods")
+ }
+ */
// First try to read custom private key path.
if privateKeyPath != "" {
authMethod, err := ssh.PrivateKey(privateKeyPath)
if err == nil {
sshAuthMethods = append(sshAuthMethods, authMethod)
- logger.Debug("initKnownHostsAuthMethods", "Added path to list of auth methods, not adding further methods", privateKeyPath)
+ dlog.Client.Debug("initKnownHostsAuthMethods",
+ "Added path to list of auth methods, not adding further methods",
+ privateKeyPath)
return sshAuthMethods, knownHostsCallback
}
- logger.FatalExit("Unable to use private SSH key", privateKeyPath, err)
+ dlog.Client.FatalPanic("Unable to use private SSH key", privateKeyPath, err)
}
// Second, try SSH Agent
authMethod, err := ssh.Agent()
if err == nil {
sshAuthMethods = append(sshAuthMethods, authMethod)
- logger.Debug("initKnownHostsAuthMethods", "Added SSH Agent (SSH_AUTH_SOCK) to list of auth methods, not adding further methods")
+ dlog.Client.Debug("initKnownHostsAuthMethods", "Added SSH Agent (SSH_AUTH_SOCK)"+
+ "to list of auth methods, not adding further methods")
return sshAuthMethods, knownHostsCallback
}
- logger.Debug("initKnownHostsAuthMethods", "Unable to init SSH Agent auth method", err)
+ dlog.Client.Debug("initKnownHostsAuthMethods",
+ "Unable to init SSH Agent auth method", err)
// Third, try Linux/UNIX default key paths
privateKeyPath = os.Getenv("HOME") + "/.ssh/id_rsa"
authMethod, err = ssh.PrivateKey(privateKeyPath)
if err == nil {
sshAuthMethods = append(sshAuthMethods, authMethod)
- logger.Debug("initKnownHostsAuthmethods", "Added path to list of auth methods, not adding further methods", privateKeyPath)
+ dlog.Client.Debug("initKnownHostsAuthmethods",
+ "Added path to list of auth methods, not adding further methods", privateKeyPath)
return sshAuthMethods, knownHostsCallback
}
- logger.Debug("initKnownHostsAuthMethods", "Unable to use private key", privateKeyPath, err)
+ dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to use private key",
+ privateKeyPath, err)
privateKeyPath = os.Getenv("HOME") + "/.ssh/id_dsa"
authMethod, err = ssh.PrivateKey(privateKeyPath)
if err == nil {
sshAuthMethods = append(sshAuthMethods, authMethod)
- logger.Debug("initKnownHostsAuthmethods", "Added path to list of auth methods, not adding further methods", privateKeyPath)
+ dlog.Client.Debug("initKnownHostsAuthmethods",
+ "Added path to list of auth methods, not adding further methods", privateKeyPath)
return sshAuthMethods, knownHostsCallback
}
- logger.Debug("initKnownHostsAuthMethods", "Unable to use private key", privateKeyPath, err)
privateKeyPath = os.Getenv("HOME") + "/.ssh/id_ecdsa"
authMethod, err = ssh.PrivateKey(privateKeyPath)
if err == nil {
sshAuthMethods = append(sshAuthMethods, authMethod)
- logger.Debug("initKnownHostsAuthmethods", "Added path to list of auth methods, not adding further methods", privateKeyPath)
+ dlog.Client.Debug("initKnownHostsAuthmethods",
+ "Added path to list of auth methods, not adding further methods", privateKeyPath)
return sshAuthMethods, knownHostsCallback
}
- logger.Debug("initKnownHostsAuthMethods", "Unable to use private key", privateKeyPath, err)
- logger.FatalExit("Unable to find private SSH key information")
+ dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to use private key",
+ privateKeyPath, err)
+
+ // This is only a panic when we expect to do something about it.
+ if !config.Client.SSHDontAddHostsToKnownHostsFile {
+ dlog.Client.FatalPanic("Unable to find private SSH key information")
+ }
// Never reach this point.
return sshAuthMethods, knownHostsCallback
diff --git a/internal/ssh/client/customkeycallback.go b/internal/ssh/client/customkeycallback.go
index 73e5289..53b8e3c 100644
--- a/internal/ssh/client/customkeycallback.go
+++ b/internal/ssh/client/customkeycallback.go
@@ -7,8 +7,7 @@ import (
)
// CustomCallback is a custom host key callback wrapper.
-type CustomCallback struct {
-}
+type CustomCallback struct{}
// NewCustomCallback returns a new wrapper.
func NewCustomCallback() (*CustomCallback, error) {
diff --git a/internal/ssh/client/knownhostscallback.go b/internal/ssh/client/knownhostscallback.go
index 1ccf6c6..2aa0168 100644
--- a/internal/ssh/client/knownhostscallback.go
+++ b/internal/ssh/client/knownhostscallback.go
@@ -10,7 +10,8 @@ import (
"sync"
"time"
- "github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/io/dlog"
"github.com/mimecast/dtail/internal/io/prompt"
"golang.org/x/crypto/ssh"
@@ -46,8 +47,9 @@ type KnownHostsCallback struct {
}
// NewKnownHostsCallback returns a new wrapper.
-func NewKnownHostsCallback(knownHostsPath string, trustAllHosts bool, throttleCh chan struct{}) (HostKeyCallback, error) {
- // Ensure file exists
+func NewKnownHostsCallback(knownHostsPath string, trustAllHosts bool,
+ throttleCh chan struct{}) (HostKeyCallback, error) {
+
os.OpenFile(knownHostsPath, os.O_RDONLY|os.O_CREATE, 0666)
untrustedHosts := make(map[string]bool)
@@ -59,11 +61,9 @@ func NewKnownHostsCallback(knownHostsPath string, trustAllHosts bool, throttleCh
untrustedHosts: untrustedHosts,
mutex: &sync.Mutex{},
}
-
if trustAllHosts {
close(c.trustAllHostsCh)
}
-
return c, nil
}
@@ -75,14 +75,12 @@ func (c KnownHostsCallback) Wrap() ssh.HostKeyCallback {
if err != nil {
return err
}
-
// Check for valid entry in known_hosts file
err = knownHostsCb(server, remote, key)
if err == nil {
// OK
return nil
}
-
// Make sure that interactive user callback does not interfere with
// SSH connection throttler.
<-c.throttleCh
@@ -96,11 +94,9 @@ func (c KnownHostsCallback) Wrap() ssh.HostKeyCallback {
ipLine: knownhosts.Line([]string{remote.String()}, key),
responseCh: make(chan response),
}
-
- logger.Warn("Encountered unknown host", unknown)
+ dlog.Common.Warn("Encountered unknown host", unknown)
// Notify user that there is an unknown host
c.unknownCh <- unknown
-
// Wait for user input.
switch <-unknown.responseCh {
case trustHost:
@@ -112,7 +108,6 @@ func (c KnownHostsCallback) Wrap() ssh.HostKeyCallback {
c.mutex.Lock()
defer c.mutex.Unlock()
c.untrustedHosts[server] = true
-
return err
}
}
@@ -121,7 +116,6 @@ func (c KnownHostsCallback) Wrap() ssh.HostKeyCallback {
// be added to the known hosts or not.
func (c KnownHostsCallback) PromptAddHosts(ctx context.Context) {
var hosts []unknownHost
-
for {
// Check whether there is a unknown host
select {
@@ -139,7 +133,7 @@ func (c KnownHostsCallback) PromptAddHosts(ctx context.Context) {
hosts = []unknownHost{}
}
case <-ctx.Done():
- logger.Debug("Stopping goroutine prompting new hosts...")
+ dlog.Common.Debug("Stopping goroutine prompting new hosts...")
return
}
}
@@ -147,14 +141,13 @@ func (c KnownHostsCallback) PromptAddHosts(ctx context.Context) {
func (c KnownHostsCallback) promptAddHosts(hosts []unknownHost) {
var servers []string
-
for _, host := range hosts {
servers = append(servers, host.server)
}
select {
case <-c.trustAllHostsCh:
- logger.Warn("Trusting host keys of servers", servers)
+ dlog.Common.Warn("Trusting host keys of servers", servers)
c.trustHosts(hosts)
return
default:
@@ -165,7 +158,6 @@ func (c KnownHostsCallback) promptAddHosts(hosts []unknownHost) {
strings.Join(servers, ","),
"Do you want to trust these hosts?",
)
-
p := prompt.New(question)
a := prompt.Answer{
@@ -175,7 +167,7 @@ func (c KnownHostsCallback) promptAddHosts(hosts []unknownHost) {
c.trustHosts(hosts)
},
EndCallback: func() {
- logger.Info("Added hosts to known hosts file", c.knownHostsPath)
+ dlog.Common.Info("Added hosts to known hosts file", c.knownHostsPath)
},
}
p.Add(a)
@@ -188,7 +180,7 @@ func (c KnownHostsCallback) promptAddHosts(hosts []unknownHost) {
c.trustHosts(hosts)
},
EndCallback: func() {
- logger.Info("Added hosts to known hosts file", c.knownHostsPath)
+ dlog.Common.Info("Added hosts to known hosts file", c.knownHostsPath)
},
}
p.Add(a)
@@ -200,7 +192,7 @@ func (c KnownHostsCallback) promptAddHosts(hosts []unknownHost) {
c.dontTrustHosts(hosts)
},
EndCallback: func() {
- logger.Info("Didn't add hosts to known hosts file", c.knownHostsPath)
+ dlog.Common.Info("Didn't add hosts to known hosts file", c.knownHostsPath)
},
}
p.Add(a)
@@ -224,6 +216,11 @@ func (c KnownHostsCallback) promptAddHosts(hosts []unknownHost) {
func (c KnownHostsCallback) trustHosts(hosts []unknownHost) {
tmpKnownHostsPath := fmt.Sprintf("%s.tmp", c.knownHostsPath)
+ if config.Client.SSHDontAddHostsToKnownHostsFile {
+ dlog.Common.Verbose("Not adding hosts to known hosts file, as disabled by config")
+ return
+ }
+
newFd, err := os.OpenFile(tmpKnownHostsPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600)
if err != nil {
panic(fmt.Sprintf("%s: %s", tmpKnownHostsPath, err.Error()))
@@ -232,7 +229,6 @@ func (c KnownHostsCallback) trustHosts(hosts []unknownHost) {
// Newly trusted hosts in normalized form
addresses := make(map[string]struct{})
-
// First write to new known hosts file, and keep track of addresses
for _, unknown := range hosts {
unknown.responseCh <- trustHost
@@ -255,7 +251,6 @@ func (c KnownHostsCallback) trustHosts(hosts []unknownHost) {
defer oldFd.Close()
scanner := bufio.NewScanner(oldFd)
-
// Now, append all still valid old entries to the new host file
for scanner.Scan() {
line := scanner.Text()
@@ -283,6 +278,5 @@ func (c KnownHostsCallback) Untrusted(server string) bool {
c.mutex.Lock()
defer c.mutex.Unlock()
_, ok := c.untrustedHosts[server]
-
return ok
}
diff --git a/internal/ssh/server/hostkey.go b/internal/ssh/server/hostkey.go
index 07790ad..33bd4e8 100644
--- a/internal/ssh/server/hostkey.go
+++ b/internal/ssh/server/hostkey.go
@@ -1,11 +1,12 @@
package server
import (
- "github.com/mimecast/dtail/internal/config"
- "github.com/mimecast/dtail/internal/io/logger"
- "github.com/mimecast/dtail/internal/ssh"
"io/ioutil"
"os"
+
+ "github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/io/dlog"
+ "github.com/mimecast/dtail/internal/ssh"
)
// PrivateHostKey retrieves the private server RSA host key.
@@ -14,24 +15,25 @@ func PrivateHostKey() []byte {
_, err := os.Stat(hostKeyFile)
if os.IsNotExist(err) {
- logger.Info("Generating private server RSA host key")
+ dlog.Common.Info("Generating private server RSA host key")
privateKey, err := ssh.GeneratePrivateRSAKey(config.Server.HostKeyBits)
if err != nil {
- logger.FatalExit("Failed to generate private server RSA host key", err)
+ dlog.Common.FatalPanic("Failed to generate private server RSA host key", err)
}
pem := ssh.EncodePrivateKeyToPEM(privateKey)
if err := ioutil.WriteFile(hostKeyFile, pem, 0600); err != nil {
- logger.Error("Unable to write private server RSA host key to file", hostKeyFile, err)
+ dlog.Common.Error("Unable to write private server RSA host key to file",
+ hostKeyFile, err)
}
return pem
}
- logger.Info("Reading private server RSA host key from file", hostKeyFile)
+ dlog.Common.Info("Reading private server RSA host key from file", hostKeyFile)
pem, err := ioutil.ReadFile(hostKeyFile)
if err != nil {
- logger.FatalExit("Failed to load private server RSA host key", err)
+ dlog.Common.FatalPanic("Failed to load private server RSA host key", err)
}
return pem
}
diff --git a/internal/ssh/server/publickeycallback.go b/internal/ssh/server/publickeycallback.go
index e81f019..ebc428a 100644
--- a/internal/ssh/server/publickeycallback.go
+++ b/internal/ssh/server/publickeycallback.go
@@ -7,28 +7,34 @@ import (
osUser "os/user"
"github.com/mimecast/dtail/internal/config"
- "github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/io/dlog"
user "github.com/mimecast/dtail/internal/user/server"
gossh "golang.org/x/crypto/ssh"
)
-// PublicKeyCallback is for the server to check whether a public SSH key is authorized ot not.
-func PublicKeyCallback(c gossh.ConnMetadata, offeredPubKey gossh.PublicKey) (*gossh.Permissions, error) {
- user := user.New(c.User(), c.RemoteAddr().String())
- logger.Info(user, "Incoming authorization")
+// PublicKeyCallback is for the server to check whether a public SSH key is
+// authorized ot not.
+func PublicKeyCallback(c gossh.ConnMetadata,
+ offeredPubKey gossh.PublicKey) (*gossh.Permissions, error) {
+ user, err := user.New(c.User(), c.RemoteAddr().String())
+ if err != nil {
+ return nil, err
+ }
+
+ dlog.Common.Info(user, "Incoming authorization")
cwd, err := os.Getwd()
if err != nil {
return nil, fmt.Errorf("Unable to get current working directory|%s|", err.Error())
}
-
if config.ServerRelaxedAuthEnable {
- logger.Fatal(user, "Granting permissions via relaxed-auth")
+ dlog.Common.Fatal(user, "Granting permissions via relaxed-auth")
return nil, nil
}
- authorizedKeysFile := fmt.Sprintf("%s/%s/%s.authorized_keys", cwd, config.Common.CacheDir, user.Name)
+ authorizedKeysFile := fmt.Sprintf("%s/%s/%s.authorized_keys", cwd,
+ config.Common.CacheDir, user.Name)
if _, err := os.Stat(authorizedKeysFile); os.IsNotExist(err) {
user, err := osUser.Lookup(user.Name)
if err != nil {
@@ -38,26 +44,28 @@ func PublicKeyCallback(c gossh.ConnMetadata, offeredPubKey gossh.PublicKey) (*go
authorizedKeysFile = user.HomeDir + "/.ssh/authorized_keys"
}
- logger.Info(user, "Reading", authorizedKeysFile)
+ dlog.Common.Info(user, "Reading", authorizedKeysFile)
authorizedKeysBytes, err := ioutil.ReadFile(authorizedKeysFile)
if err != nil {
- return nil, fmt.Errorf("Unable to read authorized keys file|%s|%s|%s", authorizedKeysFile, user, err.Error())
+ return nil, fmt.Errorf("Unable to read authorized keys file|%s|%s|%s",
+ authorizedKeysFile, user, err.Error())
}
authorizedKeysMap := map[string]bool{}
for len(authorizedKeysBytes) > 0 {
authorizedPubKey, _, _, restBytes, err := gossh.ParseAuthorizedKey(authorizedKeysBytes)
if err != nil {
- return nil, fmt.Errorf("Unable to parse authorized keys bytes|%s|%s", user, err.Error())
+ return nil, fmt.Errorf("Unable to parse authorized keys bytes|%s|%s",
+ user, err.Error())
}
authorizedKeysMap[string(authorizedPubKey.Marshal())] = true
authorizedKeysBytes = restBytes
-
- logger.Debug(user, "Authorized public key fingerprint", gossh.FingerprintSHA256(authorizedPubKey))
+ dlog.Common.Debug(user, "Authorized public key fingerprint",
+ gossh.FingerprintSHA256(authorizedPubKey))
}
- logger.Debug(user, "Offered public key fingerprint", gossh.FingerprintSHA256(offeredPubKey))
-
+ dlog.Common.Debug(user, "Offered public key fingerprint",
+ gossh.FingerprintSHA256(offeredPubKey))
if authorizedKeysMap[string(offeredPubKey.Marshal())] {
return &gossh.Permissions{
Extensions: map[string]string{
diff --git a/internal/ssh/ssh.go b/internal/ssh/ssh.go
index 3a2e416..db5aaf1 100644
--- a/internal/ssh/ssh.go
+++ b/internal/ssh/ssh.go
@@ -6,12 +6,13 @@ import (
"crypto/x509"
"encoding/pem"
"fmt"
- "github.com/mimecast/dtail/internal/io/logger"
"io/ioutil"
"net"
"os"
"syscall"
+ "github.com/mimecast/dtail/internal/io/dlog"
+
gossh "golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
"golang.org/x/crypto/ssh/terminal"
@@ -23,12 +24,10 @@ func GeneratePrivateRSAKey(size int) (*rsa.PrivateKey, error) {
if err != nil {
return nil, err
}
-
err = privateKey.Validate()
if err != nil {
return nil, err
}
-
return privateKey, nil
}
@@ -41,7 +40,6 @@ func EncodePrivateKeyToPEM(privateKey *rsa.PrivateKey) []byte {
Headers: nil,
Bytes: derFormat,
}
-
return pem.EncodeToMemory(&block)
}
@@ -57,7 +55,7 @@ func Agent() (gossh.AuthMethod, error) {
return nil, err
}
for i, key := range keys {
- logger.Debug("Public key", i, key)
+ dlog.Common.Debug("Public key", i, key)
}
return gossh.PublicKeysCallback(agentClient.Signers), nil
}
@@ -79,7 +77,6 @@ func KeyFile(keyFile string) (gossh.AuthMethod, error) {
if err != nil {
return nil, err
}
-
key, err := gossh.ParsePrivateKey(buffer)
if err != nil {
return nil, err
@@ -105,7 +102,7 @@ func KeyFile(keyFile string) (gossh.AuthMethod, error) {
func PrivateKey(keyFile string) (gossh.AuthMethod, error) {
signer, err := KeyFile(keyFile)
if err != nil {
- logger.Debug(keyFile, err)
+ dlog.Common.Debug(keyFile, err)
return nil, err
}
return gossh.AuthMethod(signer), nil