summaryrefslogtreecommitdiff
path: root/internal/llm/openai_test.go
blob: ffa625224744a744ee454cb77af26d8c8d494212 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
package llm

import (
	"context"
	"io"
	"net/http"
	"strings"
	"testing"

	"codeberg.org/snonux/hexai/internal/logging"
)

func TestOpenAIChatSuccess(t *testing.T) {
	transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
		if r.URL.Path != "/chat/completions" {
			t.Fatalf("unexpected path: %s", r.URL.Path)
		}
		if got := r.Header.Get("Authorization"); got != "Bearer test-key" {
			t.Fatalf("expected auth header, got %q", got)
		}
		return &http.Response{
			StatusCode: 200,
			Body:       io.NopCloser(strings.NewReader(`{"choices":[{"index":0,"message":{"role":"assistant","content":"hi there"},"finish_reason":"stop"}]}`)),
			Header:     make(http.Header),
		}, nil
	})

	client := openAIClient{
		httpClient:   &http.Client{Transport: transport},
		apiKey:       "test-key",
		baseURL:      "https://example.com",
		defaultModel: "gpt-test",
		chatLogger:   logging.NewChatLogger("openai"),
	}

	out, err := client.Chat(context.Background(), []Message{{Role: "user", Content: "hello"}})
	if err != nil {
		t.Fatalf("Chat returned error: %v", err)
	}
	if out != "hi there" {
		t.Fatalf("unexpected chat output: %q", out)
	}
}

func TestOpenAIChat_MissingKey_IsActionable(t *testing.T) {
	client := openAIClient{defaultModel: "gpt-test"}
	_, err := client.Chat(context.Background(), []Message{{Role: "user", Content: "hello"}})
	if err == nil {
		t.Fatal("expected missing key error")
	}
	if !strings.Contains(err.Error(), "OPENAI_API_KEY") || !strings.Contains(err.Error(), "HEXAI_OPENAI_API_KEY") {
		t.Fatalf("expected actionable API key hint, got %q", err.Error())
	}
}

func TestOpenAIChatStreamDeliversChunks(t *testing.T) {
	client := openAIClient{
		httpClient: &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
			body := "data: {\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n" +
				"data: {\"choices\":[{\"finish_reason\":\"stop\"}]}\n" +
				"data: [DONE]\n"
			return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader(body)), Header: make(http.Header)}, nil
		})},
		apiKey:       "test-key",
		baseURL:      "https://example.com",
		defaultModel: "gpt-test",
		chatLogger:   logging.NewChatLogger("openai"),
	}

	var received string
	err := client.ChatStream(context.Background(), []Message{{Role: "user", Content: "hello"}}, func(chunk string) {
		received += chunk
	})
	if err != nil {
		t.Fatalf("ChatStream returned error: %v", err)
	}
	if received != "Hello" {
		t.Fatalf("expected streamed content, got %q", received)
	}
}

func TestOpenAIChatHandlesNon2xx(t *testing.T) {
	client := openAIClient{
		httpClient: &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
			return &http.Response{StatusCode: http.StatusUnauthorized, Body: io.NopCloser(strings.NewReader("denied")), Header: make(http.Header)}, nil
		})},
		apiKey:       "test-key",
		baseURL:      "https://example.com",
		defaultModel: "gpt-test",
		chatLogger:   logging.NewChatLogger("openai"),
	}

	if _, err := client.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}); err == nil {
		t.Fatal("expected error for non-2xx response")
	}
}

type roundTripFunc func(*http.Request) (*http.Response, error)

func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) }