diff options
Diffstat (limited to 'internal/llm/ollama.go')
| -rw-r--r-- | internal/llm/ollama.go | 52 |
1 files changed, 31 insertions, 21 deletions
diff --git a/internal/llm/ollama.go b/internal/llm/ollama.go index be93ab0..e212466 100644 --- a/internal/llm/ollama.go +++ b/internal/llm/ollama.go @@ -133,7 +133,35 @@ func (c ollamaClient) Chat(ctx context.Context, messages []Message, opts ...Requ func (c ollamaClient) Name() string { return "ollama" } func (c ollamaClient) DefaultModel() string { return c.defaultModel } -// Streaming support (optional) +// parseOllamaStream reads NDJSON streaming events from dec, calling onDelta for each +// non-empty content delta. Returns an error if decoding fails or the server signals +// an error event; returns nil when the done flag is received or the stream ends. +func parseOllamaStream(dec *json.Decoder, start time.Time, onDelta func(string)) error { + for { + var ev ollamaChatResponse + if err := dec.Decode(&ev); err != nil { + if errors.Is(err, io.EOF) { + break + } + logging.Logf("llm/ollama ", "%sdecode stream error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase) + return err + } + if strings.TrimSpace(ev.Error) != "" { + logging.Logf("llm/ollama ", "%sstream event error: %s%s", logging.AnsiRed, ev.Error, logging.AnsiBase) + return fmt.Errorf("ollama stream error: %s", ev.Error) + } + if s := ev.Message.Content; strings.TrimSpace(s) != "" { + onDelta(s) + } + if ev.Done { + break + } + } + return nil +} + +// ChatStream sends a streaming chat request to Ollama, calling onDelta for each +// received content delta. It blocks until the stream ends or an error occurs. func (c ollamaClient) ChatStream(ctx context.Context, messages []Message, onDelta func(string), opts ...RequestOption) error { o := Options{Model: c.defaultModel} for _, opt := range opts { @@ -167,26 +195,8 @@ func (c ollamaClient) ChatStream(ctx context.Context, messages []Message, onDelt return err } - dec := json.NewDecoder(resp.Body) - for { - var ev ollamaChatResponse - if err := dec.Decode(&ev); err != nil { - if errors.Is(err, io.EOF) { - break - } - logging.Logf("llm/ollama ", "%sdecode stream error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase) - return err - } - if strings.TrimSpace(ev.Error) != "" { - logging.Logf("llm/ollama ", "%sstream event error: %s%s", logging.AnsiRed, ev.Error, logging.AnsiBase) - return fmt.Errorf("ollama stream error: %s", ev.Error) - } - if s := ev.Message.Content; strings.TrimSpace(s) != "" { - onDelta(s) - } - if ev.Done { - break - } + if err := parseOllamaStream(json.NewDecoder(resp.Body), start, onDelta); err != nil { + return err } logging.Logf("llm/ollama ", "stream end duration=%s", time.Since(start)) return nil |
