summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPaul Buetow <pbuetow@mimecast.com>2020-05-13 10:54:34 +0100
committerPaul Buetow <pbuetow@mimecast.com>2020-05-13 10:54:34 +0100
commit1c56e05adcdd8846708557d631dfb6f73bc26db0 (patch)
treef02c1ecfb551a99ff5885dfa47ab1db0d6316cb3
parent077af4d92fe318f8383b026e92f05390d764830a (diff)
fix bug in connection throttling
-rw-r--r--internal/clients/remote/connection.go45
-rw-r--r--internal/version/version.go4
2 files changed, 35 insertions, 14 deletions
diff --git a/internal/clients/remote/connection.go b/internal/clients/remote/connection.go
index f95c6a6..5941115 100644
--- a/internal/clients/remote/connection.go
+++ b/internal/clients/remote/connection.go
@@ -34,6 +34,8 @@ type Connection struct {
isOneOff bool
// To deal with SSH server host keys
hostKeyCallback client.HostKeyCallback
+ // To determine if connection throttling has finished or not
+ throttlingDone bool
}
// NewConnection returns a new connection.
@@ -91,22 +93,31 @@ func (c *Connection) initServerPort(server string) {
// Start the server connection. Build up SSH session and send some DTail commands.
func (c *Connection) Start(ctx context.Context, cancel context.CancelFunc, throttleCh, statsCh chan struct{}) {
// Throttle how many connections can be established concurrently (based on ch length)
+ logger.Debug("Throttling connection", len(throttleCh), cap(throttleCh))
+
select {
case throttleCh <- struct{}{}:
- defer func() { <-throttleCh }()
case <-ctx.Done():
+ logger.Debug("Not establishing connection as context is done", len(throttleCh), cap(throttleCh), c.Server)
return
}
+ logger.Debug("Throttling says that the connection can be established", len(throttleCh), cap(throttleCh), c.Server)
+
go func() {
- defer cancel()
+ defer func() {
+ if !c.throttlingDone {
+ logger.Debug("Unthrottling connection (1)", len(throttleCh), cap(throttleCh), c.Server)
+ c.throttlingDone = true
+ <-throttleCh
+ }
+ cancel()
+ }()
- if err := c.dial(ctx, cancel, c.Server, c.port, statsCh); err != nil {
+ if err := c.dial(ctx, cancel, c.Server, c.port, throttleCh, statsCh); err != nil {
logger.Warn(c.Server, c.port, err)
-
if c.hostKeyCallback.Untrusted(fmt.Sprintf("%s:%d", c.Server, c.port)) {
logger.Debug("Not trusting host", c.Server, c.port)
- return
}
}
}()
@@ -115,11 +126,15 @@ func (c *Connection) Start(ctx context.Context, cancel context.CancelFunc, throt
}
// Dail into a new SSH connection. Close connection in case of an error.
-func (c *Connection) dial(ctx context.Context, cancel context.CancelFunc, host string, port int, statsCh chan struct{}) error {
+func (c *Connection) dial(ctx context.Context, cancel context.CancelFunc, host string, port int, throttleCh, statsCh chan struct{}) error {
+ logger.Debug("Incrementing connection stats", host)
statsCh <- struct{}{}
- defer func() { <-statsCh }()
+ defer func() {
+ logger.Debug("Decrementing connection stats", host)
+ <-statsCh
+ }()
- logger.Debug(host, "dial")
+ logger.Debug("Dialing into the connection", host)
address := fmt.Sprintf("%s:%d", host, port)
client, err := ssh.Dial("tcp", address, c.config)
@@ -128,11 +143,11 @@ func (c *Connection) dial(ctx context.Context, cancel context.CancelFunc, host s
}
defer client.Close()
- return c.session(ctx, cancel, client)
+ return c.session(ctx, cancel, client, throttleCh)
}
// Create the SSH session. Close the session in case of an error.
-func (c *Connection) session(ctx context.Context, cancel context.CancelFunc, client *ssh.Client) error {
+func (c *Connection) session(ctx context.Context, cancel context.CancelFunc, client *ssh.Client, throttleCh chan struct{}) error {
logger.Debug(c.Server, "session")
session, err := client.NewSession()
@@ -141,10 +156,10 @@ func (c *Connection) session(ctx context.Context, cancel context.CancelFunc, cli
}
defer session.Close()
- return c.handle(ctx, cancel, session)
+ return c.handle(ctx, cancel, session, throttleCh)
}
-func (c *Connection) handle(ctx context.Context, cancel context.CancelFunc, session *ssh.Session) error {
+func (c *Connection) handle(ctx context.Context, cancel context.CancelFunc, session *ssh.Session, throttleCh chan struct{}) error {
logger.Debug(c.Server, "handle")
stdinPipe, err := session.StdinPipe()
@@ -185,6 +200,12 @@ func (c *Connection) handle(ctx context.Context, cancel context.CancelFunc, sess
c.Handler.SendMessage(command)
}
+ if !c.throttlingDone {
+ logger.Debug("Unthrottling connection (2)", len(throttleCh), cap(throttleCh), c.Server)
+ c.throttlingDone = true
+ <-throttleCh
+ }
+
<-ctx.Done()
return nil
}
diff --git a/internal/version/version.go b/internal/version/version.go
index ecb9e50..17719b6 100644
--- a/internal/version/version.go
+++ b/internal/version/version.go
@@ -11,9 +11,9 @@ const (
// Name of DTail.
Name string = "DTail"
// Version of DTail.
- Version string = "2.2.0"
+ Version string = "2.2.1"
// Additional information for DTail
- Additional string = ""
+ Additional string = "develop"
// ProtocolCompat -ibility version.
ProtocolCompat string = "2"
)