From f796df4c2f4bc8b61152c7d3e363152fb7bbc6f9 Mon Sep 17 00:00:00 2001 From: Paul Buetow Date: Sat, 18 Sep 2021 18:43:19 +0300 Subject: remote connector is now an interface --- internal/clients/baseclient.go | 24 +-- internal/clients/connectors/connector.go | 17 ++ internal/clients/connectors/serverconnection.go | 225 ++++++++++++++++++++++++ internal/clients/healthclient.go | 12 +- internal/clients/remote/connection.go | 212 ---------------------- 5 files changed, 259 insertions(+), 231 deletions(-) create mode 100644 internal/clients/connectors/connector.go create mode 100644 internal/clients/connectors/serverconnection.go delete mode 100644 internal/clients/remote/connection.go (limited to 'internal') diff --git a/internal/clients/baseclient.go b/internal/clients/baseclient.go index de0c101..d0631fc 100644 --- a/internal/clients/baseclient.go +++ b/internal/clients/baseclient.go @@ -5,7 +5,7 @@ import ( "sync" "time" - "github.com/mimecast/dtail/internal/clients/remote" + "github.com/mimecast/dtail/internal/clients/connectors" "github.com/mimecast/dtail/internal/discovery" "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/omode" @@ -23,7 +23,7 @@ type baseClient struct { // List of remote servers to connect to. servers []string // We have one connection per remote server. - connections []*remote.Connection + connections []connectors.Connector // SSH auth methods to use to connect to the remote servers. sshAuthMethods []gossh.AuthMethod // To deal with SSH host keys @@ -77,7 +77,7 @@ func (c *baseClient) Start(ctx context.Context, statsCh <-chan string) (status i var mutex sync.Mutex for i, conn := range c.connections { - go func(i int, conn *remote.Connection) { + go func(i int, conn connectors.Connector) { connStatus := c.start(ctx, active, i, conn) // Update global status. @@ -93,7 +93,7 @@ func (c *baseClient) Start(ctx context.Context, statsCh <-chan string) (status i return } -func (c *baseClient) start(ctx context.Context, active chan struct{}, i int, conn *remote.Connection) (status int) { +func (c *baseClient) start(ctx context.Context, active chan struct{}, i int, conn connectors.Connector) (status int) { // Increment connection count active <- struct{}{} // Derement connection count @@ -105,26 +105,20 @@ func (c *baseClient) start(ctx context.Context, active chan struct{}, i int, con conn.Start(connCtx, cancel, c.throttleCh, c.stats.connectionsEstCh) // Retrieve status code from handler (dtail client will exit with that status) - status = conn.Handler.Status() + status = conn.Handler().Status() if !c.retry { return } time.Sleep(time.Second * 2) - logger.Debug(conn.Server, "Reconnecting") - - conn = c.makeConnection(conn.Server, c.sshAuthMethods, c.hostKeyCallback) - c.connections[i] = conn + logger.Debug(conn.Server(), "Reconnecting") + c.connections[i] = c.makeConnection(conn.Server(), c.sshAuthMethods, c.hostKeyCallback) } } -func (c *baseClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback client.HostKeyCallback) *remote.Connection { - conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback) - conn.Handler = c.maker.makeHandler(server) - conn.Commands = c.maker.makeCommands() - - return conn +func (c *baseClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback client.HostKeyCallback) connectors.Connector { + return connectors.NewServerConnection(server, c.UserName, sshAuthMethods, hostKeyCallback, c.maker.makeHandler(server), c.maker.makeCommands()) } func (c *baseClient) waitUntilDone(ctx context.Context, active chan struct{}) { diff --git a/internal/clients/connectors/connector.go b/internal/clients/connectors/connector.go new file mode 100644 index 0000000..3ab6a08 --- /dev/null +++ b/internal/clients/connectors/connector.go @@ -0,0 +1,17 @@ +package connectors + +import ( + "context" + + "github.com/mimecast/dtail/internal/clients/handlers" +) + +// Connector interface. +type Connector interface { + // Start the connection. + Start(ctx context.Context, cancel context.CancelFunc, throttleCh, statsCh chan struct{}) + // Server hostname. + Server() string + // Handler for the connection. + Handler() handlers.Handler +} diff --git a/internal/clients/connectors/serverconnection.go b/internal/clients/connectors/serverconnection.go new file mode 100644 index 0000000..fab2f87 --- /dev/null +++ b/internal/clients/connectors/serverconnection.go @@ -0,0 +1,225 @@ +package connectors + +import ( + "context" + "fmt" + "io" + "strconv" + "strings" + "time" + + "github.com/mimecast/dtail/internal/clients/handlers" + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/ssh/client" + + "golang.org/x/crypto/ssh" +) + +// ServerConnection represents a client connection connection to a single server. +type ServerConnection struct { + // The remote server's hostname connected to. + server string + // The remote server's port connected to. + port int + // The SSH client configuration used. + config *ssh.ClientConfig + // The SSH client handler to use. + handler handlers.Handler + // DTail commands sent from client to server. When client loses + // connection to the server it re-connects automatically and sends the + // same commands again. + commands []string + // Is it a persistent connection or a one-off? + isOneOff bool + // To deal with SSH server host keys + hostKeyCallback client.HostKeyCallback + // To determine if connection throttling has finished or not + throttlingDone bool +} + +// NewServerConnection returns a new connection. +func NewServerConnection(server string, userName string, authMethods []ssh.AuthMethod, hostKeyCallback client.HostKeyCallback, handler handlers.Handler, commands []string) *ServerConnection { + logger.Debug(server, "Creating new connection") + + c := ServerConnection{ + hostKeyCallback: hostKeyCallback, + server: server, + handler: handler, + commands: commands, + config: &ssh.ClientConfig{ + User: userName, + Auth: authMethods, + HostKeyCallback: hostKeyCallback.Wrap(), + Timeout: time.Second * 2, + }, + } + + c.initServerPort() + return &c +} + +// NewOneOffServerConnection creates new one-off connection (only for sending a series of commands and then quit). +func NewOneOffServerConnection(server string, userName string, authMethods []ssh.AuthMethod, handler handlers.Handler, commands []string) *ServerConnection { + c := ServerConnection{ + server: server, + handler: handler, + commands: commands, + config: &ssh.ClientConfig{ + User: userName, + Auth: authMethods, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + }, + isOneOff: true, + } + + c.initServerPort() + return &c +} + +// Server hostname +func (c *ServerConnection) Server() string { + return c.server +} + +// Handler for the connection +func (c *ServerConnection) Handler() handlers.Handler { + return c.handler +} + +// Attempt to parse the server port address from the provided server FQDN. +func (c *ServerConnection) initServerPort() { + c.port = config.Common.SSHPort + parts := strings.Split(c.server, ":") + + if len(parts) == 2 { + logger.Debug("Parsing port from hostname", parts) + port, err := strconv.Atoi(parts[1]) + if err != nil { + logger.FatalExit("Unable to parse client port", c.server, parts, err) + } + c.server = parts[0] + c.port = port + } +} + +// Start the server connection. Build up SSH session and send some DTail commands. +func (c *ServerConnection) Start(ctx context.Context, cancel context.CancelFunc, throttleCh, statsCh chan struct{}) { + // Throttle how many connections can be established concurrently (based on ch length) + logger.Debug(c.server, "Throttling connection", len(throttleCh), cap(throttleCh)) + + select { + case throttleCh <- struct{}{}: + case <-ctx.Done(): + logger.Debug(c.server, "Not establishing connection as context is done", len(throttleCh), cap(throttleCh)) + return + } + + logger.Debug(c.server, "Throttling says that the connection can be established", len(throttleCh), cap(throttleCh)) + + go func() { + defer func() { + if !c.throttlingDone { + logger.Debug(c.server, "Unthrottling connection (1)", len(throttleCh), cap(throttleCh)) + c.throttlingDone = true + <-throttleCh + } + cancel() + }() + + if err := c.dial(ctx, cancel, throttleCh, statsCh); err != nil { + logger.Warn(c.server, c.port, err) + if c.hostKeyCallback.Untrusted(fmt.Sprintf("%s:%d", c.server, c.port)) { + logger.Debug(c.server, "Not trusting host") + } + } + }() + + <-ctx.Done() +} + +// Dail into a new SSH connection. Close connection in case of an error. +func (c *ServerConnection) dial(ctx context.Context, cancel context.CancelFunc, throttleCh, statsCh chan struct{}) error { + logger.Debug(c.server, "Incrementing connection stats") + statsCh <- struct{}{} + defer func() { + logger.Debug(c.server, "Decrementing connection stats") + <-statsCh + }() + + logger.Debug(c.server, "Dialing into the connection") + address := fmt.Sprintf("%s:%d", c.server, c.port) + + client, err := ssh.Dial("tcp", address, c.config) + if err != nil { + return err + } + defer client.Close() + + return c.session(ctx, cancel, client, throttleCh) +} + +// Create the SSH session. Close the session in case of an error. +func (c *ServerConnection) session(ctx context.Context, cancel context.CancelFunc, client *ssh.Client, throttleCh chan struct{}) error { + logger.Debug(c.server, "session") + + session, err := client.NewSession() + if err != nil { + return err + } + defer session.Close() + + return c.handle(ctx, cancel, session, throttleCh) +} + +func (c *ServerConnection) handle(ctx context.Context, cancel context.CancelFunc, session *ssh.Session, throttleCh chan struct{}) error { + logger.Debug(c.server, "handle") + + stdinPipe, err := session.StdinPipe() + if err != nil { + return err + } + + stdoutPipe, err := session.StdoutPipe() + if err != nil { + return err + } + + if err := session.Shell(); err != nil { + return err + } + + go func() { + io.Copy(stdinPipe, c.handler) + cancel() + }() + + go func() { + io.Copy(c.handler, stdoutPipe) + cancel() + }() + + go func() { + select { + case <-c.handler.Done(): + case <-ctx.Done(): + } + cancel() + }() + + // Send all commands to client. + for _, command := range c.commands { + logger.Debug(command) + c.handler.SendMessage(command) + } + + if !c.throttlingDone { + logger.Debug(c.server, "Unthrottling connection (2)", len(throttleCh), cap(throttleCh)) + c.throttlingDone = true + <-throttleCh + } + + <-ctx.Done() + c.handler.Shutdown() + return nil +} diff --git a/internal/clients/healthclient.go b/internal/clients/healthclient.go index 692464c..47007b6 100644 --- a/internal/clients/healthclient.go +++ b/internal/clients/healthclient.go @@ -7,8 +7,8 @@ import ( "strings" "time" + "github.com/mimecast/dtail/internal/clients/connectors" "github.com/mimecast/dtail/internal/clients/handlers" - "github.com/mimecast/dtail/internal/clients/remote" "github.com/mimecast/dtail/internal/config" "github.com/mimecast/dtail/internal/omode" "github.com/mimecast/dtail/internal/protocol" @@ -47,9 +47,13 @@ func (c *HealthClient) Start(ctx context.Context) (status int) { throttleCh := make(chan struct{}, runtime.NumCPU()) statsCh := make(chan struct{}, 1) - conn := remote.NewOneOffConnection(c.server, c.userName, c.sshAuthMethods) - conn.Handler = handlers.NewHealthHandler(c.server, receive) - conn.Commands = []string{c.mode.String()} + conn := connectors.NewOneOffServerConnection( + c.server, + c.userName, + c.sshAuthMethods, + handlers.NewHealthHandler(c.server, receive), + []string{c.mode.String()}, + ) connCtx, cancel := context.WithCancel(ctx) go conn.Start(connCtx, cancel, throttleCh, statsCh) diff --git a/internal/clients/remote/connection.go b/internal/clients/remote/connection.go deleted file mode 100644 index b29ffed..0000000 --- a/internal/clients/remote/connection.go +++ /dev/null @@ -1,212 +0,0 @@ -package remote - -import ( - "context" - "fmt" - "io" - "strconv" - "strings" - "time" - - "github.com/mimecast/dtail/internal/clients/handlers" - "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/io/logger" - "github.com/mimecast/dtail/internal/ssh/client" - - "golang.org/x/crypto/ssh" -) - -// Connection represents a client connection connection to a single server. -type Connection struct { - // The remote server's hostname connected to. - Server string - // The remote server's port connected to. - port int - // The SSH client configuration used. - config *ssh.ClientConfig - // The SSH client handler to use. - Handler handlers.Handler - // DTail commands sent from client to server. When client loses - // connection to the server it re-connects automatically and sends the - // same commands again. - Commands []string - // Is it a persistent connection or a one-off? - isOneOff bool - // To deal with SSH server host keys - hostKeyCallback client.HostKeyCallback - // To determine if connection throttling has finished or not - throttlingDone bool -} - -// NewConnection returns a new connection. -func NewConnection(server string, userName string, authMethods []ssh.AuthMethod, hostKeyCallback client.HostKeyCallback) *Connection { - logger.Debug(server, "Creating new connection") - - c := Connection{ - hostKeyCallback: hostKeyCallback, - config: &ssh.ClientConfig{ - User: userName, - Auth: authMethods, - HostKeyCallback: hostKeyCallback.Wrap(), - Timeout: time.Second * 3, - }, - } - - c.initServerPort(server) - - return &c -} - -// NewOneOffConnection creates new one-off connection (only for sending a series of commands and then quit). -func NewOneOffConnection(server string, userName string, authMethods []ssh.AuthMethod) *Connection { - c := Connection{ - config: &ssh.ClientConfig{ - User: userName, - Auth: authMethods, - HostKeyCallback: ssh.InsecureIgnoreHostKey(), - }, - isOneOff: true, - } - - c.initServerPort(server) - - return &c -} - -// Attempt to parse the server port address from the provided server FQDN. -func (c *Connection) initServerPort(server string) { - c.Server = server - c.port = config.Common.SSHPort - parts := strings.Split(server, ":") - - if len(parts) == 2 { - logger.Debug("Parsing port from hostname", parts) - port, err := strconv.Atoi(parts[1]) - if err != nil { - logger.FatalExit("Unable to parse client port", server, parts, err) - } - c.Server = parts[0] - c.port = port - } -} - -// Start the server connection. Build up SSH session and send some DTail commands. -func (c *Connection) Start(ctx context.Context, cancel context.CancelFunc, throttleCh, statsCh chan struct{}) { - // Throttle how many connections can be established concurrently (based on ch length) - logger.Debug(c.Server, "Throttling connection", len(throttleCh), cap(throttleCh)) - - select { - case throttleCh <- struct{}{}: - case <-ctx.Done(): - logger.Debug(c.Server, "Not establishing connection as context is done", len(throttleCh), cap(throttleCh)) - return - } - - logger.Debug(c.Server, "Throttling says that the connection can be established", len(throttleCh), cap(throttleCh)) - - go func() { - defer func() { - if !c.throttlingDone { - logger.Debug(c.Server, "Unthrottling connection (1)", len(throttleCh), cap(throttleCh)) - c.throttlingDone = true - <-throttleCh - } - cancel() - }() - - if err := c.dial(ctx, cancel, throttleCh, statsCh); err != nil { - logger.Warn(c.Server, c.port, err) - if c.hostKeyCallback.Untrusted(fmt.Sprintf("%s:%d", c.Server, c.port)) { - logger.Debug(c.Server, "Not trusting host") - } - } - }() - - <-ctx.Done() -} - -// Dail into a new SSH connection. Close connection in case of an error. -func (c *Connection) dial(ctx context.Context, cancel context.CancelFunc, throttleCh, statsCh chan struct{}) error { - logger.Debug(c.Server, "Incrementing connection stats") - statsCh <- struct{}{} - defer func() { - logger.Debug(c.Server, "Decrementing connection stats") - <-statsCh - }() - - logger.Debug(c.Server, "Dialing into the connection") - address := fmt.Sprintf("%s:%d", c.Server, c.port) - - client, err := ssh.Dial("tcp", address, c.config) - if err != nil { - return err - } - defer client.Close() - - return c.session(ctx, cancel, client, throttleCh) -} - -// Create the SSH session. Close the session in case of an error. -func (c *Connection) session(ctx context.Context, cancel context.CancelFunc, client *ssh.Client, throttleCh chan struct{}) error { - logger.Debug(c.Server, "session") - - session, err := client.NewSession() - if err != nil { - return err - } - defer session.Close() - - return c.handle(ctx, cancel, session, throttleCh) -} - -func (c *Connection) handle(ctx context.Context, cancel context.CancelFunc, session *ssh.Session, throttleCh chan struct{}) error { - logger.Debug(c.Server, "handle") - - stdinPipe, err := session.StdinPipe() - if err != nil { - return err - } - - stdoutPipe, err := session.StdoutPipe() - if err != nil { - return err - } - - if err := session.Shell(); err != nil { - return err - } - - go func() { - io.Copy(stdinPipe, c.Handler) - cancel() - }() - - go func() { - io.Copy(c.Handler, stdoutPipe) - cancel() - }() - - go func() { - select { - case <-c.Handler.Done(): - case <-ctx.Done(): - } - cancel() - }() - - // Send all commands to client. - for _, command := range c.Commands { - logger.Debug(command) - c.Handler.SendMessage(command) - } - - if !c.throttlingDone { - logger.Debug(c.Server, "Unthrottling connection (2)", len(throttleCh), cap(throttleCh)) - c.throttlingDone = true - <-throttleCh - } - - <-ctx.Done() - c.Handler.Shutdown() - return nil -} -- cgit v1.2.3