diff options
| author | Paul Buetow <paul@buetow.org> | 2026-03-06 17:32:24 +0200 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2026-03-06 17:32:24 +0200 |
| commit | 1561987330cb898f5ff64383a9c78e7e6559f118 (patch) | |
| tree | 69a823e8f98dce572566c97e6879c11c9d591bda /internal | |
| parent | 96225fb6159212a8851043a08d781aba721b4e78 (diff) | |
| parent | 110a193e04b81abb8d8e159abd73f9f6ed1acd7e (diff) | |
Merge branch 'feat/bubbletea-v2-migration'
Diffstat (limited to 'internal')
142 files changed, 8829 insertions, 5840 deletions
diff --git a/internal/bench_components_test.go b/internal/bench_components_test.go index 1f9ccb5..54c6f2e 100644 --- a/internal/bench_components_test.go +++ b/internal/bench_components_test.go @@ -71,7 +71,7 @@ func BenchmarkDeserializeDup3Event(b *testing.B) { func BenchmarkRawHandlerLookup(b *testing.B) { b.ReportAllocs() - el := newEventLoop(eventLoopConfig{}) + el := mustNewEventLoop(b, eventLoopConfig{}) eventTypes := []types.EventType{ types.ENTER_OPEN_EVENT, types.EXIT_OPEN_EVENT, @@ -97,7 +97,7 @@ func BenchmarkTracepointEntered(b *testing.B) { gen := benchutil.NewEventGenerator() _, raw := gen.EnterOpenEvent(1, componentBenchPID, componentBenchTID) - el := newComponentBenchEventLoop(componentBenchTID) + el := newComponentBenchEventLoop(b, componentBenchTID) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -118,7 +118,7 @@ func BenchmarkTracepointExited(b *testing.B) { gen := benchutil.NewEventGenerator() _, enterRaw := gen.EnterNullEvent(1, componentBenchPID, componentBenchTID, types.SYS_ENTER_SYNC) _, exitRaw := gen.ExitNullEvent(2, componentBenchPID, componentBenchTID, types.SYS_EXIT_SYNC) - el := newComponentBenchEventLoop(componentBenchTID) + el := newComponentBenchEventLoop(b, componentBenchTID) out := make(chan *event.Pair, 1) b.ResetTimer() @@ -137,7 +137,7 @@ func BenchmarkHandleOpenExit(b *testing.B) { gen := benchutil.NewEventGenerator() enterTemplate, _ := gen.EnterOpenEvent(1, componentBenchPID, componentBenchTID) exitTemplate, _ := gen.ExitOpenEvent(2, componentBenchPID, componentBenchTID) - el := newComponentBenchEventLoop(componentBenchTID) + el := newComponentBenchEventLoop(b, componentBenchTID) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -160,7 +160,7 @@ func BenchmarkHandleFdExit(b *testing.B) { gen := benchutil.NewEventGenerator() enterTemplate, _ := gen.EnterFdEvent(1, componentBenchPID, componentBenchTID, 99, types.SYS_ENTER_READ) exitTemplate, _ := gen.ExitRetEvent(2, componentBenchPID, componentBenchTID, types.SYS_EXIT_READ, 128) - el := newComponentBenchEventLoop(componentBenchTID) + el := newComponentBenchEventLoop(b, componentBenchTID) el.fdState().set(99, file.NewFd(99, "/tmp/fd", syscall.O_RDONLY)) b.ResetTimer() @@ -184,7 +184,7 @@ func BenchmarkHandlePathExit(b *testing.B) { gen := benchutil.NewEventGenerator() enterTemplate, _ := gen.EnterPathEvent(1, componentBenchPID, componentBenchTID, "/tmp/path", types.SYS_ENTER_MKDIR) exitTemplate, _ := gen.ExitRetEvent(2, componentBenchPID, componentBenchTID, types.SYS_EXIT_MKDIR, 0) - el := newComponentBenchEventLoop(componentBenchTID) + el := newComponentBenchEventLoop(b, componentBenchTID) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -207,7 +207,7 @@ func BenchmarkHandleNameExit(b *testing.B) { gen := benchutil.NewEventGenerator() enterTemplate, _ := gen.EnterNameEvent(1, componentBenchPID, componentBenchTID, "/tmp/a", "/tmp/b", types.SYS_ENTER_RENAME) exitTemplate, _ := gen.ExitRetEvent(2, componentBenchPID, componentBenchTID, types.SYS_EXIT_RENAME, 0) - el := newComponentBenchEventLoop(componentBenchTID) + el := newComponentBenchEventLoop(b, componentBenchTID) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -230,7 +230,7 @@ func BenchmarkHandleNullExit(b *testing.B) { gen := benchutil.NewEventGenerator() enterTemplate, _ := gen.EnterNullEvent(1, componentBenchPID, componentBenchTID, types.SYS_ENTER_SYNC) exitTemplate, _ := gen.ExitNullEvent(2, componentBenchPID, componentBenchTID, types.SYS_EXIT_SYNC) - el := newComponentBenchEventLoop(componentBenchTID) + el := newComponentBenchEventLoop(b, componentBenchTID) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -253,7 +253,7 @@ func BenchmarkHandleFcntlExit(b *testing.B) { gen := benchutil.NewEventGenerator() enterTemplate, _ := gen.EnterFcntlEvent(1, componentBenchPID, componentBenchTID, 7, syscall.F_SETFL, syscall.O_NONBLOCK) exitTemplate, _ := gen.ExitRetEvent(2, componentBenchPID, componentBenchTID, types.SYS_EXIT_FCNTL, 0) - el := newComponentBenchEventLoop(componentBenchTID) + el := newComponentBenchEventLoop(b, componentBenchTID) el.fdState().set(7, file.NewFd(7, "/tmp/fcntl", syscall.O_RDONLY)) b.ResetTimer() @@ -277,7 +277,7 @@ func BenchmarkHandleDup3Exit(b *testing.B) { gen := benchutil.NewEventGenerator() enterTemplate, _ := gen.EnterDup3Event(1, componentBenchPID, componentBenchTID, 9, syscall.O_CLOEXEC) exitTemplate, _ := gen.ExitRetEvent(2, componentBenchPID, componentBenchTID, types.SYS_EXIT_DUP3, 10) - el := newComponentBenchEventLoop(componentBenchTID) + el := newComponentBenchEventLoop(b, componentBenchTID) el.fdState().set(9, file.NewFd(9, "/tmp/dup3", syscall.O_RDONLY)) b.ResetTimer() @@ -354,8 +354,9 @@ func benchmarkDeserialize[T recyclable](b *testing.B, raw []byte, decode func([] } } -func newComponentBenchEventLoop(tids ...uint32) *eventLoop { - el := newEventLoop(eventLoopConfig{}) +func newComponentBenchEventLoop(tb testing.TB, tids ...uint32) *eventLoop { + tb.Helper() + el := mustNewEventLoop(tb, eventLoopConfig{}) for _, tid := range tids { el.setCachedComm(tid, fmt.Sprintf("bench-%d", tid)) } diff --git a/internal/bench_pipeline_test.go b/internal/bench_pipeline_test.go index 822e5a2..aa48302 100644 --- a/internal/bench_pipeline_test.go +++ b/internal/bench_pipeline_test.go @@ -70,7 +70,7 @@ func benchmarkPipelineMix(b *testing.B, mix benchutil.EventMix, events, numThrea close(rawCh) var pairCount int64 - el := newEventLoop(eventLoopConfig{}) + el := mustNewEventLoop(b, eventLoopConfig{}) preseedBenchComms(el, numThreads) el.printCb = func(ep *event.Pair) { pairCount++ diff --git a/internal/benchutil/doc.go b/internal/benchutil/doc.go new file mode 100644 index 0000000..eec2017 --- /dev/null +++ b/internal/benchutil/doc.go @@ -0,0 +1,2 @@ +// Package benchutil provides fixtures and helpers used by benchmark suites. +package benchutil diff --git a/internal/bpfsetup.go b/internal/bpfsetup.go index 9791930..f8f2c60 100644 --- a/internal/bpfsetup.go +++ b/internal/bpfsetup.go @@ -9,7 +9,7 @@ import ( bpf "github.com/aquasecurity/libbpfgo" ) -func setBPFGlobals(cfg flags.Flags, bpfModule *bpf.Module) error { +func setBPFGlobals(cfg flags.Config, bpfModule *bpf.Module) error { // Ignore `ior` process itself from the filter. if err := bpfModule.InitGlobalVariable("IOR_PID_FILTER", uint32(os.Getpid())); err != nil { return fmt.Errorf("unable set IOR_PID_FILTER: %w", err) @@ -23,7 +23,7 @@ func setBPFGlobals(cfg flags.Flags, bpfModule *bpf.Module) error { return nil } -func resizeBPFMaps(cfg flags.Flags, bpfModule *bpf.Module) error { +func resizeBPFMaps(cfg flags.Config, bpfModule *bpf.Module) error { if err := resizeBPFMap(bpfModule, "event_map", uint32(cfg.EventMapSize)); err != nil { return fmt.Errorf("event_map: %w", err) } diff --git a/internal/collapse/fields.go b/internal/collapse/fields.go new file mode 100644 index 0000000..b87fe11 --- /dev/null +++ b/internal/collapse/fields.go @@ -0,0 +1,39 @@ +package collapse + +import "slices" + +var validFields = []string{ + "path", + "comm", + "tracepoint", + "pid", + "tid", + "flags", +} + +var validCountFields = []string{ + "count", + "duration", + "durationToPrev", + "bytes", +} + +// ValidFields returns a copy of supported collapse fields. +func ValidFields() []string { + return slices.Clone(validFields) +} + +// ValidCountFields returns a copy of supported collapse count fields. +func ValidCountFields() []string { + return slices.Clone(validCountFields) +} + +// IsValidField reports whether a collapse field is supported. +func IsValidField(field string) bool { + return slices.Contains(validFields, field) +} + +// IsValidCountField reports whether a collapse count field is supported. +func IsValidCountField(field string) bool { + return slices.Contains(validCountFields, field) +} diff --git a/internal/doc.go b/internal/doc.go new file mode 100644 index 0000000..adf259b --- /dev/null +++ b/internal/doc.go @@ -0,0 +1,2 @@ +// Package internal contains runtime orchestration helpers shared by command entrypoints. +package internal diff --git a/internal/event/doc.go b/internal/event/doc.go new file mode 100644 index 0000000..1504e6f --- /dev/null +++ b/internal/event/doc.go @@ -0,0 +1,2 @@ +// Package event decodes and formats kernel event payloads for downstream consumers. +package event diff --git a/internal/event/event.go b/internal/event/event.go index 165ffe3..48bde48 100644 --- a/internal/event/event.go +++ b/internal/event/event.go @@ -1,17 +1,19 @@ package event import ( - . "ior/internal/types" "sync" + + "ior/internal/types" ) var poolOfEventPairs = sync.Pool{ New: func() interface{} { return &Pair{} }, } +// Event is the common contract implemented by decoded syscall trace events. type Event interface { String() string - GetTraceId() TraceId + GetTraceId() types.TraceId GetPid() uint32 GetTid() uint32 GetTime() uint64 diff --git a/internal/event/pair.go b/internal/event/pair.go index 1d1e9ce..131c6b3 100644 --- a/internal/event/pair.go +++ b/internal/event/pair.go @@ -2,10 +2,11 @@ package event import ( "fmt" - "ior/internal/file" - "ior/internal/types" "strconv" "strings" + + "ior/internal/file" + "ior/internal/types" ) // Pair represents a matched syscall enter/exit pair together with derived metadata. @@ -60,7 +61,7 @@ const EventStreamHeader = "durationToPrevNs,durationNs,comm,pid.tid,name,ret,not func (e *Pair) String() string { var sb strings.Builder - sb.WriteString(fmt.Sprintf("%08d,%08d", e.DurationToPrev, e.Duration)) + _, _ = fmt.Fprintf(&sb, "%08d,%08d", e.DurationToPrev, e.Duration) sb.WriteString(",") sb.WriteString(e.Comm) diff --git a/internal/event/pair_test.go b/internal/event/pair_test.go index a1cb8ab..43e9945 100644 --- a/internal/event/pair_test.go +++ b/internal/event/pair_test.go @@ -1,8 +1,9 @@ package event import ( - "ior/internal/types" "testing" + + "ior/internal/types" ) func TestPairCalculateDurationsFirstEvent(t *testing.T) { diff --git a/internal/eventfilter.go b/internal/eventfilter.go index 4ff0385..43a8f51 100644 --- a/internal/eventfilter.go +++ b/internal/eventfilter.go @@ -1,39 +1,45 @@ package internal import ( + "bytes" "fmt" + "strings" + "ior/internal/event" "ior/internal/types" - "strings" ) type eventFilter struct { commFilterEnable bool commFilter string + commFilterBytes []byte pathFilterEnable bool pathFilter string + pathFilterBytes []byte } -func newEventFilter(commFilter, pathFilter string) *eventFilter { +func newEventFilter(commFilter, pathFilter string) (*eventFilter, error) { var ef eventFilter if commFilter != "" { - if len(commFilter) > types.MAX_FILENAME_LENGTH { - panic(fmt.Sprintf("Comm filter's max size is %d", types.MAX_PROGNAME_LENGTH)) + if len(commFilter) > types.MAX_PROGNAME_LENGTH { + return nil, fmt.Errorf("comm filter max size is %d (got %d)", types.MAX_PROGNAME_LENGTH, len(commFilter)) } ef.commFilterEnable = true ef.commFilter = commFilter + ef.commFilterBytes = []byte(commFilter) } if pathFilter != "" { if len(pathFilter) > types.MAX_FILENAME_LENGTH { - panic(fmt.Sprintf("Path filter's max size is %d", types.MAX_FILENAME_LENGTH)) + return nil, fmt.Errorf("path filter max size is %d (got %d)", types.MAX_FILENAME_LENGTH, len(pathFilter)) } ef.pathFilterEnable = true ef.pathFilter = pathFilter + ef.pathFilterBytes = []byte(pathFilter) } - return &ef + return &ef, nil } func (ef *eventFilter) eventPair(ev *event.Pair) bool { @@ -47,11 +53,11 @@ func (ef *eventFilter) eventPair(ev *event.Pair) bool { } func (ef *eventFilter) openEvent(ev *types.OpenEvent) (*types.OpenEvent, bool) { - if ef.commFilterEnable && !strings.Contains(string(ev.Comm[:]), ef.commFilter) { + if ef.commFilterEnable && !bytes.Contains(ev.Comm[:], ef.commFilterBytes) { return ev, false } - if ef.pathFilterEnable && !strings.Contains(string(ev.Filename[:]), ef.pathFilter) { + if ef.pathFilterEnable && !bytes.Contains(ev.Filename[:], ef.pathFilterBytes) { return ev, false } return ev, true @@ -59,14 +65,14 @@ func (ef *eventFilter) openEvent(ev *types.OpenEvent) (*types.OpenEvent, bool) { func (ef *eventFilter) pathEvent(ev *types.PathEvent) (*types.PathEvent, bool) { if ef.pathFilterEnable { - return ev, strings.Contains(string(ev.Pathname[:]), ef.pathFilter) + return ev, bytes.Contains(ev.Pathname[:], ef.pathFilterBytes) } return ev, true } func (ef *eventFilter) nameEvent(ev *types.NameEvent) (*types.NameEvent, bool) { if ef.pathFilterEnable { - return ev, strings.Contains(string(ev.Oldname[:]), ef.pathFilter) || strings.Contains(string(ev.Newname[:]), ef.pathFilter) + return ev, bytes.Contains(ev.Oldname[:], ef.pathFilterBytes) || bytes.Contains(ev.Newname[:], ef.pathFilterBytes) } return ev, true } diff --git a/internal/eventloop.go b/internal/eventloop.go index 7d33f87..26eaafc 100644 --- a/internal/eventloop.go +++ b/internal/eventloop.go @@ -7,32 +7,33 @@ import ( "fmt" "os" "path/filepath" + "reflect" "sync" "syscall" "time" "ior/internal/event" "ior/internal/file" - "ior/internal/flamegraph" "ior/internal/types" - . "ior/internal/types" ) const sysEnterNameToHandleAtName = "name_to_handle_at" +const ( + defaultCommLookupWorkers = 4 + defaultCommLookupQueueSize = 512 +) + type eventLoopConfig struct { - pidFilter int - commFilter string - pathFilter string - liveFlamegraph bool - liveInterval time.Duration - liveOpenCommand string - collapsedFields []string - countField string - flamegraphName string - flamegraphEnable bool - pprofEnable bool - plainMode bool + pidFilter int + commFilter string + pathFilter string + collapsedFields []string + countField string + pprofEnable bool + plainMode bool + fdTracker *fdTracker + commResolver *commResolver } type fdTracker struct { @@ -72,16 +73,56 @@ type commResolver struct { mu sync.RWMutex pending map[uint32]struct{} + + lookupQueue chan uint32 + lookupWorkers int + resolveFn func(uint32) string + startWorkersOnce sync.Once } func newCommResolver(comms map[uint32]string) *commResolver { if comms == nil { comms = make(map[uint32]string) } - return &commResolver{ + r := &commResolver{ comms: comms, pending: make(map[uint32]struct{}), } + r.ensureLookupConfig() + return r +} + +func (r *commResolver) ensureLookupConfig() { + if r.lookupWorkers <= 0 { + r.lookupWorkers = defaultCommLookupWorkers + } + if r.lookupQueue == nil { + r.lookupQueue = make(chan uint32, defaultCommLookupQueueSize) + } + if r.resolveFn == nil { + r.resolveFn = resolveCommFromProc + } +} + +func (r *commResolver) startLookupWorkers() { + r.ensureLookupConfig() + r.startWorkersOnce.Do(func() { + for i := 0; i < r.lookupWorkers; i++ { + go r.lookupWorker() + } + }) +} + +func (r *commResolver) lookupWorker() { + for tid := range r.lookupQueue { + comm := r.resolveFn(tid) + r.mu.Lock() + delete(r.pending, tid) + if comm != "" { + r.comms[tid] = comm + } + r.mu.Unlock() + } } func (r *commResolver) seedTrackedPidComm(pidFilter int) { @@ -150,32 +191,31 @@ func (r *commResolver) queueLookup(tid uint32) { r.pending[tid] = struct{}{} r.mu.Unlock() - go func() { - comm := resolveCommFromProc(tid) + r.startLookupWorkers() + + // Keep event processing non-blocking if resolver workers are saturated. + select { + case r.lookupQueue <- tid: + default: r.mu.Lock() delete(r.pending, tid) - if comm != "" { - r.comms[tid] = comm - } r.mu.Unlock() - }() + } } type rawEventHandler func(raw []byte, ch chan<- *event.Pair) +type tracepointExitHandler func(ep *event.Pair) bool type eventLoop struct { filter *eventFilter enterEvs map[uint32]*event.Pair // Temp. store of sys_enter tracepoints per Tid. pendingHandles map[uint32]string // map of TID to pathname from name_to_handle_at - files map[int32]file.File // Track all open files by file descriptor. fdTracker *fdTracker - procFdCache map[uint64]file.FdFile // Cache procfs-resolved metadata for unknown fds. - comms map[uint32]string // Program or thread name of the current Tid. + procFdCache map[uint64]*file.FdFile // Cache procfs-resolved metadata for unknown fds. commResolver *commResolver prevPairTimes map[uint32]uint64 // Previous event's time (to calculate time differences between two events) - rawHandlers map[EventType]rawEventHandler - flamegraph flamegraph.IorDataCollector // Storing all paths in a map structure for analysis - liveTrie *flamegraph.LiveTrie + rawHandlers map[types.EventType]rawEventHandler + exitHandlers map[reflect.Type]tracepointExitHandler printCb func(ep *event.Pair) // Callback to print the event warningCb func(message string) // Optional callback for non-fatal event processing warnings cfg eventLoopConfig @@ -189,33 +229,57 @@ type eventLoop struct { done chan struct{} } -func newEventLoop(cfg eventLoopConfig) *eventLoop { - filesByFD := make(map[int32]file.File) - commsByTID := make(map[uint32]string) +func newEventLoop(cfg eventLoopConfig) (*eventLoop, error) { + fdState := configuredFDTracker(cfg.fdTracker) + commState := configuredCommResolver(cfg.commResolver) + filter, err := newEventFilter(cfg.commFilter, cfg.pathFilter) + if err != nil { + return nil, fmt.Errorf("create event filter: %w", err) + } el := &eventLoop{ - filter: newEventFilter(cfg.commFilter, cfg.pathFilter), + filter: filter, enterEvs: make(map[uint32]*event.Pair), pendingHandles: make(map[uint32]string), - files: filesByFD, - fdTracker: newFDTracker(filesByFD), - procFdCache: make(map[uint64]file.FdFile), - comms: commsByTID, - commResolver: newCommResolver(commsByTID), + fdTracker: fdState, + procFdCache: make(map[uint64]*file.FdFile), + commResolver: commState, prevPairTimes: make(map[uint32]uint64), - rawHandlers: make(map[EventType]rawEventHandler), + rawHandlers: make(map[types.EventType]rawEventHandler), + exitHandlers: make(map[reflect.Type]tracepointExitHandler), printCb: func(ep *event.Pair) { fmt.Println(ep); ep.Recycle() }, - flamegraph: flamegraph.New(cfg.flamegraphName), cfg: cfg, done: make(chan struct{}), } el.initRawHandlers() - if cfg.liveFlamegraph { - el.liveTrie = flamegraph.NewLiveTrie(cfg.collapsedFields, cfg.countField) - } + el.initExitHandlers() el.configureOutputCallback() el.seedTrackedPidComm() - return el + return el, nil +} + +func configuredFDTracker(injected *fdTracker) *fdTracker { + if injected == nil { + return newFDTracker(nil) + } + if injected.files == nil { + injected.files = make(map[int32]file.File) + } + return injected +} + +func configuredCommResolver(injected *commResolver) *commResolver { + if injected == nil { + return newCommResolver(nil) + } + if injected.comms == nil { + injected.comms = make(map[uint32]string) + } + if injected.pending == nil { + injected.pending = make(map[uint32]struct{}) + } + injected.ensureLookupConfig() + return injected } func (e *eventLoop) seedTrackedPidComm() { @@ -223,35 +287,31 @@ func (e *eventLoop) seedTrackedPidComm() { } func (e *eventLoop) fdState() *fdTracker { - if e.files == nil { - e.files = make(map[int32]file.File) - } if e.fdTracker == nil { - e.fdTracker = newFDTracker(e.files) + e.fdTracker = newFDTracker(nil) + } + if e.fdTracker.files == nil { + e.fdTracker.files = make(map[int32]file.File) } return e.fdTracker } func (e *eventLoop) commState() *commResolver { - if e.comms == nil { - e.comms = make(map[uint32]string) - } if e.commResolver == nil { - e.commResolver = newCommResolver(e.comms) + e.commResolver = newCommResolver(nil) + } + if e.commResolver.comms == nil { + e.commResolver.comms = make(map[uint32]string) } + if e.commResolver.pending == nil { + e.commResolver.pending = make(map[uint32]struct{}) + } + e.commResolver.ensureLookupConfig() return e.commResolver } func (e *eventLoop) configureOutputCallback() { switch { - case e.cfg.flamegraphEnable: - e.printCb = func(ep *event.Pair) { - e.flamegraph.Ch <- ep - } - case e.liveTrie != nil: - e.printCb = func(ep *event.Pair) { - e.liveTrie.Ingest(ep) - } case e.cfg.pprofEnable: e.printCb = func(ep *event.Pair) { ep.Recycle() @@ -282,29 +342,10 @@ func (e *eventLoop) stats() string { func (e *eventLoop) run(ctx context.Context, rawCh <-chan []byte) { defer close(e.done) - if e.liveTrie != nil { - fmt.Println("Starting live flamegraph server") - go func() { - liveOptions := flamegraph.LiveServerOptions{ - OpenCommand: e.cfg.liveOpenCommand, - } - if e.warningCb != nil { - liveOptions.WarningCb = e.notifyWarning - } - if err := flamegraph.ServeLiveWithOptions(ctx, e.liveTrie, e.cfg.liveInterval, liveOptions); err != nil && ctx.Err() == nil { - fmt.Println("Live flamegraph server error:", err) - } - }() - } - - if e.cfg.flamegraphEnable { - fmt.Println("Collecting flame graph stats, press Ctrl+C to stop") - e.flamegraph.Start(ctx) - } if e.cfg.pprofEnable { fmt.Println("Profiling, press Ctrl+C to stop") } - if e.cfg.plainMode && !e.cfg.flamegraphEnable && !e.cfg.pprofEnable { + if e.cfg.plainMode && !e.cfg.pprofEnable { fmt.Println(event.EventStreamHeader) } @@ -316,16 +357,6 @@ func (e *eventLoop) run(ctx context.Context, rawCh <-chan []byte) { e.printCb(ep) e.numSyscallsAfterFilter++ } - - if e.cfg.flamegraphEnable { - fmt.Println("Waiting for flamegraph") - if err := <-e.flamegraph.Done; err != nil { - e.notifyWarning(fmt.Sprintf("Flamegraph generation failed: %v", err)) - if e.warningCb == nil { - fmt.Println("Flamegraph generation failed:", err) - } - } - } } func (e *eventLoop) events(ctx context.Context, rawCh <-chan []byte) <-chan *event.Pair { @@ -359,7 +390,7 @@ func (e *eventLoop) events(ctx context.Context, rawCh <-chan []byte) <-chan *eve func (e *eventLoop) processRawEvent(raw []byte, ch chan<- *event.Pair) { e.numTracepoints++ e.initRawHandlers() - evType := EventType(raw[0]) + evType := types.EventType(raw[0]) handler, ok := e.rawHandlers[evType] if !ok { e.notifyWarning(fmt.Sprintf("Dropped unhandled raw event type %d", evType)) @@ -370,53 +401,53 @@ func (e *eventLoop) processRawEvent(raw []byte, ch chan<- *event.Pair) { func (e *eventLoop) initRawHandlers() { if e.rawHandlers == nil { - e.rawHandlers = make(map[EventType]rawEventHandler) + e.rawHandlers = make(map[types.EventType]rawEventHandler) } if len(e.rawHandlers) != 0 { return } - e.rawHandlers[ENTER_OPEN_EVENT] = func(raw []byte, _ chan<- *event.Pair) { - if ev, ok := e.filter.openEvent(NewOpenEventFast(raw)); ok { + e.rawHandlers[types.ENTER_OPEN_EVENT] = func(raw []byte, _ chan<- *event.Pair) { + if ev, ok := e.filter.openEvent(types.NewOpenEventFast(raw)); ok { e.tracepointEntered(ev) } } - e.rawHandlers[EXIT_OPEN_EVENT] = func(raw []byte, ch chan<- *event.Pair) { - e.tracepointExited(NewRetEventFast(raw), ch) + e.rawHandlers[types.EXIT_OPEN_EVENT] = func(raw []byte, ch chan<- *event.Pair) { + e.tracepointExited(types.NewRetEventFast(raw), ch) } - e.rawHandlers[ENTER_FD_EVENT] = func(raw []byte, _ chan<- *event.Pair) { - e.tracepointEntered(NewFdEventFast(raw)) + e.rawHandlers[types.ENTER_FD_EVENT] = func(raw []byte, _ chan<- *event.Pair) { + e.tracepointEntered(types.NewFdEventFast(raw)) } - e.rawHandlers[EXIT_FD_EVENT] = func(raw []byte, ch chan<- *event.Pair) { - e.tracepointExited(NewFdEventFast(raw), ch) + e.rawHandlers[types.EXIT_FD_EVENT] = func(raw []byte, ch chan<- *event.Pair) { + e.tracepointExited(types.NewFdEventFast(raw), ch) } - e.rawHandlers[ENTER_NULL_EVENT] = func(raw []byte, _ chan<- *event.Pair) { - e.tracepointEntered(NewNullEventFast(raw)) + e.rawHandlers[types.ENTER_NULL_EVENT] = func(raw []byte, _ chan<- *event.Pair) { + e.tracepointEntered(types.NewNullEventFast(raw)) } - e.rawHandlers[EXIT_NULL_EVENT] = func(raw []byte, ch chan<- *event.Pair) { - e.tracepointExited(NewNullEventFast(raw), ch) + e.rawHandlers[types.EXIT_NULL_EVENT] = func(raw []byte, ch chan<- *event.Pair) { + e.tracepointExited(types.NewNullEventFast(raw), ch) } - e.rawHandlers[EXIT_RET_EVENT] = func(raw []byte, ch chan<- *event.Pair) { - e.tracepointExited(NewRetEventFast(raw), ch) + e.rawHandlers[types.EXIT_RET_EVENT] = func(raw []byte, ch chan<- *event.Pair) { + e.tracepointExited(types.NewRetEventFast(raw), ch) } - e.rawHandlers[ENTER_NAME_EVENT] = func(raw []byte, _ chan<- *event.Pair) { - if ev, ok := e.filter.nameEvent(NewNameEventFast(raw)); ok { + e.rawHandlers[types.ENTER_NAME_EVENT] = func(raw []byte, _ chan<- *event.Pair) { + if ev, ok := e.filter.nameEvent(types.NewNameEventFast(raw)); ok { e.tracepointEntered(ev) } } - e.rawHandlers[ENTER_PATH_EVENT] = func(raw []byte, _ chan<- *event.Pair) { - if ev, ok := e.filter.pathEvent(NewPathEventFast(raw)); ok { + e.rawHandlers[types.ENTER_PATH_EVENT] = func(raw []byte, _ chan<- *event.Pair) { + if ev, ok := e.filter.pathEvent(types.NewPathEventFast(raw)); ok { e.tracepointEntered(ev) } } - e.rawHandlers[ENTER_FCNTL_EVENT] = func(raw []byte, _ chan<- *event.Pair) { - e.tracepointEntered(NewFcntlEventFast(raw)) + e.rawHandlers[types.ENTER_FCNTL_EVENT] = func(raw []byte, _ chan<- *event.Pair) { + e.tracepointEntered(types.NewFcntlEventFast(raw)) } - e.rawHandlers[ENTER_OPEN_BY_HANDLE_AT_EVENT] = func(raw []byte, _ chan<- *event.Pair) { - e.tracepointEntered(NewOpenByHandleAtEventFast(raw)) + e.rawHandlers[types.ENTER_OPEN_BY_HANDLE_AT_EVENT] = func(raw []byte, _ chan<- *event.Pair) { + e.tracepointEntered(types.NewOpenByHandleAtEventFast(raw)) } - e.rawHandlers[ENTER_DUP3_EVENT] = func(raw []byte, _ chan<- *event.Pair) { - e.tracepointEntered(NewDup3EventFast(raw)) + e.rawHandlers[types.ENTER_DUP3_EVENT] = func(raw []byte, _ chan<- *event.Pair) { + e.tracepointEntered(types.NewDup3EventFast(raw)) } } @@ -430,7 +461,7 @@ func (e *eventLoop) tracepointEntered(enterEv event.Event) { } switch enterEv.(type) { - case *OpenEvent: + case *types.OpenEvent: e.enterEvs[tid] = event.NewPair(enterEv) default: // Only, when we have a comm name @@ -471,32 +502,93 @@ func (e *eventLoop) tracepointExited(exitEv event.Event, ch chan<- *event.Pair) ch <- ep } +func (e *eventLoop) initExitHandlers() { + e.exitHandlers = map[reflect.Type]tracepointExitHandler{ + reflect.TypeOf(&types.OpenEvent{}): func(ep *event.Pair) bool { + enterEv, ok := ep.EnterEv.(*types.OpenEvent) + if !ok { + e.recyclePair(ep, "Dropped malformed open enter event") + return false + } + return e.handleOpenExit(ep, enterEv) + }, + reflect.TypeOf(&types.NameEvent{}): func(ep *event.Pair) bool { + enterEv, ok := ep.EnterEv.(*types.NameEvent) + if !ok { + e.recyclePair(ep, "Dropped malformed name enter event") + return false + } + return e.handleNameExit(ep, enterEv) + }, + reflect.TypeOf(&types.PathEvent{}): func(ep *event.Pair) bool { + enterEv, ok := ep.EnterEv.(*types.PathEvent) + if !ok { + e.recyclePair(ep, "Dropped malformed path enter event") + return false + } + return e.handlePathExit(ep, enterEv) + }, + reflect.TypeOf(&types.FdEvent{}): func(ep *event.Pair) bool { + enterEv, ok := ep.EnterEv.(*types.FdEvent) + if !ok { + e.recyclePair(ep, "Dropped malformed fd enter event") + return false + } + return e.handleFdExit(ep, enterEv) + }, + reflect.TypeOf(&types.Dup3Event{}): func(ep *event.Pair) bool { + enterEv, ok := ep.EnterEv.(*types.Dup3Event) + if !ok { + e.recyclePair(ep, "Dropped malformed dup3 enter event") + return false + } + return e.handleDup3Exit(ep, enterEv) + }, + reflect.TypeOf(&types.OpenByHandleAtEvent{}): func(ep *event.Pair) bool { + enterEv, ok := ep.EnterEv.(*types.OpenByHandleAtEvent) + if !ok { + e.recyclePair(ep, "Dropped malformed open_by_handle_at enter event") + return false + } + return e.handleOpenByHandleAtExit(ep, enterEv) + }, + reflect.TypeOf(&types.NullEvent{}): func(ep *event.Pair) bool { + enterEv, ok := ep.EnterEv.(*types.NullEvent) + if !ok { + e.recyclePair(ep, "Dropped malformed null enter event") + return false + } + return e.handleNullExit(ep, enterEv) + }, + reflect.TypeOf(&types.FcntlEvent{}): func(ep *event.Pair) bool { + enterEv, ok := ep.EnterEv.(*types.FcntlEvent) + if !ok { + e.recyclePair(ep, "Dropped malformed fcntl enter event") + return false + } + return e.handleFcntlExit(ep, enterEv) + }, + } +} + +func (e *eventLoop) exitHandlerRegistry() map[reflect.Type]tracepointExitHandler { + if e.exitHandlers == nil { + e.initExitHandlers() + } + return e.exitHandlers +} + func (e *eventLoop) handleTracepointExit(ep *event.Pair) bool { - switch enterEv := ep.EnterEv.(type) { - case *OpenEvent: - return e.handleOpenExit(ep, enterEv) - case *NameEvent: - return e.handleNameExit(ep, enterEv) - case *PathEvent: - return e.handlePathExit(ep, enterEv) - case *FdEvent: - return e.handleFdExit(ep, enterEv) - case *Dup3Event: - return e.handleDup3Exit(ep, enterEv) - case *OpenByHandleAtEvent: - return e.handleOpenByHandleAtExit(ep, enterEv) - case *NullEvent: - return e.handleNullExit(ep, enterEv) - case *FcntlEvent: - return e.handleFcntlExit(ep, enterEv) - default: + handler, ok := e.exitHandlerRegistry()[reflect.TypeOf(ep.EnterEv)] + if !ok { e.recyclePair(ep, "Dropped malformed enter event") return false } + return handler(ep) } -func (e *eventLoop) handleOpenExit(ep *event.Pair, openEv *OpenEvent) bool { - retEvent, ok := ep.ExitEv.(*RetEvent) +func (e *eventLoop) handleOpenExit(ep *event.Pair, openEv *types.OpenEvent) bool { + retEvent, ok := ep.ExitEv.(*types.RetEvent) if !ok { e.recyclePair(ep, "Dropped malformed open exit event") return false @@ -516,13 +608,13 @@ func (e *eventLoop) handleOpenExit(ep *event.Pair, openEv *OpenEvent) bool { return true } -func (e *eventLoop) handleNameExit(ep *event.Pair, nameEv *NameEvent) bool { +func (e *eventLoop) handleNameExit(ep *event.Pair, nameEv *types.NameEvent) bool { ep.File = file.NewOldnameNewname(nameEv.Oldname[:], nameEv.Newname[:]) ep.Comm = e.comm(nameEv.GetTid()) return true } -func (e *eventLoop) handlePathExit(ep *event.Pair, pathEv *PathEvent) bool { +func (e *eventLoop) handlePathExit(ep *event.Pair, pathEv *types.PathEvent) bool { if pathEv.GetTraceId().Name() == sysEnterNameToHandleAtName { retEv, ok := ep.ExitEv.(*types.RetEvent) if !ok || retEv.Ret < 0 { @@ -534,8 +626,8 @@ func (e *eventLoop) handlePathExit(ep *event.Pair, pathEv *PathEvent) bool { return false } - if ep.Is(SYS_ENTER_CREAT) { - retEvent, ok := ep.ExitEv.(*RetEvent) + if ep.Is(types.SYS_ENTER_CREAT) { + retEvent, ok := ep.ExitEv.(*types.RetEvent) if !ok { e.recyclePair(ep, "Dropped malformed creat exit event") return false @@ -553,14 +645,14 @@ func (e *eventLoop) handlePathExit(ep *event.Pair, pathEv *PathEvent) bool { return true } -func (e *eventLoop) handleFdExit(ep *event.Pair, fdEv *FdEvent) bool { +func (e *eventLoop) handleFdExit(ep *event.Pair, fdEv *types.FdEvent) bool { fd := fdEv.Fd ep.File = e.resolveFdFile(fd, fdEv.Pid) - if ep.Is(SYS_ENTER_CLOSE) { + if ep.Is(types.SYS_ENTER_CLOSE) { e.fdState().delete(fd) e.deleteProcFdCache(fd, fdEv.Pid) } - if ep.Is(SYS_ENTER_CLOSE_RANGE) { + if ep.Is(types.SYS_ENTER_CLOSE_RANGE) { // close_range provides (first, last), but fd_event only carries the first // argument, so we approximate by closing all tracked fds >= first. retEv, ok := ep.ExitEv.(*types.RetEvent) @@ -575,13 +667,13 @@ func (e *eventLoop) handleFdExit(ep *event.Pair, fdEv *FdEvent) bool { return false } - if ep.Is(SYS_ENTER_DUP) || ep.Is(SYS_ENTER_DUP2) { - fdFile, ok := ep.File.(file.FdFile) + if ep.Is(types.SYS_ENTER_DUP) || ep.Is(types.SYS_ENTER_DUP2) { + fdFile, ok := ep.File.(*file.FdFile) if !ok { e.recyclePair(ep, "Dropped malformed dup source event") return false } - retEvent, ok := ep.ExitEv.(*RetEvent) + retEvent, ok := ep.ExitEv.(*types.RetEvent) if !ok { e.recyclePair(ep, "Dropped malformed dup exit event") return false @@ -589,8 +681,8 @@ func (e *eventLoop) handleFdExit(ep *event.Pair, fdEv *FdEvent) bool { // Duplicating fd e.registerDup(fdFile, int32(retEvent.Ret), 0) } - if ep.Is(SYS_ENTER_PIDFD_GETFD) { - retEv, ok := ep.ExitEv.(*RetEvent) + if ep.Is(types.SYS_ENTER_PIDFD_GETFD) { + retEv, ok := ep.ExitEv.(*types.RetEvent) if !ok { e.recyclePair(ep, "Dropped malformed pidfd_getfd exit event") return false @@ -601,13 +693,13 @@ func (e *eventLoop) handleFdExit(ep *event.Pair, fdEv *FdEvent) bool { ep.File = transferredFile } } - if retEv, ok := ep.ExitEv.(*RetEvent); ok { + if retEv, ok := ep.ExitEv.(*types.RetEvent); ok { ep.Bytes = bytesFromRet(retEv) } return true } -func (e *eventLoop) handleDup3Exit(ep *event.Pair, dup3Ev *Dup3Event) bool { +func (e *eventLoop) handleDup3Exit(ep *event.Pair, dup3Ev *types.Dup3Event) bool { fd := int32(dup3Ev.Fd) ep.File = e.resolveFdFile(fd, dup3Ev.Pid) ep.Comm = e.comm(dup3Ev.GetTid()) @@ -616,12 +708,12 @@ func (e *eventLoop) handleDup3Exit(ep *event.Pair, dup3Ev *Dup3Event) bool { return false } - fdFile, ok := ep.File.(file.FdFile) + fdFile, ok := ep.File.(*file.FdFile) if !ok { e.recyclePair(ep, "Dropped malformed dup3 source event") return false } - retEvent, ok := ep.ExitEv.(*RetEvent) + retEvent, ok := ep.ExitEv.(*types.RetEvent) if !ok { e.recyclePair(ep, "Dropped malformed dup3 exit event") return false @@ -630,9 +722,9 @@ func (e *eventLoop) handleDup3Exit(ep *event.Pair, dup3Ev *Dup3Event) bool { return true } -func (e *eventLoop) handleOpenByHandleAtExit(ep *event.Pair, openByHandleEv *OpenByHandleAtEvent) bool { +func (e *eventLoop) handleOpenByHandleAtExit(ep *event.Pair, openByHandleEv *types.OpenByHandleAtEvent) bool { tid := openByHandleEv.GetTid() - retEvent, ok := ep.ExitEv.(*RetEvent) + retEvent, ok := ep.ExitEv.(*types.RetEvent) if !ok { e.recyclePair(ep, "Dropped malformed open_by_handle_at exit event") return false @@ -661,8 +753,8 @@ func (e *eventLoop) handleOpenByHandleAtExit(ep *event.Pair, openByHandleEv *Ope return true } -func (e *eventLoop) handleNullExit(ep *event.Pair, nullEv *NullEvent) bool { - if ep.Is(SYS_ENTER_IO_URING_SETUP) { +func (e *eventLoop) handleNullExit(ep *event.Pair, nullEv *types.NullEvent) bool { + if ep.Is(types.SYS_ENTER_IO_URING_SETUP) { retEvent, ok := ep.ExitEv.(*types.RetEvent) if !ok { e.recyclePair(ep, "Dropped malformed io_uring_setup exit event") @@ -674,7 +766,7 @@ func (e *eventLoop) handleNullExit(ep *event.Pair, nullEv *NullEvent) bool { ep.File = fdFile } } - if ep.Is(SYS_ENTER_GETCWD) { + if ep.Is(types.SYS_ENTER_GETCWD) { retEvent, ok := ep.ExitEv.(*types.RetEvent) if !ok { e.recyclePair(ep, "Dropped malformed getcwd exit event") @@ -694,7 +786,7 @@ func (e *eventLoop) handleNullExit(ep *event.Pair, nullEv *NullEvent) bool { return true } -func (e *eventLoop) handleFcntlExit(ep *event.Pair, fcntlEv *FcntlEvent) bool { +func (e *eventLoop) handleFcntlExit(ep *event.Pair, fcntlEv *types.FcntlEvent) bool { ep.Comm = e.comm(fcntlEv.GetTid()) fd := int32(fcntlEv.Fd) ep.File = e.resolveFdFile(fd, fcntlEv.Pid) @@ -713,7 +805,7 @@ func (e *eventLoop) handleFcntlExit(ep *event.Pair, fcntlEv *FcntlEvent) bool { return true } - fdFile, ok := ep.File.(file.FdFile) + fdFile, ok := ep.File.(*file.FdFile) if !ok { e.recyclePair(ep, "Dropped malformed fcntl file event") return false @@ -734,7 +826,7 @@ func (e *eventLoop) handleFcntlExit(ep *event.Pair, fcntlEv *FcntlEvent) bool { return true } -func (e *eventLoop) registerDup(fdFile file.FdFile, newFd int32, extraFlags int32) { +func (e *eventLoop) registerDup(fdFile *file.FdFile, newFd int32, extraFlags int32) { if newFd < 0 { return } @@ -775,12 +867,12 @@ func (e *eventLoop) resolveFdFile(fd int32, pid uint32) file.File { return discovered } -func (e *eventLoop) cachedProcFdFile(fd int32, pid uint32) (file.FdFile, bool) { +func (e *eventLoop) cachedProcFdFile(fd int32, pid uint32) (*file.FdFile, bool) { cache, ok := e.procFdCacheState()[procFdCacheKey(pid, fd)] return cache, ok } -func (e *eventLoop) setProcFdCache(fd int32, pid uint32, resolved file.FdFile) { +func (e *eventLoop) setProcFdCache(fd int32, pid uint32, resolved *file.FdFile) { e.procFdCacheState()[procFdCacheKey(pid, fd)] = resolved } @@ -799,9 +891,9 @@ func (e *eventLoop) deleteProcFdCacheFrom(first int32, pid uint32) { } } -func (e *eventLoop) procFdCacheState() map[uint64]file.FdFile { +func (e *eventLoop) procFdCacheState() map[uint64]*file.FdFile { if e.procFdCache == nil { - e.procFdCache = make(map[uint64]file.FdFile) + e.procFdCache = make(map[uint64]*file.FdFile) } return e.procFdCache } diff --git a/internal/eventloop_commresolver_test.go b/internal/eventloop_commresolver_test.go new file mode 100644 index 0000000..0f10db8 --- /dev/null +++ b/internal/eventloop_commresolver_test.go @@ -0,0 +1,202 @@ +package internal + +import ( + "fmt" + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestCommResolverQueueLookupRespectsWorkerLimit(t *testing.T) { + const workers = 2 + const lookups = 6 + + started := make(chan struct{}, lookups) + release := make(chan struct{}) + var wg sync.WaitGroup + wg.Add(lookups) + + var running int32 + var maxRunning int32 + + resolver := newCommResolver(nil) + resolver.lookupWorkers = workers + resolver.lookupQueue = make(chan uint32, lookups) + resolver.resolveFn = func(tid uint32) string { + current := atomic.AddInt32(&running, 1) + setMaxInt32(&maxRunning, current) + started <- struct{}{} + <-release + atomic.AddInt32(&running, -1) + wg.Done() + return fmt.Sprintf("comm-%d", tid) + } + + for i := 1; i <= lookups; i++ { + resolver.queueLookup(uint32(i)) + } + + waitForStarts(t, started, workers, 2*time.Second) + select { + case <-started: + t.Fatalf("expected at most %d concurrent lookups", workers) + case <-time.After(75 * time.Millisecond): + } + + close(release) + waitForWaitGroup(t, &wg, 2*time.Second) + waitForCondition(t, 2*time.Second, "expected all queued tids to be cached", func() bool { + for i := 1; i <= lookups; i++ { + if _, ok := resolver.cached(uint32(i)); !ok { + return false + } + } + return pendingCount(resolver) == 0 + }) + + if got := atomic.LoadInt32(&maxRunning); got > workers { + t.Fatalf("expected max concurrent lookups <= %d, got %d", workers, got) + } + + for i := 1; i <= lookups; i++ { + want := fmt.Sprintf("comm-%d", i) + got, ok := resolver.cached(uint32(i)) + if !ok { + t.Fatalf("expected cached comm for tid %d", i) + } + if got != want { + t.Fatalf("expected tid %d comm %q, got %q", i, want, got) + } + } + + if pending := pendingCount(resolver); pending != 0 { + t.Fatalf("expected no pending lookups after completion, got %d", pending) + } +} + +func TestCommResolverQueueLookupQueueFullClearsPending(t *testing.T) { + started := make(chan struct{}, 1) + release := make(chan struct{}) + + resolver := newCommResolver(nil) + resolver.lookupWorkers = 1 + resolver.lookupQueue = make(chan uint32, 1) + resolver.resolveFn = func(tid uint32) string { + select { + case started <- struct{}{}: + default: + } + <-release + return fmt.Sprintf("comm-%d", tid) + } + + const tid1 uint32 = 101 + const tid2 uint32 = 102 + const tid3 uint32 = 103 + + resolver.queueLookup(tid1) + waitForStarts(t, started, 1, 2*time.Second) + + resolver.queueLookup(tid2) + resolver.queueLookup(tid3) + + if !hasPending(resolver, tid1) { + t.Fatalf("expected tid %d to remain pending while worker is blocked", tid1) + } + if !hasPending(resolver, tid2) { + t.Fatalf("expected tid %d to remain pending while queued", tid2) + } + if hasPending(resolver, tid3) { + t.Fatalf("expected tid %d pending flag to be cleared when queue is full", tid3) + } + + close(release) + + waitForCondition(t, 2*time.Second, "expected first two tids to resolve", func() bool { + _, ok1 := resolver.cached(tid1) + _, ok2 := resolver.cached(tid2) + return ok1 && ok2 + }) + + if _, ok := resolver.cached(tid3); ok { + t.Fatalf("did not expect tid %d to resolve from the dropped queue request", tid3) + } + + resolver.queueLookup(tid3) + waitForCondition(t, 2*time.Second, "expected dropped tid to be retried successfully", func() bool { + _, ok := resolver.cached(tid3) + return ok + }) +} + +func hasPending(r *commResolver, tid uint32) bool { + r.mu.RLock() + defer r.mu.RUnlock() + _, ok := r.pending[tid] + return ok +} + +func pendingCount(r *commResolver) int { + r.mu.RLock() + defer r.mu.RUnlock() + return len(r.pending) +} + +func setMaxInt32(target *int32, candidate int32) { + for { + current := atomic.LoadInt32(target) + if candidate <= current { + return + } + if atomic.CompareAndSwapInt32(target, current, candidate) { + return + } + } +} + +func waitForStarts(t *testing.T, ch <-chan struct{}, count int, timeout time.Duration) { + t.Helper() + + timer := time.NewTimer(timeout) + defer timer.Stop() + + for i := 0; i < count; i++ { + select { + case <-ch: + case <-timer.C: + t.Fatalf("timed out waiting for %d resolver lookups to start", count) + } + } +} + +func waitForWaitGroup(t *testing.T, wg *sync.WaitGroup, timeout time.Duration) { + t.Helper() + + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(timeout): + t.Fatal("timed out waiting for resolver lookups to complete") + } +} + +func waitForCondition(t *testing.T, timeout time.Duration, message string, fn func() bool) { + t.Helper() + + deadline := time.Now().Add(timeout) + for { + if fn() { + return + } + if time.Now().After(deadline) { + t.Fatal(message) + } + time.Sleep(10 * time.Millisecond) + } +} diff --git a/internal/eventloop_constructor_test.go b/internal/eventloop_constructor_test.go new file mode 100644 index 0000000..52ca570 --- /dev/null +++ b/internal/eventloop_constructor_test.go @@ -0,0 +1,41 @@ +package internal + +import ( + "strings" + "testing" + + "ior/internal/types" +) + +func mustNewEventLoop(tb testing.TB, cfg eventLoopConfig) *eventLoop { + tb.Helper() + el, err := newEventLoop(cfg) + if err != nil { + tb.Fatalf("newEventLoop() error = %v", err) + } + return el +} + +func TestNewEventFilterRejectsTooLongCommFilter(t *testing.T) { + tooLong := strings.Repeat("a", types.MAX_PROGNAME_LENGTH+1) + _, err := newEventFilter(tooLong, "") + if err == nil { + t.Fatalf("expected error for comm filter longer than %d", types.MAX_PROGNAME_LENGTH) + } +} + +func TestNewEventFilterRejectsTooLongPathFilter(t *testing.T) { + tooLong := strings.Repeat("a", types.MAX_FILENAME_LENGTH+1) + _, err := newEventFilter("", tooLong) + if err == nil { + t.Fatalf("expected error for path filter longer than %d", types.MAX_FILENAME_LENGTH) + } +} + +func TestNewEventLoopPropagatesFilterError(t *testing.T) { + tooLong := strings.Repeat("a", types.MAX_PROGNAME_LENGTH+1) + _, err := newEventLoop(eventLoopConfig{commFilter: tooLong}) + if err == nil { + t.Fatalf("expected newEventLoop to propagate invalid filter error") + } +} diff --git a/internal/eventloop_error_handling_test.go b/internal/eventloop_error_handling_test.go index 12f9b2f..8361dea 100644 --- a/internal/eventloop_error_handling_test.go +++ b/internal/eventloop_error_handling_test.go @@ -8,7 +8,7 @@ import ( ) func TestTracepointExitedMalformedOpenExitDoesNotPanicAndNotifies(t *testing.T) { - el := newEventLoop(eventLoopConfig{}) + el := mustNewEventLoop(t, eventLoopConfig{}) warnings := make(chan string, 1) el.warningCb = func(message string) { warnings <- message } @@ -48,7 +48,7 @@ func TestTracepointExitedMalformedOpenExitDoesNotPanicAndNotifies(t *testing.T) } func TestTracepointExitedMalformedOpenByHandleAtExitDoesNotPanicAndNotifies(t *testing.T) { - el := newEventLoop(eventLoopConfig{}) + el := mustNewEventLoop(t, eventLoopConfig{}) warnings := make(chan string, 1) el.warningCb = func(message string) { warnings <- message } @@ -84,7 +84,7 @@ func TestTracepointExitedMalformedOpenByHandleAtExitDoesNotPanicAndNotifies(t *t } func TestProcessRawEventUnknownTypeDoesNotPanicAndNotifies(t *testing.T) { - el := newEventLoop(eventLoopConfig{}) + el := mustNewEventLoop(t, eventLoopConfig{}) warnings := make(chan string, 1) el.warningCb = func(message string) { warnings <- message } diff --git a/internal/eventloop_filter_test.go b/internal/eventloop_filter_test.go index c3eef1f..d0b7933 100644 --- a/internal/eventloop_filter_test.go +++ b/internal/eventloop_filter_test.go @@ -3,12 +3,12 @@ package internal import ( "context" "fmt" + "testing" + "time" + "ior/internal/event" "ior/internal/file" - "ior/internal/flamegraph" "ior/internal/types" - "testing" - "time" ) // Test that comm names are properly propagated across syscalls @@ -21,7 +21,7 @@ func TestCommPropagation(t *testing.T) { inCh := make(chan []byte) outCh := make(chan *event.Pair) - el := newEventLoop(eventLoopConfig{}) + el := mustNewEventLoop(t, eventLoopConfig{}) el.printCb = func(ev *event.Pair) { outCh <- ev } go el.run(ctx, inCh) @@ -123,7 +123,7 @@ func makeCommPropagationTestData(t *testing.T) (td testData) { t.Errorf("Expected no comm name for different thread but got '%s'", ep.Comm) } // Verify comm map doesn't have entry for this tid - if _, ok := el.comms[differentTid]; ok { + if _, ok := el.cachedComm(differentTid); ok { t.Errorf("Expected no comm entry for tid %d but one was found", differentTid) } }) @@ -438,11 +438,10 @@ func TestCommFilterToggle(t *testing.T) { commFilterEnable: false, }, enterEvs: make(map[uint32]*event.Pair), - files: make(map[int32]file.File), - comms: make(map[uint32]string), + fdTracker: newFDTracker(make(map[int32]file.File)), + commResolver: newCommResolver(make(map[uint32]string)), prevPairTimes: make(map[uint32]uint64), printCb: func(ep *event.Pair) { outCh <- ep }, - flamegraph: flamegraph.New(), done: make(chan struct{}), } go el.run(ctx, inCh) @@ -476,13 +475,13 @@ func TestCommFilterToggle(t *testing.T) { filter: &eventFilter{ commFilterEnable: true, commFilter: "test", + commFilterBytes: []byte("test"), }, enterEvs: make(map[uint32]*event.Pair), - files: make(map[int32]file.File), - comms: make(map[uint32]string), + fdTracker: newFDTracker(make(map[int32]file.File)), + commResolver: newCommResolver(make(map[uint32]string)), prevPairTimes: make(map[uint32]uint64), printCb: func(ep *event.Pair) { outCh <- ep }, - flamegraph: flamegraph.New(), done: make(chan struct{}), } go el.run(ctx, inCh) @@ -509,15 +508,16 @@ func newEventLoopWithFilter(commFilter, pathFilter string) *eventLoop { filter: &eventFilter{ commFilterEnable: commFilter != "", commFilter: commFilter, + commFilterBytes: []byte(commFilter), pathFilterEnable: pathFilter != "", pathFilter: pathFilter, + pathFilterBytes: []byte(pathFilter), }, enterEvs: make(map[uint32]*event.Pair), - files: make(map[int32]file.File), - comms: make(map[uint32]string), + fdTracker: newFDTracker(make(map[int32]file.File)), + commResolver: newCommResolver(make(map[uint32]string)), prevPairTimes: make(map[uint32]uint64), printCb: func(ep *event.Pair) { fmt.Println(ep); ep.Recycle() }, - flamegraph: flamegraph.New(), done: make(chan struct{}), } return el diff --git a/internal/eventloop_seed_test.go b/internal/eventloop_seed_test.go index 3447da2..427869e 100644 --- a/internal/eventloop_seed_test.go +++ b/internal/eventloop_seed_test.go @@ -16,7 +16,7 @@ func TestSeedTrackedPidCommCachesTrackedPidComm(t *testing.T) { cfg: eventLoopConfig{ pidFilter: int(pid), }, - comms: make(map[uint32]string), + commResolver: newCommResolver(make(map[uint32]string)), } el.seedTrackedPidComm() @@ -35,7 +35,7 @@ func TestSeedTrackedPidCommSeedsCurrentProcessWhenPidFilterDisabled(t *testing.T cfg: eventLoopConfig{ pidFilter: -1, }, - comms: make(map[uint32]string), + commResolver: newCommResolver(make(map[uint32]string)), } el.seedTrackedPidComm() diff --git a/internal/eventloop_test.go b/internal/eventloop_test.go index 4ae8597..7fcd438 100644 --- a/internal/eventloop_test.go +++ b/internal/eventloop_test.go @@ -96,7 +96,7 @@ func TestEventloop(t *testing.T) { inCh := make(chan []byte) outCh := make(chan *event.Pair) - el := newEventLoop(eventLoopConfig{}) + el := mustNewEventLoop(t, eventLoopConfig{}) el.printCb = func(ev *event.Pair) { outCh <- ev } go el.run(ctx, inCh) @@ -142,7 +142,7 @@ func TestEventloop(t *testing.T) { } func TestHandleFdExitCloseClearsProcFdCache(t *testing.T) { - el := newEventLoop(eventLoopConfig{}) + el := mustNewEventLoop(t, eventLoopConfig{}) pid := uint32(1001) fd := int32(55) @@ -170,7 +170,7 @@ func TestHandleFdExitCloseClearsProcFdCache(t *testing.T) { } func TestHandleFdExitCloseRangeClearsProcFdCacheRange(t *testing.T) { - el := newEventLoop(eventLoopConfig{}) + el := mustNewEventLoop(t, eventLoopConfig{}) pid := uint32(2002) el.setProcFdCache(10, pid, file.NewFd(10, "keep", syscall.O_RDONLY)) @@ -774,7 +774,7 @@ func makePidfdGetfdEventTestData(t *testing.T) (td testData) { if got, want := ep.File.Name(), path; got != want { t.Errorf("Expected transferred file '%v' but got '%v'", want, got) } - if _, ok := el.files[int32(fd)]; !ok { + if _, ok := el.fdState().files[int32(fd)]; !ok { t.Errorf("Expected transferred fd %d to be tracked", fd) } }) @@ -796,7 +796,7 @@ func makePidfdGetfdFailureTestData(t *testing.T) (td testData) { if !exitEv.Equals(ep.ExitEv) { t.Errorf("Expected '%v' but got '%v'", exitEv, ep.ExitEv) } - if _, ok := el.files[9999]; ok { + if _, ok := el.fdState().files[9999]; ok { t.Errorf("Expected no tracked fd for failed pidfd_getfd") } }) @@ -1382,7 +1382,7 @@ func makeIoUringSetupEventTestData(t *testing.T) (td testData) { if ep.File == nil { t.Errorf("Expected io_uring fd to be tracked") } - if _, ok := el.files[48]; !ok { + if _, ok := el.fdState().files[48]; !ok { t.Errorf("Expected io_uring fd 48 to be tracked") } }) @@ -1408,7 +1408,7 @@ func makeIoUringSetupFailureTestData(t *testing.T) (td testData) { if ep.File != nil { t.Errorf("Expected io_uring_setup failure to have no file tracked") } - if len(el.files) != 0 { + if len(el.fdState().files) != 0 { t.Errorf("Expected no fds to be tracked after io_uring_setup failure") } }) @@ -1434,7 +1434,7 @@ func makeIoUringEnterEventTestData(t *testing.T) (td testData) { if ep.File == nil { t.Errorf("Expected io_uring_enter to have a file") } - if _, ok := el.files[fd]; ok { + if _, ok := el.fdState().files[fd]; ok { t.Errorf("Expected io_uring_enter to not track fd %d", fd) } }) @@ -1460,7 +1460,7 @@ func makeIoUringRegisterEventTestData(t *testing.T) (td testData) { if ep.File == nil { t.Errorf("Expected io_uring_register to have a file") } - if _, ok := el.files[fd]; ok { + if _, ok := el.fdState().files[fd]; ok { t.Errorf("Expected io_uring_register to not track fd %d", fd) } }) @@ -1526,8 +1526,8 @@ func makeDup3WithCloexecTestData(t *testing.T) (td testData) { verifyFileDescriptor(t, el, newFd, filename) // Verify the new fd has O_CLOEXEC flag - if newFile, ok := el.files[newFd]; ok { - fdFile, ok := newFile.(file.FdFile) + if newFile, ok := el.fdState().files[newFd]; ok { + fdFile, ok := newFile.(*file.FdFile) if !ok { t.Errorf("Expected file to be FdFile type") } else if !fdFile.Flags().Is(syscall.O_CLOEXEC) { @@ -1611,8 +1611,8 @@ func makeDup2TestData(t *testing.T) (td testData) { verifyFileDescriptor(t, el, targetFd, filename) // Verify the new fd does NOT have O_CLOEXEC flag (unlike dup3) - if newFile, ok := el.files[targetFd]; ok { - fdFile, ok := newFile.(file.FdFile) + if newFile, ok := el.fdState().files[targetFd]; ok { + fdFile, ok := newFile.(*file.FdFile) if !ok { t.Errorf("Expected file to be FdFile type") } else if fdFile.Flags().Is(syscall.O_CLOEXEC) { @@ -1662,7 +1662,7 @@ func makeDup2TestData(t *testing.T) (td testData) { // Helper functions for FD lifecycle tests func verifyFileDescriptor(t *testing.T, el *eventLoop, fd int32, expectedFileName string) { - if file, ok := el.files[fd]; ok { + if file, ok := el.fdState().files[fd]; ok { if file.Name() != expectedFileName { t.Errorf("Expected fd %d to map to file '%s' but got '%s'", fd, expectedFileName, file.Name()) } @@ -1672,7 +1672,7 @@ func verifyFileDescriptor(t *testing.T, el *eventLoop, fd int32, expectedFileNam } func verifyFdNotTracked(t *testing.T, el *eventLoop, fd int32) { - if _, ok := el.files[fd]; ok { + if _, ok := el.fdState().files[fd]; ok { t.Errorf("Expected fd %d to not be tracked but it was found", fd) } } @@ -1718,7 +1718,7 @@ func verifyMismatchCount(t *testing.T, el *eventLoop, expectedCount uint) { } func verifyCommName(t *testing.T, el *eventLoop, tid uint32, expectedComm string) { - if comm, ok := el.comms[tid]; !ok { + if comm, ok := el.commState().comms[tid]; !ok { t.Errorf("Expected comm name for tid %d but it wasn't found", tid) } else if comm != expectedComm { t.Errorf("Expected comm name '%s' for tid %d but got '%s'", expectedComm, tid, comm) @@ -1764,8 +1764,8 @@ func makeFcntlSetFlagsTestData(t *testing.T) (td testData) { } // Verify flags were updated on the file descriptor - if f, ok := el.files[int32(fd)]; ok { - fdFile, ok := f.(file.FdFile) + if f, ok := el.fdState().files[int32(fd)]; ok { + fdFile, ok := f.(*file.FdFile) if !ok { t.Errorf("Expected file to be FdFile type") } else { @@ -1800,8 +1800,8 @@ func makeFcntlSetFlagsTestData(t *testing.T) (td testData) { } // Verify flags were updated correctly - if f, ok := el.files[int32(fd)]; ok { - fdFile, ok := f.(file.FdFile) + if f, ok := el.fdState().files[int32(fd)]; ok { + fdFile, ok := f.(*file.FdFile) if !ok { t.Errorf("Expected file to be FdFile type") } else { @@ -1878,8 +1878,8 @@ func makeFcntlDupfdTestData(t *testing.T) (td testData) { verifyFileDescriptor(t, el, int32(newFd), filename) // Verify the new fd does NOT have O_CLOEXEC flag (F_DUPFD doesn't set it) - if f, ok := el.files[int32(newFd)]; ok { - fdFile, ok := f.(file.FdFile) + if f, ok := el.fdState().files[int32(newFd)]; ok { + fdFile, ok := f.(*file.FdFile) if !ok { t.Errorf("Expected file to be FdFile type") } else if fdFile.Flags().Is(syscall.O_CLOEXEC) { @@ -1962,8 +1962,8 @@ func makeFcntlDupfdCloexecTestData(t *testing.T) (td testData) { td.validates = append(td.validates, func(t *testing.T, el *eventLoop, ep *event.Pair) { verifyFileDescriptor(t, el, int32(origFd), filename) // Verify original fd doesn't have O_CLOEXEC - if f, ok := el.files[int32(origFd)]; ok { - fdFile, ok := f.(file.FdFile) + if f, ok := el.fdState().files[int32(origFd)]; ok { + fdFile, ok := f.(*file.FdFile) if !ok { t.Errorf("Expected file to be FdFile type") } else if fdFile.Flags().Is(syscall.O_CLOEXEC) { @@ -1993,8 +1993,8 @@ func makeFcntlDupfdCloexecTestData(t *testing.T) (td testData) { verifyFileDescriptor(t, el, int32(newFd), filename) // Verify the new fd has O_CLOEXEC flag - if f, ok := el.files[int32(newFd)]; ok { - fdFile, ok := f.(file.FdFile) + if f, ok := el.fdState().files[int32(newFd)]; ok { + fdFile, ok := f.(*file.FdFile) if !ok { t.Errorf("Expected file to be FdFile type") } else if !fdFile.Flags().Is(syscall.O_CLOEXEC) { @@ -2003,8 +2003,8 @@ func makeFcntlDupfdCloexecTestData(t *testing.T) (td testData) { } // Verify original fd still doesn't have O_CLOEXEC - if f, ok := el.files[int32(origFd)]; ok { - fdFile, ok := f.(file.FdFile) + if f, ok := el.fdState().files[int32(origFd)]; ok { + fdFile, ok := f.(*file.FdFile) if !ok { t.Errorf("Expected file to be FdFile type") } else if fdFile.Flags().Is(syscall.O_CLOEXEC) { @@ -2122,8 +2122,8 @@ func makeFcntlErrorTestData(t *testing.T) (td testData) { } // Verify flags were NOT updated due to error - if f, ok := el.files[int32(fd)]; ok { - fdFile, ok := f.(file.FdFile) + if f, ok := el.fdState().files[int32(fd)]; ok { + fdFile, ok := f.(*file.FdFile) if !ok { t.Errorf("Expected file to be FdFile type") } else if fdFile.Flags().Is(syscall.O_NONBLOCK) { @@ -2149,8 +2149,8 @@ func makeFcntlErrorTestData(t *testing.T) (td testData) { } // Only original fd should be tracked - if len(el.files) != 1 { - t.Errorf("Expected only 1 fd to be tracked, got %d", len(el.files)) + if len(el.fdState().files) != 1 { + t.Errorf("Expected only 1 fd to be tracked, got %d", len(el.fdState().files)) } verifyFileDescriptor(t, el, int32(fd), filename) }) @@ -2193,7 +2193,7 @@ func makeFcntlInvalidFdTestData(t *testing.T) (td testData) { if ep.File == nil { t.Errorf("Expected file to be created for invalid fd") } else { - _, ok := ep.File.(file.FdFile) + _, ok := ep.File.(*file.FdFile) if !ok { t.Errorf("Expected file to be FdFile type") } diff --git a/internal/export/snapshot_csv.go b/internal/export/snapshot_csv.go new file mode 100644 index 0000000..591bd67 --- /dev/null +++ b/internal/export/snapshot_csv.go @@ -0,0 +1,108 @@ +package export + +import ( + "encoding/csv" + "errors" + "fmt" + "os" + "time" + + "ior/internal/statsengine" +) + +// SnapshotCSV writes a dashboard snapshot to a timestamped CSV file. +func SnapshotCSV(snap *statsengine.Snapshot) (filename string, retErr error) { + filename = fmt.Sprintf("ior-snapshot-%s.csv", time.Now().Format("20060102-150405")) + f, err := os.Create(filename) + if err != nil { + return "", err + } + defer func() { + if err := f.Close(); err != nil { + retErr = errors.Join(retErr, fmt.Errorf("close %s: %w", filename, err)) + } + }() + + w := csv.NewWriter(f) + + rows := [][]string{ + {"section", "name", "value1", "value2", "value3"}, + {"summary", "totals", fmt.Sprint(snapValue(snap, func(s *statsengine.Snapshot) uint64 { return s.TotalSyscalls })), fmt.Sprint(snapValue(snap, func(s *statsengine.Snapshot) uint64 { return s.TotalErrors })), fmt.Sprint(snapValue(snap, func(s *statsengine.Snapshot) uint64 { return s.TotalBytes }))}, + {"summary", "rates_per_sec", fmt.Sprintf("%.2f", snapValueF(snap, func(s *statsengine.Snapshot) float64 { return s.SyscallRatePerSec })), fmt.Sprintf("%.2f", snapValueF(snap, func(s *statsengine.Snapshot) float64 { return s.ReadBytesPerSec })), fmt.Sprintf("%.2f", snapValueF(snap, func(s *statsengine.Snapshot) float64 { return s.WriteBytesPerSec }))}, + {"summary", "latency_gap_mean_ns", fmt.Sprintf("%.2f", snapValueF(snap, func(s *statsengine.Snapshot) float64 { return s.LatencyMeanNs })), fmt.Sprintf("%.2f", snapValueF(snap, func(s *statsengine.Snapshot) float64 { return s.GapMeanNs })), ""}, + {"summary", "trend", trendSummary(snap, func(s *statsengine.Snapshot) statsengine.Trend { return s.LatencyTrend }), trendSummary(snap, func(s *statsengine.Snapshot) statsengine.Trend { return s.GapTrend }), trendSummary(snap, func(s *statsengine.Snapshot) statsengine.Trend { return s.ThroughputTrend })}, + } + for _, row := range rows { + if err := w.Write(row); err != nil { + return "", err + } + } + + if snap != nil { + for _, s := range snap.Syscalls() { + if err := w.Write([]string{"syscall", s.Name, fmt.Sprint(s.Count), fmt.Sprintf("%.2f", s.RatePerSec), fmt.Sprint(s.Bytes)}); err != nil { + return "", err + } + if err := w.Write([]string{"syscall_latency_ns", s.Name, fmt.Sprintf("%.2f", s.LatencyMeanNs), fmt.Sprint(s.LatencyMinNs), fmt.Sprint(s.LatencyMaxNs)}); err != nil { + return "", err + } + if err := w.Write([]string{"syscall_percentiles_ns", s.Name, fmt.Sprint(s.LatencyP50Ns), fmt.Sprint(s.LatencyP95Ns), fmt.Sprint(s.LatencyP99Ns)}); err != nil { + return "", err + } + } + for _, r := range snap.Files() { + if err := w.Write([]string{"file", r.Path, fmt.Sprint(r.Accesses), fmt.Sprint(r.BytesRead), fmt.Sprint(r.BytesWritten)}); err != nil { + return "", err + } + if err := w.Write([]string{"file_latency_ns", r.Path, fmt.Sprintf("%.2f", r.AvgLatencyNs), fmt.Sprint(r.MaxLatencyNs), ""}); err != nil { + return "", err + } + } + for _, p := range snap.Processes() { + if err := w.Write([]string{"process", fmt.Sprint(p.PID), fmt.Sprint(p.Syscalls), fmt.Sprintf("%.2f", p.RatePerSec), fmt.Sprint(p.Bytes)}); err != nil { + return "", err + } + if err := w.Write([]string{"process_latency_ns", fmt.Sprint(p.PID), fmt.Sprintf("%.2f", p.AvgLatencyNs), "", ""}); err != nil { + return "", err + } + } + for _, b := range snap.LatencyHistogram.Buckets() { + if err := w.Write([]string{"latency_hist", b.Label, fmt.Sprint(b.Count), fmt.Sprint(b.LowerNs), fmt.Sprint(b.UpperNs)}); err != nil { + return "", err + } + } + for _, b := range snap.GapHistogram.Buckets() { + if err := w.Write([]string{"gap_hist", b.Label, fmt.Sprint(b.Count), fmt.Sprint(b.LowerNs), fmt.Sprint(b.UpperNs)}); err != nil { + return "", err + } + } + } + + w.Flush() + if err := w.Error(); err != nil { + return "", err + } + return filename, nil +} + +func snapValue(snap *statsengine.Snapshot, get func(*statsengine.Snapshot) uint64) uint64 { + if snap == nil { + return 0 + } + return get(snap) +} + +func snapValueF(snap *statsengine.Snapshot, get func(*statsengine.Snapshot) float64) float64 { + if snap == nil { + return 0 + } + return get(snap) +} + +func trendSummary(snap *statsengine.Snapshot, get func(*statsengine.Snapshot) statsengine.Trend) string { + if snap == nil { + return "stable:0.00" + } + trend := get(snap) + return fmt.Sprintf("%s:%.2f", trend.Direction, trend.DeltaPercent) +} diff --git a/internal/file/doc.go b/internal/file/doc.go new file mode 100644 index 0000000..dd3c2dd --- /dev/null +++ b/internal/file/doc.go @@ -0,0 +1,2 @@ +// Package file provides file metadata helpers used by trace rendering and export code. +package file diff --git a/internal/file/file.go b/internal/file/file.go index b1bd84c..b95d40b 100644 --- a/internal/file/file.go +++ b/internal/file/file.go @@ -4,12 +4,14 @@ import ( "bufio" "bytes" "fmt" - "ior/internal/types" "os" "strconv" "strings" + + "ior/internal/types" ) +// File is the common interface for file-like syscall payload representations. type File interface { String() string Name() string @@ -17,6 +19,7 @@ type File interface { FD() int32 } +// FdFile represents a file descriptor-backed file reference. type FdFile struct { fd int32 name string @@ -24,8 +27,9 @@ type FdFile struct { flagsFromProcFS bool } -func NewFd(fd int32, name string, flags int32) FdFile { - f := FdFile{ +// NewFd constructs an FdFile from explicit descriptor metadata. +func NewFd(fd int32, name string, flags int32) *FdFile { + f := &FdFile{ fd: fd, name: name, flags: Flags(flags), @@ -36,10 +40,13 @@ func NewFd(fd int32, name string, flags int32) FdFile { return f } -func NewFdWithPid(fd int32, pid uint32) (f FdFile) { +// NewFdWithPid resolves descriptor metadata from /proc/<pid>/fd. +func NewFdWithPid(fd int32, pid uint32) *FdFile { + f := &FdFile{ + fd: fd, + } var err error - f.fd = fd procPath := fmt.Sprintf("/proc/%d/fd/%d", pid, fd) f.name, err = os.Readlink(procPath) if err != nil { @@ -55,10 +62,10 @@ func NewFdWithPid(fd int32, pid uint32) (f FdFile) { return f } -func (f FdFile) Dup(fd int32) FdFile { - dupFd := f +func (f *FdFile) Dup(fd int32) *FdFile { + dupFd := *f dupFd.fd = fd - return dupFd + return &dupFd } func readFlagsFromFdInfo(fd int32, pid uint32) (Flags, error) { @@ -78,11 +85,11 @@ func readFlagsFromFdInfo(fd int32, pid uint32) (Flags, error) { return unknownFlag, scanner.Err() } -func (f FdFile) Name() string { +func (f *FdFile) Name() string { return f.name } -func (f FdFile) String() string { +func (f *FdFile) String() string { var sb strings.Builder if len(f.name) == 0 { @@ -99,11 +106,11 @@ func (f FdFile) String() string { return sb.String() } -func (f FdFile) Flags() Flags { +func (f *FdFile) Flags() Flags { return f.flags } -func (f FdFile) FD() int32 { +func (f *FdFile) FD() int32 { return f.fd } @@ -119,6 +126,7 @@ type oldnameNewnameFile struct { Oldname, Newname string } +// NewOldnameNewname creates a file representation for rename-like syscalls. func NewOldnameNewname(oldname, newname []byte) oldnameNewnameFile { return oldnameNewnameFile{types.StringValue(oldname), types.StringValue(newname)} } @@ -153,6 +161,7 @@ type pathnameFile struct { Pathname string } +// NewPathname creates a path-only file representation. func NewPathname(pathname []byte) pathnameFile { return pathnameFile{types.StringValue(pathname)} } diff --git a/internal/file/file_test.go b/internal/file/file_test.go index 684a7d8..f9025fe 100644 --- a/internal/file/file_test.go +++ b/internal/file/file_test.go @@ -1,108 +1,109 @@ package file import ( - "ior/internal/types" "strings" "syscall" "testing" + + "ior/internal/types" ) func TestStringValue(t *testing.T) { - var array [128]byte - copy(array[:], "test string") + var array [128]byte + copy(array[:], "test string") - if str := types.StringValue(array[:]); str != "test string" { - t.Errorf("epxected 'test string' but got '%s' with bytes '%v'", str, []byte(str)) - } + if str := types.StringValue(array[:]); str != "test string" { + t.Errorf("epxected 'test string' but got '%s' with bytes '%v'", str, []byte(str)) + } } func TestNewFdUnknownFlags(t *testing.T) { - fdFile := NewFd(1, "test.txt", -1) - if fdFile.Flags() != unknownFlag { - t.Errorf("expected unknown flags, got %v", fdFile.Flags()) - } + fdFile := NewFd(1, "test.txt", -1) + if fdFile.Flags() != unknownFlag { + t.Errorf("expected unknown flags, got %v", fdFile.Flags()) + } } func TestNewFdEmptyName(t *testing.T) { - fdFile := NewFd(1, "", 0) - str := fdFile.String() - if !strings.Contains(str, "E:name") { - t.Errorf("expected String() to contain 'E:name' for empty name, got '%s'", str) - } + fdFile := NewFd(1, "", 0) + str := fdFile.String() + if !strings.Contains(str, "E:name") { + t.Errorf("expected String() to contain 'E:name' for empty name, got '%s'", str) + } } func TestFlagsIsUnknown(t *testing.T) { - f := unknownFlag - if f.Is(syscall.O_RDONLY) { - t.Errorf("expected Is(O_RDONLY) to be false for unknownFlag") - } - if f.Is(syscall.O_WRONLY) { - t.Errorf("expected Is(O_WRONLY) to be false for unknownFlag") - } - if f.Is(syscall.O_RDWR) { - t.Errorf("expected Is(O_RDWR) to be false for unknownFlag") - } + f := unknownFlag + if f.Is(syscall.O_RDONLY) { + t.Errorf("expected Is(O_RDONLY) to be false for unknownFlag") + } + if f.Is(syscall.O_WRONLY) { + t.Errorf("expected Is(O_WRONLY) to be false for unknownFlag") + } + if f.Is(syscall.O_RDWR) { + t.Errorf("expected Is(O_RDWR) to be false for unknownFlag") + } } func TestFlagsStringUnknown(t *testing.T) { - f := Flags(-1) - if f.String() != "O_NONE" { - t.Errorf("expected 'O_NONE' for unknown flags, got '%s'", f.String()) - } + f := Flags(-1) + if f.String() != "O_NONE" { + t.Errorf("expected 'O_NONE' for unknown flags, got '%s'", f.String()) + } } func TestNewOldnameNewnameEmpty(t *testing.T) { - var oldname, newname [128]byte - f := NewOldnameNewname(oldname[:], newname[:]) - if f.Name() != "" { - t.Errorf("expected empty Name(), got '%s'", f.Name()) - } - if !strings.Contains(f.String(), "old:") || !strings.Contains(f.String(), "->new:") { - t.Errorf("expected String() to contain 'old:' and '->new:', got '%s'", f.String()) - } + var oldname, newname [128]byte + f := NewOldnameNewname(oldname[:], newname[:]) + if f.Name() != "" { + t.Errorf("expected empty Name(), got '%s'", f.Name()) + } + if !strings.Contains(f.String(), "old:") || !strings.Contains(f.String(), "->new:") { + t.Errorf("expected String() to contain 'old:' and '->new:', got '%s'", f.String()) + } } func TestNewPathnameEmpty(t *testing.T) { - var pathname [128]byte - f := NewPathname(pathname[:]) - if f.Name() != "" { - t.Errorf("expected empty Name(), got '%s'", f.Name()) - } - if !strings.Contains(f.String(), "pathname:") { - t.Errorf("expected String() to contain 'pathname:', got '%s'", f.String()) - } + var pathname [128]byte + f := NewPathname(pathname[:]) + if f.Name() != "" { + t.Errorf("expected empty Name(), got '%s'", f.Name()) + } + if !strings.Contains(f.String(), "pathname:") { + t.Errorf("expected String() to contain 'pathname:', got '%s'", f.String()) + } } func TestFdFileSetFlags(t *testing.T) { - fdFile := NewFd(1, "test.txt", 0) - if fdFile.Flags() != Flags(0) { - t.Errorf("expected flags 0, got %v", fdFile.Flags()) - } - fdFile.SetFlags(syscall.O_WRONLY) - if fdFile.Flags() != Flags(syscall.O_WRONLY) { - t.Errorf("expected O_WRONLY after SetFlags, got %v", fdFile.Flags()) - } + fdFile := NewFd(1, "test.txt", 0) + if fdFile.Flags() != Flags(0) { + t.Errorf("expected flags 0, got %v", fdFile.Flags()) + } + fdFile.SetFlags(syscall.O_WRONLY) + if fdFile.Flags() != Flags(syscall.O_WRONLY) { + t.Errorf("expected O_WRONLY after SetFlags, got %v", fdFile.Flags()) + } } func TestFdFileAddFlags(t *testing.T) { - fdFile := NewFd(1, "test.txt", syscall.O_RDWR) - fdFile.AddFlags(syscall.O_APPEND) - expected := Flags(syscall.O_RDWR | syscall.O_APPEND) - if fdFile.Flags() != expected { - t.Errorf("expected O_RDWR|O_APPEND after AddFlags, got %v", fdFile.Flags()) - } + fdFile := NewFd(1, "test.txt", syscall.O_RDWR) + fdFile.AddFlags(syscall.O_APPEND) + expected := Flags(syscall.O_RDWR | syscall.O_APPEND) + if fdFile.Flags() != expected { + t.Errorf("expected O_RDWR|O_APPEND after AddFlags, got %v", fdFile.Flags()) + } } func TestFdFileDup(t *testing.T) { - fdFile := NewFd(1, "original.txt", syscall.O_RDONLY) - duped := fdFile.Dup(42) - if duped.Name() != "original.txt" { - t.Errorf("expected duped name 'original.txt', got '%s'", duped.Name()) - } - if !strings.Contains(duped.String(), "42") { - t.Errorf("expected duped String() to contain fd 42, got '%s'", duped.String()) - } - if strings.Contains(duped.String(), "%(1,") { - t.Errorf("expected duped String() to NOT contain old fd 1, got '%s'", duped.String()) - } + fdFile := NewFd(1, "original.txt", syscall.O_RDONLY) + duped := fdFile.Dup(42) + if duped.Name() != "original.txt" { + t.Errorf("expected duped name 'original.txt', got '%s'", duped.Name()) + } + if !strings.Contains(duped.String(), "42") { + t.Errorf("expected duped String() to contain fd 42, got '%s'", duped.String()) + } + if strings.Contains(duped.String(), "%(1,") { + t.Errorf("expected duped String() to NOT contain old fd 1, got '%s'", duped.String()) + } } diff --git a/internal/file/flags.go b/internal/file/flags.go index ca749e1..c06c27b 100644 --- a/internal/file/flags.go +++ b/internal/file/flags.go @@ -3,12 +3,13 @@ package file import ( "os" "strings" + "sync" "syscall" ) type Flags int32 -var flagsToHumanCache = map[Flags]string{} +var flagsToHumanCache sync.Map var unknownFlag = Flags(-1) type tuple struct { @@ -49,12 +50,16 @@ func (f Flags) Is(flag int) bool { } func (f Flags) BuildString(sb *strings.Builder) { - if str, ok := flagsToHumanCache[f]; ok { + if cached, ok := flagsToHumanCache.Load(f); ok { + str, _ := cached.(string) sb.WriteString(str) return } str := f.String() - flagsToHumanCache[f] = str + cached, loaded := flagsToHumanCache.LoadOrStore(f, str) + if loaded { + str, _ = cached.(string) + } sb.WriteString(str) } diff --git a/internal/file/flags_test.go b/internal/file/flags_test.go new file mode 100644 index 0000000..120e432 --- /dev/null +++ b/internal/file/flags_test.go @@ -0,0 +1,40 @@ +package file + +import ( + "strings" + "sync" + "syscall" + "testing" +) + +func TestFlagsBuildStringConcurrent(t *testing.T) { + flagsToHumanCache = sync.Map{} + + const workers = 32 + const iterations = 500 + const want = "O_WRONLY|O_APPEND" + flag := Flags(syscall.O_WRONLY | syscall.O_APPEND) + + var wg sync.WaitGroup + errs := make(chan string, workers) + for i := 0; i < workers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < iterations; j++ { + var sb strings.Builder + flag.BuildString(&sb) + if got := sb.String(); got != want { + errs <- got + return + } + } + }() + } + wg.Wait() + close(errs) + + for got := range errs { + t.Fatalf("unexpected BuildString output %q, want %q", got, want) + } +} diff --git a/internal/flags/doc.go b/internal/flags/doc.go new file mode 100644 index 0000000..103b6d4 --- /dev/null +++ b/internal/flags/doc.go @@ -0,0 +1,2 @@ +// Package flags parses CLI options and exposes runtime configuration snapshots. +package flags diff --git a/internal/flags/flags.go b/internal/flags/flags.go index b2d9dce..af8f84c 100644 --- a/internal/flags/flags.go +++ b/internal/flags/flags.go @@ -10,44 +10,23 @@ import ( "sync" "sync/atomic" "time" + + "ior/internal/collapse" ) var ( - singleton = Flags{ - TUIExportEnable: true, - } - once sync.Once - parseErr error - pidFilter atomic.Int64 - tidFilter atomic.Int64 - tuiExportEnable atomic.Bool + current atomic.Pointer[Config] + once sync.Once + parseErr error ) func init() { - pidFilter.Store(-1) - tidFilter.Store(-1) - tuiExportEnable.Store(true) + defaults := NewFlags() + current.Store(&defaults) } -var ( - validCollapsedFields = []string{ - "path", - "comm", - "tracepoint", - "pid", - "tid", - "flags", - } - - validCollapsedCounts = []string{ - "count", - "duration", - "durationToPrev", - "bytes", - } -) - -type Flags struct { +// Config captures runtime configuration parsed from CLI flags. +type Config struct { PidFilter int TidFilter int EventMapSize int @@ -60,46 +39,104 @@ type Flags struct { TracepointsToAttach []*regexp.Regexp TracepointsToExclude []*regexp.Regexp - // Flamegraph flags - PlainMode bool - FlamegraphEnable bool - LiveFlamegraph bool - LiveInterval time.Duration - OpenCommand string - FlamegraphName string - FlamegraphJSON bool - TUIExportEnable bool - - // To convert ior data into native SVG format - IorDataFile string - IorWatchInterval time.Duration - CollapsedFields []string - CountField string -} - -func Get() Flags { - out := singleton - out.PidFilter = int(pidFilter.Load()) - out.TidFilter = int(tidFilter.Load()) - out.TUIExportEnable = tuiExportEnable.Load() + // Output/runtime flags + PlainMode bool + TestFlames bool + TestLiveFlames bool + LiveInterval time.Duration + TUIExportEnable bool + CollapsedFields []string + CountField string +} + +// NewFlags returns a configuration instance initialized with project defaults. +func NewFlags() Config { + return Config{ + PidFilter: -1, + TidFilter: -1, + EventMapSize: 4096 * 16, + Duration: 900, + LiveInterval: 200 * time.Millisecond, + TUIExportEnable: true, + CollapsedFields: []string{"comm", "tracepoint", "path"}, + CountField: "count", + } +} + +// GetPidFilter returns the active process filter. +func (f Config) GetPidFilter() int { + return f.PidFilter +} + +// GetTidFilter returns the active thread filter. +func (f Config) GetTidFilter() int { + return f.TidFilter +} + +// GetTUIExportEnable reports whether TUI CSV export is enabled. +func (f Config) GetTUIExportEnable() bool { + return f.TUIExportEnable +} + +func (f Config) clone() Config { + out := f + out.TracepointsToAttach = slices.Clone(f.TracepointsToAttach) + out.TracepointsToExclude = slices.Clone(f.TracepointsToExclude) + out.CollapsedFields = slices.Clone(f.CollapsedFields) return out } +// Get returns a copy of the currently active runtime configuration. +func Get() Config { + cfg := current.Load() + if cfg == nil { + return NewFlags() + } + return cfg.clone() +} + +func setCurrent(cfg Config) { + snapshot := cfg.clone() + current.Store(&snapshot) +} + +func updateCurrent(update func(*Config)) { + for { + old := current.Load() + next := NewFlags() + if old != nil { + next = old.clone() + } + update(&next) + snapshot := next.clone() + if current.CompareAndSwap(old, &snapshot) { + return + } + } +} + // SetPidFilter updates the active PID filter used for subsequent tracing runs. func SetPidFilter(pid int) { - pidFilter.Store(int64(pid)) + updateCurrent(func(cfg *Config) { + cfg.PidFilter = pid + }) } // SetTidFilter updates the active TID filter used for subsequent tracing runs. func SetTidFilter(tid int) { - tidFilter.Store(int64(tid)) + updateCurrent(func(cfg *Config) { + cfg.TidFilter = tid + }) } // SetTUIExportEnable toggles TUI snapshot export file writing. func SetTUIExportEnable(enabled bool) { - tuiExportEnable.Store(enabled) + updateCurrent(func(cfg *Config) { + cfg.TUIExportEnable = enabled + }) } +// Parse parses CLI flags once and updates the current runtime configuration. func Parse() error { once.Do(func() { parseErr = parse() @@ -108,48 +145,42 @@ func Parse() error { } func parse() error { - flag.IntVar(&singleton.PidFilter, "pid", -1, "Filter for processes ID") - flag.IntVar(&singleton.TidFilter, "tid", -1, "Filter for thread ID") - flag.IntVar(&singleton.EventMapSize, "mapSize", 4096*16, "BPF FD event ring buffer map size") - flag.IntVar(&singleton.Duration, "duration", 900, "Probe duration in seconds") + cfg := NewFlags() + validFields := collapse.ValidFields() + validCounts := collapse.ValidCountFields() + + flag.IntVar(&cfg.PidFilter, "pid", cfg.PidFilter, "Filter for processes ID") + flag.IntVar(&cfg.TidFilter, "tid", cfg.TidFilter, "Filter for thread ID") + flag.IntVar(&cfg.EventMapSize, "mapSize", cfg.EventMapSize, "BPF FD event ring buffer map size") + flag.IntVar(&cfg.Duration, "duration", cfg.Duration, "Probe duration in seconds") - flag.StringVar(&singleton.CommFilter, "comm", "", "Command to filter for") - flag.StringVar(&singleton.PathFilter, "path", "", "Path to filter for") + flag.StringVar(&cfg.CommFilter, "comm", "", "Command to filter for") + flag.StringVar(&cfg.PathFilter, "path", "", "Path to filter for") - flag.BoolVar(&singleton.PprofEnable, "pprof", false, "Enable profiling") + flag.BoolVar(&cfg.PprofEnable, "pprof", false, "Enable profiling") tracepointsToAttach := flag.String("tps", "", "Comma separated list regexes for tracepoints to load") tracepointsToExclude := flag.String("tpsExclude", "", "Comma separated list regexes for tracepoints to exclude") - flag.BoolVar(&singleton.PlainMode, "plain", false, "Enable plain CSV output mode (disable TUI)") - flag.BoolVar(&singleton.FlamegraphEnable, "flamegraph", false, "Enable flamegraph builder") - flag.BoolVar(&singleton.LiveFlamegraph, "live", false, "Enable live flamegraph mode") - flag.DurationVar(&singleton.LiveInterval, "live-interval", 200*time.Millisecond, "Live flamegraph refresh interval") - flag.StringVar(&singleton.OpenCommand, "open", "", "Command to open live flamegraph URL (used with -live); use {url} placeholder or URL is appended") - flag.StringVar(&singleton.FlamegraphName, "name", "default", "Name of the flamegraph, used to generate the SVG file") - flag.BoolVar(&singleton.FlamegraphJSON, "flamegraphJson", false, "Also export flamegraph tree as JSON in -ior mode (experimental WASM-ready output)") - flag.BoolVar(&singleton.TUIExportEnable, "tuiExport", true, "Enable writing TUI snapshot export files") - - flag.StringVar(&singleton.IorDataFile, "ior", "", "IOR data file to convert into native SVG flamegraph") - flag.DurationVar(&singleton.IorWatchInterval, "iorWatchInterval", 0, - "In -ior mode, poll input file for changes and regenerate outputs; also enables auto-reloading viewer") + flag.BoolVar(&cfg.PlainMode, "plain", false, "Enable plain CSV output mode (disable TUI)") + flag.BoolVar(&cfg.TestFlames, "testflames", false, "Run TUI with static synthetic flamegraph data for keyboard-navigation testing") + flag.BoolVar(&cfg.TestLiveFlames, "testliveflames", false, "Run TUI with continuously-updating synthetic flamegraph data for live keyboard-navigation testing") + flag.DurationVar(&cfg.LiveInterval, "live-interval", cfg.LiveInterval, "Synthetic live flamegraph refresh interval for --testliveflames") + flag.BoolVar(&cfg.TUIExportEnable, "tuiExport", cfg.TUIExportEnable, "Enable writing TUI snapshot export files") fields := flag.String("fields", "", - fmt.Sprintf("Comma separated list of fields to collapse, valid are: %v", validCollapsedFields)) - flag.StringVar(&singleton.CountField, "count", "count", - fmt.Sprintf("Count field to collapse, valid are: %v", validCollapsedCounts)) + fmt.Sprintf("Comma separated list of fields to collapse, valid are: %v", validFields)) + flag.StringVar(&cfg.CountField, "count", cfg.CountField, + fmt.Sprintf("Count field to collapse, valid are: %v", validCounts)) if err := flag.CommandLine.Parse(os.Args[1:]); err != nil { return err } - pidFilter.Store(int64(singleton.PidFilter)) - tidFilter.Store(int64(singleton.TidFilter)) - tuiExportEnable.Store(singleton.TUIExportEnable) var err error - singleton.TracepointsToAttach, err = extractTracepointFlags(*tracepointsToAttach) + cfg.TracepointsToAttach, err = extractTracepointFlags(*tracepointsToAttach) if err != nil { return err } - singleton.TracepointsToExclude, err = extractTracepointFlags(*tracepointsToExclude) + cfg.TracepointsToExclude, err = extractTracepointFlags(*tracepointsToExclude) if err != nil { return err } @@ -160,21 +191,22 @@ func parse() error { // If future kernels regress, add targeted exclusions here. if *fields == "" { - singleton.CollapsedFields = []string{"comm", "path", "tracepoint"} + cfg.CollapsedFields = []string{"comm", "tracepoint", "path"} } else { - singleton.CollapsedFields = strings.Split(*fields, ",") + cfg.CollapsedFields = strings.Split(*fields, ",") } - for _, field := range singleton.CollapsedFields { - if !slices.Contains(validCollapsedFields, field) { + for _, field := range cfg.CollapsedFields { + if !collapse.IsValidField(field) { return fmt.Errorf("invalid field for collapse: %s", field) } } - if !slices.Contains(validCollapsedCounts, singleton.CountField) { - return fmt.Errorf("invalid count field: %s", singleton.CountField) + if !collapse.IsValidCountField(cfg.CountField) { + return fmt.Errorf("invalid count field: %s", cfg.CountField) } + setCurrent(cfg) return nil } @@ -192,7 +224,7 @@ func extractTracepointFlags(tracepoints string) (regexes []*regexp.Regexp, err e return regexes, nil } -func (flags Flags) ShouldIAttachTracepoint(tracepointName string) bool { +func (flags Config) ShouldIAttachTracepoint(tracepointName string) bool { for _, re := range flags.TracepointsToExclude { if re.MatchString(tracepointName) { return false diff --git a/internal/flags/flags_test.go b/internal/flags/flags_test.go index 54c65b8..2469068 100644 --- a/internal/flags/flags_test.go +++ b/internal/flags/flags_test.go @@ -9,114 +9,96 @@ import ( "time" ) -func parseForTest(t *testing.T, args ...string) (Flags, error) { +func parseForTest(t *testing.T, args ...string) (Config, error) { t.Helper() oldCommandLine := flag.CommandLine oldArgs := os.Args - oldSingleton := singleton + oldCurrent := Get() oldParseErr := parseErr - oldPID := pidFilter.Load() - oldTID := tidFilter.Load() - oldTUIExport := tuiExportEnable.Load() fs := flag.NewFlagSet("ior-test", flag.ContinueOnError) fs.SetOutput(io.Discard) flag.CommandLine = fs os.Args = append([]string{"ior"}, args...) - singleton = Flags{TUIExportEnable: true} + setCurrent(NewFlags()) parseErr = nil - pidFilter.Store(-1) - tidFilter.Store(-1) - tuiExportEnable.Store(true) err := parse() - cfg := singleton + cfg := Get() t.Cleanup(func() { flag.CommandLine = oldCommandLine os.Args = oldArgs - singleton = oldSingleton + setCurrent(oldCurrent) parseErr = oldParseErr - pidFilter.Store(oldPID) - tidFilter.Store(oldTID) - tuiExportEnable.Store(oldTUIExport) }) return cfg, err } -func TestParseLiveFlagsAndInterval(t *testing.T) { - cfg, err := parseForTest(t, "-live", "-live-interval", "200ms", "-pid", "1234") +func TestParseLiveIntervalAndPID(t *testing.T) { + cfg, err := parseForTest(t, "-live-interval", "200ms", "-pid", "1234") if err != nil { t.Fatalf("parse returned error: %v", err) } - if !cfg.LiveFlamegraph { - t.Fatalf("expected -live to enable live mode") - } if cfg.LiveInterval != 200*time.Millisecond { t.Fatalf("live interval = %v, want %v", cfg.LiveInterval, 200*time.Millisecond) } if cfg.PidFilter != 1234 { t.Fatalf("pid filter = %d, want 1234", cfg.PidFilter) } - if got := int(pidFilter.Load()); got != 1234 { - t.Fatalf("global pid filter = %d, want 1234", got) - } - if cfg.OpenCommand != "" { - t.Fatalf("expected empty open command by default") + if got := Get().GetPidFilter(); got != 1234 { + t.Fatalf("Get().GetPidFilter() = %d, want 1234", got) } } -func TestParseLiveDefaults(t *testing.T) { - cfg, err := parseForTest(t) - if err != nil { - t.Fatalf("parse returned error: %v", err) +func TestNewFlagsDefaultsAndGetters(t *testing.T) { + cfg := NewFlags() + if cfg.GetPidFilter() != -1 { + t.Fatalf("GetPidFilter() = %d, want -1", cfg.GetPidFilter()) } - - if cfg.LiveFlamegraph { - t.Fatalf("expected live mode disabled by default") + if cfg.GetTidFilter() != -1 { + t.Fatalf("GetTidFilter() = %d, want -1", cfg.GetTidFilter()) } - if cfg.LiveInterval != 200*time.Millisecond { - t.Fatalf("default live interval = %v, want %v", cfg.LiveInterval, 200*time.Millisecond) + if !cfg.GetTUIExportEnable() { + t.Fatalf("GetTUIExportEnable() = false, want true") } - if cfg.OpenCommand != "" { - t.Fatalf("expected empty open command by default") + if cfg.CountField != "count" { + t.Fatalf("CountField = %q, want count", cfg.CountField) } } -func TestParseOpenFlags(t *testing.T) { - cfg, err := parseForTest(t, "-live", "-open", "chromium --new-window") +func TestParseLiveDefaults(t *testing.T) { + cfg, err := parseForTest(t) if err != nil { t.Fatalf("parse returned error: %v", err) } - if !cfg.LiveFlamegraph { - t.Fatalf("expected live mode enabled") - } - if cfg.OpenCommand != "chromium --new-window" { - t.Fatalf("open command = %q, want %q", cfg.OpenCommand, "chromium --new-window") + + if cfg.LiveInterval != 200*time.Millisecond { + t.Fatalf("default live interval = %v, want %v", cfg.LiveInterval, 200*time.Millisecond) } } -func TestParseFlamegraphJSONFlag(t *testing.T) { - cfg, err := parseForTest(t, "-flamegraphJson") +func TestParseTestFlamesFlag(t *testing.T) { + cfg, err := parseForTest(t, "--testflames") if err != nil { t.Fatalf("parse returned error: %v", err) } - if !cfg.FlamegraphJSON { - t.Fatalf("expected -flamegraphJson to enable JSON export") + if !cfg.TestFlames { + t.Fatalf("expected --testflames to enable static flamegraph test mode") } } -func TestParseIorWatchIntervalFlag(t *testing.T) { - cfg, err := parseForTest(t, "-iorWatchInterval", "2s") +func TestParseTestLiveFlamesFlag(t *testing.T) { + cfg, err := parseForTest(t, "--testliveflames") if err != nil { t.Fatalf("parse returned error: %v", err) } - if cfg.IorWatchInterval != 2*time.Second { - t.Fatalf("ior watch interval = %v, want %v", cfg.IorWatchInterval, 2*time.Second) + if !cfg.TestLiveFlames { + t.Fatalf("expected --testliveflames to enable synthetic live flamegraph test mode") } } @@ -126,7 +108,7 @@ func TestParseDefaultCollapsedFieldsOrder(t *testing.T) { t.Fatalf("parse returned error: %v", err) } - want := []string{"comm", "path", "tracepoint"} + want := []string{"comm", "tracepoint", "path"} if len(cfg.CollapsedFields) != len(want) { t.Fatalf("default collapsed fields len = %d, want %d", len(cfg.CollapsedFields), len(want)) } diff --git a/internal/flamegraph/counter.go b/internal/flamegraph/counter.go index ae727d4..441db68 100644 --- a/internal/flamegraph/counter.go +++ b/internal/flamegraph/counter.go @@ -10,6 +10,7 @@ import ( // - Duration is the syscall runtime on the same thread. // - DurationToPrev is the inter-syscall gap on the same thread and is attributed // to the current node; there is no separate "idle" pseudo-node. +// // Bytes is only populated for read/write/transfer syscalls. type Counter struct { Count uint64 @@ -27,17 +28,17 @@ func (c Counter) add(other Counter) Counter { return c } -func (c Counter) ValueByName(name string) uint64 { +func (c Counter) ValueByName(name string) (uint64, error) { switch name { case "count": - return c.Count + return c.Count, nil case "duration": - return c.Duration + return c.Duration, nil case "durationToPrev": - return c.DurationToPrev + return c.DurationToPrev, nil case "bytes": - return c.Bytes + return c.Bytes, nil default: - panic(fmt.Sprintln("No", name, "in count record")) + return 0, fmt.Errorf("unknown counter field %q", name) } } diff --git a/internal/flamegraph/doc.go b/internal/flamegraph/doc.go new file mode 100644 index 0000000..02429d3 --- /dev/null +++ b/internal/flamegraph/doc.go @@ -0,0 +1,2 @@ +// Package flamegraph provides TUI flamegraph aggregation primitives. +package flamegraph diff --git a/internal/flamegraph/iordata.go b/internal/flamegraph/iordata.go index 61a65a9..4a562e3 100644 --- a/internal/flamegraph/iordata.go +++ b/internal/flamegraph/iordata.go @@ -3,16 +3,17 @@ package flamegraph import ( "bytes" "encoding/gob" + "errors" "fmt" - "io" - "ior/internal/event" - "ior/internal/file" - "ior/internal/types" "iter" "os" "strings" "time" + "ior/internal/event" + "ior/internal/file" + "ior/internal/types" + // Is there a zstd library part of Go 1.25 "github.com/DataDog/zstd" ) @@ -23,7 +24,8 @@ type commType = string type pidType = uint32 type tidType = uint32 type flagsType = file.Flags -type pathMap map[pathType]map[traceIdType]map[commType]map[pidType]map[tidType]map[flagsType]Counter + +var hostnameFn = os.Hostname type recordKey struct { Path pathType @@ -97,10 +99,10 @@ func (iod iorData) merge(other iorData) iorData { return iod } -func (iod iorData) serializeToFile(flamegraphName string) error { - hostname, err := os.Hostname() +func (iod iorData) serializeToFile(flamegraphName string) (retErr error) { + hostname, err := hostnameFn() if err != nil { - panic(err) + return fmt.Errorf("get hostname: %w", err) } if flamegraphName == "" { flamegraphName = "default" @@ -113,22 +115,33 @@ func (iod iorData) serializeToFile(flamegraphName string) error { file, err := os.Create(tmpFilename) if err != nil { - return err + return fmt.Errorf("create temp file %s: %w", tmpFilename, err) } - defer file.Close() + defer func() { + if err := file.Close(); err != nil { + retErr = errors.Join(retErr, fmt.Errorf("close temp file %s: %w", tmpFilename, err)) + } + }() encoder := zstd.NewWriter(file) - defer encoder.Close() + defer func() { + if err := encoder.Close(); err != nil { + retErr = errors.Join(retErr, fmt.Errorf("close zstd writer for %s: %w", tmpFilename, err)) + } + }() gobEncoder := gob.NewEncoder(encoder) if err := gobEncoder.Encode(iod.records); err != nil { - return err + return fmt.Errorf("encode ior records: %w", err) } if err := encoder.Flush(); err != nil { - return err + return fmt.Errorf("flush ior records: %w", err) } - return os.Rename(tmpFilename, filename) + if err := os.Rename(tmpFilename, filename); err != nil { + return fmt.Errorf("rename %s to %s: %w", tmpFilename, filename, err) + } + return nil } func (iod *iorData) loadFromFile(filename string) error { @@ -142,23 +155,14 @@ func (iod *iorData) loadFromFile(filename string) error { defer decoder.Close() var records map[recordKey]Counter - if err := gob.NewDecoder(decoder).Decode(&records); err == nil && len(records) > 0 { - iod.records = records - return nil - } - - // Fallback path for legacy payloads and empty-map ambiguity. - if _, err := file.Seek(0, io.SeekStart); err != nil { + if err := gob.NewDecoder(decoder).Decode(&records); err != nil { return err } - decoder = zstd.NewReader(file) - defer decoder.Close() - - var buffer bytes.Buffer - if _, err = io.Copy(&buffer, decoder); err != nil { - return err + if records == nil { + records = make(map[recordKey]Counter) } - return iod.deserialize(&buffer) + iod.records = records + return nil } func (iod iorData) serialize() ([]byte, error) { @@ -169,36 +173,14 @@ func (iod iorData) serialize() ([]byte, error) { } func (iod *iorData) deserialize(buf *bytes.Buffer) error { - raw := append([]byte(nil), buf.Bytes()...) - dec := gob.NewDecoder(bytes.NewReader(raw)) var records map[recordKey]Counter - if err := dec.Decode(&records); err == nil && len(records) > 0 { - iod.records = records - return nil - } - - var legacy pathMap - if err := gob.NewDecoder(bytes.NewReader(raw)).Decode(&legacy); err != nil { + if err := gob.NewDecoder(bytes.NewReader(buf.Bytes())).Decode(&records); err != nil { return err } - - iod.records = make(map[recordKey]Counter) - for path, traceIDMap := range legacy { - for traceID, commMap := range traceIDMap { - for comm, pidMap := range commMap { - for pid, tidMap := range pidMap { - for tid, flagsMap := range tidMap { - for f, cnt := range flagsMap { - iod.add(path, traceID, comm, pid, tid, f, cnt) - } - } - } - } - } - } - if len(iod.records) == 0 && records != nil { - iod.records = records + if records == nil { + records = make(map[recordKey]Counter) } + iod.records = records return nil } diff --git a/internal/flamegraph/iordata_test.go b/internal/flamegraph/iordata_test.go index 54f1ed5..ee07a90 100644 --- a/internal/flamegraph/iordata_test.go +++ b/internal/flamegraph/iordata_test.go @@ -2,9 +2,12 @@ package flamegraph import ( "bytes" - "ior/internal/types" + "errors" + "strings" "syscall" "testing" + + "ior/internal/types" ) func counterAt(iod iorData, path pathType, traceID traceIdType, comm commType, pid pidType, tid tidType, flags flagsType) (Counter, bool) { @@ -167,16 +170,36 @@ func TestStringByNameValidFields(t *testing.T) { } } -func TestCounterValueByNamePanic(t *testing.T) { +func TestCounterValueByNameUnknownField(t *testing.T) { c := Counter{Count: 1, Duration: 100, DurationToPrev: 10, Bytes: 64} - defer func() { - if r := recover(); r == nil { - t.Error("Expected panic for unknown counter name, got none") - } - }() + _, err := c.ValueByName("nonexistent") + if err == nil { + t.Error("Expected error for unknown counter name, got nil") + } +} - c.ValueByName("nonexistent") +func TestCounterValueByNameValidFields(t *testing.T) { + c := Counter{Count: 1, Duration: 100, DurationToPrev: 10, Bytes: 64} + + tests := map[string]uint64{ + "count": c.Count, + "duration": c.Duration, + "durationToPrev": c.DurationToPrev, + "bytes": c.Bytes, + } + + for field, want := range tests { + t.Run(field, func(t *testing.T) { + got, err := c.ValueByName(field) + if err != nil { + t.Fatalf("Expected no error for field %q, got %v", field, err) + } + if got != want { + t.Fatalf("Expected %d for field %q, got %d", want, field, got) + } + }) + } } func TestMergeEmpty(t *testing.T) { @@ -287,6 +310,24 @@ func TestDeserializeInvalidData(t *testing.T) { } } +func TestSerializeToFileHostnameErrorReturnsError(t *testing.T) { + origHostnameFn := hostnameFn + t.Cleanup(func() { hostnameFn = origHostnameFn }) + + hostnameFn = func() (string, error) { + return "", errors.New("hostname unavailable") + } + + iod := newIorData() + err := iod.serializeToFile("test") + if err == nil { + t.Fatal("Expected error when hostname lookup fails, got nil") + } + if !strings.Contains(err.Error(), "get hostname") { + t.Fatalf("Expected get hostname context, got %v", err) + } +} + func bothArraysHaveSameElements(a, b []string) bool { if len(a) != len(b) { return false diff --git a/internal/flamegraph/iordatacollector.go b/internal/flamegraph/iordatacollector.go deleted file mode 100644 index 9e92b63..0000000 --- a/internal/flamegraph/iordatacollector.go +++ /dev/null @@ -1,64 +0,0 @@ -package flamegraph - -import ( - "context" - "fmt" - "ior/internal/event" - "runtime" - "sync" -) - -type IorDataCollector struct { - flamegraphName string - Ch chan *event.Pair - Done chan error - workers []worker -} - -func New(flamegraphName ...string) IorDataCollector { - name := "default" - if len(flamegraphName) > 0 && flamegraphName[0] != "" { - name = flamegraphName[0] - } - - f := IorDataCollector{ - flamegraphName: name, - Ch: make(chan *event.Pair, 4096), - Done: make(chan error, 1), - } - numWorkers := runtime.NumCPU() / 4 - if numWorkers == 0 { - numWorkers = 1 - } - for range numWorkers { - f.workers = append(f.workers, newWorker()) - } - return f -} - -func (f IorDataCollector) Start(ctx context.Context) { - go func() { - defer close(f.Done) - var wg sync.WaitGroup - wg.Add(len(f.workers)) - - for i, worker := range f.workers { - fmt.Println("Starting flamegraph worker", i) - go worker.run(ctx, &wg, f.Ch) - } - wg.Wait() - - iod := f.workers[0].iod - if len(f.workers) > 1 { - for i, w := range f.workers[1:] { - iod = iod.merge(w.iod) - fmt.Println("Worker", i+1, "merged") - } - } - if err := iod.serializeToFile(f.flamegraphName); err != nil { - f.Done <- err - return - } - f.Done <- nil - }() -} diff --git a/internal/flamegraph/layout.go b/internal/flamegraph/layout.go deleted file mode 100644 index c319800..0000000 --- a/internal/flamegraph/layout.go +++ /dev/null @@ -1,78 +0,0 @@ -package flamegraph - -import "fmt" - -// FrameLayout captures renderer-agnostic flamegraph geometry for a single frame. -// -// The layout is reusable by non-SVG renderers (for example SDL or WASM UIs) so -// they can render the same hierarchy without depending on SVG internals. -type FrameLayout struct { - Name string - Title string - Fill string - X float64 - Y float64 - Width float64 - Height float64 - Depth int - Total uint64 - Percent float64 -} - -func sanitizeSVGConfig(cfg SVGConfig) SVGConfig { - if cfg.Width <= 0 || cfg.FrameHeight <= 0 || cfg.FontSize <= 0 || cfg.MinWidthPx <= 0 { - return defaultSVGConfig() - } - if cfg.Title == "" { - cfg.Title = defaultSVGConfig().Title - } - return cfg -} - -func canvasHeightFor(cfg SVGConfig, t *trie) int { - return cfg.FrameHeight*(t.maxDepth+1) + 80 -} - -// BuildFrameLayout builds renderer-agnostic frame coordinates from a flamegraph trie. -func BuildFrameLayout(t *trie, cfg SVGConfig) []FrameLayout { - if t == nil || t.root == nil || t.root.total == 0 { - return nil - } - cfg = sanitizeSVGConfig(cfg) - canvasHeight := canvasHeightFor(cfg, t) - out := make([]FrameLayout, 0, len(t.root.children)) - collectFrameLayout(&out, t.root, t.root.total, cfg, 0, 0, canvasHeight, true) - return out -} - -func collectFrameLayout(out *[]FrameLayout, node *trieNode, rootTotal uint64, - cfg SVGConfig, x float64, depth int, canvasHeight int, isRoot bool) { - - if !isRoot { - w := float64(cfg.Width) * (float64(node.total) / float64(rootTotal)) - if w < cfg.MinWidthPx { - return - } - y := float64(canvasHeight - (depth+1)*cfg.FrameHeight) - pct := 100 * float64(node.total) / float64(rootTotal) - *out = append(*out, FrameLayout{ - Name: node.name, - Title: fmt.Sprintf("%s (%d, %.2f%%)", node.name, node.total, pct), - Fill: frameColor(node.name), - X: x, - Y: y, - Width: w, - Height: float64(cfg.FrameHeight - 1), - Depth: depth, - Total: node.total, - Percent: pct, - }) - } - - cursor := x - for _, child := range node.children { - cw := float64(cfg.Width) * (float64(child.total) / float64(rootTotal)) - collectFrameLayout(out, child, rootTotal, cfg, cursor, depth+1, canvasHeight, false) - cursor += cw - } -} diff --git a/internal/flamegraph/layout_test.go b/internal/flamegraph/layout_test.go deleted file mode 100644 index 8fa7398..0000000 --- a/internal/flamegraph/layout_test.go +++ /dev/null @@ -1,77 +0,0 @@ -package flamegraph - -import ( - "math" - "testing" -) - -func almostEqual(a, b float64) bool { - return math.Abs(a-b) < 1e-6 -} - -func TestBuildFrameLayoutBasicGeometry(t *testing.T) { - tr := newTrie() - tr.add([]string{"A"}, 4) - tr.add([]string{"B"}, 1) - tr.computeTotals() - - cfg := defaultSVGConfig() - cfg.Width = 100 - cfg.FrameHeight = 10 - cfg.FontSize = 10 - cfg.MinWidthPx = 1 - - frames := BuildFrameLayout(tr, cfg) - if len(frames) != 2 { - t.Fatalf("frames len = %d, want 2", len(frames)) - } - - a := frames[0] - if a.Name != "A" { - t.Fatalf("first frame name = %q, want %q", a.Name, "A") - } - if !almostEqual(a.X, 0) { - t.Fatalf("A x = %f, want 0", a.X) - } - if !almostEqual(a.Width, 80) { - t.Fatalf("A width = %f, want 80", a.Width) - } - if !almostEqual(a.Percent, 80) { - t.Fatalf("A percent = %f, want 80", a.Percent) - } - if a.Depth != 1 { - t.Fatalf("A depth = %d, want 1", a.Depth) - } - - b := frames[1] - if b.Name != "B" { - t.Fatalf("second frame name = %q, want %q", b.Name, "B") - } - if !almostEqual(b.X, 80) { - t.Fatalf("B x = %f, want 80", b.X) - } - if !almostEqual(b.Width, 20) { - t.Fatalf("B width = %f, want 20", b.Width) - } -} - -func TestBuildFrameLayoutSkipsFramesBelowMinWidth(t *testing.T) { - tr := newTrie() - tr.add([]string{"A"}, 999) - tr.add([]string{"B"}, 1) - tr.computeTotals() - - cfg := defaultSVGConfig() - cfg.Width = 100 - cfg.FrameHeight = 10 - cfg.FontSize = 10 - cfg.MinWidthPx = 1 - - frames := BuildFrameLayout(tr, cfg) - if len(frames) != 1 { - t.Fatalf("frames len = %d, want 1", len(frames)) - } - if frames[0].Name != "A" { - t.Fatalf("remaining frame name = %q, want %q", frames[0].Name, "A") - } -} diff --git a/internal/flamegraph/livehtml.go b/internal/flamegraph/livehtml.go deleted file mode 100644 index 90a6d3d..0000000 --- a/internal/flamegraph/livehtml.go +++ /dev/null @@ -1,841 +0,0 @@ -package flamegraph - -const liveHTML = `<!doctype html> -<html lang="en"> -<head> - <meta charset="utf-8"> - <meta name="viewport" content="width=device-width, initial-scale=1"> - <title>I/O Flame Graph (Live)</title> - <style> - :root { - --fg-bg: #f6f1ea; - --fg-panel: #fbf7f1; - --fg-border: #d8cdc0; - --fg-text: #232323; - --fg-muted: #5f5f5f; - --fg-accent: #7b2d1f; - --fg-btn: #efe2d2; - --fg-btn-hover: #e6d5c1; - --fg-paused: #b02222; - } - - * { box-sizing: border-box; } - - body { - margin: 0; - min-height: 100vh; - background: linear-gradient(180deg, #f8f2ea 0%, #f2e9dc 100%); - color: var(--fg-text); - font-family: monospace; - } - - #controls { - position: sticky; - top: 0; - z-index: 1; - display: flex; - gap: 8px; - align-items: center; - flex-wrap: wrap; - padding: 10px 12px; - background: var(--fg-panel); - border-bottom: 1px solid var(--fg-border); - } - - #controls button { - border: 1px solid var(--fg-border); - background: var(--fg-btn); - color: var(--fg-text); - font: inherit; - font-size: 12px; - line-height: 1.2; - padding: 6px 10px; - cursor: pointer; - } - - #controls button:hover { - background: var(--fg-btn-hover); - } - - #controls .order-toggle { - min-width: 220px; - text-align: left; - } - - #status { - margin-left: 8px; - font-size: 12px; - color: var(--fg-muted); - } - - .paused #status { - color: var(--fg-paused); - font-weight: 700; - letter-spacing: 0.03em; - text-transform: uppercase; - } - - #flamegraph { - display: block; - width: 100%; - height: calc(100vh - 56px); - min-height: calc(100vh - 56px); - background: transparent; - } - - .title { - font-size: 14px; - font-family: monospace; - } - - .controls text { - font-size: 12px; - font-family: monospace; - cursor: pointer; - fill: #444; - } - - .frame text { - font-size: 11px; - font-family: monospace; - pointer-events: none; - fill: #111; - } - - .frame rect { - stroke: rgba(0, 0, 0, 0.18); - stroke-width: 0.5; - } - </style> -</head> -<body> - <div id="controls"> - <button id="btn-pause" type="button">Pause</button> - <button id="btn-search" type="button">Search</button> - <button id="btn-reset-search" type="button">Reset Search</button> - <button id="btn-undo-zoom" type="button">Undo Zoom</button> - <button id="btn-reset-zoom" type="button">Reset Zoom</button> - <button id="btn-reset-baseline" type="button">Reset Baseline</button> - <button id="btn-toggle-order" class="order-toggle" type="button">Order: comm > path > tracepoint</button> - <span id="status">LIVE</span> - </div> - - <svg id="flamegraph" xmlns="http://www.w3.org/2000/svg"></svg> - - <script> - (function () { - var fg = { - paused: false, - resetting: false, - lastTreeData: null, - pendingData: null, - searchQuery: '', - zoomStack: [], - zoomRange: null, - frames: [], - rootWidth: 0, - matchColor: 'rgb(220,30,70)', - eventSource: null, - svg: document.getElementById('flamegraph'), - status: document.getElementById('status'), - pauseBtn: document.getElementById('btn-pause'), - searchBtn: document.getElementById('btn-search'), - resetSearchBtn: document.getElementById('btn-reset-search'), - undoZoomBtn: document.getElementById('btn-undo-zoom'), - resetZoomBtn: document.getElementById('btn-reset-zoom'), - resetBaselineBtn: document.getElementById('btn-reset-baseline'), - toggleOrderBtn: document.getElementById('btn-toggle-order'), - orderPresets: [ - 'comm,path,tracepoint', - 'path,tracepoint,comm', - 'tracepoint,comm,path', - 'pid,path,tracepoint' - ], - orderIndex: 0, - cfg: { - baseWidth: 1200, - baseFrameHeight: 16, - width: 1200, - frameHeight: 16, - fontSize: 12, - minWidthPx: 1.0 - } - }; - - function fgFrameColor(name) { - var bytes = new TextEncoder().encode(name || ''); - var h = 2166136261 >>> 0; - for (var i = 0; i < bytes.length; i++) { - h ^= bytes[i]; - h = Math.imul(h, 16777619) >>> 0; - } - var r = 200 + (h % 35); - var g = 80 + ((h >>> 8) % 120); - var b = 40 + ((h >>> 16) % 90); - return 'rgb(' + r + ',' + g + ',' + b + ')'; - } - - function fgMaxDepth(node, depth) { - if (!node || !Array.isArray(node.c) || node.c.length === 0) { - return depth; - } - var maxDepth = depth; - for (var i = 0; i < node.c.length; i++) { - var childDepth = fgMaxDepth(node.c[i], depth + 1); - if (childDepth > maxDepth) { - maxDepth = childDepth; - } - } - return maxDepth; - } - - function fgDefaultCanvasHeight(maxDepth) { - return (fg.cfg.baseFrameHeight * (maxDepth + 1)) + 80; - } - - function fgViewportLayout(maxDepth) { - var rows = Math.max(maxDepth + 1, 1); - var defaultCanvasHeight = fgDefaultCanvasHeight(maxDepth); - var viewportWidth = Number(window.innerWidth || 0); - if (viewportWidth <= 0 && document && document.documentElement) { - viewportWidth = Number(document.documentElement.clientWidth || 0); - } - if (viewportWidth <= 0) { - viewportWidth = fg.cfg.baseWidth; - } - var viewportHeight = Number(window.innerHeight || 0); - if (viewportHeight <= 0) { - return { - width: viewportWidth, - frameHeight: fg.cfg.baseFrameHeight, - canvasHeight: defaultCanvasHeight - }; - } - - var controls = document.getElementById('controls'); - var controlsHeight = 56; - if (controls && typeof controls.getBoundingClientRect === 'function') { - controlsHeight = Number(controls.getBoundingClientRect().height || controlsHeight); - } - - var availableHeight = viewportHeight - controlsHeight; - if (availableHeight <= 0) { - return { - width: viewportWidth, - frameHeight: fg.cfg.baseFrameHeight, - canvasHeight: defaultCanvasHeight - }; - } - - var canvasHeight = Math.max(defaultCanvasHeight, availableHeight); - var frameHeight = (canvasHeight - 80) / rows; - if (frameHeight < fg.cfg.baseFrameHeight) { - frameHeight = fg.cfg.baseFrameHeight; - } - return { - width: viewportWidth, - frameHeight: frameHeight, - canvasHeight: canvasHeight - }; - } - - function fgVisibleChildrenTotal(node) { - var children = Array.isArray(node && node.c) ? node.c : []; - var total = 0; - for (var i = 0; i < children.length; i++) { - total += Number(children[i].t || 0); - } - if (total > 0) { - return total; - } - return Number(node && node.t || 0); - } - - function fgBuildFrames(node, rootTotal, x, width, depth, canvasHeight, isRoot, out, path) { - if (!node || rootTotal <= 0 || width <= 0) { - return; - } - var currentPath = path || ''; - if (!isRoot) { - var w = width; - if (w < fg.cfg.minWidthPx) { - return; - } - var name = node.n || ''; - currentPath = currentPath ? (currentPath + '\u001f' + name) : name; - var y = canvasHeight - ((depth + 1) * fg.cfg.frameHeight); - var total = Number(node.t || 0); - var pct = 100 * total / Number(rootTotal); - out.push({ - name: name, - path: currentPath, - x: x, - y: y, - w: w, - h: fg.cfg.frameHeight - 1, - depth: depth, - total: total, - pct: pct, - fill: fgFrameColor(name) - }); - } - var cursor = x; - var children = Array.isArray(node.c) ? node.c : []; - var childrenTotal = fgVisibleChildrenTotal(node); - if (childrenTotal <= 0) { - return; - } - for (var i = 0; i < children.length; i++) { - var child = children[i]; - var childTotal = Number(child.t || 0); - if (childTotal <= 0) { - continue; - } - var childWidth = width * (childTotal / childrenTotal); - fgBuildFrames(child, rootTotal, cursor, childWidth, depth + 1, canvasHeight, false, out, currentPath); - cursor += childWidth; - } - } - - function fgEscape(value) { - return String(value || '') - .replace(/&/g, '&') - .replace(/</g, '<') - .replace(/>/g, '>') - .replace(/"/g, '"') - .replace(/'/g, '''); - } - - function fgSetStatus(suffix) { - var prefix = fg.paused ? 'PAUSED' : 'LIVE'; - fg.status.textContent = suffix ? (prefix + ' | ' + suffix) : prefix; - } - - function fgOrderLabel(csv) { - return String(csv || '').split(',').join(' > '); - } - - function fgOrderFields(csv) { - return String(csv || '').split(',').filter(function (s) { return s; }); - } - - function fgSetOrderIndexByCSV(csv) { - for (var i = 0; i < fg.orderPresets.length; i++) { - if (fg.orderPresets[i] === csv) { - fg.orderIndex = i; - return; - } - } - } - - function fgUpdateOrderButton() { - fg.toggleOrderBtn.textContent = 'Order: ' + fgOrderLabel(fg.orderPresets[fg.orderIndex] || ''); - } - - function fgHover(frame) { - var title = frame.querySelector('title'); - fgSetStatus(title ? title.textContent : ''); - } - - function fgDetectRootWidth() { - var maxEnd = 0; - for (var i = 0; i < fg.frames.length; i++) { - var x = Number(fg.frames[i].dataset.x || '0'); - var w = Number(fg.frames[i].dataset.w || '0'); - if (x + w > maxEnd) { - maxEnd = x + w; - } - } - return maxEnd; - } - - function fgSnapshotOriginalGeometry(frame) { - var rect = frame.querySelector('rect'); - var text = frame.querySelector('text'); - frame.dataset.ox = frame.dataset.x || '0'; - frame.dataset.ow = frame.dataset.w || '0'; - if (rect) { - rect.dataset.ox = rect.getAttribute('x') || '0'; - rect.dataset.ow = rect.getAttribute('width') || '0'; - } - if (text) { - text.dataset.ox = text.getAttribute('x') || '0'; - text.dataset.hidden = text.style.display === 'none' ? '1' : '0'; - text.dataset.full = text.textContent || frame.dataset.name || ''; - } - } - - function fgOriginalX(frame) { - return Number(frame.dataset.ox || frame.dataset.x || '0'); - } - - function fgOriginalW(frame) { - return Number(frame.dataset.ow || frame.dataset.w || '0'); - } - - function fgFitLabel(text, width) { - var full = text.dataset.full || text.textContent || ''; - var maxChars = Math.floor((width - 6) / 7); - if (maxChars < 3) { - text.style.display = 'none'; - text.textContent = full; - return; - } - text.style.display = ''; - if (full.length <= maxChars) { - text.textContent = full; - return; - } - text.textContent = full.slice(0, maxChars - 1) + '...'; - } - - function fgSetFrameGeometry(frame, x, w) { - var rect = frame.querySelector('rect'); - var text = frame.querySelector('text'); - if (rect) { - rect.setAttribute('x', String(x)); - rect.setAttribute('width', String(w)); - } - if (text) { - text.setAttribute('x', String(x + 3)); - fgFitLabel(text, w); - } - } - - function fgRestoreFrameGeometry(frame) { - var rect = frame.querySelector('rect'); - var text = frame.querySelector('text'); - if (rect) { - rect.setAttribute('x', rect.dataset.ox || '0'); - rect.setAttribute('width', rect.dataset.ow || '0'); - } - if (text) { - text.setAttribute('x', text.dataset.ox || '0'); - if (text.dataset.hidden === '1') { - text.style.display = 'none'; - text.textContent = text.dataset.full || ''; - } else { - fgFitLabel(text, Number(rect ? (rect.dataset.ow || '0') : '0')); - } - } - } - - function fgZoom(frame) { - var width = fgOriginalW(frame); - if (width <= 0) { - return; - } - if (fg.zoomRange) { - fg.zoomStack.push(fg.zoomRange); - } - fg.zoomRange = { - x: fgOriginalX(frame), - w: width, - depth: Number(frame.dataset.depth || '0'), - path: frame.dataset.path || '' - }; - fgApplyZoom(); - } - - function fgFindFrameByPath(path) { - for (var i = 0; i < fg.frames.length; i++) { - if ((fg.frames[i].dataset.path || '') === path) { - return fg.frames[i]; - } - } - return null; - } - - function fgRefreshZoomRange() { - if (!fg.zoomRange || !fg.zoomRange.path) { - return; - } - var candidatePath = fg.zoomRange.path; - var match = null; - while (candidatePath) { - match = fgFindFrameByPath(candidatePath); - if (match) { - break; - } - var cut = candidatePath.lastIndexOf('\u001f'); - if (cut < 0) { - break; - } - candidatePath = candidatePath.slice(0, cut); - } - if (!match) { - return; - } - var width = fgOriginalW(match); - if (width <= 0) { - return; - } - fg.zoomRange.path = match.dataset.path || candidatePath; - fg.zoomRange.x = fgOriginalX(match); - fg.zoomRange.w = width; - fg.zoomRange.depth = Number(match.dataset.depth || String(fg.zoomRange.depth || 0)); - } - - function fgApplyZoom() { - if (!fg.zoomRange) { - for (var i = 0; i < fg.frames.length; i++) { - fgRestoreFrameGeometry(fg.frames[i]); - fg.frames[i].style.display = ''; - } - return; - } - fgRefreshZoomRange(); - var x = fg.zoomRange.x; - var end = x + fg.zoomRange.w; - var width = fg.zoomRange.w; - var minDepth = fg.zoomRange.depth; - var scale = fg.rootWidth > 0 ? fg.rootWidth / width : 1; - var eps = 1e-6; - for (var i = 0; i < fg.frames.length; i++) { - var frame = fg.frames[i]; - var ox = fgOriginalX(frame); - var ow = fgOriginalW(frame); - var depth = Number(frame.dataset.depth || '0'); - var inRange = (ox >= x - eps) && (ox + ow <= end + eps); - var isAncestor = depth < minDepth && ox <= x + eps && ox + ow >= end - eps; - if (isAncestor || (depth >= minDepth && inRange)) { - if (isAncestor) { - fgSetFrameGeometry(frame, 0, fg.rootWidth); - } else { - fgSetFrameGeometry(frame, (ox - x) * scale, ow * scale); - } - frame.style.display = ''; - } else { - frame.style.display = 'none'; - } - } - } - - function fgUndoZoom() { - if (fg.zoomStack.length === 0) { - fgResetZoom(); - return; - } - fg.zoomRange = fg.zoomStack.pop(); - fgApplyZoom(); - } - - function fgResetZoom() { - fg.zoomStack = []; - fg.zoomRange = null; - fgApplyZoom(); - } - - function fgResetSearch() { - for (var i = 0; i < fg.frames.length; i++) { - var rect = fg.frames[i].querySelector('rect'); - if (!rect) { - continue; - } - rect.setAttribute('fill', fg.frames[i].dataset.baseFill || ''); - } - } - - function fgApplySearch() { - fgResetSearch(); - if (!fg.searchQuery) { - return; - } - var query = fg.searchQuery.toLowerCase(); - for (var i = 0; i < fg.frames.length; i++) { - var name = (fg.frames[i].dataset.name || '').toLowerCase(); - if (name.indexOf(query) < 0) { - continue; - } - var rect = fg.frames[i].querySelector('rect'); - if (rect) { - rect.setAttribute('fill', fg.matchColor); - } - } - } - - function fgSearch() { - var query = window.prompt('Search frame substring:', fg.searchQuery || ''); - if (query === null) { - return; - } - fg.searchQuery = query.trim(); - fgApplySearch(); - } - - function fgTogglePause() { - fg.paused = !fg.paused; - document.body.classList.toggle('paused', fg.paused); - fg.pauseBtn.textContent = fg.paused ? 'Resume' : 'Pause'; - fgSetStatus(''); - if (!fg.paused && fg.pendingData) { - var pending = fg.pendingData; - fg.pendingData = null; - requestAnimationFrame(function () { - fgProcessUpdate(pending); - }); - } - } - - function fgClearLocalState() { - fg.pendingData = null; - fg.searchQuery = ''; - fg.zoomStack = []; - fg.zoomRange = null; - } - - function fgResetBaseline() { - if (fg.resetting) { - return; - } - fg.resetting = true; - fgSetStatus('resetting baseline...'); - fetch('/reset', { method: 'POST' }) - .then(function (resp) { - if (!resp.ok) { - throw new Error('reset failed'); - } - return resp.text(); - }) - .then(function (payload) { - fgClearLocalState(); - fgProcessUpdate(payload); - fgSetStatus('baseline reset'); - }) - .catch(function () { - fgSetStatus('reset failed'); - }) - .then(function () { - fg.resetting = false; - }); - } - - function fgBindFrameEvents() { - for (var i = 0; i < fg.frames.length; i++) { - fg.frames[i].addEventListener('mouseenter', function () { fgHover(this); }); - fg.frames[i].addEventListener('mouseleave', function () { fgSetStatus(''); }); - fg.frames[i].addEventListener('click', function (ev) { - if (ev.detail > 1) { - return; - } - ev.stopPropagation(); - fgZoom(ev.currentTarget); - }); - fg.frames[i].addEventListener('dblclick', function (ev) { - ev.preventDefault(); - ev.stopPropagation(); - fgResetZoom(); - }); - } - } - - function fgRender(treeData) { - var maxDepth = fgMaxDepth(treeData, 0); - var layout = fgViewportLayout(maxDepth); - fg.cfg.width = layout.width; - fg.cfg.frameHeight = layout.frameHeight; - - if (!treeData || Number(treeData.t || 0) <= 0) { - fg.svg.style.height = String(layout.canvasHeight) + 'px'; - fg.svg.setAttribute('viewBox', '0 0 ' + fg.cfg.width + ' ' + layout.canvasHeight); - fg.svg.setAttribute('preserveAspectRatio', 'xMinYMin meet'); - fg.frames = []; - fg.svg.innerHTML = ''; - fgSetStatus(''); - return; - } - - var rootTotal = fgVisibleChildrenTotal(treeData); - if (rootTotal <= 0) { - rootTotal = Number(treeData.t || 0); - } - var canvasHeight = layout.canvasHeight; - var frames = []; - fgBuildFrames(treeData, rootTotal, 0, fg.cfg.width, 0, canvasHeight, true, frames, ''); - - var parts = []; - parts.push('<text class="title" x="10" y="22">I/O Flame Graph (Live)</text>'); - for (var i = 0; i < frames.length; i++) { - var frame = frames[i]; - var textStyle = frame.w <= (fg.cfg.fontSize * 2) ? ' style="display:none"' : ''; - var title = fgEscape(frame.name + ' (' + frame.total + ', ' + frame.pct.toFixed(2) + '%)'); - parts.push('<g class="frame" data-name="' + fgEscape(frame.name) + '" data-path="' + fgEscape(frame.path) + - '" data-x="' + frame.x.toFixed(3) + '" data-w="' + frame.w.toFixed(3) + - '" data-depth="' + frame.depth + '" data-base-fill="' + frame.fill + '">'); - parts.push('<title>' + title + '</title>'); - parts.push('<rect x="' + frame.x.toFixed(3) + '" y="' + frame.y.toFixed(3) + '" width="' + frame.w.toFixed(3) + - '" height="' + frame.h.toFixed(3) + '" fill="' + frame.fill + '"></rect>'); - parts.push('<text x="' + (frame.x + 3).toFixed(3) + '" y="' + (frame.y + fg.cfg.fontSize).toFixed(3) + '"' + - textStyle + '>' + fgEscape(frame.name) + '</text>'); - parts.push('</g>'); - } - - fg.svg.setAttribute('viewBox', '0 0 ' + fg.cfg.width + ' ' + canvasHeight); - fg.svg.setAttribute('preserveAspectRatio', 'xMinYMin meet'); - fg.svg.style.height = String(canvasHeight) + 'px'; - fg.svg.innerHTML = parts.join(''); - fg.frames = Array.prototype.slice.call(fg.svg.querySelectorAll('g.frame')); - fg.rootWidth = fgDetectRootWidth(); - for (var j = 0; j < fg.frames.length; j++) { - fgSnapshotOriginalGeometry(fg.frames[j]); - } - fgBindFrameEvents(); - } - - function fgProcessUpdate(jsonStr) { - var treeData; - try { - treeData = JSON.parse(jsonStr); - } catch (err) { - fgSetStatus('parse error'); - return; - } - fg.lastTreeData = treeData; - fgRender(treeData); - fgApplyZoom(); - fgApplySearch(); - } - - function fgApplyOrder(csv, expectedIndex) { - if (fg.resetting) { - return; - } - fg.resetting = true; - fgSetStatus('changing order...'); - fetch('/order', { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ fields: fgOrderFields(csv) }) - }) - .then(function (resp) { - if (!resp.ok) { - throw new Error('order change failed'); - } - return resp.json(); - }) - .then(function (payload) { - fg.orderIndex = expectedIndex; - if (payload && Array.isArray(payload.fields)) { - fgSetOrderIndexByCSV(payload.fields.join(',')); - } - fgUpdateOrderButton(); - fgClearLocalState(); - if (payload && payload.snapshot) { - fg.lastTreeData = payload.snapshot; - fgRender(payload.snapshot); - } else { - fgProcessUpdate('{"n":"","v":0,"t":0}'); - } - fgSetStatus('order: ' + fgOrderLabel(fg.orderPresets[fg.orderIndex] || csv)); - }) - .catch(function () { - fgSetStatus('order change failed'); - }) - .then(function () { - fg.resetting = false; - }); - } - - function fgToggleOrder() { - var nextIndex = (fg.orderIndex + 1) % fg.orderPresets.length; - fgApplyOrder(fg.orderPresets[nextIndex], nextIndex); - } - - function fgConnect() { - fg.eventSource = new EventSource('/events'); - fg.eventSource.onmessage = function (e) { - if (fg.paused) { - fg.pendingData = e.data; - return; - } - requestAnimationFrame(function () { - fgProcessUpdate(e.data); - }); - }; - fg.eventSource.onerror = function () { - fgSetStatus('stream error'); - }; - } - - function fgHandleResize() { - if (!fg.lastTreeData) { - return; - } - requestAnimationFrame(function () { - fgRender(fg.lastTreeData); - fgApplyZoom(); - fgApplySearch(); - }); - } - - function fgIsTextEntryTarget(target) { - if (!target) { - return false; - } - if (target.isContentEditable) { - return true; - } - var tag = (target.tagName || '').toUpperCase(); - return tag === 'INPUT' || tag === 'TEXTAREA' || tag === 'SELECT'; - } - - function fgHandleKeydown(ev) { - if (fgIsTextEntryTarget(ev.target)) { - return; - } - if (ev.key === ' ' || ev.code === 'Space') { - ev.preventDefault(); - fgTogglePause(); - return; - } - if (ev.key === '/') { - ev.preventDefault(); - fgSearch(); - return; - } - if (ev.key === 'r') { - ev.preventDefault(); - fgResetBaseline(); - return; - } - if (ev.key === 'Escape') { - ev.preventDefault(); - fgResetZoom(); - fgResetSearch(); - } - } - - fg.pauseBtn.addEventListener('click', fgTogglePause); - fg.searchBtn.addEventListener('click', fgSearch); - fg.resetSearchBtn.addEventListener('click', fgResetSearch); - fg.undoZoomBtn.addEventListener('click', fgUndoZoom); - fg.resetZoomBtn.addEventListener('click', fgResetZoom); - fg.resetBaselineBtn.addEventListener('click', fgResetBaseline); - fg.toggleOrderBtn.addEventListener('click', fgToggleOrder); - document.addEventListener('keydown', fgHandleKeydown); - window.addEventListener('resize', fgHandleResize); - - fgUpdateOrderButton(); - fgSetStatus(''); - fgConnect(); - - window.fgFrameColor = fgFrameColor; - window.fgBuildFrames = fgBuildFrames; - window.fgMaxDepth = fgMaxDepth; - window.fgRender = fgRender; - window.fgProcessUpdate = fgProcessUpdate; - window.fgZoom = fgZoom; - window.fgApplyZoom = fgApplyZoom; - window.fgUndoZoom = fgUndoZoom; - window.fgResetZoom = fgResetZoom; - window.fgSearch = fgSearch; - window.fgResetSearch = fgResetSearch; - window.fgTogglePause = fgTogglePause; - window.fgResetBaseline = fgResetBaseline; - window.fgToggleOrder = fgToggleOrder; - window.liveFlamegraphState = fg; - })(); - </script> -</body> -</html> -` diff --git a/internal/flamegraph/livehtml_browser_test.go b/internal/flamegraph/livehtml_browser_test.go deleted file mode 100644 index c7a16c7..0000000 --- a/internal/flamegraph/livehtml_browser_test.go +++ /dev/null @@ -1,314 +0,0 @@ -package flamegraph - -import ( - "encoding/json" - "fmt" - "os" - "os/exec" - "strings" - "testing" -) - -type jsFrame struct { - Name string `json:"name"` - X float64 `json:"x"` - Y float64 `json:"y"` - W float64 `json:"w"` - H float64 `json:"h"` - Depth int `json:"depth"` -} - -type liveJSResult struct { - Colors map[string]string `json:"colors"` - KnownFrames []jsFrame `json:"knownFrames"` - SVGHTML string `json:"svgHTML"` - ViewBox string `json:"viewBox"` - TallViewBox string `json:"tallViewBox"` - TallHeight string `json:"tallHeight"` - PrunedMaxEnd float64 `json:"prunedMaxEnd"` - SingleCount int `json:"singleCount"` - DeepMaxDepth int `json:"deepMaxDepth"` - WideFrameCount int `json:"wideFrameCount"` -} - -func TestLiveHTMLJSRenderingParity(t *testing.T) { - if _, err := exec.LookPath("node"); err != nil { - t.Skip("node not available") - } - - out := runLiveHTMLJSHarness(t) - var got liveJSResult - if err := json.Unmarshal([]byte(out), &got); err != nil { - t.Fatalf("unmarshal node output: %v\nraw:\n%s", err, out) - } - - names := []string{"read", "write", "io_uring_enter", "nested/path"} - for _, name := range names { - want := frameColor(name) - if got.Colors[name] != want { - t.Fatalf("fgFrameColor(%q) = %q, want %q", name, got.Colors[name], want) - } - } - - if len(got.KnownFrames) != 3 { - t.Fatalf("known frame count = %d, want 3", len(got.KnownFrames)) - } - assertFrame(t, got.KnownFrames[0], "A", 0, 96, 720, 15, 1) - assertFrame(t, got.KnownFrames[1], "A1", 0, 80, 720, 15, 2) - assertFrame(t, got.KnownFrames[2], "B", 720, 96, 480, 15, 1) - - if !strings.Contains(got.SVGHTML, `<g class="frame"`) { - t.Fatalf("svg markup missing frame group") - } - if !strings.Contains(got.SVGHTML, `data-name="A"`) { - t.Fatalf("svg markup missing data-name for A") - } - if !strings.Contains(got.SVGHTML, `data-x="0.000"`) { - t.Fatalf("svg markup missing data-x") - } - if !strings.Contains(got.SVGHTML, `data-w="720.000"`) { - t.Fatalf("svg markup missing data-w") - } - if !strings.Contains(got.SVGHTML, `data-depth="1"`) { - t.Fatalf("svg markup missing data-depth") - } - if !strings.Contains(got.SVGHTML, `data-base-fill="rgb(`) { - t.Fatalf("svg markup missing data-base-fill") - } - if got.ViewBox != "0 0 1200 128" { - t.Fatalf("viewBox = %q, want %q", got.ViewBox, "0 0 1200 128") - } - if got.TallViewBox != "0 0 1600 844" { - t.Fatalf("tall viewBox = %q, want %q", got.TallViewBox, "0 0 1600 844") - } - if got.TallHeight != "844px" { - t.Fatalf("tall style height = %q, want %q", got.TallHeight, "844px") - } - if diff(got.PrunedMaxEnd, 1600) > 0.01 { - t.Fatalf("pruned max end = %f, want 1600", got.PrunedMaxEnd) - } - - if got.SingleCount != 1 { - t.Fatalf("single-frame case count = %d, want 1", got.SingleCount) - } - if got.DeepMaxDepth < 50 { - t.Fatalf("deep max depth = %d, want at least 50", got.DeepMaxDepth) - } - if got.WideFrameCount != 1000 { - t.Fatalf("wide frame count = %d, want 1000", got.WideFrameCount) - } -} - -func assertFrame(t *testing.T, got jsFrame, name string, x, y, w, h float64, depth int) { - t.Helper() - if got.Name != name { - t.Fatalf("frame name = %q, want %q", got.Name, name) - } - if got.Depth != depth { - t.Fatalf("frame %q depth = %d, want %d", got.Name, got.Depth, depth) - } - const eps = 0.001 - if diff(got.X, x) > eps || diff(got.Y, y) > eps || diff(got.W, w) > eps || diff(got.H, h) > eps { - t.Fatalf("frame %q geometry = {x:%f y:%f w:%f h:%f}, want {x:%f y:%f w:%f h:%f}", - got.Name, got.X, got.Y, got.W, got.H, x, y, w, h) - } -} - -func diff(a, b float64) float64 { - if a > b { - return a - b - } - return b - a -} - -func runLiveHTMLJSHarness(t *testing.T) string { - t.Helper() - - script := extractLiveHTMLScript(t) - harness := fmt.Sprintf(` -const vm = require("vm"); -const liveScript = %q; - -function makeElement(id) { - return { - id, - textContent: "", - innerHTML: "", - style: {}, - dataset: {}, - attrs: {}, - classList: { toggle: function(){}, add: function(){}, remove: function(){} }, - addEventListener: function(){}, - getBoundingClientRect: function() { return { height: id === "controls" ? 56 : 0 }; }, - setAttribute: function(k, v) { this.attrs[k] = String(v); }, - getAttribute: function(k) { return this.attrs[k] || ""; }, - querySelectorAll: function() { return []; }, - querySelector: function() { return null; } - }; -} - -const elements = {}; -["controls", "flamegraph", "status", "btn-pause", "btn-search", "btn-reset-search", "btn-undo-zoom", "btn-reset-zoom", "btn-reset-baseline", "btn-toggle-order"].forEach((id) => { - elements[id] = makeElement(id); -}); -elements["body"] = makeElement("body"); - -global.document = { - body: elements["body"], - getElementById: function(id) { - if (!elements[id]) elements[id] = makeElement(id); - return elements[id]; - }, - addEventListener: function(){}, -}; -global.window = global; -global.prompt = function(){ return ""; }; -global.fetch = function() { - return Promise.resolve({ - ok: true, - json: function() { return Promise.resolve({ fields: ["comm", "path", "tracepoint"], snapshot: { n: "", v: 0, t: 0 } }); }, - text: function() { return Promise.resolve("{\"n\":\"\",\"v\":0,\"t\":0}"); } - }); -}; -global.requestAnimationFrame = function(cb){ cb(); }; -global.EventSource = function() { - this.onmessage = null; - this.onerror = null; -}; -window.addEventListener = function(){}; - -vm.runInThisContext(liveScript); - -const names = ["read", "write", "io_uring_enter", "nested/path"]; -const colors = {}; -for (const n of names) { - colors[n] = fgFrameColor(n); -} - -const knownTree = { - n: "", - v: 0, - t: 10, - c: [ - { n: "A", v: 0, t: 6, c: [{ n: "A1", v: 6, t: 6 }] }, - { n: "B", v: 4, t: 4 } - ] -}; -const maxDepth = fgMaxDepth(knownTree, 0); -const canvasHeight = (liveFlamegraphState.cfg.frameHeight * (maxDepth + 1)) + 80; -const knownFramesRaw = []; -fgBuildFrames(knownTree, knownTree.t, 0, 1200, 0, canvasHeight, true, knownFramesRaw, ""); -const knownFrames = knownFramesRaw.map((f) => ({ - name: f.name, - x: Number(f.x.toFixed(3)), - y: Number(f.y.toFixed(3)), - w: Number(f.w.toFixed(3)), - h: Number(f.h.toFixed(3)), - depth: f.depth, -})); - -fgRender(knownTree); -const svgHTML = elements["flamegraph"].innerHTML; -const viewBox = elements["flamegraph"].attrs["viewBox"] || ""; - -window.innerWidth = 1600; -window.innerHeight = 900; -fgRender(knownTree); -const tallViewBox = elements["flamegraph"].attrs["viewBox"] || ""; -const tallHeight = elements["flamegraph"].style.height || ""; - -const singleTree = { n: "", v: 0, t: 1, c: [{ n: "only", v: 1, t: 1 }] }; -const singleFrames = []; -const singleCanvas = (liveFlamegraphState.cfg.frameHeight * (fgMaxDepth(singleTree, 0) + 1)) + 80; -fgBuildFrames(singleTree, singleTree.t, 0, 1200, 0, singleCanvas, true, singleFrames, ""); - -let deepTree = { n: "", v: 0, t: 1, c: [] }; -let cursor = deepTree; -for (let i = 0; i < 55; i++) { - const child = { n: "d" + i, v: i === 54 ? 1 : 0, t: 1, c: [] }; - cursor.c = [child]; - cursor = child; -} -const deepMaxDepth = fgMaxDepth(deepTree, 0); - -const wideChildren = []; -for (let i = 0; i < 1000; i++) { - wideChildren.push({ n: "w" + i, v: 1, t: 1 }); -} -const wideTree = { n: "", v: 0, t: 1000, c: wideChildren }; -const wideCanvas = (liveFlamegraphState.cfg.frameHeight * (fgMaxDepth(wideTree, 0) + 1)) + 80; -const wideFrames = []; -fgBuildFrames(wideTree, wideTree.t, 0, 1200, 0, wideCanvas, true, wideFrames, ""); - -const prunedTree = { - n: "", - v: 0, - t: 100, - c: [ - { n: "A", v: 0, t: 60 }, - { n: "B", v: 0, t: 20 } - ] -}; -fgRender(prunedTree); -const prunedHTML = elements["flamegraph"].innerHTML; -const prunedMatches = prunedHTML.match(/data-x=\"([0-9.]+)\" data-w=\"([0-9.]+)\"/g) || []; -let prunedMaxEnd = 0; -for (const m of prunedMatches) { - const parts = m.match(/data-x=\"([0-9.]+)\" data-w=\"([0-9.]+)\"/); - if (!parts) continue; - const end = Number(parts[1]) + Number(parts[2]); - if (end > prunedMaxEnd) { - prunedMaxEnd = end; - } -} - -console.log(JSON.stringify({ - colors, - knownFrames, - svgHTML, - viewBox, - tallViewBox, - tallHeight, - prunedMaxEnd, - singleCount: singleFrames.length, - deepMaxDepth, - wideFrameCount: wideFrames.length, -})); -`, script) - - tmp, err := os.CreateTemp("", "livehtml-js-*.cjs") - if err != nil { - t.Fatalf("create temp script: %v", err) - } - defer os.Remove(tmp.Name()) - - if _, err := tmp.WriteString(harness); err != nil { - _ = tmp.Close() - t.Fatalf("write temp script: %v", err) - } - if err := tmp.Close(); err != nil { - t.Fatalf("close temp script: %v", err) - } - - out, err := exec.Command("node", tmp.Name()).CombinedOutput() - if err != nil { - t.Fatalf("node harness failed: %v\n%s", err, string(out)) - } - return strings.TrimSpace(string(out)) -} - -func extractLiveHTMLScript(t *testing.T) string { - t.Helper() - const openTag = "<script>" - const closeTag = "</script>" - start := strings.Index(liveHTML, openTag) - if start < 0 { - t.Fatalf("script tag not found in liveHTML") - } - start += len(openTag) - end := strings.Index(liveHTML[start:], closeTag) - if end < 0 { - t.Fatalf("closing script tag not found in liveHTML") - } - return strings.TrimSpace(liveHTML[start : start+end]) -} diff --git a/internal/flamegraph/livehtml_interaction_test.go b/internal/flamegraph/livehtml_interaction_test.go deleted file mode 100644 index 0de1466..0000000 --- a/internal/flamegraph/livehtml_interaction_test.go +++ /dev/null @@ -1,615 +0,0 @@ -package flamegraph - -import ( - "encoding/json" - "fmt" - "os" - "os/exec" - "strings" - "testing" -) - -type zoomSearchStateResult struct { - BeforePath string `json:"beforePath"` - AfterPath string `json:"afterPath"` - DeepPathStable bool `json:"deepPathStable"` - SearchPersisted bool `json:"searchPersisted"` - ZoomedBranchStable bool `json:"zoomedBranchStable"` - NonZoomedHidden bool `json:"nonZoomedHidden"` - NewChildVisible bool `json:"newChildVisible"` - PauseUnpauseKeeps bool `json:"pauseUnpauseKeeps"` -} - -type pauseKeyboardResult struct { - PausedBySpace bool `json:"pausedBySpace"` - NoUpdateWhilePaused bool `json:"noUpdateWhilePaused"` - ZoomSearchWhilePaused bool `json:"zoomSearchWhilePaused"` - UnpauseRendersLatest bool `json:"unpauseRendersLatest"` - RapidToggleStable bool `json:"rapidToggleStable"` - SlashSearchWorks bool `json:"slashSearchWorks"` - EscapeResets bool `json:"escapeResets"` - ButtonMatchesKeyboard bool `json:"buttonMatchesKeyboard"` - TypingIgnoresShortcuts bool `json:"typingIgnoresShortcuts"` -} - -type resetBaselineResult struct { - HotkeyPrevented bool `json:"hotkeyPrevented"` - ShiftHotkeyIgnored bool `json:"shiftHotkeyIgnored"` - HotkeyResetApplied bool `json:"hotkeyResetApplied"` - ButtonResetApplied bool `json:"buttonResetApplied"` - ResetCallsValid bool `json:"resetCallsValid"` -} - -type orderToggleResult struct { - OrderButtonUpdated bool `json:"orderButtonUpdated"` - OrderCallValid bool `json:"orderCallValid"` - OrderSnapshotShown bool `json:"orderSnapshotShown"` -} - -func TestLiveHTMLJSZoomSearchStatePreservedAcrossUpdates(t *testing.T) { - if _, err := exec.LookPath("node"); err != nil { - t.Skip("node not available") - } - - snippet := ` -const fg = liveFlamegraphState; - -const frameA = makeFrame("A", "A", 1, 0, 700); -const frameAChild = makeFrame("Achild", "A\u001fAchild", 2, 0, 400); -const frameB = makeFrame("B", "B", 1, 700, 500); -fg.frames = [frameA, frameAChild, frameB]; -fg.rootWidth = 1200; - -fgZoom(frameA); -const beforePath = fg.zoomRange.path; -prompt = function(){ return "A"; }; -fgSearch(); - -const frameA2 = makeFrame("A", "A", 1, 0, 800); -const frameAChild2 = makeFrame("Achild", "A\u001fAchild", 2, 0, 500); -const frameAnew2 = makeFrame("Anew", "A\u001fAnew", 2, 500, 300); -const frameB2 = makeFrame("B", "B", 1, 800, 400); -fg.frames = [frameA2, frameAChild2, frameAnew2, frameB2]; -fg.rootWidth = 1200; -fgApplyZoom(); -prompt = function(_msg, prev){ return prev || "A"; }; -fgSearch(); - -const afterPath = fg.zoomRange.path; -const searchPersisted = frameA2.querySelector("rect").getAttribute("fill") === fg.matchColor; -const nonZoomedHidden = frameB2.style.display === "none"; -const newChildVisible = frameAnew2.style.display !== "none"; -const zoomedBranchStable = frameA2.style.display !== "none" && frameAChild2.style.display !== "none"; - -const deep1 = makeFrame("A2", "A\u001fA1\u001fA2", 3, 100, 200); -fg.frames = [deep1]; -fg.rootWidth = 1200; -fgZoom(deep1); -const deepPath = fg.zoomRange.path; - -const deep2 = makeFrame("A2", "A\u001fA1\u001fA2", 3, 120, 240); -fg.frames = [deep2]; -fgApplyZoom(); -const deep3 = makeFrame("A2", "A\u001fA1\u001fA2", 3, 140, 260); -fg.frames = [deep3]; -fgApplyZoom(); -const deepPathStable = fg.zoomRange.path === deepPath && deep3.style.display !== "none"; - -fg.pendingData = "{\"n\":\"\",\"v\":0,\"t\":0}"; -fgTogglePause(); -fgTogglePause(); -const pauseUnpauseKeeps = fg.zoomRange.path === deepPath; - -console.log(JSON.stringify({ - beforePath, - afterPath, - deepPathStable, - searchPersisted, - zoomedBranchStable, - nonZoomedHidden, - newChildVisible, - pauseUnpauseKeeps -})); -` - - out := runLiveHTMLNodeSnippet(t, snippet) - var got zoomSearchStateResult - if err := json.Unmarshal([]byte(out), &got); err != nil { - t.Fatalf("decode node result: %v\nraw:\n%s", err, out) - } - - if got.BeforePath != "A" || got.AfterPath != "A" { - t.Fatalf("zoom path changed unexpectedly: before=%q after=%q", got.BeforePath, got.AfterPath) - } - if !got.SearchPersisted { - t.Fatalf("expected search highlight to persist across update") - } - if !got.ZoomedBranchStable { - t.Fatalf("expected zoomed branch to remain visible across update") - } - if !got.NonZoomedHidden { - t.Fatalf("expected non-zoomed branch to be hidden while zoomed") - } - if !got.NewChildVisible { - t.Fatalf("expected newly added child in zoomed branch to remain visible") - } - if !got.DeepPathStable { - t.Fatalf("expected deep zoom path to remain stable across multiple updates") - } - if !got.PauseUnpauseKeeps { - t.Fatalf("expected pause/unpause to preserve zoom state") - } -} - -func TestLiveHTMLJSPauseResumeAndKeyboard(t *testing.T) { - if _, err := exec.LookPath("node"); err != nil { - t.Skip("node not available") - } - - snippet := ` -const fg = liveFlamegraphState; -const keydown = __docListeners["keydown"]; - -function keyEvent(key, code, target) { - let prevented = false; - keydown({ - key: key, - code: code, - target: target || { tagName: "BODY", isContentEditable: false }, - preventDefault: function(){ prevented = true; } - }); - return prevented; -} - -let promptCalls = 0; -prompt = function(_msg, prev) { - promptCalls++; - return prev || "needle"; -}; - -const pausePayload = "{\"n\":\"\",\"v\":0,\"t\":10,\"c\":[{\"n\":\"latest\",\"v\":10,\"t\":10}]}"; -const beforeHTML = fg.svg.innerHTML; -const pausedBySpacePrevented = keyEvent(" ", "Space"); -const pausedBySpace = pausedBySpacePrevented && fg.paused && fg.pauseBtn.textContent === "Resume" && fg.status.textContent.indexOf("PAUSED") === 0; - -fg.eventSource.onmessage({ data: pausePayload }); -const noUpdateWhilePaused = fg.pendingData === pausePayload && fg.svg.innerHTML === beforeHTML; - -const pausedFrame = makeFrame("needle", "needle", 1, 0, 1200); -fg.frames = [pausedFrame]; -fg.rootWidth = 1200; -fgZoom(pausedFrame); -prompt = function(_msg, prev) { - promptCalls++; - return prev || "needle"; -}; -fgSearch(); -const zoomSearchWhilePaused = fg.zoomRange && fg.zoomRange.path === "needle" && - pausedFrame.querySelector("rect").getAttribute("fill") === fg.matchColor; - -const resumedBySpacePrevented = keyEvent(" ", "Space"); -const unpauseRendersLatest = resumedBySpacePrevented && !fg.paused && fg.pendingData === null && - fg.pauseBtn.textContent === "Pause" && fg.svg.innerHTML.indexOf('data-name="latest"') >= 0; - -let rapidToggleStable = true; -for (let i = 0; i < 20; i++) { - try { - fgTogglePause(); - } catch (err) { - rapidToggleStable = false; - } -} -if (fg.paused) { - fgTogglePause(); -} -rapidToggleStable = rapidToggleStable && !fg.paused && fg.pauseBtn.textContent === "Pause"; - -promptCalls = 0; -prompt = function() { - promptCalls++; - return "slash"; -}; -const slashPrevented = keyEvent("/", "Slash"); -const slashSearchWorks = slashPrevented && promptCalls === 1 && fg.searchQuery === "slash"; - -const escFrame = makeFrame("slash", "slash", 1, 0, 1200); -fg.frames = [escFrame]; -fg.rootWidth = 1200; -fgZoom(escFrame); -fgSearch(); -const escapePrevented = keyEvent("Escape", "Escape"); -const escapeResets = escapePrevented && fg.zoomRange === null && - escFrame.querySelector("rect").getAttribute("fill") === escFrame.dataset.baseFill; - -let buttonPromptCalls = 0; -prompt = function() { - buttonPromptCalls++; - return "button"; -}; -document.getElementById("btn-pause").listeners.click(); -const pauseViaButton = fg.paused && fg.pauseBtn.textContent === "Resume"; -document.getElementById("btn-pause").listeners.click(); -const resumeViaButton = !fg.paused && fg.pauseBtn.textContent === "Pause"; -document.getElementById("btn-search").listeners.click(); -const searchViaButton = buttonPromptCalls === 1 && fg.searchQuery === "button"; - -const btnFrame = makeFrame("button", "button", 1, 0, 1200); -fg.frames = [btnFrame]; -fg.rootWidth = 1200; -fgZoom(btnFrame); -document.getElementById("btn-reset-search").listeners.click(); -document.getElementById("btn-reset-zoom").listeners.click(); -const resetViaButton = fg.zoomRange === null && - btnFrame.querySelector("rect").getAttribute("fill") === btnFrame.dataset.baseFill; -const buttonMatchesKeyboard = pauseViaButton && resumeViaButton && searchViaButton && resetViaButton; - -const typingTarget = { tagName: "INPUT", isContentEditable: false }; -fg.searchQuery = "typed"; -fg.zoomRange = { path: "typed", x: 0, w: 1200, depth: 1 }; -promptCalls = 0; -const typingSpacePrevented = keyEvent(" ", "Space", typingTarget); -const typingSlashPrevented = keyEvent("/", "Slash", typingTarget); -const typingEscapePrevented = keyEvent("Escape", "Escape", typingTarget); -const typingIgnoresShortcuts = !typingSpacePrevented && !typingSlashPrevented && !typingEscapePrevented && - !fg.paused && promptCalls === 0 && fg.zoomRange !== null && fg.searchQuery === "typed"; - -console.log(JSON.stringify({ - pausedBySpace, - noUpdateWhilePaused, - zoomSearchWhilePaused, - unpauseRendersLatest, - rapidToggleStable, - slashSearchWorks, - escapeResets, - buttonMatchesKeyboard, - typingIgnoresShortcuts -})); -` - - out := runLiveHTMLNodeSnippet(t, snippet) - var got pauseKeyboardResult - if err := json.Unmarshal([]byte(out), &got); err != nil { - t.Fatalf("decode node result: %v\nraw:\n%s", err, out) - } - - if !got.PausedBySpace { - t.Fatalf("expected Space shortcut to pause and update status/button state") - } - if !got.NoUpdateWhilePaused { - t.Fatalf("expected stream updates to queue while paused without rerendering") - } - if !got.ZoomSearchWhilePaused { - t.Fatalf("expected zoom and search to work while paused") - } - if !got.UnpauseRendersLatest { - t.Fatalf("expected unpause to render latest queued update immediately") - } - if !got.RapidToggleStable { - t.Fatalf("expected rapid pause/unpause toggles to remain stable") - } - if !got.SlashSearchWorks { - t.Fatalf("expected '/' shortcut to open search flow") - } - if !got.EscapeResets { - t.Fatalf("expected Escape shortcut to reset zoom/search highlighting") - } - if !got.ButtonMatchesKeyboard { - t.Fatalf("expected button actions to match keyboard behavior") - } - if !got.TypingIgnoresShortcuts { - t.Fatalf("expected keyboard shortcuts to be ignored while typing in an input") - } -} - -func TestLiveHTMLJSResetBaselineHotkeyAndButton(t *testing.T) { - if _, err := exec.LookPath("node"); err != nil { - t.Skip("node not available") - } - - snippet := ` -const fg = liveFlamegraphState; -const keydown = __docListeners["keydown"]; - -function keyEvent(key, code, target) { - let prevented = false; - keydown({ - key: key, - code: code, - target: target || { tagName: "BODY", isContentEditable: false }, - preventDefault: function(){ prevented = true; } - }); - return prevented; -} - -const frame = makeFrame("needle", "needle", 1, 0, 1200); -fg.frames = [frame]; -fg.rootWidth = 1200; -fgZoom(frame); -prompt = function(){ return "needle"; }; -fgSearch(); - -const resetPayload = "{\"n\":\"\",\"v\":0,\"t\":0}"; -const resetCalls = []; -fetch = function(url, opts) { - resetCalls.push({ - url: url, - method: (opts && opts.method) || "GET" - }); - return Promise.resolve({ - ok: true, - text: function() { return Promise.resolve(resetPayload); } - }); -}; - -const hotkeyPrevented = keyEvent("r", "KeyR"); -const shiftHotkeyPrevented = keyEvent("R", "KeyR"); -const shiftHotkeyIgnored = !shiftHotkeyPrevented && resetCalls.length === 1; - -setTimeout(function() { - const hotkeyResetApplied = fg.zoomRange === null && fg.searchQuery === "" && fg.frames.length === 0; - - const frame2 = makeFrame("again", "again", 1, 0, 1200); - fg.frames = [frame2]; - fg.rootWidth = 1200; - fgZoom(frame2); - fg.searchQuery = "again"; - document.getElementById("btn-reset-baseline").listeners.click(); - - setTimeout(function() { - const buttonResetApplied = fg.zoomRange === null && fg.searchQuery === "" && fg.frames.length === 0; - const resetCallsValid = resetCalls.length === 2 && - resetCalls[0].url === "/reset" && resetCalls[0].method === "POST" && - resetCalls[1].url === "/reset" && resetCalls[1].method === "POST"; - - console.log(JSON.stringify({ - hotkeyPrevented, - shiftHotkeyIgnored, - hotkeyResetApplied, - buttonResetApplied, - resetCallsValid - })); - }, 0); -}, 0); -` - - out := runLiveHTMLNodeSnippet(t, snippet) - var got resetBaselineResult - if err := json.Unmarshal([]byte(out), &got); err != nil { - t.Fatalf("decode node result: %v\nraw:\n%s", err, out) - } - - if !got.HotkeyPrevented { - t.Fatalf("expected reset hotkey to prevent default browser handling") - } - if !got.ShiftHotkeyIgnored { - t.Fatalf("expected uppercase 'R' to be ignored for baseline reset") - } - if !got.HotkeyResetApplied { - t.Fatalf("expected 'r' hotkey to reset baseline and clear UI state") - } - if !got.ButtonResetApplied { - t.Fatalf("expected Reset Baseline button to clear UI state") - } - if !got.ResetCallsValid { - t.Fatalf("expected reset interactions to POST /reset") - } -} - -func TestLiveHTMLJSOrderToggle(t *testing.T) { - if _, err := exec.LookPath("node"); err != nil { - t.Skip("node not available") - } - - snippet := ` -const fg = liveFlamegraphState; -const orderCalls = []; -fetch = function(url, opts) { - orderCalls.push({ - url: url, - method: (opts && opts.method) || "GET", - body: (opts && opts.body) || "" - }); - return Promise.resolve({ - ok: true, - json: function() { - return Promise.resolve({ - fields: ["path", "tracepoint", "comm"], - snapshot: { - n: "", - v: 0, - t: 1, - c: [{ n: "/tmp", v: 1, t: 1 }] - } - }); - } - }); -}; - -document.getElementById("btn-toggle-order").listeners.click(); - -setTimeout(function() { - const orderButtonUpdated = document.getElementById("btn-toggle-order").textContent.indexOf("path > tracepoint > comm") >= 0; - const orderSnapshotShown = fg.svg.innerHTML.indexOf('data-name="/tmp"') >= 0; - const req = orderCalls[0] || {}; - let bodyFields = []; - try { - bodyFields = JSON.parse(req.body || "{}").fields || []; - } catch (err) { - bodyFields = []; - } - const orderCallValid = orderCalls.length === 1 && - req.url === "/order" && - req.method === "POST" && - JSON.stringify(bodyFields) === JSON.stringify(["path", "tracepoint", "comm"]); - - console.log(JSON.stringify({ - orderButtonUpdated, - orderCallValid, - orderSnapshotShown - })); -}, 0); -` - - out := runLiveHTMLNodeSnippet(t, snippet) - var got orderToggleResult - if err := json.Unmarshal([]byte(out), &got); err != nil { - t.Fatalf("decode node result: %v\nraw:\n%s", err, out) - } - - if !got.OrderButtonUpdated { - t.Fatalf("expected toggle button label to update to next order") - } - if !got.OrderCallValid { - t.Fatalf("expected toggle to POST /order with next preset fields") - } - if !got.OrderSnapshotShown { - t.Fatalf("expected returned order snapshot to render immediately") - } -} - -func runLiveHTMLNodeSnippet(t *testing.T, snippet string) string { - t.Helper() - - script := extractLiveHTMLScript(t) - harness := fmt.Sprintf(` -const vm = require("vm"); -const liveScript = %q; - -function makeElement(id) { - return { - id, - textContent: "", - innerHTML: "", - style: {}, - dataset: {}, - attrs: {}, - classList: { toggle: function(){}, add: function(){}, remove: function(){} }, - listeners: {}, - addEventListener: function(event, cb) { this.listeners[event] = cb; }, - getBoundingClientRect: function() { return { height: id === "controls" ? 56 : 0 }; }, - setAttribute: function(k, v) { this.attrs[k] = String(v); }, - getAttribute: function(k) { return this.attrs[k] || ""; }, - querySelectorAll: function() { return []; }, - querySelector: function() { return null; } - }; -} - -function makeRect(fill) { - return { - attrs: { fill: fill || "" }, - dataset: {}, - style: {}, - setAttribute: function(k, v) { this.attrs[k] = String(v); }, - getAttribute: function(k) { return this.attrs[k] || ""; } - }; -} - -function makeText(name) { - return { - textContent: name || "", - dataset: { full: name || "", hidden: "0", ox: "0" }, - style: {}, - setAttribute: function(k, v) { - this[k] = String(v); - }, - getAttribute: function(k) { - return this[k] || ""; - } - }; -} - -function makeFrame(name, path, depth, x, w) { - const rect = makeRect("rgb(1,2,3)"); - rect.dataset.ox = String(x); - rect.dataset.ow = String(w); - rect.setAttribute("x", String(x)); - rect.setAttribute("width", String(w)); - - const text = makeText(name); - text.dataset.ox = String(x + 3); - text.setAttribute("x", String(x + 3)); - - const title = { textContent: name + " title" }; - return { - dataset: { - name: name, - path: path, - depth: String(depth), - x: String(x), - w: String(w), - ox: String(x), - ow: String(w), - baseFill: "rgb(1,2,3)" - }, - style: {}, - listeners: {}, - addEventListener: function(event, cb) { this.listeners[event] = cb; }, - querySelector: function(selector) { - if (selector === "rect") return rect; - if (selector === "text") return text; - if (selector === "title") return title; - return null; - }, - querySelectorAll: function() { return []; }, - }; -} - -const elements = {}; -["controls", "flamegraph", "status", "btn-pause", "btn-search", "btn-reset-search", "btn-undo-zoom", "btn-reset-zoom", "btn-reset-baseline", "btn-toggle-order"].forEach((id) => { - elements[id] = makeElement(id); -}); -elements["body"] = makeElement("body"); - -const docListeners = {}; -global.document = { - body: elements["body"], - getElementById: function(id) { - if (!elements[id]) elements[id] = makeElement(id); - return elements[id]; - }, - addEventListener: function(event, cb) { docListeners[event] = cb; }, -}; -global.window = global; -global.prompt = function(){ return ""; }; -global.fetch = function() { - return Promise.resolve({ - ok: true, - json: function() { return Promise.resolve({ fields: ["comm", "path", "tracepoint"], snapshot: { n: "", v: 0, t: 0 } }); }, - text: function() { return Promise.resolve("{\"n\":\"\",\"v\":0,\"t\":0}"); } - }); -}; -global.requestAnimationFrame = function(cb){ cb(); }; -global.EventSource = function() { - this.onmessage = null; - this.onerror = null; -}; -window.addEventListener = function(){}; - -vm.runInThisContext(liveScript); - -global.makeFrame = makeFrame; -global.__docListeners = docListeners; - -%s -`, script, snippet) - - tmp, err := os.CreateTemp("", "livehtml-node-snippet-*.cjs") - if err != nil { - t.Fatalf("create temp script: %v", err) - } - defer os.Remove(tmp.Name()) - - if _, err := tmp.WriteString(harness); err != nil { - _ = tmp.Close() - t.Fatalf("write temp script: %v", err) - } - if err := tmp.Close(); err != nil { - t.Fatalf("close temp script: %v", err) - } - - out, err := exec.Command("node", tmp.Name()).CombinedOutput() - if err != nil { - t.Fatalf("node snippet failed: %v\n%s", err, string(out)) - } - return strings.TrimSpace(string(out)) -} diff --git a/internal/flamegraph/liveserver.go b/internal/flamegraph/liveserver.go deleted file mode 100644 index 8ae2b82..0000000 --- a/internal/flamegraph/liveserver.go +++ /dev/null @@ -1,314 +0,0 @@ -package flamegraph - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "net/http" - "os" - "os/exec" - "os/user" - "path/filepath" - "strconv" - "strings" - "syscall" - "time" -) - -var liveServerTimeouts = serverTimeouts{ - readTimeout: 10 * time.Second, - writeTimeout: 5 * time.Minute, - idleTimeout: 60 * time.Second, -} - -type LiveServerOptions struct { - OpenCommand string - WarningCb func(message string) -} - -var openBrowserURLFn = openBrowserURL - -// ServeLive starts the live flamegraph HTTP server and blocks until ctx is canceled. -func ServeLive(ctx context.Context, lt *LiveTrie, interval time.Duration) error { - return ServeLiveWithOptions(ctx, lt, interval, LiveServerOptions{}) -} - -// ServeLiveWithOptions starts the live flamegraph server with runtime options. -func ServeLiveWithOptions(ctx context.Context, lt *LiveTrie, interval time.Duration, options LiveServerOptions) error { - mux := http.NewServeMux() - mux.HandleFunc("/", handleLivePage()) - mux.HandleFunc("/events", handleSSE(lt, interval)) - mux.HandleFunc("/reset", handleReset(lt)) - mux.HandleFunc("/order", handleOrder(lt)) - return runServer(ctx, mux, liveServerTimeouts, func(hostname string, port int) { - url := fmt.Sprintf("http://%s:%d/", hostname, port) - fmt.Printf("Live flamegraph available at %s\n", url) - if err := maybeOpenLiveBrowser(url, options); err != nil { - notifyLiveWarning(options.WarningCb, fmt.Sprintf("Live flamegraph browser auto-open failed: %v", err)) - } - }) -} - -func maybeOpenLiveBrowser(url string, options LiveServerOptions) error { - if strings.TrimSpace(options.OpenCommand) == "" { - return nil - } - return openBrowserURLFn(url, options.OpenCommand) -} - -func openBrowserURL(url, openCommand string) error { - parts, err := browserOpenCommandParts(openCommand, url) - if err != nil { - return err - } - cmd := exec.Command(parts[0], parts[1:]...) - applySudoInvokerContext(cmd) - if err := cmd.Start(); err != nil { - return err - } - - waitCh := make(chan error, 1) - go func() { waitCh <- cmd.Wait() }() - - timer := time.NewTimer(750 * time.Millisecond) - defer stopAndDrainTimer(timer) - - select { - case waitErr := <-waitCh: - if waitErr != nil { - return fmt.Errorf("browser command exited early: %w", waitErr) - } - case <-timer.C: - } - return nil -} - -func stopAndDrainTimer(timer *time.Timer) { - if timer == nil { - return - } - if timer.Stop() { - return - } - select { - case <-timer.C: - default: - } -} - -func notifyLiveWarning(warningCb func(string), message string) { - if message == "" { - return - } - if warningCb != nil { - warningCb(message) - return - } - fmt.Println(message) -} - -func applySudoInvokerContext(cmd *exec.Cmd) { - applySudoInvokerContextWithEnv(cmd, os.Geteuid(), os.Environ()) -} - -func applySudoInvokerContextWithEnv(cmd *exec.Cmd, euid int, env []string) { - if cmd == nil || euid != 0 { - return - } - - sudoUIDStr, okUID := lookupEnvValue(env, "SUDO_UID") - sudoGIDStr, okGID := lookupEnvValue(env, "SUDO_GID") - if !okUID || !okGID { - return - } - - uid, errUID := strconv.ParseUint(strings.TrimSpace(sudoUIDStr), 10, 32) - gid, errGID := strconv.ParseUint(strings.TrimSpace(sudoGIDStr), 10, 32) - if errUID != nil || errGID != nil { - return - } - - cmd.SysProcAttr = &syscall.SysProcAttr{ - Credential: &syscall.Credential{ - Uid: uint32(uid), - Gid: uint32(gid), - }, - } - - launchEnv := append([]string(nil), env...) - if sudoUser, ok := lookupEnvValue(env, "SUDO_USER"); ok && strings.TrimSpace(sudoUser) != "" { - launchEnv = upsertEnvValue(launchEnv, "USER", sudoUser) - launchEnv = upsertEnvValue(launchEnv, "LOGNAME", sudoUser) - } - - if sudoUser, err := user.LookupId(strconv.FormatUint(uid, 10)); err == nil && strings.TrimSpace(sudoUser.HomeDir) != "" { - launchEnv = upsertEnvValue(launchEnv, "HOME", sudoUser.HomeDir) - if _, ok := lookupEnvValue(launchEnv, "XAUTHORITY"); !ok { - xauth := filepath.Join(sudoUser.HomeDir, ".Xauthority") - if info, statErr := os.Stat(xauth); statErr == nil && !info.IsDir() { - launchEnv = upsertEnvValue(launchEnv, "XAUTHORITY", xauth) - } - } - } - - if _, ok := lookupEnvValue(launchEnv, "XDG_RUNTIME_DIR"); !ok { - runtimeDir := fmt.Sprintf("/run/user/%d", uid) - if info, statErr := os.Stat(runtimeDir); statErr == nil && info.IsDir() { - launchEnv = upsertEnvValue(launchEnv, "XDG_RUNTIME_DIR", runtimeDir) - } - } - - cmd.Env = launchEnv -} - -func lookupEnvValue(env []string, key string) (string, bool) { - prefix := key + "=" - for _, entry := range env { - if strings.HasPrefix(entry, prefix) { - return strings.TrimPrefix(entry, prefix), true - } - } - return "", false -} - -func upsertEnvValue(env []string, key, value string) []string { - prefix := key + "=" - for i := range env { - if strings.HasPrefix(env[i], prefix) { - env[i] = prefix + value - return env - } - } - return append(env, prefix+value) -} - -func browserOpenCommandParts(openCommand, url string) ([]string, error) { - parts := strings.Fields(strings.TrimSpace(openCommand)) - if len(parts) == 0 { - return nil, errors.New("empty browser open command") - } - - containsURL := false - for i := range parts { - if strings.Contains(parts[i], "{url}") { - parts[i] = strings.ReplaceAll(parts[i], "{url}", url) - containsURL = true - } - } - if !containsURL { - parts = append(parts, url) - } - return parts, nil -} - -func handleLivePage() http.HandlerFunc { - return func(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Content-Type", "text/html; charset=utf-8") - _, _ = w.Write([]byte(liveHTML)) - } -} - -func handleSSE(lt *LiveTrie, interval time.Duration) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - flusher, ok := w.(http.Flusher) - if !ok { - http.Error(w, "streaming unsupported", http.StatusInternalServerError) - return - } - if interval <= 0 { - interval = 200 * time.Millisecond - } - - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - - lastVersion, err := sendSnapshot(w, flusher, lt, ^uint64(0)) - if err != nil { - return - } - - ticker := time.NewTicker(interval) - defer ticker.Stop() - - for { - select { - case <-r.Context().Done(): - return - case <-ticker.C: - if lt.Version() == lastVersion { - continue - } - lastVersion, err = sendSnapshot(w, flusher, lt, lastVersion) - if err != nil { - return - } - } - } - } -} - -func handleReset(lt *LiveTrie) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { - w.Header().Set("Allow", http.MethodPost) - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - return - } - - lt.Reset() - payload, _ := lt.SnapshotJSON() - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write(payload) - } -} - -type orderRequest struct { - Fields []string `json:"fields"` -} - -type orderResponse struct { - Fields []string `json:"fields"` - Snapshot json.RawMessage `json:"snapshot,omitempty"` -} - -func handleOrder(lt *LiveTrie) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - switch r.Method { - case http.MethodGet: - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(orderResponse{Fields: lt.Fields()}) - case http.MethodPost: - var req orderRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, "invalid json body", http.StatusBadRequest) - return - } - if err := lt.Reconfigure(req.Fields); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - snap, _ := lt.SnapshotJSON() - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(orderResponse{ - Fields: lt.Fields(), - Snapshot: snap, - }) - default: - w.Header().Set("Allow", http.MethodGet+", "+http.MethodPost) - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - } - } -} - -func sendSnapshot(w http.ResponseWriter, flusher http.Flusher, lt *LiveTrie, lastVersion uint64) (uint64, error) { - payload, version := lt.SnapshotJSON() - if version == lastVersion { - return lastVersion, nil - } - if _, err := fmt.Fprintf(w, "data: %s\n\n", payload); err != nil { - return lastVersion, err - } - flusher.Flush() - return version, nil -} diff --git a/internal/flamegraph/liveserver_open_test.go b/internal/flamegraph/liveserver_open_test.go deleted file mode 100644 index aa9340a..0000000 --- a/internal/flamegraph/liveserver_open_test.go +++ /dev/null @@ -1,179 +0,0 @@ -package flamegraph - -import ( - "errors" - "os" - "os/exec" - "strconv" - "testing" -) - -func TestBrowserOpenCommandPartsRequiresCommand(t *testing.T) { - _, err := browserOpenCommandParts("", "http://localhost:1234/") - if err == nil { - t.Fatalf("expected error for empty open command") - } -} - -func TestBrowserOpenCommandPartsAppendsURLWhenMissing(t *testing.T) { - got, err := browserOpenCommandParts("chromium --new-window", "http://localhost:1234/") - if err != nil { - t.Fatalf("browserOpenCommandParts returned error: %v", err) - } - want := []string{"chromium", "--new-window", "http://localhost:1234/"} - if len(got) != len(want) { - t.Fatalf("len(parts) = %d, want %d", len(got), len(want)) - } - for i := range want { - if got[i] != want[i] { - t.Fatalf("parts[%d] = %q, want %q", i, got[i], want[i]) - } - } -} - -func TestBrowserOpenCommandPartsReplacesURLPlaceholder(t *testing.T) { - got, err := browserOpenCommandParts("open-browser --target={url}", "http://localhost:1234/") - if err != nil { - t.Fatalf("browserOpenCommandParts returned error: %v", err) - } - want := []string{"open-browser", "--target=http://localhost:1234/"} - if len(got) != len(want) { - t.Fatalf("len(parts) = %d, want %d", len(got), len(want)) - } - for i := range want { - if got[i] != want[i] { - t.Fatalf("parts[%d] = %q, want %q", i, got[i], want[i]) - } - } -} - -func TestMaybeOpenLiveBrowserWithoutCommandSkipsOpen(t *testing.T) { - called := false - orig := openBrowserURLFn - openBrowserURLFn = func(url, openCommand string) error { - called = true - return nil - } - t.Cleanup(func() { openBrowserURLFn = orig }) - - err := maybeOpenLiveBrowser("http://localhost:1234/", LiveServerOptions{}) - if err != nil { - t.Fatalf("maybeOpenLiveBrowser returned error: %v", err) - } - if called { - t.Fatalf("expected browser opener not to be called without open command") - } -} - -func TestMaybeOpenLiveBrowserWithCommandCallsOpen(t *testing.T) { - called := false - orig := openBrowserURLFn - openBrowserURLFn = func(url, openCommand string) error { - called = true - if url != "http://localhost:1234/" { - t.Fatalf("url = %q, want %q", url, "http://localhost:1234/") - } - if openCommand != "chromium" { - t.Fatalf("openCommand = %q, want %q", openCommand, "chromium") - } - return nil - } - t.Cleanup(func() { openBrowserURLFn = orig }) - - err := maybeOpenLiveBrowser("http://localhost:1234/", LiveServerOptions{ - OpenCommand: "chromium", - }) - if err != nil { - t.Fatalf("maybeOpenLiveBrowser returned error: %v", err) - } - if !called { - t.Fatalf("expected browser opener to be called") - } -} - -func TestMaybeOpenLiveBrowserPropagatesOpenError(t *testing.T) { - orig := openBrowserURLFn - openBrowserURLFn = func(url, openCommand string) error { - return errors.New("launch failed") - } - t.Cleanup(func() { openBrowserURLFn = orig }) - - err := maybeOpenLiveBrowser("http://localhost:1234/", LiveServerOptions{ - OpenCommand: "chromium", - }) - if err == nil || err.Error() != "launch failed" { - t.Fatalf("expected launch failed error, got %v", err) - } -} - -func TestOpenBrowserURLReturnsErrorWhenCommandExitsNonZero(t *testing.T) { - err := openBrowserURL("http://localhost:1234/", "false") - if err == nil { - t.Fatalf("expected non-nil error") - } -} - -func TestOpenBrowserURLReturnsNilWhenCommandExitsZero(t *testing.T) { - err := openBrowserURL("http://localhost:1234/", "true") - if err != nil { - t.Fatalf("expected nil error, got %v", err) - } -} - -func TestApplySudoInvokerContextWithEnvSetsCredential(t *testing.T) { - cmd := exec.Command("echo") - uid := os.Getuid() - gid := os.Getgid() - env := []string{ - "SUDO_UID=" + strconv.Itoa(uid), - "SUDO_GID=" + strconv.Itoa(gid), - "SUDO_USER=tester", - "HOME=/root", - } - - applySudoInvokerContextWithEnv(cmd, 0, env) - - if cmd.SysProcAttr == nil || cmd.SysProcAttr.Credential == nil { - t.Fatalf("expected process credentials to be configured") - } - if got := cmd.SysProcAttr.Credential.Uid; got != uint32(uid) { - t.Fatalf("credential uid = %d, want %d", got, uint32(uid)) - } - if got := cmd.SysProcAttr.Credential.Gid; got != uint32(gid) { - t.Fatalf("credential gid = %d, want %d", got, uint32(gid)) - } - if got, ok := lookupEnvValue(cmd.Env, "USER"); !ok || got != "tester" { - t.Fatalf("USER env = %q (ok=%v), want %q", got, ok, "tester") - } - if got, ok := lookupEnvValue(cmd.Env, "LOGNAME"); !ok || got != "tester" { - t.Fatalf("LOGNAME env = %q (ok=%v), want %q", got, ok, "tester") - } -} - -func TestApplySudoInvokerContextWithEnvSkipsWhenNotRoot(t *testing.T) { - cmd := exec.Command("echo") - env := []string{ - "SUDO_UID=1000", - "SUDO_GID=1000", - "SUDO_USER=tester", - } - - applySudoInvokerContextWithEnv(cmd, 1000, env) - - if cmd.SysProcAttr != nil { - t.Fatalf("expected credentials to remain nil for non-root euid") - } - if cmd.Env != nil { - t.Fatalf("expected environment to remain nil for non-root euid") - } -} - -func TestNotifyLiveWarningUsesCallback(t *testing.T) { - var got string - notifyLiveWarning(func(message string) { - got = message - }, "open failed") - if got != "open failed" { - t.Fatalf("warning callback got %q, want %q", got, "open failed") - } -} diff --git a/internal/flamegraph/liveserver_test.go b/internal/flamegraph/liveserver_test.go deleted file mode 100644 index 59a3782..0000000 --- a/internal/flamegraph/liveserver_test.go +++ /dev/null @@ -1,380 +0,0 @@ -package flamegraph - -import ( - "bufio" - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "net/http/httptest" - "os" - "strings" - "sync" - "testing" - "time" -) - -func TestHandleSSEContentTypeFormatAndEmptyTrie(t *testing.T) { - lt := NewLiveTrie([]string{"comm"}, "count") - srv := httptest.NewServer(handleSSE(lt, 5*time.Millisecond)) - defer srv.Close() - - resp := connectSSE(t, srv.URL) - defer resp.Body.Close() - - contentType := resp.Header.Get("Content-Type") - if !strings.HasPrefix(contentType, "text/event-stream") { - t.Fatalf("Content-Type = %q, want text/event-stream", contentType) - } - - data := readFirstSSEData(t, resp.Body) - snap := decodeSSESnapshot(t, data) - if snap.Total != 0 { - t.Fatalf("empty trie snapshot total = %d, want 0", snap.Total) - } -} - -func TestHandleSSEMultipleClientsReceiveInitialSnapshot(t *testing.T) { - lt := NewLiveTrie([]string{"comm"}, "count") - lt.Ingest(newTestPair("multi", 42, 1001, "/tmp/multi", 1, 1, 1)) - srv := httptest.NewServer(handleSSE(lt, 5*time.Millisecond)) - defer srv.Close() - - const clients = 4 - var wg sync.WaitGroup - errCh := make(chan error, clients) - - wg.Add(clients) - for i := 0; i < clients; i++ { - go func() { - defer wg.Done() - resp := connectSSE(t, srv.URL) - defer resp.Body.Close() - data := readFirstSSEData(t, resp.Body) - snap := decodeSSESnapshot(t, data) - if snap.Total == 0 { - errCh <- fmt.Errorf("received empty snapshot") - } - }() - } - - wg.Wait() - close(errCh) - for err := range errCh { - t.Fatal(err) - } -} - -func TestHandleSSEReconnectAfterDisconnectGetsLatestSnapshot(t *testing.T) { - lt := NewLiveTrie([]string{"path"}, "count") - lt.Ingest(newTestPair("reconnect", 1, 1001, "/tmp/a", 1, 1, 1)) - srv := httptest.NewServer(handleSSE(lt, 5*time.Millisecond)) - defer srv.Close() - - resp1 := connectSSE(t, srv.URL) - first := decodeSSESnapshot(t, readFirstSSEData(t, resp1.Body)) - _ = resp1.Body.Close() - if first.Total != 1 { - t.Fatalf("first snapshot total = %d, want 1", first.Total) - } - - lt.Ingest(newTestPair("reconnect", 1, 1002, "/tmp/b", 1, 1, 1)) - - resp2 := connectSSE(t, srv.URL) - defer resp2.Body.Close() - second := decodeSSESnapshot(t, readFirstSSEData(t, resp2.Body)) - if second.Total != 2 { - t.Fatalf("reconnected snapshot total = %d, want 2", second.Total) - } -} - -func TestHandleSSERestartedServerAcceptsNewConnection(t *testing.T) { - lt := NewLiveTrie([]string{"comm"}, "count") - lt.Ingest(newTestPair("restart", 1, 1001, "/tmp/a", 1, 1, 1)) - - srv1 := httptest.NewServer(handleSSE(lt, 5*time.Millisecond)) - resp1 := connectSSE(t, srv1.URL) - first := decodeSSESnapshot(t, readFirstSSEData(t, resp1.Body)) - _ = resp1.Body.Close() - srv1.Close() - if first.Total != 1 { - t.Fatalf("first server snapshot total = %d, want 1", first.Total) - } - - lt.Ingest(newTestPair("restart", 1, 1002, "/tmp/b", 1, 1, 1)) - - srv2 := httptest.NewServer(handleSSE(lt, 5*time.Millisecond)) - defer srv2.Close() - resp2 := connectSSE(t, srv2.URL) - defer resp2.Body.Close() - second := decodeSSESnapshot(t, readFirstSSEData(t, resp2.Body)) - if second.Total != 2 { - t.Fatalf("second server snapshot total = %d, want 2", second.Total) - } -} - -func TestHandleSSEDelayedClientLargeTrieGetsValidSnapshot(t *testing.T) { - lt := NewLiveTrie([]string{"path"}, "count") - const events = 12000 - for i := 0; i < events; i++ { - lt.Ingest(newTestPair("late", 7, uint32(10000+i), fmt.Sprintf("/late/%05d", i), 1, 1, 1)) - } - - srv := httptest.NewServer(handleSSE(lt, 5*time.Millisecond)) - defer srv.Close() - - resp := connectSSE(t, srv.URL) - defer resp.Body.Close() - snap := decodeSSESnapshot(t, readFirstSSEData(t, resp.Body)) - if snap.Total != events { - t.Fatalf("late client snapshot total = %d, want %d", snap.Total, events) - } -} - -func TestHandleResetRequiresPost(t *testing.T) { - lt := NewLiveTrie([]string{"comm"}, "count") - req := httptest.NewRequest(http.MethodGet, "/reset", nil) - rec := httptest.NewRecorder() - - handleReset(lt).ServeHTTP(rec, req) - - if rec.Code != http.StatusMethodNotAllowed { - t.Fatalf("status = %d, want %d", rec.Code, http.StatusMethodNotAllowed) - } - if allow := rec.Header().Get("Allow"); allow != http.MethodPost { - t.Fatalf("allow = %q, want %q", allow, http.MethodPost) - } -} - -func TestHandleResetClearsTrieAndReturnsEmptySnapshot(t *testing.T) { - lt := NewLiveTrie([]string{"path"}, "count") - lt.Ingest(newTestPair("reset", 1, 1001, "/tmp/a", 1, 1, 1)) - lt.Ingest(newTestPair("reset", 1, 1002, "/tmp/b", 1, 1, 1)) - if before := decodeLiveSnapshot(t, lt); before.Total == 0 { - t.Fatalf("expected non-empty trie before reset") - } - - req := httptest.NewRequest(http.MethodPost, "/reset", nil) - rec := httptest.NewRecorder() - handleReset(lt).ServeHTTP(rec, req) - - if rec.Code != http.StatusOK { - t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) - } - if ctype := rec.Header().Get("Content-Type"); !strings.Contains(ctype, "application/json") { - t.Fatalf("content-type = %q, want application/json", ctype) - } - var snap trieSnapshot - if err := json.Unmarshal(rec.Body.Bytes(), &snap); err != nil { - t.Fatalf("decode reset snapshot: %v", err) - } - if snap.Total != 0 { - t.Fatalf("reset snapshot total = %d, want 0", snap.Total) - } - - after := decodeLiveSnapshot(t, lt) - if after.Total != 0 { - t.Fatalf("trie total after reset = %d, want 0", after.Total) - } -} - -func TestHandleOrderGetReturnsCurrentFields(t *testing.T) { - lt := NewLiveTrie([]string{"comm", "path", "tracepoint"}, "count") - req := httptest.NewRequest(http.MethodGet, "/order", nil) - rec := httptest.NewRecorder() - handleOrder(lt).ServeHTTP(rec, req) - - if rec.Code != http.StatusOK { - t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) - } - var resp orderResponse - if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { - t.Fatalf("decode response: %v", err) - } - if strings.Join(resp.Fields, ",") != "comm,path,tracepoint" { - t.Fatalf("fields = %v, want [comm path tracepoint]", resp.Fields) - } -} - -func TestHandleOrderPostReconfiguresAndResets(t *testing.T) { - lt := NewLiveTrie([]string{"comm", "path", "tracepoint"}, "count") - lt.Ingest(newTestPair("svc", 42, 1001, "/tmp/a", 1, 1, 1)) - - req := httptest.NewRequest(http.MethodPost, "/order", strings.NewReader(`{"fields":["path","tracepoint","comm"]}`)) - rec := httptest.NewRecorder() - handleOrder(lt).ServeHTTP(rec, req) - - if rec.Code != http.StatusOK { - t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) - } - var resp orderResponse - if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { - t.Fatalf("decode response: %v", err) - } - if strings.Join(resp.Fields, ",") != "path,tracepoint,comm" { - t.Fatalf("fields = %v, want [path tracepoint comm]", resp.Fields) - } - var snap trieSnapshot - if err := json.Unmarshal(resp.Snapshot, &snap); err != nil { - t.Fatalf("decode snapshot: %v", err) - } - if snap.Total != 0 { - t.Fatalf("snapshot total after reconfigure = %d, want 0", snap.Total) - } -} - -func TestHandleOrderPostRejectsInvalidRequest(t *testing.T) { - lt := NewLiveTrie([]string{"comm"}, "count") - - req := httptest.NewRequest(http.MethodPost, "/order", strings.NewReader(`{"fields":["comm","bogus"]}`)) - rec := httptest.NewRecorder() - handleOrder(lt).ServeHTTP(rec, req) - if rec.Code != http.StatusBadRequest { - t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest) - } - - req = httptest.NewRequest(http.MethodPost, "/order", strings.NewReader(`{"fields":[}`)) - rec = httptest.NewRecorder() - handleOrder(lt).ServeHTTP(rec, req) - if rec.Code != http.StatusBadRequest { - t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest) - } -} - -func TestHandleOrderRequiresGetOrPost(t *testing.T) { - lt := NewLiveTrie([]string{"comm"}, "count") - req := httptest.NewRequest(http.MethodPut, "/order", nil) - rec := httptest.NewRecorder() - handleOrder(lt).ServeHTTP(rec, req) - - if rec.Code != http.StatusMethodNotAllowed { - t.Fatalf("status = %d, want %d", rec.Code, http.StatusMethodNotAllowed) - } - if allow := rec.Header().Get("Allow"); allow != http.MethodGet+", "+http.MethodPost { - t.Fatalf("allow = %q, want %q", allow, http.MethodGet+", "+http.MethodPost) - } -} - -func TestServeLivePrintsURLAndStopsOnCancel(t *testing.T) { - lt := NewLiveTrie([]string{"comm"}, "count") - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - output := captureStdout(t, func() { - errCh := make(chan error, 1) - go func() { - errCh <- ServeLive(ctx, lt, 5*time.Millisecond) - }() - - time.Sleep(40 * time.Millisecond) - cancel() - - select { - case err := <-errCh: - if err != nil { - t.Fatalf("ServeLive returned error: %v", err) - } - case <-time.After(2 * time.Second): - t.Fatalf("timeout waiting for ServeLive to return") - } - }) - - if !strings.Contains(output, "Live flamegraph available at http://") { - t.Fatalf("expected live URL in output, got %q", output) - } -} - -func connectSSE(t *testing.T, url string) *http.Response { - t.Helper() - client := &http.Client{Timeout: 5 * time.Second} - resp, err := client.Get(url) - if err != nil { - t.Fatalf("connect sse: %v", err) - } - if resp.StatusCode != http.StatusOK { - _ = resp.Body.Close() - t.Fatalf("unexpected status: %s", resp.Status) - } - return resp -} - -func readFirstSSEData(t *testing.T, body io.ReadCloser) string { - t.Helper() - type result struct { - data string - err error - } - ch := make(chan result, 1) - - go func() { - reader := bufio.NewReader(body) - line, err := reader.ReadString('\n') - if err != nil { - ch <- result{err: err} - return - } - if !strings.HasPrefix(line, "data: ") { - ch <- result{err: fmt.Errorf("invalid sse data line: %q", line)} - return - } - separator, err := reader.ReadString('\n') - if err != nil { - ch <- result{err: err} - return - } - if separator != "\n" { - ch <- result{err: fmt.Errorf("missing sse blank-line separator: %q", separator)} - return - } - ch <- result{data: strings.TrimSuffix(strings.TrimPrefix(line, "data: "), "\n")} - }() - - select { - case out := <-ch: - if out.err != nil { - t.Fatalf("read sse event: %v", out.err) - } - return out.data - case <-time.After(3 * time.Second): - _ = body.Close() - t.Fatalf("timeout waiting for first sse event") - return "" - } -} - -func decodeSSESnapshot(t *testing.T, data string) trieSnapshot { - t.Helper() - var snap trieSnapshot - if err := json.Unmarshal([]byte(data), &snap); err != nil { - t.Fatalf("invalid snapshot json: %v", err) - } - return snap -} - -func captureStdout(t *testing.T, fn func()) string { - t.Helper() - - oldStdout := os.Stdout - reader, writer, err := os.Pipe() - if err != nil { - t.Fatalf("create stdout pipe: %v", err) - } - - os.Stdout = writer - defer func() { os.Stdout = oldStdout }() - - outCh := make(chan string, 1) - go func() { - var b strings.Builder - _, _ = io.Copy(&b, reader) - outCh <- b.String() - }() - - fn() - - _ = writer.Close() - out := <-outCh - _ = reader.Close() - return out -} diff --git a/internal/flamegraph/livetrie.go b/internal/flamegraph/livetrie.go index 100e03b..600e404 100644 --- a/internal/flamegraph/livetrie.go +++ b/internal/flamegraph/livetrie.go @@ -3,15 +3,21 @@ package flamegraph import ( "encoding/json" "fmt" - "ior/internal/event" "slices" "sort" "strings" "sync" "sync/atomic" + + "ior/internal/collapse" + "ior/internal/event" ) -const liveTrieMinFraction = 0.001 +const ( + liveTrieMinFraction = 0.001 + liveTrieMinVisibleChildrenWhenPruned = 8 + liveTrieVisibleChildrenFallbackMaxDepth = 1 +) type trieSnapshot struct { Name string `json:"n"` @@ -37,6 +43,9 @@ type LiveTrie struct { // NewLiveTrie constructs an empty live trie with the configured frame/count fields. func NewLiveTrie(fields []string, countField string) *LiveTrie { + if !isLiveTrieCountField(countField) { + countField = "count" + } return &LiveTrie{ root: &trieNode{ childMap: make(map[string]*trieNode), @@ -47,23 +56,7 @@ func NewLiveTrie(fields []string, countField string) *LiveTrie { } func (lt *LiveTrie) addLocked(frames []string, value uint64) { - node := lt.root - for _, frame := range frames { - if node.childMap == nil { - node.childMap = make(map[string]*trieNode) - } - child, ok := node.childMap[frame] - if !ok { - child = &trieNode{ - name: frame, - childMap: make(map[string]*trieNode), - } - node.children = append(node.children, child) - node.childMap[frame] = child - } - node = child - } - node.value += value + insertTriePath(lt.root, frames, value) if len(frames) > lt.maxDepth { lt.maxDepth = len(frames) } @@ -84,18 +77,24 @@ func (lt *LiveTrie) invalidateCache() { lt.cacheMu.Unlock() } -// Ingest adds one event pair into the live trie and recycles the pair. +// Ingest adds one event pair into the live trie. func (lt *LiveTrie) Ingest(ep *event.Pair) { record := eventPairToRecord(ep) - value := record.Cnt.ValueByName(lt.countField) + lt.AddRecord(record) +} + +// AddRecord adds one already-decoded flamegraph record into the live trie. +func (lt *LiveTrie) AddRecord(record IterRecord) { + value, err := record.Cnt.ValueByName(lt.countField) + if err != nil { + return + } lt.mu.Lock() frames := lt.buildFrames(record) lt.addLocked(frames, value) lt.version.Add(1) lt.mu.Unlock() - - ep.Recycle() } // Reset clears the trie so live snapshots start from a new baseline. @@ -114,6 +113,33 @@ func (lt *LiveTrie) Fields() []string { return out } +// CountField returns the active metric used to aggregate node values. +func (lt *LiveTrie) CountField() string { + lt.mu.RLock() + field := lt.countField + lt.mu.RUnlock() + return field +} + +// SetCountField changes the active aggregation metric and starts a new baseline. +func (lt *LiveTrie) SetCountField(countField string) error { + field := strings.TrimSpace(countField) + if !isLiveTrieCountField(field) { + return fmt.Errorf("invalid count field %q", countField) + } + + lt.mu.Lock() + if lt.countField == field { + lt.mu.Unlock() + return nil + } + lt.countField = field + lt.resetLocked() + lt.mu.Unlock() + lt.invalidateCache() + return nil +} + // Reconfigure changes frame fields and clears accumulated data for a new baseline. func (lt *LiveTrie) Reconfigure(fields []string) error { normalized, err := normalizeLiveTrieFields(fields) @@ -222,12 +248,11 @@ func normalizeLiveTrieFields(fields []string) ([]string, error) { } func isLiveTrieField(field string) bool { - switch field { - case "path", "comm", "tracepoint", "pid", "tid", "flags": - return true - default: - return false - } + return collapse.IsValidField(field) +} + +func isLiveTrieCountField(field string) bool { + return collapse.IsValidCountField(field) } func subtreeTotal(node *trieNode) uint64 { @@ -239,29 +264,45 @@ func subtreeTotal(node *trieNode) uint64 { } func buildSnapshot(node *trieNode, depth int, minFraction float64, rootTotal uint64) *trieSnapshot { - snapshot, _ := buildSnapshotWithTotal(node, depth, minFraction, rootTotal) + snapshot, _ := buildSnapshotWithTotal(node, depth, minFraction, rootTotal, false) return snapshot } -func buildSnapshotWithTotal(node *trieNode, depth int, minFraction float64, rootTotal uint64) (*trieSnapshot, uint64) { +type childSnapshotState struct { + node *trieNode + snapshot *trieSnapshot + total uint64 +} + +func buildSnapshotWithTotal(node *trieNode, depth int, minFraction float64, rootTotal uint64, forceKeep bool) (*trieSnapshot, uint64) { total := node.value children := slices.Clone(node.children) sort.Slice(children, func(i, j int) bool { return children[i].name < children[j].name }) - childSnapshots := make([]*trieSnapshot, 0, len(children)) + childStates := make([]childSnapshotState, 0, len(children)) for _, child := range children { - childSnapshot, childTotal := buildSnapshotWithTotal(child, depth+1, minFraction, rootTotal) + childSnapshot, childTotal := buildSnapshotWithTotal(child, depth+1, minFraction, rootTotal, false) total += childTotal - if childSnapshot != nil { - childSnapshots = append(childSnapshots, childSnapshot) - } + childStates = append(childStates, childSnapshotState{ + node: child, + snapshot: childSnapshot, + total: childTotal, + }) } - if depth > 0 && rootTotal > 0 && float64(total)/float64(rootTotal) < minFraction { + if !forceKeep && depth > 0 && rootTotal > 0 && float64(total)/float64(rootTotal) < minFraction { return nil, total } + ensureFallbackVisibleChildren(childStates, depth, minFraction, rootTotal) + + childSnapshots := make([]*trieSnapshot, 0, len(childStates)) + for _, child := range childStates { + if child.snapshot != nil { + childSnapshots = append(childSnapshots, child.snapshot) + } + } snapshot := &trieSnapshot{ Name: node.name, @@ -273,3 +314,43 @@ func buildSnapshotWithTotal(node *trieNode, depth int, minFraction float64, root } return snapshot, total } + +func ensureFallbackVisibleChildren(children []childSnapshotState, depth int, minFraction float64, rootTotal uint64) { + if depth > liveTrieVisibleChildrenFallbackMaxDepth { + return + } + visible := 0 + for _, child := range children { + if child.snapshot != nil { + visible++ + } + } + if visible > 0 { + return + } + + candidates := make([]int, 0, len(children)) + for idx, child := range children { + if child.total > 0 { + candidates = append(candidates, idx) + } + } + sort.Slice(candidates, func(i, j int) bool { + left := children[candidates[i]] + right := children[candidates[j]] + if left.total == right.total { + return left.node.name < right.node.name + } + return left.total > right.total + }) + + limit := liveTrieMinVisibleChildrenWhenPruned + if len(candidates) < limit { + limit = len(candidates) + } + for i := 0; i < limit; i++ { + idx := candidates[i] + forced, _ := buildSnapshotWithTotal(children[idx].node, depth+1, minFraction, rootTotal, true) + children[idx].snapshot = forced + } +} diff --git a/internal/flamegraph/livetrie_test.go b/internal/flamegraph/livetrie_test.go index 1315c71..53bdf1f 100644 --- a/internal/flamegraph/livetrie_test.go +++ b/internal/flamegraph/livetrie_test.go @@ -4,15 +4,16 @@ import ( "bytes" "encoding/json" "fmt" - "ior/internal/event" - "ior/internal/file" - "ior/internal/types" "os" "runtime" "sync" "sync/atomic" "testing" "time" + + "ior/internal/event" + "ior/internal/file" + "ior/internal/types" ) func TestLiveTrieIngestAndSnapshotRoundTrip(t *testing.T) { @@ -47,6 +48,39 @@ func TestLiveTrieIngestIsAdditive(t *testing.T) { } } +func TestLiveTrieCommTracepointPathAggregatesSameSyscallAcrossPaths(t *testing.T) { + lt := NewLiveTrie([]string{"comm", "tracepoint", "path"}, "count") + lt.AddRecord(IterRecord{ + Path: "/srv/a", + TraceID: types.SYS_ENTER_READ, + Comm: "svc", + Pid: 1001, + Tid: 1001, + Cnt: Counter{Count: 1}, + }) + lt.AddRecord(IterRecord{ + Path: "/srv/b", + TraceID: types.SYS_ENTER_READ, + Comm: "svc", + Pid: 1002, + Tid: 1002, + Cnt: Counter{Count: 1}, + }) + + snap := decodeLiveSnapshot(t, lt) + commNode := findSnapshotPath(t, &snap, "svc") + if len(commNode.Children) != 1 { + t.Fatalf("expected one syscall child under comm node, got %d", len(commNode.Children)) + } + syscallNode := commNode.Children[0] + if got, want := syscallNode.Name, "enter_read"; got != want { + t.Fatalf("syscall child name = %q, want %q", got, want) + } + if got, want := syscallNode.Total, uint64(2); got != want { + t.Fatalf("syscall aggregate total = %d, want %d", got, want) + } +} + func TestLiveTrieVersionIncrementsPerIngest(t *testing.T) { lt := NewLiveTrie([]string{"comm"}, "count") if got := lt.Version(); got != 0 { @@ -60,6 +94,70 @@ func TestLiveTrieVersionIncrementsPerIngest(t *testing.T) { } } +func TestLiveTrieAddRecordIncrementsVersion(t *testing.T) { + lt := NewLiveTrie([]string{"comm", "path", "tracepoint"}, "count") + lt.AddRecord(IterRecord{ + Path: "/tmp/demo/read", + TraceID: types.SYS_ENTER_READ, + Comm: "demo", + Pid: 1001, + Tid: 1001, + Cnt: Counter{Count: 7, Duration: 70, DurationToPrev: 14, Bytes: 28}, + }) + + if got := lt.Version(); got != 1 { + t.Fatalf("version = %d, want 1", got) + } + snap := decodeLiveSnapshot(t, lt) + if snap.Total != 7 { + t.Fatalf("root total = %d, want 7", snap.Total) + } +} + +func TestSeedTestFlameDataBuildsStaticFixture(t *testing.T) { + lt := NewLiveTrie([]string{"comm", "path", "tracepoint"}, "count") + SeedTestFlameData(lt) + + if got := lt.Version(); got == 0 { + t.Fatalf("expected seed fixture to add records") + } + snap := decodeLiveSnapshot(t, lt) + if snap.Total == 0 { + t.Fatalf("expected non-empty seeded snapshot") + } + if findSnapshotChild(&snap, "api") == nil { + t.Fatalf("expected seeded snapshot to include api branch") + } + if findSnapshotChild(&snap, "worker") == nil { + t.Fatalf("expected seeded snapshot to include worker branch") + } +} + +func TestSeedTestLiveFlameDataVariesByTick(t *testing.T) { + lt := NewLiveTrie([]string{"comm", "path", "tracepoint"}, "count") + + SeedTestLiveFlameData(lt, 0) + snapTick0 := decodeLiveSnapshot(t, lt) + apiTick0 := findSnapshotPath(t, &snapTick0, "api").Total + workerTick0 := findSnapshotPath(t, &snapTick0, "worker").Total + + lt.Reset() + SeedTestLiveFlameData(lt, 1) + snapTick1 := decodeLiveSnapshot(t, lt) + apiTick1 := findSnapshotPath(t, &snapTick1, "api").Total + workerTick1 := findSnapshotPath(t, &snapTick1, "worker").Total + + if apiTick0 == apiTick1 && workerTick0 == workerTick1 { + t.Fatalf("expected phase shift to alter branch totals, got api=%d worker=%d for both ticks", apiTick0, workerTick0) + } + if apiTick0 <= workerTick0 { + t.Fatalf("expected api to dominate at tick 0, got api=%d worker=%d", apiTick0, workerTick0) + } + if workerTick1 <= apiTick1 { + t.Fatalf("expected worker to dominate at tick 1, got worker=%d api=%d", workerTick1, apiTick1) + } +} + func TestLiveTrieResetClearsDataAndAdvancesVersion(t *testing.T) { lt := NewLiveTrie([]string{"comm"}, "count") lt.Ingest(newTestPair("svc", 42, 1001, "/tmp/a", 1, 1, 1)) @@ -125,6 +223,54 @@ func TestLiveTrieReconfigureRejectsInvalidFields(t *testing.T) { } } +func TestLiveTrieSetCountFieldSwitchesMetricAndResetsBaseline(t *testing.T) { + lt := NewLiveTrie([]string{"comm"}, "count") + lt.Ingest(newTestPair("svc", 42, 1001, "/tmp/a", 10, 1, 64)) + + initial := decodeLiveSnapshot(t, lt) + if got, want := initial.Total, uint64(1); got != want { + t.Fatalf("count snapshot total = %d, want %d", got, want) + } + + if err := lt.SetCountField("bytes"); err != nil { + t.Fatalf("set count field: %v", err) + } + if got, want := lt.CountField(), "bytes"; got != want { + t.Fatalf("count field = %q, want %q", got, want) + } + + empty := decodeLiveSnapshot(t, lt) + if got := empty.Total; got != 0 { + t.Fatalf("expected reset baseline after metric switch, total=%d", got) + } + + lt.Ingest(newTestPair("svc", 42, 1002, "/tmp/b", 10, 1, 64)) + bytesSnap := decodeLiveSnapshot(t, lt) + if got, want := bytesSnap.Total, uint64(64); got != want { + t.Fatalf("bytes snapshot total = %d, want %d", got, want) + } + leaf := findSnapshotPath(t, &bytesSnap, "svc") + if got, want := leaf.Total, uint64(64); got != want { + t.Fatalf("bytes leaf total = %d, want %d", got, want) + } +} + +func TestLiveTrieSetCountFieldRejectsInvalidValue(t *testing.T) { + lt := NewLiveTrie([]string{"comm"}, "count") + lt.Ingest(newTestPair("svc", 42, 1001, "/tmp/a", 1, 1, 1)) + beforeVersion := lt.Version() + + if err := lt.SetCountField("bogus"); err == nil { + t.Fatalf("expected invalid count field error") + } + if got, want := lt.CountField(), "count"; got != want { + t.Fatalf("count field changed unexpectedly: got %q want %q", got, want) + } + if got := lt.Version(); got != beforeVersion { + t.Fatalf("version changed on invalid count field: got %d want %d", got, beforeVersion) + } +} + func TestLiveTrieSnapshotJSONCaching(t *testing.T) { lt := NewLiveTrie([]string{"comm"}, "count") lt.Ingest(newTestPair("svc", 42, 1001, "/tmp/a", 1, 1, 1)) @@ -156,6 +302,41 @@ func TestLiveTrieSnapshotJSONPrunesTinyNodes(t *testing.T) { } } +func TestLiveTrieSnapshotJSONKeepsFallbackChildrenWhenAllAreTinyAtRoot(t *testing.T) { + lt := NewLiveTrie([]string{"comm"}, "count") + const total = 6000 + for i := 0; i < total; i++ { + comm := fmt.Sprintf("svc-%04d", i) + lt.Ingest(newTestPair(comm, 42, uint32(100000+i), "/tmp/a", 1, 1, 1)) + } + + snap := decodeLiveSnapshot(t, lt) + if len(snap.Children) == 0 { + t.Fatalf("expected fallback root children when pruning would hide every branch") + } + if got, want := len(snap.Children), liveTrieMinVisibleChildrenWhenPruned; got != want { + t.Fatalf("expected fallback to keep %d root children, got %d", want, got) + } +} + +func TestLiveTrieSnapshotJSONKeepsFallbackChildrenAtDepthOne(t *testing.T) { + lt := NewLiveTrie([]string{"comm", "pid"}, "count") + const total = 6000 + for i := 0; i < total; i++ { + pid := uint32(100000 + i) + lt.Ingest(newTestPair("svc", pid, pid, "/tmp/a", 1, 1, 1)) + } + + snap := decodeLiveSnapshot(t, lt) + commNode := findSnapshotPath(t, &snap, "svc") + if len(commNode.Children) == 0 { + t.Fatalf("expected fallback depth-one children for pid branches") + } + if got, want := len(commNode.Children), liveTrieMinVisibleChildrenWhenPruned; got != want { + t.Fatalf("expected fallback to keep %d depth-one children, got %d", want, got) + } +} + func TestLiveTrieConcurrentIngestAndSnapshot(t *testing.T) { lt := NewLiveTrie([]string{"comm", "pid"}, "count") diff --git a/internal/flamegraph/nativejson.go b/internal/flamegraph/nativejson.go deleted file mode 100644 index 088bcfc..0000000 --- a/internal/flamegraph/nativejson.go +++ /dev/null @@ -1,86 +0,0 @@ -package flamegraph - -import ( - "encoding/json" - "fmt" - "io" - "iter" - "os" - "strings" -) - -type jsonNode struct { - Name string `json:"name"` - Value uint64 `json:"value"` - Total uint64 `json:"total"` - Children []jsonNode `json:"children,omitempty"` -} - -type jsonFlamegraph struct { - Fields []string `json:"fields"` - CountField string `json:"countField"` - Root jsonNode `json:"root"` -} - -func (n NativeSVG) WriteJSONFromFile(iorDataFile string) (outFile string, err error) { - outFile = fmt.Sprintf("%s.%s-by-%s.json", - strings.TrimSuffix(iorDataFile, ".ior.zst"), - strings.Join(n.fields, ":"), - n.countField, - ) - defer func() { - if err != nil { - _ = os.Remove(outFile) - } - }() - - iod, err := newIorDataFromFile(iorDataFile) - if err != nil { - return outFile, fmt.Errorf("read ior data: %w", err) - } - - fd, err := os.Create(outFile) - if err != nil { - return outFile, fmt.Errorf("create output %s: %w", outFile, err) - } - defer fd.Close() - - if err := n.WriteJSONFromIter(iod.iter(), fd); err != nil { - return outFile, err - } - return outFile, nil -} - -func (n NativeSVG) WriteJSONFromIter(records iter.Seq[IterRecord], w io.Writer) error { - tr, err := n.buildTrieFromIter(records) - if err != nil { - return err - } - - payload := jsonFlamegraph{ - Fields: append([]string(nil), n.fields...), - CountField: n.countField, - Root: jsonNodeFromTrieNode(tr.root, "root"), - } - - enc := json.NewEncoder(w) - enc.SetIndent("", " ") - return enc.Encode(payload) -} - -func jsonNodeFromTrieNode(node *trieNode, name string) jsonNode { - out := jsonNode{ - Name: name, - Value: node.value, - Total: node.total, - } - if len(node.children) == 0 { - return out - } - - out.Children = make([]jsonNode, 0, len(node.children)) - for _, child := range node.children { - out.Children = append(out.Children, jsonNodeFromTrieNode(child, child.name)) - } - return out -} diff --git a/internal/flamegraph/nativejson_test.go b/internal/flamegraph/nativejson_test.go deleted file mode 100644 index c76d327..0000000 --- a/internal/flamegraph/nativejson_test.go +++ /dev/null @@ -1,75 +0,0 @@ -package flamegraph - -import ( - "encoding/json" - "os" - "testing" -) - -type jsonNodeForTest struct { - Name string `json:"name"` - Value uint64 `json:"value"` - Total uint64 `json:"total"` - Children []jsonNodeForTest `json:"children"` -} - -type jsonFlamegraphForTest struct { - Fields []string `json:"fields"` - CountField string `json:"countField"` - Root jsonNodeForTest `json:"root"` -} - -func TestWriteJSONFromFileContainsFlamegraphTree(t *testing.T) { - dir := t.TempDir() - iorFile := writeTestIorZst(t, dir) - - n := NewNativeSVG([]string{"comm", "path", "tracepoint"}, "count") - outFile, err := n.WriteJSONFromFile(iorFile) - if err != nil { - t.Fatalf("WriteJSONFromFile returned error: %v", err) - } - - data, err := os.ReadFile(outFile) - if err != nil { - t.Fatalf("read output json: %v", err) - } - - var payload jsonFlamegraphForTest - if err := json.Unmarshal(data, &payload); err != nil { - t.Fatalf("unmarshal output json: %v", err) - } - - if payload.CountField != "count" { - t.Fatalf("count field = %q, want %q", payload.CountField, "count") - } - if len(payload.Fields) != 3 { - t.Fatalf("fields len = %d, want 3", len(payload.Fields)) - } - if payload.Root.Name != "root" { - t.Fatalf("root name = %q, want %q", payload.Root.Name, "root") - } - if payload.Root.Total != 1 { - t.Fatalf("root total = %d, want 1", payload.Root.Total) - } - if len(payload.Root.Children) != 1 { - t.Fatalf("root children len = %d, want 1", len(payload.Root.Children)) - } - if payload.Root.Children[0].Name != "tester" { - t.Fatalf("root child name = %q, want %q", payload.Root.Children[0].Name, "tester") - } -} - -func TestWriteJSONFromFileCleansUpPartialOutputOnError(t *testing.T) { - dir := t.TempDir() - iorFile := writeTestIorZst(t, dir) - - n := NewNativeSVG([]string{"invalidField"}, "count") - outFile, err := n.WriteJSONFromFile(iorFile) - if err == nil { - t.Fatal("expected error for invalid field, got nil") - } - - if _, statErr := os.Stat(outFile); !os.IsNotExist(statErr) { - t.Fatalf("expected partial output to be removed, stat err=%v", statErr) - } -} diff --git a/internal/flamegraph/nativesvg.go b/internal/flamegraph/nativesvg.go deleted file mode 100644 index 80061b4..0000000 --- a/internal/flamegraph/nativesvg.go +++ /dev/null @@ -1,97 +0,0 @@ -package flamegraph - -import ( - "fmt" - "io" - "iter" - "os" - "strings" -) - -// NativeSVG generates interactive flamegraph SVGs directly from .ior.zst data files. -// -// Flamegraphs are generated natively by ior from .ior.zst data files; no external -// flamegraph tool is required. The CLI typically drives this via the -ior flag, -// which reads trace data, aggregates it into a trie of stack frames (e.g. comm,path,tracepoint) -// and renders a self-contained SVG that can be viewed in a browser. -type NativeSVG struct { - fields []string - countField string - config SVGConfig -} - -func NewNativeSVG(fields []string, countField string) NativeSVG { - return NativeSVG{ - fields: fields, - countField: countField, - config: defaultSVGConfig(), - } -} - -func (n NativeSVG) WriteSVGFromFile(iorDataFile string) (outFile string, err error) { - outFile = fmt.Sprintf("%s.%s-by-%s.svg", - strings.TrimSuffix(iorDataFile, ".ior.zst"), - strings.Join(n.fields, ":"), - n.countField, - ) - defer func() { - if err != nil { - _ = os.Remove(outFile) - } - }() - - iod, err := newIorDataFromFile(iorDataFile) - if err != nil { - return outFile, fmt.Errorf("read ior data: %w", err) - } - - fd, err := os.Create(outFile) - if err != nil { - return outFile, fmt.Errorf("create output %s: %w", outFile, err) - } - defer fd.Close() - - if err := n.WriteSVGFromIter(iod.iter(), fd); err != nil { - return outFile, err - } - return outFile, nil -} - -func (n NativeSVG) WriteSVGFromIter(records iter.Seq[IterRecord], w io.Writer) error { - tr, err := n.buildTrieFromIter(records) - if err != nil { - return err - } - return WriteSVG(w, tr, n.config) -} - -func (n NativeSVG) buildTrieFromIter(records iter.Seq[IterRecord]) (*trie, error) { - tr := newTrie() - var framesBuf []string - for record := range records { - frames, err := n.recordFrames(record, framesBuf) - if err != nil { - return nil, err - } - framesBuf = frames - tr.add(frames, record.Cnt.ValueByName(n.countField)) - } - tr.computeTotals() - return tr, nil -} - -func (n NativeSVG) recordFrames(record IterRecord, framesBuf []string) ([]string, error) { - frames := framesBuf[:0] - for _, fieldName := range n.fields { - value, err := record.StringByName(fieldName) - if err != nil { - return nil, fmt.Errorf("field %s: %w", fieldName, err) - } - for _, part := range strings.Split(value, ";") { - if part != "" { - frames = append(frames, part) - } - } - } - return frames, nil -} diff --git a/internal/flamegraph/nativesvg_test.go b/internal/flamegraph/nativesvg_test.go deleted file mode 100644 index 36e88bf..0000000 --- a/internal/flamegraph/nativesvg_test.go +++ /dev/null @@ -1,60 +0,0 @@ -package flamegraph - -import ( - "os" - "path/filepath" - "syscall" - "testing" - - "ior/internal/types" - - "github.com/DataDog/zstd" -) - -func writeTestIorZst(t *testing.T, dir string) string { - t.Helper() - - iod := newIorData() - iod.add("/tmp/test", types.SYS_ENTER_OPENAT, "tester", 100, 200, flagsType(syscall.O_RDONLY), Counter{ - Count: 1, - Duration: 10, - DurationToPrev: 2, - Bytes: 0, - }) - serialized, err := iod.serialize() - if err != nil { - t.Fatalf("serialize: %v", err) - } - - path := filepath.Join(dir, "sample.ior.zst") - fd, err := os.Create(path) - if err != nil { - t.Fatalf("create test ior file: %v", err) - } - defer fd.Close() - - enc := zstd.NewWriter(fd) - if _, err := enc.Write(serialized); err != nil { - t.Fatalf("write zstd payload: %v", err) - } - if err := enc.Close(); err != nil { - t.Fatalf("close zstd writer: %v", err) - } - - return path -} - -func TestWriteSVGFromFileCleansUpPartialOutputOnError(t *testing.T) { - dir := t.TempDir() - iorFile := writeTestIorZst(t, dir) - - n := NewNativeSVG([]string{"invalidField"}, "count") - outFile, err := n.WriteSVGFromFile(iorFile) - if err == nil { - t.Fatal("expected error for invalid field, got nil") - } - - if _, statErr := os.Stat(outFile); !os.IsNotExist(statErr) { - t.Fatalf("expected partial output to be removed, stat err=%v", statErr) - } -} diff --git a/internal/flamegraph/svgwriter.go b/internal/flamegraph/svgwriter.go deleted file mode 100644 index 7fd699e..0000000 --- a/internal/flamegraph/svgwriter.go +++ /dev/null @@ -1,151 +0,0 @@ -package flamegraph - -import ( - "bufio" - "fmt" - "hash" - "hash/fnv" - "io" - "strings" - "sync" -) - -var svgEscaper = strings.NewReplacer( - "&", "&", - "<", "<", - ">", ">", - `"`, """, - "'", "'", -) - -var fnv32aPool = sync.Pool{ - New: func() any { - return fnv.New32a() - }, -} - -// SVGConfig controls the layout and styling of generated flamegraph SVGs. -// -// Width is the virtual canvas width in pixels, FrameHeight is the height of each -// stack frame row, FontSize is the base font size, and MinWidthPx controls the -// minimum rendered width for a frame (smaller frames are skipped to avoid noise). -type SVGConfig struct { - Title string - Width int - FrameHeight int - FontSize int - MinWidthPx float64 -} - -func defaultSVGConfig() SVGConfig { - return SVGConfig{ - Title: "I/O Flame Graph", - Width: 1200, - FrameHeight: 16, - FontSize: 12, - MinWidthPx: 1.0, - } -} - -// DefaultSVGConfig returns the default SVG configuration values. -func DefaultSVGConfig() SVGConfig { - return defaultSVGConfig() -} - -// WriteSVG renders a flamegraph trie into an interactive SVG document. -// -// The output is a self-contained SVG that includes embedded CSS and JavaScript -// for zoom, search, and highlighting, and is designed to be served directly to -// a browser (for example via ServeSVG) without any external assets. -func WriteSVG(w io.Writer, t *trie, cfg SVGConfig) error { - cfg = sanitizeSVGConfig(cfg) - - canvasHeight := canvasHeightFor(cfg, t) - bw := bufio.NewWriter(w) - if err := writeSVGHeader(bw, cfg, canvasHeight); err != nil { - return err - } - for _, frame := range BuildFrameLayout(t, cfg) { - if err := writeFrame(bw, frame.Name, frame.Title, frame.Fill, - frame.X, frame.Y, frame.Width, frame.Height, frame.Depth, cfg.FontSize); err != nil { - return err - } - } - if err := writeSVGFooter(bw); err != nil { - return err - } - return bw.Flush() -} - -func writeSVGHeader(bw *bufio.Writer, cfg SVGConfig, height int) error { - _, err := fmt.Fprintf(bw, `<svg xmlns="http://www.w3.org/2000/svg" width="100%%" height="%d" viewBox="0 0 %d %d" preserveAspectRatio="xMinYMin meet">`+"\n", - height, cfg.Width, height) - if err != nil { - return err - } - _, err = fmt.Fprintf(bw, "<style><![CDATA[%s]]></style>\n", flamegraphCSS(cfg)) - if err != nil { - return err - } - _, err = fmt.Fprintf(bw, "<script><![CDATA[%s]]></script>\n", flamegraphJS) - if err != nil { - return err - } - _, err = fmt.Fprintf(bw, `<text class="title" x="10" y="22">%s</text>`+"\n", svgEscape(cfg.Title)) - if err != nil { - return err - } - _, err = fmt.Fprintf(bw, `<g class="controls"><text x="10" y="42" onclick="fgSearch()">Search</text><text x="80" y="42" onclick="fgResetSearch()">Reset Search</text><text x="190" y="42" onclick="fgUndoZoom()">Undo Zoom</text><text x="280" y="42" onclick="fgResetZoom()">Reset Zoom</text><text id="fg-info" x="390" y="42"></text></g>`+"\n") - return err -} - -func writeSVGFooter(bw *bufio.Writer) error { - _, err := fmt.Fprintln(bw, "</svg>") - return err -} - -func writeFrame(bw *bufio.Writer, name, title, fill string, x, y, w, h float64, depth, fontSize int) error { - textStyle := "" - labelStyle := "" - if w <= float64(fontSize*2) { - labelStyle = ` style="display:none"` - } - if labelStyle != "" { - textStyle = labelStyle - } - _, err := fmt.Fprintf(bw, `<g class="frame" data-name="%s" data-x="%.3f" data-w="%.3f" data-depth="%d" data-base-fill="%s"> -<title>%s</title><rect x="%.3f" y="%.3f" width="%.3f" height="%.3f" fill="%s"/> -<text x="%.3f" y="%.3f"%s>%s</text> -</g> -`, - svgEscape(name), x, w, depth, fill, - svgEscape(title), x, y, w, h, fill, - x+3, y+float64(fontSize), textStyle, svgEscape(name)) - return err -} - -func frameColor(name string) string { - hasher := fnv32aPool.Get().(hash.Hash32) - hasher.Reset() - _, _ = io.WriteString(hasher, name) - h := hasher.Sum32() - fnv32aPool.Put(hasher) - r := 200 + int(h%35) - g := 80 + int((h>>8)%120) - b := 40 + int((h>>16)%90) - return fmt.Sprintf("rgb(%d,%d,%d)", r, g, b) -} - -func flamegraphCSS(cfg SVGConfig) string { - return fmt.Sprintf(` -.title { font-size: %dpx; font-family: monospace; } -.controls text { font-size: %dpx; font-family: monospace; cursor: pointer; fill: #444; } -.frame text { font-size: %dpx; font-family: monospace; pointer-events: none; fill: #111; } -.frame rect { stroke: rgba(0,0,0,0.18); stroke-width: 0.5; } -.title, .controls text, .frame text { user-select: none; -webkit-user-select: none; } -`, cfg.FontSize+2, cfg.FontSize, cfg.FontSize-1) -} - -func svgEscape(s string) string { - return svgEscaper.Replace(s) -} diff --git a/internal/flamegraph/svgwriter_js.go b/internal/flamegraph/svgwriter_js.go deleted file mode 100644 index bf8bfd2..0000000 --- a/internal/flamegraph/svgwriter_js.go +++ /dev/null @@ -1,212 +0,0 @@ -package flamegraph - -const flamegraphJS = ` -const fg = { - frames: [], - info: null, - matchColor: "rgb(220, 30, 70)", - zoomStack: [], - zoomRange: null, - rootWidth: 0, -}; - -function fgInit() { - fg.frames = Array.from(document.querySelectorAll("g.frame")); - fg.info = document.getElementById("fg-info"); - fg.rootWidth = fgDetectRootWidth(); - fg.frames.forEach((frame) => { - fgSnapshotOriginalGeometry(frame); - frame.addEventListener("click", (ev) => { - if (ev.detail > 1) return; - ev.stopPropagation(); - fgZoom(ev.currentTarget); - }); - frame.addEventListener("dblclick", (ev) => { - ev.preventDefault(); - ev.stopPropagation(); - fgResetZoom(); - }); - frame.addEventListener("mouseenter", (ev) => fgHover(ev.currentTarget)); - }); - document.addEventListener("dblclick", (ev) => { - ev.preventDefault(); - fgResetZoom(); - }); -} - -function fgHover(frame) { - if (!fg.info) return; - const title = frame.querySelector("title"); - fg.info.textContent = title ? title.textContent : ""; -} - -function fgZoom(frame) { - const x = fgOriginalX(frame); - const w = fgOriginalW(frame); - if (w <= 0) return; - if (fg.zoomRange) { - fg.zoomStack.push(fg.zoomRange); - } - fg.zoomRange = { x: x, w: w, depth: Number(frame.dataset.depth || "0") }; - fgApplyZoom(); -} - -function fgApplyZoom() { - if (!fg.zoomRange) { - fg.frames.forEach((frame) => { - frame.style.display = ""; - }); - return; - } - const x = fg.zoomRange.x; - const end = x + fg.zoomRange.w; - const width = fg.zoomRange.w; - const minDepth = fg.zoomRange.depth; - const eps = 1e-6; - const scale = fg.rootWidth / width; - fg.frames.forEach((other) => { - const ox = fgOriginalX(other); - const ow = fgOriginalW(other); - const depth = Number(other.dataset.depth || "0"); - const inSelectedRange = ox >= x-eps && ox+ow <= end+eps; - const isAncestor = depth < minDepth && ox <= x+eps && ox+ow >= end-eps; - - if (isAncestor || (depth >= minDepth && inSelectedRange)) { - if (isAncestor) { - fgSetFrameGeometry(other, 0, fg.rootWidth); - } else { - fgSetFrameGeometry(other, (ox-x)*scale, ow*scale); - } - other.style.display = ""; - } else { - other.style.display = "none"; - } - }); -} - -function fgUndoZoom() { - if (fg.zoomStack.length === 0) { - fgResetZoom(); - return; - } - fg.zoomRange = fg.zoomStack.pop(); - fgApplyZoom(); -} - -function fgResetZoom() { - fg.zoomStack = []; - fg.zoomRange = null; - fg.frames.forEach((frame) => { - fgRestoreFrameGeometry(frame); - frame.style.display = ""; - }); -} - -function fgSearch() { - const needle = prompt("Search frames (substring):", ""); - if (needle === null) return; - const q = needle.trim().toLowerCase(); - fg.frames.forEach((frame) => { - const rect = frame.querySelector("rect"); - const base = frame.dataset.baseFill || ""; - const name = (frame.dataset.name || "").toLowerCase(); - if (!rect) return; - if (q !== "" && name.includes(q)) { - rect.style.fill = fg.matchColor; - } else { - rect.style.fill = base; - } - }); -} - -function fgResetSearch() { - fg.frames.forEach((frame) => { - const rect = frame.querySelector("rect"); - if (!rect) return; - rect.style.fill = frame.dataset.baseFill || ""; - }); -} - -function fgDetectRootWidth() { - let maxEnd = 0; - fg.frames.forEach((frame) => { - const x = Number(frame.dataset.x || "0"); - const w = Number(frame.dataset.w || "0"); - maxEnd = Math.max(maxEnd, x + w); - }); - return maxEnd; -} - -function fgSnapshotOriginalGeometry(frame) { - const rect = frame.querySelector("rect"); - const text = frame.querySelector("text"); - frame.dataset.ox = frame.dataset.x || "0"; - frame.dataset.ow = frame.dataset.w || "0"; - if (rect) { - rect.dataset.ox = rect.getAttribute("x") || "0"; - rect.dataset.ow = rect.getAttribute("width") || "0"; - } - if (text) { - text.dataset.ox = text.getAttribute("x") || "0"; - text.dataset.hidden = text.style.display === "none" ? "1" : "0"; - text.dataset.full = text.textContent || frame.dataset.name || ""; - } -} - -function fgOriginalX(frame) { - return Number(frame.dataset.ox || frame.dataset.x || "0"); -} - -function fgOriginalW(frame) { - return Number(frame.dataset.ow || frame.dataset.w || "0"); -} - -function fgSetFrameGeometry(frame, x, w) { - const rect = frame.querySelector("rect"); - const text = frame.querySelector("text"); - if (rect) { - rect.setAttribute("x", String(x)); - rect.setAttribute("width", String(w)); - } - if (text) { - text.setAttribute("x", String(x + 3)); - fgFitLabel(text, w); - } -} - -function fgRestoreFrameGeometry(frame) { - const rect = frame.querySelector("rect"); - const text = frame.querySelector("text"); - if (rect) { - rect.setAttribute("x", rect.dataset.ox || "0"); - rect.setAttribute("width", rect.dataset.ow || "0"); - } - if (text) { - text.setAttribute("x", text.dataset.ox || "0"); - if (text.dataset.hidden === "1") { - text.style.display = "none"; - text.textContent = text.dataset.full || ""; - } else { - fgFitLabel(text, Number(rect ? (rect.dataset.ow || "0") : "0")); - } - } -} - -function fgFitLabel(text, width) { - const full = text.dataset.full || text.textContent || ""; - const maxChars = Math.floor((width - 6) / 7); - if (maxChars < 3) { - text.style.display = "none"; - text.textContent = full; - return; - } - text.style.display = ""; - if (full.length <= maxChars) { - text.textContent = full; - return; - } - text.textContent = full.slice(0, maxChars - 1) + "…"; -} - -window.addEventListener("DOMContentLoaded", fgInit); -` diff --git a/internal/flamegraph/svgwriter_jscode.go b/internal/flamegraph/svgwriter_jscode.go deleted file mode 100644 index 3ac00fd..0000000 --- a/internal/flamegraph/svgwriter_jscode.go +++ /dev/null @@ -1,214 +0,0 @@ -//go:build !js - -package flamegraph - -const flamegraphJS = ` -const fg = { - frames: [], - info: null, - matchColor: "rgb(220, 30, 70)", - zoomStack: [], - zoomRange: null, - rootWidth: 0, -}; - -function fgInit() { - fg.frames = Array.from(document.querySelectorAll("g.frame")); - fg.info = document.getElementById("fg-info"); - fg.rootWidth = fgDetectRootWidth(); - fg.frames.forEach((frame) => { - fgSnapshotOriginalGeometry(frame); - frame.addEventListener("click", (ev) => { - if (ev.detail > 1) return; - ev.stopPropagation(); - fgZoom(ev.currentTarget); - }); - frame.addEventListener("dblclick", (ev) => { - ev.preventDefault(); - ev.stopPropagation(); - fgResetZoom(); - }); - frame.addEventListener("mouseenter", (ev) => fgHover(ev.currentTarget)); - }); - document.addEventListener("dblclick", (ev) => { - ev.preventDefault(); - fgResetZoom(); - }); -} - -function fgHover(frame) { - if (!fg.info) return; - const title = frame.querySelector("title"); - fg.info.textContent = title ? title.textContent : ""; -} - -function fgZoom(frame) { - const x = fgOriginalX(frame); - const w = fgOriginalW(frame); - if (w <= 0) return; - if (fg.zoomRange) { - fg.zoomStack.push(fg.zoomRange); - } - fg.zoomRange = { x: x, w: w, depth: Number(frame.dataset.depth || "0") }; - fgApplyZoom(); -} - -function fgApplyZoom() { - if (!fg.zoomRange) { - fg.frames.forEach((frame) => { - frame.style.display = ""; - }); - return; - } - const x = fg.zoomRange.x; - const end = x + fg.zoomRange.w; - const width = fg.zoomRange.w; - const minDepth = fg.zoomRange.depth; - const eps = 1e-6; - const scale = fg.rootWidth / width; - fg.frames.forEach((other) => { - const ox = fgOriginalX(other); - const ow = fgOriginalW(other); - const depth = Number(other.dataset.depth || "0"); - const inSelectedRange = ox >= x-eps && ox+ow <= end+eps; - const isAncestor = depth < minDepth && ox <= x+eps && ox+ow >= end-eps; - - if (isAncestor || (depth >= minDepth && inSelectedRange)) { - if (isAncestor) { - fgSetFrameGeometry(other, 0, fg.rootWidth); - } else { - fgSetFrameGeometry(other, (ox-x)*scale, ow*scale); - } - other.style.display = ""; - } else { - other.style.display = "none"; - } - }); -} - -function fgUndoZoom() { - if (fg.zoomStack.length === 0) { - fgResetZoom(); - return; - } - fg.zoomRange = fg.zoomStack.pop(); - fgApplyZoom(); -} - -function fgResetZoom() { - fg.zoomStack = []; - fg.zoomRange = null; - fg.frames.forEach((frame) => { - fgRestoreFrameGeometry(frame); - frame.style.display = ""; - }); -} - -function fgSearch() { - const needle = prompt("Search frames (substring):", ""); - if (needle === null) return; - const q = needle.trim().toLowerCase(); - fg.frames.forEach((frame) => { - const rect = frame.querySelector("rect"); - const base = frame.dataset.baseFill || ""; - const name = (frame.dataset.name || "").toLowerCase(); - if (!rect) return; - if (q !== "" && name.includes(q)) { - rect.style.fill = fg.matchColor; - } else { - rect.style.fill = base; - } - }); -} - -function fgResetSearch() { - fg.frames.forEach((frame) => { - const rect = frame.querySelector("rect"); - if (!rect) return; - rect.style.fill = frame.dataset.baseFill || ""; - }); -} - -function fgDetectRootWidth() { - let maxEnd = 0; - fg.frames.forEach((frame) => { - const x = Number(frame.dataset.x || "0"); - const w = Number(frame.dataset.w || "0"); - maxEnd = Math.max(maxEnd, x + w); - }); - return maxEnd; -} - -function fgSnapshotOriginalGeometry(frame) { - const rect = frame.querySelector("rect"); - const text = frame.querySelector("text"); - frame.dataset.ox = frame.dataset.x || "0"; - frame.dataset.ow = frame.dataset.w || "0"; - if (rect) { - rect.dataset.ox = rect.getAttribute("x") || "0"; - rect.dataset.ow = rect.getAttribute("width") || "0"; - } - if (text) { - text.dataset.ox = text.getAttribute("x") || "0"; - text.dataset.hidden = text.style.display === "none" ? "1" : "0"; - text.dataset.full = text.textContent || frame.dataset.name || ""; - } -} - -function fgOriginalX(frame) { - return Number(frame.dataset.ox || frame.dataset.x || "0"); -} - -function fgOriginalW(frame) { - return Number(frame.dataset.ow || frame.dataset.w || "0"); -} - -function fgSetFrameGeometry(frame, x, w) { - const rect = frame.querySelector("rect"); - const text = frame.querySelector("text"); - if (rect) { - rect.setAttribute("x", String(x)); - rect.setAttribute("width", String(w)); - } - if (text) { - text.setAttribute("x", String(x + 3)); - fgFitLabel(text, w); - } -} - -function fgRestoreFrameGeometry(frame) { - const rect = frame.querySelector("rect"); - const text = frame.querySelector("text"); - if (rect) { - rect.setAttribute("x", rect.dataset.ox || "0"); - rect.setAttribute("width", rect.dataset.ow || "0"); - } - if (text) { - text.setAttribute("x", text.dataset.ox || "0"); - if (text.dataset.hidden === "1") { - text.style.display = "none"; - text.textContent = text.dataset.full || ""; - } else { - fgFitLabel(text, Number(rect ? (rect.dataset.ow || "0") : "0")); - } - } -} - -function fgFitLabel(text, width) { - const full = text.dataset.full || text.textContent || ""; - const maxChars = Math.floor((width - 6) / 7); - if (maxChars < 3) { - text.style.display = "none"; - text.textContent = full; - return; - } - text.style.display = ""; - if (full.length <= maxChars) { - text.textContent = full; - return; - } - text.textContent = full.slice(0, maxChars - 1) + "…"; -} - -window.addEventListener("DOMContentLoaded", fgInit); -` diff --git a/internal/flamegraph/svgwriter_test.go b/internal/flamegraph/svgwriter_test.go deleted file mode 100644 index 56f2c20..0000000 --- a/internal/flamegraph/svgwriter_test.go +++ /dev/null @@ -1,112 +0,0 @@ -package flamegraph - -import ( - "bytes" - "strings" - "testing" -) - -func renderSVGForTest(t *testing.T, tr *trie, cfg SVGConfig) string { - t.Helper() - var buf bytes.Buffer - if err := WriteSVG(&buf, tr, cfg); err != nil { - t.Fatalf("WriteSVG failed: %v", err) - } - return buf.String() -} - -func TestWriteSVGBasic(t *testing.T) { - tr := newTrie() - tr.add([]string{"a", "b"}, 3) - tr.add([]string{"a", "c"}, 2) - tr.computeTotals() - - svg := renderSVGForTest(t, tr, defaultSVGConfig()) - if !strings.Contains(svg, "<svg") || !strings.Contains(svg, "</svg>") { - t.Fatalf("expected valid svg wrapper, got: %s", svg) - } - if !strings.Contains(svg, "data-name=\"a\"") || !strings.Contains(svg, "data-name=\"b\"") { - t.Fatalf("expected rendered frame names, got: %s", svg) - } -} - -func TestWriteSVGEmptyTrie(t *testing.T) { - tr := newTrie() - tr.computeTotals() - - svg := renderSVGForTest(t, tr, defaultSVGConfig()) - if !strings.Contains(svg, "<svg") || !strings.Contains(svg, "</svg>") { - t.Fatalf("expected valid svg wrapper, got: %s", svg) - } - if strings.Contains(svg, "class=\"frame\"") { - t.Fatalf("expected no rendered frames for empty trie, got: %s", svg) - } -} - -func TestWriteSVGMinWidth(t *testing.T) { - tr := newTrie() - tr.add([]string{"wide"}, 100) - tr.add([]string{"tiny"}, 1) - tr.computeTotals() - - cfg := defaultSVGConfig() - cfg.Width = 120 - cfg.MinWidthPx = 2.0 - svg := renderSVGForTest(t, tr, cfg) - - if !strings.Contains(svg, "data-name=\"wide\"") { - t.Fatalf("expected wide frame to be rendered, got: %s", svg) - } - if strings.Contains(svg, "data-name=\"tiny\"") { - t.Fatalf("expected tiny frame to be skipped by min width, got: %s", svg) - } -} - -func TestWriteSVGTitle(t *testing.T) { - tr := newTrie() - tr.add([]string{"a"}, 1) - tr.computeTotals() - - cfg := defaultSVGConfig() - cfg.Title = "Custom Flamegraph" - svg := renderSVGForTest(t, tr, cfg) - - if !strings.Contains(svg, "Custom Flamegraph") { - t.Fatalf("expected custom title in output, got: %s", svg) - } -} - -func TestFrameColor(t *testing.T) { - colorA1 := frameColor("read") - colorA2 := frameColor("read") - colorB := frameColor("write") - - if colorA1 != colorA2 { - t.Fatalf("expected deterministic color for identical names, got %q vs %q", colorA1, colorA2) - } - if !strings.HasPrefix(colorA1, "rgb(") || !strings.HasSuffix(colorA1, ")") { - t.Fatalf("expected rgb() format, got %q", colorA1) - } - if colorA1 == colorB { - t.Fatalf("expected different colors for different names, got %q", colorA1) - } -} - -func TestWriteSVGInvalidConfigFallsBack(t *testing.T) { - tr := newTrie() - tr.add([]string{"a"}, 1) - tr.computeTotals() - - cfg := SVGConfig{Title: "x", Width: 0, FrameHeight: 0, FontSize: 0, MinWidthPx: 0} - svg := renderSVGForTest(t, tr, cfg) - - if !strings.Contains(svg, `width="100%"`) { - t.Fatalf("expected responsive svg width, got: %s", svg) - } - if !strings.Contains(svg, `viewBox="0 0 1200 `) { - t.Fatalf("expected fallback viewBox width, got: %s", svg) - } - if !strings.Contains(svg, "I/O Flame Graph") { - t.Fatalf("expected fallback title, got: %s", svg) - } -} diff --git a/internal/flamegraph/testfixture.go b/internal/flamegraph/testfixture.go new file mode 100644 index 0000000..2774925 --- /dev/null +++ b/internal/flamegraph/testfixture.go @@ -0,0 +1,120 @@ +package flamegraph + +import ( + "ior/internal/types" + "strings" +) + +// SeedTestFlameData populates a deterministic static flamegraph fixture. +// Intended for keyboard-navigation validation in TUI test-flame mode. +func SeedTestFlameData(liveTrie *LiveTrie) { + if liveTrie == nil { + return + } + for _, record := range testFlameRecords() { + liveTrie.AddRecord(record) + } +} + +// SeedTestLiveFlameData populates deterministic synthetic data for a given live tick. +// The data shape stays navigable while branch weights shift by phase so the +// terminal flamegraph visibly changes over time. +func SeedTestLiveFlameData(liveTrie *LiveTrie, tick uint64) { + if liveTrie == nil { + return + } + phase := tick % 4 + for _, base := range testFlameRecords() { + weight := liveTestWeight(base, phase) + liveTrie.AddRecord(withTestFlameWeight(base, weight)) + } +} + +func testFlameRecords() []IterRecord { + return []IterRecord{ + newTestFlameRecord("api", "/srv/api/lib/http/client/read", 2001, 2201, types.SYS_ENTER_READ, 180), + newTestFlameRecord("api", "/srv/api/lib/json/encode/write", 2001, 2201, types.SYS_ENTER_WRITE, 120), + newTestFlameRecord("api", "/srv/api/storage/postgres/query/read", 2001, 2201, types.SYS_ENTER_READ, 240), + newTestFlameRecord("api", "/srv/api/storage/postgres/commit/fsync", 2001, 2201, types.SYS_ENTER_FSYNC, 70), + newTestFlameRecord("worker", "/srv/worker/queue/pop/read", 2002, 2202, types.SYS_ENTER_READ, 160), + newTestFlameRecord("worker", "/srv/worker/queue/push/write", 2002, 2202, types.SYS_ENTER_WRITE, 145), + newTestFlameRecord("worker", "/srv/worker/cache/redis/get/read", 2002, 2202, types.SYS_ENTER_READ, 95), + newTestFlameRecord("worker", "/srv/worker/cache/redis/set/write", 2002, 2202, types.SYS_ENTER_WRITE, 90), + newTestFlameRecord("ingest", "/srv/ingest/parser/csv/read", 2003, 2203, types.SYS_ENTER_READ, 110), + newTestFlameRecord("ingest", "/srv/ingest/parser/csv/normalize/write", 2003, 2203, types.SYS_ENTER_WRITE, 80), + newTestFlameRecord("ingest", "/srv/ingest/uploader/s3/put/writev", 2003, 2203, types.SYS_ENTER_WRITEV, 75), + newTestFlameRecord("batch", "/srv/batch/jobs/report/open", 2004, 2204, types.SYS_ENTER_OPENAT, 55), + newTestFlameRecord("batch", "/srv/batch/jobs/report/close", 2004, 2204, types.SYS_ENTER_CLOSE, 35), + newTestFlameRecord("batch", "/srv/batch/jobs/report/rename", 2004, 2204, types.SYS_ENTER_RENAMEAT, 20), + } +} + +func newTestFlameRecord(comm, path string, pid, tid uint32, traceID types.TraceId, weight uint64) IterRecord { + return IterRecord{ + Path: path, + TraceID: traceID, + Comm: comm, + Pid: pid, + Tid: tid, + Cnt: Counter{ + Count: weight, + Duration: weight * 1000, + DurationToPrev: weight * 350, + Bytes: weight * 4096, + }, + } +} + +func withTestFlameWeight(record IterRecord, weight uint64) IterRecord { + record.Cnt = Counter{ + Count: weight, + Duration: weight * 1000, + DurationToPrev: weight * 350, + Bytes: weight * 4096, + } + return record +} + +func liveTestWeight(record IterRecord, phase uint64) uint64 { + base := record.Cnt.Count + multiplier := uint64(1) + + switch phase { + case 0: + if record.Comm == "api" { + multiplier += 4 + } + if strings.Contains(record.Path, "/lib/") { + multiplier += 2 + } + case 1: + if record.Comm == "worker" { + multiplier += 4 + } + if strings.Contains(record.Path, "/queue/") { + multiplier += 2 + } + case 2: + if record.Comm == "ingest" { + multiplier += 4 + } + if strings.Contains(record.Path, "/uploader/") || strings.Contains(record.Path, "/parser/") { + multiplier += 2 + } + case 3: + if record.Comm == "batch" { + multiplier += 4 + } + if strings.Contains(record.Path, "/report/") { + multiplier += 2 + } + } + + if strings.Contains(record.Path, "/storage/") && phase%2 == 0 { + multiplier++ + } + if strings.Contains(record.Path, "/cache/") && phase%2 == 1 { + multiplier++ + } + return base * multiplier +} diff --git a/internal/flamegraph/trie.go b/internal/flamegraph/trie.go index dbd3de6..022b846 100644 --- a/internal/flamegraph/trie.go +++ b/internal/flamegraph/trie.go @@ -24,20 +24,7 @@ func newTrie() *trie { } func (t *trie) add(frames []string, value uint64) { - node := t.root - for _, frame := range frames { - child, ok := node.childMap[frame] - if !ok { - child = &trieNode{ - name: frame, - childMap: make(map[string]*trieNode), - } - node.children = append(node.children, child) - node.childMap[frame] = child - } - node = child - } - node.value += value + insertTriePath(t.root, frames, value) } func (t *trie) computeTotals() { diff --git a/internal/flamegraph/trie_insert.go b/internal/flamegraph/trie_insert.go new file mode 100644 index 0000000..7748b4a --- /dev/null +++ b/internal/flamegraph/trie_insert.go @@ -0,0 +1,22 @@ +package flamegraph + +// insertTriePath follows or creates nodes for frames and adds value at the leaf. +func insertTriePath(root *trieNode, frames []string, value uint64) { + node := root + for _, frame := range frames { + if node.childMap == nil { + node.childMap = make(map[string]*trieNode) + } + child, ok := node.childMap[frame] + if !ok { + child = &trieNode{ + name: frame, + childMap: make(map[string]*trieNode), + } + node.children = append(node.children, child) + node.childMap[frame] = child + } + node = child + } + node.value += value +} diff --git a/internal/flamegraph/webserver.go b/internal/flamegraph/webserver.go deleted file mode 100644 index c472dfb..0000000 --- a/internal/flamegraph/webserver.go +++ /dev/null @@ -1,199 +0,0 @@ -package flamegraph - -import ( - "context" - "fmt" - "net" - "net/http" - "os" - "os/signal" - "path/filepath" - "strings" - "syscall" - "time" -) - -type serverTimeouts struct { - readTimeout time.Duration - writeTimeout time.Duration - idleTimeout time.Duration -} - -var defaultServerTimeouts = serverTimeouts{ - readTimeout: 10 * time.Second, - writeTimeout: 30 * time.Second, - idleTimeout: 60 * time.Second, -} - -// ServeSVG starts a small HTTP server that serves a single flamegraph SVG. -// -// It prints a URL of the form http://HOSTNAME:PORT/abs/path/to.svg and blocks until -// the user presses Ctrl+C or the process receives SIGTERM, at which point the server -// is shut down gracefully. -func ServeSVG(svgFile string) error { - absPath, err := filepath.Abs(svgFile) - if err != nil { - return fmt.Errorf("resolve svg path: %w", err) - } - urlPath := buildURLPath(absPath) - ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) - defer stop() - return runServer(ctx, buildSVGHandler(absPath, urlPath), defaultServerTimeouts, func(hostname string, port int) { - printServerURL(hostname, port, urlPath) - }) -} - -// ServeSVGAutoReload serves an SVG viewer page that periodically reloads the SVG. -// -// The SVG file itself is still served directly at its absolute URL path, while "/" -// serves a small HTML wrapper that appends a cache-busting query parameter on each -// refresh interval to pick up newly written SVG content. -func ServeSVGAutoReload(svgFile string, refreshInterval time.Duration) error { - if refreshInterval <= 0 { - return fmt.Errorf("refresh interval must be > 0") - } - - absPath, err := filepath.Abs(svgFile) - if err != nil { - return fmt.Errorf("resolve svg path: %w", err) - } - urlPath := buildURLPath(absPath) - ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) - defer stop() - - mux := buildSVGAutoReloadHandler(absPath, urlPath, refreshInterval) - return runServer(ctx, mux, defaultServerTimeouts, func(hostname string, port int) { - printServerURL(hostname, port, "/") - }) -} - -func buildURLPath(absPath string) string { - urlPath := filepath.ToSlash(absPath) - if !strings.HasPrefix(urlPath, "/") { - return "/" + urlPath - } - return urlPath -} - -func buildSVGHandler(absPath, urlPath string) *http.ServeMux { - mux := http.NewServeMux() - mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - http.Redirect(w, r, urlPath, http.StatusFound) - }) - mux.HandleFunc(urlPath, func(w http.ResponseWriter, r *http.Request) { - http.ServeFile(w, r, absPath) - }) - return mux -} - -func buildSVGAutoReloadHandler(absPath, urlPath string, refreshInterval time.Duration) *http.ServeMux { - intervalMs := refreshInterval.Milliseconds() - mux := http.NewServeMux() - mux.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Content-Type", "text/html; charset=utf-8") - _, _ = fmt.Fprintf(w, `<!doctype html> -<html> -<head> - <meta charset="utf-8"/> - <meta name="viewport" content="width=device-width, initial-scale=1"/> - <title>I/O Flamegraph (Auto-Reload)</title> - <style> - body { margin: 0; font-family: monospace; } - .bar { padding: 8px 12px; border-bottom: 1px solid #ddd; } - .viewer { width: 100%%; height: calc(100vh - 42px); border: 0; display: block; } - </style> -</head> -<body> - <div class="bar"> - Auto-refresh every %d ms. - <button type="button" onclick="refreshNow()">Refresh now</button> - </div> - <iframe id="fg" class="viewer" src="%s"></iframe> - <script> - const base = %q; - function refreshNow() { - document.getElementById("fg").src = base + "?t=" + Date.now(); - } - setInterval(refreshNow, %d); - </script> -</body> -</html> -`, intervalMs, urlPath, urlPath, intervalMs) - }) - mux.HandleFunc(urlPath, func(w http.ResponseWriter, r *http.Request) { - http.ServeFile(w, r, absPath) - }) - return mux -} - -func listenRandomPort() (net.Listener, error) { - listener, err := net.Listen("tcp", ":0") - if err != nil { - return nil, fmt.Errorf("start web server: %w", err) - } - return listener, nil -} - -func serverHostPort(listener net.Listener) (string, int) { - hostname, err := os.Hostname() - if err != nil { - hostname = "localhost" - } - port := listener.Addr().(*net.TCPAddr).Port - return hostname, port -} - -func printServerURL(hostname string, port int, urlPath string) { - fmt.Printf("Flamegraph available at http://%s:%d%s\n", hostname, port, urlPath) - fmt.Println("Press Ctrl+C to stop the web server.") -} - -func newHTTPServer(mux *http.ServeMux, timeouts serverTimeouts) *http.Server { - return &http.Server{ - Handler: mux, - ReadTimeout: timeouts.readTimeout, - WriteTimeout: timeouts.writeTimeout, - IdleTimeout: timeouts.idleTimeout, - } -} - -func runServer(ctx context.Context, mux *http.ServeMux, timeouts serverTimeouts, printURL func(hostname string, port int)) error { - srv := newHTTPServer(mux, timeouts) - - listener, err := listenRandomPort() - if err != nil { - return err - } - defer listener.Close() - - hostname, port := serverHostPort(listener) - if printURL != nil { - printURL(hostname, port) - } - - errCh := make(chan error, 1) - go func() { - errCh <- srv.Serve(listener) - }() - - select { - case <-ctx.Done(): - case serveErr := <-errCh: - if serveErr != nil && serveErr != http.ErrServerClosed { - return serveErr - } - return nil - } - - shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - if err := srv.Shutdown(shutdownCtx); err != nil { - return fmt.Errorf("shutdown web server: %w", err) - } - - serveErr := <-errCh - if serveErr != nil && serveErr != http.ErrServerClosed { - return serveErr - } - return nil -} diff --git a/internal/flamegraph/webserver_autoreload_test.go b/internal/flamegraph/webserver_autoreload_test.go deleted file mode 100644 index ed4c907..0000000 --- a/internal/flamegraph/webserver_autoreload_test.go +++ /dev/null @@ -1,37 +0,0 @@ -package flamegraph - -import ( - "net/http/httptest" - "strings" - "testing" - "time" -) - -func TestBuildSVGAutoReloadHandlerServesViewerPage(t *testing.T) { - mux := buildSVGAutoReloadHandler("/tmp/fake.svg", "/tmp/fake.svg", 2*time.Second) - - rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "http://localhost/", nil) - mux.ServeHTTP(rec, req) - - if rec.Code != 200 { - t.Fatalf("status code = %d, want 200", rec.Code) - } - if got := rec.Header().Get("Content-Type"); !strings.HasPrefix(got, "text/html") { - t.Fatalf("content type = %q, want text/html", got) - } - body := rec.Body.String() - if !strings.Contains(body, "Auto-refresh every 2000 ms.") { - t.Fatalf("viewer page missing refresh interval, body=%q", body) - } - if !strings.Contains(body, `id="fg"`) { - t.Fatalf("viewer page missing iframe, body=%q", body) - } -} - -func TestServeSVGAutoReloadRejectsNonPositiveInterval(t *testing.T) { - err := ServeSVGAutoReload("ignored.svg", 0) - if err == nil { - t.Fatal("expected error for non-positive interval") - } -} diff --git a/internal/flamegraph/webserver_timeout_test.go b/internal/flamegraph/webserver_timeout_test.go deleted file mode 100644 index c1df7e5..0000000 --- a/internal/flamegraph/webserver_timeout_test.go +++ /dev/null @@ -1,43 +0,0 @@ -package flamegraph - -import ( - "net/http" - "testing" - "time" -) - -func TestNewHTTPServerUsesConfiguredTimeouts(t *testing.T) { - mux := http.NewServeMux() - timeouts := serverTimeouts{ - readTimeout: 11 * time.Second, - writeTimeout: 44 * time.Second, - idleTimeout: 66 * time.Second, - } - - srv := newHTTPServer(mux, timeouts) - - if srv.Handler != mux { - t.Fatalf("Handler not set from mux") - } - if srv.ReadTimeout != timeouts.readTimeout { - t.Fatalf("ReadTimeout = %v, want %v", srv.ReadTimeout, timeouts.readTimeout) - } - if srv.WriteTimeout != timeouts.writeTimeout { - t.Fatalf("WriteTimeout = %v, want %v", srv.WriteTimeout, timeouts.writeTimeout) - } - if srv.IdleTimeout != timeouts.idleTimeout { - t.Fatalf("IdleTimeout = %v, want %v", srv.IdleTimeout, timeouts.idleTimeout) - } -} - -func TestLiveServerWriteTimeoutIsLongerThanDefault(t *testing.T) { - if liveServerTimeouts.readTimeout != defaultServerTimeouts.readTimeout { - t.Fatalf("read timeout mismatch: live=%v default=%v", liveServerTimeouts.readTimeout, defaultServerTimeouts.readTimeout) - } - if liveServerTimeouts.idleTimeout != defaultServerTimeouts.idleTimeout { - t.Fatalf("idle timeout mismatch: live=%v default=%v", liveServerTimeouts.idleTimeout, defaultServerTimeouts.idleTimeout) - } - if liveServerTimeouts.writeTimeout <= defaultServerTimeouts.writeTimeout { - t.Fatalf("expected live write timeout > default write timeout, got live=%v default=%v", liveServerTimeouts.writeTimeout, defaultServerTimeouts.writeTimeout) - } -} diff --git a/internal/flamegraph/worker.go b/internal/flamegraph/worker.go deleted file mode 100644 index 0f49568..0000000 --- a/internal/flamegraph/worker.go +++ /dev/null @@ -1,33 +0,0 @@ -package flamegraph - -import ( - "context" - "ior/internal/event" - "sync" -) - -type worker struct { - iod iorData - done chan struct{} -} - -func newWorker() worker { - return worker{iod: newIorData()} -} - -func (w worker) run(ctx context.Context, wg *sync.WaitGroup, ch <-chan *event.Pair) { - defer wg.Done() - - for { - select { - case ev, ok := <-ch: - if !ok { - return - } - w.iod.addEventPair(ev) - ev.Recycle() - case <-ctx.Done(): - return - } - } -} diff --git a/internal/generate/doc.go b/internal/generate/doc.go new file mode 100644 index 0000000..4a02b7b --- /dev/null +++ b/internal/generate/doc.go @@ -0,0 +1,2 @@ +// Package generate contains code-generation logic for tracepoint handlers and Go types. +package generate diff --git a/internal/generate/typesgo.go b/internal/generate/typesgo.go index ee24845..06ed49a 100644 --- a/internal/generate/typesgo.go +++ b/internal/generate/typesgo.go @@ -186,7 +186,7 @@ func writeTraceIdStringMethod(b *strings.Builder) { b.WriteString(`func (s TraceId) String() string { str, ok := traceId2String[s] if !ok { - panic(fmt.Sprintf("no string representation for trace ID %d found", s)) + return fmt.Sprintf("unknown_trace_id_%d", s) } return str } @@ -198,7 +198,7 @@ func writeTraceIdNameMethod(b *strings.Builder) { b.WriteString(`func (s TraceId) Name() string { str, ok := traceId2Name[s] if !ok { - panic(fmt.Sprintf("no name for trace ID %d found", s)) + return fmt.Sprintf("unknown_trace_id_%d", s) } return str } diff --git a/internal/generate/typesgo_test.go b/internal/generate/typesgo_test.go index 89dafa8..f600582 100644 --- a/internal/generate/typesgo_test.go +++ b/internal/generate/typesgo_test.go @@ -240,7 +240,7 @@ func TestGenerateTypesGoTraceIdMethods(t *testing.T) { requireContains(t, output, "func (s TraceId) String() string") requireContains(t, output, "func (s TraceId) Name() string") - requireContains(t, output, `panic(fmt.Sprintf("no string representation for trace ID %d found", s))`) + requireContains(t, output, `return fmt.Sprintf("unknown_trace_id_%d", s)`) } func TestGenerateTypesGoPackageDecl(t *testing.T) { diff --git a/internal/ior.go b/internal/ior.go index a6fdbc4..7836ef2 100644 --- a/internal/ior.go +++ b/internal/ior.go @@ -8,7 +8,10 @@ import ( "fmt" "os" "os/signal" + "runtime" "runtime/pprof" + "runtime/trace" + "sync" "syscall" "time" @@ -24,27 +27,17 @@ import ( bpf "github.com/aquasecurity/libbpfgo" ) -type tracepointProgram interface { - attachTracepoint(category, name string) (tracepointLink, error) -} - -type tracepointLink interface { - Destroy() error -} - var ( - runTraceFn = runTrace - runTraceWithContextFn = runTraceWithContext - runTUIFn = tui.RunWithTraceStarter - getEUID = os.Geteuid + runTraceFn = runTrace + runTraceWithContextFn = runTraceWithContext + runTUIFn = tui.RunWithTraceStarterConfig + runTUITestFlamesFn = tui.RunTestFlamesWithTraceStarterConfig + runTUITestLiveFlamesFn = tui.RunTestFlamesWithTraceStarterConfig + getEUID = os.Geteuid errRootPrivilegesRequired = errors.New("tracing requires root privileges (run with sudo)") ) -type tracepointModule interface { - getProgram(progName string) (tracepointProgram, error) -} - type libbpfTracepointProgram struct { prog *bpf.BPFProg } @@ -53,22 +46,10 @@ func (p libbpfTracepointProgram) AttachTracepoint(category, name string) (probem return p.prog.AttachTracepoint(category, name) } -func (p libbpfTracepointProgram) attachTracepoint(category, name string) (tracepointLink, error) { - return p.AttachTracepoint(category, name) -} - type libbpfTracepointModule struct { module *bpf.Module } -func (m libbpfTracepointModule) getProgram(progName string) (tracepointProgram, error) { - prog, err := m.module.GetProgram(progName) - if err != nil { - return nil, err - } - return libbpfTracepointProgram{prog: prog}, nil -} - func (m libbpfTracepointModule) GetProgram(progName string) (probemanager.Program, error) { prog, err := m.module.GetProgram(progName) if err != nil { @@ -77,167 +58,133 @@ func (m libbpfTracepointModule) GetProgram(progName string) (probemanager.Progra return libbpfTracepointProgram{prog: prog}, nil } -func attachTracepointsWith(module tracepointModule, shouldAttach func(string) bool, tracepointNames []string, verbose bool) error { - logln := func(...any) {} - logf := func(string, ...any) {} - if verbose { - logln = func(args ...any) { _, _ = fmt.Println(args...) } - logf = func(format string, args ...any) { _, _ = fmt.Printf(format, args...) } - } - - for _, name := range tracepointNames { - if !shouldAttach(name) { - continue - } - logln("Attaching tracepoint", name) - - prog, err := module.getProgram(fmt.Sprintf("handle_%s", name)) - if err != nil { - return fmt.Errorf("Failed to get BPF program handle_%s: %v", name, err) - } - logln("Attached prog handle_", name) - - if _, err = prog.attachTracepoint("syscalls", name); err != nil { - // OK, older Kernel versions may not have this tracepoint! - logf("Failed to attach to %s tracepoint: %v, kernel version may be too old, skipping", name, err) - continue - } - logln("Attached tracepoint ", name) - } - - return nil -} - // Run is the main entry point for the ior binary. -// -// When -ior=<trace.ior.zst> is provided it reads the compressed trace data, generates -// a native flamegraph SVG (using the selected fields and count metric) and then serves -// it via an embedded HTTP server. Without -ior, Run either executes trace mode or -// starts the TUI, depending on the active flags. func Run() error { flags.PrintVersion() - cfg := flags.Get() - iorFile := cfg.IorDataFile - var noTraceRun bool - - if iorFile != "" { - if cfg.IorWatchInterval < 0 { - return errors.New("-iorWatchInterval must be >= 0") - } - noTraceRun = true - native := flamegraph.NewNativeSVG(cfg.CollapsedFields, cfg.CountField) - svgFile, err := writeIorOutputs(native, iorFile, cfg.FlamegraphJSON) - if err != nil { - return err - } + return dispatchRun(flags.Get()) +} - done := make(chan struct{}) - defer close(done) - if cfg.IorWatchInterval > 0 { - go watchIorOutputs(done, cfg.IorWatchInterval, iorFile, native, cfg.FlamegraphJSON) - err = flamegraph.ServeSVGAutoReload(svgFile, cfg.IorWatchInterval) - } else { - err = flamegraph.ServeSVG(svgFile) - } - if err != nil { - return err - } +func dispatchRun(cfg flags.Config) error { + if err := validateRunConfig(cfg); err != nil { + return err } - - if noTraceRun { - return nil + if cfg.TestFlames { + return runTUITestFlamesFn(cfg, tuiTestFlamesStarter(cfg)) } - return dispatchRun(cfg) + if cfg.TestLiveFlames { + return runTUITestLiveFlamesFn(cfg, tuiTestLiveFlamesStarter(cfg)) + } + if shouldRunTraceMode(cfg) { + return runTraceFn(cfg) + } + return runTUIFn(cfg, tuiTraceStarterFromRunTrace(cfg, runTraceWithContextFn)) } -func writeIorOutputs(native flamegraph.NativeSVG, iorFile string, writeJSON bool) (string, error) { - svgFile, err := native.WriteSVGFromFile(iorFile) - if err != nil { - return "", err +func validateRunConfig(cfg flags.Config) error { + if cfg.TestFlames && cfg.PlainMode { + return errors.New("--testflames cannot be combined with -plain") } - if !writeJSON { - return svgFile, nil + if cfg.TestLiveFlames && cfg.PlainMode { + return errors.New("--testliveflames cannot be combined with -plain") } - if _, err := native.WriteJSONFromFile(iorFile); err != nil { - return "", err + if cfg.TestFlames && cfg.TestLiveFlames { + return errors.New("--testflames and --testliveflames are mutually exclusive") } - return svgFile, nil + return nil } -func watchIorOutputs(done <-chan struct{}, interval time.Duration, iorFile string, - native flamegraph.NativeSVG, writeJSON bool) { - - ticker := time.NewTicker(interval) - defer ticker.Stop() - lastMod := fileModTime(iorFile) - - for { - select { - case <-done: - return - case <-ticker.C: - mod := fileModTime(iorFile) - if !mod.After(lastMod) { - continue - } - if _, err := writeIorOutputs(native, iorFile, writeJSON); err != nil { - _, _ = fmt.Printf("Failed to refresh flamegraph outputs: %v\n", err) - continue - } - lastMod = mod - _, _ = fmt.Printf("Refreshed flamegraph outputs at %s\n", time.Now().Format(time.RFC3339)) +func tuiTestFlamesStarter(cfg flags.Config) tui.TraceStarter { + return func(ctx context.Context) error { + engine, streamBuf, liveTrie := buildTestFlamesRuntime(cfg) + if bindings, ok := tui.RuntimeBindingsFromContext(ctx); ok { + bindings.SetDashboardSnapshotSource(engine) + bindings.SetEventStreamSource(streamBuf) + bindings.SetLiveTrie(liveTrie) } + return nil } } -func fileModTime(path string) time.Time { - stat, err := os.Stat(path) - if err != nil { - return time.Time{} +func tuiTestLiveFlamesStarter(cfg flags.Config) tui.TraceStarter { + return func(ctx context.Context) error { + engine, streamBuf, liveTrie := buildTestLiveFlamesRuntime(ctx, cfg) + if bindings, ok := tui.RuntimeBindingsFromContext(ctx); ok { + bindings.SetDashboardSnapshotSource(engine) + bindings.SetEventStreamSource(streamBuf) + bindings.SetLiveTrie(liveTrie) + } + return nil } - return stat.ModTime() } -func dispatchRun(cfg flags.Flags) error { - if err := validateRunConfig(cfg); err != nil { - return err - } - if shouldRunTraceMode(cfg) { - return runTraceFn() - } - return runTUIFn(tuiTraceStarterFromRunTrace(runTraceWithContextFn)) +func buildTestFlamesRuntime(cfg flags.Config) (*statsengine.Engine, *eventstream.RingBuffer, *flamegraph.LiveTrie) { + engine := statsengine.NewEngine(64) + streamBuf := eventstream.NewRingBuffer() + liveTrie := flamegraph.NewLiveTrie(cfg.CollapsedFields, cfg.CountField) + flamegraph.SeedTestFlameData(liveTrie) + return engine, streamBuf, liveTrie } -func validateRunConfig(cfg flags.Flags) error { - if cfg.LiveFlamegraph && cfg.FlamegraphEnable { - return errors.New("-live and -flamegraph are mutually exclusive") +func buildTestLiveFlamesRuntime(ctx context.Context, cfg flags.Config) (*statsengine.Engine, *eventstream.RingBuffer, *flamegraph.LiveTrie) { + engine := statsengine.NewEngine(64) + streamBuf := eventstream.NewRingBuffer() + liveTrie := flamegraph.NewLiveTrie(cfg.CollapsedFields, cfg.CountField) + flamegraph.SeedTestLiveFlameData(liveTrie, 0) + + interval := cfg.LiveInterval + if interval <= 0 { + interval = 200 * time.Millisecond } - if cfg.IorWatchInterval > 0 && cfg.IorDataFile == "" { - return errors.New("-iorWatchInterval requires -ior") + go runSyntheticLiveFlames(ctx, liveTrie, interval) + return engine, streamBuf, liveTrie +} + +func runSyntheticLiveFlames(ctx context.Context, liveTrie *flamegraph.LiveTrie, interval time.Duration) { + if liveTrie == nil { + return } - if cfg.IorWatchInterval < 0 { - return errors.New("-iorWatchInterval must be >= 0") + ticker := time.NewTicker(interval) + defer ticker.Stop() + tick := uint64(1) + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + // Keep a moving synthetic workload profile so the live test flamegraph + // visibly changes shape over time instead of only increasing totals. + liveTrie.Reset() + flamegraph.SeedTestLiveFlameData(liveTrie, tick) + tick++ + } } - return nil } -func shouldRunTraceMode(cfg flags.Flags) bool { - return cfg.PlainMode || cfg.FlamegraphEnable || cfg.LiveFlamegraph || cfg.PprofEnable +func shouldRunTraceMode(cfg flags.Config) bool { + return cfg.PlainMode } func tuiTraceStarterFromRunTrace( - startTrace func(context.Context, chan<- struct{}, func(*eventLoop)) error, + baseCfg flags.Config, + startTrace func(context.Context, flags.Config, chan<- struct{}, func(*eventLoop)) error, ) tui.TraceStarter { return func(ctx context.Context) error { bpf.SetLoggerCbs(bpf.Callbacks{ Log: func(int, string) {}, }) + cfg := baseCfg + if pidFilter, tidFilter, ok := tui.TraceFiltersFromContext(ctx); ok { + cfg.PidFilter = pidFilter + cfg.TidFilter = tidFilter + } engine := statsengine.NewEngine(64) streamBuf := eventstream.NewRingBuffer() + liveTrie := flamegraph.NewLiveTrie(cfg.CollapsedFields, cfg.CountField) if bindings, ok := tui.RuntimeBindingsFromContext(ctx); ok { bindings.SetDashboardSnapshotSource(engine) bindings.SetEventStreamSource(streamBuf) + bindings.SetLiveTrie(liveTrie) } streamEvents := make(chan eventstream.StreamEvent, 4096) @@ -251,10 +198,11 @@ func tuiTraceStarterFromRunTrace( errCh := make(chan error, 1) go func() { - err := startTrace(ctx, startedCh, func(el *eventLoop) { + err := startTrace(ctx, cfg, startedCh, func(el *eventLoop) { el.printCb = func(ep *event.Pair) { engine.Ingest(ep) streamEvents <- eventstream.NewStreamEvent(ep.EnterEv.GetTime(), ep) + liveTrie.Ingest(ep) ep.Recycle() } el.warningCb = func(message string) { @@ -281,107 +229,85 @@ func tuiTraceStarterFromRunTrace( } } -func runTrace() error { - return runTraceWithContext(context.Background(), nil, nil) +func runTrace(cfg flags.Config) error { + return runTraceWithContext(context.Background(), cfg, nil, nil) } -func newEventLoopConfig(cfg flags.Flags) eventLoopConfig { +func newEventLoopConfig(cfg flags.Config) eventLoopConfig { fields := make([]string, len(cfg.CollapsedFields)) copy(fields, cfg.CollapsedFields) return eventLoopConfig{ - pidFilter: cfg.PidFilter, - commFilter: cfg.CommFilter, - pathFilter: cfg.PathFilter, - liveFlamegraph: cfg.LiveFlamegraph, - liveInterval: cfg.LiveInterval, - liveOpenCommand: cfg.OpenCommand, - collapsedFields: fields, - countField: cfg.CountField, - flamegraphName: cfg.FlamegraphName, - flamegraphEnable: cfg.FlamegraphEnable, - pprofEnable: cfg.PprofEnable, - plainMode: cfg.PlainMode, + pidFilter: cfg.PidFilter, + commFilter: cfg.CommFilter, + pathFilter: cfg.PathFilter, + collapsedFields: fields, + countField: cfg.CountField, + pprofEnable: cfg.PprofEnable, + plainMode: cfg.PlainMode, } } -func runTraceWithContext(parentCtx context.Context, started chan<- struct{}, configure func(*eventLoop)) error { - if getEUID() != 0 { - return errRootPrivilegesRequired - } +type profilingControl struct { + done chan struct{} + enabled bool + cpuProfile *os.File + memProfile *os.File + stopExecTrace func() + stopOnce sync.Once +} - verbose := started == nil - logln := func(...any) {} - if verbose { - logln = func(args ...any) { _, _ = fmt.Println(args...) } +func newLogger(verbose bool) func(...any) { + if !verbose { + return func(...any) {} } - cfg := flags.Get() + return func(args ...any) { _, _ = fmt.Println(args...) } +} + +func setupBPFModule(parentCtx context.Context, cfg flags.Config) (*bpf.Module, *probemanager.Manager, func(), error) { + releaseBindings := func() {} bpfModule, err := bpf.NewModuleFromFile("ior.bpf.o") if err != nil { - return err + return nil, nil, releaseBindings, err } - defer bpfModule.Close() - if err := resizeBPFMaps(cfg, bpfModule); err != nil { - return err + bpfModule.Close() + return nil, nil, releaseBindings, err } - if err := setBPFGlobals(cfg, bpfModule); err != nil { - return err + bpfModule.Close() + return nil, nil, releaseBindings, err } - if err := bpfModule.BPFLoadObject(); err != nil { - return err + bpfModule.Close() + return nil, nil, releaseBindings, err } mgr := probemanager.NewManager(libbpfTracepointModule{module: bpfModule}) - defer mgr.Close() if err := mgr.AttachAll(cfg.ShouldIAttachTracepoint, tracepoints.List); err != nil { - return err + mgr.Close() + bpfModule.Close() + return nil, nil, releaseBindings, err } if bindings, ok := tui.RuntimeBindingsFromContext(parentCtx); ok { bindings.SetProbeManager(mgr) - defer bindings.SetProbeManager(nil) + releaseBindings = func() { bindings.SetProbeManager(nil) } } + return bpfModule, mgr, releaseBindings, nil +} - // 4096 channel size, minimises event drops +func setupEventChannel(bpfModule *bpf.Module) (chan []byte, error) { + // 4096 channel size minimizes event drops. ch := make(chan []byte, 4096) rb, err := bpfModule.InitRingBuf("event_map", ch) if err != nil { - return err + return nil, err } rb.Poll(300) + return ch, nil +} - pprofDone := make(chan struct{}) - var cpuProfile, memProfile *os.File - if cfg.PprofEnable { - if cpuProfile, err = os.Create("ior.cpuprofile"); err != nil { - return err - } - if memProfile, err = os.Create("ior.memprofile"); err != nil { - return err - } - pprof.StartCPUProfile(cpuProfile) - } else { - close(pprofDone) - } - - signalTraceStarted(started) - - el := newEventLoop(newEventLoopConfig(cfg)) - if configure != nil { - configure(el) - } - origPrintCb := el.printCb - el.printCb = func(ep *event.Pair) { - if !mgr.IsActive(ep.EnterEv.GetTraceId().Name()) { - ep.Recycle() - return - } - if origPrintCb != nil { - origPrintCb(ep) - } - } +func setupTraceContext(parentCtx context.Context, cfg flags.Config, logln func(...any)) (context.Context, context.CancelFunc, func()) { ctx := parentCtx cancel := func() {} if shouldAutoStopByDuration(cfg) { @@ -392,39 +318,173 @@ func runTraceWithContext(parentCtx context.Context, started chan<- struct{}, con logln("Probing until stopped...") ctx, cancel = context.WithCancel(parentCtx) } - defer cancel() signalCh := make(chan os.Signal, 1) signal.Notify(signalCh, os.Interrupt, syscall.SIGTERM) - defer signal.Stop(signalCh) - + stopSignals := func() { + signal.Stop(signalCh) + } go func() { select { case <-signalCh: logln("Received signal, shutting down...") cancel() case <-ctx.Done(): - return } }() + return ctx, cancel, stopSignals +} + +func setupProfiling(ctx context.Context, cfg flags.Config, started chan<- struct{}) (*profilingControl, error) { + control := &profilingControl{ + done: make(chan struct{}), + stopExecTrace: func() {}, + } + if !cfg.PprofEnable { + close(control.done) + return control, nil + } + + control.enabled = true + isTUIMode := started != nil + cpuProfilePath, memProfilePath, execTracePath, execTraceDuration := profilingFilesForMode(isTUIMode) + + cpuProfile, err := os.Create(cpuProfilePath) + if err != nil { + return nil, err + } + memProfile, err := os.Create(memProfilePath) + if err != nil { + _ = cpuProfile.Close() + return nil, err + } + control.cpuProfile = cpuProfile + control.memProfile = memProfile + + if execTracePath != "" { + execTraceProfile, err := os.Create(execTracePath) + if err != nil { + _ = cpuProfile.Close() + _ = memProfile.Close() + return nil, err + } + if err := trace.Start(execTraceProfile); err != nil { + _ = cpuProfile.Close() + _ = memProfile.Close() + _ = execTraceProfile.Close() + return nil, err + } + var stopOnce sync.Once + control.stopExecTrace = func() { + stopOnce.Do(func() { + trace.Stop() + _ = execTraceProfile.Close() + }) + } + go func() { + timer := time.NewTimer(execTraceDuration) + defer timer.Stop() + select { + case <-ctx.Done(): + case <-timer.C: + } + control.stopExecTrace() + }() + } + + if err := pprof.StartCPUProfile(cpuProfile); err != nil { + control.stopExecTrace() + _ = cpuProfile.Close() + _ = memProfile.Close() + return nil, err + } + return control, nil +} +func (p *profilingControl) stop(logln func(...any)) { + p.stopOnce.Do(func() { + if !p.enabled { + return + } + logln("Stopping profiling and writing profile files") + pprof.StopCPUProfile() + runtime.GC() + _ = pprof.WriteHeapProfile(p.memProfile) + p.stopExecTrace() + _ = p.cpuProfile.Close() + _ = p.memProfile.Close() + close(p.done) + }) +} + +func configureEventLoopOutput(el *eventLoop, mgr *probemanager.Manager, configure func(*eventLoop)) { + if configure != nil { + configure(el) + } + origPrintCb := el.printCb + el.printCb = func(ep *event.Pair) { + if !mgr.IsActive(ep.EnterEv.GetTraceId().Name()) { + ep.Recycle() + return + } + if origPrintCb != nil { + origPrintCb(ep) + } + } +} + +func startTraceShutdownWatcher(ctx context.Context, verbose bool, el *eventLoop, profiling *profilingControl, logln func(...any)) { go func() { <-ctx.Done() if verbose { fmt.Println(el.stats()) } - if cfg.PprofEnable { - logln("Stoppig profiling, writing ior.cpuprofile and ior.memprofile") - pprof.StopCPUProfile() - pprof.WriteHeapProfile(memProfile) - close(pprofDone) - } + profiling.stop(logln) }() +} + +func runTraceWithContext(parentCtx context.Context, cfg flags.Config, started chan<- struct{}, configure func(*eventLoop)) error { + if getEUID() != 0 { + return errRootPrivilegesRequired + } + + verbose := started == nil + logln := newLogger(verbose) + + bpfModule, mgr, releaseBindings, err := setupBPFModule(parentCtx, cfg) + if err != nil { + return err + } + defer bpfModule.Close() + defer mgr.Close() + defer releaseBindings() + + ch, err := setupEventChannel(bpfModule) + if err != nil { + return err + } + ctx, cancel, stopSignals := setupTraceContext(parentCtx, cfg, logln) + defer cancel() + defer stopSignals() + + profiling, err := setupProfiling(ctx, cfg, started) + if err != nil { + return err + } + + signalTraceStarted(started) + + el, err := newEventLoop(newEventLoopConfig(cfg)) + if err != nil { + return err + } + configureEventLoopOutput(el, mgr, configure) + startTraceShutdownWatcher(ctx, verbose, el, profiling, logln) startTime := time.Now() el.run(ctx, ch) totalDuration := time.Since(startTime) - <-pprofDone + <-profiling.done logln("Good bye... (unloading BPF tracepoints will take a few seconds...) after", totalDuration) return nil } @@ -436,6 +496,13 @@ func signalTraceStarted(started chan<- struct{}) { close(started) } -func shouldAutoStopByDuration(cfg flags.Flags) bool { - return cfg.PlainMode || cfg.FlamegraphEnable || cfg.LiveFlamegraph || cfg.PprofEnable +func shouldAutoStopByDuration(cfg flags.Config) bool { + return cfg.PlainMode +} + +func profilingFilesForMode(tuiMode bool) (cpuProfilePath, memProfilePath, execTracePath string, execTraceDuration time.Duration) { + if tuiMode { + return "ior-tui-cpu.prof", "ior-tui-mem.prof", "ior-tui-trace.out", 10 * time.Second + } + return "ior.cpuprofile", "ior.memprofile", "", 0 } diff --git a/internal/ior_mode_test.go b/internal/ior_mode_test.go index bbca555..48b2c36 100644 --- a/internal/ior_mode_test.go +++ b/internal/ior_mode_test.go @@ -1,7 +1,9 @@ package internal import ( + "bytes" "context" + "encoding/json" "errors" "testing" "time" @@ -11,7 +13,7 @@ import ( ) func TestShouldRunTraceMode(t *testing.T) { - base := flags.Flags{} + base := flags.Config{} if shouldRunTraceMode(base) { t.Fatalf("expected default mode to use TUI") @@ -23,27 +25,27 @@ func TestShouldRunTraceMode(t *testing.T) { t.Fatalf("expected plain mode to use trace mode") } - withFlamegraph := base - withFlamegraph.FlamegraphEnable = true - if !shouldRunTraceMode(withFlamegraph) { - t.Fatalf("expected flamegraph mode to use trace mode") - } - withPprof := base withPprof.PprofEnable = true - if !shouldRunTraceMode(withPprof) { - t.Fatalf("expected pprof mode to use trace mode") + if shouldRunTraceMode(withPprof) { + t.Fatalf("expected pprof flag alone to keep TUI mode") + } + + withTestFlames := base + withTestFlames.TestFlames = true + if shouldRunTraceMode(withTestFlames) { + t.Fatalf("expected --testflames to stay in TUI mode") } - withLive := base - withLive.LiveFlamegraph = true - if !shouldRunTraceMode(withLive) { - t.Fatalf("expected live mode to use trace mode") + withTestLiveFlames := base + withTestLiveFlames.TestLiveFlames = true + if shouldRunTraceMode(withTestLiveFlames) { + t.Fatalf("expected --testliveflames to stay in TUI mode") } } func TestShouldAutoStopByDuration(t *testing.T) { - base := flags.Flags{} + base := flags.Config{} if shouldAutoStopByDuration(base) { t.Fatalf("expected default TUI mode not to auto-stop by duration") } @@ -54,45 +56,46 @@ func TestShouldAutoStopByDuration(t *testing.T) { t.Fatalf("expected plain mode to auto-stop by duration") } - withFlamegraph := base - withFlamegraph.FlamegraphEnable = true - if !shouldAutoStopByDuration(withFlamegraph) { - t.Fatalf("expected flamegraph mode to auto-stop by duration") - } - withPprof := base withPprof.PprofEnable = true - if !shouldAutoStopByDuration(withPprof) { - t.Fatalf("expected pprof mode to auto-stop by duration") + if shouldAutoStopByDuration(withPprof) { + t.Fatalf("expected pprof flag alone not to auto-stop by duration") } - withLive := base - withLive.LiveFlamegraph = true - if !shouldAutoStopByDuration(withLive) { - t.Fatalf("expected live mode to auto-stop by duration") - } } func TestDispatchRunUsesTraceModeWhenRequested(t *testing.T) { origRunTrace := runTraceFn origRunTUI := runTUIFn + origRunTUITestFlames := runTUITestFlamesFn + origRunTUITestLiveFlames := runTUITestLiveFlamesFn defer func() { runTraceFn = origRunTrace runTUIFn = origRunTUI + runTUITestFlamesFn = origRunTUITestFlames + runTUITestLiveFlamesFn = origRunTUITestLiveFlames }() traceCalled := false tuiCalled := false - runTraceFn = func() error { + runTraceFn = func(flags.Config) error { traceCalled = true return nil } - runTUIFn = func(tui.TraceStarter) error { + runTUIFn = func(flags.Config, tui.TraceStarter) error { tuiCalled = true return nil } + runTUITestFlamesFn = func(flags.Config, tui.TraceStarter) error { + t.Fatalf("runTUITestFlamesFn should not be called in trace mode") + return nil + } + runTUITestLiveFlamesFn = func(flags.Config, tui.TraceStarter) error { + t.Fatalf("runTUITestLiveFlamesFn should not be called in trace mode") + return nil + } - cfg := flags.Flags{PlainMode: true} + cfg := flags.Config{PlainMode: true} if err := dispatchRun(cfg); err != nil { t.Fatalf("dispatchRun returned error: %v", err) } @@ -104,16 +107,63 @@ func TestDispatchRunUsesTraceModeWhenRequested(t *testing.T) { } } +func TestDispatchRunUsesTUIWhenOnlyPprofEnabled(t *testing.T) { + origRunTrace := runTraceFn + origRunTUI := runTUIFn + origRunTUITestFlames := runTUITestFlamesFn + origRunTUITestLiveFlames := runTUITestLiveFlamesFn + defer func() { + runTraceFn = origRunTrace + runTUIFn = origRunTUI + runTUITestFlamesFn = origRunTUITestFlames + runTUITestLiveFlamesFn = origRunTUITestLiveFlames + }() + + traceCalled := false + tuiCalled := false + runTraceFn = func(flags.Config) error { + traceCalled = true + return nil + } + runTUIFn = func(flags.Config, tui.TraceStarter) error { + tuiCalled = true + return nil + } + runTUITestFlamesFn = func(flags.Config, tui.TraceStarter) error { + t.Fatalf("runTUITestFlamesFn should not be called for regular TUI mode") + return nil + } + runTUITestLiveFlamesFn = func(flags.Config, tui.TraceStarter) error { + t.Fatalf("runTUITestLiveFlamesFn should not be called for regular TUI mode") + return nil + } + + cfg := flags.Config{PprofEnable: true} + if err := dispatchRun(cfg); err != nil { + t.Fatalf("dispatchRun returned error: %v", err) + } + if traceCalled { + t.Fatalf("did not expect runTraceFn when only -pprof is enabled") + } + if !tuiCalled { + t.Fatalf("expected runTUIFn to be called") + } +} + func TestDispatchRunUsesTUIStarterWhenNotPlain(t *testing.T) { origRunTraceWithContext := runTraceWithContextFn origRunTUI := runTUIFn + origRunTUITestFlames := runTUITestFlamesFn + origRunTUITestLiveFlames := runTUITestLiveFlamesFn defer func() { runTraceWithContextFn = origRunTraceWithContext runTUIFn = origRunTUI + runTUITestFlamesFn = origRunTUITestFlames + runTUITestLiveFlamesFn = origRunTUITestLiveFlames }() traceDone := make(chan struct{}, 1) - runTraceWithContextFn = func(_ context.Context, started chan<- struct{}, configure func(*eventLoop)) error { + runTraceWithContextFn = func(_ context.Context, _ flags.Config, started chan<- struct{}, configure func(*eventLoop)) error { if configure != nil { configure(&eventLoop{}) } @@ -123,7 +173,7 @@ func TestDispatchRunUsesTUIStarterWhenNotPlain(t *testing.T) { } tuiCalled := false - runTUIFn = func(starter tui.TraceStarter) error { + runTUIFn = func(_ flags.Config, starter tui.TraceStarter) error { tuiCalled = true if starter == nil { t.Fatalf("expected non-nil starter") @@ -133,8 +183,16 @@ func TestDispatchRunUsesTUIStarterWhenNotPlain(t *testing.T) { } return nil } + runTUITestFlamesFn = func(flags.Config, tui.TraceStarter) error { + t.Fatalf("runTUITestFlamesFn should not be called for normal starter path") + return nil + } + runTUITestLiveFlamesFn = func(flags.Config, tui.TraceStarter) error { + t.Fatalf("runTUITestLiveFlamesFn should not be called for normal starter path") + return nil + } - cfg := flags.Flags{} + cfg := flags.Config{} if err := dispatchRun(cfg); err != nil { t.Fatalf("dispatchRun returned error: %v", err) } @@ -149,61 +207,209 @@ func TestDispatchRunUsesTUIStarterWhenNotPlain(t *testing.T) { } } -func TestDispatchRunRejectsLiveAndFlamegraph(t *testing.T) { +func TestDispatchRunUsesTestFlamesModeWhenRequested(t *testing.T) { + origRunTrace := runTraceFn + origRunTUI := runTUIFn + origRunTUITestFlames := runTUITestFlamesFn + origRunTUITestLiveFlames := runTUITestLiveFlamesFn + defer func() { + runTraceFn = origRunTrace + runTUIFn = origRunTUI + runTUITestFlamesFn = origRunTUITestFlames + runTUITestLiveFlamesFn = origRunTUITestLiveFlames + }() + + traceCalled := false + regularTUICalled := false + testFlamesCalled := false + runTraceFn = func(flags.Config) error { + traceCalled = true + return nil + } + runTUIFn = func(flags.Config, tui.TraceStarter) error { + regularTUICalled = true + return nil + } + runTUITestFlamesFn = func(_ flags.Config, starter tui.TraceStarter) error { + testFlamesCalled = true + if starter == nil { + t.Fatalf("expected non-nil starter for test flames mode") + } + return starter(context.Background()) + } + runTUITestLiveFlamesFn = func(flags.Config, tui.TraceStarter) error { + t.Fatalf("runTUITestLiveFlamesFn should not be called for --testflames") + return nil + } + + cfg := flags.Config{TestFlames: true} + if err := dispatchRun(cfg); err != nil { + t.Fatalf("dispatchRun returned error: %v", err) + } + if traceCalled { + t.Fatalf("did not expect runTraceFn for test flames mode") + } + if regularTUICalled { + t.Fatalf("did not expect runTUIFn for test flames mode") + } + if !testFlamesCalled { + t.Fatalf("expected runTUITestFlamesFn to be called") + } +} + +func TestDispatchRunUsesTestLiveFlamesModeWhenRequested(t *testing.T) { origRunTrace := runTraceFn origRunTUI := runTUIFn + origRunTUITestFlames := runTUITestFlamesFn + origRunTUITestLiveFlames := runTUITestLiveFlamesFn defer func() { runTraceFn = origRunTrace runTUIFn = origRunTUI + runTUITestFlamesFn = origRunTUITestFlames + runTUITestLiveFlamesFn = origRunTUITestLiveFlames }() - runTraceFn = func() error { - t.Fatalf("runTraceFn should not be called for invalid flag combos") + traceCalled := false + regularTUICalled := false + testLiveFlamesCalled := false + runTraceFn = func(flags.Config) error { + traceCalled = true + return nil + } + runTUIFn = func(flags.Config, tui.TraceStarter) error { + regularTUICalled = true return nil } - runTUIFn = func(tui.TraceStarter) error { - t.Fatalf("runTUIFn should not be called for invalid flag combos") + runTUITestFlamesFn = func(flags.Config, tui.TraceStarter) error { + t.Fatalf("runTUITestFlamesFn should not be called for --testliveflames") return nil } + runTUITestLiveFlamesFn = func(_ flags.Config, starter tui.TraceStarter) error { + testLiveFlamesCalled = true + if starter == nil { + t.Fatalf("expected non-nil starter for test live flames mode") + } + return starter(context.Background()) + } + + cfg := flags.Config{TestLiveFlames: true} + if err := dispatchRun(cfg); err != nil { + t.Fatalf("dispatchRun returned error: %v", err) + } + if traceCalled { + t.Fatalf("did not expect runTraceFn for test live flames mode") + } + if regularTUICalled { + t.Fatalf("did not expect runTUIFn for test live flames mode") + } + if !testLiveFlamesCalled { + t.Fatalf("expected runTUITestLiveFlamesFn to be called") + } +} - cfg := flags.Flags{LiveFlamegraph: true, FlamegraphEnable: true} - err := dispatchRun(cfg) +func TestValidateRunConfigRejectsTestFlamesWithTraceFlags(t *testing.T) { + cfg := flags.Config{TestFlames: true, PlainMode: true} + err := validateRunConfig(cfg) if err == nil { - t.Fatalf("expected error for -live with -flamegraph") + t.Fatalf("expected error for --testflames with trace-mode flags") } - if err.Error() != "-live and -flamegraph are mutually exclusive" { + if err.Error() != "--testflames cannot be combined with -plain" { t.Fatalf("unexpected error: %v", err) } } -func TestValidateRunConfigRejectsIorWatchWithoutIor(t *testing.T) { - cfg := flags.Flags{IorWatchInterval: time.Second} +func TestValidateRunConfigRejectsTestLiveFlamesWithTraceFlags(t *testing.T) { + cfg := flags.Config{TestLiveFlames: true, PlainMode: true} err := validateRunConfig(cfg) if err == nil { - t.Fatalf("expected error for -iorWatchInterval without -ior") + t.Fatalf("expected error for --testliveflames with trace-mode flags") } - if err.Error() != "-iorWatchInterval requires -ior" { + if err.Error() != "--testliveflames cannot be combined with -plain" { t.Fatalf("unexpected error: %v", err) } } -func TestValidateRunConfigRejectsNegativeIorWatchInterval(t *testing.T) { - cfg := flags.Flags{IorWatchInterval: -time.Second} +func TestValidateRunConfigRejectsBothTestModes(t *testing.T) { + cfg := flags.Config{TestFlames: true, TestLiveFlames: true} err := validateRunConfig(cfg) if err == nil { - t.Fatalf("expected error for negative -iorWatchInterval") + t.Fatalf("expected error when both test flame modes are enabled") } - if err.Error() != "-iorWatchInterval must be >= 0" { + if err.Error() != "--testflames and --testliveflames are mutually exclusive" { t.Fatalf("unexpected error: %v", err) } } +func TestBuildTestFlamesRuntimeSeedsLiveTrie(t *testing.T) { + cfg := flags.NewFlags() + _, streamBuf, liveTrie := buildTestFlamesRuntime(cfg) + if streamBuf == nil { + t.Fatalf("expected stream buffer in test flames runtime") + } + if liveTrie == nil { + t.Fatalf("expected live trie in test flames runtime") + } + if liveTrie.Version() == 0 { + t.Fatalf("expected seeded live trie version to be non-zero") + } + + payload, _ := liveTrie.SnapshotJSON() + var snap map[string]any + if err := json.Unmarshal(payload, &snap); err != nil { + t.Fatalf("decode snapshot: %v", err) + } + total, ok := snap["t"].(float64) + if !ok || total <= 0 { + t.Fatalf("expected seeded snapshot total > 0, got %v", snap["t"]) + } +} + +func TestBuildTestLiveFlamesRuntimeContinuouslyUpdatesLiveTrie(t *testing.T) { + cfg := flags.NewFlags() + cfg.LiveInterval = 15 * time.Millisecond + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + _, streamBuf, liveTrie := buildTestLiveFlamesRuntime(ctx, cfg) + if streamBuf == nil { + t.Fatalf("expected stream buffer in test live flames runtime") + } + if liveTrie == nil { + t.Fatalf("expected live trie in test live flames runtime") + } + + initialVersion := liveTrie.Version() + if initialVersion == 0 { + t.Fatalf("expected seeded live trie version to be non-zero") + } + initialSnapshot, _ := liveTrie.SnapshotJSON() + + sawUpdate := false + deadline := time.Now().Add(300 * time.Millisecond) + for time.Now().Before(deadline) { + if liveTrie.Version() <= initialVersion { + time.Sleep(10 * time.Millisecond) + continue + } + currentSnapshot, _ := liveTrie.SnapshotJSON() + if !bytes.Equal(initialSnapshot, currentSnapshot) { + sawUpdate = true + break + } + time.Sleep(10 * time.Millisecond) + } + if !sawUpdate { + t.Fatalf("expected test live flames snapshot shape to change over time (version > %d)", initialVersion) + } +} + func TestRunTraceWithContextRequiresRoot(t *testing.T) { origGetEUID := getEUID defer func() { getEUID = origGetEUID }() getEUID = func() int { return 1000 } - err := runTraceWithContext(context.Background(), nil, nil) + err := runTraceWithContext(context.Background(), flags.NewFlags(), nil, nil) if !errors.Is(err, errRootPrivilegesRequired) { t.Fatalf("expected root-required error, got %v", err) } @@ -211,7 +417,10 @@ func TestRunTraceWithContextRequiresRoot(t *testing.T) { func TestTuiTraceStarterFromRunTracePropagatesError(t *testing.T) { starter := tuiTraceStarterFromRunTrace( - func(context.Context, chan<- struct{}, func(*eventLoop)) error { return errors.New("startup failed") }, + flags.NewFlags(), + func(context.Context, flags.Config, chan<- struct{}, func(*eventLoop)) error { + return errors.New("startup failed") + }, ) err := starter(context.Background()) @@ -220,9 +429,55 @@ func TestTuiTraceStarterFromRunTracePropagatesError(t *testing.T) { } } +func TestTuiTraceStarterFromRunTraceUsesContextFilters(t *testing.T) { + base := flags.NewFlags() + base.PidFilter = 11 + base.TidFilter = 12 + + var gotCfg flags.Config + starter := tuiTraceStarterFromRunTrace( + base, + func(_ context.Context, cfg flags.Config, started chan<- struct{}, _ func(*eventLoop)) error { + gotCfg = cfg + close(started) + return nil + }, + ) + + ctx := tui.ContextWithTraceFilters(context.Background(), 2222, 3333) + if err := starter(ctx); err != nil { + t.Fatalf("starter returned error: %v", err) + } + if gotCfg.PidFilter != 2222 { + t.Fatalf("expected pid filter from context, got %d", gotCfg.PidFilter) + } + if gotCfg.TidFilter != 3333 { + t.Fatalf("expected tid filter from context, got %d", gotCfg.TidFilter) + } +} + +func TestProfilingFilesForMode(t *testing.T) { + cpu, mem, execTrace, duration := profilingFilesForMode(false) + if cpu != "ior.cpuprofile" || mem != "ior.memprofile" { + t.Fatalf("unexpected trace-mode profiling file names: cpu=%q mem=%q", cpu, mem) + } + if execTrace != "" || duration != 0 { + t.Fatalf("expected trace-mode execution tracing to be disabled, got trace=%q duration=%s", execTrace, duration) + } + + cpu, mem, execTrace, duration = profilingFilesForMode(true) + if cpu != "ior-tui-cpu.prof" || mem != "ior-tui-mem.prof" || execTrace != "ior-tui-trace.out" { + t.Fatalf("unexpected TUI profiling file names: cpu=%q mem=%q trace=%q", cpu, mem, execTrace) + } + if duration != 10*time.Second { + t.Fatalf("expected 10s TUI execution trace duration, got %s", duration) + } +} + func TestTuiTraceStarterFromRunTraceRespectsCancel(t *testing.T) { starter := tuiTraceStarterFromRunTrace( - func(ctx context.Context, _ chan<- struct{}, _ func(*eventLoop)) error { + flags.NewFlags(), + func(ctx context.Context, _ flags.Config, _ chan<- struct{}, _ func(*eventLoop)) error { <-ctx.Done() return ctx.Err() }, diff --git a/internal/ior_test.go b/internal/ior_test.go deleted file mode 100644 index 43e8091..0000000 --- a/internal/ior_test.go +++ /dev/null @@ -1,108 +0,0 @@ -package internal - -import ( - "errors" - "strings" - "testing" -) - -type fakeTracepointProgram struct { - attachCalls int - attachErr error -} - -type fakeTracepointLink struct{} - -func (fakeTracepointLink) Destroy() error { return nil } - -func (p *fakeTracepointProgram) attachTracepoint(_, _ string) (tracepointLink, error) { - p.attachCalls++ - if p.attachErr != nil { - return nil, p.attachErr - } - return fakeTracepointLink{}, nil -} - -type fakeTracepointModule struct { - getProgramCalls []string - getProgramErrs map[string]error - programs map[string]*fakeTracepointProgram -} - -func (m *fakeTracepointModule) getProgram(progName string) (tracepointProgram, error) { - m.getProgramCalls = append(m.getProgramCalls, progName) - if err, ok := m.getProgramErrs[progName]; ok { - return nil, err - } - if prog, ok := m.programs[progName]; ok { - return prog, nil - } - return nil, errors.New("missing program") -} - -func TestAttachTracepointsWithSkipsFilteredTracepoints(t *testing.T) { - module := &fakeTracepointModule{ - programs: map[string]*fakeTracepointProgram{ - "handle_sys_enter_read": {}, - "handle_sys_enter_write": {}, - }, - getProgramErrs: map[string]error{}, - } - - err := attachTracepointsWith(module, func(tracepoint string) bool { - return tracepoint == "sys_enter_read" - }, []string{"sys_enter_read", "sys_enter_write"}, false) - if err != nil { - t.Fatalf("attachTracepointsWith returned error: %v", err) - } - - if len(module.getProgramCalls) != 1 || module.getProgramCalls[0] != "handle_sys_enter_read" { - t.Fatalf("getProgram calls = %v, want only handle_sys_enter_read", module.getProgramCalls) - } - - if module.programs["handle_sys_enter_read"].attachCalls != 1 { - t.Fatalf("read attach calls = %d, want 1", module.programs["handle_sys_enter_read"].attachCalls) - } - if module.programs["handle_sys_enter_write"].attachCalls != 0 { - t.Fatalf("write attach calls = %d, want 0", module.programs["handle_sys_enter_write"].attachCalls) - } -} - -func TestAttachTracepointsWithReturnsErrorWhenProgramMissing(t *testing.T) { - module := &fakeTracepointModule{ - programs: map[string]*fakeTracepointProgram{}, - getProgramErrs: map[string]error{ - "handle_sys_enter_read": errors.New("not found"), - }, - } - - err := attachTracepointsWith(module, func(string) bool { return true }, []string{"sys_enter_read"}, false) - if err == nil { - t.Fatal("attachTracepointsWith returned nil error, want non-nil") - } - if !strings.Contains(err.Error(), "handle_sys_enter_read") { - t.Fatalf("error %q does not mention handle_sys_enter_read", err) - } -} - -func TestAttachTracepointsWithAttachFailureContinues(t *testing.T) { - module := &fakeTracepointModule{ - programs: map[string]*fakeTracepointProgram{ - "handle_sys_enter_read": {attachErr: errors.New("no tracepoint")}, - "handle_sys_enter_write": {}, - }, - getProgramErrs: map[string]error{}, - } - - err := attachTracepointsWith(module, func(string) bool { return true }, []string{"sys_enter_read", "sys_enter_write"}, false) - if err != nil { - t.Fatalf("attachTracepointsWith returned error: %v", err) - } - - if module.programs["handle_sys_enter_read"].attachCalls != 1 { - t.Fatalf("read attach calls = %d, want 1", module.programs["handle_sys_enter_read"].attachCalls) - } - if module.programs["handle_sys_enter_write"].attachCalls != 1 { - t.Fatalf("write attach calls = %d, want 1", module.programs["handle_sys_enter_write"].attachCalls) - } -} diff --git a/internal/probemanager/doc.go b/internal/probemanager/doc.go new file mode 100644 index 0000000..dd940b4 --- /dev/null +++ b/internal/probemanager/doc.go @@ -0,0 +1,2 @@ +// Package probemanager tracks probe enablement state and runtime toggling operations. +package probemanager diff --git a/internal/probemanager/manager.go b/internal/probemanager/manager.go index b991c7c..7feb407 100644 --- a/internal/probemanager/manager.go +++ b/internal/probemanager/manager.go @@ -50,6 +50,7 @@ type Manager struct { closed bool } +// NewManager creates a new probe manager that resolves programs via attacher. func NewManager(attacher Attacher) *Manager { return &Manager{ attacher: attacher, @@ -57,6 +58,7 @@ func NewManager(attacher Attacher) *Manager { } } +// Register registers the enter/exit tracepoint pair for a syscall key. func (m *Manager) Register(syscall string, pair TracepointPair) { if m == nil || syscall == "" { return @@ -74,6 +76,7 @@ func (m *Manager) Register(syscall string, pair TracepointPair) { entry.exitTP = pair.Exit } +// AttachAll registers and attaches all tracepoint pairs selected by shouldAttach. func (m *Manager) AttachAll(shouldAttach func(string) bool, tpNames []string) error { if m == nil { return errors.New("probe manager is nil") @@ -95,6 +98,7 @@ func (m *Manager) AttachAll(shouldAttach func(string) bool, tpNames []string) er return nil } +// Toggle flips a syscall probe between attached and detached states. func (m *Manager) Toggle(syscall string) error { if m == nil { return errors.New("probe manager is nil") @@ -118,6 +122,7 @@ func (m *Manager) Toggle(syscall string) error { return m.Attach(syscall) } +// Attach attaches enter/exit tracepoints for a registered syscall. func (m *Manager) Attach(syscall string) error { if syscall == "" { return errors.New("syscall is required") @@ -167,6 +172,7 @@ func (m *Manager) Attach(syscall string) error { return nil } +// Detach detaches enter/exit tracepoints for a registered syscall. func (m *Manager) Detach(syscall string) error { if syscall == "" { return errors.New("syscall is required") @@ -220,6 +226,7 @@ func (m *Manager) Detach(syscall string) error { return combined } +// States returns a stable snapshot of all known probe states. func (m *Manager) States() []ProbeState { if m == nil { return nil @@ -243,6 +250,7 @@ func (m *Manager) States() []ProbeState { return out } +// ActiveCount returns the number of active probes and total registered probes. func (m *Manager) ActiveCount() (active, total int) { if m == nil { return 0, 0 @@ -276,6 +284,7 @@ func (m *Manager) IsActive(syscall string) bool { return entry.active } +// Close detaches all registered probes and marks the manager closed. func (m *Manager) Close() error { if m == nil { return nil diff --git a/internal/statsengine/bench_test.go b/internal/statsengine/bench_test.go index 27f17b1..646bdda 100644 --- a/internal/statsengine/bench_test.go +++ b/internal/statsengine/bench_test.go @@ -1,10 +1,11 @@ package statsengine import ( - "ior/internal/types" "math/rand" "testing" "time" + + "ior/internal/types" ) func BenchmarkSyscallAccumulatorSnapshot(b *testing.B) { diff --git a/internal/statsengine/doc.go b/internal/statsengine/doc.go new file mode 100644 index 0000000..2d2c0c0 --- /dev/null +++ b/internal/statsengine/doc.go @@ -0,0 +1,2 @@ +// Package statsengine aggregates trace events into dashboard snapshot statistics. +package statsengine diff --git a/internal/statsengine/engine.go b/internal/statsengine/engine.go index 1ef58cf..fb85558 100644 --- a/internal/statsengine/engine.go +++ b/internal/statsengine/engine.go @@ -1,11 +1,12 @@ package statsengine import ( - "ior/internal/event" - "ior/internal/types" "math" "sync" "time" + + "ior/internal/event" + "ior/internal/types" ) const trendWindowSlots = 20 diff --git a/internal/statsengine/engine_test.go b/internal/statsengine/engine_test.go index 943fe9c..7ba8c3a 100644 --- a/internal/statsengine/engine_test.go +++ b/internal/statsengine/engine_test.go @@ -1,12 +1,13 @@ package statsengine import ( - "ior/internal/event" - "ior/internal/file" - "ior/internal/types" "math" "testing" "time" + + "ior/internal/event" + "ior/internal/file" + "ior/internal/types" ) type fakeClock struct { diff --git a/internal/statsengine/filerank.go b/internal/statsengine/filerank.go index 6e8f27f..dd83e8d 100644 --- a/internal/statsengine/filerank.go +++ b/internal/statsengine/filerank.go @@ -2,9 +2,10 @@ package statsengine import ( "container/heap" + "sort" + "ior/internal/event" "ior/internal/types" - "sort" ) const fileRankTopNDefault = 20 diff --git a/internal/statsengine/filerank_test.go b/internal/statsengine/filerank_test.go index 26a0b23..bec5eae 100644 --- a/internal/statsengine/filerank_test.go +++ b/internal/statsengine/filerank_test.go @@ -2,11 +2,12 @@ package statsengine import ( "fmt" + "reflect" + "testing" + "ior/internal/event" "ior/internal/file" "ior/internal/types" - "reflect" - "testing" ) func TestFileRankerHeapEviction(t *testing.T) { diff --git a/internal/statsengine/process.go b/internal/statsengine/process.go index e677744..b00a4bb 100644 --- a/internal/statsengine/process.go +++ b/internal/statsengine/process.go @@ -1,9 +1,10 @@ package statsengine import ( - "ior/internal/event" "sort" "time" + + "ior/internal/event" ) const processRankTopNDefault = 20 diff --git a/internal/statsengine/process_test.go b/internal/statsengine/process_test.go index aa3c5d2..77e7a0a 100644 --- a/internal/statsengine/process_test.go +++ b/internal/statsengine/process_test.go @@ -1,11 +1,12 @@ package statsengine import ( - "ior/internal/event" - "ior/internal/types" "math" "testing" "time" + + "ior/internal/event" + "ior/internal/types" ) func TestProcessAccumulatorBasicStats(t *testing.T) { diff --git a/internal/statsengine/snapshot.go b/internal/statsengine/snapshot.go index 8c9656e..f2b617b 100644 --- a/internal/statsengine/snapshot.go +++ b/internal/statsengine/snapshot.go @@ -1,9 +1,10 @@ package statsengine import ( - "ior/internal/types" "slices" "time" + + "ior/internal/types" ) // TrendDirection is the direction of a time-window comparison. diff --git a/internal/statsengine/syscall.go b/internal/statsengine/syscall.go index 6c34f4a..4feeab2 100644 --- a/internal/statsengine/syscall.go +++ b/internal/statsengine/syscall.go @@ -1,12 +1,13 @@ package statsengine import ( - "ior/internal/event" - "ior/internal/types" "math" "math/rand" "sort" "time" + + "ior/internal/event" + "ior/internal/types" ) const syscallReservoirSampleCapDefault = 10_000 diff --git a/internal/statsengine/syscall_test.go b/internal/statsengine/syscall_test.go index 1ebe214..b315bd8 100644 --- a/internal/statsengine/syscall_test.go +++ b/internal/statsengine/syscall_test.go @@ -1,12 +1,13 @@ package statsengine import ( - "ior/internal/event" - "ior/internal/types" "math" "math/rand" "testing" "time" + + "ior/internal/event" + "ior/internal/types" ) func TestSyscallAccumulatorBasicStats(t *testing.T) { diff --git a/internal/tracepoints/doc.go b/internal/tracepoints/doc.go new file mode 100644 index 0000000..8cd0d1f --- /dev/null +++ b/internal/tracepoints/doc.go @@ -0,0 +1,2 @@ +// Package tracepoints exposes generated syscall tracepoint metadata. +package tracepoints diff --git a/internal/tui/common/doc.go b/internal/tui/common/doc.go new file mode 100644 index 0000000..e15ceb7 --- /dev/null +++ b/internal/tui/common/doc.go @@ -0,0 +1,2 @@ +// Package common provides shared TUI styling, keymaps, and viewport utilities. +package common diff --git a/internal/tui/common/keys.go b/internal/tui/common/keys.go index ba17998..ab9865d 100644 --- a/internal/tui/common/keys.go +++ b/internal/tui/common/keys.go @@ -1,6 +1,6 @@ package common -import "github.com/charmbracelet/bubbles/key" +import "charm.land/bubbles/v2/key" // HelpSection groups related key bindings under a shared heading. type HelpSection struct { @@ -38,12 +38,12 @@ func DefaultKeyMap() KeyMap { return KeyMap{ Tab: key.NewBinding(key.WithKeys("tab"), key.WithHelp("tab", "next tab")), ShiftTab: key.NewBinding(key.WithKeys("shift+tab"), key.WithHelp("shift+tab", "prev tab")), - One: key.NewBinding(key.WithKeys("1"), key.WithHelp("1", "overview")), - Two: key.NewBinding(key.WithKeys("2"), key.WithHelp("2", "syscalls")), - Three: key.NewBinding(key.WithKeys("3"), key.WithHelp("3", "files")), - Four: key.NewBinding(key.WithKeys("4"), key.WithHelp("4", "processes")), - Five: key.NewBinding(key.WithKeys("5"), key.WithHelp("5", "lat+gaps")), - Six: key.NewBinding(key.WithKeys("6"), key.WithHelp("6", "stream")), + One: key.NewBinding(key.WithKeys("1"), key.WithHelp("1", "flame")), + Two: key.NewBinding(key.WithKeys("2"), key.WithHelp("2", "overview")), + Three: key.NewBinding(key.WithKeys("3"), key.WithHelp("3", "syscalls")), + Four: key.NewBinding(key.WithKeys("4"), key.WithHelp("4", "files")), + Five: key.NewBinding(key.WithKeys("5"), key.WithHelp("5", "processes")), + Six: key.NewBinding(key.WithKeys("6"), key.WithHelp("6", "lat+gaps")), Seven: key.NewBinding(key.WithKeys("7"), key.WithHelp("7", "stream")), DirGroup: key.NewBinding(key.WithKeys("d"), key.WithHelp("d", "dir group")), SelectPID: key.NewBinding(key.WithKeys("p"), key.WithHelp("p", "select pid")), @@ -83,6 +83,7 @@ func (k KeyMap) DashboardStatusHelpSections() []HelpSection { k.Four, k.Five, k.Six, + k.Seven, k.SelectPID, k.SelectTID, k.Probes, @@ -126,7 +127,7 @@ func (k KeyMap) DashboardFullHelp() [][]key.Binding { controls = append(controls, k.DirGroup, k.SelectPID, k.SelectTID, k.Probes, k.Refresh, k.Quit) return [][]key.Binding{ - {k.One, k.Two, k.Three, k.Four, k.Five, k.Six}, + {k.One, k.Two, k.Three, k.Four, k.Five, k.Six, k.Seven}, controls, { helpTextBinding("space", "stream pause"), diff --git a/internal/tui/common/keys_test.go b/internal/tui/common/keys_test.go index 42e47ab..4284faf 100644 --- a/internal/tui/common/keys_test.go +++ b/internal/tui/common/keys_test.go @@ -23,6 +23,11 @@ func TestDefaultKeyMapIncludesDirGroupBinding(t *testing.T) { if selectTIDHelp.Key != "t" || selectTIDHelp.Desc != "select tid" { t.Fatalf("unexpected select tid binding help: key=%q desc=%q", selectTIDHelp.Key, selectTIDHelp.Desc) } + + flameHelp := keys.One.Help() + if flameHelp.Key != "1" || flameHelp.Desc != "flame" { + t.Fatalf("unexpected flame binding help: key=%q desc=%q", flameHelp.Key, flameHelp.Desc) + } } func TestDashboardFullHelpIncludesDirGroupBinding(t *testing.T) { @@ -33,6 +38,7 @@ func TestDashboardFullHelpIncludesDirGroupBinding(t *testing.T) { } found := false + foundOne := false for _, binding := range groups[1] { help := binding.Help() if help.Key == "d" && help.Desc == "dir group" { @@ -44,6 +50,17 @@ func TestDashboardFullHelpIncludesDirGroupBinding(t *testing.T) { t.Fatalf("expected dir group binding in dashboard full help controls") } + for _, binding := range groups[0] { + help := binding.Help() + if help.Key == "1" && help.Desc == "flame" { + foundOne = true + break + } + } + if !foundOne { + t.Fatalf("expected flame tab binding in dashboard full help tabs") + } + found = false for _, binding := range groups[1] { help := binding.Help() @@ -86,6 +103,7 @@ func TestDashboardStatusHelpIncludesProbesBinding(t *testing.T) { short := keys.DashboardStatusHelp() found := false foundSelectTID := false + foundOne := false for _, binding := range short { help := binding.Help() if help.Key == "o" && help.Desc == "probes" { @@ -94,6 +112,9 @@ func TestDashboardStatusHelpIncludesProbesBinding(t *testing.T) { if help.Key == "t" && help.Desc == "select tid" { foundSelectTID = true } + if help.Key == "1" && help.Desc == "flame" { + foundOne = true + } } if !found { t.Fatalf("expected probes binding in dashboard short help") @@ -101,4 +122,7 @@ func TestDashboardStatusHelpIncludesProbesBinding(t *testing.T) { if !foundSelectTID { t.Fatalf("expected select tid binding in dashboard short help") } + if !foundOne { + t.Fatalf("expected flame tab binding in dashboard short help") + } } diff --git a/internal/tui/common/styles.go b/internal/tui/common/styles.go index d4c75ff..a71ef81 100644 --- a/internal/tui/common/styles.go +++ b/internal/tui/common/styles.go @@ -1,59 +1,117 @@ package common -import "github.com/charmbracelet/lipgloss" +import ( + "image/color" + + "charm.land/lipgloss/v2" +) + +// Palette defines themed colors shared across the TUI package. +type Palette struct { + Background color.Color + Panel color.Color + Primary color.Color + Accent color.Color + Muted color.Color + Text color.Color + Danger color.Color +} + +// NewPalette returns a color palette for dark or light terminal backgrounds. +func NewPalette(isDark bool) Palette { + if isDark { + return Palette{ + Background: lipgloss.Color("235"), + Panel: lipgloss.Color("238"), + Primary: lipgloss.Color("75"), + Accent: lipgloss.Color("222"), + Muted: lipgloss.Color("246"), + Text: lipgloss.Color("255"), + Danger: lipgloss.Color("203"), + } + } + + return Palette{ + Background: lipgloss.Color("255"), + Panel: lipgloss.Color("250"), + Primary: lipgloss.Color("26"), + Accent: lipgloss.Color("88"), + Muted: lipgloss.Color("242"), + Text: lipgloss.Color("235"), + Danger: lipgloss.Color("160"), + } +} var ( // Palette colors shared across the TUI package. - ColorBackground = lipgloss.Color("235") - ColorPanel = lipgloss.Color("238") - ColorPrimary = lipgloss.Color("75") - ColorAccent = lipgloss.Color("222") - ColorMuted = lipgloss.Color("246") - ColorText = lipgloss.Color("255") - ColorDanger = lipgloss.Color("203") + ColorBackground color.Color + ColorPanel color.Color + ColorPrimary color.Color + ColorAccent color.Color + ColorMuted color.Color + ColorText color.Color + ColorDanger color.Color ) var ( // ScreenStyle is the base style for full-screen models. - ScreenStyle = lipgloss.NewStyle(). - Foreground(ColorText) + ScreenStyle lipgloss.Style // HeaderStyle is used by top-level titles and screen headers. - HeaderStyle = lipgloss.NewStyle(). - Bold(true). - Foreground(ColorPrimary) + HeaderStyle lipgloss.Style // TabActiveStyle is applied to the currently-selected tab. - TabActiveStyle = lipgloss.NewStyle(). - Bold(true). - Foreground(ColorBackground). - Background(ColorPrimary). - Padding(0, 1) + TabActiveStyle lipgloss.Style // TabInactiveStyle is applied to non-selected tabs. - TabInactiveStyle = lipgloss.NewStyle(). - Foreground(ColorMuted). - Padding(0, 1) + TabInactiveStyle lipgloss.Style // PanelStyle is used for boxed sections. - PanelStyle = lipgloss.NewStyle(). - Border(lipgloss.NormalBorder()). - BorderForeground(ColorPanel). - Padding(0, 1) + PanelStyle lipgloss.Style // HelpBarStyle is used for keybinding hints at the bottom. - HelpBarStyle = lipgloss.NewStyle(). - Foreground(ColorMuted). - BorderTop(true). - BorderForeground(ColorPanel) + HelpBarStyle lipgloss.Style // HighlightStyle emphasizes inline values. - HighlightStyle = lipgloss.NewStyle(). - Bold(true). - Foreground(ColorAccent) + HighlightStyle lipgloss.Style // ErrorStyle is used for fatal or warning messages. - ErrorStyle = lipgloss.NewStyle(). - Bold(true). - Foreground(ColorDanger) + ErrorStyle lipgloss.Style ) + +// ApplyPalette updates shared colors and styles to match the provided theme. +func ApplyPalette(isDark bool) { + palette := NewPalette(isDark) + ColorBackground = palette.Background + ColorPanel = palette.Panel + ColorPrimary = palette.Primary + ColorAccent = palette.Accent + ColorMuted = palette.Muted + ColorText = palette.Text + ColorDanger = palette.Danger + + ScreenStyle = lipgloss.NewStyle().Foreground(ColorText) + HeaderStyle = lipgloss.NewStyle().Bold(true).Foreground(ColorPrimary) + TabActiveStyle = lipgloss.NewStyle(). + Bold(true). + Foreground(ColorBackground). + Background(ColorPrimary). + Padding(0, 1) + TabInactiveStyle = lipgloss.NewStyle(). + Foreground(ColorMuted). + Padding(0, 1) + PanelStyle = lipgloss.NewStyle(). + Border(lipgloss.NormalBorder()). + BorderForeground(ColorPanel). + Padding(0, 1) + HelpBarStyle = lipgloss.NewStyle(). + Foreground(ColorMuted). + BorderTop(true). + BorderForeground(ColorPanel) + HighlightStyle = lipgloss.NewStyle().Bold(true).Foreground(ColorAccent) + ErrorStyle = lipgloss.NewStyle().Bold(true).Foreground(ColorDanger) +} + +func init() { + ApplyPalette(true) +} diff --git a/internal/tui/common/styles_test.go b/internal/tui/common/styles_test.go new file mode 100644 index 0000000..c0900b3 --- /dev/null +++ b/internal/tui/common/styles_test.go @@ -0,0 +1,39 @@ +package common + +import ( + "testing" + + "charm.land/lipgloss/v2" +) + +func TestNewPaletteRendersDistinctThemes(t *testing.T) { + dark := NewPalette(true) + light := NewPalette(false) + + darkRender := lipgloss.NewStyle(). + Foreground(dark.Text). + Background(dark.Background). + Render("ior") + lightRender := lipgloss.NewStyle(). + Foreground(light.Text). + Background(light.Background). + Render("ior") + + if darkRender == lightRender { + t.Fatalf("expected dark and light palettes to render differently") + } +} + +func TestApplyPaletteUpdatesSharedStyles(t *testing.T) { + t.Cleanup(func() { ApplyPalette(true) }) + + ApplyPalette(true) + dark := ScreenStyle.Render("ior") + + ApplyPalette(false) + light := ScreenStyle.Render("ior") + + if dark == light { + t.Fatalf("expected ScreenStyle render to differ between dark and light palettes") + } +} diff --git a/internal/tui/common/viewport.go b/internal/tui/common/viewport.go index e1729db..d54c886 100644 --- a/internal/tui/common/viewport.go +++ b/internal/tui/common/viewport.go @@ -11,20 +11,22 @@ const ( defaultViewportHeight = 24 ) +var queryTerminalSize = func() (int, int, error) { + return xterm.GetSize(os.Stdout.Fd()) +} + // EffectiveViewport returns a usable terminal viewport size. Missing or invalid -// dimensions are resolved from the active terminal when possible. +// dimensions fall back to defaults. func EffectiveViewport(width, height int) (int, int) { - if width > 0 && height > 0 { - return width, height - } - - termWidth, termHeight, err := xterm.GetSize(os.Stdout.Fd()) - if err == nil { - if width <= 0 && termWidth > 0 { - width = termWidth - } - if height <= 0 && termHeight > 0 { - height = termHeight + if width <= 0 || height <= 0 { + terminalWidth, terminalHeight, err := queryTerminalSize() + if err == nil { + if width <= 0 && terminalWidth > 0 { + width = terminalWidth + } + if height <= 0 && terminalHeight > 0 { + height = terminalHeight + } } } diff --git a/internal/tui/common/viewport_test.go b/internal/tui/common/viewport_test.go new file mode 100644 index 0000000..2dda81b --- /dev/null +++ b/internal/tui/common/viewport_test.go @@ -0,0 +1,76 @@ +package common + +import "testing" + +func TestEffectiveViewport(t *testing.T) { + originalQuery := queryTerminalSize + t.Cleanup(func() { + queryTerminalSize = originalQuery + }) + queryTerminalSize = func() (int, int, error) { + return 132, 41, nil + } + + tests := []struct { + name string + width int + height int + wantWidth int + wantHeight int + }{ + { + name: "provided dimensions", + width: 120, + height: 40, + wantWidth: 120, + wantHeight: 40, + }, + { + name: "both missing use terminal size", + width: 0, + height: 0, + wantWidth: 132, + wantHeight: 41, + }, + { + name: "missing height uses terminal size", + width: 100, + height: 0, + wantWidth: 100, + wantHeight: 41, + }, + { + name: "missing width uses terminal size", + width: -1, + height: 30, + wantWidth: 132, + wantHeight: 30, + }, + } + + for _, tt := range tests { + gotWidth, gotHeight := EffectiveViewport(tt.width, tt.height) + if gotWidth != tt.wantWidth || gotHeight != tt.wantHeight { + t.Fatalf("%s: got (%d,%d), want (%d,%d)", tt.name, gotWidth, gotHeight, tt.wantWidth, tt.wantHeight) + } + } +} + +func TestEffectiveViewportFallsBackToDefaultsWhenTerminalQueryFails(t *testing.T) { + originalQuery := queryTerminalSize + t.Cleanup(func() { + queryTerminalSize = originalQuery + }) + queryTerminalSize = func() (int, int, error) { + return 0, 0, assertiveError{} + } + + gotWidth, gotHeight := EffectiveViewport(0, 0) + if gotWidth != defaultViewportWidth || gotHeight != defaultViewportHeight { + t.Fatalf("got (%d,%d), want (%d,%d)", gotWidth, gotHeight, defaultViewportWidth, defaultViewportHeight) + } +} + +type assertiveError struct{} + +func (assertiveError) Error() string { return "terminal query failed" } diff --git a/internal/tui/dashboard/doc.go b/internal/tui/dashboard/doc.go new file mode 100644 index 0000000..b9bc30e --- /dev/null +++ b/internal/tui/dashboard/doc.go @@ -0,0 +1,2 @@ +// Package dashboard implements the multi-tab runtime dashboard used in TUI mode. +package dashboard diff --git a/internal/tui/dashboard/files.go b/internal/tui/dashboard/files.go index 80e3037..d43e215 100644 --- a/internal/tui/dashboard/files.go +++ b/internal/tui/dashboard/files.go @@ -2,12 +2,13 @@ package dashboard import ( "fmt" - "ior/internal/statsengine" "path/filepath" "sort" "strconv" - "github.com/charmbracelet/bubbles/table" + "ior/internal/statsengine" + + "charm.land/bubbles/v2/table" ) type DirSnapshot struct { diff --git a/internal/tui/dashboard/histogram.go b/internal/tui/dashboard/histogram.go index 7613230..28f5b2b 100644 --- a/internal/tui/dashboard/histogram.go +++ b/internal/tui/dashboard/histogram.go @@ -2,11 +2,12 @@ package dashboard import ( "fmt" - "ior/internal/statsengine" - common "ior/internal/tui/common" "math" "strconv" "strings" + + "ior/internal/statsengine" + common "ior/internal/tui/common" ) func renderLatencyTab(snap *statsengine.Snapshot, width, height int) string { @@ -14,9 +15,10 @@ func renderLatencyTab(snap *statsengine.Snapshot, width, height int) string { return common.PanelStyle.Render("Latency: waiting for stats...") } + panelW := panelWidth(width) panelInner := panelInnerWidth(width) hist := renderHistogram(snap.LatencyHistogram, "Latency Histogram", width, height) - spark := common.PanelStyle.Width(panelInner).Render( + spark := common.PanelStyle.Width(panelW).Render( renderOverviewSparkline("Latency sparkline:", snap.LatencySeriesNs(), panelInner), ) return strings.Join([]string{hist, spark}, "\n") @@ -27,9 +29,10 @@ func renderGapsTab(snap *statsengine.Snapshot, width, height int) string { return common.PanelStyle.Render("Gaps: waiting for stats...") } + panelW := panelWidth(width) panelInner := panelInnerWidth(width) hist := renderHistogram(snap.GapHistogram, "Gap Histogram", width, height) - spark := common.PanelStyle.Width(panelInner).Render( + spark := common.PanelStyle.Width(panelW).Render( renderOverviewSparkline("Gap sparkline:", snap.GapSeriesNs(), panelInner), ) return strings.Join([]string{hist, spark}, "\n") @@ -53,6 +56,7 @@ func renderHistogram(hist statsengine.HistogramSnapshot, title string, width, he if width <= 0 { width = 80 } + panelW := panelWidth(width) panelInner := panelInnerWidth(width) if height > 0 { @@ -93,7 +97,7 @@ func renderHistogram(hist statsengine.HistogramSnapshot, title string, width, he } lines = append(lines, "Scale: █▓▒░") - return common.PanelStyle.Width(panelInner).Render(strings.Join(lines, "\n")) + return common.PanelStyle.Width(panelW).Render(strings.Join(lines, "\n")) } func renderHistogramBar(count, maxCount uint64, width int) string { diff --git a/internal/tui/dashboard/histogram_test.go b/internal/tui/dashboard/histogram_test.go index 7790394..48297a2 100644 --- a/internal/tui/dashboard/histogram_test.go +++ b/internal/tui/dashboard/histogram_test.go @@ -6,7 +6,7 @@ import ( "ior/internal/statsengine" - "github.com/charmbracelet/lipgloss" + "charm.land/lipgloss/v2" ) func TestRenderHistogramNoBuckets(t *testing.T) { diff --git a/internal/tui/dashboard/layout.go b/internal/tui/dashboard/layout.go index 0035a9d..75cbafb 100644 --- a/internal/tui/dashboard/layout.go +++ b/internal/tui/dashboard/layout.go @@ -4,7 +4,3 @@ const panelHorizontalChrome = 4 // Keep a small guard so sparkline rows never soft-wrap in panel cells. const sparklineSafetyMargin = 3 - -// Stats engine currently provides 120 time-series slots; cap rendering width -// so wide terminals don't introduce wrap/placement artifacts. -const sparklineMaxWidth = 120 diff --git a/internal/tui/dashboard/model.go b/internal/tui/dashboard/model.go index fc9caf6..d10a91a 100644 --- a/internal/tui/dashboard/model.go +++ b/internal/tui/dashboard/model.go @@ -1,20 +1,26 @@ package dashboard import ( + "strings" + "time" + "ior/internal/statsengine" common "ior/internal/tui/common" "ior/internal/tui/eventstream" + flamegraphtui "ior/internal/tui/flamegraph" "ior/internal/tui/messages" - "strings" - "time" - "github.com/charmbracelet/bubbles/key" - tea "github.com/charmbracelet/bubbletea" + "charm.land/bubbles/v2/key" + tea "charm.land/bubbletea/v2" ) const defaultRefreshMs = 1000 const streamRefreshMs = 200 +const flameRefreshMs = 200 const streamChromeRows = 4 +const dashboardHelpHintRows = 1 +const dashboardExpandedHelpRows = 2 +const dashboardTabBarRows = 1 // SnapshotSource is the dashboard data source. type SnapshotSource interface { @@ -23,6 +29,7 @@ type SnapshotSource interface { type refreshTickMsg struct{} type streamTickMsg struct{} +type flameTickMsg struct{} type streamEditorDoneMsg struct { err error } @@ -31,8 +38,9 @@ type streamEditorDoneMsg struct { type Model struct { activeTab Tab - engine SnapshotSource - latest *statsengine.Snapshot + engine SnapshotSource + latest *statsengine.Snapshot + liveTrie flamegraphtui.LiveTrieSource width int height int @@ -46,32 +54,50 @@ type Model struct { filesDirOffset int processesOffset int streamModel eventstream.Model + flamegraphModel flamegraphtui.Model showHelp bool + isDark bool + focused bool } // NewModel creates a dashboard model with default refresh cadence. -func NewModel(engine SnapshotSource, streamSource *eventstream.RingBuffer) Model { +func NewModel(engine SnapshotSource, streamSource eventstream.Source) Model { return NewModelWithConfig(engine, streamSource, defaultRefreshMs, common.Keys) } // NewModelWithConfig creates a dashboard model with explicit refresh and keys. -func NewModelWithConfig(engine SnapshotSource, streamSource *eventstream.RingBuffer, refreshMs int, keys common.KeyMap) Model { +func NewModelWithConfig(engine SnapshotSource, streamSource eventstream.Source, refreshMs int, keys common.KeyMap) Model { if refreshMs <= 0 { refreshMs = defaultRefreshMs } - return Model{ - activeTab: TabOverview, - engine: engine, - refreshEvery: time.Duration(refreshMs) * time.Millisecond, - keys: keys, - pidFilter: -1, - streamModel: eventstream.NewModel(streamSource), + m := Model{ + activeTab: TabFlame, + engine: engine, + refreshEvery: time.Duration(refreshMs) * time.Millisecond, + keys: keys, + pidFilter: -1, + streamModel: eventstream.NewModel(streamSource), + flamegraphModel: flamegraphtui.NewModel(nil), + isDark: true, + focused: true, } + m.SetDarkMode(true) + return m } // Init starts periodic refresh ticks. func (m Model) Init() tea.Cmd { - return tickCmd(m.refreshEvery) + cmds := []tea.Cmd{tickCmd(m.refreshEvery)} + switch m.activeTab { + case TabStream: + cmds = append(cmds, streamTickCmd()) + case TabFlame: + cmds = append(cmds, flameTickCmd()) + } + if len(cmds) == 1 { + return cmds[0] + } + return tea.Batch(cmds...) } // Update handles ticks, snapshots, tab changes, and resize events. @@ -82,19 +108,42 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { m.height = msg.Height streamWidth, streamHeight := streamViewport(msg.Width, msg.Height) m.streamModel.SetViewport(streamWidth, streamHeight) + flameWidth, flameHeight := flameViewport(msg.Width, msg.Height, m.showHelp) + m.flamegraphModel.SetViewport(flameWidth, flameHeight) return m, nil case refreshTickMsg: + if !m.focused { + return m, nil + } snap := m.snapshot() return m, tea.Batch( tickCmd(m.refreshEvery), func() tea.Msg { return messages.StatsTickMsg{Snap: snap} }, ) case streamTickMsg: + if !m.focused { + return m, nil + } if m.activeTab != TabStream { return m, nil } m.streamModel.Refresh() return m, streamTickCmd() + case flameTickMsg: + if !m.focused { + return m, nil + } + if m.activeTab != TabFlame { + return m, nil + } + var animCmd tea.Cmd + if m.liveTrie != nil && m.flamegraphModel.RefreshFromLiveTrie() { + animCmd = m.flamegraphModel.AnimationCmd() + } + if animCmd != nil { + return m, tea.Batch(flameTickCmd(), animCmd) + } + return m, flameTickCmd() case messages.StatsTickMsg: m.latest = msg.Snap m.syscallsOffset = clampOffset(m.syscallsOffset, m.maxSyscallsRows()) @@ -103,7 +152,7 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { m.processesOffset = clampOffset(m.processesOffset, m.maxProcessesRows()) m.streamModel.Refresh() return m, nil - case tea.KeyMsg: + case tea.KeyPressMsg: return m.handleKey(msg) case streamEditorDoneMsg: if msg.err != nil { @@ -111,17 +160,29 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } return m, nil } + if m.activeTab == TabFlame { + next, cmd := m.flamegraphModel.Update(msg) + m.flamegraphModel = next.(flamegraphtui.Model) + return m, cmd + } return m, nil } -func (m Model) handleKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { +func (m Model) handleKey(msg tea.KeyPressMsg) (tea.Model, tea.Cmd) { prevActiveTab := m.activeTab var cmd tea.Cmd keyStr := msg.String() if keyStr == "H" { m.showHelp = !m.showHelp + flameWidth, flameHeight := flameViewport(m.width, m.height, m.showHelp) + m.flamegraphModel.SetViewport(flameWidth, flameHeight) return m, nil } + if m.activeTab == TabFlame && m.flamegraphModel.ConsumesKey(msg) { + next, flameCmd := m.flamegraphModel.Update(msg) + m.flamegraphModel = next.(flamegraphtui.Model) + return m, flameCmd + } handled, scrollCmd := m.handleScrollKey(msg) if scrollCmd != nil { cmd = scrollCmd @@ -132,29 +193,29 @@ func (m Model) handleKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { if !handled { switch { + case key.Matches(msg, m.keys.One): + m.activeTab = TabFlame + handled = true case key.Matches(msg, m.keys.Tab): m.activeTab = nextTab(m.activeTab) handled = true case key.Matches(msg, m.keys.ShiftTab): m.activeTab = prevTab(m.activeTab) handled = true - case key.Matches(msg, m.keys.One): - m.activeTab = TabOverview - handled = true case key.Matches(msg, m.keys.Two): - m.activeTab = TabSyscalls + m.activeTab = TabOverview handled = true case key.Matches(msg, m.keys.Three): - m.activeTab = TabFiles + m.activeTab = TabSyscalls handled = true case key.Matches(msg, m.keys.Four): - m.activeTab = TabProcesses + m.activeTab = TabFiles handled = true case key.Matches(msg, m.keys.Five): - m.activeTab = TabLatency + m.activeTab = TabProcesses handled = true case key.Matches(msg, m.keys.Six): - m.activeTab = TabStream + m.activeTab = TabLatency handled = true case key.Matches(msg, m.keys.Seven): m.activeTab = TabStream @@ -171,18 +232,34 @@ func (m Model) handleKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { } } if !handled { + if m.activeTab == TabFlame { + next, flameCmd := m.flamegraphModel.Update(msg) + m.flamegraphModel = next.(flamegraphtui.Model) + return m, flameCmd + } return m, nil } + batch := make([]tea.Cmd, 0, 3) + if cmd != nil { + batch = append(batch, cmd) + } if prevActiveTab != TabStream && m.activeTab == TabStream { - if cmd == nil { - return m, streamTickCmd() - } - return m, tea.Batch(cmd, streamTickCmd()) + batch = append(batch, streamTickCmd()) + } + if prevActiveTab != TabFlame && m.activeTab == TabFlame { + batch = append(batch, flameTickCmd()) + } + switch len(batch) { + case 0: + return m, nil + case 1: + return m, batch[0] + default: + return m, tea.Batch(batch...) } - return m, cmd } -func (m *Model) handleScrollKey(msg tea.KeyMsg) (bool, tea.Cmd) { +func (m *Model) handleScrollKey(msg tea.KeyPressMsg) (bool, tea.Cmd) { keyStr := msg.String() switch m.activeTab { case TabSyscalls: @@ -271,26 +348,60 @@ func (m Model) LatestSnapshot() *statsengine.Snapshot { return m.latest } -// BlocksGlobalShortcuts reports whether modal UI in the active tab should -// suppress top-level shortcuts (for example global export key handling). -func (m Model) BlocksGlobalShortcuts() bool { - return m.activeTab == TabStream && (m.streamModel.FilterModalVisible() || m.streamModel.ExportModalVisible() || m.streamModel.SearchModalVisible()) +// BlocksGlobalShortcuts reports whether the active tab should suppress a +// top-level shortcut for the given key press. +func (m Model) BlocksGlobalShortcuts(msg tea.KeyPressMsg) bool { + if m.activeTab == TabStream { + return m.streamModel.FilterModalVisible() || m.streamModel.ExportModalVisible() || m.streamModel.SearchModalVisible() + } + if m.activeTab == TabFlame { + return m.flamegraphModel.ConsumesKey(msg) + } + return false } // SetStreamSource updates the live stream source used by the stream tab. -func (m *Model) SetStreamSource(source *eventstream.RingBuffer) { +func (m *Model) SetStreamSource(source eventstream.Source) { m.streamModel.SetSource(source) } +// SetLiveTrie updates the live trie source used by the flamegraph tab. +func (m *Model) SetLiveTrie(liveTrie flamegraphtui.LiveTrieSource) { + m.liveTrie = liveTrie + m.flamegraphModel.SetLiveTrie(liveTrie) + if m.width > 0 && m.height > 0 { + m.flamegraphModel.SetViewport(m.width, m.height) + } + m.flamegraphModel.RefreshFromLiveTrie() +} + +// SetDarkMode updates dashboard child models for the active theme. +func (m *Model) SetDarkMode(isDark bool) { + m.isDark = isDark + m.streamModel.SetDarkMode(isDark) + m.flamegraphModel.SetDarkMode(isDark) +} + +// SetFocused controls whether periodic refresh ticks are processed. +func (m *Model) SetFocused(focused bool) { + m.focused = focused +} + +// SnapshotCmd returns a command that fetches and emits a fresh dashboard snapshot. +func (m Model) SnapshotCmd() tea.Cmd { + snap := m.snapshot() + return func() tea.Msg { return messages.StatsTickMsg{Snap: snap} } +} + // SetPidFilter updates the active PID filter used by tab render hints. func (m *Model) SetPidFilter(pid int) { m.pidFilter = pid } // View renders the tab bar, active tab scaffold, and help bar. -func (m Model) View() string { +func (m Model) View() tea.View { width, height := common.EffectiveViewport(m.width, m.height) - activeHeight := height + _, activeHeight := flameViewport(width, height, m.showHelp) streamModel := m.streamModel streamModel.SetFooterVisible(m.showHelp) if m.activeTab == TabStream { @@ -304,6 +415,7 @@ func (m Model) View() string { m.activeTab, m.latest, &streamModel, + &m.flamegraphModel, width, activeHeight, m.pidFilter, @@ -319,20 +431,27 @@ func (m Model) View() string { } else { b.WriteString(renderHelpHint(width)) } - return common.ScreenStyle.Render(b.String()) + return tea.NewView(common.ScreenStyle.Render(b.String())) } func tickCmd(d time.Duration) tea.Cmd { return tea.Tick(d, func(time.Time) tea.Msg { return refreshTickMsg{} }) } -func renderActiveTab(tab Tab, snap *statsengine.Snapshot, streamModel *eventstream.Model, width, height, pidFilter, syscallsOffset, filesOffset int, filesDirGrouped bool, filesDirOffset, processesOffset int) string { +func renderActiveTab(tab Tab, snap *statsengine.Snapshot, streamModel *eventstream.Model, flameModel *flamegraphtui.Model, width, height, pidFilter, syscallsOffset, filesOffset int, filesDirGrouped bool, filesDirOffset, processesOffset int) string { if tab == TabStream { if streamModel == nil { return common.PanelStyle.Render("Stream: waiting for source...") } return streamModel.View(width, height) } + if tab == TabFlame { + if flameModel == nil { + return common.PanelStyle.Render("Flame: waiting for model...") + } + flameModel.SetViewport(width, height) + return flameModel.View().Content + } if snap == nil { return common.PanelStyle.Render(tab.String() + ": waiting for stats...") @@ -361,6 +480,10 @@ func streamTickCmd() tea.Cmd { return tea.Tick(streamRefreshMs*time.Millisecond, func(time.Time) tea.Msg { return streamTickMsg{} }) } +func flameTickCmd() tea.Cmd { + return tea.Tick(flameRefreshMs*time.Millisecond, func(time.Time) tea.Msg { return flameTickMsg{} }) +} + func streamViewport(width, height int) (int, int) { width, height = common.EffectiveViewport(width, height) height -= streamChromeRows @@ -369,3 +492,16 @@ func streamViewport(width, height int) (int, int) { } return width, height } + +func flameViewport(width, height int, showHelp bool) (int, int) { + width, height = common.EffectiveViewport(width, height) + chromeRows := dashboardTabBarRows + dashboardHelpHintRows + if showHelp { + chromeRows = dashboardTabBarRows + dashboardExpandedHelpRows + } + height -= chromeRows + if height < 1 { + height = 1 + } + return width, height +} diff --git a/internal/tui/dashboard/model_test.go b/internal/tui/dashboard/model_test.go index 87b60e3..d5b78e0 100644 --- a/internal/tui/dashboard/model_test.go +++ b/internal/tui/dashboard/model_test.go @@ -7,12 +7,13 @@ import ( "strings" "testing" + coreflamegraph "ior/internal/flamegraph" "ior/internal/statsengine" common "ior/internal/tui/common" "ior/internal/tui/eventstream" "ior/internal/tui/messages" - tea "github.com/charmbracelet/bubbletea" + tea "charm.land/bubbletea/v2" ) type fakeSnapshotSource struct { @@ -28,59 +29,60 @@ func (f *fakeSnapshotSource) Snapshot() *statsengine.Snapshot { func TestKeySwitchingChangesActiveTab(t *testing.T) { m := NewModelWithConfig(nil, nil, 250, common.DefaultKeyMap()) - next, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'2'}}) + next, _ := m.Update(tea.KeyPressMsg{Code: []rune{'2'}[0], Text: string([]rune{'2'})}) model := next.(Model) - if model.activeTab != TabSyscalls { - t.Fatalf("expected syscalls tab, got %v", model.activeTab) + if model.activeTab != TabOverview { + t.Fatalf("expected overview tab on key 2, got %v", model.activeTab) } - next, _ = model.Update(tea.KeyMsg{Type: tea.KeyTab}) + next, _ = model.Update(tea.KeyPressMsg{Code: tea.KeyTab}) model = next.(Model) - if model.activeTab != TabFiles { - t.Fatalf("expected next tab to be files, got %v", model.activeTab) + if model.activeTab != TabSyscalls { + t.Fatalf("expected next tab to be syscalls, got %v", model.activeTab) } - next, _ = model.Update(tea.KeyMsg{Type: tea.KeyShiftTab}) + next, _ = model.Update(tea.KeyPressMsg{Code: tea.KeyTab, Mod: tea.ModShift}) model = next.(Model) - if model.activeTab != TabSyscalls { - t.Fatalf("expected previous tab to be syscalls, got %v", model.activeTab) + if model.activeTab != TabOverview { + t.Fatalf("expected previous tab to be overview, got %v", model.activeTab) } - next, _ = model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'7'}}) + next, _ = model.Update(tea.KeyPressMsg{Code: []rune{'7'}[0], Text: string([]rune{'7'})}) model = next.(Model) if model.activeTab != TabStream { t.Fatalf("expected stream tab on key 7, got %v", model.activeTab) } - next, _ = model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'6'}}) + next, _ = model.Update(tea.KeyPressMsg{Code: []rune{'1'}[0], Text: string([]rune{'1'})}) model = next.(Model) - if model.activeTab != TabStream { - t.Fatalf("expected stream tab on key 6, got %v", model.activeTab) + if model.activeTab != TabFlame { + t.Fatalf("expected flame tab on key 1, got %v", model.activeTab) } } func TestArrowAndViKeysDoNotCycleTabs(t *testing.T) { m := NewModelWithConfig(nil, nil, 250, common.DefaultKeyMap()) + m.activeTab = TabOverview - next, _ := m.Update(tea.KeyMsg{Type: tea.KeyRight}) + next, _ := m.Update(tea.KeyPressMsg{Code: tea.KeyRight}) model := next.(Model) if model.activeTab != TabOverview { t.Fatalf("expected right arrow not to change tabs, got %v", model.activeTab) } - next, _ = model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'l'}}) + next, _ = model.Update(tea.KeyPressMsg{Code: []rune{'l'}[0], Text: string([]rune{'l'})}) model = next.(Model) if model.activeTab != TabOverview { t.Fatalf("expected l not to change tabs, got %v", model.activeTab) } - next, _ = model.Update(tea.KeyMsg{Type: tea.KeyLeft}) + next, _ = model.Update(tea.KeyPressMsg{Code: tea.KeyLeft}) model = next.(Model) if model.activeTab != TabOverview { t.Fatalf("expected left arrow not to change tabs, got %v", model.activeTab) } - next, _ = model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'h'}}) + next, _ = model.Update(tea.KeyPressMsg{Code: []rune{'h'}[0], Text: string([]rune{'h'})}) model = next.(Model) if model.activeTab != TabOverview { t.Fatalf("expected h not to change tabs, got %v", model.activeTab) @@ -93,13 +95,13 @@ func TestSyscallsTabScrollsWithJK(t *testing.T) { snap := statsengine.NewSnapshot(nil, nil, nil, []statsengine.SyscallSnapshot{{Name: "read", Count: 1}, {Name: "write", Count: 1}}, nil, nil, statsengine.HistogramSnapshot{}, statsengine.HistogramSnapshot{}) m.latest = &snap - next, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'j'}}) + next, _ := m.Update(tea.KeyPressMsg{Code: []rune{'j'}[0], Text: string([]rune{'j'})}) model := next.(Model) if model.syscallsOffset != 1 { t.Fatalf("expected offset 1 after j, got %d", model.syscallsOffset) } - next, _ = model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'k'}}) + next, _ = model.Update(tea.KeyPressMsg{Code: []rune{'k'}[0], Text: string([]rune{'k'})}) model = next.(Model) if model.syscallsOffset != 0 { t.Fatalf("expected offset 0 after k, got %d", model.syscallsOffset) @@ -112,13 +114,13 @@ func TestProcessesTabScrollsWithJK(t *testing.T) { snap := statsengine.NewSnapshot(nil, nil, nil, nil, nil, []statsengine.ProcessSnapshot{{PID: 1}, {PID: 2}}, statsengine.HistogramSnapshot{}, statsengine.HistogramSnapshot{}) m.latest = &snap - next, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'j'}}) + next, _ := m.Update(tea.KeyPressMsg{Code: []rune{'j'}[0], Text: string([]rune{'j'})}) model := next.(Model) if model.processesOffset != 1 { t.Fatalf("expected processes offset 1 after j, got %d", model.processesOffset) } - next, _ = model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'k'}}) + next, _ = model.Update(tea.KeyPressMsg{Code: []rune{'k'}[0], Text: string([]rune{'k'})}) model = next.(Model) if model.processesOffset != 0 { t.Fatalf("expected processes offset 0 after k, got %d", model.processesOffset) @@ -131,13 +133,13 @@ func TestFilesTabScrollsWithJK(t *testing.T) { snap := statsengine.NewSnapshot(nil, nil, nil, nil, []statsengine.FileSnapshot{{Path: "/a"}, {Path: "/b"}}, nil, statsengine.HistogramSnapshot{}, statsengine.HistogramSnapshot{}) m.latest = &snap - next, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'j'}}) + next, _ := m.Update(tea.KeyPressMsg{Code: []rune{'j'}[0], Text: string([]rune{'j'})}) model := next.(Model) if model.filesOffset != 1 { t.Fatalf("expected files offset 1 after j, got %d", model.filesOffset) } - next, _ = model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'k'}}) + next, _ = model.Update(tea.KeyPressMsg{Code: []rune{'k'}[0], Text: string([]rune{'k'})}) model = next.(Model) if model.filesOffset != 0 { t.Fatalf("expected files offset 0 after k, got %d", model.filesOffset) @@ -155,7 +157,7 @@ func TestFilesTabGroupedScrollUsesDirectoryOffset(t *testing.T) { }, nil, statsengine.HistogramSnapshot{}, statsengine.HistogramSnapshot{}) m.latest = &snap - next, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'j'}}) + next, _ := m.Update(tea.KeyPressMsg{Code: []rune{'j'}[0], Text: string([]rune{'j'})}) model := next.(Model) if model.filesDirOffset != 1 { t.Fatalf("expected grouped dir offset 1 after j, got %d", model.filesDirOffset) @@ -171,13 +173,73 @@ func TestStreamSpaceUnpauseSchedulesStreamTick(t *testing.T) { m.activeTab = TabStream m.streamModel.HandleKey("space") // pause - next, cmd := m.Update(tea.KeyMsg{Type: tea.KeySpace}) + next, cmd := m.Update(tea.KeyPressMsg{Code: tea.KeySpace}) _ = next if cmd == nil { t.Fatalf("expected stream tick command when unpausing stream") } } +func TestFlameTickRefreshesFlamegraphModel(t *testing.T) { + liveTrie := coreflamegraph.NewLiveTrie([]string{"comm", "path"}, "count") + liveTrie.Reset() + + m := NewModelWithConfig(nil, nil, 250, common.DefaultKeyMap()) + m.SetLiveTrie(liveTrie) + m.activeTab = TabFlame + + next, cmd := m.Update(flameTickMsg{}) + model := next.(Model) + if cmd == nil { + t.Fatalf("expected flame tick to schedule next tick command") + } + if got, want := model.flamegraphModel.LastVersion(), liveTrie.Version(); got != want { + t.Fatalf("expected flame model version %d, got %d", want, got) + } +} + +func TestSetLiveTriePreloadsInitialSnapshotWithoutVersionChange(t *testing.T) { + liveTrie := coreflamegraph.NewLiveTrie([]string{"comm", "path"}, "count") + + m := NewModelWithConfig(nil, nil, 250, common.DefaultKeyMap()) + m.SetLiveTrie(liveTrie) + m.activeTab = TabFlame + if !m.flamegraphModel.HasSnapshot() { + t.Fatalf("expected SetLiveTrie to preload a baseline snapshot") + } + + next, _ := m.Update(flameTickMsg{}) + model := next.(Model) + if !model.flamegraphModel.HasSnapshot() { + t.Fatalf("expected flame tick to retain initial snapshot even when trie version is unchanged") + } +} + +func TestFlameTickPausedFreezesAfterInitialSnapshot(t *testing.T) { + liveTrie := coreflamegraph.NewLiveTrie([]string{"comm", "path"}, "count") + m := NewModelWithConfig(nil, nil, 250, common.DefaultKeyMap()) + m.SetLiveTrie(liveTrie) + m.activeTab = TabFlame + + next, _ := m.Update(tea.KeyPressMsg{Code: tea.KeySpace, Text: " "}) + model := next.(Model) + + next, _ = model.Update(flameTickMsg{}) + model = next.(Model) + initialVersion := model.flamegraphModel.LastVersion() + + liveTrie.Reset() + if liveTrie.Version() == initialVersion { + t.Fatalf("expected reset to advance trie version") + } + + next, _ = model.Update(flameTickMsg{}) + model = next.(Model) + if got, want := model.flamegraphModel.LastVersion(), initialVersion; got != want { + t.Fatalf("expected paused flame tick to freeze version at %d, got %d", want, got) + } +} + func TestStreamPausedSupportsJKArrowsAndPageKeys(t *testing.T) { rb := eventstream.NewRingBuffer() for i := 0; i < 300; i++ { @@ -200,34 +262,34 @@ func TestStreamPausedSupportsJKArrowsAndPageKeys(t *testing.T) { m.streamModel.Refresh() _ = m.View() - next, _ = m.Update(tea.KeyMsg{Type: tea.KeySpace}) // pause + next, _ = m.Update(tea.KeyPressMsg{Code: tea.KeySpace}) // pause m = next.(Model) - before := rowFromStreamView(t, m.View()) + before := rowFromStreamView(t, m.View().Content) - next, _ = m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'k'}}) + next, _ = m.Update(tea.KeyPressMsg{Code: []rune{'k'}[0], Text: string([]rune{'k'})}) m = next.(Model) - afterK := rowFromStreamView(t, m.View()) + afterK := rowFromStreamView(t, m.View().Content) if afterK >= before { t.Fatalf("expected k to scroll up while paused: before=%d afterK=%d", before, afterK) } - next, _ = m.Update(tea.KeyMsg{Type: tea.KeyDown}) + next, _ = m.Update(tea.KeyPressMsg{Code: tea.KeyDown}) m = next.(Model) - afterDown := rowFromStreamView(t, m.View()) + afterDown := rowFromStreamView(t, m.View().Content) if afterDown <= afterK { t.Fatalf("expected down arrow to scroll down while paused: afterK=%d afterDown=%d", afterK, afterDown) } - next, _ = m.Update(tea.KeyMsg{Type: tea.KeyPgUp}) + next, _ = m.Update(tea.KeyPressMsg{Code: tea.KeyPgUp}) m = next.(Model) - afterPgUp := rowFromStreamView(t, m.View()) + afterPgUp := rowFromStreamView(t, m.View().Content) if afterPgUp >= afterDown { t.Fatalf("expected pgup to scroll up while paused: afterDown=%d afterPgUp=%d", afterDown, afterPgUp) } - next, _ = m.Update(tea.KeyMsg{Type: tea.KeyPgDown}) + next, _ = m.Update(tea.KeyPressMsg{Code: tea.KeyPgDown}) m = next.(Model) - afterPgDown := rowFromStreamView(t, m.View()) + afterPgDown := rowFromStreamView(t, m.View().Content) if afterPgDown <= afterPgUp { t.Fatalf("expected pgdown to scroll down while paused: afterPgUp=%d afterPgDown=%d", afterPgUp, afterPgDown) } @@ -251,14 +313,14 @@ func TestDirGroupKeyTogglesOnlyOnFilesTab(t *testing.T) { m := NewModelWithConfig(nil, nil, 250, common.DefaultKeyMap()) m.activeTab = TabFiles - next, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'d'}}) + next, _ := m.Update(tea.KeyPressMsg{Code: []rune{'d'}[0], Text: string([]rune{'d'})}) model := next.(Model) if !model.filesDirGrouped { t.Fatalf("expected filesDirGrouped to toggle on files tab") } model.activeTab = TabOverview - next, _ = model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'d'}}) + next, _ = model.Update(tea.KeyPressMsg{Code: []rune{'d'}[0], Text: string([]rune{'d'})}) model = next.(Model) if !model.filesDirGrouped { t.Fatalf("expected filesDirGrouped unchanged outside files tab") @@ -272,7 +334,7 @@ func TestScrollOffsetDoesNotGrowUnbounded(t *testing.T) { m.latest = &snap for i := 0; i < 50; i++ { - next, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'j'}}) + next, _ := m.Update(tea.KeyPressMsg{Code: []rune{'j'}[0], Text: string([]rune{'j'})}) m = next.(Model) } if m.syscallsOffset != 1 { @@ -284,7 +346,8 @@ func TestRefreshKeyEmitsRefreshTick(t *testing.T) { snap := &statsengine.Snapshot{TotalSyscalls: 13} engine := &fakeSnapshotSource{snap: snap} m := NewModelWithConfig(engine, nil, 250, common.DefaultKeyMap()) - next, cmd := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'r'}}) + m.activeTab = TabOverview + next, cmd := m.Update(tea.KeyPressMsg{Code: []rune{'r'}[0], Text: string([]rune{'r'})}) _ = next if cmd == nil { t.Fatalf("expected refresh command") @@ -299,6 +362,63 @@ func TestRefreshKeyEmitsRefreshTick(t *testing.T) { } } +func TestFlameTabReceivesSlashKey(t *testing.T) { + m := NewModelWithConfig(nil, nil, 250, common.DefaultKeyMap()) + m.activeTab = TabFlame + m.width = 120 + m.height = 30 + + next, cmd := m.Update(tea.KeyPressMsg{Code: []rune{'/'}[0], Text: string([]rune{'/'})}) + model := next.(Model) + if cmd != nil { + t.Fatalf("did not expect global command for flame search key") + } + if !strings.Contains(model.View().Content, "0/0 matches") { + t.Fatalf("expected flame search footer after pressing /") + } +} + +func TestFlameTabReceivesResetAndPauseKeys(t *testing.T) { + m := NewModelWithConfig(nil, nil, 250, common.DefaultKeyMap()) + m.activeTab = TabFlame + m.width = 120 + m.height = 30 + + next, _ := m.Update(tea.KeyPressMsg{Code: tea.KeySpace, Text: " "}) + model := next.(Model) + if !strings.Contains(model.View().Content, "[PAUSED]") { + t.Fatalf("expected flame space key to toggle paused state") + } + + next, cmd := model.Update(tea.KeyPressMsg{Code: []rune{'r'}[0], Text: string([]rune{'r'})}) + model = next.(Model) + if cmd != nil { + t.Fatalf("expected flame reset key to be handled by flame tab without global refresh command") + } + if model.activeTab != TabFlame { + t.Fatalf("expected flame tab to stay active after reset key") + } +} + +func TestFlameSearchConsumesNumericTabKeys(t *testing.T) { + m := NewModelWithConfig(nil, nil, 250, common.DefaultKeyMap()) + m.activeTab = TabFlame + m.width = 120 + m.height = 30 + + next, _ := m.Update(tea.KeyPressMsg{Code: []rune{'/'}[0], Text: string([]rune{'/'})}) + model := next.(Model) + if model.activeTab != TabFlame { + t.Fatalf("expected flame tab to stay active after opening search") + } + + next, _ = model.Update(tea.KeyPressMsg{Code: []rune{'2'}[0], Text: string([]rune{'2'})}) + model = next.(Model) + if model.activeTab != TabFlame { + t.Fatalf("expected numeric key while searching to stay in flame tab") + } +} + func TestRefreshTickEmitsStatsTickMsg(t *testing.T) { snap := &statsengine.Snapshot{TotalSyscalls: 9} engine := &fakeSnapshotSource{snap: snap} @@ -366,9 +486,9 @@ func TestStatsTickClampsGroupedFilesOffset(t *testing.T) { func TestViewRendersTabBarAndHelp(t *testing.T) { m := NewModelWithConfig(nil, nil, 1000, common.DefaultKeyMap()) - out := m.View() - if !strings.Contains(out, "Overview") { - t.Fatalf("expected overview label in view") + out := m.View().Content + if !strings.Contains(out, "Flame") { + t.Fatalf("expected flame tab label in view") } if !strings.Contains(out, "press H for help") { t.Fatalf("expected help hint text in view") @@ -378,6 +498,18 @@ func TestViewRendersTabBarAndHelp(t *testing.T) { } } +func TestFlameTabRendersWaitingForDataPlaceholder(t *testing.T) { + m := NewModelWithConfig(nil, nil, 1000, common.DefaultKeyMap()) + m.activeTab = TabFlame + m.width = 120 + m.height = 30 + + out := m.View().Content + if !strings.Contains(out, "Flame: waiting for data...") { + t.Fatalf("expected flame waiting placeholder, got %q", out) + } +} + func TestRenderActiveTabUsesDirectoryFilesViewWhenGrouped(t *testing.T) { snap := statsengine.NewSnapshot( nil, nil, nil, nil, @@ -386,7 +518,7 @@ func TestRenderActiveTabUsesDirectoryFilesViewWhenGrouped(t *testing.T) { statsengine.HistogramSnapshot{}, statsengine.HistogramSnapshot{}, ) - out := renderActiveTab(TabFiles, &snap, nil, 120, 30, -1, 0, 0, true, 0, 0) + out := renderActiveTab(TabFiles, &snap, nil, nil, 120, 30, -1, 0, 0, true, 0, 0) if !strings.Contains(out, "Directory") { t.Fatalf("expected grouped directory files view header, got %q", out) } @@ -405,8 +537,8 @@ func TestStreamTabViewKeepsTabAndHelpChromeVisible(t *testing.T) { m.streamModel.SetSource(rb) m.streamModel.Refresh() - out := m.View() - if !strings.Contains(out, "1:Overview") { + out := m.View().Content + if !strings.Contains(out, "1:Flame") { t.Fatalf("expected tab bar to remain visible in stream view") } if !strings.Contains(out, "press H for help") { @@ -416,21 +548,21 @@ func TestStreamTabViewKeepsTabAndHelpChromeVisible(t *testing.T) { func TestHelpToggleWithH(t *testing.T) { m := NewModelWithConfig(nil, nil, 1000, common.DefaultKeyMap()) - out := m.View() + out := m.View().Content if !strings.Contains(out, "press H for help") { t.Fatalf("expected default help hint") } - next, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'H'}}) + next, _ := m.Update(tea.KeyPressMsg{Code: []rune{'H'}[0], Text: string([]rune{'H'})}) m = next.(Model) - out = m.View() + out = m.View().Content if !strings.Contains(out, "tab next tab") { t.Fatalf("expected expanded help after pressing h") } - next, _ = m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'H'}}) + next, _ = m.Update(tea.KeyPressMsg{Code: []rune{'H'}[0], Text: string([]rune{'H'})}) m = next.(Model) - out = m.View() + out = m.View().Content if !strings.Contains(out, "press H for help") { t.Fatalf("expected help hint after pressing h again") } diff --git a/internal/tui/dashboard/overview.go b/internal/tui/dashboard/overview.go index 5b8fab8..24932b9 100644 --- a/internal/tui/dashboard/overview.go +++ b/internal/tui/dashboard/overview.go @@ -2,13 +2,14 @@ package dashboard import ( "fmt" - "ior/internal/statsengine" - common "ior/internal/tui/common" "strings" "time" "unicode/utf8" - "github.com/charmbracelet/lipgloss" + "ior/internal/statsengine" + common "ior/internal/tui/common" + + "charm.land/lipgloss/v2" ) func renderOverview(snap *statsengine.Snapshot, width, height int) string { @@ -33,6 +34,7 @@ func renderOverview(snap *statsengine.Snapshot, width, height int) string { trendWithArrow(snap.ThroughputTrend), ) + panelW := panelWidth(width) panelInner := panelInnerWidth(width) labelWidth := maxLabelWidth("Latency:", "Gap:", "Throughput:") latencySpark := renderOverviewSparklineAligned("Latency:", snap.LatencySeriesNs(), panelInner, labelWidth) @@ -44,8 +46,8 @@ func renderOverview(snap *statsengine.Snapshot, width, height int) string { latencyHist := "Latency buckets: " + summarizeHistogramBrief(snap.LatencyHistogram) gapHist := "Gap buckets: " + summarizeHistogramBrief(snap.GapHistogram) - panel := common.PanelStyle.Width(panelInner) - sparkPanel := panel.Render(strings.Join([]string{latencySpark, "", gapSpark, "", throughputSpark}, "\n")) + panel := common.PanelStyle.Width(panelW) + sparkPanel := panel.Render(strings.Join([]string{latencySpark, gapSpark, throughputSpark}, "\n")) topPanel := panel.Render(strings.Join([]string{topSyscalls, topFiles, topProcesses}, "\n")) histPanel := panel.Render(strings.Join([]string{latencyHist, gapHist}, "\n")) @@ -73,7 +75,7 @@ func renderSyscallBox(snap *statsengine.Snapshot, width int) string { snap.SyscallRatePerSec, generatedAt, ) - return common.PanelStyle.Width(summaryBoxInnerWidth(width)).Height(5).Render(content) + return common.PanelStyle.Width(width).Height(5).Render(content) } func renderBytesBox(snap *statsengine.Snapshot, width int) string { @@ -83,7 +85,7 @@ func renderBytesBox(snap *statsengine.Snapshot, width int) string { formatBytes(snap.WriteBytesPerSec), formatBytes(float64(snap.TotalBytes)), ) - return common.PanelStyle.Width(summaryBoxInnerWidth(width)).Height(5).Render(content) + return common.PanelStyle.Width(width).Height(5).Render(content) } func renderErrorBox(snap *statsengine.Snapshot, width int) string { @@ -99,7 +101,7 @@ func renderErrorBox(snap *statsengine.Snapshot, width int) string { snap.LatencyMeanNs, snap.GapMeanNs, ) - return common.PanelStyle.Width(summaryBoxInnerWidth(width)).Height(5).Render(content) + return common.PanelStyle.Width(width).Height(5).Render(content) } func trendWithArrow(trend statsengine.Trend) string { @@ -212,19 +214,8 @@ func summaryBoxWidth(width int) int { return w } -func summaryBoxInnerWidth(width int) int { - inner := width - panelHorizontalChrome - if inner < 14 { - return 14 - } - return inner -} - func renderOverviewSparkline(label string, data []float64, panelInner int) string { w := panelInner - utf8.RuneCountInString(label) - 1 - sparklineSafetyMargin - if w > sparklineMaxWidth { - w = sparklineMaxWidth - } if w < 8 { w = 8 } @@ -234,9 +225,6 @@ func renderOverviewSparkline(label string, data []float64, panelInner int) strin func renderOverviewSparklineAligned(label string, data []float64, panelInner int, labelWidth int) string { paddedLabel := padLabelRight(label, labelWidth) w := panelInner - labelWidth - 1 - sparklineSafetyMargin - if w > sparklineMaxWidth { - w = sparklineMaxWidth - } if w < 8 { w = 8 } @@ -262,13 +250,20 @@ func padLabelRight(label string, width int) string { return label + strings.Repeat(" ", pad) } -func panelInnerWidth(width int) int { +func panelWidth(width int) int { if width <= 0 { width = 80 } - inner := width - panelHorizontalChrome - if inner < 20 { + if width < 20 { return 20 } + return width +} + +func panelInnerWidth(width int) int { + inner := panelWidth(width) - panelHorizontalChrome + if inner < 16 { + return 16 + } return inner } diff --git a/internal/tui/dashboard/overview_test.go b/internal/tui/dashboard/overview_test.go index 9895490..6ac3704 100644 --- a/internal/tui/dashboard/overview_test.go +++ b/internal/tui/dashboard/overview_test.go @@ -6,8 +6,9 @@ import ( "time" "ior/internal/statsengine" + common "ior/internal/tui/common" - "github.com/charmbracelet/lipgloss" + "charm.land/lipgloss/v2" ) func TestRenderOverviewIncludesCoreMetrics(t *testing.T) { @@ -121,23 +122,22 @@ func TestRenderOverviewDoesNotOverflowWidth(t *testing.T) { func TestRenderOverviewSparklineHasSafetyMargin(t *testing.T) { const panelInner = 80 out := renderOverviewSparkline("Latency:", []float64{1, 2, 3, 4, 5}, panelInner) - lines := strings.Split(out, "\n") - if len(lines) != 2 { - t.Fatalf("expected 2-line sparkline, got %q", out) + if strings.Contains(out, "\n") { + t.Fatalf("expected single-line sparkline, got %q", out) } - if got, max := lipgloss.Width(lines[0]), panelInner-sparklineSafetyMargin; got > max { + if got, max := lipgloss.Width(out), panelInner-sparklineSafetyMargin; got > max { t.Fatalf("expected sparkline width <= %d with safety margin, got %d", max, got) } } -func TestRenderOverviewSparklineCapsWidth(t *testing.T) { +func TestRenderOverviewSparklineUsesAvailableWidth(t *testing.T) { out := renderOverviewSparkline("Latency:", make([]float64, 120), 400) - lines := strings.Split(out, "\n") - if len(lines) != 2 { - t.Fatalf("expected 2-line sparkline, got %q", out) + if strings.Contains(out, "\n") { + t.Fatalf("expected single-line sparkline, got %q", out) } - if got := lipgloss.Width(lines[0]) - len("Latency: "); got > sparklineMaxWidth { - t.Fatalf("expected capped sparkline width <= %d, got %d", sparklineMaxWidth, got) + want := 400 - len("Latency:") - 1 - sparklineSafetyMargin + if got := lipgloss.Width(out) - len("Latency: "); got != want { + t.Fatalf("expected sparkline width %d, got %d", want, got) } } @@ -164,3 +164,14 @@ func TestRenderOverviewSparklineAlignedUsesSameSparkStartColumn(t *testing.T) { t.Fatalf("unexpected throughput prefix: %q", thrTop) } } + +func TestRenderOverviewSparklineAlignedFitsSinglePanelRow(t *testing.T) { + panelW := panelWidth(220) + panelInner := panelInnerWidth(220) + labelWidth := maxLabelWidth("Latency:", "Gap:", "Throughput:") + line := renderOverviewSparklineAligned("Latency:", []float64{0, 10, 5, 10, 0}, panelInner, labelWidth) + rendered := common.PanelStyle.Width(panelW).Render(line) + if got := len(strings.Split(rendered, "\n")); got != 3 { + t.Fatalf("expected sparkline to fit one panel row (3 total lines with border), got %d lines", got) + } +} diff --git a/internal/tui/dashboard/processes.go b/internal/tui/dashboard/processes.go index 281a86a..a5e8d79 100644 --- a/internal/tui/dashboard/processes.go +++ b/internal/tui/dashboard/processes.go @@ -2,11 +2,12 @@ package dashboard import ( "fmt" - "ior/internal/statsengine" "strconv" "strings" - "github.com/charmbracelet/bubbles/table" + "ior/internal/statsengine" + + "charm.land/bubbles/v2/table" ) func renderProcesses(snap *statsengine.Snapshot, width, height int) string { diff --git a/internal/tui/dashboard/sparkline.go b/internal/tui/dashboard/sparkline.go index 2ce8c90..ab78cce 100644 --- a/internal/tui/dashboard/sparkline.go +++ b/internal/tui/dashboard/sparkline.go @@ -1,9 +1,8 @@ package dashboard import "math" -import "strings" -var sparkRowChars = []rune(" ▁▂▃▄▅▆▇█") +var sparkChars = []rune("▁▂▃▄▅▆▇█") func renderSparkline(data []float64, width int) string { if len(data) == 0 || width <= 0 { @@ -11,23 +10,15 @@ func renderSparkline(data []float64, width int) string { } samples := sampleForWidth(data, width) - leftPad := 0 - if len(samples) < width { - leftPad = width - len(samples) - } min, max := minMax(samples) if min == max { - top := repeatRune(' ', width) - bottom := repeatRune(' ', leftPad) + repeatRune('█', len(samples)) - return top + "\n" + bottom + if min == 0 { + return repeatRune(' ', width) + } + return repeatRune('▁', width) } - top := make([]rune, width) - bottom := make([]rune, width) - for i := 0; i < leftPad; i++ { - top[i] = ' ' - bottom[i] = ' ' - } + row := make([]rune, width) scale := 16.0 denom := max - min for i, value := range samples { @@ -39,20 +30,17 @@ func renderSparkline(data []float64, width int) string { level = 16 } - topLevel := level - 8 - if topLevel < 0 { - topLevel = 0 + // Collapse the previous two-row 0..16 scale to a single-row 0..7 scale. + oneRow := level / 2 + if oneRow < 0 { + oneRow = 0 } - bottomLevel := level - if bottomLevel > 8 { - bottomLevel = 8 + if oneRow > 7 { + oneRow = 7 } - - col := leftPad + i - top[col] = sparkRowChars[topLevel] - bottom[col] = sparkRowChars[bottomLevel] + row[i] = sparkChars[oneRow] } - return string(top) + "\n" + string(bottom) + return string(row) } func renderLabeledSparkline(label string, data []float64, width int) string { @@ -60,20 +48,47 @@ func renderLabeledSparkline(label string, data []float64, width int) string { if spark == "" { return label } - lines := strings.Split(spark, "\n") - if len(lines) == 1 { - return label + " " + lines[0] - } - pad := repeatRune(' ', len([]rune(label))+1) - return label + " " + lines[0] + "\n" + pad + lines[1] + return label + " " + spark } func sampleForWidth(data []float64, width int) []float64 { - if width >= len(data) { + if width <= 0 || len(data) == 0 { + return nil + } + + if width < len(data) { + start := len(data) - width + return append([]float64(nil), data[start:]...) + } + + if width == len(data) { return append([]float64(nil), data...) } - start := len(data) - width - return append([]float64(nil), data[start:]...) + + if len(data) == 1 { + out := make([]float64, width) + for i := range out { + out[i] = data[0] + } + return out + } + + out := make([]float64, width) + srcLast := len(data) - 1 + dstLast := width - 1 + for i := 0; i < width; i++ { + // Nearest-neighbor upsampling preserves the original series shape + // without introducing interpolated spikes between samples. + srcIdx := int(math.Round(float64(i) * float64(srcLast) / float64(dstLast))) + if srcIdx < 0 { + srcIdx = 0 + } + if srcIdx > srcLast { + srcIdx = srcLast + } + out[i] = data[srcIdx] + } + return out } func minMax(values []float64) (float64, float64) { diff --git a/internal/tui/dashboard/sparkline_test.go b/internal/tui/dashboard/sparkline_test.go index d7acd33..6f549d1 100644 --- a/internal/tui/dashboard/sparkline_test.go +++ b/internal/tui/dashboard/sparkline_test.go @@ -16,37 +16,52 @@ func TestRenderSparklineEmptyOrInvalidWidth(t *testing.T) { func TestRenderSparklineSingleValue(t *testing.T) { got := renderSparkline([]float64{10}, 8) - if got != " \n █" { - t.Fatalf("expected two-line constant sparkline, got %q", got) + if got != "▁▁▁▁▁▁▁▁" { + t.Fatalf("expected single-line constant sparkline, got %q", got) } } func TestRenderSparklineAllEqualValues(t *testing.T) { got := renderSparkline([]float64{5, 5, 5, 5}, 4) - if got != " \n████" { - t.Fatalf("expected two-line flat sparkline, got %q", got) + if got != "▁▁▁▁" { + t.Fatalf("expected single-line flat sparkline, got %q", got) } } -func TestRenderSparklineRightAlignsShortHistory(t *testing.T) { +func TestRenderSparklineAllZeroValuesRendersBlank(t *testing.T) { + got := renderSparkline([]float64{0, 0, 0}, 5) + if got != " " { + t.Fatalf("expected blank sparkline for all-zero series, got %q", got) + } +} + +func TestRenderSparklineLeftAlignsShortHistory(t *testing.T) { got := renderSparkline([]float64{1, 2, 3}, 6) - lines := strings.Split(got, "\n") - if len(lines) != 2 { - t.Fatalf("expected 2 lines, got %q", got) + first := strings.IndexFunc(got, func(r rune) bool { return r != ' ' }) + last := strings.LastIndexFunc(got, func(r rune) bool { return r != ' ' }) + if first < 0 || last < 0 { + t.Fatalf("expected visible sparkline cells, got %q", got) + } + if strings.HasPrefix(got, " ") { + t.Fatalf("expected sparkline not to use old right-aligned padding, got %q", got) } - if !strings.HasPrefix(lines[1], " ") { - t.Fatalf("expected left padding for short history, got %q", lines[1]) +} + +func TestRenderSparklineUsesRightmostColumn(t *testing.T) { + got := renderSparkline([]float64{1, 3, 2, 5}, 20) + row := []rune(got) + if len(row) != 20 { + t.Fatalf("expected 20 columns, got %d", len(row)) + } + if row[19] == ' ' { + t.Fatalf("expected rightmost column to contain sparkline data, got %q", got) } } func TestRenderSparklineRespectsWidthTruncation(t *testing.T) { got := renderSparkline([]float64{1, 2, 3, 4, 5, 6, 7, 8}, 4) - lines := strings.Split(got, "\n") - if len(lines) != 2 { - t.Fatalf("expected 2 lines, got %q", got) - } - if len([]rune(lines[0])) != 4 || len([]rune(lines[1])) != 4 { - t.Fatalf("expected 4 runes per line, got %q", got) + if len([]rune(got)) != 4 { + t.Fatalf("expected 4 runes, got %q", got) } } @@ -63,27 +78,32 @@ func TestSampleForWidthUsesRecentTail(t *testing.T) { } } +func TestSampleForWidthUpsamplesToFullWidth(t *testing.T) { + got := sampleForWidth([]float64{10, 20, 30}, 7) + if len(got) != 7 { + t.Fatalf("expected 7 samples, got %d", len(got)) + } + if got[0] != 10 { + t.Fatalf("expected first sample to preserve series start, got %v", got[0]) + } + if got[len(got)-1] != 30 { + t.Fatalf("expected last sample to preserve series end, got %v", got[len(got)-1]) + } +} + func TestRenderSparklineSpansLowToHigh(t *testing.T) { got := renderSparkline([]float64{0, 10}, 2) - lines := strings.Split(got, "\n") - if len(lines) != 2 { - t.Fatalf("expected 2 lines, got %q", got) - } - if !strings.Contains(got, "█") { - t.Fatalf("expected high bar, got %q", got) + if got != "▁█" { + t.Fatalf("expected low-to-high one-row sparkline, got %q", got) } } -func TestRenderLabeledSparklineAlignsSecondRow(t *testing.T) { +func TestRenderLabeledSparklineSingleLine(t *testing.T) { got := renderLabeledSparkline("Latency:", []float64{0, 10}, 2) - lines := strings.Split(got, "\n") - if len(lines) != 2 { - t.Fatalf("expected 2 lines, got %q", got) - } - if !strings.HasPrefix(lines[0], "Latency: ") { - t.Fatalf("expected label prefix on first row, got %q", lines[0]) + if strings.Contains(got, "\n") { + t.Fatalf("expected single-line labeled sparkline, got %q", got) } - if !strings.HasPrefix(lines[1], " ") { - t.Fatalf("expected padding on second row to align sparkline, got %q", lines[1]) + if !strings.HasPrefix(got, "Latency: ") { + t.Fatalf("expected label prefix, got %q", got) } } diff --git a/internal/tui/dashboard/syscalls.go b/internal/tui/dashboard/syscalls.go index 23fe37c..87acc80 100644 --- a/internal/tui/dashboard/syscalls.go +++ b/internal/tui/dashboard/syscalls.go @@ -2,11 +2,12 @@ package dashboard import ( "fmt" - "ior/internal/statsengine" "strconv" "time" - "github.com/charmbracelet/bubbles/table" + "ior/internal/statsengine" + + "charm.land/bubbles/v2/table" ) func renderSyscalls(snap *statsengine.Snapshot, width, height int) string { diff --git a/internal/tui/dashboard/tabs.go b/internal/tui/dashboard/tabs.go index df8f03e..5d15acc 100644 --- a/internal/tui/dashboard/tabs.go +++ b/internal/tui/dashboard/tabs.go @@ -2,11 +2,12 @@ package dashboard import ( "fmt" - common "ior/internal/tui/common" "strings" "unicode/utf8" - "github.com/charmbracelet/lipgloss" + common "ior/internal/tui/common" + + "charm.land/lipgloss/v2" ) // Tab is a dashboard tab identifier. @@ -25,9 +26,12 @@ const ( TabLatency // TabStream is the live event stream tab. TabStream + // TabFlame is the live flamegraph tab. + TabFlame ) var allTabs = []Tab{ + TabFlame, TabOverview, TabSyscalls, TabFiles, @@ -50,6 +54,8 @@ func (t Tab) String() string { return "Latency+Gaps" case TabStream: return "Stream" + case TabFlame: + return "Flame" default: return "Unknown" } @@ -192,6 +198,8 @@ func tabLabel(tab Tab, short bool) string { return "Lat" case TabStream: return "Str" + case TabFlame: + return "Flm" default: return "Unk" } diff --git a/internal/tui/dashboard/tabs_test.go b/internal/tui/dashboard/tabs_test.go index 1148103..16f8b76 100644 --- a/internal/tui/dashboard/tabs_test.go +++ b/internal/tui/dashboard/tabs_test.go @@ -11,17 +11,20 @@ func TestTabNavigationWraps(t *testing.T) { if got := nextTab(TabLatency); got != TabStream { t.Fatalf("expected next after latency+gaps to be stream, got %v", got) } - if got := nextTab(TabStream); got != TabOverview { - t.Fatalf("expected wrap to overview from stream, got %v", got) + if got := nextTab(TabStream); got != TabFlame { + t.Fatalf("expected next after stream to be flame, got %v", got) } - if got := prevTab(TabOverview); got != TabStream { - t.Fatalf("expected wrap to stream, got %v", got) + if got := nextTab(TabFlame); got != TabOverview { + t.Fatalf("expected wrap to overview from flame, got %v", got) + } + if got := prevTab(TabOverview); got != TabFlame { + t.Fatalf("expected wrap to flame, got %v", got) } } func TestRenderTabBarContainsLabels(t *testing.T) { out := renderTabBar(TabOverview, 100) - for _, label := range []string{"Overview", "Syscalls", "Files", "Processes", "Latency+Gaps", "Stream"} { + for _, label := range []string{"Overview", "Syscalls", "Files", "Processes", "Latency+Gaps", "Stream", "Flame"} { if !strings.Contains(out, label) { t.Fatalf("expected tab label %q in tab bar", label) } @@ -34,7 +37,7 @@ func TestRenderTabBarSmallWidthUsesSingleLine(t *testing.T) { if len(lines) != 1 { t.Fatalf("expected single-line tab bar at width 70, got %d lines", len(lines)) } - if strings.Contains(out, "6:Strea") { + if strings.Contains(out, "7:Flam") { t.Fatalf("tab label should not be wrapped/split in small width output") } } diff --git a/internal/tui/doc.go b/internal/tui/doc.go new file mode 100644 index 0000000..3175d65 --- /dev/null +++ b/internal/tui/doc.go @@ -0,0 +1,2 @@ +// Package tui hosts the top-level terminal UI model and screen routing. +package tui diff --git a/internal/tui/eventstream/doc.go b/internal/tui/eventstream/doc.go new file mode 100644 index 0000000..38bc854 --- /dev/null +++ b/internal/tui/eventstream/doc.go @@ -0,0 +1,2 @@ +// Package eventstream renders live event rows and interactive filtering controls. +package eventstream diff --git a/internal/tui/eventstream/exportmodal.go b/internal/tui/eventstream/exportmodal.go index cf020f7..3c0e2cd 100644 --- a/internal/tui/eventstream/exportmodal.go +++ b/internal/tui/eventstream/exportmodal.go @@ -3,9 +3,9 @@ package eventstream import ( "strings" - "github.com/charmbracelet/bubbles/textinput" - tea "github.com/charmbracelet/bubbletea" - "github.com/charmbracelet/lipgloss" + "charm.land/bubbles/v2/textinput" + tea "charm.land/bubbletea/v2" + "charm.land/lipgloss/v2" ) type ExportModal struct { @@ -18,7 +18,8 @@ func NewExportModal() ExportModal { input := textinput.New() input.Prompt = "" input.CharLimit = 0 - input.Width = 44 + input.SetWidth(44) + input.SetStyles(textinput.DefaultStyles(true)) return ExportModal{textInput: input} } @@ -26,6 +27,12 @@ func (m ExportModal) Visible() bool { return m.visible } +// SetDarkMode updates export modal text input styles. +func (m ExportModal) SetDarkMode(isDark bool) ExportModal { + m.textInput.SetStyles(textinput.DefaultStyles(isDark)) + return m +} + func (m ExportModal) Open(defaultName string) ExportModal { m.visible = true m.err = "" @@ -47,7 +54,7 @@ func (m ExportModal) Update(msg tea.Msg) (ExportModal, string, bool) { if !m.visible { return m, "", false } - if keyMsg, ok := msg.(tea.KeyMsg); ok { + if keyMsg, ok := msg.(tea.KeyPressMsg); ok { switch keyMsg.String() { case "esc": return m.Close(), "", false diff --git a/internal/tui/eventstream/filtermodal.go b/internal/tui/eventstream/filtermodal.go index f98db7f..bd20a03 100644 --- a/internal/tui/eventstream/filtermodal.go +++ b/internal/tui/eventstream/filtermodal.go @@ -5,9 +5,9 @@ import ( "strconv" "strings" - "github.com/charmbracelet/bubbles/textinput" - tea "github.com/charmbracelet/bubbletea" - "github.com/charmbracelet/lipgloss" + "charm.land/bubbles/v2/textinput" + tea "charm.land/bubbletea/v2" + "charm.land/lipgloss/v2" ) type fieldKey int @@ -48,7 +48,8 @@ func NewFilterModal() FilterModal { input := textinput.New() input.Prompt = "" input.CharLimit = 0 - input.Width = 24 + input.SetWidth(24) + input.SetStyles(textinput.DefaultStyles(true)) m := FilterModal{textInput: input} m.fields = defaultFilterFields() @@ -63,6 +64,12 @@ func (m FilterModal) Filter() Filter { return m.filter } +// SetDarkMode updates filter modal text input styles. +func (m FilterModal) SetDarkMode(isDark bool) FilterModal { + m.textInput.SetStyles(textinput.DefaultStyles(isDark)) + return m +} + func (m FilterModal) Open(initial Filter) FilterModal { m.visible = true m.activeField = 0 @@ -86,7 +93,7 @@ func (m FilterModal) Update(msg tea.Msg) FilterModal { return m } - if keyMsg, ok := msg.(tea.KeyMsg); ok { + if keyMsg, ok := msg.(tea.KeyPressMsg); ok { switch keyMsg.String() { case "esc": if m.editing { @@ -112,7 +119,7 @@ func (m FilterModal) Update(msg tea.Msg) FilterModal { m.fields[m.activeField].opIndex = (m.fields[m.activeField].opIndex + 1) % len(compareOps) } return m - case " ": + case " ", "space": if !m.editing && m.fields[m.activeField].fieldKey == fieldErrorsOnly { if strings.TrimSpace(m.fields[m.activeField].value) == "true" { m.fields[m.activeField].value = "false" diff --git a/internal/tui/eventstream/filtermodal_test.go b/internal/tui/eventstream/filtermodal_test.go index ee53c82..a33cbb1 100644 --- a/internal/tui/eventstream/filtermodal_test.go +++ b/internal/tui/eventstream/filtermodal_test.go @@ -3,7 +3,7 @@ package eventstream import ( "testing" - tea "github.com/charmbracelet/bubbletea" + tea "charm.land/bubbletea/v2" ) func TestFilterModalOpenClose(t *testing.T) { @@ -17,7 +17,7 @@ func TestFilterModalOpenClose(t *testing.T) { t.Fatalf("modal should be visible after open") } - m = m.Update(tea.KeyMsg{Type: tea.KeyEsc}) + m = m.Update(tea.KeyPressMsg{Code: tea.KeyEsc}) if m.Visible() { t.Fatalf("modal should close on esc") } @@ -29,11 +29,11 @@ func TestFilterModalNavigateFields(t *testing.T) { t.Fatalf("activeField=%d, want 0", m.activeField) } - m = m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("j")}) + m = m.Update(tea.KeyPressMsg{Code: []rune("j")[0], Text: string([]rune("j"))}) if m.activeField != 1 { t.Fatalf("activeField=%d, want 1", m.activeField) } - m = m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("k")}) + m = m.Update(tea.KeyPressMsg{Code: []rune("k")[0], Text: string([]rune("k"))}) if m.activeField != 0 { t.Fatalf("activeField=%d, want 0", m.activeField) } @@ -43,34 +43,34 @@ func TestFilterModalEditAndBuildFilter(t *testing.T) { m := NewFilterModal().Open(Filter{}) // Syscall = read - m = m.Update(tea.KeyMsg{Type: tea.KeyEnter}) - m = m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("read")}) - m = m.Update(tea.KeyMsg{Type: tea.KeyEnter}) + m = m.Update(tea.KeyPressMsg{Code: tea.KeyEnter}) + m = m.Update(tea.KeyPressMsg{Code: []rune("read")[0], Text: string([]rune("read"))}) + m = m.Update(tea.KeyPressMsg{Code: tea.KeyEnter}) // PID >= 123 - m = m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("j")}) - m = m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("j")}) - m = m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("j")}) - m = m.Update(tea.KeyMsg{Type: tea.KeyTab}) // '=' -> '>' - m = m.Update(tea.KeyMsg{Type: tea.KeyEnter}) - m = m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("123")}) - m = m.Update(tea.KeyMsg{Type: tea.KeyEnter}) + m = m.Update(tea.KeyPressMsg{Code: []rune("j")[0], Text: string([]rune("j"))}) + m = m.Update(tea.KeyPressMsg{Code: []rune("j")[0], Text: string([]rune("j"))}) + m = m.Update(tea.KeyPressMsg{Code: []rune("j")[0], Text: string([]rune("j"))}) + m = m.Update(tea.KeyPressMsg{Code: tea.KeyTab}) // '=' -> '>' + m = m.Update(tea.KeyPressMsg{Code: tea.KeyEnter}) + m = m.Update(tea.KeyPressMsg{Code: []rune("123")[0], Text: string([]rune("123"))}) + m = m.Update(tea.KeyPressMsg{Code: tea.KeyEnter}) // Latency >= 1ms - m = m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("j")}) - m = m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("j")}) - m = m.Update(tea.KeyMsg{Type: tea.KeyTab}) // '=' -> '>=' - m = m.Update(tea.KeyMsg{Type: tea.KeyEnter}) - m = m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("1ms")}) - m = m.Update(tea.KeyMsg{Type: tea.KeyEnter}) + m = m.Update(tea.KeyPressMsg{Code: []rune("j")[0], Text: string([]rune("j"))}) + m = m.Update(tea.KeyPressMsg{Code: []rune("j")[0], Text: string([]rune("j"))}) + m = m.Update(tea.KeyPressMsg{Code: tea.KeyTab}) // '=' -> '>=' + m = m.Update(tea.KeyPressMsg{Code: tea.KeyEnter}) + m = m.Update(tea.KeyPressMsg{Code: []rune("1ms")[0], Text: string([]rune("1ms"))}) + m = m.Update(tea.KeyPressMsg{Code: tea.KeyEnter}) // ErrorsOnly = true for m.activeField < len(m.fields)-1 { - m = m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("j")}) + m = m.Update(tea.KeyPressMsg{Code: []rune("j")[0], Text: string([]rune("j"))}) } - m = m.Update(tea.KeyMsg{Type: tea.KeySpace}) + m = m.Update(tea.KeyPressMsg{Code: tea.KeySpace}) - m = m.Update(tea.KeyMsg{Type: tea.KeyEsc}) + m = m.Update(tea.KeyPressMsg{Code: tea.KeyEsc}) if m.Visible() { t.Fatalf("modal should close on esc") } @@ -98,8 +98,8 @@ func TestFilterModalClearAll(t *testing.T) { } m := NewFilterModal().Open(initial) - m = m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("c")}) - m = m.Update(tea.KeyMsg{Type: tea.KeyEsc}) + m = m.Update(tea.KeyPressMsg{Code: []rune("c")[0], Text: string([]rune("c"))}) + m = m.Update(tea.KeyPressMsg{Code: tea.KeyEsc}) f := m.Filter() if f.IsActive() { diff --git a/internal/tui/eventstream/model.go b/internal/tui/eventstream/model.go index d9c4ee3..12aff4d 100644 --- a/internal/tui/eventstream/model.go +++ b/internal/tui/eventstream/model.go @@ -6,7 +6,8 @@ import ( "strconv" "strings" - tea "github.com/charmbracelet/bubbletea" + "charm.land/bubbles/v2/viewport" + tea "charm.land/bubbletea/v2" ) const ( @@ -23,8 +24,14 @@ const ( streamColumnCount ) +// Source is the minimal stream buffer contract needed by the stream model. +type Source interface { + Len() int + Snapshot() []StreamEvent +} + type Model struct { - source *RingBuffer + source Source allEvents []StreamEvent filtered []StreamEvent @@ -53,11 +60,13 @@ type Model struct { pendingOpenPath string statusMessage string exportDir string + isDark bool width int height int showFooter bool + viewport viewport.Model } type fdTraceViewState struct { @@ -68,8 +77,8 @@ type fdTraceViewState struct { offset int } -func NewModel(source *RingBuffer) Model { - return Model{ +func NewModel(source Source) Model { + m := Model{ source: source, filterModal: NewFilterModal(), exportModal: NewExportModal(), @@ -79,7 +88,25 @@ func NewModel(source *RingBuffer) Model { selectedCol: 0, exportDir: ".", showFooter: true, + isDark: true, + viewport: newStreamViewport(), } + m.SetDarkMode(true) + return m +} + +func newStreamViewport() viewport.Model { + vp := viewport.New() + keyMap := viewport.DefaultKeyMap() + keyMap.Down.SetKeys("down", "j") + keyMap.Up.SetKeys("up", "k") + keyMap.Left.SetKeys("left", "h") + keyMap.Right.SetKeys("right", "l") + keyMap.PageDown.SetKeys("pgdown", "pgdn", "pagedown") + keyMap.PageUp.SetKeys("pgup", "pageup") + vp.KeyMap = keyMap + vp.SoftWrap = true + return vp } // SetViewport updates the render/scroll viewport dimensions used for @@ -87,9 +114,11 @@ func NewModel(source *RingBuffer) Model { func (m *Model) SetViewport(width, height int) { if width > 0 { m.width = width + m.viewport.SetWidth(width) } if height > 0 { m.height = height + m.viewport.SetHeight(m.visibleRows()) } } @@ -99,11 +128,19 @@ func (m *Model) SetFooterVisible(visible bool) { } // SetSource updates the backing ring buffer and refreshes visible rows. -func (m *Model) SetSource(source *RingBuffer) { +func (m *Model) SetSource(source Source) { m.source = source m.Refresh() } +// SetDarkMode updates stream modal text input styles for the active theme. +func (m *Model) SetDarkMode(isDark bool) { + m.isDark = isDark + m.filterModal = m.filterModal.SetDarkMode(isDark) + m.exportModal = m.exportModal.SetDarkMode(isDark) + m.searchModal = m.searchModal.SetDarkMode(isDark) +} + // FilterModalVisible reports whether the filter modal is currently open. func (m Model) FilterModalVisible() bool { return m.filterModal.Visible() @@ -284,7 +321,8 @@ func (m *Model) HandleKey(keyStr string) bool { m.moveSelectionTo(len(m.filtered) - 1) } else { m.autoScroll = true - m.scrollOffset = m.maxScrollOffset() + m.viewport.GotoBottom() + m.scrollOffset = clamp(m.viewport.YOffset(), 0, m.maxScrollOffset()) } return true case "g": @@ -292,6 +330,7 @@ func (m *Model) HandleKey(keyStr string) bool { m.moveSelectionTo(0) } else { m.autoScroll = false + m.viewport.GotoTop() m.scrollOffset = 0 } return true @@ -305,14 +344,14 @@ func (m *Model) HandleKey(keyStr string) bool { if m.paused { m.moveSelectionBy(1) } else { - m.scrollByLines(1) + m.handleViewportUpdate(keyMsgFromString("down")) } return true case "k", "up": if m.paused { m.moveSelectionBy(-1) } else { - m.scrollByLines(-1) + m.handleViewportUpdate(keyMsgFromString("up")) } return true case "left", "h": @@ -320,25 +359,25 @@ func (m *Model) HandleKey(keyStr string) bool { m.moveSelectedColBy(-1) return true } - return false + return m.handleViewportUpdate(keyMsgFromString("left")) case "right", "l": if m.paused { m.moveSelectedColBy(1) return true } - return false + return m.handleViewportUpdate(keyMsgFromString("right")) case "pgdown", "pgdn", "pagedown": if m.paused { m.moveSelectionBy(m.pageStep()) } else { - m.scrollByLines(m.pageStep()) + m.handleViewportUpdate(keyMsgFromString("pgdown")) } return true case "pgup", "pageup": if m.paused { m.moveSelectionBy(-m.pageStep()) } else { - m.scrollByLines(-m.pageStep()) + m.handleViewportUpdate(keyMsgFromString("pgup")) } return true case "esc": @@ -353,8 +392,12 @@ func (m *Model) HandleKey(keyStr string) bool { // HandleTeaKey handles stream keys based on Bubble Tea key message types first, // then falls back to string matching for rune-driven shortcuts. -func (m *Model) HandleTeaKey(msg tea.KeyMsg) bool { - switch msg.Type { +func (m *Model) HandleTeaKey(msg tea.KeyPressMsg) bool { + if m.handleViewportUpdate(msg) { + return true + } + + switch msg.Code { case tea.KeyLeft: return m.HandleKey("left") case tea.KeyRight: @@ -373,14 +416,45 @@ func (m *Model) HandleTeaKey(msg tea.KeyMsg) bool { return m.HandleKey("esc") case tea.KeyEnter: return m.HandleKey("enter") - case tea.KeyRunes: - if len(msg.Runes) == 1 { - return m.HandleKey(string(msg.Runes[0])) + default: + if msg.Text != "" { + runes := []rune(msg.Text) + if len(runes) == 1 { + return m.HandleKey(msg.Text) + } } } return m.HandleKey(msg.String()) } +func (m *Model) handleViewportUpdate(msg tea.KeyPressMsg) bool { + if m.paused || m.fdTraceView.visible || m.filterModal.Visible() || m.exportModal.Visible() || m.searchModal.Visible() { + return false + } + + switch msg.String() { + case "down", "j", "up", "k", "left", "h", "right", "l", "pgup", "pageup", "pgdown", "pgdn", "pagedown": + default: + return false + } + + switch msg.String() { + case "pgup", "pageup": + m.viewport.ScrollUp(m.pageStep()) + case "pgdown", "pgdn", "pagedown": + m.viewport.ScrollDown(m.pageStep()) + default: + vp, cmd := m.viewport.Update(msg) + _ = cmd + m.viewport = vp + } + m.scrollOffset = clamp(m.viewport.YOffset(), 0, m.maxScrollOffset()) + if m.scrollOffset < m.maxScrollOffset() { + m.autoScroll = false + } + return true +} + func (m *Model) View(width, height int) string { if width <= 0 { width = 100 @@ -390,13 +464,16 @@ func (m *Model) View(width, height int) string { } m.width = width m.height = height + m.viewport.SetWidth(width) + m.viewport.SetHeight(m.visibleRows()) if m.fdTraceView.visible { return m.viewFDTrace(width) } rows := m.visibleRows() - start := clamp(m.scrollOffset, 0, m.maxScrollOffset()) + start := clamp(m.viewport.YOffset(), 0, m.maxScrollOffset()) + m.scrollOffset = start end := start + rows if end > len(m.filtered) { end = len(m.filtered) @@ -464,6 +541,8 @@ func (m *Model) Refresh() { m.allEvents = []StreamEvent{} m.filtered = []StreamEvent{} m.scrollOffset = 0 + m.viewport.SetContentLines(nil) + m.viewport.SetYOffset(0) return } @@ -476,6 +555,8 @@ func (m *Model) applyFilter() { m.filtered = []StreamEvent{} m.scrollOffset = 0 m.selectedIdx = -1 + m.viewport.SetContentLines(nil) + m.viewport.SetYOffset(0) return } @@ -487,12 +568,18 @@ func (m *Model) applyFilter() { } } m.filtered = filtered + m.viewport.SetWidth(m.width) + m.viewport.SetHeight(m.visibleRows()) + lines := make([]string, len(m.filtered)) + m.viewport.SetContentLines(lines) max := m.maxScrollOffset() if m.autoScroll { - m.scrollOffset = max + m.viewport.GotoBottom() + m.scrollOffset = clamp(m.viewport.YOffset(), 0, max) } else { m.scrollOffset = clamp(m.scrollOffset, 0, max) + m.viewport.SetYOffset(m.scrollOffset) } m.clampSelection() if m.paused { @@ -529,26 +616,6 @@ func (m *Model) pageStep() int { return rows - 1 } -func (m *Model) scrollByLines(delta int) { - if delta == 0 { - return - } - max := m.maxScrollOffset() - next := m.scrollOffset + delta - if next < 0 { - next = 0 - } - if next > max { - next = max - } - if next != m.scrollOffset { - m.scrollOffset = next - } - if m.scrollOffset < max { - m.autoScroll = false - } -} - func (m *Model) openFDTraceView() bool { if m.fdTraceView.visible || m.selectedIdx < 0 || m.selectedIdx >= len(m.filtered) { return false @@ -646,6 +713,7 @@ func (m *Model) centerSelection() { mid := m.visibleRows() / 2 target := m.selectedIdx - mid m.scrollOffset = clamp(target, 0, m.maxScrollOffset()) + m.viewport.SetYOffset(m.scrollOffset) } func (m *Model) ensureSelection() { @@ -807,26 +875,26 @@ func (m *Model) clampSelection() { m.selectedIdx = clamp(m.selectedIdx, 0, len(m.filtered)-1) } -func keyMsgFromString(keyStr string) tea.KeyMsg { +func keyMsgFromString(keyStr string) tea.KeyPressMsg { switch keyStr { case "esc": - return tea.KeyMsg{Type: tea.KeyEsc} + return tea.KeyPressMsg{Code: tea.KeyEsc} case "enter": - return tea.KeyMsg{Type: tea.KeyEnter} + return tea.KeyPressMsg{Code: tea.KeyEnter} case "tab": - return tea.KeyMsg{Type: tea.KeyTab} + return tea.KeyPressMsg{Code: tea.KeyTab} case "up": - return tea.KeyMsg{Type: tea.KeyUp} + return tea.KeyPressMsg{Code: tea.KeyUp} case "down": - return tea.KeyMsg{Type: tea.KeyDown} + return tea.KeyPressMsg{Code: tea.KeyDown} case " ", "space": - return tea.KeyMsg{Type: tea.KeySpace} + return tea.KeyPressMsg{Code: tea.KeySpace, Text: " "} } if keyStr == "" { - return tea.KeyMsg{} + return tea.KeyPressMsg{} } runes := []rune(keyStr) - return tea.KeyMsg{Type: tea.KeyRunes, Runes: runes} + return tea.KeyPressMsg{Code: runes[0], Text: keyStr} } func rowNumber(start, total int) int { diff --git a/internal/tui/eventstream/render.go b/internal/tui/eventstream/render.go index 1f539c6..3ec4d65 100644 --- a/internal/tui/eventstream/render.go +++ b/internal/tui/eventstream/render.go @@ -2,11 +2,12 @@ package eventstream import ( "fmt" - "ior/internal/tui/common" "strconv" "strings" - "github.com/charmbracelet/lipgloss" + "ior/internal/tui/common" + + "charm.land/lipgloss/v2" ) type columnLayout struct { diff --git a/internal/tui/eventstream/render_test.go b/internal/tui/eventstream/render_test.go index b020edf..6240c69 100644 --- a/internal/tui/eventstream/render_test.go +++ b/internal/tui/eventstream/render_test.go @@ -4,7 +4,7 @@ import ( "strings" "testing" - "github.com/charmbracelet/lipgloss" + "charm.land/lipgloss/v2" ) func TestRenderStatusAndFilterLines(t *testing.T) { diff --git a/internal/tui/eventstream/searchmodal.go b/internal/tui/eventstream/searchmodal.go index f744d00..c09542b 100644 --- a/internal/tui/eventstream/searchmodal.go +++ b/internal/tui/eventstream/searchmodal.go @@ -3,9 +3,9 @@ package eventstream import ( "strings" - "github.com/charmbracelet/bubbles/textinput" - tea "github.com/charmbracelet/bubbletea" - "github.com/charmbracelet/lipgloss" + "charm.land/bubbles/v2/textinput" + tea "charm.land/bubbletea/v2" + "charm.land/lipgloss/v2" ) type SearchDirection int @@ -26,7 +26,8 @@ func NewSearchModal() SearchModal { input := textinput.New() input.Prompt = "" input.CharLimit = 0 - input.Width = 44 + input.SetWidth(44) + input.SetStyles(textinput.DefaultStyles(true)) return SearchModal{textInput: input, direction: SearchForward} } @@ -38,6 +39,12 @@ func (m SearchModal) Direction() SearchDirection { return m.direction } +// SetDarkMode updates search modal text input styles. +func (m SearchModal) SetDarkMode(isDark bool) SearchModal { + m.textInput.SetStyles(textinput.DefaultStyles(isDark)) + return m +} + func (m SearchModal) Open(direction SearchDirection, defaultTerm string) SearchModal { m.visible = true m.err = "" @@ -60,7 +67,7 @@ func (m SearchModal) Update(msg tea.Msg) (SearchModal, string, bool) { if !m.visible { return m, "", false } - if keyMsg, ok := msg.(tea.KeyMsg); ok { + if keyMsg, ok := msg.(tea.KeyPressMsg); ok { switch keyMsg.String() { case "esc": return m.Close(), "", false diff --git a/internal/tui/eventstream/streamevent.go b/internal/tui/eventstream/streamevent.go index dbe04dd..5f1e27f 100644 --- a/internal/tui/eventstream/streamevent.go +++ b/internal/tui/eventstream/streamevent.go @@ -1,9 +1,10 @@ package eventstream import ( + "time" + "ior/internal/event" "ior/internal/types" - "time" ) type StreamEvent struct { diff --git a/internal/tui/eventstream/streamevent_test.go b/internal/tui/eventstream/streamevent_test.go index 6131fed..dd65dd1 100644 --- a/internal/tui/eventstream/streamevent_test.go +++ b/internal/tui/eventstream/streamevent_test.go @@ -1,10 +1,11 @@ package eventstream import ( + "testing" + "ior/internal/event" "ior/internal/file" "ior/internal/types" - "testing" ) func TestNewStreamEventPopulatesFields(t *testing.T) { diff --git a/internal/tui/export/doc.go b/internal/tui/export/doc.go new file mode 100644 index 0000000..356b800 --- /dev/null +++ b/internal/tui/export/doc.go @@ -0,0 +1,2 @@ +// Package export implements the TUI snapshot export modal and option handling. +package export diff --git a/internal/tui/export/model.go b/internal/tui/export/model.go index 57612db..179754d 100644 --- a/internal/tui/export/model.go +++ b/internal/tui/export/model.go @@ -5,8 +5,8 @@ import ( "fmt" "strings" - tea "github.com/charmbracelet/bubbletea" - "github.com/charmbracelet/lipgloss" + tea "charm.land/bubbletea/v2" + "charm.land/lipgloss/v2" ) // Option is a selectable export target. @@ -75,7 +75,7 @@ func (m Model) Close() Model { // Update handles modal key navigation and export completion messages. func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) { switch msg := msg.(type) { - case tea.KeyMsg: + case tea.KeyPressMsg: if !m.visible { return m, nil } diff --git a/internal/tui/export/model_test.go b/internal/tui/export/model_test.go index a97cd8b..2d47435 100644 --- a/internal/tui/export/model_test.go +++ b/internal/tui/export/model_test.go @@ -5,7 +5,7 @@ import ( "strings" "testing" - tea "github.com/charmbracelet/bubbletea" + tea "charm.land/bubbletea/v2" ) func TestOpenAndClose(t *testing.T) { @@ -21,7 +21,7 @@ func TestOpenAndClose(t *testing.T) { func TestEnterEmitsRequest(t *testing.T) { m := NewModel().Open() - next, cmd := m.Update(tea.KeyMsg{Type: tea.KeyEnter}) + next, cmd := m.Update(tea.KeyPressMsg{Code: tea.KeyEnter}) if cmd == nil { t.Fatalf("expected request command on enter") } @@ -40,7 +40,7 @@ func TestEnterEmitsRequest(t *testing.T) { func TestCancelOptionCloses(t *testing.T) { m := NewModel().Open() m.selected = len(optionValues) - 1 - next, cmd := m.Update(tea.KeyMsg{Type: tea.KeyEnter}) + next, cmd := m.Update(tea.KeyPressMsg{Code: tea.KeyEnter}) if cmd != nil { t.Fatalf("expected no command when selecting cancel") } diff --git a/internal/tui/flamegraph/animation.go b/internal/tui/flamegraph/animation.go new file mode 100644 index 0000000..103d43b --- /dev/null +++ b/internal/tui/flamegraph/animation.go @@ -0,0 +1,145 @@ +package flamegraph + +import ( + "math" + + "github.com/charmbracelet/harmonica" +) + +const springEpsilon = 0.01 + +type frameSpring struct { + path string + base tuiFrame + widthSpring harmonica.Spring + colSpring harmonica.Spring + + currentW float64 + currentCol float64 + velocityW float64 + velocityCol float64 + + targetW float64 + targetCol float64 +} + +// AnimationState stores per-frame spring interpolation state. +type AnimationState struct { + springs []frameSpring + frames []tuiFrame + settled bool + + fps int + angularVelocity float64 + damping float64 +} + +// NewAnimationState builds a spring animation state with the provided parameters. +func NewAnimationState(fps int, angularVelocity, damping float64) AnimationState { + if fps <= 0 { + fps = 30 + } + return AnimationState{ + fps: fps, + angularVelocity: angularVelocity, + damping: damping, + settled: true, + } +} + +// SetTargets sets new frame targets, preserving spring motion for matching paths. +func (a *AnimationState) SetTargets(targets []tuiFrame) { + existing := make(map[string]frameSpring, len(a.springs)) + for _, spring := range a.springs { + existing[spring.path] = spring + } + + next := make([]frameSpring, 0, len(targets)) + for _, target := range targets { + spring, ok := existing[target.Path] + if !ok { + spring = frameSpring{ + path: target.Path, + currentW: float64(target.Width), + currentCol: float64(target.Col), + } + } + spring.base = target + spring.targetW = float64(target.Width) + spring.targetCol = float64(target.Col) + spring.widthSpring = harmonica.NewSpring(harmonica.FPS(a.fps), a.angularVelocity, a.damping) + spring.colSpring = harmonica.NewSpring(harmonica.FPS(a.fps), a.angularVelocity, a.damping) + next = append(next, spring) + } + a.springs = next + if cap(a.frames) < len(a.springs) { + a.frames = make([]tuiFrame, len(a.springs)) + } else { + a.frames = a.frames[:len(a.springs)] + } + a.settled = len(a.springs) == 0 + for _, spring := range a.springs { + if !isSpringSettled(spring) { + a.settled = false + break + } + } +} + +// Tick advances springs by delta seconds and returns true while animation is active. +func (a *AnimationState) Tick(delta float64) bool { + if len(a.springs) == 0 { + a.settled = true + return false + } + baseDelta := harmonica.FPS(a.fps) + if delta <= 0 { + delta = baseDelta + } + + active := false + for idx := range a.springs { + spring := &a.springs[idx] + if delta != baseDelta { + spring.widthSpring = harmonica.NewSpring(delta, a.angularVelocity, a.damping) + spring.colSpring = harmonica.NewSpring(delta, a.angularVelocity, a.damping) + } + spring.currentW, spring.velocityW = spring.widthSpring.Update(spring.currentW, spring.velocityW, spring.targetW) + spring.currentCol, spring.velocityCol = spring.colSpring.Update(spring.currentCol, spring.velocityCol, spring.targetCol) + if !isSpringSettled(*spring) { + active = true + } + } + a.settled = !active + return active +} + +// CurrentFrames returns interpolated frames for the current animation step. +func (a *AnimationState) CurrentFrames() []tuiFrame { + for idx, spring := range a.springs { + frame := spring.base + frame.Col = maxInt(0, int(math.Round(spring.currentCol))) + frame.Width = maxInt(1, int(math.Round(spring.currentW))) + a.frames[idx] = frame + } + return a.frames +} + +// Settled reports whether all active springs are at rest. +func (a AnimationState) Settled() bool { + return a.settled +} + +func isSpringSettled(s frameSpring) bool { + return math.Abs(s.currentW-s.targetW) < springEpsilon && + math.Abs(s.currentCol-s.targetCol) < springEpsilon && + math.Abs(s.velocityW) < springEpsilon && + math.Abs(s.velocityCol) < springEpsilon +} + +func maxInt(a, b int) int { + if a > b { + return a + } + return b +} diff --git a/internal/tui/flamegraph/animation_test.go b/internal/tui/flamegraph/animation_test.go new file mode 100644 index 0000000..94272e2 --- /dev/null +++ b/internal/tui/flamegraph/animation_test.go @@ -0,0 +1,50 @@ +package flamegraph + +import "testing" + +func TestAnimationStateConvergesToTarget(t *testing.T) { + state := NewAnimationState(30, 6.0, 1.0) + state.SetTargets([]tuiFrame{{Path: "root", Col: 0, Width: 10}}) + state.SetTargets([]tuiFrame{{Path: "root", Col: 100, Width: 50}}) + + active := true + for i := 0; i < 180 && active; i++ { + active = state.Tick(0) + } + if active { + t.Fatalf("expected springs to settle within 180 ticks") + } + + frames := state.CurrentFrames() + if len(frames) != 1 { + t.Fatalf("expected one interpolated frame, got %d", len(frames)) + } + if frames[0].Col != 100 || frames[0].Width != 50 { + t.Fatalf("expected settled frame at col=100 width=50, got col=%d width=%d", frames[0].Col, frames[0].Width) + } + if state.Tick(0) { + t.Fatalf("expected settled animation to remain inactive") + } +} + +func TestAnimationStateHandlesAddedAndRemovedFrames(t *testing.T) { + state := NewAnimationState(30, 6.0, 1.0) + state.SetTargets([]tuiFrame{ + {Path: "root", Col: 0, Width: 20}, + {Path: "root\x1fchild", Col: 20, Width: 20}, + }) + if got := len(state.CurrentFrames()); got != 2 { + t.Fatalf("expected 2 frames after initial targets, got %d", got) + } + + state.SetTargets([]tuiFrame{ + {Path: "root\x1fchild", Col: 40, Width: 30}, + }) + frames := state.CurrentFrames() + if len(frames) != 1 { + t.Fatalf("expected removed frame to be dropped, got %d frames", len(frames)) + } + if frames[0].Path != "root\x1fchild" { + t.Fatalf("expected remaining frame path root\\x1fchild, got %q", frames[0].Path) + } +} diff --git a/internal/tui/flamegraph/bench_test.go b/internal/tui/flamegraph/bench_test.go new file mode 100644 index 0000000..33d77d1 --- /dev/null +++ b/internal/tui/flamegraph/bench_test.go @@ -0,0 +1,401 @@ +package flamegraph + +import ( + "encoding/json" + "fmt" + "testing" + + coreflamegraph "ior/internal/flamegraph" + "ior/internal/types" + + "github.com/charmbracelet/harmonica" +) + +var ( + benchFramesSink []tuiFrame + benchStringSink string + benchIntSink int + benchFloatSink float64 +) + +func BenchmarkBuildTerminalLayout(b *testing.B) { + // Performance target: medium_120col should remain below 500us/op. + fixtures := []struct { + label string + depth int + breadth int + }{ + {label: "small", depth: fixtureSmallDepth, breadth: fixtureSmallBreadth}, + {label: "medium", depth: fixtureMediumDepth, breadth: fixtureMediumBreadth}, + {label: "large", depth: fixtureLargeDepth, breadth: fixtureLargeBreadth}, + {label: "deep", depth: fixtureDeepDepth, breadth: fixtureDeepBreadth}, + {label: "wide", depth: fixtureWideDepth, breadth: fixtureWideBreadth}, + } + widths := []int{80, 120, 200, 300} + const height = 40 + + snapshots := make(map[string]*snapshotNode, len(fixtures)) + for _, fixture := range fixtures { + snapshots[fixture.label] = generateTestSnapshot(fixture.depth, fixture.breadth) + } + + for _, fixture := range fixtures { + snapshot := snapshots[fixture.label] + for _, width := range widths { + name := fmt.Sprintf("%s_%dcol", fixture.label, width) + b.Run(name, func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + benchFramesSink = BuildTerminalLayout(snapshot, width, height) + } + if len(benchFramesSink) == 0 { + b.Fatal("layout returned no frames") + } + }) + } + } +} + +func BenchmarkRenderFrame(b *testing.B) { + // Performance target: medium_120x40 should remain below 2ms/op. + // Allocation target: run with -benchmem and keep render path below 5 allocs/op. + fixtures := []struct { + label string + snapshot *snapshotNode + }{ + {label: "medium", snapshot: generateTestSnapshot(fixtureMediumDepth, fixtureMediumBreadth)}, + {label: "large", snapshot: generateTestSnapshot(fixtureLargeDepth, fixtureLargeBreadth)}, + } + viewports := []struct { + width int + height int + }{ + {width: 80, height: 24}, + {width: 120, height: 40}, + {width: 200, height: 60}, + } + + for _, fixture := range fixtures { + for _, viewport := range viewports { + name := fmt.Sprintf("%s_%dx%d", fixture.label, viewport.width, viewport.height) + b.Run(name, func(b *testing.B) { + model := NewModel(nil) + model.width = viewport.width + model.height = viewport.height + model.snapshot = fixture.snapshot + model.rebuildFrames(false) + if len(model.frames) == 0 { + b.Fatal("render benchmark requires non-empty frame layout") + } + + for idx := range model.frames { + switch idx % 12 { + case 0: + model.frames[idx].Name = "sys_enter_read" + case 1: + model.frames[idx].Name = "sys_enter_write" + } + } + model.selectedIdx = midDepthFrameIndex(model.frames) + model.subtreeSet = computeSubtreeSetInto(model.frames, model.selectedIdx, model.subtreeSet) + model.applySearchQuery("sys_") + + b.ReportAllocs() + for i := 0; i < b.N; i++ { + benchStringSink = model.View().Content + } + }) + } + } +} + +func BenchmarkComputeSubtreeSet(b *testing.B) { + // Performance target: 1000-frame subtree membership should remain below 100us/op. + // Allocation target: zero allocs/op by reusing map storage. + cases := []struct { + label string + frameCount int + }{ + {label: "100frames", frameCount: 100}, + {label: "1000frames", frameCount: 1000}, + {label: "5000frames", frameCount: 5000}, + } + + for _, tc := range cases { + frames := benchmarkFramesForCount(tc.frameCount) + if len(frames) == 0 { + b.Fatalf("%s produced no frames", tc.label) + } + selectedIdx := midDepthFrameIndex(frames) + reuse := make(map[int]bool, len(frames)) + + b.Run(tc.label, func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + subtree := computeSubtreeSetInto(frames, selectedIdx, reuse) + benchIntSink = len(subtree) + } + }) + } +} + +func BenchmarkSearchHighlight(b *testing.B) { + // Performance target: 1000-frame search should remain below 200us/op. + cases := []struct { + label string + frameCount int + }{ + {label: "100frames", frameCount: 100}, + {label: "1000frames", frameCount: 1000}, + {label: "5000frames", frameCount: 5000}, + } + queries := []string{"read", "sys_", "/srv/app"} + + for _, tc := range cases { + frames := benchmarkFramesForCount(tc.frameCount) + if len(frames) == 0 { + b.Fatalf("%s produced no frames", tc.label) + } + decorateFramesForSearch(frames) + + model := NewModel(nil) + model.frames = frames + model.selectedIdx = midDepthFrameIndex(frames) + model.subtreeSet = computeSubtreeSetInto(model.frames, model.selectedIdx, model.subtreeSet) + + b.Run(tc.label, func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + model.applySearchQuery(queries[i%len(queries)]) + benchIntSink = len(model.matchIndices) + } + }) + } +} + +func BenchmarkSpringUpdate(b *testing.B) { + // Performance target: 500 active springs should update in < 1ms per tick. + counts := []int{100, 500, 2000} + const ( + angularVelocity = 6.0 + damping = 1.0 + ) + + for _, count := range counts { + b.Run(fmt.Sprintf("%d_springs", count), func(b *testing.B) { + springs := make([]harmonica.Spring, count) + current := make([]float64, count) + velocity := make([]float64, count) + target := make([]float64, count) + + for idx := range springs { + springs[idx] = harmonica.NewSpring(harmonica.FPS(30), angularVelocity, damping) + current[idx] = float64(idx) + target[idx] = float64(idx + 8) + } + + b.ReportAllocs() + for i := 0; i < b.N; i++ { + for idx := range springs { + current[idx], velocity[idx] = springs[idx].Update(current[idx], velocity[idx], target[idx]) + } + benchFloatSink = current[count-1] + } + }) + } +} + +func BenchmarkAnimationTick(b *testing.B) { + // Performance target: 500 animated frames should complete in < 1ms per tick. + // Allocation target: zero allocs/op in the tick + CurrentFrames path. + counts := []int{100, 500, 2000} + + for _, count := range counts { + b.Run(fmt.Sprintf("%d_frames", count), func(b *testing.B) { + state := NewAnimationState(30, 6.0, 1.0) + base := linearFrames(count, 0, 10) + target := linearFrames(count, 5, 20) + state.SetTargets(base) + state.SetTargets(target) + + b.ReportAllocs() + for i := 0; i < b.N; i++ { + if !state.Tick(0) { + for idx := range state.springs { + state.springs[idx].targetCol += 3 + state.springs[idx].targetW += 2 + } + state.settled = false + } + frames := state.CurrentFrames() + benchIntSink = frames[len(frames)-1].Width + } + }) + } +} + +func BenchmarkZoomTransition(b *testing.B) { + // Performance target: zoom-in transition should stay below 1ms/op. + snapshot := generateTestSnapshot(fixtureMediumDepth, fixtureMediumBreadth) + model := NewModel(nil) + model.width = 120 + model.height = 40 + model.snapshot = snapshot + model.rebuildFrames(false) + if len(model.frames) == 0 { + b.Fatal("zoom benchmark requires non-empty initial layout") + } + zoomPath := model.frames[midDepthFrameIndex(model.frames)].Path + + b.Run("zoom_in", func(b *testing.B) { + benchModel := model + b.ReportAllocs() + for i := 0; i < b.N; i++ { + benchModel.zoomReset() + benchModel.selectedIdx = frameIndexByPath(benchModel.frames, zoomPath) + benchModel.zoomIn() + benchIntSink = len(benchModel.targetFrames) + } + }) + + b.Run("undo_zoom", func(b *testing.B) { + benchModel := model + benchModel.selectedIdx = frameIndexByPath(benchModel.frames, zoomPath) + benchModel.zoomIn() + if len(benchModel.zoomStack) == 0 { + b.Fatal("undo benchmark requires an active zoom stack") + } + + b.ReportAllocs() + for i := 0; i < b.N; i++ { + benchModel.zoomUndo() + benchIntSink = len(benchModel.frames) + + benchModel.selectedIdx = frameIndexByPath(benchModel.frames, zoomPath) + benchModel.zoomIn() + } + }) +} + +func BenchmarkLiveTrieIngestAndSnapshot(b *testing.B) { + // Performance target: ingest+snapshot pipeline should remain below 200us/op for small/medium cycles. + counts := []int{100, 1000, 10000} + for _, count := range counts { + b.Run(fmt.Sprintf("%d_events", count), func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + liveTrie := coreflamegraph.NewLiveTrie([]string{"comm", "path", "tracepoint"}, "count") + for eventIdx := 0; eventIdx < count; eventIdx++ { + traceID := types.SYS_ENTER_READ + if eventIdx%2 == 0 { + traceID = types.SYS_ENTER_WRITE + } + pair := newBenchmarkPair( + fmt.Sprintf("worker-%d", eventIdx%4), + traceID, + uint32(1000+(eventIdx%64)), + uint32(200000+eventIdx), + buildBenchmarkPath(8, 6, eventIdx), + ) + liveTrie.Ingest(pair) + pair.Recycle() + } + + payload, _ := liveTrie.SnapshotJSON() + var snapshot snapshotNode + if err := json.Unmarshal(payload, &snapshot); err != nil { + b.Fatalf("snapshot decode failed: %v", err) + } + benchFramesSink = BuildTerminalLayout(&snapshot, 120, 40) + } + }) + } +} + +func BenchmarkResizeRelayout(b *testing.B) { + // Performance target: resize relayout cost should match BuildTerminalLayout (< 500us medium@120col). + snapshot := generateTestSnapshot(fixtureMediumDepth, fixtureMediumBreadth) + b.ReportAllocs() + for i := 0; i < b.N; i++ { + frames120 := BuildTerminalLayout(snapshot, 120, 40) + frames80 := BuildTerminalLayout(snapshot, 80, 24) + benchFramesSink = BuildTerminalLayout(snapshot, 120, 40) + benchIntSink = len(frames120) + len(frames80) + len(benchFramesSink) + } +} + +func benchmarkFramesForCount(frameCount int) []tuiFrame { + var snapshot *snapshotNode + switch frameCount { + case 100: + snapshot = generateTestSnapshot(fixtureDeepDepth, fixtureDeepBreadth) + case 1000: + snapshot = generateTestSnapshot(20, 5) + case 5000: + snapshot = generateTestSnapshot(fixtureWideDepth, fixtureWideBreadth) + default: + snapshot = generateTestSnapshot(10, 5) + } + return BuildTerminalLayout(snapshot, 200, 80) +} + +func decorateFramesForSearch(frames []tuiFrame) { + for idx := range frames { + switch idx % 6 { + case 0: + frames[idx].Name = "sys_enter_read" + case 1: + frames[idx].Name = "sys_enter_write" + case 2: + frames[idx].Name = "read_cache_buffer" + case 3: + frames[idx].Name = "path:/srv/app/api" + case 4: + frames[idx].Name = "worker_loop" + default: + frames[idx].Name = "io_wait" + } + } +} + +func midDepthFrameIndex(frames []tuiFrame) int { + if len(frames) == 0 { + return 0 + } + maxDepth := 0 + for _, frame := range frames { + if frame.Depth > maxDepth { + maxDepth = frame.Depth + } + } + targetDepth := maxDepth / 2 + indices := framesAtDepth(frames, targetDepth) + if len(indices) == 0 { + return len(frames) / 2 + } + return indices[len(indices)/2] +} + +func frameIndexByPath(frames []tuiFrame, path string) int { + for idx, frame := range frames { + if frame.Path == path { + return idx + } + } + return 0 +} + +func linearFrames(count, colOffset, width int) []tuiFrame { + frames := make([]tuiFrame, count) + for idx := 0; idx < count; idx++ { + path := fmt.Sprintf("root%snode-%d", pathSeparator, idx) + frames[idx] = tuiFrame{ + Name: fmt.Sprintf("node-%d", idx), + Path: path, + Col: colOffset + idx, + Row: idx % 8, + Width: width, + } + } + return frames +} diff --git a/internal/tui/flamegraph/controls.go b/internal/tui/flamegraph/controls.go new file mode 100644 index 0000000..06e6d0d --- /dev/null +++ b/internal/tui/flamegraph/controls.go @@ -0,0 +1,173 @@ +package flamegraph + +import ( + "fmt" + "strings" + + common "ior/internal/tui/common" + + "charm.land/lipgloss/v2" +) + +func (m *Model) togglePause() { + m.paused = !m.paused +} + +func (m *Model) clearSnapshotState(clearSearch bool) { + m.zoomRoot = nil + m.zoomPath = "" + m.zoomStack = nil + m.selectedIdx = 0 + m.snapshot = nil + m.globalTotal = 0 + m.frames = nil + m.targetFrames = nil + m.matchIndices = make(map[int]bool) + m.filterVisible = make(map[int]bool) + m.subtreeSet = make(map[int]bool) + m.hasNavigableSnapshot = false + if clearSearch { + m.searchQuery = "" + } +} + +func (m *Model) resetBaseline() { + if m.liveTrie != nil { + m.liveTrie.Reset() + } + m.clearSnapshotState(true) + m.statusMessage = "Baseline reset" +} + +func (m *Model) cycleFieldOrder() { + if len(m.fieldPresets) == 0 { + return + } + m.fieldIndex = (m.fieldIndex + 1) % len(m.fieldPresets) + nextPreset := m.fieldPresets[m.fieldIndex] + if m.liveTrie != nil { + if err := m.liveTrie.Reconfigure(nextPreset); err != nil { + m.statusMessage = "Field order error: " + err.Error() + return + } + } + m.clearSnapshotState(false) + m.statusMessage = "Order: " + strings.Join(nextPreset, "/") +} + +func (m *Model) toggleCountField() { + next := "bytes" + if m.countField == "bytes" { + next = "count" + } + if m.liveTrie != nil { + if err := m.liveTrie.SetCountField(next); err != nil { + m.statusMessage = "Metric toggle error: " + err.Error() + return + } + } + m.countField = next + m.clearSnapshotState(false) + m.statusMessage = "Metric: " + m.countFieldLabel() + " (new baseline)" +} + +func (m *Model) toggleHelp() { + m.showHelp = !m.showHelp +} + +func (m Model) toolbarLine() string { + state := lipgloss.NewStyle().Foreground(common.ColorPrimary).Render("[LIVE]") + if m.paused { + state = lipgloss.NewStyle().Foreground(common.ColorDanger).Bold(true).Render("[PAUSED]") + } + order := m.currentFieldPresetLabel() + line := fmt.Sprintf("%s | view:%s | o:order(%s) | b:metric(%s) | /:search | enter:zoom | u/esc:undo | r:reset | space/p:pause", state, compactFramePath(m.currentRootPath()), order, m.countFieldLabel()) + if m.searchQuery != "" { + line += " | filter:" + m.searchQuery + } + if m.statusMessage != "" { + line += " | " + m.statusMessage + } + if m.lastKeyDebug != "" { + line += " | " + m.lastKeyDebug + } + width := m.width + if width <= 0 { + width = 80 + } + return padOrTrim(line, width) +} + +func (m Model) helpOverlay() string { + width := m.width + if width <= 0 { + width = 80 + } + help := "Flame help: j/k depth h/l sibling pgup top pgdn root enter zoom u/backspace/esc undo / search n/N matches space/p pause r reset baseline o order b metric ? help" + return common.HelpBarStyle.Width(width).Render(padOrTrim(help, width)) +} + +func (m Model) selectionStatusLine() string { + width := m.width + if width <= 0 { + width = 80 + } + mode := "LIVE" + if m.paused { + mode = "PAUSED" + } + if len(m.frames) == 0 { + line := fmt.Sprintf("[%s] sel:none | arrows/hjkl navigate | enter zoom | / filter", mode) + return common.HelpBarStyle.Width(width).Render(padOrTrim(line, width)) + } + selIdx := m.selectedIdx + if selIdx < 0 || selIdx >= len(m.frames) { + selIdx = 0 + } + frame := m.frames[selIdx] + systemShare := frame.Percent + if m.globalTotal > 0 { + systemShare = percentOfTotal(frame.Total, m.globalTotal) + } + metric := m.countFieldLabel() + shareLabel := fmt.Sprintf("%.2f%% of total %s", systemShare, metric) + if strings.TrimSpace(m.searchQuery) != "" && len(m.matchIndices) > 0 { + filterTotal, _ := filterCoverageTotals(m.frames, m.matchIndices, m.globalTotal) + if filterTotal > 0 { + selectedFilterTotal := filterCoverageTotalForPath(m.frames, m.matchIndices, frame.Path) + filterShare := percentOfTotal(selectedFilterTotal, filterTotal) + shareLabel = fmt.Sprintf("%.2f%% of filtered %s", filterShare, metric) + } + } + line := fmt.Sprintf("[%s] sel:%d/%d %s | path:%s | depth:%d | total(%s):%d | %s", + mode, selIdx+1, len(m.frames), frame.Name, compactFramePath(frame.Path), frame.Depth, m.countFieldLabel(), frame.Total, shareLabel) + if m.searchQuery != "" { + line += " | filter:" + m.searchQuery + } + return common.HelpBarStyle.Width(width).Render(padOrTrim(line, width)) +} + +func (m Model) currentFieldPresetLabel() string { + if len(m.fieldPresets) == 0 { + return "n/a" + } + idx := m.fieldIndex + if idx < 0 { + idx = 0 + } + if idx >= len(m.fieldPresets) { + idx = len(m.fieldPresets) - 1 + } + return strings.Join(m.fieldPresets[idx], "/") +} + +func (m Model) countFieldLabel() string { + switch m.countField { + case "count": + return "events" + case "bytes": + return "bytes" + default: + return m.countField + } +} diff --git a/internal/tui/flamegraph/doc.go b/internal/tui/flamegraph/doc.go new file mode 100644 index 0000000..7982ae9 --- /dev/null +++ b/internal/tui/flamegraph/doc.go @@ -0,0 +1,2 @@ +// Package flamegraph renders the interactive terminal flamegraph dashboard tab. +package flamegraph diff --git a/internal/tui/flamegraph/model.go b/internal/tui/flamegraph/model.go new file mode 100644 index 0000000..cc208ae --- /dev/null +++ b/internal/tui/flamegraph/model.go @@ -0,0 +1,1027 @@ +package flamegraph + +import ( + "encoding/json" + "fmt" + "image/color" + "slices" + "sort" + "strings" + "time" + + common "ior/internal/tui/common" + + "charm.land/bubbles/v2/key" + "charm.land/bubbles/v2/textinput" + tea "charm.land/bubbletea/v2" +) + +type snapshotNode struct { + Name string `json:"n"` + Value uint64 `json:"v"` + Total uint64 `json:"t"` + Children []*snapshotNode `json:"c,omitempty"` +} + +type animTickMsg struct{} + +const animFrameDuration = 33 * time.Millisecond + +// LiveTrieSource is the minimal trie contract needed by the flamegraph TUI model. +type LiveTrieSource interface { + Fields() []string + CountField() string + Reconfigure([]string) error + SetCountField(string) error + Reset() + Version() uint64 + SnapshotJSON() ([]byte, uint64) +} + +type zoomState struct { + path string + previousSelectedIdx int +} + +type flameKeyMap struct { + MoveShallower key.Binding + MoveDeeper key.Binding + PrevSibling key.Binding + NextSibling key.Binding + JumpTop key.Binding + JumpRoot key.Binding + ZoomIn key.Binding + ZoomUndo key.Binding + ZoomReset key.Binding +} + +func defaultFlameKeyMap() flameKeyMap { + return flameKeyMap{ + MoveShallower: key.NewBinding(key.WithKeys("j", "down")), + MoveDeeper: key.NewBinding(key.WithKeys("k", "up")), + PrevSibling: key.NewBinding(key.WithKeys("h", "left")), + NextSibling: key.NewBinding(key.WithKeys("l", "right")), + JumpTop: key.NewBinding(key.WithKeys("pgup", "pageup")), + JumpRoot: key.NewBinding(key.WithKeys("pgdown", "pgdn", "pagedown")), + ZoomIn: key.NewBinding(key.WithKeys("enter")), + ZoomUndo: key.NewBinding(key.WithKeys("backspace", "u", "esc")), + ZoomReset: key.NewBinding(), + } +} + +// Model is the Bubble Tea model for the TUI flamegraph tab. +type Model struct { + liveTrie LiveTrieSource + lastVersion uint64 + snapshot *snapshotNode + globalTotal uint64 + + frames []tuiFrame + targetFrames []tuiFrame + width int + height int + + selectedIdx int + zoomStack []zoomState + zoomRoot *snapshotNode + zoomPath string + + searchActive bool + searchInput textinput.Model + searchQuery string + matchIndices map[int]bool + filterVisible map[int]bool + subtreeSet map[int]bool + showHelp bool + statusMessage string + lastKeyDebug string + + fieldPresets [][]string + fieldIndex int + countField string + + animation AnimationState + animating bool + paused bool + // hasNavigableSnapshot flips once we have at least one selectable non-root frame. + hasNavigableSnapshot bool + isDark bool + keys flameKeyMap +} + +// tuiFrame stores one terminal flamegraph frame cell. +type tuiFrame struct { + Name string + Col int + Row int + Width int + Total uint64 + Percent float64 + Fill color.Color + Depth int + Path string +} + +// NewModel constructs a flamegraph tab model with default state. +func NewModel(liveTrie LiveTrieSource) Model { + searchInput := textinput.New() + searchInput.Prompt = "/" + searchInput.CharLimit = 0 + searchInput.SetWidth(32) + searchInput.SetStyles(textinput.DefaultStyles(true)) + + m := Model{ + liveTrie: liveTrie, + matchIndices: make(map[int]bool), + filterVisible: make(map[int]bool), + subtreeSet: make(map[int]bool), + searchInput: searchInput, + fieldPresets: [][]string{ + {"comm", "tracepoint", "path"}, + {"path", "tracepoint", "comm"}, + {"tracepoint", "comm", "path"}, + {"pid", "tracepoint", "path"}, + {"comm", "path", "tracepoint"}, + }, + isDark: true, + keys: defaultFlameKeyMap(), + animation: NewAnimationState(30, 6.0, 1.0), + countField: "count", + } + m.syncFieldPresetToTrie() + m.syncCountFieldToTrie() + return m +} + +// Init starts the flamegraph model. +func (m Model) Init() tea.Cmd { + return nil +} + +// Update handles incoming messages. +func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + switch msg := msg.(type) { + case animTickMsg: + if !m.animating { + return m, nil + } + m.animating = m.animation.Tick(0) + m.frames = m.animation.CurrentFrames() + m.clampSelection() + m.subtreeSet = computeSubtreeSetInto(m.frames, m.selectedIdx, m.subtreeSet) + if m.animating { + return m, animTickCmd() + } + return m, nil + case tea.WindowSizeMsg: + m.width = msg.Width + m.height = msg.Height + m.rebuildFrames(true) + if m.animating { + return m, animTickCmd() + } + return m, nil + case tea.KeyPressMsg: + if m.searchActive { + handled := false + switch msg.String() { + case "esc": + handled = true + m.clearSearch() + m.recordKeyDebug(msg, handled, false) + return m, nil + case "enter": + handled = true + m.applySearchQuery(m.searchInput.Value()) + m.searchActive = false + m.searchInput.Blur() + m.recordKeyDebug(msg, handled, false) + return m, nil + } + var cmd tea.Cmd + m.searchInput, cmd = m.searchInput.Update(msg) + _ = cmd + m.recordKeyDebug(msg, true, false) + return m, nil + } + + prev := m.selectedIdx + handled := false + switch { + case isSearchOpenKey(msg): + handled = true + m.openSearch() + case isNextMatchKey(msg): + handled = true + m.jumpMatch(1) + case isPrevMatchKey(msg): + handled = true + m.jumpMatch(-1) + case isPauseKey(msg): + handled = true + m.togglePause() + case isResetBaselineKey(msg): + handled = true + m.resetBaseline() + case isCycleOrderKey(msg): + handled = true + m.cycleFieldOrder() + case isCycleMetricKey(msg): + handled = true + m.toggleCountField() + case isHelpToggleKey(msg): + handled = true + m.toggleHelp() + case isZoomInKey(msg, m.keys): + handled = true + m.zoomIn() + case isZoomUndoKey(msg, m.keys): + handled = true + m.zoomUndo() + case isZoomResetKey(msg, m.keys): + handled = true + m.zoomReset() + case isMoveShallowerKey(msg, m.keys): + handled = true + m.moveVerticalWithFallback(-1, 1, -1) + case isMoveDeeperKey(msg, m.keys): + handled = true + m.moveVerticalWithFallback(1, -1, 1) + case isPrevSiblingKey(msg, m.keys): + handled = true + m.moveSibling(-1) + case isNextSiblingKey(msg, m.keys): + handled = true + m.moveSibling(1) + case isJumpTopKey(msg, m.keys): + handled = true + m.jumpToTop() + case isJumpRootKey(msg, m.keys): + handled = true + m.jumpToRoot() + } + if m.selectedIdx != prev { + m.subtreeSet = computeSubtreeSetInto(m.frames, m.selectedIdx, m.subtreeSet) + } + m.recordKeyDebug(msg, handled, m.selectedIdx != prev) + } + return m, nil +} + +// ConsumesKey reports whether the flamegraph should handle a key press before +// dashboard- or app-level shortcuts. +func (m Model) ConsumesKey(msg tea.KeyPressMsg) bool { + if m.searchActive { + return true + } + switch { + case isSearchOpenKey(msg), + isNextMatchKey(msg), + isPrevMatchKey(msg), + isPauseKey(msg), + isResetBaselineKey(msg), + isCycleOrderKey(msg), + isCycleMetricKey(msg), + isHelpToggleKey(msg): + return true + case isZoomInKey(msg, m.keys), + isZoomUndoKey(msg, m.keys), + isZoomResetKey(msg, m.keys), + isMoveShallowerKey(msg, m.keys), + isMoveDeeperKey(msg, m.keys), + isPrevSiblingKey(msg, m.keys), + isNextSiblingKey(msg, m.keys), + isJumpTopKey(msg, m.keys), + isJumpRootKey(msg, m.keys): + return true + default: + return false + } +} + +// View renders the flamegraph viewport. +func (m Model) View() tea.View { + extraLines := 1 // selection status line + if m.showHelp { + extraLines++ + } + renderHeight := m.height - extraLines + if renderHeight < 3 { + renderHeight = 3 + } + + content := RenderTerminalView(m.frames, m.width, renderHeight, m.selectedIdx, m.subtreeSet, m.matchIndices, m.filterVisible, m.globalTotal, m.countFieldLabel(), m.isDark, m.searchActive, m.searchQuery) + content = replaceHeaderLine(content, m.toolbarLine()) + if m.searchActive { + content = replaceFooterLine(content, m.searchFooter()) + } + if m.snapshot != nil && len(m.frames) == 0 { + content = common.PanelStyle.Render(fmt.Sprintf("Flame: snapshot v%d has no visible frames", m.lastVersion)) + } + content += "\n" + m.selectionStatusLine() + if m.showHelp { + content += "\n" + m.helpOverlay() + } + return tea.NewView(content) +} + +// SetLiveTrie updates the data source used by the flamegraph model. +func (m *Model) SetLiveTrie(liveTrie LiveTrieSource) { + m.liveTrie = liveTrie + m.syncFieldPresetToTrie() + m.syncCountFieldToTrie() + m.lastVersion = 0 + m.snapshot = nil + m.globalTotal = 0 + m.selectedIdx = 0 + m.frames = nil + m.targetFrames = nil + m.zoomStack = nil + m.zoomRoot = nil + m.zoomPath = "" + m.subtreeSet = make(map[int]bool) + m.filterVisible = make(map[int]bool) + m.animation = NewAnimationState(30, 6.0, 1.0) + m.animating = false + m.hasNavigableSnapshot = false +} + +func (m *Model) syncFieldPresetToTrie() { + if m.liveTrie == nil { + m.fieldIndex = 0 + return + } + fields := m.liveTrie.Fields() + if len(fields) == 0 { + m.fieldIndex = 0 + return + } + for idx, preset := range m.fieldPresets { + if slices.Equal(preset, fields) { + m.fieldIndex = idx + return + } + } + custom := slices.Clone(fields) + m.fieldPresets = append([][]string{custom}, m.fieldPresets...) + m.fieldIndex = 0 +} + +func (m *Model) syncCountFieldToTrie() { + if m.liveTrie == nil { + m.countField = "count" + return + } + field := strings.TrimSpace(m.liveTrie.CountField()) + if field == "" { + field = "count" + } + m.countField = field +} + +// RefreshFromLiveTrie loads a new snapshot when the source version changes. +func (m *Model) RefreshFromLiveTrie() bool { + if m.liveTrie == nil { + return false + } + // Once a snapshot exists, paused mode must freeze it regardless of current + // navigability so selection and percentages remain stable. + if m.paused && m.snapshot != nil { + return false + } + version := m.liveTrie.Version() + if version == m.lastVersion && m.snapshot != nil { + return false + } + + payload, version := m.liveTrie.SnapshotJSON() + var snapshot snapshotNode + if err := json.Unmarshal(payload, &snapshot); err != nil { + return false + } + m.snapshot = &snapshot + m.globalTotal = snapshotTotal(m.snapshot) + if m.zoomPath != "" { + m.zoomRoot = findNodeByPath(m.snapshot, m.zoomPath) + } else { + m.zoomRoot = nil + } + m.rebuildFrames(true) + m.lastVersion = version + return true +} + +// LastVersion returns the latest snapshot version loaded into the model. +func (m Model) LastVersion() uint64 { + return m.lastVersion +} + +// HasSnapshot reports whether the flamegraph model has loaded at least one snapshot. +func (m Model) HasSnapshot() bool { + return m.snapshot != nil +} + +// AnimationCmd returns a frame animation tick command when animation is active. +func (m Model) AnimationCmd() tea.Cmd { + if !m.animating { + return nil + } + return animTickCmd() +} + +// Paused reports whether live refresh is paused. +func (m Model) Paused() bool { + return m.paused +} + +// SetViewport updates model render dimensions. +func (m *Model) SetViewport(width, height int) { + m.width = width + m.height = height + m.rebuildFrames(true) +} + +// SetDarkMode sets the active color theme mode. +func (m *Model) SetDarkMode(isDark bool) { + m.isDark = isDark + m.searchInput.SetStyles(textinput.DefaultStyles(isDark)) +} + +func (m *Model) rebuildFrames(animate bool) { + prevPath := "" + if len(m.frames) > 0 && m.selectedIdx >= 0 && m.selectedIdx < len(m.frames) { + prevPath = m.frames[m.selectedIdx].Path + } + + var root *snapshotNode + rootPath := "" + if m.zoomRoot != nil { + root = m.zoomRoot + rootPath = m.zoomPath + } else { + root = m.snapshot + } + m.targetFrames = buildTerminalLayoutWithPath(root, m.width, m.height, rootPath) + m.animation.SetTargets(m.targetFrames) + if animate && len(m.frames) > 0 && !m.animation.Settled() { + m.animating = true + m.frames = m.animation.CurrentFrames() + } else { + m.animating = false + m.frames = append(m.frames[:0], m.targetFrames...) + } + if len(m.frames) > 1 { + m.hasNavigableSnapshot = true + } + m.restoreSelectionByPath(prevPath) + m.clampSelection() + m.recomputeFilterState() + m.ensureSelectionNavigable() + m.ensureSelectionVisible() + m.subtreeSet = computeSubtreeSetInto(m.frames, m.selectedIdx, m.subtreeSet) +} + +func (m *Model) restoreSelectionByPath(path string) { + if path == "" || len(m.frames) == 0 { + return + } + if idx := m.frameIndexByPath(path); idx >= 0 { + m.selectedIdx = idx + return + } + for idx, frame := range m.frames { + if hasPathBoundaryPrefix(path, frame.Path) || hasPathBoundaryPrefix(frame.Path, path) { + m.selectedIdx = idx + return + } + } +} + +func (m Model) frameIndexByPath(path string) int { + for idx, frame := range m.frames { + if frame.Path == path { + return idx + } + } + return -1 +} + +func (m *Model) zoomIn() { + if len(m.frames) == 0 || m.snapshot == nil { + m.statusMessage = "Zoom unavailable: no frame selected" + return + } + m.clampSelection() + selectedPath := m.frames[m.selectedIdx].Path + if selectedPath == m.currentRootPath() { + m.statusMessage = "Zoom unchanged: selected frame is current view root" + return + } + target := findNodeByPath(m.snapshot, selectedPath) + if target == nil { + m.statusMessage = "Zoom failed: selected node is unavailable" + return + } + m.zoomStack = append(m.zoomStack, zoomState{ + path: m.zoomPath, + previousSelectedIdx: m.selectedIdx, + }) + m.zoomRoot = target + m.zoomPath = selectedPath + m.selectedIdx = 0 + m.rebuildFrames(true) + m.statusMessage = "Zoom: " + compactFramePath(selectedPath) +} + +func (m *Model) zoomUndo() { + if len(m.zoomStack) == 0 || m.snapshot == nil { + m.statusMessage = "Zoom undo unavailable" + return + } + last := m.zoomStack[len(m.zoomStack)-1] + m.zoomStack = m.zoomStack[:len(m.zoomStack)-1] + m.zoomPath = last.path + if m.zoomPath == "" { + m.zoomRoot = nil + } else { + m.zoomRoot = findNodeByPath(m.snapshot, m.zoomPath) + } + m.selectedIdx = last.previousSelectedIdx + m.rebuildFrames(true) + if m.zoomPath == "" { + m.statusMessage = "Zoom: root" + return + } + m.statusMessage = "Zoom: " + compactFramePath(m.zoomPath) +} + +func (m *Model) zoomReset() { + if m.zoomRoot == nil && len(m.zoomStack) == 0 { + m.statusMessage = "Zoom already at root" + return + } + m.zoomRoot = nil + m.zoomPath = "" + m.zoomStack = nil + m.rebuildFrames(false) + m.statusMessage = "Zoom reset to root" +} + +func (m *Model) moveVertical(delta int) { + if len(m.frames) == 0 { + return + } + m.clampSelection() + m.ensureSelectionNavigable() + current := m.frames[m.selectedIdx] + targetDepth := current.Depth + delta + targets := m.framesAtDepth(targetDepth) + if len(targets) == 0 { + return + } + best := targets[0] + bestDist := abs(m.frames[best].Col - current.Col) + for _, idx := range targets[1:] { + dist := abs(m.frames[idx].Col - current.Col) + if dist < bestDist { + best = idx + bestDist = dist + } + } + m.selectedIdx = best +} + +func (m *Model) moveVerticalWithFallback(primaryDelta, fallbackDelta, traversalDelta int) { + before := m.selectedIdx + m.moveVertical(primaryDelta) + if m.selectedIdx == before && fallbackDelta != 0 { + m.moveVertical(fallbackDelta) + } + if m.selectedIdx == before && traversalDelta != 0 { + m.moveTraversal(traversalDelta) + } +} + +func (m *Model) moveSibling(delta int) { + if len(m.frames) == 0 { + return + } + before := m.selectedIdx + m.clampSelection() + m.ensureSelectionNavigable() + current := m.frames[m.selectedIdx] + siblings := m.framesAtDepth(current.Depth) + if len(siblings) <= 1 { + m.moveTraversal(delta) + return + } + pos := indexOf(siblings, m.selectedIdx) + if pos < 0 { + m.moveTraversal(delta) + return + } + next := pos + delta + if next < 0 { + next = 0 + } + if next >= len(siblings) { + next = len(siblings) - 1 + } + m.selectedIdx = siblings[next] + if m.selectedIdx == before { + m.moveTraversal(delta) + } +} + +func (m *Model) jumpToTop() { + if len(m.frames) == 0 { + return + } + m.clampSelection() + m.ensureSelectionNavigable() + + include := m.navigableFrameSet() + currentCol := m.frames[m.selectedIdx].Col + bestIdx := -1 + bestDepth := -1 + bestDist := int(^uint(0) >> 1) + + for idx, frame := range m.frames { + if include != nil && !include[idx] { + continue + } + dist := abs(frame.Col - currentCol) + if frame.Depth > bestDepth { + bestDepth = frame.Depth + bestIdx = idx + bestDist = dist + continue + } + if frame.Depth == bestDepth { + if dist < bestDist || (dist == bestDist && frame.Col < m.frames[bestIdx].Col) { + bestIdx = idx + bestDist = dist + } + } + } + if bestIdx >= 0 { + m.selectedIdx = bestIdx + } +} + +func (m *Model) jumpToRoot() { + if len(m.frames) == 0 { + return + } + m.clampSelection() + m.ensureSelectionNavigable() + + rootPath := m.currentRootPath() + if rootPath != "" { + if idx := m.frameIndexByPath(rootPath); idx >= 0 { + if !m.filterActive() || m.frameNavigable(idx) { + m.selectedIdx = idx + return + } + } + } + + include := m.navigableFrameSet() + currentCol := m.frames[m.selectedIdx].Col + bestIdx := -1 + bestDepth := int(^uint(0) >> 1) + bestDist := int(^uint(0) >> 1) + for idx, frame := range m.frames { + if include != nil && !include[idx] { + continue + } + dist := abs(frame.Col - currentCol) + if frame.Depth < bestDepth { + bestDepth = frame.Depth + bestDist = dist + bestIdx = idx + continue + } + if frame.Depth == bestDepth { + if dist < bestDist || (dist == bestDist && frame.Col < m.frames[bestIdx].Col) { + bestDist = dist + bestIdx = idx + } + } + } + if bestIdx >= 0 { + m.selectedIdx = bestIdx + } +} + +func framesAtDepth(frames []tuiFrame, depth int) []int { + return framesAtDepthFiltered(frames, depth, nil) +} + +func framesAtDepthFiltered(frames []tuiFrame, depth int, include map[int]bool) []int { + if depth < 0 { + return nil + } + indices := make([]int, 0) + for idx, frame := range frames { + if include != nil && !include[idx] { + continue + } + if frame.Depth == depth { + indices = append(indices, idx) + } + } + sort.Slice(indices, func(i, j int) bool { + return frames[indices[i]].Col < frames[indices[j]].Col + }) + return indices +} + +func indexOf(values []int, target int) int { + for idx, value := range values { + if value == target { + return idx + } + } + return -1 +} + +func (m *Model) clampSelection() { + if len(m.frames) == 0 { + m.selectedIdx = 0 + return + } + if m.selectedIdx < 0 { + m.selectedIdx = 0 + } + if m.selectedIdx >= len(m.frames) { + m.selectedIdx = len(m.frames) - 1 + } +} + +func abs(v int) int { + if v < 0 { + return -v + } + return v +} + +func animTickCmd() tea.Cmd { + return tea.Tick(animFrameDuration, func(time.Time) tea.Msg { return animTickMsg{} }) +} + +func (m Model) currentRootPath() string { + if m.zoomPath != "" { + return m.zoomPath + } + if len(m.frames) == 0 { + return "" + } + return m.frames[0].Path +} + +func (m Model) filterActive() bool { + return strings.TrimSpace(m.searchQuery) != "" +} + +func (m Model) navigableFrameSet() map[int]bool { + if !m.filterActive() { + return nil + } + return m.filterVisible +} + +func (m Model) framesAtDepth(depth int) []int { + return framesAtDepthFiltered(m.frames, depth, m.navigableFrameSet()) +} + +func (m Model) frameNavigable(idx int) bool { + if idx < 0 || idx >= len(m.frames) { + return false + } + if !m.filterActive() { + return true + } + return m.filterVisible[idx] +} + +func (m *Model) ensureSelectionNavigable() { + if len(m.frames) == 0 { + m.selectedIdx = 0 + return + } + m.clampSelection() + if m.frameNavigable(m.selectedIdx) { + return + } + + if len(m.matchIndices) > 0 { + for _, idx := range orderedMatchIndices(m.matchIndices) { + if m.frameNavigable(idx) { + m.selectedIdx = idx + return + } + } + } + + for idx := range m.frames { + if m.frameNavigable(idx) { + m.selectedIdx = idx + return + } + } +} + +func (m *Model) recordKeyDebug(msg tea.KeyPressMsg, handled, moved bool) { + keyID := keyString(msg) + if keyID == "" { + keyID = fmt.Sprintf("code:%d", msg.Code) + } + sel := "-" + selIdx := m.selectedIdx + if len(m.frames) > 0 && m.selectedIdx >= 0 && m.selectedIdx < len(m.frames) { + sel = compactFramePath(m.frames[m.selectedIdx].Path) + } + m.lastKeyDebug = fmt.Sprintf("dbg frames=%d idx=%d key=%q code=%d handled=%t moved=%t sel=%s", len(m.frames), selIdx, keyID, msg.Code, handled, moved, sel) +} + +func (m *Model) moveTraversal(delta int) { + if len(m.frames) == 0 || delta == 0 { + return + } + order := m.visibleTraversalOrder() + if len(order) == 0 { + return + } + pos := indexOf(order, m.selectedIdx) + if pos < 0 { + pos = 0 + } + next := pos + delta + if next < 0 { + next = 0 + } + if next >= len(order) { + next = len(order) - 1 + } + m.selectedIdx = order[next] +} + +func (m Model) visibleTraversalOrder() []int { + indices := make([]int, 0, len(m.frames)) + include := m.navigableFrameSet() + for idx := range m.frames { + if include != nil && !include[idx] { + continue + } + indices = append(indices, idx) + } + sort.Slice(indices, func(i, j int) bool { + left := m.frames[indices[i]] + right := m.frames[indices[j]] + if left.Depth != right.Depth { + return left.Depth < right.Depth + } + if left.Col != right.Col { + return left.Col < right.Col + } + if left.Row != right.Row { + return left.Row < right.Row + } + return indices[i] < indices[j] + }) + return indices +} + +func keyString(msg tea.KeyPressMsg) string { + if s := msg.String(); s != "" { + return s + } + return msg.Text +} + +func isSearchOpenKey(msg tea.KeyPressMsg) bool { return keyString(msg) == "/" } +func isNextMatchKey(msg tea.KeyPressMsg) bool { return keyString(msg) == "n" } +func isPrevMatchKey(msg tea.KeyPressMsg) bool { return keyString(msg) == "N" } +func isPauseKey(msg tea.KeyPressMsg) bool { + k := keyString(msg) + return k == "p" || k == " " || k == "space" || msg.Code == tea.KeySpace +} +func isResetBaselineKey(msg tea.KeyPressMsg) bool { + return keyString(msg) == "r" +} +func isCycleOrderKey(msg tea.KeyPressMsg) bool { return keyString(msg) == "o" } +func isCycleMetricKey(msg tea.KeyPressMsg) bool { + return keyString(msg) == "b" +} +func isHelpToggleKey(msg tea.KeyPressMsg) bool { return keyString(msg) == "?" } + +func isZoomInKey(msg tea.KeyPressMsg, keys flameKeyMap) bool { + return key.Matches(msg, keys.ZoomIn) || msg.Code == tea.KeyEnter || strings.EqualFold(keyString(msg), "enter") +} + +func isZoomUndoKey(msg tea.KeyPressMsg, keys flameKeyMap) bool { + return key.Matches(msg, keys.ZoomUndo) || msg.Code == tea.KeyBackspace || msg.Code == tea.KeyEsc +} + +func isZoomResetKey(msg tea.KeyPressMsg, keys flameKeyMap) bool { + return key.Matches(msg, keys.ZoomReset) +} + +func isMoveShallowerKey(msg tea.KeyPressMsg, keys flameKeyMap) bool { + k := keyString(msg) + return key.Matches(msg, keys.MoveShallower) || msg.Code == tea.KeyDown || keyMatchesDirection(k, "down", 'B') +} + +func isMoveDeeperKey(msg tea.KeyPressMsg, keys flameKeyMap) bool { + k := keyString(msg) + return key.Matches(msg, keys.MoveDeeper) || msg.Code == tea.KeyUp || keyMatchesDirection(k, "up", 'A') +} + +func isPrevSiblingKey(msg tea.KeyPressMsg, keys flameKeyMap) bool { + k := keyString(msg) + return key.Matches(msg, keys.PrevSibling) || msg.Code == tea.KeyLeft || keyMatchesDirection(k, "left", 'D') +} + +func isNextSiblingKey(msg tea.KeyPressMsg, keys flameKeyMap) bool { + k := keyString(msg) + return key.Matches(msg, keys.NextSibling) || msg.Code == tea.KeyRight || keyMatchesDirection(k, "right", 'C') +} + +func isJumpTopKey(msg tea.KeyPressMsg, keys flameKeyMap) bool { + k := strings.ToLower(keyString(msg)) + return key.Matches(msg, keys.JumpTop) || msg.Code == tea.KeyPgUp || k == "pgup" || k == "pageup" +} + +func isJumpRootKey(msg tea.KeyPressMsg, keys flameKeyMap) bool { + k := strings.ToLower(keyString(msg)) + return key.Matches(msg, keys.JumpRoot) || msg.Code == tea.KeyPgDown || k == "pgdown" || k == "pgdn" || k == "pagedown" +} + +func keyMatchesDirection(keyName, plain string, ansiFinal byte) bool { + if keyName == plain || strings.HasSuffix(keyName, "+"+plain) { + return true + } + return isArrowEscapeSequence(keyName, ansiFinal) +} + +func isArrowEscapeSequence(value string, ansiFinal byte) bool { + if len(value) < 3 || value[0] != '\x1b' { + return false + } + last := value[len(value)-1] + if last != ansiFinal { + return false + } + return value[1] == '[' || value[1] == 'O' +} + +func (m Model) visibleRowOffset() int { + if len(m.frames) == 0 { + return 0 + } + availableRows := m.height - 2 // toolbar + status + if availableRows <= 0 { + return 0 + } + maxRow := maxFrameRowForSet(m.frames, m.navigableFrameSet()) + if maxRow+1 <= availableRows { + return 0 + } + return maxRow + 1 - availableRows +} + +func (m *Model) ensureSelectionVisible() { + if len(m.frames) == 0 { + return + } + m.clampSelection() + m.ensureSelectionNavigable() + if !m.frameNavigable(m.selectedIdx) { + return + } + rowOffset := m.visibleRowOffset() + selected := m.frames[m.selectedIdx] + if selected.Row >= rowOffset { + return + } + + bestIdx := -1 + bestScore := int(^uint(0) >> 1) + for idx, frame := range m.frames { + if !m.frameNavigable(idx) { + continue + } + if frame.Row < rowOffset { + continue + } + score := abs(frame.Row-rowOffset)*1000 + abs(frame.Col-selected.Col) + if score < bestScore { + bestIdx = idx + bestScore = score + } + } + if bestIdx >= 0 { + m.selectedIdx = bestIdx + } +} diff --git a/internal/tui/flamegraph/model_test.go b/internal/tui/flamegraph/model_test.go new file mode 100644 index 0000000..74ce8d9 --- /dev/null +++ b/internal/tui/flamegraph/model_test.go @@ -0,0 +1,987 @@ +package flamegraph + +import ( + "reflect" + "strings" + "testing" + + coreflamegraph "ior/internal/flamegraph" + + tea "charm.land/bubbletea/v2" +) + +func TestNewModelDefaults(t *testing.T) { + m := NewModel(nil) + if m.liveTrie != nil { + t.Fatalf("expected nil liveTrie when constructor input is nil") + } + if m.matchIndices == nil { + t.Fatalf("expected matchIndices map to be initialized") + } + if len(m.fieldPresets) == 0 { + t.Fatalf("expected default field presets to be initialized") + } + if got, want := m.fieldPresets[0], []string{"comm", "tracepoint", "path"}; !reflect.DeepEqual(got, want) { + t.Fatalf("default field preset[0] = %v, want %v", got, want) + } + if !m.isDark { + t.Fatalf("expected dark mode enabled by default") + } +} + +func TestSetViewportAndDarkMode(t *testing.T) { + m := NewModel(nil) + m.SetViewport(120, 40) + m.SetDarkMode(false) + if m.width != 120 || m.height != 40 { + t.Fatalf("expected viewport 120x40, got %dx%d", m.width, m.height) + } + if m.isDark { + t.Fatalf("expected dark mode to be disabled") + } +} + +func TestRefreshFromLiveTrieTracksVersionAndSnapshot(t *testing.T) { + trie := coreflamegraph.NewLiveTrie([]string{"comm", "path"}, "count") + m := NewModel(trie) + + if changed := m.RefreshFromLiveTrie(); !changed { + t.Fatalf("expected first refresh to load baseline snapshot") + } + if m.snapshot == nil { + t.Fatalf("expected snapshot to be populated after refresh") + } + + if changed := m.RefreshFromLiveTrie(); changed { + t.Fatalf("expected no refresh when version is unchanged") + } +} + +func TestRefreshFromLiveTrieAllowsInitialLoadWhilePaused(t *testing.T) { + trie := coreflamegraph.NewLiveTrie([]string{"comm", "path"}, "count") + m := NewModel(trie) + m.paused = true + + if changed := m.RefreshFromLiveTrie(); !changed { + t.Fatalf("expected initial paused refresh to load first snapshot") + } + if m.snapshot == nil { + t.Fatalf("expected snapshot to be available after initial paused refresh") + } + if changed := m.RefreshFromLiveTrie(); changed { + t.Fatalf("expected subsequent paused refresh to be skipped once snapshot exists") + } +} + +func TestRefreshFromLiveTriePausedBlocksAfterNavigableSnapshot(t *testing.T) { + trie := coreflamegraph.NewLiveTrie([]string{"comm", "path"}, "count") + m := NewModel(trie) + m.paused = true + m.snapshot = &snapshotNode{Name: "root", Total: 1} + m.frames = []tuiFrame{ + {Name: "root", Path: "root"}, + {Name: "child", Path: "root" + pathSeparator + "child"}, + } + m.hasNavigableSnapshot = true + m.lastVersion = 1 + + if changed := m.RefreshFromLiveTrie(); changed { + t.Fatalf("expected paused refresh to remain frozen once navigable snapshot exists") + } + if got, want := m.lastVersion, uint64(1); got != want { + t.Fatalf("expected version to remain unchanged while paused, got %d want %d", got, want) + } +} + +func TestRefreshFromLiveTriePausedBlocksAfterAnySnapshot(t *testing.T) { + trie := coreflamegraph.NewLiveTrie([]string{"comm", "path"}, "count") + m := NewModel(trie) + m.paused = true + m.snapshot = &snapshotNode{Name: "root", Total: 1} + m.frames = []tuiFrame{{Name: "root", Path: "root"}} + m.hasNavigableSnapshot = false + m.lastVersion = 1 + + if changed := m.RefreshFromLiveTrie(); changed { + t.Fatalf("expected paused refresh to freeze after first snapshot even when non-navigable") + } + if got, want := m.lastVersion, uint64(1); got != want { + t.Fatalf("expected paused refresh to keep existing snapshot version, got %d want %d", got, want) + } +} + +func TestKeyboardNavigationDeepNarrowTree(t *testing.T) { + m := NewModel(nil) + m.frames = []tuiFrame{ + {Name: "root", Depth: 0, Col: 0, Path: "root"}, + {Name: "child", Depth: 1, Col: 0, Path: "root" + pathSeparator + "child"}, + {Name: "leaf", Depth: 2, Col: 0, Path: "root" + pathSeparator + "child" + pathSeparator + "leaf"}, + } + + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: []rune{'k'}[0], Text: "k"}) + if m.selectedIdx != 1 { + t.Fatalf("expected selection to move deeper to idx 1, got %d", m.selectedIdx) + } + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: []rune{'k'}[0], Text: "k"}) + if m.selectedIdx != 2 { + t.Fatalf("expected selection to move deeper to idx 2, got %d", m.selectedIdx) + } + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: []rune{'j'}[0], Text: "j"}) + if m.selectedIdx != 1 { + t.Fatalf("expected selection to move shallower to idx 1, got %d", m.selectedIdx) + } +} + +func TestKeyboardNavigationShallowWideSiblings(t *testing.T) { + m := NewModel(nil) + m.frames = []tuiFrame{ + {Name: "root", Depth: 0, Col: 0, Path: "root"}, + {Name: "A", Depth: 1, Col: 0, Path: "root" + pathSeparator + "A"}, + {Name: "B", Depth: 1, Col: 30, Path: "root" + pathSeparator + "B"}, + {Name: "C", Depth: 1, Col: 60, Path: "root" + pathSeparator + "C"}, + } + + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: []rune{'k'}[0], Text: "k"}) + if m.selectedIdx != 1 { + t.Fatalf("expected first deeper frame to be A, got idx %d", m.selectedIdx) + } + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: []rune{'l'}[0], Text: "l"}) + if m.selectedIdx != 2 { + t.Fatalf("expected next sibling B, got idx %d", m.selectedIdx) + } + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: []rune{'l'}[0], Text: "l"}) + if m.selectedIdx != 3 { + t.Fatalf("expected next sibling C, got idx %d", m.selectedIdx) + } + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: []rune{'l'}[0], Text: "l"}) + if m.selectedIdx != 3 { + t.Fatalf("expected selection to clamp at last sibling, got idx %d", m.selectedIdx) + } + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: []rune{'h'}[0], Text: "h"}) + if m.selectedIdx != 2 { + t.Fatalf("expected previous sibling B, got idx %d", m.selectedIdx) + } +} + +func TestHorizontalTraversalFallbackFromRoot(t *testing.T) { + m := NewModel(nil) + m.frames = []tuiFrame{ + {Name: "root", Depth: 0, Col: 0, Path: "root"}, + {Name: "A", Depth: 1, Col: 0, Path: "root" + pathSeparator + "A"}, + {Name: "B", Depth: 1, Col: 30, Path: "root" + pathSeparator + "B"}, + } + m.selectedIdx = 0 + + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: tea.KeyRight}) + if m.selectedIdx != 1 { + t.Fatalf("expected right arrow from root to move to first traversable frame, got idx %d", m.selectedIdx) + } + + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: []rune{'l'}[0], Text: "l"}) + if m.selectedIdx != 2 { + t.Fatalf("expected vi right key to move to next frame, got idx %d", m.selectedIdx) + } + + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: tea.KeyLeft}) + if m.selectedIdx != 1 { + t.Fatalf("expected left arrow to move back to previous frame, got idx %d", m.selectedIdx) + } + + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: []rune{'h'}[0], Text: "h"}) + if m.selectedIdx != 0 { + t.Fatalf("expected vi left key to move back to root, got idx %d", m.selectedIdx) + } +} + +func TestPageUpJumpsSelectionToTopMostDepth(t *testing.T) { + m := NewModel(nil) + m.frames = []tuiFrame{ + {Name: "root", Depth: 0, Col: 0, Path: "root"}, + {Name: "A", Depth: 1, Col: 0, Path: "root" + pathSeparator + "A"}, + {Name: "B", Depth: 1, Col: 40, Path: "root" + pathSeparator + "B"}, + {Name: "A1", Depth: 2, Col: 0, Path: "root" + pathSeparator + "A" + pathSeparator + "A1"}, + {Name: "B1", Depth: 2, Col: 40, Path: "root" + pathSeparator + "B" + pathSeparator + "B1"}, + {Name: "A2", Depth: 3, Col: 0, Path: "root" + pathSeparator + "A" + pathSeparator + "A1" + pathSeparator + "A2"}, + {Name: "B2", Depth: 3, Col: 40, Path: "root" + pathSeparator + "B" + pathSeparator + "B1" + pathSeparator + "B2"}, + } + m.selectedIdx = mustFrameIndex(t, m.frames, "root"+pathSeparator+"B"+pathSeparator+"B1") + + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: tea.KeyPgUp}) + if got, want := m.frames[m.selectedIdx].Path, "root"+pathSeparator+"B"+pathSeparator+"B1"+pathSeparator+"B2"; got != want { + t.Fatalf("expected pgup to jump to deepest top frame %q, got %q", want, got) + } +} + +func TestPageDownJumpsSelectionToCurrentViewRoot(t *testing.T) { + m := newZoomModel() + m.selectedIdx = mustFrameIndex(t, m.frames, "root"+pathSeparator+"A") + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: tea.KeyEnter}) + m.selectedIdx = mustFrameIndex(t, m.frames, "root"+pathSeparator+"A"+pathSeparator+"A1") + + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: tea.KeyPgDown}) + if got, want := m.frames[m.selectedIdx].Path, "root"+pathSeparator+"A"; got != want { + t.Fatalf("expected pgdn to jump to current zoom root %q, got %q", want, got) + } +} + +func TestPausedStateStillAllowsNavigation(t *testing.T) { + m := NewModel(nil) + m.frames = []tuiFrame{ + {Name: "root", Depth: 0, Col: 0, Path: "root"}, + {Name: "A", Depth: 1, Col: 0, Path: "root" + pathSeparator + "A"}, + } + m.paused = true + m.selectedIdx = 0 + + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: tea.KeyRight}) + if m.selectedIdx != 1 { + t.Fatalf("expected navigation to work while paused, got idx %d", m.selectedIdx) + } +} + +func TestStaticFixtureArrowTraversalVisitsAllFrames(t *testing.T) { + trie := coreflamegraph.NewLiveTrie([]string{"comm", "path", "tracepoint"}, "count") + coreflamegraph.SeedTestFlameData(trie) + + m := NewModel(trie) + m.SetViewport(180, 40) + if changed := m.RefreshFromLiveTrie(); !changed { + t.Fatalf("expected seeded fixture refresh to load frames") + } + if len(m.frames) < 2 { + t.Fatalf("expected seeded fixture to contain navigable frames, got %d", len(m.frames)) + } + + visited := map[int]bool{m.selectedIdx: true} + for i := 0; i < len(m.frames)*4; i++ { + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: tea.KeyRight}) + visited[m.selectedIdx] = true + } + for i := 0; i < len(m.frames)*4; i++ { + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: tea.KeyLeft}) + visited[m.selectedIdx] = true + } + + if got, want := len(visited), len(m.frames); got != want { + t.Fatalf("expected arrow traversal to visit all frames: visited=%d frames=%d", got, want) + } + if !strings.Contains(m.View().Content, "sel:") { + t.Fatalf("expected view to expose selected-frame status line") + } +} + +func TestLiveFixtureArrowTraversalWhileStreamingVisitsAllFrames(t *testing.T) { + trie := coreflamegraph.NewLiveTrie([]string{"comm", "path", "tracepoint"}, "count") + coreflamegraph.SeedTestLiveFlameData(trie, 0) + + m := NewModel(trie) + m.SetViewport(180, 40) + if changed := m.RefreshFromLiveTrie(); !changed { + t.Fatalf("expected initial refresh to load frames") + } + if len(m.frames) < 2 { + t.Fatalf("expected seeded fixture to contain navigable frames, got %d", len(m.frames)) + } + + selectedPath := func(model Model) string { + if len(model.frames) == 0 || model.selectedIdx < 0 || model.selectedIdx >= len(model.frames) { + return "" + } + return model.frames[model.selectedIdx].Path + } + + visitedPaths := map[string]bool{selectedPath(m): true} + moves := 0 + for i := 0; i < len(m.frames)*4; i++ { + trie.Reset() + coreflamegraph.SeedTestLiveFlameData(trie, uint64(i+1)) + if changed := m.RefreshFromLiveTrie(); !changed { + t.Fatalf("expected refresh after synthetic live ingest at step %d", i) + } + before := selectedPath(m) + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: tea.KeyRight}) + after := selectedPath(m) + if after != before { + moves++ + } + visitedPaths[after] = true + } + for i := 0; i < len(m.frames)*4; i++ { + trie.Reset() + coreflamegraph.SeedTestLiveFlameData(trie, uint64(i+1+len(m.frames)*4)) + if changed := m.RefreshFromLiveTrie(); !changed { + t.Fatalf("expected refresh after synthetic live ingest (reverse) at step %d", i) + } + before := selectedPath(m) + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: tea.KeyLeft}) + after := selectedPath(m) + if after != before { + moves++ + } + visitedPaths[after] = true + } + + if moves == 0 { + t.Fatalf("expected live-stream navigation to change selection at least once") + } + if len(visitedPaths) < 8 { + t.Fatalf("expected traversal across live updates to reach multiple frame paths, got %d", len(visitedPaths)) + } +} + +func TestSelectionRestoresByPathAcrossLiveRefresh(t *testing.T) { + trie := coreflamegraph.NewLiveTrie([]string{"comm", "path", "tracepoint"}, "count") + coreflamegraph.SeedTestLiveFlameData(trie, 0) + + m := NewModel(trie) + m.SetViewport(180, 40) + if changed := m.RefreshFromLiveTrie(); !changed { + t.Fatalf("expected initial refresh") + } + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: tea.KeyRight}) + selected := m.frames[m.selectedIdx].Path + if selected == "" || selected == "root" { + t.Fatalf("expected selection to move off root, got %q", selected) + } + + trie.Reset() + coreflamegraph.SeedTestLiveFlameData(trie, 2) + if changed := m.RefreshFromLiveTrie(); !changed { + t.Fatalf("expected refresh after live update") + } + if got := m.frames[m.selectedIdx].Path; got != selected { + t.Fatalf("expected selection path to persist across refresh, got %q want %q", got, selected) + } +} + +func TestKeyboardNavigationSingleNodeClamped(t *testing.T) { + m := NewModel(nil) + m.frames = []tuiFrame{{Name: "root", Depth: 0, Col: 0, Path: "root"}} + + keys := []tea.KeyPressMsg{ + {Code: []rune{'j'}[0], Text: "j"}, + {Code: []rune{'k'}[0], Text: "k"}, + {Code: []rune{'h'}[0], Text: "h"}, + {Code: []rune{'l'}[0], Text: "l"}, + {Code: tea.KeyDown}, + {Code: tea.KeyUp}, + {Code: tea.KeyLeft}, + {Code: tea.KeyRight}, + } + for _, keyMsg := range keys { + m = pressFlameKey(t, m, keyMsg) + if m.selectedIdx != 0 { + t.Fatalf("expected single-node selection to stay at idx 0, got %d", m.selectedIdx) + } + } +} + +func TestArrowDownFallsBackToVisibleDepthFromRoot(t *testing.T) { + m := NewModel(nil) + m.frames = []tuiFrame{ + {Name: "root", Depth: 0, Col: 0, Path: "root"}, + {Name: "child", Depth: 1, Col: 0, Path: "root" + pathSeparator + "child"}, + } + m.selectedIdx = 0 + + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: tea.KeyDown}) + if m.selectedIdx != 1 { + t.Fatalf("expected down arrow to move selection to child when root has no shallower row, got %d", m.selectedIdx) + } +} + +func TestArrowEscapeSequencesAreRecognized(t *testing.T) { + tests := []struct { + key string + dir string + ansiCode byte + }{ + {key: "\x1b[A", dir: "up", ansiCode: 'A'}, + {key: "\x1b[B", dir: "down", ansiCode: 'B'}, + {key: "\x1b[C", dir: "right", ansiCode: 'C'}, + {key: "\x1b[D", dir: "left", ansiCode: 'D'}, + {key: "\x1bOA", dir: "up", ansiCode: 'A'}, // application mode + {key: "\x1bOB", dir: "down", ansiCode: 'B'}, // application mode + {key: "\x1b[1;2A", dir: "up", ansiCode: 'A'}, + } + for _, tc := range tests { + if !keyMatchesDirection(tc.key, tc.dir, tc.ansiCode) { + t.Fatalf("expected key %q to match %s", tc.key, tc.dir) + } + } +} + +func TestFilteredNavigationSkipsHiddenBranches(t *testing.T) { + m := NewModel(nil) + m.frames = []tuiFrame{ + {Name: "root", Depth: 0, Col: 0, Row: 0, Path: "root"}, + {Name: "keep", Depth: 1, Col: 0, Row: 1, Path: "root" + pathSeparator + "keep"}, + {Name: "drop", Depth: 1, Col: 40, Row: 1, Path: "root" + pathSeparator + "drop"}, + } + m.searchQuery = "keep" + m.recomputeFilterState() + m.selectedIdx = 1 + + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: tea.KeyRight}) + if m.selectedIdx != 1 { + t.Fatalf("expected sibling navigation to stay on visible filtered branch, got idx %d", m.selectedIdx) + } + + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: tea.KeyDown}) + if m.selectedIdx != 0 { + t.Fatalf("expected down key to move to visible root ancestor, got idx %d", m.selectedIdx) + } +} + +func TestZoomInUndoSingleLevelAndNestedEsc(t *testing.T) { + m := newZoomModel() + + m.selectedIdx = mustFrameIndex(t, m.frames, "root"+pathSeparator+"A") + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: tea.KeyEnter}) + if got, want := m.zoomPath, "root"+pathSeparator+"A"; got != want { + t.Fatalf("expected zoomPath %q, got %q", want, got) + } + if len(m.zoomStack) != 1 || m.zoomStack[0].path != "" { + t.Fatalf("expected one zoom stack entry from root, got %#v", m.zoomStack) + } + if m.zoomRoot == nil || m.zoomRoot.Name != "A" { + t.Fatalf("expected zoomRoot A, got %+v", m.zoomRoot) + } + + m.selectedIdx = mustFrameIndex(t, m.frames, "root"+pathSeparator+"A"+pathSeparator+"A1") + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: tea.KeyEnter}) + if got, want := m.zoomPath, "root"+pathSeparator+"A"+pathSeparator+"A1"; got != want { + t.Fatalf("expected nested zoomPath %q, got %q", want, got) + } + if len(m.zoomStack) != 2 || m.zoomStack[1].path != "root"+pathSeparator+"A" { + t.Fatalf("expected nested zoom stack to preserve parent path, got %#v", m.zoomStack) + } + + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: tea.KeyEsc}) + if got, want := m.zoomPath, "root"+pathSeparator+"A"; got != want { + t.Fatalf("expected zoomPath after esc undo %q, got %q", want, got) + } + if len(m.zoomStack) != 1 { + t.Fatalf("expected one stack entry after esc undo, got %d", len(m.zoomStack)) + } + + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: tea.KeyEsc}) + if m.zoomPath != "" || m.zoomRoot != nil || len(m.zoomStack) != 0 { + t.Fatalf("expected second esc undo to return to root state, got path=%q root=%+v stack=%d", m.zoomPath, m.zoomRoot, len(m.zoomStack)) + } +} + +func TestZoomResetToRoot(t *testing.T) { + m := newZoomModel() + m.selectedIdx = mustFrameIndex(t, m.frames, "root"+pathSeparator+"A") + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: tea.KeyEnter}) + m.selectedIdx = mustFrameIndex(t, m.frames, "root"+pathSeparator+"A"+pathSeparator+"A1") + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: tea.KeyEnter}) + if m.zoomPath == "" || len(m.zoomStack) == 0 { + t.Fatalf("expected nested zoom before reset") + } + + m.zoomReset() + if m.zoomPath != "" || m.zoomRoot != nil || len(m.zoomStack) != 0 { + t.Fatalf("expected explicit zoom reset to clear zoom stack, got path=%q root=%+v stack=%d", m.zoomPath, m.zoomRoot, len(m.zoomStack)) + } +} + +func TestZoomInOnCurrentRootSetsStatusMessage(t *testing.T) { + m := newZoomModel() + m.selectedIdx = mustFrameIndex(t, m.frames, "root") + + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: tea.KeyEnter}) + if m.zoomPath != "" { + t.Fatalf("expected zoom path to remain root, got %q", m.zoomPath) + } + if m.statusMessage != "Zoom unchanged: selected frame is current view root" { + t.Fatalf("unexpected status message: %q", m.statusMessage) + } +} + +func TestZoomTransitionAnimatesToNewLayout(t *testing.T) { + m := newZoomModel() + pathA := "root" + pathSeparator + "A" + preWidth := m.frames[mustFrameIndex(t, m.frames, pathA)].Width + + m.selectedIdx = mustFrameIndex(t, m.frames, pathA) + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: tea.KeyEnter}) + if !m.animating { + t.Fatalf("expected zoom-in to start animation") + } + currentWidth := m.frames[mustFrameIndex(t, m.frames, pathA)].Width + targetWidth := m.targetFrames[mustFrameIndex(t, m.targetFrames, pathA)].Width + if currentWidth == targetWidth { + t.Fatalf("expected intermediate zoom frame width to differ from target (current=%d target=%d, pre=%d)", currentWidth, targetWidth, preWidth) + } + + for i := 0; i < 180 && m.animating; i++ { + next, _ := m.Update(animTickMsg{}) + m = next.(Model) + } + if m.animating { + t.Fatalf("expected zoom animation to settle within 180 ticks") + } + finalWidth := m.frames[mustFrameIndex(t, m.frames, pathA)].Width + if finalWidth != targetWidth { + t.Fatalf("expected final zoom width %d, got %d", targetWidth, finalWidth) + } +} + +func TestSearchLifecycleAndMatchNavigation(t *testing.T) { + m := NewModel(nil) + m.frames = []tuiFrame{ + {Name: "alpha", Path: "root" + pathSeparator + "alpha"}, + {Name: "beta", Path: "root" + pathSeparator + "beta"}, + {Name: "alphabet", Path: "root" + pathSeparator + "alphabet"}, + } + + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: []rune{'/'}[0], Text: "/"}) + if !m.searchActive { + t.Fatalf("expected search mode to activate on '/'") + } + for _, r := range []rune{'a', 'l', 'p'} { + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: r, Text: string(r)}) + } + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: tea.KeyEnter}) + + if m.searchActive { + t.Fatalf("expected search mode to close on enter") + } + if got := len(m.matchIndices); got != 2 { + t.Fatalf("expected 2 matches for 'alp', got %d", got) + } + first := m.selectedIdx + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: []rune{'n'}[0], Text: "n"}) + if m.selectedIdx == first { + t.Fatalf("expected 'n' to jump to next match") + } + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: []rune{'N'}[0], Text: "N"}) + if m.selectedIdx != first { + t.Fatalf("expected 'N' to jump back to previous match") + } +} + +func TestSearchEscapeClearsState(t *testing.T) { + m := NewModel(nil) + m.frames = []tuiFrame{{Name: "alpha", Path: "root" + pathSeparator + "alpha"}} + + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: []rune{'/'}[0], Text: "/"}) + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: []rune{'a'}[0], Text: "a"}) + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: tea.KeyEsc}) + + if m.searchActive { + t.Fatalf("expected search mode to close on escape") + } + if m.searchQuery != "" || len(m.matchIndices) != 0 { + t.Fatalf("expected search state to reset on escape, got query=%q matches=%d", m.searchQuery, len(m.matchIndices)) + } + if m.statusMessage != "Filter cleared" { + t.Fatalf("expected filter cleared status message, got %q", m.statusMessage) + } +} + +func TestSearchSubmitSetsFilterStatusMessage(t *testing.T) { + m := NewModel(nil) + m.frames = []tuiFrame{ + {Name: "alpha", Path: "root" + pathSeparator + "alpha"}, + {Name: "beta", Path: "root" + pathSeparator + "beta"}, + } + + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: []rune{'/'}[0], Text: "/"}) + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: []rune{'a'}[0], Text: "a"}) + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: tea.KeyEnter}) + if m.statusMessage != `Filter "a": 2 matches` { + t.Fatalf("unexpected status after applying filter: %q", m.statusMessage) + } + + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: []rune{'/'}[0], Text: "/"}) + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: tea.KeyEsc}) + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: []rune{'/'}[0], Text: "/"}) + for _, r := range []rune{'z', 'z'} { + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: r, Text: string(r)}) + } + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: tea.KeyEnter}) + if m.statusMessage != `Filter "zz": no matches` { + t.Fatalf("unexpected status for unmatched filter: %q", m.statusMessage) + } +} + +func TestControlPauseToggle(t *testing.T) { + m := NewModel(nil) + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: []rune{'p'}[0], Text: "p"}) + if !m.paused { + t.Fatalf("expected pause to toggle on") + } + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: tea.KeySpace, Text: " "}) + if m.paused { + t.Fatalf("expected space key to toggle pause off") + } + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: tea.KeySpace, Text: " "}) + if !m.paused { + t.Fatalf("expected space key to toggle pause on") + } + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: []rune{'p'}[0], Text: "p"}) + if m.paused { + t.Fatalf("expected p key to toggle pause off") + } +} + +func TestControlResetBaseline(t *testing.T) { + liveTrie := coreflamegraph.NewLiveTrie([]string{"comm", "path", "tracepoint"}, "count") + m := NewModel(liveTrie) + m.snapshot = &snapshotNode{Name: "root", Total: 10} + m.frames = []tuiFrame{{Name: "root", Path: "root"}} + m.zoomPath = "root" + m.zoomStack = []zoomState{{path: "", previousSelectedIdx: 0}} + m.selectedIdx = 3 + + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: []rune{'r'}[0], Text: "r"}) + if m.snapshot != nil || len(m.frames) != 0 || len(m.zoomStack) != 0 || m.zoomPath != "" { + t.Fatalf("expected baseline reset to clear snapshot/layout/zoom state") + } + if m.statusMessage != "Baseline reset" { + t.Fatalf("expected reset status message, got %q", m.statusMessage) + } +} + +func TestViewIncludesSelectionStatusBar(t *testing.T) { + m := NewModel(nil) + m.width = 120 + m.height = 20 + m.frames = []tuiFrame{ + {Name: "root", Depth: 0, Col: 0, Row: 0, Width: 120, Total: 100, Percent: 100, Path: "root"}, + {Name: "child", Depth: 1, Col: 0, Row: 1, Width: 60, Total: 40, Percent: 40, Path: "root" + pathSeparator + "child"}, + } + m.selectedIdx = 1 + m.globalTotal = 100 + + view := m.View().Content + if !strings.Contains(view, "[LIVE] sel:2/2 child") { + t.Fatalf("expected selection status bar to include selected frame info, got %q", view) + } + if !strings.Contains(view, "40.00% of total events") { + t.Fatalf("expected selection status bar to include selected share, got %q", view) + } +} + +func TestViewSelectionStatusUsesBytesLabelInBytesMode(t *testing.T) { + m := NewModel(nil) + m.width = 120 + m.height = 20 + m.countField = "bytes" + m.frames = []tuiFrame{ + {Name: "root", Depth: 0, Col: 0, Row: 0, Width: 120, Total: 200, Percent: 100, Path: "root"}, + {Name: "child", Depth: 1, Col: 0, Row: 1, Width: 60, Total: 80, Percent: 40, Path: "root" + pathSeparator + "child"}, + } + m.selectedIdx = 1 + m.globalTotal = 200 + + view := m.View().Content + if !strings.Contains(view, "40.00% of total bytes") { + t.Fatalf("expected bytes-based selection share label, got %q", view) + } +} + +func TestViewFitsViewportHeightAndKeepsSearchFooterVisible(t *testing.T) { + m := NewModel(nil) + m.width = 100 + m.height = 12 + m.frames = []tuiFrame{ + {Name: "root", Depth: 0, Col: 0, Row: 0, Width: 100, Total: 100, Percent: 100, Path: "root"}, + {Name: "child", Depth: 1, Col: 0, Row: 1, Width: 80, Total: 80, Percent: 80, Path: "root" + pathSeparator + "child"}, + } + m.selectedIdx = 1 + m.globalTotal = 100 + m.searchActive = true + m.searchInput.SetValue("child") + + view := m.View().Content + lines := strings.Split(view, "\n") + if got, max := len(lines), m.height; got > max { + t.Fatalf("expected flame view to fit viewport height <=%d, got %d lines", max, got) + } + if !strings.Contains(view, "matches") { + t.Fatalf("expected search footer to remain visible in viewport, got %q", view) + } + if !strings.Contains(view, "[LIVE] sel:2/2 child") { + t.Fatalf("expected selection status line to remain visible, got %q", view) + } +} + +func TestViewFilterSelectionStatusUsesFilteredTotalAndKeepsContextVisible(t *testing.T) { + snapshot := &snapshotNode{ + Name: "root", + Total: 100, + Children: []*snapshotNode{ + { + Name: "keep", + Total: 60, + Children: []*snapshotNode{ + {Name: "needle", Total: 60}, + }, + }, + { + Name: "drop", + Total: 40, + Children: []*snapshotNode{ + {Name: "noise", Total: 40}, + }, + }, + }, + } + m := NewModel(nil) + m.width = 220 + m.height = 12 + m.frames = BuildTerminalLayout(snapshot, m.width, m.height) + m.globalTotal = 100 + m.selectedIdx = mustFrameIndex(t, m.frames, "root"+pathSeparator+"keep"+pathSeparator+"needle") + m.searchQuery = "needle" + m.recomputeFilterState() + + view := m.View().Content + if !strings.Contains(view, "100.00% of filtered events") { + t.Fatalf("expected filtered selection share in status line, got %q", view) + } + if !strings.Contains(view, "drop") || !strings.Contains(view, "noise") { + t.Fatalf("expected non-matching branches to remain visible while filtering, got %q", view) + } +} + +func TestControlCycleFieldOrderReconfiguresLiveTrie(t *testing.T) { + liveTrie := coreflamegraph.NewLiveTrie([]string{"comm", "path", "tracepoint"}, "count") + m := NewModel(liveTrie) + initial := append([]string(nil), m.fieldPresets[m.fieldIndex]...) + expectedNextIdx := (m.fieldIndex + 1) % len(m.fieldPresets) + + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: []rune{'o'}[0], Text: "o"}) + if m.fieldIndex != expectedNextIdx { + t.Fatalf("expected field index to advance to %d, got %d", expectedNextIdx, m.fieldIndex) + } + next := m.fieldPresets[m.fieldIndex] + if reflect.DeepEqual(initial, next) { + t.Fatalf("expected next field preset to differ from initial") + } + if got := liveTrie.Fields(); !reflect.DeepEqual(got, next) { + t.Fatalf("expected live trie fields %v, got %v", next, got) + } +} + +func TestControlMetricToggleReconfiguresLiveTrieCountField(t *testing.T) { + liveTrie := coreflamegraph.NewLiveTrie([]string{"comm", "path", "tracepoint"}, "count") + m := NewModel(liveTrie) + + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: []rune{'b'}[0], Text: "b"}) + if got, want := m.countField, "bytes"; got != want { + t.Fatalf("expected model count field %q, got %q", want, got) + } + if got, want := liveTrie.CountField(), "bytes"; got != want { + t.Fatalf("expected live trie count field %q, got %q", want, got) + } + if got, want := m.statusMessage, "Metric: bytes (new baseline)"; got != want { + t.Fatalf("expected metric toggle status %q, got %q", want, got) + } + + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: []rune{'b'}[0], Text: "b"}) + if got, want := m.countField, "count"; got != want { + t.Fatalf("expected model count field %q after second toggle, got %q", want, got) + } + if got, want := liveTrie.CountField(), "count"; got != want { + t.Fatalf("expected live trie count field %q after second toggle, got %q", want, got) + } +} + +func TestNewModelAlignsPresetIndexToLiveTrieFields(t *testing.T) { + liveTrie := coreflamegraph.NewLiveTrie([]string{"comm", "path", "tracepoint"}, "count") + m := NewModel(liveTrie) + if got, want := m.fieldPresets[m.fieldIndex], []string{"comm", "path", "tracepoint"}; !reflect.DeepEqual(got, want) { + t.Fatalf("expected model field preset to align with trie fields, got %v want %v", got, want) + } +} + +func TestNewModelAlignsCountFieldToLiveTrie(t *testing.T) { + liveTrie := coreflamegraph.NewLiveTrie([]string{"comm", "path", "tracepoint"}, "bytes") + m := NewModel(liveTrie) + if got, want := m.countField, "bytes"; got != want { + t.Fatalf("expected model count field to align with trie field, got %q want %q", got, want) + } +} + +func TestControlHelpToggle(t *testing.T) { + m := NewModel(nil) + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: []rune{'?'}[0], Text: "?"}) + if !m.showHelp { + t.Fatalf("expected help overlay to toggle on") + } + m = pressFlameKey(t, m, tea.KeyPressMsg{Code: []rune{'?'}[0], Text: "?"}) + if m.showHelp { + t.Fatalf("expected help overlay to toggle off") + } +} + +func TestDataRefreshAnimationConvergesOverTicks(t *testing.T) { + m := NewModel(nil) + m.width = 120 + m.height = 20 + m.snapshot = &snapshotNode{ + Name: "root", + Total: 100, + Children: []*snapshotNode{ + {Name: "A", Total: 60}, + {Name: "B", Total: 40}, + }, + } + m.rebuildFrames(false) + initial := append([]tuiFrame(nil), m.frames...) + + m.snapshot = &snapshotNode{ + Name: "root", + Total: 100, + Children: []*snapshotNode{ + {Name: "A", Total: 20}, + {Name: "B", Total: 80}, + }, + } + m.rebuildFrames(true) + if !m.animating { + t.Fatalf("expected animation to start after animated rebuild") + } + + next, _ := m.Update(animTickMsg{}) + m = next.(Model) + if len(m.frames) != len(initial) { + t.Fatalf("expected frame count to remain stable during animation") + } + + for i := 0; i < 180 && m.animating; i++ { + next, _ = m.Update(animTickMsg{}) + m = next.(Model) + } + if m.animating { + t.Fatalf("expected animation to settle within 180 ticks") + } + if len(m.frames) != len(m.targetFrames) { + t.Fatalf("expected settled frame count to match targets") + } + for i := range m.frames { + if m.frames[i].Width != m.targetFrames[i].Width || m.frames[i].Col != m.targetFrames[i].Col { + t.Fatalf("frame %d did not converge to target: got col=%d width=%d want col=%d width=%d", + i, m.frames[i].Col, m.frames[i].Width, m.targetFrames[i].Col, m.targetFrames[i].Width) + } + } +} + +func TestRebuildKeepsSelectionOnVisibleRowsWhenTruncated(t *testing.T) { + m := NewModel(nil) + m.width = 80 + m.height = 4 // only 2 render rows remain after toolbar+status + m.snapshot = &snapshotNode{ + Name: "root", + Children: []*snapshotNode{ + { + Name: "a", + Children: []*snapshotNode{ + { + Name: "b", + Children: []*snapshotNode{ + {Name: "c", Total: 5}, + }, + }, + }, + }, + }, + } + + m.rebuildFrames(false) + if len(m.frames) == 0 { + t.Fatalf("expected rebuilt frames") + } + rowOffset := m.visibleRowOffset() + if m.frames[m.selectedIdx].Row < rowOffset { + t.Fatalf("expected selected frame row %d to be visible (offset=%d)", m.frames[m.selectedIdx].Row, rowOffset) + } +} + +func TestResizeRecalculatesLayoutAndCullsNarrowFrames(t *testing.T) { + m := NewModel(nil) + m.width = 120 + m.height = 40 + m.snapshot = &snapshotNode{ + Name: "root", + Total: 100, + Children: []*snapshotNode{ + { + Name: "big", + Total: 99, + Children: []*snapshotNode{ + {Name: "deep", Total: 99}, + }, + }, + {Name: "tiny", Total: 1}, + }, + } + m.rebuildFrames(false) + _ = mustFrameIndex(t, m.frames, "root"+pathSeparator+"tiny") + + next, _ := m.Update(tea.WindowSizeMsg{Width: 80, Height: 24}) + m = next.(Model) + for i := 0; i < 180 && m.animating; i++ { + next, _ = m.Update(animTickMsg{}) + m = next.(Model) + } + + for _, frame := range m.frames { + if frame.Col+frame.Width > 80 { + t.Fatalf("frame exceeds resized width: %+v", frame) + } + if frame.Row >= 24 { + t.Fatalf("frame row exceeds resized height: %+v", frame) + } + } + for _, frame := range m.frames { + if frame.Path == "root"+pathSeparator+"tiny" { + t.Fatalf("expected tiny frame to be culled at width 80") + } + } +} + +func newZoomModel() Model { + m := NewModel(nil) + m.width = 120 + m.height = 30 + m.snapshot = &snapshotNode{ + Name: "root", + Total: 100, + Children: []*snapshotNode{ + { + Name: "A", + Total: 60, + Children: []*snapshotNode{ + {Name: "A1", Total: 30}, + {Name: "A2", Total: 30}, + }, + }, + {Name: "B", Total: 40}, + }, + } + m.rebuildFrames(false) + return m +} + +func mustFrameIndex(t *testing.T, frames []tuiFrame, path string) int { + t.Helper() + for idx, frame := range frames { + if frame.Path == path { + return idx + } + } + t.Fatalf("frame path %q not found", path) + return -1 +} + +func pressFlameKey(t *testing.T, m Model, keyMsg tea.KeyPressMsg) Model { + t.Helper() + next, _ := m.Update(keyMsg) + return next.(Model) +} diff --git a/internal/tui/flamegraph/renderer.go b/internal/tui/flamegraph/renderer.go new file mode 100644 index 0000000..e4c4043 --- /dev/null +++ b/internal/tui/flamegraph/renderer.go @@ -0,0 +1,708 @@ +package flamegraph + +import ( + "fmt" + "hash/fnv" + "image/color" + "math" + "sort" + "strings" + "unicode/utf8" + + common "ior/internal/tui/common" + + "charm.land/lipgloss/v2" +) + +const pathSeparator = "\x1f" +const pathSeparatorByte = '\x1f' +const minFlameWidth = 60 +const maxBarVisualHeight = 3 + +// BuildTerminalLayout converts a live trie snapshot into terminal frame cells. +func BuildTerminalLayout(snapshot *snapshotNode, width, height int) []tuiFrame { + return buildTerminalLayoutWithPath(snapshot, width, height, "") +} + +func buildTerminalLayoutWithPath(snapshot *snapshotNode, width, height int, rootPath string) []tuiFrame { + if snapshot == nil || width <= 0 || height <= 0 { + return nil + } + rootTotal := snapshotTotal(snapshot) + if rootTotal == 0 { + return nil + } + + rootName := frameName(snapshot.Name, 0) + if rootPath != "" { + rootName = rootPath + } + frames := make([]tuiFrame, 0, len(snapshot.Children)+1) + collectTerminalLayout(&frames, snapshot, rootTotal, height, 0, 0, rootName, width) + return frames +} + +func collectTerminalLayout(out *[]tuiFrame, node *snapshotNode, rootTotal uint64, height, depth, col int, path string, span int) { + if node == nil || depth >= height { + return + } + total := snapshotTotal(node) + if total == 0 || span < 1 { + return + } + + name := frameName(node.Name, depth) + *out = append(*out, tuiFrame{ + Name: name, + Col: col, + Row: depth, + Width: span, + Total: total, + Percent: 100 * float64(total) / float64(rootTotal), + Fill: terminalFrameColor(name), + Depth: depth, + Path: path, + }) + + if len(node.Children) == 0 { + return + } + + childWidths := allocateChildWidths(node.Children, total, span) + cursor := col + for idx, child := range node.Children { + childWidth := childWidths[idx] + if childWidth < 1 { + continue + } + childName := frameName(child.Name, depth+1) + childPath := strings.Join([]string{path, childName}, pathSeparator) + collectTerminalLayout(out, child, rootTotal, height, depth+1, cursor, childPath, childWidth) + cursor += childWidth + } +} + +func allocateChildWidths(children []*snapshotNode, parentTotal uint64, span int) []int { + widths := make([]int, len(children)) + if span <= 0 || parentTotal == 0 || len(children) == 0 { + return widths + } + + type childWidth struct { + idx int + total uint64 + raw float64 + } + items := make([]childWidth, 0, len(children)) + used := 0 + for idx, child := range children { + total := snapshotTotal(child) + if total == 0 { + continue + } + raw := float64(span) * (float64(total) / float64(parentTotal)) + width := int(math.Floor(raw)) + if width > 0 { + widths[idx] = width + used += width + } + items = append(items, childWidth{idx: idx, total: total, raw: raw}) + } + if len(items) == 0 { + return widths + } + + // If proportional rounding culled every child, surface top contributors so + // the user can still navigate beyond the root frame. + if used == 0 { + sort.Slice(items, func(i, j int) bool { + if items[i].total == items[j].total { + return items[i].idx < items[j].idx + } + return items[i].total > items[j].total + }) + visible := min(span, len(items)) + for i := 0; i < visible; i++ { + widths[items[i].idx] = 1 + } + } + return widths +} + +func snapshotTotal(node *snapshotNode) uint64 { + if node == nil { + return 0 + } + total := node.Value + for _, child := range node.Children { + total += snapshotTotal(child) + } + if node.Total > total { + return node.Total + } + return total +} + +func frameName(name string, depth int) string { + if name != "" { + return name + } + if depth == 0 { + return "root" + } + return "(unknown)" +} + +func terminalFrameColor(name string) color.Color { + if semantic, ok := semanticFrameColor(name); ok { + return semantic + } + + hasher := fnv.New32a() + _, _ = hasher.Write([]byte(name)) + h := hasher.Sum32() + return color.RGBA{ + R: uint8(200 + int(h%35)), + G: uint8(80 + int((h>>8)%120)), + B: uint8(40 + int((h>>16)%90)), + A: 255, + } +} + +func semanticFrameColor(name string) (color.Color, bool) { + label := strings.ToLower(strings.TrimSpace(name)) + switch { + case label == "": + return nil, false + case strings.Contains(label, "read"), strings.Contains(label, "pread"): + return color.RGBA{R: 78, G: 132, B: 201, A: 255}, true // read I/O: blue + case strings.Contains(label, "write"), strings.Contains(label, "pwrite"): + return color.RGBA{R: 222, G: 122, B: 58, A: 255}, true // write I/O: orange + case strings.Contains(label, "open"), strings.Contains(label, "close"), strings.Contains(label, "stat"), strings.Contains(label, "rename"), strings.Contains(label, "link"): + return color.RGBA{R: 196, G: 168, B: 72, A: 255}, true // metadata I/O: amber + case strings.HasPrefix(label, "/"), strings.Contains(label, "path:"), strings.Contains(label, "/"): + return color.RGBA{R: 88, G: 156, B: 84, A: 255}, true // file paths: green + case strings.Contains(label, "pid"), strings.Contains(label, "tid"): + return color.RGBA{R: 67, G: 151, B: 149, A: 255}, true // process/thread dimensions: teal + case strings.HasPrefix(label, "sys_"): + return color.RGBA{R: 191, G: 99, B: 74, A: 255}, true // other syscall buckets: rust + default: + return nil, false + } +} + +// RenderTerminalView renders a terminal flamegraph viewport from laid out frames. +func RenderTerminalView(frames []tuiFrame, width, height, selectedIdx int, subtreeSet, matchSet, filterSet map[int]bool, globalTotal uint64, metricLabel string, isDark, searchActive bool, searchQuery string) string { + if width < minFlameWidth { + return common.PanelStyle.Render("Flame: terminal too narrow (need >= 60 columns)") + } + if height < 3 { + return common.PanelStyle.Render("Flame: viewport too short") + } + if len(frames) == 0 { + return common.PanelStyle.Render("Flame: waiting for data...") + } + if strings.TrimSpace(metricLabel) == "" { + metricLabel = "events" + } + + filterActive := strings.TrimSpace(searchQuery) != "" + if filterActive { + if filterSet == nil { + filterSet = computeFilterVisibleSetInto(frames, matchSet, nil) + } + if len(filterSet) == 0 { + return common.PanelStyle.Render(fmt.Sprintf("Flame: no frames match filter %q", searchQuery)) + } + } else { + filterSet = nil + } + + selectedIdx = normalizeSelectedIndex(frames, selectedIdx, filterSet) + selected := frames[selectedIdx] + viewPath := compactFramePath(frames[0].Path) + if subtreeSet == nil { + subtreeSet = computeSubtreeSet(frames, selectedIdx) + } + + availableRows := height - 2 // toolbar + status + maxRow := maxFrameRowForSet(frames, nil) + totalDepthRows := maxRow + 1 + barHeight := computeBarHeight(availableRows, totalDepthRows, maxBarVisualHeight) + visibleDepthRows := availableRows / barHeight + if visibleDepthRows < 1 { + visibleDepthRows = 1 + } + rowOffset := 0 + truncated := false + if maxRow+1 > visibleDepthRows { + rowOffset = maxRow + 1 - visibleDepthRows + truncated = true + } + + visibleFrames := countVisibleFrames(frames, nil) + toolbar := fmt.Sprintf("Flame | view:%s | frames:%d", viewPath, visibleFrames) + toolbar += fmt.Sprintf(" | rows:%d", availableRows) + if truncated { + toolbar += " | showing deepest levels" + } + toolbar = padOrTrim(toolbar, width) + selectedSystemShare := selected.Percent + if globalTotal > 0 { + selectedSystemShare = percentOfTotal(selected.Total, globalTotal) + } + if filterActive { + filterCoveredTotal, filterBaseTotal := filterCoverageTotals(frames, matchSet, globalTotal) + filterSystemShare := percentOfTotal(filterCoveredTotal, filterBaseTotal) + selectedFilterShare := 0.0 + if filterCoveredTotal > 0 { + selectedMatchTotal := filterCoverageTotalForPath(frames, matchSet, selected.Path) + selectedFilterShare = percentOfTotal(selectedMatchTotal, filterCoveredTotal) + } + matches := orderedMatchIndices(matchSet) + pos := 0 + if len(matches) > 0 { + if idx := indexOf(matches, selectedIdx); idx >= 0 { + pos = idx + 1 + } + } + frameCoverage := 0.0 + if len(frames) > 0 { + frameCoverage = 100 * float64(visibleFrames) / float64(len(frames)) + } + status := fmt.Sprintf("Filter %q: %.1f%% %s (%d/%d matches, %.1f%% frames shown) | Selected: %s total(%s)=%d depth=%d %.2f%% filtered %s", + searchQuery, filterSystemShare, metricLabel, pos, len(matches), frameCoverage, + selected.Name, metricLabel, selected.Total, selected.Depth, selectedFilterShare, metricLabel) + return renderViewRows(toolbar, status, rowsForRender(frames, width, rowOffset, maxRow, barHeight, availableRows, selected.Path, subtreeSet, matchSet, selectedIdx, isDark, searchActive, filterActive), width) + } else { + status := fmt.Sprintf("Selected: %s [%s] total(%s)=%d depth=%d col=%d width=%d share=%.2f%% %s", + selected.Name, compactFramePath(selected.Path), metricLabel, selected.Total, selected.Depth, selected.Col, selected.Width, selectedSystemShare, metricLabel) + return renderViewRows(toolbar, status, rowsForRender(frames, width, rowOffset, maxRow, barHeight, availableRows, selected.Path, subtreeSet, matchSet, selectedIdx, isDark, searchActive, filterActive), width) + } +} + +func rowsForRender(frames []tuiFrame, width, rowOffset, maxRow, barHeight, availableRows int, selectedPath string, subtreeSet, matchSet map[int]bool, selectedIdx int, isDark, searchActive, filterActive bool) []string { + return buildRenderRows(frames, width, rowOffset, maxRow, barHeight, availableRows, selectedPath, subtreeSet, matchSet, selectedIdx, isDark, searchActive, filterActive) +} + +func renderViewRows(toolbar, status string, rows []string, width int) string { + status = padOrTrim(status, width) + var b strings.Builder + b.Grow((width + 1) * (len(rows) + 2)) + b.WriteString(toolbar) + for _, row := range rows { + b.WriteString("\n") + b.WriteString(row) + } + b.WriteString("\n") + b.WriteString(status) + return b.String() +} + +type indexedFrame struct { + idx int + frame tuiFrame +} + +func buildRenderRows(frames []tuiFrame, width, rowOffset, maxRow, barHeight, availableRows int, selectedPath string, subtreeSet, matchSet map[int]bool, selectedIdx int, isDark, searchActive, filterActive bool) []string { + rowsByDepth := make(map[int][]indexedFrame) + for idx, frame := range frames { + if frame.Row < rowOffset || frame.Row > maxRow { + continue + } + rowsByDepth[frame.Row] = append(rowsByDepth[frame.Row], indexedFrame{idx: idx, frame: frame}) + } + + if barHeight < 1 { + barHeight = 1 + } + + rows := make([]string, 0, (maxRow-rowOffset+1)*barHeight) + for row := maxRow; row >= rowOffset; row-- { + framesAtRow := rowsByDepth[row] + sort.Slice(framesAtRow, func(i, j int) bool { + return framesAtRow[i].frame.Col < framesAtRow[j].frame.Col + }) + for repeat := 0; repeat < barHeight; repeat++ { + showLabels := repeat == barHeight/2 + rows = append(rows, renderRow(framesAtRow, width, selectedPath, subtreeSet, matchSet, selectedIdx, isDark, searchActive, filterActive, showLabels)) + } + } + + if availableRows > 0 { + if len(rows) > availableRows { + rows = rows[:availableRows] + } + if len(rows) < availableRows { + blank := strings.Repeat(" ", width) + pad := make([]string, 0, availableRows) + for i := 0; i < availableRows-len(rows); i++ { + pad = append(pad, blank) + } + pad = append(pad, rows...) + rows = pad + } + } + return rows +} + +func renderRow(frames []indexedFrame, width int, selectedPath string, subtreeSet, matchSet map[int]bool, selectedIdx int, isDark, searchActive, filterActive, showLabels bool) string { + if len(frames) == 0 { + return strings.Repeat(" ", width) + } + var b strings.Builder + b.Grow(width + 8) + cursor := 0 + for _, item := range frames { + frame := item.frame + if frame.Col >= width { + continue + } + if frame.Col > cursor { + gap := frame.Col - cursor + b.WriteString(strings.Repeat(" ", gap)) + cursor += gap + } + + cellWidth := frame.Width + if frame.Col+cellWidth > width { + cellWidth = width - frame.Col + } + if cellWidth <= 0 { + continue + } + label := strings.Repeat(" ", cellWidth) + if showLabels { + label = frameLabel(frame.Name, cellWidth, item.idx == selectedIdx, matchSet != nil && matchSet[item.idx]) + } + style := styleForFrame(item.idx, frame, selectedPath, subtreeSet, matchSet, selectedIdx, isDark, searchActive, filterActive) + cell := style.Render(label) + b.WriteString(cell) + cursor = frame.Col + cellWidth + } + if cursor < width { + b.WriteString(strings.Repeat(" ", width-cursor)) + } + return b.String() +} + +func computeSubtreeSet(frames []tuiFrame, selectedIdx int) map[int]bool { + return computeSubtreeSetInto(frames, selectedIdx, nil) +} + +func computeSubtreeSetInto(frames []tuiFrame, selectedIdx int, subtree map[int]bool) map[int]bool { + if subtree == nil { + subtree = make(map[int]bool) + } else { + for idx := range subtree { + delete(subtree, idx) + } + } + if selectedIdx < 0 || selectedIdx >= len(frames) { + return subtree + } + + selectedPath := frames[selectedIdx].Path + for idx, frame := range frames { + path := frame.Path + if path == selectedPath || + hasPathBoundaryPrefix(path, selectedPath) || + hasPathBoundaryPrefix(selectedPath, path) { + subtree[idx] = true + } + } + return subtree +} + +func hasPathBoundaryPrefix(value, prefix string) bool { + if len(value) <= len(prefix) { + return false + } + if !strings.HasPrefix(value, prefix) { + return false + } + return value[len(prefix)] == pathSeparatorByte +} + +func computeFilterVisibleSetInto(frames []tuiFrame, matchSet, visible map[int]bool) map[int]bool { + if visible == nil { + visible = make(map[int]bool) + } else { + for idx := range visible { + delete(visible, idx) + } + } + if len(matchSet) == 0 { + return visible + } + + matchPaths := make([]string, 0, len(matchSet)) + for idx := range matchSet { + if idx >= 0 && idx < len(frames) { + matchPaths = append(matchPaths, frames[idx].Path) + } + } + for idx, frame := range frames { + for _, matchPath := range matchPaths { + // Show matching frames and their full ancestry to root. + if frame.Path == matchPath || hasPathBoundaryPrefix(matchPath, frame.Path) { + visible[idx] = true + break + } + } + } + return visible +} + +func styleForFrame(idx int, frame tuiFrame, selectedPath string, subtreeSet, matchSet map[int]bool, selectedIdx int, isDark, searchActive, filterActive bool) lipgloss.Style { + _ = searchActive + base := lipgloss.NewStyle(). + Foreground(common.ColorBackground). + Background(frame.Fill) + + isSelected := idx == selectedIdx + inSubtree := subtreeSet[idx] + isMatch := matchSet != nil && matchSet[idx] + + matchColor := lipgloss.Color("160") + if !isDark { + matchColor = lipgloss.Color("124") + } + + if isSelected { + selectedBg := lipgloss.Color("129") + selectedFg := lipgloss.Color("15") + if !isDark { + selectedBg = lipgloss.Color("129") + selectedFg = lipgloss.Color("15") + } + return base.Background(selectedBg).Foreground(selectedFg).Bold(true).Underline(true) + } + + if isMatch { + style := base.Background(matchColor).Foreground(lipgloss.Color("15")) + if inSubtree { + return style.Bold(true) + } + return style.Faint(true) + } + + if filterActive { + return base.Background(common.ColorPanel).Foreground(common.ColorMuted).Faint(true) + } + + if inSubtree { + if frameRelation(frame.Path, selectedPath) == relationAncestor { + return base.BorderLeft(true).BorderForeground(common.ColorAccent) + } + return base + } + + return base.Background(common.ColorPanel).Foreground(common.ColorMuted).Faint(true) +} + +func frameLabel(name string, width int, isSelected, isMatch bool) string { + if width <= 0 { + return "" + } + if isSelected { + if width == 1 { + return ">" + } + return ">" + padOrTrim(name, width-2) + "<" + } + if isMatch { + if width == 1 { + return "*" + } + return "*" + padOrTrim(name, width-1) + } + return padOrTrim(name, width) +} + +func compactFramePath(path string) string { + if path == "" { + return "root" + } + parts := strings.Split(path, pathSeparator) + if len(parts) <= 3 { + return strings.Join(parts, "/") + } + return strings.Join([]string{parts[0], "...", parts[len(parts)-1]}, "/") +} + +type relation int + +const ( + relationNone relation = iota + relationAncestor + relationDescendant +) + +func frameRelation(path, selectedPath string) relation { + if path == selectedPath { + return relationDescendant + } + if strings.HasPrefix(selectedPath, path+pathSeparator) { + return relationAncestor + } + if strings.HasPrefix(path, selectedPath+pathSeparator) { + return relationDescendant + } + return relationNone +} + +func maxFrameRow(frames []tuiFrame) int { + return maxFrameRowForSet(frames, nil) +} + +func maxFrameRowForSet(frames []tuiFrame, include map[int]bool) int { + maxRow := 0 + for idx, frame := range frames { + if include != nil && !include[idx] { + continue + } + if frame.Row > maxRow { + maxRow = frame.Row + } + } + return maxRow +} + +func countVisibleFrames(frames []tuiFrame, include map[int]bool) int { + if include == nil { + return len(frames) + } + count := 0 + for idx := range frames { + if include[idx] { + count++ + } + } + return count +} + +func normalizeSelectedIndex(frames []tuiFrame, selectedIdx int, include map[int]bool) int { + if len(frames) == 0 { + return 0 + } + if selectedIdx >= 0 && selectedIdx < len(frames) && (include == nil || include[selectedIdx]) { + return selectedIdx + } + if include != nil { + for idx := range frames { + if include[idx] { + return idx + } + } + } + return 0 +} + +func filterSampleCoverage(frames []tuiFrame, matchSet map[int]bool, totalBase uint64) float64 { + coveredTotal, rootTotal := filterCoverageTotals(frames, matchSet, totalBase) + return percentOfTotal(coveredTotal, rootTotal) +} + +func computeBarHeight(availableRows, depthRows, maxHeight int) int { + if availableRows <= 0 || depthRows <= 0 { + return 1 + } + height := availableRows / depthRows + if height < 1 { + height = 1 + } + if maxHeight > 0 && height > maxHeight { + height = maxHeight + } + return height +} + +func filterCoverageTotals(frames []tuiFrame, matchSet map[int]bool, totalBase uint64) (coveredTotal uint64, rootTotal uint64) { + if len(frames) == 0 || len(matchSet) == 0 { + return 0, 0 + } + rootTotal = totalBase + if rootTotal == 0 { + rootTotal = frames[0].Total + } + if rootTotal == 0 { + return 0, 0 + } + roots := compactMatchRoots(frames, matchSet) + for _, root := range roots { + coveredTotal += root.total + } + return coveredTotal, rootTotal +} + +func filterCoverageTotalForPath(frames []tuiFrame, matchSet map[int]bool, path string) uint64 { + if path == "" || len(frames) == 0 || len(matchSet) == 0 { + return 0 + } + roots := compactMatchRoots(frames, matchSet) + var coveredTotal uint64 + for _, root := range roots { + if root.path == path || hasPathBoundaryPrefix(root.path, path) { + coveredTotal += root.total + } + } + return coveredTotal +} + +type matchRoot struct { + path string + total uint64 +} + +func compactMatchRoots(frames []tuiFrame, matchSet map[int]bool) []matchRoot { + roots := make([]matchRoot, 0, len(matchSet)) + for idx := range matchSet { + if idx < 0 || idx >= len(frames) { + continue + } + roots = append(roots, matchRoot{ + path: frames[idx].Path, + total: frames[idx].Total, + }) + } + sort.Slice(roots, func(i, j int) bool { + return len(roots[i].path) < len(roots[j].path) + }) + merged := make([]matchRoot, 0, len(roots)) + for _, candidate := range roots { + covered := false + for _, root := range merged { + if candidate.path == root.path || hasPathBoundaryPrefix(candidate.path, root.path) { + covered = true + break + } + } + if covered { + continue + } + merged = append(merged, candidate) + } + return merged +} + +func percentOfTotal(value, total uint64) float64 { + if total == 0 { + return 0 + } + return 100 * float64(value) / float64(total) +} + +func padOrTrim(s string, width int) string { + if width <= 0 { + return "" + } + if utf8.RuneCountInString(s) <= width { + return s + strings.Repeat(" ", width-utf8.RuneCountInString(s)) + } + if width == 1 { + return "…" + } + r := []rune(s) + return string(r[:width-1]) + "…" +} diff --git a/internal/tui/flamegraph/renderer_test.go b/internal/tui/flamegraph/renderer_test.go new file mode 100644 index 0000000..c546200 --- /dev/null +++ b/internal/tui/flamegraph/renderer_test.go @@ -0,0 +1,368 @@ +package flamegraph + +import ( + "image/color" + "strings" + "testing" +) + +func TestBuildTerminalLayoutWidthScaling(t *testing.T) { + snapshot := &snapshotNode{ + Name: "root", + Total: 100, + Children: []*snapshotNode{ + { + Name: "A", + Total: 60, + Children: []*snapshotNode{ + {Name: "A1", Total: 30}, + {Name: "A2", Total: 30}, + }, + }, + {Name: "B", Total: 40}, + }, + } + + tests := []struct { + width int + wantA int + wantB int + wantA1 int + wantA2 int + wantAll int + }{ + {width: 80, wantA: 48, wantB: 32, wantA1: 24, wantA2: 24, wantAll: 5}, + {width: 120, wantA: 72, wantB: 48, wantA1: 36, wantA2: 36, wantAll: 5}, + {width: 200, wantA: 120, wantB: 80, wantA1: 60, wantA2: 60, wantAll: 5}, + } + + for _, tc := range tests { + frames := BuildTerminalLayout(snapshot, tc.width, 10) + if len(frames) != tc.wantAll { + t.Fatalf("width %d: expected %d frames, got %d", tc.width, tc.wantAll, len(frames)) + } + root := mustFindFrame(t, frames, "root") + if root.Width != tc.width || root.Row != 0 || root.Col != 0 { + t.Fatalf("width %d: unexpected root frame %+v", tc.width, root) + } + a := mustFindFrame(t, frames, "root"+pathSeparator+"A") + b := mustFindFrame(t, frames, "root"+pathSeparator+"B") + a1 := mustFindFrame(t, frames, "root"+pathSeparator+"A"+pathSeparator+"A1") + a2 := mustFindFrame(t, frames, "root"+pathSeparator+"A"+pathSeparator+"A2") + + if a.Width != tc.wantA || b.Width != tc.wantB { + t.Fatalf("width %d: unexpected child widths A=%d B=%d", tc.width, a.Width, b.Width) + } + if a1.Width != tc.wantA1 || a2.Width != tc.wantA2 { + t.Fatalf("width %d: unexpected grandchild widths A1=%d A2=%d", tc.width, a1.Width, a2.Width) + } + if b.Col != a.Col+a.Width { + t.Fatalf("width %d: expected B col %d, got %d", tc.width, a.Col+a.Width, b.Col) + } + } +} + +func TestBuildTerminalLayoutCullsSubCellFramesAndRespectsHeight(t *testing.T) { + snapshot := &snapshotNode{ + Name: "root", + Total: 100, + Children: []*snapshotNode{ + { + Name: "big", + Total: 99, + Children: []*snapshotNode{ + {Name: "deep", Total: 99}, + }, + }, + {Name: "tiny", Total: 1}, + }, + } + + frames := BuildTerminalLayout(snapshot, 80, 2) + if hasFrame(frames, "root"+pathSeparator+"tiny") { + t.Fatalf("expected tiny frame to be culled (<1 terminal cell)") + } + if hasFrame(frames, "root"+pathSeparator+"big"+pathSeparator+"deep") { + t.Fatalf("expected deep frame to be omitted due height limit") + } + if !hasFrame(frames, "root"+pathSeparator+"big") { + t.Fatalf("expected big frame to be present") + } +} + +func TestBuildTerminalLayoutKeepsChildrenVisibleWhenRoundingWouldCullAll(t *testing.T) { + children := make([]*snapshotNode, 0, 200) + for i := 0; i < 200; i++ { + children = append(children, &snapshotNode{Name: "c", Total: 1}) + } + snapshot := &snapshotNode{Name: "root", Children: children} + + frames := BuildTerminalLayout(snapshot, 120, 6) + depthOne := 0 + for _, frame := range frames { + if frame.Depth == 1 { + depthOne++ + } + } + if depthOne == 0 { + t.Fatalf("expected at least one visible depth-1 frame, got none") + } +} + +func TestBuildTerminalLayoutUsesPathSeparatorAndColor(t *testing.T) { + snapshot := &snapshotNode{ + Name: "root", + Total: 10, + Children: []*snapshotNode{ + {Name: "child", Total: 10}, + }, + } + + frames := BuildTerminalLayout(snapshot, 80, 4) + child := mustFindFrame(t, frames, "root"+pathSeparator+"child") + if !strings.Contains(child.Path, pathSeparator) { + t.Fatalf("expected path %q to contain separator %q", child.Path, pathSeparator) + } + if child.Fill == nil { + t.Fatalf("expected frame color to be set") + } +} + +func TestTerminalFrameColorSemanticPalette(t *testing.T) { + tests := []struct { + name string + label string + want color.RGBA + }{ + {name: "read", label: "sys_enter_read", want: color.RGBA{R: 78, G: 132, B: 201, A: 255}}, + {name: "write", label: "sys_enter_write", want: color.RGBA{R: 222, G: 122, B: 58, A: 255}}, + {name: "metadata", label: "sys_enter_openat", want: color.RGBA{R: 196, G: 168, B: 72, A: 255}}, + {name: "path", label: "/var/log/app.log", want: color.RGBA{R: 88, G: 156, B: 84, A: 255}}, + {name: "pid", label: "pid=1234", want: color.RGBA{R: 67, G: 151, B: 149, A: 255}}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := terminalFrameColor(tc.label) + if got != tc.want { + t.Fatalf("unexpected semantic color for %q: got=%v want=%v", tc.label, got, tc.want) + } + }) + } +} + +func TestRenderTerminalViewShowsNarrowMessage(t *testing.T) { + out := RenderTerminalView(nil, 50, 10, 0, nil, nil, nil, 0, "events", true, false, "") + if !strings.Contains(out, "terminal too narrow") { + t.Fatalf("expected narrow terminal warning, got %q", out) + } +} + +func TestComputeBarHeightCappedAtThree(t *testing.T) { + if got := computeBarHeight(30, 4, 3); got != 3 { + t.Fatalf("expected bar height cap at 3, got %d", got) + } + if got := computeBarHeight(5, 10, 3); got != 1 { + t.Fatalf("expected bar height minimum 1 when depth exceeds rows, got %d", got) + } +} + +func TestRenderTerminalViewIncludesToolbarAndStatus(t *testing.T) { + snapshot := &snapshotNode{ + Name: "root", + Total: 10, + Children: []*snapshotNode{ + {Name: "child", Total: 10}, + }, + } + frames := BuildTerminalLayout(snapshot, 80, 6) + + out := RenderTerminalView(frames, 80, 6, 1, nil, nil, nil, 0, "events", true, false, "") + if !strings.Contains(out, "Flame | view:root | frames:2") { + t.Fatalf("expected toolbar to include frame count, got %q", out) + } + if !strings.Contains(out, "Selected: child") { + t.Fatalf("expected status line to show selected frame, got %q", out) + } +} + +func TestRenderTerminalViewFillsAvailableHeightForShallowTree(t *testing.T) { + snapshot := &snapshotNode{ + Name: "root", + Total: 10, + Children: []*snapshotNode{ + {Name: "child", Total: 10}, + }, + } + frames := BuildTerminalLayout(snapshot, 100, 20) + + out := RenderTerminalView(frames, 100, 20, 1, nil, nil, nil, 0, "events", true, false, "") + lines := strings.Split(out, "\n") + if got, want := len(lines), 20; got != want { + t.Fatalf("expected render to fill viewport height (%d lines), got %d", want, got) + } +} + +func TestFrameLabelAddsSelectionAndMatchMarkers(t *testing.T) { + if got := frameLabel("child", 7, true, false); got != ">child<" { + t.Fatalf("expected selected marker label, got %q", got) + } + if got := frameLabel("child", 6, false, true); got != "*child" { + t.Fatalf("expected match marker label, got %q", got) + } +} + +func TestRenderTerminalViewShowsPersistentFilterContext(t *testing.T) { + snapshot := &snapshotNode{ + Name: "root", + Total: 10, + Children: []*snapshotNode{ + {Name: "child", Total: 10}, + }, + } + frames := BuildTerminalLayout(snapshot, 80, 6) + matchSet := map[int]bool{1: true} + + out := RenderTerminalView(frames, 140, 6, 1, nil, matchSet, nil, 0, "events", true, false, "child") + if !strings.Contains(out, `Filter "child"`) { + t.Fatalf("expected filter context in status line, got %q", out) + } +} + +func TestRenderTerminalViewFilterKeepsNonMatchingBranchesVisible(t *testing.T) { + snapshot := &snapshotNode{ + Name: "root", + Total: 100, + Children: []*snapshotNode{ + { + Name: "keep", + Total: 60, + Children: []*snapshotNode{ + {Name: "needle", Total: 60}, + }, + }, + { + Name: "drop", + Total: 40, + Children: []*snapshotNode{ + {Name: "noise", Total: 40}, + }, + }, + }, + } + frames := BuildTerminalLayout(snapshot, 80, 8) + needleIdx := frameIndexByPathRenderer(frames, "root"+pathSeparator+"keep"+pathSeparator+"needle") + if needleIdx < 0 { + t.Fatalf("expected needle frame in layout") + } + matchSet := map[int]bool{needleIdx: true} + + out := RenderTerminalView(frames, 180, 8, needleIdx, nil, matchSet, nil, 100, "bytes", true, false, "needle") + if !strings.Contains(out, `Filter "needle": 60.0% bytes`) { + t.Fatalf("expected filter status to report 60.0%% bytes share, got %q", out) + } + if !strings.Contains(out, "keep") || !strings.Contains(out, "needle") { + t.Fatalf("expected matching branch to remain visible, got %q", out) + } + if !strings.Contains(out, "drop") || !strings.Contains(out, "noise") { + t.Fatalf("expected non-matching branch to remain visible (greyed), got %q", out) + } + if !strings.Contains(out, "100.00% filtered bytes") { + t.Fatalf("expected selected match share to be computed against filtered total, got %q", out) + } +} + +func TestFilterSampleCoverageAvoidsDoubleCountingNestedMatches(t *testing.T) { + frames := []tuiFrame{ + {Path: "root", Total: 100}, + {Path: "root" + pathSeparator + "A", Total: 60}, + {Path: "root" + pathSeparator + "A" + pathSeparator + "A1", Total: 30}, + {Path: "root" + pathSeparator + "B", Total: 40}, + } + matchSet := map[int]bool{ + 1: true, // A + 2: true, // A1 (nested under A) + } + if got := filterSampleCoverage(frames, matchSet, 100); got != 60 { + t.Fatalf("expected nested matches to count once at 60%%, got %.1f%%", got) + } +} + +func TestRenderTerminalViewShowsDeepLevelTruncationHint(t *testing.T) { + snapshot := &snapshotNode{ + Name: "root", + Total: 4, + Children: []*snapshotNode{ + { + Name: "a", + Total: 4, + Children: []*snapshotNode{ + { + Name: "b", + Total: 4, + Children: []*snapshotNode{ + { + Name: "c", + Total: 4, + Children: []*snapshotNode{ + {Name: "d", Total: 4}, + }, + }, + }, + }, + }, + }, + }, + } + frames := BuildTerminalLayout(snapshot, 80, 10) + out := RenderTerminalView(frames, 80, 4, 0, nil, nil, nil, 0, "events", true, false, "") + if !strings.Contains(out, "showing deepest levels") { + t.Fatalf("expected truncation hint in toolbar, got %q", out) + } +} + +func TestComputeSubtreeSetIncludesAncestorsAndDescendants(t *testing.T) { + frames := []tuiFrame{ + {Path: "root"}, + {Path: "root" + pathSeparator + "A"}, + {Path: "root" + pathSeparator + "A" + pathSeparator + "A1"}, + {Path: "root" + pathSeparator + "B"}, + } + + set := computeSubtreeSet(frames, 1) + if !set[0] || !set[1] || !set[2] { + t.Fatalf("expected root/A/A1 to be in selected subtree: %#v", set) + } + if set[3] { + t.Fatalf("did not expect sibling branch B in subtree: %#v", set) + } +} + +func mustFindFrame(t *testing.T, frames []tuiFrame, path string) tuiFrame { + t.Helper() + for _, frame := range frames { + if frame.Path == path { + return frame + } + } + t.Fatalf("frame with path %q not found", path) + return tuiFrame{} +} + +func hasFrame(frames []tuiFrame, path string) bool { + for _, frame := range frames { + if frame.Path == path { + return true + } + } + return false +} + +func frameIndexByPathRenderer(frames []tuiFrame, path string) int { + for idx, frame := range frames { + if frame.Path == path { + return idx + } + } + return -1 +} diff --git a/internal/tui/flamegraph/search.go b/internal/tui/flamegraph/search.go new file mode 100644 index 0000000..6bedc3e --- /dev/null +++ b/internal/tui/flamegraph/search.go @@ -0,0 +1,141 @@ +package flamegraph + +import ( + "fmt" + "sort" + "strings" +) + +func (m *Model) openSearch() { + m.searchActive = true + m.searchInput.SetValue(m.searchQuery) + m.searchInput.CursorEnd() + m.searchInput.Focus() +} + +func (m *Model) clearSearch() { + m.searchActive = false + m.searchQuery = "" + clearBoolMap(m.matchIndices) + clearBoolMap(m.filterVisible) + m.searchInput.SetValue("") + m.searchInput.Blur() + m.statusMessage = "Filter cleared" +} + +func (m *Model) applySearchQuery(raw string) { + m.searchQuery = strings.ToLower(strings.TrimSpace(raw)) + m.recomputeFilterState() + query := m.searchQuery + if query == "" { + m.ensureSelectionNavigable() + m.statusMessage = "Filter cleared" + return + } + + if len(m.matchIndices) > 0 { + m.jumpMatch(1) + m.statusMessage = fmt.Sprintf("Filter %q: %d matches", query, len(m.matchIndices)) + return + } + m.statusMessage = fmt.Sprintf("Filter %q: no matches", query) +} + +func (m *Model) jumpMatch(direction int) { + matches := orderedMatchIndices(m.matchIndices) + if len(matches) == 0 { + return + } + currentPos := indexOf(matches, m.selectedIdx) + if currentPos == -1 { + if direction < 0 { + m.selectedIdx = matches[len(matches)-1] + } else { + m.selectedIdx = matches[0] + } + m.subtreeSet = computeSubtreeSetInto(m.frames, m.selectedIdx, m.subtreeSet) + return + } + + next := currentPos + direction + if next < 0 { + next = len(matches) - 1 + } + if next >= len(matches) { + next = 0 + } + m.selectedIdx = matches[next] + m.subtreeSet = computeSubtreeSetInto(m.frames, m.selectedIdx, m.subtreeSet) +} + +func (m *Model) recomputeFilterState() { + if m.matchIndices == nil { + m.matchIndices = make(map[int]bool) + } else { + clearBoolMap(m.matchIndices) + } + if m.filterVisible == nil { + m.filterVisible = make(map[int]bool) + } else { + clearBoolMap(m.filterVisible) + } + if m.searchQuery == "" { + return + } + + for idx, frame := range m.frames { + if strings.Contains(strings.ToLower(frame.Name), m.searchQuery) { + m.matchIndices[idx] = true + } + } + m.filterVisible = computeFilterVisibleSetInto(m.frames, m.matchIndices, m.filterVisible) +} + +func orderedMatchIndices(matchSet map[int]bool) []int { + matches := make([]int, 0, len(matchSet)) + for idx := range matchSet { + matches = append(matches, idx) + } + sort.Ints(matches) + return matches +} + +func (m Model) searchFooter() string { + matches := orderedMatchIndices(m.matchIndices) + pos := 0 + if len(matches) > 0 { + idx := indexOf(matches, m.selectedIdx) + if idx >= 0 { + pos = idx + 1 + } + } + return fmt.Sprintf("%s %d/%d matches", m.searchInput.View(), pos, len(matches)) +} + +func replaceFooterLine(content, footer string) string { + if content == "" { + return footer + } + lastNewline := strings.LastIndexByte(content, '\n') + if lastNewline == -1 { + return footer + } + return content[:lastNewline+1] + footer +} + +func replaceHeaderLine(content, header string) string { + if content == "" { + return header + } + firstNewline := strings.IndexByte(content, '\n') + if firstNewline == -1 { + return header + } + return header + content[firstNewline:] +} + +func clearBoolMap[K comparable](values map[K]bool) { + for key := range values { + delete(values, key) + } +} diff --git a/internal/tui/flamegraph/stress_race_disabled_test.go b/internal/tui/flamegraph/stress_race_disabled_test.go new file mode 100644 index 0000000..c9769fd --- /dev/null +++ b/internal/tui/flamegraph/stress_race_disabled_test.go @@ -0,0 +1,7 @@ +//go:build !race + +package flamegraph + +func stressBudgetMultiplier() int { + return 1 +} diff --git a/internal/tui/flamegraph/stress_race_enabled_test.go b/internal/tui/flamegraph/stress_race_enabled_test.go new file mode 100644 index 0000000..30338f4 --- /dev/null +++ b/internal/tui/flamegraph/stress_race_enabled_test.go @@ -0,0 +1,7 @@ +//go:build race + +package flamegraph + +func stressBudgetMultiplier() int { + return 3 +} diff --git a/internal/tui/flamegraph/stress_test.go b/internal/tui/flamegraph/stress_test.go new file mode 100644 index 0000000..e53e4d5 --- /dev/null +++ b/internal/tui/flamegraph/stress_test.go @@ -0,0 +1,236 @@ +package flamegraph + +import ( + "encoding/json" + "fmt" + "math/rand" + "sync" + "testing" + "time" + + coreflamegraph "ior/internal/flamegraph" + "ior/internal/types" + + tea "charm.land/bubbletea/v2" +) + +func TestStressHighEventRate(t *testing.T) { + t.Parallel() + + const ( + workerCount = 10 + eventsPerWorker = 10000 + testDuration = 5 * time.Second + renderFPS = 30 + frameBudget = time.Second / renderFPS + ) + allowedBudget := frameBudget * time.Duration(stressBudgetMultiplier()) + + liveTrie := coreflamegraph.NewLiveTrie([]string{"comm", "path", "tracepoint"}, "count") + var ingestWG sync.WaitGroup + + type renderMetrics struct { + err error + samples int + total time.Duration + maxDuration time.Duration + } + renderDone := make(chan renderMetrics, 1) + + go func() { + ticker := time.NewTicker(frameBudget) + defer ticker.Stop() + deadline := time.NewTimer(testDuration) + defer deadline.Stop() + + metrics := renderMetrics{} + for { + select { + case <-ticker.C: + start := time.Now() + payload, _ := liveTrie.SnapshotJSON() + var snapshot snapshotNode + if err := json.Unmarshal(payload, &snapshot); err != nil { + metrics.err = fmt.Errorf("decode snapshot: %w", err) + renderDone <- metrics + return + } + frames := BuildTerminalLayout(&snapshot, 120, 40) + _ = frames + + elapsed := time.Since(start) + metrics.samples++ + metrics.total += elapsed + if elapsed > metrics.maxDuration { + metrics.maxDuration = elapsed + } + case <-deadline.C: + renderDone <- metrics + return + } + } + }() + + for worker := 0; worker < workerCount; worker++ { + worker := worker + ingestWG.Add(1) + go func() { + defer ingestWG.Done() + for i := 0; i < eventsPerWorker; i++ { + seed := worker*eventsPerWorker + i + traceID := types.SYS_ENTER_READ + if seed%2 == 0 { + traceID = types.SYS_ENTER_WRITE + } + pair := newBenchmarkPair( + fmt.Sprintf("worker-%d", worker), + traceID, + uint32(1000+worker), + uint32(200000+seed), + buildBenchmarkPath(6, 3, seed), + ) + liveTrie.Ingest(pair) + pair.Recycle() + } + }() + } + + ingestWG.Wait() + metrics := <-renderDone + + if metrics.err != nil { + t.Fatalf("render loop failed: %v", metrics.err) + } + if metrics.samples == 0 { + t.Fatal("render loop produced no samples") + } + avg := metrics.total / time.Duration(metrics.samples) + if avg > allowedBudget { + t.Fatalf("average render latency exceeded frame budget: avg=%s budget=%s samples=%d", avg, allowedBudget, metrics.samples) + } + if metrics.maxDuration > allowedBudget*6 { + t.Fatalf("max render latency too high: max=%s budget=%s", metrics.maxDuration, allowedBudget) + } +} + +func TestStressRapidResize(t *testing.T) { + t.Parallel() + + model := NewModel(nil) + model.width = 120 + model.height = 40 + model.snapshot = generateTestSnapshot(fixtureMediumDepth, fixtureMediumBreadth) + model.rebuildFrames(false) + if len(model.frames) == 0 { + t.Fatal("expected initial medium fixture frames") + } + + rng := rand.New(rand.NewSource(42)) + lastWidth, lastHeight := model.width, model.height + for i := 0; i < 100; i++ { + lastWidth = 60 + rng.Intn(241) // [60, 300] + lastHeight = 20 + rng.Intn(61) // [20, 80] + next, _ := model.Update(tea.WindowSizeMsg{Width: lastWidth, Height: lastHeight}) + model = next.(Model) + model = settleStressAnimation(model, 180) + + assertFramesWithinBounds(t, model.frames, lastWidth, lastHeight) + if len(model.frames) > 0 && (model.selectedIdx < 0 || model.selectedIdx >= len(model.frames)) { + t.Fatalf("invalid selectedIdx after resize %d: idx=%d frames=%d", i, model.selectedIdx, len(model.frames)) + } + } + + if model.width != lastWidth || model.height != lastHeight { + t.Fatalf("final viewport mismatch: got %dx%d want %dx%d", model.width, model.height, lastWidth, lastHeight) + } + assertFramesWithinBounds(t, model.frames, lastWidth, lastHeight) +} + +func TestStressZoomDuringRefresh(t *testing.T) { + t.Parallel() + + liveTrie := coreflamegraph.NewLiveTrie([]string{"comm", "path", "tracepoint"}, "count") + ingestStressEvents(liveTrie, 200, 0) + + model := NewModel(liveTrie) + model.SetViewport(120, 40) + if changed := model.RefreshFromLiveTrie(); !changed { + t.Fatal("expected initial live trie refresh") + } + if len(model.frames) == 0 { + t.Fatal("expected initial frames after refresh") + } + + for i := 0; i < 50; i++ { + ingestStressEvents(liveTrie, 20, 1000+i*20) + _ = model.RefreshFromLiveTrie() + model = settleStressAnimation(model, 180) + if len(model.frames) == 0 { + t.Fatalf("expected frames after refresh tick %d", i) + } + + prevDepth := len(model.zoomStack) + model.selectedIdx = midDepthFrameIndex(model.frames) + model.zoomIn() + model = settleStressAnimation(model, 180) + if len(model.zoomStack) != prevDepth+1 { + t.Fatalf("zoom stack did not grow after zoom-in at iteration %d: got=%d want=%d", i, len(model.zoomStack), prevDepth+1) + } + + model.zoomUndo() + model = settleStressAnimation(model, 180) + if len(model.zoomStack) != prevDepth { + t.Fatalf("zoom stack depth mismatch after undo at iteration %d: got=%d want=%d", i, len(model.zoomStack), prevDepth) + } + if model.zoomPath != "" { + if findNodeByPath(model.snapshot, model.zoomPath) == nil { + t.Fatalf("zoomPath became invalid after undo at iteration %d: %q", i, model.zoomPath) + } + } + assertFramesWithinBounds(t, model.frames, model.width, model.height) + } +} + +func settleStressAnimation(model Model, maxTicks int) Model { + for i := 0; i < maxTicks && model.animating; i++ { + next, _ := model.Update(animTickMsg{}) + model = next.(Model) + } + return model +} + +func assertFramesWithinBounds(t *testing.T, frames []tuiFrame, width, height int) { + t.Helper() + for _, frame := range frames { + if frame.Col < 0 || frame.Width <= 0 { + t.Fatalf("invalid frame geometry: %+v", frame) + } + if frame.Col+frame.Width > width { + t.Fatalf("frame exceeds width %d: %+v", width, frame) + } + if frame.Row < 0 || frame.Row >= height { + t.Fatalf("frame row outside height %d: %+v", height, frame) + } + } +} + +func ingestStressEvents(liveTrie *coreflamegraph.LiveTrie, count, seedBase int) { + for i := 0; i < count; i++ { + seed := seedBase + i + traceID := types.SYS_ENTER_READ + if seed%3 == 0 { + traceID = types.SYS_ENTER_OPENAT + } else if seed%2 == 0 { + traceID = types.SYS_ENTER_WRITE + } + pair := newBenchmarkPair( + fmt.Sprintf("stress-%d", seed%8), + traceID, + uint32(1200+(seed%64)), + uint32(300000+seed), + buildBenchmarkPath(9, 5, seed), + ) + liveTrie.Ingest(pair) + pair.Recycle() + } +} diff --git a/internal/tui/flamegraph/testdata_fixture_test.go b/internal/tui/flamegraph/testdata_fixture_test.go new file mode 100644 index 0000000..1f22c26 --- /dev/null +++ b/internal/tui/flamegraph/testdata_fixture_test.go @@ -0,0 +1,39 @@ +package flamegraph + +import "testing" + +func TestFixtureSnapshotsHaveApproximateFrameCounts(t *testing.T) { + fixtures := []struct { + name string + depth int + breadth int + expect int + }{ + {name: "small", depth: fixtureSmallDepth, breadth: fixtureSmallBreadth, expect: 121}, + {name: "medium", depth: fixtureMediumDepth, breadth: fixtureMediumBreadth, expect: 2500}, + {name: "large", depth: fixtureLargeDepth, breadth: fixtureLargeBreadth, expect: 12000}, + {name: "deep", depth: fixtureDeepDepth, breadth: fixtureDeepBreadth, expect: 100}, + {name: "wide", depth: fixtureWideDepth, breadth: fixtureWideBreadth, expect: 5000}, + } + + for _, fixture := range fixtures { + t.Run(fixture.name, func(t *testing.T) { + snap := generateTestSnapshot(fixture.depth, fixture.breadth) + got := snapshotNodeCount(snap) + if !approxEqualCount(got, fixture.expect) { + t.Fatalf("%s fixture nodes=%d, expected approximately %d", fixture.name, got, fixture.expect) + } + }) + } +} + +func TestGenerateTestTrieProducesSnapshotData(t *testing.T) { + lt := generateTestTrie(fixtureSmallDepth, fixtureSmallBreadth) + snap, err := decodeTrieSnapshot(lt) + if err != nil { + t.Fatalf("decode trie snapshot: %v", err) + } + if snap.Total == 0 { + t.Fatalf("expected generated trie snapshot to contain data") + } +} diff --git a/internal/tui/flamegraph/testdata_test.go b/internal/tui/flamegraph/testdata_test.go new file mode 100644 index 0000000..c7d97b0 --- /dev/null +++ b/internal/tui/flamegraph/testdata_test.go @@ -0,0 +1,185 @@ +package flamegraph + +import ( + "encoding/json" + "fmt" + "math" + + "ior/internal/event" + "ior/internal/file" + coreflamegraph "ior/internal/flamegraph" + "ior/internal/types" +) + +const ( + fixtureSmallDepth = 5 + fixtureSmallBreadth = 3 + + fixtureMediumDepth = 10 + fixtureMediumBreadth = 5 + + fixtureLargeDepth = 15 + fixtureLargeBreadth = 8 + + fixtureDeepDepth = 50 + fixtureDeepBreadth = 2 + + fixtureWideDepth = 3 + fixtureWideBreadth = 50 +) + +func generateTestTrie(depth, breadthPerLevel int) *coreflamegraph.LiveTrie { + lt := coreflamegraph.NewLiveTrie([]string{"comm", "path", "tracepoint"}, "count") + comms := []string{"api", "db", "worker", "cache"} + traceIDs := []types.TraceId{ + types.SYS_ENTER_READ, + types.SYS_ENTER_WRITE, + types.SYS_ENTER_OPENAT, + types.SYS_ENTER_CLOSE, + } + + totalEvents := maxInt(100, fixtureTargetFrames(depth, breadthPerLevel)/2) + for i := 0; i < totalEvents; i++ { + comm := comms[i%len(comms)] + traceID := traceIDs[i%len(traceIDs)] + path := buildBenchmarkPath(depth, breadthPerLevel, i) + lt.Ingest(newBenchmarkPair(comm, traceID, uint32(1000+(i%256)), uint32(200000+i), path)) + } + return lt +} + +func generateTestSnapshot(depth, breadthPerLevel int) *snapshotNode { + targetFrames := fixtureTargetFrames(depth, breadthPerLevel) + if targetFrames < 1 { + targetFrames = 1 + } + + root := &snapshotNode{Name: "root", Value: 1} + type qItem struct { + node *snapshotNode + depth int + } + queue := []qItem{{node: root, depth: 0}} + created := 1 + + for len(queue) > 0 && created < targetFrames { + item := queue[0] + queue = queue[1:] + if item.depth >= depth { + continue + } + remaining := targetFrames - created + branchCount := breadthPerLevel + if branchCount > remaining { + branchCount = remaining + } + for i := 0; i < branchCount; i++ { + child := &snapshotNode{ + Name: fmt.Sprintf("d%d-n%d", item.depth+1, created+i), + Value: 1, + } + item.node.Children = append(item.node.Children, child) + queue = append(queue, qItem{node: child, depth: item.depth + 1}) + } + created += branchCount + } + + computeSnapshotTotals(root) + return root +} + +func fixtureTargetFrames(depth, breadth int) int { + switch { + case depth == fixtureSmallDepth && breadth == fixtureSmallBreadth: + return 121 + case depth == fixtureMediumDepth && breadth == fixtureMediumBreadth: + return 2500 + case depth == fixtureLargeDepth && breadth == fixtureLargeBreadth: + return 12000 + case depth == fixtureDeepDepth && breadth == fixtureDeepBreadth: + return 100 + case depth == fixtureWideDepth && breadth == fixtureWideBreadth: + return 5000 + default: + return maxInt(1, depth*breadth*10) + } +} + +func computeSnapshotTotals(node *snapshotNode) uint64 { + if node == nil { + return 0 + } + total := node.Value + for _, child := range node.Children { + total += computeSnapshotTotals(child) + } + node.Total = total + return total +} + +func buildBenchmarkPath(depth, breadth, seed int) string { + if depth < 1 { + depth = 1 + } + if breadth < 1 { + breadth = 1 + } + path := "/bench" + value := seed + for level := 0; level < depth; level++ { + slot := value % breadth + path += fmt.Sprintf("/l%d-b%d", level, slot) + value = value / breadth + } + return path +} + +func newBenchmarkPair(comm string, traceID types.TraceId, pid, tid uint32, path string) *event.Pair { + enter := &types.OpenEvent{ + TraceId: traceID, + Pid: pid, + Tid: tid, + } + exit := &types.RetEvent{ + TraceId: types.SYS_EXIT_OPENAT, + Pid: pid, + Tid: tid, + } + pair := event.NewPair(enter) + pair.ExitEv = exit + pair.File = file.NewFd(3, path, 0) + pair.Comm = comm + pair.Duration = 1 + pair.DurationToPrev = 1 + pair.Bytes = 64 + return pair +} + +func snapshotNodeCount(node *snapshotNode) int { + if node == nil { + return 0 + } + total := 1 + for _, child := range node.Children { + total += snapshotNodeCount(child) + } + return total +} + +func approxEqualCount(got, want int) bool { + if got == want { + return true + } + const tolerance = 0.2 + diff := math.Abs(float64(got-want)) / float64(want) + return diff <= tolerance +} + +func decodeTrieSnapshot(lt *coreflamegraph.LiveTrie) (*snapshotNode, error) { + payload, _ := lt.SnapshotJSON() + var snap snapshotNode + if err := json.Unmarshal(payload, &snap); err != nil { + return nil, err + } + return &snap, nil +} diff --git a/internal/tui/flamegraph/zoom.go b/internal/tui/flamegraph/zoom.go new file mode 100644 index 0000000..7a3aa42 --- /dev/null +++ b/internal/tui/flamegraph/zoom.go @@ -0,0 +1,39 @@ +package flamegraph + +import "strings" + +func findNodeByPath(root *snapshotNode, path string) *snapshotNode { + if root == nil { + return nil + } + if path == "" { + return root + } + parts := strings.Split(path, pathSeparator) + if len(parts) == 0 { + return root + } + rootName := frameName(root.Name, 0) + if parts[0] == rootName { + parts = parts[1:] + } + + node := root + for _, part := range parts { + next := findChildByName(node, part) + if next == nil { + return nil + } + node = next + } + return node +} + +func findChildByName(node *snapshotNode, name string) *snapshotNode { + for _, child := range node.Children { + if child.Name == name || frameName(child.Name, 1) == name { + return child + } + } + return nil +} diff --git a/internal/tui/messages/doc.go b/internal/tui/messages/doc.go new file mode 100644 index 0000000..8d70859 --- /dev/null +++ b/internal/tui/messages/doc.go @@ -0,0 +1,2 @@ +// Package messages defines typed Bubble Tea messages exchanged across TUI models. +package messages diff --git a/internal/tui/pidpicker/doc.go b/internal/tui/pidpicker/doc.go new file mode 100644 index 0000000..5f863c0 --- /dev/null +++ b/internal/tui/pidpicker/doc.go @@ -0,0 +1,2 @@ +// Package pidpicker implements PID and TID selection workflows before tracing starts. +package pidpicker diff --git a/internal/tui/pidpicker/model.go b/internal/tui/pidpicker/model.go index 73f21ae..f4bb414 100644 --- a/internal/tui/pidpicker/model.go +++ b/internal/tui/pidpicker/model.go @@ -2,14 +2,15 @@ package pidpicker import ( "fmt" + "strings" + common "ior/internal/tui/common" "ior/internal/tui/messages" - "strings" - "github.com/charmbracelet/bubbles/key" - "github.com/charmbracelet/bubbles/textinput" - tea "github.com/charmbracelet/bubbletea" - "github.com/charmbracelet/lipgloss" + "charm.land/bubbles/v2/key" + "charm.land/bubbles/v2/textinput" + tea "charm.land/bubbletea/v2" + "charm.land/lipgloss/v2" ) const allPIDsLabel = "All PIDs" @@ -50,6 +51,14 @@ var ( errorStyle = common.ErrorStyle ) +func syncPickerStyles() { + screenStyle = common.ScreenStyle + headerStyle = common.HeaderStyle + helpBarStyle = common.HelpBarStyle + highlightStyle = common.HighlightStyle + errorStyle = common.ErrorStyle +} + type processesLoadedMsg struct { processes []ProcessInfo err error @@ -67,6 +76,7 @@ type Model struct { height int keys KeyMap lastErr error + isDark bool } // New creates a PID picker model with default shared key bindings. @@ -81,12 +91,14 @@ func NewWithKeys(keys KeyMap) Model { // NewPIDWithKeys creates a PID picker model with the provided key bindings. func NewPIDWithKeys(keys KeyMap) Model { + syncPickerStyles() input := textinput.New() input.Prompt = "Filter: " input.Placeholder = "pid, comm, or cmdline" input.Focus() input.CharLimit = 0 - input.Width = 40 + input.SetWidth(40) + input.SetStyles(textinput.DefaultStyles(true)) return Model{ input: input, @@ -94,6 +106,7 @@ func NewPIDWithKeys(keys KeyMap) Model { filtered: []ProcessInfo{}, mode: PickerModePID, targetPID: -1, + isDark: true, } } @@ -117,14 +130,18 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case tea.WindowSizeMsg: m.width = msg.Width m.height = msg.Height - m.input.Width = clamp(msg.Width-16, 10, 100) + inputWidth := msg.Width - 16 + if inputWidth < 10 { + inputWidth = 10 + } + m.input.SetWidth(inputWidth) return m, nil case processesLoadedMsg: m.processes = msg.processes m.lastErr = msg.err m.applyFilter() return m, nil - case tea.KeyMsg: + case tea.KeyPressMsg: return m.updateKey(msg) } @@ -134,21 +151,21 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { return m, cmd } -func (m Model) updateKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { +func (m Model) updateKey(msg tea.KeyPressMsg) (tea.Model, tea.Cmd) { switch { case key.Matches(msg, m.keys.Esc): return m, tea.Quit - case msg.Type == tea.KeyCtrlR: + case msg.Key().Mod&tea.ModCtrl != 0 && (msg.Key().Code == 'r' || msg.Key().Code == 'R'): return m, m.scanCmd() case key.Matches(msg, m.keys.Enter): return m, m.emitSelection() - case msg.Type == tea.KeyUp: + case msg.Key().Code == tea.KeyUp: if m.selectedIndex > 0 { m.selectedIndex-- } m.input.Blur() return m, nil - case msg.Type == tea.KeyDown: + case msg.Key().Code == tea.KeyDown: maxIndex := len(m.filtered) if m.selectedIndex < maxIndex { m.selectedIndex++ @@ -157,7 +174,7 @@ func (m Model) updateKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { return m, nil } - if msg.Type == tea.KeyRunes && !m.input.Focused() { + if msg.Key().Text != "" && !m.input.Focused() { if key.Matches(msg, m.keys.Refresh) { return m, m.scanCmd() } @@ -240,7 +257,7 @@ func cloneProcesses(in []ProcessInfo) []ProcessInfo { } // View renders the PID picker with filter input, list, and help bar. -func (m Model) View() string { +func (m Model) View() tea.View { var b strings.Builder if m.mode == PickerModeTID { if m.targetPID > 0 { @@ -264,8 +281,18 @@ func (m Model) View() string { } b.WriteString("\n") - b.WriteString(helpBarStyle.Render(renderHelp(m.keys.PickerShortHelp()))) - return screenStyle.Render(b.String()) + viewWidth, _ := common.EffectiveViewport(m.width, m.height) + helpStyle := helpBarStyle.Copy().Width(viewWidth) + b.WriteString(helpStyle.Render(renderHelp(m.keys.PickerShortHelp()))) + return tea.NewView(screenStyle.Render(b.String())) +} + +// SetDarkMode updates picker theme and text input styles. +func (m Model) SetDarkMode(isDark bool) Model { + m.isDark = isDark + syncPickerStyles() + m.input.SetStyles(textinput.DefaultStyles(isDark)) + return m } func (m Model) renderRows() string { diff --git a/internal/tui/pidpicker/model_test.go b/internal/tui/pidpicker/model_test.go index 2d76508..695e5bd 100644 --- a/internal/tui/pidpicker/model_test.go +++ b/internal/tui/pidpicker/model_test.go @@ -1,11 +1,12 @@ package pidpicker import ( - "ior/internal/tui/messages" "strings" "testing" - tea "github.com/charmbracelet/bubbletea" + "ior/internal/tui/messages" + + tea "charm.land/bubbletea/v2" ) func TestApplyFilterByPIDCommAndCmdline(t *testing.T) { @@ -39,7 +40,7 @@ func TestEnterEmitsAllPIDsAndSelectedPID(t *testing.T) { m.processes = []ProcessInfo{{Pid: 7, Comm: "vim"}, {Pid: 9, Comm: "top"}} m.applyFilter() - modelAny, cmdAny := m.Update(tea.KeyMsg{Type: tea.KeyEnter}) + modelAny, cmdAny := m.Update(tea.KeyPressMsg{Code: tea.KeyEnter}) _ = modelAny msgAny := cmdAny() pidAny, ok := msgAny.(messages.PidSelectedMsg) @@ -51,7 +52,7 @@ func TestEnterEmitsAllPIDsAndSelectedPID(t *testing.T) { } m.selectedIndex = 2 - modelOne, cmdOne := m.Update(tea.KeyMsg{Type: tea.KeyEnter}) + modelOne, cmdOne := m.Update(tea.KeyPressMsg{Code: tea.KeyEnter}) _ = modelOne msgOne := cmdOne() pidOne, ok := msgOne.(messages.PidSelectedMsg) @@ -71,7 +72,7 @@ func TestEnterEmitsAllTIDsAndSelectedTIDInTIDMode(t *testing.T) { } m.applyFilter() - modelAny, cmdAny := m.Update(tea.KeyMsg{Type: tea.KeyEnter}) + modelAny, cmdAny := m.Update(tea.KeyPressMsg{Code: tea.KeyEnter}) _ = modelAny msgAny := cmdAny() tidAny, ok := msgAny.(messages.TidSelectedMsg) @@ -86,7 +87,7 @@ func TestEnterEmitsAllTIDsAndSelectedTIDInTIDMode(t *testing.T) { } m.selectedIndex = 2 - modelOne, cmdOne := m.Update(tea.KeyMsg{Type: tea.KeyEnter}) + modelOne, cmdOne := m.Update(tea.KeyPressMsg{Code: tea.KeyEnter}) _ = modelOne msgOne := cmdOne() tidOne, ok := msgOne.(messages.TidSelectedMsg) @@ -104,7 +105,7 @@ func TestEnterEmitsAllTIDsAndSelectedTIDInTIDMode(t *testing.T) { func TestEscQuitsAndRefreshTriggersScan(t *testing.T) { m := NewWithKeys(DefaultKeyMap()) - _, escCmd := m.Update(tea.KeyMsg{Type: tea.KeyEsc}) + _, escCmd := m.Update(tea.KeyPressMsg{Code: tea.KeyEsc}) if escCmd == nil { t.Fatalf("expected esc to return quit cmd") } @@ -112,7 +113,7 @@ func TestEscQuitsAndRefreshTriggersScan(t *testing.T) { t.Fatalf("expected quit msg from esc, got %T", msg) } - _, refreshCmd := m.Update(tea.KeyMsg{Type: tea.KeyCtrlR}) + _, refreshCmd := m.Update(tea.KeyPressMsg{Code: rune('r'), Text: "r", Mod: tea.ModCtrl}) if refreshCmd == nil { t.Fatalf("expected refresh cmd") } @@ -124,7 +125,7 @@ func TestEscQuitsAndRefreshTriggersScan(t *testing.T) { func TestRuneRDoesNotTriggerRefreshWhileFilterFocused(t *testing.T) { m := NewWithKeys(DefaultKeyMap()) - next, cmd := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'r'}}) + next, cmd := m.Update(tea.KeyPressMsg{Code: []rune{'r'}[0], Text: string([]rune{'r'})}) if cmd == nil { t.Fatalf("expected textinput update cmd") } @@ -152,3 +153,13 @@ func TestRenderRowsKeepsSelectionVisible(t *testing.T) { t.Fatalf("expected selected row to remain visible, got:\n%s", rows) } } + +func TestWindowSizeDoesNotCapInputWidthOnWideTerminals(t *testing.T) { + m := NewWithKeys(DefaultKeyMap()) + next, _ := m.Update(tea.WindowSizeMsg{Width: 160, Height: 40}) + updated := next.(Model) + + if got, want := updated.input.Width(), 144; got != want { + t.Fatalf("expected input width %d for 160-col terminal, got %d", want, got) + } +} diff --git a/internal/tui/probes/doc.go b/internal/tui/probes/doc.go new file mode 100644 index 0000000..922aee6 --- /dev/null +++ b/internal/tui/probes/doc.go @@ -0,0 +1,2 @@ +// Package probes implements the runtime probe toggling modal for the TUI. +package probes diff --git a/internal/tui/probes/model.go b/internal/tui/probes/model.go index 5cec2c7..baf22e8 100644 --- a/internal/tui/probes/model.go +++ b/internal/tui/probes/model.go @@ -2,13 +2,14 @@ package probes import ( "fmt" - "ior/internal/probemanager" "strings" "unicode/utf8" - "github.com/charmbracelet/bubbles/textinput" - tea "github.com/charmbracelet/bubbletea" - "github.com/charmbracelet/lipgloss" + "ior/internal/probemanager" + + "charm.land/bubbles/v2/textinput" + tea "charm.land/bubbletea/v2" + "charm.land/lipgloss/v2" ) // Manager defines the probe operations used by the modal. @@ -39,16 +40,19 @@ type Model struct { lastErr string manager Manager height int + isDark bool } func NewModel(manager Manager) Model { ti := textinput.New() ti.Prompt = "/ " ti.CharLimit = 0 - ti.Width = 28 + ti.SetWidth(28) + ti.SetStyles(textinput.DefaultStyles(true)) return Model{ manager: manager, textInput: ti, + isDark: true, } } @@ -72,6 +76,13 @@ func (m Model) Close() Model { return m } +// SetDarkMode updates probe modal text input styles. +func (m Model) SetDarkMode(isDark bool) Model { + m.isDark = isDark + m.textInput.SetStyles(textinput.DefaultStyles(isDark)) + return m +} + func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) { if !m.visible { return m, nil @@ -87,7 +98,7 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) { } m.clampCursor() return m, nil - case tea.KeyMsg: + case tea.KeyPressMsg: if m.searching { return m.updateSearch(msg) } @@ -110,7 +121,7 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) { m.textInput.CursorEnd() m.textInput.Focus() return m, nil - case " ", "enter": + case " ", "space", "enter": selected := m.selectedSyscall() if selected == "" { return m, nil @@ -125,7 +136,7 @@ func (m Model) Update(msg tea.Msg) (Model, tea.Cmd) { return m, nil } -func (m Model) updateSearch(msg tea.KeyMsg) (Model, tea.Cmd) { +func (m Model) updateSearch(msg tea.KeyPressMsg) (Model, tea.Cmd) { switch msg.String() { case "esc": m.searching = false diff --git a/internal/tui/probes/model_test.go b/internal/tui/probes/model_test.go index 73a83bc..3a14675 100644 --- a/internal/tui/probes/model_test.go +++ b/internal/tui/probes/model_test.go @@ -5,7 +5,7 @@ import ( "ior/internal/probemanager" - tea "github.com/charmbracelet/bubbletea" + tea "charm.land/bubbletea/v2" ) type fakeManager struct { @@ -61,7 +61,7 @@ func TestToggleEmitsProbeToggledMsg(t *testing.T) { states: []probemanager.ProbeState{{Syscall: "read", Active: true}}, } m := NewModel(fm).Open() - next, cmd := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{' '}}) + next, cmd := m.Update(tea.KeyPressMsg{Code: []rune{' '}[0], Text: string([]rune{' '})}) if cmd == nil { t.Fatalf("expected toggle command") } @@ -90,7 +90,7 @@ func TestBulkKeysApplyGloballyNotOnlyFiltered(t *testing.T) { m := NewModel(fm).Open() m.search = "read" - _, cmd := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'n'}}) + _, cmd := m.Update(tea.KeyPressMsg{Code: []rune{'n'}[0], Text: string([]rune{'n'})}) if cmd == nil { t.Fatalf("expected bulk off command") } @@ -107,7 +107,7 @@ func TestBulkKeysApplyGloballyNotOnlyFiltered(t *testing.T) { m = NewModel(fm).Open() m.search = "read" fm.toggles = nil - _, cmd = m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'a'}}) + _, cmd = m.Update(tea.KeyPressMsg{Code: []rune{'a'}[0], Text: string([]rune{'a'})}) if cmd == nil { t.Fatalf("expected bulk on command") } diff --git a/internal/tui/styles.go b/internal/tui/styles.go index 3bf69f7..5452e57 100644 --- a/internal/tui/styles.go +++ b/internal/tui/styles.go @@ -26,9 +26,6 @@ var ( // TabInactiveStyle is applied to non-selected tabs. TabInactiveStyle = common.TabInactiveStyle - // PanelStyle is used for boxed sections. - PanelStyle = common.PanelStyle - // HelpBarStyle is used for keybinding hints at the bottom. HelpBarStyle = common.HelpBarStyle @@ -38,3 +35,25 @@ var ( // ErrorStyle is used for fatal or warning messages. ErrorStyle = common.ErrorStyle ) + +func syncStylesFromCommon() { + ColorBackground = common.ColorBackground + ColorPanel = common.ColorPanel + ColorPrimary = common.ColorPrimary + ColorAccent = common.ColorAccent + ColorMuted = common.ColorMuted + ColorText = common.ColorText + ColorDanger = common.ColorDanger + + ScreenStyle = common.ScreenStyle + HeaderStyle = common.HeaderStyle + TabActiveStyle = common.TabActiveStyle + TabInactiveStyle = common.TabInactiveStyle + HelpBarStyle = common.HelpBarStyle + HighlightStyle = common.HighlightStyle + ErrorStyle = common.ErrorStyle +} + +func init() { + syncStylesFromCommon() +} diff --git a/internal/tui/tui.go b/internal/tui/tui.go index bdd3ab5..d60ee4b 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -2,9 +2,14 @@ package tui import ( "context" - "encoding/csv" "errors" "fmt" + "log" + "strings" + "sync" + "time" + + coreexport "ior/internal/export" "ior/internal/flags" "ior/internal/probemanager" "ior/internal/statsengine" @@ -12,18 +17,15 @@ import ( dashboardui "ior/internal/tui/dashboard" "ior/internal/tui/eventstream" tuiexport "ior/internal/tui/export" + flamegraphtui "ior/internal/tui/flamegraph" "ior/internal/tui/messages" "ior/internal/tui/pidpicker" "ior/internal/tui/probes" - "os" - "strings" - "sync" - "time" - "github.com/charmbracelet/bubbles/key" - "github.com/charmbracelet/bubbles/spinner" - tea "github.com/charmbracelet/bubbletea" - "github.com/charmbracelet/lipgloss" + "charm.land/bubbles/v2/key" + "charm.land/bubbles/v2/spinner" + tea "charm.land/bubbletea/v2" + "charm.land/lipgloss/v2" ) // Screen identifies the currently active TUI screen. @@ -56,20 +58,28 @@ type ProbeManager interface { // (snapshot source, stream source, probe manager) into the active TUI model. type TraceRuntimeBindings interface { SetDashboardSnapshotSource(source SnapshotSource) - SetEventStreamSource(source *eventstream.RingBuffer) + SetEventStreamSource(source eventstream.Source) + SetLiveTrie(liveTrie flamegraphtui.LiveTrieSource) SetProbeManager(manager ProbeManager) } type runtimeBindingsContextKey struct{} +type traceFiltersContextKey struct{} type runtimeBindings struct { mu sync.RWMutex snapshotSource SnapshotSource - streamSource *eventstream.RingBuffer + streamSource eventstream.Source + liveTrieSource flamegraphtui.LiveTrieSource probeManager ProbeManager } +type traceFilters struct { + pidFilter int + tidFilter int +} + func newRuntimeBindings() *runtimeBindings { return &runtimeBindings{} } @@ -80,12 +90,18 @@ func (r *runtimeBindings) SetDashboardSnapshotSource(source SnapshotSource) { r.mu.Unlock() } -func (r *runtimeBindings) SetEventStreamSource(source *eventstream.RingBuffer) { +func (r *runtimeBindings) SetEventStreamSource(source eventstream.Source) { r.mu.Lock() r.streamSource = source r.mu.Unlock() } +func (r *runtimeBindings) SetLiveTrie(liveTrie flamegraphtui.LiveTrieSource) { + r.mu.Lock() + r.liveTrieSource = liveTrie + r.mu.Unlock() +} + func (r *runtimeBindings) SetProbeManager(manager ProbeManager) { r.mu.Lock() r.probeManager = manager @@ -98,12 +114,18 @@ func (r *runtimeBindings) dashboardSnapshotSource() SnapshotSource { return r.snapshotSource } -func (r *runtimeBindings) eventStreamSource() *eventstream.RingBuffer { +func (r *runtimeBindings) eventStreamSource() eventstream.Source { r.mu.RLock() defer r.mu.RUnlock() return r.streamSource } +func (r *runtimeBindings) liveTrie() flamegraphtui.LiveTrieSource { + r.mu.RLock() + defer r.mu.RUnlock() + return r.liveTrieSource +} + func (r *runtimeBindings) currentProbeManager() ProbeManager { r.mu.RLock() defer r.mu.RUnlock() @@ -135,6 +157,21 @@ func RuntimeBindingsFromContext(ctx context.Context) (TraceRuntimeBindings, bool return bindings, true } +// ContextWithTraceFilters stores the active PID/TID filters for the trace starter. +func ContextWithTraceFilters(ctx context.Context, pidFilter, tidFilter int) context.Context { + filters := traceFilters{pidFilter: pidFilter, tidFilter: tidFilter} + return context.WithValue(ctx, traceFiltersContextKey{}, filters) +} + +// TraceFiltersFromContext returns the active PID/TID filters when provided by the TUI model. +func TraceFiltersFromContext(ctx context.Context) (pidFilter, tidFilter int, ok bool) { + filters, ok := ctx.Value(traceFiltersContextKey{}).(traceFilters) + if !ok { + return 0, 0, false + } + return filters.pidFilter, filters.tidFilter, true +} + // Run starts the TUI program in alternate screen mode. func Run() error { return RunWithTraceStarter(defaultTraceStarter) @@ -142,9 +179,27 @@ func Run() error { // RunWithTraceStarter starts the TUI program with a custom trace starter. func RunWithTraceStarter(starter TraceStarter) error { - cfg := flags.Get() - model := newModelWithRuntimeConfig(cfg.PidFilter, cfg.PidFilter, cfg.TUIExportEnable, starter) - program := tea.NewProgram(model, tea.WithAltScreen()) + return RunWithTraceStarterConfig(flags.Get(), starter) +} + +// RunWithTraceStarterConfig starts the TUI with explicit runtime flags. +func RunWithTraceStarterConfig(cfg flags.Config, starter TraceStarter) error { + model := newModelWithRuntimeConfig(cfg.PidFilter, cfg.PidFilter, cfg.TidFilter, cfg.TUIExportEnable, starter) + program := tea.NewProgram(model) + _, err := program.Run() + return err +} + +// RunTestFlamesWithTraceStarter starts the TUI directly on dashboard/flame view +// with a synthetic static flamegraph source. +func RunTestFlamesWithTraceStarter(starter TraceStarter) error { + return RunTestFlamesWithTraceStarterConfig(flags.Get(), starter) +} + +// RunTestFlamesWithTraceStarterConfig starts test-flames mode with explicit runtime flags. +func RunTestFlamesWithTraceStarterConfig(cfg flags.Config, starter TraceStarter) error { + model := newModelWithRuntimeConfig(1, 1, -1, cfg.TUIExportEnable, starter) + program := tea.NewProgram(model) _, err := program.Run() return err } @@ -160,6 +215,8 @@ type Model struct { keys KeyMap + helpOverlayVisible bool + width int height int quitting bool @@ -172,16 +229,38 @@ type Model struct { traceStop context.CancelFunc pidFilter int + tidFilter int exportEnabled bool + isDark bool + focused bool + + keyboardEnhancements tea.KeyboardEnhancementsMsg + keyboardEnhancementsKnown bool + + lastKeyEventID string + lastKeyEventAt time.Time + lastKeyEventWasPress bool + // Some terminals emit release+press for a single physical key event. + // When we fallback-handle a release as a press, suppress the immediate + // matching press to avoid double-handling. + suppressPressKeyID string + suppressPressUntil time.Time } // NewModel creates the top-level TUI model. func NewModel(initialPID int, startTrace TraceStarter) Model { - cfg := flags.Get() - return newModelWithRuntimeConfig(initialPID, cfg.PidFilter, cfg.TUIExportEnable, startTrace) + return NewModelWithConfig(flags.Get(), initialPID, startTrace) +} + +// NewModelWithConfig creates the top-level TUI model with explicit runtime flags. +func NewModelWithConfig(cfg flags.Config, initialPID int, startTrace TraceStarter) Model { + return newModelWithRuntimeConfig(initialPID, cfg.PidFilter, cfg.TidFilter, cfg.TUIExportEnable, startTrace) } -func newModelWithRuntimeConfig(initialPID, startupPidFilter int, exportEnabled bool, startTrace TraceStarter) Model { +func newModelWithRuntimeConfig(initialPID, startupPidFilter, startupTidFilter int, exportEnabled bool, startTrace TraceStarter) Model { + common.ApplyPalette(true) + syncStylesFromCommon() + spin := spinner.New() spin.Spinner = spinner.MiniDot if startTrace == nil { @@ -194,28 +273,35 @@ func newModelWithRuntimeConfig(initialPID, startupPidFilter int, exportEnabled b runtime := newRuntimeBindings() dashboard := dashboardui.NewModelWithConfig(lateBoundDashboardSource{runtime: runtime}, runtime.eventStreamSource(), 1000, keys) + dashboard.SetDarkMode(true) pidFilter := selectedPIDFilter(startupPidFilter) if initialPID > 0 { pidFilter = selectedPIDFilter(initialPID) } + tidFilter := selectedPIDFilter(startupTidFilter) + if initialPID > 0 { + tidFilter = -1 + } dashboard.SetPidFilter(pidFilter) model := Model{ screen: ScreenPIDPicker, - pidPicker: pidpicker.New(), + pidPicker: pidpicker.New().SetDarkMode(true), dashboard: dashboard, exporter: tuiexport.NewModel(), - probeModal: probes.NewModel(runtime.currentProbeManager()), + probeModal: probes.NewModel(runtime.currentProbeManager()).SetDarkMode(true), runtime: runtime, keys: keys, spin: spin, startTrace: startTrace, pidFilter: pidFilter, + tidFilter: tidFilter, exportEnabled: exportEnabled, + isDark: true, + focused: true, } if initialPID > 0 { - flags.SetPidFilter(initialPID) model.screen = ScreenDashboard model.attaching = true } @@ -227,9 +313,9 @@ func newModelWithRuntimeConfig(initialPID, startupPidFilter int, exportEnabled b func (m Model) Init() tea.Cmd { sizeCmd := initialWindowSizeCmd() if m.screen == ScreenDashboard && m.attaching { - return tea.Batch(sizeCmd, tea.WindowSize(), m.spin.Tick, m.beginTraceCmd()) + return tea.Batch(sizeCmd, tea.RequestWindowSize, tea.RequestBackgroundColor, m.spin.Tick, m.beginTraceCmd()) } - return tea.Batch(sizeCmd, tea.WindowSize(), m.pidPicker.Init()) + return tea.Batch(sizeCmd, tea.RequestWindowSize, tea.RequestBackgroundColor, m.pidPicker.Init()) } func initialWindowSizeCmd() tea.Cmd { @@ -241,30 +327,41 @@ func initialWindowSizeCmd() tea.Cmd { // Update routes messages, transitions screens, and manages tracing startup state. func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + normalizedMsg, ok := m.keyNormalizer(msg) + if !ok { + return m, nil + } + msg = normalizedMsg + switch msg := msg.(type) { case tea.WindowSizeMsg: m.width = msg.Width m.height = msg.Height return m.updateActiveModel(msg) - case tea.KeyMsg: - if key.Matches(msg, m.keys.Quit) { - m.quitting = true - m.stopTrace() - return m, tea.Quit - } - if m.exportEnabled && m.screen == ScreenDashboard && !m.attaching && m.lastErr == nil && key.Matches(msg, m.keys.Export) && !m.exporter.Visible() && !m.probeModal.Visible() && !m.dashboard.BlocksGlobalShortcuts() { - m.exporter = m.exporter.Open() - return m, nil - } - if m.screen == ScreenDashboard && !m.attaching && m.lastErr == nil && key.Matches(msg, m.keys.Probes) && !m.exporter.Visible() && !m.probeModal.Visible() && !m.dashboard.BlocksGlobalShortcuts() { - m.probeModal = probes.NewModel(m.runtime.currentProbeManager()).Open() - return m, nil + case tea.BackgroundColorMsg: + m.applyTheme(msg.IsDark()) + return m, nil + case tea.KeyboardEnhancementsMsg: + m.keyboardEnhancements = msg + m.keyboardEnhancementsKnown = true + if msg.SupportsKeyDisambiguation() { + log.Printf("tui: keyboard enhancements enabled (flags=%d, eventTypes=%t)", msg.Flags, msg.SupportsEventTypes()) } - if m.screen == ScreenDashboard && !m.attaching && m.lastErr == nil && key.Matches(msg, m.keys.SelectPID) && !m.exporter.Visible() && !m.probeModal.Visible() && !m.dashboard.BlocksGlobalShortcuts() { - return m.reselectPID() + return m, nil + case tea.FocusMsg: + m.focused = true + m.dashboard.SetFocused(true) + if m.screen == ScreenDashboard && !m.attaching { + return m, tea.Batch(m.dashboard.Init(), m.dashboard.SnapshotCmd()) } - if m.screen == ScreenDashboard && !m.attaching && m.lastErr == nil && key.Matches(msg, m.keys.SelectTID) && !m.exporter.Visible() && !m.probeModal.Visible() && !m.dashboard.BlocksGlobalShortcuts() { - return m.reselectTID() + return m, nil + case tea.BlurMsg: + m.focused = false + m.dashboard.SetFocused(false) + return m, nil + case tea.KeyPressMsg: + if next, cmd, handled := m.handleGlobalKeyPress(msg); handled { + return next, cmd } case tuiexport.RequestMsg: return m, runExportCmd(m.exportEnabled, msg.Option, m.dashboard.LatestSnapshot()) @@ -292,44 +389,199 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case TracingStartedMsg: m.attaching = false m.dashboard.SetStreamSource(m.runtime.eventStreamSource()) - return m, m.dashboard.Init() + m.dashboard.SetLiveTrie(m.runtime.liveTrie()) + width, height := common.EffectiveViewport(m.width, m.height) + next, sizeCmd := m.dashboard.Update(tea.WindowSizeMsg{Width: width, Height: height}) + m.dashboard = next.(dashboardui.Model) + return m, tea.Batch(sizeCmd, m.dashboard.Init()) case TracingErrorMsg: m.attaching = false m.lastErr = msg.Err return m, nil } + if next, cmd, handled := m.handleModalDispatch(msg); handled { + return next, cmd + } + + return m.updateActiveModel(msg) +} + +func (m *Model) keyNormalizer(msg tea.Msg) (tea.Msg, bool) { + return m.normalizeKeyEvent(msg) +} + +func (m Model) canHandleDashboardShortcut(msg tea.KeyPressMsg) bool { + return m.screen == ScreenDashboard && + !m.attaching && + m.lastErr == nil && + !m.exporter.Visible() && + !m.probeModal.Visible() && + !m.dashboard.BlocksGlobalShortcuts(msg) +} + +func (m Model) handleGlobalKeyPress(msg tea.KeyPressMsg) (tea.Model, tea.Cmd, bool) { + if key.Matches(msg, m.keys.Quit) { + m.quitting = true + m.stopTrace() + return m, tea.Quit, true + } + if m.helpOverlayVisible { + if isHelpOverlayCloseKey(msg) || isHelpOverlayOpenKey(msg) { + m.helpOverlayVisible = false + } + return m, nil, true + } + if isHelpOverlayOpenKey(msg) && !m.attaching && m.lastErr == nil { + m.helpOverlayVisible = true + return m, nil, true + } + if m.exportEnabled && m.canHandleDashboardShortcut(msg) && key.Matches(msg, m.keys.Export) { + m.exporter = m.exporter.Open() + return m, nil, true + } + if m.canHandleDashboardShortcut(msg) && key.Matches(msg, m.keys.Probes) { + m.probeModal = probes.NewModel(m.runtime.currentProbeManager()).SetDarkMode(m.isDark).Open() + return m, nil, true + } + if m.canHandleDashboardShortcut(msg) && key.Matches(msg, m.keys.SelectPID) { + next, cmd := m.reselectPID() + return next, cmd, true + } + if m.canHandleDashboardShortcut(msg) && key.Matches(msg, m.keys.SelectTID) { + next, cmd := m.reselectTID() + return next, cmd, true + } + return m, nil, false +} + +func (m Model) updateDashboardForModal(msg tea.Msg) (Model, tea.Cmd) { + if _, isKey := msg.(tea.KeyPressMsg); isKey || m.screen != ScreenDashboard { + return m, nil + } + next, cmd := m.dashboard.Update(msg) + m.dashboard = next.(dashboardui.Model) + return m, cmd +} + +func (m Model) updateProbeModal(msg tea.Msg) (tea.Model, tea.Cmd) { + m, dashboardCmd := m.updateDashboardForModal(msg) + var cmd tea.Cmd + m.probeModal, cmd = m.probeModal.Update(msg) + return m, tea.Batch(dashboardCmd, cmd) +} + +func (m Model) updateExportModal(msg tea.Msg) (tea.Model, tea.Cmd) { + m, dashboardCmd := m.updateDashboardForModal(msg) + var cmd tea.Cmd + m.exporter, cmd = m.exporter.Update(msg) + return m, tea.Batch(dashboardCmd, cmd) +} + +func (m Model) handleModalDispatch(msg tea.Msg) (tea.Model, tea.Cmd, bool) { if m.attaching { var cmd tea.Cmd m.spin, cmd = m.spin.Update(msg) - return m, cmd + return m, cmd, true } if m.probeModal.Visible() { - var dashboardCmd tea.Cmd - // Keep dashboard refresh/data flow alive while probe modal is open. - if _, isKey := msg.(tea.KeyMsg); !isKey && m.screen == ScreenDashboard { - next, cmd := m.dashboard.Update(msg) - m.dashboard = next.(dashboardui.Model) - dashboardCmd = cmd - } - var cmd tea.Cmd - m.probeModal, cmd = m.probeModal.Update(msg) - return m, tea.Batch(dashboardCmd, cmd) + next, cmd := m.updateProbeModal(msg) + return next, cmd, true } if m.exporter.Visible() { - var dashboardCmd tea.Cmd - // Keep dashboard refresh/data flow alive while export modal is open. - if _, isKey := msg.(tea.KeyMsg); !isKey && m.screen == ScreenDashboard { - next, cmd := m.dashboard.Update(msg) - m.dashboard = next.(dashboardui.Model) - dashboardCmd = cmd + next, cmd := m.updateExportModal(msg) + return next, cmd, true + } + return m, nil, false +} + +func (m *Model) normalizeKeyEvent(msg tea.Msg) (tea.Msg, bool) { + switch keyMsg := msg.(type) { + case tea.KeyPressMsg: + keyID := keyEventID(keyMsg) + if m.shouldSuppressPress(keyID) { + return nil, false } - var cmd tea.Cmd - m.exporter, cmd = m.exporter.Update(msg) - return m, tea.Batch(dashboardCmd, cmd) + m.recordKeyEvent(keyMsg, true) + return keyMsg, true + case tea.KeyReleaseMsg: + pressMsg := tea.KeyPressMsg(keyMsg) + keyID := keyEventID(pressMsg) + if m.lastKeyEventWasPress && keyID != "" && keyID == m.lastKeyEventID && time.Since(m.lastKeyEventAt) <= 500*time.Millisecond { + // Some terminals emit both press+release; avoid handling release as a duplicate. + m.lastKeyEventWasPress = false + return nil, false + } + if !releaseHasIdentity(pressMsg) { + // Ignore release messages that don't carry enough identity information. + // Some terminals emit these before a usable press event. + return nil, false + } + // Fallback: treat release as press for terminals that only emit release events. + if shouldSuppressMatchingPressAfterRelease(pressMsg) { + m.armPressSuppression(keyID) + } + m.recordKeyEvent(pressMsg, false) + return pressMsg, true + default: + return msg, true } +} - return m.updateActiveModel(msg) +func (m *Model) shouldSuppressPress(keyID string) bool { + if m.suppressPressKeyID == "" { + return false + } + if time.Now().After(m.suppressPressUntil) { + m.clearPressSuppression() + return false + } + if keyID == "" || keyID != m.suppressPressKeyID { + return false + } + m.clearPressSuppression() + return true +} + +func (m *Model) armPressSuppression(keyID string) { + if keyID == "" { + return + } + // Keep this short so fast repeated key presses still work naturally. + m.suppressPressKeyID = keyID + m.suppressPressUntil = time.Now().Add(60 * time.Millisecond) +} + +func (m *Model) clearPressSuppression() { + m.suppressPressKeyID = "" + m.suppressPressUntil = time.Time{} +} + +func (m *Model) recordKeyEvent(msg tea.KeyPressMsg, wasPress bool) { + m.lastKeyEventID = keyEventID(msg) + m.lastKeyEventAt = time.Now() + m.lastKeyEventWasPress = wasPress +} + +func keyEventID(msg tea.KeyPressMsg) string { + return fmt.Sprintf("code:%d/mod:%d/key:%q/text:%q", msg.Code, msg.Mod, msg.String(), msg.Text) +} + +func releaseHasIdentity(msg tea.KeyPressMsg) bool { + if msg.Text != "" { + return true + } + keyStr := msg.String() + if keyStr != "" && keyStr != "\x00" { + return true + } + // Some terminals emit release-only space events without text identity. + return msg.Code == tea.KeySpace +} + +func shouldSuppressMatchingPressAfterRelease(msg tea.KeyPressMsg) bool { + keyStr := msg.String() + return msg.Code == tea.KeySpace || keyStr == " " || keyStr == "space" || msg.Text == " " } func (m Model) updateActiveModel(msg tea.Msg) (tea.Model, tea.Cmd) { @@ -350,9 +602,8 @@ func (m Model) updateActiveModel(msg tea.Msg) (tea.Model, tea.Cmd) { func (m Model) handlePidSelected(msg PidSelectedMsg) (tea.Model, tea.Cmd) { pid := selectedPIDFilter(msg.Pid) m.stopTrace() - flags.SetPidFilter(pid) - flags.SetTidFilter(-1) m.pidFilter = pid + m.tidFilter = -1 m.dashboard.SetPidFilter(pid) m.screen = ScreenDashboard m.attaching = true @@ -367,9 +618,8 @@ func (m Model) handleTidSelected(msg TidSelectedMsg) (tea.Model, tea.Cmd) { pid = msg.Pid } m.stopTrace() - flags.SetPidFilter(pid) - flags.SetTidFilter(tid) m.pidFilter = pid + m.tidFilter = tid m.dashboard.SetPidFilter(pid) m.screen = ScreenDashboard m.attaching = true @@ -383,8 +633,8 @@ func (m Model) reselectPID() (tea.Model, tea.Cmd) { m.attaching = false m.lastErr = nil m.exporter = tuiexport.NewModel() - m.probeModal = probes.NewModel(m.runtime.currentProbeManager()) - m.pidPicker = pidpicker.New() + m.probeModal = probes.NewModel(m.runtime.currentProbeManager()).SetDarkMode(m.isDark) + m.pidPicker = pidpicker.New().SetDarkMode(m.isDark) var sizeCmd tea.Cmd if m.width > 0 && m.height > 0 { @@ -404,8 +654,8 @@ func (m Model) reselectTID() (tea.Model, tea.Cmd) { m.attaching = false m.lastErr = nil m.exporter = tuiexport.NewModel() - m.probeModal = probes.NewModel(m.runtime.currentProbeManager()) - m.pidPicker = pidpicker.NewTIDWithKeys(pid, pidpicker.DefaultKeyMap()) + m.probeModal = probes.NewModel(m.runtime.currentProbeManager()).SetDarkMode(m.isDark) + m.pidPicker = pidpicker.NewTIDWithKeys(pid, pidpicker.DefaultKeyMap()).SetDarkMode(m.isDark) var sizeCmd tea.Cmd if m.width > 0 && m.height > 0 { @@ -428,6 +678,7 @@ func (m *Model) beginTraceCmd() tea.Cmd { ctx, cancel := context.WithCancel(context.Background()) m.traceStop = cancel ctx = context.WithValue(ctx, runtimeBindingsContextKey{}, m.runtime) + ctx = ContextWithTraceFilters(ctx, m.pidFilter, m.tidFilter) return startTraceCmd(m.startTrace, ctx) } @@ -454,44 +705,81 @@ func (m *Model) stopTrace() { } } +func (m *Model) applyTheme(isDark bool) { + if m.isDark == isDark { + return + } + m.isDark = isDark + common.ApplyPalette(isDark) + syncStylesFromCommon() + m.dashboard.SetDarkMode(isDark) + m.pidPicker = m.pidPicker.SetDarkMode(isDark) + m.probeModal = m.probeModal.SetDarkMode(isDark) +} + +func (m Model) windowTitle() string { + switch m.screen { + case ScreenPIDPicker: + return "ior - select process" + case ScreenDashboard: + if m.pidFilter > 0 { + return fmt.Sprintf("ior - tracing PID %d", m.pidFilter) + } + } + return "ior - I/O Riot" +} + // View renders the currently active screen and startup overlay state. -func (m Model) View() string { +func (m Model) View() tea.View { + title := m.windowTitle() if m.quitting { - return "" + return altScreenView("", title) } width, height := common.EffectiveViewport(m.width, m.height) if m.attaching { line := fmt.Sprintf("%s Attaching tracepoints...", m.spin.View()) - return placeToViewport(width, height, ScreenStyle.Render(PanelStyle.Render(line))) + return altScreenView(placeToViewport(width, height, ScreenStyle.Render(common.PanelStyle.Render(line))), title) } if m.lastErr != nil { - return placeToViewport(width, height, ScreenStyle.Render(ErrorStyle.Render(m.lastErr.Error()))) + return altScreenView(placeToViewport(width, height, ScreenStyle.Render(ErrorStyle.Render(m.lastErr.Error()))), title) + } + if m.helpOverlayVisible { + helpView := renderGlobalHelpOverlay(width, height, m.helpSections()) + return altScreenView(helpView, title) } switch m.screen { case ScreenPIDPicker: - base := m.pidPicker.View() + base := m.pidPicker.View().Content if m.exporter.Visible() { - return placeToViewport(width, height, m.exporter.View(width, height)+"\n"+base) + return altScreenView(placeToViewport(width, height, m.exporter.View(width, height)+"\n"+base), title) } - return placeToViewport(width, height, base) + return altScreenView(placeToViewport(width, height, base), title) case ScreenDashboard: - base := m.dashboard.View() + base := m.dashboard.View().Content if m.probeModal.Visible() { - return placeToViewport(width, height, m.probeModal.View(width, height)) + return altScreenView(placeToViewport(width, height, m.probeModal.View(width, height)), title) } if m.exporter.Visible() { - return placeToViewport(width, height, m.exporter.View(width, height)+"\n"+base) + return altScreenView(placeToViewport(width, height, m.exporter.View(width, height)+"\n"+base), title) } - return placeToViewport(width, height, base) + return altScreenView(placeToViewport(width, height, base), title) default: - return "" + return altScreenView("", title) } } +func isHelpOverlayOpenKey(msg tea.KeyPressMsg) bool { + return msg.String() == "H" +} + +func isHelpOverlayCloseKey(msg tea.KeyPressMsg) bool { + return msg.Code == tea.KeyEsc || msg.String() == "esc" || msg.String() == "?" +} + func runExportCmd(exportEnabled bool, option tuiexport.Option, snap *statsengine.Snapshot) tea.Cmd { return func() tea.Msg { if !exportEnabled { @@ -499,7 +787,7 @@ func runExportCmd(exportEnabled bool, option tuiexport.Option, snap *statsengine } switch option { case tuiexport.OptionCSV: - path, err := exportSnapshotCSV(snap) + path, err := coreexport.SnapshotCSV(snap) if err != nil { return tuiexport.FailedMsg{Err: err} } @@ -525,98 +813,6 @@ func (s lateBoundDashboardSource) Snapshot() *statsengine.Snapshot { return source.Snapshot() } -func exportSnapshotCSV(snap *statsengine.Snapshot) (string, error) { - filename := fmt.Sprintf("ior-snapshot-%s.csv", time.Now().Format("20060102-150405")) - f, err := os.Create(filename) - if err != nil { - return "", err - } - defer f.Close() - - w := csv.NewWriter(f) - - rows := [][]string{ - {"section", "name", "value1", "value2", "value3"}, - {"summary", "totals", fmt.Sprint(snapValue(snap, func(s *statsengine.Snapshot) uint64 { return s.TotalSyscalls })), fmt.Sprint(snapValue(snap, func(s *statsengine.Snapshot) uint64 { return s.TotalErrors })), fmt.Sprint(snapValue(snap, func(s *statsengine.Snapshot) uint64 { return s.TotalBytes }))}, - {"summary", "rates_per_sec", fmt.Sprintf("%.2f", snapValueF(snap, func(s *statsengine.Snapshot) float64 { return s.SyscallRatePerSec })), fmt.Sprintf("%.2f", snapValueF(snap, func(s *statsengine.Snapshot) float64 { return s.ReadBytesPerSec })), fmt.Sprintf("%.2f", snapValueF(snap, func(s *statsengine.Snapshot) float64 { return s.WriteBytesPerSec }))}, - {"summary", "latency_gap_mean_ns", fmt.Sprintf("%.2f", snapValueF(snap, func(s *statsengine.Snapshot) float64 { return s.LatencyMeanNs })), fmt.Sprintf("%.2f", snapValueF(snap, func(s *statsengine.Snapshot) float64 { return s.GapMeanNs })), ""}, - {"summary", "trend", trendSummary(snap, func(s *statsengine.Snapshot) statsengine.Trend { return s.LatencyTrend }), trendSummary(snap, func(s *statsengine.Snapshot) statsengine.Trend { return s.GapTrend }), trendSummary(snap, func(s *statsengine.Snapshot) statsengine.Trend { return s.ThroughputTrend })}, - } - for _, row := range rows { - if err := w.Write(row); err != nil { - return "", err - } - } - - if snap != nil { - for _, s := range snap.Syscalls() { - if err := w.Write([]string{"syscall", s.Name, fmt.Sprint(s.Count), fmt.Sprintf("%.2f", s.RatePerSec), fmt.Sprint(s.Bytes)}); err != nil { - return "", err - } - if err := w.Write([]string{"syscall_latency_ns", s.Name, fmt.Sprintf("%.2f", s.LatencyMeanNs), fmt.Sprint(s.LatencyMinNs), fmt.Sprint(s.LatencyMaxNs)}); err != nil { - return "", err - } - if err := w.Write([]string{"syscall_percentiles_ns", s.Name, fmt.Sprint(s.LatencyP50Ns), fmt.Sprint(s.LatencyP95Ns), fmt.Sprint(s.LatencyP99Ns)}); err != nil { - return "", err - } - } - for _, r := range snap.Files() { - if err := w.Write([]string{"file", r.Path, fmt.Sprint(r.Accesses), fmt.Sprint(r.BytesRead), fmt.Sprint(r.BytesWritten)}); err != nil { - return "", err - } - if err := w.Write([]string{"file_latency_ns", r.Path, fmt.Sprintf("%.2f", r.AvgLatencyNs), fmt.Sprint(r.MaxLatencyNs), ""}); err != nil { - return "", err - } - } - for _, p := range snap.Processes() { - if err := w.Write([]string{"process", fmt.Sprint(p.PID), fmt.Sprint(p.Syscalls), fmt.Sprintf("%.2f", p.RatePerSec), fmt.Sprint(p.Bytes)}); err != nil { - return "", err - } - if err := w.Write([]string{"process_latency_ns", fmt.Sprint(p.PID), fmt.Sprintf("%.2f", p.AvgLatencyNs), "", ""}); err != nil { - return "", err - } - } - for _, b := range snap.LatencyHistogram.Buckets() { - if err := w.Write([]string{"latency_hist", b.Label, fmt.Sprint(b.Count), fmt.Sprint(b.LowerNs), fmt.Sprint(b.UpperNs)}); err != nil { - return "", err - } - } - for _, b := range snap.GapHistogram.Buckets() { - if err := w.Write([]string{"gap_hist", b.Label, fmt.Sprint(b.Count), fmt.Sprint(b.LowerNs), fmt.Sprint(b.UpperNs)}); err != nil { - return "", err - } - } - } - - w.Flush() - if err := w.Error(); err != nil { - return "", err - } - return filename, nil -} - -func snapValue(snap *statsengine.Snapshot, get func(*statsengine.Snapshot) uint64) uint64 { - if snap == nil { - return 0 - } - return get(snap) -} - -func snapValueF(snap *statsengine.Snapshot, get func(*statsengine.Snapshot) float64) float64 { - if snap == nil { - return 0 - } - return get(snap) -} - -func trendSummary(snap *statsengine.Snapshot, get func(*statsengine.Snapshot) statsengine.Trend) string { - if snap == nil { - return "stable:0.00" - } - trend := get(snap) - return fmt.Sprintf("%s:%.2f", trend.Direction, trend.DeltaPercent) -} - func renderHelpOverlay(width, height int, groups [][]key.Binding) string { if width <= 0 { width = 80 @@ -637,23 +833,138 @@ func renderHelpOverlay(width, height int, groups [][]key.Binding) string { lines = append(lines, "", "Esc/? close") boxWidth := width - 6 - if boxWidth > 110 { - boxWidth = 110 - } if boxWidth < 72 { boxWidth = 72 } - box := PanelStyle.Copy(). + box := common.PanelStyle.Copy(). Width(boxWidth). Render(strings.Join(lines, "\n")) return lipgloss.Place(width, height, lipgloss.Center, lipgloss.Center, box) } +type helpSection struct { + title string + lines []string +} + +func (m Model) helpSections() []helpSection { + globalLines := []string{ + "H help esc close help q quit", + "tab/shift+tab cycle tabs 1..7 jump tab", + "p pid picker t tid picker o probes r refresh", + } + if help := m.keys.Export.Help(); help.Key != "" || help.Desc != "" { + globalLines = append(globalLines, "e snapshot export") + } + + return []helpSection{ + { + title: "Global", + lines: globalLines, + }, + { + title: "Flame Tab", + lines: []string{ + "arrows/hjkl navigate pgup top pgdn root", + "enter zoom u/backspace/esc undo", + "/ filter n/N match next/prev", + "space/p pause o order b metric r reset baseline", + }, + }, + { + title: "Stream Tab", + lines: []string{ + "space pause/live f add filter esc undo filter", + "enter apply filter / or ? search n/N next/prev", + "j/k/up/down scroll pgup/pgdn page g/G top/tail", + "left/right or h/l switch columns", + "c clear x export X export-as E open last", + }, + }, + { + title: "PID/TID Picker", + lines: []string{ + "enter select r refresh esc back", + }, + }, + } +} + +func renderGlobalHelpOverlay(width, height int, sections []helpSection) string { + if width <= 0 { + width = 80 + } + if height <= 0 { + height = 24 + } + + boxWidth := width - 4 + if boxWidth > 100 { + boxWidth = 100 + } + if boxWidth < 74 { + boxWidth = 74 + } + contentWidth := boxWidth - 4 + if contentWidth < 20 { + contentWidth = boxWidth + } + + lines := make([]string, 0, 24) + lines = append(lines, "Help") + for _, section := range sections { + lines = append(lines, "") + lines = append(lines, section.title) + for _, line := range section.lines { + lines = append(lines, " "+truncateHelpLine(line, contentWidth-2)) + } + } + lines = append(lines, "", "Esc close") + + maxLines := height - 4 + if maxLines < 6 { + maxLines = 6 + } + if len(lines) > maxLines { + lines = lines[:maxLines-1] + lines = append(lines, truncateHelpLine("... (resize for full help)", contentWidth)) + } + + box := common.PanelStyle.Copy().Width(boxWidth).Render(strings.Join(lines, "\n")) + return lipgloss.Place(width, height, lipgloss.Center, lipgloss.Center, box) +} + +func truncateHelpLine(s string, width int) string { + if width <= 0 { + return "" + } + if lipgloss.Width(s) <= width { + return s + } + if width == 1 { + return "…" + } + r := []rune(s) + if len(r) >= width { + return string(r[:width-1]) + "…" + } + return s +} + func placeToViewport(width, height int, content string) string { if width <= 0 || height <= 0 { return content } return lipgloss.Place(width, height, lipgloss.Left, lipgloss.Top, content) } + +func altScreenView(content, title string) tea.View { + view := tea.NewView(content) + view.AltScreen = true + view.ReportFocus = true + view.WindowTitle = title + view.KeyboardEnhancements.ReportEventTypes = true + return view +} diff --git a/internal/tui/tui_test.go b/internal/tui/tui_test.go index 890dfc4..ad529fc 100644 --- a/internal/tui/tui_test.go +++ b/internal/tui/tui_test.go @@ -3,22 +3,27 @@ package tui import ( "context" "errors" - "ior/internal/probemanager" - "ior/internal/statsengine" - "ior/internal/tui/eventstream" - tuiexport "ior/internal/tui/export" - "ior/internal/tui/messages" "os" "path/filepath" + "regexp" "strings" "testing" "time" + coreflamegraph "ior/internal/flamegraph" + "ior/internal/probemanager" + "ior/internal/statsengine" + dashboardui "ior/internal/tui/dashboard" + "ior/internal/tui/eventstream" + tuiexport "ior/internal/tui/export" + "ior/internal/tui/messages" + "ior/internal/flags" "ior/internal/tui/probes" - "github.com/charmbracelet/bubbles/key" - tea "github.com/charmbracelet/bubbletea" + "charm.land/bubbles/v2/key" + tea "charm.land/bubbletea/v2" + "charm.land/lipgloss/v2" ) type fakeProbeManager struct { @@ -46,11 +51,11 @@ func TestPidSelectedTransitionsToDashboardAndSetsPIDFilter(t *testing.T) { if !updated.attaching { t.Fatalf("expected attaching state to be true") } - if got := flags.Get().PidFilter; got != 42 { - t.Fatalf("expected pid filter 42, got %d", got) + if updated.pidFilter != 42 { + t.Fatalf("expected pid filter 42, got %d", updated.pidFilter) } - if got := flags.Get().TidFilter; got != -1 { - t.Fatalf("expected tid filter reset to -1, got %d", got) + if updated.tidFilter != -1 { + t.Fatalf("expected tid filter reset to -1, got %d", updated.tidFilter) } } @@ -75,10 +80,9 @@ func TestPidSelectedAllSetsNoFilter(t *testing.T) { next, _ := m.Update(PidSelectedMsg{Pid: 0}) updated := next.(Model) - if got := flags.Get().PidFilter; got != -1 { - t.Fatalf("expected pid filter -1 for all pids, got %d", got) + if updated.pidFilter != -1 { + t.Fatalf("expected pid filter -1 for all pids, got %d", updated.pidFilter) } - _ = updated } func TestTracingErrorMessageClearsAttachingState(t *testing.T) { @@ -98,14 +102,14 @@ func TestTracingErrorMessageClearsAttachingState(t *testing.T) { func TestViewShowsAttachingAndErrorStates(t *testing.T) { m := NewModel(-1, func(context.Context) error { return nil }) m.attaching = true - attachingView := m.View() + attachingView := m.View().Content if !strings.Contains(attachingView, "Attaching tracepoints...") { t.Fatalf("expected attaching view, got %q", attachingView) } m.attaching = false m.lastErr = errors.New("failed") - errorView := m.View() + errorView := m.View().Content if !strings.Contains(errorView, "failed") { t.Fatalf("expected error view, got %q", errorView) } @@ -114,7 +118,7 @@ func TestViewShowsAttachingAndErrorStates(t *testing.T) { func TestQuitKeySetsQuittingState(t *testing.T) { m := NewModel(-1, func(context.Context) error { return nil }) - next, cmd := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'q'}}) + next, cmd := m.Update(tea.KeyPressMsg{Code: []rune{'q'}[0], Text: string([]rune{'q'})}) if cmd == nil { t.Fatalf("expected quit cmd") } @@ -132,9 +136,9 @@ func TestQuitKeyMatchesSingleBindingWithoutPanic(t *testing.T) { m := NewModel(-1, func(context.Context) error { return nil }) m.keys.Quit = key.NewBinding(key.WithKeys("x"), key.WithHelp("x", "quit")) - _, _ = m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'z'}}) + _, _ = m.Update(tea.KeyPressMsg{Code: []rune{'z'}[0], Text: string([]rune{'z'})}) - next, cmd := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'x'}}) + next, cmd := m.Update(tea.KeyPressMsg{Code: []rune{'x'}[0], Text: string([]rune{'x'})}) if cmd == nil { t.Fatalf("expected quit cmd") } @@ -171,7 +175,7 @@ func TestQuitInvokesTraceStop(t *testing.T) { close(done) } - _, quitCmd := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'q'}}) + _, quitCmd := m.Update(tea.KeyPressMsg{Code: []rune{'q'}[0], Text: string([]rune{'q'})}) if quitCmd == nil { t.Fatalf("expected quit command") } @@ -218,6 +222,20 @@ func TestDashboardRefreshPicksLateBoundSource(t *testing.T) { } } +func TestRuntimeBindingsStoreAndExposeLiveTrie(t *testing.T) { + runtime := newRuntimeBindings() + trie := coreflamegraph.NewLiveTrie([]string{"comm", "path"}, "count") + runtime.SetLiveTrie(trie) + if got := runtime.liveTrie(); got != trie { + t.Fatalf("expected live trie to be stored and returned") + } + + runtime.SetLiveTrie(nil) + if got := runtime.liveTrie(); got != nil { + t.Fatalf("expected live trie to clear on nil assignment") + } +} + func TestProbeToggledMsgResetsDashboardStatsSource(t *testing.T) { src := &fakeResettableDashboardSource{snap: &statsengine.Snapshot{TotalSyscalls: 99}} @@ -253,16 +271,79 @@ func TestTracingStartedRebindsEventStreamSource(t *testing.T) { next, _ = m.Update(tea.WindowSizeMsg{Width: 120, Height: 30}) m = next.(Model) - next, _ = m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'7'}}) + next, _ = m.Update(tea.KeyPressMsg{Code: []rune{'7'}[0], Text: string([]rune{'7'})}) m = next.(Model) next, _ = m.Update(messages.StatsTickMsg{}) m = next.(Model) - if !strings.Contains(m.View(), "read") { + if !strings.Contains(m.View().Content, "read") { t.Fatalf("expected stream tab to render rebound stream event") } } +func TestTracingStartedUsesCurrentViewportForFlameNavigationWithoutResize(t *testing.T) { + trie := coreflamegraph.NewLiveTrie([]string{"comm", "path", "tracepoint"}, "count") + coreflamegraph.SeedTestFlameData(trie) + + m := NewModel(-1, func(context.Context) error { return nil }) + m.screen = ScreenDashboard + m.attaching = true + m.width = 120 + m.height = 30 + m.runtime.SetLiveTrie(trie) + + next, _ := m.Update(TracingStartedMsg{}) + m = next.(Model) + + if strings.Contains(m.View().Content, "sel:none") { + t.Fatalf("expected flamegraph selection to be available immediately after tracing start") + } + + selectedLabel := func(view string) string { + re := regexp.MustCompile(`sel:[0-9]+/[0-9]+ ([^|]+) \|`) + match := re.FindStringSubmatch(view) + if len(match) != 2 { + return "" + } + return strings.TrimSpace(match[1]) + } + + moved := false + before := selectedLabel(m.View().Content) + for i := 0; i < 12 && !moved; i++ { + next, _ = m.Update(tea.KeyPressMsg{Code: tea.KeyRight}) + m = next.(Model) + after := selectedLabel(m.View().Content) + if after != "" && after != before { + moved = true + break + } + } + if !moved { + t.Fatalf("expected arrow navigation to move selection without requiring resize, view=%q", m.View().Content) + } +} + +func TestTracingStartedAppliesViewportWhenModelSizeIsUnset(t *testing.T) { + trie := coreflamegraph.NewLiveTrie([]string{"comm", "path", "tracepoint"}, "count") + coreflamegraph.SeedTestFlameData(trie) + + m := NewModel(-1, func(context.Context) error { return nil }) + m.screen = ScreenDashboard + m.attaching = true + m.runtime.SetLiveTrie(trie) + m.width = 0 + m.height = 0 + + next, _ := m.Update(TracingStartedMsg{}) + m = next.(Model) + + view := m.View().Content + if strings.Contains(view, "sel:none") { + t.Fatalf("expected tracing start to apply an effective viewport even when width/height are unset") + } +} + func TestExportKeyOpensModalOnDashboard(t *testing.T) { flags.SetTUIExportEnable(true) t.Cleanup(func() { flags.SetTUIExportEnable(true) }) @@ -271,13 +352,172 @@ func TestExportKeyOpensModalOnDashboard(t *testing.T) { m.screen = ScreenDashboard m.attaching = false - next, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'e'}}) + next, _ := m.Update(tea.KeyPressMsg{Code: []rune{'e'}[0], Text: string([]rune{'e'})}) updated := next.(Model) if !updated.exporter.Visible() { t.Fatalf("expected export modal to open on e key") } } +func TestFlamePauseKeyDoesNotTriggerPIDReselect(t *testing.T) { + m := NewModel(-1, func(context.Context) error { return nil }) + m.screen = ScreenDashboard + m.attaching = false + m.width = 120 + m.height = 30 + + next, _ := m.Update(tea.KeyPressMsg{Code: tea.KeySpace, Text: " "}) + updated := next.(Model) + if updated.screen != ScreenDashboard { + t.Fatalf("expected flame space key to keep dashboard screen, got %v", updated.screen) + } + if !strings.Contains(updated.View().Content, "[PAUSED]") { + t.Fatalf("expected flame space key to toggle flame paused state") + } +} + +func TestFlameSpaceKeyReleaseFallbackTogglesPause(t *testing.T) { + m := NewModel(-1, func(context.Context) error { return nil }) + m.screen = ScreenDashboard + m.attaching = false + m.width = 120 + m.height = 30 + + next, _ := m.Update(tea.KeyReleaseMsg{Code: tea.KeySpace, Text: " "}) + updated := next.(Model) + if !strings.Contains(updated.View().Content, "[PAUSED]") { + t.Fatalf("expected key release fallback to toggle flame paused state") + } +} + +func TestFlameSpacePressReleaseDoesNotDoubleTogglePause(t *testing.T) { + m := NewModel(-1, func(context.Context) error { return nil }) + m.screen = ScreenDashboard + m.attaching = false + m.width = 120 + m.height = 30 + + next, _ := m.Update(tea.KeyPressMsg{Code: tea.KeySpace, Text: " "}) + updated := next.(Model) + if !strings.Contains(updated.View().Content, "[PAUSED]") { + t.Fatalf("expected key press to pause flame") + } + + next, _ = updated.Update(tea.KeyReleaseMsg{Code: tea.KeySpace, Text: " "}) + updated = next.(Model) + if !strings.Contains(updated.View().Content, "[PAUSED]") { + t.Fatalf("expected key release after key press to be ignored as duplicate") + } +} + +func TestFlameSpaceReleasePressDoesNotDoubleTogglePause(t *testing.T) { + m := NewModel(-1, func(context.Context) error { return nil }) + m.screen = ScreenDashboard + m.attaching = false + m.width = 120 + m.height = 30 + + next, _ := m.Update(tea.KeyReleaseMsg{Code: tea.KeySpace, Text: " "}) + updated := next.(Model) + if !strings.Contains(updated.View().Content, "[PAUSED]") { + t.Fatalf("expected key release fallback to pause flame") + } + + next, _ = updated.Update(tea.KeyPressMsg{Code: tea.KeySpace, Text: " "}) + updated = next.(Model) + if !strings.Contains(updated.View().Content, "[PAUSED]") { + t.Fatalf("expected immediate matching key press after release fallback to be ignored") + } +} + +func TestNormalizeKeyEventReleaseFallbackSuppressesImmediatePressOnly(t *testing.T) { + m := NewModel(-1, func(context.Context) error { return nil }) + + normalized, ok := m.normalizeKeyEvent(tea.KeyReleaseMsg{Code: tea.KeySpace, Text: " "}) + if !ok { + t.Fatalf("expected release fallback to be handled") + } + if _, isPress := normalized.(tea.KeyPressMsg); !isPress { + t.Fatalf("expected release fallback to normalize to KeyPressMsg, got %T", normalized) + } + + if normalized, ok = m.normalizeKeyEvent(tea.KeyPressMsg{Code: tea.KeySpace, Text: " "}); ok { + t.Fatalf("expected immediate matching press to be suppressed, got %T", normalized) + } + + time.Sleep(70 * time.Millisecond) + if normalized, ok = m.normalizeKeyEvent(tea.KeyPressMsg{Code: tea.KeySpace, Text: " "}); !ok { + t.Fatalf("expected press to be accepted after suppression window") + } + if _, isPress := normalized.(tea.KeyPressMsg); !isPress { + t.Fatalf("expected accepted message to be KeyPressMsg, got %T", normalized) + } +} + +func TestNormalizeKeyEventIgnoresUnidentifiedRelease(t *testing.T) { + m := NewModel(-1, func(context.Context) error { return nil }) + + if normalized, ok := m.normalizeKeyEvent(tea.KeyReleaseMsg{}); ok { + t.Fatalf("expected unidentified release to be ignored, got %T", normalized) + } + + normalized, ok := m.normalizeKeyEvent(tea.KeyPressMsg{Code: tea.KeySpace, Text: " "}) + if !ok { + t.Fatalf("expected subsequent real key press to be handled") + } + if _, isPress := normalized.(tea.KeyPressMsg); !isPress { + t.Fatalf("expected normalized message to be KeyPressMsg, got %T", normalized) + } +} + +func TestNormalizeKeyEventReleaseFallbackDoesNotSuppressArrowPress(t *testing.T) { + m := NewModel(-1, func(context.Context) error { return nil }) + + normalized, ok := m.normalizeKeyEvent(tea.KeyReleaseMsg{Code: tea.KeyRight}) + if !ok { + t.Fatalf("expected right release fallback to be handled") + } + if _, isPress := normalized.(tea.KeyPressMsg); !isPress { + t.Fatalf("expected release fallback to normalize to KeyPressMsg, got %T", normalized) + } + + normalized, ok = m.normalizeKeyEvent(tea.KeyPressMsg{Code: tea.KeyRight}) + if !ok { + t.Fatalf("expected right key press to be accepted after release fallback") + } + if _, isPress := normalized.(tea.KeyPressMsg); !isPress { + t.Fatalf("expected normalized message to be KeyPressMsg, got %T", normalized) + } +} + +func TestFlameOrderKeyDoesNotOpenProbeModal(t *testing.T) { + m := NewModel(-1, func(context.Context) error { return nil }) + m.screen = ScreenDashboard + m.attaching = false + m.width = 120 + m.height = 30 + + next, _ := m.Update(tea.KeyPressMsg{Code: []rune{'o'}[0], Text: string([]rune{'o'})}) + updated := next.(Model) + if updated.probeModal.Visible() { + t.Fatalf("expected flame order key to stay in flame tab, not open probes modal") + } +} + +func TestFlameMetricKeyDoesNotOpenProbeModal(t *testing.T) { + m := NewModel(-1, func(context.Context) error { return nil }) + m.screen = ScreenDashboard + m.attaching = false + m.width = 120 + m.height = 30 + + next, _ := m.Update(tea.KeyPressMsg{Code: []rune{'b'}[0], Text: string([]rune{'b'})}) + updated := next.(Model) + if updated.probeModal.Visible() { + t.Fatalf("expected flame metric key to stay in flame tab, not open probes modal") + } +} + func TestSelectPIDKeyReturnsToFreshPickerAndStopsTrace(t *testing.T) { m := NewModel(-1, func(context.Context) error { return nil }) m.screen = ScreenDashboard @@ -287,7 +527,10 @@ func TestSelectPIDKeyReturnsToFreshPickerAndStopsTrace(t *testing.T) { stopped := false m.traceStop = func() { stopped = true } - next, cmd := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'p'}}) + next, _ := m.Update(tea.KeyPressMsg{Code: []rune{'2'}[0], Text: string([]rune{'2'})}) + m = next.(Model) + + next, cmd := m.Update(tea.KeyPressMsg{Code: []rune{'p'}[0], Text: string([]rune{'p'})}) updated := next.(Model) if !stopped { @@ -319,7 +562,10 @@ func TestSelectTIDKeyReturnsToPickerWhenPIDFilterIsAll(t *testing.T) { stopped := false m.traceStop = func() { stopped = true } - next, cmd := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'t'}}) + next, _ := m.Update(tea.KeyPressMsg{Code: []rune{'2'}[0], Text: string([]rune{'2'})}) + m = next.(Model) + + next, cmd := m.Update(tea.KeyPressMsg{Code: []rune{'t'}[0], Text: string([]rune{'t'})}) updated := next.(Model) if !stopped { t.Fatalf("expected tracing stop before tid reselect") @@ -344,7 +590,10 @@ func TestSelectTIDKeyReturnsToPickerWhenSinglePIDSelected(t *testing.T) { stopped := false m.traceStop = func() { stopped = true } - next, cmd := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'t'}}) + next, _ := m.Update(tea.KeyPressMsg{Code: []rune{'2'}[0], Text: string([]rune{'2'})}) + m = next.(Model) + + next, cmd := m.Update(tea.KeyPressMsg{Code: []rune{'t'}[0], Text: string([]rune{'t'})}) updated := next.(Model) if !stopped { t.Fatalf("expected tracing stop before tid reselect") @@ -373,11 +622,11 @@ func TestTidSelectedTransitionsToDashboardAndSetsTIDFilter(t *testing.T) { if !updated.attaching { t.Fatalf("expected attaching state to be true") } - if got := flags.Get().TidFilter; got != 3333 { - t.Fatalf("expected tid filter 3333, got %d", got) + if updated.tidFilter != 3333 { + t.Fatalf("expected tid filter 3333, got %d", updated.tidFilter) } - if got := flags.Get().PidFilter; got != 2222 { - t.Fatalf("expected pid filter to remain 2222, got %d", got) + if updated.pidFilter != 2222 { + t.Fatalf("expected pid filter to remain 2222, got %d", updated.pidFilter) } } @@ -394,11 +643,11 @@ func TestTidSelectedFromAllPIDModeSetsOwningPID(t *testing.T) { if updated.screen != ScreenDashboard { t.Fatalf("expected dashboard screen, got %v", updated.screen) } - if got := flags.Get().PidFilter; got != 4444 { - t.Fatalf("expected pid filter switched to owning pid 4444, got %d", got) + if updated.pidFilter != 4444 { + t.Fatalf("expected pid filter switched to owning pid 4444, got %d", updated.pidFilter) } - if got := flags.Get().TidFilter; got != 5555 { - t.Fatalf("expected tid filter 5555, got %d", got) + if updated.tidFilter != 5555 { + t.Fatalf("expected tid filter 5555, got %d", updated.tidFilter) } } @@ -410,7 +659,7 @@ func TestExportKeyIgnoredWhenExportDisabled(t *testing.T) { m.screen = ScreenDashboard m.attaching = false - next, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'e'}}) + next, _ := m.Update(tea.KeyPressMsg{Code: []rune{'e'}[0], Text: string([]rune{'e'})}) updated := next.(Model) if updated.exporter.Visible() { t.Fatalf("expected export modal to remain closed when export is disabled") @@ -427,23 +676,23 @@ func TestStreamFilterModalConsumesEKeyInsteadOfOpeningExport(t *testing.T) { m.width = 120 m.height = 30 - next, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'7'}}) + next, _ := m.Update(tea.KeyPressMsg{Code: []rune{'7'}[0], Text: string([]rune{'7'})}) m = next.(Model) - next, _ = m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'f'}}) + next, _ = m.Update(tea.KeyPressMsg{Code: []rune{'f'}[0], Text: string([]rune{'f'})}) m = next.(Model) - next, _ = m.Update(tea.KeyMsg{Type: tea.KeyEnter}) + next, _ = m.Update(tea.KeyPressMsg{Code: tea.KeyEnter}) m = next.(Model) for _, r := range []rune{'o', 'p', 'e'} { - next, _ = m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{r}}) + next, _ = m.Update(tea.KeyPressMsg{Code: []rune{r}[0], Text: string([]rune{r})}) m = next.(Model) } - next, _ = m.Update(tea.KeyMsg{Type: tea.KeyEsc}) + next, _ = m.Update(tea.KeyPressMsg{Code: tea.KeyEsc}) m = next.(Model) if m.exporter.Visible() { t.Fatalf("expected export modal to remain closed while stream filter modal handles typing") } - if !strings.Contains(m.View(), "syscall~ope") { + if !strings.Contains(m.View().Content, "syscall~ope") { t.Fatalf("expected typed syscall filter to be applied") } } @@ -475,7 +724,7 @@ func TestRunExportCmdCSVWritesFile(t *testing.T) { func TestHelpKeyDoesNotToggleOverlay(t *testing.T) { m := NewModel(-1, func(context.Context) error { return nil }) - next, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'?'}}) + next, _ := m.Update(tea.KeyPressMsg{Code: []rune{'?'}[0], Text: string([]rune{'?'})}) updated := next.(Model) if updated.screen != ScreenPIDPicker { t.Fatalf("expected ? to have no effect, got screen %v", updated.screen) @@ -488,17 +737,60 @@ func TestViewShowsDashboardWithoutHelpOverlay(t *testing.T) { m.width = 100 m.height = 30 - out := m.View() + out := m.View().Content if !strings.Contains(out, "press H for help") { t.Fatalf("expected bottom help hint in dashboard") } } +func TestHelpOverlayOpensWithUppercaseHAndClosesWithEsc(t *testing.T) { + m := NewModel(-1, func(context.Context) error { return nil }) + m.screen = ScreenDashboard + m.attaching = false + m.width = 100 + m.height = 30 + + next, _ := m.Update(tea.KeyPressMsg{Code: []rune{'H'}[0], Text: string([]rune{'H'})}) + m = next.(Model) + if !m.helpOverlayVisible { + t.Fatalf("expected help overlay to become visible after H") + } + view := m.View().Content + if !strings.Contains(view, "Help") || !strings.Contains(view, "Global") || !strings.Contains(view, "Esc close") { + t.Fatalf("expected global help overlay content, got %q", view) + } + + next, _ = m.Update(tea.KeyPressMsg{Code: tea.KeyEsc}) + m = next.(Model) + if m.helpOverlayVisible { + t.Fatalf("expected esc to close help overlay") + } + if !strings.Contains(m.View().Content, "press H for help") { + t.Fatalf("expected dashboard help hint after closing overlay") + } +} + +func TestHelpOverlayCanOpenFromPIDPicker(t *testing.T) { + m := NewModel(-1, func(context.Context) error { return nil }) + m.screen = ScreenPIDPicker + m.width = 100 + m.height = 30 + + next, _ := m.Update(tea.KeyPressMsg{Code: []rune{'H'}[0], Text: string([]rune{'H'})}) + m = next.(Model) + if !m.helpOverlayVisible { + t.Fatalf("expected help overlay to open on pid picker screen") + } + if !strings.Contains(m.View().Content, "PID/TID Picker") { + t.Fatalf("expected picker shortcuts in help overlay") + } +} + func TestQuestionMarkDoesNotBlockUnderlyingActions(t *testing.T) { m := NewModel(-1, func(context.Context) error { return nil }) m.screen = ScreenDashboard - next, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'e'}}) + next, _ := m.Update(tea.KeyPressMsg{Code: []rune{'e'}[0], Text: string([]rune{'e'})}) updated := next.(Model) if !updated.exporter.Visible() { t.Fatalf("expected export modal to open; ? overlay is removed") @@ -512,19 +804,19 @@ func TestQuestionMarkDoesNotBreakExportModalInput(t *testing.T) { m := NewModel(-1, func(context.Context) error { return nil }) m.screen = ScreenDashboard - next, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'e'}}) + next, _ := m.Update(tea.KeyPressMsg{Code: []rune{'e'}[0], Text: string([]rune{'e'})}) updated := next.(Model) if !updated.exporter.Visible() { t.Fatalf("expected export modal to open") } - next, _ = updated.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'?'}}) + next, _ = updated.Update(tea.KeyPressMsg{Code: []rune{'?'}[0], Text: string([]rune{'?'})}) updated = next.(Model) if !updated.exporter.Visible() { t.Fatalf("expected export modal to remain open after ? key") } - next, _ = updated.Update(tea.KeyMsg{Type: tea.KeyEsc}) + next, _ = updated.Update(tea.KeyPressMsg{Code: tea.KeyEsc}) updated = next.(Model) if updated.exporter.Visible() { t.Fatalf("expected esc to close export modal") @@ -540,7 +832,7 @@ func TestStatusBarHidesExportBindingWhenExportDisabled(t *testing.T) { m.width = 100 m.height = 30 - out := m.View() + out := m.View().Content if strings.Contains(out, "e snapshot export") { t.Fatalf("did not expect export shortcut in status bar when export is disabled") } @@ -568,23 +860,23 @@ func TestDashboardTabKeysChangeActiveView(t *testing.T) { m.width = 120 m.height = 30 - out := m.View() - if !strings.Contains(out, "Overview: waiting for stats") { - t.Fatalf("expected overview waiting view by default") + out := m.View().Content + if !strings.Contains(out, "Flame: waiting for data") { + t.Fatalf("expected flame waiting view by default") } - next, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'2'}}) + next, _ := m.Update(tea.KeyPressMsg{Code: []rune{'2'}[0], Text: string([]rune{'2'})}) updated := next.(Model) - out = updated.View() - if !strings.Contains(out, "Syscalls: waiting for stats") { - t.Fatalf("expected syscalls waiting view after pressing 2") + out = updated.View().Content + if !strings.Contains(out, "Overview: waiting for stats") { + t.Fatalf("expected overview waiting view after pressing 2") } - next, _ = updated.Update(tea.KeyMsg{Type: tea.KeyTab}) + next, _ = updated.Update(tea.KeyPressMsg{Code: tea.KeyTab}) updated = next.(Model) - out = updated.View() - if !strings.Contains(out, "Files: waiting for stats") { - t.Fatalf("expected files waiting view after tab") + out = updated.View().Content + if !strings.Contains(out, "Syscalls: waiting for stats") { + t.Fatalf("expected syscalls waiting view after tab") } } @@ -598,11 +890,123 @@ func TestProbeModalViewDoesNotStackDashboardContent(t *testing.T) { m.height = 30 m.probeModal = m.probeModal.Open() - out := m.View() + out := m.View().Content if !strings.Contains(out, "Probes (") { t.Fatalf("expected probe modal content, got %q", out) } - if strings.Contains(out, "Overview: waiting for stats") { + if strings.Contains(out, "Flame: waiting for data") || strings.Contains(out, "Overview: waiting for stats") { t.Fatalf("expected probe modal to render as standalone view, got stacked dashboard content") } } + +func TestBlurPausesDashboardRefreshAndFocusResumesIt(t *testing.T) { + m := NewModel(-1, func(context.Context) error { return nil }) + m.screen = ScreenDashboard + m.attaching = false + m.dashboard = dashboardui.NewModelWithConfig(nil, nil, 1, m.keys) + m.focused = true + + next, _ := m.Update(tea.BlurMsg{}) + m = next.(Model) + if m.focused { + t.Fatalf("expected focused=false after blur") + } + + tickMsg := m.dashboard.Init()() + next, tickCmd := m.Update(tickMsg) + m = next.(Model) + if tickCmd != nil { + t.Fatalf("expected no follow-up tick command while blurred") + } + + next, focusCmd := m.Update(tea.FocusMsg{}) + m = next.(Model) + if !m.focused { + t.Fatalf("expected focused=true after focus") + } + if focusCmd == nil { + t.Fatalf("expected focus to resume refresh with a command batch") + } + if _, ok := focusCmd().(tea.BatchMsg); !ok { + t.Fatalf("expected focus command to be a batch") + } +} + +func TestKeyboardEnhancementsMsgHandledGracefully(t *testing.T) { + m := NewModel(-1, func(context.Context) error { return nil }) + + next, cmd := m.Update(tea.KeyboardEnhancementsMsg{Flags: 1}) + if cmd != nil { + t.Fatalf("expected no command when handling keyboard enhancements msg") + } + + updated := next.(Model) + if !updated.keyboardEnhancementsKnown { + t.Fatalf("expected keyboard enhancements to be marked as known") + } + if !updated.keyboardEnhancements.SupportsKeyDisambiguation() { + t.Fatalf("expected non-zero flags to report key disambiguation support") + } +} + +func TestViewSetsDynamicWindowTitle(t *testing.T) { + m := NewModel(-1, func(context.Context) error { return nil }) + + m.screen = ScreenPIDPicker + view := m.View() + if view.WindowTitle != "ior - select process" { + t.Fatalf("unexpected picker window title: %q", view.WindowTitle) + } + + m.screen = ScreenDashboard + m.pidFilter = 1234 + view = m.View() + if view.WindowTitle != "ior - tracing PID 1234" { + t.Fatalf("unexpected tracing window title: %q", view.WindowTitle) + } + + m.pidFilter = -1 + view = m.View() + if view.WindowTitle != "ior - I/O Riot" { + t.Fatalf("unexpected default window title: %q", view.WindowTitle) + } +} + +func TestRenderHelpOverlayUsesWideViewport(t *testing.T) { + groups := [][]key.Binding{{key.NewBinding(key.WithKeys("?"), key.WithHelp("?", "help"))}} + out := renderHelpOverlay(160, 40, groups) + + maxWidth := 0 + for _, line := range strings.Split(out, "\n") { + if w := lipgloss.Width(line); w > maxWidth { + maxWidth = w + } + } + + if maxWidth <= 110 { + t.Fatalf("expected wide help overlay to exceed previous 110-col cap, got %d", maxWidth) + } +} + +func TestGlobalHelpOverlayFitsStandardTerminal(t *testing.T) { + m := NewModel(-1, func(context.Context) error { return nil }) + out := renderGlobalHelpOverlay(80, 24, m.helpSections()) + + lines := strings.Split(out, "\n") + if len(lines) > 24 { + t.Fatalf("expected help overlay to fit within 24 lines, got %d", len(lines)) + } + + maxWidth := 0 + for _, line := range lines { + if w := lipgloss.Width(line); w > maxWidth { + maxWidth = w + } + } + if maxWidth > 80 { + t.Fatalf("expected help overlay width <= 80, got %d", maxWidth) + } + if !strings.Contains(out, "Flame Tab") || !strings.Contains(out, "Stream Tab") { + t.Fatalf("expected overlay to include tab-specific help sections") + } +} diff --git a/internal/types/doc.go b/internal/types/doc.go new file mode 100644 index 0000000..79f156e --- /dev/null +++ b/internal/types/doc.go @@ -0,0 +1,2 @@ +// Package types provides generated Go structs mirroring BPF event payload layouts. +package types diff --git a/internal/types/generated_types.go b/internal/types/generated_types.go index 9cdcc5b..1f4b9d8 100644 --- a/internal/types/generated_types.go +++ b/internal/types/generated_types.go @@ -22,7 +22,7 @@ var traceId2Name = map[TraceId]string{ func (s TraceId) String() string { str, ok := traceId2String[s] if !ok { - panic(fmt.Sprintf("no string representation for trace ID %d found", s)) + return fmt.Sprintf("unknown_trace_id_%d", s) } return str } @@ -30,7 +30,7 @@ func (s TraceId) String() string { func (s TraceId) Name() string { str, ok := traceId2Name[s] if !ok { - panic(fmt.Sprintf("no name for trace ID %d found", s)) + return fmt.Sprintf("unknown_trace_id_%d", s) } return str } diff --git a/internal/types/types_test.go b/internal/types/types_test.go index 6abebdb..8ba7367 100644 --- a/internal/types/types_test.go +++ b/internal/types/types_test.go @@ -1,6 +1,7 @@ package types import ( + "fmt" "syscall" "testing" ) @@ -156,6 +157,14 @@ func TestEqualsDifferentValues(t *testing.T) { t.Log("Equals returns false for same type but different values") } +func TestTraceIdUnknownFallback(t *testing.T) { + unknown := TraceId(0xFFFFFFFF) + want := fmt.Sprintf("unknown_trace_id_%d", unknown) + + assertEquals(t, want, unknown.String()) + assertEquals(t, want, unknown.Name()) +} + func assertEquals[T comparable](t *testing.T, a, b T) { if a != b { t.Errorf("Expected %v, got %v", a, b) |
