diff options
| author | Paul Buetow <paul@buetow.org> | 2026-02-24 08:28:47 +0200 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2026-02-24 08:28:47 +0200 |
| commit | ba7af922d289a9d0fff1c4ef33764b1852c774f6 (patch) | |
| tree | 3d470e9d0dc967efab4b6a3e56ee361d0ca9cd96 /internal | |
| parent | b79a868fbc85cd7fb2829e978174629ab8a9c986 (diff) | |
ior: route default mode through tui and add plain flag
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/flags/flags.go | 2 | ||||
| -rw-r--r-- | internal/ior.go | 73 | ||||
| -rw-r--r-- | internal/ior_mode_test.go | 136 | ||||
| -rw-r--r-- | internal/tui/tui.go | 7 |
4 files changed, 210 insertions, 8 deletions
diff --git a/internal/flags/flags.go b/internal/flags/flags.go index 8760d33..1909b4a 100644 --- a/internal/flags/flags.go +++ b/internal/flags/flags.go @@ -53,6 +53,7 @@ type Flags struct { TracepointsToExclude []*regexp.Regexp // Flamegraph flags + PlainMode bool FlamegraphEnable bool FlamegraphName string @@ -96,6 +97,7 @@ func parse() { tracepointsToAttach := flag.String("tps", "", "Comma separated list regexes for tracepoints to load") tracepointsToExclude := flag.String("tpsExclude", "", "Comma separated list regexes for tracepoints to exclude") + flag.BoolVar(&singleton.PlainMode, "plain", false, "Enable plain CSV output mode (disable TUI)") flag.BoolVar(&singleton.FlamegraphEnable, "flamegraph", false, "Enable flamegraph builder") flag.StringVar(&singleton.FlamegraphName, "name", "default", "Name of the flamegraph, used to generate the SVG file") diff --git a/internal/ior.go b/internal/ior.go index 0299539..011d2fb 100644 --- a/internal/ior.go +++ b/internal/ior.go @@ -14,6 +14,7 @@ import ( "ior/internal/flags" "ior/internal/flamegraph" "ior/internal/tracepoints" + "ior/internal/tui" bpf "github.com/aquasecurity/libbpfgo" ) @@ -22,6 +23,12 @@ type tracepointProgram interface { attachTracepoint(category, name string) error } +var ( + runTraceFn = runTrace + runTraceWithContextFn = runTraceWithContext + runTUIFn = tui.RunWithTraceStarter +) + type tracepointModule interface { getProgram(progName string) (tracepointProgram, error) } @@ -77,12 +84,13 @@ func attachTracepointsWith(module tracepointModule, shouldAttach func(string) bo func Run() error { flags.PrintVersion() - iorFile := flags.Get().IorDataFile + cfg := flags.Get() + iorFile := cfg.IorDataFile var noTraceRun bool if iorFile != "" { noTraceRun = true - collapsed := flamegraph.NewCollapsed(iorFile, flags.Get().CollapsedFields, flags.Get().CountField) + collapsed := flamegraph.NewCollapsed(iorFile, cfg.CollapsedFields, cfg.CountField) collapsedFile, err := collapsed.Write(iorFile) if err != nil { return err @@ -100,10 +108,46 @@ func Run() error { if noTraceRun { return nil } - return runTrace() + return dispatchRun(cfg) +} + +func dispatchRun(cfg flags.Flags) error { + if shouldRunTraceMode(cfg) { + return runTraceFn() + } + return runTUIFn(tuiTraceStarterFromRunTrace(runTraceWithContextFn)) +} + +func shouldRunTraceMode(cfg flags.Flags) bool { + return cfg.PlainMode || cfg.FlamegraphEnable || cfg.PprofEnable +} + +func tuiTraceStarterFromRunTrace(startTrace func(context.Context, chan<- struct{}) error) tui.TraceStarter { + return func(ctx context.Context) error { + startedCh := make(chan struct{}) + errCh := make(chan error, 1) + + go func() { + errCh <- startTrace(ctx, startedCh) + close(errCh) + }() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-startedCh: + return nil + case err := <-errCh: + return err + } + } } func runTrace() error { + return runTraceWithContext(context.Background(), nil) +} + +func runTraceWithContext(parentCtx context.Context, started chan<- struct{}) error { bpfModule, err := bpf.NewModuleFromFile("ior.bpf.o") if err != nil { return err @@ -148,18 +192,26 @@ func runTrace() error { close(pprofDone) } + signalTraceStarted(started) + el := newEventLoop() duration := time.Duration(flags.Get().Duration) * time.Second fmt.Println("Probing for", duration) - ctx, cancel := context.WithTimeout(context.Background(), duration) + ctx, cancel := context.WithTimeout(parentCtx, duration) + defer cancel() signalCh := make(chan os.Signal, 1) signal.Notify(signalCh, os.Interrupt, syscall.SIGTERM) + defer signal.Stop(signalCh) go func() { - <-signalCh - fmt.Println("Received signal, shutting down...") - cancel() + select { + case <-signalCh: + fmt.Println("Received signal, shutting down...") + cancel() + case <-ctx.Done(): + return + } }() go func() { @@ -180,3 +232,10 @@ func runTrace() error { fmt.Println("Good bye... (unloading BPF tracepoints will take a few seconds...) after", totalDuration) return nil } + +func signalTraceStarted(started chan<- struct{}) { + if started == nil { + return + } + close(started) +} diff --git a/internal/ior_mode_test.go b/internal/ior_mode_test.go new file mode 100644 index 0000000..ddc915b --- /dev/null +++ b/internal/ior_mode_test.go @@ -0,0 +1,136 @@ +package internal + +import ( + "context" + "errors" + "testing" + "time" + + "ior/internal/flags" + "ior/internal/tui" +) + +func TestShouldRunTraceMode(t *testing.T) { + base := flags.Flags{} + + if shouldRunTraceMode(base) { + t.Fatalf("expected default mode to use TUI") + } + + withPlain := base + withPlain.PlainMode = true + if !shouldRunTraceMode(withPlain) { + t.Fatalf("expected plain mode to use trace mode") + } + + withFlamegraph := base + withFlamegraph.FlamegraphEnable = true + if !shouldRunTraceMode(withFlamegraph) { + t.Fatalf("expected flamegraph mode to use trace mode") + } + + withPprof := base + withPprof.PprofEnable = true + if !shouldRunTraceMode(withPprof) { + t.Fatalf("expected pprof mode to use trace mode") + } +} + +func TestDispatchRunUsesTraceModeWhenRequested(t *testing.T) { + origRunTrace := runTraceFn + origRunTUI := runTUIFn + defer func() { + runTraceFn = origRunTrace + runTUIFn = origRunTUI + }() + + traceCalled := false + tuiCalled := false + runTraceFn = func() error { + traceCalled = true + return nil + } + runTUIFn = func(tui.TraceStarter) error { + tuiCalled = true + return nil + } + + cfg := flags.Flags{PlainMode: true} + if err := dispatchRun(cfg); err != nil { + t.Fatalf("dispatchRun returned error: %v", err) + } + if !traceCalled { + t.Fatalf("expected runTraceFn to be called") + } + if tuiCalled { + t.Fatalf("did not expect runTUIFn to be called") + } +} + +func TestDispatchRunUsesTUIStarterWhenNotPlain(t *testing.T) { + origRunTraceWithContext := runTraceWithContextFn + origRunTUI := runTUIFn + defer func() { + runTraceWithContextFn = origRunTraceWithContext + runTUIFn = origRunTUI + }() + + traceDone := make(chan struct{}, 1) + runTraceWithContextFn = func(_ context.Context, started chan<- struct{}) error { + close(started) + traceDone <- struct{}{} + return nil + } + + tuiCalled := false + runTUIFn = func(starter tui.TraceStarter) error { + tuiCalled = true + if starter == nil { + t.Fatalf("expected non-nil starter") + } + if err := starter(context.Background()); err != nil { + t.Fatalf("starter returned error: %v", err) + } + return nil + } + + cfg := flags.Flags{} + if err := dispatchRun(cfg); err != nil { + t.Fatalf("dispatchRun returned error: %v", err) + } + if !tuiCalled { + t.Fatalf("expected runTUIFn to be called") + } + + select { + case <-traceDone: + case <-time.After(200 * time.Millisecond): + t.Fatalf("expected starter to launch runTraceWithContextFn") + } +} + +func TestTuiTraceStarterFromRunTracePropagatesError(t *testing.T) { + starter := tuiTraceStarterFromRunTrace(func(context.Context, chan<- struct{}) error { + return errors.New("startup failed") + }) + + err := starter(context.Background()) + if err == nil || err.Error() != "startup failed" { + t.Fatalf("expected startup error, got %v", err) + } +} + +func TestTuiTraceStarterFromRunTraceRespectsCancel(t *testing.T) { + starter := tuiTraceStarterFromRunTrace(func(ctx context.Context, _ chan<- struct{}) error { + <-ctx.Done() + return ctx.Err() + }) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := starter(ctx) + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context canceled, got %v", err) + } +} diff --git a/internal/tui/tui.go b/internal/tui/tui.go index 758213d..49d2365 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -28,7 +28,12 @@ type TraceStarter func(context.Context) error // Run starts the TUI program in alternate screen mode. func Run() error { - model := NewModel(flags.Get().PidFilter, defaultTraceStarter) + return RunWithTraceStarter(defaultTraceStarter) +} + +// RunWithTraceStarter starts the TUI program with a custom trace starter. +func RunWithTraceStarter(starter TraceStarter) error { + model := NewModel(flags.Get().PidFilter, starter) program := tea.NewProgram(model, tea.WithAltScreen()) _, err := program.Run() return err |
