summaryrefslogtreecommitdiff
path: root/internal/askcli
diff options
context:
space:
mode:
authorPaul Buetow <paul@buetow.org>2026-04-13 08:09:33 +0300
committerPaul Buetow <paul@buetow.org>2026-04-13 08:10:16 +0300
commitf2dd8d8a515c1a2a220836231ad1a671a5e9b73d (patch)
tree5b19585afb01b60d03d24a96b57bc7b986ea4cc0 /internal/askcli
parent56002ff942de1bfb0ce467ec37a692b8c4ca01e9 (diff)
ask: serialize concurrent CLI with repo lock and stale PID recovery
Add advisory lock under .git/hexai-ask.lock around Taskwarrior execution, with metadata (PID and process basename) and Linux /proc comm checks to remove orphan lock files when the recorded holder is gone or not ask. Extract internal/filelock for shared flock helpers; stats uses it too. Made-with: Cursor
Diffstat (limited to 'internal/askcli')
-rw-r--r--internal/askcli/runlock.go147
-rw-r--r--internal/askcli/runlock_stale_linux.go33
-rw-r--r--internal/askcli/runlock_stale_linux_test.go19
-rw-r--r--internal/askcli/runlock_stale_other.go23
-rw-r--r--internal/askcli/runlock_test.go46
-rw-r--r--internal/askcli/taskexec.go14
-rw-r--r--internal/askcli/taskexec_test.go34
7 files changed, 306 insertions, 10 deletions
diff --git a/internal/askcli/runlock.go b/internal/askcli/runlock.go
new file mode 100644
index 0000000..d0d7be3
--- /dev/null
+++ b/internal/askcli/runlock.go
@@ -0,0 +1,147 @@
+package askcli
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "os"
+ "path/filepath"
+ "strconv"
+ "strings"
+ "time"
+
+ "codeberg.org/snonux/hexai/internal/filelock"
+)
+
+const askRepoLockFile = "hexai-ask.lock"
+
+var errAskLockReopen = errors.New("ask lock: reopen after stale file removal")
+
+func lockProcessLabel() string {
+ if exe, err := os.Executable(); err == nil {
+ if b := filepath.Base(exe); b != "" && b != "." {
+ return b
+ }
+ }
+ if b := filepath.Base(os.Args[0]); b != "" {
+ return b
+ }
+ return "ask"
+}
+
+func readLockHolderPID(f *os.File) int {
+ if _, err := f.Seek(0, io.SeekStart); err != nil {
+ return 0
+ }
+ var buf [64]byte
+ n, err := f.Read(buf[:])
+ if err != nil && !errors.Is(err, io.EOF) {
+ return 0
+ }
+ line := strings.TrimSpace(string(buf[:n]))
+ if line == "" {
+ return 0
+ }
+ end := strings.IndexAny(line, "\n\r \t")
+ if end >= 0 {
+ line = line[:end]
+ }
+ pid, err := strconv.Atoi(line)
+ if err != nil || pid <= 0 {
+ return 0
+ }
+ return pid
+}
+
+func writeLockMetadata(f *os.File, pid int, comm string) error {
+ if _, err := f.Seek(0, io.SeekStart); err != nil {
+ return err
+ }
+ if err := f.Truncate(0); err != nil {
+ return err
+ }
+ _, err := fmt.Fprintf(f, "%d\n%s\n", pid, comm)
+ if err != nil {
+ return err
+ }
+ return f.Sync()
+}
+
+// waitOrAcquireAskLockFD tries to take an exclusive lock on f, or blocks until ctx ends.
+// On success it writes lock metadata and returns an unlock function (which closes f).
+// errAskLockReopen means the caller should open the lock path again after stale removal.
+func waitOrAcquireAskLockFD(
+ ctx context.Context,
+ f *os.File,
+ lockPath string,
+ comm string,
+ retryTimer *time.Timer,
+) (func() error, error) {
+ for {
+ err := filelock.TryExclusive(f)
+ if err == nil {
+ if werr := writeLockMetadata(f, os.Getpid(), comm); werr != nil {
+ _ = filelock.UnlockExclusive(f)
+ _ = f.Close()
+ return nil, fmt.Errorf("ask lock: write metadata: %w", werr)
+ }
+ return func() error {
+ uErr := filelock.UnlockExclusive(f)
+ cErr := f.Close()
+ return errors.Join(uErr, cErr)
+ }, nil
+ }
+ if !errors.Is(err, filelock.ErrWouldBlock) {
+ _ = f.Close()
+ return nil, fmt.Errorf("ask lock: %w", err)
+ }
+
+ pid := readLockHolderPID(f)
+ if pid > 0 && lockHolderIsStale(pid, comm) {
+ _ = f.Close()
+ if rerr := os.Remove(lockPath); rerr != nil && !errors.Is(rerr, os.ErrNotExist) {
+ return nil, fmt.Errorf("ask lock: remove stale %s: %w", lockPath, rerr)
+ }
+ return nil, errAskLockReopen
+ }
+
+ retryTimer.Reset(5 * time.Millisecond)
+ select {
+ case <-ctx.Done():
+ _ = f.Close()
+ return nil, ctx.Err()
+ case <-retryTimer.C:
+ }
+ }
+}
+
+// acquireAskRepoLock serializes ask CLI access for a git working copy. It uses an
+// advisory lock under .git and records holder PID plus process name for stale detection.
+func acquireAskRepoLock(ctx context.Context, gitRoot string) (func() error, error) {
+ lockPath := filepath.Join(gitRoot, ".git", askRepoLockFile)
+ if err := os.MkdirAll(filepath.Dir(lockPath), 0o755); err != nil {
+ return nil, fmt.Errorf("ask lock: mkdir: %w", err)
+ }
+
+ comm := lockProcessLabel()
+ retryTimer := time.NewTimer(5 * time.Millisecond)
+ defer retryTimer.Stop()
+
+ for removalAttempts := 0; removalAttempts < 16; removalAttempts++ {
+ f, err := os.OpenFile(lockPath, os.O_CREATE|os.O_RDWR, 0o600)
+ if err != nil {
+ return nil, fmt.Errorf("ask lock: open %s: %w", lockPath, err)
+ }
+ unlock, err := waitOrAcquireAskLockFD(ctx, f, lockPath, comm, retryTimer)
+ if err == nil {
+ return unlock, nil
+ }
+ if errors.Is(err, errAskLockReopen) {
+ continue
+ }
+ return nil, err
+ }
+
+ return nil, fmt.Errorf("ask lock: could not acquire %s after stale recovery attempts", lockPath)
+}
diff --git a/internal/askcli/runlock_stale_linux.go b/internal/askcli/runlock_stale_linux.go
new file mode 100644
index 0000000..183dfb2
--- /dev/null
+++ b/internal/askcli/runlock_stale_linux.go
@@ -0,0 +1,33 @@
+//go:build linux
+
+package askcli
+
+import (
+ "os"
+ "strconv"
+ "strings"
+
+ "golang.org/x/sys/unix"
+)
+
+func lockHolderIsStale(pid int, expectedComm string) bool {
+ if pid <= 0 {
+ return false
+ }
+ if err := unix.Kill(pid, 0); err != nil {
+ return true
+ }
+ data, err := os.ReadFile("/proc/" + strconv.Itoa(pid) + "/comm")
+ if err != nil {
+ return false
+ }
+ holder := strings.TrimSpace(string(data))
+ if holder == "" {
+ return false
+ }
+ want := expectedComm
+ if len(want) > 15 {
+ want = want[:15]
+ }
+ return holder != want
+}
diff --git a/internal/askcli/runlock_stale_linux_test.go b/internal/askcli/runlock_stale_linux_test.go
new file mode 100644
index 0000000..7c581fb
--- /dev/null
+++ b/internal/askcli/runlock_stale_linux_test.go
@@ -0,0 +1,19 @@
+//go:build linux
+
+package askcli
+
+import (
+ "os/exec"
+ "testing"
+)
+
+func TestLockHolderIsStale_NonAskLiveProcess(t *testing.T) {
+ cmd := exec.Command("sleep", "60")
+ if err := cmd.Start(); err != nil {
+ t.Skip("sleep not available:", err)
+ }
+ defer func() { _ = cmd.Process.Kill() }()
+ if !lockHolderIsStale(cmd.Process.Pid, "ask") {
+ t.Fatal("expected sleep process to be stale when expecting ask")
+ }
+}
diff --git a/internal/askcli/runlock_stale_other.go b/internal/askcli/runlock_stale_other.go
new file mode 100644
index 0000000..21174c4
--- /dev/null
+++ b/internal/askcli/runlock_stale_other.go
@@ -0,0 +1,23 @@
+//go:build !linux
+
+package askcli
+
+import (
+ "os"
+ "syscall"
+)
+
+func lockHolderIsStale(pid int, expectedComm string) bool {
+ if pid <= 0 {
+ return false
+ }
+ _ = expectedComm
+ proc, err := os.FindProcess(pid)
+ if err != nil {
+ return true
+ }
+ if err := proc.Signal(syscall.Signal(0)); err != nil {
+ return true
+ }
+ return false
+}
diff --git a/internal/askcli/runlock_test.go b/internal/askcli/runlock_test.go
new file mode 100644
index 0000000..f56f214
--- /dev/null
+++ b/internal/askcli/runlock_test.go
@@ -0,0 +1,46 @@
+package askcli
+
+import (
+ "context"
+ "os"
+ "path/filepath"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+)
+
+func TestAcquireAskRepoLock_SerializesConcurrentHolders(t *testing.T) {
+ tmp := t.TempDir()
+ if err := os.MkdirAll(filepath.Join(tmp, ".git"), 0o755); err != nil {
+ t.Fatal(err)
+ }
+ var maxHeld int32
+ var cur int32
+ var wg sync.WaitGroup
+ for i := 0; i < 6; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ unlock, err := acquireAskRepoLock(context.Background(), tmp)
+ if err != nil {
+ t.Errorf("lock: %v", err)
+ return
+ }
+ defer func() { _ = unlock() }()
+ n := atomic.AddInt32(&cur, 1)
+ for {
+ old := atomic.LoadInt32(&maxHeld)
+ if n <= old || atomic.CompareAndSwapInt32(&maxHeld, old, n) {
+ break
+ }
+ }
+ time.Sleep(25 * time.Millisecond)
+ atomic.AddInt32(&cur, -1)
+ }()
+ }
+ wg.Wait()
+ if got := atomic.LoadInt32(&maxHeld); got != 1 {
+ t.Fatalf("max concurrent lock holders = %d, want 1", got)
+ }
+}
diff --git a/internal/askcli/taskexec.go b/internal/askcli/taskexec.go
index 0b68e3b..4eed461 100644
--- a/internal/askcli/taskexec.go
+++ b/internal/askcli/taskexec.go
@@ -77,17 +77,25 @@ func (e Executor) Run(ctx context.Context, args []string, stdin io.Reader, stdou
if err != nil {
return 1, fmt.Errorf("%s: task binary lookup failed: %w", executor.label(), err)
}
+ gitRoot, gitErr := executor.detectRepoRoot(ctx)
repoRoot := ""
if _, ok := taskProjectFromContext(ctx); !ok {
- repoRoot, err = executor.detectRepoRoot(ctx)
- if err != nil {
- return 1, fmt.Errorf("%s: must be run inside a git repository: %w", executor.label(), err)
+ if gitErr != nil {
+ return 1, fmt.Errorf("%s: must be run inside a git repository: %w", executor.label(), gitErr)
}
+ repoRoot = gitRoot
}
taskArgs, err := executor.taskArgs(ctx, repoRoot, args)
if err != nil {
return 1, fmt.Errorf("%s: %w", executor.label(), err)
}
+ if gitErr == nil {
+ unlockAsk, lerr := acquireAskRepoLock(ctx, gitRoot)
+ if lerr != nil {
+ return 1, fmt.Errorf("%s: %w", executor.label(), lerr)
+ }
+ defer func() { _ = unlockAsk() }()
+ }
if err := executor.runCommand(ctx, taskPath, taskArgs, stdin, stdout, stderr); err != nil {
return exitCodeFor(err), nil
}
diff --git a/internal/askcli/taskexec_test.go b/internal/askcli/taskexec_test.go
index 2236866..5e95f1c 100644
--- a/internal/askcli/taskexec_test.go
+++ b/internal/askcli/taskexec_test.go
@@ -5,12 +5,23 @@ import (
"context"
"errors"
"io"
+ "os"
"os/exec"
+ "path/filepath"
"reflect"
"strings"
"testing"
)
+func fakeHexaiRepoDir(t *testing.T) string {
+ t.Helper()
+ base := filepath.Join(t.TempDir(), "hexai")
+ if err := os.MkdirAll(filepath.Join(base, ".git"), 0o755); err != nil {
+ t.Fatalf("mkdir .git: %v", err)
+ }
+ return base
+}
+
func TestExecutorTaskArgs(t *testing.T) {
exec_ := NewExecutor("ask")
args, err := exec_.taskArgs(context.Background(), "/tmp/work/hexai", []string{"list", "limit:1"})
@@ -75,12 +86,13 @@ func TestExecutorTaskArgs_AddNoAgentScope(t *testing.T) {
}
func TestExecutorRun_InjectsProjectFilterAndAgentTag(t *testing.T) {
+ repo := fakeHexaiRepoDir(t)
var gotName string
var gotArgs []string
exec_ := Executor{
commandName: "ask",
findBinary: func() (string, error) { return "/usr/bin/task", nil },
- detectRepoRoot: func(context.Context) (string, error) { return "/tmp/work/hexai", nil },
+ detectRepoRoot: func(context.Context) (string, error) { return repo, nil },
runCommand: func(_ context.Context, name string, args []string, stdin io.Reader, stdout, stderr io.Writer) error {
gotName = name
gotArgs = append([]string(nil), args...)
@@ -105,11 +117,12 @@ func TestExecutorRun_InjectsProjectFilterAndAgentTag(t *testing.T) {
}
func TestExecutorRun_InjectsProjectFilterAndNoAgentTag(t *testing.T) {
+ repo := fakeHexaiRepoDir(t)
var gotArgs []string
exec_ := Executor{
commandName: "ask",
findBinary: func() (string, error) { return "/usr/bin/task", nil },
- detectRepoRoot: func(context.Context) (string, error) { return "/tmp/work/hexai", nil },
+ detectRepoRoot: func(context.Context) (string, error) { return repo, nil },
runCommand: func(_ context.Context, name string, args []string, stdin io.Reader, stdout, stderr io.Writer) error {
gotArgs = append([]string(nil), args...)
return nil
@@ -130,14 +143,16 @@ func TestExecutorRun_InjectsProjectFilterAndNoAgentTag(t *testing.T) {
}
}
-func TestExecutorRun_ProjectOverrideSkipsRepoDetection(t *testing.T) {
+func TestExecutorRun_ProjectOverrideStillLocksUsingGitRoot(t *testing.T) {
+ repo := fakeHexaiRepoDir(t)
+ var detectCalls int
var gotArgs []string
exec_ := Executor{
commandName: "ask",
findBinary: func() (string, error) { return "/usr/bin/task", nil },
detectRepoRoot: func(context.Context) (string, error) {
- t.Fatal("detectRepoRoot should not be called when project override is set")
- return "", nil
+ detectCalls++
+ return repo, nil
},
runCommand: func(_ context.Context, name string, args []string, stdin io.Reader, stdout, stderr io.Writer) error {
gotArgs = append([]string(nil), args...)
@@ -153,6 +168,9 @@ func TestExecutorRun_ProjectOverrideSkipsRepoDetection(t *testing.T) {
if exitCode != 0 {
t.Fatalf("exitCode = %d, want 0", exitCode)
}
+ if detectCalls != 1 {
+ t.Fatalf("detectRepoRoot calls = %d, want 1", detectCalls)
+ }
wantArgs := []string{"rc.verbose=nothing", "rc.confirmation=off", "project:alpha", "+agent", "list"}
if !reflect.DeepEqual(gotArgs, wantArgs) {
t.Fatalf("task args = %v, want %v", gotArgs, wantArgs)
@@ -180,10 +198,11 @@ func TestExecutorRun_OutsideGitRepo_IsActionable(t *testing.T) {
}
func TestExecutorRun_PreservesTaskwarriorExitCode(t *testing.T) {
+ repo := fakeHexaiRepoDir(t)
exec_ := Executor{
commandName: "ask",
findBinary: func() (string, error) { return "/usr/bin/task", nil },
- detectRepoRoot: func(context.Context) (string, error) { return "/tmp/work/hexai", nil },
+ detectRepoRoot: func(context.Context) (string, error) { return repo, nil },
runCommand: func(context.Context, string, []string, io.Reader, io.Writer, io.Writer) error {
return exec.Command("sh", "-c", "exit 7").Run()
},
@@ -199,12 +218,13 @@ func TestExecutorRun_PreservesTaskwarriorExitCode(t *testing.T) {
}
func TestExecutorRun_PreservesStdoutAndStderr(t *testing.T) {
+ repo := fakeHexaiRepoDir(t)
var stdout bytes.Buffer
var stderr bytes.Buffer
exec_ := Executor{
commandName: "ask",
findBinary: func() (string, error) { return "/usr/bin/task", nil },
- detectRepoRoot: func(context.Context) (string, error) { return "/tmp/work/hexai", nil },
+ detectRepoRoot: func(context.Context) (string, error) { return repo, nil },
runCommand: func(_ context.Context, name string, args []string, stdin io.Reader, out, errOut io.Writer) error {
_, _ = io.WriteString(out, "task stdout")
_, _ = io.WriteString(errOut, "task stderr")