diff options
Diffstat (limited to 'internal/clients')
| -rw-r--r-- | internal/clients/args.go | 3 | ||||
| -rw-r--r-- | internal/clients/baseclient.go | 130 | ||||
| -rw-r--r-- | internal/clients/catclient.go | 20 | ||||
| -rw-r--r-- | internal/clients/client.go | 5 | ||||
| -rw-r--r-- | internal/clients/connectionmaker.go | 12 | ||||
| -rw-r--r-- | internal/clients/execclient.go | 48 | ||||
| -rw-r--r-- | internal/clients/grepclient.go | 20 | ||||
| -rw-r--r-- | internal/clients/handlers/basehandler.go | 84 | ||||
| -rw-r--r-- | internal/clients/handlers/clienthandler.go | 11 | ||||
| -rw-r--r-- | internal/clients/handlers/handler.go | 12 | ||||
| -rw-r--r-- | internal/clients/handlers/healthhandler.go | 21 | ||||
| -rw-r--r-- | internal/clients/handlers/maprhandler.go | 21 | ||||
| -rw-r--r-- | internal/clients/handlers/withcancel.go | 24 | ||||
| -rw-r--r-- | internal/clients/healthclient.go | 7 | ||||
| -rw-r--r-- | internal/clients/maker.go | 8 | ||||
| -rw-r--r-- | internal/clients/maprclient.go | 52 | ||||
| -rw-r--r-- | internal/clients/remote/connection.go | 116 | ||||
| -rw-r--r-- | internal/clients/runclient.go | 40 | ||||
| -rw-r--r-- | internal/clients/stats.go | 8 | ||||
| -rw-r--r-- | internal/clients/tailclient.go | 21 |
20 files changed, 306 insertions, 357 deletions
diff --git a/internal/clients/args.go b/internal/clients/args.go index 5fe0a72..dea5a9e 100644 --- a/internal/clients/args.go +++ b/internal/clients/args.go @@ -9,10 +9,9 @@ type Args struct { Mode omode.Mode ServersStr string UserName string - Files string + What string Regex string TrustAllHosts bool Discovery string ConnectionsPerCPU int - PingTimeout int } diff --git a/internal/clients/baseclient.go b/internal/clients/baseclient.go index 574ae94..b1540ea 100644 --- a/internal/clients/baseclient.go +++ b/internal/clients/baseclient.go @@ -1,13 +1,14 @@ package clients import ( + "context" "regexp" "sync" "time" "github.com/mimecast/dtail/internal/clients/remote" "github.com/mimecast/dtail/internal/discovery" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/omode" "github.com/mimecast/dtail/internal/ssh/client" @@ -27,111 +28,110 @@ type baseClient struct { sshAuthMethods []gossh.AuthMethod // To deal with SSH host keys hostKeyCallback *client.HostKeyCallback - // To stop the client. - stop chan struct{} - // To indicate that the client has stopped. - stopped chan struct{} // Throttle how fast we initiate SSH connections concurrently throttleCh chan struct{} // Retry connection upon failure? retry bool - // Connection helper. - maker connectionMaker + // Connection maker helper. + maker maker } -func (c *baseClient) init(maker connectionMaker) { +func (c *baseClient) init(maker maker) { logger.Info("Initiating base client") c.maker = maker - //c.connections = make(map[string]*remote.Connection) c.sshAuthMethods, c.hostKeyCallback = client.InitSSHAuthMethods(c.TrustAllHosts, c.throttleCh) + discoveryService := discovery.New(c.Discovery, c.ServersStr, discovery.Shuffle) - // Retrieve a shuffled list of remote dtail servers. - shuffleServers := true - discoveryService := discovery.New(c.Discovery, c.ServersStr, shuffleServers) for _, server := range discoveryService.ServerList() { - c.connections = append(c.connections, c.maker.makeConnection(server, c.sshAuthMethods, c.hostKeyCallback)) + c.connections = append(c.connections, c.makeConnection(server, c.sshAuthMethods, c.hostKeyCallback)) } if _, err := regexp.Compile(c.Regex); err != nil { logger.FatalExit(c.Regex, "Can't test compile regex", err) } - // Periodically check for unknown hosts, and ask the user whether to trust them or not. - go c.hostKeyCallback.PromptAddHosts(c.stop) - - // Periodically print out connection stats to the client. c.stats = newTailStats(len(c.connections)) - go c.stats.periodicLogStats(c.throttleCh, c.stop) } -func (c *baseClient) Start() (status int) { +func (c *baseClient) Start(ctx context.Context) (status int) { + // Periodically check for unknown hosts, and ask the user whether to trust them or not. + go c.hostKeyCallback.PromptAddHosts(ctx) + // Periodically print out connection stats to the client. + go c.stats.periodicLogStats(ctx, c.throttleCh) + // Keep count of active connections active := make(chan struct{}, len(c.connections)) - var wg sync.WaitGroup - wg.Add(len(c.connections)) - + var mutex sync.Mutex for i, conn := range c.connections { go func(i int, conn *remote.Connection) { - active <- struct{}{} - defer func() { - logger.Debug(conn.Server, "Disconnected completely...") - <-active - }() - wg.Done() - - for { - conn.Start(c.throttleCh, c.stats.connectionsEstCh) - if !c.retry { - return - } - time.Sleep(time.Second * 2) - logger.Debug(conn.Server, "Reconencting") - conn = c.maker.makeConnection(conn.Server, c.sshAuthMethods, c.hostKeyCallback) - c.connections[i] = conn + connStatus := c.start(ctx, active, i, conn) + + // Update global status. + mutex.Lock() + defer mutex.Unlock() + if connStatus > status { + status = connStatus } }(i, conn) } - wg.Wait() - c.waitUntilDone(active) - + c.waitUntilDone(ctx, active) return } -func (c *baseClient) waitUntilDone(active chan struct{}) { - defer close(c.stopped) +func (c *baseClient) start(ctx context.Context, active chan struct{}, i int, conn *remote.Connection) (status int) { + // Increment connection count + active <- struct{}{} + // Derement connection count + defer func() { <-active }() - if c.Mode != omode.TailClient { - c.waitUntilZero(active) - logger.Info("All connections stopped") - return - } + for { + connCtx, cancel := conn.Handler.WithCancel(ctx) + defer cancel() - <-c.stop - logger.Info("Stopping client") - for _, conn := range c.connections { - conn.Stop() + conn.Start(connCtx, cancel, c.throttleCh, c.stats.connectionsEstCh) + // Retrieve status code from handler (dtail client will exit with that status) + status = conn.Handler.Status() + + if !c.retry { + return + } + + time.Sleep(time.Second * 2) + logger.Debug(conn.Server, "Reconnecting") + + conn = c.makeConnection(conn.Server, c.sshAuthMethods, c.hostKeyCallback) + c.connections[i] = conn } +} - c.waitUntilZero(active) +func (c *baseClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection { + conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback) + conn.Handler = c.maker.makeHandler(server) + conn.Commands = c.maker.makeCommands() + + return conn } -func (c *baseClient) waitUntilZero(active chan struct{}) { +func (c *baseClient) waitUntilDone(ctx context.Context, active chan struct{}) { + defer logger.Info("Terminated connection") + + // We want to have at least one active connection + <-active + // Put it back on the channel + active <- struct{}{} + + if c.Mode == omode.TailClient { + <-ctx.Done() + } + for { - logger.Debug("Active connections", len(active)) - if len(active) == 0 { + numActive := len(active) + if numActive == 0 { return } + logger.Debug("Active connections", numActive) time.Sleep(time.Second) } } - -func (c *baseClient) Stop() { - close(c.stop) - <-c.WaitC() -} - -func (c *baseClient) WaitC() <-chan struct{} { - return c.stopped -} diff --git a/internal/clients/catclient.go b/internal/clients/catclient.go index 5ea701d..7fd6bdc 100644 --- a/internal/clients/catclient.go +++ b/internal/clients/catclient.go @@ -7,11 +7,7 @@ import ( "strings" "github.com/mimecast/dtail/internal/clients/handlers" - "github.com/mimecast/dtail/internal/clients/remote" "github.com/mimecast/dtail/internal/omode" - "github.com/mimecast/dtail/internal/ssh/client" - - gossh "golang.org/x/crypto/ssh" ) // CatClient is a client for returning a whole file from the beginning to the end. @@ -31,8 +27,6 @@ func NewCatClient(args Args) (*CatClient, error) { c := CatClient{ baseClient: baseClient{ Args: args, - stop: make(chan struct{}), - stopped: make(chan struct{}), throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()), retry: false, }, @@ -43,11 +37,13 @@ func NewCatClient(args Args) (*CatClient, error) { return &c, nil } -func (c CatClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection { - conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback) - conn.Handler = handlers.NewClientHandler(server, c.PingTimeout) - for _, file := range strings.Split(c.Files, ",") { - conn.Commands = append(conn.Commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex)) +func (c CatClient) makeHandler(server string) handlers.Handler { + return handlers.NewClientHandler(server) +} + +func (c CatClient) makeCommands() (commands []string) { + for _, file := range strings.Split(c.What, ",") { + commands = append(commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex)) } - return conn + return } diff --git a/internal/clients/client.go b/internal/clients/client.go index 85d1aae..1fc5e23 100644 --- a/internal/clients/client.go +++ b/internal/clients/client.go @@ -1,7 +1,8 @@ package clients +import "context" + // Client is the interface for the end user command line client. type Client interface { - Start() int - Stop() + Start(ctx context.Context) int } diff --git a/internal/clients/connectionmaker.go b/internal/clients/connectionmaker.go deleted file mode 100644 index 0617992..0000000 --- a/internal/clients/connectionmaker.go +++ /dev/null @@ -1,12 +0,0 @@ -package clients - -import ( - "github.com/mimecast/dtail/internal/clients/remote" - "github.com/mimecast/dtail/internal/ssh/client" - - gossh "golang.org/x/crypto/ssh" -) - -type connectionMaker interface { - makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection -} diff --git a/internal/clients/execclient.go b/internal/clients/execclient.go deleted file mode 100644 index 10bd081..0000000 --- a/internal/clients/execclient.go +++ /dev/null @@ -1,48 +0,0 @@ -package clients - -import ( - "fmt" - "runtime" - "strings" - - "github.com/mimecast/dtail/internal/clients/handlers" - "github.com/mimecast/dtail/internal/clients/remote" - "github.com/mimecast/dtail/internal/omode" - "github.com/mimecast/dtail/internal/ssh/client" - - gossh "golang.org/x/crypto/ssh" -) - -// ExecClient is a client for execute various commands on the server. -type ExecClient struct { - baseClient -} - -// NewExecClient returns a new cat client. -func NewExecClient(args Args) (*ExecClient, error) { - args.Regex = "." - args.Mode = omode.ExecClient - - c := ExecClient{ - baseClient: baseClient{ - Args: args, - stop: make(chan struct{}), - stopped: make(chan struct{}), - throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()), - retry: false, - }, - } - - c.init(c) - - return &c, nil -} - -func (c ExecClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection { - conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback) - conn.Handler = handlers.NewClientHandler(server, c.PingTimeout) - for _, file := range strings.Split(c.Files, ";") { - conn.Commands = append(conn.Commands, fmt.Sprintf("%s %s", c.Mode.String(), file)) - } - return conn -} diff --git a/internal/clients/grepclient.go b/internal/clients/grepclient.go index c568f63..8d11458 100644 --- a/internal/clients/grepclient.go +++ b/internal/clients/grepclient.go @@ -7,11 +7,7 @@ import ( "strings" "github.com/mimecast/dtail/internal/clients/handlers" - "github.com/mimecast/dtail/internal/clients/remote" "github.com/mimecast/dtail/internal/omode" - "github.com/mimecast/dtail/internal/ssh/client" - - gossh "golang.org/x/crypto/ssh" ) // GrepClient searches a remote file for all lines matching a regular expression. Only the matching lines are displayed. @@ -29,8 +25,6 @@ func NewGrepClient(args Args) (*GrepClient, error) { c := GrepClient{ baseClient: baseClient{ Args: args, - stop: make(chan struct{}), - stopped: make(chan struct{}), throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()), retry: false, }, @@ -41,13 +35,13 @@ func NewGrepClient(args Args) (*GrepClient, error) { return &c, nil } -func (c GrepClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection { - conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback) - conn.Handler = handlers.NewClientHandler(server, c.PingTimeout) +func (c GrepClient) makeHandler(server string) handlers.Handler { + return handlers.NewClientHandler(server) +} - for _, file := range strings.Split(c.Files, ",") { - conn.Commands = append(conn.Commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex)) +func (c GrepClient) makeCommands() (commands []string) { + for _, file := range strings.Split(c.What, ",") { + commands = append(commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex)) } - - return conn + return } diff --git a/internal/clients/handlers/basehandler.go b/internal/clients/handlers/basehandler.go index 19246f9..68b8ddc 100644 --- a/internal/clients/handlers/basehandler.go +++ b/internal/clients/handlers/basehandler.go @@ -1,60 +1,44 @@ package handlers import ( - "github.com/mimecast/dtail/internal/logger" - "errors" + "encoding/base64" "fmt" "io" + "strconv" "strings" "time" + + "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/version" ) type baseHandler struct { + withCancel server string shellStarted bool commands chan string - pong chan struct{} receiveBuf []byte - stop chan struct{} - pingTimeout int + status int } func (h *baseHandler) Server() string { return h.server } -// Used to determine whether server is still responding to requests or not. -func (h *baseHandler) Ping() error { - if h.pingTimeout == 0 { - // Server ping disabled - return nil - } - - if err := h.SendCommand("ping"); err != nil { - return err - } - - select { - case <-h.pong: - return nil - case <-time.After(time.Duration(h.pingTimeout) * time.Second): - } - - return errors.New("Didn't receive any server pongs (ping replies)") +func (h *baseHandler) Status() int { + return h.status } -func (h *baseHandler) SendCommand(command string) error { - if command == "ping" { - logger.Trace("Sending command", h.server, command) - } else { - logger.Debug("Sending command", h.server, command) - } +// SendMessage to the server. +func (h *baseHandler) SendMessage(command string) error { + encoded := base64.StdEncoding.EncodeToString([]byte(command)) + logger.Debug("Sending command", h.server, command, encoded) select { - case h.commands <- fmt.Sprintf("%s;", command): + case h.commands <- fmt.Sprintf("protocol %s base64 %v;", version.ProtocolCompat, encoded): case <-time.After(time.Second * 5): - return errors.New("Timed out sending command " + command) - case <-h.stop: + return fmt.Errorf("Timed out sending command '%s' (base64: '%s')", command, encoded) + case <-h.ctx.Done(): } return nil @@ -81,7 +65,7 @@ func (h *baseHandler) Read(p []byte) (n int, err error) { select { case command := <-h.commands: n = copy(p, []byte(command)) - case <-h.stop: + case <-h.ctx.Done(): return 0, io.EOF } return @@ -92,6 +76,7 @@ func (h *baseHandler) handleMessageType(message string) { if len(h.receiveBuf) == 0 { return } + // Hidden server commands starti with a dot "." if h.receiveBuf[0] == '.' { h.handleHiddenMessage(message) @@ -108,6 +93,7 @@ func (h *baseHandler) handleMessageType(message string) { h.receiveBuf = h.receiveBuf[:0] return } + logger.Raw(message) h.receiveBuf = h.receiveBuf[:0] } @@ -116,19 +102,27 @@ func (h *baseHandler) handleMessageType(message string) { // to the end user. func (h *baseHandler) handleHiddenMessage(message string) { switch { - case strings.HasPrefix(message, ".pong"): - h.pong <- struct{}{} case strings.HasPrefix(message, ".syn close connection"): - h.SendCommand("ack close connection") - } -} + h.SendMessage(".ack close connection") + select { + case <-time.After(time.Second * 1): + logger.Debug("Shutting down client after timeout and sending ack to server") + h.withCancel.shutdown() + case <-h.ctx.Done(): + } -// Stop the handler. -func (h *baseHandler) Stop() { - select { - case <-h.stop: - default: - logger.Debug("Stopping base handler", h.server) - close(h.stop) + case strings.HasPrefix(message, ".run exitstatus"): + splitted := strings.Split(strings.TrimSuffix(message, "\n"), " ") + if len(splitted) != 3 { + logger.Error("Unable to retrieve exitstatus", message) + return + } + i, err := strconv.Atoi(splitted[2]) + if err != nil { + logger.Error("Unable to retrieve exitstatus", message, err) + return + } + h.status = i + logger.Debug("Retrieved exitstatus", h.status) } } diff --git a/internal/clients/handlers/clienthandler.go b/internal/clients/handlers/clienthandler.go index 4738cd3..fcd8052 100644 --- a/internal/clients/handlers/clienthandler.go +++ b/internal/clients/handlers/clienthandler.go @@ -1,7 +1,7 @@ package handlers import ( - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" ) // ClientHandler is the basic client handler interface. @@ -10,7 +10,7 @@ type ClientHandler struct { } // NewClientHandler creates a new client handler. -func NewClientHandler(server string, pingTimeout int) *ClientHandler { +func NewClientHandler(server string) *ClientHandler { logger.Debug(server, "Creating new client handler") return &ClientHandler{ @@ -18,9 +18,10 @@ func NewClientHandler(server string, pingTimeout int) *ClientHandler { server: server, shellStarted: false, commands: make(chan string), - pong: make(chan struct{}, 1), - stop: make(chan struct{}), - pingTimeout: pingTimeout, + status: -1, + withCancel: withCancel{ + done: make(chan struct{}), + }, }, } } diff --git a/internal/clients/handlers/handler.go b/internal/clients/handlers/handler.go index 2013be0..c53ca34 100644 --- a/internal/clients/handlers/handler.go +++ b/internal/clients/handlers/handler.go @@ -1,12 +1,16 @@ package handlers -import "io" +import ( + "context" + "io" +) // Handler provides all methods which can be run on any client handler. type Handler interface { io.ReadWriter - Ping() error - Stop() - SendCommand(command string) error + SendMessage(command string) error Server() string + Status() int + WithCancel(ctx context.Context) (context.Context, context.CancelFunc) + Done() <-chan struct{} } diff --git a/internal/clients/handlers/healthhandler.go b/internal/clients/handlers/healthhandler.go index 4051e2c..9051015 100644 --- a/internal/clients/handlers/healthhandler.go +++ b/internal/clients/handlers/healthhandler.go @@ -8,6 +8,7 @@ import ( // HealthHandler implements the handler required for health checks. type HealthHandler struct { + withCancel // Buffer of incoming data from server. receiveBuf []byte // To send commands to the server. @@ -16,6 +17,7 @@ type HealthHandler struct { receive chan<- string // The remote server address server string + status int } // NewHealthHandler returns a new health check handler. @@ -24,6 +26,10 @@ func NewHealthHandler(server string, receive chan<- string) *HealthHandler { server: server, receive: receive, commands: make(chan string), + status: -1, + withCancel: withCancel{ + done: make(chan struct{}), + }, } return &h @@ -34,18 +40,13 @@ func (h *HealthHandler) Server() string { return h.server } -// Stop is not of use for health check handler. -func (h *HealthHandler) Stop() { - // Nothing done here. +// Status of the handler. +func (h *HealthHandler) Status() int { + return h.status } -// Ping is not of use for health check handler. -func (h *HealthHandler) Ping() error { - return nil -} - -// SendCommand send a DTail command to the server. -func (h *HealthHandler) SendCommand(command string) error { +// SendMessage sends a DTail command to the server. +func (h *HealthHandler) SendMessage(command string) error { select { case h.commands <- fmt.Sprintf("%s;", command): case <-time.NewTimer(time.Second * 10).C: diff --git a/internal/clients/handlers/maprhandler.go b/internal/clients/handlers/maprhandler.go index d76cdfd..874bb7d 100644 --- a/internal/clients/handlers/maprhandler.go +++ b/internal/clients/handlers/maprhandler.go @@ -1,10 +1,11 @@ package handlers import ( - "github.com/mimecast/dtail/internal/logger" + "strings" + + "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/mapr" "github.com/mimecast/dtail/internal/mapr/client" - "strings" ) // MaprHandler is the handler used on the client side for running mapreduce aggregations. @@ -16,15 +17,16 @@ type MaprHandler struct { } // NewMaprHandler returns a new mapreduce client handler. -func NewMaprHandler(server string, query *mapr.Query, globalGroup *mapr.GlobalGroupSet, pingTimeout int) *MaprHandler { +func NewMaprHandler(server string, query *mapr.Query, globalGroup *mapr.GlobalGroupSet) *MaprHandler { return &MaprHandler{ baseHandler: baseHandler{ server: server, shellStarted: false, commands: make(chan string), - pong: make(chan struct{}, 1), - stop: make(chan struct{}), - pingTimeout: pingTimeout, + status: -1, + withCancel: withCancel{ + done: make(chan struct{}), + }, }, query: query, aggregate: client.NewAggregate(server, query, globalGroup), @@ -65,10 +67,3 @@ func (h *MaprHandler) handleAggregateMessage(message string) { h.aggregate.Aggregate(parts[2:]) logger.Debug("Aggregated aggregate data", h.server, h.count) } - -// Stop stops the mapreduce client handler. -func (h *MaprHandler) Stop() { - logger.Debug("Stopping mapreduce handler", h.server) - h.aggregate.Stop() - h.baseHandler.Stop() -} diff --git a/internal/clients/handlers/withcancel.go b/internal/clients/handlers/withcancel.go new file mode 100644 index 0000000..7c9cf4e --- /dev/null +++ b/internal/clients/handlers/withcancel.go @@ -0,0 +1,24 @@ +package handlers + +import "context" + +type withCancel struct { + ctx context.Context + done chan struct{} +} + +// WithCancel sets and returns the context used. +func (w *withCancel) WithCancel(ctx context.Context) (context.Context, context.CancelFunc) { + cancelCtx, cancel := context.WithCancel(ctx) + w.ctx = cancelCtx + + return cancelCtx, cancel +} + +func (w *withCancel) Done() <-chan struct{} { + return w.done +} + +func (w *withCancel) shutdown() { + close(w.done) +} diff --git a/internal/clients/healthclient.go b/internal/clients/healthclient.go index ff13b83..7313583 100644 --- a/internal/clients/healthclient.go +++ b/internal/clients/healthclient.go @@ -1,6 +1,7 @@ package clients import ( + "context" "fmt" "runtime" "strings" @@ -39,7 +40,7 @@ func NewHealthClient(mode omode.Mode) (*HealthClient, error) { } // Start the health client. -func (c *HealthClient) Start() (status int) { +func (c *HealthClient) Start(ctx context.Context) (status int) { receive := make(chan string) throttleCh := make(chan struct{}, runtime.NumCPU()) @@ -49,8 +50,8 @@ func (c *HealthClient) Start() (status int) { conn.Handler = handlers.NewHealthHandler(c.server, receive) conn.Commands = []string{c.mode.String()} - go conn.Start(throttleCh, statsCh) - defer conn.Stop() + connCtx, cancel := conn.Handler.WithCancel(ctx) + go conn.Start(connCtx, cancel, throttleCh, statsCh) for { select { diff --git a/internal/clients/maker.go b/internal/clients/maker.go new file mode 100644 index 0000000..da9dfc9 --- /dev/null +++ b/internal/clients/maker.go @@ -0,0 +1,8 @@ +package clients + +import "github.com/mimecast/dtail/internal/clients/handlers" + +type maker interface { + makeHandler(server string) handlers.Handler + makeCommands() (commands []string) +} diff --git a/internal/clients/maprclient.go b/internal/clients/maprclient.go index 9070827..b581844 100644 --- a/internal/clients/maprclient.go +++ b/internal/clients/maprclient.go @@ -1,6 +1,7 @@ package clients import ( + "context" "errors" "fmt" "runtime" @@ -8,13 +9,9 @@ import ( "time" "github.com/mimecast/dtail/internal/clients/handlers" - "github.com/mimecast/dtail/internal/clients/remote" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/mapr" "github.com/mimecast/dtail/internal/omode" - "github.com/mimecast/dtail/internal/ssh/client" - - gossh "golang.org/x/crypto/ssh" ) // MaprClient is used for running mapreduce aggregations on remote files. @@ -39,8 +36,6 @@ func NewMaprClient(args Args, queryStr string) (*MaprClient, error) { c := MaprClient{ baseClient: baseClient{ Args: args, - stop: make(chan struct{}), - stopped: make(chan struct{}), throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()), retry: args.Mode == omode.TailClient, }, @@ -70,35 +65,36 @@ func NewMaprClient(args Args, queryStr string) (*MaprClient, error) { return &c, nil } -func (c MaprClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection { - conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback) - conn.Handler = handlers.NewMaprHandler(conn.Server, c.query, c.globalGroup, c.PingTimeout) +// Start starts the mapreduce client. +func (c *MaprClient) Start(ctx context.Context) (status int) { + if c.query.Outfile == "" { + // Only print out periodic results if we don't write an outfile + go c.periodicPrintResults(ctx) + } - conn.Commands = append(conn.Commands, fmt.Sprintf("map %s", c.query.RawQuery)) - commandStr := "tail" + status = c.baseClient.Start(ctx) if c.additative { - commandStr = "cat" + c.recievedFinalResult() } - for _, file := range strings.Split(c.Files, ",") { - conn.Commands = append(conn.Commands, fmt.Sprintf("%s %s regex %s", commandStr, file, c.Regex)) - } + return +} - return conn +func (c MaprClient) makeHandler(server string) handlers.Handler { + return handlers.NewMaprHandler(server, c.query, c.globalGroup) } -// Start starts the mapreduce client. -func (c *MaprClient) Start() (status int) { - if c.query.Outfile == "" { - // Only print out periodic results if we don't write an outfile - go c.periodicPrintResults() - } +func (c MaprClient) makeCommands() (commands []string) { + commands = append(commands, fmt.Sprintf("map %s", c.query.RawQuery)) - status = c.baseClient.Start() + modeStr := "tail" if c.additative { - c.recievedFinalResult() + modeStr = "cat" + } + + for _, file := range strings.Split(c.What, ",") { + commands = append(commands, fmt.Sprintf("%s %s regex %s", modeStr, file, c.Regex)) } - c.baseClient.Stop() return } @@ -120,13 +116,13 @@ func (c *MaprClient) recievedFinalResult() { logger.Info(fmt.Sprintf("Wrote final mapreduce result to '%s'", c.query.Outfile)) } -func (c *MaprClient) periodicPrintResults() { +func (c *MaprClient) periodicPrintResults(ctx context.Context) { for { select { case <-time.After(c.query.Interval): logger.Info("Gathering interim mapreduce result") c.printResults() - case <-c.baseClient.stop: + case <-ctx.Done(): return } } diff --git a/internal/clients/remote/connection.go b/internal/clients/remote/connection.go index bfc7bc5..71639b1 100644 --- a/internal/clients/remote/connection.go +++ b/internal/clients/remote/connection.go @@ -1,16 +1,18 @@ package remote import ( - "github.com/mimecast/dtail/internal/clients/handlers" - "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/logger" - "github.com/mimecast/dtail/internal/ssh/client" + "context" "fmt" "io" "strconv" "strings" "time" + "github.com/mimecast/dtail/internal/clients/handlers" + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/ssh/client" + "golang.org/x/crypto/ssh" ) @@ -30,8 +32,6 @@ type Connection struct { Commands []string // Is it a persistent connection or a one-off? isOneOff bool - // Used to stop the connection - stop chan struct{} // To deal with SSH server host keys hostKeyCallback *client.HostKeyCallback } @@ -48,7 +48,6 @@ func NewConnection(server string, userName string, authMethods []ssh.AuthMethod, HostKeyCallback: hostKeyCallback.Wrap(), Timeout: time.Second * 3, }, - stop: make(chan struct{}), } c.initServerPort(server) @@ -64,7 +63,6 @@ func NewOneOffConnection(server string, userName string, authMethods []ssh.AuthM Auth: authMethods, HostKeyCallback: ssh.InsecureIgnoreHostKey(), }, - stop: make(chan struct{}), isOneOff: true, } @@ -90,39 +88,34 @@ func (c *Connection) initServerPort(server string) { } } -// Start the server connection. Build up SSH session and send some DTail commandc. -func (c *Connection) Start(throttleCh, statsCh chan struct{}) { +// Start the server connection. Build up SSH session and send some DTail commands. +func (c *Connection) Start(ctx context.Context, cancel context.CancelFunc, throttleCh, statsCh chan struct{}) { + // Throttle how many connections can be established concurrently (based on ch length) select { - case <-c.stop: - logger.Info(c.Server, c.port, "Disconnecting client") + case throttleCh <- struct{}{}: + defer func() { <-throttleCh }() + case <-ctx.Done(): return - default: } - // Wait for SSH connection throttler - throttleCh <- struct{}{} - - // Wait until connection has been initiated or an error occured - // during initialization. - throttleStopCh := make(chan struct{}, 2) go func() { - <-throttleStopCh - <-throttleCh - }() + defer cancel() - if err := c.dial(c.Server, c.port, throttleStopCh, statsCh); err != nil { - logger.Warn(c.Server, c.port, err) - throttleStopCh <- struct{}{} + if err := c.dial(ctx, cancel, c.Server, c.port, statsCh); err != nil { + logger.Warn(c.Server, c.port, err) - if c.hostKeyCallback.Untrusted(fmt.Sprintf("%s:%d", c.Server, c.port)) { - logger.Debug("Not trusting host, not trying to re-connect", c.Server, c.port) - return + if c.hostKeyCallback.Untrusted(fmt.Sprintf("%s:%d", c.Server, c.port)) { + logger.Debug("Not trusting host", c.Server, c.port) + return + } } - } + }() + + <-ctx.Done() } // Dail into a new SSH connection. Close connection in case of an error. -func (c *Connection) dial(host string, port int, throttleStopCh, statsCh chan struct{}) error { +func (c *Connection) dial(ctx context.Context, cancel context.CancelFunc, host string, port int, statsCh chan struct{}) error { statsCh <- struct{}{} defer func() { <-statsCh }() @@ -135,11 +128,11 @@ func (c *Connection) dial(host string, port int, throttleStopCh, statsCh chan st } defer client.Close() - return c.session(client, throttleStopCh) + return c.session(ctx, cancel, client) } // Create the SSH session. Close the session in case of an error. -func (c *Connection) session(client *ssh.Client, throttleStopCh chan<- struct{}) error { +func (c *Connection) session(ctx context.Context, cancel context.CancelFunc, client *ssh.Client) error { logger.Debug(c.Server, "session") session, err := client.NewSession() @@ -148,14 +141,10 @@ func (c *Connection) session(client *ssh.Client, throttleStopCh chan<- struct{}) } defer session.Close() - return c.handle(session, throttleStopCh) + return c.handle(ctx, cancel, session) } -// Handle the SSH session. Also send periodic pings to the server in order -// to determine that session is still intact. -func (c *Connection) handle(session *ssh.Session, throttleStopCh chan<- struct{}) error { - defer c.Handler.Stop() - +func (c *Connection) handle(ctx context.Context, cancel context.CancelFunc, session *ssh.Session) error { logger.Debug(c.Server, "handle") stdinPipe, err := session.StdinPipe() @@ -172,59 +161,30 @@ func (c *Connection) handle(session *ssh.Session, throttleStopCh chan<- struct{} return err } - // Establish Bi-directional pipe between SSH session and client handler. - brokenStdinPipe := make(chan struct{}) go func() { - defer close(brokenStdinPipe) + defer cancel() io.Copy(stdinPipe, c.Handler) }() - brokenStdoutPipe := make(chan struct{}) go func() { - defer close(brokenStdoutPipe) + defer cancel() io.Copy(c.Handler, stdoutPipe) }() - // SSH session established, other goroutine can initiate session now. - throttleStopCh <- struct{}{} + go func() { + defer cancel() + select { + case <-c.Handler.Done(): + case <-ctx.Done(): + } + }() // Send all commands to client. for _, command := range c.Commands { logger.Debug(command) - c.Handler.SendCommand(command) + c.Handler.SendMessage(command) } - if !c.isOneOff { - return c.periodicAliveCheck(brokenStdinPipe, brokenStdoutPipe) - } - - <-c.stop - - // Normal shutdown, all fine + <-ctx.Done() return nil } - -// Periodically check whether connection is still alive or not. -func (c *Connection) periodicAliveCheck(brokenStdinPipe, brokenStdoutPipe <-chan struct{}) error { - for { - select { - case <-time.After(time.Second * 3): - if err := c.Handler.Ping(); err != nil { - return err - } - case <-brokenStdinPipe: - logger.Debug("Broken stdin pipe", c.Server, c.port) - return nil - case <-brokenStdoutPipe: - logger.Debug("Broken stdout pipe", c.Server, c.port) - return nil - case <-c.stop: - return nil - } - } -} - -// Stop the connection. -func (c *Connection) Stop() { - close(c.stop) -} diff --git a/internal/clients/runclient.go b/internal/clients/runclient.go new file mode 100644 index 0000000..7a62fcc --- /dev/null +++ b/internal/clients/runclient.go @@ -0,0 +1,40 @@ +package clients + +import ( + "fmt" + "runtime" + + "github.com/mimecast/dtail/internal/clients/handlers" + "github.com/mimecast/dtail/internal/omode" +) + +// RunClient is a client to run various commands on the server. +type RunClient struct { + baseClient +} + +// NewRunClient returns a new cat client. +func NewRunClient(args Args) (*RunClient, error) { + args.Mode = omode.RunClient + + c := RunClient{ + baseClient: baseClient{ + Args: args, + throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()), + retry: false, + }, + } + + c.init(c) + return &c, nil +} + +func (c RunClient) makeHandler(server string) handlers.Handler { + return handlers.NewClientHandler(server) +} + +func (c RunClient) makeCommands() (commands []string) { + // Send "run COMMAND" to server! + commands = append(commands, fmt.Sprintf("%s %s", c.Mode.String(), c.What)) + return +} diff --git a/internal/clients/stats.go b/internal/clients/stats.go index d36cef6..ec6adfe 100644 --- a/internal/clients/stats.go +++ b/internal/clients/stats.go @@ -1,11 +1,13 @@ package clients import ( - "github.com/mimecast/dtail/internal/logger" + "context" "fmt" "runtime" "sync" "time" + + "github.com/mimecast/dtail/internal/io/logger" ) // Used to collect and display various client stats. @@ -28,14 +30,14 @@ func newTailStats(connectionsTotal int) *stats { } } -func (s *stats) periodicLogStats(throttleCh chan struct{}, stop <-chan struct{}) { +func (s *stats) periodicLogStats(ctx context.Context, throttleCh chan struct{}) { connectedLast := 0 statsInterval := 5 for { select { case <-time.After(time.Second * time.Duration(statsInterval)): - case <-stop: + case <-ctx.Done(): return } diff --git a/internal/clients/tailclient.go b/internal/clients/tailclient.go index 674ca36..4d81fd5 100644 --- a/internal/clients/tailclient.go +++ b/internal/clients/tailclient.go @@ -6,11 +6,7 @@ import ( "strings" "github.com/mimecast/dtail/internal/clients/handlers" - "github.com/mimecast/dtail/internal/clients/remote" "github.com/mimecast/dtail/internal/omode" - "github.com/mimecast/dtail/internal/ssh/client" - - gossh "golang.org/x/crypto/ssh" ) // TailClient is used for tailing remote log files (opening, seeking to the end and returning only new incoming lines). @@ -25,25 +21,22 @@ func NewTailClient(args Args) (*TailClient, error) { c := TailClient{ baseClient: baseClient{ Args: args, - stop: make(chan struct{}), - stopped: make(chan struct{}), throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()), retry: true, }, } c.init(c) - return &c, nil } -func (c TailClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection { - conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback) - conn.Handler = handlers.NewClientHandler(server, c.PingTimeout) +func (c TailClient) makeHandler(server string) handlers.Handler { + return handlers.NewClientHandler(server) +} - for _, file := range strings.Split(c.Files, ",") { - conn.Commands = append(conn.Commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex)) +func (c TailClient) makeCommands() (commands []string) { + for _, file := range strings.Split(c.What, ",") { + commands = append(commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex)) } - - return conn + return } |
