diff options
| author | Paul Buetow <paul@buetow.org> | 2025-09-28 00:20:05 +0300 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2025-09-28 00:20:05 +0300 |
| commit | 0ac2d186e84f77d73d924e2c0ce975a17c3a8078 (patch) | |
| tree | 49f3e2def38449544e1d67f047cbcb4aab802658 /internal/hexaicli | |
| parent | 51b2621d58633aa5c0f5cc7b64616d70d41acc91 (diff) | |
Improve multi-provider completion streaming and CLI selector flags
Diffstat (limited to 'internal/hexaicli')
| -rw-r--r-- | internal/hexaicli/run.go | 316 | ||||
| -rw-r--r-- | internal/hexaicli/run_model_override_test.go | 54 | ||||
| -rw-r--r-- | internal/hexaicli/run_test.go | 17 |
3 files changed, 338 insertions, 49 deletions
diff --git a/internal/hexaicli/run.go b/internal/hexaicli/run.go index 06fcb83..b7745c8 100644 --- a/internal/hexaicli/run.go +++ b/internal/hexaicli/run.go @@ -20,6 +20,8 @@ import ( "codeberg.org/snonux/hexai/internal/logging" "codeberg.org/snonux/hexai/internal/stats" "codeberg.org/snonux/hexai/internal/tmux" + "github.com/mattn/go-runewidth" + "golang.org/x/term" ) type requestArgs struct { @@ -35,6 +37,23 @@ type cliJob struct { req requestArgs } +type columnPrinter struct { + mu sync.Mutex + stdout io.Writer + columns int + colWidth int + partial []string + providers []string + models []string +} + +type columnWriter struct { + printer *columnPrinter + index int +} + +type selectionContextKey struct{} + func buildCLIJobs(cfg appconfig.App) ([]cliJob, error) { entries := cfg.CLIConfigs if len(entries) == 0 { @@ -150,6 +169,13 @@ func Run(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io. fmt.Fprintf(stderr, logging.AnsiBase+"hexai: LLM disabled: %v"+logging.AnsiReset+"\n", err) return err } + if selected := selectionFromContext(ctx); len(selected) > 0 { + jobs, err = filterJobsBySelection(jobs, selected) + if err != nil { + fmt.Fprintf(stderr, logging.AnsiBase+"hexai: %v"+logging.AnsiReset+"\n", err) + return err + } + } if len(jobs) == 0 { return fmt.Errorf("hexai: no CLI providers configured") } @@ -203,16 +229,29 @@ type cliJobResult struct { func runCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input string, stdout, stderr io.Writer) error { results := make([]*cliJobResult, len(jobs)) var wg sync.WaitGroup + var printer *columnPrinter + if len(jobs) > 0 { + printer = newColumnPrinter(stdout, jobs) + printer.PrintHeader() + } for _, job := range jobs { job := job wg.Add(1) printProviderInfo(stderr, job.client, job.req.model) go func() { defer wg.Done() - var outBuf, errBuf bytes.Buffer + var errBuf bytes.Buffer + var outBuf bytes.Buffer jobMsgs := make([]llm.Message, len(msgs)) copy(jobMsgs, msgs) - err := runChat(ctx, job.client, job.req, jobMsgs, input, &outBuf, &errBuf) + writer := io.Writer(&outBuf) + if printer != nil { + writer = printer.Writer(job.index) + } + err := runChat(ctx, job.client, job.req, jobMsgs, input, writer, &errBuf) + if printer != nil { + printer.Flush(job.index) + } results[job.index] = &cliJobResult{ provider: job.client.Name(), model: job.req.model, @@ -224,48 +263,275 @@ func runCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input st } wg.Wait() var firstErr error - printed := false - for _, res := range results { - if res == nil { - continue - } - if printed { - if _, err := io.WriteString(stdout, "\n"); err != nil { - return err + if printer == nil { + printed := false + for _, res := range results { + if res == nil { + continue } - } - heading := fmt.Sprintf("=== %s:%s ===\n", res.provider, res.model) - if _, err := io.WriteString(stdout, heading); err != nil { - return err - } - if res.output != "" { - if _, err := io.WriteString(stdout, res.output); err != nil { + if printed { + if _, err := io.WriteString(stdout, "\n"); err != nil { + return err + } + } + heading := fmt.Sprintf("=== %s:%s ===\n", res.provider, res.model) + if _, err := io.WriteString(stdout, heading); err != nil { return err } - if !strings.HasSuffix(res.output, "\n") { - if _, err := io.WriteString(stdout, "\n"); err != nil { + if res.output != "" { + if _, err := io.WriteString(stdout, res.output); err != nil { return err } + if !strings.HasSuffix(res.output, "\n") { + if _, err := io.WriteString(stdout, "\n"); err != nil { + return err + } + } } + printed = true + } + } + for _, res := range results { + if res == nil { + continue } - printed = true if res.summary != "" { - if _, err := io.WriteString(stderr, res.summary); err != nil { - return err + summary := strings.TrimLeft(res.summary, "\n") + if summary != "" { + if _, err := io.WriteString(stderr, summary); err != nil { + return err + } } } if res.err != nil { if _, err := fmt.Fprintf(stderr, logging.AnsiBase+"hexai: provider=%s model=%s error: %v"+logging.AnsiReset+"\n", res.provider, res.model, res.err); err != nil { return err } - if firstErr == nil { - firstErr = res.err - } + } + if firstErr == nil && res.err != nil { + firstErr = res.err } } return firstErr } +func newColumnPrinter(stdout io.Writer, jobs []cliJob) *columnPrinter { + cols := len(jobs) + width := detectTerminalWidth(stdout) + if width <= 0 { + width = 100 + } + sepWidth := (cols - 1) * 3 + colWidth := (width - sepWidth) / cols + if colWidth < 20 { + colWidth = 20 + } + providers := make([]string, cols) + models := make([]string, cols) + for _, job := range jobs { + providers[job.index] = job.client.Name() + models[job.index] = job.req.model + } + return &columnPrinter{ + stdout: stdout, + columns: cols, + colWidth: colWidth, + partial: make([]string, cols), + providers: providers, + models: models, + } +} + +func detectTerminalWidth(w io.Writer) int { + type fder interface{ Fd() uintptr } + if f, ok := w.(*os.File); ok { + if width, _, err := term.GetSize(int(f.Fd())); err == nil { + return width + } + } + if f, ok := w.(fder); ok { + if width, _, err := term.GetSize(int(f.Fd())); err == nil { + return width + } + } + return 0 +} + +func (cp *columnPrinter) Writer(idx int) io.Writer { + return columnWriter{printer: cp, index: idx} +} + +func (cp *columnPrinter) PrintHeader() { + cp.mu.Lock() + defer cp.mu.Unlock() + combo := make([]string, cp.columns) + for i := 0; i < cp.columns; i++ { + provider := strings.TrimSpace(cp.providers[i]) + model := strings.TrimSpace(cp.models[i]) + switch { + case provider != "" && model != "": + combo[i] = provider + ":" + model + case provider != "": + combo[i] = provider + case model != "": + combo[i] = model + default: + combo[i] = "" + } + } + cp.writeLine(combo) + divider := make([]string, cp.columns) + line := strings.Repeat("─", cp.colWidth) + for i := range divider { + divider[i] = line + } + cp.writeLine(divider) +} + +func (cp *columnPrinter) Flush(idx int) { + cp.mu.Lock() + defer cp.mu.Unlock() + if idx < 0 || idx >= len(cp.partial) { + return + } + if cp.partial[idx] == "" { + return + } + cp.emitJobLine(idx, cp.partial[idx]) + cp.partial[idx] = "" +} + +func (w columnWriter) Write(p []byte) (int, error) { + return w.printer.write(w.index, string(p)) +} + +func (cp *columnPrinter) write(idx int, data string) (int, error) { + cp.mu.Lock() + defer cp.mu.Unlock() + if idx < 0 || idx >= len(cp.partial) { + return len(data), nil + } + data = strings.ReplaceAll(data, "\r", "") + cp.partial[idx] += data + for strings.Contains(cp.partial[idx], "\n") { + line, rest, _ := strings.Cut(cp.partial[idx], "\n") + cp.partial[idx] = rest + cp.emitJobLine(idx, line) + } + return len(data), nil +} + +func (cp *columnPrinter) emitJobLine(idx int, line string) { + segments := cp.wrap(line) + for _, seg := range segments { + cells := make([]string, cp.columns) + if idx >= 0 && idx < len(cells) { + cells[idx] = seg + } + cp.writeLine(cells) + } +} + +func (cp *columnPrinter) wrap(text string) []string { + text = strings.ReplaceAll(text, "\t", " ") + if runewidth.StringWidth(text) <= cp.colWidth { + return []string{text} + } + var lines []string + var current strings.Builder + width := 0 + for _, r := range text { + rw := runewidth.RuneWidth(r) + if width+rw > cp.colWidth && current.Len() > 0 { + lines = append(lines, current.String()) + current.Reset() + width = 0 + } + current.WriteRune(r) + width += rw + } + if current.Len() > 0 { + lines = append(lines, current.String()) + } + if len(lines) == 0 { + lines = append(lines, "") + } + return lines +} + +func (cp *columnPrinter) writeLine(cells []string) { + if len(cells) < cp.columns { + extra := make([]string, cp.columns-len(cells)) + cells = append(cells, extra...) + } + var builder strings.Builder + for i := 0; i < cp.columns; i++ { + cell := cells[i] + width := runewidth.StringWidth(cell) + if width > cp.colWidth { + cell = runewidth.Truncate(cell, cp.colWidth, "…") + width = runewidth.StringWidth(cell) + } + builder.WriteString(cell) + if pad := cp.colWidth - width; pad > 0 { + builder.WriteString(strings.Repeat(" ", pad)) + } + if i != cp.columns-1 { + builder.WriteString(" │ ") + } + } + builder.WriteByte('\n') + _, _ = cp.stdout.Write([]byte(builder.String())) +} + +// WithCLISelection injects provider indices into the context so Run only executes those jobs. +func WithCLISelection(ctx context.Context, indices []int) context.Context { + if ctx == nil { + ctx = context.Background() + } + cpy := make([]int, len(indices)) + copy(cpy, indices) + return context.WithValue(ctx, selectionContextKey{}, cpy) +} + +func selectionFromContext(ctx context.Context) []int { + if ctx == nil { + return nil + } + if v, ok := ctx.Value(selectionContextKey{}).([]int); ok { + cpy := make([]int, len(v)) + copy(cpy, v) + return cpy + } + return nil +} + +func filterJobsBySelection(jobs []cliJob, indices []int) ([]cliJob, error) { + if len(indices) == 0 { + return jobs, nil + } + filtered := make([]cliJob, 0, len(indices)) + seen := make(map[int]struct{}, len(indices)) + for _, idx := range indices { + if idx < 0 || idx >= len(jobs) { + return nil, fmt.Errorf("provider index %d out of range (0-%d)", idx, len(jobs)-1) + } + if _, ok := seen[idx]; ok { + continue + } + clone := jobs[idx] + filtered = append(filtered, clone) + seen[idx] = struct{}{} + } + for i := range filtered { + filtered[i].index = i + } + if len(filtered) == 0 { + return nil, fmt.Errorf("no CLI providers matched selection") + } + return filtered, nil +} + // readInput reads from stdin and args, then combines them per CLI rules. func readInput(stdin io.Reader, args []string) (string, error) { var stdinData string diff --git a/internal/hexaicli/run_model_override_test.go b/internal/hexaicli/run_model_override_test.go index 6394bd1..b32b172 100644 --- a/internal/hexaicli/run_model_override_test.go +++ b/internal/hexaicli/run_model_override_test.go @@ -1,39 +1,45 @@ package hexaicli import ( - "bytes" - "context" - "strings" - "testing" + "bytes" + "context" + "strings" + "testing" - "codeberg.org/snonux/hexai/internal/appconfig" - "codeberg.org/snonux/hexai/internal/llm" + "codeberg.org/snonux/hexai/internal/appconfig" + "codeberg.org/snonux/hexai/internal/llm" ) type fakeClientModelEnv struct{ name, model string } -func (f fakeClientModelEnv) Chat(_ context.Context, _ []llm.Message, _ ...llm.RequestOption) (string, error) { return "ok", nil } + +func (f fakeClientModelEnv) Chat(_ context.Context, _ []llm.Message, _ ...llm.RequestOption) (string, error) { + return "ok", nil +} func (f fakeClientModelEnv) Name() string { return f.name } func (f fakeClientModelEnv) DefaultModel() string { return f.model } // Ensure that HEXAI_MODEL overrides config for CLI runs. func TestRun_ModelEnvOverride_FlowsIntoClient(t *testing.T) { - t.Setenv("HEXAI_MODEL", "gpt-5-codex") - t.Setenv("HEXAI_PROVIDER", "openai") - // Replace client constructor to assert model was overridden - oldNew := newClientFromApp - defer func() { newClientFromApp = oldNew }() + t.Setenv("XDG_CONFIG_HOME", t.TempDir()) + t.Setenv("HEXAI_MODEL", "gpt-5-codex") + t.Setenv("HEXAI_PROVIDER", "openai") + // Replace client constructor to assert model was overridden + oldNew := newClientFromApp + defer func() { newClientFromApp = oldNew }() + var seenModel string newClientFromApp = func(cfg appconfig.App) (llm.Client, error) { - if strings.TrimSpace(cfg.OpenAIModel) != "gpt-5-codex" { - t.Fatalf("expected cfg.OpenAIModel=gpt-5-codex, got %q", cfg.OpenAIModel) - } - return fakeClientModelEnv{name: "openai", model: cfg.OpenAIModel}, nil - } + seenModel = strings.TrimSpace(cfg.OpenAIModel) + return fakeClientModelEnv{name: "openai", model: cfg.OpenAIModel}, nil + } - var out, errb bytes.Buffer - if err := Run(context.Background(), []string{"hello"}, strings.NewReader(""), &out, &errb); err != nil { - t.Fatalf("run error: %v", err) - } - if !strings.Contains(errb.String(), "model=gpt-5-codex") { - t.Fatalf("stderr should print effective model, got: %s", errb.String()) - } + var out, errb bytes.Buffer + if err := Run(context.Background(), []string{"hello"}, strings.NewReader(""), &out, &errb); err != nil { + t.Fatalf("run error: %v", err) + } + if seenModel != "gpt-5-codex" { + t.Fatalf("expected cfg.OpenAIModel=gpt-5-codex, got %q", seenModel) + } + if !strings.Contains(errb.String(), "model=gpt-5-codex") { + t.Fatalf("stderr should print effective model, got: %s", errb.String()) + } } diff --git a/internal/hexaicli/run_test.go b/internal/hexaicli/run_test.go index f11545e..dfde068 100644 --- a/internal/hexaicli/run_test.go +++ b/internal/hexaicli/run_test.go @@ -225,6 +225,23 @@ func TestBuildCLIJobs_MultiEntries(t *testing.T) { } } +func TestFilterJobsBySelection(t *testing.T) { + jobs := []cliJob{{index: 0, provider: "openai"}, {index: 1, provider: "ollama"}, {index: 2, provider: "copilot"}} + filtered, err := filterJobsBySelection(jobs, []int{2, 0}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(filtered) != 2 || filtered[0].provider != "copilot" || filtered[1].provider != "openai" { + t.Fatalf("unexpected filtered order: %+v", filtered) + } + if filtered[0].index != 0 || filtered[1].index != 1 { + t.Fatalf("expected reindexed jobs, got %+v", filtered) + } + if _, err := filterJobsBySelection(jobs, []int{5}); err == nil { + t.Fatalf("expected out-of-range error") + } +} + func TestNewClientFromConfig_Ollama(t *testing.T) { cfg := appconfig.App{Provider: "ollama", OllamaBaseURL: "http://x", OllamaModel: "m"} c, err := newClientFromConfig(cfg) |
