summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPaul Buetow <paul@buetow.org>2021-09-18 18:43:19 +0300
committerPaul Buetow <paul@buetow.org>2021-10-02 12:26:29 +0300
commit69b88a1cae0a61bd22530c384f40166b37b9f1ea (patch)
tree3a0bbe5a25c3035e765ed40133f5a41f4f8dfedd
parent6506e20f6c80f4acb7434eb9dd14f784a67189cd (diff)
remote connector is now an interface
-rwxr-xr-xdocker/spindown.sh2
-rw-r--r--internal/clients/baseclient.go24
-rw-r--r--internal/clients/connectors/connector.go17
-rw-r--r--internal/clients/connectors/serverconnection.go (renamed from internal/clients/remote/connection.go)105
-rw-r--r--internal/clients/healthclient.go12
5 files changed, 95 insertions, 65 deletions
diff --git a/docker/spindown.sh b/docker/spindown.sh
index 2202d22..7cf9cc6 100755
--- a/docker/spindown.sh
+++ b/docker/spindown.sh
@@ -11,3 +11,5 @@ for (( i=0; i < $NUM_INSTANCES; i++ )); do
echo Removing $name
docker rm $name
done
+
+exit 0
diff --git a/internal/clients/baseclient.go b/internal/clients/baseclient.go
index de0c101..d0631fc 100644
--- a/internal/clients/baseclient.go
+++ b/internal/clients/baseclient.go
@@ -5,7 +5,7 @@ import (
"sync"
"time"
- "github.com/mimecast/dtail/internal/clients/remote"
+ "github.com/mimecast/dtail/internal/clients/connectors"
"github.com/mimecast/dtail/internal/discovery"
"github.com/mimecast/dtail/internal/io/logger"
"github.com/mimecast/dtail/internal/omode"
@@ -23,7 +23,7 @@ type baseClient struct {
// List of remote servers to connect to.
servers []string
// We have one connection per remote server.
- connections []*remote.Connection
+ connections []connectors.Connector
// SSH auth methods to use to connect to the remote servers.
sshAuthMethods []gossh.AuthMethod
// To deal with SSH host keys
@@ -77,7 +77,7 @@ func (c *baseClient) Start(ctx context.Context, statsCh <-chan string) (status i
var mutex sync.Mutex
for i, conn := range c.connections {
- go func(i int, conn *remote.Connection) {
+ go func(i int, conn connectors.Connector) {
connStatus := c.start(ctx, active, i, conn)
// Update global status.
@@ -93,7 +93,7 @@ func (c *baseClient) Start(ctx context.Context, statsCh <-chan string) (status i
return
}
-func (c *baseClient) start(ctx context.Context, active chan struct{}, i int, conn *remote.Connection) (status int) {
+func (c *baseClient) start(ctx context.Context, active chan struct{}, i int, conn connectors.Connector) (status int) {
// Increment connection count
active <- struct{}{}
// Derement connection count
@@ -105,26 +105,20 @@ func (c *baseClient) start(ctx context.Context, active chan struct{}, i int, con
conn.Start(connCtx, cancel, c.throttleCh, c.stats.connectionsEstCh)
// Retrieve status code from handler (dtail client will exit with that status)
- status = conn.Handler.Status()
+ status = conn.Handler().Status()
if !c.retry {
return
}
time.Sleep(time.Second * 2)
- logger.Debug(conn.Server, "Reconnecting")
-
- conn = c.makeConnection(conn.Server, c.sshAuthMethods, c.hostKeyCallback)
- c.connections[i] = conn
+ logger.Debug(conn.Server(), "Reconnecting")
+ c.connections[i] = c.makeConnection(conn.Server(), c.sshAuthMethods, c.hostKeyCallback)
}
}
-func (c *baseClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback client.HostKeyCallback) *remote.Connection {
- conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback)
- conn.Handler = c.maker.makeHandler(server)
- conn.Commands = c.maker.makeCommands()
-
- return conn
+func (c *baseClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback client.HostKeyCallback) connectors.Connector {
+ return connectors.NewServerConnection(server, c.UserName, sshAuthMethods, hostKeyCallback, c.maker.makeHandler(server), c.maker.makeCommands())
}
func (c *baseClient) waitUntilDone(ctx context.Context, active chan struct{}) {
diff --git a/internal/clients/connectors/connector.go b/internal/clients/connectors/connector.go
new file mode 100644
index 0000000..3ab6a08
--- /dev/null
+++ b/internal/clients/connectors/connector.go
@@ -0,0 +1,17 @@
+package connectors
+
+import (
+ "context"
+
+ "github.com/mimecast/dtail/internal/clients/handlers"
+)
+
+// Connector interface.
+type Connector interface {
+ // Start the connection.
+ Start(ctx context.Context, cancel context.CancelFunc, throttleCh, statsCh chan struct{})
+ // Server hostname.
+ Server() string
+ // Handler for the connection.
+ Handler() handlers.Handler
+}
diff --git a/internal/clients/remote/connection.go b/internal/clients/connectors/serverconnection.go
index b29ffed..fab2f87 100644
--- a/internal/clients/remote/connection.go
+++ b/internal/clients/connectors/serverconnection.go
@@ -1,4 +1,4 @@
-package remote
+package connectors
import (
"context"
@@ -16,20 +16,20 @@ import (
"golang.org/x/crypto/ssh"
)
-// Connection represents a client connection connection to a single server.
-type Connection struct {
+// ServerConnection represents a client connection connection to a single server.
+type ServerConnection struct {
// The remote server's hostname connected to.
- Server string
+ server string
// The remote server's port connected to.
port int
// The SSH client configuration used.
config *ssh.ClientConfig
// The SSH client handler to use.
- Handler handlers.Handler
+ handler handlers.Handler
// DTail commands sent from client to server. When client loses
// connection to the server it re-connects automatically and sends the
// same commands again.
- Commands []string
+ commands []string
// Is it a persistent connection or a one-off?
isOneOff bool
// To deal with SSH server host keys
@@ -38,28 +38,33 @@ type Connection struct {
throttlingDone bool
}
-// NewConnection returns a new connection.
-func NewConnection(server string, userName string, authMethods []ssh.AuthMethod, hostKeyCallback client.HostKeyCallback) *Connection {
+// NewServerConnection returns a new connection.
+func NewServerConnection(server string, userName string, authMethods []ssh.AuthMethod, hostKeyCallback client.HostKeyCallback, handler handlers.Handler, commands []string) *ServerConnection {
logger.Debug(server, "Creating new connection")
- c := Connection{
+ c := ServerConnection{
hostKeyCallback: hostKeyCallback,
+ server: server,
+ handler: handler,
+ commands: commands,
config: &ssh.ClientConfig{
User: userName,
Auth: authMethods,
HostKeyCallback: hostKeyCallback.Wrap(),
- Timeout: time.Second * 3,
+ Timeout: time.Second * 2,
},
}
- c.initServerPort(server)
-
+ c.initServerPort()
return &c
}
-// NewOneOffConnection creates new one-off connection (only for sending a series of commands and then quit).
-func NewOneOffConnection(server string, userName string, authMethods []ssh.AuthMethod) *Connection {
- c := Connection{
+// NewOneOffServerConnection creates new one-off connection (only for sending a series of commands and then quit).
+func NewOneOffServerConnection(server string, userName string, authMethods []ssh.AuthMethod, handler handlers.Handler, commands []string) *ServerConnection {
+ c := ServerConnection{
+ server: server,
+ handler: handler,
+ commands: commands,
config: &ssh.ClientConfig{
User: userName,
Auth: authMethods,
@@ -68,46 +73,54 @@ func NewOneOffConnection(server string, userName string, authMethods []ssh.AuthM
isOneOff: true,
}
- c.initServerPort(server)
-
+ c.initServerPort()
return &c
}
+// Server hostname
+func (c *ServerConnection) Server() string {
+ return c.server
+}
+
+// Handler for the connection
+func (c *ServerConnection) Handler() handlers.Handler {
+ return c.handler
+}
+
// Attempt to parse the server port address from the provided server FQDN.
-func (c *Connection) initServerPort(server string) {
- c.Server = server
+func (c *ServerConnection) initServerPort() {
c.port = config.Common.SSHPort
- parts := strings.Split(server, ":")
+ parts := strings.Split(c.server, ":")
if len(parts) == 2 {
logger.Debug("Parsing port from hostname", parts)
port, err := strconv.Atoi(parts[1])
if err != nil {
- logger.FatalExit("Unable to parse client port", server, parts, err)
+ logger.FatalExit("Unable to parse client port", c.server, parts, err)
}
- c.Server = parts[0]
+ c.server = parts[0]
c.port = port
}
}
// Start the server connection. Build up SSH session and send some DTail commands.
-func (c *Connection) Start(ctx context.Context, cancel context.CancelFunc, throttleCh, statsCh chan struct{}) {
+func (c *ServerConnection) Start(ctx context.Context, cancel context.CancelFunc, throttleCh, statsCh chan struct{}) {
// Throttle how many connections can be established concurrently (based on ch length)
- logger.Debug(c.Server, "Throttling connection", len(throttleCh), cap(throttleCh))
+ logger.Debug(c.server, "Throttling connection", len(throttleCh), cap(throttleCh))
select {
case throttleCh <- struct{}{}:
case <-ctx.Done():
- logger.Debug(c.Server, "Not establishing connection as context is done", len(throttleCh), cap(throttleCh))
+ logger.Debug(c.server, "Not establishing connection as context is done", len(throttleCh), cap(throttleCh))
return
}
- logger.Debug(c.Server, "Throttling says that the connection can be established", len(throttleCh), cap(throttleCh))
+ logger.Debug(c.server, "Throttling says that the connection can be established", len(throttleCh), cap(throttleCh))
go func() {
defer func() {
if !c.throttlingDone {
- logger.Debug(c.Server, "Unthrottling connection (1)", len(throttleCh), cap(throttleCh))
+ logger.Debug(c.server, "Unthrottling connection (1)", len(throttleCh), cap(throttleCh))
c.throttlingDone = true
<-throttleCh
}
@@ -115,9 +128,9 @@ func (c *Connection) Start(ctx context.Context, cancel context.CancelFunc, throt
}()
if err := c.dial(ctx, cancel, throttleCh, statsCh); err != nil {
- logger.Warn(c.Server, c.port, err)
- if c.hostKeyCallback.Untrusted(fmt.Sprintf("%s:%d", c.Server, c.port)) {
- logger.Debug(c.Server, "Not trusting host")
+ logger.Warn(c.server, c.port, err)
+ if c.hostKeyCallback.Untrusted(fmt.Sprintf("%s:%d", c.server, c.port)) {
+ logger.Debug(c.server, "Not trusting host")
}
}
}()
@@ -126,16 +139,16 @@ func (c *Connection) Start(ctx context.Context, cancel context.CancelFunc, throt
}
// Dail into a new SSH connection. Close connection in case of an error.
-func (c *Connection) dial(ctx context.Context, cancel context.CancelFunc, throttleCh, statsCh chan struct{}) error {
- logger.Debug(c.Server, "Incrementing connection stats")
+func (c *ServerConnection) dial(ctx context.Context, cancel context.CancelFunc, throttleCh, statsCh chan struct{}) error {
+ logger.Debug(c.server, "Incrementing connection stats")
statsCh <- struct{}{}
defer func() {
- logger.Debug(c.Server, "Decrementing connection stats")
+ logger.Debug(c.server, "Decrementing connection stats")
<-statsCh
}()
- logger.Debug(c.Server, "Dialing into the connection")
- address := fmt.Sprintf("%s:%d", c.Server, c.port)
+ logger.Debug(c.server, "Dialing into the connection")
+ address := fmt.Sprintf("%s:%d", c.server, c.port)
client, err := ssh.Dial("tcp", address, c.config)
if err != nil {
@@ -147,8 +160,8 @@ func (c *Connection) dial(ctx context.Context, cancel context.CancelFunc, thrott
}
// Create the SSH session. Close the session in case of an error.
-func (c *Connection) session(ctx context.Context, cancel context.CancelFunc, client *ssh.Client, throttleCh chan struct{}) error {
- logger.Debug(c.Server, "session")
+func (c *ServerConnection) session(ctx context.Context, cancel context.CancelFunc, client *ssh.Client, throttleCh chan struct{}) error {
+ logger.Debug(c.server, "session")
session, err := client.NewSession()
if err != nil {
@@ -159,8 +172,8 @@ func (c *Connection) session(ctx context.Context, cancel context.CancelFunc, cli
return c.handle(ctx, cancel, session, throttleCh)
}
-func (c *Connection) handle(ctx context.Context, cancel context.CancelFunc, session *ssh.Session, throttleCh chan struct{}) error {
- logger.Debug(c.Server, "handle")
+func (c *ServerConnection) handle(ctx context.Context, cancel context.CancelFunc, session *ssh.Session, throttleCh chan struct{}) error {
+ logger.Debug(c.server, "handle")
stdinPipe, err := session.StdinPipe()
if err != nil {
@@ -177,36 +190,36 @@ func (c *Connection) handle(ctx context.Context, cancel context.CancelFunc, sess
}
go func() {
- io.Copy(stdinPipe, c.Handler)
+ io.Copy(stdinPipe, c.handler)
cancel()
}()
go func() {
- io.Copy(c.Handler, stdoutPipe)
+ io.Copy(c.handler, stdoutPipe)
cancel()
}()
go func() {
select {
- case <-c.Handler.Done():
+ case <-c.handler.Done():
case <-ctx.Done():
}
cancel()
}()
// Send all commands to client.
- for _, command := range c.Commands {
+ for _, command := range c.commands {
logger.Debug(command)
- c.Handler.SendMessage(command)
+ c.handler.SendMessage(command)
}
if !c.throttlingDone {
- logger.Debug(c.Server, "Unthrottling connection (2)", len(throttleCh), cap(throttleCh))
+ logger.Debug(c.server, "Unthrottling connection (2)", len(throttleCh), cap(throttleCh))
c.throttlingDone = true
<-throttleCh
}
<-ctx.Done()
- c.Handler.Shutdown()
+ c.handler.Shutdown()
return nil
}
diff --git a/internal/clients/healthclient.go b/internal/clients/healthclient.go
index 692464c..47007b6 100644
--- a/internal/clients/healthclient.go
+++ b/internal/clients/healthclient.go
@@ -7,8 +7,8 @@ import (
"strings"
"time"
+ "github.com/mimecast/dtail/internal/clients/connectors"
"github.com/mimecast/dtail/internal/clients/handlers"
- "github.com/mimecast/dtail/internal/clients/remote"
"github.com/mimecast/dtail/internal/config"
"github.com/mimecast/dtail/internal/omode"
"github.com/mimecast/dtail/internal/protocol"
@@ -47,9 +47,13 @@ func (c *HealthClient) Start(ctx context.Context) (status int) {
throttleCh := make(chan struct{}, runtime.NumCPU())
statsCh := make(chan struct{}, 1)
- conn := remote.NewOneOffConnection(c.server, c.userName, c.sshAuthMethods)
- conn.Handler = handlers.NewHealthHandler(c.server, receive)
- conn.Commands = []string{c.mode.String()}
+ conn := connectors.NewOneOffServerConnection(
+ c.server,
+ c.userName,
+ c.sshAuthMethods,
+ handlers.NewHealthHandler(c.server, receive),
+ []string{c.mode.String()},
+ )
connCtx, cancel := context.WithCancel(ctx)
go conn.Start(connCtx, cancel, throttleCh, statsCh)