summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPaul Buetow <paul@buetow.org>2026-02-27 18:33:40 +0200
committerPaul Buetow <paul@buetow.org>2026-02-27 18:33:40 +0200
commit3783d23b8d608c3bf4a2dedd6b4bfb9165439bed (patch)
tree69bf24794994d4cdd0e01e337de0510f7d5139b8
parent1cf64c3e43b1bdc2b6443fd24db8028f3c96c6da (diff)
internal: validate live CLI mode behavior
-rw-r--r--internal/flags/flags_test.go76
-rw-r--r--internal/flamegraph/liveserver_test.go58
-rw-r--r--internal/ior.go20
-rw-r--r--internal/ior_mode_test.go42
4 files changed, 193 insertions, 3 deletions
diff --git a/internal/flags/flags_test.go b/internal/flags/flags_test.go
new file mode 100644
index 0000000..b4d47d2
--- /dev/null
+++ b/internal/flags/flags_test.go
@@ -0,0 +1,76 @@
+package flags
+
+import (
+ "flag"
+ "io"
+ "os"
+ "sync"
+ "testing"
+ "time"
+)
+
+func parseForTest(t *testing.T, args ...string) Flags {
+ t.Helper()
+
+ oldCommandLine := flag.CommandLine
+ oldArgs := os.Args
+ oldSingleton := singleton
+ oldOnce := once
+ oldPID := pidFilter.Load()
+ oldTID := tidFilter.Load()
+ oldTUIExport := tuiExportEnable.Load()
+
+ fs := flag.NewFlagSet("ior-test", flag.ContinueOnError)
+ fs.SetOutput(io.Discard)
+ flag.CommandLine = fs
+ os.Args = append([]string{"ior"}, args...)
+
+ singleton = Flags{TUIExportEnable: true}
+ once = sync.Once{}
+ pidFilter.Store(-1)
+ tidFilter.Store(-1)
+ tuiExportEnable.Store(true)
+
+ parse()
+ cfg := singleton
+
+ t.Cleanup(func() {
+ flag.CommandLine = oldCommandLine
+ os.Args = oldArgs
+ singleton = oldSingleton
+ once = oldOnce
+ pidFilter.Store(oldPID)
+ tidFilter.Store(oldTID)
+ tuiExportEnable.Store(oldTUIExport)
+ })
+
+ return cfg
+}
+
+func TestParseLiveFlagsAndInterval(t *testing.T) {
+ cfg := parseForTest(t, "-live", "-live-interval", "200ms", "-pid", "1234")
+
+ if !cfg.LiveFlamegraph {
+ t.Fatalf("expected -live to enable live mode")
+ }
+ if cfg.LiveInterval != 200*time.Millisecond {
+ t.Fatalf("live interval = %v, want %v", cfg.LiveInterval, 200*time.Millisecond)
+ }
+ if cfg.PidFilter != 1234 {
+ t.Fatalf("pid filter = %d, want 1234", cfg.PidFilter)
+ }
+ if got := int(pidFilter.Load()); got != 1234 {
+ t.Fatalf("global pid filter = %d, want 1234", got)
+ }
+}
+
+func TestParseLiveDefaults(t *testing.T) {
+ cfg := parseForTest(t)
+
+ if cfg.LiveFlamegraph {
+ t.Fatalf("expected live mode disabled by default")
+ }
+ if cfg.LiveInterval != time.Second {
+ t.Fatalf("default live interval = %v, want %v", cfg.LiveInterval, time.Second)
+ }
+}
diff --git a/internal/flamegraph/liveserver_test.go b/internal/flamegraph/liveserver_test.go
index 09472c5..0d55794 100644
--- a/internal/flamegraph/liveserver_test.go
+++ b/internal/flamegraph/liveserver_test.go
@@ -2,11 +2,13 @@ package flamegraph
import (
"bufio"
+ "context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
+ "os"
"strings"
"sync"
"testing"
@@ -130,6 +132,35 @@ func TestHandleSSEDelayedClientLargeTrieGetsValidSnapshot(t *testing.T) {
}
}
+func TestServeLivePrintsURLAndStopsOnCancel(t *testing.T) {
+ lt := NewLiveTrie([]string{"comm"}, "count")
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ output := captureStdout(t, func() {
+ errCh := make(chan error, 1)
+ go func() {
+ errCh <- ServeLive(ctx, lt, 5*time.Millisecond)
+ }()
+
+ time.Sleep(40 * time.Millisecond)
+ cancel()
+
+ select {
+ case err := <-errCh:
+ if err != nil {
+ t.Fatalf("ServeLive returned error: %v", err)
+ }
+ case <-time.After(2 * time.Second):
+ t.Fatalf("timeout waiting for ServeLive to return")
+ }
+ })
+
+ if !strings.Contains(output, "Live flamegraph available at http://") {
+ t.Fatalf("expected live URL in output, got %q", output)
+ }
+}
+
func connectSSE(t *testing.T, url string) *http.Response {
t.Helper()
client := &http.Client{Timeout: 5 * time.Second}
@@ -196,3 +227,30 @@ func decodeSSESnapshot(t *testing.T, data string) trieSnapshot {
}
return snap
}
+
+func captureStdout(t *testing.T, fn func()) string {
+ t.Helper()
+
+ oldStdout := os.Stdout
+ reader, writer, err := os.Pipe()
+ if err != nil {
+ t.Fatalf("create stdout pipe: %v", err)
+ }
+
+ os.Stdout = writer
+ defer func() { os.Stdout = oldStdout }()
+
+ outCh := make(chan string, 1)
+ go func() {
+ var b strings.Builder
+ _, _ = io.Copy(&b, reader)
+ outCh <- b.String()
+ }()
+
+ fn()
+
+ _ = writer.Close()
+ out := <-outCh
+ _ = reader.Close()
+ return out
+}
diff --git a/internal/ior.go b/internal/ior.go
index bf0fb1f..cdddc24 100644
--- a/internal/ior.go
+++ b/internal/ior.go
@@ -4,6 +4,7 @@ import "C"
import (
"context"
+ "errors"
"fmt"
"os"
"os/signal"
@@ -35,6 +36,9 @@ var (
runTraceFn = runTrace
runTraceWithContextFn = runTraceWithContext
runTUIFn = tui.RunWithTraceStarter
+ getEUID = os.Geteuid
+
+ errRootPrivilegesRequired = errors.New("tracing requires root privileges (run with sudo)")
)
type tracepointModule interface {
@@ -140,12 +144,22 @@ func Run() error {
}
func dispatchRun(cfg flags.Flags) error {
+ if err := validateRunConfig(cfg); err != nil {
+ return err
+ }
if shouldRunTraceMode(cfg) {
return runTraceFn()
}
return runTUIFn(tuiTraceStarterFromRunTrace(runTraceWithContextFn))
}
+func validateRunConfig(cfg flags.Flags) error {
+ if cfg.LiveFlamegraph && cfg.FlamegraphEnable {
+ return errors.New("-live and -flamegraph are mutually exclusive")
+ }
+ return nil
+}
+
func shouldRunTraceMode(cfg flags.Flags) bool {
return cfg.PlainMode || cfg.FlamegraphEnable || cfg.LiveFlamegraph || cfg.PprofEnable
}
@@ -202,6 +216,10 @@ func runTrace() error {
}
func runTraceWithContext(parentCtx context.Context, started chan<- struct{}, configure func(*eventLoop)) error {
+ if getEUID() != 0 {
+ return errRootPrivilegesRequired
+ }
+
verbose := started == nil
logln := func(...any) {}
if verbose {
@@ -328,5 +346,5 @@ func signalTraceStarted(started chan<- struct{}) {
}
func shouldAutoStopByDuration(cfg flags.Flags) bool {
- return cfg.PlainMode || cfg.FlamegraphEnable || cfg.PprofEnable
+ return cfg.PlainMode || cfg.FlamegraphEnable || cfg.LiveFlamegraph || cfg.PprofEnable
}
diff --git a/internal/ior_mode_test.go b/internal/ior_mode_test.go
index 35d7f43..bac54bd 100644
--- a/internal/ior_mode_test.go
+++ b/internal/ior_mode_test.go
@@ -68,8 +68,8 @@ func TestShouldAutoStopByDuration(t *testing.T) {
withLive := base
withLive.LiveFlamegraph = true
- if shouldAutoStopByDuration(withLive) {
- t.Fatalf("expected live mode not to auto-stop by duration")
+ if !shouldAutoStopByDuration(withLive) {
+ t.Fatalf("expected live mode to auto-stop by duration")
}
}
@@ -149,6 +149,44 @@ func TestDispatchRunUsesTUIStarterWhenNotPlain(t *testing.T) {
}
}
+func TestDispatchRunRejectsLiveAndFlamegraph(t *testing.T) {
+ origRunTrace := runTraceFn
+ origRunTUI := runTUIFn
+ defer func() {
+ runTraceFn = origRunTrace
+ runTUIFn = origRunTUI
+ }()
+
+ runTraceFn = func() error {
+ t.Fatalf("runTraceFn should not be called for invalid flag combos")
+ return nil
+ }
+ runTUIFn = func(tui.TraceStarter) error {
+ t.Fatalf("runTUIFn should not be called for invalid flag combos")
+ return nil
+ }
+
+ cfg := flags.Flags{LiveFlamegraph: true, FlamegraphEnable: true}
+ err := dispatchRun(cfg)
+ if err == nil {
+ t.Fatalf("expected error for -live with -flamegraph")
+ }
+ if err.Error() != "-live and -flamegraph are mutually exclusive" {
+ t.Fatalf("unexpected error: %v", err)
+ }
+}
+
+func TestRunTraceWithContextRequiresRoot(t *testing.T) {
+ origGetEUID := getEUID
+ defer func() { getEUID = origGetEUID }()
+
+ getEUID = func() int { return 1000 }
+ err := runTraceWithContext(context.Background(), nil, nil)
+ if !errors.Is(err, errRootPrivilegesRequired) {
+ t.Fatalf("expected root-required error, got %v", err)
+ }
+}
+
func TestTuiTraceStarterFromRunTracePropagatesError(t *testing.T) {
starter := tuiTraceStarterFromRunTrace(
func(context.Context, chan<- struct{}, func(*eventLoop)) error { return errors.New("startup failed") },