diff options
| author | Paul Buetow <paul@buetow.org> | 2026-03-02 13:42:06 +0200 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2026-03-02 13:42:06 +0200 |
| commit | 021785d750de2cd8d1f94334282a2b110b77c0fd (patch) | |
| tree | 06c1e4942af0e1885e4c872cbb99d4653a3ec8a6 /internal/llmutils | |
| parent | 8a70afb354d0050f3f8e1142753284859036fa1c (diff) | |
llmutils: centralize provider normalization and client setup (task 410)
Diffstat (limited to 'internal/llmutils')
| -rw-r--r-- | internal/llmutils/client.go | 53 | ||||
| -rw-r--r-- | internal/llmutils/client_test.go | 49 |
2 files changed, 102 insertions, 0 deletions
diff --git a/internal/llmutils/client.go b/internal/llmutils/client.go index c8d9a90..16a6338 100644 --- a/internal/llmutils/client.go +++ b/internal/llmutils/client.go @@ -8,6 +8,59 @@ import ( "codeberg.org/snonux/hexai/internal/llm" ) +// CanonicalProvider normalizes provider names and defaults to openai. +func CanonicalProvider(name string) string { + provider := strings.ToLower(strings.TrimSpace(name)) + if provider == "" { + return "openai" + } + return provider +} + +// DefaultModelForProvider returns the configured default model for a provider. +func DefaultModelForProvider(cfg appconfig.App, provider string) string { + switch CanonicalProvider(provider) { + case "openrouter": + return strings.TrimSpace(cfg.OpenRouterModel) + case "ollama": + return strings.TrimSpace(cfg.OllamaModel) + case "anthropic": + return strings.TrimSpace(cfg.AnthropicModel) + default: + return strings.TrimSpace(cfg.OpenAIModel) + } +} + +// ConfigForProvider returns cfg adjusted for the selected provider/model. +func ConfigForProvider(cfg appconfig.App, provider, modelOverride string) appconfig.App { + derived := cfg + if strings.TrimSpace(provider) == "" { + provider = cfg.Provider + } + normalized := CanonicalProvider(provider) + derived.Provider = normalized + model := strings.TrimSpace(modelOverride) + if model == "" { + return derived + } + switch normalized { + case "openrouter": + derived.OpenRouterModel = model + case "ollama": + derived.OllamaModel = model + case "anthropic": + derived.AnthropicModel = model + default: + derived.OpenAIModel = model + } + return derived +} + +// NewClientFromAppForProvider builds a client for a specific provider/model. +func NewClientFromAppForProvider(cfg appconfig.App, provider, modelOverride string) (llm.Client, error) { + return NewClientFromApp(ConfigForProvider(cfg, provider, modelOverride)) +} + // NewClientFromApp builds an llm.Client using app config and environment keys. func NewClientFromApp(cfg appconfig.App) (llm.Client, error) { llmCfg := llm.Config{ diff --git a/internal/llmutils/client_test.go b/internal/llmutils/client_test.go index 2e20db3..837d408 100644 --- a/internal/llmutils/client_test.go +++ b/internal/llmutils/client_test.go @@ -25,3 +25,52 @@ func TestNewClientFromApp_OpenAI_WithKey(t *testing.T) { // ensure env override precedence _ = os.Unsetenv("OPENAI_API_KEY") } + +func TestCanonicalProvider(t *testing.T) { + if got := CanonicalProvider(" OpenRouter "); got != "openrouter" { + t.Fatalf("CanonicalProvider(openrouter) = %q", got) + } + if got := CanonicalProvider(" "); got != "openai" { + t.Fatalf("CanonicalProvider(empty) = %q", got) + } +} + +func TestDefaultModelForProvider(t *testing.T) { + cfg := appconfig.App{ + OpenAIModel: "gpt-4.1", + OpenRouterModel: "openrouter/auto", + OllamaModel: "qwen3", + AnthropicModel: "claude", + } + if got := DefaultModelForProvider(cfg, "openai"); got != "gpt-4.1" { + t.Fatalf("openai model = %q", got) + } + if got := DefaultModelForProvider(cfg, "openrouter"); got != "openrouter/auto" { + t.Fatalf("openrouter model = %q", got) + } + if got := DefaultModelForProvider(cfg, "ollama"); got != "qwen3" { + t.Fatalf("ollama model = %q", got) + } + if got := DefaultModelForProvider(cfg, "anthropic"); got != "claude" { + t.Fatalf("anthropic model = %q", got) + } +} + +func TestConfigForProvider(t *testing.T) { + base := appconfig.App{ + Provider: "openai", + OpenAIModel: "gpt-4.1", + OllamaModel: "qwen3", + AnthropicModel: "claude", + } + got := ConfigForProvider(base, "ollama", "qwen3-coder") + if got.Provider != "ollama" { + t.Fatalf("provider = %q", got.Provider) + } + if got.OllamaModel != "qwen3-coder" { + t.Fatalf("ollama model = %q", got.OllamaModel) + } + if got.OpenAIModel != "gpt-4.1" { + t.Fatalf("openai model unexpectedly changed: %q", got.OpenAIModel) + } +} |
