diff options
| author | Paul Bütow <pbuetow@mimecast.com> | 2020-01-20 18:41:05 +0000 |
|---|---|---|
| committer | Paul Bütow <pbuetow@mimecast.com> | 2020-01-21 14:35:23 +0000 |
| commit | c128865c4c7411c29a59fca9a3a2f95537686d7b (patch) | |
| tree | 193bccc70d942c8b70cc93fae2670263701e43aa /internal | |
| parent | 3755a9911ecb05886577095f2b8cc8b9e4066a3a (diff) | |
Move commands to cmd/ and move internal dependencies to internal/
Diffstat (limited to 'internal')
67 files changed, 6793 insertions, 0 deletions
diff --git a/internal/clients/args.go b/internal/clients/args.go new file mode 100644 index 0000000..5fe0a72 --- /dev/null +++ b/internal/clients/args.go @@ -0,0 +1,18 @@ +package clients + +import ( + "github.com/mimecast/dtail/internal/omode" +) + +// Args is a helper struct to summarize common client arguments. +type Args struct { + Mode omode.Mode + ServersStr string + UserName string + Files string + Regex string + TrustAllHosts bool + Discovery string + ConnectionsPerCPU int + PingTimeout int +} diff --git a/internal/clients/baseclient.go b/internal/clients/baseclient.go new file mode 100644 index 0000000..574ae94 --- /dev/null +++ b/internal/clients/baseclient.go @@ -0,0 +1,137 @@ +package clients + +import ( + "regexp" + "sync" + "time" + + "github.com/mimecast/dtail/internal/clients/remote" + "github.com/mimecast/dtail/internal/discovery" + "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/omode" + "github.com/mimecast/dtail/internal/ssh/client" + + gossh "golang.org/x/crypto/ssh" +) + +// This is the main client data structure. +type baseClient struct { + Args + // To display client side stats + stats *stats + // List of remote servers to connect to. + servers []string + // We have one connection per remote server. + connections []*remote.Connection + // SSH auth methods to use to connect to the remote servers. + sshAuthMethods []gossh.AuthMethod + // To deal with SSH host keys + hostKeyCallback *client.HostKeyCallback + // To stop the client. + stop chan struct{} + // To indicate that the client has stopped. + stopped chan struct{} + // Throttle how fast we initiate SSH connections concurrently + throttleCh chan struct{} + // Retry connection upon failure? + retry bool + // Connection helper. + maker connectionMaker +} + +func (c *baseClient) init(maker connectionMaker) { + logger.Info("Initiating base client") + + c.maker = maker + //c.connections = make(map[string]*remote.Connection) + c.sshAuthMethods, c.hostKeyCallback = client.InitSSHAuthMethods(c.TrustAllHosts, c.throttleCh) + + // Retrieve a shuffled list of remote dtail servers. + shuffleServers := true + discoveryService := discovery.New(c.Discovery, c.ServersStr, shuffleServers) + for _, server := range discoveryService.ServerList() { + c.connections = append(c.connections, c.maker.makeConnection(server, c.sshAuthMethods, c.hostKeyCallback)) + } + + if _, err := regexp.Compile(c.Regex); err != nil { + logger.FatalExit(c.Regex, "Can't test compile regex", err) + } + + // Periodically check for unknown hosts, and ask the user whether to trust them or not. + go c.hostKeyCallback.PromptAddHosts(c.stop) + + // Periodically print out connection stats to the client. + c.stats = newTailStats(len(c.connections)) + go c.stats.periodicLogStats(c.throttleCh, c.stop) +} + +func (c *baseClient) Start() (status int) { + active := make(chan struct{}, len(c.connections)) + + var wg sync.WaitGroup + wg.Add(len(c.connections)) + + for i, conn := range c.connections { + go func(i int, conn *remote.Connection) { + active <- struct{}{} + defer func() { + logger.Debug(conn.Server, "Disconnected completely...") + <-active + }() + wg.Done() + + for { + conn.Start(c.throttleCh, c.stats.connectionsEstCh) + if !c.retry { + return + } + time.Sleep(time.Second * 2) + logger.Debug(conn.Server, "Reconencting") + conn = c.maker.makeConnection(conn.Server, c.sshAuthMethods, c.hostKeyCallback) + c.connections[i] = conn + } + }(i, conn) + } + + wg.Wait() + c.waitUntilDone(active) + + return +} + +func (c *baseClient) waitUntilDone(active chan struct{}) { + defer close(c.stopped) + + if c.Mode != omode.TailClient { + c.waitUntilZero(active) + logger.Info("All connections stopped") + return + } + + <-c.stop + logger.Info("Stopping client") + for _, conn := range c.connections { + conn.Stop() + } + + c.waitUntilZero(active) +} + +func (c *baseClient) waitUntilZero(active chan struct{}) { + for { + logger.Debug("Active connections", len(active)) + if len(active) == 0 { + return + } + time.Sleep(time.Second) + } +} + +func (c *baseClient) Stop() { + close(c.stop) + <-c.WaitC() +} + +func (c *baseClient) WaitC() <-chan struct{} { + return c.stopped +} diff --git a/internal/clients/catclient.go b/internal/clients/catclient.go new file mode 100644 index 0000000..5ea701d --- /dev/null +++ b/internal/clients/catclient.go @@ -0,0 +1,53 @@ +package clients + +import ( + "errors" + "fmt" + "runtime" + "strings" + + "github.com/mimecast/dtail/internal/clients/handlers" + "github.com/mimecast/dtail/internal/clients/remote" + "github.com/mimecast/dtail/internal/omode" + "github.com/mimecast/dtail/internal/ssh/client" + + gossh "golang.org/x/crypto/ssh" +) + +// CatClient is a client for returning a whole file from the beginning to the end. +type CatClient struct { + baseClient +} + +// NewCatClient returns a new cat client. +func NewCatClient(args Args) (*CatClient, error) { + if args.Regex != "" { + return nil, errors.New("Can't use regex with 'cat' operating mode") + } + + args.Regex = "." + args.Mode = omode.CatClient + + c := CatClient{ + baseClient: baseClient{ + Args: args, + stop: make(chan struct{}), + stopped: make(chan struct{}), + throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()), + retry: false, + }, + } + + c.init(c) + + return &c, nil +} + +func (c CatClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection { + conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback) + conn.Handler = handlers.NewClientHandler(server, c.PingTimeout) + for _, file := range strings.Split(c.Files, ",") { + conn.Commands = append(conn.Commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex)) + } + return conn +} diff --git a/internal/clients/client.go b/internal/clients/client.go new file mode 100644 index 0000000..85d1aae --- /dev/null +++ b/internal/clients/client.go @@ -0,0 +1,7 @@ +package clients + +// Client is the interface for the end user command line client. +type Client interface { + Start() int + Stop() +} diff --git a/internal/clients/connectionmaker.go b/internal/clients/connectionmaker.go new file mode 100644 index 0000000..0617992 --- /dev/null +++ b/internal/clients/connectionmaker.go @@ -0,0 +1,12 @@ +package clients + +import ( + "github.com/mimecast/dtail/internal/clients/remote" + "github.com/mimecast/dtail/internal/ssh/client" + + gossh "golang.org/x/crypto/ssh" +) + +type connectionMaker interface { + makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection +} diff --git a/internal/clients/grepclient.go b/internal/clients/grepclient.go new file mode 100644 index 0000000..c568f63 --- /dev/null +++ b/internal/clients/grepclient.go @@ -0,0 +1,53 @@ +package clients + +import ( + "errors" + "fmt" + "runtime" + "strings" + + "github.com/mimecast/dtail/internal/clients/handlers" + "github.com/mimecast/dtail/internal/clients/remote" + "github.com/mimecast/dtail/internal/omode" + "github.com/mimecast/dtail/internal/ssh/client" + + gossh "golang.org/x/crypto/ssh" +) + +// GrepClient searches a remote file for all lines matching a regular expression. Only the matching lines are displayed. +type GrepClient struct { + baseClient +} + +// NewGrepClient creates a new grep client. +func NewGrepClient(args Args) (*GrepClient, error) { + if args.Regex == "" { + return nil, errors.New("No regex specified, use '-regex' flag") + } + args.Mode = omode.GrepClient + + c := GrepClient{ + baseClient: baseClient{ + Args: args, + stop: make(chan struct{}), + stopped: make(chan struct{}), + throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()), + retry: false, + }, + } + + c.init(c) + + return &c, nil +} + +func (c GrepClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection { + conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback) + conn.Handler = handlers.NewClientHandler(server, c.PingTimeout) + + for _, file := range strings.Split(c.Files, ",") { + conn.Commands = append(conn.Commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex)) + } + + return conn +} diff --git a/internal/clients/handlers/basehandler.go b/internal/clients/handlers/basehandler.go new file mode 100644 index 0000000..19246f9 --- /dev/null +++ b/internal/clients/handlers/basehandler.go @@ -0,0 +1,134 @@ +package handlers + +import ( + "github.com/mimecast/dtail/internal/logger" + "errors" + "fmt" + "io" + "strings" + "time" +) + +type baseHandler struct { + server string + shellStarted bool + commands chan string + pong chan struct{} + receiveBuf []byte + stop chan struct{} + pingTimeout int +} + +func (h *baseHandler) Server() string { + return h.server +} + +// Used to determine whether server is still responding to requests or not. +func (h *baseHandler) Ping() error { + if h.pingTimeout == 0 { + // Server ping disabled + return nil + } + + if err := h.SendCommand("ping"); err != nil { + return err + } + + select { + case <-h.pong: + return nil + case <-time.After(time.Duration(h.pingTimeout) * time.Second): + } + + return errors.New("Didn't receive any server pongs (ping replies)") +} + +func (h *baseHandler) SendCommand(command string) error { + if command == "ping" { + logger.Trace("Sending command", h.server, command) + } else { + logger.Debug("Sending command", h.server, command) + } + + select { + case h.commands <- fmt.Sprintf("%s;", command): + case <-time.After(time.Second * 5): + return errors.New("Timed out sending command " + command) + case <-h.stop: + } + + return nil +} + +// Read data from the dtail server via Writer interface. +func (h *baseHandler) Write(p []byte) (n int, err error) { + for _, b := range p { + h.receiveBuf = append(h.receiveBuf, b) + if b == '\n' { + if len(h.receiveBuf) == 0 { + continue + } + message := string(h.receiveBuf) + h.handleMessageType(message) + } + } + + return len(p), nil +} + +// Send data to the dtail server via Reader interface. +func (h *baseHandler) Read(p []byte) (n int, err error) { + select { + case command := <-h.commands: + n = copy(p, []byte(command)) + case <-h.stop: + return 0, io.EOF + } + return +} + +// Handle various message types. +func (h *baseHandler) handleMessageType(message string) { + if len(h.receiveBuf) == 0 { + return + } + // Hidden server commands starti with a dot "." + if h.receiveBuf[0] == '.' { + h.handleHiddenMessage(message) + h.receiveBuf = h.receiveBuf[:0] + return + } + + // Silent mode will only print out remote logs but not remote server + // commands. But remote server commands will be still logged to ./log/. + if logger.Mode == logger.SilentMode { + if h.receiveBuf[0] == 'R' { + logger.Raw(message) + } + h.receiveBuf = h.receiveBuf[:0] + return + } + logger.Raw(message) + h.receiveBuf = h.receiveBuf[:0] +} + +// Handle messages received from server which are not meant to be displayed +// to the end user. +func (h *baseHandler) handleHiddenMessage(message string) { + switch { + case strings.HasPrefix(message, ".pong"): + h.pong <- struct{}{} + case strings.HasPrefix(message, ".syn close connection"): + h.SendCommand("ack close connection") + } +} + +// Stop the handler. +func (h *baseHandler) Stop() { + select { + case <-h.stop: + default: + logger.Debug("Stopping base handler", h.server) + close(h.stop) + } +} diff --git a/internal/clients/handlers/clienthandler.go b/internal/clients/handlers/clienthandler.go new file mode 100644 index 0000000..4738cd3 --- /dev/null +++ b/internal/clients/handlers/clienthandler.go @@ -0,0 +1,26 @@ +package handlers + +import ( + "github.com/mimecast/dtail/internal/logger" +) + +// ClientHandler is the basic client handler interface. +type ClientHandler struct { + baseHandler +} + +// NewClientHandler creates a new client handler. +func NewClientHandler(server string, pingTimeout int) *ClientHandler { + logger.Debug(server, "Creating new client handler") + + return &ClientHandler{ + baseHandler{ + server: server, + shellStarted: false, + commands: make(chan string), + pong: make(chan struct{}, 1), + stop: make(chan struct{}), + pingTimeout: pingTimeout, + }, + } +} diff --git a/internal/clients/handlers/handler.go b/internal/clients/handlers/handler.go new file mode 100644 index 0000000..2013be0 --- /dev/null +++ b/internal/clients/handlers/handler.go @@ -0,0 +1,12 @@ +package handlers + +import "io" + +// Handler provides all methods which can be run on any client handler. +type Handler interface { + io.ReadWriter + Ping() error + Stop() + SendCommand(command string) error + Server() string +} diff --git a/internal/clients/handlers/healthhandler.go b/internal/clients/handlers/healthhandler.go new file mode 100644 index 0000000..4051e2c --- /dev/null +++ b/internal/clients/handlers/healthhandler.go @@ -0,0 +1,75 @@ +package handlers + +import ( + "errors" + "fmt" + "time" +) + +// HealthHandler implements the handler required for health checks. +type HealthHandler struct { + // Buffer of incoming data from server. + receiveBuf []byte + // To send commands to the server. + commands chan string + // To receive messages from the server. + receive chan<- string + // The remote server address + server string +} + +// NewHealthHandler returns a new health check handler. +func NewHealthHandler(server string, receive chan<- string) *HealthHandler { + h := HealthHandler{ + server: server, + receive: receive, + commands: make(chan string), + } + + return &h +} + +// Server returns the remote server name. +func (h *HealthHandler) Server() string { + return h.server +} + +// Stop is not of use for health check handler. +func (h *HealthHandler) Stop() { + // Nothing done here. +} + +// Ping is not of use for health check handler. +func (h *HealthHandler) Ping() error { + return nil +} + +// SendCommand send a DTail command to the server. +func (h *HealthHandler) SendCommand(command string) error { + select { + case h.commands <- fmt.Sprintf("%s;", command): + case <-time.NewTimer(time.Second * 10).C: + return errors.New("Timed out sending command " + command) + } + + return nil +} + +// Server writes byte stream to client. +func (h *HealthHandler) Write(p []byte) (n int, err error) { + for _, b := range p { + h.receiveBuf = append(h.receiveBuf, b) + if b == '\n' { + h.receive <- string(h.receiveBuf) + h.receiveBuf = h.receiveBuf[:0] + } + } + + return len(p), nil +} + +// Server reads byte stream from client. +func (h *HealthHandler) Read(p []byte) (n int, err error) { + n = copy(p, []byte(<-h.commands)) + return +} diff --git a/internal/clients/handlers/maprhandler.go b/internal/clients/handlers/maprhandler.go new file mode 100644 index 0000000..d76cdfd --- /dev/null +++ b/internal/clients/handlers/maprhandler.go @@ -0,0 +1,74 @@ +package handlers + +import ( + "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/mapr" + "github.com/mimecast/dtail/internal/mapr/client" + "strings" +) + +// MaprHandler is the handler used on the client side for running mapreduce aggregations. +type MaprHandler struct { + baseHandler + aggregate *client.Aggregate + query *mapr.Query + count uint64 +} + +// NewMaprHandler returns a new mapreduce client handler. +func NewMaprHandler(server string, query *mapr.Query, globalGroup *mapr.GlobalGroupSet, pingTimeout int) *MaprHandler { + return &MaprHandler{ + baseHandler: baseHandler{ + server: server, + shellStarted: false, + commands: make(chan string), + pong: make(chan struct{}, 1), + stop: make(chan struct{}), + pingTimeout: pingTimeout, + }, + query: query, + aggregate: client.NewAggregate(server, query, globalGroup), + } +} + +// Read data from the dtail server via Writer interface. +func (h *MaprHandler) Write(p []byte) (n int, err error) { + for _, b := range p { + h.baseHandler.receiveBuf = append(h.baseHandler.receiveBuf, b) + if b == '\n' { + if len(h.baseHandler.receiveBuf) == 0 { + continue + } + message := string(h.baseHandler.receiveBuf) + + if h.baseHandler.receiveBuf[0] == 'A' { + h.handleAggregateMessage(strings.TrimSpace(message)) + h.baseHandler.receiveBuf = h.baseHandler.receiveBuf[:0] + continue + } + h.baseHandler.handleMessageType(message) + } + } + + return len(p), nil +} + +// Handle a message received from server including mapr aggregation +// related data. +func (h *MaprHandler) handleAggregateMessage(message string) { + h.count++ + parts := strings.Split(message, "|") + + // Index 0 contains 'AGGREGATE', 1 contains server host. + // Aggregation data begins from index 2. + logger.Debug("Received aggregate data", h.server, h.count) + h.aggregate.Aggregate(parts[2:]) + logger.Debug("Aggregated aggregate data", h.server, h.count) +} + +// Stop stops the mapreduce client handler. +func (h *MaprHandler) Stop() { + logger.Debug("Stopping mapreduce handler", h.server) + h.aggregate.Stop() + h.baseHandler.Stop() +} diff --git a/internal/clients/healthclient.go b/internal/clients/healthclient.go new file mode 100644 index 0000000..ff13b83 --- /dev/null +++ b/internal/clients/healthclient.go @@ -0,0 +1,95 @@ +package clients + +import ( + "fmt" + "runtime" + "strings" + "time" + + "github.com/mimecast/dtail/internal/clients/handlers" + "github.com/mimecast/dtail/internal/clients/remote" + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/omode" + + gossh "golang.org/x/crypto/ssh" +) + +// HealthClient is used for health checking (e.g. via Nagios) +type HealthClient struct { + // Client operating mode + mode omode.Mode + // The remote server address + server string + // SSH user name + userName string + // SSH auth methods to use to connect to the remote servers. + sshAuthMethods []gossh.AuthMethod +} + +// NewHealthClient returns a new healh client. +func NewHealthClient(mode omode.Mode) (*HealthClient, error) { + c := HealthClient{ + mode: mode, + server: fmt.Sprintf("%s:%d", config.Server.SSHBindAddress, config.Common.SSHPort), + userName: config.ControlUser, + } + c.initSSHAuthMethods() + + return &c, nil +} + +// Start the health client. +func (c *HealthClient) Start() (status int) { + receive := make(chan string) + + throttleCh := make(chan struct{}, runtime.NumCPU()) + statsCh := make(chan struct{}, 1) + + conn := remote.NewOneOffConnection(c.server, c.userName, c.sshAuthMethods) + conn.Handler = handlers.NewHealthHandler(c.server, receive) + conn.Commands = []string{c.mode.String()} + + go conn.Start(throttleCh, statsCh) + defer conn.Stop() + + for { + select { + case data := <-receive: + // Parse recieved data. + s := strings.Split(data, "|") + message := s[len(s)-1] + if strings.HasPrefix(message, "done;") { + return + } + + // Set severity. + s = strings.Split(message, ":") + switch s[0] { + case "OK": + case "WARNING": + if status < 1 { + status = 1 + } + case "CRITICAL": + status = 2 + case "UNKNOWN": + status = 3 + default: + fmt.Printf("CRITICAL: Unexpected server response: '%s'\n", message) + status = 2 + return + } + fmt.Print(message) + + case <-time.After(time.Second * 2): + status = 2 + fmt.Println("CRITICAL: Could not communicate with DTail server") + return + } + } +} + +// Initialize SSH auth methods. +func (c *HealthClient) initSSHAuthMethods() { + c.sshAuthMethods = append(c.sshAuthMethods, gossh.Password(config.ControlUser)) +} diff --git a/internal/clients/maprclient.go b/internal/clients/maprclient.go new file mode 100644 index 0000000..9070827 --- /dev/null +++ b/internal/clients/maprclient.go @@ -0,0 +1,152 @@ +package clients + +import ( + "errors" + "fmt" + "runtime" + "strings" + "time" + + "github.com/mimecast/dtail/internal/clients/handlers" + "github.com/mimecast/dtail/internal/clients/remote" + "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/mapr" + "github.com/mimecast/dtail/internal/omode" + "github.com/mimecast/dtail/internal/ssh/client" + + gossh "golang.org/x/crypto/ssh" +) + +// MaprClient is used for running mapreduce aggregations on remote files. +type MaprClient struct { + baseClient + // Query string for mapr aggregations + queryStr string + // Global group set for merged mapr aggregation results + globalGroup *mapr.GlobalGroupSet + // The query object (constructed from queryStr) + query *mapr.Query + // Additative result or new result every run? + additative bool +} + +// NewMaprClient returns a new mapreduce client. +func NewMaprClient(args Args, queryStr string) (*MaprClient, error) { + if queryStr == "" { + return nil, errors.New("No mapreduce query specified, use '-query' flag") + } + + c := MaprClient{ + baseClient: baseClient{ + Args: args, + stop: make(chan struct{}), + stopped: make(chan struct{}), + throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()), + retry: args.Mode == omode.TailClient, + }, + queryStr: queryStr, + additative: args.Mode == omode.MapClient, + } + + query, err := mapr.NewQuery(c.queryStr) + if err != nil { + logger.FatalExit(c.queryStr, "Can't parse mapr query", err) + } + + c.query = query + + switch c.query.Table { + case "*": + c.Regex = fmt.Sprintf("\\|MAPREDUCE:\\|") + case ".": + c.Regex = "." + default: + c.Regex = fmt.Sprintf("\\|MAPREDUCE:%s\\|", c.query.Table) + } + + c.globalGroup = mapr.NewGlobalGroupSet() + c.baseClient.init(c) + + return &c, nil +} + +func (c MaprClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection { + conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback) + conn.Handler = handlers.NewMaprHandler(conn.Server, c.query, c.globalGroup, c.PingTimeout) + + conn.Commands = append(conn.Commands, fmt.Sprintf("map %s", c.query.RawQuery)) + commandStr := "tail" + if c.additative { + commandStr = "cat" + } + + for _, file := range strings.Split(c.Files, ",") { + conn.Commands = append(conn.Commands, fmt.Sprintf("%s %s regex %s", commandStr, file, c.Regex)) + } + + return conn +} + +// Start starts the mapreduce client. +func (c *MaprClient) Start() (status int) { + if c.query.Outfile == "" { + // Only print out periodic results if we don't write an outfile + go c.periodicPrintResults() + } + + status = c.baseClient.Start() + if c.additative { + c.recievedFinalResult() + } + c.baseClient.Stop() + + return +} + +func (c *MaprClient) recievedFinalResult() { + logger.Info("Received final mapreduce result") + + if c.query.Outfile == "" { + c.printResults() + return + } + + logger.Info(fmt.Sprintf("Writing final mapreduce result to '%s'", c.query.Outfile)) + err := c.globalGroup.WriteResult(c.query) + if err != nil { + logger.FatalExit(err) + return + } + logger.Info(fmt.Sprintf("Wrote final mapreduce result to '%s'", c.query.Outfile)) +} + +func (c *MaprClient) periodicPrintResults() { + for { + select { + case <-time.After(c.query.Interval): + logger.Info("Gathering interim mapreduce result") + c.printResults() + case <-c.baseClient.stop: + return + } + } +} + +func (c *MaprClient) printResults() { + var result string + var err error + var numLines int + + if c.additative { + result, numLines, err = c.globalGroup.Result(c.query) + } else { + result, numLines, err = c.globalGroup.SwapOut().Result(c.query) + } + if err != nil { + logger.FatalExit(err) + } + if numLines > 0 { + logger.Raw(fmt.Sprintf("%s\n", c.query.RawQuery)) + logger.Raw(result) + } +} diff --git a/internal/clients/remote/connection.go b/internal/clients/remote/connection.go new file mode 100644 index 0000000..bfc7bc5 --- /dev/null +++ b/internal/clients/remote/connection.go @@ -0,0 +1,230 @@ +package remote + +import ( + "github.com/mimecast/dtail/internal/clients/handlers" + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/ssh/client" + "fmt" + "io" + "strconv" + "strings" + "time" + + "golang.org/x/crypto/ssh" +) + +// Connection represents a client connection connection to a single server. +type Connection struct { + // The remote server's hostname connected to. + Server string + // The remote server's port connected to. + port int + // The SSH client configuration used. + config *ssh.ClientConfig + // The SSH client handler to use. + Handler handlers.Handler + // DTail commands sent from client to server. When client loses + // connection to the server it re-connects automatically and sends the + // same commands again. + Commands []string + // Is it a persistent connection or a one-off? + isOneOff bool + // Used to stop the connection + stop chan struct{} + // To deal with SSH server host keys + hostKeyCallback *client.HostKeyCallback +} + +// NewConnection returns a new connection. +func NewConnection(server string, userName string, authMethods []ssh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *Connection { + logger.Debug(server, "Creating new connection") + + c := Connection{ + hostKeyCallback: hostKeyCallback, + config: &ssh.ClientConfig{ + User: userName, + Auth: authMethods, + HostKeyCallback: hostKeyCallback.Wrap(), + Timeout: time.Second * 3, + }, + stop: make(chan struct{}), + } + + c.initServerPort(server) + + return &c +} + +// NewOneOffConnection creates new one-off connection (only for sending a series of commands and then quit). +func NewOneOffConnection(server string, userName string, authMethods []ssh.AuthMethod) *Connection { + c := Connection{ + config: &ssh.ClientConfig{ + User: userName, + Auth: authMethods, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + }, + stop: make(chan struct{}), + isOneOff: true, + } + + c.initServerPort(server) + + return &c +} + +// Attempt to parse the server port address from the provided server FQDN. +func (c *Connection) initServerPort(server string) { + c.Server = server + c.port = config.Common.SSHPort + parts := strings.Split(server, ":") + + if len(parts) == 2 { + logger.Debug("Parsing port from hostname", parts) + port, err := strconv.Atoi(parts[1]) + if err != nil { + logger.FatalExit("Unable to parse client port", server, parts, err) + } + c.Server = parts[0] + c.port = port + } +} + +// Start the server connection. Build up SSH session and send some DTail commandc. +func (c *Connection) Start(throttleCh, statsCh chan struct{}) { + select { + case <-c.stop: + logger.Info(c.Server, c.port, "Disconnecting client") + return + default: + } + + // Wait for SSH connection throttler + throttleCh <- struct{}{} + + // Wait until connection has been initiated or an error occured + // during initialization. + throttleStopCh := make(chan struct{}, 2) + go func() { + <-throttleStopCh + <-throttleCh + }() + + if err := c.dial(c.Server, c.port, throttleStopCh, statsCh); err != nil { + logger.Warn(c.Server, c.port, err) + throttleStopCh <- struct{}{} + + if c.hostKeyCallback.Untrusted(fmt.Sprintf("%s:%d", c.Server, c.port)) { + logger.Debug("Not trusting host, not trying to re-connect", c.Server, c.port) + return + } + } +} + +// Dail into a new SSH connection. Close connection in case of an error. +func (c *Connection) dial(host string, port int, throttleStopCh, statsCh chan struct{}) error { + statsCh <- struct{}{} + defer func() { <-statsCh }() + + logger.Debug(host, "dial") + address := fmt.Sprintf("%s:%d", host, port) + + client, err := ssh.Dial("tcp", address, c.config) + if err != nil { + return err + } + defer client.Close() + + return c.session(client, throttleStopCh) +} + +// Create the SSH session. Close the session in case of an error. +func (c *Connection) session(client *ssh.Client, throttleStopCh chan<- struct{}) error { + logger.Debug(c.Server, "session") + + session, err := client.NewSession() + if err != nil { + return err + } + defer session.Close() + + return c.handle(session, throttleStopCh) +} + +// Handle the SSH session. Also send periodic pings to the server in order +// to determine that session is still intact. +func (c *Connection) handle(session *ssh.Session, throttleStopCh chan<- struct{}) error { + defer c.Handler.Stop() + + logger.Debug(c.Server, "handle") + + stdinPipe, err := session.StdinPipe() + if err != nil { + return err + } + + stdoutPipe, err := session.StdoutPipe() + if err != nil { + return err + } + + if err := session.Shell(); err != nil { + return err + } + + // Establish Bi-directional pipe between SSH session and client handler. + brokenStdinPipe := make(chan struct{}) + go func() { + defer close(brokenStdinPipe) + io.Copy(stdinPipe, c.Handler) + }() + + brokenStdoutPipe := make(chan struct{}) + go func() { + defer close(brokenStdoutPipe) + io.Copy(c.Handler, stdoutPipe) + }() + + // SSH session established, other goroutine can initiate session now. + throttleStopCh <- struct{}{} + + // Send all commands to client. + for _, command := range c.Commands { + logger.Debug(command) + c.Handler.SendCommand(command) + } + + if !c.isOneOff { + return c.periodicAliveCheck(brokenStdinPipe, brokenStdoutPipe) + } + + <-c.stop + + // Normal shutdown, all fine + return nil +} + +// Periodically check whether connection is still alive or not. +func (c *Connection) periodicAliveCheck(brokenStdinPipe, brokenStdoutPipe <-chan struct{}) error { + for { + select { + case <-time.After(time.Second * 3): + if err := c.Handler.Ping(); err != nil { + return err + } + case <-brokenStdinPipe: + logger.Debug("Broken stdin pipe", c.Server, c.port) + return nil + case <-brokenStdoutPipe: + logger.Debug("Broken stdout pipe", c.Server, c.port) + return nil + case <-c.stop: + return nil + } + } +} + +// Stop the connection. +func (c *Connection) Stop() { + close(c.stop) +} diff --git a/internal/clients/stats.go b/internal/clients/stats.go new file mode 100644 index 0000000..d36cef6 --- /dev/null +++ b/internal/clients/stats.go @@ -0,0 +1,81 @@ +package clients + +import ( + "github.com/mimecast/dtail/internal/logger" + "fmt" + "runtime" + "sync" + "time" +) + +// Used to collect and display various client stats. +type stats struct { + // Total amount servers to connect to. + connectionsTotal int + // To keep track of what connected and disconnected + connectionsEstCh chan struct{} + // Amount of servers connections are established. + connected int + // To synchronize concurrent access. + mutex sync.Mutex +} + +func newTailStats(connectionsTotal int) *stats { + return &stats{ + connectionsTotal: connectionsTotal, + connectionsEstCh: make(chan struct{}, connectionsTotal), + connected: 0, + } +} + +func (s *stats) periodicLogStats(throttleCh chan struct{}, stop <-chan struct{}) { + connectedLast := 0 + statsInterval := 5 + + for { + select { + case <-time.After(time.Second * time.Duration(statsInterval)): + case <-stop: + return + } + + connected := len(s.connectionsEstCh) + throttle := len(throttleCh) + + newConnections := connected - connectedLast + connectionsPerSecond := float64(newConnections) / float64(statsInterval) + s.log(connected, newConnections, connectionsPerSecond, throttle) + + connectedLast = connected + + s.mutex.Lock() + s.connected = connected + s.mutex.Unlock() + } +} + +func (s *stats) numConnected() int { + s.mutex.Lock() + defer s.mutex.Unlock() + + return s.connected +} + +func (s *stats) log(connected, newConnections int, connectionsPerSecond float64, throttle int) { + percConnected := percentOf(float64(s.connectionsTotal), float64(connected)) + + connectedStr := fmt.Sprintf("connected=%d/%d(%d%%)", connected, s.connectionsTotal, int(percConnected)) + newConnStr := fmt.Sprintf("new=%d", newConnections) + rateStr := fmt.Sprintf("rate=%2.2f/s", connectionsPerSecond) + throttleStr := fmt.Sprintf("throttle=%d", throttle) + cpusGoroutinesStr := fmt.Sprintf("cpus/goroutines=%d/%d", runtime.NumCPU(), runtime.NumGoroutine()) + + logger.Info("stats", connectedStr, newConnStr, rateStr, throttleStr, cpusGoroutinesStr) +} + +func percentOf(total float64, value float64) float64 { + if total == 0 || total == value { + return 100 + } + return value / (total / 100.0) +} diff --git a/internal/clients/tailclient.go b/internal/clients/tailclient.go new file mode 100644 index 0000000..674ca36 --- /dev/null +++ b/internal/clients/tailclient.go @@ -0,0 +1,49 @@ +package clients + +import ( + "fmt" + "runtime" + "strings" + + "github.com/mimecast/dtail/internal/clients/handlers" + "github.com/mimecast/dtail/internal/clients/remote" + "github.com/mimecast/dtail/internal/omode" + "github.com/mimecast/dtail/internal/ssh/client" + + gossh "golang.org/x/crypto/ssh" +) + +// TailClient is used for tailing remote log files (opening, seeking to the end and returning only new incoming lines). +type TailClient struct { + baseClient +} + +// NewTailClient returns a new TailClient. +func NewTailClient(args Args) (*TailClient, error) { + args.Mode = omode.TailClient + + c := TailClient{ + baseClient: baseClient{ + Args: args, + stop: make(chan struct{}), + stopped: make(chan struct{}), + throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()), + retry: true, + }, + } + + c.init(c) + + return &c, nil +} + +func (c TailClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection { + conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback) + conn.Handler = handlers.NewClientHandler(server, c.PingTimeout) + + for _, file := range strings.Split(c.Files, ",") { + conn.Commands = append(conn.Commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex)) + } + + return conn +} diff --git a/internal/color/color.go b/internal/color/color.go new file mode 100644 index 0000000..0736199 --- /dev/null +++ b/internal/color/color.go @@ -0,0 +1,70 @@ +// Package color is used to prettify console output via ANSII terminal colors. +package color + +import ( + "fmt" +) + +// Color name. +type Color string + +// Attribute of a color. +type Attribute string + +// The possible color variations. +const ( + escape = "\x1b" + reset = escape + "[0m" + seq string = "%s%s%s" + + Gray Color = escape + "[30m" + Red Color = escape + "[31m" + Green Color = escape + "[32m" + Orange Color = escape + "[33m" + Blue Color = escape + "[34m" + Magenta Color = escape + "[35m" + Yellow Color = escape + "[36m" + LightGray Color = escape + "[37m" + + BgGray Color = escape + "[40m" + BgRed Color = escape + "[41m" + BgGreen Color = escape + "[42m" + BgOrange Color = escape + "[43m" + BgBlue Color = escape + "[44m" + BgMagenta Color = escape + "[45m" + BgYellow Color = escape + "[46m" + BgLightGray Color = escape + "[47m" + + Bold Attribute = escape + "[1m" + Italic Attribute = escape + "[3m" + Underline Attribute = escape + "[4m" + ReverseColor Attribute = escape + "[7m" + + resetBold = escape + "[22m" + resetItalic = escape + "[23m" + resetUnderline = escape + "[24m" + + Test Color = BgYellow + TestAttr Attribute = Bold +) + +// Colored DTail client output enabled. +var Colored bool + +// Paint a given string in a given color. +func Paint(c Color, s string) string { + return fmt.Sprintf(seq, c, s, reset) +} + +// Attr adds a given attribute to a given string, such as "bold" or "italic". +func Attr(c Attribute, s string) string { + switch c { + case Bold: + return fmt.Sprintf(seq, Bold, s, resetBold) + case Italic: + return fmt.Sprintf(seq, Italic, s, resetItalic) + case Underline: + return fmt.Sprintf(seq, Underline, s, resetUnderline) + } + panic("Unknown attribute") +} diff --git a/internal/color/colorfy.go b/internal/color/colorfy.go new file mode 100644 index 0000000..9ae46f5 --- /dev/null +++ b/internal/color/colorfy.go @@ -0,0 +1,58 @@ +package color + +import ( + "fmt" + "strings" +) + +// Add some color to log lines received from remote servers. +func paintRemote(line string) string { + splitted := strings.Split(line, "|") + if splitted[2] == "100" { + splitted[2] = Paint(BgGreen, splitted[2]) + } else { + splitted[2] = Paint(BgRed, splitted[2]) + } + info := strings.Join(splitted[0:5], "|") + log := strings.Join(splitted[5:], "|") + + if strings.HasPrefix(log, "WARN") { + log = Paint(BgYellow, log) + } else if strings.HasPrefix(log, "ERROR") { + log = Paint(BgRed, log) + } else if strings.HasPrefix(log, "FATAL") { + log = Attr(Bold, Paint(BgRed, log)) + } else { + log = Paint(Blue, log) + } + + return fmt.Sprintf("%s|%s", info, log) +} + +// Add some color to stats generated by the client. +func paintClientStats(line string) string { + splitted := strings.Split(line, "|") + first := strings.Join(splitted[0:4], "|") + connected := Paint(BgBlue, splitted[4]) + last := strings.Join(splitted[5:], "|") + + return fmt.Sprintf("%s|%s|%s", first, connected, last) +} + +// Colorfy a given line based on the line's content. +func Colorfy(line string) string { + if strings.HasPrefix(line, "REMOTE") { + return paintRemote(line) + } + if strings.HasPrefix(line, "CLIENT") && strings.Contains(line, "|stats|") { + return paintClientStats(line) + } + if strings.Contains(line, "ERROR") { + return Paint(Magenta, line) + } + if strings.Contains(line, "WARN") { + return Paint(Magenta, line) + } + + return line +} diff --git a/internal/config/client.go b/internal/config/client.go new file mode 100644 index 0000000..1515aae --- /dev/null +++ b/internal/config/client.go @@ -0,0 +1,11 @@ +package config + +// ClientConfig represents a DTail client configuration (empty as of now as there +// are no available config options yet, but that may changes in the future). +type ClientConfig struct { +} + +// Create a new default client configuration. +func newDefaultClientConfig() *ClientConfig { + return &ClientConfig{} +} diff --git a/internal/config/common.go b/internal/config/common.go new file mode 100644 index 0000000..8c07710 --- /dev/null +++ b/internal/config/common.go @@ -0,0 +1,42 @@ +package config + +// CommonConfig stores configuration keys shared by DTail server and client. +type CommonConfig struct { + // The SSH server port number. + SSHPort int + // Enable experimental features. + ExperimentalFeaturesEnable bool `json:",omitempty"` + // Enable extra debug logging (used for deevlopment or debugging purpes only). + DebugEnable bool `json:",omitempty"` + // Enable extra trace logging (used for deevlopment or debugging purpes only). + TraceEnable bool `json:",omitempty"` + // The log strategy to use, one of + // stdout: only log to stdout (useful when used with systemd) + // daily: create a log file for every day + LogStrategy string + // The log directory + LogDir string + // The cache directory + CacheDir string + // Do we want to enable pperf http server? + PProfEnable bool `json:",omitempty"` + // The HTTP port used by PProf + PProfPort int `json:",omitempty"` + // The PProf HTTP server bind address + PProfBindAddress string `json:",omitempty"` +} + +// Create a new default configuration. +func newDefaultCommonConfig() *CommonConfig { + return &CommonConfig{ + SSHPort: 2222, + DebugEnable: false, + TraceEnable: false, + ExperimentalFeaturesEnable: false, + LogDir: "log", + CacheDir: "cache", + PProfEnable: false, + PProfPort: 6060, + PProfBindAddress: "0.0.0.0", + } +} diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..0f26635 --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,45 @@ +package config + +import ( + "encoding/json" + "io/ioutil" + "os" +) + +// ControlUser is used for various DTail specific operations. +const ControlUser string = "DTAIL-CONTROL-USER" + +// Client holds a DTail client configuration. +var Client *ClientConfig + +// Server holds a DTail server configuration. +var Server *ServerConfig + +// Common holds common configs of both both, client and server. +var Common *CommonConfig + +// Used to initialize the configuration. +type configInitializer struct { + Common *CommonConfig + Server *ServerConfig + Client *ClientConfig +} + +// Parse and read a given config file in JSON format. +func (c *configInitializer) parseConfig(configFile string) { + fd, err := os.Open(configFile) + if err != nil { + panic(err) + } + defer fd.Close() + + cfgBytes, err := ioutil.ReadAll(fd) + if err != nil { + panic(err) + } + + err = json.Unmarshal([]byte(cfgBytes), c) + if err != nil { + panic(err) + } +} diff --git a/internal/config/read.go b/internal/config/read.go new file mode 100644 index 0000000..a4e605b --- /dev/null +++ b/internal/config/read.go @@ -0,0 +1,37 @@ +package config + +import ( + "os" +) + +// Read the DTail configuration. +func Read(configFile string, sshPort int) { + initializer := configInitializer{ + Common: newDefaultCommonConfig(), + Server: newDefaultServerConfig(), + Client: newDefaultClientConfig(), + } + + if configFile == "" { + configFile = "./cfg/dtail.json" + } + + if _, err := os.Stat(configFile); !os.IsNotExist(err) { + initializer.parseConfig(configFile) + } + + // Assign pointers to global variables, so that we can access the + // configuration from any place of the program. + Common = initializer.Common + Server = initializer.Server + Client = initializer.Client + + if Server.MapreduceLogFormat == "" { + Server.MapreduceLogFormat = "default" + } + + // If non-standard port specified, overwrite config + if sshPort != 2222 { + Common.SSHPort = sshPort + } +} diff --git a/internal/config/server.go b/internal/config/server.go new file mode 100644 index 0000000..7883b33 --- /dev/null +++ b/internal/config/server.go @@ -0,0 +1,66 @@ +package config + +import ( + "errors" +) + +// Permissions map. Each SSH user has a list of permissions which +// log files it is allowed to follow and which ones not. +type Permissions struct { + // The default user permissions. + Default []string + // The per user special permissions. + Users map[string][]string +} + +// ServerConfig represents the server configuration. +type ServerConfig struct { + // The SSH server bind port. + SSHBindAddress string + // The max amount of concurrent user connection allowed to connect to the server. + MaxConnections int + // The max amount of concurrent cats per server. + MaxConcurrentCats int + // The max amount of concurrent tails per server. + MaxConcurrentTails int + // The user permissions. + Permissions Permissions `json:",omitempty"` + // The mapr log format + MapreduceLogFormat string `json:",omitempty"` + // The default path of the server host key + HostKeyFile string + // The host key size in bits + HostKeyBits int +} + +// Create a new default server configuration. +func newDefaultServerConfig() *ServerConfig { + defaultPermissions := []string{"^/.*"} + defaultBindAddress := "0.0.0.0" + + return &ServerConfig{ + SSHBindAddress: defaultBindAddress, + MaxConnections: 10, + MaxConcurrentCats: 2, + MaxConcurrentTails: 50, + HostKeyFile: "./cache/ssh_host_key", + HostKeyBits: 4096, + Permissions: Permissions{ + Default: defaultPermissions, + }, + } +} + +// ServerUserPermissions retrieves the permission set of a given user. +func ServerUserPermissions(userName string) (permissions []string, err error) { + permissions = Server.Permissions.Default + if p, ok := Server.Permissions.Users[userName]; ok { + permissions = p + } + + if len(permissions) == 0 { + err = errors.New("Empty set of permission, user won't be able to open any files") + } + + return +} diff --git a/internal/discovery/comma.go b/internal/discovery/comma.go new file mode 100644 index 0000000..ad18be0 --- /dev/null +++ b/internal/discovery/comma.go @@ -0,0 +1,12 @@ +package discovery + +import ( + "github.com/mimecast/dtail/internal/logger" + "strings" +) + +// ServerListFromCOMMA retrieves a list of servers from comma separated input list. +func (d *Discovery) ServerListFromCOMMA() []string { + logger.Debug("Retrieving server list from comma separated list", d.server) + return strings.Split(d.server, ",") +} diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go new file mode 100644 index 0000000..d76c1b2 --- /dev/null +++ b/internal/discovery/discovery.go @@ -0,0 +1,173 @@ +package discovery + +import ( + "github.com/mimecast/dtail/internal/logger" + "fmt" + "math/rand" + "os" + "reflect" + "regexp" + "strings" + "time" +) + +// Discovery method for discovering a list of available DTail servers. +type Discovery struct { + // To plug in a custom server discovery module. + module string + // To specifiy optional server discovery module options. + options string + // To either filter a server list or to secify an exact list. + server string + // To filter server list. + regex *regexp.Regexp + // To shuffle resulting server list. + shuffle bool +} + +// New returns a new discovery method. +func New(method, server string, shuffle bool) *Discovery { + module := method + options := "" + + if strings.Contains(module, ":") { + s := strings.Split(module, ":") + if len(s) != 2 { + logger.FatalExit("Unable to parse discovery module", module) + } + module = s[0] + options = s[1] + } + + d := Discovery{ + module: strings.ToUpper(module), + options: options, + server: server, + shuffle: shuffle, + } + + if strings.HasPrefix(server, "/") && strings.HasSuffix(server, "/") { + d.initRegex() + } + + return &d +} + +func (d *Discovery) initRegex() { + var runes []rune + last := len(d.server) - 1 + for i, char := range d.server { + if i != 0 && i != last { + runes = append(runes, char) + } + } + + regexStr := string(runes) + logger.Debug("Using filter regex", regexStr) + + regex, err := regexp.Compile(regexStr) + if err != nil { + logger.FatalExit("Could not compile regex", regexStr, err) + } + + d.regex = regex + d.server = "" +} + +// ServerList to connect to via DTail client. +func (d *Discovery) ServerList() []string { + servers := d.serverListFromModule() + + if d.regex != nil { + servers = d.filterList(servers) + } + + servers = d.dedupList(servers) + + if d.shuffle { + servers = d.shuffleList(servers) + } + + logger.Debug("Discovered servers", len(servers), servers) + return servers +} + +func (d *Discovery) serverListFromModule() []string { + if d.module != "" { + return d.serverListFromReflectedModule() + } + + if _, err := os.Stat(d.server); err == nil { + // Appears to be a file name, now try to read from that file. + return d.ServerListFromFILE() + } + + // Appears to be a list of FQDNs (or a single FQDN) + return d.ServerListFromCOMMA() +} + +// The aim of this is that everyone can plug in their own server discovery +// method to DTail. Just add a method ServerListFrommMODULENAME to type +// Discovery. Whereas MODULENAME must be a upeprcase string. +func (d *Discovery) serverListFromReflectedModule() []string { + methodName := fmt.Sprintf("ServerListFrom%s", d.module) + + rt := reflect.TypeOf(d) + reflectedMethod, ok := rt.MethodByName(methodName) + if !ok { + logger.FatalExit("No such server discovery module", d.module, methodName) + } + + inputValues := make([]reflect.Value, 1) + // Thist input value is method receiver. + inputValues[0] = reflect.ValueOf(d) + returnValues := reflectedMethod.Func.Call(inputValues) + + // First return value is server list. + return returnValues[0].Interface().([]string) +} + +// Filter server list based on a regexp. +func (d *Discovery) filterList(servers []string) (filtered []string) { + logger.Debug("Filtering server list") + + for _, server := range servers { + if d.regex.MatchString(server) { + filtered = append(filtered, server) + } + } + + return +} + +// Deduplicate the server list. +func (d *Discovery) dedupList(servers []string) (deduped []string) { + serverMap := make(map[string]struct{}, len(servers)) + + for _, server := range servers { + if _, ok := serverMap[server]; !ok { + serverMap[server] = struct{}{} + deduped = append(deduped, server) + } + } + + logger.Info("Deduped server list", len(servers), len(deduped)) + return +} + +// Randomly shuffle the server list. +func (d *Discovery) shuffleList(servers []string) []string { + logger.Debug("Shuffling server list") + + r := rand.New(rand.NewSource(time.Now().Unix())) + shuffled := make([]string, len(servers)) + n := len(servers) + + for i := 0; i < n; i++ { + randIndex := r.Intn(len(servers)) + shuffled[i] = servers[randIndex] + servers = append(servers[:randIndex], servers[randIndex+1:]...) + } + + return shuffled +} diff --git a/internal/discovery/file.go b/internal/discovery/file.go new file mode 100644 index 0000000..2edc867 --- /dev/null +++ b/internal/discovery/file.go @@ -0,0 +1,28 @@ +package discovery + +import ( + "bufio" + "github.com/mimecast/dtail/internal/logger" + "os" +) + +// ServerListFromFILE retrieves a list of servers from a file. +func (d *Discovery) ServerListFromFILE() (servers []string) { + logger.Debug("Retrieving server list from file", d.server) + + file, err := os.Open(d.server) + if err != nil { + logger.FatalExit(d.server, err) + } + defer file.Close() + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + servers = append(servers, scanner.Text()) + } + if err := scanner.Err(); err != nil { + logger.FatalExit(d.server, err) + } + + return +} diff --git a/internal/fs/catfile.go b/internal/fs/catfile.go new file mode 100644 index 0000000..99f521f --- /dev/null +++ b/internal/fs/catfile.go @@ -0,0 +1,27 @@ +package fs + +import "sync" + +// CatFile is for reading a whole file. +type CatFile struct { + readFile +} + +// NewCatFile returns a new file catter. +func NewCatFile(filePath string, globID string, serverMessages chan<- string, limiter chan struct{}) CatFile { + var mutex sync.Mutex + + return CatFile{ + readFile: readFile{ + filePath: filePath, + stop: make(chan struct{}), + globID: globID, + serverMessages: serverMessages, + retry: false, + canSkipLines: false, + seekEOF: false, + limiter: limiter, + mutex: &mutex, + }, + } +} diff --git a/internal/fs/filereader.go b/internal/fs/filereader.go new file mode 100644 index 0000000..5a08e27 --- /dev/null +++ b/internal/fs/filereader.go @@ -0,0 +1,9 @@ +package fs + +// FileReader is the interface used on the dtail server to read/cat/grep/mapr... a file. +type FileReader interface { + Start(lines chan<- LineRead, regex string) error + FilePath() string + Retry() bool + Stop() +} diff --git a/internal/fs/lineread.go b/internal/fs/lineread.go new file mode 100644 index 0000000..7ee558e --- /dev/null +++ b/internal/fs/lineread.go @@ -0,0 +1,28 @@ +package fs + +import ( + "fmt" +) + +// LineRead represents a read log line. +type LineRead struct { + // The content of the log line. + Content []byte + // Until now, how many log lines were processed? + Count uint64 + // Sometimes we produce too many log lines so that the client + // is too slow to process all of them. The server will drop log + // lines if that happens but it will signal to the client how + // many log lines in % could be transmitted to the client. + TransmittedPerc int + GlobID *string +} + +// Return a human readable representation of the followed line. +func (l LineRead) String() string { + return fmt.Sprintf("LineRead(Content:%s,TransmittedPerc:%v,Count:%v,GlobID:%s)", + string(l.Content), + l.TransmittedPerc, + l.Count, + *l.GlobID) +} diff --git a/internal/fs/permissions/permission.go b/internal/fs/permissions/permission.go new file mode 100644 index 0000000..6e83309 --- /dev/null +++ b/internal/fs/permissions/permission.go @@ -0,0 +1,14 @@ +// +build !linux + +package permissions + +import ( + "github.com/mimecast/dtail/internal/logger" +) + +// ToRead is to check whether user has read permissions to a given file. +func ToRead(user, filePath string) (bool, error) { + // Only implemented for Linux, always expect true + logger.Warn(user, filePath, "Not performing ACL check, not supported on this platform") + return true, nil +} diff --git a/internal/fs/permissions/permission_linux.c b/internal/fs/permissions/permission_linux.c new file mode 100644 index 0000000..cd10525 --- /dev/null +++ b/internal/fs/permissions/permission_linux.c @@ -0,0 +1,395 @@ +#include "permission_linux.h" + +#ifdef DEBUG +void debug_print_checker(struct permission_checker *pc) { + fprintf(stderr, "DEBUG: user_name:%s (%d)\n", + pc->user_name, pc->uid); + + fprintf(stderr, "DEBUG: ngids:%d\n", pc->ngids); + int j; + for (j = 0; j < pc->ngids; j++) { + fprintf(stderr, "DEBUG: %d", pc->gids[j]); + struct group *gr = getgrgid(pc->gids[j]); + if (gr != NULL) + fprintf(stderr, " (%s)", gr->gr_name); + fprintf(stderr, "\n"); + } + + fprintf(stderr, "DEBUG: file_path:%s (%d:%d)\n", + pc->file_path, pc->file_stat.st_uid, pc->file_stat.st_gid); +} +#endif // DEBUG + +int stat_file(struct permission_checker *pc) { + if (stat(pc->file_path, &pc->file_stat) != 0) + return -1; + +#ifdef DEBUG + fprintf(stderr, "DEBUG: File'%s' is owned by '%d:%d'\n", + pc->file_path, pc->file_stat.st_uid, pc->file_stat.st_gid); +#endif + + return 0; +} + +int get_user_uid(struct permission_checker *pc) { + struct passwd *result = NULL; + + size_t bufsize = sysconf(_SC_GETPW_R_SIZE_MAX); + if (bufsize == -1) + bufsize = 16384; + + char *buf = malloc(bufsize); + if (buf == NULL) { +#ifdef DEBUG + fprintf(stderr, "DEBUG: Unabel to allocate bufer while retrieving user '%s'\n", pc->user_name); +#endif + return -1; + } + + int rc = getpwnam_r(pc->user_name, &pc->pw, buf, bufsize, &result); + + if (result == NULL) { +#ifdef DEBUG + if (rc == 0) { + fprintf(stderr, "DEBUG: No user '%s' found\n", pc->user_name); + } else { + fprintf(stderr, "DEBUG: Unknown error while retrieving user '%s'\n", pc->user_name); + } +#endif + + free(buf); + return -1; + } + + pc->uid = pc->pw.pw_uid; + + free(buf); + return 0; +} + +int get_user_groups(struct permission_checker *pc) { + // First assume we are in 10 groups max + pc->ngids = 10; + pc->gids = malloc(pc->ngids * sizeof(gid_t)); + + if (pc->gids == NULL) { +#ifdef DEBUG + fprintf(stderr, "DEBUG: Unable to allocate space for gids."); +#endif + return -1; + } + + // Try so many times to load group list until it fits into group array. + while (getgrouplist(pc->user_name, pc->pw.pw_gid, pc->gids, &pc->ngids) == -1) { + // Too many groups, enlarge group array and try again + int newngids = pc->ngids + 100; + size_t newsize = newngids * sizeof(gid_t); + + if (SIZE_MAX / newngids < sizeof(gid_t)) { + // Overflow +#ifdef DEBUG + fprintf(stderr, "DEBUG: Overflow detected."); +#endif + return -1; + } + + gid_t *newgids = realloc(pc->gids, newsize); + if (newgids == NULL) { +#ifdef DEBUG + fprintf(stderr, "DEBUG: Unable to allocate space for gids."); +#endif + free(pc->gids); + return -1; + } + + pc->gids = newgids; + pc->ngids = newngids; + } + + return 0; +} + +int is_member_of_group(struct permission_checker *pc, gid_t gid) { + int j; + for (j = 0; j < pc->ngids; j++) + if (pc->gids[j] == gid) + return 1; + return 0; +} + +int check_acl_uid_matches(uid_t uid, acl_entry_t entry) { + int ret = -1; + uid_t *acl_uid = acl_get_qualifier(entry); + if (acl_uid == NULL) { +#ifdef DEBUG + fprintf(stderr, "DEBUG: Unable to retrieve user uid from ACL entry"); +#endif + return -1; + } + + ret = *acl_uid == uid ? 0 : -1; +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL user match?: %d <=> %d: %d\n", *acl_uid, uid, ret); +#endif + acl_free(acl_uid); + return ret; +} + +int check_acl_gid_matches(gid_t *gids, int ngids, acl_entry_t entry) { + int ret = -1; + gid_t *acl_gid = acl_get_qualifier(entry); + if (acl_gid == NULL) { +#ifdef DEBUG + fprintf(stderr, "DEBUG: Unable to retrieve user uid from ACL entry"); +#endif + return -1; + } + + int j; + for (j = 0; j < ngids; j++) { + if (*acl_gid == gids[j]) { +#ifdef DEBUG + fprintf(stderr, "DEBUG: User is in group %d", *acl_gid); +#endif + ret = 0; + break; + } + } + +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL group match?: %d <=> ...: %d\n", *acl_gid, ret); +#endif + acl_free(acl_gid); + return ret; +} + +int check_acl(struct permission_checker *pc, const int flag) { + // By default user has no read perm. + int has_read_perm = 0; + + // By default mask tells that there are read perm. However in order to have + // read permissions both, has_read_perm and mask_allows_read_access must be 1! + int mask_allows_read_access = 1; + + acl_type_t type = ACL_TYPE_ACCESS; + acl_t acl = acl_get_file(pc->file_path, type); + + if (acl == NULL) + // Unable to retrieve ACL. + return -1; + + // Walk through each entry of this ACL. + int id; + for (id = ACL_FIRST_ENTRY; ; id = ACL_NEXT_ENTRY) { + acl_entry_t entry; + if (acl_get_entry(acl, id, &entry) != 1) + // No more ACL entries. + break; + + acl_tag_t tag; + if (acl_get_tag_type(entry, &tag) == -1) + // Unable to retrieve ACL tag. + return -1; + + switch (tag) { + case ACL_USER_OBJ: + if (flag == GROUP_CHECK) + continue; +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL_USER_OBJ\n"); +#endif + // Ignore this ACL entry if user is not owner of file. + if (pc->uid != pc->file_stat.st_uid) + continue; + break; + case ACL_USER: + if (flag == GROUP_CHECK) + continue; +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL_USER\n"); +#endif + // Ignore this ACL entry if uid does not match. + if (check_acl_uid_matches(pc->uid, entry) != 0) + continue; + break; + case ACL_GROUP_OBJ: + if (flag == USER_CHECK) + continue; +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL_GROUP_OBJ\n"); +#endif + // Ignore ACL entry if user is not in group of file. + if (!is_member_of_group(pc, pc->file_stat.st_gid)) + continue; + break; + case ACL_GROUP: + if (flag == USER_CHECK) + continue; +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL_GROUP\n"); +#endif + // Ignore ACL entry if user is not in group of entry. + if (check_acl_gid_matches(pc->gids, pc->ngids, entry) != 0) + continue; + break; + case ACL_OTHER: + if (flag == GROUP_CHECK) + continue; +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL_OTHER\n"); +#endif + break; + case ACL_MASK: +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL_MASK\n"); +#endif + break; + default: +#ifdef DEBUG + fprintf(stderr, "DEBUG: Unknown ACL tag\n"); +#endif + return -1; + } + +#ifdef DEBUG + fprintf(stderr, "DEBUG: Retrieving permset\n"); +#endif + acl_permset_t permset; + int permission; + if (acl_get_permset(entry, &permset) == -1) + // Unable to retrieve permset. + return -1; + + if ((permission = acl_get_perm(permset, ACL_READ)) == -1) + // Unable to retrieve permset value. + return -1; + + if (permission == 1 && tag != ACL_MASK) { +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL says user has permission to read file.\n"); +#endif + has_read_perm = 1; + } else if (permission == 0 && tag == ACL_MASK) { + // Mask says that there are no permissions to read. + mask_allows_read_access = 0; +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL mask says no permission to read file.\n"); +#endif + } + } + + if (has_read_perm && mask_allows_read_access) { +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL end result: User has permission to read file.\n"); +#endif + return 1; + } + +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL end result: User has no permission to read file.\n"); +#endif + return 0; +} + +int check_traditional(struct permission_checker *pc, const int flag) { + mode_t mode = pc->file_stat.st_mode; + uid_t uid = pc->file_stat.st_uid; + gid_t gid = pc->file_stat.st_gid; + + if (flag == USER_CHECK && (mode & S_IROTH)) { +#ifdef DEBUG + fprintf(stderr, "DEBUG: Others can read file '%s'\n", + pc->file_path); +#endif + return 1; + + } else if (flag == USER_CHECK && (mode & S_IRUSR) && uid == pc->uid) { +#ifdef DEBUG + fprintf(stderr, "DEBUG: User '%s' can read file '%s'\n", + pc->user_name, pc->file_path); +#endif + return 1; + + } else if (flag == GROUP_CHECK && (mode & S_IRGRP) && is_member_of_group(pc, gid)) { +#ifdef DEBUG + fprintf(stderr, "DEBUG: User's '%s' group can read file '%s'\n", + pc->user_name, pc->file_path); +#endif + return 1; + } + + return 0; +} + +int permission_to_read(char* user_name, char *file_path) { + int rc = -1; + +#ifdef DEBUG + fprintf(stderr, "DEBUG: User check '%s' for file '%s'\n", user_name, file_path); +#endif + struct permission_checker pc = { + .user_name = user_name, + .gids = NULL, + .ngids = 0, + .file_path = file_path, + }; + + // Gather user's UID. + if ((rc = get_user_uid(&pc)) == -1) + // Could not retrieve UID. + goto cleanup; + + // Gather file owner (user and group). + if ((rc = stat_file(&pc)) == -1) + // Could not stat file. + goto cleanup; + + // Check whether there is an ACL entry which would allow the user + // to read the file. Don't check for any groups yet. The issue with + // groups is that it can be very slow to retrieve the list of groups + // of a specific user when done via a remote LDAP server! + if ((rc = check_acl(&pc, USER_CHECK)) == 1) + // Yes, has permissions. + goto cleanup; + + // Check whether ACLs of file could be retrieved. + if (rc == -1) { + if (errno != ENOTSUP) + // Unknown error. + goto cleanup; + + // File system does not support ACLs. + // Fallback to traditional permissions. + if ((rc = check_traditional(&pc, USER_CHECK)) == 1) + // Yes, has traditional permissions. + goto cleanup; + + if ((rc = get_user_groups(&pc)) == -1) + // Can not retrieve user's groups. + goto cleanup; + + rc = check_traditional(&pc, GROUP_CHECK); + goto cleanup; + } + + if ((rc = get_user_groups(&pc)) == -1) + // Can not retrieve use'r groups. + goto cleanup; + + // Check whether there is an ACL entry which would allow any of the + // user's groups to read the file. + rc = check_acl(&pc, GROUP_CHECK); + +cleanup: +#ifdef DEBUG + debug_print_checker(&pc); +#endif + + if (pc.ngids) + free(pc.gids); + + return rc; +} + +// vim: set tabstop=8 softtabstop=0 expandtab shiftwidth=4 smarttab diff --git a/internal/fs/permissions/permission_linux.go b/internal/fs/permissions/permission_linux.go new file mode 100644 index 0000000..feae729 --- /dev/null +++ b/internal/fs/permissions/permission_linux.go @@ -0,0 +1,33 @@ +package permissions + +/* +#include "permission_linux.h" +#cgo LDFLAGS: -L. -lacl +*/ +import "C" + +import ( + "errors" + "unsafe" +) + +// To check whether user has Linux file system permissions to read a given file. +func ToRead(user, filePath string) (bool, error) { + cUser := C.CString(user) + cFilePath := C.CString(filePath) + + defer C.free(unsafe.Pointer(cUser)) + defer C.free(unsafe.Pointer(cFilePath)) + + cOk, err := C.permission_to_read(cUser, cFilePath) + if cOk == 1 { + return true, nil + } + + if err != nil { + // err contains errno message + return false, err + } + + return false, errors.New("User without permission to read file") +} diff --git a/internal/fs/permissions/permission_linux.h b/internal/fs/permissions/permission_linux.h new file mode 100644 index 0000000..a2c266e --- /dev/null +++ b/internal/fs/permissions/permission_linux.h @@ -0,0 +1,60 @@ +#ifndef PERMISSION_LINUX_H +#define PERMISSION_LINUX_H + +#include <acl/libacl.h> +#include <errno.h> +#include <grp.h> +#include <pwd.h> +#include <stdio.h> +#include <stdint.h> +#include <stdlib.h> +#include <sys/acl.h> +#include <sys/stat.h> +#include <sys/types.h> +#include <unistd.h> + +//#define DEBUG +#define USER_CHECK 0 +#define GROUP_CHECK 1 + +struct permission_checker { + char *user_name; + uid_t uid; + gid_t *gids; + int ngids; + char *file_path; + struct stat file_stat; + struct passwd pw; +}; + + +#ifdef DEBUG +// Print out permission_checker struct. +void debug_print_checker(struct permission_checker *pc); +#endif + +// Stat a given file to retrieve traditional UNIX permissions. +int stat_file(struct permission_checker *pc); + +// Retrieve UID of user. +int get_user_uid(struct permission_checker *pc); + +// Retrieve all groups of the user. +int get_user_groups(struct permission_checker *pc); + +// Check whether user is member of a group or not. +int is_member_of_group(struct permission_checker *pc, gid_t gid); + +// Check whether user can read file according Linux ACLs. +// As flag use either USER_CHECK or GROUP_CHECK. +int check_acl(struct permission_checker *pc, const int flag); + +// Check whether user has permissions to read file according traditional +// UNIX permissions. As flag use either USER_CHECK or GROUP_CHECK. +int check_traditional(struct permission_checker *pc, const int flag); + +// Returns 1 if user has permission to read file. +// Returns <0 on error and returns 0 if no permissions. +int permission_to_read(char* user, char *file_path); + +#endif // PERMISSION_LINUX_H diff --git a/internal/fs/permissions/permission_test.go b/internal/fs/permissions/permission_test.go new file mode 100644 index 0000000..d415ac2 --- /dev/null +++ b/internal/fs/permissions/permission_test.go @@ -0,0 +1,112 @@ +// +build linux + +package permissions + +import ( + "os" + "os/exec" + "os/user" + "strings" + "testing" +) + +const ( + setfacl string = "/usr/bin/setfacl" + file string = "/tmp/acltest" +) + +func TestLinuxACL(t *testing.T) { + setfacl := "/usr/bin/setfacl" + file := "/tmp/acltest" + + // Delete file if it exists. + if _, err := os.Stat(file); err == nil { + os.Remove(file) + } + + f, err := os.Create(file) + if err != nil { + t.Errorf("%v", err) + } + defer func() { + f.Close() + //os.Remove(file) + }() + + user, err := user.Current() + if err != nil { + t.Errorf("Unable to retrieve current user: %v", err) + } + + // Test 1: Remove all permissions and perform a permission check + cmd := exec.Command(setfacl, "-b", "-m", "u::---,g::---,o::---", file) + if err := cmd.Run(); err != nil { + t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) + } + if ok, _ := ToRead(user.Username, file); ok { + t.Errorf("Didn't expect permissions to read file!") + } + + // Test 2: Add read permission to file owner + cmd = exec.Command(setfacl, "-b", "-m", "u::r--,g::---,o::---", file) + if err := cmd.Run(); err != nil { + t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) + } + if ok, err := ToRead(user.Username, file); !ok { + t.Errorf("Expected permissions to read file: %v", err) + } + + // Test 3: Add read permission to file group + cmd = exec.Command(setfacl, "-b", "-m", "u::---,g::r--,o::---", file) + if err := cmd.Run(); err != nil { + t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) + } + if ok, err := ToRead(user.Username, file); !ok { + t.Errorf("Expected permissions to read file: %v", err) + } + + // Test 4: Add read permission to others + cmd = exec.Command(setfacl, "-b", "-m", "u::---,g::---,o::r--", file) + if err := cmd.Run(); err != nil { + t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) + } + + if ok, err := ToRead(user.Username, file); !ok { + t.Errorf("Expected permissions to read file: %v", err) + } + + // Test 5: Remove read permission from mask + cmd = exec.Command(setfacl, "-m", "m::---", file) + if err := cmd.Run(); err != nil { + t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) + } + if ok, _ := ToRead(user.Username, file); ok { + t.Errorf("Didn't expect permissions to read file!") + } + cmd = exec.Command(setfacl, "-m", "m::r--", file) + if err := cmd.Run(); err != nil { + t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) + } + + // Test 6: Add read permission to specific group + cmd = exec.Command(setfacl, "-b", "-m", "u::---,g:"+user.Username+":r--,o::---", file) + if err := cmd.Run(); err != nil { + t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) + } + if ok, err := ToRead(user.Username, file); !ok { + t.Errorf("Expected permissions to read file for user %v: %v", user.Username, err) + } + + // Test 7: Remove all permissions but mask + cmd = exec.Command(setfacl, "-b", "-m", "u::---,g::---,o::---", file) + if err := cmd.Run(); err != nil { + t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) + } + cmd = exec.Command(setfacl, "-m", "m::r--", file) + if err := cmd.Run(); err != nil { + t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) + } + if ok, _ := ToRead(user.Username, file); ok { + t.Errorf("Didn't expect permissions to read file!") + } +} diff --git a/internal/fs/readfile.go b/internal/fs/readfile.go new file mode 100644 index 0000000..312447a --- /dev/null +++ b/internal/fs/readfile.go @@ -0,0 +1,318 @@ +package fs + +import ( + "bufio" + "compress/gzip" + "github.com/mimecast/dtail/internal/logger" + "errors" + "io" + "os" + "regexp" + "strings" + "sync" + "time" + + "github.com/DataDog/zstd" +) + +// Used to tail and filter a local log file. +type readFile struct { + // Various statistics (e.g. regex hit percentage, transfer percentage). + stats + // Path of log file to tail. + filePath string + // Only consider all log lines matching this regular expression. + re *regexp.Regexp + // The glob identifier of the file. + globID string + // Channel to send a server message to the dtail client + serverMessages chan<- string + // Signals to stop tailing the log file. + stop chan struct{} + // Periodically retry reading file. + retry bool + // Can I skip messages when there are too many? + canSkipLines bool + // Seek to the EOF before processing file? + seekEOF bool + // Mutex to control the stopping of the file + mutex *sync.Mutex + limiter chan struct{} +} + +// FilePath returns the full file path. +func (f readFile) FilePath() string { + return f.filePath +} + +// Retry reading the file on error? +func (f readFile) Retry() bool { + return f.retry +} + +// Start tailing a log file. +func (f readFile) Start(lines chan<- LineRead, regex string) error { + defer func() { + select { + case <-f.limiter: + default: + } + }() + + select { + case f.limiter <- struct{}{}: + default: + select { + case f.serverMessages <- logger.Warn(f.filePath, f.globID, "Server limit reached. Queuing file..."): + case <-f.stop: + return nil + } + f.limiter <- struct{}{} + } + + fd, err := os.Open(f.filePath) + if err != nil { + return err + } + defer fd.Close() + + if f.seekEOF { + fd.Seek(0, io.SeekEnd) + } + + rawLines := make(chan []byte, 100) + truncate := make(chan struct{}) + + var wg sync.WaitGroup + wg.Add(1) + + go f.periodicTruncateCheck(truncate) + go f.filter(&wg, rawLines, lines, regex) + + err = f.read(fd, rawLines, truncate) + close(rawLines) + wg.Wait() + + return err +} + +func (f readFile) periodicTruncateCheck(truncate chan struct{}) { + for { + select { + case <-time.After(time.Second * 3): + select { + case truncate <- struct{}{}: + case <-f.stop: + } + case <-f.stop: + return + } + } +} + +// Stop reading file. +func (f readFile) Stop() { + f.mutex.Lock() + defer f.mutex.Unlock() + + select { + case <-f.stop: + return + default: + } + + close(f.stop) +} + +func (f readFile) makeReader(fd *os.File) (reader *bufio.Reader, err error) { + switch { + case strings.HasSuffix(f.FilePath(), ".gz"): + fallthrough + case strings.HasSuffix(f.FilePath(), ".gzip"): + logger.Info(f.FilePath(), "Detected gzip compression format") + var gzipReader *gzip.Reader + gzipReader, err = gzip.NewReader(fd) + if err != nil { + return + } + reader = bufio.NewReader(gzipReader) + case strings.HasSuffix(f.FilePath(), ".zst"): + logger.Info(f.FilePath(), "Detected zstd compression format") + reader = bufio.NewReader(zstd.NewReader(fd)) + default: + reader = bufio.NewReader(fd) + } + + return +} + +func (f readFile) read(fd *os.File, rawLines chan []byte, truncate <-chan struct{}) error { + reader, err := f.makeReader(fd) + if err != nil { + return err + } + rawLine := make([]byte, 0, 512) + var offset uint64 + + lineLengthThreshold := 1024 * 1024 // 1mb + longLineWarning := false + + for { + select { + case <-truncate: + if isTruncated, err := f.truncated(fd); isTruncated { + return err + } + logger.Info(f.filePath, "Current offset", offset) + + case <-f.stop: + return nil + default: + } + + // Read some bytes (max 4k at once as of go 1.12). isPrefix will + // be set if line does not fit into 4k buffer. + bytes, isPrefix, err := reader.ReadLine() + + if err != nil { + // If EOF, sleep a couple of ms and return with nil error. + // If other error, return with non-nil error. + if err != io.EOF { + return err + } + if !f.seekEOF { + logger.Debug(f.FilePath(), "End of file reached") + return nil + } + time.Sleep(time.Millisecond * 100) + continue + } + + rawLine = append(rawLine, bytes...) + offset += uint64(len(bytes)) + + if !isPrefix { + // last LineRead call returned contend until end of line. + rawLine = append(rawLine, '\n') + select { + case rawLines <- rawLine: + case <-f.stop: + return nil + } + rawLine = make([]byte, 0, 512) + if longLineWarning { + longLineWarning = false + } + continue + } + + // Last LineRead call could not read content until end of line, buffer + // was too small. Determine whether we exceed the max line length we + // want dtail to send to the client at once. Possibly split up log line + // into multiple log lines. + if len(rawLine) >= lineLengthThreshold { + if !longLineWarning { + f.serverMessages <- logger.Warn(f.filePath, "Long log line, splitting into multiple lines") + // Only print out one warning per long log line. + longLineWarning = true + } + rawLine = append(rawLine, '\n') + select { + case rawLines <- rawLine: + case <-f.stop: + return nil + } + rawLine = make([]byte, 0, 512) + } + } +} + +// Filter log lines matching a given regular expression. +func (f readFile) filter(wg *sync.WaitGroup, rawLines <-chan []byte, lines chan<- LineRead, regex string) { + defer wg.Done() + + if regex == "" { + regex = "." + } + + re, err := regexp.Compile(regex) + if err != nil { + logger.Error(regex, "Can't compile regex, using '.' instead", err) + re = regexp.MustCompile(".") + } + f.re = re + + for { + select { + case line, ok := <-rawLines: + f.updatePosition() + if !ok { + return + } + if filteredLine, ok := f.transmittable(line, len(lines), cap(lines)); ok { + select { + case lines <- filteredLine: + case <-f.stop: + return + } + } + } + } +} + +func (f readFile) transmittable(line []byte, length, capacity int) (LineRead, bool) { + var read LineRead + + if !f.re.Match(line) { + f.updateLineNotMatched() + f.updateLineNotTransmitted() + return read, false + } + f.updateLineMatched() + + // Can we actually send more messages, channel capacity reached? + if f.canSkipLines && length >= capacity { + f.updateLineNotTransmitted() + return read, false + } + f.updateLineTransmitted() + + read = LineRead{ + Content: line, + GlobID: &f.globID, + Count: f.totalLineCount(), + TransmittedPerc: f.transmittedPerc(), + } + + return read, true +} + +// Check wether log file is truncated. Returns nil if not. +func (f readFile) truncated(fd *os.File) (bool, error) { + logger.Debug(f.filePath, "File truncation check") + + // Can not seek currently open FD. + curPos, err := fd.Seek(0, os.SEEK_CUR) + if err != nil { + return true, err + } + + // Can not open file at original path. + pathFd, err := os.Open(f.filePath) + if err != nil { + return true, err + } + defer pathFd.Close() + + // Can not seek file at original path. + pathPos, err := pathFd.Seek(0, io.SeekEnd) + if err != nil { + return true, err + } + + if curPos > pathPos { + return true, errors.New("File got truncated") + } + + return false, nil +} diff --git a/internal/fs/stats.go b/internal/fs/stats.go new file mode 100644 index 0000000..4121ff7 --- /dev/null +++ b/internal/fs/stats.go @@ -0,0 +1,69 @@ +package fs + +// Used to calculate how many log lines matched the regular expression +// and how many log files could be transmitted from the server to the client. +// Hit and transmit percentage takes only the last 100 log lines into calculation. +type stats struct { + pos int + lineCount uint64 + matched [100]bool + matchCount uint64 + transmitted [100]bool + transmitCount int +} + +// Return the total line count. +func (f *stats) totalLineCount() uint64 { + return f.lineCount +} + +// Calculate the percentage of log lines transmitted to the client. +func (f *stats) transmittedPerc() int { + return int(percentOf(float64(f.matchCount), float64(f.transmitCount))) +} + +// Update bucket position. We only take into consideration the last 100 +// lines for stats. +func (f *stats) updatePosition() { + f.pos = (f.pos + 1) % 100 + f.lineCount++ +} + +// Increment match counter. +func (f *stats) updateLineMatched() { + if !f.matched[f.pos] { + f.matchCount++ + f.matched[f.pos] = true + } +} + +// Increment transmitted counter. +func (f *stats) updateLineTransmitted() { + if !f.transmitted[f.pos] { + f.transmitCount++ + f.transmitted[f.pos] = true + } +} + +// Decrement match counter. +func (f *stats) updateLineNotMatched() { + if f.matched[f.pos] { + f.matchCount-- + f.matched[f.pos] = false + } +} + +// Decrement transmitted counter. +func (f *stats) updateLineNotTransmitted() { + if f.transmitted[f.pos] { + f.transmitCount-- + f.transmitted[f.pos] = false + } +} + +func percentOf(total float64, value float64) float64 { + if total == 0 || total == value { + return 100 + } + return value / (total / 100.0) +} diff --git a/internal/fs/tailfile.go b/internal/fs/tailfile.go new file mode 100644 index 0000000..a19d4e6 --- /dev/null +++ b/internal/fs/tailfile.go @@ -0,0 +1,27 @@ +package fs + +import "sync" + +// TailFile is to tail and filter a log file. +type TailFile struct { + readFile +} + +// NewTailFile returns a new file tailer. +func NewTailFile(filePath string, globID string, serverMessages chan<- string, limiter chan struct{}) TailFile { + var mutex sync.Mutex + + return TailFile{ + readFile: readFile{ + filePath: filePath, + stop: make(chan struct{}), + globID: globID, + serverMessages: serverMessages, + retry: true, + canSkipLines: true, + seekEOF: true, + limiter: limiter, + mutex: &mutex, + }, + } +} diff --git a/internal/logger/logger.go b/internal/logger/logger.go new file mode 100644 index 0000000..ca85e32 --- /dev/null +++ b/internal/logger/logger.go @@ -0,0 +1,457 @@ +package logger + +import ( + "bufio" + "fmt" + "os" + "os/signal" + "runtime" + "strings" + "sync" + "syscall" + "time" + + "github.com/mimecast/dtail/internal/color" + "github.com/mimecast/dtail/internal/config" +) + +const ( + clientStr string = "CLIENT" + serverStr string = "SERVER" + infoStr string = "INFO" + warnStr string = "WARN" + errorStr string = "ERROR" + fatalStr string = "FATAL" + debugStr string = "DEBUG" + traceStr string = "TRACE" +) + +// Synchronise access to logging. +var mutex sync.Mutex + +// File descriptor of log file when logToFile enabled. +var fd *os.File + +// File write buffer of log file when logToFile enabled. +var writer *bufio.Writer + +// File write buffer of stdout when logToStdout enabled. +var stdoutWriter *bufio.Writer + +// Current hostname. +var hostname string + +// Used to detect change of day (create one log file per day0 +var lastDateStr string + +// True if log in server mode, false if log in client mode. +var serverEnable bool + +// Used to make logging non-blocking. +var logBufCh chan buf +var stdoutBufCh chan string + +// Stdout channel, required to pause output +var pauseCh chan struct{} +var resumeCh chan struct{} + +// Tell the logger that we are done, program shuts down +var stop chan struct{} +var stdoutFlushed chan struct{} + +// Tell the logger about logrotation +var rotateCh chan os.Signal + +// LogMode allows to specify the verbosity of logging. +type LogMode int + +// Possible log modes. +const ( + NormalMode LogMode = iota + DebugMode LogMode = iota + SilentMode LogMode = iota + TraceMode LogMode = iota + NothingMode LogMode = iota +) + +// Mode is the current log mode in use. +var Mode LogMode + +// LogStrategy allows to specify a log rotation strategy. +type LogStrategy int + +// Possible log strategies. +const ( + NormalStrategy LogStrategy = iota + DailyStrategy LogStrategy = iota + StdoutStrategy LogStrategy = iota +) + +// Strategy is the current log strattegy used. +var Strategy LogStrategy + +// Enables logging to stdout. +var logToStdout bool + +// Enables logging to file. +var logToFile bool + +// Helper type to make logging non-blocking. +type buf struct { + time time.Time + message string +} + +// Start logging. +func Start(myServerEnable, debugEnable, silentEnable, nothingEnable bool) { + serverEnable = myServerEnable + + mode := logMode(debugEnable, silentEnable, nothingEnable) + strategy := logStrategy() + + stdoutWriter = bufio.NewWriter(os.Stdout) + Mode = mode + Strategy = strategy + + if Mode == NothingMode { + return + } + + switch Strategy { + case DailyStrategy: + _, err := os.Stat(config.Common.LogDir) + logToFile = !os.IsNotExist(err) + logToStdout = !serverEnable || Mode == DebugMode || Mode == TraceMode + case StdoutStrategy: + fallthrough + default: + logToFile = false + logToStdout = true + } + + fqdn, err := os.Hostname() + if err != nil { + panic(err) + } + s := strings.Split(fqdn, ".") + hostname = s[0] + + pauseCh = make(chan struct{}) + resumeCh = make(chan struct{}) + stop = make(chan struct{}) + stdoutFlushed = make(chan struct{}) + + // Setup logrotation + rotateCh = make(chan os.Signal, 1) + signal.Notify(rotateCh, syscall.SIGHUP) + + if logToStdout { + stdoutBufCh = make(chan string, runtime.NumCPU()*100) + go writeToStdout() + } + + if logToFile { + logBufCh = make(chan buf, runtime.NumCPU()*100) + go writeToFile() + } +} + +func logMode(debugEnable, silentEnable, nothingEnable bool) LogMode { + switch { + case debugEnable: + return DebugMode + case nothingEnable: + return NothingMode + case config.Common.TraceEnable: + return TraceMode + case config.Common.DebugEnable: + return DebugMode + case silentEnable: + return SilentMode + default: + } + return NormalMode +} + +func logStrategy() LogStrategy { + switch config.Common.LogStrategy { + case "daily": + return DailyStrategy + default: + } + return StdoutStrategy +} + +// Info message logging. +func Info(args ...interface{}) string { + if serverEnable { + return log(serverStr, infoStr, args) + } + + return log(clientStr, infoStr, args) +} + +// Warn message logging. +func Warn(args ...interface{}) string { + if serverEnable { + return log(serverStr, warnStr, args) + } + + return log(clientStr, warnStr, args) +} + +// Error message logging. +func Error(args ...interface{}) string { + if serverEnable { + return log(serverStr, errorStr, args) + } + + return log(clientStr, errorStr, args) +} + +// FatalExit logs an error and exists the process. +func FatalExit(args ...interface{}) { + what := clientStr + if serverEnable { + what = serverStr + } + log(what, fatalStr, args) + + time.Sleep(time.Second) + mutex.Lock() + defer mutex.Unlock() + + closeWriter() + os.Exit(3) +} + +// Debug message logging. +func Debug(args ...interface{}) string { + if Mode == DebugMode || Mode == TraceMode { + if serverEnable { + return log(serverStr, debugStr, args) + } + return log(clientStr, debugStr, args) + } + + return "" +} + +// Trace message logging. +func Trace(args ...interface{}) string { + if Mode == TraceMode { + if serverEnable { + return log(serverStr, traceStr, args) + } + return log(clientStr, traceStr, args) + } + + return "" +} + +// Write log line to buffer and/or log file. +func write(what, severity, message string) { + if logToStdout && (Mode != SilentMode || severity != warnStr) { + line := fmt.Sprintf("%s|%s|%s|%s\n", what, hostname, severity, message) + + if color.Colored { + line = color.Colorfy(line) + } + + stdoutBufCh <- line + } + + if logToFile { + t := time.Now() + timeStr := t.Format("20060102-150405") + logBufCh <- buf{ + time: t, + message: fmt.Sprintf("%s|%s|%s|%s\n", severity, timeStr, what, message), + } + } +} + +// Generig log message. +func log(what string, severity string, args []interface{}) string { + if Mode == NothingMode { + return "" + } + + var messages []string + + for _, arg := range args { + switch v := arg.(type) { + case string: + messages = append(messages, v) + case int: + messages = append(messages, fmt.Sprintf("%d", v)) + case error: + messages = append(messages, v.Error()) + default: + messages = append(messages, fmt.Sprintf("%v", v)) + } + } + + message := strings.Join(messages, "|") + write(what, severity, message) + + return fmt.Sprintf("%s|%s", severity, message) +} + +// Raw message logging. +func Raw(message string) { + if Mode == NothingMode { + return + } + + if logToStdout { + if color.Colored { + message = color.Colorfy(message) + } + stdoutBufCh <- message + } + + if logToFile { + logBufCh <- buf{time.Now(), message} + } +} + +// Close log writer (e.g. on change of day). +func closeWriter() { + if writer != nil { + writer.Flush() + fd.Close() + } +} + +// Return the correct log file writer +func fileWriter(dateStr string) *bufio.Writer { + if dateStr != lastDateStr { + return updateFileWriter(dateStr) + } + + // Check for log rotation signal + select { + case <-rotateCh: + stdoutWriter.WriteString("Received signal for logrotation\n") + return updateFileWriter(dateStr) + default: + } + + return writer +} + +// Update log file writer +func updateFileWriter(dateStr string) *bufio.Writer { + // Detected change of day. Close current writer and create a new one. + mutex.Lock() + defer mutex.Unlock() + closeWriter() + + if _, err := os.Stat(config.Common.LogDir); os.IsNotExist(err) { + if err = os.MkdirAll(config.Common.LogDir, 0755); err != nil { + panic(err) + } + } + + logFile := fmt.Sprintf("%s/%s.log", config.Common.LogDir, dateStr) + newFd, err := os.OpenFile(logFile, os.O_CREATE|os.O_RDWR|os.O_APPEND, 0644) + if err != nil { + panic(err) + } + + fd = newFd + writer = bufio.NewWriterSize(fd, 1) + lastDateStr = dateStr + + return writer +} + +func flushStdout() { + defer close(stdoutFlushed) + + for { + select { + case message := <-stdoutBufCh: + stdoutWriter.WriteString(message) + default: + stdoutWriter.Flush() + return + } + } +} + +func writeToStdout() { + for { + select { + case message := <-stdoutBufCh: + stdoutWriter.WriteString(message) + case <-time.After(time.Millisecond * 100): + stdoutWriter.Flush() + case <-pauseCh: + PAUSE: + for { + select { + case <-stdoutBufCh: + case <-resumeCh: + break PAUSE + case <-stop: + return + } + } + case <-stop: + flushStdout() + return + } + } +} + +func writeToFile() { + for { + select { + case buf := <-logBufCh: + dateStr := buf.time.Format("20060102") + w := fileWriter(dateStr) + w.WriteString(buf.message) + case <-pauseCh: + PAUSE: + for { + select { + case <-stdoutBufCh: + case <-resumeCh: + break PAUSE + case <-stop: + return + } + } + case <-stop: + return + } + } +} + +// Pause logging. +func Pause() { + if logToStdout { + pauseCh <- struct{}{} + } + if logToFile { + pauseCh <- struct{}{} + } +} + +// Resume logging (after pausing). +func Resume() { + if logToStdout { + resumeCh <- struct{}{} + } + if logToFile { + resumeCh <- struct{}{} + } +} + +// Stop logging. +func Stop() { + close(stop) + <-stdoutFlushed +} diff --git a/internal/mapr/aggregateset.go b/internal/mapr/aggregateset.go new file mode 100644 index 0000000..2096c3c --- /dev/null +++ b/internal/mapr/aggregateset.go @@ -0,0 +1,185 @@ +package mapr + +import ( + "fmt" + "strconv" + "strings" +) + +// AggregateSet represents aggregated key/value pairs from the +// MAPREDUCE log lines. These could be either string values or float +// values. +type AggregateSet struct { + Samples int + FValues map[string]float64 + SValues map[string]string +} + +// NewAggregateSet creates a new empty aggregate set. +func NewAggregateSet() *AggregateSet { + return &AggregateSet{ + FValues: make(map[string]float64), + SValues: make(map[string]string), + } +} + +// String representation of aggregate set. +func (s *AggregateSet) String() string { + return fmt.Sprintf("AggregateSet(Samples:%d,FValues:%v,SValues:%v)", + s.Samples, s.FValues, s.SValues) +} + +// Merge one aggregate set into this one. +func (s *AggregateSet) Merge(query *Query, set *AggregateSet) error { + s.Samples += set.Samples + //logger.Trace("Merge", set) + + for _, sc := range query.Select { + storage := sc.FieldStorage + switch sc.Operation { + case Count: + fallthrough + case Sum: + fallthrough + case Avg: + value := set.FValues[storage] + s.addFloat(storage, value) + case Min: + value := set.FValues[storage] + s.addFloatMin(storage, value) + case Max: + value := set.FValues[storage] + s.addFloatMax(storage, value) + case Last: + value := set.SValues[storage] + s.setString(storage, value) + case Len: + s.setString(storage, set.SValues[storage]) + s.setFloat(storage, set.FValues[storage]) + default: + return fmt.Errorf("Unknown aggregation method '%v'", sc.Operation) + } + } + return nil +} + +// Serialize the aggregate set so it can be sent over the wire. +func (s *AggregateSet) Serialize(groupKey string, ch chan<- string, stop chan struct{}) { + //logger.Trace("Serialising mapr.AggregateSet", s) + var sb strings.Builder + + sb.WriteString(groupKey) + sb.WriteString("|") + sb.WriteString(fmt.Sprintf("%d|", s.Samples)) + + for k, v := range s.FValues { + sb.WriteString(k) + sb.WriteString("=") + sb.WriteString(fmt.Sprintf("%v|", v)) + } + + for k, v := range s.SValues { + sb.WriteString(k) + sb.WriteString("=") + sb.WriteString(v) + sb.WriteString("|") + } + + select { + case ch <- sb.String(): + case <-stop: + } +} + +// Add a float value. +func (s *AggregateSet) addFloat(key string, value float64) { + if _, ok := s.FValues[key]; !ok { + s.FValues[key] = value + return + } + s.FValues[key] += value +} + +// Add a float minimum value. +func (s *AggregateSet) addFloatMin(key string, value float64) { + f, ok := s.FValues[key] + if !ok { + s.FValues[key] = value + return + } + + if f > value { + s.FValues[key] = value + } +} + +// Add a float maximum value. +func (s *AggregateSet) addFloatMax(key string, value float64) { + f, ok := s.FValues[key] + if !ok { + s.FValues[key] = value + return + } + + if f < value { + s.FValues[key] = value + } +} + +// Set a string. +func (s *AggregateSet) setString(key, value string) { + s.SValues[key] = value +} + +// Set a float. +func (s *AggregateSet) setFloat(key string, value float64) { + s.FValues[key] = value +} + +// Aggregate data to the aggregate set. +func (s *AggregateSet) Aggregate(key string, agg AggregateOperation, value string, clientAggregation bool) (err error) { + var f float64 + + // First check if we can aggregate anything without converting value to float. + switch agg { + case Count: + if clientAggregation { + f, err = strconv.ParseFloat(value, 64) + if err != nil { + return + } + s.addFloat(key, f) + return + } + s.addFloat(key, 1) + return + case Last: + s.setString(key, value) + return + case Len: + s.setString(key, value) + s.setFloat(key, float64(len(value))) + return + default: + } + + // No, we have to convert to float. + f, err = strconv.ParseFloat(value, 64) + if err != nil { + return + } + + switch agg { + case Sum: + fallthrough + case Avg: + s.addFloat(key, f) + case Min: + s.addFloatMin(key, f) + case Max: + s.addFloatMax(key, f) + default: + err = fmt.Errorf("Unknown aggregation method '%v'", agg) + } + return +} diff --git a/internal/mapr/client/aggregate.go b/internal/mapr/client/aggregate.go new file mode 100644 index 0000000..3f2b7a5 --- /dev/null +++ b/internal/mapr/client/aggregate.go @@ -0,0 +1,100 @@ +package client + +import ( + "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/mapr" + "strconv" + "strings" +) + +// Aggregate mapreduce data on the DTail client side. +type Aggregate struct { + // This is the mapr query specified on the command line. + query *mapr.Query + // This represents aggregated data of a single remote server. + group *mapr.GroupSet + // This represents the merged aggregated data of all servers. + globalGroup *mapr.GlobalGroupSet + stop chan struct{} + // The server we aggregate the data for (logging and debugging purposes only) + server string +} + +// NewAggregate create new client aggregator. +func NewAggregate(server string, query *mapr.Query, globalGroup *mapr.GlobalGroupSet) *Aggregate { + return &Aggregate{ + query: query, + group: mapr.NewGroupSet(), + globalGroup: globalGroup, + stop: make(chan struct{}), + server: server, + } +} + +// Aggregate data from mapr log line into local (and global) group sets. +func (a *Aggregate) Aggregate(parts []string) { + select { + case <-a.stop: + logger.Error("Client aggregator stopped for server, not processing new data", a.server) + return + default: + } + + groupKey := parts[0] + samples, err := strconv.Atoi(parts[1]) + if err != nil { + logger.FatalExit(parts, err) + } + fields := a.makeFields(parts[2:]) + set := a.group.GetSet(groupKey) + + var addedSamples bool + for _, sc := range a.query.Select { + if val, ok := fields[sc.FieldStorage]; ok { + if err := set.Aggregate(sc.FieldStorage, sc.Operation, val, true); err != nil { + logger.Error(err) + continue + } + addedSamples = true + } + } + if addedSamples { + set.Samples += samples + } + + // Merge data from group into global group. + isMerged, err := a.globalGroup.MergeNoblock(a.query, a.group) + if err != nil { + panic(err) + } + if isMerged { + // Re-init local group (make it empty again). + a.group.InitSet() + } +} + +// Create a map of key-value pairs from a part list such as ["foo=bar", "bar=baz"]. +func (a *Aggregate) makeFields(parts []string) map[string]string { + fields := make(map[string]string, len(parts)) + + for _, part := range parts { + kv := strings.Split(part, "=") + if len(kv) != 2 { + continue + } + fields[kv[0]] = kv[1] + } + + return fields +} + +// Stop the client side mapreduce aggregator. +func (a *Aggregate) Stop() { + logger.Debug("Stopping client mapreduce aggregator") + close(a.stop) + + err := a.globalGroup.Merge(a.query, a.group) + if err != nil { + panic(err) + } +} diff --git a/internal/mapr/globalgroupset.go b/internal/mapr/globalgroupset.go new file mode 100644 index 0000000..cfab506 --- /dev/null +++ b/internal/mapr/globalgroupset.go @@ -0,0 +1,100 @@ +package mapr + +import ( + "fmt" +) + +// GlobalGroupSet is used on the dtail client to merge multiple group sets +// (one group set per remote server) to one single global group set. +type GlobalGroupSet struct { + GroupSet + semaphore chan struct{} +} + +// NewGlobalGroupSet creates a new empty global group set. +func NewGlobalGroupSet() *GlobalGroupSet { + g := GlobalGroupSet{ + semaphore: make(chan struct{}, 1), + } + g.InitSet() + + return &g +} + +// String representation of the global group set. +func (g *GlobalGroupSet) String() string { + return fmt.Sprintf("GlobalGroupSet(%s)", g.GroupSet.String()) +} + +// Merge (blocking) a group set into the global group set. +func (g *GlobalGroupSet) Merge(query *Query, group *GroupSet) error { + g.semaphore <- struct{}{} + defer func() { <-g.semaphore }() + + return g.merge(query, group) +} + +// MergeNoblock merges (non-blocking) a group set into the global group set. +func (g *GlobalGroupSet) MergeNoblock(query *Query, group *GroupSet) (bool, error) { + select { + case g.semaphore <- struct{}{}: + err := g.merge(query, group) + <-g.semaphore + return true, err + default: + return false, nil + } +} + +// Merge a group set into the global group set. +func (g *GlobalGroupSet) merge(query *Query, group *GroupSet) error { + for groupKey, set := range group.sets { + s := g.GetSet(groupKey) + if err := s.Merge(query, set); err != nil { + return err + } + } + + return nil +} + +// IsEmpty determines whether the global group set has any data in it. +func (g *GlobalGroupSet) IsEmpty() bool { + return g.NumSets() == 0 +} + +// NumSets determines the number of sets. +func (g *GlobalGroupSet) NumSets() int { + g.semaphore <- struct{}{} + defer func() { <-g.semaphore }() + + return len(g.sets) +} + +// SwapOut teturn the underlying group set and create a new empty one, so +// that the global group set is empty again and can aggregate new data. +func (g *GlobalGroupSet) SwapOut() *GroupSet { + g.semaphore <- struct{}{} + defer func() { <-g.semaphore }() + + set := &GroupSet{sets: g.sets} + g.InitSet() + + return set +} + +// WriteResult writes the result of a mapreduce aggregation to an outfile. +func (g *GlobalGroupSet) WriteResult(query *Query) error { + g.semaphore <- struct{}{} + defer func() { <-g.semaphore }() + + return g.GroupSet.WriteResult(query) +} + +// Result returns the result of the mapreduce aggregation as a string. +func (g *GlobalGroupSet) Result(query *Query) (string, int, error) { + g.semaphore <- struct{}{} + defer func() { <-g.semaphore }() + + return g.GroupSet.Result(query) +} diff --git a/internal/mapr/groupset.go b/internal/mapr/groupset.go new file mode 100644 index 0000000..d8f9379 --- /dev/null +++ b/internal/mapr/groupset.go @@ -0,0 +1,178 @@ +package mapr + +import ( + "errors" + "fmt" + "io/ioutil" + "sort" + "strconv" + "strings" +) + +// GroupSet represents a map of aggregate sets. The group sets +// are requierd by the "group by" mapr clause, whereas the +// group set map keys are the values of the "group by" arguments. +// E.g. "group by $cid" would create one aggregate set and one map +// entry per customer id. +type GroupSet struct { + sets map[string]*AggregateSet +} + +// NewGroupSet returns a new empty group set. +func NewGroupSet() *GroupSet { + g := GroupSet{} + g.InitSet() + return &g +} + +// String representation of the group set. +func (g *GroupSet) String() string { + return fmt.Sprintf("GroupSet(%v)", g.sets) +} + +// InitSet makes the group set empty (initialize). +func (g *GroupSet) InitSet() { + g.sets = make(map[string]*AggregateSet) +} + +// GetSet gets a specific aggregate set from the group set. +func (g *GroupSet) GetSet(groupKey string) *AggregateSet { + set, ok := g.sets[groupKey] + if !ok { + set = NewAggregateSet() + g.sets[groupKey] = set + } + return set +} + +// Serialize the group set (e.g. to send it over the wire). +func (g *GroupSet) Serialize(ch chan<- string, stop chan struct{}) { + for groupKey, set := range g.sets { + set.Serialize(groupKey, ch, stop) + } +} + +// Result returns a nicely formated result of the query from the group set. +func (g *GroupSet) Result(query *Query) (string, int, error) { + return g.limitedResult(query, query.Limit, "\t", " ", false) +} + +// WriteResult writes the result to an outfile. +func (g *GroupSet) WriteResult(query *Query) error { + if query.Outfile == "" { + return errors.New("No outfile specified") + } + + // -1: Don't limit the result, include all data sets + result, _, err := g.limitedResult(query, -1, "", ",", true) + if err != nil { + return err + } + + return ioutil.WriteFile(query.Outfile, []byte(result), 0644) +} + +// Return a nicely formated result of the query from the group set. +func (g *GroupSet) limitedResult(query *Query, limit int, lineStarter, fieldSeparator string, addHeader bool) (string, int, error) { + type result struct { + groupKey string + resultStr string + orderBy float64 + } + + var resultSlice []result + + for groupKey, set := range g.sets { + var sb strings.Builder + r := result{groupKey: groupKey} + + lastIndex := len(query.Select) - 1 + for i, sc := range query.Select { + storage := sc.FieldStorage + orderByThis := storage == query.OrderBy + + switch sc.Operation { + case Count: + value := set.FValues[storage] + sb.WriteString(fmt.Sprintf("%d", int(value))) + if orderByThis { + r.orderBy = value + } + case Len: + fallthrough + case Sum: + fallthrough + case Min: + fallthrough + case Max: + value := set.FValues[storage] + sb.WriteString(fmt.Sprintf("%f", value)) + if orderByThis { + r.orderBy = value + } + case Last: + value := set.SValues[storage] + if orderByThis { + f, err := strconv.ParseFloat(value, 64) + if err == nil { + r.orderBy = f + } + } + sb.WriteString(value) + case Avg: + value := set.FValues[storage] / float64(set.Samples) + sb.WriteString(fmt.Sprintf("%f", value)) + if orderByThis { + r.orderBy = value + } + default: + return "", 0, fmt.Errorf("Unknown aggregation method '%v'", sc.Operation) + } + if i != lastIndex { + sb.WriteString(fieldSeparator) + } + } + + r.resultStr = sb.String() + resultSlice = append(resultSlice, r) + } + + if query.OrderBy != "" { + if query.ReverseOrder { + sort.SliceStable(resultSlice, func(i, j int) bool { + return resultSlice[i].orderBy < resultSlice[j].orderBy + }) + } else { + sort.SliceStable(resultSlice, func(i, j int) bool { + return resultSlice[i].orderBy > resultSlice[j].orderBy + }) + } + } + + var sb strings.Builder + + // Write header first + if addHeader { + lastIndex := len(query.Select) - 1 + sb.WriteString(lineStarter) + for i, sc := range query.Select { + sb.WriteString(sc.FieldStorage) + if i != lastIndex { + sb.WriteString(fieldSeparator) + } + } + sb.WriteString("\n") + } + + // And now write the data + for i, r := range resultSlice { + if i == limit { + break + } + sb.WriteString(lineStarter) + sb.WriteString(r.resultStr) + sb.WriteString("\n") + } + + return sb.String(), len(resultSlice), nil +} diff --git a/internal/mapr/logformat/default.go b/internal/mapr/logformat/default.go new file mode 100644 index 0000000..f0df5bc --- /dev/null +++ b/internal/mapr/logformat/default.go @@ -0,0 +1,23 @@ +package logformat + +import ( + "errors" + "strings" +) + +// MakeFieldsDEFAULT is the default log file mapreduce parser. +func (p *Parser) MakeFieldsDEFAULT(maprLine string) (map[string]string, error) { + fields := make(map[string]string, 20) + splitted := strings.Split(maprLine, "|") + + fields["$hostname"] = p.hostname + + for _, kv := range splitted { + keyAndValue := strings.SplitN(kv, "=", 2) + if len(keyAndValue) != 2 { + return fields, errors.New("Error parsing mapr token: " + kv) + } + fields[strings.ToLower(keyAndValue[0])] = keyAndValue[1] + } + return fields, nil +} diff --git a/internal/mapr/logformat/default_test.go b/internal/mapr/logformat/default_test.go new file mode 100644 index 0000000..a3c47fb --- /dev/null +++ b/internal/mapr/logformat/default_test.go @@ -0,0 +1,35 @@ +package logformat + +import ( + "testing" +) + +func TestDefaultLogFormat(t *testing.T) { + parser, err := NewParser("default") + if err != nil { + t.Errorf("Unable to create parser: %s", err.Error()) + } + + fields, err := parser.MakeFields("foo=bar|baz=bay") + + if err != nil { + t.Errorf("Unable to parse: %s", err.Error()) + } + + if bar, ok := fields["foo"]; !ok { + t.Errorf("Expected field 'foo', but no such field there\n") + } else if bar != "bar" { + t.Errorf("Expected 'bar' stored in field 'foo', but got '%s'\n", bar) + } + + if bay, ok := fields["baz"]; !ok { + t.Errorf("Expected field 'baz', but no such field there\n") + } else if bay != "bay" { + t.Errorf("Expected 'bay' stored in field 'baz', but got '%s'\n", bay) + } + + fields, err = parser.MakeFields("foo=bar|bazbay") + if err == nil { + t.Errorf("Expected error but didn't: %s", err.Error()) + } +} diff --git a/internal/mapr/logformat/parser.go b/internal/mapr/logformat/parser.go new file mode 100644 index 0000000..5730d29 --- /dev/null +++ b/internal/mapr/logformat/parser.go @@ -0,0 +1,75 @@ +package logformat + +import ( + "github.com/mimecast/dtail/internal/logger" + "errors" + "fmt" + "os" + "reflect" + "strings" +) + +// Parser is used to parse the mapreduce information from the server log files. +type Parser struct { + hostname string + logFormatName string + makeFieldsFunc reflect.Value + makeFieldsReceiver reflect.Value +} + +// NewParser returns a new log parser. +func NewParser(logFormatName string) (*Parser, error) { + hostname, err := os.Hostname() + + if err != nil { + return nil, err + } + + p := Parser{ + hostname: hostname, + } + + err = p.reflectLogFormat(logFormatName) + if err != nil { + return nil, err + } + + return &p, nil +} + +// The aim of this is that everyone can plug in their own mapr log format +// parsing method to DTail. Just add a method MakeFieldsMODULENAME to type +// Parser. Whereas MODULENAME must be a upeprcase string. +func (p *Parser) reflectLogFormat(logFormatName string) error { + methodName := fmt.Sprintf("MakeFields%s", strings.ToUpper(logFormatName)) + + rt := reflect.TypeOf(p) + method, ok := rt.MethodByName(methodName) + if !ok { + return errors.New("No such mapr log format module: " + methodName) + } + + p.makeFieldsFunc = method.Func + p.makeFieldsReceiver = reflect.ValueOf(p) + + return nil +} + +// MakeFields is for returning the fields from a given log line. +func (p *Parser) MakeFields(maprLine string) (fields map[string]string, err error) { + inputValues := []reflect.Value{p.makeFieldsReceiver, reflect.ValueOf(maprLine)} + returnValues := p.makeFieldsFunc.Call(inputValues) + + errInterface := returnValues[1].Interface() + + if errInterface == nil { + fields, err = returnValues[0].Interface().(map[string]string), nil + logger.Trace("parser.MakeFields", fields, err) + return + } + + fields, err = returnValues[0].Interface().(map[string]string), errInterface.(error) + logger.Trace("parser.MakeFields", fields, err) + + return +} diff --git a/internal/mapr/query.go b/internal/mapr/query.go new file mode 100644 index 0000000..3805d15 --- /dev/null +++ b/internal/mapr/query.go @@ -0,0 +1,245 @@ +package mapr + +import ( + "github.com/mimecast/dtail/internal/logger" + "errors" + "fmt" + "strconv" + "strings" + "time" +) + +const ( + invalidQuery string = "Invalid query: " + unexpectedEnd string = "Unexpected end of query" +) + +// Query represents a parsed mapr query. +type Query struct { + Select []selectCondition + Table string + Where []whereCondition + GroupBy []string + OrderBy string + ReverseOrder bool + GroupKey string + Interval time.Duration + Limit int + Outfile string + RawQuery string + tokens []token +} + +func (q Query) String() string { + return fmt.Sprintf("Query(Select:%v,Table:%s,Where:%v,GroupBy:%v,GroupKey:%s,OrderBy:%v,ReverseOrder:%v,Interval:%v,Limit:%d,Outfile:%s,RawQuery:%s,tokens:%v)", + q.Select, + q.Table, + q.Where, + q.GroupBy, + q.GroupKey, + q.OrderBy, + q.ReverseOrder, + q.Interval, + q.Limit, + q.Outfile, + q.RawQuery, + q.tokens) +} + +// NewQuery returns a new mapreduce query. +func NewQuery(queryStr string) (*Query, error) { + if queryStr == "" { + return nil, nil + } + + tokens := tokenize(queryStr) + + q := Query{ + RawQuery: queryStr, + tokens: tokens, + Interval: time.Second * 5, + Limit: -1, + } + + err := q.parse(tokens) + + logger.Debug(q) + return &q, err +} + +func (q *Query) parse(tokens []token) error { + var found []token + var err error + + for tokens != nil && len(tokens) > 0 { + switch strings.ToLower(tokens[0].str) { + case "select": + tokens, found = tokensConsume(tokens[1:]) + q.Select, err = makeSelectConditions(found) + if err != nil { + return err + } + case "from": + tokens, found = tokensConsume(tokens[1:]) + if len(found) > 0 { + q.Table = strings.ToUpper(found[0].str) + } + case "where": + tokens, found = tokensConsume(tokens[1:]) + if q.Where, err = makeWhereConditions(found); err != nil { + return err + } + case "group": + tokens = tokensConsumeOptional(tokens[1:], "by") + if tokens == nil || len(tokens) < 1 { + return errors.New(invalidQuery + unexpectedEnd) + } + tokens, q.GroupBy = tokensConsumeStr(tokens) + q.GroupKey = strings.Join(q.GroupBy, ",") + case "rorder": + tokens = tokensConsumeOptional(tokens[1:], "by") + if tokens == nil || len(tokens) < 1 { + return errors.New(invalidQuery + unexpectedEnd) + } + tokens, found = tokensConsume(tokens) + if len(found) == 0 { + return errors.New(invalidQuery + unexpectedEnd) + } + q.OrderBy = found[0].str + q.ReverseOrder = true + case "order": + tokens = tokensConsumeOptional(tokens[1:], "by") + if tokens == nil || len(tokens) < 1 { + return errors.New(invalidQuery + unexpectedEnd) + } + tokens, found = tokensConsume(tokens) + if len(found) == 0 { + return errors.New(invalidQuery + unexpectedEnd) + } + q.OrderBy = found[0].str + case "interval": + tokens, found = tokensConsume(tokens[1:]) + if len(found) > 0 { + i, err := strconv.Atoi(found[0].str) + if err != nil { + return errors.New(invalidQuery + err.Error()) + } + q.Interval = time.Second * time.Duration(i) + } + case "limit": + tokens, found = tokensConsume(tokens[1:]) + if len(found) == 0 { + return errors.New(invalidQuery + unexpectedEnd) + } + i, err := strconv.Atoi(found[0].str) + if err != nil { + return errors.New(invalidQuery + err.Error()) + } + q.Limit = i + case "outfile": + tokens, found = tokensConsume(tokens[1:]) + if len(found) == 0 { + return errors.New(invalidQuery + unexpectedEnd) + } + q.Outfile = found[0].str + default: + return errors.New(invalidQuery + "Unexpected keyword " + tokens[0].str) + } + } + + if q.Table == "" { + return errors.New(invalidQuery + "Empty table specified in 'from' clause") + } + if len(q.Select) < 1 { + return errors.New(invalidQuery + "Expected at least one field in 'select' clause but got none") + } + if len(q.GroupBy) == 0 { + field := q.Select[0].Field + q.GroupBy = append(q.GroupBy, field) + } + + if q.OrderBy != "" { + var orderFieldIsValid bool + for _, sc := range q.Select { + if q.OrderBy == sc.FieldStorage { + orderFieldIsValid = true + break + } + } + if !orderFieldIsValid { + return errors.New(invalidQuery + fmt.Sprintf("Can not '(r)order by' '%s', must be present in 'select' clause", q.OrderBy)) + } + } + + return nil +} + +// WhereClause interprets the where clause of the mapreduce query. +func (q *Query) WhereClause(fields map[string]string) bool { + floatValue := func(str string, float float64, t whereType) (float64, bool) { + switch t { + case Float: + return float, true + case Field: + value, ok := fields[str] + if !ok { + return 0, false + } + f, err := strconv.ParseFloat(value, 64) + if err != nil { + return 0, false + } + return f, true + default: + logger.Error("Unexpected argument in 'where' clause", str, float, t) + return 0, false + } + } + + stringValue := func(str string, t whereType) (string, bool) { + switch t { + case Field: + value, ok := fields[str] + if !ok { + return str, false + } + return value, true + case String: + return str, true + default: + logger.Error("Unexpected argument in 'where' clause", str, t) + return str, false + } + } + + for _, wc := range q.Where { + var ok bool + + if wc.Operation > FloatOperation { + var lValue, rValue float64 + if lValue, ok = floatValue(wc.lString, wc.lFloat, wc.lType); !ok { + return false + } + if rValue, ok = floatValue(wc.rString, wc.rFloat, wc.rType); !ok { + return false + } + if ok = wc.floatClause(lValue, rValue); !ok { + return false + } + continue + } + + var lValue, rValue string + if lValue, ok = stringValue(wc.lString, wc.lType); !ok { + return false + } + if rValue, ok = stringValue(wc.rString, wc.rType); !ok { + return false + } + if ok = wc.stringClause(lValue, rValue); !ok { + return false + } + } + + return true +} diff --git a/internal/mapr/query_test.go b/internal/mapr/query_test.go new file mode 100644 index 0000000..6176461 --- /dev/null +++ b/internal/mapr/query_test.go @@ -0,0 +1,149 @@ +package mapr + +import ( + "testing" + "time" +) + +func TestParseQuerySimple(t *testing.T) { + errorQueries := []string{ + "select", + "select foo", + "select foo from", + "select foo from bar where baz", + "select foo from bar where baz <", + "select foo from bar where baz < 100 bay eq 12 group", + "select foo from bar where baz < 100 bay eq 12 group by foo order by", + "select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz order by foo limit", + } + okQueries := []string{"select foo from bar", + "select foo from bar where", + "select foo from bar where baz < 100 bay eq 12", + "select foo from bar where baz < 100, bay eq 12", + "select foo from bar where baz < 100 and bay eq 12", + "select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz order by foo", + "select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz order by foo limit 23", + "select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz order by foo limit 23 outfile \"result.csv\"", + } + + for _, queryStr := range errorQueries { + q, err := NewQuery(queryStr) + if err == nil { + t.Errorf("Expected a parse error: %s\n%v", queryStr, q) + continue + } + } + + for _, queryStr := range okQueries { + _, err := NewQuery(queryStr) + if err != nil { + t.Errorf("%s: %s", err.Error(), queryStr) + continue + } + } +} + +func TestParseQueryDeep(t *testing.T) { + dialects := []string{ + "select s1, `from`, count(s3) from table where w1 == 2 and w2 eq \"free beer\" group by g1, g2 order by count(s3) interval 10 limit 23", + "SELECT s1, `from` COUNT(s3) FROM table WHERE w1 == 2 AND w2 eq \"free beer\" GROUP g1, g2 ORDER count(s3) INTERVAL 10 LIMIT 23", + "select s1, `from` count(s3) from table where w1 == 2 and w2 eq \"free beer\" group by g1, g2 order by count(s3) interval 10 limit 23", + "sElEct s1, `from` coUnt(s3) from taBle where w1 == 2 aNd w2 eq \"free beer\" Group By g1, g2 order bY count(s3) intervaL 10 LiMiT 23", + "SELECT s1 `from` COUNT(s3) FROM table WHERE w1 == 2 AND w2 eq \"free beer\" GROUP BY g1 g2 ORDER BY count(s3) INTERVAL 10 LIMIT 23", + "select s1 `from` count(s3) from table where w1 == 2 w2 eq \"free beer\" group g1 g2 order count(s3) interval 10 limit 23", + "limit 23 interval 10 order count(s3) group g1 g2 where w1 == 2 w2 eq \"free beer\" from table select s1 `from` count(s3)", + } + + for _, queryStr := range dialects { + q, err := NewQuery(queryStr) + if err != nil { + t.Errorf("%s: %s", err.Error(), queryStr) + } + + // 'select' clause + if len(q.Select) != 3 { + t.Errorf("Expected three elements in 'select' clause but got '%v': %s\n%v", q.Select, queryStr, q) + } + + if q.Select[0].Field != "s1" { + t.Errorf("Expected 's1' as first element in 'select' clause but got '%v': %s\n%v", q.Select[0].Field, queryStr, q) + } + if q.Select[0].Operation != Last { + t.Errorf("Expected 'last' as aggregation function of first element in 'select' clause but got '%v': %s\n%v", q.Select[0].Operation, queryStr, q) + } + + if q.Select[1].Field != "from" { + t.Errorf("Expected 'from' as second element in 'select' clause but got '%v': %s\n%v", q.Select[1].Field, queryStr, q) + } + if q.Select[1].Operation != Last { + t.Errorf("Expected 'last' as aggregation function of second element in 'select' clause but got '%v': %s\n%v", q.Select[1].Operation, queryStr, q) + } + + if q.Select[2].Field != "s3" { + t.Errorf("Expected 's3' as third element in 'select' clause but got '%v': %s\n%v", q.Select[2].Field, queryStr, q) + } + if q.Select[2].Operation != Count { + t.Errorf("Expected 'count' as aggregation function of third element in 'select' clause but got '%v': %s\n%v", q.Select[2].Operation, queryStr, q) + } + if q.Select[2].FieldStorage != "count(s3)" { + t.Errorf("Expected 'count(s3)' as third element's storage in 'select' clause but got '%v': %s\n%v", q.Select[2].FieldStorage, queryStr, q) + } + + // 'from' clause + if q.Table != "TABLE" { + t.Errorf("Expected 'TABLE' in 'from' clause but got '%v': %s\n%v", q.Table, queryStr, q) + } + + // 'where' clause + if len(q.Where) != 2 { + t.Errorf("Expected two elements in 'where' clause but got '%v': %s\n%v", q.Where, queryStr, q) + } + if q.Where[0].lString != "w1" { + t.Errorf("Expected w1 as first element in 'where' clause but got '%v': %s\n%v", q.Where[0].lString, queryStr, q) + } + if q.Where[0].Operation != FloatEq { + t.Errorf("Expected FloatEq operation in first 'where' condition but got '%v': %s\n%v", q.Where[0].Operation, queryStr, q) + } + if q.Where[0].rFloat != 2 { + t.Errorf("Expected '2' as float argument in first 'where' condition but got '%v': %s\n%v", q.Where[0].rFloat, queryStr, q) + } + if q.Where[1].lString != "w2" { + t.Errorf("Expected w2 as second element in 'where' clause but got '%v': %s\n%v", q.Where[1].lString, queryStr, q) + } + if q.Where[1].Operation != StringEq { + t.Errorf("Expected StringEq operation in second 'where' condition but got '%v': %s\n%v", q.Where[0].Operation, queryStr, q) + } + if q.Where[1].rString != "free beer" { + t.Errorf("Expected 'free beer' as string argument in second 'where' condition but got '%v': %s\n%v", q.Where[0].rString, queryStr, q) + } + + // 'group by' clause + if len(q.GroupBy) != 2 { + t.Errorf("Expected two elements in 'group by' clause but got '%v': %s\n%v", q.GroupBy, queryStr, q) + } + if q.GroupBy[0] != "g1" { + t.Errorf("Expected 'g1' as first element in 'group by' clause but got '%v': %s\n%v", q.GroupBy[0], queryStr, q) + } + if q.GroupBy[1] != "g2" { + t.Errorf("Expected 'g2' as second element in 'group by' clause but got '%v': %s\n%v", q.GroupBy[1], queryStr, q) + } + if q.GroupKey != "g1,g2" { + t.Errorf("Expected 'g1,g2' as group key in 'group by' clause but got '%v': %s\n%v", q.GroupKey, queryStr, q) + } + + // 'order by' clause + if q.OrderBy != "count(s3)" { + t.Errorf("Expected 'count(s3)' as element in 'order by' clause but got '%v': %s\n%v", q.OrderBy, queryStr, q) + } + + // 'interval' clause + if q.Interval != time.Second*time.Duration(10) { + t.Errorf("Expected '10s' as duration 'interval' clause but got '%v': %s\n%v", q.Interval, queryStr, q) + } + + // 'limit' clause + if q.Limit != 23 { + t.Errorf("Expected '23' as limit in 'limit' clause but got '%v': %s\n%v", q.Limit, queryStr, q) + } + } +} diff --git a/internal/mapr/selectcondition.go b/internal/mapr/selectcondition.go new file mode 100644 index 0000000..1882b7e --- /dev/null +++ b/internal/mapr/selectcondition.go @@ -0,0 +1,96 @@ +package mapr + +import ( + "errors" + "fmt" + "strings" +) + +// AggregateOperation is to specify the aggregate operation type. +type AggregateOperation int + +// Aggregate operation types +const ( + UndefAggregateOperation AggregateOperation = iota + Count AggregateOperation = iota + Sum AggregateOperation = iota + Min AggregateOperation = iota + Max AggregateOperation = iota + Last AggregateOperation = iota + Avg AggregateOperation = iota + Len AggregateOperation = iota +) + +// Represents a parsed "select" clause, used by mapr.Query. +type selectCondition struct { + Field string + FieldStorage string + Operation AggregateOperation +} + +func (sc selectCondition) String() string { + return fmt.Sprintf("selectCondition(Field:%s,FieldStorage:%s,Operation:%v)", + sc.Field, + sc.FieldStorage, + sc.Operation) +} + +func makeSelectConditions(tokens []token) ([]selectCondition, error) { + var sel []selectCondition + + // Parse select aggregation, e.g. sum(foo) + parse := func(token token) (selectCondition, error) { + var sc selectCondition + tokenStr := strings.ToLower(token.str) + + if !strings.Contains(tokenStr, "(") && !strings.Contains(tokenStr, ")") { + sc.Field = tokenStr + sc.FieldStorage = tokenStr + sc.Operation = Last + return sc, nil + } + + a := strings.Split(tokenStr, "(") + if len(a) != 2 { + return sc, errors.New(invalidQuery + "Can't parse 'select' aggregation: " + token.str) + } + agg := a[0] // Aggregation, e.g. 'sum' + + b := strings.Split(a[1], ")") + if len(b) != 2 { + return sc, errors.New(invalidQuery + "Can't parse 'select' field name from aggregation: " + token.str) + } + sc.Field = b[0] // Field name, e.g. 'foo' + sc.FieldStorage = tokenStr // e.g. 'sum(foo)' + + switch agg { + case "count": + sc.Operation = Count + case "sum": + sc.Operation = Sum + case "min": + sc.Operation = Min + case "max": + sc.Operation = Max + case "last": + sc.Operation = Last + case "avg": + sc.Operation = Avg + case "len": + sc.Operation = Len + default: + return sc, errors.New(invalidQuery + "Unknown aggregation in 'select' clause: " + agg) + } + + return sc, nil + } + + for _, token := range tokens { + sc, err := parse(token) + if err != nil { + return nil, err + } + sel = append(sel, sc) + } + return sel, nil +} diff --git a/internal/mapr/server/aggregate.go b/internal/mapr/server/aggregate.go new file mode 100644 index 0000000..900756e --- /dev/null +++ b/internal/mapr/server/aggregate.go @@ -0,0 +1,170 @@ +package server + +import ( + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/fs" + "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/mapr" + "github.com/mimecast/dtail/internal/mapr/logformat" + "os" + "strings" + "time" +) + +// Aggregate is for aggregating mapreduce data on the DTail server side. +type Aggregate struct { + // Log lines to process (parsing MAPREDUCE lines). + Lines chan fs.LineRead + // Hostname of the current server (used to populate $hostname field). + hostname string + // Signals to exit goroutine. + stop chan struct{} + // Signals to serialize data. + serialize chan struct{} + // The mapr query + query *mapr.Query + // The mapr log format parser + parser *logformat.Parser +} + +// NewAggregate return a new server side aggregator. +func NewAggregate(maprLines chan<- string, queryStr string) (*Aggregate, error) { + query, err := mapr.NewQuery(queryStr) + if err != nil { + return nil, err + } + + fqdn, err := os.Hostname() + if err != nil { + logger.Error(err) + } + s := strings.Split(fqdn, ".") + + logger.Info("Creating mapr log format parser", config.Server.MapreduceLogFormat) + logParser, err := logformat.NewParser(config.Server.MapreduceLogFormat) + if err != nil { + logger.FatalExit("Could not create mapr log format parser", err) + } + + a := Aggregate{ + Lines: make(chan fs.LineRead, 100), + stop: make(chan struct{}), + serialize: make(chan struct{}), + hostname: s[0], + query: query, + parser: logParser, + } + + go a.periodicAggregateTimer() + + fieldsCh := make(chan map[string]string) + go a.readFields(fieldsCh, maprLines) + go a.readLines(fieldsCh) + + return &a, nil +} + +func (a *Aggregate) periodicAggregateTimer() { + for { + select { + case <-time.After(a.query.Interval): + a.Serialize() + case <-a.stop: + return + } + } +} + +func (a *Aggregate) readFields(fieldsCh <-chan map[string]string, maprLines chan<- string) { + group := mapr.NewGroupSet() + + for { + select { + case fields := <-fieldsCh: + a.aggregate(group, fields) + case <-a.serialize: + logger.Info("Serializing mapreduce result") + group.Serialize(maprLines, a.stop) + logger.Info("Done serializing mapreduce result") + group = mapr.NewGroupSet() + case <-a.stop: + return + } + } +} + +func (a *Aggregate) readLines(fieldsCh chan<- map[string]string) { + for { + select { + case line, ok := <-a.Lines: + if !ok { + return + } + + maprLine := strings.TrimSpace(string(line.Content)) + fields, err := a.parser.MakeFields(maprLine) + + if err != nil { + logger.Error(err) + continue + } + if !a.query.WhereClause(fields) { + continue + } + + select { + case fieldsCh <- fields: + case <-a.stop: + } + case <-a.stop: + return + } + } +} + +func (a *Aggregate) aggregate(group *mapr.GroupSet, fields map[string]string) { + //logger.Trace("Aggregating", group, fields) + var sb strings.Builder + + for i, field := range a.query.GroupBy { + if i > 0 { + sb.WriteString(" ") + } + if val, ok := fields[field]; ok { + sb.WriteString(val) + } + } + groupKey := sb.String() + set := group.GetSet(groupKey) + + var addedSample bool + for _, sc := range a.query.Select { + if val, ok := fields[sc.Field]; ok { + if err := set.Aggregate(sc.FieldStorage, sc.Operation, val, false); err != nil { + logger.Error(err) + continue + } + addedSample = true + } + } + + if addedSample { + set.Samples++ + return + } + + logger.Trace("Aggregated data locally without adding new samples") +} + +// Serialize all the aggregated data. +func (a *Aggregate) Serialize() { + select { + case a.serialize <- struct{}{}: + case <-a.stop: + } +} + +// Close the aggregator. +func (a *Aggregate) Close() { + close(a.stop) +} diff --git a/internal/mapr/token.go b/internal/mapr/token.go new file mode 100644 index 0000000..b8be4da --- /dev/null +++ b/internal/mapr/token.go @@ -0,0 +1,108 @@ +package mapr + +import ( + "strings" +) + +var keywords = [...]string{"select", "from", "where", "group", "rorder", "order", "interval", "limit", "outfile"} + +// Represents a parsed token, used to parse the mapr query. +type token struct { + str string + isBareword bool +} + +func (t token) isKeyword() bool { + if !t.isBareword { + return false + } + + for _, keyword := range keywords { + if strings.ToLower(t.str) == keyword { + return true + } + } + return false +} + +func (t token) String() string { + return t.str +} + +func tokenize(queryStr string) []token { + var tokens []token + + for i, part := range strings.Split(queryStr, "\"") { + // Even i, means that it is not a quoted string + if i%2 == 0 { + commasStripped := strings.Replace(part, ",", " ", -1) + for _, tokenStr := range strings.Fields(commasStripped) { + token := token{ + str: tokenStr, + isBareword: true, + } + tokens = append(tokens, token) + } + continue + } + // Add whole quoted string as a token + token := token{ + str: part, + isBareword: false, + } + tokens = append(tokens, token) + } + + return tokens +} + +func tokensConsume(tokens []token) ([]token, []token) { + //logger.Trace("=====================") + var consumed []token + + for i, t := range tokens { + if t.isKeyword() { + //logger.Trace("keyword", t) + return tokens[i:], consumed + } + // strip escapes, such as ` from `foo`, this allows to use keywords as field names + length := len(t.str) + if length == 0 { + continue + } + if t.str[0] == '`' && t.str[length-1] == '`' { + stripped := t.str[1 : length-1] + //logger.Trace("stripped", stripped) + t := token{ + str: stripped, + isBareword: t.isBareword, + } + consumed = append(consumed, t) + continue + } + //logger.Trace("bare", token) + consumed = append(consumed, t) + } + + //logger.Trace("result", consumed) + return nil, consumed +} + +func tokensConsumeStr(tokens []token) ([]token, []string) { + var strings []string + tokens, found := tokensConsume(tokens) + for _, token := range found { + strings = append(strings, token.str) + } + return tokens, strings +} + +func tokensConsumeOptional(tokens []token, optional string) []token { + if len(tokens) < 1 { + return tokens + } + if strings.ToLower(tokens[0].str) == strings.ToLower(optional) { + return tokens[1:] + } + return tokens +} diff --git a/internal/mapr/wherecondition.go b/internal/mapr/wherecondition.go new file mode 100644 index 0000000..e1f4e5b --- /dev/null +++ b/internal/mapr/wherecondition.go @@ -0,0 +1,193 @@ +package mapr + +import ( + "github.com/mimecast/dtail/internal/logger" + "errors" + "fmt" + "strconv" + "strings" +) + +// QueryOperation determines the mapreduce operation. +type QueryOperation int + +// The possible mapreduce operation.s +const ( + UndefQueryOperation QueryOperation = iota + StringEq QueryOperation = iota + StringNe QueryOperation = iota + StringContains QueryOperation = iota + FloatOperation QueryOperation = iota + FloatEq QueryOperation = iota + FloatNe QueryOperation = iota + FloatLt QueryOperation = iota + FloatLe QueryOperation = iota + FloatGt QueryOperation = iota + FloatGe QueryOperation = iota +) + +type whereType int + +// The possible field types. +const ( + UndefWhereType whereType = iota + Field whereType = iota + String whereType = iota + Float whereType = iota +) + +func (w whereType) String() string { + switch w { + case Field: + return fmt.Sprintf("Field") + case String: + return fmt.Sprintf("String") + case Float: + return fmt.Sprintf("Float") + default: + return fmt.Sprintf("UndefWhereType") + } +} + +// Represent a parsed "where" clause, used by mapr.Query +type whereCondition struct { + lString string + lFloat float64 + lType whereType + + Operation QueryOperation + + rString string + rFloat float64 + rType whereType +} + +func (wc *whereCondition) String() string { + return fmt.Sprintf("whereCondition(Operation:%v,lString:%s,lFloat:%v,lType:%s,rString:%s,rFloat:%v,rType:%s)", + wc.Operation, wc.lString, wc.lFloat, wc.lType.String(), wc.rString, wc.rFloat, wc.rType.String()) +} + +func makeWhereConditions(tokens []token) (where []whereCondition, err error) { + parse := func(tokens []token) (whereCondition, []token, error) { + var wc whereCondition + if len(tokens) < 3 { + return wc, nil, errors.New(invalidQuery + "Not enough arguments in 'where' clause") + } + + whereOp := strings.ToLower(tokens[1].str) + switch whereOp { + case "==": + wc.Operation = FloatEq + case "!=": + wc.Operation = FloatNe + case "<": + wc.Operation = FloatLt + case "<=": + wc.Operation = FloatLe + case "=<": + wc.Operation = FloatLe + case ">": + wc.Operation = FloatGt + case ">=": + wc.Operation = FloatGe + case "=>": + wc.Operation = FloatGe + case "eq": + wc.Operation = StringEq + case "ne": + wc.Operation = StringNe + case "contains": + wc.Operation = StringContains + default: + return wc, nil, errors.New(invalidQuery + "Unknown operation in 'where' clause: " + whereOp) + } + + wc.lString = tokens[0].str + wc.rString = tokens[2].str + + if wc.Operation > FloatOperation { + if !tokens[0].isBareword { + return wc, nil, errors.New(invalidQuery + "Expected bareword at 'where' clause's lValue: " + tokens[0].str) + } + if f, err := strconv.ParseFloat(wc.lString, 64); err == nil { + wc.lFloat = f + wc.lType = Float + } else { + wc.lType = Field + } + + if !tokens[2].isBareword { + return wc, nil, errors.New(invalidQuery + "Expected bareword at 'where' clause's rValue: " + tokens[2].str) + } + if f, err := strconv.ParseFloat(wc.rString, 64); err == nil { + wc.rFloat = f + wc.rType = Float + } else { + wc.rType = Field + } + return wc, tokens[3:], nil + } + + if tokens[0].isBareword { + wc.lType = Field + } else { + wc.lType = String + } + if tokens[2].isBareword { + wc.rType = Field + } else { + wc.rType = String + } + + return wc, tokens[3:], nil + } + + for len(tokens) > 0 { + var wc whereCondition + var err error + + wc, tokens, err = parse(tokens) + if err != nil { + return nil, err + } + + where = append(where, wc) + tokens = tokensConsumeOptional(tokens, "and") + } + + return +} + +func (wc *whereCondition) floatClause(lValue float64, rValue float64) bool { + switch wc.Operation { + case FloatEq: + return lValue == rValue + case FloatNe: + return lValue != rValue + case FloatLt: + return lValue < rValue + case FloatLe: + return lValue <= rValue + case FloatGt: + return lValue > rValue + case FloatGe: + return lValue >= rValue + default: + logger.Error("Unknown float operation", lValue, wc.Operation, rValue) + } + return false +} + +func (wc *whereCondition) stringClause(lValue string, rValue string) bool { + switch wc.Operation { + case StringEq: + return lValue == rValue + case StringNe: + return lValue != rValue + case StringContains: + return strings.Contains(lValue, rValue) + default: + logger.Error("Unknown string operation", lValue, wc.Operation, rValue) + } + return false +} diff --git a/internal/omode/mode.go b/internal/omode/mode.go new file mode 100644 index 0000000..4bdfc45 --- /dev/null +++ b/internal/omode/mode.go @@ -0,0 +1,81 @@ +package omode + +import ( + "fmt" + "os" + "path" +) + +// Mode used. +type Mode int + +// Possible modes. +const ( + Unknown Mode = iota + Server Mode = iota + TailClient Mode = iota + CatClient Mode = iota + GrepClient Mode = iota + MapClient Mode = iota + HealthClient Mode = iota +) + +// New returns the mode based on the mode string. +func New(modeStr string) Mode { + switch modeStr { + case "dserver": + return Server + case "server": + return Server + + case "dtail": + fallthrough + case "tail": + return TailClient + + case "grep": + fallthrough + case "dgrep": + return GrepClient + + case "cat": + fallthrough + case "dcat": + return CatClient + + case "map": + fallthrough + case "dmap": + return MapClient + + case "health": + return HealthClient + + default: + panic(fmt.Sprintf("Unknown mode: '%s'", modeStr)) + } +} + +// Default mode. +func Default() Mode { + return New(path.Base(os.Args[0])) +} + +func (m Mode) String() string { + switch m { + case Server: + return "server" + case TailClient: + return "tail" + case CatClient: + return "cat" + case GrepClient: + return "grep" + case MapClient: + return "map" + case HealthClient: + return "health" + default: + return "unknown" + } +} diff --git a/internal/pprof/pprof.go b/internal/pprof/pprof.go new file mode 100644 index 0000000..f78bcf6 --- /dev/null +++ b/internal/pprof/pprof.go @@ -0,0 +1,17 @@ +package pprof + +import ( + "fmt" + "net/http" + _ "net/http" + _ "net/http/pprof" + + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/logger" +) + +func Start() { + bindAddr := fmt.Sprintf("%s:%d", config.Common.PProfBindAddress, config.Common.PProfPort) + logger.Info("Starting PProf server", bindAddr) + go http.ListenAndServe(bindAddr, nil) +} diff --git a/internal/prompt/prompt.go b/internal/prompt/prompt.go new file mode 100644 index 0000000..76a2726 --- /dev/null +++ b/internal/prompt/prompt.go @@ -0,0 +1,95 @@ +package prompt + +import ( + "bufio" + "github.com/mimecast/dtail/internal/logger" + "fmt" + "os" + "strings" +) + +// Answer is a user input of a prompt question. +type Answer struct { + // Long version of the expected user input + Long string + // Short version of the expected user input + Short string + // Runs when user input matches + Callback func() + // Runs after Callback and after logging resumes + EndCallback func() + + AskAgain bool +} + +// Prompt used for interactive user input. +type Prompt struct { + question string + answers []Answer +} + +func (p *Prompt) askString() string { + var sb strings.Builder + + sb.WriteString(p.question) + sb.WriteString("? (") + + var ax []string + for _, a := range p.answers { + ax = append(ax, fmt.Sprintf("%s=%s", a.Short, a.Long)) + } + + sb.WriteString(strings.Join(ax, ",")) + sb.WriteString("): ") + + return sb.String() +} + +// New returns a new prompt. +func New(question string) *Prompt { + return &Prompt{question: question} +} + +// Add an answer. +func (p *Prompt) Add(answer Answer) { + p.answers = append(p.answers, answer) +} + +// Ask a question. +func (p *Prompt) Ask() { + reader := bufio.NewReader(os.Stdin) + logger.Pause() + + for { + fmt.Print(p.askString()) + answerStr, _ := reader.ReadString('\n') + + if a, ok := p.answer(strings.TrimSpace(answerStr)); ok { + if a.Callback != nil { + a.Callback() + } + + if !a.AskAgain { + logger.Resume() + if a.EndCallback != nil { + a.EndCallback() + } + return + } + } + } +} + +func (p *Prompt) answer(answerStr string) (*Answer, bool) { + for _, a := range p.answers { + switch answerStr { + case a.Long: + return &a, true + case a.Short: + return &a, true + default: + } + } + + return nil, false +} diff --git a/internal/server/handlers/controlhandler.go b/internal/server/handlers/controlhandler.go new file mode 100644 index 0000000..482f759 --- /dev/null +++ b/internal/server/handlers/controlhandler.go @@ -0,0 +1,106 @@ +package handlers + +import ( + "fmt" + "io" + "os" + "strings" + + "github.com/mimecast/dtail/internal/logger" + user "github.com/mimecast/dtail/internal/user/server" +) + +// ControlHandler is used for control functions and health monitoring. +type ControlHandler struct { + serverMessages chan string + pong chan struct{} + stop chan struct{} + payload []byte + hostname string + user *user.User +} + +// NewControlHandler returns a new control handler. +func NewControlHandler(user *user.User) *ControlHandler { + logger.Debug(user, "Creating control handler") + + h := ControlHandler{ + serverMessages: make(chan string, 10), + pong: make(chan struct{}, 10), + stop: make(chan struct{}), + user: user, + } + + fqdn, err := os.Hostname() + if err != nil { + logger.FatalExit(err) + } + + s := strings.Split(fqdn, ".") + h.hostname = s[0] + return &h +} + +// Read is to send data to the client via the Reader interface. +func (h *ControlHandler) Read(p []byte) (n int, err error) { + for { + select { + case message := <-h.serverMessages: + wholePayload := []byte(fmt.Sprintf("SERVER|%s|%s\n", h.hostname, message)) + n = copy(p, wholePayload) + return + case <-h.pong: + logger.Info(h.user, "Sending pong") + n = copy(p, []byte(".pong\n")) + return + case <-h.stop: + return 0, io.EOF + } + } +} + +// Write is to read data to the client via the Writer interface. +func (h *ControlHandler) Write(p []byte) (n int, err error) { + for _, c := range p { + switch c { + case ';': + wholePayload := strings.TrimSpace(string(h.payload)) + h.handleCommand(wholePayload) + h.payload = nil + + default: + h.payload = append(h.payload, c) + } + } + + n = len(p) + return +} + +// Close the control handler. +func (h *ControlHandler) Close() { + close(h.stop) +} + +// Wait returns the handler stop channel. +func (h *ControlHandler) Wait() <-chan struct{} { + return h.stop +} + +func (h *ControlHandler) handleCommand(command string) { + logger.Info(h.user, command) + s := strings.Split(command, " ") + logger.Debug(h.user, "Receiving command", command, s) + + switch s[0] { + case "health": + h.serverMessages <- "OK: DTail SSH Server seems fine" + h.serverMessages <- "done;" + case "ping": + h.pong <- struct{}{} + case "debug": + h.serverMessages <- logger.Debug(h.user, "Receiving debug command", command, s) + default: + h.serverMessages <- logger.Warn(h.user, "Received unknown command", command, s) + } +} diff --git a/internal/server/handlers/handler.go b/internal/server/handlers/handler.go new file mode 100644 index 0000000..8b1f73e --- /dev/null +++ b/internal/server/handlers/handler.go @@ -0,0 +1,10 @@ +package handlers + +import "io" + +// Handler interface for server side functionality. +type Handler interface { + io.ReadWriter + Close() + Wait() <-chan struct{} +} diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go new file mode 100644 index 0000000..bed8609 --- /dev/null +++ b/internal/server/handlers/serverhandler.go @@ -0,0 +1,492 @@ +package handlers + +import ( + "fmt" + "io" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/fs" + "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/mapr/server" + "github.com/mimecast/dtail/internal/omode" + user "github.com/mimecast/dtail/internal/user/server" + "github.com/mimecast/dtail/internal/version" +) + +const ( + commandParseWarning string = "Unable to parse command" +) + +// ServerHandler implements the Reader and Writer interfaces to handle +// the Bi-directional communication between SSH client and server. +// This handler implements the handler of the SSH server. +type ServerHandler struct { + // Local log file readers + fileReaders []fs.FileReader + fileReadersMtx *sync.Mutex + // Channel for read lines. + lines chan fs.LineRead + // Only process log lines matching this regex. + regex string + // Server side mapr log aggregation. + aggregate *server.Aggregate + // Channel of aggregated log lines. + aggregatedMessages chan string + // Channel for server messages to be sent to the client. + serverMessages chan string + // Channel for hidden messages to be sent to the client. + hiddenMessages chan string + // The current payload sent to the client. + payload []byte + // The current server hostname. + hostname string + // The user connecting to dtail. + user *user.User + // To limit the server wide max amount of concurrent cats + catLimiter chan struct{} + // To limit the server wide max amount of concurrent tails + tailLimiter chan struct{} + // Server can tell handler to stop the handler. + stop chan struct{} + // Indicate that client responded to server with "ack stop connection" + ackStopReceived chan struct{} + // Stop timeout. + stopTimeout chan struct{} +} + +// NewServerHandler returns the server handler. +func NewServerHandler(user *user.User, catLimiter chan struct{}, tailLimiter chan struct{}) *ServerHandler { + logger.Debug(user, "Creating tail handler") + h := ServerHandler{ + fileReadersMtx: &sync.Mutex{}, + lines: make(chan fs.LineRead, 100), + serverMessages: make(chan string, 10), + aggregatedMessages: make(chan string, 10), + hiddenMessages: make(chan string, 10), + ackStopReceived: make(chan struct{}), + stopTimeout: make(chan struct{}), + stop: make(chan struct{}), + catLimiter: catLimiter, + tailLimiter: tailLimiter, + regex: ".", + user: user, + } + + fqdn, err := os.Hostname() + if err != nil { + logger.FatalExit(err) + } + + s := strings.Split(fqdn, ".") + h.hostname = s[0] + + return &h +} + +// Read is to send data to the dtail client via Reader interface. +func (h *ServerHandler) Read(p []byte) (n int, err error) { + for { + select { + case message := <-h.serverMessages: + wholePayload := []byte(fmt.Sprintf("SERVER|%s|%s\n", h.hostname, message)) + n = copy(p, wholePayload) + return + case message := <-h.aggregatedMessages: + data := fmt.Sprintf("AGGREGATE|%s|%s\n", h.hostname, message) + //logger.Debug("Sending aggregation data", data) + wholePayload := []byte(data) + n = copy(p, wholePayload) + return + case message := <-h.hiddenMessages: + //logger.Debug(h.user, "Sending hidden message", message) + wholePayload := []byte(fmt.Sprintf(".%s\n", message)) + n = copy(p, wholePayload) + return + case line := <-h.lines: + serverInfo := []byte(fmt.Sprintf("REMOTE|%s|%3d|%v|%s|", + h.hostname, line.TransmittedPerc, line.Count, *line.GlobID)) + wholePayload := append(serverInfo, line.Content[:]...) + n = copy(p, wholePayload) + return + case <-time.After(time.Second): + select { + case <-h.stop: + return 0, io.EOF + default: + } + } + } +} + +// Write is to receive data from the dtail client via Writer interface. +func (h *ServerHandler) Write(p []byte) (n int, err error) { + for _, c := range p { + switch c { + case ';': + commandStr := strings.TrimSpace(string(h.payload)) + h.handleCommand(commandStr) + h.payload = nil + default: + h.payload = append(h.payload, c) + } + } + + n = len(p) + return +} + +// Close the server handler. +func (h *ServerHandler) Close() { + h.fileReadersMtx.Lock() + defer h.fileReadersMtx.Unlock() + + for _, reader := range h.fileReaders { + reader.Stop() + } + if h.aggregate != nil { + h.aggregate.Close() + } + + close(h.stop) +} + +func (h *ServerHandler) makeGlobID(path, glob string) string { + var idParts []string + pathParts := strings.Split(path, "/") + + for i, globPart := range strings.Split(glob, "/") { + if strings.Contains(globPart, "*") { + idParts = append(idParts, pathParts[i]) + } + } + + if len(idParts) > 0 { + return strings.Join(idParts, "/") + } + + if len(pathParts) > 0 { + return pathParts[len(pathParts)-1] + } + + h.send(h.serverMessages, logger.Error("Empty file path given?", path, glob)) + return "" +} + +func (h *ServerHandler) processFileGlob(mode omode.Mode, glob string, regex string) { + retryInterval := time.Second * 5 + glob = filepath.Clean(glob) + + errors := make(chan struct{}) + stop := make(chan struct{}) + defer close(stop) + + go func() { + for { + select { + case <-errors: + h.send(h.serverMessages, logger.Warn(h.user, "Unable to read file(s), check server logs")) + case <-stop: + return + case <-h.stop: + return + } + } + }() + + maxRetries := 10 + for { + maxRetries-- + if maxRetries < 0 { + h.send(h.serverMessages, logger.Warn(h.user, "Giving up to read file(s)")) + h.internalClose() + return + } + + paths, err := filepath.Glob(glob) + if err != nil { + logger.Warn(h.user, glob, err) + time.Sleep(retryInterval) + continue + } + + if numPaths := len(paths); numPaths == 0 { + logger.Error(h.user, "No such file(s) to read", glob) + select { + case errors <- struct{}{}: + case <-h.stop: + return + default: + } + time.Sleep(retryInterval) + continue + } + + h.startReadingFiles(mode, paths, glob, regex, retryInterval, errors) + break + } +} + +func (h *ServerHandler) startReadingFiles(mode omode.Mode, paths []string, glob string, regex string, retryInterval time.Duration, errors chan<- struct{}) { + var wg sync.WaitGroup + wg.Add(len(paths)) + + read := func(path string, wg *sync.WaitGroup) { + defer wg.Done() + globID := h.makeGlobID(path, glob) + + if !h.user.HasFilePermission(path) { + logger.Error(h.user, "No permission to read file", path, globID) + select { + case errors <- struct{}{}: + default: + } + return + } + + h.startReadingFile(mode, path, globID, regex) + } + + for _, path := range paths { + go read(path, &wg) + } + + wg.Wait() +} + +func (h *ServerHandler) startReadingFile(mode omode.Mode, path, globID, regex string) { + defer h.stopReadingFile(path) + logger.Info(h.user, "Start reading file", path, globID) + + var reader fs.FileReader + switch mode { + case omode.TailClient: + reader = fs.NewTailFile(path, globID, h.serverMessages, h.tailLimiter) + case omode.GrepClient: + fallthrough + case omode.CatClient: + reader = fs.NewCatFile(path, globID, h.serverMessages, h.catLimiter) + default: + reader = fs.NewTailFile(path, globID, h.serverMessages, h.tailLimiter) + } + + h.fileReadersMtx.Lock() + h.fileReaders = append(h.fileReaders, reader) + h.fileReadersMtx.Unlock() + + lines := h.lines + // Plugin mappreduce engine + if h.aggregate != nil { + lines = h.aggregate.Lines + } + + for { + if err := reader.Start(lines, regex); err != nil { + logger.Error(h.user, path, globID, err) + } + + select { + case <-h.stop: + return + default: + if !reader.Retry() { + return + } + } + + time.Sleep(time.Second * 2) + logger.Info(path, globID, "Reading file again") + } +} + +func (h *ServerHandler) stopReadingFile(path string) { + logger.Info(h.user, "Stop reading file", path) + + h.fileReadersMtx.Lock() + defer h.fileReadersMtx.Unlock() + + path = filepath.Clean(path) + var fileReaders []fs.FileReader + + for _, reader := range h.fileReaders { + if reader.FilePath() == path { + reader.Stop() + continue + } + fileReaders = append(fileReaders, reader) + } + + if len(fileReaders) == len(h.fileReaders) { + logger.Warn(h.user, "Didn't read file path", path) + return + } + + h.fileReaders = fileReaders + + if len(fileReaders) == 0 { + if h.aggregate != nil { + h.aggregate.Serialize() + } + h.allLinesSent() + } +} + +func (h *ServerHandler) numUnsentMessages() int { + return len(h.lines) + len(h.serverMessages) + len(h.hiddenMessages) + len(h.aggregatedMessages) +} + +func (h *ServerHandler) allLinesSent() { + defer h.internalClose() + + for i := 0; i < 3; i++ { + if h.numUnsentMessages() == 0 { + logger.Debug(h.user, "All lines sent") + return + } + logger.Debug(h.user, "Still lines to be sent") + time.Sleep(time.Second) + } + + logger.Warn(h.user, "Some lines remain unsent", h.numUnsentMessages()) +} + +// Handler decides to shutdown the connection, not the server itself. +func (h *ServerHandler) internalClose() { + select { + case h.hiddenMessages <- "syn close connection": + case <-time.After(time.Second * 5): + logger.Debug(h.user, "Not waiting for ack close connection") + close(h.stopTimeout) + return + } + + select { + case <-h.Wait(): + case <-time.After(time.Second * 5): + logger.Debug(h.user, "Not waiting for ack close connection") + close(h.stopTimeout) + } +} + +func (h *ServerHandler) handleCommand(commandStr string) { + logger.Info(h.user, commandStr) + + args := strings.Split(commandStr, " ") + argc := len(args) + + logger.Debug(h.user, "Received command", commandStr, argc, args) + + if h.user.Name == config.ControlUser { + h.handleControlCommand(argc, args) + return + } + + h.handleUserCommand(argc, args) +} + +// Special (restricted) set of commands for anonymous ControlUser access. +func (h *ServerHandler) handleControlCommand(argc int, args []string) { + switch args[0] { + case "ping": + h.send(h.hiddenMessages, "pong") + case "debug": + h.send(h.serverMessages, logger.Debug(h.user, "Receiving debug command", argc, args)) + default: + logger.Warn(h.user, "Received unknown command", argc, args) + } +} + +// Commands for authed users. +func (h *ServerHandler) handleUserCommand(argc int, args []string) { + switch args[0] { + case "grep": + fallthrough + case "cat": + h.handleReadCommand(argc, args, omode.CatClient) + case "tail": + h.handleReadCommand(argc, args, omode.TailClient) + case "map": + h.handleMapCommand(argc, args) + case "ack": + h.handleAckCommand(argc, args) + case "ping": + h.send(h.hiddenMessages, "pong") + case "version": + h.send(h.serverMessages, fmt.Sprintf("Server version is "+version.String())) + case "debug": + h.send(h.serverMessages, logger.Debug(h.user, "Received debug command", argc, args)) + default: + h.send(h.serverMessages, logger.Warn(h.user, "Received unknown command", argc, args)) + } +} + +func (h *ServerHandler) handleReadCommand(argc int, args []string, mode omode.Mode) { + regex := "." + if argc >= 4 { + regex = args[3] + } + if argc < 3 { + h.send(h.serverMessages, logger.Warn(h.user, commandParseWarning, args, argc)) + return + } + go h.processFileGlob(mode, args[1], regex) +} + +func (h *ServerHandler) handleMapCommand(argc int, args []string) { + if argc < 2 { + h.send(h.serverMessages, logger.Warn(h.user, commandParseWarning, args, argc)) + return + } + + queryStr := strings.Join(args[1:], " ") + logger.Info(h.user, "Creating new mapr aggregator", queryStr) + aggregate, err := server.NewAggregate(h.aggregatedMessages, queryStr) + + if err != nil { + h.send(h.serverMessages, logger.Error(h.user, err)) + return + } + + h.aggregate = aggregate +} + +func (h *ServerHandler) handleAckCommand(argc int, args []string) { + if argc < 3 { + h.send(h.serverMessages, logger.Warn(h.user, commandParseWarning, args, argc)) + return + } + if args[1] == "close" && args[2] == "connection" { + close(h.ackStopReceived) + } +} + +func (h *ServerHandler) send(ch chan<- string, message string) { + select { + case ch <- message: + case <-h.stop: + } +} + +// Wait (block) until server handler is closed or a timeout has exceeded. +func (h *ServerHandler) Wait() <-chan struct{} { + wait := make(chan struct{}) + + go func() { + select { + case <-h.ackStopReceived: + logger.Debug(h.user, "Closing wait channel due to ACK stop received") + close(wait) + case <-h.stopTimeout: + logger.Debug(h.user, "Closing wait channel due to wait timeout") + close(wait) + case <-h.stop: + logger.Debug(h.user, "Closing wait channel due to stop") + } + }() + + return wait +} diff --git a/internal/server/server.go b/internal/server/server.go new file mode 100644 index 0000000..27a98f5 --- /dev/null +++ b/internal/server/server.go @@ -0,0 +1,212 @@ +package server + +import ( + "errors" + "fmt" + "io" + "net" + + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/server/handlers" + "github.com/mimecast/dtail/internal/ssh/server" + user "github.com/mimecast/dtail/internal/user/server" + "github.com/mimecast/dtail/internal/version" + + gossh "golang.org/x/crypto/ssh" +) + +// Server is the main server data structure. +type Server struct { + // Various server statistics counters. + stats stats + // SSH server configuration. + sshServerConfig *gossh.ServerConfig + // To control the max amount of concurrent cats (which can cause a lot of I/O on the server) + catLimiterCh chan struct{} + // To control the max amount of concurrent tails + tailLimiterCh chan struct{} + // Ask to shutdown the server + stop chan struct{} +} + +// New returns a new server. +func New() *Server { + logger.Info("Creating server", version.String()) + + s := Server{ + sshServerConfig: &gossh.ServerConfig{}, + catLimiterCh: make(chan struct{}, config.Server.MaxConcurrentCats), + tailLimiterCh: make(chan struct{}, config.Server.MaxConcurrentTails), + stop: make(chan struct{}), + } + + s.sshServerConfig.PasswordCallback = s.controlUserCallback + s.sshServerConfig.PublicKeyCallback = server.PublicKeyCallback + + private, err := gossh.ParsePrivateKey(server.PrivateHostKey()) + if err != nil { + logger.FatalExit(err) + } + s.sshServerConfig.AddHostKey(private) + + return &s +} + +// Start the server. +func (s *Server) Start() int { + logger.Info("Starting server") + + bindAt := fmt.Sprintf("%s:%d", config.Server.SSHBindAddress, config.Common.SSHPort) + logger.Info("Binding server", bindAt) + listener, err := net.Listen("tcp", bindAt) + if err != nil { + logger.FatalExit("Failed to open listening TCP socket", err) + } + + go s.stats.periodicLogServerStats(s.stop) + + for { + conn, err := listener.Accept() // Blocking + if err != nil { + logger.Error("Failed to accept incoming connection", err) + continue + } + + if err := s.stats.serverLimitExceeded(); err != nil { + logger.Error(err) + conn.Close() + continue + } + + go s.handleConnection(conn) + } +} + +func (s *Server) handleConnection(conn net.Conn) { + logger.Info("Handling connection") + + sshConn, chans, reqs, err := gossh.NewServerConn(conn, s.sshServerConfig) + if err != nil { + logger.Error("Something just happened", err) + return + } + + s.stats.incrementConnections() + + go gossh.DiscardRequests(reqs) + for newChannel := range chans { + go s.handleChannel(sshConn, newChannel) + } +} + +func (s *Server) handleChannel(sshConn gossh.Conn, newChannel gossh.NewChannel) { + user := user.New(sshConn.User(), sshConn.RemoteAddr().String()) + logger.Info(user, "Invoking channel handler") + + if newChannel.ChannelType() != "session" { + err := errors.New("Don'w allow other channel types than session") + logger.Error(user, err) + newChannel.Reject(gossh.Prohibited, err.Error()) + return + } + + channel, requests, err := newChannel.Accept() + if err != nil { + logger.Error(user, "Could not accept channel", err) + return + } + + if err := s.handleRequests(sshConn, requests, channel, user); err != nil { + logger.Error(user, err) + sshConn.Close() + } +} + +func (s *Server) handleRequests(sshConn gossh.Conn, in <-chan *gossh.Request, channel gossh.Channel, user *user.User) error { + logger.Info(user, "Invoking request handler") + + for req := range in { + var payload = struct{ Value string }{} + gossh.Unmarshal(req.Payload, &payload) + + switch req.Type { + case "shell": + var handler handlers.Handler + switch user.Name { + case config.ControlUser: + handler = handlers.NewControlHandler(user) + default: + handler = handlers.NewServerHandler(user, s.catLimiterCh, s.tailLimiterCh) + } + + // Bi-directionally connect SSH stream to SSH handler + brokenPipe1 := make(chan struct{}) + go func() { + defer close(brokenPipe1) + io.Copy(channel, handler) + }() + + brokenPipe2 := make(chan struct{}) + go func() { + defer close(brokenPipe2) + io.Copy(handler, channel) + }() + + // Ensure to close all fd's and stop all goroutines once ssh connection terminated + go func() { + defer s.stats.decrementConnections() + defer handler.Close() + + if err := sshConn.Wait(); err != nil && err != io.EOF { + logger.Error(user, err) + } + logger.Info(user, "Good bye Mister!") + }() + + // Close the underlying ssh socket when server shuts down + go func() { + select { + case <-s.stop: + logger.Debug(user, "Server initiating shutdown on handler") + case <-handler.Wait(): + logger.Debug(user, "Handler initiating shutdown by its own") + case <-brokenPipe1: + logger.Debug(user, "Broken pipe1") + case <-brokenPipe2: + logger.Debug(user, "Broken pipe2") + } + sshConn.Close() + logger.Info(user, "Closed SSH connection") + }() + + // Only serving shell type + req.Reply(true, nil) + + default: + req.Reply(false, nil) + + return fmt.Errorf("Closing SSH connection as unknown request recieved|%s|%v", + req.Type, payload.Value) + } + } + + return nil +} + +func (*Server) controlUserCallback(c gossh.ConnMetadata, authPayload []byte) (*gossh.Permissions, error) { + user := user.New(c.User(), c.RemoteAddr().String()) + + if user.Name == config.ControlUser && string(authPayload) == config.ControlUser { + logger.Debug(user, "Initiating master control program") + return nil, nil + } + + return nil, fmt.Errorf("Not authorized") +} + +// Stop the server. +func (s *Server) Stop() { + close(s.stop) + s.stats.waitForConnections() +} diff --git a/internal/server/stats.go b/internal/server/stats.go new file mode 100644 index 0000000..beb1885 --- /dev/null +++ b/internal/server/stats.go @@ -0,0 +1,88 @@ +package server + +import ( + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/logger" + "fmt" + "runtime" + "sync" + "time" +) + +// Used to collect and display various server stats. +type stats struct { + mutex sync.Mutex + currentConnections int + lifetimeConnections uint64 +} + +func (s *stats) incrementConnections() { + defer s.logServerStats() + + s.mutex.Lock() + s.currentConnections++ + s.lifetimeConnections++ + s.mutex.Unlock() +} + +func (s *stats) decrementConnections() { + defer s.logServerStats() + + s.mutex.Lock() + s.currentConnections-- + s.mutex.Unlock() +} + +func (s *stats) hasConnections() bool { + s.mutex.Lock() + currentConnections := s.currentConnections + s.mutex.Unlock() + + has := currentConnections > 0 + logger.Info("stats", "Server with open connections?", has, currentConnections) + + return has +} + +func (s *stats) logServerStats() { + s.mutex.Lock() + defer s.mutex.Unlock() + + currentConnections := fmt.Sprintf("currentConnections=%d", s.currentConnections) + lifetimeConnections := fmt.Sprintf("lifetimeConnections=%d", s.lifetimeConnections) + goroutines := fmt.Sprintf("goroutines=%d", runtime.NumGoroutine()) + logger.Info("stats", currentConnections, lifetimeConnections, goroutines) +} + +func (s *stats) serverLimitExceeded() error { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.currentConnections >= config.Server.MaxConnections { + return fmt.Errorf("Exceeded max allowed concurrent connections of %d", config.Server.MaxConnections) + } + + return nil +} + +func (s *stats) periodicLogServerStats(stop <-chan struct{}) { + for { + select { + case <-time.NewTimer(time.Second * 10).C: + s.logServerStats() + case <-stop: + return + } + } +} + +func (s *stats) waitForConnections() { + for { + select { + case <-time.NewTimer(time.Second).C: + if !s.hasConnections() { + return + } + } + } +} diff --git a/internal/ssh/client/authmethods.go b/internal/ssh/client/authmethods.go new file mode 100644 index 0000000..3392eb1 --- /dev/null +++ b/internal/ssh/client/authmethods.go @@ -0,0 +1,45 @@ +package client + +import ( + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/ssh" + "os" + + gossh "golang.org/x/crypto/ssh" +) + +// InitSSHAuthMethods initialises all known SSH auth methods on othe client side. +func InitSSHAuthMethods(trustAllHosts bool, throttleCh chan struct{}) ([]gossh.AuthMethod, *HostKeyCallback) { + var sshAuthMethods []gossh.AuthMethod + + if config.Common.ExperimentalFeaturesEnable { + sshAuthMethods = append(sshAuthMethods, gossh.Password("experimental feature test")) + logger.Info("Added experimental method to list of auth methods") + } + + keyPath := os.Getenv("HOME") + "/.ssh/id_rsa" + if authMethod, err := ssh.PrivateKey(keyPath); err == nil { + sshAuthMethods = append(sshAuthMethods, authMethod) + logger.Info("Added path to list of auth methods", keyPath) + } + + keyPath = os.Getenv("HOME") + "/.ssh/id_dsa" + if authMethod, err := ssh.PrivateKey(keyPath); err == nil { + sshAuthMethods = append(sshAuthMethods, authMethod) + logger.Info("Added path to list of auth methods", keyPath) + } + + if authMethod, err := ssh.Agent(); err == nil { + sshAuthMethods = append(sshAuthMethods, authMethod) + logger.Info("Added SSH Agent to list of auth methods") + } + + knownHostsPath := os.Getenv("HOME") + "/.ssh/known_hosts" + hostKeyCallback, err := NewHostKeyCallback(knownHostsPath, trustAllHosts, throttleCh) + if err != nil { + logger.FatalExit(knownHostsPath, err) + } + + return sshAuthMethods, hostKeyCallback +} diff --git a/internal/ssh/client/hostkeycallback.go b/internal/ssh/client/hostkeycallback.go new file mode 100644 index 0000000..4023e59 --- /dev/null +++ b/internal/ssh/client/hostkeycallback.go @@ -0,0 +1,285 @@ +package client + +import ( + "bufio" + "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/prompt" + "fmt" + "net" + "os" + "strings" + "sync" + "time" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/knownhosts" +) + +type response int + +const ( + trustHost response = iota + dontTrustHost response = iota +) + +// Represents an unknown host. +type unknownHost struct { + server string + remote net.Addr + key ssh.PublicKey + hostLine string + ipLine string + responseCh chan response +} + +// HostKeyCallback is a wrapper around ssh.KnownHosts so that we can add all +// unknown hosts in a single batch to the known_hosts file. +type HostKeyCallback struct { + knownHostsPath string + unknownCh chan unknownHost + throttleCh chan struct{} + trustAllHostsCh chan struct{} + untrustedHosts map[string]bool + mutex sync.Mutex +} + +// NewHostKeyCallback returns a new wrapper. +func NewHostKeyCallback(knownHostsPath string, trustAllHosts bool, throttleCh chan struct{}) (*HostKeyCallback, error) { + // Ensure file exists + os.OpenFile(knownHostsPath, os.O_RDONLY|os.O_CREATE, 0666) + + h := HostKeyCallback{ + knownHostsPath: knownHostsPath, + unknownCh: make(chan unknownHost), + trustAllHostsCh: make(chan struct{}), + throttleCh: throttleCh, + untrustedHosts: make(map[string]bool), + } + + if trustAllHosts { + close(h.trustAllHostsCh) + } + + return &h, nil +} + +// Wrap the host key callback. +func (h *HostKeyCallback) Wrap() ssh.HostKeyCallback { + return func(server string, remote net.Addr, key ssh.PublicKey) error { + // Parse known_hosts file + knownHostsCb, err := knownhosts.New(h.knownHostsPath) + if err != nil { + // Problem parsing it + return err + } + + // Check for valid entry in known_hosts file + err = knownHostsCb(server, remote, key) + if err == nil { + // OK + return nil + } + + // Make sure that interactive user callback does not interfere with + // SSH connection throttler. + <-h.throttleCh + defer func() { h.throttleCh <- struct{}{} }() + + unknown := unknownHost{ + server: server, + remote: remote, + key: key, + hostLine: knownhosts.Line([]string{server}, key), + ipLine: knownhosts.Line([]string{remote.String()}, key), + responseCh: make(chan response), + } + + logger.Warn("Encountered unknown host", unknown) + // Notify user that there is an unknown host + h.unknownCh <- unknown + + // Wait for user input. + switch <-unknown.responseCh { + case trustHost: + // End user acknowledged host key + return nil + case dontTrustHost: + } + + h.mutex.Lock() + defer h.mutex.Unlock() + h.untrustedHosts[server] = true + + return err + } +} + +// PromptAddHosts prompts a question to the user whether unknown hosts should +// be added to the known hosts or not. +func (h *HostKeyCallback) PromptAddHosts(stop <-chan struct{}) { + var hosts []unknownHost + + for { + // Check whether there is a unknown host + select { + case unknown := <-h.unknownCh: + hosts = append(hosts, unknown) + // Ask every 50 unknown hosts + if len(hosts) >= 50 { + h.promptAddHosts(hosts) + hosts = []unknownHost{} + } + case <-time.After(2 * time.Second): + // Or ask when after 2 seconds no new unknown hosts were added. + if len(hosts) > 0 { + h.promptAddHosts(hosts) + hosts = []unknownHost{} + } + case <-stop: + logger.Debug("Stopping goroutine prompting new hosts...") + return + } + } +} + +func (h *HostKeyCallback) promptAddHosts(hosts []unknownHost) { + var servers []string + + for _, host := range hosts { + servers = append(servers, host.server) + } + + select { + case <-h.trustAllHostsCh: + logger.Warn("Trusting host keys of servers", servers) + h.trustHosts(hosts) + return + default: + } + + question := fmt.Sprintf("Encountered %d unknown hosts: '%s'\n%s", + len(servers), + strings.Join(servers, ","), + "Do you want to trust these hosts?", + ) + + p := prompt.New(question) + + a := prompt.Answer{ + Long: "yes", + Short: "y", + Callback: func() { + h.trustHosts(hosts) + }, + EndCallback: func() { + logger.Info("Added hosts to known hosts file", h.knownHostsPath) + }, + } + p.Add(a) + + a = prompt.Answer{ + Long: "all", + Short: "a", + Callback: func() { + close(h.trustAllHostsCh) + h.trustHosts(hosts) + }, + EndCallback: func() { + logger.Info("Added hosts to known hosts file", h.knownHostsPath) + }, + } + p.Add(a) + + a = prompt.Answer{ + Long: "no", + Short: "n", + Callback: func() { + h.dontTrustHosts(hosts) + }, + EndCallback: func() { + logger.Info("Didn't add hosts to known hosts file", h.knownHostsPath) + }, + } + p.Add(a) + + a = prompt.Answer{ + Long: "details", + Short: "d", + AskAgain: true, + Callback: func() { + for _, unknown := range hosts { + fmt.Println(unknown.hostLine) + fmt.Println(unknown.ipLine) + } + }, + } + p.Add(a) + + p.Ask() +} + +func (h *HostKeyCallback) trustHosts(hosts []unknownHost) { + tmpKnownHostsPath := fmt.Sprintf("%s.tmp", h.knownHostsPath) + + newFd, err := os.OpenFile(tmpKnownHostsPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600) + if err != nil { + panic(fmt.Sprintf("%s: %s", tmpKnownHostsPath, err.Error())) + } + defer newFd.Close() + + // Newly trusted hosts in normalized form + addresses := make(map[string]struct{}) + + // First write to new known hosts file, and keep track of addresses + for _, unknown := range hosts { + unknown.responseCh <- trustHost + + // Add once as [HOSTNAME]:PORT + addresses[knownhosts.Normalize(unknown.server)] = struct{}{} + // And once as [IP]:PORT + addresses[knownhosts.Normalize(unknown.remote.String())] = struct{}{} + + newFd.WriteString(fmt.Sprintf("%s\n", unknown.hostLine)) + newFd.WriteString(fmt.Sprintf("%s\n", unknown.ipLine)) + } + + // Read old known hosts file, to see which are old and new entries + os.OpenFile(h.knownHostsPath, os.O_RDONLY|os.O_CREATE, 0666) + oldFd, err := os.Open(h.knownHostsPath) + if err != nil { + panic(err) + } + defer oldFd.Close() + + scanner := bufio.NewScanner(oldFd) + + // Now, append all still valid old entries to the new host file + for scanner.Scan() { + line := scanner.Text() + address := strings.SplitN(line, " ", 2)[0] + + if _, ok := addresses[address]; !ok { + newFd.WriteString(fmt.Sprintf("%s\n", line)) + } + } + + // Now, replace old known hosts file + if err := os.Rename(tmpKnownHostsPath, h.knownHostsPath); err != nil { + panic(err) + } +} + +func (h *HostKeyCallback) dontTrustHosts(hosts []unknownHost) { + for _, unknown := range hosts { + unknown.responseCh <- dontTrustHost + } +} + +// Untrusted returns true if the host is not trusted. False otherwise. +func (h *HostKeyCallback) Untrusted(server string) bool { + h.mutex.Lock() + defer h.mutex.Unlock() + _, ok := h.untrustedHosts[server] + + return ok +} diff --git a/internal/ssh/server/hostkey.go b/internal/ssh/server/hostkey.go new file mode 100644 index 0000000..7baa4aa --- /dev/null +++ b/internal/ssh/server/hostkey.go @@ -0,0 +1,37 @@ +package server + +import ( + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/ssh" + "io/ioutil" + "os" +) + +// PrivateHostKey retrieves the private server RSA host key. +func PrivateHostKey() []byte { + hostKeyFile := config.Server.HostKeyFile + _, err := os.Stat(hostKeyFile) + + if os.IsNotExist(err) { + logger.Info("Generating private server RSA host key") + privateKey, err := ssh.GeneratePrivateRSAKey(config.Server.HostKeyBits) + + if err != nil { + logger.FatalExit("Failed to generate private server RSA host key", err) + } + + pem := ssh.EncodePrivateKeyToPEM(privateKey) + if err := ioutil.WriteFile(hostKeyFile, pem, 0600); err != nil { + logger.Error("Unable to write private server RSA host key to file", hostKeyFile, err) + } + return pem + } + + logger.Info("Reading private server RSA host key from file", hostKeyFile) + pem, err := ioutil.ReadFile(hostKeyFile) + if err != nil { + logger.FatalExit("Failed to load private server RSA host key", err) + } + return pem +} diff --git a/internal/ssh/server/publickeycallback.go b/internal/ssh/server/publickeycallback.go new file mode 100644 index 0000000..c6929d7 --- /dev/null +++ b/internal/ssh/server/publickeycallback.go @@ -0,0 +1,62 @@ +package server + +import ( + "fmt" + "io/ioutil" + "os" + osUser "os/user" + + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/logger" + user "github.com/mimecast/dtail/internal/user/server" + + gossh "golang.org/x/crypto/ssh" +) + +// PublicKeyCallback is for the server to check whether a public SSH key is authorized ot not. +func PublicKeyCallback(c gossh.ConnMetadata, pubKey gossh.PublicKey) (*gossh.Permissions, error) { + user := user.New(c.User(), c.RemoteAddr().String()) + logger.Info(user, "Incoming authorization") + + cwd, err := os.Getwd() + if err != nil { + return nil, fmt.Errorf("Unable to get current working directory|%s|", err.Error()) + } + + authorizedKeysFile := fmt.Sprintf("%s/%s/%s.authorized_keys", cwd, config.Common.CacheDir, user.Name) + if _, err := os.Stat(authorizedKeysFile); os.IsNotExist(err) { + user, err := osUser.Lookup(user.Name) + if err != nil { + return nil, fmt.Errorf("Unable to authorize|%s|%s|", user, err.Error()) + } + // Fallback to ~ + authorizedKeysFile = user.HomeDir + "/.ssh/authorized_keys" + } + + logger.Info(user, "Reading", authorizedKeysFile) + authorizedKeysBytes, err := ioutil.ReadFile(authorizedKeysFile) + if err != nil { + return nil, fmt.Errorf("Unable to read authorized keys file|%s|%s|%s", authorizedKeysFile, user, err.Error()) + } + + authorizedKeysMap := map[string]bool{} + for len(authorizedKeysBytes) > 0 { + pubKey, _, _, rest, err := gossh.ParseAuthorizedKey(authorizedKeysBytes) + if err != nil { + return nil, fmt.Errorf("Unable to parse authorized keys bytes|%s|%s", user, err.Error()) + } + authorizedKeysMap[string(pubKey.Marshal())] = true + authorizedKeysBytes = rest + } + + if authorizedKeysMap[string(pubKey.Marshal())] { + logger.Debug("Public key fingerprint", gossh.FingerprintSHA256(pubKey), user) + return &gossh.Permissions{ + Extensions: map[string]string{ + "pubkey-fp": gossh.FingerprintSHA256(pubKey), + }, + }, nil + } + + return nil, fmt.Errorf("Unknown public key|%s", user) +} diff --git a/internal/ssh/ssh.go b/internal/ssh/ssh.go new file mode 100644 index 0000000..77cc341 --- /dev/null +++ b/internal/ssh/ssh.go @@ -0,0 +1,112 @@ +package ssh + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "github.com/mimecast/dtail/internal/logger" + "encoding/pem" + "fmt" + "io/ioutil" + "net" + "os" + "syscall" + + gossh "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" + "golang.org/x/crypto/ssh/terminal" +) + +// GeneratePrivateRSAKey is used by the server to generate its key. +func GeneratePrivateRSAKey(size int) (*rsa.PrivateKey, error) { + privateKey, err := rsa.GenerateKey(rand.Reader, size) + if err != nil { + return nil, err + } + + err = privateKey.Validate() + if err != nil { + return nil, err + } + + return privateKey, nil +} + +// EncodePrivateKeyToPEM is a helper function for converting a key to PEM format. +func EncodePrivateKeyToPEM(privateKey *rsa.PrivateKey) []byte { + derFormat := x509.MarshalPKCS1PrivateKey(privateKey) + + block := pem.Block{ + Type: "RSA PRIVATE KEY", + Headers: nil, + Bytes: derFormat, + } + + return pem.EncodeToMemory(&block) +} + +// Agent used for SSH auth. +func Agent() (gossh.AuthMethod, error) { + sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")) + if err != nil { + return nil, err + } + agentClient := agent.NewClient(sshAgent) + keys, err := agentClient.List() + if err != nil { + return nil, err + } + for i, key := range keys { + logger.Debug("Public key", i, key) + } + return gossh.PublicKeysCallback(agentClient.Signers), nil +} + +// EnterKeyPhrase is required to read phrase protected private keys. +func EnterKeyPhrase(keyFile string) []byte { + fmt.Printf("Enter phrase for key %s: ", keyFile) + phrase, err := terminal.ReadPassword(int(syscall.Stdin)) + if err != nil { + panic(err) + } + fmt.Printf("%s\n", string(phrase)) + return phrase +} + +// KeyFile returns the key as a SSH auth method. +func KeyFile(keyFile string) (gossh.AuthMethod, error) { + buffer, err := ioutil.ReadFile(keyFile) + if err != nil { + return nil, err + } + + key, err := gossh.ParsePrivateKey(buffer) + if err != nil { + return nil, err + } + + // Key phrase support disabled as password will be printed to stdout! + /* + if err == nil { + return gossh.PublicKeys(key), nil + } + + keyPhrase := EnterKeyPhrase(keyFile) + key, err = gossh.ParsePrivateKeyWithPassphrase(buffer, keyPhrase) + if err != nil { + return nil, err + } + */ + + return gossh.PublicKeys(key), nil +} + +// PrivateKey returns the private key as a SSH auth method. +func PrivateKey(keyFile string) (gossh.AuthMethod, error) { + signer, err := KeyFile(keyFile) + if err != nil { + logger.Debug(keyFile, err) + return nil, err + } + return gossh.AuthMethod(signer), nil +} diff --git a/internal/user/name.go b/internal/user/name.go new file mode 100644 index 0000000..5171ec7 --- /dev/null +++ b/internal/user/name.go @@ -0,0 +1,24 @@ +package user + +import ( + "os/user" + ) + + +func Name() string { + user, err := user.Current() + if err != nil { + panic(err) + } + + if user.Uid == "0" { + panic("Not allowed to run as UID 0") + } + + if user.Gid == "0" { + panic("Not allowed to run as GID 0") + } + + return user.Username +} + diff --git a/internal/user/server/user.go b/internal/user/server/user.go new file mode 100644 index 0000000..fad38d8 --- /dev/null +++ b/internal/user/server/user.go @@ -0,0 +1,131 @@ +package server + +import ( + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/fs/permissions" + "github.com/mimecast/dtail/internal/logger" + "fmt" + "os" + "path/filepath" + "regexp" + "strings" +) + +const maxLinkDepth int = 100 + +// User represents an end-user which connected to the server via the DTail client. +type User struct { + // The user name. + Name string + // The remote address connected from. + remoteAddress string + // The permissions the user has. + permissions []string +} + +// New returns a new user. +func New(name, remoteAddress string) *User { + return &User{ + Name: name, + remoteAddress: remoteAddress, + } +} + +// String representation of the user. +func (u *User) String() string { + return fmt.Sprintf("%s@%s", u.Name, u.remoteAddress) +} + +// HasFilePermission is used to determine whether user is alowed to read a file. +func (u *User) HasFilePermission(filePath string) (hasPermission bool) { + cleanPath, err := filepath.EvalSymlinks(filePath) + if err != nil { + logger.Error(u, filePath, "Unable to evaluate symlinks", err) + hasPermission = false + return + } + + cleanPath, err = filepath.Abs(cleanPath) + if err != nil { + logger.Error(u, cleanPath, "Unable to make file path absolute", err) + hasPermission = false + return + } + + if cleanPath != filePath { + logger.Info(u, filePath, cleanPath, "Calculated new clean path from original file path (possibly symlink)") + } + + hasPermission, err = u.hasFilePermission(cleanPath) + if err != nil { + logger.Warn(u, cleanPath, err) + } + + return +} + +func (u *User) hasFilePermission(cleanPath string) (bool, error) { + // First check file system Linux/UNIX permission. + if _, err := permissions.ToRead(u.Name, cleanPath); err != nil { + return false, fmt.Errorf("User without OS file system permissions to read file: '%v'", err) + } + logger.Info(u, cleanPath, "User has OS file system permissions to read file") + + // If file system permission is given, also check permissions + // as configured in DTail config file. + if len(u.permissions) == 0 { + p, err := config.ServerUserPermissions(u.Name) + if err != nil { + return false, err + } + u.permissions = p + } + + var hasPermission bool + var err error + + if hasPermission, err = u.iteratePaths(cleanPath); err != nil { + return false, err + } + + // Only allow to follow regular files or symlinks. + info, err := os.Lstat(cleanPath) + if err != nil { + return false, fmt.Errorf("Unable to determine file type: '%v'", err) + } + + if !info.Mode().IsRegular() { + return false, fmt.Errorf("Can only open regular files or follow symlinks") + } + + return hasPermission, nil +} + +func (u *User) iteratePaths(cleanPath string) (bool, error) { + for _, permission := range u.permissions { + var regexStr string + var negate bool + + if strings.HasPrefix(permission, "!") { + regexStr = permission[1:] + negate = true + } + regexStr = permission + negate = false + + re, err := regexp.Compile(regexStr) + if err != nil { + return false, fmt.Errorf("Permission test failed, can't compile regex '%s': '%v'", regexStr, err) + } + + if negate && re.MatchString(cleanPath) { + return false, fmt.Errorf("Permission test failed, matching negative pattern '%s'", permission) + } + + if !negate && re.MatchString(cleanPath) { + logger.Info(u, cleanPath, "Permission test passed partially, matching positive pattern", permission) + } + } + + return true, nil +} diff --git a/internal/version/version.go b/internal/version/version.go new file mode 100644 index 0000000..d036a68 --- /dev/null +++ b/internal/version/version.go @@ -0,0 +1,40 @@ +package version + +import ( + "fmt" + "os" + + "github.com/mimecast/dtail/internal/color" +) + +// Name of DTail. +const Name = "DTail" + +// Version of DTail. +const Version = "1.1.0" + +// Additional information. +const Additional = "develop" + +// String representation of the DTail version. +func String() string { + return fmt.Sprintf("%s v%v %s", Name, Version, Additional) +} + +// PaintedString is a prettier string representation of the DTail version. +func PaintedString() string { + if !color.Colored { + return String() + } + name := color.Paint(color.Yellow, Name) + version := color.Paint(color.Blue, Version) + descr := color.Paint(color.Green, Additional) + + return fmt.Sprintf("%s %v %s", name, version, descr) +} + +// PrintAndExit prints the program version and exists. +func PrintAndExit() { + fmt.Println(PaintedString()) + os.Exit(0) +} |
