diff options
| author | Paul Buetow <paul@buetow.org> | 2026-03-06 16:01:46 +0200 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2026-03-06 16:01:46 +0200 |
| commit | fcee8baac995b25ffb9ab06567f010df105c3db1 (patch) | |
| tree | caeba14673afcdc66698d63ccfee8d1b535a41ed /internal | |
| parent | aca5e2205b4dd18a13706c725daa0f326e10000b (diff) | |
refactor: thread runtime flags through ior and tui (task 385)
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/ior.go | 40 | ||||
| -rw-r--r-- | internal/ior_mode_test.go | 77 | ||||
| -rw-r--r-- | internal/tui/tui.go | 61 | ||||
| -rw-r--r-- | internal/tui/tui_test.go | 29 |
4 files changed, 139 insertions, 68 deletions
diff --git a/internal/ior.go b/internal/ior.go index ee0663e..4d5aea1 100644 --- a/internal/ior.go +++ b/internal/ior.go @@ -38,9 +38,9 @@ type tracepointLink interface { var ( runTraceFn = runTrace runTraceWithContextFn = runTraceWithContext - runTUIFn = tui.RunWithTraceStarter - runTUITestFlamesFn = tui.RunTestFlamesWithTraceStarter - runTUITestLiveFlamesFn = tui.RunTestFlamesWithTraceStarter + runTUIFn = tui.RunWithTraceStarterConfig + runTUITestFlamesFn = tui.RunTestFlamesWithTraceStarterConfig + runTUITestLiveFlamesFn = tui.RunTestFlamesWithTraceStarterConfig getEUID = os.Geteuid errRootPrivilegesRequired = errors.New("tracing requires root privileges (run with sudo)") @@ -124,15 +124,15 @@ func dispatchRun(cfg flags.Flags) error { return err } if cfg.TestFlames { - return runTUITestFlamesFn(tuiTestFlamesStarter()) + return runTUITestFlamesFn(cfg, tuiTestFlamesStarter(cfg)) } if cfg.TestLiveFlames { - return runTUITestLiveFlamesFn(tuiTestLiveFlamesStarter()) + return runTUITestLiveFlamesFn(cfg, tuiTestLiveFlamesStarter(cfg)) } if shouldRunTraceMode(cfg) { - return runTraceFn() + return runTraceFn(cfg) } - return runTUIFn(tuiTraceStarterFromRunTrace(runTraceWithContextFn)) + return runTUIFn(cfg, tuiTraceStarterFromRunTrace(cfg, runTraceWithContextFn)) } func validateRunConfig(cfg flags.Flags) error { @@ -148,9 +148,9 @@ func validateRunConfig(cfg flags.Flags) error { return nil } -func tuiTestFlamesStarter() tui.TraceStarter { +func tuiTestFlamesStarter(cfg flags.Flags) tui.TraceStarter { return func(ctx context.Context) error { - engine, streamBuf, liveTrie := buildTestFlamesRuntime(flags.Get()) + engine, streamBuf, liveTrie := buildTestFlamesRuntime(cfg) if bindings, ok := tui.RuntimeBindingsFromContext(ctx); ok { bindings.SetDashboardSnapshotSource(engine) bindings.SetEventStreamSource(streamBuf) @@ -160,9 +160,9 @@ func tuiTestFlamesStarter() tui.TraceStarter { } } -func tuiTestLiveFlamesStarter() tui.TraceStarter { +func tuiTestLiveFlamesStarter(cfg flags.Flags) tui.TraceStarter { return func(ctx context.Context) error { - engine, streamBuf, liveTrie := buildTestLiveFlamesRuntime(ctx, flags.Get()) + engine, streamBuf, liveTrie := buildTestLiveFlamesRuntime(ctx, cfg) if bindings, ok := tui.RuntimeBindingsFromContext(ctx); ok { bindings.SetDashboardSnapshotSource(engine) bindings.SetEventStreamSource(streamBuf) @@ -220,14 +220,19 @@ func shouldRunTraceMode(cfg flags.Flags) bool { } func tuiTraceStarterFromRunTrace( - startTrace func(context.Context, chan<- struct{}, func(*eventLoop)) error, + baseCfg flags.Flags, + startTrace func(context.Context, flags.Flags, chan<- struct{}, func(*eventLoop)) error, ) tui.TraceStarter { return func(ctx context.Context) error { bpf.SetLoggerCbs(bpf.Callbacks{ Log: func(int, string) {}, }) - cfg := flags.Get() + cfg := baseCfg + if pidFilter, tidFilter, ok := tui.TraceFiltersFromContext(ctx); ok { + cfg.PidFilter = pidFilter + cfg.TidFilter = tidFilter + } engine := statsengine.NewEngine(64) streamBuf := eventstream.NewRingBuffer() liveTrie := flamegraph.NewLiveTrie(cfg.CollapsedFields, cfg.CountField) @@ -248,7 +253,7 @@ func tuiTraceStarterFromRunTrace( errCh := make(chan error, 1) go func() { - err := startTrace(ctx, startedCh, func(el *eventLoop) { + err := startTrace(ctx, cfg, startedCh, func(el *eventLoop) { el.printCb = func(ep *event.Pair) { engine.Ingest(ep) streamEvents <- eventstream.NewStreamEvent(ep.EnterEv.GetTime(), ep) @@ -278,8 +283,8 @@ func tuiTraceStarterFromRunTrace( } } -func runTrace() error { - return runTraceWithContext(context.Background(), nil, nil) +func runTrace(cfg flags.Flags) error { + return runTraceWithContext(context.Background(), cfg, nil, nil) } func newEventLoopConfig(cfg flags.Flags) eventLoopConfig { @@ -296,7 +301,7 @@ func newEventLoopConfig(cfg flags.Flags) eventLoopConfig { } } -func runTraceWithContext(parentCtx context.Context, started chan<- struct{}, configure func(*eventLoop)) error { +func runTraceWithContext(parentCtx context.Context, cfg flags.Flags, started chan<- struct{}, configure func(*eventLoop)) error { if getEUID() != 0 { return errRootPrivilegesRequired } @@ -306,7 +311,6 @@ func runTraceWithContext(parentCtx context.Context, started chan<- struct{}, con if verbose { logln = func(args ...any) { _, _ = fmt.Println(args...) } } - cfg := flags.Get() bpfModule, err := bpf.NewModuleFromFile("ior.bpf.o") if err != nil { diff --git a/internal/ior_mode_test.go b/internal/ior_mode_test.go index 9ea8b61..617e567 100644 --- a/internal/ior_mode_test.go +++ b/internal/ior_mode_test.go @@ -78,19 +78,19 @@ func TestDispatchRunUsesTraceModeWhenRequested(t *testing.T) { traceCalled := false tuiCalled := false - runTraceFn = func() error { + runTraceFn = func(flags.Flags) error { traceCalled = true return nil } - runTUIFn = func(tui.TraceStarter) error { + runTUIFn = func(flags.Flags, tui.TraceStarter) error { tuiCalled = true return nil } - runTUITestFlamesFn = func(tui.TraceStarter) error { + runTUITestFlamesFn = func(flags.Flags, tui.TraceStarter) error { t.Fatalf("runTUITestFlamesFn should not be called in trace mode") return nil } - runTUITestLiveFlamesFn = func(tui.TraceStarter) error { + runTUITestLiveFlamesFn = func(flags.Flags, tui.TraceStarter) error { t.Fatalf("runTUITestLiveFlamesFn should not be called in trace mode") return nil } @@ -121,19 +121,19 @@ func TestDispatchRunUsesTUIWhenOnlyPprofEnabled(t *testing.T) { traceCalled := false tuiCalled := false - runTraceFn = func() error { + runTraceFn = func(flags.Flags) error { traceCalled = true return nil } - runTUIFn = func(tui.TraceStarter) error { + runTUIFn = func(flags.Flags, tui.TraceStarter) error { tuiCalled = true return nil } - runTUITestFlamesFn = func(tui.TraceStarter) error { + runTUITestFlamesFn = func(flags.Flags, tui.TraceStarter) error { t.Fatalf("runTUITestFlamesFn should not be called for regular TUI mode") return nil } - runTUITestLiveFlamesFn = func(tui.TraceStarter) error { + runTUITestLiveFlamesFn = func(flags.Flags, tui.TraceStarter) error { t.Fatalf("runTUITestLiveFlamesFn should not be called for regular TUI mode") return nil } @@ -163,7 +163,7 @@ func TestDispatchRunUsesTUIStarterWhenNotPlain(t *testing.T) { }() traceDone := make(chan struct{}, 1) - runTraceWithContextFn = func(_ context.Context, started chan<- struct{}, configure func(*eventLoop)) error { + runTraceWithContextFn = func(_ context.Context, _ flags.Flags, started chan<- struct{}, configure func(*eventLoop)) error { if configure != nil { configure(&eventLoop{}) } @@ -173,7 +173,7 @@ func TestDispatchRunUsesTUIStarterWhenNotPlain(t *testing.T) { } tuiCalled := false - runTUIFn = func(starter tui.TraceStarter) error { + runTUIFn = func(_ flags.Flags, starter tui.TraceStarter) error { tuiCalled = true if starter == nil { t.Fatalf("expected non-nil starter") @@ -183,11 +183,11 @@ func TestDispatchRunUsesTUIStarterWhenNotPlain(t *testing.T) { } return nil } - runTUITestFlamesFn = func(tui.TraceStarter) error { + runTUITestFlamesFn = func(flags.Flags, tui.TraceStarter) error { t.Fatalf("runTUITestFlamesFn should not be called for normal starter path") return nil } - runTUITestLiveFlamesFn = func(tui.TraceStarter) error { + runTUITestLiveFlamesFn = func(flags.Flags, tui.TraceStarter) error { t.Fatalf("runTUITestLiveFlamesFn should not be called for normal starter path") return nil } @@ -222,22 +222,22 @@ func TestDispatchRunUsesTestFlamesModeWhenRequested(t *testing.T) { traceCalled := false regularTUICalled := false testFlamesCalled := false - runTraceFn = func() error { + runTraceFn = func(flags.Flags) error { traceCalled = true return nil } - runTUIFn = func(tui.TraceStarter) error { + runTUIFn = func(flags.Flags, tui.TraceStarter) error { regularTUICalled = true return nil } - runTUITestFlamesFn = func(starter tui.TraceStarter) error { + runTUITestFlamesFn = func(_ flags.Flags, starter tui.TraceStarter) error { testFlamesCalled = true if starter == nil { t.Fatalf("expected non-nil starter for test flames mode") } return starter(context.Background()) } - runTUITestLiveFlamesFn = func(tui.TraceStarter) error { + runTUITestLiveFlamesFn = func(flags.Flags, tui.TraceStarter) error { t.Fatalf("runTUITestLiveFlamesFn should not be called for --testflames") return nil } @@ -272,19 +272,19 @@ func TestDispatchRunUsesTestLiveFlamesModeWhenRequested(t *testing.T) { traceCalled := false regularTUICalled := false testLiveFlamesCalled := false - runTraceFn = func() error { + runTraceFn = func(flags.Flags) error { traceCalled = true return nil } - runTUIFn = func(tui.TraceStarter) error { + runTUIFn = func(flags.Flags, tui.TraceStarter) error { regularTUICalled = true return nil } - runTUITestFlamesFn = func(tui.TraceStarter) error { + runTUITestFlamesFn = func(flags.Flags, tui.TraceStarter) error { t.Fatalf("runTUITestFlamesFn should not be called for --testliveflames") return nil } - runTUITestLiveFlamesFn = func(starter tui.TraceStarter) error { + runTUITestLiveFlamesFn = func(_ flags.Flags, starter tui.TraceStarter) error { testLiveFlamesCalled = true if starter == nil { t.Fatalf("expected non-nil starter for test live flames mode") @@ -409,7 +409,7 @@ func TestRunTraceWithContextRequiresRoot(t *testing.T) { defer func() { getEUID = origGetEUID }() getEUID = func() int { return 1000 } - err := runTraceWithContext(context.Background(), nil, nil) + err := runTraceWithContext(context.Background(), flags.NewFlags(), nil, nil) if !errors.Is(err, errRootPrivilegesRequired) { t.Fatalf("expected root-required error, got %v", err) } @@ -417,7 +417,10 @@ func TestRunTraceWithContextRequiresRoot(t *testing.T) { func TestTuiTraceStarterFromRunTracePropagatesError(t *testing.T) { starter := tuiTraceStarterFromRunTrace( - func(context.Context, chan<- struct{}, func(*eventLoop)) error { return errors.New("startup failed") }, + flags.NewFlags(), + func(context.Context, flags.Flags, chan<- struct{}, func(*eventLoop)) error { + return errors.New("startup failed") + }, ) err := starter(context.Background()) @@ -426,6 +429,33 @@ func TestTuiTraceStarterFromRunTracePropagatesError(t *testing.T) { } } +func TestTuiTraceStarterFromRunTraceUsesContextFilters(t *testing.T) { + base := flags.NewFlags() + base.PidFilter = 11 + base.TidFilter = 12 + + var gotCfg flags.Flags + starter := tuiTraceStarterFromRunTrace( + base, + func(_ context.Context, cfg flags.Flags, started chan<- struct{}, _ func(*eventLoop)) error { + gotCfg = cfg + close(started) + return nil + }, + ) + + ctx := tui.ContextWithTraceFilters(context.Background(), 2222, 3333) + if err := starter(ctx); err != nil { + t.Fatalf("starter returned error: %v", err) + } + if gotCfg.PidFilter != 2222 { + t.Fatalf("expected pid filter from context, got %d", gotCfg.PidFilter) + } + if gotCfg.TidFilter != 3333 { + t.Fatalf("expected tid filter from context, got %d", gotCfg.TidFilter) + } +} + func TestProfilingFilesForMode(t *testing.T) { cpu, mem, execTrace, duration := profilingFilesForMode(false) if cpu != "ior.cpuprofile" || mem != "ior.memprofile" { @@ -446,7 +476,8 @@ func TestProfilingFilesForMode(t *testing.T) { func TestTuiTraceStarterFromRunTraceRespectsCancel(t *testing.T) { starter := tuiTraceStarterFromRunTrace( - func(ctx context.Context, _ chan<- struct{}, _ func(*eventLoop)) error { + flags.NewFlags(), + func(ctx context.Context, _ flags.Flags, _ chan<- struct{}, _ func(*eventLoop)) error { <-ctx.Done() return ctx.Err() }, diff --git a/internal/tui/tui.go b/internal/tui/tui.go index 12c904d..c1ba700 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -65,6 +65,7 @@ type TraceRuntimeBindings interface { } type runtimeBindingsContextKey struct{} +type traceFiltersContextKey struct{} type runtimeBindings struct { mu sync.RWMutex @@ -75,6 +76,11 @@ type runtimeBindings struct { probeManager ProbeManager } +type traceFilters struct { + pidFilter int + tidFilter int +} + func newRuntimeBindings() *runtimeBindings { return &runtimeBindings{} } @@ -152,6 +158,21 @@ func RuntimeBindingsFromContext(ctx context.Context) (TraceRuntimeBindings, bool return bindings, true } +// ContextWithTraceFilters stores the active PID/TID filters for the trace starter. +func ContextWithTraceFilters(ctx context.Context, pidFilter, tidFilter int) context.Context { + filters := traceFilters{pidFilter: pidFilter, tidFilter: tidFilter} + return context.WithValue(ctx, traceFiltersContextKey{}, filters) +} + +// TraceFiltersFromContext returns the active PID/TID filters when provided by the TUI model. +func TraceFiltersFromContext(ctx context.Context) (pidFilter, tidFilter int, ok bool) { + filters, ok := ctx.Value(traceFiltersContextKey{}).(traceFilters) + if !ok { + return 0, 0, false + } + return filters.pidFilter, filters.tidFilter, true +} + // Run starts the TUI program in alternate screen mode. func Run() error { return RunWithTraceStarter(defaultTraceStarter) @@ -159,8 +180,12 @@ func Run() error { // RunWithTraceStarter starts the TUI program with a custom trace starter. func RunWithTraceStarter(starter TraceStarter) error { - cfg := flags.Get() - model := newModelWithRuntimeConfig(cfg.PidFilter, cfg.PidFilter, cfg.TUIExportEnable, starter) + return RunWithTraceStarterConfig(flags.Get(), starter) +} + +// RunWithTraceStarterConfig starts the TUI with explicit runtime flags. +func RunWithTraceStarterConfig(cfg flags.Flags, starter TraceStarter) error { + model := newModelWithRuntimeConfig(cfg.PidFilter, cfg.PidFilter, cfg.TidFilter, cfg.TUIExportEnable, starter) program := tea.NewProgram(model) _, err := program.Run() return err @@ -169,8 +194,12 @@ func RunWithTraceStarter(starter TraceStarter) error { // RunTestFlamesWithTraceStarter starts the TUI directly on dashboard/flame view // with a synthetic static flamegraph source. func RunTestFlamesWithTraceStarter(starter TraceStarter) error { - cfg := flags.Get() - model := newModelWithRuntimeConfig(1, 1, cfg.TUIExportEnable, starter) + return RunTestFlamesWithTraceStarterConfig(flags.Get(), starter) +} + +// RunTestFlamesWithTraceStarterConfig starts test-flames mode with explicit runtime flags. +func RunTestFlamesWithTraceStarterConfig(cfg flags.Flags, starter TraceStarter) error { + model := newModelWithRuntimeConfig(1, 1, -1, cfg.TUIExportEnable, starter) program := tea.NewProgram(model) _, err := program.Run() return err @@ -201,6 +230,7 @@ type Model struct { traceStop context.CancelFunc pidFilter int + tidFilter int exportEnabled bool isDark bool focused bool @@ -220,11 +250,15 @@ type Model struct { // NewModel creates the top-level TUI model. func NewModel(initialPID int, startTrace TraceStarter) Model { - cfg := flags.Get() - return newModelWithRuntimeConfig(initialPID, cfg.PidFilter, cfg.TUIExportEnable, startTrace) + return NewModelWithConfig(flags.Get(), initialPID, startTrace) } -func newModelWithRuntimeConfig(initialPID, startupPidFilter int, exportEnabled bool, startTrace TraceStarter) Model { +// NewModelWithConfig creates the top-level TUI model with explicit runtime flags. +func NewModelWithConfig(cfg flags.Flags, initialPID int, startTrace TraceStarter) Model { + return newModelWithRuntimeConfig(initialPID, cfg.PidFilter, cfg.TidFilter, cfg.TUIExportEnable, startTrace) +} + +func newModelWithRuntimeConfig(initialPID, startupPidFilter, startupTidFilter int, exportEnabled bool, startTrace TraceStarter) Model { common.ApplyPalette(true) syncStylesFromCommon() @@ -245,6 +279,10 @@ func newModelWithRuntimeConfig(initialPID, startupPidFilter int, exportEnabled b if initialPID > 0 { pidFilter = selectedPIDFilter(initialPID) } + tidFilter := selectedPIDFilter(startupTidFilter) + if initialPID > 0 { + tidFilter = -1 + } dashboard.SetPidFilter(pidFilter) model := Model{ @@ -258,13 +296,13 @@ func newModelWithRuntimeConfig(initialPID, startupPidFilter int, exportEnabled b spin: spin, startTrace: startTrace, pidFilter: pidFilter, + tidFilter: tidFilter, exportEnabled: exportEnabled, isDark: true, focused: true, } if initialPID > 0 { - flags.SetPidFilter(initialPID) model.screen = ScreenDashboard model.attaching = true } @@ -529,9 +567,8 @@ func (m Model) updateActiveModel(msg tea.Msg) (tea.Model, tea.Cmd) { func (m Model) handlePidSelected(msg PidSelectedMsg) (tea.Model, tea.Cmd) { pid := selectedPIDFilter(msg.Pid) m.stopTrace() - flags.SetPidFilter(pid) - flags.SetTidFilter(-1) m.pidFilter = pid + m.tidFilter = -1 m.dashboard.SetPidFilter(pid) m.screen = ScreenDashboard m.attaching = true @@ -546,9 +583,8 @@ func (m Model) handleTidSelected(msg TidSelectedMsg) (tea.Model, tea.Cmd) { pid = msg.Pid } m.stopTrace() - flags.SetPidFilter(pid) - flags.SetTidFilter(tid) m.pidFilter = pid + m.tidFilter = tid m.dashboard.SetPidFilter(pid) m.screen = ScreenDashboard m.attaching = true @@ -607,6 +643,7 @@ func (m *Model) beginTraceCmd() tea.Cmd { ctx, cancel := context.WithCancel(context.Background()) m.traceStop = cancel ctx = context.WithValue(ctx, runtimeBindingsContextKey{}, m.runtime) + ctx = ContextWithTraceFilters(ctx, m.pidFilter, m.tidFilter) return startTraceCmd(m.startTrace, ctx) } diff --git a/internal/tui/tui_test.go b/internal/tui/tui_test.go index cd9e6cd..ad529fc 100644 --- a/internal/tui/tui_test.go +++ b/internal/tui/tui_test.go @@ -51,11 +51,11 @@ func TestPidSelectedTransitionsToDashboardAndSetsPIDFilter(t *testing.T) { if !updated.attaching { t.Fatalf("expected attaching state to be true") } - if got := flags.Get().PidFilter; got != 42 { - t.Fatalf("expected pid filter 42, got %d", got) + if updated.pidFilter != 42 { + t.Fatalf("expected pid filter 42, got %d", updated.pidFilter) } - if got := flags.Get().TidFilter; got != -1 { - t.Fatalf("expected tid filter reset to -1, got %d", got) + if updated.tidFilter != -1 { + t.Fatalf("expected tid filter reset to -1, got %d", updated.tidFilter) } } @@ -80,10 +80,9 @@ func TestPidSelectedAllSetsNoFilter(t *testing.T) { next, _ := m.Update(PidSelectedMsg{Pid: 0}) updated := next.(Model) - if got := flags.Get().PidFilter; got != -1 { - t.Fatalf("expected pid filter -1 for all pids, got %d", got) + if updated.pidFilter != -1 { + t.Fatalf("expected pid filter -1 for all pids, got %d", updated.pidFilter) } - _ = updated } func TestTracingErrorMessageClearsAttachingState(t *testing.T) { @@ -623,11 +622,11 @@ func TestTidSelectedTransitionsToDashboardAndSetsTIDFilter(t *testing.T) { if !updated.attaching { t.Fatalf("expected attaching state to be true") } - if got := flags.Get().TidFilter; got != 3333 { - t.Fatalf("expected tid filter 3333, got %d", got) + if updated.tidFilter != 3333 { + t.Fatalf("expected tid filter 3333, got %d", updated.tidFilter) } - if got := flags.Get().PidFilter; got != 2222 { - t.Fatalf("expected pid filter to remain 2222, got %d", got) + if updated.pidFilter != 2222 { + t.Fatalf("expected pid filter to remain 2222, got %d", updated.pidFilter) } } @@ -644,11 +643,11 @@ func TestTidSelectedFromAllPIDModeSetsOwningPID(t *testing.T) { if updated.screen != ScreenDashboard { t.Fatalf("expected dashboard screen, got %v", updated.screen) } - if got := flags.Get().PidFilter; got != 4444 { - t.Fatalf("expected pid filter switched to owning pid 4444, got %d", got) + if updated.pidFilter != 4444 { + t.Fatalf("expected pid filter switched to owning pid 4444, got %d", updated.pidFilter) } - if got := flags.Get().TidFilter; got != 5555 { - t.Fatalf("expected tid filter 5555, got %d", got) + if updated.tidFilter != 5555 { + t.Fatalf("expected tid filter 5555, got %d", updated.tidFilter) } } |
