diff options
Diffstat (limited to 'internal/llm/ollama.go')
| -rw-r--r-- | internal/llm/ollama.go | 192 |
1 files changed, 78 insertions, 114 deletions
diff --git a/internal/llm/ollama.go b/internal/llm/ollama.go index 20dfe2a..50e9837 100644 --- a/internal/llm/ollama.go +++ b/internal/llm/ollama.go @@ -1,5 +1,4 @@ // Summary: Ollama client against a local server; supports chat responses and streaming via /api/chat. -// Not yet reviewed by a human package llm import ( @@ -18,11 +17,11 @@ import ( // ollamaClient implements Client against a local Ollama server. type ollamaClient struct { - httpClient *http.Client - baseURL string - defaultModel string - chatLogger logging.ChatLogger - defaultTemperature *float64 + httpClient *http.Client + baseURL string + defaultModel string + chatLogger logging.ChatLogger + defaultTemperature *float64 } type ollamaChatRequest struct { @@ -49,13 +48,13 @@ func newOllama(baseURL, model string, defaultTemp *float64) Client { if strings.TrimSpace(model) == "" { model = "qwen3-coder:30b-a3b-q4_K_M`" } - return ollamaClient{ - httpClient: &http.Client{Timeout: 30 * time.Second}, - baseURL: strings.TrimRight(baseURL, "/"), - defaultModel: model, - chatLogger: logging.NewChatLogger("ollama"), - defaultTemperature: defaultTemp, - } + return ollamaClient{ + httpClient: &http.Client{Timeout: 30 * time.Second}, + baseURL: strings.TrimRight(baseURL, "/"), + defaultModel: model, + chatLogger: logging.NewChatLogger("ollama"), + defaultTemperature: defaultTemp, + } } // TODO: This function is too long and should be refactored for readability and maintainability. @@ -69,41 +68,8 @@ func (c ollamaClient) Chat(ctx context.Context, messages []Message, opts ...Requ } start := time.Now() - logMessages := make([]struct { - Role string - Content string - }, len(messages)) - for i, m := range messages { - logMessages[i] = struct { - Role string - Content string - }{Role: m.Role, Content: m.Content} - } - c.chatLogger.LogStart(false, o.Model, o.Temperature, o.MaxTokens, o.Stop, logMessages) - - req := ollamaChatRequest{Model: o.Model, Stream: false} - req.Messages = make([]oaMessage, len(messages)) - for i, m := range messages { - req.Messages[i] = oaMessage{Role: m.Role, Content: m.Content} - } - - // Build options map only if any option is set - optsMap := map[string]any{} - if o.Temperature != 0 { - optsMap["temperature"] = o.Temperature - } else if c.defaultTemperature != nil { - optsMap["temperature"] = *c.defaultTemperature - } - if o.MaxTokens > 0 { - optsMap["num_predict"] = o.MaxTokens - } - if len(o.Stop) > 0 { - optsMap["stop"] = o.Stop - } - if len(optsMap) > 0 { - req.Options = optsMap - } - + c.logStart(false, o, messages) + req := buildOllamaRequest(o, messages, c.defaultTemperature, false) body, err := json.Marshal(req) if err != nil { return "", err @@ -111,27 +77,14 @@ func (c ollamaClient) Chat(ctx context.Context, messages []Message, opts ...Requ endpoint := c.baseURL + "/api/chat" logging.Logf("llm/ollama ", "POST %s", endpoint) - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body)) - if err != nil { - return "", err - } - httpReq.Header.Set("Content-Type", "application/json") - - resp, err := c.httpClient.Do(httpReq) + resp, err := c.doJSON(ctx, endpoint, body) if err != nil { logging.Logf("llm/ollama ", "%shttp error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase) return "", err } defer resp.Body.Close() - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - var apiErr ollamaChatResponse - _ = json.NewDecoder(resp.Body).Decode(&apiErr) - if strings.TrimSpace(apiErr.Error) != "" { - logging.Logf("llm/ollama ", "%sapi error status=%d msg=%s duration=%s%s", logging.AnsiRed, resp.StatusCode, apiErr.Error, time.Since(start), logging.AnsiBase) - return "", fmt.Errorf("ollama error: %s (status %d)", apiErr.Error, resp.StatusCode) - } - logging.Logf("llm/ollama ", "%shttp non-2xx status=%d duration=%s%s", logging.AnsiRed, resp.StatusCode, time.Since(start), logging.AnsiBase) - return "", fmt.Errorf("ollama http error: status %d", resp.StatusCode) + if err := handleOllamaNon2xx(resp, start); err != nil { + return "", err } var out ollamaChatResponse @@ -163,40 +116,8 @@ func (c ollamaClient) ChatStream(ctx context.Context, messages []Message, onDelt } start := time.Now() - logMessages := make([]struct { - Role string - Content string - }, len(messages)) - for i, m := range messages { - logMessages[i] = struct { - Role string - Content string - }{Role: m.Role, Content: m.Content} - } - c.chatLogger.LogStart(true, o.Model, o.Temperature, o.MaxTokens, o.Stop, logMessages) - - req := ollamaChatRequest{Model: o.Model, Stream: true} - req.Messages = make([]oaMessage, len(messages)) - for i, m := range messages { - req.Messages[i] = oaMessage{Role: m.Role, Content: m.Content} - } - // Build options map - optsMap := map[string]any{} - if o.Temperature != 0 { - optsMap["temperature"] = o.Temperature - } else if c.defaultTemperature != nil { - optsMap["temperature"] = *c.defaultTemperature - } - if o.MaxTokens > 0 { - optsMap["num_predict"] = o.MaxTokens - } - if len(o.Stop) > 0 { - optsMap["stop"] = o.Stop - } - if len(optsMap) > 0 { - req.Options = optsMap - } - + c.logStart(true, o, messages) + req := buildOllamaRequest(o, messages, c.defaultTemperature, true) body, err := json.Marshal(req) if err != nil { return err @@ -204,27 +125,14 @@ func (c ollamaClient) ChatStream(ctx context.Context, messages []Message, onDelt endpoint := c.baseURL + "/api/chat" logging.Logf("llm/ollama ", "POST %s (stream)", endpoint) - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body)) - if err != nil { - return err - } - httpReq.Header.Set("Content-Type", "application/json") - - resp, err := c.httpClient.Do(httpReq) + resp, err := c.doJSON(ctx, endpoint, body) if err != nil { logging.Logf("llm/ollama ", "%shttp error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase) return err } defer resp.Body.Close() - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - var apiErr ollamaChatResponse - _ = json.NewDecoder(resp.Body).Decode(&apiErr) - if strings.TrimSpace(apiErr.Error) != "" { - logging.Logf("llm/ollama ", "%sapi error status=%d msg=%s duration=%s%s", logging.AnsiRed, resp.StatusCode, apiErr.Error, time.Since(start), logging.AnsiBase) - return fmt.Errorf("ollama error: %s (status %d)", apiErr.Error, resp.StatusCode) - } - logging.Logf("llm/ollama ", "%shttp non-2xx status=%d duration=%s%s", logging.AnsiRed, resp.StatusCode, time.Since(start), logging.AnsiBase) - return fmt.Errorf("ollama http error: status %d", resp.StatusCode) + if err := handleOllamaNon2xx(resp, start); err != nil { + return err } dec := json.NewDecoder(resp.Body) @@ -251,3 +159,59 @@ func (c ollamaClient) ChatStream(ctx context.Context, messages []Message, onDelt logging.Logf("llm/ollama ", "stream end duration=%s", time.Since(start)) return nil } + +// helpers to keep methods small +func (c ollamaClient) logStart(stream bool, o Options, messages []Message) { + logMessages := make([]struct{ Role, Content string }, len(messages)) + for i, m := range messages { + logMessages[i] = struct{ Role, Content string }{m.Role, m.Content} + } + c.chatLogger.LogStart(stream, o.Model, o.Temperature, o.MaxTokens, o.Stop, logMessages) +} + +func buildOllamaRequest(o Options, messages []Message, defaultTemp *float64, stream bool) ollamaChatRequest { + req := ollamaChatRequest{Model: o.Model, Stream: stream} + req.Messages = make([]oaMessage, len(messages)) + for i, m := range messages { + req.Messages[i] = oaMessage{Role: m.Role, Content: m.Content} + } + optsMap := map[string]any{} + if o.Temperature != 0 { + optsMap["temperature"] = o.Temperature + } else if defaultTemp != nil { + optsMap["temperature"] = *defaultTemp + } + if o.MaxTokens > 0 { + optsMap["num_predict"] = o.MaxTokens + } + if len(o.Stop) > 0 { + optsMap["stop"] = o.Stop + } + if len(optsMap) > 0 { + req.Options = optsMap + } + return req +} + +func (c ollamaClient) doJSON(ctx context.Context, url string, body []byte) (*http.Response, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + return c.httpClient.Do(req) +} + +func handleOllamaNon2xx(resp *http.Response, start time.Time) error { + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + return nil + } + var apiErr ollamaChatResponse + _ = json.NewDecoder(resp.Body).Decode(&apiErr) + if strings.TrimSpace(apiErr.Error) != "" { + logging.Logf("llm/ollama ", "%sapi error status=%d msg=%s duration=%s%s", logging.AnsiRed, resp.StatusCode, apiErr.Error, time.Since(start), logging.AnsiBase) + return fmt.Errorf("ollama error: %s (status %d)", apiErr.Error, resp.StatusCode) + } + logging.Logf("llm/ollama ", "%shttp non-2xx status=%d duration=%s%s", logging.AnsiRed, resp.StatusCode, time.Since(start), logging.AnsiBase) + return fmt.Errorf("ollama http error: status %d", resp.StatusCode) +} |
