summaryrefslogtreecommitdiff
path: root/internal/hexaicli
diff options
context:
space:
mode:
authorPaul Buetow <paul@buetow.org>2025-09-28 00:20:05 +0300
committerPaul Buetow <paul@buetow.org>2025-09-28 00:20:05 +0300
commit0ac2d186e84f77d73d924e2c0ce975a17c3a8078 (patch)
tree49f3e2def38449544e1d67f047cbcb4aab802658 /internal/hexaicli
parent51b2621d58633aa5c0f5cc7b64616d70d41acc91 (diff)
Improve multi-provider completion streaming and CLI selector flags
Diffstat (limited to 'internal/hexaicli')
-rw-r--r--internal/hexaicli/run.go316
-rw-r--r--internal/hexaicli/run_model_override_test.go54
-rw-r--r--internal/hexaicli/run_test.go17
3 files changed, 338 insertions, 49 deletions
diff --git a/internal/hexaicli/run.go b/internal/hexaicli/run.go
index 06fcb83..b7745c8 100644
--- a/internal/hexaicli/run.go
+++ b/internal/hexaicli/run.go
@@ -20,6 +20,8 @@ import (
"codeberg.org/snonux/hexai/internal/logging"
"codeberg.org/snonux/hexai/internal/stats"
"codeberg.org/snonux/hexai/internal/tmux"
+ "github.com/mattn/go-runewidth"
+ "golang.org/x/term"
)
type requestArgs struct {
@@ -35,6 +37,23 @@ type cliJob struct {
req requestArgs
}
+type columnPrinter struct {
+ mu sync.Mutex
+ stdout io.Writer
+ columns int
+ colWidth int
+ partial []string
+ providers []string
+ models []string
+}
+
+type columnWriter struct {
+ printer *columnPrinter
+ index int
+}
+
+type selectionContextKey struct{}
+
func buildCLIJobs(cfg appconfig.App) ([]cliJob, error) {
entries := cfg.CLIConfigs
if len(entries) == 0 {
@@ -150,6 +169,13 @@ func Run(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io.
fmt.Fprintf(stderr, logging.AnsiBase+"hexai: LLM disabled: %v"+logging.AnsiReset+"\n", err)
return err
}
+ if selected := selectionFromContext(ctx); len(selected) > 0 {
+ jobs, err = filterJobsBySelection(jobs, selected)
+ if err != nil {
+ fmt.Fprintf(stderr, logging.AnsiBase+"hexai: %v"+logging.AnsiReset+"\n", err)
+ return err
+ }
+ }
if len(jobs) == 0 {
return fmt.Errorf("hexai: no CLI providers configured")
}
@@ -203,16 +229,29 @@ type cliJobResult struct {
func runCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input string, stdout, stderr io.Writer) error {
results := make([]*cliJobResult, len(jobs))
var wg sync.WaitGroup
+ var printer *columnPrinter
+ if len(jobs) > 0 {
+ printer = newColumnPrinter(stdout, jobs)
+ printer.PrintHeader()
+ }
for _, job := range jobs {
job := job
wg.Add(1)
printProviderInfo(stderr, job.client, job.req.model)
go func() {
defer wg.Done()
- var outBuf, errBuf bytes.Buffer
+ var errBuf bytes.Buffer
+ var outBuf bytes.Buffer
jobMsgs := make([]llm.Message, len(msgs))
copy(jobMsgs, msgs)
- err := runChat(ctx, job.client, job.req, jobMsgs, input, &outBuf, &errBuf)
+ writer := io.Writer(&outBuf)
+ if printer != nil {
+ writer = printer.Writer(job.index)
+ }
+ err := runChat(ctx, job.client, job.req, jobMsgs, input, writer, &errBuf)
+ if printer != nil {
+ printer.Flush(job.index)
+ }
results[job.index] = &cliJobResult{
provider: job.client.Name(),
model: job.req.model,
@@ -224,48 +263,275 @@ func runCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input st
}
wg.Wait()
var firstErr error
- printed := false
- for _, res := range results {
- if res == nil {
- continue
- }
- if printed {
- if _, err := io.WriteString(stdout, "\n"); err != nil {
- return err
+ if printer == nil {
+ printed := false
+ for _, res := range results {
+ if res == nil {
+ continue
}
- }
- heading := fmt.Sprintf("=== %s:%s ===\n", res.provider, res.model)
- if _, err := io.WriteString(stdout, heading); err != nil {
- return err
- }
- if res.output != "" {
- if _, err := io.WriteString(stdout, res.output); err != nil {
+ if printed {
+ if _, err := io.WriteString(stdout, "\n"); err != nil {
+ return err
+ }
+ }
+ heading := fmt.Sprintf("=== %s:%s ===\n", res.provider, res.model)
+ if _, err := io.WriteString(stdout, heading); err != nil {
return err
}
- if !strings.HasSuffix(res.output, "\n") {
- if _, err := io.WriteString(stdout, "\n"); err != nil {
+ if res.output != "" {
+ if _, err := io.WriteString(stdout, res.output); err != nil {
return err
}
+ if !strings.HasSuffix(res.output, "\n") {
+ if _, err := io.WriteString(stdout, "\n"); err != nil {
+ return err
+ }
+ }
}
+ printed = true
+ }
+ }
+ for _, res := range results {
+ if res == nil {
+ continue
}
- printed = true
if res.summary != "" {
- if _, err := io.WriteString(stderr, res.summary); err != nil {
- return err
+ summary := strings.TrimLeft(res.summary, "\n")
+ if summary != "" {
+ if _, err := io.WriteString(stderr, summary); err != nil {
+ return err
+ }
}
}
if res.err != nil {
if _, err := fmt.Fprintf(stderr, logging.AnsiBase+"hexai: provider=%s model=%s error: %v"+logging.AnsiReset+"\n", res.provider, res.model, res.err); err != nil {
return err
}
- if firstErr == nil {
- firstErr = res.err
- }
+ }
+ if firstErr == nil && res.err != nil {
+ firstErr = res.err
}
}
return firstErr
}
+func newColumnPrinter(stdout io.Writer, jobs []cliJob) *columnPrinter {
+ cols := len(jobs)
+ width := detectTerminalWidth(stdout)
+ if width <= 0 {
+ width = 100
+ }
+ sepWidth := (cols - 1) * 3
+ colWidth := (width - sepWidth) / cols
+ if colWidth < 20 {
+ colWidth = 20
+ }
+ providers := make([]string, cols)
+ models := make([]string, cols)
+ for _, job := range jobs {
+ providers[job.index] = job.client.Name()
+ models[job.index] = job.req.model
+ }
+ return &columnPrinter{
+ stdout: stdout,
+ columns: cols,
+ colWidth: colWidth,
+ partial: make([]string, cols),
+ providers: providers,
+ models: models,
+ }
+}
+
+func detectTerminalWidth(w io.Writer) int {
+ type fder interface{ Fd() uintptr }
+ if f, ok := w.(*os.File); ok {
+ if width, _, err := term.GetSize(int(f.Fd())); err == nil {
+ return width
+ }
+ }
+ if f, ok := w.(fder); ok {
+ if width, _, err := term.GetSize(int(f.Fd())); err == nil {
+ return width
+ }
+ }
+ return 0
+}
+
+func (cp *columnPrinter) Writer(idx int) io.Writer {
+ return columnWriter{printer: cp, index: idx}
+}
+
+func (cp *columnPrinter) PrintHeader() {
+ cp.mu.Lock()
+ defer cp.mu.Unlock()
+ combo := make([]string, cp.columns)
+ for i := 0; i < cp.columns; i++ {
+ provider := strings.TrimSpace(cp.providers[i])
+ model := strings.TrimSpace(cp.models[i])
+ switch {
+ case provider != "" && model != "":
+ combo[i] = provider + ":" + model
+ case provider != "":
+ combo[i] = provider
+ case model != "":
+ combo[i] = model
+ default:
+ combo[i] = ""
+ }
+ }
+ cp.writeLine(combo)
+ divider := make([]string, cp.columns)
+ line := strings.Repeat("─", cp.colWidth)
+ for i := range divider {
+ divider[i] = line
+ }
+ cp.writeLine(divider)
+}
+
+func (cp *columnPrinter) Flush(idx int) {
+ cp.mu.Lock()
+ defer cp.mu.Unlock()
+ if idx < 0 || idx >= len(cp.partial) {
+ return
+ }
+ if cp.partial[idx] == "" {
+ return
+ }
+ cp.emitJobLine(idx, cp.partial[idx])
+ cp.partial[idx] = ""
+}
+
+func (w columnWriter) Write(p []byte) (int, error) {
+ return w.printer.write(w.index, string(p))
+}
+
+func (cp *columnPrinter) write(idx int, data string) (int, error) {
+ cp.mu.Lock()
+ defer cp.mu.Unlock()
+ if idx < 0 || idx >= len(cp.partial) {
+ return len(data), nil
+ }
+ data = strings.ReplaceAll(data, "\r", "")
+ cp.partial[idx] += data
+ for strings.Contains(cp.partial[idx], "\n") {
+ line, rest, _ := strings.Cut(cp.partial[idx], "\n")
+ cp.partial[idx] = rest
+ cp.emitJobLine(idx, line)
+ }
+ return len(data), nil
+}
+
+func (cp *columnPrinter) emitJobLine(idx int, line string) {
+ segments := cp.wrap(line)
+ for _, seg := range segments {
+ cells := make([]string, cp.columns)
+ if idx >= 0 && idx < len(cells) {
+ cells[idx] = seg
+ }
+ cp.writeLine(cells)
+ }
+}
+
+func (cp *columnPrinter) wrap(text string) []string {
+ text = strings.ReplaceAll(text, "\t", " ")
+ if runewidth.StringWidth(text) <= cp.colWidth {
+ return []string{text}
+ }
+ var lines []string
+ var current strings.Builder
+ width := 0
+ for _, r := range text {
+ rw := runewidth.RuneWidth(r)
+ if width+rw > cp.colWidth && current.Len() > 0 {
+ lines = append(lines, current.String())
+ current.Reset()
+ width = 0
+ }
+ current.WriteRune(r)
+ width += rw
+ }
+ if current.Len() > 0 {
+ lines = append(lines, current.String())
+ }
+ if len(lines) == 0 {
+ lines = append(lines, "")
+ }
+ return lines
+}
+
+func (cp *columnPrinter) writeLine(cells []string) {
+ if len(cells) < cp.columns {
+ extra := make([]string, cp.columns-len(cells))
+ cells = append(cells, extra...)
+ }
+ var builder strings.Builder
+ for i := 0; i < cp.columns; i++ {
+ cell := cells[i]
+ width := runewidth.StringWidth(cell)
+ if width > cp.colWidth {
+ cell = runewidth.Truncate(cell, cp.colWidth, "…")
+ width = runewidth.StringWidth(cell)
+ }
+ builder.WriteString(cell)
+ if pad := cp.colWidth - width; pad > 0 {
+ builder.WriteString(strings.Repeat(" ", pad))
+ }
+ if i != cp.columns-1 {
+ builder.WriteString(" │ ")
+ }
+ }
+ builder.WriteByte('\n')
+ _, _ = cp.stdout.Write([]byte(builder.String()))
+}
+
+// WithCLISelection injects provider indices into the context so Run only executes those jobs.
+func WithCLISelection(ctx context.Context, indices []int) context.Context {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+ cpy := make([]int, len(indices))
+ copy(cpy, indices)
+ return context.WithValue(ctx, selectionContextKey{}, cpy)
+}
+
+func selectionFromContext(ctx context.Context) []int {
+ if ctx == nil {
+ return nil
+ }
+ if v, ok := ctx.Value(selectionContextKey{}).([]int); ok {
+ cpy := make([]int, len(v))
+ copy(cpy, v)
+ return cpy
+ }
+ return nil
+}
+
+func filterJobsBySelection(jobs []cliJob, indices []int) ([]cliJob, error) {
+ if len(indices) == 0 {
+ return jobs, nil
+ }
+ filtered := make([]cliJob, 0, len(indices))
+ seen := make(map[int]struct{}, len(indices))
+ for _, idx := range indices {
+ if idx < 0 || idx >= len(jobs) {
+ return nil, fmt.Errorf("provider index %d out of range (0-%d)", idx, len(jobs)-1)
+ }
+ if _, ok := seen[idx]; ok {
+ continue
+ }
+ clone := jobs[idx]
+ filtered = append(filtered, clone)
+ seen[idx] = struct{}{}
+ }
+ for i := range filtered {
+ filtered[i].index = i
+ }
+ if len(filtered) == 0 {
+ return nil, fmt.Errorf("no CLI providers matched selection")
+ }
+ return filtered, nil
+}
+
// readInput reads from stdin and args, then combines them per CLI rules.
func readInput(stdin io.Reader, args []string) (string, error) {
var stdinData string
diff --git a/internal/hexaicli/run_model_override_test.go b/internal/hexaicli/run_model_override_test.go
index 6394bd1..b32b172 100644
--- a/internal/hexaicli/run_model_override_test.go
+++ b/internal/hexaicli/run_model_override_test.go
@@ -1,39 +1,45 @@
package hexaicli
import (
- "bytes"
- "context"
- "strings"
- "testing"
+ "bytes"
+ "context"
+ "strings"
+ "testing"
- "codeberg.org/snonux/hexai/internal/appconfig"
- "codeberg.org/snonux/hexai/internal/llm"
+ "codeberg.org/snonux/hexai/internal/appconfig"
+ "codeberg.org/snonux/hexai/internal/llm"
)
type fakeClientModelEnv struct{ name, model string }
-func (f fakeClientModelEnv) Chat(_ context.Context, _ []llm.Message, _ ...llm.RequestOption) (string, error) { return "ok", nil }
+
+func (f fakeClientModelEnv) Chat(_ context.Context, _ []llm.Message, _ ...llm.RequestOption) (string, error) {
+ return "ok", nil
+}
func (f fakeClientModelEnv) Name() string { return f.name }
func (f fakeClientModelEnv) DefaultModel() string { return f.model }
// Ensure that HEXAI_MODEL overrides config for CLI runs.
func TestRun_ModelEnvOverride_FlowsIntoClient(t *testing.T) {
- t.Setenv("HEXAI_MODEL", "gpt-5-codex")
- t.Setenv("HEXAI_PROVIDER", "openai")
- // Replace client constructor to assert model was overridden
- oldNew := newClientFromApp
- defer func() { newClientFromApp = oldNew }()
+ t.Setenv("XDG_CONFIG_HOME", t.TempDir())
+ t.Setenv("HEXAI_MODEL", "gpt-5-codex")
+ t.Setenv("HEXAI_PROVIDER", "openai")
+ // Replace client constructor to assert model was overridden
+ oldNew := newClientFromApp
+ defer func() { newClientFromApp = oldNew }()
+ var seenModel string
newClientFromApp = func(cfg appconfig.App) (llm.Client, error) {
- if strings.TrimSpace(cfg.OpenAIModel) != "gpt-5-codex" {
- t.Fatalf("expected cfg.OpenAIModel=gpt-5-codex, got %q", cfg.OpenAIModel)
- }
- return fakeClientModelEnv{name: "openai", model: cfg.OpenAIModel}, nil
- }
+ seenModel = strings.TrimSpace(cfg.OpenAIModel)
+ return fakeClientModelEnv{name: "openai", model: cfg.OpenAIModel}, nil
+ }
- var out, errb bytes.Buffer
- if err := Run(context.Background(), []string{"hello"}, strings.NewReader(""), &out, &errb); err != nil {
- t.Fatalf("run error: %v", err)
- }
- if !strings.Contains(errb.String(), "model=gpt-5-codex") {
- t.Fatalf("stderr should print effective model, got: %s", errb.String())
- }
+ var out, errb bytes.Buffer
+ if err := Run(context.Background(), []string{"hello"}, strings.NewReader(""), &out, &errb); err != nil {
+ t.Fatalf("run error: %v", err)
+ }
+ if seenModel != "gpt-5-codex" {
+ t.Fatalf("expected cfg.OpenAIModel=gpt-5-codex, got %q", seenModel)
+ }
+ if !strings.Contains(errb.String(), "model=gpt-5-codex") {
+ t.Fatalf("stderr should print effective model, got: %s", errb.String())
+ }
}
diff --git a/internal/hexaicli/run_test.go b/internal/hexaicli/run_test.go
index f11545e..dfde068 100644
--- a/internal/hexaicli/run_test.go
+++ b/internal/hexaicli/run_test.go
@@ -225,6 +225,23 @@ func TestBuildCLIJobs_MultiEntries(t *testing.T) {
}
}
+func TestFilterJobsBySelection(t *testing.T) {
+ jobs := []cliJob{{index: 0, provider: "openai"}, {index: 1, provider: "ollama"}, {index: 2, provider: "copilot"}}
+ filtered, err := filterJobsBySelection(jobs, []int{2, 0})
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if len(filtered) != 2 || filtered[0].provider != "copilot" || filtered[1].provider != "openai" {
+ t.Fatalf("unexpected filtered order: %+v", filtered)
+ }
+ if filtered[0].index != 0 || filtered[1].index != 1 {
+ t.Fatalf("expected reindexed jobs, got %+v", filtered)
+ }
+ if _, err := filterJobsBySelection(jobs, []int{5}); err == nil {
+ t.Fatalf("expected out-of-range error")
+ }
+}
+
func TestNewClientFromConfig_Ollama(t *testing.T) {
cfg := appconfig.App{Provider: "ollama", OllamaBaseURL: "http://x", OllamaModel: "m"}
c, err := newClientFromConfig(cfg)