summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
Diffstat (limited to 'internal')
-rw-r--r--internal/clients/baseclient.go2
-rw-r--r--internal/clients/maprclient.go39
-rw-r--r--internal/clients/runclient.go2
-rw-r--r--internal/clients/tailclient.go1
-rw-r--r--internal/io/fs/readfile.go12
-rw-r--r--internal/mapr/server/aggregate.go59
-rw-r--r--internal/server/handlers/serverhandler.go70
-rw-r--r--internal/version/version.go2
8 files changed, 127 insertions, 60 deletions
diff --git a/internal/clients/baseclient.go b/internal/clients/baseclient.go
index 75da187..10a5559 100644
--- a/internal/clients/baseclient.go
+++ b/internal/clients/baseclient.go
@@ -123,7 +123,7 @@ func (c *baseClient) waitUntilDone(ctx context.Context, active chan struct{}) {
// Put it back on the channel
active <- struct{}{}
- if c.Mode == omode.TailClient {
+ if c.Mode == omode.TailClient && c.retry {
<-ctx.Done()
}
diff --git a/internal/clients/maprclient.go b/internal/clients/maprclient.go
index b581844..32340b3 100644
--- a/internal/clients/maprclient.go
+++ b/internal/clients/maprclient.go
@@ -24,7 +24,7 @@ type MaprClient struct {
// The query object (constructed from queryStr)
query *mapr.Query
// Additative result or new result every run?
- additative bool
+ cumulative bool
}
// NewMaprClient returns a new mapreduce client.
@@ -33,23 +33,28 @@ func NewMaprClient(args Args, queryStr string) (*MaprClient, error) {
return nil, errors.New("No mapreduce query specified, use '-query' flag")
}
+ query, err := mapr.NewQuery(queryStr)
+ if err != nil {
+ logger.FatalExit(queryStr, "Can't parse mapr query", err)
+ }
+
+ // Don't retry connection if in tail mode and no outfile specified.
+ retry := args.Mode == omode.TailClient && query.Outfile == ""
+
+ // Result is comulative if we are in MapClient mode or with outfile
+ cumulative := args.Mode == omode.MapClient || query.Outfile != ""
+
c := MaprClient{
baseClient: baseClient{
Args: args,
throttleCh: make(chan struct{}, args.ConnectionsPerCPU*runtime.NumCPU()),
- retry: args.Mode == omode.TailClient,
+ retry: retry,
},
+ query: query,
queryStr: queryStr,
- additative: args.Mode == omode.MapClient,
+ cumulative: cumulative,
}
- query, err := mapr.NewQuery(c.queryStr)
- if err != nil {
- logger.FatalExit(c.queryStr, "Can't parse mapr query", err)
- }
-
- c.query = query
-
switch c.query.Table {
case "*":
c.Regex = fmt.Sprintf("\\|MAPREDUCE:\\|")
@@ -73,7 +78,7 @@ func (c *MaprClient) Start(ctx context.Context) (status int) {
}
status = c.baseClient.Start(ctx)
- if c.additative {
+ if c.cumulative {
c.recievedFinalResult()
}
@@ -87,12 +92,16 @@ func (c MaprClient) makeHandler(server string) handlers.Handler {
func (c MaprClient) makeCommands() (commands []string) {
commands = append(commands, fmt.Sprintf("map %s", c.query.RawQuery))
- modeStr := "tail"
- if c.additative {
- modeStr = "cat"
+ modeStr := "cat"
+ if c.Mode == omode.TailClient {
+ modeStr = "tail"
}
for _, file := range strings.Split(c.What, ",") {
+ if c.Timeout > 0 {
+ commands = append(commands, fmt.Sprintf("timeout %d %s %s regex %s", c.Timeout, modeStr, file, c.Regex))
+ continue
+ }
commands = append(commands, fmt.Sprintf("%s %s regex %s", modeStr, file, c.Regex))
}
@@ -133,7 +142,7 @@ func (c *MaprClient) printResults() {
var err error
var numLines int
- if c.additative {
+ if c.cumulative {
result, numLines, err = c.globalGroup.Result(c.query)
} else {
result, numLines, err = c.globalGroup.SwapOut().Result(c.query)
diff --git a/internal/clients/runclient.go b/internal/clients/runclient.go
index 543df15..9f8e478 100644
--- a/internal/clients/runclient.go
+++ b/internal/clients/runclient.go
@@ -53,8 +53,6 @@ func (c RunClient) makeCommands() (commands []string) {
}
commands = append(commands, fmt.Sprintf("run%s %s", c.options(), c.What))
- logger.Debug(commands)
-
return
}
diff --git a/internal/clients/tailclient.go b/internal/clients/tailclient.go
index 4d81fd5..15e77cc 100644
--- a/internal/clients/tailclient.go
+++ b/internal/clients/tailclient.go
@@ -38,5 +38,6 @@ func (c TailClient) makeCommands() (commands []string) {
for _, file := range strings.Split(c.What, ",") {
commands = append(commands, fmt.Sprintf("%s %s regex %s", c.Mode.String(), file, c.Regex))
}
+
return
}
diff --git a/internal/io/fs/readfile.go b/internal/io/fs/readfile.go
index 321432e..cb16ec1 100644
--- a/internal/io/fs/readfile.go
+++ b/internal/io/fs/readfile.go
@@ -5,6 +5,7 @@ import (
"compress/gzip"
"context"
"errors"
+ "fmt"
"io"
"os"
"regexp"
@@ -39,6 +40,16 @@ type readFile struct {
limiter chan struct{}
}
+// String returns the string representation of the readFile
+func (f readFile) String() string {
+ return fmt.Sprintf("readFile(filePath:%s,globID:%s,retry:%v,canSkipLines:%v,seekEOF:%v)",
+ f.filePath,
+ f.globID,
+ f.retry,
+ f.canSkipLines,
+ f.seekEOF)
+}
+
// FilePath returns the full file path.
func (f readFile) FilePath() string {
return f.filePath
@@ -51,6 +62,7 @@ func (f readFile) Retry() bool {
// Start tailing a log file.
func (f readFile) Start(ctx context.Context, lines chan<- line.Line, regex string) error {
+ logger.Debug("readFile", f)
defer func() {
select {
case <-f.limiter:
diff --git a/internal/mapr/server/aggregate.go b/internal/mapr/server/aggregate.go
index 922dcbd..fade689 100644
--- a/internal/mapr/server/aggregate.go
+++ b/internal/mapr/server/aggregate.go
@@ -27,6 +27,8 @@ type Aggregate struct {
query *mapr.Query
// The mapr log format parser
parser *logformat.Parser
+ cancel context.CancelFunc
+ ctx context.Context
}
// NewAggregate return a new server side aggregator.
@@ -48,6 +50,8 @@ func NewAggregate(queryStr string) (*Aggregate, error) {
logger.FatalExit("Could not create mapr log format parser", err)
}
+ ctx, cancel := context.WithCancel(context.Background())
+
a := Aggregate{
Lines: make(chan line.Line, 100),
serialize: make(chan struct{}),
@@ -55,18 +59,27 @@ func NewAggregate(queryStr string) (*Aggregate, error) {
hostname: s[0],
query: query,
parser: logParser,
+ ctx: ctx,
+ cancel: cancel,
}
return &a, nil
}
-// Start an aggregation run.
+// Start an aggregation.
func (a *Aggregate) Start(ctx context.Context, maprLines chan<- string) {
+ defer a.cancel()
+
fieldsCh := a.linesToFields(ctx)
go a.fieldsToMaprLines(ctx, fieldsCh, maprLines)
a.periodicAggregateTimer(ctx)
}
+// Cancel the aggregation.
+func (a *Aggregate) Cancel() {
+ a.cancel()
+}
+
func (a *Aggregate) periodicAggregateTimer(ctx context.Context) {
for {
select {
@@ -74,6 +87,8 @@ func (a *Aggregate) periodicAggregateTimer(ctx context.Context) {
a.Serialize(ctx)
case <-ctx.Done():
return
+ case <-a.ctx.Done():
+ return
}
}
}
@@ -108,6 +123,8 @@ func (a *Aggregate) linesToFields(ctx context.Context) <-chan map[string]string
}
case <-ctx.Done():
return
+ case <-a.ctx.Done():
+ return
}
}
}()
@@ -118,30 +135,36 @@ func (a *Aggregate) linesToFields(ctx context.Context) <-chan map[string]string
func (a *Aggregate) fieldsToMaprLines(ctx context.Context, fieldsCh <-chan map[string]string, maprLines chan<- string) {
group := mapr.NewGroupSet()
+ serialize := func() {
+ logger.Info("Serializing mapreduce result")
+ group.Serialize(ctx, maprLines)
+ group = mapr.NewGroupSet()
+ logger.Info("Done serializing mapreduce result")
+ }
+
for {
select {
case fields, ok := <-fieldsCh:
if !ok {
- logger.Info("Serializing mapreduce result (final)")
- group.Serialize(ctx, maprLines)
- group = mapr.NewGroupSet()
- logger.Info("Done serializing mapreduce result (final)")
+ serialize()
return
}
a.aggregate(group, fields)
case <-a.serialize:
- logger.Info("Serializing mapreduce result")
- group.Serialize(ctx, maprLines)
- group = mapr.NewGroupSet()
- logger.Info("Done serializing mapreduce result")
+ serialize()
case <-a.flush:
logger.Info("Flushing mapreduce result")
- group.Serialize(ctx, maprLines)
- group = mapr.NewGroupSet()
+ serialize()
a.flush <- struct{}{}
logger.Info("Done flushing mapreduce result")
case <-ctx.Done():
return
+ case <-a.ctx.Done():
+ logger.Info("Flushing mapreduce result")
+ serialize()
+ a.flush <- struct{}{}
+ logger.Info("Done flushing mapreduce result")
+ return
}
}
}
@@ -190,6 +213,16 @@ func (a *Aggregate) Serialize(ctx context.Context) {
// Flush all data.
func (a *Aggregate) Flush() {
- a.flush <- struct{}{}
- <-a.flush
+ select {
+ case <-a.ctx.Done():
+ return
+ case a.flush <- struct{}{}:
+ case <-time.After(time.Minute):
+ return
+ }
+
+ select {
+ case <-a.flush:
+ case <-time.After(time.Minute):
+ }
}
diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go
index 739696c..939388c 100644
--- a/internal/server/handlers/serverhandler.go
+++ b/internal/server/handlers/serverhandler.go
@@ -10,6 +10,7 @@ import (
"strconv"
"strings"
"sync"
+ "sync/atomic"
"time"
"github.com/mimecast/dtail/internal/config"
@@ -30,7 +31,6 @@ const (
// the Bi-directional communication between SSH client and server.
// This handler implements the handler of the SSH server.
type ServerHandler struct {
- mutex *sync.Mutex
lines chan line.Line
regex string
aggregate *server.Aggregate
@@ -47,7 +47,8 @@ type ServerHandler struct {
serverCtx context.Context
handlerCtx context.Context
done chan struct{}
- activeCommands int
+ activeCommands int32
+ activeReaders int32
background background.Background
}
@@ -57,7 +58,6 @@ func NewServerHandler(handlerCtx, serverCtx context.Context, user *user.User, ca
serverCtx: serverCtx,
handlerCtx: handlerCtx,
done: make(chan struct{}),
- mutex: &sync.Mutex{},
lines: make(chan line.Line, 100),
serverMessages: make(chan string, 10),
aggregatedMessages: make(chan string, 10),
@@ -170,10 +170,11 @@ func (h *ServerHandler) handleCommand(ctx context.Context, commandStr string) {
}
if timeout > 0 {
- logger.Debug("Command with timeout context", argc, args, timeout)
+ logger.Info(h.user, "Command with timeout context", argc, args, timeout)
commandCtx, cancel := context.WithTimeout(ctx, timeout)
go func() {
<-commandCtx.Done()
+ logger.Info(h.user, "Command timed out, canceling it", args, args, timeout)
cancel()
}()
h.handleUserCommand(commandCtx, argc, args, timeout)
@@ -241,11 +242,19 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []
logger.Debug(h.user, "handleUserCommand", argc, args)
h.incrementActiveCommands()
- finished := func() {
+ commandFinished := func() {
if h.decrementActiveCommands() == 0 {
h.shutdown()
}
}
+ readerFinished := func() {
+ if h.decrementActiveReaders() == 0 {
+ if h.aggregate == nil {
+ return
+ }
+ h.aggregate.Cancel()
+ }
+ }
splitted := strings.Split(args[0], ":")
commandName := splitted[0]
@@ -253,24 +262,27 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []
options, err := readOptions(splitted[1:])
if err != nil {
h.sendServerMessage(logger.Error(h.user, err))
- finished()
+ commandFinished()
return
}
switch commandName {
case "grep", "cat":
command := newReadCommand(h, omode.CatClient)
- h.incrementActiveCommands()
go func() {
+ h.incrementActiveReaders()
command.Start(ctx, argc, args)
- finished()
+ readerFinished()
+ commandFinished()
}()
case "tail":
command := newReadCommand(h, omode.TailClient)
go func() {
+ h.incrementActiveReaders()
command.Start(ctx, argc, args)
- finished()
+ readerFinished()
+ commandFinished()
}()
case "map":
@@ -278,14 +290,14 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []
if err != nil {
h.sendServerMessage(err.Error())
logger.Error(h.user, err)
- finished()
+ commandFinished()
return
}
h.aggregate = aggregate
go func() {
command.Start(ctx, h.aggregatedMessages)
- finished()
+ commandFinished()
}()
case "run":
@@ -301,7 +313,7 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []
} else {
h.sendServerMessage(logger.Info(h.user, "job cancelled", jobName))
}
- finished()
+ commandFinished()
return
}
@@ -313,7 +325,7 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []
count++
}
h.sendServerMessage(fmt.Sprintf("Found %d jobs", count))
- finished()
+ commandFinished()
return
}
@@ -339,7 +351,7 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []
if err := h.background.Add(h.user.Name, jobName, cancel, &wg); err != nil {
h.sendServerMessage(logger.Error(h.user, err, jobName, args))
- finished()
+ commandFinished()
return
}
ctx = commandCtx
@@ -347,7 +359,7 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []
if err := command.StartBackground(ctx, &wg, argc, args, outerArgs); err != nil {
h.sendServerMessage(logger.Error(h.user, "Unable to execute command", argc, args, err))
- finished()
+ commandFinished()
return
}
@@ -360,21 +372,21 @@ func (h *ServerHandler) handleUserCommand(ctx context.Context, argc int, args []
if background {
h.sendServerMessage(logger.Info(h.user, jobName, "job started in background"))
- finished()
+ commandFinished()
return
}
// Command run in foreground, wait for it to complete before finishing the connection.
wg.Wait()
- finished()
+ commandFinished()
case "ack", ".ack":
h.handleAckCommand(argc, args)
- finished()
+ commandFinished()
default:
h.sendServerMessage(logger.Error(h.user, "Received unknown user command", commandName, argc, args, options))
- finished()
+ commandFinished()
}
}
@@ -450,19 +462,21 @@ func (h *ServerHandler) shutdown() {
}
func (h *ServerHandler) incrementActiveCommands() {
- // TODO: Use atomic counter variable instead, so we can get rid of the mutex
- h.mutex.Lock()
- defer h.mutex.Unlock()
+ atomic.AddInt32(&h.activeCommands, 1)
+}
- h.activeCommands++
+func (h *ServerHandler) decrementActiveCommands() int32 {
+ atomic.AddInt32(&h.activeCommands, -1)
+ return atomic.LoadInt32(&h.activeCommands)
}
-func (h *ServerHandler) decrementActiveCommands() int {
- h.mutex.Lock()
- defer h.mutex.Unlock()
+func (h *ServerHandler) incrementActiveReaders() {
+ atomic.AddInt32(&h.activeReaders, 1)
+}
- h.activeCommands--
- return h.activeCommands
+func (h *ServerHandler) decrementActiveReaders() int32 {
+ atomic.AddInt32(&h.activeReaders, -1)
+ return atomic.LoadInt32(&h.activeReaders)
}
func readOptions(opts []string) (map[string]string, error) {
diff --git a/internal/version/version.go b/internal/version/version.go
index 6115453..290eb2f 100644
--- a/internal/version/version.go
+++ b/internal/version/version.go
@@ -13,7 +13,7 @@ const (
// Version of DTail.
Version string = "2.1.1"
// Additional information for DTail
- Additional string = "develop2"
+ Additional string = "develop3"
// ProtocolCompat -ibility version.
ProtocolCompat string = "2"
)