diff options
Diffstat (limited to 'internal/lsp')
| -rw-r--r-- | internal/lsp/codeaction_test.go | 63 | ||||
| -rw-r--r-- | internal/lsp/handlers.go | 339 | ||||
| -rw-r--r-- | internal/lsp/handlers_helpers_test.go | 52 | ||||
| -rw-r--r-- | internal/lsp/handlers_test.go | 30 |
4 files changed, 305 insertions, 179 deletions
diff --git a/internal/lsp/codeaction_test.go b/internal/lsp/codeaction_test.go new file mode 100644 index 0000000..e9abbb8 --- /dev/null +++ b/internal/lsp/codeaction_test.go @@ -0,0 +1,63 @@ +package lsp + +import ( + "context" + "encoding/json" + "testing" + "hexai/internal/llm" +) + +type fakeLLM struct{ resp string; err error } + +func (f fakeLLM) Chat(_ context.Context, _ []llm.Message, _ ...llm.RequestOption) (string, error) { + return f.resp, f.err +} +func (f fakeLLM) Name() string { return "fake" } +func (f fakeLLM) DefaultModel() string { return "fake-model" } + +func TestBuildRewriteCodeAction_ReturnsEdit(t *testing.T) { + s := newTestServer() + s.llmClient = fakeLLM{resp: "REWRITTEN"} + p := CodeActionParams{TextDocument: TextDocumentIdentifier{URI: "file:///t.go"}, Range: Range{Start: Position{Line: 1, Character: 2}, End: Position{Line: 3, Character: 4}}} + sel := ";rewrite;\nold code" + ca := s.buildRewriteCodeAction(p, sel) + if ca == nil { t.Fatalf("expected code action") } + if ca.Edit == nil || len(ca.Edit.Changes) == 0 { t.Fatalf("expected workspace edit with changes") } + edits := ca.Edit.Changes[p.TextDocument.URI] + if len(edits) != 1 { t.Fatalf("expected 1 edit, got %d", len(edits)) } + if edits[0].Range != p.Range { t.Fatalf("edit range mismatch: got %+v want %+v", edits[0].Range, p.Range) } + if edits[0].NewText == "" { t.Fatalf("expected non-empty replacement text") } +} + +func TestBuildRewriteCodeAction_NoInstruction(t *testing.T) { + s := newTestServer() + s.llmClient = fakeLLM{resp: "IGNORED"} + p := CodeActionParams{TextDocument: TextDocumentIdentifier{URI: "file:///t.go"}, Range: Range{}} + sel := "no instruction here" + if ca := s.buildRewriteCodeAction(p, sel); ca != nil { t.Fatalf("expected nil action when no instruction present") } +} + +func TestBuildDiagnosticsCodeAction_ReturnsEdit(t *testing.T) { + s := newTestServer() + s.llmClient = fakeLLM{resp: "FIXED"} + p := CodeActionParams{TextDocument: TextDocumentIdentifier{URI: "file:///t.go"}, Range: Range{Start: Position{Line: 10}, End: Position{Line: 12, Character: 5}}} + ctx := CodeActionContext{Diagnostics: []Diagnostic{ + {Range: Range{Start: Position{Line: 11}, End: Position{Line: 11, Character: 10}}, Message: "inside"}, + {Range: Range{Start: Position{Line: 2}, End: Position{Line: 3}}, Message: "outside"}, + }} + raw, _ := json.Marshal(ctx) + p.Context = json.RawMessage(raw) + sel := "some selected code" + ca := s.buildDiagnosticsCodeAction(p, sel) + if ca == nil { t.Fatalf("expected diagnostics code action") } + if ca.Edit == nil || len(ca.Edit.Changes) == 0 { t.Fatalf("expected workspace edit") } +} + +func TestBuildDiagnosticsCodeAction_NoDiagnostics(t *testing.T) { + s := newTestServer() + s.llmClient = fakeLLM{resp: "FIXED"} + p := CodeActionParams{TextDocument: TextDocumentIdentifier{URI: "file:///t.go"}, Range: Range{}} + // empty context + p.Context = json.RawMessage(nil) + if ca := s.buildDiagnosticsCodeAction(p, "sel"); ca != nil { t.Fatalf("expected nil action when no diagnostics") } +} diff --git a/internal/lsp/handlers.go b/internal/lsp/handlers.go index d21c5b3..43d42c8 100644 --- a/internal/lsp/handlers.go +++ b/internal/lsp/handlers.go @@ -68,16 +68,15 @@ func (s *Server) handleCodeAction(req Request) { } return } - // Extract selected text d := s.getDocument(p.TextDocument.URI) - if d == nil || len(d.lines) == 0 { + if d == nil || len(d.lines) == 0 || s.llmClient == nil { if len(req.ID) != 0 { s.reply(req.ID, []CodeAction{}, nil) } return } sel := extractRangeText(d, p.Range) - if strings.TrimSpace(sel) == "" || s.llmClient == nil { + if strings.TrimSpace(sel) == "" { if len(req.ID) != 0 { s.reply(req.ID, []CodeAction{}, nil) } @@ -85,67 +84,77 @@ func (s *Server) handleCodeAction(req Request) { } actions := make([]CodeAction, 0, 2) + if a := s.buildRewriteCodeAction(p, sel); a != nil { + actions = append(actions, *a) + } + if a := s.buildDiagnosticsCodeAction(p, sel); a != nil { + actions = append(actions, *a) + } + if len(req.ID) != 0 { + s.reply(req.ID, actions, nil) + } +} - // Action 1: Rewrite selection based on first instruction in selection +func (s *Server) buildRewriteCodeAction(p CodeActionParams, sel string) *CodeAction { if instr, cleaned := instructionFromSelection(sel); strings.TrimSpace(instr) != "" { sys := "You are a precise code refactoring engine. Rewrite the given code strictly according to the instruction. Return only the updated code with no prose or backticks. Preserve formatting where reasonable." user := fmt.Sprintf("Instruction: %s\n\nSelected code to transform:\n%s", instr, cleaned) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() messages := []llm.Message{{Role: "system", Content: sys}, {Role: "user", Content: user}} - // Build request options from server settings - opts := []llm.RequestOption{llm.WithMaxTokens(s.maxTokens)} - if s.codingTemperature != nil { - opts = append(opts, llm.WithTemperature(*s.codingTemperature)) - } - if text, err := s.llmClient.Chat(ctx, messages, opts...); err == nil { - out := strings.TrimSpace(text) - if out != "" { + opts := s.llmRequestOpts() + if text, err := s.llmClient.Chat(ctx, messages, opts...); err == nil { + if out := strings.TrimSpace(text); out != "" { edit := WorkspaceEdit{Changes: map[string][]TextEdit{p.TextDocument.URI: {{Range: p.Range, NewText: out}}}} - actions = append(actions, CodeAction{Title: "Hexai: rewrite selection", Kind: "refactor.rewrite", Edit: &edit}) + ca := CodeAction{Title: "Hexai: rewrite selection", Kind: "refactor.rewrite", Edit: &edit} + return &ca } } else { logging.Logf("lsp ", "codeAction rewrite llm error: %v", err) } } + return nil +} - // Action 2: Resolve diagnostics within selection - if diags := s.diagnosticsInRange(p.Context, p.Range); len(diags) > 0 { - // Compose a prompt listing diagnostics relevant to the selected code - sys := "You are a precise code fixer. Resolve the given diagnostics by editing only the selected code. Return only the corrected code with no prose or backticks. Keep behavior and style, and avoid unrelated changes." - var b strings.Builder - b.WriteString("Diagnostics to resolve (selection only):\n") - for i, dgn := range diags { - // Minimal, user-facing summary; include source if present - if dgn.Source != "" { - fmt.Fprintf(&b, "%d. [%s] %s\n", i+1, dgn.Source, dgn.Message) - } else { - fmt.Fprintf(&b, "%d. %s\n", i+1, dgn.Message) - } - } - b.WriteString("\nSelected code:\n") - b.WriteString(sel) - ctx, cancel := context.WithTimeout(context.Background(), 12*time.Second) - defer cancel() - messages := []llm.Message{{Role: "system", Content: sys}, {Role: "user", Content: b.String()}} - opts := []llm.RequestOption{llm.WithMaxTokens(s.maxTokens)} - if s.codingTemperature != nil { - opts = append(opts, llm.WithTemperature(*s.codingTemperature)) - } - if text, err := s.llmClient.Chat(ctx, messages, opts...); err == nil { - out := strings.TrimSpace(text) - if out != "" { - edit := WorkspaceEdit{Changes: map[string][]TextEdit{p.TextDocument.URI: {{Range: p.Range, NewText: out}}}} - actions = append(actions, CodeAction{Title: "Hexai: resolve diagnostics", Kind: "quickfix", Edit: &edit}) - } +func (s *Server) buildDiagnosticsCodeAction(p CodeActionParams, sel string) *CodeAction { + diags := s.diagnosticsInRange(p.Context, p.Range) + if len(diags) == 0 { + return nil + } + sys := "You are a precise code fixer. Resolve the given diagnostics by editing only the selected code. Return only the corrected code with no prose or backticks. Keep behavior and style, and avoid unrelated changes." + var b strings.Builder + b.WriteString("Diagnostics to resolve (selection only):\n") + for i, dgn := range diags { + if dgn.Source != "" { + fmt.Fprintf(&b, "%d. [%s] %s\n", i+1, dgn.Source, dgn.Message) } else { - logging.Logf("lsp ", "codeAction diagnostics llm error: %v", err) + fmt.Fprintf(&b, "%d. %s\n", i+1, dgn.Message) } } + b.WriteString("\nSelected code:\n") + b.WriteString(sel) + ctx, cancel := context.WithTimeout(context.Background(), 12*time.Second) + defer cancel() + messages := []llm.Message{{Role: "system", Content: sys}, {Role: "user", Content: b.String()}} + opts := s.llmRequestOpts() + if text, err := s.llmClient.Chat(ctx, messages, opts...); err == nil { + if out := strings.TrimSpace(text); out != "" { + edit := WorkspaceEdit{Changes: map[string][]TextEdit{p.TextDocument.URI: {{Range: p.Range, NewText: out}}}} + ca := CodeAction{Title: "Hexai: resolve diagnostics", Kind: "quickfix", Edit: &edit} + return &ca + } + } else { + logging.Logf("lsp ", "codeAction diagnostics llm error: %v", err) + } + return nil +} - if len(req.ID) != 0 { - s.reply(req.ID, actions, nil) +func (s *Server) llmRequestOpts() []llm.RequestOption { + opts := []llm.RequestOption{llm.WithMaxTokens(s.maxTokens)} + if s.codingTemperature != nil { + opts = append(opts, llm.WithTemperature(*s.codingTemperature)) } + return opts } // instructionFromSelection extracts the first instruction from selection text. @@ -457,64 +466,22 @@ func (s *Server) tryLLMCompletion(p CompletionParams, above, current, below, fun for _, m := range messages { sentSize += len(m.Content) } - // Update request counters (sent) - s.mu.Lock() - s.llmReqTotal++ - s.llmSentBytesTotal += int64(sentSize) - s.mu.Unlock() + s.incSentCounters(sentSize) - opts := []llm.RequestOption{llm.WithMaxTokens(s.maxTokens)} - if s.codingTemperature != nil { - opts = append(opts, llm.WithTemperature(*s.codingTemperature)) - } - text, err := s.llmClient.Chat(ctx, messages, opts...) + opts := []llm.RequestOption{llm.WithMaxTokens(s.maxTokens)} + if s.codingTemperature != nil { + opts = append(opts, llm.WithTemperature(*s.codingTemperature)) + } + text, err := s.llmClient.Chat(ctx, messages, opts...) if err != nil { logging.Logf("lsp ", "llm completion error: %v", err) // Log updated averages after this request (even if failed) - s.mu.RLock() - avgSent := int64(0) - if s.llmReqTotal > 0 { - avgSent = s.llmSentBytesTotal / s.llmReqTotal - } - avgRecv := int64(0) - if s.llmRespTotal > 0 { - avgRecv = s.llmRespBytesTotal / s.llmRespTotal - } - reqs, sentTot, recvTot := s.llmReqTotal, s.llmSentBytesTotal, s.llmRespBytesTotal - s.mu.RUnlock() - mins := time.Since(s.startTime).Minutes() - if mins <= 0 { - mins = 0.001 - } - rpm := float64(reqs) / mins - sentPerMin := float64(sentTot) / mins - recvPerMin := float64(recvTot) / mins - logging.Logf("lsp ", "llm stats reqs=%d avg_sent=%d avg_recv=%d sent_total=%d recv_total=%d rpm=%.2f sent_per_min=%.0f recv_per_min=%.0f", reqs, avgSent, avgRecv, sentTot, recvTot, rpm, sentPerMin, recvPerMin) + s.logLLMStats() return nil, false } // Update response counters (received) - recvSize := len(text) - s.mu.Lock() - s.llmRespTotal++ - s.llmRespBytesTotal += int64(recvSize) - avgSent := int64(0) - if s.llmReqTotal > 0 { - avgSent = s.llmSentBytesTotal / s.llmReqTotal - } - avgRecv := int64(0) - if s.llmRespTotal > 0 { - avgRecv = s.llmRespBytesTotal / s.llmRespTotal - } - reqs, sentTot, recvTot := s.llmReqTotal, s.llmSentBytesTotal, s.llmRespBytesTotal - s.mu.Unlock() - mins := time.Since(s.startTime).Minutes() - if mins <= 0 { - mins = 0.001 - } - rpm := float64(reqs) / mins - sentPerMin := float64(sentTot) / mins - recvPerMin := float64(recvTot) / mins - logging.Logf("lsp ", "llm stats reqs=%d avg_sent=%d avg_recv=%d sent_total=%d recv_total=%d rpm=%.2f sent_per_min=%.0f recv_per_min=%.0f", reqs, avgSent, avgRecv, sentTot, recvTot, rpm, sentPerMin, recvPerMin) + s.incRecvCounters(len(text)) + s.logLLMStats() cleaned := strings.TrimSpace(text) if cleaned != "" { cleaned = stripDuplicateAssignmentPrefix(current[:p.Position.Character], cleaned) @@ -523,15 +490,18 @@ func (s *Server) tryLLMCompletion(p CompletionParams, above, current, below, fun return nil, false } + return s.makeCompletionItems(cleaned, inParams, current, p, docStr), true +} + +func (s *Server) makeCompletionItems(cleaned string, inParams bool, current string, p CompletionParams, docStr string) []CompletionItem { te, filter := computeTextEditAndFilter(cleaned, inParams, current, p) rm := s.collectPromptRemovalEdits(p.TextDocument.URI) label := labelForCompletion(cleaned, filter) - // Detail shows provider/model for visibility in client UI detail := "Hexai LLM completion" if s.llmClient != nil { detail = "Hexai " + s.llmClient.Name() + ":" + s.llmClient.DefaultModel() } - items := []CompletionItem{{ + return []CompletionItem{{ Label: label, Kind: 1, Detail: detail, @@ -542,7 +512,43 @@ func (s *Server) tryLLMCompletion(p CompletionParams, above, current, below, fun SortText: "0000", Documentation: docStr, }} - return items, true +} + +// small helpers to keep tryLLMCompletion short +func (s *Server) incSentCounters(n int) { + s.mu.Lock() + s.llmReqTotal++ + s.llmSentBytesTotal += int64(n) + s.mu.Unlock() +} + +func (s *Server) incRecvCounters(n int) { + s.mu.Lock() + s.llmRespTotal++ + s.llmRespBytesTotal += int64(n) + s.mu.Unlock() +} + +func (s *Server) logLLMStats() { + s.mu.RLock() + avgSent := int64(0) + if s.llmReqTotal > 0 { + avgSent = s.llmSentBytesTotal / s.llmReqTotal + } + avgRecv := int64(0) + if s.llmRespTotal > 0 { + avgRecv = s.llmRespBytesTotal / s.llmRespTotal + } + reqs, sentTot, recvTot := s.llmReqTotal, s.llmSentBytesTotal, s.llmRespBytesTotal + s.mu.RUnlock() + mins := time.Since(s.startTime).Minutes() + if mins <= 0 { + mins = 0.001 + } + rpm := float64(reqs) / mins + sentPerMin := float64(sentTot) / mins + recvPerMin := float64(recvTot) / mins + logging.Logf("lsp ", "llm stats reqs=%d avg_sent=%d avg_recv=%d sent_total=%d recv_total=%d rpm=%.2f sent_per_min=%.0f recv_per_min=%.0f", reqs, avgSent, avgRecv, sentTot, recvTot, rpm, sentPerMin, recvPerMin) } // collectPromptRemovalEdits returns edits to remove all inline prompt markers. @@ -559,83 +565,78 @@ func (s *Server) collectPromptRemovalEdits(uri string) []TextEdit { } var edits []TextEdit for i, line := range d.lines { - // If the line contains a double-semicolon trigger of the form - // ";;text;" (no space after the ";;" and no space before the closing ';'), - // remove the entire line. - removeWholeLine := false - { - pos := 0 - for pos < len(line) { - j := strings.Index(line[pos:], ";;") - if j < 0 { - break - } - j += pos - // ensure there's a non-space after the two semicolons - if j+2 >= len(line) || line[j+2] == ' ' { - pos = j + 2 - continue - } - // find closing ';' after the content - k := strings.Index(line[j+2:], ";") - if k < 0 { - break - } + edits = append(edits, promptRemovalEditsForLine(line, i)...) + } + return edits +} + +func promptRemovalEditsForLine(line string, lineNum int) []TextEdit { + if hasDoubleSemicolonTrigger(line) { + return []TextEdit{{Range: Range{Start: Position{Line: lineNum, Character: 0}, End: Position{Line: lineNum, Character: len(line)}}, NewText: ""}} + } + return collectSemicolonMarkers(line, lineNum) +} + +func hasDoubleSemicolonTrigger(line string) bool { + pos := 0 + for pos < len(line) { + j := strings.Index(line[pos:], ";;") + if j < 0 { + return false + } + j += pos + if j+2 < len(line) && line[j+2] != ' ' { + if k := strings.Index(line[j+2:], ";"); k >= 0 { closeIdx := j + 2 + k - // ensure char before closing ';' is not a space - if closeIdx-1 < 0 || line[closeIdx-1] == ' ' { - pos = closeIdx + 1 - continue + if closeIdx-1 >= 0 && line[closeIdx-1] != ' ' { + return true } - removeWholeLine = true - break + pos = closeIdx + 1 + continue } + return false + } + pos = j + 2 + } + return false +} + +func collectSemicolonMarkers(line string, lineNum int) []TextEdit { + var edits []TextEdit + startSemi := 0 + for startSemi < len(line) { + j := strings.Index(line[startSemi:], ";") + if j < 0 { + break } - if removeWholeLine { - edits = append(edits, TextEdit{Range: Range{Start: Position{Line: i, Character: 0}, End: Position{Line: i, Character: len(line)}}, NewText: ""}) + j += startSemi + k := strings.Index(line[j+1:], ";") + if k < 0 { + break + } + if j+1 >= len(line) || line[j+1] == ' ' { + startSemi = j + 1 continue } - // Scan for ;...; markers that have no spaces directly inside the semicolons - startSemi := 0 - for startSemi < len(line) { - j := strings.Index(line[startSemi:], ";") - if j < 0 { - break - } - j += startSemi - k := strings.Index(line[j+1:], ";") - if k < 0 { - break - } - // Require no space immediately after the first ';' - if j+1 >= len(line) || line[j+1] == ' ' { - startSemi = j + 1 - continue - } - // Ignore patterns that start with double semicolon here; handled above - if line[j+1] == ';' { - startSemi = j + 2 - continue - } - // Index of the closing ';' - closeIdx := j + 1 + k - // Require no space immediately before the closing ';' - if closeIdx-1 < 0 || line[closeIdx-1] == ' ' { - startSemi = closeIdx + 1 - continue - } - // Require at least one character between the semicolons - if closeIdx-(j+1) < 1 { - startSemi = closeIdx + 1 - continue - } - endChar := closeIdx + 1 // include trailing ';' - if endChar < len(line) && line[endChar] == ' ' { - endChar++ - } - edits = append(edits, TextEdit{Range: Range{Start: Position{Line: i, Character: j}, End: Position{Line: i, Character: endChar}}, NewText: ""}) - startSemi = endChar + if line[j+1] == ';' { + startSemi = j + 2 + continue + } + closeIdx := j + 1 + k + if closeIdx-1 < 0 || line[closeIdx-1] == ' ' { + startSemi = closeIdx + 1 + continue + } + if closeIdx-(j+1) < 1 { + startSemi = closeIdx + 1 + continue + } + endChar := closeIdx + 1 + if endChar < len(line) && line[endChar] == ' ' { + endChar++ } + edits = append(edits, TextEdit{Range: Range{Start: Position{Line: lineNum, Character: j}, End: Position{Line: lineNum, Character: endChar}}, NewText: ""}) + startSemi = endChar } return edits } diff --git a/internal/lsp/handlers_helpers_test.go b/internal/lsp/handlers_helpers_test.go new file mode 100644 index 0000000..84dce77 --- /dev/null +++ b/internal/lsp/handlers_helpers_test.go @@ -0,0 +1,52 @@ +package lsp + +import ( + "strings" + "testing" +) + +func TestHasDoubleSemicolonTrigger(t *testing.T) { + cases := []struct{ + line string + want bool + }{ + {";;todo; remove this", true}, + {"prefix ;;x; suffix", true}, + {";; spaced ;", false}, + {"no markers", false}, + {";;x ; space before close", false}, + } + for _, tc := range cases { + got := hasDoubleSemicolonTrigger(tc.line) + if got != tc.want { + t.Fatalf("hasDoubleSemicolonTrigger(%q)=%v want %v", tc.line, got, tc.want) + } + } +} + +func TestCollectSemicolonMarkers(t *testing.T) { + line := "keep ;ok; this and ;another; that" + edits := collectSemicolonMarkers(line, 7) + if len(edits) != 2 { + t.Fatalf("expected 2 edits, got %d", len(edits)) + } + // Validate the first edit aligns with ;ok; + start := strings.Index(line, ";ok;") + if start < 0 { t.Fatalf("test setup: missing ;ok;") } + if edits[0].Range.Start.Line != 7 || edits[0].Range.Start.Character != start { + t.Fatalf("first edit start got line=%d char=%d want line=7 char=%d", edits[0].Range.Start.Line, edits[0].Range.Start.Character, start) + } +} + +func TestPromptRemovalEditsForLine_WholeLine(t *testing.T) { + line := ";;todo; remove this whole line" + edits := promptRemovalEditsForLine(line, 3) + if len(edits) != 1 { + t.Fatalf("expected 1 whole-line edit, got %d", len(edits)) + } + e := edits[0] + if e.Range.Start.Line != 3 || e.Range.End.Line != 3 || e.Range.Start.Character != 0 || e.Range.End.Character != len(line) { + t.Fatalf("unexpected range for whole-line removal: %+v", e.Range) + } +} + diff --git a/internal/lsp/handlers_test.go b/internal/lsp/handlers_test.go index 9a490e3..10b704b 100644 --- a/internal/lsp/handlers_test.go +++ b/internal/lsp/handlers_test.go @@ -9,16 +9,26 @@ import ( ) func TestInParamList(t *testing.T) { - line := "func foo(a int, b string) int {" - if !inParamList(line, 15) { // inside params - t.Fatalf("expected inParamList true for cursor inside params") - } - if inParamList(line, 2) { // before 'func' - t.Fatalf("expected inParamList false for cursor before params") - } - if inParamList(line, len(line)) { // after ')' - t.Fatalf("expected inParamList false for cursor after params") - } + line := "func foo(a int, b string) int {" + cases := []struct{ + name string + cursor int + want bool + }{ + {"inside-params", 15, true}, + {"before-func", 2, false}, + {"after-paren", len(line), false}, + {"at-open-paren", strings.Index(line, "(")+1, true}, + {"at-close-paren", strings.Index(line, ")"), true}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := inParamList(line, tc.cursor) + if got != tc.want { + t.Fatalf("cursor=%d got %v want %v", tc.cursor, got, tc.want) + } + }) + } } func TestComputeWordStart(t *testing.T) { |
