diff options
| author | Paul Buetow <paul@buetow.org> | 2026-04-26 08:50:20 +0300 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2026-04-26 08:50:20 +0300 |
| commit | e4a6723bb679b13401020bb4953cd7c4c9564e8c (patch) | |
| tree | 6404ac14586e6207056b83f3db8a1cd9657992d1 /internal | |
| parent | 97e2dde7693618516a42019d7aa7cfda1f5a8811 (diff) | |
feat: optional API key for Ollama provider (Ollama Cloud)
Adds an optional HEXAI_OLLAMA_API_KEY (with OLLAMA_API_KEY fallback) so
the existing Ollama provider can target Ollama Cloud (ollama.ai) in
addition to a local server. When the key is empty the request is
unauthenticated, preserving local-server behavior byte-for-byte; when
set, an Authorization: Bearer header is attached for both Chat and
ChatStream. Documented cloud usage in config.toml.example.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/llm/ollama.go | 33 | ||||
| -rw-r--r-- | internal/llm/ollama_test.go | 80 | ||||
| -rw-r--r-- | internal/llm/openai_temp_test.go | 6 | ||||
| -rw-r--r-- | internal/llm/provider.go | 11 | ||||
| -rw-r--r-- | internal/llm/provider_more_test.go | 2 | ||||
| -rw-r--r-- | internal/llm/provider_test.go | 4 | ||||
| -rw-r--r-- | internal/llmutils/client.go | 8 |
7 files changed, 116 insertions, 28 deletions
diff --git a/internal/llm/ollama.go b/internal/llm/ollama.go index b2cecfa..0916c06 100644 --- a/internal/llm/ollama.go +++ b/internal/llm/ollama.go @@ -1,4 +1,6 @@ -// Ollama client against a local server; supports chat responses and streaming via /api/chat. +// Ollama client supporting both a local server and Ollama Cloud (ollama.ai). +// The optional API key is sent as a Bearer token; when empty, requests are +// unauthenticated so a local Ollama server keeps working unchanged. package llm import ( @@ -14,11 +16,13 @@ import ( "codeberg.org/snonux/hexai/internal/logging" ) -// ollamaClient implements Client against a local Ollama server. +// ollamaClient implements Client against a local Ollama server or Ollama Cloud. +// apiKey is optional: empty for local, non-empty enables Bearer auth for cloud. type ollamaClient struct { httpClient *http.Client baseURL string defaultModel string + apiKey string chatLogger logging.ChatLogger defaultTemperature *float64 } @@ -45,21 +49,23 @@ type ollamaChatResponse struct { Error string `json:"error,omitempty"` } -func ollamaProviderFactory(cfg Config, _ ProviderKeys) (Client, error) { +func ollamaProviderFactory(cfg Config, keys ProviderKeys) (Client, error) { return newOllamaWithTimeout( cfg.OllamaBaseURL, cfg.OllamaModel, + keys.OllamaAPIKey, withDefaultTemperature(cfg.OllamaTemperature, 0.2), cfg.RequestTimeout, ), nil } -// Constructor (kept among the first functions by convention) -func newOllama(baseURL, model string, defaultTemp *float64) Client { - return newOllamaWithTimeout(baseURL, model, defaultTemp, 0) +// Constructor (kept among the first functions by convention). +// apiKey may be empty for local Ollama; pass a non-empty key for Ollama Cloud. +func newOllama(baseURL, model string, defaultTemp *float64, apiKey string) Client { + return newOllamaWithTimeout(baseURL, model, apiKey, defaultTemp, 0) } -func newOllamaWithTimeout(baseURL, model string, defaultTemp *float64, timeoutSec int) Client { +func newOllamaWithTimeout(baseURL, model, apiKey string, defaultTemp *float64, timeoutSec int) Client { if strings.TrimSpace(baseURL) == "" { baseURL = "http://localhost:11434" } @@ -73,6 +79,7 @@ func newOllamaWithTimeout(baseURL, model string, defaultTemp *float64, timeoutSe httpClient: &http.Client{Timeout: time.Duration(timeoutSec) * time.Second}, baseURL: strings.TrimRight(baseURL, "/"), defaultModel: model, + apiKey: strings.TrimSpace(apiKey), chatLogger: logging.NewChatLogger("ollama"), defaultTemperature: defaultTemp, } @@ -228,7 +235,17 @@ func buildOllamaRequest(o Options, messages []Message, defaultTemp *float64, str } func (c ollamaClient) doJSON(ctx context.Context, url string, body []byte) (*http.Response, error) { - return doJSONRequest(ctx, c.httpClient, url, body, nil, "") + return doJSONRequest(ctx, c.httpClient, url, body, c.authHeaders(), "") +} + +// authHeaders returns Bearer auth for Ollama Cloud, or nil for unauthenticated +// local Ollama. Returning nil keeps the local request shape byte-identical to +// the previous implementation. +func (c ollamaClient) authHeaders() map[string]string { + if c.apiKey == "" { + return nil + } + return map[string]string{"Authorization": "Bearer " + c.apiKey} } func handleOllamaNon2xx(resp *http.Response, start time.Time) error { diff --git a/internal/llm/ollama_test.go b/internal/llm/ollama_test.go index 8bd33ca..2216e21 100644 --- a/internal/llm/ollama_test.go +++ b/internal/llm/ollama_test.go @@ -49,7 +49,7 @@ func TestBuildOllamaRequest_TempOverride(t *testing.T) { } func TestOllama_NameAndModel(t *testing.T) { - c := newOllama("http://x", "model-x", nil).(ollamaClient) + c := newOllama("http://x", "model-x", nil, "").(ollamaClient) if c.Name() != "ollama" { t.Fatalf("name: %q", c.Name()) } @@ -58,6 +58,66 @@ func TestOllama_NameAndModel(t *testing.T) { } } +// Local Ollama (no key) must not send an Authorization header — the existing +// unauthenticated server would reject or misinterpret one. +func TestOllamaChat_NoAuthHeaderWhenKeyEmpty(t *testing.T) { + if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { + t.Skip("skip network-bound tests in restricted environments") + } + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("Authorization"); got != "" { + t.Fatalf("expected no Authorization header, got %q", got) + } + _ = json.NewEncoder(w).Encode(map[string]any{"message": map[string]string{"role": "assistant", "content": "ok"}, "done": true}) + })) + defer ts.Close() + c := newOllama(ts.URL, "m", nil, "").(ollamaClient) + c.httpClient = ts.Client() + if _, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}); err != nil { + t.Fatalf("unexpected: %v", err) + } +} + +// Ollama Cloud usage: when an API key is configured, both Chat and ChatStream +// must send "Authorization: Bearer <key>". +func TestOllamaChat_AuthHeaderWhenKeySet(t *testing.T) { + if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { + t.Skip("skip network-bound tests in restricted environments") + } + const key = "test-key-xyz" + const want = "Bearer " + key + + t.Run("chat", func(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("Authorization"); got != want { + t.Fatalf("Authorization: got %q, want %q", got, want) + } + _ = json.NewEncoder(w).Encode(map[string]any{"message": map[string]string{"role": "assistant", "content": "ok"}, "done": true}) + })) + defer ts.Close() + c := newOllama(ts.URL, "m", f64p(0.1), key).(ollamaClient) + c.httpClient = ts.Client() + if _, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}); err != nil { + t.Fatalf("unexpected: %v", err) + } + }) + + t.Run("stream", func(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("Authorization"); got != want { + t.Fatalf("Authorization: got %q, want %q", got, want) + } + _, _ = w.Write([]byte(`{"message":{"role":"assistant","content":"ok"},"done":true}`)) + })) + defer ts.Close() + c := newOllama(ts.URL, "m", nil, key).(ollamaClient) + c.httpClient = ts.Client() + if err := c.ChatStream(context.Background(), []Message{{Role: "user", Content: "hi"}}, func(string) {}); err != nil { + t.Fatalf("unexpected: %v", err) + } + }) +} + func TestOllamaChat_Success(t *testing.T) { if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") @@ -70,7 +130,7 @@ func TestOllamaChat_Success(t *testing.T) { _ = json.NewEncoder(w).Encode(map[string]any{"message": map[string]string{"role": "assistant", "content": "Hello"}, "done": true}) })) defer ts.Close() - c := newOllama(ts.URL, "m", f64p(0.1)).(ollamaClient) + c := newOllama(ts.URL, "m", f64p(0.1), "").(ollamaClient) c.httpClient = ts.Client() out, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}) if err != nil { @@ -89,7 +149,7 @@ func TestOllamaChat_EmptyContent(t *testing.T) { _ = json.NewEncoder(w).Encode(map[string]any{"message": map[string]string{"role": "assistant", "content": ""}, "done": true}) })) defer ts.Close() - c := newOllama(ts.URL, "m", nil).(ollamaClient) + c := newOllama(ts.URL, "m", nil, "").(ollamaClient) c.httpClient = ts.Client() if _, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "x"}}); err == nil { t.Fatalf("expected error for empty content") @@ -106,7 +166,7 @@ func TestOllamaChat_Non2xx(t *testing.T) { _ = json.NewEncoder(w).Encode(map[string]any{"error": "bad"}) })) defer ts1.Close() - c1 := newOllama(ts1.URL, "m", nil).(ollamaClient) + c1 := newOllama(ts1.URL, "m", nil, "").(ollamaClient) c1.httpClient = ts1.Client() if _, err := c1.Chat(context.Background(), []Message{{Role: "user", Content: "x"}}); err == nil { t.Fatalf("expected error for 400 with api body") @@ -117,7 +177,7 @@ func TestOllamaChat_Non2xx(t *testing.T) { _, _ = w.Write([]byte("{}")) })) defer ts2.Close() - c2 := newOllama(ts2.URL, "m", nil).(ollamaClient) + c2 := newOllama(ts2.URL, "m", nil, "").(ollamaClient) c2.httpClient = ts2.Client() if _, err := c2.Chat(context.Background(), []Message{{Role: "user", Content: "x"}}); err == nil { t.Fatalf("expected error for 500") @@ -129,7 +189,7 @@ type rtFunc func(*http.Request) (*http.Response, error) func (f rtFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) } func TestOllamaChat_HTTPError(t *testing.T) { - c := newOllama("http://127.0.0.1:0", "m", nil).(ollamaClient) + c := newOllama("http://127.0.0.1:0", "m", nil, "").(ollamaClient) c.httpClient = &http.Client{Transport: rtFunc(func(*http.Request) (*http.Response, error) { return nil, fmt.Errorf("boom") })} if _, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "x"}}); err == nil { t.Fatalf("expected http error path") @@ -144,7 +204,7 @@ func TestOllamaChat_DecodeError(t *testing.T) { _, _ = w.Write([]byte("{bad json}")) })) defer ts.Close() - c := newOllama(ts.URL, "m", nil).(ollamaClient) + c := newOllama(ts.URL, "m", nil, "").(ollamaClient) c.httpClient = ts.Client() if _, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "x"}}); err == nil { t.Fatalf("expected decode error") @@ -169,7 +229,7 @@ func TestOllamaChatStream_Success(t *testing.T) { _, _ = w.Write([]byte(`{"message":{"role":"assistant","content":"!"},"done":true}`)) })) defer ts.Close() - c := newOllama(ts.URL, "m", nil).(ollamaClient) + c := newOllama(ts.URL, "m", nil, "").(ollamaClient) c.httpClient = ts.Client() var got strings.Builder if err := c.ChatStream(context.Background(), []Message{{Role: "user", Content: "x"}}, func(s string) { got.WriteString(s) }); err != nil { @@ -188,7 +248,7 @@ func TestOllamaChatStream_ErrorEvent(t *testing.T) { _ = json.NewEncoder(w).Encode(map[string]any{"error": "oops"}) })) defer ts.Close() - c := newOllama(ts.URL, "m", nil).(ollamaClient) + c := newOllama(ts.URL, "m", nil, "").(ollamaClient) c.httpClient = ts.Client() if err := c.ChatStream(context.Background(), []Message{{Role: "user", Content: "x"}}, func(string) {}); err == nil { t.Fatalf("expected stream error") @@ -203,7 +263,7 @@ func TestOllamaChatStream_DecodeError(t *testing.T) { _, _ = w.Write([]byte("{not json}")) })) defer ts.Close() - c := newOllama(ts.URL, "m", nil).(ollamaClient) + c := newOllama(ts.URL, "m", nil, "").(ollamaClient) c.httpClient = ts.Client() if err := c.ChatStream(context.Background(), []Message{{Role: "user", Content: "x"}}, func(string) {}); err == nil { t.Fatalf("expected decode error") diff --git a/internal/llm/openai_temp_test.go b/internal/llm/openai_temp_test.go index 3d71b94..07abbd5 100644 --- a/internal/llm/openai_temp_test.go +++ b/internal/llm/openai_temp_test.go @@ -5,7 +5,7 @@ import "testing" func TestNewFromConfig_DefaultTemp_ByModel(t *testing.T) { // OpenAI, gpt-5.* → default temp 1.0 when not provided cfg := Config{Provider: "openai", OpenAIModel: "gpt-5.0-preview"} - c, err := NewFromConfig(cfg, "key", "", "") + c, err := NewFromConfig(cfg, "key", "", "", "") if err != nil { t.Fatalf("new: %v", err) } @@ -18,7 +18,7 @@ func TestNewFromConfig_DefaultTemp_ByModel(t *testing.T) { } // OpenAI, gpt-4.* → default temp 0.2 when not provided cfg2 := Config{Provider: "openai", OpenAIModel: "gpt-4.1"} - c2, err := NewFromConfig(cfg2, "key", "", "") + c2, err := NewFromConfig(cfg2, "key", "", "", "") if err != nil { t.Fatalf("new2: %v", err) } @@ -32,7 +32,7 @@ func TestNewFromConfig_DefaultTemp_UpgradeWhenGpt5AndDefault02(t *testing.T) { // Simulate app-default of 0.2 while selecting a gpt-5 model: should upgrade to 1.0 v := 0.2 cfg := Config{Provider: "openai", OpenAIModel: "gpt-5.0", OpenAITemperature: &v} - c, err := NewFromConfig(cfg, "key", "", "") + c, err := NewFromConfig(cfg, "key", "", "", "") if err != nil { t.Fatalf("new: %v", err) } diff --git a/internal/llm/provider.go b/internal/llm/provider.go index 6c0c04b..255297c 100644 --- a/internal/llm/provider.go +++ b/internal/llm/provider.go @@ -94,10 +94,13 @@ type Config struct { } // ProviderKeys contains API credentials used by provider factories. +// OllamaAPIKey is optional: it enables auth against Ollama Cloud while a local +// Ollama server still works with an empty key. type ProviderKeys struct { OpenAIAPIKey string OpenRouterAPIKey string AnthropicAPIKey string + OllamaAPIKey string } // ProviderFactory builds an LLM client for a named provider. @@ -143,9 +146,10 @@ func RegisterAllProviders() { } // NewFromConfig creates an LLM client using only the supplied configuration. -// The OpenAI API key is supplied separately and may be read from the environment -// by the caller; other environment-based configuration is not used. -func NewFromConfig(cfg Config, openAIAPIKey, openRouterAPIKey, anthropicAPIKey string) (Client, error) { +// API keys are supplied separately and may be read from the environment by the +// caller. ollamaAPIKey is optional and only used when targeting Ollama Cloud; +// a local Ollama server works with an empty value. +func NewFromConfig(cfg Config, openAIAPIKey, openRouterAPIKey, anthropicAPIKey, ollamaAPIKey string) (Client, error) { provider := normalizeProvider(cfg.Provider) if provider == "" { provider = "openai" @@ -160,6 +164,7 @@ func NewFromConfig(cfg Config, openAIAPIKey, openRouterAPIKey, anthropicAPIKey s OpenAIAPIKey: openAIAPIKey, OpenRouterAPIKey: openRouterAPIKey, AnthropicAPIKey: anthropicAPIKey, + OllamaAPIKey: ollamaAPIKey, }) } diff --git a/internal/llm/provider_more_test.go b/internal/llm/provider_more_test.go index 8d7b133..d3be8ef 100644 --- a/internal/llm/provider_more_test.go +++ b/internal/llm/provider_more_test.go @@ -16,7 +16,7 @@ func TestWithOptions_Apply(t *testing.T) { func TestNewFromConfig_Success_OpenAI(t *testing.T) { // OpenAI success oc := Config{Provider: "openai", OpenAIBaseURL: "http://x", OpenAIModel: "gpt"} - c, err := NewFromConfig(oc, "KEY", "", "") + c, err := NewFromConfig(oc, "KEY", "", "", "") if err != nil || c == nil || c.Name() != "openai" || c.DefaultModel() == "" { t.Fatalf("openai new: %v %v", c, err) } diff --git a/internal/llm/provider_test.go b/internal/llm/provider_test.go index 14de7a6..2acbc69 100644 --- a/internal/llm/provider_test.go +++ b/internal/llm/provider_test.go @@ -7,13 +7,13 @@ import ( func TestNewFromConfig_DefaultsAndErrors(t *testing.T) { // Unknown provider - if _, err := NewFromConfig(Config{Provider: "bogus"}, "", "", ""); err == nil { + if _, err := NewFromConfig(Config{Provider: "bogus"}, "", "", "", ""); err == nil { t.Fatalf("expected error for unknown provider") } else if !strings.Contains(err.Error(), "supported providers:") { t.Fatalf("expected supported providers hint, got %q", err.Error()) } // OpenAI missing key - if _, err := NewFromConfig(Config{Provider: "openai", OpenAIModel: "g"}, "", "", ""); err == nil { + if _, err := NewFromConfig(Config{Provider: "openai", OpenAIModel: "g"}, "", "", "", ""); err == nil { t.Fatalf("expected key error") } else if !strings.Contains(err.Error(), "OPENAI_API_KEY") || !strings.Contains(err.Error(), "HEXAI_OPENAI_API_KEY") { t.Fatalf("expected actionable API key hint, got %q", err.Error()) diff --git a/internal/llmutils/client.go b/internal/llmutils/client.go index 3641556..ef24571 100644 --- a/internal/llmutils/client.go +++ b/internal/llmutils/client.go @@ -103,5 +103,11 @@ func NewClientFromApp(cfg appconfig.App) (llm.Client, error) { if strings.TrimSpace(anKey) == "" { anKey = os.Getenv("ANTHROPIC_API_KEY") } - return llm.NewFromConfig(llmCfg, oaKey, orKey, anKey) + // Ollama API key is optional: only needed for Ollama Cloud (ollama.ai). + // A local Ollama server keeps working when this is empty. + olKey := os.Getenv("HEXAI_OLLAMA_API_KEY") + if strings.TrimSpace(olKey) == "" { + olKey = os.Getenv("OLLAMA_API_KEY") + } + return llm.NewFromConfig(llmCfg, oaKey, orKey, anKey, olKey) } |
