summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--internal/ior.go39
-rw-r--r--internal/ior_mode_registry.go114
-rw-r--r--internal/ior_mode_test.go297
-rw-r--r--internal/ior_parquet_sink.go6
4 files changed, 243 insertions, 213 deletions
diff --git a/internal/ior.go b/internal/ior.go
index f1b159d..cddb657 100644
--- a/internal/ior.go
+++ b/internal/ior.go
@@ -29,32 +29,20 @@ import (
// never imports the TUI layer.
type tuiRunFunc func(flags.Config, runtime.TraceStarter) error
-var (
- runTraceFn = runTrace
- runParquetFn = runHeadlessParquet
- runTraceWithContextFn = runTraceWithContext
- // runTUIFn, runTUITestFlamesFn, runTUITestLiveFlamesFn are injected by
- // main (via SetTUIRunners) before Run is called. They default to nil so
- // that test files can replace individual runners without importing tui.
- runTUIFn tuiRunFunc
- runTUITestFlamesFn tuiRunFunc
- runTUITestLiveFlamesFn tuiRunFunc
- getEUID = os.Geteuid
-
- errRootPrivilegesRequired = errors.New("tracing requires root privileges (run with sudo)")
-)
+var errRootPrivilegesRequired = errors.New("tracing requires root privileges (run with sudo)")
// SetTUIRunners injects the concrete TUI runner functions from the cmd layer
-// so the core internal package does not need to import the TUI packages.
-// This must be called before Run when running in TUI mode.
+// into the default registry so the core internal package does not need to
+// import the TUI packages. This must be called before Run when running in
+// TUI mode.
func SetTUIRunners(
runTUI tuiRunFunc,
runTUITestFlames tuiRunFunc,
runTUITestLiveFlames tuiRunFunc,
) {
- runTUIFn = runTUI
- runTUITestFlamesFn = runTUITestFlames
- runTUITestLiveFlamesFn = runTUITestLiveFlames
+ defaultRegistry.deps.runTUI = runTUI
+ defaultRegistry.deps.runTUITestFlames = runTUITestFlames
+ defaultRegistry.deps.runTUITestLiveFlames = runTUITestLiveFlames
}
// streamEventSink is the write-side contract for the stream ring buffer used
@@ -75,6 +63,13 @@ func dispatchRun(cfg flags.Config) error {
return defaultRegistry.dispatch(cfg)
}
+// dispatchRunWithDeps constructs an isolated registry from the given deps and
+// dispatches cfg through it. Used by tests to inject stub functions without
+// mutating the global defaultRegistry.
+func dispatchRunWithDeps(cfg flags.Config, deps runnerDeps) error {
+ return newModeRegistry(deps).dispatch(cfg)
+}
+
// validateRunConfig runs all cross-mode constraint checks without running
// any mode. It is a thin wrapper around defaultRegistry.validate so that
// callers (and tests) that only want validation do not need to know about
@@ -448,10 +443,10 @@ func finaliseTrace(recorder *flamegraph.Recorder, profiling *profilingControl, t
return nil
}
+// runTraceWithContext is the concrete BPF trace implementation. Root privilege
+// is checked by the mode handler (via runnerDeps.getEUID) before calling this
+// function; the handler is the authoritative place for the EUID gate.
func runTraceWithContext(parentCtx context.Context, cfg flags.Config, started chan<- struct{}, configure func(*eventLoop)) error {
- if getEUID() != 0 {
- return errRootPrivilegesRequired
- }
verbose := started == nil
logln := newLogger(verbose)
diff --git a/internal/ior_mode_registry.go b/internal/ior_mode_registry.go
index 6d04052..3eb59b5 100644
--- a/internal/ior_mode_registry.go
+++ b/internal/ior_mode_registry.go
@@ -1,11 +1,55 @@
package internal
import (
+ "context"
"errors"
+ "os"
"ior/internal/flags"
)
+// runnerDeps bundles all injectable function dependencies used by the mode
+// registry and its handlers. Using a struct instead of package-level vars
+// allows tests to substitute individual functions without mutating global
+// state (Dependency Inversion Principle).
+type runnerDeps struct {
+ // getEUID returns the effective user ID of the calling process.
+ // Overridden in tests to simulate root or non-root execution.
+ getEUID func() int
+
+ // runTrace executes a headless plain/flamegraph trace (no TUI).
+ runTrace func(flags.Config) error
+
+ // runParquet executes a headless Parquet recording run (no TUI).
+ runParquet func(flags.Config) error
+
+ // runTraceWithContext drives a BPF trace with a parent context, started
+ // signal channel, and event-loop configurator. Used by the TUI starter.
+ runTraceWithContext func(context.Context, flags.Config, chan<- struct{}, func(*eventLoop)) error
+
+ // runTUI launches the interactive TUI backed by a live BPF trace.
+ // Injected at startup via SetTUIRunners so that the core package never
+ // imports the TUI layer.
+ runTUI tuiRunFunc
+
+ // runTUITestFlames launches the TUI seeded with static synthetic flame data.
+ runTUITestFlames tuiRunFunc
+
+ // runTUITestLiveFlames launches the TUI fed by a live synthetic flame goroutine.
+ runTUITestLiveFlames tuiRunFunc
+}
+
+// defaultRunnerDeps returns the production function set.
+func defaultRunnerDeps() runnerDeps {
+ return runnerDeps{
+ getEUID: os.Geteuid,
+ runTrace: runTrace,
+ runParquet: runHeadlessParquet,
+ runTraceWithContext: runTraceWithContext,
+ // TUI runners are nil until SetTUIRunners is called from cmd/ior/main.go.
+ }
+}
+
// modeHandler describes a single execution mode for the ior binary.
// Each mode knows how to recognise itself (match), enforce its
// invariants (validate), and run (run). The registry evaluates
@@ -18,33 +62,47 @@ type modeHandler interface {
// (pre-root modes are checked first and return early before requiring root).
validate(cfg flags.Config) error
// run executes the mode using the supplied config.
- run(cfg flags.Config) error
+ run(cfg flags.Config, deps runnerDeps) error
}
-// modeRegistry is an ordered list of modeHandlers.
-// dispatchRun and validateRunConfig iterate through it.
-type modeRegistry []modeHandler
+// modeRegistry is an ordered list of modeHandlers paired with the
+// injectable function dependencies they share. Storing deps on the registry
+// (rather than as package-level vars) lets tests construct isolated
+// registries without mutating global state.
+type modeRegistry struct {
+ handlers []modeHandler
+ deps runnerDeps
+}
-// defaultRegistry is the canonical ordered registry used at runtime.
+// newModeRegistry constructs a registry with the standard handler order and
+// the provided dependencies.
// Modes are evaluated first-match-wins, so more specific modes (e.g.,
-// testFlames) are registered before more general ones (e.g., TUI default).
-var defaultRegistry = modeRegistry{
- &testFlamesModeHandler{},
- &testLiveFlamesModeHandler{},
- &headlessParquetModeHandler{},
- &plainTraceModeHandler{},
- &tuiModeHandler{},
+// testFlames) must be registered before more general ones (e.g., TUI default).
+func newModeRegistry(deps runnerDeps) modeRegistry {
+ return modeRegistry{
+ handlers: []modeHandler{
+ &testFlamesModeHandler{},
+ &testLiveFlamesModeHandler{},
+ &headlessParquetModeHandler{},
+ &plainTraceModeHandler{},
+ &tuiModeHandler{},
+ },
+ deps: deps,
+ }
}
+// defaultRegistry is the canonical ordered registry used at runtime.
+var defaultRegistry = newModeRegistry(defaultRunnerDeps())
+
// dispatch validates cross-mode constraints, requires root when necessary,
// then delegates to the first matching handler in the registry.
func (reg modeRegistry) dispatch(cfg flags.Config) error {
if err := reg.validate(cfg); err != nil {
return err
}
- for _, h := range reg {
+ for _, h := range reg.handlers {
if h.match(cfg) {
- return h.run(cfg)
+ return h.run(cfg, reg.deps)
}
}
// Registry must always include a catch-all (tuiModeHandler matches everything).
@@ -56,7 +114,7 @@ func (reg modeRegistry) dispatch(cfg flags.Config) error {
// combination errors (e.g., parquet + plain is rejected regardless of which
// handler ultimately runs).
func (reg modeRegistry) validate(cfg flags.Config) error {
- for _, h := range reg {
+ for _, h := range reg.handlers {
if err := h.validate(cfg); err != nil {
return err
}
@@ -93,8 +151,8 @@ func (h *testFlamesModeHandler) validate(cfg flags.Config) error {
return nil
}
-func (h *testFlamesModeHandler) run(cfg flags.Config) error {
- return runTUITestFlamesFn(cfg, tuiTestFlamesStarter(cfg))
+func (h *testFlamesModeHandler) run(cfg flags.Config, deps runnerDeps) error {
+ return deps.runTUITestFlames(cfg, tuiTestFlamesStarter(cfg))
}
// --- testLiveFlamesModeHandler ---
@@ -123,8 +181,8 @@ func (h *testLiveFlamesModeHandler) validate(cfg flags.Config) error {
return nil
}
-func (h *testLiveFlamesModeHandler) run(cfg flags.Config) error {
- return runTUITestLiveFlamesFn(cfg, tuiTestLiveFlamesStarter(cfg))
+func (h *testLiveFlamesModeHandler) run(cfg flags.Config, deps runnerDeps) error {
+ return deps.runTUITestLiveFlames(cfg, tuiTestLiveFlamesStarter(cfg))
}
// --- headlessParquetModeHandler ---
@@ -161,11 +219,11 @@ func (h *headlessParquetModeHandler) validate(cfg flags.Config) error {
return nil
}
-func (h *headlessParquetModeHandler) run(cfg flags.Config) error {
- if getEUID() != 0 {
+func (h *headlessParquetModeHandler) run(cfg flags.Config, deps runnerDeps) error {
+ if deps.getEUID() != 0 {
return errRootPrivilegesRequired
}
- return runParquetFn(cfg)
+ return deps.runParquet(cfg)
}
// --- plainTraceModeHandler ---
@@ -186,11 +244,11 @@ func (h *plainTraceModeHandler) validate(cfg flags.Config) error {
return nil
}
-func (h *plainTraceModeHandler) run(cfg flags.Config) error {
- if getEUID() != 0 {
+func (h *plainTraceModeHandler) run(cfg flags.Config, deps runnerDeps) error {
+ if deps.getEUID() != 0 {
return errRootPrivilegesRequired
}
- return runTraceFn(cfg)
+ return deps.runTrace(cfg)
}
// --- tuiModeHandler ---
@@ -208,9 +266,9 @@ func (h *tuiModeHandler) validate(_ flags.Config) error {
return nil
}
-func (h *tuiModeHandler) run(cfg flags.Config) error {
- if getEUID() != 0 {
+func (h *tuiModeHandler) run(cfg flags.Config, deps runnerDeps) error {
+ if deps.getEUID() != 0 {
return errRootPrivilegesRequired
}
- return runTUIFn(cfg, tuiTraceStarterFromRunTrace(cfg, runTraceWithContextFn))
+ return deps.runTUI(cfg, tuiTraceStarterFromRunTrace(cfg, deps.runTraceWithContext))
}
diff --git a/internal/ior_mode_test.go b/internal/ior_mode_test.go
index 0697ada..5ba2894 100644
--- a/internal/ior_mode_test.go
+++ b/internal/ior_mode_test.go
@@ -25,6 +25,22 @@ import (
parquetgo "github.com/parquet-go/parquet-go"
)
+// stubDeps returns a runnerDeps with safe no-op stubs for every function
+// field. Individual tests override only the functions they care about,
+// keeping test setup concise and making it easy to add new fields without
+// updating every test.
+func stubDeps() runnerDeps {
+ return runnerDeps{
+ getEUID: func() int { return 0 },
+ runTrace: func(flags.Config) error { return nil },
+ runParquet: func(flags.Config) error { return nil },
+ runTraceWithContext: func(context.Context, flags.Config, chan<- struct{}, func(*eventLoop)) error { return nil },
+ runTUI: func(flags.Config, runtime.TraceStarter) error { return nil },
+ runTUITestFlames: func(flags.Config, runtime.TraceStarter) error { return nil },
+ runTUITestLiveFlames: func(flags.Config, runtime.TraceStarter) error { return nil },
+ }
+}
+
func TestShouldRunTraceMode(t *testing.T) {
base := flags.Config{}
@@ -86,180 +102,123 @@ func TestShouldAutoStopByDuration(t *testing.T) {
if shouldAutoStopByDuration(withPprof) {
t.Fatalf("expected pprof flag alone not to auto-stop by duration")
}
-
}
func TestDispatchRunUsesTraceModeWhenRequested(t *testing.T) {
- origRunTrace := runTraceFn
- origRunParquet := runParquetFn
- origRunTUI := runTUIFn
- origRunTUITestFlames := runTUITestFlamesFn
- origRunTUITestLiveFlames := runTUITestLiveFlamesFn
- origGetEUID := getEUID
- defer func() {
- runTraceFn = origRunTrace
- runParquetFn = origRunParquet
- runTUIFn = origRunTUI
- runTUITestFlamesFn = origRunTUITestFlames
- runTUITestLiveFlamesFn = origRunTUITestLiveFlames
- getEUID = origGetEUID
- }()
- getEUID = func() int { return 0 }
-
traceCalled := false
- tuiCalled := false
- runTraceFn = func(flags.Config) error {
+ deps := stubDeps()
+ deps.runTrace = func(flags.Config) error {
traceCalled = true
return nil
}
- runParquetFn = func(flags.Config) error {
- t.Fatalf("runParquetFn should not be called in plain trace mode")
+ deps.runParquet = func(flags.Config) error {
+ t.Fatalf("runParquet should not be called in plain trace mode")
return nil
}
- runTUIFn = func(flags.Config, runtime.TraceStarter) error {
+ tuiCalled := false
+ deps.runTUI = func(flags.Config, runtime.TraceStarter) error {
tuiCalled = true
return nil
}
- runTUITestFlamesFn = func(flags.Config, runtime.TraceStarter) error {
- t.Fatalf("runTUITestFlamesFn should not be called in trace mode")
+ deps.runTUITestFlames = func(flags.Config, runtime.TraceStarter) error {
+ t.Fatalf("runTUITestFlames should not be called in trace mode")
return nil
}
- runTUITestLiveFlamesFn = func(flags.Config, runtime.TraceStarter) error {
- t.Fatalf("runTUITestLiveFlamesFn should not be called in trace mode")
+ deps.runTUITestLiveFlames = func(flags.Config, runtime.TraceStarter) error {
+ t.Fatalf("runTUITestLiveFlames should not be called in trace mode")
return nil
}
cfg := flags.Config{PlainMode: true}
- if err := dispatchRun(cfg); err != nil {
- t.Fatalf("dispatchRun returned error: %v", err)
+ if err := dispatchRunWithDeps(cfg, deps); err != nil {
+ t.Fatalf("dispatchRunWithDeps returned error: %v", err)
}
if !traceCalled {
- t.Fatalf("expected runTraceFn to be called")
+ t.Fatalf("expected runTrace to be called")
}
if tuiCalled {
- t.Fatalf("did not expect runTUIFn to be called")
+ t.Fatalf("did not expect runTUI to be called")
}
}
func TestDispatchRunUsesHeadlessParquetModeWhenRequested(t *testing.T) {
- origRunTrace := runTraceFn
- origRunParquet := runParquetFn
- origRunTUI := runTUIFn
- origRunTUITestFlames := runTUITestFlamesFn
- origRunTUITestLiveFlames := runTUITestLiveFlamesFn
- origGetEUID := getEUID
- defer func() {
- runTraceFn = origRunTrace
- runParquetFn = origRunParquet
- runTUIFn = origRunTUI
- runTUITestFlamesFn = origRunTUITestFlames
- runTUITestLiveFlamesFn = origRunTUITestLiveFlames
- getEUID = origGetEUID
- }()
- getEUID = func() int { return 0 }
-
traceCalled := false
parquetCalled := false
tuiCalled := false
- runTraceFn = func(flags.Config) error {
+ deps := stubDeps()
+ deps.runTrace = func(flags.Config) error {
traceCalled = true
return nil
}
- runParquetFn = func(flags.Config) error {
+ deps.runParquet = func(flags.Config) error {
parquetCalled = true
return nil
}
- runTUIFn = func(flags.Config, runtime.TraceStarter) error {
+ deps.runTUI = func(flags.Config, runtime.TraceStarter) error {
tuiCalled = true
return nil
}
- runTUITestFlamesFn = func(flags.Config, runtime.TraceStarter) error {
- t.Fatalf("runTUITestFlamesFn should not be called in parquet mode")
+ deps.runTUITestFlames = func(flags.Config, runtime.TraceStarter) error {
+ t.Fatalf("runTUITestFlames should not be called in parquet mode")
return nil
}
- runTUITestLiveFlamesFn = func(flags.Config, runtime.TraceStarter) error {
- t.Fatalf("runTUITestLiveFlamesFn should not be called in parquet mode")
+ deps.runTUITestLiveFlames = func(flags.Config, runtime.TraceStarter) error {
+ t.Fatalf("runTUITestLiveFlames should not be called in parquet mode")
return nil
}
cfg := flags.Config{ParquetPath: "trace.parquet"}
- if err := dispatchRun(cfg); err != nil {
- t.Fatalf("dispatchRun returned error: %v", err)
+ if err := dispatchRunWithDeps(cfg, deps); err != nil {
+ t.Fatalf("dispatchRunWithDeps returned error: %v", err)
}
if !parquetCalled {
- t.Fatalf("expected runParquetFn to be called")
+ t.Fatalf("expected runParquet to be called")
}
if traceCalled {
- t.Fatalf("did not expect runTraceFn to be called")
+ t.Fatalf("did not expect runTrace to be called")
}
if tuiCalled {
- t.Fatalf("did not expect runTUIFn to be called")
+ t.Fatalf("did not expect runTUI to be called")
}
}
func TestDispatchRunUsesTUIWhenOnlyPprofEnabled(t *testing.T) {
- origRunTrace := runTraceFn
- origRunTUI := runTUIFn
- origRunTUITestFlames := runTUITestFlamesFn
- origRunTUITestLiveFlames := runTUITestLiveFlamesFn
- origGetEUID := getEUID
- defer func() {
- runTraceFn = origRunTrace
- runTUIFn = origRunTUI
- runTUITestFlamesFn = origRunTUITestFlames
- runTUITestLiveFlamesFn = origRunTUITestLiveFlames
- getEUID = origGetEUID
- }()
- getEUID = func() int { return 0 }
-
traceCalled := false
tuiCalled := false
- runTraceFn = func(flags.Config) error {
+ deps := stubDeps()
+ deps.runTrace = func(flags.Config) error {
traceCalled = true
return nil
}
- runTUIFn = func(flags.Config, runtime.TraceStarter) error {
+ deps.runTUI = func(flags.Config, runtime.TraceStarter) error {
tuiCalled = true
return nil
}
- runTUITestFlamesFn = func(flags.Config, runtime.TraceStarter) error {
- t.Fatalf("runTUITestFlamesFn should not be called for regular TUI mode")
+ deps.runTUITestFlames = func(flags.Config, runtime.TraceStarter) error {
+ t.Fatalf("runTUITestFlames should not be called for regular TUI mode")
return nil
}
- runTUITestLiveFlamesFn = func(flags.Config, runtime.TraceStarter) error {
- t.Fatalf("runTUITestLiveFlamesFn should not be called for regular TUI mode")
+ deps.runTUITestLiveFlames = func(flags.Config, runtime.TraceStarter) error {
+ t.Fatalf("runTUITestLiveFlames should not be called for regular TUI mode")
return nil
}
cfg := flags.Config{PprofEnable: true}
- if err := dispatchRun(cfg); err != nil {
- t.Fatalf("dispatchRun returned error: %v", err)
+ if err := dispatchRunWithDeps(cfg, deps); err != nil {
+ t.Fatalf("dispatchRunWithDeps returned error: %v", err)
}
if traceCalled {
- t.Fatalf("did not expect runTraceFn when only -pprof is enabled")
+ t.Fatalf("did not expect runTrace when only -pprof is enabled")
}
if !tuiCalled {
- t.Fatalf("expected runTUIFn to be called")
+ t.Fatalf("expected runTUI to be called")
}
}
func TestDispatchRunUsesTUIStarterWhenNotPlain(t *testing.T) {
- origRunTraceWithContext := runTraceWithContextFn
- origRunTUI := runTUIFn
- origRunTUITestFlames := runTUITestFlamesFn
- origRunTUITestLiveFlames := runTUITestLiveFlamesFn
- origGetEUID := getEUID
- defer func() {
- runTraceWithContextFn = origRunTraceWithContext
- runTUIFn = origRunTUI
- runTUITestFlamesFn = origRunTUITestFlames
- runTUITestLiveFlamesFn = origRunTUITestLiveFlames
- getEUID = origGetEUID
- }()
- getEUID = func() int { return 0 }
-
traceDone := make(chan struct{}, 1)
- runTraceWithContextFn = func(_ context.Context, _ flags.Config, started chan<- struct{}, configure func(*eventLoop)) error {
+ deps := stubDeps()
+ deps.runTraceWithContext = func(_ context.Context, _ flags.Config, started chan<- struct{}, configure func(*eventLoop)) error {
if configure != nil {
configure(&eventLoop{})
}
@@ -269,7 +228,7 @@ func TestDispatchRunUsesTUIStarterWhenNotPlain(t *testing.T) {
}
tuiCalled := false
- runTUIFn = func(_ flags.Config, starter runtime.TraceStarter) error {
+ deps.runTUI = func(_ flags.Config, starter runtime.TraceStarter) error {
tuiCalled = true
if starter == nil {
t.Fatalf("expected non-nil starter")
@@ -279,108 +238,88 @@ func TestDispatchRunUsesTUIStarterWhenNotPlain(t *testing.T) {
}
return nil
}
- runTUITestFlamesFn = func(flags.Config, runtime.TraceStarter) error {
- t.Fatalf("runTUITestFlamesFn should not be called for normal starter path")
+ deps.runTUITestFlames = func(flags.Config, runtime.TraceStarter) error {
+ t.Fatalf("runTUITestFlames should not be called for normal starter path")
return nil
}
- runTUITestLiveFlamesFn = func(flags.Config, runtime.TraceStarter) error {
- t.Fatalf("runTUITestLiveFlamesFn should not be called for normal starter path")
+ deps.runTUITestLiveFlames = func(flags.Config, runtime.TraceStarter) error {
+ t.Fatalf("runTUITestLiveFlames should not be called for normal starter path")
return nil
}
cfg := flags.Config{}
- if err := dispatchRun(cfg); err != nil {
- t.Fatalf("dispatchRun returned error: %v", err)
+ if err := dispatchRunWithDeps(cfg, deps); err != nil {
+ t.Fatalf("dispatchRunWithDeps returned error: %v", err)
}
if !tuiCalled {
- t.Fatalf("expected runTUIFn to be called")
+ t.Fatalf("expected runTUI to be called")
}
select {
case <-traceDone:
case <-time.After(200 * time.Millisecond):
- t.Fatalf("expected starter to launch runTraceWithContextFn")
+ t.Fatalf("expected starter to launch runTraceWithContext")
}
}
func TestDispatchRunUsesTestFlamesModeWhenRequested(t *testing.T) {
- origRunTrace := runTraceFn
- origRunTUI := runTUIFn
- origRunTUITestFlames := runTUITestFlamesFn
- origRunTUITestLiveFlames := runTUITestLiveFlamesFn
- defer func() {
- runTraceFn = origRunTrace
- runTUIFn = origRunTUI
- runTUITestFlamesFn = origRunTUITestFlames
- runTUITestLiveFlamesFn = origRunTUITestLiveFlames
- }()
-
traceCalled := false
regularTUICalled := false
testFlamesCalled := false
- runTraceFn = func(flags.Config) error {
+ deps := stubDeps()
+ deps.runTrace = func(flags.Config) error {
traceCalled = true
return nil
}
- runTUIFn = func(flags.Config, runtime.TraceStarter) error {
+ deps.runTUI = func(flags.Config, runtime.TraceStarter) error {
regularTUICalled = true
return nil
}
- runTUITestFlamesFn = func(_ flags.Config, starter runtime.TraceStarter) error {
+ deps.runTUITestFlames = func(_ flags.Config, starter runtime.TraceStarter) error {
testFlamesCalled = true
if starter == nil {
t.Fatalf("expected non-nil starter for test flames mode")
}
return starter(context.Background())
}
- runTUITestLiveFlamesFn = func(flags.Config, runtime.TraceStarter) error {
- t.Fatalf("runTUITestLiveFlamesFn should not be called for --testflames")
+ deps.runTUITestLiveFlames = func(flags.Config, runtime.TraceStarter) error {
+ t.Fatalf("runTUITestLiveFlames should not be called for --testflames")
return nil
}
cfg := flags.Config{TestFlames: true}
- if err := dispatchRun(cfg); err != nil {
- t.Fatalf("dispatchRun returned error: %v", err)
+ if err := dispatchRunWithDeps(cfg, deps); err != nil {
+ t.Fatalf("dispatchRunWithDeps returned error: %v", err)
}
if traceCalled {
- t.Fatalf("did not expect runTraceFn for test flames mode")
+ t.Fatalf("did not expect runTrace for test flames mode")
}
if regularTUICalled {
- t.Fatalf("did not expect runTUIFn for test flames mode")
+ t.Fatalf("did not expect runTUI for test flames mode")
}
if !testFlamesCalled {
- t.Fatalf("expected runTUITestFlamesFn to be called")
+ t.Fatalf("expected runTUITestFlames to be called")
}
}
func TestDispatchRunUsesTestLiveFlamesModeWhenRequested(t *testing.T) {
- origRunTrace := runTraceFn
- origRunTUI := runTUIFn
- origRunTUITestFlames := runTUITestFlamesFn
- origRunTUITestLiveFlames := runTUITestLiveFlamesFn
- defer func() {
- runTraceFn = origRunTrace
- runTUIFn = origRunTUI
- runTUITestFlamesFn = origRunTUITestFlames
- runTUITestLiveFlamesFn = origRunTUITestLiveFlames
- }()
-
traceCalled := false
regularTUICalled := false
testLiveFlamesCalled := false
- runTraceFn = func(flags.Config) error {
+ deps := stubDeps()
+ deps.runTrace = func(flags.Config) error {
traceCalled = true
return nil
}
- runTUIFn = func(flags.Config, runtime.TraceStarter) error {
+ deps.runTUI = func(flags.Config, runtime.TraceStarter) error {
regularTUICalled = true
return nil
}
- runTUITestFlamesFn = func(flags.Config, runtime.TraceStarter) error {
- t.Fatalf("runTUITestFlamesFn should not be called for --testliveflames")
+ deps.runTUITestFlames = func(flags.Config, runtime.TraceStarter) error {
+ t.Fatalf("runTUITestFlames should not be called for --testliveflames")
return nil
}
- runTUITestLiveFlamesFn = func(_ flags.Config, starter runtime.TraceStarter) error {
+ deps.runTUITestLiveFlames = func(_ flags.Config, starter runtime.TraceStarter) error {
testLiveFlamesCalled = true
if starter == nil {
t.Fatalf("expected non-nil starter for test live flames mode")
@@ -389,17 +328,68 @@ func TestDispatchRunUsesTestLiveFlamesModeWhenRequested(t *testing.T) {
}
cfg := flags.Config{TestLiveFlames: true}
- if err := dispatchRun(cfg); err != nil {
- t.Fatalf("dispatchRun returned error: %v", err)
+ if err := dispatchRunWithDeps(cfg, deps); err != nil {
+ t.Fatalf("dispatchRunWithDeps returned error: %v", err)
}
if traceCalled {
- t.Fatalf("did not expect runTraceFn for test live flames mode")
+ t.Fatalf("did not expect runTrace for test live flames mode")
}
if regularTUICalled {
- t.Fatalf("did not expect runTUIFn for test live flames mode")
+ t.Fatalf("did not expect runTUI for test live flames mode")
}
if !testLiveFlamesCalled {
- t.Fatalf("expected runTUITestLiveFlamesFn to be called")
+ t.Fatalf("expected runTUITestLiveFlames to be called")
+ }
+}
+
+// TestDispatchRunRequiresRootForTUI verifies that the TUI mode handler
+// enforces the root-privilege gate via deps.getEUID.
+func TestDispatchRunRequiresRootForTUI(t *testing.T) {
+ deps := stubDeps()
+ deps.getEUID = func() int { return 1000 } // non-root
+ deps.runTUI = func(flags.Config, runtime.TraceStarter) error {
+ t.Fatalf("runTUI must not be called when not root")
+ return nil
+ }
+
+ cfg := flags.Config{}
+ err := dispatchRunWithDeps(cfg, deps)
+ if !errors.Is(err, errRootPrivilegesRequired) {
+ t.Fatalf("expected root-required error, got %v", err)
+ }
+}
+
+// TestDispatchRunRequiresRootForPlainTrace verifies that the plain trace
+// mode handler enforces the root-privilege gate via deps.getEUID.
+func TestDispatchRunRequiresRootForPlainTrace(t *testing.T) {
+ deps := stubDeps()
+ deps.getEUID = func() int { return 1000 } // non-root
+ deps.runTrace = func(flags.Config) error {
+ t.Fatalf("runTrace must not be called when not root")
+ return nil
+ }
+
+ cfg := flags.Config{PlainMode: true}
+ err := dispatchRunWithDeps(cfg, deps)
+ if !errors.Is(err, errRootPrivilegesRequired) {
+ t.Fatalf("expected root-required error, got %v", err)
+ }
+}
+
+// TestDispatchRunRequiresRootForParquet verifies that the headless Parquet
+// mode handler enforces the root-privilege gate via deps.getEUID.
+func TestDispatchRunRequiresRootForParquet(t *testing.T) {
+ deps := stubDeps()
+ deps.getEUID = func() int { return 1000 } // non-root
+ deps.runParquet = func(flags.Config) error {
+ t.Fatalf("runParquet must not be called when not root")
+ return nil
+ }
+
+ cfg := flags.Config{ParquetPath: "trace.parquet"}
+ err := dispatchRunWithDeps(cfg, deps)
+ if !errors.Is(err, errRootPrivilegesRequired) {
+ t.Fatalf("expected root-required error, got %v", err)
}
}
@@ -536,17 +526,6 @@ func TestBuildTestLiveFlamesRuntimeContinuouslyUpdatesLiveTrie(t *testing.T) {
})
}
-func TestRunTraceWithContextRequiresRoot(t *testing.T) {
- origGetEUID := getEUID
- defer func() { getEUID = origGetEUID }()
-
- getEUID = func() int { return 1000 }
- err := runTraceWithContext(context.Background(), flags.NewFlags(), nil, nil)
- if !errors.Is(err, errRootPrivilegesRequired) {
- t.Fatalf("expected root-required error, got %v", err)
- }
-}
-
func TestTuiTraceStarterFromRunTracePropagatesError(t *testing.T) {
starter := tuiTraceStarterFromRunTrace(
flags.NewFlags(),
diff --git a/internal/ior_parquet_sink.go b/internal/ior_parquet_sink.go
index 83ca769..b2a1439 100644
--- a/internal/ior_parquet_sink.go
+++ b/internal/ior_parquet_sink.go
@@ -90,11 +90,9 @@ func headlessParquetTraceConfig(cfg flags.Config) flags.Config {
}
// runHeadlessParquet records all traced syscalls directly to a Parquet file
-// without starting the TUI.
+// without starting the TUI. Root privilege is checked by the mode handler
+// (via runnerDeps.getEUID) before this function is invoked.
func runHeadlessParquet(cfg flags.Config) error {
- if getEUID() != 0 {
- return errRootPrivilegesRequired
- }
cfg = headlessParquetTraceConfig(cfg)
logln := newLogger(true)