summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
Diffstat (limited to 'internal')
-rw-r--r--internal/lsp/handlers_codeaction.go289
-rw-r--r--internal/lsp/handlers_document.go39
-rw-r--r--internal/lsp/handlers_execute.go35
-rw-r--r--internal/lsp/server.go1
-rw-r--r--internal/lsp/types.go35
-rw-r--r--internal/version.go2
6 files changed, 367 insertions, 34 deletions
diff --git a/internal/lsp/handlers_codeaction.go b/internal/lsp/handlers_codeaction.go
index 4407ac0..ad11861 100644
--- a/internal/lsp/handlers_codeaction.go
+++ b/internal/lsp/handlers_codeaction.go
@@ -9,6 +9,8 @@ import (
"codeberg.org/snonux/hexai/internal/logging"
"strings"
"time"
+ "os"
+ "path/filepath"
)
func (s *Server) handleCodeAction(req Request) {
@@ -26,24 +28,21 @@ func (s *Server) handleCodeAction(req Request) {
}
return
}
- sel := extractRangeText(d, p.Range)
- if strings.TrimSpace(sel) == "" {
- if len(req.ID) != 0 {
- s.reply(req.ID, []CodeAction{}, nil)
- }
- return
- }
+ sel := extractRangeText(d, p.Range)
- actions := make([]CodeAction, 0, 2)
- if a := s.buildRewriteCodeAction(p, sel); a != nil {
- actions = append(actions, *a)
- }
- if a := s.buildDiagnosticsCodeAction(p, sel); a != nil {
- actions = append(actions, *a)
- }
- if len(req.ID) != 0 {
- s.reply(req.ID, actions, nil)
- }
+ actions := make([]CodeAction, 0, 3)
+ if a := s.buildRewriteCodeAction(p, sel); a != nil {
+ actions = append(actions, *a)
+ }
+ if a := s.buildDiagnosticsCodeAction(p, sel); a != nil {
+ actions = append(actions, *a)
+ }
+ if a := s.buildGoUnitTestCodeAction(p); a != nil {
+ actions = append(actions, *a)
+ }
+ if len(req.ID) != 0 {
+ s.reply(req.ID, actions, nil)
+ }
}
func (s *Server) buildRewriteCodeAction(p CodeActionParams, sel string) *CodeAction {
@@ -94,8 +93,8 @@ func (s *Server) resolveCodeAction(ca CodeAction) (CodeAction, bool) {
if err := json.Unmarshal(ca.Data, &payload); err != nil {
return ca, false
}
- switch payload.Type {
- case "rewrite":
+ switch payload.Type {
+ case "rewrite":
sys := "You are a precise code refactoring engine. Rewrite the given code strictly according to the instruction. Return only the updated code with no prose or backticks. Preserve formatting where reasonable."
user := fmt.Sprintf("Instruction: %s\n\nSelected code to transform:\n%s", payload.Instruction, payload.Selection)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
@@ -111,7 +110,7 @@ func (s *Server) resolveCodeAction(ca CodeAction) (CodeAction, bool) {
} else {
logging.Logf("lsp ", "codeAction rewrite llm error: %v", err)
}
- case "diagnostics":
+ case "diagnostics":
sys := "You are a precise code fixer. Resolve the given diagnostics by editing only the selected code. Return only the corrected code with no prose or backticks. Keep behavior and style, and avoid unrelated changes."
var b strings.Builder
b.WriteString("Diagnostics to resolve (selection only):\n")
@@ -137,8 +136,18 @@ func (s *Server) resolveCodeAction(ca CodeAction) (CodeAction, bool) {
} else {
logging.Logf("lsp ", "codeAction diagnostics llm error: %v", err)
}
- }
- return ca, false
+ case "go_test":
+ if edit, jumpURI, jumpRange, ok := s.resolveGoTest(payload.URI, payload.Range.Start); ok {
+ ca.Edit = &edit
+ // After edit is applied, ask client to jump to new test function
+ ca.Command = &Command{Title: "Jump to generated test", Command: "hexai.showDocument", Arguments: []any{jumpURI, jumpRange}}
+ // Also send a server-initiated showDocument shortly after resolve to cover
+ // clients that do not execute commands from code actions.
+ s.deferShowDocument(jumpURI, jumpRange)
+ return ca, true
+ }
+ }
+ return ca, false
}
func (s *Server) handleCodeActionResolve(req Request) {
@@ -212,3 +221,239 @@ func greaterPos(p, q Position) bool {
}
return p.Character > q.Character
}
+
+// --- Go unit test code action ---
+
+func (s *Server) buildGoUnitTestCodeAction(p CodeActionParams) *CodeAction {
+ uri := p.TextDocument.URI
+ if uri == "" || !strings.HasSuffix(strings.TrimPrefix(uri, "file://"), ".go") {
+ return nil
+ }
+ // Skip if already a _test.go file
+ if strings.HasSuffix(strings.TrimPrefix(uri, "file://"), "_test.go") {
+ return nil
+ }
+ // Heuristic: only offer when a function context is found above the cursor
+ _, _, _, funcCtx := s.lineContext(uri, p.Range.Start)
+ if !strings.Contains(funcCtx, "func ") {
+ return nil
+ }
+ payload := struct {
+ Type string `json:"type"`
+ URI string `json:"uri"`
+ Range Range `json:"range"`
+ }{Type: "go_test", URI: uri, Range: p.Range}
+ raw, _ := json.Marshal(payload)
+ ca := CodeAction{Title: "Hexai: implement unit test", Kind: "quickfix", Data: raw}
+ return &ca
+}
+
+func (s *Server) resolveGoTest(uri string, pos Position) (WorkspaceEdit, string, Range, bool) {
+ path := strings.TrimPrefix(uri, "file://")
+ if !strings.HasSuffix(path, ".go") || strings.HasSuffix(path, "_test.go") {
+ return WorkspaceEdit{}, "", Range{}, false
+ }
+ // Load source text
+ _, lines := s.loadFileText(uri)
+ if len(lines) == 0 {
+ return WorkspaceEdit{}, "", Range{}, false
+ }
+ pkg := parseGoPackageName(lines)
+ fnStart, fnEnd := findGoFunctionAtLine(lines, pos.Line)
+ if fnStart < 0 || fnEnd < fnStart {
+ return WorkspaceEdit{}, "", Range{}, false
+ }
+ funcCode := strings.Join(lines[fnStart:fnEnd+1], "\n")
+ testFunc := s.generateGoTestFunction(funcCode)
+ if strings.TrimSpace(testFunc) == "" {
+ return WorkspaceEdit{}, "", Range{}, false
+ }
+ // Determine test file target
+ testPath := strings.TrimSuffix(path, ".go") + "_test.go"
+ testURI := "file://" + testPath
+
+ // If test file exists, append test at EOF; otherwise, create a new file with package+import
+ if fileExists(testPath) {
+ // Build an insertion at end of file
+ _, tLines := s.loadFileText(testURI)
+ // Fallback when not open and cannot read: still insert at line 0
+ lineIdx := 0
+ col := 0
+ if len(tLines) > 0 {
+ lineIdx = len(tLines) - 1
+ col = len(tLines[lineIdx])
+ }
+ var b strings.Builder
+ // Ensure at least two newlines before the new test
+ if len(tLines) == 0 || (len(tLines) > 0 && !strings.HasSuffix(strings.Join(tLines, "\n"), "\n\n")) {
+ b.WriteString("\n\n")
+ }
+ b.WriteString(testFunc)
+ insert := b.String()
+ edit := TextEdit{Range: Range{Start: Position{Line: lineIdx, Character: col}, End: Position{Line: lineIdx, Character: col}}, NewText: insert}
+ we := WorkspaceEdit{Changes: map[string][]TextEdit{testURI: {edit}}}
+ // Compute jump range start
+ // Count how many prefix newlines added before the test function
+ prefixNL := 0
+ if strings.HasPrefix(insert, "\n\n") { prefixNL = 2 }
+ startLine := lineIdx + prefixNL
+ // If we inserted with two newlines and last line wasn't blank, first newline moves to next line
+ if prefixNL > 0 { startLine = lineIdx + prefixNL }
+ jump := Range{Start: Position{Line: startLine, Character: 0}, End: Position{Line: startLine, Character: 0}}
+ return we, testURI, jump, true
+ }
+ // Create new file content
+ var content strings.Builder
+ if pkg == "" { pkg = filepath.Base(filepath.Dir(path)) }
+ content.WriteString("package ")
+ content.WriteString(pkg)
+ content.WriteString("\n\n")
+ content.WriteString("import (\n\t\"testing\"\n)\n\n")
+ content.WriteString(testFunc)
+ full := content.String()
+ // Use documentChanges with create + full content insert
+ create := CreateFile{Kind: "create", URI: testURI}
+ tde := TextDocumentEdit{TextDocument: VersionedTextDocumentIdentifier{URI: testURI}, Edits: []TextEdit{{Range: Range{Start: Position{Line: 0, Character: 0}, End: Position{Line: 0, Character: 0}}, NewText: full}}}
+ we := WorkspaceEdit{DocumentChanges: []any{create, tde}}
+ // Find start line of first test function
+ // Count lines before the substring "func Test"
+ pre := content.String()
+ idx := strings.Index(pre, "func Test")
+ startLine := 0
+ if idx > 0 {
+ before := pre[:idx]
+ startLine = strings.Count(before, "\n")
+ }
+ jump := Range{Start: Position{Line: startLine, Character: 0}, End: Position{Line: startLine, Character: 0}}
+ return we, testURI, jump, true
+}
+
+// loadFileText returns the file content and lines. It prefers the open document; otherwise reads from disk.
+func (s *Server) loadFileText(uri string) (string, []string) {
+ if d := s.getDocument(uri); d != nil {
+ return d.text, append([]string{}, d.lines...)
+ }
+ path := strings.TrimPrefix(uri, "file://")
+ b, err := os.ReadFile(path)
+ if err != nil {
+ return "", nil
+ }
+ txt := string(b)
+ return txt, splitLines(txt)
+}
+
+func fileExists(path string) bool {
+ if _, err := os.Stat(path); err == nil {
+ return true
+ }
+ return false
+}
+
+// parseGoPackageName returns the package name from file lines, or empty if not found.
+func parseGoPackageName(lines []string) string {
+ for _, ln := range lines {
+ t := strings.TrimSpace(ln)
+ if strings.HasPrefix(t, "package ") {
+ name := strings.TrimSpace(strings.TrimPrefix(t, "package "))
+ // strip inline comments
+ if i := strings.Index(name, " "); i >= 0 { name = name[:i] }
+ if i := strings.Index(name, "\t"); i >= 0 { name = name[:i] }
+ if i := strings.Index(name, "//"); i >= 0 { name = strings.TrimSpace(name[:i]) }
+ return name
+ }
+ }
+ return ""
+}
+
+// findGoFunctionAtLine finds the function enclosing or preceding line idx. Returns start and end line indexes.
+func findGoFunctionAtLine(lines []string, idx int) (int, int) {
+ if idx < 0 { idx = 0 }
+ if idx >= len(lines) { idx = len(lines)-1 }
+ // find signature start
+ start := -1
+ for i := idx; i >= 0; i-- {
+ if strings.Contains(lines[i], "func ") {
+ start = i
+ break
+ }
+ if strings.Contains(lines[i], "}") {
+ break
+ }
+ }
+ if start == -1 { return -1, -1 }
+ // find first '{'
+ depth := 0
+ seenOpen := false
+ for i := start; i < len(lines); i++ {
+ ln := lines[i]
+ for j := 0; j < len(ln); j++ {
+ switch ln[j] {
+ case '{':
+ depth++
+ seenOpen = true
+ case '}':
+ if depth > 0 { depth-- }
+ if seenOpen && depth == 0 {
+ return start, i
+ }
+ }
+ }
+ }
+ // if never saw '{', assume single-line prototype; return that line
+ if !seenOpen {
+ return start, start
+ }
+ return start, -1
+}
+
+// generateGoTestFunction uses LLM to produce a test function; falls back to a stub when unavailable.
+func (s *Server) generateGoTestFunction(funcCode string) string {
+ if s.llmClient != nil {
+ sys := "You are a precise Go unit test generator. Given a Go function, write one or more Test* functions using the testing package. Do NOT include package or imports, only the test function(s). Prefer table-driven tests. Keep it minimal and idiomatic."
+ user := "Function under test:\n" + funcCode
+ ctx, cancel := context.WithTimeout(context.Background(), 8*time.Second)
+ defer cancel()
+ messages := []llm.Message{{Role: "system", Content: sys}, {Role: "user", Content: user}}
+ opts := s.llmRequestOpts()
+ if out, err := s.llmClient.Chat(ctx, messages, opts...); err == nil {
+ cleaned := strings.TrimSpace(stripCodeFences(out))
+ if cleaned != "" { return cleaned }
+ } else {
+ logging.Logf("lsp ", "codeAction go_test llm error: %v", err)
+ }
+ }
+ // Fallback stub
+ name := deriveGoFuncName(funcCode)
+ if name == "" { name = "Function" }
+ return fmt.Sprintf("func Test%s(t *testing.T) {\n\t// TODO: implement tests for %s\n}\n", exportName(name), name)
+}
+
+// deriveGoFuncName extracts function or method name from code.
+func deriveGoFuncName(code string) string {
+ // look for line starting with func
+ line := firstLine(code)
+ line = strings.TrimSpace(line)
+ if !strings.HasPrefix(line, "func ") { return "" }
+ rest := strings.TrimSpace(strings.TrimPrefix(line, "func "))
+ // method receiver
+ if strings.HasPrefix(rest, "(") {
+ // find ")"
+ if i := strings.Index(rest, ")"); i >= 0 && i+1 < len(rest) {
+ rest = strings.TrimSpace(rest[i+1:])
+ }
+ }
+ // now rest should start with Name(
+ if i := strings.Index(rest, "("); i > 0 {
+ return strings.TrimSpace(rest[:i])
+ }
+ return ""
+}
+
+func exportName(name string) string {
+ if name == "" { return name }
+ r := []rune(name)
+ if r[0] >= 'a' && r[0] <= 'z' {
+ r[0] = r[0] - ('a' - 'A')
+ }
+ return string(r)
+}
diff --git a/internal/lsp/handlers_document.go b/internal/lsp/handlers_document.go
index 53c1588..5b83d78 100644
--- a/internal/lsp/handlers_document.go
+++ b/internal/lsp/handlers_document.go
@@ -254,12 +254,12 @@ func stripTrailingTrigger(sx string) string {
// clientApplyEdit sends a workspace/applyEdit request to the client.
func (s *Server) clientApplyEdit(label string, edit WorkspaceEdit) {
- params := ApplyWorkspaceEditParams{Label: label, Edit: edit}
- id := s.nextReqID()
- req := Request{JSONRPC: "2.0", ID: id, Method: "workspace/applyEdit"}
- b, _ := json.Marshal(params)
- req.Params = b
- s.writeMessage(req)
+ params := ApplyWorkspaceEditParams{Label: label, Edit: edit}
+ id := s.nextReqID()
+ req := Request{JSONRPC: "2.0", ID: id, Method: "workspace/applyEdit"}
+ b, _ := json.Marshal(params)
+ req.Params = b
+ s.writeMessage(req)
}
// nextReqID returns a unique json.RawMessage id for server-initiated requests.
@@ -271,3 +271,30 @@ func (s *Server) nextReqID() json.RawMessage {
b, _ := json.Marshal(idNum)
return b
}
+
+// clientShowDocument asks the client to open/focus a document and select a range.
+func (s *Server) clientShowDocument(uri string, sel *Range) {
+ var params struct {
+ URI string `json:"uri"`
+ External bool `json:"external,omitempty"`
+ TakeFocus bool `json:"takeFocus,omitempty"`
+ Selection *Range `json:"selection,omitempty"`
+ }
+ params.URI = uri
+ params.TakeFocus = true
+ params.Selection = sel
+ id := s.nextReqID()
+ req := Request{JSONRPC: "2.0", ID: id, Method: "window/showDocument"}
+ b, _ := json.Marshal(params)
+ req.Params = b
+ s.writeMessage(req)
+}
+
+// deferShowDocument schedules a showDocument after a short delay to allow the client
+// time to apply any pending edits (e.g., create the file before focusing it).
+func (s *Server) deferShowDocument(uri string, sel Range) {
+ go func() {
+ time.Sleep(120 * time.Millisecond)
+ s.clientShowDocument(uri, &sel)
+ }()
+}
diff --git a/internal/lsp/handlers_execute.go b/internal/lsp/handlers_execute.go
new file mode 100644
index 0000000..2e3ec52
--- /dev/null
+++ b/internal/lsp/handlers_execute.go
@@ -0,0 +1,35 @@
+// Summary: ExecuteCommand handler to support post-edit navigation (jump to generated test).
+package lsp
+
+import (
+ "encoding/json"
+)
+
+func (s *Server) handleExecuteCommand(req Request) {
+ var p ExecuteCommandParams
+ if err := json.Unmarshal(req.Params, &p); err != nil {
+ s.reply(req.ID, nil, nil)
+ return
+ }
+ switch p.Command {
+ case "hexai.showDocument":
+ if len(p.Arguments) >= 2 {
+ uri, _ := p.Arguments[0].(string)
+ var r Range
+ // Convert second arg to Range via re-marshal to be robust across clients
+ if b, err := json.Marshal(p.Arguments[1]); err == nil {
+ _ = json.Unmarshal(b, &r)
+ }
+ if uri != "" {
+ s.clientShowDocument(uri, &r)
+ }
+ }
+ s.reply(req.ID, nil, nil)
+ return
+ default:
+ // Unknown command; no-op
+ s.reply(req.ID, nil, nil)
+ return
+ }
+}
+
diff --git a/internal/lsp/server.go b/internal/lsp/server.go
index 8af64ec..7a1007e 100644
--- a/internal/lsp/server.go
+++ b/internal/lsp/server.go
@@ -121,6 +121,7 @@ func NewServer(r io.Reader, w io.Writer, logger *log.Logger, opts ServerOptions)
"textDocument/completion": s.handleCompletion,
"textDocument/codeAction": s.handleCodeAction,
"codeAction/resolve": s.handleCodeActionResolve,
+ "workspace/executeCommand": s.handleExecuteCommand,
}
return s
}
diff --git a/internal/lsp/types.go b/internal/lsp/types.go
index 5169d44..1598b96 100644
--- a/internal/lsp/types.go
+++ b/internal/lsp/types.go
@@ -124,7 +124,8 @@ type CodeActionParams struct {
}
type WorkspaceEdit struct {
- Changes map[string][]TextEdit `json:"changes,omitempty"`
+ Changes map[string][]TextEdit `json:"changes,omitempty"`
+ DocumentChanges []any `json:"documentChanges,omitempty"`
}
// ApplyWorkspaceEditParams is the client request payload for workspace/applyEdit.
@@ -134,10 +135,34 @@ type ApplyWorkspaceEditParams struct {
}
type CodeAction struct {
- Title string `json:"title"`
- Kind string `json:"kind,omitempty"`
- Edit *WorkspaceEdit `json:"edit,omitempty"`
- Data json.RawMessage `json:"data,omitempty"`
+ Title string `json:"title"`
+ Kind string `json:"kind,omitempty"`
+ Edit *WorkspaceEdit `json:"edit,omitempty"`
+ Data json.RawMessage `json:"data,omitempty"`
+ Command *Command `json:"command,omitempty"`
+}
+
+// Extended workspace edit types (minimal subset)
+type TextDocumentEdit struct {
+ TextDocument VersionedTextDocumentIdentifier `json:"textDocument"`
+ Edits []TextEdit `json:"edits"`
+}
+
+type CreateFile struct {
+ Kind string `json:"kind"`
+ URI string `json:"uri"`
+}
+
+// Commands
+type Command struct {
+ Title string `json:"title"`
+ Command string `json:"command"`
+ Arguments []any `json:"arguments,omitempty"`
+}
+
+type ExecuteCommandParams struct {
+ Command string `json:"command"`
+ Arguments []any `json:"arguments,omitempty"`
}
// Diagnostics (subset needed for code action context)
diff --git a/internal/version.go b/internal/version.go
index 372522c..77c279b 100644
--- a/internal/version.go
+++ b/internal/version.go
@@ -1,4 +1,4 @@
// Summary: Hexai semantic version identifier used by CLI and LSP binaries.
package internal
-const Version = "0.4.0"
+const Version = "0.4.1"