summaryrefslogtreecommitdiff
path: root/internal/clients
diff options
context:
space:
mode:
Diffstat (limited to 'internal/clients')
-rw-r--r--internal/clients/args.go3
-rw-r--r--internal/clients/baseclient.go130
-rw-r--r--internal/clients/catclient.go20
-rw-r--r--internal/clients/client.go5
-rw-r--r--internal/clients/connectionmaker.go12
-rw-r--r--internal/clients/execclient.go48
-rw-r--r--internal/clients/grepclient.go20
-rw-r--r--internal/clients/handlers/basehandler.go84
-rw-r--r--internal/clients/handlers/clienthandler.go11
-rw-r--r--internal/clients/handlers/handler.go12
-rw-r--r--internal/clients/handlers/healthhandler.go21
-rw-r--r--internal/clients/handlers/maprhandler.go21
-rw-r--r--internal/clients/handlers/withcancel.go24
-rw-r--r--internal/clients/healthclient.go7
-rw-r--r--internal/clients/maker.go8
-rw-r--r--internal/clients/maprclient.go52
-rw-r--r--internal/clients/remote/connection.go116
-rw-r--r--internal/clients/runclient.go40
-rw-r--r--internal/clients/stats.go8
-rw-r--r--internal/clients/tailclient.go21
20 files changed, 306 insertions, 357 deletions
diff --git a/internal/clients/args.go b/internal/clients/args.go
index 5fe0a72..dea5a9e 100644
--- a/internal/clients/args.go
+++ b/internal/clients/args.go
@@ -9,10 +9,9 @@ type Args struct {
Mode omode.Mode
ServersStr string
UserName string
- Files string
+ What string
Regex string
TrustAllHosts bool
Discovery string
ConnectionsPerCPU int
- PingTimeout int
}
diff --git a/internal/clients/baseclient.go b/internal/clients/baseclient.go
index 574ae94..b1540ea 100644
--- a/internal/clients/baseclient.go
+++ b/internal/clients/baseclient.go
@@ -1,13 +1,14 @@
package clients
import (
+ "context"
"regexp"
"sync"
"time"
"github.com/mimecast/dtail/internal/clients/remote"
"github.com/mimecast/dtail/internal/discovery"
- "github.com/mimecast/dtail/internal/logger"
+ "github.com/mimecast/dtail/internal/io/logger"
"github.com/mimecast/dtail/internal/omode"
"github.com/mimecast/dtail/internal/ssh/client"
@@ -27,111 +28,110 @@ type baseClient struct {
sshAuthMethods []gossh.AuthMethod
// To deal with SSH host keys
hostKeyCallback *client.HostKeyCallback
- // To stop the client.
- stop chan struct{}
- // To indicate that the client has stopped.
- stopped chan struct{}
// Throttle how fast we initiate SSH connections concurrently
throttleCh chan struct{}
// Retry connection upon failure?
retry bool
- // Connection helper.
- maker connectionMaker
+ // Connection maker helper.
+ maker maker
}
-func (c *baseClient) init(maker connectionMaker) {
+func (c *baseClient) init(maker maker) {
logger.Info("Initiating base client")
c.maker = maker
- //c.connections = make(map[string]*remote.Connection)
c.sshAuthMethods, c.hostKeyCallback = client.InitSSHAuthMethods(c.TrustAllHosts, c.throttleCh)
+ discoveryService := discovery.New(c.Discovery, c.ServersStr, discovery.Shuffle)
- // Retrieve a shuffled list of remote dtail servers.
- shuffleServers := true
- discoveryService := discovery.New(c.Discovery, c.ServersStr, shuffleServers)
for _, server := range discoveryService.ServerList() {
- c.connections = append(c.connections, c.maker.makeConnection(server, c.sshAuthMethods, c.hostKeyCallback))
+ c.connections = append(c.connections, c.makeConnection(server, c.sshAuthMethods, c.hostKeyCallback))
}
if _, err := regexp.Compile(c.Regex); err != nil {
logger.FatalExit(c.Regex, "Can't test compile regex", err)
}
- // Periodically check for unknown hosts, and ask the user whether to trust them or not.
- go c.hostKeyCallback.PromptAddHosts(c.stop)
-
- // Periodically print out connection stats to the client.
c.stats = newTailStats(len(c.connections))
- go c.stats.periodicLogStats(c.throttleCh, c.stop)
}
-func (c *baseClient) Start() (status int) {
+func (c *baseClient) Start(ctx context.Context) (status int) {
+ // Periodically check for unknown hosts, and ask the user whether to trust them or not.
+ go c.hostKeyCallback.PromptAddHosts(ctx)
+ // Periodically print out connection stats to the client.
+ go c.stats.periodicLogStats(ctx, c.throttleCh)
+ // Keep count of active connections
active := make(chan struct{}, len(c.connections))
- var wg sync.WaitGroup
- wg.Add(len(c.connections))
-
+ var mutex sync.Mutex
for i, conn := range c.connections {
go func(i int, conn *remote.Connection) {
- active <- struct{}{}
- defer func() {
- logger.Debug(conn.Server, "Disconnected completely...")
- <-active
- }()
- wg.Done()
-
- for {
- conn.Start(c.throttleCh, c.stats.connectionsEstCh)
- if !c.retry {
- return
- }
- time.Sleep(time.Second * 2)
- logger.Debug(conn.Server, "Reconencting")
- conn = c.maker.makeConnection(conn.Server, c.sshAuthMethods, c.hostKeyCallback)
- c.connections[i] = conn
+ connStatus := c.start(ctx, active, i, conn)
+
+ // Update global status.
+ mutex.Lock()
+ defer mutex.Unlock()
+ if connStatus > status {
+ status = connStatus
}
}(i, conn)
}
- wg.Wait()
- c.waitUntilDone(active)
-
+ c.waitUntilDone(ctx, active)
return
}
-func (c *baseClient) waitUntilDone(active chan struct{}) {
- defer close(c.stopped)
+func (c *baseClient) start(ctx context.Context, active chan struct{}, i int, conn *remote.Connection) (status int) {
+ // Increment connection count
+ active <- struct{}{}
+ // Derement connection count
+ defer func() { <-active }()
- if c.Mode != omode.TailClient {
- c.waitUntilZero(active)
- logger.Info("All connections stopped")
- return
- }
+ for {
+ connCtx, cancel := conn.Handler.WithCancel(ctx)
+ defer cancel()
- <-c.stop
- logger.Info("Stopping client")
- for _, conn := range c.connections {
- conn.Stop()
+ conn.Start(connCtx, cancel, c.throttleCh, c.stats.connectionsEstCh)
+ // Retrieve status code from handler (dtail client will exit with that status)
+ status = conn.Handler.Status()
+
+ if !c.retry {
+ return
+ }
+
+ time.Sleep(time.Second * 2)
+ logger.Debug(conn.Server, "Reconnecting")
+
+ conn = c.makeConnection(conn.Server, c.sshAuthMethods, c.hostKeyCallback)
+ c.connections[i] = conn
}
+}
- c.waitUntilZero(active)
+func (c *baseClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection {
+ conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback)
+ conn.Handler = c.maker.makeHandler(server)
+ conn.Commands = c.maker.makeCommands()
+
+ return conn
}
-func (c *baseClient) waitUntilZero(active chan struct{}) {
+func (c *baseClient) waitUntilDone(ctx context.Context, active chan struct{}) {
+ defer logger.Info("Terminated connection")
+
+ // We want to have at least one active connection
+ <-active
+ // Put it back on the channel
+ active <- struct{}{}
+
+ if c.Mode == omode.TailClient {
+ <-ctx.Done()
+ }
+
for {
- logger.Debug("Active connections", len(active))
- if len(active) == 0 {
+ numActive := len(active)
+ if numActive == 0 {
return
}
+ logger.Debug("Active connections", numActive)
time.Sleep(time.Second)
}
}
-
-func (c *baseClient) Stop() {
- close(c.stop)
- <-c.WaitC()
-}
-
-func (c *baseClient) WaitC() <-chan struct{} {
- return c.stopped
-}
diff --git a/internal/clients/catclient.go b/internal/clients/catclient.go
index 5ea701d..7fd6bdc 100644
--- a/internal/clients/catclient.go
+++ b/internal/clients/catclient.go
@@ -7,11 +7,7 @@ import (
"strings"
"github.com/mimecast/dtail/internal/clients/handlers"
- "github.com/mimecast/dtail/internal/clients/remote"
"github.com/mimecast/dtail/internal/omode"
- "github.com/mimecast/dtail/internal/ssh/client"
-
- gossh "golang.org/x/crypto/ssh"
)
// CatClient is a client for returning a whole file from the beginning to the end.
@@ -31,8 +27,6 @@ func NewCatClient(args Args) (*CatClient, error) {
c := CatClient{
baseClient: baseClient{
Args: args,
- stop: make(chan struct{}),
- stopped: make(chan struct{}),
throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()),
retry: false,
},
@@ -43,11 +37,13 @@ func NewCatClient(args Args) (*CatClient, error) {
return &c, nil
}
-func (c CatClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection {
- conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback)
- conn.Handler = handlers.NewClientHandler(server, c.PingTimeout)
- for _, file := range strings.Split(c.Files, ",") {
- conn.Commands = append(conn.Commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex))
+func (c CatClient) makeHandler(server string) handlers.Handler {
+ return handlers.NewClientHandler(server)
+}
+
+func (c CatClient) makeCommands() (commands []string) {
+ for _, file := range strings.Split(c.What, ",") {
+ commands = append(commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex))
}
- return conn
+ return
}
diff --git a/internal/clients/client.go b/internal/clients/client.go
index 85d1aae..1fc5e23 100644
--- a/internal/clients/client.go
+++ b/internal/clients/client.go
@@ -1,7 +1,8 @@
package clients
+import "context"
+
// Client is the interface for the end user command line client.
type Client interface {
- Start() int
- Stop()
+ Start(ctx context.Context) int
}
diff --git a/internal/clients/connectionmaker.go b/internal/clients/connectionmaker.go
deleted file mode 100644
index 0617992..0000000
--- a/internal/clients/connectionmaker.go
+++ /dev/null
@@ -1,12 +0,0 @@
-package clients
-
-import (
- "github.com/mimecast/dtail/internal/clients/remote"
- "github.com/mimecast/dtail/internal/ssh/client"
-
- gossh "golang.org/x/crypto/ssh"
-)
-
-type connectionMaker interface {
- makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection
-}
diff --git a/internal/clients/execclient.go b/internal/clients/execclient.go
deleted file mode 100644
index 10bd081..0000000
--- a/internal/clients/execclient.go
+++ /dev/null
@@ -1,48 +0,0 @@
-package clients
-
-import (
- "fmt"
- "runtime"
- "strings"
-
- "github.com/mimecast/dtail/internal/clients/handlers"
- "github.com/mimecast/dtail/internal/clients/remote"
- "github.com/mimecast/dtail/internal/omode"
- "github.com/mimecast/dtail/internal/ssh/client"
-
- gossh "golang.org/x/crypto/ssh"
-)
-
-// ExecClient is a client for execute various commands on the server.
-type ExecClient struct {
- baseClient
-}
-
-// NewExecClient returns a new cat client.
-func NewExecClient(args Args) (*ExecClient, error) {
- args.Regex = "."
- args.Mode = omode.ExecClient
-
- c := ExecClient{
- baseClient: baseClient{
- Args: args,
- stop: make(chan struct{}),
- stopped: make(chan struct{}),
- throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()),
- retry: false,
- },
- }
-
- c.init(c)
-
- return &c, nil
-}
-
-func (c ExecClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection {
- conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback)
- conn.Handler = handlers.NewClientHandler(server, c.PingTimeout)
- for _, file := range strings.Split(c.Files, ";") {
- conn.Commands = append(conn.Commands, fmt.Sprintf("%s %s", c.Mode.String(), file))
- }
- return conn
-}
diff --git a/internal/clients/grepclient.go b/internal/clients/grepclient.go
index c568f63..8d11458 100644
--- a/internal/clients/grepclient.go
+++ b/internal/clients/grepclient.go
@@ -7,11 +7,7 @@ import (
"strings"
"github.com/mimecast/dtail/internal/clients/handlers"
- "github.com/mimecast/dtail/internal/clients/remote"
"github.com/mimecast/dtail/internal/omode"
- "github.com/mimecast/dtail/internal/ssh/client"
-
- gossh "golang.org/x/crypto/ssh"
)
// GrepClient searches a remote file for all lines matching a regular expression. Only the matching lines are displayed.
@@ -29,8 +25,6 @@ func NewGrepClient(args Args) (*GrepClient, error) {
c := GrepClient{
baseClient: baseClient{
Args: args,
- stop: make(chan struct{}),
- stopped: make(chan struct{}),
throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()),
retry: false,
},
@@ -41,13 +35,13 @@ func NewGrepClient(args Args) (*GrepClient, error) {
return &c, nil
}
-func (c GrepClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection {
- conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback)
- conn.Handler = handlers.NewClientHandler(server, c.PingTimeout)
+func (c GrepClient) makeHandler(server string) handlers.Handler {
+ return handlers.NewClientHandler(server)
+}
- for _, file := range strings.Split(c.Files, ",") {
- conn.Commands = append(conn.Commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex))
+func (c GrepClient) makeCommands() (commands []string) {
+ for _, file := range strings.Split(c.What, ",") {
+ commands = append(commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex))
}
-
- return conn
+ return
}
diff --git a/internal/clients/handlers/basehandler.go b/internal/clients/handlers/basehandler.go
index 19246f9..68b8ddc 100644
--- a/internal/clients/handlers/basehandler.go
+++ b/internal/clients/handlers/basehandler.go
@@ -1,60 +1,44 @@
package handlers
import (
- "github.com/mimecast/dtail/internal/logger"
- "errors"
+ "encoding/base64"
"fmt"
"io"
+ "strconv"
"strings"
"time"
+
+ "github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/version"
)
type baseHandler struct {
+ withCancel
server string
shellStarted bool
commands chan string
- pong chan struct{}
receiveBuf []byte
- stop chan struct{}
- pingTimeout int
+ status int
}
func (h *baseHandler) Server() string {
return h.server
}
-// Used to determine whether server is still responding to requests or not.
-func (h *baseHandler) Ping() error {
- if h.pingTimeout == 0 {
- // Server ping disabled
- return nil
- }
-
- if err := h.SendCommand("ping"); err != nil {
- return err
- }
-
- select {
- case <-h.pong:
- return nil
- case <-time.After(time.Duration(h.pingTimeout) * time.Second):
- }
-
- return errors.New("Didn't receive any server pongs (ping replies)")
+func (h *baseHandler) Status() int {
+ return h.status
}
-func (h *baseHandler) SendCommand(command string) error {
- if command == "ping" {
- logger.Trace("Sending command", h.server, command)
- } else {
- logger.Debug("Sending command", h.server, command)
- }
+// SendMessage to the server.
+func (h *baseHandler) SendMessage(command string) error {
+ encoded := base64.StdEncoding.EncodeToString([]byte(command))
+ logger.Debug("Sending command", h.server, command, encoded)
select {
- case h.commands <- fmt.Sprintf("%s;", command):
+ case h.commands <- fmt.Sprintf("protocol %s base64 %v;", version.ProtocolCompat, encoded):
case <-time.After(time.Second * 5):
- return errors.New("Timed out sending command " + command)
- case <-h.stop:
+ return fmt.Errorf("Timed out sending command '%s' (base64: '%s')", command, encoded)
+ case <-h.ctx.Done():
}
return nil
@@ -81,7 +65,7 @@ func (h *baseHandler) Read(p []byte) (n int, err error) {
select {
case command := <-h.commands:
n = copy(p, []byte(command))
- case <-h.stop:
+ case <-h.ctx.Done():
return 0, io.EOF
}
return
@@ -92,6 +76,7 @@ func (h *baseHandler) handleMessageType(message string) {
if len(h.receiveBuf) == 0 {
return
}
+
// Hidden server commands starti with a dot "."
if h.receiveBuf[0] == '.' {
h.handleHiddenMessage(message)
@@ -108,6 +93,7 @@ func (h *baseHandler) handleMessageType(message string) {
h.receiveBuf = h.receiveBuf[:0]
return
}
+
logger.Raw(message)
h.receiveBuf = h.receiveBuf[:0]
}
@@ -116,19 +102,27 @@ func (h *baseHandler) handleMessageType(message string) {
// to the end user.
func (h *baseHandler) handleHiddenMessage(message string) {
switch {
- case strings.HasPrefix(message, ".pong"):
- h.pong <- struct{}{}
case strings.HasPrefix(message, ".syn close connection"):
- h.SendCommand("ack close connection")
- }
-}
+ h.SendMessage(".ack close connection")
+ select {
+ case <-time.After(time.Second * 1):
+ logger.Debug("Shutting down client after timeout and sending ack to server")
+ h.withCancel.shutdown()
+ case <-h.ctx.Done():
+ }
-// Stop the handler.
-func (h *baseHandler) Stop() {
- select {
- case <-h.stop:
- default:
- logger.Debug("Stopping base handler", h.server)
- close(h.stop)
+ case strings.HasPrefix(message, ".run exitstatus"):
+ splitted := strings.Split(strings.TrimSuffix(message, "\n"), " ")
+ if len(splitted) != 3 {
+ logger.Error("Unable to retrieve exitstatus", message)
+ return
+ }
+ i, err := strconv.Atoi(splitted[2])
+ if err != nil {
+ logger.Error("Unable to retrieve exitstatus", message, err)
+ return
+ }
+ h.status = i
+ logger.Debug("Retrieved exitstatus", h.status)
}
}
diff --git a/internal/clients/handlers/clienthandler.go b/internal/clients/handlers/clienthandler.go
index 4738cd3..fcd8052 100644
--- a/internal/clients/handlers/clienthandler.go
+++ b/internal/clients/handlers/clienthandler.go
@@ -1,7 +1,7 @@
package handlers
import (
- "github.com/mimecast/dtail/internal/logger"
+ "github.com/mimecast/dtail/internal/io/logger"
)
// ClientHandler is the basic client handler interface.
@@ -10,7 +10,7 @@ type ClientHandler struct {
}
// NewClientHandler creates a new client handler.
-func NewClientHandler(server string, pingTimeout int) *ClientHandler {
+func NewClientHandler(server string) *ClientHandler {
logger.Debug(server, "Creating new client handler")
return &ClientHandler{
@@ -18,9 +18,10 @@ func NewClientHandler(server string, pingTimeout int) *ClientHandler {
server: server,
shellStarted: false,
commands: make(chan string),
- pong: make(chan struct{}, 1),
- stop: make(chan struct{}),
- pingTimeout: pingTimeout,
+ status: -1,
+ withCancel: withCancel{
+ done: make(chan struct{}),
+ },
},
}
}
diff --git a/internal/clients/handlers/handler.go b/internal/clients/handlers/handler.go
index 2013be0..c53ca34 100644
--- a/internal/clients/handlers/handler.go
+++ b/internal/clients/handlers/handler.go
@@ -1,12 +1,16 @@
package handlers
-import "io"
+import (
+ "context"
+ "io"
+)
// Handler provides all methods which can be run on any client handler.
type Handler interface {
io.ReadWriter
- Ping() error
- Stop()
- SendCommand(command string) error
+ SendMessage(command string) error
Server() string
+ Status() int
+ WithCancel(ctx context.Context) (context.Context, context.CancelFunc)
+ Done() <-chan struct{}
}
diff --git a/internal/clients/handlers/healthhandler.go b/internal/clients/handlers/healthhandler.go
index 4051e2c..9051015 100644
--- a/internal/clients/handlers/healthhandler.go
+++ b/internal/clients/handlers/healthhandler.go
@@ -8,6 +8,7 @@ import (
// HealthHandler implements the handler required for health checks.
type HealthHandler struct {
+ withCancel
// Buffer of incoming data from server.
receiveBuf []byte
// To send commands to the server.
@@ -16,6 +17,7 @@ type HealthHandler struct {
receive chan<- string
// The remote server address
server string
+ status int
}
// NewHealthHandler returns a new health check handler.
@@ -24,6 +26,10 @@ func NewHealthHandler(server string, receive chan<- string) *HealthHandler {
server: server,
receive: receive,
commands: make(chan string),
+ status: -1,
+ withCancel: withCancel{
+ done: make(chan struct{}),
+ },
}
return &h
@@ -34,18 +40,13 @@ func (h *HealthHandler) Server() string {
return h.server
}
-// Stop is not of use for health check handler.
-func (h *HealthHandler) Stop() {
- // Nothing done here.
+// Status of the handler.
+func (h *HealthHandler) Status() int {
+ return h.status
}
-// Ping is not of use for health check handler.
-func (h *HealthHandler) Ping() error {
- return nil
-}
-
-// SendCommand send a DTail command to the server.
-func (h *HealthHandler) SendCommand(command string) error {
+// SendMessage sends a DTail command to the server.
+func (h *HealthHandler) SendMessage(command string) error {
select {
case h.commands <- fmt.Sprintf("%s;", command):
case <-time.NewTimer(time.Second * 10).C:
diff --git a/internal/clients/handlers/maprhandler.go b/internal/clients/handlers/maprhandler.go
index d76cdfd..874bb7d 100644
--- a/internal/clients/handlers/maprhandler.go
+++ b/internal/clients/handlers/maprhandler.go
@@ -1,10 +1,11 @@
package handlers
import (
- "github.com/mimecast/dtail/internal/logger"
+ "strings"
+
+ "github.com/mimecast/dtail/internal/io/logger"
"github.com/mimecast/dtail/internal/mapr"
"github.com/mimecast/dtail/internal/mapr/client"
- "strings"
)
// MaprHandler is the handler used on the client side for running mapreduce aggregations.
@@ -16,15 +17,16 @@ type MaprHandler struct {
}
// NewMaprHandler returns a new mapreduce client handler.
-func NewMaprHandler(server string, query *mapr.Query, globalGroup *mapr.GlobalGroupSet, pingTimeout int) *MaprHandler {
+func NewMaprHandler(server string, query *mapr.Query, globalGroup *mapr.GlobalGroupSet) *MaprHandler {
return &MaprHandler{
baseHandler: baseHandler{
server: server,
shellStarted: false,
commands: make(chan string),
- pong: make(chan struct{}, 1),
- stop: make(chan struct{}),
- pingTimeout: pingTimeout,
+ status: -1,
+ withCancel: withCancel{
+ done: make(chan struct{}),
+ },
},
query: query,
aggregate: client.NewAggregate(server, query, globalGroup),
@@ -65,10 +67,3 @@ func (h *MaprHandler) handleAggregateMessage(message string) {
h.aggregate.Aggregate(parts[2:])
logger.Debug("Aggregated aggregate data", h.server, h.count)
}
-
-// Stop stops the mapreduce client handler.
-func (h *MaprHandler) Stop() {
- logger.Debug("Stopping mapreduce handler", h.server)
- h.aggregate.Stop()
- h.baseHandler.Stop()
-}
diff --git a/internal/clients/handlers/withcancel.go b/internal/clients/handlers/withcancel.go
new file mode 100644
index 0000000..7c9cf4e
--- /dev/null
+++ b/internal/clients/handlers/withcancel.go
@@ -0,0 +1,24 @@
+package handlers
+
+import "context"
+
+type withCancel struct {
+ ctx context.Context
+ done chan struct{}
+}
+
+// WithCancel sets and returns the context used.
+func (w *withCancel) WithCancel(ctx context.Context) (context.Context, context.CancelFunc) {
+ cancelCtx, cancel := context.WithCancel(ctx)
+ w.ctx = cancelCtx
+
+ return cancelCtx, cancel
+}
+
+func (w *withCancel) Done() <-chan struct{} {
+ return w.done
+}
+
+func (w *withCancel) shutdown() {
+ close(w.done)
+}
diff --git a/internal/clients/healthclient.go b/internal/clients/healthclient.go
index ff13b83..7313583 100644
--- a/internal/clients/healthclient.go
+++ b/internal/clients/healthclient.go
@@ -1,6 +1,7 @@
package clients
import (
+ "context"
"fmt"
"runtime"
"strings"
@@ -39,7 +40,7 @@ func NewHealthClient(mode omode.Mode) (*HealthClient, error) {
}
// Start the health client.
-func (c *HealthClient) Start() (status int) {
+func (c *HealthClient) Start(ctx context.Context) (status int) {
receive := make(chan string)
throttleCh := make(chan struct{}, runtime.NumCPU())
@@ -49,8 +50,8 @@ func (c *HealthClient) Start() (status int) {
conn.Handler = handlers.NewHealthHandler(c.server, receive)
conn.Commands = []string{c.mode.String()}
- go conn.Start(throttleCh, statsCh)
- defer conn.Stop()
+ connCtx, cancel := conn.Handler.WithCancel(ctx)
+ go conn.Start(connCtx, cancel, throttleCh, statsCh)
for {
select {
diff --git a/internal/clients/maker.go b/internal/clients/maker.go
new file mode 100644
index 0000000..da9dfc9
--- /dev/null
+++ b/internal/clients/maker.go
@@ -0,0 +1,8 @@
+package clients
+
+import "github.com/mimecast/dtail/internal/clients/handlers"
+
+type maker interface {
+ makeHandler(server string) handlers.Handler
+ makeCommands() (commands []string)
+}
diff --git a/internal/clients/maprclient.go b/internal/clients/maprclient.go
index 9070827..b581844 100644
--- a/internal/clients/maprclient.go
+++ b/internal/clients/maprclient.go
@@ -1,6 +1,7 @@
package clients
import (
+ "context"
"errors"
"fmt"
"runtime"
@@ -8,13 +9,9 @@ import (
"time"
"github.com/mimecast/dtail/internal/clients/handlers"
- "github.com/mimecast/dtail/internal/clients/remote"
- "github.com/mimecast/dtail/internal/logger"
+ "github.com/mimecast/dtail/internal/io/logger"
"github.com/mimecast/dtail/internal/mapr"
"github.com/mimecast/dtail/internal/omode"
- "github.com/mimecast/dtail/internal/ssh/client"
-
- gossh "golang.org/x/crypto/ssh"
)
// MaprClient is used for running mapreduce aggregations on remote files.
@@ -39,8 +36,6 @@ func NewMaprClient(args Args, queryStr string) (*MaprClient, error) {
c := MaprClient{
baseClient: baseClient{
Args: args,
- stop: make(chan struct{}),
- stopped: make(chan struct{}),
throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()),
retry: args.Mode == omode.TailClient,
},
@@ -70,35 +65,36 @@ func NewMaprClient(args Args, queryStr string) (*MaprClient, error) {
return &c, nil
}
-func (c MaprClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection {
- conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback)
- conn.Handler = handlers.NewMaprHandler(conn.Server, c.query, c.globalGroup, c.PingTimeout)
+// Start starts the mapreduce client.
+func (c *MaprClient) Start(ctx context.Context) (status int) {
+ if c.query.Outfile == "" {
+ // Only print out periodic results if we don't write an outfile
+ go c.periodicPrintResults(ctx)
+ }
- conn.Commands = append(conn.Commands, fmt.Sprintf("map %s", c.query.RawQuery))
- commandStr := "tail"
+ status = c.baseClient.Start(ctx)
if c.additative {
- commandStr = "cat"
+ c.recievedFinalResult()
}
- for _, file := range strings.Split(c.Files, ",") {
- conn.Commands = append(conn.Commands, fmt.Sprintf("%s %s regex %s", commandStr, file, c.Regex))
- }
+ return
+}
- return conn
+func (c MaprClient) makeHandler(server string) handlers.Handler {
+ return handlers.NewMaprHandler(server, c.query, c.globalGroup)
}
-// Start starts the mapreduce client.
-func (c *MaprClient) Start() (status int) {
- if c.query.Outfile == "" {
- // Only print out periodic results if we don't write an outfile
- go c.periodicPrintResults()
- }
+func (c MaprClient) makeCommands() (commands []string) {
+ commands = append(commands, fmt.Sprintf("map %s", c.query.RawQuery))
- status = c.baseClient.Start()
+ modeStr := "tail"
if c.additative {
- c.recievedFinalResult()
+ modeStr = "cat"
+ }
+
+ for _, file := range strings.Split(c.What, ",") {
+ commands = append(commands, fmt.Sprintf("%s %s regex %s", modeStr, file, c.Regex))
}
- c.baseClient.Stop()
return
}
@@ -120,13 +116,13 @@ func (c *MaprClient) recievedFinalResult() {
logger.Info(fmt.Sprintf("Wrote final mapreduce result to '%s'", c.query.Outfile))
}
-func (c *MaprClient) periodicPrintResults() {
+func (c *MaprClient) periodicPrintResults(ctx context.Context) {
for {
select {
case <-time.After(c.query.Interval):
logger.Info("Gathering interim mapreduce result")
c.printResults()
- case <-c.baseClient.stop:
+ case <-ctx.Done():
return
}
}
diff --git a/internal/clients/remote/connection.go b/internal/clients/remote/connection.go
index bfc7bc5..71639b1 100644
--- a/internal/clients/remote/connection.go
+++ b/internal/clients/remote/connection.go
@@ -1,16 +1,18 @@
package remote
import (
- "github.com/mimecast/dtail/internal/clients/handlers"
- "github.com/mimecast/dtail/internal/config"
- "github.com/mimecast/dtail/internal/logger"
- "github.com/mimecast/dtail/internal/ssh/client"
+ "context"
"fmt"
"io"
"strconv"
"strings"
"time"
+ "github.com/mimecast/dtail/internal/clients/handlers"
+ "github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/ssh/client"
+
"golang.org/x/crypto/ssh"
)
@@ -30,8 +32,6 @@ type Connection struct {
Commands []string
// Is it a persistent connection or a one-off?
isOneOff bool
- // Used to stop the connection
- stop chan struct{}
// To deal with SSH server host keys
hostKeyCallback *client.HostKeyCallback
}
@@ -48,7 +48,6 @@ func NewConnection(server string, userName string, authMethods []ssh.AuthMethod,
HostKeyCallback: hostKeyCallback.Wrap(),
Timeout: time.Second * 3,
},
- stop: make(chan struct{}),
}
c.initServerPort(server)
@@ -64,7 +63,6 @@ func NewOneOffConnection(server string, userName string, authMethods []ssh.AuthM
Auth: authMethods,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
},
- stop: make(chan struct{}),
isOneOff: true,
}
@@ -90,39 +88,34 @@ func (c *Connection) initServerPort(server string) {
}
}
-// Start the server connection. Build up SSH session and send some DTail commandc.
-func (c *Connection) Start(throttleCh, statsCh chan struct{}) {
+// Start the server connection. Build up SSH session and send some DTail commands.
+func (c *Connection) Start(ctx context.Context, cancel context.CancelFunc, throttleCh, statsCh chan struct{}) {
+ // Throttle how many connections can be established concurrently (based on ch length)
select {
- case <-c.stop:
- logger.Info(c.Server, c.port, "Disconnecting client")
+ case throttleCh <- struct{}{}:
+ defer func() { <-throttleCh }()
+ case <-ctx.Done():
return
- default:
}
- // Wait for SSH connection throttler
- throttleCh <- struct{}{}
-
- // Wait until connection has been initiated or an error occured
- // during initialization.
- throttleStopCh := make(chan struct{}, 2)
go func() {
- <-throttleStopCh
- <-throttleCh
- }()
+ defer cancel()
- if err := c.dial(c.Server, c.port, throttleStopCh, statsCh); err != nil {
- logger.Warn(c.Server, c.port, err)
- throttleStopCh <- struct{}{}
+ if err := c.dial(ctx, cancel, c.Server, c.port, statsCh); err != nil {
+ logger.Warn(c.Server, c.port, err)
- if c.hostKeyCallback.Untrusted(fmt.Sprintf("%s:%d", c.Server, c.port)) {
- logger.Debug("Not trusting host, not trying to re-connect", c.Server, c.port)
- return
+ if c.hostKeyCallback.Untrusted(fmt.Sprintf("%s:%d", c.Server, c.port)) {
+ logger.Debug("Not trusting host", c.Server, c.port)
+ return
+ }
}
- }
+ }()
+
+ <-ctx.Done()
}
// Dail into a new SSH connection. Close connection in case of an error.
-func (c *Connection) dial(host string, port int, throttleStopCh, statsCh chan struct{}) error {
+func (c *Connection) dial(ctx context.Context, cancel context.CancelFunc, host string, port int, statsCh chan struct{}) error {
statsCh <- struct{}{}
defer func() { <-statsCh }()
@@ -135,11 +128,11 @@ func (c *Connection) dial(host string, port int, throttleStopCh, statsCh chan st
}
defer client.Close()
- return c.session(client, throttleStopCh)
+ return c.session(ctx, cancel, client)
}
// Create the SSH session. Close the session in case of an error.
-func (c *Connection) session(client *ssh.Client, throttleStopCh chan<- struct{}) error {
+func (c *Connection) session(ctx context.Context, cancel context.CancelFunc, client *ssh.Client) error {
logger.Debug(c.Server, "session")
session, err := client.NewSession()
@@ -148,14 +141,10 @@ func (c *Connection) session(client *ssh.Client, throttleStopCh chan<- struct{})
}
defer session.Close()
- return c.handle(session, throttleStopCh)
+ return c.handle(ctx, cancel, session)
}
-// Handle the SSH session. Also send periodic pings to the server in order
-// to determine that session is still intact.
-func (c *Connection) handle(session *ssh.Session, throttleStopCh chan<- struct{}) error {
- defer c.Handler.Stop()
-
+func (c *Connection) handle(ctx context.Context, cancel context.CancelFunc, session *ssh.Session) error {
logger.Debug(c.Server, "handle")
stdinPipe, err := session.StdinPipe()
@@ -172,59 +161,30 @@ func (c *Connection) handle(session *ssh.Session, throttleStopCh chan<- struct{}
return err
}
- // Establish Bi-directional pipe between SSH session and client handler.
- brokenStdinPipe := make(chan struct{})
go func() {
- defer close(brokenStdinPipe)
+ defer cancel()
io.Copy(stdinPipe, c.Handler)
}()
- brokenStdoutPipe := make(chan struct{})
go func() {
- defer close(brokenStdoutPipe)
+ defer cancel()
io.Copy(c.Handler, stdoutPipe)
}()
- // SSH session established, other goroutine can initiate session now.
- throttleStopCh <- struct{}{}
+ go func() {
+ defer cancel()
+ select {
+ case <-c.Handler.Done():
+ case <-ctx.Done():
+ }
+ }()
// Send all commands to client.
for _, command := range c.Commands {
logger.Debug(command)
- c.Handler.SendCommand(command)
+ c.Handler.SendMessage(command)
}
- if !c.isOneOff {
- return c.periodicAliveCheck(brokenStdinPipe, brokenStdoutPipe)
- }
-
- <-c.stop
-
- // Normal shutdown, all fine
+ <-ctx.Done()
return nil
}
-
-// Periodically check whether connection is still alive or not.
-func (c *Connection) periodicAliveCheck(brokenStdinPipe, brokenStdoutPipe <-chan struct{}) error {
- for {
- select {
- case <-time.After(time.Second * 3):
- if err := c.Handler.Ping(); err != nil {
- return err
- }
- case <-brokenStdinPipe:
- logger.Debug("Broken stdin pipe", c.Server, c.port)
- return nil
- case <-brokenStdoutPipe:
- logger.Debug("Broken stdout pipe", c.Server, c.port)
- return nil
- case <-c.stop:
- return nil
- }
- }
-}
-
-// Stop the connection.
-func (c *Connection) Stop() {
- close(c.stop)
-}
diff --git a/internal/clients/runclient.go b/internal/clients/runclient.go
new file mode 100644
index 0000000..7a62fcc
--- /dev/null
+++ b/internal/clients/runclient.go
@@ -0,0 +1,40 @@
+package clients
+
+import (
+ "fmt"
+ "runtime"
+
+ "github.com/mimecast/dtail/internal/clients/handlers"
+ "github.com/mimecast/dtail/internal/omode"
+)
+
+// RunClient is a client to run various commands on the server.
+type RunClient struct {
+ baseClient
+}
+
+// NewRunClient returns a new cat client.
+func NewRunClient(args Args) (*RunClient, error) {
+ args.Mode = omode.RunClient
+
+ c := RunClient{
+ baseClient: baseClient{
+ Args: args,
+ throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()),
+ retry: false,
+ },
+ }
+
+ c.init(c)
+ return &c, nil
+}
+
+func (c RunClient) makeHandler(server string) handlers.Handler {
+ return handlers.NewClientHandler(server)
+}
+
+func (c RunClient) makeCommands() (commands []string) {
+ // Send "run COMMAND" to server!
+ commands = append(commands, fmt.Sprintf("%s %s", c.Mode.String(), c.What))
+ return
+}
diff --git a/internal/clients/stats.go b/internal/clients/stats.go
index d36cef6..ec6adfe 100644
--- a/internal/clients/stats.go
+++ b/internal/clients/stats.go
@@ -1,11 +1,13 @@
package clients
import (
- "github.com/mimecast/dtail/internal/logger"
+ "context"
"fmt"
"runtime"
"sync"
"time"
+
+ "github.com/mimecast/dtail/internal/io/logger"
)
// Used to collect and display various client stats.
@@ -28,14 +30,14 @@ func newTailStats(connectionsTotal int) *stats {
}
}
-func (s *stats) periodicLogStats(throttleCh chan struct{}, stop <-chan struct{}) {
+func (s *stats) periodicLogStats(ctx context.Context, throttleCh chan struct{}) {
connectedLast := 0
statsInterval := 5
for {
select {
case <-time.After(time.Second * time.Duration(statsInterval)):
- case <-stop:
+ case <-ctx.Done():
return
}
diff --git a/internal/clients/tailclient.go b/internal/clients/tailclient.go
index 674ca36..4d81fd5 100644
--- a/internal/clients/tailclient.go
+++ b/internal/clients/tailclient.go
@@ -6,11 +6,7 @@ import (
"strings"
"github.com/mimecast/dtail/internal/clients/handlers"
- "github.com/mimecast/dtail/internal/clients/remote"
"github.com/mimecast/dtail/internal/omode"
- "github.com/mimecast/dtail/internal/ssh/client"
-
- gossh "golang.org/x/crypto/ssh"
)
// TailClient is used for tailing remote log files (opening, seeking to the end and returning only new incoming lines).
@@ -25,25 +21,22 @@ func NewTailClient(args Args) (*TailClient, error) {
c := TailClient{
baseClient: baseClient{
Args: args,
- stop: make(chan struct{}),
- stopped: make(chan struct{}),
throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()),
retry: true,
},
}
c.init(c)
-
return &c, nil
}
-func (c TailClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection {
- conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback)
- conn.Handler = handlers.NewClientHandler(server, c.PingTimeout)
+func (c TailClient) makeHandler(server string) handlers.Handler {
+ return handlers.NewClientHandler(server)
+}
- for _, file := range strings.Split(c.Files, ",") {
- conn.Commands = append(conn.Commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex))
+func (c TailClient) makeCommands() (commands []string) {
+ for _, file := range strings.Split(c.What, ",") {
+ commands = append(commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex))
}
-
- return conn
+ return
}