summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPaul Buetow <paul@buetow.org>2025-08-17 09:05:45 +0300
committerPaul Buetow <paul@buetow.org>2025-08-17 09:05:45 +0300
commitd5fbb6ef5957894eb5be0854bdb328a6774abddb (patch)
tree7558a3d1c70db85d54f3367a5c81c28e7260392b
parent5d9e197a394089f66539320e77abb3f3689b3381 (diff)
cli: stream responses in hexai when supported (OpenAI, Ollama)
- Add llm.Streamer optional interface - Implement ChatStream for OpenAI (SSE) and Ollama (JSON stream) - CLI uses streaming; LSP unchanged (non-streaming) - README: document streaming behavior for CLI
-rw-r--r--README.md1
-rw-r--r--cmd/hexai/main.go42
-rw-r--r--internal/llm/ollama.go116
-rw-r--r--internal/llm/openai.go146
-rw-r--r--internal/llm/provider.go22
5 files changed, 284 insertions, 43 deletions
diff --git a/README.md b/README.md
index 7de7365..b17983d 100644
--- a/README.md
+++ b/README.md
@@ -58,6 +58,7 @@ Notes for `hexai` (CLI):
- Prints LLM output to stdout.
- Prints provider/model immediately to stderr, and a summary to stderr at the end (time, input bytes, output bytes, provider/model).
- Default response style: short answers. If the prompt asks for commands, outputs only the commands with no explanation. Include the word `explain` anywhere in the prompt to request a verbose explanation.
+- Streams output: when supported by the provider (OpenAI, Ollama), `hexai` streams tokens and prints them to stdout as they arrive. Copilot falls back to non-streaming.
### Hexai CLI behavior
diff --git a/cmd/hexai/main.go b/cmd/hexai/main.go
index c356800..6cbd288 100644
--- a/cmd/hexai/main.go
+++ b/cmd/hexai/main.go
@@ -83,18 +83,34 @@ func main() {
{Role: "system", Content: system},
{Role: "user", Content: input},
}
- out, err := client.Chat(context.Background(), msgs)
- dur := time.Since(start)
- if err != nil {
- fmt.Fprintf(os.Stderr, logging.AnsiBase+"hexai: error: %v"+logging.AnsiReset+"\n", err)
- os.Exit(1)
+ var out string
+ if s, ok := client.(llm.Streamer); ok {
+ var b strings.Builder
+ err := s.ChatStream(context.Background(), msgs, func(chunk string) {
+ b.WriteString(chunk)
+ fmt.Fprint(os.Stdout, chunk)
+ })
+ dur := time.Since(start)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, logging.AnsiBase+"hexai: error: %v"+logging.AnsiReset+"\n", err)
+ os.Exit(1)
+ }
+ out = b.String()
+ // Summary
+ inSize := len(input)
+ outSize := len(out)
+ fmt.Fprintf(os.Stderr, "\n"+logging.AnsiBase+"done provider=%s model=%s time=%s in_bytes=%d out_bytes=%d"+logging.AnsiReset+"\n", client.Name(), client.DefaultModel(), dur.Round(time.Millisecond), inSize, outSize)
+ } else {
+ outText, err := client.Chat(context.Background(), msgs)
+ dur := time.Since(start)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, logging.AnsiBase+"hexai: error: %v"+logging.AnsiReset+"\n", err)
+ os.Exit(1)
+ }
+ out = outText
+ fmt.Fprint(os.Stdout, out)
+ inSize := len(input)
+ outSize := len(out)
+ fmt.Fprintf(os.Stderr, "\n"+logging.AnsiBase+"done provider=%s model=%s time=%s in_bytes=%d out_bytes=%d"+logging.AnsiReset+"\n", client.Name(), client.DefaultModel(), dur.Round(time.Millisecond), inSize, outSize)
}
-
- // Write assistant output to stdout
- fmt.Fprint(os.Stdout, out)
-
- // Summary to stderr (preceded by a blank line)
- inSize := len(input)
- outSize := len(out)
- fmt.Fprintf(os.Stderr, "\n"+logging.AnsiBase+"done provider=%s model=%s time=%s in_bytes=%d out_bytes=%d"+logging.AnsiReset+"\n", client.Name(), client.DefaultModel(), dur.Round(time.Millisecond), inSize, outSize)
}
diff --git a/internal/llm/ollama.go b/internal/llm/ollama.go
index e8b75c9..ffee354 100644
--- a/internal/llm/ollama.go
+++ b/internal/llm/ollama.go
@@ -1,14 +1,15 @@
package llm
import (
- "bytes"
- "context"
- "encoding/json"
- "errors"
- "fmt"
- "net/http"
- "strings"
- "time"
+ "bytes"
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "strings"
+ "time"
"hexai/internal/logging"
)
@@ -35,10 +36,10 @@ func newOllama(baseURL, model string) Client {
}
type ollamaChatRequest struct {
- Model string `json:"model"`
- Messages []oaMessage `json:"messages"`
- Stream bool `json:"stream"`
- Options any `json:"options,omitempty"`
+ Model string `json:"model"`
+ Messages []oaMessage `json:"messages"`
+ Stream bool `json:"stream"`
+ Options any `json:"options,omitempty"`
}
type ollamaChatResponse struct {
@@ -133,3 +134,94 @@ func (c *ollamaClient) Chat(ctx context.Context, messages []Message, opts ...Req
// Provider metadata
func (c *ollamaClient) Name() string { return "ollama" }
func (c *ollamaClient) DefaultModel() string { return c.defaultModel }
+
+// Streaming support (optional)
+func (c *ollamaClient) ChatStream(ctx context.Context, messages []Message, onDelta func(string), opts ...RequestOption) error {
+ o := Options{Model: c.defaultModel}
+ for _, opt := range opts {
+ opt(&o)
+ }
+ if o.Model == "" {
+ o.Model = c.defaultModel
+ }
+
+ start := time.Now()
+ logging.Logf("llm/ollama ", "stream start model=%s temp=%.2f max_tokens=%d stop=%d messages=%d", o.Model, o.Temperature, o.MaxTokens, len(o.Stop), len(messages))
+ for i, m := range messages {
+ logging.Logf("llm/ollama ", "msg[%d] role=%s size=%d preview=%s%s%s", i, m.Role, len(m.Content), logging.AnsiCyan, logging.PreviewForLog(m.Content), logging.AnsiBase)
+ }
+
+ req := ollamaChatRequest{Model: o.Model, Stream: true}
+ req.Messages = make([]oaMessage, len(messages))
+ for i, m := range messages {
+ req.Messages[i] = oaMessage{Role: m.Role, Content: m.Content}
+ }
+ // Build options map
+ optsMap := map[string]any{}
+ if o.Temperature != 0 {
+ optsMap["temperature"] = o.Temperature
+ }
+ if o.MaxTokens > 0 {
+ optsMap["num_predict"] = o.MaxTokens
+ }
+ if len(o.Stop) > 0 {
+ optsMap["stop"] = o.Stop
+ }
+ if len(optsMap) > 0 {
+ req.Options = optsMap
+ }
+
+ body, err := json.Marshal(req)
+ if err != nil {
+ return err
+ }
+
+ endpoint := c.baseURL + "/api/chat"
+ logging.Logf("llm/ollama ", "POST %s (stream)", endpoint)
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
+ if err != nil {
+ return err
+ }
+ httpReq.Header.Set("Content-Type", "application/json")
+
+ resp, err := c.httpClient.Do(httpReq)
+ if err != nil {
+ logging.Logf("llm/ollama ", "%shttp error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase)
+ return err
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
+ var apiErr ollamaChatResponse
+ _ = json.NewDecoder(resp.Body).Decode(&apiErr)
+ if strings.TrimSpace(apiErr.Error) != "" {
+ logging.Logf("llm/ollama ", "%sapi error status=%d msg=%s duration=%s%s", logging.AnsiRed, resp.StatusCode, apiErr.Error, time.Since(start), logging.AnsiBase)
+ return fmt.Errorf("ollama error: %s (status %d)", apiErr.Error, resp.StatusCode)
+ }
+ logging.Logf("llm/ollama ", "%shttp non-2xx status=%d duration=%s%s", logging.AnsiRed, resp.StatusCode, time.Since(start), logging.AnsiBase)
+ return fmt.Errorf("ollama http error: status %d", resp.StatusCode)
+ }
+
+ dec := json.NewDecoder(resp.Body)
+ for {
+ var ev ollamaChatResponse
+ if err := dec.Decode(&ev); err != nil {
+ if errors.Is(err, io.EOF) {
+ break
+ }
+ logging.Logf("llm/ollama ", "%sdecode stream error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase)
+ return err
+ }
+ if strings.TrimSpace(ev.Error) != "" {
+ logging.Logf("llm/ollama ", "%sstream event error: %s%s", logging.AnsiRed, ev.Error, logging.AnsiBase)
+ return fmt.Errorf("ollama stream error: %s", ev.Error)
+ }
+ if s := ev.Message.Content; strings.TrimSpace(s) != "" {
+ onDelta(s)
+ }
+ if ev.Done {
+ break
+ }
+ }
+ logging.Logf("llm/ollama ", "stream end duration=%s", time.Since(start))
+ return nil
+}
diff --git a/internal/llm/openai.go b/internal/llm/openai.go
index 03e894a..080d4e9 100644
--- a/internal/llm/openai.go
+++ b/internal/llm/openai.go
@@ -1,16 +1,17 @@
package llm
import (
- "bytes"
- "context"
- "encoding/json"
- "errors"
- "fmt"
+ "bufio"
+ "bytes"
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
"net/http"
"strings"
- "time"
+ "time"
- "hexai/internal/logging"
+ "hexai/internal/logging"
)
// openAIClient implements Client against OpenAI's Chat Completions API.
@@ -41,11 +42,12 @@ func newOpenAI(baseURL, model, apiKey string) Client {
}
type oaChatRequest struct {
- Model string `json:"model"`
- Messages []oaMessage `json:"messages"`
- Temperature *float64 `json:"temperature,omitempty"`
- MaxTokens *int `json:"max_tokens,omitempty"`
- Stop []string `json:"stop,omitempty"`
+ Model string `json:"model"`
+ Messages []oaMessage `json:"messages"`
+ Temperature *float64 `json:"temperature,omitempty"`
+ MaxTokens *int `json:"max_tokens,omitempty"`
+ Stop []string `json:"stop,omitempty"`
+ Stream bool `json:"stream,omitempty"`
}
type oaMessage struct {
@@ -163,3 +165,123 @@ func trimPreview(s string, n int) string {
// Provider metadata
func (c *openAIClient) Name() string { return "openai" }
func (c *openAIClient) DefaultModel() string { return c.defaultModel }
+
+// Streaming support (optional)
+type oaStreamChunk struct {
+ Choices []struct {
+ Delta struct {
+ Content string `json:"content"`
+ } `json:"delta"`
+ FinishReason string `json:"finish_reason"`
+ } `json:"choices"`
+ Error *struct {
+ Message string `json:"message"`
+ Type string `json:"type"`
+ Param any `json:"param"`
+ Code any `json:"code"`
+ } `json:"error,omitempty"`
+}
+
+func (c *openAIClient) ChatStream(ctx context.Context, messages []Message, onDelta func(string), opts ...RequestOption) error {
+ if c.apiKey == "" {
+ return errors.New("missing OpenAI API key")
+ }
+ o := Options{Model: c.defaultModel}
+ for _, opt := range opts {
+ opt(&o)
+ }
+ if o.Model == "" {
+ o.Model = c.defaultModel
+ }
+ start := time.Now()
+ logging.Logf("llm/openai ", "stream start model=%s temp=%.2f max_tokens=%d stop=%d messages=%d", o.Model, o.Temperature, o.MaxTokens, len(o.Stop), len(messages))
+ for i, m := range messages {
+ logging.Logf("llm/openai ", "msg[%d] role=%s size=%d preview=%s%s%s", i, m.Role, len(m.Content), logging.AnsiCyan, logging.PreviewForLog(m.Content), logging.AnsiBase)
+ }
+
+ req := oaChatRequest{Model: o.Model, Stream: true}
+ req.Messages = make([]oaMessage, len(messages))
+ for i, m := range messages {
+ req.Messages[i] = oaMessage{Role: m.Role, Content: m.Content}
+ }
+ if o.Temperature != 0 {
+ req.Temperature = &o.Temperature
+ }
+ if o.MaxTokens > 0 {
+ req.MaxTokens = &o.MaxTokens
+ }
+ if len(o.Stop) > 0 {
+ req.Stop = o.Stop
+ }
+
+ body, err := json.Marshal(req)
+ if err != nil {
+ c.logf("marshal error: %v", err)
+ return err
+ }
+ endpoint := c.baseURL + "/chat/completions"
+ logging.Logf("llm/openai ", "POST %s (stream)", endpoint)
+ httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
+ if err != nil {
+ c.logf("new request error: %v", err)
+ return err
+ }
+ httpReq.Header.Set("Content-Type", "application/json")
+ httpReq.Header.Set("Authorization", "Bearer "+c.apiKey)
+ // Streaming uses SSE-style data lines
+ httpReq.Header.Set("Accept", "text/event-stream")
+
+ resp, err := c.httpClient.Do(httpReq)
+ if err != nil {
+ logging.Logf("llm/openai ", "%shttp error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase)
+ return err
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
+ // try to decode body to surface message
+ var apiErr oaChatResponse
+ _ = json.NewDecoder(resp.Body).Decode(&apiErr)
+ if apiErr.Error != nil && apiErr.Error.Message != "" {
+ logging.Logf("llm/openai ", "%sapi error status=%d type=%s msg=%s duration=%s%s", logging.AnsiRed, resp.StatusCode, apiErr.Error.Type, apiErr.Error.Message, time.Since(start), logging.AnsiBase)
+ return fmt.Errorf("openai error: %s (status %d)", apiErr.Error.Message, resp.StatusCode)
+ }
+ logging.Logf("llm/openai ", "%shttp non-2xx status=%d duration=%s%s", logging.AnsiRed, resp.StatusCode, time.Since(start), logging.AnsiBase)
+ return fmt.Errorf("openai http error: status %d", resp.StatusCode)
+ }
+
+ // Parse SSE: lines starting with "data: " containing JSON or [DONE]
+ scanner := bufio.NewScanner(resp.Body)
+ // Increase buffer for long lines
+ const maxBuf = 1024 * 1024
+ buf := make([]byte, 0, 64*1024)
+ scanner.Buffer(buf, maxBuf)
+ for scanner.Scan() {
+ line := scanner.Text()
+ if !strings.HasPrefix(line, "data: ") {
+ continue
+ }
+ payload := strings.TrimPrefix(line, "data: ")
+ if strings.TrimSpace(payload) == "[DONE]" {
+ break
+ }
+ var chunk oaStreamChunk
+ if err := json.Unmarshal([]byte(payload), &chunk); err != nil {
+ continue // skip malformed lines
+ }
+ if chunk.Error != nil && chunk.Error.Message != "" {
+ logging.Logf("llm/openai ", "%sstream error: %s%s", logging.AnsiRed, chunk.Error.Message, logging.AnsiBase)
+ return fmt.Errorf("openai stream error: %s", chunk.Error.Message)
+ }
+ for _, ch := range chunk.Choices {
+ if ch.Delta.Content != "" {
+ onDelta(ch.Delta.Content)
+ }
+ }
+ }
+ if err := scanner.Err(); err != nil {
+ logging.Logf("llm/openai ", "%sstream read error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase)
+ return err
+ }
+ logging.Logf("llm/openai ", "stream end duration=%s", time.Since(start))
+ return nil
+}
diff --git a/internal/llm/provider.go b/internal/llm/provider.go
index dda3d16..3e3023e 100644
--- a/internal/llm/provider.go
+++ b/internal/llm/provider.go
@@ -15,12 +15,22 @@ type Message struct {
// Client is a minimal LLM provider interface.
// Future providers (Ollama, etc.) should implement this.
type Client interface {
- // Chat sends chat messages and returns the assistant text.
- Chat(ctx context.Context, messages []Message, opts ...RequestOption) (string, error)
- // Name returns the provider's short name (e.g., "openai", "ollama").
- Name() string
- // DefaultModel returns the configured default model name.
- DefaultModel() string
+ // Chat sends chat messages and returns the assistant text.
+ Chat(ctx context.Context, messages []Message, opts ...RequestOption) (string, error)
+ // Name returns the provider's short name (e.g., "openai", "ollama").
+ Name() string
+ // DefaultModel returns the configured default model name.
+ DefaultModel() string
+}
+
+// Streamer is an optional interface that providers may implement to support
+// token-by-token streaming responses. Callers can type-assert to Streamer and
+// fall back to Client.Chat when not implemented.
+type Streamer interface {
+ // ChatStream sends chat messages and invokes onDelta with incremental text
+ // chunks as they are produced by the model. Implementations should call
+ // onDelta with empty strings sparingly (prefer only non-empty chunks).
+ ChatStream(ctx context.Context, messages []Message, onDelta func(string), opts ...RequestOption) error
}
// Options for a request. Providers may ignore unsupported fields.