summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--internal/eventloop_comm.go32
-rw-r--r--internal/eventloop_commresolver_test.go48
2 files changed, 72 insertions, 8 deletions
diff --git a/internal/eventloop_comm.go b/internal/eventloop_comm.go
index 4d49ef2..6e9ed0b 100644
--- a/internal/eventloop_comm.go
+++ b/internal/eventloop_comm.go
@@ -1,6 +1,7 @@
package internal
import (
+ "context"
"errors"
"fmt"
"os"
@@ -8,8 +9,13 @@ import (
"strconv"
"sync"
"syscall"
+ "time"
)
+// resolveCommTimeout caps each procfs read so a frozen cgroup cannot stall
+// a lookup worker indefinitely and block clean shutdown.
+const resolveCommTimeout = time.Second
+
type commResolver struct {
comms map[uint32]string
@@ -19,7 +25,7 @@ type commResolver struct {
lookupQueue chan uint32
lookupWorkers int
- resolveFn func(uint32) (string, error)
+ resolveFn func(context.Context, uint32) (string, error)
warningFn func(string)
startWorkersOnce sync.Once
workersWG sync.WaitGroup
@@ -46,7 +52,16 @@ func (r *commResolver) ensureLookupConfig() {
r.lookupQueue = make(chan uint32, defaultCommLookupQueueSize)
}
if r.resolveFn == nil {
- r.resolveFn = resolveCommFromProcWithError
+ // Default resolver wraps resolveCommFromProcWithError, which does not
+ // accept a context itself, so we honour cancellation by returning early
+ // when the context deadline is already exceeded before the call returns.
+ r.resolveFn = func(ctx context.Context, tid uint32) (string, error) {
+ comm, err := resolveCommFromProcWithError(tid)
+ if ctx.Err() != nil {
+ return "", ctx.Err()
+ }
+ return comm, err
+ }
}
}
@@ -69,7 +84,12 @@ func (r *commResolver) startLookupWorkers() {
func (r *commResolver) lookupWorker() {
defer r.workersWG.Done()
for tid := range r.lookupQueue {
- comm, err := r.resolveFn(tid)
+ // Each procfs read gets an independent timeout so that a frozen cgroup
+ // or a slow /proc entry cannot block a worker goroutine indefinitely
+ // and stall shutdown (which waits on workersWG).
+ ctx, cancel := context.WithTimeout(context.Background(), resolveCommTimeout)
+ comm, err := r.resolveFn(ctx, tid)
+ cancel()
r.mu.Lock()
delete(r.pending, tid)
if comm != "" {
@@ -95,7 +115,11 @@ func (r *commResolver) seedTrackedPidComm(pidFilter int) {
continue
}
seen[tid] = struct{}{}
- comm, err := r.resolveFn(tid)
+ // Use a short timeout here too; seeding happens at startup and a stall
+ // would delay the entire event loop initialisation.
+ ctx, cancel := context.WithTimeout(context.Background(), resolveCommTimeout)
+ comm, err := r.resolveFn(ctx, tid)
+ cancel()
if comm != "" {
r.setCached(tid, comm)
continue
diff --git a/internal/eventloop_commresolver_test.go b/internal/eventloop_commresolver_test.go
index 351db70..04019d4 100644
--- a/internal/eventloop_commresolver_test.go
+++ b/internal/eventloop_commresolver_test.go
@@ -1,6 +1,7 @@
package internal
import (
+ "context"
"errors"
"fmt"
"strings"
@@ -26,7 +27,7 @@ func TestCommResolverQueueLookupRespectsWorkerLimit(t *testing.T) {
defer resolver.shutdown()
resolver.lookupWorkers = workers
resolver.lookupQueue = make(chan uint32, lookups)
- resolver.resolveFn = func(tid uint32) (string, error) {
+ resolver.resolveFn = func(_ context.Context, tid uint32) (string, error) {
current := atomic.AddInt32(&running, 1)
setMaxInt32(&maxRunning, current)
started <- struct{}{}
@@ -86,7 +87,7 @@ func TestCommResolverQueueLookupQueueFullClearsPending(t *testing.T) {
defer resolver.shutdown()
resolver.lookupWorkers = 1
resolver.lookupQueue = make(chan uint32, 1)
- resolver.resolveFn = func(tid uint32) (string, error) {
+ resolver.resolveFn = func(_ context.Context, tid uint32) (string, error) {
select {
case started <- struct{}{}:
default:
@@ -141,7 +142,7 @@ func TestCommResolverShutdownStopsWorkersAndPreventsNewLookups(t *testing.T) {
resolver := newCommResolver(nil)
resolver.lookupWorkers = 1
resolver.lookupQueue = make(chan uint32, 1)
- resolver.resolveFn = func(tid uint32) (string, error) {
+ resolver.resolveFn = func(_ context.Context, tid uint32) (string, error) {
started <- struct{}{}
<-release
return fmt.Sprintf("comm-%d", tid), nil
@@ -193,7 +194,7 @@ func TestCommResolverLookupWarnsOnUnexpectedResolveError(t *testing.T) {
resolver.lookupWorkers = 1
resolver.lookupQueue = make(chan uint32, 1)
resolver.warningFn = func(message string) { warnings <- message }
- resolver.resolveFn = func(uint32) (string, error) {
+ resolver.resolveFn = func(context.Context, uint32) (string, error) {
return "", errors.New("boom")
}
@@ -226,6 +227,45 @@ func TestResolveCommFromProcWithErrorIgnoresMissingProcess(t *testing.T) {
}
}
+// TestCommResolverLookupWorkerRespectsTimeout verifies that a resolveFn that
+// blocks longer than resolveCommTimeout is interrupted and the pending entry
+// is cleared so shutdown is not stalled.
+func TestCommResolverLookupWorkerRespectsTimeout(t *testing.T) {
+ const tid uint32 = 401
+
+ // blockUntilCtxDone blocks until the context passed by the worker expires.
+ blockUntilCtxDone := make(chan struct{})
+ resolver := newCommResolver(nil)
+ defer resolver.shutdown()
+ resolver.lookupWorkers = 1
+ resolver.lookupQueue = make(chan uint32, 1)
+ resolver.resolveFn = func(ctx context.Context, _ uint32) (string, error) {
+ close(blockUntilCtxDone)
+ <-ctx.Done()
+ return "", ctx.Err()
+ }
+
+ resolver.queueLookup(tid)
+
+ // Wait until the resolver fn has started and confirmed it is blocking.
+ select {
+ case <-blockUntilCtxDone:
+ case <-time.After(2 * time.Second):
+ t.Fatal("timed out waiting for resolver fn to start")
+ }
+
+ // The pending entry must be cleared once the context times out and the
+ // worker loop continues to the next iteration.
+ waitForCondition(t, resolveCommTimeout+2*time.Second,
+ "expected pending entry to be cleared after context timeout",
+ func() bool { return pendingCount(resolver) == 0 },
+ )
+
+ if _, ok := resolver.cached(tid); ok {
+ t.Fatalf("did not expect tid %d to be cached after a timed-out resolve", tid)
+ }
+}
+
func hasPending(r *commResolver, tid uint32) bool {
r.mu.RLock()
defer r.mu.RUnlock()