summaryrefslogtreecommitdiff
path: root/internal/hexaicli/run_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/hexaicli/run_test.go')
-rw-r--r--internal/hexaicli/run_test.go194
1 files changed, 194 insertions, 0 deletions
diff --git a/internal/hexaicli/run_test.go b/internal/hexaicli/run_test.go
new file mode 100644
index 0000000..f9c8443
--- /dev/null
+++ b/internal/hexaicli/run_test.go
@@ -0,0 +1,194 @@
+// Summary: Unit tests for Hexai CLI helpers and run flow (input parsing, messages, streaming).
+// Not yet reviewed by a human
+package hexaicli
+
+import (
+ "bytes"
+ "context"
+ "strings"
+ "testing"
+)
+
+// helpers moved to testhelpers_test.go
+
+func TestReadInput_ArgsOnly(t *testing.T) {
+ restore, f := setStdin(t, "")
+ defer restore()
+ // Pass the same file reader used for os.Stdin (empty)
+ got, err := readInput(f, []string{"hello", "world"})
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ want := "hello world"
+ if got != want {
+ t.Fatalf("want %q, got %q", want, got)
+ }
+}
+
+func TestReadInput_StdinOnly(t *testing.T) {
+ restore, f := setStdin(t, "payload")
+ defer restore()
+ got, err := readInput(f, nil)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if got != "payload" {
+ t.Fatalf("want %q, got %q", "payload", got)
+ }
+}
+
+func TestReadInput_Combined(t *testing.T) {
+ restore, f := setStdin(t, "payload")
+ defer restore()
+ got, err := readInput(f, []string{"subject"})
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ want := "subject:\n\npayload"
+ if got != want {
+ t.Fatalf("want %q, got %q", want, got)
+ }
+}
+
+func TestReadInput_EmptyError(t *testing.T) {
+ restore, f := setStdin(t, "")
+ defer restore()
+ _, err := readInput(f, nil)
+ if err == nil {
+ t.Fatalf("expected error, got nil")
+ }
+ if !strings.Contains(err.Error(), "no input") {
+ t.Fatalf("unexpected error: %v", err)
+ }
+}
+
+func TestBuildMessages_DefaultAndExplain(t *testing.T) {
+ // Default concise
+ msgs := buildMessages("list files in folder")
+ if len(msgs) != 2 {
+ t.Fatalf("expected 2 messages, got %d", len(msgs))
+ }
+ if msgs[0].Role != "system" || msgs[1].Role != "user" {
+ t.Fatalf("unexpected roles: %+v", msgs)
+ }
+ if !strings.Contains(msgs[0].Content, "very short, concise answers") {
+ t.Fatalf("unexpected system message: %q", msgs[0].Content)
+ }
+ if msgs[1].Content != "list files in folder" {
+ t.Fatalf("unexpected user content: %q", msgs[1].Content)
+ }
+
+ // Verbose explain
+ msgs2 := buildMessages("please explain how this works")
+ if len(msgs2) != 2 {
+ t.Fatalf("expected 2 messages, got %d", len(msgs2))
+ }
+ if !strings.Contains(strings.ToLower(msgs2[0].Content), "requested an explanation") {
+ t.Fatalf("unexpected system message: %q", msgs2[0].Content)
+ }
+ if msgs2[1].Content != "please explain how this works" {
+ t.Fatalf("unexpected user content: %q", msgs2[1].Content)
+ }
+}
+
+func TestRunChat_NonStreaming(t *testing.T) {
+ var out bytes.Buffer
+ var errb bytes.Buffer
+ fc := fakeClient{name: "fake", model: "m", resp: "OUTPUT"}
+ if err := runChat(context.Background(), &fc, nil, "input", &out, &errb); err != nil {
+ t.Fatalf("runChat error: %v", err)
+ }
+ if out.String() != "OUTPUT" {
+ t.Fatalf("stdout want %q, got %q", "OUTPUT", out.String())
+ }
+ es := errb.String()
+ if !strings.Contains(es, "done provider=fake model=m") {
+ t.Fatalf("stderr missing provider/model: %q", es)
+ }
+ if !strings.Contains(es, "in_bytes=5") || !strings.Contains(es, "out_bytes=6") {
+ t.Fatalf("stderr missing byte counts: %q", es)
+ }
+}
+
+func TestRunChat_Streaming(t *testing.T) {
+ var out bytes.Buffer
+ var errb bytes.Buffer
+ fs := fakeStreamer{fakeClient: fakeClient{name: "fake", model: "m"}, chunks: []string{"OUT", "PUT"}}
+ if err := runChat(context.Background(), &fs, nil, "input", &out, &errb); err != nil {
+ t.Fatalf("runChat error: %v", err)
+ }
+ if out.String() != "OUTPUT" {
+ t.Fatalf("stdout want %q, got %q", "OUTPUT", out.String())
+ }
+ es := errb.String()
+ if !strings.Contains(es, "done provider=fake model=m") {
+ t.Fatalf("stderr missing provider/model: %q", es)
+ }
+ if !strings.Contains(es, "in_bytes=5") || !strings.Contains(es, "out_bytes=6") {
+ t.Fatalf("stderr missing byte counts: %q", es)
+ }
+}
+
+func TestPrintProviderInfo(t *testing.T) {
+ var b bytes.Buffer
+ fc := fakeClient{name: "fake", model: "m"}
+ printProviderInfo(&b, &fc)
+ s := b.String()
+ if !strings.Contains(s, "provider=fake model=m") {
+ t.Fatalf("unexpected banner: %q", s)
+ }
+}
+
+func TestRunWithClient_NonStreaming(t *testing.T) {
+ restore, f := setStdin(t, "")
+ defer restore()
+ var out bytes.Buffer
+ var errb bytes.Buffer
+ fc := fakeClient{name: "fake", model: "m", resp: "OK"}
+ if err := RunWithClient(context.Background(), []string{"ask"}, f, &out, &errb, &fc); err != nil {
+ t.Fatalf("RunWithClient error: %v", err)
+ }
+ if out.String() != "OK" {
+ t.Fatalf("stdout want %q, got %q", "OK", out.String())
+ }
+ if !strings.Contains(errb.String(), "provider=fake model=m") {
+ t.Fatalf("missing banner: %q", errb.String())
+ }
+}
+
+func TestRunWithClient_Streaming(t *testing.T) {
+ restore, f := setStdin(t, "")
+ defer restore()
+ var out bytes.Buffer
+ var errb bytes.Buffer
+ fs := fakeStreamer{fakeClient: fakeClient{name: "fake", model: "m"}, chunks: []string{"A", "B"}}
+ if err := RunWithClient(context.Background(), []string{"ask"}, f, &out, &errb, &fs); err != nil {
+ t.Fatalf("RunWithClient error: %v", err)
+ }
+ if out.String() != "AB" {
+ t.Fatalf("stdout want %q, got %q", "AB", out.String())
+ }
+ if !strings.Contains(errb.String(), "provider=fake model=m") {
+ t.Fatalf("missing banner: %q", errb.String())
+ }
+}
+
+func TestRunWithClient_CombinedInput_UsesCombinedMessage(t *testing.T) {
+ restore, f := setStdin(t, "payload")
+ defer restore()
+ var out bytes.Buffer
+ var errb bytes.Buffer
+ fc := fakeClient{name: "fake", model: "m", resp: "OK"}
+ if err := RunWithClient(context.Background(), []string{"subject"}, f, &out, &errb, &fc); err != nil {
+ t.Fatalf("RunWithClient error: %v", err)
+ }
+ if out.String() != "OK" {
+ t.Fatalf("stdout want %q, got %q", "OK", out.String())
+ }
+ if len(fc.gotMsgs) != 2 {
+ t.Fatalf("expected 2 messages, got %d", len(fc.gotMsgs))
+ }
+ if fc.gotMsgs[1].Content != "subject:\n\npayload" {
+ t.Fatalf("unexpected user message: %q", fc.gotMsgs[1].Content)
+ }
+}