summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorPaul Bütow <pbuetow@mimecast.com>2020-01-26 11:26:53 +0000
committerPaul Bütow <pbuetow@mimecast.com>2020-02-07 13:31:15 +0000
commit0945da8dfefcbb723eecea0e5f4eafff63398253 (patch)
treef06dab4d2bf21d25d176b23d5baeca588d27f5d7 /internal
parent2a8e5de265a0e0a31a5834909d6879f5c9941467 (diff)
Introduce drun command, refactor code to use context package
Diffstat (limited to 'internal')
-rw-r--r--internal/clients/args.go3
-rw-r--r--internal/clients/baseclient.go130
-rw-r--r--internal/clients/catclient.go20
-rw-r--r--internal/clients/client.go5
-rw-r--r--internal/clients/connectionmaker.go12
-rw-r--r--internal/clients/execclient.go48
-rw-r--r--internal/clients/grepclient.go20
-rw-r--r--internal/clients/handlers/basehandler.go84
-rw-r--r--internal/clients/handlers/clienthandler.go11
-rw-r--r--internal/clients/handlers/handler.go12
-rw-r--r--internal/clients/handlers/healthhandler.go21
-rw-r--r--internal/clients/handlers/maprhandler.go21
-rw-r--r--internal/clients/handlers/withcancel.go24
-rw-r--r--internal/clients/healthclient.go7
-rw-r--r--internal/clients/maker.go8
-rw-r--r--internal/clients/maprclient.go52
-rw-r--r--internal/clients/remote/connection.go116
-rw-r--r--internal/clients/runclient.go40
-rw-r--r--internal/clients/stats.go8
-rw-r--r--internal/clients/tailclient.go21
-rw-r--r--internal/discovery/comma.go2
-rw-r--r--internal/discovery/discovery.go21
-rw-r--r--internal/discovery/file.go2
-rw-r--r--internal/io/fs/catfile.go (renamed from internal/fs/catfile.go)6
-rw-r--r--internal/io/fs/filereader.go (renamed from internal/fs/filereader.go)9
-rw-r--r--internal/io/fs/permissions/permission.go (renamed from internal/fs/permissions/permission.go)2
-rw-r--r--internal/io/fs/permissions/permission_linux.c (renamed from internal/fs/permissions/permission_linux.c)0
-rw-r--r--internal/io/fs/permissions/permission_linux.go (renamed from internal/fs/permissions/permission_linux.go)0
-rw-r--r--internal/io/fs/permissions/permission_linux.h (renamed from internal/fs/permissions/permission_linux.h)0
-rw-r--r--internal/io/fs/permissions/permission_test.go (renamed from internal/fs/permissions/permission_test.go)0
-rw-r--r--internal/io/fs/readfile.go (renamed from internal/fs/readfile.go)73
-rw-r--r--internal/io/fs/stats.go (renamed from internal/fs/stats.go)0
-rw-r--r--internal/io/fs/tailfile.go (renamed from internal/fs/tailfile.go)6
-rw-r--r--internal/io/line/line.go (renamed from internal/fs/lineread.go)14
-rw-r--r--internal/io/logger/logger.go (renamed from internal/logger/logger.go)56
-rw-r--r--internal/io/run/run.go104
-rw-r--r--internal/mapr/aggregateset.go5
-rw-r--r--internal/mapr/client/aggregate.go25
-rw-r--r--internal/mapr/groupset.go5
-rw-r--r--internal/mapr/logformat/parser.go2
-rw-r--r--internal/mapr/query.go2
-rw-r--r--internal/mapr/server/aggregate.go141
-rw-r--r--internal/mapr/wherecondition.go2
-rw-r--r--internal/omode/mode.go6
-rw-r--r--internal/pprof/pprof.go3
-rw-r--r--internal/prompt/prompt.go2
-rw-r--r--internal/server/handlers/controlhandler.go42
-rw-r--r--internal/server/handlers/handler.go2
-rw-r--r--internal/server/handlers/mapcommand.go35
-rw-r--r--internal/server/handlers/readcommand.go158
-rw-r--r--internal/server/handlers/runcommand.go73
-rw-r--r--internal/server/handlers/serverhandler.go521
-rw-r--r--internal/server/server.go70
-rw-r--r--internal/server/stats.go10
-rw-r--r--internal/ssh/client/authmethods.go2
-rw-r--r--internal/ssh/client/hostkeycallback.go10
-rw-r--r--internal/ssh/server/hostkey.go2
-rw-r--r--internal/ssh/server/publickeycallback.go2
-rw-r--r--internal/ssh/ssh.go2
-rw-r--r--internal/user/name.go15
-rw-r--r--internal/user/server/user.go44
-rw-r--r--internal/version/version.go22
62 files changed, 1161 insertions, 1000 deletions
diff --git a/internal/clients/args.go b/internal/clients/args.go
index 5fe0a72..dea5a9e 100644
--- a/internal/clients/args.go
+++ b/internal/clients/args.go
@@ -9,10 +9,9 @@ type Args struct {
Mode omode.Mode
ServersStr string
UserName string
- Files string
+ What string
Regex string
TrustAllHosts bool
Discovery string
ConnectionsPerCPU int
- PingTimeout int
}
diff --git a/internal/clients/baseclient.go b/internal/clients/baseclient.go
index 574ae94..b1540ea 100644
--- a/internal/clients/baseclient.go
+++ b/internal/clients/baseclient.go
@@ -1,13 +1,14 @@
package clients
import (
+ "context"
"regexp"
"sync"
"time"
"github.com/mimecast/dtail/internal/clients/remote"
"github.com/mimecast/dtail/internal/discovery"
- "github.com/mimecast/dtail/internal/logger"
+ "github.com/mimecast/dtail/internal/io/logger"
"github.com/mimecast/dtail/internal/omode"
"github.com/mimecast/dtail/internal/ssh/client"
@@ -27,111 +28,110 @@ type baseClient struct {
sshAuthMethods []gossh.AuthMethod
// To deal with SSH host keys
hostKeyCallback *client.HostKeyCallback
- // To stop the client.
- stop chan struct{}
- // To indicate that the client has stopped.
- stopped chan struct{}
// Throttle how fast we initiate SSH connections concurrently
throttleCh chan struct{}
// Retry connection upon failure?
retry bool
- // Connection helper.
- maker connectionMaker
+ // Connection maker helper.
+ maker maker
}
-func (c *baseClient) init(maker connectionMaker) {
+func (c *baseClient) init(maker maker) {
logger.Info("Initiating base client")
c.maker = maker
- //c.connections = make(map[string]*remote.Connection)
c.sshAuthMethods, c.hostKeyCallback = client.InitSSHAuthMethods(c.TrustAllHosts, c.throttleCh)
+ discoveryService := discovery.New(c.Discovery, c.ServersStr, discovery.Shuffle)
- // Retrieve a shuffled list of remote dtail servers.
- shuffleServers := true
- discoveryService := discovery.New(c.Discovery, c.ServersStr, shuffleServers)
for _, server := range discoveryService.ServerList() {
- c.connections = append(c.connections, c.maker.makeConnection(server, c.sshAuthMethods, c.hostKeyCallback))
+ c.connections = append(c.connections, c.makeConnection(server, c.sshAuthMethods, c.hostKeyCallback))
}
if _, err := regexp.Compile(c.Regex); err != nil {
logger.FatalExit(c.Regex, "Can't test compile regex", err)
}
- // Periodically check for unknown hosts, and ask the user whether to trust them or not.
- go c.hostKeyCallback.PromptAddHosts(c.stop)
-
- // Periodically print out connection stats to the client.
c.stats = newTailStats(len(c.connections))
- go c.stats.periodicLogStats(c.throttleCh, c.stop)
}
-func (c *baseClient) Start() (status int) {
+func (c *baseClient) Start(ctx context.Context) (status int) {
+ // Periodically check for unknown hosts, and ask the user whether to trust them or not.
+ go c.hostKeyCallback.PromptAddHosts(ctx)
+ // Periodically print out connection stats to the client.
+ go c.stats.periodicLogStats(ctx, c.throttleCh)
+ // Keep count of active connections
active := make(chan struct{}, len(c.connections))
- var wg sync.WaitGroup
- wg.Add(len(c.connections))
-
+ var mutex sync.Mutex
for i, conn := range c.connections {
go func(i int, conn *remote.Connection) {
- active <- struct{}{}
- defer func() {
- logger.Debug(conn.Server, "Disconnected completely...")
- <-active
- }()
- wg.Done()
-
- for {
- conn.Start(c.throttleCh, c.stats.connectionsEstCh)
- if !c.retry {
- return
- }
- time.Sleep(time.Second * 2)
- logger.Debug(conn.Server, "Reconencting")
- conn = c.maker.makeConnection(conn.Server, c.sshAuthMethods, c.hostKeyCallback)
- c.connections[i] = conn
+ connStatus := c.start(ctx, active, i, conn)
+
+ // Update global status.
+ mutex.Lock()
+ defer mutex.Unlock()
+ if connStatus > status {
+ status = connStatus
}
}(i, conn)
}
- wg.Wait()
- c.waitUntilDone(active)
-
+ c.waitUntilDone(ctx, active)
return
}
-func (c *baseClient) waitUntilDone(active chan struct{}) {
- defer close(c.stopped)
+func (c *baseClient) start(ctx context.Context, active chan struct{}, i int, conn *remote.Connection) (status int) {
+ // Increment connection count
+ active <- struct{}{}
+ // Derement connection count
+ defer func() { <-active }()
- if c.Mode != omode.TailClient {
- c.waitUntilZero(active)
- logger.Info("All connections stopped")
- return
- }
+ for {
+ connCtx, cancel := conn.Handler.WithCancel(ctx)
+ defer cancel()
- <-c.stop
- logger.Info("Stopping client")
- for _, conn := range c.connections {
- conn.Stop()
+ conn.Start(connCtx, cancel, c.throttleCh, c.stats.connectionsEstCh)
+ // Retrieve status code from handler (dtail client will exit with that status)
+ status = conn.Handler.Status()
+
+ if !c.retry {
+ return
+ }
+
+ time.Sleep(time.Second * 2)
+ logger.Debug(conn.Server, "Reconnecting")
+
+ conn = c.makeConnection(conn.Server, c.sshAuthMethods, c.hostKeyCallback)
+ c.connections[i] = conn
}
+}
- c.waitUntilZero(active)
+func (c *baseClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection {
+ conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback)
+ conn.Handler = c.maker.makeHandler(server)
+ conn.Commands = c.maker.makeCommands()
+
+ return conn
}
-func (c *baseClient) waitUntilZero(active chan struct{}) {
+func (c *baseClient) waitUntilDone(ctx context.Context, active chan struct{}) {
+ defer logger.Info("Terminated connection")
+
+ // We want to have at least one active connection
+ <-active
+ // Put it back on the channel
+ active <- struct{}{}
+
+ if c.Mode == omode.TailClient {
+ <-ctx.Done()
+ }
+
for {
- logger.Debug("Active connections", len(active))
- if len(active) == 0 {
+ numActive := len(active)
+ if numActive == 0 {
return
}
+ logger.Debug("Active connections", numActive)
time.Sleep(time.Second)
}
}
-
-func (c *baseClient) Stop() {
- close(c.stop)
- <-c.WaitC()
-}
-
-func (c *baseClient) WaitC() <-chan struct{} {
- return c.stopped
-}
diff --git a/internal/clients/catclient.go b/internal/clients/catclient.go
index 5ea701d..7fd6bdc 100644
--- a/internal/clients/catclient.go
+++ b/internal/clients/catclient.go
@@ -7,11 +7,7 @@ import (
"strings"
"github.com/mimecast/dtail/internal/clients/handlers"
- "github.com/mimecast/dtail/internal/clients/remote"
"github.com/mimecast/dtail/internal/omode"
- "github.com/mimecast/dtail/internal/ssh/client"
-
- gossh "golang.org/x/crypto/ssh"
)
// CatClient is a client for returning a whole file from the beginning to the end.
@@ -31,8 +27,6 @@ func NewCatClient(args Args) (*CatClient, error) {
c := CatClient{
baseClient: baseClient{
Args: args,
- stop: make(chan struct{}),
- stopped: make(chan struct{}),
throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()),
retry: false,
},
@@ -43,11 +37,13 @@ func NewCatClient(args Args) (*CatClient, error) {
return &c, nil
}
-func (c CatClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection {
- conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback)
- conn.Handler = handlers.NewClientHandler(server, c.PingTimeout)
- for _, file := range strings.Split(c.Files, ",") {
- conn.Commands = append(conn.Commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex))
+func (c CatClient) makeHandler(server string) handlers.Handler {
+ return handlers.NewClientHandler(server)
+}
+
+func (c CatClient) makeCommands() (commands []string) {
+ for _, file := range strings.Split(c.What, ",") {
+ commands = append(commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex))
}
- return conn
+ return
}
diff --git a/internal/clients/client.go b/internal/clients/client.go
index 85d1aae..1fc5e23 100644
--- a/internal/clients/client.go
+++ b/internal/clients/client.go
@@ -1,7 +1,8 @@
package clients
+import "context"
+
// Client is the interface for the end user command line client.
type Client interface {
- Start() int
- Stop()
+ Start(ctx context.Context) int
}
diff --git a/internal/clients/connectionmaker.go b/internal/clients/connectionmaker.go
deleted file mode 100644
index 0617992..0000000
--- a/internal/clients/connectionmaker.go
+++ /dev/null
@@ -1,12 +0,0 @@
-package clients
-
-import (
- "github.com/mimecast/dtail/internal/clients/remote"
- "github.com/mimecast/dtail/internal/ssh/client"
-
- gossh "golang.org/x/crypto/ssh"
-)
-
-type connectionMaker interface {
- makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection
-}
diff --git a/internal/clients/execclient.go b/internal/clients/execclient.go
deleted file mode 100644
index 10bd081..0000000
--- a/internal/clients/execclient.go
+++ /dev/null
@@ -1,48 +0,0 @@
-package clients
-
-import (
- "fmt"
- "runtime"
- "strings"
-
- "github.com/mimecast/dtail/internal/clients/handlers"
- "github.com/mimecast/dtail/internal/clients/remote"
- "github.com/mimecast/dtail/internal/omode"
- "github.com/mimecast/dtail/internal/ssh/client"
-
- gossh "golang.org/x/crypto/ssh"
-)
-
-// ExecClient is a client for execute various commands on the server.
-type ExecClient struct {
- baseClient
-}
-
-// NewExecClient returns a new cat client.
-func NewExecClient(args Args) (*ExecClient, error) {
- args.Regex = "."
- args.Mode = omode.ExecClient
-
- c := ExecClient{
- baseClient: baseClient{
- Args: args,
- stop: make(chan struct{}),
- stopped: make(chan struct{}),
- throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()),
- retry: false,
- },
- }
-
- c.init(c)
-
- return &c, nil
-}
-
-func (c ExecClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection {
- conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback)
- conn.Handler = handlers.NewClientHandler(server, c.PingTimeout)
- for _, file := range strings.Split(c.Files, ";") {
- conn.Commands = append(conn.Commands, fmt.Sprintf("%s %s", c.Mode.String(), file))
- }
- return conn
-}
diff --git a/internal/clients/grepclient.go b/internal/clients/grepclient.go
index c568f63..8d11458 100644
--- a/internal/clients/grepclient.go
+++ b/internal/clients/grepclient.go
@@ -7,11 +7,7 @@ import (
"strings"
"github.com/mimecast/dtail/internal/clients/handlers"
- "github.com/mimecast/dtail/internal/clients/remote"
"github.com/mimecast/dtail/internal/omode"
- "github.com/mimecast/dtail/internal/ssh/client"
-
- gossh "golang.org/x/crypto/ssh"
)
// GrepClient searches a remote file for all lines matching a regular expression. Only the matching lines are displayed.
@@ -29,8 +25,6 @@ func NewGrepClient(args Args) (*GrepClient, error) {
c := GrepClient{
baseClient: baseClient{
Args: args,
- stop: make(chan struct{}),
- stopped: make(chan struct{}),
throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()),
retry: false,
},
@@ -41,13 +35,13 @@ func NewGrepClient(args Args) (*GrepClient, error) {
return &c, nil
}
-func (c GrepClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection {
- conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback)
- conn.Handler = handlers.NewClientHandler(server, c.PingTimeout)
+func (c GrepClient) makeHandler(server string) handlers.Handler {
+ return handlers.NewClientHandler(server)
+}
- for _, file := range strings.Split(c.Files, ",") {
- conn.Commands = append(conn.Commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex))
+func (c GrepClient) makeCommands() (commands []string) {
+ for _, file := range strings.Split(c.What, ",") {
+ commands = append(commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex))
}
-
- return conn
+ return
}
diff --git a/internal/clients/handlers/basehandler.go b/internal/clients/handlers/basehandler.go
index 19246f9..68b8ddc 100644
--- a/internal/clients/handlers/basehandler.go
+++ b/internal/clients/handlers/basehandler.go
@@ -1,60 +1,44 @@
package handlers
import (
- "github.com/mimecast/dtail/internal/logger"
- "errors"
+ "encoding/base64"
"fmt"
"io"
+ "strconv"
"strings"
"time"
+
+ "github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/version"
)
type baseHandler struct {
+ withCancel
server string
shellStarted bool
commands chan string
- pong chan struct{}
receiveBuf []byte
- stop chan struct{}
- pingTimeout int
+ status int
}
func (h *baseHandler) Server() string {
return h.server
}
-// Used to determine whether server is still responding to requests or not.
-func (h *baseHandler) Ping() error {
- if h.pingTimeout == 0 {
- // Server ping disabled
- return nil
- }
-
- if err := h.SendCommand("ping"); err != nil {
- return err
- }
-
- select {
- case <-h.pong:
- return nil
- case <-time.After(time.Duration(h.pingTimeout) * time.Second):
- }
-
- return errors.New("Didn't receive any server pongs (ping replies)")
+func (h *baseHandler) Status() int {
+ return h.status
}
-func (h *baseHandler) SendCommand(command string) error {
- if command == "ping" {
- logger.Trace("Sending command", h.server, command)
- } else {
- logger.Debug("Sending command", h.server, command)
- }
+// SendMessage to the server.
+func (h *baseHandler) SendMessage(command string) error {
+ encoded := base64.StdEncoding.EncodeToString([]byte(command))
+ logger.Debug("Sending command", h.server, command, encoded)
select {
- case h.commands <- fmt.Sprintf("%s;", command):
+ case h.commands <- fmt.Sprintf("protocol %s base64 %v;", version.ProtocolCompat, encoded):
case <-time.After(time.Second * 5):
- return errors.New("Timed out sending command " + command)
- case <-h.stop:
+ return fmt.Errorf("Timed out sending command '%s' (base64: '%s')", command, encoded)
+ case <-h.ctx.Done():
}
return nil
@@ -81,7 +65,7 @@ func (h *baseHandler) Read(p []byte) (n int, err error) {
select {
case command := <-h.commands:
n = copy(p, []byte(command))
- case <-h.stop:
+ case <-h.ctx.Done():
return 0, io.EOF
}
return
@@ -92,6 +76,7 @@ func (h *baseHandler) handleMessageType(message string) {
if len(h.receiveBuf) == 0 {
return
}
+
// Hidden server commands starti with a dot "."
if h.receiveBuf[0] == '.' {
h.handleHiddenMessage(message)
@@ -108,6 +93,7 @@ func (h *baseHandler) handleMessageType(message string) {
h.receiveBuf = h.receiveBuf[:0]
return
}
+
logger.Raw(message)
h.receiveBuf = h.receiveBuf[:0]
}
@@ -116,19 +102,27 @@ func (h *baseHandler) handleMessageType(message string) {
// to the end user.
func (h *baseHandler) handleHiddenMessage(message string) {
switch {
- case strings.HasPrefix(message, ".pong"):
- h.pong <- struct{}{}
case strings.HasPrefix(message, ".syn close connection"):
- h.SendCommand("ack close connection")
- }
-}
+ h.SendMessage(".ack close connection")
+ select {
+ case <-time.After(time.Second * 1):
+ logger.Debug("Shutting down client after timeout and sending ack to server")
+ h.withCancel.shutdown()
+ case <-h.ctx.Done():
+ }
-// Stop the handler.
-func (h *baseHandler) Stop() {
- select {
- case <-h.stop:
- default:
- logger.Debug("Stopping base handler", h.server)
- close(h.stop)
+ case strings.HasPrefix(message, ".run exitstatus"):
+ splitted := strings.Split(strings.TrimSuffix(message, "\n"), " ")
+ if len(splitted) != 3 {
+ logger.Error("Unable to retrieve exitstatus", message)
+ return
+ }
+ i, err := strconv.Atoi(splitted[2])
+ if err != nil {
+ logger.Error("Unable to retrieve exitstatus", message, err)
+ return
+ }
+ h.status = i
+ logger.Debug("Retrieved exitstatus", h.status)
}
}
diff --git a/internal/clients/handlers/clienthandler.go b/internal/clients/handlers/clienthandler.go
index 4738cd3..fcd8052 100644
--- a/internal/clients/handlers/clienthandler.go
+++ b/internal/clients/handlers/clienthandler.go
@@ -1,7 +1,7 @@
package handlers
import (
- "github.com/mimecast/dtail/internal/logger"
+ "github.com/mimecast/dtail/internal/io/logger"
)
// ClientHandler is the basic client handler interface.
@@ -10,7 +10,7 @@ type ClientHandler struct {
}
// NewClientHandler creates a new client handler.
-func NewClientHandler(server string, pingTimeout int) *ClientHandler {
+func NewClientHandler(server string) *ClientHandler {
logger.Debug(server, "Creating new client handler")
return &ClientHandler{
@@ -18,9 +18,10 @@ func NewClientHandler(server string, pingTimeout int) *ClientHandler {
server: server,
shellStarted: false,
commands: make(chan string),
- pong: make(chan struct{}, 1),
- stop: make(chan struct{}),
- pingTimeout: pingTimeout,
+ status: -1,
+ withCancel: withCancel{
+ done: make(chan struct{}),
+ },
},
}
}
diff --git a/internal/clients/handlers/handler.go b/internal/clients/handlers/handler.go
index 2013be0..c53ca34 100644
--- a/internal/clients/handlers/handler.go
+++ b/internal/clients/handlers/handler.go
@@ -1,12 +1,16 @@
package handlers
-import "io"
+import (
+ "context"
+ "io"
+)
// Handler provides all methods which can be run on any client handler.
type Handler interface {
io.ReadWriter
- Ping() error
- Stop()
- SendCommand(command string) error
+ SendMessage(command string) error
Server() string
+ Status() int
+ WithCancel(ctx context.Context) (context.Context, context.CancelFunc)
+ Done() <-chan struct{}
}
diff --git a/internal/clients/handlers/healthhandler.go b/internal/clients/handlers/healthhandler.go
index 4051e2c..9051015 100644
--- a/internal/clients/handlers/healthhandler.go
+++ b/internal/clients/handlers/healthhandler.go
@@ -8,6 +8,7 @@ import (
// HealthHandler implements the handler required for health checks.
type HealthHandler struct {
+ withCancel
// Buffer of incoming data from server.
receiveBuf []byte
// To send commands to the server.
@@ -16,6 +17,7 @@ type HealthHandler struct {
receive chan<- string
// The remote server address
server string
+ status int
}
// NewHealthHandler returns a new health check handler.
@@ -24,6 +26,10 @@ func NewHealthHandler(server string, receive chan<- string) *HealthHandler {
server: server,
receive: receive,
commands: make(chan string),
+ status: -1,
+ withCancel: withCancel{
+ done: make(chan struct{}),
+ },
}
return &h
@@ -34,18 +40,13 @@ func (h *HealthHandler) Server() string {
return h.server
}
-// Stop is not of use for health check handler.
-func (h *HealthHandler) Stop() {
- // Nothing done here.
+// Status of the handler.
+func (h *HealthHandler) Status() int {
+ return h.status
}
-// Ping is not of use for health check handler.
-func (h *HealthHandler) Ping() error {
- return nil
-}
-
-// SendCommand send a DTail command to the server.
-func (h *HealthHandler) SendCommand(command string) error {
+// SendMessage sends a DTail command to the server.
+func (h *HealthHandler) SendMessage(command string) error {
select {
case h.commands <- fmt.Sprintf("%s;", command):
case <-time.NewTimer(time.Second * 10).C:
diff --git a/internal/clients/handlers/maprhandler.go b/internal/clients/handlers/maprhandler.go
index d76cdfd..874bb7d 100644
--- a/internal/clients/handlers/maprhandler.go
+++ b/internal/clients/handlers/maprhandler.go
@@ -1,10 +1,11 @@
package handlers
import (
- "github.com/mimecast/dtail/internal/logger"
+ "strings"
+
+ "github.com/mimecast/dtail/internal/io/logger"
"github.com/mimecast/dtail/internal/mapr"
"github.com/mimecast/dtail/internal/mapr/client"
- "strings"
)
// MaprHandler is the handler used on the client side for running mapreduce aggregations.
@@ -16,15 +17,16 @@ type MaprHandler struct {
}
// NewMaprHandler returns a new mapreduce client handler.
-func NewMaprHandler(server string, query *mapr.Query, globalGroup *mapr.GlobalGroupSet, pingTimeout int) *MaprHandler {
+func NewMaprHandler(server string, query *mapr.Query, globalGroup *mapr.GlobalGroupSet) *MaprHandler {
return &MaprHandler{
baseHandler: baseHandler{
server: server,
shellStarted: false,
commands: make(chan string),
- pong: make(chan struct{}, 1),
- stop: make(chan struct{}),
- pingTimeout: pingTimeout,
+ status: -1,
+ withCancel: withCancel{
+ done: make(chan struct{}),
+ },
},
query: query,
aggregate: client.NewAggregate(server, query, globalGroup),
@@ -65,10 +67,3 @@ func (h *MaprHandler) handleAggregateMessage(message string) {
h.aggregate.Aggregate(parts[2:])
logger.Debug("Aggregated aggregate data", h.server, h.count)
}
-
-// Stop stops the mapreduce client handler.
-func (h *MaprHandler) Stop() {
- logger.Debug("Stopping mapreduce handler", h.server)
- h.aggregate.Stop()
- h.baseHandler.Stop()
-}
diff --git a/internal/clients/handlers/withcancel.go b/internal/clients/handlers/withcancel.go
new file mode 100644
index 0000000..7c9cf4e
--- /dev/null
+++ b/internal/clients/handlers/withcancel.go
@@ -0,0 +1,24 @@
+package handlers
+
+import "context"
+
+type withCancel struct {
+ ctx context.Context
+ done chan struct{}
+}
+
+// WithCancel sets and returns the context used.
+func (w *withCancel) WithCancel(ctx context.Context) (context.Context, context.CancelFunc) {
+ cancelCtx, cancel := context.WithCancel(ctx)
+ w.ctx = cancelCtx
+
+ return cancelCtx, cancel
+}
+
+func (w *withCancel) Done() <-chan struct{} {
+ return w.done
+}
+
+func (w *withCancel) shutdown() {
+ close(w.done)
+}
diff --git a/internal/clients/healthclient.go b/internal/clients/healthclient.go
index ff13b83..7313583 100644
--- a/internal/clients/healthclient.go
+++ b/internal/clients/healthclient.go
@@ -1,6 +1,7 @@
package clients
import (
+ "context"
"fmt"
"runtime"
"strings"
@@ -39,7 +40,7 @@ func NewHealthClient(mode omode.Mode) (*HealthClient, error) {
}
// Start the health client.
-func (c *HealthClient) Start() (status int) {
+func (c *HealthClient) Start(ctx context.Context) (status int) {
receive := make(chan string)
throttleCh := make(chan struct{}, runtime.NumCPU())
@@ -49,8 +50,8 @@ func (c *HealthClient) Start() (status int) {
conn.Handler = handlers.NewHealthHandler(c.server, receive)
conn.Commands = []string{c.mode.String()}
- go conn.Start(throttleCh, statsCh)
- defer conn.Stop()
+ connCtx, cancel := conn.Handler.WithCancel(ctx)
+ go conn.Start(connCtx, cancel, throttleCh, statsCh)
for {
select {
diff --git a/internal/clients/maker.go b/internal/clients/maker.go
new file mode 100644
index 0000000..da9dfc9
--- /dev/null
+++ b/internal/clients/maker.go
@@ -0,0 +1,8 @@
+package clients
+
+import "github.com/mimecast/dtail/internal/clients/handlers"
+
+type maker interface {
+ makeHandler(server string) handlers.Handler
+ makeCommands() (commands []string)
+}
diff --git a/internal/clients/maprclient.go b/internal/clients/maprclient.go
index 9070827..b581844 100644
--- a/internal/clients/maprclient.go
+++ b/internal/clients/maprclient.go
@@ -1,6 +1,7 @@
package clients
import (
+ "context"
"errors"
"fmt"
"runtime"
@@ -8,13 +9,9 @@ import (
"time"
"github.com/mimecast/dtail/internal/clients/handlers"
- "github.com/mimecast/dtail/internal/clients/remote"
- "github.com/mimecast/dtail/internal/logger"
+ "github.com/mimecast/dtail/internal/io/logger"
"github.com/mimecast/dtail/internal/mapr"
"github.com/mimecast/dtail/internal/omode"
- "github.com/mimecast/dtail/internal/ssh/client"
-
- gossh "golang.org/x/crypto/ssh"
)
// MaprClient is used for running mapreduce aggregations on remote files.
@@ -39,8 +36,6 @@ func NewMaprClient(args Args, queryStr string) (*MaprClient, error) {
c := MaprClient{
baseClient: baseClient{
Args: args,
- stop: make(chan struct{}),
- stopped: make(chan struct{}),
throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()),
retry: args.Mode == omode.TailClient,
},
@@ -70,35 +65,36 @@ func NewMaprClient(args Args, queryStr string) (*MaprClient, error) {
return &c, nil
}
-func (c MaprClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection {
- conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback)
- conn.Handler = handlers.NewMaprHandler(conn.Server, c.query, c.globalGroup, c.PingTimeout)
+// Start starts the mapreduce client.
+func (c *MaprClient) Start(ctx context.Context) (status int) {
+ if c.query.Outfile == "" {
+ // Only print out periodic results if we don't write an outfile
+ go c.periodicPrintResults(ctx)
+ }
- conn.Commands = append(conn.Commands, fmt.Sprintf("map %s", c.query.RawQuery))
- commandStr := "tail"
+ status = c.baseClient.Start(ctx)
if c.additative {
- commandStr = "cat"
+ c.recievedFinalResult()
}
- for _, file := range strings.Split(c.Files, ",") {
- conn.Commands = append(conn.Commands, fmt.Sprintf("%s %s regex %s", commandStr, file, c.Regex))
- }
+ return
+}
- return conn
+func (c MaprClient) makeHandler(server string) handlers.Handler {
+ return handlers.NewMaprHandler(server, c.query, c.globalGroup)
}
-// Start starts the mapreduce client.
-func (c *MaprClient) Start() (status int) {
- if c.query.Outfile == "" {
- // Only print out periodic results if we don't write an outfile
- go c.periodicPrintResults()
- }
+func (c MaprClient) makeCommands() (commands []string) {
+ commands = append(commands, fmt.Sprintf("map %s", c.query.RawQuery))
- status = c.baseClient.Start()
+ modeStr := "tail"
if c.additative {
- c.recievedFinalResult()
+ modeStr = "cat"
+ }
+
+ for _, file := range strings.Split(c.What, ",") {
+ commands = append(commands, fmt.Sprintf("%s %s regex %s", modeStr, file, c.Regex))
}
- c.baseClient.Stop()
return
}
@@ -120,13 +116,13 @@ func (c *MaprClient) recievedFinalResult() {
logger.Info(fmt.Sprintf("Wrote final mapreduce result to '%s'", c.query.Outfile))
}
-func (c *MaprClient) periodicPrintResults() {
+func (c *MaprClient) periodicPrintResults(ctx context.Context) {
for {
select {
case <-time.After(c.query.Interval):
logger.Info("Gathering interim mapreduce result")
c.printResults()
- case <-c.baseClient.stop:
+ case <-ctx.Done():
return
}
}
diff --git a/internal/clients/remote/connection.go b/internal/clients/remote/connection.go
index bfc7bc5..71639b1 100644
--- a/internal/clients/remote/connection.go
+++ b/internal/clients/remote/connection.go
@@ -1,16 +1,18 @@
package remote
import (
- "github.com/mimecast/dtail/internal/clients/handlers"
- "github.com/mimecast/dtail/internal/config"
- "github.com/mimecast/dtail/internal/logger"
- "github.com/mimecast/dtail/internal/ssh/client"
+ "context"
"fmt"
"io"
"strconv"
"strings"
"time"
+ "github.com/mimecast/dtail/internal/clients/handlers"
+ "github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/ssh/client"
+
"golang.org/x/crypto/ssh"
)
@@ -30,8 +32,6 @@ type Connection struct {
Commands []string
// Is it a persistent connection or a one-off?
isOneOff bool
- // Used to stop the connection
- stop chan struct{}
// To deal with SSH server host keys
hostKeyCallback *client.HostKeyCallback
}
@@ -48,7 +48,6 @@ func NewConnection(server string, userName string, authMethods []ssh.AuthMethod,
HostKeyCallback: hostKeyCallback.Wrap(),
Timeout: time.Second * 3,
},
- stop: make(chan struct{}),
}
c.initServerPort(server)
@@ -64,7 +63,6 @@ func NewOneOffConnection(server string, userName string, authMethods []ssh.AuthM
Auth: authMethods,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
},
- stop: make(chan struct{}),
isOneOff: true,
}
@@ -90,39 +88,34 @@ func (c *Connection) initServerPort(server string) {
}
}
-// Start the server connection. Build up SSH session and send some DTail commandc.
-func (c *Connection) Start(throttleCh, statsCh chan struct{}) {
+// Start the server connection. Build up SSH session and send some DTail commands.
+func (c *Connection) Start(ctx context.Context, cancel context.CancelFunc, throttleCh, statsCh chan struct{}) {
+ // Throttle how many connections can be established concurrently (based on ch length)
select {
- case <-c.stop:
- logger.Info(c.Server, c.port, "Disconnecting client")
+ case throttleCh <- struct{}{}:
+ defer func() { <-throttleCh }()
+ case <-ctx.Done():
return
- default:
}
- // Wait for SSH connection throttler
- throttleCh <- struct{}{}
-
- // Wait until connection has been initiated or an error occured
- // during initialization.
- throttleStopCh := make(chan struct{}, 2)
go func() {
- <-throttleStopCh
- <-throttleCh
- }()
+ defer cancel()
- if err := c.dial(c.Server, c.port, throttleStopCh, statsCh); err != nil {
- logger.Warn(c.Server, c.port, err)
- throttleStopCh <- struct{}{}
+ if err := c.dial(ctx, cancel, c.Server, c.port, statsCh); err != nil {
+ logger.Warn(c.Server, c.port, err)
- if c.hostKeyCallback.Untrusted(fmt.Sprintf("%s:%d", c.Server, c.port)) {
- logger.Debug("Not trusting host, not trying to re-connect", c.Server, c.port)
- return
+ if c.hostKeyCallback.Untrusted(fmt.Sprintf("%s:%d", c.Server, c.port)) {
+ logger.Debug("Not trusting host", c.Server, c.port)
+ return
+ }
}
- }
+ }()
+
+ <-ctx.Done()
}
// Dail into a new SSH connection. Close connection in case of an error.
-func (c *Connection) dial(host string, port int, throttleStopCh, statsCh chan struct{}) error {
+func (c *Connection) dial(ctx context.Context, cancel context.CancelFunc, host string, port int, statsCh chan struct{}) error {
statsCh <- struct{}{}
defer func() { <-statsCh }()
@@ -135,11 +128,11 @@ func (c *Connection) dial(host string, port int, throttleStopCh, statsCh chan st
}
defer client.Close()
- return c.session(client, throttleStopCh)
+ return c.session(ctx, cancel, client)
}
// Create the SSH session. Close the session in case of an error.
-func (c *Connection) session(client *ssh.Client, throttleStopCh chan<- struct{}) error {
+func (c *Connection) session(ctx context.Context, cancel context.CancelFunc, client *ssh.Client) error {
logger.Debug(c.Server, "session")
session, err := client.NewSession()
@@ -148,14 +141,10 @@ func (c *Connection) session(client *ssh.Client, throttleStopCh chan<- struct{})
}
defer session.Close()
- return c.handle(session, throttleStopCh)
+ return c.handle(ctx, cancel, session)
}
-// Handle the SSH session. Also send periodic pings to the server in order
-// to determine that session is still intact.
-func (c *Connection) handle(session *ssh.Session, throttleStopCh chan<- struct{}) error {
- defer c.Handler.Stop()
-
+func (c *Connection) handle(ctx context.Context, cancel context.CancelFunc, session *ssh.Session) error {
logger.Debug(c.Server, "handle")
stdinPipe, err := session.StdinPipe()
@@ -172,59 +161,30 @@ func (c *Connection) handle(session *ssh.Session, throttleStopCh chan<- struct{}
return err
}
- // Establish Bi-directional pipe between SSH session and client handler.
- brokenStdinPipe := make(chan struct{})
go func() {
- defer close(brokenStdinPipe)
+ defer cancel()
io.Copy(stdinPipe, c.Handler)
}()
- brokenStdoutPipe := make(chan struct{})
go func() {
- defer close(brokenStdoutPipe)
+ defer cancel()
io.Copy(c.Handler, stdoutPipe)
}()
- // SSH session established, other goroutine can initiate session now.
- throttleStopCh <- struct{}{}
+ go func() {
+ defer cancel()
+ select {
+ case <-c.Handler.Done():
+ case <-ctx.Done():
+ }
+ }()
// Send all commands to client.
for _, command := range c.Commands {
logger.Debug(command)
- c.Handler.SendCommand(command)
+ c.Handler.SendMessage(command)
}
- if !c.isOneOff {
- return c.periodicAliveCheck(brokenStdinPipe, brokenStdoutPipe)
- }
-
- <-c.stop
-
- // Normal shutdown, all fine
+ <-ctx.Done()
return nil
}
-
-// Periodically check whether connection is still alive or not.
-func (c *Connection) periodicAliveCheck(brokenStdinPipe, brokenStdoutPipe <-chan struct{}) error {
- for {
- select {
- case <-time.After(time.Second * 3):
- if err := c.Handler.Ping(); err != nil {
- return err
- }
- case <-brokenStdinPipe:
- logger.Debug("Broken stdin pipe", c.Server, c.port)
- return nil
- case <-brokenStdoutPipe:
- logger.Debug("Broken stdout pipe", c.Server, c.port)
- return nil
- case <-c.stop:
- return nil
- }
- }
-}
-
-// Stop the connection.
-func (c *Connection) Stop() {
- close(c.stop)
-}
diff --git a/internal/clients/runclient.go b/internal/clients/runclient.go
new file mode 100644
index 0000000..7a62fcc
--- /dev/null
+++ b/internal/clients/runclient.go
@@ -0,0 +1,40 @@
+package clients
+
+import (
+ "fmt"
+ "runtime"
+
+ "github.com/mimecast/dtail/internal/clients/handlers"
+ "github.com/mimecast/dtail/internal/omode"
+)
+
+// RunClient is a client to run various commands on the server.
+type RunClient struct {
+ baseClient
+}
+
+// NewRunClient returns a new cat client.
+func NewRunClient(args Args) (*RunClient, error) {
+ args.Mode = omode.RunClient
+
+ c := RunClient{
+ baseClient: baseClient{
+ Args: args,
+ throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()),
+ retry: false,
+ },
+ }
+
+ c.init(c)
+ return &c, nil
+}
+
+func (c RunClient) makeHandler(server string) handlers.Handler {
+ return handlers.NewClientHandler(server)
+}
+
+func (c RunClient) makeCommands() (commands []string) {
+ // Send "run COMMAND" to server!
+ commands = append(commands, fmt.Sprintf("%s %s", c.Mode.String(), c.What))
+ return
+}
diff --git a/internal/clients/stats.go b/internal/clients/stats.go
index d36cef6..ec6adfe 100644
--- a/internal/clients/stats.go
+++ b/internal/clients/stats.go
@@ -1,11 +1,13 @@
package clients
import (
- "github.com/mimecast/dtail/internal/logger"
+ "context"
"fmt"
"runtime"
"sync"
"time"
+
+ "github.com/mimecast/dtail/internal/io/logger"
)
// Used to collect and display various client stats.
@@ -28,14 +30,14 @@ func newTailStats(connectionsTotal int) *stats {
}
}
-func (s *stats) periodicLogStats(throttleCh chan struct{}, stop <-chan struct{}) {
+func (s *stats) periodicLogStats(ctx context.Context, throttleCh chan struct{}) {
connectedLast := 0
statsInterval := 5
for {
select {
case <-time.After(time.Second * time.Duration(statsInterval)):
- case <-stop:
+ case <-ctx.Done():
return
}
diff --git a/internal/clients/tailclient.go b/internal/clients/tailclient.go
index 674ca36..4d81fd5 100644
--- a/internal/clients/tailclient.go
+++ b/internal/clients/tailclient.go
@@ -6,11 +6,7 @@ import (
"strings"
"github.com/mimecast/dtail/internal/clients/handlers"
- "github.com/mimecast/dtail/internal/clients/remote"
"github.com/mimecast/dtail/internal/omode"
- "github.com/mimecast/dtail/internal/ssh/client"
-
- gossh "golang.org/x/crypto/ssh"
)
// TailClient is used for tailing remote log files (opening, seeking to the end and returning only new incoming lines).
@@ -25,25 +21,22 @@ func NewTailClient(args Args) (*TailClient, error) {
c := TailClient{
baseClient: baseClient{
Args: args,
- stop: make(chan struct{}),
- stopped: make(chan struct{}),
throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()),
retry: true,
},
}
c.init(c)
-
return &c, nil
}
-func (c TailClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection {
- conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback)
- conn.Handler = handlers.NewClientHandler(server, c.PingTimeout)
+func (c TailClient) makeHandler(server string) handlers.Handler {
+ return handlers.NewClientHandler(server)
+}
- for _, file := range strings.Split(c.Files, ",") {
- conn.Commands = append(conn.Commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex))
+func (c TailClient) makeCommands() (commands []string) {
+ for _, file := range strings.Split(c.What, ",") {
+ commands = append(commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex))
}
-
- return conn
+ return
}
diff --git a/internal/discovery/comma.go b/internal/discovery/comma.go
index ad18be0..94276c7 100644
--- a/internal/discovery/comma.go
+++ b/internal/discovery/comma.go
@@ -1,7 +1,7 @@
package discovery
import (
- "github.com/mimecast/dtail/internal/logger"
+ "github.com/mimecast/dtail/internal/io/logger"
"strings"
)
diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go
index d76c1b2..1090ea9 100644
--- a/internal/discovery/discovery.go
+++ b/internal/discovery/discovery.go
@@ -1,7 +1,6 @@
package discovery
import (
- "github.com/mimecast/dtail/internal/logger"
"fmt"
"math/rand"
"os"
@@ -9,6 +8,16 @@ import (
"regexp"
"strings"
"time"
+
+ "github.com/mimecast/dtail/internal/io/logger"
+)
+
+// ServerOrder to specify how to sort the server list.
+type ServerOrder int
+
+const (
+ // Shuffle the server list?
+ Shuffle ServerOrder = iota
)
// Discovery method for discovering a list of available DTail servers.
@@ -21,12 +30,12 @@ type Discovery struct {
server string
// To filter server list.
regex *regexp.Regexp
- // To shuffle resulting server list.
- shuffle bool
+ // How to order the server list.
+ order ServerOrder
}
// New returns a new discovery method.
-func New(method, server string, shuffle bool) *Discovery {
+func New(method, server string, order ServerOrder) *Discovery {
module := method
options := ""
@@ -43,7 +52,7 @@ func New(method, server string, shuffle bool) *Discovery {
module: strings.ToUpper(module),
options: options,
server: server,
- shuffle: shuffle,
+ order: order,
}
if strings.HasPrefix(server, "/") && strings.HasSuffix(server, "/") {
@@ -84,7 +93,7 @@ func (d *Discovery) ServerList() []string {
servers = d.dedupList(servers)
- if d.shuffle {
+ if d.order == Shuffle {
servers = d.shuffleList(servers)
}
diff --git a/internal/discovery/file.go b/internal/discovery/file.go
index 2edc867..c04173e 100644
--- a/internal/discovery/file.go
+++ b/internal/discovery/file.go
@@ -2,7 +2,7 @@ package discovery
import (
"bufio"
- "github.com/mimecast/dtail/internal/logger"
+ "github.com/mimecast/dtail/internal/io/logger"
"os"
)
diff --git a/internal/fs/catfile.go b/internal/io/fs/catfile.go
index 99f521f..7f387bc 100644
--- a/internal/fs/catfile.go
+++ b/internal/io/fs/catfile.go
@@ -1,7 +1,5 @@
package fs
-import "sync"
-
// CatFile is for reading a whole file.
type CatFile struct {
readFile
@@ -9,19 +7,15 @@ type CatFile struct {
// NewCatFile returns a new file catter.
func NewCatFile(filePath string, globID string, serverMessages chan<- string, limiter chan struct{}) CatFile {
- var mutex sync.Mutex
-
return CatFile{
readFile: readFile{
filePath: filePath,
- stop: make(chan struct{}),
globID: globID,
serverMessages: serverMessages,
retry: false,
canSkipLines: false,
seekEOF: false,
limiter: limiter,
- mutex: &mutex,
},
}
}
diff --git a/internal/fs/filereader.go b/internal/io/fs/filereader.go
index 5a08e27..05e58a1 100644
--- a/internal/fs/filereader.go
+++ b/internal/io/fs/filereader.go
@@ -1,9 +1,14 @@
package fs
+import (
+ "context"
+
+ "github.com/mimecast/dtail/internal/io/line"
+)
+
// FileReader is the interface used on the dtail server to read/cat/grep/mapr... a file.
type FileReader interface {
- Start(lines chan<- LineRead, regex string) error
+ Start(ctx context.Context, lines chan<- line.Line, regex string) error
FilePath() string
Retry() bool
- Stop()
}
diff --git a/internal/fs/permissions/permission.go b/internal/io/fs/permissions/permission.go
index 6e83309..0ed4f17 100644
--- a/internal/fs/permissions/permission.go
+++ b/internal/io/fs/permissions/permission.go
@@ -3,7 +3,7 @@
package permissions
import (
- "github.com/mimecast/dtail/internal/logger"
+ "github.com/mimecast/dtail/internal/io/logger"
)
// ToRead is to check whether user has read permissions to a given file.
diff --git a/internal/fs/permissions/permission_linux.c b/internal/io/fs/permissions/permission_linux.c
index cd10525..cd10525 100644
--- a/internal/fs/permissions/permission_linux.c
+++ b/internal/io/fs/permissions/permission_linux.c
diff --git a/internal/fs/permissions/permission_linux.go b/internal/io/fs/permissions/permission_linux.go
index feae729..feae729 100644
--- a/internal/fs/permissions/permission_linux.go
+++ b/internal/io/fs/permissions/permission_linux.go
diff --git a/internal/fs/permissions/permission_linux.h b/internal/io/fs/permissions/permission_linux.h
index a2c266e..a2c266e 100644
--- a/internal/fs/permissions/permission_linux.h
+++ b/internal/io/fs/permissions/permission_linux.h
diff --git a/internal/fs/permissions/permission_test.go b/internal/io/fs/permissions/permission_test.go
index d415ac2..d415ac2 100644
--- a/internal/fs/permissions/permission_test.go
+++ b/internal/io/fs/permissions/permission_test.go
diff --git a/internal/fs/readfile.go b/internal/io/fs/readfile.go
index 312447a..321432e 100644
--- a/internal/fs/readfile.go
+++ b/internal/io/fs/readfile.go
@@ -3,7 +3,7 @@ package fs
import (
"bufio"
"compress/gzip"
- "github.com/mimecast/dtail/internal/logger"
+ "context"
"errors"
"io"
"os"
@@ -12,6 +12,9 @@ import (
"sync"
"time"
+ "github.com/mimecast/dtail/internal/io/line"
+ "github.com/mimecast/dtail/internal/io/logger"
+
"github.com/DataDog/zstd"
)
@@ -27,16 +30,12 @@ type readFile struct {
globID string
// Channel to send a server message to the dtail client
serverMessages chan<- string
- // Signals to stop tailing the log file.
- stop chan struct{}
// Periodically retry reading file.
retry bool
// Can I skip messages when there are too many?
canSkipLines bool
// Seek to the EOF before processing file?
seekEOF bool
- // Mutex to control the stopping of the file
- mutex *sync.Mutex
limiter chan struct{}
}
@@ -51,7 +50,7 @@ func (f readFile) Retry() bool {
}
// Start tailing a log file.
-func (f readFile) Start(lines chan<- LineRead, regex string) error {
+func (f readFile) Start(ctx context.Context, lines chan<- line.Line, regex string) error {
defer func() {
select {
case <-f.limiter:
@@ -64,7 +63,7 @@ func (f readFile) Start(lines chan<- LineRead, regex string) error {
default:
select {
case f.serverMessages <- logger.Warn(f.filePath, f.globID, "Server limit reached. Queuing file..."):
- case <-f.stop:
+ case <-ctx.Done():
return nil
}
f.limiter <- struct{}{}
@@ -86,44 +85,30 @@ func (f readFile) Start(lines chan<- LineRead, regex string) error {
var wg sync.WaitGroup
wg.Add(1)
- go f.periodicTruncateCheck(truncate)
- go f.filter(&wg, rawLines, lines, regex)
+ go f.periodicTruncateCheck(ctx, truncate)
+ go f.filter(ctx, &wg, rawLines, lines, regex)
- err = f.read(fd, rawLines, truncate)
+ err = f.read(ctx, fd, rawLines, truncate)
close(rawLines)
wg.Wait()
return err
}
-func (f readFile) periodicTruncateCheck(truncate chan struct{}) {
+func (f readFile) periodicTruncateCheck(ctx context.Context, truncate chan struct{}) {
for {
select {
case <-time.After(time.Second * 3):
select {
case truncate <- struct{}{}:
- case <-f.stop:
+ case <-ctx.Done():
}
- case <-f.stop:
+ case <-ctx.Done():
return
}
}
}
-// Stop reading file.
-func (f readFile) Stop() {
- f.mutex.Lock()
- defer f.mutex.Unlock()
-
- select {
- case <-f.stop:
- return
- default:
- }
-
- close(f.stop)
-}
-
func (f readFile) makeReader(fd *os.File) (reader *bufio.Reader, err error) {
switch {
case strings.HasSuffix(f.FilePath(), ".gz"):
@@ -146,27 +131,31 @@ func (f readFile) makeReader(fd *os.File) (reader *bufio.Reader, err error) {
return
}
-func (f readFile) read(fd *os.File, rawLines chan []byte, truncate <-chan struct{}) error {
+func (f readFile) read(ctx context.Context, fd *os.File, rawLines chan []byte, truncate <-chan struct{}) error {
+ var offset uint64
+
reader, err := f.makeReader(fd)
if err != nil {
return err
}
rawLine := make([]byte, 0, 512)
- var offset uint64
lineLengthThreshold := 1024 * 1024 // 1mb
longLineWarning := false
for {
select {
+ case <-ctx.Done():
+ return nil
+ default:
+ }
+
+ select {
case <-truncate:
if isTruncated, err := f.truncated(fd); isTruncated {
return err
}
logger.Info(f.filePath, "Current offset", offset)
-
- case <-f.stop:
- return nil
default:
}
@@ -196,7 +185,7 @@ func (f readFile) read(fd *os.File, rawLines chan []byte, truncate <-chan struct
rawLine = append(rawLine, '\n')
select {
case rawLines <- rawLine:
- case <-f.stop:
+ case <-ctx.Done():
return nil
}
rawLine = make([]byte, 0, 512)
@@ -219,7 +208,7 @@ func (f readFile) read(fd *os.File, rawLines chan []byte, truncate <-chan struct
rawLine = append(rawLine, '\n')
select {
case rawLines <- rawLine:
- case <-f.stop:
+ case <-ctx.Done():
return nil
}
rawLine = make([]byte, 0, 512)
@@ -228,7 +217,7 @@ func (f readFile) read(fd *os.File, rawLines chan []byte, truncate <-chan struct
}
// Filter log lines matching a given regular expression.
-func (f readFile) filter(wg *sync.WaitGroup, rawLines <-chan []byte, lines chan<- LineRead, regex string) {
+func (f readFile) filter(ctx context.Context, wg *sync.WaitGroup, rawLines <-chan []byte, lines chan<- line.Line, regex string) {
defer wg.Done()
if regex == "" {
@@ -252,7 +241,7 @@ func (f readFile) filter(wg *sync.WaitGroup, rawLines <-chan []byte, lines chan<
if filteredLine, ok := f.transmittable(line, len(lines), cap(lines)); ok {
select {
case lines <- filteredLine:
- case <-f.stop:
+ case <-ctx.Done():
return
}
}
@@ -260,10 +249,10 @@ func (f readFile) filter(wg *sync.WaitGroup, rawLines <-chan []byte, lines chan<
}
}
-func (f readFile) transmittable(line []byte, length, capacity int) (LineRead, bool) {
- var read LineRead
+func (f readFile) transmittable(lineBytes []byte, length, capacity int) (line.Line, bool) {
+ var read line.Line
- if !f.re.Match(line) {
+ if !f.re.Match(lineBytes) {
f.updateLineNotMatched()
f.updateLineNotTransmitted()
return read, false
@@ -277,9 +266,9 @@ func (f readFile) transmittable(line []byte, length, capacity int) (LineRead, bo
}
f.updateLineTransmitted()
- read = LineRead{
- Content: line,
- GlobID: &f.globID,
+ read = line.Line{
+ Content: lineBytes,
+ SourceID: f.globID,
Count: f.totalLineCount(),
TransmittedPerc: f.transmittedPerc(),
}
diff --git a/internal/fs/stats.go b/internal/io/fs/stats.go
index 4121ff7..4121ff7 100644
--- a/internal/fs/stats.go
+++ b/internal/io/fs/stats.go
diff --git a/internal/fs/tailfile.go b/internal/io/fs/tailfile.go
index a19d4e6..14994e5 100644
--- a/internal/fs/tailfile.go
+++ b/internal/io/fs/tailfile.go
@@ -1,7 +1,5 @@
package fs
-import "sync"
-
// TailFile is to tail and filter a log file.
type TailFile struct {
readFile
@@ -9,19 +7,15 @@ type TailFile struct {
// NewTailFile returns a new file tailer.
func NewTailFile(filePath string, globID string, serverMessages chan<- string, limiter chan struct{}) TailFile {
- var mutex sync.Mutex
-
return TailFile{
readFile: readFile{
filePath: filePath,
- stop: make(chan struct{}),
globID: globID,
serverMessages: serverMessages,
retry: true,
canSkipLines: true,
seekEOF: true,
limiter: limiter,
- mutex: &mutex,
},
}
}
diff --git a/internal/fs/lineread.go b/internal/io/line/line.go
index 7ee558e..9db93c0 100644
--- a/internal/fs/lineread.go
+++ b/internal/io/line/line.go
@@ -1,11 +1,11 @@
-package fs
+package line
import (
"fmt"
)
-// LineRead represents a read log line.
-type LineRead struct {
+// Line represents a read log line.
+type Line struct {
// The content of the log line.
Content []byte
// Until now, how many log lines were processed?
@@ -15,14 +15,14 @@ type LineRead struct {
// lines if that happens but it will signal to the client how
// many log lines in % could be transmitted to the client.
TransmittedPerc int
- GlobID *string
+ SourceID string
}
// Return a human readable representation of the followed line.
-func (l LineRead) String() string {
- return fmt.Sprintf("LineRead(Content:%s,TransmittedPerc:%v,Count:%v,GlobID:%s)",
+func (l Line) String() string {
+ return fmt.Sprintf("Line(Content:%s,TransmittedPerc:%v,Count:%v,SourceID:%s)",
string(l.Content),
l.TransmittedPerc,
l.Count,
- *l.GlobID)
+ l.SourceID)
}
diff --git a/internal/logger/logger.go b/internal/io/logger/logger.go
index ca85e32..e30b907 100644
--- a/internal/logger/logger.go
+++ b/internal/io/logger/logger.go
@@ -2,6 +2,7 @@ package logger
import (
"bufio"
+ "context"
"fmt"
"os"
"os/signal"
@@ -48,17 +49,13 @@ var lastDateStr string
var serverEnable bool
// Used to make logging non-blocking.
-var logBufCh chan buf
+var fileLogBufCh chan buf
var stdoutBufCh chan string
// Stdout channel, required to pause output
var pauseCh chan struct{}
var resumeCh chan struct{}
-// Tell the logger that we are done, program shuts down
-var stop chan struct{}
-var stdoutFlushed chan struct{}
-
// Tell the logger about logrotation
var rotateCh chan os.Signal
@@ -103,7 +100,7 @@ type buf struct {
}
// Start logging.
-func Start(myServerEnable, debugEnable, silentEnable, nothingEnable bool) {
+func Start(ctx context.Context, myServerEnable, debugEnable, silentEnable, nothingEnable bool) {
serverEnable = myServerEnable
mode := logMode(debugEnable, silentEnable, nothingEnable)
@@ -125,7 +122,7 @@ func Start(myServerEnable, debugEnable, silentEnable, nothingEnable bool) {
case StdoutStrategy:
fallthrough
default:
- logToFile = false
+ logToFile = !serverEnable
logToStdout = true
}
@@ -138,8 +135,6 @@ func Start(myServerEnable, debugEnable, silentEnable, nothingEnable bool) {
pauseCh = make(chan struct{})
resumeCh = make(chan struct{})
- stop = make(chan struct{})
- stdoutFlushed = make(chan struct{})
// Setup logrotation
rotateCh = make(chan os.Signal, 1)
@@ -147,12 +142,12 @@ func Start(myServerEnable, debugEnable, silentEnable, nothingEnable bool) {
if logToStdout {
stdoutBufCh = make(chan string, runtime.NumCPU()*100)
- go writeToStdout()
+ go writeToStdout(ctx)
}
if logToFile {
- logBufCh = make(chan buf, runtime.NumCPU()*100)
- go writeToFile()
+ fileLogBufCh = make(chan buf, runtime.NumCPU()*100)
+ go writeToFile(ctx)
}
}
@@ -264,7 +259,7 @@ func write(what, severity, message string) {
if logToFile {
t := time.Now()
timeStr := t.Format("20060102-150405")
- logBufCh <- buf{
+ fileLogBufCh <- buf{
time: t,
message: fmt.Sprintf("%s|%s|%s|%s\n", severity, timeStr, what, message),
}
@@ -304,16 +299,16 @@ func Raw(message string) {
return
}
+ if logToFile {
+ fileLogBufCh <- buf{time.Now(), message}
+ }
+
if logToStdout {
if color.Colored {
message = color.Colorfy(message)
}
stdoutBufCh <- message
}
-
- if logToFile {
- logBufCh <- buf{time.Now(), message}
- }
}
// Close log writer (e.g. on change of day).
@@ -367,9 +362,8 @@ func updateFileWriter(dateStr string) *bufio.Writer {
return writer
}
-func flushStdout() {
- defer close(stdoutFlushed)
-
+// Flush all outstanding lines.
+func Flush() {
for {
select {
case message := <-stdoutBufCh:
@@ -381,7 +375,7 @@ func flushStdout() {
}
}
-func writeToStdout() {
+func writeToStdout(ctx context.Context) {
for {
select {
case message := <-stdoutBufCh:
@@ -395,21 +389,21 @@ func writeToStdout() {
case <-stdoutBufCh:
case <-resumeCh:
break PAUSE
- case <-stop:
+ case <-ctx.Done():
return
}
}
- case <-stop:
- flushStdout()
+ case <-ctx.Done():
+ Flush()
return
}
}
}
-func writeToFile() {
+func writeToFile(ctx context.Context) {
for {
select {
- case buf := <-logBufCh:
+ case buf := <-fileLogBufCh:
dateStr := buf.time.Format("20060102")
w := fileWriter(dateStr)
w.WriteString(buf.message)
@@ -420,11 +414,11 @@ func writeToFile() {
case <-stdoutBufCh:
case <-resumeCh:
break PAUSE
- case <-stop:
+ case <-ctx.Done():
return
}
}
- case <-stop:
+ case <-ctx.Done():
return
}
}
@@ -449,9 +443,3 @@ func Resume() {
resumeCh <- struct{}{}
}
}
-
-// Stop logging.
-func Stop() {
- close(stop)
- <-stdoutFlushed
-}
diff --git a/internal/io/run/run.go b/internal/io/run/run.go
new file mode 100644
index 0000000..b608639
--- /dev/null
+++ b/internal/io/run/run.go
@@ -0,0 +1,104 @@
+package run
+
+import (
+ "bufio"
+ "context"
+ "io"
+ "os/exec"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/mimecast/dtail/internal/io/line"
+ "github.com/mimecast/dtail/internal/io/logger"
+)
+
+// Run is for execute a command.
+type Run struct {
+ commandPath string
+ args []string
+ cmd *exec.Cmd
+}
+
+// New returns a new command runner.
+func New(commandPath string, args []string) Run {
+ return Run{
+ commandPath: commandPath,
+ args: args,
+ }
+}
+
+// Start running the command.
+func (r Run) Start(ctx context.Context, lines chan<- line.Line) (pid int, ec int, err error) {
+ done := make(chan struct{})
+ defer close(done)
+
+ ec = -1
+ pid = -1
+
+ if len(r.args) > 0 {
+ logger.Debug(r.commandPath, strings.Join(r.args, " "))
+ r.cmd = exec.CommandContext(ctx, r.commandPath, strings.Join(r.args, " "))
+ } else {
+ logger.Debug(r.commandPath)
+ r.cmd = exec.CommandContext(ctx, r.commandPath)
+ }
+
+ stdoutPipe, myErr := r.cmd.StdoutPipe()
+ if err != nil {
+ err = myErr
+ return
+ }
+
+ stderrPipe, myErr := r.cmd.StderrPipe()
+ if myErr != nil {
+ err = myErr
+ return
+ }
+
+ if myErr := r.cmd.Start(); err != nil {
+ err = myErr
+ return
+ }
+
+ pid = r.cmd.Process.Pid
+ ec = 0
+
+ var wg sync.WaitGroup
+ wg.Add(2)
+
+ go r.pipeToLines(done, &wg, pid, stdoutPipe, "STDOUT", lines)
+ go r.pipeToLines(done, &wg, pid, stderrPipe, "STDERR", lines)
+
+ if err = r.cmd.Wait(); err != nil {
+ if exitError, ok := err.(*exec.ExitError); ok {
+ ec = exitError.ExitCode()
+ }
+ }
+
+ return
+}
+
+func (r Run) pipeToLines(done chan struct{}, wg *sync.WaitGroup, pid int, reader io.Reader, what string, lines chan<- line.Line) {
+ defer wg.Done()
+ bufReader := bufio.NewReader(reader)
+
+ for {
+ lineStr, err := bufReader.ReadString('\n')
+ for err == nil {
+ lines <- line.Line{
+ Content: []byte(lineStr),
+ Count: uint64(pid),
+ TransmittedPerc: 100,
+ SourceID: what,
+ }
+ lineStr, err = bufReader.ReadString('\n')
+ }
+ select {
+ case <-done:
+ return
+ default:
+ }
+ time.Sleep(time.Millisecond * 10)
+ }
+}
diff --git a/internal/mapr/aggregateset.go b/internal/mapr/aggregateset.go
index 2096c3c..7fb4c17 100644
--- a/internal/mapr/aggregateset.go
+++ b/internal/mapr/aggregateset.go
@@ -1,6 +1,7 @@
package mapr
import (
+ "context"
"fmt"
"strconv"
"strings"
@@ -64,7 +65,7 @@ func (s *AggregateSet) Merge(query *Query, set *AggregateSet) error {
}
// Serialize the aggregate set so it can be sent over the wire.
-func (s *AggregateSet) Serialize(groupKey string, ch chan<- string, stop chan struct{}) {
+func (s *AggregateSet) Serialize(ctx context.Context, groupKey string, ch chan<- string) {
//logger.Trace("Serialising mapr.AggregateSet", s)
var sb strings.Builder
@@ -87,7 +88,7 @@ func (s *AggregateSet) Serialize(groupKey string, ch chan<- string, stop chan st
select {
case ch <- sb.String():
- case <-stop:
+ case <-ctx.Done():
}
}
diff --git a/internal/mapr/client/aggregate.go b/internal/mapr/client/aggregate.go
index 3f2b7a5..1272a19 100644
--- a/internal/mapr/client/aggregate.go
+++ b/internal/mapr/client/aggregate.go
@@ -1,10 +1,11 @@
package client
import (
- "github.com/mimecast/dtail/internal/logger"
- "github.com/mimecast/dtail/internal/mapr"
"strconv"
"strings"
+
+ "github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/mapr"
)
// Aggregate mapreduce data on the DTail client side.
@@ -15,7 +16,6 @@ type Aggregate struct {
group *mapr.GroupSet
// This represents the merged aggregated data of all servers.
globalGroup *mapr.GlobalGroupSet
- stop chan struct{}
// The server we aggregate the data for (logging and debugging purposes only)
server string
}
@@ -26,20 +26,12 @@ func NewAggregate(server string, query *mapr.Query, globalGroup *mapr.GlobalGrou
query: query,
group: mapr.NewGroupSet(),
globalGroup: globalGroup,
- stop: make(chan struct{}),
server: server,
}
}
// Aggregate data from mapr log line into local (and global) group sets.
func (a *Aggregate) Aggregate(parts []string) {
- select {
- case <-a.stop:
- logger.Error("Client aggregator stopped for server, not processing new data", a.server)
- return
- default:
- }
-
groupKey := parts[0]
samples, err := strconv.Atoi(parts[1])
if err != nil {
@@ -87,14 +79,3 @@ func (a *Aggregate) makeFields(parts []string) map[string]string {
return fields
}
-
-// Stop the client side mapreduce aggregator.
-func (a *Aggregate) Stop() {
- logger.Debug("Stopping client mapreduce aggregator")
- close(a.stop)
-
- err := a.globalGroup.Merge(a.query, a.group)
- if err != nil {
- panic(err)
- }
-}
diff --git a/internal/mapr/groupset.go b/internal/mapr/groupset.go
index d8f9379..e9e0d37 100644
--- a/internal/mapr/groupset.go
+++ b/internal/mapr/groupset.go
@@ -1,6 +1,7 @@
package mapr
import (
+ "context"
"errors"
"fmt"
"io/ioutil"
@@ -46,9 +47,9 @@ func (g *GroupSet) GetSet(groupKey string) *AggregateSet {
}
// Serialize the group set (e.g. to send it over the wire).
-func (g *GroupSet) Serialize(ch chan<- string, stop chan struct{}) {
+func (g *GroupSet) Serialize(ctx context.Context, ch chan<- string) {
for groupKey, set := range g.sets {
- set.Serialize(groupKey, ch, stop)
+ set.Serialize(ctx, groupKey, ch)
}
}
diff --git a/internal/mapr/logformat/parser.go b/internal/mapr/logformat/parser.go
index 5730d29..09c706b 100644
--- a/internal/mapr/logformat/parser.go
+++ b/internal/mapr/logformat/parser.go
@@ -1,9 +1,9 @@
package logformat
import (
- "github.com/mimecast/dtail/internal/logger"
"errors"
"fmt"
+ "github.com/mimecast/dtail/internal/io/logger"
"os"
"reflect"
"strings"
diff --git a/internal/mapr/query.go b/internal/mapr/query.go
index 3805d15..0127be3 100644
--- a/internal/mapr/query.go
+++ b/internal/mapr/query.go
@@ -1,9 +1,9 @@
package mapr
import (
- "github.com/mimecast/dtail/internal/logger"
"errors"
"fmt"
+ "github.com/mimecast/dtail/internal/io/logger"
"strconv"
"strings"
"time"
diff --git a/internal/mapr/server/aggregate.go b/internal/mapr/server/aggregate.go
index 900756e..922dcbd 100644
--- a/internal/mapr/server/aggregate.go
+++ b/internal/mapr/server/aggregate.go
@@ -1,26 +1,28 @@
package server
import (
- "github.com/mimecast/dtail/internal/config"
- "github.com/mimecast/dtail/internal/fs"
- "github.com/mimecast/dtail/internal/logger"
- "github.com/mimecast/dtail/internal/mapr"
- "github.com/mimecast/dtail/internal/mapr/logformat"
+ "context"
"os"
"strings"
"time"
+
+ "github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/io/line"
+ "github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/mapr"
+ "github.com/mimecast/dtail/internal/mapr/logformat"
)
// Aggregate is for aggregating mapreduce data on the DTail server side.
type Aggregate struct {
// Log lines to process (parsing MAPREDUCE lines).
- Lines chan fs.LineRead
+ Lines chan line.Line
// Hostname of the current server (used to populate $hostname field).
hostname string
- // Signals to exit goroutine.
- stop chan struct{}
// Signals to serialize data.
serialize chan struct{}
+ // Signals to flush data.
+ flush chan struct{}
// The mapr query
query *mapr.Query
// The mapr log format parser
@@ -28,7 +30,7 @@ type Aggregate struct {
}
// NewAggregate return a new server side aggregator.
-func NewAggregate(maprLines chan<- string, queryStr string) (*Aggregate, error) {
+func NewAggregate(queryStr string) (*Aggregate, error) {
query, err := mapr.NewQuery(queryStr)
if err != nil {
return nil, err
@@ -47,76 +49,98 @@ func NewAggregate(maprLines chan<- string, queryStr string) (*Aggregate, error)
}
a := Aggregate{
- Lines: make(chan fs.LineRead, 100),
- stop: make(chan struct{}),
+ Lines: make(chan line.Line, 100),
serialize: make(chan struct{}),
+ flush: make(chan struct{}),
hostname: s[0],
query: query,
parser: logParser,
}
- go a.periodicAggregateTimer()
-
- fieldsCh := make(chan map[string]string)
- go a.readFields(fieldsCh, maprLines)
- go a.readLines(fieldsCh)
-
return &a, nil
}
-func (a *Aggregate) periodicAggregateTimer() {
+// Start an aggregation run.
+func (a *Aggregate) Start(ctx context.Context, maprLines chan<- string) {
+ fieldsCh := a.linesToFields(ctx)
+ go a.fieldsToMaprLines(ctx, fieldsCh, maprLines)
+ a.periodicAggregateTimer(ctx)
+}
+
+func (a *Aggregate) periodicAggregateTimer(ctx context.Context) {
for {
select {
case <-time.After(a.query.Interval):
- a.Serialize()
- case <-a.stop:
+ a.Serialize(ctx)
+ case <-ctx.Done():
return
}
}
}
-func (a *Aggregate) readFields(fieldsCh <-chan map[string]string, maprLines chan<- string) {
- group := mapr.NewGroupSet()
+func (a *Aggregate) linesToFields(ctx context.Context) <-chan map[string]string {
+ fieldsCh := make(chan map[string]string)
- for {
- select {
- case fields := <-fieldsCh:
- a.aggregate(group, fields)
- case <-a.serialize:
- logger.Info("Serializing mapreduce result")
- group.Serialize(maprLines, a.stop)
- logger.Info("Done serializing mapreduce result")
- group = mapr.NewGroupSet()
- case <-a.stop:
- return
+ go func() {
+ defer close(fieldsCh)
+
+ for {
+ select {
+ case line, ok := <-a.Lines:
+ if !ok {
+ return
+ }
+
+ maprLine := strings.TrimSpace(string(line.Content))
+ fields, err := a.parser.MakeFields(maprLine)
+
+ if err != nil {
+ logger.Error(err)
+ continue
+ }
+ if !a.query.WhereClause(fields) {
+ continue
+ }
+
+ select {
+ case fieldsCh <- fields:
+ case <-ctx.Done():
+ }
+ case <-ctx.Done():
+ return
+ }
}
- }
+ }()
+
+ return fieldsCh
}
-func (a *Aggregate) readLines(fieldsCh chan<- map[string]string) {
+func (a *Aggregate) fieldsToMaprLines(ctx context.Context, fieldsCh <-chan map[string]string, maprLines chan<- string) {
+ group := mapr.NewGroupSet()
+
for {
select {
- case line, ok := <-a.Lines:
+ case fields, ok := <-fieldsCh:
if !ok {
+ logger.Info("Serializing mapreduce result (final)")
+ group.Serialize(ctx, maprLines)
+ group = mapr.NewGroupSet()
+ logger.Info("Done serializing mapreduce result (final)")
return
}
-
- maprLine := strings.TrimSpace(string(line.Content))
- fields, err := a.parser.MakeFields(maprLine)
-
- if err != nil {
- logger.Error(err)
- continue
- }
- if !a.query.WhereClause(fields) {
- continue
- }
-
- select {
- case fieldsCh <- fields:
- case <-a.stop:
- }
- case <-a.stop:
+ a.aggregate(group, fields)
+ case <-a.serialize:
+ logger.Info("Serializing mapreduce result")
+ group.Serialize(ctx, maprLines)
+ group = mapr.NewGroupSet()
+ logger.Info("Done serializing mapreduce result")
+ case <-a.flush:
+ logger.Info("Flushing mapreduce result")
+ group.Serialize(ctx, maprLines)
+ group = mapr.NewGroupSet()
+ a.flush <- struct{}{}
+ logger.Info("Done flushing mapreduce result")
+ case <-ctx.Done():
return
}
}
@@ -157,14 +181,15 @@ func (a *Aggregate) aggregate(group *mapr.GroupSet, fields map[string]string) {
}
// Serialize all the aggregated data.
-func (a *Aggregate) Serialize() {
+func (a *Aggregate) Serialize(ctx context.Context) {
select {
case a.serialize <- struct{}{}:
- case <-a.stop:
+ case <-ctx.Done():
}
}
-// Close the aggregator.
-func (a *Aggregate) Close() {
- close(a.stop)
+// Flush all data.
+func (a *Aggregate) Flush() {
+ a.flush <- struct{}{}
+ <-a.flush
}
diff --git a/internal/mapr/wherecondition.go b/internal/mapr/wherecondition.go
index e1f4e5b..ab46bed 100644
--- a/internal/mapr/wherecondition.go
+++ b/internal/mapr/wherecondition.go
@@ -1,9 +1,9 @@
package mapr
import (
- "github.com/mimecast/dtail/internal/logger"
"errors"
"fmt"
+ "github.com/mimecast/dtail/internal/io/logger"
"strconv"
"strings"
)
diff --git a/internal/omode/mode.go b/internal/omode/mode.go
index 57366d2..e29aacc 100644
--- a/internal/omode/mode.go
+++ b/internal/omode/mode.go
@@ -12,7 +12,7 @@ const (
GrepClient Mode = iota
MapClient Mode = iota
HealthClient Mode = iota
- ExecClient Mode = iota
+ RunClient Mode = iota
)
func (m Mode) String() string {
@@ -29,8 +29,8 @@ func (m Mode) String() string {
return "map"
case HealthClient:
return "health"
- case ExecClient:
- return "exec"
+ case RunClient:
+ return "run"
default:
return "unknown"
}
diff --git a/internal/pprof/pprof.go b/internal/pprof/pprof.go
index f78bcf6..c6d11ca 100644
--- a/internal/pprof/pprof.go
+++ b/internal/pprof/pprof.go
@@ -7,9 +7,10 @@ import (
_ "net/http/pprof"
"github.com/mimecast/dtail/internal/config"
- "github.com/mimecast/dtail/internal/logger"
+ "github.com/mimecast/dtail/internal/io/logger"
)
+// Start the profiler HTTP server.
func Start() {
bindAddr := fmt.Sprintf("%s:%d", config.Common.PProfBindAddress, config.Common.PProfPort)
logger.Info("Starting PProf server", bindAddr)
diff --git a/internal/prompt/prompt.go b/internal/prompt/prompt.go
index 76a2726..a438d33 100644
--- a/internal/prompt/prompt.go
+++ b/internal/prompt/prompt.go
@@ -2,8 +2,8 @@ package prompt
import (
"bufio"
- "github.com/mimecast/dtail/internal/logger"
"fmt"
+ "github.com/mimecast/dtail/internal/io/logger"
"os"
"strings"
)
diff --git a/internal/server/handlers/controlhandler.go b/internal/server/handlers/controlhandler.go
index 482f759..a33a78b 100644
--- a/internal/server/handlers/controlhandler.go
+++ b/internal/server/handlers/controlhandler.go
@@ -1,33 +1,34 @@
package handlers
import (
+ "context"
"fmt"
"io"
"os"
"strings"
- "github.com/mimecast/dtail/internal/logger"
+ "github.com/mimecast/dtail/internal/io/logger"
user "github.com/mimecast/dtail/internal/user/server"
)
// ControlHandler is used for control functions and health monitoring.
type ControlHandler struct {
- serverMessages chan string
- pong chan struct{}
- stop chan struct{}
- payload []byte
+ ctx context.Context
+ done chan struct{}
hostname string
+ payload []byte
+ serverMessages chan string
user *user.User
}
// NewControlHandler returns a new control handler.
-func NewControlHandler(user *user.User) *ControlHandler {
+func NewControlHandler(ctx context.Context, user *user.User) (*ControlHandler, <-chan struct{}) {
logger.Debug(user, "Creating control handler")
h := ControlHandler{
+ ctx: ctx,
+ done: make(chan struct{}),
serverMessages: make(chan string, 10),
- pong: make(chan struct{}, 10),
- stop: make(chan struct{}),
user: user,
}
@@ -38,7 +39,8 @@ func NewControlHandler(user *user.User) *ControlHandler {
s := strings.Split(fqdn, ".")
h.hostname = s[0]
- return &h
+
+ return &h, h.done
}
// Read is to send data to the client via the Reader interface.
@@ -49,11 +51,7 @@ func (h *ControlHandler) Read(p []byte) (n int, err error) {
wholePayload := []byte(fmt.Sprintf("SERVER|%s|%s\n", h.hostname, message))
n = copy(p, wholePayload)
return
- case <-h.pong:
- logger.Info(h.user, "Sending pong")
- n = copy(p, []byte(".pong\n"))
- return
- case <-h.stop:
+ case <-h.ctx.Done():
return 0, io.EOF
}
}
@@ -65,7 +63,7 @@ func (h *ControlHandler) Write(p []byte) (n int, err error) {
switch c {
case ';':
wholePayload := strings.TrimSpace(string(h.payload))
- h.handleCommand(wholePayload)
+ h.handleCommand(h.ctx, wholePayload)
h.payload = nil
default:
@@ -77,17 +75,7 @@ func (h *ControlHandler) Write(p []byte) (n int, err error) {
return
}
-// Close the control handler.
-func (h *ControlHandler) Close() {
- close(h.stop)
-}
-
-// Wait returns the handler stop channel.
-func (h *ControlHandler) Wait() <-chan struct{} {
- return h.stop
-}
-
-func (h *ControlHandler) handleCommand(command string) {
+func (h *ControlHandler) handleCommand(ctx context.Context, command string) {
logger.Info(h.user, command)
s := strings.Split(command, " ")
logger.Debug(h.user, "Receiving command", command, s)
@@ -96,8 +84,6 @@ func (h *ControlHandler) handleCommand(command string) {
case "health":
h.serverMessages <- "OK: DTail SSH Server seems fine"
h.serverMessages <- "done;"
- case "ping":
- h.pong <- struct{}{}
case "debug":
h.serverMessages <- logger.Debug(h.user, "Receiving debug command", command, s)
default:
diff --git a/internal/server/handlers/handler.go b/internal/server/handlers/handler.go
index 8b1f73e..c42ceb9 100644
--- a/internal/server/handlers/handler.go
+++ b/internal/server/handlers/handler.go
@@ -5,6 +5,4 @@ import "io"
// Handler interface for server side functionality.
type Handler interface {
io.ReadWriter
- Close()
- Wait() <-chan struct{}
}
diff --git a/internal/server/handlers/mapcommand.go b/internal/server/handlers/mapcommand.go
new file mode 100644
index 0000000..10372da
--- /dev/null
+++ b/internal/server/handlers/mapcommand.go
@@ -0,0 +1,35 @@
+package handlers
+
+import (
+ "context"
+ "strings"
+
+ "github.com/mimecast/dtail/internal/mapr/server"
+)
+
+// Map command implements the mapreduce command server side.
+type mapCommand struct {
+ aggregate *server.Aggregate
+ server *ServerHandler
+}
+
+// NewMapCommand returns a new server side mapreduce command.
+func newMapCommand(serverHandler *ServerHandler, argc int, args []string) (mapCommand, *server.Aggregate, error) {
+ mapCommand := mapCommand{
+ server: serverHandler,
+ }
+
+ queryStr := strings.Join(args[1:], " ")
+ aggregate, err := server.NewAggregate(queryStr)
+ if err != nil {
+ return mapCommand, nil, err
+ }
+
+ mapCommand.aggregate = aggregate
+ return mapCommand, aggregate, nil
+
+}
+
+func (m mapCommand) Start(ctx context.Context, aggregatedMessages chan<- string) {
+ m.aggregate.Start(ctx, aggregatedMessages)
+}
diff --git a/internal/server/handlers/readcommand.go b/internal/server/handlers/readcommand.go
new file mode 100644
index 0000000..e4079e8
--- /dev/null
+++ b/internal/server/handlers/readcommand.go
@@ -0,0 +1,158 @@
+package handlers
+
+import (
+ "context"
+ "path/filepath"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/mimecast/dtail/internal/io/fs"
+ "github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/omode"
+)
+
+type readCommand struct {
+ server *ServerHandler
+ mode omode.Mode
+}
+
+func newReadCommand(server *ServerHandler, mode omode.Mode) *readCommand {
+ return &readCommand{
+ server: server,
+ mode: mode,
+ }
+}
+
+func (r *readCommand) Start(ctx context.Context, argc int, args []string) {
+ regex := "."
+ if argc >= 4 {
+ regex = args[3]
+ }
+ if argc < 3 {
+ r.server.sendServerMessage(logger.Warn(r.server.user, commandParseWarning, args, argc))
+ return
+ }
+ r.readGlob(ctx, args[1], regex)
+}
+
+func (r *readCommand) readGlob(ctx context.Context, glob string, regex string) {
+ retryInterval := time.Second * 5
+ glob = filepath.Clean(glob)
+
+ maxRetries := 10
+ for {
+ maxRetries--
+ if maxRetries < 0 {
+ r.server.sendServerMessage(logger.Warn(r.server.user, "Giving up to read file(s)"))
+ return
+ }
+
+ paths, err := filepath.Glob(glob)
+ if err != nil {
+ logger.Warn(r.server.user, glob, err)
+ time.Sleep(retryInterval)
+ continue
+ }
+
+ if numPaths := len(paths); numPaths == 0 {
+ logger.Error(r.server.user, "No such file(s) to read", glob)
+ r.server.sendServerMessage(logger.Warn(r.server.user, "Unable to read file(s), check server logs"))
+ select {
+ case <-ctx.Done():
+ return
+ default:
+ }
+ time.Sleep(retryInterval)
+ continue
+ }
+
+ r.readFiles(ctx, paths, glob, regex, retryInterval)
+ break
+ }
+}
+
+func (r *readCommand) readFiles(ctx context.Context, paths []string, glob string, regex string, retryInterval time.Duration) {
+ var wg sync.WaitGroup
+ wg.Add(len(paths))
+
+ for _, path := range paths {
+ go r.readFileIfPermissions(ctx, &wg, path, glob, regex)
+ }
+
+ wg.Wait()
+}
+
+func (r *readCommand) readFileIfPermissions(ctx context.Context, wg *sync.WaitGroup, path, glob, regex string) {
+ defer wg.Done()
+ globID := r.makeGlobID(path, glob)
+
+ if !r.server.user.HasFilePermission(path, "readfiles") {
+ logger.Error(r.server.user, "No permission to read file", path, globID)
+ r.server.sendServerMessage(logger.Warn(r.server.user, "Unable to read file(s), check server logs"))
+ return
+ }
+
+ r.readFile(ctx, path, globID, regex)
+}
+
+func (r *readCommand) readFile(ctx context.Context, path, globID, regex string) {
+ logger.Info(r.server.user, "Start reading file", path, globID)
+
+ var reader fs.FileReader
+ switch r.mode {
+ case omode.TailClient:
+ reader = fs.NewTailFile(path, globID, r.server.serverMessages, r.server.tailLimiter)
+ case omode.GrepClient, omode.CatClient:
+ reader = fs.NewCatFile(path, globID, r.server.serverMessages, r.server.catLimiter)
+ default:
+ reader = fs.NewTailFile(path, globID, r.server.serverMessages, r.server.tailLimiter)
+ }
+
+ lines := r.server.lines
+
+ // Plug in mappreduce engine
+ if r.server.aggregate != nil {
+ lines = r.server.aggregate.Lines
+ }
+
+ for {
+ if err := reader.Start(ctx, lines, regex); err != nil {
+ logger.Error(r.server.user, path, globID, err)
+ }
+
+ select {
+ case <-ctx.Done():
+ return
+ default:
+ if !reader.Retry() {
+ return
+ }
+ }
+
+ time.Sleep(time.Second * 2)
+ logger.Info(path, globID, "Reading file again")
+ }
+}
+
+func (r *readCommand) makeGlobID(path, glob string) string {
+ var idParts []string
+ pathParts := strings.Split(path, "/")
+
+ for i, globPart := range strings.Split(glob, "/") {
+ if strings.Contains(globPart, "*") {
+ idParts = append(idParts, pathParts[i])
+ }
+ }
+
+ if len(idParts) > 0 {
+ return strings.Join(idParts, "/")
+ }
+
+ if len(pathParts) > 0 {
+ return pathParts[len(pathParts)-1]
+ }
+
+ r.server.sendServerMessage(logger.Error("Empty file path given?", path, glob))
+ return ""
+}
diff --git a/internal/server/handlers/runcommand.go b/internal/server/handlers/runcommand.go
new file mode 100644
index 0000000..e260060
--- /dev/null
+++ b/internal/server/handlers/runcommand.go
@@ -0,0 +1,73 @@
+package handlers
+
+import (
+ "context"
+ "fmt"
+ "os/exec"
+ "strings"
+
+ "github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/io/run"
+)
+
+type runCommand struct {
+ server *ServerHandler
+ run run.Run
+}
+
+func newRunCommand(server *ServerHandler) runCommand {
+ return runCommand{
+ server: server,
+ }
+}
+
+func (r runCommand) Start(ctx context.Context, argc int, args []string) {
+ if argc < 2 {
+ r.server.sendServerMessage(logger.Warn(r.server.user, commandParseWarning, args, argc))
+ return
+ }
+ commands := strings.Split(strings.Join(args[1:], " "), ";")
+ r.start(ctx, commands)
+}
+
+func (r runCommand) start(ctx context.Context, commands []string) {
+ for _, command := range commands {
+ command = strings.TrimSpace(command)
+ if len(command) == 0 {
+ continue
+ }
+ splitted := strings.Split(command, " ")
+ path := splitted[0]
+ args := splitted[1:]
+
+ qualifiedPath, err := exec.LookPath(path)
+ if err != nil {
+ logger.Error(r.server.user, err)
+ r.server.sendServerMessage(logger.Warn(r.server.user, "Unable to execute command(s), check server logs"))
+ r.server.sendServerMessage(fmt.Sprintf(".run exitstatus -%d", -1))
+ return
+ }
+
+ if !r.server.user.HasFilePermission(qualifiedPath, "runcommands") {
+ logger.Error(r.server.user, "No permission to execute path", qualifiedPath)
+ r.server.sendServerMessage(logger.Warn(r.server.user, "Unable to execute command(s), check server logs"))
+ r.server.sendServerMessage(fmt.Sprintf(".run exitstatus -%d", -1))
+ return
+ }
+
+ r.run = run.New(qualifiedPath, args)
+ pid, ec, err := r.run.Start(ctx, r.server.lines)
+
+ if err != nil {
+ message := fmt.Sprintf("Unable to execute remote command '%s'", command)
+ logger.Error(r.server.user, message, ec, pid, err)
+ r.server.sendServerMessage(logger.Error(message, ec, pid, err))
+ r.server.sendServerMessage(fmt.Sprintf(".run exitstatus -%d", ec))
+ return
+ }
+
+ message := fmt.Sprintf("Remote process '%d' exited with status '%d'", pid, ec)
+ r.server.sendServerMessage(fmt.Sprintf(".run exitstatus %d", ec))
+ r.server.sendServerMessage(logger.Info("run", pid, ec, message))
+ }
+}
diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go
index bed8609..3f0d6ce 100644
--- a/internal/server/handlers/serverhandler.go
+++ b/internal/server/handlers/serverhandler.go
@@ -1,17 +1,19 @@
package handlers
import (
+ "context"
+ "encoding/base64"
+ "errors"
"fmt"
"io"
"os"
- "path/filepath"
"strings"
"sync"
"time"
"github.com/mimecast/dtail/internal/config"
- "github.com/mimecast/dtail/internal/fs"
- "github.com/mimecast/dtail/internal/logger"
+ "github.com/mimecast/dtail/internal/io/line"
+ "github.com/mimecast/dtail/internal/io/logger"
"github.com/mimecast/dtail/internal/mapr/server"
"github.com/mimecast/dtail/internal/omode"
user "github.com/mimecast/dtail/internal/user/server"
@@ -26,51 +28,33 @@ const (
// the Bi-directional communication between SSH client and server.
// This handler implements the handler of the SSH server.
type ServerHandler struct {
- // Local log file readers
- fileReaders []fs.FileReader
- fileReadersMtx *sync.Mutex
- // Channel for read lines.
- lines chan fs.LineRead
- // Only process log lines matching this regex.
- regex string
- // Server side mapr log aggregation.
- aggregate *server.Aggregate
- // Channel of aggregated log lines.
+ mutex *sync.Mutex
+ lines chan line.Line
+ regex string
+ aggregate *server.Aggregate
aggregatedMessages chan string
- // Channel for server messages to be sent to the client.
- serverMessages chan string
- // Channel for hidden messages to be sent to the client.
- hiddenMessages chan string
- // The current payload sent to the client.
- payload []byte
- // The current server hostname.
- hostname string
- // The user connecting to dtail.
- user *user.User
- // To limit the server wide max amount of concurrent cats
- catLimiter chan struct{}
- // To limit the server wide max amount of concurrent tails
- tailLimiter chan struct{}
- // Server can tell handler to stop the handler.
- stop chan struct{}
- // Indicate that client responded to server with "ack stop connection"
- ackStopReceived chan struct{}
- // Stop timeout.
- stopTimeout chan struct{}
+ serverMessages chan string
+ payload []byte
+ hostname string
+ user *user.User
+ catLimiter chan struct{}
+ tailLimiter chan struct{}
+ ackCloseReceived chan struct{}
+ ctx context.Context
+ done chan struct{}
+ activeReaders int
}
// NewServerHandler returns the server handler.
-func NewServerHandler(user *user.User, catLimiter chan struct{}, tailLimiter chan struct{}) *ServerHandler {
- logger.Debug(user, "Creating tail handler")
+func NewServerHandler(ctx context.Context, user *user.User, catLimiter chan struct{}, tailLimiter chan struct{}) (*ServerHandler, <-chan struct{}) {
h := ServerHandler{
- fileReadersMtx: &sync.Mutex{},
- lines: make(chan fs.LineRead, 100),
+ ctx: ctx,
+ done: make(chan struct{}),
+ mutex: &sync.Mutex{},
+ lines: make(chan line.Line, 100),
serverMessages: make(chan string, 10),
aggregatedMessages: make(chan string, 10),
- hiddenMessages: make(chan string, 10),
- ackStopReceived: make(chan struct{}),
- stopTimeout: make(chan struct{}),
- stop: make(chan struct{}),
+ ackCloseReceived: make(chan struct{}),
catLimiter: catLimiter,
tailLimiter: tailLimiter,
regex: ".",
@@ -85,37 +69,46 @@ func NewServerHandler(user *user.User, catLimiter chan struct{}, tailLimiter cha
s := strings.Split(fqdn, ".")
h.hostname = s[0]
- return &h
+ return &h, h.done
}
// Read is to send data to the dtail client via Reader interface.
func (h *ServerHandler) Read(p []byte) (n int, err error) {
for {
select {
+
case message := <-h.serverMessages:
+ if message[0] == '.' {
+ // Handle hidden message (don't display to the user, interpreted by dtail client)
+ wholePayload := []byte(fmt.Sprintf("%s\n", message))
+ n = copy(p, wholePayload)
+ return
+ }
+
+ // Handle normal server message (display to the user)
wholePayload := []byte(fmt.Sprintf("SERVER|%s|%s\n", h.hostname, message))
n = copy(p, wholePayload)
return
+
case message := <-h.aggregatedMessages:
+ // Send mapreduce-aggregated data as a message.
data := fmt.Sprintf("AGGREGATE|%s|%s\n", h.hostname, message)
- //logger.Debug("Sending aggregation data", data)
wholePayload := []byte(data)
n = copy(p, wholePayload)
return
- case message := <-h.hiddenMessages:
- //logger.Debug(h.user, "Sending hidden message", message)
- wholePayload := []byte(fmt.Sprintf(".%s\n", message))
- n = copy(p, wholePayload)
- return
+
case line := <-h.lines:
+ // Send normal file content data as a message.
serverInfo := []byte(fmt.Sprintf("REMOTE|%s|%3d|%v|%s|",
- h.hostname, line.TransmittedPerc, line.Count, *line.GlobID))
+ h.hostname, line.TransmittedPerc, line.Count, line.SourceID))
wholePayload := append(serverInfo, line.Content[:]...)
n = copy(p, wholePayload)
return
+
case <-time.After(time.Second):
+ // Once in a while check whether we are done.
select {
- case <-h.stop:
+ case <-h.ctx.Done():
return 0, io.EOF
default:
}
@@ -129,7 +122,7 @@ func (h *ServerHandler) Write(p []byte) (n int, err error) {
switch c {
case ';':
commandStr := strings.TrimSpace(string(h.payload))
- h.handleCommand(commandStr)
+ h.handleCommand(h.ctx, commandStr)
h.payload = nil
default:
h.payload = append(h.payload, c)
@@ -140,210 +133,167 @@ func (h *ServerHandler) Write(p []byte) (n int, err error) {
return
}
-// Close the server handler.
-func (h *ServerHandler) Close() {
- h.fileReadersMtx.Lock()
- defer h.fileReadersMtx.Unlock()
+func (h *ServerHandler) handleCommand(ctx context.Context, commandStr string) {
+ logger.Debug(h.user, commandStr)
- for _, reader := range h.fileReaders {
- reader.Stop()
+ args, argc, err := h.handleProtocolVersion(strings.Split(commandStr, " "))
+ if err != nil {
+ h.send(h.serverMessages, logger.Error(h.user, err))
+ return
}
- if h.aggregate != nil {
- h.aggregate.Close()
+
+ args, argc, err = h.handleBase64(args, argc)
+ if err != nil {
+ h.send(h.serverMessages, logger.Error(h.user, err))
+ return
}
- close(h.stop)
-}
+ if h.user.Name == config.ControlUser {
+ h.handleControlCommand(argc, args)
+ return
+ }
-func (h *ServerHandler) makeGlobID(path, glob string) string {
- var idParts []string
- pathParts := strings.Split(path, "/")
+ h.handleUserCommand(ctx, argc, args)
+}
- for i, globPart := range strings.Split(glob, "/") {
- if strings.Contains(globPart, "*") {
- idParts = append(idParts, pathParts[i])
- }
- }
+func (h *ServerHandler) handleProtocolVersion(args []string) ([]string, int, error) {
+ argc := len(args)
- if len(idParts) > 0 {
- return strings.Join(idParts, "/")
+ if argc <= 2 || args[0] != "protocol" {
+ return args, argc, errors.New("unable to determine protocol version")
}
- if len(pathParts) > 0 {
- return pathParts[len(pathParts)-1]
+ if args[1] != version.ProtocolCompat {
+ err := fmt.Errorf("server with protool version '%s' but client with '%s', please update DTail", version.ProtocolCompat, args[1])
+ return args, argc, err
}
- h.send(h.serverMessages, logger.Error("Empty file path given?", path, glob))
- return ""
+ return args[2:], argc - 2, nil
}
-func (h *ServerHandler) processFileGlob(mode omode.Mode, glob string, regex string) {
- retryInterval := time.Second * 5
- glob = filepath.Clean(glob)
-
- errors := make(chan struct{})
- stop := make(chan struct{})
- defer close(stop)
+func (h *ServerHandler) handleBase64(args []string, argc int) ([]string, int, error) {
+ err := errors.New("Unable to decode client message")
- go func() {
- for {
- select {
- case <-errors:
- h.send(h.serverMessages, logger.Warn(h.user, "Unable to read file(s), check server logs"))
- case <-stop:
- return
- case <-h.stop:
- return
- }
- }
- }()
+ if argc != 2 || args[0] != "base64" {
+ return args, argc, err
+ }
- maxRetries := 10
- for {
- maxRetries--
- if maxRetries < 0 {
- h.send(h.serverMessages, logger.Warn(h.user, "Giving up to read file(s)"))
- h.internalClose()
- return
- }
+ decoded, err := base64.StdEncoding.DecodeString(args[1])
+ if err != nil {
+ return args, argc, err
+ }
+ decodedStr := string(decoded)
- paths, err := filepath.Glob(glob)
- if err != nil {
- logger.Warn(h.user, glob, err)
- time.Sleep(retryInterval)
- continue
- }
+ args = strings.Split(decodedStr, " ")
+ argc = len(decodedStr)
+ logger.Trace(h.user, "Base64 decoded received command", decodedStr, argc, args)
- if numPaths := len(paths); numPaths == 0 {
- logger.Error(h.user, "No such file(s) to read", glob)
- select {
- case errors <- struct{}{}:
- case <-h.stop:
- return
- default:
- }
- time.Sleep(retryInterval)
- continue
- }
+ return args, argc, nil
+}
- h.startReadingFiles(mode, paths, glob, regex, retryInterval, errors)
- break
+func (h *ServerHandler) handleControlCommand(argc int, args []string) {
+ switch args[0] {
+ case "debug":
+ h.send(h.serverMessages, logger.Debug(h.user, "Receiving debug command", argc, args))
+ default:
+ logger.Warn(h.user, "Received unknown command", argc, args)
}
}
-func (h *ServerHandler) startReadingFiles(mode omode.Mode, paths []string, glob string, regex string, retryInterval time.Duration, errors chan<- struct{}) {
- var wg sync.WaitGroup
- wg.Add(len(paths))
+func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []string) {
+ logger.Debug(h.user, "handleUserCommand", argc, args)
- read := func(path string, wg *sync.WaitGroup) {
- defer wg.Done()
- globID := h.makeGlobID(path, glob)
+ switch args[0] {
+ case "grep", "cat":
+ command := newReadCommand(h, omode.CatClient)
+ h.incrementActiveReaders()
+ go func() {
+ command.Start(ctx, argc, args)
+ if h.decrementActiveReaders() == 0 {
+ h.shutdown()
+ }
+ }()
- if !h.user.HasFilePermission(path) {
- logger.Error(h.user, "No permission to read file", path, globID)
- select {
- case errors <- struct{}{}:
- default:
+ case "tail":
+ command := newReadCommand(h, omode.TailClient)
+ h.incrementActiveReaders()
+ go func() {
+ command.Start(ctx, argc, args)
+ if h.decrementActiveReaders() == 0 {
+ h.shutdown()
}
+ }()
+
+ case "map":
+ command, aggregate, err := newMapCommand(h, argc, args)
+ if err != nil {
+ h.sendServerMessage(err.Error())
+ logger.Error(h.user, err)
return
}
- h.startReadingFile(mode, path, globID, regex)
- }
-
- for _, path := range paths {
- go read(path, &wg)
- }
+ h.aggregate = aggregate
+ go func() {
+ command.Start(ctx, h.aggregatedMessages)
+ h.shutdown()
+ }()
+
+ case "run":
+ command := newRunCommand(h)
+ h.incrementActiveReaders()
+ go func() {
+ command.Start(ctx, argc, args)
+ if h.decrementActiveReaders() == 0 {
+ h.shutdown()
+ }
+ }()
- wg.Wait()
-}
+ case "ack", ".ack":
+ h.handleAckCommand(argc, args)
-func (h *ServerHandler) startReadingFile(mode omode.Mode, path, globID, regex string) {
- defer h.stopReadingFile(path)
- logger.Info(h.user, "Start reading file", path, globID)
-
- var reader fs.FileReader
- switch mode {
- case omode.TailClient:
- reader = fs.NewTailFile(path, globID, h.serverMessages, h.tailLimiter)
- case omode.GrepClient:
- fallthrough
- case omode.CatClient:
- reader = fs.NewCatFile(path, globID, h.serverMessages, h.catLimiter)
default:
- reader = fs.NewTailFile(path, globID, h.serverMessages, h.tailLimiter)
+ h.sendServerMessage(logger.Error(h.user, "Received unknown command", argc, args))
}
+}
- h.fileReadersMtx.Lock()
- h.fileReaders = append(h.fileReaders, reader)
- h.fileReadersMtx.Unlock()
-
- lines := h.lines
- // Plugin mappreduce engine
- if h.aggregate != nil {
- lines = h.aggregate.Lines
+func (h *ServerHandler) handleAckCommand(argc int, args []string) {
+ if argc < 3 {
+ h.sendServerMessage(logger.Warn(h.user, commandParseWarning, args, argc))
+ return
}
-
- for {
- if err := reader.Start(lines, regex); err != nil {
- logger.Error(h.user, path, globID, err)
- }
-
- select {
- case <-h.stop:
- return
- default:
- if !reader.Retry() {
- return
- }
- }
-
- time.Sleep(time.Second * 2)
- logger.Info(path, globID, "Reading file again")
+ if args[1] == "close" && args[2] == "connection" {
+ close(h.ackCloseReceived)
}
}
-func (h *ServerHandler) stopReadingFile(path string) {
- logger.Info(h.user, "Stop reading file", path)
+func (h *ServerHandler) send(ch chan<- string, message string) {
+ select {
+ case ch <- message:
+ case <-h.ctx.Done():
+ }
+}
- h.fileReadersMtx.Lock()
- defer h.fileReadersMtx.Unlock()
+func (h *ServerHandler) sendServerMessage(message string) {
+ h.send(h.serverMessageC(), message)
+}
- path = filepath.Clean(path)
- var fileReaders []fs.FileReader
+func (h *ServerHandler) serverMessageC() chan<- string {
+ return h.serverMessages
+}
- for _, reader := range h.fileReaders {
- if reader.FilePath() == path {
- reader.Stop()
- continue
- }
- fileReaders = append(fileReaders, reader)
- }
+func (h *ServerHandler) flush() {
+ logger.Debug(h.user, "flush()")
- if len(fileReaders) == len(h.fileReaders) {
- logger.Warn(h.user, "Didn't read file path", path)
- return
+ if h.aggregate != nil {
+ h.aggregate.Flush()
}
- h.fileReaders = fileReaders
-
- if len(fileReaders) == 0 {
- if h.aggregate != nil {
- h.aggregate.Serialize()
- }
- h.allLinesSent()
+ unsentMessages := func() int {
+ return len(h.lines) + len(h.serverMessages) + len(h.aggregatedMessages)
}
-}
-
-func (h *ServerHandler) numUnsentMessages() int {
- return len(h.lines) + len(h.serverMessages) + len(h.hiddenMessages) + len(h.aggregatedMessages)
-}
-
-func (h *ServerHandler) allLinesSent() {
- defer h.internalClose()
for i := 0; i < 3; i++ {
- if h.numUnsentMessages() == 0 {
+ if unsentMessages() == 0 {
logger.Debug(h.user, "All lines sent")
return
}
@@ -351,142 +301,43 @@ func (h *ServerHandler) allLinesSent() {
time.Sleep(time.Second)
}
- logger.Warn(h.user, "Some lines remain unsent", h.numUnsentMessages())
+ logger.Warn(h.user, "Some lines remain unsent", unsentMessages())
}
-// Handler decides to shutdown the connection, not the server itself.
-func (h *ServerHandler) internalClose() {
- select {
- case h.hiddenMessages <- "syn close connection":
- case <-time.After(time.Second * 5):
- logger.Debug(h.user, "Not waiting for ack close connection")
- close(h.stopTimeout)
- return
- }
+func (h *ServerHandler) shutdown() {
+ logger.Debug(h.user, "shutdown()")
+ h.flush()
+
+ go func() {
+ select {
+ case h.serverMessageC() <- ".syn close connection":
+ case <-h.ctx.Done():
+ }
+ }()
select {
- case <-h.Wait():
+ case <-h.ackCloseReceived:
case <-time.After(time.Second * 5):
- logger.Debug(h.user, "Not waiting for ack close connection")
- close(h.stopTimeout)
- }
-}
-
-func (h *ServerHandler) handleCommand(commandStr string) {
- logger.Info(h.user, commandStr)
-
- args := strings.Split(commandStr, " ")
- argc := len(args)
-
- logger.Debug(h.user, "Received command", commandStr, argc, args)
-
- if h.user.Name == config.ControlUser {
- h.handleControlCommand(argc, args)
- return
+ logger.Debug(h.user, "Shutdown timeout reached, enforcing shutdown")
+ case <-h.ctx.Done():
}
- h.handleUserCommand(argc, args)
-}
-
-// Special (restricted) set of commands for anonymous ControlUser access.
-func (h *ServerHandler) handleControlCommand(argc int, args []string) {
- switch args[0] {
- case "ping":
- h.send(h.hiddenMessages, "pong")
- case "debug":
- h.send(h.serverMessages, logger.Debug(h.user, "Receiving debug command", argc, args))
- default:
- logger.Warn(h.user, "Received unknown command", argc, args)
- }
-}
-
-// Commands for authed users.
-func (h *ServerHandler) handleUserCommand(argc int, args []string) {
- switch args[0] {
- case "grep":
- fallthrough
- case "cat":
- h.handleReadCommand(argc, args, omode.CatClient)
- case "tail":
- h.handleReadCommand(argc, args, omode.TailClient)
- case "map":
- h.handleMapCommand(argc, args)
- case "ack":
- h.handleAckCommand(argc, args)
- case "ping":
- h.send(h.hiddenMessages, "pong")
- case "version":
- h.send(h.serverMessages, fmt.Sprintf("Server version is "+version.String()))
- case "debug":
- h.send(h.serverMessages, logger.Debug(h.user, "Received debug command", argc, args))
+ select {
+ case h.done <- struct{}{}:
default:
- h.send(h.serverMessages, logger.Warn(h.user, "Received unknown command", argc, args))
}
}
-func (h *ServerHandler) handleReadCommand(argc int, args []string, mode omode.Mode) {
- regex := "."
- if argc >= 4 {
- regex = args[3]
- }
- if argc < 3 {
- h.send(h.serverMessages, logger.Warn(h.user, commandParseWarning, args, argc))
- return
- }
- go h.processFileGlob(mode, args[1], regex)
+func (h *ServerHandler) incrementActiveReaders() {
+ // TODO: Use atomic counter variable instead, so we can get rid of the mutex
+ h.mutex.Lock()
+ defer h.mutex.Unlock()
+ h.activeReaders++
}
+func (h *ServerHandler) decrementActiveReaders() int {
+ h.mutex.Lock()
+ defer h.mutex.Unlock()
+ h.activeReaders--
-func (h *ServerHandler) handleMapCommand(argc int, args []string) {
- if argc < 2 {
- h.send(h.serverMessages, logger.Warn(h.user, commandParseWarning, args, argc))
- return
- }
-
- queryStr := strings.Join(args[1:], " ")
- logger.Info(h.user, "Creating new mapr aggregator", queryStr)
- aggregate, err := server.NewAggregate(h.aggregatedMessages, queryStr)
-
- if err != nil {
- h.send(h.serverMessages, logger.Error(h.user, err))
- return
- }
-
- h.aggregate = aggregate
-}
-
-func (h *ServerHandler) handleAckCommand(argc int, args []string) {
- if argc < 3 {
- h.send(h.serverMessages, logger.Warn(h.user, commandParseWarning, args, argc))
- return
- }
- if args[1] == "close" && args[2] == "connection" {
- close(h.ackStopReceived)
- }
-}
-
-func (h *ServerHandler) send(ch chan<- string, message string) {
- select {
- case ch <- message:
- case <-h.stop:
- }
-}
-
-// Wait (block) until server handler is closed or a timeout has exceeded.
-func (h *ServerHandler) Wait() <-chan struct{} {
- wait := make(chan struct{})
-
- go func() {
- select {
- case <-h.ackStopReceived:
- logger.Debug(h.user, "Closing wait channel due to ACK stop received")
- close(wait)
- case <-h.stopTimeout:
- logger.Debug(h.user, "Closing wait channel due to wait timeout")
- close(wait)
- case <-h.stop:
- logger.Debug(h.user, "Closing wait channel due to stop")
- }
- }()
-
- return wait
+ return h.activeReaders
}
diff --git a/internal/server/server.go b/internal/server/server.go
index 27a98f5..42eb74c 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -1,13 +1,14 @@
package server
import (
+ "context"
"errors"
"fmt"
"io"
"net"
"github.com/mimecast/dtail/internal/config"
- "github.com/mimecast/dtail/internal/logger"
+ "github.com/mimecast/dtail/internal/io/logger"
"github.com/mimecast/dtail/internal/server/handlers"
"github.com/mimecast/dtail/internal/ssh/server"
user "github.com/mimecast/dtail/internal/user/server"
@@ -26,8 +27,6 @@ type Server struct {
catLimiterCh chan struct{}
// To control the max amount of concurrent tails
tailLimiterCh chan struct{}
- // Ask to shutdown the server
- stop chan struct{}
}
// New returns a new server.
@@ -38,7 +37,6 @@ func New() *Server {
sshServerConfig: &gossh.ServerConfig{},
catLimiterCh: make(chan struct{}, config.Server.MaxConcurrentCats),
tailLimiterCh: make(chan struct{}, config.Server.MaxConcurrentTails),
- stop: make(chan struct{}),
}
s.sshServerConfig.PasswordCallback = s.controlUserCallback
@@ -54,7 +52,7 @@ func New() *Server {
}
// Start the server.
-func (s *Server) Start() int {
+func (s *Server) Start(ctx context.Context) int {
logger.Info("Starting server")
bindAt := fmt.Sprintf("%s:%d", config.Server.SSHBindAddress, config.Common.SSHPort)
@@ -64,7 +62,7 @@ func (s *Server) Start() int {
logger.FatalExit("Failed to open listening TCP socket", err)
}
- go s.stats.periodicLogServerStats(s.stop)
+ go s.stats.periodicLogServerStats(ctx)
for {
conn, err := listener.Accept() // Blocking
@@ -79,11 +77,11 @@ func (s *Server) Start() int {
continue
}
- go s.handleConnection(conn)
+ go s.handleConnection(ctx, conn)
}
}
-func (s *Server) handleConnection(conn net.Conn) {
+func (s *Server) handleConnection(ctx context.Context, conn net.Conn) {
logger.Info("Handling connection")
sshConn, chans, reqs, err := gossh.NewServerConn(conn, s.sshServerConfig)
@@ -96,11 +94,11 @@ func (s *Server) handleConnection(conn net.Conn) {
go gossh.DiscardRequests(reqs)
for newChannel := range chans {
- go s.handleChannel(sshConn, newChannel)
+ go s.handleChannel(ctx, sshConn, newChannel)
}
}
-func (s *Server) handleChannel(sshConn gossh.Conn, newChannel gossh.NewChannel) {
+func (s *Server) handleChannel(ctx context.Context, sshConn gossh.Conn, newChannel gossh.NewChannel) {
user := user.New(sshConn.User(), sshConn.RemoteAddr().String())
logger.Info(user, "Invoking channel handler")
@@ -117,13 +115,13 @@ func (s *Server) handleChannel(sshConn gossh.Conn, newChannel gossh.NewChannel)
return
}
- if err := s.handleRequests(sshConn, requests, channel, user); err != nil {
+ if err := s.handleRequests(ctx, sshConn, requests, channel, user); err != nil {
logger.Error(user, err)
sshConn.Close()
}
}
-func (s *Server) handleRequests(sshConn gossh.Conn, in <-chan *gossh.Request, channel gossh.Channel, user *user.User) error {
+func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, in <-chan *gossh.Request, channel gossh.Channel, user *user.User) error {
logger.Info(user, "Invoking request handler")
for req := range in {
@@ -132,50 +130,50 @@ func (s *Server) handleRequests(sshConn gossh.Conn, in <-chan *gossh.Request, ch
switch req.Type {
case "shell":
+ handlerCtx, cancel := context.WithCancel(ctx)
+
var handler handlers.Handler
+ var done <-chan struct{}
+
switch user.Name {
case config.ControlUser:
- handler = handlers.NewControlHandler(user)
+ handler, done = handlers.NewControlHandler(handlerCtx, user)
default:
- handler = handlers.NewServerHandler(user, s.catLimiterCh, s.tailLimiterCh)
+ handler, done = handlers.NewServerHandler(handlerCtx, user, s.catLimiterCh, s.tailLimiterCh)
}
- // Bi-directionally connect SSH stream to SSH handler
- brokenPipe1 := make(chan struct{})
go func() {
- defer close(brokenPipe1)
+ // Handler finished work, cancel all remaining routines
+ defer cancel()
+ <-done
+ }()
+
+ go func() {
+ // Broken pipe, cancel
+ defer cancel()
+
io.Copy(channel, handler)
}()
- brokenPipe2 := make(chan struct{})
go func() {
- defer close(brokenPipe2)
+ // Broken pipe, cancel
+ defer cancel()
+
io.Copy(handler, channel)
}()
- // Ensure to close all fd's and stop all goroutines once ssh connection terminated
go func() {
- defer s.stats.decrementConnections()
- defer handler.Close()
+ defer cancel()
if err := sshConn.Wait(); err != nil && err != io.EOF {
logger.Error(user, err)
}
+ s.stats.decrementConnections()
logger.Info(user, "Good bye Mister!")
}()
- // Close the underlying ssh socket when server shuts down
go func() {
- select {
- case <-s.stop:
- logger.Debug(user, "Server initiating shutdown on handler")
- case <-handler.Wait():
- logger.Debug(user, "Handler initiating shutdown by its own")
- case <-brokenPipe1:
- logger.Debug(user, "Broken pipe1")
- case <-brokenPipe2:
- logger.Debug(user, "Broken pipe2")
- }
+ <-handlerCtx.Done()
sshConn.Close()
logger.Info(user, "Closed SSH connection")
}()
@@ -204,9 +202,3 @@ func (*Server) controlUserCallback(c gossh.ConnMetadata, authPayload []byte) (*g
return nil, fmt.Errorf("Not authorized")
}
-
-// Stop the server.
-func (s *Server) Stop() {
- close(s.stop)
- s.stats.waitForConnections()
-}
diff --git a/internal/server/stats.go b/internal/server/stats.go
index beb1885..4d661f7 100644
--- a/internal/server/stats.go
+++ b/internal/server/stats.go
@@ -1,12 +1,14 @@
package server
import (
- "github.com/mimecast/dtail/internal/config"
- "github.com/mimecast/dtail/internal/logger"
+ "context"
"fmt"
"runtime"
"sync"
"time"
+
+ "github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/io/logger"
)
// Used to collect and display various server stats.
@@ -65,12 +67,12 @@ func (s *stats) serverLimitExceeded() error {
return nil
}
-func (s *stats) periodicLogServerStats(stop <-chan struct{}) {
+func (s *stats) periodicLogServerStats(ctx context.Context) {
for {
select {
case <-time.NewTimer(time.Second * 10).C:
s.logServerStats()
- case <-stop:
+ case <-ctx.Done():
return
}
}
diff --git a/internal/ssh/client/authmethods.go b/internal/ssh/client/authmethods.go
index 3392eb1..967866f 100644
--- a/internal/ssh/client/authmethods.go
+++ b/internal/ssh/client/authmethods.go
@@ -2,7 +2,7 @@ package client
import (
"github.com/mimecast/dtail/internal/config"
- "github.com/mimecast/dtail/internal/logger"
+ "github.com/mimecast/dtail/internal/io/logger"
"github.com/mimecast/dtail/internal/ssh"
"os"
diff --git a/internal/ssh/client/hostkeycallback.go b/internal/ssh/client/hostkeycallback.go
index 4023e59..7ae2396 100644
--- a/internal/ssh/client/hostkeycallback.go
+++ b/internal/ssh/client/hostkeycallback.go
@@ -2,8 +2,7 @@ package client
import (
"bufio"
- "github.com/mimecast/dtail/internal/logger"
- "github.com/mimecast/dtail/internal/prompt"
+ "context"
"fmt"
"net"
"os"
@@ -11,6 +10,9 @@ import (
"sync"
"time"
+ "github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/prompt"
+
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/knownhosts"
)
@@ -116,7 +118,7 @@ func (h *HostKeyCallback) Wrap() ssh.HostKeyCallback {
// PromptAddHosts prompts a question to the user whether unknown hosts should
// be added to the known hosts or not.
-func (h *HostKeyCallback) PromptAddHosts(stop <-chan struct{}) {
+func (h *HostKeyCallback) PromptAddHosts(ctx context.Context) {
var hosts []unknownHost
for {
@@ -135,7 +137,7 @@ func (h *HostKeyCallback) PromptAddHosts(stop <-chan struct{}) {
h.promptAddHosts(hosts)
hosts = []unknownHost{}
}
- case <-stop:
+ case <-ctx.Done():
logger.Debug("Stopping goroutine prompting new hosts...")
return
}
diff --git a/internal/ssh/server/hostkey.go b/internal/ssh/server/hostkey.go
index 7baa4aa..07790ad 100644
--- a/internal/ssh/server/hostkey.go
+++ b/internal/ssh/server/hostkey.go
@@ -2,7 +2,7 @@ package server
import (
"github.com/mimecast/dtail/internal/config"
- "github.com/mimecast/dtail/internal/logger"
+ "github.com/mimecast/dtail/internal/io/logger"
"github.com/mimecast/dtail/internal/ssh"
"io/ioutil"
"os"
diff --git a/internal/ssh/server/publickeycallback.go b/internal/ssh/server/publickeycallback.go
index c6929d7..757def7 100644
--- a/internal/ssh/server/publickeycallback.go
+++ b/internal/ssh/server/publickeycallback.go
@@ -7,7 +7,7 @@ import (
osUser "os/user"
"github.com/mimecast/dtail/internal/config"
- "github.com/mimecast/dtail/internal/logger"
+ "github.com/mimecast/dtail/internal/io/logger"
user "github.com/mimecast/dtail/internal/user/server"
gossh "golang.org/x/crypto/ssh"
diff --git a/internal/ssh/ssh.go b/internal/ssh/ssh.go
index 77cc341..3a2e416 100644
--- a/internal/ssh/ssh.go
+++ b/internal/ssh/ssh.go
@@ -4,9 +4,9 @@ import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
- "github.com/mimecast/dtail/internal/logger"
"encoding/pem"
"fmt"
+ "github.com/mimecast/dtail/internal/io/logger"
"io/ioutil"
"net"
"os"
diff --git a/internal/user/name.go b/internal/user/name.go
index 5171ec7..28ab0a4 100644
--- a/internal/user/name.go
+++ b/internal/user/name.go
@@ -2,10 +2,10 @@ package user
import (
"os/user"
- )
+)
-
-func Name() string {
+// NoRootCheck verifies that the DTail run user is not with UID or GID 0.
+func NoRootCheck() {
user, err := user.Current()
if err != nil {
panic(err)
@@ -18,7 +18,14 @@ func Name() string {
if user.Gid == "0" {
panic("Not allowed to run as GID 0")
}
+}
+
+// Name of the current run user.
+func Name() string {
+ user, err := user.Current()
+ if err != nil {
+ panic(err)
+ }
return user.Username
}
-
diff --git a/internal/user/server/user.go b/internal/user/server/user.go
index fad38d8..271a4ac 100644
--- a/internal/user/server/user.go
+++ b/internal/user/server/user.go
@@ -1,14 +1,15 @@
package server
import (
- "github.com/mimecast/dtail/internal/config"
- "github.com/mimecast/dtail/internal/fs/permissions"
- "github.com/mimecast/dtail/internal/logger"
"fmt"
"os"
"path/filepath"
"regexp"
"strings"
+
+ "github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/io/fs/permissions"
+ "github.com/mimecast/dtail/internal/io/logger"
)
const maxLinkDepth int = 100
@@ -37,26 +38,28 @@ func (u *User) String() string {
}
// HasFilePermission is used to determine whether user is alowed to read a file.
-func (u *User) HasFilePermission(filePath string) (hasPermission bool) {
+func (u *User) HasFilePermission(filePath, permissionType string) (hasPermission bool) {
+ logger.Debug(u, filePath, permissionType, "Checking config permissions")
+
cleanPath, err := filepath.EvalSymlinks(filePath)
if err != nil {
- logger.Error(u, filePath, "Unable to evaluate symlinks", err)
+ logger.Error(u, filePath, permissionType, "Unable to evaluate symlinks", err)
hasPermission = false
return
}
cleanPath, err = filepath.Abs(cleanPath)
if err != nil {
- logger.Error(u, cleanPath, "Unable to make file path absolute", err)
+ logger.Error(u, cleanPath, permissionType, "Unable to make file path absolute", err)
hasPermission = false
return
}
if cleanPath != filePath {
- logger.Info(u, filePath, cleanPath, "Calculated new clean path from original file path (possibly symlink)")
+ logger.Info(u, filePath, cleanPath, permissionType, "Calculated new clean path from original file path (possibly symlink)")
}
- hasPermission, err = u.hasFilePermission(cleanPath)
+ hasPermission, err = u.hasFilePermission(cleanPath, permissionType)
if err != nil {
logger.Warn(u, cleanPath, err)
}
@@ -64,12 +67,12 @@ func (u *User) HasFilePermission(filePath string) (hasPermission bool) {
return
}
-func (u *User) hasFilePermission(cleanPath string) (bool, error) {
+func (u *User) hasFilePermission(cleanPath, permissionType string) (bool, error) {
// First check file system Linux/UNIX permission.
if _, err := permissions.ToRead(u.Name, cleanPath); err != nil {
- return false, fmt.Errorf("User without OS file system permissions to read file: '%v'", err)
+ return false, fmt.Errorf("User without OS file system permissions to read path: '%v'", err)
}
- logger.Info(u, cleanPath, "User has OS file system permissions to read file")
+ logger.Info(u, cleanPath, permissionType, "User with OS file system permissions to path")
// If file system permission is given, also check permissions
// as configured in DTail config file.
@@ -84,7 +87,7 @@ func (u *User) hasFilePermission(cleanPath string) (bool, error) {
var hasPermission bool
var err error
- if hasPermission, err = u.iteratePaths(cleanPath); err != nil {
+ if hasPermission, err = u.iteratePaths(cleanPath, permissionType); err != nil {
return false, err
}
@@ -101,17 +104,28 @@ func (u *User) hasFilePermission(cleanPath string) (bool, error) {
return hasPermission, nil
}
-func (u *User) iteratePaths(cleanPath string) (bool, error) {
+func (u *User) iteratePaths(cleanPath, permissionType string) (bool, error) {
for _, permission := range u.permissions {
+ typeStr := "readfiles" // Assume ReadFiles by default.
+
var regexStr string
var negate bool
+ splitted := strings.Split(permission, ":")
+ if len(splitted) > 1 {
+ typeStr = splitted[0]
+ permission = strings.Join(splitted[1:], ":")
+ }
+
+ if typeStr != permissionType {
+ continue
+ }
+
+ regexStr = permission
if strings.HasPrefix(permission, "!") {
regexStr = permission[1:]
negate = true
}
- regexStr = permission
- negate = false
re, err := regexp.Compile(regexStr)
if err != nil {
diff --git a/internal/version/version.go b/internal/version/version.go
index 3a4a5dc..3c057df 100644
--- a/internal/version/version.go
+++ b/internal/version/version.go
@@ -7,18 +7,20 @@ import (
"github.com/mimecast/dtail/internal/color"
)
-// Name of DTail.
-const Name = "DTail"
-
-// Version of DTail.
-const Version = "1.1.0"
-
-// Additional information.
-const Additional = ""
+const (
+ // Name of DTail.
+ Name string = "DTail"
+ // Version of DTail.
+ Version string = "2.0.0"
+ // Additional information for DTail
+ Additional string = ""
+ // ProtocolCompat -ibility version.
+ ProtocolCompat string = "2"
+)
// String representation of the DTail version.
func String() string {
- return fmt.Sprintf("%s v%v %s", Name, Version, Additional)
+ return fmt.Sprintf("%s %v Protocol %s %s", Name, Version, ProtocolCompat, Additional)
}
// PaintedString is a prettier string representation of the DTail version.
@@ -30,7 +32,7 @@ func PaintedString() string {
version := color.Paint(color.Blue, Version)
descr := color.Paint(color.Green, Additional)
- return fmt.Sprintf("%s %v %s", name, version, descr)
+ return fmt.Sprintf("%s %v Protocol %s %s", name, version, ProtocolCompat, descr)
}
// PrintAndExit prints the program version and exists.