diff options
| author | Paul Buetow <paul@buetow.org> | 2021-09-18 18:43:19 +0300 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2021-10-02 12:26:29 +0300 |
| commit | 69b88a1cae0a61bd22530c384f40166b37b9f1ea (patch) | |
| tree | 3a0bbe5a25c3035e765ed40133f5a41f4f8dfedd | |
| parent | 6506e20f6c80f4acb7434eb9dd14f784a67189cd (diff) | |
remote connector is now an interface
| -rwxr-xr-x | docker/spindown.sh | 2 | ||||
| -rw-r--r-- | internal/clients/baseclient.go | 24 | ||||
| -rw-r--r-- | internal/clients/connectors/connector.go | 17 | ||||
| -rw-r--r-- | internal/clients/connectors/serverconnection.go (renamed from internal/clients/remote/connection.go) | 105 | ||||
| -rw-r--r-- | internal/clients/healthclient.go | 12 |
5 files changed, 95 insertions, 65 deletions
diff --git a/docker/spindown.sh b/docker/spindown.sh index 2202d22..7cf9cc6 100755 --- a/docker/spindown.sh +++ b/docker/spindown.sh @@ -11,3 +11,5 @@ for (( i=0; i < $NUM_INSTANCES; i++ )); do echo Removing $name docker rm $name done + +exit 0 diff --git a/internal/clients/baseclient.go b/internal/clients/baseclient.go index de0c101..d0631fc 100644 --- a/internal/clients/baseclient.go +++ b/internal/clients/baseclient.go @@ -5,7 +5,7 @@ import ( "sync" "time" - "github.com/mimecast/dtail/internal/clients/remote" + "github.com/mimecast/dtail/internal/clients/connectors" "github.com/mimecast/dtail/internal/discovery" "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/omode" @@ -23,7 +23,7 @@ type baseClient struct { // List of remote servers to connect to. servers []string // We have one connection per remote server. - connections []*remote.Connection + connections []connectors.Connector // SSH auth methods to use to connect to the remote servers. sshAuthMethods []gossh.AuthMethod // To deal with SSH host keys @@ -77,7 +77,7 @@ func (c *baseClient) Start(ctx context.Context, statsCh <-chan string) (status i var mutex sync.Mutex for i, conn := range c.connections { - go func(i int, conn *remote.Connection) { + go func(i int, conn connectors.Connector) { connStatus := c.start(ctx, active, i, conn) // Update global status. @@ -93,7 +93,7 @@ func (c *baseClient) Start(ctx context.Context, statsCh <-chan string) (status i return } -func (c *baseClient) start(ctx context.Context, active chan struct{}, i int, conn *remote.Connection) (status int) { +func (c *baseClient) start(ctx context.Context, active chan struct{}, i int, conn connectors.Connector) (status int) { // Increment connection count active <- struct{}{} // Derement connection count @@ -105,26 +105,20 @@ func (c *baseClient) start(ctx context.Context, active chan struct{}, i int, con conn.Start(connCtx, cancel, c.throttleCh, c.stats.connectionsEstCh) // Retrieve status code from handler (dtail client will exit with that status) - status = conn.Handler.Status() + status = conn.Handler().Status() if !c.retry { return } time.Sleep(time.Second * 2) - logger.Debug(conn.Server, "Reconnecting") - - conn = c.makeConnection(conn.Server, c.sshAuthMethods, c.hostKeyCallback) - c.connections[i] = conn + logger.Debug(conn.Server(), "Reconnecting") + c.connections[i] = c.makeConnection(conn.Server(), c.sshAuthMethods, c.hostKeyCallback) } } -func (c *baseClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback client.HostKeyCallback) *remote.Connection { - conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback) - conn.Handler = c.maker.makeHandler(server) - conn.Commands = c.maker.makeCommands() - - return conn +func (c *baseClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback client.HostKeyCallback) connectors.Connector { + return connectors.NewServerConnection(server, c.UserName, sshAuthMethods, hostKeyCallback, c.maker.makeHandler(server), c.maker.makeCommands()) } func (c *baseClient) waitUntilDone(ctx context.Context, active chan struct{}) { diff --git a/internal/clients/connectors/connector.go b/internal/clients/connectors/connector.go new file mode 100644 index 0000000..3ab6a08 --- /dev/null +++ b/internal/clients/connectors/connector.go @@ -0,0 +1,17 @@ +package connectors + +import ( + "context" + + "github.com/mimecast/dtail/internal/clients/handlers" +) + +// Connector interface. +type Connector interface { + // Start the connection. + Start(ctx context.Context, cancel context.CancelFunc, throttleCh, statsCh chan struct{}) + // Server hostname. + Server() string + // Handler for the connection. + Handler() handlers.Handler +} diff --git a/internal/clients/remote/connection.go b/internal/clients/connectors/serverconnection.go index b29ffed..fab2f87 100644 --- a/internal/clients/remote/connection.go +++ b/internal/clients/connectors/serverconnection.go @@ -1,4 +1,4 @@ -package remote +package connectors import ( "context" @@ -16,20 +16,20 @@ import ( "golang.org/x/crypto/ssh" ) -// Connection represents a client connection connection to a single server. -type Connection struct { +// ServerConnection represents a client connection connection to a single server. +type ServerConnection struct { // The remote server's hostname connected to. - Server string + server string // The remote server's port connected to. port int // The SSH client configuration used. config *ssh.ClientConfig // The SSH client handler to use. - Handler handlers.Handler + handler handlers.Handler // DTail commands sent from client to server. When client loses // connection to the server it re-connects automatically and sends the // same commands again. - Commands []string + commands []string // Is it a persistent connection or a one-off? isOneOff bool // To deal with SSH server host keys @@ -38,28 +38,33 @@ type Connection struct { throttlingDone bool } -// NewConnection returns a new connection. -func NewConnection(server string, userName string, authMethods []ssh.AuthMethod, hostKeyCallback client.HostKeyCallback) *Connection { +// NewServerConnection returns a new connection. +func NewServerConnection(server string, userName string, authMethods []ssh.AuthMethod, hostKeyCallback client.HostKeyCallback, handler handlers.Handler, commands []string) *ServerConnection { logger.Debug(server, "Creating new connection") - c := Connection{ + c := ServerConnection{ hostKeyCallback: hostKeyCallback, + server: server, + handler: handler, + commands: commands, config: &ssh.ClientConfig{ User: userName, Auth: authMethods, HostKeyCallback: hostKeyCallback.Wrap(), - Timeout: time.Second * 3, + Timeout: time.Second * 2, }, } - c.initServerPort(server) - + c.initServerPort() return &c } -// NewOneOffConnection creates new one-off connection (only for sending a series of commands and then quit). -func NewOneOffConnection(server string, userName string, authMethods []ssh.AuthMethod) *Connection { - c := Connection{ +// NewOneOffServerConnection creates new one-off connection (only for sending a series of commands and then quit). +func NewOneOffServerConnection(server string, userName string, authMethods []ssh.AuthMethod, handler handlers.Handler, commands []string) *ServerConnection { + c := ServerConnection{ + server: server, + handler: handler, + commands: commands, config: &ssh.ClientConfig{ User: userName, Auth: authMethods, @@ -68,46 +73,54 @@ func NewOneOffConnection(server string, userName string, authMethods []ssh.AuthM isOneOff: true, } - c.initServerPort(server) - + c.initServerPort() return &c } +// Server hostname +func (c *ServerConnection) Server() string { + return c.server +} + +// Handler for the connection +func (c *ServerConnection) Handler() handlers.Handler { + return c.handler +} + // Attempt to parse the server port address from the provided server FQDN. -func (c *Connection) initServerPort(server string) { - c.Server = server +func (c *ServerConnection) initServerPort() { c.port = config.Common.SSHPort - parts := strings.Split(server, ":") + parts := strings.Split(c.server, ":") if len(parts) == 2 { logger.Debug("Parsing port from hostname", parts) port, err := strconv.Atoi(parts[1]) if err != nil { - logger.FatalExit("Unable to parse client port", server, parts, err) + logger.FatalExit("Unable to parse client port", c.server, parts, err) } - c.Server = parts[0] + c.server = parts[0] c.port = port } } // Start the server connection. Build up SSH session and send some DTail commands. -func (c *Connection) Start(ctx context.Context, cancel context.CancelFunc, throttleCh, statsCh chan struct{}) { +func (c *ServerConnection) Start(ctx context.Context, cancel context.CancelFunc, throttleCh, statsCh chan struct{}) { // Throttle how many connections can be established concurrently (based on ch length) - logger.Debug(c.Server, "Throttling connection", len(throttleCh), cap(throttleCh)) + logger.Debug(c.server, "Throttling connection", len(throttleCh), cap(throttleCh)) select { case throttleCh <- struct{}{}: case <-ctx.Done(): - logger.Debug(c.Server, "Not establishing connection as context is done", len(throttleCh), cap(throttleCh)) + logger.Debug(c.server, "Not establishing connection as context is done", len(throttleCh), cap(throttleCh)) return } - logger.Debug(c.Server, "Throttling says that the connection can be established", len(throttleCh), cap(throttleCh)) + logger.Debug(c.server, "Throttling says that the connection can be established", len(throttleCh), cap(throttleCh)) go func() { defer func() { if !c.throttlingDone { - logger.Debug(c.Server, "Unthrottling connection (1)", len(throttleCh), cap(throttleCh)) + logger.Debug(c.server, "Unthrottling connection (1)", len(throttleCh), cap(throttleCh)) c.throttlingDone = true <-throttleCh } @@ -115,9 +128,9 @@ func (c *Connection) Start(ctx context.Context, cancel context.CancelFunc, throt }() if err := c.dial(ctx, cancel, throttleCh, statsCh); err != nil { - logger.Warn(c.Server, c.port, err) - if c.hostKeyCallback.Untrusted(fmt.Sprintf("%s:%d", c.Server, c.port)) { - logger.Debug(c.Server, "Not trusting host") + logger.Warn(c.server, c.port, err) + if c.hostKeyCallback.Untrusted(fmt.Sprintf("%s:%d", c.server, c.port)) { + logger.Debug(c.server, "Not trusting host") } } }() @@ -126,16 +139,16 @@ func (c *Connection) Start(ctx context.Context, cancel context.CancelFunc, throt } // Dail into a new SSH connection. Close connection in case of an error. -func (c *Connection) dial(ctx context.Context, cancel context.CancelFunc, throttleCh, statsCh chan struct{}) error { - logger.Debug(c.Server, "Incrementing connection stats") +func (c *ServerConnection) dial(ctx context.Context, cancel context.CancelFunc, throttleCh, statsCh chan struct{}) error { + logger.Debug(c.server, "Incrementing connection stats") statsCh <- struct{}{} defer func() { - logger.Debug(c.Server, "Decrementing connection stats") + logger.Debug(c.server, "Decrementing connection stats") <-statsCh }() - logger.Debug(c.Server, "Dialing into the connection") - address := fmt.Sprintf("%s:%d", c.Server, c.port) + logger.Debug(c.server, "Dialing into the connection") + address := fmt.Sprintf("%s:%d", c.server, c.port) client, err := ssh.Dial("tcp", address, c.config) if err != nil { @@ -147,8 +160,8 @@ func (c *Connection) dial(ctx context.Context, cancel context.CancelFunc, thrott } // Create the SSH session. Close the session in case of an error. -func (c *Connection) session(ctx context.Context, cancel context.CancelFunc, client *ssh.Client, throttleCh chan struct{}) error { - logger.Debug(c.Server, "session") +func (c *ServerConnection) session(ctx context.Context, cancel context.CancelFunc, client *ssh.Client, throttleCh chan struct{}) error { + logger.Debug(c.server, "session") session, err := client.NewSession() if err != nil { @@ -159,8 +172,8 @@ func (c *Connection) session(ctx context.Context, cancel context.CancelFunc, cli return c.handle(ctx, cancel, session, throttleCh) } -func (c *Connection) handle(ctx context.Context, cancel context.CancelFunc, session *ssh.Session, throttleCh chan struct{}) error { - logger.Debug(c.Server, "handle") +func (c *ServerConnection) handle(ctx context.Context, cancel context.CancelFunc, session *ssh.Session, throttleCh chan struct{}) error { + logger.Debug(c.server, "handle") stdinPipe, err := session.StdinPipe() if err != nil { @@ -177,36 +190,36 @@ func (c *Connection) handle(ctx context.Context, cancel context.CancelFunc, sess } go func() { - io.Copy(stdinPipe, c.Handler) + io.Copy(stdinPipe, c.handler) cancel() }() go func() { - io.Copy(c.Handler, stdoutPipe) + io.Copy(c.handler, stdoutPipe) cancel() }() go func() { select { - case <-c.Handler.Done(): + case <-c.handler.Done(): case <-ctx.Done(): } cancel() }() // Send all commands to client. - for _, command := range c.Commands { + for _, command := range c.commands { logger.Debug(command) - c.Handler.SendMessage(command) + c.handler.SendMessage(command) } if !c.throttlingDone { - logger.Debug(c.Server, "Unthrottling connection (2)", len(throttleCh), cap(throttleCh)) + logger.Debug(c.server, "Unthrottling connection (2)", len(throttleCh), cap(throttleCh)) c.throttlingDone = true <-throttleCh } <-ctx.Done() - c.Handler.Shutdown() + c.handler.Shutdown() return nil } diff --git a/internal/clients/healthclient.go b/internal/clients/healthclient.go index 692464c..47007b6 100644 --- a/internal/clients/healthclient.go +++ b/internal/clients/healthclient.go @@ -7,8 +7,8 @@ import ( "strings" "time" + "github.com/mimecast/dtail/internal/clients/connectors" "github.com/mimecast/dtail/internal/clients/handlers" - "github.com/mimecast/dtail/internal/clients/remote" "github.com/mimecast/dtail/internal/config" "github.com/mimecast/dtail/internal/omode" "github.com/mimecast/dtail/internal/protocol" @@ -47,9 +47,13 @@ func (c *HealthClient) Start(ctx context.Context) (status int) { throttleCh := make(chan struct{}, runtime.NumCPU()) statsCh := make(chan struct{}, 1) - conn := remote.NewOneOffConnection(c.server, c.userName, c.sshAuthMethods) - conn.Handler = handlers.NewHealthHandler(c.server, receive) - conn.Commands = []string{c.mode.String()} + conn := connectors.NewOneOffServerConnection( + c.server, + c.userName, + c.sshAuthMethods, + handlers.NewHealthHandler(c.server, receive), + []string{c.mode.String()}, + ) connCtx, cancel := context.WithCancel(ctx) go conn.Start(connCtx, cancel, throttleCh, statsCh) |
