summaryrefslogtreecommitdiff
path: root/internal/lsp
diff options
context:
space:
mode:
Diffstat (limited to 'internal/lsp')
-rw-r--r--internal/lsp/document_test.go9
-rw-r--r--internal/lsp/handlers_codeaction.go32
-rw-r--r--internal/lsp/handlers_completion.go56
-rw-r--r--internal/lsp/handlers_document.go15
-rw-r--r--internal/lsp/handlers_utils.go166
-rw-r--r--internal/lsp/llm_request_opts_test.go11
-rw-r--r--internal/lsp/llm_stats_test.go2
-rw-r--r--internal/lsp/provider_native_success_test.go9
-rw-r--r--internal/lsp/server.go85
9 files changed, 303 insertions, 82 deletions
diff --git a/internal/lsp/document_test.go b/internal/lsp/document_test.go
index ed2ccea..fd13e5d 100644
--- a/internal/lsp/document_test.go
+++ b/internal/lsp/document_test.go
@@ -8,6 +8,7 @@ import (
"testing"
"codeberg.org/snonux/hexai/internal/appconfig"
+ "codeberg.org/snonux/hexai/internal/llm"
)
func newTestServer() *Server {
@@ -35,9 +36,11 @@ func newTestServer() *Server {
PromptCodeActionGoTestUser: "Function under test:\n{{function}}",
}
return &Server{
- logger: log.New(io.Discard, "", 0),
- docs: make(map[string]*document),
- cfg: cfg,
+ logger: log.New(io.Discard, "", 0),
+ docs: make(map[string]*document),
+ cfg: cfg,
+ altClients: make(map[string]llm.Client),
+ llmProvider: canonicalProvider(cfg.Provider),
}
}
diff --git a/internal/lsp/handlers_codeaction.go b/internal/lsp/handlers_codeaction.go
index 7631935..24429a1 100644
--- a/internal/lsp/handlers_codeaction.go
+++ b/internal/lsp/handlers_codeaction.go
@@ -245,8 +245,8 @@ func (s *Server) completeCodeAction(ca CodeAction, uri string, rng Range, sys, u
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
messages := []llm.Message{{Role: "system", Content: sys}, {Role: "user", Content: user}}
- opts := s.llmRequestOpts()
- if text, err := s.chatWithStats(ctx, messages, opts...); err == nil {
+ spec := s.buildRequestSpec(surfaceCodeAction)
+ if text, err := s.chatWithStats(ctx, surfaceCodeAction, spec, messages); err == nil {
if out := stripCodeFences(strings.TrimSpace(text)); out != "" {
edit := WorkspaceEdit{Changes: map[string][]TextEdit{uri: {{Range: rng, NewText: out}}}}
ca.Edit = &edit
@@ -555,22 +555,20 @@ func findGoFunctionAtLine(lines []string, idx int) (int, int) {
// generateGoTestFunction uses LLM to produce a test function; falls back to a stub when unavailable.
func (s *Server) generateGoTestFunction(funcCode string) string {
- if client := s.currentLLMClient(); client != nil {
- cfg := s.currentConfig()
- sys := cfg.PromptCodeActionGoTestSystem
- user := renderTemplate(cfg.PromptCodeActionGoTestUser, map[string]string{"function": funcCode})
- ctx, cancel := context.WithTimeout(context.Background(), 18*time.Second)
- defer cancel()
- messages := []llm.Message{{Role: "system", Content: sys}, {Role: "user", Content: user}}
- opts := s.llmRequestOpts()
- if out, err := s.chatWithStats(ctx, messages, opts...); err == nil {
- cleaned := strings.TrimSpace(stripCodeFences(out))
- if cleaned != "" {
- return cleaned
- }
- } else {
- logging.Logf("lsp ", "codeAction go_test llm error: %v", err)
+ spec := s.buildRequestSpec(surfaceCodeAction)
+ cfg := s.currentConfig()
+ sys := cfg.PromptCodeActionGoTestSystem
+ user := renderTemplate(cfg.PromptCodeActionGoTestUser, map[string]string{"function": funcCode})
+ ctx, cancel := context.WithTimeout(context.Background(), 18*time.Second)
+ defer cancel()
+ messages := []llm.Message{{Role: "system", Content: sys}, {Role: "user", Content: user}}
+ if out, err := s.chatWithStats(ctx, surfaceCodeAction, spec, messages); err == nil {
+ cleaned := strings.TrimSpace(stripCodeFences(out))
+ if cleaned != "" {
+ return cleaned
}
+ } else {
+ logging.Logf("lsp ", "codeAction go_test llm error: %v", err)
}
// Fallback stub
name := deriveGoFuncName(funcCode)
diff --git a/internal/lsp/handlers_completion.go b/internal/lsp/handlers_completion.go
index f7f41ef..d115741 100644
--- a/internal/lsp/handlers_completion.go
+++ b/internal/lsp/handlers_completion.go
@@ -95,11 +95,16 @@ func (s *Server) tryLLMCompletion(p CompletionParams, above, current, below, fun
return items, true
}
- if items, ok := s.tryProviderNativeCompletion(current, p, above, below, funcCtx, docStr, hasExtra, extraText, plan.inParams); ok {
+ spec := s.buildRequestSpec(surfaceCompletion)
+ client := s.clientFor(spec)
+ if client == nil {
+ return nil, false
+ }
+ if items, ok := s.tryProviderNativeCompletion(spec, client, current, p, above, below, funcCtx, docStr, hasExtra, extraText, plan.inParams); ok {
return items, true
}
- return s.executeChatCompletion(ctx, plan)
+ return s.executeChatCompletion(ctx, plan, spec, client)
}
func (s *Server) prepareCompletionPlan(p CompletionParams, above, current, below, funcCtx, docStr string, hasExtra bool, extraText string) (completionPlan, []CompletionItem, bool) {
@@ -142,31 +147,31 @@ func (s *Server) prepareCompletionPlan(p CompletionParams, above, current, below
return plan, nil, false
}
-func (s *Server) executeChatCompletion(ctx context.Context, plan completionPlan) ([]CompletionItem, bool) {
+func (s *Server) executeChatCompletion(ctx context.Context, plan completionPlan, spec requestSpec, client llm.Client) ([]CompletionItem, bool) {
messages := s.buildCompletionMessages(plan.inlinePrompt, plan.hasExtra, plan.extraText, plan.inParams, plan.params, plan.above, plan.current, plan.below, plan.funcCtx)
sentSize := 0
for _, m := range messages {
sentSize += len(m.Content)
}
s.incSentCounters(sentSize)
- opts := s.llmRequestOpts()
+ opts := spec.options
s.waitForDebounce(ctx)
if !s.waitForThrottle(ctx) {
return nil, false
}
- client := s.currentLLMClient()
- if client == nil {
- return nil, false
+ modelUsed := spec.effectiveModel()
+ if strings.TrimSpace(modelUsed) == "" {
+ modelUsed = client.DefaultModel()
}
- logging.Logf("lsp ", "completion llm=requesting model=%s", client.DefaultModel())
+ logging.Logf("lsp ", "completion llm=requesting model=%s", modelUsed)
text, err := client.Chat(ctx, messages, opts...)
if err != nil {
logging.Logf("lsp ", "llm completion error: %v", err)
- s.logLLMStats()
+ s.logLLMStats(modelUsed)
return nil, false
}
s.incRecvCounters(len(text))
- s.logLLMStats()
+ s.logLLMStats(modelUsed)
trimmed := strings.TrimSpace(text)
cleaned := s.postProcessCompletion(trimmed, plan.current[:plan.params.Position.Character], plan.current)
if cleaned == "" {
@@ -255,8 +260,7 @@ func (s *Server) prefixHeuristicAllows(inlinePrompt bool, current string, p Comp
}
// tryProviderNativeCompletion attempts provider-native completion and returns items when successful.
-func (s *Server) tryProviderNativeCompletion(current string, p CompletionParams, above, below, funcCtx, docStr string, hasExtra bool, extraText string, inParams bool) ([]CompletionItem, bool) {
- client := s.currentLLMClient()
+func (s *Server) tryProviderNativeCompletion(spec requestSpec, client llm.Client, current string, p CompletionParams, above, below, funcCtx, docStr string, hasExtra bool, extraText string, inParams bool) ([]CompletionItem, bool) {
cc, ok := client.(llm.CodeCompleter)
if !ok {
return nil, false
@@ -271,15 +275,11 @@ func (s *Server) tryProviderNativeCompletion(current string, p CompletionParams,
"before": before,
})
lang := ""
- temp := 0.0
- if cfg.CodingTemperature != nil {
- temp = *cfg.CodingTemperature
- }
- prov := ""
- if client != nil {
- prov = client.Name()
+ provider := spec.provider
+ if provider == "" {
+ provider = canonicalProvider(cfg.Provider)
}
- logging.Logf("lsp ", "completion path=codex provider=%s uri=%s", prov, path)
+ logging.Logf("lsp ", "completion path=codex provider=%s uri=%s", provider, path)
ctx2, cancel2 := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel2()
@@ -290,16 +290,24 @@ func (s *Server) tryProviderNativeCompletion(current string, p CompletionParams,
}
// Count approximate payload sizes: prompt+after sent; first suggestion received
sentBytes := len(prompt) + len(after)
- suggestions, err := cc.CodeCompletion(ctx2, prompt, after, 1, lang, temp)
+ modelUsed := spec.effectiveModel()
+ if strings.TrimSpace(modelUsed) == "" {
+ modelUsed = client.DefaultModel()
+ }
+ tempVal := 0.0
+ if val, ok := chooseSurfaceTemperature(surfaceCompletion, cfg, provider, spec.modelOverride, spec.fallbackModel); ok {
+ tempVal = val
+ }
+ suggestions, err := cc.CodeCompletion(ctx2, prompt, after, 1, lang, tempVal)
if err == nil && len(suggestions) > 0 {
// Update counters and heartbeat
s.incSentCounters(sentBytes)
s.incRecvCounters(len(suggestions[0]))
// Contribute to global stats (provider-native path)
if client != nil {
- _ = stats.Update(ctx2, client.Name(), client.DefaultModel(), sentBytes, len(suggestions[0]))
+ _ = stats.Update(ctx2, client.Name(), modelUsed, sentBytes, len(suggestions[0]))
}
- s.logLLMStats()
+ s.logLLMStats(modelUsed)
cleaned := strings.TrimSpace(suggestions[0])
if cleaned != "" {
cleaned = stripDuplicateAssignmentPrefix(current[:p.Position.Character], cleaned)
@@ -322,7 +330,7 @@ func (s *Server) tryProviderNativeCompletion(current string, p CompletionParams,
logging.Logf("lsp ", "completion path=codex error=%v (falling back to chat)", err)
// Still emit a heartbeat for visibility, even on error
s.incSentCounters(sentBytes)
- s.logLLMStats()
+ s.logLLMStats(modelUsed)
}
return nil, false
}
diff --git a/internal/lsp/handlers_document.go b/internal/lsp/handlers_document.go
index 0340866..f8ed9ed 100644
--- a/internal/lsp/handlers_document.go
+++ b/internal/lsp/handlers_document.go
@@ -161,22 +161,23 @@ func (s *Server) detectAndHandleChat(uri string) {
}
return
}
- if s.currentLLMClient() == nil {
- continue
- }
go func(prompt string, remove int) {
ctx, cancel := context.WithTimeout(context.Background(), 25*time.Second)
defer cancel()
// Build messages with history and context_mode aware extras.
pos := Position{Line: lineIdx, Character: lastIdx + 1}
msgs := s.buildChatMessages(uri, pos, prompt)
- opts := s.llmRequestOpts()
- client := s.currentLLMClient()
+ spec := s.buildRequestSpec(surfaceChat)
+ client := s.clientFor(spec)
if client == nil {
return
}
- logging.Logf("lsp ", "chat llm=requesting model=%s", client.DefaultModel())
- text, err := s.chatWithStats(ctx, msgs, opts...)
+ modelUsed := spec.effectiveModel()
+ if strings.TrimSpace(modelUsed) == "" {
+ modelUsed = client.DefaultModel()
+ }
+ logging.Logf("lsp ", "chat llm=requesting model=%s", modelUsed)
+ text, err := s.chatWithStats(ctx, surfaceChat, spec, msgs)
if err != nil {
logging.Logf("lsp ", "chat llm error: %v", err)
return
diff --git a/internal/lsp/handlers_utils.go b/internal/lsp/handlers_utils.go
index 5d5ca27..3bd13ee 100644
--- a/internal/lsp/handlers_utils.go
+++ b/internal/lsp/handlers_utils.go
@@ -7,6 +7,7 @@ import (
"strings"
"time"
+ "codeberg.org/snonux/hexai/internal/appconfig"
"codeberg.org/snonux/hexai/internal/llm"
"codeberg.org/snonux/hexai/internal/logging"
"codeberg.org/snonux/hexai/internal/stats"
@@ -14,24 +15,134 @@ import (
tmx "codeberg.org/snonux/hexai/internal/tmux"
)
-// llmRequestOpts builds request options from server settings.
-func (s *Server) llmRequestOpts() []llm.RequestOption {
+type surfaceKind string
+
+const (
+ surfaceCompletion surfaceKind = "completion"
+ surfaceCodeAction surfaceKind = "code_action"
+ surfaceChat surfaceKind = "chat"
+)
+
+type requestSpec struct {
+ provider string
+ modelOverride string
+ fallbackModel string
+ options []llm.RequestOption
+}
+
+func (r requestSpec) effectiveModel() string {
+ if s := strings.TrimSpace(r.modelOverride); s != "" {
+ return s
+ }
+ return strings.TrimSpace(r.fallbackModel)
+}
+
+func (s *Server) buildRequestSpec(surface surfaceKind) requestSpec {
+ cfg := s.currentConfig()
+ providerOverride := strings.TrimSpace(surfaceProviderFromConfig(cfg, surface))
+ provider := canonicalProvider(cfg.Provider)
+ if providerOverride != "" {
+ provider = canonicalProvider(providerOverride)
+ }
+ fallbackModel := strings.TrimSpace(resolveDefaultModel(cfg, provider))
+ modelOverride := strings.TrimSpace(surfaceModelFromConfig(cfg, surface))
maxTokens := s.maxTokens()
- client := s.currentLLMClient()
- tempPtr := s.codingTemperature()
opts := []llm.RequestOption{llm.WithMaxTokens(maxTokens)}
- if tempPtr != nil {
- temp := *tempPtr
- if client != nil {
- prov := strings.ToLower(strings.TrimSpace(client.Name()))
- model := strings.ToLower(strings.TrimSpace(client.DefaultModel()))
- if prov == "openai" && strings.HasPrefix(model, "gpt-5") {
- temp = 1.0
- }
+ if tempVal, ok := chooseSurfaceTemperature(surface, cfg, provider, modelOverride, fallbackModel); ok {
+ opts = append(opts, llm.WithTemperature(tempVal))
+ }
+ if modelOverride != "" {
+ opts = append(opts, llm.WithModel(modelOverride))
+ }
+ return requestSpec{
+ provider: provider,
+ modelOverride: modelOverride,
+ fallbackModel: fallbackModel,
+ options: opts,
+ }
+}
+
+func canonicalProvider(name string) string {
+ p := strings.ToLower(strings.TrimSpace(name))
+ if p == "" {
+ return "openai"
+ }
+ return p
+}
+
+func resolveDefaultModel(cfg appconfig.App, provider string) string {
+ switch provider {
+ case "ollama":
+ return strings.TrimSpace(cfg.OllamaModel)
+ case "copilot":
+ return strings.TrimSpace(cfg.CopilotModel)
+ default:
+ return strings.TrimSpace(cfg.OpenAIModel)
+ }
+}
+
+func surfaceModelFromConfig(cfg appconfig.App, surface surfaceKind) string {
+ switch surface {
+ case surfaceCompletion:
+ return cfg.CompletionModel
+ case surfaceCodeAction:
+ return cfg.CodeActionModel
+ case surfaceChat:
+ return cfg.ChatModel
+ default:
+ return ""
+ }
+}
+
+func surfaceProviderFromConfig(cfg appconfig.App, surface surfaceKind) string {
+ switch surface {
+ case surfaceCompletion:
+ return cfg.CompletionProvider
+ case surfaceCodeAction:
+ return cfg.CodeActionProvider
+ case surfaceChat:
+ return cfg.ChatProvider
+ default:
+ return ""
+ }
+}
+
+func surfaceTemperatureFromConfig(cfg appconfig.App, surface surfaceKind) *float64 {
+ switch surface {
+ case surfaceCompletion:
+ return cfg.CompletionTemperature
+ case surfaceCodeAction:
+ return cfg.CodeActionTemperature
+ case surfaceChat:
+ return cfg.ChatTemperature
+ default:
+ return nil
+ }
+}
+
+func chooseSurfaceTemperature(surface surfaceKind, cfg appconfig.App, provider string, overrideModel, fallbackModel string) (float64, bool) {
+ if t := surfaceTemperatureFromConfig(cfg, surface); t != nil {
+ return *t, true
+ }
+ if cfg.CodingTemperature != nil {
+ temp := *cfg.CodingTemperature
+ effectiveModel := strings.TrimSpace(overrideModel)
+ if effectiveModel == "" {
+ effectiveModel = strings.TrimSpace(fallbackModel)
+ }
+ if provider == "openai" && strings.HasPrefix(strings.ToLower(effectiveModel), "gpt-5") && temp == 0.2 {
+ temp = 1.0
}
- opts = append(opts, llm.WithTemperature(temp))
+ return temp, true
}
- return opts
+ effectiveModel := strings.TrimSpace(overrideModel)
+ if effectiveModel == "" {
+ effectiveModel = strings.TrimSpace(fallbackModel)
+ }
+ if provider == "openai" && strings.HasPrefix(strings.ToLower(effectiveModel), "gpt-5") {
+ return 1.0, true
+ }
+ return 0, false
}
// small helpers for LLM traffic stats
@@ -49,7 +160,7 @@ func (s *Server) incRecvCounters(n int) {
s.mu.Unlock()
}
-func (s *Server) logLLMStats() {
+func (s *Server) logLLMStats(model string) {
s.mu.RLock()
avgSent := int64(0)
if s.llmReqTotal > 0 {
@@ -75,11 +186,14 @@ func (s *Server) logLLMStats() {
if err == nil {
if client := s.currentLLMClient(); client != nil {
provider := client.Name()
- model := client.DefaultModel()
+ modelName := strings.TrimSpace(model)
+ if modelName == "" {
+ modelName = client.DefaultModel()
+ }
// Per-scope rpm estimated from window
scopeReqs := int64(0)
if pe, ok := snap.Providers[provider]; ok {
- if mc, ok2 := pe.Models[model]; ok2 {
+ if mc, ok2 := pe.Models[modelName]; ok2 {
scopeReqs = mc.Reqs
}
}
@@ -88,7 +202,7 @@ func (s *Server) logLLMStats() {
minsWin = 0.001
}
scopeRPM := float64(scopeReqs) / minsWin
- status := tmx.FormatGlobalStatusColored(snap.Global.Reqs, snap.RPM, snap.Global.Sent, snap.Global.Recv, provider, model, scopeRPM, scopeReqs, snap.Window)
+ status := tmx.FormatGlobalStatusColored(snap.Global.Reqs, snap.RPM, snap.Global.Sent, snap.Global.Recv, provider, modelName, scopeRPM, scopeReqs, snap.Window)
_ = tmx.SetStatus(status)
}
}
@@ -154,7 +268,7 @@ func isIdentChar(ch byte) bool {
}
// chatWithStats wraps llmClient.Chat to increment counters and emit a tmux heartbeat.
-func (s *Server) chatWithStats(ctx context.Context, msgs []llm.Message, opts ...llm.RequestOption) (string, error) {
+func (s *Server) chatWithStats(ctx context.Context, surface surfaceKind, spec requestSpec, msgs []llm.Message) (string, error) {
// Count bytes sent
sent := 0
for _, m := range msgs {
@@ -167,19 +281,23 @@ func (s *Server) chatWithStats(ctx context.Context, msgs []llm.Message, opts ...
return "", context.Canceled
}
// Perform request
- client := s.currentLLMClient()
+ client := s.clientFor(spec)
if client == nil {
return "", fmt.Errorf("llm client unavailable")
}
- txt, err := client.Chat(ctx, msgs, opts...)
+ txt, err := client.Chat(ctx, msgs, spec.options...)
if err != nil {
- s.logLLMStats()
+ s.logLLMStats(spec.effectiveModel())
return "", err
}
s.incRecvCounters(len(txt))
// Update global stats cache
- _ = stats.Update(ctx, client.Name(), client.DefaultModel(), sent, len(txt))
- s.logLLMStats()
+ model := spec.effectiveModel()
+ if model == "" {
+ model = client.DefaultModel()
+ }
+ _ = stats.Update(ctx, client.Name(), model, sent, len(txt))
+ s.logLLMStats(model)
return txt, nil
}
diff --git a/internal/lsp/llm_request_opts_test.go b/internal/lsp/llm_request_opts_test.go
index c6699b0..263db79 100644
--- a/internal/lsp/llm_request_opts_test.go
+++ b/internal/lsp/llm_request_opts_test.go
@@ -15,17 +15,22 @@ func (f fakeClient) Chat(_ context.Context, _ []llm.Message, _ ...llm.RequestOpt
func (f fakeClient) Name() string { return f.name }
func (f fakeClient) DefaultModel() string { return f.model }
-func TestLlmRequestOpts_Gpt5_ForcesTemp1(t *testing.T) {
+func TestRequestSpec_Gpt5_ForcesTemp1(t *testing.T) {
s := newTestServer()
one := 0.2
s.cfg.CodingTemperature = &one
s.llmClient = fakeClient{name: "openai", model: "gpt-5.0"}
- opts := s.llmRequestOpts()
+ s.cfg.OpenAIModel = "gpt-5.0"
+
+ spec := s.buildRequestSpec(surfaceCompletion)
var got llm.Options
- for _, o := range opts {
+ for _, o := range spec.options {
o(&got)
}
if got.Temperature != 1.0 {
t.Fatalf("expected temp 1.0 for gpt-5, got %v", got.Temperature)
}
+ if model := spec.effectiveModel(); model != "gpt-5.0" {
+ t.Fatalf("expected fallback model gpt-5.0, got %q", model)
+ }
}
diff --git a/internal/lsp/llm_stats_test.go b/internal/lsp/llm_stats_test.go
index 43582a2..7813c10 100644
--- a/internal/lsp/llm_stats_test.go
+++ b/internal/lsp/llm_stats_test.go
@@ -6,5 +6,5 @@ func TestLogLLMStats_CoversCounters(t *testing.T) {
s := newTestServer()
s.incSentCounters(10)
s.incRecvCounters(20)
- s.logLLMStats() // just ensure it does not panic and executes
+ s.logLLMStats("model") // just ensure it does not panic and executes
}
diff --git a/internal/lsp/provider_native_success_test.go b/internal/lsp/provider_native_success_test.go
index 6df5698..aab886c 100644
--- a/internal/lsp/provider_native_success_test.go
+++ b/internal/lsp/provider_native_success_test.go
@@ -21,10 +21,11 @@ func (fakeCompleterOk) CodeCompletion(context.Context, string, string, int, stri
func TestProviderNativeCompletion_Success(t *testing.T) {
s := newTestServer()
s.llmClient = fakeCompleterOk{}
+ spec := s.buildRequestSpec(surfaceCompletion)
// current line with dot trigger; position after dot
current := "fmt."
p := CompletionParams{TextDocument: TextDocumentIdentifier{URI: "file:///x.go"}, Position: Position{Line: 0, Character: len(current)}}
- items, ok := s.tryProviderNativeCompletion(current, p, "", "", "func f(){}", "doc", false, "", false)
+ items, ok := s.tryProviderNativeCompletion(spec, s.llmClient, current, p, "", "", "func f(){}", "doc", false, "", false)
if !ok || len(items) == 0 {
t.Fatalf("expected provider-native items")
}
@@ -47,9 +48,10 @@ func (fakeCompleterIndent) CodeCompletion(context.Context, string, string, int,
func TestProviderNativeCompletion_IndentWithDoubleOpen(t *testing.T) {
s := newTestServer()
s.llmClient = fakeCompleterIndent{}
+ spec := s.buildRequestSpec(surfaceCompletion)
current := " >>do>" // leading indent + double-open marker
p := CompletionParams{TextDocument: TextDocumentIdentifier{URI: "file:///x.go"}, Position: Position{Line: 0, Character: len(current)}}
- items, ok := s.tryProviderNativeCompletion(current, p, "", "", "func f(){}", "doc", false, "", false)
+ items, ok := s.tryProviderNativeCompletion(spec, s.llmClient, current, p, "", "", "func f(){}", "doc", false, "", false)
if !ok || len(items) == 0 {
t.Fatalf("expected provider-native items")
}
@@ -80,12 +82,13 @@ func TestProviderNativeCompletion_UsesPromptTemplate(t *testing.T) {
cfg := s.cfg
cfg.PromptNativeCompletion = "NATIVE {{path}} {{before}}"
s.cfg = cfg
+ spec := s.buildRequestSpec(surfaceCompletion)
uri := "file:///x.go"
s.setDocument(uri, "AAA\nBBB\nCCC")
current := "fmt."
// Cursor at line 1, char 1 -> before should be "AAA\nB"
p := CompletionParams{TextDocument: TextDocumentIdentifier{URI: uri}, Position: Position{Line: 1, Character: 1}}
- if _, ok := s.tryProviderNativeCompletion(current, p, "", "", "func f(){}", "doc", false, "", false); !ok {
+ if _, ok := s.tryProviderNativeCompletion(spec, s.llmClient, current, p, "", "", "func f(){}", "doc", false, "", false); !ok {
t.Fatalf("expected provider-native path")
}
if cap.lastPrompt == "" {
diff --git a/internal/lsp/server.go b/internal/lsp/server.go
index 7b8bc88..28f3218 100644
--- a/internal/lsp/server.go
+++ b/internal/lsp/server.go
@@ -6,6 +6,7 @@ import (
"encoding/json"
"io"
"log"
+ "os"
"strings"
"sync"
"time"
@@ -29,6 +30,8 @@ type Server struct {
configStore *runtimeconfig.Store
cfg appconfig.App
llmClient llm.Client
+ llmProvider string
+ altClients map[string]llm.Client
lastInput time.Time
// LLM request stats
llmReqTotal int64
@@ -186,6 +189,12 @@ func (s *Server) applyOptions(opts ServerOptions) {
}
}
s.llmClient = opts.Client
+ if opts.Client != nil {
+ s.llmProvider = canonicalProvider(opts.Client.Name())
+ } else {
+ s.llmProvider = canonicalProvider(s.cfg.Provider)
+ }
+ s.altClients = make(map[string]llm.Client)
}
// ApplyOptions updates the server's configuration at runtime.
@@ -199,6 +208,82 @@ func (s *Server) currentLLMClient() llm.Client {
return s.llmClient
}
+func newClientForProvider(cfg appconfig.App, provider string) (llm.Client, error) {
+ llmCfg := llm.Config{
+ Provider: provider,
+ OpenAIBaseURL: cfg.OpenAIBaseURL,
+ OpenAIModel: cfg.OpenAIModel,
+ OpenAITemperature: cfg.OpenAITemperature,
+ OllamaBaseURL: cfg.OllamaBaseURL,
+ OllamaModel: cfg.OllamaModel,
+ OllamaTemperature: cfg.OllamaTemperature,
+ CopilotBaseURL: cfg.CopilotBaseURL,
+ CopilotModel: cfg.CopilotModel,
+ CopilotTemperature: cfg.CopilotTemperature,
+ }
+ oaKey := strings.TrimSpace(os.Getenv("HEXAI_OPENAI_API_KEY"))
+ if oaKey == "" {
+ oaKey = strings.TrimSpace(os.Getenv("OPENAI_API_KEY"))
+ }
+ cpKey := strings.TrimSpace(os.Getenv("HEXAI_COPILOT_API_KEY"))
+ if cpKey == "" {
+ cpKey = strings.TrimSpace(os.Getenv("COPILOT_API_KEY"))
+ }
+ return llm.NewFromConfig(llmCfg, oaKey, cpKey)
+}
+
+func (s *Server) clientFor(spec requestSpec) llm.Client {
+ provider := canonicalProvider(spec.provider)
+ s.mu.RLock()
+ baseProvider := s.llmProvider
+ baseClient := s.llmClient
+ if baseClient != nil && strings.TrimSpace(baseProvider) == "" {
+ baseProvider = canonicalProvider(baseClient.Name())
+ }
+ if provider == "" {
+ provider = baseProvider
+ }
+ if provider == baseProvider && baseClient != nil {
+ s.mu.RUnlock()
+ return baseClient
+ }
+ if c, ok := s.altClients[provider]; ok {
+ s.mu.RUnlock()
+ return c
+ }
+ cfg := s.cfg
+ store := s.configStore
+ s.mu.RUnlock()
+ if store != nil {
+ cfg = store.Snapshot()
+ }
+ client, err := newClientForProvider(cfg, provider)
+ if err != nil {
+ logging.Logf("lsp ", "failed to build client for provider=%s: %v", provider, err)
+ if baseClient != nil {
+ return baseClient
+ }
+ return nil
+ }
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if provider == s.llmProvider {
+ if s.llmClient == nil {
+ s.llmClient = client
+ s.llmProvider = provider
+ }
+ return s.llmClient
+ }
+ if existing, ok := s.altClients[provider]; ok {
+ return existing
+ }
+ if s.altClients == nil {
+ s.altClients = make(map[string]llm.Client)
+ }
+ s.altClients[provider] = client
+ return client
+}
+
func (s *Server) currentConfig() appconfig.App {
if s.configStore != nil {
return s.configStore.Snapshot()