summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorPaul Bütow <pbuetow@mimecast.com>2020-01-20 18:41:05 +0000
committerPaul Bütow <pbuetow@mimecast.com>2020-01-21 14:35:23 +0000
commitc128865c4c7411c29a59fca9a3a2f95537686d7b (patch)
tree193bccc70d942c8b70cc93fae2670263701e43aa /internal
parent3755a9911ecb05886577095f2b8cc8b9e4066a3a (diff)
Move commands to cmd/ and move internal dependencies to internal/
Diffstat (limited to 'internal')
-rw-r--r--internal/clients/args.go18
-rw-r--r--internal/clients/baseclient.go137
-rw-r--r--internal/clients/catclient.go53
-rw-r--r--internal/clients/client.go7
-rw-r--r--internal/clients/connectionmaker.go12
-rw-r--r--internal/clients/grepclient.go53
-rw-r--r--internal/clients/handlers/basehandler.go134
-rw-r--r--internal/clients/handlers/clienthandler.go26
-rw-r--r--internal/clients/handlers/handler.go12
-rw-r--r--internal/clients/handlers/healthhandler.go75
-rw-r--r--internal/clients/handlers/maprhandler.go74
-rw-r--r--internal/clients/healthclient.go95
-rw-r--r--internal/clients/maprclient.go152
-rw-r--r--internal/clients/remote/connection.go230
-rw-r--r--internal/clients/stats.go81
-rw-r--r--internal/clients/tailclient.go49
-rw-r--r--internal/color/color.go70
-rw-r--r--internal/color/colorfy.go58
-rw-r--r--internal/config/client.go11
-rw-r--r--internal/config/common.go42
-rw-r--r--internal/config/config.go45
-rw-r--r--internal/config/read.go37
-rw-r--r--internal/config/server.go66
-rw-r--r--internal/discovery/comma.go12
-rw-r--r--internal/discovery/discovery.go173
-rw-r--r--internal/discovery/file.go28
-rw-r--r--internal/fs/catfile.go27
-rw-r--r--internal/fs/filereader.go9
-rw-r--r--internal/fs/lineread.go28
-rw-r--r--internal/fs/permissions/permission.go14
-rw-r--r--internal/fs/permissions/permission_linux.c395
-rw-r--r--internal/fs/permissions/permission_linux.go33
-rw-r--r--internal/fs/permissions/permission_linux.h60
-rw-r--r--internal/fs/permissions/permission_test.go112
-rw-r--r--internal/fs/readfile.go318
-rw-r--r--internal/fs/stats.go69
-rw-r--r--internal/fs/tailfile.go27
-rw-r--r--internal/logger/logger.go457
-rw-r--r--internal/mapr/aggregateset.go185
-rw-r--r--internal/mapr/client/aggregate.go100
-rw-r--r--internal/mapr/globalgroupset.go100
-rw-r--r--internal/mapr/groupset.go178
-rw-r--r--internal/mapr/logformat/default.go23
-rw-r--r--internal/mapr/logformat/default_test.go35
-rw-r--r--internal/mapr/logformat/parser.go75
-rw-r--r--internal/mapr/query.go245
-rw-r--r--internal/mapr/query_test.go149
-rw-r--r--internal/mapr/selectcondition.go96
-rw-r--r--internal/mapr/server/aggregate.go170
-rw-r--r--internal/mapr/token.go108
-rw-r--r--internal/mapr/wherecondition.go193
-rw-r--r--internal/omode/mode.go81
-rw-r--r--internal/pprof/pprof.go17
-rw-r--r--internal/prompt/prompt.go95
-rw-r--r--internal/server/handlers/controlhandler.go106
-rw-r--r--internal/server/handlers/handler.go10
-rw-r--r--internal/server/handlers/serverhandler.go492
-rw-r--r--internal/server/server.go212
-rw-r--r--internal/server/stats.go88
-rw-r--r--internal/ssh/client/authmethods.go45
-rw-r--r--internal/ssh/client/hostkeycallback.go285
-rw-r--r--internal/ssh/server/hostkey.go37
-rw-r--r--internal/ssh/server/publickeycallback.go62
-rw-r--r--internal/ssh/ssh.go112
-rw-r--r--internal/user/name.go24
-rw-r--r--internal/user/server/user.go131
-rw-r--r--internal/version/version.go40
67 files changed, 6793 insertions, 0 deletions
diff --git a/internal/clients/args.go b/internal/clients/args.go
new file mode 100644
index 0000000..5fe0a72
--- /dev/null
+++ b/internal/clients/args.go
@@ -0,0 +1,18 @@
+package clients
+
+import (
+ "github.com/mimecast/dtail/internal/omode"
+)
+
+// Args is a helper struct to summarize common client arguments.
+type Args struct {
+ Mode omode.Mode
+ ServersStr string
+ UserName string
+ Files string
+ Regex string
+ TrustAllHosts bool
+ Discovery string
+ ConnectionsPerCPU int
+ PingTimeout int
+}
diff --git a/internal/clients/baseclient.go b/internal/clients/baseclient.go
new file mode 100644
index 0000000..574ae94
--- /dev/null
+++ b/internal/clients/baseclient.go
@@ -0,0 +1,137 @@
+package clients
+
+import (
+ "regexp"
+ "sync"
+ "time"
+
+ "github.com/mimecast/dtail/internal/clients/remote"
+ "github.com/mimecast/dtail/internal/discovery"
+ "github.com/mimecast/dtail/internal/logger"
+ "github.com/mimecast/dtail/internal/omode"
+ "github.com/mimecast/dtail/internal/ssh/client"
+
+ gossh "golang.org/x/crypto/ssh"
+)
+
+// This is the main client data structure.
+type baseClient struct {
+ Args
+ // To display client side stats
+ stats *stats
+ // List of remote servers to connect to.
+ servers []string
+ // We have one connection per remote server.
+ connections []*remote.Connection
+ // SSH auth methods to use to connect to the remote servers.
+ sshAuthMethods []gossh.AuthMethod
+ // To deal with SSH host keys
+ hostKeyCallback *client.HostKeyCallback
+ // To stop the client.
+ stop chan struct{}
+ // To indicate that the client has stopped.
+ stopped chan struct{}
+ // Throttle how fast we initiate SSH connections concurrently
+ throttleCh chan struct{}
+ // Retry connection upon failure?
+ retry bool
+ // Connection helper.
+ maker connectionMaker
+}
+
+func (c *baseClient) init(maker connectionMaker) {
+ logger.Info("Initiating base client")
+
+ c.maker = maker
+ //c.connections = make(map[string]*remote.Connection)
+ c.sshAuthMethods, c.hostKeyCallback = client.InitSSHAuthMethods(c.TrustAllHosts, c.throttleCh)
+
+ // Retrieve a shuffled list of remote dtail servers.
+ shuffleServers := true
+ discoveryService := discovery.New(c.Discovery, c.ServersStr, shuffleServers)
+ for _, server := range discoveryService.ServerList() {
+ c.connections = append(c.connections, c.maker.makeConnection(server, c.sshAuthMethods, c.hostKeyCallback))
+ }
+
+ if _, err := regexp.Compile(c.Regex); err != nil {
+ logger.FatalExit(c.Regex, "Can't test compile regex", err)
+ }
+
+ // Periodically check for unknown hosts, and ask the user whether to trust them or not.
+ go c.hostKeyCallback.PromptAddHosts(c.stop)
+
+ // Periodically print out connection stats to the client.
+ c.stats = newTailStats(len(c.connections))
+ go c.stats.periodicLogStats(c.throttleCh, c.stop)
+}
+
+func (c *baseClient) Start() (status int) {
+ active := make(chan struct{}, len(c.connections))
+
+ var wg sync.WaitGroup
+ wg.Add(len(c.connections))
+
+ for i, conn := range c.connections {
+ go func(i int, conn *remote.Connection) {
+ active <- struct{}{}
+ defer func() {
+ logger.Debug(conn.Server, "Disconnected completely...")
+ <-active
+ }()
+ wg.Done()
+
+ for {
+ conn.Start(c.throttleCh, c.stats.connectionsEstCh)
+ if !c.retry {
+ return
+ }
+ time.Sleep(time.Second * 2)
+ logger.Debug(conn.Server, "Reconencting")
+ conn = c.maker.makeConnection(conn.Server, c.sshAuthMethods, c.hostKeyCallback)
+ c.connections[i] = conn
+ }
+ }(i, conn)
+ }
+
+ wg.Wait()
+ c.waitUntilDone(active)
+
+ return
+}
+
+func (c *baseClient) waitUntilDone(active chan struct{}) {
+ defer close(c.stopped)
+
+ if c.Mode != omode.TailClient {
+ c.waitUntilZero(active)
+ logger.Info("All connections stopped")
+ return
+ }
+
+ <-c.stop
+ logger.Info("Stopping client")
+ for _, conn := range c.connections {
+ conn.Stop()
+ }
+
+ c.waitUntilZero(active)
+}
+
+func (c *baseClient) waitUntilZero(active chan struct{}) {
+ for {
+ logger.Debug("Active connections", len(active))
+ if len(active) == 0 {
+ return
+ }
+ time.Sleep(time.Second)
+ }
+}
+
+func (c *baseClient) Stop() {
+ close(c.stop)
+ <-c.WaitC()
+}
+
+func (c *baseClient) WaitC() <-chan struct{} {
+ return c.stopped
+}
diff --git a/internal/clients/catclient.go b/internal/clients/catclient.go
new file mode 100644
index 0000000..5ea701d
--- /dev/null
+++ b/internal/clients/catclient.go
@@ -0,0 +1,53 @@
+package clients
+
+import (
+ "errors"
+ "fmt"
+ "runtime"
+ "strings"
+
+ "github.com/mimecast/dtail/internal/clients/handlers"
+ "github.com/mimecast/dtail/internal/clients/remote"
+ "github.com/mimecast/dtail/internal/omode"
+ "github.com/mimecast/dtail/internal/ssh/client"
+
+ gossh "golang.org/x/crypto/ssh"
+)
+
+// CatClient is a client for returning a whole file from the beginning to the end.
+type CatClient struct {
+ baseClient
+}
+
+// NewCatClient returns a new cat client.
+func NewCatClient(args Args) (*CatClient, error) {
+ if args.Regex != "" {
+ return nil, errors.New("Can't use regex with 'cat' operating mode")
+ }
+
+ args.Regex = "."
+ args.Mode = omode.CatClient
+
+ c := CatClient{
+ baseClient: baseClient{
+ Args: args,
+ stop: make(chan struct{}),
+ stopped: make(chan struct{}),
+ throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()),
+ retry: false,
+ },
+ }
+
+ c.init(c)
+
+ return &c, nil
+}
+
+func (c CatClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection {
+ conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback)
+ conn.Handler = handlers.NewClientHandler(server, c.PingTimeout)
+ for _, file := range strings.Split(c.Files, ",") {
+ conn.Commands = append(conn.Commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex))
+ }
+ return conn
+}
diff --git a/internal/clients/client.go b/internal/clients/client.go
new file mode 100644
index 0000000..85d1aae
--- /dev/null
+++ b/internal/clients/client.go
@@ -0,0 +1,7 @@
+package clients
+
+// Client is the interface for the end user command line client.
+type Client interface {
+ Start() int
+ Stop()
+}
diff --git a/internal/clients/connectionmaker.go b/internal/clients/connectionmaker.go
new file mode 100644
index 0000000..0617992
--- /dev/null
+++ b/internal/clients/connectionmaker.go
@@ -0,0 +1,12 @@
+package clients
+
+import (
+ "github.com/mimecast/dtail/internal/clients/remote"
+ "github.com/mimecast/dtail/internal/ssh/client"
+
+ gossh "golang.org/x/crypto/ssh"
+)
+
+type connectionMaker interface {
+ makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection
+}
diff --git a/internal/clients/grepclient.go b/internal/clients/grepclient.go
new file mode 100644
index 0000000..c568f63
--- /dev/null
+++ b/internal/clients/grepclient.go
@@ -0,0 +1,53 @@
+package clients
+
+import (
+ "errors"
+ "fmt"
+ "runtime"
+ "strings"
+
+ "github.com/mimecast/dtail/internal/clients/handlers"
+ "github.com/mimecast/dtail/internal/clients/remote"
+ "github.com/mimecast/dtail/internal/omode"
+ "github.com/mimecast/dtail/internal/ssh/client"
+
+ gossh "golang.org/x/crypto/ssh"
+)
+
+// GrepClient searches a remote file for all lines matching a regular expression. Only the matching lines are displayed.
+type GrepClient struct {
+ baseClient
+}
+
+// NewGrepClient creates a new grep client.
+func NewGrepClient(args Args) (*GrepClient, error) {
+ if args.Regex == "" {
+ return nil, errors.New("No regex specified, use '-regex' flag")
+ }
+ args.Mode = omode.GrepClient
+
+ c := GrepClient{
+ baseClient: baseClient{
+ Args: args,
+ stop: make(chan struct{}),
+ stopped: make(chan struct{}),
+ throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()),
+ retry: false,
+ },
+ }
+
+ c.init(c)
+
+ return &c, nil
+}
+
+func (c GrepClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection {
+ conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback)
+ conn.Handler = handlers.NewClientHandler(server, c.PingTimeout)
+
+ for _, file := range strings.Split(c.Files, ",") {
+ conn.Commands = append(conn.Commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex))
+ }
+
+ return conn
+}
diff --git a/internal/clients/handlers/basehandler.go b/internal/clients/handlers/basehandler.go
new file mode 100644
index 0000000..19246f9
--- /dev/null
+++ b/internal/clients/handlers/basehandler.go
@@ -0,0 +1,134 @@
+package handlers
+
+import (
+ "github.com/mimecast/dtail/internal/logger"
+ "errors"
+ "fmt"
+ "io"
+ "strings"
+ "time"
+)
+
+type baseHandler struct {
+ server string
+ shellStarted bool
+ commands chan string
+ pong chan struct{}
+ receiveBuf []byte
+ stop chan struct{}
+ pingTimeout int
+}
+
+func (h *baseHandler) Server() string {
+ return h.server
+}
+
+// Used to determine whether server is still responding to requests or not.
+func (h *baseHandler) Ping() error {
+ if h.pingTimeout == 0 {
+ // Server ping disabled
+ return nil
+ }
+
+ if err := h.SendCommand("ping"); err != nil {
+ return err
+ }
+
+ select {
+ case <-h.pong:
+ return nil
+ case <-time.After(time.Duration(h.pingTimeout) * time.Second):
+ }
+
+ return errors.New("Didn't receive any server pongs (ping replies)")
+}
+
+func (h *baseHandler) SendCommand(command string) error {
+ if command == "ping" {
+ logger.Trace("Sending command", h.server, command)
+ } else {
+ logger.Debug("Sending command", h.server, command)
+ }
+
+ select {
+ case h.commands <- fmt.Sprintf("%s;", command):
+ case <-time.After(time.Second * 5):
+ return errors.New("Timed out sending command " + command)
+ case <-h.stop:
+ }
+
+ return nil
+}
+
+// Read data from the dtail server via Writer interface.
+func (h *baseHandler) Write(p []byte) (n int, err error) {
+ for _, b := range p {
+ h.receiveBuf = append(h.receiveBuf, b)
+ if b == '\n' {
+ if len(h.receiveBuf) == 0 {
+ continue
+ }
+ message := string(h.receiveBuf)
+ h.handleMessageType(message)
+ }
+ }
+
+ return len(p), nil
+}
+
+// Send data to the dtail server via Reader interface.
+func (h *baseHandler) Read(p []byte) (n int, err error) {
+ select {
+ case command := <-h.commands:
+ n = copy(p, []byte(command))
+ case <-h.stop:
+ return 0, io.EOF
+ }
+ return
+}
+
+// Handle various message types.
+func (h *baseHandler) handleMessageType(message string) {
+ if len(h.receiveBuf) == 0 {
+ return
+ }
+ // Hidden server commands starti with a dot "."
+ if h.receiveBuf[0] == '.' {
+ h.handleHiddenMessage(message)
+ h.receiveBuf = h.receiveBuf[:0]
+ return
+ }
+
+ // Silent mode will only print out remote logs but not remote server
+ // commands. But remote server commands will be still logged to ./log/.
+ if logger.Mode == logger.SilentMode {
+ if h.receiveBuf[0] == 'R' {
+ logger.Raw(message)
+ }
+ h.receiveBuf = h.receiveBuf[:0]
+ return
+ }
+ logger.Raw(message)
+ h.receiveBuf = h.receiveBuf[:0]
+}
+
+// Handle messages received from server which are not meant to be displayed
+// to the end user.
+func (h *baseHandler) handleHiddenMessage(message string) {
+ switch {
+ case strings.HasPrefix(message, ".pong"):
+ h.pong <- struct{}{}
+ case strings.HasPrefix(message, ".syn close connection"):
+ h.SendCommand("ack close connection")
+ }
+}
+
+// Stop the handler.
+func (h *baseHandler) Stop() {
+ select {
+ case <-h.stop:
+ default:
+ logger.Debug("Stopping base handler", h.server)
+ close(h.stop)
+ }
+}
diff --git a/internal/clients/handlers/clienthandler.go b/internal/clients/handlers/clienthandler.go
new file mode 100644
index 0000000..4738cd3
--- /dev/null
+++ b/internal/clients/handlers/clienthandler.go
@@ -0,0 +1,26 @@
+package handlers
+
+import (
+ "github.com/mimecast/dtail/internal/logger"
+)
+
+// ClientHandler is the basic client handler interface.
+type ClientHandler struct {
+ baseHandler
+}
+
+// NewClientHandler creates a new client handler.
+func NewClientHandler(server string, pingTimeout int) *ClientHandler {
+ logger.Debug(server, "Creating new client handler")
+
+ return &ClientHandler{
+ baseHandler{
+ server: server,
+ shellStarted: false,
+ commands: make(chan string),
+ pong: make(chan struct{}, 1),
+ stop: make(chan struct{}),
+ pingTimeout: pingTimeout,
+ },
+ }
+}
diff --git a/internal/clients/handlers/handler.go b/internal/clients/handlers/handler.go
new file mode 100644
index 0000000..2013be0
--- /dev/null
+++ b/internal/clients/handlers/handler.go
@@ -0,0 +1,12 @@
+package handlers
+
+import "io"
+
+// Handler provides all methods which can be run on any client handler.
+type Handler interface {
+ io.ReadWriter
+ Ping() error
+ Stop()
+ SendCommand(command string) error
+ Server() string
+}
diff --git a/internal/clients/handlers/healthhandler.go b/internal/clients/handlers/healthhandler.go
new file mode 100644
index 0000000..4051e2c
--- /dev/null
+++ b/internal/clients/handlers/healthhandler.go
@@ -0,0 +1,75 @@
+package handlers
+
+import (
+ "errors"
+ "fmt"
+ "time"
+)
+
+// HealthHandler implements the handler required for health checks.
+type HealthHandler struct {
+ // Buffer of incoming data from server.
+ receiveBuf []byte
+ // To send commands to the server.
+ commands chan string
+ // To receive messages from the server.
+ receive chan<- string
+ // The remote server address
+ server string
+}
+
+// NewHealthHandler returns a new health check handler.
+func NewHealthHandler(server string, receive chan<- string) *HealthHandler {
+ h := HealthHandler{
+ server: server,
+ receive: receive,
+ commands: make(chan string),
+ }
+
+ return &h
+}
+
+// Server returns the remote server name.
+func (h *HealthHandler) Server() string {
+ return h.server
+}
+
+// Stop is not of use for health check handler.
+func (h *HealthHandler) Stop() {
+ // Nothing done here.
+}
+
+// Ping is not of use for health check handler.
+func (h *HealthHandler) Ping() error {
+ return nil
+}
+
+// SendCommand send a DTail command to the server.
+func (h *HealthHandler) SendCommand(command string) error {
+ select {
+ case h.commands <- fmt.Sprintf("%s;", command):
+ case <-time.NewTimer(time.Second * 10).C:
+ return errors.New("Timed out sending command " + command)
+ }
+
+ return nil
+}
+
+// Server writes byte stream to client.
+func (h *HealthHandler) Write(p []byte) (n int, err error) {
+ for _, b := range p {
+ h.receiveBuf = append(h.receiveBuf, b)
+ if b == '\n' {
+ h.receive <- string(h.receiveBuf)
+ h.receiveBuf = h.receiveBuf[:0]
+ }
+ }
+
+ return len(p), nil
+}
+
+// Server reads byte stream from client.
+func (h *HealthHandler) Read(p []byte) (n int, err error) {
+ n = copy(p, []byte(<-h.commands))
+ return
+}
diff --git a/internal/clients/handlers/maprhandler.go b/internal/clients/handlers/maprhandler.go
new file mode 100644
index 0000000..d76cdfd
--- /dev/null
+++ b/internal/clients/handlers/maprhandler.go
@@ -0,0 +1,74 @@
+package handlers
+
+import (
+ "github.com/mimecast/dtail/internal/logger"
+ "github.com/mimecast/dtail/internal/mapr"
+ "github.com/mimecast/dtail/internal/mapr/client"
+ "strings"
+)
+
+// MaprHandler is the handler used on the client side for running mapreduce aggregations.
+type MaprHandler struct {
+ baseHandler
+ aggregate *client.Aggregate
+ query *mapr.Query
+ count uint64
+}
+
+// NewMaprHandler returns a new mapreduce client handler.
+func NewMaprHandler(server string, query *mapr.Query, globalGroup *mapr.GlobalGroupSet, pingTimeout int) *MaprHandler {
+ return &MaprHandler{
+ baseHandler: baseHandler{
+ server: server,
+ shellStarted: false,
+ commands: make(chan string),
+ pong: make(chan struct{}, 1),
+ stop: make(chan struct{}),
+ pingTimeout: pingTimeout,
+ },
+ query: query,
+ aggregate: client.NewAggregate(server, query, globalGroup),
+ }
+}
+
+// Read data from the dtail server via Writer interface.
+func (h *MaprHandler) Write(p []byte) (n int, err error) {
+ for _, b := range p {
+ h.baseHandler.receiveBuf = append(h.baseHandler.receiveBuf, b)
+ if b == '\n' {
+ if len(h.baseHandler.receiveBuf) == 0 {
+ continue
+ }
+ message := string(h.baseHandler.receiveBuf)
+
+ if h.baseHandler.receiveBuf[0] == 'A' {
+ h.handleAggregateMessage(strings.TrimSpace(message))
+ h.baseHandler.receiveBuf = h.baseHandler.receiveBuf[:0]
+ continue
+ }
+ h.baseHandler.handleMessageType(message)
+ }
+ }
+
+ return len(p), nil
+}
+
+// Handle a message received from server including mapr aggregation
+// related data.
+func (h *MaprHandler) handleAggregateMessage(message string) {
+ h.count++
+ parts := strings.Split(message, "|")
+
+ // Index 0 contains 'AGGREGATE', 1 contains server host.
+ // Aggregation data begins from index 2.
+ logger.Debug("Received aggregate data", h.server, h.count)
+ h.aggregate.Aggregate(parts[2:])
+ logger.Debug("Aggregated aggregate data", h.server, h.count)
+}
+
+// Stop stops the mapreduce client handler.
+func (h *MaprHandler) Stop() {
+ logger.Debug("Stopping mapreduce handler", h.server)
+ h.aggregate.Stop()
+ h.baseHandler.Stop()
+}
diff --git a/internal/clients/healthclient.go b/internal/clients/healthclient.go
new file mode 100644
index 0000000..ff13b83
--- /dev/null
+++ b/internal/clients/healthclient.go
@@ -0,0 +1,95 @@
+package clients
+
+import (
+ "fmt"
+ "runtime"
+ "strings"
+ "time"
+
+ "github.com/mimecast/dtail/internal/clients/handlers"
+ "github.com/mimecast/dtail/internal/clients/remote"
+ "github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/omode"
+
+ gossh "golang.org/x/crypto/ssh"
+)
+
+// HealthClient is used for health checking (e.g. via Nagios)
+type HealthClient struct {
+ // Client operating mode
+ mode omode.Mode
+ // The remote server address
+ server string
+ // SSH user name
+ userName string
+ // SSH auth methods to use to connect to the remote servers.
+ sshAuthMethods []gossh.AuthMethod
+}
+
+// NewHealthClient returns a new healh client.
+func NewHealthClient(mode omode.Mode) (*HealthClient, error) {
+ c := HealthClient{
+ mode: mode,
+ server: fmt.Sprintf("%s:%d", config.Server.SSHBindAddress, config.Common.SSHPort),
+ userName: config.ControlUser,
+ }
+ c.initSSHAuthMethods()
+
+ return &c, nil
+}
+
+// Start the health client.
+func (c *HealthClient) Start() (status int) {
+ receive := make(chan string)
+
+ throttleCh := make(chan struct{}, runtime.NumCPU())
+ statsCh := make(chan struct{}, 1)
+
+ conn := remote.NewOneOffConnection(c.server, c.userName, c.sshAuthMethods)
+ conn.Handler = handlers.NewHealthHandler(c.server, receive)
+ conn.Commands = []string{c.mode.String()}
+
+ go conn.Start(throttleCh, statsCh)
+ defer conn.Stop()
+
+ for {
+ select {
+ case data := <-receive:
+ // Parse recieved data.
+ s := strings.Split(data, "|")
+ message := s[len(s)-1]
+ if strings.HasPrefix(message, "done;") {
+ return
+ }
+
+ // Set severity.
+ s = strings.Split(message, ":")
+ switch s[0] {
+ case "OK":
+ case "WARNING":
+ if status < 1 {
+ status = 1
+ }
+ case "CRITICAL":
+ status = 2
+ case "UNKNOWN":
+ status = 3
+ default:
+ fmt.Printf("CRITICAL: Unexpected server response: '%s'\n", message)
+ status = 2
+ return
+ }
+ fmt.Print(message)
+
+ case <-time.After(time.Second * 2):
+ status = 2
+ fmt.Println("CRITICAL: Could not communicate with DTail server")
+ return
+ }
+ }
+}
+
+// Initialize SSH auth methods.
+func (c *HealthClient) initSSHAuthMethods() {
+ c.sshAuthMethods = append(c.sshAuthMethods, gossh.Password(config.ControlUser))
+}
diff --git a/internal/clients/maprclient.go b/internal/clients/maprclient.go
new file mode 100644
index 0000000..9070827
--- /dev/null
+++ b/internal/clients/maprclient.go
@@ -0,0 +1,152 @@
+package clients
+
+import (
+ "errors"
+ "fmt"
+ "runtime"
+ "strings"
+ "time"
+
+ "github.com/mimecast/dtail/internal/clients/handlers"
+ "github.com/mimecast/dtail/internal/clients/remote"
+ "github.com/mimecast/dtail/internal/logger"
+ "github.com/mimecast/dtail/internal/mapr"
+ "github.com/mimecast/dtail/internal/omode"
+ "github.com/mimecast/dtail/internal/ssh/client"
+
+ gossh "golang.org/x/crypto/ssh"
+)
+
+// MaprClient is used for running mapreduce aggregations on remote files.
+type MaprClient struct {
+ baseClient
+ // Query string for mapr aggregations
+ queryStr string
+ // Global group set for merged mapr aggregation results
+ globalGroup *mapr.GlobalGroupSet
+ // The query object (constructed from queryStr)
+ query *mapr.Query
+ // Additative result or new result every run?
+ additative bool
+}
+
+// NewMaprClient returns a new mapreduce client.
+func NewMaprClient(args Args, queryStr string) (*MaprClient, error) {
+ if queryStr == "" {
+ return nil, errors.New("No mapreduce query specified, use '-query' flag")
+ }
+
+ c := MaprClient{
+ baseClient: baseClient{
+ Args: args,
+ stop: make(chan struct{}),
+ stopped: make(chan struct{}),
+ throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()),
+ retry: args.Mode == omode.TailClient,
+ },
+ queryStr: queryStr,
+ additative: args.Mode == omode.MapClient,
+ }
+
+ query, err := mapr.NewQuery(c.queryStr)
+ if err != nil {
+ logger.FatalExit(c.queryStr, "Can't parse mapr query", err)
+ }
+
+ c.query = query
+
+ switch c.query.Table {
+ case "*":
+ c.Regex = fmt.Sprintf("\\|MAPREDUCE:\\|")
+ case ".":
+ c.Regex = "."
+ default:
+ c.Regex = fmt.Sprintf("\\|MAPREDUCE:%s\\|", c.query.Table)
+ }
+
+ c.globalGroup = mapr.NewGlobalGroupSet()
+ c.baseClient.init(c)
+
+ return &c, nil
+}
+
+func (c MaprClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection {
+ conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback)
+ conn.Handler = handlers.NewMaprHandler(conn.Server, c.query, c.globalGroup, c.PingTimeout)
+
+ conn.Commands = append(conn.Commands, fmt.Sprintf("map %s", c.query.RawQuery))
+ commandStr := "tail"
+ if c.additative {
+ commandStr = "cat"
+ }
+
+ for _, file := range strings.Split(c.Files, ",") {
+ conn.Commands = append(conn.Commands, fmt.Sprintf("%s %s regex %s", commandStr, file, c.Regex))
+ }
+
+ return conn
+}
+
+// Start starts the mapreduce client.
+func (c *MaprClient) Start() (status int) {
+ if c.query.Outfile == "" {
+ // Only print out periodic results if we don't write an outfile
+ go c.periodicPrintResults()
+ }
+
+ status = c.baseClient.Start()
+ if c.additative {
+ c.recievedFinalResult()
+ }
+ c.baseClient.Stop()
+
+ return
+}
+
+func (c *MaprClient) recievedFinalResult() {
+ logger.Info("Received final mapreduce result")
+
+ if c.query.Outfile == "" {
+ c.printResults()
+ return
+ }
+
+ logger.Info(fmt.Sprintf("Writing final mapreduce result to '%s'", c.query.Outfile))
+ err := c.globalGroup.WriteResult(c.query)
+ if err != nil {
+ logger.FatalExit(err)
+ return
+ }
+ logger.Info(fmt.Sprintf("Wrote final mapreduce result to '%s'", c.query.Outfile))
+}
+
+func (c *MaprClient) periodicPrintResults() {
+ for {
+ select {
+ case <-time.After(c.query.Interval):
+ logger.Info("Gathering interim mapreduce result")
+ c.printResults()
+ case <-c.baseClient.stop:
+ return
+ }
+ }
+}
+
+func (c *MaprClient) printResults() {
+ var result string
+ var err error
+ var numLines int
+
+ if c.additative {
+ result, numLines, err = c.globalGroup.Result(c.query)
+ } else {
+ result, numLines, err = c.globalGroup.SwapOut().Result(c.query)
+ }
+ if err != nil {
+ logger.FatalExit(err)
+ }
+ if numLines > 0 {
+ logger.Raw(fmt.Sprintf("%s\n", c.query.RawQuery))
+ logger.Raw(result)
+ }
+}
diff --git a/internal/clients/remote/connection.go b/internal/clients/remote/connection.go
new file mode 100644
index 0000000..bfc7bc5
--- /dev/null
+++ b/internal/clients/remote/connection.go
@@ -0,0 +1,230 @@
+package remote
+
+import (
+ "github.com/mimecast/dtail/internal/clients/handlers"
+ "github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/logger"
+ "github.com/mimecast/dtail/internal/ssh/client"
+ "fmt"
+ "io"
+ "strconv"
+ "strings"
+ "time"
+
+ "golang.org/x/crypto/ssh"
+)
+
+// Connection represents a client connection connection to a single server.
+type Connection struct {
+ // The remote server's hostname connected to.
+ Server string
+ // The remote server's port connected to.
+ port int
+ // The SSH client configuration used.
+ config *ssh.ClientConfig
+ // The SSH client handler to use.
+ Handler handlers.Handler
+ // DTail commands sent from client to server. When client loses
+ // connection to the server it re-connects automatically and sends the
+ // same commands again.
+ Commands []string
+ // Is it a persistent connection or a one-off?
+ isOneOff bool
+ // Used to stop the connection
+ stop chan struct{}
+ // To deal with SSH server host keys
+ hostKeyCallback *client.HostKeyCallback
+}
+
+// NewConnection returns a new connection.
+func NewConnection(server string, userName string, authMethods []ssh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *Connection {
+ logger.Debug(server, "Creating new connection")
+
+ c := Connection{
+ hostKeyCallback: hostKeyCallback,
+ config: &ssh.ClientConfig{
+ User: userName,
+ Auth: authMethods,
+ HostKeyCallback: hostKeyCallback.Wrap(),
+ Timeout: time.Second * 3,
+ },
+ stop: make(chan struct{}),
+ }
+
+ c.initServerPort(server)
+
+ return &c
+}
+
+// NewOneOffConnection creates new one-off connection (only for sending a series of commands and then quit).
+func NewOneOffConnection(server string, userName string, authMethods []ssh.AuthMethod) *Connection {
+ c := Connection{
+ config: &ssh.ClientConfig{
+ User: userName,
+ Auth: authMethods,
+ HostKeyCallback: ssh.InsecureIgnoreHostKey(),
+ },
+ stop: make(chan struct{}),
+ isOneOff: true,
+ }
+
+ c.initServerPort(server)
+
+ return &c
+}
+
+// Attempt to parse the server port address from the provided server FQDN.
+func (c *Connection) initServerPort(server string) {
+ c.Server = server
+ c.port = config.Common.SSHPort
+ parts := strings.Split(server, ":")
+
+ if len(parts) == 2 {
+ logger.Debug("Parsing port from hostname", parts)
+ port, err := strconv.Atoi(parts[1])
+ if err != nil {
+ logger.FatalExit("Unable to parse client port", server, parts, err)
+ }
+ c.Server = parts[0]
+ c.port = port
+ }
+}
+
+// Start the server connection. Build up SSH session and send some DTail commandc.
+func (c *Connection) Start(throttleCh, statsCh chan struct{}) {
+ select {
+ case <-c.stop:
+ logger.Info(c.Server, c.port, "Disconnecting client")
+ return
+ default:
+ }
+
+ // Wait for SSH connection throttler
+ throttleCh <- struct{}{}
+
+ // Wait until connection has been initiated or an error occured
+ // during initialization.
+ throttleStopCh := make(chan struct{}, 2)
+ go func() {
+ <-throttleStopCh
+ <-throttleCh
+ }()
+
+ if err := c.dial(c.Server, c.port, throttleStopCh, statsCh); err != nil {
+ logger.Warn(c.Server, c.port, err)
+ throttleStopCh <- struct{}{}
+
+ if c.hostKeyCallback.Untrusted(fmt.Sprintf("%s:%d", c.Server, c.port)) {
+ logger.Debug("Not trusting host, not trying to re-connect", c.Server, c.port)
+ return
+ }
+ }
+}
+
+// Dail into a new SSH connection. Close connection in case of an error.
+func (c *Connection) dial(host string, port int, throttleStopCh, statsCh chan struct{}) error {
+ statsCh <- struct{}{}
+ defer func() { <-statsCh }()
+
+ logger.Debug(host, "dial")
+ address := fmt.Sprintf("%s:%d", host, port)
+
+ client, err := ssh.Dial("tcp", address, c.config)
+ if err != nil {
+ return err
+ }
+ defer client.Close()
+
+ return c.session(client, throttleStopCh)
+}
+
+// Create the SSH session. Close the session in case of an error.
+func (c *Connection) session(client *ssh.Client, throttleStopCh chan<- struct{}) error {
+ logger.Debug(c.Server, "session")
+
+ session, err := client.NewSession()
+ if err != nil {
+ return err
+ }
+ defer session.Close()
+
+ return c.handle(session, throttleStopCh)
+}
+
+// Handle the SSH session. Also send periodic pings to the server in order
+// to determine that session is still intact.
+func (c *Connection) handle(session *ssh.Session, throttleStopCh chan<- struct{}) error {
+ defer c.Handler.Stop()
+
+ logger.Debug(c.Server, "handle")
+
+ stdinPipe, err := session.StdinPipe()
+ if err != nil {
+ return err
+ }
+
+ stdoutPipe, err := session.StdoutPipe()
+ if err != nil {
+ return err
+ }
+
+ if err := session.Shell(); err != nil {
+ return err
+ }
+
+ // Establish Bi-directional pipe between SSH session and client handler.
+ brokenStdinPipe := make(chan struct{})
+ go func() {
+ defer close(brokenStdinPipe)
+ io.Copy(stdinPipe, c.Handler)
+ }()
+
+ brokenStdoutPipe := make(chan struct{})
+ go func() {
+ defer close(brokenStdoutPipe)
+ io.Copy(c.Handler, stdoutPipe)
+ }()
+
+ // SSH session established, other goroutine can initiate session now.
+ throttleStopCh <- struct{}{}
+
+ // Send all commands to client.
+ for _, command := range c.Commands {
+ logger.Debug(command)
+ c.Handler.SendCommand(command)
+ }
+
+ if !c.isOneOff {
+ return c.periodicAliveCheck(brokenStdinPipe, brokenStdoutPipe)
+ }
+
+ <-c.stop
+
+ // Normal shutdown, all fine
+ return nil
+}
+
+// Periodically check whether connection is still alive or not.
+func (c *Connection) periodicAliveCheck(brokenStdinPipe, brokenStdoutPipe <-chan struct{}) error {
+ for {
+ select {
+ case <-time.After(time.Second * 3):
+ if err := c.Handler.Ping(); err != nil {
+ return err
+ }
+ case <-brokenStdinPipe:
+ logger.Debug("Broken stdin pipe", c.Server, c.port)
+ return nil
+ case <-brokenStdoutPipe:
+ logger.Debug("Broken stdout pipe", c.Server, c.port)
+ return nil
+ case <-c.stop:
+ return nil
+ }
+ }
+}
+
+// Stop the connection.
+func (c *Connection) Stop() {
+ close(c.stop)
+}
diff --git a/internal/clients/stats.go b/internal/clients/stats.go
new file mode 100644
index 0000000..d36cef6
--- /dev/null
+++ b/internal/clients/stats.go
@@ -0,0 +1,81 @@
+package clients
+
+import (
+ "github.com/mimecast/dtail/internal/logger"
+ "fmt"
+ "runtime"
+ "sync"
+ "time"
+)
+
+// Used to collect and display various client stats.
+type stats struct {
+ // Total amount servers to connect to.
+ connectionsTotal int
+ // To keep track of what connected and disconnected
+ connectionsEstCh chan struct{}
+ // Amount of servers connections are established.
+ connected int
+ // To synchronize concurrent access.
+ mutex sync.Mutex
+}
+
+func newTailStats(connectionsTotal int) *stats {
+ return &stats{
+ connectionsTotal: connectionsTotal,
+ connectionsEstCh: make(chan struct{}, connectionsTotal),
+ connected: 0,
+ }
+}
+
+func (s *stats) periodicLogStats(throttleCh chan struct{}, stop <-chan struct{}) {
+ connectedLast := 0
+ statsInterval := 5
+
+ for {
+ select {
+ case <-time.After(time.Second * time.Duration(statsInterval)):
+ case <-stop:
+ return
+ }
+
+ connected := len(s.connectionsEstCh)
+ throttle := len(throttleCh)
+
+ newConnections := connected - connectedLast
+ connectionsPerSecond := float64(newConnections) / float64(statsInterval)
+ s.log(connected, newConnections, connectionsPerSecond, throttle)
+
+ connectedLast = connected
+
+ s.mutex.Lock()
+ s.connected = connected
+ s.mutex.Unlock()
+ }
+}
+
+func (s *stats) numConnected() int {
+ s.mutex.Lock()
+ defer s.mutex.Unlock()
+
+ return s.connected
+}
+
+func (s *stats) log(connected, newConnections int, connectionsPerSecond float64, throttle int) {
+ percConnected := percentOf(float64(s.connectionsTotal), float64(connected))
+
+ connectedStr := fmt.Sprintf("connected=%d/%d(%d%%)", connected, s.connectionsTotal, int(percConnected))
+ newConnStr := fmt.Sprintf("new=%d", newConnections)
+ rateStr := fmt.Sprintf("rate=%2.2f/s", connectionsPerSecond)
+ throttleStr := fmt.Sprintf("throttle=%d", throttle)
+ cpusGoroutinesStr := fmt.Sprintf("cpus/goroutines=%d/%d", runtime.NumCPU(), runtime.NumGoroutine())
+
+ logger.Info("stats", connectedStr, newConnStr, rateStr, throttleStr, cpusGoroutinesStr)
+}
+
+func percentOf(total float64, value float64) float64 {
+ if total == 0 || total == value {
+ return 100
+ }
+ return value / (total / 100.0)
+}
diff --git a/internal/clients/tailclient.go b/internal/clients/tailclient.go
new file mode 100644
index 0000000..674ca36
--- /dev/null
+++ b/internal/clients/tailclient.go
@@ -0,0 +1,49 @@
+package clients
+
+import (
+ "fmt"
+ "runtime"
+ "strings"
+
+ "github.com/mimecast/dtail/internal/clients/handlers"
+ "github.com/mimecast/dtail/internal/clients/remote"
+ "github.com/mimecast/dtail/internal/omode"
+ "github.com/mimecast/dtail/internal/ssh/client"
+
+ gossh "golang.org/x/crypto/ssh"
+)
+
+// TailClient is used for tailing remote log files (opening, seeking to the end and returning only new incoming lines).
+type TailClient struct {
+ baseClient
+}
+
+// NewTailClient returns a new TailClient.
+func NewTailClient(args Args) (*TailClient, error) {
+ args.Mode = omode.TailClient
+
+ c := TailClient{
+ baseClient: baseClient{
+ Args: args,
+ stop: make(chan struct{}),
+ stopped: make(chan struct{}),
+ throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()),
+ retry: true,
+ },
+ }
+
+ c.init(c)
+
+ return &c, nil
+}
+
+func (c TailClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection {
+ conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback)
+ conn.Handler = handlers.NewClientHandler(server, c.PingTimeout)
+
+ for _, file := range strings.Split(c.Files, ",") {
+ conn.Commands = append(conn.Commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex))
+ }
+
+ return conn
+}
diff --git a/internal/color/color.go b/internal/color/color.go
new file mode 100644
index 0000000..0736199
--- /dev/null
+++ b/internal/color/color.go
@@ -0,0 +1,70 @@
+// Package color is used to prettify console output via ANSII terminal colors.
+package color
+
+import (
+ "fmt"
+)
+
+// Color name.
+type Color string
+
+// Attribute of a color.
+type Attribute string
+
+// The possible color variations.
+const (
+ escape = "\x1b"
+ reset = escape + "[0m"
+ seq string = "%s%s%s"
+
+ Gray Color = escape + "[30m"
+ Red Color = escape + "[31m"
+ Green Color = escape + "[32m"
+ Orange Color = escape + "[33m"
+ Blue Color = escape + "[34m"
+ Magenta Color = escape + "[35m"
+ Yellow Color = escape + "[36m"
+ LightGray Color = escape + "[37m"
+
+ BgGray Color = escape + "[40m"
+ BgRed Color = escape + "[41m"
+ BgGreen Color = escape + "[42m"
+ BgOrange Color = escape + "[43m"
+ BgBlue Color = escape + "[44m"
+ BgMagenta Color = escape + "[45m"
+ BgYellow Color = escape + "[46m"
+ BgLightGray Color = escape + "[47m"
+
+ Bold Attribute = escape + "[1m"
+ Italic Attribute = escape + "[3m"
+ Underline Attribute = escape + "[4m"
+ ReverseColor Attribute = escape + "[7m"
+
+ resetBold = escape + "[22m"
+ resetItalic = escape + "[23m"
+ resetUnderline = escape + "[24m"
+
+ Test Color = BgYellow
+ TestAttr Attribute = Bold
+)
+
+// Colored DTail client output enabled.
+var Colored bool
+
+// Paint a given string in a given color.
+func Paint(c Color, s string) string {
+ return fmt.Sprintf(seq, c, s, reset)
+}
+
+// Attr adds a given attribute to a given string, such as "bold" or "italic".
+func Attr(c Attribute, s string) string {
+ switch c {
+ case Bold:
+ return fmt.Sprintf(seq, Bold, s, resetBold)
+ case Italic:
+ return fmt.Sprintf(seq, Italic, s, resetItalic)
+ case Underline:
+ return fmt.Sprintf(seq, Underline, s, resetUnderline)
+ }
+ panic("Unknown attribute")
+}
diff --git a/internal/color/colorfy.go b/internal/color/colorfy.go
new file mode 100644
index 0000000..9ae46f5
--- /dev/null
+++ b/internal/color/colorfy.go
@@ -0,0 +1,58 @@
+package color
+
+import (
+ "fmt"
+ "strings"
+)
+
+// Add some color to log lines received from remote servers.
+func paintRemote(line string) string {
+ splitted := strings.Split(line, "|")
+ if splitted[2] == "100" {
+ splitted[2] = Paint(BgGreen, splitted[2])
+ } else {
+ splitted[2] = Paint(BgRed, splitted[2])
+ }
+ info := strings.Join(splitted[0:5], "|")
+ log := strings.Join(splitted[5:], "|")
+
+ if strings.HasPrefix(log, "WARN") {
+ log = Paint(BgYellow, log)
+ } else if strings.HasPrefix(log, "ERROR") {
+ log = Paint(BgRed, log)
+ } else if strings.HasPrefix(log, "FATAL") {
+ log = Attr(Bold, Paint(BgRed, log))
+ } else {
+ log = Paint(Blue, log)
+ }
+
+ return fmt.Sprintf("%s|%s", info, log)
+}
+
+// Add some color to stats generated by the client.
+func paintClientStats(line string) string {
+ splitted := strings.Split(line, "|")
+ first := strings.Join(splitted[0:4], "|")
+ connected := Paint(BgBlue, splitted[4])
+ last := strings.Join(splitted[5:], "|")
+
+ return fmt.Sprintf("%s|%s|%s", first, connected, last)
+}
+
+// Colorfy a given line based on the line's content.
+func Colorfy(line string) string {
+ if strings.HasPrefix(line, "REMOTE") {
+ return paintRemote(line)
+ }
+ if strings.HasPrefix(line, "CLIENT") && strings.Contains(line, "|stats|") {
+ return paintClientStats(line)
+ }
+ if strings.Contains(line, "ERROR") {
+ return Paint(Magenta, line)
+ }
+ if strings.Contains(line, "WARN") {
+ return Paint(Magenta, line)
+ }
+
+ return line
+}
diff --git a/internal/config/client.go b/internal/config/client.go
new file mode 100644
index 0000000..1515aae
--- /dev/null
+++ b/internal/config/client.go
@@ -0,0 +1,11 @@
+package config
+
+// ClientConfig represents a DTail client configuration (empty as of now as there
+// are no available config options yet, but that may changes in the future).
+type ClientConfig struct {
+}
+
+// Create a new default client configuration.
+func newDefaultClientConfig() *ClientConfig {
+ return &ClientConfig{}
+}
diff --git a/internal/config/common.go b/internal/config/common.go
new file mode 100644
index 0000000..8c07710
--- /dev/null
+++ b/internal/config/common.go
@@ -0,0 +1,42 @@
+package config
+
+// CommonConfig stores configuration keys shared by DTail server and client.
+type CommonConfig struct {
+ // The SSH server port number.
+ SSHPort int
+ // Enable experimental features.
+ ExperimentalFeaturesEnable bool `json:",omitempty"`
+ // Enable extra debug logging (used for deevlopment or debugging purpes only).
+ DebugEnable bool `json:",omitempty"`
+ // Enable extra trace logging (used for deevlopment or debugging purpes only).
+ TraceEnable bool `json:",omitempty"`
+ // The log strategy to use, one of
+ // stdout: only log to stdout (useful when used with systemd)
+ // daily: create a log file for every day
+ LogStrategy string
+ // The log directory
+ LogDir string
+ // The cache directory
+ CacheDir string
+ // Do we want to enable pperf http server?
+ PProfEnable bool `json:",omitempty"`
+ // The HTTP port used by PProf
+ PProfPort int `json:",omitempty"`
+ // The PProf HTTP server bind address
+ PProfBindAddress string `json:",omitempty"`
+}
+
+// Create a new default configuration.
+func newDefaultCommonConfig() *CommonConfig {
+ return &CommonConfig{
+ SSHPort: 2222,
+ DebugEnable: false,
+ TraceEnable: false,
+ ExperimentalFeaturesEnable: false,
+ LogDir: "log",
+ CacheDir: "cache",
+ PProfEnable: false,
+ PProfPort: 6060,
+ PProfBindAddress: "0.0.0.0",
+ }
+}
diff --git a/internal/config/config.go b/internal/config/config.go
new file mode 100644
index 0000000..0f26635
--- /dev/null
+++ b/internal/config/config.go
@@ -0,0 +1,45 @@
+package config
+
+import (
+ "encoding/json"
+ "io/ioutil"
+ "os"
+)
+
+// ControlUser is used for various DTail specific operations.
+const ControlUser string = "DTAIL-CONTROL-USER"
+
+// Client holds a DTail client configuration.
+var Client *ClientConfig
+
+// Server holds a DTail server configuration.
+var Server *ServerConfig
+
+// Common holds common configs of both both, client and server.
+var Common *CommonConfig
+
+// Used to initialize the configuration.
+type configInitializer struct {
+ Common *CommonConfig
+ Server *ServerConfig
+ Client *ClientConfig
+}
+
+// Parse and read a given config file in JSON format.
+func (c *configInitializer) parseConfig(configFile string) {
+ fd, err := os.Open(configFile)
+ if err != nil {
+ panic(err)
+ }
+ defer fd.Close()
+
+ cfgBytes, err := ioutil.ReadAll(fd)
+ if err != nil {
+ panic(err)
+ }
+
+ err = json.Unmarshal([]byte(cfgBytes), c)
+ if err != nil {
+ panic(err)
+ }
+}
diff --git a/internal/config/read.go b/internal/config/read.go
new file mode 100644
index 0000000..a4e605b
--- /dev/null
+++ b/internal/config/read.go
@@ -0,0 +1,37 @@
+package config
+
+import (
+ "os"
+)
+
+// Read the DTail configuration.
+func Read(configFile string, sshPort int) {
+ initializer := configInitializer{
+ Common: newDefaultCommonConfig(),
+ Server: newDefaultServerConfig(),
+ Client: newDefaultClientConfig(),
+ }
+
+ if configFile == "" {
+ configFile = "./cfg/dtail.json"
+ }
+
+ if _, err := os.Stat(configFile); !os.IsNotExist(err) {
+ initializer.parseConfig(configFile)
+ }
+
+ // Assign pointers to global variables, so that we can access the
+ // configuration from any place of the program.
+ Common = initializer.Common
+ Server = initializer.Server
+ Client = initializer.Client
+
+ if Server.MapreduceLogFormat == "" {
+ Server.MapreduceLogFormat = "default"
+ }
+
+ // If non-standard port specified, overwrite config
+ if sshPort != 2222 {
+ Common.SSHPort = sshPort
+ }
+}
diff --git a/internal/config/server.go b/internal/config/server.go
new file mode 100644
index 0000000..7883b33
--- /dev/null
+++ b/internal/config/server.go
@@ -0,0 +1,66 @@
+package config
+
+import (
+ "errors"
+)
+
+// Permissions map. Each SSH user has a list of permissions which
+// log files it is allowed to follow and which ones not.
+type Permissions struct {
+ // The default user permissions.
+ Default []string
+ // The per user special permissions.
+ Users map[string][]string
+}
+
+// ServerConfig represents the server configuration.
+type ServerConfig struct {
+ // The SSH server bind port.
+ SSHBindAddress string
+ // The max amount of concurrent user connection allowed to connect to the server.
+ MaxConnections int
+ // The max amount of concurrent cats per server.
+ MaxConcurrentCats int
+ // The max amount of concurrent tails per server.
+ MaxConcurrentTails int
+ // The user permissions.
+ Permissions Permissions `json:",omitempty"`
+ // The mapr log format
+ MapreduceLogFormat string `json:",omitempty"`
+ // The default path of the server host key
+ HostKeyFile string
+ // The host key size in bits
+ HostKeyBits int
+}
+
+// Create a new default server configuration.
+func newDefaultServerConfig() *ServerConfig {
+ defaultPermissions := []string{"^/.*"}
+ defaultBindAddress := "0.0.0.0"
+
+ return &ServerConfig{
+ SSHBindAddress: defaultBindAddress,
+ MaxConnections: 10,
+ MaxConcurrentCats: 2,
+ MaxConcurrentTails: 50,
+ HostKeyFile: "./cache/ssh_host_key",
+ HostKeyBits: 4096,
+ Permissions: Permissions{
+ Default: defaultPermissions,
+ },
+ }
+}
+
+// ServerUserPermissions retrieves the permission set of a given user.
+func ServerUserPermissions(userName string) (permissions []string, err error) {
+ permissions = Server.Permissions.Default
+ if p, ok := Server.Permissions.Users[userName]; ok {
+ permissions = p
+ }
+
+ if len(permissions) == 0 {
+ err = errors.New("Empty set of permission, user won't be able to open any files")
+ }
+
+ return
+}
diff --git a/internal/discovery/comma.go b/internal/discovery/comma.go
new file mode 100644
index 0000000..ad18be0
--- /dev/null
+++ b/internal/discovery/comma.go
@@ -0,0 +1,12 @@
+package discovery
+
+import (
+ "github.com/mimecast/dtail/internal/logger"
+ "strings"
+)
+
+// ServerListFromCOMMA retrieves a list of servers from comma separated input list.
+func (d *Discovery) ServerListFromCOMMA() []string {
+ logger.Debug("Retrieving server list from comma separated list", d.server)
+ return strings.Split(d.server, ",")
+}
diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go
new file mode 100644
index 0000000..d76c1b2
--- /dev/null
+++ b/internal/discovery/discovery.go
@@ -0,0 +1,173 @@
+package discovery
+
+import (
+ "github.com/mimecast/dtail/internal/logger"
+ "fmt"
+ "math/rand"
+ "os"
+ "reflect"
+ "regexp"
+ "strings"
+ "time"
+)
+
+// Discovery method for discovering a list of available DTail servers.
+type Discovery struct {
+ // To plug in a custom server discovery module.
+ module string
+ // To specifiy optional server discovery module options.
+ options string
+ // To either filter a server list or to secify an exact list.
+ server string
+ // To filter server list.
+ regex *regexp.Regexp
+ // To shuffle resulting server list.
+ shuffle bool
+}
+
+// New returns a new discovery method.
+func New(method, server string, shuffle bool) *Discovery {
+ module := method
+ options := ""
+
+ if strings.Contains(module, ":") {
+ s := strings.Split(module, ":")
+ if len(s) != 2 {
+ logger.FatalExit("Unable to parse discovery module", module)
+ }
+ module = s[0]
+ options = s[1]
+ }
+
+ d := Discovery{
+ module: strings.ToUpper(module),
+ options: options,
+ server: server,
+ shuffle: shuffle,
+ }
+
+ if strings.HasPrefix(server, "/") && strings.HasSuffix(server, "/") {
+ d.initRegex()
+ }
+
+ return &d
+}
+
+func (d *Discovery) initRegex() {
+ var runes []rune
+ last := len(d.server) - 1
+ for i, char := range d.server {
+ if i != 0 && i != last {
+ runes = append(runes, char)
+ }
+ }
+
+ regexStr := string(runes)
+ logger.Debug("Using filter regex", regexStr)
+
+ regex, err := regexp.Compile(regexStr)
+ if err != nil {
+ logger.FatalExit("Could not compile regex", regexStr, err)
+ }
+
+ d.regex = regex
+ d.server = ""
+}
+
+// ServerList to connect to via DTail client.
+func (d *Discovery) ServerList() []string {
+ servers := d.serverListFromModule()
+
+ if d.regex != nil {
+ servers = d.filterList(servers)
+ }
+
+ servers = d.dedupList(servers)
+
+ if d.shuffle {
+ servers = d.shuffleList(servers)
+ }
+
+ logger.Debug("Discovered servers", len(servers), servers)
+ return servers
+}
+
+func (d *Discovery) serverListFromModule() []string {
+ if d.module != "" {
+ return d.serverListFromReflectedModule()
+ }
+
+ if _, err := os.Stat(d.server); err == nil {
+ // Appears to be a file name, now try to read from that file.
+ return d.ServerListFromFILE()
+ }
+
+ // Appears to be a list of FQDNs (or a single FQDN)
+ return d.ServerListFromCOMMA()
+}
+
+// The aim of this is that everyone can plug in their own server discovery
+// method to DTail. Just add a method ServerListFrommMODULENAME to type
+// Discovery. Whereas MODULENAME must be a upeprcase string.
+func (d *Discovery) serverListFromReflectedModule() []string {
+ methodName := fmt.Sprintf("ServerListFrom%s", d.module)
+
+ rt := reflect.TypeOf(d)
+ reflectedMethod, ok := rt.MethodByName(methodName)
+ if !ok {
+ logger.FatalExit("No such server discovery module", d.module, methodName)
+ }
+
+ inputValues := make([]reflect.Value, 1)
+ // Thist input value is method receiver.
+ inputValues[0] = reflect.ValueOf(d)
+ returnValues := reflectedMethod.Func.Call(inputValues)
+
+ // First return value is server list.
+ return returnValues[0].Interface().([]string)
+}
+
+// Filter server list based on a regexp.
+func (d *Discovery) filterList(servers []string) (filtered []string) {
+ logger.Debug("Filtering server list")
+
+ for _, server := range servers {
+ if d.regex.MatchString(server) {
+ filtered = append(filtered, server)
+ }
+ }
+
+ return
+}
+
+// Deduplicate the server list.
+func (d *Discovery) dedupList(servers []string) (deduped []string) {
+ serverMap := make(map[string]struct{}, len(servers))
+
+ for _, server := range servers {
+ if _, ok := serverMap[server]; !ok {
+ serverMap[server] = struct{}{}
+ deduped = append(deduped, server)
+ }
+ }
+
+ logger.Info("Deduped server list", len(servers), len(deduped))
+ return
+}
+
+// Randomly shuffle the server list.
+func (d *Discovery) shuffleList(servers []string) []string {
+ logger.Debug("Shuffling server list")
+
+ r := rand.New(rand.NewSource(time.Now().Unix()))
+ shuffled := make([]string, len(servers))
+ n := len(servers)
+
+ for i := 0; i < n; i++ {
+ randIndex := r.Intn(len(servers))
+ shuffled[i] = servers[randIndex]
+ servers = append(servers[:randIndex], servers[randIndex+1:]...)
+ }
+
+ return shuffled
+}
diff --git a/internal/discovery/file.go b/internal/discovery/file.go
new file mode 100644
index 0000000..2edc867
--- /dev/null
+++ b/internal/discovery/file.go
@@ -0,0 +1,28 @@
+package discovery
+
+import (
+ "bufio"
+ "github.com/mimecast/dtail/internal/logger"
+ "os"
+)
+
+// ServerListFromFILE retrieves a list of servers from a file.
+func (d *Discovery) ServerListFromFILE() (servers []string) {
+ logger.Debug("Retrieving server list from file", d.server)
+
+ file, err := os.Open(d.server)
+ if err != nil {
+ logger.FatalExit(d.server, err)
+ }
+ defer file.Close()
+
+ scanner := bufio.NewScanner(file)
+ for scanner.Scan() {
+ servers = append(servers, scanner.Text())
+ }
+ if err := scanner.Err(); err != nil {
+ logger.FatalExit(d.server, err)
+ }
+
+ return
+}
diff --git a/internal/fs/catfile.go b/internal/fs/catfile.go
new file mode 100644
index 0000000..99f521f
--- /dev/null
+++ b/internal/fs/catfile.go
@@ -0,0 +1,27 @@
+package fs
+
+import "sync"
+
+// CatFile is for reading a whole file.
+type CatFile struct {
+ readFile
+}
+
+// NewCatFile returns a new file catter.
+func NewCatFile(filePath string, globID string, serverMessages chan<- string, limiter chan struct{}) CatFile {
+ var mutex sync.Mutex
+
+ return CatFile{
+ readFile: readFile{
+ filePath: filePath,
+ stop: make(chan struct{}),
+ globID: globID,
+ serverMessages: serverMessages,
+ retry: false,
+ canSkipLines: false,
+ seekEOF: false,
+ limiter: limiter,
+ mutex: &mutex,
+ },
+ }
+}
diff --git a/internal/fs/filereader.go b/internal/fs/filereader.go
new file mode 100644
index 0000000..5a08e27
--- /dev/null
+++ b/internal/fs/filereader.go
@@ -0,0 +1,9 @@
+package fs
+
+// FileReader is the interface used on the dtail server to read/cat/grep/mapr... a file.
+type FileReader interface {
+ Start(lines chan<- LineRead, regex string) error
+ FilePath() string
+ Retry() bool
+ Stop()
+}
diff --git a/internal/fs/lineread.go b/internal/fs/lineread.go
new file mode 100644
index 0000000..7ee558e
--- /dev/null
+++ b/internal/fs/lineread.go
@@ -0,0 +1,28 @@
+package fs
+
+import (
+ "fmt"
+)
+
+// LineRead represents a read log line.
+type LineRead struct {
+ // The content of the log line.
+ Content []byte
+ // Until now, how many log lines were processed?
+ Count uint64
+ // Sometimes we produce too many log lines so that the client
+ // is too slow to process all of them. The server will drop log
+ // lines if that happens but it will signal to the client how
+ // many log lines in % could be transmitted to the client.
+ TransmittedPerc int
+ GlobID *string
+}
+
+// Return a human readable representation of the followed line.
+func (l LineRead) String() string {
+ return fmt.Sprintf("LineRead(Content:%s,TransmittedPerc:%v,Count:%v,GlobID:%s)",
+ string(l.Content),
+ l.TransmittedPerc,
+ l.Count,
+ *l.GlobID)
+}
diff --git a/internal/fs/permissions/permission.go b/internal/fs/permissions/permission.go
new file mode 100644
index 0000000..6e83309
--- /dev/null
+++ b/internal/fs/permissions/permission.go
@@ -0,0 +1,14 @@
+// +build !linux
+
+package permissions
+
+import (
+ "github.com/mimecast/dtail/internal/logger"
+)
+
+// ToRead is to check whether user has read permissions to a given file.
+func ToRead(user, filePath string) (bool, error) {
+ // Only implemented for Linux, always expect true
+ logger.Warn(user, filePath, "Not performing ACL check, not supported on this platform")
+ return true, nil
+}
diff --git a/internal/fs/permissions/permission_linux.c b/internal/fs/permissions/permission_linux.c
new file mode 100644
index 0000000..cd10525
--- /dev/null
+++ b/internal/fs/permissions/permission_linux.c
@@ -0,0 +1,395 @@
+#include "permission_linux.h"
+
+#ifdef DEBUG
+void debug_print_checker(struct permission_checker *pc) {
+ fprintf(stderr, "DEBUG: user_name:%s (%d)\n",
+ pc->user_name, pc->uid);
+
+ fprintf(stderr, "DEBUG: ngids:%d\n", pc->ngids);
+ int j;
+ for (j = 0; j < pc->ngids; j++) {
+ fprintf(stderr, "DEBUG: %d", pc->gids[j]);
+ struct group *gr = getgrgid(pc->gids[j]);
+ if (gr != NULL)
+ fprintf(stderr, " (%s)", gr->gr_name);
+ fprintf(stderr, "\n");
+ }
+
+ fprintf(stderr, "DEBUG: file_path:%s (%d:%d)\n",
+ pc->file_path, pc->file_stat.st_uid, pc->file_stat.st_gid);
+}
+#endif // DEBUG
+
+int stat_file(struct permission_checker *pc) {
+ if (stat(pc->file_path, &pc->file_stat) != 0)
+ return -1;
+
+#ifdef DEBUG
+ fprintf(stderr, "DEBUG: File'%s' is owned by '%d:%d'\n",
+ pc->file_path, pc->file_stat.st_uid, pc->file_stat.st_gid);
+#endif
+
+ return 0;
+}
+
+int get_user_uid(struct permission_checker *pc) {
+ struct passwd *result = NULL;
+
+ size_t bufsize = sysconf(_SC_GETPW_R_SIZE_MAX);
+ if (bufsize == -1)
+ bufsize = 16384;
+
+ char *buf = malloc(bufsize);
+ if (buf == NULL) {
+#ifdef DEBUG
+ fprintf(stderr, "DEBUG: Unabel to allocate bufer while retrieving user '%s'\n", pc->user_name);
+#endif
+ return -1;
+ }
+
+ int rc = getpwnam_r(pc->user_name, &pc->pw, buf, bufsize, &result);
+
+ if (result == NULL) {
+#ifdef DEBUG
+ if (rc == 0) {
+ fprintf(stderr, "DEBUG: No user '%s' found\n", pc->user_name);
+ } else {
+ fprintf(stderr, "DEBUG: Unknown error while retrieving user '%s'\n", pc->user_name);
+ }
+#endif
+
+ free(buf);
+ return -1;
+ }
+
+ pc->uid = pc->pw.pw_uid;
+
+ free(buf);
+ return 0;
+}
+
+int get_user_groups(struct permission_checker *pc) {
+ // First assume we are in 10 groups max
+ pc->ngids = 10;
+ pc->gids = malloc(pc->ngids * sizeof(gid_t));
+
+ if (pc->gids == NULL) {
+#ifdef DEBUG
+ fprintf(stderr, "DEBUG: Unable to allocate space for gids.");
+#endif
+ return -1;
+ }
+
+ // Try so many times to load group list until it fits into group array.
+ while (getgrouplist(pc->user_name, pc->pw.pw_gid, pc->gids, &pc->ngids) == -1) {
+ // Too many groups, enlarge group array and try again
+ int newngids = pc->ngids + 100;
+ size_t newsize = newngids * sizeof(gid_t);
+
+ if (SIZE_MAX / newngids < sizeof(gid_t)) {
+ // Overflow
+#ifdef DEBUG
+ fprintf(stderr, "DEBUG: Overflow detected.");
+#endif
+ return -1;
+ }
+
+ gid_t *newgids = realloc(pc->gids, newsize);
+ if (newgids == NULL) {
+#ifdef DEBUG
+ fprintf(stderr, "DEBUG: Unable to allocate space for gids.");
+#endif
+ free(pc->gids);
+ return -1;
+ }
+
+ pc->gids = newgids;
+ pc->ngids = newngids;
+ }
+
+ return 0;
+}
+
+int is_member_of_group(struct permission_checker *pc, gid_t gid) {
+ int j;
+ for (j = 0; j < pc->ngids; j++)
+ if (pc->gids[j] == gid)
+ return 1;
+ return 0;
+}
+
+int check_acl_uid_matches(uid_t uid, acl_entry_t entry) {
+ int ret = -1;
+ uid_t *acl_uid = acl_get_qualifier(entry);
+ if (acl_uid == NULL) {
+#ifdef DEBUG
+ fprintf(stderr, "DEBUG: Unable to retrieve user uid from ACL entry");
+#endif
+ return -1;
+ }
+
+ ret = *acl_uid == uid ? 0 : -1;
+#ifdef DEBUG
+ fprintf(stderr, "DEBUG: ACL user match?: %d <=> %d: %d\n", *acl_uid, uid, ret);
+#endif
+ acl_free(acl_uid);
+ return ret;
+}
+
+int check_acl_gid_matches(gid_t *gids, int ngids, acl_entry_t entry) {
+ int ret = -1;
+ gid_t *acl_gid = acl_get_qualifier(entry);
+ if (acl_gid == NULL) {
+#ifdef DEBUG
+ fprintf(stderr, "DEBUG: Unable to retrieve user uid from ACL entry");
+#endif
+ return -1;
+ }
+
+ int j;
+ for (j = 0; j < ngids; j++) {
+ if (*acl_gid == gids[j]) {
+#ifdef DEBUG
+ fprintf(stderr, "DEBUG: User is in group %d", *acl_gid);
+#endif
+ ret = 0;
+ break;
+ }
+ }
+
+#ifdef DEBUG
+ fprintf(stderr, "DEBUG: ACL group match?: %d <=> ...: %d\n", *acl_gid, ret);
+#endif
+ acl_free(acl_gid);
+ return ret;
+}
+
+int check_acl(struct permission_checker *pc, const int flag) {
+ // By default user has no read perm.
+ int has_read_perm = 0;
+
+ // By default mask tells that there are read perm. However in order to have
+ // read permissions both, has_read_perm and mask_allows_read_access must be 1!
+ int mask_allows_read_access = 1;
+
+ acl_type_t type = ACL_TYPE_ACCESS;
+ acl_t acl = acl_get_file(pc->file_path, type);
+
+ if (acl == NULL)
+ // Unable to retrieve ACL.
+ return -1;
+
+ // Walk through each entry of this ACL.
+ int id;
+ for (id = ACL_FIRST_ENTRY; ; id = ACL_NEXT_ENTRY) {
+ acl_entry_t entry;
+ if (acl_get_entry(acl, id, &entry) != 1)
+ // No more ACL entries.
+ break;
+
+ acl_tag_t tag;
+ if (acl_get_tag_type(entry, &tag) == -1)
+ // Unable to retrieve ACL tag.
+ return -1;
+
+ switch (tag) {
+ case ACL_USER_OBJ:
+ if (flag == GROUP_CHECK)
+ continue;
+#ifdef DEBUG
+ fprintf(stderr, "DEBUG: ACL_USER_OBJ\n");
+#endif
+ // Ignore this ACL entry if user is not owner of file.
+ if (pc->uid != pc->file_stat.st_uid)
+ continue;
+ break;
+ case ACL_USER:
+ if (flag == GROUP_CHECK)
+ continue;
+#ifdef DEBUG
+ fprintf(stderr, "DEBUG: ACL_USER\n");
+#endif
+ // Ignore this ACL entry if uid does not match.
+ if (check_acl_uid_matches(pc->uid, entry) != 0)
+ continue;
+ break;
+ case ACL_GROUP_OBJ:
+ if (flag == USER_CHECK)
+ continue;
+#ifdef DEBUG
+ fprintf(stderr, "DEBUG: ACL_GROUP_OBJ\n");
+#endif
+ // Ignore ACL entry if user is not in group of file.
+ if (!is_member_of_group(pc, pc->file_stat.st_gid))
+ continue;
+ break;
+ case ACL_GROUP:
+ if (flag == USER_CHECK)
+ continue;
+#ifdef DEBUG
+ fprintf(stderr, "DEBUG: ACL_GROUP\n");
+#endif
+ // Ignore ACL entry if user is not in group of entry.
+ if (check_acl_gid_matches(pc->gids, pc->ngids, entry) != 0)
+ continue;
+ break;
+ case ACL_OTHER:
+ if (flag == GROUP_CHECK)
+ continue;
+#ifdef DEBUG
+ fprintf(stderr, "DEBUG: ACL_OTHER\n");
+#endif
+ break;
+ case ACL_MASK:
+#ifdef DEBUG
+ fprintf(stderr, "DEBUG: ACL_MASK\n");
+#endif
+ break;
+ default:
+#ifdef DEBUG
+ fprintf(stderr, "DEBUG: Unknown ACL tag\n");
+#endif
+ return -1;
+ }
+
+#ifdef DEBUG
+ fprintf(stderr, "DEBUG: Retrieving permset\n");
+#endif
+ acl_permset_t permset;
+ int permission;
+ if (acl_get_permset(entry, &permset) == -1)
+ // Unable to retrieve permset.
+ return -1;
+
+ if ((permission = acl_get_perm(permset, ACL_READ)) == -1)
+ // Unable to retrieve permset value.
+ return -1;
+
+ if (permission == 1 && tag != ACL_MASK) {
+#ifdef DEBUG
+ fprintf(stderr, "DEBUG: ACL says user has permission to read file.\n");
+#endif
+ has_read_perm = 1;
+ } else if (permission == 0 && tag == ACL_MASK) {
+ // Mask says that there are no permissions to read.
+ mask_allows_read_access = 0;
+#ifdef DEBUG
+ fprintf(stderr, "DEBUG: ACL mask says no permission to read file.\n");
+#endif
+ }
+ }
+
+ if (has_read_perm && mask_allows_read_access) {
+#ifdef DEBUG
+ fprintf(stderr, "DEBUG: ACL end result: User has permission to read file.\n");
+#endif
+ return 1;
+ }
+
+#ifdef DEBUG
+ fprintf(stderr, "DEBUG: ACL end result: User has no permission to read file.\n");
+#endif
+ return 0;
+}
+
+int check_traditional(struct permission_checker *pc, const int flag) {
+ mode_t mode = pc->file_stat.st_mode;
+ uid_t uid = pc->file_stat.st_uid;
+ gid_t gid = pc->file_stat.st_gid;
+
+ if (flag == USER_CHECK && (mode & S_IROTH)) {
+#ifdef DEBUG
+ fprintf(stderr, "DEBUG: Others can read file '%s'\n",
+ pc->file_path);
+#endif
+ return 1;
+
+ } else if (flag == USER_CHECK && (mode & S_IRUSR) && uid == pc->uid) {
+#ifdef DEBUG
+ fprintf(stderr, "DEBUG: User '%s' can read file '%s'\n",
+ pc->user_name, pc->file_path);
+#endif
+ return 1;
+
+ } else if (flag == GROUP_CHECK && (mode & S_IRGRP) && is_member_of_group(pc, gid)) {
+#ifdef DEBUG
+ fprintf(stderr, "DEBUG: User's '%s' group can read file '%s'\n",
+ pc->user_name, pc->file_path);
+#endif
+ return 1;
+ }
+
+ return 0;
+}
+
+int permission_to_read(char* user_name, char *file_path) {
+ int rc = -1;
+
+#ifdef DEBUG
+ fprintf(stderr, "DEBUG: User check '%s' for file '%s'\n", user_name, file_path);
+#endif
+ struct permission_checker pc = {
+ .user_name = user_name,
+ .gids = NULL,
+ .ngids = 0,
+ .file_path = file_path,
+ };
+
+ // Gather user's UID.
+ if ((rc = get_user_uid(&pc)) == -1)
+ // Could not retrieve UID.
+ goto cleanup;
+
+ // Gather file owner (user and group).
+ if ((rc = stat_file(&pc)) == -1)
+ // Could not stat file.
+ goto cleanup;
+
+ // Check whether there is an ACL entry which would allow the user
+ // to read the file. Don't check for any groups yet. The issue with
+ // groups is that it can be very slow to retrieve the list of groups
+ // of a specific user when done via a remote LDAP server!
+ if ((rc = check_acl(&pc, USER_CHECK)) == 1)
+ // Yes, has permissions.
+ goto cleanup;
+
+ // Check whether ACLs of file could be retrieved.
+ if (rc == -1) {
+ if (errno != ENOTSUP)
+ // Unknown error.
+ goto cleanup;
+
+ // File system does not support ACLs.
+ // Fallback to traditional permissions.
+ if ((rc = check_traditional(&pc, USER_CHECK)) == 1)
+ // Yes, has traditional permissions.
+ goto cleanup;
+
+ if ((rc = get_user_groups(&pc)) == -1)
+ // Can not retrieve user's groups.
+ goto cleanup;
+
+ rc = check_traditional(&pc, GROUP_CHECK);
+ goto cleanup;
+ }
+
+ if ((rc = get_user_groups(&pc)) == -1)
+ // Can not retrieve use'r groups.
+ goto cleanup;
+
+ // Check whether there is an ACL entry which would allow any of the
+ // user's groups to read the file.
+ rc = check_acl(&pc, GROUP_CHECK);
+
+cleanup:
+#ifdef DEBUG
+ debug_print_checker(&pc);
+#endif
+
+ if (pc.ngids)
+ free(pc.gids);
+
+ return rc;
+}
+
+// vim: set tabstop=8 softtabstop=0 expandtab shiftwidth=4 smarttab
diff --git a/internal/fs/permissions/permission_linux.go b/internal/fs/permissions/permission_linux.go
new file mode 100644
index 0000000..feae729
--- /dev/null
+++ b/internal/fs/permissions/permission_linux.go
@@ -0,0 +1,33 @@
+package permissions
+
+/*
+#include "permission_linux.h"
+#cgo LDFLAGS: -L. -lacl
+*/
+import "C"
+
+import (
+ "errors"
+ "unsafe"
+)
+
+// To check whether user has Linux file system permissions to read a given file.
+func ToRead(user, filePath string) (bool, error) {
+ cUser := C.CString(user)
+ cFilePath := C.CString(filePath)
+
+ defer C.free(unsafe.Pointer(cUser))
+ defer C.free(unsafe.Pointer(cFilePath))
+
+ cOk, err := C.permission_to_read(cUser, cFilePath)
+ if cOk == 1 {
+ return true, nil
+ }
+
+ if err != nil {
+ // err contains errno message
+ return false, err
+ }
+
+ return false, errors.New("User without permission to read file")
+}
diff --git a/internal/fs/permissions/permission_linux.h b/internal/fs/permissions/permission_linux.h
new file mode 100644
index 0000000..a2c266e
--- /dev/null
+++ b/internal/fs/permissions/permission_linux.h
@@ -0,0 +1,60 @@
+#ifndef PERMISSION_LINUX_H
+#define PERMISSION_LINUX_H
+
+#include <acl/libacl.h>
+#include <errno.h>
+#include <grp.h>
+#include <pwd.h>
+#include <stdio.h>
+#include <stdint.h>
+#include <stdlib.h>
+#include <sys/acl.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+//#define DEBUG
+#define USER_CHECK 0
+#define GROUP_CHECK 1
+
+struct permission_checker {
+ char *user_name;
+ uid_t uid;
+ gid_t *gids;
+ int ngids;
+ char *file_path;
+ struct stat file_stat;
+ struct passwd pw;
+};
+
+
+#ifdef DEBUG
+// Print out permission_checker struct.
+void debug_print_checker(struct permission_checker *pc);
+#endif
+
+// Stat a given file to retrieve traditional UNIX permissions.
+int stat_file(struct permission_checker *pc);
+
+// Retrieve UID of user.
+int get_user_uid(struct permission_checker *pc);
+
+// Retrieve all groups of the user.
+int get_user_groups(struct permission_checker *pc);
+
+// Check whether user is member of a group or not.
+int is_member_of_group(struct permission_checker *pc, gid_t gid);
+
+// Check whether user can read file according Linux ACLs.
+// As flag use either USER_CHECK or GROUP_CHECK.
+int check_acl(struct permission_checker *pc, const int flag);
+
+// Check whether user has permissions to read file according traditional
+// UNIX permissions. As flag use either USER_CHECK or GROUP_CHECK.
+int check_traditional(struct permission_checker *pc, const int flag);
+
+// Returns 1 if user has permission to read file.
+// Returns <0 on error and returns 0 if no permissions.
+int permission_to_read(char* user, char *file_path);
+
+#endif // PERMISSION_LINUX_H
diff --git a/internal/fs/permissions/permission_test.go b/internal/fs/permissions/permission_test.go
new file mode 100644
index 0000000..d415ac2
--- /dev/null
+++ b/internal/fs/permissions/permission_test.go
@@ -0,0 +1,112 @@
+// +build linux
+
+package permissions
+
+import (
+ "os"
+ "os/exec"
+ "os/user"
+ "strings"
+ "testing"
+)
+
+const (
+ setfacl string = "/usr/bin/setfacl"
+ file string = "/tmp/acltest"
+)
+
+func TestLinuxACL(t *testing.T) {
+ setfacl := "/usr/bin/setfacl"
+ file := "/tmp/acltest"
+
+ // Delete file if it exists.
+ if _, err := os.Stat(file); err == nil {
+ os.Remove(file)
+ }
+
+ f, err := os.Create(file)
+ if err != nil {
+ t.Errorf("%v", err)
+ }
+ defer func() {
+ f.Close()
+ //os.Remove(file)
+ }()
+
+ user, err := user.Current()
+ if err != nil {
+ t.Errorf("Unable to retrieve current user: %v", err)
+ }
+
+ // Test 1: Remove all permissions and perform a permission check
+ cmd := exec.Command(setfacl, "-b", "-m", "u::---,g::---,o::---", file)
+ if err := cmd.Run(); err != nil {
+ t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err)
+ }
+ if ok, _ := ToRead(user.Username, file); ok {
+ t.Errorf("Didn't expect permissions to read file!")
+ }
+
+ // Test 2: Add read permission to file owner
+ cmd = exec.Command(setfacl, "-b", "-m", "u::r--,g::---,o::---", file)
+ if err := cmd.Run(); err != nil {
+ t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err)
+ }
+ if ok, err := ToRead(user.Username, file); !ok {
+ t.Errorf("Expected permissions to read file: %v", err)
+ }
+
+ // Test 3: Add read permission to file group
+ cmd = exec.Command(setfacl, "-b", "-m", "u::---,g::r--,o::---", file)
+ if err := cmd.Run(); err != nil {
+ t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err)
+ }
+ if ok, err := ToRead(user.Username, file); !ok {
+ t.Errorf("Expected permissions to read file: %v", err)
+ }
+
+ // Test 4: Add read permission to others
+ cmd = exec.Command(setfacl, "-b", "-m", "u::---,g::---,o::r--", file)
+ if err := cmd.Run(); err != nil {
+ t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err)
+ }
+
+ if ok, err := ToRead(user.Username, file); !ok {
+ t.Errorf("Expected permissions to read file: %v", err)
+ }
+
+ // Test 5: Remove read permission from mask
+ cmd = exec.Command(setfacl, "-m", "m::---", file)
+ if err := cmd.Run(); err != nil {
+ t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err)
+ }
+ if ok, _ := ToRead(user.Username, file); ok {
+ t.Errorf("Didn't expect permissions to read file!")
+ }
+ cmd = exec.Command(setfacl, "-m", "m::r--", file)
+ if err := cmd.Run(); err != nil {
+ t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err)
+ }
+
+ // Test 6: Add read permission to specific group
+ cmd = exec.Command(setfacl, "-b", "-m", "u::---,g:"+user.Username+":r--,o::---", file)
+ if err := cmd.Run(); err != nil {
+ t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err)
+ }
+ if ok, err := ToRead(user.Username, file); !ok {
+ t.Errorf("Expected permissions to read file for user %v: %v", user.Username, err)
+ }
+
+ // Test 7: Remove all permissions but mask
+ cmd = exec.Command(setfacl, "-b", "-m", "u::---,g::---,o::---", file)
+ if err := cmd.Run(); err != nil {
+ t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err)
+ }
+ cmd = exec.Command(setfacl, "-m", "m::r--", file)
+ if err := cmd.Run(); err != nil {
+ t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err)
+ }
+ if ok, _ := ToRead(user.Username, file); ok {
+ t.Errorf("Didn't expect permissions to read file!")
+ }
+}
diff --git a/internal/fs/readfile.go b/internal/fs/readfile.go
new file mode 100644
index 0000000..312447a
--- /dev/null
+++ b/internal/fs/readfile.go
@@ -0,0 +1,318 @@
+package fs
+
+import (
+ "bufio"
+ "compress/gzip"
+ "github.com/mimecast/dtail/internal/logger"
+ "errors"
+ "io"
+ "os"
+ "regexp"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/DataDog/zstd"
+)
+
+// Used to tail and filter a local log file.
+type readFile struct {
+ // Various statistics (e.g. regex hit percentage, transfer percentage).
+ stats
+ // Path of log file to tail.
+ filePath string
+ // Only consider all log lines matching this regular expression.
+ re *regexp.Regexp
+ // The glob identifier of the file.
+ globID string
+ // Channel to send a server message to the dtail client
+ serverMessages chan<- string
+ // Signals to stop tailing the log file.
+ stop chan struct{}
+ // Periodically retry reading file.
+ retry bool
+ // Can I skip messages when there are too many?
+ canSkipLines bool
+ // Seek to the EOF before processing file?
+ seekEOF bool
+ // Mutex to control the stopping of the file
+ mutex *sync.Mutex
+ limiter chan struct{}
+}
+
+// FilePath returns the full file path.
+func (f readFile) FilePath() string {
+ return f.filePath
+}
+
+// Retry reading the file on error?
+func (f readFile) Retry() bool {
+ return f.retry
+}
+
+// Start tailing a log file.
+func (f readFile) Start(lines chan<- LineRead, regex string) error {
+ defer func() {
+ select {
+ case <-f.limiter:
+ default:
+ }
+ }()
+
+ select {
+ case f.limiter <- struct{}{}:
+ default:
+ select {
+ case f.serverMessages <- logger.Warn(f.filePath, f.globID, "Server limit reached. Queuing file..."):
+ case <-f.stop:
+ return nil
+ }
+ f.limiter <- struct{}{}
+ }
+
+ fd, err := os.Open(f.filePath)
+ if err != nil {
+ return err
+ }
+ defer fd.Close()
+
+ if f.seekEOF {
+ fd.Seek(0, io.SeekEnd)
+ }
+
+ rawLines := make(chan []byte, 100)
+ truncate := make(chan struct{})
+
+ var wg sync.WaitGroup
+ wg.Add(1)
+
+ go f.periodicTruncateCheck(truncate)
+ go f.filter(&wg, rawLines, lines, regex)
+
+ err = f.read(fd, rawLines, truncate)
+ close(rawLines)
+ wg.Wait()
+
+ return err
+}
+
+func (f readFile) periodicTruncateCheck(truncate chan struct{}) {
+ for {
+ select {
+ case <-time.After(time.Second * 3):
+ select {
+ case truncate <- struct{}{}:
+ case <-f.stop:
+ }
+ case <-f.stop:
+ return
+ }
+ }
+}
+
+// Stop reading file.
+func (f readFile) Stop() {
+ f.mutex.Lock()
+ defer f.mutex.Unlock()
+
+ select {
+ case <-f.stop:
+ return
+ default:
+ }
+
+ close(f.stop)
+}
+
+func (f readFile) makeReader(fd *os.File) (reader *bufio.Reader, err error) {
+ switch {
+ case strings.HasSuffix(f.FilePath(), ".gz"):
+ fallthrough
+ case strings.HasSuffix(f.FilePath(), ".gzip"):
+ logger.Info(f.FilePath(), "Detected gzip compression format")
+ var gzipReader *gzip.Reader
+ gzipReader, err = gzip.NewReader(fd)
+ if err != nil {
+ return
+ }
+ reader = bufio.NewReader(gzipReader)
+ case strings.HasSuffix(f.FilePath(), ".zst"):
+ logger.Info(f.FilePath(), "Detected zstd compression format")
+ reader = bufio.NewReader(zstd.NewReader(fd))
+ default:
+ reader = bufio.NewReader(fd)
+ }
+
+ return
+}
+
+func (f readFile) read(fd *os.File, rawLines chan []byte, truncate <-chan struct{}) error {
+ reader, err := f.makeReader(fd)
+ if err != nil {
+ return err
+ }
+ rawLine := make([]byte, 0, 512)
+ var offset uint64
+
+ lineLengthThreshold := 1024 * 1024 // 1mb
+ longLineWarning := false
+
+ for {
+ select {
+ case <-truncate:
+ if isTruncated, err := f.truncated(fd); isTruncated {
+ return err
+ }
+ logger.Info(f.filePath, "Current offset", offset)
+
+ case <-f.stop:
+ return nil
+ default:
+ }
+
+ // Read some bytes (max 4k at once as of go 1.12). isPrefix will
+ // be set if line does not fit into 4k buffer.
+ bytes, isPrefix, err := reader.ReadLine()
+
+ if err != nil {
+ // If EOF, sleep a couple of ms and return with nil error.
+ // If other error, return with non-nil error.
+ if err != io.EOF {
+ return err
+ }
+ if !f.seekEOF {
+ logger.Debug(f.FilePath(), "End of file reached")
+ return nil
+ }
+ time.Sleep(time.Millisecond * 100)
+ continue
+ }
+
+ rawLine = append(rawLine, bytes...)
+ offset += uint64(len(bytes))
+
+ if !isPrefix {
+ // last LineRead call returned contend until end of line.
+ rawLine = append(rawLine, '\n')
+ select {
+ case rawLines <- rawLine:
+ case <-f.stop:
+ return nil
+ }
+ rawLine = make([]byte, 0, 512)
+ if longLineWarning {
+ longLineWarning = false
+ }
+ continue
+ }
+
+ // Last LineRead call could not read content until end of line, buffer
+ // was too small. Determine whether we exceed the max line length we
+ // want dtail to send to the client at once. Possibly split up log line
+ // into multiple log lines.
+ if len(rawLine) >= lineLengthThreshold {
+ if !longLineWarning {
+ f.serverMessages <- logger.Warn(f.filePath, "Long log line, splitting into multiple lines")
+ // Only print out one warning per long log line.
+ longLineWarning = true
+ }
+ rawLine = append(rawLine, '\n')
+ select {
+ case rawLines <- rawLine:
+ case <-f.stop:
+ return nil
+ }
+ rawLine = make([]byte, 0, 512)
+ }
+ }
+}
+
+// Filter log lines matching a given regular expression.
+func (f readFile) filter(wg *sync.WaitGroup, rawLines <-chan []byte, lines chan<- LineRead, regex string) {
+ defer wg.Done()
+
+ if regex == "" {
+ regex = "."
+ }
+
+ re, err := regexp.Compile(regex)
+ if err != nil {
+ logger.Error(regex, "Can't compile regex, using '.' instead", err)
+ re = regexp.MustCompile(".")
+ }
+ f.re = re
+
+ for {
+ select {
+ case line, ok := <-rawLines:
+ f.updatePosition()
+ if !ok {
+ return
+ }
+ if filteredLine, ok := f.transmittable(line, len(lines), cap(lines)); ok {
+ select {
+ case lines <- filteredLine:
+ case <-f.stop:
+ return
+ }
+ }
+ }
+ }
+}
+
+func (f readFile) transmittable(line []byte, length, capacity int) (LineRead, bool) {
+ var read LineRead
+
+ if !f.re.Match(line) {
+ f.updateLineNotMatched()
+ f.updateLineNotTransmitted()
+ return read, false
+ }
+ f.updateLineMatched()
+
+ // Can we actually send more messages, channel capacity reached?
+ if f.canSkipLines && length >= capacity {
+ f.updateLineNotTransmitted()
+ return read, false
+ }
+ f.updateLineTransmitted()
+
+ read = LineRead{
+ Content: line,
+ GlobID: &f.globID,
+ Count: f.totalLineCount(),
+ TransmittedPerc: f.transmittedPerc(),
+ }
+
+ return read, true
+}
+
+// Check wether log file is truncated. Returns nil if not.
+func (f readFile) truncated(fd *os.File) (bool, error) {
+ logger.Debug(f.filePath, "File truncation check")
+
+ // Can not seek currently open FD.
+ curPos, err := fd.Seek(0, os.SEEK_CUR)
+ if err != nil {
+ return true, err
+ }
+
+ // Can not open file at original path.
+ pathFd, err := os.Open(f.filePath)
+ if err != nil {
+ return true, err
+ }
+ defer pathFd.Close()
+
+ // Can not seek file at original path.
+ pathPos, err := pathFd.Seek(0, io.SeekEnd)
+ if err != nil {
+ return true, err
+ }
+
+ if curPos > pathPos {
+ return true, errors.New("File got truncated")
+ }
+
+ return false, nil
+}
diff --git a/internal/fs/stats.go b/internal/fs/stats.go
new file mode 100644
index 0000000..4121ff7
--- /dev/null
+++ b/internal/fs/stats.go
@@ -0,0 +1,69 @@
+package fs
+
+// Used to calculate how many log lines matched the regular expression
+// and how many log files could be transmitted from the server to the client.
+// Hit and transmit percentage takes only the last 100 log lines into calculation.
+type stats struct {
+ pos int
+ lineCount uint64
+ matched [100]bool
+ matchCount uint64
+ transmitted [100]bool
+ transmitCount int
+}
+
+// Return the total line count.
+func (f *stats) totalLineCount() uint64 {
+ return f.lineCount
+}
+
+// Calculate the percentage of log lines transmitted to the client.
+func (f *stats) transmittedPerc() int {
+ return int(percentOf(float64(f.matchCount), float64(f.transmitCount)))
+}
+
+// Update bucket position. We only take into consideration the last 100
+// lines for stats.
+func (f *stats) updatePosition() {
+ f.pos = (f.pos + 1) % 100
+ f.lineCount++
+}
+
+// Increment match counter.
+func (f *stats) updateLineMatched() {
+ if !f.matched[f.pos] {
+ f.matchCount++
+ f.matched[f.pos] = true
+ }
+}
+
+// Increment transmitted counter.
+func (f *stats) updateLineTransmitted() {
+ if !f.transmitted[f.pos] {
+ f.transmitCount++
+ f.transmitted[f.pos] = true
+ }
+}
+
+// Decrement match counter.
+func (f *stats) updateLineNotMatched() {
+ if f.matched[f.pos] {
+ f.matchCount--
+ f.matched[f.pos] = false
+ }
+}
+
+// Decrement transmitted counter.
+func (f *stats) updateLineNotTransmitted() {
+ if f.transmitted[f.pos] {
+ f.transmitCount--
+ f.transmitted[f.pos] = false
+ }
+}
+
+func percentOf(total float64, value float64) float64 {
+ if total == 0 || total == value {
+ return 100
+ }
+ return value / (total / 100.0)
+}
diff --git a/internal/fs/tailfile.go b/internal/fs/tailfile.go
new file mode 100644
index 0000000..a19d4e6
--- /dev/null
+++ b/internal/fs/tailfile.go
@@ -0,0 +1,27 @@
+package fs
+
+import "sync"
+
+// TailFile is to tail and filter a log file.
+type TailFile struct {
+ readFile
+}
+
+// NewTailFile returns a new file tailer.
+func NewTailFile(filePath string, globID string, serverMessages chan<- string, limiter chan struct{}) TailFile {
+ var mutex sync.Mutex
+
+ return TailFile{
+ readFile: readFile{
+ filePath: filePath,
+ stop: make(chan struct{}),
+ globID: globID,
+ serverMessages: serverMessages,
+ retry: true,
+ canSkipLines: true,
+ seekEOF: true,
+ limiter: limiter,
+ mutex: &mutex,
+ },
+ }
+}
diff --git a/internal/logger/logger.go b/internal/logger/logger.go
new file mode 100644
index 0000000..ca85e32
--- /dev/null
+++ b/internal/logger/logger.go
@@ -0,0 +1,457 @@
+package logger
+
+import (
+ "bufio"
+ "fmt"
+ "os"
+ "os/signal"
+ "runtime"
+ "strings"
+ "sync"
+ "syscall"
+ "time"
+
+ "github.com/mimecast/dtail/internal/color"
+ "github.com/mimecast/dtail/internal/config"
+)
+
+const (
+ clientStr string = "CLIENT"
+ serverStr string = "SERVER"
+ infoStr string = "INFO"
+ warnStr string = "WARN"
+ errorStr string = "ERROR"
+ fatalStr string = "FATAL"
+ debugStr string = "DEBUG"
+ traceStr string = "TRACE"
+)
+
+// Synchronise access to logging.
+var mutex sync.Mutex
+
+// File descriptor of log file when logToFile enabled.
+var fd *os.File
+
+// File write buffer of log file when logToFile enabled.
+var writer *bufio.Writer
+
+// File write buffer of stdout when logToStdout enabled.
+var stdoutWriter *bufio.Writer
+
+// Current hostname.
+var hostname string
+
+// Used to detect change of day (create one log file per day0
+var lastDateStr string
+
+// True if log in server mode, false if log in client mode.
+var serverEnable bool
+
+// Used to make logging non-blocking.
+var logBufCh chan buf
+var stdoutBufCh chan string
+
+// Stdout channel, required to pause output
+var pauseCh chan struct{}
+var resumeCh chan struct{}
+
+// Tell the logger that we are done, program shuts down
+var stop chan struct{}
+var stdoutFlushed chan struct{}
+
+// Tell the logger about logrotation
+var rotateCh chan os.Signal
+
+// LogMode allows to specify the verbosity of logging.
+type LogMode int
+
+// Possible log modes.
+const (
+ NormalMode LogMode = iota
+ DebugMode LogMode = iota
+ SilentMode LogMode = iota
+ TraceMode LogMode = iota
+ NothingMode LogMode = iota
+)
+
+// Mode is the current log mode in use.
+var Mode LogMode
+
+// LogStrategy allows to specify a log rotation strategy.
+type LogStrategy int
+
+// Possible log strategies.
+const (
+ NormalStrategy LogStrategy = iota
+ DailyStrategy LogStrategy = iota
+ StdoutStrategy LogStrategy = iota
+)
+
+// Strategy is the current log strattegy used.
+var Strategy LogStrategy
+
+// Enables logging to stdout.
+var logToStdout bool
+
+// Enables logging to file.
+var logToFile bool
+
+// Helper type to make logging non-blocking.
+type buf struct {
+ time time.Time
+ message string
+}
+
+// Start logging.
+func Start(myServerEnable, debugEnable, silentEnable, nothingEnable bool) {
+ serverEnable = myServerEnable
+
+ mode := logMode(debugEnable, silentEnable, nothingEnable)
+ strategy := logStrategy()
+
+ stdoutWriter = bufio.NewWriter(os.Stdout)
+ Mode = mode
+ Strategy = strategy
+
+ if Mode == NothingMode {
+ return
+ }
+
+ switch Strategy {
+ case DailyStrategy:
+ _, err := os.Stat(config.Common.LogDir)
+ logToFile = !os.IsNotExist(err)
+ logToStdout = !serverEnable || Mode == DebugMode || Mode == TraceMode
+ case StdoutStrategy:
+ fallthrough
+ default:
+ logToFile = false
+ logToStdout = true
+ }
+
+ fqdn, err := os.Hostname()
+ if err != nil {
+ panic(err)
+ }
+ s := strings.Split(fqdn, ".")
+ hostname = s[0]
+
+ pauseCh = make(chan struct{})
+ resumeCh = make(chan struct{})
+ stop = make(chan struct{})
+ stdoutFlushed = make(chan struct{})
+
+ // Setup logrotation
+ rotateCh = make(chan os.Signal, 1)
+ signal.Notify(rotateCh, syscall.SIGHUP)
+
+ if logToStdout {
+ stdoutBufCh = make(chan string, runtime.NumCPU()*100)
+ go writeToStdout()
+ }
+
+ if logToFile {
+ logBufCh = make(chan buf, runtime.NumCPU()*100)
+ go writeToFile()
+ }
+}
+
+func logMode(debugEnable, silentEnable, nothingEnable bool) LogMode {
+ switch {
+ case debugEnable:
+ return DebugMode
+ case nothingEnable:
+ return NothingMode
+ case config.Common.TraceEnable:
+ return TraceMode
+ case config.Common.DebugEnable:
+ return DebugMode
+ case silentEnable:
+ return SilentMode
+ default:
+ }
+ return NormalMode
+}
+
+func logStrategy() LogStrategy {
+ switch config.Common.LogStrategy {
+ case "daily":
+ return DailyStrategy
+ default:
+ }
+ return StdoutStrategy
+}
+
+// Info message logging.
+func Info(args ...interface{}) string {
+ if serverEnable {
+ return log(serverStr, infoStr, args)
+ }
+
+ return log(clientStr, infoStr, args)
+}
+
+// Warn message logging.
+func Warn(args ...interface{}) string {
+ if serverEnable {
+ return log(serverStr, warnStr, args)
+ }
+
+ return log(clientStr, warnStr, args)
+}
+
+// Error message logging.
+func Error(args ...interface{}) string {
+ if serverEnable {
+ return log(serverStr, errorStr, args)
+ }
+
+ return log(clientStr, errorStr, args)
+}
+
+// FatalExit logs an error and exists the process.
+func FatalExit(args ...interface{}) {
+ what := clientStr
+ if serverEnable {
+ what = serverStr
+ }
+ log(what, fatalStr, args)
+
+ time.Sleep(time.Second)
+ mutex.Lock()
+ defer mutex.Unlock()
+
+ closeWriter()
+ os.Exit(3)
+}
+
+// Debug message logging.
+func Debug(args ...interface{}) string {
+ if Mode == DebugMode || Mode == TraceMode {
+ if serverEnable {
+ return log(serverStr, debugStr, args)
+ }
+ return log(clientStr, debugStr, args)
+ }
+
+ return ""
+}
+
+// Trace message logging.
+func Trace(args ...interface{}) string {
+ if Mode == TraceMode {
+ if serverEnable {
+ return log(serverStr, traceStr, args)
+ }
+ return log(clientStr, traceStr, args)
+ }
+
+ return ""
+}
+
+// Write log line to buffer and/or log file.
+func write(what, severity, message string) {
+ if logToStdout && (Mode != SilentMode || severity != warnStr) {
+ line := fmt.Sprintf("%s|%s|%s|%s\n", what, hostname, severity, message)
+
+ if color.Colored {
+ line = color.Colorfy(line)
+ }
+
+ stdoutBufCh <- line
+ }
+
+ if logToFile {
+ t := time.Now()
+ timeStr := t.Format("20060102-150405")
+ logBufCh <- buf{
+ time: t,
+ message: fmt.Sprintf("%s|%s|%s|%s\n", severity, timeStr, what, message),
+ }
+ }
+}
+
+// Generig log message.
+func log(what string, severity string, args []interface{}) string {
+ if Mode == NothingMode {
+ return ""
+ }
+
+ var messages []string
+
+ for _, arg := range args {
+ switch v := arg.(type) {
+ case string:
+ messages = append(messages, v)
+ case int:
+ messages = append(messages, fmt.Sprintf("%d", v))
+ case error:
+ messages = append(messages, v.Error())
+ default:
+ messages = append(messages, fmt.Sprintf("%v", v))
+ }
+ }
+
+ message := strings.Join(messages, "|")
+ write(what, severity, message)
+
+ return fmt.Sprintf("%s|%s", severity, message)
+}
+
+// Raw message logging.
+func Raw(message string) {
+ if Mode == NothingMode {
+ return
+ }
+
+ if logToStdout {
+ if color.Colored {
+ message = color.Colorfy(message)
+ }
+ stdoutBufCh <- message
+ }
+
+ if logToFile {
+ logBufCh <- buf{time.Now(), message}
+ }
+}
+
+// Close log writer (e.g. on change of day).
+func closeWriter() {
+ if writer != nil {
+ writer.Flush()
+ fd.Close()
+ }
+}
+
+// Return the correct log file writer
+func fileWriter(dateStr string) *bufio.Writer {
+ if dateStr != lastDateStr {
+ return updateFileWriter(dateStr)
+ }
+
+ // Check for log rotation signal
+ select {
+ case <-rotateCh:
+ stdoutWriter.WriteString("Received signal for logrotation\n")
+ return updateFileWriter(dateStr)
+ default:
+ }
+
+ return writer
+}
+
+// Update log file writer
+func updateFileWriter(dateStr string) *bufio.Writer {
+ // Detected change of day. Close current writer and create a new one.
+ mutex.Lock()
+ defer mutex.Unlock()
+ closeWriter()
+
+ if _, err := os.Stat(config.Common.LogDir); os.IsNotExist(err) {
+ if err = os.MkdirAll(config.Common.LogDir, 0755); err != nil {
+ panic(err)
+ }
+ }
+
+ logFile := fmt.Sprintf("%s/%s.log", config.Common.LogDir, dateStr)
+ newFd, err := os.OpenFile(logFile, os.O_CREATE|os.O_RDWR|os.O_APPEND, 0644)
+ if err != nil {
+ panic(err)
+ }
+
+ fd = newFd
+ writer = bufio.NewWriterSize(fd, 1)
+ lastDateStr = dateStr
+
+ return writer
+}
+
+func flushStdout() {
+ defer close(stdoutFlushed)
+
+ for {
+ select {
+ case message := <-stdoutBufCh:
+ stdoutWriter.WriteString(message)
+ default:
+ stdoutWriter.Flush()
+ return
+ }
+ }
+}
+
+func writeToStdout() {
+ for {
+ select {
+ case message := <-stdoutBufCh:
+ stdoutWriter.WriteString(message)
+ case <-time.After(time.Millisecond * 100):
+ stdoutWriter.Flush()
+ case <-pauseCh:
+ PAUSE:
+ for {
+ select {
+ case <-stdoutBufCh:
+ case <-resumeCh:
+ break PAUSE
+ case <-stop:
+ return
+ }
+ }
+ case <-stop:
+ flushStdout()
+ return
+ }
+ }
+}
+
+func writeToFile() {
+ for {
+ select {
+ case buf := <-logBufCh:
+ dateStr := buf.time.Format("20060102")
+ w := fileWriter(dateStr)
+ w.WriteString(buf.message)
+ case <-pauseCh:
+ PAUSE:
+ for {
+ select {
+ case <-stdoutBufCh:
+ case <-resumeCh:
+ break PAUSE
+ case <-stop:
+ return
+ }
+ }
+ case <-stop:
+ return
+ }
+ }
+}
+
+// Pause logging.
+func Pause() {
+ if logToStdout {
+ pauseCh <- struct{}{}
+ }
+ if logToFile {
+ pauseCh <- struct{}{}
+ }
+}
+
+// Resume logging (after pausing).
+func Resume() {
+ if logToStdout {
+ resumeCh <- struct{}{}
+ }
+ if logToFile {
+ resumeCh <- struct{}{}
+ }
+}
+
+// Stop logging.
+func Stop() {
+ close(stop)
+ <-stdoutFlushed
+}
diff --git a/internal/mapr/aggregateset.go b/internal/mapr/aggregateset.go
new file mode 100644
index 0000000..2096c3c
--- /dev/null
+++ b/internal/mapr/aggregateset.go
@@ -0,0 +1,185 @@
+package mapr
+
+import (
+ "fmt"
+ "strconv"
+ "strings"
+)
+
+// AggregateSet represents aggregated key/value pairs from the
+// MAPREDUCE log lines. These could be either string values or float
+// values.
+type AggregateSet struct {
+ Samples int
+ FValues map[string]float64
+ SValues map[string]string
+}
+
+// NewAggregateSet creates a new empty aggregate set.
+func NewAggregateSet() *AggregateSet {
+ return &AggregateSet{
+ FValues: make(map[string]float64),
+ SValues: make(map[string]string),
+ }
+}
+
+// String representation of aggregate set.
+func (s *AggregateSet) String() string {
+ return fmt.Sprintf("AggregateSet(Samples:%d,FValues:%v,SValues:%v)",
+ s.Samples, s.FValues, s.SValues)
+}
+
+// Merge one aggregate set into this one.
+func (s *AggregateSet) Merge(query *Query, set *AggregateSet) error {
+ s.Samples += set.Samples
+ //logger.Trace("Merge", set)
+
+ for _, sc := range query.Select {
+ storage := sc.FieldStorage
+ switch sc.Operation {
+ case Count:
+ fallthrough
+ case Sum:
+ fallthrough
+ case Avg:
+ value := set.FValues[storage]
+ s.addFloat(storage, value)
+ case Min:
+ value := set.FValues[storage]
+ s.addFloatMin(storage, value)
+ case Max:
+ value := set.FValues[storage]
+ s.addFloatMax(storage, value)
+ case Last:
+ value := set.SValues[storage]
+ s.setString(storage, value)
+ case Len:
+ s.setString(storage, set.SValues[storage])
+ s.setFloat(storage, set.FValues[storage])
+ default:
+ return fmt.Errorf("Unknown aggregation method '%v'", sc.Operation)
+ }
+ }
+ return nil
+}
+
+// Serialize the aggregate set so it can be sent over the wire.
+func (s *AggregateSet) Serialize(groupKey string, ch chan<- string, stop chan struct{}) {
+ //logger.Trace("Serialising mapr.AggregateSet", s)
+ var sb strings.Builder
+
+ sb.WriteString(groupKey)
+ sb.WriteString("|")
+ sb.WriteString(fmt.Sprintf("%d|", s.Samples))
+
+ for k, v := range s.FValues {
+ sb.WriteString(k)
+ sb.WriteString("=")
+ sb.WriteString(fmt.Sprintf("%v|", v))
+ }
+
+ for k, v := range s.SValues {
+ sb.WriteString(k)
+ sb.WriteString("=")
+ sb.WriteString(v)
+ sb.WriteString("|")
+ }
+
+ select {
+ case ch <- sb.String():
+ case <-stop:
+ }
+}
+
+// Add a float value.
+func (s *AggregateSet) addFloat(key string, value float64) {
+ if _, ok := s.FValues[key]; !ok {
+ s.FValues[key] = value
+ return
+ }
+ s.FValues[key] += value
+}
+
+// Add a float minimum value.
+func (s *AggregateSet) addFloatMin(key string, value float64) {
+ f, ok := s.FValues[key]
+ if !ok {
+ s.FValues[key] = value
+ return
+ }
+
+ if f > value {
+ s.FValues[key] = value
+ }
+}
+
+// Add a float maximum value.
+func (s *AggregateSet) addFloatMax(key string, value float64) {
+ f, ok := s.FValues[key]
+ if !ok {
+ s.FValues[key] = value
+ return
+ }
+
+ if f < value {
+ s.FValues[key] = value
+ }
+}
+
+// Set a string.
+func (s *AggregateSet) setString(key, value string) {
+ s.SValues[key] = value
+}
+
+// Set a float.
+func (s *AggregateSet) setFloat(key string, value float64) {
+ s.FValues[key] = value
+}
+
+// Aggregate data to the aggregate set.
+func (s *AggregateSet) Aggregate(key string, agg AggregateOperation, value string, clientAggregation bool) (err error) {
+ var f float64
+
+ // First check if we can aggregate anything without converting value to float.
+ switch agg {
+ case Count:
+ if clientAggregation {
+ f, err = strconv.ParseFloat(value, 64)
+ if err != nil {
+ return
+ }
+ s.addFloat(key, f)
+ return
+ }
+ s.addFloat(key, 1)
+ return
+ case Last:
+ s.setString(key, value)
+ return
+ case Len:
+ s.setString(key, value)
+ s.setFloat(key, float64(len(value)))
+ return
+ default:
+ }
+
+ // No, we have to convert to float.
+ f, err = strconv.ParseFloat(value, 64)
+ if err != nil {
+ return
+ }
+
+ switch agg {
+ case Sum:
+ fallthrough
+ case Avg:
+ s.addFloat(key, f)
+ case Min:
+ s.addFloatMin(key, f)
+ case Max:
+ s.addFloatMax(key, f)
+ default:
+ err = fmt.Errorf("Unknown aggregation method '%v'", agg)
+ }
+ return
+}
diff --git a/internal/mapr/client/aggregate.go b/internal/mapr/client/aggregate.go
new file mode 100644
index 0000000..3f2b7a5
--- /dev/null
+++ b/internal/mapr/client/aggregate.go
@@ -0,0 +1,100 @@
+package client
+
+import (
+ "github.com/mimecast/dtail/internal/logger"
+ "github.com/mimecast/dtail/internal/mapr"
+ "strconv"
+ "strings"
+)
+
+// Aggregate mapreduce data on the DTail client side.
+type Aggregate struct {
+ // This is the mapr query specified on the command line.
+ query *mapr.Query
+ // This represents aggregated data of a single remote server.
+ group *mapr.GroupSet
+ // This represents the merged aggregated data of all servers.
+ globalGroup *mapr.GlobalGroupSet
+ stop chan struct{}
+ // The server we aggregate the data for (logging and debugging purposes only)
+ server string
+}
+
+// NewAggregate create new client aggregator.
+func NewAggregate(server string, query *mapr.Query, globalGroup *mapr.GlobalGroupSet) *Aggregate {
+ return &Aggregate{
+ query: query,
+ group: mapr.NewGroupSet(),
+ globalGroup: globalGroup,
+ stop: make(chan struct{}),
+ server: server,
+ }
+}
+
+// Aggregate data from mapr log line into local (and global) group sets.
+func (a *Aggregate) Aggregate(parts []string) {
+ select {
+ case <-a.stop:
+ logger.Error("Client aggregator stopped for server, not processing new data", a.server)
+ return
+ default:
+ }
+
+ groupKey := parts[0]
+ samples, err := strconv.Atoi(parts[1])
+ if err != nil {
+ logger.FatalExit(parts, err)
+ }
+ fields := a.makeFields(parts[2:])
+ set := a.group.GetSet(groupKey)
+
+ var addedSamples bool
+ for _, sc := range a.query.Select {
+ if val, ok := fields[sc.FieldStorage]; ok {
+ if err := set.Aggregate(sc.FieldStorage, sc.Operation, val, true); err != nil {
+ logger.Error(err)
+ continue
+ }
+ addedSamples = true
+ }
+ }
+ if addedSamples {
+ set.Samples += samples
+ }
+
+ // Merge data from group into global group.
+ isMerged, err := a.globalGroup.MergeNoblock(a.query, a.group)
+ if err != nil {
+ panic(err)
+ }
+ if isMerged {
+ // Re-init local group (make it empty again).
+ a.group.InitSet()
+ }
+}
+
+// Create a map of key-value pairs from a part list such as ["foo=bar", "bar=baz"].
+func (a *Aggregate) makeFields(parts []string) map[string]string {
+ fields := make(map[string]string, len(parts))
+
+ for _, part := range parts {
+ kv := strings.Split(part, "=")
+ if len(kv) != 2 {
+ continue
+ }
+ fields[kv[0]] = kv[1]
+ }
+
+ return fields
+}
+
+// Stop the client side mapreduce aggregator.
+func (a *Aggregate) Stop() {
+ logger.Debug("Stopping client mapreduce aggregator")
+ close(a.stop)
+
+ err := a.globalGroup.Merge(a.query, a.group)
+ if err != nil {
+ panic(err)
+ }
+}
diff --git a/internal/mapr/globalgroupset.go b/internal/mapr/globalgroupset.go
new file mode 100644
index 0000000..cfab506
--- /dev/null
+++ b/internal/mapr/globalgroupset.go
@@ -0,0 +1,100 @@
+package mapr
+
+import (
+ "fmt"
+)
+
+// GlobalGroupSet is used on the dtail client to merge multiple group sets
+// (one group set per remote server) to one single global group set.
+type GlobalGroupSet struct {
+ GroupSet
+ semaphore chan struct{}
+}
+
+// NewGlobalGroupSet creates a new empty global group set.
+func NewGlobalGroupSet() *GlobalGroupSet {
+ g := GlobalGroupSet{
+ semaphore: make(chan struct{}, 1),
+ }
+ g.InitSet()
+
+ return &g
+}
+
+// String representation of the global group set.
+func (g *GlobalGroupSet) String() string {
+ return fmt.Sprintf("GlobalGroupSet(%s)", g.GroupSet.String())
+}
+
+// Merge (blocking) a group set into the global group set.
+func (g *GlobalGroupSet) Merge(query *Query, group *GroupSet) error {
+ g.semaphore <- struct{}{}
+ defer func() { <-g.semaphore }()
+
+ return g.merge(query, group)
+}
+
+// MergeNoblock merges (non-blocking) a group set into the global group set.
+func (g *GlobalGroupSet) MergeNoblock(query *Query, group *GroupSet) (bool, error) {
+ select {
+ case g.semaphore <- struct{}{}:
+ err := g.merge(query, group)
+ <-g.semaphore
+ return true, err
+ default:
+ return false, nil
+ }
+}
+
+// Merge a group set into the global group set.
+func (g *GlobalGroupSet) merge(query *Query, group *GroupSet) error {
+ for groupKey, set := range group.sets {
+ s := g.GetSet(groupKey)
+ if err := s.Merge(query, set); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// IsEmpty determines whether the global group set has any data in it.
+func (g *GlobalGroupSet) IsEmpty() bool {
+ return g.NumSets() == 0
+}
+
+// NumSets determines the number of sets.
+func (g *GlobalGroupSet) NumSets() int {
+ g.semaphore <- struct{}{}
+ defer func() { <-g.semaphore }()
+
+ return len(g.sets)
+}
+
+// SwapOut teturn the underlying group set and create a new empty one, so
+// that the global group set is empty again and can aggregate new data.
+func (g *GlobalGroupSet) SwapOut() *GroupSet {
+ g.semaphore <- struct{}{}
+ defer func() { <-g.semaphore }()
+
+ set := &GroupSet{sets: g.sets}
+ g.InitSet()
+
+ return set
+}
+
+// WriteResult writes the result of a mapreduce aggregation to an outfile.
+func (g *GlobalGroupSet) WriteResult(query *Query) error {
+ g.semaphore <- struct{}{}
+ defer func() { <-g.semaphore }()
+
+ return g.GroupSet.WriteResult(query)
+}
+
+// Result returns the result of the mapreduce aggregation as a string.
+func (g *GlobalGroupSet) Result(query *Query) (string, int, error) {
+ g.semaphore <- struct{}{}
+ defer func() { <-g.semaphore }()
+
+ return g.GroupSet.Result(query)
+}
diff --git a/internal/mapr/groupset.go b/internal/mapr/groupset.go
new file mode 100644
index 0000000..d8f9379
--- /dev/null
+++ b/internal/mapr/groupset.go
@@ -0,0 +1,178 @@
+package mapr
+
+import (
+ "errors"
+ "fmt"
+ "io/ioutil"
+ "sort"
+ "strconv"
+ "strings"
+)
+
+// GroupSet represents a map of aggregate sets. The group sets
+// are requierd by the "group by" mapr clause, whereas the
+// group set map keys are the values of the "group by" arguments.
+// E.g. "group by $cid" would create one aggregate set and one map
+// entry per customer id.
+type GroupSet struct {
+ sets map[string]*AggregateSet
+}
+
+// NewGroupSet returns a new empty group set.
+func NewGroupSet() *GroupSet {
+ g := GroupSet{}
+ g.InitSet()
+ return &g
+}
+
+// String representation of the group set.
+func (g *GroupSet) String() string {
+ return fmt.Sprintf("GroupSet(%v)", g.sets)
+}
+
+// InitSet makes the group set empty (initialize).
+func (g *GroupSet) InitSet() {
+ g.sets = make(map[string]*AggregateSet)
+}
+
+// GetSet gets a specific aggregate set from the group set.
+func (g *GroupSet) GetSet(groupKey string) *AggregateSet {
+ set, ok := g.sets[groupKey]
+ if !ok {
+ set = NewAggregateSet()
+ g.sets[groupKey] = set
+ }
+ return set
+}
+
+// Serialize the group set (e.g. to send it over the wire).
+func (g *GroupSet) Serialize(ch chan<- string, stop chan struct{}) {
+ for groupKey, set := range g.sets {
+ set.Serialize(groupKey, ch, stop)
+ }
+}
+
+// Result returns a nicely formated result of the query from the group set.
+func (g *GroupSet) Result(query *Query) (string, int, error) {
+ return g.limitedResult(query, query.Limit, "\t", " ", false)
+}
+
+// WriteResult writes the result to an outfile.
+func (g *GroupSet) WriteResult(query *Query) error {
+ if query.Outfile == "" {
+ return errors.New("No outfile specified")
+ }
+
+ // -1: Don't limit the result, include all data sets
+ result, _, err := g.limitedResult(query, -1, "", ",", true)
+ if err != nil {
+ return err
+ }
+
+ return ioutil.WriteFile(query.Outfile, []byte(result), 0644)
+}
+
+// Return a nicely formated result of the query from the group set.
+func (g *GroupSet) limitedResult(query *Query, limit int, lineStarter, fieldSeparator string, addHeader bool) (string, int, error) {
+ type result struct {
+ groupKey string
+ resultStr string
+ orderBy float64
+ }
+
+ var resultSlice []result
+
+ for groupKey, set := range g.sets {
+ var sb strings.Builder
+ r := result{groupKey: groupKey}
+
+ lastIndex := len(query.Select) - 1
+ for i, sc := range query.Select {
+ storage := sc.FieldStorage
+ orderByThis := storage == query.OrderBy
+
+ switch sc.Operation {
+ case Count:
+ value := set.FValues[storage]
+ sb.WriteString(fmt.Sprintf("%d", int(value)))
+ if orderByThis {
+ r.orderBy = value
+ }
+ case Len:
+ fallthrough
+ case Sum:
+ fallthrough
+ case Min:
+ fallthrough
+ case Max:
+ value := set.FValues[storage]
+ sb.WriteString(fmt.Sprintf("%f", value))
+ if orderByThis {
+ r.orderBy = value
+ }
+ case Last:
+ value := set.SValues[storage]
+ if orderByThis {
+ f, err := strconv.ParseFloat(value, 64)
+ if err == nil {
+ r.orderBy = f
+ }
+ }
+ sb.WriteString(value)
+ case Avg:
+ value := set.FValues[storage] / float64(set.Samples)
+ sb.WriteString(fmt.Sprintf("%f", value))
+ if orderByThis {
+ r.orderBy = value
+ }
+ default:
+ return "", 0, fmt.Errorf("Unknown aggregation method '%v'", sc.Operation)
+ }
+ if i != lastIndex {
+ sb.WriteString(fieldSeparator)
+ }
+ }
+
+ r.resultStr = sb.String()
+ resultSlice = append(resultSlice, r)
+ }
+
+ if query.OrderBy != "" {
+ if query.ReverseOrder {
+ sort.SliceStable(resultSlice, func(i, j int) bool {
+ return resultSlice[i].orderBy < resultSlice[j].orderBy
+ })
+ } else {
+ sort.SliceStable(resultSlice, func(i, j int) bool {
+ return resultSlice[i].orderBy > resultSlice[j].orderBy
+ })
+ }
+ }
+
+ var sb strings.Builder
+
+ // Write header first
+ if addHeader {
+ lastIndex := len(query.Select) - 1
+ sb.WriteString(lineStarter)
+ for i, sc := range query.Select {
+ sb.WriteString(sc.FieldStorage)
+ if i != lastIndex {
+ sb.WriteString(fieldSeparator)
+ }
+ }
+ sb.WriteString("\n")
+ }
+
+ // And now write the data
+ for i, r := range resultSlice {
+ if i == limit {
+ break
+ }
+ sb.WriteString(lineStarter)
+ sb.WriteString(r.resultStr)
+ sb.WriteString("\n")
+ }
+
+ return sb.String(), len(resultSlice), nil
+}
diff --git a/internal/mapr/logformat/default.go b/internal/mapr/logformat/default.go
new file mode 100644
index 0000000..f0df5bc
--- /dev/null
+++ b/internal/mapr/logformat/default.go
@@ -0,0 +1,23 @@
+package logformat
+
+import (
+ "errors"
+ "strings"
+)
+
+// MakeFieldsDEFAULT is the default log file mapreduce parser.
+func (p *Parser) MakeFieldsDEFAULT(maprLine string) (map[string]string, error) {
+ fields := make(map[string]string, 20)
+ splitted := strings.Split(maprLine, "|")
+
+ fields["$hostname"] = p.hostname
+
+ for _, kv := range splitted {
+ keyAndValue := strings.SplitN(kv, "=", 2)
+ if len(keyAndValue) != 2 {
+ return fields, errors.New("Error parsing mapr token: " + kv)
+ }
+ fields[strings.ToLower(keyAndValue[0])] = keyAndValue[1]
+ }
+ return fields, nil
+}
diff --git a/internal/mapr/logformat/default_test.go b/internal/mapr/logformat/default_test.go
new file mode 100644
index 0000000..a3c47fb
--- /dev/null
+++ b/internal/mapr/logformat/default_test.go
@@ -0,0 +1,35 @@
+package logformat
+
+import (
+ "testing"
+)
+
+func TestDefaultLogFormat(t *testing.T) {
+ parser, err := NewParser("default")
+ if err != nil {
+ t.Errorf("Unable to create parser: %s", err.Error())
+ }
+
+ fields, err := parser.MakeFields("foo=bar|baz=bay")
+
+ if err != nil {
+ t.Errorf("Unable to parse: %s", err.Error())
+ }
+
+ if bar, ok := fields["foo"]; !ok {
+ t.Errorf("Expected field 'foo', but no such field there\n")
+ } else if bar != "bar" {
+ t.Errorf("Expected 'bar' stored in field 'foo', but got '%s'\n", bar)
+ }
+
+ if bay, ok := fields["baz"]; !ok {
+ t.Errorf("Expected field 'baz', but no such field there\n")
+ } else if bay != "bay" {
+ t.Errorf("Expected 'bay' stored in field 'baz', but got '%s'\n", bay)
+ }
+
+ fields, err = parser.MakeFields("foo=bar|bazbay")
+ if err == nil {
+ t.Errorf("Expected error but didn't: %s", err.Error())
+ }
+}
diff --git a/internal/mapr/logformat/parser.go b/internal/mapr/logformat/parser.go
new file mode 100644
index 0000000..5730d29
--- /dev/null
+++ b/internal/mapr/logformat/parser.go
@@ -0,0 +1,75 @@
+package logformat
+
+import (
+ "github.com/mimecast/dtail/internal/logger"
+ "errors"
+ "fmt"
+ "os"
+ "reflect"
+ "strings"
+)
+
+// Parser is used to parse the mapreduce information from the server log files.
+type Parser struct {
+ hostname string
+ logFormatName string
+ makeFieldsFunc reflect.Value
+ makeFieldsReceiver reflect.Value
+}
+
+// NewParser returns a new log parser.
+func NewParser(logFormatName string) (*Parser, error) {
+ hostname, err := os.Hostname()
+
+ if err != nil {
+ return nil, err
+ }
+
+ p := Parser{
+ hostname: hostname,
+ }
+
+ err = p.reflectLogFormat(logFormatName)
+ if err != nil {
+ return nil, err
+ }
+
+ return &p, nil
+}
+
+// The aim of this is that everyone can plug in their own mapr log format
+// parsing method to DTail. Just add a method MakeFieldsMODULENAME to type
+// Parser. Whereas MODULENAME must be a upeprcase string.
+func (p *Parser) reflectLogFormat(logFormatName string) error {
+ methodName := fmt.Sprintf("MakeFields%s", strings.ToUpper(logFormatName))
+
+ rt := reflect.TypeOf(p)
+ method, ok := rt.MethodByName(methodName)
+ if !ok {
+ return errors.New("No such mapr log format module: " + methodName)
+ }
+
+ p.makeFieldsFunc = method.Func
+ p.makeFieldsReceiver = reflect.ValueOf(p)
+
+ return nil
+}
+
+// MakeFields is for returning the fields from a given log line.
+func (p *Parser) MakeFields(maprLine string) (fields map[string]string, err error) {
+ inputValues := []reflect.Value{p.makeFieldsReceiver, reflect.ValueOf(maprLine)}
+ returnValues := p.makeFieldsFunc.Call(inputValues)
+
+ errInterface := returnValues[1].Interface()
+
+ if errInterface == nil {
+ fields, err = returnValues[0].Interface().(map[string]string), nil
+ logger.Trace("parser.MakeFields", fields, err)
+ return
+ }
+
+ fields, err = returnValues[0].Interface().(map[string]string), errInterface.(error)
+ logger.Trace("parser.MakeFields", fields, err)
+
+ return
+}
diff --git a/internal/mapr/query.go b/internal/mapr/query.go
new file mode 100644
index 0000000..3805d15
--- /dev/null
+++ b/internal/mapr/query.go
@@ -0,0 +1,245 @@
+package mapr
+
+import (
+ "github.com/mimecast/dtail/internal/logger"
+ "errors"
+ "fmt"
+ "strconv"
+ "strings"
+ "time"
+)
+
+const (
+ invalidQuery string = "Invalid query: "
+ unexpectedEnd string = "Unexpected end of query"
+)
+
+// Query represents a parsed mapr query.
+type Query struct {
+ Select []selectCondition
+ Table string
+ Where []whereCondition
+ GroupBy []string
+ OrderBy string
+ ReverseOrder bool
+ GroupKey string
+ Interval time.Duration
+ Limit int
+ Outfile string
+ RawQuery string
+ tokens []token
+}
+
+func (q Query) String() string {
+ return fmt.Sprintf("Query(Select:%v,Table:%s,Where:%v,GroupBy:%v,GroupKey:%s,OrderBy:%v,ReverseOrder:%v,Interval:%v,Limit:%d,Outfile:%s,RawQuery:%s,tokens:%v)",
+ q.Select,
+ q.Table,
+ q.Where,
+ q.GroupBy,
+ q.GroupKey,
+ q.OrderBy,
+ q.ReverseOrder,
+ q.Interval,
+ q.Limit,
+ q.Outfile,
+ q.RawQuery,
+ q.tokens)
+}
+
+// NewQuery returns a new mapreduce query.
+func NewQuery(queryStr string) (*Query, error) {
+ if queryStr == "" {
+ return nil, nil
+ }
+
+ tokens := tokenize(queryStr)
+
+ q := Query{
+ RawQuery: queryStr,
+ tokens: tokens,
+ Interval: time.Second * 5,
+ Limit: -1,
+ }
+
+ err := q.parse(tokens)
+
+ logger.Debug(q)
+ return &q, err
+}
+
+func (q *Query) parse(tokens []token) error {
+ var found []token
+ var err error
+
+ for tokens != nil && len(tokens) > 0 {
+ switch strings.ToLower(tokens[0].str) {
+ case "select":
+ tokens, found = tokensConsume(tokens[1:])
+ q.Select, err = makeSelectConditions(found)
+ if err != nil {
+ return err
+ }
+ case "from":
+ tokens, found = tokensConsume(tokens[1:])
+ if len(found) > 0 {
+ q.Table = strings.ToUpper(found[0].str)
+ }
+ case "where":
+ tokens, found = tokensConsume(tokens[1:])
+ if q.Where, err = makeWhereConditions(found); err != nil {
+ return err
+ }
+ case "group":
+ tokens = tokensConsumeOptional(tokens[1:], "by")
+ if tokens == nil || len(tokens) < 1 {
+ return errors.New(invalidQuery + unexpectedEnd)
+ }
+ tokens, q.GroupBy = tokensConsumeStr(tokens)
+ q.GroupKey = strings.Join(q.GroupBy, ",")
+ case "rorder":
+ tokens = tokensConsumeOptional(tokens[1:], "by")
+ if tokens == nil || len(tokens) < 1 {
+ return errors.New(invalidQuery + unexpectedEnd)
+ }
+ tokens, found = tokensConsume(tokens)
+ if len(found) == 0 {
+ return errors.New(invalidQuery + unexpectedEnd)
+ }
+ q.OrderBy = found[0].str
+ q.ReverseOrder = true
+ case "order":
+ tokens = tokensConsumeOptional(tokens[1:], "by")
+ if tokens == nil || len(tokens) < 1 {
+ return errors.New(invalidQuery + unexpectedEnd)
+ }
+ tokens, found = tokensConsume(tokens)
+ if len(found) == 0 {
+ return errors.New(invalidQuery + unexpectedEnd)
+ }
+ q.OrderBy = found[0].str
+ case "interval":
+ tokens, found = tokensConsume(tokens[1:])
+ if len(found) > 0 {
+ i, err := strconv.Atoi(found[0].str)
+ if err != nil {
+ return errors.New(invalidQuery + err.Error())
+ }
+ q.Interval = time.Second * time.Duration(i)
+ }
+ case "limit":
+ tokens, found = tokensConsume(tokens[1:])
+ if len(found) == 0 {
+ return errors.New(invalidQuery + unexpectedEnd)
+ }
+ i, err := strconv.Atoi(found[0].str)
+ if err != nil {
+ return errors.New(invalidQuery + err.Error())
+ }
+ q.Limit = i
+ case "outfile":
+ tokens, found = tokensConsume(tokens[1:])
+ if len(found) == 0 {
+ return errors.New(invalidQuery + unexpectedEnd)
+ }
+ q.Outfile = found[0].str
+ default:
+ return errors.New(invalidQuery + "Unexpected keyword " + tokens[0].str)
+ }
+ }
+
+ if q.Table == "" {
+ return errors.New(invalidQuery + "Empty table specified in 'from' clause")
+ }
+ if len(q.Select) < 1 {
+ return errors.New(invalidQuery + "Expected at least one field in 'select' clause but got none")
+ }
+ if len(q.GroupBy) == 0 {
+ field := q.Select[0].Field
+ q.GroupBy = append(q.GroupBy, field)
+ }
+
+ if q.OrderBy != "" {
+ var orderFieldIsValid bool
+ for _, sc := range q.Select {
+ if q.OrderBy == sc.FieldStorage {
+ orderFieldIsValid = true
+ break
+ }
+ }
+ if !orderFieldIsValid {
+ return errors.New(invalidQuery + fmt.Sprintf("Can not '(r)order by' '%s', must be present in 'select' clause", q.OrderBy))
+ }
+ }
+
+ return nil
+}
+
+// WhereClause interprets the where clause of the mapreduce query.
+func (q *Query) WhereClause(fields map[string]string) bool {
+ floatValue := func(str string, float float64, t whereType) (float64, bool) {
+ switch t {
+ case Float:
+ return float, true
+ case Field:
+ value, ok := fields[str]
+ if !ok {
+ return 0, false
+ }
+ f, err := strconv.ParseFloat(value, 64)
+ if err != nil {
+ return 0, false
+ }
+ return f, true
+ default:
+ logger.Error("Unexpected argument in 'where' clause", str, float, t)
+ return 0, false
+ }
+ }
+
+ stringValue := func(str string, t whereType) (string, bool) {
+ switch t {
+ case Field:
+ value, ok := fields[str]
+ if !ok {
+ return str, false
+ }
+ return value, true
+ case String:
+ return str, true
+ default:
+ logger.Error("Unexpected argument in 'where' clause", str, t)
+ return str, false
+ }
+ }
+
+ for _, wc := range q.Where {
+ var ok bool
+
+ if wc.Operation > FloatOperation {
+ var lValue, rValue float64
+ if lValue, ok = floatValue(wc.lString, wc.lFloat, wc.lType); !ok {
+ return false
+ }
+ if rValue, ok = floatValue(wc.rString, wc.rFloat, wc.rType); !ok {
+ return false
+ }
+ if ok = wc.floatClause(lValue, rValue); !ok {
+ return false
+ }
+ continue
+ }
+
+ var lValue, rValue string
+ if lValue, ok = stringValue(wc.lString, wc.lType); !ok {
+ return false
+ }
+ if rValue, ok = stringValue(wc.rString, wc.rType); !ok {
+ return false
+ }
+ if ok = wc.stringClause(lValue, rValue); !ok {
+ return false
+ }
+ }
+
+ return true
+}
diff --git a/internal/mapr/query_test.go b/internal/mapr/query_test.go
new file mode 100644
index 0000000..6176461
--- /dev/null
+++ b/internal/mapr/query_test.go
@@ -0,0 +1,149 @@
+package mapr
+
+import (
+ "testing"
+ "time"
+)
+
+func TestParseQuerySimple(t *testing.T) {
+ errorQueries := []string{
+ "select",
+ "select foo",
+ "select foo from",
+ "select foo from bar where baz",
+ "select foo from bar where baz <",
+ "select foo from bar where baz < 100 bay eq 12 group",
+ "select foo from bar where baz < 100 bay eq 12 group by foo order by",
+ "select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz order by foo limit",
+ }
+ okQueries := []string{"select foo from bar",
+ "select foo from bar where",
+ "select foo from bar where baz < 100 bay eq 12",
+ "select foo from bar where baz < 100, bay eq 12",
+ "select foo from bar where baz < 100 and bay eq 12",
+ "select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz order by foo",
+ "select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz order by foo limit 23",
+ "select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz order by foo limit 23 outfile \"result.csv\"",
+ }
+
+ for _, queryStr := range errorQueries {
+ q, err := NewQuery(queryStr)
+ if err == nil {
+ t.Errorf("Expected a parse error: %s\n%v", queryStr, q)
+ continue
+ }
+ }
+
+ for _, queryStr := range okQueries {
+ _, err := NewQuery(queryStr)
+ if err != nil {
+ t.Errorf("%s: %s", err.Error(), queryStr)
+ continue
+ }
+ }
+}
+
+func TestParseQueryDeep(t *testing.T) {
+ dialects := []string{
+ "select s1, `from`, count(s3) from table where w1 == 2 and w2 eq \"free beer\" group by g1, g2 order by count(s3) interval 10 limit 23",
+ "SELECT s1, `from` COUNT(s3) FROM table WHERE w1 == 2 AND w2 eq \"free beer\" GROUP g1, g2 ORDER count(s3) INTERVAL 10 LIMIT 23",
+ "select s1, `from` count(s3) from table where w1 == 2 and w2 eq \"free beer\" group by g1, g2 order by count(s3) interval 10 limit 23",
+ "sElEct s1, `from` coUnt(s3) from taBle where w1 == 2 aNd w2 eq \"free beer\" Group By g1, g2 order bY count(s3) intervaL 10 LiMiT 23",
+ "SELECT s1 `from` COUNT(s3) FROM table WHERE w1 == 2 AND w2 eq \"free beer\" GROUP BY g1 g2 ORDER BY count(s3) INTERVAL 10 LIMIT 23",
+ "select s1 `from` count(s3) from table where w1 == 2 w2 eq \"free beer\" group g1 g2 order count(s3) interval 10 limit 23",
+ "limit 23 interval 10 order count(s3) group g1 g2 where w1 == 2 w2 eq \"free beer\" from table select s1 `from` count(s3)",
+ }
+
+ for _, queryStr := range dialects {
+ q, err := NewQuery(queryStr)
+ if err != nil {
+ t.Errorf("%s: %s", err.Error(), queryStr)
+ }
+
+ // 'select' clause
+ if len(q.Select) != 3 {
+ t.Errorf("Expected three elements in 'select' clause but got '%v': %s\n%v", q.Select, queryStr, q)
+ }
+
+ if q.Select[0].Field != "s1" {
+ t.Errorf("Expected 's1' as first element in 'select' clause but got '%v': %s\n%v", q.Select[0].Field, queryStr, q)
+ }
+ if q.Select[0].Operation != Last {
+ t.Errorf("Expected 'last' as aggregation function of first element in 'select' clause but got '%v': %s\n%v", q.Select[0].Operation, queryStr, q)
+ }
+
+ if q.Select[1].Field != "from" {
+ t.Errorf("Expected 'from' as second element in 'select' clause but got '%v': %s\n%v", q.Select[1].Field, queryStr, q)
+ }
+ if q.Select[1].Operation != Last {
+ t.Errorf("Expected 'last' as aggregation function of second element in 'select' clause but got '%v': %s\n%v", q.Select[1].Operation, queryStr, q)
+ }
+
+ if q.Select[2].Field != "s3" {
+ t.Errorf("Expected 's3' as third element in 'select' clause but got '%v': %s\n%v", q.Select[2].Field, queryStr, q)
+ }
+ if q.Select[2].Operation != Count {
+ t.Errorf("Expected 'count' as aggregation function of third element in 'select' clause but got '%v': %s\n%v", q.Select[2].Operation, queryStr, q)
+ }
+ if q.Select[2].FieldStorage != "count(s3)" {
+ t.Errorf("Expected 'count(s3)' as third element's storage in 'select' clause but got '%v': %s\n%v", q.Select[2].FieldStorage, queryStr, q)
+ }
+
+ // 'from' clause
+ if q.Table != "TABLE" {
+ t.Errorf("Expected 'TABLE' in 'from' clause but got '%v': %s\n%v", q.Table, queryStr, q)
+ }
+
+ // 'where' clause
+ if len(q.Where) != 2 {
+ t.Errorf("Expected two elements in 'where' clause but got '%v': %s\n%v", q.Where, queryStr, q)
+ }
+ if q.Where[0].lString != "w1" {
+ t.Errorf("Expected w1 as first element in 'where' clause but got '%v': %s\n%v", q.Where[0].lString, queryStr, q)
+ }
+ if q.Where[0].Operation != FloatEq {
+ t.Errorf("Expected FloatEq operation in first 'where' condition but got '%v': %s\n%v", q.Where[0].Operation, queryStr, q)
+ }
+ if q.Where[0].rFloat != 2 {
+ t.Errorf("Expected '2' as float argument in first 'where' condition but got '%v': %s\n%v", q.Where[0].rFloat, queryStr, q)
+ }
+ if q.Where[1].lString != "w2" {
+ t.Errorf("Expected w2 as second element in 'where' clause but got '%v': %s\n%v", q.Where[1].lString, queryStr, q)
+ }
+ if q.Where[1].Operation != StringEq {
+ t.Errorf("Expected StringEq operation in second 'where' condition but got '%v': %s\n%v", q.Where[0].Operation, queryStr, q)
+ }
+ if q.Where[1].rString != "free beer" {
+ t.Errorf("Expected 'free beer' as string argument in second 'where' condition but got '%v': %s\n%v", q.Where[0].rString, queryStr, q)
+ }
+
+ // 'group by' clause
+ if len(q.GroupBy) != 2 {
+ t.Errorf("Expected two elements in 'group by' clause but got '%v': %s\n%v", q.GroupBy, queryStr, q)
+ }
+ if q.GroupBy[0] != "g1" {
+ t.Errorf("Expected 'g1' as first element in 'group by' clause but got '%v': %s\n%v", q.GroupBy[0], queryStr, q)
+ }
+ if q.GroupBy[1] != "g2" {
+ t.Errorf("Expected 'g2' as second element in 'group by' clause but got '%v': %s\n%v", q.GroupBy[1], queryStr, q)
+ }
+ if q.GroupKey != "g1,g2" {
+ t.Errorf("Expected 'g1,g2' as group key in 'group by' clause but got '%v': %s\n%v", q.GroupKey, queryStr, q)
+ }
+
+ // 'order by' clause
+ if q.OrderBy != "count(s3)" {
+ t.Errorf("Expected 'count(s3)' as element in 'order by' clause but got '%v': %s\n%v", q.OrderBy, queryStr, q)
+ }
+
+ // 'interval' clause
+ if q.Interval != time.Second*time.Duration(10) {
+ t.Errorf("Expected '10s' as duration 'interval' clause but got '%v': %s\n%v", q.Interval, queryStr, q)
+ }
+
+ // 'limit' clause
+ if q.Limit != 23 {
+ t.Errorf("Expected '23' as limit in 'limit' clause but got '%v': %s\n%v", q.Limit, queryStr, q)
+ }
+ }
+}
diff --git a/internal/mapr/selectcondition.go b/internal/mapr/selectcondition.go
new file mode 100644
index 0000000..1882b7e
--- /dev/null
+++ b/internal/mapr/selectcondition.go
@@ -0,0 +1,96 @@
+package mapr
+
+import (
+ "errors"
+ "fmt"
+ "strings"
+)
+
+// AggregateOperation is to specify the aggregate operation type.
+type AggregateOperation int
+
+// Aggregate operation types
+const (
+ UndefAggregateOperation AggregateOperation = iota
+ Count AggregateOperation = iota
+ Sum AggregateOperation = iota
+ Min AggregateOperation = iota
+ Max AggregateOperation = iota
+ Last AggregateOperation = iota
+ Avg AggregateOperation = iota
+ Len AggregateOperation = iota
+)
+
+// Represents a parsed "select" clause, used by mapr.Query.
+type selectCondition struct {
+ Field string
+ FieldStorage string
+ Operation AggregateOperation
+}
+
+func (sc selectCondition) String() string {
+ return fmt.Sprintf("selectCondition(Field:%s,FieldStorage:%s,Operation:%v)",
+ sc.Field,
+ sc.FieldStorage,
+ sc.Operation)
+}
+
+func makeSelectConditions(tokens []token) ([]selectCondition, error) {
+ var sel []selectCondition
+
+ // Parse select aggregation, e.g. sum(foo)
+ parse := func(token token) (selectCondition, error) {
+ var sc selectCondition
+ tokenStr := strings.ToLower(token.str)
+
+ if !strings.Contains(tokenStr, "(") && !strings.Contains(tokenStr, ")") {
+ sc.Field = tokenStr
+ sc.FieldStorage = tokenStr
+ sc.Operation = Last
+ return sc, nil
+ }
+
+ a := strings.Split(tokenStr, "(")
+ if len(a) != 2 {
+ return sc, errors.New(invalidQuery + "Can't parse 'select' aggregation: " + token.str)
+ }
+ agg := a[0] // Aggregation, e.g. 'sum'
+
+ b := strings.Split(a[1], ")")
+ if len(b) != 2 {
+ return sc, errors.New(invalidQuery + "Can't parse 'select' field name from aggregation: " + token.str)
+ }
+ sc.Field = b[0] // Field name, e.g. 'foo'
+ sc.FieldStorage = tokenStr // e.g. 'sum(foo)'
+
+ switch agg {
+ case "count":
+ sc.Operation = Count
+ case "sum":
+ sc.Operation = Sum
+ case "min":
+ sc.Operation = Min
+ case "max":
+ sc.Operation = Max
+ case "last":
+ sc.Operation = Last
+ case "avg":
+ sc.Operation = Avg
+ case "len":
+ sc.Operation = Len
+ default:
+ return sc, errors.New(invalidQuery + "Unknown aggregation in 'select' clause: " + agg)
+ }
+
+ return sc, nil
+ }
+
+ for _, token := range tokens {
+ sc, err := parse(token)
+ if err != nil {
+ return nil, err
+ }
+ sel = append(sel, sc)
+ }
+ return sel, nil
+}
diff --git a/internal/mapr/server/aggregate.go b/internal/mapr/server/aggregate.go
new file mode 100644
index 0000000..900756e
--- /dev/null
+++ b/internal/mapr/server/aggregate.go
@@ -0,0 +1,170 @@
+package server
+
+import (
+ "github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/fs"
+ "github.com/mimecast/dtail/internal/logger"
+ "github.com/mimecast/dtail/internal/mapr"
+ "github.com/mimecast/dtail/internal/mapr/logformat"
+ "os"
+ "strings"
+ "time"
+)
+
+// Aggregate is for aggregating mapreduce data on the DTail server side.
+type Aggregate struct {
+ // Log lines to process (parsing MAPREDUCE lines).
+ Lines chan fs.LineRead
+ // Hostname of the current server (used to populate $hostname field).
+ hostname string
+ // Signals to exit goroutine.
+ stop chan struct{}
+ // Signals to serialize data.
+ serialize chan struct{}
+ // The mapr query
+ query *mapr.Query
+ // The mapr log format parser
+ parser *logformat.Parser
+}
+
+// NewAggregate return a new server side aggregator.
+func NewAggregate(maprLines chan<- string, queryStr string) (*Aggregate, error) {
+ query, err := mapr.NewQuery(queryStr)
+ if err != nil {
+ return nil, err
+ }
+
+ fqdn, err := os.Hostname()
+ if err != nil {
+ logger.Error(err)
+ }
+ s := strings.Split(fqdn, ".")
+
+ logger.Info("Creating mapr log format parser", config.Server.MapreduceLogFormat)
+ logParser, err := logformat.NewParser(config.Server.MapreduceLogFormat)
+ if err != nil {
+ logger.FatalExit("Could not create mapr log format parser", err)
+ }
+
+ a := Aggregate{
+ Lines: make(chan fs.LineRead, 100),
+ stop: make(chan struct{}),
+ serialize: make(chan struct{}),
+ hostname: s[0],
+ query: query,
+ parser: logParser,
+ }
+
+ go a.periodicAggregateTimer()
+
+ fieldsCh := make(chan map[string]string)
+ go a.readFields(fieldsCh, maprLines)
+ go a.readLines(fieldsCh)
+
+ return &a, nil
+}
+
+func (a *Aggregate) periodicAggregateTimer() {
+ for {
+ select {
+ case <-time.After(a.query.Interval):
+ a.Serialize()
+ case <-a.stop:
+ return
+ }
+ }
+}
+
+func (a *Aggregate) readFields(fieldsCh <-chan map[string]string, maprLines chan<- string) {
+ group := mapr.NewGroupSet()
+
+ for {
+ select {
+ case fields := <-fieldsCh:
+ a.aggregate(group, fields)
+ case <-a.serialize:
+ logger.Info("Serializing mapreduce result")
+ group.Serialize(maprLines, a.stop)
+ logger.Info("Done serializing mapreduce result")
+ group = mapr.NewGroupSet()
+ case <-a.stop:
+ return
+ }
+ }
+}
+
+func (a *Aggregate) readLines(fieldsCh chan<- map[string]string) {
+ for {
+ select {
+ case line, ok := <-a.Lines:
+ if !ok {
+ return
+ }
+
+ maprLine := strings.TrimSpace(string(line.Content))
+ fields, err := a.parser.MakeFields(maprLine)
+
+ if err != nil {
+ logger.Error(err)
+ continue
+ }
+ if !a.query.WhereClause(fields) {
+ continue
+ }
+
+ select {
+ case fieldsCh <- fields:
+ case <-a.stop:
+ }
+ case <-a.stop:
+ return
+ }
+ }
+}
+
+func (a *Aggregate) aggregate(group *mapr.GroupSet, fields map[string]string) {
+ //logger.Trace("Aggregating", group, fields)
+ var sb strings.Builder
+
+ for i, field := range a.query.GroupBy {
+ if i > 0 {
+ sb.WriteString(" ")
+ }
+ if val, ok := fields[field]; ok {
+ sb.WriteString(val)
+ }
+ }
+ groupKey := sb.String()
+ set := group.GetSet(groupKey)
+
+ var addedSample bool
+ for _, sc := range a.query.Select {
+ if val, ok := fields[sc.Field]; ok {
+ if err := set.Aggregate(sc.FieldStorage, sc.Operation, val, false); err != nil {
+ logger.Error(err)
+ continue
+ }
+ addedSample = true
+ }
+ }
+
+ if addedSample {
+ set.Samples++
+ return
+ }
+
+ logger.Trace("Aggregated data locally without adding new samples")
+}
+
+// Serialize all the aggregated data.
+func (a *Aggregate) Serialize() {
+ select {
+ case a.serialize <- struct{}{}:
+ case <-a.stop:
+ }
+}
+
+// Close the aggregator.
+func (a *Aggregate) Close() {
+ close(a.stop)
+}
diff --git a/internal/mapr/token.go b/internal/mapr/token.go
new file mode 100644
index 0000000..b8be4da
--- /dev/null
+++ b/internal/mapr/token.go
@@ -0,0 +1,108 @@
+package mapr
+
+import (
+ "strings"
+)
+
+var keywords = [...]string{"select", "from", "where", "group", "rorder", "order", "interval", "limit", "outfile"}
+
+// Represents a parsed token, used to parse the mapr query.
+type token struct {
+ str string
+ isBareword bool
+}
+
+func (t token) isKeyword() bool {
+ if !t.isBareword {
+ return false
+ }
+
+ for _, keyword := range keywords {
+ if strings.ToLower(t.str) == keyword {
+ return true
+ }
+ }
+ return false
+}
+
+func (t token) String() string {
+ return t.str
+}
+
+func tokenize(queryStr string) []token {
+ var tokens []token
+
+ for i, part := range strings.Split(queryStr, "\"") {
+ // Even i, means that it is not a quoted string
+ if i%2 == 0 {
+ commasStripped := strings.Replace(part, ",", " ", -1)
+ for _, tokenStr := range strings.Fields(commasStripped) {
+ token := token{
+ str: tokenStr,
+ isBareword: true,
+ }
+ tokens = append(tokens, token)
+ }
+ continue
+ }
+ // Add whole quoted string as a token
+ token := token{
+ str: part,
+ isBareword: false,
+ }
+ tokens = append(tokens, token)
+ }
+
+ return tokens
+}
+
+func tokensConsume(tokens []token) ([]token, []token) {
+ //logger.Trace("=====================")
+ var consumed []token
+
+ for i, t := range tokens {
+ if t.isKeyword() {
+ //logger.Trace("keyword", t)
+ return tokens[i:], consumed
+ }
+ // strip escapes, such as ` from `foo`, this allows to use keywords as field names
+ length := len(t.str)
+ if length == 0 {
+ continue
+ }
+ if t.str[0] == '`' && t.str[length-1] == '`' {
+ stripped := t.str[1 : length-1]
+ //logger.Trace("stripped", stripped)
+ t := token{
+ str: stripped,
+ isBareword: t.isBareword,
+ }
+ consumed = append(consumed, t)
+ continue
+ }
+ //logger.Trace("bare", token)
+ consumed = append(consumed, t)
+ }
+
+ //logger.Trace("result", consumed)
+ return nil, consumed
+}
+
+func tokensConsumeStr(tokens []token) ([]token, []string) {
+ var strings []string
+ tokens, found := tokensConsume(tokens)
+ for _, token := range found {
+ strings = append(strings, token.str)
+ }
+ return tokens, strings
+}
+
+func tokensConsumeOptional(tokens []token, optional string) []token {
+ if len(tokens) < 1 {
+ return tokens
+ }
+ if strings.ToLower(tokens[0].str) == strings.ToLower(optional) {
+ return tokens[1:]
+ }
+ return tokens
+}
diff --git a/internal/mapr/wherecondition.go b/internal/mapr/wherecondition.go
new file mode 100644
index 0000000..e1f4e5b
--- /dev/null
+++ b/internal/mapr/wherecondition.go
@@ -0,0 +1,193 @@
+package mapr
+
+import (
+ "github.com/mimecast/dtail/internal/logger"
+ "errors"
+ "fmt"
+ "strconv"
+ "strings"
+)
+
+// QueryOperation determines the mapreduce operation.
+type QueryOperation int
+
+// The possible mapreduce operation.s
+const (
+ UndefQueryOperation QueryOperation = iota
+ StringEq QueryOperation = iota
+ StringNe QueryOperation = iota
+ StringContains QueryOperation = iota
+ FloatOperation QueryOperation = iota
+ FloatEq QueryOperation = iota
+ FloatNe QueryOperation = iota
+ FloatLt QueryOperation = iota
+ FloatLe QueryOperation = iota
+ FloatGt QueryOperation = iota
+ FloatGe QueryOperation = iota
+)
+
+type whereType int
+
+// The possible field types.
+const (
+ UndefWhereType whereType = iota
+ Field whereType = iota
+ String whereType = iota
+ Float whereType = iota
+)
+
+func (w whereType) String() string {
+ switch w {
+ case Field:
+ return fmt.Sprintf("Field")
+ case String:
+ return fmt.Sprintf("String")
+ case Float:
+ return fmt.Sprintf("Float")
+ default:
+ return fmt.Sprintf("UndefWhereType")
+ }
+}
+
+// Represent a parsed "where" clause, used by mapr.Query
+type whereCondition struct {
+ lString string
+ lFloat float64
+ lType whereType
+
+ Operation QueryOperation
+
+ rString string
+ rFloat float64
+ rType whereType
+}
+
+func (wc *whereCondition) String() string {
+ return fmt.Sprintf("whereCondition(Operation:%v,lString:%s,lFloat:%v,lType:%s,rString:%s,rFloat:%v,rType:%s)",
+ wc.Operation, wc.lString, wc.lFloat, wc.lType.String(), wc.rString, wc.rFloat, wc.rType.String())
+}
+
+func makeWhereConditions(tokens []token) (where []whereCondition, err error) {
+ parse := func(tokens []token) (whereCondition, []token, error) {
+ var wc whereCondition
+ if len(tokens) < 3 {
+ return wc, nil, errors.New(invalidQuery + "Not enough arguments in 'where' clause")
+ }
+
+ whereOp := strings.ToLower(tokens[1].str)
+ switch whereOp {
+ case "==":
+ wc.Operation = FloatEq
+ case "!=":
+ wc.Operation = FloatNe
+ case "<":
+ wc.Operation = FloatLt
+ case "<=":
+ wc.Operation = FloatLe
+ case "=<":
+ wc.Operation = FloatLe
+ case ">":
+ wc.Operation = FloatGt
+ case ">=":
+ wc.Operation = FloatGe
+ case "=>":
+ wc.Operation = FloatGe
+ case "eq":
+ wc.Operation = StringEq
+ case "ne":
+ wc.Operation = StringNe
+ case "contains":
+ wc.Operation = StringContains
+ default:
+ return wc, nil, errors.New(invalidQuery + "Unknown operation in 'where' clause: " + whereOp)
+ }
+
+ wc.lString = tokens[0].str
+ wc.rString = tokens[2].str
+
+ if wc.Operation > FloatOperation {
+ if !tokens[0].isBareword {
+ return wc, nil, errors.New(invalidQuery + "Expected bareword at 'where' clause's lValue: " + tokens[0].str)
+ }
+ if f, err := strconv.ParseFloat(wc.lString, 64); err == nil {
+ wc.lFloat = f
+ wc.lType = Float
+ } else {
+ wc.lType = Field
+ }
+
+ if !tokens[2].isBareword {
+ return wc, nil, errors.New(invalidQuery + "Expected bareword at 'where' clause's rValue: " + tokens[2].str)
+ }
+ if f, err := strconv.ParseFloat(wc.rString, 64); err == nil {
+ wc.rFloat = f
+ wc.rType = Float
+ } else {
+ wc.rType = Field
+ }
+ return wc, tokens[3:], nil
+ }
+
+ if tokens[0].isBareword {
+ wc.lType = Field
+ } else {
+ wc.lType = String
+ }
+ if tokens[2].isBareword {
+ wc.rType = Field
+ } else {
+ wc.rType = String
+ }
+
+ return wc, tokens[3:], nil
+ }
+
+ for len(tokens) > 0 {
+ var wc whereCondition
+ var err error
+
+ wc, tokens, err = parse(tokens)
+ if err != nil {
+ return nil, err
+ }
+
+ where = append(where, wc)
+ tokens = tokensConsumeOptional(tokens, "and")
+ }
+
+ return
+}
+
+func (wc *whereCondition) floatClause(lValue float64, rValue float64) bool {
+ switch wc.Operation {
+ case FloatEq:
+ return lValue == rValue
+ case FloatNe:
+ return lValue != rValue
+ case FloatLt:
+ return lValue < rValue
+ case FloatLe:
+ return lValue <= rValue
+ case FloatGt:
+ return lValue > rValue
+ case FloatGe:
+ return lValue >= rValue
+ default:
+ logger.Error("Unknown float operation", lValue, wc.Operation, rValue)
+ }
+ return false
+}
+
+func (wc *whereCondition) stringClause(lValue string, rValue string) bool {
+ switch wc.Operation {
+ case StringEq:
+ return lValue == rValue
+ case StringNe:
+ return lValue != rValue
+ case StringContains:
+ return strings.Contains(lValue, rValue)
+ default:
+ logger.Error("Unknown string operation", lValue, wc.Operation, rValue)
+ }
+ return false
+}
diff --git a/internal/omode/mode.go b/internal/omode/mode.go
new file mode 100644
index 0000000..4bdfc45
--- /dev/null
+++ b/internal/omode/mode.go
@@ -0,0 +1,81 @@
+package omode
+
+import (
+ "fmt"
+ "os"
+ "path"
+)
+
+// Mode used.
+type Mode int
+
+// Possible modes.
+const (
+ Unknown Mode = iota
+ Server Mode = iota
+ TailClient Mode = iota
+ CatClient Mode = iota
+ GrepClient Mode = iota
+ MapClient Mode = iota
+ HealthClient Mode = iota
+)
+
+// New returns the mode based on the mode string.
+func New(modeStr string) Mode {
+ switch modeStr {
+ case "dserver":
+ return Server
+ case "server":
+ return Server
+
+ case "dtail":
+ fallthrough
+ case "tail":
+ return TailClient
+
+ case "grep":
+ fallthrough
+ case "dgrep":
+ return GrepClient
+
+ case "cat":
+ fallthrough
+ case "dcat":
+ return CatClient
+
+ case "map":
+ fallthrough
+ case "dmap":
+ return MapClient
+
+ case "health":
+ return HealthClient
+
+ default:
+ panic(fmt.Sprintf("Unknown mode: '%s'", modeStr))
+ }
+}
+
+// Default mode.
+func Default() Mode {
+ return New(path.Base(os.Args[0]))
+}
+
+func (m Mode) String() string {
+ switch m {
+ case Server:
+ return "server"
+ case TailClient:
+ return "tail"
+ case CatClient:
+ return "cat"
+ case GrepClient:
+ return "grep"
+ case MapClient:
+ return "map"
+ case HealthClient:
+ return "health"
+ default:
+ return "unknown"
+ }
+}
diff --git a/internal/pprof/pprof.go b/internal/pprof/pprof.go
new file mode 100644
index 0000000..f78bcf6
--- /dev/null
+++ b/internal/pprof/pprof.go
@@ -0,0 +1,17 @@
+package pprof
+
+import (
+ "fmt"
+ "net/http"
+ _ "net/http"
+ _ "net/http/pprof"
+
+ "github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/logger"
+)
+
+func Start() {
+ bindAddr := fmt.Sprintf("%s:%d", config.Common.PProfBindAddress, config.Common.PProfPort)
+ logger.Info("Starting PProf server", bindAddr)
+ go http.ListenAndServe(bindAddr, nil)
+}
diff --git a/internal/prompt/prompt.go b/internal/prompt/prompt.go
new file mode 100644
index 0000000..76a2726
--- /dev/null
+++ b/internal/prompt/prompt.go
@@ -0,0 +1,95 @@
+package prompt
+
+import (
+ "bufio"
+ "github.com/mimecast/dtail/internal/logger"
+ "fmt"
+ "os"
+ "strings"
+)
+
+// Answer is a user input of a prompt question.
+type Answer struct {
+ // Long version of the expected user input
+ Long string
+ // Short version of the expected user input
+ Short string
+ // Runs when user input matches
+ Callback func()
+ // Runs after Callback and after logging resumes
+ EndCallback func()
+
+ AskAgain bool
+}
+
+// Prompt used for interactive user input.
+type Prompt struct {
+ question string
+ answers []Answer
+}
+
+func (p *Prompt) askString() string {
+ var sb strings.Builder
+
+ sb.WriteString(p.question)
+ sb.WriteString("? (")
+
+ var ax []string
+ for _, a := range p.answers {
+ ax = append(ax, fmt.Sprintf("%s=%s", a.Short, a.Long))
+ }
+
+ sb.WriteString(strings.Join(ax, ","))
+ sb.WriteString("): ")
+
+ return sb.String()
+}
+
+// New returns a new prompt.
+func New(question string) *Prompt {
+ return &Prompt{question: question}
+}
+
+// Add an answer.
+func (p *Prompt) Add(answer Answer) {
+ p.answers = append(p.answers, answer)
+}
+
+// Ask a question.
+func (p *Prompt) Ask() {
+ reader := bufio.NewReader(os.Stdin)
+ logger.Pause()
+
+ for {
+ fmt.Print(p.askString())
+ answerStr, _ := reader.ReadString('\n')
+
+ if a, ok := p.answer(strings.TrimSpace(answerStr)); ok {
+ if a.Callback != nil {
+ a.Callback()
+ }
+
+ if !a.AskAgain {
+ logger.Resume()
+ if a.EndCallback != nil {
+ a.EndCallback()
+ }
+ return
+ }
+ }
+ }
+}
+
+func (p *Prompt) answer(answerStr string) (*Answer, bool) {
+ for _, a := range p.answers {
+ switch answerStr {
+ case a.Long:
+ return &a, true
+ case a.Short:
+ return &a, true
+ default:
+ }
+ }
+
+ return nil, false
+}
diff --git a/internal/server/handlers/controlhandler.go b/internal/server/handlers/controlhandler.go
new file mode 100644
index 0000000..482f759
--- /dev/null
+++ b/internal/server/handlers/controlhandler.go
@@ -0,0 +1,106 @@
+package handlers
+
+import (
+ "fmt"
+ "io"
+ "os"
+ "strings"
+
+ "github.com/mimecast/dtail/internal/logger"
+ user "github.com/mimecast/dtail/internal/user/server"
+)
+
+// ControlHandler is used for control functions and health monitoring.
+type ControlHandler struct {
+ serverMessages chan string
+ pong chan struct{}
+ stop chan struct{}
+ payload []byte
+ hostname string
+ user *user.User
+}
+
+// NewControlHandler returns a new control handler.
+func NewControlHandler(user *user.User) *ControlHandler {
+ logger.Debug(user, "Creating control handler")
+
+ h := ControlHandler{
+ serverMessages: make(chan string, 10),
+ pong: make(chan struct{}, 10),
+ stop: make(chan struct{}),
+ user: user,
+ }
+
+ fqdn, err := os.Hostname()
+ if err != nil {
+ logger.FatalExit(err)
+ }
+
+ s := strings.Split(fqdn, ".")
+ h.hostname = s[0]
+ return &h
+}
+
+// Read is to send data to the client via the Reader interface.
+func (h *ControlHandler) Read(p []byte) (n int, err error) {
+ for {
+ select {
+ case message := <-h.serverMessages:
+ wholePayload := []byte(fmt.Sprintf("SERVER|%s|%s\n", h.hostname, message))
+ n = copy(p, wholePayload)
+ return
+ case <-h.pong:
+ logger.Info(h.user, "Sending pong")
+ n = copy(p, []byte(".pong\n"))
+ return
+ case <-h.stop:
+ return 0, io.EOF
+ }
+ }
+}
+
+// Write is to read data to the client via the Writer interface.
+func (h *ControlHandler) Write(p []byte) (n int, err error) {
+ for _, c := range p {
+ switch c {
+ case ';':
+ wholePayload := strings.TrimSpace(string(h.payload))
+ h.handleCommand(wholePayload)
+ h.payload = nil
+
+ default:
+ h.payload = append(h.payload, c)
+ }
+ }
+
+ n = len(p)
+ return
+}
+
+// Close the control handler.
+func (h *ControlHandler) Close() {
+ close(h.stop)
+}
+
+// Wait returns the handler stop channel.
+func (h *ControlHandler) Wait() <-chan struct{} {
+ return h.stop
+}
+
+func (h *ControlHandler) handleCommand(command string) {
+ logger.Info(h.user, command)
+ s := strings.Split(command, " ")
+ logger.Debug(h.user, "Receiving command", command, s)
+
+ switch s[0] {
+ case "health":
+ h.serverMessages <- "OK: DTail SSH Server seems fine"
+ h.serverMessages <- "done;"
+ case "ping":
+ h.pong <- struct{}{}
+ case "debug":
+ h.serverMessages <- logger.Debug(h.user, "Receiving debug command", command, s)
+ default:
+ h.serverMessages <- logger.Warn(h.user, "Received unknown command", command, s)
+ }
+}
diff --git a/internal/server/handlers/handler.go b/internal/server/handlers/handler.go
new file mode 100644
index 0000000..8b1f73e
--- /dev/null
+++ b/internal/server/handlers/handler.go
@@ -0,0 +1,10 @@
+package handlers
+
+import "io"
+
+// Handler interface for server side functionality.
+type Handler interface {
+ io.ReadWriter
+ Close()
+ Wait() <-chan struct{}
+}
diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go
new file mode 100644
index 0000000..bed8609
--- /dev/null
+++ b/internal/server/handlers/serverhandler.go
@@ -0,0 +1,492 @@
+package handlers
+
+import (
+ "fmt"
+ "io"
+ "os"
+ "path/filepath"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/fs"
+ "github.com/mimecast/dtail/internal/logger"
+ "github.com/mimecast/dtail/internal/mapr/server"
+ "github.com/mimecast/dtail/internal/omode"
+ user "github.com/mimecast/dtail/internal/user/server"
+ "github.com/mimecast/dtail/internal/version"
+)
+
+const (
+ commandParseWarning string = "Unable to parse command"
+)
+
+// ServerHandler implements the Reader and Writer interfaces to handle
+// the Bi-directional communication between SSH client and server.
+// This handler implements the handler of the SSH server.
+type ServerHandler struct {
+ // Local log file readers
+ fileReaders []fs.FileReader
+ fileReadersMtx *sync.Mutex
+ // Channel for read lines.
+ lines chan fs.LineRead
+ // Only process log lines matching this regex.
+ regex string
+ // Server side mapr log aggregation.
+ aggregate *server.Aggregate
+ // Channel of aggregated log lines.
+ aggregatedMessages chan string
+ // Channel for server messages to be sent to the client.
+ serverMessages chan string
+ // Channel for hidden messages to be sent to the client.
+ hiddenMessages chan string
+ // The current payload sent to the client.
+ payload []byte
+ // The current server hostname.
+ hostname string
+ // The user connecting to dtail.
+ user *user.User
+ // To limit the server wide max amount of concurrent cats
+ catLimiter chan struct{}
+ // To limit the server wide max amount of concurrent tails
+ tailLimiter chan struct{}
+ // Server can tell handler to stop the handler.
+ stop chan struct{}
+ // Indicate that client responded to server with "ack stop connection"
+ ackStopReceived chan struct{}
+ // Stop timeout.
+ stopTimeout chan struct{}
+}
+
+// NewServerHandler returns the server handler.
+func NewServerHandler(user *user.User, catLimiter chan struct{}, tailLimiter chan struct{}) *ServerHandler {
+ logger.Debug(user, "Creating tail handler")
+ h := ServerHandler{
+ fileReadersMtx: &sync.Mutex{},
+ lines: make(chan fs.LineRead, 100),
+ serverMessages: make(chan string, 10),
+ aggregatedMessages: make(chan string, 10),
+ hiddenMessages: make(chan string, 10),
+ ackStopReceived: make(chan struct{}),
+ stopTimeout: make(chan struct{}),
+ stop: make(chan struct{}),
+ catLimiter: catLimiter,
+ tailLimiter: tailLimiter,
+ regex: ".",
+ user: user,
+ }
+
+ fqdn, err := os.Hostname()
+ if err != nil {
+ logger.FatalExit(err)
+ }
+
+ s := strings.Split(fqdn, ".")
+ h.hostname = s[0]
+
+ return &h
+}
+
+// Read is to send data to the dtail client via Reader interface.
+func (h *ServerHandler) Read(p []byte) (n int, err error) {
+ for {
+ select {
+ case message := <-h.serverMessages:
+ wholePayload := []byte(fmt.Sprintf("SERVER|%s|%s\n", h.hostname, message))
+ n = copy(p, wholePayload)
+ return
+ case message := <-h.aggregatedMessages:
+ data := fmt.Sprintf("AGGREGATE|%s|%s\n", h.hostname, message)
+ //logger.Debug("Sending aggregation data", data)
+ wholePayload := []byte(data)
+ n = copy(p, wholePayload)
+ return
+ case message := <-h.hiddenMessages:
+ //logger.Debug(h.user, "Sending hidden message", message)
+ wholePayload := []byte(fmt.Sprintf(".%s\n", message))
+ n = copy(p, wholePayload)
+ return
+ case line := <-h.lines:
+ serverInfo := []byte(fmt.Sprintf("REMOTE|%s|%3d|%v|%s|",
+ h.hostname, line.TransmittedPerc, line.Count, *line.GlobID))
+ wholePayload := append(serverInfo, line.Content[:]...)
+ n = copy(p, wholePayload)
+ return
+ case <-time.After(time.Second):
+ select {
+ case <-h.stop:
+ return 0, io.EOF
+ default:
+ }
+ }
+ }
+}
+
+// Write is to receive data from the dtail client via Writer interface.
+func (h *ServerHandler) Write(p []byte) (n int, err error) {
+ for _, c := range p {
+ switch c {
+ case ';':
+ commandStr := strings.TrimSpace(string(h.payload))
+ h.handleCommand(commandStr)
+ h.payload = nil
+ default:
+ h.payload = append(h.payload, c)
+ }
+ }
+
+ n = len(p)
+ return
+}
+
+// Close the server handler.
+func (h *ServerHandler) Close() {
+ h.fileReadersMtx.Lock()
+ defer h.fileReadersMtx.Unlock()
+
+ for _, reader := range h.fileReaders {
+ reader.Stop()
+ }
+ if h.aggregate != nil {
+ h.aggregate.Close()
+ }
+
+ close(h.stop)
+}
+
+func (h *ServerHandler) makeGlobID(path, glob string) string {
+ var idParts []string
+ pathParts := strings.Split(path, "/")
+
+ for i, globPart := range strings.Split(glob, "/") {
+ if strings.Contains(globPart, "*") {
+ idParts = append(idParts, pathParts[i])
+ }
+ }
+
+ if len(idParts) > 0 {
+ return strings.Join(idParts, "/")
+ }
+
+ if len(pathParts) > 0 {
+ return pathParts[len(pathParts)-1]
+ }
+
+ h.send(h.serverMessages, logger.Error("Empty file path given?", path, glob))
+ return ""
+}
+
+func (h *ServerHandler) processFileGlob(mode omode.Mode, glob string, regex string) {
+ retryInterval := time.Second * 5
+ glob = filepath.Clean(glob)
+
+ errors := make(chan struct{})
+ stop := make(chan struct{})
+ defer close(stop)
+
+ go func() {
+ for {
+ select {
+ case <-errors:
+ h.send(h.serverMessages, logger.Warn(h.user, "Unable to read file(s), check server logs"))
+ case <-stop:
+ return
+ case <-h.stop:
+ return
+ }
+ }
+ }()
+
+ maxRetries := 10
+ for {
+ maxRetries--
+ if maxRetries < 0 {
+ h.send(h.serverMessages, logger.Warn(h.user, "Giving up to read file(s)"))
+ h.internalClose()
+ return
+ }
+
+ paths, err := filepath.Glob(glob)
+ if err != nil {
+ logger.Warn(h.user, glob, err)
+ time.Sleep(retryInterval)
+ continue
+ }
+
+ if numPaths := len(paths); numPaths == 0 {
+ logger.Error(h.user, "No such file(s) to read", glob)
+ select {
+ case errors <- struct{}{}:
+ case <-h.stop:
+ return
+ default:
+ }
+ time.Sleep(retryInterval)
+ continue
+ }
+
+ h.startReadingFiles(mode, paths, glob, regex, retryInterval, errors)
+ break
+ }
+}
+
+func (h *ServerHandler) startReadingFiles(mode omode.Mode, paths []string, glob string, regex string, retryInterval time.Duration, errors chan<- struct{}) {
+ var wg sync.WaitGroup
+ wg.Add(len(paths))
+
+ read := func(path string, wg *sync.WaitGroup) {
+ defer wg.Done()
+ globID := h.makeGlobID(path, glob)
+
+ if !h.user.HasFilePermission(path) {
+ logger.Error(h.user, "No permission to read file", path, globID)
+ select {
+ case errors <- struct{}{}:
+ default:
+ }
+ return
+ }
+
+ h.startReadingFile(mode, path, globID, regex)
+ }
+
+ for _, path := range paths {
+ go read(path, &wg)
+ }
+
+ wg.Wait()
+}
+
+func (h *ServerHandler) startReadingFile(mode omode.Mode, path, globID, regex string) {
+ defer h.stopReadingFile(path)
+ logger.Info(h.user, "Start reading file", path, globID)
+
+ var reader fs.FileReader
+ switch mode {
+ case omode.TailClient:
+ reader = fs.NewTailFile(path, globID, h.serverMessages, h.tailLimiter)
+ case omode.GrepClient:
+ fallthrough
+ case omode.CatClient:
+ reader = fs.NewCatFile(path, globID, h.serverMessages, h.catLimiter)
+ default:
+ reader = fs.NewTailFile(path, globID, h.serverMessages, h.tailLimiter)
+ }
+
+ h.fileReadersMtx.Lock()
+ h.fileReaders = append(h.fileReaders, reader)
+ h.fileReadersMtx.Unlock()
+
+ lines := h.lines
+ // Plugin mappreduce engine
+ if h.aggregate != nil {
+ lines = h.aggregate.Lines
+ }
+
+ for {
+ if err := reader.Start(lines, regex); err != nil {
+ logger.Error(h.user, path, globID, err)
+ }
+
+ select {
+ case <-h.stop:
+ return
+ default:
+ if !reader.Retry() {
+ return
+ }
+ }
+
+ time.Sleep(time.Second * 2)
+ logger.Info(path, globID, "Reading file again")
+ }
+}
+
+func (h *ServerHandler) stopReadingFile(path string) {
+ logger.Info(h.user, "Stop reading file", path)
+
+ h.fileReadersMtx.Lock()
+ defer h.fileReadersMtx.Unlock()
+
+ path = filepath.Clean(path)
+ var fileReaders []fs.FileReader
+
+ for _, reader := range h.fileReaders {
+ if reader.FilePath() == path {
+ reader.Stop()
+ continue
+ }
+ fileReaders = append(fileReaders, reader)
+ }
+
+ if len(fileReaders) == len(h.fileReaders) {
+ logger.Warn(h.user, "Didn't read file path", path)
+ return
+ }
+
+ h.fileReaders = fileReaders
+
+ if len(fileReaders) == 0 {
+ if h.aggregate != nil {
+ h.aggregate.Serialize()
+ }
+ h.allLinesSent()
+ }
+}
+
+func (h *ServerHandler) numUnsentMessages() int {
+ return len(h.lines) + len(h.serverMessages) + len(h.hiddenMessages) + len(h.aggregatedMessages)
+}
+
+func (h *ServerHandler) allLinesSent() {
+ defer h.internalClose()
+
+ for i := 0; i < 3; i++ {
+ if h.numUnsentMessages() == 0 {
+ logger.Debug(h.user, "All lines sent")
+ return
+ }
+ logger.Debug(h.user, "Still lines to be sent")
+ time.Sleep(time.Second)
+ }
+
+ logger.Warn(h.user, "Some lines remain unsent", h.numUnsentMessages())
+}
+
+// Handler decides to shutdown the connection, not the server itself.
+func (h *ServerHandler) internalClose() {
+ select {
+ case h.hiddenMessages <- "syn close connection":
+ case <-time.After(time.Second * 5):
+ logger.Debug(h.user, "Not waiting for ack close connection")
+ close(h.stopTimeout)
+ return
+ }
+
+ select {
+ case <-h.Wait():
+ case <-time.After(time.Second * 5):
+ logger.Debug(h.user, "Not waiting for ack close connection")
+ close(h.stopTimeout)
+ }
+}
+
+func (h *ServerHandler) handleCommand(commandStr string) {
+ logger.Info(h.user, commandStr)
+
+ args := strings.Split(commandStr, " ")
+ argc := len(args)
+
+ logger.Debug(h.user, "Received command", commandStr, argc, args)
+
+ if h.user.Name == config.ControlUser {
+ h.handleControlCommand(argc, args)
+ return
+ }
+
+ h.handleUserCommand(argc, args)
+}
+
+// Special (restricted) set of commands for anonymous ControlUser access.
+func (h *ServerHandler) handleControlCommand(argc int, args []string) {
+ switch args[0] {
+ case "ping":
+ h.send(h.hiddenMessages, "pong")
+ case "debug":
+ h.send(h.serverMessages, logger.Debug(h.user, "Receiving debug command", argc, args))
+ default:
+ logger.Warn(h.user, "Received unknown command", argc, args)
+ }
+}
+
+// Commands for authed users.
+func (h *ServerHandler) handleUserCommand(argc int, args []string) {
+ switch args[0] {
+ case "grep":
+ fallthrough
+ case "cat":
+ h.handleReadCommand(argc, args, omode.CatClient)
+ case "tail":
+ h.handleReadCommand(argc, args, omode.TailClient)
+ case "map":
+ h.handleMapCommand(argc, args)
+ case "ack":
+ h.handleAckCommand(argc, args)
+ case "ping":
+ h.send(h.hiddenMessages, "pong")
+ case "version":
+ h.send(h.serverMessages, fmt.Sprintf("Server version is "+version.String()))
+ case "debug":
+ h.send(h.serverMessages, logger.Debug(h.user, "Received debug command", argc, args))
+ default:
+ h.send(h.serverMessages, logger.Warn(h.user, "Received unknown command", argc, args))
+ }
+}
+
+func (h *ServerHandler) handleReadCommand(argc int, args []string, mode omode.Mode) {
+ regex := "."
+ if argc >= 4 {
+ regex = args[3]
+ }
+ if argc < 3 {
+ h.send(h.serverMessages, logger.Warn(h.user, commandParseWarning, args, argc))
+ return
+ }
+ go h.processFileGlob(mode, args[1], regex)
+}
+
+func (h *ServerHandler) handleMapCommand(argc int, args []string) {
+ if argc < 2 {
+ h.send(h.serverMessages, logger.Warn(h.user, commandParseWarning, args, argc))
+ return
+ }
+
+ queryStr := strings.Join(args[1:], " ")
+ logger.Info(h.user, "Creating new mapr aggregator", queryStr)
+ aggregate, err := server.NewAggregate(h.aggregatedMessages, queryStr)
+
+ if err != nil {
+ h.send(h.serverMessages, logger.Error(h.user, err))
+ return
+ }
+
+ h.aggregate = aggregate
+}
+
+func (h *ServerHandler) handleAckCommand(argc int, args []string) {
+ if argc < 3 {
+ h.send(h.serverMessages, logger.Warn(h.user, commandParseWarning, args, argc))
+ return
+ }
+ if args[1] == "close" && args[2] == "connection" {
+ close(h.ackStopReceived)
+ }
+}
+
+func (h *ServerHandler) send(ch chan<- string, message string) {
+ select {
+ case ch <- message:
+ case <-h.stop:
+ }
+}
+
+// Wait (block) until server handler is closed or a timeout has exceeded.
+func (h *ServerHandler) Wait() <-chan struct{} {
+ wait := make(chan struct{})
+
+ go func() {
+ select {
+ case <-h.ackStopReceived:
+ logger.Debug(h.user, "Closing wait channel due to ACK stop received")
+ close(wait)
+ case <-h.stopTimeout:
+ logger.Debug(h.user, "Closing wait channel due to wait timeout")
+ close(wait)
+ case <-h.stop:
+ logger.Debug(h.user, "Closing wait channel due to stop")
+ }
+ }()
+
+ return wait
+}
diff --git a/internal/server/server.go b/internal/server/server.go
new file mode 100644
index 0000000..27a98f5
--- /dev/null
+++ b/internal/server/server.go
@@ -0,0 +1,212 @@
+package server
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "net"
+
+ "github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/logger"
+ "github.com/mimecast/dtail/internal/server/handlers"
+ "github.com/mimecast/dtail/internal/ssh/server"
+ user "github.com/mimecast/dtail/internal/user/server"
+ "github.com/mimecast/dtail/internal/version"
+
+ gossh "golang.org/x/crypto/ssh"
+)
+
+// Server is the main server data structure.
+type Server struct {
+ // Various server statistics counters.
+ stats stats
+ // SSH server configuration.
+ sshServerConfig *gossh.ServerConfig
+ // To control the max amount of concurrent cats (which can cause a lot of I/O on the server)
+ catLimiterCh chan struct{}
+ // To control the max amount of concurrent tails
+ tailLimiterCh chan struct{}
+ // Ask to shutdown the server
+ stop chan struct{}
+}
+
+// New returns a new server.
+func New() *Server {
+ logger.Info("Creating server", version.String())
+
+ s := Server{
+ sshServerConfig: &gossh.ServerConfig{},
+ catLimiterCh: make(chan struct{}, config.Server.MaxConcurrentCats),
+ tailLimiterCh: make(chan struct{}, config.Server.MaxConcurrentTails),
+ stop: make(chan struct{}),
+ }
+
+ s.sshServerConfig.PasswordCallback = s.controlUserCallback
+ s.sshServerConfig.PublicKeyCallback = server.PublicKeyCallback
+
+ private, err := gossh.ParsePrivateKey(server.PrivateHostKey())
+ if err != nil {
+ logger.FatalExit(err)
+ }
+ s.sshServerConfig.AddHostKey(private)
+
+ return &s
+}
+
+// Start the server.
+func (s *Server) Start() int {
+ logger.Info("Starting server")
+
+ bindAt := fmt.Sprintf("%s:%d", config.Server.SSHBindAddress, config.Common.SSHPort)
+ logger.Info("Binding server", bindAt)
+ listener, err := net.Listen("tcp", bindAt)
+ if err != nil {
+ logger.FatalExit("Failed to open listening TCP socket", err)
+ }
+
+ go s.stats.periodicLogServerStats(s.stop)
+
+ for {
+ conn, err := listener.Accept() // Blocking
+ if err != nil {
+ logger.Error("Failed to accept incoming connection", err)
+ continue
+ }
+
+ if err := s.stats.serverLimitExceeded(); err != nil {
+ logger.Error(err)
+ conn.Close()
+ continue
+ }
+
+ go s.handleConnection(conn)
+ }
+}
+
+func (s *Server) handleConnection(conn net.Conn) {
+ logger.Info("Handling connection")
+
+ sshConn, chans, reqs, err := gossh.NewServerConn(conn, s.sshServerConfig)
+ if err != nil {
+ logger.Error("Something just happened", err)
+ return
+ }
+
+ s.stats.incrementConnections()
+
+ go gossh.DiscardRequests(reqs)
+ for newChannel := range chans {
+ go s.handleChannel(sshConn, newChannel)
+ }
+}
+
+func (s *Server) handleChannel(sshConn gossh.Conn, newChannel gossh.NewChannel) {
+ user := user.New(sshConn.User(), sshConn.RemoteAddr().String())
+ logger.Info(user, "Invoking channel handler")
+
+ if newChannel.ChannelType() != "session" {
+ err := errors.New("Don'w allow other channel types than session")
+ logger.Error(user, err)
+ newChannel.Reject(gossh.Prohibited, err.Error())
+ return
+ }
+
+ channel, requests, err := newChannel.Accept()
+ if err != nil {
+ logger.Error(user, "Could not accept channel", err)
+ return
+ }
+
+ if err := s.handleRequests(sshConn, requests, channel, user); err != nil {
+ logger.Error(user, err)
+ sshConn.Close()
+ }
+}
+
+func (s *Server) handleRequests(sshConn gossh.Conn, in <-chan *gossh.Request, channel gossh.Channel, user *user.User) error {
+ logger.Info(user, "Invoking request handler")
+
+ for req := range in {
+ var payload = struct{ Value string }{}
+ gossh.Unmarshal(req.Payload, &payload)
+
+ switch req.Type {
+ case "shell":
+ var handler handlers.Handler
+ switch user.Name {
+ case config.ControlUser:
+ handler = handlers.NewControlHandler(user)
+ default:
+ handler = handlers.NewServerHandler(user, s.catLimiterCh, s.tailLimiterCh)
+ }
+
+ // Bi-directionally connect SSH stream to SSH handler
+ brokenPipe1 := make(chan struct{})
+ go func() {
+ defer close(brokenPipe1)
+ io.Copy(channel, handler)
+ }()
+
+ brokenPipe2 := make(chan struct{})
+ go func() {
+ defer close(brokenPipe2)
+ io.Copy(handler, channel)
+ }()
+
+ // Ensure to close all fd's and stop all goroutines once ssh connection terminated
+ go func() {
+ defer s.stats.decrementConnections()
+ defer handler.Close()
+
+ if err := sshConn.Wait(); err != nil && err != io.EOF {
+ logger.Error(user, err)
+ }
+ logger.Info(user, "Good bye Mister!")
+ }()
+
+ // Close the underlying ssh socket when server shuts down
+ go func() {
+ select {
+ case <-s.stop:
+ logger.Debug(user, "Server initiating shutdown on handler")
+ case <-handler.Wait():
+ logger.Debug(user, "Handler initiating shutdown by its own")
+ case <-brokenPipe1:
+ logger.Debug(user, "Broken pipe1")
+ case <-brokenPipe2:
+ logger.Debug(user, "Broken pipe2")
+ }
+ sshConn.Close()
+ logger.Info(user, "Closed SSH connection")
+ }()
+
+ // Only serving shell type
+ req.Reply(true, nil)
+
+ default:
+ req.Reply(false, nil)
+
+ return fmt.Errorf("Closing SSH connection as unknown request recieved|%s|%v",
+ req.Type, payload.Value)
+ }
+ }
+
+ return nil
+}
+
+func (*Server) controlUserCallback(c gossh.ConnMetadata, authPayload []byte) (*gossh.Permissions, error) {
+ user := user.New(c.User(), c.RemoteAddr().String())
+
+ if user.Name == config.ControlUser && string(authPayload) == config.ControlUser {
+ logger.Debug(user, "Initiating master control program")
+ return nil, nil
+ }
+
+ return nil, fmt.Errorf("Not authorized")
+}
+
+// Stop the server.
+func (s *Server) Stop() {
+ close(s.stop)
+ s.stats.waitForConnections()
+}
diff --git a/internal/server/stats.go b/internal/server/stats.go
new file mode 100644
index 0000000..beb1885
--- /dev/null
+++ b/internal/server/stats.go
@@ -0,0 +1,88 @@
+package server
+
+import (
+ "github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/logger"
+ "fmt"
+ "runtime"
+ "sync"
+ "time"
+)
+
+// Used to collect and display various server stats.
+type stats struct {
+ mutex sync.Mutex
+ currentConnections int
+ lifetimeConnections uint64
+}
+
+func (s *stats) incrementConnections() {
+ defer s.logServerStats()
+
+ s.mutex.Lock()
+ s.currentConnections++
+ s.lifetimeConnections++
+ s.mutex.Unlock()
+}
+
+func (s *stats) decrementConnections() {
+ defer s.logServerStats()
+
+ s.mutex.Lock()
+ s.currentConnections--
+ s.mutex.Unlock()
+}
+
+func (s *stats) hasConnections() bool {
+ s.mutex.Lock()
+ currentConnections := s.currentConnections
+ s.mutex.Unlock()
+
+ has := currentConnections > 0
+ logger.Info("stats", "Server with open connections?", has, currentConnections)
+
+ return has
+}
+
+func (s *stats) logServerStats() {
+ s.mutex.Lock()
+ defer s.mutex.Unlock()
+
+ currentConnections := fmt.Sprintf("currentConnections=%d", s.currentConnections)
+ lifetimeConnections := fmt.Sprintf("lifetimeConnections=%d", s.lifetimeConnections)
+ goroutines := fmt.Sprintf("goroutines=%d", runtime.NumGoroutine())
+ logger.Info("stats", currentConnections, lifetimeConnections, goroutines)
+}
+
+func (s *stats) serverLimitExceeded() error {
+ s.mutex.Lock()
+ defer s.mutex.Unlock()
+
+ if s.currentConnections >= config.Server.MaxConnections {
+ return fmt.Errorf("Exceeded max allowed concurrent connections of %d", config.Server.MaxConnections)
+ }
+
+ return nil
+}
+
+func (s *stats) periodicLogServerStats(stop <-chan struct{}) {
+ for {
+ select {
+ case <-time.NewTimer(time.Second * 10).C:
+ s.logServerStats()
+ case <-stop:
+ return
+ }
+ }
+}
+
+func (s *stats) waitForConnections() {
+ for {
+ select {
+ case <-time.NewTimer(time.Second).C:
+ if !s.hasConnections() {
+ return
+ }
+ }
+ }
+}
diff --git a/internal/ssh/client/authmethods.go b/internal/ssh/client/authmethods.go
new file mode 100644
index 0000000..3392eb1
--- /dev/null
+++ b/internal/ssh/client/authmethods.go
@@ -0,0 +1,45 @@
+package client
+
+import (
+ "github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/logger"
+ "github.com/mimecast/dtail/internal/ssh"
+ "os"
+
+ gossh "golang.org/x/crypto/ssh"
+)
+
+// InitSSHAuthMethods initialises all known SSH auth methods on othe client side.
+func InitSSHAuthMethods(trustAllHosts bool, throttleCh chan struct{}) ([]gossh.AuthMethod, *HostKeyCallback) {
+ var sshAuthMethods []gossh.AuthMethod
+
+ if config.Common.ExperimentalFeaturesEnable {
+ sshAuthMethods = append(sshAuthMethods, gossh.Password("experimental feature test"))
+ logger.Info("Added experimental method to list of auth methods")
+ }
+
+ keyPath := os.Getenv("HOME") + "/.ssh/id_rsa"
+ if authMethod, err := ssh.PrivateKey(keyPath); err == nil {
+ sshAuthMethods = append(sshAuthMethods, authMethod)
+ logger.Info("Added path to list of auth methods", keyPath)
+ }
+
+ keyPath = os.Getenv("HOME") + "/.ssh/id_dsa"
+ if authMethod, err := ssh.PrivateKey(keyPath); err == nil {
+ sshAuthMethods = append(sshAuthMethods, authMethod)
+ logger.Info("Added path to list of auth methods", keyPath)
+ }
+
+ if authMethod, err := ssh.Agent(); err == nil {
+ sshAuthMethods = append(sshAuthMethods, authMethod)
+ logger.Info("Added SSH Agent to list of auth methods")
+ }
+
+ knownHostsPath := os.Getenv("HOME") + "/.ssh/known_hosts"
+ hostKeyCallback, err := NewHostKeyCallback(knownHostsPath, trustAllHosts, throttleCh)
+ if err != nil {
+ logger.FatalExit(knownHostsPath, err)
+ }
+
+ return sshAuthMethods, hostKeyCallback
+}
diff --git a/internal/ssh/client/hostkeycallback.go b/internal/ssh/client/hostkeycallback.go
new file mode 100644
index 0000000..4023e59
--- /dev/null
+++ b/internal/ssh/client/hostkeycallback.go
@@ -0,0 +1,285 @@
+package client
+
+import (
+ "bufio"
+ "github.com/mimecast/dtail/internal/logger"
+ "github.com/mimecast/dtail/internal/prompt"
+ "fmt"
+ "net"
+ "os"
+ "strings"
+ "sync"
+ "time"
+
+ "golang.org/x/crypto/ssh"
+ "golang.org/x/crypto/ssh/knownhosts"
+)
+
+type response int
+
+const (
+ trustHost response = iota
+ dontTrustHost response = iota
+)
+
+// Represents an unknown host.
+type unknownHost struct {
+ server string
+ remote net.Addr
+ key ssh.PublicKey
+ hostLine string
+ ipLine string
+ responseCh chan response
+}
+
+// HostKeyCallback is a wrapper around ssh.KnownHosts so that we can add all
+// unknown hosts in a single batch to the known_hosts file.
+type HostKeyCallback struct {
+ knownHostsPath string
+ unknownCh chan unknownHost
+ throttleCh chan struct{}
+ trustAllHostsCh chan struct{}
+ untrustedHosts map[string]bool
+ mutex sync.Mutex
+}
+
+// NewHostKeyCallback returns a new wrapper.
+func NewHostKeyCallback(knownHostsPath string, trustAllHosts bool, throttleCh chan struct{}) (*HostKeyCallback, error) {
+ // Ensure file exists
+ os.OpenFile(knownHostsPath, os.O_RDONLY|os.O_CREATE, 0666)
+
+ h := HostKeyCallback{
+ knownHostsPath: knownHostsPath,
+ unknownCh: make(chan unknownHost),
+ trustAllHostsCh: make(chan struct{}),
+ throttleCh: throttleCh,
+ untrustedHosts: make(map[string]bool),
+ }
+
+ if trustAllHosts {
+ close(h.trustAllHostsCh)
+ }
+
+ return &h, nil
+}
+
+// Wrap the host key callback.
+func (h *HostKeyCallback) Wrap() ssh.HostKeyCallback {
+ return func(server string, remote net.Addr, key ssh.PublicKey) error {
+ // Parse known_hosts file
+ knownHostsCb, err := knownhosts.New(h.knownHostsPath)
+ if err != nil {
+ // Problem parsing it
+ return err
+ }
+
+ // Check for valid entry in known_hosts file
+ err = knownHostsCb(server, remote, key)
+ if err == nil {
+ // OK
+ return nil
+ }
+
+ // Make sure that interactive user callback does not interfere with
+ // SSH connection throttler.
+ <-h.throttleCh
+ defer func() { h.throttleCh <- struct{}{} }()
+
+ unknown := unknownHost{
+ server: server,
+ remote: remote,
+ key: key,
+ hostLine: knownhosts.Line([]string{server}, key),
+ ipLine: knownhosts.Line([]string{remote.String()}, key),
+ responseCh: make(chan response),
+ }
+
+ logger.Warn("Encountered unknown host", unknown)
+ // Notify user that there is an unknown host
+ h.unknownCh <- unknown
+
+ // Wait for user input.
+ switch <-unknown.responseCh {
+ case trustHost:
+ // End user acknowledged host key
+ return nil
+ case dontTrustHost:
+ }
+
+ h.mutex.Lock()
+ defer h.mutex.Unlock()
+ h.untrustedHosts[server] = true
+
+ return err
+ }
+}
+
+// PromptAddHosts prompts a question to the user whether unknown hosts should
+// be added to the known hosts or not.
+func (h *HostKeyCallback) PromptAddHosts(stop <-chan struct{}) {
+ var hosts []unknownHost
+
+ for {
+ // Check whether there is a unknown host
+ select {
+ case unknown := <-h.unknownCh:
+ hosts = append(hosts, unknown)
+ // Ask every 50 unknown hosts
+ if len(hosts) >= 50 {
+ h.promptAddHosts(hosts)
+ hosts = []unknownHost{}
+ }
+ case <-time.After(2 * time.Second):
+ // Or ask when after 2 seconds no new unknown hosts were added.
+ if len(hosts) > 0 {
+ h.promptAddHosts(hosts)
+ hosts = []unknownHost{}
+ }
+ case <-stop:
+ logger.Debug("Stopping goroutine prompting new hosts...")
+ return
+ }
+ }
+}
+
+func (h *HostKeyCallback) promptAddHosts(hosts []unknownHost) {
+ var servers []string
+
+ for _, host := range hosts {
+ servers = append(servers, host.server)
+ }
+
+ select {
+ case <-h.trustAllHostsCh:
+ logger.Warn("Trusting host keys of servers", servers)
+ h.trustHosts(hosts)
+ return
+ default:
+ }
+
+ question := fmt.Sprintf("Encountered %d unknown hosts: '%s'\n%s",
+ len(servers),
+ strings.Join(servers, ","),
+ "Do you want to trust these hosts?",
+ )
+
+ p := prompt.New(question)
+
+ a := prompt.Answer{
+ Long: "yes",
+ Short: "y",
+ Callback: func() {
+ h.trustHosts(hosts)
+ },
+ EndCallback: func() {
+ logger.Info("Added hosts to known hosts file", h.knownHostsPath)
+ },
+ }
+ p.Add(a)
+
+ a = prompt.Answer{
+ Long: "all",
+ Short: "a",
+ Callback: func() {
+ close(h.trustAllHostsCh)
+ h.trustHosts(hosts)
+ },
+ EndCallback: func() {
+ logger.Info("Added hosts to known hosts file", h.knownHostsPath)
+ },
+ }
+ p.Add(a)
+
+ a = prompt.Answer{
+ Long: "no",
+ Short: "n",
+ Callback: func() {
+ h.dontTrustHosts(hosts)
+ },
+ EndCallback: func() {
+ logger.Info("Didn't add hosts to known hosts file", h.knownHostsPath)
+ },
+ }
+ p.Add(a)
+
+ a = prompt.Answer{
+ Long: "details",
+ Short: "d",
+ AskAgain: true,
+ Callback: func() {
+ for _, unknown := range hosts {
+ fmt.Println(unknown.hostLine)
+ fmt.Println(unknown.ipLine)
+ }
+ },
+ }
+ p.Add(a)
+
+ p.Ask()
+}
+
+func (h *HostKeyCallback) trustHosts(hosts []unknownHost) {
+ tmpKnownHostsPath := fmt.Sprintf("%s.tmp", h.knownHostsPath)
+
+ newFd, err := os.OpenFile(tmpKnownHostsPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600)
+ if err != nil {
+ panic(fmt.Sprintf("%s: %s", tmpKnownHostsPath, err.Error()))
+ }
+ defer newFd.Close()
+
+ // Newly trusted hosts in normalized form
+ addresses := make(map[string]struct{})
+
+ // First write to new known hosts file, and keep track of addresses
+ for _, unknown := range hosts {
+ unknown.responseCh <- trustHost
+
+ // Add once as [HOSTNAME]:PORT
+ addresses[knownhosts.Normalize(unknown.server)] = struct{}{}
+ // And once as [IP]:PORT
+ addresses[knownhosts.Normalize(unknown.remote.String())] = struct{}{}
+
+ newFd.WriteString(fmt.Sprintf("%s\n", unknown.hostLine))
+ newFd.WriteString(fmt.Sprintf("%s\n", unknown.ipLine))
+ }
+
+ // Read old known hosts file, to see which are old and new entries
+ os.OpenFile(h.knownHostsPath, os.O_RDONLY|os.O_CREATE, 0666)
+ oldFd, err := os.Open(h.knownHostsPath)
+ if err != nil {
+ panic(err)
+ }
+ defer oldFd.Close()
+
+ scanner := bufio.NewScanner(oldFd)
+
+ // Now, append all still valid old entries to the new host file
+ for scanner.Scan() {
+ line := scanner.Text()
+ address := strings.SplitN(line, " ", 2)[0]
+
+ if _, ok := addresses[address]; !ok {
+ newFd.WriteString(fmt.Sprintf("%s\n", line))
+ }
+ }
+
+ // Now, replace old known hosts file
+ if err := os.Rename(tmpKnownHostsPath, h.knownHostsPath); err != nil {
+ panic(err)
+ }
+}
+
+func (h *HostKeyCallback) dontTrustHosts(hosts []unknownHost) {
+ for _, unknown := range hosts {
+ unknown.responseCh <- dontTrustHost
+ }
+}
+
+// Untrusted returns true if the host is not trusted. False otherwise.
+func (h *HostKeyCallback) Untrusted(server string) bool {
+ h.mutex.Lock()
+ defer h.mutex.Unlock()
+ _, ok := h.untrustedHosts[server]
+
+ return ok
+}
diff --git a/internal/ssh/server/hostkey.go b/internal/ssh/server/hostkey.go
new file mode 100644
index 0000000..7baa4aa
--- /dev/null
+++ b/internal/ssh/server/hostkey.go
@@ -0,0 +1,37 @@
+package server
+
+import (
+ "github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/logger"
+ "github.com/mimecast/dtail/internal/ssh"
+ "io/ioutil"
+ "os"
+)
+
+// PrivateHostKey retrieves the private server RSA host key.
+func PrivateHostKey() []byte {
+ hostKeyFile := config.Server.HostKeyFile
+ _, err := os.Stat(hostKeyFile)
+
+ if os.IsNotExist(err) {
+ logger.Info("Generating private server RSA host key")
+ privateKey, err := ssh.GeneratePrivateRSAKey(config.Server.HostKeyBits)
+
+ if err != nil {
+ logger.FatalExit("Failed to generate private server RSA host key", err)
+ }
+
+ pem := ssh.EncodePrivateKeyToPEM(privateKey)
+ if err := ioutil.WriteFile(hostKeyFile, pem, 0600); err != nil {
+ logger.Error("Unable to write private server RSA host key to file", hostKeyFile, err)
+ }
+ return pem
+ }
+
+ logger.Info("Reading private server RSA host key from file", hostKeyFile)
+ pem, err := ioutil.ReadFile(hostKeyFile)
+ if err != nil {
+ logger.FatalExit("Failed to load private server RSA host key", err)
+ }
+ return pem
+}
diff --git a/internal/ssh/server/publickeycallback.go b/internal/ssh/server/publickeycallback.go
new file mode 100644
index 0000000..c6929d7
--- /dev/null
+++ b/internal/ssh/server/publickeycallback.go
@@ -0,0 +1,62 @@
+package server
+
+import (
+ "fmt"
+ "io/ioutil"
+ "os"
+ osUser "os/user"
+
+ "github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/logger"
+ user "github.com/mimecast/dtail/internal/user/server"
+
+ gossh "golang.org/x/crypto/ssh"
+)
+
+// PublicKeyCallback is for the server to check whether a public SSH key is authorized ot not.
+func PublicKeyCallback(c gossh.ConnMetadata, pubKey gossh.PublicKey) (*gossh.Permissions, error) {
+ user := user.New(c.User(), c.RemoteAddr().String())
+ logger.Info(user, "Incoming authorization")
+
+ cwd, err := os.Getwd()
+ if err != nil {
+ return nil, fmt.Errorf("Unable to get current working directory|%s|", err.Error())
+ }
+
+ authorizedKeysFile := fmt.Sprintf("%s/%s/%s.authorized_keys", cwd, config.Common.CacheDir, user.Name)
+ if _, err := os.Stat(authorizedKeysFile); os.IsNotExist(err) {
+ user, err := osUser.Lookup(user.Name)
+ if err != nil {
+ return nil, fmt.Errorf("Unable to authorize|%s|%s|", user, err.Error())
+ }
+ // Fallback to ~
+ authorizedKeysFile = user.HomeDir + "/.ssh/authorized_keys"
+ }
+
+ logger.Info(user, "Reading", authorizedKeysFile)
+ authorizedKeysBytes, err := ioutil.ReadFile(authorizedKeysFile)
+ if err != nil {
+ return nil, fmt.Errorf("Unable to read authorized keys file|%s|%s|%s", authorizedKeysFile, user, err.Error())
+ }
+
+ authorizedKeysMap := map[string]bool{}
+ for len(authorizedKeysBytes) > 0 {
+ pubKey, _, _, rest, err := gossh.ParseAuthorizedKey(authorizedKeysBytes)
+ if err != nil {
+ return nil, fmt.Errorf("Unable to parse authorized keys bytes|%s|%s", user, err.Error())
+ }
+ authorizedKeysMap[string(pubKey.Marshal())] = true
+ authorizedKeysBytes = rest
+ }
+
+ if authorizedKeysMap[string(pubKey.Marshal())] {
+ logger.Debug("Public key fingerprint", gossh.FingerprintSHA256(pubKey), user)
+ return &gossh.Permissions{
+ Extensions: map[string]string{
+ "pubkey-fp": gossh.FingerprintSHA256(pubKey),
+ },
+ }, nil
+ }
+
+ return nil, fmt.Errorf("Unknown public key|%s", user)
+}
diff --git a/internal/ssh/ssh.go b/internal/ssh/ssh.go
new file mode 100644
index 0000000..77cc341
--- /dev/null
+++ b/internal/ssh/ssh.go
@@ -0,0 +1,112 @@
+package ssh
+
+import (
+ "crypto/rand"
+ "crypto/rsa"
+ "crypto/x509"
+ "github.com/mimecast/dtail/internal/logger"
+ "encoding/pem"
+ "fmt"
+ "io/ioutil"
+ "net"
+ "os"
+ "syscall"
+
+ gossh "golang.org/x/crypto/ssh"
+ "golang.org/x/crypto/ssh/agent"
+ "golang.org/x/crypto/ssh/terminal"
+)
+
+// GeneratePrivateRSAKey is used by the server to generate its key.
+func GeneratePrivateRSAKey(size int) (*rsa.PrivateKey, error) {
+ privateKey, err := rsa.GenerateKey(rand.Reader, size)
+ if err != nil {
+ return nil, err
+ }
+
+ err = privateKey.Validate()
+ if err != nil {
+ return nil, err
+ }
+
+ return privateKey, nil
+}
+
+// EncodePrivateKeyToPEM is a helper function for converting a key to PEM format.
+func EncodePrivateKeyToPEM(privateKey *rsa.PrivateKey) []byte {
+ derFormat := x509.MarshalPKCS1PrivateKey(privateKey)
+
+ block := pem.Block{
+ Type: "RSA PRIVATE KEY",
+ Headers: nil,
+ Bytes: derFormat,
+ }
+
+ return pem.EncodeToMemory(&block)
+}
+
+// Agent used for SSH auth.
+func Agent() (gossh.AuthMethod, error) {
+ sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK"))
+ if err != nil {
+ return nil, err
+ }
+ agentClient := agent.NewClient(sshAgent)
+ keys, err := agentClient.List()
+ if err != nil {
+ return nil, err
+ }
+ for i, key := range keys {
+ logger.Debug("Public key", i, key)
+ }
+ return gossh.PublicKeysCallback(agentClient.Signers), nil
+}
+
+// EnterKeyPhrase is required to read phrase protected private keys.
+func EnterKeyPhrase(keyFile string) []byte {
+ fmt.Printf("Enter phrase for key %s: ", keyFile)
+ phrase, err := terminal.ReadPassword(int(syscall.Stdin))
+ if err != nil {
+ panic(err)
+ }
+ fmt.Printf("%s\n", string(phrase))
+ return phrase
+}
+
+// KeyFile returns the key as a SSH auth method.
+func KeyFile(keyFile string) (gossh.AuthMethod, error) {
+ buffer, err := ioutil.ReadFile(keyFile)
+ if err != nil {
+ return nil, err
+ }
+
+ key, err := gossh.ParsePrivateKey(buffer)
+ if err != nil {
+ return nil, err
+ }
+
+ // Key phrase support disabled as password will be printed to stdout!
+ /*
+ if err == nil {
+ return gossh.PublicKeys(key), nil
+ }
+
+ keyPhrase := EnterKeyPhrase(keyFile)
+ key, err = gossh.ParsePrivateKeyWithPassphrase(buffer, keyPhrase)
+ if err != nil {
+ return nil, err
+ }
+ */
+
+ return gossh.PublicKeys(key), nil
+}
+
+// PrivateKey returns the private key as a SSH auth method.
+func PrivateKey(keyFile string) (gossh.AuthMethod, error) {
+ signer, err := KeyFile(keyFile)
+ if err != nil {
+ logger.Debug(keyFile, err)
+ return nil, err
+ }
+ return gossh.AuthMethod(signer), nil
+}
diff --git a/internal/user/name.go b/internal/user/name.go
new file mode 100644
index 0000000..5171ec7
--- /dev/null
+++ b/internal/user/name.go
@@ -0,0 +1,24 @@
+package user
+
+import (
+ "os/user"
+ )
+
+
+func Name() string {
+ user, err := user.Current()
+ if err != nil {
+ panic(err)
+ }
+
+ if user.Uid == "0" {
+ panic("Not allowed to run as UID 0")
+ }
+
+ if user.Gid == "0" {
+ panic("Not allowed to run as GID 0")
+ }
+
+ return user.Username
+}
+
diff --git a/internal/user/server/user.go b/internal/user/server/user.go
new file mode 100644
index 0000000..fad38d8
--- /dev/null
+++ b/internal/user/server/user.go
@@ -0,0 +1,131 @@
+package server
+
+import (
+ "github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/fs/permissions"
+ "github.com/mimecast/dtail/internal/logger"
+ "fmt"
+ "os"
+ "path/filepath"
+ "regexp"
+ "strings"
+)
+
+const maxLinkDepth int = 100
+
+// User represents an end-user which connected to the server via the DTail client.
+type User struct {
+ // The user name.
+ Name string
+ // The remote address connected from.
+ remoteAddress string
+ // The permissions the user has.
+ permissions []string
+}
+
+// New returns a new user.
+func New(name, remoteAddress string) *User {
+ return &User{
+ Name: name,
+ remoteAddress: remoteAddress,
+ }
+}
+
+// String representation of the user.
+func (u *User) String() string {
+ return fmt.Sprintf("%s@%s", u.Name, u.remoteAddress)
+}
+
+// HasFilePermission is used to determine whether user is alowed to read a file.
+func (u *User) HasFilePermission(filePath string) (hasPermission bool) {
+ cleanPath, err := filepath.EvalSymlinks(filePath)
+ if err != nil {
+ logger.Error(u, filePath, "Unable to evaluate symlinks", err)
+ hasPermission = false
+ return
+ }
+
+ cleanPath, err = filepath.Abs(cleanPath)
+ if err != nil {
+ logger.Error(u, cleanPath, "Unable to make file path absolute", err)
+ hasPermission = false
+ return
+ }
+
+ if cleanPath != filePath {
+ logger.Info(u, filePath, cleanPath, "Calculated new clean path from original file path (possibly symlink)")
+ }
+
+ hasPermission, err = u.hasFilePermission(cleanPath)
+ if err != nil {
+ logger.Warn(u, cleanPath, err)
+ }
+
+ return
+}
+
+func (u *User) hasFilePermission(cleanPath string) (bool, error) {
+ // First check file system Linux/UNIX permission.
+ if _, err := permissions.ToRead(u.Name, cleanPath); err != nil {
+ return false, fmt.Errorf("User without OS file system permissions to read file: '%v'", err)
+ }
+ logger.Info(u, cleanPath, "User has OS file system permissions to read file")
+
+ // If file system permission is given, also check permissions
+ // as configured in DTail config file.
+ if len(u.permissions) == 0 {
+ p, err := config.ServerUserPermissions(u.Name)
+ if err != nil {
+ return false, err
+ }
+ u.permissions = p
+ }
+
+ var hasPermission bool
+ var err error
+
+ if hasPermission, err = u.iteratePaths(cleanPath); err != nil {
+ return false, err
+ }
+
+ // Only allow to follow regular files or symlinks.
+ info, err := os.Lstat(cleanPath)
+ if err != nil {
+ return false, fmt.Errorf("Unable to determine file type: '%v'", err)
+ }
+
+ if !info.Mode().IsRegular() {
+ return false, fmt.Errorf("Can only open regular files or follow symlinks")
+ }
+
+ return hasPermission, nil
+}
+
+func (u *User) iteratePaths(cleanPath string) (bool, error) {
+ for _, permission := range u.permissions {
+ var regexStr string
+ var negate bool
+
+ if strings.HasPrefix(permission, "!") {
+ regexStr = permission[1:]
+ negate = true
+ }
+ regexStr = permission
+ negate = false
+
+ re, err := regexp.Compile(regexStr)
+ if err != nil {
+ return false, fmt.Errorf("Permission test failed, can't compile regex '%s': '%v'", regexStr, err)
+ }
+
+ if negate && re.MatchString(cleanPath) {
+ return false, fmt.Errorf("Permission test failed, matching negative pattern '%s'", permission)
+ }
+
+ if !negate && re.MatchString(cleanPath) {
+ logger.Info(u, cleanPath, "Permission test passed partially, matching positive pattern", permission)
+ }
+ }
+
+ return true, nil
+}
diff --git a/internal/version/version.go b/internal/version/version.go
new file mode 100644
index 0000000..d036a68
--- /dev/null
+++ b/internal/version/version.go
@@ -0,0 +1,40 @@
+package version
+
+import (
+ "fmt"
+ "os"
+
+ "github.com/mimecast/dtail/internal/color"
+)
+
+// Name of DTail.
+const Name = "DTail"
+
+// Version of DTail.
+const Version = "1.1.0"
+
+// Additional information.
+const Additional = "develop"
+
+// String representation of the DTail version.
+func String() string {
+ return fmt.Sprintf("%s v%v %s", Name, Version, Additional)
+}
+
+// PaintedString is a prettier string representation of the DTail version.
+func PaintedString() string {
+ if !color.Colored {
+ return String()
+ }
+ name := color.Paint(color.Yellow, Name)
+ version := color.Paint(color.Blue, Version)
+ descr := color.Paint(color.Green, Additional)
+
+ return fmt.Sprintf("%s %v %s", name, version, descr)
+}
+
+// PrintAndExit prints the program version and exists.
+func PrintAndExit() {
+ fmt.Println(PaintedString())
+ os.Exit(0)
+}