diff options
Diffstat (limited to 'internal/clients/baseclient.go')
| -rw-r--r-- | internal/clients/baseclient.go | 131 |
1 files changed, 40 insertions, 91 deletions
diff --git a/internal/clients/baseclient.go b/internal/clients/baseclient.go index f83fcfd..4a7bd84 100644 --- a/internal/clients/baseclient.go +++ b/internal/clients/baseclient.go @@ -2,15 +2,13 @@ package clients import ( "context" - "fmt" - "strings" "sync" "time" - "github.com/mimecast/dtail/internal/clients/remote" + "github.com/mimecast/dtail/internal/clients/connectors" + "github.com/mimecast/dtail/internal/config" "github.com/mimecast/dtail/internal/discovery" - "github.com/mimecast/dtail/internal/io/logger" - "github.com/mimecast/dtail/internal/omode" + "github.com/mimecast/dtail/internal/io/dlog" "github.com/mimecast/dtail/internal/regex" "github.com/mimecast/dtail/internal/ssh/client" @@ -19,13 +17,13 @@ import ( // This is the main client data structure. type baseClient struct { - Args + config.Args // To display client side stats stats *stats // List of remote servers to connect to. servers []string // We have one connection per remote server. - connections []*remote.Connection + connections []connectors.Connector // SSH auth methods to use to connect to the remote servers. sshAuthMethods []gossh.AuthMethod // To deal with SSH host keys @@ -41,7 +39,7 @@ type baseClient struct { } func (c *baseClient) init() { - logger.Debug("Initiating base client") + dlog.Client.Debug("Initiating base client", c.Args.String()) flag := regex.Default if c.Args.RegexInvert { @@ -49,12 +47,16 @@ func (c *baseClient) init() { } regex, err := regex.New(c.Args.RegexStr, flag) if err != nil { - logger.FatalExit(c.Regex, "invalid regex!", err, regex) + dlog.Client.FatalPanic(c.Regex, "Invalid regex!", err, regex) } c.Regex = regex - logger.Debug("Regex", c.Regex) - c.sshAuthMethods, c.hostKeyCallback = client.InitSSHAuthMethods(c.Args.SSHAuthMethods, c.Args.SSHHostKeyCallback, c.Args.TrustAllHosts, c.throttleCh, c.Args.PrivateKeyPathFile) + if c.Args.Serverless { + return + } + c.sshAuthMethods, c.hostKeyCallback = client.InitSSHAuthMethods( + c.Args.SSHAuthMethods, c.Args.SSHHostKeyCallback, c.Args.TrustAllHosts, + c.throttleCh, c.Args.PrivateKeyPathFile) } func (c *baseClient) makeConnections(maker maker) { @@ -62,26 +64,31 @@ func (c *baseClient) makeConnections(maker maker) { discoveryService := discovery.New(c.Discovery, c.ServersStr, discovery.Shuffle) for _, server := range discoveryService.ServerList() { - c.connections = append(c.connections, c.makeConnection(server, c.sshAuthMethods, c.hostKeyCallback)) + c.connections = append(c.connections, c.makeConnection(server, + c.sshAuthMethods, c.hostKeyCallback)) } c.stats = newTailStats(len(c.connections)) } func (c *baseClient) Start(ctx context.Context, statsCh <-chan string) (status int) { - // Periodically check for unknown hosts, and ask the user whether to trust them or not. - go c.hostKeyCallback.PromptAddHosts(ctx) + dlog.Client.Trace("Starting base client") + // Can be nil when serverless. + if c.hostKeyCallback != nil { + // Periodically check for unknown hosts, and ask the user whether to trust them or not. + go c.hostKeyCallback.PromptAddHosts(ctx) + } // Print client stats every time something on statsCh is recieved. go c.stats.Start(ctx, c.throttleCh, statsCh, c.Args.Quiet) - // Keep count of active connections - active := make(chan struct{}, len(c.connections)) + var wg sync.WaitGroup + wg.Add(len(c.connections)) var mutex sync.Mutex - for i, conn := range c.connections { - go func(i int, conn *remote.Connection) { - connStatus := c.start(ctx, active, i, conn) - // Update global status. + for i, conn := range c.connections { + go func(i int, conn connectors.Connector) { + defer wg.Done() + connStatus := c.startConnection(ctx, i, conn) mutex.Lock() defer mutex.Unlock() if connStatus > status { @@ -90,15 +97,12 @@ func (c *baseClient) Start(ctx context.Context, statsCh <-chan string) (status i }(i, conn) } - c.waitUntilDone(ctx, active) + wg.Wait() return } -func (c *baseClient) start(ctx context.Context, active chan struct{}, i int, conn *remote.Connection) (status int) { - // Increment connection count - active <- struct{}{} - // Derement connection count - defer func() { <-active }() +func (c *baseClient) startConnection(ctx context.Context, i int, + conn connectors.Connector) (status int) { for { connCtx, cancel := context.WithCancel(ctx) @@ -106,80 +110,25 @@ func (c *baseClient) start(ctx context.Context, active chan struct{}, i int, con conn.Start(connCtx, cancel, c.throttleCh, c.stats.connectionsEstCh) // Retrieve status code from handler (dtail client will exit with that status) - status = conn.Handler.Status() + status = conn.Handler().Status() if !c.retry { return } time.Sleep(time.Second * 2) - logger.Debug(conn.Server, "Reconnecting") - - conn = c.makeConnection(conn.Server, c.sshAuthMethods, c.hostKeyCallback) + dlog.Client.Debug(conn.Server(), "Reconnecting") + conn = c.makeConnection(conn.Server(), c.sshAuthMethods, c.hostKeyCallback) c.connections[i] = conn } } -func (c *baseClient) makeCommandOptions() map[string]string { - options := make(map[string]string) - - if c.Args.Quiet { - options["quiet"] = fmt.Sprintf("%v", c.Args.Quiet) - } - if c.Args.LContext.MaxCount != 0 { - options["max"] = fmt.Sprintf("%d", c.Args.LContext.MaxCount) - } - if c.Args.LContext.BeforeContext != 0 { - options["before"] = fmt.Sprintf("%d", c.Args.LContext.BeforeContext) - } - if c.Args.LContext.AfterContext != 0 { - options["after"] = fmt.Sprintf("%d", c.Args.LContext.AfterContext) - } - - return options -} - -func (c *baseClient) commandOptionsToString(options map[string]string) string { - var sb strings.Builder - - count := 0 - for k, v := range options { - if count > 0 { - sb.WriteString(":") - } - sb.WriteString(fmt.Sprintf("%s=%s", k, v)) - count++ - } - - return sb.String() -} - -func (c *baseClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback client.HostKeyCallback) *remote.Connection { - conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback) - conn.Handler = c.maker.makeHandler(server) - conn.Commands = c.maker.makeCommands(c.makeCommandOptions()) - - return conn -} - -func (c *baseClient) waitUntilDone(ctx context.Context, active chan struct{}) { - defer logger.Debug("Terminated connection") - - // We want to have at least one active connection - <-active - // Put it back on the channel - active <- struct{}{} - - if c.Mode == omode.TailClient && c.retry { - <-ctx.Done() - } - - for { - numActive := len(active) - if numActive == 0 { - return - } - logger.Debug("Active connections", numActive) - time.Sleep(time.Second) +func (c *baseClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, + hostKeyCallback client.HostKeyCallback) connectors.Connector { + if c.Args.Serverless { + return connectors.NewServerless(c.UserName, c.maker.makeHandler(server), + c.maker.makeCommands()) } + return connectors.NewServerConnection(server, c.UserName, sshAuthMethods, + hostKeyCallback, c.maker.makeHandler(server), c.maker.makeCommands()) } |
