diff options
Diffstat (limited to 'internal/clients/remote/connection.go')
| -rw-r--r-- | internal/clients/remote/connection.go | 116 |
1 files changed, 38 insertions, 78 deletions
diff --git a/internal/clients/remote/connection.go b/internal/clients/remote/connection.go index bfc7bc5..71639b1 100644 --- a/internal/clients/remote/connection.go +++ b/internal/clients/remote/connection.go @@ -1,16 +1,18 @@ package remote import ( - "github.com/mimecast/dtail/internal/clients/handlers" - "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/logger" - "github.com/mimecast/dtail/internal/ssh/client" + "context" "fmt" "io" "strconv" "strings" "time" + "github.com/mimecast/dtail/internal/clients/handlers" + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/ssh/client" + "golang.org/x/crypto/ssh" ) @@ -30,8 +32,6 @@ type Connection struct { Commands []string // Is it a persistent connection or a one-off? isOneOff bool - // Used to stop the connection - stop chan struct{} // To deal with SSH server host keys hostKeyCallback *client.HostKeyCallback } @@ -48,7 +48,6 @@ func NewConnection(server string, userName string, authMethods []ssh.AuthMethod, HostKeyCallback: hostKeyCallback.Wrap(), Timeout: time.Second * 3, }, - stop: make(chan struct{}), } c.initServerPort(server) @@ -64,7 +63,6 @@ func NewOneOffConnection(server string, userName string, authMethods []ssh.AuthM Auth: authMethods, HostKeyCallback: ssh.InsecureIgnoreHostKey(), }, - stop: make(chan struct{}), isOneOff: true, } @@ -90,39 +88,34 @@ func (c *Connection) initServerPort(server string) { } } -// Start the server connection. Build up SSH session and send some DTail commandc. -func (c *Connection) Start(throttleCh, statsCh chan struct{}) { +// Start the server connection. Build up SSH session and send some DTail commands. +func (c *Connection) Start(ctx context.Context, cancel context.CancelFunc, throttleCh, statsCh chan struct{}) { + // Throttle how many connections can be established concurrently (based on ch length) select { - case <-c.stop: - logger.Info(c.Server, c.port, "Disconnecting client") + case throttleCh <- struct{}{}: + defer func() { <-throttleCh }() + case <-ctx.Done(): return - default: } - // Wait for SSH connection throttler - throttleCh <- struct{}{} - - // Wait until connection has been initiated or an error occured - // during initialization. - throttleStopCh := make(chan struct{}, 2) go func() { - <-throttleStopCh - <-throttleCh - }() + defer cancel() - if err := c.dial(c.Server, c.port, throttleStopCh, statsCh); err != nil { - logger.Warn(c.Server, c.port, err) - throttleStopCh <- struct{}{} + if err := c.dial(ctx, cancel, c.Server, c.port, statsCh); err != nil { + logger.Warn(c.Server, c.port, err) - if c.hostKeyCallback.Untrusted(fmt.Sprintf("%s:%d", c.Server, c.port)) { - logger.Debug("Not trusting host, not trying to re-connect", c.Server, c.port) - return + if c.hostKeyCallback.Untrusted(fmt.Sprintf("%s:%d", c.Server, c.port)) { + logger.Debug("Not trusting host", c.Server, c.port) + return + } } - } + }() + + <-ctx.Done() } // Dail into a new SSH connection. Close connection in case of an error. -func (c *Connection) dial(host string, port int, throttleStopCh, statsCh chan struct{}) error { +func (c *Connection) dial(ctx context.Context, cancel context.CancelFunc, host string, port int, statsCh chan struct{}) error { statsCh <- struct{}{} defer func() { <-statsCh }() @@ -135,11 +128,11 @@ func (c *Connection) dial(host string, port int, throttleStopCh, statsCh chan st } defer client.Close() - return c.session(client, throttleStopCh) + return c.session(ctx, cancel, client) } // Create the SSH session. Close the session in case of an error. -func (c *Connection) session(client *ssh.Client, throttleStopCh chan<- struct{}) error { +func (c *Connection) session(ctx context.Context, cancel context.CancelFunc, client *ssh.Client) error { logger.Debug(c.Server, "session") session, err := client.NewSession() @@ -148,14 +141,10 @@ func (c *Connection) session(client *ssh.Client, throttleStopCh chan<- struct{}) } defer session.Close() - return c.handle(session, throttleStopCh) + return c.handle(ctx, cancel, session) } -// Handle the SSH session. Also send periodic pings to the server in order -// to determine that session is still intact. -func (c *Connection) handle(session *ssh.Session, throttleStopCh chan<- struct{}) error { - defer c.Handler.Stop() - +func (c *Connection) handle(ctx context.Context, cancel context.CancelFunc, session *ssh.Session) error { logger.Debug(c.Server, "handle") stdinPipe, err := session.StdinPipe() @@ -172,59 +161,30 @@ func (c *Connection) handle(session *ssh.Session, throttleStopCh chan<- struct{} return err } - // Establish Bi-directional pipe between SSH session and client handler. - brokenStdinPipe := make(chan struct{}) go func() { - defer close(brokenStdinPipe) + defer cancel() io.Copy(stdinPipe, c.Handler) }() - brokenStdoutPipe := make(chan struct{}) go func() { - defer close(brokenStdoutPipe) + defer cancel() io.Copy(c.Handler, stdoutPipe) }() - // SSH session established, other goroutine can initiate session now. - throttleStopCh <- struct{}{} + go func() { + defer cancel() + select { + case <-c.Handler.Done(): + case <-ctx.Done(): + } + }() // Send all commands to client. for _, command := range c.Commands { logger.Debug(command) - c.Handler.SendCommand(command) + c.Handler.SendMessage(command) } - if !c.isOneOff { - return c.periodicAliveCheck(brokenStdinPipe, brokenStdoutPipe) - } - - <-c.stop - - // Normal shutdown, all fine + <-ctx.Done() return nil } - -// Periodically check whether connection is still alive or not. -func (c *Connection) periodicAliveCheck(brokenStdinPipe, brokenStdoutPipe <-chan struct{}) error { - for { - select { - case <-time.After(time.Second * 3): - if err := c.Handler.Ping(); err != nil { - return err - } - case <-brokenStdinPipe: - logger.Debug("Broken stdin pipe", c.Server, c.port) - return nil - case <-brokenStdoutPipe: - logger.Debug("Broken stdout pipe", c.Server, c.port) - return nil - case <-c.stop: - return nil - } - } -} - -// Stop the connection. -func (c *Connection) Stop() { - close(c.stop) -} |
