summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPaul Buetow <pbuetow@mimecast.com>2021-10-27 09:54:06 +0300
committerPaul Buetow <pbuetow@mimecast.com>2021-10-27 09:54:06 +0300
commitd0492e3f63f86ce053f59362a55d23cf6397d526 (patch)
treeefe6668328b58db3e14d3b6c408295e0f3bf85dd
parent7c927e7e6d913168798d245a7ceea32a9ca85643 (diff)
integration tests use separate ssh private key file
-rw-r--r--.gitignore3
-rw-r--r--cmd/dcat/main.go2
-rw-r--r--cmd/dgrep/main.go2
-rw-r--r--cmd/dmap/main.go2
-rw-r--r--cmd/dtail/main.go2
-rw-r--r--integrationtests/dcat_test.go6
-rw-r--r--integrationtests/dgrep_test.go8
-rw-r--r--integrationtests/dmap_test.go6
-rw-r--r--integrationtests/dtail_test.go4
-rw-r--r--integrationtests/dtailhealth_test.go6
-rw-r--r--internal/clients/baseclient.go2
-rw-r--r--internal/config/args.go52
-rw-r--r--internal/config/initializer.go2
-rw-r--r--internal/ssh/client/authmethods.go40
-rw-r--r--internal/ssh/client/clientkeypair.go91
15 files changed, 167 insertions, 61 deletions
diff --git a/.gitignore b/.gitignore
index 20e162c..44dc08e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -14,3 +14,6 @@ tags
/dserver
/dtailhealth
known_hosts
+id_rsa
+id_rsa.pub
+ssh_host_key
diff --git a/cmd/dcat/main.go b/cmd/dcat/main.go
index 5fd22ea..8ac5eda 100644
--- a/cmd/dcat/main.go
+++ b/cmd/dcat/main.go
@@ -40,7 +40,7 @@ func main() {
flag.StringVar(&args.LogDir, "logDir", "~/log", "Log dir")
flag.StringVar(&args.Logger, "logger", config.DefaultClientLogger, "Logger name")
flag.StringVar(&args.LogLevel, "logLevel", config.DefaultLogLevel, "Log level")
- flag.StringVar(&args.PrivateKeyPathFile, "key", "", "Path to private key")
+ flag.StringVar(&args.SSHPrivateKeyPathFile, "key", "", "Path to private key")
flag.StringVar(&args.ServersStr, "servers", "", "Remote servers to connect")
flag.StringVar(&args.UserName, "user", userName, "Your system user name")
flag.StringVar(&args.What, "files", "", "File(s) to read")
diff --git a/cmd/dgrep/main.go b/cmd/dgrep/main.go
index 02b2463..8657fa9 100644
--- a/cmd/dgrep/main.go
+++ b/cmd/dgrep/main.go
@@ -44,7 +44,7 @@ func main() {
flag.StringVar(&args.LogDir, "logDir", "~/log", "Log dir")
flag.StringVar(&args.Logger, "logger", config.DefaultClientLogger, "Logger name")
flag.StringVar(&args.LogLevel, "logLevel", config.DefaultLogLevel, "Log level")
- flag.StringVar(&args.PrivateKeyPathFile, "key", "", "Path to private key")
+ flag.StringVar(&args.SSHPrivateKeyPathFile, "key", "", "Path to private key")
flag.StringVar(&args.RegexStr, "regex", ".", "Regular expression")
flag.StringVar(&args.ServersStr, "servers", "", "Remote servers to connect")
flag.StringVar(&args.UserName, "user", userName, "Your system user name")
diff --git a/cmd/dmap/main.go b/cmd/dmap/main.go
index 2c941f3..79a53ba 100644
--- a/cmd/dmap/main.go
+++ b/cmd/dmap/main.go
@@ -44,7 +44,7 @@ func main() {
flag.StringVar(&args.LogDir, "logDir", "~/log", "Log dir")
flag.StringVar(&args.Logger, "logger", config.DefaultClientLogger, "Logger name")
flag.StringVar(&args.LogLevel, "logLevel", config.DefaultLogLevel, "Log level")
- flag.StringVar(&args.PrivateKeyPathFile, "key", "", "Path to private key")
+ flag.StringVar(&args.SSHPrivateKeyPathFile, "key", "", "Path to private key")
flag.StringVar(&args.QueryStr, "query", "", "Map reduce query")
flag.StringVar(&args.ServersStr, "servers", "", "Remote servers to connect")
flag.StringVar(&args.UserName, "user", userName, "Your system user name")
diff --git a/cmd/dtail/main.go b/cmd/dtail/main.go
index ff0cea9..2b46141 100644
--- a/cmd/dtail/main.go
+++ b/cmd/dtail/main.go
@@ -57,7 +57,7 @@ func main() {
flag.StringVar(&args.LogDir, "logDir", "~/log", "Log dir")
flag.StringVar(&args.Logger, "logger", config.DefaultClientLogger, "Logger name")
flag.StringVar(&args.LogLevel, "logLevel", config.DefaultLogLevel, "Log level")
- flag.StringVar(&args.PrivateKeyPathFile, "key", "", "Path to private key")
+ flag.StringVar(&args.SSHPrivateKeyPathFile, "key", "", "Path to private key")
flag.StringVar(&args.QueryStr, "query", "", "Map reduce query")
flag.StringVar(&args.RegexStr, "regex", ".", "Regular expression")
flag.StringVar(&args.ServersStr, "servers", "", "Remote servers to connect")
diff --git a/integrationtests/dcat_test.go b/integrationtests/dcat_test.go
index 777e835..6928afa 100644
--- a/integrationtests/dcat_test.go
+++ b/integrationtests/dcat_test.go
@@ -9,7 +9,7 @@ import (
)
func TestDCat(t *testing.T) {
- if !config.Env("DTAIL_RUN_INTEGRATION_TESTS") {
+ if !config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") {
t.Log("Skipping")
return
}
@@ -33,7 +33,7 @@ func TestDCat(t *testing.T) {
}
func TestDCat2(t *testing.T) {
- if !config.Env("DTAIL_RUN_INTEGRATION_TESTS") {
+ if !config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") {
return
}
testdataFile := "dcat2.txt"
@@ -62,7 +62,7 @@ func TestDCat2(t *testing.T) {
}
func TestDCatColors(t *testing.T) {
- if !config.Env("DTAIL_RUN_INTEGRATION_TESTS") {
+ if !config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") {
return
}
diff --git a/integrationtests/dgrep_test.go b/integrationtests/dgrep_test.go
index 26abc2f..35c3ff5 100644
--- a/integrationtests/dgrep_test.go
+++ b/integrationtests/dgrep_test.go
@@ -9,7 +9,7 @@ import (
)
func TestDGrep(t *testing.T) {
- if !config.Env("DTAIL_RUN_INTEGRATION_TESTS") {
+ if !config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") {
t.Log("Skipping")
return
}
@@ -38,7 +38,7 @@ func TestDGrep(t *testing.T) {
}
func TestDGrep2(t *testing.T) {
- if !config.Env("DTAIL_RUN_INTEGRATION_TESTS") {
+ if !config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") {
t.Log("Skipping")
return
}
@@ -68,7 +68,7 @@ func TestDGrep2(t *testing.T) {
}
func TestDGrepContext(t *testing.T) {
- if !config.Env("DTAIL_RUN_INTEGRATION_TESTS") {
+ if !config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") {
t.Log("Skipping")
return
}
@@ -98,7 +98,7 @@ func TestDGrepContext(t *testing.T) {
}
func TestDGrepContext2(t *testing.T) {
- if !config.Env("DTAIL_RUN_INTEGRATION_TESTS") {
+ if !config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") {
t.Log("Skipping")
return
}
diff --git a/integrationtests/dmap_test.go b/integrationtests/dmap_test.go
index 53b8574..6a93b7b 100644
--- a/integrationtests/dmap_test.go
+++ b/integrationtests/dmap_test.go
@@ -10,7 +10,7 @@ import (
)
func TestDMap(t *testing.T) {
- if !config.Env("DTAIL_RUN_INTEGRATION_TESTS") {
+ if !config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") {
t.Log("Skipping")
return
}
@@ -56,7 +56,7 @@ func TestDMap(t *testing.T) {
}
func TestDMap2(t *testing.T) {
- if !config.Env("DTAIL_RUN_INTEGRATION_TESTS") {
+ if !config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") {
t.Log("Skipping")
return
}
@@ -93,7 +93,7 @@ func TestDMap2(t *testing.T) {
}
func TestDMap3(t *testing.T) {
- if !config.Env("DTAIL_RUN_INTEGRATION_TESTS") {
+ if !config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") {
t.Log("Skipping")
return
}
diff --git a/integrationtests/dtail_test.go b/integrationtests/dtail_test.go
index e9cf257..2f2708e 100644
--- a/integrationtests/dtail_test.go
+++ b/integrationtests/dtail_test.go
@@ -12,7 +12,7 @@ import (
)
func TestDTailWithServer(t *testing.T) {
- if !config.Env("DTAIL_RUN_INTEGRATION_TESTS") {
+ if !config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") {
t.Log("Skipping")
return
}
@@ -131,7 +131,7 @@ func TestDTailWithServer(t *testing.T) {
}
func TestDTailColorTable(t *testing.T) {
- if !config.Env("DTAIL_RUN_INTEGRATION_TESTS") {
+ if !config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") {
t.Log("Skipping")
return
}
diff --git a/integrationtests/dtailhealth_test.go b/integrationtests/dtailhealth_test.go
index 271f11d..b53c425 100644
--- a/integrationtests/dtailhealth_test.go
+++ b/integrationtests/dtailhealth_test.go
@@ -10,7 +10,7 @@ import (
)
func TestDTailHealthCheck(t *testing.T) {
- if !config.Env("DTAIL_RUN_INTEGRATION_TESTS") {
+ if !config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") {
t.Log("Skipping")
return
}
@@ -32,7 +32,7 @@ func TestDTailHealthCheck(t *testing.T) {
}
func TestDTailHealthCheck2(t *testing.T) {
- if !config.Env("DTAIL_RUN_INTEGRATION_TESTS") {
+ if !config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") {
t.Log("Skipping")
return
}
@@ -57,7 +57,7 @@ func TestDTailHealthCheck2(t *testing.T) {
}
func TestDTailHealthCheck3(t *testing.T) {
- if !config.Env("DTAIL_RUN_INTEGRATION_TESTS") {
+ if !config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") {
t.Log("Skipping")
return
}
diff --git a/internal/clients/baseclient.go b/internal/clients/baseclient.go
index 4a7bd84..41521ea 100644
--- a/internal/clients/baseclient.go
+++ b/internal/clients/baseclient.go
@@ -56,7 +56,7 @@ func (c *baseClient) init() {
}
c.sshAuthMethods, c.hostKeyCallback = client.InitSSHAuthMethods(
c.Args.SSHAuthMethods, c.Args.SSHHostKeyCallback, c.Args.TrustAllHosts,
- c.throttleCh, c.Args.PrivateKeyPathFile)
+ c.throttleCh, c.Args.SSHPrivateKeyPathFile)
}
func (c *baseClient) makeConnections(maker maker) {
diff --git a/internal/config/args.go b/internal/config/args.go
index 3d7ac7d..859166d 100644
--- a/internal/config/args.go
+++ b/internal/config/args.go
@@ -15,31 +15,31 @@ import (
// Args is a helper struct to summarize common client arguments.
type Args struct {
lcontext.LContext
- Arguments []string
- ConfigFile string
- ConnectionsPerCPU int
- Discovery string
- LogDir string
- Logger string
- LogLevel string
- Mode omode.Mode
- NoColor bool
- PrivateKeyPathFile string
- QueryStr string
- Quiet bool
- RegexInvert bool
- RegexStr string
- Serverless bool
- ServersStr string
- Spartan bool
- SSHAuthMethods []gossh.AuthMethod
- SSHBindAddress string
- SSHHostKeyCallback gossh.HostKeyCallback
- SSHPort int
- Timeout int
- TrustAllHosts bool
- UserName string
- What string
+ Arguments []string
+ ConfigFile string
+ ConnectionsPerCPU int
+ Discovery string
+ LogDir string
+ Logger string
+ LogLevel string
+ Mode omode.Mode
+ NoColor bool
+ QueryStr string
+ Quiet bool
+ RegexInvert bool
+ RegexStr string
+ SSHAuthMethods []gossh.AuthMethod
+ SSHBindAddress string
+ SSHHostKeyCallback gossh.HostKeyCallback
+ SSHPort int
+ SSHPrivateKeyPathFile string
+ Serverless bool
+ ServersStr string
+ Spartan bool
+ Timeout int
+ TrustAllHosts bool
+ UserName string
+ What string
}
func (a *Args) String() string {
@@ -56,7 +56,6 @@ func (a *Args) String() string {
sb.WriteString(fmt.Sprintf("%s:%v,", "Logger", a.Logger))
sb.WriteString(fmt.Sprintf("%s:%v,", "Mode", a.Mode))
sb.WriteString(fmt.Sprintf("%s:%v,", "NoColor", a.NoColor))
- sb.WriteString(fmt.Sprintf("%s:%v,", "PrivateKeyPathFile", a.PrivateKeyPathFile))
sb.WriteString(fmt.Sprintf("%s:%v,", "QueryStr", a.QueryStr))
sb.WriteString(fmt.Sprintf("%s:%v,", "Quiet", a.Quiet))
sb.WriteString(fmt.Sprintf("%s:%v,", "RegexInvert", a.RegexInvert))
@@ -64,6 +63,7 @@ func (a *Args) String() string {
sb.WriteString(fmt.Sprintf("%s:%v,", "SSHAuthMethods", a.SSHAuthMethods))
sb.WriteString(fmt.Sprintf("%s:%v,", "SSHBindAddress", a.SSHBindAddress))
sb.WriteString(fmt.Sprintf("%s:%v,", "SSHHostKeyCallback", a.SSHHostKeyCallback))
+ sb.WriteString(fmt.Sprintf("%s:%v,", "SSHPrivateKeyPathFile", a.SSHPrivateKeyPathFile))
sb.WriteString(fmt.Sprintf("%s:%v,", "SSHPort", a.SSHPort))
sb.WriteString(fmt.Sprintf("%s:%v,", "Serverless", a.Serverless))
sb.WriteString(fmt.Sprintf("%s:%v,", "ServersStr", a.ServersStr))
diff --git a/internal/config/initializer.go b/internal/config/initializer.go
index 936df8a..74d0289 100644
--- a/internal/config/initializer.go
+++ b/internal/config/initializer.go
@@ -82,7 +82,7 @@ func (in *initializer) transformConfig(sourceProcess source.Source, args *Args,
// There are some special options which can be set by environment variable.
func (in *initializer) readEnvironmentVars() {
- if Env("DTAIL_RUN_INTEGRATION_TESTS") {
+ if Env("DTAIL_INTEGRATION_TEST_RUN_MODE") {
os.Setenv("DTAIL_HOSTNAME_OVERRIDE", "integrationtest")
os.Setenv("DTAIL_SSH_KNOWN_HOSTS_FILE", "./known_hosts")
in.Server.HostKeyFile = "./ssh_host_key"
diff --git a/internal/ssh/client/authmethods.go b/internal/ssh/client/authmethods.go
index 2ee32ad..49cf938 100644
--- a/internal/ssh/client/authmethods.go
+++ b/internal/ssh/client/authmethods.go
@@ -25,27 +25,44 @@ func InitSSHAuthMethods(sshAuthMethods []gossh.AuthMethod,
return initKnownHostsAuthMethods(trustAllHosts, throttleCh, privateKeyPath)
}
+func initIntegrationTestKnownHostsAuthMethods() []gossh.AuthMethod {
+ var sshAuthMethods []gossh.AuthMethod
+ privateKeyPath := "./id_rsa"
+
+ GeneratePrivatePublicKeyPairIfNotExists(privateKeyPath, 4096)
+ authMethod, err := ssh.PrivateKey(privateKeyPath)
+ if err != nil {
+ dlog.Client.FatalPanic("Unable to use private SSH key", privateKeyPath, err)
+ }
+
+ sshAuthMethods = append(sshAuthMethods, authMethod)
+ dlog.Client.Debug("initKnownHostsAuthMethods",
+ "Added path to list of auth methods, not adding further methods", privateKeyPath)
+ return sshAuthMethods
+}
+
func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{},
privateKeyPath string) ([]gossh.AuthMethod, HostKeyCallback) {
- var sshAuthMethods []gossh.AuthMethod
knownHostsFile := config.SSHKnownHostsFile()
- knownHostsCallback, err := NewKnownHostsCallback(knownHostsFile, trustAllHosts,
- throttleCh)
+ knownHostsCallback, err := NewKnownHostsCallback(knownHostsFile, trustAllHosts, throttleCh)
if err != nil {
dlog.Client.FatalPanic(knownHostsFile, err)
}
-
dlog.Client.Debug("initKnownHostsAuthMethods", "Added known hosts file path", knownHostsFile)
+ if config.Env("DTAIL_INTEGRATION_TEST_RUN_MODE") {
+ return initIntegrationTestKnownHostsAuthMethods(), knownHostsCallback
+ }
+
+ var sshAuthMethods []gossh.AuthMethod
// First try to read custom private key path.
if privateKeyPath != "" {
authMethod, err := ssh.PrivateKey(privateKeyPath)
if err == nil {
sshAuthMethods = append(sshAuthMethods, authMethod)
dlog.Client.Debug("initKnownHostsAuthMethods",
- "Added path to list of auth methods, not adding further methods",
- privateKeyPath)
+ "Added path to list of auth methods, not adding further methods", privateKeyPath)
return sshAuthMethods, knownHostsCallback
}
dlog.Client.FatalPanic("Unable to use private SSH key", privateKeyPath, err)
@@ -59,8 +76,7 @@ func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{},
"to list of auth methods, not adding further methods")
return sshAuthMethods, knownHostsCallback
}
- dlog.Client.Debug("initKnownHostsAuthMethods",
- "Unable to init SSH Agent auth method", err)
+ dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to init SSH Agent auth method", err)
// Third, try Linux/UNIX default key paths
privateKeyPath = os.Getenv("HOME") + "/.ssh/id_rsa"
@@ -71,8 +87,7 @@ func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{},
"Added path to list of auth methods, not adding further methods", privateKeyPath)
return sshAuthMethods, knownHostsCallback
}
- dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to use private key",
- privateKeyPath, err)
+ dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to use private key", privateKeyPath, err)
privateKeyPath = os.Getenv("HOME") + "/.ssh/id_dsa"
authMethod, err = ssh.PrivateKey(privateKeyPath)
@@ -92,10 +107,7 @@ func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{},
return sshAuthMethods, knownHostsCallback
}
- dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to use private key",
- privateKeyPath, err)
-
- dlog.Client.FatalPanic("Unable to find private SSH key information")
+ dlog.Client.FatalPanic("Unable to find private SSH key information", privateKeyPath, err)
// Never reach this point.
return sshAuthMethods, knownHostsCallback
}
diff --git a/internal/ssh/client/clientkeypair.go b/internal/ssh/client/clientkeypair.go
new file mode 100644
index 0000000..0e21d0c
--- /dev/null
+++ b/internal/ssh/client/clientkeypair.go
@@ -0,0 +1,91 @@
+package client
+
+import (
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/x509"
+ "encoding/pem"
+ "fmt"
+ "io/ioutil"
+ "os"
+
+ "github.com/mimecast/dtail/internal/io/dlog"
+ "golang.org/x/crypto/ssh"
+)
+
+// GeneratePrivatePublicKeyPairIfNotExists generates a SSH key pair (used by the integration tests)
+func GeneratePrivatePublicKeyPairIfNotExists(keyPath string, bitSize int) {
+ if _, err := os.Stat(keyPath); err == nil {
+ dlog.Common.Debug("Private/public key pair already exists", keyPath)
+ return
+ }
+ GeneratePrivatePublicKeyPair(keyPath, bitSize)
+}
+
+// GeneratePrivatePublicKeyPair generates a SSH key pair (used by the integration tests)
+func GeneratePrivatePublicKeyPair(keyPath string, bitSize int) {
+ privateKeyPath := keyPath
+ publicKeyPath := fmt.Sprintf("%s.pub", keyPath)
+
+ dlog.Common.Debug("Generating private/public key pair", privateKeyPath, publicKeyPath)
+
+ privateKey, err := generatePrivateKey(bitSize)
+ if err != nil {
+ dlog.Common.FatalPanic(err)
+ }
+ publicKeyBytes, err := generatePublicKey(&privateKey.PublicKey)
+ if err != nil {
+ dlog.Common.FatalPanic(err)
+ }
+ privateKeyBytes := encodePrivateKeyToPEM(privateKey)
+ err = writeKey(privateKeyBytes, privateKeyPath)
+ if err != nil {
+ dlog.Common.FatalPanic(err)
+ }
+ err = writeKey([]byte(publicKeyBytes), publicKeyPath)
+ if err != nil {
+ dlog.Common.FatalPanic(err)
+ }
+
+ dlog.Common.Debug("Done generating private/public key pair", privateKeyPath, publicKeyPath)
+}
+
+func generatePrivateKey(bitSize int) (*rsa.PrivateKey, error) {
+ privateKey, err := rsa.GenerateKey(rand.Reader, bitSize)
+ if err != nil {
+ return nil, err
+ }
+ err = privateKey.Validate()
+ if err != nil {
+ return nil, err
+ }
+ return privateKey, nil
+}
+
+func encodePrivateKeyToPEM(privateKey *rsa.PrivateKey) []byte {
+ privDER := x509.MarshalPKCS1PrivateKey(privateKey)
+ privBlock := pem.Block{
+ Type: "RSA PRIVATE KEY",
+ Headers: nil,
+ Bytes: privDER,
+ }
+ privatePEM := pem.EncodeToMemory(&privBlock)
+ return privatePEM
+}
+
+func generatePublicKey(privatekey *rsa.PublicKey) ([]byte, error) {
+ publicRsaKey, err := ssh.NewPublicKey(privatekey)
+ if err != nil {
+ return nil, err
+ }
+ pubKeyBytes := ssh.MarshalAuthorizedKey(publicRsaKey)
+ return pubKeyBytes, nil
+}
+
+func writeKey(keyBytes []byte, saveFileTo string) error {
+ err := ioutil.WriteFile(saveFileTo, keyBytes, 0600)
+ if err != nil {
+ return err
+ }
+ return nil
+}