diff options
Diffstat (limited to 'internal/llm')
| -rw-r--r-- | internal/llm/anthropic.go | 17 | ||||
| -rw-r--r-- | internal/llm/ollama.go | 13 | ||||
| -rw-r--r-- | internal/llm/openai.go | 33 | ||||
| -rw-r--r-- | internal/llm/openrouter.go | 17 | ||||
| -rw-r--r-- | internal/llm/provider.go | 119 |
5 files changed, 145 insertions, 54 deletions
diff --git a/internal/llm/anthropic.go b/internal/llm/anthropic.go index 0d87424..82d8b8a 100644 --- a/internal/llm/anthropic.go +++ b/internal/llm/anthropic.go @@ -86,6 +86,23 @@ var ( _ Streamer = (*anthropicClient)(nil) ) +func init() { + RegisterProvider("anthropic", anthropicProviderFactory) +} + +func anthropicProviderFactory(cfg Config, keys ProviderKeys) (Client, error) { + if strings.TrimSpace(keys.AnthropicAPIKey) == "" { + return nil, errors.New("missing ANTHROPIC_API_KEY for provider anthropic") + } + return newAnthropicWithTimeout( + cfg.AnthropicBaseURL, + cfg.AnthropicModel, + keys.AnthropicAPIKey, + withDefaultTemperature(cfg.AnthropicTemperature, 0.2), + cfg.RequestTimeout, + ), nil +} + // Constructor // newAnthropic constructs an Anthropic client using explicit configuration values. // The apiKey may be empty; calls will fail until a valid key is supplied. diff --git a/internal/llm/ollama.go b/internal/llm/ollama.go index a22dd7b..a878b62 100644 --- a/internal/llm/ollama.go +++ b/internal/llm/ollama.go @@ -40,6 +40,19 @@ type ollamaChatResponse struct { Error string `json:"error,omitempty"` } +func init() { + RegisterProvider("ollama", ollamaProviderFactory) +} + +func ollamaProviderFactory(cfg Config, _ ProviderKeys) (Client, error) { + return newOllamaWithTimeout( + cfg.OllamaBaseURL, + cfg.OllamaModel, + withDefaultTemperature(cfg.OllamaTemperature, 0.2), + cfg.RequestTimeout, + ), nil +} + // Constructor (kept among the first functions by convention) func newOllama(baseURL, model string, defaultTemp *float64) Client { return newOllamaWithTimeout(baseURL, model, defaultTemp, 0) diff --git a/internal/llm/openai.go b/internal/llm/openai.go index 6bc3a7c..5c1e525 100644 --- a/internal/llm/openai.go +++ b/internal/llm/openai.go @@ -73,6 +73,39 @@ type oaStreamChunk struct { } `json:"error,omitempty"` } +func init() { + RegisterProvider("openai", openAIProviderFactory) +} + +func openAIProviderFactory(cfg Config, keys ProviderKeys) (Client, error) { + if strings.TrimSpace(keys.OpenAIAPIKey) == "" { + return nil, errors.New("missing OPENAI_API_KEY for provider openai") + } + return newOpenAIWithTimeout( + cfg.OpenAIBaseURL, + cfg.OpenAIModel, + keys.OpenAIAPIKey, + resolveOpenAITemperature(cfg.OpenAIModel, cfg.OpenAITemperature), + cfg.RequestTimeout, + ), nil +} + +func resolveOpenAITemperature(model string, configured *float64) *float64 { + isGPT5 := strings.HasPrefix(strings.ToLower(strings.TrimSpace(model)), "gpt-5") + if isGPT5 { + if configured == nil || *configured == 0.2 { + v := 1.0 + return &v + } + return configured + } + if configured != nil { + return configured + } + v := 0.2 + return &v +} + // Constructor (kept among the first functions by convention) // newOpenAI constructs an OpenAI client using explicit configuration values. // The apiKey may be empty; calls will fail until a valid key is supplied. diff --git a/internal/llm/openrouter.go b/internal/llm/openrouter.go index 21e3102..8aae6b8 100644 --- a/internal/llm/openrouter.go +++ b/internal/llm/openrouter.go @@ -22,6 +22,23 @@ type openRouterClient struct { defaultTemperature *float64 } +func init() { + RegisterProvider("openrouter", openRouterProviderFactory) +} + +func openRouterProviderFactory(cfg Config, keys ProviderKeys) (Client, error) { + if strings.TrimSpace(keys.OpenRouterAPIKey) == "" { + return nil, errors.New("missing OPENROUTER_API_KEY for provider openrouter") + } + return newOpenRouterWithTimeout( + cfg.OpenRouterBaseURL, + cfg.OpenRouterModel, + keys.OpenRouterAPIKey, + withDefaultTemperature(cfg.OpenRouterTemperature, 0.2), + cfg.RequestTimeout, + ), nil +} + func newOpenRouter(baseURL, model, apiKey string, defaultTemp *float64) Client { return newOpenRouterWithTimeout(baseURL, model, apiKey, defaultTemp, 0) } diff --git a/internal/llm/provider.go b/internal/llm/provider.go index 8230b53..847ea60 100644 --- a/internal/llm/provider.go +++ b/internal/llm/provider.go @@ -5,6 +5,7 @@ import ( "context" "errors" "strings" + "sync" ) // Message represents a chat-style prompt message. @@ -84,64 +85,74 @@ type Config struct { AnthropicTemperature *float64 } +// ProviderKeys contains API credentials used by provider factories. +type ProviderKeys struct { + OpenAIAPIKey string + OpenRouterAPIKey string + AnthropicAPIKey string +} + +// ProviderFactory builds an LLM client for a named provider. +type ProviderFactory func(cfg Config, keys ProviderKeys) (Client, error) + +var ( + providerRegistryMu sync.RWMutex + providerRegistry = map[string]ProviderFactory{} +) + +// RegisterProvider registers a provider factory by normalized name. +func RegisterProvider(name string, factory ProviderFactory) { + normalized := normalizeProvider(name) + if normalized == "" { + panic("llm: provider name cannot be empty") + } + if factory == nil { + panic("llm: provider factory cannot be nil") + } + providerRegistryMu.Lock() + defer providerRegistryMu.Unlock() + if _, exists := providerRegistry[normalized]; exists { + panic("llm: provider already registered: " + normalized) + } + providerRegistry[normalized] = factory +} + // NewFromConfig creates an LLM client using only the supplied configuration. // The OpenAI API key is supplied separately and may be read from the environment // by the caller; other environment-based configuration is not used. func NewFromConfig(cfg Config, openAIAPIKey, openRouterAPIKey, anthropicAPIKey string) (Client, error) { - p := strings.ToLower(strings.TrimSpace(cfg.Provider)) - if p == "" { - p = "openai" + provider := normalizeProvider(cfg.Provider) + if provider == "" { + provider = "openai" } - switch p { - case "openai": - if strings.TrimSpace(openAIAPIKey) == "" { - return nil, errors.New("missing OPENAI_API_KEY for provider openai") - } - // Default temperature selection: - // - When model is gpt-5*, prefer 1.0 by default (more exploratory). - // - Otherwise, prefer 0.2 by default (coding friendly). - // The app-wide defaults currently set provider temps to 0.2. - // If the user hasn't explicitly overridden and the model is gpt-5*, - // upgrade 0.2 → 1.0 to satisfy the requested default for gpt-5. - model := strings.ToLower(strings.TrimSpace(cfg.OpenAIModel)) - if strings.HasPrefix(model, "gpt-5") { - if cfg.OpenAITemperature == nil { - v := 1.0 - cfg.OpenAITemperature = &v - } else if *cfg.OpenAITemperature == 0.2 { - v := 1.0 - cfg.OpenAITemperature = &v - } - } else if cfg.OpenAITemperature == nil { - v := 0.2 - cfg.OpenAITemperature = &v - } - return newOpenAIWithTimeout(cfg.OpenAIBaseURL, cfg.OpenAIModel, openAIAPIKey, cfg.OpenAITemperature, cfg.RequestTimeout), nil - case "openrouter": - if strings.TrimSpace(openRouterAPIKey) == "" { - return nil, errors.New("missing OPENROUTER_API_KEY for provider openrouter") - } - if cfg.OpenRouterTemperature == nil { - t := 0.2 - cfg.OpenRouterTemperature = &t - } - return newOpenRouterWithTimeout(cfg.OpenRouterBaseURL, cfg.OpenRouterModel, openRouterAPIKey, cfg.OpenRouterTemperature, cfg.RequestTimeout), nil - case "ollama": - if cfg.OllamaTemperature == nil { - t := 0.2 - cfg.OllamaTemperature = &t - } - return newOllamaWithTimeout(cfg.OllamaBaseURL, cfg.OllamaModel, cfg.OllamaTemperature, cfg.RequestTimeout), nil - case "anthropic": - if strings.TrimSpace(anthropicAPIKey) == "" { - return nil, errors.New("missing ANTHROPIC_API_KEY for provider anthropic") - } - if cfg.AnthropicTemperature == nil { - t := 0.2 - cfg.AnthropicTemperature = &t - } - return newAnthropicWithTimeout(cfg.AnthropicBaseURL, cfg.AnthropicModel, anthropicAPIKey, cfg.AnthropicTemperature, cfg.RequestTimeout), nil - default: - return nil, errors.New("unknown LLM provider: " + p) + + factory, ok := lookupProviderFactory(provider) + if !ok { + return nil, errors.New("unknown LLM provider: " + provider) + } + + return factory(cfg, ProviderKeys{ + OpenAIAPIKey: openAIAPIKey, + OpenRouterAPIKey: openRouterAPIKey, + AnthropicAPIKey: anthropicAPIKey, + }) +} + +func normalizeProvider(provider string) string { + return strings.ToLower(strings.TrimSpace(provider)) +} + +func lookupProviderFactory(provider string) (ProviderFactory, bool) { + providerRegistryMu.RLock() + defer providerRegistryMu.RUnlock() + factory, ok := providerRegistry[provider] + return factory, ok +} + +func withDefaultTemperature(configured *float64, fallback float64) *float64 { + if configured != nil { + return configured } + v := fallback + return &v } |
