summaryrefslogtreecommitdiff
path: root/internal/llm
diff options
context:
space:
mode:
Diffstat (limited to 'internal/llm')
-rw-r--r--internal/llm/anthropic.go17
-rw-r--r--internal/llm/ollama.go14
-rw-r--r--internal/llm/openai.go28
-rw-r--r--internal/llm/openrouter.go20
-rw-r--r--internal/llm/util.go33
5 files changed, 41 insertions, 71 deletions
diff --git a/internal/llm/anthropic.go b/internal/llm/anthropic.go
index 82d8b8a..2fc1d84 100644
--- a/internal/llm/anthropic.go
+++ b/internal/llm/anthropic.go
@@ -3,7 +3,6 @@ package llm
import (
"bufio"
- "bytes"
"context"
"encoding/json"
"errors"
@@ -241,11 +240,7 @@ func (c anthropicClient) logf(format string, args ...any) {
}
func (c anthropicClient) logStart(stream bool, o Options, messages []Message) {
- logMessages := make([]struct{ Role, Content string }, len(messages))
- for i, m := range messages {
- logMessages[i] = struct{ Role, Content string }{m.Role, m.Content}
- }
- c.chatLogger.LogStart(stream, o.Model, o.Temperature, o.MaxTokens, o.Stop, logMessages)
+ logStartMessages(c.chatLogger, stream, o, messages)
}
func buildAnthropicChatRequest(o Options, messages []Message, defaultModel string, defaultTemp *float64, stream bool) anthropicChatRequest {
@@ -287,15 +282,7 @@ func buildAnthropicChatRequest(o Options, messages []Message, defaultModel strin
}
func (c anthropicClient) doJSON(ctx context.Context, url string, body []byte, headers map[string]string) (*http.Response, error) {
- req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
- if err != nil {
- return nil, err
- }
- req.Header.Set("Content-Type", "application/json")
- for k, v := range headers {
- req.Header.Set(k, v)
- }
- return c.httpClient.Do(req)
+ return doJSONRequest(ctx, c.httpClient, url, body, headers, "")
}
func handleAnthropicNon2xx(resp *http.Response, start time.Time) error {
diff --git a/internal/llm/ollama.go b/internal/llm/ollama.go
index a878b62..896ca13 100644
--- a/internal/llm/ollama.go
+++ b/internal/llm/ollama.go
@@ -2,7 +2,6 @@
package llm
import (
- "bytes"
"context"
"encoding/json"
"errors"
@@ -189,11 +188,7 @@ func (c ollamaClient) ChatStream(ctx context.Context, messages []Message, onDelt
// helpers to keep methods small
func (c ollamaClient) logStart(stream bool, o Options, messages []Message) {
- logMessages := make([]struct{ Role, Content string }, len(messages))
- for i, m := range messages {
- logMessages[i] = struct{ Role, Content string }{m.Role, m.Content}
- }
- c.chatLogger.LogStart(stream, o.Model, o.Temperature, o.MaxTokens, o.Stop, logMessages)
+ logStartMessages(c.chatLogger, stream, o, messages)
}
func buildOllamaRequest(o Options, messages []Message, defaultTemp *float64, stream bool) ollamaChatRequest {
@@ -221,12 +216,7 @@ func buildOllamaRequest(o Options, messages []Message, defaultTemp *float64, str
}
func (c ollamaClient) doJSON(ctx context.Context, url string, body []byte) (*http.Response, error) {
- req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
- if err != nil {
- return nil, err
- }
- req.Header.Set("Content-Type", "application/json")
- return c.httpClient.Do(req)
+ return doJSONRequest(ctx, c.httpClient, url, body, nil, "")
}
func handleOllamaNon2xx(resp *http.Response, start time.Time) error {
diff --git a/internal/llm/openai.go b/internal/llm/openai.go
index 5c1e525..3d2bf94 100644
--- a/internal/llm/openai.go
+++ b/internal/llm/openai.go
@@ -3,7 +3,6 @@ package llm
import (
"bufio"
- "bytes"
"context"
"encoding/json"
"errors"
@@ -237,11 +236,7 @@ func (c openAIClient) logf(format string, args ...any) { logging.Logf("llm/opena
// helpers extracted to keep methods small
func (c openAIClient) logStart(stream bool, o Options, messages []Message) {
- logMessages := make([]struct{ Role, Content string }, len(messages))
- for i, m := range messages {
- logMessages[i] = struct{ Role, Content string }{m.Role, m.Content}
- }
- c.chatLogger.LogStart(stream, o.Model, o.Temperature, o.MaxTokens, o.Stop, logMessages)
+ logStartMessages(c.chatLogger, stream, o, messages)
}
func buildOAChatRequest(o Options, messages []Message, defaultTemp *float64, stream bool, logPrefix string) oaChatRequest {
@@ -286,28 +281,11 @@ func requiresMaxCompletionTokens(model string) bool {
}
func (c openAIClient) doJSON(ctx context.Context, url string, body []byte, headers map[string]string) (*http.Response, error) {
- req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
- if err != nil {
- return nil, err
- }
- req.Header.Set("Content-Type", "application/json")
- for k, v := range headers {
- req.Header.Set(k, v)
- }
- return c.httpClient.Do(req)
+ return doJSONRequest(ctx, c.httpClient, url, body, headers, "")
}
func (c openAIClient) doJSONWithAccept(ctx context.Context, url string, body []byte, headers map[string]string, accept string) (*http.Response, error) {
- req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
- if err != nil {
- return nil, err
- }
- req.Header.Set("Content-Type", "application/json")
- req.Header.Set("Accept", accept)
- for k, v := range headers {
- req.Header.Set(k, v)
- }
- return c.httpClient.Do(req)
+ return doJSONRequest(ctx, c.httpClient, url, body, headers, accept)
}
func handleOpenAINon2xx(resp *http.Response, start time.Time, logPrefix, provider string) error {
diff --git a/internal/llm/openrouter.go b/internal/llm/openrouter.go
index 8aae6b8..19c2f6c 100644
--- a/internal/llm/openrouter.go
+++ b/internal/llm/openrouter.go
@@ -2,7 +2,6 @@
package llm
import (
- "bytes"
"context"
"encoding/json"
"errors"
@@ -159,11 +158,7 @@ func (c openRouterClient) logf(format string, args ...any) {
}
func (c openRouterClient) logStart(stream bool, o Options, messages []Message) {
- logMessages := make([]struct{ Role, Content string }, len(messages))
- for i, m := range messages {
- logMessages[i] = struct{ Role, Content string }{m.Role, m.Content}
- }
- c.chatLogger.LogStart(stream, o.Model, o.Temperature, o.MaxTokens, o.Stop, logMessages)
+ logStartMessages(c.chatLogger, stream, o, messages)
}
func (c openRouterClient) doJSON(ctx context.Context, url string, body []byte) (*http.Response, error) {
@@ -185,16 +180,5 @@ func (c openRouterClient) doJSONWithAccept(ctx context.Context, url string, body
}
func (c openRouterClient) doJSONWithHeaders(ctx context.Context, url string, body []byte, headers map[string]string, accept string) (*http.Response, error) {
- req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
- if err != nil {
- return nil, err
- }
- req.Header.Set("Content-Type", "application/json")
- if strings.TrimSpace(accept) != "" {
- req.Header.Set("Accept", accept)
- }
- for k, v := range headers {
- req.Header.Set(k, v)
- }
- return c.httpClient.Do(req)
+ return doJSONRequest(ctx, c.httpClient, url, body, headers, accept)
}
diff --git a/internal/llm/util.go b/internal/llm/util.go
index b99d7c8..b6e2adc 100644
--- a/internal/llm/util.go
+++ b/internal/llm/util.go
@@ -1,6 +1,37 @@
package llm
-import "errors"
+import (
+ "bytes"
+ "context"
+ "errors"
+ "net/http"
+ "strings"
+
+ "codeberg.org/snonux/hexai/internal/logging"
+)
// small helper to keep return type consistent
func nilStringErr(msg string) (string, error) { return "", errors.New(msg) }
+
+func doJSONRequest(ctx context.Context, httpClient *http.Client, url string, body []byte, headers map[string]string, accept string) (*http.Response, error) {
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Content-Type", "application/json")
+ if strings.TrimSpace(accept) != "" {
+ req.Header.Set("Accept", accept)
+ }
+ for k, v := range headers {
+ req.Header.Set(k, v)
+ }
+ return httpClient.Do(req)
+}
+
+func logStartMessages(chatLogger logging.ChatLogger, stream bool, o Options, messages []Message) {
+ logMessages := make([]struct{ Role, Content string }, len(messages))
+ for i, m := range messages {
+ logMessages[i] = struct{ Role, Content string }{m.Role, m.Content}
+ }
+ chatLogger.LogStart(stream, o.Model, o.Temperature, o.MaxTokens, o.Stop, logMessages)
+}