diff options
| -rw-r--r-- | integrationtests/dtail_test.go | 103 | ||||
| -rw-r--r-- | integrationtests/dtailhealthcheck_test.go | 8 | ||||
| -rw-r--r-- | internal/clients/connectors/serverconnection.go | 17 |
3 files changed, 81 insertions, 47 deletions
diff --git a/integrationtests/dtail_test.go b/integrationtests/dtail_test.go index 267cd26..79b5881 100644 --- a/integrationtests/dtail_test.go +++ b/integrationtests/dtail_test.go @@ -2,18 +2,31 @@ package integrationtests import ( "context" + "fmt" "os" + "strings" "testing" + "time" ) +// TODO: Have a serverless variant too. func TestDTailWithServer(t *testing.T) { followFile := "dtail.follow.tmp" - //serverStdoutFile := "dtail.dserver.stdout.tmp" - //greetings := []string{"world", "sol system", "milky way", "universe", "multiverse"} + greetings := []string{"world!", "sol-system!", "milky-way!", "universe!", "multiverse!"} ctx, cancel := context.WithCancel(context.Background()) defer cancel() + go func() { + select { + case <-time.After(time.Minute): + t.Error("Max time for this test exceeded!") + cancel() + case <-ctx.Done(): + return + } + }() + serverCh, _, _, err := startCommand(ctx, "../dserver", "--logger", "stdout", @@ -40,55 +53,69 @@ func TestDTailWithServer(t *testing.T) { t.Error(err) return } + // Write greetings to followFile + fd, err := os.Create(followFile) + if err != nil { + t.Error(err) + } + defer fd.Close() + + go func() { + var circular int + for { + select { + case <-time.After(time.Second): + fd.WriteString(time.Now().String()) + fd.WriteString(fmt.Sprintf(" - Hello %s\n", greetings[circular])) + circular = (circular + 1) % len(greetings) + case <-ctx.Done(): + return + } + } + }() + + var greetingsRecv []string - for { + for len(greetingsRecv) < len(greetings) { select { case line := <-serverCh: t.Log("server:", line) case line := <-clientCh: t.Log("client:", line) + if strings.Contains(line, "Hello ") { + s := strings.Split(line, " ") + greeting := s[len(s)-1] + greetingsRecv = append(greetingsRecv, greeting) + t.Log("Received greeting", greeting, len(greetingsRecv)) + } case <-ctx.Done(): t.Log("Done reading client and server pipes") } } - /* - // Start dtail client, connect to the server and follow followFile. - - //clientStdoutFile := "dtail.stdout.tmp" - /* - - t.Log(clientArgs) - // TODO: Pipe with dtail command to read stdin stream. - // runCommandContextRetry(ctx, t, "../dtail", clientArgs, clientStdoutFile) - - // Write greetings to followFile - fd, err := os.Create(followFile) - if err != nil { - t.Error(err) - } - defer fd.Close() + // We expect to have received the greetings in the same order they were sent.` + offset := -1 + for i, g := range greetings { + if g == greetingsRecv[0] { + offset = i + break + } + } + if offset == -1 { + t.Error("Could not find first offset of greetings received") + return + } - go func() { - var circular int - for { - select { - case <-ctx.Done(): - return - case <-time.After(time.Second): - fd.WriteString(time.Now().String()) - fd.WriteString(fmt.Sprintf(" - Hello %s!\n", greetings[circular])) - circular = (circular + 1) % len(greetings) - } - } - }() - */ + for i, g := range greetingsRecv { + index := (i + offset) % len(greetings) + if greetings[index] != g { + t.Error(fmt.Sprintf("Expected '%s' but got '%s' at '%v' vs '%v'\n", + g, greetings[index], greetings, greetingsRecv)) + return + } + } - /* - os.Remove(serverStdoutFile) - os.Remove(clientStdoutFile) - os.Remove(followFile) - */ + os.Remove(followFile) } func TestDTailColorTable(t *testing.T) { diff --git a/integrationtests/dtailhealthcheck_test.go b/integrationtests/dtailhealthcheck_test.go index a99bfdc..bb6c146 100644 --- a/integrationtests/dtailhealthcheck_test.go +++ b/integrationtests/dtailhealthcheck_test.go @@ -53,14 +53,18 @@ func TestDTailHealthCheck3(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - startCommand(ctx, + _, _, _, err := startCommand(ctx, "../dserver", "--logger", "stdout", "--logLevel", "trace", "--port", "4242", ) + if err != nil { + t.Error(err) + return + } - _, err := runCommandRetry(ctx, 10, stdoutFile, + _, err = runCommandRetry(ctx, 10, stdoutFile, "../dtailhealthcheck", "--server", "localhost:4242") if err != nil { t.Error(err) diff --git a/internal/clients/connectors/serverconnection.go b/internal/clients/connectors/serverconnection.go index 2737ede..1df4d73 100644 --- a/internal/clients/connectors/serverconnection.go +++ b/internal/clients/connectors/serverconnection.go @@ -19,7 +19,11 @@ import ( // ServerConnection represents a connection to a single remote dtail server via // SSH protocol. type ServerConnection struct { - server string + // The full server string as received from the server discovery (can be with port number) + server string + // Only the hostname or FQDN (without the port number) + hostname string + // Only the port number. port int config *ssh.ClientConfig handler handlers.Handler @@ -37,7 +41,6 @@ func NewServerConnection(server string, userName string, c := ServerConnection{ hostKeyCallback: hostKeyCallback, server: server, - port: config.Common.SSHPort, handler: handler, commands: commands, config: &ssh.ClientConfig{ @@ -48,7 +51,6 @@ func NewServerConnection(server string, userName string, }, } - // TODO: After reconnecting the port is wrong! Due to string slicing? c.initServerPort() return &c } @@ -61,6 +63,7 @@ func (c *ServerConnection) Handler() handlers.Handler { return c.handler } // Attempt to parse the server port address from the provided server FQDN. func (c *ServerConnection) initServerPort() { + c.port = config.Common.SSHPort parts := strings.Split(c.server, ":") if len(parts) == 2 { dlog.Client.Debug("Parsing port from hostname", parts) @@ -68,7 +71,7 @@ func (c *ServerConnection) initServerPort() { if err != nil { dlog.Client.FatalPanic("Unable to parse client port", c.server, parts, err) } - c.server = parts[0] + c.hostname = parts[0] c.port = port } } @@ -103,8 +106,8 @@ func (c *ServerConnection) Start(ctx context.Context, cancel context.CancelFunc, }() if err := c.dial(ctx, cancel, throttleCh, statsCh); err != nil { - dlog.Client.Warn(c.server, c.port, err) - if c.hostKeyCallback.Untrusted(fmt.Sprintf("%s:%d", c.server, c.port)) { + dlog.Client.Warn(c.server, err) + if c.hostKeyCallback.Untrusted(c.server) { dlog.Client.Debug(c.server, "Not trusting host") } } @@ -125,7 +128,7 @@ func (c *ServerConnection) dial(ctx context.Context, cancel context.CancelFunc, }() dlog.Client.Debug(c.server, "Dialing into the connection") - address := fmt.Sprintf("%s:%d", c.server, c.port) + address := fmt.Sprintf("%s:%d", c.hostname, c.port) client, err := ssh.Dial("tcp", address, c.config) if err != nil { |
