1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
|
package llm
import (
"context"
"io"
"net/http"
"strings"
"testing"
"codeberg.org/snonux/hexai/internal/logging"
)
func TestOpenAIChatSuccess(t *testing.T) {
transport := roundTripFunc(func(r *http.Request) (*http.Response, error) {
if r.URL.Path != "/chat/completions" {
t.Fatalf("unexpected path: %s", r.URL.Path)
}
if got := r.Header.Get("Authorization"); got != "Bearer test-key" {
t.Fatalf("expected auth header, got %q", got)
}
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(strings.NewReader(`{"choices":[{"index":0,"message":{"role":"assistant","content":"hi there"},"finish_reason":"stop"}]}`)),
Header: make(http.Header),
}, nil
})
client := openAIClient{
httpClient: &http.Client{Transport: transport},
apiKey: "test-key",
baseURL: "https://example.com",
defaultModel: "gpt-test",
chatLogger: logging.NewChatLogger("openai"),
}
out, err := client.Chat(context.Background(), []Message{{Role: "user", Content: "hello"}})
if err != nil {
t.Fatalf("Chat returned error: %v", err)
}
if out != "hi there" {
t.Fatalf("unexpected chat output: %q", out)
}
}
func TestOpenAIChat_MissingKey_IsActionable(t *testing.T) {
client := openAIClient{defaultModel: "gpt-test"}
_, err := client.Chat(context.Background(), []Message{{Role: "user", Content: "hello"}})
if err == nil {
t.Fatal("expected missing key error")
}
if !strings.Contains(err.Error(), "OPENAI_API_KEY") || !strings.Contains(err.Error(), "HEXAI_OPENAI_API_KEY") {
t.Fatalf("expected actionable API key hint, got %q", err.Error())
}
}
func TestOpenAIChatStreamDeliversChunks(t *testing.T) {
client := openAIClient{
httpClient: &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
body := "data: {\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n" +
"data: {\"choices\":[{\"finish_reason\":\"stop\"}]}\n" +
"data: [DONE]\n"
return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader(body)), Header: make(http.Header)}, nil
})},
apiKey: "test-key",
baseURL: "https://example.com",
defaultModel: "gpt-test",
chatLogger: logging.NewChatLogger("openai"),
}
var received string
err := client.ChatStream(context.Background(), []Message{{Role: "user", Content: "hello"}}, func(chunk string) {
received += chunk
})
if err != nil {
t.Fatalf("ChatStream returned error: %v", err)
}
if received != "Hello" {
t.Fatalf("expected streamed content, got %q", received)
}
}
func TestOpenAIChatHandlesNon2xx(t *testing.T) {
client := openAIClient{
httpClient: &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) {
return &http.Response{StatusCode: http.StatusUnauthorized, Body: io.NopCloser(strings.NewReader("denied")), Header: make(http.Header)}, nil
})},
apiKey: "test-key",
baseURL: "https://example.com",
defaultModel: "gpt-test",
chatLogger: logging.NewChatLogger("openai"),
}
if _, err := client.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}); err == nil {
t.Fatal("expected error for non-2xx response")
}
}
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) }
|