From e9c747c607e14d8b8f69547aa3c0f3b079c20796 Mon Sep 17 00:00:00 2001 From: Paul Buetow Date: Wed, 13 May 2026 09:35:46 +0300 Subject: replace package-level test doubles in ior.go with constructor injection (DIP) Introduce runnerDeps struct to bundle all injectable function dependencies (getEUID, runTrace, runParquet, runTraceWithContext, runTUI*) that were previously package-level vars overridden in tests. The modeRegistry now carries a runnerDeps instance and passes it to each handler's run() method, eliminating global state mutation in tests. - Add runnerDeps struct and defaultRunnerDeps() constructor in ior_mode_registry.go - Convert modeRegistry from a []modeHandler slice type to a struct with handlers + deps fields; add newModeRegistry(deps) constructor - Update modeHandler.run() signature to accept runnerDeps; handlers call deps.getEUID / deps.runTrace etc. instead of globals - Update SetTUIRunners to write into defaultRegistry.deps instead of package-level vars - Add dispatchRunWithDeps helper for test isolation without global mutation - Remove root-privilege check from runTraceWithContext and runHeadlessParquet; each mode handler owns the EUID gate via deps.getEUID - Rewrite ior_mode_test.go: replace save/restore patterns with stubDeps() helper and dispatchRunWithDeps; add three new root-privilege tests replacing the removed TestRunTraceWithContextRequiresRoot Co-Authored-By: Claude Sonnet 4.6 --- internal/ior_mode_test.go | 297 +++++++++++++++++++++------------------------- 1 file changed, 138 insertions(+), 159 deletions(-) (limited to 'internal/ior_mode_test.go') diff --git a/internal/ior_mode_test.go b/internal/ior_mode_test.go index 0697ada..5ba2894 100644 --- a/internal/ior_mode_test.go +++ b/internal/ior_mode_test.go @@ -25,6 +25,22 @@ import ( parquetgo "github.com/parquet-go/parquet-go" ) +// stubDeps returns a runnerDeps with safe no-op stubs for every function +// field. Individual tests override only the functions they care about, +// keeping test setup concise and making it easy to add new fields without +// updating every test. +func stubDeps() runnerDeps { + return runnerDeps{ + getEUID: func() int { return 0 }, + runTrace: func(flags.Config) error { return nil }, + runParquet: func(flags.Config) error { return nil }, + runTraceWithContext: func(context.Context, flags.Config, chan<- struct{}, func(*eventLoop)) error { return nil }, + runTUI: func(flags.Config, runtime.TraceStarter) error { return nil }, + runTUITestFlames: func(flags.Config, runtime.TraceStarter) error { return nil }, + runTUITestLiveFlames: func(flags.Config, runtime.TraceStarter) error { return nil }, + } +} + func TestShouldRunTraceMode(t *testing.T) { base := flags.Config{} @@ -86,180 +102,123 @@ func TestShouldAutoStopByDuration(t *testing.T) { if shouldAutoStopByDuration(withPprof) { t.Fatalf("expected pprof flag alone not to auto-stop by duration") } - } func TestDispatchRunUsesTraceModeWhenRequested(t *testing.T) { - origRunTrace := runTraceFn - origRunParquet := runParquetFn - origRunTUI := runTUIFn - origRunTUITestFlames := runTUITestFlamesFn - origRunTUITestLiveFlames := runTUITestLiveFlamesFn - origGetEUID := getEUID - defer func() { - runTraceFn = origRunTrace - runParquetFn = origRunParquet - runTUIFn = origRunTUI - runTUITestFlamesFn = origRunTUITestFlames - runTUITestLiveFlamesFn = origRunTUITestLiveFlames - getEUID = origGetEUID - }() - getEUID = func() int { return 0 } - traceCalled := false - tuiCalled := false - runTraceFn = func(flags.Config) error { + deps := stubDeps() + deps.runTrace = func(flags.Config) error { traceCalled = true return nil } - runParquetFn = func(flags.Config) error { - t.Fatalf("runParquetFn should not be called in plain trace mode") + deps.runParquet = func(flags.Config) error { + t.Fatalf("runParquet should not be called in plain trace mode") return nil } - runTUIFn = func(flags.Config, runtime.TraceStarter) error { + tuiCalled := false + deps.runTUI = func(flags.Config, runtime.TraceStarter) error { tuiCalled = true return nil } - runTUITestFlamesFn = func(flags.Config, runtime.TraceStarter) error { - t.Fatalf("runTUITestFlamesFn should not be called in trace mode") + deps.runTUITestFlames = func(flags.Config, runtime.TraceStarter) error { + t.Fatalf("runTUITestFlames should not be called in trace mode") return nil } - runTUITestLiveFlamesFn = func(flags.Config, runtime.TraceStarter) error { - t.Fatalf("runTUITestLiveFlamesFn should not be called in trace mode") + deps.runTUITestLiveFlames = func(flags.Config, runtime.TraceStarter) error { + t.Fatalf("runTUITestLiveFlames should not be called in trace mode") return nil } cfg := flags.Config{PlainMode: true} - if err := dispatchRun(cfg); err != nil { - t.Fatalf("dispatchRun returned error: %v", err) + if err := dispatchRunWithDeps(cfg, deps); err != nil { + t.Fatalf("dispatchRunWithDeps returned error: %v", err) } if !traceCalled { - t.Fatalf("expected runTraceFn to be called") + t.Fatalf("expected runTrace to be called") } if tuiCalled { - t.Fatalf("did not expect runTUIFn to be called") + t.Fatalf("did not expect runTUI to be called") } } func TestDispatchRunUsesHeadlessParquetModeWhenRequested(t *testing.T) { - origRunTrace := runTraceFn - origRunParquet := runParquetFn - origRunTUI := runTUIFn - origRunTUITestFlames := runTUITestFlamesFn - origRunTUITestLiveFlames := runTUITestLiveFlamesFn - origGetEUID := getEUID - defer func() { - runTraceFn = origRunTrace - runParquetFn = origRunParquet - runTUIFn = origRunTUI - runTUITestFlamesFn = origRunTUITestFlames - runTUITestLiveFlamesFn = origRunTUITestLiveFlames - getEUID = origGetEUID - }() - getEUID = func() int { return 0 } - traceCalled := false parquetCalled := false tuiCalled := false - runTraceFn = func(flags.Config) error { + deps := stubDeps() + deps.runTrace = func(flags.Config) error { traceCalled = true return nil } - runParquetFn = func(flags.Config) error { + deps.runParquet = func(flags.Config) error { parquetCalled = true return nil } - runTUIFn = func(flags.Config, runtime.TraceStarter) error { + deps.runTUI = func(flags.Config, runtime.TraceStarter) error { tuiCalled = true return nil } - runTUITestFlamesFn = func(flags.Config, runtime.TraceStarter) error { - t.Fatalf("runTUITestFlamesFn should not be called in parquet mode") + deps.runTUITestFlames = func(flags.Config, runtime.TraceStarter) error { + t.Fatalf("runTUITestFlames should not be called in parquet mode") return nil } - runTUITestLiveFlamesFn = func(flags.Config, runtime.TraceStarter) error { - t.Fatalf("runTUITestLiveFlamesFn should not be called in parquet mode") + deps.runTUITestLiveFlames = func(flags.Config, runtime.TraceStarter) error { + t.Fatalf("runTUITestLiveFlames should not be called in parquet mode") return nil } cfg := flags.Config{ParquetPath: "trace.parquet"} - if err := dispatchRun(cfg); err != nil { - t.Fatalf("dispatchRun returned error: %v", err) + if err := dispatchRunWithDeps(cfg, deps); err != nil { + t.Fatalf("dispatchRunWithDeps returned error: %v", err) } if !parquetCalled { - t.Fatalf("expected runParquetFn to be called") + t.Fatalf("expected runParquet to be called") } if traceCalled { - t.Fatalf("did not expect runTraceFn to be called") + t.Fatalf("did not expect runTrace to be called") } if tuiCalled { - t.Fatalf("did not expect runTUIFn to be called") + t.Fatalf("did not expect runTUI to be called") } } func TestDispatchRunUsesTUIWhenOnlyPprofEnabled(t *testing.T) { - origRunTrace := runTraceFn - origRunTUI := runTUIFn - origRunTUITestFlames := runTUITestFlamesFn - origRunTUITestLiveFlames := runTUITestLiveFlamesFn - origGetEUID := getEUID - defer func() { - runTraceFn = origRunTrace - runTUIFn = origRunTUI - runTUITestFlamesFn = origRunTUITestFlames - runTUITestLiveFlamesFn = origRunTUITestLiveFlames - getEUID = origGetEUID - }() - getEUID = func() int { return 0 } - traceCalled := false tuiCalled := false - runTraceFn = func(flags.Config) error { + deps := stubDeps() + deps.runTrace = func(flags.Config) error { traceCalled = true return nil } - runTUIFn = func(flags.Config, runtime.TraceStarter) error { + deps.runTUI = func(flags.Config, runtime.TraceStarter) error { tuiCalled = true return nil } - runTUITestFlamesFn = func(flags.Config, runtime.TraceStarter) error { - t.Fatalf("runTUITestFlamesFn should not be called for regular TUI mode") + deps.runTUITestFlames = func(flags.Config, runtime.TraceStarter) error { + t.Fatalf("runTUITestFlames should not be called for regular TUI mode") return nil } - runTUITestLiveFlamesFn = func(flags.Config, runtime.TraceStarter) error { - t.Fatalf("runTUITestLiveFlamesFn should not be called for regular TUI mode") + deps.runTUITestLiveFlames = func(flags.Config, runtime.TraceStarter) error { + t.Fatalf("runTUITestLiveFlames should not be called for regular TUI mode") return nil } cfg := flags.Config{PprofEnable: true} - if err := dispatchRun(cfg); err != nil { - t.Fatalf("dispatchRun returned error: %v", err) + if err := dispatchRunWithDeps(cfg, deps); err != nil { + t.Fatalf("dispatchRunWithDeps returned error: %v", err) } if traceCalled { - t.Fatalf("did not expect runTraceFn when only -pprof is enabled") + t.Fatalf("did not expect runTrace when only -pprof is enabled") } if !tuiCalled { - t.Fatalf("expected runTUIFn to be called") + t.Fatalf("expected runTUI to be called") } } func TestDispatchRunUsesTUIStarterWhenNotPlain(t *testing.T) { - origRunTraceWithContext := runTraceWithContextFn - origRunTUI := runTUIFn - origRunTUITestFlames := runTUITestFlamesFn - origRunTUITestLiveFlames := runTUITestLiveFlamesFn - origGetEUID := getEUID - defer func() { - runTraceWithContextFn = origRunTraceWithContext - runTUIFn = origRunTUI - runTUITestFlamesFn = origRunTUITestFlames - runTUITestLiveFlamesFn = origRunTUITestLiveFlames - getEUID = origGetEUID - }() - getEUID = func() int { return 0 } - traceDone := make(chan struct{}, 1) - runTraceWithContextFn = func(_ context.Context, _ flags.Config, started chan<- struct{}, configure func(*eventLoop)) error { + deps := stubDeps() + deps.runTraceWithContext = func(_ context.Context, _ flags.Config, started chan<- struct{}, configure func(*eventLoop)) error { if configure != nil { configure(&eventLoop{}) } @@ -269,7 +228,7 @@ func TestDispatchRunUsesTUIStarterWhenNotPlain(t *testing.T) { } tuiCalled := false - runTUIFn = func(_ flags.Config, starter runtime.TraceStarter) error { + deps.runTUI = func(_ flags.Config, starter runtime.TraceStarter) error { tuiCalled = true if starter == nil { t.Fatalf("expected non-nil starter") @@ -279,108 +238,88 @@ func TestDispatchRunUsesTUIStarterWhenNotPlain(t *testing.T) { } return nil } - runTUITestFlamesFn = func(flags.Config, runtime.TraceStarter) error { - t.Fatalf("runTUITestFlamesFn should not be called for normal starter path") + deps.runTUITestFlames = func(flags.Config, runtime.TraceStarter) error { + t.Fatalf("runTUITestFlames should not be called for normal starter path") return nil } - runTUITestLiveFlamesFn = func(flags.Config, runtime.TraceStarter) error { - t.Fatalf("runTUITestLiveFlamesFn should not be called for normal starter path") + deps.runTUITestLiveFlames = func(flags.Config, runtime.TraceStarter) error { + t.Fatalf("runTUITestLiveFlames should not be called for normal starter path") return nil } cfg := flags.Config{} - if err := dispatchRun(cfg); err != nil { - t.Fatalf("dispatchRun returned error: %v", err) + if err := dispatchRunWithDeps(cfg, deps); err != nil { + t.Fatalf("dispatchRunWithDeps returned error: %v", err) } if !tuiCalled { - t.Fatalf("expected runTUIFn to be called") + t.Fatalf("expected runTUI to be called") } select { case <-traceDone: case <-time.After(200 * time.Millisecond): - t.Fatalf("expected starter to launch runTraceWithContextFn") + t.Fatalf("expected starter to launch runTraceWithContext") } } func TestDispatchRunUsesTestFlamesModeWhenRequested(t *testing.T) { - origRunTrace := runTraceFn - origRunTUI := runTUIFn - origRunTUITestFlames := runTUITestFlamesFn - origRunTUITestLiveFlames := runTUITestLiveFlamesFn - defer func() { - runTraceFn = origRunTrace - runTUIFn = origRunTUI - runTUITestFlamesFn = origRunTUITestFlames - runTUITestLiveFlamesFn = origRunTUITestLiveFlames - }() - traceCalled := false regularTUICalled := false testFlamesCalled := false - runTraceFn = func(flags.Config) error { + deps := stubDeps() + deps.runTrace = func(flags.Config) error { traceCalled = true return nil } - runTUIFn = func(flags.Config, runtime.TraceStarter) error { + deps.runTUI = func(flags.Config, runtime.TraceStarter) error { regularTUICalled = true return nil } - runTUITestFlamesFn = func(_ flags.Config, starter runtime.TraceStarter) error { + deps.runTUITestFlames = func(_ flags.Config, starter runtime.TraceStarter) error { testFlamesCalled = true if starter == nil { t.Fatalf("expected non-nil starter for test flames mode") } return starter(context.Background()) } - runTUITestLiveFlamesFn = func(flags.Config, runtime.TraceStarter) error { - t.Fatalf("runTUITestLiveFlamesFn should not be called for --testflames") + deps.runTUITestLiveFlames = func(flags.Config, runtime.TraceStarter) error { + t.Fatalf("runTUITestLiveFlames should not be called for --testflames") return nil } cfg := flags.Config{TestFlames: true} - if err := dispatchRun(cfg); err != nil { - t.Fatalf("dispatchRun returned error: %v", err) + if err := dispatchRunWithDeps(cfg, deps); err != nil { + t.Fatalf("dispatchRunWithDeps returned error: %v", err) } if traceCalled { - t.Fatalf("did not expect runTraceFn for test flames mode") + t.Fatalf("did not expect runTrace for test flames mode") } if regularTUICalled { - t.Fatalf("did not expect runTUIFn for test flames mode") + t.Fatalf("did not expect runTUI for test flames mode") } if !testFlamesCalled { - t.Fatalf("expected runTUITestFlamesFn to be called") + t.Fatalf("expected runTUITestFlames to be called") } } func TestDispatchRunUsesTestLiveFlamesModeWhenRequested(t *testing.T) { - origRunTrace := runTraceFn - origRunTUI := runTUIFn - origRunTUITestFlames := runTUITestFlamesFn - origRunTUITestLiveFlames := runTUITestLiveFlamesFn - defer func() { - runTraceFn = origRunTrace - runTUIFn = origRunTUI - runTUITestFlamesFn = origRunTUITestFlames - runTUITestLiveFlamesFn = origRunTUITestLiveFlames - }() - traceCalled := false regularTUICalled := false testLiveFlamesCalled := false - runTraceFn = func(flags.Config) error { + deps := stubDeps() + deps.runTrace = func(flags.Config) error { traceCalled = true return nil } - runTUIFn = func(flags.Config, runtime.TraceStarter) error { + deps.runTUI = func(flags.Config, runtime.TraceStarter) error { regularTUICalled = true return nil } - runTUITestFlamesFn = func(flags.Config, runtime.TraceStarter) error { - t.Fatalf("runTUITestFlamesFn should not be called for --testliveflames") + deps.runTUITestFlames = func(flags.Config, runtime.TraceStarter) error { + t.Fatalf("runTUITestFlames should not be called for --testliveflames") return nil } - runTUITestLiveFlamesFn = func(_ flags.Config, starter runtime.TraceStarter) error { + deps.runTUITestLiveFlames = func(_ flags.Config, starter runtime.TraceStarter) error { testLiveFlamesCalled = true if starter == nil { t.Fatalf("expected non-nil starter for test live flames mode") @@ -389,17 +328,68 @@ func TestDispatchRunUsesTestLiveFlamesModeWhenRequested(t *testing.T) { } cfg := flags.Config{TestLiveFlames: true} - if err := dispatchRun(cfg); err != nil { - t.Fatalf("dispatchRun returned error: %v", err) + if err := dispatchRunWithDeps(cfg, deps); err != nil { + t.Fatalf("dispatchRunWithDeps returned error: %v", err) } if traceCalled { - t.Fatalf("did not expect runTraceFn for test live flames mode") + t.Fatalf("did not expect runTrace for test live flames mode") } if regularTUICalled { - t.Fatalf("did not expect runTUIFn for test live flames mode") + t.Fatalf("did not expect runTUI for test live flames mode") } if !testLiveFlamesCalled { - t.Fatalf("expected runTUITestLiveFlamesFn to be called") + t.Fatalf("expected runTUITestLiveFlames to be called") + } +} + +// TestDispatchRunRequiresRootForTUI verifies that the TUI mode handler +// enforces the root-privilege gate via deps.getEUID. +func TestDispatchRunRequiresRootForTUI(t *testing.T) { + deps := stubDeps() + deps.getEUID = func() int { return 1000 } // non-root + deps.runTUI = func(flags.Config, runtime.TraceStarter) error { + t.Fatalf("runTUI must not be called when not root") + return nil + } + + cfg := flags.Config{} + err := dispatchRunWithDeps(cfg, deps) + if !errors.Is(err, errRootPrivilegesRequired) { + t.Fatalf("expected root-required error, got %v", err) + } +} + +// TestDispatchRunRequiresRootForPlainTrace verifies that the plain trace +// mode handler enforces the root-privilege gate via deps.getEUID. +func TestDispatchRunRequiresRootForPlainTrace(t *testing.T) { + deps := stubDeps() + deps.getEUID = func() int { return 1000 } // non-root + deps.runTrace = func(flags.Config) error { + t.Fatalf("runTrace must not be called when not root") + return nil + } + + cfg := flags.Config{PlainMode: true} + err := dispatchRunWithDeps(cfg, deps) + if !errors.Is(err, errRootPrivilegesRequired) { + t.Fatalf("expected root-required error, got %v", err) + } +} + +// TestDispatchRunRequiresRootForParquet verifies that the headless Parquet +// mode handler enforces the root-privilege gate via deps.getEUID. +func TestDispatchRunRequiresRootForParquet(t *testing.T) { + deps := stubDeps() + deps.getEUID = func() int { return 1000 } // non-root + deps.runParquet = func(flags.Config) error { + t.Fatalf("runParquet must not be called when not root") + return nil + } + + cfg := flags.Config{ParquetPath: "trace.parquet"} + err := dispatchRunWithDeps(cfg, deps) + if !errors.Is(err, errRootPrivilegesRequired) { + t.Fatalf("expected root-required error, got %v", err) } } @@ -536,17 +526,6 @@ func TestBuildTestLiveFlamesRuntimeContinuouslyUpdatesLiveTrie(t *testing.T) { }) } -func TestRunTraceWithContextRequiresRoot(t *testing.T) { - origGetEUID := getEUID - defer func() { getEUID = origGetEUID }() - - getEUID = func() int { return 1000 } - err := runTraceWithContext(context.Background(), flags.NewFlags(), nil, nil) - if !errors.Is(err, errRootPrivilegesRequired) { - t.Fatalf("expected root-required error, got %v", err) - } -} - func TestTuiTraceStarterFromRunTracePropagatesError(t *testing.T) { starter := tuiTraceStarterFromRunTrace( flags.NewFlags(), -- cgit v1.2.3