From f2dd8d8a515c1a2a220836231ad1a671a5e9b73d Mon Sep 17 00:00:00 2001 From: Paul Buetow Date: Mon, 13 Apr 2026 08:09:33 +0300 Subject: ask: serialize concurrent CLI with repo lock and stale PID recovery Add advisory lock under .git/hexai-ask.lock around Taskwarrior execution, with metadata (PID and process basename) and Linux /proc comm checks to remove orphan lock files when the recorded holder is gone or not ask. Extract internal/filelock for shared flock helpers; stats uses it too. Made-with: Cursor --- internal/askcli/runlock.go | 147 ++++++++++++++++++++++++++++ internal/askcli/runlock_stale_linux.go | 33 +++++++ internal/askcli/runlock_stale_linux_test.go | 19 ++++ internal/askcli/runlock_stale_other.go | 23 +++++ internal/askcli/runlock_test.go | 46 +++++++++ internal/askcli/taskexec.go | 14 ++- internal/askcli/taskexec_test.go | 34 +++++-- 7 files changed, 306 insertions(+), 10 deletions(-) create mode 100644 internal/askcli/runlock.go create mode 100644 internal/askcli/runlock_stale_linux.go create mode 100644 internal/askcli/runlock_stale_linux_test.go create mode 100644 internal/askcli/runlock_stale_other.go create mode 100644 internal/askcli/runlock_test.go (limited to 'internal/askcli') diff --git a/internal/askcli/runlock.go b/internal/askcli/runlock.go new file mode 100644 index 0000000..d0d7be3 --- /dev/null +++ b/internal/askcli/runlock.go @@ -0,0 +1,147 @@ +package askcli + +import ( + "context" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "strconv" + "strings" + "time" + + "codeberg.org/snonux/hexai/internal/filelock" +) + +const askRepoLockFile = "hexai-ask.lock" + +var errAskLockReopen = errors.New("ask lock: reopen after stale file removal") + +func lockProcessLabel() string { + if exe, err := os.Executable(); err == nil { + if b := filepath.Base(exe); b != "" && b != "." { + return b + } + } + if b := filepath.Base(os.Args[0]); b != "" { + return b + } + return "ask" +} + +func readLockHolderPID(f *os.File) int { + if _, err := f.Seek(0, io.SeekStart); err != nil { + return 0 + } + var buf [64]byte + n, err := f.Read(buf[:]) + if err != nil && !errors.Is(err, io.EOF) { + return 0 + } + line := strings.TrimSpace(string(buf[:n])) + if line == "" { + return 0 + } + end := strings.IndexAny(line, "\n\r \t") + if end >= 0 { + line = line[:end] + } + pid, err := strconv.Atoi(line) + if err != nil || pid <= 0 { + return 0 + } + return pid +} + +func writeLockMetadata(f *os.File, pid int, comm string) error { + if _, err := f.Seek(0, io.SeekStart); err != nil { + return err + } + if err := f.Truncate(0); err != nil { + return err + } + _, err := fmt.Fprintf(f, "%d\n%s\n", pid, comm) + if err != nil { + return err + } + return f.Sync() +} + +// waitOrAcquireAskLockFD tries to take an exclusive lock on f, or blocks until ctx ends. +// On success it writes lock metadata and returns an unlock function (which closes f). +// errAskLockReopen means the caller should open the lock path again after stale removal. +func waitOrAcquireAskLockFD( + ctx context.Context, + f *os.File, + lockPath string, + comm string, + retryTimer *time.Timer, +) (func() error, error) { + for { + err := filelock.TryExclusive(f) + if err == nil { + if werr := writeLockMetadata(f, os.Getpid(), comm); werr != nil { + _ = filelock.UnlockExclusive(f) + _ = f.Close() + return nil, fmt.Errorf("ask lock: write metadata: %w", werr) + } + return func() error { + uErr := filelock.UnlockExclusive(f) + cErr := f.Close() + return errors.Join(uErr, cErr) + }, nil + } + if !errors.Is(err, filelock.ErrWouldBlock) { + _ = f.Close() + return nil, fmt.Errorf("ask lock: %w", err) + } + + pid := readLockHolderPID(f) + if pid > 0 && lockHolderIsStale(pid, comm) { + _ = f.Close() + if rerr := os.Remove(lockPath); rerr != nil && !errors.Is(rerr, os.ErrNotExist) { + return nil, fmt.Errorf("ask lock: remove stale %s: %w", lockPath, rerr) + } + return nil, errAskLockReopen + } + + retryTimer.Reset(5 * time.Millisecond) + select { + case <-ctx.Done(): + _ = f.Close() + return nil, ctx.Err() + case <-retryTimer.C: + } + } +} + +// acquireAskRepoLock serializes ask CLI access for a git working copy. It uses an +// advisory lock under .git and records holder PID plus process name for stale detection. +func acquireAskRepoLock(ctx context.Context, gitRoot string) (func() error, error) { + lockPath := filepath.Join(gitRoot, ".git", askRepoLockFile) + if err := os.MkdirAll(filepath.Dir(lockPath), 0o755); err != nil { + return nil, fmt.Errorf("ask lock: mkdir: %w", err) + } + + comm := lockProcessLabel() + retryTimer := time.NewTimer(5 * time.Millisecond) + defer retryTimer.Stop() + + for removalAttempts := 0; removalAttempts < 16; removalAttempts++ { + f, err := os.OpenFile(lockPath, os.O_CREATE|os.O_RDWR, 0o600) + if err != nil { + return nil, fmt.Errorf("ask lock: open %s: %w", lockPath, err) + } + unlock, err := waitOrAcquireAskLockFD(ctx, f, lockPath, comm, retryTimer) + if err == nil { + return unlock, nil + } + if errors.Is(err, errAskLockReopen) { + continue + } + return nil, err + } + + return nil, fmt.Errorf("ask lock: could not acquire %s after stale recovery attempts", lockPath) +} diff --git a/internal/askcli/runlock_stale_linux.go b/internal/askcli/runlock_stale_linux.go new file mode 100644 index 0000000..183dfb2 --- /dev/null +++ b/internal/askcli/runlock_stale_linux.go @@ -0,0 +1,33 @@ +//go:build linux + +package askcli + +import ( + "os" + "strconv" + "strings" + + "golang.org/x/sys/unix" +) + +func lockHolderIsStale(pid int, expectedComm string) bool { + if pid <= 0 { + return false + } + if err := unix.Kill(pid, 0); err != nil { + return true + } + data, err := os.ReadFile("/proc/" + strconv.Itoa(pid) + "/comm") + if err != nil { + return false + } + holder := strings.TrimSpace(string(data)) + if holder == "" { + return false + } + want := expectedComm + if len(want) > 15 { + want = want[:15] + } + return holder != want +} diff --git a/internal/askcli/runlock_stale_linux_test.go b/internal/askcli/runlock_stale_linux_test.go new file mode 100644 index 0000000..7c581fb --- /dev/null +++ b/internal/askcli/runlock_stale_linux_test.go @@ -0,0 +1,19 @@ +//go:build linux + +package askcli + +import ( + "os/exec" + "testing" +) + +func TestLockHolderIsStale_NonAskLiveProcess(t *testing.T) { + cmd := exec.Command("sleep", "60") + if err := cmd.Start(); err != nil { + t.Skip("sleep not available:", err) + } + defer func() { _ = cmd.Process.Kill() }() + if !lockHolderIsStale(cmd.Process.Pid, "ask") { + t.Fatal("expected sleep process to be stale when expecting ask") + } +} diff --git a/internal/askcli/runlock_stale_other.go b/internal/askcli/runlock_stale_other.go new file mode 100644 index 0000000..21174c4 --- /dev/null +++ b/internal/askcli/runlock_stale_other.go @@ -0,0 +1,23 @@ +//go:build !linux + +package askcli + +import ( + "os" + "syscall" +) + +func lockHolderIsStale(pid int, expectedComm string) bool { + if pid <= 0 { + return false + } + _ = expectedComm + proc, err := os.FindProcess(pid) + if err != nil { + return true + } + if err := proc.Signal(syscall.Signal(0)); err != nil { + return true + } + return false +} diff --git a/internal/askcli/runlock_test.go b/internal/askcli/runlock_test.go new file mode 100644 index 0000000..f56f214 --- /dev/null +++ b/internal/askcli/runlock_test.go @@ -0,0 +1,46 @@ +package askcli + +import ( + "context" + "os" + "path/filepath" + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestAcquireAskRepoLock_SerializesConcurrentHolders(t *testing.T) { + tmp := t.TempDir() + if err := os.MkdirAll(filepath.Join(tmp, ".git"), 0o755); err != nil { + t.Fatal(err) + } + var maxHeld int32 + var cur int32 + var wg sync.WaitGroup + for i := 0; i < 6; i++ { + wg.Add(1) + go func() { + defer wg.Done() + unlock, err := acquireAskRepoLock(context.Background(), tmp) + if err != nil { + t.Errorf("lock: %v", err) + return + } + defer func() { _ = unlock() }() + n := atomic.AddInt32(&cur, 1) + for { + old := atomic.LoadInt32(&maxHeld) + if n <= old || atomic.CompareAndSwapInt32(&maxHeld, old, n) { + break + } + } + time.Sleep(25 * time.Millisecond) + atomic.AddInt32(&cur, -1) + }() + } + wg.Wait() + if got := atomic.LoadInt32(&maxHeld); got != 1 { + t.Fatalf("max concurrent lock holders = %d, want 1", got) + } +} diff --git a/internal/askcli/taskexec.go b/internal/askcli/taskexec.go index 0b68e3b..4eed461 100644 --- a/internal/askcli/taskexec.go +++ b/internal/askcli/taskexec.go @@ -77,17 +77,25 @@ func (e Executor) Run(ctx context.Context, args []string, stdin io.Reader, stdou if err != nil { return 1, fmt.Errorf("%s: task binary lookup failed: %w", executor.label(), err) } + gitRoot, gitErr := executor.detectRepoRoot(ctx) repoRoot := "" if _, ok := taskProjectFromContext(ctx); !ok { - repoRoot, err = executor.detectRepoRoot(ctx) - if err != nil { - return 1, fmt.Errorf("%s: must be run inside a git repository: %w", executor.label(), err) + if gitErr != nil { + return 1, fmt.Errorf("%s: must be run inside a git repository: %w", executor.label(), gitErr) } + repoRoot = gitRoot } taskArgs, err := executor.taskArgs(ctx, repoRoot, args) if err != nil { return 1, fmt.Errorf("%s: %w", executor.label(), err) } + if gitErr == nil { + unlockAsk, lerr := acquireAskRepoLock(ctx, gitRoot) + if lerr != nil { + return 1, fmt.Errorf("%s: %w", executor.label(), lerr) + } + defer func() { _ = unlockAsk() }() + } if err := executor.runCommand(ctx, taskPath, taskArgs, stdin, stdout, stderr); err != nil { return exitCodeFor(err), nil } diff --git a/internal/askcli/taskexec_test.go b/internal/askcli/taskexec_test.go index 2236866..5e95f1c 100644 --- a/internal/askcli/taskexec_test.go +++ b/internal/askcli/taskexec_test.go @@ -5,12 +5,23 @@ import ( "context" "errors" "io" + "os" "os/exec" + "path/filepath" "reflect" "strings" "testing" ) +func fakeHexaiRepoDir(t *testing.T) string { + t.Helper() + base := filepath.Join(t.TempDir(), "hexai") + if err := os.MkdirAll(filepath.Join(base, ".git"), 0o755); err != nil { + t.Fatalf("mkdir .git: %v", err) + } + return base +} + func TestExecutorTaskArgs(t *testing.T) { exec_ := NewExecutor("ask") args, err := exec_.taskArgs(context.Background(), "/tmp/work/hexai", []string{"list", "limit:1"}) @@ -75,12 +86,13 @@ func TestExecutorTaskArgs_AddNoAgentScope(t *testing.T) { } func TestExecutorRun_InjectsProjectFilterAndAgentTag(t *testing.T) { + repo := fakeHexaiRepoDir(t) var gotName string var gotArgs []string exec_ := Executor{ commandName: "ask", findBinary: func() (string, error) { return "/usr/bin/task", nil }, - detectRepoRoot: func(context.Context) (string, error) { return "/tmp/work/hexai", nil }, + detectRepoRoot: func(context.Context) (string, error) { return repo, nil }, runCommand: func(_ context.Context, name string, args []string, stdin io.Reader, stdout, stderr io.Writer) error { gotName = name gotArgs = append([]string(nil), args...) @@ -105,11 +117,12 @@ func TestExecutorRun_InjectsProjectFilterAndAgentTag(t *testing.T) { } func TestExecutorRun_InjectsProjectFilterAndNoAgentTag(t *testing.T) { + repo := fakeHexaiRepoDir(t) var gotArgs []string exec_ := Executor{ commandName: "ask", findBinary: func() (string, error) { return "/usr/bin/task", nil }, - detectRepoRoot: func(context.Context) (string, error) { return "/tmp/work/hexai", nil }, + detectRepoRoot: func(context.Context) (string, error) { return repo, nil }, runCommand: func(_ context.Context, name string, args []string, stdin io.Reader, stdout, stderr io.Writer) error { gotArgs = append([]string(nil), args...) return nil @@ -130,14 +143,16 @@ func TestExecutorRun_InjectsProjectFilterAndNoAgentTag(t *testing.T) { } } -func TestExecutorRun_ProjectOverrideSkipsRepoDetection(t *testing.T) { +func TestExecutorRun_ProjectOverrideStillLocksUsingGitRoot(t *testing.T) { + repo := fakeHexaiRepoDir(t) + var detectCalls int var gotArgs []string exec_ := Executor{ commandName: "ask", findBinary: func() (string, error) { return "/usr/bin/task", nil }, detectRepoRoot: func(context.Context) (string, error) { - t.Fatal("detectRepoRoot should not be called when project override is set") - return "", nil + detectCalls++ + return repo, nil }, runCommand: func(_ context.Context, name string, args []string, stdin io.Reader, stdout, stderr io.Writer) error { gotArgs = append([]string(nil), args...) @@ -153,6 +168,9 @@ func TestExecutorRun_ProjectOverrideSkipsRepoDetection(t *testing.T) { if exitCode != 0 { t.Fatalf("exitCode = %d, want 0", exitCode) } + if detectCalls != 1 { + t.Fatalf("detectRepoRoot calls = %d, want 1", detectCalls) + } wantArgs := []string{"rc.verbose=nothing", "rc.confirmation=off", "project:alpha", "+agent", "list"} if !reflect.DeepEqual(gotArgs, wantArgs) { t.Fatalf("task args = %v, want %v", gotArgs, wantArgs) @@ -180,10 +198,11 @@ func TestExecutorRun_OutsideGitRepo_IsActionable(t *testing.T) { } func TestExecutorRun_PreservesTaskwarriorExitCode(t *testing.T) { + repo := fakeHexaiRepoDir(t) exec_ := Executor{ commandName: "ask", findBinary: func() (string, error) { return "/usr/bin/task", nil }, - detectRepoRoot: func(context.Context) (string, error) { return "/tmp/work/hexai", nil }, + detectRepoRoot: func(context.Context) (string, error) { return repo, nil }, runCommand: func(context.Context, string, []string, io.Reader, io.Writer, io.Writer) error { return exec.Command("sh", "-c", "exit 7").Run() }, @@ -199,12 +218,13 @@ func TestExecutorRun_PreservesTaskwarriorExitCode(t *testing.T) { } func TestExecutorRun_PreservesStdoutAndStderr(t *testing.T) { + repo := fakeHexaiRepoDir(t) var stdout bytes.Buffer var stderr bytes.Buffer exec_ := Executor{ commandName: "ask", findBinary: func() (string, error) { return "/usr/bin/task", nil }, - detectRepoRoot: func(context.Context) (string, error) { return "/tmp/work/hexai", nil }, + detectRepoRoot: func(context.Context) (string, error) { return repo, nil }, runCommand: func(_ context.Context, name string, args []string, stdin io.Reader, out, errOut io.Writer) error { _, _ = io.WriteString(out, "task stdout") _, _ = io.WriteString(errOut, "task stderr") -- cgit v1.2.3