summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
Diffstat (limited to 'internal')
-rw-r--r--internal/llm/anthropic.go17
-rw-r--r--internal/llm/ollama.go13
-rw-r--r--internal/llm/openai.go33
-rw-r--r--internal/llm/openrouter.go17
-rw-r--r--internal/llm/provider.go119
5 files changed, 145 insertions, 54 deletions
diff --git a/internal/llm/anthropic.go b/internal/llm/anthropic.go
index 0d87424..82d8b8a 100644
--- a/internal/llm/anthropic.go
+++ b/internal/llm/anthropic.go
@@ -86,6 +86,23 @@ var (
_ Streamer = (*anthropicClient)(nil)
)
+func init() {
+ RegisterProvider("anthropic", anthropicProviderFactory)
+}
+
+func anthropicProviderFactory(cfg Config, keys ProviderKeys) (Client, error) {
+ if strings.TrimSpace(keys.AnthropicAPIKey) == "" {
+ return nil, errors.New("missing ANTHROPIC_API_KEY for provider anthropic")
+ }
+ return newAnthropicWithTimeout(
+ cfg.AnthropicBaseURL,
+ cfg.AnthropicModel,
+ keys.AnthropicAPIKey,
+ withDefaultTemperature(cfg.AnthropicTemperature, 0.2),
+ cfg.RequestTimeout,
+ ), nil
+}
+
// Constructor
// newAnthropic constructs an Anthropic client using explicit configuration values.
// The apiKey may be empty; calls will fail until a valid key is supplied.
diff --git a/internal/llm/ollama.go b/internal/llm/ollama.go
index a22dd7b..a878b62 100644
--- a/internal/llm/ollama.go
+++ b/internal/llm/ollama.go
@@ -40,6 +40,19 @@ type ollamaChatResponse struct {
Error string `json:"error,omitempty"`
}
+func init() {
+ RegisterProvider("ollama", ollamaProviderFactory)
+}
+
+func ollamaProviderFactory(cfg Config, _ ProviderKeys) (Client, error) {
+ return newOllamaWithTimeout(
+ cfg.OllamaBaseURL,
+ cfg.OllamaModel,
+ withDefaultTemperature(cfg.OllamaTemperature, 0.2),
+ cfg.RequestTimeout,
+ ), nil
+}
+
// Constructor (kept among the first functions by convention)
func newOllama(baseURL, model string, defaultTemp *float64) Client {
return newOllamaWithTimeout(baseURL, model, defaultTemp, 0)
diff --git a/internal/llm/openai.go b/internal/llm/openai.go
index 6bc3a7c..5c1e525 100644
--- a/internal/llm/openai.go
+++ b/internal/llm/openai.go
@@ -73,6 +73,39 @@ type oaStreamChunk struct {
} `json:"error,omitempty"`
}
+func init() {
+ RegisterProvider("openai", openAIProviderFactory)
+}
+
+func openAIProviderFactory(cfg Config, keys ProviderKeys) (Client, error) {
+ if strings.TrimSpace(keys.OpenAIAPIKey) == "" {
+ return nil, errors.New("missing OPENAI_API_KEY for provider openai")
+ }
+ return newOpenAIWithTimeout(
+ cfg.OpenAIBaseURL,
+ cfg.OpenAIModel,
+ keys.OpenAIAPIKey,
+ resolveOpenAITemperature(cfg.OpenAIModel, cfg.OpenAITemperature),
+ cfg.RequestTimeout,
+ ), nil
+}
+
+func resolveOpenAITemperature(model string, configured *float64) *float64 {
+ isGPT5 := strings.HasPrefix(strings.ToLower(strings.TrimSpace(model)), "gpt-5")
+ if isGPT5 {
+ if configured == nil || *configured == 0.2 {
+ v := 1.0
+ return &v
+ }
+ return configured
+ }
+ if configured != nil {
+ return configured
+ }
+ v := 0.2
+ return &v
+}
+
// Constructor (kept among the first functions by convention)
// newOpenAI constructs an OpenAI client using explicit configuration values.
// The apiKey may be empty; calls will fail until a valid key is supplied.
diff --git a/internal/llm/openrouter.go b/internal/llm/openrouter.go
index 21e3102..8aae6b8 100644
--- a/internal/llm/openrouter.go
+++ b/internal/llm/openrouter.go
@@ -22,6 +22,23 @@ type openRouterClient struct {
defaultTemperature *float64
}
+func init() {
+ RegisterProvider("openrouter", openRouterProviderFactory)
+}
+
+func openRouterProviderFactory(cfg Config, keys ProviderKeys) (Client, error) {
+ if strings.TrimSpace(keys.OpenRouterAPIKey) == "" {
+ return nil, errors.New("missing OPENROUTER_API_KEY for provider openrouter")
+ }
+ return newOpenRouterWithTimeout(
+ cfg.OpenRouterBaseURL,
+ cfg.OpenRouterModel,
+ keys.OpenRouterAPIKey,
+ withDefaultTemperature(cfg.OpenRouterTemperature, 0.2),
+ cfg.RequestTimeout,
+ ), nil
+}
+
func newOpenRouter(baseURL, model, apiKey string, defaultTemp *float64) Client {
return newOpenRouterWithTimeout(baseURL, model, apiKey, defaultTemp, 0)
}
diff --git a/internal/llm/provider.go b/internal/llm/provider.go
index 8230b53..847ea60 100644
--- a/internal/llm/provider.go
+++ b/internal/llm/provider.go
@@ -5,6 +5,7 @@ import (
"context"
"errors"
"strings"
+ "sync"
)
// Message represents a chat-style prompt message.
@@ -84,64 +85,74 @@ type Config struct {
AnthropicTemperature *float64
}
+// ProviderKeys contains API credentials used by provider factories.
+type ProviderKeys struct {
+ OpenAIAPIKey string
+ OpenRouterAPIKey string
+ AnthropicAPIKey string
+}
+
+// ProviderFactory builds an LLM client for a named provider.
+type ProviderFactory func(cfg Config, keys ProviderKeys) (Client, error)
+
+var (
+ providerRegistryMu sync.RWMutex
+ providerRegistry = map[string]ProviderFactory{}
+)
+
+// RegisterProvider registers a provider factory by normalized name.
+func RegisterProvider(name string, factory ProviderFactory) {
+ normalized := normalizeProvider(name)
+ if normalized == "" {
+ panic("llm: provider name cannot be empty")
+ }
+ if factory == nil {
+ panic("llm: provider factory cannot be nil")
+ }
+ providerRegistryMu.Lock()
+ defer providerRegistryMu.Unlock()
+ if _, exists := providerRegistry[normalized]; exists {
+ panic("llm: provider already registered: " + normalized)
+ }
+ providerRegistry[normalized] = factory
+}
+
// NewFromConfig creates an LLM client using only the supplied configuration.
// The OpenAI API key is supplied separately and may be read from the environment
// by the caller; other environment-based configuration is not used.
func NewFromConfig(cfg Config, openAIAPIKey, openRouterAPIKey, anthropicAPIKey string) (Client, error) {
- p := strings.ToLower(strings.TrimSpace(cfg.Provider))
- if p == "" {
- p = "openai"
+ provider := normalizeProvider(cfg.Provider)
+ if provider == "" {
+ provider = "openai"
}
- switch p {
- case "openai":
- if strings.TrimSpace(openAIAPIKey) == "" {
- return nil, errors.New("missing OPENAI_API_KEY for provider openai")
- }
- // Default temperature selection:
- // - When model is gpt-5*, prefer 1.0 by default (more exploratory).
- // - Otherwise, prefer 0.2 by default (coding friendly).
- // The app-wide defaults currently set provider temps to 0.2.
- // If the user hasn't explicitly overridden and the model is gpt-5*,
- // upgrade 0.2 → 1.0 to satisfy the requested default for gpt-5.
- model := strings.ToLower(strings.TrimSpace(cfg.OpenAIModel))
- if strings.HasPrefix(model, "gpt-5") {
- if cfg.OpenAITemperature == nil {
- v := 1.0
- cfg.OpenAITemperature = &v
- } else if *cfg.OpenAITemperature == 0.2 {
- v := 1.0
- cfg.OpenAITemperature = &v
- }
- } else if cfg.OpenAITemperature == nil {
- v := 0.2
- cfg.OpenAITemperature = &v
- }
- return newOpenAIWithTimeout(cfg.OpenAIBaseURL, cfg.OpenAIModel, openAIAPIKey, cfg.OpenAITemperature, cfg.RequestTimeout), nil
- case "openrouter":
- if strings.TrimSpace(openRouterAPIKey) == "" {
- return nil, errors.New("missing OPENROUTER_API_KEY for provider openrouter")
- }
- if cfg.OpenRouterTemperature == nil {
- t := 0.2
- cfg.OpenRouterTemperature = &t
- }
- return newOpenRouterWithTimeout(cfg.OpenRouterBaseURL, cfg.OpenRouterModel, openRouterAPIKey, cfg.OpenRouterTemperature, cfg.RequestTimeout), nil
- case "ollama":
- if cfg.OllamaTemperature == nil {
- t := 0.2
- cfg.OllamaTemperature = &t
- }
- return newOllamaWithTimeout(cfg.OllamaBaseURL, cfg.OllamaModel, cfg.OllamaTemperature, cfg.RequestTimeout), nil
- case "anthropic":
- if strings.TrimSpace(anthropicAPIKey) == "" {
- return nil, errors.New("missing ANTHROPIC_API_KEY for provider anthropic")
- }
- if cfg.AnthropicTemperature == nil {
- t := 0.2
- cfg.AnthropicTemperature = &t
- }
- return newAnthropicWithTimeout(cfg.AnthropicBaseURL, cfg.AnthropicModel, anthropicAPIKey, cfg.AnthropicTemperature, cfg.RequestTimeout), nil
- default:
- return nil, errors.New("unknown LLM provider: " + p)
+
+ factory, ok := lookupProviderFactory(provider)
+ if !ok {
+ return nil, errors.New("unknown LLM provider: " + provider)
+ }
+
+ return factory(cfg, ProviderKeys{
+ OpenAIAPIKey: openAIAPIKey,
+ OpenRouterAPIKey: openRouterAPIKey,
+ AnthropicAPIKey: anthropicAPIKey,
+ })
+}
+
+func normalizeProvider(provider string) string {
+ return strings.ToLower(strings.TrimSpace(provider))
+}
+
+func lookupProviderFactory(provider string) (ProviderFactory, bool) {
+ providerRegistryMu.RLock()
+ defer providerRegistryMu.RUnlock()
+ factory, ok := providerRegistry[provider]
+ return factory, ok
+}
+
+func withDefaultTemperature(configured *float64, fallback float64) *float64 {
+ if configured != nil {
+ return configured
}
+ v := fallback
+ return &v
}