diff options
| -rw-r--r-- | internal/hexaiaction/prompts.go | 16 | ||||
| -rw-r--r-- | internal/hexaicli/run.go | 28 | ||||
| -rw-r--r-- | internal/hexailsp/run.go | 40 | ||||
| -rw-r--r-- | internal/llmutils/client.go | 53 | ||||
| -rw-r--r-- | internal/llmutils/client_test.go | 49 | ||||
| -rw-r--r-- | internal/lsp/handlers_utils.go | 18 | ||||
| -rw-r--r-- | internal/lsp/server.go | 64 |
7 files changed, 122 insertions, 146 deletions
diff --git a/internal/hexaiaction/prompts.go b/internal/hexaiaction/prompts.go index fc743a0..2e4b52b 100644 --- a/internal/hexaiaction/prompts.go +++ b/internal/hexaiaction/prompts.go @@ -7,6 +7,7 @@ import ( "codeberg.org/snonux/hexai/internal/appconfig" "codeberg.org/snonux/hexai/internal/llm" + "codeberg.org/snonux/hexai/internal/llmutils" "codeberg.org/snonux/hexai/internal/stats" "codeberg.org/snonux/hexai/internal/textutil" "codeberg.org/snonux/hexai/internal/tmux" @@ -38,22 +39,11 @@ func providerOf(c any) string { } func canonicalProvider(name string) string { - p := strings.ToLower(strings.TrimSpace(name)) - if p == "" { - return "openai" - } - return p + return llmutils.CanonicalProvider(name) } func defaultModelForProvider(cfg appconfig.App, provider string) string { - switch provider { - case "ollama": - return cfg.OllamaModel - case "anthropic": - return cfg.AnthropicModel - default: - return cfg.OpenAIModel - } + return llmutils.DefaultModelForProvider(cfg, provider) } func selectActionTemperature(cfg appconfig.App, provider string, entry appconfig.SurfaceConfig, model string) (float64, bool) { diff --git a/internal/hexaicli/run.go b/internal/hexaicli/run.go index 9ea3a40..4cb9e01 100644 --- a/internal/hexaicli/run.go +++ b/internal/hexaicli/run.go @@ -70,18 +70,7 @@ func buildCLIJobs(cfg appconfig.App) ([]cliJob, error) { provider = cfg.Provider } provider = canonicalProvider(provider) - derived := cfg - derived.Provider = provider - switch provider { - case "openai": - if entry.Model != "" { - derived.OpenAIModel = entry.Model - } - case "ollama": - if entry.Model != "" { - derived.OllamaModel = entry.Model - } - } + derived := llmutils.ConfigForProvider(cfg, provider, entry.Model) client, err := newClientFromApp(derived) if err != nil { return nil, err @@ -136,22 +125,11 @@ func cliTemperatureFromEntry(cfg appconfig.App, provider string, entry appconfig } func canonicalProvider(name string) string { - p := strings.ToLower(strings.TrimSpace(name)) - if p == "" { - return "openai" - } - return p + return llmutils.CanonicalProvider(name) } func defaultModelForProvider(cfg appconfig.App, provider string) string { - switch provider { - case "ollama": - return cfg.OllamaModel - case "anthropic": - return cfg.AnthropicModel - default: - return cfg.OpenAIModel - } + return llmutils.DefaultModelForProvider(cfg, provider) } // Run executes the Hexai CLI behavior given arguments and I/O streams. diff --git a/internal/hexailsp/run.go b/internal/hexailsp/run.go index 99779bb..250fc67 100644 --- a/internal/hexailsp/run.go +++ b/internal/hexailsp/run.go @@ -12,6 +12,7 @@ import ( "codeberg.org/snonux/hexai/internal/appconfig" "codeberg.org/snonux/hexai/internal/ignore" "codeberg.org/snonux/hexai/internal/llm" + "codeberg.org/snonux/hexai/internal/llmutils" "codeberg.org/snonux/hexai/internal/logging" "codeberg.org/snonux/hexai/internal/lsp" "codeberg.org/snonux/hexai/internal/runtimeconfig" @@ -116,44 +117,13 @@ func buildClientIfNil(cfg appconfig.App, client llm.Client) llm.Client { if client != nil { return client } - llmCfg := llm.Config{ - Provider: cfg.Provider, - RequestTimeout: cfg.RequestTimeout, - OpenAIBaseURL: cfg.OpenAIBaseURL, - OpenAIModel: cfg.OpenAIModel, - OpenAITemperature: cfg.OpenAITemperature, - OpenRouterBaseURL: cfg.OpenRouterBaseURL, - OpenRouterModel: cfg.OpenRouterModel, - OpenRouterTemperature: cfg.OpenRouterTemperature, - OllamaBaseURL: cfg.OllamaBaseURL, - OllamaModel: cfg.OllamaModel, - OllamaTemperature: cfg.OllamaTemperature, - AnthropicBaseURL: cfg.AnthropicBaseURL, - AnthropicModel: cfg.AnthropicModel, - AnthropicTemperature: cfg.AnthropicTemperature, - } - // Prefer HEXAI_OPENAI_API_KEY; fall back to OPENAI_API_KEY - oaKey := os.Getenv("HEXAI_OPENAI_API_KEY") - if strings.TrimSpace(oaKey) == "" { - oaKey = os.Getenv("OPENAI_API_KEY") - } - // Prefer HEXAI_OPENROUTER_API_KEY; fall back to OPENROUTER_API_KEY - orKey := os.Getenv("HEXAI_OPENROUTER_API_KEY") - if strings.TrimSpace(orKey) == "" { - orKey = os.Getenv("OPENROUTER_API_KEY") - } - // Prefer HEXAI_ANTHROPIC_API_KEY; fall back to ANTHROPIC_API_KEY - anKey := os.Getenv("HEXAI_ANTHROPIC_API_KEY") - if strings.TrimSpace(anKey) == "" { - anKey = os.Getenv("ANTHROPIC_API_KEY") - } - if c, err := llm.NewFromConfig(llmCfg, oaKey, orKey, anKey); err != nil { + c, err := llmutils.NewClientFromApp(cfg) + if err != nil { logging.Logf("lsp ", "llm disabled: %v", err) return nil - } else { - logging.Logf("lsp ", "llm enabled provider=%s model=%s", c.Name(), c.DefaultModel()) - return c } + logging.Logf("lsp ", "llm enabled provider=%s model=%s", c.Name(), c.DefaultModel()) + return c } func ensureFactory(factory ServerFactory) ServerFactory { diff --git a/internal/llmutils/client.go b/internal/llmutils/client.go index c8d9a90..16a6338 100644 --- a/internal/llmutils/client.go +++ b/internal/llmutils/client.go @@ -8,6 +8,59 @@ import ( "codeberg.org/snonux/hexai/internal/llm" ) +// CanonicalProvider normalizes provider names and defaults to openai. +func CanonicalProvider(name string) string { + provider := strings.ToLower(strings.TrimSpace(name)) + if provider == "" { + return "openai" + } + return provider +} + +// DefaultModelForProvider returns the configured default model for a provider. +func DefaultModelForProvider(cfg appconfig.App, provider string) string { + switch CanonicalProvider(provider) { + case "openrouter": + return strings.TrimSpace(cfg.OpenRouterModel) + case "ollama": + return strings.TrimSpace(cfg.OllamaModel) + case "anthropic": + return strings.TrimSpace(cfg.AnthropicModel) + default: + return strings.TrimSpace(cfg.OpenAIModel) + } +} + +// ConfigForProvider returns cfg adjusted for the selected provider/model. +func ConfigForProvider(cfg appconfig.App, provider, modelOverride string) appconfig.App { + derived := cfg + if strings.TrimSpace(provider) == "" { + provider = cfg.Provider + } + normalized := CanonicalProvider(provider) + derived.Provider = normalized + model := strings.TrimSpace(modelOverride) + if model == "" { + return derived + } + switch normalized { + case "openrouter": + derived.OpenRouterModel = model + case "ollama": + derived.OllamaModel = model + case "anthropic": + derived.AnthropicModel = model + default: + derived.OpenAIModel = model + } + return derived +} + +// NewClientFromAppForProvider builds a client for a specific provider/model. +func NewClientFromAppForProvider(cfg appconfig.App, provider, modelOverride string) (llm.Client, error) { + return NewClientFromApp(ConfigForProvider(cfg, provider, modelOverride)) +} + // NewClientFromApp builds an llm.Client using app config and environment keys. func NewClientFromApp(cfg appconfig.App) (llm.Client, error) { llmCfg := llm.Config{ diff --git a/internal/llmutils/client_test.go b/internal/llmutils/client_test.go index 2e20db3..837d408 100644 --- a/internal/llmutils/client_test.go +++ b/internal/llmutils/client_test.go @@ -25,3 +25,52 @@ func TestNewClientFromApp_OpenAI_WithKey(t *testing.T) { // ensure env override precedence _ = os.Unsetenv("OPENAI_API_KEY") } + +func TestCanonicalProvider(t *testing.T) { + if got := CanonicalProvider(" OpenRouter "); got != "openrouter" { + t.Fatalf("CanonicalProvider(openrouter) = %q", got) + } + if got := CanonicalProvider(" "); got != "openai" { + t.Fatalf("CanonicalProvider(empty) = %q", got) + } +} + +func TestDefaultModelForProvider(t *testing.T) { + cfg := appconfig.App{ + OpenAIModel: "gpt-4.1", + OpenRouterModel: "openrouter/auto", + OllamaModel: "qwen3", + AnthropicModel: "claude", + } + if got := DefaultModelForProvider(cfg, "openai"); got != "gpt-4.1" { + t.Fatalf("openai model = %q", got) + } + if got := DefaultModelForProvider(cfg, "openrouter"); got != "openrouter/auto" { + t.Fatalf("openrouter model = %q", got) + } + if got := DefaultModelForProvider(cfg, "ollama"); got != "qwen3" { + t.Fatalf("ollama model = %q", got) + } + if got := DefaultModelForProvider(cfg, "anthropic"); got != "claude" { + t.Fatalf("anthropic model = %q", got) + } +} + +func TestConfigForProvider(t *testing.T) { + base := appconfig.App{ + Provider: "openai", + OpenAIModel: "gpt-4.1", + OllamaModel: "qwen3", + AnthropicModel: "claude", + } + got := ConfigForProvider(base, "ollama", "qwen3-coder") + if got.Provider != "ollama" { + t.Fatalf("provider = %q", got.Provider) + } + if got.OllamaModel != "qwen3-coder" { + t.Fatalf("ollama model = %q", got.OllamaModel) + } + if got.OpenAIModel != "gpt-4.1" { + t.Fatalf("openai model unexpectedly changed: %q", got.OpenAIModel) + } +} diff --git a/internal/lsp/handlers_utils.go b/internal/lsp/handlers_utils.go index 1ea36c8..1f6acfe 100644 --- a/internal/lsp/handlers_utils.go +++ b/internal/lsp/handlers_utils.go @@ -9,6 +9,7 @@ import ( "codeberg.org/snonux/hexai/internal/appconfig" "codeberg.org/snonux/hexai/internal/llm" + "codeberg.org/snonux/hexai/internal/llmutils" "codeberg.org/snonux/hexai/internal/logging" "codeberg.org/snonux/hexai/internal/stats" "codeberg.org/snonux/hexai/internal/textutil" @@ -99,24 +100,11 @@ func (s *Server) buildRequestSpec(surface surfaceKind) requestSpec { } func canonicalProvider(name string) string { - p := strings.ToLower(strings.TrimSpace(name)) - if p == "" { - return "openai" - } - return p + return llmutils.CanonicalProvider(name) } func resolveDefaultModel(cfg appconfig.App, provider string) string { - switch provider { - case "ollama": - return strings.TrimSpace(cfg.OllamaModel) - case "anthropic": - return strings.TrimSpace(cfg.AnthropicModel) - case "openrouter": - return strings.TrimSpace(cfg.OpenRouterModel) - default: - return strings.TrimSpace(cfg.OpenAIModel) - } + return llmutils.DefaultModelForProvider(cfg, provider) } func surfaceConfigsFor(cfg appconfig.App, surface surfaceKind) []appconfig.SurfaceConfig { diff --git a/internal/lsp/server.go b/internal/lsp/server.go index a5a8a2a..385f5ce 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -6,7 +6,6 @@ import ( "encoding/json" "io" "log" - "os" "strings" "sync" "time" @@ -14,6 +13,7 @@ import ( "codeberg.org/snonux/hexai/internal/appconfig" "codeberg.org/snonux/hexai/internal/ignore" "codeberg.org/snonux/hexai/internal/llm" + "codeberg.org/snonux/hexai/internal/llmutils" "codeberg.org/snonux/hexai/internal/logging" "codeberg.org/snonux/hexai/internal/runtimeconfig" ) @@ -227,36 +227,8 @@ func (s *Server) currentLLMClient() llm.Client { return s.llmClient } -func newClientForProvider(cfg appconfig.App, provider string) (llm.Client, error) { - llmCfg := llm.Config{ - Provider: provider, - RequestTimeout: cfg.RequestTimeout, - OpenAIBaseURL: cfg.OpenAIBaseURL, - OpenAIModel: cfg.OpenAIModel, - OpenAITemperature: cfg.OpenAITemperature, - OpenRouterBaseURL: cfg.OpenRouterBaseURL, - OpenRouterModel: cfg.OpenRouterModel, - OpenRouterTemperature: cfg.OpenRouterTemperature, - OllamaBaseURL: cfg.OllamaBaseURL, - OllamaModel: cfg.OllamaModel, - OllamaTemperature: cfg.OllamaTemperature, - AnthropicBaseURL: cfg.AnthropicBaseURL, - AnthropicModel: cfg.AnthropicModel, - AnthropicTemperature: cfg.AnthropicTemperature, - } - oaKey := strings.TrimSpace(os.Getenv("HEXAI_OPENAI_API_KEY")) - if oaKey == "" { - oaKey = strings.TrimSpace(os.Getenv("OPENAI_API_KEY")) - } - orKey := strings.TrimSpace(os.Getenv("HEXAI_OPENROUTER_API_KEY")) - if orKey == "" { - orKey = strings.TrimSpace(os.Getenv("OPENROUTER_API_KEY")) - } - anKey := strings.TrimSpace(os.Getenv("HEXAI_ANTHROPIC_API_KEY")) - if anKey == "" { - anKey = strings.TrimSpace(os.Getenv("ANTHROPIC_API_KEY")) - } - return llm.NewFromConfig(llmCfg, oaKey, orKey, anKey) +func newClientForProvider(cfg appconfig.App, provider, modelOverride string) (llm.Client, error) { + return llmutils.NewClientFromAppForProvider(cfg, provider, modelOverride) } func (s *Server) clientFor(spec requestSpec) llm.Client { @@ -284,35 +256,11 @@ func (s *Server) clientFor(spec requestSpec) llm.Client { if store != nil { cfg = store.Snapshot() } - cfg.Provider = provider modelOverride := strings.TrimSpace(spec.entry.Model) - switch provider { - case "openai": - if modelOverride != "" { - cfg.OpenAIModel = modelOverride - } else if spec.fallbackModel != "" { - cfg.OpenAIModel = spec.fallbackModel - } - case "openrouter": - if modelOverride != "" { - cfg.OpenRouterModel = modelOverride - } else if spec.fallbackModel != "" { - cfg.OpenRouterModel = spec.fallbackModel - } - case "ollama": - if modelOverride != "" { - cfg.OllamaModel = modelOverride - } else if spec.fallbackModel != "" { - cfg.OllamaModel = spec.fallbackModel - } - case "anthropic": - if modelOverride != "" { - cfg.AnthropicModel = modelOverride - } else if spec.fallbackModel != "" { - cfg.AnthropicModel = spec.fallbackModel - } + if modelOverride == "" { + modelOverride = strings.TrimSpace(spec.fallbackModel) } - client, err := newClientForProvider(cfg, provider) + client, err := newClientForProvider(cfg, provider, modelOverride) if err != nil { logging.Logf("lsp ", "failed to build client for provider=%s: %v", provider, err) if baseClient != nil { |
