summaryrefslogtreecommitdiff
path: root/internal/llm/provider.go
diff options
context:
space:
mode:
authorPaul Buetow <paul@buetow.org>2026-03-02 13:33:08 +0200
committerPaul Buetow <paul@buetow.org>2026-03-02 13:33:08 +0200
commit3bd295d60ecbb30852e8bcf672b1befd93eb9bff (patch)
tree2a530a83cb766a990a31d98d16328aaf634c1eeb /internal/llm/provider.go
parent10406467650942b780e5de462d5103431c5a951e (diff)
llm: add provider registry and self-registration factories (task 410)
Diffstat (limited to 'internal/llm/provider.go')
-rw-r--r--internal/llm/provider.go119
1 files changed, 65 insertions, 54 deletions
diff --git a/internal/llm/provider.go b/internal/llm/provider.go
index 8230b53..847ea60 100644
--- a/internal/llm/provider.go
+++ b/internal/llm/provider.go
@@ -5,6 +5,7 @@ import (
"context"
"errors"
"strings"
+ "sync"
)
// Message represents a chat-style prompt message.
@@ -84,64 +85,74 @@ type Config struct {
AnthropicTemperature *float64
}
+// ProviderKeys contains API credentials used by provider factories.
+type ProviderKeys struct {
+ OpenAIAPIKey string
+ OpenRouterAPIKey string
+ AnthropicAPIKey string
+}
+
+// ProviderFactory builds an LLM client for a named provider.
+type ProviderFactory func(cfg Config, keys ProviderKeys) (Client, error)
+
+var (
+ providerRegistryMu sync.RWMutex
+ providerRegistry = map[string]ProviderFactory{}
+)
+
+// RegisterProvider registers a provider factory by normalized name.
+func RegisterProvider(name string, factory ProviderFactory) {
+ normalized := normalizeProvider(name)
+ if normalized == "" {
+ panic("llm: provider name cannot be empty")
+ }
+ if factory == nil {
+ panic("llm: provider factory cannot be nil")
+ }
+ providerRegistryMu.Lock()
+ defer providerRegistryMu.Unlock()
+ if _, exists := providerRegistry[normalized]; exists {
+ panic("llm: provider already registered: " + normalized)
+ }
+ providerRegistry[normalized] = factory
+}
+
// NewFromConfig creates an LLM client using only the supplied configuration.
// The OpenAI API key is supplied separately and may be read from the environment
// by the caller; other environment-based configuration is not used.
func NewFromConfig(cfg Config, openAIAPIKey, openRouterAPIKey, anthropicAPIKey string) (Client, error) {
- p := strings.ToLower(strings.TrimSpace(cfg.Provider))
- if p == "" {
- p = "openai"
+ provider := normalizeProvider(cfg.Provider)
+ if provider == "" {
+ provider = "openai"
}
- switch p {
- case "openai":
- if strings.TrimSpace(openAIAPIKey) == "" {
- return nil, errors.New("missing OPENAI_API_KEY for provider openai")
- }
- // Default temperature selection:
- // - When model is gpt-5*, prefer 1.0 by default (more exploratory).
- // - Otherwise, prefer 0.2 by default (coding friendly).
- // The app-wide defaults currently set provider temps to 0.2.
- // If the user hasn't explicitly overridden and the model is gpt-5*,
- // upgrade 0.2 → 1.0 to satisfy the requested default for gpt-5.
- model := strings.ToLower(strings.TrimSpace(cfg.OpenAIModel))
- if strings.HasPrefix(model, "gpt-5") {
- if cfg.OpenAITemperature == nil {
- v := 1.0
- cfg.OpenAITemperature = &v
- } else if *cfg.OpenAITemperature == 0.2 {
- v := 1.0
- cfg.OpenAITemperature = &v
- }
- } else if cfg.OpenAITemperature == nil {
- v := 0.2
- cfg.OpenAITemperature = &v
- }
- return newOpenAIWithTimeout(cfg.OpenAIBaseURL, cfg.OpenAIModel, openAIAPIKey, cfg.OpenAITemperature, cfg.RequestTimeout), nil
- case "openrouter":
- if strings.TrimSpace(openRouterAPIKey) == "" {
- return nil, errors.New("missing OPENROUTER_API_KEY for provider openrouter")
- }
- if cfg.OpenRouterTemperature == nil {
- t := 0.2
- cfg.OpenRouterTemperature = &t
- }
- return newOpenRouterWithTimeout(cfg.OpenRouterBaseURL, cfg.OpenRouterModel, openRouterAPIKey, cfg.OpenRouterTemperature, cfg.RequestTimeout), nil
- case "ollama":
- if cfg.OllamaTemperature == nil {
- t := 0.2
- cfg.OllamaTemperature = &t
- }
- return newOllamaWithTimeout(cfg.OllamaBaseURL, cfg.OllamaModel, cfg.OllamaTemperature, cfg.RequestTimeout), nil
- case "anthropic":
- if strings.TrimSpace(anthropicAPIKey) == "" {
- return nil, errors.New("missing ANTHROPIC_API_KEY for provider anthropic")
- }
- if cfg.AnthropicTemperature == nil {
- t := 0.2
- cfg.AnthropicTemperature = &t
- }
- return newAnthropicWithTimeout(cfg.AnthropicBaseURL, cfg.AnthropicModel, anthropicAPIKey, cfg.AnthropicTemperature, cfg.RequestTimeout), nil
- default:
- return nil, errors.New("unknown LLM provider: " + p)
+
+ factory, ok := lookupProviderFactory(provider)
+ if !ok {
+ return nil, errors.New("unknown LLM provider: " + provider)
+ }
+
+ return factory(cfg, ProviderKeys{
+ OpenAIAPIKey: openAIAPIKey,
+ OpenRouterAPIKey: openRouterAPIKey,
+ AnthropicAPIKey: anthropicAPIKey,
+ })
+}
+
+func normalizeProvider(provider string) string {
+ return strings.ToLower(strings.TrimSpace(provider))
+}
+
+func lookupProviderFactory(provider string) (ProviderFactory, bool) {
+ providerRegistryMu.RLock()
+ defer providerRegistryMu.RUnlock()
+ factory, ok := providerRegistry[provider]
+ return factory, ok
+}
+
+func withDefaultTemperature(configured *float64, fallback float64) *float64 {
+ if configured != nil {
+ return configured
}
+ v := fallback
+ return &v
}