diff options
| -rw-r--r-- | internal/eventloop_comm.go | 32 | ||||
| -rw-r--r-- | internal/eventloop_commresolver_test.go | 48 |
2 files changed, 72 insertions, 8 deletions
diff --git a/internal/eventloop_comm.go b/internal/eventloop_comm.go index 4d49ef2..6e9ed0b 100644 --- a/internal/eventloop_comm.go +++ b/internal/eventloop_comm.go @@ -1,6 +1,7 @@ package internal import ( + "context" "errors" "fmt" "os" @@ -8,8 +9,13 @@ import ( "strconv" "sync" "syscall" + "time" ) +// resolveCommTimeout caps each procfs read so a frozen cgroup cannot stall +// a lookup worker indefinitely and block clean shutdown. +const resolveCommTimeout = time.Second + type commResolver struct { comms map[uint32]string @@ -19,7 +25,7 @@ type commResolver struct { lookupQueue chan uint32 lookupWorkers int - resolveFn func(uint32) (string, error) + resolveFn func(context.Context, uint32) (string, error) warningFn func(string) startWorkersOnce sync.Once workersWG sync.WaitGroup @@ -46,7 +52,16 @@ func (r *commResolver) ensureLookupConfig() { r.lookupQueue = make(chan uint32, defaultCommLookupQueueSize) } if r.resolveFn == nil { - r.resolveFn = resolveCommFromProcWithError + // Default resolver wraps resolveCommFromProcWithError, which does not + // accept a context itself, so we honour cancellation by returning early + // when the context deadline is already exceeded before the call returns. + r.resolveFn = func(ctx context.Context, tid uint32) (string, error) { + comm, err := resolveCommFromProcWithError(tid) + if ctx.Err() != nil { + return "", ctx.Err() + } + return comm, err + } } } @@ -69,7 +84,12 @@ func (r *commResolver) startLookupWorkers() { func (r *commResolver) lookupWorker() { defer r.workersWG.Done() for tid := range r.lookupQueue { - comm, err := r.resolveFn(tid) + // Each procfs read gets an independent timeout so that a frozen cgroup + // or a slow /proc entry cannot block a worker goroutine indefinitely + // and stall shutdown (which waits on workersWG). + ctx, cancel := context.WithTimeout(context.Background(), resolveCommTimeout) + comm, err := r.resolveFn(ctx, tid) + cancel() r.mu.Lock() delete(r.pending, tid) if comm != "" { @@ -95,7 +115,11 @@ func (r *commResolver) seedTrackedPidComm(pidFilter int) { continue } seen[tid] = struct{}{} - comm, err := r.resolveFn(tid) + // Use a short timeout here too; seeding happens at startup and a stall + // would delay the entire event loop initialisation. + ctx, cancel := context.WithTimeout(context.Background(), resolveCommTimeout) + comm, err := r.resolveFn(ctx, tid) + cancel() if comm != "" { r.setCached(tid, comm) continue diff --git a/internal/eventloop_commresolver_test.go b/internal/eventloop_commresolver_test.go index 351db70..04019d4 100644 --- a/internal/eventloop_commresolver_test.go +++ b/internal/eventloop_commresolver_test.go @@ -1,6 +1,7 @@ package internal import ( + "context" "errors" "fmt" "strings" @@ -26,7 +27,7 @@ func TestCommResolverQueueLookupRespectsWorkerLimit(t *testing.T) { defer resolver.shutdown() resolver.lookupWorkers = workers resolver.lookupQueue = make(chan uint32, lookups) - resolver.resolveFn = func(tid uint32) (string, error) { + resolver.resolveFn = func(_ context.Context, tid uint32) (string, error) { current := atomic.AddInt32(&running, 1) setMaxInt32(&maxRunning, current) started <- struct{}{} @@ -86,7 +87,7 @@ func TestCommResolverQueueLookupQueueFullClearsPending(t *testing.T) { defer resolver.shutdown() resolver.lookupWorkers = 1 resolver.lookupQueue = make(chan uint32, 1) - resolver.resolveFn = func(tid uint32) (string, error) { + resolver.resolveFn = func(_ context.Context, tid uint32) (string, error) { select { case started <- struct{}{}: default: @@ -141,7 +142,7 @@ func TestCommResolverShutdownStopsWorkersAndPreventsNewLookups(t *testing.T) { resolver := newCommResolver(nil) resolver.lookupWorkers = 1 resolver.lookupQueue = make(chan uint32, 1) - resolver.resolveFn = func(tid uint32) (string, error) { + resolver.resolveFn = func(_ context.Context, tid uint32) (string, error) { started <- struct{}{} <-release return fmt.Sprintf("comm-%d", tid), nil @@ -193,7 +194,7 @@ func TestCommResolverLookupWarnsOnUnexpectedResolveError(t *testing.T) { resolver.lookupWorkers = 1 resolver.lookupQueue = make(chan uint32, 1) resolver.warningFn = func(message string) { warnings <- message } - resolver.resolveFn = func(uint32) (string, error) { + resolver.resolveFn = func(context.Context, uint32) (string, error) { return "", errors.New("boom") } @@ -226,6 +227,45 @@ func TestResolveCommFromProcWithErrorIgnoresMissingProcess(t *testing.T) { } } +// TestCommResolverLookupWorkerRespectsTimeout verifies that a resolveFn that +// blocks longer than resolveCommTimeout is interrupted and the pending entry +// is cleared so shutdown is not stalled. +func TestCommResolverLookupWorkerRespectsTimeout(t *testing.T) { + const tid uint32 = 401 + + // blockUntilCtxDone blocks until the context passed by the worker expires. + blockUntilCtxDone := make(chan struct{}) + resolver := newCommResolver(nil) + defer resolver.shutdown() + resolver.lookupWorkers = 1 + resolver.lookupQueue = make(chan uint32, 1) + resolver.resolveFn = func(ctx context.Context, _ uint32) (string, error) { + close(blockUntilCtxDone) + <-ctx.Done() + return "", ctx.Err() + } + + resolver.queueLookup(tid) + + // Wait until the resolver fn has started and confirmed it is blocking. + select { + case <-blockUntilCtxDone: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for resolver fn to start") + } + + // The pending entry must be cleared once the context times out and the + // worker loop continues to the next iteration. + waitForCondition(t, resolveCommTimeout+2*time.Second, + "expected pending entry to be cleared after context timeout", + func() bool { return pendingCount(resolver) == 0 }, + ) + + if _, ok := resolver.cached(tid); ok { + t.Fatalf("did not expect tid %d to be cached after a timed-out resolve", tid) + } +} + func hasPending(r *commResolver, tid uint32) bool { r.mu.RLock() defer r.mu.RUnlock() |
