diff options
| author | Paul Buetow <paul@buetow.org> | 2025-09-19 22:52:48 +0300 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2025-09-19 22:52:48 +0300 |
| commit | eb72b06fe8e62cb77af73f6dc558d384a5a5fe80 (patch) | |
| tree | efeb1165b9fbcb69a4ee675dba7bdc8c28fee3aa /internal | |
| parent | acc400768153a7bfda1413f15579c9455b877c87 (diff) | |
fix
Diffstat (limited to 'internal')
33 files changed, 455 insertions, 229 deletions
diff --git a/internal/appconfig/config.go b/internal/appconfig/config.go index 2274aee..9119688 100644 --- a/internal/appconfig/config.go +++ b/internal/appconfig/config.go @@ -247,9 +247,36 @@ type sectionStats struct { } type sectionOpenAI struct { - Model string `toml:"model"` - BaseURL string `toml:"base_url"` - Temperature *float64 `toml:"temperature"` + Model string `toml:"model"` + BaseURL string `toml:"base_url"` + Temperature *float64 `toml:"temperature"` + Presets map[string]string `toml:"presets"` +} + +func (s sectionOpenAI) isZero() bool { + return strings.TrimSpace(s.Model) == "" && strings.TrimSpace(s.BaseURL) == "" && s.Temperature == nil && len(s.Presets) == 0 +} + +func (s sectionOpenAI) resolvedModel() string { + model := strings.TrimSpace(s.Model) + if model == "" { + return "" + } + if len(s.Presets) == 0 { + return model + } + if mapped := strings.TrimSpace(s.Presets[model]); mapped != "" { + return mapped + } + lower := strings.ToLower(model) + for k, v := range s.Presets { + if strings.ToLower(strings.TrimSpace(k)) == lower { + if mapped := strings.TrimSpace(v); mapped != "" { + return mapped + } + } + } + return model } type sectionCopilot struct { @@ -380,10 +407,10 @@ func (fc *fileConfig) toApp() App { } // openai - if (fc.OpenAI != sectionOpenAI{}) || fc.OpenAI.Temperature != nil { + if !fc.OpenAI.isZero() || fc.OpenAI.Temperature != nil { tmp := App{ OpenAIBaseURL: fc.OpenAI.BaseURL, - OpenAIModel: fc.OpenAI.Model, + OpenAIModel: fc.OpenAI.resolvedModel(), OpenAITemperature: fc.OpenAI.Temperature, } out.mergeProviderFields(&tmp) @@ -939,13 +966,46 @@ func loadFromEnv(logger *log.Logger) *App { any = true } + modelForce := strings.TrimSpace(getenv("HEXAI_MODEL_FORCE")) + modelGeneric := strings.TrimSpace(getenv("HEXAI_MODEL")) + providerLower := strings.ToLower(strings.TrimSpace(out.Provider)) + forceUsed := false + genericUsed := false + pickModel := func(providerName, specific string) (string, bool) { + specific = strings.TrimSpace(specific) + nameLower := strings.ToLower(strings.TrimSpace(providerName)) + if modelForce != "" { + if providerLower == nameLower { + forceUsed = true + return modelForce, true + } + if providerLower == "" && !forceUsed { + forceUsed = true + return modelForce, true + } + } + if specific != "" { + return specific, true + } + if modelGeneric != "" { + if providerLower == nameLower { + return modelGeneric, true + } + if providerLower == "" && !genericUsed { + genericUsed = true + return modelGeneric, true + } + } + return "", false + } + // Provider-specific if s := getenv("HEXAI_OPENAI_BASE_URL"); s != "" { out.OpenAIBaseURL = s any = true } - if s := getenv("HEXAI_OPENAI_MODEL"); s != "" { - out.OpenAIModel = s + if model, ok := pickModel("openai", getenv("HEXAI_OPENAI_MODEL")); ok { + out.OpenAIModel = model any = true } if f, ok := parseFloatPtr("HEXAI_OPENAI_TEMPERATURE"); ok { @@ -957,8 +1017,8 @@ func loadFromEnv(logger *log.Logger) *App { out.OllamaBaseURL = s any = true } - if s := getenv("HEXAI_OLLAMA_MODEL"); s != "" { - out.OllamaModel = s + if model, ok := pickModel("ollama", getenv("HEXAI_OLLAMA_MODEL")); ok { + out.OllamaModel = model any = true } if f, ok := parseFloatPtr("HEXAI_OLLAMA_TEMPERATURE"); ok { @@ -970,8 +1030,8 @@ func loadFromEnv(logger *log.Logger) *App { out.CopilotBaseURL = s any = true } - if s := getenv("HEXAI_COPILOT_MODEL"); s != "" { - out.CopilotModel = s + if model, ok := pickModel("copilot", getenv("HEXAI_COPILOT_MODEL")); ok { + out.CopilotModel = model any = true } if f, ok := parseFloatPtr("HEXAI_COPILOT_TEMPERATURE"); ok { diff --git a/internal/hexaiaction/run.go b/internal/hexaiaction/run.go index 45eacc2..a48bf94 100644 --- a/internal/hexaiaction/run.go +++ b/internal/hexaiaction/run.go @@ -73,54 +73,85 @@ func executeAction(ctx context.Context, kind ActionKind, parts InputParts, cfg a case ActionSkip: return parts.Selection, nil case ActionRewrite: - instr, cleaned := ExtractInstruction(parts.Selection) - if strings.TrimSpace(instr) == "" { - fmt.Fprintln(stderr, logging.AnsiBase+"hexai-tmux-action: no inline instruction found; echoing input"+logging.AnsiReset) - return parts.Selection, nil - } - cctx, cancel := timeout10s(ctx) - defer cancel() - return runRewrite(cctx, cfg, client, instr, cleaned) + return handleRewriteAction(ctx, parts, cfg, client, stderr) case ActionDiagnostics: - cctx, cancel := timeout10s(ctx) - defer cancel() - return runDiagnostics(cctx, cfg, client, parts.Diagnostics, parts.Selection) + return handleDiagnosticsAction(ctx, parts, cfg, client) case ActionDocument: - cctx, cancel := timeout10s(ctx) - defer cancel() - return runDocument(cctx, cfg, client, parts.Selection) + return handleDocumentAction(ctx, parts, cfg, client) case ActionGoTest: - cctx, cancel := timeout8s(ctx) - defer cancel() - return runGoTest(cctx, cfg, client, parts.Selection) + return handleGoTestAction(ctx, parts, cfg, client) case ActionSimplify: - cctx, cancel := timeout10s(ctx) - defer cancel() - return runSimplify(cctx, cfg, client, parts.Selection) + return handleSimplifyAction(ctx, parts, cfg, client) case ActionCustom: - cctx, cancel := timeout10s(ctx) - defer cancel() - if selectedCustom != nil { - // Run configured custom action - out, err := runCustom(cctx, cfg, client, *selectedCustom, parts) - selectedCustom = nil // clear after use - return out, err - } - // No selected custom; treat as no-op - return parts.Selection, nil + return handleCustomAction(ctx, parts, cfg, client) case ActionCustomPrompt: - cctx, cancel := timeout10s(ctx) - defer cancel() - // Open editor for free-form instruction - prompt, err := editor.OpenTempAndEdit(nil) - if err != nil || strings.TrimSpace(prompt) == "" { - fmt.Fprintln(stderr, logging.AnsiBase+"hexai-tmux-action: custom prompt canceled or empty; echoing input"+logging.AnsiReset) - return parts.Selection, nil - } - return runRewrite(cctx, cfg, client, prompt, parts.Selection) + return handleCustomPromptAction(ctx, parts, cfg, client, stderr) default: return parts.Selection, nil } } +func handleRewriteAction(ctx context.Context, parts InputParts, cfg appconfig.App, client chatDoer, stderr io.Writer) (string, error) { + instr, cleaned := ExtractInstruction(parts.Selection) + if strings.TrimSpace(instr) == "" { + fmt.Fprintln(stderr, logging.AnsiBase+"hexai-tmux-action: no inline instruction found; echoing input"+logging.AnsiReset) + return parts.Selection, nil + } + return runWithTimeout(ctx, timeout10s, func(cctx context.Context) (string, error) { + return runRewrite(cctx, cfg, client, instr, cleaned) + }) +} + +func handleDiagnosticsAction(ctx context.Context, parts InputParts, cfg appconfig.App, client chatDoer) (string, error) { + return runWithTimeout(ctx, timeout10s, func(cctx context.Context) (string, error) { + return runDiagnostics(cctx, cfg, client, parts.Diagnostics, parts.Selection) + }) +} + +func handleDocumentAction(ctx context.Context, parts InputParts, cfg appconfig.App, client chatDoer) (string, error) { + return runWithTimeout(ctx, timeout10s, func(cctx context.Context) (string, error) { + return runDocument(cctx, cfg, client, parts.Selection) + }) +} + +func handleGoTestAction(ctx context.Context, parts InputParts, cfg appconfig.App, client chatDoer) (string, error) { + return runWithTimeout(ctx, timeout8s, func(cctx context.Context) (string, error) { + return runGoTest(cctx, cfg, client, parts.Selection) + }) +} + +func handleSimplifyAction(ctx context.Context, parts InputParts, cfg appconfig.App, client chatDoer) (string, error) { + return runWithTimeout(ctx, timeout10s, func(cctx context.Context) (string, error) { + return runSimplify(cctx, cfg, client, parts.Selection) + }) +} + +func handleCustomAction(ctx context.Context, parts InputParts, cfg appconfig.App, client chatDoer) (string, error) { + if selectedCustom == nil { + return parts.Selection, nil + } + return runWithTimeout(ctx, timeout10s, func(cctx context.Context) (string, error) { + out, err := runCustom(cctx, cfg, client, *selectedCustom, parts) + selectedCustom = nil + return out, err + }) +} + +func handleCustomPromptAction(ctx context.Context, parts InputParts, cfg appconfig.App, client chatDoer, stderr io.Writer) (string, error) { + prompt, err := editor.OpenTempAndEdit(nil) + if err != nil || strings.TrimSpace(prompt) == "" { + fmt.Fprintln(stderr, logging.AnsiBase+"hexai-tmux-action: custom prompt canceled or empty; echoing input"+logging.AnsiReset) + return parts.Selection, nil + } + return runWithTimeout(ctx, timeout10s, func(cctx context.Context) (string, error) { + return runRewrite(cctx, cfg, client, prompt, parts.Selection) + }) +} + +func runWithTimeout(ctx context.Context, timeout func(context.Context) (context.Context, context.CancelFunc), fn func(context.Context) (string, error)) (string, error) { + innerCtx, cancel := timeout(ctx) + defer cancel() + return fn(innerCtx) +} + // client construction is shared via internal/llmutils diff --git a/internal/hexaicli/run.go b/internal/hexaicli/run.go index 823dcaa..11e8938 100644 --- a/internal/hexaicli/run.go +++ b/internal/hexaicli/run.go @@ -3,7 +3,6 @@ package hexaicli import ( - "bufio" "context" "fmt" "io" @@ -78,8 +77,11 @@ func RunWithClient(ctx context.Context, args []string, stdin io.Reader, stdout, func readInput(stdin io.Reader, args []string) (string, error) { var stdinData string if fi, err := os.Stdin.Stat(); err == nil && (fi.Mode()&os.ModeCharDevice) == 0 { - b, _ := io.ReadAll(bufio.NewReader(stdin)) - stdinData = strings.TrimSpace(string(b)) + data, readErr := io.ReadAll(stdin) + if readErr != nil { + return "", fmt.Errorf("hexai: failed to read stdin: %w", readErr) + } + stdinData = strings.TrimSpace(string(data)) } argData := strings.TrimSpace(strings.Join(args, " ")) switch { diff --git a/internal/hexaicli/run_test.go b/internal/hexaicli/run_test.go index d192850..a4184f6 100644 --- a/internal/hexaicli/run_test.go +++ b/internal/hexaicli/run_test.go @@ -12,6 +12,10 @@ import ( "codeberg.org/snonux/hexai/internal/llm" ) +type failingReader struct{ err error } + +func (f failingReader) Read([]byte) (int, error) { return 0, f.err } + func TestReadInput_Combinations(t *testing.T) { // stdin + arg restore, f := setStdin(t, "from-stdin") @@ -41,6 +45,15 @@ func TestReadInput_Combinations(t *testing.T) { } } +func TestReadInput_PropagatesStdinError(t *testing.T) { + restore, _ := setStdin(t, "ignored") + defer restore() + bad := failingReader{err: io.ErrUnexpectedEOF} + if _, err := readInput(bad, nil); err == nil || !strings.Contains(err.Error(), "failed to read stdin") { + t.Fatalf("expected stdin read error, got %v", err) + } +} + func TestBuildMessages_Explain(t *testing.T) { msgs := buildMessages("please explain this") if len(msgs) != 2 || msgs[0].Role != "system" || !strings.Contains(strings.ToLower(msgs[0].Content), "explanation") { diff --git a/internal/llm/ollama.go b/internal/llm/ollama.go index ce607a7..374a771 100644 --- a/internal/llm/ollama.go +++ b/internal/llm/ollama.go @@ -46,7 +46,7 @@ func newOllama(baseURL, model string, defaultTemp *float64) Client { baseURL = "http://localhost:11434" } if strings.TrimSpace(model) == "" { - model = "qwen3-coder:30b-a3b-q4_K_M`" + model = "qwen3-coder:30b-a3b-q4_K_M" } return ollamaClient{ httpClient: &http.Client{Timeout: 30 * time.Second}, diff --git a/internal/lsp/chat_history_test.go b/internal/lsp/chat_history_test.go index b1cae80..70080f3 100644 --- a/internal/lsp/chat_history_test.go +++ b/internal/lsp/chat_history_test.go @@ -3,19 +3,20 @@ package lsp import "testing" func TestStripTrailingTrigger(t *testing.T) { - if got := stripTrailingTrigger("what?"); got != "what" { + s := newTestServer() + if got := s.stripTrailingTrigger("what?"); got != "what" { t.Fatalf("should remove trailing ?") } - if got := stripTrailingTrigger("what?>"); got != "what?" { + if got := s.stripTrailingTrigger("what?>"); got != "what?" { t.Fatalf("should drop trailing > when preceded by ?") } - if got := stripTrailingTrigger("ok!>"); got != "ok!" { + if got := s.stripTrailingTrigger("ok!>"); got != "ok!" { t.Fatalf("should drop > after !") } - if got := stripTrailingTrigger("note:>"); got != "note:" { + if got := s.stripTrailingTrigger("note:>"); got != "note:" { t.Fatalf("should drop > after :") } - if got := stripTrailingTrigger("go;>"); got != "go;" { + if got := s.stripTrailingTrigger("go;>"); got != "go;" { t.Fatalf("should drop > after ;") } } diff --git a/internal/lsp/chat_no_double_answer_test.go b/internal/lsp/chat_no_double_answer_test.go index 8821cd0..04196f8 100644 --- a/internal/lsp/chat_no_double_answer_test.go +++ b/internal/lsp/chat_no_double_answer_test.go @@ -10,6 +10,7 @@ import ( func TestDetectAndHandleChat_NoDoubleAnswer(t *testing.T) { var out bytes.Buffer s := &Server{logger: log.New(io.Discard, "", 0), docs: make(map[string]*document), out: &out} + initServerDefaults(s) s.llmClient = fakeLLM{resp: "IGNORED"} uri := "file:///x.go" // Question line with trigger, followed by an existing answer line starting with '>' diff --git a/internal/lsp/chat_trigger_suppression_test.go b/internal/lsp/chat_trigger_suppression_test.go index 55a5245..8d016d1 100644 --- a/internal/lsp/chat_trigger_suppression_test.go +++ b/internal/lsp/chat_trigger_suppression_test.go @@ -5,6 +5,7 @@ import "testing" // Ensure completion is suppressed when a chat trigger is at EOL (?>,!>,:>,;>) func TestCompletionSuppressedOnChatTriggerEOL(t *testing.T) { s := &Server{maxTokens: 32, triggerChars: []string{".", ":", "/", "_"}, compCache: make(map[string]string)} + initServerDefaults(s) s.llmClient = &countingLLM{} tests := []string{"What now?>", "Explain!>", "Refactor:>", "note ;>"} for i, line := range tests { diff --git a/internal/lsp/codeaction_custom_test.go b/internal/lsp/codeaction_custom_test.go index 7baf993..1ea4c3c 100644 --- a/internal/lsp/codeaction_custom_test.go +++ b/internal/lsp/codeaction_custom_test.go @@ -27,7 +27,18 @@ func capResp(t *testing.T, buf *bytes.Buffer) Response { func TestHandleCodeAction_ListsCustomActions(t *testing.T) { var out bytes.Buffer - s := &Server{logger: log.New(io.Discard, "", 0), docs: make(map[string]*document), out: &out} + s := &Server{ + logger: log.New(io.Discard, "", 0), + docs: make(map[string]*document), + out: &out, + inlineOpen: ">", + inlineClose: ">", + inlineOpenChar: '>', + inlineCloseChar: '>', + chatSuffix: ">", + chatSuffixChar: '>', + chatPrefixes: []string{"?", "!", ":", ";"}, + } s.llmClient = fakeLLM{resp: "IGN"} // Inject two custom actions s.customActions = []CustomAction{ diff --git a/internal/lsp/codeaction_gotest_int_test.go b/internal/lsp/codeaction_gotest_int_test.go index 04a73e0..384f3d5 100644 --- a/internal/lsp/codeaction_gotest_int_test.go +++ b/internal/lsp/codeaction_gotest_int_test.go @@ -13,6 +13,7 @@ func TestResolveGoTest_CreatesTestFile(t *testing.T) { t.Fatalf("write: %v", err) } s := &Server{} // minimal server with nil llmClient to trigger stub + initServerDefaults(s) uri := "file://" + src we, jumpURI, jumpRange, ok := s.resolveGoTest(uri, Position{Line: 2}) if !ok || jumpURI == "" || jumpRange.Start.Line < 0 { diff --git a/internal/lsp/completion_codex_path_test.go b/internal/lsp/completion_codex_path_test.go index bd3b3f4..6c0a60f 100644 --- a/internal/lsp/completion_codex_path_test.go +++ b/internal/lsp/completion_codex_path_test.go @@ -40,6 +40,7 @@ func (f *fakeCodeLLM) DefaultModel() string { return "m" } func TestTryLLMCompletion_PrefersCodeCompleterOverChat(t *testing.T) { s := &Server{maxTokens: 32, triggerChars: []string{"."}, compCache: make(map[string]string)} + initServerDefaults(s) fake := &fakeCodeLLM{result: "DoThing()"} s.llmClient = fake line := "obj." @@ -58,6 +59,7 @@ func TestTryLLMCompletion_PrefersCodeCompleterOverChat(t *testing.T) { func TestTryLLMCompletion_FallsBackToChatOnCodeCompleterError(t *testing.T) { s := &Server{maxTokens: 32, triggerChars: []string{"."}, compCache: make(map[string]string)} + initServerDefaults(s) fake := &fakeCodeLLM{result: "DoThing()", codeErr: errors.New("boom")} s.llmClient = fake line := "obj." diff --git a/internal/lsp/completion_prefix_strip_test.go b/internal/lsp/completion_prefix_strip_test.go index 6af87a0..acc7921 100644 --- a/internal/lsp/completion_prefix_strip_test.go +++ b/internal/lsp/completion_prefix_strip_test.go @@ -42,6 +42,7 @@ func TestStripDuplicateAssignmentPrefix_AssignAndWalrus(t *testing.T) { func TestTryLLMCompletion_ManualInvokeAfterWhitespace_Allows(t *testing.T) { s := &Server{maxTokens: 32, triggerChars: []string{".", ":", "/", "_"}, compCache: make(map[string]string)} + initServerDefaults(s) s.llmClient = fakeLLM{resp: tut.MultilineFunctionSuggestion()} line := "func fib(i int) " // cursor after space p := CompletionParams{Position: Position{Line: 0, Character: len(line)}, TextDocument: TextDocumentIdentifier{URI: "file://x.go"}} @@ -58,6 +59,7 @@ func TestTryLLMCompletion_ManualInvokeAfterWhitespace_Allows(t *testing.T) { func TestTryLLMCompletion_InlinePromptAlwaysTriggers(t *testing.T) { s := &Server{maxTokens: 32, triggerChars: []string{".", ":", "/", "_"}, compCache: make(map[string]string)} + initServerDefaults(s) s.llmClient = fakeLLM{resp: "replacement"} line := "prefix >do something> suffix" // No trigger char immediately before cursor; place cursor at end @@ -69,7 +71,17 @@ func TestTryLLMCompletion_InlinePromptAlwaysTriggers(t *testing.T) { } func TestTryLLMCompletion_DoubleOpenEmpty_DoesNotAutoTrigger(t *testing.T) { - s := &Server{maxTokens: 32, triggerChars: []string{".", ":", "/", "_"}, compCache: make(map[string]string)} + s := &Server{ + maxTokens: 32, + triggerChars: []string{".", ":", "/", "_"}, + compCache: make(map[string]string), + inlineOpen: ">", + inlineClose: ">", + inlineOpenChar: '>', + inlineCloseChar: '>', + } + initServerDefaults(s) + initServerDefaults(s) fake := &countingLLM{} s.llmClient = fake line := ">> " // empty content after double-open should not force-trigger @@ -87,22 +99,30 @@ func TestTryLLMCompletion_DoubleOpenEmpty_DoesNotAutoTrigger(t *testing.T) { } func TestHasDoubleSemicolonTrigger_Variants(t *testing.T) { - if hasDoubleOpenTrigger(">>") { + if hasDoubleOpenTrigger(">>", '>', '>') { t.Fatalf("bare double-open should not trigger") } - if hasDoubleOpenTrigger(">> ") { + if hasDoubleOpenTrigger(">> ", '>', '>') { t.Fatalf("double-open followed by space should not trigger") } - if hasDoubleOpenTrigger(">>>") { + if hasDoubleOpenTrigger(">>>", '>', '>') { t.Fatalf("';;;' should not trigger (no content)") } - if !hasDoubleOpenTrigger(">>x>") { + if !hasDoubleOpenTrigger(">>x>", '>', '>') { t.Fatalf("expected trigger for ';;x;' pattern") } } func TestBareDoubleOpenPreventsAutoTriggerEvenWithOtherTriggers(t *testing.T) { - s := &Server{maxTokens: 32, triggerChars: []string{".", ":", "/", "_"}, compCache: make(map[string]string)} + s := &Server{ + maxTokens: 32, + triggerChars: []string{".", ":", "/", "_"}, + compCache: make(map[string]string), + inlineOpen: ">", + inlineClose: ">", + inlineOpenChar: '>', + inlineCloseChar: '>', + } fake := &countingLLM{} s.llmClient = fake // Place a '.' earlier but also include bare double-open at end; should not auto-trigger @@ -122,6 +142,7 @@ func TestBareDoubleOpenPreventsAutoTriggerEvenWithOtherTriggers(t *testing.T) { func TestBareDoubleOpenOnNextLine_PreventsAutoTrigger(t *testing.T) { s := &Server{maxTokens: 32, triggerChars: []string{".", ":", "/", "_"}, compCache: make(map[string]string)} + initServerDefaults(s) fake := &countingLLM{} s.llmClient = fake current := "expression := flag.String(\"expression\", \"\", \"Expression to evaluate\")" @@ -141,6 +162,7 @@ func TestBareDoubleOpenOnNextLine_PreventsAutoTrigger(t *testing.T) { func TestBareDoubleOpenPreventsManualInvoke(t *testing.T) { s := &Server{maxTokens: 32, triggerChars: []string{".", ":", "/", "_"}, compCache: make(map[string]string)} + initServerDefaults(s) fake := &countingLLM{} s.llmClient = fake line := ">>" diff --git a/internal/lsp/coverage_add_test.go b/internal/lsp/coverage_add_test.go index 7701a5e..b3b7322 100644 --- a/internal/lsp/coverage_add_test.go +++ b/internal/lsp/coverage_add_test.go @@ -56,13 +56,14 @@ func TestFindGoFunctionAtLine_NoBody(t *testing.T) { } func TestLineHasInlinePrompt(t *testing.T) { - if !lineHasInlinePrompt(">do>") { + if !lineHasInlinePrompt(">do>", '>', '>') { t.Fatalf("expected inline prompt") } } func TestDiagnosticsInRange_Overlap(t *testing.T) { s := &Server{} + initServerDefaults(s) ctx := CodeActionContext{Diagnostics: []Diagnostic{{ Range: Range{Start: Position{Line: 10, Character: 0}, End: Position{Line: 12, Character: 0}}, Message: "x", @@ -88,15 +89,12 @@ func TestIndentHelpersAndPromptRemoval(t *testing.T) { t.Fatalf("applyIndent: %q", out) } // double-open trigger removes whole line - edits := promptRemovalEditsForLine(">>ask>", 3) + edits := promptRemovalEditsForLine(">>ask>", 3, '>', '>') if len(edits) != 1 || edits[0].Range.Start.Line != 3 { t.Fatalf("unexpected edits: %#v", edits) } - // temporarily switch to semicolon tags and test collection - oldOpen, oldClose := inlineOpenChar, inlineCloseChar - inlineOpenChar, inlineCloseChar = ';', ';' - t.Cleanup(func() { inlineOpenChar, inlineCloseChar = oldOpen, oldClose }) - edits2 := collectSemicolonMarkers("pre;do;post", 1) + // semicolon tags collect correctly when provided explicitly + edits2 := collectSemicolonMarkers("pre;do;post", 1, ';', ';') if len(edits2) != 1 { t.Fatalf("expected one semicolon edit, got %#v", edits2) } diff --git a/internal/lsp/diagnostics_action_test.go b/internal/lsp/diagnostics_action_test.go index a607b86..761062d 100644 --- a/internal/lsp/diagnostics_action_test.go +++ b/internal/lsp/diagnostics_action_test.go @@ -9,6 +9,7 @@ import ( func TestHandleCodeAction_ListsDiagnosticsActionWhenOverlap(t *testing.T) { s := &Server{logger: log.New(io.Discard, "", 0), docs: make(map[string]*document)} + initServerDefaults(s) s.llmClient = fakeLLM{resp: "fixed"} uri := "file:///x.go" s.setDocument(uri, "package p\nvar a=1\n") diff --git a/internal/lsp/document_handlers_test.go b/internal/lsp/document_handlers_test.go index eae5020..1fdd0da 100644 --- a/internal/lsp/document_handlers_test.go +++ b/internal/lsp/document_handlers_test.go @@ -34,6 +34,7 @@ func TestDidOpenChangeClose_UpdateDocs(t *testing.T) { func TestClientShowDocument_WritesRequest(t *testing.T) { var out bytes.Buffer s := &Server{logger: log.New(io.Discard, "", 0), docs: make(map[string]*document), out: &out} + initServerDefaults(s) uri := "file:///x.go" sel := Range{Start: Position{Line: 1}, End: Position{Line: 2}} out.Reset() @@ -47,6 +48,7 @@ func TestClientShowDocument_WritesRequest(t *testing.T) { func TestHandleExecuteCommand_ShowDocument(t *testing.T) { var out bytes.Buffer s := &Server{logger: log.New(io.Discard, "", 0), docs: make(map[string]*document), out: &out} + initServerDefaults(s) uri := "file:///x.go" r := Range{Start: Position{Line: 0}, End: Position{Line: 0}} args := []any{uri, r} @@ -61,6 +63,7 @@ func TestHandleExecuteCommand_ShowDocument(t *testing.T) { func TestDeferShowDocument_WritesLater(t *testing.T) { var out bytes.Buffer s := &Server{logger: log.New(io.Discard, "", 0), docs: make(map[string]*document), out: &out} + initServerDefaults(s) uri := "file:///x.go" out.Reset() s.deferShowDocument(uri, Range{Start: Position{Line: 0}, End: Position{Line: 0}}) diff --git a/internal/lsp/document_test.go b/internal/lsp/document_test.go index 652d867..cbea62a 100644 --- a/internal/lsp/document_test.go +++ b/internal/lsp/document_test.go @@ -10,12 +10,15 @@ import ( func newTestServer() *Server { s := &Server{ - logger: log.New(io.Discard, "", 0), - docs: make(map[string]*document), - inlineOpen: ">", - inlineClose: ">", - chatSuffix: ">", - chatPrefixes: []string{"?", "!", ":", ";"}, + logger: log.New(io.Discard, "", 0), + docs: make(map[string]*document), + inlineOpen: ">", + inlineClose: ">", + chatSuffix: ">", + chatPrefixes: []string{"?", "!", ":", ";"}, + inlineOpenChar: '>', + inlineCloseChar: '>', + chatSuffixChar: '>', } // Default prompt templates (mirror app defaults) s.promptCompSysParams = "You are a code completion engine for function signatures. Return only the parameter list contents (without parentheses), no braces, no prose. Prefer idiomatic names and types." @@ -34,14 +37,33 @@ func newTestServer() *Server { s.promptDocumentUser = "Add documentation comments to this code:\n{{selection}}" s.promptGoTestSystem = "You are a precise Go unit test generator. Given a Go function, write one or more Test* functions using the testing package. Do NOT include package or imports, only the test function(s). Prefer table-driven tests. Keep it minimal and idiomatic." s.promptGoTestUser = "Function under test:\n{{function}}" - // Keep package-level helpers in sync for tests using free functions - inlineOpenChar = '>' - inlineCloseChar = '>' - chatSuffixChar = '>' - chatPrefixSingles = []string{"?", "!", ":", ";"} return s } +func initServerDefaults(s *Server) { + if s.inlineOpen == "" { + s.inlineOpen = ">" + } + if s.inlineClose == "" { + s.inlineClose = ">" + } + if s.inlineOpenChar == 0 && s.inlineOpen != "" { + s.inlineOpenChar = s.inlineOpen[0] + } + if s.inlineCloseChar == 0 && s.inlineClose != "" { + s.inlineCloseChar = s.inlineClose[0] + } + if s.chatSuffix == "" { + s.chatSuffix = ">" + } + if s.chatSuffixChar == 0 && s.chatSuffix != "" { + s.chatSuffixChar = s.chatSuffix[0] + } + if len(s.chatPrefixes) == 0 { + s.chatPrefixes = []string{"?", "!", ":", ";"} + } +} + func TestSplitLines(t *testing.T) { in := "a\r\nb\nc" got := splitLines(in) diff --git a/internal/lsp/handlers.go b/internal/lsp/handlers.go index e85065b..9452551 100644 --- a/internal/lsp/handlers.go +++ b/internal/lsp/handlers.go @@ -25,10 +25,10 @@ func (s *Server) handle(req Request) { // Preference order on each line: strict ;text; marker (no inner spaces), then // a line comment (//, #, --). Returns the instruction string and the selection // text cleaned of the matched instruction marker or comment. -func instructionFromSelection(sel string) (string, string) { +func (s *Server) instructionFromSelection(sel string) (string, string) { lines := splitLines(sel) for idx, line := range lines { - if instr, cleaned, ok := findFirstInstructionInLine(line); ok && strings.TrimSpace(instr) != "" { + if instr, cleaned, ok := s.findFirstInstructionInLine(line); ok && strings.TrimSpace(instr) != "" { lines[idx] = cleaned return instr, strings.Join(lines, "\n") } @@ -45,13 +45,13 @@ func instructionFromSelection(sel string) (string, string) { // - // text // - # text // - -- text -func findFirstInstructionInLine(line string) (instr string, cleaned string, ok bool) { +func (s *Server) findFirstInstructionInLine(line string) (instr string, cleaned string, ok bool) { type cand struct { start, end int text string } cands := []cand{} - if t, l, r, ok := findStrictInlineTag(line); ok { + if t, l, r, ok := findStrictInlineTag(line, s.inlineOpenChar, s.inlineCloseChar); ok { cands = append(cands, cand{start: l, end: r, text: t}) } if i := strings.Index(line, "/*"); i >= 0 { @@ -300,7 +300,7 @@ func (s *Server) isTriggerEvent(p CompletionParams, current string) bool { } // If configured and the line contains a bare double-open marker (e.g., '>>' with no '>>text>'), // do not treat as a trigger source. - if s.inlineOpen != "" && strings.Contains(current, s.inlineOpen+s.inlineOpen) && !hasDoubleOpenTrigger(current) { + if s.inlineOpen != "" && strings.Contains(current, s.inlineOpen+s.inlineOpen) && !hasDoubleOpenTrigger(current, s.inlineOpenChar, s.inlineCloseChar) { return false } // TriggerKind 1 = Invoked (manual). Always allow manual invoke. @@ -328,7 +328,7 @@ func (s *Server) isTriggerEvent(p CompletionParams, current string) bool { return false } // Bare double-open should not trigger via fallback char either (only when configured) - if s.inlineOpen != "" && strings.Contains(current, s.inlineOpen+s.inlineOpen) && !hasDoubleOpenTrigger(current) { + if s.inlineOpen != "" && strings.Contains(current, s.inlineOpen+s.inlineOpen) && !hasDoubleOpenTrigger(current, s.inlineOpenChar, s.inlineCloseChar) { return false } ch := string(current[idx-1]) diff --git a/internal/lsp/handlers_codeaction.go b/internal/lsp/handlers_codeaction.go index e5e61ef..8764525 100644 --- a/internal/lsp/handlers_codeaction.go +++ b/internal/lsp/handlers_codeaction.go @@ -122,7 +122,7 @@ func (s *Server) buildSimplifyCodeAction(p CodeActionParams, sel string) *CodeAc } func (s *Server) buildRewriteCodeAction(p CodeActionParams, sel string) *CodeAction { - if instr, cleaned := instructionFromSelection(sel); strings.TrimSpace(instr) != "" { + if instr, cleaned := s.instructionFromSelection(sel); strings.TrimSpace(instr) != "" { payload := struct { Type string `json:"type"` URI string `json:"uri"` diff --git a/internal/lsp/handlers_completion.go b/internal/lsp/handlers_completion.go index 6142a30..df541cc 100644 --- a/internal/lsp/handlers_completion.go +++ b/internal/lsp/handlers_completion.go @@ -13,6 +13,21 @@ import ( "codeberg.org/snonux/hexai/internal/stats" ) +type completionPlan struct { + params CompletionParams + above string + current string + below string + funcCtx string + docStr string + hasExtra bool + extraText string + inlinePrompt bool + inParams bool + manualInvoke bool + cacheKey string +} + func (s *Server) handleCompletion(req Request) { var p CompletionParams var docStr string @@ -75,44 +90,59 @@ func (s *Server) tryLLMCompletion(p CompletionParams, above, current, below, fun ctx, cancel := context.WithTimeout(context.Background(), 12*time.Second) defer cancel() - inlinePrompt := lineHasInlinePrompt(current) - if !inlinePrompt && !s.isTriggerEvent(p, current) { - logging.Logf("lsp ", "%scompletion skip=no-trigger line=%d char=%d current=%q%s", logging.AnsiYellow, p.Position.Line, p.Position.Character, trimLen(current), logging.AnsiBase) - return []CompletionItem{}, true + plan, items, handled := s.prepareCompletionPlan(p, above, current, below, funcCtx, docStr, hasExtra, extraText) + if handled { + return items, true } - if s.shouldSuppressForChatTriggerEOL(current, p) { - return []CompletionItem{}, true + + if items, ok := s.tryProviderNativeCompletion(current, p, above, below, funcCtx, docStr, hasExtra, extraText, plan.inParams); ok { + return items, true } - inParams := inParamList(current, p.Position.Character) - manualInvoke := parseManualInvoke(p.Context) + return s.executeChatCompletion(ctx, plan) +} - // Cache fast-path - key := s.completionCacheKey(p, above, current, below, funcCtx, inParams, hasExtra, extraText) - if cleaned, ok := s.completionCacheGet(key); ok && strings.TrimSpace(cleaned) != "" { +func (s *Server) prepareCompletionPlan(p CompletionParams, above, current, below, funcCtx, docStr string, hasExtra bool, extraText string) (completionPlan, []CompletionItem, bool) { + plan := completionPlan{ + params: p, + above: above, + current: current, + below: below, + funcCtx: funcCtx, + docStr: docStr, + hasExtra: hasExtra, + extraText: extraText, + } + plan.inlinePrompt = lineHasInlinePrompt(current, s.inlineOpenChar, s.inlineCloseChar) + if !plan.inlinePrompt && !s.isTriggerEvent(p, current) { + logging.Logf("lsp ", "%scompletion skip=no-trigger line=%d char=%d current=%q%s", logging.AnsiYellow, p.Position.Line, p.Position.Character, trimLen(current), logging.AnsiBase) + return plan, []CompletionItem{}, true + } + if s.shouldSuppressForChatTriggerEOL(current, p) { + return plan, []CompletionItem{}, true + } + plan.inParams = inParamList(current, p.Position.Character) + plan.manualInvoke = parseManualInvoke(p.Context) + plan.cacheKey = s.completionCacheKey(p, above, current, below, funcCtx, plan.inParams, hasExtra, extraText) + if cleaned, ok := s.completionCacheGet(plan.cacheKey); ok && strings.TrimSpace(cleaned) != "" { logging.Logf("lsp ", "completion cache hit uri=%s line=%d char=%d preview=%s%s%s", p.TextDocument.URI, p.Position.Line, p.Position.Character, logging.AnsiGreen, logging.PreviewForLog(cleaned), logging.AnsiBase) - return s.makeCompletionItems(cleaned, inParams, current, p, docStr), true + return plan, s.makeCompletionItems(cleaned, plan.inParams, current, p, docStr), true } - if isBareDoubleOpen(current) || isBareDoubleOpen(below) { + if isBareDoubleOpen(current, s.inlineOpenChar, s.inlineCloseChar) || isBareDoubleOpen(below, s.inlineOpenChar, s.inlineCloseChar) { logging.Logf("lsp ", "%scompletion skip=empty-double-semicolon line=%d char=%d current=%q%s", logging.AnsiYellow, p.Position.Line, p.Position.Character, trimLen(current), logging.AnsiBase) - return []CompletionItem{}, true + return plan, []CompletionItem{}, true } - - if !inParams && !s.prefixHeuristicAllows(inlinePrompt, current, p, manualInvoke) { + if !plan.inParams && !s.prefixHeuristicAllows(plan.inlinePrompt, current, p, plan.manualInvoke) { logging.Logf("lsp ", "%scompletion skip=short-prefix line=%d char=%d current=%q%s", logging.AnsiYellow, p.Position.Line, p.Position.Character, trimLen(current), logging.AnsiBase) - return []CompletionItem{}, true - } - - // Provider-native path - if items, ok := s.tryProviderNativeCompletion(current, p, above, below, funcCtx, docStr, hasExtra, extraText, inParams); ok { - return items, true + return plan, []CompletionItem{}, true } + return plan, nil, false +} - // Chat path - messages := s.buildCompletionMessages(inlinePrompt, hasExtra, extraText, inParams, p, above, current, below, funcCtx) - // Counters and options +func (s *Server) executeChatCompletion(ctx context.Context, plan completionPlan) ([]CompletionItem, bool) { + messages := s.buildCompletionMessages(plan.inlinePrompt, plan.hasExtra, plan.extraText, plan.inParams, plan.params, plan.above, plan.current, plan.below, plan.funcCtx) sentSize := 0 for _, m := range messages { sentSize += len(m.Content) @@ -122,13 +152,14 @@ func (s *Server) tryLLMCompletion(p CompletionParams, above, current, below, fun if s.codingTemperature != nil { opts = append(opts, llm.WithTemperature(*s.codingTemperature)) } - // Debounce and throttle before making the LLM call s.waitForDebounce(ctx) if !s.waitForThrottle(ctx) { return nil, false } + if s.llmClient == nil { + return nil, false + } logging.Logf("lsp ", "completion llm=requesting model=%s", s.llmClient.DefaultModel()) - text, err := s.llmClient.Chat(ctx, messages, opts...) if err != nil { logging.Logf("lsp ", "llm completion error: %v", err) @@ -137,13 +168,14 @@ func (s *Server) tryLLMCompletion(p CompletionParams, above, current, below, fun } s.incRecvCounters(len(text)) s.logLLMStats() - - cleaned := s.postProcessCompletion(strings.TrimSpace(text), current[:p.Position.Character], current) + trimmed := strings.TrimSpace(text) + cleaned := s.postProcessCompletion(trimmed, plan.current[:plan.params.Position.Character], plan.current) if cleaned == "" { return nil, false } - s.completionCachePut(key, cleaned) - return s.makeCompletionItems(cleaned, inParams, current, p, docStr), true + s.completionCachePut(plan.cacheKey, cleaned) + items := s.makeCompletionItems(cleaned, plan.inParams, plan.current, plan.params, plan.docStr) + return items, true } // parseManualInvoke inspects the LSP completion context and reports whether the user manually invoked completion. @@ -269,7 +301,7 @@ func (s *Server) tryProviderNativeCompletion(current string, p CompletionParams, if cleaned != "" { cleaned = stripDuplicateGeneralPrefix(current[:p.Position.Character], cleaned) } - if cleaned != "" && hasDoubleOpenTrigger(current) { + if cleaned != "" && hasDoubleOpenTrigger(current, s.inlineOpenChar, s.inlineCloseChar) { indent := leadingIndent(current) if indent != "" { cleaned = applyIndent(indent, cleaned) @@ -398,7 +430,7 @@ func (s *Server) postProcessCompletion(text string, leftOfCursor string, current if cleaned != "" { cleaned = stripDuplicateGeneralPrefix(leftOfCursor, cleaned) } - if cleaned != "" && hasDoubleOpenTrigger(currentLine) { + if cleaned != "" && hasDoubleOpenTrigger(currentLine, s.inlineOpenChar, s.inlineCloseChar) { if indent := leadingIndent(currentLine); indent != "" { cleaned = applyIndent(indent, cleaned) } diff --git a/internal/lsp/handlers_document.go b/internal/lsp/handlers_document.go index 3897885..ca0cb8d 100644 --- a/internal/lsp/handlers_document.go +++ b/internal/lsp/handlers_document.go @@ -11,13 +11,6 @@ import ( "codeberg.org/snonux/hexai/internal/logging" ) -// Package-level chat trigger vars for helpers without Server receiver. -// NewServer assigns these from configuration on startup. -var ( - chatSuffixChar byte = '>' - chatPrefixSingles = []string{"?", "!", ":", ";"} -) - func (s *Server) handleDidOpen(req Request) { var p DidOpenTextDocumentParams if err := json.Unmarshal(req.Params, &p); err == nil { @@ -236,7 +229,7 @@ func (s *Server) buildChatHistory(uri string, lineIdx int, currentPrompt string) break } q := strings.TrimSpace(d.lines[i]) - q = stripTrailingTrigger(q) + q = s.stripTrailingTrigger(q) pairs = append([]pair{{q: q, a: strings.Join(replyLines, "\n")}}, pairs...) i-- } @@ -254,25 +247,23 @@ func (s *Server) buildChatHistory(uri string, lineIdx int, currentPrompt string) } // stripTrailingTrigger removes the trailing chat trigger punctuation from a line if present. -func stripTrailingTrigger(sx string) string { - s := strings.TrimRight(sx, " \t") - if len(s) == 0 { +func (s *Server) stripTrailingTrigger(sx string) string { + trim := strings.TrimRight(sx, " \t") + if len(trim) == 0 { return sx } - // Configurable suffix removal when preceded by configured prefixes - if len(s) >= 2 && s[len(s)-1] == chatSuffixChar { - prev := string(s[len(s)-2]) - for _, pf := range chatPrefixSingles { + if len(trim) >= 2 && s.chatSuffixChar != 0 && trim[len(trim)-1] == s.chatSuffixChar { + prev := string(trim[len(trim)-2]) + for _, pf := range s.chatPrefixes { if prev == pf { - return strings.TrimRight(s[:len(s)-1], " \t") + return strings.TrimRight(trim[:len(trim)-1], " \t") } } } - // Legacy: remove one trailing punctuation (?, !, :) to build history nicely - last := s[len(s)-1] + last := trim[len(trim)-1] switch last { case '?', '!', ':': - return strings.TrimRight(s[:len(s)-1], " \t") + return strings.TrimRight(trim[:len(trim)-1], " \t") default: return sx } diff --git a/internal/lsp/handlers_end_to_end_test.go b/internal/lsp/handlers_end_to_end_test.go index 32cb488..5489b97 100644 --- a/internal/lsp/handlers_end_to_end_test.go +++ b/internal/lsp/handlers_end_to_end_test.go @@ -73,6 +73,7 @@ func TestHandleCodeAction_ListsHexaiActions(t *testing.T) { // Prepare server var out bytes.Buffer s := &Server{logger: log.New(io.Discard, "", 0), docs: make(map[string]*document), out: &out} + initServerDefaults(s) s.chatSuffix = ">" s.chatPrefixes = []string{"?", "!", ":", ";"} s.llmClient = fakeLLM{resp: "// doc\nfunc add(a,b int) int { return a+b }"} @@ -121,6 +122,7 @@ func TestHandleCodeAction_ListsHexaiActions(t *testing.T) { func TestHandleCodeActionResolve_Document(t *testing.T) { var out bytes.Buffer s := &Server{logger: log.New(io.Discard, "", 0), docs: make(map[string]*document), out: &out} + initServerDefaults(s) s.llmClient = fakeLLM{resp: "// doc\nfunc f(){}"} uri := "file:///x.go" s.setDocument(uri, "package p\nfunc f(){}\n") @@ -152,6 +154,7 @@ func TestHandleCodeActionResolve_Document(t *testing.T) { func TestHandleCodeAction_NoLLMOrEmptySelection_ReturnsEmpty(t *testing.T) { var out bytes.Buffer s := &Server{logger: log.New(io.Discard, "", 0), docs: make(map[string]*document), out: &out} + initServerDefaults(s) uri := "file:///x.go" s.setDocument(uri, "package p\n\n") // Empty selection @@ -187,6 +190,7 @@ func mustJSON(v any) json.RawMessage { b, _ := json.Marshal(v); return b } func TestHandle_UnknownMethod_ReturnsError(t *testing.T) { var out bytes.Buffer s := &Server{logger: log.New(io.Discard, "", 0), docs: make(map[string]*document), out: &out, handlers: map[string]func(Request){}} + initServerDefaults(s) req := Request{JSONRPC: "2.0", ID: json.RawMessage("9"), Method: "no/such"} out.Reset() s.handle(req) @@ -253,6 +257,7 @@ func TestDetectAndHandleChat_InsertsReply(t *testing.T) { func TestHandleCodeActionResolve_Diagnostics(t *testing.T) { var out bytes.Buffer s := &Server{logger: log.New(io.Discard, "", 0), docs: make(map[string]*document), out: &out} + initServerDefaults(s) s.llmClient = fakeLLM{resp: "fixed"} uri := "file:///x.go" s.setDocument(uri, "package p\nvar x = 1\n") diff --git a/internal/lsp/handlers_helpers_test.go b/internal/lsp/handlers_helpers_test.go index 0120cc3..2bd677e 100644 --- a/internal/lsp/handlers_helpers_test.go +++ b/internal/lsp/handlers_helpers_test.go @@ -17,7 +17,7 @@ func TestHasDoubleSemicolonTrigger(t *testing.T) { {">>x > space before close", false}, } for _, tc := range cases { - got := hasDoubleOpenTrigger(tc.line) + got := hasDoubleOpenTrigger(tc.line, '>', '>') if got != tc.want { t.Fatalf("hasDoubleOpenTrigger(%q)=%v want %v", tc.line, got, tc.want) } @@ -26,7 +26,7 @@ func TestHasDoubleSemicolonTrigger(t *testing.T) { func TestCollectSemicolonMarkers(t *testing.T) { line := "keep >ok> this and >another> that" - edits := collectSemicolonMarkers(line, 7) + edits := collectSemicolonMarkers(line, 7, '>', '>') if len(edits) != 2 { t.Fatalf("expected 2 edits, got %d", len(edits)) } @@ -42,7 +42,7 @@ func TestCollectSemicolonMarkers(t *testing.T) { func TestPromptRemovalEditsForLine_WholeLine(t *testing.T) { line := ">>todo> remove this whole line" - edits := promptRemovalEditsForLine(line, 3) + edits := promptRemovalEditsForLine(line, 3, '>', '>') if len(edits) != 1 { t.Fatalf("expected 1 whole-line edit, got %d", len(edits)) } diff --git a/internal/lsp/handlers_test.go b/internal/lsp/handlers_test.go index a171143..6803d1e 100644 --- a/internal/lsp/handlers_test.go +++ b/internal/lsp/handlers_test.go @@ -5,7 +5,8 @@ import "testing" func TestFindFirstInstructionInLine_NoMarker(t *testing.T) { line := "fmt.Println(\"hello\")" - instr, cleaned, ok := findFirstInstructionInLine(line) + s := newTestServer() + instr, cleaned, ok := s.findFirstInstructionInLine(line) if ok { t.Fatalf("expected ok=false; got ok=true with instr=%q cleaned=%q", instr, cleaned) } @@ -16,7 +17,8 @@ func TestFindFirstInstructionInLine_NoMarker(t *testing.T) { func TestFindFirstInstructionInLine_StrictInline_Basic(t *testing.T) { line := "prefix >rename var> suffix" - instr, cleaned, ok := findFirstInstructionInLine(line) + s := newTestServer() + instr, cleaned, ok := s.findFirstInstructionInLine(line) if !ok { t.Fatalf("expected ok=true") } @@ -31,7 +33,8 @@ func TestFindFirstInstructionInLine_StrictInline_Basic(t *testing.T) { func TestFindFirstInstructionInLine_StrictInline_TrailingSpacesTrimmed(t *testing.T) { line := "code>fix> \t\t" - instr, cleaned, ok := findFirstInstructionInLine(line) + s := newTestServer() + instr, cleaned, ok := s.findFirstInstructionInLine(line) if !ok { t.Fatalf("expected ok=true") } @@ -50,7 +53,8 @@ func TestFindFirstInstructionInLine_Inline_InvalidPatterns(t *testing.T) { "prefix > > suffix", // empty inner ⇒ invalid } for _, line := range cases { - if instr, _, ok := findFirstInstructionInLine(line); ok && instr != "" { + s := newTestServer() + if instr, _, ok := s.findFirstInstructionInLine(line); ok && instr != "" { t.Fatalf("%q: expected no inline instruction; got instr=%q", line, instr) } } @@ -58,7 +62,8 @@ func TestFindFirstInstructionInLine_Inline_InvalidPatterns(t *testing.T) { func TestFindFirstInstructionInLine_CBlockComment(t *testing.T) { line := "foo /* update this part */ bar" - instr, cleaned, ok := findFirstInstructionInLine(line) + s := newTestServer() + instr, cleaned, ok := s.findFirstInstructionInLine(line) if !ok { t.Fatalf("expected ok=true") } @@ -72,7 +77,8 @@ func TestFindFirstInstructionInLine_CBlockComment(t *testing.T) { func TestFindFirstInstructionInLine_HTMLComment(t *testing.T) { line := "foo <!-- do x --> bar" - instr, cleaned, ok := findFirstInstructionInLine(line) + s := newTestServer() + instr, cleaned, ok := s.findFirstInstructionInLine(line) if !ok { t.Fatalf("expected ok=true") } @@ -86,7 +92,8 @@ func TestFindFirstInstructionInLine_HTMLComment(t *testing.T) { func TestFindFirstInstructionInLine_SlashSlash(t *testing.T) { line := "val // do this change" - instr, cleaned, ok := findFirstInstructionInLine(line) + s := newTestServer() + instr, cleaned, ok := s.findFirstInstructionInLine(line) if !ok { t.Fatalf("expected ok=true") } @@ -100,7 +107,8 @@ func TestFindFirstInstructionInLine_SlashSlash(t *testing.T) { func TestFindFirstInstructionInLine_Hash(t *testing.T) { line := "val # do this" - instr, cleaned, ok := findFirstInstructionInLine(line) + s := newTestServer() + instr, cleaned, ok := s.findFirstInstructionInLine(line) if !ok { t.Fatalf("expected ok=true") } @@ -114,7 +122,8 @@ func TestFindFirstInstructionInLine_Hash(t *testing.T) { func TestFindFirstInstructionInLine_DoubleDash(t *testing.T) { line := "SQL -- fix query" - instr, cleaned, ok := findFirstInstructionInLine(line) + s := newTestServer() + instr, cleaned, ok := s.findFirstInstructionInLine(line) if !ok { t.Fatalf("expected ok=true") } @@ -128,7 +137,8 @@ func TestFindFirstInstructionInLine_DoubleDash(t *testing.T) { func TestFindFirstInstructionInLine_EarliestWins_CommentOverInline(t *testing.T) { line := "aa // comment >not this> trailing" - instr, cleaned, ok := findFirstInstructionInLine(line) + s := newTestServer() + instr, cleaned, ok := s.findFirstInstructionInLine(line) if !ok { t.Fatalf("expected ok=true") } @@ -142,7 +152,8 @@ func TestFindFirstInstructionInLine_EarliestWins_CommentOverInline(t *testing.T) func TestFindFirstInstructionInLine_EarliestWins_InlineOverComment(t *testing.T) { line := "aa >short> // comment" - instr, cleaned, ok := findFirstInstructionInLine(line) + s := newTestServer() + instr, cleaned, ok := s.findFirstInstructionInLine(line) if !ok { t.Fatalf("expected ok=true") } @@ -157,19 +168,19 @@ func TestFindFirstInstructionInLine_EarliestWins_InlineOverComment(t *testing.T) func TestFindStrictInlineTag_Various(t *testing.T) { // basic - if text, l, r, ok := findStrictInlineTag("pre>do it>post"); !ok || text != "do it" || l != 3 || r != 10 { + if text, l, r, ok := findStrictInlineTag("pre>do it>post", '>', '>'); !ok || text != "do it" || l != 3 || r != 10 { t.Fatalf("unexpected: ok=%v text=%q l=%d r=%d", ok, text, l, r) } // at start - if text, l, r, ok := findStrictInlineTag(">x>"); !ok || text != "x" || l != 0 || r != 3 { + if text, l, r, ok := findStrictInlineTag(">x>", '>', '>'); !ok || text != "x" || l != 0 || r != 3 { t.Fatalf("unexpected at start: ok=%v text=%q l=%d r=%d", ok, text, l, r) } // double opening '>>' should still allow a tag starting at the second '>' - if text, _, _, ok := findStrictInlineTag("prefix >>bad> suffix"); !ok || text != "bad" { + if text, _, _, ok := findStrictInlineTag("prefix >>bad> suffix", '>', '>'); !ok || text != "bad" { t.Fatalf("unexpected double-open handling: ok=%v text=%q", ok, text) } // inner spaces directly after first '>' or before last '>' invalidate the tag - if _, _, _, ok := findStrictInlineTag("a> inner >b"); ok { + if _, _, _, ok := findStrictInlineTag("a> inner >b", '>', '>'); ok { t.Fatalf("expected invalid strict tag due to spaces at boundaries") } } diff --git a/internal/lsp/handlers_utils.go b/internal/lsp/handlers_utils.go index c0ec7c3..56d752d 100644 --- a/internal/lsp/handlers_utils.go +++ b/internal/lsp/handlers_utils.go @@ -13,13 +13,6 @@ import ( tmx "codeberg.org/snonux/hexai/internal/tmux" ) -// Configurable inline trigger characters (default to '>') used by free helpers below. -// NewServer assigns these based on ServerOptions. -var ( - inlineOpenChar byte = '>' - inlineCloseChar byte = '>' -) - // llmRequestOpts builds request options from server settings. func (s *Server) llmRequestOpts() []llm.RequestOption { opts := []llm.RequestOption{llm.WithMaxTokens(s.maxTokens)} @@ -183,11 +176,12 @@ func (s *Server) chatWithStats(ctx context.Context, msgs []llm.Message, opts ... } // Inline prompt utilities -func lineHasInlinePrompt(line string) bool { - if _, _, _, ok := findStrictInlineTag(line); ok { + +func lineHasInlinePrompt(line string, open, close byte) bool { + if _, _, _, ok := findStrictInlineTag(line, open, close); ok { return true } - return hasDoubleOpenTrigger(line) + return hasDoubleOpenTrigger(line, open, close) } func leadingIndent(line string) string { @@ -227,22 +221,22 @@ func applyIndent(indent, suggestion string) string { // findStrictInlineTag finds >text> (configurable), with no space after the first // opening marker and no space immediately before the closing marker. Returns the // text between markers, the start index, the end index just after closing, and ok. -func findStrictInlineTag(line string) (string, int, int, bool) { +func findStrictInlineTag(line string, open, close byte) (string, int, int, bool) { pos := 0 for pos < len(line) { // find opening marker - j := strings.IndexByte(line[pos:], inlineOpenChar) + j := strings.IndexByte(line[pos:], open) if j < 0 { return "", 0, 0, false } j += pos // ensure single open (not double) and non-space after - if j+1 >= len(line) || line[j+1] == inlineOpenChar || line[j+1] == ' ' { + if j+1 >= len(line) || line[j+1] == open || line[j+1] == ' ' { pos = j + 1 continue } // find closing marker - k := strings.IndexByte(line[j+1:], inlineCloseChar) + k := strings.IndexByte(line[j+1:], close) if k < 0 { return "", 0, 0, false } @@ -265,14 +259,14 @@ func findStrictInlineTag(line string) (string, int, int, bool) { // isBareDoubleSemicolon reports whether the line contains a standalone // double-semicolon marker with no inline content (";;" possibly with only // whitespace after it). It explicitly excludes the valid form ";;text;". -func isBareDoubleOpen(line string) bool { +func isBareDoubleOpen(line string, open, close byte) bool { t := strings.TrimSpace(line) // check for double-open pattern - dbl := string([]byte{inlineOpenChar, inlineOpenChar}) + dbl := string([]byte{open, open}) if !strings.Contains(t, dbl) { return false } - if hasDoubleOpenTrigger(t) { + if hasDoubleOpenTrigger(t, open, close) { return false } if strings.HasPrefix(t, dbl) { @@ -434,23 +428,23 @@ func (s *Server) collectPromptRemovalEdits(uri string) []TextEdit { } var edits []TextEdit for i, line := range d.lines { - edits = append(edits, promptRemovalEditsForLine(line, i)...) + edits = append(edits, promptRemovalEditsForLine(line, i, s.inlineOpenChar, s.inlineCloseChar)...) } return edits } -func promptRemovalEditsForLine(line string, lineNum int) []TextEdit { - if hasDoubleOpenTrigger(line) { +func promptRemovalEditsForLine(line string, lineNum int, open, close byte) []TextEdit { + if hasDoubleOpenTrigger(line, open, close) { return []TextEdit{{Range: Range{Start: Position{Line: lineNum, Character: 0}, End: Position{Line: lineNum, Character: len(line)}}, NewText: ""}} } - return collectSemicolonMarkers(line, lineNum) + return collectSemicolonMarkers(line, lineNum, open, close) } -func hasDoubleOpenTrigger(line string) bool { +func hasDoubleOpenTrigger(line string, open, close byte) bool { pos := 0 for pos < len(line) { // look for double-open sequence - dbl := string([]byte{inlineOpenChar, inlineOpenChar}) + dbl := string([]byte{open, open}) j := strings.Index(line[pos:], dbl) if j < 0 { return false @@ -461,12 +455,12 @@ func hasDoubleOpenTrigger(line string) bool { return false } first := line[contentStart] - if first == ' ' || first == inlineOpenChar { + if first == ' ' || first == open { pos = contentStart + 1 continue } // find closing - k := strings.IndexByte(line[contentStart+1:], inlineCloseChar) + k := strings.IndexByte(line[contentStart+1:], close) if k < 0 { return false } @@ -480,16 +474,16 @@ func hasDoubleOpenTrigger(line string) bool { return false } -func collectSemicolonMarkers(line string, lineNum int) []TextEdit { +func collectSemicolonMarkers(line string, lineNum int, open, close byte) []TextEdit { var edits []TextEdit startSemi := 0 for startSemi < len(line) { - j := strings.IndexByte(line[startSemi:], inlineOpenChar) + j := strings.IndexByte(line[startSemi:], open) if j < 0 { break } j += startSemi - k := strings.IndexByte(line[j+1:], inlineCloseChar) + k := strings.IndexByte(line[j+1:], close) if k < 0 { break } @@ -497,7 +491,7 @@ func collectSemicolonMarkers(line string, lineNum int) []TextEdit { startSemi = j + 1 continue } - if line[j+1] == inlineOpenChar { // skip double-open start + if line[j+1] == open { // skip double-open start startSemi = j + 2 continue } diff --git a/internal/lsp/helpers_inline_prompt_test.go b/internal/lsp/helpers_inline_prompt_test.go index 4aaf892..e4a38f5 100644 --- a/internal/lsp/helpers_inline_prompt_test.go +++ b/internal/lsp/helpers_inline_prompt_test.go @@ -7,11 +7,11 @@ import ( func TestLineHasInlinePrompt_BasicAndDoubleOpen(t *testing.T) { // Basic inline - if !lineHasInlinePrompt("do >task> now") { + if !lineHasInlinePrompt("do >task> now", '>', '>') { t.Fatalf("expected inline prompt detection for >text>") } // Double-open variant should be recognized as inline prompt too - if !lineHasInlinePrompt(">>replace>") { + if !lineHasInlinePrompt(">>replace>", '>', '>') { t.Fatalf("expected inline prompt detection for >>text>") } } diff --git a/internal/lsp/helpers_more_test.go b/internal/lsp/helpers_more_test.go index 160f91c..287aa9d 100644 --- a/internal/lsp/helpers_more_test.go +++ b/internal/lsp/helpers_more_test.go @@ -25,10 +25,10 @@ func TestLeadingAndApplyIndent(t *testing.T) { } func TestFindStrictInlineTag(t *testing.T) { - if _, _, _, ok := findStrictInlineTag(">do this> next"); !ok { + if _, _, _, ok := findStrictInlineTag(">do this> next", '>', '>'); !ok { t.Fatalf("expected strict tag") } - if _, _, _, ok := findStrictInlineTag("> spaced >"); ok { + if _, _, _, ok := findStrictInlineTag("> spaced >", '>', '>'); ok { t.Fatalf("should ignore spaced tag") } } @@ -81,11 +81,11 @@ func TestRangesOverlapAndOrder(t *testing.T) { } func TestPromptRemovalEditsForLine(t *testing.T) { - edits := promptRemovalEditsForLine(">>do thing>", 3) + edits := promptRemovalEditsForLine(">>do thing>", 3, '>', '>') if len(edits) != 1 || edits[0].Range.Start.Line != 3 { t.Fatalf("expected full-line removal for double-semicolon") } - edits2 := promptRemovalEditsForLine(">act> and >b>", 1) + edits2 := promptRemovalEditsForLine(">act> and >b>", 1, '>', '>') if len(edits2) == 0 { t.Fatalf("expected edits to remove strict markers") } @@ -143,10 +143,10 @@ func TestComputeTextEditAndFilter(t *testing.T) { } func TestIsBareDoubleOpen(t *testing.T) { - if !isBareDoubleOpen(">> ") { + if !isBareDoubleOpen(">> ", '>', '>') { t.Fatalf("expected true") } - if isBareDoubleOpen(">>x>") { + if isBareDoubleOpen(">>x>", '>', '>') { t.Fatalf("expected false for content form") } } diff --git a/internal/lsp/init_and_trigger_test.go b/internal/lsp/init_and_trigger_test.go index 10c04fd..10d0968 100644 --- a/internal/lsp/init_and_trigger_test.go +++ b/internal/lsp/init_and_trigger_test.go @@ -11,6 +11,7 @@ import ( func TestHandleInitialize_Capabilities(t *testing.T) { var out bytes.Buffer s := &Server{logger: log.New(io.Discard, "", 0), docs: make(map[string]*document), out: &out} + initServerDefaults(s) s.triggerChars = []string{".", ":"} req := Request{JSONRPC: "2.0", ID: json.RawMessage("7"), Method: "initialize"} out.Reset() diff --git a/internal/lsp/init_shutdown_test.go b/internal/lsp/init_shutdown_test.go index 19b9b33..2847170 100644 --- a/internal/lsp/init_shutdown_test.go +++ b/internal/lsp/init_shutdown_test.go @@ -11,6 +11,7 @@ import ( func TestHandleShutdown_Replies(t *testing.T) { var out bytes.Buffer s := &Server{logger: log.New(io.Discard, "", 0), docs: make(map[string]*document), out: &out} + initServerDefaults(s) req := Request{JSONRPC: "2.0", ID: json.RawMessage("12"), Method: "shutdown"} out.Reset() s.handleShutdown(req) diff --git a/internal/lsp/instruction_table_test.go b/internal/lsp/instruction_table_test.go index ff750ca..a6042b1 100644 --- a/internal/lsp/instruction_table_test.go +++ b/internal/lsp/instruction_table_test.go @@ -16,7 +16,8 @@ func TestFindFirstInstructionInLine_Table(t *testing.T) { {"double_dash", "-- rewrite quickly", "rewrite quickly"}, } for _, c := range cases { - instr, _, ok := findFirstInstructionInLine(c.line) + s := newTestServer() + instr, _, ok := s.findFirstInstructionInLine(c.line) if !ok || instr != c.instr { t.Fatalf("%s: got %q ok=%v", c.name, instr, ok) } diff --git a/internal/lsp/server.go b/internal/lsp/server.go index e3728c8..13066f7 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -18,6 +18,7 @@ import ( type Server struct { in *bufio.Reader out io.Writer + outMu sync.Mutex logger *log.Logger exited bool mu sync.RWMutex @@ -55,10 +56,13 @@ type Server struct { handlers map[string]func(Request) // Configurable trigger characters - inlineOpen string - inlineClose string - chatSuffix string - chatPrefixes []string + inlineOpen string + inlineClose string + chatSuffix string + chatPrefixes []string + inlineOpenChar byte + inlineCloseChar byte + chatSuffixChar byte // Prompt templates // Completion @@ -230,18 +234,20 @@ func NewServer(r io.Reader, w io.Writer, logger *log.Logger, opts ServerOptions) s.customActions = append([]CustomAction{}, opts.CustomActions...) } - // Assign package-level inline trigger chars for free helper functions if s.inlineOpen != "" { - inlineOpenChar = s.inlineOpen[0] + s.inlineOpenChar = s.inlineOpen[0] + } else { + s.inlineOpenChar = '>' } if s.inlineClose != "" { - inlineCloseChar = s.inlineClose[0] + s.inlineCloseChar = s.inlineClose[0] + } else { + s.inlineCloseChar = '>' } if s.chatSuffix != "" { - chatSuffixChar = s.chatSuffix[0] - } - if len(s.chatPrefixes) > 0 { - chatPrefixSingles = append([]string{}, s.chatPrefixes...) + s.chatSuffixChar = s.chatSuffix[0] + } else { + s.chatSuffixChar = '>' } // Initialize dispatch table s.handlers = map[string]func(Request){ diff --git a/internal/lsp/transport.go b/internal/lsp/transport.go index bdd01a1..60e5379 100644 --- a/internal/lsp/transport.go +++ b/internal/lsp/transport.go @@ -49,6 +49,9 @@ func (s *Server) readMessage() ([]byte, error) { } func (s *Server) writeMessage(v any) { + s.outMu.Lock() + defer s.outMu.Unlock() + data, err := json.Marshal(v) if err != nil { logging.Logf("lsp ", "marshal error: %v", err) diff --git a/internal/lsp/triggers_config_test.go b/internal/lsp/triggers_config_test.go index 93d312a..0fcbd15 100644 --- a/internal/lsp/triggers_config_test.go +++ b/internal/lsp/triggers_config_test.go @@ -30,16 +30,16 @@ func TestNewServer_AssignsTriggerGlobals_AndParsingUsesThem(t *testing.T) { InlineOpen: "<", InlineClose: ">", ChatSuffix: ")", ChatPrefixes: []string{":"}, }) _ = s // ensure server constructed applies globals - if inlineOpenChar != '<' || inlineCloseChar != '>' { - t.Fatalf("inline markers not applied: %q %q", string(inlineOpenChar), string(inlineCloseChar)) + if s.inlineOpenChar != '<' || s.inlineCloseChar != '>' { + t.Fatalf("inline markers not applied: %q %q", string(s.inlineOpenChar), string(s.inlineCloseChar)) } - if chatSuffixChar != ')' || len(chatPrefixSingles) == 0 || chatPrefixSingles[0] != ":" { - t.Fatalf("chat markers not applied: suffix=%q prefixes=%v", string(chatSuffixChar), chatPrefixSingles) + if s.chatSuffixChar != ')' || len(s.chatPrefixes) == 0 || s.chatPrefixes[0] != ":" { + t.Fatalf("chat markers not applied: suffix=%q prefixes=%v", string(s.chatSuffixChar), s.chatPrefixes) } - if txt, l, r, ok := findStrictInlineTag("x<do>y"); !ok || txt != "do" || l != 1 || r != 5 { + if txt, l, r, ok := findStrictInlineTag("x<do>y", s.inlineOpenChar, s.inlineCloseChar); !ok || txt != "do" || l != 1 || r != 5 { t.Fatalf("findStrictInlineTag failed: ok=%v txt=%q l=%d r=%d", ok, txt, l, r) } - if got := stripTrailingTrigger("note:)"); got != "note:" { + if got := s.stripTrailingTrigger("note:)"); got != "note:" { t.Fatalf("stripTrailingTrigger failed: %q", got) } } diff --git a/internal/stats/stats.go b/internal/stats/stats.go index 3a9a9ab..a8390ef 100644 --- a/internal/stats/stats.go +++ b/internal/stats/stats.go @@ -14,7 +14,6 @@ import ( "path/filepath" "strconv" "sync/atomic" - "syscall" "time" ) @@ -27,6 +26,8 @@ const ( var windowSeconds int64 = int64(defaultWindow.Seconds()) +var errLockWouldBlock = errors.New("stats: lock would block") + // SetWindow sets the sliding window used for pruning and aggregation. func SetWindow(d time.Duration) { if d < time.Second { @@ -88,19 +89,11 @@ func Update(ctx context.Context, provider, model string, sentBytes, recvBytes in return err } defer f.Close() - // Acquire exclusive flock; best-effort ctx support via polling - for { - if err := syscall.Flock(int(f.Fd()), syscall.LOCK_EX|syscall.LOCK_NB); err == nil { - defer syscall.Flock(int(f.Fd()), syscall.LOCK_UN) - break - } - // Wait a bit or exit if context canceled - select { - case <-ctx.Done(): - return ctx.Err() - case <-time.After(5 * time.Millisecond): - } + unlock, err := acquireFileLock(ctx, f) + if err != nil { + return err } + defer func() { _ = unlock() }() // Read existing file (if any) path := filepath.Join(dir, fileName) var sf File @@ -158,6 +151,25 @@ func Update(ctx context.Context, provider, model string, sentBytes, recvBytes in return nil } +func acquireFileLock(ctx context.Context, f *os.File) (func() error, error) { + fd := f.Fd() + for { + err := tryLockFile(fd) + if err == nil { + return func() error { return unlockFile(fd) }, nil + } + if errors.Is(err, errLockWouldBlock) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(5 * time.Millisecond): + } + continue + } + return nil, err + } +} + // Snapshot reads and aggregates events within the configured window. func TakeSnapshot() (Snapshot, error) { dir, err := CacheDir() |
