summaryrefslogtreecommitdiff
path: root/internal/ior_mode_test.go
diff options
context:
space:
mode:
authorPaul Buetow <paul@buetow.org>2026-05-13 09:35:46 +0300
committerPaul Buetow <paul@buetow.org>2026-05-13 09:35:46 +0300
commite9c747c607e14d8b8f69547aa3c0f3b079c20796 (patch)
treeb5c461f3ace5e7785ed9716fdc136ff985475440 /internal/ior_mode_test.go
parentb086d3680e845a94c6662811c7f40e9592c3dec6 (diff)
replace package-level test doubles in ior.go with constructor injection (DIP)
Introduce runnerDeps struct to bundle all injectable function dependencies (getEUID, runTrace, runParquet, runTraceWithContext, runTUI*) that were previously package-level vars overridden in tests. The modeRegistry now carries a runnerDeps instance and passes it to each handler's run() method, eliminating global state mutation in tests. - Add runnerDeps struct and defaultRunnerDeps() constructor in ior_mode_registry.go - Convert modeRegistry from a []modeHandler slice type to a struct with handlers + deps fields; add newModeRegistry(deps) constructor - Update modeHandler.run() signature to accept runnerDeps; handlers call deps.getEUID / deps.runTrace etc. instead of globals - Update SetTUIRunners to write into defaultRegistry.deps instead of package-level vars - Add dispatchRunWithDeps helper for test isolation without global mutation - Remove root-privilege check from runTraceWithContext and runHeadlessParquet; each mode handler owns the EUID gate via deps.getEUID - Rewrite ior_mode_test.go: replace save/restore patterns with stubDeps() helper and dispatchRunWithDeps; add three new root-privilege tests replacing the removed TestRunTraceWithContextRequiresRoot Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Diffstat (limited to 'internal/ior_mode_test.go')
-rw-r--r--internal/ior_mode_test.go297
1 files changed, 138 insertions, 159 deletions
diff --git a/internal/ior_mode_test.go b/internal/ior_mode_test.go
index 0697ada..5ba2894 100644
--- a/internal/ior_mode_test.go
+++ b/internal/ior_mode_test.go
@@ -25,6 +25,22 @@ import (
parquetgo "github.com/parquet-go/parquet-go"
)
+// stubDeps returns a runnerDeps with safe no-op stubs for every function
+// field. Individual tests override only the functions they care about,
+// keeping test setup concise and making it easy to add new fields without
+// updating every test.
+func stubDeps() runnerDeps {
+ return runnerDeps{
+ getEUID: func() int { return 0 },
+ runTrace: func(flags.Config) error { return nil },
+ runParquet: func(flags.Config) error { return nil },
+ runTraceWithContext: func(context.Context, flags.Config, chan<- struct{}, func(*eventLoop)) error { return nil },
+ runTUI: func(flags.Config, runtime.TraceStarter) error { return nil },
+ runTUITestFlames: func(flags.Config, runtime.TraceStarter) error { return nil },
+ runTUITestLiveFlames: func(flags.Config, runtime.TraceStarter) error { return nil },
+ }
+}
+
func TestShouldRunTraceMode(t *testing.T) {
base := flags.Config{}
@@ -86,180 +102,123 @@ func TestShouldAutoStopByDuration(t *testing.T) {
if shouldAutoStopByDuration(withPprof) {
t.Fatalf("expected pprof flag alone not to auto-stop by duration")
}
-
}
func TestDispatchRunUsesTraceModeWhenRequested(t *testing.T) {
- origRunTrace := runTraceFn
- origRunParquet := runParquetFn
- origRunTUI := runTUIFn
- origRunTUITestFlames := runTUITestFlamesFn
- origRunTUITestLiveFlames := runTUITestLiveFlamesFn
- origGetEUID := getEUID
- defer func() {
- runTraceFn = origRunTrace
- runParquetFn = origRunParquet
- runTUIFn = origRunTUI
- runTUITestFlamesFn = origRunTUITestFlames
- runTUITestLiveFlamesFn = origRunTUITestLiveFlames
- getEUID = origGetEUID
- }()
- getEUID = func() int { return 0 }
-
traceCalled := false
- tuiCalled := false
- runTraceFn = func(flags.Config) error {
+ deps := stubDeps()
+ deps.runTrace = func(flags.Config) error {
traceCalled = true
return nil
}
- runParquetFn = func(flags.Config) error {
- t.Fatalf("runParquetFn should not be called in plain trace mode")
+ deps.runParquet = func(flags.Config) error {
+ t.Fatalf("runParquet should not be called in plain trace mode")
return nil
}
- runTUIFn = func(flags.Config, runtime.TraceStarter) error {
+ tuiCalled := false
+ deps.runTUI = func(flags.Config, runtime.TraceStarter) error {
tuiCalled = true
return nil
}
- runTUITestFlamesFn = func(flags.Config, runtime.TraceStarter) error {
- t.Fatalf("runTUITestFlamesFn should not be called in trace mode")
+ deps.runTUITestFlames = func(flags.Config, runtime.TraceStarter) error {
+ t.Fatalf("runTUITestFlames should not be called in trace mode")
return nil
}
- runTUITestLiveFlamesFn = func(flags.Config, runtime.TraceStarter) error {
- t.Fatalf("runTUITestLiveFlamesFn should not be called in trace mode")
+ deps.runTUITestLiveFlames = func(flags.Config, runtime.TraceStarter) error {
+ t.Fatalf("runTUITestLiveFlames should not be called in trace mode")
return nil
}
cfg := flags.Config{PlainMode: true}
- if err := dispatchRun(cfg); err != nil {
- t.Fatalf("dispatchRun returned error: %v", err)
+ if err := dispatchRunWithDeps(cfg, deps); err != nil {
+ t.Fatalf("dispatchRunWithDeps returned error: %v", err)
}
if !traceCalled {
- t.Fatalf("expected runTraceFn to be called")
+ t.Fatalf("expected runTrace to be called")
}
if tuiCalled {
- t.Fatalf("did not expect runTUIFn to be called")
+ t.Fatalf("did not expect runTUI to be called")
}
}
func TestDispatchRunUsesHeadlessParquetModeWhenRequested(t *testing.T) {
- origRunTrace := runTraceFn
- origRunParquet := runParquetFn
- origRunTUI := runTUIFn
- origRunTUITestFlames := runTUITestFlamesFn
- origRunTUITestLiveFlames := runTUITestLiveFlamesFn
- origGetEUID := getEUID
- defer func() {
- runTraceFn = origRunTrace
- runParquetFn = origRunParquet
- runTUIFn = origRunTUI
- runTUITestFlamesFn = origRunTUITestFlames
- runTUITestLiveFlamesFn = origRunTUITestLiveFlames
- getEUID = origGetEUID
- }()
- getEUID = func() int { return 0 }
-
traceCalled := false
parquetCalled := false
tuiCalled := false
- runTraceFn = func(flags.Config) error {
+ deps := stubDeps()
+ deps.runTrace = func(flags.Config) error {
traceCalled = true
return nil
}
- runParquetFn = func(flags.Config) error {
+ deps.runParquet = func(flags.Config) error {
parquetCalled = true
return nil
}
- runTUIFn = func(flags.Config, runtime.TraceStarter) error {
+ deps.runTUI = func(flags.Config, runtime.TraceStarter) error {
tuiCalled = true
return nil
}
- runTUITestFlamesFn = func(flags.Config, runtime.TraceStarter) error {
- t.Fatalf("runTUITestFlamesFn should not be called in parquet mode")
+ deps.runTUITestFlames = func(flags.Config, runtime.TraceStarter) error {
+ t.Fatalf("runTUITestFlames should not be called in parquet mode")
return nil
}
- runTUITestLiveFlamesFn = func(flags.Config, runtime.TraceStarter) error {
- t.Fatalf("runTUITestLiveFlamesFn should not be called in parquet mode")
+ deps.runTUITestLiveFlames = func(flags.Config, runtime.TraceStarter) error {
+ t.Fatalf("runTUITestLiveFlames should not be called in parquet mode")
return nil
}
cfg := flags.Config{ParquetPath: "trace.parquet"}
- if err := dispatchRun(cfg); err != nil {
- t.Fatalf("dispatchRun returned error: %v", err)
+ if err := dispatchRunWithDeps(cfg, deps); err != nil {
+ t.Fatalf("dispatchRunWithDeps returned error: %v", err)
}
if !parquetCalled {
- t.Fatalf("expected runParquetFn to be called")
+ t.Fatalf("expected runParquet to be called")
}
if traceCalled {
- t.Fatalf("did not expect runTraceFn to be called")
+ t.Fatalf("did not expect runTrace to be called")
}
if tuiCalled {
- t.Fatalf("did not expect runTUIFn to be called")
+ t.Fatalf("did not expect runTUI to be called")
}
}
func TestDispatchRunUsesTUIWhenOnlyPprofEnabled(t *testing.T) {
- origRunTrace := runTraceFn
- origRunTUI := runTUIFn
- origRunTUITestFlames := runTUITestFlamesFn
- origRunTUITestLiveFlames := runTUITestLiveFlamesFn
- origGetEUID := getEUID
- defer func() {
- runTraceFn = origRunTrace
- runTUIFn = origRunTUI
- runTUITestFlamesFn = origRunTUITestFlames
- runTUITestLiveFlamesFn = origRunTUITestLiveFlames
- getEUID = origGetEUID
- }()
- getEUID = func() int { return 0 }
-
traceCalled := false
tuiCalled := false
- runTraceFn = func(flags.Config) error {
+ deps := stubDeps()
+ deps.runTrace = func(flags.Config) error {
traceCalled = true
return nil
}
- runTUIFn = func(flags.Config, runtime.TraceStarter) error {
+ deps.runTUI = func(flags.Config, runtime.TraceStarter) error {
tuiCalled = true
return nil
}
- runTUITestFlamesFn = func(flags.Config, runtime.TraceStarter) error {
- t.Fatalf("runTUITestFlamesFn should not be called for regular TUI mode")
+ deps.runTUITestFlames = func(flags.Config, runtime.TraceStarter) error {
+ t.Fatalf("runTUITestFlames should not be called for regular TUI mode")
return nil
}
- runTUITestLiveFlamesFn = func(flags.Config, runtime.TraceStarter) error {
- t.Fatalf("runTUITestLiveFlamesFn should not be called for regular TUI mode")
+ deps.runTUITestLiveFlames = func(flags.Config, runtime.TraceStarter) error {
+ t.Fatalf("runTUITestLiveFlames should not be called for regular TUI mode")
return nil
}
cfg := flags.Config{PprofEnable: true}
- if err := dispatchRun(cfg); err != nil {
- t.Fatalf("dispatchRun returned error: %v", err)
+ if err := dispatchRunWithDeps(cfg, deps); err != nil {
+ t.Fatalf("dispatchRunWithDeps returned error: %v", err)
}
if traceCalled {
- t.Fatalf("did not expect runTraceFn when only -pprof is enabled")
+ t.Fatalf("did not expect runTrace when only -pprof is enabled")
}
if !tuiCalled {
- t.Fatalf("expected runTUIFn to be called")
+ t.Fatalf("expected runTUI to be called")
}
}
func TestDispatchRunUsesTUIStarterWhenNotPlain(t *testing.T) {
- origRunTraceWithContext := runTraceWithContextFn
- origRunTUI := runTUIFn
- origRunTUITestFlames := runTUITestFlamesFn
- origRunTUITestLiveFlames := runTUITestLiveFlamesFn
- origGetEUID := getEUID
- defer func() {
- runTraceWithContextFn = origRunTraceWithContext
- runTUIFn = origRunTUI
- runTUITestFlamesFn = origRunTUITestFlames
- runTUITestLiveFlamesFn = origRunTUITestLiveFlames
- getEUID = origGetEUID
- }()
- getEUID = func() int { return 0 }
-
traceDone := make(chan struct{}, 1)
- runTraceWithContextFn = func(_ context.Context, _ flags.Config, started chan<- struct{}, configure func(*eventLoop)) error {
+ deps := stubDeps()
+ deps.runTraceWithContext = func(_ context.Context, _ flags.Config, started chan<- struct{}, configure func(*eventLoop)) error {
if configure != nil {
configure(&eventLoop{})
}
@@ -269,7 +228,7 @@ func TestDispatchRunUsesTUIStarterWhenNotPlain(t *testing.T) {
}
tuiCalled := false
- runTUIFn = func(_ flags.Config, starter runtime.TraceStarter) error {
+ deps.runTUI = func(_ flags.Config, starter runtime.TraceStarter) error {
tuiCalled = true
if starter == nil {
t.Fatalf("expected non-nil starter")
@@ -279,108 +238,88 @@ func TestDispatchRunUsesTUIStarterWhenNotPlain(t *testing.T) {
}
return nil
}
- runTUITestFlamesFn = func(flags.Config, runtime.TraceStarter) error {
- t.Fatalf("runTUITestFlamesFn should not be called for normal starter path")
+ deps.runTUITestFlames = func(flags.Config, runtime.TraceStarter) error {
+ t.Fatalf("runTUITestFlames should not be called for normal starter path")
return nil
}
- runTUITestLiveFlamesFn = func(flags.Config, runtime.TraceStarter) error {
- t.Fatalf("runTUITestLiveFlamesFn should not be called for normal starter path")
+ deps.runTUITestLiveFlames = func(flags.Config, runtime.TraceStarter) error {
+ t.Fatalf("runTUITestLiveFlames should not be called for normal starter path")
return nil
}
cfg := flags.Config{}
- if err := dispatchRun(cfg); err != nil {
- t.Fatalf("dispatchRun returned error: %v", err)
+ if err := dispatchRunWithDeps(cfg, deps); err != nil {
+ t.Fatalf("dispatchRunWithDeps returned error: %v", err)
}
if !tuiCalled {
- t.Fatalf("expected runTUIFn to be called")
+ t.Fatalf("expected runTUI to be called")
}
select {
case <-traceDone:
case <-time.After(200 * time.Millisecond):
- t.Fatalf("expected starter to launch runTraceWithContextFn")
+ t.Fatalf("expected starter to launch runTraceWithContext")
}
}
func TestDispatchRunUsesTestFlamesModeWhenRequested(t *testing.T) {
- origRunTrace := runTraceFn
- origRunTUI := runTUIFn
- origRunTUITestFlames := runTUITestFlamesFn
- origRunTUITestLiveFlames := runTUITestLiveFlamesFn
- defer func() {
- runTraceFn = origRunTrace
- runTUIFn = origRunTUI
- runTUITestFlamesFn = origRunTUITestFlames
- runTUITestLiveFlamesFn = origRunTUITestLiveFlames
- }()
-
traceCalled := false
regularTUICalled := false
testFlamesCalled := false
- runTraceFn = func(flags.Config) error {
+ deps := stubDeps()
+ deps.runTrace = func(flags.Config) error {
traceCalled = true
return nil
}
- runTUIFn = func(flags.Config, runtime.TraceStarter) error {
+ deps.runTUI = func(flags.Config, runtime.TraceStarter) error {
regularTUICalled = true
return nil
}
- runTUITestFlamesFn = func(_ flags.Config, starter runtime.TraceStarter) error {
+ deps.runTUITestFlames = func(_ flags.Config, starter runtime.TraceStarter) error {
testFlamesCalled = true
if starter == nil {
t.Fatalf("expected non-nil starter for test flames mode")
}
return starter(context.Background())
}
- runTUITestLiveFlamesFn = func(flags.Config, runtime.TraceStarter) error {
- t.Fatalf("runTUITestLiveFlamesFn should not be called for --testflames")
+ deps.runTUITestLiveFlames = func(flags.Config, runtime.TraceStarter) error {
+ t.Fatalf("runTUITestLiveFlames should not be called for --testflames")
return nil
}
cfg := flags.Config{TestFlames: true}
- if err := dispatchRun(cfg); err != nil {
- t.Fatalf("dispatchRun returned error: %v", err)
+ if err := dispatchRunWithDeps(cfg, deps); err != nil {
+ t.Fatalf("dispatchRunWithDeps returned error: %v", err)
}
if traceCalled {
- t.Fatalf("did not expect runTraceFn for test flames mode")
+ t.Fatalf("did not expect runTrace for test flames mode")
}
if regularTUICalled {
- t.Fatalf("did not expect runTUIFn for test flames mode")
+ t.Fatalf("did not expect runTUI for test flames mode")
}
if !testFlamesCalled {
- t.Fatalf("expected runTUITestFlamesFn to be called")
+ t.Fatalf("expected runTUITestFlames to be called")
}
}
func TestDispatchRunUsesTestLiveFlamesModeWhenRequested(t *testing.T) {
- origRunTrace := runTraceFn
- origRunTUI := runTUIFn
- origRunTUITestFlames := runTUITestFlamesFn
- origRunTUITestLiveFlames := runTUITestLiveFlamesFn
- defer func() {
- runTraceFn = origRunTrace
- runTUIFn = origRunTUI
- runTUITestFlamesFn = origRunTUITestFlames
- runTUITestLiveFlamesFn = origRunTUITestLiveFlames
- }()
-
traceCalled := false
regularTUICalled := false
testLiveFlamesCalled := false
- runTraceFn = func(flags.Config) error {
+ deps := stubDeps()
+ deps.runTrace = func(flags.Config) error {
traceCalled = true
return nil
}
- runTUIFn = func(flags.Config, runtime.TraceStarter) error {
+ deps.runTUI = func(flags.Config, runtime.TraceStarter) error {
regularTUICalled = true
return nil
}
- runTUITestFlamesFn = func(flags.Config, runtime.TraceStarter) error {
- t.Fatalf("runTUITestFlamesFn should not be called for --testliveflames")
+ deps.runTUITestFlames = func(flags.Config, runtime.TraceStarter) error {
+ t.Fatalf("runTUITestFlames should not be called for --testliveflames")
return nil
}
- runTUITestLiveFlamesFn = func(_ flags.Config, starter runtime.TraceStarter) error {
+ deps.runTUITestLiveFlames = func(_ flags.Config, starter runtime.TraceStarter) error {
testLiveFlamesCalled = true
if starter == nil {
t.Fatalf("expected non-nil starter for test live flames mode")
@@ -389,17 +328,68 @@ func TestDispatchRunUsesTestLiveFlamesModeWhenRequested(t *testing.T) {
}
cfg := flags.Config{TestLiveFlames: true}
- if err := dispatchRun(cfg); err != nil {
- t.Fatalf("dispatchRun returned error: %v", err)
+ if err := dispatchRunWithDeps(cfg, deps); err != nil {
+ t.Fatalf("dispatchRunWithDeps returned error: %v", err)
}
if traceCalled {
- t.Fatalf("did not expect runTraceFn for test live flames mode")
+ t.Fatalf("did not expect runTrace for test live flames mode")
}
if regularTUICalled {
- t.Fatalf("did not expect runTUIFn for test live flames mode")
+ t.Fatalf("did not expect runTUI for test live flames mode")
}
if !testLiveFlamesCalled {
- t.Fatalf("expected runTUITestLiveFlamesFn to be called")
+ t.Fatalf("expected runTUITestLiveFlames to be called")
+ }
+}
+
+// TestDispatchRunRequiresRootForTUI verifies that the TUI mode handler
+// enforces the root-privilege gate via deps.getEUID.
+func TestDispatchRunRequiresRootForTUI(t *testing.T) {
+ deps := stubDeps()
+ deps.getEUID = func() int { return 1000 } // non-root
+ deps.runTUI = func(flags.Config, runtime.TraceStarter) error {
+ t.Fatalf("runTUI must not be called when not root")
+ return nil
+ }
+
+ cfg := flags.Config{}
+ err := dispatchRunWithDeps(cfg, deps)
+ if !errors.Is(err, errRootPrivilegesRequired) {
+ t.Fatalf("expected root-required error, got %v", err)
+ }
+}
+
+// TestDispatchRunRequiresRootForPlainTrace verifies that the plain trace
+// mode handler enforces the root-privilege gate via deps.getEUID.
+func TestDispatchRunRequiresRootForPlainTrace(t *testing.T) {
+ deps := stubDeps()
+ deps.getEUID = func() int { return 1000 } // non-root
+ deps.runTrace = func(flags.Config) error {
+ t.Fatalf("runTrace must not be called when not root")
+ return nil
+ }
+
+ cfg := flags.Config{PlainMode: true}
+ err := dispatchRunWithDeps(cfg, deps)
+ if !errors.Is(err, errRootPrivilegesRequired) {
+ t.Fatalf("expected root-required error, got %v", err)
+ }
+}
+
+// TestDispatchRunRequiresRootForParquet verifies that the headless Parquet
+// mode handler enforces the root-privilege gate via deps.getEUID.
+func TestDispatchRunRequiresRootForParquet(t *testing.T) {
+ deps := stubDeps()
+ deps.getEUID = func() int { return 1000 } // non-root
+ deps.runParquet = func(flags.Config) error {
+ t.Fatalf("runParquet must not be called when not root")
+ return nil
+ }
+
+ cfg := flags.Config{ParquetPath: "trace.parquet"}
+ err := dispatchRunWithDeps(cfg, deps)
+ if !errors.Is(err, errRootPrivilegesRequired) {
+ t.Fatalf("expected root-required error, got %v", err)
}
}
@@ -536,17 +526,6 @@ func TestBuildTestLiveFlamesRuntimeContinuouslyUpdatesLiveTrie(t *testing.T) {
})
}
-func TestRunTraceWithContextRequiresRoot(t *testing.T) {
- origGetEUID := getEUID
- defer func() { getEUID = origGetEUID }()
-
- getEUID = func() int { return 1000 }
- err := runTraceWithContext(context.Background(), flags.NewFlags(), nil, nil)
- if !errors.Is(err, errRootPrivilegesRequired) {
- t.Fatalf("expected root-required error, got %v", err)
- }
-}
-
func TestTuiTraceStarterFromRunTracePropagatesError(t *testing.T) {
starter := tuiTraceStarterFromRunTrace(
flags.NewFlags(),