diff options
| author | Paul Buetow <paul@buetow.org> | 2025-09-26 19:34:19 +0300 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2025-09-26 19:34:19 +0300 |
| commit | 0583b360ceb606b8e58f12a17f588bd27feeb117 (patch) | |
| tree | ae8ac0d7968a409a76d18d84e080d02da52ce775 /internal/lsp | |
| parent | 869c018a7a26285263cf7692f25f6aa44e2635c9 (diff) | |
Add per-surface provider overrides and wiring
Diffstat (limited to 'internal/lsp')
| -rw-r--r-- | internal/lsp/document_test.go | 9 | ||||
| -rw-r--r-- | internal/lsp/handlers_codeaction.go | 32 | ||||
| -rw-r--r-- | internal/lsp/handlers_completion.go | 56 | ||||
| -rw-r--r-- | internal/lsp/handlers_document.go | 15 | ||||
| -rw-r--r-- | internal/lsp/handlers_utils.go | 166 | ||||
| -rw-r--r-- | internal/lsp/llm_request_opts_test.go | 11 | ||||
| -rw-r--r-- | internal/lsp/llm_stats_test.go | 2 | ||||
| -rw-r--r-- | internal/lsp/provider_native_success_test.go | 9 | ||||
| -rw-r--r-- | internal/lsp/server.go | 85 |
9 files changed, 303 insertions, 82 deletions
diff --git a/internal/lsp/document_test.go b/internal/lsp/document_test.go index ed2ccea..fd13e5d 100644 --- a/internal/lsp/document_test.go +++ b/internal/lsp/document_test.go @@ -8,6 +8,7 @@ import ( "testing" "codeberg.org/snonux/hexai/internal/appconfig" + "codeberg.org/snonux/hexai/internal/llm" ) func newTestServer() *Server { @@ -35,9 +36,11 @@ func newTestServer() *Server { PromptCodeActionGoTestUser: "Function under test:\n{{function}}", } return &Server{ - logger: log.New(io.Discard, "", 0), - docs: make(map[string]*document), - cfg: cfg, + logger: log.New(io.Discard, "", 0), + docs: make(map[string]*document), + cfg: cfg, + altClients: make(map[string]llm.Client), + llmProvider: canonicalProvider(cfg.Provider), } } diff --git a/internal/lsp/handlers_codeaction.go b/internal/lsp/handlers_codeaction.go index 7631935..24429a1 100644 --- a/internal/lsp/handlers_codeaction.go +++ b/internal/lsp/handlers_codeaction.go @@ -245,8 +245,8 @@ func (s *Server) completeCodeAction(ca CodeAction, uri string, rng Range, sys, u ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() messages := []llm.Message{{Role: "system", Content: sys}, {Role: "user", Content: user}} - opts := s.llmRequestOpts() - if text, err := s.chatWithStats(ctx, messages, opts...); err == nil { + spec := s.buildRequestSpec(surfaceCodeAction) + if text, err := s.chatWithStats(ctx, surfaceCodeAction, spec, messages); err == nil { if out := stripCodeFences(strings.TrimSpace(text)); out != "" { edit := WorkspaceEdit{Changes: map[string][]TextEdit{uri: {{Range: rng, NewText: out}}}} ca.Edit = &edit @@ -555,22 +555,20 @@ func findGoFunctionAtLine(lines []string, idx int) (int, int) { // generateGoTestFunction uses LLM to produce a test function; falls back to a stub when unavailable. func (s *Server) generateGoTestFunction(funcCode string) string { - if client := s.currentLLMClient(); client != nil { - cfg := s.currentConfig() - sys := cfg.PromptCodeActionGoTestSystem - user := renderTemplate(cfg.PromptCodeActionGoTestUser, map[string]string{"function": funcCode}) - ctx, cancel := context.WithTimeout(context.Background(), 18*time.Second) - defer cancel() - messages := []llm.Message{{Role: "system", Content: sys}, {Role: "user", Content: user}} - opts := s.llmRequestOpts() - if out, err := s.chatWithStats(ctx, messages, opts...); err == nil { - cleaned := strings.TrimSpace(stripCodeFences(out)) - if cleaned != "" { - return cleaned - } - } else { - logging.Logf("lsp ", "codeAction go_test llm error: %v", err) + spec := s.buildRequestSpec(surfaceCodeAction) + cfg := s.currentConfig() + sys := cfg.PromptCodeActionGoTestSystem + user := renderTemplate(cfg.PromptCodeActionGoTestUser, map[string]string{"function": funcCode}) + ctx, cancel := context.WithTimeout(context.Background(), 18*time.Second) + defer cancel() + messages := []llm.Message{{Role: "system", Content: sys}, {Role: "user", Content: user}} + if out, err := s.chatWithStats(ctx, surfaceCodeAction, spec, messages); err == nil { + cleaned := strings.TrimSpace(stripCodeFences(out)) + if cleaned != "" { + return cleaned } + } else { + logging.Logf("lsp ", "codeAction go_test llm error: %v", err) } // Fallback stub name := deriveGoFuncName(funcCode) diff --git a/internal/lsp/handlers_completion.go b/internal/lsp/handlers_completion.go index f7f41ef..d115741 100644 --- a/internal/lsp/handlers_completion.go +++ b/internal/lsp/handlers_completion.go @@ -95,11 +95,16 @@ func (s *Server) tryLLMCompletion(p CompletionParams, above, current, below, fun return items, true } - if items, ok := s.tryProviderNativeCompletion(current, p, above, below, funcCtx, docStr, hasExtra, extraText, plan.inParams); ok { + spec := s.buildRequestSpec(surfaceCompletion) + client := s.clientFor(spec) + if client == nil { + return nil, false + } + if items, ok := s.tryProviderNativeCompletion(spec, client, current, p, above, below, funcCtx, docStr, hasExtra, extraText, plan.inParams); ok { return items, true } - return s.executeChatCompletion(ctx, plan) + return s.executeChatCompletion(ctx, plan, spec, client) } func (s *Server) prepareCompletionPlan(p CompletionParams, above, current, below, funcCtx, docStr string, hasExtra bool, extraText string) (completionPlan, []CompletionItem, bool) { @@ -142,31 +147,31 @@ func (s *Server) prepareCompletionPlan(p CompletionParams, above, current, below return plan, nil, false } -func (s *Server) executeChatCompletion(ctx context.Context, plan completionPlan) ([]CompletionItem, bool) { +func (s *Server) executeChatCompletion(ctx context.Context, plan completionPlan, spec requestSpec, client llm.Client) ([]CompletionItem, bool) { messages := s.buildCompletionMessages(plan.inlinePrompt, plan.hasExtra, plan.extraText, plan.inParams, plan.params, plan.above, plan.current, plan.below, plan.funcCtx) sentSize := 0 for _, m := range messages { sentSize += len(m.Content) } s.incSentCounters(sentSize) - opts := s.llmRequestOpts() + opts := spec.options s.waitForDebounce(ctx) if !s.waitForThrottle(ctx) { return nil, false } - client := s.currentLLMClient() - if client == nil { - return nil, false + modelUsed := spec.effectiveModel() + if strings.TrimSpace(modelUsed) == "" { + modelUsed = client.DefaultModel() } - logging.Logf("lsp ", "completion llm=requesting model=%s", client.DefaultModel()) + logging.Logf("lsp ", "completion llm=requesting model=%s", modelUsed) text, err := client.Chat(ctx, messages, opts...) if err != nil { logging.Logf("lsp ", "llm completion error: %v", err) - s.logLLMStats() + s.logLLMStats(modelUsed) return nil, false } s.incRecvCounters(len(text)) - s.logLLMStats() + s.logLLMStats(modelUsed) trimmed := strings.TrimSpace(text) cleaned := s.postProcessCompletion(trimmed, plan.current[:plan.params.Position.Character], plan.current) if cleaned == "" { @@ -255,8 +260,7 @@ func (s *Server) prefixHeuristicAllows(inlinePrompt bool, current string, p Comp } // tryProviderNativeCompletion attempts provider-native completion and returns items when successful. -func (s *Server) tryProviderNativeCompletion(current string, p CompletionParams, above, below, funcCtx, docStr string, hasExtra bool, extraText string, inParams bool) ([]CompletionItem, bool) { - client := s.currentLLMClient() +func (s *Server) tryProviderNativeCompletion(spec requestSpec, client llm.Client, current string, p CompletionParams, above, below, funcCtx, docStr string, hasExtra bool, extraText string, inParams bool) ([]CompletionItem, bool) { cc, ok := client.(llm.CodeCompleter) if !ok { return nil, false @@ -271,15 +275,11 @@ func (s *Server) tryProviderNativeCompletion(current string, p CompletionParams, "before": before, }) lang := "" - temp := 0.0 - if cfg.CodingTemperature != nil { - temp = *cfg.CodingTemperature - } - prov := "" - if client != nil { - prov = client.Name() + provider := spec.provider + if provider == "" { + provider = canonicalProvider(cfg.Provider) } - logging.Logf("lsp ", "completion path=codex provider=%s uri=%s", prov, path) + logging.Logf("lsp ", "completion path=codex provider=%s uri=%s", provider, path) ctx2, cancel2 := context.WithTimeout(context.Background(), 15*time.Second) defer cancel2() @@ -290,16 +290,24 @@ func (s *Server) tryProviderNativeCompletion(current string, p CompletionParams, } // Count approximate payload sizes: prompt+after sent; first suggestion received sentBytes := len(prompt) + len(after) - suggestions, err := cc.CodeCompletion(ctx2, prompt, after, 1, lang, temp) + modelUsed := spec.effectiveModel() + if strings.TrimSpace(modelUsed) == "" { + modelUsed = client.DefaultModel() + } + tempVal := 0.0 + if val, ok := chooseSurfaceTemperature(surfaceCompletion, cfg, provider, spec.modelOverride, spec.fallbackModel); ok { + tempVal = val + } + suggestions, err := cc.CodeCompletion(ctx2, prompt, after, 1, lang, tempVal) if err == nil && len(suggestions) > 0 { // Update counters and heartbeat s.incSentCounters(sentBytes) s.incRecvCounters(len(suggestions[0])) // Contribute to global stats (provider-native path) if client != nil { - _ = stats.Update(ctx2, client.Name(), client.DefaultModel(), sentBytes, len(suggestions[0])) + _ = stats.Update(ctx2, client.Name(), modelUsed, sentBytes, len(suggestions[0])) } - s.logLLMStats() + s.logLLMStats(modelUsed) cleaned := strings.TrimSpace(suggestions[0]) if cleaned != "" { cleaned = stripDuplicateAssignmentPrefix(current[:p.Position.Character], cleaned) @@ -322,7 +330,7 @@ func (s *Server) tryProviderNativeCompletion(current string, p CompletionParams, logging.Logf("lsp ", "completion path=codex error=%v (falling back to chat)", err) // Still emit a heartbeat for visibility, even on error s.incSentCounters(sentBytes) - s.logLLMStats() + s.logLLMStats(modelUsed) } return nil, false } diff --git a/internal/lsp/handlers_document.go b/internal/lsp/handlers_document.go index 0340866..f8ed9ed 100644 --- a/internal/lsp/handlers_document.go +++ b/internal/lsp/handlers_document.go @@ -161,22 +161,23 @@ func (s *Server) detectAndHandleChat(uri string) { } return } - if s.currentLLMClient() == nil { - continue - } go func(prompt string, remove int) { ctx, cancel := context.WithTimeout(context.Background(), 25*time.Second) defer cancel() // Build messages with history and context_mode aware extras. pos := Position{Line: lineIdx, Character: lastIdx + 1} msgs := s.buildChatMessages(uri, pos, prompt) - opts := s.llmRequestOpts() - client := s.currentLLMClient() + spec := s.buildRequestSpec(surfaceChat) + client := s.clientFor(spec) if client == nil { return } - logging.Logf("lsp ", "chat llm=requesting model=%s", client.DefaultModel()) - text, err := s.chatWithStats(ctx, msgs, opts...) + modelUsed := spec.effectiveModel() + if strings.TrimSpace(modelUsed) == "" { + modelUsed = client.DefaultModel() + } + logging.Logf("lsp ", "chat llm=requesting model=%s", modelUsed) + text, err := s.chatWithStats(ctx, surfaceChat, spec, msgs) if err != nil { logging.Logf("lsp ", "chat llm error: %v", err) return diff --git a/internal/lsp/handlers_utils.go b/internal/lsp/handlers_utils.go index 5d5ca27..3bd13ee 100644 --- a/internal/lsp/handlers_utils.go +++ b/internal/lsp/handlers_utils.go @@ -7,6 +7,7 @@ import ( "strings" "time" + "codeberg.org/snonux/hexai/internal/appconfig" "codeberg.org/snonux/hexai/internal/llm" "codeberg.org/snonux/hexai/internal/logging" "codeberg.org/snonux/hexai/internal/stats" @@ -14,24 +15,134 @@ import ( tmx "codeberg.org/snonux/hexai/internal/tmux" ) -// llmRequestOpts builds request options from server settings. -func (s *Server) llmRequestOpts() []llm.RequestOption { +type surfaceKind string + +const ( + surfaceCompletion surfaceKind = "completion" + surfaceCodeAction surfaceKind = "code_action" + surfaceChat surfaceKind = "chat" +) + +type requestSpec struct { + provider string + modelOverride string + fallbackModel string + options []llm.RequestOption +} + +func (r requestSpec) effectiveModel() string { + if s := strings.TrimSpace(r.modelOverride); s != "" { + return s + } + return strings.TrimSpace(r.fallbackModel) +} + +func (s *Server) buildRequestSpec(surface surfaceKind) requestSpec { + cfg := s.currentConfig() + providerOverride := strings.TrimSpace(surfaceProviderFromConfig(cfg, surface)) + provider := canonicalProvider(cfg.Provider) + if providerOverride != "" { + provider = canonicalProvider(providerOverride) + } + fallbackModel := strings.TrimSpace(resolveDefaultModel(cfg, provider)) + modelOverride := strings.TrimSpace(surfaceModelFromConfig(cfg, surface)) maxTokens := s.maxTokens() - client := s.currentLLMClient() - tempPtr := s.codingTemperature() opts := []llm.RequestOption{llm.WithMaxTokens(maxTokens)} - if tempPtr != nil { - temp := *tempPtr - if client != nil { - prov := strings.ToLower(strings.TrimSpace(client.Name())) - model := strings.ToLower(strings.TrimSpace(client.DefaultModel())) - if prov == "openai" && strings.HasPrefix(model, "gpt-5") { - temp = 1.0 - } + if tempVal, ok := chooseSurfaceTemperature(surface, cfg, provider, modelOverride, fallbackModel); ok { + opts = append(opts, llm.WithTemperature(tempVal)) + } + if modelOverride != "" { + opts = append(opts, llm.WithModel(modelOverride)) + } + return requestSpec{ + provider: provider, + modelOverride: modelOverride, + fallbackModel: fallbackModel, + options: opts, + } +} + +func canonicalProvider(name string) string { + p := strings.ToLower(strings.TrimSpace(name)) + if p == "" { + return "openai" + } + return p +} + +func resolveDefaultModel(cfg appconfig.App, provider string) string { + switch provider { + case "ollama": + return strings.TrimSpace(cfg.OllamaModel) + case "copilot": + return strings.TrimSpace(cfg.CopilotModel) + default: + return strings.TrimSpace(cfg.OpenAIModel) + } +} + +func surfaceModelFromConfig(cfg appconfig.App, surface surfaceKind) string { + switch surface { + case surfaceCompletion: + return cfg.CompletionModel + case surfaceCodeAction: + return cfg.CodeActionModel + case surfaceChat: + return cfg.ChatModel + default: + return "" + } +} + +func surfaceProviderFromConfig(cfg appconfig.App, surface surfaceKind) string { + switch surface { + case surfaceCompletion: + return cfg.CompletionProvider + case surfaceCodeAction: + return cfg.CodeActionProvider + case surfaceChat: + return cfg.ChatProvider + default: + return "" + } +} + +func surfaceTemperatureFromConfig(cfg appconfig.App, surface surfaceKind) *float64 { + switch surface { + case surfaceCompletion: + return cfg.CompletionTemperature + case surfaceCodeAction: + return cfg.CodeActionTemperature + case surfaceChat: + return cfg.ChatTemperature + default: + return nil + } +} + +func chooseSurfaceTemperature(surface surfaceKind, cfg appconfig.App, provider string, overrideModel, fallbackModel string) (float64, bool) { + if t := surfaceTemperatureFromConfig(cfg, surface); t != nil { + return *t, true + } + if cfg.CodingTemperature != nil { + temp := *cfg.CodingTemperature + effectiveModel := strings.TrimSpace(overrideModel) + if effectiveModel == "" { + effectiveModel = strings.TrimSpace(fallbackModel) + } + if provider == "openai" && strings.HasPrefix(strings.ToLower(effectiveModel), "gpt-5") && temp == 0.2 { + temp = 1.0 } - opts = append(opts, llm.WithTemperature(temp)) + return temp, true } - return opts + effectiveModel := strings.TrimSpace(overrideModel) + if effectiveModel == "" { + effectiveModel = strings.TrimSpace(fallbackModel) + } + if provider == "openai" && strings.HasPrefix(strings.ToLower(effectiveModel), "gpt-5") { + return 1.0, true + } + return 0, false } // small helpers for LLM traffic stats @@ -49,7 +160,7 @@ func (s *Server) incRecvCounters(n int) { s.mu.Unlock() } -func (s *Server) logLLMStats() { +func (s *Server) logLLMStats(model string) { s.mu.RLock() avgSent := int64(0) if s.llmReqTotal > 0 { @@ -75,11 +186,14 @@ func (s *Server) logLLMStats() { if err == nil { if client := s.currentLLMClient(); client != nil { provider := client.Name() - model := client.DefaultModel() + modelName := strings.TrimSpace(model) + if modelName == "" { + modelName = client.DefaultModel() + } // Per-scope rpm estimated from window scopeReqs := int64(0) if pe, ok := snap.Providers[provider]; ok { - if mc, ok2 := pe.Models[model]; ok2 { + if mc, ok2 := pe.Models[modelName]; ok2 { scopeReqs = mc.Reqs } } @@ -88,7 +202,7 @@ func (s *Server) logLLMStats() { minsWin = 0.001 } scopeRPM := float64(scopeReqs) / minsWin - status := tmx.FormatGlobalStatusColored(snap.Global.Reqs, snap.RPM, snap.Global.Sent, snap.Global.Recv, provider, model, scopeRPM, scopeReqs, snap.Window) + status := tmx.FormatGlobalStatusColored(snap.Global.Reqs, snap.RPM, snap.Global.Sent, snap.Global.Recv, provider, modelName, scopeRPM, scopeReqs, snap.Window) _ = tmx.SetStatus(status) } } @@ -154,7 +268,7 @@ func isIdentChar(ch byte) bool { } // chatWithStats wraps llmClient.Chat to increment counters and emit a tmux heartbeat. -func (s *Server) chatWithStats(ctx context.Context, msgs []llm.Message, opts ...llm.RequestOption) (string, error) { +func (s *Server) chatWithStats(ctx context.Context, surface surfaceKind, spec requestSpec, msgs []llm.Message) (string, error) { // Count bytes sent sent := 0 for _, m := range msgs { @@ -167,19 +281,23 @@ func (s *Server) chatWithStats(ctx context.Context, msgs []llm.Message, opts ... return "", context.Canceled } // Perform request - client := s.currentLLMClient() + client := s.clientFor(spec) if client == nil { return "", fmt.Errorf("llm client unavailable") } - txt, err := client.Chat(ctx, msgs, opts...) + txt, err := client.Chat(ctx, msgs, spec.options...) if err != nil { - s.logLLMStats() + s.logLLMStats(spec.effectiveModel()) return "", err } s.incRecvCounters(len(txt)) // Update global stats cache - _ = stats.Update(ctx, client.Name(), client.DefaultModel(), sent, len(txt)) - s.logLLMStats() + model := spec.effectiveModel() + if model == "" { + model = client.DefaultModel() + } + _ = stats.Update(ctx, client.Name(), model, sent, len(txt)) + s.logLLMStats(model) return txt, nil } diff --git a/internal/lsp/llm_request_opts_test.go b/internal/lsp/llm_request_opts_test.go index c6699b0..263db79 100644 --- a/internal/lsp/llm_request_opts_test.go +++ b/internal/lsp/llm_request_opts_test.go @@ -15,17 +15,22 @@ func (f fakeClient) Chat(_ context.Context, _ []llm.Message, _ ...llm.RequestOpt func (f fakeClient) Name() string { return f.name } func (f fakeClient) DefaultModel() string { return f.model } -func TestLlmRequestOpts_Gpt5_ForcesTemp1(t *testing.T) { +func TestRequestSpec_Gpt5_ForcesTemp1(t *testing.T) { s := newTestServer() one := 0.2 s.cfg.CodingTemperature = &one s.llmClient = fakeClient{name: "openai", model: "gpt-5.0"} - opts := s.llmRequestOpts() + s.cfg.OpenAIModel = "gpt-5.0" + + spec := s.buildRequestSpec(surfaceCompletion) var got llm.Options - for _, o := range opts { + for _, o := range spec.options { o(&got) } if got.Temperature != 1.0 { t.Fatalf("expected temp 1.0 for gpt-5, got %v", got.Temperature) } + if model := spec.effectiveModel(); model != "gpt-5.0" { + t.Fatalf("expected fallback model gpt-5.0, got %q", model) + } } diff --git a/internal/lsp/llm_stats_test.go b/internal/lsp/llm_stats_test.go index 43582a2..7813c10 100644 --- a/internal/lsp/llm_stats_test.go +++ b/internal/lsp/llm_stats_test.go @@ -6,5 +6,5 @@ func TestLogLLMStats_CoversCounters(t *testing.T) { s := newTestServer() s.incSentCounters(10) s.incRecvCounters(20) - s.logLLMStats() // just ensure it does not panic and executes + s.logLLMStats("model") // just ensure it does not panic and executes } diff --git a/internal/lsp/provider_native_success_test.go b/internal/lsp/provider_native_success_test.go index 6df5698..aab886c 100644 --- a/internal/lsp/provider_native_success_test.go +++ b/internal/lsp/provider_native_success_test.go @@ -21,10 +21,11 @@ func (fakeCompleterOk) CodeCompletion(context.Context, string, string, int, stri func TestProviderNativeCompletion_Success(t *testing.T) { s := newTestServer() s.llmClient = fakeCompleterOk{} + spec := s.buildRequestSpec(surfaceCompletion) // current line with dot trigger; position after dot current := "fmt." p := CompletionParams{TextDocument: TextDocumentIdentifier{URI: "file:///x.go"}, Position: Position{Line: 0, Character: len(current)}} - items, ok := s.tryProviderNativeCompletion(current, p, "", "", "func f(){}", "doc", false, "", false) + items, ok := s.tryProviderNativeCompletion(spec, s.llmClient, current, p, "", "", "func f(){}", "doc", false, "", false) if !ok || len(items) == 0 { t.Fatalf("expected provider-native items") } @@ -47,9 +48,10 @@ func (fakeCompleterIndent) CodeCompletion(context.Context, string, string, int, func TestProviderNativeCompletion_IndentWithDoubleOpen(t *testing.T) { s := newTestServer() s.llmClient = fakeCompleterIndent{} + spec := s.buildRequestSpec(surfaceCompletion) current := " >>do>" // leading indent + double-open marker p := CompletionParams{TextDocument: TextDocumentIdentifier{URI: "file:///x.go"}, Position: Position{Line: 0, Character: len(current)}} - items, ok := s.tryProviderNativeCompletion(current, p, "", "", "func f(){}", "doc", false, "", false) + items, ok := s.tryProviderNativeCompletion(spec, s.llmClient, current, p, "", "", "func f(){}", "doc", false, "", false) if !ok || len(items) == 0 { t.Fatalf("expected provider-native items") } @@ -80,12 +82,13 @@ func TestProviderNativeCompletion_UsesPromptTemplate(t *testing.T) { cfg := s.cfg cfg.PromptNativeCompletion = "NATIVE {{path}} {{before}}" s.cfg = cfg + spec := s.buildRequestSpec(surfaceCompletion) uri := "file:///x.go" s.setDocument(uri, "AAA\nBBB\nCCC") current := "fmt." // Cursor at line 1, char 1 -> before should be "AAA\nB" p := CompletionParams{TextDocument: TextDocumentIdentifier{URI: uri}, Position: Position{Line: 1, Character: 1}} - if _, ok := s.tryProviderNativeCompletion(current, p, "", "", "func f(){}", "doc", false, "", false); !ok { + if _, ok := s.tryProviderNativeCompletion(spec, s.llmClient, current, p, "", "", "func f(){}", "doc", false, "", false); !ok { t.Fatalf("expected provider-native path") } if cap.lastPrompt == "" { diff --git a/internal/lsp/server.go b/internal/lsp/server.go index 7b8bc88..28f3218 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -6,6 +6,7 @@ import ( "encoding/json" "io" "log" + "os" "strings" "sync" "time" @@ -29,6 +30,8 @@ type Server struct { configStore *runtimeconfig.Store cfg appconfig.App llmClient llm.Client + llmProvider string + altClients map[string]llm.Client lastInput time.Time // LLM request stats llmReqTotal int64 @@ -186,6 +189,12 @@ func (s *Server) applyOptions(opts ServerOptions) { } } s.llmClient = opts.Client + if opts.Client != nil { + s.llmProvider = canonicalProvider(opts.Client.Name()) + } else { + s.llmProvider = canonicalProvider(s.cfg.Provider) + } + s.altClients = make(map[string]llm.Client) } // ApplyOptions updates the server's configuration at runtime. @@ -199,6 +208,82 @@ func (s *Server) currentLLMClient() llm.Client { return s.llmClient } +func newClientForProvider(cfg appconfig.App, provider string) (llm.Client, error) { + llmCfg := llm.Config{ + Provider: provider, + OpenAIBaseURL: cfg.OpenAIBaseURL, + OpenAIModel: cfg.OpenAIModel, + OpenAITemperature: cfg.OpenAITemperature, + OllamaBaseURL: cfg.OllamaBaseURL, + OllamaModel: cfg.OllamaModel, + OllamaTemperature: cfg.OllamaTemperature, + CopilotBaseURL: cfg.CopilotBaseURL, + CopilotModel: cfg.CopilotModel, + CopilotTemperature: cfg.CopilotTemperature, + } + oaKey := strings.TrimSpace(os.Getenv("HEXAI_OPENAI_API_KEY")) + if oaKey == "" { + oaKey = strings.TrimSpace(os.Getenv("OPENAI_API_KEY")) + } + cpKey := strings.TrimSpace(os.Getenv("HEXAI_COPILOT_API_KEY")) + if cpKey == "" { + cpKey = strings.TrimSpace(os.Getenv("COPILOT_API_KEY")) + } + return llm.NewFromConfig(llmCfg, oaKey, cpKey) +} + +func (s *Server) clientFor(spec requestSpec) llm.Client { + provider := canonicalProvider(spec.provider) + s.mu.RLock() + baseProvider := s.llmProvider + baseClient := s.llmClient + if baseClient != nil && strings.TrimSpace(baseProvider) == "" { + baseProvider = canonicalProvider(baseClient.Name()) + } + if provider == "" { + provider = baseProvider + } + if provider == baseProvider && baseClient != nil { + s.mu.RUnlock() + return baseClient + } + if c, ok := s.altClients[provider]; ok { + s.mu.RUnlock() + return c + } + cfg := s.cfg + store := s.configStore + s.mu.RUnlock() + if store != nil { + cfg = store.Snapshot() + } + client, err := newClientForProvider(cfg, provider) + if err != nil { + logging.Logf("lsp ", "failed to build client for provider=%s: %v", provider, err) + if baseClient != nil { + return baseClient + } + return nil + } + s.mu.Lock() + defer s.mu.Unlock() + if provider == s.llmProvider { + if s.llmClient == nil { + s.llmClient = client + s.llmProvider = provider + } + return s.llmClient + } + if existing, ok := s.altClients[provider]; ok { + return existing + } + if s.altClients == nil { + s.altClients = make(map[string]llm.Client) + } + s.altClients[provider] = client + return client +} + func (s *Server) currentConfig() appconfig.App { if s.configStore != nil { return s.configStore.Snapshot() |
