diff options
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/clients/baseclient.go | 2 | ||||
| -rw-r--r-- | internal/clients/maprclient.go | 39 | ||||
| -rw-r--r-- | internal/clients/runclient.go | 2 | ||||
| -rw-r--r-- | internal/clients/tailclient.go | 1 | ||||
| -rw-r--r-- | internal/io/fs/readfile.go | 12 | ||||
| -rw-r--r-- | internal/mapr/server/aggregate.go | 59 | ||||
| -rw-r--r-- | internal/server/handlers/serverhandler.go | 70 | ||||
| -rw-r--r-- | internal/version/version.go | 2 |
8 files changed, 127 insertions, 60 deletions
diff --git a/internal/clients/baseclient.go b/internal/clients/baseclient.go index 75da187..10a5559 100644 --- a/internal/clients/baseclient.go +++ b/internal/clients/baseclient.go @@ -123,7 +123,7 @@ func (c *baseClient) waitUntilDone(ctx context.Context, active chan struct{}) { // Put it back on the channel active <- struct{}{} - if c.Mode == omode.TailClient { + if c.Mode == omode.TailClient && c.retry { <-ctx.Done() } diff --git a/internal/clients/maprclient.go b/internal/clients/maprclient.go index b581844..32340b3 100644 --- a/internal/clients/maprclient.go +++ b/internal/clients/maprclient.go @@ -24,7 +24,7 @@ type MaprClient struct { // The query object (constructed from queryStr) query *mapr.Query // Additative result or new result every run? - additative bool + cumulative bool } // NewMaprClient returns a new mapreduce client. @@ -33,23 +33,28 @@ func NewMaprClient(args Args, queryStr string) (*MaprClient, error) { return nil, errors.New("No mapreduce query specified, use '-query' flag") } + query, err := mapr.NewQuery(queryStr) + if err != nil { + logger.FatalExit(queryStr, "Can't parse mapr query", err) + } + + // Don't retry connection if in tail mode and no outfile specified. + retry := args.Mode == omode.TailClient && query.Outfile == "" + + // Result is comulative if we are in MapClient mode or with outfile + cumulative := args.Mode == omode.MapClient || query.Outfile != "" + c := MaprClient{ baseClient: baseClient{ Args: args, throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()), - retry: args.Mode == omode.TailClient, + retry: retry, }, + query: query, queryStr: queryStr, - additative: args.Mode == omode.MapClient, + cumulative: cumulative, } - query, err := mapr.NewQuery(c.queryStr) - if err != nil { - logger.FatalExit(c.queryStr, "Can't parse mapr query", err) - } - - c.query = query - switch c.query.Table { case "*": c.Regex = fmt.Sprintf("\\|MAPREDUCE:\\|") @@ -73,7 +78,7 @@ func (c *MaprClient) Start(ctx context.Context) (status int) { } status = c.baseClient.Start(ctx) - if c.additative { + if c.cumulative { c.recievedFinalResult() } @@ -87,12 +92,16 @@ func (c MaprClient) makeHandler(server string) handlers.Handler { func (c MaprClient) makeCommands() (commands []string) { commands = append(commands, fmt.Sprintf("map %s", c.query.RawQuery)) - modeStr := "tail" - if c.additative { - modeStr = "cat" + modeStr := "cat" + if c.Mode == omode.TailClient { + modeStr = "tail" } for _, file := range strings.Split(c.What, ",") { + if c.Timeout > 0 { + commands = append(commands, fmt.Sprintf("timeout %d %s %s regex %s", c.Timeout, modeStr, file, c.Regex)) + continue + } commands = append(commands, fmt.Sprintf("%s %s regex %s", modeStr, file, c.Regex)) } @@ -133,7 +142,7 @@ func (c *MaprClient) printResults() { var err error var numLines int - if c.additative { + if c.cumulative { result, numLines, err = c.globalGroup.Result(c.query) } else { result, numLines, err = c.globalGroup.SwapOut().Result(c.query) diff --git a/internal/clients/runclient.go b/internal/clients/runclient.go index 543df15..9f8e478 100644 --- a/internal/clients/runclient.go +++ b/internal/clients/runclient.go @@ -53,8 +53,6 @@ func (c RunClient) makeCommands() (commands []string) { } commands = append(commands, fmt.Sprintf("run%s %s", c.options(), c.What)) - logger.Debug(commands) - return } diff --git a/internal/clients/tailclient.go b/internal/clients/tailclient.go index 4d81fd5..15e77cc 100644 --- a/internal/clients/tailclient.go +++ b/internal/clients/tailclient.go @@ -38,5 +38,6 @@ func (c TailClient) makeCommands() (commands []string) { for _, file := range strings.Split(c.What, ",") { commands = append(commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex)) } + return } diff --git a/internal/io/fs/readfile.go b/internal/io/fs/readfile.go index 321432e..cb16ec1 100644 --- a/internal/io/fs/readfile.go +++ b/internal/io/fs/readfile.go @@ -5,6 +5,7 @@ import ( "compress/gzip" "context" "errors" + "fmt" "io" "os" "regexp" @@ -39,6 +40,16 @@ type readFile struct { limiter chan struct{} } +// String returns the string representation of the readFile +func (f readFile) String() string { + return fmt.Sprintf("readFile(filePath:%s,globID:%s,retry:%v,canSkipLines:%v,seekEOF:%v)", + f.filePath, + f.globID, + f.retry, + f.canSkipLines, + f.seekEOF) +} + // FilePath returns the full file path. func (f readFile) FilePath() string { return f.filePath @@ -51,6 +62,7 @@ func (f readFile) Retry() bool { // Start tailing a log file. func (f readFile) Start(ctx context.Context, lines chan<- line.Line, regex string) error { + logger.Debug("readFile", f) defer func() { select { case <-f.limiter: diff --git a/internal/mapr/server/aggregate.go b/internal/mapr/server/aggregate.go index 922dcbd..fade689 100644 --- a/internal/mapr/server/aggregate.go +++ b/internal/mapr/server/aggregate.go @@ -27,6 +27,8 @@ type Aggregate struct { query *mapr.Query // The mapr log format parser parser *logformat.Parser + cancel context.CancelFunc + ctx context.Context } // NewAggregate return a new server side aggregator. @@ -48,6 +50,8 @@ func NewAggregate(queryStr string) (*Aggregate, error) { logger.FatalExit("Could not create mapr log format parser", err) } + ctx, cancel := context.WithCancel(context.Background()) + a := Aggregate{ Lines: make(chan line.Line, 100), serialize: make(chan struct{}), @@ -55,18 +59,27 @@ func NewAggregate(queryStr string) (*Aggregate, error) { hostname: s[0], query: query, parser: logParser, + ctx: ctx, + cancel: cancel, } return &a, nil } -// Start an aggregation run. +// Start an aggregation. func (a *Aggregate) Start(ctx context.Context, maprLines chan<- string) { + defer a.cancel() + fieldsCh := a.linesToFields(ctx) go a.fieldsToMaprLines(ctx, fieldsCh, maprLines) a.periodicAggregateTimer(ctx) } +// Cancel the aggregation. +func (a *Aggregate) Cancel() { + a.cancel() +} + func (a *Aggregate) periodicAggregateTimer(ctx context.Context) { for { select { @@ -74,6 +87,8 @@ func (a *Aggregate) periodicAggregateTimer(ctx context.Context) { a.Serialize(ctx) case <-ctx.Done(): return + case <-a.ctx.Done(): + return } } } @@ -108,6 +123,8 @@ func (a *Aggregate) linesToFields(ctx context.Context) <-chan map[string]string } case <-ctx.Done(): return + case <-a.ctx.Done(): + return } } }() @@ -118,30 +135,36 @@ func (a *Aggregate) linesToFields(ctx context.Context) <-chan map[string]string func (a *Aggregate) fieldsToMaprLines(ctx context.Context, fieldsCh <-chan map[string]string, maprLines chan<- string) { group := mapr.NewGroupSet() + serialize := func() { + logger.Info("Serializing mapreduce result") + group.Serialize(ctx, maprLines) + group = mapr.NewGroupSet() + logger.Info("Done serializing mapreduce result") + } + for { select { case fields, ok := <-fieldsCh: if !ok { - logger.Info("Serializing mapreduce result (final)") - group.Serialize(ctx, maprLines) - group = mapr.NewGroupSet() - logger.Info("Done serializing mapreduce result (final)") + serialize() return } a.aggregate(group, fields) case <-a.serialize: - logger.Info("Serializing mapreduce result") - group.Serialize(ctx, maprLines) - group = mapr.NewGroupSet() - logger.Info("Done serializing mapreduce result") + serialize() case <-a.flush: logger.Info("Flushing mapreduce result") - group.Serialize(ctx, maprLines) - group = mapr.NewGroupSet() + serialize() a.flush <- struct{}{} logger.Info("Done flushing mapreduce result") case <-ctx.Done(): return + case <-a.ctx.Done(): + logger.Info("Flushing mapreduce result") + serialize() + a.flush <- struct{}{} + logger.Info("Done flushing mapreduce result") + return } } } @@ -190,6 +213,16 @@ func (a *Aggregate) Serialize(ctx context.Context) { // Flush all data. func (a *Aggregate) Flush() { - a.flush <- struct{}{} - <-a.flush + select { + case <-a.ctx.Done(): + return + case a.flush <- struct{}{}: + case <-time.After(time.Minute): + return + } + + select { + case <-a.flush: + case <-time.After(time.Minute): + } } diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go index 739696c..939388c 100644 --- a/internal/server/handlers/serverhandler.go +++ b/internal/server/handlers/serverhandler.go @@ -10,6 +10,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" "github.com/mimecast/dtail/internal/config" @@ -30,7 +31,6 @@ const ( // the Bi-directional communication between SSH client and server. // This handler implements the handler of the SSH server. type ServerHandler struct { - mutex *sync.Mutex lines chan line.Line regex string aggregate *server.Aggregate @@ -47,7 +47,8 @@ type ServerHandler struct { serverCtx context.Context handlerCtx context.Context done chan struct{} - activeCommands int + activeCommands int32 + activeReaders int32 background background.Background } @@ -57,7 +58,6 @@ func NewServerHandler(handlerCtx, serverCtx context.Context, user *user.User, ca serverCtx: serverCtx, handlerCtx: handlerCtx, done: make(chan struct{}), - mutex: &sync.Mutex{}, lines: make(chan line.Line, 100), serverMessages: make(chan string, 10), aggregatedMessages: make(chan string, 10), @@ -170,10 +170,11 @@ func (h *ServerHandler) handleCommand(ctx context.Context, commandStr string) { } if timeout > 0 { - logger.Debug("Command with timeout context", argc, args, timeout) + logger.Info(h.user, "Command with timeout context", argc, args, timeout) commandCtx, cancel := context.WithTimeout(ctx, timeout) go func() { <-commandCtx.Done() + logger.Info(h.user, "Command timed out, canceling it", args, args, timeout) cancel() }() h.handleUserCommand(commandCtx, argc, args, timeout) @@ -241,11 +242,19 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args [] logger.Debug(h.user, "handleUserCommand", argc, args) h.incrementActiveCommands() - finished := func() { + commandFinished := func() { if h.decrementActiveCommands() == 0 { h.shutdown() } } + readerFinished := func() { + if h.decrementActiveReaders() == 0 { + if h.aggregate == nil { + return + } + h.aggregate.Cancel() + } + } splitted := strings.Split(args[0], ":") commandName := splitted[0] @@ -253,24 +262,27 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args [] options, err := readOptions(splitted[1:]) if err != nil { h.sendServerMessage(logger.Error(h.user, err)) - finished() + commandFinished() return } switch commandName { case "grep", "cat": command := newReadCommand(h, omode.CatClient) - h.incrementActiveCommands() go func() { + h.incrementActiveReaders() command.Start(ctx, argc, args) - finished() + readerFinished() + commandFinished() }() case "tail": command := newReadCommand(h, omode.TailClient) go func() { + h.incrementActiveReaders() command.Start(ctx, argc, args) - finished() + readerFinished() + commandFinished() }() case "map": @@ -278,14 +290,14 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args [] if err != nil { h.sendServerMessage(err.Error()) logger.Error(h.user, err) - finished() + commandFinished() return } h.aggregate = aggregate go func() { command.Start(ctx, h.aggregatedMessages) - finished() + commandFinished() }() case "run": @@ -301,7 +313,7 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args [] } else { h.sendServerMessage(logger.Info(h.user, "job cancelled", jobName)) } - finished() + commandFinished() return } @@ -313,7 +325,7 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args [] count++ } h.sendServerMessage(fmt.Sprintf("Found %d jobs", count)) - finished() + commandFinished() return } @@ -339,7 +351,7 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args [] if err := h.background.Add(h.user.Name, jobName, cancel, &wg); err != nil { h.sendServerMessage(logger.Error(h.user, err, jobName, args)) - finished() + commandFinished() return } ctx = commandCtx @@ -347,7 +359,7 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args [] if err := command.StartBackground(ctx, &wg, argc, args, outerArgs); err != nil { h.sendServerMessage(logger.Error(h.user, "Unable to execute command", argc, args, err)) - finished() + commandFinished() return } @@ -360,21 +372,21 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args [] if background { h.sendServerMessage(logger.Info(h.user, jobName, "job started in background")) - finished() + commandFinished() return } // Command run in foreground, wait for it to complete before finishing the connection. wg.Wait() - finished() + commandFinished() case "ack", ".ack": h.handleAckCommand(argc, args) - finished() + commandFinished() default: h.sendServerMessage(logger.Error(h.user, "Received unknown user command", commandName, argc, args, options)) - finished() + commandFinished() } } @@ -450,19 +462,21 @@ func (h *ServerHandler) shutdown() { } func (h *ServerHandler) incrementActiveCommands() { - // TODO: Use atomic counter variable instead, so we can get rid of the mutex - h.mutex.Lock() - defer h.mutex.Unlock() + atomic.AddInt32(&h.activeCommands, 1) +} - h.activeCommands++ +func (h *ServerHandler) decrementActiveCommands() int32 { + atomic.AddInt32(&h.activeCommands, -1) + return atomic.LoadInt32(&h.activeCommands) } -func (h *ServerHandler) decrementActiveCommands() int { - h.mutex.Lock() - defer h.mutex.Unlock() +func (h *ServerHandler) incrementActiveReaders() { + atomic.AddInt32(&h.activeReaders, 1) +} - h.activeCommands-- - return h.activeCommands +func (h *ServerHandler) decrementActiveReaders() int32 { + atomic.AddInt32(&h.activeReaders, -1) + return atomic.LoadInt32(&h.activeReaders) } func readOptions(opts []string) (map[string]string, error) { diff --git a/internal/version/version.go b/internal/version/version.go index 6115453..290eb2f 100644 --- a/internal/version/version.go +++ b/internal/version/version.go @@ -13,7 +13,7 @@ const ( // Version of DTail. Version string = "2.1.1" // Additional information for DTail - Additional string = "develop2" + Additional string = "develop3" // ProtocolCompat -ibility version. ProtocolCompat string = "2" ) |
