summaryrefslogtreecommitdiff
path: root/internal/hexaicli/run.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/hexaicli/run.go')
-rw-r--r--internal/hexaicli/run.go114
1 files changed, 86 insertions, 28 deletions
diff --git a/internal/hexaicli/run.go b/internal/hexaicli/run.go
index b48bee0..4cd94b4 100644
--- a/internal/hexaicli/run.go
+++ b/internal/hexaicli/run.go
@@ -24,15 +24,17 @@ import (
)
type requestArgs struct {
- model string
- options []llm.RequestOption
+ model string
+ maxTokens int
+ temperature *float64
+ options []llm.RequestOption
}
type cliJob struct {
index int
provider string
entry appconfig.SurfaceConfig
- client llm.Client
+ cfg appconfig.App
req requestArgs
}
@@ -55,40 +57,33 @@ func buildCLIJobs(cfg appconfig.App) ([]cliJob, error) {
}
provider = canonicalProvider(provider)
derived := llmutils.ConfigForProvider(cfg, provider, entry.Model)
- client, err := newClientFromApp(derived)
- if err != nil {
- return nil, err
- }
- req := buildCLIRequest(entry, provider, cfg, client)
- if strings.TrimSpace(req.model) == "" {
- req.model = strings.TrimSpace(client.DefaultModel())
- }
- jobs = append(jobs, cliJob{index: i, provider: provider, entry: entry, client: client, req: req})
+ req := buildCLIRequest(entry, provider, derived)
+ jobs = append(jobs, cliJob{index: i, provider: provider, entry: entry, cfg: derived, req: req})
}
return jobs, nil
}
-func buildCLIRequest(entry appconfig.SurfaceConfig, provider string, cfg appconfig.App, client llm.Client) requestArgs {
+func buildCLIRequest(entry appconfig.SurfaceConfig, provider string, cfg appconfig.App) requestArgs {
opts := make([]llm.RequestOption, 0, 2)
+ req := requestArgs{maxTokens: cfg.MaxTokens}
if cfg.MaxTokens > 0 {
opts = append(opts, llm.WithMaxTokens(cfg.MaxTokens))
}
model := strings.TrimSpace(entry.Model)
if model == "" {
- if client != nil {
- model = strings.TrimSpace(client.DefaultModel())
- }
- if model == "" {
- model = strings.TrimSpace(llmutils.DefaultModelForProvider(cfg, provider))
- }
+ model = strings.TrimSpace(llmutils.DefaultModelForProvider(cfg, provider))
}
if entry.Model != "" {
opts = append(opts, llm.WithModel(entry.Model))
}
if temp, ok := cliTemperatureFromEntry(cfg, provider, entry, model); ok {
+ tempValue := temp
+ req.temperature = &tempValue
opts = append(opts, llm.WithTemperature(temp))
}
- return requestArgs{model: model, options: opts}
+ req.model = model
+ req.options = opts
+ return req
}
func cliTemperatureFromEntry(cfg appconfig.App, provider string, entry appconfig.SurfaceConfig, model string) (float64, bool) {
@@ -240,28 +235,72 @@ func setupCLIPrinter(stdout io.Writer, jobs []cliJob) *termprint.ColumnPrinter {
}
func runSingleCLIJob(ctx context.Context, job cliJob, msgs []llm.Message, input string, stdout io.Writer, printer *termprint.ColumnPrinter, streamOutput bool) *cliJobResult {
+ if res := cachedCLIJobResult(job, msgs, stdout, printer, streamOutput); res != nil {
+ return res
+ }
+
+ client, err := newClientFromApp(job.cfg)
+ if err != nil {
+ return &cliJobResult{provider: job.provider, model: job.req.model, err: err}
+ }
+ model := effectiveModel(job.req, client)
+
var errBuf bytes.Buffer
var outBuf bytes.Buffer
jobMsgs := append([]llm.Message(nil), msgs...)
writer := io.Writer(&outBuf)
if printer != nil {
- writer = printer.Writer(job.index)
+ writer = io.MultiWriter(printer.Writer(job.index), &outBuf)
} else if streamOutput {
writer = io.MultiWriter(stdout, &outBuf)
}
- err := runChat(ctx, job.client, job.req, jobMsgs, input, writer, &errBuf)
+ err = runChat(ctx, client, job.req, jobMsgs, input, writer, &errBuf)
if printer != nil {
printer.Flush(job.index)
}
+ if err == nil {
+ storeCLIResponseCache(newCLIResponseCacheKey(job.provider, model, job.req, jobMsgs), outBuf.String())
+ }
return &cliJobResult{
- provider: job.client.Name(),
- model: job.req.model,
+ provider: job.provider,
+ model: model,
output: outBuf.String(),
summary: errBuf.String(),
err: err,
}
}
+func cachedCLIJobResult(job cliJob, msgs []llm.Message, stdout io.Writer, printer *termprint.ColumnPrinter, streamOutput bool) *cliJobResult {
+ output, age, ok := lookupCLIResponseCache(newCLIResponseCacheKey(job.provider, job.req.model, job.req, msgs))
+ if !ok {
+ return nil
+ }
+ if err := writeCachedCLIJobOutput(output, stdout, printer, job.index, streamOutput); err != nil {
+ return &cliJobResult{provider: job.provider, model: job.req.model, err: err}
+ }
+ return &cliJobResult{
+ provider: job.provider,
+ model: job.req.model,
+ output: output,
+ summary: cacheHitSummary(job.provider, job.req.model, age),
+ }
+}
+
+func writeCachedCLIJobOutput(output string, stdout io.Writer, printer *termprint.ColumnPrinter, idx int, streamOutput bool) error {
+ if printer != nil {
+ if _, err := io.WriteString(printer.Writer(idx), output); err != nil {
+ return err
+ }
+ printer.Flush(idx)
+ return nil
+ }
+ if !streamOutput {
+ return nil
+ }
+ _, err := io.WriteString(stdout, output)
+ return err
+}
+
func writeCLIJobOutputs(stdout io.Writer, results []*cliJobResult) error {
printed := false
showHeading := cliJobResultCount(results) > 1
@@ -346,7 +385,7 @@ func newColumnPrinter(stdout io.Writer, jobs []cliJob) *termprint.ColumnPrinter
providers := make([]string, len(jobs))
models := make([]string, len(jobs))
for _, job := range jobs {
- providers[job.index] = job.client.Name()
+ providers[job.index] = job.provider
models[job.index] = job.req.model
}
return termprint.NewColumnPrinter(stdout, providers, models)
@@ -361,7 +400,7 @@ func printCLIHeader(stderr io.Writer, jobs []cliJob, printer *termprint.ColumnPr
return
}
job := jobs[0]
- printProviderInfo(stderr, job.client, job.req.model)
+ printProviderLabel(stderr, job.provider, job.req.model)
}
// WithCLISelection injects provider indices into the context so Run only executes those jobs.
@@ -577,16 +616,35 @@ func summarizeChatRun(ctx context.Context, client llm.Client, model string, msgs
// printProviderInfo writes the provider:model header and divider to stderr.
func printProviderInfo(errw io.Writer, client llm.Client, model string) {
+ printProviderLabel(errw, client.Name(), chooseCLIModel(model, client.DefaultModel()))
+}
+
+func printProviderLabel(errw io.Writer, provider, model string) {
if strings.TrimSpace(model) == "" {
- model = client.DefaultModel()
+ return
}
- printer := termprint.NewColumnPrinter(errw, []string{client.Name()}, []string{model})
+ printer := termprint.NewColumnPrinter(errw, []string{provider}, []string{model})
if printer == nil {
return
}
printer.PrintHeader()
}
+func chooseCLIModel(model, fallback string) string {
+ model = strings.TrimSpace(model)
+ if model != "" {
+ return model
+ }
+ return strings.TrimSpace(fallback)
+}
+
+func cacheHitSummary(provider, model string, age time.Duration) string {
+ if age < 0 {
+ age = 0
+ }
+ return fmt.Sprintf(logging.AnsiBase+"cache hit provider=%s model=%s age=%s"+logging.AnsiReset+"\n", provider, model, age.Round(time.Second))
+}
+
// newClientFromConfig is kept for tests; delegates to llmutils.
var newClientFromApp = llmutils.NewClientFromApp