summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorPaul Buetow <pbuetow@mimecast.com>2021-10-21 21:28:49 +0300
committerPaul Buetow <pbuetow@mimecast.com>2021-10-21 21:28:49 +0300
commitf4207a55f71bfbcfdc532d5cdd3befaa3474a157 (patch)
treeea5e4a2d2a67035f645bdee496ae55a52034178a /internal
parentd80d6070557e3a800e3a54967af9eced518f116b (diff)
parent739205206d63bf42f4e843b39d04d4c8cd8207c3 (diff)
merge develop
Diffstat (limited to 'internal')
-rw-r--r--internal/clients/baseclient.go131
-rw-r--r--internal/clients/catclient.go16
-rw-r--r--internal/clients/connectors/connector.go17
-rw-r--r--internal/clients/connectors/serverconnection.go206
-rw-r--r--internal/clients/connectors/serverless.go116
-rw-r--r--internal/clients/grepclient.go19
-rw-r--r--internal/clients/handlers/basehandler.go77
-rw-r--r--internal/clients/handlers/clienthandler.go4
-rw-r--r--internal/clients/handlers/healthhandler.go106
-rw-r--r--internal/clients/handlers/maprhandler.go56
-rw-r--r--internal/clients/healthclient.go114
-rw-r--r--internal/clients/maker.go2
-rw-r--r--internal/clients/maprclient.go89
-rw-r--r--internal/clients/remote/connection.go212
-rw-r--r--internal/clients/stats.go76
-rw-r--r--internal/clients/tailclient.go20
-rw-r--r--internal/color/brush/brush.go194
-rw-r--r--internal/color/color.go174
-rw-r--r--internal/color/color_test.go53
-rw-r--r--internal/color/colorfy.go56
-rw-r--r--internal/color/paint.go91
-rw-r--r--internal/color/table.go53
-rw-r--r--internal/config/args.go164
-rw-r--r--internal/config/client.go194
-rw-r--r--internal/config/common.go26
-rw-r--r--internal/config/config.go71
-rw-r--r--internal/config/env.go7
-rw-r--r--internal/config/initializer.go184
-rw-r--r--internal/config/read.go37
-rw-r--r--internal/config/server.go10
-rw-r--r--internal/discovery/comma.go4
-rw-r--r--internal/discovery/discovery.go33
-rw-r--r--internal/discovery/file.go8
-rw-r--r--internal/done.go10
-rw-r--r--internal/io/dlog/dlog.go272
-rw-r--r--internal/io/dlog/level.go84
-rw-r--r--internal/io/dlog/loggers/factory.go54
-rw-r--r--internal/io/dlog/loggers/file.go165
-rw-r--r--internal/io/dlog/loggers/fout.go46
-rw-r--r--internal/io/dlog/loggers/logger.go19
-rw-r--r--internal/io/dlog/loggers/none.go21
-rw-r--r--internal/io/dlog/loggers/stdout.go54
-rw-r--r--internal/io/dlog/loggers/strategy.go35
-rw-r--r--internal/io/dlog/rotation.go27
-rw-r--r--internal/io/fs/catfile.go4
-rw-r--r--internal/io/fs/filereader.go6
-rw-r--r--internal/io/fs/filter.go167
-rw-r--r--internal/io/fs/permissions/permission.go4
-rw-r--r--internal/io/fs/permissions/permission_linuxacl.go2
-rw-r--r--internal/io/fs/readfile.go384
-rw-r--r--internal/io/fs/tailfile.go4
-rw-r--r--internal/io/fs/truncate.go61
-rw-r--r--internal/io/line/line.go5
-rw-r--r--internal/io/logger/logger.go403
-rw-r--r--internal/io/logger/modes.go12
-rw-r--r--internal/io/logger/strategy.go22
-rw-r--r--internal/io/pool/builder.go21
-rw-r--r--internal/io/pool/bytesbuffer.go22
-rw-r--r--internal/io/prompt/prompt.go13
-rw-r--r--internal/io/signal/signal.go8
-rw-r--r--internal/lcontext/lcontext.go2
-rw-r--r--internal/mapr/aggregateset.go29
-rw-r--r--internal/mapr/client/aggregate.go29
-rw-r--r--internal/mapr/funcs/function.go11
-rw-r--r--internal/mapr/funcs/function_test.go21
-rw-r--r--internal/mapr/funcs/maskdigits.go2
-rw-r--r--internal/mapr/globalgroupset.go11
-rw-r--r--internal/mapr/groupset.go291
-rw-r--r--internal/mapr/logformat/default.go41
-rw-r--r--internal/mapr/logformat/default_test.go88
-rw-r--r--internal/mapr/logformat/generickv.go31
-rw-r--r--internal/mapr/logformat/parser.go15
-rw-r--r--internal/mapr/query.go21
-rw-r--r--internal/mapr/query_test.go125
-rw-r--r--internal/mapr/selectcondition.go12
-rw-r--r--internal/mapr/server/aggregate.go155
-rw-r--r--internal/mapr/setclause.go2
-rw-r--r--internal/mapr/setcondition.go15
-rw-r--r--internal/mapr/token.go21
-rw-r--r--internal/mapr/whereclause.go16
-rw-r--r--internal/mapr/wherecondition.go30
-rw-r--r--internal/protocol/protocol.go18
-rw-r--r--internal/regex/regex.go20
-rw-r--r--internal/regex/regex_test.go33
-rw-r--r--internal/server/continuous.go36
-rw-r--r--internal/server/handlers/basehandler.go320
-rw-r--r--internal/server/handlers/controlhandler.go98
-rw-r--r--internal/server/handlers/healthhandler.go58
-rw-r--r--internal/server/handlers/mapcommand.go7
-rw-r--r--internal/server/handlers/readcommand.go83
-rw-r--r--internal/server/handlers/serverhandler.go412
-rw-r--r--internal/server/scheduler.go37
-rw-r--r--internal/server/server.go110
-rw-r--r--internal/server/stats.go21
-rw-r--r--internal/source/source.go30
-rw-r--r--internal/ssh/client/authmethods.go67
-rw-r--r--internal/ssh/client/customkeycallback.go3
-rw-r--r--internal/ssh/client/knownhostscallback.go38
-rw-r--r--internal/ssh/server/hostkey.go18
-rw-r--r--internal/ssh/server/publickeycallback.go38
-rw-r--r--internal/ssh/ssh.go11
-rw-r--r--internal/user/name.go3
-rw-r--r--internal/user/server/user.go57
-rw-r--r--internal/version/version.go34
104 files changed, 4357 insertions, 2708 deletions
diff --git a/internal/clients/baseclient.go b/internal/clients/baseclient.go
index f83fcfd..4a7bd84 100644
--- a/internal/clients/baseclient.go
+++ b/internal/clients/baseclient.go
@@ -2,15 +2,13 @@ package clients
import (
"context"
- "fmt"
- "strings"
"sync"
"time"
- "github.com/mimecast/dtail/internal/clients/remote"
+ "github.com/mimecast/dtail/internal/clients/connectors"
+ "github.com/mimecast/dtail/internal/config"
"github.com/mimecast/dtail/internal/discovery"
- "github.com/mimecast/dtail/internal/io/logger"
- "github.com/mimecast/dtail/internal/omode"
+ "github.com/mimecast/dtail/internal/io/dlog"
"github.com/mimecast/dtail/internal/regex"
"github.com/mimecast/dtail/internal/ssh/client"
@@ -19,13 +17,13 @@ import (
// This is the main client data structure.
type baseClient struct {
- Args
+ config.Args
// To display client side stats
stats *stats
// List of remote servers to connect to.
servers []string
// We have one connection per remote server.
- connections []*remote.Connection
+ connections []connectors.Connector
// SSH auth methods to use to connect to the remote servers.
sshAuthMethods []gossh.AuthMethod
// To deal with SSH host keys
@@ -41,7 +39,7 @@ type baseClient struct {
}
func (c *baseClient) init() {
- logger.Debug("Initiating base client")
+ dlog.Client.Debug("Initiating base client", c.Args.String())
flag := regex.Default
if c.Args.RegexInvert {
@@ -49,12 +47,16 @@ func (c *baseClient) init() {
}
regex, err := regex.New(c.Args.RegexStr, flag)
if err != nil {
- logger.FatalExit(c.Regex, "invalid regex!", err, regex)
+ dlog.Client.FatalPanic(c.Regex, "Invalid regex!", err, regex)
}
c.Regex = regex
- logger.Debug("Regex", c.Regex)
- c.sshAuthMethods, c.hostKeyCallback = client.InitSSHAuthMethods(c.Args.SSHAuthMethods, c.Args.SSHHostKeyCallback, c.Args.TrustAllHosts, c.throttleCh, c.Args.PrivateKeyPathFile)
+ if c.Args.Serverless {
+ return
+ }
+ c.sshAuthMethods, c.hostKeyCallback = client.InitSSHAuthMethods(
+ c.Args.SSHAuthMethods, c.Args.SSHHostKeyCallback, c.Args.TrustAllHosts,
+ c.throttleCh, c.Args.PrivateKeyPathFile)
}
func (c *baseClient) makeConnections(maker maker) {
@@ -62,26 +64,31 @@ func (c *baseClient) makeConnections(maker maker) {
discoveryService := discovery.New(c.Discovery, c.ServersStr, discovery.Shuffle)
for _, server := range discoveryService.ServerList() {
- c.connections = append(c.connections, c.makeConnection(server, c.sshAuthMethods, c.hostKeyCallback))
+ c.connections = append(c.connections, c.makeConnection(server,
+ c.sshAuthMethods, c.hostKeyCallback))
}
c.stats = newTailStats(len(c.connections))
}
func (c *baseClient) Start(ctx context.Context, statsCh <-chan string) (status int) {
- // Periodically check for unknown hosts, and ask the user whether to trust them or not.
- go c.hostKeyCallback.PromptAddHosts(ctx)
+ dlog.Client.Trace("Starting base client")
+ // Can be nil when serverless.
+ if c.hostKeyCallback != nil {
+ // Periodically check for unknown hosts, and ask the user whether to trust them or not.
+ go c.hostKeyCallback.PromptAddHosts(ctx)
+ }
// Print client stats every time something on statsCh is recieved.
go c.stats.Start(ctx, c.throttleCh, statsCh, c.Args.Quiet)
- // Keep count of active connections
- active := make(chan struct{}, len(c.connections))
+ var wg sync.WaitGroup
+ wg.Add(len(c.connections))
var mutex sync.Mutex
- for i, conn := range c.connections {
- go func(i int, conn *remote.Connection) {
- connStatus := c.start(ctx, active, i, conn)
- // Update global status.
+ for i, conn := range c.connections {
+ go func(i int, conn connectors.Connector) {
+ defer wg.Done()
+ connStatus := c.startConnection(ctx, i, conn)
mutex.Lock()
defer mutex.Unlock()
if connStatus > status {
@@ -90,15 +97,12 @@ func (c *baseClient) Start(ctx context.Context, statsCh <-chan string) (status i
}(i, conn)
}
- c.waitUntilDone(ctx, active)
+ wg.Wait()
return
}
-func (c *baseClient) start(ctx context.Context, active chan struct{}, i int, conn *remote.Connection) (status int) {
- // Increment connection count
- active <- struct{}{}
- // Derement connection count
- defer func() { <-active }()
+func (c *baseClient) startConnection(ctx context.Context, i int,
+ conn connectors.Connector) (status int) {
for {
connCtx, cancel := context.WithCancel(ctx)
@@ -106,80 +110,25 @@ func (c *baseClient) start(ctx context.Context, active chan struct{}, i int, con
conn.Start(connCtx, cancel, c.throttleCh, c.stats.connectionsEstCh)
// Retrieve status code from handler (dtail client will exit with that status)
- status = conn.Handler.Status()
+ status = conn.Handler().Status()
if !c.retry {
return
}
time.Sleep(time.Second * 2)
- logger.Debug(conn.Server, "Reconnecting")
-
- conn = c.makeConnection(conn.Server, c.sshAuthMethods, c.hostKeyCallback)
+ dlog.Client.Debug(conn.Server(), "Reconnecting")
+ conn = c.makeConnection(conn.Server(), c.sshAuthMethods, c.hostKeyCallback)
c.connections[i] = conn
}
}
-func (c *baseClient) makeCommandOptions() map[string]string {
- options := make(map[string]string)
-
- if c.Args.Quiet {
- options["quiet"] = fmt.Sprintf("%v", c.Args.Quiet)
- }
- if c.Args.LContext.MaxCount != 0 {
- options["max"] = fmt.Sprintf("%d", c.Args.LContext.MaxCount)
- }
- if c.Args.LContext.BeforeContext != 0 {
- options["before"] = fmt.Sprintf("%d", c.Args.LContext.BeforeContext)
- }
- if c.Args.LContext.AfterContext != 0 {
- options["after"] = fmt.Sprintf("%d", c.Args.LContext.AfterContext)
- }
-
- return options
-}
-
-func (c *baseClient) commandOptionsToString(options map[string]string) string {
- var sb strings.Builder
-
- count := 0
- for k, v := range options {
- if count > 0 {
- sb.WriteString(":")
- }
- sb.WriteString(fmt.Sprintf("%s=%s", k, v))
- count++
- }
-
- return sb.String()
-}
-
-func (c *baseClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback client.HostKeyCallback) *remote.Connection {
- conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback)
- conn.Handler = c.maker.makeHandler(server)
- conn.Commands = c.maker.makeCommands(c.makeCommandOptions())
-
- return conn
-}
-
-func (c *baseClient) waitUntilDone(ctx context.Context, active chan struct{}) {
- defer logger.Debug("Terminated connection")
-
- // We want to have at least one active connection
- <-active
- // Put it back on the channel
- active <- struct{}{}
-
- if c.Mode == omode.TailClient && c.retry {
- <-ctx.Done()
- }
-
- for {
- numActive := len(active)
- if numActive == 0 {
- return
- }
- logger.Debug("Active connections", numActive)
- time.Sleep(time.Second)
+func (c *baseClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod,
+ hostKeyCallback client.HostKeyCallback) connectors.Connector {
+ if c.Args.Serverless {
+ return connectors.NewServerless(c.UserName, c.maker.makeHandler(server),
+ c.maker.makeCommands())
}
+ return connectors.NewServerConnection(server, c.UserName, sshAuthMethods,
+ hostKeyCallback, c.maker.makeHandler(server), c.maker.makeCommands())
}
diff --git a/internal/clients/catclient.go b/internal/clients/catclient.go
index db892f1..bd65560 100644
--- a/internal/clients/catclient.go
+++ b/internal/clients/catclient.go
@@ -7,6 +7,8 @@ import (
"strings"
"github.com/mimecast/dtail/internal/clients/handlers"
+ "github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/io/dlog"
"github.com/mimecast/dtail/internal/omode"
)
@@ -16,11 +18,10 @@ type CatClient struct {
}
// NewCatClient returns a new cat client.
-func NewCatClient(args Args) (*CatClient, error) {
+func NewCatClient(args config.Args) (*CatClient, error) {
if args.RegexStr != "" {
return nil, errors.New("Can't use regex with 'cat' operating mode")
}
-
args.Mode = omode.CatClient
c := CatClient{
@@ -33,7 +34,6 @@ func NewCatClient(args Args) (*CatClient, error) {
c.init()
c.makeConnections(c)
-
return &c, nil
}
@@ -41,10 +41,14 @@ func (c CatClient) makeHandler(server string) handlers.Handler {
return handlers.NewClientHandler(server)
}
-func (c CatClient) makeCommands(options map[string]string) (commands []string) {
- optionsStr := c.commandOptionsToString(options)
+func (c CatClient) makeCommands() (commands []string) {
+ regex, err := c.Regex.Serialize()
+ if err != nil {
+ dlog.Client.FatalPanic(err)
+ }
for _, file := range strings.Split(c.What, ",") {
- commands = append(commands, fmt.Sprintf("%s:%s %s %s", c.Mode.String(), optionsStr, file, c.Regex.Serialize()))
+ commands = append(commands, fmt.Sprintf("%s:%s %s %s",
+ c.Mode.String(), c.Args.SerializeOptions(), file, regex))
}
return
}
diff --git a/internal/clients/connectors/connector.go b/internal/clients/connectors/connector.go
new file mode 100644
index 0000000..3ab6a08
--- /dev/null
+++ b/internal/clients/connectors/connector.go
@@ -0,0 +1,17 @@
+package connectors
+
+import (
+ "context"
+
+ "github.com/mimecast/dtail/internal/clients/handlers"
+)
+
+// Connector interface.
+type Connector interface {
+ // Start the connection.
+ Start(ctx context.Context, cancel context.CancelFunc, throttleCh, statsCh chan struct{})
+ // Server hostname.
+ Server() string
+ // Handler for the connection.
+ Handler() handlers.Handler
+}
diff --git a/internal/clients/connectors/serverconnection.go b/internal/clients/connectors/serverconnection.go
new file mode 100644
index 0000000..aeb2a41
--- /dev/null
+++ b/internal/clients/connectors/serverconnection.go
@@ -0,0 +1,206 @@
+package connectors
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "strconv"
+ "strings"
+ "time"
+
+ "github.com/mimecast/dtail/internal/clients/handlers"
+ "github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/io/dlog"
+ "github.com/mimecast/dtail/internal/ssh/client"
+
+ "golang.org/x/crypto/ssh"
+)
+
+// ServerConnection represents a connection to a single remote dtail server via
+// SSH protocol.
+type ServerConnection struct {
+ // The full server string as received from the server discovery (can be with port number)
+ server string
+ // Only the hostname or FQDN (without the port number)
+ hostname string
+ // Only the port number.
+ port int
+ config *ssh.ClientConfig
+ handler handlers.Handler
+ commands []string
+ hostKeyCallback client.HostKeyCallback
+ throttlingDone bool
+}
+
+// NewServerConnection returns a new DTail SSH server connection.
+func NewServerConnection(server string, userName string,
+ authMethods []ssh.AuthMethod, hostKeyCallback client.HostKeyCallback,
+ handler handlers.Handler, commands []string) *ServerConnection {
+
+ dlog.Client.Debug(server, "Creating new connection", server, handler, commands)
+ c := ServerConnection{
+ hostKeyCallback: hostKeyCallback,
+ server: server,
+ handler: handler,
+ commands: commands,
+ config: &ssh.ClientConfig{
+ User: userName,
+ Auth: authMethods,
+ HostKeyCallback: hostKeyCallback.Wrap(),
+ Timeout: time.Second * 2,
+ },
+ }
+
+ c.initServerPort()
+ return &c
+}
+
+// Server returns the server hostname connected to.
+func (c *ServerConnection) Server() string { return c.server }
+
+// Handler returns the handler used for the connection.
+func (c *ServerConnection) Handler() handlers.Handler { return c.handler }
+
+// Attempt to parse the server port address from the provided server FQDN.
+func (c *ServerConnection) initServerPort() {
+ parts := strings.Split(c.server, ":")
+ if len(parts) == 1 {
+ c.hostname = c.server
+ c.port = config.Common.SSHPort
+ return
+ }
+
+ dlog.Client.Debug("Parsing port from hostname", parts)
+ port, err := strconv.Atoi(parts[1])
+ if err != nil {
+ dlog.Client.FatalPanic("Unable to parse client port", c.server, parts, err)
+ }
+ c.hostname = parts[0]
+ c.port = port
+}
+
+// Start the connection to the server.
+func (c *ServerConnection) Start(ctx context.Context, cancel context.CancelFunc,
+ throttleCh, statsCh chan struct{}) {
+
+ // Throttle how many connections can be established concurrently (based on ch length)
+ dlog.Client.Debug(c.server, "Throttling connection", len(throttleCh), cap(throttleCh))
+
+ select {
+ case throttleCh <- struct{}{}:
+ case <-ctx.Done():
+ dlog.Client.Debug(c.server, "Not establishing connection as context is done",
+ len(throttleCh), cap(throttleCh))
+ return
+ }
+
+ dlog.Client.Debug(c.server, "Throttling says that the connection can be established",
+ len(throttleCh), cap(throttleCh))
+
+ go func() {
+ defer func() {
+ if !c.throttlingDone {
+ dlog.Client.Debug(c.server, "Unthrottling connection (1)",
+ len(throttleCh), cap(throttleCh))
+ c.throttlingDone = true
+ <-throttleCh
+ }
+ cancel()
+ }()
+
+ if err := c.dial(ctx, cancel, throttleCh, statsCh); err != nil {
+ dlog.Client.Warn(c.server, err)
+ if c.hostKeyCallback.Untrusted(c.server) {
+ dlog.Client.Debug(c.server, "Not trusting host")
+ }
+ }
+ }()
+
+ <-ctx.Done()
+}
+
+// Dail into a new SSH connection. Close connection in case of an error.
+func (c *ServerConnection) dial(ctx context.Context, cancel context.CancelFunc,
+ throttleCh, statsCh chan struct{}) error {
+
+ dlog.Client.Debug(c.server, "Incrementing connection stats")
+ statsCh <- struct{}{}
+ defer func() {
+ dlog.Client.Debug(c.server, "Decrementing connection stats")
+ <-statsCh
+ }()
+
+ address := fmt.Sprintf("%s:%d", c.hostname, c.port)
+ dlog.Client.Debug(c.server, "Dialing into the connection", address)
+
+ client, err := ssh.Dial("tcp", address, c.config)
+ if err != nil {
+ return err
+ }
+ defer client.Close()
+
+ return c.session(ctx, cancel, client, throttleCh)
+}
+
+// Create the SSH session. Close the session in case of an error.
+func (c *ServerConnection) session(ctx context.Context, cancel context.CancelFunc,
+ client *ssh.Client, throttleCh chan struct{}) error {
+
+ dlog.Client.Debug(c.server, "Creating SSH session")
+ session, err := client.NewSession()
+ if err != nil {
+ return err
+ }
+ defer session.Close()
+ return c.handle(ctx, cancel, session, throttleCh)
+}
+
+func (c *ServerConnection) handle(ctx context.Context, cancel context.CancelFunc,
+ session *ssh.Session, throttleCh chan struct{}) error {
+
+ dlog.Client.Debug(c.server, "Creating handler for SSH session")
+ stdinPipe, err := session.StdinPipe()
+ if err != nil {
+ return err
+ }
+ stdoutPipe, err := session.StdoutPipe()
+ if err != nil {
+ return err
+ }
+ if err := session.Shell(); err != nil {
+ return err
+ }
+
+ go func() {
+ io.Copy(stdinPipe, c.handler)
+ cancel()
+ }()
+ go func() {
+ io.Copy(c.handler, stdoutPipe)
+ cancel()
+ }()
+ go func() {
+ select {
+ case <-c.handler.Done():
+ case <-ctx.Done():
+ }
+ cancel()
+ }()
+
+ // Send all commands to client.
+ for _, command := range c.commands {
+ dlog.Client.Debug(command)
+ c.handler.SendMessage(command)
+ }
+
+ if !c.throttlingDone {
+ dlog.Client.Debug(c.server, "Unthrottling connection (2)",
+ len(throttleCh), cap(throttleCh))
+ c.throttlingDone = true
+ <-throttleCh
+ }
+
+ <-ctx.Done()
+ c.handler.Shutdown()
+ return nil
+}
diff --git a/internal/clients/connectors/serverless.go b/internal/clients/connectors/serverless.go
new file mode 100644
index 0000000..2ff490a
--- /dev/null
+++ b/internal/clients/connectors/serverless.go
@@ -0,0 +1,116 @@
+package connectors
+
+import (
+ "context"
+ "io"
+
+ "github.com/mimecast/dtail/internal/clients/handlers"
+ "github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/io/dlog"
+ serverHandlers "github.com/mimecast/dtail/internal/server/handlers"
+ user "github.com/mimecast/dtail/internal/user/server"
+)
+
+// Serverless creates a server object directly without TCP.
+type Serverless struct {
+ handler handlers.Handler
+ commands []string
+ userName string
+}
+
+// NewServerless starts a new serverless session.
+func NewServerless(userName string, handler handlers.Handler,
+ commands []string) *Serverless {
+
+ dlog.Client.Debug("Creating new serverless connector", handler, commands)
+ return &Serverless{
+ userName: userName,
+ handler: handler,
+ commands: commands,
+ }
+}
+
+// Server returns serverless server indicator.
+func (s *Serverless) Server() string {
+ return "local(serverless)"
+}
+
+// Handler returns the handler used for the serverless connection.
+func (s *Serverless) Handler() handlers.Handler {
+ return s.handler
+}
+
+// Start the serverless connection.
+func (s *Serverless) Start(ctx context.Context, cancel context.CancelFunc,
+ throttleCh, statsCh chan struct{}) {
+
+ dlog.Client.Debug("Starting serverless connector")
+ go func() {
+ defer cancel()
+
+ if err := s.handle(ctx, cancel); err != nil {
+ dlog.Client.Warn(err)
+ }
+ }()
+ <-ctx.Done()
+}
+
+func (s *Serverless) handle(ctx context.Context, cancel context.CancelFunc) error {
+ dlog.Client.Debug("Creating server handler for a serverless session")
+
+ user, err := user.New(s.userName, s.Server())
+ if err != nil {
+ return err
+ }
+
+ var serverHandler serverHandlers.Handler
+ switch s.userName {
+ case config.HealthUser:
+ dlog.Client.Debug("Creating serverless health handler")
+ serverHandler = serverHandlers.NewHealthHandler(user)
+ default:
+ dlog.Client.Debug("Creating serverless server handler")
+ serverHandler = serverHandlers.NewServerHandler(
+ user,
+ make(chan struct{}, config.Server.MaxConcurrentCats),
+ make(chan struct{}, config.Server.MaxConcurrentTails),
+ )
+ }
+
+ terminate := func() {
+ dlog.Client.Debug("Terminating serverless connection")
+ serverHandler.Shutdown()
+ cancel()
+ }
+
+ go func() {
+ io.Copy(serverHandler, s.handler)
+ dlog.Client.Trace("io.Copy(serverHandler, s.handler) => done")
+ terminate()
+ }()
+ go func() {
+ io.Copy(s.handler, serverHandler)
+ dlog.Client.Trace("io.Copy(s.handler, serverHandler) => done")
+ terminate()
+ }()
+ go func() {
+ select {
+ case <-s.handler.Done():
+ dlog.Client.Trace("<-s.handler.Done()")
+ case <-ctx.Done():
+ dlog.Client.Trace("<-ctx.Done()")
+ }
+ terminate()
+ }()
+
+ // Send all commands to client.
+ for _, command := range s.commands {
+ dlog.Client.Debug("Sending command to serverless server", command)
+ s.handler.SendMessage(command)
+ }
+
+ <-ctx.Done()
+ dlog.Client.Trace("s.handler.Shutdown()")
+ s.handler.Shutdown()
+ return nil
+}
diff --git a/internal/clients/grepclient.go b/internal/clients/grepclient.go
index 567193a..7521c67 100644
--- a/internal/clients/grepclient.go
+++ b/internal/clients/grepclient.go
@@ -7,16 +7,19 @@ import (
"strings"
"github.com/mimecast/dtail/internal/clients/handlers"
+ "github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/io/dlog"
"github.com/mimecast/dtail/internal/omode"
)
-// GrepClient searches a remote file for all lines matching a regular expression. Only the matching lines are displayed.
+// GrepClient searches a remote file for all lines matching a regular
+// expression. Only the matching lines are displayed.
type GrepClient struct {
baseClient
}
// NewGrepClient creates a new grep client.
-func NewGrepClient(args Args) (*GrepClient, error) {
+func NewGrepClient(args config.Args) (*GrepClient, error) {
if args.RegexStr == "" {
return nil, errors.New("No regex specified, use '-regex' flag")
}
@@ -32,7 +35,6 @@ func NewGrepClient(args Args) (*GrepClient, error) {
c.init()
c.makeConnections(c)
-
return &c, nil
}
@@ -40,11 +42,14 @@ func (c GrepClient) makeHandler(server string) handlers.Handler {
return handlers.NewClientHandler(server)
}
-func (c GrepClient) makeCommands(options map[string]string) (commands []string) {
- optionsStr := c.commandOptionsToString(options)
+func (c GrepClient) makeCommands() (commands []string) {
+ regex, err := c.Regex.Serialize()
+ if err != nil {
+ dlog.Client.FatalPanic(err)
+ }
for _, file := range strings.Split(c.What, ",") {
- commands = append(commands, fmt.Sprintf("%s:%s %s %s", c.Mode.String(), optionsStr, file, c.Regex.Serialize()))
+ commands = append(commands, fmt.Sprintf("%s:%s %s %s",
+ c.Mode.String(), c.Args.SerializeOptions(), file, regex))
}
-
return
}
diff --git a/internal/clients/handlers/basehandler.go b/internal/clients/handlers/basehandler.go
index 602a7ac..b520c25 100644
--- a/internal/clients/handlers/basehandler.go
+++ b/internal/clients/handlers/basehandler.go
@@ -1,6 +1,7 @@
package handlers
import (
+ "bytes"
"encoding/base64"
"fmt"
"io"
@@ -8,8 +9,8 @@ import (
"time"
"github.com/mimecast/dtail/internal"
- "github.com/mimecast/dtail/internal/io/logger"
- "github.com/mimecast/dtail/internal/version"
+ "github.com/mimecast/dtail/internal/io/dlog"
+ "github.com/mimecast/dtail/internal/protocol"
)
type baseHandler struct {
@@ -17,10 +18,20 @@ type baseHandler struct {
server string
shellStarted bool
commands chan string
- receiveBuf []byte
+ receiveBuf bytes.Buffer
status int
}
+func (h *baseHandler) String() string {
+ return fmt.Sprintf("baseHandler(%s,server:%s,shellStarted:%v,status:%d)@%p",
+ h.done,
+ h.server,
+ h.shellStarted,
+ h.status,
+ h,
+ )
+}
+
func (h *baseHandler) Server() string {
return h.server
}
@@ -29,21 +40,13 @@ func (h *baseHandler) Status() int {
return h.status
}
-func (h *baseHandler) Done() <-chan struct{} {
- return h.done.Done()
-}
-
-func (h *baseHandler) Shutdown() {
- h.done.Shutdown()
-}
-
// SendMessage to the server.
func (h *baseHandler) SendMessage(command string) error {
encoded := base64.StdEncoding.EncodeToString([]byte(command))
- logger.Debug("Sending command", h.server, command, encoded)
+ dlog.Client.Debug("Sending command", h.server, command, encoded)
select {
- case h.commands <- fmt.Sprintf("protocol %s base64 %v;", version.ProtocolCompat, encoded):
+ case h.commands <- fmt.Sprintf("protocol %s base64 %v;", protocol.ProtocolCompat, encoded):
case <-time.After(time.Second * 5):
return fmt.Errorf("Timed out sending command '%s' (base64: '%s')", command, encoded)
case <-h.Done():
@@ -56,13 +59,20 @@ func (h *baseHandler) SendMessage(command string) error {
// Read data from the dtail server via Writer interface.
func (h *baseHandler) Write(p []byte) (n int, err error) {
for _, b := range p {
- h.receiveBuf = append(h.receiveBuf, b)
- if b == '\n' {
- if len(h.receiveBuf) == 0 {
+ switch b {
+ /*
+ // NEXT: Next DTail version make it so that '\n' gets ignored. For now
+ // leave it for compatibility with older DTail server + ability to display
+ // the protocol mismatch warn message.
+ case '\n' {
continue
- }
- message := string(h.receiveBuf)
- h.handleMessageType(message)
+ */
+ case '\n', protocol.MessageDelimiter:
+ message := h.receiveBuf.String()
+ h.handleMessage(message)
+ h.receiveBuf.Reset()
+ default:
+ h.receiveBuf.WriteByte(b)
}
}
@@ -77,31 +87,32 @@ func (h *baseHandler) Read(p []byte) (n int, err error) {
case <-h.Done():
return 0, io.EOF
}
-
return
}
-// Handle various message types.
-func (h *baseHandler) handleMessageType(message string) {
- if len(h.receiveBuf) == 0 {
- return
- }
-
- // Hidden server commands starti with a dot "."
- if h.receiveBuf[0] == '.' {
+func (h *baseHandler) handleMessage(message string) {
+ if len(message) > 0 && message[0] == '.' {
h.handleHiddenMessage(message)
- h.receiveBuf = h.receiveBuf[:0]
return
}
- logger.Raw(message)
- h.receiveBuf = h.receiveBuf[:0]
+ dlog.Client.Raw(message)
}
// Handle messages received from server which are not meant to be displayed
// to the end user.
func (h *baseHandler) handleHiddenMessage(message string) {
- if strings.HasPrefix(message, ".syn close connection") {
- h.SendMessage(".ack close connection")
+ switch {
+ case strings.HasPrefix(message, ".syn close connection"):
+ go h.SendMessage(".ack close connection")
+ h.Shutdown()
}
}
+
+func (h *baseHandler) Done() <-chan struct{} {
+ return h.done.Done()
+}
+
+func (h *baseHandler) Shutdown() {
+ h.done.Shutdown()
+}
diff --git a/internal/clients/handlers/clienthandler.go b/internal/clients/handlers/clienthandler.go
index 2bcb038..27ac85e 100644
--- a/internal/clients/handlers/clienthandler.go
+++ b/internal/clients/handlers/clienthandler.go
@@ -2,7 +2,7 @@ package handlers
import (
"github.com/mimecast/dtail/internal"
- "github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/io/dlog"
)
// ClientHandler is the basic client handler interface.
@@ -12,7 +12,7 @@ type ClientHandler struct {
// NewClientHandler creates a new client handler.
func NewClientHandler(server string) *ClientHandler {
- logger.Debug(server, "Creating new client handler")
+ dlog.Client.Debug(server, "Creating new client handler")
return &ClientHandler{
baseHandler{
diff --git a/internal/clients/handlers/healthhandler.go b/internal/clients/handlers/healthhandler.go
index 0440706..47b594e 100644
--- a/internal/clients/handlers/healthhandler.go
+++ b/internal/clients/handlers/healthhandler.go
@@ -1,88 +1,56 @@
package handlers
import (
- "errors"
- "fmt"
- "time"
+ "strings"
"github.com/mimecast/dtail/internal"
+ "github.com/mimecast/dtail/internal/io/dlog"
+ "github.com/mimecast/dtail/internal/protocol"
)
-// HealthHandler implements the handler required for health checks.
+// HealthHandler is the handler used on the client side for running mapreduce
+// aggregations.
type HealthHandler struct {
- done *internal.Done
- // Buffer of incoming data from server.
- receiveBuf []byte
- // To send commands to the server.
- commands chan string
- // To receive messages from the server.
- receive chan<- string
- // The remote server address
- server string
- // The return status.
- status int
-}
-
-// NewHealthHandler returns a new health check handler.
-func NewHealthHandler(server string, receive chan<- string) *HealthHandler {
- h := HealthHandler{
- server: server,
- receive: receive,
- commands: make(chan string),
- status: -1,
- done: internal.NewDone(),
- }
-
- return &h
-}
-
-// Server returns the remote server name.
-func (h *HealthHandler) Server() string {
- return h.server
-}
-
-// Status of the handler.
-func (h *HealthHandler) Status() int {
- return h.status
-}
-
-// Done returns done channel of the handler.
-func (h *HealthHandler) Done() <-chan struct{} {
- return h.done.Done()
-}
-
-// Shutdown the handler.
-func (h *HealthHandler) Shutdown() {
- h.done.Shutdown()
-}
-
-// SendMessage sends a DTail command to the server.
-func (h *HealthHandler) SendMessage(command string) error {
- select {
- case h.commands <- fmt.Sprintf("%s;", command):
- case <-time.NewTimer(time.Second * 10).C:
- return errors.New("Timed out sending command " + command)
- case <-h.Done():
+ baseHandler
+}
+
+// NewHealthHandler returns a new health client handler.
+func NewHealthHandler(server string) *HealthHandler {
+ dlog.Client.Debug(server, "Creating new health handler")
+ return &HealthHandler{
+ baseHandler: baseHandler{
+ server: server,
+ shellStarted: false,
+ commands: make(chan string),
+ status: 2, // Assume CRITICAL status by default.
+ done: internal.NewDone(),
+ },
}
-
- return nil
}
-// Server writes byte stream to client.
+// Read data from the dtail server via Writer interface.
func (h *HealthHandler) Write(p []byte) (n int, err error) {
for _, b := range p {
- h.receiveBuf = append(h.receiveBuf, b)
- if b == '\n' {
- h.receive <- string(h.receiveBuf)
- h.receiveBuf = h.receiveBuf[:0]
+ switch b {
+ case '\n', protocol.MessageDelimiter:
+ message := h.baseHandler.receiveBuf.String()
+ h.handleMessage(message)
+ h.baseHandler.receiveBuf.Reset()
+ default:
+ h.baseHandler.receiveBuf.WriteByte(b)
}
}
-
return len(p), nil
}
-// Server reads byte stream from client.
-func (h *HealthHandler) Read(p []byte) (n int, err error) {
- n = copy(p, []byte(<-h.commands))
- return
+func (h *HealthHandler) handleMessage(message string) {
+ if len(message) > 0 && message[0] == '.' {
+ h.baseHandler.handleHiddenMessage(message)
+ return
+ }
+ s := strings.Split(message, protocol.FieldDelimiter)
+ message = s[len(s)-1]
+ if message == "OK" {
+ h.baseHandler.status = 0
+ }
}
diff --git a/internal/clients/handlers/maprhandler.go b/internal/clients/handlers/maprhandler.go
index fb71c8f..8718b35 100644
--- a/internal/clients/handlers/maprhandler.go
+++ b/internal/clients/handlers/maprhandler.go
@@ -4,21 +4,24 @@ import (
"strings"
"github.com/mimecast/dtail/internal"
- "github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/io/dlog"
"github.com/mimecast/dtail/internal/mapr"
"github.com/mimecast/dtail/internal/mapr/client"
+ "github.com/mimecast/dtail/internal/protocol"
)
-// MaprHandler is the handler used on the client side for running mapreduce aggregations.
+// MaprHandler is the handler used on the client side for running mapreduce
+// aggregations.
type MaprHandler struct {
baseHandler
aggregate *client.Aggregate
query *mapr.Query
- count uint64
}
// NewMaprHandler returns a new mapreduce client handler.
-func NewMaprHandler(server string, query *mapr.Query, globalGroup *mapr.GlobalGroupSet) *MaprHandler {
+func NewMaprHandler(server string, query *mapr.Query,
+ globalGroup *mapr.GlobalGroupSet) *MaprHandler {
+
return &MaprHandler{
baseHandler: baseHandler{
server: server,
@@ -35,34 +38,35 @@ func NewMaprHandler(server string, query *mapr.Query, globalGroup *mapr.GlobalGr
// Read data from the dtail server via Writer interface.
func (h *MaprHandler) Write(p []byte) (n int, err error) {
for _, b := range p {
- h.baseHandler.receiveBuf = append(h.baseHandler.receiveBuf, b)
- if b == '\n' {
- if len(h.baseHandler.receiveBuf) == 0 {
- continue
- }
- message := string(h.baseHandler.receiveBuf)
-
- if h.baseHandler.receiveBuf[0] == 'A' {
- h.handleAggregateMessage(strings.TrimSpace(message))
- h.baseHandler.receiveBuf = h.baseHandler.receiveBuf[:0]
- continue
+ switch b {
+ case '\n':
+ continue
+ case protocol.MessageDelimiter:
+ message := h.baseHandler.receiveBuf.String()
+ dlog.Client.Debug(message)
+ if message[0] == 'A' {
+ h.handleAggregateMessage(message)
+ } else {
+ h.baseHandler.handleMessage(message)
}
- h.baseHandler.handleMessageType(message)
+ h.baseHandler.receiveBuf.Reset()
+ default:
+ h.baseHandler.receiveBuf.WriteByte(b)
}
}
return len(p), nil
}
-// Handle a message received from server including mapr aggregation
-// related data.
+// Handle a message received from server including mapr aggregation related data.
func (h *MaprHandler) handleAggregateMessage(message string) {
- h.count++
- parts := strings.Split(message, "âž”")
-
- // Index 0 contains 'AGGREGATE', 1 contains server host.
- // Aggregation data begins from index 2.
- logger.Debug("Received aggregate data", h.server, h.count, parts)
- h.aggregate.Aggregate(parts[2:])
- logger.Debug("Aggregated aggregate data", h.server, h.count)
+ parts := strings.SplitN(message, protocol.FieldDelimiter, 3)
+ if len(parts) != 3 {
+ dlog.Client.Error("Unable to aggregate data", h.server, message, parts,
+ len(parts), "expected 3 parts")
+ return
+ }
+ if err := h.aggregate.Aggregate(parts[2]); err != nil {
+ dlog.Client.Error("Unable to aggregate data", h.server, message, err)
+ }
}
diff --git a/internal/clients/healthclient.go b/internal/clients/healthclient.go
index e93f6be..1a02827 100644
--- a/internal/clients/healthclient.go
+++ b/internal/clients/healthclient.go
@@ -4,93 +4,75 @@ import (
"context"
"fmt"
"runtime"
- "strings"
- "time"
"github.com/mimecast/dtail/internal/clients/handlers"
- "github.com/mimecast/dtail/internal/clients/remote"
"github.com/mimecast/dtail/internal/config"
"github.com/mimecast/dtail/internal/omode"
gossh "golang.org/x/crypto/ssh"
)
-// HealthClient is used for health checking (e.g. via Nagios)
+// HealthClient is used to perform a basic server health check.
type HealthClient struct {
- // Client operating mode
- mode omode.Mode
- // The remote server address
- server string
- // SSH user name
- userName string
- // SSH auth methods to use to connect to the remote servers.
- sshAuthMethods []gossh.AuthMethod
+ baseClient
}
-// NewHealthClient returns a new healh client.
-func NewHealthClient(mode omode.Mode) (*HealthClient, error) {
+// NewHealthClient returns a new health client.
+func NewHealthClient(args config.Args) (*HealthClient, error) {
+ args.Mode = omode.HealthClient
+ args.UserName = config.HealthUser
c := HealthClient{
- mode: mode,
- server: fmt.Sprintf("%s:%d", config.Server.SSHBindAddress, config.Common.SSHPort),
- userName: config.ControlUser,
+ baseClient: baseClient{
+ Args: args,
+ throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()),
+ retry: false,
+ },
}
- c.initSSHAuthMethods()
+ c.init()
+ c.sshAuthMethods = append(c.sshAuthMethods, gossh.Password(config.HealthUser))
+ c.makeConnections(c)
return &c, nil
}
-// Start the health client.
-func (c *HealthClient) Start(ctx context.Context) (status int) {
- receive := make(chan string)
-
- throttleCh := make(chan struct{}, runtime.NumCPU())
- statsCh := make(chan struct{}, 1)
-
- conn := remote.NewOneOffConnection(c.server, c.userName, c.sshAuthMethods)
- conn.Handler = handlers.NewHealthHandler(c.server, receive)
- conn.Commands = []string{c.mode.String()}
-
- connCtx, cancel := context.WithCancel(ctx)
- go conn.Start(connCtx, cancel, throttleCh, statsCh)
+func (c HealthClient) makeHandler(server string) handlers.Handler {
+ return handlers.NewHealthHandler(server)
+}
- for {
- select {
- case data := <-receive:
- // Parse recieved data.
- s := strings.Split(data, "|")
- message := s[len(s)-1]
- if strings.HasPrefix(message, "done;") {
- return
- }
+func (c HealthClient) makeCommands() (commands []string) {
+ commands = append(commands, "health")
+ return
+}
- // Set severity.
- s = strings.Split(message, ":")
- switch s[0] {
- case "OK":
- case "WARNING":
- if status < 1 {
- status = 1
- }
- case "CRITICAL":
- status = 2
- case "UNKNOWN":
- status = 3
- default:
- fmt.Printf("CRITICAL: Unexpected server response: '%s'\n", message)
- status = 2
- return
- }
- fmt.Print(message)
+// Start the health client.
+func (c *HealthClient) Start(ctx context.Context, statsCh <-chan string) int {
+ status := c.baseClient.Start(ctx, statsCh)
- case <-time.After(time.Second * 2):
- status = 2
- fmt.Println("CRITICAL: Could not communicate with DTail server")
- return
+ switch status {
+ case 0:
+ if c.Serverless {
+ fmt.Printf("WARNING: All seems fine but the check only run in serverless mode" +
+ ", please specify a remote server via --server hostname:port\n")
+ return 1
+ }
+ fmt.Printf("OK: All fine at %s :-)\n", c.ServersStr)
+ case 2:
+ if c.Serverless {
+ fmt.Printf("CRITICAL: DTail server not operating properly (using " +
+ "serverless connction)!\n")
+ return 2
}
+ fmt.Printf("CRITICAL: DTail server not operating properly at %s!\n",
+ c.ServersStr)
+ default:
+ if c.Serverless {
+ fmt.Printf("UNKNOWN: Received unknown status code %d (using serverless "+
+ "connection)\n", status)
+ return status
+ }
+ fmt.Printf("UNKNOWN: Received unknown status code %d from %s!\n",
+ status, c.ServersStr)
}
-}
-// Initialize SSH auth methods.
-func (c *HealthClient) initSSHAuthMethods() {
- c.sshAuthMethods = append(c.sshAuthMethods, gossh.Password(config.ControlUser))
+ return status
}
diff --git a/internal/clients/maker.go b/internal/clients/maker.go
index a1d6864..d5ffd8b 100644
--- a/internal/clients/maker.go
+++ b/internal/clients/maker.go
@@ -9,5 +9,5 @@ import (
// and send different commands to the DTail server.
type maker interface {
makeHandler(server string) handlers.Handler
- makeCommands(options map[string]string) (commands []string)
+ makeCommands() (commands []string)
}
diff --git a/internal/clients/maprclient.go b/internal/clients/maprclient.go
index feb7e47..246946f 100644
--- a/internal/clients/maprclient.go
+++ b/internal/clients/maprclient.go
@@ -9,7 +9,9 @@ import (
"time"
"github.com/mimecast/dtail/internal/clients/handlers"
- "github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/color"
+ "github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/io/dlog"
"github.com/mimecast/dtail/internal/mapr"
"github.com/mimecast/dtail/internal/omode"
)
@@ -29,25 +31,25 @@ const (
// MaprClient is used for running mapreduce aggregations on remote files.
type MaprClient struct {
baseClient
- // Query string for mapr aggregations
- queryStr string
// Global group set for merged mapr aggregation results
globalGroup *mapr.GlobalGroupSet
// The query object (constructed from queryStr)
query *mapr.Query
// Additative result or new result every interval run?
cumulative bool
+ // The last result string received
+ lastResult string
}
// NewMaprClient returns a new mapreduce client.
-func NewMaprClient(args Args, queryStr string, maprClientMode MaprClientMode) (*MaprClient, error) {
- if queryStr == "" {
+func NewMaprClient(args config.Args, maprClientMode MaprClientMode) (*MaprClient, error) {
+ if args.QueryStr == "" {
return nil, errors.New("No mapreduce query specified, use '-query' flag")
}
- query, err := mapr.NewQuery(queryStr)
+ query, err := mapr.NewQuery(args.QueryStr)
if err != nil {
- logger.FatalExit(queryStr, "Can't parse mapr query", err)
+ dlog.Client.FatalPanic(args.QueryStr, "Can't parse mapr query", err)
}
// Don't retry connection if in tail mode and no outfile specified.
@@ -64,7 +66,7 @@ func NewMaprClient(args Args, queryStr string, maprClientMode MaprClientMode) (*
cumulative = args.Mode == omode.MapClient || query.HasOutfile()
}
- logger.Debug("Cumulative mapreduce mode?", cumulative)
+ dlog.Client.Debug("Cumulative mapreduce mode?", cumulative)
c := MaprClient{
baseClient: baseClient{
@@ -73,7 +75,6 @@ func NewMaprClient(args Args, queryStr string, maprClientMode MaprClientMode) (*
retry: retry,
},
query: query,
- queryStr: queryStr,
cumulative: cumulative,
}
@@ -99,46 +100,51 @@ func (c *MaprClient) Start(ctx context.Context, statsCh <-chan string) (status i
status = c.baseClient.Start(ctx, statsCh)
if c.cumulative {
- logger.Debug("Received final mapreduce result")
+ dlog.Client.Debug("Received final mapreduce result")
c.reportResults()
}
return
}
+// NEXT: Make this a callback function rather trying to use polymorphism to call
+// this. This applies to all clients. It will make the code easier to read.
func (c MaprClient) makeHandler(server string) handlers.Handler {
return handlers.NewMaprHandler(server, c.query, c.globalGroup)
}
-func (c MaprClient) makeCommands(options map[string]string) (commands []string) {
+func (c MaprClient) makeCommands() (commands []string) {
commands = append(commands, fmt.Sprintf("map %s", c.query.RawQuery))
-
modeStr := "cat"
if c.Mode == omode.TailClient {
modeStr = "tail"
}
- optionsStr := c.commandOptionsToString(options)
for _, file := range strings.Split(c.What, ",") {
+ regex, err := c.Regex.Serialize()
+ if err != nil {
+ dlog.Client.FatalPanic(err)
+ }
if c.Timeout > 0 {
- commands = append(commands, fmt.Sprintf("timeout %d %s %s %s", c.Timeout, modeStr, file, c.Regex.Serialize()))
+ commands = append(commands, fmt.Sprintf("timeout %d %s %s %s", c.Timeout,
+ modeStr, file, regex))
continue
}
- commands = append(commands, fmt.Sprintf("%s:%s %s %s", modeStr, optionsStr, file, c.Regex.Serialize()))
+ commands = append(commands, fmt.Sprintf("%s:%s %s %s",
+ modeStr, c.Args.SerializeOptions(), file, regex))
}
-
return
}
func (c *MaprClient) periodicReportResults(ctx context.Context) {
rampUpSleep := c.query.Interval / 2
- logger.Debug("Ramp up sleeping before processing mapreduce results", rampUpSleep)
+ dlog.Client.Debug("Ramp up sleeping before processing mapreduce results", rampUpSleep)
time.Sleep(rampUpSleep)
for {
select {
case <-time.After(c.query.Interval):
- logger.Debug("Gathering interim mapreduce result")
+ dlog.Client.Debug("Gathering interim mapreduce result")
c.reportResults()
case <-ctx.Done():
return
@@ -151,42 +157,65 @@ func (c *MaprClient) reportResults() {
c.writeResultsToOutfile()
return
}
-
c.printResults()
}
func (c *MaprClient) printResults() {
var result string
var err error
- var numLines int
+ var numRows int
+ rowsLimit := -1
+
+ if c.query.Limit == -1 {
+ // Limit output to 10 rows when the result is printed to stdout.
+ // This can be overriden with the limit clause though.
+ rowsLimit = 10
+ }
if c.cumulative {
- result, numLines, err = c.globalGroup.Result(c.query)
+ result, numRows, err = c.globalGroup.Result(c.query, rowsLimit)
} else {
- result, numLines, err = c.globalGroup.SwapOut().Result(c.query)
+ result, numRows, err = c.globalGroup.SwapOut().Result(c.query, rowsLimit)
}
if err != nil {
- logger.FatalExit(err)
+ dlog.Client.FatalPanic(err)
}
- if numLines == 0 {
- logger.Warn("Empty result set this time...")
+ if result == c.lastResult {
+ dlog.Client.Debug("Result hasn't changed compared to last time...")
return
}
+ c.lastResult = result
- logger.Raw(fmt.Sprintf("%s\n", c.query.RawQuery))
- logger.Raw(result)
+ if numRows == 0 {
+ dlog.Client.Debug("Empty result set this time...")
+ return
+ }
+
+ rawQuery := c.query.RawQuery
+ if config.Client.TermColorsEnable {
+ rawQuery = color.PaintStrWithAttr(rawQuery,
+ config.Client.TermColors.MaprTable.RawQueryFg,
+ config.Client.TermColors.MaprTable.RawQueryBg,
+ config.Client.TermColors.MaprTable.RawQueryAttr)
+ }
+ dlog.Client.Raw(rawQuery)
+
+ if rowsLimit > 0 && numRows > rowsLimit {
+ dlog.Client.Warn(fmt.Sprintf("Got %d results but limited terminal output "+
+ "to %d rows! Use 'limit' clause to override!", numRows, rowsLimit))
+ }
+ dlog.Client.Raw(result)
}
func (c *MaprClient) writeResultsToOutfile() {
if c.cumulative {
if err := c.globalGroup.WriteResult(c.query); err != nil {
- logger.FatalExit(err)
+ dlog.Client.FatalPanic(err)
}
return
}
-
if err := c.globalGroup.SwapOut().WriteResult(c.query); err != nil {
- logger.FatalExit(err)
+ dlog.Client.FatalPanic(err)
}
}
diff --git a/internal/clients/remote/connection.go b/internal/clients/remote/connection.go
deleted file mode 100644
index b29ffed..0000000
--- a/internal/clients/remote/connection.go
+++ /dev/null
@@ -1,212 +0,0 @@
-package remote
-
-import (
- "context"
- "fmt"
- "io"
- "strconv"
- "strings"
- "time"
-
- "github.com/mimecast/dtail/internal/clients/handlers"
- "github.com/mimecast/dtail/internal/config"
- "github.com/mimecast/dtail/internal/io/logger"
- "github.com/mimecast/dtail/internal/ssh/client"
-
- "golang.org/x/crypto/ssh"
-)
-
-// Connection represents a client connection connection to a single server.
-type Connection struct {
- // The remote server's hostname connected to.
- Server string
- // The remote server's port connected to.
- port int
- // The SSH client configuration used.
- config *ssh.ClientConfig
- // The SSH client handler to use.
- Handler handlers.Handler
- // DTail commands sent from client to server. When client loses
- // connection to the server it re-connects automatically and sends the
- // same commands again.
- Commands []string
- // Is it a persistent connection or a one-off?
- isOneOff bool
- // To deal with SSH server host keys
- hostKeyCallback client.HostKeyCallback
- // To determine if connection throttling has finished or not
- throttlingDone bool
-}
-
-// NewConnection returns a new connection.
-func NewConnection(server string, userName string, authMethods []ssh.AuthMethod, hostKeyCallback client.HostKeyCallback) *Connection {
- logger.Debug(server, "Creating new connection")
-
- c := Connection{
- hostKeyCallback: hostKeyCallback,
- config: &ssh.ClientConfig{
- User: userName,
- Auth: authMethods,
- HostKeyCallback: hostKeyCallback.Wrap(),
- Timeout: time.Second * 3,
- },
- }
-
- c.initServerPort(server)
-
- return &c
-}
-
-// NewOneOffConnection creates new one-off connection (only for sending a series of commands and then quit).
-func NewOneOffConnection(server string, userName string, authMethods []ssh.AuthMethod) *Connection {
- c := Connection{
- config: &ssh.ClientConfig{
- User: userName,
- Auth: authMethods,
- HostKeyCallback: ssh.InsecureIgnoreHostKey(),
- },
- isOneOff: true,
- }
-
- c.initServerPort(server)
-
- return &c
-}
-
-// Attempt to parse the server port address from the provided server FQDN.
-func (c *Connection) initServerPort(server string) {
- c.Server = server
- c.port = config.Common.SSHPort
- parts := strings.Split(server, ":")
-
- if len(parts) == 2 {
- logger.Debug("Parsing port from hostname", parts)
- port, err := strconv.Atoi(parts[1])
- if err != nil {
- logger.FatalExit("Unable to parse client port", server, parts, err)
- }
- c.Server = parts[0]
- c.port = port
- }
-}
-
-// Start the server connection. Build up SSH session and send some DTail commands.
-func (c *Connection) Start(ctx context.Context, cancel context.CancelFunc, throttleCh, statsCh chan struct{}) {
- // Throttle how many connections can be established concurrently (based on ch length)
- logger.Debug(c.Server, "Throttling connection", len(throttleCh), cap(throttleCh))
-
- select {
- case throttleCh <- struct{}{}:
- case <-ctx.Done():
- logger.Debug(c.Server, "Not establishing connection as context is done", len(throttleCh), cap(throttleCh))
- return
- }
-
- logger.Debug(c.Server, "Throttling says that the connection can be established", len(throttleCh), cap(throttleCh))
-
- go func() {
- defer func() {
- if !c.throttlingDone {
- logger.Debug(c.Server, "Unthrottling connection (1)", len(throttleCh), cap(throttleCh))
- c.throttlingDone = true
- <-throttleCh
- }
- cancel()
- }()
-
- if err := c.dial(ctx, cancel, throttleCh, statsCh); err != nil {
- logger.Warn(c.Server, c.port, err)
- if c.hostKeyCallback.Untrusted(fmt.Sprintf("%s:%d", c.Server, c.port)) {
- logger.Debug(c.Server, "Not trusting host")
- }
- }
- }()
-
- <-ctx.Done()
-}
-
-// Dail into a new SSH connection. Close connection in case of an error.
-func (c *Connection) dial(ctx context.Context, cancel context.CancelFunc, throttleCh, statsCh chan struct{}) error {
- logger.Debug(c.Server, "Incrementing connection stats")
- statsCh <- struct{}{}
- defer func() {
- logger.Debug(c.Server, "Decrementing connection stats")
- <-statsCh
- }()
-
- logger.Debug(c.Server, "Dialing into the connection")
- address := fmt.Sprintf("%s:%d", c.Server, c.port)
-
- client, err := ssh.Dial("tcp", address, c.config)
- if err != nil {
- return err
- }
- defer client.Close()
-
- return c.session(ctx, cancel, client, throttleCh)
-}
-
-// Create the SSH session. Close the session in case of an error.
-func (c *Connection) session(ctx context.Context, cancel context.CancelFunc, client *ssh.Client, throttleCh chan struct{}) error {
- logger.Debug(c.Server, "session")
-
- session, err := client.NewSession()
- if err != nil {
- return err
- }
- defer session.Close()
-
- return c.handle(ctx, cancel, session, throttleCh)
-}
-
-func (c *Connection) handle(ctx context.Context, cancel context.CancelFunc, session *ssh.Session, throttleCh chan struct{}) error {
- logger.Debug(c.Server, "handle")
-
- stdinPipe, err := session.StdinPipe()
- if err != nil {
- return err
- }
-
- stdoutPipe, err := session.StdoutPipe()
- if err != nil {
- return err
- }
-
- if err := session.Shell(); err != nil {
- return err
- }
-
- go func() {
- io.Copy(stdinPipe, c.Handler)
- cancel()
- }()
-
- go func() {
- io.Copy(c.Handler, stdoutPipe)
- cancel()
- }()
-
- go func() {
- select {
- case <-c.Handler.Done():
- case <-ctx.Done():
- }
- cancel()
- }()
-
- // Send all commands to client.
- for _, command := range c.Commands {
- logger.Debug(command)
- c.Handler.SendMessage(command)
- }
-
- if !c.throttlingDone {
- logger.Debug(c.Server, "Unthrottling connection (2)", len(throttleCh), cap(throttleCh))
- c.throttlingDone = true
- <-throttleCh
- }
-
- <-ctx.Done()
- c.Handler.Shutdown()
- return nil
-}
diff --git a/internal/clients/stats.go b/internal/clients/stats.go
index d8163d4..1315aea 100644
--- a/internal/clients/stats.go
+++ b/internal/clients/stats.go
@@ -8,14 +8,16 @@ import (
"sync"
"time"
+ "github.com/mimecast/dtail/internal/color"
"github.com/mimecast/dtail/internal/config"
- "github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/io/dlog"
+ "github.com/mimecast/dtail/internal/protocol"
)
// Used to collect and display various client stats.
type stats struct {
// Total amount servers to connect to.
- connectionsTotal int
+ servers int
// To keep track of what connected and disconnected
connectionsEstCh chan struct{}
// Amount of servers connections are established.
@@ -24,19 +26,20 @@ type stats struct {
mutex sync.Mutex
}
-func newTailStats(connectionsTotal int) *stats {
+func newTailStats(servers int) *stats {
return &stats{
- connectionsTotal: connectionsTotal,
- connectionsEstCh: make(chan struct{}, connectionsTotal),
+ servers: servers,
+ connectionsEstCh: make(chan struct{}, servers),
connected: 0,
}
}
// Start starts printing client connection stats every time a signal is recieved or
// connection count has changed.
-func (s *stats) Start(ctx context.Context, throttleCh <-chan struct{}, statsCh <-chan string, quiet bool) {
- var connectedLast int
+func (s *stats) Start(ctx context.Context, throttleCh <-chan struct{},
+ statsCh <-chan string, quiet bool) {
+ var connectedLast int
for {
var force bool
var messages []string
@@ -54,18 +57,18 @@ func (s *stats) Start(ctx context.Context, throttleCh <-chan struct{}, statsCh <
throttle := len(throttleCh)
newConnections := connected - connectedLast
-
if (connected == connectedLast || quiet) && !force {
continue
}
- stats := s.statsLine(connected, newConnections, throttle)
switch force {
case true:
+ stats := s.statsLine(connected, newConnections, throttle)
messages = append(messages, fmt.Sprintf("Connection stats: %s", stats))
s.printStatsDueInterrupt(messages)
default:
- logger.Info(stats)
+ data := s.statsData(connected, newConnections, throttle)
+ dlog.Client.Mapreduce("STATS", data)
}
connectedLast = connected
@@ -76,30 +79,58 @@ func (s *stats) Start(ctx context.Context, throttleCh <-chan struct{}, statsCh <
}
func (s *stats) printStatsDueInterrupt(messages []string) {
- logger.Pause()
- for _, message := range messages {
+ dlog.Client.Pause()
+ for i, message := range messages {
+ if i > 0 && config.Client.TermColorsEnable {
+ fmt.Println(color.PaintStrWithAttr(message,
+ config.Client.TermColors.Client.ClientFg,
+ config.Client.TermColors.Client.ClientBg,
+ config.Client.TermColors.Client.ClientAttr,
+ ))
+ continue
+ }
fmt.Println(fmt.Sprintf(" %s", message))
}
time.Sleep(time.Second * time.Duration(config.InterruptTimeoutS))
- logger.Resume()
+ dlog.Client.Resume()
}
-func (s *stats) statsLine(connected, newConnections int, throttle int) string {
- percConnected := percentOf(float64(s.connectionsTotal), float64(connected))
+func (s *stats) statsData(connected, newConnections int,
+ throttle int) map[string]interface{} {
+
+ percConnected := percentOf(float64(s.servers), float64(connected))
- var stats []string
- stats = append(stats, fmt.Sprintf("connected=%d/%d(%d%%)", connected, s.connectionsTotal, int(percConnected)))
- stats = append(stats, fmt.Sprintf("new=%d", newConnections))
- stats = append(stats, fmt.Sprintf("throttle=%d", throttle))
- stats = append(stats, fmt.Sprintf("cpus/goroutines=%d/%d", runtime.NumCPU(), runtime.NumGoroutine()))
+ data := make(map[string]interface{})
+ data["connected"] = connected
+ data["servers"] = s.servers
+ data["connected%"] = int(percConnected)
+ data["new"] = newConnections
+ data["throttle"] = throttle
+ data["goroutines"] = runtime.NumGoroutine()
+ data["cgocalls"] = runtime.NumCgoCall()
+ data["cpu"] = runtime.NumCPU()
- return strings.Join(stats, "|")
+ return data
+}
+
+func (s *stats) statsLine(connected, newConnections int, throttle int) string {
+ sb := strings.Builder{}
+ i := 0
+ for k, v := range s.statsData(connected, newConnections, throttle) {
+ if i > 0 {
+ sb.WriteString(protocol.FieldDelimiter)
+ }
+ sb.WriteString(k)
+ sb.WriteByte('=')
+ sb.WriteString(fmt.Sprintf("%v", v))
+ i++
+ }
+ return sb.String()
}
func (s *stats) numConnected() int {
s.mutex.Lock()
defer s.mutex.Unlock()
-
return s.connected
}
@@ -107,6 +138,5 @@ func percentOf(total float64, value float64) float64 {
if total == 0 || total == value {
return 100
}
-
return value / (total / 100.0)
}
diff --git a/internal/clients/tailclient.go b/internal/clients/tailclient.go
index 853ef1d..35c01d4 100644
--- a/internal/clients/tailclient.go
+++ b/internal/clients/tailclient.go
@@ -6,7 +6,8 @@ import (
"strings"
"github.com/mimecast/dtail/internal/clients/handlers"
- "github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/io/dlog"
"github.com/mimecast/dtail/internal/omode"
)
@@ -16,9 +17,8 @@ type TailClient struct {
}
// NewTailClient returns a new TailClient.
-func NewTailClient(args Args) (*TailClient, error) {
+func NewTailClient(args config.Args) (*TailClient, error) {
args.Mode = omode.TailClient
-
c := TailClient{
baseClient: baseClient{
Args: args,
@@ -29,7 +29,6 @@ func NewTailClient(args Args) (*TailClient, error) {
c.init()
c.makeConnections(c)
-
return &c, nil
}
@@ -37,12 +36,15 @@ func (c TailClient) makeHandler(server string) handlers.Handler {
return handlers.NewClientHandler(server)
}
-func (c TailClient) makeCommands(options map[string]string) (commands []string) {
- optionsStr := c.commandOptionsToString(options)
+func (c TailClient) makeCommands() (commands []string) {
+ regex, err := c.Regex.Serialize()
+ if err != nil {
+ dlog.Client.FatalPanic(err)
+ }
for _, file := range strings.Split(c.What, ",") {
- commands = append(commands, fmt.Sprintf("%s:%s %s %s", c.Mode.String(), optionsStr, file, c.Regex.Serialize()))
+ commands = append(commands, fmt.Sprintf("%s:%s %s %s",
+ c.Mode.String(), c.Args.SerializeOptions(), file, regex))
}
- logger.Debug(commands)
-
+ dlog.Client.Debug(commands)
return
}
diff --git a/internal/color/brush/brush.go b/internal/color/brush/brush.go
new file mode 100644
index 0000000..63d63d8
--- /dev/null
+++ b/internal/color/brush/brush.go
@@ -0,0 +1,194 @@
+package brush
+
+import (
+ "strings"
+
+ "github.com/mimecast/dtail/internal/color"
+ "github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/io/pool"
+ "github.com/mimecast/dtail/internal/protocol"
+)
+
+func paintSeverity(sb *strings.Builder, text string) bool {
+ switch {
+ case strings.HasPrefix(text, "WARN"):
+ color.PaintWithAttr(sb, text,
+ config.Client.TermColors.Common.SeverityWarnFg,
+ config.Client.TermColors.Common.SeverityWarnBg,
+ config.Client.TermColors.Common.SeverityWarnAttr)
+
+ case strings.HasPrefix(text, "ERROR"):
+ color.PaintWithAttr(sb, text,
+ config.Client.TermColors.Common.SeverityErrorFg,
+ config.Client.TermColors.Common.SeverityErrorBg,
+ config.Client.TermColors.Common.SeverityErrorAttr)
+
+ case strings.HasPrefix(text, "FATAL"):
+ color.PaintWithAttr(sb, text,
+ config.Client.TermColors.Common.SeverityFatalFg,
+ config.Client.TermColors.Common.SeverityFatalBg,
+ config.Client.TermColors.Common.SeverityFatalAttr)
+
+ default:
+ return false
+ }
+ return true
+}
+
+func paintRemote(sb *strings.Builder, line string) {
+ splitted := strings.SplitN(line, protocol.FieldDelimiter, 6)
+
+ color.PaintWithAttr(sb, splitted[0],
+ config.Client.TermColors.Remote.RemoteFg,
+ config.Client.TermColors.Remote.RemoteBg,
+ config.Client.TermColors.Remote.RemoteAttr)
+
+ color.PaintWithAttr(sb, protocol.FieldDelimiter,
+ config.Client.TermColors.Remote.DelimiterFg,
+ config.Client.TermColors.Remote.DelimiterBg,
+ config.Client.TermColors.Remote.DelimiterAttr)
+
+ color.PaintWithAttr(sb, splitted[1],
+ config.Client.TermColors.Remote.HostnameFg,
+ config.Client.TermColors.Remote.HostnameBg,
+ config.Client.TermColors.Remote.HostnameAttr)
+
+ color.PaintWithAttr(sb, protocol.FieldDelimiter,
+ config.Client.TermColors.Remote.DelimiterFg,
+ config.Client.TermColors.Remote.DelimiterBg,
+ config.Client.TermColors.Remote.DelimiterAttr)
+
+ if splitted[2] == "100" {
+ color.PaintWithAttr(sb, splitted[2],
+ config.Client.TermColors.Remote.StatsOkFg,
+ config.Client.TermColors.Remote.StatsOkBg,
+ config.Client.TermColors.Remote.StatsOkAttr)
+ } else {
+ color.PaintWithAttr(sb, splitted[2],
+ config.Client.TermColors.Remote.StatsWarnFg,
+ config.Client.TermColors.Remote.StatsWarnBg,
+ config.Client.TermColors.Remote.StatsWarnAttr)
+ }
+
+ color.PaintWithAttr(sb, protocol.FieldDelimiter,
+ config.Client.TermColors.Remote.DelimiterFg,
+ config.Client.TermColors.Remote.DelimiterBg,
+ config.Client.TermColors.Remote.DelimiterAttr)
+
+ color.PaintWithAttr(sb, splitted[3],
+ config.Client.TermColors.Remote.CountFg,
+ config.Client.TermColors.Remote.CountBg,
+ config.Client.TermColors.Remote.CountAttr)
+
+ color.PaintWithAttr(sb, protocol.FieldDelimiter,
+ config.Client.TermColors.Remote.DelimiterFg,
+ config.Client.TermColors.Remote.DelimiterBg,
+ config.Client.TermColors.Remote.DelimiterAttr)
+
+ color.PaintWithAttr(sb, splitted[4],
+ config.Client.TermColors.Remote.IDFg,
+ config.Client.TermColors.Remote.IDBg,
+ config.Client.TermColors.Remote.IDAttr)
+ color.PaintWithAttr(sb, protocol.FieldDelimiter,
+ config.Client.TermColors.Remote.DelimiterFg,
+ config.Client.TermColors.Remote.DelimiterBg,
+ config.Client.TermColors.Remote.DelimiterAttr)
+
+ if paintSeverity(sb, splitted[5]) {
+ return
+ }
+ color.PaintWithAttr(sb, splitted[5],
+ config.Client.TermColors.Remote.TextFg,
+ config.Client.TermColors.Remote.TextBg,
+ config.Client.TermColors.Remote.TextAttr)
+}
+
+func paintClient(sb *strings.Builder, line string) {
+ splitted := strings.SplitN(line, protocol.FieldDelimiter, 3)
+
+ color.PaintWithAttr(sb, splitted[0],
+ config.Client.TermColors.Client.ClientFg,
+ config.Client.TermColors.Client.ClientBg,
+ config.Client.TermColors.Client.ClientAttr)
+
+ color.PaintWithAttr(sb, protocol.FieldDelimiter,
+ config.Client.TermColors.Client.DelimiterFg,
+ config.Client.TermColors.Client.DelimiterBg,
+ config.Client.TermColors.Client.DelimiterAttr)
+
+ color.PaintWithAttr(sb, splitted[1],
+ config.Client.TermColors.Client.HostnameFg,
+ config.Client.TermColors.Client.HostnameBg,
+ config.Client.TermColors.Client.HostnameAttr)
+
+ color.PaintWithAttr(sb, protocol.FieldDelimiter,
+ config.Client.TermColors.Client.DelimiterFg,
+ config.Client.TermColors.Client.DelimiterBg,
+ config.Client.TermColors.Client.DelimiterAttr)
+
+ if paintSeverity(sb, splitted[2]) {
+ return
+ }
+
+ color.PaintWithAttr(sb, splitted[2],
+ config.Client.TermColors.Client.TextFg,
+ config.Client.TermColors.Client.TextBg,
+ config.Client.TermColors.Client.TextAttr)
+}
+
+func paintServer(sb *strings.Builder, line string) {
+ splitted := strings.SplitN(line, protocol.FieldDelimiter, 3)
+
+ color.PaintWithAttr(sb, splitted[0],
+ config.Client.TermColors.Server.ServerFg,
+ config.Client.TermColors.Server.ServerBg,
+ config.Client.TermColors.Server.ServerAttr)
+
+ color.PaintWithAttr(sb, protocol.FieldDelimiter,
+ config.Client.TermColors.Server.DelimiterFg,
+ config.Client.TermColors.Server.DelimiterBg,
+ config.Client.TermColors.Server.DelimiterAttr)
+
+ color.PaintWithAttr(sb, splitted[1],
+ config.Client.TermColors.Server.HostnameFg,
+ config.Client.TermColors.Server.HostnameBg,
+ config.Client.TermColors.Server.HostnameAttr)
+
+ color.PaintWithAttr(sb, protocol.FieldDelimiter,
+ config.Client.TermColors.Server.DelimiterFg,
+ config.Client.TermColors.Server.DelimiterBg,
+ config.Client.TermColors.Server.DelimiterAttr)
+
+ if paintSeverity(sb, splitted[2]) {
+ return
+ }
+
+ color.PaintWithAttr(sb, splitted[2],
+ config.Client.TermColors.Server.TextFg,
+ config.Client.TermColors.Server.TextBg,
+ config.Client.TermColors.Server.TextAttr)
+}
+
+// Colorfy a given line based on the line's content.
+func Colorfy(line string) string {
+ sb := pool.BuilderBuffer.Get().(*strings.Builder)
+ defer pool.RecycleBuilderBuffer(sb)
+
+ switch {
+ case strings.HasPrefix(line, "REMOTE"):
+ paintRemote(sb, line)
+
+ case strings.HasPrefix(line, "CLIENT"):
+ paintClient(sb, line)
+
+ case strings.HasPrefix(line, "SERVER"):
+ paintServer(sb, line)
+
+ default:
+ color.PaintWithAttr(sb, line,
+ color.FgDefault,
+ color.BgDefault,
+ color.AttrNone)
+ }
+ return sb.String()
+}
diff --git a/internal/color/color.go b/internal/color/color.go
index 0736199..9d0bc2e 100644
--- a/internal/color/color.go
+++ b/internal/color/color.go
@@ -1,70 +1,148 @@
-// Package color is used to prettify console output via ANSII terminal colors.
+// Package color contains all terminal color codes we know of.
package color
import (
"fmt"
+ "strings"
)
-// Color name.
-type Color string
+// FgColor is the text foreground color.
+type FgColor string
-// Attribute of a color.
+// BgColor is the text background color.
+type BgColor string
+
+// Attribute of text.
type Attribute string
// The possible color variations.
const (
- escape = "\x1b"
- reset = escape + "[0m"
- seq string = "%s%s%s"
+ escape = "\x1b"
- Gray Color = escape + "[30m"
- Red Color = escape + "[31m"
- Green Color = escape + "[32m"
- Orange Color = escape + "[33m"
- Blue Color = escape + "[34m"
- Magenta Color = escape + "[35m"
- Yellow Color = escape + "[36m"
- LightGray Color = escape + "[37m"
+ FgBlack FgColor = escape + "[30m"
+ FgRed FgColor = escape + "[31m"
+ FgGreen FgColor = escape + "[32m"
+ FgYellow FgColor = escape + "[33m"
+ FgBlue FgColor = escape + "[34m"
+ FgMagenta FgColor = escape + "[35m"
+ FgCyan FgColor = escape + "[36m"
+ FgWhite FgColor = escape + "[37m"
+ FgDefault FgColor = escape + "[39m"
- BgGray Color = escape + "[40m"
- BgRed Color = escape + "[41m"
- BgGreen Color = escape + "[42m"
- BgOrange Color = escape + "[43m"
- BgBlue Color = escape + "[44m"
- BgMagenta Color = escape + "[45m"
- BgYellow Color = escape + "[46m"
- BgLightGray Color = escape + "[47m"
+ BgBlack BgColor = escape + "[40m"
+ BgRed BgColor = escape + "[41m"
+ BgGreen BgColor = escape + "[42m"
+ BgYellow BgColor = escape + "[43m"
+ BgBlue BgColor = escape + "[44m"
+ BgMagenta BgColor = escape + "[45m"
+ BgCyan BgColor = escape + "[46m"
+ BgWhite BgColor = escape + "[47m"
+ BgDefault BgColor = escape + "[49m"
- Bold Attribute = escape + "[1m"
- Italic Attribute = escape + "[3m"
- Underline Attribute = escape + "[4m"
- ReverseColor Attribute = escape + "[7m"
+ AttrNone Attribute = ""
+ AttrReset Attribute = escape + "[0m"
+ AttrBold Attribute = escape + "[1m"
+ AttrDim Attribute = escape + "[2m"
+ AttrItalic Attribute = escape + "[3m"
+ AttrUnderline Attribute = escape + "[4m"
+ AttrBlink Attribute = escape + "[5m"
+ AttrSlowBlink Attribute = escape + "[5m"
+ AttrRapidBlink Attribute = escape + "[6m"
+ AttrReverse Attribute = escape + "[7m"
+ AttrHidden Attribute = escape + "[8m"
+)
- resetBold = escape + "[22m"
- resetItalic = escape + "[23m"
- resetUnderline = escape + "[24m"
+// ColorNames is the list of all supported terminal colors.
+var ColorNames = []string{
+ "Black", "Red", "Green", "Yellow", "Blue", "Magenta", "Cyan", "White", "Default",
+}
- Test Color = BgYellow
- TestAttr Attribute = Bold
-)
+// AttributeNames is the list of all supported terminal text attributes.
+var AttributeNames = []string{
+ "Bold", "Dim", "Italic", "Underline", "Blink", "SlowBlink", "RapidBlink",
+ "Reverse", "Hidden", "None",
+}
-// Colored DTail client output enabled.
-var Colored bool
+// ToFgColor converts a given string (e.g. from a config file) into a foreground
+// color code.
+func ToFgColor(s string) (FgColor, error) {
+ switch strings.ToLower(s) {
+ case "black":
+ return FgBlack, nil
+ case "red":
+ return FgRed, nil
+ case "green":
+ return FgGreen, nil
+ case "yellow":
+ return FgYellow, nil
+ case "blue":
+ return FgBlue, nil
+ case "magenta":
+ return FgMagenta, nil
+ case "cyan":
+ return FgCyan, nil
+ case "white":
+ return FgWhite, nil
+ case "default":
+ return FgDefault, nil
+ default:
+ return FgDefault, fmt.Errorf("unknown foreground text color '" + s + "'")
+ }
+}
-// Paint a given string in a given color.
-func Paint(c Color, s string) string {
- return fmt.Sprintf(seq, c, s, reset)
+// ToBgColor converts a given string (e.g. from a config file) into a background
+// color code.
+func ToBgColor(s string) (BgColor, error) {
+ switch strings.ToLower(s) {
+ case "black":
+ return BgBlack, nil
+ case "red":
+ return BgRed, nil
+ case "green":
+ return BgGreen, nil
+ case "yellow":
+ return BgYellow, nil
+ case "blue":
+ return BgBlue, nil
+ case "magenta":
+ return BgMagenta, nil
+ case "cyan":
+ return BgCyan, nil
+ case "white":
+ return BgWhite, nil
+ case "default":
+ return BgDefault, nil
+ default:
+ return BgDefault, fmt.Errorf("unknown background text color '" + s + "'")
+ }
}
-// Attr adds a given attribute to a given string, such as "bold" or "italic".
-func Attr(c Attribute, s string) string {
- switch c {
- case Bold:
- return fmt.Sprintf(seq, Bold, s, resetBold)
- case Italic:
- return fmt.Sprintf(seq, Italic, s, resetItalic)
- case Underline:
- return fmt.Sprintf(seq, Underline, s, resetUnderline)
+// ToAttribute converts a given string (e.g. from a config file) into a text attribute.
+func ToAttribute(s string) (Attribute, error) {
+ switch strings.ToLower(s) {
+ case "bold":
+ return AttrBold, nil
+ case "dim":
+ return AttrDim, nil
+ case "italic":
+ return AttrItalic, nil
+ case "underline":
+ return AttrUnderline, nil
+ case "blink":
+ return AttrBlink, nil
+ case "slowblink":
+ return AttrSlowBlink, nil
+ case "rapidblink":
+ return AttrRapidBlink, nil
+ case "reverse":
+ return AttrReverse, nil
+ case "hidden":
+ return AttrHidden, nil
+ case "none":
+ fallthrough
+ case "":
+ return AttrNone, nil
+ default:
+ return AttrNone, fmt.Errorf("unknown text attribute '" + s + "'")
}
- panic("Unknown attribute")
}
diff --git a/internal/color/color_test.go b/internal/color/color_test.go
new file mode 100644
index 0000000..7002052
--- /dev/null
+++ b/internal/color/color_test.go
@@ -0,0 +1,53 @@
+package color
+
+import (
+ "strings"
+ "testing"
+)
+
+func TestColors(t *testing.T) {
+ text := " Mimecast "
+ builder := strings.Builder{}
+
+ for _, color := range ColorNames {
+ fgColor, err := ToFgColor(color)
+ if err != nil {
+ t.Errorf("unable to paint foreground : %s\n%v", text, err)
+ }
+ builder.WriteString(PaintStrFg(text, fgColor))
+
+ bgColor, err := ToBgColor(color)
+ if err != nil {
+ t.Errorf("unable to paint background: %s\n%v", text, err)
+ }
+ builder.WriteString(PaintStrBg(text, bgColor))
+ }
+
+ for _, fg := range ColorNames {
+ fgColor, _ := ToFgColor(fg)
+ for _, bg := range ColorNames {
+ if fg == bg {
+ continue
+ }
+ bgColor, _ := ToBgColor(bg)
+ builder.WriteString(PaintStr(text, fgColor, bgColor))
+ }
+ }
+
+ t.Log(builder.String())
+}
+
+func TestAttributes(t *testing.T) {
+ text := " Mimecast "
+ builder := strings.Builder{}
+
+ for _, attribute := range AttributeNames {
+ att, err := ToAttribute(attribute)
+ if err != nil {
+ t.Errorf("unable to paint attribute: %s\n%v", text, err)
+ }
+ builder.WriteString(PaintStrWithAttr(text, FgWhite, BgBlue, att))
+ }
+
+ t.Log(builder.String())
+}
diff --git a/internal/color/colorfy.go b/internal/color/colorfy.go
deleted file mode 100644
index a2beb7a..0000000
--- a/internal/color/colorfy.go
+++ /dev/null
@@ -1,56 +0,0 @@
-package color
-
-import (
- "fmt"
- "strings"
-)
-
-// Add some color to log lines received from remote servers.
-func paintRemote(line string) string {
- splitted := strings.Split(line, "|")
- if splitted[2] == "100" {
- splitted[2] = Paint(BgGreen, splitted[2])
- } else {
- splitted[2] = Paint(BgRed, splitted[2])
- }
- info := strings.Join(splitted[0:5], "|")
- log := strings.Join(splitted[5:], "|")
-
- if strings.HasPrefix(log, "WARN") {
- log = Paint(BgYellow, log)
- } else if strings.HasPrefix(log, "ERROR") {
- log = Paint(BgRed, log)
- } else if strings.HasPrefix(log, "FATAL") {
- log = Attr(Bold, Paint(BgRed, log))
- } else {
- log = Paint(Blue, log)
- }
-
- return fmt.Sprintf("%s|%s", info, log)
-}
-
-// Add some color to stats generated by the client.
-func paintClientStats(line string) string {
- splitted := strings.Split(line, "|")
- first := strings.Join(splitted[0:4], "|")
- connected := Paint(BgBlue, splitted[4])
- last := strings.Join(splitted[5:], "|")
-
- return fmt.Sprintf("%s|%s|%s", first, connected, last)
-}
-
-// Colorfy a given line based on the line's content.
-func Colorfy(line string) string {
- switch {
- case strings.HasPrefix(line, "REMOTE"):
- return paintRemote(line)
- case strings.HasPrefix(line, "CLIENT") && strings.Contains(line, "|stats|"):
- return paintClientStats(line)
- case strings.Contains(line, "ERROR"):
- return Paint(Magenta, line)
- case strings.Contains(line, "WARN"):
- return Paint(Magenta, line)
- }
-
- return line
-}
diff --git a/internal/color/paint.go b/internal/color/paint.go
new file mode 100644
index 0000000..7735d87
--- /dev/null
+++ b/internal/color/paint.go
@@ -0,0 +1,91 @@
+package color
+
+import (
+ "fmt"
+ "strings"
+)
+
+// PaintStr paints a given text in a given foreground/background color combination.
+func PaintStr(text string, fg FgColor, bg BgColor) string {
+ return fmt.Sprintf("%s%s%s%s%s", fg, bg, text, BgDefault, FgDefault)
+}
+
+// PaintStrWithAttr paints a given text in a given foreground/background/attribute
+// combination
+func PaintStrWithAttr(text string, fg FgColor, bg BgColor, attr Attribute) string {
+ if attr == AttrNone {
+ return PaintStr(text, fg, bg)
+ }
+ return fmt.Sprintf("%s%s%s%s%s%s%s", fg, bg, attr, text, AttrReset,
+ BgDefault, FgDefault)
+}
+
+// PaintStrFg paints a given text in a given foreground color.
+func PaintStrFg(text string, fg FgColor) string {
+ return fmt.Sprintf("%s%s%s", fg, text, FgDefault)
+}
+
+// PaintStrBg paints a given text in a given background color.
+func PaintStrBg(text string, bg BgColor) string {
+ return fmt.Sprintf("%s%s%s", bg, text, BgDefault)
+}
+
+// PaintStrAttr adds a given attribute to a given text, such as "bold" or "italic".
+func PaintStrAttr(text string, attr Attribute) string {
+ return fmt.Sprintf("%s%s%s", attr, text, AttrReset)
+}
+
+// Paint paints a given text in a given foreground/background color combination.
+func Paint(sb *strings.Builder, text string, fg FgColor, bg BgColor) {
+ sb.WriteString(string(fg))
+ sb.WriteString(string(bg))
+ sb.WriteString(text)
+ sb.WriteString(string(BgDefault))
+ sb.WriteString(string(FgDefault))
+}
+
+// Reset background and foreground colors.
+func Reset(sb *strings.Builder) {
+ sb.WriteString(string(BgDefault))
+ sb.WriteString(string(FgDefault))
+}
+
+// PaintWithAttr starts painting a given text in a given foreground/background/
+// attribute combination.
+func PaintWithAttr(sb *strings.Builder, text string, fg FgColor, bg BgColor,
+ attr Attribute) {
+
+ if attr == AttrNone {
+ Paint(sb, text, fg, bg)
+ return
+ }
+ sb.WriteString(string(fg))
+ sb.WriteString(string(bg))
+ sb.WriteString(string(attr))
+ sb.WriteString(text)
+ sb.WriteString(string(AttrReset))
+ sb.WriteString(string(BgDefault))
+ sb.WriteString(string(FgDefault))
+}
+
+// PaintWithAttrs is similar to PaintWithAttr, but it takes multiple attributes.
+func PaintWithAttrs(sb *strings.Builder, text string, fg FgColor, bg BgColor,
+ attrs []Attribute) {
+
+ sb.WriteString(string(fg))
+ sb.WriteString(string(bg))
+ for _, attr := range attrs {
+ sb.WriteString(string(attr))
+ }
+ sb.WriteString(text)
+ sb.WriteString(string(AttrReset))
+ sb.WriteString(string(BgDefault))
+ sb.WriteString(string(FgDefault))
+}
+
+// ResetWithAttr resets background, foreground and attributes.
+func ResetWithAttr(sb *strings.Builder) {
+ sb.WriteString(string(AttrReset))
+ sb.WriteString(string(BgDefault))
+ sb.WriteString(string(FgDefault))
+}
diff --git a/internal/color/table.go b/internal/color/table.go
new file mode 100644
index 0000000..e0e4946
--- /dev/null
+++ b/internal/color/table.go
@@ -0,0 +1,53 @@
+package color
+
+import (
+ "fmt"
+ "os"
+)
+
+const sampleParagraph string = "Mimecast is Making Email Safer for Business. " +
+ "We believe that securely operating a business in the cloud requires new " +
+ "levels of IT preparedness, centered around cyber resilience. This is why " +
+ "we unify the delivery and management of security, continuity and data " +
+ "protection for email via one, simple-to-use cloud platform. Thousands of " +
+ "organizations trust us to increase their cyber resilience preparedness, " +
+ "streamline compliance, reduce IT complexity and keep their business running. " +
+ "We give employees fast and secure access to sensitive business information, " +
+ "and ensure email keeps flowing in the event of an outage. Mimecast will " +
+ "remain committed to protecting your IT assets through constant innovation " +
+ "and focus on your success."
+
+// TablePrintAndExit prints the color table and then exits the process.
+func TablePrintAndExit(displaySampleParagraph bool) {
+ for _, attr := range AttributeNames {
+ if attr == "Hidden" || attr == "SlowBlink" {
+ continue
+ }
+ printColorTable(attr, displaySampleParagraph)
+ }
+ os.Exit(0)
+}
+
+func printColorTable(attr string, displaySampleParagraph bool) {
+ for _, fg := range ColorNames {
+ fgColor, _ := ToFgColor(fg)
+ for _, bg := range ColorNames {
+ if fg == bg {
+ continue
+ }
+
+ bgColor, _ := ToBgColor(bg)
+ attribute, _ := ToAttribute(attr)
+ text := fmt.Sprintf(" Foreground:%10s | Background:%10s | Attribute:%10s ",
+ fg, bg, attr)
+ fmt.Print(PaintStrWithAttr(text, fgColor, bgColor, attribute))
+
+ if displaySampleParagraph {
+ fmt.Print("\n")
+ fmt.Print(PaintStrWithAttr(sampleParagraph, fgColor, bgColor, attribute))
+ fmt.Print("\n")
+ }
+ fmt.Print("\n")
+ }
+ }
+}
diff --git a/internal/config/args.go b/internal/config/args.go
new file mode 100644
index 0000000..3d7ac7d
--- /dev/null
+++ b/internal/config/args.go
@@ -0,0 +1,164 @@
+package config
+
+import (
+ "encoding/base64"
+ "fmt"
+ "strconv"
+ "strings"
+
+ "github.com/mimecast/dtail/internal/lcontext"
+ "github.com/mimecast/dtail/internal/omode"
+
+ gossh "golang.org/x/crypto/ssh"
+)
+
+// Args is a helper struct to summarize common client arguments.
+type Args struct {
+ lcontext.LContext
+ Arguments []string
+ ConfigFile string
+ ConnectionsPerCPU int
+ Discovery string
+ LogDir string
+ Logger string
+ LogLevel string
+ Mode omode.Mode
+ NoColor bool
+ PrivateKeyPathFile string
+ QueryStr string
+ Quiet bool
+ RegexInvert bool
+ RegexStr string
+ Serverless bool
+ ServersStr string
+ Spartan bool
+ SSHAuthMethods []gossh.AuthMethod
+ SSHBindAddress string
+ SSHHostKeyCallback gossh.HostKeyCallback
+ SSHPort int
+ Timeout int
+ TrustAllHosts bool
+ UserName string
+ What string
+}
+
+func (a *Args) String() string {
+ var sb strings.Builder
+
+ sb.WriteString("Args(")
+
+ sb.WriteString(fmt.Sprintf("%s:%v,", "Arguments", a.Arguments))
+ sb.WriteString(fmt.Sprintf("%s:%v,", "ConfigFile", a.ConfigFile))
+ sb.WriteString(fmt.Sprintf("%s:%v,", "ConnectionsPerCPU", a.ConnectionsPerCPU))
+ sb.WriteString(fmt.Sprintf("%s:%v,", "Discovery", a.Discovery))
+ sb.WriteString(fmt.Sprintf("%s:%v,", "LogDir", a.LogDir))
+ sb.WriteString(fmt.Sprintf("%s:%v,", "LogLevel", a.LogLevel))
+ sb.WriteString(fmt.Sprintf("%s:%v,", "Logger", a.Logger))
+ sb.WriteString(fmt.Sprintf("%s:%v,", "Mode", a.Mode))
+ sb.WriteString(fmt.Sprintf("%s:%v,", "NoColor", a.NoColor))
+ sb.WriteString(fmt.Sprintf("%s:%v,", "PrivateKeyPathFile", a.PrivateKeyPathFile))
+ sb.WriteString(fmt.Sprintf("%s:%v,", "QueryStr", a.QueryStr))
+ sb.WriteString(fmt.Sprintf("%s:%v,", "Quiet", a.Quiet))
+ sb.WriteString(fmt.Sprintf("%s:%v,", "RegexInvert", a.RegexInvert))
+ sb.WriteString(fmt.Sprintf("%s:%v,", "RegexStr", a.RegexStr))
+ sb.WriteString(fmt.Sprintf("%s:%v,", "SSHAuthMethods", a.SSHAuthMethods))
+ sb.WriteString(fmt.Sprintf("%s:%v,", "SSHBindAddress", a.SSHBindAddress))
+ sb.WriteString(fmt.Sprintf("%s:%v,", "SSHHostKeyCallback", a.SSHHostKeyCallback))
+ sb.WriteString(fmt.Sprintf("%s:%v,", "SSHPort", a.SSHPort))
+ sb.WriteString(fmt.Sprintf("%s:%v,", "Serverless", a.Serverless))
+ sb.WriteString(fmt.Sprintf("%s:%v,", "ServersStr", a.ServersStr))
+ sb.WriteString(fmt.Sprintf("%s:%v,", "Spartan", a.Spartan))
+ sb.WriteString(fmt.Sprintf("%s:%v,", "Timeout", a.Timeout))
+ sb.WriteString(fmt.Sprintf("%s:%v,", "TrustAllHosts", a.TrustAllHosts))
+ sb.WriteString(fmt.Sprintf("%s:%v,", "UserName", a.UserName))
+ sb.WriteString(fmt.Sprintf("%s:%v", "What", a.What))
+ sb.WriteString(")")
+
+ return sb.String()
+}
+
+// SerializeOptions returns a string ready to be sent over the wire to the server.
+func (a *Args) SerializeOptions() string {
+ options := make(map[string]string)
+
+ if a.Quiet {
+ options["quiet"] = fmt.Sprintf("%v", a.Quiet)
+ }
+ if a.Spartan {
+ options["spartan"] = fmt.Sprintf("%v", a.Spartan)
+ }
+ if a.Serverless {
+ options["serverless"] = fmt.Sprintf("%v", a.Serverless)
+ }
+ if a.LContext.MaxCount != 0 {
+ options["max"] = fmt.Sprintf("%d", a.LContext.MaxCount)
+ }
+ if a.LContext.BeforeContext != 0 {
+ options["before"] = fmt.Sprintf("%d", a.LContext.BeforeContext)
+ }
+ if a.LContext.AfterContext != 0 {
+ options["after"] = fmt.Sprintf("%d", a.LContext.AfterContext)
+ }
+
+ var sb strings.Builder
+ var i int
+ for k, v := range options {
+ if i > 0 {
+ sb.WriteString(":")
+ }
+ sb.WriteString(k)
+ sb.WriteString("=")
+ sb.WriteString(v)
+ i++
+ }
+ return sb.String()
+}
+
+// DeserializeOptions deserializes the options, but into a map.
+func DeserializeOptions(opts []string) (map[string]string, lcontext.LContext, error) {
+ options := make(map[string]string, len(opts))
+ var ltx lcontext.LContext
+
+ for _, o := range opts {
+ kv := strings.SplitN(o, "=", 2)
+ if len(kv) != 2 {
+ return options, ltx, fmt.Errorf("Unable to parse options: %v", kv)
+ }
+ key := kv[0]
+ val := kv[1]
+
+ if strings.HasPrefix(val, "base64%") {
+ s := strings.SplitN(val, "%", 2)
+ decoded, err := base64.StdEncoding.DecodeString(s[1])
+ if err != nil {
+ return options, ltx, err
+ }
+ val = string(decoded)
+ }
+
+ switch key {
+ case "before":
+ iVal, err := strconv.Atoi(val)
+ if err != nil {
+ return options, ltx, err
+ }
+ ltx.BeforeContext = iVal
+ case "after":
+ iVal, err := strconv.Atoi(val)
+ if err != nil {
+ return options, ltx, err
+ }
+ ltx.AfterContext = iVal
+ case "max":
+ iVal, err := strconv.Atoi(val)
+ if err != nil {
+ return options, ltx, err
+ }
+ ltx.MaxCount = iVal
+ default:
+ options[key] = val
+ }
+ }
+
+ return options, ltx, nil
+}
diff --git a/internal/config/client.go b/internal/config/client.go
index 1515aae..86f97f0 100644
--- a/internal/config/client.go
+++ b/internal/config/client.go
@@ -1,11 +1,203 @@
package config
+import "github.com/mimecast/dtail/internal/color"
+
+type remoteTermColors struct {
+ DelimiterAttr color.Attribute
+ DelimiterBg color.BgColor
+ DelimiterFg color.FgColor
+ RemoteAttr color.Attribute
+ RemoteBg color.BgColor
+ RemoteFg color.FgColor
+ CountAttr color.Attribute
+ CountBg color.BgColor
+ CountFg color.FgColor
+ HostnameAttr color.Attribute
+ HostnameBg color.BgColor
+ HostnameFg color.FgColor
+ IDAttr color.Attribute
+ IDBg color.BgColor
+ IDFg color.FgColor
+ StatsOkAttr color.Attribute
+ StatsOkBg color.BgColor
+ StatsOkFg color.FgColor
+ StatsWarnAttr color.Attribute
+ StatsWarnBg color.BgColor
+ StatsWarnFg color.FgColor
+ TextAttr color.Attribute
+ TextBg color.BgColor
+ TextFg color.FgColor
+}
+
+type clientTermColors struct {
+ DelimiterAttr color.Attribute
+ DelimiterBg color.BgColor
+ DelimiterFg color.FgColor
+ ClientAttr color.Attribute
+ ClientBg color.BgColor
+ ClientFg color.FgColor
+ HostnameAttr color.Attribute
+ HostnameBg color.BgColor
+ HostnameFg color.FgColor
+ TextAttr color.Attribute
+ TextBg color.BgColor
+ TextFg color.FgColor
+}
+
+type serverTermColors struct {
+ DelimiterAttr color.Attribute
+ DelimiterBg color.BgColor
+ DelimiterFg color.FgColor
+ ServerAttr color.Attribute
+ ServerBg color.BgColor
+ ServerFg color.FgColor
+ HostnameAttr color.Attribute
+ HostnameBg color.BgColor
+ HostnameFg color.FgColor
+ TextAttr color.Attribute
+ TextBg color.BgColor
+ TextFg color.FgColor
+}
+
+type commonTermColors struct {
+ SeverityErrorAttr color.Attribute
+ SeverityErrorBg color.BgColor
+ SeverityErrorFg color.FgColor
+ SeverityFatalAttr color.Attribute
+ SeverityFatalBg color.BgColor
+ SeverityFatalFg color.FgColor
+ SeverityWarnAttr color.Attribute
+ SeverityWarnBg color.BgColor
+ SeverityWarnFg color.FgColor
+}
+
+type maprTableTermColors struct {
+ DataAttr color.Attribute
+ DataBg color.BgColor
+ DataFg color.FgColor
+ DelimiterAttr color.Attribute
+ DelimiterBg color.BgColor
+ DelimiterFg color.FgColor
+ HeaderAttr color.Attribute
+ HeaderBg color.BgColor
+ HeaderDelimiterAttr color.Attribute
+ HeaderDelimiterBg color.BgColor
+ HeaderDelimiterFg color.FgColor
+ HeaderFg color.FgColor
+ HeaderGroupKeyAttr color.Attribute
+ HeaderSortKeyAttr color.Attribute
+ RawQueryAttr color.Attribute
+ RawQueryBg color.BgColor
+ RawQueryFg color.FgColor
+}
+
+type termColors struct {
+ Remote remoteTermColors
+ Client clientTermColors
+ Server serverTermColors
+ Common commonTermColors
+ MaprTable maprTableTermColors
+}
+
// ClientConfig represents a DTail client configuration (empty as of now as there
// are no available config options yet, but that may changes in the future).
type ClientConfig struct {
+ TermColorsEnable bool `json:",omitempty"`
+ TermColors termColors `json:",omitempty"`
+ // When unit testing in Jenkins you don't want to touch files in ~jenkins
+ // during integration tests really.
+ SSHDontAddHostsToKnownHostsFile bool `json:",omitempty"`
}
// Create a new default client configuration.
func newDefaultClientConfig() *ClientConfig {
- return &ClientConfig{}
+ return &ClientConfig{
+ TermColorsEnable: true,
+ TermColors: termColors{
+ Remote: remoteTermColors{
+ DelimiterAttr: color.AttrDim,
+ DelimiterBg: color.BgBlue,
+ DelimiterFg: color.FgCyan,
+ RemoteAttr: color.AttrDim,
+ RemoteBg: color.BgBlue,
+ RemoteFg: color.FgWhite,
+ CountAttr: color.AttrDim,
+ CountBg: color.BgBlue,
+ CountFg: color.FgWhite,
+ HostnameAttr: color.AttrBold,
+ HostnameBg: color.BgBlue,
+ HostnameFg: color.FgWhite,
+ IDAttr: color.AttrDim,
+ IDBg: color.BgBlue,
+ IDFg: color.FgWhite,
+ StatsOkAttr: color.AttrNone,
+ StatsOkBg: color.BgGreen,
+ StatsOkFg: color.FgBlack,
+ StatsWarnAttr: color.AttrNone,
+ StatsWarnBg: color.BgRed,
+ StatsWarnFg: color.FgWhite,
+ TextAttr: color.AttrNone,
+ TextBg: color.BgBlack,
+ TextFg: color.FgWhite,
+ },
+ Client: clientTermColors{
+ DelimiterAttr: color.AttrDim,
+ DelimiterBg: color.BgYellow,
+ DelimiterFg: color.FgBlack,
+ ClientAttr: color.AttrDim,
+ ClientBg: color.BgYellow,
+ ClientFg: color.FgBlack,
+ HostnameAttr: color.AttrDim,
+ HostnameBg: color.BgYellow,
+ HostnameFg: color.FgBlack,
+ TextAttr: color.AttrNone,
+ TextBg: color.BgBlack,
+ TextFg: color.FgWhite,
+ },
+ Server: serverTermColors{
+ DelimiterAttr: color.AttrDim,
+ DelimiterBg: color.BgCyan,
+ DelimiterFg: color.FgBlack,
+ ServerAttr: color.AttrDim,
+ ServerBg: color.BgCyan,
+ ServerFg: color.FgBlack,
+ HostnameAttr: color.AttrBold,
+ HostnameBg: color.BgCyan,
+ HostnameFg: color.FgBlack,
+ TextAttr: color.AttrNone,
+ TextBg: color.BgBlack,
+ TextFg: color.FgWhite,
+ },
+ Common: commonTermColors{
+ SeverityErrorAttr: color.AttrBold,
+ SeverityErrorBg: color.BgRed,
+ SeverityErrorFg: color.FgWhite,
+ SeverityFatalAttr: color.AttrBold,
+ SeverityFatalBg: color.BgMagenta,
+ SeverityFatalFg: color.FgWhite,
+ SeverityWarnAttr: color.AttrBold,
+ SeverityWarnBg: color.BgBlack,
+ SeverityWarnFg: color.FgWhite,
+ },
+ MaprTable: maprTableTermColors{
+ DataAttr: color.AttrNone,
+ DataBg: color.BgBlue,
+ DataFg: color.FgWhite,
+ DelimiterAttr: color.AttrDim,
+ DelimiterBg: color.BgBlue,
+ DelimiterFg: color.FgWhite,
+ HeaderAttr: color.AttrBold,
+ HeaderBg: color.BgBlue,
+ HeaderFg: color.FgWhite,
+ HeaderDelimiterAttr: color.AttrDim,
+ HeaderDelimiterBg: color.BgBlue,
+ HeaderDelimiterFg: color.FgWhite,
+ HeaderSortKeyAttr: color.AttrUnderline,
+ HeaderGroupKeyAttr: color.AttrReverse,
+ RawQueryAttr: color.AttrDim,
+ RawQueryBg: color.BgBlack,
+ RawQueryFg: color.FgCyan,
+ },
+ },
+ }
}
diff --git a/internal/config/common.go b/internal/config/common.go
index c3e203e..7a72cfe 100644
--- a/internal/config/common.go
+++ b/internal/config/common.go
@@ -6,31 +6,27 @@ type CommonConfig struct {
SSHPort int
// Enable experimental features (mainly for dev purposes)
ExperimentalFeaturesEnable bool `json:",omitempty"`
- // Enable debug logging. Don't enable in production.
- DebugEnable bool `json:",omitempty"`
- // Enable trace logging. Don't enable in production.
- TraceEnable bool `json:",omitempty"`
- // The log strategy to use, one of
- // stdout: only log to stdout (useful when used with systemd)
- // daily: create a log file for every day
- LogStrategy string
- // The log directory
+ // LogDir defines the log directory.
LogDir string
+ // Logger defines the name of the logger implementation.
+ Logger string
+ // LogLevel defines how much is logged.
+ LogLevel string `json:",omitempty"`
+ // LogRotation strategy to be used.
+ LogRotation string
// The cache directory
CacheDir string
- // The temp directory
- TmpDir string `json:",omitempty"`
}
// Create a new default configuration.
func newDefaultCommonConfig() *CommonConfig {
return &CommonConfig{
- SSHPort: 2222,
- DebugEnable: false,
- TraceEnable: false,
+ SSHPort: DefaultSSHPort,
ExperimentalFeaturesEnable: false,
LogDir: "log",
+ Logger: "stdout",
+ LogLevel: DefaultLogLevel,
+ LogRotation: "daily",
CacheDir: "cache",
- TmpDir: "/tmp",
}
}
diff --git a/internal/config/config.go b/internal/config/config.go
index 276ddcf..ee23829 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -1,23 +1,30 @@
package config
-import (
- "encoding/json"
- "io/ioutil"
- "os"
+import "github.com/mimecast/dtail/internal/source"
+
+const (
+ // HealthUser is used for the health check
+ HealthUser string = "DTAIL-HEALTH"
+ // ScheduleUser is used for non-interactive scheduled mapreduce queries.
+ ScheduleUser string = "DTAIL-SCHEDULE"
+ // ContinuousUser is used for non-interactive continuous mapreduce queries.
+ ContinuousUser string = "DTAIL-CONTINUOUS"
+ // InterruptTimeoutS specifies the Ctrl+C log pause interval.
+ InterruptTimeoutS int = 3
+ // DefaultConnectionsPerCPU controls how many connections are established concurrently.
+ DefaultConnectionsPerCPU int = 10
+ // DefaultSSHPort is the default DServer port.
+ DefaultSSHPort int = 2222
+ // DefaultLogLevel specifies the default log level (obviously)
+ DefaultLogLevel string = "info"
+ // DefaultClientLogger specifies the default logger for the client commands.
+ DefaultClientLogger string = "fout"
+ // DefaultServerLogger specifies the default logger for dtail server.
+ DefaultServerLogger string = "file"
+ // DefaultHealthCheckLogger specifies the default logger used for health checks.
+ DefaultHealthCheckLogger string = "none"
)
-// ControlUser is used for various DTail specific operations.
-const ControlUser string = "DTAIL-CONTROL"
-
-// ScheduleUser is used for non-interactive scheduled mapreduce queries.
-const ScheduleUser string = "DTAIL-SCHEDULE"
-
-// ContinuousUser is used for non-interactive continuous mapreduce queries.
-const ContinuousUser string = "DTAIL-CONTINUOUS"
-
-// InterruptTimeoutS is used to terminate DTail when Ctrl+C was pressed twice within a given interval.
-const InterruptTimeoutS int = 3
-
// Client holds a DTail client configuration.
var Client *ClientConfig
@@ -27,28 +34,22 @@ var Server *ServerConfig
// Common holds common configs of both both, client and server.
var Common *CommonConfig
-// Used to initialize the configuration.
-type configInitializer struct {
- Common *CommonConfig
- Server *ServerConfig
- Client *ClientConfig
-}
-
-// Parse and read a given config file in JSON format.
-func (c *configInitializer) parseConfig(configFile string) {
- fd, err := os.Open(configFile)
- if err != nil {
- panic(err)
+// Setup the DTail configuration.
+func Setup(sourceProcess source.Source, args *Args, additionalArgs []string) {
+ initializer := initializer{
+ Common: newDefaultCommonConfig(),
+ Server: newDefaultServerConfig(),
+ Client: newDefaultClientConfig(),
}
- defer fd.Close()
-
- cfgBytes, err := ioutil.ReadAll(fd)
- if err != nil {
+ if err := initializer.parseConfig(args); err != nil {
panic(err)
}
-
- err = json.Unmarshal([]byte(cfgBytes), c)
- if err != nil {
+ if err := initializer.transformConfig(sourceProcess, args, additionalArgs); err != nil {
panic(err)
}
+
+ // Make config accessible globally
+ Server = initializer.Server
+ Client = initializer.Client
+ Common = initializer.Common
}
diff --git a/internal/config/env.go b/internal/config/env.go
new file mode 100644
index 0000000..88b831d
--- /dev/null
+++ b/internal/config/env.go
@@ -0,0 +1,7 @@
+package config
+
+import "os"
+
+func Env(env string) bool {
+ return "yes" == os.Getenv(env)
+}
diff --git a/internal/config/initializer.go b/internal/config/initializer.go
new file mode 100644
index 0000000..4d6a73b
--- /dev/null
+++ b/internal/config/initializer.go
@@ -0,0 +1,184 @@
+package config
+
+import (
+ "encoding/json"
+ "flag"
+ "fmt"
+ "io/ioutil"
+ "os"
+ "strings"
+
+ "github.com/mimecast/dtail/internal/source"
+)
+
+// Used to initialize the configuration.
+type initializer struct {
+ Common *CommonConfig
+ Server *ServerConfig
+ Client *ClientConfig
+}
+
+type transformCb func(*initializer, *Args, []string) error
+
+func (in *initializer) parseConfig(args *Args) error {
+ if strings.ToUpper(args.ConfigFile) == "NONE" {
+ return nil
+ }
+
+ if args.ConfigFile != "" {
+ return in.parseSpecificConfig(args.ConfigFile)
+ }
+
+ if homeDir, err := os.UserHomeDir(); err != nil {
+ var paths []string
+ paths = append(paths, fmt.Sprintf("%s/.config/dtail/dtail.conf", homeDir))
+ paths = append(paths, fmt.Sprintf("%s/.dtail.conf", homeDir))
+ for _, configPath := range paths {
+ if _, err := os.Stat(configPath); !os.IsNotExist(err) {
+ in.parseSpecificConfig(configPath)
+ }
+ }
+ }
+
+ return nil
+}
+
+func (in *initializer) parseSpecificConfig(configFile string) error {
+ fd, err := os.Open(configFile)
+ if err != nil {
+ return fmt.Errorf("Unable to read config file: %v", err)
+ }
+ defer fd.Close()
+
+ cfgBytes, err := ioutil.ReadAll(fd)
+ if err != nil {
+ return fmt.Errorf("Unable to read config file %s: %v", configFile, err)
+ }
+
+ if err := json.Unmarshal([]byte(cfgBytes), in); err != nil {
+ return fmt.Errorf("Unable to parse config file %s: %v", configFile, err)
+ }
+
+ return nil
+}
+
+func (in *initializer) transformConfig(sourceProcess source.Source, args *Args,
+ additionalArgs []string) error {
+
+ in.readEnvironmentVars()
+
+ switch sourceProcess {
+ case source.Server:
+ return in.optimusPrime(transformServer, args, additionalArgs)
+ case source.Client:
+ return in.optimusPrime(transformClient, args, additionalArgs)
+ case source.HealthCheck:
+ return in.optimusPrime(transformHealthCheck, args, additionalArgs)
+ default:
+ return fmt.Errorf("Unable to transform config, unknown source '%s'",
+ sourceProcess)
+ }
+}
+
+// There are some special options which can be set by environment variable.
+func (in *initializer) readEnvironmentVars() {
+ if Env("DTAIL_SSH_DONT_ADD_HOSTS_TO_KNOWNHOSTS_FILE") ||
+ Env("DTAIL_JENKINS") {
+ in.Client.SSHDontAddHostsToKnownHostsFile = true
+ }
+}
+
+func (in *initializer) optimusPrime(sourceCb transformCb, args *Args,
+ additionalArgs []string) error {
+
+ // Copy args to config objects.
+ // NEXT: Maybe unify args and config structs?
+ if args.SSHPort != DefaultSSHPort {
+ in.Common.SSHPort = args.SSHPort
+ }
+ if args.LogLevel != DefaultLogLevel {
+ in.Common.LogLevel = args.LogLevel
+ }
+ if args.NoColor {
+ in.Client.TermColorsEnable = false
+ }
+ if args.LogDir != "" {
+ in.Common.LogDir = args.LogDir
+ }
+ if args.Logger != "" {
+ in.Common.Logger = args.Logger
+ }
+ if args.ConnectionsPerCPU == 0 {
+ args.ConnectionsPerCPU = DefaultConnectionsPerCPU
+ }
+
+ // Setup log directory.
+ if strings.Contains(in.Common.LogDir, "~/") {
+ homeDir, err := os.UserHomeDir()
+ if err != nil {
+ panic(err)
+ }
+ in.Common.LogDir = strings.ReplaceAll(in.Common.LogDir, "~/",
+ fmt.Sprintf("%s/", homeDir))
+ }
+
+ // Source type specific transormations.
+ sourceCb(in, args, additionalArgs)
+
+ // Spartan mode.
+ if args.Spartan {
+ args.Quiet = true
+ args.NoColor = true
+ in.Client.TermColorsEnable = false
+ if args.LogLevel == "" {
+ args.LogLevel = "ERROR"
+ in.Common.LogLevel = "ERROR"
+ }
+ }
+ // Interpret additional args as file list or as query.
+ if args.What == "" {
+ var files []string
+ for _, arg := range flag.Args() {
+ if args.QueryStr == "" && strings.Contains(strings.ToLower(arg), "select ") {
+ args.QueryStr = arg
+ continue
+ }
+ files = append(files, arg)
+ }
+ args.What = strings.Join(files, ",")
+ }
+
+ return nil
+}
+
+func transformClient(in *initializer, args *Args, additionalArgs []string) error {
+ // Serverless mode.
+ if args.Discovery == "" && (args.ServersStr == "" ||
+ strings.ToLower(args.ServersStr) == "serverless") {
+ // We are not connecting to any servers.
+ args.Serverless = true
+ if args.LogLevel == DefaultLogLevel {
+ in.Common.LogLevel = "warn"
+ }
+ }
+ return nil
+}
+
+func transformServer(in *initializer, args *Args, additionalArgs []string) error {
+ if args.SSHBindAddress != "" {
+ in.Server.SSHBindAddress = args.SSHBindAddress
+ }
+ return nil
+}
+
+func transformHealthCheck(in *initializer, args *Args, additionalArgs []string) error {
+ // Serverless mode.
+ if args.Discovery == "" && (args.ServersStr == "" ||
+ strings.ToLower(args.ServersStr) == "serverless") {
+ // We are not connecting to any servers.
+ args.Serverless = true
+ in.Common.LogLevel = "warn"
+ }
+ args.TrustAllHosts = true
+ return nil
+}
diff --git a/internal/config/read.go b/internal/config/read.go
deleted file mode 100644
index a4e605b..0000000
--- a/internal/config/read.go
+++ /dev/null
@@ -1,37 +0,0 @@
-package config
-
-import (
- "os"
-)
-
-// Read the DTail configuration.
-func Read(configFile string, sshPort int) {
- initializer := configInitializer{
- Common: newDefaultCommonConfig(),
- Server: newDefaultServerConfig(),
- Client: newDefaultClientConfig(),
- }
-
- if configFile == "" {
- configFile = "./cfg/dtail.json"
- }
-
- if _, err := os.Stat(configFile); !os.IsNotExist(err) {
- initializer.parseConfig(configFile)
- }
-
- // Assign pointers to global variables, so that we can access the
- // configuration from any place of the program.
- Common = initializer.Common
- Server = initializer.Server
- Client = initializer.Client
-
- if Server.MapreduceLogFormat == "" {
- Server.MapreduceLogFormat = "default"
- }
-
- // If non-standard port specified, overwrite config
- if sshPort != 2222 {
- Common.SSHPort = sshPort
- }
-}
diff --git a/internal/config/server.go b/internal/config/server.go
index dc0d587..254ea0c 100644
--- a/internal/config/server.go
+++ b/internal/config/server.go
@@ -4,8 +4,8 @@ import (
"errors"
)
-// Permissions map. Each SSH user has a list of permissions which
-// log files it is allowed to follow and which ones not.
+// Permissions map. Each SSH user has a list of permissions which log files it
+// is allowed to follow and which ones not.
type Permissions struct {
// The default user permissions.
Default []string
@@ -47,7 +47,7 @@ type ServerConfig struct {
MaxConcurrentCats int
// The max amount of concurrent tails per server.
MaxConcurrentTails int
- // The user permissions.
+ // The user permissions. TODO: Add to JSON schema
Permissions Permissions `json:",omitempty"`
// The mapr log format
MapreduceLogFormat string `json:",omitempty"`
@@ -68,7 +68,6 @@ var ServerRelaxedAuthEnable bool
func newDefaultServerConfig() *ServerConfig {
defaultPermissions := []string{"^/.*"}
defaultBindAddress := "0.0.0.0"
-
return &ServerConfig{
SSHBindAddress: defaultBindAddress,
MaxConnections: 10,
@@ -76,6 +75,7 @@ func newDefaultServerConfig() *ServerConfig {
MaxConcurrentTails: 50,
HostKeyFile: "./cache/ssh_host_key",
HostKeyBits: 4096,
+ MapreduceLogFormat: "default",
Permissions: Permissions{
Default: defaultPermissions,
},
@@ -88,10 +88,8 @@ func ServerUserPermissions(userName string) (permissions []string, err error) {
if p, ok := Server.Permissions.Users[userName]; ok {
permissions = p
}
-
if len(permissions) == 0 {
err = errors.New("Empty set of permission, user won't be able to open any files")
}
-
return
}
diff --git a/internal/discovery/comma.go b/internal/discovery/comma.go
index 4344240..9bea89c 100644
--- a/internal/discovery/comma.go
+++ b/internal/discovery/comma.go
@@ -3,11 +3,11 @@ package discovery
import (
"strings"
- "github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/io/dlog"
)
// ServerListFromCOMMA retrieves a list of servers from comma separated input list.
func (d *Discovery) ServerListFromCOMMA() []string {
- logger.Debug("Retrieving server list from comma separated list", d.server)
+ dlog.Common.Debug("Retrieving server list from comma separated list", d.server)
return strings.Split(d.server, ",")
}
diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go
index a25b136..8bb1e85 100644
--- a/internal/discovery/discovery.go
+++ b/internal/discovery/discovery.go
@@ -9,7 +9,7 @@ import (
"strings"
"time"
- "github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/io/dlog"
)
// ServerOrder to specify how to sort the server list.
@@ -24,7 +24,7 @@ const (
type Discovery struct {
// To plug in a custom server discovery module.
module string
- // To specify optional server discovery module options.
+ // To specifiy optional server discovery module options.
options string
// To either filter a server list or to secify an exact list.
server string
@@ -42,7 +42,7 @@ func New(method, server string, order ServerOrder) *Discovery {
if strings.Contains(module, ":") {
s := strings.Split(module, ":")
if len(s) != 2 {
- logger.FatalExit("Unable to parse discovery module", module)
+ dlog.Common.FatalPanic("Unable to parse discovery module", module)
}
module = s[0]
options = s[1]
@@ -72,11 +72,10 @@ func (d *Discovery) initRegex() {
}
regexStr := string(runes)
- logger.Debug("Using filter regex", regexStr)
-
+ dlog.Common.Debug("Using filter regex", regexStr)
regex, err := regexp.Compile(regexStr)
if err != nil {
- logger.FatalExit("Could not compile regex", regexStr, err)
+ dlog.Common.FatalPanic("Could not compile regex", regexStr, err)
}
d.regex = regex
@@ -90,14 +89,12 @@ func (d *Discovery) ServerList() []string {
if d.regex != nil {
servers = d.filterList(servers)
}
-
servers = d.dedupList(servers)
-
if d.order == Shuffle {
servers = d.shuffleList(servers)
}
- logger.Debug("Discovered servers", len(servers), servers)
+ dlog.Common.Debug("Discovered servers", len(servers), servers)
return servers
}
@@ -105,12 +102,10 @@ func (d *Discovery) serverListFromModule() []string {
if d.module != "" {
return d.serverListFromReflectedModule()
}
-
if _, err := os.Stat(d.server); err == nil {
// Appears to be a file name, now try to read from that file.
return d.ServerListFromFILE()
}
-
// Appears to be a list of FQDNs (or a single FQDN)
return d.ServerListFromCOMMA()
}
@@ -120,53 +115,47 @@ func (d *Discovery) serverListFromModule() []string {
// Discovery. Whereas MODULENAME must be a upeprcase string.
func (d *Discovery) serverListFromReflectedModule() []string {
methodName := fmt.Sprintf("ServerListFrom%s", d.module)
-
+ // Now we are reflecting the serve discovery function by it's name.
rt := reflect.TypeOf(d)
reflectedMethod, ok := rt.MethodByName(methodName)
if !ok {
- logger.FatalExit("No such server discovery module", d.module, methodName)
+ dlog.Common.FatalPanic("No such server discovery module", d.module, methodName)
}
-
inputValues := make([]reflect.Value, 1)
// Thist input value is method receiver.
inputValues[0] = reflect.ValueOf(d)
returnValues := reflectedMethod.Func.Call(inputValues)
-
// First return value is server list.
return returnValues[0].Interface().([]string)
}
// Filter server list based on a regexp.
func (d *Discovery) filterList(servers []string) (filtered []string) {
- logger.Debug("Filtering server list")
-
+ dlog.Common.Debug("Filtering server list")
for _, server := range servers {
if d.regex.MatchString(server) {
filtered = append(filtered, server)
}
}
-
return
}
// Deduplicate the server list.
func (d *Discovery) dedupList(servers []string) (deduped []string) {
serverMap := make(map[string]struct{}, len(servers))
-
for _, server := range servers {
if _, ok := serverMap[server]; !ok {
serverMap[server] = struct{}{}
deduped = append(deduped, server)
}
}
-
- logger.Debug("Deduped server list", len(servers), len(deduped))
+ dlog.Common.Debug("Deduped server list", len(servers), len(deduped))
return
}
// Randomly shuffle the server list.
func (d *Discovery) shuffleList(servers []string) []string {
- logger.Debug("Shuffling server list")
+ dlog.Common.Debug("Shuffling server list")
r := rand.New(rand.NewSource(time.Now().Unix()))
shuffled := make([]string, len(servers))
diff --git a/internal/discovery/file.go b/internal/discovery/file.go
index 1250755..fb46eeb 100644
--- a/internal/discovery/file.go
+++ b/internal/discovery/file.go
@@ -4,16 +4,16 @@ import (
"bufio"
"os"
- "github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/io/dlog"
)
// ServerListFromFILE retrieves a list of servers from a file.
func (d *Discovery) ServerListFromFILE() (servers []string) {
- logger.Debug("Retrieving server list from file", d.server)
+ dlog.Common.Debug("Retrieving server list from file", d.server)
file, err := os.Open(d.server)
if err != nil {
- logger.FatalExit(d.server, err)
+ dlog.Common.FatalPanic(d.server, err)
}
defer file.Close()
@@ -22,7 +22,7 @@ func (d *Discovery) ServerListFromFILE() (servers []string) {
servers = append(servers, scanner.Text())
}
if err := scanner.Err(); err != nil {
- logger.FatalExit(d.server, err)
+ dlog.Common.FatalPanic(d.server, err)
}
return
diff --git a/internal/done.go b/internal/done.go
index 54e5e8e..94f9289 100644
--- a/internal/done.go
+++ b/internal/done.go
@@ -17,6 +17,15 @@ func NewDone() *Done {
}
}
+func (d *Done) String() string {
+ select {
+ case <-d.Done():
+ return "Done(yes)"
+ default:
+ return "Done(no)"
+ }
+}
+
// Done returns the done channel (closed when done)
func (d *Done) Done() <-chan struct{} {
return d.ch
@@ -26,7 +35,6 @@ func (d *Done) Done() <-chan struct{} {
func (d *Done) Shutdown() {
d.mutex.Lock()
defer d.mutex.Unlock()
-
select {
case <-d.ch:
return
diff --git a/internal/io/dlog/dlog.go b/internal/io/dlog/dlog.go
new file mode 100644
index 0000000..5e0c3a1
--- /dev/null
+++ b/internal/io/dlog/dlog.go
@@ -0,0 +1,272 @@
+package dlog
+
+import (
+ "context"
+ "fmt"
+ "io/ioutil"
+ "os"
+ "path/filepath"
+ "runtime"
+ "strconv"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/mimecast/dtail/internal/color/brush"
+ "github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/io/dlog/loggers"
+ "github.com/mimecast/dtail/internal/io/pool"
+ "github.com/mimecast/dtail/internal/protocol"
+ "github.com/mimecast/dtail/internal/source"
+)
+
+// Client is the log handler for the client packages.
+var Client *DLog
+
+// Server is the log handler for the server packages.
+var Server *DLog
+
+// Common is the log handler for all other packages.
+var Common *DLog
+
+var mutex sync.Mutex
+var started bool
+
+// Start logger(s).
+func Start(ctx context.Context, wg *sync.WaitGroup, sourceProcess source.Source) {
+ mutex.Lock()
+ defer mutex.Unlock()
+
+ if started {
+ Common.FatalPanic("Logger already started")
+ }
+
+ Client = new(sourceProcess, source.Client)
+ Server = new(sourceProcess, source.Server)
+ Common = Client
+ if sourceProcess == source.Server {
+ Common = Server
+ }
+
+ var wg2 sync.WaitGroup
+ wg2.Add(2)
+ go Client.start(ctx, &wg2)
+ go Server.start(ctx, &wg2)
+
+ go rotation(ctx)
+ go func() {
+ wg2.Wait()
+ wg.Done()
+ }()
+
+ started = true
+}
+
+// DLog is the DTail logger.
+type DLog struct {
+ logger loggers.Logger
+ // Is this a DTail server or client process logging?
+ sourceProcess source.Source
+ // Is this a DTail server or client package logging? In serverless mode
+ // the client can also execute code from the server package.
+ sourcePackage source.Source
+ // Max log level to log.
+ maxLevel level
+ // Current hostname.
+ hostname string
+}
+
+// new creates a new DTail logger.
+func new(sourceProcess, sourcePackage source.Source) *DLog {
+ hostname, err := os.Hostname()
+ if err != nil {
+ panic(err)
+ }
+ logRotation := loggers.NewStrategy(config.Common.LogRotation)
+ loggerName := config.Common.Logger
+ level := newLevel(config.Common.LogLevel)
+
+ return &DLog{
+ logger: loggers.Factory(sourceProcess.String(), loggerName, logRotation),
+ sourceProcess: sourceProcess,
+ sourcePackage: sourcePackage,
+ maxLevel: level,
+ hostname: hostname,
+ }
+}
+
+func (d *DLog) start(ctx context.Context, wg *sync.WaitGroup) {
+ defer wg.Done()
+ var wg2 sync.WaitGroup
+ wg2.Add(1)
+ d.logger.Start(ctx, &wg2)
+ <-ctx.Done()
+ wg2.Wait()
+}
+
+func (d *DLog) log(level level, args []interface{}) string {
+ if d.maxLevel < level {
+ return ""
+ }
+ sb := pool.BuilderBuffer.Get().(*strings.Builder)
+ defer pool.RecycleBuilderBuffer(sb)
+ now := time.Now()
+
+ switch d.sourceProcess {
+ case source.Client:
+ sb.WriteString(d.sourcePackage.String())
+ sb.WriteString(protocol.FieldDelimiter)
+ sb.WriteString(d.hostname)
+ sb.WriteString(protocol.FieldDelimiter)
+ sb.WriteString(level.String())
+ default:
+ sb.WriteString(level.String())
+ sb.WriteString(protocol.FieldDelimiter)
+ sb.WriteString(now.Format("20060102-150405"))
+ }
+ sb.WriteString(protocol.FieldDelimiter)
+ d.writeArgStrings(sb, args)
+
+ message := sb.String()
+ if !config.Client.TermColorsEnable || !d.logger.SupportsColors() {
+ d.logger.Log(now, message)
+ return message
+ }
+
+ d.logger.LogWithColors(now, message, brush.Colorfy(message))
+ return message
+}
+
+func (d *DLog) writeArgStrings(sb *strings.Builder, args []interface{}) {
+ for i, arg := range args {
+ if i > 0 {
+ sb.WriteString(protocol.FieldDelimiter)
+ }
+ switch v := arg.(type) {
+ case string:
+ sb.WriteString(v)
+ case error:
+ sb.WriteString(v.Error())
+ default:
+ sb.WriteString(fmt.Sprintf("%v", v))
+ }
+ }
+}
+
+// FatalPanic terminates the process with a fatal error.
+func (d *DLog) FatalPanic(args ...interface{}) {
+ d.log(Fatal, args)
+ d.Flush()
+
+ var sb strings.Builder
+ d.writeArgStrings(&sb, args)
+ panic(sb.String())
+}
+
+// Fatal logs a fatal error.
+func (d *DLog) Fatal(args ...interface{}) string {
+ return d.log(Fatal, args)
+}
+
+// Error logging.
+func (d *DLog) Error(args ...interface{}) string {
+ return d.log(Error, args)
+}
+
+// Warn logs a warning message.
+func (d *DLog) Warn(args ...interface{}) string {
+ return d.log(Warn, args)
+}
+
+// Info logging.
+func (d *DLog) Info(args ...interface{}) string {
+ return d.log(Info, args)
+}
+
+// Verbose logging.
+func (d *DLog) Verbose(args ...interface{}) string {
+ return d.log(Verbose, args)
+}
+
+// Debug logging.
+func (d *DLog) Debug(args ...interface{}) string {
+ return d.log(Debug, args)
+}
+
+// Trace logging.
+func (d *DLog) Trace(args ...interface{}) string {
+ _, file, line, _ := runtime.Caller(1)
+ args = append(args, fmt.Sprintf("at %s:%d", file, line))
+ return d.log(Trace, args)
+}
+
+// Devel used for development purpose only logging (e.g. "print" debugging).
+func (d *DLog) Devel(args ...interface{}) string {
+ _, file, line, _ := runtime.Caller(1)
+ args = append(args, fmt.Sprintf("at %s:%d", file, line))
+ return d.log(Devel, args)
+}
+
+// Raw message logging.
+func (d *DLog) Raw(message string) string {
+ if !config.Client.TermColorsEnable || !d.logger.SupportsColors() {
+ d.logger.Log(time.Now(), message)
+ return message
+ }
+ d.logger.LogWithColors(time.Now(), message, brush.Colorfy(message))
+ return message
+}
+
+// Mapreduce logging.
+func (d *DLog) Mapreduce(table string, data map[string]interface{}) string {
+ args := make([]interface{}, len(data)+1)
+
+ if d.sourceProcess == source.Server {
+ // level|date-time|process|caller|cpus|goroutines|cgocalls|loadavg|uptime|MAPREDUCE:TABLE|key=value|...
+
+ var loadAvg string
+ if loadAvgBytes, err := ioutil.ReadFile("/proc/loadavg"); err == nil {
+ tmp := string(loadAvgBytes)
+ s := strings.SplitN(tmp, " ", 2)
+ loadAvg = s[0]
+ }
+
+ var uptime string
+ if uptimeBytes, err := ioutil.ReadFile("/proc/uptime"); err == nil {
+ tmp := string(uptimeBytes)
+ s := strings.SplitN(tmp, ".", 2)
+ i, _ := strconv.ParseInt(s[0], 10, 64)
+ t := time.Duration(i) * time.Second
+ uptime = fmt.Sprintf("%v", t)
+ }
+
+ _, file, line, _ := runtime.Caller(1)
+ args[0] = fmt.Sprintf("%d|%s:%d|%d|%d|%d|%s|%s|MAPREDUCE:%s",
+ os.Getpid(),
+ filepath.Base(file), line,
+ runtime.NumCPU(),
+ runtime.NumGoroutine(),
+ runtime.NumCgoCall(),
+ loadAvg,
+ uptime,
+ strings.ToUpper(table))
+ } else {
+ args[0] = fmt.Sprintf("STATS:%s", strings.ToUpper(table))
+ }
+
+ i := 1
+ for k, v := range data {
+ args[i] = fmt.Sprintf("%s=%v", k, v)
+ i++
+ }
+ return d.log(Info, args)
+}
+
+// Flush the log buffers.
+func (d *DLog) Flush() { d.logger.Flush() }
+
+// Pause the logging.
+func (d *DLog) Pause() { d.logger.Pause() }
+
+// Resume the logging.
+func (d *DLog) Resume() { d.logger.Resume() }
diff --git a/internal/io/dlog/level.go b/internal/io/dlog/level.go
new file mode 100644
index 0000000..05d9ed9
--- /dev/null
+++ b/internal/io/dlog/level.go
@@ -0,0 +1,84 @@
+package dlog
+
+import (
+ "fmt"
+ "strings"
+)
+
+type level int
+
+// Available log levels.
+const (
+ None level = iota
+ Fatal level = iota
+ Error level = iota
+ Warn level = iota
+ Info level = iota
+ Default level = iota
+ Verbose level = iota
+ Debug level = iota
+ Devel level = iota
+ Trace level = iota
+ All level = iota
+)
+
+var allLevels = []level{Fatal, Error, Warn, Info, Default, Verbose, Debug,
+ Devel, Trace, All}
+
+func newLevel(l string) level {
+ switch strings.ToLower(l) {
+ case "none":
+ return None
+ case "fatal":
+ return Fatal
+ case "error":
+ return Error
+ case "warn":
+ return Warn
+ case "info":
+ return Info
+ case "":
+ fallthrough
+ case "default":
+ return Default
+ case "verbose":
+ return Verbose
+ case "debug":
+ return Debug
+ case "devel":
+ return Devel
+ case "trace":
+ return Trace
+ case "all":
+ return All
+ }
+ panic(fmt.Sprintf("Unknown log level %s, must be one of: %v", l, allLevels))
+}
+
+func (l level) String() string {
+ switch l {
+ case None:
+ return "NONE"
+ case Fatal:
+ return "FATAL"
+ case Error:
+ return "ERROR"
+ case Warn:
+ return "WARN"
+ case Info:
+ return "INFO"
+ case Default:
+ return "DEFAULT"
+ case Verbose:
+ return "VERBOSE"
+ case Debug:
+ return "DEBUG"
+ case Devel:
+ return "DEVEL"
+ case Trace:
+ return "TRACE"
+ case All:
+ return "ALL"
+ }
+ panic("Unknown log level " + fmt.Sprintf("%d", l))
+}
diff --git a/internal/io/dlog/loggers/factory.go b/internal/io/dlog/loggers/factory.go
new file mode 100644
index 0000000..a5cc7cf
--- /dev/null
+++ b/internal/io/dlog/loggers/factory.go
@@ -0,0 +1,54 @@
+package loggers
+
+import (
+ "fmt"
+ "strings"
+ "sync"
+)
+
+var factoryMap map[string]Logger
+var factoryMutex sync.Mutex
+
+// Factory is there to retrieve a logger based on various settings.
+func Factory(sourceName, loggerName string, logRotation Strategy) Logger {
+ factoryMutex.Lock()
+ defer factoryMutex.Unlock()
+
+ id := fmt.Sprintf("sourceName:%s,fileBase:%s,loggerName:%s", sourceName,
+ logRotation.FileBase, loggerName)
+ if factoryMap == nil {
+ factoryMap = make(map[string]Logger)
+ }
+
+ singleton, ok := factoryMap[id]
+ if !ok {
+ switch strings.ToLower(loggerName) {
+ case "none":
+ singleton = none{}
+ case "stdout":
+ singleton = newStdout()
+ factoryMap[id] = singleton
+ case "file":
+ singleton = newFile(logRotation)
+ factoryMap[id] = singleton
+ case "fout":
+ singleton = newFout(logRotation)
+ factoryMap[id] = singleton
+ default:
+ panic(fmt.Sprintf("Unsupported logger type '%s'", loggerName))
+ }
+ }
+ return singleton
+}
+
+// FactoryRotate invokes a log rotation of all loggers.
+func FactoryRotate() {
+ factoryMutex.Lock()
+ defer factoryMutex.Unlock()
+ if factoryMap == nil {
+ return
+ }
+ for _, logger := range factoryMap {
+ logger.Rotate()
+ }
+}
diff --git a/internal/io/dlog/loggers/file.go b/internal/io/dlog/loggers/file.go
new file mode 100644
index 0000000..94824fe
--- /dev/null
+++ b/internal/io/dlog/loggers/file.go
@@ -0,0 +1,165 @@
+package loggers
+
+import (
+ "bufio"
+ "context"
+ "fmt"
+ "os"
+ "runtime"
+ "sync"
+ "time"
+
+ "github.com/mimecast/dtail/internal/config"
+)
+
+type fileWriter struct{}
+
+type fileMessageBuf struct {
+ now time.Time
+ message string
+}
+
+type file struct {
+ bufferCh chan *fileMessageBuf
+ pauseCh chan struct{}
+ resumeCh chan struct{}
+ rotateCh chan struct{}
+ flushCh chan struct{}
+ fd *os.File
+ writer *bufio.Writer
+ mutex sync.Mutex
+ started bool
+ lastFileName string
+ strategy Strategy
+}
+
+func newFile(strategy Strategy) *file {
+ return &file{
+ bufferCh: make(chan *fileMessageBuf, runtime.NumCPU()*100),
+ pauseCh: make(chan struct{}),
+ resumeCh: make(chan struct{}),
+ rotateCh: make(chan struct{}),
+ flushCh: make(chan struct{}),
+ strategy: strategy,
+ }
+}
+
+func (f *file) Start(ctx context.Context, wg *sync.WaitGroup) {
+ f.mutex.Lock()
+ defer func() {
+ f.started = true
+ f.mutex.Unlock()
+ }()
+
+ if f.started {
+ // Logger already started from another Goroutine.
+ wg.Done()
+ return
+ }
+
+ pause := func(ctx context.Context) {
+ select {
+ case <-f.resumeCh:
+ return
+ case <-ctx.Done():
+ return
+ }
+ }
+
+ go func() {
+ defer wg.Done()
+ for {
+ select {
+ case m := <-f.bufferCh:
+ f.write(m)
+ case <-f.pauseCh:
+ pause(ctx)
+ case <-f.flushCh:
+ f.flush()
+ case <-ctx.Done():
+ f.flush()
+ f.fd.Close()
+ return
+ }
+ }
+ }()
+}
+
+func (f *file) Log(now time.Time, message string) {
+ f.bufferCh <- &fileMessageBuf{now, message}
+}
+
+func (f *file) LogWithColors(now time.Time, message, coloredMessage string) {
+ panic("Colors not supported in file logger")
+}
+
+func (f *file) Pause() { f.pauseCh <- struct{}{} }
+func (f *file) Resume() { f.resumeCh <- struct{}{} }
+func (f *file) Flush() { f.flushCh <- struct{}{} }
+
+func (f *file) Rotate() { f.rotateCh <- struct{}{} }
+func (*file) SupportsColors() bool { return false }
+
+func (f *file) write(m *fileMessageBuf) {
+ select {
+ case <-f.rotateCh:
+ // Force re-opening the outfile next time in getWriter.
+ f.lastFileName = ""
+ default:
+ }
+
+ var writer *bufio.Writer
+ if f.strategy.Rotation == DailyRotation {
+ writer = f.getWriter(m.now.Format("20060102"))
+ } else {
+ writer = f.getWriter(f.strategy.FileBase)
+ }
+
+ writer.WriteString(m.message)
+ writer.WriteByte('\n')
+}
+
+func (f *file) getWriter(name string) *bufio.Writer {
+ if f.lastFileName == name {
+ return f.writer
+ }
+ if _, err := os.Stat(config.Common.LogDir); os.IsNotExist(err) {
+ if err = os.MkdirAll(config.Common.LogDir, 0755); err != nil {
+ panic(err)
+ }
+ }
+
+ logFile := fmt.Sprintf("%s/%s.log", config.Common.LogDir, name)
+ newFd, err := os.OpenFile(logFile, os.O_CREATE|os.O_RDWR|os.O_APPEND, 0644)
+ if err != nil {
+ panic(err)
+ }
+
+ // Close old writer.
+ if f.fd != nil {
+ f.writer.Flush()
+ f.fd.Close()
+ }
+ // Set new writer.
+ f.fd = newFd
+ f.writer = bufio.NewWriterSize(f.fd, 1)
+ f.lastFileName = name
+
+ return f.writer
+}
+
+func (f *file) flush() {
+ defer func() {
+ if f.writer != nil {
+ f.writer.Flush()
+ }
+ }()
+ for {
+ select {
+ case m := <-f.bufferCh:
+ f.write(m)
+ default:
+ return
+ }
+ }
+}
diff --git a/internal/io/dlog/loggers/fout.go b/internal/io/dlog/loggers/fout.go
new file mode 100644
index 0000000..60c318d
--- /dev/null
+++ b/internal/io/dlog/loggers/fout.go
@@ -0,0 +1,46 @@
+package loggers
+
+import (
+ "context"
+ "sync"
+ "time"
+)
+
+type fout struct {
+ file *file
+ stdout *stdout
+}
+
+// Logs to both, a file and stdout
+func newFout(strategy Strategy) *fout {
+ return &fout{file: newFile(strategy), stdout: newStdout()}
+}
+
+func (f *fout) Start(ctx context.Context, wg *sync.WaitGroup) {
+ go func() {
+ defer wg.Done()
+
+ var wg2 sync.WaitGroup
+ wg2.Add(2)
+ f.file.Start(ctx, &wg2)
+ f.stdout.Start(ctx, &wg2)
+ wg2.Wait()
+ }()
+}
+
+func (f *fout) Log(now time.Time, message string) {
+ f.stdout.Log(now, message)
+ f.file.Log(now, message)
+}
+
+func (f *fout) LogWithColors(now time.Time, message, coloredMessage string) {
+ f.stdout.LogWithColors(now, "", coloredMessage)
+ f.file.Log(now, message)
+}
+
+func (f *fout) Flush() { f.stdout.Flush(); f.file.Flush() }
+func (f *fout) Pause() { f.stdout.Pause(); f.file.Pause() }
+func (f *fout) Resume() { f.stdout.Resume(); f.file.Resume() }
+func (f *fout) Rotate() { f.file.Rotate() }
+
+func (fout) SupportsColors() bool { return true }
diff --git a/internal/io/dlog/loggers/logger.go b/internal/io/dlog/loggers/logger.go
new file mode 100644
index 0000000..d4e85de
--- /dev/null
+++ b/internal/io/dlog/loggers/logger.go
@@ -0,0 +1,19 @@
+package loggers
+
+import (
+ "context"
+ "sync"
+ "time"
+)
+
+// Logger is there to plug in your own log implementation.
+type Logger interface {
+ Log(now time.Time, message string)
+ LogWithColors(now time.Time, message, messageWithColors string)
+ Start(ctx context.Context, wg *sync.WaitGroup)
+ Flush()
+ Pause()
+ Resume()
+ Rotate()
+ SupportsColors() bool
+}
diff --git a/internal/io/dlog/loggers/none.go b/internal/io/dlog/loggers/none.go
new file mode 100644
index 0000000..270027f
--- /dev/null
+++ b/internal/io/dlog/loggers/none.go
@@ -0,0 +1,21 @@
+package loggers
+
+import (
+ "context"
+ "sync"
+ "time"
+)
+
+// don't log anything
+type none struct{}
+
+func (none) Start(ctx context.Context, wg *sync.WaitGroup) { wg.Done() }
+func (none) Log(now time.Time, message string) {}
+
+func (none) LogWithColors(now time.Time, message, coloredMessage string) {}
+
+func (none) Flush() {}
+func (none) Pause() {}
+func (none) Resume() {}
+func (none) Rotate() {}
+func (none) SupportsColors() bool { return false }
diff --git a/internal/io/dlog/loggers/stdout.go b/internal/io/dlog/loggers/stdout.go
new file mode 100644
index 0000000..05485c6
--- /dev/null
+++ b/internal/io/dlog/loggers/stdout.go
@@ -0,0 +1,54 @@
+package loggers
+
+import (
+ "context"
+ "fmt"
+ "sync"
+ "time"
+)
+
+type stdout struct {
+ pauseCh chan struct{}
+ resumeCh chan struct{}
+ mutex sync.Mutex
+}
+
+func newStdout() *stdout {
+ return &stdout{
+ pauseCh: make(chan struct{}),
+ resumeCh: make(chan struct{}),
+ }
+}
+
+func (s *stdout) Start(ctx context.Context, wg *sync.WaitGroup) {
+ wg.Done()
+}
+
+func (s *stdout) Log(now time.Time, message string) {
+ s.log(message)
+}
+
+func (s *stdout) LogWithColors(now time.Time, message, coloredMessage string) {
+ s.log(coloredMessage)
+}
+
+func (s *stdout) log(message string) {
+ s.mutex.Lock()
+ defer s.mutex.Unlock()
+
+ select {
+ case <-s.pauseCh:
+ // Pause until resumed.
+ <-s.resumeCh
+ default:
+ }
+
+ fmt.Println(message)
+}
+
+func (s *stdout) Pause() { s.pauseCh <- struct{}{} }
+func (s *stdout) Resume() { s.resumeCh <- struct{}{} }
+func (s *stdout) Flush() {}
+func (s *stdout) Rotate() {}
+
+func (stdout) SupportsColors() bool { return true }
diff --git a/internal/io/dlog/loggers/strategy.go b/internal/io/dlog/loggers/strategy.go
new file mode 100644
index 0000000..48e7d44
--- /dev/null
+++ b/internal/io/dlog/loggers/strategy.go
@@ -0,0 +1,35 @@
+package loggers
+
+import (
+ "os"
+ "path/filepath"
+ "strings"
+)
+
+// Rotation is the actual strategy used for log rotation..
+type Rotation int
+
+const (
+ // DailyRotation tells DTail to rotate its logs on a daily basis or on SIGHUP.
+ DailyRotation Rotation = iota
+ // SignalRotation tells DTail to rotate its logs only on SIGHUP.
+ SignalRotation Rotation = iota
+)
+
+// Strategy is a pair of the rotation and the file base.
+type Strategy struct {
+ // Rotation is the actual rotation strategy used.
+ Rotation Rotation
+ // FileBase can be a name (e.g. "dserver", "dmap") when signal rotation is used.
+ FileBase string
+}
+
+// NewStrategy returns the stratey based on its name.
+func NewStrategy(name string) Strategy {
+ switch strings.ToLower(name) {
+ case "daily":
+ return Strategy{DailyRotation, ""}
+ default:
+ return Strategy{SignalRotation, filepath.Base(os.Args[0])}
+ }
+}
diff --git a/internal/io/dlog/rotation.go b/internal/io/dlog/rotation.go
new file mode 100644
index 0000000..15ce1fd
--- /dev/null
+++ b/internal/io/dlog/rotation.go
@@ -0,0 +1,27 @@
+package dlog
+
+import (
+ "context"
+ "os"
+ "os/signal"
+ "syscall"
+
+ "github.com/mimecast/dtail/internal/io/dlog/loggers"
+)
+
+func rotation(ctx context.Context) {
+ rotateCh := make(chan os.Signal, 1)
+ signal.Notify(rotateCh, syscall.SIGHUP)
+ go func() {
+ for {
+ select {
+ case <-rotateCh:
+ Common.Debug("Invoking log rotation")
+ loggers.FactoryRotate()
+ return
+ case <-ctx.Done():
+ return
+ }
+ }
+ }()
+}
diff --git a/internal/io/fs/catfile.go b/internal/io/fs/catfile.go
index 7f387bc..01c15ba 100644
--- a/internal/io/fs/catfile.go
+++ b/internal/io/fs/catfile.go
@@ -6,7 +6,9 @@ type CatFile struct {
}
// NewCatFile returns a new file catter.
-func NewCatFile(filePath string, globID string, serverMessages chan<- string, limiter chan struct{}) CatFile {
+func NewCatFile(filePath string, globID string, serverMessages chan<- string,
+ limiter chan struct{}) CatFile {
+
return CatFile{
readFile: readFile{
filePath: filePath,
diff --git a/internal/io/fs/filereader.go b/internal/io/fs/filereader.go
index efd410e..b05fd39 100644
--- a/internal/io/fs/filereader.go
+++ b/internal/io/fs/filereader.go
@@ -8,9 +8,11 @@ import (
"github.com/mimecast/dtail/internal/regex"
)
-// FileReader is the interface used on the dtail server to read/cat/grep/mapr... a file.
+// FileReader is the interface used on the dtail server to read/cat/grep/mapr...
+// a file.
type FileReader interface {
- Start(ctx context.Context, lContext lcontext.LContext, lines chan<- line.Line, re regex.Regex) error
+ Start(ctx context.Context, ltx lcontext.LContext, lines chan<- line.Line,
+ re regex.Regex) error
FilePath() string
Retry() bool
}
diff --git a/internal/io/fs/filter.go b/internal/io/fs/filter.go
deleted file mode 100644
index c4f605e..0000000
--- a/internal/io/fs/filter.go
+++ /dev/null
@@ -1,167 +0,0 @@
-package fs
-
-import (
- "context"
-
- "github.com/mimecast/dtail/internal/io/line"
- "github.com/mimecast/dtail/internal/lcontext"
- "github.com/mimecast/dtail/internal/regex"
-)
-
-func (f readFile) filter(ctx context.Context, rawLines <-chan []byte, lines chan<- line.Line, re regex.Regex, lContext lcontext.LContext) {
- // Do we have any kind of local context settings? If so then run the more complex
- // filterWithLContext method.
- if lContext.Has() {
- // We can not skip transmitting any lines to the client with a local
- // grep context specified.
- f.canSkipLines = false
- f.filterWithLContext(ctx, rawLines, lines, re, lContext)
- return
- }
-
- f.filterWithoutLContext(ctx, rawLines, lines, re)
-}
-
-// Filter log lines matching a given regular expression, however with local grep context.
-func (f readFile) filterWithLContext(ctx context.Context, rawLines <-chan []byte, lines chan<- line.Line, re regex.Regex, lContext lcontext.LContext) {
- // Scenario 1: Finish once maxCount hits found
- maxCount := lContext.MaxCount
- processMaxCount := maxCount > 0
- maxReached := false
-
- // Scenario 2: Print prev. N lines when current line matches.
- before := lContext.BeforeContext
- processBefore := before > 0
- var beforeBuf chan []byte
- if processBefore {
- beforeBuf = make(chan []byte, before)
- }
-
- // Screnario 3: Print next N lines when current line matches.
- after := 0
- processAfter := lContext.AfterContext > 0
-
- for rawLine := range rawLines {
- // logger.Debug("rawLine", string(rawLine))
- f.updatePosition()
-
- if !re.Match(rawLine) {
- f.updateLineNotMatched()
-
- if processAfter && after > 0 {
- after--
- myLine := line.Line{Content: rawLine, SourceID: f.globID, Count: f.totalLineCount(), TransmittedPerc: 100}
- select {
- case lines <- myLine:
- case <-ctx.Done():
- return
- }
-
- } else if processBefore {
- // Keep last num BeforeContext raw messages.
- select {
- case beforeBuf <- rawLine:
- default:
- <-beforeBuf
- beforeBuf <- rawLine
- }
- }
- continue
- }
-
- f.updateLineMatched()
-
- if processAfter {
- if maxReached {
- return
- }
- after = lContext.AfterContext
- }
-
- if processBefore {
- i := uint64(len(beforeBuf))
- for {
- select {
- case myRawLine := <-beforeBuf:
- myLine := line.Line{Content: myRawLine, SourceID: f.globID, Count: f.totalLineCount() - i, TransmittedPerc: 100}
- i--
- select {
- case lines <- myLine:
- case <-ctx.Done():
- return
- }
- default:
- // beforeBuf is now empty.
- }
- if len(beforeBuf) == 0 {
- break
- }
- }
- }
-
- line := line.Line{Content: rawLine, SourceID: f.globID, Count: f.totalLineCount(), TransmittedPerc: 100}
-
- select {
- case lines <- line:
- if processMaxCount {
- maxCount--
- if maxCount == 0 {
- if !processAfter || after == 0 {
- return
- }
- // Unfortunatley we have to continue filter, as there might be more lines to print
- maxReached = true
- }
- }
- case <-ctx.Done():
- return
- }
- }
-}
-
-// Filter log lines matching a given regular expression, there is no local grep context specified.
-func (f readFile) filterWithoutLContext(ctx context.Context, rawLines <-chan []byte, lines chan<- line.Line, re regex.Regex) {
- for {
- select {
- case rawLine, ok := <-rawLines:
- f.updatePosition()
- if !ok {
- return
- }
-
- if f.lineUntransmittable(rawLine, len(lines), cap(lines), re) {
- continue
- }
-
- line := line.Line{Content: rawLine, SourceID: f.globID, Count: f.totalLineCount(), TransmittedPerc: f.transmittedPerc()}
-
- select {
- case lines <- line:
- continue
- case <-ctx.Done():
- return
- }
- }
- }
-}
-
-func (f readFile) lineUntransmittable(rawLine []byte, length, capacity int, re regex.Regex) bool {
- if !re.Match(rawLine) {
- f.updateLineNotMatched()
- f.updateLineNotTransmitted()
- // Regex dosn't match, so not interested in it.
- return true
- }
- f.updateLineMatched()
-
- // Can we actually send more messages, channel capacity reached?
- if f.canSkipLines && length >= capacity {
- f.updateLineNotTransmitted()
- // Matching, not transmittable
- return true
- }
- f.updateLineTransmitted()
-
- // Matching, transmittable
- return false
-}
diff --git a/internal/io/fs/permissions/permission.go b/internal/io/fs/permissions/permission.go
index cc5dd9b..d621c09 100644
--- a/internal/io/fs/permissions/permission.go
+++ b/internal/io/fs/permissions/permission.go
@@ -3,12 +3,12 @@
package permissions
import (
- "github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/io/dlog"
)
// ToRead is to check whether user has read permissions to a given file.
func ToRead(user, filePath string) (bool, error) {
// Only implemented for Linux, always expect true
- logger.Warn(user, filePath, "Not performing ACL check, not supported on this platform")
+ dlog.Common.Debug(user, filePath, "Not performing ACL check as not compiled in")
return true, nil
}
diff --git a/internal/io/fs/permissions/permission_linuxacl.go b/internal/io/fs/permissions/permission_linuxacl.go
index 7d2d7ca..904b90f 100644
--- a/internal/io/fs/permissions/permission_linuxacl.go
+++ b/internal/io/fs/permissions/permission_linuxacl.go
@@ -13,7 +13,7 @@ import (
"unsafe"
)
-// ToRead checks whether user has Linux file system permissions to read a given file.
+// ToRead checks whether user has Linux file system permissions to read a file.
func ToRead(user, filePath string) (bool, error) {
cUser := C.CString(user)
cFilePath := C.CString(filePath)
diff --git a/internal/io/fs/readfile.go b/internal/io/fs/readfile.go
index 161e3f0..28cbe58 100644
--- a/internal/io/fs/readfile.go
+++ b/internal/io/fs/readfile.go
@@ -2,16 +2,20 @@ package fs
import (
"bufio"
+ "bytes"
"compress/gzip"
"context"
+ "errors"
"fmt"
"io"
"os"
"strings"
+ "sync"
"time"
+ "github.com/mimecast/dtail/internal/io/dlog"
"github.com/mimecast/dtail/internal/io/line"
- "github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/io/pool"
"github.com/mimecast/dtail/internal/lcontext"
"github.com/mimecast/dtail/internal/regex"
@@ -37,31 +41,10 @@ type readFile struct {
limiter chan struct{}
}
-func (f readFile) makeReader(fd *os.File) (reader *bufio.Reader, err error) {
- switch {
- case strings.HasSuffix(f.FilePath(), ".gz"):
- fallthrough
- case strings.HasSuffix(f.FilePath(), ".gzip"):
- logger.Info(f.FilePath(), "Detected gzip compression format")
- var gzipReader *gzip.Reader
- gzipReader, err = gzip.NewReader(fd)
- if err != nil {
- return
- }
- reader = bufio.NewReader(gzipReader)
- case strings.HasSuffix(f.FilePath(), ".zst"):
- logger.Info(f.FilePath(), "Detected zstd compression format")
- reader = bufio.NewReader(zstd.NewReader(fd))
- default:
- reader = bufio.NewReader(fd)
- }
-
- return
-}
-
// String returns the string representation of the readFile
func (f readFile) String() string {
- return fmt.Sprintf("readFile(filePath:%s,globID:%s,retry:%v,canSkipLines:%v,seekEOF:%v)",
+ return fmt.Sprintf(
+ "readFile(filePath:%s,globID:%s,retry:%v,canSkipLines:%v,seekEOF:%v)",
f.filePath,
f.globID,
f.retry,
@@ -80,8 +63,10 @@ func (f readFile) Retry() bool {
}
// Start tailing a log file.
-func (f readFile) Start(ctx context.Context, lContext lcontext.LContext, lines chan<- line.Line, re regex.Regex) error {
- logger.Debug("readFile", f)
+func (f readFile) Start(ctx context.Context, ltx lcontext.LContext,
+ lines chan<- line.Line, re regex.Regex) error {
+
+ dlog.Common.Debug("readFile", f)
defer func() {
select {
case <-f.limiter:
@@ -93,7 +78,8 @@ func (f readFile) Start(ctx context.Context, lContext lcontext.LContext, lines c
case f.limiter <- struct{}{}:
default:
select {
- case f.serverMessages <- logger.Warn(f.filePath, f.globID, "Server limit reached. Queuing file..."):
+ case f.serverMessages <- dlog.Common.Warn(f.filePath, f.globID,
+ "Server limit reached. Queuing file..."):
case <-ctx.Done():
return nil
}
@@ -110,111 +96,335 @@ func (f readFile) Start(ctx context.Context, lContext lcontext.LContext, lines c
fd.Seek(0, io.SeekEnd)
}
- rawLines := make(chan []byte, 100)
+ rawLines := make(chan *bytes.Buffer, 100)
+ truncate := make(chan struct{})
+
readCtx, readCancel := context.WithCancel(ctx)
+ var filterWg sync.WaitGroup
+ filterWg.Add(1)
- filterDone := make(chan struct{})
+ go f.periodicTruncateCheck(ctx, truncate)
go func() {
- f.filter(ctx, rawLines, lines, re, lContext)
- close(filterDone)
+ f.filter(ctx, ltx, rawLines, lines, re)
+ filterWg.Done()
// If the filter stopped, make the reader stop too, no need to read
// more data if there is nothing more the filter wants to filter for!
// E.g. it could be that we only want to filter N matches but not more.
readCancel()
}()
- err = f.read(readCtx, fd, rawLines)
+ err = f.read(readCtx, fd, rawLines, truncate)
close(rawLines)
-
- // Filter may flushes some data still. So wait until it is done here.
- <-filterDone
+ // Filter may sends some data still. So wait until it is done here.
+ filterWg.Wait()
return err
}
-func (f readFile) read(ctx context.Context, fd *os.File, rawLines chan []byte) error {
- var offset uint64
+func (f readFile) periodicTruncateCheck(ctx context.Context, truncate chan struct{}) {
+ for {
+ select {
+ case <-time.After(time.Second * 3):
+ select {
+ case truncate <- struct{}{}:
+ case <-ctx.Done():
+ }
+ case <-ctx.Done():
+ return
+ }
+ }
+}
+func (f readFile) makeReader(fd *os.File) (reader *bufio.Reader, err error) {
+ switch {
+ case strings.HasSuffix(f.FilePath(), ".gz"):
+ fallthrough
+ case strings.HasSuffix(f.FilePath(), ".gzip"):
+ dlog.Common.Info(f.FilePath(), "Detected gzip compression format")
+ var gzipReader *gzip.Reader
+ gzipReader, err = gzip.NewReader(fd)
+ if err != nil {
+ return
+ }
+ reader = bufio.NewReader(gzipReader)
+ case strings.HasSuffix(f.FilePath(), ".zst"):
+ dlog.Common.Info(f.FilePath(), "Detected zstd compression format")
+ reader = bufio.NewReader(zstd.NewReader(fd))
+ default:
+ reader = bufio.NewReader(fd)
+ }
+ return
+}
+
+func (f readFile) read(ctx context.Context, fd *os.File, rawLines chan *bytes.Buffer, truncate <-chan struct{}) error {
+ var offset uint64
reader, err := f.makeReader(fd)
if err != nil {
return err
}
- rawLine := make([]byte, 0, 512)
lineLengthThreshold := 1024 * 1024 // 1mb
- longLineWarning := false
-
- checkTruncate := f.truncateTimer(ctx)
+ warnedAboutLongLine := false
+ message := pool.BytesBuffer.Get().(*bytes.Buffer)
for {
- select {
- case <-ctx.Done():
- return nil
- default:
- }
-
- select {
- case <-checkTruncate:
- if isTruncated, err := f.truncated(fd); isTruncated {
- return err
- }
- logger.Info(f.filePath, "Current offset", offset)
- default:
- }
-
- // Read some bytes (max 4k at once as of go 1.12). isPrefix will
- // be set if line does not fit into 4k buffer.
- bytes, isPrefix, err := reader.ReadLine()
+ b, err := reader.ReadByte()
if err != nil {
- // If EOF, sleep a couple of ms and return with nil error.
- // If other error, return with non-nil error.
if err != io.EOF {
return err
}
+ select {
+ case <-truncate:
+ if isTruncated, err := f.truncated(fd); isTruncated {
+ return err
+ }
+ case <-ctx.Done():
+ return nil
+ default:
+ }
if !f.seekEOF {
- logger.Debug(f.FilePath(), "End of file reached")
+ dlog.Common.Info(f.FilePath(), "End of file reached")
return nil
}
time.Sleep(time.Millisecond * 100)
continue
}
+ offset++
- rawLine = append(rawLine, bytes...)
- offset += uint64(len(bytes))
-
- if !isPrefix {
- // last LineRead call returned contend until end of line.
- rawLine = append(rawLine, '\n')
+ switch b {
+ case '\n':
select {
- case rawLines <- rawLine:
+ case rawLines <- message:
+ message = pool.BytesBuffer.Get().(*bytes.Buffer)
+ //fmt.Printf("%d %d %p\n", message.Len(), message.Cap(), message)
+ warnedAboutLongLine = false
case <-ctx.Done():
return nil
}
- rawLine = make([]byte, 0, 512)
- if longLineWarning {
- longLineWarning = false
+ default:
+ if message.Len() >= lineLengthThreshold {
+ if !warnedAboutLongLine {
+ f.serverMessages <- dlog.Common.Warn(f.filePath,
+ "Long log line, splitting into multiple lines")
+ warnedAboutLongLine = true
+ }
+ message.WriteString("\n")
+ select {
+ case rawLines <- message:
+ message = pool.BytesBuffer.Get().(*bytes.Buffer)
+ case <-ctx.Done():
+ return nil
+ }
+ }
+ message.WriteByte(b)
+ }
+ }
+}
+
+// Filter log lines matching a given regular expression.
+func (f readFile) filter(ctx context.Context, ltx lcontext.LContext,
+ rawLines <-chan *bytes.Buffer, lines chan<- line.Line, re regex.Regex) {
+
+ // Do we have any kind of local context settings? If so then run the more complex
+ // filterWithLContext method.
+ if ltx.Has() {
+ // We can not skip transmitting any lines to the client with a local
+ // grep context specified.
+ f.canSkipLines = false
+ f.filterWithLContext(ctx, ltx, rawLines, lines, re)
+ return
+ }
+
+ f.filterWithoutLContext(ctx, rawLines, lines, re)
+}
+
+func (f readFile) filterWithoutLContext(ctx context.Context, rawLines <-chan *bytes.Buffer,
+ lines chan<- line.Line, re regex.Regex) {
+
+ for {
+ select {
+ case line, ok := <-rawLines:
+ f.updatePosition()
+ if !ok {
+ return
+ }
+ if filteredLine, ok := f.transmittable(line, len(lines), cap(lines), re); ok {
+ select {
+ case lines <- filteredLine:
+ case <-ctx.Done():
+ return
+ }
+ }
+ }
+ }
+}
+
+// Filter log lines matching a given regular expression, however with local grep context.
+func (f readFile) filterWithLContext(ctx context.Context, ltx lcontext.LContext,
+ rawLines <-chan *bytes.Buffer, lines chan<- line.Line, re regex.Regex) {
+
+ // Scenario 1: Finish once maxCount hits found
+ maxCount := ltx.MaxCount
+ processMaxCount := maxCount > 0
+ maxReached := false
+
+ // Scenario 2: Print prev. N lines when current line matches.
+ before := ltx.BeforeContext
+ processBefore := before > 0
+ var beforeBuf chan *bytes.Buffer
+ if processBefore {
+ beforeBuf = make(chan *bytes.Buffer, before)
+ }
+
+ // Screnario 3: Print next N lines when current line matches.
+ after := 0
+ processAfter := ltx.AfterContext > 0
+
+ for lineBytesBuffer := range rawLines {
+ f.updatePosition()
+
+ if !re.Match(lineBytesBuffer.Bytes()) {
+ f.updateLineNotMatched()
+
+ if processAfter && after > 0 {
+ after--
+ myLine := line.Line{
+ Content: lineBytesBuffer,
+ SourceID: f.globID,
+ Count: f.totalLineCount(),
+ TransmittedPerc: 100,
+ }
+
+ select {
+ case lines <- myLine:
+ case <-ctx.Done():
+ return
+ }
+
+ } else if processBefore {
+ // Keep last num BeforeContext raw messages.
+ select {
+ case beforeBuf <- lineBytesBuffer:
+ default:
+ pool.RecycleBytesBuffer(<-beforeBuf)
+ beforeBuf <- lineBytesBuffer
+ }
}
continue
}
- // Last LineRead call could not read content until end of line, buffer
- // was too small. Determine whether we exceed the max line length we
- // want dtail to send to the client at once. Possibly split up log line
- // into multiple log lines.
- if len(rawLine) >= lineLengthThreshold {
- if !longLineWarning {
- f.serverMessages <- logger.Warn(f.filePath, "Long log line, splitting into multiple lines")
- // Only print out one warning per long log line.
- longLineWarning = true
+ f.updateLineMatched()
+
+ if processAfter {
+ if maxReached {
+ return
}
- rawLine = append(rawLine, '\n')
- select {
- case rawLines <- rawLine:
- case <-ctx.Done():
- return nil
+ after = ltx.AfterContext
+ }
+
+ if processBefore {
+ i := uint64(len(beforeBuf))
+ for {
+ select {
+ case lineBytesBuffer := <-beforeBuf:
+ myLine := line.Line{
+ Content: lineBytesBuffer,
+ SourceID: f.globID,
+ Count: f.totalLineCount() - i,
+ TransmittedPerc: 100,
+ }
+ i--
+
+ select {
+ case lines <- myLine:
+ case <-ctx.Done():
+ return
+ }
+ default:
+ // beforeBuf is now empty.
+ }
+ if len(beforeBuf) == 0 {
+ break
+ }
}
- rawLine = make([]byte, 0, 512)
}
+
+ line := line.Line{
+ Content: lineBytesBuffer,
+ SourceID: f.globID,
+ Count: f.totalLineCount(),
+ TransmittedPerc: 100,
+ }
+
+ select {
+ case lines <- line:
+ if processMaxCount {
+ maxCount--
+ if maxCount == 0 {
+ if !processAfter || after == 0 {
+ return
+ }
+ // Unfortunatley we have to continue filter, as there might be more lines to print
+ maxReached = true
+ }
+ }
+ case <-ctx.Done():
+ return
+ }
+ }
+}
+
+func (f readFile) transmittable(lineBytesBuffer *bytes.Buffer, length, capacity int,
+ re regex.Regex) (line.Line, bool) {
+
+ var read line.Line
+ if !re.Match(lineBytesBuffer.Bytes()) {
+ f.updateLineNotMatched()
+ f.updateLineNotTransmitted()
+ return read, false
+ }
+ f.updateLineMatched()
+
+ // Can we actually send more messages, channel capacity reached?
+ if f.canSkipLines && length >= capacity {
+ f.updateLineNotTransmitted()
+ return read, false
+ }
+ f.updateLineTransmitted()
+
+ read = line.Line{
+ Content: lineBytesBuffer,
+ SourceID: f.globID,
+ Count: f.totalLineCount(),
+ TransmittedPerc: f.transmittedPerc(),
+ }
+ return read, true
+}
+
+// Check wether log file is truncated. Returns nil if not.
+func (f readFile) truncated(fd *os.File) (bool, error) {
+ dlog.Common.Debug(f.filePath, "File truncation check")
+
+ // Can not seek currently open FD.
+ curPos, err := fd.Seek(0, os.SEEK_CUR)
+ if err != nil {
+ return true, err
+ }
+ // Can not open file at original path.
+ pathFd, err := os.Open(f.filePath)
+ if err != nil {
+ return true, err
+ }
+ defer pathFd.Close()
+
+ // Can not seek file at original path.
+ pathPos, err := pathFd.Seek(0, io.SeekEnd)
+ if err != nil {
+ return true, err
+ }
+ if curPos > pathPos {
+ return true, errors.New("File got truncated")
}
+ return false, nil
}
diff --git a/internal/io/fs/tailfile.go b/internal/io/fs/tailfile.go
index 14994e5..b03b45d 100644
--- a/internal/io/fs/tailfile.go
+++ b/internal/io/fs/tailfile.go
@@ -6,7 +6,9 @@ type TailFile struct {
}
// NewTailFile returns a new file tailer.
-func NewTailFile(filePath string, globID string, serverMessages chan<- string, limiter chan struct{}) TailFile {
+func NewTailFile(filePath string, globID string, serverMessages chan<- string,
+ limiter chan struct{}) TailFile {
+
return TailFile{
readFile: readFile{
filePath: filePath,
diff --git a/internal/io/fs/truncate.go b/internal/io/fs/truncate.go
deleted file mode 100644
index a8d59ac..0000000
--- a/internal/io/fs/truncate.go
+++ /dev/null
@@ -1,61 +0,0 @@
-package fs
-
-import (
- "context"
- "errors"
- "io"
- "os"
- "time"
-
- "github.com/mimecast/dtail/internal/io/logger"
-)
-
-func (f readFile) truncateTimer(ctx context.Context) (checkTruncate chan struct{}) {
- checkTruncate = make(chan struct{})
-
- go func() {
- for {
- select {
- case <-time.After(time.Second * 3):
- select {
- case checkTruncate <- struct{}{}:
- case <-ctx.Done():
- }
- case <-ctx.Done():
- return
- }
- }
- }()
-
- return
-}
-
-// Check wether log file is truncated. Returns nil if not.
-func (f readFile) truncated(fd *os.File) (bool, error) {
- logger.Debug(f.filePath, "File truncation check")
-
- // Can not seek currently open FD.
- curPos, err := fd.Seek(0, os.SEEK_CUR)
- if err != nil {
- return true, err
- }
-
- // Can not open file at original path.
- pathFd, err := os.Open(f.filePath)
- if err != nil {
- return true, err
- }
- defer pathFd.Close()
-
- // Can not seek file at original path.
- pathPos, err := pathFd.Seek(0, io.SeekEnd)
- if err != nil {
- return true, err
- }
-
- if curPos > pathPos {
- return true, errors.New("File got truncated")
- }
-
- return false, nil
-}
diff --git a/internal/io/line/line.go b/internal/io/line/line.go
index 715be34..d306c88 100644
--- a/internal/io/line/line.go
+++ b/internal/io/line/line.go
@@ -1,13 +1,14 @@
package line
import (
+ "bytes"
"fmt"
)
// Line represents a read log line.
type Line struct {
// The content of the log line.
- Content []byte
+ Content *bytes.Buffer
// Until now, how many log lines were processed?
Count uint64
// Sometimes we produce too many log lines so that the client
@@ -25,7 +26,7 @@ type Line struct {
// Return a human readable representation of the followed line.
func (l Line) String() string {
return fmt.Sprintf("Line(Content:%s,TransmittedPerc:%v,Count:%v,SourceID:%s)",
- string(l.Content),
+ l.Content.String(),
l.TransmittedPerc,
l.Count,
l.SourceID)
diff --git a/internal/io/logger/logger.go b/internal/io/logger/logger.go
deleted file mode 100644
index 4254eef..0000000
--- a/internal/io/logger/logger.go
+++ /dev/null
@@ -1,403 +0,0 @@
-package logger
-
-import (
- "bufio"
- "context"
- "fmt"
- "os"
- "os/signal"
- "runtime"
- "strings"
- "sync"
- "syscall"
- "time"
-
- "github.com/mimecast/dtail/internal/color"
- "github.com/mimecast/dtail/internal/config"
-)
-
-const (
- clientStr string = "CLIENT"
- serverStr string = "SERVER"
- infoStr string = "INFO"
- warnStr string = "WARN"
- errorStr string = "ERROR"
- fatalStr string = "FATAL"
- debugStr string = "DEBUG"
- traceStr string = "TRACE"
-)
-
-// Mode specifies the configured logging mode(s)
-var Mode Modes
-
-// Strategy is the current log strattegy used.
-var strategy Strategy
-
-// Synchronise access to logging.
-var mutex sync.Mutex
-
-// File descriptor of log file when Mode.logToFile enabled.
-var fd *os.File
-
-// File write buffer of log file when Mode.logToFile enabled.
-var writer *bufio.Writer
-
-// File write buffer of stdout when Mode.logToStdout enabled.
-var stdoutWriter *bufio.Writer
-
-// Current hostname.
-var hostname string
-
-// Used to detect change of day (create one log file per day0
-var lastDateStr string
-
-// Used to make logging non-blocking.
-var fileLogBufCh chan buf
-var stdoutBufCh chan string
-
-// Stdout channel, required to pause output
-var pauseCh chan struct{}
-var resumeCh chan struct{}
-
-// Tell the logger about logrotation
-var rotateCh chan os.Signal
-
-// Helper type to make logging non-blocking.
-type buf struct {
- time time.Time
- message string
-}
-
-// Start logging.
-func Start(ctx context.Context, mode Modes) {
- Mode = mode
-
- switch {
- case Mode.Nothing:
- return
- case Mode.Quiet:
- Mode.Trace = false
- Mode.Debug = false
- case Mode.Trace:
- Mode.Debug = true
- default:
- }
-
- strategy := logStrategy()
- stdoutWriter = bufio.NewWriter(os.Stdout)
-
- switch strategy {
- case DailyStrategy:
- _, err := os.Stat(config.Common.LogDir)
- Mode.logToFile = !os.IsNotExist(err)
- Mode.logToStdout = !Mode.Server || Mode.Debug || Mode.Trace || Mode.Quiet
- case StdoutStrategy:
- fallthrough
- default:
- Mode.logToFile = !Mode.Server
- Mode.logToStdout = true
- }
-
- fqdn, err := os.Hostname()
- if err != nil {
- panic(err)
- }
- s := strings.Split(fqdn, ".")
- hostname = s[0]
-
- pauseCh = make(chan struct{})
- resumeCh = make(chan struct{})
-
- // Setup logrotation
- rotateCh = make(chan os.Signal, 1)
- signal.Notify(rotateCh, syscall.SIGHUP)
-
- if Mode.logToStdout {
- stdoutBufCh = make(chan string, runtime.NumCPU()*100)
- go writeToStdout(ctx)
- }
-
- if Mode.logToFile {
- fileLogBufCh = make(chan buf, runtime.NumCPU()*100)
- go writeToFile(ctx)
- }
-}
-
-// Info message logging.
-func Info(args ...interface{}) string {
- if Mode.Server {
- return log(serverStr, infoStr, args)
- }
-
- return log(clientStr, infoStr, args)
-}
-
-// Warn message logging.
-func Warn(args ...interface{}) string {
- if !Mode.Quiet {
- if Mode.Server {
- return log(serverStr, warnStr, args)
- }
- return log(clientStr, warnStr, args)
- }
-
- return ""
-}
-
-// Error message logging.
-func Error(args ...interface{}) string {
- if Mode.Server {
- return log(serverStr, errorStr, args)
- }
-
- return log(clientStr, errorStr, args)
-}
-
-// Fatal message logging.
-func Fatal(args ...interface{}) string {
- if Mode.Server {
- return log(serverStr, fatalStr, args)
- }
-
- return log(clientStr, fatalStr, args)
-}
-
-// FatalExit logs an error and exists the process.
-func FatalExit(args ...interface{}) {
- what := clientStr
- if Mode.Server {
- what = serverStr
- }
- log(what, fatalStr, args)
-
- time.Sleep(time.Second)
- mutex.Lock()
- defer mutex.Unlock()
-
- closeWriter()
- os.Exit(3)
-}
-
-// Debug message logging.
-func Debug(args ...interface{}) string {
- if Mode.Debug {
- if Mode.Server {
- return log(serverStr, debugStr, args)
- }
- return log(clientStr, debugStr, args)
- }
-
- return ""
-}
-
-// Trace message logging.
-func Trace(args ...interface{}) string {
- if Mode.Trace {
- if Mode.Server {
- return log(serverStr, traceStr, args)
- }
- return log(clientStr, traceStr, args)
- }
-
- return ""
-}
-
-// Write log line to buffer and/or log file.
-func write(what, severity, message string) {
- if Mode.logToStdout {
- line := fmt.Sprintf("%s|%s|%s|%s\n", what, hostname, severity, message)
-
- if color.Colored {
- line = color.Colorfy(line)
- }
-
- stdoutBufCh <- line
- }
-
- if Mode.logToFile {
- t := time.Now()
- timeStr := t.Format("20060102-150405")
- fileLogBufCh <- buf{
- time: t,
- message: fmt.Sprintf("%s|%s|%s|%s\n", severity, timeStr, what, message),
- }
- }
-}
-
-// Generig log message.
-func log(what string, severity string, args []interface{}) string {
- if Mode.Nothing {
- return ""
- }
-
- messages := []string{}
-
- for _, arg := range args {
- switch v := arg.(type) {
- case string:
- messages = append(messages, v)
- case int:
- messages = append(messages, fmt.Sprintf("%d", v))
- case error:
- messages = append(messages, v.Error())
- default:
- messages = append(messages, fmt.Sprintf("%v", v))
- }
- }
-
- message := strings.Join(messages, "|")
- write(what, severity, message)
-
- return fmt.Sprintf("%s|%s", severity, message)
-}
-
-// Raw message logging.
-func Raw(message string) {
- if Mode.Nothing {
- return
- }
-
- if Mode.logToFile {
- fileLogBufCh <- buf{time.Now(), message}
- }
-
- if Mode.logToStdout {
- if color.Colored {
- message = color.Colorfy(message)
- }
- stdoutBufCh <- message
- }
-}
-
-// Close log writer (e.g. on change of day).
-func closeWriter() {
- if writer != nil {
- writer.Flush()
- fd.Close()
- }
-}
-
-// Return the correct log file writer
-func fileWriter(dateStr string) *bufio.Writer {
- if dateStr != lastDateStr {
- return updateFileWriter(dateStr)
- }
-
- // Check for log rotation signal
- select {
- case <-rotateCh:
- stdoutWriter.WriteString("Received signal for logrotation\n")
- return updateFileWriter(dateStr)
- default:
- }
-
- return writer
-}
-
-// Update log file writer
-func updateFileWriter(dateStr string) *bufio.Writer {
- // Detected change of day. Close current writer and create a new one.
- mutex.Lock()
- defer mutex.Unlock()
- closeWriter()
-
- if _, err := os.Stat(config.Common.LogDir); os.IsNotExist(err) {
- if err = os.MkdirAll(config.Common.LogDir, 0755); err != nil {
- panic(err)
- }
- }
-
- logFile := fmt.Sprintf("%s/%s.log", config.Common.LogDir, dateStr)
- newFd, err := os.OpenFile(logFile, os.O_CREATE|os.O_RDWR|os.O_APPEND, 0644)
- if err != nil {
- panic(err)
- }
-
- fd = newFd
- writer = bufio.NewWriterSize(fd, 1)
- lastDateStr = dateStr
-
- return writer
-}
-
-// Flush all outstanding lines.
-func Flush() {
- for {
- select {
- case message := <-stdoutBufCh:
- stdoutWriter.WriteString(message)
- default:
- stdoutWriter.Flush()
- return
- }
- }
-}
-
-func writeToStdout(ctx context.Context) {
- for {
- select {
- case message := <-stdoutBufCh:
- stdoutWriter.WriteString(message)
- case <-time.After(time.Millisecond * 100):
- stdoutWriter.Flush()
- case <-pauseCh:
- PAUSE:
- for {
- select {
- case <-stdoutBufCh:
- case <-resumeCh:
- break PAUSE
- case <-ctx.Done():
- return
- }
- }
- case <-ctx.Done():
- Flush()
- return
- }
- }
-}
-
-func writeToFile(ctx context.Context) {
- for {
- select {
- case buf := <-fileLogBufCh:
- dateStr := buf.time.Format("20060102")
- w := fileWriter(dateStr)
- w.WriteString(buf.message)
- case <-pauseCh:
- PAUSE:
- for {
- select {
- case <-stdoutBufCh:
- case <-resumeCh:
- break PAUSE
- case <-ctx.Done():
- return
- }
- }
- case <-ctx.Done():
- return
- }
- }
-}
-
-// Pause logging.
-func Pause() {
- if Mode.logToStdout {
- pauseCh <- struct{}{}
- }
- if Mode.logToFile {
- pauseCh <- struct{}{}
- }
-}
-
-// Resume logging (after pausing).
-func Resume() {
- if Mode.logToStdout {
- resumeCh <- struct{}{}
- }
- if Mode.logToFile {
- resumeCh <- struct{}{}
- }
-}
diff --git a/internal/io/logger/modes.go b/internal/io/logger/modes.go
deleted file mode 100644
index 8864179..0000000
--- a/internal/io/logger/modes.go
+++ /dev/null
@@ -1,12 +0,0 @@
-package logger
-
-// Modes specifies the logging mode.
-type Modes struct {
- Server bool
- Trace bool
- Debug bool
- Nothing bool
- Quiet bool
- logToStdout bool
- logToFile bool
-}
diff --git a/internal/io/logger/strategy.go b/internal/io/logger/strategy.go
deleted file mode 100644
index 44bf393..0000000
--- a/internal/io/logger/strategy.go
+++ /dev/null
@@ -1,22 +0,0 @@
-package logger
-
-import "github.com/mimecast/dtail/internal/config"
-
-// Strategy allows to specify a log rotation strategy.
-type Strategy int
-
-// Possible log strategies.
-const (
- NormalStrategy Strategy = iota
- DailyStrategy Strategy = iota
- StdoutStrategy Strategy = iota
-)
-
-func logStrategy() Strategy {
- switch config.Common.LogStrategy {
- case "daily":
- return DailyStrategy
- default:
- }
- return StdoutStrategy
-}
diff --git a/internal/io/pool/builder.go b/internal/io/pool/builder.go
new file mode 100644
index 0000000..89fcf81
--- /dev/null
+++ b/internal/io/pool/builder.go
@@ -0,0 +1,21 @@
+package pool
+
+import (
+ "strings"
+ "sync"
+)
+
+// BuilderBuffer is there to optimize memory allocations (DTail allocates a lot
+// of memory while reading log data otherwise)
+var BuilderBuffer = sync.Pool{
+ New: func() interface{} {
+ sb := strings.Builder{}
+ return &sb
+ },
+}
+
+// RecycleBuilderBuffer recycles the buffer again.
+func RecycleBuilderBuffer(sb *strings.Builder) {
+ sb.Reset()
+ BuilderBuffer.Put(sb)
+}
diff --git a/internal/io/pool/bytesbuffer.go b/internal/io/pool/bytesbuffer.go
new file mode 100644
index 0000000..3d48f2c
--- /dev/null
+++ b/internal/io/pool/bytesbuffer.go
@@ -0,0 +1,22 @@
+package pool
+
+import (
+ "bytes"
+ "sync"
+)
+
+// BytesBuffer is there to optimize memory allocations. DTail otherwise allocates
+// a lot of memory while reading logs.
+var BytesBuffer = sync.Pool{
+ New: func() interface{} {
+ b := bytes.Buffer{}
+ b.Grow(128)
+ return &b
+ },
+}
+
+// RecycleBytesBuffer recycles the buffer again.
+func RecycleBytesBuffer(b *bytes.Buffer) {
+ b.Reset()
+ BytesBuffer.Put(b)
+}
diff --git a/internal/io/prompt/prompt.go b/internal/io/prompt/prompt.go
index 36ebdb5..e82132d 100644
--- a/internal/io/prompt/prompt.go
+++ b/internal/io/prompt/prompt.go
@@ -6,7 +6,7 @@ import (
"os"
"strings"
- "github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/io/dlog"
)
// Answer is a user input of a prompt question.
@@ -19,7 +19,8 @@ type Answer struct {
Callback func()
// Runs after Callback and after logging resumes
EndCallback func()
- AskAgain bool
+ // AskAgain can be used to not to ask again about the question.
+ AskAgain bool
}
// Prompt used for interactive user input.
@@ -30,7 +31,6 @@ type Prompt struct {
func (p *Prompt) askString() string {
var sb strings.Builder
-
sb.WriteString(p.question)
sb.WriteString("? (")
@@ -41,7 +41,6 @@ func (p *Prompt) askString() string {
sb.WriteString(strings.Join(ax, ","))
sb.WriteString("): ")
-
return sb.String()
}
@@ -58,7 +57,7 @@ func (p *Prompt) Add(answer Answer) {
// Ask a question.
func (p *Prompt) Ask() {
reader := bufio.NewReader(os.Stdin)
- logger.Pause()
+ dlog.Common.Pause()
for {
fmt.Print(p.askString())
@@ -68,9 +67,8 @@ func (p *Prompt) Ask() {
if a.Callback != nil {
a.Callback()
}
-
if !a.AskAgain {
- logger.Resume()
+ dlog.Common.Resume()
if a.EndCallback != nil {
a.EndCallback()
}
@@ -90,6 +88,5 @@ func (p *Prompt) answer(answerStr string) (*Answer, bool) {
default:
}
}
-
return nil, false
}
diff --git a/internal/io/signal/signal.go b/internal/io/signal/signal.go
index 500c530..584b59c 100644
--- a/internal/io/signal/signal.go
+++ b/internal/io/signal/signal.go
@@ -14,10 +14,8 @@ import (
func InterruptCh(ctx context.Context) <-chan string {
sigIntCh := make(chan os.Signal)
gosignal.Notify(sigIntCh, os.Interrupt)
-
sigOtherCh := make(chan os.Signal)
gosignal.Notify(sigOtherCh, syscall.SIGHUP, syscall.SIGTERM, syscall.SIGQUIT)
-
statsCh := make(chan string)
go func() {
@@ -41,6 +39,10 @@ func InterruptCh(ctx context.Context) <-chan string {
}
}
}()
-
return statsCh
}
+
+// NoCh doesn't listen on a signal.
+func NoCh(ctx context.Context) <-chan string {
+ return make(chan string)
+}
diff --git a/internal/lcontext/lcontext.go b/internal/lcontext/lcontext.go
index 89cb7c3..183ceb5 100644
--- a/internal/lcontext/lcontext.go
+++ b/internal/lcontext/lcontext.go
@@ -1,6 +1,6 @@
package lcontext
-// LContext stands for line context and is here to help filtering out only specific lines.
+// LContext stands for line context (used by context aware grep queries e.g.)
type LContext struct {
AfterContext int
BeforeContext int
diff --git a/internal/mapr/aggregateset.go b/internal/mapr/aggregateset.go
index a6cc6eb..c50c7a1 100644
--- a/internal/mapr/aggregateset.go
+++ b/internal/mapr/aggregateset.go
@@ -5,6 +5,10 @@ import (
"fmt"
"strconv"
"strings"
+
+ "github.com/mimecast/dtail/internal/io/dlog"
+ "github.com/mimecast/dtail/internal/io/pool"
+ "github.com/mimecast/dtail/internal/protocol"
)
// AggregateSet represents aggregated key/value pairs from the
@@ -33,8 +37,7 @@ func (s *AggregateSet) String() string {
// Merge one aggregate set into this one.
func (s *AggregateSet) Merge(query *Query, set *AggregateSet) error {
s.Samples += set.Samples
- //logger.Trace("Merge", set)
-
+ //dlog.Common.Trace("Merge", set)
for _, sc := range query.Select {
storage := sc.FieldStorage
switch sc.Operation {
@@ -66,24 +69,27 @@ func (s *AggregateSet) Merge(query *Query, set *AggregateSet) error {
// Serialize the aggregate set so it can be sent over the wire.
func (s *AggregateSet) Serialize(ctx context.Context, groupKey string, ch chan<- string) {
- //logger.Trace("Serialising mapr.AggregateSet", s)
- var sb strings.Builder
+ dlog.Common.Trace("Serialising mapr.AggregateSet", s)
+ sb := pool.BuilderBuffer.Get().(*strings.Builder)
+ defer pool.RecycleBuilderBuffer(sb)
sb.WriteString(groupKey)
- sb.WriteString("âž”")
- sb.WriteString(fmt.Sprintf("%dâž”", s.Samples))
+ sb.WriteString(protocol.AggregateDelimiter)
+ sb.WriteString(fmt.Sprintf("%d", s.Samples))
+ sb.WriteString(protocol.AggregateDelimiter)
for k, v := range s.FValues {
sb.WriteString(k)
- sb.WriteString("=")
- sb.WriteString(fmt.Sprintf("%vâž”", v))
+ sb.WriteString(protocol.AggregateKVDelimiter)
+ sb.WriteString(fmt.Sprintf("%v", v))
+ sb.WriteString(protocol.AggregateDelimiter)
}
for k, v := range s.SValues {
sb.WriteString(k)
- sb.WriteString("=")
+ sb.WriteString(protocol.AggregateKVDelimiter)
sb.WriteString(v)
- sb.WriteString("âž”")
+ sb.WriteString(protocol.AggregateDelimiter)
}
select {
@@ -108,7 +114,6 @@ func (s *AggregateSet) addFloatMin(key string, value float64) {
s.FValues[key] = value
return
}
-
if f > value {
s.FValues[key] = value
}
@@ -121,7 +126,6 @@ func (s *AggregateSet) addFloatMax(key string, value float64) {
s.FValues[key] = value
return
}
-
if f < value {
s.FValues[key] = value
}
@@ -140,7 +144,6 @@ func (s *AggregateSet) setFloat(key string, value float64) {
// Aggregate data to the aggregate set.
func (s *AggregateSet) Aggregate(key string, agg AggregateOperation, value string, clientAggregation bool) (err error) {
var f float64
-
// First check if we can aggregate anything without converting value to float.
switch agg {
case Count:
diff --git a/internal/mapr/client/aggregate.go b/internal/mapr/client/aggregate.go
index 10b34d4..02a6a5a 100644
--- a/internal/mapr/client/aggregate.go
+++ b/internal/mapr/client/aggregate.go
@@ -1,11 +1,13 @@
package client
import (
+ "fmt"
"strconv"
"strings"
- "github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/io/dlog"
"github.com/mimecast/dtail/internal/mapr"
+ "github.com/mimecast/dtail/internal/protocol"
)
// Aggregate mapreduce data on the DTail client side.
@@ -21,7 +23,9 @@ type Aggregate struct {
}
// NewAggregate create new client aggregator.
-func NewAggregate(server string, query *mapr.Query, globalGroup *mapr.GlobalGroupSet) *Aggregate {
+func NewAggregate(server string, query *mapr.Query,
+ globalGroup *mapr.GlobalGroupSet) *Aggregate {
+
return &Aggregate{
query: query,
group: mapr.NewGroupSet(),
@@ -31,20 +35,26 @@ func NewAggregate(server string, query *mapr.Query, globalGroup *mapr.GlobalGrou
}
// Aggregate data from mapr log line into local (and global) group sets.
-func (a *Aggregate) Aggregate(parts []string) {
+func (a *Aggregate) Aggregate(message string) error {
+ parts := strings.Split(message, protocol.AggregateDelimiter)
+ if len(parts) < 4 {
+ return fmt.Errorf("aggregate message without any real data")
+ }
+
groupKey := parts[0]
samples, err := strconv.Atoi(parts[1])
if err != nil {
- logger.FatalExit("Unable to parse sample count", parts[1], err, parts)
+ return fmt.Errorf("unable to parse sample count '%s': %v", parts[1], err)
}
+
fields := a.makeFields(parts[2:])
set := a.group.GetSet(groupKey)
-
var addedSamples bool
+
for _, sc := range a.query.Select {
if val, ok := fields[sc.FieldStorage]; ok {
if err := set.Aggregate(sc.FieldStorage, sc.Operation, val, true); err != nil {
- logger.Error(err)
+ dlog.Common.Error(err)
continue
}
addedSamples = true
@@ -63,19 +73,18 @@ func (a *Aggregate) Aggregate(parts []string) {
// Re-init local group (make it empty again).
a.group.InitSet()
}
+ return nil
}
// Create a map of key-value pairs from a part list such as ["foo=bar", "bar=baz"].
func (a *Aggregate) makeFields(parts []string) map[string]string {
fields := make(map[string]string, len(parts))
-
for _, part := range parts {
- kv := strings.SplitN(part, "=", 2)
- if len(kv) < 2 {
+ kv := strings.SplitN(part, protocol.AggregateKVDelimiter, 2)
+ if len(kv) != 2 {
continue
}
fields[kv[0]] = kv[1]
}
-
return fields
}
diff --git a/internal/mapr/funcs/function.go b/internal/mapr/funcs/function.go
index 1a89c3a..418d86f 100644
--- a/internal/mapr/funcs/function.go
+++ b/internal/mapr/funcs/function.go
@@ -19,13 +19,12 @@ type Function struct {
// FunctionStack is a list of functions stacked each other
type FunctionStack []Function
-// NewFunctionStack parses the input string, e.g. foo(bar("arg")) and returns a corresponding function stack.
+// NewFunctionStack parses the input string, e.g. foo(bar("arg")) and returns
+// a corresponding function stack.
func NewFunctionStack(in string) (FunctionStack, string, error) {
var fs FunctionStack
-
getCallback := func(name string) (CallbackFunc, error) {
var cb CallbackFunc
-
switch name {
case "md5sum":
return Md5Sum, nil
@@ -51,17 +50,15 @@ func NewFunctionStack(in string) (FunctionStack, string, error) {
fs = append(fs, Function{name, call})
aux = aux[index+1 : len(aux)-1]
}
-
return fs, aux, nil
}
// Call the function stack.
func (fs FunctionStack) Call(str string) string {
for i := len(fs) - 1; i >= 0; i-- {
- //logger.Debug("Call", fs[i].Name, str)
+ //dlog.Common.Debug("Call", fs[i].Name, str)
str = fs[i].call(str)
- //logger.Debug("Call.result", fs[i].Name, str)
+ //dlog.Common.Debug("Call.result", fs[i].Name, str)
}
-
return str
}
diff --git a/internal/mapr/funcs/function_test.go b/internal/mapr/funcs/function_test.go
index 415683c..8b5d8b7 100644
--- a/internal/mapr/funcs/function_test.go
+++ b/internal/mapr/funcs/function_test.go
@@ -6,16 +6,19 @@ func TestFunction(t *testing.T) {
input := "md5sum($line)"
fs, arg, err := NewFunctionStack(input)
if err != nil {
- t.Errorf("error parsing function input '%s': %s (%v)\n", input, err.Error(), fs)
+ t.Errorf("error parsing function input '%s': %s (%v)\n",
+ input, err.Error(), fs)
}
if arg != "$line" {
- t.Errorf("error parsing function input '%s': expected argument '$line' but got '%s' (%v)\n", input, arg, fs)
+ t.Errorf("error parsing function input '%s': expected argument '$line' but "+
+ "got '%s' (%v)\n", input, arg, fs)
}
t.Log(input, fs, arg)
result := fs.Call(input)
if result != "b38699013d79e50d9d122433753959c1" {
- t.Errorf("error executing function stack '%s': expected result 'b38699013d79e50d9d122433753959c1' but got '%s' (%v)\n", input, result, fs)
+ t.Errorf("error executing function stack '%s': expected result "+
+ "'b38699013d79e50d9d122433753959c1' but got '%s' (%v)\n", input, result, fs)
}
input = "maskdigits(md5sum(maskdigits($line)))"
@@ -24,22 +27,26 @@ func TestFunction(t *testing.T) {
t.Errorf("error parsing function input '%s': %s (%v)\n", input, err.Error(), fs)
}
if arg != "$line" {
- t.Errorf("error parsing function input '%s': expected argument '$line' but got '%s' (%v)\n", input, arg, fs)
+ t.Errorf("error parsing function input '%s': expected argument '$line' but "+
+ "got '%s' (%v)\n", input, arg, fs)
}
t.Log(input, fs, arg)
result = fs.Call(input)
if result != ".fac.bbe..bb.........d...a.c..b." {
- t.Errorf("error executing function stack '%s': expected result '.fac.bbe..bb.........d...a.c..b.' but got '%s' (%v)\n", input, result, fs)
+ t.Errorf("error executing function stack '%s': expected result "+
+ "'.fac.bbe..bb.........d...a.c..b.' but got '%s' (%v)\n", input, result, fs)
}
input = "md5sum$line)"
if fs, _, err := NewFunctionStack(input); err == nil {
- t.Errorf("Expected error parsing function input '%s' (%v) but got no error\n", input, fs)
+ t.Errorf("Expected error parsing function input '%s' (%v) but got no error\n",
+ input, fs)
}
input = "md5sum(makedigits$line))"
if fs, _, err := NewFunctionStack(input); err == nil {
- t.Errorf("Expected error parsing function input '%s' (%v) but got no error\n", input, fs)
+ t.Errorf("Expected error parsing function input '%s' (%v) but got no error\n",
+ input, fs)
}
}
diff --git a/internal/mapr/funcs/maskdigits.go b/internal/mapr/funcs/maskdigits.go
index d51f3d8..925ec4d 100644
--- a/internal/mapr/funcs/maskdigits.go
+++ b/internal/mapr/funcs/maskdigits.go
@@ -3,12 +3,10 @@ package funcs
// MaskDigits masks all digits (replaces them with .)
func MaskDigits(input string) string {
s := []byte(input)
-
for i, b := range s {
if '0' <= b && b <= '9' {
s[i] = '.'
}
}
-
return string(s)
}
diff --git a/internal/mapr/globalgroupset.go b/internal/mapr/globalgroupset.go
index cfab506..2d7f10b 100644
--- a/internal/mapr/globalgroupset.go
+++ b/internal/mapr/globalgroupset.go
@@ -17,7 +17,6 @@ func NewGlobalGroupSet() *GlobalGroupSet {
semaphore: make(chan struct{}, 1),
}
g.InitSet()
-
return &g
}
@@ -30,7 +29,6 @@ func (g *GlobalGroupSet) String() string {
func (g *GlobalGroupSet) Merge(query *Query, group *GroupSet) error {
g.semaphore <- struct{}{}
defer func() { <-g.semaphore }()
-
return g.merge(query, group)
}
@@ -54,7 +52,6 @@ func (g *GlobalGroupSet) merge(query *Query, group *GroupSet) error {
return err
}
}
-
return nil
}
@@ -67,7 +64,6 @@ func (g *GlobalGroupSet) IsEmpty() bool {
func (g *GlobalGroupSet) NumSets() int {
g.semaphore <- struct{}{}
defer func() { <-g.semaphore }()
-
return len(g.sets)
}
@@ -79,7 +75,6 @@ func (g *GlobalGroupSet) SwapOut() *GroupSet {
set := &GroupSet{sets: g.sets}
g.InitSet()
-
return set
}
@@ -87,14 +82,12 @@ func (g *GlobalGroupSet) SwapOut() *GroupSet {
func (g *GlobalGroupSet) WriteResult(query *Query) error {
g.semaphore <- struct{}{}
defer func() { <-g.semaphore }()
-
return g.GroupSet.WriteResult(query)
}
// Result returns the result of the mapreduce aggregation as a string.
-func (g *GlobalGroupSet) Result(query *Query) (string, int, error) {
+func (g *GlobalGroupSet) Result(query *Query, rowsLimit int) (string, int, error) {
g.semaphore <- struct{}{}
defer func() { <-g.semaphore }()
-
- return g.GroupSet.Result(query)
+ return g.GroupSet.Result(query, rowsLimit)
}
diff --git a/internal/mapr/groupset.go b/internal/mapr/groupset.go
index b5c8a48..6ffc8b9 100644
--- a/internal/mapr/groupset.go
+++ b/internal/mapr/groupset.go
@@ -4,13 +4,16 @@ import (
"context"
"errors"
"fmt"
- "io/ioutil"
"os"
"sort"
"strconv"
"strings"
- "github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/color"
+ "github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/io/dlog"
+ "github.com/mimecast/dtail/internal/io/pool"
+ "github.com/mimecast/dtail/internal/protocol"
)
// GroupSet represents a map of aggregate sets. The group sets
@@ -22,6 +25,14 @@ type GroupSet struct {
sets map[string]*AggregateSet
}
+// Internal helper type
+type result struct {
+ groupKey string
+ values []string
+ widths []int
+ orderBy float64
+}
+
// NewGroupSet returns a new empty group set.
func NewGroupSet() *GroupSet {
g := GroupSet{}
@@ -57,28 +68,181 @@ func (g *GroupSet) Serialize(ctx context.Context, ch chan<- string) {
}
// Result returns a nicely formated result of the query from the group set.
-func (g *GroupSet) Result(query *Query) (string, int, error) {
- return g.limitedResult(query, query.Limit, "\t", " ", false)
+func (g *GroupSet) Result(query *Query, rowsLimit int) (string, int, error) {
+ rows, widths, err := g.result(query, true)
+ if err != nil {
+ return "", 0, err
+ }
+ if query.Limit != -1 {
+ rowsLimit = query.Limit
+ }
+
+ sb := pool.BuilderBuffer.Get().(*strings.Builder)
+ defer pool.RecycleBuilderBuffer(sb)
+
+ // Generate header now
+ lastIndex := len(query.Select) - 1
+ for i, sc := range query.Select {
+ format := fmt.Sprintf(" %%%ds ", widths[i])
+ str := fmt.Sprintf(format, sc.FieldStorage)
+ if config.Client.TermColorsEnable {
+ attrs := []color.Attribute{config.Client.TermColors.MaprTable.HeaderAttr}
+ if sc.FieldStorage == query.OrderBy {
+ attrs = append(attrs, config.Client.TermColors.MaprTable.HeaderSortKeyAttr)
+ }
+
+ for _, groupBy := range query.GroupBy {
+ if sc.FieldStorage == groupBy {
+ attrs = append(attrs, config.Client.TermColors.MaprTable.HeaderGroupKeyAttr)
+ break
+ }
+ }
+
+ color.PaintWithAttrs(sb, str,
+ config.Client.TermColors.MaprTable.HeaderFg,
+ config.Client.TermColors.MaprTable.HeaderBg,
+ attrs)
+ } else {
+ sb.WriteString(str)
+ }
+
+ if i == lastIndex {
+ continue
+ }
+ if config.Client.TermColorsEnable {
+ color.PaintWithAttr(sb, protocol.FieldDelimiter,
+ config.Client.TermColors.MaprTable.HeaderDelimiterFg,
+ config.Client.TermColors.MaprTable.HeaderDelimiterBg,
+ config.Client.TermColors.MaprTable.HeaderDelimiterAttr)
+ } else {
+ sb.WriteString(protocol.FieldDelimiter)
+ }
+ }
+ sb.WriteString("\n")
+
+ for i := 0; i < len(query.Select); i++ {
+ str := fmt.Sprintf("-%s-", strings.Repeat("-", widths[i]))
+ if config.Client.TermColorsEnable {
+ color.PaintWithAttr(sb, str,
+ config.Client.TermColors.MaprTable.HeaderDelimiterFg,
+ config.Client.TermColors.MaprTable.HeaderDelimiterBg,
+ config.Client.TermColors.MaprTable.HeaderDelimiterAttr)
+ } else {
+ sb.WriteString(str)
+ }
+ if i == lastIndex {
+ continue
+ }
+ if config.Client.TermColorsEnable {
+ color.PaintWithAttr(sb, protocol.FieldDelimiter,
+ config.Client.TermColors.MaprTable.HeaderDelimiterFg,
+ config.Client.TermColors.MaprTable.HeaderDelimiterBg,
+ config.Client.TermColors.MaprTable.HeaderDelimiterAttr)
+ } else {
+ sb.WriteString(protocol.FieldDelimiter)
+ }
+ }
+ sb.WriteString("\n")
+
+ // And now write the data
+ for i, r := range rows {
+ if i == rowsLimit {
+ break
+ }
+ for j, value := range r.values {
+ format := fmt.Sprintf(" %%%ds ", widths[j])
+ str := fmt.Sprintf(format, value)
+ if config.Client.TermColorsEnable {
+ color.PaintWithAttr(sb, str,
+ config.Client.TermColors.MaprTable.DataFg,
+ config.Client.TermColors.MaprTable.DataBg,
+ config.Client.TermColors.MaprTable.DataAttr)
+ } else {
+ sb.WriteString(str)
+ }
+
+ if j == lastIndex {
+ continue
+ }
+ if config.Client.TermColorsEnable {
+ color.PaintWithAttr(sb, protocol.FieldDelimiter,
+ config.Client.TermColors.MaprTable.DelimiterFg,
+ config.Client.TermColors.MaprTable.DelimiterBg,
+ config.Client.TermColors.MaprTable.DelimiterAttr)
+ } else {
+ sb.WriteString(protocol.FieldDelimiter)
+ }
+ }
+ sb.WriteString("\n")
+ }
+
+ return sb.String(), len(rows), nil
}
-// WriteResult writes the result to an outfile.
+func (*GroupSet) writeQueryFile(query *Query) error {
+ queryFile := fmt.Sprintf("%s.query", query.Outfile)
+ tmpQueryFile := fmt.Sprintf("%s.tmp", queryFile)
+ dlog.Common.Debug("Writing query file", queryFile)
+
+ fd, err := os.Create(tmpQueryFile)
+ if err != nil {
+ return err
+ }
+ defer fd.Close()
+
+ fd.WriteString(query.RawQuery)
+ os.Rename(tmpQueryFile, queryFile)
+ return nil
+}
+
+// WriteResult writes the result to an CSV outfile.
func (g *GroupSet) WriteResult(query *Query) error {
if !query.HasOutfile() {
return errors.New("No outfile specified")
}
+ if err := g.writeQueryFile(query); err != nil {
+ return err
+ }
- // -1: Don't limit the result, include all data sets
- result, _, err := g.limitedResult(query, query.Limit, "", ",", true)
+ rows, _, err := g.result(query, false)
if err != nil {
return err
}
- logger.Info("Writing outfile", query.Outfile)
+ dlog.Common.Info("Writing outfile", query.Outfile)
tmpOutfile := fmt.Sprintf("%s.tmp", query.Outfile)
- if err := ioutil.WriteFile(tmpOutfile, []byte(result), 0644); err != nil {
+ fd, err := os.Create(tmpOutfile)
+ if err != nil {
return err
}
+ defer fd.Close()
+
+ // Generate header now
+ lastIndex := len(query.Select) - 1
+ for i, sc := range query.Select {
+ fd.WriteString(sc.FieldStorage)
+ if i == lastIndex {
+ continue
+ }
+ fd.WriteString(protocol.CSVDelimiter)
+ }
+ fd.WriteString("\n")
+
+ // And now write the data
+ for i, r := range rows {
+ if i == query.Limit {
+ break
+ }
+ for j, value := range r.values {
+ fd.WriteString(value)
+ if j == lastIndex {
+ continue
+ }
+ fd.WriteString(protocol.CSVDelimiter)
+ }
+ fd.WriteString("\n")
+ }
if err := os.Rename(tmpOutfile, query.Outfile); err != nil {
os.Remove(tmpOutfile)
@@ -88,32 +252,21 @@ func (g *GroupSet) WriteResult(query *Query) error {
return nil
}
-// Return a nicely formated result of the query from the group set.
-func (g *GroupSet) limitedResult(query *Query, limit int, lineStarter, fieldSeparator string, addHeader bool) (string, int, error) {
- type result struct {
- groupKey string
- resultStr string
- orderBy float64
- }
-
- var resultSlice []result
+// Return a sorted result slice of the query from the group set.
+func (g *GroupSet) result(query *Query, gatherWidths bool) ([]result, []int, error) {
+ var rows []result
+ widths := make([]int, len(query.Select))
+ var valueStr string
+ var value float64
for groupKey, set := range g.sets {
- var sb strings.Builder
r := result{groupKey: groupKey}
- lastIndex := len(query.Select) - 1
for i, sc := range query.Select {
- storage := sc.FieldStorage
- orderByThis := storage == query.OrderBy
-
switch sc.Operation {
case Count:
- value := set.FValues[storage]
- sb.WriteString(fmt.Sprintf("%d", int(value)))
- if orderByThis {
- r.orderBy = value
- }
+ value = set.FValues[sc.FieldStorage]
+ valueStr = fmt.Sprintf("%d", int(value))
case Len:
fallthrough
case Sum:
@@ -121,74 +274,48 @@ func (g *GroupSet) limitedResult(query *Query, limit int, lineStarter, fieldSepa
case Min:
fallthrough
case Max:
- value := set.FValues[storage]
- sb.WriteString(fmt.Sprintf("%f", value))
- if orderByThis {
- r.orderBy = value
- }
+ value = set.FValues[sc.FieldStorage]
+ valueStr = fmt.Sprintf("%f", value)
case Last:
- value := set.SValues[storage]
- if orderByThis {
- f, err := strconv.ParseFloat(value, 64)
- if err == nil {
- r.orderBy = f
- }
- }
- sb.WriteString(value)
+ valueStr = set.SValues[sc.FieldStorage]
+ value, _ = strconv.ParseFloat(valueStr, 64)
case Avg:
- value := set.FValues[storage] / float64(set.Samples)
- sb.WriteString(fmt.Sprintf("%f", value))
- if orderByThis {
- r.orderBy = value
- }
+ value = set.FValues[sc.FieldStorage] / float64(set.Samples)
+ valueStr = fmt.Sprintf("%f", value)
default:
- return "", 0, fmt.Errorf("Unknown aggregation method '%v'", sc.Operation)
+ return rows, widths, fmt.Errorf("Unknown aggregation method '%v'",
+ sc.Operation)
}
- if i != lastIndex {
- sb.WriteString(fieldSeparator)
+
+ if sc.FieldStorage == query.OrderBy {
+ r.orderBy = value
}
- }
+ r.values = append(r.values, valueStr)
- r.resultStr = sb.String()
- resultSlice = append(resultSlice, r)
+ if !gatherWidths {
+ continue
+ }
+ if widths[i] < len(sc.FieldStorage) {
+ widths[i] = len(sc.FieldStorage)
+ }
+ if widths[i] < len(valueStr) {
+ widths[i] = len(valueStr)
+ }
+ }
+ rows = append(rows, r)
}
if query.OrderBy != "" {
if query.ReverseOrder {
- sort.SliceStable(resultSlice, func(i, j int) bool {
- return resultSlice[i].orderBy < resultSlice[j].orderBy
+ sort.SliceStable(rows, func(i, j int) bool {
+ return rows[i].orderBy < rows[j].orderBy
})
} else {
- sort.SliceStable(resultSlice, func(i, j int) bool {
- return resultSlice[i].orderBy > resultSlice[j].orderBy
+ sort.SliceStable(rows, func(i, j int) bool {
+ return rows[i].orderBy > rows[j].orderBy
})
}
}
- var sb strings.Builder
-
- // Write header first
- if addHeader {
- lastIndex := len(query.Select) - 1
- sb.WriteString(lineStarter)
- for i, sc := range query.Select {
- sb.WriteString(sc.FieldStorage)
- if i != lastIndex {
- sb.WriteString(fieldSeparator)
- }
- }
- sb.WriteString("\n")
- }
-
- // And now write the data
- for i, r := range resultSlice {
- if i == limit {
- break
- }
- sb.WriteString(lineStarter)
- sb.WriteString(r.resultStr)
- sb.WriteString("\n")
- }
-
- return sb.String(), len(resultSlice), nil
+ return rows, widths, nil
}
diff --git a/internal/mapr/logformat/default.go b/internal/mapr/logformat/default.go
index 44bf558..9b6c855 100644
--- a/internal/mapr/logformat/default.go
+++ b/internal/mapr/logformat/default.go
@@ -1,14 +1,23 @@
package logformat
import (
- "errors"
+ "fmt"
"strings"
+
+ "github.com/mimecast/dtail/internal/protocol"
)
-// MakeFieldsDEFAULT is the default log file mapreduce parser.
+// MakeFieldsDEFAULT is the default DTail log file key-value parser.
func (p *Parser) MakeFieldsDEFAULT(maprLine string) (map[string]string, error) {
- fields := make(map[string]string, 20)
- splitted := strings.Split(maprLine, "|")
+ splitted := strings.Split(maprLine, protocol.FieldDelimiter)
+
+ if len(splitted) < 11 || !strings.HasPrefix(splitted[9], "MAPREDUCE:") ||
+ !strings.HasPrefix(splitted[0], "INFO") {
+ // Not a DTail mapreduce log line.
+ return nil, ErrIgnoreFields
+ }
+
+ fields := make(map[string]string, len(splitted)+8)
fields["*"] = "*"
fields["$line"] = maprLine
@@ -17,10 +26,30 @@ func (p *Parser) MakeFieldsDEFAULT(maprLine string) (map[string]string, error) {
fields["$timezone"] = p.timeZoneName
fields["$timeoffset"] = p.timeZoneOffset
- for _, kv := range splitted {
+ fields["$severity"] = splitted[0]
+ fields["$loglevel"] = splitted[0]
+
+ time := splitted[1]
+ fields["$time"] = time
+ if len(time) == 15 {
+ // Example: 20211002-071209
+ fields["$date"] = time[0:8]
+ fields["$hour"] = time[9:11]
+ fields["$minute"] = time[11:13]
+ fields["$second"] = time[13:]
+ }
+ fields["$pid"] = splitted[2]
+ fields["$caller"] = splitted[3]
+ fields["$cpus"] = splitted[4]
+ fields["$goroutines"] = splitted[5]
+ fields["$cgocalls"] = splitted[6]
+ fields["$loadavg"] = splitted[7]
+ fields["$uptime"] = splitted[8]
+
+ for _, kv := range splitted[10:] {
keyAndValue := strings.SplitN(kv, "=", 2)
if len(keyAndValue) != 2 {
- return fields, errors.New("Error parsing mapr token: " + kv)
+ return fields, fmt.Errorf("Unable to parse key-value token '%s'", kv)
}
fields[strings.ToLower(keyAndValue[0])] = keyAndValue[1]
}
diff --git a/internal/mapr/logformat/default_test.go b/internal/mapr/logformat/default_test.go
index 10ec8b7..28e1acc 100644
--- a/internal/mapr/logformat/default_test.go
+++ b/internal/mapr/logformat/default_test.go
@@ -1,6 +1,7 @@
package logformat
import (
+ "fmt"
"testing"
)
@@ -10,26 +11,83 @@ func TestDefaultLogFormat(t *testing.T) {
t.Errorf("Unable to create parser: %s", err.Error())
}
- fields, err := parser.MakeFields("foo=bar|baz=bay")
+ date := "20211002"
+ hour := "07"
+ minute := "23"
+ second := "42"
+ time := fmt.Sprintf("%s-%s%s%s", date, hour, minute, second)
- if err != nil {
- t.Errorf("Unable to parse: %s", err.Error())
+ inputs := []string{
+ fmt.Sprintf("INFO|%s|1|default_test.go:0|8|14|7|0.21|471h0m21s|MAPREDUCE:STATS|foo=bar|bar=foo", time),
+ fmt.Sprintf("INFO|%s|1|default_test.go:0|8|14|7|0.21|471h0m21s|MAPREDUCE:STATS|bar=foo|foo=bar", time),
}
- if bar, ok := fields["foo"]; !ok {
- t.Errorf("Expected field 'foo', but no such field there\n")
- } else if bar != "bar" {
- t.Errorf("Expected 'bar' stored in field 'foo', but got '%s'\n", bar)
- }
+ for _, input := range inputs {
+ fields, err := parser.MakeFields(input)
+
+ if err != nil {
+ t.Errorf("Parser unable to make fields: %s", err.Error())
+ }
+
+ if val, ok := fields["$severity"]; !ok {
+ t.Errorf("Expected field '$severity', but no such field there in '%s'\n", input)
+ } else if val != "INFO" {
+ t.Errorf("Expected 'Info' stored in field '$severity', but got '%s' in '%s'\n",
+ val, input)
+ }
+
+ if val, ok := fields["$time"]; !ok {
+ t.Errorf("Expected field '$time', but no such field there in '%s'\n", input)
+ } else if val != time {
+ t.Errorf("Expected '%s' stored in field '$time', but got '%s' in '%s'\n",
+ time, val, input)
+ }
+
+ if val, ok := fields["$date"]; !ok {
+ t.Errorf("Expected field '$date', but no such field there in '%s'\n", input)
+ } else if val != date {
+ t.Errorf("Expected '%s' stored in field '$date', but got '%s' in '%s'\n",
+ date, val, input)
+ }
+
+ if val, ok := fields["$hour"]; !ok {
+ t.Errorf("Expected field '$hour', but no such field there in '%s'\n", input)
+ } else if val != hour {
+ t.Errorf("Expected '%s' stored in field '$hour', but got '%s' in '%s'\n",
+ hour, val, input)
+ }
+
+ if val, ok := fields["$minute"]; !ok {
+ t.Errorf("Expected field '$minute', but no such field there in '%s'\n", input)
+ } else if val != minute {
+ t.Errorf("Expected '%s' stored in field '$minute', but got '%s' in '%s'\n",
+ minute, val, input)
+ }
+
+ if val, ok := fields["$second"]; !ok {
+ t.Errorf("Expected field '$second', but no such field there in '%s'\n", input)
+ } else if val != second {
+ t.Errorf("Expected '%s' stored in field '$second', but got '%s' in '%s'\n",
+ second, val, input)
+ }
+
+ if val, ok := fields["foo"]; !ok {
+ t.Errorf("Expected field 'foo', but no such field there in '%s'\n", input)
+ } else if val != "bar" {
+ t.Errorf("Expected 'bar' stored in field 'foo', but got '%s' in '%s'\n",
+ val, input)
+ }
- if bay, ok := fields["baz"]; !ok {
- t.Errorf("Expected field 'baz', but no such field there\n")
- } else if bay != "bay" {
- t.Errorf("Expected 'bay' stored in field 'baz', but got '%s'\n", bay)
+ if val, ok := fields["bar"]; !ok {
+ t.Errorf("Expected field 'bar', but no such field there in '%s'\n", input)
+ } else if val != "foo" {
+ t.Errorf("Expected 'foo' stored in field 'bar', but got '%s' in '%s'\n",
+ val, input)
+ }
}
- _, err = parser.MakeFields("foo=bar|bazbay")
- if err == nil {
- t.Errorf("Expected error but didn't: %s", err.Error())
+ fields, err := parser.MakeFields("foozoo=bar|bazbay")
+ if _, ok := fields["foo"]; ok {
+ t.Errorf("Expected fiending field 'foo', but found it\n")
}
}
diff --git a/internal/mapr/logformat/generickv.go b/internal/mapr/logformat/generickv.go
new file mode 100644
index 0000000..433eb5f
--- /dev/null
+++ b/internal/mapr/logformat/generickv.go
@@ -0,0 +1,31 @@
+package logformat
+
+import (
+ "strings"
+
+ "github.com/mimecast/dtail/internal/protocol"
+)
+
+// MakeFieldsGENERIGKV is the generic key-value logfile parser.
+func (p *Parser) MakeFieldsGENERIGKV(maprLine string) (map[string]string, error) {
+ splitted := strings.Split(maprLine, protocol.FieldDelimiter)
+ fields := make(map[string]string, len(splitted))
+
+ fields["*"] = "*"
+ fields["$line"] = maprLine
+ fields["$empty"] = ""
+ fields["$hostname"] = p.hostname
+ fields["$timezone"] = p.timeZoneName
+ fields["$timeoffset"] = p.timeZoneOffset
+
+ for _, kv := range splitted[0:] {
+ keyAndValue := strings.SplitN(kv, "=", 2)
+ if len(keyAndValue) != 2 {
+ //dlog.Common.Debug("Unable to parse key-value token, ignoring it", kv)
+ continue
+ }
+ fields[strings.ToLower(keyAndValue[0])] = keyAndValue[1]
+ }
+
+ return fields, nil
+}
diff --git a/internal/mapr/logformat/parser.go b/internal/mapr/logformat/parser.go
index c53729a..129081d 100644
--- a/internal/mapr/logformat/parser.go
+++ b/internal/mapr/logformat/parser.go
@@ -8,10 +8,12 @@ import (
"strings"
"time"
- "github.com/mimecast/dtail/internal/io/logger"
"github.com/mimecast/dtail/internal/mapr"
)
+// ErrIgnoreFields indicates that the fields should be ignored.
+var ErrIgnoreFields error = errors.New("Ignore this field set")
+
// Parser is used to parse the mapreduce information from the server log files.
type Parser struct {
hostname string
@@ -25,11 +27,9 @@ type Parser struct {
// NewParser returns a new log parser.
func NewParser(logFormatName string, query *mapr.Query) (*Parser, error) {
hostname, err := os.Hostname()
-
if err != nil {
return nil, err
}
-
now := time.Now()
zone, offset := now.Zone()
@@ -43,7 +43,6 @@ func NewParser(logFormatName string, query *mapr.Query) (*Parser, error) {
if err != nil {
return nil, err
}
-
return &p, nil
}
@@ -52,7 +51,6 @@ func NewParser(logFormatName string, query *mapr.Query) (*Parser, error) {
// Parser. Whereas MODULENAME must be a upeprcase string.
func (p *Parser) reflectLogFormat(logFormatName string) error {
methodName := fmt.Sprintf("MakeFields%s", strings.ToUpper(logFormatName))
-
rt := reflect.TypeOf(p)
method, ok := rt.MethodByName(methodName)
if !ok {
@@ -61,7 +59,6 @@ func (p *Parser) reflectLogFormat(logFormatName string) error {
p.makeFieldsFunc = method.Func
p.makeFieldsReceiver = reflect.ValueOf(p)
-
return nil
}
@@ -69,17 +66,11 @@ func (p *Parser) reflectLogFormat(logFormatName string) error {
func (p *Parser) MakeFields(maprLine string) (fields map[string]string, err error) {
inputValues := []reflect.Value{p.makeFieldsReceiver, reflect.ValueOf(maprLine)}
returnValues := p.makeFieldsFunc.Call(inputValues)
-
errInterface := returnValues[1].Interface()
-
if errInterface == nil {
fields, err = returnValues[0].Interface().(map[string]string), nil
- logger.Trace("parser.MakeFields", fields, err)
return
}
-
fields, err = returnValues[0].Interface().(map[string]string), errInterface.(error)
- logger.Trace("parser.MakeFields", fields, err)
-
return
}
diff --git a/internal/mapr/query.go b/internal/mapr/query.go
index 01852da..d7c32bd 100644
--- a/internal/mapr/query.go
+++ b/internal/mapr/query.go
@@ -6,8 +6,6 @@ import (
"strconv"
"strings"
"time"
-
- "github.com/mimecast/dtail/internal/io/logger"
)
const (
@@ -34,7 +32,9 @@ type Query struct {
}
func (q Query) String() string {
- return fmt.Sprintf("Query(Select:%v,Table:%s,Where:%v,Set:%vGroupBy:%v,GroupKey:%s,OrderBy:%v,ReverseOrder:%v,Interval:%v,Limit:%d,Outfile:%s,RawQuery:%s,tokens:%v,LogFormat:%s)",
+ return fmt.Sprintf("Query(Select:%v,Table:%s,Where:%v,Set:%vGroupBy:%v,"+
+ "GroupKey:%s,OrderBy:%v,ReverseOrder:%v,Interval:%v,Limit:%d,Outfile:%s,"+
+ "RawQuery:%s,tokens:%v,LogFormat:%s)",
q.Select,
q.Table,
q.Where,
@@ -56,20 +56,14 @@ func NewQuery(queryStr string) (*Query, error) {
if queryStr == "" {
return nil, nil
}
-
tokens := tokenize(queryStr)
-
q := Query{
RawQuery: queryStr,
tokens: tokens,
Interval: time.Second * 5,
Limit: -1,
}
-
- err := q.parse(tokens)
-
- logger.Debug(q)
- return &q, err
+ return &q, q.parse(tokens)
}
// HasOutfile returns true if query result will be written to a CVS output file.
@@ -178,13 +172,13 @@ func (q *Query) parse(tokens []token) error {
}
if len(q.Select) < 1 {
- return errors.New(invalidQuery + "Expected at least one field in 'select' clause but got none")
+ return errors.New(invalidQuery + "Expected at least one field in 'select' " +
+ "clause but got none")
}
if len(q.GroupBy) == 0 {
field := q.Select[0].Field
q.GroupBy = append(q.GroupBy, field)
}
-
if q.OrderBy != "" {
var orderFieldIsValid bool
for _, sc := range q.Select {
@@ -194,7 +188,8 @@ func (q *Query) parse(tokens []token) error {
}
}
if !orderFieldIsValid {
- return errors.New(invalidQuery + fmt.Sprintf("Can not '(r)order by' '%s', must be present in 'select' clause", q.OrderBy))
+ return errors.New(invalidQuery + fmt.Sprintf("Can not '(r)order by' '%s',"+
+ "must be present in 'select' clause", q.OrderBy))
}
}
diff --git a/internal/mapr/query_test.go b/internal/mapr/query_test.go
index b0b6c3a..88f7387 100644
--- a/internal/mapr/query_test.go
+++ b/internal/mapr/query_test.go
@@ -13,18 +13,25 @@ func TestParseQuerySimple(t *testing.T) {
"select foo from bar where baz <",
"select foo from bar where baz < 100 bay eq 12 group",
"select foo from bar where baz < 100 bay eq 12 group by foo order by",
- "select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz order by foo limit",
- "select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz order by foo limit set foo = bar;",
+ "select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz " +
+ "order by foo limit",
+ "select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz " +
+ "order by foo limit set foo = bar;",
}
okQueries := []string{"select foo from bar",
"select foo from bar where",
"select foo from bar where baz < 100 bay eq 12",
"select foo from bar where baz < 100, bay eq 12",
"select foo from bar where baz < 100 and bay eq 12",
- "select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz order by foo",
- "select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz order by foo limit 23",
- "select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz order by foo limit 23 outfile \"result.csv\"",
- "select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz order by foo limit 23 outfile \"result.csv\" set $foo = maskdigits(bar), $baz = 12, $bay = $foo;",
+ "select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz " +
+ "order by foo",
+ "select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz " +
+ "order by foo limit 23",
+ "select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz " +
+ "order by foo limit 23 outfile \"result.csv\"",
+ "select foo from bar where baz < 100 bay eq 12 group by foo, bar, baz " +
+ "order by foo limit 23 outfile \"result.csv\" " +
+ "set $foo = maskdigits(bar), $baz = 12, $bay = $foo;",
}
for _, queryStr := range errorQueries {
@@ -46,8 +53,13 @@ func TestParseQuerySimple(t *testing.T) {
func TestParseQueryDeep(t *testing.T) {
dialects := []string{
- "select s1, `from`, count(s3) from table where w1 == 2 and w2 eq \"free beer\" group by g1, g2 order by count(s3) interval 10 limit 23 set $foo = maskdigits(bar), $baz = 12, $bay = $foo logformat generic",
- "SELECT s1, `from`, COUNT(s3) FROM table WHERE w1 == 2 AND w2 eq \"free beer\" GROUP g1, g2 ORDER count(s3) INTERVAL 10 LIMIT 23 SET $foo = maskdigits(bar), $baz = 12, $bay = $foo logformat generic",
+ "select s1, `from`, count(s3) from table where w1 == 2 and w2 eq " +
+ "\"free beer\" group by g1, g2 order by count(s3) interval 10 limit 23 " +
+ "set $foo = maskdigits(bar), $baz = 12, $bay = $foo logformat generic",
+
+ "SELECT s1, `from`, COUNT(s3) FROM table WHERE w1 == 2 AND w2 eq " +
+ "\"free beer\" GROUP g1, g2 ORDER count(s3) INTERVAL 10 LIMIT 23 " +
+ "SET $foo = maskdigits(bar), $baz = 12, $bay = $foo logformat generic",
}
for _, queryStr := range dialects {
@@ -55,119 +67,144 @@ func TestParseQueryDeep(t *testing.T) {
if err != nil {
t.Errorf("%s: %s", err.Error(), queryStr)
}
-
t.Log(q)
// 'select' clause
if len(q.Select) != 3 {
- t.Errorf("Expected three elements in 'select' clause but got '%v': %s\n%v", q.Select, queryStr, q)
+ t.Errorf("Expected three elements in 'select' clause but got '%v': %s\n%v",
+ q.Select, queryStr, q)
}
-
if q.Select[0].Field != "s1" {
- t.Errorf("Expected 's1' as first element in 'select' clause but got '%v': %s\n%v", q.Select[0].Field, queryStr, q)
+ t.Errorf("Expected 's1' as first element in 'select' clause but got '%v': %s\n%v",
+ q.Select[0].Field, queryStr, q)
}
if q.Select[0].Operation != Last {
- t.Errorf("Expected 'last' as aggregation function of first element in 'select' clause but got '%v': %s\n%v", q.Select[0].Operation, queryStr, q)
+ t.Errorf("Expected 'last' as aggregation function of first element in "+
+ "'select' clause but got '%v': %s\n%v", q.Select[0].Operation, queryStr, q)
}
-
if q.Select[1].Field != "from" {
- t.Errorf("Expected 'from' as second element in 'select' clause but got '%v': %s\n%v", q.Select[1].Field, queryStr, q)
+ t.Errorf("Expected 'from' as second element in 'select' clause but got "+
+ "'%v': %s\n%v", q.Select[1].Field, queryStr, q)
}
if q.Select[1].Operation != Last {
- t.Errorf("Expected 'last' as aggregation function of second element in 'select' clause but got '%v': %s\n%v", q.Select[1].Operation, queryStr, q)
+ t.Errorf("Expected 'last' as aggregation function of second element in "+
+ "'select' clause but got '%v': %s\n%v", q.Select[1].Operation, queryStr, q)
}
-
if q.Select[2].Field != "s3" {
- t.Errorf("Expected 's3' as third element in 'select' clause but got '%v': %s\n%v", q.Select[2].Field, queryStr, q)
+ t.Errorf("Expected 's3' as third element in 'select' clause but got "+
+ "'%v': %s\n%v", q.Select[2].Field, queryStr, q)
}
if q.Select[2].Operation != Count {
- t.Errorf("Expected 'count' as aggregation function of third element in 'select' clause but got '%v': %s\n%v", q.Select[2].Operation, queryStr, q)
+ t.Errorf("Expected 'count' as aggregation function of third element in "+
+ "'select' clause but got '%v': %s\n%v", q.Select[2].Operation, queryStr, q)
}
if q.Select[2].FieldStorage != "count(s3)" {
- t.Errorf("Expected 'count(s3)' as third element's storage in 'select' clause but got '%v': %s\n%v", q.Select[2].FieldStorage, queryStr, q)
+ t.Errorf("Expected 'count(s3)' as third element's storage in 'select' "+
+ "clause but got '%v': %s\n%v", q.Select[2].FieldStorage, queryStr, q)
}
// 'from' clause
if q.Table != "TABLE" {
- t.Errorf("Expected 'TABLE' in 'from' clause but got '%v': %s\n%v", q.Table, queryStr, q)
+ t.Errorf("Expected 'TABLE' in 'from' clause but got '%v': %s\n%v",
+ q.Table, queryStr, q)
}
// 'where' clause
if len(q.Where) != 2 {
- t.Errorf("Expected two elements in 'where' clause but got '%v': %s\n%v", q.Where, queryStr, q)
+ t.Errorf("Expected two elements in 'where' clause but got '%v': %s\n%v",
+ q.Where, queryStr, q)
}
if q.Where[0].lString != "w1" {
- t.Errorf("Expected w1 as first element in 'where' clause but got '%v': %s\n%v", q.Where[0].lString, queryStr, q)
+ t.Errorf("Expected w1 as first element in 'where' clause but got '%v': %s\n%v",
+ q.Where[0].lString, queryStr, q)
}
if q.Where[0].Operation != FloatEq {
- t.Errorf("Expected FloatEq operation in first 'where' condition but got '%v': %s\n%v", q.Where[0].Operation, queryStr, q)
+ t.Errorf("Expected FloatEq operation in first 'where' condition but got "+
+ "'%v': %s\n%v", q.Where[0].Operation, queryStr, q)
}
if q.Where[0].rFloat != 2 {
- t.Errorf("Expected '2' as float argument in first 'where' condition but got '%v': %s\n%v", q.Where[0].rFloat, queryStr, q)
+ t.Errorf("Expected '2' as float argument in first 'where' condition but "+
+ "got '%v': %s\n%v", q.Where[0].rFloat, queryStr, q)
}
if q.Where[1].lString != "w2" {
- t.Errorf("Expected w2 as second element in 'where' clause but got '%v': %s\n%v", q.Where[1].lString, queryStr, q)
+ t.Errorf("Expected w2 as second element in 'where' clause but got '%v': "+
+ "%s\n%v", q.Where[1].lString, queryStr, q)
}
if q.Where[1].Operation != StringEq {
- t.Errorf("Expected StringEq operation in second 'where' condition but got '%v': %s\n%v", q.Where[0].Operation, queryStr, q)
+ t.Errorf("Expected StringEq operation in second 'where' condition but got "+
+ "'%v': %s\n%v", q.Where[0].Operation, queryStr, q)
}
if q.Where[1].rString != "free beer" {
- t.Errorf("Expected 'free beer' as string argument in second 'where' condition but got '%v': %s\n%v", q.Where[0].rString, queryStr, q)
+ t.Errorf("Expected 'free beer' as string argument in second 'where' "+
+ "condition but got '%v': %s\n%v", q.Where[0].rString, queryStr, q)
}
// 'group by' clause
if len(q.GroupBy) != 2 {
- t.Errorf("Expected two elements in 'group by' clause but got '%v': %s\n%v", q.GroupBy, queryStr, q)
+ t.Errorf("Expected two elements in 'group by' clause but got '%v': %s\n%v",
+ q.GroupBy, queryStr, q)
}
if q.GroupBy[0] != "g1" {
- t.Errorf("Expected 'g1' as first element in 'group by' clause but got '%v': %s\n%v", q.GroupBy[0], queryStr, q)
+ t.Errorf("Expected 'g1' as first element in 'group by' clause but got "+
+ "'%v': %s\n%v", q.GroupBy[0], queryStr, q)
}
if q.GroupBy[1] != "g2" {
- t.Errorf("Expected 'g2' as second element in 'group by' clause but got '%v': %s\n%v", q.GroupBy[1], queryStr, q)
+ t.Errorf("Expected 'g2' as second element in 'group by' clause but got "+
+ "'%v': %s\n%v", q.GroupBy[1], queryStr, q)
}
if q.GroupKey != "g1,g2" {
- t.Errorf("Expected 'g1,g2' as group key in 'group by' clause but got '%v': %s\n%v", q.GroupKey, queryStr, q)
+ t.Errorf("Expected 'g1,g2' as group key in 'group by' clause but got "+
+ "'%v': %s\n%v", q.GroupKey, queryStr, q)
}
// 'order by' clause
if q.OrderBy != "count(s3)" {
- t.Errorf("Expected 'count(s3)' as element in 'order by' clause but got '%v': %s\n%v", q.OrderBy, queryStr, q)
+ t.Errorf("Expected 'count(s3)' as element in 'order by' clause but got "+
+ "'%v': %s\n%v", q.OrderBy, queryStr, q)
}
// 'interval' clause
if q.Interval != time.Second*time.Duration(10) {
- t.Errorf("Expected '10s' as duration 'interval' clause but got '%v': %s\n%v", q.Interval, queryStr, q)
+ t.Errorf("Expected '10s' as duration 'interval' clause but got '%v': %s\n%v",
+ q.Interval, queryStr, q)
}
// 'limit' clause
if q.Limit != 23 {
- t.Errorf("Expected '23' as limit in 'limit' clause but got '%v': %s\n%v", q.Limit, queryStr, q)
+ t.Errorf("Expected '23' as limit in 'limit' clause but got '%v': %s\n%v",
+ q.Limit, queryStr, q)
}
// 'set' clause
if q.Set[0].lString != "$foo" {
- t.Errorf("Expected '$foo' lvalue in first 'set' condition clause but got '%v': %s\n%v", q.Set[0].lString, queryStr, q)
+ t.Errorf("Expected '$foo' lvalue in first 'set' condition clause but got "+
+ "'%v': %s\n%v", q.Set[0].lString, queryStr, q)
}
if q.Set[0].rString != "bar" {
- t.Errorf("Expected 'bar' rvalue in first 'set' condition clause but got '%v': %s\n%v", q.Set[0].rString, queryStr, q)
+ t.Errorf("Expected 'bar' rvalue in first 'set' condition clause but got "+
+ "'%v': %s\n%v", q.Set[0].rString, queryStr, q)
}
-
if q.Set[1].lString != "$baz" {
- t.Errorf("Expected '$baz' lvalue in second 'set' condition clause but got '%v': %s\n%v", q.Set[1].lString, queryStr, q)
+ t.Errorf("Expected '$baz' lvalue in second 'set' condition clause but got "+
+ "'%v': %s\n%v", q.Set[1].lString, queryStr, q)
}
if q.Set[1].rString != "12" {
- t.Errorf("Expected '12' rvalue in second 'set' condition clause but got '%v': %s\n%v", q.Set[1].rString, queryStr, q)
+ t.Errorf("Expected '12' rvalue in second 'set' condition clause but got "+
+ "'%v': %s\n%v", q.Set[1].rString, queryStr, q)
}
-
if q.Set[2].lString != "$bay" {
- t.Errorf("Expected '$bay' lvalue in third 'set' condition clause but got '%v': %s\n%v", q.Set[2].lString, queryStr, q)
+ t.Errorf("Expected '$bay' lvalue in third 'set' condition clause but got "+
+ "'%v': %s\n%v", q.Set[2].lString, queryStr, q)
}
if q.Set[2].rString != "$foo" {
- t.Errorf("Expected '$foo' rvalue in third 'set' condition clause but got '%v': %s\n%v", q.Set[2].rString, queryStr, q)
+ t.Errorf("Expected '$foo' rvalue in third 'set' condition clause but got "+
+ "'%v': %s\n%v", q.Set[2].rString, queryStr, q)
}
+ // 'logformat' clause
if q.LogFormat != "generic" {
- t.Errorf("Expected 'generic' logformat got '%v': %s\n%v", q.LogFormat, queryStr, q)
+ t.Errorf("Expected 'generic' logformat got '%v': %s\n%v",
+ q.LogFormat, queryStr, q)
}
}
}
diff --git a/internal/mapr/selectcondition.go b/internal/mapr/selectcondition.go
index d6aa0d4..5cfb8c7 100644
--- a/internal/mapr/selectcondition.go
+++ b/internal/mapr/selectcondition.go
@@ -37,7 +37,6 @@ func (sc selectCondition) String() string {
func makeSelectConditions(tokens []token) ([]selectCondition, error) {
var sel []selectCondition
-
// Parse select aggregation, e.g. sum(foo)
parse := func(token token) (selectCondition, error) {
var sc selectCondition
@@ -52,13 +51,15 @@ func makeSelectConditions(tokens []token) ([]selectCondition, error) {
a := strings.Split(tokenStr, "(")
if len(a) != 2 {
- return sc, errors.New(invalidQuery + "Can't parse 'select' aggregation: " + token.str)
+ return sc, errors.New(invalidQuery + "Can't parse 'select' aggregation: " +
+ token.str)
}
agg := a[0] // Aggregation, e.g. 'sum'
b := strings.Split(a[1], ")")
if len(b) != 2 {
- return sc, errors.New(invalidQuery + "Can't parse 'select' field name from aggregation: " + token.str)
+ return sc, errors.New(invalidQuery + "Can't parse 'select' field name " +
+ "from aggregation: " + token.str)
}
sc.Field = b[0] // Field name, e.g. 'foo'
sc.FieldStorage = tokenStr // e.g. 'sum(foo)'
@@ -79,9 +80,9 @@ func makeSelectConditions(tokens []token) ([]selectCondition, error) {
case "len":
sc.Operation = Len
default:
- return sc, errors.New(invalidQuery + "Unknown aggregation in 'select' clause: " + agg)
+ return sc, errors.New(invalidQuery +
+ "Unknown aggregation in 'select' clause: " + agg)
}
-
return sc, nil
}
@@ -92,6 +93,5 @@ func makeSelectConditions(tokens []token) ([]selectCondition, error) {
}
sel = append(sel, sc)
}
-
return sel, nil
}
diff --git a/internal/mapr/server/aggregate.go b/internal/mapr/server/aggregate.go
index 28bb074..97fee11 100644
--- a/internal/mapr/server/aggregate.go
+++ b/internal/mapr/server/aggregate.go
@@ -8,25 +8,22 @@ import (
"github.com/mimecast/dtail/internal"
"github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/io/dlog"
"github.com/mimecast/dtail/internal/io/line"
- "github.com/mimecast/dtail/internal/io/logger"
"github.com/mimecast/dtail/internal/mapr"
"github.com/mimecast/dtail/internal/mapr/logformat"
+ "github.com/mimecast/dtail/internal/protocol"
)
// Aggregate is for aggregating mapreduce data on the DTail server side.
type Aggregate struct {
done *internal.Done
- // Log lines to process (parsing MAPREDUCE lines).
- Lines chan line.Line
+ // NextLinesCh can be used to use a new line ch.
+ NextLinesCh chan chan line.Line
// Hostname of the current server (used to populate $hostname field).
hostname string
// Signals to serialize data.
serialize chan struct{}
- // Signals to flush data.
- flush chan struct{}
- // Signals that data has been flushed
- flushed chan struct{}
// The mapr query
query *mapr.Query
// The mapr log format parser
@@ -42,7 +39,7 @@ func NewAggregate(queryStr string) (*Aggregate, error) {
fqdn, err := os.Hostname()
if err != nil {
- logger.Error(err)
+ dlog.Common.Error(err)
}
s := strings.Split(fqdn, ".")
@@ -57,38 +54,32 @@ func NewAggregate(queryStr string) (*Aggregate, error) {
parserName = query.LogFormat
}
- logger.Info("Creating log format parser", parserName)
+ dlog.Common.Info("Creating log format parser", parserName)
logParser, err := logformat.NewParser(parserName, query)
if err != nil {
- logger.Error("Could not create log format parser. Falling back to 'generic'", err)
+ dlog.Common.Error("Could not create log format parser. Falling back to 'generic'", err)
if logParser, err = logformat.NewParser("generic", query); err != nil {
- logger.FatalExit("Could not create log format parser", err)
+ dlog.Common.FatalPanic("Could not create log format parser", err)
}
}
- a := Aggregate{
- done: internal.NewDone(),
- Lines: make(chan line.Line, 100),
- serialize: make(chan struct{}),
- flush: make(chan struct{}),
- flushed: make(chan struct{}),
- hostname: s[0],
- query: query,
- parser: logParser,
- }
-
- return &a, nil
+ return &Aggregate{
+ done: internal.NewDone(),
+ NextLinesCh: make(chan chan line.Line, 10),
+ serialize: make(chan struct{}),
+ hostname: s[0],
+ query: query,
+ parser: logParser,
+ }, nil
}
// Shutdown the aggregation engine.
func (a *Aggregate) Shutdown() {
- a.Flush()
a.done.Shutdown()
}
// Start an aggregation.
-func (a *Aggregate) Start(ctx context.Context, maprLines chan<- string) {
-
+func (a *Aggregate) Start(ctx context.Context, maprMessages chan<- string) {
myCtx, cancel := context.WithCancel(ctx)
defer cancel()
@@ -101,15 +92,14 @@ func (a *Aggregate) Start(ctx context.Context, maprLines chan<- string) {
}
}()
- fieldsCh := a.makeFields(myCtx)
-
+ fieldsCh := a.fieldsFromLines(myCtx)
// Add fields (e.g. via 'set' clause)
if len(a.query.Set) > 0 {
- fieldsCh = a.addFields(myCtx, fieldsCh)
+ fieldsCh = a.setAdditionalFields(myCtx, fieldsCh)
}
-
+ // Periodically pre-aggregate data every a.query.Interval seconds.
go a.aggregateTimer(myCtx)
- a.makeMaprLines(myCtx, fieldsCh, maprLines)
+ a.aggregateAndSerialize(myCtx, fieldsCh, maprMessages)
}
func (a *Aggregate) aggregateTimer(ctx context.Context) {
@@ -123,25 +113,46 @@ func (a *Aggregate) aggregateTimer(ctx context.Context) {
}
}
-func (a *Aggregate) makeFields(ctx context.Context) <-chan map[string]string {
- ch := make(chan map[string]string)
+func (a *Aggregate) fieldsFromLines(ctx context.Context) <-chan map[string]string {
+ fieldsCh := make(chan map[string]string)
go func() {
- defer close(ch)
+ defer close(fieldsCh)
+ var lines chan line.Line
+
+ // Gather first lines channel (first input file)
+ select {
+ case lines = <-a.NextLinesCh:
+ case <-ctx.Done():
+ return
+ }
for {
select {
- case line, ok := <-a.Lines:
+ case line, ok := <-lines:
if !ok {
- return
+ select {
+ case lines = <-a.NextLinesCh:
+ // Have a new lines channel (e.g. new input file)
+ case <-ctx.Done():
+ default:
+ // No new lines channel found.
+ return
+ }
}
- maprLine := strings.TrimSpace(string(line.Content))
+ maprLine := strings.TrimSpace(line.Content.String())
fields, err := a.parser.MakeFields(maprLine)
- logger.Debug(fields, err)
+ // Can't recycle it here yet, as field slices are still
+ // TODO: Add unit test reading from multiple mapreduce files lines.
+ // TODO: Add capability to recycle this bytes buffer.
+ //pool.RecycleBytesBuffer(line.Content)
if err != nil {
- logger.Error(err)
+ // Should fields be ignored anyway?
+ if err != logformat.ErrIgnoreFields {
+ dlog.Common.Error(fields, err)
+ }
continue
}
if !a.query.WhereClause(fields) {
@@ -149,7 +160,7 @@ func (a *Aggregate) makeFields(ctx context.Context) <-chan map[string]string {
}
select {
- case ch <- fields:
+ case fieldsCh <- fields:
case <-ctx.Done():
}
case <-ctx.Done():
@@ -158,45 +169,42 @@ func (a *Aggregate) makeFields(ctx context.Context) <-chan map[string]string {
}
}()
- return ch
+ return fieldsCh
}
-func (a *Aggregate) addFields(ctx context.Context, fieldsCh <-chan map[string]string) <-chan map[string]string {
- ch := make(chan map[string]string)
+func (a *Aggregate) setAdditionalFields(ctx context.Context,
+ fieldsCh <-chan map[string]string) <-chan map[string]string {
+ newFieldsCh := make(chan map[string]string)
go func() {
- defer close(ch)
-
+ defer close(newFieldsCh)
for {
- // fieldsCh will be closed via 'makeFields' if ctx is done
fields, ok := <-fieldsCh
if !ok {
return
}
if err := a.query.SetClause(fields); err != nil {
- logger.Error(err)
+ dlog.Common.Error(err)
}
select {
- case ch <- fields:
+ case newFieldsCh <- fields:
case <-ctx.Done():
}
}
}()
-
- return ch
+ return newFieldsCh
}
-func (a *Aggregate) makeMaprLines(ctx context.Context, fieldsCh <-chan map[string]string, maprLines chan<- string) {
- group := mapr.NewGroupSet()
+func (a *Aggregate) aggregateAndSerialize(ctx context.Context,
+ fieldsCh <-chan map[string]string, maprMessages chan<- string) {
+ group := mapr.NewGroupSet()
serialize := func() {
- logger.Info("Serializing mapreduce result")
- group.Serialize(ctx, maprLines)
+ dlog.Common.Info("Serializing mapreduce result")
+ group.Serialize(ctx, maprMessages)
group = mapr.NewGroupSet()
- logger.Info("Done serializing mapreduce result")
}
-
for {
select {
case fields, ok := <-fieldsCh:
@@ -207,9 +215,6 @@ func (a *Aggregate) makeMaprLines(ctx context.Context, fieldsCh <-chan map[strin
a.aggregate(group, fields)
case <-a.serialize:
serialize()
- case <-a.flush:
- serialize()
- a.flushed <- struct{}{}
case <-ctx.Done():
return
}
@@ -217,12 +222,10 @@ func (a *Aggregate) makeMaprLines(ctx context.Context, fieldsCh <-chan map[strin
}
func (a *Aggregate) aggregate(group *mapr.GroupSet, fields map[string]string) {
- //logger.Trace("Aggregating", group, fields)
var sb strings.Builder
-
for i, field := range a.query.GroupBy {
if i > 0 {
- sb.WriteString(" ")
+ sb.WriteString(protocol.AggregateGroupKeyCombinator)
}
if val, ok := fields[field]; ok {
sb.WriteString(val)
@@ -235,7 +238,7 @@ func (a *Aggregate) aggregate(group *mapr.GroupSet, fields map[string]string) {
for _, sc := range a.query.Select {
if val, ok := fields[sc.Field]; ok {
if err := set.Aggregate(sc.FieldStorage, sc.Operation, val, false); err != nil {
- logger.Error(err)
+ dlog.Common.Error(err)
continue
}
addedSample = true
@@ -246,8 +249,7 @@ func (a *Aggregate) aggregate(group *mapr.GroupSet, fields map[string]string) {
set.Samples++
return
}
-
- logger.Trace("Aggregated data locally without adding new samples")
+ dlog.Common.Trace("Aggregated data locally without adding new samples")
}
// Serialize all the aggregated data.
@@ -255,28 +257,7 @@ func (a *Aggregate) Serialize(ctx context.Context) {
select {
case a.serialize <- struct{}{}:
case <-time.After(time.Minute):
- logger.Warn("Starting to serialize mapredice data takes over a minute")
+ dlog.Common.Warn("Starting to serialize mapredice data takes over a minute")
case <-ctx.Done():
}
}
-
-// Flush all data.
-func (a *Aggregate) Flush() {
- select {
- case a.flush <- struct{}{}:
- logger.Info("Flushing mapreduce data")
- case <-time.After(time.Minute):
- logger.Warn("Starting to flush mapreduce data takes over a minute")
- return
- case <-a.done.Done():
- return
- }
-
- select {
- case <-a.flushed:
- logger.Info("Done flushing")
- case <-time.After(time.Minute):
- logger.Warn("Waiting for data to be flushed takes over a minute")
- case <-a.done.Done():
- }
-}
diff --git a/internal/mapr/setclause.go b/internal/mapr/setclause.go
index b4c2f73..1843d31 100644
--- a/internal/mapr/setclause.go
+++ b/internal/mapr/setclause.go
@@ -7,7 +7,6 @@ func (q *Query) SetClause(fields map[string]string) error {
if !ok {
continue
}
-
switch sc.rType {
case FunctionStack:
fields[sc.lString] = sc.functionStack.Call(value)
@@ -15,6 +14,5 @@ func (q *Query) SetClause(fields map[string]string) error {
fields[sc.lString] = value
}
}
-
return nil
}
diff --git a/internal/mapr/setcondition.go b/internal/mapr/setcondition.go
index 8c5cfc9..92b21f4 100644
--- a/internal/mapr/setcondition.go
+++ b/internal/mapr/setcondition.go
@@ -39,20 +39,22 @@ func makeSetConditions(tokens []token) (set []setCondition, err error) {
switch setOp {
case "=":
default:
- return sc, nil, errors.New(invalidQuery + "Unknown operation in 'set' clause: " + setOp)
+ return sc, nil, errors.New(invalidQuery + "Unknown operation in 'set' " +
+ "clause: " + setOp)
}
if !tokens[0].isBareword {
- return sc, nil, errors.New(invalidQuery + "Expected bareword at 'set' clause's lValue: " + tokens[0].str)
+ return sc, nil, errors.New(invalidQuery + "Expected bareword at 'set' " +
+ "clause's lValue: " + tokens[0].str)
}
-
sc.lString = tokens[0].str
if !strings.HasPrefix(sc.lString, "$") {
- return sc, nil, errors.New(invalidQuery + "Expected field variable name (starting with $) at 'set' clause's lValue: " + tokens[0].str)
+ return sc, nil, errors.New(invalidQuery + "Expected field variable name " +
+ "(starting with $) at 'set' clause's lValue: " + tokens[0].str)
}
sc.rType = Field
-
rString := tokens[2].str
+
// Seems like a function call?
if strings.HasSuffix(rString, ")") {
functionStack, functionArg, err := funcs.NewFunctionStack(tokens[2].str)
@@ -72,7 +74,6 @@ func makeSetConditions(tokens []token) (set []setCondition, err error) {
} else {
sc.rType = Field
}
-
return sc, tokens[3:], nil
}
@@ -84,10 +85,8 @@ func makeSetConditions(tokens []token) (set []setCondition, err error) {
if err != nil {
return nil, err
}
-
set = append(set, sc)
tokens = tokensConsumeOptional(tokens, ",")
}
-
return
}
diff --git a/internal/mapr/token.go b/internal/mapr/token.go
index 8972188..6ac7631 100644
--- a/internal/mapr/token.go
+++ b/internal/mapr/token.go
@@ -4,7 +4,8 @@ import (
"strings"
)
-var keywords = [...]string{"select", "from", "where", "set", "group", "rorder", "order", "interval", "limit", "outfile", "logformat"}
+var keywords = [...]string{"select", "from", "where", "set", "group", "rorder",
+ "order", "interval", "limit", "outfile", "logformat"}
// Represents a parsed token, used to parse the mapr query.
type token struct {
@@ -16,13 +17,11 @@ func (t token) isKeyword() bool {
if !t.isBareword {
return false
}
-
for _, keyword := range keywords {
if strings.ToLower(t.str) == keyword {
return true
}
}
-
return false
}
@@ -32,7 +31,6 @@ func (t token) String() string {
func tokenize(queryStr string) []token {
var tokens []token
-
for i, part := range strings.Split(queryStr, "\"") {
// Even i, means that it is not a quoted string
if i%2 == 0 {
@@ -53,17 +51,15 @@ func tokenize(queryStr string) []token {
}
tokens = append(tokens, token)
}
-
return tokens
}
func tokensConsume(tokens []token) ([]token, []token) {
- //logger.Trace("=====================")
+ //dlog.Common.Trace("=====================")
var consumed []token
-
for i, t := range tokens {
if t.isKeyword() {
- //logger.Trace("keyword", t)
+ //dlog.Common.Trace("keyword", t)
return tokens[i:], consumed
}
// strip escapes, such as ` from `foo`, this allows to use keywords as field names
@@ -73,7 +69,7 @@ func tokensConsume(tokens []token) ([]token, []token) {
}
if t.str[0] == '`' && t.str[length-1] == '`' {
stripped := t.str[1 : length-1]
- //logger.Trace("stripped", stripped)
+ //dlog.Common.Trace("stripped", stripped)
t := token{
str: stripped,
isBareword: t.isBareword,
@@ -81,11 +77,10 @@ func tokensConsume(tokens []token) ([]token, []token) {
consumed = append(consumed, t)
continue
}
- //logger.Trace("bare", token)
+ //dlog.Common.Trace("bare", token)
consumed = append(consumed, t)
}
-
- //logger.Trace("result", consumed)
+ //dlog.Common.Trace("result", consumed)
return nil, consumed
}
@@ -95,7 +90,6 @@ func tokensConsumeStr(tokens []token) ([]token, []string) {
for _, token := range found {
strings = append(strings, token.str)
}
-
return tokens, strings
}
@@ -106,6 +100,5 @@ func tokensConsumeOptional(tokens []token, optional string) []token {
if strings.ToLower(tokens[0].str) == strings.ToLower(optional) {
return tokens[1:]
}
-
return tokens
}
diff --git a/internal/mapr/whereclause.go b/internal/mapr/whereclause.go
index cc1c164..d9f32eb 100644
--- a/internal/mapr/whereclause.go
+++ b/internal/mapr/whereclause.go
@@ -3,14 +3,13 @@ package mapr
import (
"strconv"
- "github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/io/dlog"
)
// WhereClause interprets the where clause of the mapreduce query.
func (q *Query) WhereClause(fields map[string]string) bool {
for _, wc := range q.Where {
var ok bool
-
if wc.Operation > FloatOperation {
var lValue, rValue float64
if lValue, ok = whereClauseFloatValue(fields, wc.lString, wc.lFloat, wc.lType); !ok {
@@ -36,11 +35,12 @@ func (q *Query) WhereClause(fields map[string]string) bool {
return false
}
}
-
return true
}
-func whereClauseFloatValue(fields map[string]string, str string, float float64, t fieldType) (float64, bool) {
+func whereClauseFloatValue(fields map[string]string, str string, float float64,
+ t fieldType) (float64, bool) {
+
switch t {
case Float:
return float, true
@@ -55,12 +55,14 @@ func whereClauseFloatValue(fields map[string]string, str string, float float64,
}
return f, true
default:
- logger.Error("Unexpected argument in 'where' clause", str, float, t)
+ dlog.Common.Error("Unexpected argument in 'where' clause", str, float, t)
return 0, false
}
}
-func whereClauseStringValue(fields map[string]string, str string, t fieldType) (string, bool) {
+func whereClauseStringValue(fields map[string]string, str string,
+ t fieldType) (string, bool) {
+
switch t {
case Field:
value, ok := fields[str]
@@ -71,7 +73,7 @@ func whereClauseStringValue(fields map[string]string, str string, t fieldType) (
case String:
return str, true
default:
- logger.Error("Unexpected argument in 'where' clause", str, t)
+ dlog.Common.Error("Unexpected argument in 'where' clause", str, t)
return str, false
}
}
diff --git a/internal/mapr/wherecondition.go b/internal/mapr/wherecondition.go
index 7a60dba..280dcfb 100644
--- a/internal/mapr/wherecondition.go
+++ b/internal/mapr/wherecondition.go
@@ -6,7 +6,7 @@ import (
"strconv"
"strings"
- "github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/io/dlog"
)
// QueryOperation determines the mapreduce operation.
@@ -46,15 +46,18 @@ type whereCondition struct {
}
func (wc *whereCondition) String() string {
- return fmt.Sprintf("whereCondition(Operation:%v,lString:%s,lFloat:%v,lType:%s,rString:%s,rFloat:%v,rType:%s)",
- wc.Operation, wc.lString, wc.lFloat, wc.lType.String(), wc.rString, wc.rFloat, wc.rType.String())
+ return fmt.Sprintf("whereCondition(Operation:%v,lString:%s,lFloat:%v,"+
+ "lType:%s,rString:%s,rFloat:%v,rType:%s)",
+ wc.Operation, wc.lString, wc.lFloat, wc.lType.String(), wc.rString,
+ wc.rFloat, wc.rType.String())
}
func makeWhereConditions(tokens []token) (where []whereCondition, err error) {
parse := func(tokens []token) (whereCondition, []token, error) {
var wc whereCondition
if len(tokens) < 3 {
- return wc, nil, errors.New(invalidQuery + "Not enough arguments in 'where' clause")
+ err := errors.New(invalidQuery + "Not enough arguments in 'where' clause")
+ return wc, nil, err
}
whereOp := strings.ToLower(tokens[1].str)
@@ -94,7 +97,8 @@ func makeWhereConditions(tokens []token) (where []whereCondition, err error) {
case "nhassuffix":
wc.Operation = StringNotHasSuffix
default:
- return wc, nil, errors.New(invalidQuery + "Unknown operation in 'where' clause: " + whereOp)
+ return wc, nil, errors.New(invalidQuery +
+ "Unknown operation in 'where' clause: " + whereOp)
}
wc.lString = tokens[0].str
@@ -102,7 +106,8 @@ func makeWhereConditions(tokens []token) (where []whereCondition, err error) {
if wc.Operation > FloatOperation {
if !tokens[0].isBareword {
- return wc, nil, errors.New(invalidQuery + "Expected bareword at 'where' clause's lValue: " + tokens[0].str)
+ return wc, nil, errors.New(invalidQuery +
+ "Expected bareword at 'where' clause's lValue: " + tokens[0].str)
}
if f, err := strconv.ParseFloat(wc.lString, 64); err == nil {
wc.lFloat = f
@@ -112,7 +117,8 @@ func makeWhereConditions(tokens []token) (where []whereCondition, err error) {
}
if !tokens[2].isBareword {
- return wc, nil, errors.New(invalidQuery + "Expected bareword at 'where' clause's rValue: " + tokens[2].str)
+ return wc, nil, errors.New(invalidQuery +
+ "Expected bareword at 'where' clause's rValue: " + tokens[2].str)
}
if f, err := strconv.ParseFloat(wc.rString, 64); err == nil {
wc.rFloat = f
@@ -133,23 +139,19 @@ func makeWhereConditions(tokens []token) (where []whereCondition, err error) {
} else {
wc.rType = String
}
-
return wc, tokens[3:], nil
}
for len(tokens) > 0 {
var wc whereCondition
var err error
-
wc, tokens, err = parse(tokens)
if err != nil {
return nil, err
}
-
where = append(where, wc)
tokens = tokensConsumeOptional(tokens, "and")
}
-
return
}
@@ -168,9 +170,8 @@ func (wc *whereCondition) floatClause(lValue float64, rValue float64) bool {
case FloatGe:
return lValue >= rValue
default:
- logger.Error("Unknown float operation", lValue, wc.Operation, rValue)
+ dlog.Common.Error("Unknown float operation", lValue, wc.Operation, rValue)
}
-
return false
}
@@ -193,8 +194,7 @@ func (wc *whereCondition) stringClause(lValue string, rValue string) bool {
case StringNotHasSuffix:
return !strings.HasSuffix(lValue, rValue)
default:
- logger.Error("Unknown string operation", lValue, wc.Operation, rValue)
+ dlog.Common.Error("Unknown string operation", lValue, wc.Operation, rValue)
}
-
return false
}
diff --git a/internal/protocol/protocol.go b/internal/protocol/protocol.go
new file mode 100644
index 0000000..d29706c
--- /dev/null
+++ b/internal/protocol/protocol.go
@@ -0,0 +1,18 @@
+package protocol
+
+const (
+ // ProtocolCompat -ibility version
+ ProtocolCompat string = "4"
+ // MessageDelimiter delimits separate messages.
+ MessageDelimiter byte = '¬'
+ // FieldDelimiter delimits messagefields.
+ FieldDelimiter string = "|"
+ // CSVDelimiter delimits CSV file fields.kj:w
+ CSVDelimiter string = ","
+ // AggregateKVDelimiter delimits key-values of an aggregation message.
+ AggregateKVDelimiter string = "≔"
+ // AggregateDelimiter delimits parts of an aggregation message.
+ AggregateDelimiter string = "∥"
+ // AggregateGroupKeyCombinator combines the group set keys.
+ AggregateGroupKeyCombinator string = ","
+)
diff --git a/internal/regex/regex.go b/internal/regex/regex.go
index 2561659..eb6e1b3 100644
--- a/internal/regex/regex.go
+++ b/internal/regex/regex.go
@@ -4,8 +4,6 @@ import (
"fmt"
"regexp"
"strings"
-
- "github.com/mimecast/dtail/internal/io/logger"
)
// Regex for filtering lines.
@@ -50,9 +48,7 @@ func new(regexStr string, flags []Flag) (Regex, error) {
regexStr: regexStr,
flags: flags,
}
-
re, err := regexp.Compile(regexStr)
-
if err != nil {
return r, err
}
@@ -91,17 +87,15 @@ func (r Regex) MatchString(str string) bool {
}
// Serialize the regex.
-func (r Regex) Serialize() string {
+func (r Regex) Serialize() (string, error) {
var flags []string
for _, flag := range r.flags {
flags = append(flags, flag.String())
}
-
if !r.initialized {
- logger.FatalExit("Unable to serialize regex as not initialized properly", r)
+ return "", fmt.Errorf("Unable to serialize regex as not initialized properly: %v", r)
}
-
- return fmt.Sprintf("regex:%s %s", strings.Join(flags, ","), r.regexStr)
+ return fmt.Sprintf("regex:%s %s", strings.Join(flags, ","), r.regexStr), nil
}
// Deserialize the regex.
@@ -109,15 +103,14 @@ func Deserialize(str string) (Regex, error) {
// Get regex string
s := strings.SplitN(str, " ", 2)
if len(s) < 2 {
- logger.Debug("Using noop regex", str)
return NewNoop(), nil
}
-
flagsStr := s[0]
regexStr := s[1]
if !strings.HasPrefix(flagsStr, "regex") {
- return Regex{}, fmt.Errorf("unable to deserialize regex '%s': should start with string 'regex'", str)
+ return Regex{}, fmt.Errorf("unable to deserialize regex '%s': should start "+
+ "with string 'regex'", str)
}
// Parse regex flags, e.g. "regex:flag1,flag2,flag3..."
@@ -127,13 +120,10 @@ func Deserialize(str string) (Regex, error) {
for _, flagStr := range strings.Split(s[1], ",") {
flag, err := NewFlag(flagStr)
if err != nil {
- logger.Error("ignoring flag", err)
continue
}
- logger.Debug("Adding regex flag", flag)
flags = append(flags, flag)
}
}
-
return new(regexStr, flags)
}
diff --git a/internal/regex/regex_test.go b/internal/regex/regex_test.go
index a5e7faf..033a286 100644
--- a/internal/regex/regex_test.go
+++ b/internal/regex/regex_test.go
@@ -9,7 +9,8 @@ func TestRegex(t *testing.T) {
r := NewNoop()
if !r.MatchString(input) {
- t.Errorf("expected to match string '%s' with noop regex '%v' but didn't\n", input, r)
+ t.Errorf("expected to match string '%s' with noop regex '%v' but didn't\n",
+ input, r)
}
r, err := New(".hello", Default)
@@ -17,16 +18,21 @@ func TestRegex(t *testing.T) {
t.Errorf("unable to create regex: %v\n", err)
}
if r.MatchString(input) {
- t.Errorf("expected to match string '%s' with regex '%v' but didn't\n", input, r)
+ t.Errorf("expected to match string '%s' with regex '%v' but didn't\n",
+ input, r)
}
- r2, err := Deserialize(r.Serialize())
+ serialized, err := r.Serialize()
if err != nil {
- t.Errorf("unable to serialize deserialized regex: %v: %v\n", r.Serialize(), err)
+ t.Errorf("unable to serialize regex: %v: %v\n", serialized, err)
+ }
+ r2, err := Deserialize(serialized)
+ if err != nil {
+ t.Errorf("unable to serialize deserialized regex: %v: %v\n", serialized, err)
}
if r.String() != r2.String() {
- t.Errorf("regex should be the same after deserialize(serialize(..)), got '%s' but expected '%s'.\n",
- r2.String(), r.String())
+ t.Errorf("regex should be the same after deserialize(serialize(..)), got "+
+ "'%s' but expected '%s'.\n", r2.String(), r.String())
}
r, err = New(".hello", Invert)
@@ -34,15 +40,20 @@ func TestRegex(t *testing.T) {
t.Errorf("unable to create regex: %v\n", err)
}
if !r.MatchString(input) {
- t.Errorf("expected to not match string '%s' with regex '%v' but matched\n", input, r)
+ t.Errorf("expected to not match string '%s' with regex '%v' but matched\n",
+ input, r)
}
- r2, err = Deserialize(r.Serialize())
+ serialized, err = r.Serialize()
+ if err != nil {
+ t.Errorf("unable to serialize regex: %v: %v\n", serialized, err)
+ }
+ r2, err = Deserialize(serialized)
if err != nil {
- t.Errorf("unable to serialize deserialized regex: %v: %v\n", r.Serialize(), err)
+ t.Errorf("unable to serialize deserialized regex: %v: %v\n", serialized, err)
}
if r.String() != r2.String() {
- t.Errorf("regex should be the same after deserialize(serialize(..)), got '%s' but expected '%s'.\n",
- r2.String(), r.String())
+ t.Errorf("regex should be the same after deserialize(serialize(..)), got "+
+ "'%s' but expected '%s'.\n", r2.String(), r.String())
}
}
diff --git a/internal/server/continuous.go b/internal/server/continuous.go
index f75c732..93b3fcb 100644
--- a/internal/server/continuous.go
+++ b/internal/server/continuous.go
@@ -8,33 +8,29 @@ import (
"github.com/mimecast/dtail/internal/clients"
"github.com/mimecast/dtail/internal/config"
- "github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/io/dlog"
"github.com/mimecast/dtail/internal/omode"
-
gossh "golang.org/x/crypto/ssh"
)
-type continuous struct {
-}
+type continuous struct{}
func newContinuous() *continuous {
return &continuous{}
}
func (c *continuous) start(ctx context.Context) {
- logger.Info("Starting continuous job runner after 10s")
+ dlog.Server.Info("Starting continuous job runner after 10s")
time.Sleep(time.Second * 10)
-
c.runJobs(ctx)
}
func (c *continuous) runJobs(ctx context.Context) {
for _, job := range config.Server.Continuous {
if !job.Enable {
- logger.Debug(job.Name, "Not running job as not enabled")
+ dlog.Server.Debug(job.Name, "Not running job as not enabled")
continue
}
-
go func(job config.Continuous) {
c.runJob(ctx, job)
for {
@@ -51,18 +47,17 @@ func (c *continuous) runJobs(ctx context.Context) {
}
func (c *continuous) runJob(ctx context.Context, job config.Continuous) {
- logger.Debug(job.Name, "Processing job")
+ dlog.Server.Debug(job.Name, "Processing job")
files := fillDates(job.Files)
outfile := fillDates(job.Outfile)
-
servers := strings.Join(job.Servers, ",")
if servers == "" {
servers = config.Server.SSHBindAddress
}
- args := clients.Args{
- ConnectionsPerCPU: 10,
+ args := config.Args{
+ ConnectionsPerCPU: config.DefaultConnectionsPerCPU,
Discovery: job.Discovery,
ServersStr: servers,
What: files,
@@ -71,35 +66,32 @@ func (c *continuous) runJob(ctx context.Context, job config.Continuous) {
}
args.SSHAuthMethods = append(args.SSHAuthMethods, gossh.Password(job.Name))
-
- query := fmt.Sprintf("%s outfile %s", job.Query, outfile)
- client, err := clients.NewMaprClient(args, query, clients.NonCumulativeMode)
+ args.QueryStr = fmt.Sprintf("%s outfile %s", job.Query, outfile)
+ client, err := clients.NewMaprClient(args, clients.NonCumulativeMode)
if err != nil {
- logger.Error(fmt.Sprintf("Unable to create job %s", job.Name), err)
+ dlog.Server.Error(fmt.Sprintf("Unable to create job %s", job.Name), err)
return
}
jobCtx, cancel := context.WithCancel(ctx)
defer cancel()
-
if job.RestartOnDayChange {
go func() {
if c.waitForDayChange(ctx) {
- logger.Info(fmt.Sprintf("Canceling job %s due to day change", job.Name))
+ dlog.Server.Info(fmt.Sprintf("Canceling job %s due to day change", job.Name))
cancel()
}
}()
}
- logger.Info(fmt.Sprintf("Starting job %s", job.Name))
+ dlog.Server.Info(fmt.Sprintf("Starting job %s", job.Name))
status := client.Start(jobCtx, make(chan string))
logMessage := fmt.Sprintf("Job exited with status %d", status)
-
if status != 0 {
- logger.Warn(logMessage)
+ dlog.Server.Warn(logMessage)
return
}
- logger.Info(logMessage)
+ dlog.Server.Info(logMessage)
}
func (c *continuous) waitForDayChange(ctx context.Context) bool {
diff --git a/internal/server/handlers/basehandler.go b/internal/server/handlers/basehandler.go
new file mode 100644
index 0000000..6d10d17
--- /dev/null
+++ b/internal/server/handlers/basehandler.go
@@ -0,0 +1,320 @@
+package handlers
+
+import (
+ "bytes"
+ "context"
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "io"
+ "strconv"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/mimecast/dtail/internal"
+ "github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/io/dlog"
+ "github.com/mimecast/dtail/internal/io/line"
+ "github.com/mimecast/dtail/internal/io/pool"
+ "github.com/mimecast/dtail/internal/lcontext"
+ "github.com/mimecast/dtail/internal/mapr/server"
+ "github.com/mimecast/dtail/internal/protocol"
+ user "github.com/mimecast/dtail/internal/user/server"
+)
+
+type handleCommandCb func(context.Context, lcontext.LContext, int, []string, string)
+
+type baseHandler struct {
+ done *internal.Done
+ handleCommandCb handleCommandCb
+ lines chan line.Line
+ aggregate *server.Aggregate
+ maprMessages chan string
+ serverMessages chan string
+ hostname string
+ user *user.User
+ ackCloseReceived chan struct{}
+ activeCommands int32
+ readBuf bytes.Buffer
+ writeBuf bytes.Buffer
+
+ // Some global options + sync primitives required.
+ once sync.Once
+ mutex sync.Mutex
+ quiet bool
+ spartan bool
+ serverless bool
+}
+
+// Shutdown the handler.
+func (h *baseHandler) Shutdown() {
+ h.done.Shutdown()
+}
+
+// Done channel of the handler.
+func (h *baseHandler) Done() <-chan struct{} {
+ return h.done.Done()
+}
+
+// Read is to send data to the dtail client via Reader interface.
+func (h *baseHandler) Read(p []byte) (n int, err error) {
+ defer h.readBuf.Reset()
+
+ select {
+ case message := <-h.serverMessages:
+ if len(message) > 0 && message[0] == '.' {
+ // Handle hidden message (don't display to the user)
+ h.readBuf.WriteString(message)
+ h.readBuf.WriteByte(protocol.MessageDelimiter)
+ n = copy(p, h.readBuf.Bytes())
+ return
+ }
+
+ if h.serverless {
+ return
+ }
+
+ // Handle normal server message (display to the user)
+ h.readBuf.WriteString("SERVER")
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(h.hostname)
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(message)
+ h.readBuf.WriteByte(protocol.MessageDelimiter)
+ n = copy(p, h.readBuf.Bytes())
+
+ case message := <-h.maprMessages:
+ // Send mapreduce-aggregated data as a message.
+ h.readBuf.WriteString("AGGREGATE")
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(h.hostname)
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(message)
+ h.readBuf.WriteByte(protocol.MessageDelimiter)
+ n = copy(p, h.readBuf.Bytes())
+
+ case line := <-h.lines:
+ if !h.spartan {
+ h.readBuf.WriteString("REMOTE")
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(h.hostname)
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(fmt.Sprintf("%3d", line.TransmittedPerc))
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(fmt.Sprintf("%v", line.Count))
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ h.readBuf.WriteString(line.SourceID)
+ h.readBuf.WriteString(protocol.FieldDelimiter)
+ }
+ h.readBuf.WriteString(line.Content.String())
+ h.readBuf.WriteByte(protocol.MessageDelimiter)
+ n = copy(p, h.readBuf.Bytes())
+ pool.RecycleBytesBuffer(line.Content)
+
+ case <-time.After(time.Second):
+ // Once in a while check whether we are done.
+ select {
+ case <-h.done.Done():
+ err = io.EOF
+ return
+ default:
+ }
+ }
+ return
+}
+
+// Write is to receive data from the dtail client via Writer interface.
+func (h *baseHandler) Write(p []byte) (n int, err error) {
+ for _, b := range p {
+ switch b {
+ case ';':
+ h.handleCommand(string(h.writeBuf.Bytes()))
+ h.writeBuf.Reset()
+ default:
+ h.writeBuf.WriteByte(b)
+ }
+ }
+ n = len(p)
+ return
+}
+
+func (h *baseHandler) handleCommand(commandStr string) {
+ dlog.Server.Debug(h.user, commandStr)
+
+ args, argc, add, err := h.handleProtocolVersion(strings.Split(commandStr, " "))
+ if err != nil {
+ h.send(h.serverMessages, dlog.Server.Error(h.user, err)+add)
+ return
+ }
+ args, argc, err = h.handleBase64(args, argc)
+ if err != nil {
+ h.send(h.serverMessages, dlog.Server.Error(h.user, err))
+ return
+ }
+ ctx, cancel := context.WithCancel(context.Background())
+ go func() {
+ <-h.done.Done()
+ cancel()
+ }()
+
+ parts := strings.Split(args[0], ":")
+ commandName := parts[0]
+
+ // Either no options or empty options provided.
+ if len(parts) == 1 || len(parts[1]) == 0 {
+ h.handleCommandCb(ctx, lcontext.LContext{}, argc, args, commandName)
+ return
+ }
+
+ options, ltx, err := config.DeserializeOptions(parts[1:])
+ if err != nil {
+ h.send(h.serverMessages, dlog.Server.Error(h.user, err))
+ return
+ }
+ h.handleOptions(options)
+ h.handleCommandCb(ctx, ltx, argc, args, commandName)
+}
+
+func (h *baseHandler) handleProtocolVersion(args []string) ([]string, int, string, error) {
+ argc := len(args)
+ var add string
+
+ if argc <= 2 || args[0] != "protocol" {
+ return args, argc, add, errors.New("unable to determine protocol version")
+ }
+
+ if args[1] != protocol.ProtocolCompat {
+ clientCompat, _ := strconv.Atoi(args[1])
+ serverCompat, _ := strconv.Atoi(protocol.ProtocolCompat)
+ if clientCompat <= 3 {
+ // Protocol version 3 or lower expect a newline as message separator
+ // One day (after 2 major versions) this exception may be removed!
+ add = "\n"
+ }
+
+ toUpdate := "client"
+ if clientCompat > serverCompat {
+ toUpdate = "server"
+ }
+ err := fmt.Errorf("the DTail server protocol version '%s' does not match "+
+ "client protocol version '%s', please update DTail %s",
+ protocol.ProtocolCompat, args[1], toUpdate)
+ return args, argc, add, err
+ }
+
+ return args[2:], argc - 2, add, nil
+}
+
+func (h *baseHandler) handleBase64(args []string, argc int) ([]string, int, error) {
+ err := errors.New("unable to decode client message, DTail server and client " +
+ "versions may not be compatible")
+ if argc != 2 || args[0] != "base64" {
+ return args, argc, err
+ }
+
+ decoded, err := base64.StdEncoding.DecodeString(args[1])
+ if err != nil {
+ return args, argc, err
+ }
+ decodedStr := string(decoded)
+
+ args = strings.Split(decodedStr, " ")
+ argc = len(decodedStr)
+ dlog.Server.Trace(h.user, "Base64 decoded received command",
+ decodedStr, argc, args)
+
+ return args, argc, nil
+}
+
+func (h *baseHandler) handleAckCommand(argc int, args []string) {
+ if argc < 3 {
+ if !h.quiet {
+ h.send(h.serverMessages, dlog.Server.Warn(h.user,
+ "Unable to parse command", args, argc))
+ }
+ return
+ }
+ if args[1] == "close" && args[2] == "connection" {
+ select {
+ case <-h.ackCloseReceived:
+ default:
+ close(h.ackCloseReceived)
+ }
+ }
+}
+
+func (h *baseHandler) handleOptions(options map[string]string) {
+ // We have to make sure that this block is executed only once.
+ h.mutex.Lock()
+ defer h.mutex.Unlock()
+ // We can read the options only once, will cause a data race otherwise if
+ // changed multiple times for multiple incoming commands.
+ h.once.Do(func() {
+ if quiet, _ := options["quiet"]; quiet == "true" {
+ dlog.Server.Debug(h.user, "Enabling quiet mode")
+ h.quiet = true
+ }
+ if spartan, _ := options["spartan"]; spartan == "true" {
+ dlog.Server.Debug(h.user, "Enabling spartan mode")
+ h.spartan = true
+ }
+ if serverless, _ := options["serverless"]; serverless == "true" {
+ dlog.Server.Debug(h.user, "Enabling serverless mode")
+ h.serverless = true
+ }
+ })
+}
+
+func (h *baseHandler) send(ch chan<- string, message string) {
+ select {
+ case ch <- message:
+ case <-h.done.Done():
+ }
+}
+
+func (h *baseHandler) flush() {
+ dlog.Server.Trace(h.user, "flush()")
+ numUnsentMessages := func() int {
+ return len(h.lines) + len(h.serverMessages) + len(h.maprMessages)
+ }
+ for i := 0; i < 10; i++ {
+ if numUnsentMessages() == 0 {
+ dlog.Server.Debug(h.user, "ALL lines sent", fmt.Sprintf("%p", h))
+ return
+ }
+ dlog.Server.Debug(h.user, "Still lines to be sent")
+ time.Sleep(time.Millisecond * 10)
+ }
+ dlog.Server.Warn(h.user, "Some lines remain unsent", numUnsentMessages())
+}
+
+func (h *baseHandler) shutdown() {
+ dlog.Server.Debug(h.user, "shutdown()")
+ h.flush()
+
+ go func() {
+ select {
+ case h.serverMessages <- ".syn close connection":
+ case <-h.done.Done():
+ }
+ }()
+
+ select {
+ case <-h.ackCloseReceived:
+ case <-time.After(time.Second * 5):
+ dlog.Server.Debug(h.user, "Shutdown timeout reached, enforcing shutdown")
+ case <-h.done.Done():
+ }
+ h.done.Shutdown()
+}
+
+func (h *baseHandler) incrementActiveCommands() {
+ atomic.AddInt32(&h.activeCommands, 1)
+}
+
+func (h *baseHandler) decrementActiveCommands() int32 {
+ atomic.AddInt32(&h.activeCommands, -1)
+ return atomic.LoadInt32(&h.activeCommands)
+}
diff --git a/internal/server/handlers/controlhandler.go b/internal/server/handlers/controlhandler.go
deleted file mode 100644
index 1e17c78..0000000
--- a/internal/server/handlers/controlhandler.go
+++ /dev/null
@@ -1,98 +0,0 @@
-package handlers
-
-import (
- "fmt"
- "io"
- "os"
- "strings"
-
- "github.com/mimecast/dtail/internal"
- "github.com/mimecast/dtail/internal/io/logger"
- user "github.com/mimecast/dtail/internal/user/server"
-)
-
-// ControlHandler is used for control functions and health monitoring.
-type ControlHandler struct {
- done *internal.Done
- hostname string
- payload []byte
- serverMessages chan string
- user *user.User
-}
-
-// NewControlHandler returns a new control handler.
-func NewControlHandler(user *user.User) *ControlHandler {
- logger.Debug(user, "Creating control handler")
-
- h := ControlHandler{
- done: internal.NewDone(),
- serverMessages: make(chan string, 10),
- user: user,
- }
-
- fqdn, err := os.Hostname()
- if err != nil {
- logger.FatalExit(err)
- }
-
- s := strings.Split(fqdn, ".")
- h.hostname = s[0]
-
- return &h
-}
-
-// Shutdown the handler.
-func (h *ControlHandler) Shutdown() {
- h.done.Shutdown()
-}
-
-// Done channel of the handler.
-func (h *ControlHandler) Done() <-chan struct{} {
- return h.done.Done()
-}
-
-// Read is to send data to the client via the Reader interface.
-func (h *ControlHandler) Read(p []byte) (n int, err error) {
- for {
- select {
- case message := <-h.serverMessages:
- wholePayload := []byte(fmt.Sprintf("SERVER|%s|%s\n", h.hostname, message))
- n = copy(p, wholePayload)
- return
- case <-h.done.Done():
- return 0, io.EOF
- }
- }
-}
-
-// Write is to read data to the client via the Writer interface.
-func (h *ControlHandler) Write(p []byte) (n int, err error) {
- for _, c := range p {
- switch c {
- case ';':
- wholePayload := strings.TrimSpace(string(h.payload))
- h.handleCommand(wholePayload)
- h.payload = nil
-
- default:
- h.payload = append(h.payload, c)
- }
- }
-
- n = len(p)
- return
-}
-
-func (h *ControlHandler) handleCommand(command string) {
- logger.Info(h.user, command)
- s := strings.Split(command, " ")
- logger.Debug(h.user, "Receiving command", command, s)
-
- switch s[0] {
- case "health":
- h.serverMessages <- "OK: DTail SSH Server seems fine"
- h.serverMessages <- "done;"
- default:
- h.serverMessages <- logger.Error(h.user, "Received unknown control command", command, s)
- }
-}
diff --git a/internal/server/handlers/healthhandler.go b/internal/server/handlers/healthhandler.go
new file mode 100644
index 0000000..6dd9872
--- /dev/null
+++ b/internal/server/handlers/healthhandler.go
@@ -0,0 +1,58 @@
+package handlers
+
+import (
+ "context"
+ "os"
+ "strings"
+
+ "github.com/mimecast/dtail/internal"
+ "github.com/mimecast/dtail/internal/io/dlog"
+ "github.com/mimecast/dtail/internal/io/line"
+ "github.com/mimecast/dtail/internal/lcontext"
+ user "github.com/mimecast/dtail/internal/user/server"
+)
+
+// HealthHandler is for the remote health check.
+type HealthHandler struct {
+ baseHandler
+}
+
+// NewHealthHandler returns the server handler.
+func NewHealthHandler(user *user.User) *HealthHandler {
+ dlog.Server.Debug(user, "Creating new server health handler")
+ h := HealthHandler{
+ baseHandler: baseHandler{
+ done: internal.NewDone(),
+ lines: make(chan line.Line, 100),
+ serverMessages: make(chan string, 10),
+ maprMessages: make(chan string, 10),
+ ackCloseReceived: make(chan struct{}),
+ user: user,
+ },
+ }
+ h.handleCommandCb = h.handleHealthCommand
+
+ fqdn, err := os.Hostname()
+ if err != nil {
+ dlog.Server.FatalPanic(err)
+ }
+ s := strings.Split(fqdn, ".")
+ h.hostname = s[0]
+ return &h
+}
+
+func (h *HealthHandler) handleHealthCommand(ctx context.Context,
+ ltx lcontext.LContext, argc int, args []string, commandName string) {
+
+ dlog.Server.Debug(h.user, "Handling health command", argc, args)
+ switch commandName {
+ case "health":
+ h.send(h.serverMessages, "OK")
+ case ".ack":
+ h.handleAckCommand(argc, args)
+ default:
+ h.send(h.serverMessages, dlog.Server.Error(h.user,
+ "Received unknown health command", commandName, argc, args))
+ }
+ h.shutdown()
+}
diff --git a/internal/server/handlers/mapcommand.go b/internal/server/handlers/mapcommand.go
index c3e600e..65e0ed8 100644
--- a/internal/server/handlers/mapcommand.go
+++ b/internal/server/handlers/mapcommand.go
@@ -14,18 +14,17 @@ type mapCommand struct {
}
// NewMapCommand returns a new server side mapreduce command.
-func newMapCommand(serverHandler *ServerHandler, argc int, args []string) (mapCommand, *server.Aggregate, error) {
- m := mapCommand{server: serverHandler}
+func newMapCommand(serverHandler *ServerHandler, argc int,
+ args []string) (mapCommand, *server.Aggregate, error) {
+ m := mapCommand{server: serverHandler}
queryStr := strings.Join(args[1:], " ")
aggregate, err := server.NewAggregate(queryStr)
if err != nil {
return m, nil, err
}
-
m.aggregate = aggregate
return m, aggregate, nil
-
}
func (m mapCommand) Start(ctx context.Context, aggregatedMessages chan<- string) {
diff --git a/internal/server/handlers/readcommand.go b/internal/server/handlers/readcommand.go
index b659c06..4728a55 100644
--- a/internal/server/handlers/readcommand.go
+++ b/internal/server/handlers/readcommand.go
@@ -7,8 +7,9 @@ import (
"sync"
"time"
+ "github.com/mimecast/dtail/internal/io/dlog"
"github.com/mimecast/dtail/internal/io/fs"
- "github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/io/line"
"github.com/mimecast/dtail/internal/lcontext"
"github.com/mimecast/dtail/internal/omode"
"github.com/mimecast/dtail/internal/regex"
@@ -26,39 +27,45 @@ func newReadCommand(server *ServerHandler, mode omode.Mode) *readCommand {
}
}
-func (r *readCommand) Start(ctx context.Context, lContext lcontext.LContext, argc int, args []string, retries int) {
- re := regex.NewNoop()
+func (r *readCommand) Start(ctx context.Context, ltx lcontext.LContext,
+ argc int, args []string, retries int) {
+ re := regex.NewNoop()
if argc >= 4 {
deserializedRegex, err := regex.Deserialize(strings.Join(args[2:], " "))
if err != nil {
- r.server.sendServerMessage(logger.Error(r.server.user, commandParseWarning, err))
+ r.server.send(r.server.serverMessages, dlog.Server.Error(r.server.user,
+ "Unable to parse command", err))
return
}
re = deserializedRegex
}
if argc < 3 {
- r.server.sendServerWarnMessage(logger.Warn(r.server.user, commandParseWarning, args, argc))
+ r.server.send(r.server.serverMessages, dlog.Server.Warn(r.server.user,
+ "Unable to parse command", args, argc))
return
}
- r.readGlob(ctx, lContext, args[1], re, retries)
+ r.readGlob(ctx, ltx, args[1], re, retries)
}
-func (r *readCommand) readGlob(ctx context.Context, lContext lcontext.LContext, glob string, re regex.Regex, retries int) {
+func (r *readCommand) readGlob(ctx context.Context, ltx lcontext.LContext,
+ glob string, re regex.Regex, retries int) {
+
retryInterval := time.Second * 5
glob = filepath.Clean(glob)
for retryCount := 0; retryCount < retries; retryCount++ {
paths, err := filepath.Glob(glob)
if err != nil {
- logger.Warn(r.server.user, glob, err)
+ dlog.Server.Warn(r.server.user, glob, err)
time.Sleep(retryInterval)
continue
}
if numPaths := len(paths); numPaths == 0 {
- logger.Error(r.server.user, "No such file(s) to read", glob)
- r.server.sendServerWarnMessage(logger.Warn(r.server.user, "Unable to read file(s), check server logs"))
+ dlog.Server.Error(r.server.user, "No such file(s) to read", glob)
+ r.server.send(r.server.serverMessages, dlog.Server.Warn(r.server.user,
+ "Unable to read file(s), check server logs"))
select {
case <-ctx.Done():
return
@@ -68,41 +75,44 @@ func (r *readCommand) readGlob(ctx context.Context, lContext lcontext.LContext,
continue
}
- r.readFiles(ctx, lContext, paths, glob, re, retryInterval)
+ r.readFiles(ctx, ltx, paths, glob, re, retryInterval)
return
}
- r.server.sendServerWarnMessage(logger.Warn(r.server.user, "Giving up to read file(s)"))
+ r.server.send(r.server.serverMessages, dlog.Server.Warn(r.server.user,
+ "Giving up to read file(s)"))
return
}
-func (r *readCommand) readFiles(ctx context.Context, lContext lcontext.LContext, paths []string, glob string, re regex.Regex, retryInterval time.Duration) {
+func (r *readCommand) readFiles(ctx context.Context, ltx lcontext.LContext,
+ paths []string, glob string, re regex.Regex, retryInterval time.Duration) {
+
var wg sync.WaitGroup
wg.Add(len(paths))
-
for _, path := range paths {
- go r.readFileIfPermissions(ctx, lContext, &wg, path, glob, re)
+ go r.readFileIfPermissions(ctx, ltx, &wg, path, glob, re)
}
-
wg.Wait()
}
-func (r *readCommand) readFileIfPermissions(ctx context.Context, lContext lcontext.LContext, wg *sync.WaitGroup, path, glob string, re regex.Regex) {
+func (r *readCommand) readFileIfPermissions(ctx context.Context, ltx lcontext.LContext,
+ wg *sync.WaitGroup, path, glob string, re regex.Regex) {
+
defer wg.Done()
globID := r.makeGlobID(path, glob)
-
if !r.server.user.HasFilePermission(path, "readfiles") {
- logger.Error(r.server.user, "No permission to read file", path, globID)
- r.server.sendServerWarnMessage(logger.Warn(r.server.user, "Unable to read file(s), check server logs"))
+ dlog.Server.Error(r.server.user, "No permission to read file", path, globID)
+ r.server.send(r.server.serverMessages, dlog.Server.Warn(r.server.user,
+ "Unable to read file(s), check server logs"))
return
}
-
- r.readFile(ctx, lContext, path, globID, re)
+ r.readFile(ctx, ltx, path, globID, re)
}
-func (r *readCommand) readFile(ctx context.Context, lContext lcontext.LContext, path, globID string, re regex.Regex) {
- logger.Info(r.server.user, "Start reading file", path, globID)
+func (r *readCommand) readFile(ctx context.Context, ltx lcontext.LContext,
+ path, globID string, re regex.Regex) {
+ dlog.Server.Info(r.server.user, "Start reading file", path, globID)
var reader fs.FileReader
switch r.mode {
case omode.TailClient:
@@ -114,15 +124,19 @@ func (r *readCommand) readFile(ctx context.Context, lContext lcontext.LContext,
}
lines := r.server.lines
-
- // Plug in mappreduce engine
- if r.server.aggregate != nil {
- lines = r.server.aggregate.Lines
- }
+ aggregate := r.server.aggregate
for {
- if err := reader.Start(ctx, lContext, lines, re); err != nil {
- logger.Error(r.server.user, path, globID, err)
+ if aggregate != nil {
+ lines = make(chan line.Line, 100)
+ aggregate.NextLinesCh <- lines
+ }
+ if err := reader.Start(ctx, ltx, lines, re); err != nil {
+ dlog.Server.Error(r.server.user, path, globID, err)
+ }
+ if aggregate != nil {
+ // Also makes aggregate to Flush
+ close(lines)
}
select {
@@ -133,9 +147,8 @@ func (r *readCommand) readFile(ctx context.Context, lContext lcontext.LContext,
return
}
}
-
time.Sleep(time.Second * 2)
- logger.Info(path, globID, "Reading file again")
+ dlog.Server.Info(path, globID, "Reading file again")
}
}
@@ -152,11 +165,11 @@ func (r *readCommand) makeGlobID(path, glob string) string {
if len(idParts) > 0 {
return strings.Join(idParts, "/")
}
-
if len(pathParts) > 0 {
return pathParts[len(pathParts)-1]
}
- r.server.sendServerWarnMessage(logger.Warn("Empty file path given?", path, glob))
+ r.server.send(r.server.serverMessages,
+ dlog.Server.Warn("Empty file path given?", path, glob))
return ""
}
diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go
index 39d5d5f..36574a9 100644
--- a/internal/server/handlers/serverhandler.go
+++ b/internal/server/handlers/serverhandler.go
@@ -2,69 +2,50 @@ package handlers
import (
"context"
- "encoding/base64"
- "errors"
- "fmt"
- "io"
"os"
- "strconv"
"strings"
- "sync/atomic"
- "time"
"github.com/mimecast/dtail/internal"
- "github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/io/dlog"
"github.com/mimecast/dtail/internal/io/line"
- "github.com/mimecast/dtail/internal/io/logger"
"github.com/mimecast/dtail/internal/lcontext"
- "github.com/mimecast/dtail/internal/mapr/server"
"github.com/mimecast/dtail/internal/omode"
user "github.com/mimecast/dtail/internal/user/server"
- "github.com/mimecast/dtail/internal/version"
-)
-
-const (
- commandParseWarning string = "Unable to parse command"
)
// ServerHandler implements the Reader and Writer interfaces to handle
// the Bi-directional communication between SSH client and server.
// This handler implements the handler of the SSH server.
type ServerHandler struct {
- done *internal.Done
- lines chan line.Line
- regex string
- aggregate *server.Aggregate
- aggregatedMessages chan string
- serverMessages chan string
- payload []byte
- hostname string
- user *user.User
- catLimiter chan struct{}
- tailLimiter chan struct{}
- ackCloseReceived chan struct{}
- activeCommands int32
- activeReaders int32
- quiet bool
+ baseHandler
+ catLimiter chan struct{}
+ tailLimiter chan struct{}
+ regex string
}
// NewServerHandler returns the server handler.
-func NewServerHandler(user *user.User, catLimiter, tailLimiter chan struct{}) *ServerHandler {
+func NewServerHandler(user *user.User, catLimiter,
+ tailLimiter chan struct{}) *ServerHandler {
+
+ dlog.Server.Debug(user, "Creating new server handler")
h := ServerHandler{
- done: internal.NewDone(),
- lines: make(chan line.Line, 100),
- serverMessages: make(chan string, 10),
- aggregatedMessages: make(chan string, 10),
- ackCloseReceived: make(chan struct{}),
- catLimiter: catLimiter,
- tailLimiter: tailLimiter,
- regex: ".",
- user: user,
- }
+ baseHandler: baseHandler{
+ done: internal.NewDone(),
+ lines: make(chan line.Line, 100),
+ serverMessages: make(chan string, 10),
+ maprMessages: make(chan string, 10),
+ ackCloseReceived: make(chan struct{}),
+ user: user,
+ },
+ catLimiter: catLimiter,
+ tailLimiter: tailLimiter,
+ regex: ".",
+ }
+ h.handleCommandCb = h.handleUserCommand
fqdn, err := os.Hostname()
if err != nil {
- logger.FatalExit(err)
+ dlog.Server.FatalPanic(err)
}
s := strings.Split(fqdn, ".")
@@ -73,374 +54,49 @@ func NewServerHandler(user *user.User, catLimiter, tailLimiter chan struct{}) *S
return &h
}
-// Shutdown the handler.
-func (h *ServerHandler) Shutdown() {
- h.done.Shutdown()
-}
-
-// Done channel of the handler.
-func (h *ServerHandler) Done() <-chan struct{} {
- return h.done.Done()
-}
-
-// Read is to send data to the dtail client via Reader interface.
-func (h *ServerHandler) Read(p []byte) (n int, err error) {
- for {
- select {
- case message := <-h.serverMessages:
- if len(message) == 0 {
- logger.Warn(h.user, "Empty message received")
- return
- }
- if message[0] == '.' {
- // Handle hidden message (don't display to the user, interpreted by dtail client)
- wholePayload := []byte(fmt.Sprintf("%s\n", message))
- n = copy(p, wholePayload)
- return
- }
-
- // Handle normal server message (display to the user)
- wholePayload := []byte(fmt.Sprintf("SERVER|%s|%s\n", h.hostname, message))
- n = copy(p, wholePayload)
- return
-
- case message := <-h.aggregatedMessages:
- // Send mapreduce-aggregated data as a message.
- data := fmt.Sprintf("AGGREGATEâž”%sâž”%s\n", h.hostname, message)
- wholePayload := []byte(data)
- n = copy(p, wholePayload)
- return
-
- case line := <-h.lines:
- // Send normal file content data as a message.
- serverInfo := []byte(fmt.Sprintf("REMOTE|%s|%3d|%v|%s|",
- h.hostname, line.TransmittedPerc, line.Count, line.SourceID))
- wholePayload := append(serverInfo, line.Content[:]...)
- n = copy(p, wholePayload)
- return
-
- case <-time.After(time.Second):
- // Once in a while check whether we are done.
- select {
- case <-h.done.Done():
- return 0, io.EOF
- default:
- }
- }
- }
-}
-
-// Write is to receive data from the dtail client via Writer interface.
-func (h *ServerHandler) Write(p []byte) (n int, err error) {
- for _, c := range p {
- switch c {
- case ';':
- commandStr := strings.TrimSpace(string(h.payload))
- h.handleCommand(commandStr)
- h.payload = nil
- default:
- h.payload = append(h.payload, c)
- }
- }
-
- n = len(p)
- return
-}
-
-func (h *ServerHandler) handleCommand(commandStr string) {
- logger.Debug(h.user, commandStr)
- ctx := context.Background()
-
- args, argc, err := h.handleProtocolVersion(strings.Split(commandStr, " "))
- if err != nil {
- h.send(h.serverMessages, logger.Error(h.user, err))
- return
- }
-
- args, argc, err = h.handleBase64(args, argc)
- if err != nil {
- h.send(h.serverMessages, logger.Error(h.user, err))
- return
- }
-
- if h.user.Name == config.ControlUser {
- h.handleControlCommand(argc, args)
- return
- }
-
- ctx, cancel := context.WithCancel(ctx)
- go func() {
- <-h.done.Done()
- cancel()
- }()
-
- h.handleUserCommand(ctx, argc, args)
-}
-
-func (h *ServerHandler) handleProtocolVersion(args []string) ([]string, int, error) {
- argc := len(args)
-
- if argc <= 2 || args[0] != "protocol" {
- return args, argc, errors.New("unable to determine protocol version")
- }
-
- if args[1] != version.ProtocolCompat {
- err := fmt.Errorf("server with protocol version '%s' but client with '%s', please update DTail", version.ProtocolCompat, args[1])
- return args, argc, err
- }
-
- return args[2:], argc - 2, nil
-}
-
-func (h *ServerHandler) handleBase64(args []string, argc int) ([]string, int, error) {
- err := errors.New("Unable to decode client message, DTail server and client versions may not be compatible")
-
- if argc != 2 || args[0] != "base64" {
- return args, argc, err
- }
-
- decoded, err := base64.StdEncoding.DecodeString(args[1])
- if err != nil {
- return args, argc, err
- }
- decodedStr := string(decoded)
-
- args = strings.Split(decodedStr, " ")
- argc = len(decodedStr)
- logger.Trace(h.user, "Base64 decoded received command", decodedStr, argc, args)
-
- return args, argc, nil
-}
-
-func (h *ServerHandler) handleControlCommand(argc int, args []string) {
- switch args[0] {
- case "debug":
- h.send(h.serverMessages, logger.Debug(h.user, "Receiving debug command", argc, args))
- default:
- logger.Warn(h.user, "Received unknown control command", argc, args)
- }
-}
-
-func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []string) {
- logger.Debug(h.user, "handleUserCommand", argc, args)
+func (h *ServerHandler) handleUserCommand(ctx context.Context, ltx lcontext.LContext,
+ argc int, args []string, commandName string) {
+ dlog.Server.Debug(h.user, "Handling user command", argc, args)
h.incrementActiveCommands()
commandFinished := func() {
if h.decrementActiveCommands() == 0 {
h.shutdown()
}
}
- readerFinished := func() {
- if h.decrementActiveReaders() == 0 {
- if h.aggregate == nil {
- return
- }
- h.aggregate.Shutdown()
- }
- }
-
- splitted := strings.Split(args[0], ":")
- commandName := splitted[0]
-
- options, lContext, err := readOptions(splitted[1:])
- if err != nil {
- h.sendServerMessage(logger.Error(h.user, err))
- commandFinished()
- return
- }
- if quiet, ok := options["quiet"]; ok {
- if quiet == "true" {
- logger.Debug(h.user, "Enabling quiet mode")
- h.quiet = true
- }
- }
switch commandName {
case "grep", "cat":
command := newReadCommand(h, omode.CatClient)
go func() {
- h.incrementActiveReaders()
- command.Start(ctx, lContext, argc, args, 1)
- readerFinished()
+ command.Start(ctx, ltx, argc, args, 1)
commandFinished()
}()
-
case "tail":
command := newReadCommand(h, omode.TailClient)
go func() {
- h.incrementActiveReaders()
- command.Start(ctx, lContext, argc, args, 10)
- readerFinished()
+ command.Start(ctx, ltx, argc, args, 10)
commandFinished()
}()
-
case "map":
command, aggregate, err := newMapCommand(h, argc, args)
if err != nil {
- h.sendServerMessage(err.Error())
- logger.Error(h.user, err)
+ h.send(h.serverMessages, err.Error())
+ dlog.Server.Error(h.user, err)
commandFinished()
return
}
-
h.aggregate = aggregate
go func() {
- command.Start(ctx, h.aggregatedMessages)
+ command.Start(ctx, h.maprMessages)
commandFinished()
}()
-
- case "ack", ".ack":
+ case ".ack":
h.handleAckCommand(argc, args)
commandFinished()
-
default:
- h.sendServerMessage(logger.Error(h.user, "Received unknown user command", commandName, argc, args, options))
+ h.send(h.serverMessages, dlog.Server.Error(h.user,
+ "Received unknown user command", commandName, argc, args))
commandFinished()
}
}
-
-func (h *ServerHandler) handleAckCommand(argc int, args []string) {
- if argc < 3 {
- h.sendServerWarnMessage(logger.Warn(h.user, commandParseWarning, args, argc))
- return
- }
- if args[1] == "close" && args[2] == "connection" {
- close(h.ackCloseReceived)
- }
-}
-
-func (h *ServerHandler) send(ch chan<- string, message string) {
- select {
- case ch <- message:
- case <-h.done.Done():
- }
-}
-
-func (h *ServerHandler) sendServerMessage(message string) {
- h.send(h.serverMessageC(), message)
-}
-
-func (h *ServerHandler) sendServerWarnMessage(message string) {
- if h.quiet {
- return
- }
- h.send(h.serverMessageC(), message)
-}
-
-func (h *ServerHandler) serverMessageC() chan<- string {
- return h.serverMessages
-}
-
-func (h *ServerHandler) flush() {
- logger.Debug(h.user, "flush()")
-
- if h.aggregate != nil {
- h.aggregate.Flush()
- }
-
- unsentMessages := func() int {
- return len(h.lines) + len(h.serverMessages) + len(h.aggregatedMessages)
- }
- for i := 0; i < 3; i++ {
- if unsentMessages() == 0 {
- logger.Debug(h.user, "All lines sent")
- return
- }
- logger.Debug(h.user, "Still lines to be sent")
- time.Sleep(time.Second)
- }
-
- logger.Warn(h.user, "Some lines remain unsent", unsentMessages())
-}
-
-func (h *ServerHandler) shutdown() {
- logger.Debug(h.user, "shutdown()")
- h.flush()
-
- go func() {
- select {
- case h.serverMessageC() <- ".syn close connection":
- case <-h.done.Done():
- }
- }()
-
- select {
- case <-h.ackCloseReceived:
- case <-time.After(time.Second * 5):
- logger.Debug(h.user, "Shutdown timeout reached, enforcing shutdown")
- case <-h.done.Done():
- }
-
- h.done.Shutdown()
-}
-
-func (h *ServerHandler) incrementActiveCommands() {
- atomic.AddInt32(&h.activeCommands, 1)
-}
-
-func (h *ServerHandler) decrementActiveCommands() int32 {
- atomic.AddInt32(&h.activeCommands, -1)
- return atomic.LoadInt32(&h.activeCommands)
-}
-
-func (h *ServerHandler) incrementActiveReaders() {
- atomic.AddInt32(&h.activeReaders, 1)
-}
-
-func (h *ServerHandler) decrementActiveReaders() int32 {
- atomic.AddInt32(&h.activeReaders, -1)
- return atomic.LoadInt32(&h.activeReaders)
-}
-
-// TODO: All options related code should be in its own package (client + server)
-func readOptions(opts []string) (map[string]string, lcontext.LContext, error) {
- options := make(map[string]string, len(opts))
- // Local search context
- var lContext lcontext.LContext
-
- for _, o := range opts {
- kv := strings.SplitN(o, "=", 2)
- if len(kv) != 2 {
- continue
- }
- key := kv[0]
- val := kv[1]
-
- if strings.HasPrefix(val, "base64%") {
- s := strings.SplitN(val, "%", 2)
- decoded, err := base64.StdEncoding.DecodeString(s[1])
- if err != nil {
- return options, lContext, err
- }
- val = string(decoded)
- }
-
- switch key {
- case "before":
- iVal, err := strconv.Atoi(val)
- if err != nil {
- logger.Error(err)
- continue
- }
- lContext.BeforeContext = iVal
- case "after":
- iVal, err := strconv.Atoi(val)
- if err != nil {
- logger.Error(err)
- continue
- }
- lContext.AfterContext = iVal
- case "max":
- iVal, err := strconv.Atoi(val)
- if err != nil {
- logger.Error(err)
- continue
- }
- lContext.MaxCount = iVal
- default:
- options[key] = val
- }
- }
-
- return options, lContext, nil
-}
diff --git a/internal/server/scheduler.go b/internal/server/scheduler.go
index a1e9e36..0ba65f7 100644
--- a/internal/server/scheduler.go
+++ b/internal/server/scheduler.go
@@ -10,25 +10,23 @@ import (
"github.com/mimecast/dtail/internal/clients"
"github.com/mimecast/dtail/internal/config"
- "github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/io/dlog"
"github.com/mimecast/dtail/internal/omode"
gossh "golang.org/x/crypto/ssh"
)
-type scheduler struct {
-}
+type scheduler struct{}
func newScheduler() *scheduler {
return &scheduler{}
}
func (s *scheduler) start(ctx context.Context) {
- logger.Info("Starting scheduled job runner after 10s")
+ dlog.Server.Info("Starting scheduled job runner after 10s")
// First run after just 10s!
time.Sleep(time.Second * 10)
s.runJobs(ctx)
-
for {
select {
case <-time.After(time.Minute):
@@ -42,27 +40,24 @@ func (s *scheduler) start(ctx context.Context) {
func (s *scheduler) runJobs(ctx context.Context) {
for _, job := range config.Server.Schedule {
if !job.Enable {
- logger.Debug(job.Name, "Not running job as not enabled")
+ dlog.Server.Debug(job.Name, "Not running job as not enabled")
continue
}
-
hour, err := strconv.Atoi(time.Now().Format("15"))
if err != nil {
- logger.Error(job.Name, "Unable to create job", err)
+ dlog.Server.Error(job.Name, "Unable to create job", err)
continue
}
-
if hour < job.TimeRange[0] || hour >= job.TimeRange[1] {
- logger.Debug(job.Name, "Not running job out of time range")
+ dlog.Server.Debug(job.Name, "Not running job out of time range")
continue
}
files := fillDates(job.Files)
outfile := fillDates(job.Outfile)
-
_, err = os.Stat(outfile)
if !os.IsNotExist(err) {
- logger.Debug(job.Name, "Not running job as outfile already exists", outfile)
+ dlog.Server.Debug(job.Name, "Not running job as outfile already exists", outfile)
continue
}
@@ -70,9 +65,8 @@ func (s *scheduler) runJobs(ctx context.Context) {
if servers == "" {
servers = config.Server.SSHBindAddress
}
-
- args := clients.Args{
- ConnectionsPerCPU: 10,
+ args := config.Args{
+ ConnectionsPerCPU: config.DefaultConnectionsPerCPU,
Discovery: job.Discovery,
ServersStr: servers,
What: files,
@@ -81,25 +75,24 @@ func (s *scheduler) runJobs(ctx context.Context) {
}
args.SSHAuthMethods = append(args.SSHAuthMethods, gossh.Password(job.Name))
-
- query := fmt.Sprintf("%s outfile %s", job.Query, outfile)
- client, err := clients.NewMaprClient(args, query, clients.CumulativeMode)
+ args.QueryStr = fmt.Sprintf("%s outfile %s", job.Query, outfile)
+ client, err := clients.NewMaprClient(args, clients.CumulativeMode)
if err != nil {
- logger.Error(fmt.Sprintf("Unable to create job %s", job.Name), err)
+ dlog.Server.Error(fmt.Sprintf("Unable to create job %s", job.Name), err)
continue
}
jobCtx, cancel := context.WithCancel(ctx)
defer cancel()
- logger.Info(fmt.Sprintf("Starting job %s", job.Name))
+ dlog.Server.Info(fmt.Sprintf("Starting job %s", job.Name))
status := client.Start(jobCtx, make(chan string))
logMessage := fmt.Sprintf("Job exited with status %d", status)
if status != 0 {
- logger.Warn(logMessage)
+ dlog.Server.Warn(logMessage)
continue
}
- logger.Info(logMessage)
+ dlog.Server.Info(logMessage)
}
}
diff --git a/internal/server/server.go b/internal/server/server.go
index 3640208..0cb5e27 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -9,7 +9,7 @@ import (
"strings"
"github.com/mimecast/dtail/internal/config"
- "github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/io/dlog"
"github.com/mimecast/dtail/internal/server/handlers"
"github.com/mimecast/dtail/internal/ssh/server"
user "github.com/mimecast/dtail/internal/user/server"
@@ -24,9 +24,9 @@ type Server struct {
stats stats
// SSH server configuration.
sshServerConfig *gossh.ServerConfig
- // To control the max amount of concurrent cats (which can cause a lot of I/O on the server)
+ // To control the max amount of concurrent cats.
catLimiter chan struct{}
- // To control the max amount of concurrent tails
+ // To control the max amount of concurrent tails.
tailLimiter chan struct{}
// To run scheduled tasks (if configured)
sched *scheduler
@@ -36,7 +36,7 @@ type Server struct {
// New returns a new server.
func New() *Server {
- logger.Info("Creating server", version.String())
+ dlog.Server.Info("Creating server", version.String())
s := Server{
sshServerConfig: &gossh.ServerConfig{},
@@ -51,7 +51,7 @@ func New() *Server {
private, err := gossh.ParsePrivateKey(server.PrivateHostKey())
if err != nil {
- logger.FatalExit(err)
+ dlog.Server.FatalPanic(err)
}
s.sshServerConfig.AddHostKey(private)
@@ -60,14 +60,13 @@ func New() *Server {
// Start the server.
func (s *Server) Start(ctx context.Context) int {
- logger.Info("Starting server")
-
+ dlog.Server.Info("Starting server")
bindAt := fmt.Sprintf("%s:%d", config.Server.SSHBindAddress, config.Common.SSHPort)
- logger.Info("Binding server", bindAt)
+ dlog.Server.Info("Binding server", bindAt)
listener, err := net.Listen("tcp", bindAt)
if err != nil {
- logger.FatalExit("Failed to open listening TCP socket", err)
+ dlog.Server.FatalPanic("Failed to open listening TCP socket", err)
}
go s.stats.start(ctx)
@@ -76,14 +75,12 @@ func (s *Server) Start(ctx context.Context) int {
go s.listenerLoop(ctx, listener)
<-ctx.Done()
-
// For future use.
return 0
}
func (s *Server) listenerLoop(ctx context.Context, listener net.Listener) {
- logger.Debug("Starting listener loop")
-
+ dlog.Server.Debug("Starting listener loop")
for {
conn, err := listener.Accept() // Blocking
if err != nil {
@@ -92,63 +89,69 @@ func (s *Server) listenerLoop(ctx context.Context, listener net.Listener) {
return
default:
}
- logger.Error("Failed to accept incoming connection", err)
+ dlog.Server.Error("Failed to accept incoming connection", err)
continue
}
if err := s.stats.serverLimitExceeded(); err != nil {
- logger.Error(err)
+ dlog.Server.Error(err)
conn.Close()
continue
}
-
go s.handleConnection(ctx, conn)
}
}
func (s *Server) handleConnection(ctx context.Context, conn net.Conn) {
- logger.Info("Handling connection")
+ dlog.Server.Info("Handling connection")
sshConn, chans, reqs, err := gossh.NewServerConn(conn, s.sshServerConfig)
if err != nil {
- logger.Error("Something just happened", err)
+ dlog.Server.Error("Something just happened", err)
return
}
s.stats.incrementConnections()
-
go gossh.DiscardRequests(reqs)
for newChannel := range chans {
go s.handleChannel(ctx, sshConn, newChannel)
}
}
-func (s *Server) handleChannel(ctx context.Context, sshConn gossh.Conn, newChannel gossh.NewChannel) {
- user := user.New(sshConn.User(), sshConn.RemoteAddr().String())
- logger.Info(user, "Invoking channel handler")
+func (s *Server) handleChannel(ctx context.Context, sshConn gossh.Conn,
+ newChannel gossh.NewChannel) {
+ user, err := user.New(sshConn.User(), sshConn.RemoteAddr().String())
+ if err != nil {
+ dlog.Server.Error(user, err)
+ newChannel.Reject(gossh.Prohibited, err.Error())
+ return
+ }
+
+ dlog.Server.Info(user, "Invoking channel handler")
if newChannel.ChannelType() != "session" {
err := errors.New("Don'w allow other channel types than session")
- logger.Error(user, err)
+ dlog.Server.Error(user, err)
newChannel.Reject(gossh.Prohibited, err.Error())
return
}
channel, requests, err := newChannel.Accept()
if err != nil {
- logger.Error(user, "Could not accept channel", err)
+ dlog.Server.Error(user, "Could not accept channel", err)
return
}
if err := s.handleRequests(ctx, sshConn, requests, channel, user); err != nil {
- logger.Error(user, "While handling request", err)
+ dlog.Server.Error(user, err)
sshConn.Close()
}
}
-func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, in <-chan *gossh.Request, channel gossh.Channel, user *user.User) error {
- logger.Info(user, "Invoking request handler")
+func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn,
+ in <-chan *gossh.Request, channel gossh.Channel, user *user.User) error {
+ dlog.Server.Info(user, "Invoking request handler")
for req := range in {
var payload = struct{ Value string }{}
gossh.Unmarshal(req.Payload, &payload)
@@ -157,12 +160,11 @@ func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, in <-ch
case "shell":
var handler handlers.Handler
switch user.Name {
- case config.ControlUser:
- handler = handlers.NewControlHandler(user)
+ case config.HealthUser:
+ handler = handlers.NewHealthHandler(user)
default:
handler = handlers.NewServerHandler(user, s.catLimiter, s.tailLimiter)
}
-
terminate := func() {
handler.Shutdown()
sshConn.Close()
@@ -173,13 +175,11 @@ func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, in <-ch
io.Copy(channel, handler)
terminate()
}()
-
go func() {
// Broken pipe, cancel
io.Copy(handler, channel)
terminate()
}()
-
go func() {
select {
case <-ctx.Done():
@@ -187,62 +187,61 @@ func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, in <-ch
}
terminate()
}()
-
go func() {
if err := sshConn.Wait(); err != nil && err != io.EOF {
- // Use of closed network connection.
- logger.Debug(user, "While waiting for ssh connection", err)
+ dlog.Server.Error(user, err)
}
s.stats.decrementConnections()
- logger.Info(user, "Good bye Mister!")
+ dlog.Server.Info(user, "Good bye Mister!")
terminate()
}()
// Only serving shell type
req.Reply(true, nil)
-
default:
req.Reply(false, nil)
-
- return fmt.Errorf("Closing SSH connection as unknown request received|%s|%v",
+ return fmt.Errorf("Closing SSH connection as unknown request recieved|%s|%v",
req.Type, payload.Value)
}
}
-
return nil
}
// Callback for SSH authentication.
-func (s *Server) Callback(c gossh.ConnMetadata, authPayload []byte) (*gossh.Permissions, error) {
- user := user.New(c.User(), c.RemoteAddr().String())
+func (s *Server) Callback(c gossh.ConnMetadata,
+ authPayload []byte) (*gossh.Permissions, error) {
+
+ user, err := user.New(c.User(), c.RemoteAddr().String())
+ if err != nil {
+ return nil, err
+ }
if config.ServerRelaxedAuthEnable {
- logger.Fatal(user, "Granting permissions via relaxed-auth")
+ dlog.Server.Fatal(user, "Granting permissions via relaxed-auth")
return nil, nil
}
authInfo := string(authPayload)
-
splitted := strings.Split(c.RemoteAddr().String(), ":")
remoteIP := splitted[0]
switch user.Name {
- case config.ControlUser:
- if authInfo == config.ControlUser {
- logger.Debug(user, "Granting permissions to control user")
+ case config.HealthUser:
+ if authInfo == config.HealthUser {
+ dlog.Server.Debug(user, "Granting permissions to health user")
return nil, nil
}
case config.ScheduleUser:
for _, job := range config.Server.Schedule {
if s.backgroundCanSSH(user, authInfo, remoteIP, job.Name, job.AllowFrom) {
- logger.Debug(user, "Granting SSH connection")
+ dlog.Server.Debug(user, "Granting SSH connection")
return nil, nil
}
}
case config.ContinuousUser:
for _, job := range config.Server.Continuous {
if s.backgroundCanSSH(user, authInfo, remoteIP, job.Name, job.AllowFrom) {
- logger.Debug(user, "Granting SSH connection")
+ dlog.Server.Debug(user, "Granting SSH connection")
return nil, nil
}
}
@@ -252,23 +251,26 @@ func (s *Server) Callback(c gossh.ConnMetadata, authPayload []byte) (*gossh.Perm
return nil, fmt.Errorf("user %s not authorized", user)
}
-func (s *Server) backgroundCanSSH(user *user.User, jobName, remoteIP, allowedJobName string, allowFrom []string) bool {
- logger.Debug("backgroundCanSSH", user, jobName, remoteIP, allowedJobName, allowFrom)
+func (s *Server) backgroundCanSSH(user *user.User, jobName, remoteIP,
+ allowedJobName string, allowFrom []string) bool {
+ dlog.Server.Debug("backgroundCanSSH", user, jobName, remoteIP, allowedJobName, allowFrom)
if jobName != allowedJobName {
- logger.Debug(user, jobName, "backgroundCanSSH", "Job name does not match, skipping to next one...", allowedJobName)
+ dlog.Server.Debug(user, jobName, "backgroundCanSSH",
+ "Job name does not match, skipping to next one...", allowedJobName)
return false
}
for _, myAddr := range allowFrom {
ips, err := net.LookupIP(myAddr)
if err != nil {
- logger.Debug(user, jobName, "backgroundCanSSH", "Unable to lookup IP address for allowed hosts lookup, skipping to next one...", myAddr, err)
+ dlog.Server.Debug(user, jobName, "backgroundCanSSH", "Unable to lookup IP "+
+ "address for allowed hosts lookup, skipping to next one...", myAddr, err)
continue
}
-
for _, ip := range ips {
- logger.Debug(user, jobName, "backgroundCanSSH", "Comparing IP addresses", remoteIP, ip.String())
+ dlog.Server.Debug(user, jobName, "backgroundCanSSH", "Comparing IP addresses",
+ remoteIP, ip.String())
if remoteIP == ip.String() {
return true
}
diff --git a/internal/server/stats.go b/internal/server/stats.go
index ac579ad..99a644a 100644
--- a/internal/server/stats.go
+++ b/internal/server/stats.go
@@ -3,12 +3,11 @@ package server
import (
"context"
"fmt"
- "runtime"
"sync"
"time"
"github.com/mimecast/dtail/internal/config"
- "github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/io/dlog"
)
// Used to collect and display various server stats.
@@ -20,7 +19,6 @@ type stats struct {
func (s *stats) incrementConnections() {
defer s.logServerStats()
-
s.mutex.Lock()
s.currentConnections++
s.lifetimeConnections++
@@ -29,7 +27,6 @@ func (s *stats) incrementConnections() {
func (s *stats) decrementConnections() {
defer s.logServerStats()
-
s.mutex.Lock()
s.currentConnections--
s.mutex.Unlock()
@@ -41,8 +38,8 @@ func (s *stats) hasConnections() bool {
s.mutex.Unlock()
has := currentConnections > 0
- logger.Info("stats", "Server with open connections?", has, currentConnections)
-
+ dlog.Server.Info("stats", "Server with open connections?",
+ has, currentConnections)
return has
}
@@ -50,10 +47,10 @@ func (s *stats) logServerStats() {
s.mutex.Lock()
defer s.mutex.Unlock()
- currentConnections := fmt.Sprintf("currentConnections=%d", s.currentConnections)
- lifetimeConnections := fmt.Sprintf("lifetimeConnections=%d", s.lifetimeConnections)
- goroutines := fmt.Sprintf("goroutines=%d", runtime.NumGoroutine())
- logger.Info("stats", currentConnections, lifetimeConnections, goroutines)
+ data := make(map[string]interface{})
+ data["currentConnections"] = s.currentConnections
+ data["lifetimeConnections"] = s.lifetimeConnections
+ dlog.Server.Mapreduce("STATS", data)
}
func (s *stats) serverLimitExceeded() error {
@@ -61,9 +58,9 @@ func (s *stats) serverLimitExceeded() error {
defer s.mutex.Unlock()
if s.currentConnections >= config.Server.MaxConnections {
- return fmt.Errorf("Exceeded max allowed concurrent connections of %d", config.Server.MaxConnections)
+ return fmt.Errorf("Exceeded max allowed concurrent connections of %d",
+ config.Server.MaxConnections)
}
-
return nil
}
diff --git a/internal/source/source.go b/internal/source/source.go
new file mode 100644
index 0000000..4bb0784
--- /dev/null
+++ b/internal/source/source.go
@@ -0,0 +1,30 @@
+package source
+
+// Source specifies the origin of either the current process (dtail is a client
+// process, dserver is a server process) or the source code package (e.g.
+// dserver server side code or dtail client side code). Notice that dtail client
+// may also executes server code directly (e.g. via serverless mode) and that
+// the dserver may also executes client code (e.g. via scheduled server side
+// mapreduce queries).
+type Source int
+
+const (
+ // Client process or source code package.
+ Client Source = iota
+ // Server process or source code package.
+ Server Source = iota
+ // HealthCheck process or client source code package.
+ HealthCheck Source = iota
+)
+
+func (s Source) String() string {
+ switch s {
+ case Client:
+ return "CLIENT"
+ case Server:
+ return "SERVER"
+ case HealthCheck:
+ return "HEALTHCHECK"
+ }
+ panic("Unknown source type")
+}
diff --git a/internal/ssh/client/authmethods.go b/internal/ssh/client/authmethods.go
index bbfb7be..37f8382 100644
--- a/internal/ssh/client/authmethods.go
+++ b/internal/ssh/client/authmethods.go
@@ -4,89 +4,106 @@ import (
"os"
"github.com/mimecast/dtail/internal/config"
- "github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/io/dlog"
"github.com/mimecast/dtail/internal/ssh"
gossh "golang.org/x/crypto/ssh"
)
// InitSSHAuthMethods initialises all known SSH auth methods on the client side.
-func InitSSHAuthMethods(sshAuthMethods []gossh.AuthMethod, hostKeyCallback gossh.HostKeyCallback, trustAllHosts bool, throttleCh chan struct{}, privateKeyPath string) ([]gossh.AuthMethod, HostKeyCallback) {
+func InitSSHAuthMethods(sshAuthMethods []gossh.AuthMethod,
+ hostKeyCallback gossh.HostKeyCallback, trustAllHosts bool, throttleCh chan struct{},
+ privateKeyPath string) ([]gossh.AuthMethod, HostKeyCallback) {
+
if len(sshAuthMethods) > 0 {
simpleCallback, err := NewSimpleCallback()
if err != nil {
- logger.FatalExit(err)
+ dlog.Client.FatalPanic(err)
}
return sshAuthMethods, simpleCallback
}
-
return initKnownHostsAuthMethods(trustAllHosts, throttleCh, privateKeyPath)
}
-func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{}, privateKeyPath string) ([]gossh.AuthMethod, HostKeyCallback) {
- var sshAuthMethods []gossh.AuthMethod
+func initKnownHostsAuthMethods(trustAllHosts bool, throttleCh chan struct{},
+ privateKeyPath string) ([]gossh.AuthMethod, HostKeyCallback) {
+ var sshAuthMethods []gossh.AuthMethod
knownHostsPath := os.Getenv("HOME") + "/.ssh/known_hosts"
- knownHostsCallback, err := NewKnownHostsCallback(knownHostsPath, trustAllHosts, throttleCh)
+ knownHostsCallback, err := NewKnownHostsCallback(knownHostsPath, trustAllHosts,
+ throttleCh)
if err != nil {
- logger.FatalExit(knownHostsPath, err)
- }
- logger.Debug("initKnownHostsAuthMethods", "Added known hosts file path", knownHostsPath)
-
- if config.Common.ExperimentalFeaturesEnable {
- sshAuthMethods = append(sshAuthMethods, gossh.Password("experimental feature test"))
- logger.Debug("initKnownHostsAuthMethods", "Added experimental method to list of auth methods")
+ dlog.Client.FatalPanic(knownHostsPath, err)
}
+ dlog.Client.Debug("initKnownHostsAuthMethods", "Added known hosts file path", knownHostsPath)
+ /*
+ if config.Client.ExperimentalFeaturesEnable {
+ sshAuthMethods = append(sshAuthMethods, gossh.Password("experimental feature test"))
+ dlog.Client.Debug("initKnownHostsAuthMethods", "Added experimental method to list of auth methods")
+ }
+ */
// First try to read custom private key path.
if privateKeyPath != "" {
authMethod, err := ssh.PrivateKey(privateKeyPath)
if err == nil {
sshAuthMethods = append(sshAuthMethods, authMethod)
- logger.Debug("initKnownHostsAuthMethods", "Added path to list of auth methods, not adding further methods", privateKeyPath)
+ dlog.Client.Debug("initKnownHostsAuthMethods",
+ "Added path to list of auth methods, not adding further methods",
+ privateKeyPath)
return sshAuthMethods, knownHostsCallback
}
- logger.FatalExit("Unable to use private SSH key", privateKeyPath, err)
+ dlog.Client.FatalPanic("Unable to use private SSH key", privateKeyPath, err)
}
// Second, try SSH Agent
authMethod, err := ssh.Agent()
if err == nil {
sshAuthMethods = append(sshAuthMethods, authMethod)
- logger.Debug("initKnownHostsAuthMethods", "Added SSH Agent (SSH_AUTH_SOCK) to list of auth methods, not adding further methods")
+ dlog.Client.Debug("initKnownHostsAuthMethods", "Added SSH Agent (SSH_AUTH_SOCK)"+
+ "to list of auth methods, not adding further methods")
return sshAuthMethods, knownHostsCallback
}
- logger.Debug("initKnownHostsAuthMethods", "Unable to init SSH Agent auth method", err)
+ dlog.Client.Debug("initKnownHostsAuthMethods",
+ "Unable to init SSH Agent auth method", err)
// Third, try Linux/UNIX default key paths
privateKeyPath = os.Getenv("HOME") + "/.ssh/id_rsa"
authMethod, err = ssh.PrivateKey(privateKeyPath)
if err == nil {
sshAuthMethods = append(sshAuthMethods, authMethod)
- logger.Debug("initKnownHostsAuthmethods", "Added path to list of auth methods, not adding further methods", privateKeyPath)
+ dlog.Client.Debug("initKnownHostsAuthmethods",
+ "Added path to list of auth methods, not adding further methods", privateKeyPath)
return sshAuthMethods, knownHostsCallback
}
- logger.Debug("initKnownHostsAuthMethods", "Unable to use private key", privateKeyPath, err)
+ dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to use private key",
+ privateKeyPath, err)
privateKeyPath = os.Getenv("HOME") + "/.ssh/id_dsa"
authMethod, err = ssh.PrivateKey(privateKeyPath)
if err == nil {
sshAuthMethods = append(sshAuthMethods, authMethod)
- logger.Debug("initKnownHostsAuthmethods", "Added path to list of auth methods, not adding further methods", privateKeyPath)
+ dlog.Client.Debug("initKnownHostsAuthmethods",
+ "Added path to list of auth methods, not adding further methods", privateKeyPath)
return sshAuthMethods, knownHostsCallback
}
- logger.Debug("initKnownHostsAuthMethods", "Unable to use private key", privateKeyPath, err)
privateKeyPath = os.Getenv("HOME") + "/.ssh/id_ecdsa"
authMethod, err = ssh.PrivateKey(privateKeyPath)
if err == nil {
sshAuthMethods = append(sshAuthMethods, authMethod)
- logger.Debug("initKnownHostsAuthmethods", "Added path to list of auth methods, not adding further methods", privateKeyPath)
+ dlog.Client.Debug("initKnownHostsAuthmethods",
+ "Added path to list of auth methods, not adding further methods", privateKeyPath)
return sshAuthMethods, knownHostsCallback
}
- logger.Debug("initKnownHostsAuthMethods", "Unable to use private key", privateKeyPath, err)
- logger.FatalExit("Unable to find private SSH key information")
+ dlog.Client.Debug("initKnownHostsAuthMethods", "Unable to use private key",
+ privateKeyPath, err)
+
+ // This is only a panic when we expect to do something about it.
+ if !config.Client.SSHDontAddHostsToKnownHostsFile {
+ dlog.Client.FatalPanic("Unable to find private SSH key information")
+ }
// Never reach this point.
return sshAuthMethods, knownHostsCallback
diff --git a/internal/ssh/client/customkeycallback.go b/internal/ssh/client/customkeycallback.go
index 73e5289..53b8e3c 100644
--- a/internal/ssh/client/customkeycallback.go
+++ b/internal/ssh/client/customkeycallback.go
@@ -7,8 +7,7 @@ import (
)
// CustomCallback is a custom host key callback wrapper.
-type CustomCallback struct {
-}
+type CustomCallback struct{}
// NewCustomCallback returns a new wrapper.
func NewCustomCallback() (*CustomCallback, error) {
diff --git a/internal/ssh/client/knownhostscallback.go b/internal/ssh/client/knownhostscallback.go
index 1ccf6c6..2aa0168 100644
--- a/internal/ssh/client/knownhostscallback.go
+++ b/internal/ssh/client/knownhostscallback.go
@@ -10,7 +10,8 @@ import (
"sync"
"time"
- "github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/io/dlog"
"github.com/mimecast/dtail/internal/io/prompt"
"golang.org/x/crypto/ssh"
@@ -46,8 +47,9 @@ type KnownHostsCallback struct {
}
// NewKnownHostsCallback returns a new wrapper.
-func NewKnownHostsCallback(knownHostsPath string, trustAllHosts bool, throttleCh chan struct{}) (HostKeyCallback, error) {
- // Ensure file exists
+func NewKnownHostsCallback(knownHostsPath string, trustAllHosts bool,
+ throttleCh chan struct{}) (HostKeyCallback, error) {
+
os.OpenFile(knownHostsPath, os.O_RDONLY|os.O_CREATE, 0666)
untrustedHosts := make(map[string]bool)
@@ -59,11 +61,9 @@ func NewKnownHostsCallback(knownHostsPath string, trustAllHosts bool, throttleCh
untrustedHosts: untrustedHosts,
mutex: &sync.Mutex{},
}
-
if trustAllHosts {
close(c.trustAllHostsCh)
}
-
return c, nil
}
@@ -75,14 +75,12 @@ func (c KnownHostsCallback) Wrap() ssh.HostKeyCallback {
if err != nil {
return err
}
-
// Check for valid entry in known_hosts file
err = knownHostsCb(server, remote, key)
if err == nil {
// OK
return nil
}
-
// Make sure that interactive user callback does not interfere with
// SSH connection throttler.
<-c.throttleCh
@@ -96,11 +94,9 @@ func (c KnownHostsCallback) Wrap() ssh.HostKeyCallback {
ipLine: knownhosts.Line([]string{remote.String()}, key),
responseCh: make(chan response),
}
-
- logger.Warn("Encountered unknown host", unknown)
+ dlog.Common.Warn("Encountered unknown host", unknown)
// Notify user that there is an unknown host
c.unknownCh <- unknown
-
// Wait for user input.
switch <-unknown.responseCh {
case trustHost:
@@ -112,7 +108,6 @@ func (c KnownHostsCallback) Wrap() ssh.HostKeyCallback {
c.mutex.Lock()
defer c.mutex.Unlock()
c.untrustedHosts[server] = true
-
return err
}
}
@@ -121,7 +116,6 @@ func (c KnownHostsCallback) Wrap() ssh.HostKeyCallback {
// be added to the known hosts or not.
func (c KnownHostsCallback) PromptAddHosts(ctx context.Context) {
var hosts []unknownHost
-
for {
// Check whether there is a unknown host
select {
@@ -139,7 +133,7 @@ func (c KnownHostsCallback) PromptAddHosts(ctx context.Context) {
hosts = []unknownHost{}
}
case <-ctx.Done():
- logger.Debug("Stopping goroutine prompting new hosts...")
+ dlog.Common.Debug("Stopping goroutine prompting new hosts...")
return
}
}
@@ -147,14 +141,13 @@ func (c KnownHostsCallback) PromptAddHosts(ctx context.Context) {
func (c KnownHostsCallback) promptAddHosts(hosts []unknownHost) {
var servers []string
-
for _, host := range hosts {
servers = append(servers, host.server)
}
select {
case <-c.trustAllHostsCh:
- logger.Warn("Trusting host keys of servers", servers)
+ dlog.Common.Warn("Trusting host keys of servers", servers)
c.trustHosts(hosts)
return
default:
@@ -165,7 +158,6 @@ func (c KnownHostsCallback) promptAddHosts(hosts []unknownHost) {
strings.Join(servers, ","),
"Do you want to trust these hosts?",
)
-
p := prompt.New(question)
a := prompt.Answer{
@@ -175,7 +167,7 @@ func (c KnownHostsCallback) promptAddHosts(hosts []unknownHost) {
c.trustHosts(hosts)
},
EndCallback: func() {
- logger.Info("Added hosts to known hosts file", c.knownHostsPath)
+ dlog.Common.Info("Added hosts to known hosts file", c.knownHostsPath)
},
}
p.Add(a)
@@ -188,7 +180,7 @@ func (c KnownHostsCallback) promptAddHosts(hosts []unknownHost) {
c.trustHosts(hosts)
},
EndCallback: func() {
- logger.Info("Added hosts to known hosts file", c.knownHostsPath)
+ dlog.Common.Info("Added hosts to known hosts file", c.knownHostsPath)
},
}
p.Add(a)
@@ -200,7 +192,7 @@ func (c KnownHostsCallback) promptAddHosts(hosts []unknownHost) {
c.dontTrustHosts(hosts)
},
EndCallback: func() {
- logger.Info("Didn't add hosts to known hosts file", c.knownHostsPath)
+ dlog.Common.Info("Didn't add hosts to known hosts file", c.knownHostsPath)
},
}
p.Add(a)
@@ -224,6 +216,11 @@ func (c KnownHostsCallback) promptAddHosts(hosts []unknownHost) {
func (c KnownHostsCallback) trustHosts(hosts []unknownHost) {
tmpKnownHostsPath := fmt.Sprintf("%s.tmp", c.knownHostsPath)
+ if config.Client.SSHDontAddHostsToKnownHostsFile {
+ dlog.Common.Verbose("Not adding hosts to known hosts file, as disabled by config")
+ return
+ }
+
newFd, err := os.OpenFile(tmpKnownHostsPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0600)
if err != nil {
panic(fmt.Sprintf("%s: %s", tmpKnownHostsPath, err.Error()))
@@ -232,7 +229,6 @@ func (c KnownHostsCallback) trustHosts(hosts []unknownHost) {
// Newly trusted hosts in normalized form
addresses := make(map[string]struct{})
-
// First write to new known hosts file, and keep track of addresses
for _, unknown := range hosts {
unknown.responseCh <- trustHost
@@ -255,7 +251,6 @@ func (c KnownHostsCallback) trustHosts(hosts []unknownHost) {
defer oldFd.Close()
scanner := bufio.NewScanner(oldFd)
-
// Now, append all still valid old entries to the new host file
for scanner.Scan() {
line := scanner.Text()
@@ -283,6 +278,5 @@ func (c KnownHostsCallback) Untrusted(server string) bool {
c.mutex.Lock()
defer c.mutex.Unlock()
_, ok := c.untrustedHosts[server]
-
return ok
}
diff --git a/internal/ssh/server/hostkey.go b/internal/ssh/server/hostkey.go
index 07790ad..33bd4e8 100644
--- a/internal/ssh/server/hostkey.go
+++ b/internal/ssh/server/hostkey.go
@@ -1,11 +1,12 @@
package server
import (
- "github.com/mimecast/dtail/internal/config"
- "github.com/mimecast/dtail/internal/io/logger"
- "github.com/mimecast/dtail/internal/ssh"
"io/ioutil"
"os"
+
+ "github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/io/dlog"
+ "github.com/mimecast/dtail/internal/ssh"
)
// PrivateHostKey retrieves the private server RSA host key.
@@ -14,24 +15,25 @@ func PrivateHostKey() []byte {
_, err := os.Stat(hostKeyFile)
if os.IsNotExist(err) {
- logger.Info("Generating private server RSA host key")
+ dlog.Common.Info("Generating private server RSA host key")
privateKey, err := ssh.GeneratePrivateRSAKey(config.Server.HostKeyBits)
if err != nil {
- logger.FatalExit("Failed to generate private server RSA host key", err)
+ dlog.Common.FatalPanic("Failed to generate private server RSA host key", err)
}
pem := ssh.EncodePrivateKeyToPEM(privateKey)
if err := ioutil.WriteFile(hostKeyFile, pem, 0600); err != nil {
- logger.Error("Unable to write private server RSA host key to file", hostKeyFile, err)
+ dlog.Common.Error("Unable to write private server RSA host key to file",
+ hostKeyFile, err)
}
return pem
}
- logger.Info("Reading private server RSA host key from file", hostKeyFile)
+ dlog.Common.Info("Reading private server RSA host key from file", hostKeyFile)
pem, err := ioutil.ReadFile(hostKeyFile)
if err != nil {
- logger.FatalExit("Failed to load private server RSA host key", err)
+ dlog.Common.FatalPanic("Failed to load private server RSA host key", err)
}
return pem
}
diff --git a/internal/ssh/server/publickeycallback.go b/internal/ssh/server/publickeycallback.go
index e81f019..ebc428a 100644
--- a/internal/ssh/server/publickeycallback.go
+++ b/internal/ssh/server/publickeycallback.go
@@ -7,28 +7,34 @@ import (
osUser "os/user"
"github.com/mimecast/dtail/internal/config"
- "github.com/mimecast/dtail/internal/io/logger"
+ "github.com/mimecast/dtail/internal/io/dlog"
user "github.com/mimecast/dtail/internal/user/server"
gossh "golang.org/x/crypto/ssh"
)
-// PublicKeyCallback is for the server to check whether a public SSH key is authorized ot not.
-func PublicKeyCallback(c gossh.ConnMetadata, offeredPubKey gossh.PublicKey) (*gossh.Permissions, error) {
- user := user.New(c.User(), c.RemoteAddr().String())
- logger.Info(user, "Incoming authorization")
+// PublicKeyCallback is for the server to check whether a public SSH key is
+// authorized ot not.
+func PublicKeyCallback(c gossh.ConnMetadata,
+ offeredPubKey gossh.PublicKey) (*gossh.Permissions, error) {
+ user, err := user.New(c.User(), c.RemoteAddr().String())
+ if err != nil {
+ return nil, err
+ }
+
+ dlog.Common.Info(user, "Incoming authorization")
cwd, err := os.Getwd()
if err != nil {
return nil, fmt.Errorf("Unable to get current working directory|%s|", err.Error())
}
-
if config.ServerRelaxedAuthEnable {
- logger.Fatal(user, "Granting permissions via relaxed-auth")
+ dlog.Common.Fatal(user, "Granting permissions via relaxed-auth")
return nil, nil
}
- authorizedKeysFile := fmt.Sprintf("%s/%s/%s.authorized_keys", cwd, config.Common.CacheDir, user.Name)
+ authorizedKeysFile := fmt.Sprintf("%s/%s/%s.authorized_keys", cwd,
+ config.Common.CacheDir, user.Name)
if _, err := os.Stat(authorizedKeysFile); os.IsNotExist(err) {
user, err := osUser.Lookup(user.Name)
if err != nil {
@@ -38,26 +44,28 @@ func PublicKeyCallback(c gossh.ConnMetadata, offeredPubKey gossh.PublicKey) (*go
authorizedKeysFile = user.HomeDir + "/.ssh/authorized_keys"
}
- logger.Info(user, "Reading", authorizedKeysFile)
+ dlog.Common.Info(user, "Reading", authorizedKeysFile)
authorizedKeysBytes, err := ioutil.ReadFile(authorizedKeysFile)
if err != nil {
- return nil, fmt.Errorf("Unable to read authorized keys file|%s|%s|%s", authorizedKeysFile, user, err.Error())
+ return nil, fmt.Errorf("Unable to read authorized keys file|%s|%s|%s",
+ authorizedKeysFile, user, err.Error())
}
authorizedKeysMap := map[string]bool{}
for len(authorizedKeysBytes) > 0 {
authorizedPubKey, _, _, restBytes, err := gossh.ParseAuthorizedKey(authorizedKeysBytes)
if err != nil {
- return nil, fmt.Errorf("Unable to parse authorized keys bytes|%s|%s", user, err.Error())
+ return nil, fmt.Errorf("Unable to parse authorized keys bytes|%s|%s",
+ user, err.Error())
}
authorizedKeysMap[string(authorizedPubKey.Marshal())] = true
authorizedKeysBytes = restBytes
-
- logger.Debug(user, "Authorized public key fingerprint", gossh.FingerprintSHA256(authorizedPubKey))
+ dlog.Common.Debug(user, "Authorized public key fingerprint",
+ gossh.FingerprintSHA256(authorizedPubKey))
}
- logger.Debug(user, "Offered public key fingerprint", gossh.FingerprintSHA256(offeredPubKey))
-
+ dlog.Common.Debug(user, "Offered public key fingerprint",
+ gossh.FingerprintSHA256(offeredPubKey))
if authorizedKeysMap[string(offeredPubKey.Marshal())] {
return &gossh.Permissions{
Extensions: map[string]string{
diff --git a/internal/ssh/ssh.go b/internal/ssh/ssh.go
index 3a2e416..db5aaf1 100644
--- a/internal/ssh/ssh.go
+++ b/internal/ssh/ssh.go
@@ -6,12 +6,13 @@ import (
"crypto/x509"
"encoding/pem"
"fmt"
- "github.com/mimecast/dtail/internal/io/logger"
"io/ioutil"
"net"
"os"
"syscall"
+ "github.com/mimecast/dtail/internal/io/dlog"
+
gossh "golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
"golang.org/x/crypto/ssh/terminal"
@@ -23,12 +24,10 @@ func GeneratePrivateRSAKey(size int) (*rsa.PrivateKey, error) {
if err != nil {
return nil, err
}
-
err = privateKey.Validate()
if err != nil {
return nil, err
}
-
return privateKey, nil
}
@@ -41,7 +40,6 @@ func EncodePrivateKeyToPEM(privateKey *rsa.PrivateKey) []byte {
Headers: nil,
Bytes: derFormat,
}
-
return pem.EncodeToMemory(&block)
}
@@ -57,7 +55,7 @@ func Agent() (gossh.AuthMethod, error) {
return nil, err
}
for i, key := range keys {
- logger.Debug("Public key", i, key)
+ dlog.Common.Debug("Public key", i, key)
}
return gossh.PublicKeysCallback(agentClient.Signers), nil
}
@@ -79,7 +77,6 @@ func KeyFile(keyFile string) (gossh.AuthMethod, error) {
if err != nil {
return nil, err
}
-
key, err := gossh.ParsePrivateKey(buffer)
if err != nil {
return nil, err
@@ -105,7 +102,7 @@ func KeyFile(keyFile string) (gossh.AuthMethod, error) {
func PrivateKey(keyFile string) (gossh.AuthMethod, error) {
signer, err := KeyFile(keyFile)
if err != nil {
- logger.Debug(keyFile, err)
+ dlog.Common.Debug(keyFile, err)
return nil, err
}
return gossh.AuthMethod(signer), nil
diff --git a/internal/user/name.go b/internal/user/name.go
index 28ab0a4..cd11907 100644
--- a/internal/user/name.go
+++ b/internal/user/name.go
@@ -10,11 +10,9 @@ func NoRootCheck() {
if err != nil {
panic(err)
}
-
if user.Uid == "0" {
panic("Not allowed to run as UID 0")
}
-
if user.Gid == "0" {
panic("Not allowed to run as GID 0")
}
@@ -26,6 +24,5 @@ func Name() string {
if err != nil {
panic(err)
}
-
return user.Username
}
diff --git a/internal/user/server/user.go b/internal/user/server/user.go
index af6b0d0..aa7f8b1 100644
--- a/internal/user/server/user.go
+++ b/internal/user/server/user.go
@@ -8,8 +8,8 @@ import (
"strings"
"github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/io/dlog"
"github.com/mimecast/dtail/internal/io/fs/permissions"
- "github.com/mimecast/dtail/internal/io/logger"
)
const maxLinkDepth int = 100
@@ -25,11 +25,16 @@ type User struct {
}
// New returns a new user.
-func New(name, remoteAddress string) *User {
+func New(name, remoteAddress string) (*User, error) {
+ permissions, err := config.ServerUserPermissions(name)
+ if err != nil {
+ return nil, err
+ }
return &User{
Name: name,
remoteAddress: remoteAddress,
- }
+ permissions: permissions,
+ }, nil
}
// String representation of the user.
@@ -37,14 +42,13 @@ func (u *User) String() string {
return fmt.Sprintf("%s@%s", u.Name, u.remoteAddress)
}
-// HasFilePermission is used to determine whether user is allowed to read a file.
+// HasFilePermission is used to determine whether user is alowed to read a file.
func (u *User) HasFilePermission(filePath, permissionType string) (hasPermission bool) {
- logger.Debug(u, filePath, permissionType, "Checking config permissions")
+ dlog.Server.Debug(u, filePath, permissionType, "Checking config permissions")
if config.ServerRelaxedAuthEnable {
- logger.Fatal(u, filePath, permissionType, "Server releaxed auth enabled")
+ dlog.Server.Fatal(u, filePath, permissionType, "Server releaxed auth enabled")
return true
}
-
if u.Name == config.ScheduleUser || u.Name == config.ContinuousUser {
// Background user has same permissions as dtail process itself.
return true
@@ -52,27 +56,29 @@ func (u *User) HasFilePermission(filePath, permissionType string) (hasPermission
cleanPath, err := filepath.EvalSymlinks(filePath)
if err != nil {
- logger.Error(u, filePath, permissionType, "Unable to evaluate symlinks", err)
+ dlog.Server.Error(u, filePath, permissionType,
+ "Unable to evaluate symlinks", err)
hasPermission = false
return
}
cleanPath, err = filepath.Abs(cleanPath)
if err != nil {
- logger.Error(u, cleanPath, permissionType, "Unable to make file path absolute", err)
+ dlog.Server.Error(u, cleanPath, permissionType,
+ "Unable to make file path absolute", err)
hasPermission = false
return
}
if cleanPath != filePath {
- logger.Info(u, filePath, cleanPath, permissionType, "Calculated new clean path from original file path (possibly symlink)")
+ dlog.Server.Info(u, filePath, cleanPath, permissionType,
+ "Calculated new clean path from original file path (possibly symlink)")
}
hasPermission, err = u.hasFilePermission(cleanPath, permissionType)
if err != nil {
- logger.Warn(u, cleanPath, err)
+ dlog.Server.Warn(u, cleanPath, err)
}
-
return
}
@@ -81,24 +87,17 @@ func (u *User) hasFilePermission(cleanPath, permissionType string) (bool, error)
if _, err := permissions.ToRead(u.Name, cleanPath); err != nil {
return false, fmt.Errorf("User without OS file system permissions to read path: '%v'", err)
}
- logger.Info(u, cleanPath, permissionType, "User with OS file system permissions to path")
+ dlog.Server.Info(u, cleanPath, permissionType,
+ "User with OS file system permissions to path")
// Only allow to follow regular files or symlinks.
info, err := os.Lstat(cleanPath)
if err != nil {
return false, fmt.Errorf("Unable to determine file type: '%v'", err)
}
-
if !info.Mode().IsRegular() {
return false, fmt.Errorf("Can only open regular files or follow symlinks")
}
-
- permissions, err := config.ServerUserPermissions(u.Name)
- if err != nil {
- return false, err
- }
- u.permissions = permissions
-
hasPermission, err := u.iteratePaths(cleanPath, permissionType)
if err != nil {
return false, err
@@ -110,10 +109,8 @@ func (u *User) hasFilePermission(cleanPath, permissionType string) (bool, error)
func (u *User) iteratePaths(cleanPath, permissionType string) (bool, error) {
// By default assume no permissions
hasPermission := false
-
for _, permission := range u.permissions {
typeStr := "readfiles" // Assume ReadFiles by default.
-
var regexStr string
var negate bool
@@ -123,8 +120,7 @@ func (u *User) iteratePaths(cleanPath, permissionType string) (bool, error) {
permission = strings.Join(splitted[1:], ":")
}
- logger.Debug(u, cleanPath, typeStr, permission)
-
+ dlog.Server.Debug(u, cleanPath, typeStr, permission)
if typeStr != permissionType {
continue
}
@@ -137,16 +133,17 @@ func (u *User) iteratePaths(cleanPath, permissionType string) (bool, error) {
re, err := regexp.Compile(regexStr)
if err != nil {
- return false, fmt.Errorf("Permission test failed, can't compile regex '%s': '%v'", regexStr, err)
+ return false, fmt.Errorf("Permission test failed, can't compile regex "+
+ "'%s': '%v'", regexStr, err)
}
-
if negate && re.MatchString(cleanPath) {
- logger.Info(u, cleanPath, "Permission test failed partially, matching negative pattern '%s'", permission)
+ dlog.Server.Info(u, cleanPath, "Permission test failed partially, "+
+ "matching negative pattern '%s'", permission)
hasPermission = false
}
-
if !negate && re.MatchString(cleanPath) {
- logger.Info(u, cleanPath, "Permission test passed partially, matching positive pattern", permission)
+ dlog.Server.Info(u, cleanPath, "Permission test passed partially, "+
+ "matching positive pattern", permission)
hasPermission = true
}
}
diff --git a/internal/version/version.go b/internal/version/version.go
index 3c31059..68b9e6e 100644
--- a/internal/version/version.go
+++ b/internal/version/version.go
@@ -5,38 +5,50 @@ import (
"os"
"github.com/mimecast/dtail/internal/color"
+ "github.com/mimecast/dtail/internal/config"
+ "github.com/mimecast/dtail/internal/protocol"
)
const (
// Name of DTail.
Name string = "DTail"
// Version of DTail.
- Version string = "3.3.1"
+ Version string = "4.0.0-RC1"
// Additional information for DTail
- Additional string = ""
- // ProtocolCompat -ibility version.
- ProtocolCompat string = "3"
+ Additional string = "Have a lot of fun!"
)
// String representation of the DTail version.
func String() string {
- return fmt.Sprintf("%s %v Protocol %s %s", Name, Version, ProtocolCompat, Additional)
+ return fmt.Sprintf("%s %v Protocol %s %s", Name, Version,
+ protocol.ProtocolCompat, Additional)
}
// PaintedString is a prettier string representation of the DTail version.
func PaintedString() string {
- if !color.Colored {
+ if !config.Client.TermColorsEnable {
return String()
}
- name := color.Paint(color.Yellow, Name)
- version := color.Paint(color.Blue, Version)
- descr := color.Paint(color.Green, Additional)
- return fmt.Sprintf("%s %v Protocol %s %s", name, version, ProtocolCompat, descr)
+ name := color.PaintStrWithAttr(fmt.Sprintf(" %s ", Name),
+ color.FgYellow, color.BgBlue, color.AttrBold)
+ version := color.PaintStrWithAttr(fmt.Sprintf(" %s ", Version),
+ color.FgBlue, color.BgYellow, color.AttrBold)
+ protocol := color.PaintStr(fmt.Sprintf(" Protocol %s ", protocol.ProtocolCompat),
+ color.FgBlack, color.BgGreen)
+ additional := color.PaintStrWithAttr(fmt.Sprintf(" %s ", Additional),
+ color.FgWhite, color.BgMagenta, color.AttrUnderline)
+
+ return fmt.Sprintf("%s%v%s%s", name, version, protocol, additional)
+}
+
+// Print the version.
+func Print() {
+ fmt.Println(PaintedString())
}
// PrintAndExit prints the program version and exists.
func PrintAndExit() {
- fmt.Println(PaintedString())
+ Print()
os.Exit(0)
}