diff options
Diffstat (limited to 'internal/clients/baseclient.go')
| -rw-r--r-- | internal/clients/baseclient.go | 130 |
1 files changed, 65 insertions, 65 deletions
diff --git a/internal/clients/baseclient.go b/internal/clients/baseclient.go index 574ae94..b1540ea 100644 --- a/internal/clients/baseclient.go +++ b/internal/clients/baseclient.go @@ -1,13 +1,14 @@ package clients import ( + "context" "regexp" "sync" "time" "github.com/mimecast/dtail/internal/clients/remote" "github.com/mimecast/dtail/internal/discovery" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/omode" "github.com/mimecast/dtail/internal/ssh/client" @@ -27,111 +28,110 @@ type baseClient struct { sshAuthMethods []gossh.AuthMethod // To deal with SSH host keys hostKeyCallback *client.HostKeyCallback - // To stop the client. - stop chan struct{} - // To indicate that the client has stopped. - stopped chan struct{} // Throttle how fast we initiate SSH connections concurrently throttleCh chan struct{} // Retry connection upon failure? retry bool - // Connection helper. - maker connectionMaker + // Connection maker helper. + maker maker } -func (c *baseClient) init(maker connectionMaker) { +func (c *baseClient) init(maker maker) { logger.Info("Initiating base client") c.maker = maker - //c.connections = make(map[string]*remote.Connection) c.sshAuthMethods, c.hostKeyCallback = client.InitSSHAuthMethods(c.TrustAllHosts, c.throttleCh) + discoveryService := discovery.New(c.Discovery, c.ServersStr, discovery.Shuffle) - // Retrieve a shuffled list of remote dtail servers. - shuffleServers := true - discoveryService := discovery.New(c.Discovery, c.ServersStr, shuffleServers) for _, server := range discoveryService.ServerList() { - c.connections = append(c.connections, c.maker.makeConnection(server, c.sshAuthMethods, c.hostKeyCallback)) + c.connections = append(c.connections, c.makeConnection(server, c.sshAuthMethods, c.hostKeyCallback)) } if _, err := regexp.Compile(c.Regex); err != nil { logger.FatalExit(c.Regex, "Can't test compile regex", err) } - // Periodically check for unknown hosts, and ask the user whether to trust them or not. - go c.hostKeyCallback.PromptAddHosts(c.stop) - - // Periodically print out connection stats to the client. c.stats = newTailStats(len(c.connections)) - go c.stats.periodicLogStats(c.throttleCh, c.stop) } -func (c *baseClient) Start() (status int) { +func (c *baseClient) Start(ctx context.Context) (status int) { + // Periodically check for unknown hosts, and ask the user whether to trust them or not. + go c.hostKeyCallback.PromptAddHosts(ctx) + // Periodically print out connection stats to the client. + go c.stats.periodicLogStats(ctx, c.throttleCh) + // Keep count of active connections active := make(chan struct{}, len(c.connections)) - var wg sync.WaitGroup - wg.Add(len(c.connections)) - + var mutex sync.Mutex for i, conn := range c.connections { go func(i int, conn *remote.Connection) { - active <- struct{}{} - defer func() { - logger.Debug(conn.Server, "Disconnected completely...") - <-active - }() - wg.Done() - - for { - conn.Start(c.throttleCh, c.stats.connectionsEstCh) - if !c.retry { - return - } - time.Sleep(time.Second * 2) - logger.Debug(conn.Server, "Reconencting") - conn = c.maker.makeConnection(conn.Server, c.sshAuthMethods, c.hostKeyCallback) - c.connections[i] = conn + connStatus := c.start(ctx, active, i, conn) + + // Update global status. + mutex.Lock() + defer mutex.Unlock() + if connStatus > status { + status = connStatus } }(i, conn) } - wg.Wait() - c.waitUntilDone(active) - + c.waitUntilDone(ctx, active) return } -func (c *baseClient) waitUntilDone(active chan struct{}) { - defer close(c.stopped) +func (c *baseClient) start(ctx context.Context, active chan struct{}, i int, conn *remote.Connection) (status int) { + // Increment connection count + active <- struct{}{} + // Derement connection count + defer func() { <-active }() - if c.Mode != omode.TailClient { - c.waitUntilZero(active) - logger.Info("All connections stopped") - return - } + for { + connCtx, cancel := conn.Handler.WithCancel(ctx) + defer cancel() - <-c.stop - logger.Info("Stopping client") - for _, conn := range c.connections { - conn.Stop() + conn.Start(connCtx, cancel, c.throttleCh, c.stats.connectionsEstCh) + // Retrieve status code from handler (dtail client will exit with that status) + status = conn.Handler.Status() + + if !c.retry { + return + } + + time.Sleep(time.Second * 2) + logger.Debug(conn.Server, "Reconnecting") + + conn = c.makeConnection(conn.Server, c.sshAuthMethods, c.hostKeyCallback) + c.connections[i] = conn } +} - c.waitUntilZero(active) +func (c *baseClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection { + conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback) + conn.Handler = c.maker.makeHandler(server) + conn.Commands = c.maker.makeCommands() + + return conn } -func (c *baseClient) waitUntilZero(active chan struct{}) { +func (c *baseClient) waitUntilDone(ctx context.Context, active chan struct{}) { + defer logger.Info("Terminated connection") + + // We want to have at least one active connection + <-active + // Put it back on the channel + active <- struct{}{} + + if c.Mode == omode.TailClient { + <-ctx.Done() + } + for { - logger.Debug("Active connections", len(active)) - if len(active) == 0 { + numActive := len(active) + if numActive == 0 { return } + logger.Debug("Active connections", numActive) time.Sleep(time.Second) } } - -func (c *baseClient) Stop() { - close(c.stop) - <-c.WaitC() -} - -func (c *baseClient) WaitC() <-chan struct{} { - return c.stopped -} |
