summaryrefslogtreecommitdiff
path: root/internal/llm/ollama.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/llm/ollama.go')
-rw-r--r--internal/llm/ollama.go52
1 files changed, 31 insertions, 21 deletions
diff --git a/internal/llm/ollama.go b/internal/llm/ollama.go
index be93ab0..e212466 100644
--- a/internal/llm/ollama.go
+++ b/internal/llm/ollama.go
@@ -133,7 +133,35 @@ func (c ollamaClient) Chat(ctx context.Context, messages []Message, opts ...Requ
func (c ollamaClient) Name() string { return "ollama" }
func (c ollamaClient) DefaultModel() string { return c.defaultModel }
-// Streaming support (optional)
+// parseOllamaStream reads NDJSON streaming events from dec, calling onDelta for each
+// non-empty content delta. Returns an error if decoding fails or the server signals
+// an error event; returns nil when the done flag is received or the stream ends.
+func parseOllamaStream(dec *json.Decoder, start time.Time, onDelta func(string)) error {
+ for {
+ var ev ollamaChatResponse
+ if err := dec.Decode(&ev); err != nil {
+ if errors.Is(err, io.EOF) {
+ break
+ }
+ logging.Logf("llm/ollama ", "%sdecode stream error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase)
+ return err
+ }
+ if strings.TrimSpace(ev.Error) != "" {
+ logging.Logf("llm/ollama ", "%sstream event error: %s%s", logging.AnsiRed, ev.Error, logging.AnsiBase)
+ return fmt.Errorf("ollama stream error: %s", ev.Error)
+ }
+ if s := ev.Message.Content; strings.TrimSpace(s) != "" {
+ onDelta(s)
+ }
+ if ev.Done {
+ break
+ }
+ }
+ return nil
+}
+
+// ChatStream sends a streaming chat request to Ollama, calling onDelta for each
+// received content delta. It blocks until the stream ends or an error occurs.
func (c ollamaClient) ChatStream(ctx context.Context, messages []Message, onDelta func(string), opts ...RequestOption) error {
o := Options{Model: c.defaultModel}
for _, opt := range opts {
@@ -167,26 +195,8 @@ func (c ollamaClient) ChatStream(ctx context.Context, messages []Message, onDelt
return err
}
- dec := json.NewDecoder(resp.Body)
- for {
- var ev ollamaChatResponse
- if err := dec.Decode(&ev); err != nil {
- if errors.Is(err, io.EOF) {
- break
- }
- logging.Logf("llm/ollama ", "%sdecode stream error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase)
- return err
- }
- if strings.TrimSpace(ev.Error) != "" {
- logging.Logf("llm/ollama ", "%sstream event error: %s%s", logging.AnsiRed, ev.Error, logging.AnsiBase)
- return fmt.Errorf("ollama stream error: %s", ev.Error)
- }
- if s := ev.Message.Content; strings.TrimSpace(s) != "" {
- onDelta(s)
- }
- if ev.Done {
- break
- }
+ if err := parseOllamaStream(json.NewDecoder(resp.Body), start, onDelta); err != nil {
+ return err
}
logging.Logf("llm/ollama ", "stream end duration=%s", time.Since(start))
return nil