summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorPaul Buetow <paul@buetow.org>2026-02-24 08:28:47 +0200
committerPaul Buetow <paul@buetow.org>2026-02-24 08:28:47 +0200
commitba7af922d289a9d0fff1c4ef33764b1852c774f6 (patch)
tree3d470e9d0dc967efab4b6a3e56ee361d0ca9cd96 /internal
parentb79a868fbc85cd7fb2829e978174629ab8a9c986 (diff)
ior: route default mode through tui and add plain flag
Diffstat (limited to 'internal')
-rw-r--r--internal/flags/flags.go2
-rw-r--r--internal/ior.go73
-rw-r--r--internal/ior_mode_test.go136
-rw-r--r--internal/tui/tui.go7
4 files changed, 210 insertions, 8 deletions
diff --git a/internal/flags/flags.go b/internal/flags/flags.go
index 8760d33..1909b4a 100644
--- a/internal/flags/flags.go
+++ b/internal/flags/flags.go
@@ -53,6 +53,7 @@ type Flags struct {
TracepointsToExclude []*regexp.Regexp
// Flamegraph flags
+ PlainMode bool
FlamegraphEnable bool
FlamegraphName string
@@ -96,6 +97,7 @@ func parse() {
tracepointsToAttach := flag.String("tps", "", "Comma separated list regexes for tracepoints to load")
tracepointsToExclude := flag.String("tpsExclude", "", "Comma separated list regexes for tracepoints to exclude")
+ flag.BoolVar(&singleton.PlainMode, "plain", false, "Enable plain CSV output mode (disable TUI)")
flag.BoolVar(&singleton.FlamegraphEnable, "flamegraph", false, "Enable flamegraph builder")
flag.StringVar(&singleton.FlamegraphName, "name", "default", "Name of the flamegraph, used to generate the SVG file")
diff --git a/internal/ior.go b/internal/ior.go
index 0299539..011d2fb 100644
--- a/internal/ior.go
+++ b/internal/ior.go
@@ -14,6 +14,7 @@ import (
"ior/internal/flags"
"ior/internal/flamegraph"
"ior/internal/tracepoints"
+ "ior/internal/tui"
bpf "github.com/aquasecurity/libbpfgo"
)
@@ -22,6 +23,12 @@ type tracepointProgram interface {
attachTracepoint(category, name string) error
}
+var (
+ runTraceFn = runTrace
+ runTraceWithContextFn = runTraceWithContext
+ runTUIFn = tui.RunWithTraceStarter
+)
+
type tracepointModule interface {
getProgram(progName string) (tracepointProgram, error)
}
@@ -77,12 +84,13 @@ func attachTracepointsWith(module tracepointModule, shouldAttach func(string) bo
func Run() error {
flags.PrintVersion()
- iorFile := flags.Get().IorDataFile
+ cfg := flags.Get()
+ iorFile := cfg.IorDataFile
var noTraceRun bool
if iorFile != "" {
noTraceRun = true
- collapsed := flamegraph.NewCollapsed(iorFile, flags.Get().CollapsedFields, flags.Get().CountField)
+ collapsed := flamegraph.NewCollapsed(iorFile, cfg.CollapsedFields, cfg.CountField)
collapsedFile, err := collapsed.Write(iorFile)
if err != nil {
return err
@@ -100,10 +108,46 @@ func Run() error {
if noTraceRun {
return nil
}
- return runTrace()
+ return dispatchRun(cfg)
+}
+
+func dispatchRun(cfg flags.Flags) error {
+ if shouldRunTraceMode(cfg) {
+ return runTraceFn()
+ }
+ return runTUIFn(tuiTraceStarterFromRunTrace(runTraceWithContextFn))
+}
+
+func shouldRunTraceMode(cfg flags.Flags) bool {
+ return cfg.PlainMode || cfg.FlamegraphEnable || cfg.PprofEnable
+}
+
+func tuiTraceStarterFromRunTrace(startTrace func(context.Context, chan<- struct{}) error) tui.TraceStarter {
+ return func(ctx context.Context) error {
+ startedCh := make(chan struct{})
+ errCh := make(chan error, 1)
+
+ go func() {
+ errCh <- startTrace(ctx, startedCh)
+ close(errCh)
+ }()
+
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-startedCh:
+ return nil
+ case err := <-errCh:
+ return err
+ }
+ }
}
func runTrace() error {
+ return runTraceWithContext(context.Background(), nil)
+}
+
+func runTraceWithContext(parentCtx context.Context, started chan<- struct{}) error {
bpfModule, err := bpf.NewModuleFromFile("ior.bpf.o")
if err != nil {
return err
@@ -148,18 +192,26 @@ func runTrace() error {
close(pprofDone)
}
+ signalTraceStarted(started)
+
el := newEventLoop()
duration := time.Duration(flags.Get().Duration) * time.Second
fmt.Println("Probing for", duration)
- ctx, cancel := context.WithTimeout(context.Background(), duration)
+ ctx, cancel := context.WithTimeout(parentCtx, duration)
+ defer cancel()
signalCh := make(chan os.Signal, 1)
signal.Notify(signalCh, os.Interrupt, syscall.SIGTERM)
+ defer signal.Stop(signalCh)
go func() {
- <-signalCh
- fmt.Println("Received signal, shutting down...")
- cancel()
+ select {
+ case <-signalCh:
+ fmt.Println("Received signal, shutting down...")
+ cancel()
+ case <-ctx.Done():
+ return
+ }
}()
go func() {
@@ -180,3 +232,10 @@ func runTrace() error {
fmt.Println("Good bye... (unloading BPF tracepoints will take a few seconds...) after", totalDuration)
return nil
}
+
+func signalTraceStarted(started chan<- struct{}) {
+ if started == nil {
+ return
+ }
+ close(started)
+}
diff --git a/internal/ior_mode_test.go b/internal/ior_mode_test.go
new file mode 100644
index 0000000..ddc915b
--- /dev/null
+++ b/internal/ior_mode_test.go
@@ -0,0 +1,136 @@
+package internal
+
+import (
+ "context"
+ "errors"
+ "testing"
+ "time"
+
+ "ior/internal/flags"
+ "ior/internal/tui"
+)
+
+func TestShouldRunTraceMode(t *testing.T) {
+ base := flags.Flags{}
+
+ if shouldRunTraceMode(base) {
+ t.Fatalf("expected default mode to use TUI")
+ }
+
+ withPlain := base
+ withPlain.PlainMode = true
+ if !shouldRunTraceMode(withPlain) {
+ t.Fatalf("expected plain mode to use trace mode")
+ }
+
+ withFlamegraph := base
+ withFlamegraph.FlamegraphEnable = true
+ if !shouldRunTraceMode(withFlamegraph) {
+ t.Fatalf("expected flamegraph mode to use trace mode")
+ }
+
+ withPprof := base
+ withPprof.PprofEnable = true
+ if !shouldRunTraceMode(withPprof) {
+ t.Fatalf("expected pprof mode to use trace mode")
+ }
+}
+
+func TestDispatchRunUsesTraceModeWhenRequested(t *testing.T) {
+ origRunTrace := runTraceFn
+ origRunTUI := runTUIFn
+ defer func() {
+ runTraceFn = origRunTrace
+ runTUIFn = origRunTUI
+ }()
+
+ traceCalled := false
+ tuiCalled := false
+ runTraceFn = func() error {
+ traceCalled = true
+ return nil
+ }
+ runTUIFn = func(tui.TraceStarter) error {
+ tuiCalled = true
+ return nil
+ }
+
+ cfg := flags.Flags{PlainMode: true}
+ if err := dispatchRun(cfg); err != nil {
+ t.Fatalf("dispatchRun returned error: %v", err)
+ }
+ if !traceCalled {
+ t.Fatalf("expected runTraceFn to be called")
+ }
+ if tuiCalled {
+ t.Fatalf("did not expect runTUIFn to be called")
+ }
+}
+
+func TestDispatchRunUsesTUIStarterWhenNotPlain(t *testing.T) {
+ origRunTraceWithContext := runTraceWithContextFn
+ origRunTUI := runTUIFn
+ defer func() {
+ runTraceWithContextFn = origRunTraceWithContext
+ runTUIFn = origRunTUI
+ }()
+
+ traceDone := make(chan struct{}, 1)
+ runTraceWithContextFn = func(_ context.Context, started chan<- struct{}) error {
+ close(started)
+ traceDone <- struct{}{}
+ return nil
+ }
+
+ tuiCalled := false
+ runTUIFn = func(starter tui.TraceStarter) error {
+ tuiCalled = true
+ if starter == nil {
+ t.Fatalf("expected non-nil starter")
+ }
+ if err := starter(context.Background()); err != nil {
+ t.Fatalf("starter returned error: %v", err)
+ }
+ return nil
+ }
+
+ cfg := flags.Flags{}
+ if err := dispatchRun(cfg); err != nil {
+ t.Fatalf("dispatchRun returned error: %v", err)
+ }
+ if !tuiCalled {
+ t.Fatalf("expected runTUIFn to be called")
+ }
+
+ select {
+ case <-traceDone:
+ case <-time.After(200 * time.Millisecond):
+ t.Fatalf("expected starter to launch runTraceWithContextFn")
+ }
+}
+
+func TestTuiTraceStarterFromRunTracePropagatesError(t *testing.T) {
+ starter := tuiTraceStarterFromRunTrace(func(context.Context, chan<- struct{}) error {
+ return errors.New("startup failed")
+ })
+
+ err := starter(context.Background())
+ if err == nil || err.Error() != "startup failed" {
+ t.Fatalf("expected startup error, got %v", err)
+ }
+}
+
+func TestTuiTraceStarterFromRunTraceRespectsCancel(t *testing.T) {
+ starter := tuiTraceStarterFromRunTrace(func(ctx context.Context, _ chan<- struct{}) error {
+ <-ctx.Done()
+ return ctx.Err()
+ })
+
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+
+ err := starter(ctx)
+ if !errors.Is(err, context.Canceled) {
+ t.Fatalf("expected context canceled, got %v", err)
+ }
+}
diff --git a/internal/tui/tui.go b/internal/tui/tui.go
index 758213d..49d2365 100644
--- a/internal/tui/tui.go
+++ b/internal/tui/tui.go
@@ -28,7 +28,12 @@ type TraceStarter func(context.Context) error
// Run starts the TUI program in alternate screen mode.
func Run() error {
- model := NewModel(flags.Get().PidFilter, defaultTraceStarter)
+ return RunWithTraceStarter(defaultTraceStarter)
+}
+
+// RunWithTraceStarter starts the TUI program with a custom trace starter.
+func RunWithTraceStarter(starter TraceStarter) error {
+ model := NewModel(flags.Get().PidFilter, starter)
program := tea.NewProgram(model, tea.WithAltScreen())
_, err := program.Run()
return err