diff options
| author | Paul Buetow <paul@buetow.org> | 2026-02-27 18:33:40 +0200 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2026-02-27 18:33:40 +0200 |
| commit | 3783d23b8d608c3bf4a2dedd6b4bfb9165439bed (patch) | |
| tree | 69bf24794994d4cdd0e01e337de0510f7d5139b8 | |
| parent | 1cf64c3e43b1bdc2b6443fd24db8028f3c96c6da (diff) | |
internal: validate live CLI mode behavior
| -rw-r--r-- | internal/flags/flags_test.go | 76 | ||||
| -rw-r--r-- | internal/flamegraph/liveserver_test.go | 58 | ||||
| -rw-r--r-- | internal/ior.go | 20 | ||||
| -rw-r--r-- | internal/ior_mode_test.go | 42 |
4 files changed, 193 insertions, 3 deletions
diff --git a/internal/flags/flags_test.go b/internal/flags/flags_test.go new file mode 100644 index 0000000..b4d47d2 --- /dev/null +++ b/internal/flags/flags_test.go @@ -0,0 +1,76 @@ +package flags + +import ( + "flag" + "io" + "os" + "sync" + "testing" + "time" +) + +func parseForTest(t *testing.T, args ...string) Flags { + t.Helper() + + oldCommandLine := flag.CommandLine + oldArgs := os.Args + oldSingleton := singleton + oldOnce := once + oldPID := pidFilter.Load() + oldTID := tidFilter.Load() + oldTUIExport := tuiExportEnable.Load() + + fs := flag.NewFlagSet("ior-test", flag.ContinueOnError) + fs.SetOutput(io.Discard) + flag.CommandLine = fs + os.Args = append([]string{"ior"}, args...) + + singleton = Flags{TUIExportEnable: true} + once = sync.Once{} + pidFilter.Store(-1) + tidFilter.Store(-1) + tuiExportEnable.Store(true) + + parse() + cfg := singleton + + t.Cleanup(func() { + flag.CommandLine = oldCommandLine + os.Args = oldArgs + singleton = oldSingleton + once = oldOnce + pidFilter.Store(oldPID) + tidFilter.Store(oldTID) + tuiExportEnable.Store(oldTUIExport) + }) + + return cfg +} + +func TestParseLiveFlagsAndInterval(t *testing.T) { + cfg := parseForTest(t, "-live", "-live-interval", "200ms", "-pid", "1234") + + if !cfg.LiveFlamegraph { + t.Fatalf("expected -live to enable live mode") + } + if cfg.LiveInterval != 200*time.Millisecond { + t.Fatalf("live interval = %v, want %v", cfg.LiveInterval, 200*time.Millisecond) + } + if cfg.PidFilter != 1234 { + t.Fatalf("pid filter = %d, want 1234", cfg.PidFilter) + } + if got := int(pidFilter.Load()); got != 1234 { + t.Fatalf("global pid filter = %d, want 1234", got) + } +} + +func TestParseLiveDefaults(t *testing.T) { + cfg := parseForTest(t) + + if cfg.LiveFlamegraph { + t.Fatalf("expected live mode disabled by default") + } + if cfg.LiveInterval != time.Second { + t.Fatalf("default live interval = %v, want %v", cfg.LiveInterval, time.Second) + } +} diff --git a/internal/flamegraph/liveserver_test.go b/internal/flamegraph/liveserver_test.go index 09472c5..0d55794 100644 --- a/internal/flamegraph/liveserver_test.go +++ b/internal/flamegraph/liveserver_test.go @@ -2,11 +2,13 @@ package flamegraph import ( "bufio" + "context" "encoding/json" "fmt" "io" "net/http" "net/http/httptest" + "os" "strings" "sync" "testing" @@ -130,6 +132,35 @@ func TestHandleSSEDelayedClientLargeTrieGetsValidSnapshot(t *testing.T) { } } +func TestServeLivePrintsURLAndStopsOnCancel(t *testing.T) { + lt := NewLiveTrie([]string{"comm"}, "count") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + output := captureStdout(t, func() { + errCh := make(chan error, 1) + go func() { + errCh <- ServeLive(ctx, lt, 5*time.Millisecond) + }() + + time.Sleep(40 * time.Millisecond) + cancel() + + select { + case err := <-errCh: + if err != nil { + t.Fatalf("ServeLive returned error: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatalf("timeout waiting for ServeLive to return") + } + }) + + if !strings.Contains(output, "Live flamegraph available at http://") { + t.Fatalf("expected live URL in output, got %q", output) + } +} + func connectSSE(t *testing.T, url string) *http.Response { t.Helper() client := &http.Client{Timeout: 5 * time.Second} @@ -196,3 +227,30 @@ func decodeSSESnapshot(t *testing.T, data string) trieSnapshot { } return snap } + +func captureStdout(t *testing.T, fn func()) string { + t.Helper() + + oldStdout := os.Stdout + reader, writer, err := os.Pipe() + if err != nil { + t.Fatalf("create stdout pipe: %v", err) + } + + os.Stdout = writer + defer func() { os.Stdout = oldStdout }() + + outCh := make(chan string, 1) + go func() { + var b strings.Builder + _, _ = io.Copy(&b, reader) + outCh <- b.String() + }() + + fn() + + _ = writer.Close() + out := <-outCh + _ = reader.Close() + return out +} diff --git a/internal/ior.go b/internal/ior.go index bf0fb1f..cdddc24 100644 --- a/internal/ior.go +++ b/internal/ior.go @@ -4,6 +4,7 @@ import "C" import ( "context" + "errors" "fmt" "os" "os/signal" @@ -35,6 +36,9 @@ var ( runTraceFn = runTrace runTraceWithContextFn = runTraceWithContext runTUIFn = tui.RunWithTraceStarter + getEUID = os.Geteuid + + errRootPrivilegesRequired = errors.New("tracing requires root privileges (run with sudo)") ) type tracepointModule interface { @@ -140,12 +144,22 @@ func Run() error { } func dispatchRun(cfg flags.Flags) error { + if err := validateRunConfig(cfg); err != nil { + return err + } if shouldRunTraceMode(cfg) { return runTraceFn() } return runTUIFn(tuiTraceStarterFromRunTrace(runTraceWithContextFn)) } +func validateRunConfig(cfg flags.Flags) error { + if cfg.LiveFlamegraph && cfg.FlamegraphEnable { + return errors.New("-live and -flamegraph are mutually exclusive") + } + return nil +} + func shouldRunTraceMode(cfg flags.Flags) bool { return cfg.PlainMode || cfg.FlamegraphEnable || cfg.LiveFlamegraph || cfg.PprofEnable } @@ -202,6 +216,10 @@ func runTrace() error { } func runTraceWithContext(parentCtx context.Context, started chan<- struct{}, configure func(*eventLoop)) error { + if getEUID() != 0 { + return errRootPrivilegesRequired + } + verbose := started == nil logln := func(...any) {} if verbose { @@ -328,5 +346,5 @@ func signalTraceStarted(started chan<- struct{}) { } func shouldAutoStopByDuration(cfg flags.Flags) bool { - return cfg.PlainMode || cfg.FlamegraphEnable || cfg.PprofEnable + return cfg.PlainMode || cfg.FlamegraphEnable || cfg.LiveFlamegraph || cfg.PprofEnable } diff --git a/internal/ior_mode_test.go b/internal/ior_mode_test.go index 35d7f43..bac54bd 100644 --- a/internal/ior_mode_test.go +++ b/internal/ior_mode_test.go @@ -68,8 +68,8 @@ func TestShouldAutoStopByDuration(t *testing.T) { withLive := base withLive.LiveFlamegraph = true - if shouldAutoStopByDuration(withLive) { - t.Fatalf("expected live mode not to auto-stop by duration") + if !shouldAutoStopByDuration(withLive) { + t.Fatalf("expected live mode to auto-stop by duration") } } @@ -149,6 +149,44 @@ func TestDispatchRunUsesTUIStarterWhenNotPlain(t *testing.T) { } } +func TestDispatchRunRejectsLiveAndFlamegraph(t *testing.T) { + origRunTrace := runTraceFn + origRunTUI := runTUIFn + defer func() { + runTraceFn = origRunTrace + runTUIFn = origRunTUI + }() + + runTraceFn = func() error { + t.Fatalf("runTraceFn should not be called for invalid flag combos") + return nil + } + runTUIFn = func(tui.TraceStarter) error { + t.Fatalf("runTUIFn should not be called for invalid flag combos") + return nil + } + + cfg := flags.Flags{LiveFlamegraph: true, FlamegraphEnable: true} + err := dispatchRun(cfg) + if err == nil { + t.Fatalf("expected error for -live with -flamegraph") + } + if err.Error() != "-live and -flamegraph are mutually exclusive" { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestRunTraceWithContextRequiresRoot(t *testing.T) { + origGetEUID := getEUID + defer func() { getEUID = origGetEUID }() + + getEUID = func() int { return 1000 } + err := runTraceWithContext(context.Background(), nil, nil) + if !errors.Is(err, errRootPrivilegesRequired) { + t.Fatalf("expected root-required error, got %v", err) + } +} + func TestTuiTraceStarterFromRunTracePropagatesError(t *testing.T) { starter := tuiTraceStarterFromRunTrace( func(context.Context, chan<- struct{}, func(*eventLoop)) error { return errors.New("startup failed") }, |
