summaryrefslogtreecommitdiff
path: root/internal/llm/ollama.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/llm/ollama.go')
-rw-r--r--internal/llm/ollama.go192
1 files changed, 78 insertions, 114 deletions
diff --git a/internal/llm/ollama.go b/internal/llm/ollama.go
index 20dfe2a..50e9837 100644
--- a/internal/llm/ollama.go
+++ b/internal/llm/ollama.go
@@ -1,5 +1,4 @@
// Summary: Ollama client against a local server; supports chat responses and streaming via /api/chat.
-// Not yet reviewed by a human
package llm
import (
@@ -18,11 +17,11 @@ import (
// ollamaClient implements Client against a local Ollama server.
type ollamaClient struct {
- httpClient *http.Client
- baseURL string
- defaultModel string
- chatLogger logging.ChatLogger
- defaultTemperature *float64
+ httpClient *http.Client
+ baseURL string
+ defaultModel string
+ chatLogger logging.ChatLogger
+ defaultTemperature *float64
}
type ollamaChatRequest struct {
@@ -49,13 +48,13 @@ func newOllama(baseURL, model string, defaultTemp *float64) Client {
if strings.TrimSpace(model) == "" {
model = "qwen3-coder:30b-a3b-q4_K_M`"
}
- return ollamaClient{
- httpClient: &http.Client{Timeout: 30 * time.Second},
- baseURL: strings.TrimRight(baseURL, "/"),
- defaultModel: model,
- chatLogger: logging.NewChatLogger("ollama"),
- defaultTemperature: defaultTemp,
- }
+ return ollamaClient{
+ httpClient: &http.Client{Timeout: 30 * time.Second},
+ baseURL: strings.TrimRight(baseURL, "/"),
+ defaultModel: model,
+ chatLogger: logging.NewChatLogger("ollama"),
+ defaultTemperature: defaultTemp,
+ }
}
// TODO: This function is too long and should be refactored for readability and maintainability.
@@ -69,41 +68,8 @@ func (c ollamaClient) Chat(ctx context.Context, messages []Message, opts ...Requ
}
start := time.Now()
- logMessages := make([]struct {
- Role string
- Content string
- }, len(messages))
- for i, m := range messages {
- logMessages[i] = struct {
- Role string
- Content string
- }{Role: m.Role, Content: m.Content}
- }
- c.chatLogger.LogStart(false, o.Model, o.Temperature, o.MaxTokens, o.Stop, logMessages)
-
- req := ollamaChatRequest{Model: o.Model, Stream: false}
- req.Messages = make([]oaMessage, len(messages))
- for i, m := range messages {
- req.Messages[i] = oaMessage{Role: m.Role, Content: m.Content}
- }
-
- // Build options map only if any option is set
- optsMap := map[string]any{}
- if o.Temperature != 0 {
- optsMap["temperature"] = o.Temperature
- } else if c.defaultTemperature != nil {
- optsMap["temperature"] = *c.defaultTemperature
- }
- if o.MaxTokens > 0 {
- optsMap["num_predict"] = o.MaxTokens
- }
- if len(o.Stop) > 0 {
- optsMap["stop"] = o.Stop
- }
- if len(optsMap) > 0 {
- req.Options = optsMap
- }
-
+ c.logStart(false, o, messages)
+ req := buildOllamaRequest(o, messages, c.defaultTemperature, false)
body, err := json.Marshal(req)
if err != nil {
return "", err
@@ -111,27 +77,14 @@ func (c ollamaClient) Chat(ctx context.Context, messages []Message, opts ...Requ
endpoint := c.baseURL + "/api/chat"
logging.Logf("llm/ollama ", "POST %s", endpoint)
- httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
- if err != nil {
- return "", err
- }
- httpReq.Header.Set("Content-Type", "application/json")
-
- resp, err := c.httpClient.Do(httpReq)
+ resp, err := c.doJSON(ctx, endpoint, body)
if err != nil {
logging.Logf("llm/ollama ", "%shttp error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase)
return "", err
}
defer resp.Body.Close()
- if resp.StatusCode < 200 || resp.StatusCode >= 300 {
- var apiErr ollamaChatResponse
- _ = json.NewDecoder(resp.Body).Decode(&apiErr)
- if strings.TrimSpace(apiErr.Error) != "" {
- logging.Logf("llm/ollama ", "%sapi error status=%d msg=%s duration=%s%s", logging.AnsiRed, resp.StatusCode, apiErr.Error, time.Since(start), logging.AnsiBase)
- return "", fmt.Errorf("ollama error: %s (status %d)", apiErr.Error, resp.StatusCode)
- }
- logging.Logf("llm/ollama ", "%shttp non-2xx status=%d duration=%s%s", logging.AnsiRed, resp.StatusCode, time.Since(start), logging.AnsiBase)
- return "", fmt.Errorf("ollama http error: status %d", resp.StatusCode)
+ if err := handleOllamaNon2xx(resp, start); err != nil {
+ return "", err
}
var out ollamaChatResponse
@@ -163,40 +116,8 @@ func (c ollamaClient) ChatStream(ctx context.Context, messages []Message, onDelt
}
start := time.Now()
- logMessages := make([]struct {
- Role string
- Content string
- }, len(messages))
- for i, m := range messages {
- logMessages[i] = struct {
- Role string
- Content string
- }{Role: m.Role, Content: m.Content}
- }
- c.chatLogger.LogStart(true, o.Model, o.Temperature, o.MaxTokens, o.Stop, logMessages)
-
- req := ollamaChatRequest{Model: o.Model, Stream: true}
- req.Messages = make([]oaMessage, len(messages))
- for i, m := range messages {
- req.Messages[i] = oaMessage{Role: m.Role, Content: m.Content}
- }
- // Build options map
- optsMap := map[string]any{}
- if o.Temperature != 0 {
- optsMap["temperature"] = o.Temperature
- } else if c.defaultTemperature != nil {
- optsMap["temperature"] = *c.defaultTemperature
- }
- if o.MaxTokens > 0 {
- optsMap["num_predict"] = o.MaxTokens
- }
- if len(o.Stop) > 0 {
- optsMap["stop"] = o.Stop
- }
- if len(optsMap) > 0 {
- req.Options = optsMap
- }
-
+ c.logStart(true, o, messages)
+ req := buildOllamaRequest(o, messages, c.defaultTemperature, true)
body, err := json.Marshal(req)
if err != nil {
return err
@@ -204,27 +125,14 @@ func (c ollamaClient) ChatStream(ctx context.Context, messages []Message, onDelt
endpoint := c.baseURL + "/api/chat"
logging.Logf("llm/ollama ", "POST %s (stream)", endpoint)
- httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
- if err != nil {
- return err
- }
- httpReq.Header.Set("Content-Type", "application/json")
-
- resp, err := c.httpClient.Do(httpReq)
+ resp, err := c.doJSON(ctx, endpoint, body)
if err != nil {
logging.Logf("llm/ollama ", "%shttp error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase)
return err
}
defer resp.Body.Close()
- if resp.StatusCode < 200 || resp.StatusCode >= 300 {
- var apiErr ollamaChatResponse
- _ = json.NewDecoder(resp.Body).Decode(&apiErr)
- if strings.TrimSpace(apiErr.Error) != "" {
- logging.Logf("llm/ollama ", "%sapi error status=%d msg=%s duration=%s%s", logging.AnsiRed, resp.StatusCode, apiErr.Error, time.Since(start), logging.AnsiBase)
- return fmt.Errorf("ollama error: %s (status %d)", apiErr.Error, resp.StatusCode)
- }
- logging.Logf("llm/ollama ", "%shttp non-2xx status=%d duration=%s%s", logging.AnsiRed, resp.StatusCode, time.Since(start), logging.AnsiBase)
- return fmt.Errorf("ollama http error: status %d", resp.StatusCode)
+ if err := handleOllamaNon2xx(resp, start); err != nil {
+ return err
}
dec := json.NewDecoder(resp.Body)
@@ -251,3 +159,59 @@ func (c ollamaClient) ChatStream(ctx context.Context, messages []Message, onDelt
logging.Logf("llm/ollama ", "stream end duration=%s", time.Since(start))
return nil
}
+
+// helpers to keep methods small
+func (c ollamaClient) logStart(stream bool, o Options, messages []Message) {
+ logMessages := make([]struct{ Role, Content string }, len(messages))
+ for i, m := range messages {
+ logMessages[i] = struct{ Role, Content string }{m.Role, m.Content}
+ }
+ c.chatLogger.LogStart(stream, o.Model, o.Temperature, o.MaxTokens, o.Stop, logMessages)
+}
+
+func buildOllamaRequest(o Options, messages []Message, defaultTemp *float64, stream bool) ollamaChatRequest {
+ req := ollamaChatRequest{Model: o.Model, Stream: stream}
+ req.Messages = make([]oaMessage, len(messages))
+ for i, m := range messages {
+ req.Messages[i] = oaMessage{Role: m.Role, Content: m.Content}
+ }
+ optsMap := map[string]any{}
+ if o.Temperature != 0 {
+ optsMap["temperature"] = o.Temperature
+ } else if defaultTemp != nil {
+ optsMap["temperature"] = *defaultTemp
+ }
+ if o.MaxTokens > 0 {
+ optsMap["num_predict"] = o.MaxTokens
+ }
+ if len(o.Stop) > 0 {
+ optsMap["stop"] = o.Stop
+ }
+ if len(optsMap) > 0 {
+ req.Options = optsMap
+ }
+ return req
+}
+
+func (c ollamaClient) doJSON(ctx context.Context, url string, body []byte) (*http.Response, error) {
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Content-Type", "application/json")
+ return c.httpClient.Do(req)
+}
+
+func handleOllamaNon2xx(resp *http.Response, start time.Time) error {
+ if resp.StatusCode >= 200 && resp.StatusCode < 300 {
+ return nil
+ }
+ var apiErr ollamaChatResponse
+ _ = json.NewDecoder(resp.Body).Decode(&apiErr)
+ if strings.TrimSpace(apiErr.Error) != "" {
+ logging.Logf("llm/ollama ", "%sapi error status=%d msg=%s duration=%s%s", logging.AnsiRed, resp.StatusCode, apiErr.Error, time.Since(start), logging.AnsiBase)
+ return fmt.Errorf("ollama error: %s (status %d)", apiErr.Error, resp.StatusCode)
+ }
+ logging.Logf("llm/ollama ", "%shttp non-2xx status=%d duration=%s%s", logging.AnsiRed, resp.StatusCode, time.Since(start), logging.AnsiBase)
+ return fmt.Errorf("ollama http error: status %d", resp.StatusCode)
+}