summaryrefslogtreecommitdiff
path: root/internal/flags
diff options
context:
space:
mode:
Diffstat (limited to 'internal/flags')
-rw-r--r--internal/flags/flags.go143
-rw-r--r--internal/flags/flags_test.go32
2 files changed, 48 insertions, 127 deletions
diff --git a/internal/flags/flags.go b/internal/flags/flags.go
index 1844cfe..9d05390 100644
--- a/internal/flags/flags.go
+++ b/internal/flags/flags.go
@@ -7,8 +7,6 @@ import (
"regexp"
"slices"
"strings"
- "sync"
- "sync/atomic"
"time"
"ior/internal/collapse"
@@ -76,17 +74,6 @@ type Config struct {
ShowVersion bool
}
-var (
- current atomic.Pointer[Config]
- once sync.Once
- parseErr error
-)
-
-func init() {
- defaults := NewFlags()
- current.Store(&defaults)
-}
-
// DefaultResetTimer is the default cadence for the dashboard's auto-reset
// timer. It periodically clears aggregate state (live flamegraph trie and
// stats engine) — the same effect as pressing `r` — to prevent unbounded
@@ -123,7 +110,9 @@ func (f Config) GetTUIExportEnable() bool {
return f.TUIExportEnable
}
-func (f Config) clone() Config {
+// Clone returns a deep copy of the Config, duplicating all slice and filter
+// fields so that modifications to the copy do not affect the original.
+func (f Config) Clone() Config {
out := f
out.TracepointsToAttach = slices.Clone(f.TracepointsToAttach)
out.TracepointsToExclude = slices.Clone(f.TracepointsToExclude)
@@ -132,109 +121,62 @@ func (f Config) clone() Config {
return out
}
-// Get returns a copy of the currently active runtime configuration.
-func Get() Config {
- cfg := current.Load()
- if cfg == nil {
- return NewFlags()
- }
- return cfg.clone()
+// Parse parses CLI flags from os.Args and returns the resulting Config.
+// It uses the global flag.CommandLine set, so it must be called once at
+// program startup before any other flag parsing occurs.
+func Parse() (Config, error) {
+ return parseFromFlagSet(flag.CommandLine, os.Args[1:])
}
-func setCurrent(cfg Config) {
- snapshot := cfg.clone()
- current.Store(&snapshot)
-}
-
-func updateCurrent(update func(*Config)) {
- for {
- old := current.Load()
- next := NewFlags()
- if old != nil {
- next = old.clone()
- }
- update(&next)
- snapshot := next.clone()
- if current.CompareAndSwap(old, &snapshot) {
- return
- }
- }
-}
-
-// SetPidFilter updates the active PID filter used for subsequent tracing runs.
-func SetPidFilter(pid int) {
- updateCurrent(func(cfg *Config) {
- cfg.PidFilter = pid
- })
-}
-
-// SetTidFilter updates the active TID filter used for subsequent tracing runs.
-func SetTidFilter(tid int) {
- updateCurrent(func(cfg *Config) {
- cfg.TidFilter = tid
- })
-}
-
-// SetTUIExportEnable toggles TUI stream export file writing.
-func SetTUIExportEnable(enabled bool) {
- updateCurrent(func(cfg *Config) {
- cfg.TUIExportEnable = enabled
- })
-}
-
-// Parse parses CLI flags once and updates the current runtime configuration.
-func Parse() error {
- once.Do(func() {
- parseErr = parse()
- })
- return parseErr
-}
-
-func parse() error {
+// parseFromFlagSet parses flags into a new Config using the provided FlagSet
+// and argument list. It is factored out of Parse to allow tests to inject a
+// fresh FlagSet and custom argument slices without touching global state.
+func parseFromFlagSet(fs *flag.FlagSet, args []string) (Config, error) {
cfg := NewFlags()
validFields := collapse.ValidFields()
validCounts := collapse.ValidCountFields()
- flag.IntVar(&cfg.PidFilter, "pid", cfg.PidFilter, "Filter for processes ID")
- flag.IntVar(&cfg.TidFilter, "tid", cfg.TidFilter, "Filter for thread ID")
- flag.IntVar(&cfg.EventMapSize, "mapSize", cfg.EventMapSize, "BPF FD event ring buffer map size")
- flag.IntVar(&cfg.Duration, "duration", cfg.Duration, "Probe duration in seconds")
+ fs.IntVar(&cfg.PidFilter, "pid", cfg.PidFilter, "Filter for processes ID")
+ fs.IntVar(&cfg.TidFilter, "tid", cfg.TidFilter, "Filter for thread ID")
+ fs.IntVar(&cfg.EventMapSize, "mapSize", cfg.EventMapSize, "BPF FD event ring buffer map size")
+ fs.IntVar(&cfg.Duration, "duration", cfg.Duration, "Probe duration in seconds")
- flag.StringVar(&cfg.CommFilter, "comm", "", "Command to filter for")
- flag.StringVar(&cfg.PathFilter, "path", "", "Path to filter for")
+ fs.StringVar(&cfg.CommFilter, "comm", "", "Command to filter for")
+ fs.StringVar(&cfg.PathFilter, "path", "", "Path to filter for")
- flag.BoolVar(&cfg.PprofEnable, "pprof", false, "Enable profiling")
+ fs.BoolVar(&cfg.PprofEnable, "pprof", false, "Enable profiling")
- tracepointsToAttach := flag.String("tps", "", "Comma separated list regexes for tracepoints to load")
- tracepointsToExclude := flag.String("tpsExclude", "", "Comma separated list regexes for tracepoints to exclude")
+ tracepointsToAttach := fs.String("tps", "", "Comma separated list regexes for tracepoints to load")
+ tracepointsToExclude := fs.String("tpsExclude", "", "Comma separated list regexes for tracepoints to exclude")
- flag.BoolVar(&cfg.PlainMode, "plain", false, "Enable plain CSV output mode (disable TUI)")
- flag.BoolVar(&cfg.FlamegraphOutput, "flamegraph", false, "Write aggregated .ior.zst output for trace/integration workflows")
- flag.StringVar(&cfg.ParquetPath, "parquet", cfg.ParquetPath, "Write all traced syscall rows directly to a parquet file in headless mode (skip the TUI; incompatible with -plain, -flamegraph, --testflames, --testliveflames, and content filters)")
- flag.StringVar(&cfg.OutputName, "name", cfg.OutputName, "Base name for .ior.zst trace output files")
- flag.BoolVar(&cfg.TestFlames, "testflames", false, "Run TUI with static synthetic flamegraph data for keyboard-navigation testing")
- flag.BoolVar(&cfg.TestLiveFlames, "testliveflames", false, "Run TUI with continuously-updating synthetic flamegraph data for live keyboard-navigation testing")
- flag.DurationVar(&cfg.LiveInterval, "live-interval", cfg.LiveInterval, "Synthetic live flamegraph refresh interval for --testliveflames")
- flag.BoolVar(&cfg.TUIExportEnable, "tuiExport", cfg.TUIExportEnable, "Enable TUI CSV snapshot export files (separate from Parquet recording)")
- flag.DurationVar(&cfg.ResetTimer, "resetTimer", cfg.ResetTimer,
+ fs.BoolVar(&cfg.PlainMode, "plain", false, "Enable plain CSV output mode (disable TUI)")
+ fs.BoolVar(&cfg.FlamegraphOutput, "flamegraph", false, "Write aggregated .ior.zst output for trace/integration workflows")
+ fs.StringVar(&cfg.ParquetPath, "parquet", cfg.ParquetPath, "Write all traced syscall rows directly to a parquet file in headless mode (skip the TUI; incompatible with -plain, -flamegraph, --testflames, --testliveflames, and content filters)")
+ fs.StringVar(&cfg.OutputName, "name", cfg.OutputName, "Base name for .ior.zst trace output files")
+ fs.BoolVar(&cfg.TestFlames, "testflames", false, "Run TUI with static synthetic flamegraph data for keyboard-navigation testing")
+ fs.BoolVar(&cfg.TestLiveFlames, "testliveflames", false, "Run TUI with continuously-updating synthetic flamegraph data for live keyboard-navigation testing")
+ fs.DurationVar(&cfg.LiveInterval, "live-interval", cfg.LiveInterval, "Synthetic live flamegraph refresh interval for --testliveflames")
+ fs.BoolVar(&cfg.TUIExportEnable, "tuiExport", cfg.TUIExportEnable, "Enable TUI CSV snapshot export files (separate from Parquet recording)")
+ fs.DurationVar(&cfg.ResetTimer, "resetTimer", cfg.ResetTimer,
"Auto-reset interval for aggregate dashboard state (flamegraph trie + stats engine); set to 0 to disable")
- flag.BoolVar(&cfg.ShowVersion, "version", false, "Print version banner and exit")
- fields := flag.String("fields", "",
+ fs.BoolVar(&cfg.ShowVersion, "version", false, "Print version banner and exit")
+ fields := fs.String("fields", "",
fmt.Sprintf("Comma separated list of fields to collapse, valid are: %v", validFields))
- flag.StringVar(&cfg.CountField, "count", cfg.CountField,
+ fs.StringVar(&cfg.CountField, "count", cfg.CountField,
fmt.Sprintf("Count field to collapse, valid are: %v", validCounts))
- if err := flag.CommandLine.Parse(os.Args[1:]); err != nil {
- return err
+
+ if err := fs.Parse(args); err != nil {
+ return Config{}, err
}
var err error
cfg.TracepointsToAttach, err = extractTracepointFlags(*tracepointsToAttach)
if err != nil {
- return err
+ return Config{}, err
}
cfg.TracepointsToExclude, err = extractTracepointFlags(*tracepointsToExclude)
if err != nil {
- return err
+ return Config{}, err
}
// Keep this list empty by default.
@@ -250,22 +192,21 @@ func parse() error {
for _, field := range cfg.CollapsedFields {
if !collapse.IsValidField(field) {
- return fmt.Errorf("invalid field for collapse: %s", field)
+ return Config{}, fmt.Errorf("invalid field for collapse: %s", field)
}
}
if !collapse.IsValidCountField(cfg.CountField) {
- return fmt.Errorf("invalid count field: %s", cfg.CountField)
+ return Config{}, fmt.Errorf("invalid count field: %s", cfg.CountField)
}
// A negative reset timer would imply auto-resets in the past, which is
// nonsensical. 0 disables, anything positive enables.
if cfg.ResetTimer < 0 {
- return fmt.Errorf("invalid resetTimer: %s (must be >= 0; 0 disables)", cfg.ResetTimer)
+ return Config{}, fmt.Errorf("invalid resetTimer: %s (must be >= 0; 0 disables)", cfg.ResetTimer)
}
- setCurrent(cfg)
- return nil
+ return cfg, nil
}
func extractTracepointFlags(tracepoints string) (regexes []*regexp.Regexp, err error) {
diff --git a/internal/flags/flags_test.go b/internal/flags/flags_test.go
index 77c167c..1630554 100644
--- a/internal/flags/flags_test.go
+++ b/internal/flags/flags_test.go
@@ -3,39 +3,19 @@ package flags
import (
"flag"
"io"
- "os"
"strings"
"testing"
"time"
)
+// parseForTest builds a fresh FlagSet and parses the given args, returning
+// the resulting Config. It avoids touching any global state so tests can run
+// in parallel without interfering with each other.
func parseForTest(t *testing.T, args ...string) (Config, error) {
t.Helper()
-
- oldCommandLine := flag.CommandLine
- oldArgs := os.Args
- oldCurrent := Get()
- oldParseErr := parseErr
-
fs := flag.NewFlagSet("ior-test", flag.ContinueOnError)
fs.SetOutput(io.Discard)
- flag.CommandLine = fs
- os.Args = append([]string{"ior"}, args...)
-
- setCurrent(NewFlags())
- parseErr = nil
-
- err := parse()
- cfg := Get()
-
- t.Cleanup(func() {
- flag.CommandLine = oldCommandLine
- os.Args = oldArgs
- setCurrent(oldCurrent)
- parseErr = oldParseErr
- })
-
- return cfg, err
+ return parseFromFlagSet(fs, args)
}
func TestParseLiveIntervalAndPID(t *testing.T) {
@@ -50,8 +30,8 @@ func TestParseLiveIntervalAndPID(t *testing.T) {
if cfg.PidFilter != 1234 {
t.Fatalf("pid filter = %d, want 1234", cfg.PidFilter)
}
- if got := Get().GetPidFilter(); got != 1234 {
- t.Fatalf("Get().GetPidFilter() = %d, want 1234", got)
+ if got := cfg.GetPidFilter(); got != 1234 {
+ t.Fatalf("cfg.GetPidFilter() = %d, want 1234", got)
}
}