diff options
| author | Paul Buetow <pbuetow@mimecast.com> | 2020-05-13 10:54:34 +0100 |
|---|---|---|
| committer | Paul Buetow <pbuetow@mimecast.com> | 2020-05-13 10:54:34 +0100 |
| commit | 1c56e05adcdd8846708557d631dfb6f73bc26db0 (patch) | |
| tree | f02c1ecfb551a99ff5885dfa47ab1db0d6316cb3 | |
| parent | 077af4d92fe318f8383b026e92f05390d764830a (diff) | |
fix bug in connection throttling
| -rw-r--r-- | internal/clients/remote/connection.go | 45 | ||||
| -rw-r--r-- | internal/version/version.go | 4 |
2 files changed, 35 insertions, 14 deletions
diff --git a/internal/clients/remote/connection.go b/internal/clients/remote/connection.go index f95c6a6..5941115 100644 --- a/internal/clients/remote/connection.go +++ b/internal/clients/remote/connection.go @@ -34,6 +34,8 @@ type Connection struct { isOneOff bool // To deal with SSH server host keys hostKeyCallback client.HostKeyCallback + // To determine if connection throttling has finished or not + throttlingDone bool } // NewConnection returns a new connection. @@ -91,22 +93,31 @@ func (c *Connection) initServerPort(server string) { // Start the server connection. Build up SSH session and send some DTail commands. func (c *Connection) Start(ctx context.Context, cancel context.CancelFunc, throttleCh, statsCh chan struct{}) { // Throttle how many connections can be established concurrently (based on ch length) + logger.Debug("Throttling connection", len(throttleCh), cap(throttleCh)) + select { case throttleCh <- struct{}{}: - defer func() { <-throttleCh }() case <-ctx.Done(): + logger.Debug("Not establishing connection as context is done", len(throttleCh), cap(throttleCh), c.Server) return } + logger.Debug("Throttling says that the connection can be established", len(throttleCh), cap(throttleCh), c.Server) + go func() { - defer cancel() + defer func() { + if !c.throttlingDone { + logger.Debug("Unthrottling connection (1)", len(throttleCh), cap(throttleCh), c.Server) + c.throttlingDone = true + <-throttleCh + } + cancel() + }() - if err := c.dial(ctx, cancel, c.Server, c.port, statsCh); err != nil { + if err := c.dial(ctx, cancel, c.Server, c.port, throttleCh, statsCh); err != nil { logger.Warn(c.Server, c.port, err) - if c.hostKeyCallback.Untrusted(fmt.Sprintf("%s:%d", c.Server, c.port)) { logger.Debug("Not trusting host", c.Server, c.port) - return } } }() @@ -115,11 +126,15 @@ func (c *Connection) Start(ctx context.Context, cancel context.CancelFunc, throt } // Dail into a new SSH connection. Close connection in case of an error. -func (c *Connection) dial(ctx context.Context, cancel context.CancelFunc, host string, port int, statsCh chan struct{}) error { +func (c *Connection) dial(ctx context.Context, cancel context.CancelFunc, host string, port int, throttleCh, statsCh chan struct{}) error { + logger.Debug("Incrementing connection stats", host) statsCh <- struct{}{} - defer func() { <-statsCh }() + defer func() { + logger.Debug("Decrementing connection stats", host) + <-statsCh + }() - logger.Debug(host, "dial") + logger.Debug("Dialing into the connection", host) address := fmt.Sprintf("%s:%d", host, port) client, err := ssh.Dial("tcp", address, c.config) @@ -128,11 +143,11 @@ func (c *Connection) dial(ctx context.Context, cancel context.CancelFunc, host s } defer client.Close() - return c.session(ctx, cancel, client) + return c.session(ctx, cancel, client, throttleCh) } // Create the SSH session. Close the session in case of an error. -func (c *Connection) session(ctx context.Context, cancel context.CancelFunc, client *ssh.Client) error { +func (c *Connection) session(ctx context.Context, cancel context.CancelFunc, client *ssh.Client, throttleCh chan struct{}) error { logger.Debug(c.Server, "session") session, err := client.NewSession() @@ -141,10 +156,10 @@ func (c *Connection) session(ctx context.Context, cancel context.CancelFunc, cli } defer session.Close() - return c.handle(ctx, cancel, session) + return c.handle(ctx, cancel, session, throttleCh) } -func (c *Connection) handle(ctx context.Context, cancel context.CancelFunc, session *ssh.Session) error { +func (c *Connection) handle(ctx context.Context, cancel context.CancelFunc, session *ssh.Session, throttleCh chan struct{}) error { logger.Debug(c.Server, "handle") stdinPipe, err := session.StdinPipe() @@ -185,6 +200,12 @@ func (c *Connection) handle(ctx context.Context, cancel context.CancelFunc, sess c.Handler.SendMessage(command) } + if !c.throttlingDone { + logger.Debug("Unthrottling connection (2)", len(throttleCh), cap(throttleCh), c.Server) + c.throttlingDone = true + <-throttleCh + } + <-ctx.Done() return nil } diff --git a/internal/version/version.go b/internal/version/version.go index ecb9e50..17719b6 100644 --- a/internal/version/version.go +++ b/internal/version/version.go @@ -11,9 +11,9 @@ const ( // Name of DTail. Name string = "DTail" // Version of DTail. - Version string = "2.2.0" + Version string = "2.2.1" // Additional information for DTail - Additional string = "" + Additional string = "develop" // ProtocolCompat -ibility version. ProtocolCompat string = "2" ) |
