diff options
| author | Paul Buetow <paul@buetow.org> | 2026-05-13 09:35:46 +0300 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2026-05-13 09:35:46 +0300 |
| commit | e9c747c607e14d8b8f69547aa3c0f3b079c20796 (patch) | |
| tree | b5c461f3ace5e7785ed9716fdc136ff985475440 /internal/ior_mode_test.go | |
| parent | b086d3680e845a94c6662811c7f40e9592c3dec6 (diff) | |
replace package-level test doubles in ior.go with constructor injection (DIP)
Introduce runnerDeps struct to bundle all injectable function dependencies
(getEUID, runTrace, runParquet, runTraceWithContext, runTUI*) that were
previously package-level vars overridden in tests. The modeRegistry now
carries a runnerDeps instance and passes it to each handler's run() method,
eliminating global state mutation in tests.
- Add runnerDeps struct and defaultRunnerDeps() constructor in ior_mode_registry.go
- Convert modeRegistry from a []modeHandler slice type to a struct with
handlers + deps fields; add newModeRegistry(deps) constructor
- Update modeHandler.run() signature to accept runnerDeps; handlers call
deps.getEUID / deps.runTrace etc. instead of globals
- Update SetTUIRunners to write into defaultRegistry.deps instead of
package-level vars
- Add dispatchRunWithDeps helper for test isolation without global mutation
- Remove root-privilege check from runTraceWithContext and runHeadlessParquet;
each mode handler owns the EUID gate via deps.getEUID
- Rewrite ior_mode_test.go: replace save/restore patterns with stubDeps()
helper and dispatchRunWithDeps; add three new root-privilege tests
replacing the removed TestRunTraceWithContextRequiresRoot
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Diffstat (limited to 'internal/ior_mode_test.go')
| -rw-r--r-- | internal/ior_mode_test.go | 297 |
1 files changed, 138 insertions, 159 deletions
diff --git a/internal/ior_mode_test.go b/internal/ior_mode_test.go index 0697ada..5ba2894 100644 --- a/internal/ior_mode_test.go +++ b/internal/ior_mode_test.go @@ -25,6 +25,22 @@ import ( parquetgo "github.com/parquet-go/parquet-go" ) +// stubDeps returns a runnerDeps with safe no-op stubs for every function +// field. Individual tests override only the functions they care about, +// keeping test setup concise and making it easy to add new fields without +// updating every test. +func stubDeps() runnerDeps { + return runnerDeps{ + getEUID: func() int { return 0 }, + runTrace: func(flags.Config) error { return nil }, + runParquet: func(flags.Config) error { return nil }, + runTraceWithContext: func(context.Context, flags.Config, chan<- struct{}, func(*eventLoop)) error { return nil }, + runTUI: func(flags.Config, runtime.TraceStarter) error { return nil }, + runTUITestFlames: func(flags.Config, runtime.TraceStarter) error { return nil }, + runTUITestLiveFlames: func(flags.Config, runtime.TraceStarter) error { return nil }, + } +} + func TestShouldRunTraceMode(t *testing.T) { base := flags.Config{} @@ -86,180 +102,123 @@ func TestShouldAutoStopByDuration(t *testing.T) { if shouldAutoStopByDuration(withPprof) { t.Fatalf("expected pprof flag alone not to auto-stop by duration") } - } func TestDispatchRunUsesTraceModeWhenRequested(t *testing.T) { - origRunTrace := runTraceFn - origRunParquet := runParquetFn - origRunTUI := runTUIFn - origRunTUITestFlames := runTUITestFlamesFn - origRunTUITestLiveFlames := runTUITestLiveFlamesFn - origGetEUID := getEUID - defer func() { - runTraceFn = origRunTrace - runParquetFn = origRunParquet - runTUIFn = origRunTUI - runTUITestFlamesFn = origRunTUITestFlames - runTUITestLiveFlamesFn = origRunTUITestLiveFlames - getEUID = origGetEUID - }() - getEUID = func() int { return 0 } - traceCalled := false - tuiCalled := false - runTraceFn = func(flags.Config) error { + deps := stubDeps() + deps.runTrace = func(flags.Config) error { traceCalled = true return nil } - runParquetFn = func(flags.Config) error { - t.Fatalf("runParquetFn should not be called in plain trace mode") + deps.runParquet = func(flags.Config) error { + t.Fatalf("runParquet should not be called in plain trace mode") return nil } - runTUIFn = func(flags.Config, runtime.TraceStarter) error { + tuiCalled := false + deps.runTUI = func(flags.Config, runtime.TraceStarter) error { tuiCalled = true return nil } - runTUITestFlamesFn = func(flags.Config, runtime.TraceStarter) error { - t.Fatalf("runTUITestFlamesFn should not be called in trace mode") + deps.runTUITestFlames = func(flags.Config, runtime.TraceStarter) error { + t.Fatalf("runTUITestFlames should not be called in trace mode") return nil } - runTUITestLiveFlamesFn = func(flags.Config, runtime.TraceStarter) error { - t.Fatalf("runTUITestLiveFlamesFn should not be called in trace mode") + deps.runTUITestLiveFlames = func(flags.Config, runtime.TraceStarter) error { + t.Fatalf("runTUITestLiveFlames should not be called in trace mode") return nil } cfg := flags.Config{PlainMode: true} - if err := dispatchRun(cfg); err != nil { - t.Fatalf("dispatchRun returned error: %v", err) + if err := dispatchRunWithDeps(cfg, deps); err != nil { + t.Fatalf("dispatchRunWithDeps returned error: %v", err) } if !traceCalled { - t.Fatalf("expected runTraceFn to be called") + t.Fatalf("expected runTrace to be called") } if tuiCalled { - t.Fatalf("did not expect runTUIFn to be called") + t.Fatalf("did not expect runTUI to be called") } } func TestDispatchRunUsesHeadlessParquetModeWhenRequested(t *testing.T) { - origRunTrace := runTraceFn - origRunParquet := runParquetFn - origRunTUI := runTUIFn - origRunTUITestFlames := runTUITestFlamesFn - origRunTUITestLiveFlames := runTUITestLiveFlamesFn - origGetEUID := getEUID - defer func() { - runTraceFn = origRunTrace - runParquetFn = origRunParquet - runTUIFn = origRunTUI - runTUITestFlamesFn = origRunTUITestFlames - runTUITestLiveFlamesFn = origRunTUITestLiveFlames - getEUID = origGetEUID - }() - getEUID = func() int { return 0 } - traceCalled := false parquetCalled := false tuiCalled := false - runTraceFn = func(flags.Config) error { + deps := stubDeps() + deps.runTrace = func(flags.Config) error { traceCalled = true return nil } - runParquetFn = func(flags.Config) error { + deps.runParquet = func(flags.Config) error { parquetCalled = true return nil } - runTUIFn = func(flags.Config, runtime.TraceStarter) error { + deps.runTUI = func(flags.Config, runtime.TraceStarter) error { tuiCalled = true return nil } - runTUITestFlamesFn = func(flags.Config, runtime.TraceStarter) error { - t.Fatalf("runTUITestFlamesFn should not be called in parquet mode") + deps.runTUITestFlames = func(flags.Config, runtime.TraceStarter) error { + t.Fatalf("runTUITestFlames should not be called in parquet mode") return nil } - runTUITestLiveFlamesFn = func(flags.Config, runtime.TraceStarter) error { - t.Fatalf("runTUITestLiveFlamesFn should not be called in parquet mode") + deps.runTUITestLiveFlames = func(flags.Config, runtime.TraceStarter) error { + t.Fatalf("runTUITestLiveFlames should not be called in parquet mode") return nil } cfg := flags.Config{ParquetPath: "trace.parquet"} - if err := dispatchRun(cfg); err != nil { - t.Fatalf("dispatchRun returned error: %v", err) + if err := dispatchRunWithDeps(cfg, deps); err != nil { + t.Fatalf("dispatchRunWithDeps returned error: %v", err) } if !parquetCalled { - t.Fatalf("expected runParquetFn to be called") + t.Fatalf("expected runParquet to be called") } if traceCalled { - t.Fatalf("did not expect runTraceFn to be called") + t.Fatalf("did not expect runTrace to be called") } if tuiCalled { - t.Fatalf("did not expect runTUIFn to be called") + t.Fatalf("did not expect runTUI to be called") } } func TestDispatchRunUsesTUIWhenOnlyPprofEnabled(t *testing.T) { - origRunTrace := runTraceFn - origRunTUI := runTUIFn - origRunTUITestFlames := runTUITestFlamesFn - origRunTUITestLiveFlames := runTUITestLiveFlamesFn - origGetEUID := getEUID - defer func() { - runTraceFn = origRunTrace - runTUIFn = origRunTUI - runTUITestFlamesFn = origRunTUITestFlames - runTUITestLiveFlamesFn = origRunTUITestLiveFlames - getEUID = origGetEUID - }() - getEUID = func() int { return 0 } - traceCalled := false tuiCalled := false - runTraceFn = func(flags.Config) error { + deps := stubDeps() + deps.runTrace = func(flags.Config) error { traceCalled = true return nil } - runTUIFn = func(flags.Config, runtime.TraceStarter) error { + deps.runTUI = func(flags.Config, runtime.TraceStarter) error { tuiCalled = true return nil } - runTUITestFlamesFn = func(flags.Config, runtime.TraceStarter) error { - t.Fatalf("runTUITestFlamesFn should not be called for regular TUI mode") + deps.runTUITestFlames = func(flags.Config, runtime.TraceStarter) error { + t.Fatalf("runTUITestFlames should not be called for regular TUI mode") return nil } - runTUITestLiveFlamesFn = func(flags.Config, runtime.TraceStarter) error { - t.Fatalf("runTUITestLiveFlamesFn should not be called for regular TUI mode") + deps.runTUITestLiveFlames = func(flags.Config, runtime.TraceStarter) error { + t.Fatalf("runTUITestLiveFlames should not be called for regular TUI mode") return nil } cfg := flags.Config{PprofEnable: true} - if err := dispatchRun(cfg); err != nil { - t.Fatalf("dispatchRun returned error: %v", err) + if err := dispatchRunWithDeps(cfg, deps); err != nil { + t.Fatalf("dispatchRunWithDeps returned error: %v", err) } if traceCalled { - t.Fatalf("did not expect runTraceFn when only -pprof is enabled") + t.Fatalf("did not expect runTrace when only -pprof is enabled") } if !tuiCalled { - t.Fatalf("expected runTUIFn to be called") + t.Fatalf("expected runTUI to be called") } } func TestDispatchRunUsesTUIStarterWhenNotPlain(t *testing.T) { - origRunTraceWithContext := runTraceWithContextFn - origRunTUI := runTUIFn - origRunTUITestFlames := runTUITestFlamesFn - origRunTUITestLiveFlames := runTUITestLiveFlamesFn - origGetEUID := getEUID - defer func() { - runTraceWithContextFn = origRunTraceWithContext - runTUIFn = origRunTUI - runTUITestFlamesFn = origRunTUITestFlames - runTUITestLiveFlamesFn = origRunTUITestLiveFlames - getEUID = origGetEUID - }() - getEUID = func() int { return 0 } - traceDone := make(chan struct{}, 1) - runTraceWithContextFn = func(_ context.Context, _ flags.Config, started chan<- struct{}, configure func(*eventLoop)) error { + deps := stubDeps() + deps.runTraceWithContext = func(_ context.Context, _ flags.Config, started chan<- struct{}, configure func(*eventLoop)) error { if configure != nil { configure(&eventLoop{}) } @@ -269,7 +228,7 @@ func TestDispatchRunUsesTUIStarterWhenNotPlain(t *testing.T) { } tuiCalled := false - runTUIFn = func(_ flags.Config, starter runtime.TraceStarter) error { + deps.runTUI = func(_ flags.Config, starter runtime.TraceStarter) error { tuiCalled = true if starter == nil { t.Fatalf("expected non-nil starter") @@ -279,108 +238,88 @@ func TestDispatchRunUsesTUIStarterWhenNotPlain(t *testing.T) { } return nil } - runTUITestFlamesFn = func(flags.Config, runtime.TraceStarter) error { - t.Fatalf("runTUITestFlamesFn should not be called for normal starter path") + deps.runTUITestFlames = func(flags.Config, runtime.TraceStarter) error { + t.Fatalf("runTUITestFlames should not be called for normal starter path") return nil } - runTUITestLiveFlamesFn = func(flags.Config, runtime.TraceStarter) error { - t.Fatalf("runTUITestLiveFlamesFn should not be called for normal starter path") + deps.runTUITestLiveFlames = func(flags.Config, runtime.TraceStarter) error { + t.Fatalf("runTUITestLiveFlames should not be called for normal starter path") return nil } cfg := flags.Config{} - if err := dispatchRun(cfg); err != nil { - t.Fatalf("dispatchRun returned error: %v", err) + if err := dispatchRunWithDeps(cfg, deps); err != nil { + t.Fatalf("dispatchRunWithDeps returned error: %v", err) } if !tuiCalled { - t.Fatalf("expected runTUIFn to be called") + t.Fatalf("expected runTUI to be called") } select { case <-traceDone: case <-time.After(200 * time.Millisecond): - t.Fatalf("expected starter to launch runTraceWithContextFn") + t.Fatalf("expected starter to launch runTraceWithContext") } } func TestDispatchRunUsesTestFlamesModeWhenRequested(t *testing.T) { - origRunTrace := runTraceFn - origRunTUI := runTUIFn - origRunTUITestFlames := runTUITestFlamesFn - origRunTUITestLiveFlames := runTUITestLiveFlamesFn - defer func() { - runTraceFn = origRunTrace - runTUIFn = origRunTUI - runTUITestFlamesFn = origRunTUITestFlames - runTUITestLiveFlamesFn = origRunTUITestLiveFlames - }() - traceCalled := false regularTUICalled := false testFlamesCalled := false - runTraceFn = func(flags.Config) error { + deps := stubDeps() + deps.runTrace = func(flags.Config) error { traceCalled = true return nil } - runTUIFn = func(flags.Config, runtime.TraceStarter) error { + deps.runTUI = func(flags.Config, runtime.TraceStarter) error { regularTUICalled = true return nil } - runTUITestFlamesFn = func(_ flags.Config, starter runtime.TraceStarter) error { + deps.runTUITestFlames = func(_ flags.Config, starter runtime.TraceStarter) error { testFlamesCalled = true if starter == nil { t.Fatalf("expected non-nil starter for test flames mode") } return starter(context.Background()) } - runTUITestLiveFlamesFn = func(flags.Config, runtime.TraceStarter) error { - t.Fatalf("runTUITestLiveFlamesFn should not be called for --testflames") + deps.runTUITestLiveFlames = func(flags.Config, runtime.TraceStarter) error { + t.Fatalf("runTUITestLiveFlames should not be called for --testflames") return nil } cfg := flags.Config{TestFlames: true} - if err := dispatchRun(cfg); err != nil { - t.Fatalf("dispatchRun returned error: %v", err) + if err := dispatchRunWithDeps(cfg, deps); err != nil { + t.Fatalf("dispatchRunWithDeps returned error: %v", err) } if traceCalled { - t.Fatalf("did not expect runTraceFn for test flames mode") + t.Fatalf("did not expect runTrace for test flames mode") } if regularTUICalled { - t.Fatalf("did not expect runTUIFn for test flames mode") + t.Fatalf("did not expect runTUI for test flames mode") } if !testFlamesCalled { - t.Fatalf("expected runTUITestFlamesFn to be called") + t.Fatalf("expected runTUITestFlames to be called") } } func TestDispatchRunUsesTestLiveFlamesModeWhenRequested(t *testing.T) { - origRunTrace := runTraceFn - origRunTUI := runTUIFn - origRunTUITestFlames := runTUITestFlamesFn - origRunTUITestLiveFlames := runTUITestLiveFlamesFn - defer func() { - runTraceFn = origRunTrace - runTUIFn = origRunTUI - runTUITestFlamesFn = origRunTUITestFlames - runTUITestLiveFlamesFn = origRunTUITestLiveFlames - }() - traceCalled := false regularTUICalled := false testLiveFlamesCalled := false - runTraceFn = func(flags.Config) error { + deps := stubDeps() + deps.runTrace = func(flags.Config) error { traceCalled = true return nil } - runTUIFn = func(flags.Config, runtime.TraceStarter) error { + deps.runTUI = func(flags.Config, runtime.TraceStarter) error { regularTUICalled = true return nil } - runTUITestFlamesFn = func(flags.Config, runtime.TraceStarter) error { - t.Fatalf("runTUITestFlamesFn should not be called for --testliveflames") + deps.runTUITestFlames = func(flags.Config, runtime.TraceStarter) error { + t.Fatalf("runTUITestFlames should not be called for --testliveflames") return nil } - runTUITestLiveFlamesFn = func(_ flags.Config, starter runtime.TraceStarter) error { + deps.runTUITestLiveFlames = func(_ flags.Config, starter runtime.TraceStarter) error { testLiveFlamesCalled = true if starter == nil { t.Fatalf("expected non-nil starter for test live flames mode") @@ -389,17 +328,68 @@ func TestDispatchRunUsesTestLiveFlamesModeWhenRequested(t *testing.T) { } cfg := flags.Config{TestLiveFlames: true} - if err := dispatchRun(cfg); err != nil { - t.Fatalf("dispatchRun returned error: %v", err) + if err := dispatchRunWithDeps(cfg, deps); err != nil { + t.Fatalf("dispatchRunWithDeps returned error: %v", err) } if traceCalled { - t.Fatalf("did not expect runTraceFn for test live flames mode") + t.Fatalf("did not expect runTrace for test live flames mode") } if regularTUICalled { - t.Fatalf("did not expect runTUIFn for test live flames mode") + t.Fatalf("did not expect runTUI for test live flames mode") } if !testLiveFlamesCalled { - t.Fatalf("expected runTUITestLiveFlamesFn to be called") + t.Fatalf("expected runTUITestLiveFlames to be called") + } +} + +// TestDispatchRunRequiresRootForTUI verifies that the TUI mode handler +// enforces the root-privilege gate via deps.getEUID. +func TestDispatchRunRequiresRootForTUI(t *testing.T) { + deps := stubDeps() + deps.getEUID = func() int { return 1000 } // non-root + deps.runTUI = func(flags.Config, runtime.TraceStarter) error { + t.Fatalf("runTUI must not be called when not root") + return nil + } + + cfg := flags.Config{} + err := dispatchRunWithDeps(cfg, deps) + if !errors.Is(err, errRootPrivilegesRequired) { + t.Fatalf("expected root-required error, got %v", err) + } +} + +// TestDispatchRunRequiresRootForPlainTrace verifies that the plain trace +// mode handler enforces the root-privilege gate via deps.getEUID. +func TestDispatchRunRequiresRootForPlainTrace(t *testing.T) { + deps := stubDeps() + deps.getEUID = func() int { return 1000 } // non-root + deps.runTrace = func(flags.Config) error { + t.Fatalf("runTrace must not be called when not root") + return nil + } + + cfg := flags.Config{PlainMode: true} + err := dispatchRunWithDeps(cfg, deps) + if !errors.Is(err, errRootPrivilegesRequired) { + t.Fatalf("expected root-required error, got %v", err) + } +} + +// TestDispatchRunRequiresRootForParquet verifies that the headless Parquet +// mode handler enforces the root-privilege gate via deps.getEUID. +func TestDispatchRunRequiresRootForParquet(t *testing.T) { + deps := stubDeps() + deps.getEUID = func() int { return 1000 } // non-root + deps.runParquet = func(flags.Config) error { + t.Fatalf("runParquet must not be called when not root") + return nil + } + + cfg := flags.Config{ParquetPath: "trace.parquet"} + err := dispatchRunWithDeps(cfg, deps) + if !errors.Is(err, errRootPrivilegesRequired) { + t.Fatalf("expected root-required error, got %v", err) } } @@ -536,17 +526,6 @@ func TestBuildTestLiveFlamesRuntimeContinuouslyUpdatesLiveTrie(t *testing.T) { }) } -func TestRunTraceWithContextRequiresRoot(t *testing.T) { - origGetEUID := getEUID - defer func() { getEUID = origGetEUID }() - - getEUID = func() int { return 1000 } - err := runTraceWithContext(context.Background(), flags.NewFlags(), nil, nil) - if !errors.Is(err, errRootPrivilegesRequired) { - t.Fatalf("expected root-required error, got %v", err) - } -} - func TestTuiTraceStarterFromRunTracePropagatesError(t *testing.T) { starter := tuiTraceStarterFromRunTrace( flags.NewFlags(), |
