summaryrefslogtreecommitdiff
path: root/internal/clients/baseclient.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/clients/baseclient.go')
-rw-r--r--internal/clients/baseclient.go130
1 files changed, 65 insertions, 65 deletions
diff --git a/internal/clients/baseclient.go b/internal/clients/baseclient.go
index 574ae94..b1540ea 100644
--- a/internal/clients/baseclient.go
+++ b/internal/clients/baseclient.go
@@ -1,13 +1,14 @@
package clients
import (
+ "context"
"regexp"
"sync"
"time"
"github.com/mimecast/dtail/internal/clients/remote"
"github.com/mimecast/dtail/internal/discovery"
- "github.com/mimecast/dtail/internal/logger"
+ "github.com/mimecast/dtail/internal/io/logger"
"github.com/mimecast/dtail/internal/omode"
"github.com/mimecast/dtail/internal/ssh/client"
@@ -27,111 +28,110 @@ type baseClient struct {
sshAuthMethods []gossh.AuthMethod
// To deal with SSH host keys
hostKeyCallback *client.HostKeyCallback
- // To stop the client.
- stop chan struct{}
- // To indicate that the client has stopped.
- stopped chan struct{}
// Throttle how fast we initiate SSH connections concurrently
throttleCh chan struct{}
// Retry connection upon failure?
retry bool
- // Connection helper.
- maker connectionMaker
+ // Connection maker helper.
+ maker maker
}
-func (c *baseClient) init(maker connectionMaker) {
+func (c *baseClient) init(maker maker) {
logger.Info("Initiating base client")
c.maker = maker
- //c.connections = make(map[string]*remote.Connection)
c.sshAuthMethods, c.hostKeyCallback = client.InitSSHAuthMethods(c.TrustAllHosts, c.throttleCh)
+ discoveryService := discovery.New(c.Discovery, c.ServersStr, discovery.Shuffle)
- // Retrieve a shuffled list of remote dtail servers.
- shuffleServers := true
- discoveryService := discovery.New(c.Discovery, c.ServersStr, shuffleServers)
for _, server := range discoveryService.ServerList() {
- c.connections = append(c.connections, c.maker.makeConnection(server, c.sshAuthMethods, c.hostKeyCallback))
+ c.connections = append(c.connections, c.makeConnection(server, c.sshAuthMethods, c.hostKeyCallback))
}
if _, err := regexp.Compile(c.Regex); err != nil {
logger.FatalExit(c.Regex, "Can't test compile regex", err)
}
- // Periodically check for unknown hosts, and ask the user whether to trust them or not.
- go c.hostKeyCallback.PromptAddHosts(c.stop)
-
- // Periodically print out connection stats to the client.
c.stats = newTailStats(len(c.connections))
- go c.stats.periodicLogStats(c.throttleCh, c.stop)
}
-func (c *baseClient) Start() (status int) {
+func (c *baseClient) Start(ctx context.Context) (status int) {
+ // Periodically check for unknown hosts, and ask the user whether to trust them or not.
+ go c.hostKeyCallback.PromptAddHosts(ctx)
+ // Periodically print out connection stats to the client.
+ go c.stats.periodicLogStats(ctx, c.throttleCh)
+ // Keep count of active connections
active := make(chan struct{}, len(c.connections))
- var wg sync.WaitGroup
- wg.Add(len(c.connections))
-
+ var mutex sync.Mutex
for i, conn := range c.connections {
go func(i int, conn *remote.Connection) {
- active <- struct{}{}
- defer func() {
- logger.Debug(conn.Server, "Disconnected completely...")
- <-active
- }()
- wg.Done()
-
- for {
- conn.Start(c.throttleCh, c.stats.connectionsEstCh)
- if !c.retry {
- return
- }
- time.Sleep(time.Second * 2)
- logger.Debug(conn.Server, "Reconencting")
- conn = c.maker.makeConnection(conn.Server, c.sshAuthMethods, c.hostKeyCallback)
- c.connections[i] = conn
+ connStatus := c.start(ctx, active, i, conn)
+
+ // Update global status.
+ mutex.Lock()
+ defer mutex.Unlock()
+ if connStatus > status {
+ status = connStatus
}
}(i, conn)
}
- wg.Wait()
- c.waitUntilDone(active)
-
+ c.waitUntilDone(ctx, active)
return
}
-func (c *baseClient) waitUntilDone(active chan struct{}) {
- defer close(c.stopped)
+func (c *baseClient) start(ctx context.Context, active chan struct{}, i int, conn *remote.Connection) (status int) {
+ // Increment connection count
+ active <- struct{}{}
+ // Derement connection count
+ defer func() { <-active }()
- if c.Mode != omode.TailClient {
- c.waitUntilZero(active)
- logger.Info("All connections stopped")
- return
- }
+ for {
+ connCtx, cancel := conn.Handler.WithCancel(ctx)
+ defer cancel()
- <-c.stop
- logger.Info("Stopping client")
- for _, conn := range c.connections {
- conn.Stop()
+ conn.Start(connCtx, cancel, c.throttleCh, c.stats.connectionsEstCh)
+ // Retrieve status code from handler (dtail client will exit with that status)
+ status = conn.Handler.Status()
+
+ if !c.retry {
+ return
+ }
+
+ time.Sleep(time.Second * 2)
+ logger.Debug(conn.Server, "Reconnecting")
+
+ conn = c.makeConnection(conn.Server, c.sshAuthMethods, c.hostKeyCallback)
+ c.connections[i] = conn
}
+}
- c.waitUntilZero(active)
+func (c *baseClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection {
+ conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback)
+ conn.Handler = c.maker.makeHandler(server)
+ conn.Commands = c.maker.makeCommands()
+
+ return conn
}
-func (c *baseClient) waitUntilZero(active chan struct{}) {
+func (c *baseClient) waitUntilDone(ctx context.Context, active chan struct{}) {
+ defer logger.Info("Terminated connection")
+
+ // We want to have at least one active connection
+ <-active
+ // Put it back on the channel
+ active <- struct{}{}
+
+ if c.Mode == omode.TailClient {
+ <-ctx.Done()
+ }
+
for {
- logger.Debug("Active connections", len(active))
- if len(active) == 0 {
+ numActive := len(active)
+ if numActive == 0 {
return
}
+ logger.Debug("Active connections", numActive)
time.Sleep(time.Second)
}
}
-
-func (c *baseClient) Stop() {
- close(c.stop)
- <-c.WaitC()
-}
-
-func (c *baseClient) WaitC() <-chan struct{} {
- return c.stopped
-}