From 0945da8dfefcbb723eecea0e5f4eafff63398253 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paul=20B=C3=BCtow?= Date: Sun, 26 Jan 2020 11:26:53 +0000 Subject: Introduce drun command, refactor code to use context package --- internal/clients/args.go | 3 +- internal/clients/baseclient.go | 130 +++--- internal/clients/catclient.go | 20 +- internal/clients/client.go | 5 +- internal/clients/connectionmaker.go | 12 - internal/clients/execclient.go | 48 --- internal/clients/grepclient.go | 20 +- internal/clients/handlers/basehandler.go | 84 ++-- internal/clients/handlers/clienthandler.go | 11 +- internal/clients/handlers/handler.go | 12 +- internal/clients/handlers/healthhandler.go | 21 +- internal/clients/handlers/maprhandler.go | 21 +- internal/clients/handlers/withcancel.go | 24 ++ internal/clients/healthclient.go | 7 +- internal/clients/maker.go | 8 + internal/clients/maprclient.go | 52 ++- internal/clients/remote/connection.go | 116 ++---- internal/clients/runclient.go | 40 ++ internal/clients/stats.go | 8 +- internal/clients/tailclient.go | 21 +- internal/discovery/comma.go | 2 +- internal/discovery/discovery.go | 21 +- internal/discovery/file.go | 2 +- internal/fs/catfile.go | 27 -- internal/fs/filereader.go | 9 - internal/fs/lineread.go | 28 -- internal/fs/permissions/permission.go | 14 - internal/fs/permissions/permission_linux.c | 395 ------------------- internal/fs/permissions/permission_linux.go | 33 -- internal/fs/permissions/permission_linux.h | 60 --- internal/fs/permissions/permission_test.go | 112 ------ internal/fs/readfile.go | 318 --------------- internal/fs/stats.go | 69 ---- internal/fs/tailfile.go | 27 -- internal/io/fs/catfile.go | 21 + internal/io/fs/filereader.go | 14 + internal/io/fs/permissions/permission.go | 14 + internal/io/fs/permissions/permission_linux.c | 395 +++++++++++++++++++ internal/io/fs/permissions/permission_linux.go | 33 ++ internal/io/fs/permissions/permission_linux.h | 60 +++ internal/io/fs/permissions/permission_test.go | 112 ++++++ internal/io/fs/readfile.go | 307 +++++++++++++++ internal/io/fs/stats.go | 69 ++++ internal/io/fs/tailfile.go | 21 + internal/io/line/line.go | 28 ++ internal/io/logger/logger.go | 445 +++++++++++++++++++++ internal/io/run/run.go | 104 +++++ internal/logger/logger.go | 457 ---------------------- internal/mapr/aggregateset.go | 5 +- internal/mapr/client/aggregate.go | 25 +- internal/mapr/groupset.go | 5 +- internal/mapr/logformat/parser.go | 2 +- internal/mapr/query.go | 2 +- internal/mapr/server/aggregate.go | 141 ++++--- internal/mapr/wherecondition.go | 2 +- internal/omode/mode.go | 6 +- internal/pprof/pprof.go | 3 +- internal/prompt/prompt.go | 2 +- internal/server/handlers/controlhandler.go | 42 +- internal/server/handlers/handler.go | 2 - internal/server/handlers/mapcommand.go | 35 ++ internal/server/handlers/readcommand.go | 158 ++++++++ internal/server/handlers/runcommand.go | 73 ++++ internal/server/handlers/serverhandler.go | 521 +++++++++---------------- internal/server/server.go | 70 ++-- internal/server/stats.go | 10 +- internal/ssh/client/authmethods.go | 2 +- internal/ssh/client/hostkeycallback.go | 10 +- internal/ssh/server/hostkey.go | 2 +- internal/ssh/server/publickeycallback.go | 2 +- internal/ssh/ssh.go | 2 +- internal/user/name.go | 15 +- internal/user/server/user.go | 44 ++- internal/version/version.go | 22 +- 74 files changed, 2612 insertions(+), 2451 deletions(-) delete mode 100644 internal/clients/connectionmaker.go delete mode 100644 internal/clients/execclient.go create mode 100644 internal/clients/handlers/withcancel.go create mode 100644 internal/clients/maker.go create mode 100644 internal/clients/runclient.go delete mode 100644 internal/fs/catfile.go delete mode 100644 internal/fs/filereader.go delete mode 100644 internal/fs/lineread.go delete mode 100644 internal/fs/permissions/permission.go delete mode 100644 internal/fs/permissions/permission_linux.c delete mode 100644 internal/fs/permissions/permission_linux.go delete mode 100644 internal/fs/permissions/permission_linux.h delete mode 100644 internal/fs/permissions/permission_test.go delete mode 100644 internal/fs/readfile.go delete mode 100644 internal/fs/stats.go delete mode 100644 internal/fs/tailfile.go create mode 100644 internal/io/fs/catfile.go create mode 100644 internal/io/fs/filereader.go create mode 100644 internal/io/fs/permissions/permission.go create mode 100644 internal/io/fs/permissions/permission_linux.c create mode 100644 internal/io/fs/permissions/permission_linux.go create mode 100644 internal/io/fs/permissions/permission_linux.h create mode 100644 internal/io/fs/permissions/permission_test.go create mode 100644 internal/io/fs/readfile.go create mode 100644 internal/io/fs/stats.go create mode 100644 internal/io/fs/tailfile.go create mode 100644 internal/io/line/line.go create mode 100644 internal/io/logger/logger.go create mode 100644 internal/io/run/run.go delete mode 100644 internal/logger/logger.go create mode 100644 internal/server/handlers/mapcommand.go create mode 100644 internal/server/handlers/readcommand.go create mode 100644 internal/server/handlers/runcommand.go (limited to 'internal') diff --git a/internal/clients/args.go b/internal/clients/args.go index 5fe0a72..dea5a9e 100644 --- a/internal/clients/args.go +++ b/internal/clients/args.go @@ -9,10 +9,9 @@ type Args struct { Mode omode.Mode ServersStr string UserName string - Files string + What string Regex string TrustAllHosts bool Discovery string ConnectionsPerCPU int - PingTimeout int } diff --git a/internal/clients/baseclient.go b/internal/clients/baseclient.go index 574ae94..b1540ea 100644 --- a/internal/clients/baseclient.go +++ b/internal/clients/baseclient.go @@ -1,13 +1,14 @@ package clients import ( + "context" "regexp" "sync" "time" "github.com/mimecast/dtail/internal/clients/remote" "github.com/mimecast/dtail/internal/discovery" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/omode" "github.com/mimecast/dtail/internal/ssh/client" @@ -27,111 +28,110 @@ type baseClient struct { sshAuthMethods []gossh.AuthMethod // To deal with SSH host keys hostKeyCallback *client.HostKeyCallback - // To stop the client. - stop chan struct{} - // To indicate that the client has stopped. - stopped chan struct{} // Throttle how fast we initiate SSH connections concurrently throttleCh chan struct{} // Retry connection upon failure? retry bool - // Connection helper. - maker connectionMaker + // Connection maker helper. + maker maker } -func (c *baseClient) init(maker connectionMaker) { +func (c *baseClient) init(maker maker) { logger.Info("Initiating base client") c.maker = maker - //c.connections = make(map[string]*remote.Connection) c.sshAuthMethods, c.hostKeyCallback = client.InitSSHAuthMethods(c.TrustAllHosts, c.throttleCh) + discoveryService := discovery.New(c.Discovery, c.ServersStr, discovery.Shuffle) - // Retrieve a shuffled list of remote dtail servers. - shuffleServers := true - discoveryService := discovery.New(c.Discovery, c.ServersStr, shuffleServers) for _, server := range discoveryService.ServerList() { - c.connections = append(c.connections, c.maker.makeConnection(server, c.sshAuthMethods, c.hostKeyCallback)) + c.connections = append(c.connections, c.makeConnection(server, c.sshAuthMethods, c.hostKeyCallback)) } if _, err := regexp.Compile(c.Regex); err != nil { logger.FatalExit(c.Regex, "Can't test compile regex", err) } - // Periodically check for unknown hosts, and ask the user whether to trust them or not. - go c.hostKeyCallback.PromptAddHosts(c.stop) - - // Periodically print out connection stats to the client. c.stats = newTailStats(len(c.connections)) - go c.stats.periodicLogStats(c.throttleCh, c.stop) } -func (c *baseClient) Start() (status int) { +func (c *baseClient) Start(ctx context.Context) (status int) { + // Periodically check for unknown hosts, and ask the user whether to trust them or not. + go c.hostKeyCallback.PromptAddHosts(ctx) + // Periodically print out connection stats to the client. + go c.stats.periodicLogStats(ctx, c.throttleCh) + // Keep count of active connections active := make(chan struct{}, len(c.connections)) - var wg sync.WaitGroup - wg.Add(len(c.connections)) - + var mutex sync.Mutex for i, conn := range c.connections { go func(i int, conn *remote.Connection) { - active <- struct{}{} - defer func() { - logger.Debug(conn.Server, "Disconnected completely...") - <-active - }() - wg.Done() - - for { - conn.Start(c.throttleCh, c.stats.connectionsEstCh) - if !c.retry { - return - } - time.Sleep(time.Second * 2) - logger.Debug(conn.Server, "Reconencting") - conn = c.maker.makeConnection(conn.Server, c.sshAuthMethods, c.hostKeyCallback) - c.connections[i] = conn + connStatus := c.start(ctx, active, i, conn) + + // Update global status. + mutex.Lock() + defer mutex.Unlock() + if connStatus > status { + status = connStatus } }(i, conn) } - wg.Wait() - c.waitUntilDone(active) - + c.waitUntilDone(ctx, active) return } -func (c *baseClient) waitUntilDone(active chan struct{}) { - defer close(c.stopped) +func (c *baseClient) start(ctx context.Context, active chan struct{}, i int, conn *remote.Connection) (status int) { + // Increment connection count + active <- struct{}{} + // Derement connection count + defer func() { <-active }() - if c.Mode != omode.TailClient { - c.waitUntilZero(active) - logger.Info("All connections stopped") - return - } + for { + connCtx, cancel := conn.Handler.WithCancel(ctx) + defer cancel() - <-c.stop - logger.Info("Stopping client") - for _, conn := range c.connections { - conn.Stop() + conn.Start(connCtx, cancel, c.throttleCh, c.stats.connectionsEstCh) + // Retrieve status code from handler (dtail client will exit with that status) + status = conn.Handler.Status() + + if !c.retry { + return + } + + time.Sleep(time.Second * 2) + logger.Debug(conn.Server, "Reconnecting") + + conn = c.makeConnection(conn.Server, c.sshAuthMethods, c.hostKeyCallback) + c.connections[i] = conn } +} - c.waitUntilZero(active) +func (c *baseClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection { + conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback) + conn.Handler = c.maker.makeHandler(server) + conn.Commands = c.maker.makeCommands() + + return conn } -func (c *baseClient) waitUntilZero(active chan struct{}) { +func (c *baseClient) waitUntilDone(ctx context.Context, active chan struct{}) { + defer logger.Info("Terminated connection") + + // We want to have at least one active connection + <-active + // Put it back on the channel + active <- struct{}{} + + if c.Mode == omode.TailClient { + <-ctx.Done() + } + for { - logger.Debug("Active connections", len(active)) - if len(active) == 0 { + numActive := len(active) + if numActive == 0 { return } + logger.Debug("Active connections", numActive) time.Sleep(time.Second) } } - -func (c *baseClient) Stop() { - close(c.stop) - <-c.WaitC() -} - -func (c *baseClient) WaitC() <-chan struct{} { - return c.stopped -} diff --git a/internal/clients/catclient.go b/internal/clients/catclient.go index 5ea701d..7fd6bdc 100644 --- a/internal/clients/catclient.go +++ b/internal/clients/catclient.go @@ -7,11 +7,7 @@ import ( "strings" "github.com/mimecast/dtail/internal/clients/handlers" - "github.com/mimecast/dtail/internal/clients/remote" "github.com/mimecast/dtail/internal/omode" - "github.com/mimecast/dtail/internal/ssh/client" - - gossh "golang.org/x/crypto/ssh" ) // CatClient is a client for returning a whole file from the beginning to the end. @@ -31,8 +27,6 @@ func NewCatClient(args Args) (*CatClient, error) { c := CatClient{ baseClient: baseClient{ Args: args, - stop: make(chan struct{}), - stopped: make(chan struct{}), throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()), retry: false, }, @@ -43,11 +37,13 @@ func NewCatClient(args Args) (*CatClient, error) { return &c, nil } -func (c CatClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection { - conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback) - conn.Handler = handlers.NewClientHandler(server, c.PingTimeout) - for _, file := range strings.Split(c.Files, ",") { - conn.Commands = append(conn.Commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex)) +func (c CatClient) makeHandler(server string) handlers.Handler { + return handlers.NewClientHandler(server) +} + +func (c CatClient) makeCommands() (commands []string) { + for _, file := range strings.Split(c.What, ",") { + commands = append(commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex)) } - return conn + return } diff --git a/internal/clients/client.go b/internal/clients/client.go index 85d1aae..1fc5e23 100644 --- a/internal/clients/client.go +++ b/internal/clients/client.go @@ -1,7 +1,8 @@ package clients +import "context" + // Client is the interface for the end user command line client. type Client interface { - Start() int - Stop() + Start(ctx context.Context) int } diff --git a/internal/clients/connectionmaker.go b/internal/clients/connectionmaker.go deleted file mode 100644 index 0617992..0000000 --- a/internal/clients/connectionmaker.go +++ /dev/null @@ -1,12 +0,0 @@ -package clients - -import ( - "github.com/mimecast/dtail/internal/clients/remote" - "github.com/mimecast/dtail/internal/ssh/client" - - gossh "golang.org/x/crypto/ssh" -) - -type connectionMaker interface { - makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection -} diff --git a/internal/clients/execclient.go b/internal/clients/execclient.go deleted file mode 100644 index 10bd081..0000000 --- a/internal/clients/execclient.go +++ /dev/null @@ -1,48 +0,0 @@ -package clients - -import ( - "fmt" - "runtime" - "strings" - - "github.com/mimecast/dtail/internal/clients/handlers" - "github.com/mimecast/dtail/internal/clients/remote" - "github.com/mimecast/dtail/internal/omode" - "github.com/mimecast/dtail/internal/ssh/client" - - gossh "golang.org/x/crypto/ssh" -) - -// ExecClient is a client for execute various commands on the server. -type ExecClient struct { - baseClient -} - -// NewExecClient returns a new cat client. -func NewExecClient(args Args) (*ExecClient, error) { - args.Regex = "." - args.Mode = omode.ExecClient - - c := ExecClient{ - baseClient: baseClient{ - Args: args, - stop: make(chan struct{}), - stopped: make(chan struct{}), - throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()), - retry: false, - }, - } - - c.init(c) - - return &c, nil -} - -func (c ExecClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection { - conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback) - conn.Handler = handlers.NewClientHandler(server, c.PingTimeout) - for _, file := range strings.Split(c.Files, ";") { - conn.Commands = append(conn.Commands, fmt.Sprintf("%s %s", c.Mode.String(), file)) - } - return conn -} diff --git a/internal/clients/grepclient.go b/internal/clients/grepclient.go index c568f63..8d11458 100644 --- a/internal/clients/grepclient.go +++ b/internal/clients/grepclient.go @@ -7,11 +7,7 @@ import ( "strings" "github.com/mimecast/dtail/internal/clients/handlers" - "github.com/mimecast/dtail/internal/clients/remote" "github.com/mimecast/dtail/internal/omode" - "github.com/mimecast/dtail/internal/ssh/client" - - gossh "golang.org/x/crypto/ssh" ) // GrepClient searches a remote file for all lines matching a regular expression. Only the matching lines are displayed. @@ -29,8 +25,6 @@ func NewGrepClient(args Args) (*GrepClient, error) { c := GrepClient{ baseClient: baseClient{ Args: args, - stop: make(chan struct{}), - stopped: make(chan struct{}), throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()), retry: false, }, @@ -41,13 +35,13 @@ func NewGrepClient(args Args) (*GrepClient, error) { return &c, nil } -func (c GrepClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection { - conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback) - conn.Handler = handlers.NewClientHandler(server, c.PingTimeout) +func (c GrepClient) makeHandler(server string) handlers.Handler { + return handlers.NewClientHandler(server) +} - for _, file := range strings.Split(c.Files, ",") { - conn.Commands = append(conn.Commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex)) +func (c GrepClient) makeCommands() (commands []string) { + for _, file := range strings.Split(c.What, ",") { + commands = append(commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex)) } - - return conn + return } diff --git a/internal/clients/handlers/basehandler.go b/internal/clients/handlers/basehandler.go index 19246f9..68b8ddc 100644 --- a/internal/clients/handlers/basehandler.go +++ b/internal/clients/handlers/basehandler.go @@ -1,60 +1,44 @@ package handlers import ( - "github.com/mimecast/dtail/internal/logger" - "errors" + "encoding/base64" "fmt" "io" + "strconv" "strings" "time" + + "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/version" ) type baseHandler struct { + withCancel server string shellStarted bool commands chan string - pong chan struct{} receiveBuf []byte - stop chan struct{} - pingTimeout int + status int } func (h *baseHandler) Server() string { return h.server } -// Used to determine whether server is still responding to requests or not. -func (h *baseHandler) Ping() error { - if h.pingTimeout == 0 { - // Server ping disabled - return nil - } - - if err := h.SendCommand("ping"); err != nil { - return err - } - - select { - case <-h.pong: - return nil - case <-time.After(time.Duration(h.pingTimeout) * time.Second): - } - - return errors.New("Didn't receive any server pongs (ping replies)") +func (h *baseHandler) Status() int { + return h.status } -func (h *baseHandler) SendCommand(command string) error { - if command == "ping" { - logger.Trace("Sending command", h.server, command) - } else { - logger.Debug("Sending command", h.server, command) - } +// SendMessage to the server. +func (h *baseHandler) SendMessage(command string) error { + encoded := base64.StdEncoding.EncodeToString([]byte(command)) + logger.Debug("Sending command", h.server, command, encoded) select { - case h.commands <- fmt.Sprintf("%s;", command): + case h.commands <- fmt.Sprintf("protocol %s base64 %v;", version.ProtocolCompat, encoded): case <-time.After(time.Second * 5): - return errors.New("Timed out sending command " + command) - case <-h.stop: + return fmt.Errorf("Timed out sending command '%s' (base64: '%s')", command, encoded) + case <-h.ctx.Done(): } return nil @@ -81,7 +65,7 @@ func (h *baseHandler) Read(p []byte) (n int, err error) { select { case command := <-h.commands: n = copy(p, []byte(command)) - case <-h.stop: + case <-h.ctx.Done(): return 0, io.EOF } return @@ -92,6 +76,7 @@ func (h *baseHandler) handleMessageType(message string) { if len(h.receiveBuf) == 0 { return } + // Hidden server commands starti with a dot "." if h.receiveBuf[0] == '.' { h.handleHiddenMessage(message) @@ -108,6 +93,7 @@ func (h *baseHandler) handleMessageType(message string) { h.receiveBuf = h.receiveBuf[:0] return } + logger.Raw(message) h.receiveBuf = h.receiveBuf[:0] } @@ -116,19 +102,27 @@ func (h *baseHandler) handleMessageType(message string) { // to the end user. func (h *baseHandler) handleHiddenMessage(message string) { switch { - case strings.HasPrefix(message, ".pong"): - h.pong <- struct{}{} case strings.HasPrefix(message, ".syn close connection"): - h.SendCommand("ack close connection") - } -} + h.SendMessage(".ack close connection") + select { + case <-time.After(time.Second * 1): + logger.Debug("Shutting down client after timeout and sending ack to server") + h.withCancel.shutdown() + case <-h.ctx.Done(): + } -// Stop the handler. -func (h *baseHandler) Stop() { - select { - case <-h.stop: - default: - logger.Debug("Stopping base handler", h.server) - close(h.stop) + case strings.HasPrefix(message, ".run exitstatus"): + splitted := strings.Split(strings.TrimSuffix(message, "\n"), " ") + if len(splitted) != 3 { + logger.Error("Unable to retrieve exitstatus", message) + return + } + i, err := strconv.Atoi(splitted[2]) + if err != nil { + logger.Error("Unable to retrieve exitstatus", message, err) + return + } + h.status = i + logger.Debug("Retrieved exitstatus", h.status) } } diff --git a/internal/clients/handlers/clienthandler.go b/internal/clients/handlers/clienthandler.go index 4738cd3..fcd8052 100644 --- a/internal/clients/handlers/clienthandler.go +++ b/internal/clients/handlers/clienthandler.go @@ -1,7 +1,7 @@ package handlers import ( - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" ) // ClientHandler is the basic client handler interface. @@ -10,7 +10,7 @@ type ClientHandler struct { } // NewClientHandler creates a new client handler. -func NewClientHandler(server string, pingTimeout int) *ClientHandler { +func NewClientHandler(server string) *ClientHandler { logger.Debug(server, "Creating new client handler") return &ClientHandler{ @@ -18,9 +18,10 @@ func NewClientHandler(server string, pingTimeout int) *ClientHandler { server: server, shellStarted: false, commands: make(chan string), - pong: make(chan struct{}, 1), - stop: make(chan struct{}), - pingTimeout: pingTimeout, + status: -1, + withCancel: withCancel{ + done: make(chan struct{}), + }, }, } } diff --git a/internal/clients/handlers/handler.go b/internal/clients/handlers/handler.go index 2013be0..c53ca34 100644 --- a/internal/clients/handlers/handler.go +++ b/internal/clients/handlers/handler.go @@ -1,12 +1,16 @@ package handlers -import "io" +import ( + "context" + "io" +) // Handler provides all methods which can be run on any client handler. type Handler interface { io.ReadWriter - Ping() error - Stop() - SendCommand(command string) error + SendMessage(command string) error Server() string + Status() int + WithCancel(ctx context.Context) (context.Context, context.CancelFunc) + Done() <-chan struct{} } diff --git a/internal/clients/handlers/healthhandler.go b/internal/clients/handlers/healthhandler.go index 4051e2c..9051015 100644 --- a/internal/clients/handlers/healthhandler.go +++ b/internal/clients/handlers/healthhandler.go @@ -8,6 +8,7 @@ import ( // HealthHandler implements the handler required for health checks. type HealthHandler struct { + withCancel // Buffer of incoming data from server. receiveBuf []byte // To send commands to the server. @@ -16,6 +17,7 @@ type HealthHandler struct { receive chan<- string // The remote server address server string + status int } // NewHealthHandler returns a new health check handler. @@ -24,6 +26,10 @@ func NewHealthHandler(server string, receive chan<- string) *HealthHandler { server: server, receive: receive, commands: make(chan string), + status: -1, + withCancel: withCancel{ + done: make(chan struct{}), + }, } return &h @@ -34,18 +40,13 @@ func (h *HealthHandler) Server() string { return h.server } -// Stop is not of use for health check handler. -func (h *HealthHandler) Stop() { - // Nothing done here. +// Status of the handler. +func (h *HealthHandler) Status() int { + return h.status } -// Ping is not of use for health check handler. -func (h *HealthHandler) Ping() error { - return nil -} - -// SendCommand send a DTail command to the server. -func (h *HealthHandler) SendCommand(command string) error { +// SendMessage sends a DTail command to the server. +func (h *HealthHandler) SendMessage(command string) error { select { case h.commands <- fmt.Sprintf("%s;", command): case <-time.NewTimer(time.Second * 10).C: diff --git a/internal/clients/handlers/maprhandler.go b/internal/clients/handlers/maprhandler.go index d76cdfd..874bb7d 100644 --- a/internal/clients/handlers/maprhandler.go +++ b/internal/clients/handlers/maprhandler.go @@ -1,10 +1,11 @@ package handlers import ( - "github.com/mimecast/dtail/internal/logger" + "strings" + + "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/mapr" "github.com/mimecast/dtail/internal/mapr/client" - "strings" ) // MaprHandler is the handler used on the client side for running mapreduce aggregations. @@ -16,15 +17,16 @@ type MaprHandler struct { } // NewMaprHandler returns a new mapreduce client handler. -func NewMaprHandler(server string, query *mapr.Query, globalGroup *mapr.GlobalGroupSet, pingTimeout int) *MaprHandler { +func NewMaprHandler(server string, query *mapr.Query, globalGroup *mapr.GlobalGroupSet) *MaprHandler { return &MaprHandler{ baseHandler: baseHandler{ server: server, shellStarted: false, commands: make(chan string), - pong: make(chan struct{}, 1), - stop: make(chan struct{}), - pingTimeout: pingTimeout, + status: -1, + withCancel: withCancel{ + done: make(chan struct{}), + }, }, query: query, aggregate: client.NewAggregate(server, query, globalGroup), @@ -65,10 +67,3 @@ func (h *MaprHandler) handleAggregateMessage(message string) { h.aggregate.Aggregate(parts[2:]) logger.Debug("Aggregated aggregate data", h.server, h.count) } - -// Stop stops the mapreduce client handler. -func (h *MaprHandler) Stop() { - logger.Debug("Stopping mapreduce handler", h.server) - h.aggregate.Stop() - h.baseHandler.Stop() -} diff --git a/internal/clients/handlers/withcancel.go b/internal/clients/handlers/withcancel.go new file mode 100644 index 0000000..7c9cf4e --- /dev/null +++ b/internal/clients/handlers/withcancel.go @@ -0,0 +1,24 @@ +package handlers + +import "context" + +type withCancel struct { + ctx context.Context + done chan struct{} +} + +// WithCancel sets and returns the context used. +func (w *withCancel) WithCancel(ctx context.Context) (context.Context, context.CancelFunc) { + cancelCtx, cancel := context.WithCancel(ctx) + w.ctx = cancelCtx + + return cancelCtx, cancel +} + +func (w *withCancel) Done() <-chan struct{} { + return w.done +} + +func (w *withCancel) shutdown() { + close(w.done) +} diff --git a/internal/clients/healthclient.go b/internal/clients/healthclient.go index ff13b83..7313583 100644 --- a/internal/clients/healthclient.go +++ b/internal/clients/healthclient.go @@ -1,6 +1,7 @@ package clients import ( + "context" "fmt" "runtime" "strings" @@ -39,7 +40,7 @@ func NewHealthClient(mode omode.Mode) (*HealthClient, error) { } // Start the health client. -func (c *HealthClient) Start() (status int) { +func (c *HealthClient) Start(ctx context.Context) (status int) { receive := make(chan string) throttleCh := make(chan struct{}, runtime.NumCPU()) @@ -49,8 +50,8 @@ func (c *HealthClient) Start() (status int) { conn.Handler = handlers.NewHealthHandler(c.server, receive) conn.Commands = []string{c.mode.String()} - go conn.Start(throttleCh, statsCh) - defer conn.Stop() + connCtx, cancel := conn.Handler.WithCancel(ctx) + go conn.Start(connCtx, cancel, throttleCh, statsCh) for { select { diff --git a/internal/clients/maker.go b/internal/clients/maker.go new file mode 100644 index 0000000..da9dfc9 --- /dev/null +++ b/internal/clients/maker.go @@ -0,0 +1,8 @@ +package clients + +import "github.com/mimecast/dtail/internal/clients/handlers" + +type maker interface { + makeHandler(server string) handlers.Handler + makeCommands() (commands []string) +} diff --git a/internal/clients/maprclient.go b/internal/clients/maprclient.go index 9070827..b581844 100644 --- a/internal/clients/maprclient.go +++ b/internal/clients/maprclient.go @@ -1,6 +1,7 @@ package clients import ( + "context" "errors" "fmt" "runtime" @@ -8,13 +9,9 @@ import ( "time" "github.com/mimecast/dtail/internal/clients/handlers" - "github.com/mimecast/dtail/internal/clients/remote" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/mapr" "github.com/mimecast/dtail/internal/omode" - "github.com/mimecast/dtail/internal/ssh/client" - - gossh "golang.org/x/crypto/ssh" ) // MaprClient is used for running mapreduce aggregations on remote files. @@ -39,8 +36,6 @@ func NewMaprClient(args Args, queryStr string) (*MaprClient, error) { c := MaprClient{ baseClient: baseClient{ Args: args, - stop: make(chan struct{}), - stopped: make(chan struct{}), throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()), retry: args.Mode == omode.TailClient, }, @@ -70,35 +65,36 @@ func NewMaprClient(args Args, queryStr string) (*MaprClient, error) { return &c, nil } -func (c MaprClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection { - conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback) - conn.Handler = handlers.NewMaprHandler(conn.Server, c.query, c.globalGroup, c.PingTimeout) +// Start starts the mapreduce client. +func (c *MaprClient) Start(ctx context.Context) (status int) { + if c.query.Outfile == "" { + // Only print out periodic results if we don't write an outfile + go c.periodicPrintResults(ctx) + } - conn.Commands = append(conn.Commands, fmt.Sprintf("map %s", c.query.RawQuery)) - commandStr := "tail" + status = c.baseClient.Start(ctx) if c.additative { - commandStr = "cat" + c.recievedFinalResult() } - for _, file := range strings.Split(c.Files, ",") { - conn.Commands = append(conn.Commands, fmt.Sprintf("%s %s regex %s", commandStr, file, c.Regex)) - } + return +} - return conn +func (c MaprClient) makeHandler(server string) handlers.Handler { + return handlers.NewMaprHandler(server, c.query, c.globalGroup) } -// Start starts the mapreduce client. -func (c *MaprClient) Start() (status int) { - if c.query.Outfile == "" { - // Only print out periodic results if we don't write an outfile - go c.periodicPrintResults() - } +func (c MaprClient) makeCommands() (commands []string) { + commands = append(commands, fmt.Sprintf("map %s", c.query.RawQuery)) - status = c.baseClient.Start() + modeStr := "tail" if c.additative { - c.recievedFinalResult() + modeStr = "cat" + } + + for _, file := range strings.Split(c.What, ",") { + commands = append(commands, fmt.Sprintf("%s %s regex %s", modeStr, file, c.Regex)) } - c.baseClient.Stop() return } @@ -120,13 +116,13 @@ func (c *MaprClient) recievedFinalResult() { logger.Info(fmt.Sprintf("Wrote final mapreduce result to '%s'", c.query.Outfile)) } -func (c *MaprClient) periodicPrintResults() { +func (c *MaprClient) periodicPrintResults(ctx context.Context) { for { select { case <-time.After(c.query.Interval): logger.Info("Gathering interim mapreduce result") c.printResults() - case <-c.baseClient.stop: + case <-ctx.Done(): return } } diff --git a/internal/clients/remote/connection.go b/internal/clients/remote/connection.go index bfc7bc5..71639b1 100644 --- a/internal/clients/remote/connection.go +++ b/internal/clients/remote/connection.go @@ -1,16 +1,18 @@ package remote import ( - "github.com/mimecast/dtail/internal/clients/handlers" - "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/logger" - "github.com/mimecast/dtail/internal/ssh/client" + "context" "fmt" "io" "strconv" "strings" "time" + "github.com/mimecast/dtail/internal/clients/handlers" + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/ssh/client" + "golang.org/x/crypto/ssh" ) @@ -30,8 +32,6 @@ type Connection struct { Commands []string // Is it a persistent connection or a one-off? isOneOff bool - // Used to stop the connection - stop chan struct{} // To deal with SSH server host keys hostKeyCallback *client.HostKeyCallback } @@ -48,7 +48,6 @@ func NewConnection(server string, userName string, authMethods []ssh.AuthMethod, HostKeyCallback: hostKeyCallback.Wrap(), Timeout: time.Second * 3, }, - stop: make(chan struct{}), } c.initServerPort(server) @@ -64,7 +63,6 @@ func NewOneOffConnection(server string, userName string, authMethods []ssh.AuthM Auth: authMethods, HostKeyCallback: ssh.InsecureIgnoreHostKey(), }, - stop: make(chan struct{}), isOneOff: true, } @@ -90,39 +88,34 @@ func (c *Connection) initServerPort(server string) { } } -// Start the server connection. Build up SSH session and send some DTail commandc. -func (c *Connection) Start(throttleCh, statsCh chan struct{}) { +// Start the server connection. Build up SSH session and send some DTail commands. +func (c *Connection) Start(ctx context.Context, cancel context.CancelFunc, throttleCh, statsCh chan struct{}) { + // Throttle how many connections can be established concurrently (based on ch length) select { - case <-c.stop: - logger.Info(c.Server, c.port, "Disconnecting client") + case throttleCh <- struct{}{}: + defer func() { <-throttleCh }() + case <-ctx.Done(): return - default: } - // Wait for SSH connection throttler - throttleCh <- struct{}{} - - // Wait until connection has been initiated or an error occured - // during initialization. - throttleStopCh := make(chan struct{}, 2) go func() { - <-throttleStopCh - <-throttleCh - }() + defer cancel() - if err := c.dial(c.Server, c.port, throttleStopCh, statsCh); err != nil { - logger.Warn(c.Server, c.port, err) - throttleStopCh <- struct{}{} + if err := c.dial(ctx, cancel, c.Server, c.port, statsCh); err != nil { + logger.Warn(c.Server, c.port, err) - if c.hostKeyCallback.Untrusted(fmt.Sprintf("%s:%d", c.Server, c.port)) { - logger.Debug("Not trusting host, not trying to re-connect", c.Server, c.port) - return + if c.hostKeyCallback.Untrusted(fmt.Sprintf("%s:%d", c.Server, c.port)) { + logger.Debug("Not trusting host", c.Server, c.port) + return + } } - } + }() + + <-ctx.Done() } // Dail into a new SSH connection. Close connection in case of an error. -func (c *Connection) dial(host string, port int, throttleStopCh, statsCh chan struct{}) error { +func (c *Connection) dial(ctx context.Context, cancel context.CancelFunc, host string, port int, statsCh chan struct{}) error { statsCh <- struct{}{} defer func() { <-statsCh }() @@ -135,11 +128,11 @@ func (c *Connection) dial(host string, port int, throttleStopCh, statsCh chan st } defer client.Close() - return c.session(client, throttleStopCh) + return c.session(ctx, cancel, client) } // Create the SSH session. Close the session in case of an error. -func (c *Connection) session(client *ssh.Client, throttleStopCh chan<- struct{}) error { +func (c *Connection) session(ctx context.Context, cancel context.CancelFunc, client *ssh.Client) error { logger.Debug(c.Server, "session") session, err := client.NewSession() @@ -148,14 +141,10 @@ func (c *Connection) session(client *ssh.Client, throttleStopCh chan<- struct{}) } defer session.Close() - return c.handle(session, throttleStopCh) + return c.handle(ctx, cancel, session) } -// Handle the SSH session. Also send periodic pings to the server in order -// to determine that session is still intact. -func (c *Connection) handle(session *ssh.Session, throttleStopCh chan<- struct{}) error { - defer c.Handler.Stop() - +func (c *Connection) handle(ctx context.Context, cancel context.CancelFunc, session *ssh.Session) error { logger.Debug(c.Server, "handle") stdinPipe, err := session.StdinPipe() @@ -172,59 +161,30 @@ func (c *Connection) handle(session *ssh.Session, throttleStopCh chan<- struct{} return err } - // Establish Bi-directional pipe between SSH session and client handler. - brokenStdinPipe := make(chan struct{}) go func() { - defer close(brokenStdinPipe) + defer cancel() io.Copy(stdinPipe, c.Handler) }() - brokenStdoutPipe := make(chan struct{}) go func() { - defer close(brokenStdoutPipe) + defer cancel() io.Copy(c.Handler, stdoutPipe) }() - // SSH session established, other goroutine can initiate session now. - throttleStopCh <- struct{}{} + go func() { + defer cancel() + select { + case <-c.Handler.Done(): + case <-ctx.Done(): + } + }() // Send all commands to client. for _, command := range c.Commands { logger.Debug(command) - c.Handler.SendCommand(command) + c.Handler.SendMessage(command) } - if !c.isOneOff { - return c.periodicAliveCheck(brokenStdinPipe, brokenStdoutPipe) - } - - <-c.stop - - // Normal shutdown, all fine + <-ctx.Done() return nil } - -// Periodically check whether connection is still alive or not. -func (c *Connection) periodicAliveCheck(brokenStdinPipe, brokenStdoutPipe <-chan struct{}) error { - for { - select { - case <-time.After(time.Second * 3): - if err := c.Handler.Ping(); err != nil { - return err - } - case <-brokenStdinPipe: - logger.Debug("Broken stdin pipe", c.Server, c.port) - return nil - case <-brokenStdoutPipe: - logger.Debug("Broken stdout pipe", c.Server, c.port) - return nil - case <-c.stop: - return nil - } - } -} - -// Stop the connection. -func (c *Connection) Stop() { - close(c.stop) -} diff --git a/internal/clients/runclient.go b/internal/clients/runclient.go new file mode 100644 index 0000000..7a62fcc --- /dev/null +++ b/internal/clients/runclient.go @@ -0,0 +1,40 @@ +package clients + +import ( + "fmt" + "runtime" + + "github.com/mimecast/dtail/internal/clients/handlers" + "github.com/mimecast/dtail/internal/omode" +) + +// RunClient is a client to run various commands on the server. +type RunClient struct { + baseClient +} + +// NewRunClient returns a new cat client. +func NewRunClient(args Args) (*RunClient, error) { + args.Mode = omode.RunClient + + c := RunClient{ + baseClient: baseClient{ + Args: args, + throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()), + retry: false, + }, + } + + c.init(c) + return &c, nil +} + +func (c RunClient) makeHandler(server string) handlers.Handler { + return handlers.NewClientHandler(server) +} + +func (c RunClient) makeCommands() (commands []string) { + // Send "run COMMAND" to server! + commands = append(commands, fmt.Sprintf("%s %s", c.Mode.String(), c.What)) + return +} diff --git a/internal/clients/stats.go b/internal/clients/stats.go index d36cef6..ec6adfe 100644 --- a/internal/clients/stats.go +++ b/internal/clients/stats.go @@ -1,11 +1,13 @@ package clients import ( - "github.com/mimecast/dtail/internal/logger" + "context" "fmt" "runtime" "sync" "time" + + "github.com/mimecast/dtail/internal/io/logger" ) // Used to collect and display various client stats. @@ -28,14 +30,14 @@ func newTailStats(connectionsTotal int) *stats { } } -func (s *stats) periodicLogStats(throttleCh chan struct{}, stop <-chan struct{}) { +func (s *stats) periodicLogStats(ctx context.Context, throttleCh chan struct{}) { connectedLast := 0 statsInterval := 5 for { select { case <-time.After(time.Second * time.Duration(statsInterval)): - case <-stop: + case <-ctx.Done(): return } diff --git a/internal/clients/tailclient.go b/internal/clients/tailclient.go index 674ca36..4d81fd5 100644 --- a/internal/clients/tailclient.go +++ b/internal/clients/tailclient.go @@ -6,11 +6,7 @@ import ( "strings" "github.com/mimecast/dtail/internal/clients/handlers" - "github.com/mimecast/dtail/internal/clients/remote" "github.com/mimecast/dtail/internal/omode" - "github.com/mimecast/dtail/internal/ssh/client" - - gossh "golang.org/x/crypto/ssh" ) // TailClient is used for tailing remote log files (opening, seeking to the end and returning only new incoming lines). @@ -25,25 +21,22 @@ func NewTailClient(args Args) (*TailClient, error) { c := TailClient{ baseClient: baseClient{ Args: args, - stop: make(chan struct{}), - stopped: make(chan struct{}), throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()), retry: true, }, } c.init(c) - return &c, nil } -func (c TailClient) makeConnection(server string, sshAuthMethods []gossh.AuthMethod, hostKeyCallback *client.HostKeyCallback) *remote.Connection { - conn := remote.NewConnection(server, c.UserName, sshAuthMethods, hostKeyCallback) - conn.Handler = handlers.NewClientHandler(server, c.PingTimeout) +func (c TailClient) makeHandler(server string) handlers.Handler { + return handlers.NewClientHandler(server) +} - for _, file := range strings.Split(c.Files, ",") { - conn.Commands = append(conn.Commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex)) +func (c TailClient) makeCommands() (commands []string) { + for _, file := range strings.Split(c.What, ",") { + commands = append(commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex)) } - - return conn + return } diff --git a/internal/discovery/comma.go b/internal/discovery/comma.go index ad18be0..94276c7 100644 --- a/internal/discovery/comma.go +++ b/internal/discovery/comma.go @@ -1,7 +1,7 @@ package discovery import ( - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" "strings" ) diff --git a/internal/discovery/discovery.go b/internal/discovery/discovery.go index d76c1b2..1090ea9 100644 --- a/internal/discovery/discovery.go +++ b/internal/discovery/discovery.go @@ -1,7 +1,6 @@ package discovery import ( - "github.com/mimecast/dtail/internal/logger" "fmt" "math/rand" "os" @@ -9,6 +8,16 @@ import ( "regexp" "strings" "time" + + "github.com/mimecast/dtail/internal/io/logger" +) + +// ServerOrder to specify how to sort the server list. +type ServerOrder int + +const ( + // Shuffle the server list? + Shuffle ServerOrder = iota ) // Discovery method for discovering a list of available DTail servers. @@ -21,12 +30,12 @@ type Discovery struct { server string // To filter server list. regex *regexp.Regexp - // To shuffle resulting server list. - shuffle bool + // How to order the server list. + order ServerOrder } // New returns a new discovery method. -func New(method, server string, shuffle bool) *Discovery { +func New(method, server string, order ServerOrder) *Discovery { module := method options := "" @@ -43,7 +52,7 @@ func New(method, server string, shuffle bool) *Discovery { module: strings.ToUpper(module), options: options, server: server, - shuffle: shuffle, + order: order, } if strings.HasPrefix(server, "/") && strings.HasSuffix(server, "/") { @@ -84,7 +93,7 @@ func (d *Discovery) ServerList() []string { servers = d.dedupList(servers) - if d.shuffle { + if d.order == Shuffle { servers = d.shuffleList(servers) } diff --git a/internal/discovery/file.go b/internal/discovery/file.go index 2edc867..c04173e 100644 --- a/internal/discovery/file.go +++ b/internal/discovery/file.go @@ -2,7 +2,7 @@ package discovery import ( "bufio" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" "os" ) diff --git a/internal/fs/catfile.go b/internal/fs/catfile.go deleted file mode 100644 index 99f521f..0000000 --- a/internal/fs/catfile.go +++ /dev/null @@ -1,27 +0,0 @@ -package fs - -import "sync" - -// CatFile is for reading a whole file. -type CatFile struct { - readFile -} - -// NewCatFile returns a new file catter. -func NewCatFile(filePath string, globID string, serverMessages chan<- string, limiter chan struct{}) CatFile { - var mutex sync.Mutex - - return CatFile{ - readFile: readFile{ - filePath: filePath, - stop: make(chan struct{}), - globID: globID, - serverMessages: serverMessages, - retry: false, - canSkipLines: false, - seekEOF: false, - limiter: limiter, - mutex: &mutex, - }, - } -} diff --git a/internal/fs/filereader.go b/internal/fs/filereader.go deleted file mode 100644 index 5a08e27..0000000 --- a/internal/fs/filereader.go +++ /dev/null @@ -1,9 +0,0 @@ -package fs - -// FileReader is the interface used on the dtail server to read/cat/grep/mapr... a file. -type FileReader interface { - Start(lines chan<- LineRead, regex string) error - FilePath() string - Retry() bool - Stop() -} diff --git a/internal/fs/lineread.go b/internal/fs/lineread.go deleted file mode 100644 index 7ee558e..0000000 --- a/internal/fs/lineread.go +++ /dev/null @@ -1,28 +0,0 @@ -package fs - -import ( - "fmt" -) - -// LineRead represents a read log line. -type LineRead struct { - // The content of the log line. - Content []byte - // Until now, how many log lines were processed? - Count uint64 - // Sometimes we produce too many log lines so that the client - // is too slow to process all of them. The server will drop log - // lines if that happens but it will signal to the client how - // many log lines in % could be transmitted to the client. - TransmittedPerc int - GlobID *string -} - -// Return a human readable representation of the followed line. -func (l LineRead) String() string { - return fmt.Sprintf("LineRead(Content:%s,TransmittedPerc:%v,Count:%v,GlobID:%s)", - string(l.Content), - l.TransmittedPerc, - l.Count, - *l.GlobID) -} diff --git a/internal/fs/permissions/permission.go b/internal/fs/permissions/permission.go deleted file mode 100644 index 6e83309..0000000 --- a/internal/fs/permissions/permission.go +++ /dev/null @@ -1,14 +0,0 @@ -// +build !linux - -package permissions - -import ( - "github.com/mimecast/dtail/internal/logger" -) - -// ToRead is to check whether user has read permissions to a given file. -func ToRead(user, filePath string) (bool, error) { - // Only implemented for Linux, always expect true - logger.Warn(user, filePath, "Not performing ACL check, not supported on this platform") - return true, nil -} diff --git a/internal/fs/permissions/permission_linux.c b/internal/fs/permissions/permission_linux.c deleted file mode 100644 index cd10525..0000000 --- a/internal/fs/permissions/permission_linux.c +++ /dev/null @@ -1,395 +0,0 @@ -#include "permission_linux.h" - -#ifdef DEBUG -void debug_print_checker(struct permission_checker *pc) { - fprintf(stderr, "DEBUG: user_name:%s (%d)\n", - pc->user_name, pc->uid); - - fprintf(stderr, "DEBUG: ngids:%d\n", pc->ngids); - int j; - for (j = 0; j < pc->ngids; j++) { - fprintf(stderr, "DEBUG: %d", pc->gids[j]); - struct group *gr = getgrgid(pc->gids[j]); - if (gr != NULL) - fprintf(stderr, " (%s)", gr->gr_name); - fprintf(stderr, "\n"); - } - - fprintf(stderr, "DEBUG: file_path:%s (%d:%d)\n", - pc->file_path, pc->file_stat.st_uid, pc->file_stat.st_gid); -} -#endif // DEBUG - -int stat_file(struct permission_checker *pc) { - if (stat(pc->file_path, &pc->file_stat) != 0) - return -1; - -#ifdef DEBUG - fprintf(stderr, "DEBUG: File'%s' is owned by '%d:%d'\n", - pc->file_path, pc->file_stat.st_uid, pc->file_stat.st_gid); -#endif - - return 0; -} - -int get_user_uid(struct permission_checker *pc) { - struct passwd *result = NULL; - - size_t bufsize = sysconf(_SC_GETPW_R_SIZE_MAX); - if (bufsize == -1) - bufsize = 16384; - - char *buf = malloc(bufsize); - if (buf == NULL) { -#ifdef DEBUG - fprintf(stderr, "DEBUG: Unabel to allocate bufer while retrieving user '%s'\n", pc->user_name); -#endif - return -1; - } - - int rc = getpwnam_r(pc->user_name, &pc->pw, buf, bufsize, &result); - - if (result == NULL) { -#ifdef DEBUG - if (rc == 0) { - fprintf(stderr, "DEBUG: No user '%s' found\n", pc->user_name); - } else { - fprintf(stderr, "DEBUG: Unknown error while retrieving user '%s'\n", pc->user_name); - } -#endif - - free(buf); - return -1; - } - - pc->uid = pc->pw.pw_uid; - - free(buf); - return 0; -} - -int get_user_groups(struct permission_checker *pc) { - // First assume we are in 10 groups max - pc->ngids = 10; - pc->gids = malloc(pc->ngids * sizeof(gid_t)); - - if (pc->gids == NULL) { -#ifdef DEBUG - fprintf(stderr, "DEBUG: Unable to allocate space for gids."); -#endif - return -1; - } - - // Try so many times to load group list until it fits into group array. - while (getgrouplist(pc->user_name, pc->pw.pw_gid, pc->gids, &pc->ngids) == -1) { - // Too many groups, enlarge group array and try again - int newngids = pc->ngids + 100; - size_t newsize = newngids * sizeof(gid_t); - - if (SIZE_MAX / newngids < sizeof(gid_t)) { - // Overflow -#ifdef DEBUG - fprintf(stderr, "DEBUG: Overflow detected."); -#endif - return -1; - } - - gid_t *newgids = realloc(pc->gids, newsize); - if (newgids == NULL) { -#ifdef DEBUG - fprintf(stderr, "DEBUG: Unable to allocate space for gids."); -#endif - free(pc->gids); - return -1; - } - - pc->gids = newgids; - pc->ngids = newngids; - } - - return 0; -} - -int is_member_of_group(struct permission_checker *pc, gid_t gid) { - int j; - for (j = 0; j < pc->ngids; j++) - if (pc->gids[j] == gid) - return 1; - return 0; -} - -int check_acl_uid_matches(uid_t uid, acl_entry_t entry) { - int ret = -1; - uid_t *acl_uid = acl_get_qualifier(entry); - if (acl_uid == NULL) { -#ifdef DEBUG - fprintf(stderr, "DEBUG: Unable to retrieve user uid from ACL entry"); -#endif - return -1; - } - - ret = *acl_uid == uid ? 0 : -1; -#ifdef DEBUG - fprintf(stderr, "DEBUG: ACL user match?: %d <=> %d: %d\n", *acl_uid, uid, ret); -#endif - acl_free(acl_uid); - return ret; -} - -int check_acl_gid_matches(gid_t *gids, int ngids, acl_entry_t entry) { - int ret = -1; - gid_t *acl_gid = acl_get_qualifier(entry); - if (acl_gid == NULL) { -#ifdef DEBUG - fprintf(stderr, "DEBUG: Unable to retrieve user uid from ACL entry"); -#endif - return -1; - } - - int j; - for (j = 0; j < ngids; j++) { - if (*acl_gid == gids[j]) { -#ifdef DEBUG - fprintf(stderr, "DEBUG: User is in group %d", *acl_gid); -#endif - ret = 0; - break; - } - } - -#ifdef DEBUG - fprintf(stderr, "DEBUG: ACL group match?: %d <=> ...: %d\n", *acl_gid, ret); -#endif - acl_free(acl_gid); - return ret; -} - -int check_acl(struct permission_checker *pc, const int flag) { - // By default user has no read perm. - int has_read_perm = 0; - - // By default mask tells that there are read perm. However in order to have - // read permissions both, has_read_perm and mask_allows_read_access must be 1! - int mask_allows_read_access = 1; - - acl_type_t type = ACL_TYPE_ACCESS; - acl_t acl = acl_get_file(pc->file_path, type); - - if (acl == NULL) - // Unable to retrieve ACL. - return -1; - - // Walk through each entry of this ACL. - int id; - for (id = ACL_FIRST_ENTRY; ; id = ACL_NEXT_ENTRY) { - acl_entry_t entry; - if (acl_get_entry(acl, id, &entry) != 1) - // No more ACL entries. - break; - - acl_tag_t tag; - if (acl_get_tag_type(entry, &tag) == -1) - // Unable to retrieve ACL tag. - return -1; - - switch (tag) { - case ACL_USER_OBJ: - if (flag == GROUP_CHECK) - continue; -#ifdef DEBUG - fprintf(stderr, "DEBUG: ACL_USER_OBJ\n"); -#endif - // Ignore this ACL entry if user is not owner of file. - if (pc->uid != pc->file_stat.st_uid) - continue; - break; - case ACL_USER: - if (flag == GROUP_CHECK) - continue; -#ifdef DEBUG - fprintf(stderr, "DEBUG: ACL_USER\n"); -#endif - // Ignore this ACL entry if uid does not match. - if (check_acl_uid_matches(pc->uid, entry) != 0) - continue; - break; - case ACL_GROUP_OBJ: - if (flag == USER_CHECK) - continue; -#ifdef DEBUG - fprintf(stderr, "DEBUG: ACL_GROUP_OBJ\n"); -#endif - // Ignore ACL entry if user is not in group of file. - if (!is_member_of_group(pc, pc->file_stat.st_gid)) - continue; - break; - case ACL_GROUP: - if (flag == USER_CHECK) - continue; -#ifdef DEBUG - fprintf(stderr, "DEBUG: ACL_GROUP\n"); -#endif - // Ignore ACL entry if user is not in group of entry. - if (check_acl_gid_matches(pc->gids, pc->ngids, entry) != 0) - continue; - break; - case ACL_OTHER: - if (flag == GROUP_CHECK) - continue; -#ifdef DEBUG - fprintf(stderr, "DEBUG: ACL_OTHER\n"); -#endif - break; - case ACL_MASK: -#ifdef DEBUG - fprintf(stderr, "DEBUG: ACL_MASK\n"); -#endif - break; - default: -#ifdef DEBUG - fprintf(stderr, "DEBUG: Unknown ACL tag\n"); -#endif - return -1; - } - -#ifdef DEBUG - fprintf(stderr, "DEBUG: Retrieving permset\n"); -#endif - acl_permset_t permset; - int permission; - if (acl_get_permset(entry, &permset) == -1) - // Unable to retrieve permset. - return -1; - - if ((permission = acl_get_perm(permset, ACL_READ)) == -1) - // Unable to retrieve permset value. - return -1; - - if (permission == 1 && tag != ACL_MASK) { -#ifdef DEBUG - fprintf(stderr, "DEBUG: ACL says user has permission to read file.\n"); -#endif - has_read_perm = 1; - } else if (permission == 0 && tag == ACL_MASK) { - // Mask says that there are no permissions to read. - mask_allows_read_access = 0; -#ifdef DEBUG - fprintf(stderr, "DEBUG: ACL mask says no permission to read file.\n"); -#endif - } - } - - if (has_read_perm && mask_allows_read_access) { -#ifdef DEBUG - fprintf(stderr, "DEBUG: ACL end result: User has permission to read file.\n"); -#endif - return 1; - } - -#ifdef DEBUG - fprintf(stderr, "DEBUG: ACL end result: User has no permission to read file.\n"); -#endif - return 0; -} - -int check_traditional(struct permission_checker *pc, const int flag) { - mode_t mode = pc->file_stat.st_mode; - uid_t uid = pc->file_stat.st_uid; - gid_t gid = pc->file_stat.st_gid; - - if (flag == USER_CHECK && (mode & S_IROTH)) { -#ifdef DEBUG - fprintf(stderr, "DEBUG: Others can read file '%s'\n", - pc->file_path); -#endif - return 1; - - } else if (flag == USER_CHECK && (mode & S_IRUSR) && uid == pc->uid) { -#ifdef DEBUG - fprintf(stderr, "DEBUG: User '%s' can read file '%s'\n", - pc->user_name, pc->file_path); -#endif - return 1; - - } else if (flag == GROUP_CHECK && (mode & S_IRGRP) && is_member_of_group(pc, gid)) { -#ifdef DEBUG - fprintf(stderr, "DEBUG: User's '%s' group can read file '%s'\n", - pc->user_name, pc->file_path); -#endif - return 1; - } - - return 0; -} - -int permission_to_read(char* user_name, char *file_path) { - int rc = -1; - -#ifdef DEBUG - fprintf(stderr, "DEBUG: User check '%s' for file '%s'\n", user_name, file_path); -#endif - struct permission_checker pc = { - .user_name = user_name, - .gids = NULL, - .ngids = 0, - .file_path = file_path, - }; - - // Gather user's UID. - if ((rc = get_user_uid(&pc)) == -1) - // Could not retrieve UID. - goto cleanup; - - // Gather file owner (user and group). - if ((rc = stat_file(&pc)) == -1) - // Could not stat file. - goto cleanup; - - // Check whether there is an ACL entry which would allow the user - // to read the file. Don't check for any groups yet. The issue with - // groups is that it can be very slow to retrieve the list of groups - // of a specific user when done via a remote LDAP server! - if ((rc = check_acl(&pc, USER_CHECK)) == 1) - // Yes, has permissions. - goto cleanup; - - // Check whether ACLs of file could be retrieved. - if (rc == -1) { - if (errno != ENOTSUP) - // Unknown error. - goto cleanup; - - // File system does not support ACLs. - // Fallback to traditional permissions. - if ((rc = check_traditional(&pc, USER_CHECK)) == 1) - // Yes, has traditional permissions. - goto cleanup; - - if ((rc = get_user_groups(&pc)) == -1) - // Can not retrieve user's groups. - goto cleanup; - - rc = check_traditional(&pc, GROUP_CHECK); - goto cleanup; - } - - if ((rc = get_user_groups(&pc)) == -1) - // Can not retrieve use'r groups. - goto cleanup; - - // Check whether there is an ACL entry which would allow any of the - // user's groups to read the file. - rc = check_acl(&pc, GROUP_CHECK); - -cleanup: -#ifdef DEBUG - debug_print_checker(&pc); -#endif - - if (pc.ngids) - free(pc.gids); - - return rc; -} - -// vim: set tabstop=8 softtabstop=0 expandtab shiftwidth=4 smarttab diff --git a/internal/fs/permissions/permission_linux.go b/internal/fs/permissions/permission_linux.go deleted file mode 100644 index feae729..0000000 --- a/internal/fs/permissions/permission_linux.go +++ /dev/null @@ -1,33 +0,0 @@ -package permissions - -/* -#include "permission_linux.h" -#cgo LDFLAGS: -L. -lacl -*/ -import "C" - -import ( - "errors" - "unsafe" -) - -// To check whether user has Linux file system permissions to read a given file. -func ToRead(user, filePath string) (bool, error) { - cUser := C.CString(user) - cFilePath := C.CString(filePath) - - defer C.free(unsafe.Pointer(cUser)) - defer C.free(unsafe.Pointer(cFilePath)) - - cOk, err := C.permission_to_read(cUser, cFilePath) - if cOk == 1 { - return true, nil - } - - if err != nil { - // err contains errno message - return false, err - } - - return false, errors.New("User without permission to read file") -} diff --git a/internal/fs/permissions/permission_linux.h b/internal/fs/permissions/permission_linux.h deleted file mode 100644 index a2c266e..0000000 --- a/internal/fs/permissions/permission_linux.h +++ /dev/null @@ -1,60 +0,0 @@ -#ifndef PERMISSION_LINUX_H -#define PERMISSION_LINUX_H - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -//#define DEBUG -#define USER_CHECK 0 -#define GROUP_CHECK 1 - -struct permission_checker { - char *user_name; - uid_t uid; - gid_t *gids; - int ngids; - char *file_path; - struct stat file_stat; - struct passwd pw; -}; - - -#ifdef DEBUG -// Print out permission_checker struct. -void debug_print_checker(struct permission_checker *pc); -#endif - -// Stat a given file to retrieve traditional UNIX permissions. -int stat_file(struct permission_checker *pc); - -// Retrieve UID of user. -int get_user_uid(struct permission_checker *pc); - -// Retrieve all groups of the user. -int get_user_groups(struct permission_checker *pc); - -// Check whether user is member of a group or not. -int is_member_of_group(struct permission_checker *pc, gid_t gid); - -// Check whether user can read file according Linux ACLs. -// As flag use either USER_CHECK or GROUP_CHECK. -int check_acl(struct permission_checker *pc, const int flag); - -// Check whether user has permissions to read file according traditional -// UNIX permissions. As flag use either USER_CHECK or GROUP_CHECK. -int check_traditional(struct permission_checker *pc, const int flag); - -// Returns 1 if user has permission to read file. -// Returns <0 on error and returns 0 if no permissions. -int permission_to_read(char* user, char *file_path); - -#endif // PERMISSION_LINUX_H diff --git a/internal/fs/permissions/permission_test.go b/internal/fs/permissions/permission_test.go deleted file mode 100644 index d415ac2..0000000 --- a/internal/fs/permissions/permission_test.go +++ /dev/null @@ -1,112 +0,0 @@ -// +build linux - -package permissions - -import ( - "os" - "os/exec" - "os/user" - "strings" - "testing" -) - -const ( - setfacl string = "/usr/bin/setfacl" - file string = "/tmp/acltest" -) - -func TestLinuxACL(t *testing.T) { - setfacl := "/usr/bin/setfacl" - file := "/tmp/acltest" - - // Delete file if it exists. - if _, err := os.Stat(file); err == nil { - os.Remove(file) - } - - f, err := os.Create(file) - if err != nil { - t.Errorf("%v", err) - } - defer func() { - f.Close() - //os.Remove(file) - }() - - user, err := user.Current() - if err != nil { - t.Errorf("Unable to retrieve current user: %v", err) - } - - // Test 1: Remove all permissions and perform a permission check - cmd := exec.Command(setfacl, "-b", "-m", "u::---,g::---,o::---", file) - if err := cmd.Run(); err != nil { - t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) - } - if ok, _ := ToRead(user.Username, file); ok { - t.Errorf("Didn't expect permissions to read file!") - } - - // Test 2: Add read permission to file owner - cmd = exec.Command(setfacl, "-b", "-m", "u::r--,g::---,o::---", file) - if err := cmd.Run(); err != nil { - t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) - } - if ok, err := ToRead(user.Username, file); !ok { - t.Errorf("Expected permissions to read file: %v", err) - } - - // Test 3: Add read permission to file group - cmd = exec.Command(setfacl, "-b", "-m", "u::---,g::r--,o::---", file) - if err := cmd.Run(); err != nil { - t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) - } - if ok, err := ToRead(user.Username, file); !ok { - t.Errorf("Expected permissions to read file: %v", err) - } - - // Test 4: Add read permission to others - cmd = exec.Command(setfacl, "-b", "-m", "u::---,g::---,o::r--", file) - if err := cmd.Run(); err != nil { - t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) - } - - if ok, err := ToRead(user.Username, file); !ok { - t.Errorf("Expected permissions to read file: %v", err) - } - - // Test 5: Remove read permission from mask - cmd = exec.Command(setfacl, "-m", "m::---", file) - if err := cmd.Run(); err != nil { - t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) - } - if ok, _ := ToRead(user.Username, file); ok { - t.Errorf("Didn't expect permissions to read file!") - } - cmd = exec.Command(setfacl, "-m", "m::r--", file) - if err := cmd.Run(); err != nil { - t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) - } - - // Test 6: Add read permission to specific group - cmd = exec.Command(setfacl, "-b", "-m", "u::---,g:"+user.Username+":r--,o::---", file) - if err := cmd.Run(); err != nil { - t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) - } - if ok, err := ToRead(user.Username, file); !ok { - t.Errorf("Expected permissions to read file for user %v: %v", user.Username, err) - } - - // Test 7: Remove all permissions but mask - cmd = exec.Command(setfacl, "-b", "-m", "u::---,g::---,o::---", file) - if err := cmd.Run(); err != nil { - t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) - } - cmd = exec.Command(setfacl, "-m", "m::r--", file) - if err := cmd.Run(); err != nil { - t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) - } - if ok, _ := ToRead(user.Username, file); ok { - t.Errorf("Didn't expect permissions to read file!") - } -} diff --git a/internal/fs/readfile.go b/internal/fs/readfile.go deleted file mode 100644 index 312447a..0000000 --- a/internal/fs/readfile.go +++ /dev/null @@ -1,318 +0,0 @@ -package fs - -import ( - "bufio" - "compress/gzip" - "github.com/mimecast/dtail/internal/logger" - "errors" - "io" - "os" - "regexp" - "strings" - "sync" - "time" - - "github.com/DataDog/zstd" -) - -// Used to tail and filter a local log file. -type readFile struct { - // Various statistics (e.g. regex hit percentage, transfer percentage). - stats - // Path of log file to tail. - filePath string - // Only consider all log lines matching this regular expression. - re *regexp.Regexp - // The glob identifier of the file. - globID string - // Channel to send a server message to the dtail client - serverMessages chan<- string - // Signals to stop tailing the log file. - stop chan struct{} - // Periodically retry reading file. - retry bool - // Can I skip messages when there are too many? - canSkipLines bool - // Seek to the EOF before processing file? - seekEOF bool - // Mutex to control the stopping of the file - mutex *sync.Mutex - limiter chan struct{} -} - -// FilePath returns the full file path. -func (f readFile) FilePath() string { - return f.filePath -} - -// Retry reading the file on error? -func (f readFile) Retry() bool { - return f.retry -} - -// Start tailing a log file. -func (f readFile) Start(lines chan<- LineRead, regex string) error { - defer func() { - select { - case <-f.limiter: - default: - } - }() - - select { - case f.limiter <- struct{}{}: - default: - select { - case f.serverMessages <- logger.Warn(f.filePath, f.globID, "Server limit reached. Queuing file..."): - case <-f.stop: - return nil - } - f.limiter <- struct{}{} - } - - fd, err := os.Open(f.filePath) - if err != nil { - return err - } - defer fd.Close() - - if f.seekEOF { - fd.Seek(0, io.SeekEnd) - } - - rawLines := make(chan []byte, 100) - truncate := make(chan struct{}) - - var wg sync.WaitGroup - wg.Add(1) - - go f.periodicTruncateCheck(truncate) - go f.filter(&wg, rawLines, lines, regex) - - err = f.read(fd, rawLines, truncate) - close(rawLines) - wg.Wait() - - return err -} - -func (f readFile) periodicTruncateCheck(truncate chan struct{}) { - for { - select { - case <-time.After(time.Second * 3): - select { - case truncate <- struct{}{}: - case <-f.stop: - } - case <-f.stop: - return - } - } -} - -// Stop reading file. -func (f readFile) Stop() { - f.mutex.Lock() - defer f.mutex.Unlock() - - select { - case <-f.stop: - return - default: - } - - close(f.stop) -} - -func (f readFile) makeReader(fd *os.File) (reader *bufio.Reader, err error) { - switch { - case strings.HasSuffix(f.FilePath(), ".gz"): - fallthrough - case strings.HasSuffix(f.FilePath(), ".gzip"): - logger.Info(f.FilePath(), "Detected gzip compression format") - var gzipReader *gzip.Reader - gzipReader, err = gzip.NewReader(fd) - if err != nil { - return - } - reader = bufio.NewReader(gzipReader) - case strings.HasSuffix(f.FilePath(), ".zst"): - logger.Info(f.FilePath(), "Detected zstd compression format") - reader = bufio.NewReader(zstd.NewReader(fd)) - default: - reader = bufio.NewReader(fd) - } - - return -} - -func (f readFile) read(fd *os.File, rawLines chan []byte, truncate <-chan struct{}) error { - reader, err := f.makeReader(fd) - if err != nil { - return err - } - rawLine := make([]byte, 0, 512) - var offset uint64 - - lineLengthThreshold := 1024 * 1024 // 1mb - longLineWarning := false - - for { - select { - case <-truncate: - if isTruncated, err := f.truncated(fd); isTruncated { - return err - } - logger.Info(f.filePath, "Current offset", offset) - - case <-f.stop: - return nil - default: - } - - // Read some bytes (max 4k at once as of go 1.12). isPrefix will - // be set if line does not fit into 4k buffer. - bytes, isPrefix, err := reader.ReadLine() - - if err != nil { - // If EOF, sleep a couple of ms and return with nil error. - // If other error, return with non-nil error. - if err != io.EOF { - return err - } - if !f.seekEOF { - logger.Debug(f.FilePath(), "End of file reached") - return nil - } - time.Sleep(time.Millisecond * 100) - continue - } - - rawLine = append(rawLine, bytes...) - offset += uint64(len(bytes)) - - if !isPrefix { - // last LineRead call returned contend until end of line. - rawLine = append(rawLine, '\n') - select { - case rawLines <- rawLine: - case <-f.stop: - return nil - } - rawLine = make([]byte, 0, 512) - if longLineWarning { - longLineWarning = false - } - continue - } - - // Last LineRead call could not read content until end of line, buffer - // was too small. Determine whether we exceed the max line length we - // want dtail to send to the client at once. Possibly split up log line - // into multiple log lines. - if len(rawLine) >= lineLengthThreshold { - if !longLineWarning { - f.serverMessages <- logger.Warn(f.filePath, "Long log line, splitting into multiple lines") - // Only print out one warning per long log line. - longLineWarning = true - } - rawLine = append(rawLine, '\n') - select { - case rawLines <- rawLine: - case <-f.stop: - return nil - } - rawLine = make([]byte, 0, 512) - } - } -} - -// Filter log lines matching a given regular expression. -func (f readFile) filter(wg *sync.WaitGroup, rawLines <-chan []byte, lines chan<- LineRead, regex string) { - defer wg.Done() - - if regex == "" { - regex = "." - } - - re, err := regexp.Compile(regex) - if err != nil { - logger.Error(regex, "Can't compile regex, using '.' instead", err) - re = regexp.MustCompile(".") - } - f.re = re - - for { - select { - case line, ok := <-rawLines: - f.updatePosition() - if !ok { - return - } - if filteredLine, ok := f.transmittable(line, len(lines), cap(lines)); ok { - select { - case lines <- filteredLine: - case <-f.stop: - return - } - } - } - } -} - -func (f readFile) transmittable(line []byte, length, capacity int) (LineRead, bool) { - var read LineRead - - if !f.re.Match(line) { - f.updateLineNotMatched() - f.updateLineNotTransmitted() - return read, false - } - f.updateLineMatched() - - // Can we actually send more messages, channel capacity reached? - if f.canSkipLines && length >= capacity { - f.updateLineNotTransmitted() - return read, false - } - f.updateLineTransmitted() - - read = LineRead{ - Content: line, - GlobID: &f.globID, - Count: f.totalLineCount(), - TransmittedPerc: f.transmittedPerc(), - } - - return read, true -} - -// Check wether log file is truncated. Returns nil if not. -func (f readFile) truncated(fd *os.File) (bool, error) { - logger.Debug(f.filePath, "File truncation check") - - // Can not seek currently open FD. - curPos, err := fd.Seek(0, os.SEEK_CUR) - if err != nil { - return true, err - } - - // Can not open file at original path. - pathFd, err := os.Open(f.filePath) - if err != nil { - return true, err - } - defer pathFd.Close() - - // Can not seek file at original path. - pathPos, err := pathFd.Seek(0, io.SeekEnd) - if err != nil { - return true, err - } - - if curPos > pathPos { - return true, errors.New("File got truncated") - } - - return false, nil -} diff --git a/internal/fs/stats.go b/internal/fs/stats.go deleted file mode 100644 index 4121ff7..0000000 --- a/internal/fs/stats.go +++ /dev/null @@ -1,69 +0,0 @@ -package fs - -// Used to calculate how many log lines matched the regular expression -// and how many log files could be transmitted from the server to the client. -// Hit and transmit percentage takes only the last 100 log lines into calculation. -type stats struct { - pos int - lineCount uint64 - matched [100]bool - matchCount uint64 - transmitted [100]bool - transmitCount int -} - -// Return the total line count. -func (f *stats) totalLineCount() uint64 { - return f.lineCount -} - -// Calculate the percentage of log lines transmitted to the client. -func (f *stats) transmittedPerc() int { - return int(percentOf(float64(f.matchCount), float64(f.transmitCount))) -} - -// Update bucket position. We only take into consideration the last 100 -// lines for stats. -func (f *stats) updatePosition() { - f.pos = (f.pos + 1) % 100 - f.lineCount++ -} - -// Increment match counter. -func (f *stats) updateLineMatched() { - if !f.matched[f.pos] { - f.matchCount++ - f.matched[f.pos] = true - } -} - -// Increment transmitted counter. -func (f *stats) updateLineTransmitted() { - if !f.transmitted[f.pos] { - f.transmitCount++ - f.transmitted[f.pos] = true - } -} - -// Decrement match counter. -func (f *stats) updateLineNotMatched() { - if f.matched[f.pos] { - f.matchCount-- - f.matched[f.pos] = false - } -} - -// Decrement transmitted counter. -func (f *stats) updateLineNotTransmitted() { - if f.transmitted[f.pos] { - f.transmitCount-- - f.transmitted[f.pos] = false - } -} - -func percentOf(total float64, value float64) float64 { - if total == 0 || total == value { - return 100 - } - return value / (total / 100.0) -} diff --git a/internal/fs/tailfile.go b/internal/fs/tailfile.go deleted file mode 100644 index a19d4e6..0000000 --- a/internal/fs/tailfile.go +++ /dev/null @@ -1,27 +0,0 @@ -package fs - -import "sync" - -// TailFile is to tail and filter a log file. -type TailFile struct { - readFile -} - -// NewTailFile returns a new file tailer. -func NewTailFile(filePath string, globID string, serverMessages chan<- string, limiter chan struct{}) TailFile { - var mutex sync.Mutex - - return TailFile{ - readFile: readFile{ - filePath: filePath, - stop: make(chan struct{}), - globID: globID, - serverMessages: serverMessages, - retry: true, - canSkipLines: true, - seekEOF: true, - limiter: limiter, - mutex: &mutex, - }, - } -} diff --git a/internal/io/fs/catfile.go b/internal/io/fs/catfile.go new file mode 100644 index 0000000..7f387bc --- /dev/null +++ b/internal/io/fs/catfile.go @@ -0,0 +1,21 @@ +package fs + +// CatFile is for reading a whole file. +type CatFile struct { + readFile +} + +// NewCatFile returns a new file catter. +func NewCatFile(filePath string, globID string, serverMessages chan<- string, limiter chan struct{}) CatFile { + return CatFile{ + readFile: readFile{ + filePath: filePath, + globID: globID, + serverMessages: serverMessages, + retry: false, + canSkipLines: false, + seekEOF: false, + limiter: limiter, + }, + } +} diff --git a/internal/io/fs/filereader.go b/internal/io/fs/filereader.go new file mode 100644 index 0000000..05e58a1 --- /dev/null +++ b/internal/io/fs/filereader.go @@ -0,0 +1,14 @@ +package fs + +import ( + "context" + + "github.com/mimecast/dtail/internal/io/line" +) + +// FileReader is the interface used on the dtail server to read/cat/grep/mapr... a file. +type FileReader interface { + Start(ctx context.Context, lines chan<- line.Line, regex string) error + FilePath() string + Retry() bool +} diff --git a/internal/io/fs/permissions/permission.go b/internal/io/fs/permissions/permission.go new file mode 100644 index 0000000..0ed4f17 --- /dev/null +++ b/internal/io/fs/permissions/permission.go @@ -0,0 +1,14 @@ +// +build !linux + +package permissions + +import ( + "github.com/mimecast/dtail/internal/io/logger" +) + +// ToRead is to check whether user has read permissions to a given file. +func ToRead(user, filePath string) (bool, error) { + // Only implemented for Linux, always expect true + logger.Warn(user, filePath, "Not performing ACL check, not supported on this platform") + return true, nil +} diff --git a/internal/io/fs/permissions/permission_linux.c b/internal/io/fs/permissions/permission_linux.c new file mode 100644 index 0000000..cd10525 --- /dev/null +++ b/internal/io/fs/permissions/permission_linux.c @@ -0,0 +1,395 @@ +#include "permission_linux.h" + +#ifdef DEBUG +void debug_print_checker(struct permission_checker *pc) { + fprintf(stderr, "DEBUG: user_name:%s (%d)\n", + pc->user_name, pc->uid); + + fprintf(stderr, "DEBUG: ngids:%d\n", pc->ngids); + int j; + for (j = 0; j < pc->ngids; j++) { + fprintf(stderr, "DEBUG: %d", pc->gids[j]); + struct group *gr = getgrgid(pc->gids[j]); + if (gr != NULL) + fprintf(stderr, " (%s)", gr->gr_name); + fprintf(stderr, "\n"); + } + + fprintf(stderr, "DEBUG: file_path:%s (%d:%d)\n", + pc->file_path, pc->file_stat.st_uid, pc->file_stat.st_gid); +} +#endif // DEBUG + +int stat_file(struct permission_checker *pc) { + if (stat(pc->file_path, &pc->file_stat) != 0) + return -1; + +#ifdef DEBUG + fprintf(stderr, "DEBUG: File'%s' is owned by '%d:%d'\n", + pc->file_path, pc->file_stat.st_uid, pc->file_stat.st_gid); +#endif + + return 0; +} + +int get_user_uid(struct permission_checker *pc) { + struct passwd *result = NULL; + + size_t bufsize = sysconf(_SC_GETPW_R_SIZE_MAX); + if (bufsize == -1) + bufsize = 16384; + + char *buf = malloc(bufsize); + if (buf == NULL) { +#ifdef DEBUG + fprintf(stderr, "DEBUG: Unabel to allocate bufer while retrieving user '%s'\n", pc->user_name); +#endif + return -1; + } + + int rc = getpwnam_r(pc->user_name, &pc->pw, buf, bufsize, &result); + + if (result == NULL) { +#ifdef DEBUG + if (rc == 0) { + fprintf(stderr, "DEBUG: No user '%s' found\n", pc->user_name); + } else { + fprintf(stderr, "DEBUG: Unknown error while retrieving user '%s'\n", pc->user_name); + } +#endif + + free(buf); + return -1; + } + + pc->uid = pc->pw.pw_uid; + + free(buf); + return 0; +} + +int get_user_groups(struct permission_checker *pc) { + // First assume we are in 10 groups max + pc->ngids = 10; + pc->gids = malloc(pc->ngids * sizeof(gid_t)); + + if (pc->gids == NULL) { +#ifdef DEBUG + fprintf(stderr, "DEBUG: Unable to allocate space for gids."); +#endif + return -1; + } + + // Try so many times to load group list until it fits into group array. + while (getgrouplist(pc->user_name, pc->pw.pw_gid, pc->gids, &pc->ngids) == -1) { + // Too many groups, enlarge group array and try again + int newngids = pc->ngids + 100; + size_t newsize = newngids * sizeof(gid_t); + + if (SIZE_MAX / newngids < sizeof(gid_t)) { + // Overflow +#ifdef DEBUG + fprintf(stderr, "DEBUG: Overflow detected."); +#endif + return -1; + } + + gid_t *newgids = realloc(pc->gids, newsize); + if (newgids == NULL) { +#ifdef DEBUG + fprintf(stderr, "DEBUG: Unable to allocate space for gids."); +#endif + free(pc->gids); + return -1; + } + + pc->gids = newgids; + pc->ngids = newngids; + } + + return 0; +} + +int is_member_of_group(struct permission_checker *pc, gid_t gid) { + int j; + for (j = 0; j < pc->ngids; j++) + if (pc->gids[j] == gid) + return 1; + return 0; +} + +int check_acl_uid_matches(uid_t uid, acl_entry_t entry) { + int ret = -1; + uid_t *acl_uid = acl_get_qualifier(entry); + if (acl_uid == NULL) { +#ifdef DEBUG + fprintf(stderr, "DEBUG: Unable to retrieve user uid from ACL entry"); +#endif + return -1; + } + + ret = *acl_uid == uid ? 0 : -1; +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL user match?: %d <=> %d: %d\n", *acl_uid, uid, ret); +#endif + acl_free(acl_uid); + return ret; +} + +int check_acl_gid_matches(gid_t *gids, int ngids, acl_entry_t entry) { + int ret = -1; + gid_t *acl_gid = acl_get_qualifier(entry); + if (acl_gid == NULL) { +#ifdef DEBUG + fprintf(stderr, "DEBUG: Unable to retrieve user uid from ACL entry"); +#endif + return -1; + } + + int j; + for (j = 0; j < ngids; j++) { + if (*acl_gid == gids[j]) { +#ifdef DEBUG + fprintf(stderr, "DEBUG: User is in group %d", *acl_gid); +#endif + ret = 0; + break; + } + } + +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL group match?: %d <=> ...: %d\n", *acl_gid, ret); +#endif + acl_free(acl_gid); + return ret; +} + +int check_acl(struct permission_checker *pc, const int flag) { + // By default user has no read perm. + int has_read_perm = 0; + + // By default mask tells that there are read perm. However in order to have + // read permissions both, has_read_perm and mask_allows_read_access must be 1! + int mask_allows_read_access = 1; + + acl_type_t type = ACL_TYPE_ACCESS; + acl_t acl = acl_get_file(pc->file_path, type); + + if (acl == NULL) + // Unable to retrieve ACL. + return -1; + + // Walk through each entry of this ACL. + int id; + for (id = ACL_FIRST_ENTRY; ; id = ACL_NEXT_ENTRY) { + acl_entry_t entry; + if (acl_get_entry(acl, id, &entry) != 1) + // No more ACL entries. + break; + + acl_tag_t tag; + if (acl_get_tag_type(entry, &tag) == -1) + // Unable to retrieve ACL tag. + return -1; + + switch (tag) { + case ACL_USER_OBJ: + if (flag == GROUP_CHECK) + continue; +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL_USER_OBJ\n"); +#endif + // Ignore this ACL entry if user is not owner of file. + if (pc->uid != pc->file_stat.st_uid) + continue; + break; + case ACL_USER: + if (flag == GROUP_CHECK) + continue; +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL_USER\n"); +#endif + // Ignore this ACL entry if uid does not match. + if (check_acl_uid_matches(pc->uid, entry) != 0) + continue; + break; + case ACL_GROUP_OBJ: + if (flag == USER_CHECK) + continue; +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL_GROUP_OBJ\n"); +#endif + // Ignore ACL entry if user is not in group of file. + if (!is_member_of_group(pc, pc->file_stat.st_gid)) + continue; + break; + case ACL_GROUP: + if (flag == USER_CHECK) + continue; +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL_GROUP\n"); +#endif + // Ignore ACL entry if user is not in group of entry. + if (check_acl_gid_matches(pc->gids, pc->ngids, entry) != 0) + continue; + break; + case ACL_OTHER: + if (flag == GROUP_CHECK) + continue; +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL_OTHER\n"); +#endif + break; + case ACL_MASK: +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL_MASK\n"); +#endif + break; + default: +#ifdef DEBUG + fprintf(stderr, "DEBUG: Unknown ACL tag\n"); +#endif + return -1; + } + +#ifdef DEBUG + fprintf(stderr, "DEBUG: Retrieving permset\n"); +#endif + acl_permset_t permset; + int permission; + if (acl_get_permset(entry, &permset) == -1) + // Unable to retrieve permset. + return -1; + + if ((permission = acl_get_perm(permset, ACL_READ)) == -1) + // Unable to retrieve permset value. + return -1; + + if (permission == 1 && tag != ACL_MASK) { +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL says user has permission to read file.\n"); +#endif + has_read_perm = 1; + } else if (permission == 0 && tag == ACL_MASK) { + // Mask says that there are no permissions to read. + mask_allows_read_access = 0; +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL mask says no permission to read file.\n"); +#endif + } + } + + if (has_read_perm && mask_allows_read_access) { +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL end result: User has permission to read file.\n"); +#endif + return 1; + } + +#ifdef DEBUG + fprintf(stderr, "DEBUG: ACL end result: User has no permission to read file.\n"); +#endif + return 0; +} + +int check_traditional(struct permission_checker *pc, const int flag) { + mode_t mode = pc->file_stat.st_mode; + uid_t uid = pc->file_stat.st_uid; + gid_t gid = pc->file_stat.st_gid; + + if (flag == USER_CHECK && (mode & S_IROTH)) { +#ifdef DEBUG + fprintf(stderr, "DEBUG: Others can read file '%s'\n", + pc->file_path); +#endif + return 1; + + } else if (flag == USER_CHECK && (mode & S_IRUSR) && uid == pc->uid) { +#ifdef DEBUG + fprintf(stderr, "DEBUG: User '%s' can read file '%s'\n", + pc->user_name, pc->file_path); +#endif + return 1; + + } else if (flag == GROUP_CHECK && (mode & S_IRGRP) && is_member_of_group(pc, gid)) { +#ifdef DEBUG + fprintf(stderr, "DEBUG: User's '%s' group can read file '%s'\n", + pc->user_name, pc->file_path); +#endif + return 1; + } + + return 0; +} + +int permission_to_read(char* user_name, char *file_path) { + int rc = -1; + +#ifdef DEBUG + fprintf(stderr, "DEBUG: User check '%s' for file '%s'\n", user_name, file_path); +#endif + struct permission_checker pc = { + .user_name = user_name, + .gids = NULL, + .ngids = 0, + .file_path = file_path, + }; + + // Gather user's UID. + if ((rc = get_user_uid(&pc)) == -1) + // Could not retrieve UID. + goto cleanup; + + // Gather file owner (user and group). + if ((rc = stat_file(&pc)) == -1) + // Could not stat file. + goto cleanup; + + // Check whether there is an ACL entry which would allow the user + // to read the file. Don't check for any groups yet. The issue with + // groups is that it can be very slow to retrieve the list of groups + // of a specific user when done via a remote LDAP server! + if ((rc = check_acl(&pc, USER_CHECK)) == 1) + // Yes, has permissions. + goto cleanup; + + // Check whether ACLs of file could be retrieved. + if (rc == -1) { + if (errno != ENOTSUP) + // Unknown error. + goto cleanup; + + // File system does not support ACLs. + // Fallback to traditional permissions. + if ((rc = check_traditional(&pc, USER_CHECK)) == 1) + // Yes, has traditional permissions. + goto cleanup; + + if ((rc = get_user_groups(&pc)) == -1) + // Can not retrieve user's groups. + goto cleanup; + + rc = check_traditional(&pc, GROUP_CHECK); + goto cleanup; + } + + if ((rc = get_user_groups(&pc)) == -1) + // Can not retrieve use'r groups. + goto cleanup; + + // Check whether there is an ACL entry which would allow any of the + // user's groups to read the file. + rc = check_acl(&pc, GROUP_CHECK); + +cleanup: +#ifdef DEBUG + debug_print_checker(&pc); +#endif + + if (pc.ngids) + free(pc.gids); + + return rc; +} + +// vim: set tabstop=8 softtabstop=0 expandtab shiftwidth=4 smarttab diff --git a/internal/io/fs/permissions/permission_linux.go b/internal/io/fs/permissions/permission_linux.go new file mode 100644 index 0000000..feae729 --- /dev/null +++ b/internal/io/fs/permissions/permission_linux.go @@ -0,0 +1,33 @@ +package permissions + +/* +#include "permission_linux.h" +#cgo LDFLAGS: -L. -lacl +*/ +import "C" + +import ( + "errors" + "unsafe" +) + +// To check whether user has Linux file system permissions to read a given file. +func ToRead(user, filePath string) (bool, error) { + cUser := C.CString(user) + cFilePath := C.CString(filePath) + + defer C.free(unsafe.Pointer(cUser)) + defer C.free(unsafe.Pointer(cFilePath)) + + cOk, err := C.permission_to_read(cUser, cFilePath) + if cOk == 1 { + return true, nil + } + + if err != nil { + // err contains errno message + return false, err + } + + return false, errors.New("User without permission to read file") +} diff --git a/internal/io/fs/permissions/permission_linux.h b/internal/io/fs/permissions/permission_linux.h new file mode 100644 index 0000000..a2c266e --- /dev/null +++ b/internal/io/fs/permissions/permission_linux.h @@ -0,0 +1,60 @@ +#ifndef PERMISSION_LINUX_H +#define PERMISSION_LINUX_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +//#define DEBUG +#define USER_CHECK 0 +#define GROUP_CHECK 1 + +struct permission_checker { + char *user_name; + uid_t uid; + gid_t *gids; + int ngids; + char *file_path; + struct stat file_stat; + struct passwd pw; +}; + + +#ifdef DEBUG +// Print out permission_checker struct. +void debug_print_checker(struct permission_checker *pc); +#endif + +// Stat a given file to retrieve traditional UNIX permissions. +int stat_file(struct permission_checker *pc); + +// Retrieve UID of user. +int get_user_uid(struct permission_checker *pc); + +// Retrieve all groups of the user. +int get_user_groups(struct permission_checker *pc); + +// Check whether user is member of a group or not. +int is_member_of_group(struct permission_checker *pc, gid_t gid); + +// Check whether user can read file according Linux ACLs. +// As flag use either USER_CHECK or GROUP_CHECK. +int check_acl(struct permission_checker *pc, const int flag); + +// Check whether user has permissions to read file according traditional +// UNIX permissions. As flag use either USER_CHECK or GROUP_CHECK. +int check_traditional(struct permission_checker *pc, const int flag); + +// Returns 1 if user has permission to read file. +// Returns <0 on error and returns 0 if no permissions. +int permission_to_read(char* user, char *file_path); + +#endif // PERMISSION_LINUX_H diff --git a/internal/io/fs/permissions/permission_test.go b/internal/io/fs/permissions/permission_test.go new file mode 100644 index 0000000..d415ac2 --- /dev/null +++ b/internal/io/fs/permissions/permission_test.go @@ -0,0 +1,112 @@ +// +build linux + +package permissions + +import ( + "os" + "os/exec" + "os/user" + "strings" + "testing" +) + +const ( + setfacl string = "/usr/bin/setfacl" + file string = "/tmp/acltest" +) + +func TestLinuxACL(t *testing.T) { + setfacl := "/usr/bin/setfacl" + file := "/tmp/acltest" + + // Delete file if it exists. + if _, err := os.Stat(file); err == nil { + os.Remove(file) + } + + f, err := os.Create(file) + if err != nil { + t.Errorf("%v", err) + } + defer func() { + f.Close() + //os.Remove(file) + }() + + user, err := user.Current() + if err != nil { + t.Errorf("Unable to retrieve current user: %v", err) + } + + // Test 1: Remove all permissions and perform a permission check + cmd := exec.Command(setfacl, "-b", "-m", "u::---,g::---,o::---", file) + if err := cmd.Run(); err != nil { + t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) + } + if ok, _ := ToRead(user.Username, file); ok { + t.Errorf("Didn't expect permissions to read file!") + } + + // Test 2: Add read permission to file owner + cmd = exec.Command(setfacl, "-b", "-m", "u::r--,g::---,o::---", file) + if err := cmd.Run(); err != nil { + t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) + } + if ok, err := ToRead(user.Username, file); !ok { + t.Errorf("Expected permissions to read file: %v", err) + } + + // Test 3: Add read permission to file group + cmd = exec.Command(setfacl, "-b", "-m", "u::---,g::r--,o::---", file) + if err := cmd.Run(); err != nil { + t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) + } + if ok, err := ToRead(user.Username, file); !ok { + t.Errorf("Expected permissions to read file: %v", err) + } + + // Test 4: Add read permission to others + cmd = exec.Command(setfacl, "-b", "-m", "u::---,g::---,o::r--", file) + if err := cmd.Run(); err != nil { + t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) + } + + if ok, err := ToRead(user.Username, file); !ok { + t.Errorf("Expected permissions to read file: %v", err) + } + + // Test 5: Remove read permission from mask + cmd = exec.Command(setfacl, "-m", "m::---", file) + if err := cmd.Run(); err != nil { + t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) + } + if ok, _ := ToRead(user.Username, file); ok { + t.Errorf("Didn't expect permissions to read file!") + } + cmd = exec.Command(setfacl, "-m", "m::r--", file) + if err := cmd.Run(); err != nil { + t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) + } + + // Test 6: Add read permission to specific group + cmd = exec.Command(setfacl, "-b", "-m", "u::---,g:"+user.Username+":r--,o::---", file) + if err := cmd.Run(); err != nil { + t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) + } + if ok, err := ToRead(user.Username, file); !ok { + t.Errorf("Expected permissions to read file for user %v: %v", user.Username, err) + } + + // Test 7: Remove all permissions but mask + cmd = exec.Command(setfacl, "-b", "-m", "u::---,g::---,o::---", file) + if err := cmd.Run(); err != nil { + t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) + } + cmd = exec.Command(setfacl, "-m", "m::r--", file) + if err := cmd.Run(); err != nil { + t.Errorf("%s -> %v", strings.Join(cmd.Args, " "), err) + } + if ok, _ := ToRead(user.Username, file); ok { + t.Errorf("Didn't expect permissions to read file!") + } +} diff --git a/internal/io/fs/readfile.go b/internal/io/fs/readfile.go new file mode 100644 index 0000000..321432e --- /dev/null +++ b/internal/io/fs/readfile.go @@ -0,0 +1,307 @@ +package fs + +import ( + "bufio" + "compress/gzip" + "context" + "errors" + "io" + "os" + "regexp" + "strings" + "sync" + "time" + + "github.com/mimecast/dtail/internal/io/line" + "github.com/mimecast/dtail/internal/io/logger" + + "github.com/DataDog/zstd" +) + +// Used to tail and filter a local log file. +type readFile struct { + // Various statistics (e.g. regex hit percentage, transfer percentage). + stats + // Path of log file to tail. + filePath string + // Only consider all log lines matching this regular expression. + re *regexp.Regexp + // The glob identifier of the file. + globID string + // Channel to send a server message to the dtail client + serverMessages chan<- string + // Periodically retry reading file. + retry bool + // Can I skip messages when there are too many? + canSkipLines bool + // Seek to the EOF before processing file? + seekEOF bool + limiter chan struct{} +} + +// FilePath returns the full file path. +func (f readFile) FilePath() string { + return f.filePath +} + +// Retry reading the file on error? +func (f readFile) Retry() bool { + return f.retry +} + +// Start tailing a log file. +func (f readFile) Start(ctx context.Context, lines chan<- line.Line, regex string) error { + defer func() { + select { + case <-f.limiter: + default: + } + }() + + select { + case f.limiter <- struct{}{}: + default: + select { + case f.serverMessages <- logger.Warn(f.filePath, f.globID, "Server limit reached. Queuing file..."): + case <-ctx.Done(): + return nil + } + f.limiter <- struct{}{} + } + + fd, err := os.Open(f.filePath) + if err != nil { + return err + } + defer fd.Close() + + if f.seekEOF { + fd.Seek(0, io.SeekEnd) + } + + rawLines := make(chan []byte, 100) + truncate := make(chan struct{}) + + var wg sync.WaitGroup + wg.Add(1) + + go f.periodicTruncateCheck(ctx, truncate) + go f.filter(ctx, &wg, rawLines, lines, regex) + + err = f.read(ctx, fd, rawLines, truncate) + close(rawLines) + wg.Wait() + + return err +} + +func (f readFile) periodicTruncateCheck(ctx context.Context, truncate chan struct{}) { + for { + select { + case <-time.After(time.Second * 3): + select { + case truncate <- struct{}{}: + case <-ctx.Done(): + } + case <-ctx.Done(): + return + } + } +} + +func (f readFile) makeReader(fd *os.File) (reader *bufio.Reader, err error) { + switch { + case strings.HasSuffix(f.FilePath(), ".gz"): + fallthrough + case strings.HasSuffix(f.FilePath(), ".gzip"): + logger.Info(f.FilePath(), "Detected gzip compression format") + var gzipReader *gzip.Reader + gzipReader, err = gzip.NewReader(fd) + if err != nil { + return + } + reader = bufio.NewReader(gzipReader) + case strings.HasSuffix(f.FilePath(), ".zst"): + logger.Info(f.FilePath(), "Detected zstd compression format") + reader = bufio.NewReader(zstd.NewReader(fd)) + default: + reader = bufio.NewReader(fd) + } + + return +} + +func (f readFile) read(ctx context.Context, fd *os.File, rawLines chan []byte, truncate <-chan struct{}) error { + var offset uint64 + + reader, err := f.makeReader(fd) + if err != nil { + return err + } + rawLine := make([]byte, 0, 512) + + lineLengthThreshold := 1024 * 1024 // 1mb + longLineWarning := false + + for { + select { + case <-ctx.Done(): + return nil + default: + } + + select { + case <-truncate: + if isTruncated, err := f.truncated(fd); isTruncated { + return err + } + logger.Info(f.filePath, "Current offset", offset) + default: + } + + // Read some bytes (max 4k at once as of go 1.12). isPrefix will + // be set if line does not fit into 4k buffer. + bytes, isPrefix, err := reader.ReadLine() + + if err != nil { + // If EOF, sleep a couple of ms and return with nil error. + // If other error, return with non-nil error. + if err != io.EOF { + return err + } + if !f.seekEOF { + logger.Debug(f.FilePath(), "End of file reached") + return nil + } + time.Sleep(time.Millisecond * 100) + continue + } + + rawLine = append(rawLine, bytes...) + offset += uint64(len(bytes)) + + if !isPrefix { + // last LineRead call returned contend until end of line. + rawLine = append(rawLine, '\n') + select { + case rawLines <- rawLine: + case <-ctx.Done(): + return nil + } + rawLine = make([]byte, 0, 512) + if longLineWarning { + longLineWarning = false + } + continue + } + + // Last LineRead call could not read content until end of line, buffer + // was too small. Determine whether we exceed the max line length we + // want dtail to send to the client at once. Possibly split up log line + // into multiple log lines. + if len(rawLine) >= lineLengthThreshold { + if !longLineWarning { + f.serverMessages <- logger.Warn(f.filePath, "Long log line, splitting into multiple lines") + // Only print out one warning per long log line. + longLineWarning = true + } + rawLine = append(rawLine, '\n') + select { + case rawLines <- rawLine: + case <-ctx.Done(): + return nil + } + rawLine = make([]byte, 0, 512) + } + } +} + +// Filter log lines matching a given regular expression. +func (f readFile) filter(ctx context.Context, wg *sync.WaitGroup, rawLines <-chan []byte, lines chan<- line.Line, regex string) { + defer wg.Done() + + if regex == "" { + regex = "." + } + + re, err := regexp.Compile(regex) + if err != nil { + logger.Error(regex, "Can't compile regex, using '.' instead", err) + re = regexp.MustCompile(".") + } + f.re = re + + for { + select { + case line, ok := <-rawLines: + f.updatePosition() + if !ok { + return + } + if filteredLine, ok := f.transmittable(line, len(lines), cap(lines)); ok { + select { + case lines <- filteredLine: + case <-ctx.Done(): + return + } + } + } + } +} + +func (f readFile) transmittable(lineBytes []byte, length, capacity int) (line.Line, bool) { + var read line.Line + + if !f.re.Match(lineBytes) { + f.updateLineNotMatched() + f.updateLineNotTransmitted() + return read, false + } + f.updateLineMatched() + + // Can we actually send more messages, channel capacity reached? + if f.canSkipLines && length >= capacity { + f.updateLineNotTransmitted() + return read, false + } + f.updateLineTransmitted() + + read = line.Line{ + Content: lineBytes, + SourceID: f.globID, + Count: f.totalLineCount(), + TransmittedPerc: f.transmittedPerc(), + } + + return read, true +} + +// Check wether log file is truncated. Returns nil if not. +func (f readFile) truncated(fd *os.File) (bool, error) { + logger.Debug(f.filePath, "File truncation check") + + // Can not seek currently open FD. + curPos, err := fd.Seek(0, os.SEEK_CUR) + if err != nil { + return true, err + } + + // Can not open file at original path. + pathFd, err := os.Open(f.filePath) + if err != nil { + return true, err + } + defer pathFd.Close() + + // Can not seek file at original path. + pathPos, err := pathFd.Seek(0, io.SeekEnd) + if err != nil { + return true, err + } + + if curPos > pathPos { + return true, errors.New("File got truncated") + } + + return false, nil +} diff --git a/internal/io/fs/stats.go b/internal/io/fs/stats.go new file mode 100644 index 0000000..4121ff7 --- /dev/null +++ b/internal/io/fs/stats.go @@ -0,0 +1,69 @@ +package fs + +// Used to calculate how many log lines matched the regular expression +// and how many log files could be transmitted from the server to the client. +// Hit and transmit percentage takes only the last 100 log lines into calculation. +type stats struct { + pos int + lineCount uint64 + matched [100]bool + matchCount uint64 + transmitted [100]bool + transmitCount int +} + +// Return the total line count. +func (f *stats) totalLineCount() uint64 { + return f.lineCount +} + +// Calculate the percentage of log lines transmitted to the client. +func (f *stats) transmittedPerc() int { + return int(percentOf(float64(f.matchCount), float64(f.transmitCount))) +} + +// Update bucket position. We only take into consideration the last 100 +// lines for stats. +func (f *stats) updatePosition() { + f.pos = (f.pos + 1) % 100 + f.lineCount++ +} + +// Increment match counter. +func (f *stats) updateLineMatched() { + if !f.matched[f.pos] { + f.matchCount++ + f.matched[f.pos] = true + } +} + +// Increment transmitted counter. +func (f *stats) updateLineTransmitted() { + if !f.transmitted[f.pos] { + f.transmitCount++ + f.transmitted[f.pos] = true + } +} + +// Decrement match counter. +func (f *stats) updateLineNotMatched() { + if f.matched[f.pos] { + f.matchCount-- + f.matched[f.pos] = false + } +} + +// Decrement transmitted counter. +func (f *stats) updateLineNotTransmitted() { + if f.transmitted[f.pos] { + f.transmitCount-- + f.transmitted[f.pos] = false + } +} + +func percentOf(total float64, value float64) float64 { + if total == 0 || total == value { + return 100 + } + return value / (total / 100.0) +} diff --git a/internal/io/fs/tailfile.go b/internal/io/fs/tailfile.go new file mode 100644 index 0000000..14994e5 --- /dev/null +++ b/internal/io/fs/tailfile.go @@ -0,0 +1,21 @@ +package fs + +// TailFile is to tail and filter a log file. +type TailFile struct { + readFile +} + +// NewTailFile returns a new file tailer. +func NewTailFile(filePath string, globID string, serverMessages chan<- string, limiter chan struct{}) TailFile { + return TailFile{ + readFile: readFile{ + filePath: filePath, + globID: globID, + serverMessages: serverMessages, + retry: true, + canSkipLines: true, + seekEOF: true, + limiter: limiter, + }, + } +} diff --git a/internal/io/line/line.go b/internal/io/line/line.go new file mode 100644 index 0000000..9db93c0 --- /dev/null +++ b/internal/io/line/line.go @@ -0,0 +1,28 @@ +package line + +import ( + "fmt" +) + +// Line represents a read log line. +type Line struct { + // The content of the log line. + Content []byte + // Until now, how many log lines were processed? + Count uint64 + // Sometimes we produce too many log lines so that the client + // is too slow to process all of them. The server will drop log + // lines if that happens but it will signal to the client how + // many log lines in % could be transmitted to the client. + TransmittedPerc int + SourceID string +} + +// Return a human readable representation of the followed line. +func (l Line) String() string { + return fmt.Sprintf("Line(Content:%s,TransmittedPerc:%v,Count:%v,SourceID:%s)", + string(l.Content), + l.TransmittedPerc, + l.Count, + l.SourceID) +} diff --git a/internal/io/logger/logger.go b/internal/io/logger/logger.go new file mode 100644 index 0000000..e30b907 --- /dev/null +++ b/internal/io/logger/logger.go @@ -0,0 +1,445 @@ +package logger + +import ( + "bufio" + "context" + "fmt" + "os" + "os/signal" + "runtime" + "strings" + "sync" + "syscall" + "time" + + "github.com/mimecast/dtail/internal/color" + "github.com/mimecast/dtail/internal/config" +) + +const ( + clientStr string = "CLIENT" + serverStr string = "SERVER" + infoStr string = "INFO" + warnStr string = "WARN" + errorStr string = "ERROR" + fatalStr string = "FATAL" + debugStr string = "DEBUG" + traceStr string = "TRACE" +) + +// Synchronise access to logging. +var mutex sync.Mutex + +// File descriptor of log file when logToFile enabled. +var fd *os.File + +// File write buffer of log file when logToFile enabled. +var writer *bufio.Writer + +// File write buffer of stdout when logToStdout enabled. +var stdoutWriter *bufio.Writer + +// Current hostname. +var hostname string + +// Used to detect change of day (create one log file per day0 +var lastDateStr string + +// True if log in server mode, false if log in client mode. +var serverEnable bool + +// Used to make logging non-blocking. +var fileLogBufCh chan buf +var stdoutBufCh chan string + +// Stdout channel, required to pause output +var pauseCh chan struct{} +var resumeCh chan struct{} + +// Tell the logger about logrotation +var rotateCh chan os.Signal + +// LogMode allows to specify the verbosity of logging. +type LogMode int + +// Possible log modes. +const ( + NormalMode LogMode = iota + DebugMode LogMode = iota + SilentMode LogMode = iota + TraceMode LogMode = iota + NothingMode LogMode = iota +) + +// Mode is the current log mode in use. +var Mode LogMode + +// LogStrategy allows to specify a log rotation strategy. +type LogStrategy int + +// Possible log strategies. +const ( + NormalStrategy LogStrategy = iota + DailyStrategy LogStrategy = iota + StdoutStrategy LogStrategy = iota +) + +// Strategy is the current log strattegy used. +var Strategy LogStrategy + +// Enables logging to stdout. +var logToStdout bool + +// Enables logging to file. +var logToFile bool + +// Helper type to make logging non-blocking. +type buf struct { + time time.Time + message string +} + +// Start logging. +func Start(ctx context.Context, myServerEnable, debugEnable, silentEnable, nothingEnable bool) { + serverEnable = myServerEnable + + mode := logMode(debugEnable, silentEnable, nothingEnable) + strategy := logStrategy() + + stdoutWriter = bufio.NewWriter(os.Stdout) + Mode = mode + Strategy = strategy + + if Mode == NothingMode { + return + } + + switch Strategy { + case DailyStrategy: + _, err := os.Stat(config.Common.LogDir) + logToFile = !os.IsNotExist(err) + logToStdout = !serverEnable || Mode == DebugMode || Mode == TraceMode + case StdoutStrategy: + fallthrough + default: + logToFile = !serverEnable + logToStdout = true + } + + fqdn, err := os.Hostname() + if err != nil { + panic(err) + } + s := strings.Split(fqdn, ".") + hostname = s[0] + + pauseCh = make(chan struct{}) + resumeCh = make(chan struct{}) + + // Setup logrotation + rotateCh = make(chan os.Signal, 1) + signal.Notify(rotateCh, syscall.SIGHUP) + + if logToStdout { + stdoutBufCh = make(chan string, runtime.NumCPU()*100) + go writeToStdout(ctx) + } + + if logToFile { + fileLogBufCh = make(chan buf, runtime.NumCPU()*100) + go writeToFile(ctx) + } +} + +func logMode(debugEnable, silentEnable, nothingEnable bool) LogMode { + switch { + case debugEnable: + return DebugMode + case nothingEnable: + return NothingMode + case config.Common.TraceEnable: + return TraceMode + case config.Common.DebugEnable: + return DebugMode + case silentEnable: + return SilentMode + default: + } + return NormalMode +} + +func logStrategy() LogStrategy { + switch config.Common.LogStrategy { + case "daily": + return DailyStrategy + default: + } + return StdoutStrategy +} + +// Info message logging. +func Info(args ...interface{}) string { + if serverEnable { + return log(serverStr, infoStr, args) + } + + return log(clientStr, infoStr, args) +} + +// Warn message logging. +func Warn(args ...interface{}) string { + if serverEnable { + return log(serverStr, warnStr, args) + } + + return log(clientStr, warnStr, args) +} + +// Error message logging. +func Error(args ...interface{}) string { + if serverEnable { + return log(serverStr, errorStr, args) + } + + return log(clientStr, errorStr, args) +} + +// FatalExit logs an error and exists the process. +func FatalExit(args ...interface{}) { + what := clientStr + if serverEnable { + what = serverStr + } + log(what, fatalStr, args) + + time.Sleep(time.Second) + mutex.Lock() + defer mutex.Unlock() + + closeWriter() + os.Exit(3) +} + +// Debug message logging. +func Debug(args ...interface{}) string { + if Mode == DebugMode || Mode == TraceMode { + if serverEnable { + return log(serverStr, debugStr, args) + } + return log(clientStr, debugStr, args) + } + + return "" +} + +// Trace message logging. +func Trace(args ...interface{}) string { + if Mode == TraceMode { + if serverEnable { + return log(serverStr, traceStr, args) + } + return log(clientStr, traceStr, args) + } + + return "" +} + +// Write log line to buffer and/or log file. +func write(what, severity, message string) { + if logToStdout && (Mode != SilentMode || severity != warnStr) { + line := fmt.Sprintf("%s|%s|%s|%s\n", what, hostname, severity, message) + + if color.Colored { + line = color.Colorfy(line) + } + + stdoutBufCh <- line + } + + if logToFile { + t := time.Now() + timeStr := t.Format("20060102-150405") + fileLogBufCh <- buf{ + time: t, + message: fmt.Sprintf("%s|%s|%s|%s\n", severity, timeStr, what, message), + } + } +} + +// Generig log message. +func log(what string, severity string, args []interface{}) string { + if Mode == NothingMode { + return "" + } + + var messages []string + + for _, arg := range args { + switch v := arg.(type) { + case string: + messages = append(messages, v) + case int: + messages = append(messages, fmt.Sprintf("%d", v)) + case error: + messages = append(messages, v.Error()) + default: + messages = append(messages, fmt.Sprintf("%v", v)) + } + } + + message := strings.Join(messages, "|") + write(what, severity, message) + + return fmt.Sprintf("%s|%s", severity, message) +} + +// Raw message logging. +func Raw(message string) { + if Mode == NothingMode { + return + } + + if logToFile { + fileLogBufCh <- buf{time.Now(), message} + } + + if logToStdout { + if color.Colored { + message = color.Colorfy(message) + } + stdoutBufCh <- message + } +} + +// Close log writer (e.g. on change of day). +func closeWriter() { + if writer != nil { + writer.Flush() + fd.Close() + } +} + +// Return the correct log file writer +func fileWriter(dateStr string) *bufio.Writer { + if dateStr != lastDateStr { + return updateFileWriter(dateStr) + } + + // Check for log rotation signal + select { + case <-rotateCh: + stdoutWriter.WriteString("Received signal for logrotation\n") + return updateFileWriter(dateStr) + default: + } + + return writer +} + +// Update log file writer +func updateFileWriter(dateStr string) *bufio.Writer { + // Detected change of day. Close current writer and create a new one. + mutex.Lock() + defer mutex.Unlock() + closeWriter() + + if _, err := os.Stat(config.Common.LogDir); os.IsNotExist(err) { + if err = os.MkdirAll(config.Common.LogDir, 0755); err != nil { + panic(err) + } + } + + logFile := fmt.Sprintf("%s/%s.log", config.Common.LogDir, dateStr) + newFd, err := os.OpenFile(logFile, os.O_CREATE|os.O_RDWR|os.O_APPEND, 0644) + if err != nil { + panic(err) + } + + fd = newFd + writer = bufio.NewWriterSize(fd, 1) + lastDateStr = dateStr + + return writer +} + +// Flush all outstanding lines. +func Flush() { + for { + select { + case message := <-stdoutBufCh: + stdoutWriter.WriteString(message) + default: + stdoutWriter.Flush() + return + } + } +} + +func writeToStdout(ctx context.Context) { + for { + select { + case message := <-stdoutBufCh: + stdoutWriter.WriteString(message) + case <-time.After(time.Millisecond * 100): + stdoutWriter.Flush() + case <-pauseCh: + PAUSE: + for { + select { + case <-stdoutBufCh: + case <-resumeCh: + break PAUSE + case <-ctx.Done(): + return + } + } + case <-ctx.Done(): + Flush() + return + } + } +} + +func writeToFile(ctx context.Context) { + for { + select { + case buf := <-fileLogBufCh: + dateStr := buf.time.Format("20060102") + w := fileWriter(dateStr) + w.WriteString(buf.message) + case <-pauseCh: + PAUSE: + for { + select { + case <-stdoutBufCh: + case <-resumeCh: + break PAUSE + case <-ctx.Done(): + return + } + } + case <-ctx.Done(): + return + } + } +} + +// Pause logging. +func Pause() { + if logToStdout { + pauseCh <- struct{}{} + } + if logToFile { + pauseCh <- struct{}{} + } +} + +// Resume logging (after pausing). +func Resume() { + if logToStdout { + resumeCh <- struct{}{} + } + if logToFile { + resumeCh <- struct{}{} + } +} diff --git a/internal/io/run/run.go b/internal/io/run/run.go new file mode 100644 index 0000000..b608639 --- /dev/null +++ b/internal/io/run/run.go @@ -0,0 +1,104 @@ +package run + +import ( + "bufio" + "context" + "io" + "os/exec" + "strings" + "sync" + "time" + + "github.com/mimecast/dtail/internal/io/line" + "github.com/mimecast/dtail/internal/io/logger" +) + +// Run is for execute a command. +type Run struct { + commandPath string + args []string + cmd *exec.Cmd +} + +// New returns a new command runner. +func New(commandPath string, args []string) Run { + return Run{ + commandPath: commandPath, + args: args, + } +} + +// Start running the command. +func (r Run) Start(ctx context.Context, lines chan<- line.Line) (pid int, ec int, err error) { + done := make(chan struct{}) + defer close(done) + + ec = -1 + pid = -1 + + if len(r.args) > 0 { + logger.Debug(r.commandPath, strings.Join(r.args, " ")) + r.cmd = exec.CommandContext(ctx, r.commandPath, strings.Join(r.args, " ")) + } else { + logger.Debug(r.commandPath) + r.cmd = exec.CommandContext(ctx, r.commandPath) + } + + stdoutPipe, myErr := r.cmd.StdoutPipe() + if err != nil { + err = myErr + return + } + + stderrPipe, myErr := r.cmd.StderrPipe() + if myErr != nil { + err = myErr + return + } + + if myErr := r.cmd.Start(); err != nil { + err = myErr + return + } + + pid = r.cmd.Process.Pid + ec = 0 + + var wg sync.WaitGroup + wg.Add(2) + + go r.pipeToLines(done, &wg, pid, stdoutPipe, "STDOUT", lines) + go r.pipeToLines(done, &wg, pid, stderrPipe, "STDERR", lines) + + if err = r.cmd.Wait(); err != nil { + if exitError, ok := err.(*exec.ExitError); ok { + ec = exitError.ExitCode() + } + } + + return +} + +func (r Run) pipeToLines(done chan struct{}, wg *sync.WaitGroup, pid int, reader io.Reader, what string, lines chan<- line.Line) { + defer wg.Done() + bufReader := bufio.NewReader(reader) + + for { + lineStr, err := bufReader.ReadString('\n') + for err == nil { + lines <- line.Line{ + Content: []byte(lineStr), + Count: uint64(pid), + TransmittedPerc: 100, + SourceID: what, + } + lineStr, err = bufReader.ReadString('\n') + } + select { + case <-done: + return + default: + } + time.Sleep(time.Millisecond * 10) + } +} diff --git a/internal/logger/logger.go b/internal/logger/logger.go deleted file mode 100644 index ca85e32..0000000 --- a/internal/logger/logger.go +++ /dev/null @@ -1,457 +0,0 @@ -package logger - -import ( - "bufio" - "fmt" - "os" - "os/signal" - "runtime" - "strings" - "sync" - "syscall" - "time" - - "github.com/mimecast/dtail/internal/color" - "github.com/mimecast/dtail/internal/config" -) - -const ( - clientStr string = "CLIENT" - serverStr string = "SERVER" - infoStr string = "INFO" - warnStr string = "WARN" - errorStr string = "ERROR" - fatalStr string = "FATAL" - debugStr string = "DEBUG" - traceStr string = "TRACE" -) - -// Synchronise access to logging. -var mutex sync.Mutex - -// File descriptor of log file when logToFile enabled. -var fd *os.File - -// File write buffer of log file when logToFile enabled. -var writer *bufio.Writer - -// File write buffer of stdout when logToStdout enabled. -var stdoutWriter *bufio.Writer - -// Current hostname. -var hostname string - -// Used to detect change of day (create one log file per day0 -var lastDateStr string - -// True if log in server mode, false if log in client mode. -var serverEnable bool - -// Used to make logging non-blocking. -var logBufCh chan buf -var stdoutBufCh chan string - -// Stdout channel, required to pause output -var pauseCh chan struct{} -var resumeCh chan struct{} - -// Tell the logger that we are done, program shuts down -var stop chan struct{} -var stdoutFlushed chan struct{} - -// Tell the logger about logrotation -var rotateCh chan os.Signal - -// LogMode allows to specify the verbosity of logging. -type LogMode int - -// Possible log modes. -const ( - NormalMode LogMode = iota - DebugMode LogMode = iota - SilentMode LogMode = iota - TraceMode LogMode = iota - NothingMode LogMode = iota -) - -// Mode is the current log mode in use. -var Mode LogMode - -// LogStrategy allows to specify a log rotation strategy. -type LogStrategy int - -// Possible log strategies. -const ( - NormalStrategy LogStrategy = iota - DailyStrategy LogStrategy = iota - StdoutStrategy LogStrategy = iota -) - -// Strategy is the current log strattegy used. -var Strategy LogStrategy - -// Enables logging to stdout. -var logToStdout bool - -// Enables logging to file. -var logToFile bool - -// Helper type to make logging non-blocking. -type buf struct { - time time.Time - message string -} - -// Start logging. -func Start(myServerEnable, debugEnable, silentEnable, nothingEnable bool) { - serverEnable = myServerEnable - - mode := logMode(debugEnable, silentEnable, nothingEnable) - strategy := logStrategy() - - stdoutWriter = bufio.NewWriter(os.Stdout) - Mode = mode - Strategy = strategy - - if Mode == NothingMode { - return - } - - switch Strategy { - case DailyStrategy: - _, err := os.Stat(config.Common.LogDir) - logToFile = !os.IsNotExist(err) - logToStdout = !serverEnable || Mode == DebugMode || Mode == TraceMode - case StdoutStrategy: - fallthrough - default: - logToFile = false - logToStdout = true - } - - fqdn, err := os.Hostname() - if err != nil { - panic(err) - } - s := strings.Split(fqdn, ".") - hostname = s[0] - - pauseCh = make(chan struct{}) - resumeCh = make(chan struct{}) - stop = make(chan struct{}) - stdoutFlushed = make(chan struct{}) - - // Setup logrotation - rotateCh = make(chan os.Signal, 1) - signal.Notify(rotateCh, syscall.SIGHUP) - - if logToStdout { - stdoutBufCh = make(chan string, runtime.NumCPU()*100) - go writeToStdout() - } - - if logToFile { - logBufCh = make(chan buf, runtime.NumCPU()*100) - go writeToFile() - } -} - -func logMode(debugEnable, silentEnable, nothingEnable bool) LogMode { - switch { - case debugEnable: - return DebugMode - case nothingEnable: - return NothingMode - case config.Common.TraceEnable: - return TraceMode - case config.Common.DebugEnable: - return DebugMode - case silentEnable: - return SilentMode - default: - } - return NormalMode -} - -func logStrategy() LogStrategy { - switch config.Common.LogStrategy { - case "daily": - return DailyStrategy - default: - } - return StdoutStrategy -} - -// Info message logging. -func Info(args ...interface{}) string { - if serverEnable { - return log(serverStr, infoStr, args) - } - - return log(clientStr, infoStr, args) -} - -// Warn message logging. -func Warn(args ...interface{}) string { - if serverEnable { - return log(serverStr, warnStr, args) - } - - return log(clientStr, warnStr, args) -} - -// Error message logging. -func Error(args ...interface{}) string { - if serverEnable { - return log(serverStr, errorStr, args) - } - - return log(clientStr, errorStr, args) -} - -// FatalExit logs an error and exists the process. -func FatalExit(args ...interface{}) { - what := clientStr - if serverEnable { - what = serverStr - } - log(what, fatalStr, args) - - time.Sleep(time.Second) - mutex.Lock() - defer mutex.Unlock() - - closeWriter() - os.Exit(3) -} - -// Debug message logging. -func Debug(args ...interface{}) string { - if Mode == DebugMode || Mode == TraceMode { - if serverEnable { - return log(serverStr, debugStr, args) - } - return log(clientStr, debugStr, args) - } - - return "" -} - -// Trace message logging. -func Trace(args ...interface{}) string { - if Mode == TraceMode { - if serverEnable { - return log(serverStr, traceStr, args) - } - return log(clientStr, traceStr, args) - } - - return "" -} - -// Write log line to buffer and/or log file. -func write(what, severity, message string) { - if logToStdout && (Mode != SilentMode || severity != warnStr) { - line := fmt.Sprintf("%s|%s|%s|%s\n", what, hostname, severity, message) - - if color.Colored { - line = color.Colorfy(line) - } - - stdoutBufCh <- line - } - - if logToFile { - t := time.Now() - timeStr := t.Format("20060102-150405") - logBufCh <- buf{ - time: t, - message: fmt.Sprintf("%s|%s|%s|%s\n", severity, timeStr, what, message), - } - } -} - -// Generig log message. -func log(what string, severity string, args []interface{}) string { - if Mode == NothingMode { - return "" - } - - var messages []string - - for _, arg := range args { - switch v := arg.(type) { - case string: - messages = append(messages, v) - case int: - messages = append(messages, fmt.Sprintf("%d", v)) - case error: - messages = append(messages, v.Error()) - default: - messages = append(messages, fmt.Sprintf("%v", v)) - } - } - - message := strings.Join(messages, "|") - write(what, severity, message) - - return fmt.Sprintf("%s|%s", severity, message) -} - -// Raw message logging. -func Raw(message string) { - if Mode == NothingMode { - return - } - - if logToStdout { - if color.Colored { - message = color.Colorfy(message) - } - stdoutBufCh <- message - } - - if logToFile { - logBufCh <- buf{time.Now(), message} - } -} - -// Close log writer (e.g. on change of day). -func closeWriter() { - if writer != nil { - writer.Flush() - fd.Close() - } -} - -// Return the correct log file writer -func fileWriter(dateStr string) *bufio.Writer { - if dateStr != lastDateStr { - return updateFileWriter(dateStr) - } - - // Check for log rotation signal - select { - case <-rotateCh: - stdoutWriter.WriteString("Received signal for logrotation\n") - return updateFileWriter(dateStr) - default: - } - - return writer -} - -// Update log file writer -func updateFileWriter(dateStr string) *bufio.Writer { - // Detected change of day. Close current writer and create a new one. - mutex.Lock() - defer mutex.Unlock() - closeWriter() - - if _, err := os.Stat(config.Common.LogDir); os.IsNotExist(err) { - if err = os.MkdirAll(config.Common.LogDir, 0755); err != nil { - panic(err) - } - } - - logFile := fmt.Sprintf("%s/%s.log", config.Common.LogDir, dateStr) - newFd, err := os.OpenFile(logFile, os.O_CREATE|os.O_RDWR|os.O_APPEND, 0644) - if err != nil { - panic(err) - } - - fd = newFd - writer = bufio.NewWriterSize(fd, 1) - lastDateStr = dateStr - - return writer -} - -func flushStdout() { - defer close(stdoutFlushed) - - for { - select { - case message := <-stdoutBufCh: - stdoutWriter.WriteString(message) - default: - stdoutWriter.Flush() - return - } - } -} - -func writeToStdout() { - for { - select { - case message := <-stdoutBufCh: - stdoutWriter.WriteString(message) - case <-time.After(time.Millisecond * 100): - stdoutWriter.Flush() - case <-pauseCh: - PAUSE: - for { - select { - case <-stdoutBufCh: - case <-resumeCh: - break PAUSE - case <-stop: - return - } - } - case <-stop: - flushStdout() - return - } - } -} - -func writeToFile() { - for { - select { - case buf := <-logBufCh: - dateStr := buf.time.Format("20060102") - w := fileWriter(dateStr) - w.WriteString(buf.message) - case <-pauseCh: - PAUSE: - for { - select { - case <-stdoutBufCh: - case <-resumeCh: - break PAUSE - case <-stop: - return - } - } - case <-stop: - return - } - } -} - -// Pause logging. -func Pause() { - if logToStdout { - pauseCh <- struct{}{} - } - if logToFile { - pauseCh <- struct{}{} - } -} - -// Resume logging (after pausing). -func Resume() { - if logToStdout { - resumeCh <- struct{}{} - } - if logToFile { - resumeCh <- struct{}{} - } -} - -// Stop logging. -func Stop() { - close(stop) - <-stdoutFlushed -} diff --git a/internal/mapr/aggregateset.go b/internal/mapr/aggregateset.go index 2096c3c..7fb4c17 100644 --- a/internal/mapr/aggregateset.go +++ b/internal/mapr/aggregateset.go @@ -1,6 +1,7 @@ package mapr import ( + "context" "fmt" "strconv" "strings" @@ -64,7 +65,7 @@ func (s *AggregateSet) Merge(query *Query, set *AggregateSet) error { } // Serialize the aggregate set so it can be sent over the wire. -func (s *AggregateSet) Serialize(groupKey string, ch chan<- string, stop chan struct{}) { +func (s *AggregateSet) Serialize(ctx context.Context, groupKey string, ch chan<- string) { //logger.Trace("Serialising mapr.AggregateSet", s) var sb strings.Builder @@ -87,7 +88,7 @@ func (s *AggregateSet) Serialize(groupKey string, ch chan<- string, stop chan st select { case ch <- sb.String(): - case <-stop: + case <-ctx.Done(): } } diff --git a/internal/mapr/client/aggregate.go b/internal/mapr/client/aggregate.go index 3f2b7a5..1272a19 100644 --- a/internal/mapr/client/aggregate.go +++ b/internal/mapr/client/aggregate.go @@ -1,10 +1,11 @@ package client import ( - "github.com/mimecast/dtail/internal/logger" - "github.com/mimecast/dtail/internal/mapr" "strconv" "strings" + + "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/mapr" ) // Aggregate mapreduce data on the DTail client side. @@ -15,7 +16,6 @@ type Aggregate struct { group *mapr.GroupSet // This represents the merged aggregated data of all servers. globalGroup *mapr.GlobalGroupSet - stop chan struct{} // The server we aggregate the data for (logging and debugging purposes only) server string } @@ -26,20 +26,12 @@ func NewAggregate(server string, query *mapr.Query, globalGroup *mapr.GlobalGrou query: query, group: mapr.NewGroupSet(), globalGroup: globalGroup, - stop: make(chan struct{}), server: server, } } // Aggregate data from mapr log line into local (and global) group sets. func (a *Aggregate) Aggregate(parts []string) { - select { - case <-a.stop: - logger.Error("Client aggregator stopped for server, not processing new data", a.server) - return - default: - } - groupKey := parts[0] samples, err := strconv.Atoi(parts[1]) if err != nil { @@ -87,14 +79,3 @@ func (a *Aggregate) makeFields(parts []string) map[string]string { return fields } - -// Stop the client side mapreduce aggregator. -func (a *Aggregate) Stop() { - logger.Debug("Stopping client mapreduce aggregator") - close(a.stop) - - err := a.globalGroup.Merge(a.query, a.group) - if err != nil { - panic(err) - } -} diff --git a/internal/mapr/groupset.go b/internal/mapr/groupset.go index d8f9379..e9e0d37 100644 --- a/internal/mapr/groupset.go +++ b/internal/mapr/groupset.go @@ -1,6 +1,7 @@ package mapr import ( + "context" "errors" "fmt" "io/ioutil" @@ -46,9 +47,9 @@ func (g *GroupSet) GetSet(groupKey string) *AggregateSet { } // Serialize the group set (e.g. to send it over the wire). -func (g *GroupSet) Serialize(ch chan<- string, stop chan struct{}) { +func (g *GroupSet) Serialize(ctx context.Context, ch chan<- string) { for groupKey, set := range g.sets { - set.Serialize(groupKey, ch, stop) + set.Serialize(ctx, groupKey, ch) } } diff --git a/internal/mapr/logformat/parser.go b/internal/mapr/logformat/parser.go index 5730d29..09c706b 100644 --- a/internal/mapr/logformat/parser.go +++ b/internal/mapr/logformat/parser.go @@ -1,9 +1,9 @@ package logformat import ( - "github.com/mimecast/dtail/internal/logger" "errors" "fmt" + "github.com/mimecast/dtail/internal/io/logger" "os" "reflect" "strings" diff --git a/internal/mapr/query.go b/internal/mapr/query.go index 3805d15..0127be3 100644 --- a/internal/mapr/query.go +++ b/internal/mapr/query.go @@ -1,9 +1,9 @@ package mapr import ( - "github.com/mimecast/dtail/internal/logger" "errors" "fmt" + "github.com/mimecast/dtail/internal/io/logger" "strconv" "strings" "time" diff --git a/internal/mapr/server/aggregate.go b/internal/mapr/server/aggregate.go index 900756e..922dcbd 100644 --- a/internal/mapr/server/aggregate.go +++ b/internal/mapr/server/aggregate.go @@ -1,26 +1,28 @@ package server import ( - "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/fs" - "github.com/mimecast/dtail/internal/logger" - "github.com/mimecast/dtail/internal/mapr" - "github.com/mimecast/dtail/internal/mapr/logformat" + "context" "os" "strings" "time" + + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/io/line" + "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/mapr" + "github.com/mimecast/dtail/internal/mapr/logformat" ) // Aggregate is for aggregating mapreduce data on the DTail server side. type Aggregate struct { // Log lines to process (parsing MAPREDUCE lines). - Lines chan fs.LineRead + Lines chan line.Line // Hostname of the current server (used to populate $hostname field). hostname string - // Signals to exit goroutine. - stop chan struct{} // Signals to serialize data. serialize chan struct{} + // Signals to flush data. + flush chan struct{} // The mapr query query *mapr.Query // The mapr log format parser @@ -28,7 +30,7 @@ type Aggregate struct { } // NewAggregate return a new server side aggregator. -func NewAggregate(maprLines chan<- string, queryStr string) (*Aggregate, error) { +func NewAggregate(queryStr string) (*Aggregate, error) { query, err := mapr.NewQuery(queryStr) if err != nil { return nil, err @@ -47,76 +49,98 @@ func NewAggregate(maprLines chan<- string, queryStr string) (*Aggregate, error) } a := Aggregate{ - Lines: make(chan fs.LineRead, 100), - stop: make(chan struct{}), + Lines: make(chan line.Line, 100), serialize: make(chan struct{}), + flush: make(chan struct{}), hostname: s[0], query: query, parser: logParser, } - go a.periodicAggregateTimer() - - fieldsCh := make(chan map[string]string) - go a.readFields(fieldsCh, maprLines) - go a.readLines(fieldsCh) - return &a, nil } -func (a *Aggregate) periodicAggregateTimer() { +// Start an aggregation run. +func (a *Aggregate) Start(ctx context.Context, maprLines chan<- string) { + fieldsCh := a.linesToFields(ctx) + go a.fieldsToMaprLines(ctx, fieldsCh, maprLines) + a.periodicAggregateTimer(ctx) +} + +func (a *Aggregate) periodicAggregateTimer(ctx context.Context) { for { select { case <-time.After(a.query.Interval): - a.Serialize() - case <-a.stop: + a.Serialize(ctx) + case <-ctx.Done(): return } } } -func (a *Aggregate) readFields(fieldsCh <-chan map[string]string, maprLines chan<- string) { - group := mapr.NewGroupSet() +func (a *Aggregate) linesToFields(ctx context.Context) <-chan map[string]string { + fieldsCh := make(chan map[string]string) - for { - select { - case fields := <-fieldsCh: - a.aggregate(group, fields) - case <-a.serialize: - logger.Info("Serializing mapreduce result") - group.Serialize(maprLines, a.stop) - logger.Info("Done serializing mapreduce result") - group = mapr.NewGroupSet() - case <-a.stop: - return + go func() { + defer close(fieldsCh) + + for { + select { + case line, ok := <-a.Lines: + if !ok { + return + } + + maprLine := strings.TrimSpace(string(line.Content)) + fields, err := a.parser.MakeFields(maprLine) + + if err != nil { + logger.Error(err) + continue + } + if !a.query.WhereClause(fields) { + continue + } + + select { + case fieldsCh <- fields: + case <-ctx.Done(): + } + case <-ctx.Done(): + return + } } - } + }() + + return fieldsCh } -func (a *Aggregate) readLines(fieldsCh chan<- map[string]string) { +func (a *Aggregate) fieldsToMaprLines(ctx context.Context, fieldsCh <-chan map[string]string, maprLines chan<- string) { + group := mapr.NewGroupSet() + for { select { - case line, ok := <-a.Lines: + case fields, ok := <-fieldsCh: if !ok { + logger.Info("Serializing mapreduce result (final)") + group.Serialize(ctx, maprLines) + group = mapr.NewGroupSet() + logger.Info("Done serializing mapreduce result (final)") return } - - maprLine := strings.TrimSpace(string(line.Content)) - fields, err := a.parser.MakeFields(maprLine) - - if err != nil { - logger.Error(err) - continue - } - if !a.query.WhereClause(fields) { - continue - } - - select { - case fieldsCh <- fields: - case <-a.stop: - } - case <-a.stop: + a.aggregate(group, fields) + case <-a.serialize: + logger.Info("Serializing mapreduce result") + group.Serialize(ctx, maprLines) + group = mapr.NewGroupSet() + logger.Info("Done serializing mapreduce result") + case <-a.flush: + logger.Info("Flushing mapreduce result") + group.Serialize(ctx, maprLines) + group = mapr.NewGroupSet() + a.flush <- struct{}{} + logger.Info("Done flushing mapreduce result") + case <-ctx.Done(): return } } @@ -157,14 +181,15 @@ func (a *Aggregate) aggregate(group *mapr.GroupSet, fields map[string]string) { } // Serialize all the aggregated data. -func (a *Aggregate) Serialize() { +func (a *Aggregate) Serialize(ctx context.Context) { select { case a.serialize <- struct{}{}: - case <-a.stop: + case <-ctx.Done(): } } -// Close the aggregator. -func (a *Aggregate) Close() { - close(a.stop) +// Flush all data. +func (a *Aggregate) Flush() { + a.flush <- struct{}{} + <-a.flush } diff --git a/internal/mapr/wherecondition.go b/internal/mapr/wherecondition.go index e1f4e5b..ab46bed 100644 --- a/internal/mapr/wherecondition.go +++ b/internal/mapr/wherecondition.go @@ -1,9 +1,9 @@ package mapr import ( - "github.com/mimecast/dtail/internal/logger" "errors" "fmt" + "github.com/mimecast/dtail/internal/io/logger" "strconv" "strings" ) diff --git a/internal/omode/mode.go b/internal/omode/mode.go index 57366d2..e29aacc 100644 --- a/internal/omode/mode.go +++ b/internal/omode/mode.go @@ -12,7 +12,7 @@ const ( GrepClient Mode = iota MapClient Mode = iota HealthClient Mode = iota - ExecClient Mode = iota + RunClient Mode = iota ) func (m Mode) String() string { @@ -29,8 +29,8 @@ func (m Mode) String() string { return "map" case HealthClient: return "health" - case ExecClient: - return "exec" + case RunClient: + return "run" default: return "unknown" } diff --git a/internal/pprof/pprof.go b/internal/pprof/pprof.go index f78bcf6..c6d11ca 100644 --- a/internal/pprof/pprof.go +++ b/internal/pprof/pprof.go @@ -7,9 +7,10 @@ import ( _ "net/http/pprof" "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" ) +// Start the profiler HTTP server. func Start() { bindAddr := fmt.Sprintf("%s:%d", config.Common.PProfBindAddress, config.Common.PProfPort) logger.Info("Starting PProf server", bindAddr) diff --git a/internal/prompt/prompt.go b/internal/prompt/prompt.go index 76a2726..a438d33 100644 --- a/internal/prompt/prompt.go +++ b/internal/prompt/prompt.go @@ -2,8 +2,8 @@ package prompt import ( "bufio" - "github.com/mimecast/dtail/internal/logger" "fmt" + "github.com/mimecast/dtail/internal/io/logger" "os" "strings" ) diff --git a/internal/server/handlers/controlhandler.go b/internal/server/handlers/controlhandler.go index 482f759..a33a78b 100644 --- a/internal/server/handlers/controlhandler.go +++ b/internal/server/handlers/controlhandler.go @@ -1,33 +1,34 @@ package handlers import ( + "context" "fmt" "io" "os" "strings" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" user "github.com/mimecast/dtail/internal/user/server" ) // ControlHandler is used for control functions and health monitoring. type ControlHandler struct { - serverMessages chan string - pong chan struct{} - stop chan struct{} - payload []byte + ctx context.Context + done chan struct{} hostname string + payload []byte + serverMessages chan string user *user.User } // NewControlHandler returns a new control handler. -func NewControlHandler(user *user.User) *ControlHandler { +func NewControlHandler(ctx context.Context, user *user.User) (*ControlHandler, <-chan struct{}) { logger.Debug(user, "Creating control handler") h := ControlHandler{ + ctx: ctx, + done: make(chan struct{}), serverMessages: make(chan string, 10), - pong: make(chan struct{}, 10), - stop: make(chan struct{}), user: user, } @@ -38,7 +39,8 @@ func NewControlHandler(user *user.User) *ControlHandler { s := strings.Split(fqdn, ".") h.hostname = s[0] - return &h + + return &h, h.done } // Read is to send data to the client via the Reader interface. @@ -49,11 +51,7 @@ func (h *ControlHandler) Read(p []byte) (n int, err error) { wholePayload := []byte(fmt.Sprintf("SERVER|%s|%s\n", h.hostname, message)) n = copy(p, wholePayload) return - case <-h.pong: - logger.Info(h.user, "Sending pong") - n = copy(p, []byte(".pong\n")) - return - case <-h.stop: + case <-h.ctx.Done(): return 0, io.EOF } } @@ -65,7 +63,7 @@ func (h *ControlHandler) Write(p []byte) (n int, err error) { switch c { case ';': wholePayload := strings.TrimSpace(string(h.payload)) - h.handleCommand(wholePayload) + h.handleCommand(h.ctx, wholePayload) h.payload = nil default: @@ -77,17 +75,7 @@ func (h *ControlHandler) Write(p []byte) (n int, err error) { return } -// Close the control handler. -func (h *ControlHandler) Close() { - close(h.stop) -} - -// Wait returns the handler stop channel. -func (h *ControlHandler) Wait() <-chan struct{} { - return h.stop -} - -func (h *ControlHandler) handleCommand(command string) { +func (h *ControlHandler) handleCommand(ctx context.Context, command string) { logger.Info(h.user, command) s := strings.Split(command, " ") logger.Debug(h.user, "Receiving command", command, s) @@ -96,8 +84,6 @@ func (h *ControlHandler) handleCommand(command string) { case "health": h.serverMessages <- "OK: DTail SSH Server seems fine" h.serverMessages <- "done;" - case "ping": - h.pong <- struct{}{} case "debug": h.serverMessages <- logger.Debug(h.user, "Receiving debug command", command, s) default: diff --git a/internal/server/handlers/handler.go b/internal/server/handlers/handler.go index 8b1f73e..c42ceb9 100644 --- a/internal/server/handlers/handler.go +++ b/internal/server/handlers/handler.go @@ -5,6 +5,4 @@ import "io" // Handler interface for server side functionality. type Handler interface { io.ReadWriter - Close() - Wait() <-chan struct{} } diff --git a/internal/server/handlers/mapcommand.go b/internal/server/handlers/mapcommand.go new file mode 100644 index 0000000..10372da --- /dev/null +++ b/internal/server/handlers/mapcommand.go @@ -0,0 +1,35 @@ +package handlers + +import ( + "context" + "strings" + + "github.com/mimecast/dtail/internal/mapr/server" +) + +// Map command implements the mapreduce command server side. +type mapCommand struct { + aggregate *server.Aggregate + server *ServerHandler +} + +// NewMapCommand returns a new server side mapreduce command. +func newMapCommand(serverHandler *ServerHandler, argc int, args []string) (mapCommand, *server.Aggregate, error) { + mapCommand := mapCommand{ + server: serverHandler, + } + + queryStr := strings.Join(args[1:], " ") + aggregate, err := server.NewAggregate(queryStr) + if err != nil { + return mapCommand, nil, err + } + + mapCommand.aggregate = aggregate + return mapCommand, aggregate, nil + +} + +func (m mapCommand) Start(ctx context.Context, aggregatedMessages chan<- string) { + m.aggregate.Start(ctx, aggregatedMessages) +} diff --git a/internal/server/handlers/readcommand.go b/internal/server/handlers/readcommand.go new file mode 100644 index 0000000..e4079e8 --- /dev/null +++ b/internal/server/handlers/readcommand.go @@ -0,0 +1,158 @@ +package handlers + +import ( + "context" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/mimecast/dtail/internal/io/fs" + "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/omode" +) + +type readCommand struct { + server *ServerHandler + mode omode.Mode +} + +func newReadCommand(server *ServerHandler, mode omode.Mode) *readCommand { + return &readCommand{ + server: server, + mode: mode, + } +} + +func (r *readCommand) Start(ctx context.Context, argc int, args []string) { + regex := "." + if argc >= 4 { + regex = args[3] + } + if argc < 3 { + r.server.sendServerMessage(logger.Warn(r.server.user, commandParseWarning, args, argc)) + return + } + r.readGlob(ctx, args[1], regex) +} + +func (r *readCommand) readGlob(ctx context.Context, glob string, regex string) { + retryInterval := time.Second * 5 + glob = filepath.Clean(glob) + + maxRetries := 10 + for { + maxRetries-- + if maxRetries < 0 { + r.server.sendServerMessage(logger.Warn(r.server.user, "Giving up to read file(s)")) + return + } + + paths, err := filepath.Glob(glob) + if err != nil { + logger.Warn(r.server.user, glob, err) + time.Sleep(retryInterval) + continue + } + + if numPaths := len(paths); numPaths == 0 { + logger.Error(r.server.user, "No such file(s) to read", glob) + r.server.sendServerMessage(logger.Warn(r.server.user, "Unable to read file(s), check server logs")) + select { + case <-ctx.Done(): + return + default: + } + time.Sleep(retryInterval) + continue + } + + r.readFiles(ctx, paths, glob, regex, retryInterval) + break + } +} + +func (r *readCommand) readFiles(ctx context.Context, paths []string, glob string, regex string, retryInterval time.Duration) { + var wg sync.WaitGroup + wg.Add(len(paths)) + + for _, path := range paths { + go r.readFileIfPermissions(ctx, &wg, path, glob, regex) + } + + wg.Wait() +} + +func (r *readCommand) readFileIfPermissions(ctx context.Context, wg *sync.WaitGroup, path, glob, regex string) { + defer wg.Done() + globID := r.makeGlobID(path, glob) + + if !r.server.user.HasFilePermission(path, "readfiles") { + logger.Error(r.server.user, "No permission to read file", path, globID) + r.server.sendServerMessage(logger.Warn(r.server.user, "Unable to read file(s), check server logs")) + return + } + + r.readFile(ctx, path, globID, regex) +} + +func (r *readCommand) readFile(ctx context.Context, path, globID, regex string) { + logger.Info(r.server.user, "Start reading file", path, globID) + + var reader fs.FileReader + switch r.mode { + case omode.TailClient: + reader = fs.NewTailFile(path, globID, r.server.serverMessages, r.server.tailLimiter) + case omode.GrepClient, omode.CatClient: + reader = fs.NewCatFile(path, globID, r.server.serverMessages, r.server.catLimiter) + default: + reader = fs.NewTailFile(path, globID, r.server.serverMessages, r.server.tailLimiter) + } + + lines := r.server.lines + + // Plug in mappreduce engine + if r.server.aggregate != nil { + lines = r.server.aggregate.Lines + } + + for { + if err := reader.Start(ctx, lines, regex); err != nil { + logger.Error(r.server.user, path, globID, err) + } + + select { + case <-ctx.Done(): + return + default: + if !reader.Retry() { + return + } + } + + time.Sleep(time.Second * 2) + logger.Info(path, globID, "Reading file again") + } +} + +func (r *readCommand) makeGlobID(path, glob string) string { + var idParts []string + pathParts := strings.Split(path, "/") + + for i, globPart := range strings.Split(glob, "/") { + if strings.Contains(globPart, "*") { + idParts = append(idParts, pathParts[i]) + } + } + + if len(idParts) > 0 { + return strings.Join(idParts, "/") + } + + if len(pathParts) > 0 { + return pathParts[len(pathParts)-1] + } + + r.server.sendServerMessage(logger.Error("Empty file path given?", path, glob)) + return "" +} diff --git a/internal/server/handlers/runcommand.go b/internal/server/handlers/runcommand.go new file mode 100644 index 0000000..e260060 --- /dev/null +++ b/internal/server/handlers/runcommand.go @@ -0,0 +1,73 @@ +package handlers + +import ( + "context" + "fmt" + "os/exec" + "strings" + + "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/io/run" +) + +type runCommand struct { + server *ServerHandler + run run.Run +} + +func newRunCommand(server *ServerHandler) runCommand { + return runCommand{ + server: server, + } +} + +func (r runCommand) Start(ctx context.Context, argc int, args []string) { + if argc < 2 { + r.server.sendServerMessage(logger.Warn(r.server.user, commandParseWarning, args, argc)) + return + } + commands := strings.Split(strings.Join(args[1:], " "), ";") + r.start(ctx, commands) +} + +func (r runCommand) start(ctx context.Context, commands []string) { + for _, command := range commands { + command = strings.TrimSpace(command) + if len(command) == 0 { + continue + } + splitted := strings.Split(command, " ") + path := splitted[0] + args := splitted[1:] + + qualifiedPath, err := exec.LookPath(path) + if err != nil { + logger.Error(r.server.user, err) + r.server.sendServerMessage(logger.Warn(r.server.user, "Unable to execute command(s), check server logs")) + r.server.sendServerMessage(fmt.Sprintf(".run exitstatus -%d", -1)) + return + } + + if !r.server.user.HasFilePermission(qualifiedPath, "runcommands") { + logger.Error(r.server.user, "No permission to execute path", qualifiedPath) + r.server.sendServerMessage(logger.Warn(r.server.user, "Unable to execute command(s), check server logs")) + r.server.sendServerMessage(fmt.Sprintf(".run exitstatus -%d", -1)) + return + } + + r.run = run.New(qualifiedPath, args) + pid, ec, err := r.run.Start(ctx, r.server.lines) + + if err != nil { + message := fmt.Sprintf("Unable to execute remote command '%s'", command) + logger.Error(r.server.user, message, ec, pid, err) + r.server.sendServerMessage(logger.Error(message, ec, pid, err)) + r.server.sendServerMessage(fmt.Sprintf(".run exitstatus -%d", ec)) + return + } + + message := fmt.Sprintf("Remote process '%d' exited with status '%d'", pid, ec) + r.server.sendServerMessage(fmt.Sprintf(".run exitstatus %d", ec)) + r.server.sendServerMessage(logger.Info("run", pid, ec, message)) + } +} diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go index bed8609..3f0d6ce 100644 --- a/internal/server/handlers/serverhandler.go +++ b/internal/server/handlers/serverhandler.go @@ -1,17 +1,19 @@ package handlers import ( + "context" + "encoding/base64" + "errors" "fmt" "io" "os" - "path/filepath" "strings" "sync" "time" "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/fs" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/line" + "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/mapr/server" "github.com/mimecast/dtail/internal/omode" user "github.com/mimecast/dtail/internal/user/server" @@ -26,51 +28,33 @@ const ( // the Bi-directional communication between SSH client and server. // This handler implements the handler of the SSH server. type ServerHandler struct { - // Local log file readers - fileReaders []fs.FileReader - fileReadersMtx *sync.Mutex - // Channel for read lines. - lines chan fs.LineRead - // Only process log lines matching this regex. - regex string - // Server side mapr log aggregation. - aggregate *server.Aggregate - // Channel of aggregated log lines. + mutex *sync.Mutex + lines chan line.Line + regex string + aggregate *server.Aggregate aggregatedMessages chan string - // Channel for server messages to be sent to the client. - serverMessages chan string - // Channel for hidden messages to be sent to the client. - hiddenMessages chan string - // The current payload sent to the client. - payload []byte - // The current server hostname. - hostname string - // The user connecting to dtail. - user *user.User - // To limit the server wide max amount of concurrent cats - catLimiter chan struct{} - // To limit the server wide max amount of concurrent tails - tailLimiter chan struct{} - // Server can tell handler to stop the handler. - stop chan struct{} - // Indicate that client responded to server with "ack stop connection" - ackStopReceived chan struct{} - // Stop timeout. - stopTimeout chan struct{} + serverMessages chan string + payload []byte + hostname string + user *user.User + catLimiter chan struct{} + tailLimiter chan struct{} + ackCloseReceived chan struct{} + ctx context.Context + done chan struct{} + activeReaders int } // NewServerHandler returns the server handler. -func NewServerHandler(user *user.User, catLimiter chan struct{}, tailLimiter chan struct{}) *ServerHandler { - logger.Debug(user, "Creating tail handler") +func NewServerHandler(ctx context.Context, user *user.User, catLimiter chan struct{}, tailLimiter chan struct{}) (*ServerHandler, <-chan struct{}) { h := ServerHandler{ - fileReadersMtx: &sync.Mutex{}, - lines: make(chan fs.LineRead, 100), + ctx: ctx, + done: make(chan struct{}), + mutex: &sync.Mutex{}, + lines: make(chan line.Line, 100), serverMessages: make(chan string, 10), aggregatedMessages: make(chan string, 10), - hiddenMessages: make(chan string, 10), - ackStopReceived: make(chan struct{}), - stopTimeout: make(chan struct{}), - stop: make(chan struct{}), + ackCloseReceived: make(chan struct{}), catLimiter: catLimiter, tailLimiter: tailLimiter, regex: ".", @@ -85,37 +69,46 @@ func NewServerHandler(user *user.User, catLimiter chan struct{}, tailLimiter cha s := strings.Split(fqdn, ".") h.hostname = s[0] - return &h + return &h, h.done } // Read is to send data to the dtail client via Reader interface. func (h *ServerHandler) Read(p []byte) (n int, err error) { for { select { + case message := <-h.serverMessages: + if message[0] == '.' { + // Handle hidden message (don't display to the user, interpreted by dtail client) + wholePayload := []byte(fmt.Sprintf("%s\n", message)) + n = copy(p, wholePayload) + return + } + + // Handle normal server message (display to the user) wholePayload := []byte(fmt.Sprintf("SERVER|%s|%s\n", h.hostname, message)) n = copy(p, wholePayload) return + case message := <-h.aggregatedMessages: + // Send mapreduce-aggregated data as a message. data := fmt.Sprintf("AGGREGATE|%s|%s\n", h.hostname, message) - //logger.Debug("Sending aggregation data", data) wholePayload := []byte(data) n = copy(p, wholePayload) return - case message := <-h.hiddenMessages: - //logger.Debug(h.user, "Sending hidden message", message) - wholePayload := []byte(fmt.Sprintf(".%s\n", message)) - n = copy(p, wholePayload) - return + case line := <-h.lines: + // Send normal file content data as a message. serverInfo := []byte(fmt.Sprintf("REMOTE|%s|%3d|%v|%s|", - h.hostname, line.TransmittedPerc, line.Count, *line.GlobID)) + h.hostname, line.TransmittedPerc, line.Count, line.SourceID)) wholePayload := append(serverInfo, line.Content[:]...) n = copy(p, wholePayload) return + case <-time.After(time.Second): + // Once in a while check whether we are done. select { - case <-h.stop: + case <-h.ctx.Done(): return 0, io.EOF default: } @@ -129,7 +122,7 @@ func (h *ServerHandler) Write(p []byte) (n int, err error) { switch c { case ';': commandStr := strings.TrimSpace(string(h.payload)) - h.handleCommand(commandStr) + h.handleCommand(h.ctx, commandStr) h.payload = nil default: h.payload = append(h.payload, c) @@ -140,210 +133,167 @@ func (h *ServerHandler) Write(p []byte) (n int, err error) { return } -// Close the server handler. -func (h *ServerHandler) Close() { - h.fileReadersMtx.Lock() - defer h.fileReadersMtx.Unlock() +func (h *ServerHandler) handleCommand(ctx context.Context, commandStr string) { + logger.Debug(h.user, commandStr) - for _, reader := range h.fileReaders { - reader.Stop() + args, argc, err := h.handleProtocolVersion(strings.Split(commandStr, " ")) + if err != nil { + h.send(h.serverMessages, logger.Error(h.user, err)) + return } - if h.aggregate != nil { - h.aggregate.Close() + + args, argc, err = h.handleBase64(args, argc) + if err != nil { + h.send(h.serverMessages, logger.Error(h.user, err)) + return } - close(h.stop) -} + if h.user.Name == config.ControlUser { + h.handleControlCommand(argc, args) + return + } -func (h *ServerHandler) makeGlobID(path, glob string) string { - var idParts []string - pathParts := strings.Split(path, "/") + h.handleUserCommand(ctx, argc, args) +} - for i, globPart := range strings.Split(glob, "/") { - if strings.Contains(globPart, "*") { - idParts = append(idParts, pathParts[i]) - } - } +func (h *ServerHandler) handleProtocolVersion(args []string) ([]string, int, error) { + argc := len(args) - if len(idParts) > 0 { - return strings.Join(idParts, "/") + if argc <= 2 || args[0] != "protocol" { + return args, argc, errors.New("unable to determine protocol version") } - if len(pathParts) > 0 { - return pathParts[len(pathParts)-1] + if args[1] != version.ProtocolCompat { + err := fmt.Errorf("server with protool version '%s' but client with '%s', please update DTail", version.ProtocolCompat, args[1]) + return args, argc, err } - h.send(h.serverMessages, logger.Error("Empty file path given?", path, glob)) - return "" + return args[2:], argc - 2, nil } -func (h *ServerHandler) processFileGlob(mode omode.Mode, glob string, regex string) { - retryInterval := time.Second * 5 - glob = filepath.Clean(glob) - - errors := make(chan struct{}) - stop := make(chan struct{}) - defer close(stop) +func (h *ServerHandler) handleBase64(args []string, argc int) ([]string, int, error) { + err := errors.New("Unable to decode client message") - go func() { - for { - select { - case <-errors: - h.send(h.serverMessages, logger.Warn(h.user, "Unable to read file(s), check server logs")) - case <-stop: - return - case <-h.stop: - return - } - } - }() + if argc != 2 || args[0] != "base64" { + return args, argc, err + } - maxRetries := 10 - for { - maxRetries-- - if maxRetries < 0 { - h.send(h.serverMessages, logger.Warn(h.user, "Giving up to read file(s)")) - h.internalClose() - return - } + decoded, err := base64.StdEncoding.DecodeString(args[1]) + if err != nil { + return args, argc, err + } + decodedStr := string(decoded) - paths, err := filepath.Glob(glob) - if err != nil { - logger.Warn(h.user, glob, err) - time.Sleep(retryInterval) - continue - } + args = strings.Split(decodedStr, " ") + argc = len(decodedStr) + logger.Trace(h.user, "Base64 decoded received command", decodedStr, argc, args) - if numPaths := len(paths); numPaths == 0 { - logger.Error(h.user, "No such file(s) to read", glob) - select { - case errors <- struct{}{}: - case <-h.stop: - return - default: - } - time.Sleep(retryInterval) - continue - } + return args, argc, nil +} - h.startReadingFiles(mode, paths, glob, regex, retryInterval, errors) - break +func (h *ServerHandler) handleControlCommand(argc int, args []string) { + switch args[0] { + case "debug": + h.send(h.serverMessages, logger.Debug(h.user, "Receiving debug command", argc, args)) + default: + logger.Warn(h.user, "Received unknown command", argc, args) } } -func (h *ServerHandler) startReadingFiles(mode omode.Mode, paths []string, glob string, regex string, retryInterval time.Duration, errors chan<- struct{}) { - var wg sync.WaitGroup - wg.Add(len(paths)) +func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []string) { + logger.Debug(h.user, "handleUserCommand", argc, args) - read := func(path string, wg *sync.WaitGroup) { - defer wg.Done() - globID := h.makeGlobID(path, glob) + switch args[0] { + case "grep", "cat": + command := newReadCommand(h, omode.CatClient) + h.incrementActiveReaders() + go func() { + command.Start(ctx, argc, args) + if h.decrementActiveReaders() == 0 { + h.shutdown() + } + }() - if !h.user.HasFilePermission(path) { - logger.Error(h.user, "No permission to read file", path, globID) - select { - case errors <- struct{}{}: - default: + case "tail": + command := newReadCommand(h, omode.TailClient) + h.incrementActiveReaders() + go func() { + command.Start(ctx, argc, args) + if h.decrementActiveReaders() == 0 { + h.shutdown() } + }() + + case "map": + command, aggregate, err := newMapCommand(h, argc, args) + if err != nil { + h.sendServerMessage(err.Error()) + logger.Error(h.user, err) return } - h.startReadingFile(mode, path, globID, regex) - } - - for _, path := range paths { - go read(path, &wg) - } + h.aggregate = aggregate + go func() { + command.Start(ctx, h.aggregatedMessages) + h.shutdown() + }() + + case "run": + command := newRunCommand(h) + h.incrementActiveReaders() + go func() { + command.Start(ctx, argc, args) + if h.decrementActiveReaders() == 0 { + h.shutdown() + } + }() - wg.Wait() -} + case "ack", ".ack": + h.handleAckCommand(argc, args) -func (h *ServerHandler) startReadingFile(mode omode.Mode, path, globID, regex string) { - defer h.stopReadingFile(path) - logger.Info(h.user, "Start reading file", path, globID) - - var reader fs.FileReader - switch mode { - case omode.TailClient: - reader = fs.NewTailFile(path, globID, h.serverMessages, h.tailLimiter) - case omode.GrepClient: - fallthrough - case omode.CatClient: - reader = fs.NewCatFile(path, globID, h.serverMessages, h.catLimiter) default: - reader = fs.NewTailFile(path, globID, h.serverMessages, h.tailLimiter) + h.sendServerMessage(logger.Error(h.user, "Received unknown command", argc, args)) } +} - h.fileReadersMtx.Lock() - h.fileReaders = append(h.fileReaders, reader) - h.fileReadersMtx.Unlock() - - lines := h.lines - // Plugin mappreduce engine - if h.aggregate != nil { - lines = h.aggregate.Lines +func (h *ServerHandler) handleAckCommand(argc int, args []string) { + if argc < 3 { + h.sendServerMessage(logger.Warn(h.user, commandParseWarning, args, argc)) + return } - - for { - if err := reader.Start(lines, regex); err != nil { - logger.Error(h.user, path, globID, err) - } - - select { - case <-h.stop: - return - default: - if !reader.Retry() { - return - } - } - - time.Sleep(time.Second * 2) - logger.Info(path, globID, "Reading file again") + if args[1] == "close" && args[2] == "connection" { + close(h.ackCloseReceived) } } -func (h *ServerHandler) stopReadingFile(path string) { - logger.Info(h.user, "Stop reading file", path) +func (h *ServerHandler) send(ch chan<- string, message string) { + select { + case ch <- message: + case <-h.ctx.Done(): + } +} - h.fileReadersMtx.Lock() - defer h.fileReadersMtx.Unlock() +func (h *ServerHandler) sendServerMessage(message string) { + h.send(h.serverMessageC(), message) +} - path = filepath.Clean(path) - var fileReaders []fs.FileReader +func (h *ServerHandler) serverMessageC() chan<- string { + return h.serverMessages +} - for _, reader := range h.fileReaders { - if reader.FilePath() == path { - reader.Stop() - continue - } - fileReaders = append(fileReaders, reader) - } +func (h *ServerHandler) flush() { + logger.Debug(h.user, "flush()") - if len(fileReaders) == len(h.fileReaders) { - logger.Warn(h.user, "Didn't read file path", path) - return + if h.aggregate != nil { + h.aggregate.Flush() } - h.fileReaders = fileReaders - - if len(fileReaders) == 0 { - if h.aggregate != nil { - h.aggregate.Serialize() - } - h.allLinesSent() + unsentMessages := func() int { + return len(h.lines) + len(h.serverMessages) + len(h.aggregatedMessages) } -} - -func (h *ServerHandler) numUnsentMessages() int { - return len(h.lines) + len(h.serverMessages) + len(h.hiddenMessages) + len(h.aggregatedMessages) -} - -func (h *ServerHandler) allLinesSent() { - defer h.internalClose() for i := 0; i < 3; i++ { - if h.numUnsentMessages() == 0 { + if unsentMessages() == 0 { logger.Debug(h.user, "All lines sent") return } @@ -351,142 +301,43 @@ func (h *ServerHandler) allLinesSent() { time.Sleep(time.Second) } - logger.Warn(h.user, "Some lines remain unsent", h.numUnsentMessages()) + logger.Warn(h.user, "Some lines remain unsent", unsentMessages()) } -// Handler decides to shutdown the connection, not the server itself. -func (h *ServerHandler) internalClose() { - select { - case h.hiddenMessages <- "syn close connection": - case <-time.After(time.Second * 5): - logger.Debug(h.user, "Not waiting for ack close connection") - close(h.stopTimeout) - return - } +func (h *ServerHandler) shutdown() { + logger.Debug(h.user, "shutdown()") + h.flush() + + go func() { + select { + case h.serverMessageC() <- ".syn close connection": + case <-h.ctx.Done(): + } + }() select { - case <-h.Wait(): + case <-h.ackCloseReceived: case <-time.After(time.Second * 5): - logger.Debug(h.user, "Not waiting for ack close connection") - close(h.stopTimeout) - } -} - -func (h *ServerHandler) handleCommand(commandStr string) { - logger.Info(h.user, commandStr) - - args := strings.Split(commandStr, " ") - argc := len(args) - - logger.Debug(h.user, "Received command", commandStr, argc, args) - - if h.user.Name == config.ControlUser { - h.handleControlCommand(argc, args) - return + logger.Debug(h.user, "Shutdown timeout reached, enforcing shutdown") + case <-h.ctx.Done(): } - h.handleUserCommand(argc, args) -} - -// Special (restricted) set of commands for anonymous ControlUser access. -func (h *ServerHandler) handleControlCommand(argc int, args []string) { - switch args[0] { - case "ping": - h.send(h.hiddenMessages, "pong") - case "debug": - h.send(h.serverMessages, logger.Debug(h.user, "Receiving debug command", argc, args)) - default: - logger.Warn(h.user, "Received unknown command", argc, args) - } -} - -// Commands for authed users. -func (h *ServerHandler) handleUserCommand(argc int, args []string) { - switch args[0] { - case "grep": - fallthrough - case "cat": - h.handleReadCommand(argc, args, omode.CatClient) - case "tail": - h.handleReadCommand(argc, args, omode.TailClient) - case "map": - h.handleMapCommand(argc, args) - case "ack": - h.handleAckCommand(argc, args) - case "ping": - h.send(h.hiddenMessages, "pong") - case "version": - h.send(h.serverMessages, fmt.Sprintf("Server version is "+version.String())) - case "debug": - h.send(h.serverMessages, logger.Debug(h.user, "Received debug command", argc, args)) + select { + case h.done <- struct{}{}: default: - h.send(h.serverMessages, logger.Warn(h.user, "Received unknown command", argc, args)) } } -func (h *ServerHandler) handleReadCommand(argc int, args []string, mode omode.Mode) { - regex := "." - if argc >= 4 { - regex = args[3] - } - if argc < 3 { - h.send(h.serverMessages, logger.Warn(h.user, commandParseWarning, args, argc)) - return - } - go h.processFileGlob(mode, args[1], regex) +func (h *ServerHandler) incrementActiveReaders() { + // TODO: Use atomic counter variable instead, so we can get rid of the mutex + h.mutex.Lock() + defer h.mutex.Unlock() + h.activeReaders++ } +func (h *ServerHandler) decrementActiveReaders() int { + h.mutex.Lock() + defer h.mutex.Unlock() + h.activeReaders-- -func (h *ServerHandler) handleMapCommand(argc int, args []string) { - if argc < 2 { - h.send(h.serverMessages, logger.Warn(h.user, commandParseWarning, args, argc)) - return - } - - queryStr := strings.Join(args[1:], " ") - logger.Info(h.user, "Creating new mapr aggregator", queryStr) - aggregate, err := server.NewAggregate(h.aggregatedMessages, queryStr) - - if err != nil { - h.send(h.serverMessages, logger.Error(h.user, err)) - return - } - - h.aggregate = aggregate -} - -func (h *ServerHandler) handleAckCommand(argc int, args []string) { - if argc < 3 { - h.send(h.serverMessages, logger.Warn(h.user, commandParseWarning, args, argc)) - return - } - if args[1] == "close" && args[2] == "connection" { - close(h.ackStopReceived) - } -} - -func (h *ServerHandler) send(ch chan<- string, message string) { - select { - case ch <- message: - case <-h.stop: - } -} - -// Wait (block) until server handler is closed or a timeout has exceeded. -func (h *ServerHandler) Wait() <-chan struct{} { - wait := make(chan struct{}) - - go func() { - select { - case <-h.ackStopReceived: - logger.Debug(h.user, "Closing wait channel due to ACK stop received") - close(wait) - case <-h.stopTimeout: - logger.Debug(h.user, "Closing wait channel due to wait timeout") - close(wait) - case <-h.stop: - logger.Debug(h.user, "Closing wait channel due to stop") - } - }() - - return wait + return h.activeReaders } diff --git a/internal/server/server.go b/internal/server/server.go index 27a98f5..42eb74c 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -1,13 +1,14 @@ package server import ( + "context" "errors" "fmt" "io" "net" "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/server/handlers" "github.com/mimecast/dtail/internal/ssh/server" user "github.com/mimecast/dtail/internal/user/server" @@ -26,8 +27,6 @@ type Server struct { catLimiterCh chan struct{} // To control the max amount of concurrent tails tailLimiterCh chan struct{} - // Ask to shutdown the server - stop chan struct{} } // New returns a new server. @@ -38,7 +37,6 @@ func New() *Server { sshServerConfig: &gossh.ServerConfig{}, catLimiterCh: make(chan struct{}, config.Server.MaxConcurrentCats), tailLimiterCh: make(chan struct{}, config.Server.MaxConcurrentTails), - stop: make(chan struct{}), } s.sshServerConfig.PasswordCallback = s.controlUserCallback @@ -54,7 +52,7 @@ func New() *Server { } // Start the server. -func (s *Server) Start() int { +func (s *Server) Start(ctx context.Context) int { logger.Info("Starting server") bindAt := fmt.Sprintf("%s:%d", config.Server.SSHBindAddress, config.Common.SSHPort) @@ -64,7 +62,7 @@ func (s *Server) Start() int { logger.FatalExit("Failed to open listening TCP socket", err) } - go s.stats.periodicLogServerStats(s.stop) + go s.stats.periodicLogServerStats(ctx) for { conn, err := listener.Accept() // Blocking @@ -79,11 +77,11 @@ func (s *Server) Start() int { continue } - go s.handleConnection(conn) + go s.handleConnection(ctx, conn) } } -func (s *Server) handleConnection(conn net.Conn) { +func (s *Server) handleConnection(ctx context.Context, conn net.Conn) { logger.Info("Handling connection") sshConn, chans, reqs, err := gossh.NewServerConn(conn, s.sshServerConfig) @@ -96,11 +94,11 @@ func (s *Server) handleConnection(conn net.Conn) { go gossh.DiscardRequests(reqs) for newChannel := range chans { - go s.handleChannel(sshConn, newChannel) + go s.handleChannel(ctx, sshConn, newChannel) } } -func (s *Server) handleChannel(sshConn gossh.Conn, newChannel gossh.NewChannel) { +func (s *Server) handleChannel(ctx context.Context, sshConn gossh.Conn, newChannel gossh.NewChannel) { user := user.New(sshConn.User(), sshConn.RemoteAddr().String()) logger.Info(user, "Invoking channel handler") @@ -117,13 +115,13 @@ func (s *Server) handleChannel(sshConn gossh.Conn, newChannel gossh.NewChannel) return } - if err := s.handleRequests(sshConn, requests, channel, user); err != nil { + if err := s.handleRequests(ctx, sshConn, requests, channel, user); err != nil { logger.Error(user, err) sshConn.Close() } } -func (s *Server) handleRequests(sshConn gossh.Conn, in <-chan *gossh.Request, channel gossh.Channel, user *user.User) error { +func (s *Server) handleRequests(ctx context.Context, sshConn gossh.Conn, in <-chan *gossh.Request, channel gossh.Channel, user *user.User) error { logger.Info(user, "Invoking request handler") for req := range in { @@ -132,50 +130,50 @@ func (s *Server) handleRequests(sshConn gossh.Conn, in <-chan *gossh.Request, ch switch req.Type { case "shell": + handlerCtx, cancel := context.WithCancel(ctx) + var handler handlers.Handler + var done <-chan struct{} + switch user.Name { case config.ControlUser: - handler = handlers.NewControlHandler(user) + handler, done = handlers.NewControlHandler(handlerCtx, user) default: - handler = handlers.NewServerHandler(user, s.catLimiterCh, s.tailLimiterCh) + handler, done = handlers.NewServerHandler(handlerCtx, user, s.catLimiterCh, s.tailLimiterCh) } - // Bi-directionally connect SSH stream to SSH handler - brokenPipe1 := make(chan struct{}) go func() { - defer close(brokenPipe1) + // Handler finished work, cancel all remaining routines + defer cancel() + <-done + }() + + go func() { + // Broken pipe, cancel + defer cancel() + io.Copy(channel, handler) }() - brokenPipe2 := make(chan struct{}) go func() { - defer close(brokenPipe2) + // Broken pipe, cancel + defer cancel() + io.Copy(handler, channel) }() - // Ensure to close all fd's and stop all goroutines once ssh connection terminated go func() { - defer s.stats.decrementConnections() - defer handler.Close() + defer cancel() if err := sshConn.Wait(); err != nil && err != io.EOF { logger.Error(user, err) } + s.stats.decrementConnections() logger.Info(user, "Good bye Mister!") }() - // Close the underlying ssh socket when server shuts down go func() { - select { - case <-s.stop: - logger.Debug(user, "Server initiating shutdown on handler") - case <-handler.Wait(): - logger.Debug(user, "Handler initiating shutdown by its own") - case <-brokenPipe1: - logger.Debug(user, "Broken pipe1") - case <-brokenPipe2: - logger.Debug(user, "Broken pipe2") - } + <-handlerCtx.Done() sshConn.Close() logger.Info(user, "Closed SSH connection") }() @@ -204,9 +202,3 @@ func (*Server) controlUserCallback(c gossh.ConnMetadata, authPayload []byte) (*g return nil, fmt.Errorf("Not authorized") } - -// Stop the server. -func (s *Server) Stop() { - close(s.stop) - s.stats.waitForConnections() -} diff --git a/internal/server/stats.go b/internal/server/stats.go index beb1885..4d661f7 100644 --- a/internal/server/stats.go +++ b/internal/server/stats.go @@ -1,12 +1,14 @@ package server import ( - "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/logger" + "context" "fmt" "runtime" "sync" "time" + + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/io/logger" ) // Used to collect and display various server stats. @@ -65,12 +67,12 @@ func (s *stats) serverLimitExceeded() error { return nil } -func (s *stats) periodicLogServerStats(stop <-chan struct{}) { +func (s *stats) periodicLogServerStats(ctx context.Context) { for { select { case <-time.NewTimer(time.Second * 10).C: s.logServerStats() - case <-stop: + case <-ctx.Done(): return } } diff --git a/internal/ssh/client/authmethods.go b/internal/ssh/client/authmethods.go index 3392eb1..967866f 100644 --- a/internal/ssh/client/authmethods.go +++ b/internal/ssh/client/authmethods.go @@ -2,7 +2,7 @@ package client import ( "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/ssh" "os" diff --git a/internal/ssh/client/hostkeycallback.go b/internal/ssh/client/hostkeycallback.go index 4023e59..7ae2396 100644 --- a/internal/ssh/client/hostkeycallback.go +++ b/internal/ssh/client/hostkeycallback.go @@ -2,8 +2,7 @@ package client import ( "bufio" - "github.com/mimecast/dtail/internal/logger" - "github.com/mimecast/dtail/internal/prompt" + "context" "fmt" "net" "os" @@ -11,6 +10,9 @@ import ( "sync" "time" + "github.com/mimecast/dtail/internal/io/logger" + "github.com/mimecast/dtail/internal/prompt" + "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/knownhosts" ) @@ -116,7 +118,7 @@ func (h *HostKeyCallback) Wrap() ssh.HostKeyCallback { // PromptAddHosts prompts a question to the user whether unknown hosts should // be added to the known hosts or not. -func (h *HostKeyCallback) PromptAddHosts(stop <-chan struct{}) { +func (h *HostKeyCallback) PromptAddHosts(ctx context.Context) { var hosts []unknownHost for { @@ -135,7 +137,7 @@ func (h *HostKeyCallback) PromptAddHosts(stop <-chan struct{}) { h.promptAddHosts(hosts) hosts = []unknownHost{} } - case <-stop: + case <-ctx.Done(): logger.Debug("Stopping goroutine prompting new hosts...") return } diff --git a/internal/ssh/server/hostkey.go b/internal/ssh/server/hostkey.go index 7baa4aa..07790ad 100644 --- a/internal/ssh/server/hostkey.go +++ b/internal/ssh/server/hostkey.go @@ -2,7 +2,7 @@ package server import ( "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" "github.com/mimecast/dtail/internal/ssh" "io/ioutil" "os" diff --git a/internal/ssh/server/publickeycallback.go b/internal/ssh/server/publickeycallback.go index c6929d7..757def7 100644 --- a/internal/ssh/server/publickeycallback.go +++ b/internal/ssh/server/publickeycallback.go @@ -7,7 +7,7 @@ import ( osUser "os/user" "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/logger" + "github.com/mimecast/dtail/internal/io/logger" user "github.com/mimecast/dtail/internal/user/server" gossh "golang.org/x/crypto/ssh" diff --git a/internal/ssh/ssh.go b/internal/ssh/ssh.go index 77cc341..3a2e416 100644 --- a/internal/ssh/ssh.go +++ b/internal/ssh/ssh.go @@ -4,9 +4,9 @@ import ( "crypto/rand" "crypto/rsa" "crypto/x509" - "github.com/mimecast/dtail/internal/logger" "encoding/pem" "fmt" + "github.com/mimecast/dtail/internal/io/logger" "io/ioutil" "net" "os" diff --git a/internal/user/name.go b/internal/user/name.go index 5171ec7..28ab0a4 100644 --- a/internal/user/name.go +++ b/internal/user/name.go @@ -2,10 +2,10 @@ package user import ( "os/user" - ) +) - -func Name() string { +// NoRootCheck verifies that the DTail run user is not with UID or GID 0. +func NoRootCheck() { user, err := user.Current() if err != nil { panic(err) @@ -18,7 +18,14 @@ func Name() string { if user.Gid == "0" { panic("Not allowed to run as GID 0") } +} + +// Name of the current run user. +func Name() string { + user, err := user.Current() + if err != nil { + panic(err) + } return user.Username } - diff --git a/internal/user/server/user.go b/internal/user/server/user.go index fad38d8..271a4ac 100644 --- a/internal/user/server/user.go +++ b/internal/user/server/user.go @@ -1,14 +1,15 @@ package server import ( - "github.com/mimecast/dtail/internal/config" - "github.com/mimecast/dtail/internal/fs/permissions" - "github.com/mimecast/dtail/internal/logger" "fmt" "os" "path/filepath" "regexp" "strings" + + "github.com/mimecast/dtail/internal/config" + "github.com/mimecast/dtail/internal/io/fs/permissions" + "github.com/mimecast/dtail/internal/io/logger" ) const maxLinkDepth int = 100 @@ -37,26 +38,28 @@ func (u *User) String() string { } // HasFilePermission is used to determine whether user is alowed to read a file. -func (u *User) HasFilePermission(filePath string) (hasPermission bool) { +func (u *User) HasFilePermission(filePath, permissionType string) (hasPermission bool) { + logger.Debug(u, filePath, permissionType, "Checking config permissions") + cleanPath, err := filepath.EvalSymlinks(filePath) if err != nil { - logger.Error(u, filePath, "Unable to evaluate symlinks", err) + logger.Error(u, filePath, permissionType, "Unable to evaluate symlinks", err) hasPermission = false return } cleanPath, err = filepath.Abs(cleanPath) if err != nil { - logger.Error(u, cleanPath, "Unable to make file path absolute", err) + logger.Error(u, cleanPath, permissionType, "Unable to make file path absolute", err) hasPermission = false return } if cleanPath != filePath { - logger.Info(u, filePath, cleanPath, "Calculated new clean path from original file path (possibly symlink)") + logger.Info(u, filePath, cleanPath, permissionType, "Calculated new clean path from original file path (possibly symlink)") } - hasPermission, err = u.hasFilePermission(cleanPath) + hasPermission, err = u.hasFilePermission(cleanPath, permissionType) if err != nil { logger.Warn(u, cleanPath, err) } @@ -64,12 +67,12 @@ func (u *User) HasFilePermission(filePath string) (hasPermission bool) { return } -func (u *User) hasFilePermission(cleanPath string) (bool, error) { +func (u *User) hasFilePermission(cleanPath, permissionType string) (bool, error) { // First check file system Linux/UNIX permission. if _, err := permissions.ToRead(u.Name, cleanPath); err != nil { - return false, fmt.Errorf("User without OS file system permissions to read file: '%v'", err) + return false, fmt.Errorf("User without OS file system permissions to read path: '%v'", err) } - logger.Info(u, cleanPath, "User has OS file system permissions to read file") + logger.Info(u, cleanPath, permissionType, "User with OS file system permissions to path") // If file system permission is given, also check permissions // as configured in DTail config file. @@ -84,7 +87,7 @@ func (u *User) hasFilePermission(cleanPath string) (bool, error) { var hasPermission bool var err error - if hasPermission, err = u.iteratePaths(cleanPath); err != nil { + if hasPermission, err = u.iteratePaths(cleanPath, permissionType); err != nil { return false, err } @@ -101,17 +104,28 @@ func (u *User) hasFilePermission(cleanPath string) (bool, error) { return hasPermission, nil } -func (u *User) iteratePaths(cleanPath string) (bool, error) { +func (u *User) iteratePaths(cleanPath, permissionType string) (bool, error) { for _, permission := range u.permissions { + typeStr := "readfiles" // Assume ReadFiles by default. + var regexStr string var negate bool + splitted := strings.Split(permission, ":") + if len(splitted) > 1 { + typeStr = splitted[0] + permission = strings.Join(splitted[1:], ":") + } + + if typeStr != permissionType { + continue + } + + regexStr = permission if strings.HasPrefix(permission, "!") { regexStr = permission[1:] negate = true } - regexStr = permission - negate = false re, err := regexp.Compile(regexStr) if err != nil { diff --git a/internal/version/version.go b/internal/version/version.go index 3a4a5dc..3c057df 100644 --- a/internal/version/version.go +++ b/internal/version/version.go @@ -7,18 +7,20 @@ import ( "github.com/mimecast/dtail/internal/color" ) -// Name of DTail. -const Name = "DTail" - -// Version of DTail. -const Version = "1.1.0" - -// Additional information. -const Additional = "" +const ( + // Name of DTail. + Name string = "DTail" + // Version of DTail. + Version string = "2.0.0" + // Additional information for DTail + Additional string = "" + // ProtocolCompat -ibility version. + ProtocolCompat string = "2" +) // String representation of the DTail version. func String() string { - return fmt.Sprintf("%s v%v %s", Name, Version, Additional) + return fmt.Sprintf("%s %v Protocol %s %s", Name, Version, ProtocolCompat, Additional) } // PaintedString is a prettier string representation of the DTail version. @@ -30,7 +32,7 @@ func PaintedString() string { version := color.Paint(color.Blue, Version) descr := color.Paint(color.Green, Additional) - return fmt.Sprintf("%s %v %s", name, version, descr) + return fmt.Sprintf("%s %v Protocol %s %s", name, version, ProtocolCompat, descr) } // PrintAndExit prints the program version and exists. -- cgit v1.2.3