diff options
| author | Paul Buetow <paul@buetow.org> | 2026-03-15 23:24:00 +0200 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2026-03-15 23:24:00 +0200 |
| commit | 8ec8ee16e23081018e32dea122ecd9a3b8d8b2c7 (patch) | |
| tree | 5a564bb36fc9750d3353435d2dd3cf2f28fa5261 /internal/hexaicli/run.go | |
| parent | 10112d4b7a8150118e705b95df73c08824ac2b22 (diff) | |
Release v0.23.0v0.23.0
Diffstat (limited to 'internal/hexaicli/run.go')
| -rw-r--r-- | internal/hexaicli/run.go | 114 |
1 files changed, 86 insertions, 28 deletions
diff --git a/internal/hexaicli/run.go b/internal/hexaicli/run.go index b48bee0..4cd94b4 100644 --- a/internal/hexaicli/run.go +++ b/internal/hexaicli/run.go @@ -24,15 +24,17 @@ import ( ) type requestArgs struct { - model string - options []llm.RequestOption + model string + maxTokens int + temperature *float64 + options []llm.RequestOption } type cliJob struct { index int provider string entry appconfig.SurfaceConfig - client llm.Client + cfg appconfig.App req requestArgs } @@ -55,40 +57,33 @@ func buildCLIJobs(cfg appconfig.App) ([]cliJob, error) { } provider = canonicalProvider(provider) derived := llmutils.ConfigForProvider(cfg, provider, entry.Model) - client, err := newClientFromApp(derived) - if err != nil { - return nil, err - } - req := buildCLIRequest(entry, provider, cfg, client) - if strings.TrimSpace(req.model) == "" { - req.model = strings.TrimSpace(client.DefaultModel()) - } - jobs = append(jobs, cliJob{index: i, provider: provider, entry: entry, client: client, req: req}) + req := buildCLIRequest(entry, provider, derived) + jobs = append(jobs, cliJob{index: i, provider: provider, entry: entry, cfg: derived, req: req}) } return jobs, nil } -func buildCLIRequest(entry appconfig.SurfaceConfig, provider string, cfg appconfig.App, client llm.Client) requestArgs { +func buildCLIRequest(entry appconfig.SurfaceConfig, provider string, cfg appconfig.App) requestArgs { opts := make([]llm.RequestOption, 0, 2) + req := requestArgs{maxTokens: cfg.MaxTokens} if cfg.MaxTokens > 0 { opts = append(opts, llm.WithMaxTokens(cfg.MaxTokens)) } model := strings.TrimSpace(entry.Model) if model == "" { - if client != nil { - model = strings.TrimSpace(client.DefaultModel()) - } - if model == "" { - model = strings.TrimSpace(llmutils.DefaultModelForProvider(cfg, provider)) - } + model = strings.TrimSpace(llmutils.DefaultModelForProvider(cfg, provider)) } if entry.Model != "" { opts = append(opts, llm.WithModel(entry.Model)) } if temp, ok := cliTemperatureFromEntry(cfg, provider, entry, model); ok { + tempValue := temp + req.temperature = &tempValue opts = append(opts, llm.WithTemperature(temp)) } - return requestArgs{model: model, options: opts} + req.model = model + req.options = opts + return req } func cliTemperatureFromEntry(cfg appconfig.App, provider string, entry appconfig.SurfaceConfig, model string) (float64, bool) { @@ -240,28 +235,72 @@ func setupCLIPrinter(stdout io.Writer, jobs []cliJob) *termprint.ColumnPrinter { } func runSingleCLIJob(ctx context.Context, job cliJob, msgs []llm.Message, input string, stdout io.Writer, printer *termprint.ColumnPrinter, streamOutput bool) *cliJobResult { + if res := cachedCLIJobResult(job, msgs, stdout, printer, streamOutput); res != nil { + return res + } + + client, err := newClientFromApp(job.cfg) + if err != nil { + return &cliJobResult{provider: job.provider, model: job.req.model, err: err} + } + model := effectiveModel(job.req, client) + var errBuf bytes.Buffer var outBuf bytes.Buffer jobMsgs := append([]llm.Message(nil), msgs...) writer := io.Writer(&outBuf) if printer != nil { - writer = printer.Writer(job.index) + writer = io.MultiWriter(printer.Writer(job.index), &outBuf) } else if streamOutput { writer = io.MultiWriter(stdout, &outBuf) } - err := runChat(ctx, job.client, job.req, jobMsgs, input, writer, &errBuf) + err = runChat(ctx, client, job.req, jobMsgs, input, writer, &errBuf) if printer != nil { printer.Flush(job.index) } + if err == nil { + storeCLIResponseCache(newCLIResponseCacheKey(job.provider, model, job.req, jobMsgs), outBuf.String()) + } return &cliJobResult{ - provider: job.client.Name(), - model: job.req.model, + provider: job.provider, + model: model, output: outBuf.String(), summary: errBuf.String(), err: err, } } +func cachedCLIJobResult(job cliJob, msgs []llm.Message, stdout io.Writer, printer *termprint.ColumnPrinter, streamOutput bool) *cliJobResult { + output, age, ok := lookupCLIResponseCache(newCLIResponseCacheKey(job.provider, job.req.model, job.req, msgs)) + if !ok { + return nil + } + if err := writeCachedCLIJobOutput(output, stdout, printer, job.index, streamOutput); err != nil { + return &cliJobResult{provider: job.provider, model: job.req.model, err: err} + } + return &cliJobResult{ + provider: job.provider, + model: job.req.model, + output: output, + summary: cacheHitSummary(job.provider, job.req.model, age), + } +} + +func writeCachedCLIJobOutput(output string, stdout io.Writer, printer *termprint.ColumnPrinter, idx int, streamOutput bool) error { + if printer != nil { + if _, err := io.WriteString(printer.Writer(idx), output); err != nil { + return err + } + printer.Flush(idx) + return nil + } + if !streamOutput { + return nil + } + _, err := io.WriteString(stdout, output) + return err +} + func writeCLIJobOutputs(stdout io.Writer, results []*cliJobResult) error { printed := false showHeading := cliJobResultCount(results) > 1 @@ -346,7 +385,7 @@ func newColumnPrinter(stdout io.Writer, jobs []cliJob) *termprint.ColumnPrinter providers := make([]string, len(jobs)) models := make([]string, len(jobs)) for _, job := range jobs { - providers[job.index] = job.client.Name() + providers[job.index] = job.provider models[job.index] = job.req.model } return termprint.NewColumnPrinter(stdout, providers, models) @@ -361,7 +400,7 @@ func printCLIHeader(stderr io.Writer, jobs []cliJob, printer *termprint.ColumnPr return } job := jobs[0] - printProviderInfo(stderr, job.client, job.req.model) + printProviderLabel(stderr, job.provider, job.req.model) } // WithCLISelection injects provider indices into the context so Run only executes those jobs. @@ -577,16 +616,35 @@ func summarizeChatRun(ctx context.Context, client llm.Client, model string, msgs // printProviderInfo writes the provider:model header and divider to stderr. func printProviderInfo(errw io.Writer, client llm.Client, model string) { + printProviderLabel(errw, client.Name(), chooseCLIModel(model, client.DefaultModel())) +} + +func printProviderLabel(errw io.Writer, provider, model string) { if strings.TrimSpace(model) == "" { - model = client.DefaultModel() + return } - printer := termprint.NewColumnPrinter(errw, []string{client.Name()}, []string{model}) + printer := termprint.NewColumnPrinter(errw, []string{provider}, []string{model}) if printer == nil { return } printer.PrintHeader() } +func chooseCLIModel(model, fallback string) string { + model = strings.TrimSpace(model) + if model != "" { + return model + } + return strings.TrimSpace(fallback) +} + +func cacheHitSummary(provider, model string, age time.Duration) string { + if age < 0 { + age = 0 + } + return fmt.Sprintf(logging.AnsiBase+"cache hit provider=%s model=%s age=%s"+logging.AnsiReset+"\n", provider, model, age.Round(time.Second)) +} + // newClientFromConfig is kept for tests; delegates to llmutils. var newClientFromApp = llmutils.NewClientFromApp |
