summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--integrationtests/dtail_test.go103
-rw-r--r--integrationtests/dtailhealthcheck_test.go8
-rw-r--r--internal/clients/connectors/serverconnection.go17
3 files changed, 81 insertions, 47 deletions
diff --git a/integrationtests/dtail_test.go b/integrationtests/dtail_test.go
index 267cd26..79b5881 100644
--- a/integrationtests/dtail_test.go
+++ b/integrationtests/dtail_test.go
@@ -2,18 +2,31 @@ package integrationtests
import (
"context"
+ "fmt"
"os"
+ "strings"
"testing"
+ "time"
)
+// TODO: Have a serverless variant too.
func TestDTailWithServer(t *testing.T) {
followFile := "dtail.follow.tmp"
- //serverStdoutFile := "dtail.dserver.stdout.tmp"
- //greetings := []string{"world", "sol system", "milky way", "universe", "multiverse"}
+ greetings := []string{"world!", "sol-system!", "milky-way!", "universe!", "multiverse!"}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
+ go func() {
+ select {
+ case <-time.After(time.Minute):
+ t.Error("Max time for this test exceeded!")
+ cancel()
+ case <-ctx.Done():
+ return
+ }
+ }()
+
serverCh, _, _, err := startCommand(ctx,
"../dserver",
"--logger", "stdout",
@@ -40,55 +53,69 @@ func TestDTailWithServer(t *testing.T) {
t.Error(err)
return
}
+ // Write greetings to followFile
+ fd, err := os.Create(followFile)
+ if err != nil {
+ t.Error(err)
+ }
+ defer fd.Close()
+
+ go func() {
+ var circular int
+ for {
+ select {
+ case <-time.After(time.Second):
+ fd.WriteString(time.Now().String())
+ fd.WriteString(fmt.Sprintf(" - Hello %s\n", greetings[circular]))
+ circular = (circular + 1) % len(greetings)
+ case <-ctx.Done():
+ return
+ }
+ }
+ }()
+
+ var greetingsRecv []string
- for {
+ for len(greetingsRecv) < len(greetings) {
select {
case line := <-serverCh:
t.Log("server:", line)
case line := <-clientCh:
t.Log("client:", line)
+ if strings.Contains(line, "Hello ") {
+ s := strings.Split(line, " ")
+ greeting := s[len(s)-1]
+ greetingsRecv = append(greetingsRecv, greeting)
+ t.Log("Received greeting", greeting, len(greetingsRecv))
+ }
case <-ctx.Done():
t.Log("Done reading client and server pipes")
}
}
- /*
- // Start dtail client, connect to the server and follow followFile.
-
- //clientStdoutFile := "dtail.stdout.tmp"
- /*
-
- t.Log(clientArgs)
- // TODO: Pipe with dtail command to read stdin stream.
- // runCommandContextRetry(ctx, t, "../dtail", clientArgs, clientStdoutFile)
-
- // Write greetings to followFile
- fd, err := os.Create(followFile)
- if err != nil {
- t.Error(err)
- }
- defer fd.Close()
+ // We expect to have received the greetings in the same order they were sent.`
+ offset := -1
+ for i, g := range greetings {
+ if g == greetingsRecv[0] {
+ offset = i
+ break
+ }
+ }
+ if offset == -1 {
+ t.Error("Could not find first offset of greetings received")
+ return
+ }
- go func() {
- var circular int
- for {
- select {
- case <-ctx.Done():
- return
- case <-time.After(time.Second):
- fd.WriteString(time.Now().String())
- fd.WriteString(fmt.Sprintf(" - Hello %s!\n", greetings[circular]))
- circular = (circular + 1) % len(greetings)
- }
- }
- }()
- */
+ for i, g := range greetingsRecv {
+ index := (i + offset) % len(greetings)
+ if greetings[index] != g {
+ t.Error(fmt.Sprintf("Expected '%s' but got '%s' at '%v' vs '%v'\n",
+ g, greetings[index], greetings, greetingsRecv))
+ return
+ }
+ }
- /*
- os.Remove(serverStdoutFile)
- os.Remove(clientStdoutFile)
- os.Remove(followFile)
- */
+ os.Remove(followFile)
}
func TestDTailColorTable(t *testing.T) {
diff --git a/integrationtests/dtailhealthcheck_test.go b/integrationtests/dtailhealthcheck_test.go
index a99bfdc..bb6c146 100644
--- a/integrationtests/dtailhealthcheck_test.go
+++ b/integrationtests/dtailhealthcheck_test.go
@@ -53,14 +53,18 @@ func TestDTailHealthCheck3(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
- startCommand(ctx,
+ _, _, _, err := startCommand(ctx,
"../dserver",
"--logger", "stdout",
"--logLevel", "trace",
"--port", "4242",
)
+ if err != nil {
+ t.Error(err)
+ return
+ }
- _, err := runCommandRetry(ctx, 10, stdoutFile,
+ _, err = runCommandRetry(ctx, 10, stdoutFile,
"../dtailhealthcheck", "--server", "localhost:4242")
if err != nil {
t.Error(err)
diff --git a/internal/clients/connectors/serverconnection.go b/internal/clients/connectors/serverconnection.go
index 2737ede..1df4d73 100644
--- a/internal/clients/connectors/serverconnection.go
+++ b/internal/clients/connectors/serverconnection.go
@@ -19,7 +19,11 @@ import (
// ServerConnection represents a connection to a single remote dtail server via
// SSH protocol.
type ServerConnection struct {
- server string
+ // The full server string as received from the server discovery (can be with port number)
+ server string
+ // Only the hostname or FQDN (without the port number)
+ hostname string
+ // Only the port number.
port int
config *ssh.ClientConfig
handler handlers.Handler
@@ -37,7 +41,6 @@ func NewServerConnection(server string, userName string,
c := ServerConnection{
hostKeyCallback: hostKeyCallback,
server: server,
- port: config.Common.SSHPort,
handler: handler,
commands: commands,
config: &ssh.ClientConfig{
@@ -48,7 +51,6 @@ func NewServerConnection(server string, userName string,
},
}
- // TODO: After reconnecting the port is wrong! Due to string slicing?
c.initServerPort()
return &c
}
@@ -61,6 +63,7 @@ func (c *ServerConnection) Handler() handlers.Handler { return c.handler }
// Attempt to parse the server port address from the provided server FQDN.
func (c *ServerConnection) initServerPort() {
+ c.port = config.Common.SSHPort
parts := strings.Split(c.server, ":")
if len(parts) == 2 {
dlog.Client.Debug("Parsing port from hostname", parts)
@@ -68,7 +71,7 @@ func (c *ServerConnection) initServerPort() {
if err != nil {
dlog.Client.FatalPanic("Unable to parse client port", c.server, parts, err)
}
- c.server = parts[0]
+ c.hostname = parts[0]
c.port = port
}
}
@@ -103,8 +106,8 @@ func (c *ServerConnection) Start(ctx context.Context, cancel context.CancelFunc,
}()
if err := c.dial(ctx, cancel, throttleCh, statsCh); err != nil {
- dlog.Client.Warn(c.server, c.port, err)
- if c.hostKeyCallback.Untrusted(fmt.Sprintf("%s:%d", c.server, c.port)) {
+ dlog.Client.Warn(c.server, err)
+ if c.hostKeyCallback.Untrusted(c.server) {
dlog.Client.Debug(c.server, "Not trusting host")
}
}
@@ -125,7 +128,7 @@ func (c *ServerConnection) dial(ctx context.Context, cancel context.CancelFunc,
}()
dlog.Client.Debug(c.server, "Dialing into the connection")
- address := fmt.Sprintf("%s:%d", c.server, c.port)
+ address := fmt.Sprintf("%s:%d", c.hostname, c.port)
client, err := ssh.Dial("tcp", address, c.config)
if err != nil {