summaryrefslogtreecommitdiff
path: root/internal/flags
diff options
context:
space:
mode:
Diffstat (limited to 'internal/flags')
-rw-r--r--internal/flags/flags.go40
-rw-r--r--internal/flags/flags_test.go55
2 files changed, 74 insertions, 21 deletions
diff --git a/internal/flags/flags.go b/internal/flags/flags.go
index 3f6bfc3..19f9a63 100644
--- a/internal/flags/flags.go
+++ b/internal/flags/flags.go
@@ -19,6 +19,7 @@ var (
TUIExportEnable: true,
}
once sync.Once
+ parseErr error
pidFilter atomic.Int64
tidFilter atomic.Int64
tuiExportEnable atomic.Bool
@@ -98,13 +99,14 @@ func SetTUIExportEnable(enabled bool) {
tuiExportEnable.Store(enabled)
}
-func Parse() {
+func Parse() error {
once.Do(func() {
- parse()
+ parseErr = parse()
})
+ return parseErr
}
-func parse() {
+func parse() error {
flag.IntVar(&singleton.PidFilter, "pid", -1, "Filter for processes ID")
flag.IntVar(&singleton.TidFilter, "tid", -1, "Filter for thread ID")
flag.IntVar(&singleton.EventMapSize, "mapSize", 4096*16, "BPF FD event ring buffer map size")
@@ -130,13 +132,22 @@ func parse() {
fmt.Sprintf("Comma separated list of fields to collapse, valid are: %v", validCollapsedFields))
flag.StringVar(&singleton.CountField, "count", "count",
fmt.Sprintf("Count field to collapse, valid are: %v", validCollapsedCounts))
- flag.Parse()
+ if err := flag.CommandLine.Parse(os.Args[1:]); err != nil {
+ return err
+ }
pidFilter.Store(int64(singleton.PidFilter))
tidFilter.Store(int64(singleton.TidFilter))
tuiExportEnable.Store(singleton.TUIExportEnable)
- singleton.TracepointsToAttach = extractTracepointFlags(*tracepointsToAttach)
- singleton.TracepointsToExclude = extractTracepointFlags(*tracepointsToExclude)
+ var err error
+ singleton.TracepointsToAttach, err = extractTracepointFlags(*tracepointsToAttach)
+ if err != nil {
+ return err
+ }
+ singleton.TracepointsToExclude, err = extractTracepointFlags(*tracepointsToExclude)
+ if err != nil {
+ return err
+ }
// Keep this list empty by default.
// As of February 23, 2026, open_by_handle_at and name_to_handle_at were
@@ -151,30 +162,29 @@ func parse() {
for _, field := range singleton.CollapsedFields {
if !slices.Contains(validCollapsedFields, field) {
- fmt.Println("Invalid field for collapse:", field)
- os.Exit(2)
+ return fmt.Errorf("invalid field for collapse: %s", field)
}
}
if !slices.Contains(validCollapsedCounts, singleton.CountField) {
- fmt.Println("Invalid count field:", singleton.CountField)
- os.Exit(2)
+ return fmt.Errorf("invalid count field: %s", singleton.CountField)
}
+
+ return nil
}
-func extractTracepointFlags(tracepoints string) (regexes []*regexp.Regexp) {
+func extractTracepointFlags(tracepoints string) (regexes []*regexp.Regexp, err error) {
if len(tracepoints) == 0 {
- return regexes
+ return regexes, nil
}
for _, name := range strings.Split(tracepoints, ",") {
re, err := regexp.Compile(name)
if err != nil {
- fmt.Println("Unable to compile regex", name, ": ", err)
- os.Exit(2)
+ return nil, fmt.Errorf("unable to compile regex %q: %w", name, err)
}
regexes = append(regexes, re)
}
- return regexes
+ return regexes, nil
}
func (flags Flags) ShouldIAttachTracepoint(tracepointName string) bool {
diff --git a/internal/flags/flags_test.go b/internal/flags/flags_test.go
index 9fc6570..a4feb5d 100644
--- a/internal/flags/flags_test.go
+++ b/internal/flags/flags_test.go
@@ -4,18 +4,20 @@ import (
"flag"
"io"
"os"
+ "strings"
"sync"
"testing"
"time"
)
-func parseForTest(t *testing.T, args ...string) Flags {
+func parseForTest(t *testing.T, args ...string) (Flags, error) {
t.Helper()
oldCommandLine := flag.CommandLine
oldArgs := os.Args
oldSingleton := singleton
oldOnce := once
+ oldParseErr := parseErr
oldPID := pidFilter.Load()
oldTID := tidFilter.Load()
oldTUIExport := tuiExportEnable.Load()
@@ -27,11 +29,12 @@ func parseForTest(t *testing.T, args ...string) Flags {
singleton = Flags{TUIExportEnable: true}
once = sync.Once{}
+ parseErr = nil
pidFilter.Store(-1)
tidFilter.Store(-1)
tuiExportEnable.Store(true)
- parse()
+ err := parse()
cfg := singleton
t.Cleanup(func() {
@@ -39,16 +42,20 @@ func parseForTest(t *testing.T, args ...string) Flags {
os.Args = oldArgs
singleton = oldSingleton
once = oldOnce
+ parseErr = oldParseErr
pidFilter.Store(oldPID)
tidFilter.Store(oldTID)
tuiExportEnable.Store(oldTUIExport)
})
- return cfg
+ return cfg, err
}
func TestParseLiveFlagsAndInterval(t *testing.T) {
- cfg := parseForTest(t, "-live", "-live-interval", "200ms", "-pid", "1234")
+ cfg, err := parseForTest(t, "-live", "-live-interval", "200ms", "-pid", "1234")
+ if err != nil {
+ t.Fatalf("parse returned error: %v", err)
+ }
if !cfg.LiveFlamegraph {
t.Fatalf("expected -live to enable live mode")
@@ -65,7 +72,10 @@ func TestParseLiveFlagsAndInterval(t *testing.T) {
}
func TestParseLiveDefaults(t *testing.T) {
- cfg := parseForTest(t)
+ cfg, err := parseForTest(t)
+ if err != nil {
+ t.Fatalf("parse returned error: %v", err)
+ }
if cfg.LiveFlamegraph {
t.Fatalf("expected live mode disabled by default")
@@ -76,7 +86,10 @@ func TestParseLiveDefaults(t *testing.T) {
}
func TestParseDefaultCollapsedFieldsOrder(t *testing.T) {
- cfg := parseForTest(t)
+ cfg, err := parseForTest(t)
+ if err != nil {
+ t.Fatalf("parse returned error: %v", err)
+ }
want := []string{"comm", "path", "tracepoint"}
if len(cfg.CollapsedFields) != len(want) {
@@ -88,3 +101,33 @@ func TestParseDefaultCollapsedFieldsOrder(t *testing.T) {
}
}
}
+
+func TestParseInvalidCollapsedFieldReturnsError(t *testing.T) {
+ _, err := parseForTest(t, "-fields", "comm,invalid")
+ if err == nil {
+ t.Fatalf("expected parse error for invalid collapsed field")
+ }
+ if !strings.Contains(err.Error(), "invalid field for collapse: invalid") {
+ t.Fatalf("unexpected error: %v", err)
+ }
+}
+
+func TestParseInvalidCountFieldReturnsError(t *testing.T) {
+ _, err := parseForTest(t, "-count", "invalid")
+ if err == nil {
+ t.Fatalf("expected parse error for invalid count field")
+ }
+ if !strings.Contains(err.Error(), "invalid count field: invalid") {
+ t.Fatalf("unexpected error: %v", err)
+ }
+}
+
+func TestParseInvalidTracepointRegexReturnsError(t *testing.T) {
+ _, err := parseForTest(t, "-tps", "[")
+ if err == nil {
+ t.Fatalf("expected parse error for invalid tracepoint regex")
+ }
+ if !strings.Contains(err.Error(), "unable to compile regex") {
+ t.Fatalf("unexpected error: %v", err)
+ }
+}