diff options
| author | Paul Buetow <paul@buetow.org> | 2026-05-13 09:35:46 +0300 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2026-05-13 09:35:46 +0300 |
| commit | e9c747c607e14d8b8f69547aa3c0f3b079c20796 (patch) | |
| tree | b5c461f3ace5e7785ed9716fdc136ff985475440 | |
| parent | b086d3680e845a94c6662811c7f40e9592c3dec6 (diff) | |
replace package-level test doubles in ior.go with constructor injection (DIP)
Introduce runnerDeps struct to bundle all injectable function dependencies
(getEUID, runTrace, runParquet, runTraceWithContext, runTUI*) that were
previously package-level vars overridden in tests. The modeRegistry now
carries a runnerDeps instance and passes it to each handler's run() method,
eliminating global state mutation in tests.
- Add runnerDeps struct and defaultRunnerDeps() constructor in ior_mode_registry.go
- Convert modeRegistry from a []modeHandler slice type to a struct with
handlers + deps fields; add newModeRegistry(deps) constructor
- Update modeHandler.run() signature to accept runnerDeps; handlers call
deps.getEUID / deps.runTrace etc. instead of globals
- Update SetTUIRunners to write into defaultRegistry.deps instead of
package-level vars
- Add dispatchRunWithDeps helper for test isolation without global mutation
- Remove root-privilege check from runTraceWithContext and runHeadlessParquet;
each mode handler owns the EUID gate via deps.getEUID
- Rewrite ior_mode_test.go: replace save/restore patterns with stubDeps()
helper and dispatchRunWithDeps; add three new root-privilege tests
replacing the removed TestRunTraceWithContextRequiresRoot
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
| -rw-r--r-- | internal/ior.go | 39 | ||||
| -rw-r--r-- | internal/ior_mode_registry.go | 114 | ||||
| -rw-r--r-- | internal/ior_mode_test.go | 297 | ||||
| -rw-r--r-- | internal/ior_parquet_sink.go | 6 |
4 files changed, 243 insertions, 213 deletions
diff --git a/internal/ior.go b/internal/ior.go index f1b159d..cddb657 100644 --- a/internal/ior.go +++ b/internal/ior.go @@ -29,32 +29,20 @@ import ( // never imports the TUI layer. type tuiRunFunc func(flags.Config, runtime.TraceStarter) error -var ( - runTraceFn = runTrace - runParquetFn = runHeadlessParquet - runTraceWithContextFn = runTraceWithContext - // runTUIFn, runTUITestFlamesFn, runTUITestLiveFlamesFn are injected by - // main (via SetTUIRunners) before Run is called. They default to nil so - // that test files can replace individual runners without importing tui. - runTUIFn tuiRunFunc - runTUITestFlamesFn tuiRunFunc - runTUITestLiveFlamesFn tuiRunFunc - getEUID = os.Geteuid - - errRootPrivilegesRequired = errors.New("tracing requires root privileges (run with sudo)") -) +var errRootPrivilegesRequired = errors.New("tracing requires root privileges (run with sudo)") // SetTUIRunners injects the concrete TUI runner functions from the cmd layer -// so the core internal package does not need to import the TUI packages. -// This must be called before Run when running in TUI mode. +// into the default registry so the core internal package does not need to +// import the TUI packages. This must be called before Run when running in +// TUI mode. func SetTUIRunners( runTUI tuiRunFunc, runTUITestFlames tuiRunFunc, runTUITestLiveFlames tuiRunFunc, ) { - runTUIFn = runTUI - runTUITestFlamesFn = runTUITestFlames - runTUITestLiveFlamesFn = runTUITestLiveFlames + defaultRegistry.deps.runTUI = runTUI + defaultRegistry.deps.runTUITestFlames = runTUITestFlames + defaultRegistry.deps.runTUITestLiveFlames = runTUITestLiveFlames } // streamEventSink is the write-side contract for the stream ring buffer used @@ -75,6 +63,13 @@ func dispatchRun(cfg flags.Config) error { return defaultRegistry.dispatch(cfg) } +// dispatchRunWithDeps constructs an isolated registry from the given deps and +// dispatches cfg through it. Used by tests to inject stub functions without +// mutating the global defaultRegistry. +func dispatchRunWithDeps(cfg flags.Config, deps runnerDeps) error { + return newModeRegistry(deps).dispatch(cfg) +} + // validateRunConfig runs all cross-mode constraint checks without running // any mode. It is a thin wrapper around defaultRegistry.validate so that // callers (and tests) that only want validation do not need to know about @@ -448,10 +443,10 @@ func finaliseTrace(recorder *flamegraph.Recorder, profiling *profilingControl, t return nil } +// runTraceWithContext is the concrete BPF trace implementation. Root privilege +// is checked by the mode handler (via runnerDeps.getEUID) before calling this +// function; the handler is the authoritative place for the EUID gate. func runTraceWithContext(parentCtx context.Context, cfg flags.Config, started chan<- struct{}, configure func(*eventLoop)) error { - if getEUID() != 0 { - return errRootPrivilegesRequired - } verbose := started == nil logln := newLogger(verbose) diff --git a/internal/ior_mode_registry.go b/internal/ior_mode_registry.go index 6d04052..3eb59b5 100644 --- a/internal/ior_mode_registry.go +++ b/internal/ior_mode_registry.go @@ -1,11 +1,55 @@ package internal import ( + "context" "errors" + "os" "ior/internal/flags" ) +// runnerDeps bundles all injectable function dependencies used by the mode +// registry and its handlers. Using a struct instead of package-level vars +// allows tests to substitute individual functions without mutating global +// state (Dependency Inversion Principle). +type runnerDeps struct { + // getEUID returns the effective user ID of the calling process. + // Overridden in tests to simulate root or non-root execution. + getEUID func() int + + // runTrace executes a headless plain/flamegraph trace (no TUI). + runTrace func(flags.Config) error + + // runParquet executes a headless Parquet recording run (no TUI). + runParquet func(flags.Config) error + + // runTraceWithContext drives a BPF trace with a parent context, started + // signal channel, and event-loop configurator. Used by the TUI starter. + runTraceWithContext func(context.Context, flags.Config, chan<- struct{}, func(*eventLoop)) error + + // runTUI launches the interactive TUI backed by a live BPF trace. + // Injected at startup via SetTUIRunners so that the core package never + // imports the TUI layer. + runTUI tuiRunFunc + + // runTUITestFlames launches the TUI seeded with static synthetic flame data. + runTUITestFlames tuiRunFunc + + // runTUITestLiveFlames launches the TUI fed by a live synthetic flame goroutine. + runTUITestLiveFlames tuiRunFunc +} + +// defaultRunnerDeps returns the production function set. +func defaultRunnerDeps() runnerDeps { + return runnerDeps{ + getEUID: os.Geteuid, + runTrace: runTrace, + runParquet: runHeadlessParquet, + runTraceWithContext: runTraceWithContext, + // TUI runners are nil until SetTUIRunners is called from cmd/ior/main.go. + } +} + // modeHandler describes a single execution mode for the ior binary. // Each mode knows how to recognise itself (match), enforce its // invariants (validate), and run (run). The registry evaluates @@ -18,33 +62,47 @@ type modeHandler interface { // (pre-root modes are checked first and return early before requiring root). validate(cfg flags.Config) error // run executes the mode using the supplied config. - run(cfg flags.Config) error + run(cfg flags.Config, deps runnerDeps) error } -// modeRegistry is an ordered list of modeHandlers. -// dispatchRun and validateRunConfig iterate through it. -type modeRegistry []modeHandler +// modeRegistry is an ordered list of modeHandlers paired with the +// injectable function dependencies they share. Storing deps on the registry +// (rather than as package-level vars) lets tests construct isolated +// registries without mutating global state. +type modeRegistry struct { + handlers []modeHandler + deps runnerDeps +} -// defaultRegistry is the canonical ordered registry used at runtime. +// newModeRegistry constructs a registry with the standard handler order and +// the provided dependencies. // Modes are evaluated first-match-wins, so more specific modes (e.g., -// testFlames) are registered before more general ones (e.g., TUI default). -var defaultRegistry = modeRegistry{ - &testFlamesModeHandler{}, - &testLiveFlamesModeHandler{}, - &headlessParquetModeHandler{}, - &plainTraceModeHandler{}, - &tuiModeHandler{}, +// testFlames) must be registered before more general ones (e.g., TUI default). +func newModeRegistry(deps runnerDeps) modeRegistry { + return modeRegistry{ + handlers: []modeHandler{ + &testFlamesModeHandler{}, + &testLiveFlamesModeHandler{}, + &headlessParquetModeHandler{}, + &plainTraceModeHandler{}, + &tuiModeHandler{}, + }, + deps: deps, + } } +// defaultRegistry is the canonical ordered registry used at runtime. +var defaultRegistry = newModeRegistry(defaultRunnerDeps()) + // dispatch validates cross-mode constraints, requires root when necessary, // then delegates to the first matching handler in the registry. func (reg modeRegistry) dispatch(cfg flags.Config) error { if err := reg.validate(cfg); err != nil { return err } - for _, h := range reg { + for _, h := range reg.handlers { if h.match(cfg) { - return h.run(cfg) + return h.run(cfg, reg.deps) } } // Registry must always include a catch-all (tuiModeHandler matches everything). @@ -56,7 +114,7 @@ func (reg modeRegistry) dispatch(cfg flags.Config) error { // combination errors (e.g., parquet + plain is rejected regardless of which // handler ultimately runs). func (reg modeRegistry) validate(cfg flags.Config) error { - for _, h := range reg { + for _, h := range reg.handlers { if err := h.validate(cfg); err != nil { return err } @@ -93,8 +151,8 @@ func (h *testFlamesModeHandler) validate(cfg flags.Config) error { return nil } -func (h *testFlamesModeHandler) run(cfg flags.Config) error { - return runTUITestFlamesFn(cfg, tuiTestFlamesStarter(cfg)) +func (h *testFlamesModeHandler) run(cfg flags.Config, deps runnerDeps) error { + return deps.runTUITestFlames(cfg, tuiTestFlamesStarter(cfg)) } // --- testLiveFlamesModeHandler --- @@ -123,8 +181,8 @@ func (h *testLiveFlamesModeHandler) validate(cfg flags.Config) error { return nil } -func (h *testLiveFlamesModeHandler) run(cfg flags.Config) error { - return runTUITestLiveFlamesFn(cfg, tuiTestLiveFlamesStarter(cfg)) +func (h *testLiveFlamesModeHandler) run(cfg flags.Config, deps runnerDeps) error { + return deps.runTUITestLiveFlames(cfg, tuiTestLiveFlamesStarter(cfg)) } // --- headlessParquetModeHandler --- @@ -161,11 +219,11 @@ func (h *headlessParquetModeHandler) validate(cfg flags.Config) error { return nil } -func (h *headlessParquetModeHandler) run(cfg flags.Config) error { - if getEUID() != 0 { +func (h *headlessParquetModeHandler) run(cfg flags.Config, deps runnerDeps) error { + if deps.getEUID() != 0 { return errRootPrivilegesRequired } - return runParquetFn(cfg) + return deps.runParquet(cfg) } // --- plainTraceModeHandler --- @@ -186,11 +244,11 @@ func (h *plainTraceModeHandler) validate(cfg flags.Config) error { return nil } -func (h *plainTraceModeHandler) run(cfg flags.Config) error { - if getEUID() != 0 { +func (h *plainTraceModeHandler) run(cfg flags.Config, deps runnerDeps) error { + if deps.getEUID() != 0 { return errRootPrivilegesRequired } - return runTraceFn(cfg) + return deps.runTrace(cfg) } // --- tuiModeHandler --- @@ -208,9 +266,9 @@ func (h *tuiModeHandler) validate(_ flags.Config) error { return nil } -func (h *tuiModeHandler) run(cfg flags.Config) error { - if getEUID() != 0 { +func (h *tuiModeHandler) run(cfg flags.Config, deps runnerDeps) error { + if deps.getEUID() != 0 { return errRootPrivilegesRequired } - return runTUIFn(cfg, tuiTraceStarterFromRunTrace(cfg, runTraceWithContextFn)) + return deps.runTUI(cfg, tuiTraceStarterFromRunTrace(cfg, deps.runTraceWithContext)) } diff --git a/internal/ior_mode_test.go b/internal/ior_mode_test.go index 0697ada..5ba2894 100644 --- a/internal/ior_mode_test.go +++ b/internal/ior_mode_test.go @@ -25,6 +25,22 @@ import ( parquetgo "github.com/parquet-go/parquet-go" ) +// stubDeps returns a runnerDeps with safe no-op stubs for every function +// field. Individual tests override only the functions they care about, +// keeping test setup concise and making it easy to add new fields without +// updating every test. +func stubDeps() runnerDeps { + return runnerDeps{ + getEUID: func() int { return 0 }, + runTrace: func(flags.Config) error { return nil }, + runParquet: func(flags.Config) error { return nil }, + runTraceWithContext: func(context.Context, flags.Config, chan<- struct{}, func(*eventLoop)) error { return nil }, + runTUI: func(flags.Config, runtime.TraceStarter) error { return nil }, + runTUITestFlames: func(flags.Config, runtime.TraceStarter) error { return nil }, + runTUITestLiveFlames: func(flags.Config, runtime.TraceStarter) error { return nil }, + } +} + func TestShouldRunTraceMode(t *testing.T) { base := flags.Config{} @@ -86,180 +102,123 @@ func TestShouldAutoStopByDuration(t *testing.T) { if shouldAutoStopByDuration(withPprof) { t.Fatalf("expected pprof flag alone not to auto-stop by duration") } - } func TestDispatchRunUsesTraceModeWhenRequested(t *testing.T) { - origRunTrace := runTraceFn - origRunParquet := runParquetFn - origRunTUI := runTUIFn - origRunTUITestFlames := runTUITestFlamesFn - origRunTUITestLiveFlames := runTUITestLiveFlamesFn - origGetEUID := getEUID - defer func() { - runTraceFn = origRunTrace - runParquetFn = origRunParquet - runTUIFn = origRunTUI - runTUITestFlamesFn = origRunTUITestFlames - runTUITestLiveFlamesFn = origRunTUITestLiveFlames - getEUID = origGetEUID - }() - getEUID = func() int { return 0 } - traceCalled := false - tuiCalled := false - runTraceFn = func(flags.Config) error { + deps := stubDeps() + deps.runTrace = func(flags.Config) error { traceCalled = true return nil } - runParquetFn = func(flags.Config) error { - t.Fatalf("runParquetFn should not be called in plain trace mode") + deps.runParquet = func(flags.Config) error { + t.Fatalf("runParquet should not be called in plain trace mode") return nil } - runTUIFn = func(flags.Config, runtime.TraceStarter) error { + tuiCalled := false + deps.runTUI = func(flags.Config, runtime.TraceStarter) error { tuiCalled = true return nil } - runTUITestFlamesFn = func(flags.Config, runtime.TraceStarter) error { - t.Fatalf("runTUITestFlamesFn should not be called in trace mode") + deps.runTUITestFlames = func(flags.Config, runtime.TraceStarter) error { + t.Fatalf("runTUITestFlames should not be called in trace mode") return nil } - runTUITestLiveFlamesFn = func(flags.Config, runtime.TraceStarter) error { - t.Fatalf("runTUITestLiveFlamesFn should not be called in trace mode") + deps.runTUITestLiveFlames = func(flags.Config, runtime.TraceStarter) error { + t.Fatalf("runTUITestLiveFlames should not be called in trace mode") return nil } cfg := flags.Config{PlainMode: true} - if err := dispatchRun(cfg); err != nil { - t.Fatalf("dispatchRun returned error: %v", err) + if err := dispatchRunWithDeps(cfg, deps); err != nil { + t.Fatalf("dispatchRunWithDeps returned error: %v", err) } if !traceCalled { - t.Fatalf("expected runTraceFn to be called") + t.Fatalf("expected runTrace to be called") } if tuiCalled { - t.Fatalf("did not expect runTUIFn to be called") + t.Fatalf("did not expect runTUI to be called") } } func TestDispatchRunUsesHeadlessParquetModeWhenRequested(t *testing.T) { - origRunTrace := runTraceFn - origRunParquet := runParquetFn - origRunTUI := runTUIFn - origRunTUITestFlames := runTUITestFlamesFn - origRunTUITestLiveFlames := runTUITestLiveFlamesFn - origGetEUID := getEUID - defer func() { - runTraceFn = origRunTrace - runParquetFn = origRunParquet - runTUIFn = origRunTUI - runTUITestFlamesFn = origRunTUITestFlames - runTUITestLiveFlamesFn = origRunTUITestLiveFlames - getEUID = origGetEUID - }() - getEUID = func() int { return 0 } - traceCalled := false parquetCalled := false tuiCalled := false - runTraceFn = func(flags.Config) error { + deps := stubDeps() + deps.runTrace = func(flags.Config) error { traceCalled = true return nil } - runParquetFn = func(flags.Config) error { + deps.runParquet = func(flags.Config) error { parquetCalled = true return nil } - runTUIFn = func(flags.Config, runtime.TraceStarter) error { + deps.runTUI = func(flags.Config, runtime.TraceStarter) error { tuiCalled = true return nil } - runTUITestFlamesFn = func(flags.Config, runtime.TraceStarter) error { - t.Fatalf("runTUITestFlamesFn should not be called in parquet mode") + deps.runTUITestFlames = func(flags.Config, runtime.TraceStarter) error { + t.Fatalf("runTUITestFlames should not be called in parquet mode") return nil } - runTUITestLiveFlamesFn = func(flags.Config, runtime.TraceStarter) error { - t.Fatalf("runTUITestLiveFlamesFn should not be called in parquet mode") + deps.runTUITestLiveFlames = func(flags.Config, runtime.TraceStarter) error { + t.Fatalf("runTUITestLiveFlames should not be called in parquet mode") return nil } cfg := flags.Config{ParquetPath: "trace.parquet"} - if err := dispatchRun(cfg); err != nil { - t.Fatalf("dispatchRun returned error: %v", err) + if err := dispatchRunWithDeps(cfg, deps); err != nil { + t.Fatalf("dispatchRunWithDeps returned error: %v", err) } if !parquetCalled { - t.Fatalf("expected runParquetFn to be called") + t.Fatalf("expected runParquet to be called") } if traceCalled { - t.Fatalf("did not expect runTraceFn to be called") + t.Fatalf("did not expect runTrace to be called") } if tuiCalled { - t.Fatalf("did not expect runTUIFn to be called") + t.Fatalf("did not expect runTUI to be called") } } func TestDispatchRunUsesTUIWhenOnlyPprofEnabled(t *testing.T) { - origRunTrace := runTraceFn - origRunTUI := runTUIFn - origRunTUITestFlames := runTUITestFlamesFn - origRunTUITestLiveFlames := runTUITestLiveFlamesFn - origGetEUID := getEUID - defer func() { - runTraceFn = origRunTrace - runTUIFn = origRunTUI - runTUITestFlamesFn = origRunTUITestFlames - runTUITestLiveFlamesFn = origRunTUITestLiveFlames - getEUID = origGetEUID - }() - getEUID = func() int { return 0 } - traceCalled := false tuiCalled := false - runTraceFn = func(flags.Config) error { + deps := stubDeps() + deps.runTrace = func(flags.Config) error { traceCalled = true return nil } - runTUIFn = func(flags.Config, runtime.TraceStarter) error { + deps.runTUI = func(flags.Config, runtime.TraceStarter) error { tuiCalled = true return nil } - runTUITestFlamesFn = func(flags.Config, runtime.TraceStarter) error { - t.Fatalf("runTUITestFlamesFn should not be called for regular TUI mode") + deps.runTUITestFlames = func(flags.Config, runtime.TraceStarter) error { + t.Fatalf("runTUITestFlames should not be called for regular TUI mode") return nil } - runTUITestLiveFlamesFn = func(flags.Config, runtime.TraceStarter) error { - t.Fatalf("runTUITestLiveFlamesFn should not be called for regular TUI mode") + deps.runTUITestLiveFlames = func(flags.Config, runtime.TraceStarter) error { + t.Fatalf("runTUITestLiveFlames should not be called for regular TUI mode") return nil } cfg := flags.Config{PprofEnable: true} - if err := dispatchRun(cfg); err != nil { - t.Fatalf("dispatchRun returned error: %v", err) + if err := dispatchRunWithDeps(cfg, deps); err != nil { + t.Fatalf("dispatchRunWithDeps returned error: %v", err) } if traceCalled { - t.Fatalf("did not expect runTraceFn when only -pprof is enabled") + t.Fatalf("did not expect runTrace when only -pprof is enabled") } if !tuiCalled { - t.Fatalf("expected runTUIFn to be called") + t.Fatalf("expected runTUI to be called") } } func TestDispatchRunUsesTUIStarterWhenNotPlain(t *testing.T) { - origRunTraceWithContext := runTraceWithContextFn - origRunTUI := runTUIFn - origRunTUITestFlames := runTUITestFlamesFn - origRunTUITestLiveFlames := runTUITestLiveFlamesFn - origGetEUID := getEUID - defer func() { - runTraceWithContextFn = origRunTraceWithContext - runTUIFn = origRunTUI - runTUITestFlamesFn = origRunTUITestFlames - runTUITestLiveFlamesFn = origRunTUITestLiveFlames - getEUID = origGetEUID - }() - getEUID = func() int { return 0 } - traceDone := make(chan struct{}, 1) - runTraceWithContextFn = func(_ context.Context, _ flags.Config, started chan<- struct{}, configure func(*eventLoop)) error { + deps := stubDeps() + deps.runTraceWithContext = func(_ context.Context, _ flags.Config, started chan<- struct{}, configure func(*eventLoop)) error { if configure != nil { configure(&eventLoop{}) } @@ -269,7 +228,7 @@ func TestDispatchRunUsesTUIStarterWhenNotPlain(t *testing.T) { } tuiCalled := false - runTUIFn = func(_ flags.Config, starter runtime.TraceStarter) error { + deps.runTUI = func(_ flags.Config, starter runtime.TraceStarter) error { tuiCalled = true if starter == nil { t.Fatalf("expected non-nil starter") @@ -279,108 +238,88 @@ func TestDispatchRunUsesTUIStarterWhenNotPlain(t *testing.T) { } return nil } - runTUITestFlamesFn = func(flags.Config, runtime.TraceStarter) error { - t.Fatalf("runTUITestFlamesFn should not be called for normal starter path") + deps.runTUITestFlames = func(flags.Config, runtime.TraceStarter) error { + t.Fatalf("runTUITestFlames should not be called for normal starter path") return nil } - runTUITestLiveFlamesFn = func(flags.Config, runtime.TraceStarter) error { - t.Fatalf("runTUITestLiveFlamesFn should not be called for normal starter path") + deps.runTUITestLiveFlames = func(flags.Config, runtime.TraceStarter) error { + t.Fatalf("runTUITestLiveFlames should not be called for normal starter path") return nil } cfg := flags.Config{} - if err := dispatchRun(cfg); err != nil { - t.Fatalf("dispatchRun returned error: %v", err) + if err := dispatchRunWithDeps(cfg, deps); err != nil { + t.Fatalf("dispatchRunWithDeps returned error: %v", err) } if !tuiCalled { - t.Fatalf("expected runTUIFn to be called") + t.Fatalf("expected runTUI to be called") } select { case <-traceDone: case <-time.After(200 * time.Millisecond): - t.Fatalf("expected starter to launch runTraceWithContextFn") + t.Fatalf("expected starter to launch runTraceWithContext") } } func TestDispatchRunUsesTestFlamesModeWhenRequested(t *testing.T) { - origRunTrace := runTraceFn - origRunTUI := runTUIFn - origRunTUITestFlames := runTUITestFlamesFn - origRunTUITestLiveFlames := runTUITestLiveFlamesFn - defer func() { - runTraceFn = origRunTrace - runTUIFn = origRunTUI - runTUITestFlamesFn = origRunTUITestFlames - runTUITestLiveFlamesFn = origRunTUITestLiveFlames - }() - traceCalled := false regularTUICalled := false testFlamesCalled := false - runTraceFn = func(flags.Config) error { + deps := stubDeps() + deps.runTrace = func(flags.Config) error { traceCalled = true return nil } - runTUIFn = func(flags.Config, runtime.TraceStarter) error { + deps.runTUI = func(flags.Config, runtime.TraceStarter) error { regularTUICalled = true return nil } - runTUITestFlamesFn = func(_ flags.Config, starter runtime.TraceStarter) error { + deps.runTUITestFlames = func(_ flags.Config, starter runtime.TraceStarter) error { testFlamesCalled = true if starter == nil { t.Fatalf("expected non-nil starter for test flames mode") } return starter(context.Background()) } - runTUITestLiveFlamesFn = func(flags.Config, runtime.TraceStarter) error { - t.Fatalf("runTUITestLiveFlamesFn should not be called for --testflames") + deps.runTUITestLiveFlames = func(flags.Config, runtime.TraceStarter) error { + t.Fatalf("runTUITestLiveFlames should not be called for --testflames") return nil } cfg := flags.Config{TestFlames: true} - if err := dispatchRun(cfg); err != nil { - t.Fatalf("dispatchRun returned error: %v", err) + if err := dispatchRunWithDeps(cfg, deps); err != nil { + t.Fatalf("dispatchRunWithDeps returned error: %v", err) } if traceCalled { - t.Fatalf("did not expect runTraceFn for test flames mode") + t.Fatalf("did not expect runTrace for test flames mode") } if regularTUICalled { - t.Fatalf("did not expect runTUIFn for test flames mode") + t.Fatalf("did not expect runTUI for test flames mode") } if !testFlamesCalled { - t.Fatalf("expected runTUITestFlamesFn to be called") + t.Fatalf("expected runTUITestFlames to be called") } } func TestDispatchRunUsesTestLiveFlamesModeWhenRequested(t *testing.T) { - origRunTrace := runTraceFn - origRunTUI := runTUIFn - origRunTUITestFlames := runTUITestFlamesFn - origRunTUITestLiveFlames := runTUITestLiveFlamesFn - defer func() { - runTraceFn = origRunTrace - runTUIFn = origRunTUI - runTUITestFlamesFn = origRunTUITestFlames - runTUITestLiveFlamesFn = origRunTUITestLiveFlames - }() - traceCalled := false regularTUICalled := false testLiveFlamesCalled := false - runTraceFn = func(flags.Config) error { + deps := stubDeps() + deps.runTrace = func(flags.Config) error { traceCalled = true return nil } - runTUIFn = func(flags.Config, runtime.TraceStarter) error { + deps.runTUI = func(flags.Config, runtime.TraceStarter) error { regularTUICalled = true return nil } - runTUITestFlamesFn = func(flags.Config, runtime.TraceStarter) error { - t.Fatalf("runTUITestFlamesFn should not be called for --testliveflames") + deps.runTUITestFlames = func(flags.Config, runtime.TraceStarter) error { + t.Fatalf("runTUITestFlames should not be called for --testliveflames") return nil } - runTUITestLiveFlamesFn = func(_ flags.Config, starter runtime.TraceStarter) error { + deps.runTUITestLiveFlames = func(_ flags.Config, starter runtime.TraceStarter) error { testLiveFlamesCalled = true if starter == nil { t.Fatalf("expected non-nil starter for test live flames mode") @@ -389,17 +328,68 @@ func TestDispatchRunUsesTestLiveFlamesModeWhenRequested(t *testing.T) { } cfg := flags.Config{TestLiveFlames: true} - if err := dispatchRun(cfg); err != nil { - t.Fatalf("dispatchRun returned error: %v", err) + if err := dispatchRunWithDeps(cfg, deps); err != nil { + t.Fatalf("dispatchRunWithDeps returned error: %v", err) } if traceCalled { - t.Fatalf("did not expect runTraceFn for test live flames mode") + t.Fatalf("did not expect runTrace for test live flames mode") } if regularTUICalled { - t.Fatalf("did not expect runTUIFn for test live flames mode") + t.Fatalf("did not expect runTUI for test live flames mode") } if !testLiveFlamesCalled { - t.Fatalf("expected runTUITestLiveFlamesFn to be called") + t.Fatalf("expected runTUITestLiveFlames to be called") + } +} + +// TestDispatchRunRequiresRootForTUI verifies that the TUI mode handler +// enforces the root-privilege gate via deps.getEUID. +func TestDispatchRunRequiresRootForTUI(t *testing.T) { + deps := stubDeps() + deps.getEUID = func() int { return 1000 } // non-root + deps.runTUI = func(flags.Config, runtime.TraceStarter) error { + t.Fatalf("runTUI must not be called when not root") + return nil + } + + cfg := flags.Config{} + err := dispatchRunWithDeps(cfg, deps) + if !errors.Is(err, errRootPrivilegesRequired) { + t.Fatalf("expected root-required error, got %v", err) + } +} + +// TestDispatchRunRequiresRootForPlainTrace verifies that the plain trace +// mode handler enforces the root-privilege gate via deps.getEUID. +func TestDispatchRunRequiresRootForPlainTrace(t *testing.T) { + deps := stubDeps() + deps.getEUID = func() int { return 1000 } // non-root + deps.runTrace = func(flags.Config) error { + t.Fatalf("runTrace must not be called when not root") + return nil + } + + cfg := flags.Config{PlainMode: true} + err := dispatchRunWithDeps(cfg, deps) + if !errors.Is(err, errRootPrivilegesRequired) { + t.Fatalf("expected root-required error, got %v", err) + } +} + +// TestDispatchRunRequiresRootForParquet verifies that the headless Parquet +// mode handler enforces the root-privilege gate via deps.getEUID. +func TestDispatchRunRequiresRootForParquet(t *testing.T) { + deps := stubDeps() + deps.getEUID = func() int { return 1000 } // non-root + deps.runParquet = func(flags.Config) error { + t.Fatalf("runParquet must not be called when not root") + return nil + } + + cfg := flags.Config{ParquetPath: "trace.parquet"} + err := dispatchRunWithDeps(cfg, deps) + if !errors.Is(err, errRootPrivilegesRequired) { + t.Fatalf("expected root-required error, got %v", err) } } @@ -536,17 +526,6 @@ func TestBuildTestLiveFlamesRuntimeContinuouslyUpdatesLiveTrie(t *testing.T) { }) } -func TestRunTraceWithContextRequiresRoot(t *testing.T) { - origGetEUID := getEUID - defer func() { getEUID = origGetEUID }() - - getEUID = func() int { return 1000 } - err := runTraceWithContext(context.Background(), flags.NewFlags(), nil, nil) - if !errors.Is(err, errRootPrivilegesRequired) { - t.Fatalf("expected root-required error, got %v", err) - } -} - func TestTuiTraceStarterFromRunTracePropagatesError(t *testing.T) { starter := tuiTraceStarterFromRunTrace( flags.NewFlags(), diff --git a/internal/ior_parquet_sink.go b/internal/ior_parquet_sink.go index 83ca769..b2a1439 100644 --- a/internal/ior_parquet_sink.go +++ b/internal/ior_parquet_sink.go @@ -90,11 +90,9 @@ func headlessParquetTraceConfig(cfg flags.Config) flags.Config { } // runHeadlessParquet records all traced syscalls directly to a Parquet file -// without starting the TUI. +// without starting the TUI. Root privilege is checked by the mode handler +// (via runnerDeps.getEUID) before this function is invoked. func runHeadlessParquet(cfg flags.Config) error { - if getEUID() != 0 { - return errRootPrivilegesRequired - } cfg = headlessParquetTraceConfig(cfg) logln := newLogger(true) |
