diff options
Diffstat (limited to 'internal/flags')
| -rw-r--r-- | internal/flags/flags.go | 40 | ||||
| -rw-r--r-- | internal/flags/flags_test.go | 55 |
2 files changed, 74 insertions, 21 deletions
diff --git a/internal/flags/flags.go b/internal/flags/flags.go index 3f6bfc3..19f9a63 100644 --- a/internal/flags/flags.go +++ b/internal/flags/flags.go @@ -19,6 +19,7 @@ var ( TUIExportEnable: true, } once sync.Once + parseErr error pidFilter atomic.Int64 tidFilter atomic.Int64 tuiExportEnable atomic.Bool @@ -98,13 +99,14 @@ func SetTUIExportEnable(enabled bool) { tuiExportEnable.Store(enabled) } -func Parse() { +func Parse() error { once.Do(func() { - parse() + parseErr = parse() }) + return parseErr } -func parse() { +func parse() error { flag.IntVar(&singleton.PidFilter, "pid", -1, "Filter for processes ID") flag.IntVar(&singleton.TidFilter, "tid", -1, "Filter for thread ID") flag.IntVar(&singleton.EventMapSize, "mapSize", 4096*16, "BPF FD event ring buffer map size") @@ -130,13 +132,22 @@ func parse() { fmt.Sprintf("Comma separated list of fields to collapse, valid are: %v", validCollapsedFields)) flag.StringVar(&singleton.CountField, "count", "count", fmt.Sprintf("Count field to collapse, valid are: %v", validCollapsedCounts)) - flag.Parse() + if err := flag.CommandLine.Parse(os.Args[1:]); err != nil { + return err + } pidFilter.Store(int64(singleton.PidFilter)) tidFilter.Store(int64(singleton.TidFilter)) tuiExportEnable.Store(singleton.TUIExportEnable) - singleton.TracepointsToAttach = extractTracepointFlags(*tracepointsToAttach) - singleton.TracepointsToExclude = extractTracepointFlags(*tracepointsToExclude) + var err error + singleton.TracepointsToAttach, err = extractTracepointFlags(*tracepointsToAttach) + if err != nil { + return err + } + singleton.TracepointsToExclude, err = extractTracepointFlags(*tracepointsToExclude) + if err != nil { + return err + } // Keep this list empty by default. // As of February 23, 2026, open_by_handle_at and name_to_handle_at were @@ -151,30 +162,29 @@ func parse() { for _, field := range singleton.CollapsedFields { if !slices.Contains(validCollapsedFields, field) { - fmt.Println("Invalid field for collapse:", field) - os.Exit(2) + return fmt.Errorf("invalid field for collapse: %s", field) } } if !slices.Contains(validCollapsedCounts, singleton.CountField) { - fmt.Println("Invalid count field:", singleton.CountField) - os.Exit(2) + return fmt.Errorf("invalid count field: %s", singleton.CountField) } + + return nil } -func extractTracepointFlags(tracepoints string) (regexes []*regexp.Regexp) { +func extractTracepointFlags(tracepoints string) (regexes []*regexp.Regexp, err error) { if len(tracepoints) == 0 { - return regexes + return regexes, nil } for _, name := range strings.Split(tracepoints, ",") { re, err := regexp.Compile(name) if err != nil { - fmt.Println("Unable to compile regex", name, ": ", err) - os.Exit(2) + return nil, fmt.Errorf("unable to compile regex %q: %w", name, err) } regexes = append(regexes, re) } - return regexes + return regexes, nil } func (flags Flags) ShouldIAttachTracepoint(tracepointName string) bool { diff --git a/internal/flags/flags_test.go b/internal/flags/flags_test.go index 9fc6570..a4feb5d 100644 --- a/internal/flags/flags_test.go +++ b/internal/flags/flags_test.go @@ -4,18 +4,20 @@ import ( "flag" "io" "os" + "strings" "sync" "testing" "time" ) -func parseForTest(t *testing.T, args ...string) Flags { +func parseForTest(t *testing.T, args ...string) (Flags, error) { t.Helper() oldCommandLine := flag.CommandLine oldArgs := os.Args oldSingleton := singleton oldOnce := once + oldParseErr := parseErr oldPID := pidFilter.Load() oldTID := tidFilter.Load() oldTUIExport := tuiExportEnable.Load() @@ -27,11 +29,12 @@ func parseForTest(t *testing.T, args ...string) Flags { singleton = Flags{TUIExportEnable: true} once = sync.Once{} + parseErr = nil pidFilter.Store(-1) tidFilter.Store(-1) tuiExportEnable.Store(true) - parse() + err := parse() cfg := singleton t.Cleanup(func() { @@ -39,16 +42,20 @@ func parseForTest(t *testing.T, args ...string) Flags { os.Args = oldArgs singleton = oldSingleton once = oldOnce + parseErr = oldParseErr pidFilter.Store(oldPID) tidFilter.Store(oldTID) tuiExportEnable.Store(oldTUIExport) }) - return cfg + return cfg, err } func TestParseLiveFlagsAndInterval(t *testing.T) { - cfg := parseForTest(t, "-live", "-live-interval", "200ms", "-pid", "1234") + cfg, err := parseForTest(t, "-live", "-live-interval", "200ms", "-pid", "1234") + if err != nil { + t.Fatalf("parse returned error: %v", err) + } if !cfg.LiveFlamegraph { t.Fatalf("expected -live to enable live mode") @@ -65,7 +72,10 @@ func TestParseLiveFlagsAndInterval(t *testing.T) { } func TestParseLiveDefaults(t *testing.T) { - cfg := parseForTest(t) + cfg, err := parseForTest(t) + if err != nil { + t.Fatalf("parse returned error: %v", err) + } if cfg.LiveFlamegraph { t.Fatalf("expected live mode disabled by default") @@ -76,7 +86,10 @@ func TestParseLiveDefaults(t *testing.T) { } func TestParseDefaultCollapsedFieldsOrder(t *testing.T) { - cfg := parseForTest(t) + cfg, err := parseForTest(t) + if err != nil { + t.Fatalf("parse returned error: %v", err) + } want := []string{"comm", "path", "tracepoint"} if len(cfg.CollapsedFields) != len(want) { @@ -88,3 +101,33 @@ func TestParseDefaultCollapsedFieldsOrder(t *testing.T) { } } } + +func TestParseInvalidCollapsedFieldReturnsError(t *testing.T) { + _, err := parseForTest(t, "-fields", "comm,invalid") + if err == nil { + t.Fatalf("expected parse error for invalid collapsed field") + } + if !strings.Contains(err.Error(), "invalid field for collapse: invalid") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestParseInvalidCountFieldReturnsError(t *testing.T) { + _, err := parseForTest(t, "-count", "invalid") + if err == nil { + t.Fatalf("expected parse error for invalid count field") + } + if !strings.Contains(err.Error(), "invalid count field: invalid") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestParseInvalidTracepointRegexReturnsError(t *testing.T) { + _, err := parseForTest(t, "-tps", "[") + if err == nil { + t.Fatalf("expected parse error for invalid tracepoint regex") + } + if !strings.Contains(err.Error(), "unable to compile regex") { + t.Fatalf("unexpected error: %v", err) + } +} |
