diff options
Diffstat (limited to 'internal/llm/provider.go')
| -rw-r--r-- | internal/llm/provider.go | 119 |
1 files changed, 65 insertions, 54 deletions
diff --git a/internal/llm/provider.go b/internal/llm/provider.go index 8230b53..847ea60 100644 --- a/internal/llm/provider.go +++ b/internal/llm/provider.go @@ -5,6 +5,7 @@ import ( "context" "errors" "strings" + "sync" ) // Message represents a chat-style prompt message. @@ -84,64 +85,74 @@ type Config struct { AnthropicTemperature *float64 } +// ProviderKeys contains API credentials used by provider factories. +type ProviderKeys struct { + OpenAIAPIKey string + OpenRouterAPIKey string + AnthropicAPIKey string +} + +// ProviderFactory builds an LLM client for a named provider. +type ProviderFactory func(cfg Config, keys ProviderKeys) (Client, error) + +var ( + providerRegistryMu sync.RWMutex + providerRegistry = map[string]ProviderFactory{} +) + +// RegisterProvider registers a provider factory by normalized name. +func RegisterProvider(name string, factory ProviderFactory) { + normalized := normalizeProvider(name) + if normalized == "" { + panic("llm: provider name cannot be empty") + } + if factory == nil { + panic("llm: provider factory cannot be nil") + } + providerRegistryMu.Lock() + defer providerRegistryMu.Unlock() + if _, exists := providerRegistry[normalized]; exists { + panic("llm: provider already registered: " + normalized) + } + providerRegistry[normalized] = factory +} + // NewFromConfig creates an LLM client using only the supplied configuration. // The OpenAI API key is supplied separately and may be read from the environment // by the caller; other environment-based configuration is not used. func NewFromConfig(cfg Config, openAIAPIKey, openRouterAPIKey, anthropicAPIKey string) (Client, error) { - p := strings.ToLower(strings.TrimSpace(cfg.Provider)) - if p == "" { - p = "openai" + provider := normalizeProvider(cfg.Provider) + if provider == "" { + provider = "openai" } - switch p { - case "openai": - if strings.TrimSpace(openAIAPIKey) == "" { - return nil, errors.New("missing OPENAI_API_KEY for provider openai") - } - // Default temperature selection: - // - When model is gpt-5*, prefer 1.0 by default (more exploratory). - // - Otherwise, prefer 0.2 by default (coding friendly). - // The app-wide defaults currently set provider temps to 0.2. - // If the user hasn't explicitly overridden and the model is gpt-5*, - // upgrade 0.2 → 1.0 to satisfy the requested default for gpt-5. - model := strings.ToLower(strings.TrimSpace(cfg.OpenAIModel)) - if strings.HasPrefix(model, "gpt-5") { - if cfg.OpenAITemperature == nil { - v := 1.0 - cfg.OpenAITemperature = &v - } else if *cfg.OpenAITemperature == 0.2 { - v := 1.0 - cfg.OpenAITemperature = &v - } - } else if cfg.OpenAITemperature == nil { - v := 0.2 - cfg.OpenAITemperature = &v - } - return newOpenAIWithTimeout(cfg.OpenAIBaseURL, cfg.OpenAIModel, openAIAPIKey, cfg.OpenAITemperature, cfg.RequestTimeout), nil - case "openrouter": - if strings.TrimSpace(openRouterAPIKey) == "" { - return nil, errors.New("missing OPENROUTER_API_KEY for provider openrouter") - } - if cfg.OpenRouterTemperature == nil { - t := 0.2 - cfg.OpenRouterTemperature = &t - } - return newOpenRouterWithTimeout(cfg.OpenRouterBaseURL, cfg.OpenRouterModel, openRouterAPIKey, cfg.OpenRouterTemperature, cfg.RequestTimeout), nil - case "ollama": - if cfg.OllamaTemperature == nil { - t := 0.2 - cfg.OllamaTemperature = &t - } - return newOllamaWithTimeout(cfg.OllamaBaseURL, cfg.OllamaModel, cfg.OllamaTemperature, cfg.RequestTimeout), nil - case "anthropic": - if strings.TrimSpace(anthropicAPIKey) == "" { - return nil, errors.New("missing ANTHROPIC_API_KEY for provider anthropic") - } - if cfg.AnthropicTemperature == nil { - t := 0.2 - cfg.AnthropicTemperature = &t - } - return newAnthropicWithTimeout(cfg.AnthropicBaseURL, cfg.AnthropicModel, anthropicAPIKey, cfg.AnthropicTemperature, cfg.RequestTimeout), nil - default: - return nil, errors.New("unknown LLM provider: " + p) + + factory, ok := lookupProviderFactory(provider) + if !ok { + return nil, errors.New("unknown LLM provider: " + provider) + } + + return factory(cfg, ProviderKeys{ + OpenAIAPIKey: openAIAPIKey, + OpenRouterAPIKey: openRouterAPIKey, + AnthropicAPIKey: anthropicAPIKey, + }) +} + +func normalizeProvider(provider string) string { + return strings.ToLower(strings.TrimSpace(provider)) +} + +func lookupProviderFactory(provider string) (ProviderFactory, bool) { + providerRegistryMu.RLock() + defer providerRegistryMu.RUnlock() + factory, ok := providerRegistry[provider] + return factory, ok +} + +func withDefaultTemperature(configured *float64, fallback float64) *float64 { + if configured != nil { + return configured } + v := fallback + return &v } |
