summaryrefslogtreecommitdiff
path: root/internal/llm
diff options
context:
space:
mode:
Diffstat (limited to 'internal/llm')
-rw-r--r--internal/llm/openai.go38
-rw-r--r--internal/llm/openai_http_test.go2
-rw-r--r--internal/llm/openai_request_test.go6
-rw-r--r--internal/llm/openai_temp_test.go6
-rw-r--r--internal/llm/openrouter.go168
-rw-r--r--internal/llm/openrouter_test.go125
-rw-r--r--internal/llm/provider.go15
-rw-r--r--internal/llm/provider_more2_test.go2
-rw-r--r--internal/llm/provider_more_test.go4
-rw-r--r--internal/llm/provider_test.go6
10 files changed, 339 insertions, 33 deletions
diff --git a/internal/llm/openai.go b/internal/llm/openai.go
index 8a0d6d7..c284bb3 100644
--- a/internal/llm/openai.go
+++ b/internal/llm/openai.go
@@ -106,7 +106,7 @@ func (c openAIClient) Chat(ctx context.Context, messages []Message, opts ...Requ
}
start := time.Now()
c.logStart(false, o, messages)
- req := buildOAChatRequest(o, messages, c.defaultTemperature, false)
+ req := buildOAChatRequest(o, messages, c.defaultTemperature, false, "llm/openai ")
body, err := json.Marshal(req)
if err != nil {
c.logf("marshal error: %v", err)
@@ -122,10 +122,10 @@ func (c openAIClient) Chat(ctx context.Context, messages []Message, opts ...Requ
return "", err
}
defer resp.Body.Close()
- if err := handleOpenAINon2xx(resp, start); err != nil {
+ if err := handleOpenAINon2xx(resp, start, "llm/openai ", "openai"); err != nil {
return "", err
}
- out, err := decodeOpenAIChat(resp, start)
+ out, err := decodeOpenAIChat(resp, start, "llm/openai ")
if err != nil {
return "", err
}
@@ -157,7 +157,7 @@ func (c openAIClient) ChatStream(ctx context.Context, messages []Message, onDelt
}
start := time.Now()
c.logStart(true, o, messages)
- req := buildOAChatRequest(o, messages, c.defaultTemperature, true)
+ req := buildOAChatRequest(o, messages, c.defaultTemperature, true, "llm/openai ")
body, err := json.Marshal(req)
if err != nil {
c.logf("marshal error: %v", err)
@@ -173,11 +173,11 @@ func (c openAIClient) ChatStream(ctx context.Context, messages []Message, onDelt
return err
}
defer resp.Body.Close()
- if err := handleOpenAINon2xx(resp, start); err != nil {
+ if err := handleOpenAINon2xx(resp, start, "llm/openai ", "openai"); err != nil {
return err
}
- if err := parseOpenAIStream(resp, start, onDelta); err != nil {
+ if err := parseOpenAIStream(resp, start, onDelta, "llm/openai ", "openai"); err != nil {
return err
}
logging.Logf("llm/openai ", "stream end duration=%s", time.Since(start))
@@ -196,7 +196,7 @@ func (c openAIClient) logStart(stream bool, o Options, messages []Message) {
c.chatLogger.LogStart(stream, o.Model, o.Temperature, o.MaxTokens, o.Stop, logMessages)
}
-func buildOAChatRequest(o Options, messages []Message, defaultTemp *float64, stream bool) oaChatRequest {
+func buildOAChatRequest(o Options, messages []Message, defaultTemp *float64, stream bool, logPrefix string) oaChatRequest {
req := oaChatRequest{Model: o.Model, Stream: stream}
req.Messages = make([]oaMessage, len(messages))
for i, m := range messages {
@@ -223,7 +223,7 @@ func buildOAChatRequest(o Options, messages []Message, defaultTemp *float64, str
if req.Temperature == nil || *req.Temperature != 1.0 {
t := 1.0
req.Temperature = &t
- logging.Logf("llm/openai ", "forcing temperature=1.0 for model=%s (gpt-5 constraint)", o.Model)
+ logging.Logf(logPrefix, "forcing temperature=1.0 for model=%s (gpt-5 constraint)", o.Model)
}
}
return req
@@ -262,30 +262,30 @@ func (c openAIClient) doJSONWithAccept(ctx context.Context, url string, body []b
return c.httpClient.Do(req)
}
-func handleOpenAINon2xx(resp *http.Response, start time.Time) error {
+func handleOpenAINon2xx(resp *http.Response, start time.Time, logPrefix, provider string) error {
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
return nil
}
var apiErr oaChatResponse
_ = json.NewDecoder(resp.Body).Decode(&apiErr)
if apiErr.Error != nil && apiErr.Error.Message != "" {
- logging.Logf("llm/openai ", "%sapi error status=%d type=%s msg=%s duration=%s%s", logging.AnsiRed, resp.StatusCode, apiErr.Error.Type, apiErr.Error.Message, time.Since(start), logging.AnsiBase)
- return fmt.Errorf("openai error: %s (status %d)", apiErr.Error.Message, resp.StatusCode)
+ logging.Logf(logPrefix, "%sapi error status=%d type=%s msg=%s duration=%s%s", logging.AnsiRed, resp.StatusCode, apiErr.Error.Type, apiErr.Error.Message, time.Since(start), logging.AnsiBase)
+ return fmt.Errorf("%s error: %s (status %d)", provider, apiErr.Error.Message, resp.StatusCode)
}
- logging.Logf("llm/openai ", "%shttp non-2xx status=%d duration=%s%s", logging.AnsiRed, resp.StatusCode, time.Since(start), logging.AnsiBase)
- return fmt.Errorf("openai http error: status %d", resp.StatusCode)
+ logging.Logf(logPrefix, "%shttp non-2xx status=%d duration=%s%s", logging.AnsiRed, resp.StatusCode, time.Since(start), logging.AnsiBase)
+ return fmt.Errorf("%s http error: status %d", provider, resp.StatusCode)
}
-func decodeOpenAIChat(resp *http.Response, start time.Time) (oaChatResponse, error) {
+func decodeOpenAIChat(resp *http.Response, start time.Time, logPrefix string) (oaChatResponse, error) {
var out oaChatResponse
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
- logging.Logf("llm/openai ", "%sdecode error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase)
+ logging.Logf(logPrefix, "%sdecode error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase)
return oaChatResponse{}, err
}
return out, nil
}
-func parseOpenAIStream(resp *http.Response, start time.Time, onDelta func(string)) error {
+func parseOpenAIStream(resp *http.Response, start time.Time, onDelta func(string), logPrefix, provider string) error {
// Parse SSE: lines starting with "data: " containing JSON or [DONE]
scanner := bufio.NewScanner(resp.Body)
const maxBuf = 1024 * 1024
@@ -305,8 +305,8 @@ func parseOpenAIStream(resp *http.Response, start time.Time, onDelta func(string
continue
}
if chunk.Error != nil && chunk.Error.Message != "" {
- logging.Logf("llm/openai ", "%sstream error: %s%s", logging.AnsiRed, chunk.Error.Message, logging.AnsiBase)
- return fmt.Errorf("openai stream error: %s", chunk.Error.Message)
+ logging.Logf(logPrefix, "%sstream error: %s%s", logging.AnsiRed, chunk.Error.Message, logging.AnsiBase)
+ return fmt.Errorf("%s stream error: %s", provider, chunk.Error.Message)
}
for _, ch := range chunk.Choices {
if ch.Delta.Content != "" {
@@ -315,7 +315,7 @@ func parseOpenAIStream(resp *http.Response, start time.Time, onDelta func(string
}
}
if err := scanner.Err(); err != nil {
- logging.Logf("llm/openai ", "%sstream read error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase)
+ logging.Logf(logPrefix, "%sstream read error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase)
return err
}
return nil
diff --git a/internal/llm/openai_http_test.go b/internal/llm/openai_http_test.go
index cb4bfcb..affcae9 100644
--- a/internal/llm/openai_http_test.go
+++ b/internal/llm/openai_http_test.go
@@ -60,7 +60,7 @@ func TestOpenAI_ChatStream_SSE(t *testing.T) {
func TestHandleOpenAINon2xx_NoErrorBody(t *testing.T) {
resp := &http.Response{StatusCode: 500, Body: io.NopCloser(strings.NewReader("{}"))}
- if err := handleOpenAINon2xx(resp, time.Now()); err == nil {
+ if err := handleOpenAINon2xx(resp, time.Now(), "llm/openai ", "openai"); err == nil {
t.Fatalf("expected http error")
}
}
diff --git a/internal/llm/openai_request_test.go b/internal/llm/openai_request_test.go
index 001e3b7..d053031 100644
--- a/internal/llm/openai_request_test.go
+++ b/internal/llm/openai_request_test.go
@@ -9,13 +9,13 @@ func TestBuildOAChatRequest_MaxTokensKeyByModel(t *testing.T) {
msgs := []Message{{Role: "user", Content: "hi"}}
mt := 123
// Legacy model: use max_tokens
- r1 := buildOAChatRequest(Options{Model: "gpt-4.1", MaxTokens: mt}, msgs, nil, false)
+ r1 := buildOAChatRequest(Options{Model: "gpt-4.1", MaxTokens: mt}, msgs, nil, false, "llm/test ")
b1, _ := json.Marshal(r1)
if !contains(string(b1), "max_tokens") || contains(string(b1), "max_completion_tokens") {
t.Fatalf("expected max_tokens only, got %s", string(b1))
}
// gpt-5 family: use max_completion_tokens
- r2 := buildOAChatRequest(Options{Model: "gpt-5.0-preview", MaxTokens: mt}, msgs, nil, false)
+ r2 := buildOAChatRequest(Options{Model: "gpt-5.0-preview", MaxTokens: mt}, msgs, nil, false, "llm/test ")
b2, _ := json.Marshal(r2)
if !contains(string(b2), "max_completion_tokens") || contains(string(b2), "max_tokens\":") {
t.Fatalf("expected max_completion_tokens only, got %s", string(b2))
@@ -25,7 +25,7 @@ func TestBuildOAChatRequest_MaxTokensKeyByModel(t *testing.T) {
func TestBuildOAChatRequest_TemperatureForcedForGpt5(t *testing.T) {
msgs := []Message{{Role: "user", Content: "hi"}}
// Explicit temp 0.2 → should be forced to 1.0 for gpt-5
- r := buildOAChatRequest(Options{Model: "gpt-5.0", Temperature: 0.2, MaxTokens: 50}, msgs, nil, false)
+ r := buildOAChatRequest(Options{Model: "gpt-5.0", Temperature: 0.2, MaxTokens: 50}, msgs, nil, false, "llm/test ")
b, _ := json.Marshal(r)
if !contains(string(b), "\"temperature\":1") {
t.Fatalf("expected forced temperature 1.0 for gpt-5, got %s", string(b))
diff --git a/internal/llm/openai_temp_test.go b/internal/llm/openai_temp_test.go
index 7615117..3d71b94 100644
--- a/internal/llm/openai_temp_test.go
+++ b/internal/llm/openai_temp_test.go
@@ -5,7 +5,7 @@ import "testing"
func TestNewFromConfig_DefaultTemp_ByModel(t *testing.T) {
// OpenAI, gpt-5.* → default temp 1.0 when not provided
cfg := Config{Provider: "openai", OpenAIModel: "gpt-5.0-preview"}
- c, err := NewFromConfig(cfg, "key", "")
+ c, err := NewFromConfig(cfg, "key", "", "")
if err != nil {
t.Fatalf("new: %v", err)
}
@@ -18,7 +18,7 @@ func TestNewFromConfig_DefaultTemp_ByModel(t *testing.T) {
}
// OpenAI, gpt-4.* → default temp 0.2 when not provided
cfg2 := Config{Provider: "openai", OpenAIModel: "gpt-4.1"}
- c2, err := NewFromConfig(cfg2, "key", "")
+ c2, err := NewFromConfig(cfg2, "key", "", "")
if err != nil {
t.Fatalf("new2: %v", err)
}
@@ -32,7 +32,7 @@ func TestNewFromConfig_DefaultTemp_UpgradeWhenGpt5AndDefault02(t *testing.T) {
// Simulate app-default of 0.2 while selecting a gpt-5 model: should upgrade to 1.0
v := 0.2
cfg := Config{Provider: "openai", OpenAIModel: "gpt-5.0", OpenAITemperature: &v}
- c, err := NewFromConfig(cfg, "key", "")
+ c, err := NewFromConfig(cfg, "key", "", "")
if err != nil {
t.Fatalf("new: %v", err)
}
diff --git a/internal/llm/openrouter.go b/internal/llm/openrouter.go
new file mode 100644
index 0000000..f03844a
--- /dev/null
+++ b/internal/llm/openrouter.go
@@ -0,0 +1,168 @@
+// Summary: OpenRouter client implementation leveraging OpenAI-compatible helpers with provider-specific headers.
+package llm
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "errors"
+ "net/http"
+ "strings"
+ "time"
+
+ "codeberg.org/snonux/hexai/internal/logging"
+)
+
+type openRouterClient struct {
+ httpClient *http.Client
+ apiKey string
+ baseURL string
+ defaultModel string
+ chatLogger logging.ChatLogger
+ defaultTemperature *float64
+}
+
+func newOpenRouter(baseURL, model, apiKey string, defaultTemp *float64) Client {
+ if strings.TrimSpace(baseURL) == "" {
+ baseURL = "https://openrouter.ai/api/v1"
+ }
+ if strings.TrimSpace(model) == "" {
+ model = "openrouter/auto"
+ }
+ return openRouterClient{
+ httpClient: &http.Client{Timeout: 30 * time.Second},
+ apiKey: apiKey,
+ baseURL: baseURL,
+ defaultModel: model,
+ chatLogger: logging.NewChatLogger("openrouter"),
+ defaultTemperature: defaultTemp,
+ }
+}
+
+func (c openRouterClient) Chat(ctx context.Context, messages []Message, opts ...RequestOption) (string, error) {
+ if strings.TrimSpace(c.apiKey) == "" {
+ return nilStringErr("missing OpenRouter API key")
+ }
+ o := Options{Model: c.defaultModel}
+ for _, opt := range opts {
+ opt(&o)
+ }
+ if strings.TrimSpace(o.Model) == "" {
+ o.Model = c.defaultModel
+ }
+ start := time.Now()
+ c.logStart(false, o, messages)
+ req := buildOAChatRequest(o, messages, c.defaultTemperature, false, "llm/openrouter ")
+ body, err := json.Marshal(req)
+ if err != nil {
+ c.logf("marshal error: %v", err)
+ return "", err
+ }
+ endpoint := strings.TrimRight(c.baseURL, "/") + "/chat/completions"
+ logging.Logf("llm/openrouter ", "POST %s", endpoint)
+ resp, err := c.doJSON(ctx, endpoint, body)
+ if err != nil {
+ logging.Logf("llm/openrouter ", "%shttp error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase)
+ return "", err
+ }
+ defer resp.Body.Close()
+ if err := handleOpenAINon2xx(resp, start, "llm/openrouter ", "openrouter"); err != nil {
+ return "", err
+ }
+ out, err := decodeOpenAIChat(resp, start, "llm/openrouter ")
+ if err != nil {
+ return "", err
+ }
+ if len(out.Choices) == 0 {
+ logging.Logf("llm/openrouter ", "%sno choices returned duration=%s%s", logging.AnsiRed, time.Since(start), logging.AnsiBase)
+ return "", errors.New("openrouter: no choices returned")
+ }
+ content := out.Choices[0].Message.Content
+ logging.Logf("llm/openrouter ", "success choice=0 finish=%s size=%d preview=%s%s%s duration=%s", out.Choices[0].FinishReason, len(content), logging.AnsiGreen, logging.PreviewForLog(content), logging.AnsiBase, time.Since(start))
+ return content, nil
+}
+
+func (c openRouterClient) Name() string { return "openrouter" }
+func (c openRouterClient) DefaultModel() string { return c.defaultModel }
+
+func (c openRouterClient) ChatStream(ctx context.Context, messages []Message, onDelta func(string), opts ...RequestOption) error {
+ if strings.TrimSpace(c.apiKey) == "" {
+ return errors.New("missing OpenRouter API key")
+ }
+ o := Options{Model: c.defaultModel}
+ for _, opt := range opts {
+ opt(&o)
+ }
+ if strings.TrimSpace(o.Model) == "" {
+ o.Model = c.defaultModel
+ }
+ start := time.Now()
+ c.logStart(true, o, messages)
+ req := buildOAChatRequest(o, messages, c.defaultTemperature, true, "llm/openrouter ")
+ body, err := json.Marshal(req)
+ if err != nil {
+ c.logf("marshal error: %v", err)
+ return err
+ }
+ endpoint := strings.TrimRight(c.baseURL, "/") + "/chat/completions"
+ logging.Logf("llm/openrouter ", "POST %s (stream)", endpoint)
+ resp, err := c.doJSONWithAccept(ctx, endpoint, body, "text/event-stream")
+ if err != nil {
+ logging.Logf("llm/openrouter ", "%shttp error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase)
+ return err
+ }
+ defer resp.Body.Close()
+ if err := handleOpenAINon2xx(resp, start, "llm/openrouter ", "openrouter"); err != nil {
+ return err
+ }
+ if err := parseOpenAIStream(resp, start, onDelta, "llm/openrouter ", "openrouter"); err != nil {
+ return err
+ }
+ logging.Logf("llm/openrouter ", "stream end duration=%s", time.Since(start))
+ return nil
+}
+
+func (c openRouterClient) logf(format string, args ...any) {
+ logging.Logf("llm/openrouter ", format, args...)
+}
+
+func (c openRouterClient) logStart(stream bool, o Options, messages []Message) {
+ logMessages := make([]struct{ Role, Content string }, len(messages))
+ for i, m := range messages {
+ logMessages[i] = struct{ Role, Content string }{m.Role, m.Content}
+ }
+ c.chatLogger.LogStart(stream, o.Model, o.Temperature, o.MaxTokens, o.Stop, logMessages)
+}
+
+func (c openRouterClient) doJSON(ctx context.Context, url string, body []byte) (*http.Response, error) {
+ headers := map[string]string{
+ "Authorization": "Bearer " + c.apiKey,
+ "HTTP-Referer": "https://github.com/snonux/hexai",
+ "X-Title": "Hexai",
+ }
+ return c.doJSONWithHeaders(ctx, url, body, headers, "")
+}
+
+func (c openRouterClient) doJSONWithAccept(ctx context.Context, url string, body []byte, accept string) (*http.Response, error) {
+ headers := map[string]string{
+ "Authorization": "Bearer " + c.apiKey,
+ "HTTP-Referer": "https://github.com/snonux/hexai",
+ "X-Title": "Hexai",
+ }
+ return c.doJSONWithHeaders(ctx, url, body, headers, accept)
+}
+
+func (c openRouterClient) doJSONWithHeaders(ctx context.Context, url string, body []byte, headers map[string]string, accept string) (*http.Response, error) {
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Content-Type", "application/json")
+ if strings.TrimSpace(accept) != "" {
+ req.Header.Set("Accept", accept)
+ }
+ for k, v := range headers {
+ req.Header.Set(k, v)
+ }
+ return c.httpClient.Do(req)
+}
diff --git a/internal/llm/openrouter_test.go b/internal/llm/openrouter_test.go
new file mode 100644
index 0000000..2a07be0
--- /dev/null
+++ b/internal/llm/openrouter_test.go
@@ -0,0 +1,125 @@
+package llm
+
+import (
+ "context"
+ "encoding/json"
+ "io"
+ "log"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "testing"
+
+ "codeberg.org/snonux/hexai/internal/logging"
+)
+
+func TestOpenRouter_Chat_SendsHeadersAndBody(t *testing.T) {
+ if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" {
+ t.Skip("skip network-bound tests in restricted environments")
+ }
+ var capturedHeaders http.Header
+ var capturedBody []byte
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ capturedHeaders = r.Header.Clone()
+ body, err := io.ReadAll(r.Body)
+ if err != nil {
+ t.Fatalf("read body: %v", err)
+ }
+ capturedBody = append([]byte(nil), body...)
+ _ = json.NewEncoder(w).Encode(map[string]any{
+ "choices": []map[string]any{
+ {"index": 0, "message": map[string]string{"role": "assistant", "content": "ack"}},
+ },
+ })
+ }))
+ defer srv.Close()
+
+ c := newOpenRouter(srv.URL, "anthropic/claude-test", "KEY", f64p(0.2)).(openRouterClient)
+ c.httpClient = srv.Client()
+ out, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "ping"}})
+ if err != nil {
+ t.Fatalf("chat returned error: %v", err)
+ }
+ if out != "ack" {
+ t.Fatalf("unexpected response: %q", out)
+ }
+ if capturedHeaders.Get("Authorization") != "Bearer KEY" {
+ t.Fatalf("missing auth header: %#v", capturedHeaders)
+ }
+ if capturedHeaders.Get("HTTP-Referer") != "https://github.com/snonux/hexai" {
+ t.Fatalf("missing referer header: %#v", capturedHeaders)
+ }
+ if capturedHeaders.Get("X-Title") != "Hexai" {
+ t.Fatalf("missing title header: %#v", capturedHeaders)
+ }
+
+ var req oaChatRequest
+ if err := json.Unmarshal(capturedBody, &req); err != nil {
+ t.Fatalf("unmarshal request: %v", err)
+ }
+ if req.Model != "anthropic/claude-test" {
+ t.Fatalf("unexpected model: %q", req.Model)
+ }
+ if len(req.Messages) != 1 || req.Messages[0].Role != "user" || req.Messages[0].Content != "ping" {
+ t.Fatalf("unexpected messages: %#v", req.Messages)
+ }
+}
+
+func TestOpenRouter_ChatStream_SendsHeaders(t *testing.T) {
+ if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" {
+ t.Skip("skip network-bound tests in restricted environments")
+ }
+ var acceptHeader string
+ var referer string
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ acceptHeader = r.Header.Get("Accept")
+ referer = r.Header.Get("HTTP-Referer")
+ w.Header().Set("Content-Type", "text/event-stream")
+ io.WriteString(w, "data: {\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\n")
+ io.WriteString(w, "data: [DONE]\n")
+ }))
+ defer srv.Close()
+
+ c := newOpenRouter(srv.URL, "anthropic/claude-test", "KEY", f64p(0.2)).(openRouterClient)
+ c.httpClient = srv.Client()
+ var got string
+ err := c.ChatStream(context.Background(), []Message{{Role: "user", Content: "ping"}}, func(s string) { got += s })
+ if err != nil {
+ t.Fatalf("chat stream error: %v", err)
+ }
+ if got != "hi" {
+ t.Fatalf("expected stream output 'hi', got %q", got)
+ }
+ if acceptHeader != "text/event-stream" {
+ t.Fatalf("unexpected Accept header: %q", acceptHeader)
+ }
+ if referer != "https://github.com/snonux/hexai" {
+ t.Fatalf("missing referer header in stream: %q", referer)
+ }
+}
+
+func TestOpenRouter_Chat_MissingKey(t *testing.T) {
+ c := newOpenRouter("http://example", "anthropic/claude-test", "", f64p(0.2)).(openRouterClient)
+ if _, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "ping"}}); err == nil {
+ t.Fatalf("expected error for missing api key")
+ }
+}
+
+func TestOpenRouter_DefaultsAndMetadata(t *testing.T) {
+ logger := log.New(io.Discard, "", 0)
+ logging.Bind(logger)
+ c := newOpenRouter("", "", "KEY", nil).(openRouterClient)
+ if c.baseURL != "https://openrouter.ai/api/v1" {
+ t.Fatalf("default baseURL mismatch: %s", c.baseURL)
+ }
+ if c.defaultModel != "openrouter/auto" {
+ t.Fatalf("default model mismatch: %s", c.defaultModel)
+ }
+ if name := c.Name(); name != "openrouter" {
+ t.Fatalf("Name() = %s", name)
+ }
+ if model := c.DefaultModel(); model != "openrouter/auto" {
+ t.Fatalf("DefaultModel() = %s", model)
+ }
+ c.logf("smoke")
+}
diff --git a/internal/llm/provider.go b/internal/llm/provider.go
index 84efaf9..b2c47e4 100644
--- a/internal/llm/provider.go
+++ b/internal/llm/provider.go
@@ -69,6 +69,10 @@ type Config struct {
OpenAIBaseURL string
OpenAIModel string
OpenAITemperature *float64
+ // OpenRouter options
+ OpenRouterBaseURL string
+ OpenRouterModel string
+ OpenRouterTemperature *float64
// Ollama options
OllamaBaseURL string
OllamaModel string
@@ -82,7 +86,7 @@ type Config struct {
// NewFromConfig creates an LLM client using only the supplied configuration.
// The OpenAI API key is supplied separately and may be read from the environment
// by the caller; other environment-based configuration is not used.
-func NewFromConfig(cfg Config, openAIAPIKey, copilotAPIKey string) (Client, error) {
+func NewFromConfig(cfg Config, openAIAPIKey, openRouterAPIKey, copilotAPIKey string) (Client, error) {
p := strings.ToLower(strings.TrimSpace(cfg.Provider))
if p == "" {
p = "openai"
@@ -112,6 +116,15 @@ func NewFromConfig(cfg Config, openAIAPIKey, copilotAPIKey string) (Client, erro
cfg.OpenAITemperature = &v
}
return newOpenAI(cfg.OpenAIBaseURL, cfg.OpenAIModel, openAIAPIKey, cfg.OpenAITemperature), nil
+ case "openrouter":
+ if strings.TrimSpace(openRouterAPIKey) == "" {
+ return nil, errors.New("missing OPENROUTER_API_KEY for provider openrouter")
+ }
+ if cfg.OpenRouterTemperature == nil {
+ t := 0.2
+ cfg.OpenRouterTemperature = &t
+ }
+ return newOpenRouter(cfg.OpenRouterBaseURL, cfg.OpenRouterModel, openRouterAPIKey, cfg.OpenRouterTemperature), nil
case "ollama":
if cfg.OllamaTemperature == nil {
t := 0.2
diff --git a/internal/llm/provider_more2_test.go b/internal/llm/provider_more2_test.go
index 465be82..e001e5c 100644
--- a/internal/llm/provider_more2_test.go
+++ b/internal/llm/provider_more2_test.go
@@ -5,7 +5,7 @@ import "testing"
func TestNewFromConfig_Copilot(t *testing.T) {
t.Setenv("COPILOT_API_KEY", "x")
cfg := Config{Provider: "copilot", CopilotModel: "small"}
- c, err := NewFromConfig(cfg, "", "x")
+ c, err := NewFromConfig(cfg, "", "", "x")
if err != nil || c == nil {
t.Fatalf("copilot provider failed: %v %v", c, err)
}
diff --git a/internal/llm/provider_more_test.go b/internal/llm/provider_more_test.go
index d7469af..eff99e6 100644
--- a/internal/llm/provider_more_test.go
+++ b/internal/llm/provider_more_test.go
@@ -16,13 +16,13 @@ func TestWithOptions_Apply(t *testing.T) {
func TestNewFromConfig_Success_OpenAI_And_Copilot(t *testing.T) {
// OpenAI success
oc := Config{Provider: "openai", OpenAIBaseURL: "http://x", OpenAIModel: "gpt"}
- c, err := NewFromConfig(oc, "KEY", "")
+ c, err := NewFromConfig(oc, "KEY", "", "")
if err != nil || c == nil || c.Name() != "openai" || c.DefaultModel() == "" {
t.Fatalf("openai new: %v %v", c, err)
}
// Copilot success
cc := Config{Provider: "copilot", CopilotBaseURL: "http://x", CopilotModel: "gpt-4o-mini"}
- c2, err := NewFromConfig(cc, "", "KEY")
+ c2, err := NewFromConfig(cc, "", "", "KEY")
if err != nil || c2 == nil || c2.Name() != "copilot" || c2.DefaultModel() == "" {
t.Fatalf("copilot new: %v %v", c2, err)
}
diff --git a/internal/llm/provider_test.go b/internal/llm/provider_test.go
index 29e2514..2c0d69c 100644
--- a/internal/llm/provider_test.go
+++ b/internal/llm/provider_test.go
@@ -7,15 +7,15 @@ import (
func TestNewFromConfig_DefaultsAndErrors(t *testing.T) {
// Unknown provider
- if _, err := NewFromConfig(Config{Provider: "bogus"}, "", ""); err == nil {
+ if _, err := NewFromConfig(Config{Provider: "bogus"}, "", "", ""); err == nil {
t.Fatalf("expected error for unknown provider")
}
// OpenAI missing key
- if _, err := NewFromConfig(Config{Provider: "openai", OpenAIModel: "g"}, "", ""); err == nil {
+ if _, err := NewFromConfig(Config{Provider: "openai", OpenAIModel: "g"}, "", "", ""); err == nil {
t.Fatalf("expected key error")
}
// Copilot missing key
- if _, err := NewFromConfig(Config{Provider: "copilot", CopilotModel: "m"}, "", ""); err == nil {
+ if _, err := NewFromConfig(Config{Provider: "copilot", CopilotModel: "m"}, "", "", ""); err == nil {
t.Fatalf("expected key error")
}
}