diff options
| author | Paul Buetow <pbuetow@mimecast.com> | 2021-10-21 21:28:49 +0300 |
|---|---|---|
| committer | Paul Buetow <pbuetow@mimecast.com> | 2021-10-21 21:28:49 +0300 |
| commit | f4207a55f71bfbcfdc532d5cdd3befaa3474a157 (patch) | |
| tree | ea5e4a2d2a67035f645bdee496ae55a52034178a /internal | |
| parent | d80d6070557e3a800e3a54967af9eced518f116b (diff) | |
| parent | 739205206d63bf42f4e843b39d04d4c8cd8207c3 (diff) | |
merge develop
Diffstat (limited to 'internal')
104 files changed, 4357 insertions, 2708 deletions
diff --git a/internal/clients/baseclient.go b/internal/clients/baseclient.go index f83fcfd..4a7bd84 100644 --- a/internal/clients/baseclient.go +++ b/internal/clients/baseclient.go @@ -2,15 +2,13 @@ package clients import ( "context" - "fmt" - "strings" "sync" "time" - "github.com/mimecast/dtail/internal/clients/remote" + "github.com/mimecast/dtail/internal/clients/connectors" + "github.com/mimecast/dtail/internal/config" "github.com/mimecast/dtail/internal/discovery" - "github.com/mimecast/dtail/internal/io/logger" - "github.com/mimecast/dtail/internal/omode" + "github.com/mimecast/dtail/internal/io/dlog" "github.com/mimecast/dtail/internal/regex" "github.com/mimecast/dtail/internal/ssh/client" @@ -19,13 +17,13 @@ import ( // This is the main client data structure. type baseClient struct { - Args + config.Args // To display client side stats stats *stats // List of remote servers to connect to. servers []string // We have one connection per remote server. - connections []*remote.Connection + connections []connectors.Connector // SSH auth methods to use to connect to the remote servers. sshAuthMethods []gossh.AuthMethod // To deal with SSH host keys @@ -41,7 +39,7 @@ type baseClient struct { } func (c *baseClient) init() { - logger.Debug("Initiating base client") + dlog.Client.Debug("Initiating base client", c.Args.String()) flag := regex.Default if c.Args.RegexInvert { @@ -49,12 +47,16 @@ func (c *baseClient) init() { } regex, err := regex.New(c.Args.RegexStr, flag) if err != nil { - logger.FatalExit(c.Regex, "invalid regex!", err, regex) + dlog.Client.FatalPanic(c.Regex, "Invalid regex!", err, regex) } c.Regex = regex - logger.Debug("Regex", c.Regex) - c.sshAuthMethods, c.hostKeyCallback = client.InitSSHAuthMethods(c.Args.SSHAuthMethods, c.Args.SSHHostKeyCallback, c.Args.TrustAllHosts, c.throttleCh, c.Args.PrivateKeyPathFile) + if c.Args.Serverless { + return + } + c.sshAuthMethods, c.hostKeyCallback = client.InitSSHAuthMethods( + c.Args.SSHAuthMethods, c.Args.SSHHostKeyCallback, c.Args.TrustAllHosts, + c.throttleCh, c.Args.PrivateKeyPathFile) } func (c *baseClient) makeConnections(maker maker) { @@ -62,26 +64,31 @@ func (c *baseClient) makeConnections(maker maker) { discoveryService := discovery.New(c.Discovery, c.ServersStr, discovery.Shuffle) for _, server := range discoveryService.ServerList() { - c.connections = append(c.connections, c.makeConnection(server, c.sshAuthMethods, c.hostKeyCallback)) + c.connections = append(c.connections, c.makeConnection(server, + c.sshAuthMethods, c.hostKeyCallback)) } c.stats = newTailStats(len(c.connections)) } func (c *baseClient) Start(ctx context.Context, statsCh <-chan string) (status int) { - // Periodically check for unknown hosts, and ask the user whether to trust them or not. - go c.hostKeyCallback.PromptAddHosts(ctx) + dlog.Client.Trace("Starting base client") + // Can be nil when serverless. + if c.hostKeyCallback != nil { + // Periodically check for unknown hosts, and ask the user whether to trust them or not. + go c.hostKeyCallback.PromptAddHosts(ctx) + } // Print client stats every time something on statsCh is recieved. go c.stats.Start(ctx, c.throttleCh, statsCh, c.Args.Quiet) - // Keep count of active connections - active := make(chan struct{}, len(c.connections)) + var wg sync.WaitGroup + wg.Add(len(c.connections)) var mutex sync.Mutex - for i, conn := range c.connections { - go func(i int, conn *remote.Connection) { - connStatus := c.start(ctx, active, i, conn) - // Update global status. + for i, conn := range c.connections { + go func(i int, conn connectors.Connector) { + defer wg.Done() + connStatus := c.startConnection(ctx, i, conn) mutex.Lock() defer mutex.Unlock() if connStatus > status { @@ -90,15 +97,12 @@ func (c *baseClient) Start(ctx context.Context, statsCh <-chan string) (status i }(i, conn) } - c.waitUntilDone(ctx, active) + wg.Wait() return } -func (c *baseClient) start(ctx context.Context, active chan struct{}, i int, conn *remote.Connection) (status int) { - // Increment connection count - active <- struct{}{} - // Derement connection count - defer func() { <-active }() +func (c *baseClient) startConnection(ctx context.Context, i int, + conn connectors.Connector) (status int) { for { connCtx, cancel := context.WithCancel(ctx) @@ -106,80 +110,25 @@ func (c *baseClient) start(ctx context.Context, active chan struct{}, i int, con conn.Start(connCtx, cancel, c.throttleCh, c.stats.connectionsEstCh) // Retrieve status code from handler (dtail client will exit with that status) - status = conn.Handler.Status() + status = conn.Handler().Status() if !c.retry { return } time.Sleep(time.Second * 2) - logger.Debug(conn.Server, "Reconnecting") - - conn = c.makeConnection(conn.Server, c.sshAuthMethods, c.hostKeyCallback) + dlog.Client.Debug(conn.Server(), "Reconnecting") + conn = c.makeConnection(conn.Server(), c.sshAuthMethods, c.hostKeyCallback) c.connections[i] = conn } } -func (c *baseClient) makeCommandOptions() map[string]string { - options := make(map[string]string) - - if c.Args.Quiet { - options["quiet"] = fmt.Sprintf("%v", c.Args.Quiet) - } - if c.Args.LContext.MaxCount != 0 { - options["max"] = fmt.Sprintf("%d", c.Args.LContext.MaxCount) - } - if c.Args.LContext.BeforeContext != 0 { - options["before"] = fmt.Sprintf("%d", c.Args.LContext.BeforeContext) - } - if c.Args.LContext.AfterContext != 0 { - options["after"] = fmt.Sprintf("%d", c.Args.LContext.AfterContext) - } - - return options -} - -func (c *baseClient) commandOptionsToString(options map[string]string) string { - var sb strings.Builder - - count := 0 - for k, v := range options { - if count > 0 { - sb.WriteString(":") - } - sb.WriteString(fmt.Sprintf("%s=%s", k, v)) - count++ - } - - return sb.String() -} - -func (c *baseClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback client.HostKeyCallback) *remote.Connection { - conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback) - conn.Handler = c.maker.makeHandler(server) - conn.Commands = c.maker.makeCommands(c.makeCommandOptions()) - - return conn -} - -func (c *baseClient) waitUntilDone(ctx context.Context, active chan struct{}) { - defer logger.Debug("Terminated connection") - - // We want to have at least one active connection - <-active - // Put it back on the channel - active <- struct{}{} - - if c.Mode == omode.TailClient && c.retry { - <-ctx.Done() - } - - for { - numActive := len(active) - if numActive == 0 { - return - } - logger.Debug("Active connections", numActive) - time.Sleep(time.Second) +func (c *baseClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, + hostKeyCallback client.HostKeyCallback) connectors.Connector { + if c.Args.Serverless { + return connectors.NewServerless(c.UserName, c.maker.makeHandler(server), + c.maker.makeCommands()) } + return connectors.NewServerConnection(server, c.UserName, sshAuthMethods, + hostKeyCallback, c.maker.makeHandler(server), c.maker.makeCommands()) } diff --git a/internal/clients/catclient.go b/internal/clients/catclient.go index db892f1..bd65560 100644 --- a/internal/clients/catclient.go +++ b/internal/clients/catclient.go @@ -7,6 +7,8 @@ import ( "strings" "github.com/mimecast/dtail/internal/clients/handlers" + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/io/dlog" "github.com/mimecast/dtail/internal/omode" ) @@ -16,11 +18,10 @@ type CatClient struct { } // NewCatClient returns a new cat client. -func NewCatClient(args Args) (*CatClient, error) { +func NewCatClient(args config.Args) (*CatClient, error) { if args.RegexStr != "" { return nil, errors.New("Can't use regex with 'cat' operating mode") } - args.Mode = omode.CatClient c := CatClient{ @@ -33,7 +34,6 @@ func NewCatClient(args Args) (*CatClient, error) { c.init() c.makeConnections(c) - return &c, nil } @@ -41,10 +41,14 @@ func (c CatClient) makeHandler(server string) handlers.Handler { return handlers.NewClientHandler(server) } -func (c CatClient) makeCommands(options map[string]string) (commands []string) { - optionsStr := c.commandOptionsToString(options) +func (c CatClient) makeCommands() (commands []string) { + regex, err := c.Regex.Serialize() + if err != nil { + dlog.Client.FatalPanic(err) + } for _, file := range strings.Split(c.What, ",") { - commands = append(commands, fmt.Sprintf("%s:%s %s %s", c.Mode.String(), optionsStr, file, c.Regex.Serialize())) + commands = append(commands, fmt.Sprintf("%s:%s %s %s", + c.Mode.String(), c.Args.SerializeOptions(), file, regex)) } return } diff --git a/internal/clients/connectors/connector.go b/internal/clients/connectors/connector.go new file mode 100644 index 0000000..3ab6a08 --- /dev/null +++ b/internal/clients/connectors/connector.go @@ -0,0 +1,17 @@ +package connectors + +import ( + "context" + + "github.com/mimecast/dtail/internal/clients/handlers" +) + +// Connector interface. +type Connector interface { + // Start the connection. + Start(ctx context.Context, cancel context.CancelFunc, throttleCh, statsCh chan struct{}) + // Server hostname. + Server() string + // Handler for the connection. + Handler() handlers.Handler +} diff --git a/internal/clients/connectors/serverconnection.go b/internal/clients/connectors/serverconnection.go new file mode 100644 index 0000000..aeb2a41 --- /dev/null +++ b/internal/clients/connectors/serverconnection.go @@ -0,0 +1,206 @@ +package connectors + +import ( + "context" + "fmt" + "io" + "strconv" + "strings" + "time" + + "github.com/mimecast/dtail/internal/clients/handlers" + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/io/dlog" + "github.com/mimecast/dtail/internal/ssh/client" + + "golang.org/x/crypto/ssh" +) + +// ServerConnection represents a connection to a single remote dtail server via +// SSH protocol. +type ServerConnection struct { + // The full server string as received from the server discovery (can be with port number) + server string + // Only the hostname or FQDN (without the port number) + hostname string + // Only the port number. + port int + config *ssh.ClientConfig + handler handlers.Handler + commands []string + hostKeyCallback client.HostKeyCallback + throttlingDone bool +} + +// NewServerConnection returns a new DTail SSH server connection. +func NewServerConnection(server string, userName string, + authMethods []ssh.AuthMethod, hostKeyCallback client.HostKeyCallback, + handler handlers.Handler, commands []string) *ServerConnection { + + dlog.Client.Debug(server, "Creating new connection", server, handler, commands) + c := ServerConnection{ + hostKeyCallback: hostKeyCallback, + server: server, + handler: handler, + commands: commands, + config: &ssh.ClientConfig{ + User: userName, + Auth: authMethods, + HostKeyCallback: hostKeyCallback.Wrap(), + Timeout: time.Second * 2, + }, + } + + c.initServerPort() + return &c +} + +// Server returns the server hostname connected to. +func (c *ServerConnection) Server() string { return c.server } + +// Handler returns the handler used for the connection. +func (c *ServerConnection) Handler() handlers.Handler { return c.handler } + +// Attempt to parse the server port address from the provided server FQDN. +func (c *ServerConnection) initServerPort() { + parts := strings.Split(c.server, ":") + if len(parts) == 1 { + c.hostname = c.server + c.port = config.Common.SSHPort + return + } + + dlog.Client.Debug("Parsing port from hostname", parts) + port, err := strconv.Atoi(parts[1]) + if err != nil { + dlog.Client.FatalPanic("Unable to parse client port", c.server, parts, err) + } + c.hostname = parts[0] + c.port = port +} + +// Start the connection to the server. +func (c *ServerConnection) Start(ctx context.Context, cancel context.CancelFunc, + throttleCh, statsCh chan struct{}) { + + // Throttle how many connections can be established concurrently (based on ch length) + dlog.Client.Debug(c.server, "Throttling connection", len(throttleCh), cap(throttleCh)) + + select { + case throttleCh <- struct{}{}: + case <-ctx.Done(): + dlog.Client.Debug(c.server, "Not establishing connection as context is done", + len(throttleCh), cap(throttleCh)) + return + } + + dlog.Client.Debug(c.server, "Throttling says that the connection can be established", + len(throttleCh), cap(throttleCh)) + + go func() { + defer func() { + if !c.throttlingDone { + dlog.Client.Debug(c.server, "Unthrottling connection (1)", + len(throttleCh), cap(throttleCh)) + c.throttlingDone = true + <-throttleCh + } + cancel() + }() + + if err := c.dial(ctx, cancel, throttleCh, statsCh); err != nil { + dlog.Client.Warn(c.server, err) + if c.hostKeyCallback.Untrusted(c.server) { + dlog.Client.Debug(c.server, "Not trusting host") + } + } + }() + + <-ctx.Done() +} + +// Dail into a new SSH connection. Close connection in case of an error. +func (c *ServerConnection) dial(ctx context.Context, cancel context.CancelFunc, + throttleCh, statsCh chan struct{}) error { + + dlog.Client.Debug(c.server, "Incrementing connection stats") + statsCh <- struct{}{} + defer func() { + dlog.Client.Debug(c.server, "Decrementing connection stats") + <-statsCh + }() + + address := fmt.Sprintf("%s:%d", c.hostname, c.port) + dlog.Client.Debug(c.server, "Dialing into the connection", address) + + client, err := ssh.Dial("tcp", address, c.config) + if err != nil { + return err + } + defer client.Close() + + return c.session(ctx, cancel, client, throttleCh) +} + +// Create the SSH session. Close the session in case of an error. +func (c *ServerConnection) session(ctx context.Context, cancel context.CancelFunc, + client *ssh.Client, throttleCh chan struct{}) error { + + dlog.Client.Debug(c.server, "Creating SSH session") + session, err := client.NewSession() + if err != nil { + return err + } + defer session.Close() + return c.handle(ctx, cancel, session, throttleCh) +} + +func (c *ServerConnection) handle(ctx context.Context, cancel context.CancelFunc, + session *ssh.Session, throttleCh chan struct{}) error { + + dlog.Client.Debug(c.server, "Creating handler for SSH session") + stdinPipe, err := session.StdinPipe() + if err != nil { + return err + } + stdoutPipe, err := session.StdoutPipe() + if err != nil { + return err + } + if err := session.Shell(); err != nil { + return err + } + + go func() { + io.Copy(stdinPipe, c.handler) + cancel() + }() + go func() { + io.Copy(c.handler, stdoutPipe) + cancel() + }() + go func() { + select { + case <-c.handler.Done(): + case <-ctx.Done(): + } + cancel() + }() + + // Send all commands to client. + for _, command := range c.commands { + dlog.Client.Debug(command) + c.handler.SendMessage(command) + } + + if !c.throttlingDone { + dlog.Client.Debug(c.server, "Unthrottling connection (2)", + len(throttleCh), cap(throttleCh)) + c.throttlingDone = true + <-throttleCh + } + + <-ctx.Done() + c.handler.Shutdown() + return nil +} diff --git a/internal/clients/connectors/serverless.go b/internal/clients/connectors/serverless.go new file mode 100644 index 0000000..2ff490a --- /dev/null +++ b/internal/clients/connectors/serverless.go @@ -0,0 +1,116 @@ +package connectors + +import ( + "context" + "io" + + "github.com/mimecast/dtail/internal/clients/handlers" + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/io/dlog" + serverHandlers "github.com/mimecast/dtail/internal/server/handlers" + user "github.com/mimecast/dtail/internal/user/server" +) + +// Serverless creates a server object directly without TCP. +type Serverless struct { + handler handlers.Handler + commands []string + userName string +} + +// NewServerless starts a new serverless session. +func NewServerless(userName string, handler handlers.Handler, + commands []string) *Serverless { + + dlog.Client.Debug("Creating new serverless connector", handler, commands) + return &Serverless{ + userName: userName, + handler: handler, + commands: commands, + } +} + +// Server returns serverless server indicator. +func (s *Serverless) Server() string { + return "local(serverless)" +} + +// Handler returns the handler used for the serverless connection. +func (s *Serverless) Handler() handlers.Handler { + return s.handler +} + +// Start the serverless connection. +func (s *Serverless) Start(ctx context.Context, cancel context.CancelFunc, + throttleCh, statsCh chan struct{}) { + + dlog.Client.Debug("Starting serverless connector") + go func() { + defer cancel() + + if err := s.handle(ctx, cancel); err != nil { + dlog.Client.Warn(err) + } + }() + <-ctx.Done() +} + +func (s *Serverless) handle(ctx context.Context, cancel context.CancelFunc) error { + dlog.Client.Debug("Creating server handler for a serverless session") + + user, err := user.New(s.userName, s.Server()) + if err != nil { + return err + } + + var serverHandler serverHandlers.Handler + switch s.userName { + case config.HealthUser: + dlog.Client.Debug("Creating serverless health handler") + serverHandler = serverHandlers.NewHealthHandler(user) + default: + dlog.Client.Debug("Creating serverless server handler") + serverHandler = serverHandlers.NewServerHandler( + user, + make(chan struct{}, config.Server.MaxConcurrentCats), + make(chan struct{}, config.Server.MaxConcurrentTails), + ) + } + + terminate := func() { + dlog.Client.Debug("Terminating serverless connection") + serverHandler.Shutdown() + cancel() + } + + go func() { + io.Copy(serverHandler, s.handler) + dlog.Client.Trace("io.Copy(serverHandler, s.handler) => done") + terminate() + }() + go func() { + io.Copy(s.handler, serverHandler) + dlog.Client.Trace("io.Copy(s.handler, serverHandler) => done") + terminate() + }() + go func() { + select { + case <-s.handler.Done(): + dlog.Client.Trace("<-s.handler.Done()") + case <-ctx.Done(): + dlog.Client.Trace("<-ctx.Done()") + } + terminate() + }() + + // Send all commands to client. + for _, command := range s.commands { + dlog.Client.Debug("Sending command to serverless server", command) + s.handler.SendMessage(command) + } + + <-ctx.Done() + dlog.Client.Trace("s.handler.Shutdown()") + s.handler.Shutdown() + return nil +} diff --git a/internal/clients/grepclient.go b/internal/clients/grepclient.go index 567193a..7521c67 100644 --- a/internal/clients/grepclient.go +++ b/internal/clients/grepclient.go @@ -7,16 +7,19 @@ import ( "strings" "github.com/mimecast/dtail/internal/clients/handlers" + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/io/dlog" "github.com/mimecast/dtail/internal/omode" ) -// GrepClient searches a remote file for all lines matching a regular expression. Only the matching lines are displayed. +// GrepClient searches a remote file for all lines matching a regular +// expression. Only the matching lines are displayed. type GrepClient struct { baseClient } // NewGrepClient creates a new grep client. -func NewGrepClient(args Args) (*GrepClient, error) { +func NewGrepClient(args config.Args) (*GrepClient, error) { if args.RegexStr == "" { return nil, errors.New("No regex specified, use '-regex' flag") } @@ -32,7 +35,6 @@ func NewGrepClient(args Args) (*GrepClient, error) { c.init() c.makeConnections(c) - return &c, nil } @@ -40,11 +42,14 @@ func (c GrepClient) makeHandler(server string) handlers.Handler { return handlers.NewClientHandler(server) } -func (c GrepClient) makeCommands(options map[string]string) (commands []string) { - optionsStr := c.commandOptionsToString(options) +func (c GrepClient) makeCommands() (commands []string) { + regex, err := c.Regex.Serialize() + if err != nil { + dlog.Client.FatalPanic(err) + } for _, file := range strings.Split(c.What, ",") { - commands = append(commands, fmt.Sprintf("%s:%s %s %s", c.Mode.String(), optionsStr, file, c.Regex.Serialize())) + commands = append(commands, fmt.Sprintf("%s:%s %s %s", + c.Mode.String(), c.Args.SerializeOptions(), file, regex)) } - return } diff --git a/internal/clients/handlers/basehandler.go b/internal/clients/handlers/basehandler.go index 602a7ac..b520c25 100644 --- a/internal/clients/handlers/basehandler.go +++ b/internal/clients/handlers/basehandler.go @@ -1,6 +1,7 @@ package handlers import ( + "bytes" "encoding/base64" "fmt" "io" @@ -8,8 +9,8 @@ import ( "time" "github.com/mimecast/dtail/internal" - "github.com/mimecast/dtail/internal/io/logger" - "github.com/mimecast/dtail/internal/version" + "github.com/mimecast/dtail/internal/io/dlog" + "github.com/mimecast/dtail/internal/protocol" ) type baseHandler struct { @@ -17,10 +18,20 @@ type baseHandler struct { server string shellStarted bool commands chan string - receiveBuf []byte + receiveBuf bytes.Buffer status int } +func (h *baseHandler) String() string { + return fmt.Sprintf("baseHandler(%s,server:%s,shellStarted:%v,status:%d)@%p", + h.done, + h.server, + h.shellStarted, + h.status, + h, + ) +} + func (h *baseHandler) Server() string { return h.server } @@ -29,21 +40,13 @@ func (h *baseHandler) Status() int { return h.status } -func (h *baseHandler) Done() <-chan struct{} { - return h.done.Done() -} - -func (h *baseHandler) Shutdown() { - h.done.Shutdown() -} - // SendMessage to the server. func (h *baseHandler) SendMessage(command string) error { encoded := base64.StdEncoding.EncodeToString([]byte(command)) - logger.Debug("Sending command", h.server, command, encoded) + dlog.Client.Debug("Sending command", h.server, command, encoded) select { - case h.commands <- fmt.Sprintf("protocol %s base64 %v;", version.ProtocolCompat, encoded): + case h.commands <- fmt.Sprintf("protocol %s base64 %v;", protocol.ProtocolCompat, encoded): case <-time.After(time.Second * 5): return fmt.Errorf("Timed out sending command '%s' (base64: '%s')", command, encoded) case <-h.Done(): @@ -56,13 +59,20 @@ func (h *baseHandler) SendMessage(command string) error { // Read data from the dtail server via Writer interface. func (h *baseHandler) Write(p []byte) (n int, err error) { for _, b := range p { - h.receiveBuf = append(h.receiveBuf, b) - if b == '\n' { - if len(h.receiveBuf) == 0 { + switch b { + /* + // NEXT: Next DTail version make it so that '\n' gets ignored. For now + // leave it for compatibility with older DTail server + ability to display + // the protocol mismatch warn message. + case '\n' { continue - } - message := string(h.receiveBuf) - h.handleMessageType(message) + */ + case '\n', protocol.MessageDelimiter: + message := h.receiveBuf.String() + h.handleMessage(message) + h.receiveBuf.Reset() + default: + h.receiveBuf.WriteByte(b) } } @@ -77,31 +87,32 @@ func (h *baseHandler) Read(p []byte) (n int, err error) { case <-h.Done(): return 0, io.EOF } - return } -// Handle various message types. -func (h *baseHandler) handleMessageType(message string) { - if len(h.receiveBuf) == 0 { - return - } - - // Hidden server commands starti with a dot "." - if h.receiveBuf[0] == '.' { +func (h *baseHandler) handleMessage(message string) { + if len(message) > 0 && message[0] == '.' { h.handleHiddenMessage(message) - h.receiveBuf = h.receiveBuf[:0] return } - logger.Raw(message) - h.receiveBuf = h.receiveBuf[:0] + dlog.Client.Raw(message) } // Handle messages received from server which are not meant to be displayed // to the end user. func (h *baseHandler) handleHiddenMessage(message string) { - if strings.HasPrefix(message, ".syn close connection") { - h.SendMessage(".ack close connection") + switch { + case strings.HasPrefix(message, ".syn close connection"): + go h.SendMessage(".ack close connection") + h.Shutdown() } } + +func (h *baseHandler) Done() <-chan struct{} { + return h.done.Done() +} + +func (h *baseHandler) Shutdown() { + h.done.Shutdown() +} diff --git a/internal/clients/handlers/clienthandler.go b/internal/clients/handlers/clienthandler.go index 2bcb038..27ac85e 100644 --- a/internal/clients/handlers/clienthandler.go +++ b/internal/clients/handlers/clienthandler.go @@ -2,7 +2,7 @@ package handlers import ( "github.com/mimecast/dtail/internal" - "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/io/dlog" ) // ClientHandler is the basic client handler interface. @@ -12,7 +12,7 @@ type ClientHandler struct { // NewClientHandler creates a new client handler. func NewClientHandler(server string) *ClientHandler { - logger.Debug(server, "Creating new client handler") + dlog.Client.Debug(server, "Creating new client handler") return &ClientHandler{ baseHandler{ diff --git a/internal/clients/handlers/healthhandler.go b/internal/clients/handlers/healthhandler.go index 0440706..47b594e 100644 --- a/internal/clients/handlers/healthhandler.go +++ b/internal/clients/handlers/healthhandler.go @@ -1,88 +1,56 @@ package handlers import ( - "errors" - "fmt" - "time" + "strings" "github.com/mimecast/dtail/internal" + "github.com/mimecast/dtail/internal/io/dlog" + "github.com/mimecast/dtail/internal/protocol" ) -// HealthHandler implements the handler required for health checks. +// HealthHandler is the handler used on the client side for running mapreduce +// aggregations. type HealthHandler struct { - done *internal.Done - // Buffer of incoming data from server. - receiveBuf []byte - // To send commands to the server. - commands chan string - // To receive messages from the server. - receive chan<- string - // The remote server address - server string - // The return status. - status int -} - -// NewHealthHandler returns a new health check handler. -func NewHealthHandler(server string, receive chan<- string) *HealthHandler { - h := HealthHandler{ - server: server, - receive: receive, - commands: make(chan string), - status: -1, - done: internal.NewDone(), - } - - return &h -} - -// Server returns the remote server name. -func (h *HealthHandler) Server() string { - return h.server -} - -// Status of the handler. -func (h *HealthHandler) Status() int { - return h.status -} - -// Done returns done channel of the handler. -func (h *HealthHandler) Done() <-chan struct{} { - return h.done.Done() -} - -// Shutdown the handler. -func (h *HealthHandler) Shutdown() { - h.done.Shutdown() -} - -// SendMessage sends a DTail command to the server. -func (h *HealthHandler) SendMessage(command string) error { - select { - case h.commands <- fmt.Sprintf("%s;", command): - case <-time.NewTimer(time.Second * 10).C: - return errors.New("Timed out sending command " + command) - case <-h.Done(): + baseHandler +} + +// NewHealthHandler returns a new health client handler. +func NewHealthHandler(server string) *HealthHandler { + dlog.Client.Debug(server, "Creating new health handler") + return &HealthHandler{ + baseHandler: baseHandler{ + server: server, + shellStarted: false, + commands: make(chan string), + status: 2, // Assume CRITICAL status by default. + done: internal.NewDone(), + }, } - - return nil } -// Server writes byte stream to client. +// Read data from the dtail server via Writer interface. func (h *HealthHandler) Write(p []byte) (n int, err error) { for _, b := range p { - h.receiveBuf = append(h.receiveBuf, b) - if b == '\n' { - h.receive <- string(h.receiveBuf) - h.receiveBuf = h.receiveBuf[:0] + switch b { + case '\n', protocol.MessageDelimiter: + message := h.baseHandler.receiveBuf.String() + h.handleMessage(message) + h.baseHandler.receiveBuf.Reset() + default: + h.baseHandler.receiveBuf.WriteByte(b) } } - return len(p), nil } -// Server reads byte stream from client. -func (h *HealthHandler) Read(p []byte) (n int, err error) { - n = copy(p, []byte(<-h.commands)) - return +func (h *HealthHandler) handleMessage(message string) { + if len(message) > 0 && message[0] == '.' { + h.baseHandler.handleHiddenMessage(message) + return + } + s := strings.Split(message, protocol.FieldDelimiter) + message = s[len(s)-1] + if message == "OK" { + h.baseHandler.status = 0 + } } diff --git a/internal/clients/handlers/maprhandler.go b/internal/clients/handlers/maprhandler.go index fb71c8f..8718b35 100644 --- a/internal/clients/handlers/maprhandler.go +++ b/internal/clients/handlers/maprhandler.go @@ -4,21 +4,24 @@ import ( "strings" "github.com/mimecast/dtail/internal" - "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/io/dlog" "github.com/mimecast/dtail/internal/mapr" "github.com/mimecast/dtail/internal/mapr/client" + "github.com/mimecast/dtail/internal/protocol" ) -// MaprHandler is the handler used on the client side for running mapreduce aggregations. +// MaprHandler is the handler used on the client side for running mapreduce +// aggregations. type MaprHandler struct { baseHandler aggregate *client.Aggregate query *mapr.Query - count uint64 } // NewMaprHandler returns a new mapreduce client handler. -func NewMaprHandler(server string, query *mapr.Query, globalGroup *mapr.GlobalGroupSet) *MaprHandler { +func NewMaprHandler(server string, query *mapr.Query, + globalGroup *mapr.GlobalGroupSet) *MaprHandler { + return &MaprHandler{ baseHandler: baseHandler{ server: server, @@ -35,34 +38,35 @@ func NewMaprHandler(server string, query *mapr.Query, globalGroup *mapr.GlobalGr // Read data from the dtail server via Writer interface. func (h *MaprHandler) Write(p []byte) (n int, err error) { for _, b := range p { - h.baseHandler.receiveBuf = append(h.baseHandler.receiveBuf, b) - if b == '\n' { - if len(h.baseHandler.receiveBuf) == 0 { - continue - } - message := string(h.baseHandler.receiveBuf) - - if h.baseHandler.receiveBuf[0] == 'A' { - h.handleAggregateMessage(strings.TrimSpace(message)) - h.baseHandler.receiveBuf = h.baseHandler.receiveBuf[:0] - continue + switch b { + case '\n': + continue + case protocol.MessageDelimiter: + message := h.baseHandler.receiveBuf.String() + dlog.Client.Debug(message) + if message[0] == 'A' { + h.handleAggregateMessage(message) + } else { + h.baseHandler.handleMessage(message) } - h.baseHandler.handleMessageType(message) + h.baseHandler.receiveBuf.Reset() + default: + h.baseHandler.receiveBuf.WriteByte(b) } } return len(p), nil } -// Handle a message received from server including mapr aggregation -// related data. +// Handle a message received from server including mapr aggregation related data. func (h *MaprHandler) handleAggregateMessage(message string) { - h.count++ - parts := strings.Split(message, "➔") - - // Index 0 contains 'AGGREGATE', 1 contains server host. - // Aggregation data begins from index 2. - logger.Debug("Received aggregate data", h.server, h.count, parts) - h.aggregate.Aggregate(parts[2:]) - logger.Debug("Aggregated aggregate data", h.server, h.count) + parts := strings.SplitN(message, protocol.FieldDelimiter, 3) + if len(parts) != 3 { + dlog.Client.Error("Unable to aggregate data", h.server, message, parts, + len(parts), "expected 3 parts") + return + } + if err := h.aggregate.Aggregate(parts[2]); err != nil { + dlog.Client.Error("Unable to aggregate data", h.server, message, err) + } } diff --git a/internal/clients/healthclient.go b/internal/clients/healthclient.go index e93f6be..1a02827 100644 --- a/internal/clients/healthclient.go +++ b/internal/clients/healthclient.go @@ -4,93 +4,75 @@ import ( "context" "fmt" "runtime" - "strings" - "time" "github.com/mimecast/dtail/internal/clients/handlers" - "github.com/mimecast/dtail/internal/clients/remote" "github.com/mimecast/dtail/internal/config" "github.com/mimecast/dtail/internal/omode" gossh "golang.org/x/crypto/ssh" ) -// HealthClient is used for health checking (e.g. via Nagios) +// HealthClient is used to perform a basic server health check. type HealthClient struct { - // Client operating mode - mode omode.Mode - // The remote server address - server string - // SSH user name - userName string - // SSH auth methods to use to connect to the remote servers. - sshAuthMethods []gossh.AuthMethod + baseClient } -// NewHealthClient returns a new healh client. -func NewHealthClient(mode omode.Mode) (*HealthClient, error) { +// NewHealthClient returns a new health client. +func NewHealthClient(args config.Args) (*HealthClient, error) { + args.Mode = omode.HealthClient + args.UserName = config.HealthUser c := HealthClient{ - mode: mode, - server: fmt.Sprintf("%s:%d", config.Server.SSHBindAddress, config.Common.SSHPort), - userName: config.ControlUser, + baseClient: baseClient{ + Args: args, + throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()), + retry: false, + }, } - c.initSSHAuthMethods() + c.init() + c.sshAuthMethods = append(c.sshAuthMethods, gossh.Password(config.HealthUser)) + c.makeConnections(c) return &c, nil } -// Start the health client. -func (c *HealthClient) Start(ctx context.Context) (status int) { - receive := make(chan string) - - throttleCh := make(chan struct{}, runtime.NumCPU()) - statsCh := make(chan struct{}, 1) - - conn := remote.NewOneOffConnection(c.server, c.userName, c.sshAuthMethods) - conn.Handler = handlers.NewHealthHandler(c.server, receive) - conn.Commands = []string{c.mode.String()} - - connCtx, cancel := context.WithCancel(ctx) - go conn.Start(connCtx, cancel, throttleCh, statsCh) +func (c HealthClient) makeHandler(server string) handlers.Handler { + return handlers.NewHealthHandler(server) +} - for { - select { - case data := <-receive: - // Parse recieved data. - s := strings.Split(data, "|") - message := s[len(s)-1] - if strings.HasPrefix(message, "done;") { - return - } +func (c HealthClient) makeCommands() (commands []string) { + commands = append(commands, "health") + return +} - // Set severity. - s = strings.Split(message, ":") - switch s[0] { - case "OK": - case "WARNING": - if status < 1 { - status = 1 - } - case "CRITICAL": - status = 2 - case "UNKNOWN": - status = 3 - default: - fmt.Printf("CRITICAL: Unexpected server response: '%s'\n", message) - status = 2 - return - } - fmt.Print(message) +// Start the health client. +func (c *HealthClient) Start(ctx context.Context, statsCh <-chan string) int { + status := c.baseClient.Start(ctx, statsCh) - case <-time.After(time.Second * 2): - status = 2 - fmt.Println("CRITICAL: Could not communicate with DTail server") - return + switch status { + case 0: + if c.Serverless { + fmt.Printf("WARNING: All seems fine but the check only run in serverless mode" + + ", please specify a remote server via --server hostname:port\n") + return 1 + } + fmt.Printf("OK: All fine at %s :-)\n", c.ServersStr) + case 2: + if c.Serverless { + fmt.Printf("CRITICAL: DTail server not operating properly (using " + + "serverless connction)!\n") + return 2 } + fmt.Printf("CRITICAL: DTail server not operating properly at %s!\n", + c.ServersStr) + default: + if c.Serverless { + fmt.Printf("UNKNOWN: Received unknown status code %d (using serverless "+ + "connection)\n", status) + return status + } + fmt.Printf("UNKNOWN: Received unknown status code %d from %s!\n", + status, c.ServersStr) } -} -// Initialize SSH auth methods. -func (c *HealthClient) initSSHAuthMethods() { - c.sshAuthMethods = append(c.sshAuthMethods, gossh.Password(config.ControlUser)) + return status } diff --git a/internal/clients/maker.go b/internal/clients/maker.go index a1d6864..d5ffd8b 100644 --- a/internal/clients/maker.go +++ b/internal/clients/maker.go @@ -9,5 +9,5 @@ import ( // and send different commands to the DTail server. type maker interface { makeHandler(server string) handlers.Handler - makeCommands(options map[string]string) (commands []string) + makeCommands() (commands []string) } diff --git a/internal/clients/maprclient.go b/internal/clients/maprclient.go index feb7e47..246946f 100644 --- a/internal/clients/maprclient.go +++ b/internal/clients/maprclient.go @@ -9,7 +9,9 @@ import ( "time" "github.com/mimecast/dtail/internal/clients/handlers" - "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/color" + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/io/dlog" "github.com/mimecast/dtail/internal/mapr" "github.com/mimecast/dtail/internal/omode" ) @@ -29,25 +31,25 @@ const ( // MaprClient is used for running mapreduce aggregations on remote files. type MaprClient struct { baseClient - // Query string for mapr aggregations - queryStr string // Global group set for merged mapr aggregation results globalGroup *mapr.GlobalGroupSet // The query object (constructed from queryStr) query *mapr.Query // Additative result or new result every interval run? cumulative bool + // The last result string received + lastResult string } // NewMaprClient returns a new mapreduce client. -func NewMaprClient(args Args, queryStr string, maprClientMode MaprClientMode) (*MaprClient, error) { - if queryStr == "" { +func NewMaprClient(args config.Args, maprClientMode MaprClientMode) (*MaprClient, error) { + if args.QueryStr == "" { return nil, errors.New("No mapreduce query specified, use '-query' flag") } - query, err := mapr.NewQuery(queryStr) + query, err := mapr.NewQuery(args.QueryStr) if err != nil { - logger.FatalExit(queryStr, "Can't parse mapr query", err) + dlog.Client.FatalPanic(args.QueryStr, "Can't parse mapr query", err) } // Don't retry connection if in tail mode and no outfile specified. @@ -64,7 +66,7 @@ func NewMaprClient(args Args, queryStr string, maprClientMode MaprClientMode) (* cumulative = args.Mode == omode.MapClient || query.HasOutfile() } - logger.Debug("Cumulative mapreduce mode?", cumulative) + dlog.Client.Debug("Cumulative mapreduce mode?", cumulative) c := MaprClient{ baseClient: baseClient{ @@ -73,7 +75,6 @@ func NewMaprClient(args Args, queryStr string, maprClientMode MaprClientMode) (* retry: retry, }, query: query, - queryStr: queryStr, cumulative: cumulative, } @@ -99,46 +100,51 @@ func (c *MaprClient) Start(ctx context.Context, statsCh <-chan string) (status i status = c.baseClient.Start(ctx, statsCh) if c.cumulative { - logger.Debug("Received final mapreduce result") + dlog.Client.Debug("Received final mapreduce result") c.reportResults() } return } +// NEXT: Make this a callback function rather trying to use polymorphism to call +// this. This applies to all clients. It will make the code easier to read. func (c MaprClient) makeHandler(server string) handlers.Handler { return handlers.NewMaprHandler(server, c.query, c.globalGroup) } -func (c MaprClient) makeCommands(options map[string]string) (commands []string) { +func (c MaprClient) makeCommands() (commands []string) { commands = append(commands, fmt.Sprintf("map %s", c.query.RawQuery)) - modeStr := "cat" if c.Mode == omode.TailClient { modeStr = "tail" } - optionsStr := c.commandOptionsToString(options) for _, file := range strings.Split(c.What, ",") { + regex, err := c.Regex.Serialize() + if err != nil { + dlog.Client.FatalPanic(err) + } if c.Timeout > 0 { - commands = append(commands, fmt.Sprintf("timeout %d %s %s %s", c.Timeout, modeStr, file, c.Regex.Serialize())) + commands = append(commands, fmt.Sprintf("timeout %d %s %s %s", c.Timeout, + modeStr, file, regex)) continue } - commands = append(commands, fmt.Sprintf("%s:%s %s %s", modeStr, optionsStr, file, c.Regex.Serialize())) + commands = append(commands, fmt.Sprintf("%s:%s %s %s", + modeStr, c.Args.SerializeOptions(), file, regex)) } - return } func (c *MaprClient) periodicReportResults(ctx context.Context) { rampUpSleep := c.query.Interval / 2 - logger.Debug("Ramp up sleeping before processing mapreduce results", rampUpSleep) + dlog.Client.Debug("Ramp up sleeping before processing mapreduce results", rampUpSleep) time.Sleep(rampUpSleep) for { select { case <-time.After(c.query.Interval): - logger.Debug("Gathering interim mapreduce result") + dlog.Client.Debug("Gathering interim mapreduce result") c.reportResults() case <-ctx.Done(): return @@ -151,42 +157,65 @@ func (c *MaprClient) reportResults() { c.writeResultsToOutfile() return } - c.printResults() } func (c *MaprClient) printResults() { var result string var err error - var numLines int + var numRows int + rowsLimit := -1 + + if c.query.Limit == -1 { + // Limit output to 10 rows when the result is printed to stdout. + // This can be overriden with the limit clause though. + rowsLimit = 10 + } if c.cumulative { - result, numLines, err = c.globalGroup.Result(c.query) + result, numRows, err = c.globalGroup.Result(c.query, rowsLimit) } else { - result, numLines, err = c.globalGroup.SwapOut().Result(c.query) + result, numRows, err = c.globalGroup.SwapOut().Result(c.query, rowsLimit) } if err != nil { - logger.FatalExit(err) + dlog.Client.FatalPanic(err) } - if numLines == 0 { - logger.Warn("Empty result set this time...") + if result == c.lastResult { + dlog.Client.Debug("Result hasn't changed compared to last time...") return } + c.lastResult = result - logger.Raw(fmt.Sprintf("%s\n", c.query.RawQuery)) - logger.Raw(result) + if numRows == 0 { + dlog.Client.Debug("Empty result set this time...") + return + } + + rawQuery := c.query.RawQuery + if config.Client.TermColorsEnable { + rawQuery = color.PaintStrWithAttr(rawQuery, + config.Client.TermColors.MaprTable.RawQueryFg, + config.Client.TermColors.MaprTable.RawQueryBg, + config.Client.TermColors.MaprTable.RawQueryAttr) + } + dlog.Client.Raw(rawQuery) + + if rowsLimit > 0 && numRows > rowsLimit { + dlog.Client.Warn(fmt.Sprintf("Got %d results but limited terminal output "+ + "to %d rows! Use 'limit' clause to override!", numRows, rowsLimit)) + } + dlog.Client.Raw(result) } func (c *MaprClient) writeResultsToOutfile() { if c.cumulative { if err := c.globalGroup.WriteResult(c.query); err != nil { - logger.FatalExit(err) + dlog.Client.FatalPanic(err) } return } - if err := c.globalGroup.SwapOut().WriteResult(c.query); err != nil { - logger.FatalExit(err) + dlog.Client.FatalPanic(err) } } diff --git a/internal/clients/remote/connection.go b/internal/clients/remote/connection.go deleted file mode 100644 index b29ffed..0000000 --- a/internal/clients/remote/connection.go +++ /dev/null @@ -1,212 +0,0 @@ -package remote - -import ( - "context" - "fmt" - "io" - "strconv" - "strings" - "time" - - "github.com/mimecast/dtail/internal/clients/handlers" - "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/io/logger" - "github.com/mimecast/dtail/internal/ssh/client" - - "golang.org/x/crypto/ssh" -) - -// Connection represents a client connection connection to a single server. -type Connection struct { - // The remote server's hostname connected to. - Server string - // The remote server's port connected to. - port int - // The SSH client configuration used. - config *ssh.ClientConfig - // The SSH client handler to use. - Handler handlers.Handler - // DTail commands sent from client to server. When client loses - // connection to the server it re-connects automatically and sends the - // same commands again. - Commands []string - // Is it a persistent connection or a one-off? - isOneOff bool - // To deal with SSH server host keys - hostKeyCallback client.HostKeyCallback - // To determine if connection throttling has finished or not - throttlingDone bool -} - -// NewConnection returns a new connection. -func NewConnection(server string, userName string, authMethods []ssh.AuthMethod, hostKeyCallback client.HostKeyCallback) *Connection { - logger.Debug(server, "Creating new connection") - - c := Connection{ - hostKeyCallback: hostKeyCallback, - config: &ssh.ClientConfig{ - User: userName, - Auth: authMethods, - HostKeyCallback: hostKeyCallback.Wrap(), - Timeout: time.Second * 3, - }, - } - - c.initServerPort(server) - - return &c -} - -// NewOneOffConnection creates new one-off connection (only for sending a series of commands and then quit). -func NewOneOffConnection(server string, userName string, authMethods []ssh.AuthMethod) *Connection { - c := Connection{ - config: &ssh.ClientConfig{ - User: userName, - Auth: authMethods, - HostKeyCallback: ssh.InsecureIgnoreHostKey(), - }, - isOneOff: true, - } - - c.initServerPort(server) - - return &c -} - -// Attempt to parse the server port address from the provided server FQDN. -func (c *Connection) initServerPort(server string) { - c.Server = server - c.port = config.Common.SSHPort - parts := strings.Split(server, ":") - - if len(parts) == 2 { - logger.Debug("Parsing port from hostname", parts) - port, err := strconv.Atoi(parts[1]) - if err != nil { - logger.FatalExit("Unable to parse client port", server, parts, err) - } - c.Server = parts[0] - c.port = port - } -} - -// Start the server connection. Build up SSH session and send some DTail commands. -func (c *Connection) Start(ctx context.Context, cancel context.CancelFunc, throttleCh, statsCh chan struct{}) { - // Throttle how many connections can be established concurrently (based on ch length) - logger.Debug(c.Server, "Throttling connection", len(throttleCh), cap(throttleCh)) - - select { - case throttleCh <- struct{}{}: - case <-ctx.Done(): - logger.Debug(c.Server, "Not establishing connection as context is done", len(throttleCh), cap(throttleCh)) - return - } - - logger.Debug(c.Server, "Throttling says that the connection can be established", len(throttleCh), cap(throttleCh)) - - go func() { - defer func() { - if !c.throttlingDone { - logger.Debug(c.Server, "Unthrottling connection (1)", len(throttleCh), cap(throttleCh)) - c.throttlingDone = true - <-throttleCh - } - cancel() - }() - - if err := c.dial(ctx, cancel, throttleCh, statsCh); err != nil { - logger.Warn(c.Server, c.port, err) - if c.hostKeyCallback.Untrusted(fmt.Sprintf("%s:%d", c.Server, c.port)) { - logger.Debug(c.Server, "Not trusting host") - } - } - }() - - <-ctx.Done() -} - -// Dail into a new SSH connection. Close connection in case of an error. -func (c *Connection) dial(ctx context.Context, cancel context.CancelFunc, throttleCh, statsCh chan struct{}) error { - logger.Debug(c.Server, "Incrementing connection stats") - statsCh <- struct{}{} - defer func() { - logger.Debug(c.Server, "Decrementing connection stats") - <-statsCh - }() - - logger.Debug(c.Server, "Dialing into the connection") - address := fmt.Sprintf("%s:%d", c.Server, c.port) - - client, err := ssh.Dial("tcp", address, c.config) - if err != nil { - return err - } - defer client.Close() - - return c.session(ctx, cancel, client, throttleCh) -} - -// Create the SSH session. Close the session in case of an error. -func (c *Connection) session(ctx context.Context, cancel context.CancelFunc, client *ssh.Client, throttleCh chan struct{}) error { - logger.Debug(c.Server, "session") - - session, err := client.NewSession() - if err != nil { - return err - } - defer session.Close() - - return c.handle(ctx, cancel, session, throttleCh) -} - -func (c *Connection) handle(ctx context.Context, cancel context.CancelFunc, session *ssh.Session, throttleCh chan struct{}) error { - logger.Debug(c.Server, "handle") - - stdinPipe, err := session.StdinPipe() - if err != nil { - return err - } - - stdoutPipe, err := session.StdoutPipe() - if err != nil { - return err - } - - if err := session.Shell(); err != nil { - return err - } - - go func() { - io.Copy(stdinPipe, c.Handler) - cancel() - }() - - go func() { - io.Copy(c.Handler, stdoutPipe) - cancel() - }() - - go func() { - select { - case <-c.Handler.Done(): - case <-ctx.Done(): - } - cancel() - }() - - // Send all commands to client. - for _, command := range c.Commands { - logger.Debug(command) - c.Handler.SendMessage(command) - } - - if !c.throttlingDone { - logger.Debug(c.Server, "Unthrottling connection (2)", len(throttleCh), cap(throttleCh)) - c.throttlingDone = true - <-throttleCh - } - - <-ctx.Done() - c.Handler.Shutdown() - return nil -} diff --git a/internal/clients/stats.go b/internal/clients/stats.go index d8163d4..1315aea 100644 --- a/internal/clients/stats.go +++ b/internal/clients/stats.go @@ -8,14 +8,16 @@ import ( "sync" "time" + "github.com/mimecast/dtail/internal/color" "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/io/dlog" + "github.com/mimecast/dtail/internal/protocol" ) // Used to collect and display various client stats. type stats struct { // Total amount servers to connect to. - connectionsTotal int + servers int // To keep track of what connected and disconnected connectionsEstCh chan struct{} // Amount of servers connections are established. @@ -24,19 +26,20 @@ type stats struct { mutex sync.Mutex } -func newTailStats(connectionsTotal int) *stats { +func newTailStats(servers int) *stats { return &stats{ - connectionsTotal: connectionsTotal, - connectionsEstCh: make(chan struct{}, connectionsTotal), + servers: servers, + connectionsEstCh: make(chan struct{}, servers), connected: 0, } } // Start starts printing client connection stats every time a signal is recieved or // connection count has changed. -func (s *stats) Start(ctx context.Context, throttleCh <-chan struct{}, statsCh <-chan string, quiet bool) { - var connectedLast int +func (s *stats) Start(ctx context.Context, throttleCh <-chan struct{}, + statsCh <-chan string, quiet bool) { + var connectedLast int for { var force bool var messages []string @@ -54,18 +57,18 @@ func (s *stats) Start(ctx context.Context, throttleCh <-chan struct{}, statsCh < throttle := len(throttleCh) newConnections := connected - connectedLast - if (connected == connectedLast || quiet) && !force { continue } - stats := s.statsLine(connected, newConnections, throttle) switch force { case true: + stats := s.statsLine(connected, newConnections, throttle) messages = append(messages, fmt.Sprintf("Connection stats: %s", stats)) s.printStatsDueInterrupt(messages) default: - logger.Info(stats) + data := s.statsData(connected, newConnections, throttle) + dlog.Client.Mapreduce("STATS", data) } connectedLast = connected @@ -76,30 +79,58 @@ func (s *stats) Start(ctx context.Context, throttleCh <-chan struct{}, statsCh < } func (s *stats) printStatsDueInterrupt(messages []string) { - logger.Pause() - for _, message := range messages { + dlog.Client.Pause() + for i, message := range messages { + if i > 0 && config.Client.TermColorsEnable { + fmt.Println(color.PaintStrWithAttr(message, + config.Client.TermColors.Client.ClientFg, + config.Client.TermColors.Client.ClientBg, + config.Client.TermColors.Client.ClientAttr, + )) + continue + } fmt.Println(fmt.Sprintf(" %s", message)) } time.Sleep(time.Second * time.Duration(config.InterruptTimeoutS)) - logger.Resume() + dlog.Client.Resume() } -func (s *stats) statsLine(connected, newConnections int, throttle int) string { - percConnected := percentOf(float64(s.connectionsTotal), float64(connected)) +func (s *stats) statsData(connected, newConnections int, + throttle int) map[string]interface{} { + + percConnected := percentOf(float64(s.servers), float64(connected)) - var stats []string - stats = append(stats, fmt.Sprintf("connected=%d/%d(%d%%)", connected, s.connectionsTotal, int(percConnected))) - stats = append(stats, fmt.Sprintf("new=%d", newConnections)) - stats = append(stats, fmt.Sprintf("throttle=%d", throttle)) - stats = append(stats, fmt.Sprintf("cpus/goroutines=%d/%d", runtime.NumCPU(), runtime.NumGoroutine())) + data := make(map[string]interface{}) + data["connected"] = connected + data["servers"] = s.servers + data["connected%"] = int(percConnected) + data["new"] = newConnections + data["throttle"] = throttle + data["goroutines"] = runtime.NumGoroutine() + data["cgocalls"] = runtime.NumCgoCall() + data["cpu"] = runtime.NumCPU() - return strings.Join(stats, "|") + return data +} + +func (s *stats) statsLine(connected, newConnections int, throttle int) string { + sb := strings.Builder{} + i := 0 + for k, v := range s.statsData(connected, newConnections, throttle) { + if i > 0 { + sb.WriteString(protocol.FieldDelimiter) + } + sb.WriteString(k) + sb.WriteByte('=') + sb.WriteString(fmt.Sprintf("%v", v)) + i++ + } + return sb.String() } func (s *stats) numConnected() int { s.mutex.Lock() defer s.mutex.Unlock() - return s.connected } @@ -107,6 +138,5 @@ func percentOf(total float64, value float64) float64 { if total == 0 || total == value { return 100 } - return value / (total / 100.0) } diff --git a/internal/clients/tailclient.go b/internal/clients/tailclient.go index 853ef1d..35c01d4 100644 --- a/internal/clients/tailclient.go +++ b/internal/clients/tailclient.go @@ -6,7 +6,8 @@ import ( "strings" "github.com/mimecast/dtail/internal/clients/handlers" - "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/io/dlog" "github.com/mimecast/dtail/internal/omode" ) @@ -16,9 +17,8 @@ type TailClient struct { } // NewTailClient returns a new TailClient. -func NewTailClient(args Args) (*TailClient, error) { +func NewTailClient(args config.Args) (*TailClient, error) { args.Mode = omode.TailClient - c := TailClient{ baseClient: baseClient{ Args: args, @@ -29,7 +29,6 @@ func NewTailClient(args Args) (*TailClient, error) { c.init() c.makeConnections(c) - return &c, nil } @@ -37,12 +36,15 @@ func (c TailClient) makeHandler(server string) handlers.Handler { return handlers.NewClientHandler(server) } -func (c TailClient) makeCommands(options map[string]string) (commands []string) { - optionsStr := c.commandOptionsToString(options) +func (c TailClient) makeCommands() (commands []string) { + regex, err := c.Regex.Serialize() + if err != nil { + dlog.Client.FatalPanic(err) + } for _, file := range strings.Split(c.What, ",") { - commands = append(commands, fmt.Sprintf("%s:%s %s %s", c.Mode.String(), optionsStr, file, c.Regex.Serialize())) + commands = append(commands, fmt.Sprintf("%s:%s %s %s", + c.Mode.String(), c.Args.SerializeOptions(), file, regex)) } - logger.Debug(commands) - + dlog.Client.Debug(commands) return } diff --git a/internal/color/brush/brush.go b/internal/color/brush/brush.go new file mode 100644 index 0000000..63d63d8 --- /dev/null +++ b/internal/color/brush/brush.go @@ -0,0 +1,194 @@ +package brush + +import ( + "strings" + + "github.com/mimecast/dtail/internal/color" + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/io/pool" + "github.com/mimecast/dtail/internal/protocol" +) + +func paintSeverity(sb *strings.Builder, text string) bool { + switch { + case strings.HasPrefix(text, "WARN"): + color.PaintWithAttr(sb, text, + config.Client.TermColors.Common.SeverityWarnFg, + config.Client.TermColors.Common.SeverityWarnBg, + config.Client.TermColors.Common.SeverityWarnAttr) + + case strings.HasPrefix(text, "ERROR"): + color.PaintWithAttr(sb, text, + config.Client.TermColors.Common.SeverityErrorFg, + config.Client.TermColors.Common.SeverityErrorBg, + config.Client.TermColors.Common.SeverityErrorAttr) + + case strings.HasPrefix(text, "FATAL"): + color.PaintWithAttr(sb, text, + config.Client.TermColors.Common.SeverityFatalFg, + config.Client.TermColors.Common.SeverityFatalBg, + config.Client.TermColors.Common.SeverityFatalAttr) + + default: + return false + } + return true +} + +func paintRemote(sb *strings.Builder, line string) { + splitted := strings.SplitN(line, protocol.FieldDelimiter, 6) + + color.PaintWithAttr(sb, splitted[0], + config.Client.TermColors.Remote.RemoteFg, + config.Client.TermColors.Remote.RemoteBg, + config.Client.TermColors.Remote.RemoteAttr) + + color.PaintWithAttr(sb, protocol.FieldDelimiter, + config.Client.TermColors.Remote.DelimiterFg, + config.Client.TermColors.Remote.DelimiterBg, + config.Client.TermColors.Remote.DelimiterAttr) + + color.PaintWithAttr(sb, splitted[1], + config.Client.TermColors.Remote.HostnameFg, + config.Client.TermColors.Remote.HostnameBg, + config.Client.TermColors.Remote.HostnameAttr) + + color.PaintWithAttr(sb, protocol.FieldDelimiter, + config.Client.TermColors.Remote.DelimiterFg, + config.Client.TermColors.Remote.DelimiterBg, + config.Client.TermColors.Remote.DelimiterAttr) + + if splitted[2] == "100" { + color.PaintWithAttr(sb, splitted[2], + config.Client.TermColors.Remote.StatsOkFg, + config.Client.TermColors.Remote.StatsOkBg, + config.Client.TermColors.Remote.StatsOkAttr) + } else { + color.PaintWithAttr(sb, splitted[2], + config.Client.TermColors.Remote.StatsWarnFg, + config.Client.TermColors.Remote.StatsWarnBg, + config.Client.TermColors.Remote.StatsWarnAttr) + } + + color.PaintWithAttr(sb, protocol.FieldDelimiter, + config.Client.TermColors.Remote.DelimiterFg, + config.Client.TermColors.Remote.DelimiterBg, + config.Client.TermColors.Remote.DelimiterAttr) + + color.PaintWithAttr(sb, splitted[3], + config.Client.TermColors.Remote.CountFg, + config.Client.TermColors.Remote.CountBg, + config.Client.TermColors.Remote.CountAttr) + + color.PaintWithAttr(sb, protocol.FieldDelimiter, + config.Client.TermColors.Remote.DelimiterFg, + config.Client.TermColors.Remote.DelimiterBg, + config.Client.TermColors.Remote.DelimiterAttr) + + color.PaintWithAttr(sb, splitted[4], + config.Client.TermColors.Remote.IDFg, + config.Client.TermColors.Remote.IDBg, + config.Client.TermColors.Remote.IDAttr) + color.PaintWithAttr(sb, protocol.FieldDelimiter, + config.Client.TermColors.Remote.DelimiterFg, + config.Client.TermColors.Remote.DelimiterBg, + config.Client.TermColors.Remote.DelimiterAttr) + + if paintSeverity(sb, splitted[5]) { + return + } + color.PaintWithAttr(sb, splitted[5], + config.Client.TermColors.Remote.TextFg, + config.Client.TermColors.Remote.TextBg, + config.Client.TermColors.Remote.TextAttr) +} + +func paintClient(sb *strings.Builder, line string) { + splitted := strings.SplitN(line, protocol.FieldDelimiter, 3) + + color.PaintWithAttr(sb, splitted[0], + config.Client.TermColors.Client.ClientFg, + config.Client.TermColors.Client.ClientBg, + config.Client.TermColors.Client.ClientAttr) + + color.PaintWithAttr(sb, protocol.FieldDelimiter, + config.Client.TermColors.Client.DelimiterFg, + config.Client.TermColors.Client.DelimiterBg, + config.Client.TermColors.Client.DelimiterAttr) + + color.PaintWithAttr(sb, splitted[1], + config.Client.TermColors.Client.HostnameFg, + config.Client.TermColors.Client.HostnameBg, + config.Client.TermColors.Client.HostnameAttr) + + color.PaintWithAttr(sb, protocol.FieldDelimiter, + config.Client.TermColors.Client.DelimiterFg, + config.Client.TermColors.Client.DelimiterBg, + config.Client.TermColors.Client.DelimiterAttr) + + if paintSeverity(sb, splitted[2]) { + return + } + + color.PaintWithAttr(sb, splitted[2], + config.Client.TermColors.Client.TextFg, + config.Client.TermColors.Client.TextBg, + config.Client.TermColors.Client.TextAttr) +} + +func paintServer(sb *strings.Builder, line string) { + splitted := strings.SplitN(line, protocol.FieldDelimiter, 3) + + color.PaintWithAttr(sb, splitted[0], + config.Client.TermColors.Server.ServerFg, + config.Client.TermColors.Server.ServerBg, + config.Client.TermColors.Server.ServerAttr) + + color.PaintWithAttr(sb, protocol.FieldDelimiter, + config.Client.TermColors.Server.DelimiterFg, + config.Client.TermColors.Server.DelimiterBg, + config.Client.TermColors.Server.DelimiterAttr) + + color.PaintWithAttr(sb, splitted[1], + config.Client.TermColors.Server.HostnameFg, + config.Client.TermColors.Server.HostnameBg, + config.Client.TermColors.Server.HostnameAttr) + + color.PaintWithAttr(sb, protocol.FieldDelimiter, + config.Client.TermColors.Server.DelimiterFg, + config.Client.TermColors.Server.DelimiterBg, + config.Client.TermColors.Server.DelimiterAttr) + + if paintSeverity(sb, splitted[2]) { + return + } + + color.PaintWithAttr(sb, splitted[2], + config.Client.TermColors.Server.TextFg, + config.Client.TermColors.Server.TextBg, + config.Client.TermColors.Server.TextAttr) +} + +// Colorfy a given line based on the line's content. +func Colorfy(line string) string { + sb := pool.BuilderBuffer.Get().(*strings.Builder) + defer pool.RecycleBuilderBuffer(sb) + + switch { + case strings.HasPrefix(line, "REMOTE"): + paintRemote(sb, line) + + case strings.HasPrefix(line, "CLIENT"): + paintClient(sb, line) + + case strings.HasPrefix(line, "SERVER"): + paintServer(sb, line) + + default: + color.PaintWithAttr(sb, line, + color.FgDefault, + color.BgDefault, + color.AttrNone) + } + return sb.String() +} diff --git a/internal/color/color.go b/internal/color/color.go index 0736199..9d0bc2e 100644 --- a/internal/color/color.go +++ b/internal/color/color.go @@ -1,70 +1,148 @@ -// Package color is used to prettify console output via ANSII terminal colors. +// Package color contains all terminal color codes we know of. package color import ( "fmt" + "strings" ) -// Color name. -type Color string +// FgColor is the text foreground color. +type FgColor string -// Attribute of a color. +// BgColor is the text background color. +type BgColor string + +// Attribute of text. type Attribute string // The possible color variations. const ( - escape = "\x1b" - reset = escape + "[0m" - seq string = "%s%s%s" + escape = "\x1b" - Gray Color = escape + "[30m" - Red Color = escape + "[31m" - Green Color = escape + "[32m" - Orange Color = escape + "[33m" - Blue Color = escape + "[34m" - Magenta Color = escape + "[35m" - Yellow Color = escape + "[36m" - LightGray Color = escape + "[37m" + FgBlack FgColor = escape + "[30m" + FgRed FgColor = escape + "[31m" + FgGreen FgColor = escape + "[32m" + FgYellow FgColor = escape + "[33m" + FgBlue FgColor = escape + "[34m" + FgMagenta FgColor = escape + "[35m" + FgCyan FgColor = escape + "[36m" + FgWhite FgColor = escape + "[37m" + FgDefault FgColor = escape + "[39m" - BgGray Color = escape + "[40m" - BgRed Color = escape + "[41m" - BgGreen Color = escape + "[42m" - BgOrange Color = escape + "[43m" - BgBlue Color = escape + "[44m" - BgMagenta Color = escape + "[45m" - BgYellow Color = escape + "[46m" - BgLightGray Color = escape + "[47m" + BgBlack BgColor = escape + "[40m" + BgRed BgColor = escape + "[41m" + BgGreen BgColor = escape + "[42m" + BgYellow BgColor = escape + "[43m" + BgBlue BgColor = escape + "[44m" + BgMagenta BgColor = escape + "[45m" + BgCyan BgColor = escape + "[46m" + BgWhite BgColor = escape + "[47m" + BgDefault BgColor = escape + "[49m" - Bold Attribute = escape + "[1m" - Italic Attribute = escape + "[3m" - Underline Attribute = escape + "[4m" - ReverseColor Attribute = escape + "[7m" + AttrNone Attribute = "" + AttrReset Attribute = escape + "[0m" + AttrBold Attribute = escape + "[1m" + AttrDim Attribute = escape + "[2m" + AttrItalic Attribute = escape + "[3m" + AttrUnderline Attribute = escape + "[4m" + AttrBlink Attribute = escape + "[5m" + AttrSlowBlink Attribute = escape + "[5m" + AttrRapidBlink Attribute = escape + "[6m" + AttrReverse Attribute = escape + "[7m" + AttrHidden Attribute = escape + "[8m" +) - resetBold = escape + "[22m" - resetItalic = escape + "[23m" - resetUnderline = escape + "[24m" +// ColorNames is the list of all supported terminal colors. +var ColorNames = []string{ + "Black", "Red", "Green", "Yellow", "Blue", "Magenta", "Cyan", "White", "Default", +} - Test Color = BgYellow - TestAttr Attribute = Bold -) +// AttributeNames is the list of all supported terminal text attributes. +var AttributeNames = []string{ + "Bold", "Dim", "Italic", "Underline", "Blink", "SlowBlink", "RapidBlink", + "Reverse", "Hidden", "None", +} -// Colored DTail client output enabled. -var Colored bool +// ToFgColor converts a given string (e.g. from a config file) into a foreground +// color code. +func ToFgColor(s string) (FgColor, error) { + switch strings.ToLower(s) { + case "black": + return FgBlack, nil + case "red": + return FgRed, nil + case "green": + return FgGreen, nil + case "yellow": + return FgYellow, nil + case "blue": + return FgBlue, nil + case "magenta": + return FgMagenta, nil + case "cyan": + return FgCyan, nil + case "white": + return FgWhite, nil + case "default": + return FgDefault, nil + default: + return FgDefault, fmt.Errorf("unknown foreground text color '" + s + "'") + } +} -// Paint a given string in a given color. -func Paint(c Color, s string) string { - return fmt.Sprintf(seq, c, s, reset) +// ToBgColor converts a given string (e.g. from a config file) into a background +// color code. +func ToBgColor(s string) (BgColor, error) { + switch strings.ToLower(s) { + case "black": + return BgBlack, nil + case "red": + return BgRed, nil + case "green": + return BgGreen, nil + case "yellow": + return BgYellow, nil + case "blue": + return BgBlue, nil + case "magenta": + return BgMagenta, nil + case "cyan": + return BgCyan, nil + case "white": + return BgWhite, nil + case "default": + return BgDefault, nil + default: + return BgDefault, fmt.Errorf("unknown background text color '" + s + "'") + } } -// Attr adds a given attribute to a given string, such as "bold" or "italic". -func Attr(c Attribute, s string) string { - switch c { - case Bold: - return fmt.Sprintf(seq, Bold, s, resetBold) - case Italic: - return fmt.Sprintf(seq, Italic, s, resetItalic) - case Underline: - return fmt.Sprintf(seq, Underline, s, resetUnderline) +// ToAttribute converts a given string (e.g. from a config file) into a text attribute. +func ToAttribute(s string) (Attribute, error) { + switch strings.ToLower(s) { + case "bold": + return AttrBold, nil + case "dim": + return AttrDim, nil + case "italic": + return AttrItalic, nil + case "underline": + return AttrUnderline, nil + case "blink": + return AttrBlink, nil + case "slowblink": + return AttrSlowBlink, nil + case "rapidblink": + return AttrRapidBlink, nil + case "reverse": + return AttrReverse, nil + case "hidden": + return AttrHidden, nil + case "none": + fallthrough + case "": + return AttrNone, nil + default: + return AttrNone, fmt.Errorf("unknown text attribute '" + s + "'") } - panic("Unknown attribute") } diff --git a/internal/color/color_test.go b/internal/color/color_test.go new file mode 100644 index 0000000..7002052 --- /dev/null +++ b/internal/color/color_test.go @@ -0,0 +1,53 @@ +package color + +import ( + "strings" + "testing" +) + +func TestColors(t *testing.T) { + text := " Mimecast " + builder := strings.Builder{} + + for _, color := range ColorNames { + fgColor, err := ToFgColor(color) + if err != nil { + t.Errorf("unable to paint foreground : %s\n%v", text, err) + } + builder.WriteString(PaintStrFg(text, fgColor)) + + bgColor, err := ToBgColor(color) + if err != nil { + t.Errorf("unable to paint background: %s\n%v", text, err) + } + builder.WriteString(PaintStrBg(text, bgColor)) + } + + for _, fg := range ColorNames { + fgColor, _ := ToFgColor(fg) + for _, bg := range ColorNames { + if fg == bg { + continue + } + bgColor, _ := ToBgColor(bg) + builder.WriteString(PaintStr(text, fgColor, bgColor)) + } + } + + t.Log(builder.String()) +} + +func TestAttributes(t *testing.T) { + text := " Mimecast " + builder := strings.Builder{} + + for _, attribute := range AttributeNames { + att, err := ToAttribute(attribute) + if err != nil { + t.Errorf("unable to paint attribute: %s\n%v", text, err) + } + builder.WriteString(PaintStrWithAttr(text, FgWhite, BgBlue, att)) + } + + t.Log(builder.String()) +} diff --git a/internal/color/colorfy.go b/internal/color/colorfy.go deleted file mode 100644 index a2beb7a..0000000 --- a/internal/color/colorfy.go +++ /dev/null @@ -1,56 +0,0 @@ -package color - -import ( - "fmt" - "strings" -) - -// Add some color to log lines received from remote servers. -func paintRemote(line string) string { - splitted := strings.Split(line, "|") - if splitted[2] == "100" { - splitted[2] = Paint(BgGreen, splitted[2]) - } else { - splitted[2] = Paint(BgRed, splitted[2]) - } - info := strings.Join(splitted[0:5], "|") - log := strings.Join(splitted[5:], "|") - - if strings.HasPrefix(log, "WARN") { - log = Paint(BgYellow, log) - } else if strings.HasPrefix(log, "ERROR") { - log = Paint(BgRed, log) - } else if strings.HasPrefix(log, "FATAL") { - log = Attr(Bold, Paint(BgRed, log)) - } else { - log = Paint(Blue, log) - } - - return fmt.Sprintf("%s|%s", info, log) -} - -// Add some color to stats generated by the client. -func paintClientStats(line string) string { - splitted := strings.Split(line, "|") - first := strings.Join(splitted[0:4], "|") - connected := Paint(BgBlue, splitted[4]) - last := strings.Join(splitted[5:], "|") - - return fmt.Sprintf("%s|%s|%s", first, connected, last) -} - -// Colorfy a given line based on the line's content. -func Colorfy(line string) string { - switch { - case strings.HasPrefix(line, "REMOTE"): - return paintRemote(line) - case strings.HasPrefix(line, "CLIENT") && strings.Contains(line, "|stats|"): - return paintClientStats(line) - case strings.Contains(line, "ERROR"): - return Paint(Magenta, line) - case strings.Contains(line, "WARN"): - return Paint(Magenta, line) - } - - return line -} diff --git a/internal/color/paint.go b/internal/color/paint.go new file mode 100644 index 0000000..7735d87 --- /dev/null +++ b/internal/color/paint.go @@ -0,0 +1,91 @@ +package color + +import ( + "fmt" + "strings" +) + +// PaintStr paints a given text in a given foreground/background color combination. +func PaintStr(text string, fg FgColor, bg BgColor) string { + return fmt.Sprintf("%s%s%s%s%s", fg, bg, text, BgDefault, FgDefault) +} + +// PaintStrWithAttr paints a given text in a given foreground/background/attribute +// combination +func PaintStrWithAttr(text string, fg FgColor, bg BgColor, attr Attribute) string { + if attr == AttrNone { + return PaintStr(text, fg, bg) + } + return fmt.Sprintf("%s%s%s%s%s%s%s", fg, bg, attr, text, AttrReset, + BgDefault, FgDefault) +} + +// PaintStrFg paints a given text in a given foreground color. +func PaintStrFg(text string, fg FgColor) string { + return fmt.Sprintf("%s%s%s", fg, text, FgDefault) +} + +// PaintStrBg paints a given text in a given background color. +func PaintStrBg(text string, bg BgColor) string { + return fmt.Sprintf("%s%s%s", bg, text, BgDefault) +} + +// PaintStrAttr adds a given attribute to a given text, such as "bold" or "italic". +func PaintStrAttr(text string, attr Attribute) string { + return fmt.Sprintf("%s%s%s", attr, text, AttrReset) +} + +// Paint paints a given text in a given foreground/background color combination. +func Paint(sb *strings.Builder, text string, fg FgColor, bg BgColor) { + sb.WriteString(string(fg)) + sb.WriteString(string(bg)) + sb.WriteString(text) + sb.WriteString(string(BgDefault)) + sb.WriteString(string(FgDefault)) +} + +// Reset background and foreground colors. +func Reset(sb *strings.Builder) { + sb.WriteString(string(BgDefault)) + sb.WriteString(string(FgDefault)) +} + +// PaintWithAttr starts painting a given text in a given foreground/background/ +// attribute combination. +func PaintWithAttr(sb *strings.Builder, text string, fg FgColor, bg BgColor, + attr Attribute) { + + if attr == AttrNone { + Paint(sb, text, fg, bg) + return + } + sb.WriteString(string(fg)) + sb.WriteString(string(bg)) + sb.WriteString(string(attr)) + sb.WriteString(text) + sb.WriteString(string(AttrReset)) + sb.WriteString(string(BgDefault)) + sb.WriteString(string(FgDefault)) +} + +// PaintWithAttrs is similar to PaintWithAttr, but it takes multiple attributes. +func PaintWithAttrs(sb *strings.Builder, text string, fg FgColor, bg BgColor, + attrs []Attribute) { + + sb.WriteString(string(fg)) + sb.WriteString(string(bg)) + for _, attr := range attrs { + sb.WriteString(string(attr)) + } + sb.WriteString(text) + sb.WriteString(string(AttrReset)) + sb.WriteString(string(BgDefault)) + sb.WriteString(string(FgDefault)) +} + +// ResetWithAttr resets background, foreground and attributes. +func ResetWithAttr(sb *strings.Builder) { + sb.WriteString(string(AttrReset)) + sb.WriteString(string(BgDefault)) + sb.WriteString(string(FgDefault)) +} diff --git a/internal/color/table.go b/internal/color/table.go new file mode 100644 index 0000000..e0e4946 --- /dev/null +++ b/internal/color/table.go @@ -0,0 +1,53 @@ +package color + +import ( + "fmt" + "os" +) + +const sampleParagraph string = "Mimecast is Making Email Safer for Business. " + + "We believe that securely operating a business in the cloud requires new " + + "levels of IT preparedness, centered around cyber resilience. This is why " + + "we unify the delivery and management of security, continuity and data " + + "protection for email via one, simple-to-use cloud platform. Thousands of " + + "organizations trust us to increase their cyber resilience preparedness, " + + "streamline compliance, reduce IT complexity and keep their business running. " + + "We give employees fast and secure access to sensitive business information, " + + "and ensure email keeps flowing in the event of an outage. Mimecast will " + + "remain committed to protecting your IT assets through constant innovation " + + "and focus on your success." + +// TablePrintAndExit prints the color table and then exits the process. +func TablePrintAndExit(displaySampleParagraph bool) { + for _, attr := range AttributeNames { + if attr == "Hidden" || attr == "SlowBlink" { + continue + } + printColorTable(attr, displaySampleParagraph) + } + os.Exit(0) +} + +func printColorTable(attr string, displaySampleParagraph bool) { + for _, fg := range ColorNames { + fgColor, _ := ToFgColor(fg) + for _, bg := range ColorNames { + if fg == bg { + continue + } + + bgColor, _ := ToBgColor(bg) + attribute, _ := ToAttribute(attr) + text := fmt.Sprintf(" Foreground:%10s | Background:%10s | Attribute:%10s ", + fg, bg, attr) + fmt.Print(PaintStrWithAttr(text, fgColor, bgColor, attribute)) + + if displaySampleParagraph { + fmt.Print("\n") + fmt.Print(PaintStrWithAttr(sampleParagraph, fgColor, bgColor, attribute)) + fmt.Print("\n") + } + fmt.Print("\n") + } + } +} diff --git a/internal/config/args.go b/internal/config/args.go new file mode 100644 index 0000000..3d7ac7d --- /dev/null +++ b/internal/config/args.go @@ -0,0 +1,164 @@ +package config + +import ( + "encoding/base64" + "fmt" + "strconv" + "strings" + + "github.com/mimecast/dtail/internal/lcontext" + "github.com/mimecast/dtail/internal/omode" + + gossh "golang.org/x/crypto/ssh" +) + +// Args is a helper struct to summarize common client arguments. +type Args struct { + lcontext.LContext + Arguments []string + ConfigFile string + ConnectionsPerCPU int + Discovery string + LogDir string + Logger string + LogLevel string + Mode omode.Mode + NoColor bool + PrivateKeyPathFile string + QueryStr string + Quiet bool + RegexInvert bool + RegexStr string + Serverless bool + ServersStr string + Spartan bool + SSHAuthMethods []gossh.AuthMethod + SSHBindAddress string + SSHHostKeyCallback gossh.HostKeyCallback + SSHPort int + Timeout int + TrustAllHosts bool + UserName string + What string +} + +func (a *Args) String() string { + var sb strings.Builder + + sb.WriteString("Args(") + + sb.WriteString(fmt.Sprintf("%s:%v,", "Arguments", a.Arguments)) + sb.WriteString(fmt.Sprintf("%s:%v,", "ConfigFile", a.ConfigFile)) + sb.WriteString(fmt.Sprintf("%s:%v,", "ConnectionsPerCPU", a.ConnectionsPerCPU)) + sb.WriteString(fmt.Sprintf("%s:%v,", "Discovery", a.Discovery)) + sb.WriteString(fmt.Sprintf("%s:%v,", "LogDir", a.LogDir)) + sb.WriteString(fmt.Sprintf("%s:%v,", "LogLevel", a.LogLevel)) + sb.WriteString(fmt.Sprintf("%s:%v,", "Logger", a.Logger)) + sb.WriteString(fmt.Sprintf("%s:%v,", "Mode", a.Mode)) + sb.WriteString(fmt.Sprintf("%s:%v,", "NoColor", a.NoColor)) + sb.WriteString(fmt.Sprintf("%s:%v,", "PrivateKeyPathFile", a.PrivateKeyPathFile)) + sb.WriteString(fmt.Sprintf("%s:%v,", "QueryStr", a.QueryStr)) + sb.WriteString(fmt.Sprintf("%s:%v,", "Quiet", a.Quiet)) + sb.WriteString(fmt.Sprintf("%s:%v,", "RegexInvert", a.RegexInvert)) + sb.WriteString(fmt.Sprintf("%s:%v,", "RegexStr", a.RegexStr)) + sb.WriteString(fmt.Sprintf("%s:%v,", "SSHAuthMethods", a.SSHAuthMethods)) + sb.WriteString(fmt.Sprintf("%s:%v,", "SSHBindAddress", a.SSHBindAddress)) + sb.WriteString(fmt.Sprintf("%s:%v,", "SSHHostKeyCallback", a.SSHHostKeyCallback)) + sb.WriteString(fmt.Sprintf("%s:%v,", "SSHPort", a.SSHPort)) + sb.WriteString(fmt.Sprintf("%s:%v,", "Serverless", a.Serverless)) + sb.WriteString(fmt.Sprintf("%s:%v,", "ServersStr", a.ServersStr)) + sb.WriteString(fmt.Sprintf("%s:%v,", "Spartan", a.Spartan)) + sb.WriteString(fmt.Sprintf("%s:%v,", "Timeout", a.Timeout)) + sb.WriteString(fmt.Sprintf("%s:%v,", "TrustAllHosts", a.TrustAllHosts)) + sb.WriteString(fmt.Sprintf("%s:%v,", "UserName", a.UserName)) + sb.WriteString(fmt.Sprintf("%s:%v", "What", a.What)) + sb.WriteString(")") + + return sb.String() +} + +// SerializeOptions returns a string ready to be sent over the wire to the server. +func (a *Args) SerializeOptions() string { + options := make(map[string]string) + + if a.Quiet { + options["quiet"] = fmt.Sprintf("%v", a.Quiet) + } + if a.Spartan { + options["spartan"] = fmt.Sprintf("%v", a.Spartan) + } + if a.Serverless { + options["serverless"] = fmt.Sprintf("%v", a.Serverless) + } + if a.LContext.MaxCount != 0 { + options["max"] = fmt.Sprintf("%d", a.LContext.MaxCount) + } + if a.LContext.BeforeContext != 0 { + options["before"] = fmt.Sprintf("%d", a.LContext.BeforeContext) + } + if a.LContext.AfterContext != 0 { + options["after"] = fmt.Sprintf("%d", a.LContext.AfterContext) + } + + var sb strings.Builder + var i int + for k, v := range options { + if i > 0 { + sb.WriteString(":") + } + sb.WriteString(k) + sb.WriteString("=") + sb.WriteString(v) + i++ + } + return sb.String() +} + +// DeserializeOptions deserializes the options, but into a map. +func DeserializeOptions(opts []string) (map[string]string, lcontext.LContext, error) { + options := make(map[string]string, len(opts)) + var ltx lcontext.LContext + + for _, o := range opts { + kv := strings.SplitN(o, "=", 2) + if len(kv) != 2 { + return options, ltx, fmt.Errorf("Unable to parse options: %v", kv) + } + key := kv[0] + val := kv[1] + + if strings.HasPrefix(val, "base64%") { + s := strings.SplitN(val, "%", 2) + decoded, err := base64.StdEncoding.DecodeString(s[1]) + if err != nil { + return options, ltx, err + } + val = string(decoded) + } + + switch key { + case "before": + iVal, err := strconv.Atoi(val) + if err != nil { + return options, ltx, err + } + ltx.BeforeContext = iVal + case "after": + iVal, err := strconv.Atoi(val) + if err != nil { + return options, ltx, err + } + ltx.AfterContext = iVal + case "max": + iVal, err := strconv.Atoi(val) + if err != nil { + return options, ltx, err + } + ltx.MaxCount = iVal + default: + options[key] = val + } + } + + return options, ltx, nil +} diff --git a/internal/config/client.go b/internal/config/client.go index 1515aae..86f97f0 100644 --- a/internal/config/client.go +++ b/internal/config/client.go @@ -1,11 +1,203 @@ package config +import "github.com/mimecast/dtail/internal/color" + +type remoteTermColors struct { + DelimiterAttr color.Attribute + DelimiterBg color.BgColor + DelimiterFg color.FgColor + RemoteAttr color.Attribute + RemoteBg color.BgColor + RemoteFg color.FgColor + CountAttr color.Attribute + CountBg color.BgColor + CountFg color.FgColor + HostnameAttr color.Attribute + HostnameBg color.BgColor + HostnameFg color.FgColor + IDAttr color.Attribute + IDBg color.BgColor + IDFg color.FgColor + StatsOkAttr color.Attribute + StatsOkBg color.BgColor + StatsOkFg color.FgColor + StatsWarnAttr color.Attribute + StatsWarnBg color.BgColor + StatsWarnFg color.FgColor + TextAttr color.Attribute + TextBg color.BgColor + TextFg color.FgColor +} + +type clientTermColors struct { + DelimiterAttr color.Attribute + DelimiterBg color.BgColor + DelimiterFg color.FgColor + ClientAttr color.Attribute + ClientBg color.BgColor + ClientFg color.FgColor + HostnameAttr color.Attribute + HostnameBg color.BgColor + HostnameFg color.FgColor + TextAttr color.Attribute + TextBg color.BgColor + TextFg color.FgColor +} + +type serverTermColors struct { + DelimiterAttr color.Attribute + DelimiterBg color.BgColor + DelimiterFg color.FgColor + ServerAttr color.Attribute + ServerBg color.BgColor + ServerFg color.FgColor + HostnameAttr color.Attribute + HostnameBg color.BgColor + HostnameFg color.FgColor + TextAttr color.Attribute + TextBg color.BgColor + TextFg color.FgColor +} + +type commonTermColors struct { + SeverityErrorAttr color.Attribute + SeverityErrorBg color.BgColor + SeverityErrorFg color.FgColor + SeverityFatalAttr color.Attribute + SeverityFatalBg color.BgColor + SeverityFatalFg color.FgColor + SeverityWarnAttr color.Attribute + SeverityWarnBg color.BgColor + SeverityWarnFg color.FgColor +} + +type maprTableTermColors struct { + DataAttr color.Attribute + DataBg color.BgColor + DataFg color.FgColor + DelimiterAttr color.Attribute + DelimiterBg color.BgColor + DelimiterFg color.FgColor + HeaderAttr color.Attribute + HeaderBg color.BgColor + HeaderDelimiterAttr color.Attribute + HeaderDelimiterBg color.BgColor + HeaderDelimiterFg color.FgColor + HeaderFg color.FgColor + HeaderGroupKeyAttr color.Attribute + HeaderSortKeyAttr color.Attribute + RawQueryAttr color.Attribute + RawQueryBg color.BgColor + RawQueryFg color.FgColor +} + +type termColors struct { + Remote remoteTermColors + Client clientTermColors + Server serverTermColors + Common commonTermColors + MaprTable maprTableTermColors +} + // ClientConfig represents a DTail client configuration (empty as of now as there // are no available config options yet, but that may changes in the future). type ClientConfig struct { + TermColorsEnable bool `json:",omitempty"` + TermColors termColors `json:",omitempty"` + // When unit testing in Jenkins you don't want to touch files in ~jenkins + // during integration tests really. + SSHDontAddHostsToKnownHostsFile bool `json:",omitempty"` } // Create a new default client configuration. func newDefaultClientConfig() *ClientConfig { - return &ClientConfig{} + return &ClientConfig{ + TermColorsEnable: true, + TermColors: termColors{ + Remote: remoteTermColors{ + DelimiterAttr: color.AttrDim, + DelimiterBg: color.BgBlue, + DelimiterFg: color.FgCyan, + RemoteAttr: color.AttrDim, + RemoteBg: color.BgBlue, + RemoteFg: color.FgWhite, + CountAttr: color.AttrDim, + CountBg: color.BgBlue, + CountFg: color.FgWhite, + HostnameAttr: color.AttrBold, + HostnameBg: color.BgBlue, + HostnameFg: color.FgWhite, + IDAttr: color.AttrDim, + IDBg: color.BgBlue, + IDFg: color.FgWhite, + StatsOkAttr: color.AttrNone, + StatsOkBg: color.BgGreen, + StatsOkFg: color.FgBlack, + StatsWarnAttr: color.AttrNone, + StatsWarnBg: color.BgRed, + StatsWarnFg: color.FgWhite, + TextAttr: color.AttrNone, + TextBg: color.BgBlack, + TextFg: color.FgWhite, + }, + Client: clientTermColors{ + DelimiterAttr: color.AttrDim, + DelimiterBg: color.BgYellow, + DelimiterFg: color.FgBlack, + ClientAttr: color.AttrDim, + ClientBg: color.BgYellow, + ClientFg: color.FgBlack, + HostnameAttr: color.AttrDim, + HostnameBg: color.BgYellow, + HostnameFg: color.FgBlack, + TextAttr: color.AttrNone, + TextBg: color.BgBlack, + TextFg: color.FgWhite, + }, + Server: serverTermColors{ + DelimiterAttr: color.AttrDim, + DelimiterBg: color.BgCyan, + DelimiterFg: color.FgBlack, + ServerAttr: color.AttrDim, + ServerBg: color.BgCyan, + ServerFg: color.FgBlack, + HostnameAttr: color.AttrBold, + HostnameBg: color.BgCyan, + HostnameFg: color.FgBlack, + TextAttr: color.AttrNone, + TextBg: color.BgBlack, + TextFg: color.FgWhite, + }, + Common: commonTermColors{ + SeverityErrorAttr: color.AttrBold, + SeverityErrorBg: color.BgRed, + SeverityErrorFg: color.FgWhite, + SeverityFatalAttr: color.AttrBold, + SeverityFatalBg: color.BgMagenta, + SeverityFatalFg: color.FgWhite, + SeverityWarnAttr: color.AttrBold, + SeverityWarnBg: color.BgBlack, + SeverityWarnFg: color.FgWhite, + }, + MaprTable: maprTableTermColors{ + DataAttr: color.AttrNone, + DataBg: color.BgBlue, + DataFg: color.FgWhite, + DelimiterAttr: color.AttrDim, + DelimiterBg: color.BgBlue, + DelimiterFg: color.FgWhite, + HeaderAttr: color.AttrBold, + HeaderBg: color.BgBlue, + HeaderFg: color.FgWhite, + HeaderDelimiterAttr: color.AttrDim, + HeaderDelimiterBg: color.BgBlue, + HeaderDelimiterFg: color.FgWhite, + HeaderSortKeyAttr: color.AttrUnderline, + HeaderGroupKeyAttr: color.AttrReverse, + RawQueryAttr: color.AttrDim, + RawQueryBg: color.BgBlack, + RawQueryFg: color.FgCyan, + }, + }, + } } diff --git a/internal/config/common.go b/internal/config/common.go index c3e203e..7a72cfe 100644 --- a/internal/config/common.go +++ b/internal/config/common.go @@ -6,31 +6,27 @@ type CommonConfig struct { SSHPort int // Enable experimental features (mainly for dev purposes) ExperimentalFeaturesEnable bool `json:",omitempty"` - // Enable debug logging. Don't enable in production. - DebugEnable bool `json:",omitempty"` - // Enable trace logging. Don't enable in production. - TraceEnable bool `json:",omitempty"` - // The log strategy to use, one of - // stdout: only log to stdout (useful when used with systemd) - // daily: create a log file for every day - LogStrategy string - // The log directory + // LogDir defines the log directory. LogDir string + // Logger defines the name of the logger implementation. + Logger string + // LogLevel defines how much is logged. + LogLevel string `json:",omitempty"` + // LogRotation strategy to be used. + LogRotation string // The cache directory CacheDir string - // The temp directory - TmpDir string `json:",omitempty"` } // Create a new default configuration. func newDefaultCommonConfig() *CommonConfig { return &CommonConfig{ - SSHPort: 2222, - DebugEnable: false, - TraceEnable: false, + SSHPort: DefaultSSHPort, ExperimentalFeaturesEnable: false, LogDir: "log", + Logger: "stdout", + LogLevel: DefaultLogLevel, + LogRotation: "daily", CacheDir: "cache", - TmpDir: "/tmp", } } diff --git a/internal/config/config.go b/internal/config/config.go index 276ddcf..ee23829 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,23 +1,30 @@ package config -import ( - "encoding/json" - "io/ioutil" - "os" +import "github.com/mimecast/dtail/internal/source" + +const ( + // HealthUser is used for the health check + HealthUser string = "DTAIL-HEALTH" + // ScheduleUser is used for non-interactive scheduled mapreduce queries. + ScheduleUser string = "DTAIL-SCHEDULE" + // ContinuousUser is used for non-interactive continuous mapreduce queries. + ContinuousUser string = "DTAIL-CONTINUOUS" + // InterruptTimeoutS specifies the Ctrl+C log pause interval. + InterruptTimeoutS int = 3 + // DefaultConnectionsPerCPU controls how many connections are established concurrently. + DefaultConnectionsPerCPU int = 10 + // DefaultSSHPort is the default DServer port. + DefaultSSHPort int = 2222 + // DefaultLogLevel specifies the default log level (obviously) + DefaultLogLevel string = "info" + // DefaultClientLogger specifies the default logger for the client commands. + DefaultClientLogger string = "fout" + // DefaultServerLogger specifies the default logger for dtail server. + DefaultServerLogger string = "file" + // DefaultHealthCheckLogger specifies the default logger used for health checks. + DefaultHealthCheckLogger string = "none" ) -// ControlUser is used for various DTail specific operations. -const ControlUser string = "DTAIL-CONTROL" - -// ScheduleUser is used for non-interactive scheduled mapreduce queries. -const ScheduleUser string = "DTAIL-SCHEDULE" - -// ContinuousUser is used for non-interactive continuous mapreduce queries. -const ContinuousUser string = "DTAIL-CONTINUOUS" - -// InterruptTimeoutS is used to terminate DTail when Ctrl+C was pressed twice within a given interval. -const InterruptTimeoutS int = 3 - // Client holds a DTail client configuration. var Client *ClientConfig @@ -27,28 +34,22 @@ var Server *ServerConfig // Common holds common configs of both both, client and server. var Common *CommonConfig -// Used to initialize the configuration. -type configInitializer struct { - Common *CommonConfig - Server *ServerConfig - Client *ClientConfig -} - -// Parse and read a given config file in JSON format. -func (c *configInitializer) parseConfig(configFile string) { - fd, err := os.Open(configFile) - if err != nil { - panic(err) +// Setup the DTail configuration. +func Setup(sourceProcess source.Source, args *Args, additionalArgs []string) { + initializer := initializer{ + Common: newDefaultCommonConfig(), + Server: newDefaultServerConfig(), + Client: newDefaultClientConfig(), } - defer fd.Close() - - cfgBytes, err := ioutil.ReadAll(fd) - if err != nil { + if err := initializer.parseConfig(args); err != nil { panic(err) } - - err = json.Unmarshal([]byte(cfgBytes), c) - if err != nil { + if err := initializer.transformConfig(sourceProcess, args, additionalArgs); err != nil { panic(err) } + + // Make config accessible globally + Server = initializer.Server + Client = initializer.Client + Common = initializer.Common } diff --git a/internal/config/env.go b/internal/config/env.go new file mode 100644 index 0000000..88b831d --- /dev/null +++ b/internal/config/env.go @@ -0,0 +1,7 @@ +package config + +import "os" + +func Env(env string) bool { + return "yes" == os.Getenv(env) +} diff --git a/internal/config/initializer.go b/internal/config/initializer.go new file mode 100644 index 0000000..4d6a73b --- /dev/null +++ b/internal/config/initializer.go @@ -0,0 +1,184 @@ +package config + +import ( + "encoding/json" + "flag" + "fmt" + "io/ioutil" + "os" + "strings" + + "github.com/mimecast/dtail/internal/source" +) + +// Used to initialize the configuration. +type initializer struct { + Common *CommonConfig + Server *ServerConfig + Client *ClientConfig +} + +type transformCb func(*initializer, *Args, []string) error + +func (in *initializer) parseConfig(args *Args) error { + if strings.ToUpper(args.ConfigFile) == "NONE" { + return nil + } + + if args.ConfigFile != "" { + return in.parseSpecificConfig(args.ConfigFile) + } + + if homeDir, err := os.UserHomeDir(); err != nil { + var paths []string + paths = append(paths, fmt.Sprintf("%s/.config/dtail/dtail.conf", homeDir)) + paths = append(paths, fmt.Sprintf("%s/.dtail.conf", homeDir)) + for _, configPath := range paths { + if _, err := os.Stat(configPath); !os.IsNotExist(err) { + in.parseSpecificConfig(configPath) + } + } + } + + return nil +} + +func (in *initializer) parseSpecificConfig(configFile string) error { + fd, err := os.Open(configFile) + if err != nil { + return fmt.Errorf("Unable to read config file: %v", err) + } + defer fd.Close() + + cfgBytes, err := ioutil.ReadAll(fd) + if err != nil { + return fmt.Errorf("Unable to read config file %s: %v", configFile, err) + } + + if err := json.Unmarshal([]byte(cfgBytes), in); err != nil { + return fmt.Errorf("Unable to parse config file %s: %v", configFile, err) + } + + return nil +} + +func (in *initializer) transformConfig(sourceProcess source.Source, args *Args, + additionalArgs []string) error { + + in.readEnvironmentVars() + + switch sourceProcess { + case source.Server: + return in.optimusPrime(transformServer, args, additionalArgs) + case source.Client: + return in.optimusPrime(transformClient, args, additionalArgs) + case source.HealthCheck: + return in.optimusPrime(transformHealthCheck, args, additionalArgs) + default: + return fmt.Errorf("Unable to transform config, unknown source '%s'", + sourceProcess) + } +} + +// There are some special options which can be set by environment variable. +func (in *initializer) readEnvironmentVars() { + if Env("DTAIL_SSH_DONT_ADD_HOSTS_TO_KNOWNHOSTS_FILE") || + Env("DTAIL_JENKINS") { + in.Client.SSHDontAddHostsToKnownHostsFile = true + } +} + +func (in *initializer) optimusPrime(sourceCb transformCb, args *Args, + additionalArgs []string) error { + + // Copy args to config objects. + // NEXT: Maybe unify args and config structs? + if args.SSHPort != DefaultSSHPort { + in.Common.SSHPort = args.SSHPort + } + if args.LogLevel != DefaultLogLevel { + in.Common.LogLevel = args.LogLevel + } + if args.NoColor { + in.Client.TermColorsEnable = false + } + if args.LogDir != "" { + in.Common.LogDir = args.LogDir + } + if args.Logger != "" { + in.Common.Logger = args.Logger + } + if args.ConnectionsPerCPU == 0 { + args.ConnectionsPerCPU = DefaultConnectionsPerCPU + } + + // Setup log directory. + if strings.Contains(in.Common.LogDir, "~/") { + homeDir, err := os.UserHomeDir() + if err != nil { + panic(err) + } + in.Common.LogDir = strings.ReplaceAll(in.Common.LogDir, "~/", + fmt.Sprintf("%s/", homeDir)) + } + + // Source type specific transormations. + sourceCb(in, args, additionalArgs) + + // Spartan mode. + if args.Spartan { + args.Quiet = true + args.NoColor = true + in.Client.TermColorsEnable = false + if args.LogLevel == "" { + args.LogLevel = "ERROR" + in.Common.LogLevel = "ERROR" + } + } + // Interpret additional args as file list or as query. + if args.What == "" { + var files []string + for _, arg := range flag.Args() { + if args.QueryStr == "" && strings.Contains(strings.ToLower(arg), "select ") { + args.QueryStr = arg + continue + } + files = append(files, arg) + } + args.What = strings.Join(files, ",") + } + + return nil +} + +func transformClient(in *initializer, args *Args, additionalArgs []string) error { + // Serverless mode. + if args.Discovery == "" && (args.ServersStr == "" || + strings.ToLower(args.ServersStr) == "serverless") { + // We are not connecting to any servers. + args.Serverless = true + if args.LogLevel == DefaultLogLevel { + in.Common.LogLevel = "warn" + } + } + return nil +} + +func transformServer(in *initializer, args *Args, additionalArgs []string) error { + if args.SSHBindAddress != "" { + in.Server.SSHBindAddress = args.SSHBindAddress + } + return nil +} + +func transformHealthCheck(in *initializer, args *Args, additionalArgs []string) error { + // Serverless mode. + if args.Discovery == "" && (args.ServersStr == "" || + strings.ToLower(args.ServersStr) == "serverless") { + // We are not connecting to any servers. + args.Serverless = true + in.Common.LogLevel = "warn" + } + args.TrustAllHosts = true + return nil +} diff --git a/internal/config/read.go b/internal/config/read.go deleted file mode 100644 index a4e605b..0000000 --- a/internal/config/read.go +++ /dev/null @@ -1,37 +0,0 @@ -package config - -import ( - "os" -) - -// Read the DTail configuration. -func Read(configFile string, sshPort int) { - initializer := configInitializer{ - Common: newDefaultCommonConfig(), - Server: newDefaultServerConfig(), - Client: newDefaultClientConfig(), - } - - if configFile == "" { - configFile = "./cfg/dtail.json" - } - - if _, err := os.Stat(configFile); !os.IsNotExist(err) { - initializer.parseConfig(configFile) - } - - // Assign pointers to global variables, so that we can access the - // configuration from any place of the program. - Common = initializer.Common - Server = initializer.Server - Client = initializer.Client - - if Server.MapreduceLogFormat == "" { - Server.MapreduceLogFormat = "default" - } - - // If non-standard port specified, overwrite config - if sshPort != 2222 { - Common.SSHPort = sshPort - } -} diff --git a/internal/config/server.go b/internal/config/server.go index dc0d587..254ea0c 100644 --- a/internal/config/server.go +++ b/internal/config/server.go @@ -4,8 +4,8 @@ import ( "errors" ) -// Permissions map. Each SSH user has a list of permissions which -// log files it is allowed to follow and which ones not. +// Permissions map. Each SSH user has a list of permissions which log files it +// is allowed to follow and which ones not. type Permissions struct { // The default user permissions. Default []string @@ -47,7 +47,7 @@ type ServerConfig struct { MaxConcurrentCats int // The max amount of concurrent tails per server. MaxConcurrentTails int - // The user permissions. + // The user permissions. TODO: Add to JSON schema Permissions Permissions `json:",omitempty"` // The mapr log format MapreduceLogFormat string `json:",omitempty"` @@ -68,7 +68,6 @@ var ServerRelaxedAuthEnable bool func newDefaultServerConfig() *ServerConfig { defaultPermissions := []string{"^/.*"} defaultBindAddress := "0.0.0.0" - return &ServerConfig{ SSHBindAddress: defaultBindAddress, MaxConnections: 10, @@ -76,6 +75,7 @@ func newDefaultServerConfig() *ServerConfig { MaxConcurrentTails: 50, HostKeyFile: "./cache/ssh_host_key", HostKeyBits: 4096, + MapreduceLogFormat: "default", Permissions: Permissions{ Default: defaultPermissions, }, @@ -88,10 +88,8 @@ func ServerUserPermissions(userName string) (permissions []string, err error) { if p, ok := Server.Permissions.Users[userName]; ok { permissions = p } - if len(permissions) == 0 { err = errors.New("Empty set of permission, user won't be able to open any files") } - return } diff --git a/internal/discovery/comma.go b/internal/discovery/comma.go index 4344240..9bea89c 100644 --- a/internal/discovery/comma.go +++ b/internal/discovery/comma.go @@ -3,11 +3,11 @@ package discovery import ( "strings" - "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/io/dlog" ) // ServerListFromCOMMA retrieves a list of servers from comma separated input list. func (d *Discovery) ServerListFromCOMMA() []string { - logger.Debug("Retrieving server list from comma separated list", d.server) + dlog.Common.Debug("Retrieving server list from comma separated list", d.server) return strings.Split(d.server, ",") } diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go index a25b136..8bb1e85 100644 --- a/internal/discovery/discovery.go +++ b/internal/discovery/discovery.go @@ -9,7 +9,7 @@ import ( "strings" "time" - "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/io/dlog" ) // ServerOrder to specify how to sort the server list. @@ -24,7 +24,7 @@ const ( type Discovery struct { // To plug in a custom server discovery module. module string - // To specify optional server discovery module options. + // To specifiy optional server discovery module options. options string // To either filter a server list or to secify an exact list. server string @@ -42,7 +42,7 @@ func New(method, server string, order ServerOrder) *Discovery { if strings.Contains(module, ":") { s := strings.Split(module, ":") if len(s) != 2 { - logger.FatalExit("Unable to parse discovery module", module) + dlog.Common.FatalPanic("Unable to parse discovery module", module) } module = s[0] options = s[1] @@ -72,11 +72,10 @@ func (d *Discovery) initRegex() { } regexStr := string(runes) - logger.Debug("Using filter regex", regexStr) - + dlog.Common.Debug("Using filter regex", regexStr) regex, err := regexp.Compile(regexStr) if err != nil { - logger.FatalExit("Could not compile regex", regexStr, err) + dlog.Common.FatalPanic("Could not compile regex", regexStr, err) } d.regex = regex @@ -90,14 +89,12 @@ func (d *Discovery) ServerList() []string { if d.regex != nil { servers = d.filterList(servers) } - servers = d.dedupList(servers) - if d.order == Shuffle { servers = d.shuffleList(servers) } - logger.Debug("Discovered servers", len(servers), servers) + dlog.Common.Debug("Discovered servers", len(servers), servers) return servers } @@ -105,12 +102,10 @@ func (d *Discovery) serverListFromModule() []string { if d.module != "" { return d.serverListFromReflectedModule() } - if _, err := os.Stat(d.server); err == nil { // Appears to be a file name, now try to read from that file. return d.ServerListFromFILE() } - // Appears to be a list of FQDNs (or a single FQDN) return d.ServerListFromCOMMA() } @@ -120,53 +115,47 @@ func (d *Discovery) serverListFromModule() []string { // Discovery. Whereas MODULENAME must be a upeprcase string. func (d *Discovery) serverListFromReflectedModule() []string { methodName := fmt.Sprintf("ServerListFrom%s", d.module) - + // Now we are reflecting the serve discovery function by it's name. rt := reflect.TypeOf(d) reflectedMethod, ok := rt.MethodByName(methodName) if !ok { - logger.FatalExit("No such server discovery module", d.module, methodName) + dlog.Common.FatalPanic("No such server discovery module", d.module, methodName) } - inputValues := make([]reflect.Value, 1) // Thist input value is method receiver. inputValues[0] = reflect.ValueOf(d) returnValues := reflectedMethod.Func.Call(inputValues) - // First return value is server list. return returnValues[0].Interface().([]string) } // Filter server list based on a regexp. func (d *Discovery) filterList(servers []string) (filtered []string) { - logger.Debug("Filtering server list") - + dlog.Common.Debug("Filtering server list") for _, server := range servers { if d.regex.MatchString(server) { filtered = append(filtered, server) } } - return } // Deduplicate the server list. func (d *Discovery) dedupList(servers []string) (deduped []string) { serverMap := make(map[string]struct{}, len(servers)) - for _, server := range servers { if _, ok := serverMap[server]; !ok { serverMap[server] = struct{}{} deduped = append(deduped, server) } } - - logger.Debug("Deduped server list", len(servers), len(deduped)) + dlog.Common.Debug("Deduped server list", len(servers), len(deduped)) return } // Randomly shuffle the server list. func (d *Discovery) shuffleList(servers []string) []string { - logger.Debug("Shuffling server list") + dlog.Common.Debug("Shuffling server list") r := rand.New(rand.NewSource(time.Now().Unix())) shuffled := make([]string, len(servers)) diff --git a/internal/discovery/file.go b/internal/discovery/file.go index 1250755..fb46eeb 100644 --- a/internal/discovery/file.go +++ b/internal/discovery/file.go @@ -4,16 +4,16 @@ import ( "bufio" "os" - "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/io/dlog" ) // ServerListFromFILE retrieves a list of servers from a file. func (d *Discovery) ServerListFromFILE() (servers []string) { - logger.Debug("Retrieving server list from file", d.server) + dlog.Common.Debug("Retrieving server list from file", d.server) file, err := os.Open(d.server) if err != nil { - logger.FatalExit(d.server, err) + dlog.Common.FatalPanic(d.server, err) } defer file.Close() @@ -22,7 +22,7 @@ func (d *Discovery) ServerListFromFILE() (servers []string) { servers = append(servers, scanner.Text()) } if err := scanner.Err(); err != nil { - logger.FatalExit(d.server, err) + dlog.Common.FatalPanic(d.server, err) } return diff --git a/internal/done.go b/internal/done.go index 54e5e8e..94f9289 100644 --- a/internal/done.go +++ b/internal/done.go @@ -17,6 +17,15 @@ func NewDone() *Done { } } +func (d *Done) String() string { + select { + case <-d.Done(): + return "Done(yes)" + default: + return "Done(no)" + } +} + // Done returns the done channel (closed when done) func (d *Done) Done() <-chan struct{} { return d.ch @@ -26,7 +35,6 @@ func (d *Done) Done() <-chan struct{} { func (d *Done) Shutdown() { d.mutex.Lock() defer d.mutex.Unlock() - select { case <-d.ch: return diff --git a/internal/io/dlog/dlog.go b/internal/io/dlog/dlog.go new file mode 100644 index 0000000..5e0c3a1 --- /dev/null +++ b/internal/io/dlog/dlog.go @@ -0,0 +1,272 @@ +package dlog + +import ( + "context" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "runtime" + "strconv" + "strings" + "sync" + "time" + + "github.com/mimecast/dtail/internal/color/brush" + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/io/dlog/loggers" + "github.com/mimecast/dtail/internal/io/pool" + "github.com/mimecast/dtail/internal/protocol" + "github.com/mimecast/dtail/internal/source" +) + +// Client is the log handler for the client packages. +var Client *DLog + +// Server is the log handler for the server packages. +var Server *DLog + +// Common is the log handler for all other packages. +var Common *DLog + +var mutex sync.Mutex +var started bool + +// Start logger(s). +func Start(ctx context.Context, wg *sync.WaitGroup, sourceProcess source.Source) { + mutex.Lock() + defer mutex.Unlock() + + if started { + Common.FatalPanic("Logger already started") + } + + Client = new(sourceProcess, source.Client) + Server = new(sourceProcess, source.Server) + Common = Client + if sourceProcess == source.Server { + Common = Server + } + + var wg2 sync.WaitGroup + wg2.Add(2) + go Client.start(ctx, &wg2) + go Server.start(ctx, &wg2) + + go rotation(ctx) + go func() { + wg2.Wait() + wg.Done() + }() + + started = true +} + +// DLog is the DTail logger. +type DLog struct { + logger loggers.Logger + // Is this a DTail server or client process logging? + sourceProcess source.Source + // Is this a DTail server or client package logging? In serverless mode + // the client can also execute code from the server package. + sourcePackage source.Source + // Max log level to log. + maxLevel level + // Current hostname. + hostname string +} + +// new creates a new DTail logger. +func new(sourceProcess, sourcePackage source.Source) *DLog { + hostname, err := os.Hostname() + if err != nil { + panic(err) + } + logRotation := loggers.NewStrategy(config.Common.LogRotation) + loggerName := config.Common.Logger + level := newLevel(config.Common.LogLevel) + + return &DLog{ + logger: loggers.Factory(sourceProcess.String(), loggerName, logRotation), + sourceProcess: sourceProcess, + sourcePackage: sourcePackage, + maxLevel: level, + hostname: hostname, + } +} + +func (d *DLog) start(ctx context.Context, wg *sync.WaitGroup) { + defer wg.Done() + var wg2 sync.WaitGroup + wg2.Add(1) + d.logger.Start(ctx, &wg2) + <-ctx.Done() + wg2.Wait() +} + +func (d *DLog) log(level level, args []interface{}) string { + if d.maxLevel < level { + return "" + } + sb := pool.BuilderBuffer.Get().(*strings.Builder) + defer pool.RecycleBuilderBuffer(sb) + now := time.Now() + + switch d.sourceProcess { + case source.Client: + sb.WriteString(d.sourcePackage.String()) + sb.WriteString(protocol.FieldDelimiter) + sb.WriteString(d.hostname) + sb.WriteString(protocol.FieldDelimiter) + sb.WriteString(level.String()) + default: + sb.WriteString(level.String()) + sb.WriteString(protocol.FieldDelimiter) + sb.WriteString(now.Format("20060102-150405")) + } + sb.WriteString(protocol.FieldDelimiter) + d.writeArgStrings(sb, args) + + message := sb.String() + if !config.Client.TermColorsEnable || !d.logger.SupportsColors() { + d.logger.Log(now, message) + return message + } + + d.logger.LogWithColors(now, message, brush.Colorfy(message)) + return message +} + +func (d *DLog) writeArgStrings(sb *strings.Builder, args []interface{}) { + for i, arg := range args { + if i > 0 { + sb.WriteString(protocol.FieldDelimiter) + } + switch v := arg.(type) { + case string: + sb.WriteString(v) + case error: + sb.WriteString(v.Error()) + default: + sb.WriteString(fmt.Sprintf("%v", v)) + } + } +} + +// FatalPanic terminates the process with a fatal error. +func (d *DLog) FatalPanic(args ...interface{}) { + d.log(Fatal, args) + d.Flush() + + var sb strings.Builder + d.writeArgStrings(&sb, args) + panic(sb.String()) +} + +// Fatal logs a fatal error. +func (d *DLog) Fatal(args ...interface{}) string { + return d.log(Fatal, args) +} + +// Error logging. +func (d *DLog) Error(args ...interface{}) string { + return d.log(Error, args) +} + +// Warn logs a warning message. +func (d *DLog) Warn(args ...interface{}) string { + return d.log(Warn, args) +} + +// Info logging. +func (d *DLog) Info(args ...interface{}) string { + return d.log(Info, args) +} + +// Verbose logging. +func (d *DLog) Verbose(args ...interface{}) string { + return d.log(Verbose, args) +} + +// Debug logging. +func (d *DLog) Debug(args ...interface{}) string { + return d.log(Debug, args) +} + +// Trace logging. +func (d *DLog) Trace(args ...interface{}) string { + _, file, line, _ := runtime.Caller(1) + args = append(args, fmt.Sprintf("at %s:%d", file, line)) + return d.log(Trace, args) +} + +// Devel used for development purpose only logging (e.g. "print" debugging). +func (d *DLog) Devel(args ...interface{}) string { + _, file, line, _ := runtime.Caller(1) + args = append(args, fmt.Sprintf("at %s:%d", file, line)) + return d.log(Devel, args) +} + +// Raw message logging. +func (d *DLog) Raw(message string) string { + if !config.Client.TermColorsEnable || !d.logger.SupportsColors() { + d.logger.Log(time.Now(), message) + return message + } + d.logger.LogWithColors(time.Now(), message, brush.Colorfy(message)) + return message +} + +// Mapreduce logging. +func (d *DLog) Mapreduce(table string, data map[string]interface{}) string { + args := make([]interface{}, len(data)+1) + + if d.sourceProcess == source.Server { + // level|date-time|process|caller|cpus|goroutines|cgocalls|loadavg|uptime|MAPREDUCE:TABLE|key=value|... + + var loadAvg string + if loadAvgBytes, err := ioutil.ReadFile("/proc/loadavg"); err == nil { + tmp := string(loadAvgBytes) + s := strings.SplitN(tmp, " ", 2) + loadAvg = s[0] + } + + var uptime string + if uptimeBytes, err := ioutil.ReadFile("/proc/uptime"); err == nil { + tmp := string(uptimeBytes) + s := strings.SplitN(tmp, ".", 2) + i, _ := strconv.ParseInt(s[0], 10, 64) + t := time.Duration(i) * time.Second + uptime = fmt.Sprintf("%v", t) + } + + _, file, line, _ := runtime.Caller(1) + args[0] = fmt.Sprintf("%d|%s:%d|%d|%d|%d|%s|%s|MAPREDUCE:%s", + os.Getpid(), + filepath.Base(file), line, + runtime.NumCPU(), + runtime.NumGoroutine(), + runtime.NumCgoCall(), + loadAvg, + uptime, + strings.ToUpper(table)) + } else { + args[0] = fmt.Sprintf("STATS:%s", strings.ToUpper(table)) + } + + i := 1 + for k, v := range data { + args[i] = fmt.Sprintf("%s=%v", k, v) + i++ + } + return d.log(Info, args) +} + +// Flush the log buffers. +func (d *DLog) Flush() { d.logger.Flush() } + +// Pause the logging. +func (d *DLog) Pause() { d.logger.Pause() } + +// Resume the logging. +func (d *DLog) Resume() { d.logger.Resume() } diff --git a/internal/io/dlog/level.go b/internal/io/dlog/level.go new file mode 100644 index 0000000..05d9ed9 --- /dev/null +++ b/internal/io/dlog/level.go @@ -0,0 +1,84 @@ +package dlog + +import ( + "fmt" + "strings" +) + +type level int + +// Available log levels. +const ( + None level = iota + Fatal level = iota + Error level = iota + Warn level = iota + Info level = iota + Default level = iota + Verbose level = iota + Debug level = iota + Devel level = iota + Trace level = iota + All level = iota +) + +var allLevels = []level{Fatal, Error, Warn, Info, Default, Verbose, Debug, + Devel, Trace, All} + +func newLevel(l string) level { + switch strings.ToLower(l) { + case "none": + return None + case "fatal": + return Fatal + case "error": + return Error + case "warn": + return Warn + case "info": + return Info + case "": + fallthrough + case "default": + return Default + case "verbose": + return Verbose + case "debug": + return Debug + case "devel": + return Devel + case "trace": + return Trace + case "all": + return All + } + panic(fmt.Sprintf("Unknown log level %s, must be one of: %v", l, allLevels)) +} + +func (l level) String() string { + switch l { + case None: + return "NONE" + case Fatal: + return "FATAL" + case Error: + return "ERROR" + case Warn: + return "WARN" + case Info: + return "INFO" + case Default: + return "DEFAULT" + case Verbose: + return "VERBOSE" + case Debug: + return "DEBUG" + case Devel: + return "DEVEL" + case Trace: + return "TRACE" + case All: + return "ALL" + } + panic("Unknown log level " + fmt.Sprintf("%d", l)) +} diff --git a/internal/io/dlog/loggers/factory.go b/internal/io/dlog/loggers/factory.go new file mode 100644 index 0000000..a5cc7cf --- /dev/null +++ b/internal/io/dlog/loggers/factory.go @@ -0,0 +1,54 @@ +package loggers + +import ( + "fmt" + "strings" + "sync" +) + +var factoryMap map[string]Logger +var factoryMutex sync.Mutex + +// Factory is there to retrieve a logger based on various settings. +func Factory(sourceName, loggerName string, logRotation Strategy) Logger { + factoryMutex.Lock() + defer factoryMutex.Unlock() + + id := fmt.Sprintf("sourceName:%s,fileBase:%s,loggerName:%s", sourceName, + logRotation.FileBase, loggerName) + if factoryMap == nil { + factoryMap = make(map[string]Logger) + } + + singleton, ok := factoryMap[id] + if !ok { + switch strings.ToLower(loggerName) { + case "none": + singleton = none{} + case "stdout": + singleton = newStdout() + factoryMap[id] = singleton + case "file": + singleton = newFile(logRotation) + factoryMap[id] = singleton + case "fout": + singleton = newFout(logRotation) + factoryMap[id] = singleton + default: + panic(fmt.Sprintf("Unsupported logger type '%s'", loggerName)) + } + } + return singleton +} + +// FactoryRotate invokes a log rotation of all loggers. +func FactoryRotate() { + factoryMutex.Lock() + defer factoryMutex.Unlock() + if factoryMap == nil { + return + } + for _, logger := range factoryMap { + logger.Rotate() + } +} diff --git a/internal/io/dlog/loggers/file.go b/internal/io/dlog/loggers/file.go new file mode 100644 index 0000000..94824fe --- /dev/null +++ b/internal/io/dlog/loggers/file.go @@ -0,0 +1,165 @@ +package loggers + +import ( + "bufio" + "context" + "fmt" + "os" + "runtime" + "sync" + "time" + + "github.com/mimecast/dtail/internal/config" +) + +type fileWriter struct{} + +type fileMessageBuf struct { + now time.Time + message string +} + +type file struct { + bufferCh chan *fileMessageBuf + pauseCh chan struct{} + resumeCh chan struct{} + rotateCh chan struct{} + flushCh chan struct{} + fd *os.File + writer *bufio.Writer + mutex sync.Mutex + started bool + lastFileName string + strategy Strategy +} + +func newFile(strategy Strategy) *file { + return &file{ + bufferCh: make(chan *fileMessageBuf, runtime.NumCPU()*100), + pauseCh: make(chan struct{}), + resumeCh: make(chan struct{}), + rotateCh: make(chan struct{}), + flushCh: make(chan struct{}), + strategy: strategy, + } +} + +func (f *file) Start(ctx context.Context, wg *sync.WaitGroup) { + f.mutex.Lock() + defer func() { + f.started = true + f.mutex.Unlock() + }() + + if f.started { + // Logger already started from another Goroutine. + wg.Done() + return + } + + pause := func(ctx context.Context) { + select { + case <-f.resumeCh: + return + case <-ctx.Done(): + return + } + } + + go func() { + defer wg.Done() + for { + select { + case m := <-f.bufferCh: + f.write(m) + case <-f.pauseCh: + pause(ctx) + case <-f.flushCh: + f.flush() + case <-ctx.Done(): + f.flush() + f.fd.Close() + return + } + } + }() +} + +func (f *file) Log(now time.Time, message string) { + f.bufferCh <- &fileMessageBuf{now, message} +} + +func (f *file) LogWithColors(now time.Time, message, coloredMessage string) { + panic("Colors not supported in file logger") +} + +func (f *file) Pause() { f.pauseCh <- struct{}{} } +func (f *file) Resume() { f.resumeCh <- struct{}{} } +func (f *file) Flush() { f.flushCh <- struct{}{} } + +func (f *file) Rotate() { f.rotateCh <- struct{}{} } +func (*file) SupportsColors() bool { return false } + +func (f *file) write(m *fileMessageBuf) { + select { + case <-f.rotateCh: + // Force re-opening the outfile next time in getWriter. + f.lastFileName = "" + default: + } + + var writer *bufio.Writer + if f.strategy.Rotation == DailyRotation { + writer = f.getWriter(m.now.Format("20060102")) + } else { + writer = f.getWriter(f.strategy.FileBase) + } + + writer.WriteString(m.message) + writer.WriteByte('\n') +} + +func (f *file) getWriter(name string) *bufio.Writer { + if f.lastFileName == name { + return f.writer + } + if _, err := os.Stat(config.Common.LogDir); os.IsNotExist(err) { + if err = os.MkdirAll(config.Common.LogDir, 0755); err != nil { + panic(err) + } + } + + logFile := fmt.Sprintf("%s/%s.log", config.Common.LogDir, name) + newFd, err := os.OpenFile(logFile, os.O_CREATE|os.O_RDWR|os.O_APPEND, 0644) + if err != nil { + panic(err) + } + + // Close old writer. + if f.fd != nil { + f.writer.Flush() + f.fd.Close() + } + // Set new writer. + f.fd = newFd + f.writer = bufio.NewWriterSize(f.fd, 1) + f.lastFileName = name + + return f.writer +} + +func (f *file) flush() { + defer func() { + if f.writer != nil { + f.writer.Flush() + } + }() + for { + select { + case m := <-f.bufferCh: + f.write(m) + default: + return + } + } +} diff --git a/internal/io/dlog/loggers/fout.go b/internal/io/dlog/loggers/fout.go new file mode 100644 index 0000000..60c318d --- /dev/null +++ b/internal/io/dlog/loggers/fout.go @@ -0,0 +1,46 @@ +package loggers + +import ( + "context" + "sync" + "time" +) + +type fout struct { + file *file + stdout *stdout +} + +// Logs to both, a file and stdout +func newFout(strategy Strategy) *fout { + return &fout{file: newFile(strategy), stdout: newStdout()} +} + +func (f *fout) Start(ctx context.Context, wg *sync.WaitGroup) { + go func() { + defer wg.Done() + + var wg2 sync.WaitGroup + wg2.Add(2) + f.file.Start(ctx, &wg2) + f.stdout.Start(ctx, &wg2) + wg2.Wait() + }() +} + +func (f *fout) Log(now time.Time, message string) { + f.stdout.Log(now, message) + f.file.Log(now, message) +} + +func (f *fout) LogWithColors(now time.Time, message, coloredMessage string) { + f.stdout.LogWithColors(now, "", coloredMessage) + f.file.Log(now, message) +} + +func (f *fout) Flush() { f.stdout.Flush(); f.file.Flush() } +func (f *fout) Pause() { f.stdout.Pause(); f.file.Pause() } +func (f *fout) Resume() { f.stdout.Resume(); f.file.Resume() } +func (f *fout) Rotate() { f.file.Rotate() } + +func (fout) SupportsColors() bool { return true } diff --git a/internal/io/dlog/loggers/logger.go b/internal/io/dlog/loggers/logger.go new file mode 100644 index 0000000..d4e85de --- /dev/null +++ b/internal/io/dlog/loggers/logger.go @@ -0,0 +1,19 @@ +package loggers + +import ( + "context" + "sync" + "time" +) + +// Logger is there to plug in your own log implementation. +type Logger interface { + Log(now time.Time, message string) + LogWithColors(now time.Time, message, messageWithColors string) + Start(ctx context.Context, wg *sync.WaitGroup) + Flush() + Pause() + Resume() + Rotate() + SupportsColors() bool +} diff --git a/internal/io/dlog/loggers/none.go b/internal/io/dlog/loggers/none.go new file mode 100644 index 0000000..270027f --- /dev/null +++ b/internal/io/dlog/loggers/none.go @@ -0,0 +1,21 @@ +package loggers + +import ( + "context" + "sync" + "time" +) + +// don't log anything +type none struct{} + +func (none) Start(ctx context.Context, wg *sync.WaitGroup) { wg.Done() } +func (none) Log(now time.Time, message string) {} + +func (none) LogWithColors(now time.Time, message, coloredMessage string) {} + +func (none) Flush() {} +func (none) Pause() {} +func (none) Resume() {} +func (none) Rotate() {} +func (none) SupportsColors() bool { return false } diff --git a/internal/io/dlog/loggers/stdout.go b/internal/io/dlog/loggers/stdout.go new file mode 100644 index 0000000..05485c6 --- /dev/null +++ b/internal/io/dlog/loggers/stdout.go @@ -0,0 +1,54 @@ +package loggers + +import ( + "context" + "fmt" + "sync" + "time" +) + +type stdout struct { + pauseCh chan struct{} + resumeCh chan struct{} + mutex sync.Mutex +} + +func newStdout() *stdout { + return &stdout{ + pauseCh: make(chan struct{}), + resumeCh: make(chan struct{}), + } +} + +func (s *stdout) Start(ctx context.Context, wg *sync.WaitGroup) { + wg.Done() +} + +func (s *stdout) Log(now time.Time, message string) { + s.log(message) +} + +func (s *stdout) LogWithColors(now time.Time, message, coloredMessage string) { + s.log(coloredMessage) +} + +func (s *stdout) log(message string) { + s.mutex.Lock() + defer s.mutex.Unlock() + + select { + case <-s.pauseCh: + // Pause until resumed. + <-s.resumeCh + default: + } + + fmt.Println(message) +} + +func (s *stdout) Pause() { s.pauseCh <- struct{}{} } +func (s *stdout) Resume() { s.resumeCh <- struct{}{} } +func (s *stdout) Flush() {} +func (s *stdout) Rotate() {} + +func (stdout) SupportsColors() bool { return true } diff --git a/internal/io/dlog/loggers/strategy.go b/internal/io/dlog/loggers/strategy.go new file mode 100644 index 0000000..48e7d44 --- /dev/null +++ b/internal/io/dlog/loggers/strategy.go @@ -0,0 +1,35 @@ +package loggers + +import ( + "os" + "path/filepath" + "strings" +) + +// Rotation is the actual strategy used for log rotation.. +type Rotation int + +const ( + // DailyRotation tells DTail to rotate its logs on a daily basis or on SIGHUP. + DailyRotation Rotation = iota + // SignalRotation tells DTail to rotate its logs only on SIGHUP. + SignalRotation Rotation = iota +) + +// Strategy is a pair of the rotation and the file base. +type Strategy struct { + // Rotation is the actual rotation strategy used. + Rotation Rotation + // FileBase can be a name (e.g. "dserver", "dmap") when signal rotation is used. + FileBase string +} + +// NewStrategy returns the stratey based on its name. +func NewStrategy(name string) Strategy { + switch strings.ToLower(name) { + case "daily": + return Strategy{DailyRotation, ""} + default: + return Strategy{SignalRotation, filepath.Base(os.Args[0])} + } +} diff --git a/internal/io/dlog/rotation.go b/internal/io/dlog/rotation.go new file mode 100644 index 0000000..15ce1fd --- /dev/null +++ b/internal/io/dlog/rotation.go @@ -0,0 +1,27 @@ +package dlog + +import ( + "context" + "os" + "os/signal" + "syscall" + + "github.com/mimecast/dtail/internal/io/dlog/loggers" +) + +func rotation(ctx context.Context) { + rotateCh := make(chan os.Signal, 1) + signal.Notify(rotateCh, syscall.SIGHUP) + go func() { + for { + select { + case <-rotateCh: + Common.Debug("Invoking log rotation") + loggers.FactoryRotate() + return + case <-ctx.Done(): + return + } + } + }() +} diff --git a/internal/io/fs/catfile.go b/internal/io/fs/catfile.go index 7f387bc..01c15ba 100644 --- a/internal/io/fs/catfile.go +++ b/internal/io/fs/catfile.go @@ -6,7 +6,9 @@ type CatFile struct { } // NewCatFile returns a new file catter. -func NewCatFile(filePath string, globID string, serverMessages chan<- string, limiter chan struct{}) CatFile { +func NewCatFile(filePath string, globID string, serverMessages chan<- string, + limiter chan struct{}) CatFile { + return CatFile{ readFile: readFile{ filePath: filePath, diff --git a/internal/io/fs/filereader.go b/internal/io/fs/filereader.go index efd410e..b05fd39 100644 --- a/internal/io/fs/filereader.go +++ b/internal/io/fs/filereader.go @@ -8,9 +8,11 @@ import ( "github.com/mimecast/dtail/internal/regex" ) -// FileReader is the interface used on the dtail server to read/cat/grep/mapr... a file. +// FileReader is the interface used on the dtail server to read/cat/grep/mapr... +// a file. type FileReader interface { - Start(ctx context.Context, lContext lcontext.LContext, lines chan<- line.Line, re regex.Regex) error + Start(ctx context.Context, ltx lcontext.LContext, lines chan<- line.Line, + re regex.Regex) error FilePath() string Retry() bool } diff --git a/internal/io/fs/filter.go b/internal/io/fs/filter.go deleted file mode 100644 index c4f605e..0000000 --- a/internal/io/fs/filter.go +++ /dev/null @@ -1,167 +0,0 @@ -package fs - -import ( - "context" - - "github.com/mimecast/dtail/internal/io/line" - "github.com/mimecast/dtail/internal/lcontext" - "github.com/mimecast/dtail/internal/regex" -) - -func (f readFile) filter(ctx context.Context, rawLines <-chan []byte, lines chan<- line.Line, re regex.Regex, lContext lcontext.LContext) { - // Do we have any kind of local context settings? If so then run the more complex - // filterWithLContext method. - if lContext.Has() { - // We can not skip transmitting any lines to the client with a local - // grep context specified. - f.canSkipLines = false - f.filterWithLContext(ctx, rawLines, lines, re, lContext) - return - } - - f.filterWithoutLContext(ctx, rawLines, lines, re) -} - -// Filter log lines matching a given regular expression, however with local grep context. -func (f readFile) filterWithLContext(ctx context.Context, rawLines <-chan []byte, lines chan<- line.Line, re regex.Regex, lContext lcontext.LContext) { - // Scenario 1: Finish once maxCount hits found - maxCount := lContext.MaxCount - processMaxCount := maxCount > 0 - maxReached := false - - // Scenario 2: Print prev. N lines when current line matches. - before := lContext.BeforeContext - processBefore := before > 0 - var beforeBuf chan []byte - if processBefore { - beforeBuf = make(chan []byte, before) - } - - // Screnario 3: Print next N lines when current line matches. - after := 0 - processAfter := lContext.AfterContext > 0 - - for rawLine := range rawLines { - // logger.Debug("rawLine", string(rawLine)) - f.updatePosition() - - if !re.Match(rawLine) { - f.updateLineNotMatched() - - if processAfter && after > 0 { - after-- - myLine := line.Line{Content: rawLine, SourceID: f.globID, Count: f.totalLineCount(), TransmittedPerc: 100} - select { - case lines <- myLine: - case <-ctx.Done(): - return - } - - } else if processBefore { - // Keep last num BeforeContext raw messages. - select { - case beforeBuf <- rawLine: - default: - <-beforeBuf - beforeBuf <- rawLine - } - } - continue - } - - f.updateLineMatched() - - if processAfter { - if maxReached { - return - } - after = lContext.AfterContext - } - - if processBefore { - i := uint64(len(beforeBuf)) - for { - select { - case myRawLine := <-beforeBuf: - myLine := line.Line{Content: myRawLine, SourceID: f.globID, Count: f.totalLineCount() - i, TransmittedPerc: 100} - i-- - select { - case lines <- myLine: - case <-ctx.Done(): - return - } - default: - // beforeBuf is now empty. - } - if len(beforeBuf) == 0 { - break - } - } - } - - line := line.Line{Content: rawLine, SourceID: f.globID, Count: f.totalLineCount(), TransmittedPerc: 100} - - select { - case lines <- line: - if processMaxCount { - maxCount-- - if maxCount == 0 { - if !processAfter || after == 0 { - return - } - // Unfortunatley we have to continue filter, as there might be more lines to print - maxReached = true - } - } - case <-ctx.Done(): - return - } - } -} - -// Filter log lines matching a given regular expression, there is no local grep context specified. -func (f readFile) filterWithoutLContext(ctx context.Context, rawLines <-chan []byte, lines chan<- line.Line, re regex.Regex) { - for { - select { - case rawLine, ok := <-rawLines: - f.updatePosition() - if !ok { - return - } - - if f.lineUntransmittable(rawLine, len(lines), cap(lines), re) { - continue - } - - line := line.Line{Content: rawLine, SourceID: f.globID, Count: f.totalLineCount(), TransmittedPerc: f.transmittedPerc()} - - select { - case lines <- line: - continue - case <-ctx.Done(): - return - } - } - } -} - -func (f readFile) lineUntransmittable(rawLine []byte, length, capacity int, re regex.Regex) bool { - if !re.Match(rawLine) { - f.updateLineNotMatched() - f.updateLineNotTransmitted() - // Regex dosn't match, so not interested in it. - return true - } - f.updateLineMatched() - - // Can we actually send more messages, channel capacity reached? - if f.canSkipLines && length >= capacity { - f.updateLineNotTransmitted() - // Matching, not transmittable - return true - } - f.updateLineTransmitted() - - // Matching, transmittable - return false -} diff --git a/internal/io/fs/permissions/permission.go b/internal/io/fs/permissions/permission.go index cc5dd9b..d621c09 100644 --- a/internal/io/fs/permissions/permission.go +++ b/internal/io/fs/permissions/permission.go @@ -3,12 +3,12 @@ package permissions import ( - "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/io/dlog" ) // ToRead is to check whether user has read permissions to a given file. func ToRead(user, filePath string) (bool, error) { // Only implemented for Linux, always expect true - logger.Warn(user, filePath, "Not performing ACL check, not supported on this platform") + dlog.Common.Debug(user, filePath, "Not performing ACL check as not compiled in") return true, nil } diff --git a/internal/io/fs/permissions/permission_linuxacl.go b/internal/io/fs/permissions/permission_linuxacl.go index 7d2d7ca..904b90f 100644 --- a/internal/io/fs/permissions/permission_linuxacl.go +++ b/internal/io/fs/permissions/permission_linuxacl.go @@ -13,7 +13,7 @@ import ( "unsafe" ) -// ToRead checks whether user has Linux file system permissions to read a given file. +// ToRead checks whether user has Linux file system permissions to read a file. func ToRead(user, filePath string) (bool, error) { cUser := C.CString(user) cFilePath := C.CString(filePath) diff --git a/internal/io/fs/readfile.go b/internal/io/fs/readfile.go index 161e3f0..28cbe58 100644 --- a/internal/io/fs/readfile.go +++ b/internal/io/fs/readfile.go @@ -2,16 +2,20 @@ package fs import ( "bufio" + "bytes" "compress/gzip" "context" + "errors" "fmt" "io" "os" "strings" + "sync" "time" + "github.com/mimecast/dtail/internal/io/dlog" "github.com/mimecast/dtail/internal/io/line" - "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/io/pool" "github.com/mimecast/dtail/internal/lcontext" "github.com/mimecast/dtail/internal/regex" @@ -37,31 +41,10 @@ type readFile struct { limiter chan struct{} } -func (f readFile) makeReader(fd *os.File) (reader *bufio.Reader, err error) { - switch { - case strings.HasSuffix(f.FilePath(), ".gz"): - fallthrough - case strings.HasSuffix(f.FilePath(), ".gzip"): - logger.Info(f.FilePath(), "Detected gzip compression format") - var gzipReader *gzip.Reader - gzipReader, err = gzip.NewReader(fd) - if err != nil { - return - } - reader = bufio.NewReader(gzipReader) - case strings.HasSuffix(f.FilePath(), ".zst"): - logger.Info(f.FilePath(), "Detected zstd compression format") - reader = bufio.NewReader(zstd.NewReader(fd)) - default: - reader = bufio.NewReader(fd) - } - - return -} - // String returns the string representation of the readFile func (f readFile) String() string { - return fmt.Sprintf("readFile(filePath:%s,globID:%s,retry:%v,canSkipLines:%v,seekEOF:%v)", + return fmt.Sprintf( + "readFile(filePath:%s,globID:%s,retry:%v,canSkipLines:%v,seekEOF:%v)", f.filePath, f.globID, f.retry, @@ -80,8 +63,10 @@ func (f readFile) Retry() bool { } // Start tailing a log file. -func (f readFile) Start(ctx context.Context, lContext lcontext.LContext, lines chan<- line.Line, re regex.Regex) error { - logger.Debug("readFile", f) +func (f readFile) Start(ctx context.Context, ltx lcontext.LContext, + lines chan<- line.Line, re regex.Regex) error { + + dlog.Common.Debug("readFile", f) defer func() { select { case <-f.limiter: @@ -93,7 +78,8 @@ func (f readFile) Start(ctx context.Context, lContext lcontext.LContext, lines c case f.limiter <- struct{}{}: default: select { - case f.serverMessages <- logger.Warn(f.filePath, f.globID, "Server limit reached. Queuing file..."): + case f.serverMessages <- dlog.Common.Warn(f.filePath, f.globID, + "Server limit reached. Queuing file..."): case <-ctx.Done(): return nil } @@ -110,111 +96,335 @@ func (f readFile) Start(ctx context.Context, lContext lcontext.LContext, lines c fd.Seek(0, io.SeekEnd) } - rawLines := make(chan []byte, 100) + rawLines := make(chan *bytes.Buffer, 100) + truncate := make(chan struct{}) + readCtx, readCancel := context.WithCancel(ctx) + var filterWg sync.WaitGroup + filterWg.Add(1) - filterDone := make(chan struct{}) + go f.periodicTruncateCheck(ctx, truncate) go func() { - f.filter(ctx, rawLines, lines, re, lContext) - close(filterDone) + f.filter(ctx, ltx, rawLines, lines, re) + filterWg.Done() // If the filter stopped, make the reader stop too, no need to read // more data if there is nothing more the filter wants to filter for! // E.g. it could be that we only want to filter N matches but not more. readCancel() }() - err = f.read(readCtx, fd, rawLines) + err = f.read(readCtx, fd, rawLines, truncate) close(rawLines) - - // Filter may flushes some data still. So wait until it is done here. - <-filterDone + // Filter may sends some data still. So wait until it is done here. + filterWg.Wait() return err } -func (f readFile) read(ctx context.Context, fd *os.File, rawLines chan []byte) error { - var offset uint64 +func (f readFile) periodicTruncateCheck(ctx context.Context, truncate chan struct{}) { + for { + select { + case <-time.After(time.Second * 3): + select { + case truncate <- struct{}{}: + case <-ctx.Done(): + } + case <-ctx.Done(): + return + } + } +} +func (f readFile) makeReader(fd *os.File) (reader *bufio.Reader, err error) { + switch { + case strings.HasSuffix(f.FilePath(), ".gz"): + fallthrough + case strings.HasSuffix(f.FilePath(), ".gzip"): + dlog.Common.Info(f.FilePath(), "Detected gzip compression format") + var gzipReader *gzip.Reader + gzipReader, err = gzip.NewReader(fd) + if err != nil { + return + } + reader = bufio.NewReader(gzipReader) + case strings.HasSuffix(f.FilePath(), ".zst"): + dlog.Common.Info(f.FilePath(), "Detected zstd compression format") + reader = bufio.NewReader(zstd.NewReader(fd)) + default: + reader = bufio.NewReader(fd) + } + return +} + +func (f readFile) read(ctx context.Context, fd *os.File, rawLines chan *bytes.Buffer, truncate <-chan struct{}) error { + var offset uint64 reader, err := f.makeReader(fd) if err != nil { return err } - rawLine := make([]byte, 0, 512) lineLengthThreshold := 1024 * 1024 // 1mb - longLineWarning := false - - checkTruncate := f.truncateTimer(ctx) + warnedAboutLongLine := false + message := pool.BytesBuffer.Get().(*bytes.Buffer) for { - select { - case <-ctx.Done(): - return nil - default: - } - - select { - case <-checkTruncate: - if isTruncated, err := f.truncated(fd); isTruncated { - return err - } - logger.Info(f.filePath, "Current offset", offset) - default: - } - - // Read some bytes (max 4k at once as of go 1.12). isPrefix will - // be set if line does not fit into 4k buffer. - bytes, isPrefix, err := reader.ReadLine() + b, err := reader.ReadByte() if err != nil { - // If EOF, sleep a couple of ms and return with nil error. - // If other error, return with non-nil error. if err != io.EOF { return err } + select { + case <-truncate: + if isTruncated, err := f.truncated(fd); isTruncated { + return err + } + case <-ctx.Done(): + return nil + default: + } if !f.seekEOF { - logger.Debug(f.FilePath(), "End of file reached") + dlog.Common.Info(f.FilePath(), "End of file reached") return nil } time.Sleep(time.Millisecond * 100) continue } + offset++ - rawLine = append(rawLine, bytes...) - offset += uint64(len(bytes)) - - if !isPrefix { - // last LineRead call returned contend until end of line. - rawLine = append(rawLine, '\n') + switch b { + case '\n': select { - case rawLines <- rawLine: + case rawLines <- message: + message = pool.BytesBuffer.Get().(*bytes.Buffer) + //fmt.Printf("%d %d %p\n", message.Len(), message.Cap(), message) + warnedAboutLongLine = false case <-ctx.Done(): return nil } - rawLine = make([]byte, 0, 512) - if longLineWarning { - longLineWarning = false + default: + if message.Len() >= lineLengthThreshold { + if !warnedAboutLongLine { + f.serverMessages <- dlog.Common.Warn(f.filePath, + "Long log line, splitting into multiple lines") + warnedAboutLongLine = true + } + message.WriteString("\n") + select { + case rawLines <- message: + message = pool.BytesBuffer.Get().(*bytes.Buffer) + case <-ctx.Done(): + return nil + } + } + message.WriteByte(b) + } + } +} + +// Filter log lines matching a given regular expression. +func (f readFile) filter(ctx context.Context, ltx lcontext.LContext, + rawLines <-chan *bytes.Buffer, lines chan<- line.Line, re regex.Regex) { + + // Do we have any kind of local context settings? If so then run the more complex + // filterWithLContext method. + if ltx.Has() { + // We can not skip transmitting any lines to the client with a local + // grep context specified. + f.canSkipLines = false + f.filterWithLContext(ctx, ltx, rawLines, lines, re) + return + } + + f.filterWithoutLContext(ctx, rawLines, lines, re) +} + +func (f readFile) filterWithoutLContext(ctx context.Context, rawLines <-chan *bytes.Buffer, + lines chan<- line.Line, re regex.Regex) { + + for { + select { + case line, ok := <-rawLines: + f.updatePosition() + if !ok { + return + } + if filteredLine, ok := f.transmittable(line, len(lines), cap(lines), re); ok { + select { + case lines <- filteredLine: + case <-ctx.Done(): + return + } + } + } + } +} + +// Filter log lines matching a given regular expression, however with local grep context. +func (f readFile) filterWithLContext(ctx context.Context, ltx lcontext.LContext, + rawLines <-chan *bytes.Buffer, lines chan<- line.Line, re regex.Regex) { + + // Scenario 1: Finish once maxCount hits found + maxCount := ltx.MaxCount + processMaxCount := maxCount > 0 + maxReached := false + + // Scenario 2: Print prev. N lines when current line matches. + before := ltx.BeforeContext + processBefore := before > 0 + var beforeBuf chan *bytes.Buffer + if processBefore { + beforeBuf = make(chan *bytes.Buffer, before) + } + + // Screnario 3: Print next N lines when current line matches. + after := 0 + processAfter := ltx.AfterContext > 0 + + for lineBytesBuffer := range rawLines { + f.updatePosition() + + if !re.Match(lineBytesBuffer.Bytes()) { + f.updateLineNotMatched() + + if processAfter && after > 0 { + after-- + myLine := line.Line{ + Content: lineBytesBuffer, + SourceID: f.globID, + Count: f.totalLineCount(), + TransmittedPerc: 100, + } + + select { + case lines <- myLine: + case <-ctx.Done(): + return + } + + } else if processBefore { + // Keep last num BeforeContext raw messages. + select { + case beforeBuf <- lineBytesBuffer: + default: + pool.RecycleBytesBuffer(<-beforeBuf) + beforeBuf <- lineBytesBuffer + } } continue } - // Last LineRead call could not read content until end of line, buffer - // was too small. Determine whether we exceed the max line length we - // want dtail to send to the client at once. Possibly split up log line - // into multiple log lines. - if len(rawLine) >= lineLengthThreshold { - if !longLineWarning { - f.serverMessages <- logger.Warn(f.filePath, "Long log line, splitting into multiple lines") - // Only print out one warning per long log line. - longLineWarning = true + f.updateLineMatched() + + if processAfter { + if maxReached { + return } - rawLine = append(rawLine, '\n') - select { - case rawLines <- rawLine: - case <-ctx.Done(): - return nil + after = ltx.AfterContext + } + + if processBefore { + i := uint64(len(beforeBuf)) + for { + select { + case lineBytesBuffer := <-beforeBuf: + myLine := line.Line{ + Content: lineBytesBuffer, + SourceID: f.globID, + Count: f.totalLineCount() - i, + TransmittedPerc: 100, + } + i-- + + select { + case lines <- myLine: + case <-ctx.Done(): + return + } + default: + // beforeBuf is now empty. + } + if len(beforeBuf) == 0 { + break + } } - rawLine = make([]byte, 0, 512) } + + line := line.Line{ + Content: lineBytesBuffer, + SourceID: f.globID, + Count: f.totalLineCount(), + TransmittedPerc: 100, + } + + select { + case lines <- line: + if processMaxCount { + maxCount-- + if maxCount == 0 { + if !processAfter || after == 0 { + return + } + // Unfortunatley we have to continue filter, as there might be more lines to print + maxReached = true + } + } + case <-ctx.Done(): + return + } + } +} + +func (f readFile) transmittable(lineBytesBuffer *bytes.Buffer, length, capacity int, + re regex.Regex) (line.Line, bool) { + + var read line.Line + if !re.Match(lineBytesBuffer.Bytes()) { + f.updateLineNotMatched() + f.updateLineNotTransmitted() + return read, false + } + f.updateLineMatched() + + // Can we actually send more messages, channel capacity reached? + if f.canSkipLines && length >= capacity { + f.updateLineNotTransmitted() + return read, false + } + f.updateLineTransmitted() + + read = line.Line{ + Content: lineBytesBuffer, + SourceID: f.globID, + Count: f.totalLineCount(), + TransmittedPerc: f.transmittedPerc(), + } + return read, true +} + +// Check wether log file is truncated. Returns nil if not. +func (f readFile) truncated(fd *os.File) (bool, error) { + dlog.Common.Debug(f.filePath, "File truncation check") + + // Can not seek currently open FD. + curPos, err := fd.Seek(0, os.SEEK_CUR) + if err != nil { + return true, err + } + // Can not open file at original path. + pathFd, err := os.Open(f.filePath) + if err != nil { + return true, err + } + defer pathFd.Close() + + // Can not seek file at original path. + pathPos, err := pathFd.Seek(0, io.SeekEnd) + if err != nil { + return true, err + } + if curPos > pathPos { + return true, errors.New("File got truncated") } + return false, nil } diff --git a/internal/io/fs/tailfile.go b/internal/io/fs/tailfile.go index 14994e5..b03b45d 100644 --- a/internal/io/fs/tailfile.go +++ b/internal/io/fs/tailfile.go @@ -6,7 +6,9 @@ type TailFile struct { } // NewTailFile returns a new file tailer. -func NewTailFile(filePath string, globID string, serverMessages chan<- string, limiter chan struct{}) TailFile { +func NewTailFile(filePath string, globID string, serverMessages chan<- string, + limiter chan struct{}) TailFile { + return TailFile{ readFile: readFile{ filePath: filePath, diff --git a/internal/io/fs/truncate.go b/internal/io/fs/truncate.go deleted file mode 100644 index a8d59ac..0000000 --- a/internal/io/fs/truncate.go +++ /dev/null @@ -1,61 +0,0 @@ -package fs - -import ( - "context" - "errors" - "io" - "os" - "time" - - "github.com/mimecast/dtail/internal/io/logger" -) - -func (f readFile) truncateTimer(ctx context.Context) (checkTruncate chan struct{}) { - checkTruncate = make(chan struct{}) - - go func() { - for { - select { - case <-time.After(time.Second * 3): - select { - case checkTruncate <- struct{}{}: - case <-ctx.Done(): - } - case <-ctx.Done(): - return - } - } - }() - - return -} - -// Check wether log file is truncated. Returns nil if not. -func (f readFile) truncated(fd *os.File) (bool, error) { - logger.Debug(f.filePath, "File truncation check") - - // Can not seek currently open FD. - curPos, err := fd.Seek(0, os.SEEK_CUR) - if err != nil { - return true, err - } - - // Can not open file at original path. - pathFd, err := os.Open(f.filePath) - if err != nil { - return true, err - } - defer pathFd.Close() - - // Can not seek file at original path. - pathPos, err := pathFd.Seek(0, io.SeekEnd) - if err != nil { - return true, err - } - - if curPos > pathPos { - return true, errors.New("File got truncated") - } - - return false, nil -} diff --git a/internal/io/line/line.go b/internal/io/line/line.go index 715be34..d306c88 100644 --- a/internal/io/line/line.go +++ b/internal/io/line/line.go @@ -1,13 +1,14 @@ package line import ( + "bytes" "fmt" ) // Line represents a read log line. type Line struct { // The content of the log line. - Content []byte + Content *bytes.Buffer // Until now, how many log lines were processed? Count uint64 // Sometimes we produce too many log lines so that the client @@ -25,7 +26,7 @@ type Line struct { // Return a human readable representation of the followed line. func (l Line) String() string { return fmt.Sprintf("Line(Content:%s,TransmittedPerc:%v,Count:%v,SourceID:%s)", - string(l.Content), + l.Content.String(), l.TransmittedPerc, l.Count, l.SourceID) diff --git a/internal/io/logger/logger.go b/internal/io/logger/logger.go deleted file mode 100644 index 4254eef..0000000 --- a/internal/io/logger/logger.go +++ /dev/null @@ -1,403 +0,0 @@ -package logger - -import ( - "bufio" - "context" - "fmt" - "os" - "os/signal" - "runtime" - "strings" - "sync" - "syscall" - "time" - - "github.com/mimecast/dtail/internal/color" - "github.com/mimecast/dtail/internal/config" -) - -const ( - clientStr string = "CLIENT" - serverStr string = "SERVER" - infoStr string = "INFO" - warnStr string = "WARN" - errorStr string = "ERROR" - fatalStr string = "FATAL" - debugStr string = "DEBUG" - traceStr string = "TRACE" -) - -// Mode specifies the configured logging mode(s) -var Mode Modes - -// Strategy is the current log strattegy used. -var strategy Strategy - -// Synchronise access to logging. -var mutex sync.Mutex - -// File descriptor of log file when Mode.logToFile enabled. -var fd *os.File - -// File write buffer of log file when Mode.logToFile enabled. -var writer *bufio.Writer - -// File write buffer of stdout when Mode.logToStdout enabled. -var stdoutWriter *bufio.Writer - -// Current hostname. -var hostname string - -// Used to detect change of day (create one log file per day0 -var lastDateStr string - -// Used to make logging non-blocking. -var fileLogBufCh chan buf -var stdoutBufCh chan string - -// Stdout channel, required to pause output -var pauseCh chan struct{} -var resumeCh chan struct{} - -// Tell the logger about logrotation -var rotateCh chan os.Signal - -// Helper type to make logging non-blocking. -type buf struct { - time time.Time - message string -} - -// Start logging. -func Start(ctx context.Context, mode Modes) { - Mode = mode - - switch { - case Mode.Nothing: - return - case Mode.Quiet: - Mode.Trace = false - Mode.Debug = false - case Mode.Trace: - Mode.Debug = true - default: - } - - strategy := logStrategy() - stdoutWriter = bufio.NewWriter(os.Stdout) - - switch strategy { - case DailyStrategy: - _, err := os.Stat(config.Common.LogDir) - Mode.logToFile = !os.IsNotExist(err) - Mode.logToStdout = !Mode.Server || Mode.Debug || Mode.Trace || Mode.Quiet - case StdoutStrategy: - fallthrough - default: - Mode.logToFile = !Mode.Server - Mode.logToStdout = true - } - - fqdn, err := os.Hostname() - if err != nil { - panic(err) - } - s := strings.Split(fqdn, ".") - hostname = s[0] - - pauseCh = make(chan struct{}) - resumeCh = make(chan struct{}) - - // Setup logrotation - rotateCh = make(chan os.Signal, 1) - signal.Notify(rotateCh, syscall.SIGHUP) - - if Mode.logToStdout { - stdoutBufCh = make(chan string, runtime.NumCPU()*100) - go writeToStdout(ctx) - } - - if Mode.logToFile { - fileLogBufCh = make(chan buf, runtime.NumCPU()*100) - go writeToFile(ctx) - } -} - -// Info message logging. -func Info(args ...interface{}) string { - if Mode.Server { - return log(serverStr, infoStr, args) - } - - return log(clientStr, infoStr, args) -} - -// Warn message logging. -func Warn(args ...interface{}) string { - if !Mode.Quiet { - if Mode.Server { - return log(serverStr, warnStr, args) - } - return log(clientStr, warnStr, args) - } - - return "" -} - -// Error message logging. -func Error(args ...interface{}) string { - if Mode.Server { - return log(serverStr, errorStr, args) - } - - return log(clientStr, errorStr, args) -} - -// Fatal message logging. -func Fatal(args ...interface{}) string { - if Mode.Server { - return log(serverStr, fatalStr, args) - } - - return log(clientStr, fatalStr, args) -} - -// FatalExit logs an error and exists the process. -func FatalExit(args ...interface{}) { - what := clientStr - if Mode.Server { - what = serverStr - } - log(what, fatalStr, args) - - time.Sleep(time.Second) - mutex.Lock() - defer mutex.Unlock() - - closeWriter() - os.Exit(3) -} - -// Debug message logging. -func Debug(args ...interface{}) string { - if Mode.Debug { - if Mode.Server { - return log(serverStr, debugStr, args) - } - return log(clientStr, debugStr, args) - } - - return "" -} - -// Trace message logging. -func Trace(args ...interface{}) string { - if Mode.Trace { - if Mode.Server { - return log(serverStr, traceStr, args) - } - return log(clientStr, traceStr, args) - } - - return "" -} - -// Write log line to buffer and/or log file. -func write(what, severity, message string) { - if Mode.logToStdout { - line := fmt.Sprintf("%s|%s|%s|%s\n", what, hostname, severity, message) - - if color.Colored { - line = color.Colorfy(line) - } - - stdoutBufCh <- line - } - - if Mode.logToFile { - t := time.Now() - timeStr := t.Format("20060102-150405") - fileLogBufCh <- buf{ - time: t, - message: fmt.Sprintf("%s|%s|%s|%s\n", severity, timeStr, what, message), - } - } -} - -// Generig log message. -func log(what string, severity string, args []interface{}) string { - if Mode.Nothing { - return "" - } - - messages := []string{} - - for _, arg := range args { - switch v := arg.(type) { - case string: - messages = append(messages, v) - case int: - messages = append(messages, fmt.Sprintf("%d", v)) - case error: - messages = append(messages, v.Error()) - default: - messages = append(messages, fmt.Sprintf("%v", v)) - } - } - - message := strings.Join(messages, "|") - write(what, severity, message) - - return fmt.Sprintf("%s|%s", severity, message) -} - -// Raw message logging. -func Raw(message string) { - if Mode.Nothing { - return - } - - if Mode.logToFile { - fileLogBufCh <- buf{time.Now(), message} - } - - if Mode.logToStdout { - if color.Colored { - message = color.Colorfy(message) - } - stdoutBufCh <- message - } -} - -// Close log writer (e.g. on change of day). -func closeWriter() { - if writer != nil { - writer.Flush() - fd.Close() - } -} - -// Return the correct log file writer -func fileWriter(dateStr string) *bufio.Writer { - if dateStr != lastDateStr { - return updateFileWriter(dateStr) - } - - // Check for log rotation signal - select { - case <-rotateCh: - stdoutWriter.WriteString("Received signal for logrotation\n") - return updateFileWriter(dateStr) - default: - } - - return writer -} - -// Update log file writer -func updateFileWriter(dateStr string) *bufio.Writer { - // Detected change of day. Close current writer and create a new one. - mutex.Lock() - defer mutex.Unlock() - closeWriter() - - if _, err := os.Stat(config.Common.LogDir); os.IsNotExist(err) { - if err = os.MkdirAll(config.Common.LogDir, 0755); err != nil { - panic(err) - } - } - - logFile := fmt.Sprintf("%s/%s.log", config.Common.LogDir, dateStr) - newFd, err := os.OpenFile(logFile, os.O_CREATE|os.O_RDWR|os.O_APPEND, 0644) - if err != nil { - panic(err) - } - - fd = newFd - writer = bufio.NewWriterSize(fd, 1) - lastDateStr = dateStr - - return writer -} - -// Flush all outstanding lines. -func Flush() { - for { - select { - case message := <-stdoutBufCh: - stdoutWriter.WriteString(message) - default: - stdoutWriter.Flush() - return - } - } -} - -func writeToStdout(ctx context.Context) { - for { - select { - case message := <-stdoutBufCh: - stdoutWriter.WriteString(message) - case <-time.After(time.Millisecond * 100): - stdoutWriter.Flush() - case <-pauseCh: - PAUSE: - for { - select { - case <-stdoutBufCh: - case <-resumeCh: - break PAUSE - case <-ctx.Done(): - return - } - } - case <-ctx.Done(): - Flush() - return - } - } -} - -func writeToFile(ctx context.Context) { - for { - select { - case buf := <-fileLogBufCh: - dateStr := buf.time.Format("20060102") - w := fileWriter(dateStr) - w.WriteString(buf.message) - case <-pauseCh: - PAUSE: - for { - select { - case <-stdoutBufCh: - case <-resumeCh: - break PAUSE - case <-ctx.Done(): - return - } - } - case <-ctx.Done(): - return - } - } -} - -// Pause logging. -func Pause() { - if Mode.logToStdout { - pauseCh <- struct{}{} - } - if Mode.logToFile { - pauseCh <- struct{}{} - } -} - -// Resume logging (after pausing). -func Resume() { - if Mode.logToStdout { - resumeCh <- struct{}{} - } - if Mode.logToFile { - resumeCh <- struct{}{} - } -} diff --git a/internal/io/logger/modes.go b/internal/io/logger/modes.go deleted file mode 100644 index 8864179..0000000 --- a/internal/io/logger/modes.go +++ /dev/null @@ -1,12 +0,0 @@ -package logger - -// Modes specifies the logging mode. -type Modes struct { - Server bool - Trace bool - Debug bool - Nothing bool - Quiet bool - logToStdout bool - logToFile bool -} diff --git a/internal/io/logger/strategy.go b/internal/io/logger/strategy.go deleted file mode 100644 index 44bf393..0000000 --- a/internal/io/logger/strategy.go +++ /dev/null @@ -1,22 +0,0 @@ -package logger - -import "github.com/mimecast/dtail/internal/config" - -// Strategy allows to specify a log rotation strategy. -type Strategy int - -// Possible log strategies. -const ( - NormalStrategy Strategy = iota - DailyStrategy Strategy = iota - StdoutStrategy Strategy = iota -) - -func logStrategy() Strategy { - switch config.Common.LogStrategy { - case "daily": - return DailyStrategy - default: - } - return StdoutStrategy -} diff --git a/internal/io/pool/builder.go b/internal/io/pool/builder.go new file mode 100644 index 0000000..89fcf81 --- /dev/null +++ b/internal/io/pool/builder.go @@ -0,0 +1,21 @@ +package pool + +import ( + "strings" + "sync" +) + +// BuilderBuffer is there to optimize memory allocations (DTail allocates a lot +// of memory while reading log data otherwise) +var BuilderBuffer = sync.Pool{ + New: func() interface{} { + sb := strings.Builder{} + return &sb + }, +} + +// RecycleBuilderBuffer recycles the buffer again. +func RecycleBuilderBuffer(sb *strings.Builder) { + sb.Reset() + BuilderBuffer.Put(sb) +} diff --git a/internal/io/pool/bytesbuffer.go b/internal/io/pool/bytesbuffer.go new file mode 100644 index 0000000..3d48f2c --- /dev/null +++ b/internal/io/pool/bytesbuffer.go @@ -0,0 +1,22 @@ +package pool + +import ( + "bytes" + "sync" +) + +// BytesBuffer is there to optimize memory allocations. DTail otherwise allocates +// a lot of memory while reading logs. +var BytesBuffer = sync.Pool{ + New: func() interface{} { + b := bytes.Buffer{} + b.Grow(128) + return &b + }, +} + +// RecycleBytesBuffer recycles the buffer again. +func RecycleBytesBuffer(b *bytes.Buffer) { + b.Reset() + BytesBuffer.Put(b) +} diff --git a/internal/io/prompt/prompt.go b/internal/io/prompt/prompt.go index 36ebdb5..e82132d 100644 --- a/internal/io/prompt/prompt.go +++ b/internal/io/prompt/prompt.go @@ -6,7 +6,7 @@ import ( "os" "strings" - "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/io/dlog" ) // Answer is a user input of a prompt question. @@ -19,7 +19,8 @@ type Answer struct { Callback func() // Runs after Callback and after logging resumes EndCallback func() - AskAgain bool + // AskAgain can be used to not to ask again about the question. + AskAgain bool } // Prompt used for interactive user input. @@ -30,7 +31,6 @@ type Prompt struct { func (p *Prompt) askString() string { var sb strings.Builder - sb.WriteString(p.question) sb.WriteString("? (") @@ -41,7 +41,6 @@ func (p *Prompt) askString() string { sb.WriteString(strings.Join(ax, ",")) sb.WriteString("): ") - return sb.String() } @@ -58,7 +57,7 @@ func (p *Prompt) Add(answer Answer) { // Ask a question. func (p *Prompt) Ask() { reader := bufio.NewReader(os.Stdin) - logger.Pause() + dlog.Common.Pause() for { fmt.Print(p.askString()) @@ -68,9 +67,8 @@ func (p *Prompt) Ask() { if a.Callback != nil { a.Callback() } - if !a.AskAgain { - logger.Resume() + dlog.Common.Resume() if a.EndCallback != nil { a.EndCallback() } @@ -90,6 +88,5 @@ func (p *Prompt) answer(answerStr string) (*Answer, bool) { default: } } - return nil, false } diff --git a/internal/io/signal/signal.go b/internal/io/signal/signal.go index 500c530..584b59c 100644 --- a/internal/io/signal/signal.go +++ b/internal/io/signal/signal.go @@ -14,10 +14,8 @@ import ( func InterruptCh(ctx context.Context) <-chan string { sigIntCh := make(chan os.Signal) gosignal.Notify(sigIntCh, os.Interrupt) - sigOtherCh := make(chan os.Signal) gosignal.Notify(sigOtherCh, syscall.SIGHUP, syscall.SIGTERM, syscall.SIGQUIT) - statsCh := make(chan string) go func() { @@ -41,6 +39,10 @@ func InterruptCh(ctx context.Context) <-chan string { } } }() - return statsCh } + +// NoCh doesn't listen on a signal. +func NoCh(ctx context.Context) <-chan string { + return make(chan string) +} diff --git a/internal/lcontext/lcontext.go b/internal/lcontext/lcontext.go index 89cb7c3..183ceb5 100644 --- a/internal/lcontext/lcontext.go +++ b/internal/lcontext/lcontext.go @@ -1,6 +1,6 @@ package lcontext -// LContext stands for line context and is here to help filtering out only specific lines. +// LContext stands for line context (used by context aware grep queries e.g.) type LContext struct { AfterContext int BeforeContext int diff --git a/internal/mapr/aggregateset.go b/internal/mapr/aggregateset.go index a6cc6eb..c50c7a1 100644 --- a/internal/mapr/aggregateset.go +++ b/internal/mapr/aggregateset.go @@ -5,6 +5,10 @@ import ( "fmt" "strconv" "strings" + + "github.com/mimecast/dtail/internal/io/dlog" + "github.com/mimecast/dtail/internal/io/pool" + "github.com/mimecast/dtail/internal/protocol" ) // AggregateSet represents aggregated key/value pairs from the @@ -33,8 +37,7 @@ func (s *AggregateSet) String() string { // Merge one aggregate set into this one. func (s *AggregateSet) Merge(query *Query, set *AggregateSet) error { s.Samples += set.Samples - //logger.Trace("Merge", set) - + //dlog.Common.Trace("Merge", set) for _, sc := range query.Select { storage := sc.FieldStorage switch sc.Operation { @@ -66,24 +69,27 @@ func (s *AggregateSet) Merge(query *Query, set *AggregateSet) error { // Serialize the aggregate set so it can be sent over the wire. func (s *AggregateSet) Serialize(ctx context.Context, groupKey string, ch chan<- string) { - //logger.Trace("Serialising mapr.AggregateSet", s) - var sb strings.Builder + dlog.Common.Trace("Serialising mapr.AggregateSet", s) + sb := pool.BuilderBuffer.Get().(*strings.Builder) + defer pool.RecycleBuilderBuffer(sb) sb.WriteString(groupKey) - sb.WriteString("➔") - sb.WriteString(fmt.Sprintf("%d➔", s.Samples)) + sb.WriteString(protocol.AggregateDelimiter) + sb.WriteString(fmt.Sprintf("%d", s.Samples)) + sb.WriteString(protocol.AggregateDelimiter) for k, v := range s.FValues { sb.WriteString(k) - sb.WriteString("=") - sb.WriteString(fmt.Sprintf("%v➔", v)) + sb.WriteString(protocol.AggregateKVDelimiter) + sb.WriteString(fmt.Sprintf("%v", v)) + sb.WriteString(protocol.AggregateDelimiter) } for k, v := range s.SValues { sb.WriteString(k) - sb.WriteString("=") + sb.WriteString(protocol.AggregateKVDelimiter) sb.WriteString(v) - sb.WriteString("➔") + sb.WriteString(protocol.AggregateDelimiter) } select { @@ -108,7 +114,6 @@ func (s *AggregateSet) addFloatMin(key string, value float64) { s.FValues[key] = value return } - if f > value { s.FValues[key] = value } @@ -121,7 +126,6 @@ func (s *AggregateSet) addFloatMax(key string, value float64) { s.FValues[key] = value return } - if f < value { s.FValues[key] = value } @@ -140,7 +144,6 @@ func (s *AggregateSet) setFloat(key string, value float64) { // Aggregate data to the aggregate set. func (s *AggregateSet) Aggregate(key string, agg AggregateOperation, value string, clientAggregation bool) (err error) { var f float64 - // First check if we can aggregate anything without converting value to float. switch agg { case Count: diff --git a/internal/mapr/client/aggregate.go b/internal/mapr/client/aggregate.go index 10b34d4..02a6a5a 100644 --- a/internal/mapr/client/aggregate.go +++ b/internal/mapr/client/aggregate.go @@ -1,11 +1,13 @@ package client import ( + "fmt" "strconv" "strings" - "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/io/dlog" "github.com/mimecast/dtail/internal/mapr" + "github.com/mimecast/dtail/internal/protocol" ) // Aggregate mapreduce data on the DTail client side. @@ -21,7 +23,9 @@ type Aggregate struct { } // NewAggregate create new client aggregator. -func NewAggregate(server string, query *mapr.Query, globalGroup *mapr.GlobalGroupSet) *Aggregate { +func NewAggregate(server string, query *mapr.Query, + globalGroup *mapr.GlobalGroupSet) *Aggregate { + return &Aggregate{ query: query, group: mapr.NewGroupSet(), @@ -31,20 +35,26 @@ func NewAggregate(server string, query *mapr.Query, globalGroup *mapr.GlobalGrou } // Aggregate data from mapr log line into local (and global) group sets. -func (a *Aggregate) Aggregate(parts []string) { +func (a *Aggregate) Aggregate(message string) error { + parts := strings.Split(message, protocol.AggregateDelimiter) + if len(parts) < 4 { + return fmt.Errorf("aggregate message without any real data") + } + groupKey := parts[0] samples, err := strconv.Atoi(parts[1]) if err != nil { - logger.FatalExit("Unable to parse sample count", parts[1], err, parts) + return fmt.Errorf("unable to parse sample count '%s': %v", parts[1], err) } + fields := a.makeFields(parts[2:]) set := a.group.GetSet(groupKey) - var addedSamples bool + for _, sc := range a.query.Select { if val, ok := fields[sc.FieldStorage]; ok { if err := set.Aggregate(sc.FieldStorage, sc.Operation, val, true); err != nil { - logger.Error(err) + dlog.Common.Error(err) continue } addedSamples = true @@ -63,19 +73,18 @@ func (a *Aggregate) Aggregate(parts []string) { // Re-init local group (make it empty again). a.group.InitSet() } + return nil } // Create a map of key-value pairs from a part list such as ["foo=bar", "bar=baz"]. func (a *Aggregate) makeFields(parts []string) map[string]string { fields := make(map[string]string, len(parts)) - for _, part := range parts { - kv := strings.SplitN(part, "=", 2) - if len(kv) < 2 { + kv := strings.SplitN(part, protocol.AggregateKVDelimiter, 2) + if len(kv) != 2 { continue } fields[kv[0]] = kv[1] } - return fields } diff --git a/internal/mapr/funcs/function.go b/internal/mapr/funcs/function.go index 1a89c3a..418d86f 100644 --- a/internal/mapr/funcs/function.go +++ b/internal/mapr/funcs/function.go @@ -19,13 +19,12 @@ type Function struct { // FunctionStack is a list of functions stacked each other type FunctionStack []Function -// NewFunctionStack parses the input string, e.g. foo(bar("arg")) and returns a corresponding function stack. +// NewFunctionStack parses the input string, e.g. foo(bar("arg")) and returns +// a corresponding function stack. func NewFunctionStack(in string) (FunctionStack, string, error) { var fs FunctionStack - getCallback := func(name string) (CallbackFunc, error) { var cb CallbackFunc - switch name { case "md5sum": return Md5Sum, nil @@ -51,17 +50,15 @@ func NewFunctionStack(in string) (FunctionStack, string, error) { fs = append(fs, Function{name, call}) aux = aux[index+1 : len(aux)-1] } - return fs, aux, nil } // Call the function stack. func (fs FunctionStack) Call(str string) string { for i := len(fs) - 1; i >= 0; i-- { - //logger.Debug("Call", fs[i].Name, str) + //dlog.Common.Debug("Call", fs[i].Name, str) str = fs[i].call(str) - //logger.Debug("Call.result", fs[i].Name, str) + //dlog.Common.Debug("Call.result", fs[i].Name, str) } - return str } diff --git a/internal/mapr/funcs/function_test.go b/internal/mapr/funcs/function_test.go index 415683c..8b5d8b7 100644 --- a/internal/mapr/funcs/function_test.go +++ b/internal/mapr/funcs/function_test.go @@ -6,16 +6,19 @@ func TestFunction(t *testing.T) { input := "md5sum($line)" fs, arg, err := NewFunctionStack(input) if err != nil { - t.Errorf("error parsing function input '%s': %s (%v)\n", input, err.Error(), fs) + t.Errorf("error parsing function input '%s': %s (%v)\n", + input, err.Error(), fs) } if arg != "$line" { - t.Errorf("error parsing function input '%s': expected argument '$line' but got '%s' (%v)\n", input, arg, fs) + t.Errorf("error parsing function input '%s': expected argument '$line' but "+ + "got '%s' (%v)\n", input, arg, fs) } t.Log(input, fs, arg) result := fs.Call(input) if result != "b38699013d79e50d9d122433753959c1" { - t.Errorf("error executing function stack '%s': expected result 'b38699013d79e50d9d122433753959c1' but got '%s' (%v)\n", input, result, fs) + t.Errorf("error executing function stack '%s': expected result "+ + "'b38699013d79e50d9d122433753959c1' but got '%s' (%v)\n", input, result, fs) } input = "maskdigits(md5sum(maskdigits($line)))" @@ -24,22 +27,26 @@ func TestFunction(t *testing.T) { t.Errorf("error parsing function input '%s': %s (%v)\n", input, err.Error(), fs) } if arg != "$line" { - t.Errorf("error parsing function input '%s': expected argument '$line' but got '%s' (%v)\n", input, arg, fs) + t.Errorf("error parsing function input '%s': expected argument '$line' but "+ + "got '%s' (%v)\n", input, arg, fs) } t.Log(input, fs, arg) result = fs.Call(input) if result != ".fac.bbe..bb.........d...a.c..b." { - t.Errorf("error executing function stack '%s': expected result '.fac.bbe..bb.........d...a.c..b.' but got '%s' (%v)\n", input, result, fs) + t.Errorf("error executing function stack '%s': expected result "+ + "'.fac.bbe..bb.........d...a.c..b.' but got '%s' (%v)\n", input, result, fs) } input = "md5sum$line)" if fs, _, err := NewFunctionStack(input); err == nil { - t.Errorf("Expected error parsing function input '%s' (%v) but got no error\n", input, fs) + t.Errorf("Expected error parsing function input '%s' (%v) but got no error\n", + input, fs) } input = "md5sum(makedigits$line))" if fs, _, err := NewFunctionStack(input); err == nil { - t.Errorf("Expected error parsing function input '%s' (%v) but got no error\n", input, fs) + t.Errorf("Expected error parsing function input '%s' (%v) but got no error\n", + input, fs) } } diff --git a/internal/mapr/funcs/maskdigits.go b/internal/mapr/funcs/maskdigits.go index d51f3d8..925ec4d 100644 --- a/internal/mapr/funcs/maskdigits.go +++ b/internal/mapr/funcs/maskdigits.go @@ -3,12 +3,10 @@ package funcs // MaskDigits masks all digits (replaces them with .) func MaskDigits(input string) string { s := []byte(input) - for i, b := range s { if '0' <= b && b <= '9' { s[i] = '.' } } - return string(s) } diff --git a/internal/mapr/globalgroupset.go b/internal/mapr/globalgroupset.go index cfab506..2d7f10b 100644 --- a/internal/mapr/globalgroupset.go +++ b/internal/mapr/globalgroupset.go @@ -17,7 +17,6 @@ func NewGlobalGroupSet() *GlobalGroupSet { semaphore: make(chan struct{}, 1), } g.InitSet() - return &g } @@ -30,7 +29,6 @@ func (g *GlobalGroupSet) String() string { func (g *GlobalGroupSet) Merge(query *Query, group *GroupSet) error { g.semaphore <- struct{}{} defer func() { <-g.semaphore }() - return g.merge(query, group) } @@ -54,7 +52,6 @@ func (g *GlobalGroupSet) merge(query *Query, group *GroupSet) error { return err } } - return nil } @@ -67,7 +64,6 @@ func (g *GlobalGroupSet) IsEmpty() bool { func (g *GlobalGroupSet) NumSets() int { g.semaphore <- struct{}{} defer func() { <-g.semaphore }() - return len(g.sets) } @@ -79,7 +75,6 @@ func (g *GlobalGroupSet) SwapOut() *GroupSet { set := &GroupSet{sets: g.sets} g.InitSet() - return set } @@ -87,14 +82,12 @@ func (g *GlobalGroupSet) SwapOut() *GroupSet { func (g *GlobalGroupSet) WriteResult(query *Query) error { g.semaphore <- struct{}{} defer func() { <-g.semaphore }() - return g.GroupSet.WriteResult(query) } // Result returns the result of the mapreduce aggregation as a string. -func (g *GlobalGroupSet) Result(query *Query) (string, int, error) { +func (g *GlobalGroupSet) Result(query *Query, rowsLimit int) (string, int, error) { g.semaphore <- struct{}{} defer func() { <-g.semaphore }() - - return g.GroupSet.Result(query) + return g.GroupSet.Result(query, rowsLimit) } diff --git a/internal/mapr/groupset.go b/internal/mapr/groupset.go index b5c8a48..6ffc8b9 100644 --- a/internal/mapr/groupset.go +++ b/internal/mapr/groupset.go @@ -4,13 +4,16 @@ import ( "context" "errors" "fmt" - "io/ioutil" "os" "sort" "strconv" "strings" - "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/color" + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/io/dlog" + "github.com/mimecast/dtail/internal/io/pool" + "github.com/mimecast/dtail/internal/protocol" ) // GroupSet represents a map of aggregate sets. The group sets @@ -22,6 +25,14 @@ type GroupSet struct { sets map[string]*AggregateSet } +// Internal helper type +type result struct { + groupKey string + values []string + widths []int + orderBy float64 +} + // NewGroupSet returns a new empty group set. func NewGroupSet() *GroupSet { g := GroupSet{} @@ -57,28 +68,181 @@ func (g *GroupSet) Serialize(ctx context.Context, ch chan<- string) { } // Result returns a nicely formated result of the query from the group set. -func (g *GroupSet) Result(query *Query) (string, int, error) { - return g.limitedResult(query, query.Limit, "\t", " ", false) +func (g *GroupSet) Result(query *Query, rowsLimit int) (string, int, error) { + rows, widths, err := g.result(query, true) + if err != nil { + return "", 0, err + } + if query.Limit != -1 { + rowsLimit = query.Limit + } + + sb := pool.BuilderBuffer.Get().(*strings.Builder) + defer pool.RecycleBuilderBuffer(sb) + + // Generate header now + lastIndex := len(query.Select) - 1 + for i, sc := range query.Select { + format := fmt.Sprintf(" %%%ds ", widths[i]) + str := fmt.Sprintf(format, sc.FieldStorage) + if config.Client.TermColorsEnable { + attrs := []color.Attribute{config.Client.TermColors.MaprTable.HeaderAttr} + if sc.FieldStorage == query.OrderBy { + attrs = append(attrs, config.Client.TermColors.MaprTable.HeaderSortKeyAttr) + } + + for _, groupBy := range query.GroupBy { + if sc.FieldStorage == groupBy { + attrs = append(attrs, config.Client.TermColors.MaprTable.HeaderGroupKeyAttr) + break + } + } + + color.PaintWithAttrs(sb, str, + config.Client.TermColors.MaprTable.HeaderFg, + config.Client.TermColors.MaprTable.HeaderBg, + attrs) + } else { + sb.WriteString(str) + } + + if i == lastIndex { + continue + } + if config.Client.TermColorsEnable { + color.PaintWithAttr(sb, protocol.FieldDelimiter, + config.Client.TermColors.MaprTable.HeaderDelimiterFg, + config.Client.TermColors.MaprTable.HeaderDelimiterBg, + config.Client.TermColors.MaprTable.HeaderDelimiterAttr) + } else { + sb.WriteString(protocol.FieldDelimiter) + } + } + sb.WriteString("\n") + + for i := 0; i < len(query.Select); i++ { + str := fmt.Sprintf("-%s-", strings.Repeat("-", widths[i])) + if config.Client.TermColorsEnable { + color.PaintWithAttr(sb, str, + config.Client.TermColors.MaprTable.HeaderDelimiterFg, + config.Client.TermColors.MaprTable.HeaderDelimiterBg, + config.Client.TermColors.MaprTable.HeaderDelimiterAttr) + } else { + sb.WriteString(str) + } + if i == lastIndex { + continue + } + if config.Client.TermColorsEnable { + color.PaintWithAttr(sb, protocol.FieldDelimiter, + config.Client.TermColors.MaprTable.HeaderDelimiterFg, + config.Client.TermColors.MaprTable.HeaderDelimiterBg, + config.Client.TermColors.MaprTable.HeaderDelimiterAttr) + } else { + sb.WriteString(protocol.FieldDelimiter) + } + } + sb.WriteString("\n") + + // And now write the data + for i, r := range rows { + if i == rowsLimit { + break + } + for j, value := range r.values { + format := fmt.Sprintf(" %%%ds ", widths[j]) + str := fmt.Sprintf(format, value) + if config.Client.TermColorsEnable { + color.PaintWithAttr(sb, str, + config.Client.TermColors.MaprTable.DataFg, + config.Client.TermColors.MaprTable.DataBg, + config.Client.TermColors.MaprTable.DataAttr) + } else { + sb.WriteString(str) + } + + if j == lastIndex { + continue + } + if config.Client.TermColorsEnable { + color.PaintWithAttr(sb, protocol.FieldDelimiter, + config.Client.TermColors.MaprTable.DelimiterFg, + config.Client.TermColors.MaprTable.DelimiterBg, + config.Client.TermColors.MaprTable.DelimiterAttr) + } else { + sb.WriteString(protocol.FieldDelimiter) + } + } + sb.WriteString("\n") + } + + return sb.String(), len(rows), nil } -// WriteResult writes the result to an outfile. +func (*GroupSet) writeQueryFile(query *Query) error { + queryFile := fmt.Sprintf("%s.query", query.Outfile) + tmpQueryFile := fmt.Sprintf("%s.tmp", queryFile) + dlog.Common.Debug("Writing query file", queryFile) + + fd, err := os.Create(tmpQueryFile) + if err != nil { + return err + } + defer fd.Close() + + fd.WriteString(query.RawQuery) + os.Rename(tmpQueryFile, queryFile) + return nil +} + +// WriteResult writes the result to an CSV outfile. func (g *GroupSet) WriteResult(query *Query) error { if !query.HasOutfile() { return errors.New("No outfile specified") } + if err := g.writeQueryFile(query); err != nil { + return err + } - // -1: Don't limit the result, include all data sets - result, _, err := g.limitedResult(query, query.Limit, "", ",", true) + rows, _, err := g.result(query, false) if err != nil { return err } - logger.Info("Writing outfile", query.Outfile) + dlog.Common.Info("Writing outfile", query.Outfile) tmpOutfile := fmt.Sprintf("%s.tmp", query.Outfile) - if err := ioutil.WriteFile(tmpOutfile, []byte(result), 0644); err != nil { + fd, err := os.Create(tmpOutfile) + if err != nil { return err } + defer fd.Close() + + // Generate header now + lastIndex := len(query.Select) - 1 + for i, sc := range query.Select { + fd.WriteString(sc.FieldStorage) + if i == lastIndex { + continue + } + fd.WriteString(protocol.CSVDelimiter) + } + fd.WriteString("\n") + + // And now write the data + for i, r := range rows { + if i == query.Limit { + break + } + for j, value := range r.values { + fd.WriteString(value) + if j == lastIndex { + continue + } + fd.WriteString(protocol.CSVDelimiter) + } + fd.WriteString("\n") + } if err := os.Rename(tmpOutfile, query.Outfile); err != nil { os.Remove(tmpOutfile) @@ -88,32 +252,21 @@ func (g *GroupSet) WriteResult(query *Query) error { return nil } -// Return a nicely formated result of the query from the group set. -func (g *GroupSet) limitedResult(query *Query, limit int, lineStarter, fieldSeparator string, addHeader bool) (string, int, error) { - type result struct { - groupKey string - resultStr string - orderBy float64 - } - - var resultSlice []result +// Return a sorted result slice of the query from the group set. +func (g *GroupSet) result(query *Query, gatherWidths bool) ([]result, []int, error) { + var rows []result + widths := make([]int, len(query.Select)) + var valueStr string + var value float64 for groupKey, set := range g.sets { - var sb strings.Builder r := result{groupKey: groupKey} - lastIndex := len(query.Select) - 1 for i, sc := range query.Select { - storage := sc.FieldStorage - orderByThis := storage == query.OrderBy - switch sc.Operation { case Count: - value := set.FValues[storage] - sb.WriteString(fmt.Sprintf("%d", int(value))) - if orderByThis { - r.orderBy = value - } + value = set.FValues[sc.FieldStorage] + valueStr = fmt.Sprintf("%d", int(value)) case Len: fallthrough case Sum: @@ -121,74 +274,48 @@ func (g *GroupSet) limitedResult(query *Query, limit int, lineStarter, fieldSepa case Min: fallthrough case Max: - value := set.FValues[storage] - sb.WriteString(fmt.Sprintf("%f", value)) - if orderByThis { - r.orderBy = value - } + value = set.FValues[sc.FieldStorage] + valueStr = fmt.Sprintf("%f", value) case Last: - value := set.SValues[storage] - if orderByThis { - f, err := strconv.ParseFloat(value, 64) - if err == nil { - r.orderBy = f - } - } - sb.WriteString(value) + valueStr = set.SValues[sc.FieldStorage] + value, _ = strconv.ParseFloat(valueStr, 64) case Avg: - value := set.FValues[storage] / float64(set.Samples) - sb.WriteString(fmt.Sprintf("%f", value)) - if orderByThis { - r.orderBy = value - } + value = set.FValues[sc.FieldStorage] / float64(set.Samples) + valueStr = fmt.Sprintf("%f", value) default: - return "", 0, fmt.Errorf("Unknown aggregation method '%v'", sc.Operation) + return rows, widths, fmt.Errorf("Unknown aggregation method '%v'", + sc.Operation) } - if i != lastIndex { - sb.WriteString(fieldSeparator) + + if sc.FieldStorage == query.OrderBy { + r.orderBy = value } - } + r.values = append(r.values, valueStr) - r.resultStr = sb.String() - resultSlice = append(resultSlice, r) + if !gatherWidths { + continue + } + if widths[i] < len(sc.FieldStorage) { + widths[i] = len(sc.FieldStorage) + } + if widths[i] < len(valueStr) { + widths[i] = len(valueStr) + } + } + rows = append(rows, r) } if query.OrderBy != "" { if query.ReverseOrder { - sort.SliceStable(resultSlice, func(i, j int) bool { - return resultSlice[i].orderBy < resultSlice[j].orderBy + sort.SliceStable(rows, func(i, j int) bool { + return rows[i].orderBy < rows[j].orderBy }) } else { - sort.SliceStable(resultSlice, func(i, j int) bool { - return resultSlice[i].orderBy > resultSlice[j].orderBy + sort.SliceStable(rows, func(i, j int) bool { + return rows[i].orderBy > rows[j].orderBy }) } } - var sb strings.Builder - - // Write header first - if addHeader { - lastIndex := len(query.Select) - 1 - sb.WriteString(lineStarter) - for i, sc := range query.Select { - sb.WriteString(sc.FieldStorage) - if i != lastIndex { - sb.WriteString(fieldSeparator) - } - } - sb.WriteString("\n") - } - - // And now write the data - for i, r := range resultSlice { - if i == limit { - break - } - sb.WriteString(lineStarter) - sb.WriteString(r.resultStr) - sb.WriteString("\n") - } - - return sb.String(), len(resultSlice), nil + return rows, widths, nil } diff --git a/internal/mapr/logformat/default.go b/internal/mapr/logformat/default.go index 44bf558..9b6c855 100644 --- a/internal/mapr/logformat/default.go +++ b/internal/mapr/logformat/default.go @@ -1,14 +1,23 @@ package logformat import ( - "errors" + "fmt" "strings" + + "github.com/mimecast/dtail/internal/protocol" ) -// MakeFieldsDEFAULT is the default log file mapreduce parser. +// MakeFieldsDEFAULT is the default DTail log file key-value parser. func (p *Parser) MakeFieldsDEFAULT(maprLine string) (map[string]string, error) { - fields := make(map[string]string, 20) - splitted := strings.Split(maprLine, "|") + splitted := strings.Split(maprLine, protocol.FieldDelimiter) + + if len(splitted) < 11 || !strings.HasPrefix(splitted[9], "MAPREDUCE:") || + !strings.HasPrefix(splitted[0], "INFO") { + // Not a DTail mapreduce log line. + return nil, ErrIgnoreFields + } + + fields := make(map[string]string, len(splitted)+8) fields["*"] = "*" fields["$line"] = maprLine @@ -17,10 +26,30 @@ func (p *Parser) MakeFieldsDEFAULT(maprLine string) (map[string]string, error) { fields["$timezone"] = p.timeZoneName fields["$timeoffset"] = p.timeZoneOffset - for _, kv := range splitted { + fields["$severity"] = splitted[0] + fields["$loglevel"] = splitted[0] + + time := splitted[1] + fields["$time"] = time + if len(time) == 15 { + // Example: 20211002-071209 + fields["$date"] = time[0:8] + fields["$hour"] = time[9:11] + fields["$minute"] = time[11:13] + fields["$second"] = time[13:] + } + fields["$pid"] = splitted[2] + fields["$caller"] = splitted[3] + fields["$cpus"] = splitted[4] + fields["$goroutines"] = splitted[5] + fields["$cgocalls"] = splitted[6] + fields["$loadavg"] = splitted[7] + fields["$uptime"] = splitted[8] + + for _, kv := range splitted[10:] { keyAndValue := strings.SplitN(kv, "=", 2) if len(keyAndValue) != 2 { - return fields, errors.New("Error parsing mapr token: " + kv) + return fields, fmt.Errorf("Unable to parse key-value token '%s'", kv) } fields[strings.ToLower(keyAndValue[0])] = keyAndValue[1] } diff --git a/internal/mapr/logformat/default_test.go b/internal/mapr/logformat/default_test.go index 10ec8b7..28e1acc 100644 --- a/internal/mapr/logformat/default_test.go +++ b/internal/mapr/logformat/default_test.go @@ -1,6 +1,7 @@ package logformat import ( + "fmt" "testing" ) @@ -10,26 +11,83 @@ func TestDefaultLogFormat(t *testing.T) { t.Errorf("Unable to create parser: %s", err.Error()) } - fields, err := parser.MakeFields("foo=bar|baz=bay") + date := "20211002" + hour := "07" + minute := "23" + second := "42" + time := fmt.Sprintf("%s-%s%s%s", date, hour, minute, second) - if err != nil { - t.Errorf("Unable to parse: %s", err.Error()) + inputs := []string{ + fmt.Sprintf("INFO|%s|1|default_test.go:0|8|14|7|0.21|471h0m21s|MAPREDUCE:STATS|foo=bar|bar=foo", time), + fmt.Sprintf("INFO|%s|1|default_test.go:0|8|14|7|0.21|471h0m21s|MAPREDUCE:STATS|bar=foo|foo=bar", time), } - if bar, ok := fields["foo"]; !ok { - t.Errorf("Expected field 'foo', but no such field there\n") - } else if bar != "bar" { - t.Errorf("Expected 'bar' stored in field 'foo', but got '%s'\n", bar) - } + for _, input := range inputs { + fields, err := parser.MakeFields(input) + + if err != nil { + t.Errorf("Parser unable to make fields: %s", err.Error()) + } + + if val, ok := fields["$severity"]; !ok { + t.Errorf("Expected field '$severity', but no such field there in '%s'\n", input) + } else if val != "INFO" { + t.Errorf("Expected 'Info' stored in field '$severity', but got '%s' in '%s'\n", + val, input) + } + + if val, ok := fields["$time"]; !ok { + t.Errorf("Expected field '$time', but no such field there in '%s'\n", input) + } else if val != time { + t.Errorf("Expected '%s' stored in field '$time', but got '%s' in '%s'\n", + time, val, input) + } + + if val, ok := fields["$date"]; !ok { + t.Errorf("Expected field '$date', but no such field there in '%s'\n", input) + } else if val != date { + t.Errorf("Expected '%s' stored in field '$date', but got '%s' in '%s'\n", + date, val, input) + } + + if val, ok := fields["$hour"]; !ok { + t.Errorf("Expected field '$hour', but no such field there in '%s'\n", input) + } else if val != hour { + t.Errorf("Expected '%s' stored in field '$hour', but got '%s' in '%s'\n", + hour, val, input) + } + + if val, ok := fields["$minute"]; !ok { + t.Errorf("Expected field '$minute', but no such field there in '%s'\n", input) + } else if val != minute { + t.Errorf("Expected '%s' stored in field '$minute', but got '%s' in '%s'\n", + minute, val, input) + } + + if val, ok := fields["$second"]; !ok { + t.Errorf("Expected field '$second', but no such field there in '%s'\n", input) + } else if val != second { + t.Errorf("Expected '%s' stored in field '$second', but got '%s' in '%s'\n", + second, val, input) + } + + if val, ok := fields["foo"]; !ok { + t.Errorf("Expected field 'foo', but no such field there in '%s'\n", input) + } else if val != "bar" { + t.Errorf("Expected 'bar' stored in field 'foo', but got '%s' in '%s'\n", + val, input) + } - if bay, ok := fields["baz"]; !ok { - t.Errorf("Expected field 'baz', but no such field there\n") - } else if bay != "bay" { - t.Errorf("Expected 'bay' stored in field 'baz', but got '%s'\n", bay) + if val, ok := fields["bar"]; !ok { + t.Errorf("Expected field 'bar', but no such field there in '%s'\n", input) + } else if val != "foo" { + t.Errorf("Expected 'foo' stored in field 'bar', but got '%s' in '%s'\n", + val, input) + } } - _, err = parser.MakeFields("foo=bar|bazbay") - if err == nil { - t.Errorf("Expected error but didn't: %s", err.Error()) + fields, err := parser.MakeFields("foozoo=bar|bazbay") + if _, ok := fields["foo"]; ok { + t.Errorf("Expected fiending field 'foo', but found it\n") } } diff --git a/internal/mapr/logformat/generickv.go b/internal/mapr/logformat/generickv.go new file mode 100644 index 0000000..433eb5f --- /dev/null +++ b/internal/mapr/logformat/generickv.go @@ -0,0 +1,31 @@ +package logformat + +import ( + "strings" + + "github.com/mimecast/dtail/internal/protocol" +) + +// MakeFieldsGENERIGKV is the generic key-value logfile parser. +func (p *Parser) MakeFieldsGENERIGKV(maprLine string) (map[string]string, error) { + splitted := strings.Split(maprLine, protocol.FieldDelimiter) + fields := make(map[string]string, len(splitted)) + + fields["*"] = "*" + fields["$line"] = maprLine + fields["$empty"] = "" + fields["$hostname"] = p.hostname + fields["$timezone"] = p.timeZoneName + fields["$timeoffset"] = p.timeZoneOffset + + for _, kv := range splitted[0:] { + keyAndValue := strings.SplitN(kv, "=", 2) + if len(keyAndValue) != 2 { + //dlog.Common.Debug("Unable to parse key-value token, ignoring it", kv) + continue + } + fields[strings.ToLower(keyAndValue[0])] = keyAndValue[1] + } + + return fields, nil +} diff --git a/internal/mapr/logformat/parser.go b/internal/mapr/logformat/parser.go index c53729a..129081d 100644 --- a/internal/mapr/logformat/parser.go +++ b/internal/mapr/logformat/parser.go @@ -8,10 +8,12 @@ import ( "strings" "time" - "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/mapr" ) +// ErrIgnoreFields indicates that the fields should be ignored. +var ErrIgnoreFields error = errors.New("Ignore this field set") + // Parser is used to parse the mapreduce information from the server log files. type Parser struct { hostname string @@ -25,11 +27,9 @@ type Parser struct { // NewParser returns a new log parser. func NewParser(logFormatName string, query *mapr.Query) (*Parser, error) { hostname, err := os.Hostname() - if err != nil { return nil, err } - now := time.Now() zone, offset := now.Zone() @@ -43,7 +43,6 @@ func NewParser(logFormatName string, query *mapr.Query) (*Parser, error) { if err != nil { return nil, err } - return &p, nil } @@ -52,7 +51,6 @@ func NewParser(logFormatName string, query *mapr.Query) (*Parser, error) { // Parser. Whereas MODULENAME must be a upeprcase string. func (p *Parser) reflectLogFormat(logFormatName string) error { methodName := fmt.Sprintf("MakeFields%s", strings.ToUpper(logFormatName)) - rt := reflect.TypeOf(p) method, ok := rt.MethodByName(methodName) if !ok { @@ -61,7 +59,6 @@ func (p *Parser) reflectLogFormat(logFormatName string) error { p.makeFieldsFunc = method.Func p.makeFieldsReceiver = reflect.ValueOf(p) - return nil } @@ -69,17 +66,11 @@ func (p *Parser) reflectLogFormat(logFormatName string) error { func (p *Parser) MakeFields(maprLine string) (fields map[string]string, err error) { inputValues := []reflect.Value{p.makeFieldsReceiver, reflect.ValueOf(maprLine)} returnValues := p.makeFieldsFunc.Call(inputValues) - errInterface := returnValues[1].Interface() - if errInterface == nil { fields, err = returnValues[0].Interface().(map[string]string), nil - logger.Trace("parser.MakeFields", fields, err) return } - fields, err = returnValues[0].Interface().(map[string]string), errInterface.(error) - logger.Trace("parser.MakeFields", fields, err) - return } diff --git a/internal/mapr/query.go b/internal/mapr/query.go index 01852da..d7c32bd 100644 --- a/internal/mapr/query.go +++ b/internal/mapr/query.go @@ -6,8 +6,6 @@ import ( "strconv" "strings" "time" - - "github.com/mimecast/dtail/internal/io/logger" ) const ( @@ -34,7 +32,9 @@ type Query struct { } func (q Query) String() string { - return fmt.Sprintf("Query(Select:%v,Table:%s,Where:%v,Set:%vGroupBy:%v,GroupKey:%s,OrderBy:%v,ReverseOrder:%v,Interval:%v,Limit:%d,Outfile:%s,RawQuery:%s,tokens:%v,LogFormat:%s)", + return fmt.Sprintf("Query(Select:%v,Table:%s,Where:%v,Set:%vGroupBy:%v,"+ + "GroupKey:%s,OrderBy:%v,ReverseOrder:%v,Interval:%v,Limit:%d,Outfile:%s,"+ + "RawQuery:%s,tokens:%v,LogFormat:%s)", q.Select, q.Table, q.Where, @@ -56,20 +56,14 @@ func NewQuery(queryStr string) (*Query, error) { if queryStr == "" { return nil, nil } - tokens := tokenize(queryStr) - q := Query{ RawQuery: queryStr, tokens: tokens, Interval: time.Second * 5, Limit: -1, } - - err := q.parse(tokens) - - logger.Debug(q) - return &q, err + return &q, q.parse(tokens) } // HasOutfile returns true if query result will be written to a CVS output file. @@ -178,13 +172,13 @@ func (q *Query) parse(tokens []token) error { } if len(q.Select) < 1 { - return errors.New(invalidQuery + "Expected at least one field in 'select' clause but got none") + return errors.New(invalidQuery + "Expected at least one field in 'select' " + + "clause but got none") } if len(q.GroupBy) == 0 { field := q.Select[0].Field q.GroupBy = append(q.GroupBy, field) } - if q.OrderBy != "" { var orderFieldIsValid bool for _, sc := range q.Select { @@ -194,7 +188,8 @@ func (q *Query) parse(tokens []token) error { } } if !orderFieldIsValid { - return errors.New(invalidQuery + fmt.Sprintf("Can not '(r)order by' '%s', must be present in 'select' clause", q.OrderBy)) + return errors.New(invalidQuery + fmt.Sprintf("Can not '(r)order by' '%s',"+ + "must be present in 'select' clause", q.OrderBy)) } } diff --git a/internal/mapr/query_test.go b/internal/mapr/query_test.go index b0b6c3a..88f7387 100644 --- a/internal/mapr/query_test.go +++ b/internal/mapr/query_test.go @@ -13,18 +13,25 @@ func TestParseQuerySimple(t *testing.T) { "select foo from bar where baz <", "select foo from bar where baz < 100 bay eq 12 group", "select foo from bar where baz < 100 bay eq 12 group by foo order by", - "select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz order by foo limit", - "select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz order by foo limit set foo = bar;", + "select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz " + + "order by foo limit", + "select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz " + + "order by foo limit set foo = bar;", } okQueries := []string{"select foo from bar", "select foo from bar where", "select foo from bar where baz < 100 bay eq 12", "select foo from bar where baz < 100, bay eq 12", "select foo from bar where baz < 100 and bay eq 12", - "select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz order by foo", - "select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz order by foo limit 23", - "select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz order by foo limit 23 outfile \"result.csv\"", - "select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz order by foo limit 23 outfile \"result.csv\" set $foo = maskdigits(bar), $baz = 12, $bay = $foo;", + "select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz " + + "order by foo", + "select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz " + + "order by foo limit 23", + "select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz " + + "order by foo limit 23 outfile \"result.csv\"", + "select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz " + + "order by foo limit 23 outfile \"result.csv\" " + + "set $foo = maskdigits(bar), $baz = 12, $bay = $foo;", } for _, queryStr := range errorQueries { @@ -46,8 +53,13 @@ func TestParseQuerySimple(t *testing.T) { func TestParseQueryDeep(t *testing.T) { dialects := []string{ - "select s1, `from`, count(s3) from table where w1 == 2 and w2 eq \"free beer\" group by g1, g2 order by count(s3) interval 10 limit 23 set $foo = maskdigits(bar), $baz = 12, $bay = $foo logformat generic", - "SELECT s1, `from`, COUNT(s3) FROM table WHERE w1 == 2 AND w2 eq \"free beer\" GROUP g1, g2 ORDER count(s3) INTERVAL 10 LIMIT 23 SET $foo = maskdigits(bar), $baz = 12, $bay = $foo logformat generic", + "select s1, `from`, count(s3) from table where w1 == 2 and w2 eq " + + "\"free beer\" group by g1, g2 order by count(s3) interval 10 limit 23 " + + "set $foo = maskdigits(bar), $baz = 12, $bay = $foo logformat generic", + + "SELECT s1, `from`, COUNT(s3) FROM table WHERE w1 == 2 AND w2 eq " + + "\"free beer\" GROUP g1, g2 ORDER count(s3) INTERVAL 10 LIMIT 23 " + + "SET $foo = maskdigits(bar), $baz = 12, $bay = $foo logformat generic", } for _, queryStr := range dialects { @@ -55,119 +67,144 @@ func TestParseQueryDeep(t *testing.T) { if err != nil { t.Errorf("%s: %s", err.Error(), queryStr) } - t.Log(q) // 'select' clause if len(q.Select) != 3 { - t.Errorf("Expected three elements in 'select' clause but got '%v': %s\n%v", q.Select, queryStr, q) + t.Errorf("Expected three elements in 'select' clause but got '%v': %s\n%v", + q.Select, queryStr, q) } - if q.Select[0].Field != "s1" { - t.Errorf("Expected 's1' as first element in 'select' clause but got '%v': %s\n%v", q.Select[0].Field, queryStr, q) + t.Errorf("Expected 's1' as first element in 'select' clause but got '%v': %s\n%v", + q.Select[0].Field, queryStr, q) } if q.Select[0].Operation != Last { - t.Errorf("Expected 'last' as aggregation function of first element in 'select' clause but got '%v': %s\n%v", q.Select[0].Operation, queryStr, q) + t.Errorf("Expected 'last' as aggregation function of first element in "+ + "'select' clause but got '%v': %s\n%v", q.Select[0].Operation, queryStr, q) } - if q.Select[1].Field != "from" { - t.Errorf("Expected 'from' as second element in 'select' clause but got '%v': %s\n%v", q.Select[1].Field, queryStr, q) + t.Errorf("Expected 'from' as second element in 'select' clause but got "+ + "'%v': %s\n%v", q.Select[1].Field, queryStr, q) } if q.Select[1].Operation != Last { - t.Errorf("Expected 'last' as aggregation function of second element in 'select' clause but got '%v': %s\n%v", q.Select[1].Operation, queryStr, q) + t.Errorf("Expected 'last' as aggregation function of second element in "+ + "'select' clause but got '%v': %s\n%v", q.Select[1].Operation, queryStr, q) } - if q.Select[2].Field != "s3" { - t.Errorf("Expected 's3' as third element in 'select' clause but got '%v': %s\n%v", q.Select[2].Field, queryStr, q) + t.Errorf("Expected 's3' as third element in 'select' clause but got "+ + "'%v': %s\n%v", q.Select[2].Field, queryStr, q) } if q.Select[2].Operation != Count { - t.Errorf("Expected 'count' as aggregation function of third element in 'select' clause but got '%v': %s\n%v", q.Select[2].Operation, queryStr, q) + t.Errorf("Expected 'count' as aggregation function of third element in "+ + "'select' clause but got '%v': %s\n%v", q.Select[2].Operation, queryStr, q) } if q.Select[2].FieldStorage != "count(s3)" { - t.Errorf("Expected 'count(s3)' as third element's storage in 'select' clause but got '%v': %s\n%v", q.Select[2].FieldStorage, queryStr, q) + t.Errorf("Expected 'count(s3)' as third element's storage in 'select' "+ + "clause but got '%v': %s\n%v", q.Select[2].FieldStorage, queryStr, q) } // 'from' clause if q.Table != "TABLE" { - t.Errorf("Expected 'TABLE' in 'from' clause but got '%v': %s\n%v", q.Table, queryStr, q) + t.Errorf("Expected 'TABLE' in 'from' clause but got '%v': %s\n%v", + q.Table, queryStr, q) } // 'where' clause if len(q.Where) != 2 { - t.Errorf("Expected two elements in 'where' clause but got '%v': %s\n%v", q.Where, queryStr, q) + t.Errorf("Expected two elements in 'where' clause but got '%v': %s\n%v", + q.Where, queryStr, q) } if q.Where[0].lString != "w1" { - t.Errorf("Expected w1 as first element in 'where' clause but got '%v': %s\n%v", q.Where[0].lString, queryStr, q) + t.Errorf("Expected w1 as first element in 'where' clause but got '%v': %s\n%v", + q.Where[0].lString, queryStr, q) } if q.Where[0].Operation != FloatEq { - t.Errorf("Expected FloatEq operation in first 'where' condition but got '%v': %s\n%v", q.Where[0].Operation, queryStr, q) + t.Errorf("Expected FloatEq operation in first 'where' condition but got "+ + "'%v': %s\n%v", q.Where[0].Operation, queryStr, q) } if q.Where[0].rFloat != 2 { - t.Errorf("Expected '2' as float argument in first 'where' condition but got '%v': %s\n%v", q.Where[0].rFloat, queryStr, q) + t.Errorf("Expected '2' as float argument in first 'where' condition but "+ + "got '%v': %s\n%v", q.Where[0].rFloat, queryStr, q) } if q.Where[1].lString != "w2" { - t.Errorf("Expected w2 as second element in 'where' clause but got '%v': %s\n%v", q.Where[1].lString, queryStr, q) + t.Errorf("Expected w2 as second element in 'where' clause but got '%v': "+ + "%s\n%v", q.Where[1].lString, queryStr, q) } if q.Where[1].Operation != StringEq { - t.Errorf("Expected StringEq operation in second 'where' condition but got '%v': %s\n%v", q.Where[0].Operation, queryStr, q) + t.Errorf("Expected StringEq operation in second 'where' condition but got "+ + "'%v': %s\n%v", q.Where[0].Operation, queryStr, q) } if q.Where[1].rString != "free beer" { - t.Errorf("Expected 'free beer' as string argument in second 'where' condition but got '%v': %s\n%v", q.Where[0].rString, queryStr, q) + t.Errorf("Expected 'free beer' as string argument in second 'where' "+ + "condition but got '%v': %s\n%v", q.Where[0].rString, queryStr, q) } // 'group by' clause if len(q.GroupBy) != 2 { - t.Errorf("Expected two elements in 'group by' clause but got '%v': %s\n%v", q.GroupBy, queryStr, q) + t.Errorf("Expected two elements in 'group by' clause but got '%v': %s\n%v", + q.GroupBy, queryStr, q) } if q.GroupBy[0] != "g1" { - t.Errorf("Expected 'g1' as first element in 'group by' clause but got '%v': %s\n%v", q.GroupBy[0], queryStr, q) + t.Errorf("Expected 'g1' as first element in 'group by' clause but got "+ + "'%v': %s\n%v", q.GroupBy[0], queryStr, q) } if q.GroupBy[1] != "g2" { - t.Errorf("Expected 'g2' as second element in 'group by' clause but got '%v': %s\n%v", q.GroupBy[1], queryStr, q) + t.Errorf("Expected 'g2' as second element in 'group by' clause but got "+ + "'%v': %s\n%v", q.GroupBy[1], queryStr, q) } if q.GroupKey != "g1,g2" { - t.Errorf("Expected 'g1,g2' as group key in 'group by' clause but got '%v': %s\n%v", q.GroupKey, queryStr, q) + t.Errorf("Expected 'g1,g2' as group key in 'group by' clause but got "+ + "'%v': %s\n%v", q.GroupKey, queryStr, q) } // 'order by' clause if q.OrderBy != "count(s3)" { - t.Errorf("Expected 'count(s3)' as element in 'order by' clause but got '%v': %s\n%v", q.OrderBy, queryStr, q) + t.Errorf("Expected 'count(s3)' as element in 'order by' clause but got "+ + "'%v': %s\n%v", q.OrderBy, queryStr, q) } // 'interval' clause if q.Interval != time.Second*time.Duration(10) { - t.Errorf("Expected '10s' as duration 'interval' clause but got '%v': %s\n%v", q.Interval, queryStr, q) + t.Errorf("Expected '10s' as duration 'interval' clause but got '%v': %s\n%v", + q.Interval, queryStr, q) } // 'limit' clause if q.Limit != 23 { - t.Errorf("Expected '23' as limit in 'limit' clause but got '%v': %s\n%v", q.Limit, queryStr, q) + t.Errorf("Expected '23' as limit in 'limit' clause but got '%v': %s\n%v", + q.Limit, queryStr, q) } // 'set' clause if q.Set[0].lString != "$foo" { - t.Errorf("Expected '$foo' lvalue in first 'set' condition clause but got '%v': %s\n%v", q.Set[0].lString, queryStr, q) + t.Errorf("Expected '$foo' lvalue in first 'set' condition clause but got "+ + "'%v': %s\n%v", q.Set[0].lString, queryStr, q) } if q.Set[0].rString != "bar" { - t.Errorf("Expected 'bar' rvalue in first 'set' condition clause but got '%v': %s\n%v", q.Set[0].rString, queryStr, q) + t.Errorf("Expected 'bar' rvalue in first 'set' condition clause but got "+ + "'%v': %s\n%v", q.Set[0].rString, queryStr, q) } - if q.Set[1].lString != "$baz" { - t.Errorf("Expected '$baz' lvalue in second 'set' condition clause but got '%v': %s\n%v", q.Set[1].lString, queryStr, q) + t.Errorf("Expected '$baz' lvalue in second 'set' condition clause but got "+ + "'%v': %s\n%v", q.Set[1].lString, queryStr, q) } if q.Set[1].rString != "12" { - t.Errorf("Expected '12' rvalue in second 'set' condition clause but got '%v': %s\n%v", q.Set[1].rString, queryStr, q) + t.Errorf("Expected '12' rvalue in second 'set' condition clause but got "+ + "'%v': %s\n%v", q.Set[1].rString, queryStr, q) } - if q.Set[2].lString != "$bay" { - t.Errorf("Expected '$bay' lvalue in third 'set' condition clause but got '%v': %s\n%v", q.Set[2].lString, queryStr, q) + t.Errorf("Expected '$bay' lvalue in third 'set' condition clause but got "+ + "'%v': %s\n%v", q.Set[2].lString, queryStr, q) } if q.Set[2].rString != "$foo" { - t.Errorf("Expected '$foo' rvalue in third 'set' condition clause but got '%v': %s\n%v", q.Set[2].rString, queryStr, q) + t.Errorf("Expected '$foo' rvalue in third 'set' condition clause but got "+ + "'%v': %s\n%v", q.Set[2].rString, queryStr, q) } + // 'logformat' clause if q.LogFormat != "generic" { - t.Errorf("Expected 'generic' logformat got '%v': %s\n%v", q.LogFormat, queryStr, q) + t.Errorf("Expected 'generic' logformat got '%v': %s\n%v", + q.LogFormat, queryStr, q) } } } diff --git a/internal/mapr/selectcondition.go b/internal/mapr/selectcondition.go index d6aa0d4..5cfb8c7 100644 --- a/internal/mapr/selectcondition.go +++ b/internal/mapr/selectcondition.go @@ -37,7 +37,6 @@ func (sc selectCondition) String() string { func makeSelectConditions(tokens []token) ([]selectCondition, error) { var sel []selectCondition - // Parse select aggregation, e.g. sum(foo) parse := func(token token) (selectCondition, error) { var sc selectCondition @@ -52,13 +51,15 @@ func makeSelectConditions(tokens []token) ([]selectCondition, error) { a := strings.Split(tokenStr, "(") if len(a) != 2 { - return sc, errors.New(invalidQuery + "Can't parse 'select' aggregation: " + token.str) + return sc, errors.New(invalidQuery + "Can't parse 'select' aggregation: " + + token.str) } agg := a[0] // Aggregation, e.g. 'sum' b := strings.Split(a[1], ")") if len(b) != 2 { - return sc, errors.New(invalidQuery + "Can't parse 'select' field name from aggregation: " + token.str) + return sc, errors.New(invalidQuery + "Can't parse 'select' field name " + + "from aggregation: " + token.str) } sc.Field = b[0] // Field name, e.g. 'foo' sc.FieldStorage = tokenStr // e.g. 'sum(foo)' @@ -79,9 +80,9 @@ func makeSelectConditions(tokens []token) ([]selectCondition, error) { case "len": sc.Operation = Len default: - return sc, errors.New(invalidQuery + "Unknown aggregation in 'select' clause: " + agg) + return sc, errors.New(invalidQuery + + "Unknown aggregation in 'select' clause: " + agg) } - return sc, nil } @@ -92,6 +93,5 @@ func makeSelectConditions(tokens []token) ([]selectCondition, error) { } sel = append(sel, sc) } - return sel, nil } diff --git a/internal/mapr/server/aggregate.go b/internal/mapr/server/aggregate.go index 28bb074..97fee11 100644 --- a/internal/mapr/server/aggregate.go +++ b/internal/mapr/server/aggregate.go @@ -8,25 +8,22 @@ import ( "github.com/mimecast/dtail/internal" "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/io/dlog" "github.com/mimecast/dtail/internal/io/line" - "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/mapr" "github.com/mimecast/dtail/internal/mapr/logformat" + "github.com/mimecast/dtail/internal/protocol" ) // Aggregate is for aggregating mapreduce data on the DTail server side. type Aggregate struct { done *internal.Done - // Log lines to process (parsing MAPREDUCE lines). - Lines chan line.Line + // NextLinesCh can be used to use a new line ch. + NextLinesCh chan chan line.Line // Hostname of the current server (used to populate $hostname field). hostname string // Signals to serialize data. serialize chan struct{} - // Signals to flush data. - flush chan struct{} - // Signals that data has been flushed - flushed chan struct{} // The mapr query query *mapr.Query // The mapr log format parser @@ -42,7 +39,7 @@ func NewAggregate(queryStr string) (*Aggregate, error) { fqdn, err := os.Hostname() if err != nil { - logger.Error(err) + dlog.Common.Error(err) } s := strings.Split(fqdn, ".") @@ -57,38 +54,32 @@ func NewAggregate(queryStr string) (*Aggregate, error) { parserName = query.LogFormat } - logger.Info("Creating log format parser", parserName) + dlog.Common.Info("Creating log format parser", parserName) logParser, err := logformat.NewParser(parserName, query) if err != nil { - logger.Error("Could not create log format parser. Falling back to 'generic'", err) + dlog.Common.Error("Could not create log format parser. Falling back to 'generic'", err) if logParser, err = logformat.NewParser("generic", query); err != nil { - logger.FatalExit("Could not create log format parser", err) + dlog.Common.FatalPanic("Could not create log format parser", err) } } - a := Aggregate{ - done: internal.NewDone(), - Lines: make(chan line.Line, 100), - serialize: make(chan struct{}), - flush: make(chan struct{}), - flushed: make(chan struct{}), - hostname: s[0], - query: query, - parser: logParser, - } - - return &a, nil + return &Aggregate{ + done: internal.NewDone(), + NextLinesCh: make(chan chan line.Line, 10), + serialize: make(chan struct{}), + hostname: s[0], + query: query, + parser: logParser, + }, nil } // Shutdown the aggregation engine. func (a *Aggregate) Shutdown() { - a.Flush() a.done.Shutdown() } // Start an aggregation. -func (a *Aggregate) Start(ctx context.Context, maprLines chan<- string) { - +func (a *Aggregate) Start(ctx context.Context, maprMessages chan<- string) { myCtx, cancel := context.WithCancel(ctx) defer cancel() @@ -101,15 +92,14 @@ func (a *Aggregate) Start(ctx context.Context, maprLines chan<- string) { } }() - fieldsCh := a.makeFields(myCtx) - + fieldsCh := a.fieldsFromLines(myCtx) // Add fields (e.g. via 'set' clause) if len(a.query.Set) > 0 { - fieldsCh = a.addFields(myCtx, fieldsCh) + fieldsCh = a.setAdditionalFields(myCtx, fieldsCh) } - + // Periodically pre-aggregate data every a.query.Interval seconds. go a.aggregateTimer(myCtx) - a.makeMaprLines(myCtx, fieldsCh, maprLines) + a.aggregateAndSerialize(myCtx, fieldsCh, maprMessages) } func (a *Aggregate) aggregateTimer(ctx context.Context) { @@ -123,25 +113,46 @@ func (a *Aggregate) aggregateTimer(ctx context.Context) { } } -func (a *Aggregate) makeFields(ctx context.Context) <-chan map[string]string { - ch := make(chan map[string]string) +func (a *Aggregate) fieldsFromLines(ctx context.Context) <-chan map[string]string { + fieldsCh := make(chan map[string]string) go func() { - defer close(ch) + defer close(fieldsCh) + var lines chan line.Line + + // Gather first lines channel (first input file) + select { + case lines = <-a.NextLinesCh: + case <-ctx.Done(): + return + } for { select { - case line, ok := <-a.Lines: + case line, ok := <-lines: if !ok { - return + select { + case lines = <-a.NextLinesCh: + // Have a new lines channel (e.g. new input file) + case <-ctx.Done(): + default: + // No new lines channel found. + return + } } - maprLine := strings.TrimSpace(string(line.Content)) + maprLine := strings.TrimSpace(line.Content.String()) fields, err := a.parser.MakeFields(maprLine) - logger.Debug(fields, err) + // Can't recycle it here yet, as field slices are still + // TODO: Add unit test reading from multiple mapreduce files lines. + // TODO: Add capability to recycle this bytes buffer. + //pool.RecycleBytesBuffer(line.Content) if err != nil { - logger.Error(err) + // Should fields be ignored anyway? + if err != logformat.ErrIgnoreFields { + dlog.Common.Error(fields, err) + } continue } if !a.query.WhereClause(fields) { @@ -149,7 +160,7 @@ func (a *Aggregate) makeFields(ctx context.Context) <-chan map[string]string { } select { - case ch <- fields: + case fieldsCh <- fields: case <-ctx.Done(): } case <-ctx.Done(): @@ -158,45 +169,42 @@ func (a *Aggregate) makeFields(ctx context.Context) <-chan map[string]string { } }() - return ch + return fieldsCh } -func (a *Aggregate) addFields(ctx context.Context, fieldsCh <-chan map[string]string) <-chan map[string]string { - ch := make(chan map[string]string) +func (a *Aggregate) setAdditionalFields(ctx context.Context, + fieldsCh <-chan map[string]string) <-chan map[string]string { + newFieldsCh := make(chan map[string]string) go func() { - defer close(ch) - + defer close(newFieldsCh) for { - // fieldsCh will be closed via 'makeFields' if ctx is done fields, ok := <-fieldsCh if !ok { return } if err := a.query.SetClause(fields); err != nil { - logger.Error(err) + dlog.Common.Error(err) } select { - case ch <- fields: + case newFieldsCh <- fields: case <-ctx.Done(): } } }() - - return ch + return newFieldsCh } -func (a *Aggregate) makeMaprLines(ctx context.Context, fieldsCh <-chan map[string]string, maprLines chan<- string) { - group := mapr.NewGroupSet() +func (a *Aggregate) aggregateAndSerialize(ctx context.Context, + fieldsCh <-chan map[string]string, maprMessages chan<- string) { + group := mapr.NewGroupSet() serialize := func() { - logger.Info("Serializing mapreduce result") - group.Serialize(ctx, maprLines) + dlog.Common.Info("Serializing mapreduce result") + group.Serialize(ctx, maprMessages) group = mapr.NewGroupSet() - logger.Info("Done serializing mapreduce result") } - for { select { case fields, ok := <-fieldsCh: @@ -207,9 +215,6 @@ func (a *Aggregate) makeMaprLines(ctx context.Context, fieldsCh <-chan map[strin a.aggregate(group, fields) case <-a.serialize: serialize() - case <-a.flush: - serialize() - a.flushed <- struct{}{} case <-ctx.Done(): return } @@ -217,12 +222,10 @@ func (a *Aggregate) makeMaprLines(ctx context.Context, fieldsCh <-chan map[strin } func (a *Aggregate) aggregate(group *mapr.GroupSet, fields map[string]string) { - //logger.Trace("Aggregating", group, fields) var sb strings.Builder - for i, field := range a.query.GroupBy { if i > 0 { - sb.WriteString(" ") + sb.WriteString(protocol.AggregateGroupKeyCombinator) } if val, ok := fields[field]; ok { sb.WriteString(val) @@ -235,7 +238,7 @@ func (a *Aggregate) aggregate(group *mapr.GroupSet, fields map[string]string) { for _, sc := range a.query.Select { if val, ok := fields[sc.Field]; ok { if err := set.Aggregate(sc.FieldStorage, sc.Operation, val, false); err != nil { - logger.Error(err) + dlog.Common.Error(err) continue } addedSample = true @@ -246,8 +249,7 @@ func (a *Aggregate) aggregate(group *mapr.GroupSet, fields map[string]string) { set.Samples++ return } - - logger.Trace("Aggregated data locally without adding new samples") + dlog.Common.Trace("Aggregated data locally without adding new samples") } // Serialize all the aggregated data. @@ -255,28 +257,7 @@ func (a *Aggregate) Serialize(ctx context.Context) { select { case a.serialize <- struct{}{}: case <-time.After(time.Minute): - logger.Warn("Starting to serialize mapredice data takes over a minute") + dlog.Common.Warn("Starting to serialize mapredice data takes over a minute") case <-ctx.Done(): } } - -// Flush all data. -func (a *Aggregate) Flush() { - select { - case a.flush <- struct{}{}: - logger.Info("Flushing mapreduce data") - case <-time.After(time.Minute): - logger.Warn("Starting to flush mapreduce data takes over a minute") - return - case <-a.done.Done(): - return - } - - select { - case <-a.flushed: - logger.Info("Done flushing") - case <-time.After(time.Minute): - logger.Warn("Waiting for data to be flushed takes over a minute") - case <-a.done.Done(): - } -} diff --git a/internal/mapr/setclause.go b/internal/mapr/setclause.go index b4c2f73..1843d31 100644 --- a/internal/mapr/setclause.go +++ b/internal/mapr/setclause.go @@ -7,7 +7,6 @@ func (q *Query) SetClause(fields map[string]string) error { if !ok { continue } - switch sc.rType { case FunctionStack: fields[sc.lString] = sc.functionStack.Call(value) @@ -15,6 +14,5 @@ func (q *Query) SetClause(fields map[string]string) error { fields[sc.lString] = value } } - return nil } diff --git a/internal/mapr/setcondition.go b/internal/mapr/setcondition.go index 8c5cfc9..92b21f4 100644 --- a/internal/mapr/setcondition.go +++ b/internal/mapr/setcondition.go @@ -39,20 +39,22 @@ func makeSetConditions(tokens []token) (set []setCondition, err error) { switch setOp { case "=": default: - return sc, nil, errors.New(invalidQuery + "Unknown operation in 'set' clause: " + setOp) + return sc, nil, errors.New(invalidQuery + "Unknown operation in 'set' " + + "clause: " + setOp) } if !tokens[0].isBareword { - return sc, nil, errors.New(invalidQuery + "Expected bareword at 'set' clause's lValue: " + tokens[0].str) + return sc, nil, errors.New(invalidQuery + "Expected bareword at 'set' " + + "clause's lValue: " + tokens[0].str) } - sc.lString = tokens[0].str if !strings.HasPrefix(sc.lString, "$") { - return sc, nil, errors.New(invalidQuery + "Expected field variable name (starting with $) at 'set' clause's lValue: " + tokens[0].str) + return sc, nil, errors.New(invalidQuery + "Expected field variable name " + + "(starting with $) at 'set' clause's lValue: " + tokens[0].str) } sc.rType = Field - rString := tokens[2].str + // Seems like a function call? if strings.HasSuffix(rString, ")") { functionStack, functionArg, err := funcs.NewFunctionStack(tokens[2].str) @@ -72,7 +74,6 @@ func makeSetConditions(tokens []token) (set []setCondition, err error) { } else { sc.rType = Field } - return sc, tokens[3:], nil } @@ -84,10 +85,8 @@ func makeSetConditions(tokens []token) (set []setCondition, err error) { if err != nil { return nil, err } - set = append(set, sc) tokens = tokensConsumeOptional(tokens, ",") } - return } diff --git a/internal/mapr/token.go b/internal/mapr/token.go index 8972188..6ac7631 100644 --- a/internal/mapr/token.go +++ b/internal/mapr/token.go @@ -4,7 +4,8 @@ import ( "strings" ) -var keywords = [...]string{"select", "from", "where", "set", "group", "rorder", "order", "interval", "limit", "outfile", "logformat"} +var keywords = [...]string{"select", "from", "where", "set", "group", "rorder", + "order", "interval", "limit", "outfile", "logformat"} // Represents a parsed token, used to parse the mapr query. type token struct { @@ -16,13 +17,11 @@ func (t token) isKeyword() bool { if !t.isBareword { return false } - for _, keyword := range keywords { if strings.ToLower(t.str) == keyword { return true } } - return false } @@ -32,7 +31,6 @@ func (t token) String() string { func tokenize(queryStr string) []token { var tokens []token - for i, part := range strings.Split(queryStr, "\"") { // Even i, means that it is not a quoted string if i%2 == 0 { @@ -53,17 +51,15 @@ func tokenize(queryStr string) []token { } tokens = append(tokens, token) } - return tokens } func tokensConsume(tokens []token) ([]token, []token) { - //logger.Trace("=====================") + //dlog.Common.Trace("=====================") var consumed []token - for i, t := range tokens { if t.isKeyword() { - //logger.Trace("keyword", t) + //dlog.Common.Trace("keyword", t) return tokens[i:], consumed } // strip escapes, such as ` from `foo`, this allows to use keywords as field names @@ -73,7 +69,7 @@ func tokensConsume(tokens []token) ([]token, []token) { } if t.str[0] == '`' && t.str[length-1] == '`' { stripped := t.str[1 : length-1] - //logger.Trace("stripped", stripped) + //dlog.Common.Trace("stripped", stripped) t := token{ str: stripped, isBareword: t.isBareword, @@ -81,11 +77,10 @@ func tokensConsume(tokens []token) ([]token, []token) { consumed = append(consumed, t) continue } - //logger.Trace("bare", token) + //dlog.Common.Trace("bare", token) consumed = append(consumed, t) } - - //logger.Trace("result", consumed) + //dlog.Common.Trace("result", consumed) return nil, consumed } @@ -95,7 +90,6 @@ func tokensConsumeStr(tokens []token) ([]token, []string) { for _, token := range found { strings = append(strings, token.str) } - return tokens, strings } @@ -106,6 +100,5 @@ func tokensConsumeOptional(tokens []token, optional string) []token { if strings.ToLower(tokens[0].str) == strings.ToLower(optional) { return tokens[1:] } - return tokens } diff --git a/internal/mapr/whereclause.go b/internal/mapr/whereclause.go index cc1c164..d9f32eb 100644 --- a/internal/mapr/whereclause.go +++ b/internal/mapr/whereclause.go @@ -3,14 +3,13 @@ package mapr import ( "strconv" - "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/io/dlog" ) // WhereClause interprets the where clause of the mapreduce query. func (q *Query) WhereClause(fields map[string]string) bool { for _, wc := range q.Where { var ok bool - if wc.Operation > FloatOperation { var lValue, rValue float64 if lValue, ok = whereClauseFloatValue(fields, wc.lString, wc.lFloat, wc.lType); !ok { @@ -36,11 +35,12 @@ func (q *Query) WhereClause(fields map[string]string) bool { return false } } - return true } -func whereClauseFloatValue(fields map[string]string, str string, float float64, t fieldType) (float64, bool) { +func whereClauseFloatValue(fields map[string]string, str string, float float64, + t fieldType) (float64, bool) { + switch t { case Float: return float, true @@ -55,12 +55,14 @@ func whereClauseFloatValue(fields map[string]string, str string, float float64, } return f, true default: - logger.Error("Unexpected argument in 'where' clause", str, float, t) + dlog.Common.Error("Unexpected argument in 'where' clause", str, float, t) return 0, false } } -func whereClauseStringValue(fields map[string]string, str string, t fieldType) (string, bool) { +func whereClauseStringValue(fields map[string]string, str string, + t fieldType) (string, bool) { + switch t { case Field: value, ok := fields[str] @@ -71,7 +73,7 @@ func whereClauseStringValue(fields map[string]string, str string, t fieldType) ( case String: return str, true default: - logger.Error("Unexpected argument in 'where' clause", str, t) + dlog.Common.Error("Unexpected argument in 'where' clause", str, t) return str, false } } diff --git a/internal/mapr/wherecondition.go b/internal/mapr/wherecondition.go index 7a60dba..280dcfb 100644 --- a/internal/mapr/wherecondition.go +++ b/internal/mapr/wherecondition.go @@ -6,7 +6,7 @@ import ( "strconv" "strings" - "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/io/dlog" ) // QueryOperation determines the mapreduce operation. @@ -46,15 +46,18 @@ type whereCondition struct { } func (wc *whereCondition) String() string { - return fmt.Sprintf("whereCondition(Operation:%v,lString:%s,lFloat:%v,lType:%s,rString:%s,rFloat:%v,rType:%s)", - wc.Operation, wc.lString, wc.lFloat, wc.lType.String(), wc.rString, wc.rFloat, wc.rType.String()) + return fmt.Sprintf("whereCondition(Operation:%v,lString:%s,lFloat:%v,"+ + "lType:%s,rString:%s,rFloat:%v,rType:%s)", + wc.Operation, wc.lString, wc.lFloat, wc.lType.String(), wc.rString, + wc.rFloat, wc.rType.String()) } func makeWhereConditions(tokens []token) (where []whereCondition, err error) { parse := func(tokens []token) (whereCondition, []token, error) { var wc whereCondition if len(tokens) < 3 { - return wc, nil, errors.New(invalidQuery + "Not enough arguments in 'where' clause") + err := errors.New(invalidQuery + "Not enough arguments in 'where' clause") + return wc, nil, err } whereOp := strings.ToLower(tokens[1].str) @@ -94,7 +97,8 @@ func makeWhereConditions(tokens []token) (where []whereCondition, err error) { case "nhassuffix": wc.Operation = StringNotHasSuffix default: - return wc, nil, errors.New(invalidQuery + "Unknown operation in 'where' clause: " + whereOp) + return wc, nil, errors.New(invalidQuery + + "Unknown operation in 'where' clause: " + whereOp) } wc.lString = tokens[0].str @@ -102,7 +106,8 @@ func makeWhereConditions(tokens []token) (where []whereCondition, err error) { if wc.Operation > FloatOperation { if !tokens[0].isBareword { - return wc, nil, errors.New(invalidQuery + "Expected bareword at 'where' clause's lValue: " + tokens[0].str) + return wc, nil, errors.New(invalidQuery + + "Expected bareword at 'where' clause's lValue: " + tokens[0].str) } if f, err := strconv.ParseFloat(wc.lString, 64); err == nil { wc.lFloat = f @@ -112,7 +117,8 @@ func makeWhereConditions(tokens []token) (where []whereCondition, err error) { } if !tokens[2].isBareword { - return wc, nil, errors.New(invalidQuery + "Expected bareword at 'where' clause's rValue: " + tokens[2].str) + return wc, nil, errors.New(invalidQuery + + "Expected bareword at 'where' clause's rValue: " + tokens[2].str) } if f, err := strconv.ParseFloat(wc.rString, 64); err == nil { wc.rFloat = f @@ -133,23 +139,19 @@ func makeWhereConditions(tokens []token) (where []whereCondition, err error) { } else { wc.rType = String } - return wc, tokens[3:], nil } for len(tokens) > 0 { var wc whereCondition var err error - wc, tokens, err = parse(tokens) if err != nil { return nil, err } - where = append(where, wc) tokens = tokensConsumeOptional(tokens, "and") } - return } @@ -168,9 +170,8 @@ func (wc *whereCondition) floatClause(lValue float64, rValue float64) bool { case FloatGe: return lValue >= rValue default: - logger.Error("Unknown float operation", lValue, wc.Operation, rValue) + dlog.Common.Error("Unknown float operation", lValue, wc.Operation, rValue) } - return false } @@ -193,8 +194,7 @@ func (wc *whereCondition) stringClause(lValue string, rValue string) bool { case StringNotHasSuffix: return !strings.HasSuffix(lValue, rValue) default: - logger.Error("Unknown string operation", lValue, wc.Operation, rValue) + dlog.Common.Error("Unknown string operation", lValue, wc.Operation, rValue) } - return false } diff --git a/internal/protocol/protocol.go b/internal/protocol/protocol.go new file mode 100644 index 0000000..d29706c --- /dev/null +++ b/internal/protocol/protocol.go @@ -0,0 +1,18 @@ +package protocol + +const ( + // ProtocolCompat -ibility version + ProtocolCompat string = "4" + // MessageDelimiter delimits separate messages. + MessageDelimiter byte = '¬' + // FieldDelimiter delimits messagefields. + FieldDelimiter string = "|" + // CSVDelimiter delimits CSV file fields.kj:w + CSVDelimiter string = "," + // AggregateKVDelimiter delimits key-values of an aggregation message. + AggregateKVDelimiter string = "≔" + // AggregateDelimiter delimits parts of an aggregation message. + AggregateDelimiter string = "∥" + // AggregateGroupKeyCombinator combines the group set keys. + AggregateGroupKeyCombinator string = "," +) diff --git a/internal/regex/regex.go b/internal/regex/regex.go index 2561659..eb6e1b3 100644 --- a/internal/regex/regex.go +++ b/internal/regex/regex.go @@ -4,8 +4,6 @@ import ( "fmt" "regexp" "strings" - - "github.com/mimecast/dtail/internal/io/logger" ) // Regex for filtering lines. @@ -50,9 +48,7 @@ func new(regexStr string, flags []Flag) (Regex, error) { regexStr: regexStr, flags: flags, } - re, err := regexp.Compile(regexStr) - if err != nil { return r, err } @@ -91,17 +87,15 @@ func (r Regex) MatchString(str string) bool { } // Serialize the regex. -func (r Regex) Serialize() string { +func (r Regex) Serialize() (string, error) { var flags []string for _, flag := range r.flags { flags = append(flags, flag.String()) } - if !r.initialized { - logger.FatalExit("Unable to serialize regex as not initialized properly", r) + return "", fmt.Errorf("Unable to serialize regex as not initialized properly: %v", r) } - - return fmt.Sprintf("regex:%s %s", strings.Join(flags, ","), r.regexStr) + return fmt.Sprintf("regex:%s %s", strings.Join(flags, ","), r.regexStr), nil } // Deserialize the regex. @@ -109,15 +103,14 @@ func Deserialize(str string) (Regex, error) { // Get regex string s := strings.SplitN(str, " ", 2) if len(s) < 2 { - logger.Debug("Using noop regex", str) return NewNoop(), nil } - flagsStr := s[0] regexStr := s[1] if !strings.HasPrefix(flagsStr, "regex") { - return Regex{}, fmt.Errorf("unable to deserialize regex '%s': should start with string 'regex'", str) + return Regex{}, fmt.Errorf("unable to deserialize regex '%s': should start "+ + "with string 'regex'", str) } // Parse regex flags, e.g. "regex:flag1,flag2,flag3..." @@ -127,13 +120,10 @@ func Deserialize(str string) (Regex, error) { for _, flagStr := range strings.Split(s[1], ",") { flag, err := NewFlag(flagStr) if err != nil { - logger.Error("ignoring flag", err) continue } - logger.Debug("Adding regex flag", flag) flags = append(flags, flag) } } - return new(regexStr, flags) } diff --git a/internal/regex/regex_test.go b/internal/regex/regex_test.go index a5e7faf..033a286 100644 --- a/internal/regex/regex_test.go +++ b/internal/regex/regex_test.go @@ -9,7 +9,8 @@ func TestRegex(t *testing.T) { r := NewNoop() if !r.MatchString(input) { - t.Errorf("expected to match string '%s' with noop regex '%v' but didn't\n", input, r) + t.Errorf("expected to match string '%s' with noop regex '%v' but didn't\n", + input, r) } r, err := New(".hello", Default) @@ -17,16 +18,21 @@ func TestRegex(t *testing.T) { t.Errorf("unable to create regex: %v\n", err) } if r.MatchString(input) { - t.Errorf("expected to match string '%s' with regex '%v' but didn't\n", input, r) + t.Errorf("expected to match string '%s' with regex '%v' but didn't\n", + input, r) } - r2, err := Deserialize(r.Serialize()) + serialized, err := r.Serialize() if err != nil { - t.Errorf("unable to serialize deserialized regex: %v: %v\n", r.Serialize(), err) + t.Errorf("unable to serialize regex: %v: %v\n", serialized, err) + } + r2, err := Deserialize(serialized) + if err != nil { + t.Errorf("unable to serialize deserialized regex: %v: %v\n", serialized, err) } if r.String() != r2.String() { - t.Errorf("regex should be the same after deserialize(serialize(..)), got '%s' but expected '%s'.\n", - r2.String(), r.String()) + t.Errorf("regex should be the same after deserialize(serialize(..)), got "+ + "'%s' but expected '%s'.\n", r2.String(), r.String()) } r, err = New(".hello", Invert) @@ -34,15 +40,20 @@ func TestRegex(t *testing.T) { t.Errorf("unable to create regex: %v\n", err) } if !r.MatchString(input) { - t.Errorf("expected to not match string '%s' with regex '%v' but matched\n", input, r) + t.Errorf("expected to not match string '%s' with regex '%v' but matched\n", + input, r) } - r2, err = Deserialize(r.Serialize()) + serialized, err = r.Serialize() + if err != nil { + t.Errorf("unable to serialize regex: %v: %v\n", serialized, err) + } + r2, err = Deserialize(serialized) if err != nil { - t.Errorf("unable to serialize deserialized regex: %v: %v\n", r.Serialize(), err) + t.Errorf("unable to serialize deserialized regex: %v: %v\n", serialized, err) } if r.String() != r2.String() { - t.Errorf("regex should be the same after deserialize(serialize(..)), got '%s' but expected '%s'.\n", - r2.String(), r.String()) + t.Errorf("regex should be the same after deserialize(serialize(..)), got "+ + "'%s' but expected '%s'.\n", r2.String(), r.String()) } } diff --git a/internal/server/continuous.go b/internal/server/continuous.go index f75c732..93b3fcb 100644 --- a/internal/server/continuous.go +++ b/internal/server/continuous.go @@ -8,33 +8,29 @@ import ( "github.com/mimecast/dtail/internal/clients" "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/io/dlog" "github.com/mimecast/dtail/internal/omode" - gossh "golang.org/x/crypto/ssh" ) -type continuous struct { -} +type continuous struct{} func newContinuous() *continuous { return &continuous{} } func (c *continuous) start(ctx context.Context) { - logger.Info("Starting continuous job runner after 10s") + dlog.Server.Info("Starting continuous job runner after 10s") time.Sleep(time.Second * 10) - c.runJobs(ctx) } func (c *continuous) runJobs(ctx context.Context) { for _, job := range config.Server.Continuous { if !job.Enable { - logger.Debug(job.Name, "Not running job as not enabled") + dlog.Server.Debug(job.Name, "Not running job as not enabled") continue } - go func(job config.Continuous) { c.runJob(ctx, job) for { @@ -51,18 +47,17 @@ func (c *continuous) runJobs(ctx context.Context) { } func (c *continuous) runJob(ctx context.Context, job config.Continuous) { - logger.Debug(job.Name, "Processing job") + dlog.Server.Debug(job.Name, "Processing job") files := fillDates(job.Files) outfile := fillDates(job.Outfile) - servers := strings.Join(job.Servers, ",") if servers == "" { servers = config.Server.SSHBindAddress } - args := clients.Args{ - ConnectionsPerCPU: 10, + args := config.Args{ + ConnectionsPerCPU: config.DefaultConnectionsPerCPU, Discovery: job.Discovery, ServersStr: servers, What: files, @@ -71,35 +66,32 @@ func (c *continuous) runJob(ctx context.Context, job config.Continuous) { } args.SSHAuthMethods = append(args.SSHAuthMethods, gossh.Password(job.Name)) - - query := fmt.Sprintf("%s outfile %s", job.Query, outfile) - client, err := clients.NewMaprClient(args, query, clients.NonCumulativeMode) + args.QueryStr = fmt.Sprintf("%s outfile %s", job.Query, outfile) + client, err := clients.NewMaprClient(args, clients.NonCumulativeMode) if err != nil { - logger.Error(fmt.Sprintf("Unable to create job %s", job.Name), err) + dlog.Server.Error(fmt.Sprintf("Unable to create job %s", job.Name), err) return } jobCtx, cancel := context.WithCancel(ctx) defer cancel() - if job.RestartOnDayChange { go func() { if c.waitForDayChange(ctx) { - logger.Info(fmt.Sprintf("Canceling job %s due to day change", job.Name)) + dlog.Server.Info(fmt.Sprintf("Canceling job %s due to day change", job.Name)) cancel() } }() } - logger.Info(fmt.Sprintf("Starting job %s", job.Name)) + dlog.Server.Info(fmt.Sprintf("Starting job %s", job.Name)) status := client.Start(jobCtx, make(chan string)) logMessage := fmt.Sprintf("Job exited with status %d", status) - if status != 0 { - logger.Warn(logMessage) + dlog.Server.Warn(logMessage) return } - logger.Info(logMessage) + dlog.Server.Info(logMessage) } func (c *continuous) waitForDayChange(ctx context.Context) bool { diff --git a/internal/server/handlers/basehandler.go b/internal/server/handlers/basehandler.go new file mode 100644 index 0000000..6d10d17 --- /dev/null +++ b/internal/server/handlers/basehandler.go @@ -0,0 +1,320 @@ +package handlers + +import ( + "bytes" + "context" + "encoding/base64" + "errors" + "fmt" + "io" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/mimecast/dtail/internal" + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/io/dlog" + "github.com/mimecast/dtail/internal/io/line" + "github.com/mimecast/dtail/internal/io/pool" + "github.com/mimecast/dtail/internal/lcontext" + "github.com/mimecast/dtail/internal/mapr/server" + "github.com/mimecast/dtail/internal/protocol" + user "github.com/mimecast/dtail/internal/user/server" +) + +type handleCommandCb func(context.Context, lcontext.LContext, int, []string, string) + +type baseHandler struct { + done *internal.Done + handleCommandCb handleCommandCb + lines chan line.Line + aggregate *server.Aggregate + maprMessages chan string + serverMessages chan string + hostname string + user *user.User + ackCloseReceived chan struct{} + activeCommands int32 + readBuf bytes.Buffer + writeBuf bytes.Buffer + + // Some global options + sync primitives required. + once sync.Once + mutex sync.Mutex + quiet bool + spartan bool + serverless bool +} + +// Shutdown the handler. +func (h *baseHandler) Shutdown() { + h.done.Shutdown() +} + +// Done channel of the handler. +func (h *baseHandler) Done() <-chan struct{} { + return h.done.Done() +} + +// Read is to send data to the dtail client via Reader interface. +func (h *baseHandler) Read(p []byte) (n int, err error) { + defer h.readBuf.Reset() + + select { + case message := <-h.serverMessages: + if len(message) > 0 && message[0] == '.' { + // Handle hidden message (don't display to the user) + h.readBuf.WriteString(message) + h.readBuf.WriteByte(protocol.MessageDelimiter) + n = copy(p, h.readBuf.Bytes()) + return + } + + if h.serverless { + return + } + + // Handle normal server message (display to the user) + h.readBuf.WriteString("SERVER") + h.readBuf.WriteString(protocol.FieldDelimiter) + h.readBuf.WriteString(h.hostname) + h.readBuf.WriteString(protocol.FieldDelimiter) + h.readBuf.WriteString(message) + h.readBuf.WriteByte(protocol.MessageDelimiter) + n = copy(p, h.readBuf.Bytes()) + + case message := <-h.maprMessages: + // Send mapreduce-aggregated data as a message. + h.readBuf.WriteString("AGGREGATE") + h.readBuf.WriteString(protocol.FieldDelimiter) + h.readBuf.WriteString(h.hostname) + h.readBuf.WriteString(protocol.FieldDelimiter) + h.readBuf.WriteString(message) + h.readBuf.WriteByte(protocol.MessageDelimiter) + n = copy(p, h.readBuf.Bytes()) + + case line := <-h.lines: + if !h.spartan { + h.readBuf.WriteString("REMOTE") + h.readBuf.WriteString(protocol.FieldDelimiter) + h.readBuf.WriteString(h.hostname) + h.readBuf.WriteString(protocol.FieldDelimiter) + h.readBuf.WriteString(fmt.Sprintf("%3d", line.TransmittedPerc)) + h.readBuf.WriteString(protocol.FieldDelimiter) + h.readBuf.WriteString(fmt.Sprintf("%v", line.Count)) + h.readBuf.WriteString(protocol.FieldDelimiter) + h.readBuf.WriteString(line.SourceID) + h.readBuf.WriteString(protocol.FieldDelimiter) + } + h.readBuf.WriteString(line.Content.String()) + h.readBuf.WriteByte(protocol.MessageDelimiter) + n = copy(p, h.readBuf.Bytes()) + pool.RecycleBytesBuffer(line.Content) + + case <-time.After(time.Second): + // Once in a while check whether we are done. + select { + case <-h.done.Done(): + err = io.EOF + return + default: + } + } + return +} + +// Write is to receive data from the dtail client via Writer interface. +func (h *baseHandler) Write(p []byte) (n int, err error) { + for _, b := range p { + switch b { + case ';': + h.handleCommand(string(h.writeBuf.Bytes())) + h.writeBuf.Reset() + default: + h.writeBuf.WriteByte(b) + } + } + n = len(p) + return +} + +func (h *baseHandler) handleCommand(commandStr string) { + dlog.Server.Debug(h.user, commandStr) + + args, argc, add, err := h.handleProtocolVersion(strings.Split(commandStr, " ")) + if err != nil { + h.send(h.serverMessages, dlog.Server.Error(h.user, err)+add) + return + } + args, argc, err = h.handleBase64(args, argc) + if err != nil { + h.send(h.serverMessages, dlog.Server.Error(h.user, err)) + return + } + ctx, cancel := context.WithCancel(context.Background()) + go func() { + <-h.done.Done() + cancel() + }() + + parts := strings.Split(args[0], ":") + commandName := parts[0] + + // Either no options or empty options provided. + if len(parts) == 1 || len(parts[1]) == 0 { + h.handleCommandCb(ctx, lcontext.LContext{}, argc, args, commandName) + return + } + + options, ltx, err := config.DeserializeOptions(parts[1:]) + if err != nil { + h.send(h.serverMessages, dlog.Server.Error(h.user, err)) + return + } + h.handleOptions(options) + h.handleCommandCb(ctx, ltx, argc, args, commandName) +} + +func (h *baseHandler) handleProtocolVersion(args []string) ([]string, int, string, error) { + argc := len(args) + var add string + + if argc <= 2 || args[0] != "protocol" { + return args, argc, add, errors.New("unable to determine protocol version") + } + + if args[1] != protocol.ProtocolCompat { + clientCompat, _ := strconv.Atoi(args[1]) + serverCompat, _ := strconv.Atoi(protocol.ProtocolCompat) + if clientCompat <= 3 { + // Protocol version 3 or lower expect a newline as message separator + // One day (after 2 major versions) this exception may be removed! + add = "\n" + } + + toUpdate := "client" + if clientCompat > serverCompat { + toUpdate = "server" + } + err := fmt.Errorf("the DTail server protocol version '%s' does not match "+ + "client protocol version '%s', please update DTail %s", + protocol.ProtocolCompat, args[1], toUpdate) + return args, argc, add, err + } + + return args[2:], argc - 2, add, nil +} + +func (h *baseHandler) handleBase64(args []string, argc int) ([]string, int, error) { + err := errors.New("unable to decode client message, DTail server and client " + + "versions may not be compatible") + if argc != 2 || args[0] != "base64" { + return args, argc, err + } + + decoded, err := base64.StdEncoding.DecodeString(args[1]) + if err != nil { + return args, argc, err + } + decodedStr := string(decoded) + + args = strings.Split(decodedStr, " ") + argc = len(decodedStr) + dlog.Server.Trace(h.user, "Base64 decoded received command", + decodedStr, argc, args) + + return args, argc, nil +} + +func (h *baseHandler) handleAckCommand(argc int, args []string) { + if argc < 3 { + if !h.quiet { + h.send(h.serverMessages, dlog.Server.Warn(h.user, + "Unable to parse command", args, argc)) + } + return + } + if args[1] == "close" && args[2] == "connection" { + select { + case <-h.ackCloseReceived: + default: + close(h.ackCloseReceived) + } + } +} + +func (h *baseHandler) handleOptions(options map[string]string) { + // We have to make sure that this block is executed only once. + h.mutex.Lock() + defer h.mutex.Unlock() + // We can read the options only once, will cause a data race otherwise if + // changed multiple times for multiple incoming commands. + h.once.Do(func() { + if quiet, _ := options["quiet"]; quiet == "true" { + dlog.Server.Debug(h.user, "Enabling quiet mode") + h.quiet = true + } + if spartan, _ := options["spartan"]; spartan == "true" { + dlog.Server.Debug(h.user, "Enabling spartan mode") + h.spartan = true + } + if serverless, _ := options["serverless"]; serverless == "true" { + dlog.Server.Debug(h.user, "Enabling serverless mode") + h.serverless = true + } + }) +} + +func (h *baseHandler) send(ch chan<- string, message string) { + select { + case ch <- message: + case <-h.done.Done(): + } +} + +func (h *baseHandler) flush() { + dlog.Server.Trace(h.user, "flush()") + numUnsentMessages := func() int { + return len(h.lines) + len(h.serverMessages) + len(h.maprMessages) + } + for i := 0; i < 10; i++ { + if numUnsentMessages() == 0 { + dlog.Server.Debug(h.user, "ALL lines sent", fmt.Sprintf("%p", h)) + return + } + dlog.Server.Debug(h.user, "Still lines to be sent") + time.Sleep(time.Millisecond * 10) + } + dlog.Server.Warn(h.user, "Some lines remain unsent", numUnsentMessages()) +} + +func (h *baseHandler) shutdown() { + dlog.Server.Debug(h.user, "shutdown()") + h.flush() + + go func() { + select { + case h.serverMessages <- ".syn close connection": + case <-h.done.Done(): + } + }() + + select { + case <-h.ackCloseReceived: + case <-time.After(time.Second * 5): + dlog.Server.Debug(h.user, "Shutdown timeout reached, enforcing shutdown") + case <-h.done.Done(): + } + h.done.Shutdown() +} + +func (h *baseHandler) incrementActiveCommands() { + atomic.AddInt32(&h.activeCommands, 1) +} + +func (h *baseHandler) decrementActiveCommands() int32 { + atomic.AddInt32(&h.activeCommands, -1) + return atomic.LoadInt32(&h.activeCommands) +} diff --git a/internal/server/handlers/controlhandler.go b/internal/server/handlers/controlhandler.go deleted file mode 100644 index 1e17c78..0000000 --- a/internal/server/handlers/controlhandler.go +++ /dev/null @@ -1,98 +0,0 @@ -package handlers - -import ( - "fmt" - "io" - "os" - "strings" - - "github.com/mimecast/dtail/internal" - "github.com/mimecast/dtail/internal/io/logger" - user "github.com/mimecast/dtail/internal/user/server" -) - -// ControlHandler is used for control functions and health monitoring. -type ControlHandler struct { - done *internal.Done - hostname string - payload []byte - serverMessages chan string - user *user.User -} - -// NewControlHandler returns a new control handler. -func NewControlHandler(user *user.User) *ControlHandler { - logger.Debug(user, "Creating control handler") - - h := ControlHandler{ - done: internal.NewDone(), - serverMessages: make(chan string, 10), - user: user, - } - - fqdn, err := os.Hostname() - if err != nil { - logger.FatalExit(err) - } - - s := strings.Split(fqdn, ".") - h.hostname = s[0] - - return &h -} - -// Shutdown the handler. -func (h *ControlHandler) Shutdown() { - h.done.Shutdown() -} - -// Done channel of the handler. -func (h *ControlHandler) Done() <-chan struct{} { - return h.done.Done() -} - -// Read is to send data to the client via the Reader interface. -func (h *ControlHandler) Read(p []byte) (n int, err error) { - for { - select { - case message := <-h.serverMessages: - wholePayload := []byte(fmt.Sprintf("SERVER|%s|%s\n", h.hostname, message)) - n = copy(p, wholePayload) - return - case <-h.done.Done(): - return 0, io.EOF - } - } -} - -// Write is to read data to the client via the Writer interface. -func (h *ControlHandler) Write(p []byte) (n int, err error) { - for _, c := range p { - switch c { - case ';': - wholePayload := strings.TrimSpace(string(h.payload)) - h.handleCommand(wholePayload) - h.payload = nil - - default: - h.payload = append(h.payload, c) - } - } - - n = len(p) - return -} - -func (h *ControlHandler) handleCommand(command string) { - logger.Info(h.user, command) - s := strings.Split(command, " ") - logger.Debug(h.user, "Receiving command", command, s) - - switch s[0] { - case "health": - h.serverMessages <- "OK: DTail SSH Server seems fine" - h.serverMessages <- "done;" - default: - h.serverMessages <- logger.Error(h.user, "Received unknown control command", command, s) - } -} diff --git a/internal/server/handlers/healthhandler.go b/internal/server/handlers/healthhandler.go new file mode 100644 index 0000000..6dd9872 --- /dev/null +++ b/internal/server/handlers/healthhandler.go @@ -0,0 +1,58 @@ +package handlers + +import ( + "context" + "os" + "strings" + + "github.com/mimecast/dtail/internal" + "github.com/mimecast/dtail/internal/io/dlog" + "github.com/mimecast/dtail/internal/io/line" + "github.com/mimecast/dtail/internal/lcontext" + user "github.com/mimecast/dtail/internal/user/server" +) + +// HealthHandler is for the remote health check. +type HealthHandler struct { + baseHandler +} + +// NewHealthHandler returns the server handler. +func NewHealthHandler(user *user.User) *HealthHandler { + dlog.Server.Debug(user, "Creating new server health handler") + h := HealthHandler{ + baseHandler: baseHandler{ + done: internal.NewDone(), + lines: make(chan line.Line, 100), + serverMessages: make(chan string, 10), + maprMessages: make(chan string, 10), + ackCloseReceived: make(chan struct{}), + user: user, + }, + } + h.handleCommandCb = h.handleHealthCommand + + fqdn, err := os.Hostname() + if err != nil { + dlog.Server.FatalPanic(err) + } + s := strings.Split(fqdn, ".") + h.hostname = s[0] + return &h +} + +func (h *HealthHandler) handleHealthCommand(ctx context.Context, + ltx lcontext.LContext, argc int, args []string, commandName string) { + + dlog.Server.Debug(h.user, "Handling health command", argc, args) + switch commandName { + case "health": + h.send(h.serverMessages, "OK") + case ".ack": + h.handleAckCommand(argc, args) + default: + h.send(h.serverMessages, dlog.Server.Error(h.user, + "Received unknown health command", commandName, argc, args)) + } + h.shutdown() +} diff --git a/internal/server/handlers/mapcommand.go b/internal/server/handlers/mapcommand.go index c3e600e..65e0ed8 100644 --- a/internal/server/handlers/mapcommand.go +++ b/internal/server/handlers/mapcommand.go @@ -14,18 +14,17 @@ type mapCommand struct { } // NewMapCommand returns a new server side mapreduce command. -func newMapCommand(serverHandler *ServerHandler, argc int, args []string) (mapCommand, *server.Aggregate, error) { - m := mapCommand{server: serverHandler} +func newMapCommand(serverHandler *ServerHandler, argc int, + args []string) (mapCommand, *server.Aggregate, error) { + m := mapCommand{server: serverHandler} queryStr := strings.Join(args[1:], " ") aggregate, err := server.NewAggregate(queryStr) if err != nil { return m, nil, err } - m.aggregate = aggregate return m, aggregate, nil - } func (m mapCommand) Start(ctx context.Context, aggregatedMessages chan<- string) { diff --git a/internal/server/handlers/readcommand.go b/internal/server/handlers/readcommand.go index b659c06..4728a55 100644 --- a/internal/server/handlers/readcommand.go +++ b/internal/server/handlers/readcommand.go @@ -7,8 +7,9 @@ import ( "sync" "time" + "github.com/mimecast/dtail/internal/io/dlog" "github.com/mimecast/dtail/internal/io/fs" - "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/io/line" "github.com/mimecast/dtail/internal/lcontext" "github.com/mimecast/dtail/internal/omode" "github.com/mimecast/dtail/internal/regex" @@ -26,39 +27,45 @@ func newReadCommand(server *ServerHandler, mode omode.Mode) *readCommand { } } -func (r *readCommand) Start(ctx context.Context, lContext lcontext.LContext, argc int, args []string, retries int) { - re := regex.NewNoop() +func (r *readCommand) Start(ctx context.Context, ltx lcontext.LContext, + argc int, args []string, retries int) { + re := regex.NewNoop() if argc >= 4 { deserializedRegex, err := regex.Deserialize(strings.Join(args[2:], " ")) if err != nil { - r.server.sendServerMessage(logger.Error(r.server.user, commandParseWarning, err)) + r.server.send(r.server.serverMessages, dlog.Server.Error(r.server.user, + "Unable to parse command", err)) return } re = deserializedRegex } if argc < 3 { - r.server.sendServerWarnMessage(logger.Warn(r.server.user, commandParseWarning, args, argc)) + r.server.send(r.server.serverMessages, dlog.Server.Warn(r.server.user, + "Unable to parse command", args, argc)) return } - r.readGlob(ctx, lContext, args[1], re, retries) + r.readGlob(ctx, ltx, args[1], re, retries) } -func (r *readCommand) readGlob(ctx context.Context, lContext lcontext.LContext, glob string, re regex.Regex, retries int) { +func (r *readCommand) readGlob(ctx context.Context, ltx lcontext.LContext, + glob string, re regex.Regex, retries int) { + retryInterval := time.Second * 5 glob = filepath.Clean(glob) for retryCount := 0; retryCount < retries; retryCount++ { paths, err := filepath.Glob(glob) if err != nil { - logger.Warn(r.server.user, glob, err) + dlog.Server.Warn(r.server.user, glob, err) time.Sleep(retryInterval) continue } if numPaths := len(paths); numPaths == 0 { - logger.Error(r.server.user, "No such file(s) to read", glob) - r.server.sendServerWarnMessage(logger.Warn(r.server.user, "Unable to read file(s), check server logs")) + dlog.Server.Error(r.server.user, "No such file(s) to read", glob) + r.server.send(r.server.serverMessages, dlog.Server.Warn(r.server.user, + "Unable to read file(s), check server logs")) select { case <-ctx.Done(): return @@ -68,41 +75,44 @@ func (r *readCommand) readGlob(ctx context.Context, lContext lcontext.LContext, continue } - r.readFiles(ctx, lContext, paths, glob, re, retryInterval) + r.readFiles(ctx, ltx, paths, glob, re, retryInterval) return } - r.server.sendServerWarnMessage(logger.Warn(r.server.user, "Giving up to read file(s)")) + r.server.send(r.server.serverMessages, dlog.Server.Warn(r.server.user, + "Giving up to read file(s)")) return } -func (r *readCommand) readFiles(ctx context.Context, lContext lcontext.LContext, paths []string, glob string, re regex.Regex, retryInterval time.Duration) { +func (r *readCommand) readFiles(ctx context.Context, ltx lcontext.LContext, + paths []string, glob string, re regex.Regex, retryInterval time.Duration) { + var wg sync.WaitGroup wg.Add(len(paths)) - for _, path := range paths { - go r.readFileIfPermissions(ctx, lContext, &wg, path, glob, re) + go r.readFileIfPermissions(ctx, ltx, &wg, path, glob, re) } - wg.Wait() } -func (r *readCommand) readFileIfPermissions(ctx context.Context, lContext lcontext.LContext, wg *sync.WaitGroup, path, glob string, re regex.Regex) { +func (r *readCommand) readFileIfPermissions(ctx context.Context, ltx lcontext.LContext, + wg *sync.WaitGroup, path, glob string, re regex.Regex) { + defer wg.Done() globID := r.makeGlobID(path, glob) - if !r.server.user.HasFilePermission(path, "readfiles") { - logger.Error(r.server.user, "No permission to read file", path, globID) - r.server.sendServerWarnMessage(logger.Warn(r.server.user, "Unable to read file(s), check server logs")) + dlog.Server.Error(r.server.user, "No permission to read file", path, globID) + r.server.send(r.server.serverMessages, dlog.Server.Warn(r.server.user, + "Unable to read file(s), check server logs")) return } - - r.readFile(ctx, lContext, path, globID, re) + r.readFile(ctx, ltx, path, globID, re) } -func (r *readCommand) readFile(ctx context.Context, lContext lcontext.LContext, path, globID string, re regex.Regex) { - logger.Info(r.server.user, "Start reading file", path, globID) +func (r *readCommand) readFile(ctx context.Context, ltx lcontext.LContext, + path, globID string, re regex.Regex) { + dlog.Server.Info(r.server.user, "Start reading file", path, globID) var reader fs.FileReader switch r.mode { case omode.TailClient: @@ -114,15 +124,19 @@ func (r *readCommand) readFile(ctx context.Context, lContext lcontext.LContext, } lines := r.server.lines - - // Plug in mappreduce engine - if r.server.aggregate != nil { - lines = r.server.aggregate.Lines - } + aggregate := r.server.aggregate for { - if err := reader.Start(ctx, lContext, lines, re); err != nil { - logger.Error(r.server.user, path, globID, err) + if aggregate != nil { + lines = make(chan line.Line, 100) + aggregate.NextLinesCh <- lines + } + if err := reader.Start(ctx, ltx, lines, re); err != nil { + dlog.Server.Error(r.server.user, path, globID, err) + } + if aggregate != nil { + // Also makes aggregate to Flush + close(lines) } select { @@ -133,9 +147,8 @@ func (r *readCommand) readFile(ctx context.Context, lContext lcontext.LContext, return } } - time.Sleep(time.Second * 2) - logger.Info(path, globID, "Reading file again") + dlog.Server.Info(path, globID, "Reading file again") } } @@ -152,11 +165,11 @@ func (r *readCommand) makeGlobID(path, glob string) string { if len(idParts) > 0 { return strings.Join(idParts, "/") } - if len(pathParts) > 0 { return pathParts[len(pathParts)-1] } - r.server.sendServerWarnMessage(logger.Warn("Empty file path given?", path, glob)) + r.server.send(r.server.serverMessages, + dlog.Server.Warn("Empty file path given?", path, glob)) return "" } diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go index 39d5d5f..36574a9 100644 --- a/internal/server/handlers/serverhandler.go +++ b/internal/server/handlers/serverhandler.go @@ -2,69 +2,50 @@ package handlers import ( "context" - "encoding/base64" - "errors" - "fmt" - "io" "os" - "strconv" "strings" - "sync/atomic" - "time" "github.com/mimecast/dtail/internal" - "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/io/dlog" "github.com/mimecast/dtail/internal/io/line" - "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/lcontext" - "github.com/mimecast/dtail/internal/mapr/server" "github.com/mimecast/dtail/internal/omode" user "github.com/mimecast/dtail/internal/user/server" - "github.com/mimecast/dtail/internal/version" -) - -const ( - commandParseWarning string = "Unable to parse command" ) // ServerHandler implements the Reader and Writer interfaces to handle // the Bi-directional communication between SSH client and server. // This handler implements the handler of the SSH server. type ServerHandler struct { - done *internal.Done - lines chan line.Line - regex string - aggregate *server.Aggregate - aggregatedMessages chan string - serverMessages chan string - payload []byte - hostname string - user *user.User - catLimiter chan struct{} - tailLimiter chan struct{} - ackCloseReceived chan struct{} - activeCommands int32 - activeReaders int32 - quiet bool + baseHandler + catLimiter chan struct{} + tailLimiter chan struct{} + regex string } // NewServerHandler returns the server handler. -func NewServerHandler(user *user.User, catLimiter, tailLimiter chan struct{}) *ServerHandler { +func NewServerHandler(user *user.User, catLimiter, + tailLimiter chan struct{}) *ServerHandler { + + dlog.Server.Debug(user, "Creating new server handler") h := ServerHandler{ - done: internal.NewDone(), - lines: make(chan line.Line, 100), - serverMessages: make(chan string, 10), - aggregatedMessages: make(chan string, 10), - ackCloseReceived: make(chan struct{}), - catLimiter: catLimiter, - tailLimiter: tailLimiter, - regex: ".", - user: user, - } + baseHandler: baseHandler{ + done: internal.NewDone(), + lines: make(chan line.Line, 100), + serverMessages: make(chan string, 10), + maprMessages: make(chan string, 10), + ackCloseReceived: make(chan struct{}), + user: user, + }, + catLimiter: catLimiter, + tailLimiter: tailLimiter, + regex: ".", + } + h.handleCommandCb = h.handleUserCommand fqdn, err := os.Hostname() if err != nil { - logger.FatalExit(err) + dlog.Server.FatalPanic(err) } s := strings.Split(fqdn, ".") @@ -73,374 +54,49 @@ func NewServerHandler(user *user.User, catLimiter, tailLimiter chan struct{}) *S return &h } -// Shutdown the handler. -func (h *ServerHandler) Shutdown() { - h.done.Shutdown() -} - -// Done channel of the handler. -func (h *ServerHandler) Done() <-chan struct{} { - return h.done.Done() -} - -// Read is to send data to the dtail client via Reader interface. -func (h *ServerHandler) Read(p []byte) (n int, err error) { - for { - select { - case message := <-h.serverMessages: - if len(message) == 0 { - logger.Warn(h.user, "Empty message received") - return - } - if message[0] == '.' { - // Handle hidden message (don't display to the user, interpreted by dtail client) - wholePayload := []byte(fmt.Sprintf("%s\n", message)) - n = copy(p, wholePayload) - return - } - - // Handle normal server message (display to the user) - wholePayload := []byte(fmt.Sprintf("SERVER|%s|%s\n", h.hostname, message)) - n = copy(p, wholePayload) - return - - case message := <-h.aggregatedMessages: - // Send mapreduce-aggregated data as a message. - data := fmt.Sprintf("AGGREGATE➔%s➔%s\n", h.hostname, message) - wholePayload := []byte(data) - n = copy(p, wholePayload) - return - - case line := <-h.lines: - // Send normal file content data as a message. - serverInfo := []byte(fmt.Sprintf("REMOTE|%s|%3d|%v|%s|", - h.hostname, line.TransmittedPerc, line.Count, line.SourceID)) - wholePayload := append(serverInfo, line.Content[:]...) - n = copy(p, wholePayload) - return - - case <-time.After(time.Second): - // Once in a while check whether we are done. - select { - case <-h.done.Done(): - return 0, io.EOF - default: - } - } - } -} - -// Write is to receive data from the dtail client via Writer interface. -func (h *ServerHandler) Write(p []byte) (n int, err error) { - for _, c := range p { - switch c { - case ';': - commandStr := strings.TrimSpace(string(h.payload)) - h.handleCommand(commandStr) - h.payload = nil - default: - h.payload = append(h.payload, c) - } - } - - n = len(p) - return -} - -func (h *ServerHandler) handleCommand(commandStr string) { - logger.Debug(h.user, commandStr) - ctx := context.Background() - - args, argc, err := h.handleProtocolVersion(strings.Split(commandStr, " ")) - if err != nil { - h.send(h.serverMessages, logger.Error(h.user, err)) - return - } - - args, argc, err = h.handleBase64(args, argc) - if err != nil { - h.send(h.serverMessages, logger.Error(h.user, err)) - return - } - - if h.user.Name == config.ControlUser { - h.handleControlCommand(argc, args) - return - } - - ctx, cancel := context.WithCancel(ctx) - go func() { - <-h.done.Done() - cancel() - }() - - h.handleUserCommand(ctx, argc, args) -} - -func (h *ServerHandler) handleProtocolVersion(args []string) ([]string, int, error) { - argc := len(args) - - if argc <= 2 || args[0] != "protocol" { - return args, argc, errors.New("unable to determine protocol version") - } - - if args[1] != version.ProtocolCompat { - err := fmt.Errorf("server with protocol version '%s' but client with '%s', please update DTail", version.ProtocolCompat, args[1]) - return args, argc, err - } - - return args[2:], argc - 2, nil -} - -func (h *ServerHandler) handleBase64(args []string, argc int) ([]string, int, error) { - err := errors.New("Unable to decode client message, DTail server and client versions may not be compatible") - - if argc != 2 || args[0] != "base64" { - return args, argc, err - } - - decoded, err := base64.StdEncoding.DecodeString(args[1]) - if err != nil { - return args, argc, err - } - decodedStr := string(decoded) - - args = strings.Split(decodedStr, " ") - argc = len(decodedStr) - logger.Trace(h.user, "Base64 decoded received command", decodedStr, argc, args) - - return args, argc, nil -} - -func (h *ServerHandler) handleControlCommand(argc int, args []string) { - switch args[0] { - case "debug": - h.send(h.serverMessages, logger.Debug(h.user, "Receiving debug command", argc, args)) - default: - logger.Warn(h.user, "Received unknown control command", argc, args) - } -} - -func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []string) { - logger.Debug(h.user, "handleUserCommand", argc, args) +func (h *ServerHandler) handleUserCommand(ctx context.Context, ltx lcontext.LContext, + argc int, args []string, commandName string) { + dlog.Server.Debug(h.user, "Handling user command", argc, args) h.incrementActiveCommands() commandFinished := func() { if h.decrementActiveCommands() == 0 { h.shutdown() } } - readerFinished := func() { - if h.decrementActiveReaders() == 0 { - if h.aggregate == nil { - return - } - h.aggregate.Shutdown() - } - } - - splitted := strings.Split(args[0], ":") - commandName := splitted[0] - - options, lContext, err := readOptions(splitted[1:]) - if err != nil { - h.sendServerMessage(logger.Error(h.user, err)) - commandFinished() - return - } - if quiet, ok := options["quiet"]; ok { - if quiet == "true" { - logger.Debug(h.user, "Enabling quiet mode") - h.quiet = true - } - } switch commandName { case "grep", "cat": command := newReadCommand(h, omode.CatClient) go func() { - h.incrementActiveReaders() - command.Start(ctx, lContext, argc, args, 1) - readerFinished() + command.Start(ctx, ltx, argc, args, 1) commandFinished() }() - case "tail": command := newReadCommand(h, omode.TailClient) go func() { - h.incrementActiveReaders() - command.Start(ctx, lContext, argc, args, 10) - readerFinished() + command.Start(ctx, ltx, argc, args, 10) commandFinished() }() - case "map": command, aggregate, err := newMapCommand(h, argc, args) if err != nil { - h.sendServerMessage(err.Error()) - logger.Error(h.user, err) + h.send(h.serverMessages, err.Error()) + dlog.Server.Error(h.user, err) commandFinished() return } - h.aggregate = aggregate go func() { - command.Start(ctx, h.aggregatedMessages) + command.Start(ctx, h.maprMessages) commandFinished() }() - - case "ack", ".ack": + case ".ack": h.handleAckCommand(argc, args) commandFinished() - default: - h.sendServerMessage(logger.Error(h.user, "Received unknown user command", commandName, argc, args, options)) + h.send(h.serverMessages, dlog.Server.Error(h.user, + "Received unknown user command", commandName, argc, args)) commandFinished() } } - -func (h *ServerHandler) handleAckCommand(argc int, args []string) { - if argc < 3 { - h.sendServerWarnMessage(logger.Warn(h.user, commandParseWarning, args, argc)) - return - } - if args[1] == "close" && args[2] == "connection" { - close(h.ackCloseReceived) - } -} - -func (h *ServerHandler) send(ch chan<- string, message string) { - select { - case ch <- message: - case <-h.done.Done(): - } -} - -func (h *ServerHandler) sendServerMessage(message string) { - h.send(h.serverMessageC(), message) -} - -func (h *ServerHandler) sendServerWarnMessage(message string) { - if h.quiet { - return - } - h.send(h.serverMessageC(), message) -} - -func (h *ServerHandler) serverMessageC() chan<- string { - return h.serverMessages -} - -func (h *ServerHandler) flush() { - logger.Debug(h.user, "flush()") - - if h.aggregate != nil { - h.aggregate.Flush() - } - - unsentMessages := func() int { - return len(h.lines) + len(h.serverMessages) + len(h.aggregatedMessages) - } - for i := 0; i < 3; i++ { - if unsentMessages() == 0 { - logger.Debug(h.user, "All lines sent") - return - } - logger.Debug(h.user, "Still lines to be sent") - time.Sleep(time.Second) - } - - logger.Warn(h.user, "Some lines remain unsent", unsentMessages()) -} - -func (h *ServerHandler) shutdown() { - logger.Debug(h.user, "shutdown()") - h.flush() - - go func() { - select { - case h.serverMessageC() <- ".syn close connection": - case <-h.done.Done(): - } - }() - - select { - case <-h.ackCloseReceived: - case <-time.After(time.Second * 5): - logger.Debug(h.user, "Shutdown timeout reached, enforcing shutdown") - case <-h.done.Done(): - } - - h.done.Shutdown() -} - -func (h *ServerHandler) incrementActiveCommands() { - atomic.AddInt32(&h.activeCommands, 1) -} - -func (h *ServerHandler) decrementActiveCommands() int32 { - atomic.AddInt32(&h.activeCommands, -1) - return atomic.LoadInt32(&h.activeCommands) -} - -func (h *ServerHandler) incrementActiveReaders() { - atomic.AddInt32(&h.activeReaders, 1) -} - -func (h *ServerHandler) decrementActiveReaders() int32 { - atomic.AddInt32(&h.activeReaders, -1) - return atomic.LoadInt32(&h.activeReaders) -} - -// TODO: All options related code should be in its own package (client + server) -func readOptions(opts []string) (map[string]string, lcontext.LContext, error) { - options := make(map[string]string, len(opts)) - // Local search context - var lContext lcontext.LContext - - for _, o := range opts { - kv := strings.SplitN(o, "=", 2) - if len(kv) != 2 { - continue - } - key := kv[0] - val := kv[1] - - if strings.HasPrefix(val, "base64%") { - s := strings.SplitN(val, "%", 2) - decoded, err := base64.StdEncoding.DecodeString(s[1]) - if err != nil { - return options, lContext, err - } - val = string(decoded) - } - - switch key { - case "before": - iVal, err := strconv.Atoi(val) - if err != nil { - logger.Error(err) - continue - } - lContext.BeforeContext = iVal - case "after": - iVal, err := strconv.Atoi(val) - if err != nil { - logger.Error(err) - continue - } - lContext.AfterContext = iVal - case "max": - iVal, err := strconv.Atoi(val) - if err != nil { - logger.Error(err) - continue - } - lContext.MaxCount = iVal - default: - options[key] = val - } - } - - return options, lContext, nil -} diff --git a/internal/server/scheduler.go b/internal/server/scheduler.go index a1e9e36..0ba65f7 100644 --- a/internal/server/scheduler.go +++ b/internal/server/scheduler.go @@ -10,25 +10,23 @@ import ( "github.com/mimecast/dtail/internal/clients" "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/io/dlog" "github.com/mimecast/dtail/internal/omode" gossh "golang.org/x/crypto/ssh" ) -type scheduler struct { -} +type scheduler struct{} func newScheduler() *scheduler { return &scheduler{} } func (s *scheduler) start(ctx context.Context) { - logger.Info("Starting scheduled job runner after 10s") + dlog.Server.Info("Starting scheduled job runner after 10s") // First run after just 10s! time.Sleep(time.Second * 10) s.runJobs(ctx) - for { select { case <-time.After(time.Minute): @@ -42,27 +40,24 @@ func (s *scheduler) start(ctx context.Context) { func (s *scheduler) runJobs(ctx context.Context) { for _, job := range config.Server.Schedule { if !job.Enable { - logger.Debug(job.Name, "Not running job as not enabled") + dlog.Server.Debug(job.Name, "Not running job as not enabled") continue } - hour, err := strconv.Atoi(time.Now().Format("15")) if err != nil { - logger.Error(job.Name, "Unable to create job", err) + dlog.Server.Error(job.Name, "Unable to create job", err) continue } - if hour < job.TimeRange[0] || hour >= job.TimeRange[1] { - logger.Debug(job.Name, "Not running job out of time range") + dlog.Server.Debug(job.Name, "Not running job out of time range") continue } files := fillDates(job.Files) outfile := fillDates(job.Outfile) - _, err = os.Stat(outfile) if !os.IsNotExist(err) { - logger.Debug(job.Name, "Not running job as outfile already exists", outfile) + dlog.Server.Debug(job.Name, "Not running job as outfile already exists", outfile) continue } @@ -70,9 +65,8 @@ func (s *scheduler) runJobs(ctx context.Context) { if servers == "" { servers = config.Server.SSHBindAddress } - - args := clients.Args{ - ConnectionsPerCPU: 10, + args := config.Args{ + ConnectionsPerCPU: config.DefaultConnectionsPerCPU, Discovery: job.Discovery, ServersStr: servers, What: files, @@ -81,25 +75,24 @@ func (s *scheduler) runJobs(ctx context.Context) { } args.SSHAuthMethods = append(args.SSHAuthMethods, gossh.Password(job.Name)) - - query := fmt.Sprintf("%s outfile %s", job.Query, outfile) - client, err := clients.NewMaprClient(args, query, clients.CumulativeMode) + args.QueryStr = fmt.Sprintf("%s outfile %s", job.Query, outfile) + client, err := clients.NewMaprClient(args, clients.CumulativeMode) if err != nil { - logger.Error(fmt.Sprintf("Unable to create job %s", job.Name), err) + dlog.Server.Error(fmt.Sprintf("Unable to create job %s", job.Name), err) continue } jobCtx, cancel := context.WithCancel(ctx) defer cancel() - logger.Info(fmt.Sprintf("Starting job %s", job.Name)) + dlog.Server.Info(fmt.Sprintf("Starting job %s", job.Name)) status := client.Start(jobCtx, make(chan string)) logMessage := fmt.Sprintf("Job exited with status %d", status) if status != 0 { - logger.Warn(logMessage) + dlog.Server.Warn(logMessage) continue } - logger.Info(logMessage) + dlog.Server.Info(logMessage) } } diff --git a/internal/server/server.go b/internal/server/server.go index 3640208..0cb5e27 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -9,7 +9,7 @@ import ( "strings" "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/io/dlog" "github.com/mimecast/dtail/internal/server/handlers" "github.com/mimecast/dtail/internal/ssh/server" user "github.com/mimecast/dtail/internal/user/server" @@ -24,9 +24,9 @@ type Server struct { stats stats // SSH server configuration. sshServerConfig *gossh.ServerConfig - // To control the max amount of concurrent cats (which can cause a lot of I/O on the server) + // To control the max amount of concurrent cats. catLimiter chan struct{} - // To control the max amount of concurrent tails + // To control the max amount of concurrent tails. tailLimiter chan struct{} // To run scheduled tasks (if configured) sched *scheduler @@ -36,7 +36,7 @@ type Server struct { // New returns a new server. func New() *Server { - logger.Info("Creating server", version.String()) + dlog.Server.Info("Creating server", version.String()) s := Server{ sshServerConfig: &gossh.ServerConfig{}, @@ -51,7 +51,7 @@ func New() *Server { private, err := gossh.ParsePrivateKey(server.PrivateHostKey()) if err != nil { - logger.FatalExit(err) + dlog.Server.FatalPanic(err) } s.sshServerConfig.AddHostKey(private) @@ -60,14 +60,13 @@ func New() *Server { // Start the server. func (s *Server) Start(ctx context.Context) int { - logger.Info("Starting server") - + dlog.Server.Info("Starting server") bindAt := fmt.Sprintf("%s:%d", config.Server.SSHBindAddress, config.Common.SSHPort) - logger.Info("Binding server", bindAt) + dlog.Server.Info("Binding server", bindAt) listener, err := net.Listen("tcp", bindAt) if err != nil { - logger.FatalExit("Failed to open listening TCP socket", err) + dlog.Server.FatalPanic("Failed to open listening TCP socket", err) } go s.stats.start(ctx) @@ -76,14 +75,12 @@ func (s *Server) Start(ctx context.Context) int { go s.listenerLoop(ctx, listener) <-ctx.Done() - // For future use. return 0 } func (s *Server) listenerLoop(ctx context.Context, listener net.Listener) { - logger.Debug("Starting listener loop") - + dlog.Server.Debug("Starting listener loop") for { conn, err := listener.Accept() // Blocking if err != nil { @@ -92,63 +89,69 @@ func (s *Server) listenerLoop(ctx context.Context, listener net.Listener) { return default: } - logger.Error("Failed to accept incoming connection", err) + dlog.Server.Error("Failed to accept incoming connection", err) continue } if err := s.stats.serverLimitExceeded(); err != nil { - logger.Error(err) + dlog.Server.Error(err) conn.Close() continue } - go s.handleConnection(ctx, conn) } } func (s *Server) handleConnection(ctx context.Context, conn net.Conn) { - logger.Info("Handling connection") + dlog.Server.Info("Handling connection") sshConn, chans, reqs, err := gossh.NewServerConn(conn, s.sshServerConfig) if err != nil { - logger.Error("Something just happened", err) + dlog.Server.Error("Something just happened", err) return } s.stats.incrementConnections() - go gossh.DiscardRequests(reqs) for newChannel := range chans { go s.handleChannel(ctx, sshConn, newChannel) } } -func (s *Server) handleChannel(ctx context.Context, sshConn gossh.Conn, newChannel gossh.NewChannel) { - user := user.New(sshConn.User(), sshConn.RemoteAddr().String()) - logger.Info(user, "Invoking channel handler") +func (s *Server) handleChannel(ctx context.Context, sshConn gossh.Conn, + newChannel gossh.NewChannel) { + user, err := user.New(sshConn.User(), sshConn.RemoteAddr().String()) + if err != nil { + dlog.Server.Error(user, err) + newChannel.Reject(gossh.Prohibited, err.Error()) + return + } + + dlog.Server.Info(user, "Invoking channel handler") if newChannel.ChannelType() != "session" { err := errors.New("Don'w allow other channel types than session") - logger.Error(user, err) + dlog.Server.Error(user, err) newChannel.Reject(gossh.Prohibited, err.Error()) return } channel, requests, err := newChannel.Accept() if err != nil { - logger.Error(user, "Could not accept channel", err) + dlog.Server.Error(user, "Could not accept channel", err) return } if err := s.handleRequests(ctx, sshConn, requests, channel, user); err != nil { - logger.Error(user, "While handling request", err) + dlog.Server.Error(user, err) sshConn.Close() } } -func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, in <-chan *gossh.Request, channel gossh.Channel, user *user.User) error { - logger.Info(user, "Invoking request handler") +func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, + in <-chan *gossh.Request, channel gossh.Channel, user *user.User) error { + dlog.Server.Info(user, "Invoking request handler") for req := range in { var payload = struct{ Value string }{} gossh.Unmarshal(req.Payload, &payload) @@ -157,12 +160,11 @@ func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, in <-ch case "shell": var handler handlers.Handler switch user.Name { - case config.ControlUser: - handler = handlers.NewControlHandler(user) + case config.HealthUser: + handler = handlers.NewHealthHandler(user) default: handler = handlers.NewServerHandler(user, s.catLimiter, s.tailLimiter) } - terminate := func() { handler.Shutdown() sshConn.Close() @@ -173,13 +175,11 @@ func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, in <-ch io.Copy(channel, handler) terminate() }() - go func() { // Broken pipe, cancel io.Copy(handler, channel) terminate() }() - go func() { select { case <-ctx.Done(): @@ -187,62 +187,61 @@ func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, in <-ch } terminate() }() - go func() { if err := sshConn.Wait(); err != nil && err != io.EOF { - // Use of closed network connection. - logger.Debug(user, "While waiting for ssh connection", err) + dlog.Server.Error(user, err) } s.stats.decrementConnections() - logger.Info(user, "Good bye Mister!") + dlog.Server.Info(user, "Good bye Mister!") terminate() }() // Only serving shell type req.Reply(true, nil) - default: req.Reply(false, nil) - - return fmt.Errorf("Closing SSH connection as unknown request received|%s|%v", + return fmt.Errorf("Closing SSH connection as unknown request recieved|%s|%v", req.Type, payload.Value) } } - return nil } // Callback for SSH authentication. -func (s *Server) Callback(c gossh.ConnMetadata, authPayload []byte) (*gossh.Permissions, error) { - user := user.New(c.User(), c.RemoteAddr().String()) +func (s *Server) Callback(c gossh.ConnMetadata, + authPayload []byte) (*gossh.Permissions, error) { + + user, err := user.New(c.User(), c.RemoteAddr().String()) + if err != nil { + return nil, err + } if config.ServerRelaxedAuthEnable { - logger.Fatal(user, "Granting permissions via relaxed-auth") + dlog.Server.Fatal(user, "Granting permissions via relaxed-auth") return nil, nil } authInfo := string(authPayload) - splitted := strings.Split(c.RemoteAddr().String(), ":") remoteIP := splitted[0] switch user.Name { - case config.ControlUser: - if authInfo == config.ControlUser { - logger.Debug(user, "Granting permissions to control user") + case config.HealthUser: + if authInfo == config.HealthUser { + dlog.Server.Debug(user, "Granting permissions to health user") return nil, nil } case config.ScheduleUser: for _, job := range config.Server.Schedule { if s.backgroundCanSSH(user, authInfo, remoteIP, job.Name, job.AllowFrom) { - logger.Debug(user, "Granting SSH connection") + dlog.Server.Debug(user, "Granting SSH connection") return nil, nil } } case config.ContinuousUser: for _, job := range config.Server.Continuous { if s.backgroundCanSSH(user, authInfo, remoteIP, job.Name, job.AllowFrom) { - logger.Debug(user, "Granting SSH connection") + dlog.Server.Debug(user, "Granting SSH connection") return nil, nil } } @@ -252,23 +251,26 @@ func (s *Server) Callback(c gossh.ConnMetadata, authPayload []byte) (*gossh.Perm return nil, fmt.Errorf("user %s not authorized", user) } -func (s *Server) backgroundCanSSH(user *user.User, jobName, remoteIP, allowedJobName string, allowFrom []string) bool { - logger.Debug("backgroundCanSSH", user, jobName, remoteIP, allowedJobName, allowFrom) +func (s *Server) backgroundCanSSH(user *user.User, jobName, remoteIP, + allowedJobName string, allowFrom []string) bool { + dlog.Server.Debug("backgroundCanSSH", user, jobName, remoteIP, allowedJobName, allowFrom) if jobName != allowedJobName { - logger.Debug(user, jobName, "backgroundCanSSH", "Job name does not match, skipping to next one...", allowedJobName) + dlog.Server.Debug(user, jobName, "backgroundCanSSH", + "Job name does not match, skipping to next one...", allowedJobName) return false } for _, myAddr := range allowFrom { ips, err := net.LookupIP(myAddr) if err != nil { - logger.Debug(user, jobName, "backgroundCanSSH", "Unable to lookup IP address for allowed hosts lookup, skipping to next one...", myAddr, err) + dlog.Server.Debug(user, jobName, "backgroundCanSSH", "Unable to lookup IP "+ + "address for allowed hosts lookup, skipping to next one...", myAddr, err) continue } - for _, ip := range ips { - logger.Debug(user, jobName, "backgroundCanSSH", "Comparing IP addresses", remoteIP, ip.String()) + dlog.Server.Debug(user, jobName, "backgroundCanSSH", "Comparing IP addresses", + remoteIP, ip.String()) if remoteIP == ip.String() { return true } diff --git a/internal/server/stats.go b/internal/server/stats.go index ac579ad..99a644a 100644 --- a/internal/server/stats.go +++ b/internal/server/stats.go @@ -3,12 +3,11 @@ package server import ( "context" "fmt" - "runtime" "sync" "time" "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/io/dlog" ) // Used to collect and display various server stats. @@ -20,7 +19,6 @@ type stats struct { func (s *stats) incrementConnections() { defer s.logServerStats() - s.mutex.Lock() s.currentConnections++ s.lifetimeConnections++ @@ -29,7 +27,6 @@ func (s *stats) incrementConnections() { func (s *stats) decrementConnections() { defer s.logServerStats() - s.mutex.Lock() s.currentConnections-- s.mutex.Unlock() @@ -41,8 +38,8 @@ func (s *stats) hasConnections() bool { s.mutex.Unlock() has := currentConnections > 0 - logger.Info("stats", "Server with open connections?", has, currentConnections) - + dlog.Server.Info("stats", "Server with open connections?", + has, currentConnections) return has } @@ -50,10 +47,10 @@ func (s *stats) logServerStats() { s.mutex.Lock() defer s.mutex.Unlock() - currentConnections := fmt.Sprintf("currentConnections=%d", s.currentConnections) - lifetimeConnections := fmt.Sprintf("lifetimeConnections=%d", s.lifetimeConnections) - goroutines := fmt.Sprintf("goroutines=%d", runtime.NumGoroutine()) - logger.Info("stats", currentConnections, lifetimeConnections, goroutines) + data := make(map[string]interface{}) + data["currentConnections"] = s.currentConnections + data["lifetimeConnections"] = s.lifetimeConnections + dlog.Server.Mapreduce("STATS", data) } func (s *stats) serverLimitExceeded() error { @@ -61,9 +58,9 @@ func (s *stats) serverLimitExceeded() error { defer s.mutex.Unlock() if s.currentConnections >= config.Server.MaxConnections { - return fmt.Errorf("Exceeded max allowed concurrent connections of %d", config.Server.MaxConnections) + return fmt.Errorf("Exceeded max allowed concurrent connections of %d", + config.Server.MaxConnections) } - return nil } diff --git a/internal/source/source.go b/internal/source/source.go new file mode 100644 index 0000000..4bb0784 --- /dev/null +++ b/internal/source/source.go @@ -0,0 +1,30 @@ +package source + +// Source specifies the origin of either the current process (dtail is a client +// process, dserver is a server process) or the source code package (e.g. +// dserver server side code or dtail client side code). Notice that dtail client +// may also executes server code directly (e.g. via serverless mode) and that +// the dserver may also executes client code (e.g. via scheduled server side +// mapreduce queries). +type Source int + +const ( + // Client process or source code package. + Client Source = iota + // Server process or source code package. + Server Source = iota + // HealthCheck process or client source code package. + HealthCheck Source = iota +) + +func (s Source) String() string { + switch s { + case Client: + return "CLIENT" + case Server: + return "SERVER" + case HealthCheck: + return "HEALTHCHECK" + } + panic("Unknown source type") +} diff --git a/internal/ssh/client/authmethods.go b/internal/ssh/client/authmethods.go index bbfb7be..37f8382 100644 --- a/internal/ssh/client/authmethods.go +++ b/internal/ssh/client/authmethods.go @@ -4,89 +4,106 @@ import ( "os" "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/io/dlog" "github.com/mimecast/dtail/internal/ssh" gossh "golang.org/x/crypto/ssh" ) // InitSSHAuthMethods initialises all known SSH auth methods on the client side. -func InitSSHAuthMethods(sshAuthMethods []gossh.AuthMethod, hostKeyCallback gossh.HostKeyCallback, trustAllHosts bool, throttleCh chan struct{}, privateKeyPath string) ([]gossh.AuthMethod, HostKeyCallback) { +func InitSSHAuthMethods(sshAuthMethods []gossh.AuthMethod, + hostKeyCallback gossh.HostKeyCallback, trustAllHosts bool, throttleCh chan struct{}, + privateKeyPath string) ([]gossh.AuthMethod, HostKeyCallback) { + if len(sshAuthMethods) > 0 { simpleCallback, err := NewSimpleCallback() if err != nil { - logger.FatalExit(err) + dlog.Client.FatalPanic(err) } return sshAuthMethods, simpleCallback } - return initKnownHostsAuthMethods(trustAllHosts, throttleCh, privateKeyPath) } -func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{}, privateKeyPath string) ([]gossh.AuthMethod, HostKeyCallback) { - var sshAuthMethods []gossh.AuthMethod +func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{}, + privateKeyPath string) ([]gossh.AuthMethod, HostKeyCallback) { + var sshAuthMethods []gossh.AuthMethod knownHostsPath := os.Getenv("HOME") + "/.ssh/known_hosts" - knownHostsCallback, err := NewKnownHostsCallback(knownHostsPath, trustAllHosts, throttleCh) + knownHostsCallback, err := NewKnownHostsCallback(knownHostsPath, trustAllHosts, + throttleCh) if err != nil { - logger.FatalExit(knownHostsPath, err) - } - logger.Debug("initKnownHostsAuthMethods", "Added known hosts file path", knownHostsPath) - - if config.Common.ExperimentalFeaturesEnable { - sshAuthMethods = append(sshAuthMethods, gossh.Password("experimental feature test")) - logger.Debug("initKnownHostsAuthMethods", "Added experimental method to list of auth methods") + dlog.Client.FatalPanic(knownHostsPath, err) } + dlog.Client.Debug("initKnownHostsAuthMethods", "Added known hosts file path", knownHostsPath) + /* + if config.Client.ExperimentalFeaturesEnable { + sshAuthMethods = append(sshAuthMethods, gossh.Password("experimental feature test")) + dlog.Client.Debug("initKnownHostsAuthMethods", "Added experimental method to list of auth methods") + } + */ // First try to read custom private key path. if privateKeyPath != "" { authMethod, err := ssh.PrivateKey(privateKeyPath) if err == nil { sshAuthMethods = append(sshAuthMethods, authMethod) - logger.Debug("initKnownHostsAuthMethods", "Added path to list of auth methods, not adding further methods", privateKeyPath) + dlog.Client.Debug("initKnownHostsAuthMethods", + "Added path to list of auth methods, not adding further methods", + privateKeyPath) return sshAuthMethods, knownHostsCallback } - logger.FatalExit("Unable to use private SSH key", privateKeyPath, err) + dlog.Client.FatalPanic("Unable to use private SSH key", privateKeyPath, err) } // Second, try SSH Agent authMethod, err := ssh.Agent() if err == nil { sshAuthMethods = append(sshAuthMethods, authMethod) - logger.Debug("initKnownHostsAuthMethods", "Added SSH Agent (SSH_AUTH_SOCK) to list of auth methods, not adding further methods") + dlog.Client.Debug("initKnownHostsAuthMethods", "Added SSH Agent (SSH_AUTH_SOCK)"+ + "to list of auth methods, not adding further methods") return sshAuthMethods, knownHostsCallback } - logger.Debug("initKnownHostsAuthMethods", "Unable to init SSH Agent auth method", err) + dlog.Client.Debug("initKnownHostsAuthMethods", + "Unable to init SSH Agent auth method", err) // Third, try Linux/UNIX default key paths privateKeyPath = os.Getenv("HOME") + "/.ssh/id_rsa" authMethod, err = ssh.PrivateKey(privateKeyPath) if err == nil { sshAuthMethods = append(sshAuthMethods, authMethod) - logger.Debug("initKnownHostsAuthmethods", "Added path to list of auth methods, not adding further methods", privateKeyPath) + dlog.Client.Debug("initKnownHostsAuthmethods", + "Added path to list of auth methods, not adding further methods", privateKeyPath) return sshAuthMethods, knownHostsCallback } - logger.Debug("initKnownHostsAuthMethods", "Unable to use private key", privateKeyPath, err) + dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to use private key", + privateKeyPath, err) privateKeyPath = os.Getenv("HOME") + "/.ssh/id_dsa" authMethod, err = ssh.PrivateKey(privateKeyPath) if err == nil { sshAuthMethods = append(sshAuthMethods, authMethod) - logger.Debug("initKnownHostsAuthmethods", "Added path to list of auth methods, not adding further methods", privateKeyPath) + dlog.Client.Debug("initKnownHostsAuthmethods", + "Added path to list of auth methods, not adding further methods", privateKeyPath) return sshAuthMethods, knownHostsCallback } - logger.Debug("initKnownHostsAuthMethods", "Unable to use private key", privateKeyPath, err) privateKeyPath = os.Getenv("HOME") + "/.ssh/id_ecdsa" authMethod, err = ssh.PrivateKey(privateKeyPath) if err == nil { sshAuthMethods = append(sshAuthMethods, authMethod) - logger.Debug("initKnownHostsAuthmethods", "Added path to list of auth methods, not adding further methods", privateKeyPath) + dlog.Client.Debug("initKnownHostsAuthmethods", + "Added path to list of auth methods, not adding further methods", privateKeyPath) return sshAuthMethods, knownHostsCallback } - logger.Debug("initKnownHostsAuthMethods", "Unable to use private key", privateKeyPath, err) - logger.FatalExit("Unable to find private SSH key information") + dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to use private key", + privateKeyPath, err) + + // This is only a panic when we expect to do something about it. + if !config.Client.SSHDontAddHostsToKnownHostsFile { + dlog.Client.FatalPanic("Unable to find private SSH key information") + } // Never reach this point. return sshAuthMethods, knownHostsCallback diff --git a/internal/ssh/client/customkeycallback.go b/internal/ssh/client/customkeycallback.go index 73e5289..53b8e3c 100644 --- a/internal/ssh/client/customkeycallback.go +++ b/internal/ssh/client/customkeycallback.go @@ -7,8 +7,7 @@ import ( ) // CustomCallback is a custom host key callback wrapper. -type CustomCallback struct { -} +type CustomCallback struct{} // NewCustomCallback returns a new wrapper. func NewCustomCallback() (*CustomCallback, error) { diff --git a/internal/ssh/client/knownhostscallback.go b/internal/ssh/client/knownhostscallback.go index 1ccf6c6..2aa0168 100644 --- a/internal/ssh/client/knownhostscallback.go +++ b/internal/ssh/client/knownhostscallback.go @@ -10,7 +10,8 @@ import ( "sync" "time" - "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/io/dlog" "github.com/mimecast/dtail/internal/io/prompt" "golang.org/x/crypto/ssh" @@ -46,8 +47,9 @@ type KnownHostsCallback struct { } // NewKnownHostsCallback returns a new wrapper. -func NewKnownHostsCallback(knownHostsPath string, trustAllHosts bool, throttleCh chan struct{}) (HostKeyCallback, error) { - // Ensure file exists +func NewKnownHostsCallback(knownHostsPath string, trustAllHosts bool, + throttleCh chan struct{}) (HostKeyCallback, error) { + os.OpenFile(knownHostsPath, os.O_RDONLY|os.O_CREATE, 0666) untrustedHosts := make(map[string]bool) @@ -59,11 +61,9 @@ func NewKnownHostsCallback(knownHostsPath string, trustAllHosts bool, throttleCh untrustedHosts: untrustedHosts, mutex: &sync.Mutex{}, } - if trustAllHosts { close(c.trustAllHostsCh) } - return c, nil } @@ -75,14 +75,12 @@ func (c KnownHostsCallback) Wrap() ssh.HostKeyCallback { if err != nil { return err } - // Check for valid entry in known_hosts file err = knownHostsCb(server, remote, key) if err == nil { // OK return nil } - // Make sure that interactive user callback does not interfere with // SSH connection throttler. <-c.throttleCh @@ -96,11 +94,9 @@ func (c KnownHostsCallback) Wrap() ssh.HostKeyCallback { ipLine: knownhosts.Line([]string{remote.String()}, key), responseCh: make(chan response), } - - logger.Warn("Encountered unknown host", unknown) + dlog.Common.Warn("Encountered unknown host", unknown) // Notify user that there is an unknown host c.unknownCh <- unknown - // Wait for user input. switch <-unknown.responseCh { case trustHost: @@ -112,7 +108,6 @@ func (c KnownHostsCallback) Wrap() ssh.HostKeyCallback { c.mutex.Lock() defer c.mutex.Unlock() c.untrustedHosts[server] = true - return err } } @@ -121,7 +116,6 @@ func (c KnownHostsCallback) Wrap() ssh.HostKeyCallback { // be added to the known hosts or not. func (c KnownHostsCallback) PromptAddHosts(ctx context.Context) { var hosts []unknownHost - for { // Check whether there is a unknown host select { @@ -139,7 +133,7 @@ func (c KnownHostsCallback) PromptAddHosts(ctx context.Context) { hosts = []unknownHost{} } case <-ctx.Done(): - logger.Debug("Stopping goroutine prompting new hosts...") + dlog.Common.Debug("Stopping goroutine prompting new hosts...") return } } @@ -147,14 +141,13 @@ func (c KnownHostsCallback) PromptAddHosts(ctx context.Context) { func (c KnownHostsCallback) promptAddHosts(hosts []unknownHost) { var servers []string - for _, host := range hosts { servers = append(servers, host.server) } select { case <-c.trustAllHostsCh: - logger.Warn("Trusting host keys of servers", servers) + dlog.Common.Warn("Trusting host keys of servers", servers) c.trustHosts(hosts) return default: @@ -165,7 +158,6 @@ func (c KnownHostsCallback) promptAddHosts(hosts []unknownHost) { strings.Join(servers, ","), "Do you want to trust these hosts?", ) - p := prompt.New(question) a := prompt.Answer{ @@ -175,7 +167,7 @@ func (c KnownHostsCallback) promptAddHosts(hosts []unknownHost) { c.trustHosts(hosts) }, EndCallback: func() { - logger.Info("Added hosts to known hosts file", c.knownHostsPath) + dlog.Common.Info("Added hosts to known hosts file", c.knownHostsPath) }, } p.Add(a) @@ -188,7 +180,7 @@ func (c KnownHostsCallback) promptAddHosts(hosts []unknownHost) { c.trustHosts(hosts) }, EndCallback: func() { - logger.Info("Added hosts to known hosts file", c.knownHostsPath) + dlog.Common.Info("Added hosts to known hosts file", c.knownHostsPath) }, } p.Add(a) @@ -200,7 +192,7 @@ func (c KnownHostsCallback) promptAddHosts(hosts []unknownHost) { c.dontTrustHosts(hosts) }, EndCallback: func() { - logger.Info("Didn't add hosts to known hosts file", c.knownHostsPath) + dlog.Common.Info("Didn't add hosts to known hosts file", c.knownHostsPath) }, } p.Add(a) @@ -224,6 +216,11 @@ func (c KnownHostsCallback) promptAddHosts(hosts []unknownHost) { func (c KnownHostsCallback) trustHosts(hosts []unknownHost) { tmpKnownHostsPath := fmt.Sprintf("%s.tmp", c.knownHostsPath) + if config.Client.SSHDontAddHostsToKnownHostsFile { + dlog.Common.Verbose("Not adding hosts to known hosts file, as disabled by config") + return + } + newFd, err := os.OpenFile(tmpKnownHostsPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600) if err != nil { panic(fmt.Sprintf("%s: %s", tmpKnownHostsPath, err.Error())) @@ -232,7 +229,6 @@ func (c KnownHostsCallback) trustHosts(hosts []unknownHost) { // Newly trusted hosts in normalized form addresses := make(map[string]struct{}) - // First write to new known hosts file, and keep track of addresses for _, unknown := range hosts { unknown.responseCh <- trustHost @@ -255,7 +251,6 @@ func (c KnownHostsCallback) trustHosts(hosts []unknownHost) { defer oldFd.Close() scanner := bufio.NewScanner(oldFd) - // Now, append all still valid old entries to the new host file for scanner.Scan() { line := scanner.Text() @@ -283,6 +278,5 @@ func (c KnownHostsCallback) Untrusted(server string) bool { c.mutex.Lock() defer c.mutex.Unlock() _, ok := c.untrustedHosts[server] - return ok } diff --git a/internal/ssh/server/hostkey.go b/internal/ssh/server/hostkey.go index 07790ad..33bd4e8 100644 --- a/internal/ssh/server/hostkey.go +++ b/internal/ssh/server/hostkey.go @@ -1,11 +1,12 @@ package server import ( - "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/io/logger" - "github.com/mimecast/dtail/internal/ssh" "io/ioutil" "os" + + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/io/dlog" + "github.com/mimecast/dtail/internal/ssh" ) // PrivateHostKey retrieves the private server RSA host key. @@ -14,24 +15,25 @@ func PrivateHostKey() []byte { _, err := os.Stat(hostKeyFile) if os.IsNotExist(err) { - logger.Info("Generating private server RSA host key") + dlog.Common.Info("Generating private server RSA host key") privateKey, err := ssh.GeneratePrivateRSAKey(config.Server.HostKeyBits) if err != nil { - logger.FatalExit("Failed to generate private server RSA host key", err) + dlog.Common.FatalPanic("Failed to generate private server RSA host key", err) } pem := ssh.EncodePrivateKeyToPEM(privateKey) if err := ioutil.WriteFile(hostKeyFile, pem, 0600); err != nil { - logger.Error("Unable to write private server RSA host key to file", hostKeyFile, err) + dlog.Common.Error("Unable to write private server RSA host key to file", + hostKeyFile, err) } return pem } - logger.Info("Reading private server RSA host key from file", hostKeyFile) + dlog.Common.Info("Reading private server RSA host key from file", hostKeyFile) pem, err := ioutil.ReadFile(hostKeyFile) if err != nil { - logger.FatalExit("Failed to load private server RSA host key", err) + dlog.Common.FatalPanic("Failed to load private server RSA host key", err) } return pem } diff --git a/internal/ssh/server/publickeycallback.go b/internal/ssh/server/publickeycallback.go index e81f019..ebc428a 100644 --- a/internal/ssh/server/publickeycallback.go +++ b/internal/ssh/server/publickeycallback.go @@ -7,28 +7,34 @@ import ( osUser "os/user" "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/io/dlog" user "github.com/mimecast/dtail/internal/user/server" gossh "golang.org/x/crypto/ssh" ) -// PublicKeyCallback is for the server to check whether a public SSH key is authorized ot not. -func PublicKeyCallback(c gossh.ConnMetadata, offeredPubKey gossh.PublicKey) (*gossh.Permissions, error) { - user := user.New(c.User(), c.RemoteAddr().String()) - logger.Info(user, "Incoming authorization") +// PublicKeyCallback is for the server to check whether a public SSH key is +// authorized ot not. +func PublicKeyCallback(c gossh.ConnMetadata, + offeredPubKey gossh.PublicKey) (*gossh.Permissions, error) { + user, err := user.New(c.User(), c.RemoteAddr().String()) + if err != nil { + return nil, err + } + + dlog.Common.Info(user, "Incoming authorization") cwd, err := os.Getwd() if err != nil { return nil, fmt.Errorf("Unable to get current working directory|%s|", err.Error()) } - if config.ServerRelaxedAuthEnable { - logger.Fatal(user, "Granting permissions via relaxed-auth") + dlog.Common.Fatal(user, "Granting permissions via relaxed-auth") return nil, nil } - authorizedKeysFile := fmt.Sprintf("%s/%s/%s.authorized_keys", cwd, config.Common.CacheDir, user.Name) + authorizedKeysFile := fmt.Sprintf("%s/%s/%s.authorized_keys", cwd, + config.Common.CacheDir, user.Name) if _, err := os.Stat(authorizedKeysFile); os.IsNotExist(err) { user, err := osUser.Lookup(user.Name) if err != nil { @@ -38,26 +44,28 @@ func PublicKeyCallback(c gossh.ConnMetadata, offeredPubKey gossh.PublicKey) (*go authorizedKeysFile = user.HomeDir + "/.ssh/authorized_keys" } - logger.Info(user, "Reading", authorizedKeysFile) + dlog.Common.Info(user, "Reading", authorizedKeysFile) authorizedKeysBytes, err := ioutil.ReadFile(authorizedKeysFile) if err != nil { - return nil, fmt.Errorf("Unable to read authorized keys file|%s|%s|%s", authorizedKeysFile, user, err.Error()) + return nil, fmt.Errorf("Unable to read authorized keys file|%s|%s|%s", + authorizedKeysFile, user, err.Error()) } authorizedKeysMap := map[string]bool{} for len(authorizedKeysBytes) > 0 { authorizedPubKey, _, _, restBytes, err := gossh.ParseAuthorizedKey(authorizedKeysBytes) if err != nil { - return nil, fmt.Errorf("Unable to parse authorized keys bytes|%s|%s", user, err.Error()) + return nil, fmt.Errorf("Unable to parse authorized keys bytes|%s|%s", + user, err.Error()) } authorizedKeysMap[string(authorizedPubKey.Marshal())] = true authorizedKeysBytes = restBytes - - logger.Debug(user, "Authorized public key fingerprint", gossh.FingerprintSHA256(authorizedPubKey)) + dlog.Common.Debug(user, "Authorized public key fingerprint", + gossh.FingerprintSHA256(authorizedPubKey)) } - logger.Debug(user, "Offered public key fingerprint", gossh.FingerprintSHA256(offeredPubKey)) - + dlog.Common.Debug(user, "Offered public key fingerprint", + gossh.FingerprintSHA256(offeredPubKey)) if authorizedKeysMap[string(offeredPubKey.Marshal())] { return &gossh.Permissions{ Extensions: map[string]string{ diff --git a/internal/ssh/ssh.go b/internal/ssh/ssh.go index 3a2e416..db5aaf1 100644 --- a/internal/ssh/ssh.go +++ b/internal/ssh/ssh.go @@ -6,12 +6,13 @@ import ( "crypto/x509" "encoding/pem" "fmt" - "github.com/mimecast/dtail/internal/io/logger" "io/ioutil" "net" "os" "syscall" + "github.com/mimecast/dtail/internal/io/dlog" + gossh "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" "golang.org/x/crypto/ssh/terminal" @@ -23,12 +24,10 @@ func GeneratePrivateRSAKey(size int) (*rsa.PrivateKey, error) { if err != nil { return nil, err } - err = privateKey.Validate() if err != nil { return nil, err } - return privateKey, nil } @@ -41,7 +40,6 @@ func EncodePrivateKeyToPEM(privateKey *rsa.PrivateKey) []byte { Headers: nil, Bytes: derFormat, } - return pem.EncodeToMemory(&block) } @@ -57,7 +55,7 @@ func Agent() (gossh.AuthMethod, error) { return nil, err } for i, key := range keys { - logger.Debug("Public key", i, key) + dlog.Common.Debug("Public key", i, key) } return gossh.PublicKeysCallback(agentClient.Signers), nil } @@ -79,7 +77,6 @@ func KeyFile(keyFile string) (gossh.AuthMethod, error) { if err != nil { return nil, err } - key, err := gossh.ParsePrivateKey(buffer) if err != nil { return nil, err @@ -105,7 +102,7 @@ func KeyFile(keyFile string) (gossh.AuthMethod, error) { func PrivateKey(keyFile string) (gossh.AuthMethod, error) { signer, err := KeyFile(keyFile) if err != nil { - logger.Debug(keyFile, err) + dlog.Common.Debug(keyFile, err) return nil, err } return gossh.AuthMethod(signer), nil diff --git a/internal/user/name.go b/internal/user/name.go index 28ab0a4..cd11907 100644 --- a/internal/user/name.go +++ b/internal/user/name.go @@ -10,11 +10,9 @@ func NoRootCheck() { if err != nil { panic(err) } - if user.Uid == "0" { panic("Not allowed to run as UID 0") } - if user.Gid == "0" { panic("Not allowed to run as GID 0") } @@ -26,6 +24,5 @@ func Name() string { if err != nil { panic(err) } - return user.Username } diff --git a/internal/user/server/user.go b/internal/user/server/user.go index af6b0d0..aa7f8b1 100644 --- a/internal/user/server/user.go +++ b/internal/user/server/user.go @@ -8,8 +8,8 @@ import ( "strings" "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/io/dlog" "github.com/mimecast/dtail/internal/io/fs/permissions" - "github.com/mimecast/dtail/internal/io/logger" ) const maxLinkDepth int = 100 @@ -25,11 +25,16 @@ type User struct { } // New returns a new user. -func New(name, remoteAddress string) *User { +func New(name, remoteAddress string) (*User, error) { + permissions, err := config.ServerUserPermissions(name) + if err != nil { + return nil, err + } return &User{ Name: name, remoteAddress: remoteAddress, - } + permissions: permissions, + }, nil } // String representation of the user. @@ -37,14 +42,13 @@ func (u *User) String() string { return fmt.Sprintf("%s@%s", u.Name, u.remoteAddress) } -// HasFilePermission is used to determine whether user is allowed to read a file. +// HasFilePermission is used to determine whether user is alowed to read a file. func (u *User) HasFilePermission(filePath, permissionType string) (hasPermission bool) { - logger.Debug(u, filePath, permissionType, "Checking config permissions") + dlog.Server.Debug(u, filePath, permissionType, "Checking config permissions") if config.ServerRelaxedAuthEnable { - logger.Fatal(u, filePath, permissionType, "Server releaxed auth enabled") + dlog.Server.Fatal(u, filePath, permissionType, "Server releaxed auth enabled") return true } - if u.Name == config.ScheduleUser || u.Name == config.ContinuousUser { // Background user has same permissions as dtail process itself. return true @@ -52,27 +56,29 @@ func (u *User) HasFilePermission(filePath, permissionType string) (hasPermission cleanPath, err := filepath.EvalSymlinks(filePath) if err != nil { - logger.Error(u, filePath, permissionType, "Unable to evaluate symlinks", err) + dlog.Server.Error(u, filePath, permissionType, + "Unable to evaluate symlinks", err) hasPermission = false return } cleanPath, err = filepath.Abs(cleanPath) if err != nil { - logger.Error(u, cleanPath, permissionType, "Unable to make file path absolute", err) + dlog.Server.Error(u, cleanPath, permissionType, + "Unable to make file path absolute", err) hasPermission = false return } if cleanPath != filePath { - logger.Info(u, filePath, cleanPath, permissionType, "Calculated new clean path from original file path (possibly symlink)") + dlog.Server.Info(u, filePath, cleanPath, permissionType, + "Calculated new clean path from original file path (possibly symlink)") } hasPermission, err = u.hasFilePermission(cleanPath, permissionType) if err != nil { - logger.Warn(u, cleanPath, err) + dlog.Server.Warn(u, cleanPath, err) } - return } @@ -81,24 +87,17 @@ func (u *User) hasFilePermission(cleanPath, permissionType string) (bool, error) if _, err := permissions.ToRead(u.Name, cleanPath); err != nil { return false, fmt.Errorf("User without OS file system permissions to read path: '%v'", err) } - logger.Info(u, cleanPath, permissionType, "User with OS file system permissions to path") + dlog.Server.Info(u, cleanPath, permissionType, + "User with OS file system permissions to path") // Only allow to follow regular files or symlinks. info, err := os.Lstat(cleanPath) if err != nil { return false, fmt.Errorf("Unable to determine file type: '%v'", err) } - if !info.Mode().IsRegular() { return false, fmt.Errorf("Can only open regular files or follow symlinks") } - - permissions, err := config.ServerUserPermissions(u.Name) - if err != nil { - return false, err - } - u.permissions = permissions - hasPermission, err := u.iteratePaths(cleanPath, permissionType) if err != nil { return false, err @@ -110,10 +109,8 @@ func (u *User) hasFilePermission(cleanPath, permissionType string) (bool, error) func (u *User) iteratePaths(cleanPath, permissionType string) (bool, error) { // By default assume no permissions hasPermission := false - for _, permission := range u.permissions { typeStr := "readfiles" // Assume ReadFiles by default. - var regexStr string var negate bool @@ -123,8 +120,7 @@ func (u *User) iteratePaths(cleanPath, permissionType string) (bool, error) { permission = strings.Join(splitted[1:], ":") } - logger.Debug(u, cleanPath, typeStr, permission) - + dlog.Server.Debug(u, cleanPath, typeStr, permission) if typeStr != permissionType { continue } @@ -137,16 +133,17 @@ func (u *User) iteratePaths(cleanPath, permissionType string) (bool, error) { re, err := regexp.Compile(regexStr) if err != nil { - return false, fmt.Errorf("Permission test failed, can't compile regex '%s': '%v'", regexStr, err) + return false, fmt.Errorf("Permission test failed, can't compile regex "+ + "'%s': '%v'", regexStr, err) } - if negate && re.MatchString(cleanPath) { - logger.Info(u, cleanPath, "Permission test failed partially, matching negative pattern '%s'", permission) + dlog.Server.Info(u, cleanPath, "Permission test failed partially, "+ + "matching negative pattern '%s'", permission) hasPermission = false } - if !negate && re.MatchString(cleanPath) { - logger.Info(u, cleanPath, "Permission test passed partially, matching positive pattern", permission) + dlog.Server.Info(u, cleanPath, "Permission test passed partially, "+ + "matching positive pattern", permission) hasPermission = true } } diff --git a/internal/version/version.go b/internal/version/version.go index 3c31059..68b9e6e 100644 --- a/internal/version/version.go +++ b/internal/version/version.go @@ -5,38 +5,50 @@ import ( "os" "github.com/mimecast/dtail/internal/color" + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/protocol" ) const ( // Name of DTail. Name string = "DTail" // Version of DTail. - Version string = "3.3.1" + Version string = "4.0.0-RC1" // Additional information for DTail - Additional string = "" - // ProtocolCompat -ibility version. - ProtocolCompat string = "3" + Additional string = "Have a lot of fun!" ) // String representation of the DTail version. func String() string { - return fmt.Sprintf("%s %v Protocol %s %s", Name, Version, ProtocolCompat, Additional) + return fmt.Sprintf("%s %v Protocol %s %s", Name, Version, + protocol.ProtocolCompat, Additional) } // PaintedString is a prettier string representation of the DTail version. func PaintedString() string { - if !color.Colored { + if !config.Client.TermColorsEnable { return String() } - name := color.Paint(color.Yellow, Name) - version := color.Paint(color.Blue, Version) - descr := color.Paint(color.Green, Additional) - return fmt.Sprintf("%s %v Protocol %s %s", name, version, ProtocolCompat, descr) + name := color.PaintStrWithAttr(fmt.Sprintf(" %s ", Name), + color.FgYellow, color.BgBlue, color.AttrBold) + version := color.PaintStrWithAttr(fmt.Sprintf(" %s ", Version), + color.FgBlue, color.BgYellow, color.AttrBold) + protocol := color.PaintStr(fmt.Sprintf(" Protocol %s ", protocol.ProtocolCompat), + color.FgBlack, color.BgGreen) + additional := color.PaintStrWithAttr(fmt.Sprintf(" %s ", Additional), + color.FgWhite, color.BgMagenta, color.AttrUnderline) + + return fmt.Sprintf("%s%v%s%s", name, version, protocol, additional) +} + +// Print the version. +func Print() { + fmt.Println(PaintedString()) } // PrintAndExit prints the program version and exists. func PrintAndExit() { - fmt.Println(PaintedString()) + Print() os.Exit(0) } |
