diff options
| -rw-r--r-- | internal/lsp/handlers_codeaction.go | 5 | ||||
| -rw-r--r-- | internal/lsp/handlers_completion.go | 2 | ||||
| -rw-r--r-- | internal/lsp/handlers_document.go | 3 | ||||
| -rw-r--r-- | internal/lsp/handlers_init.go | 2 | ||||
| -rw-r--r-- | internal/lsp/init_shutdown_test.go | 41 | ||||
| -rw-r--r-- | internal/lsp/server.go | 51 |
6 files changed, 86 insertions, 18 deletions
diff --git a/internal/lsp/handlers_codeaction.go b/internal/lsp/handlers_codeaction.go index f61b79f..f45dec4 100644 --- a/internal/lsp/handlers_codeaction.go +++ b/internal/lsp/handlers_codeaction.go @@ -2,7 +2,6 @@ package lsp import ( - "context" "encoding/json" "fmt" "os" @@ -392,7 +391,7 @@ func (s *Server) customActionByID(id string) *CustomAction { } func (s *Server) completeCodeAction(ca CodeAction, uri string, rng Range, sys, user string, timeout time.Duration) (CodeAction, bool) { - ctx, cancel := context.WithTimeout(context.Background(), timeout) + ctx, cancel := s.requestTimeoutContext(timeout) defer cancel() messages := []llm.Message{{Role: "system", Content: sys}, {Role: "user", Content: user}} spec := s.buildRequestSpec(surfaceCodeAction) @@ -724,7 +723,7 @@ func (s *Server) generateGoTestFunction(funcCode string) string { cfg := s.currentConfig() sys := cfg.PromptCodeActionGoTestSystem user := renderTemplate(cfg.PromptCodeActionGoTestUser, map[string]string{"function": funcCode}) - ctx, cancel := context.WithTimeout(context.Background(), 18*time.Second) + ctx, cancel := s.requestTimeoutContext(18 * time.Second) defer cancel() messages := []llm.Message{{Role: "system", Content: sys}, {Role: "user", Content: user}} if out, err := s.chatWithStats(ctx, surfaceCodeAction, spec, messages); err == nil { diff --git a/internal/lsp/handlers_completion.go b/internal/lsp/handlers_completion.go index 6350c59..4212897 100644 --- a/internal/lsp/handlers_completion.go +++ b/internal/lsp/handlers_completion.go @@ -104,7 +104,7 @@ func (s *Server) logCompletionContext(p CompletionParams, above, current, below, } func (s *Server) tryLLMCompletion(p CompletionParams, above, current, below, funcCtx, docStr string, hasExtra bool, extraText string) ([]CompletionItem, bool, bool) { - ctx, cancel := context.WithTimeout(context.Background(), 12*time.Second) + ctx, cancel := s.requestTimeoutContext(12 * time.Second) var cancelOnce sync.Once end := func() { cancelOnce.Do(cancel) } diff --git a/internal/lsp/handlers_document.go b/internal/lsp/handlers_document.go index b907014..e39022e 100644 --- a/internal/lsp/handlers_document.go +++ b/internal/lsp/handlers_document.go @@ -2,7 +2,6 @@ package lsp import ( - "context" "encoding/json" "strings" "time" @@ -166,7 +165,7 @@ func (s *Server) detectAndHandleChat(uri string) { return } go func(prompt string, remove int) { - ctx, cancel := context.WithTimeout(context.Background(), 25*time.Second) + ctx, cancel := s.requestTimeoutContext(25 * time.Second) defer cancel() // Build messages with history and context_mode aware extras. pos := Position{Line: lineIdx, Character: lastIdx + 1} diff --git a/internal/lsp/handlers_init.go b/internal/lsp/handlers_init.go index 702871d..0cecc6c 100644 --- a/internal/lsp/handlers_init.go +++ b/internal/lsp/handlers_init.go @@ -36,9 +36,11 @@ func (s *Server) handleInitialized() { } func (s *Server) handleShutdown(req Request) { + s.cancelRequests() s.reply(req.ID, nil, nil) } func (s *Server) handleExit() { + s.cancelRequests() s.exited.Store(true) } diff --git a/internal/lsp/init_shutdown_test.go b/internal/lsp/init_shutdown_test.go index 2847170..4e1bd2f 100644 --- a/internal/lsp/init_shutdown_test.go +++ b/internal/lsp/init_shutdown_test.go @@ -2,10 +2,13 @@ package lsp import ( "bytes" + "context" "encoding/json" + "errors" "io" "log" "testing" + "time" ) func TestHandleShutdown_Replies(t *testing.T) { @@ -20,3 +23,41 @@ func TestHandleShutdown_Replies(t *testing.T) { t.Fatalf("unexpected shutdown response: %+v", resp) } } + +func TestHandleShutdown_CancelsServerContext(t *testing.T) { + var out bytes.Buffer + s := NewServer(bytes.NewReader(nil), &out, log.New(io.Discard, "", 0), ServerOptions{}) + req := Request{JSONRPC: "2.0", ID: json.RawMessage("12"), Method: "shutdown"} + s.handleShutdown(req) + + ctx, cancel := s.requestTimeoutContext(2 * time.Second) + defer cancel() + select { + case <-ctx.Done(): + if !errors.Is(ctx.Err(), context.Canceled) { + t.Fatalf("expected canceled context, got %v", ctx.Err()) + } + default: + t.Fatalf("expected canceled context after shutdown") + } +} + +func TestHandleExit_CancelsServerContext(t *testing.T) { + var out bytes.Buffer + s := NewServer(bytes.NewReader(nil), &out, log.New(io.Discard, "", 0), ServerOptions{}) + s.handleExit() + if !s.exited.Load() { + t.Fatalf("expected exited flag to be set") + } + + ctx, cancel := s.requestTimeoutContext(2 * time.Second) + defer cancel() + select { + case <-ctx.Done(): + if !errors.Is(ctx.Err(), context.Canceled) { + t.Fatalf("expected canceled context, got %v", ctx.Err()) + } + default: + t.Fatalf("expected canceled context after exit") + } +} diff --git a/internal/lsp/server.go b/internal/lsp/server.go index fa3b375..bf1f724 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -3,6 +3,7 @@ package lsp import ( "bufio" + "context" "encoding/json" "io" "log" @@ -21,17 +22,19 @@ import ( // Server implements a minimal LSP over stdio. type Server struct { - in *bufio.Reader - out io.Writer - outMu sync.Mutex - logger *log.Logger - exited atomic.Bool - mu sync.RWMutex - docs map[string]*document - logContext bool - configStore *runtimeconfig.Store - cfg appconfig.App - llmClient llm.Client + in *bufio.Reader + out io.Writer + outMu sync.Mutex + logger *log.Logger + serverCtx context.Context + serverCancel context.CancelFunc + exited atomic.Bool + mu sync.RWMutex + docs map[string]*document + logContext bool + configStore *runtimeconfig.Store + cfg appconfig.App + llmClient llm.Client codeActionSubsystem chatSubsystem // LLM request stats @@ -94,7 +97,17 @@ type CustomAction struct { } func NewServer(r io.Reader, w io.Writer, logger *log.Logger, opts ServerOptions) *Server { - s := &Server{in: bufio.NewReader(r), out: w, logger: logger, docs: make(map[string]*document), logContext: opts.LogContext, configStore: opts.ConfigStore} + ctx, cancel := context.WithCancel(context.Background()) + s := &Server{ + in: bufio.NewReader(r), + out: w, + logger: logger, + docs: make(map[string]*document), + logContext: opts.LogContext, + configStore: opts.ConfigStore, + serverCtx: ctx, + serverCancel: cancel, + } s.startTime = time.Now() s.compCache = make(map[string]string) s.pendingCompletions = make(map[string][]CompletionItem) @@ -408,7 +421,21 @@ func (s *Server) customActions() []CustomAction { return customs } +func (s *Server) requestTimeoutContext(timeout time.Duration) (context.Context, context.CancelFunc) { + if s.serverCtx == nil { + return context.WithTimeout(context.Background(), timeout) + } + return context.WithTimeout(s.serverCtx, timeout) +} + +func (s *Server) cancelRequests() { + if s.serverCancel != nil { + s.serverCancel() + } +} + func (s *Server) Run() error { + defer s.cancelRequests() for { body, err := s.readMessage() if err == io.EOF { |
