diff options
Diffstat (limited to 'internal/lsp/handlers_utils.go')
| -rw-r--r-- | internal/lsp/handlers_utils.go | 129 |
1 files changed, 70 insertions, 59 deletions
diff --git a/internal/lsp/handlers_utils.go b/internal/lsp/handlers_utils.go index c8d2d24..2748a60 100644 --- a/internal/lsp/handlers_utils.go +++ b/internal/lsp/handlers_utils.go @@ -25,41 +25,79 @@ const ( type requestSpec struct { provider string - modelOverride string + entry appconfig.SurfaceConfig fallbackModel string options []llm.RequestOption + index int } -func (r requestSpec) effectiveModel() string { - if s := strings.TrimSpace(r.modelOverride); s != "" { - return s +func (r requestSpec) modelOverride() string { return strings.TrimSpace(r.entry.Model) } + +func (r requestSpec) effectiveModel(defaultModel string) string { + if m := strings.TrimSpace(r.entry.Model); m != "" { + return m + } + if f := strings.TrimSpace(r.fallbackModel); f != "" { + return f } - return strings.TrimSpace(r.fallbackModel) + return strings.TrimSpace(defaultModel) } -func (s *Server) buildRequestSpec(surface surfaceKind) requestSpec { +func (s *Server) buildRequestSpecs(surface surfaceKind) []requestSpec { cfg := s.currentConfig() - providerOverride := strings.TrimSpace(surfaceProviderFromConfig(cfg, surface)) - provider := canonicalProvider(cfg.Provider) - if providerOverride != "" { - provider = canonicalProvider(providerOverride) + entries := surfaceConfigsFor(cfg, surface) + if len(entries) == 0 { + entries = []appconfig.SurfaceConfig{{Provider: cfg.Provider}} } - fallbackModel := strings.TrimSpace(resolveDefaultModel(cfg, provider)) - modelOverride := strings.TrimSpace(surfaceModelFromConfig(cfg, surface)) maxTokens := s.maxTokens() - opts := []llm.RequestOption{llm.WithMaxTokens(maxTokens)} - if tempVal, ok := chooseSurfaceTemperature(surface, cfg, provider, modelOverride, fallbackModel); ok { - opts = append(opts, llm.WithTemperature(tempVal)) - } - if modelOverride != "" { - opts = append(opts, llm.WithModel(modelOverride)) + specs := make([]requestSpec, 0, len(entries)) + for idx, raw := range entries { + entry := appconfig.SurfaceConfig{ + Provider: strings.TrimSpace(raw.Provider), + Model: strings.TrimSpace(raw.Model), + Temperature: raw.Temperature, + } + provider := entry.Provider + if provider == "" { + provider = cfg.Provider + } + provider = canonicalProvider(provider) + fallbackModel := entry.Model + if fallbackModel == "" { + fallbackModel = strings.TrimSpace(resolveDefaultModel(cfg, provider)) + } + opts := []llm.RequestOption{llm.WithMaxTokens(maxTokens)} + if entry.Model != "" { + opts = append(opts, llm.WithModel(entry.Model)) + } + if temp, ok := chooseSurfaceTemperature(surface, cfg, entry, provider, fallbackModel); ok { + opts = append(opts, llm.WithTemperature(temp)) + } + specs = append(specs, requestSpec{ + provider: provider, + entry: entry, + fallbackModel: fallbackModel, + options: opts, + index: idx, + }) } - return requestSpec{ - provider: provider, - modelOverride: modelOverride, - fallbackModel: fallbackModel, - options: opts, + return specs +} + +func (s *Server) primaryRequestSpec(surface surfaceKind) requestSpec { + specs := s.buildRequestSpecs(surface) + if len(specs) == 0 { + cfg := s.currentConfig() + provider := canonicalProvider(cfg.Provider) + fallback := strings.TrimSpace(resolveDefaultModel(cfg, provider)) + return requestSpec{provider: provider, fallbackModel: fallback, options: []llm.RequestOption{llm.WithMaxTokens(s.maxTokens())}} } + return specs[0] +} + +// buildRequestSpec is retained for consumers expecting a single-entry helper. +func (s *Server) buildRequestSpec(surface surfaceKind) requestSpec { + return s.primaryRequestSpec(surface) } func canonicalProvider(name string) string { @@ -94,37 +132,13 @@ func surfaceConfigsFor(cfg appconfig.App, surface surfaceKind) []appconfig.Surfa } } -func surfaceModelFromConfig(cfg appconfig.App, surface surfaceKind) string { - configs := surfaceConfigsFor(cfg, surface) - if len(configs) == 0 { - return "" - } - return configs[0].Model -} - -func surfaceProviderFromConfig(cfg appconfig.App, surface surfaceKind) string { - configs := surfaceConfigsFor(cfg, surface) - if len(configs) == 0 { - return "" - } - return configs[0].Provider -} - -func surfaceTemperatureFromConfig(cfg appconfig.App, surface surfaceKind) *float64 { - configs := surfaceConfigsFor(cfg, surface) - if len(configs) == 0 { - return nil - } - return configs[0].Temperature -} - -func chooseSurfaceTemperature(surface surfaceKind, cfg appconfig.App, provider string, overrideModel, fallbackModel string) (float64, bool) { - if t := surfaceTemperatureFromConfig(cfg, surface); t != nil { - return *t, true +func chooseSurfaceTemperature(surface surfaceKind, cfg appconfig.App, entry appconfig.SurfaceConfig, provider string, fallbackModel string) (float64, bool) { + if entry.Temperature != nil { + return *entry.Temperature, true } if cfg.CodingTemperature != nil { temp := *cfg.CodingTemperature - effectiveModel := strings.TrimSpace(overrideModel) + effectiveModel := strings.TrimSpace(entry.Model) if effectiveModel == "" { effectiveModel = strings.TrimSpace(fallbackModel) } @@ -133,7 +147,7 @@ func chooseSurfaceTemperature(surface surfaceKind, cfg appconfig.App, provider s } return temp, true } - effectiveModel := strings.TrimSpace(overrideModel) + effectiveModel := strings.TrimSpace(entry.Model) if effectiveModel == "" { effectiveModel = strings.TrimSpace(fallbackModel) } @@ -283,19 +297,16 @@ func (s *Server) chatWithStats(ctx context.Context, surface surfaceKind, spec re if client == nil { return "", fmt.Errorf("llm client unavailable") } + modelUsed := spec.effectiveModel(client.DefaultModel()) txt, err := client.Chat(ctx, msgs, spec.options...) if err != nil { - s.logLLMStats(spec.effectiveModel()) + s.logLLMStats(modelUsed) return "", err } s.incRecvCounters(len(txt)) // Update global stats cache - model := spec.effectiveModel() - if model == "" { - model = client.DefaultModel() - } - _ = stats.Update(ctx, client.Name(), model, sent, len(txt)) - s.logLLMStats(model) + _ = stats.Update(ctx, client.Name(), modelUsed, sent, len(txt)) + s.logLLMStats(modelUsed) return txt, nil } |
