summaryrefslogtreecommitdiff
path: root/internal/hexaicli
diff options
context:
space:
mode:
authorPaul Buetow <paul@buetow.org>2025-09-06 10:56:27 +0300
committerPaul Buetow <paul@buetow.org>2025-09-06 10:56:27 +0300
commit320de746f7a2985b60c8564a0e65bdf231e840b7 (patch)
treee70bcf50813dba411afa2934e774383124bbc99e /internal/hexaicli
parent06247527d5170f329b454b42f59a3e4434ab1f4b (diff)
use gofumpt
Diffstat (limited to 'internal/hexaicli')
-rw-r--r--internal/hexaicli/run.go78
-rw-r--r--internal/hexaicli/run_test.go196
-rw-r--r--internal/hexaicli/testhelpers_test.go40
3 files changed, 174 insertions, 140 deletions
diff --git a/internal/hexaicli/run.go b/internal/hexaicli/run.go
index 7471816..54cb3ff 100644
--- a/internal/hexaicli/run.go
+++ b/internal/hexaicli/run.go
@@ -3,14 +3,14 @@
package hexaicli
import (
- "bufio"
- "context"
- "fmt"
- "io"
- "log"
- "os"
- "strings"
- "time"
+ "bufio"
+ "context"
+ "fmt"
+ "io"
+ "log"
+ "os"
+ "strings"
+ "time"
"codeberg.org/snonux/hexai/internal/appconfig"
"codeberg.org/snonux/hexai/internal/llm"
@@ -20,14 +20,14 @@ import (
// Run executes the Hexai CLI behavior given arguments and I/O streams.
// It assumes flags have already been parsed by the caller.
func Run(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io.Writer) error {
- // Load configuration with a logger so file-based config is respected.
- logger := log.New(stderr, "hexai ", log.LstdFlags|log.Lmsgprefix)
- cfg := appconfig.Load(logger)
- client, err := newClientFromConfig(cfg)
- if err != nil {
- fmt.Fprintf(stderr, logging.AnsiBase+"hexai: LLM disabled: %v"+logging.AnsiReset+"\n", err)
- return err
- }
+ // Load configuration with a logger so file-based config is respected.
+ logger := log.New(stderr, "hexai ", log.LstdFlags|log.Lmsgprefix)
+ cfg := appconfig.Load(logger)
+ client, err := newClientFromConfig(cfg)
+ if err != nil {
+ fmt.Fprintf(stderr, logging.AnsiBase+"hexai: LLM disabled: %v"+logging.AnsiReset+"\n", err)
+ return err
+ }
return RunWithClient(ctx, args, stdin, stdout, stderr, client)
}
@@ -71,29 +71,29 @@ func readInput(stdin io.Reader, args []string) (string, error) {
// newClientFromConfig builds an LLM client from the app config and env keys.
func newClientFromConfig(cfg appconfig.App) (llm.Client, error) {
- llmCfg := llm.Config{
- Provider: cfg.Provider,
- OpenAIBaseURL: cfg.OpenAIBaseURL,
- OpenAIModel: cfg.OpenAIModel,
- OpenAITemperature: cfg.OpenAITemperature,
- OllamaBaseURL: cfg.OllamaBaseURL,
- OllamaModel: cfg.OllamaModel,
- OllamaTemperature: cfg.OllamaTemperature,
- CopilotBaseURL: cfg.CopilotBaseURL,
- CopilotModel: cfg.CopilotModel,
- CopilotTemperature: cfg.CopilotTemperature,
- }
- // Prefer HEXAI_OPENAI_API_KEY; fall back to OPENAI_API_KEY
- oaKey := os.Getenv("HEXAI_OPENAI_API_KEY")
- if strings.TrimSpace(oaKey) == "" {
- oaKey = os.Getenv("OPENAI_API_KEY")
- }
- // Prefer HEXAI_COPILOT_API_KEY; fall back to COPILOT_API_KEY
- cpKey := os.Getenv("HEXAI_COPILOT_API_KEY")
- if strings.TrimSpace(cpKey) == "" {
- cpKey = os.Getenv("COPILOT_API_KEY")
- }
- return llm.NewFromConfig(llmCfg, oaKey, cpKey)
+ llmCfg := llm.Config{
+ Provider: cfg.Provider,
+ OpenAIBaseURL: cfg.OpenAIBaseURL,
+ OpenAIModel: cfg.OpenAIModel,
+ OpenAITemperature: cfg.OpenAITemperature,
+ OllamaBaseURL: cfg.OllamaBaseURL,
+ OllamaModel: cfg.OllamaModel,
+ OllamaTemperature: cfg.OllamaTemperature,
+ CopilotBaseURL: cfg.CopilotBaseURL,
+ CopilotModel: cfg.CopilotModel,
+ CopilotTemperature: cfg.CopilotTemperature,
+ }
+ // Prefer HEXAI_OPENAI_API_KEY; fall back to OPENAI_API_KEY
+ oaKey := os.Getenv("HEXAI_OPENAI_API_KEY")
+ if strings.TrimSpace(oaKey) == "" {
+ oaKey = os.Getenv("OPENAI_API_KEY")
+ }
+ // Prefer HEXAI_COPILOT_API_KEY; fall back to COPILOT_API_KEY
+ cpKey := os.Getenv("HEXAI_COPILOT_API_KEY")
+ if strings.TrimSpace(cpKey) == "" {
+ cpKey = os.Getenv("COPILOT_API_KEY")
+ }
+ return llm.NewFromConfig(llmCfg, oaKey, cpKey)
}
// buildMessages creates system and user messages based on input content.
diff --git a/internal/hexaicli/run_test.go b/internal/hexaicli/run_test.go
index 0d77e19..77daa8b 100644
--- a/internal/hexaicli/run_test.go
+++ b/internal/hexaicli/run_test.go
@@ -1,122 +1,150 @@
package hexaicli
import (
- "bytes"
- "context"
- "io"
- "path/filepath"
- "strings"
- "testing"
+ "bytes"
+ "context"
+ "io"
+ "path/filepath"
+ "strings"
+ "testing"
- "codeberg.org/snonux/hexai/internal/appconfig"
- "codeberg.org/snonux/hexai/internal/llm"
+ "codeberg.org/snonux/hexai/internal/appconfig"
+ "codeberg.org/snonux/hexai/internal/llm"
)
func TestReadInput_Combinations(t *testing.T) {
- // stdin + arg
- restore, f := setStdin(t, "from-stdin")
- defer restore()
- s, err := readInput(f, []string{"from-arg"})
- if err != nil || !strings.HasPrefix(s, "from-arg:\n\nfrom-stdin") { t.Fatalf("stdin+arg failed: %q %v", s, err) }
- // stdin only
- restore2, f2 := setStdin(t, "from-stdin")
- defer restore2()
- s, err = readInput(f2, nil)
- if err != nil || s != "from-stdin" { t.Fatalf("stdin only failed: %q %v", s, err) }
- // arg only
- s, err = readInput(strings.NewReader(""), []string{"arg1","arg2"})
- if err != nil || s != "arg1 arg2" { t.Fatalf("arg only failed: %q %v", s, err) }
- // no input
- restore3, f3 := setStdin(t, "")
- defer restore3()
- _, err = readInput(f3, nil)
- if err == nil { t.Fatalf("expected error for no input") }
+ // stdin + arg
+ restore, f := setStdin(t, "from-stdin")
+ defer restore()
+ s, err := readInput(f, []string{"from-arg"})
+ if err != nil || !strings.HasPrefix(s, "from-arg:\n\nfrom-stdin") {
+ t.Fatalf("stdin+arg failed: %q %v", s, err)
+ }
+ // stdin only
+ restore2, f2 := setStdin(t, "from-stdin")
+ defer restore2()
+ s, err = readInput(f2, nil)
+ if err != nil || s != "from-stdin" {
+ t.Fatalf("stdin only failed: %q %v", s, err)
+ }
+ // arg only
+ s, err = readInput(strings.NewReader(""), []string{"arg1", "arg2"})
+ if err != nil || s != "arg1 arg2" {
+ t.Fatalf("arg only failed: %q %v", s, err)
+ }
+ // no input
+ restore3, f3 := setStdin(t, "")
+ defer restore3()
+ _, err = readInput(f3, nil)
+ if err == nil {
+ t.Fatalf("expected error for no input")
+ }
}
func TestBuildMessages_Explain(t *testing.T) {
- msgs := buildMessages("please explain this")
- if len(msgs) != 2 || msgs[0].Role != "system" || !strings.Contains(strings.ToLower(msgs[0].Content), "explanation") {
- t.Fatalf("unexpected system prompt: %#v", msgs)
- }
+ msgs := buildMessages("please explain this")
+ if len(msgs) != 2 || msgs[0].Role != "system" || !strings.Contains(strings.ToLower(msgs[0].Content), "explanation") {
+ t.Fatalf("unexpected system prompt: %#v", msgs)
+ }
}
func TestBuildMessages_Default(t *testing.T) {
- msgs := buildMessages("just do it")
- if len(msgs) != 2 || msgs[0].Role != "system" || strings.Contains(msgs[0].Content, "requested an explanation") {
- t.Fatalf("unexpected system prompt: %#v", msgs)
- }
+ msgs := buildMessages("just do it")
+ if len(msgs) != 2 || msgs[0].Role != "system" || strings.Contains(msgs[0].Content, "requested an explanation") {
+ t.Fatalf("unexpected system prompt: %#v", msgs)
+ }
}
func TestRunChat_StreamAndNonStream(t *testing.T) {
- // stream path
- fc := &fakeStreamer{fakeClient: fakeClient{name: "p", model: "m"}, chunks: []string{"H","i","!"}}
- var out, errb bytes.Buffer
- if err := runChat(context.Background(), fc, buildMessages("hello"), "hello", &out, &errb); err != nil { t.Fatalf("stream: %v", err) }
- if out.String() != "Hi!" || !strings.Contains(errb.String(), "provider=p model=m") { t.Fatalf("bad output or summary: %q %q", out.String(), errb.String()) }
- // non-stream path
- fc2 := &fakeClient{name: "p2", model: "m2", resp: "Yo"}
- out.Reset(); errb.Reset()
- if err := runChat(context.Background(), fc2, buildMessages("hello"), "hello", &out, &errb); err != nil { t.Fatalf("non-stream: %v", err) }
- if out.String() != "Yo" || !strings.Contains(errb.String(), "provider=p2 model=m2") { t.Fatalf("bad output or summary (non-stream)") }
+ // stream path
+ fc := &fakeStreamer{fakeClient: fakeClient{name: "p", model: "m"}, chunks: []string{"H", "i", "!"}}
+ var out, errb bytes.Buffer
+ if err := runChat(context.Background(), fc, buildMessages("hello"), "hello", &out, &errb); err != nil {
+ t.Fatalf("stream: %v", err)
+ }
+ if out.String() != "Hi!" || !strings.Contains(errb.String(), "provider=p model=m") {
+ t.Fatalf("bad output or summary: %q %q", out.String(), errb.String())
+ }
+ // non-stream path
+ fc2 := &fakeClient{name: "p2", model: "m2", resp: "Yo"}
+ out.Reset()
+ errb.Reset()
+ if err := runChat(context.Background(), fc2, buildMessages("hello"), "hello", &out, &errb); err != nil {
+ t.Fatalf("non-stream: %v", err)
+ }
+ if out.String() != "Yo" || !strings.Contains(errb.String(), "provider=p2 model=m2") {
+ t.Fatalf("bad output or summary (non-stream)")
+ }
}
type clientErr struct{ name, model string }
-func (c clientErr) Chat(context.Context, []llm.Message, ...llm.RequestOption) (string, error) { return "", io.EOF }
-func (c clientErr) Name() string { return c.name }
+
+func (c clientErr) Chat(context.Context, []llm.Message, ...llm.RequestOption) (string, error) {
+ return "", io.EOF
+}
+func (c clientErr) Name() string { return c.name }
func (c clientErr) DefaultModel() string { return c.model }
func TestRunChat_ErrorPaths(t *testing.T) {
- ctx := context.Background()
- out, errb := &bytes.Buffer{}, &bytes.Buffer{}
- if err := runChat(ctx, clientErr{"p","m"}, buildMessages("hi"), "hi", out, errb); err == nil {
- t.Fatalf("expected error from Chat")
- }
+ ctx := context.Background()
+ out, errb := &bytes.Buffer{}, &bytes.Buffer{}
+ if err := runChat(ctx, clientErr{"p", "m"}, buildMessages("hi"), "hi", out, errb); err == nil {
+ t.Fatalf("expected error from Chat")
+ }
}
func TestRunWithClient_ErrorPrint(t *testing.T) {
- var out, errb bytes.Buffer
- err := RunWithClient(context.Background(), []string{"hi"}, strings.NewReader(""), &out, &errb, clientErr{"p","m"})
- if err == nil { t.Fatalf("expected error") }
- if !strings.Contains(errb.String(), "hexai: error:") {
- t.Fatalf("expected error line, got %q", errb.String())
- }
+ var out, errb bytes.Buffer
+ err := RunWithClient(context.Background(), []string{"hi"}, strings.NewReader(""), &out, &errb, clientErr{"p", "m"})
+ if err == nil {
+ t.Fatalf("expected error")
+ }
+ if !strings.Contains(errb.String(), "hexai: error:") {
+ t.Fatalf("expected error line, got %q", errb.String())
+ }
}
func TestRun_OpenAI_NoKey_ShowsError(t *testing.T) {
- dir := testingTempDir(t)
- // write config with provider=openai
- writeJSON(t, filepath.Join(dir, "hexai", "config.json"), map[string]any{"provider":"openai", "openai_model":"gpt-x"})
- t.Setenv("XDG_CONFIG_HOME", dir)
- // Ensure no OpenAI API key is present in environment
- t.Setenv("HEXAI_OPENAI_API_KEY", "")
- t.Setenv("OPENAI_API_KEY", "")
- var out, errb bytes.Buffer
- // Run expects parsed flags; here args irrelevant
- err := Run(context.Background(), []string{"hello"}, strings.NewReader(""), &out, &errb)
- if err == nil { t.Fatalf("expected error due to missing API key") }
- // Accept either explicit "LLM disabled" or a generic provider error emitted by Run.
- if !(strings.Contains(errb.String(), "LLM disabled") || strings.Contains(errb.String(), "openai error") || strings.Contains(errb.String(), "hexai: error:")) {
- t.Fatalf("expected disabled-or-error message, got %q", errb.String())
- }
+ dir := testingTempDir(t)
+ // write config with provider=openai
+ writeJSON(t, filepath.Join(dir, "hexai", "config.json"), map[string]any{"provider": "openai", "openai_model": "gpt-x"})
+ t.Setenv("XDG_CONFIG_HOME", dir)
+ // Ensure no OpenAI API key is present in environment
+ t.Setenv("HEXAI_OPENAI_API_KEY", "")
+ t.Setenv("OPENAI_API_KEY", "")
+ var out, errb bytes.Buffer
+ // Run expects parsed flags; here args irrelevant
+ err := Run(context.Background(), []string{"hello"}, strings.NewReader(""), &out, &errb)
+ if err == nil {
+ t.Fatalf("expected error due to missing API key")
+ }
+ // Accept either explicit "LLM disabled" or a generic provider error emitted by Run.
+ if !(strings.Contains(errb.String(), "LLM disabled") || strings.Contains(errb.String(), "openai error") || strings.Contains(errb.String(), "hexai: error:")) {
+ t.Fatalf("expected disabled-or-error message, got %q", errb.String())
+ }
}
func TestPrintProviderInfo(t *testing.T) {
- var b bytes.Buffer
- printProviderInfo(&b, &fakeClient{name:"x", model:"y"})
- if !strings.Contains(b.String(), "provider=x model=y") { t.Fatalf("missing provider line: %q", b.String()) }
+ var b bytes.Buffer
+ printProviderInfo(&b, &fakeClient{name: "x", model: "y"})
+ if !strings.Contains(b.String(), "provider=x model=y") {
+ t.Fatalf("missing provider line: %q", b.String())
+ }
}
func TestNewClientFromConfig_Ollama(t *testing.T) {
- cfg := appconfig.App{ Provider: "ollama", OllamaBaseURL: "http://x", OllamaModel: "m" }
- c, err := newClientFromConfig(cfg)
- if err != nil || c == nil { t.Fatalf("expected client: %v %v", c, err) }
+ cfg := appconfig.App{Provider: "ollama", OllamaBaseURL: "http://x", OllamaModel: "m"}
+ c, err := newClientFromConfig(cfg)
+ if err != nil || c == nil {
+ t.Fatalf("expected client: %v %v", c, err)
+ }
}
func TestNewClientFromConfig_OpenAI_MissingKey(t *testing.T) {
- cfg := appconfig.App{ Provider: "openai", OpenAIBaseURL: "https://api", OpenAIModel: "gpt" }
- t.Setenv("HEXAI_OPENAI_API_KEY", "")
- t.Setenv("OPENAI_API_KEY", "")
- if _, err := newClientFromConfig(cfg); err == nil {
- t.Fatalf("expected error for missing openai key")
- }
+ cfg := appconfig.App{Provider: "openai", OpenAIBaseURL: "https://api", OpenAIModel: "gpt"}
+ t.Setenv("HEXAI_OPENAI_API_KEY", "")
+ t.Setenv("OPENAI_API_KEY", "")
+ if _, err := newClientFromConfig(cfg); err == nil {
+ t.Fatalf("expected error for missing openai key")
+ }
}
diff --git a/internal/hexaicli/testhelpers_test.go b/internal/hexaicli/testhelpers_test.go
index 1f75916..512a3ba 100644
--- a/internal/hexaicli/testhelpers_test.go
+++ b/internal/hexaicli/testhelpers_test.go
@@ -2,13 +2,13 @@
package hexaicli
import (
- "context"
- "encoding/json"
- "os"
- "path/filepath"
- "testing"
+ "context"
+ "encoding/json"
+ "os"
+ "path/filepath"
+ "testing"
- "codeberg.org/snonux/hexai/internal/llm"
+ "codeberg.org/snonux/hexai/internal/llm"
)
// setStdin sets os.Stdin from a string and returns a restore func and reader.
@@ -55,21 +55,27 @@ type fakeStreamer struct {
}
func (s *fakeStreamer) ChatStream(ctx context.Context, messages []llm.Message, onDelta func(string), opts ...llm.RequestOption) error {
- s.sMsgs = append([]llm.Message{}, messages...)
- for _, c := range s.chunks {
- onDelta(c)
- }
- return nil
+ s.sMsgs = append([]llm.Message{}, messages...)
+ for _, c := range s.chunks {
+ onDelta(c)
+ }
+ return nil
}
// small JSON writer for tests
func writeJSON(t *testing.T, path string, v any) {
- t.Helper()
- if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { t.Fatalf("mkdir: %v", err) }
- f, err := os.Create(path)
- if err != nil { t.Fatalf("create: %v", err) }
- defer f.Close()
- if err := json.NewEncoder(f).Encode(v); err != nil { t.Fatalf("encode: %v", err) }
+ t.Helper()
+ if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
+ t.Fatalf("mkdir: %v", err)
+ }
+ f, err := os.Create(path)
+ if err != nil {
+ t.Fatalf("create: %v", err)
+ }
+ defer f.Close()
+ if err := json.NewEncoder(f).Encode(v); err != nil {
+ t.Fatalf("encode: %v", err)
+ }
}
func testingTempDir(t *testing.T) string { t.Helper(); return t.TempDir() }