summaryrefslogtreecommitdiff
path: root/internal/clients/remote/connection.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/clients/remote/connection.go')
-rw-r--r--internal/clients/remote/connection.go116
1 files changed, 38 insertions, 78 deletions
diff --git a/internal/clients/remote/connection.go b/internal/clients/remote/connection.go
index bfc7bc5..71639b1 100644
--- a/internal/clients/remote/connection.go
+++ b/internal/clients/remote/connection.go
@@ -1,16 +1,18 @@
package remote
import (
- "github.com/mimecast/dtail/internal/clients/handlers"
- "github.com/mimecast/dtail/internal/config"
- "github.com/mimecast/dtail/internal/logger"
- "github.com/mimecast/dtail/internal/ssh/client"
+ "context"
"fmt"
"io"
"strconv"
"strings"
"time"
+ "github.com/mimecast/dtail/internal/clients/handlers"
+ "github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/ssh/client"
+
"golang.org/x/crypto/ssh"
)
@@ -30,8 +32,6 @@ type Connection struct {
Commands []string
// Is it a persistent connection or a one-off?
isOneOff bool
- // Used to stop the connection
- stop chan struct{}
// To deal with SSH server host keys
hostKeyCallback *client.HostKeyCallback
}
@@ -48,7 +48,6 @@ func NewConnection(server string, userName string, authMethods []ssh.AuthMethod,
HostKeyCallback: hostKeyCallback.Wrap(),
Timeout: time.Second * 3,
},
- stop: make(chan struct{}),
}
c.initServerPort(server)
@@ -64,7 +63,6 @@ func NewOneOffConnection(server string, userName string, authMethods []ssh.AuthM
Auth: authMethods,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
},
- stop: make(chan struct{}),
isOneOff: true,
}
@@ -90,39 +88,34 @@ func (c *Connection) initServerPort(server string) {
}
}
-// Start the server connection. Build up SSH session and send some DTail commandc.
-func (c *Connection) Start(throttleCh, statsCh chan struct{}) {
+// Start the server connection. Build up SSH session and send some DTail commands.
+func (c *Connection) Start(ctx context.Context, cancel context.CancelFunc, throttleCh, statsCh chan struct{}) {
+ // Throttle how many connections can be established concurrently (based on ch length)
select {
- case <-c.stop:
- logger.Info(c.Server, c.port, "Disconnecting client")
+ case throttleCh <- struct{}{}:
+ defer func() { <-throttleCh }()
+ case <-ctx.Done():
return
- default:
}
- // Wait for SSH connection throttler
- throttleCh <- struct{}{}
-
- // Wait until connection has been initiated or an error occured
- // during initialization.
- throttleStopCh := make(chan struct{}, 2)
go func() {
- <-throttleStopCh
- <-throttleCh
- }()
+ defer cancel()
- if err := c.dial(c.Server, c.port, throttleStopCh, statsCh); err != nil {
- logger.Warn(c.Server, c.port, err)
- throttleStopCh <- struct{}{}
+ if err := c.dial(ctx, cancel, c.Server, c.port, statsCh); err != nil {
+ logger.Warn(c.Server, c.port, err)
- if c.hostKeyCallback.Untrusted(fmt.Sprintf("%s:%d", c.Server, c.port)) {
- logger.Debug("Not trusting host, not trying to re-connect", c.Server, c.port)
- return
+ if c.hostKeyCallback.Untrusted(fmt.Sprintf("%s:%d", c.Server, c.port)) {
+ logger.Debug("Not trusting host", c.Server, c.port)
+ return
+ }
}
- }
+ }()
+
+ <-ctx.Done()
}
// Dail into a new SSH connection. Close connection in case of an error.
-func (c *Connection) dial(host string, port int, throttleStopCh, statsCh chan struct{}) error {
+func (c *Connection) dial(ctx context.Context, cancel context.CancelFunc, host string, port int, statsCh chan struct{}) error {
statsCh <- struct{}{}
defer func() { <-statsCh }()
@@ -135,11 +128,11 @@ func (c *Connection) dial(host string, port int, throttleStopCh, statsCh chan st
}
defer client.Close()
- return c.session(client, throttleStopCh)
+ return c.session(ctx, cancel, client)
}
// Create the SSH session. Close the session in case of an error.
-func (c *Connection) session(client *ssh.Client, throttleStopCh chan<- struct{}) error {
+func (c *Connection) session(ctx context.Context, cancel context.CancelFunc, client *ssh.Client) error {
logger.Debug(c.Server, "session")
session, err := client.NewSession()
@@ -148,14 +141,10 @@ func (c *Connection) session(client *ssh.Client, throttleStopCh chan<- struct{})
}
defer session.Close()
- return c.handle(session, throttleStopCh)
+ return c.handle(ctx, cancel, session)
}
-// Handle the SSH session. Also send periodic pings to the server in order
-// to determine that session is still intact.
-func (c *Connection) handle(session *ssh.Session, throttleStopCh chan<- struct{}) error {
- defer c.Handler.Stop()
-
+func (c *Connection) handle(ctx context.Context, cancel context.CancelFunc, session *ssh.Session) error {
logger.Debug(c.Server, "handle")
stdinPipe, err := session.StdinPipe()
@@ -172,59 +161,30 @@ func (c *Connection) handle(session *ssh.Session, throttleStopCh chan<- struct{}
return err
}
- // Establish Bi-directional pipe between SSH session and client handler.
- brokenStdinPipe := make(chan struct{})
go func() {
- defer close(brokenStdinPipe)
+ defer cancel()
io.Copy(stdinPipe, c.Handler)
}()
- brokenStdoutPipe := make(chan struct{})
go func() {
- defer close(brokenStdoutPipe)
+ defer cancel()
io.Copy(c.Handler, stdoutPipe)
}()
- // SSH session established, other goroutine can initiate session now.
- throttleStopCh <- struct{}{}
+ go func() {
+ defer cancel()
+ select {
+ case <-c.Handler.Done():
+ case <-ctx.Done():
+ }
+ }()
// Send all commands to client.
for _, command := range c.Commands {
logger.Debug(command)
- c.Handler.SendCommand(command)
+ c.Handler.SendMessage(command)
}
- if !c.isOneOff {
- return c.periodicAliveCheck(brokenStdinPipe, brokenStdoutPipe)
- }
-
- <-c.stop
-
- // Normal shutdown, all fine
+ <-ctx.Done()
return nil
}
-
-// Periodically check whether connection is still alive or not.
-func (c *Connection) periodicAliveCheck(brokenStdinPipe, brokenStdoutPipe <-chan struct{}) error {
- for {
- select {
- case <-time.After(time.Second * 3):
- if err := c.Handler.Ping(); err != nil {
- return err
- }
- case <-brokenStdinPipe:
- logger.Debug("Broken stdin pipe", c.Server, c.port)
- return nil
- case <-brokenStdoutPipe:
- logger.Debug("Broken stdout pipe", c.Server, c.port)
- return nil
- case <-c.stop:
- return nil
- }
- }
-}
-
-// Stop the connection.
-func (c *Connection) Stop() {
- close(c.stop)
-}