summaryrefslogtreecommitdiff
path: root/internal/clients/baseclient.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/clients/baseclient.go')
-rw-r--r--internal/clients/baseclient.go131
1 files changed, 40 insertions, 91 deletions
diff --git a/internal/clients/baseclient.go b/internal/clients/baseclient.go
index f83fcfd..4a7bd84 100644
--- a/internal/clients/baseclient.go
+++ b/internal/clients/baseclient.go
@@ -2,15 +2,13 @@ package clients
import (
"context"
- "fmt"
- "strings"
"sync"
"time"
- "github.com/mimecast/dtail/internal/clients/remote"
+ "github.com/mimecast/dtail/internal/clients/connectors"
+ "github.com/mimecast/dtail/internal/config"
"github.com/mimecast/dtail/internal/discovery"
- "github.com/mimecast/dtail/internal/io/logger"
- "github.com/mimecast/dtail/internal/omode"
+ "github.com/mimecast/dtail/internal/io/dlog"
"github.com/mimecast/dtail/internal/regex"
"github.com/mimecast/dtail/internal/ssh/client"
@@ -19,13 +17,13 @@ import (
// This is the main client data structure.
type baseClient struct {
- Args
+ config.Args
// To display client side stats
stats *stats
// List of remote servers to connect to.
servers []string
// We have one connection per remote server.
- connections []*remote.Connection
+ connections []connectors.Connector
// SSH auth methods to use to connect to the remote servers.
sshAuthMethods []gossh.AuthMethod
// To deal with SSH host keys
@@ -41,7 +39,7 @@ type baseClient struct {
}
func (c *baseClient) init() {
- logger.Debug("Initiating base client")
+ dlog.Client.Debug("Initiating base client", c.Args.String())
flag := regex.Default
if c.Args.RegexInvert {
@@ -49,12 +47,16 @@ func (c *baseClient) init() {
}
regex, err := regex.New(c.Args.RegexStr, flag)
if err != nil {
- logger.FatalExit(c.Regex, "invalid regex!", err, regex)
+ dlog.Client.FatalPanic(c.Regex, "Invalid regex!", err, regex)
}
c.Regex = regex
- logger.Debug("Regex", c.Regex)
- c.sshAuthMethods, c.hostKeyCallback = client.InitSSHAuthMethods(c.Args.SSHAuthMethods, c.Args.SSHHostKeyCallback, c.Args.TrustAllHosts, c.throttleCh, c.Args.PrivateKeyPathFile)
+ if c.Args.Serverless {
+ return
+ }
+ c.sshAuthMethods, c.hostKeyCallback = client.InitSSHAuthMethods(
+ c.Args.SSHAuthMethods, c.Args.SSHHostKeyCallback, c.Args.TrustAllHosts,
+ c.throttleCh, c.Args.PrivateKeyPathFile)
}
func (c *baseClient) makeConnections(maker maker) {
@@ -62,26 +64,31 @@ func (c *baseClient) makeConnections(maker maker) {
discoveryService := discovery.New(c.Discovery, c.ServersStr, discovery.Shuffle)
for _, server := range discoveryService.ServerList() {
- c.connections = append(c.connections, c.makeConnection(server, c.sshAuthMethods, c.hostKeyCallback))
+ c.connections = append(c.connections, c.makeConnection(server,
+ c.sshAuthMethods, c.hostKeyCallback))
}
c.stats = newTailStats(len(c.connections))
}
func (c *baseClient) Start(ctx context.Context, statsCh <-chan string) (status int) {
- // Periodically check for unknown hosts, and ask the user whether to trust them or not.
- go c.hostKeyCallback.PromptAddHosts(ctx)
+ dlog.Client.Trace("Starting base client")
+ // Can be nil when serverless.
+ if c.hostKeyCallback != nil {
+ // Periodically check for unknown hosts, and ask the user whether to trust them or not.
+ go c.hostKeyCallback.PromptAddHosts(ctx)
+ }
// Print client stats every time something on statsCh is recieved.
go c.stats.Start(ctx, c.throttleCh, statsCh, c.Args.Quiet)
- // Keep count of active connections
- active := make(chan struct{}, len(c.connections))
+ var wg sync.WaitGroup
+ wg.Add(len(c.connections))
var mutex sync.Mutex
- for i, conn := range c.connections {
- go func(i int, conn *remote.Connection) {
- connStatus := c.start(ctx, active, i, conn)
- // Update global status.
+ for i, conn := range c.connections {
+ go func(i int, conn connectors.Connector) {
+ defer wg.Done()
+ connStatus := c.startConnection(ctx, i, conn)
mutex.Lock()
defer mutex.Unlock()
if connStatus > status {
@@ -90,15 +97,12 @@ func (c *baseClient) Start(ctx context.Context, statsCh <-chan string) (status i
}(i, conn)
}
- c.waitUntilDone(ctx, active)
+ wg.Wait()
return
}
-func (c *baseClient) start(ctx context.Context, active chan struct{}, i int, conn *remote.Connection) (status int) {
- // Increment connection count
- active <- struct{}{}
- // Derement connection count
- defer func() { <-active }()
+func (c *baseClient) startConnection(ctx context.Context, i int,
+ conn connectors.Connector) (status int) {
for {
connCtx, cancel := context.WithCancel(ctx)
@@ -106,80 +110,25 @@ func (c *baseClient) start(ctx context.Context, active chan struct{}, i int, con
conn.Start(connCtx, cancel, c.throttleCh, c.stats.connectionsEstCh)
// Retrieve status code from handler (dtail client will exit with that status)
- status = conn.Handler.Status()
+ status = conn.Handler().Status()
if !c.retry {
return
}
time.Sleep(time.Second * 2)
- logger.Debug(conn.Server, "Reconnecting")
-
- conn = c.makeConnection(conn.Server, c.sshAuthMethods, c.hostKeyCallback)
+ dlog.Client.Debug(conn.Server(), "Reconnecting")
+ conn = c.makeConnection(conn.Server(), c.sshAuthMethods, c.hostKeyCallback)
c.connections[i] = conn
}
}
-func (c *baseClient) makeCommandOptions() map[string]string {
- options := make(map[string]string)
-
- if c.Args.Quiet {
- options["quiet"] = fmt.Sprintf("%v", c.Args.Quiet)
- }
- if c.Args.LContext.MaxCount != 0 {
- options["max"] = fmt.Sprintf("%d", c.Args.LContext.MaxCount)
- }
- if c.Args.LContext.BeforeContext != 0 {
- options["before"] = fmt.Sprintf("%d", c.Args.LContext.BeforeContext)
- }
- if c.Args.LContext.AfterContext != 0 {
- options["after"] = fmt.Sprintf("%d", c.Args.LContext.AfterContext)
- }
-
- return options
-}
-
-func (c *baseClient) commandOptionsToString(options map[string]string) string {
- var sb strings.Builder
-
- count := 0
- for k, v := range options {
- if count > 0 {
- sb.WriteString(":")
- }
- sb.WriteString(fmt.Sprintf("%s=%s", k, v))
- count++
- }
-
- return sb.String()
-}
-
-func (c *baseClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback client.HostKeyCallback) *remote.Connection {
- conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback)
- conn.Handler = c.maker.makeHandler(server)
- conn.Commands = c.maker.makeCommands(c.makeCommandOptions())
-
- return conn
-}
-
-func (c *baseClient) waitUntilDone(ctx context.Context, active chan struct{}) {
- defer logger.Debug("Terminated connection")
-
- // We want to have at least one active connection
- <-active
- // Put it back on the channel
- active <- struct{}{}
-
- if c.Mode == omode.TailClient && c.retry {
- <-ctx.Done()
- }
-
- for {
- numActive := len(active)
- if numActive == 0 {
- return
- }
- logger.Debug("Active connections", numActive)
- time.Sleep(time.Second)
+func (c *baseClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod,
+ hostKeyCallback client.HostKeyCallback) connectors.Connector {
+ if c.Args.Serverless {
+ return connectors.NewServerless(c.UserName, c.maker.makeHandler(server),
+ c.maker.makeCommands())
}
+ return connectors.NewServerConnection(server, c.UserName, sshAuthMethods,
+ hostKeyCallback, c.maker.makeHandler(server), c.maker.makeCommands())
}