summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorPaul Buetow <paul@buetow.org>2026-03-20 11:06:50 +0200
committerPaul Buetow <paul@buetow.org>2026-03-20 11:06:50 +0200
commit13b21feb07c86f65760f7338f284f3b492364cd9 (patch)
treec9fa6fc4fb0c7fe8b927297d26e5f3b1448a3518 /internal
parentda8e581617a0240626d2bc922916416440e65bae (diff)
Optimize mapr parsing and stabilize aggregate shutdown
Diffstat (limited to 'internal')
-rw-r--r--internal/clients/handlers/maprhandler.go10
-rw-r--r--internal/clients/handlers/maprhandler_test.go55
-rw-r--r--internal/mapr/aggregateset.go4
-rw-r--r--internal/mapr/client/aggregate.go31
-rw-r--r--internal/mapr/client/aggregate_test.go26
-rw-r--r--internal/mapr/logformat/csv.go40
-rw-r--r--internal/mapr/logformat/default.go237
-rw-r--r--internal/mapr/logformat/default_benchmark_test.go44
-rw-r--r--internal/mapr/logformat/default_test.go35
-rw-r--r--internal/mapr/logformat/delimited.go12
-rw-r--r--internal/mapr/logformat/generic.go11
-rw-r--r--internal/mapr/logformat/generickv.go34
-rw-r--r--internal/mapr/logformat/parser.go20
-rw-r--r--internal/mapr/parserfieldplan.go81
-rw-r--r--internal/mapr/parserfieldplan_test.go32
-rw-r--r--internal/mapr/server/aggregate.go14
-rw-r--r--internal/mapr/server/groupkey.go31
-rw-r--r--internal/mapr/server/turbo_aggregate.go528
-rw-r--r--internal/mapr/server/turbo_aggregate_test.go83
-rw-r--r--internal/server/handlers/serverhandler.go2
-rw-r--r--internal/tools/profile/profile.go107
-rw-r--r--internal/tools/profile/profile_test.go30
22 files changed, 864 insertions, 603 deletions
diff --git a/internal/clients/handlers/maprhandler.go b/internal/clients/handlers/maprhandler.go
index d4e171c..4391a34 100644
--- a/internal/clients/handlers/maprhandler.go
+++ b/internal/clients/handlers/maprhandler.go
@@ -75,3 +75,13 @@ func (h *MaprHandler) handleAggregateMessage(message string) {
dlog.Client.Error("Unable to aggregate data", h.server, message, err)
}
}
+
+// Shutdown flushes any pending aggregate state before marking the handler done.
+func (h *MaprHandler) Shutdown() {
+ if h.aggregate != nil {
+ if err := h.aggregate.Flush(); err != nil {
+ dlog.Client.Error("Unable to flush aggregate data on shutdown", h.server, err)
+ }
+ }
+ h.baseHandler.Shutdown()
+}
diff --git a/internal/clients/handlers/maprhandler_test.go b/internal/clients/handlers/maprhandler_test.go
new file mode 100644
index 0000000..7b7b211
--- /dev/null
+++ b/internal/clients/handlers/maprhandler_test.go
@@ -0,0 +1,55 @@
+package handlers
+
+import (
+ "strings"
+ "testing"
+
+ "github.com/mimecast/dtail/internal/mapr"
+ maprclient "github.com/mimecast/dtail/internal/mapr/client"
+ "github.com/mimecast/dtail/internal/protocol"
+)
+
+func TestMaprHandlerShutdownFlushesPendingAggregateState(t *testing.T) {
+ query, err := mapr.NewQuery("select status,count(status) from stats group by status")
+ if err != nil {
+ t.Fatalf("NewQuery() error = %v", err)
+ }
+
+ session := maprclient.NewSessionState(query)
+ handler := NewMaprHandler("srv1", session)
+ countStorage := handlerCountStorage(t, query)
+
+ message := strings.Join([]string{
+ "ERROR",
+ "2",
+ countStorage + protocol.AggregateKVDelimiter + "2",
+ "",
+ }, protocol.AggregateDelimiter)
+ if err := handler.aggregate.Aggregate(message); err != nil {
+ t.Fatalf("Aggregate() error = %v", err)
+ }
+
+ handler.Shutdown()
+
+ result, numRows, err := session.Snapshot().GlobalGroup.Result(query, 10, nil)
+ if err != nil {
+ t.Fatalf("Result() error = %v", err)
+ }
+ if numRows != 1 {
+ t.Fatalf("numRows = %d, want 1", numRows)
+ }
+ if !strings.Contains(result, "2") {
+ t.Fatalf("expected flushed aggregate row, got %q", result)
+ }
+}
+
+func handlerCountStorage(t *testing.T, query *mapr.Query) string {
+ t.Helper()
+ for _, selectCondition := range query.Select {
+ if selectCondition.Operation == mapr.Count {
+ return selectCondition.FieldStorage
+ }
+ }
+ t.Fatalf("query %q does not contain count() storage", query.RawQuery)
+ return ""
+}
diff --git a/internal/mapr/aggregateset.go b/internal/mapr/aggregateset.go
index 263ef35..fc2354f 100644
--- a/internal/mapr/aggregateset.go
+++ b/internal/mapr/aggregateset.go
@@ -79,13 +79,13 @@ func (s *AggregateSet) Serialize(ctx context.Context, groupKey string, ch chan<-
sb.WriteString(groupKey)
sb.WriteString(protocol.AggregateDelimiter)
- sb.WriteString(fmt.Sprintf("%d", s.Samples))
+ sb.WriteString(strconv.Itoa(s.Samples))
sb.WriteString(protocol.AggregateDelimiter)
for k, v := range s.FValues {
sb.WriteString(k)
sb.WriteString(protocol.AggregateKVDelimiter)
- sb.WriteString(fmt.Sprintf("%v", v))
+ sb.WriteString(strconv.FormatFloat(v, 'f', -1, 64))
sb.WriteString(protocol.AggregateDelimiter)
}
diff --git a/internal/mapr/client/aggregate.go b/internal/mapr/client/aggregate.go
index 8cbd339..9989e8f 100644
--- a/internal/mapr/client/aggregate.go
+++ b/internal/mapr/client/aggregate.go
@@ -92,15 +92,40 @@ func (a *Aggregate) Aggregate(message string) error {
return nil
}
+// Flush merges any pending per-server aggregate state into the shared global group.
+// The normal hot path uses MergeNoblock to avoid stalling on the global merge lock.
+// During shutdown we need a blocking flush so the last local batch is not lost.
+func (a *Aggregate) Flush() error {
+ if a.session == nil {
+ return fmt.Errorf("missing client mapreduce session state")
+ }
+
+ snapshot := a.session.Snapshot()
+ if snapshot.Query == nil || snapshot.GlobalGroup == nil {
+ return nil
+ }
+ if snapshot.Generation != a.generation {
+ a.group.InitSet()
+ a.generation = snapshot.Generation
+ return nil
+ }
+
+ if err := snapshot.GlobalGroup.Merge(snapshot.Query, a.group); err != nil {
+ return fmt.Errorf("unable to flush aggregate data for server %s: %w", a.server, err)
+ }
+ a.group.InitSet()
+ return nil
+}
+
// Create a map of key-value pairs from a part list such as ["foo=bar", "bar=baz"].
func (a *Aggregate) makeFields(parts []string) map[string]string {
fields := make(map[string]string, len(parts))
for _, part := range parts {
- kv := strings.SplitN(part, protocol.AggregateKVDelimiter, 2)
- if len(kv) != 2 {
+ key, value, ok := strings.Cut(part, protocol.AggregateKVDelimiter)
+ if !ok {
continue
}
- fields[kv[0]] = kv[1]
+ fields[key] = value
}
return fields
}
diff --git a/internal/mapr/client/aggregate_test.go b/internal/mapr/client/aggregate_test.go
index 8ac94a1..3387a63 100644
--- a/internal/mapr/client/aggregate_test.go
+++ b/internal/mapr/client/aggregate_test.go
@@ -57,6 +57,32 @@ func TestAggregateRejectsMalformedMessage(t *testing.T) {
}
}
+func TestAggregateFlushMergesPendingLocalState(t *testing.T) {
+ query := mustSessionStateQuery(t, "select status,count(status) from stats group by status")
+ state := NewSessionState(query)
+ aggregate := NewAggregate("srv1", state)
+ countStorage := aggregateCountStorage(t, query)
+
+ set := aggregate.group.GetSet("ERROR")
+ set.Samples = 3
+ set.FValues[countStorage] = 3
+
+ if err := aggregate.Flush(); err != nil {
+ t.Fatalf("Flush() error = %v", err)
+ }
+
+ result, numRows, err := state.Snapshot().GlobalGroup.Result(query, 10, nil)
+ if err != nil {
+ t.Fatalf("Result() error = %v", err)
+ }
+ if numRows != 1 {
+ t.Fatalf("numRows = %d, want 1", numRows)
+ }
+ if !strings.Contains(result, "3") {
+ t.Fatalf("expected flushed aggregate row, got %q", result)
+ }
+}
+
func aggregateCountStorage(t *testing.T, query *mapr.Query) string {
t.Helper()
diff --git a/internal/mapr/logformat/csv.go b/internal/mapr/logformat/csv.go
index b8f565c..ecb1f8b 100644
--- a/internal/mapr/logformat/csv.go
+++ b/internal/mapr/logformat/csv.go
@@ -2,7 +2,6 @@ package logformat
import (
"fmt"
- "strings"
"github.com/mimecast/dtail/internal/protocol"
)
@@ -29,27 +28,38 @@ func (p *csvParser) MakeFields(maprLine string) (map[string]string, error) {
return nil, ErrIgnoreFields
}
- fields := make(map[string]string, 7+len(p.header))
- fields["*"] = "*"
- fields["$hostname"] = p.hostname
- fields["$server"] = p.hostname
- fields["$line"] = maprLine
- fields["$empty"] = ""
- fields["$timezone"] = p.timeZoneName
- fields["$timeoffset"] = p.timeZoneOffset
-
- splitted := strings.Split(maprLine, protocol.CSVDelimiter)
- for i, value := range splitted {
- if i >= len(p.header) {
+ fields := make(map[string]string, p.fieldsCapacity)
+ p.addDefaultFields(fields, maprLine)
+ start := 0
+ column := 0
+ delimiter := protocol.CSVDelimiter[0]
+
+ for {
+ value, next, done := scanDelimitedField(maprLine, start, delimiter)
+ if column >= len(p.header) {
return fields, fmt.Errorf("CSV file seems corrupted, more fields than header values?")
}
- fields[p.header[i]] = value
+ p.addDynamicField(fields, p.header[column], value)
+ column++
+ if done {
+ break
+ }
+ start = next
}
return fields, nil
}
func (p *csvParser) parseHeader(maprLine string) {
- p.header = strings.Split(maprLine, protocol.CSVDelimiter)
+ start := 0
+ delimiter := protocol.CSVDelimiter[0]
+ for {
+ header, next, done := scanDelimitedField(maprLine, start, delimiter)
+ p.header = append(p.header, header)
+ if done {
+ break
+ }
+ start = next
+ }
p.hasHeader = true
}
diff --git a/internal/mapr/logformat/default.go b/internal/mapr/logformat/default.go
index a499bc5..396a589 100644
--- a/internal/mapr/logformat/default.go
+++ b/internal/mapr/logformat/default.go
@@ -4,6 +4,7 @@ import (
"fmt"
"strings"
+ "github.com/mimecast/dtail/internal/mapr"
"github.com/mimecast/dtail/internal/protocol"
)
@@ -11,62 +12,216 @@ type defaultParser struct {
hostname string
timeZoneName string
timeZoneOffset string
+ fieldsCapacity int
+
+ wantStar bool
+ wantLine bool
+ wantEmpty bool
+ wantHostname bool
+ wantServer bool
+ wantTimezone bool
+ wantTimeOffset bool
+ wantSeverity bool
+ wantLogLevel bool
+ wantTime bool
+ wantDate bool
+ wantHour bool
+ wantMinute bool
+ wantSecond bool
+ wantPID bool
+ wantCaller bool
+ wantCPUs bool
+ wantGoroutines bool
+ wantCGOCalls bool
+ wantLoadAvg bool
+ wantUptime bool
+
+ allDynamicFields bool
+ dynamicFields map[string]struct{}
}
func newDefaultParser(hostname, timeZoneName string, timeZoneOffset int) (*defaultParser, error) {
- return &defaultParser{
+ parser := &defaultParser{
hostname: hostname,
timeZoneName: timeZoneName,
timeZoneOffset: fmt.Sprintf("%d", timeZoneOffset),
- }, nil
+ }
+ parser.configureFieldPlan(mapr.ParserFieldPlan{AllFields: true})
+ return parser, nil
+}
+
+func (p *defaultParser) setQuery(query *mapr.Query) {
+ p.configureFieldPlan(query.ParserFieldPlan())
}
func (p *defaultParser) MakeFields(maprLine string) (map[string]string, error) {
- splitted := strings.Split(maprLine, protocol.FieldDelimiter)
+ fields := make(map[string]string, p.fieldsCapacity)
+ tokenIndex := 0
+ start := 0
+ delimiter := protocol.FieldDelimiter[0]
- if len(splitted) < 11 || !strings.HasPrefix(splitted[9], "MAPREDUCE:") ||
- !strings.HasPrefix(splitted[0], "INFO") {
+ for {
+ token, next, done := scanDelimitedField(maprLine, start, delimiter)
+ switch {
+ case tokenIndex == 0:
+ if !strings.HasPrefix(token, "INFO") {
+ return nil, ErrIgnoreFields
+ }
+ p.addDefaultFields(fields, maprLine)
+ if p.wantSeverity {
+ fields["$severity"] = token
+ }
+ if p.wantLogLevel {
+ fields["$loglevel"] = token
+ }
+ case tokenIndex == 1:
+ if p.wantTime {
+ fields["$time"] = token
+ }
+ if len(token) == 15 {
+ // Example: 20211002-071209
+ if p.wantDate {
+ fields["$date"] = token[0:8]
+ }
+ if p.wantHour {
+ fields["$hour"] = token[9:11]
+ }
+ if p.wantMinute {
+ fields["$minute"] = token[11:13]
+ }
+ if p.wantSecond {
+ fields["$second"] = token[13:]
+ }
+ }
+ case tokenIndex == 2:
+ if p.wantPID {
+ fields["$pid"] = token
+ }
+ case tokenIndex == 3:
+ if p.wantCaller {
+ fields["$caller"] = token
+ }
+ case tokenIndex == 4:
+ if p.wantCPUs {
+ fields["$cpus"] = token
+ }
+ case tokenIndex == 5:
+ if p.wantGoroutines {
+ fields["$goroutines"] = token
+ }
+ case tokenIndex == 6:
+ if p.wantCGOCalls {
+ fields["$cgocalls"] = token
+ }
+ case tokenIndex == 7:
+ if p.wantLoadAvg {
+ fields["$loadavg"] = token
+ }
+ case tokenIndex == 8:
+ if p.wantUptime {
+ fields["$uptime"] = token
+ }
+ case tokenIndex == 9:
+ if !strings.HasPrefix(token, "MAPREDUCE:") {
+ return nil, ErrIgnoreFields
+ }
+ default:
+ if err := p.addKeyValueField(fields, token); err != nil {
+ return fields, err
+ }
+ }
+
+ tokenIndex++
+ if done {
+ break
+ }
+ start = next
+ }
+
+ if tokenIndex < 11 {
// Not a DTail mapreduce log line.
return nil, ErrIgnoreFields
}
- fields := make(map[string]string, len(splitted)+8)
-
- fields["*"] = "*"
- fields["$line"] = maprLine
- fields["$empty"] = ""
- fields["$hostname"] = p.hostname
- fields["$server"] = p.hostname
- fields["$timezone"] = p.timeZoneName
- fields["$timeoffset"] = p.timeZoneOffset
-
- fields["$severity"] = splitted[0]
- fields["$loglevel"] = splitted[0]
-
- time := splitted[1]
- fields["$time"] = time
- if len(time) == 15 {
- // Example: 20211002-071209
- fields["$date"] = time[0:8]
- fields["$hour"] = time[9:11]
- fields["$minute"] = time[11:13]
- fields["$second"] = time[13:]
+ return fields, nil
+}
+
+func (p *defaultParser) addDefaultFields(fields map[string]string, maprLine string) {
+ if p.wantStar {
+ fields["*"] = "*"
}
- fields["$pid"] = splitted[2]
- fields["$caller"] = splitted[3]
- fields["$cpus"] = splitted[4]
- fields["$goroutines"] = splitted[5]
- fields["$cgocalls"] = splitted[6]
- fields["$loadavg"] = splitted[7]
- fields["$uptime"] = splitted[8]
-
- for _, kv := range splitted[10:] {
- keyAndValue := strings.SplitN(kv, "=", 2)
- if len(keyAndValue) != 2 {
- return fields, fmt.Errorf("Unable to parse key-value token '%s'", kv)
- }
- fields[keyAndValue[0]] = keyAndValue[1]
+ if p.wantLine {
+ fields["$line"] = maprLine
+ }
+ if p.wantEmpty {
+ fields["$empty"] = ""
+ }
+ if p.wantHostname {
+ fields["$hostname"] = p.hostname
+ }
+ if p.wantServer {
+ fields["$server"] = p.hostname
}
+ if p.wantTimezone {
+ fields["$timezone"] = p.timeZoneName
+ }
+ if p.wantTimeOffset {
+ fields["$timeoffset"] = p.timeZoneOffset
+ }
+}
- return fields, nil
+func (p *defaultParser) addDynamicField(fields map[string]string, key string, value string) {
+ if p.allDynamicFields {
+ fields[key] = value
+ return
+ }
+ if _, ok := p.dynamicFields[key]; ok {
+ fields[key] = value
+ }
+}
+
+func (p *defaultParser) addKeyValueField(fields map[string]string, token string) error {
+ keyAndValueIndex := strings.IndexByte(token, '=')
+ if keyAndValueIndex < 0 {
+ return fmt.Errorf("Unable to parse key-value token '%s'", token)
+ }
+ p.addDynamicField(fields, token[:keyAndValueIndex], token[keyAndValueIndex+1:])
+ return nil
+}
+
+func (p *defaultParser) configureFieldPlan(plan mapr.ParserFieldPlan) {
+ p.fieldsCapacity = plan.Capacity()
+ p.dynamicFields = nil
+ p.allDynamicFields = plan.AllFields
+
+ p.wantStar = plan.Needs("*")
+ p.wantLine = plan.Needs("$line")
+ p.wantEmpty = plan.Needs("$empty")
+ p.wantHostname = plan.Needs("$hostname")
+ p.wantServer = plan.Needs("$server")
+ p.wantTimezone = plan.Needs("$timezone")
+ p.wantTimeOffset = plan.Needs("$timeoffset")
+ p.wantSeverity = plan.Needs("$severity")
+ p.wantLogLevel = plan.Needs("$loglevel")
+ p.wantTime = plan.Needs("$time")
+ p.wantDate = plan.Needs("$date")
+ p.wantHour = plan.Needs("$hour")
+ p.wantMinute = plan.Needs("$minute")
+ p.wantSecond = plan.Needs("$second")
+ p.wantPID = plan.Needs("$pid")
+ p.wantCaller = plan.Needs("$caller")
+ p.wantCPUs = plan.Needs("$cpus")
+ p.wantGoroutines = plan.Needs("$goroutines")
+ p.wantCGOCalls = plan.Needs("$cgocalls")
+ p.wantLoadAvg = plan.Needs("$loadavg")
+ p.wantUptime = plan.Needs("$uptime")
+
+ if plan.AllFields {
+ return
+ }
+
+ p.dynamicFields = make(map[string]struct{}, len(plan.Fields))
+ for field := range plan.Fields {
+ p.dynamicFields[field] = struct{}{}
+ }
}
diff --git a/internal/mapr/logformat/default_benchmark_test.go b/internal/mapr/logformat/default_benchmark_test.go
new file mode 100644
index 0000000..b3ae400
--- /dev/null
+++ b/internal/mapr/logformat/default_benchmark_test.go
@@ -0,0 +1,44 @@
+package logformat
+
+import (
+ "testing"
+
+ "github.com/mimecast/dtail/internal/mapr"
+)
+
+func BenchmarkDefaultParserMakeFields(b *testing.B) {
+ input := "INFO|20211002-072342|1|default_benchmark_test.go:0|8|14|7|0.21|471h0m21s|" +
+ "MAPREDUCE:STATS|foo=bar|bar=baz|qux=quux|alpha=beta|gamma=delta"
+
+ b.Run("all_fields", func(b *testing.B) {
+ parser, err := NewParser("default", nil)
+ if err != nil {
+ b.Fatalf("Unable to create parser: %s", err.Error())
+ }
+
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ if _, err := parser.MakeFields(input); err != nil {
+ b.Fatalf("Unable to parse input: %s", err.Error())
+ }
+ }
+ })
+
+ b.Run("query_specific", func(b *testing.B) {
+ q, err := mapr.NewQuery(`select count(foo) from STATS where bar eq "baz"`)
+ if err != nil {
+ b.Fatalf("Unable to create query: %s", err.Error())
+ }
+ parser, err := NewParser("default", q)
+ if err != nil {
+ b.Fatalf("Unable to create parser: %s", err.Error())
+ }
+
+ b.ReportAllocs()
+ for i := 0; i < b.N; i++ {
+ if _, err := parser.MakeFields(input); err != nil {
+ b.Fatalf("Unable to parse input: %s", err.Error())
+ }
+ }
+ })
+}
diff --git a/internal/mapr/logformat/default_test.go b/internal/mapr/logformat/default_test.go
index edf238f..6417c2f 100644
--- a/internal/mapr/logformat/default_test.go
+++ b/internal/mapr/logformat/default_test.go
@@ -3,6 +3,8 @@ package logformat
import (
"fmt"
"testing"
+
+ "github.com/mimecast/dtail/internal/mapr"
)
func TestDefaultLogFormat(t *testing.T) {
@@ -95,3 +97,36 @@ func TestDefaultLogFormat(t *testing.T) {
t.Errorf("Expected fiending field 'foo', but found it\n")
}
}
+
+func TestDefaultLogFormatQuerySpecificFields(t *testing.T) {
+ q, err := mapr.NewQuery(`select count(foo) from STATS where $hostname eq "testhost"`)
+ if err != nil {
+ t.Fatalf("Unable to create query: %s", err.Error())
+ }
+
+ parser, err := NewParser("default", q)
+ if err != nil {
+ t.Fatalf("Unable to create parser: %s", err.Error())
+ }
+
+ fields, err := parser.MakeFields(
+ "INFO|20211002-072342|1|default_test.go:0|8|14|7|0.21|471h0m21s|MAPREDUCE:STATS|foo=bar|bar=baz",
+ )
+ if err != nil {
+ t.Fatalf("Parser unable to make fields: %s", err.Error())
+ }
+
+ requiredFields := []string{"foo", "$hostname"}
+ for _, field := range requiredFields {
+ if _, ok := fields[field]; !ok {
+ t.Errorf("Expected query-specific field '%s' to be present", field)
+ }
+ }
+
+ omittedFields := []string{"bar", "$time", "$pid", "$line"}
+ for _, field := range omittedFields {
+ if _, ok := fields[field]; ok {
+ t.Errorf("Expected query-specific field '%s' to be omitted", field)
+ }
+ }
+}
diff --git a/internal/mapr/logformat/delimited.go b/internal/mapr/logformat/delimited.go
new file mode 100644
index 0000000..2fa0639
--- /dev/null
+++ b/internal/mapr/logformat/delimited.go
@@ -0,0 +1,12 @@
+package logformat
+
+import "strings"
+
+func scanDelimitedField(input string, start int, delimiter byte) (token string, next int, done bool) {
+ index := strings.IndexByte(input[start:], delimiter)
+ if index < 0 {
+ return input[start:], len(input), true
+ }
+ end := start + index
+ return input[start:end], end + 1, false
+}
diff --git a/internal/mapr/logformat/generic.go b/internal/mapr/logformat/generic.go
index 1350eff..ecb9e75 100644
--- a/internal/mapr/logformat/generic.go
+++ b/internal/mapr/logformat/generic.go
@@ -15,15 +15,8 @@ func newGenericParser(hostname, timeZoneName string, timeZoneOffset int) (*gener
}
func (p *genericParser) MakeFields(maprLine string) (map[string]string, error) {
- fields := make(map[string]string, 3)
-
- fields["*"] = "*"
- fields["$hostname"] = p.hostname
- fields["$server"] = p.hostname
- fields["$line"] = maprLine
- fields["$empty"] = ""
- fields["$timezone"] = p.timeZoneName
- fields["$timeoffset"] = p.timeZoneOffset
+ fields := make(map[string]string, p.fieldsCapacity)
+ p.addDefaultFields(fields, maprLine)
return fields, nil
}
diff --git a/internal/mapr/logformat/generickv.go b/internal/mapr/logformat/generickv.go
index bd9aad5..b5da8c1 100644
--- a/internal/mapr/logformat/generickv.go
+++ b/internal/mapr/logformat/generickv.go
@@ -1,10 +1,6 @@
package logformat
-import (
- "strings"
-
- "github.com/mimecast/dtail/internal/protocol"
-)
+import "github.com/mimecast/dtail/internal/protocol"
type genericKVParser struct {
defaultParser
@@ -21,24 +17,20 @@ func newGenericKVParser(hostname, timeZoneName string, timeZoneOffset int) (*gen
}
func (p *genericKVParser) MakeFields(maprLine string) (map[string]string, error) {
- splitted := strings.Split(maprLine, protocol.FieldDelimiter)
- fields := make(map[string]string, len(splitted))
-
- fields["*"] = "*"
- fields["$line"] = maprLine
- fields["$empty"] = ""
- fields["$hostname"] = p.hostname
- fields["$server"] = p.hostname
- fields["$timezone"] = p.timeZoneName
- fields["$timeoffset"] = p.timeZoneOffset
-
- for _, kv := range splitted[0:] {
- keyAndValue := strings.SplitN(kv, "=", 2)
- if len(keyAndValue) != 2 {
- //dlog.Common.Debug("Unable to parse key-value token, ignoring it", kv)
+ fields := make(map[string]string, p.fieldsCapacity)
+ p.addDefaultFields(fields, maprLine)
+ start := 0
+ delimiter := protocol.FieldDelimiter[0]
+
+ for {
+ token, next, done := scanDelimitedField(maprLine, start, delimiter)
+ if err := p.addKeyValueField(fields, token); err != nil {
continue
}
- fields[keyAndValue[0]] = keyAndValue[1]
+ if done {
+ break
+ }
+ start = next
}
return fields, nil
diff --git a/internal/mapr/logformat/parser.go b/internal/mapr/logformat/parser.go
index b6ed87d..d7db826 100644
--- a/internal/mapr/logformat/parser.go
+++ b/internal/mapr/logformat/parser.go
@@ -20,6 +20,10 @@ type Parser interface {
MakeFields(string) (map[string]string, error)
}
+type queryAwareParser interface {
+ setQuery(*mapr.Query)
+}
+
// ParserFactory builds a Parser for a specific log format.
type ParserFactory func(hostname, timeZoneName string, timeZoneOffset int) (Parser, error)
@@ -86,7 +90,9 @@ func NewParser(logFormatName string, query *mapr.Query) (Parser, error) {
timeZoneName, timeZoneOffset := now.Zone()
if parserFactory, found := getParserFactory(logFormatName); found {
- return parserFactory(hostname, timeZoneName, timeZoneOffset)
+ parser, err := parserFactory(hostname, timeZoneName, timeZoneOffset)
+ configureParserQuery(parser, query)
+ return parser, err
}
defaultFactory, found := getParserFactory("default")
@@ -99,5 +105,17 @@ func NewParser(logFormatName string, query *mapr.Query) (Parser, error) {
return p, fmt.Errorf("No '%s' mapr log format and problem creating default one: %v",
logFormatName, err)
}
+ configureParserQuery(p, query)
return p, fmt.Errorf("No '%s' mapr log format", logFormatName)
}
+
+func configureParserQuery(parser Parser, query *mapr.Query) {
+ if parser == nil {
+ return
+ }
+ queryAware, ok := parser.(queryAwareParser)
+ if !ok {
+ return
+ }
+ queryAware.setQuery(query)
+}
diff --git a/internal/mapr/parserfieldplan.go b/internal/mapr/parserfieldplan.go
new file mode 100644
index 0000000..fd831a5
--- /dev/null
+++ b/internal/mapr/parserfieldplan.go
@@ -0,0 +1,81 @@
+package mapr
+
+// ParserFieldPlan describes which raw fields a parser needs to materialize.
+type ParserFieldPlan struct {
+ AllFields bool
+ Fields map[string]struct{}
+}
+
+// Needs reports whether the parser plan requires a field.
+func (p ParserFieldPlan) Needs(field string) bool {
+ if p.AllFields {
+ return true
+ }
+ _, ok := p.Fields[field]
+ return ok
+}
+
+// Capacity returns a reasonable initial capacity for a parsed field map.
+func (p ParserFieldPlan) Capacity() int {
+ if p.AllFields {
+ return 20
+ }
+ if len(p.Fields) == 0 {
+ return 4
+ }
+ return len(p.Fields) + 2
+}
+
+// ParserFieldPlan returns the raw fields required to evaluate the query.
+func (q *Query) ParserFieldPlan() ParserFieldPlan {
+ if q == nil {
+ return ParserFieldPlan{AllFields: true}
+ }
+
+ fields := make(map[string]struct{}, len(q.Select)+len(q.GroupBy)+len(q.Where)*2+len(q.Set))
+ producedBySet := make(map[string]struct{}, len(q.Set))
+
+ add := func(field string) {
+ if field == "" {
+ return
+ }
+ fields[field] = struct{}{}
+ }
+ isProduced := func(field string) bool {
+ _, ok := producedBySet[field]
+ return ok
+ }
+
+ for _, wc := range q.Where {
+ if wc.lType == Field {
+ add(wc.lString)
+ }
+ if wc.rType == Field {
+ add(wc.rString)
+ }
+ }
+
+ for _, sc := range q.Set {
+ switch sc.rType {
+ case Field, FunctionStack:
+ if !isProduced(sc.rString) {
+ add(sc.rString)
+ }
+ }
+ producedBySet[sc.lString] = struct{}{}
+ }
+
+ for _, groupBy := range q.GroupBy {
+ if !isProduced(groupBy) {
+ add(groupBy)
+ }
+ }
+
+ for _, sc := range q.Select {
+ if !isProduced(sc.Field) {
+ add(sc.Field)
+ }
+ }
+
+ return ParserFieldPlan{Fields: fields}
+}
diff --git a/internal/mapr/parserfieldplan_test.go b/internal/mapr/parserfieldplan_test.go
new file mode 100644
index 0000000..f6d664f
--- /dev/null
+++ b/internal/mapr/parserfieldplan_test.go
@@ -0,0 +1,32 @@
+package mapr
+
+import "testing"
+
+func TestParserFieldPlan(t *testing.T) {
+ q, err := NewQuery(
+ "select count($derived) from STATS where $goroutines > 10 " +
+ "set $derived = md5sum(foo), $other = $derived group by $derived",
+ )
+ if err != nil {
+ t.Fatalf("Unable to create query: %s", err.Error())
+ }
+
+ plan := q.ParserFieldPlan()
+ if plan.AllFields {
+ t.Fatalf("Expected query-specific field plan")
+ }
+
+ requiredFields := []string{"foo", "$goroutines"}
+ for _, field := range requiredFields {
+ if !plan.Needs(field) {
+ t.Errorf("Expected field '%s' to be required", field)
+ }
+ }
+
+ notRequiredFields := []string{"$derived", "$other", "$time"}
+ for _, field := range notRequiredFields {
+ if plan.Needs(field) {
+ t.Errorf("Expected field '%s' to not be required", field)
+ }
+ }
+}
diff --git a/internal/mapr/server/aggregate.go b/internal/mapr/server/aggregate.go
index 9a736a5..c9d4641 100644
--- a/internal/mapr/server/aggregate.go
+++ b/internal/mapr/server/aggregate.go
@@ -12,7 +12,6 @@ import (
"github.com/mimecast/dtail/internal/io/line"
"github.com/mimecast/dtail/internal/mapr"
"github.com/mimecast/dtail/internal/mapr/logformat"
- "github.com/mimecast/dtail/internal/protocol"
)
// Aggregate is for aggregating mapreduce data on the DTail server side.
@@ -282,7 +281,7 @@ func (a *Aggregate) aggregateAndSerialize(ctx context.Context,
serialize := func() {
dlog.Server.Info("Serializing mapreduce result")
group.Serialize(ctx, maprMessages)
- group = mapr.NewGroupSet()
+ group.InitSet()
}
for {
select {
@@ -301,16 +300,7 @@ func (a *Aggregate) aggregateAndSerialize(ctx context.Context,
}
func (a *Aggregate) aggregate(group *mapr.GroupSet, fields map[string]string) {
- var sb strings.Builder
- for i, field := range a.query.GroupBy {
- if i > 0 {
- sb.WriteString(protocol.AggregateGroupKeyCombinator)
- }
- if val, ok := fields[field]; ok {
- sb.WriteString(val)
- }
- }
- groupKey := sb.String()
+ groupKey := buildGroupKey(a.query.GroupBy, fields)
set := group.GetSet(groupKey)
var addedSample bool
diff --git a/internal/mapr/server/groupkey.go b/internal/mapr/server/groupkey.go
new file mode 100644
index 0000000..0963e4f
--- /dev/null
+++ b/internal/mapr/server/groupkey.go
@@ -0,0 +1,31 @@
+package server
+
+import (
+ "strings"
+
+ "github.com/mimecast/dtail/internal/protocol"
+)
+
+func buildGroupKey(groupBy []string, fields map[string]string) string {
+ if len(groupBy) == 0 {
+ return ""
+ }
+
+ total := 0
+ for _, field := range groupBy {
+ total += len(fields[field])
+ }
+ total += (len(groupBy) - 1) * len(protocol.AggregateGroupKeyCombinator)
+
+ var sb strings.Builder
+ sb.Grow(total)
+
+ for i, field := range groupBy {
+ if i > 0 {
+ sb.WriteString(protocol.AggregateGroupKeyCombinator)
+ }
+ sb.WriteString(fields[field])
+ }
+
+ return sb.String()
+}
diff --git a/internal/mapr/server/turbo_aggregate.go b/internal/mapr/server/turbo_aggregate.go
index 188be1c..c3aaf32 100644
--- a/internal/mapr/server/turbo_aggregate.go
+++ b/internal/mapr/server/turbo_aggregate.go
@@ -3,7 +3,6 @@ package server
import (
"bytes"
"context"
- "fmt"
"strings"
"sync"
"sync/atomic"
@@ -15,7 +14,6 @@ import (
"github.com/mimecast/dtail/internal/io/pool"
"github.com/mimecast/dtail/internal/mapr"
"github.com/mimecast/dtail/internal/mapr/logformat"
- "github.com/mimecast/dtail/internal/protocol"
)
// TurboAggregate is a high-performance aggregator for MapReduce operations in turbo mode.
@@ -28,9 +26,11 @@ type TurboAggregate struct {
query *mapr.Query
// The mapr log format parser
parser logformat.Parser
- // Group sets protected by mutex during serialization
- groupSets sync.Map // map[string]*mapr.SafeAggregateSet
- bufferMu sync.Mutex // Protects serialization
+ // Group sets are swapped out during serialization to avoid clone-heavy flushes.
+ groupMu sync.Mutex
+ groupSets map[string]*mapr.AggregateSet
+ // serializeMu ensures only one serialization runs at a time.
+ serializeMu sync.Mutex
// Batch processing
batchMu sync.Mutex
batch []rawLine
@@ -43,16 +43,16 @@ type TurboAggregate struct {
linesProcessed atomic.Uint64
errors atomic.Uint64
filesProcessed atomic.Uint64
- // Field map pool to reduce allocations
- fieldPool sync.Pool
- // Synchronization for clean shutdown
- processingWg sync.WaitGroup
+ // Synchronization for clean shutdown.
+ processorsWg sync.WaitGroup
// Track active file processors
activeProcessors atomic.Int32
+ startOnce sync.Once
+ started chan struct{}
}
type rawLine struct {
- content []byte
+ content *bytes.Buffer
sourceID string
}
@@ -104,109 +104,34 @@ func NewTurboAggregate(queryStr string, defaultLogFormat string) (*TurboAggregat
hostname: s[0],
query: query,
parser: logParser,
- groupSets: sync.Map{},
+ groupSets: make(map[string]*mapr.AggregateSet),
batchSize: 100, // Process 100 lines at a time
batch: make([]rawLine, 0, 100),
- fieldPool: sync.Pool{
- New: func() interface{} {
- return make(map[string]string, 20)
- },
- },
+ started: make(chan struct{}),
}, nil
}
// countGroups returns the current number of groups in the aggregation.
func (a *TurboAggregate) countGroups() int {
- count := 0
- a.groupSets.Range(func(_, _ interface{}) bool {
- count++
- return true
- })
- return count
-}
-
-// min returns the minimum of two integers.
-func min(a, b int) int {
- if a < b {
- return a
- }
- return b
+ a.groupMu.Lock()
+ defer a.groupMu.Unlock()
+ return len(a.groupSets)
}
// Shutdown the aggregation engine.
func (a *TurboAggregate) Shutdown() {
- dlog.Server.Info("TurboAggregate: Shutdown called",
- "linesProcessed", a.linesProcessed.Load(),
- "filesProcessed", a.filesProcessed.Load(),
- "activeProcessors", a.activeProcessors.Load(),
- "currentGroups", a.countGroups())
-
- // Signal shutdown
a.done.Shutdown()
-
- // Stop the ticker
a.stopSerializeTicker()
-
- // Wait for active processors to finish
- for a.activeProcessors.Load() > 0 {
- dlog.Server.Info("TurboAggregate: Waiting for active processors",
- "activeProcessors", a.activeProcessors.Load())
- time.Sleep(10 * time.Millisecond)
- }
-
- // Process any remaining batch synchronously
- dlog.Server.Info("TurboAggregate: Processing final batch")
+ a.processorsWg.Wait()
a.processBatchAndWait()
-
- // Wait for all processing to complete
- dlog.Server.Info("TurboAggregate: Waiting for all processing to complete")
- a.processingWg.Wait()
-
- dlog.Server.Info("TurboAggregate: All processing complete, groups before final serialization",
- "groupCount", a.countGroups())
-
- // Trigger final serialization after all processing is done
- // Use a longer timeout to ensure data gets through
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
-
- dlog.Server.Info("TurboAggregate: Triggering final serialization")
a.doSerialize(ctx)
-
- // Give more time for messages to be sent and processed
- // This is crucial to ensure the baseHandler's Read method picks up the messages
- dlog.Server.Info("TurboAggregate: Waiting for message delivery",
- "channelLen", len(a.maprMessages))
-
- // Wait for channel to drain or timeout
- timeout := time.After(2 * time.Second)
- ticker := time.NewTicker(50 * time.Millisecond)
- defer ticker.Stop()
-
- for {
- select {
- case <-timeout:
- dlog.Server.Warn("TurboAggregate: Timeout waiting for message delivery",
- "remainingMessages", len(a.maprMessages))
- return
- case <-ticker.C:
- if len(a.maprMessages) == 0 {
- dlog.Server.Info("TurboAggregate: All messages delivered")
- return
- }
- }
- }
}
// Abort stops background processing without waiting for final serialization.
// Session generation replacement uses this to preempt old query work immediately.
func (a *TurboAggregate) Abort() {
- dlog.Server.Info("TurboAggregate: Abort called",
- "linesProcessed", a.linesProcessed.Load(),
- "filesProcessed", a.filesProcessed.Load(),
- "activeProcessors", a.activeProcessors.Load(),
- "currentGroups", a.countGroups())
-
a.done.Shutdown()
a.stopSerializeTicker()
}
@@ -214,326 +139,148 @@ func (a *TurboAggregate) Abort() {
// Start the turbo aggregation.
func (a *TurboAggregate) Start(ctx context.Context, maprMessages chan<- string) {
a.maprMessages = maprMessages
+ interval := a.query.Interval
+ if interval <= 0 {
+ interval = time.Second
+ }
+ a.serializeTicker = time.NewTicker(interval)
+ a.startOnce.Do(func() {
+ if a.started != nil {
+ close(a.started)
+ }
+ })
+ defer a.stopSerializeTicker()
- dlog.Server.Info("TurboAggregate: Starting",
- "interval", a.query.Interval)
-
- // Start periodic serialization
- a.serializeTicker = time.NewTicker(a.query.Interval)
go a.serializationLoop(ctx)
- // Start batch processor
- go a.batchProcessorLoop(ctx)
-
- // Debug: Don't trigger immediate serialization - let data accumulate first
- dlog.Server.Info("TurboAggregate: Started, waiting for data")
+ select {
+ case <-ctx.Done():
+ case <-a.done.Done():
+ }
}
// ProcessLineDirect processes a line directly without channels.
// This is called from the TurboAggregateProcessor.
-func (a *TurboAggregate) ProcessLineDirect(lineContent []byte, sourceID string) error {
+func (a *TurboAggregate) ProcessLineDirect(lineContent *bytes.Buffer, sourceID string) error {
if a.stopping() {
+ pool.RecycleBytesBuffer(lineContent)
return nil
}
- // Increment counter first
a.linesProcessed.Add(1)
- // Debug: Track when lines are received
- totalLines := a.linesProcessed.Load()
- if totalLines == 1 || totalLines%1000 == 0 {
- dlog.Server.Info("TurboAggregate: ProcessLineDirect called",
- "totalLinesReceived", totalLines,
- "sourceID", sourceID,
- "lineLength", len(lineContent),
- "linePreview", string(lineContent[:min(50, len(lineContent))]))
- }
-
- // Make a copy of the line content as the buffer will be recycled
- content := make([]byte, len(lineContent))
- copy(content, lineContent)
-
// Add to batch
a.batchMu.Lock()
- a.batch = append(a.batch, rawLine{content: content, sourceID: sourceID})
+ a.batch = append(a.batch, rawLine{content: lineContent, sourceID: sourceID})
shouldProcess := len(a.batch) >= a.batchSize
- batchLen := len(a.batch)
a.batchMu.Unlock()
- if batchLen == 1 {
- dlog.Server.Info("TurboAggregate: First line received in batch",
- "sourceID", sourceID,
- "batchSize", a.batchSize)
- }
-
- // Process batch if full
if shouldProcess {
- dlog.Server.Debug("TurboAggregate: Batch full, processing",
- "batchLen", batchLen)
a.processBatch()
}
return nil
}
-// batchProcessorLoop continuously processes batches.
-func (a *TurboAggregate) batchProcessorLoop(ctx context.Context) {
- dlog.Server.Info("TurboAggregate: Batch processor loop started")
- defer dlog.Server.Info("TurboAggregate: Batch processor loop ended")
-
- ticker := time.NewTicker(100 * time.Millisecond)
- defer ticker.Stop()
-
- for {
- select {
- case <-a.done.Done():
- dlog.Server.Info("TurboAggregate: Batch processor stopped by shutdown")
- // Process any remaining batch synchronously before exiting
- a.processBatchAndWait()
- return
- case <-ticker.C:
- // Periodically process any accumulated batch
- a.processBatch()
-
- // Check if context is done but only exit if no pending work
- select {
- case <-ctx.Done():
- a.batchMu.Lock()
- batchLen := len(a.batch)
- a.batchMu.Unlock()
-
- activeProcs := a.activeProcessors.Load()
-
- if batchLen > 0 || activeProcs > 0 {
- dlog.Server.Info("TurboAggregate: Context cancelled but work pending",
- "batchLen", batchLen,
- "activeProcessors", activeProcs)
- // Continue processing
- } else {
- dlog.Server.Info("TurboAggregate: Context cancelled, no pending work")
- return
- }
- default:
- // Context not done, continue
- }
- }
- }
-}
-
-// processBatch processes a batch of lines asynchronously.
+// processBatch processes a full batch immediately.
func (a *TurboAggregate) processBatch() {
- a.batchMu.Lock()
- if len(a.batch) == 0 {
- a.batchMu.Unlock()
- return
- }
- batch := a.batch
- batchSize := len(batch)
- a.batch = make([]rawLine, 0, a.batchSize)
- a.batchMu.Unlock()
-
- dlog.Server.Info("TurboAggregate: Processing batch",
- "batchSize", batchSize,
- "totalLinesProcessed", a.linesProcessed.Load())
-
- // Track this batch processing
- a.processingWg.Add(1)
- defer a.processingWg.Done()
-
- // Process each line in the batch
- successCount := 0
- errorCount := 0
- for i, line := range batch {
- if err := a.processLine(line.content, line.sourceID); err != nil {
- a.errors.Add(1)
- errorCount++
- dlog.Server.Error("Error processing line:", err, "lineIndex", i)
- } else {
- successCount++
- }
- // Note: line count is already incremented in ProcessLineDirect
- }
-
- dlog.Server.Info("TurboAggregate: Batch processed",
- "successCount", successCount,
- "errorCount", errorCount,
- "totalLinesProcessed", a.linesProcessed.Load())
+ a.processRawBatch(a.takeBatch())
}
// processBatchAndWait processes a batch of lines synchronously and waits for completion.
// This is used when flushing to ensure all data is processed before continuing.
func (a *TurboAggregate) processBatchAndWait() {
+ a.processRawBatch(a.takeBatch())
+}
+
+func (a *TurboAggregate) takeBatch() []rawLine {
a.batchMu.Lock()
if len(a.batch) == 0 {
a.batchMu.Unlock()
- return
+ return nil
}
batch := a.batch
- batchSize := len(batch)
a.batch = make([]rawLine, 0, a.batchSize)
a.batchMu.Unlock()
+ return batch
+}
- dlog.Server.Info("TurboAggregate: Processing batch synchronously",
- "batchSize", batchSize,
- "totalLinesProcessed", a.linesProcessed.Load())
-
- // Process each line in the batch (no goroutine, synchronous)
- successCount := 0
- errorCount := 0
- for i, line := range batch {
- if err := a.processLine(line.content, line.sourceID); err != nil {
+func (a *TurboAggregate) processRawBatch(batch []rawLine) {
+ for i := range batch {
+ if err := a.processLine(batch[i].content, batch[i].sourceID); err != nil {
a.errors.Add(1)
- errorCount++
dlog.Server.Error("Error processing line:", err, "lineIndex", i)
- } else {
- successCount++
}
- // Note: line count is already incremented in ProcessLineDirect
+ if batch[i].content != nil {
+ pool.RecycleBytesBuffer(batch[i].content)
+ }
}
-
- dlog.Server.Info("TurboAggregate: Batch processed synchronously",
- "successCount", successCount,
- "errorCount", errorCount,
- "totalLinesProcessed", a.linesProcessed.Load())
}
// processLine processes a single line and aggregates it.
-func (a *TurboAggregate) processLine(lineContent []byte, sourceID string) error {
- // Trim whitespace
- maprLine := strings.TrimSpace(string(lineContent))
-
- // Debug: Log sample lines
- if a.linesProcessed.Load()%1000 == 0 {
- dlog.Server.Debug("TurboAggregate: Processing line",
- "lineNumber", a.linesProcessed.Load(),
- "linePreview", maprLine[:min(100, len(maprLine))])
- }
-
- // Get a field map from the pool
- fields := a.fieldPool.Get().(map[string]string)
- defer func() {
- // Clear the map before returning to pool
- for k := range fields {
- delete(fields, k)
- }
- a.fieldPool.Put(fields)
- }()
-
- // Parse the line
+func (a *TurboAggregate) processLine(lineContent *bytes.Buffer, _ string) error {
+ maprLine := strings.TrimSpace(lineContent.String())
parsedFields, err := a.parser.MakeFields(maprLine)
if err != nil {
if err != logformat.ErrIgnoreFields {
- dlog.Server.Debug("TurboAggregate: Parser error",
- "error", err,
- "line", maprLine[:min(100, len(maprLine))])
return err
}
return nil
}
- // Copy parsed fields to our pooled map
- for k, v := range parsedFields {
- fields[k] = v
- }
-
- // Debug: Log parsed fields for first few lines
- if a.linesProcessed.Load() < 5 {
- dlog.Server.Info("TurboAggregate: Parsed fields",
- "lineNumber", a.linesProcessed.Load(),
- "fieldCount", len(fields),
- "fields", fields)
- }
-
// Apply where clause
- if !a.query.WhereClause(fields) {
- dlog.Server.Debug("TurboAggregate: Line filtered by WHERE clause")
+ if !a.query.WhereClause(parsedFields) {
return nil
}
// Apply set clause if needed
if len(a.query.Set) > 0 {
- if err := a.query.SetClause(fields); err != nil {
- dlog.Server.Error("TurboAggregate: SET clause error", err)
+ if err := a.query.SetClause(parsedFields); err != nil {
return err
}
}
// Aggregate the fields
- a.aggregate(fields)
+ a.aggregate(parsedFields)
return nil
}
// aggregate adds fields to the appropriate group.
func (a *TurboAggregate) aggregate(fields map[string]string) {
- // Build group key
- var sb strings.Builder
- for i, field := range a.query.GroupBy {
- if i > 0 {
- sb.WriteString(protocol.AggregateGroupKeyCombinator)
- }
- if val, ok := fields[field]; ok {
- sb.WriteString(val)
- }
+ groupKey := buildGroupKey(a.query.GroupBy, fields)
+ a.groupMu.Lock()
+ set, ok := a.groupSets[groupKey]
+ if !ok {
+ set = mapr.NewAggregateSet()
+ a.groupSets[groupKey] = set
}
- groupKey := sb.String()
-
- // Get or create the aggregate set
- setInterface, loaded := a.groupSets.LoadOrStore(groupKey, mapr.NewSafeAggregateSet())
- set := setInterface.(*mapr.SafeAggregateSet)
-
- if !loaded {
- dlog.Server.Info("TurboAggregate: New group created",
- "groupKey", groupKey,
- "totalGroups", a.countGroups())
- }
-
- // Aggregate the values
var addedSample bool
- aggregatedFields := []string{}
for _, sc := range a.query.Select {
if val, ok := fields[sc.Field]; ok {
if err := set.Aggregate(sc.FieldStorage, sc.Operation, val, false); err != nil {
- dlog.Server.Error("TurboAggregate: Aggregation error",
- "field", sc.Field,
- "operation", sc.Operation,
- "error", err)
+ dlog.Server.Error("TurboAggregate aggregation error", err, "field", sc.Field, "operation", sc.Operation)
continue
}
addedSample = true
- aggregatedFields = append(aggregatedFields, sc.Field)
}
}
-
if addedSample {
- set.IncrementSamples()
- // Debug: Log aggregation details for first few samples
- if a.linesProcessed.Load() < 10 {
- dlog.Server.Info("TurboAggregate: Aggregated sample",
- "groupKey", groupKey,
- "aggregatedFields", aggregatedFields,
- "sampleCount", set.GetSamples())
- }
+ set.Samples++
}
+ a.groupMu.Unlock()
}
// serializationLoop handles periodic serialization.
func (a *TurboAggregate) serializationLoop(ctx context.Context) {
- dlog.Server.Info("TurboAggregate: Serialization loop started")
- defer dlog.Server.Info("TurboAggregate: Serialization loop ended")
-
for {
select {
case <-ctx.Done():
- dlog.Server.Info("TurboAggregate: Serialization loop stopped by context")
return
case <-a.done.Done():
- dlog.Server.Info("TurboAggregate: Serialization loop stopped by shutdown")
return
case <-a.serializeTicker.C:
- dlog.Server.Info("TurboAggregate: Periodic serialization triggered")
a.Serialize(ctx)
case <-a.serialize:
- dlog.Server.Info("TurboAggregate: Manual serialization triggered")
a.doSerialize(ctx)
}
}
@@ -551,119 +298,60 @@ func (a *TurboAggregate) Serialize(ctx context.Context) {
// doSerialize performs the actual serialization.
func (a *TurboAggregate) doSerialize(ctx context.Context) {
- dlog.Server.Info("TurboAggregate: Starting serialization",
- "linesProcessed", a.linesProcessed.Load(),
- "currentGroups", a.countGroups())
+ a.serializeMu.Lock()
+ defer a.serializeMu.Unlock()
- // Process any remaining batch synchronously before serialization
- dlog.Server.Info("TurboAggregate: Processing remaining batch before serialization")
a.processBatchAndWait()
+ if a.maprMessages == nil {
+ dlog.Server.Error("TurboAggregate maprMessages channel is nil")
+ return
+ }
- // Wait a moment for any in-progress batch processing
- dlog.Server.Info("TurboAggregate: Waiting for batch processing to complete")
- time.Sleep(50 * time.Millisecond) // Increased wait time
-
- // Lock to prevent concurrent modifications during serialization
- a.bufferMu.Lock()
- defer a.bufferMu.Unlock()
-
- // Count groups before serialization
- groupsBeforeSerialization := a.countGroups()
- dlog.Server.Info("TurboAggregate: Groups before serialization",
- "count", groupsBeforeSerialization)
-
- if groupsBeforeSerialization == 0 {
- dlog.Server.Warn("TurboAggregate: No groups to serialize!")
+ snapshot := a.swapGroupSets()
+ if len(snapshot) == 0 {
return
}
- // Create a new group set for serialization
group := mapr.NewGroupSet()
-
- // Copy all aggregate sets from the groupSets
- groupCount := 0
- sampleDetails := make([]string, 0)
- a.groupSets.Range(func(key, value interface{}) bool {
- groupKey := key.(string)
- safeSet := value.(*mapr.SafeAggregateSet)
-
- // Clone the safe set to get a regular AggregateSet
- clonedSet := safeSet.Clone()
-
- // Debug: Log details of first few groups
- if groupCount < 5 {
- sampleDetails = append(sampleDetails,
- fmt.Sprintf("group=%s, samples=%d", groupKey, clonedSet.Samples))
- }
-
- // Add to the group set
+ for groupKey, aggregateSet := range snapshot {
groupSet := group.GetSet(groupKey)
- *groupSet = *clonedSet
- groupCount++
-
- return true
- })
-
- dlog.Server.Info("TurboAggregate: Serialization details",
- "groupCount", groupCount,
- "sampleGroups", sampleDetails,
- "maprMessagesChannel", a.maprMessages != nil)
-
- // Check if we have a valid channel
- if a.maprMessages == nil {
- dlog.Server.Error("TurboAggregate: maprMessages channel is nil!")
- return
+ *groupSet = *aggregateSet
}
- // Serialize the group - use a longer timeout context for final serialization
- // to ensure data is sent even during shutdown
serializeCtx := ctx
if _, ok := ctx.Deadline(); ok {
- // If context has a deadline, extend it for serialization
var cancel context.CancelFunc
serializeCtx, cancel = context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
}
-
- dlog.Server.Info("TurboAggregate: Calling group.Serialize",
- "channelCap", cap(a.maprMessages),
- "channelLen", len(a.maprMessages))
-
group.Serialize(serializeCtx, a.maprMessages)
+}
- dlog.Server.Info("TurboAggregate: group.Serialize completed",
- "sentGroups", groupCount,
- "channelLen", len(a.maprMessages))
+func (a *TurboAggregate) swapGroupSets() map[string]*mapr.AggregateSet {
+ a.groupMu.Lock()
+ defer a.groupMu.Unlock()
- // Clear the groupSets after serialization only if not shutting down
- select {
- case <-a.done.Done():
- // During shutdown, keep the data for potential final serialization
- dlog.Server.Info("TurboAggregate: Keeping groupSets during shutdown")
- default:
- // Normal operation - clear for next interval
- dlog.Server.Info("TurboAggregate: Clearing groupSets for next interval")
- a.groupSets = sync.Map{}
+ if len(a.groupSets) == 0 {
+ return nil
}
- // Log the state after serialization
- groupsAfterSerialization := a.countGroups()
- dlog.Server.Info("TurboAggregate: After serialization",
- "groupsRemaining", groupsAfterSerialization)
+ snapshot := a.groupSets
+ a.groupSets = make(map[string]*mapr.AggregateSet, len(snapshot))
+ return snapshot
}
// TurboAggregateProcessor implements the line processor interface for turbo mode aggregation.
type TurboAggregateProcessor struct {
aggregate *TurboAggregate
globID string
+ flushOnce sync.Once
+ closeOnce sync.Once
}
// NewTurboAggregateProcessor creates a new turbo aggregate processor.
func NewTurboAggregateProcessor(aggregate *TurboAggregate, globID string) *TurboAggregateProcessor {
+ aggregate.processorsWg.Add(1)
aggregate.activeProcessors.Add(1)
- dlog.Server.Debug("TurboAggregate: New processor created",
- "globID", globID,
- "activeProcessors", aggregate.activeProcessors.Load())
return &TurboAggregateProcessor{
aggregate: aggregate,
globID: globID,
@@ -671,27 +359,12 @@ func NewTurboAggregateProcessor(aggregate *TurboAggregate, globID string) *Turbo
}
// ProcessLine processes a line directly to the turbo aggregate.
-func (p *TurboAggregateProcessor) ProcessLine(lineContent *bytes.Buffer, lineNum uint64, sourceID string) error {
+func (p *TurboAggregateProcessor) ProcessLine(lineContent *bytes.Buffer, _ uint64, sourceID string) error {
if p.aggregate.stopping() {
pool.RecycleBytesBuffer(lineContent)
return nil
}
-
- // Debug: Log when ProcessLine is called
- if lineNum == 1 || lineNum%1000 == 0 {
- dlog.Server.Info("TurboAggregateProcessor: ProcessLine called",
- "lineNum", lineNum,
- "sourceID", sourceID,
- "contentLen", lineContent.Len())
- }
-
- // Process the line directly
- err := p.aggregate.ProcessLineDirect(lineContent.Bytes(), sourceID)
-
- // Recycle the buffer
- pool.RecycleBytesBuffer(lineContent)
-
- return err
+ return p.aggregate.ProcessLineDirect(lineContent, sourceID)
}
// Flush ensures all buffered data is processed.
@@ -700,30 +373,19 @@ func (p *TurboAggregateProcessor) Flush() error {
return nil
}
- // Log flush call for debugging
- dlog.Server.Info("TurboAggregateProcessor: Flush called",
- "globID", p.globID,
- "linesProcessed", p.aggregate.linesProcessed.Load())
-
- // Process any remaining batch synchronously
- p.aggregate.processBatchAndWait()
-
- // Increment files processed counter
- p.aggregate.filesProcessed.Add(1)
-
- dlog.Server.Info("TurboAggregateProcessor: Flush completed",
- "globID", p.globID,
- "linesProcessed", p.aggregate.linesProcessed.Load(),
- "filesProcessed", p.aggregate.filesProcessed.Load())
+ p.flushOnce.Do(func() {
+ p.aggregate.processBatchAndWait()
+ p.aggregate.filesProcessed.Add(1)
+ })
return nil
}
// Close flushes any remaining data.
func (p *TurboAggregateProcessor) Close() error {
err := p.Flush()
- p.aggregate.activeProcessors.Add(-1)
- dlog.Server.Debug("TurboAggregate: Processor closed",
- "globID", p.globID,
- "activeProcessors", p.aggregate.activeProcessors.Load())
+ p.closeOnce.Do(func() {
+ p.aggregate.activeProcessors.Add(-1)
+ p.aggregate.processorsWg.Done()
+ })
return err
}
diff --git a/internal/mapr/server/turbo_aggregate_test.go b/internal/mapr/server/turbo_aggregate_test.go
index 7ae4b5a..e674252 100644
--- a/internal/mapr/server/turbo_aggregate_test.go
+++ b/internal/mapr/server/turbo_aggregate_test.go
@@ -62,8 +62,12 @@ func TestTurboAggregateVsRegular(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
- // Start the turbo aggregate
- turboAgg.Start(ctx, messages)
+ startDone := make(chan struct{})
+ go func() {
+ defer close(startDone)
+ turboAgg.Start(ctx, messages)
+ }()
+ waitForTurboAggregateStart(t, turboAgg)
// Process lines
processor := NewTurboAggregateProcessor(turboAgg, "test")
@@ -92,6 +96,7 @@ func TestTurboAggregateVsRegular(t *testing.T) {
// Cancel context to stop background goroutines
cancel()
+ <-startDone
// Collect results with timeout
done := make(chan struct{})
@@ -169,16 +174,21 @@ func TestTurboAggregateVsRegular(t *testing.T) {
}
close(lines)
- // Wait for processing
- time.Sleep(100 * time.Millisecond)
-
- // Shutdown
- regularAgg.Shutdown()
+ // Wait for the aggregate to drain the closed line channel and serialize naturally.
+ done := make(chan struct{})
+ go func() {
+ wg.Wait()
+ close(done)
+ }()
+ select {
+ case <-done:
+ case <-time.After(2 * time.Second):
+ regularAgg.Shutdown()
+ cancel()
+ t.Fatal("Timeout waiting for regular aggregate to finish")
+ }
cancel()
- // Wait for the Start goroutine to finish
- wg.Wait()
-
// Collect results
close(messages)
@@ -232,10 +242,15 @@ func TestTurboAggregateConcurrency(t *testing.T) {
// Channel to collect messages
messages := make(chan string, 1000)
- ctx := context.Background()
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
- // Start the turbo aggregate
- turboAgg.Start(ctx, messages)
+ startDone := make(chan struct{})
+ go func() {
+ defer close(startDone)
+ turboAgg.Start(ctx, messages)
+ }()
+ waitForTurboAggregateStart(t, turboAgg)
// Process multiple "files" concurrently
var wg sync.WaitGroup
@@ -269,6 +284,8 @@ func TestTurboAggregateConcurrency(t *testing.T) {
// Shutdown and get results
turboAgg.Shutdown()
+ cancel()
+ <-startDone
// Collect results
time.Sleep(200 * time.Millisecond)
@@ -291,9 +308,8 @@ func TestTurboAggregateConcurrency(t *testing.T) {
t.Errorf("Expected %d lines processed, got %d", expectedLines, turboAgg.linesProcessed.Load())
}
- // Verify file count (may be higher if test was run multiple times)
- if turboAgg.filesProcessed.Load() < uint64(numFiles) {
- t.Errorf("Expected at least %d files processed, got %d", numFiles, turboAgg.filesProcessed.Load())
+ if turboAgg.filesProcessed.Load() != uint64(numFiles) {
+ t.Errorf("Expected %d files processed, got %d", numFiles, turboAgg.filesProcessed.Load())
}
// Parse result to check count
@@ -330,3 +346,38 @@ func TestTurboAggregateAbortReturnsPromptlyWithActiveProcessors(t *testing.T) {
t.Fatal("Abort did not return promptly while processors were still active")
}
}
+
+func TestTurboAggregateProcessorCountsFlushOnce(t *testing.T) {
+ aggregate := &TurboAggregate{
+ done: internal.NewDone(),
+ batchSize: 16,
+ }
+
+ processor := NewTurboAggregateProcessor(aggregate, "test")
+ if err := processor.Flush(); err != nil {
+ t.Fatalf("Flush failed: %v", err)
+ }
+ if err := processor.Close(); err != nil {
+ t.Fatalf("Close failed: %v", err)
+ }
+
+ if got := aggregate.filesProcessed.Load(); got != 1 {
+ t.Fatalf("expected filesProcessed to be 1, got %d", got)
+ }
+ if got := aggregate.activeProcessors.Load(); got != 0 {
+ t.Fatalf("expected activeProcessors to be 0, got %d", got)
+ }
+}
+
+func waitForTurboAggregateStart(t *testing.T, aggregate *TurboAggregate) {
+ t.Helper()
+
+ if aggregate.started == nil {
+ t.Fatal("turbo aggregate missing start signal")
+ }
+ select {
+ case <-aggregate.started:
+ case <-time.After(500 * time.Millisecond):
+ t.Fatal("turbo aggregate did not finish Start initialization")
+ }
+}
diff --git a/internal/server/handlers/serverhandler.go b/internal/server/handlers/serverhandler.go
index cd930f9..e008473 100644
--- a/internal/server/handlers/serverhandler.go
+++ b/internal/server/handlers/serverhandler.go
@@ -163,8 +163,8 @@ func (h *ServerHandler) handleMapCommand(ctx context.Context, _ lcontext.LContex
h.turboAggregate = turboAggregate
maprMessages, closeMaprMessages := h.newGeneratedMaprMessagesChannel(ctx, sessionGenerationFromContext(ctx))
go func() {
- defer closeMaprMessages()
command.Start(ctx, maprMessages)
+ closeMaprMessages()
commandFinished()
}()
}
diff --git a/internal/tools/profile/profile.go b/internal/tools/profile/profile.go
index d87662c..21508b2 100644
--- a/internal/tools/profile/profile.go
+++ b/internal/tools/profile/profile.go
@@ -64,29 +64,29 @@ func parseFlags() *Config {
flag.IntVar(&cfg.Runs, "runs", 1, "Number of profiling runs")
flag.BoolVar(&cfg.NoColor, "nocolor", false, "Disable colored output")
flag.DurationVar(&cfg.Timeout, "timeout", cfg.Timeout, "Timeout for profiling runs")
-
+
// Custom command list
var cmdList string
flag.StringVar(&cmdList, "commands", "", "Comma-separated list of commands to profile")
-
+
flag.Parse()
-
+
if cmdList != "" {
cfg.Commands = strings.Split(cmdList, ",")
}
-
+
return cfg
}
func runQuickProfile(cfg *Config) error {
common.PrintSection("DTail Quick Profiling")
-
+
// Generate test data
gen := common.NewDataGenerator()
-
+
logFile := filepath.Join(cfg.TestDataDir, "quick_test.log")
csvFile := filepath.Join(cfg.TestDataDir, "quick_test.csv")
-
+
common.PrintInfo("Generating test data...\n")
if err := gen.GenerateFile(logFile, "10MB", common.FormatLog); err != nil {
return fmt.Errorf("failed to generate log file: %w", err)
@@ -94,31 +94,31 @@ func runQuickProfile(cfg *Config) error {
if err := gen.GenerateFile(csvFile, "10MB", common.FormatCSV); err != nil {
return fmt.Errorf("failed to generate CSV file: %w", err)
}
-
+
// Build commands
common.PrintInfo("Building commands...\n")
if err := common.BuildCommands("dcat", "dgrep", "dmap"); err != nil {
return err
}
-
+
// Profile each command
common.PrintSection("Running quick profiles...")
-
+
// Profile dcat
if err := profileCommand("dcat", "dcat",
[]string{"-profile", "-profiledir", cfg.ProfileDir, "-plain", "-cfg", "none", logFile},
cfg.Timeout); err != nil {
return err
}
-
+
// Profile dgrep
if err := profileCommand("dgrep", "dgrep",
- []string{"-profile", "-profiledir", cfg.ProfileDir, "-plain", "-cfg", "none",
+ []string{"-profile", "-profiledir", cfg.ProfileDir, "-plain", "-cfg", "none",
"-regex", "user[0-9]+", logFile},
cfg.Timeout); err != nil {
return err
}
-
+
// Profile dmap
query := `select count($line),avg($duration) group by $user logformat csv`
if err := profileCommand("dmap", "dmap",
@@ -127,24 +127,24 @@ func runQuickProfile(cfg *Config) error {
cfg.Timeout); err != nil {
return err
}
-
+
// Analyze results
return analyzeLatestProfiles(cfg)
}
func runFullProfile(cfg *Config) error {
common.PrintSection("DTail Full Profiling")
-
+
// Generate test data
gen := common.NewDataGenerator()
-
+
testFiles := map[string]string{
"small.log": "10MB",
"medium.log": "100MB",
"test.csv": "50MB",
"dtail_format.log": "100000", // lines
}
-
+
common.PrintInfo("Generating test data...\n")
for filename, size := range testFiles {
fullPath := filepath.Join(cfg.TestDataDir, filename)
@@ -163,16 +163,16 @@ func runFullProfile(cfg *Config) error {
}
}
}
-
+
// Build commands
common.PrintInfo("Building commands...\n")
if err := common.BuildCommands("dcat", "dgrep", "dmap"); err != nil {
return err
}
-
+
// Run profiling
common.PrintSection("Running full profiling suite...")
-
+
// Profile configurations
profiles := []struct {
cmd string
@@ -184,24 +184,24 @@ func runFullProfile(cfg *Config) error {
filepath.Join(cfg.TestDataDir, "small.log")}},
{"dcat", "medium_file", []string{"-profile", "-profiledir", cfg.ProfileDir, "-plain", "-cfg", "none",
filepath.Join(cfg.TestDataDir, "medium.log")}},
-
+
// dgrep profiles
{"dgrep", "simple_pattern", []string{"-profile", "-profiledir", cfg.ProfileDir, "-plain", "-cfg", "none",
"-regex", "ERROR", filepath.Join(cfg.TestDataDir, "medium.log")}},
{"dgrep", "complex_pattern", []string{"-profile", "-profiledir", cfg.ProfileDir, "-plain", "-cfg", "none",
"-regex", "(ERROR|WARN).*user[0-9]+", filepath.Join(cfg.TestDataDir, "medium.log")}},
-
+
// dmap profiles
{"dmap", "simple_count", []string{"-profile", "-profiledir", cfg.ProfileDir, "-plain", "-cfg", "none",
"-query", "from STATS select count(*)", "-files", filepath.Join(cfg.TestDataDir, "dtail_format.log")}},
{"dmap", "aggregations", []string{"-profile", "-profiledir", cfg.ProfileDir, "-plain", "-cfg", "none",
- "-query", "from STATS select sum($goroutines),avg($cgocalls),max(lifetimeConnections)",
+ "-query", "from STATS select sum($goroutines),avg($cgocalls),max(lifetimeConnections)",
"-files", filepath.Join(cfg.TestDataDir, "dtail_format.log")}},
{"dmap", "csv_query", []string{"-profile", "-profiledir", cfg.ProfileDir, "-plain", "-cfg", "none",
"-query", `select count($line),count($user),count($action) group by $user,$action where $status eq "success" logformat csv`,
"-files", filepath.Join(cfg.TestDataDir, "test.csv")}},
}
-
+
for _, p := range profiles {
common.PrintInfo("\nProfiling %s - %s\n", p.cmd, p.name)
for i := 1; i <= cfg.Runs; i++ {
@@ -216,19 +216,19 @@ func runFullProfile(cfg *Config) error {
}
}
}
-
+
return analyzeLatestProfiles(cfg)
}
func runDMapProfile(cfg *Config) error {
common.PrintSection("DTail dmap Profiling")
-
+
// Generate MapReduce test data
gen := common.NewDataGenerator()
-
+
smallFile := filepath.Join(cfg.TestDataDir, "stats_small.log")
mediumFile := filepath.Join(cfg.TestDataDir, "stats_medium.log")
-
+
common.PrintInfo("Preparing MapReduce test data...\n")
if err := gen.GenerateLogFileWithLines(smallFile, 1000, common.FormatDTail); err != nil {
return fmt.Errorf("failed to generate small file: %w", err)
@@ -236,16 +236,16 @@ func runDMapProfile(cfg *Config) error {
if err := gen.GenerateLogFileWithLines(mediumFile, 1000000, common.FormatDTail); err != nil {
return fmt.Errorf("failed to generate medium file: %w", err)
}
-
+
// Build dmap
common.PrintInfo("Building dmap...\n")
if err := common.BuildCommand("dmap"); err != nil {
return err
}
-
+
// Profile different queries
common.PrintSection("Profiling dmap queries...")
-
+
queries := []struct {
name string
query string
@@ -256,7 +256,7 @@ func runDMapProfile(cfg *Config) error {
{"Min and max", "from STATS select min(currentConnections),max(lifetimeConnections) group by hostname", smallFile},
{"Large file processing", "from STATS select count($line),avg($goroutines) group by hostname", mediumFile},
}
-
+
for _, q := range queries {
common.PrintInfo("\nQuery: %s\n", q.name)
args := []string{"-profile", "-profiledir", cfg.ProfileDir, "-plain", "-cfg", "none",
@@ -265,26 +265,26 @@ func runDMapProfile(cfg *Config) error {
return fmt.Errorf("failed to profile query %s: %w", q.name, err)
}
}
-
+
return analyzeLatestProfiles(cfg)
}
func profileCommand(name, cmd string, args []string, timeout time.Duration) error {
fmt.Printf("Command: %s %s\n", cmd, strings.Join(args, " "))
-
+
command := exec.Command("./"+cmd, args...)
command.Stdout = nil // Suppress output during profiling
command.Stderr = os.Stderr
-
+
if err := command.Start(); err != nil {
return err
}
-
+
done := make(chan error, 1)
go func() {
done <- command.Wait()
}()
-
+
select {
case <-time.After(timeout):
command.Process.Kill()
@@ -294,9 +294,9 @@ func profileCommand(name, cmd string, args []string, timeout time.Duration) erro
return err
}
}
-
+
// Find generated profile
- pattern := filepath.Join("profiles", fmt.Sprintf("%s_cpu_*.prof", name))
+ pattern := filepath.Join(profileDirFromArgs(args), fmt.Sprintf("%s_cpu_*.prof", name))
matches, _ := filepath.Glob(pattern)
if len(matches) > 0 {
// Sort by modification time and get the latest
@@ -307,52 +307,61 @@ func profileCommand(name, cmd string, args []string, timeout time.Duration) erro
})
fmt.Printf(" Generated: %s\n", filepath.Base(matches[0]))
}
-
+
return nil
}
+func profileDirFromArgs(args []string) string {
+ for i := 0; i < len(args)-1; i++ {
+ if args[i] == "-profiledir" {
+ return args[i+1]
+ }
+ }
+ return "profiles"
+}
+
func analyzeLatestProfiles(cfg *Config) error {
common.PrintSection("Profile Analysis")
-
+
// Find latest profiles for each command
for _, cmd := range cfg.Commands {
cpuPattern := filepath.Join(cfg.ProfileDir, fmt.Sprintf("%s_cpu_*.prof", cmd))
memPattern := filepath.Join(cfg.ProfileDir, fmt.Sprintf("%s_mem_*.prof", cmd))
-
+
cpuProfiles, _ := filepath.Glob(cpuPattern)
memProfiles, _ := filepath.Glob(memPattern)
-
+
if len(cpuProfiles) > 0 {
sort.Slice(cpuProfiles, func(i, j int) bool {
fi, _ := os.Stat(cpuProfiles[i])
fj, _ := os.Stat(cpuProfiles[j])
return fi.ModTime().After(fj.ModTime())
})
-
+
fmt.Printf("\n%s CPU Profile: %s\n", cmd, filepath.Base(cpuProfiles[0]))
if err := showTopFunctions(cpuProfiles[0], 5, false); err != nil {
fmt.Printf(" Analysis failed: %v\n", err)
}
}
-
+
if len(memProfiles) > 0 {
sort.Slice(memProfiles, func(i, j int) bool {
fi, _ := os.Stat(memProfiles[i])
fj, _ := os.Stat(memProfiles[j])
return fi.ModTime().After(fj.ModTime())
})
-
+
fmt.Printf("\n%s Memory Profile: %s\n", cmd, filepath.Base(memProfiles[0]))
if err := showTopFunctions(memProfiles[0], 5, true); err != nil {
fmt.Printf(" Analysis failed: %v\n", err)
}
}
}
-
+
common.PrintSuccess("\nProfiling complete!\n")
fmt.Println("\nTo analyze profiles in detail:")
fmt.Printf(" go tool pprof %s/<profile_file>\n", cfg.ProfileDir)
fmt.Printf(" dtail-tools profile -mode analyze <profile_file>\n")
-
+
return nil
-} \ No newline at end of file
+}
diff --git a/internal/tools/profile/profile_test.go b/internal/tools/profile/profile_test.go
new file mode 100644
index 0000000..1a11fdd
--- /dev/null
+++ b/internal/tools/profile/profile_test.go
@@ -0,0 +1,30 @@
+package profile
+
+import "testing"
+
+func TestProfileDirFromArgs(t *testing.T) {
+ tests := []struct {
+ name string
+ args []string
+ want string
+ }{
+ {
+ name: "explicit profile dir",
+ args: []string{"-profile", "-profiledir", "custom-profiles", "-plain"},
+ want: "custom-profiles",
+ },
+ {
+ name: "missing profile dir falls back to default",
+ args: []string{"-profile", "-plain"},
+ want: "profiles",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := profileDirFromArgs(tt.args); got != tt.want {
+ t.Fatalf("profileDirFromArgs(%v) = %q, want %q", tt.args, got, tt.want)
+ }
+ })
+ }
+}