summaryrefslogtreecommitdiff
path: root/internal/hexaicli/run.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/hexaicli/run.go')
-rw-r--r--internal/hexaicli/run.go97
1 files changed, 19 insertions, 78 deletions
diff --git a/internal/hexaicli/run.go b/internal/hexaicli/run.go
index 9c7ba73..0da1a5f 100644
--- a/internal/hexaicli/run.go
+++ b/internal/hexaicli/run.go
@@ -7,20 +7,17 @@ import (
"context"
"fmt"
"io"
- "log"
"os"
"strings"
"sync"
"time"
"codeberg.org/snonux/hexai/internal/appconfig"
- "codeberg.org/snonux/hexai/internal/editor"
"codeberg.org/snonux/hexai/internal/llm"
"codeberg.org/snonux/hexai/internal/llmutils"
"codeberg.org/snonux/hexai/internal/logging"
"codeberg.org/snonux/hexai/internal/stats"
"codeberg.org/snonux/hexai/internal/termprint"
- "codeberg.org/snonux/hexai/internal/tmux"
)
type requestArgs struct {
@@ -95,77 +92,13 @@ func cliTemperatureFromEntry(cfg appconfig.App, provider string, entry appconfig
// Run executes the Hexai CLI behavior given arguments and I/O streams.
// It assumes flags have already been parsed by the caller.
func Run(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io.Writer) error {
- if spec, ok, err := tpsSimulationFromContext(ctx); err != nil {
- _, _ = fmt.Fprintln(stderr, logging.AnsiBase+err.Error()+logging.AnsiReset)
- return err
- } else if ok {
- input, inputErr := readSimulationInput(stdin, args)
- if inputErr != nil {
- _, _ = fmt.Fprintln(stderr, logging.AnsiBase+inputErr.Error()+logging.AnsiReset)
- return inputErr
- }
- return runTPSSimulation(ctx, spec, input, stdout)
- }
-
- // Load configuration silently; config-load messages are noise in the CLI.
- logger := log.New(io.Discard, "", 0)
- configPath := configPathFromContext(ctx)
- cfg := appconfig.LoadWithOptions(logger, appconfig.LoadOptions{ConfigPath: configPath})
- if cfg.StatsWindowMinutes > 0 {
- stats.SetWindow(time.Duration(cfg.StatsWindowMinutes) * time.Minute)
- }
- jobs, err := buildCLIJobs(cfg)
- if err != nil {
- _, _ = fmt.Fprintf(stderr, logging.AnsiBase+"hexai: LLM disabled: %v"+logging.AnsiReset+"\n", err)
- return err
- }
- if selected := selectionFromContext(ctx); len(selected) > 0 {
- jobs, err = filterJobsBySelection(jobs, selected)
- if err != nil {
- _, _ = fmt.Fprintf(stderr, logging.AnsiBase+"hexai: %v"+logging.AnsiReset+"\n", err)
- return err
- }
- }
- if len(jobs) == 0 {
- return fmt.Errorf("hexai: no CLI providers configured")
- }
- // Prefer piped stdin when present; only open the editor when there are no args
- // and no stdin content available.
- input, rerr := readInput(stdin, args)
- if rerr != nil && len(args) == 0 {
- if prompt, eerr := editor.OpenTempAndEdit(nil); eerr == nil && strings.TrimSpace(prompt) != "" {
- args = []string{prompt}
- input, rerr = readInput(stdin, args)
- }
- }
- if rerr != nil {
- _, _ = fmt.Fprintln(stderr, logging.AnsiBase+rerr.Error()+logging.AnsiReset)
- return rerr
- }
- msgs := buildMessagesFromConfig(cfg, input)
- if err := runCLIJobs(ctx, jobs, msgs, input, stdout, stderr); err != nil {
- _, _ = fmt.Fprintf(stderr, logging.AnsiBase+"hexai: error: %v"+logging.AnsiReset+"\n", err)
- return err
- }
- return nil
+ return NewRunner().Run(ctx, args, stdin, stdout, stderr)
}
// RunWithClient executes the CLI flow using an already-constructed client.
// Useful for testing and embedding.
func RunWithClient(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io.Writer, client llm.Client) error {
- input, err := readInput(stdin, args)
- if err != nil {
- _, _ = fmt.Fprintln(stderr, logging.AnsiBase+err.Error()+logging.AnsiReset)
- return err
- }
- req := requestArgs{model: strings.TrimSpace(client.DefaultModel())}
- printProviderInfo(stderr, client, req.model)
- msgs := buildMessages(input)
- if err := runChat(ctx, client, req, msgs, input, stdout, stderr); err != nil {
- _, _ = fmt.Fprintf(stderr, logging.AnsiBase+"hexai: error: %v"+logging.AnsiReset+"\n", err)
- return err
- }
- return nil
+ return NewRunner().RunWithClient(ctx, args, stdin, stdout, stderr, client)
}
type cliJobResult struct {
@@ -184,9 +117,9 @@ type chatRunSummary struct {
scopeRPM float64
}
-func runCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input string, stdout, stderr io.Writer) error {
+func runCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input string, stdout, stderr io.Writer, clientFactory cliClientFactory, statusSink cliStatusSink) error {
streamSingle := len(jobs) == 1
- results, printer := executeCLIJobs(ctx, jobs, msgs, input, stdout, stderr, streamSingle)
+ results, printer := executeCLIJobs(ctx, jobs, msgs, input, stdout, stderr, streamSingle, clientFactory, statusSink)
if printer == nil && !streamSingle {
if err := writeCLIJobOutputs(stdout, results); err != nil {
return err
@@ -195,7 +128,7 @@ func runCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input st
return writeCLIJobSummaries(stderr, results)
}
-func executeCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input string, stdout io.Writer, stderr io.Writer, streamSingle bool) ([]*cliJobResult, *termprint.ColumnPrinter) {
+func executeCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, input string, stdout io.Writer, stderr io.Writer, streamSingle bool, clientFactory cliClientFactory, statusSink cliStatusSink) ([]*cliJobResult, *termprint.ColumnPrinter) {
results := make([]*cliJobResult, len(jobs))
printer := setupCLIPrinter(stdout, jobs)
printCLIHeader(stderr, jobs, printer)
@@ -205,7 +138,7 @@ func executeCLIJobs(ctx context.Context, jobs []cliJob, msgs []llm.Message, inpu
wg.Add(1)
go func() {
defer wg.Done()
- results[job.index] = runSingleCLIJob(ctx, job, msgs, input, stdout, printer, streamSingle)
+ results[job.index] = runSingleCLIJob(ctx, job, msgs, input, stdout, printer, streamSingle, clientFactory, statusSink)
}()
}
wg.Wait()
@@ -219,12 +152,12 @@ func setupCLIPrinter(stdout io.Writer, jobs []cliJob) *termprint.ColumnPrinter {
return newColumnPrinter(stdout, jobs)
}
-func runSingleCLIJob(ctx context.Context, job cliJob, msgs []llm.Message, input string, stdout io.Writer, printer *termprint.ColumnPrinter, streamOutput bool) *cliJobResult {
+func runSingleCLIJob(ctx context.Context, job cliJob, msgs []llm.Message, input string, stdout io.Writer, printer *termprint.ColumnPrinter, streamOutput bool, clientFactory cliClientFactory, statusSink cliStatusSink) *cliJobResult {
if res := cachedCLIJobResult(job, msgs, stdout, printer, streamOutput); res != nil {
return res
}
- client, err := newClientFromApp(job.cfg)
+ client, err := clientFactory(job.cfg)
if err != nil {
return &cliJobResult{provider: job.provider, model: job.req.model, err: err}
}
@@ -239,7 +172,7 @@ func runSingleCLIJob(ctx context.Context, job cliJob, msgs []llm.Message, input
} else if streamOutput {
writer = io.MultiWriter(stdout, &outBuf)
}
- err = runChat(ctx, client, job.req, jobMsgs, input, writer, &errBuf)
+ err = runChatWithStatus(statusSink, ctx, client, job.req, jobMsgs, input, writer, &errBuf)
if printer != nil {
printer.Flush(job.index)
}
@@ -532,9 +465,15 @@ func buildMessagesFromConfig(cfg appconfig.App, input string) []llm.Message {
// runChat executes the chat request, handling streaming and summary output.
func runChat(ctx context.Context, client llm.Client, req requestArgs, msgs []llm.Message, input string, out io.Writer, errw io.Writer) error {
+ return runChatWithStatus(tmuxCLIStatusSink{}, ctx, client, req, msgs, input, out, errw)
+}
+
+func runChatWithStatus(statusSink cliStatusSink, ctx context.Context, client llm.Client, req requestArgs, msgs []llm.Message, input string, out io.Writer, errw io.Writer) error {
start := time.Now()
model := effectiveModel(req, client)
- _ = tmux.SetStatus(tmux.FormatLLMStartStatus(client.Name(), model))
+ if statusSink != nil {
+ _ = statusSink.SetLLMStart(client.Name(), model)
+ }
output, err := runChatRequest(ctx, client, req, msgs, out)
if err != nil {
@@ -547,7 +486,9 @@ func runChat(ctx context.Context, client llm.Client, req requestArgs, msgs []llm
client.Name(), model, dur.Round(time.Millisecond), summary.sent, summary.recv, summary.snapshot.Global.Reqs, summary.snapshot.RPM); err != nil {
return err
}
- _ = tmux.SetStatus(tmux.FormatGlobalStatusColored(summary.snapshot.Global.Reqs, summary.snapshot.RPM, summary.snapshot.Global.Sent, summary.snapshot.Global.Recv, client.Name(), model, summary.scopeRPM, summary.scopeReq, summary.snapshot.Window))
+ if statusSink != nil {
+ _ = statusSink.SetGlobal(summary.snapshot, client.Name(), model, summary.scopeRPM, summary.scopeReq)
+ }
return nil
}