summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorPaul Buetow <paul@buetow.org>2026-03-02 14:31:26 +0200
committerPaul Buetow <paul@buetow.org>2026-03-02 14:31:26 +0200
commit3e3d1cb988b457a55425a7334f8f5a91cc939666 (patch)
tree0a8130c5c81f7226a5d4c9bab2d5d4afaedadc16 /internal
parent5b04524a7c134e101da1bc7e6a99402ca07ad4cc (diff)
gotest: extract Go codegen heuristics from lsp handlers (task 406)
Diffstat (limited to 'internal')
-rw-r--r--internal/gotest/heuristics.go113
-rw-r--r--internal/gotest/heuristics_test.go39
-rw-r--r--internal/lsp/handlers_codeaction.go95
3 files changed, 157 insertions, 90 deletions
diff --git a/internal/gotest/heuristics.go b/internal/gotest/heuristics.go
new file mode 100644
index 0000000..d7bc8b7
--- /dev/null
+++ b/internal/gotest/heuristics.go
@@ -0,0 +1,113 @@
+package gotest
+
+import "strings"
+
+// ParsePackageName returns the package name from file lines, or empty when missing.
+func ParsePackageName(lines []string) string {
+ for _, ln := range lines {
+ t := strings.TrimSpace(ln)
+ if !strings.HasPrefix(t, "package ") {
+ continue
+ }
+ name := strings.TrimSpace(strings.TrimPrefix(t, "package "))
+ if i := strings.Index(name, " "); i >= 0 {
+ name = name[:i]
+ }
+ if i := strings.Index(name, "\t"); i >= 0 {
+ name = name[:i]
+ }
+ if i := strings.Index(name, "//"); i >= 0 {
+ name = strings.TrimSpace(name[:i])
+ }
+ return name
+ }
+ return ""
+}
+
+// FindFunctionAtLine finds the function enclosing or preceding idx.
+// It returns start/end line indexes or -1/-1 when no function is found.
+func FindFunctionAtLine(lines []string, idx int) (int, int) {
+ if len(lines) == 0 {
+ return -1, -1
+ }
+ if idx < 0 {
+ idx = 0
+ }
+ if idx >= len(lines) {
+ idx = len(lines) - 1
+ }
+
+ start := -1
+ for i := idx; i >= 0; i-- {
+ if strings.Contains(lines[i], "func ") {
+ start = i
+ break
+ }
+ if strings.Contains(lines[i], "}") {
+ break
+ }
+ }
+ if start == -1 {
+ return -1, -1
+ }
+
+ depth := 0
+ seenOpen := false
+ for i := start; i < len(lines); i++ {
+ line := lines[i]
+ for j := 0; j < len(line); j++ {
+ switch line[j] {
+ case '{':
+ depth++
+ seenOpen = true
+ case '}':
+ if depth > 0 {
+ depth--
+ }
+ if seenOpen && depth == 0 {
+ return start, i
+ }
+ }
+ }
+ }
+ if !seenOpen {
+ return start, start
+ }
+ return start, -1
+}
+
+// DeriveFuncName extracts function or method name from Go source snippet.
+func DeriveFuncName(code string) string {
+ line := strings.TrimSpace(firstLine(code))
+ if !strings.HasPrefix(line, "func ") {
+ return ""
+ }
+ rest := strings.TrimSpace(strings.TrimPrefix(line, "func "))
+ if strings.HasPrefix(rest, "(") {
+ if i := strings.Index(rest, ")"); i >= 0 && i+1 < len(rest) {
+ rest = strings.TrimSpace(rest[i+1:])
+ }
+ }
+ if i := strings.Index(rest, "("); i > 0 {
+ return strings.TrimSpace(rest[:i])
+ }
+ return ""
+}
+
+// ExportName upper-cases the first character for use in Test* names.
+func ExportName(name string) string {
+ if name == "" {
+ return ""
+ }
+ r := []rune(name)
+ first := string(r[0])
+ r[0] = []rune(strings.ToUpper(first))[0]
+ return string(r)
+}
+
+func firstLine(s string) string {
+ if i := strings.IndexByte(s, '\n'); i >= 0 {
+ return s[:i]
+ }
+ return s
+}
diff --git a/internal/gotest/heuristics_test.go b/internal/gotest/heuristics_test.go
new file mode 100644
index 0000000..831262d
--- /dev/null
+++ b/internal/gotest/heuristics_test.go
@@ -0,0 +1,39 @@
+package gotest
+
+import "testing"
+
+func TestParsePackageName(t *testing.T) {
+ lines := []string{"// comment", "package mypkg // trailing"}
+ if got := ParsePackageName(lines); got != "mypkg" {
+ t.Fatalf("got %q", got)
+ }
+ if got := ParsePackageName([]string{"no package"}); got != "" {
+ t.Fatalf("expected empty package name")
+ }
+}
+
+func TestFindFunctionAtLine_NoBody(t *testing.T) {
+ lines := []string{"func X(a int)", "// comment"}
+ start, end := FindFunctionAtLine(lines, 0)
+ if start != 0 || end != 0 {
+ t.Fatalf("expected single-line prototype, got %d,%d", start, end)
+ }
+}
+
+func TestDeriveFuncName(t *testing.T) {
+ if got := DeriveFuncName("func Sum(a int) int { return a }"); got != "Sum" {
+ t.Fatalf("got %q", got)
+ }
+ if got := DeriveFuncName("func (t *Type) Method(x int) {}"); got != "Method" {
+ t.Fatalf("got %q", got)
+ }
+}
+
+func TestExportName(t *testing.T) {
+ if got := ExportName("sum"); got != "Sum" {
+ t.Fatalf("got %q", got)
+ }
+ if got := ExportName("Sum"); got != "Sum" {
+ t.Fatalf("got %q", got)
+ }
+}
diff --git a/internal/lsp/handlers_codeaction.go b/internal/lsp/handlers_codeaction.go
index 58d7134..d393bd4 100644
--- a/internal/lsp/handlers_codeaction.go
+++ b/internal/lsp/handlers_codeaction.go
@@ -10,6 +10,7 @@ import (
"time"
"codeberg.org/snonux/hexai/internal/appconfig"
+ "codeberg.org/snonux/hexai/internal/gotest"
"codeberg.org/snonux/hexai/internal/llm"
"codeberg.org/snonux/hexai/internal/logging"
)
@@ -678,73 +679,12 @@ func fileExists(path string) bool {
// parseGoPackageName returns the package name from file lines, or empty if not found.
func parseGoPackageName(lines []string) string {
- for _, ln := range lines {
- t := strings.TrimSpace(ln)
- if strings.HasPrefix(t, "package ") {
- name := strings.TrimSpace(strings.TrimPrefix(t, "package "))
- // strip inline comments
- if i := strings.Index(name, " "); i >= 0 {
- name = name[:i]
- }
- if i := strings.Index(name, "\t"); i >= 0 {
- name = name[:i]
- }
- if i := strings.Index(name, "//"); i >= 0 {
- name = strings.TrimSpace(name[:i])
- }
- return name
- }
- }
- return ""
+ return gotest.ParsePackageName(lines)
}
// findGoFunctionAtLine finds the function enclosing or preceding line idx. Returns start and end line indexes.
func findGoFunctionAtLine(lines []string, idx int) (int, int) {
- if idx < 0 {
- idx = 0
- }
- if idx >= len(lines) {
- idx = len(lines) - 1
- }
- // find signature start
- start := -1
- for i := idx; i >= 0; i-- {
- if strings.Contains(lines[i], "func ") {
- start = i
- break
- }
- if strings.Contains(lines[i], "}") {
- break
- }
- }
- if start == -1 {
- return -1, -1
- }
- // find first '{'
- depth := 0
- seenOpen := false
- for i := start; i < len(lines); i++ {
- ln := lines[i]
- for j := 0; j < len(ln); j++ {
- switch ln[j] {
- case '{':
- depth++
- seenOpen = true
- case '}':
- if depth > 0 {
- depth--
- }
- if seenOpen && depth == 0 {
- return start, i
- }
- }
- }
- }
- // if never saw '{', assume single-line prototype; return that line
- if !seenOpen {
- return start, start
- }
- return start, -1
+ return gotest.FindFunctionAtLine(lines, idx)
}
// generateGoTestFunction uses LLM to produce a test function; falls back to a stub when unavailable.
@@ -774,34 +714,9 @@ func (s *Server) generateGoTestFunction(funcCode string) string {
// deriveGoFuncName extracts function or method name from code.
func deriveGoFuncName(code string) string {
- // look for line starting with func
- line := firstLine(code)
- line = strings.TrimSpace(line)
- if !strings.HasPrefix(line, "func ") {
- return ""
- }
- rest := strings.TrimSpace(strings.TrimPrefix(line, "func "))
- // method receiver
- if strings.HasPrefix(rest, "(") {
- // find ")"
- if i := strings.Index(rest, ")"); i >= 0 && i+1 < len(rest) {
- rest = strings.TrimSpace(rest[i+1:])
- }
- }
- // now rest should start with Name(
- if i := strings.Index(rest, "("); i > 0 {
- return strings.TrimSpace(rest[:i])
- }
- return ""
+ return gotest.DeriveFuncName(code)
}
func exportName(name string) string {
- if name == "" {
- return name
- }
- r := []rune(name)
- if r[0] >= 'a' && r[0] <= 'z' {
- r[0] = r[0] - ('a' - 'A')
- }
- return string(r)
+ return gotest.ExportName(name)
}