diff options
| author | Paul Buetow <pbuetow@mimecast.com> | 2021-10-27 09:54:06 +0300 |
|---|---|---|
| committer | Paul Buetow <pbuetow@mimecast.com> | 2021-10-27 09:54:06 +0300 |
| commit | d0492e3f63f86ce053f59362a55d23cf6397d526 (patch) | |
| tree | efe6668328b58db3e14d3b6c408295e0f3bf85dd | |
| parent | 7c927e7e6d913168798d245a7ceea32a9ca85643 (diff) | |
integration tests use separate ssh private key file
| -rw-r--r-- | .gitignore | 3 | ||||
| -rw-r--r-- | cmd/dcat/main.go | 2 | ||||
| -rw-r--r-- | cmd/dgrep/main.go | 2 | ||||
| -rw-r--r-- | cmd/dmap/main.go | 2 | ||||
| -rw-r--r-- | cmd/dtail/main.go | 2 | ||||
| -rw-r--r-- | integrationtests/dcat_test.go | 6 | ||||
| -rw-r--r-- | integrationtests/dgrep_test.go | 8 | ||||
| -rw-r--r-- | integrationtests/dmap_test.go | 6 | ||||
| -rw-r--r-- | integrationtests/dtail_test.go | 4 | ||||
| -rw-r--r-- | integrationtests/dtailhealth_test.go | 6 | ||||
| -rw-r--r-- | internal/clients/baseclient.go | 2 | ||||
| -rw-r--r-- | internal/config/args.go | 52 | ||||
| -rw-r--r-- | internal/config/initializer.go | 2 | ||||
| -rw-r--r-- | internal/ssh/client/authmethods.go | 40 | ||||
| -rw-r--r-- | internal/ssh/client/clientkeypair.go | 91 |
15 files changed, 167 insertions, 61 deletions
@@ -14,3 +14,6 @@ tags /dserver /dtailhealth known_hosts +id_rsa +id_rsa.pub +ssh_host_key diff --git a/cmd/dcat/main.go b/cmd/dcat/main.go index 5fd22ea..8ac5eda 100644 --- a/cmd/dcat/main.go +++ b/cmd/dcat/main.go @@ -40,7 +40,7 @@ func main() { flag.StringVar(&args.LogDir, "logDir", "~/log", "Log dir") flag.StringVar(&args.Logger, "logger", config.DefaultClientLogger, "Logger name") flag.StringVar(&args.LogLevel, "logLevel", config.DefaultLogLevel, "Log level") - flag.StringVar(&args.PrivateKeyPathFile, "key", "", "Path to private key") + flag.StringVar(&args.SSHPrivateKeyPathFile, "key", "", "Path to private key") flag.StringVar(&args.ServersStr, "servers", "", "Remote servers to connect") flag.StringVar(&args.UserName, "user", userName, "Your system user name") flag.StringVar(&args.What, "files", "", "File(s) to read") diff --git a/cmd/dgrep/main.go b/cmd/dgrep/main.go index 02b2463..8657fa9 100644 --- a/cmd/dgrep/main.go +++ b/cmd/dgrep/main.go @@ -44,7 +44,7 @@ func main() { flag.StringVar(&args.LogDir, "logDir", "~/log", "Log dir") flag.StringVar(&args.Logger, "logger", config.DefaultClientLogger, "Logger name") flag.StringVar(&args.LogLevel, "logLevel", config.DefaultLogLevel, "Log level") - flag.StringVar(&args.PrivateKeyPathFile, "key", "", "Path to private key") + flag.StringVar(&args.SSHPrivateKeyPathFile, "key", "", "Path to private key") flag.StringVar(&args.RegexStr, "regex", ".", "Regular expression") flag.StringVar(&args.ServersStr, "servers", "", "Remote servers to connect") flag.StringVar(&args.UserName, "user", userName, "Your system user name") diff --git a/cmd/dmap/main.go b/cmd/dmap/main.go index 2c941f3..79a53ba 100644 --- a/cmd/dmap/main.go +++ b/cmd/dmap/main.go @@ -44,7 +44,7 @@ func main() { flag.StringVar(&args.LogDir, "logDir", "~/log", "Log dir") flag.StringVar(&args.Logger, "logger", config.DefaultClientLogger, "Logger name") flag.StringVar(&args.LogLevel, "logLevel", config.DefaultLogLevel, "Log level") - flag.StringVar(&args.PrivateKeyPathFile, "key", "", "Path to private key") + flag.StringVar(&args.SSHPrivateKeyPathFile, "key", "", "Path to private key") flag.StringVar(&args.QueryStr, "query", "", "Map reduce query") flag.StringVar(&args.ServersStr, "servers", "", "Remote servers to connect") flag.StringVar(&args.UserName, "user", userName, "Your system user name") diff --git a/cmd/dtail/main.go b/cmd/dtail/main.go index ff0cea9..2b46141 100644 --- a/cmd/dtail/main.go +++ b/cmd/dtail/main.go @@ -57,7 +57,7 @@ func main() { flag.StringVar(&args.LogDir, "logDir", "~/log", "Log dir") flag.StringVar(&args.Logger, "logger", config.DefaultClientLogger, "Logger name") flag.StringVar(&args.LogLevel, "logLevel", config.DefaultLogLevel, "Log level") - flag.StringVar(&args.PrivateKeyPathFile, "key", "", "Path to private key") + flag.StringVar(&args.SSHPrivateKeyPathFile, "key", "", "Path to private key") flag.StringVar(&args.QueryStr, "query", "", "Map reduce query") flag.StringVar(&args.RegexStr, "regex", ".", "Regular expression") flag.StringVar(&args.ServersStr, "servers", "", "Remote servers to connect") diff --git a/integrationtests/dcat_test.go b/integrationtests/dcat_test.go index 777e835..6928afa 100644 --- a/integrationtests/dcat_test.go +++ b/integrationtests/dcat_test.go @@ -9,7 +9,7 @@ import ( ) func TestDCat(t *testing.T) { - if !config.Env("DTAIL_RUN_INTEGRATION_TESTS") { + if !config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") { t.Log("Skipping") return } @@ -33,7 +33,7 @@ func TestDCat(t *testing.T) { } func TestDCat2(t *testing.T) { - if !config.Env("DTAIL_RUN_INTEGRATION_TESTS") { + if !config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") { return } testdataFile := "dcat2.txt" @@ -62,7 +62,7 @@ func TestDCat2(t *testing.T) { } func TestDCatColors(t *testing.T) { - if !config.Env("DTAIL_RUN_INTEGRATION_TESTS") { + if !config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") { return } diff --git a/integrationtests/dgrep_test.go b/integrationtests/dgrep_test.go index 26abc2f..35c3ff5 100644 --- a/integrationtests/dgrep_test.go +++ b/integrationtests/dgrep_test.go @@ -9,7 +9,7 @@ import ( ) func TestDGrep(t *testing.T) { - if !config.Env("DTAIL_RUN_INTEGRATION_TESTS") { + if !config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") { t.Log("Skipping") return } @@ -38,7 +38,7 @@ func TestDGrep(t *testing.T) { } func TestDGrep2(t *testing.T) { - if !config.Env("DTAIL_RUN_INTEGRATION_TESTS") { + if !config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") { t.Log("Skipping") return } @@ -68,7 +68,7 @@ func TestDGrep2(t *testing.T) { } func TestDGrepContext(t *testing.T) { - if !config.Env("DTAIL_RUN_INTEGRATION_TESTS") { + if !config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") { t.Log("Skipping") return } @@ -98,7 +98,7 @@ func TestDGrepContext(t *testing.T) { } func TestDGrepContext2(t *testing.T) { - if !config.Env("DTAIL_RUN_INTEGRATION_TESTS") { + if !config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") { t.Log("Skipping") return } diff --git a/integrationtests/dmap_test.go b/integrationtests/dmap_test.go index 53b8574..6a93b7b 100644 --- a/integrationtests/dmap_test.go +++ b/integrationtests/dmap_test.go @@ -10,7 +10,7 @@ import ( ) func TestDMap(t *testing.T) { - if !config.Env("DTAIL_RUN_INTEGRATION_TESTS") { + if !config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") { t.Log("Skipping") return } @@ -56,7 +56,7 @@ func TestDMap(t *testing.T) { } func TestDMap2(t *testing.T) { - if !config.Env("DTAIL_RUN_INTEGRATION_TESTS") { + if !config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") { t.Log("Skipping") return } @@ -93,7 +93,7 @@ func TestDMap2(t *testing.T) { } func TestDMap3(t *testing.T) { - if !config.Env("DTAIL_RUN_INTEGRATION_TESTS") { + if !config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") { t.Log("Skipping") return } diff --git a/integrationtests/dtail_test.go b/integrationtests/dtail_test.go index e9cf257..2f2708e 100644 --- a/integrationtests/dtail_test.go +++ b/integrationtests/dtail_test.go @@ -12,7 +12,7 @@ import ( ) func TestDTailWithServer(t *testing.T) { - if !config.Env("DTAIL_RUN_INTEGRATION_TESTS") { + if !config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") { t.Log("Skipping") return } @@ -131,7 +131,7 @@ func TestDTailWithServer(t *testing.T) { } func TestDTailColorTable(t *testing.T) { - if !config.Env("DTAIL_RUN_INTEGRATION_TESTS") { + if !config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") { t.Log("Skipping") return } diff --git a/integrationtests/dtailhealth_test.go b/integrationtests/dtailhealth_test.go index 271f11d..b53c425 100644 --- a/integrationtests/dtailhealth_test.go +++ b/integrationtests/dtailhealth_test.go @@ -10,7 +10,7 @@ import ( ) func TestDTailHealthCheck(t *testing.T) { - if !config.Env("DTAIL_RUN_INTEGRATION_TESTS") { + if !config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") { t.Log("Skipping") return } @@ -32,7 +32,7 @@ func TestDTailHealthCheck(t *testing.T) { } func TestDTailHealthCheck2(t *testing.T) { - if !config.Env("DTAIL_RUN_INTEGRATION_TESTS") { + if !config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") { t.Log("Skipping") return } @@ -57,7 +57,7 @@ func TestDTailHealthCheck2(t *testing.T) { } func TestDTailHealthCheck3(t *testing.T) { - if !config.Env("DTAIL_RUN_INTEGRATION_TESTS") { + if !config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") { t.Log("Skipping") return } diff --git a/internal/clients/baseclient.go b/internal/clients/baseclient.go index 4a7bd84..41521ea 100644 --- a/internal/clients/baseclient.go +++ b/internal/clients/baseclient.go @@ -56,7 +56,7 @@ func (c *baseClient) init() { } c.sshAuthMethods, c.hostKeyCallback = client.InitSSHAuthMethods( c.Args.SSHAuthMethods, c.Args.SSHHostKeyCallback, c.Args.TrustAllHosts, - c.throttleCh, c.Args.PrivateKeyPathFile) + c.throttleCh, c.Args.SSHPrivateKeyPathFile) } func (c *baseClient) makeConnections(maker maker) { diff --git a/internal/config/args.go b/internal/config/args.go index 3d7ac7d..859166d 100644 --- a/internal/config/args.go +++ b/internal/config/args.go @@ -15,31 +15,31 @@ import ( // Args is a helper struct to summarize common client arguments. type Args struct { lcontext.LContext - Arguments []string - ConfigFile string - ConnectionsPerCPU int - Discovery string - LogDir string - Logger string - LogLevel string - Mode omode.Mode - NoColor bool - PrivateKeyPathFile string - QueryStr string - Quiet bool - RegexInvert bool - RegexStr string - Serverless bool - ServersStr string - Spartan bool - SSHAuthMethods []gossh.AuthMethod - SSHBindAddress string - SSHHostKeyCallback gossh.HostKeyCallback - SSHPort int - Timeout int - TrustAllHosts bool - UserName string - What string + Arguments []string + ConfigFile string + ConnectionsPerCPU int + Discovery string + LogDir string + Logger string + LogLevel string + Mode omode.Mode + NoColor bool + QueryStr string + Quiet bool + RegexInvert bool + RegexStr string + SSHAuthMethods []gossh.AuthMethod + SSHBindAddress string + SSHHostKeyCallback gossh.HostKeyCallback + SSHPort int + SSHPrivateKeyPathFile string + Serverless bool + ServersStr string + Spartan bool + Timeout int + TrustAllHosts bool + UserName string + What string } func (a *Args) String() string { @@ -56,7 +56,6 @@ func (a *Args) String() string { sb.WriteString(fmt.Sprintf("%s:%v,", "Logger", a.Logger)) sb.WriteString(fmt.Sprintf("%s:%v,", "Mode", a.Mode)) sb.WriteString(fmt.Sprintf("%s:%v,", "NoColor", a.NoColor)) - sb.WriteString(fmt.Sprintf("%s:%v,", "PrivateKeyPathFile", a.PrivateKeyPathFile)) sb.WriteString(fmt.Sprintf("%s:%v,", "QueryStr", a.QueryStr)) sb.WriteString(fmt.Sprintf("%s:%v,", "Quiet", a.Quiet)) sb.WriteString(fmt.Sprintf("%s:%v,", "RegexInvert", a.RegexInvert)) @@ -64,6 +63,7 @@ func (a *Args) String() string { sb.WriteString(fmt.Sprintf("%s:%v,", "SSHAuthMethods", a.SSHAuthMethods)) sb.WriteString(fmt.Sprintf("%s:%v,", "SSHBindAddress", a.SSHBindAddress)) sb.WriteString(fmt.Sprintf("%s:%v,", "SSHHostKeyCallback", a.SSHHostKeyCallback)) + sb.WriteString(fmt.Sprintf("%s:%v,", "SSHPrivateKeyPathFile", a.SSHPrivateKeyPathFile)) sb.WriteString(fmt.Sprintf("%s:%v,", "SSHPort", a.SSHPort)) sb.WriteString(fmt.Sprintf("%s:%v,", "Serverless", a.Serverless)) sb.WriteString(fmt.Sprintf("%s:%v,", "ServersStr", a.ServersStr)) diff --git a/internal/config/initializer.go b/internal/config/initializer.go index 936df8a..74d0289 100644 --- a/internal/config/initializer.go +++ b/internal/config/initializer.go @@ -82,7 +82,7 @@ func (in *initializer) transformConfig(sourceProcess source.Source, args *Args, // There are some special options which can be set by environment variable. func (in *initializer) readEnvironmentVars() { - if Env("DTAIL_RUN_INTEGRATION_TESTS") { + if Env("DTAIL_INTEGRATION_TEST_RUN_MODE") { os.Setenv("DTAIL_HOSTNAME_OVERRIDE", "integrationtest") os.Setenv("DTAIL_SSH_KNOWN_HOSTS_FILE", "./known_hosts") in.Server.HostKeyFile = "./ssh_host_key" diff --git a/internal/ssh/client/authmethods.go b/internal/ssh/client/authmethods.go index 2ee32ad..49cf938 100644 --- a/internal/ssh/client/authmethods.go +++ b/internal/ssh/client/authmethods.go @@ -25,27 +25,44 @@ func InitSSHAuthMethods(sshAuthMethods []gossh.AuthMethod, return initKnownHostsAuthMethods(trustAllHosts, throttleCh, privateKeyPath) } +func initIntegrationTestKnownHostsAuthMethods() []gossh.AuthMethod { + var sshAuthMethods []gossh.AuthMethod + privateKeyPath := "./id_rsa" + + GeneratePrivatePublicKeyPairIfNotExists(privateKeyPath, 4096) + authMethod, err := ssh.PrivateKey(privateKeyPath) + if err != nil { + dlog.Client.FatalPanic("Unable to use private SSH key", privateKeyPath, err) + } + + sshAuthMethods = append(sshAuthMethods, authMethod) + dlog.Client.Debug("initKnownHostsAuthMethods", + "Added path to list of auth methods, not adding further methods", privateKeyPath) + return sshAuthMethods +} + func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{}, privateKeyPath string) ([]gossh.AuthMethod, HostKeyCallback) { - var sshAuthMethods []gossh.AuthMethod knownHostsFile := config.SSHKnownHostsFile() - knownHostsCallback, err := NewKnownHostsCallback(knownHostsFile, trustAllHosts, - throttleCh) + knownHostsCallback, err := NewKnownHostsCallback(knownHostsFile, trustAllHosts, throttleCh) if err != nil { dlog.Client.FatalPanic(knownHostsFile, err) } - dlog.Client.Debug("initKnownHostsAuthMethods", "Added known hosts file path", knownHostsFile) + if config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") { + return initIntegrationTestKnownHostsAuthMethods(), knownHostsCallback + } + + var sshAuthMethods []gossh.AuthMethod // First try to read custom private key path. if privateKeyPath != "" { authMethod, err := ssh.PrivateKey(privateKeyPath) if err == nil { sshAuthMethods = append(sshAuthMethods, authMethod) dlog.Client.Debug("initKnownHostsAuthMethods", - "Added path to list of auth methods, not adding further methods", - privateKeyPath) + "Added path to list of auth methods, not adding further methods", privateKeyPath) return sshAuthMethods, knownHostsCallback } dlog.Client.FatalPanic("Unable to use private SSH key", privateKeyPath, err) @@ -59,8 +76,7 @@ func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{}, "to list of auth methods, not adding further methods") return sshAuthMethods, knownHostsCallback } - dlog.Client.Debug("initKnownHostsAuthMethods", - "Unable to init SSH Agent auth method", err) + dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to init SSH Agent auth method", err) // Third, try Linux/UNIX default key paths privateKeyPath = os.Getenv("HOME") + "/.ssh/id_rsa" @@ -71,8 +87,7 @@ func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{}, "Added path to list of auth methods, not adding further methods", privateKeyPath) return sshAuthMethods, knownHostsCallback } - dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to use private key", - privateKeyPath, err) + dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to use private key", privateKeyPath, err) privateKeyPath = os.Getenv("HOME") + "/.ssh/id_dsa" authMethod, err = ssh.PrivateKey(privateKeyPath) @@ -92,10 +107,7 @@ func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{}, return sshAuthMethods, knownHostsCallback } - dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to use private key", - privateKeyPath, err) - - dlog.Client.FatalPanic("Unable to find private SSH key information") + dlog.Client.FatalPanic("Unable to find private SSH key information", privateKeyPath, err) // Never reach this point. return sshAuthMethods, knownHostsCallback } diff --git a/internal/ssh/client/clientkeypair.go b/internal/ssh/client/clientkeypair.go new file mode 100644 index 0000000..0e21d0c --- /dev/null +++ b/internal/ssh/client/clientkeypair.go @@ -0,0 +1,91 @@ +package client + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "fmt" + "io/ioutil" + "os" + + "github.com/mimecast/dtail/internal/io/dlog" + "golang.org/x/crypto/ssh" +) + +// GeneratePrivatePublicKeyPairIfNotExists generates a SSH key pair (used by the integration tests) +func GeneratePrivatePublicKeyPairIfNotExists(keyPath string, bitSize int) { + if _, err := os.Stat(keyPath); err == nil { + dlog.Common.Debug("Private/public key pair already exists", keyPath) + return + } + GeneratePrivatePublicKeyPair(keyPath, bitSize) +} + +// GeneratePrivatePublicKeyPair generates a SSH key pair (used by the integration tests) +func GeneratePrivatePublicKeyPair(keyPath string, bitSize int) { + privateKeyPath := keyPath + publicKeyPath := fmt.Sprintf("%s.pub", keyPath) + + dlog.Common.Debug("Generating private/public key pair", privateKeyPath, publicKeyPath) + + privateKey, err := generatePrivateKey(bitSize) + if err != nil { + dlog.Common.FatalPanic(err) + } + publicKeyBytes, err := generatePublicKey(&privateKey.PublicKey) + if err != nil { + dlog.Common.FatalPanic(err) + } + privateKeyBytes := encodePrivateKeyToPEM(privateKey) + err = writeKey(privateKeyBytes, privateKeyPath) + if err != nil { + dlog.Common.FatalPanic(err) + } + err = writeKey([]byte(publicKeyBytes), publicKeyPath) + if err != nil { + dlog.Common.FatalPanic(err) + } + + dlog.Common.Debug("Done generating private/public key pair", privateKeyPath, publicKeyPath) +} + +func generatePrivateKey(bitSize int) (*rsa.PrivateKey, error) { + privateKey, err := rsa.GenerateKey(rand.Reader, bitSize) + if err != nil { + return nil, err + } + err = privateKey.Validate() + if err != nil { + return nil, err + } + return privateKey, nil +} + +func encodePrivateKeyToPEM(privateKey *rsa.PrivateKey) []byte { + privDER := x509.MarshalPKCS1PrivateKey(privateKey) + privBlock := pem.Block{ + Type: "RSA PRIVATE KEY", + Headers: nil, + Bytes: privDER, + } + privatePEM := pem.EncodeToMemory(&privBlock) + return privatePEM +} + +func generatePublicKey(privatekey *rsa.PublicKey) ([]byte, error) { + publicRsaKey, err := ssh.NewPublicKey(privatekey) + if err != nil { + return nil, err + } + pubKeyBytes := ssh.MarshalAuthorizedKey(publicRsaKey) + return pubKeyBytes, nil +} + +func writeKey(keyBytes []byte, saveFileTo string) error { + err := ioutil.WriteFile(saveFileTo, keyBytes, 0600) + if err != nil { + return err + } + return nil +} |
