summaryrefslogtreecommitdiff
path: root/internal/llm
diff options
context:
space:
mode:
Diffstat (limited to 'internal/llm')
-rw-r--r--internal/llm/copilot.go297
-rw-r--r--internal/llm/copilot_http_test.go392
-rw-r--r--internal/llm/ollama_test.go296
-rw-r--r--internal/llm/openai_http_test.go250
-rw-r--r--internal/llm/openai_sse_negative_test.go46
-rw-r--r--internal/llm/openai_test.go79
-rw-r--r--internal/llm/provider.go108
-rw-r--r--internal/llm/provider_more_test.go37
-rw-r--r--internal/llm/provider_test.go30
-rw-r--r--internal/llm/util_test.go7
10 files changed, 876 insertions, 666 deletions
diff --git a/internal/llm/copilot.go b/internal/llm/copilot.go
index 16eeda6..d3b1a9d 100644
--- a/internal/llm/copilot.go
+++ b/internal/llm/copilot.go
@@ -4,6 +4,7 @@ package llm
import (
"bytes"
"context"
+ "encoding/base64"
"encoding/json"
"errors"
"fmt"
@@ -13,7 +14,6 @@ import (
"strings"
"time"
- "encoding/base64"
appver "codeberg.org/snonux/hexai/internal"
"codeberg.org/snonux/hexai/internal/logging"
)
@@ -162,10 +162,14 @@ func buildCopilotChatRequest(o Options, messages []Message, defaultTemp *float64
}
func (c copilotClient) postJSON(ctx context.Context, url string, body []byte, headers map[string]string) (*http.Response, error) {
- req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
- if err != nil { return nil, err }
- for k, v := range headers { req.Header.Set(k, v) }
- return c.httpClient.Do(req)
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
+ if err != nil {
+ return nil, err
+ }
+ for k, v := range headers {
+ req.Header.Set(k, v)
+ }
+ return c.httpClient.Do(req)
}
func handleCopilotNon2xx(resp *http.Response, start time.Time) error {
@@ -194,55 +198,73 @@ func decodeCopilotChat(resp *http.Response, start time.Time) (copilotChatRespons
// --- Copilot session token management ---
type ghCopilotTokenResp struct {
- Token string `json:"token"`
+ Token string `json:"token"`
}
func (c *copilotClient) ensureSession(ctx context.Context) error {
- // If token valid for >60s, reuse
- if c.sessionToken != "" && time.Now().Add(60*time.Second).Before(c.tokenExpiry) {
- return nil
- }
- if strings.TrimSpace(c.apiKey) == "" {
- return errors.New("missing Copilot API key")
- }
- req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://api.github.com/copilot_internal/v2/token", nil)
- if err != nil { return err }
- req.Header.Set("Authorization", "Bearer "+c.apiKey)
- req.Header.Set("Accept", "application/json")
- req.Header.Set("User-Agent", "hexai/"+appver.Version)
- resp, err := c.httpClient.Do(req)
- if err != nil { return err }
- defer resp.Body.Close()
- if resp.StatusCode < 200 || resp.StatusCode >= 300 {
- return fmt.Errorf("copilot token http error: %d", resp.StatusCode)
- }
- var out ghCopilotTokenResp
- if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { return err }
- if strings.TrimSpace(out.Token) == "" { return errors.New("empty copilot session token") }
- // Parse JWT exp
- exp := parseJWTExp(out.Token)
- if exp.IsZero() { exp = time.Now().Add(10 * time.Minute) }
- c.sessionToken = out.Token
- c.tokenExpiry = exp
- return nil
+ // If token valid for >60s, reuse
+ if c.sessionToken != "" && time.Now().Add(60*time.Second).Before(c.tokenExpiry) {
+ return nil
+ }
+ if strings.TrimSpace(c.apiKey) == "" {
+ return errors.New("missing Copilot API key")
+ }
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://api.github.com/copilot_internal/v2/token", nil)
+ if err != nil {
+ return err
+ }
+ req.Header.Set("Authorization", "Bearer "+c.apiKey)
+ req.Header.Set("Accept", "application/json")
+ req.Header.Set("User-Agent", "hexai/"+appver.Version)
+ resp, err := c.httpClient.Do(req)
+ if err != nil {
+ return err
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
+ return fmt.Errorf("copilot token http error: %d", resp.StatusCode)
+ }
+ var out ghCopilotTokenResp
+ if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
+ return err
+ }
+ if strings.TrimSpace(out.Token) == "" {
+ return errors.New("empty copilot session token")
+ }
+ // Parse JWT exp
+ exp := parseJWTExp(out.Token)
+ if exp.IsZero() {
+ exp = time.Now().Add(10 * time.Minute)
+ }
+ c.sessionToken = out.Token
+ c.tokenExpiry = exp
+ return nil
}
var jwtExpRe = regexp.MustCompile(`"exp"\s*:\s*([0-9]+)`) // fallback if we can't base64 decode
func parseJWTExp(token string) time.Time {
- parts := strings.Split(token, ".")
- if len(parts) < 2 { return time.Time{} }
- b, err := base64.RawURLEncoding.DecodeString(parts[1])
- if err != nil {
- if m := jwtExpRe.FindStringSubmatch(token); len(m) == 2 {
- if n, err2 := parseInt64(m[1]); err2 == nil { return time.Unix(n, 0) }
- }
- return time.Time{}
- }
- var payload struct{ Exp int64 `json:"exp"` }
- _ = json.Unmarshal(b, &payload)
- if payload.Exp == 0 { return time.Time{} }
- return time.Unix(payload.Exp, 0)
+ parts := strings.Split(token, ".")
+ if len(parts) < 2 {
+ return time.Time{}
+ }
+ b, err := base64.RawURLEncoding.DecodeString(parts[1])
+ if err != nil {
+ if m := jwtExpRe.FindStringSubmatch(token); len(m) == 2 {
+ if n, err2 := parseInt64(m[1]); err2 == nil {
+ return time.Unix(n, 0)
+ }
+ }
+ return time.Time{}
+ }
+ var payload struct {
+ Exp int64 `json:"exp"`
+ }
+ _ = json.Unmarshal(b, &payload)
+ if payload.Exp == 0 {
+ return time.Time{}
+ }
+ return time.Unix(payload.Exp, 0)
}
func parseInt64(s string) (int64, error) { var n int64; _, err := fmt.Sscan(s, &n); return n, err }
@@ -250,99 +272,120 @@ func parseInt64(s string) (int64, error) { var n int64; _, err := fmt.Sscan(s, &
// --- Copilot headers ---
func (c *copilotClient) headersChat() map[string]string {
- _ = c.ensureSession(context.Background())
- h := map[string]string{
- "Content-Type": "application/json; charset=utf-8",
- "Accept": "application/json",
- "Authorization": "Bearer " + c.sessionToken,
- "User-Agent": "GitHubCopilotChat/0.8.0",
- "Editor-Plugin-Version": "copilot-chat/0.8.0",
- "Editor-Version": "vscode/1.85.1",
- "Openai-Intent": "conversation-panel",
- "Openai-Organization": "github-copilot",
- "VScode-MachineId": randHex(64),
- "VScode-SessionId": randHex(8) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(12),
- "X-Request-Id": randHex(8) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(12),
- }
- return h
+ _ = c.ensureSession(context.Background())
+ h := map[string]string{
+ "Content-Type": "application/json; charset=utf-8",
+ "Accept": "application/json",
+ "Authorization": "Bearer " + c.sessionToken,
+ "User-Agent": "GitHubCopilotChat/0.8.0",
+ "Editor-Plugin-Version": "copilot-chat/0.8.0",
+ "Editor-Version": "vscode/1.85.1",
+ "Openai-Intent": "conversation-panel",
+ "Openai-Organization": "github-copilot",
+ "VScode-MachineId": randHex(64),
+ "VScode-SessionId": randHex(8) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(12),
+ "X-Request-Id": randHex(8) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(12),
+ }
+ return h
}
func (c *copilotClient) headersGhost() map[string]string {
- _ = c.ensureSession(context.Background())
- h := map[string]string{
- "Content-Type": "application/json; charset=utf-8",
- "Accept": "*/*",
- "Authorization": "Bearer " + c.sessionToken,
- "User-Agent": "GithubCopilot/1.155.0",
- "Editor-Plugin-Version": "copilot/1.155.0",
- "Editor-Version": "vscode/1.85.1",
- "Openai-Intent": "copilot-ghost",
- "Openai-Organization": "github-copilot",
- "VScode-MachineId": randHex(64),
- "VScode-SessionId": randHex(8) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(12),
- "X-Request-Id": randHex(8) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(12),
- }
- return h
+ _ = c.ensureSession(context.Background())
+ h := map[string]string{
+ "Content-Type": "application/json; charset=utf-8",
+ "Accept": "*/*",
+ "Authorization": "Bearer " + c.sessionToken,
+ "User-Agent": "GithubCopilot/1.155.0",
+ "Editor-Plugin-Version": "copilot/1.155.0",
+ "Editor-Version": "vscode/1.85.1",
+ "Openai-Intent": "copilot-ghost",
+ "Openai-Organization": "github-copilot",
+ "VScode-MachineId": randHex(64),
+ "VScode-SessionId": randHex(8) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(12),
+ "X-Request-Id": randHex(8) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(12),
+ }
+ return h
}
func randHex(n int) string {
- const hex = "0123456789abcdef"
- b := make([]byte, n)
- for i := range b {
- b[i] = hex[int(time.Now().UnixNano()+int64(i))%len(hex)]
- }
- return string(b)
+ const hex = "0123456789abcdef"
+ b := make([]byte, n)
+ for i := range b {
+ b[i] = hex[int(time.Now().UnixNano()+int64(i))%len(hex)]
+ }
+ return string(b)
}
// --- Codex-style code completion ---
// CodeCompletion implements CodeCompleter; returns up to n suggestions.
func (c copilotClient) CodeCompletion(ctx context.Context, prompt string, suffix string, n int, language string, temperature float64) ([]string, error) {
- if strings.TrimSpace(c.apiKey) == "" { return nil, errors.New("missing Copilot API key") }
- if err := c.ensureSession(ctx); err != nil { return nil, err }
- if n <= 0 { n = 1 }
- maxTokens := 500
- body := map[string]any{
- "extra": map[string]any{
- "language": language,
- "next_indent": 0,
- "prompt_tokens": 500,
- "suffix_tokens": 400,
- "trim_by_indentation": true,
- },
- "max_tokens": maxTokens,
- "n": n,
- "nwo": "hexai",
- "prompt": prompt,
- "stop": []string{"\n\n"},
- "stream": true,
- "suffix": suffix,
- "temperature": temperature,
- "top_p": 1,
- }
- buf, _ := json.Marshal(body)
- url := "https://copilot-proxy.githubusercontent.com/v1/engines/copilot-codex/completions"
- resp, err := c.postJSON(ctx, url, buf, c.headersGhost())
- if err != nil { return nil, err }
- defer resp.Body.Close()
- if resp.StatusCode < 200 || resp.StatusCode >= 300 {
- return nil, fmt.Errorf("copilot codex http error: %d", resp.StatusCode)
- }
- // Read all and parse lines that start with "data: " accumulating by index
- raw, _ := io.ReadAll(resp.Body)
- byIndex := make(map[int]string)
- lines := strings.Split(string(raw), "\n")
- for _, ln := range lines {
- if !strings.HasPrefix(ln, "data: ") { continue }
- var evt struct{ Choices []struct{ Index int `json:"index"`; Text string `json:"text"` } `json:"choices"` }
- if err := json.Unmarshal([]byte(strings.TrimPrefix(ln, "data: ")), &evt); err != nil { continue }
- for _, ch := range evt.Choices { byIndex[ch.Index] += ch.Text }
- }
- out := make([]string, 0, len(byIndex))
- for i := 0; i < n; i++ {
- if s, ok := byIndex[i]; ok && strings.TrimSpace(s) != "" { out = append(out, s) }
- }
- return out, nil
+ if strings.TrimSpace(c.apiKey) == "" {
+ return nil, errors.New("missing Copilot API key")
+ }
+ if err := c.ensureSession(ctx); err != nil {
+ return nil, err
+ }
+ if n <= 0 {
+ n = 1
+ }
+ maxTokens := 500
+ body := map[string]any{
+ "extra": map[string]any{
+ "language": language,
+ "next_indent": 0,
+ "prompt_tokens": 500,
+ "suffix_tokens": 400,
+ "trim_by_indentation": true,
+ },
+ "max_tokens": maxTokens,
+ "n": n,
+ "nwo": "hexai",
+ "prompt": prompt,
+ "stop": []string{"\n\n"},
+ "stream": true,
+ "suffix": suffix,
+ "temperature": temperature,
+ "top_p": 1,
+ }
+ buf, _ := json.Marshal(body)
+ url := "https://copilot-proxy.githubusercontent.com/v1/engines/copilot-codex/completions"
+ resp, err := c.postJSON(ctx, url, buf, c.headersGhost())
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
+ return nil, fmt.Errorf("copilot codex http error: %d", resp.StatusCode)
+ }
+ // Read all and parse lines that start with "data: " accumulating by index
+ raw, _ := io.ReadAll(resp.Body)
+ byIndex := make(map[int]string)
+ lines := strings.Split(string(raw), "\n")
+ for _, ln := range lines {
+ if !strings.HasPrefix(ln, "data: ") {
+ continue
+ }
+ var evt struct {
+ Choices []struct {
+ Index int `json:"index"`
+ Text string `json:"text"`
+ } `json:"choices"`
+ }
+ if err := json.Unmarshal([]byte(strings.TrimPrefix(ln, "data: ")), &evt); err != nil {
+ continue
+ }
+ for _, ch := range evt.Choices {
+ byIndex[ch.Index] += ch.Text
+ }
+ }
+ out := make([]string, 0, len(byIndex))
+ for i := 0; i < n; i++ {
+ if s, ok := byIndex[i]; ok && strings.TrimSpace(s) != "" {
+ out = append(out, s)
+ }
+ }
+ return out, nil
}
// newLineDataReader wraps a streaming body and exposes a JSON decoder that
diff --git a/internal/llm/copilot_http_test.go b/internal/llm/copilot_http_test.go
index 180e43e..d66311c 100644
--- a/internal/llm/copilot_http_test.go
+++ b/internal/llm/copilot_http_test.go
@@ -1,205 +1,261 @@
package llm
import (
- "context"
- "encoding/json"
- "io"
- "net/http"
- "net/http/httptest"
- "strings"
- "testing"
- "time"
- "encoding/base64"
- "os"
+ "context"
+ "encoding/base64"
+ "encoding/json"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "strings"
+ "testing"
+ "time"
)
type rtFunc2 func(*http.Request) (*http.Response, error)
+
func (f rtFunc2) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) }
func TestCopilot_EnsureSession_AndChat_Success(t *testing.T) {
- if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") }
- // Mock chat endpoint
- chatSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if r.URL.Path != "/chat/completions" { t.Fatalf("unexpected path: %s", r.URL.Path) }
- _ = json.NewEncoder(w).Encode(map[string]any{"choices": []map[string]any{{"index":0, "message": map[string]string{"role":"assistant","content":"OK"}}}})
- }))
- defer chatSrv.Close()
- c := newCopilot(chatSrv.URL, "gpt-4o-mini", "APIKEY", f64p(0.1)).(copilotClient)
- // Intercept token endpoint to return a session token
- tr := rtFunc2(func(r *http.Request) (*http.Response, error) {
- if r.URL.Host == "api.github.com" && r.URL.Path == "/copilot_internal/v2/token" {
- rw := httptest.NewRecorder()
- _ = json.NewEncoder(rw).Encode(map[string]string{"token":"tok"})
- res := rw.Result()
- res.StatusCode = 200
- return res, nil
- }
- // Fallback to default transport for chatSrv
- return http.DefaultTransport.RoundTrip(r)
- })
- c.httpClient = &http.Client{Transport: tr, Timeout: 5 * time.Second}
- out, err := c.Chat(context.Background(), []Message{{Role:"user", Content:"hi"}})
- if err != nil || out != "OK" { t.Fatalf("copilot chat failed: %v %q", err, out) }
+ if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" {
+ t.Skip("skip network-bound tests in restricted environments")
+ }
+ // Mock chat endpoint
+ chatSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path != "/chat/completions" {
+ t.Fatalf("unexpected path: %s", r.URL.Path)
+ }
+ _ = json.NewEncoder(w).Encode(map[string]any{"choices": []map[string]any{{"index": 0, "message": map[string]string{"role": "assistant", "content": "OK"}}}})
+ }))
+ defer chatSrv.Close()
+ c := newCopilot(chatSrv.URL, "gpt-4o-mini", "APIKEY", f64p(0.1)).(copilotClient)
+ // Intercept token endpoint to return a session token
+ tr := rtFunc2(func(r *http.Request) (*http.Response, error) {
+ if r.URL.Host == "api.github.com" && r.URL.Path == "/copilot_internal/v2/token" {
+ rw := httptest.NewRecorder()
+ _ = json.NewEncoder(rw).Encode(map[string]string{"token": "tok"})
+ res := rw.Result()
+ res.StatusCode = 200
+ return res, nil
+ }
+ // Fallback to default transport for chatSrv
+ return http.DefaultTransport.RoundTrip(r)
+ })
+ c.httpClient = &http.Client{Transport: tr, Timeout: 5 * time.Second}
+ out, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}})
+ if err != nil || out != "OK" {
+ t.Fatalf("copilot chat failed: %v %q", err, out)
+ }
}
func TestCopilot_HandleNon2xx(t *testing.T) {
- b, _ := json.Marshal(map[string]any{"error": map[string]any{"message":"bad","type":"invalid"}})
- resp := &http.Response{StatusCode: 400, Body: io.NopCloser(bytesReader(b))}
- if err := handleCopilotNon2xx(resp, time.Now()); err == nil { t.Fatalf("expected error") }
+ b, _ := json.Marshal(map[string]any{"error": map[string]any{"message": "bad", "type": "invalid"}})
+ resp := &http.Response{StatusCode: 400, Body: io.NopCloser(bytesReader(b))}
+ if err := handleCopilotNon2xx(resp, time.Now()); err == nil {
+ t.Fatalf("expected error")
+ }
}
func TestCopilot_CodeCompletion_Success(t *testing.T) {
- c := newCopilot("https://api.githubcopilot.com", "gpt-4o-mini", "API", f64p(0.1)).(copilotClient)
- tr := rtFunc2(func(r *http.Request) (*http.Response, error) {
- // Token endpoint
- if r.URL.Host == "api.github.com" && r.URL.Path == "/copilot_internal/v2/token" {
- rw := httptest.NewRecorder()
- _ = json.NewEncoder(rw).Encode(map[string]string{"token":"tok"})
- res := rw.Result(); res.StatusCode = 200; return res, nil
- }
- // Codex completion endpoint
- if r.URL.Host == "copilot-proxy.githubusercontent.com" && strings.HasSuffix(r.URL.Path, "/v1/engines/copilot-codex/completions") {
- rw := httptest.NewRecorder()
- // two choices for index 0 and 1
- rw.WriteString("data: {\"choices\":[{\"index\":0,\"text\":\"A\"}]}\n")
- rw.WriteString("data: {\"choices\":[{\"index\":1,\"text\":\"B\"}]}\n")
- res := rw.Result(); res.StatusCode = 200; return res, nil
- }
- return http.DefaultTransport.RoundTrip(r)
- })
- c.httpClient = &http.Client{Transport: tr, Timeout: 5 * time.Second}
- out, err := c.CodeCompletion(context.Background(), "p", "s", 2, "go", 0.1)
- if err != nil || len(out) != 2 || out[0] != "A" || out[1] != "B" {
- t.Fatalf("codex: %v %#v", err, out)
- }
+ c := newCopilot("https://api.githubcopilot.com", "gpt-4o-mini", "API", f64p(0.1)).(copilotClient)
+ tr := rtFunc2(func(r *http.Request) (*http.Response, error) {
+ // Token endpoint
+ if r.URL.Host == "api.github.com" && r.URL.Path == "/copilot_internal/v2/token" {
+ rw := httptest.NewRecorder()
+ _ = json.NewEncoder(rw).Encode(map[string]string{"token": "tok"})
+ res := rw.Result()
+ res.StatusCode = 200
+ return res, nil
+ }
+ // Codex completion endpoint
+ if r.URL.Host == "copilot-proxy.githubusercontent.com" && strings.HasSuffix(r.URL.Path, "/v1/engines/copilot-codex/completions") {
+ rw := httptest.NewRecorder()
+ // two choices for index 0 and 1
+ rw.WriteString("data: {\"choices\":[{\"index\":0,\"text\":\"A\"}]}\n")
+ rw.WriteString("data: {\"choices\":[{\"index\":1,\"text\":\"B\"}]}\n")
+ res := rw.Result()
+ res.StatusCode = 200
+ return res, nil
+ }
+ return http.DefaultTransport.RoundTrip(r)
+ })
+ c.httpClient = &http.Client{Transport: tr, Timeout: 5 * time.Second}
+ out, err := c.CodeCompletion(context.Background(), "p", "s", 2, "go", 0.1)
+ if err != nil || len(out) != 2 || out[0] != "A" || out[1] != "B" {
+ t.Fatalf("codex: %v %#v", err, out)
+ }
}
func TestCopilot_Chat_MultiChoice_And_ErrorBody(t *testing.T) {
- if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") }
- // Chat multi-choice: return two choices; client returns first content
- srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- _ = json.NewEncoder(w).Encode(map[string]any{
- "choices": []map[string]any{
- {"index": 0, "finish_reason": "stop", "message": map[string]string{"role": "assistant", "content": "FIRST"}},
- {"index": 1, "finish_reason": "length", "message": map[string]string{"role": "assistant", "content": "SECOND"}},
- },
- })
- }))
- defer srv.Close()
- c := newCopilot(srv.URL, "gpt-4o-mini", "KEY", f64p(0.1)).(copilotClient)
- // Token success
- tr := rtFunc2(func(r *http.Request) (*http.Response, error) {
- if r.URL.Host == "api.github.com" && r.URL.Path == "/copilot_internal/v2/token" {
- rw := httptest.NewRecorder(); _ = json.NewEncoder(rw).Encode(map[string]string{"token":"tok"}); res := rw.Result(); res.StatusCode = 200; return res, nil
- }
- return http.DefaultTransport.RoundTrip(r)
- })
- c.httpClient = &http.Client{Transport: tr, Timeout: 5 * time.Second}
- out, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}})
- if err != nil || out != "FIRST" { t.Fatalf("copilot multi-choice: %v %q", err, out) }
+ if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" {
+ t.Skip("skip network-bound tests in restricted environments")
+ }
+ // Chat multi-choice: return two choices; client returns first content
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ _ = json.NewEncoder(w).Encode(map[string]any{
+ "choices": []map[string]any{
+ {"index": 0, "finish_reason": "stop", "message": map[string]string{"role": "assistant", "content": "FIRST"}},
+ {"index": 1, "finish_reason": "length", "message": map[string]string{"role": "assistant", "content": "SECOND"}},
+ },
+ })
+ }))
+ defer srv.Close()
+ c := newCopilot(srv.URL, "gpt-4o-mini", "KEY", f64p(0.1)).(copilotClient)
+ // Token success
+ tr := rtFunc2(func(r *http.Request) (*http.Response, error) {
+ if r.URL.Host == "api.github.com" && r.URL.Path == "/copilot_internal/v2/token" {
+ rw := httptest.NewRecorder()
+ _ = json.NewEncoder(rw).Encode(map[string]string{"token": "tok"})
+ res := rw.Result()
+ res.StatusCode = 200
+ return res, nil
+ }
+ return http.DefaultTransport.RoundTrip(r)
+ })
+ c.httpClient = &http.Client{Transport: tr, Timeout: 5 * time.Second}
+ out, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}})
+ if err != nil || out != "FIRST" {
+ t.Fatalf("copilot multi-choice: %v %q", err, out)
+ }
- // Non-2xx with error body
- srv2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(403)
- _ = json.NewEncoder(w).Encode(map[string]any{"error": map[string]any{"message":"denied","type":"forbidden"}})
- }))
- defer srv2.Close()
- c2 := newCopilot(srv2.URL, "gpt-4o-mini", "KEY", f64p(0.1)).(copilotClient)
- c2.httpClient = &http.Client{Transport: tr, Timeout: 5 * time.Second}
- if _, err := c2.Chat(context.Background(), []Message{{Role:"user", Content:"hi"}}); err == nil {
- t.Fatalf("expected error for copilot non-2xx with error body")
- }
+ // Non-2xx with error body
+ srv2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(403)
+ _ = json.NewEncoder(w).Encode(map[string]any{"error": map[string]any{"message": "denied", "type": "forbidden"}})
+ }))
+ defer srv2.Close()
+ c2 := newCopilot(srv2.URL, "gpt-4o-mini", "KEY", f64p(0.1)).(copilotClient)
+ c2.httpClient = &http.Client{Transport: tr, Timeout: 5 * time.Second}
+ if _, err := c2.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}); err == nil {
+ t.Fatalf("expected error for copilot non-2xx with error body")
+ }
}
func TestCopilot_Chat_NoChoices_Error(t *testing.T) {
- if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") }
- srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- _ = json.NewEncoder(w).Encode(map[string]any{"choices": []any{}})
- }))
- defer srv.Close()
- c := newCopilot(srv.URL, "gpt-4o-mini", "KEY", f64p(0.1)).(copilotClient)
- tr := rtFunc2(func(r *http.Request) (*http.Response, error) {
- if r.URL.Host == "api.github.com" && r.URL.Path == "/copilot_internal/v2/token" {
- rw := httptest.NewRecorder(); _ = json.NewEncoder(rw).Encode(map[string]string{"token":"tok"}); res := rw.Result(); res.StatusCode = 200; return res, nil
- }
- return http.DefaultTransport.RoundTrip(r)
- })
- c.httpClient = &http.Client{Transport: tr, Timeout: 5 * time.Second}
- if _, err := c.Chat(context.Background(), []Message{{Role:"user", Content:"hi"}}); err == nil {
- t.Fatalf("expected error when no choices returned")
- }
+ if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" {
+ t.Skip("skip network-bound tests in restricted environments")
+ }
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ _ = json.NewEncoder(w).Encode(map[string]any{"choices": []any{}})
+ }))
+ defer srv.Close()
+ c := newCopilot(srv.URL, "gpt-4o-mini", "KEY", f64p(0.1)).(copilotClient)
+ tr := rtFunc2(func(r *http.Request) (*http.Response, error) {
+ if r.URL.Host == "api.github.com" && r.URL.Path == "/copilot_internal/v2/token" {
+ rw := httptest.NewRecorder()
+ _ = json.NewEncoder(rw).Encode(map[string]string{"token": "tok"})
+ res := rw.Result()
+ res.StatusCode = 200
+ return res, nil
+ }
+ return http.DefaultTransport.RoundTrip(r)
+ })
+ c.httpClient = &http.Client{Transport: tr, Timeout: 5 * time.Second}
+ if _, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}); err == nil {
+ t.Fatalf("expected error when no choices returned")
+ }
}
func TestCopilot_Chat_DecodeError_StatusOK(t *testing.T) {
- if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") }
- // Chat returns 200 but invalid JSON; expect decode error
- srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- io.WriteString(w, "{invalid")
- }))
- defer srv.Close()
- c := newCopilot(srv.URL, "gpt-4o-mini", "KEY", f64p(0.1)).(copilotClient)
- tr := rtFunc2(func(r *http.Request) (*http.Response, error) {
- if r.URL.Host == "api.github.com" && r.URL.Path == "/copilot_internal/v2/token" {
- rw := httptest.NewRecorder(); _ = json.NewEncoder(rw).Encode(map[string]string{"token":"tok"}); res := rw.Result(); res.StatusCode = 200; return res, nil
- }
- return http.DefaultTransport.RoundTrip(r)
- })
- c.httpClient = &http.Client{Transport: tr, Timeout: 5 * time.Second}
- if _, err := c.Chat(context.Background(), []Message{{Role:"user", Content:"hi"}}); err == nil {
- t.Fatalf("expected decode error for invalid body")
- }
+ if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" {
+ t.Skip("skip network-bound tests in restricted environments")
+ }
+ // Chat returns 200 but invalid JSON; expect decode error
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ io.WriteString(w, "{invalid")
+ }))
+ defer srv.Close()
+ c := newCopilot(srv.URL, "gpt-4o-mini", "KEY", f64p(0.1)).(copilotClient)
+ tr := rtFunc2(func(r *http.Request) (*http.Response, error) {
+ if r.URL.Host == "api.github.com" && r.URL.Path == "/copilot_internal/v2/token" {
+ rw := httptest.NewRecorder()
+ _ = json.NewEncoder(rw).Encode(map[string]string{"token": "tok"})
+ res := rw.Result()
+ res.StatusCode = 200
+ return res, nil
+ }
+ return http.DefaultTransport.RoundTrip(r)
+ })
+ c.httpClient = &http.Client{Transport: tr, Timeout: 5 * time.Second}
+ if _, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}); err == nil {
+ t.Fatalf("expected decode error for invalid body")
+ }
}
func TestCopilot_CodeCompletion_MalformedAndEmpty(t *testing.T) {
- c := newCopilot("https://api.githubcopilot.com", "gpt-4o-mini", "API", f64p(0.1)).(copilotClient)
- tr := rtFunc2(func(r *http.Request) (*http.Response, error) {
- if r.URL.Host == "api.github.com" && r.URL.Path == "/copilot_internal/v2/token" {
- rw := httptest.NewRecorder(); _ = json.NewEncoder(rw).Encode(map[string]string{"token":"tok"}); res := rw.Result(); res.StatusCode = 200; return res, nil
- }
- if r.URL.Host == "copilot-proxy.githubusercontent.com" && strings.HasSuffix(r.URL.Path, "/v1/engines/copilot-codex/completions") {
- rw := httptest.NewRecorder()
- // malformed line
- rw.WriteString("data: {bad}\n")
- // done; should produce empty suggestions
- rw.WriteString("data: [DONE]\n")
- res := rw.Result(); res.StatusCode = 200; return res, nil
- }
- return http.DefaultTransport.RoundTrip(r)
- })
- c.httpClient = &http.Client{Transport: tr, Timeout: 5 * time.Second}
- out, err := c.CodeCompletion(context.Background(), "p", "s", 1, "go", 0.1)
- if err != nil { t.Fatalf("unexpected error: %v", err) }
- if len(out) != 0 { t.Fatalf("expected empty suggestions, got %#v", out) }
+ c := newCopilot("https://api.githubcopilot.com", "gpt-4o-mini", "API", f64p(0.1)).(copilotClient)
+ tr := rtFunc2(func(r *http.Request) (*http.Response, error) {
+ if r.URL.Host == "api.github.com" && r.URL.Path == "/copilot_internal/v2/token" {
+ rw := httptest.NewRecorder()
+ _ = json.NewEncoder(rw).Encode(map[string]string{"token": "tok"})
+ res := rw.Result()
+ res.StatusCode = 200
+ return res, nil
+ }
+ if r.URL.Host == "copilot-proxy.githubusercontent.com" && strings.HasSuffix(r.URL.Path, "/v1/engines/copilot-codex/completions") {
+ rw := httptest.NewRecorder()
+ // malformed line
+ rw.WriteString("data: {bad}\n")
+ // done; should produce empty suggestions
+ rw.WriteString("data: [DONE]\n")
+ res := rw.Result()
+ res.StatusCode = 200
+ return res, nil
+ }
+ return http.DefaultTransport.RoundTrip(r)
+ })
+ c.httpClient = &http.Client{Transport: tr, Timeout: 5 * time.Second}
+ out, err := c.CodeCompletion(context.Background(), "p", "s", 1, "go", 0.1)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if len(out) != 0 {
+ t.Fatalf("expected empty suggestions, got %#v", out)
+ }
- // Now include one good chunk after malformed
- tr2 := rtFunc2(func(r *http.Request) (*http.Response, error) {
- if r.URL.Host == "api.github.com" && r.URL.Path == "/copilot_internal/v2/token" {
- rw := httptest.NewRecorder(); _ = json.NewEncoder(rw).Encode(map[string]string{"token":"tok"}); res := rw.Result(); res.StatusCode = 200; return res, nil
- }
- if r.URL.Host == "copilot-proxy.githubusercontent.com" && strings.HasSuffix(r.URL.Path, "/v1/engines/copilot-codex/completions") {
- rw := httptest.NewRecorder()
- rw.WriteString("data: {bad}\n")
- rw.WriteString("data: {\"choices\":[{\"index\":0,\"text\":\"OK\"}]}\n")
- rw.WriteString("data: [DONE]\n")
- res := rw.Result(); res.StatusCode = 200; return res, nil
- }
- return http.DefaultTransport.RoundTrip(r)
- })
- c.httpClient = &http.Client{Transport: tr2, Timeout: 5 * time.Second}
- out2, err := c.CodeCompletion(context.Background(), "p", "s", 1, "go", 0.1)
- if err != nil || len(out2) != 1 || out2[0] != "OK" { t.Fatalf("unexpected: %v %#v", err, out2) }
+ // Now include one good chunk after malformed
+ tr2 := rtFunc2(func(r *http.Request) (*http.Response, error) {
+ if r.URL.Host == "api.github.com" && r.URL.Path == "/copilot_internal/v2/token" {
+ rw := httptest.NewRecorder()
+ _ = json.NewEncoder(rw).Encode(map[string]string{"token": "tok"})
+ res := rw.Result()
+ res.StatusCode = 200
+ return res, nil
+ }
+ if r.URL.Host == "copilot-proxy.githubusercontent.com" && strings.HasSuffix(r.URL.Path, "/v1/engines/copilot-codex/completions") {
+ rw := httptest.NewRecorder()
+ rw.WriteString("data: {bad}\n")
+ rw.WriteString("data: {\"choices\":[{\"index\":0,\"text\":\"OK\"}]}\n")
+ rw.WriteString("data: [DONE]\n")
+ res := rw.Result()
+ res.StatusCode = 200
+ return res, nil
+ }
+ return http.DefaultTransport.RoundTrip(r)
+ })
+ c.httpClient = &http.Client{Transport: tr2, Timeout: 5 * time.Second}
+ out2, err := c.CodeCompletion(context.Background(), "p", "s", 1, "go", 0.1)
+ if err != nil || len(out2) != 1 || out2[0] != "OK" {
+ t.Fatalf("unexpected: %v %#v", err, out2)
+ }
}
func TestParseJWTExp_AndParseInt64(t *testing.T) {
- // Valid base64 payload
- payload := `{"exp": 1700000000}`
- b := base64.RawURLEncoding.EncodeToString([]byte(payload))
- tok := "x." + b + ".y"
- if tm := parseJWTExp(tok); tm.IsZero() { t.Fatalf("expected non-zero time") }
- if n, err := parseInt64("123"); err != nil || n != 123 { t.Fatalf("parseInt64: %v %d", err, n) }
+ // Valid base64 payload
+ payload := `{"exp": 1700000000}`
+ b := base64.RawURLEncoding.EncodeToString([]byte(payload))
+ tok := "x." + b + ".y"
+ if tm := parseJWTExp(tok); tm.IsZero() {
+ t.Fatalf("expected non-zero time")
+ }
+ if n, err := parseInt64("123"); err != nil || n != 123 {
+ t.Fatalf("parseInt64: %v %d", err, n)
+ }
}
// bytesReader wraps a byte slice with an io.ReadCloser without importing extra.
type bytesReader []byte
+
func (b bytesReader) Read(p []byte) (int, error) { n := copy(p, b); return n, io.EOF }
-func (b bytesReader) Close() error { return nil }
+func (b bytesReader) Close() error { return nil }
diff --git a/internal/llm/ollama_test.go b/internal/llm/ollama_test.go
index 15f9cff..8bd33ca 100644
--- a/internal/llm/ollama_test.go
+++ b/internal/llm/ollama_test.go
@@ -1,173 +1,217 @@
package llm
import (
- "context"
- "encoding/json"
- "fmt"
- "net/http"
- "net/http/httptest"
- "strings"
- "testing"
- "time"
- "os"
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "strings"
+ "testing"
+ "time"
)
func TestBuildOllamaRequest_OptionsAndStream(t *testing.T) {
- o := Options{Model: "codemodel", Temperature: 0, MaxTokens: 256, Stop: []string{"STOP"}}
- msgs := []Message{{Role: "user", Content: "hello"}}
- req := buildOllamaRequest(o, msgs, f64p(0.2), false)
- if req.Model != "codemodel" || req.Stream { t.Fatalf("model/stream mismatch: %+v", req) }
- if req.Options == nil { t.Fatalf("expected options map") }
- if req.Options.(map[string]any)["temperature"].(float64) != 0.2 { t.Fatalf("default temp not applied") }
- if req.Options.(map[string]any)["num_predict"].(int) != 256 { t.Fatalf("num_predict not applied") }
- if req.Options.(map[string]any)["stop"].([]string)[0] != "STOP" { t.Fatalf("stop not applied") }
-
- req2 := buildOllamaRequest(o, msgs, f64p(0.2), true)
- if !req2.Stream { t.Fatalf("expected stream=true") }
+ o := Options{Model: "codemodel", Temperature: 0, MaxTokens: 256, Stop: []string{"STOP"}}
+ msgs := []Message{{Role: "user", Content: "hello"}}
+ req := buildOllamaRequest(o, msgs, f64p(0.2), false)
+ if req.Model != "codemodel" || req.Stream {
+ t.Fatalf("model/stream mismatch: %+v", req)
+ }
+ if req.Options == nil {
+ t.Fatalf("expected options map")
+ }
+ if req.Options.(map[string]any)["temperature"].(float64) != 0.2 {
+ t.Fatalf("default temp not applied")
+ }
+ if req.Options.(map[string]any)["num_predict"].(int) != 256 {
+ t.Fatalf("num_predict not applied")
+ }
+ if req.Options.(map[string]any)["stop"].([]string)[0] != "STOP" {
+ t.Fatalf("stop not applied")
+ }
+
+ req2 := buildOllamaRequest(o, msgs, f64p(0.2), true)
+ if !req2.Stream {
+ t.Fatalf("expected stream=true")
+ }
}
func TestBuildOllamaRequest_TempOverride(t *testing.T) {
- o := Options{Model: "m", Temperature: 0.9}
- msgs := []Message{{Role: "user", Content: "hi"}}
- req := buildOllamaRequest(o, msgs, f64p(0.2), false)
- m := req.Options.(map[string]any)
- if m["temperature"].(float64) != 0.9 { t.Fatalf("explicit temp should override default") }
+ o := Options{Model: "m", Temperature: 0.9}
+ msgs := []Message{{Role: "user", Content: "hi"}}
+ req := buildOllamaRequest(o, msgs, f64p(0.2), false)
+ m := req.Options.(map[string]any)
+ if m["temperature"].(float64) != 0.9 {
+ t.Fatalf("explicit temp should override default")
+ }
}
func TestOllama_NameAndModel(t *testing.T) {
- c := newOllama("http://x", "model-x", nil).(ollamaClient)
- if c.Name() != "ollama" { t.Fatalf("name: %q", c.Name()) }
- if c.DefaultModel() != "model-x" { t.Fatalf("default model: %q", c.DefaultModel()) }
+ c := newOllama("http://x", "model-x", nil).(ollamaClient)
+ if c.Name() != "ollama" {
+ t.Fatalf("name: %q", c.Name())
+ }
+ if c.DefaultModel() != "model-x" {
+ t.Fatalf("default model: %q", c.DefaultModel())
+ }
}
func TestOllamaChat_Success(t *testing.T) {
- if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") }
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if r.Method != http.MethodPost || r.URL.Path != "/api/chat" { t.Fatalf("unexpected request: %s %s", r.Method, r.URL.Path) }
- w.Header().Set("Content-Type", "application/json")
- _ = json.NewEncoder(w).Encode(map[string]any{"message": map[string]string{"role":"assistant","content":"Hello"}, "done": true})
- }))
- defer ts.Close()
- c := newOllama(ts.URL, "m", f64p(0.1)).(ollamaClient)
- c.httpClient = ts.Client()
- out, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}})
- if err != nil { t.Fatalf("unexpected err: %v", err) }
- if out != "Hello" { t.Fatalf("got %q", out) }
+ if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" {
+ t.Skip("skip network-bound tests in restricted environments")
+ }
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost || r.URL.Path != "/api/chat" {
+ t.Fatalf("unexpected request: %s %s", r.Method, r.URL.Path)
+ }
+ w.Header().Set("Content-Type", "application/json")
+ _ = json.NewEncoder(w).Encode(map[string]any{"message": map[string]string{"role": "assistant", "content": "Hello"}, "done": true})
+ }))
+ defer ts.Close()
+ c := newOllama(ts.URL, "m", f64p(0.1)).(ollamaClient)
+ c.httpClient = ts.Client()
+ out, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}})
+ if err != nil {
+ t.Fatalf("unexpected err: %v", err)
+ }
+ if out != "Hello" {
+ t.Fatalf("got %q", out)
+ }
}
func TestOllamaChat_EmptyContent(t *testing.T) {
- if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") }
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- _ = json.NewEncoder(w).Encode(map[string]any{"message": map[string]string{"role":"assistant","content":""}, "done": true})
- }))
- defer ts.Close()
- c := newOllama(ts.URL, "m", nil).(ollamaClient)
- c.httpClient = ts.Client()
- if _, err := c.Chat(context.Background(), []Message{{Role:"user", Content:"x"}}); err == nil {
- t.Fatalf("expected error for empty content")
- }
+ if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" {
+ t.Skip("skip network-bound tests in restricted environments")
+ }
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ _ = json.NewEncoder(w).Encode(map[string]any{"message": map[string]string{"role": "assistant", "content": ""}, "done": true})
+ }))
+ defer ts.Close()
+ c := newOllama(ts.URL, "m", nil).(ollamaClient)
+ c.httpClient = ts.Client()
+ if _, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "x"}}); err == nil {
+ t.Fatalf("expected error for empty content")
+ }
}
func TestOllamaChat_Non2xx(t *testing.T) {
- if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") }
- // API error string
- ts1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(400)
- _ = json.NewEncoder(w).Encode(map[string]any{"error":"bad"})
- }))
- defer ts1.Close()
- c1 := newOllama(ts1.URL, "m", nil).(ollamaClient)
- c1.httpClient = ts1.Client()
- if _, err := c1.Chat(context.Background(), []Message{{Role:"user", Content:"x"}}); err == nil {
- t.Fatalf("expected error for 400 with api body")
- }
- // Plain HTTP error without api message
- ts2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(500)
- _, _ = w.Write([]byte("{}"))
- }))
- defer ts2.Close()
- c2 := newOllama(ts2.URL, "m", nil).(ollamaClient)
- c2.httpClient = ts2.Client()
- if _, err := c2.Chat(context.Background(), []Message{{Role:"user", Content:"x"}}); err == nil {
- t.Fatalf("expected error for 500")
- }
+ if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" {
+ t.Skip("skip network-bound tests in restricted environments")
+ }
+ // API error string
+ ts1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(400)
+ _ = json.NewEncoder(w).Encode(map[string]any{"error": "bad"})
+ }))
+ defer ts1.Close()
+ c1 := newOllama(ts1.URL, "m", nil).(ollamaClient)
+ c1.httpClient = ts1.Client()
+ if _, err := c1.Chat(context.Background(), []Message{{Role: "user", Content: "x"}}); err == nil {
+ t.Fatalf("expected error for 400 with api body")
+ }
+ // Plain HTTP error without api message
+ ts2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(500)
+ _, _ = w.Write([]byte("{}"))
+ }))
+ defer ts2.Close()
+ c2 := newOllama(ts2.URL, "m", nil).(ollamaClient)
+ c2.httpClient = ts2.Client()
+ if _, err := c2.Chat(context.Background(), []Message{{Role: "user", Content: "x"}}); err == nil {
+ t.Fatalf("expected error for 500")
+ }
}
type rtFunc func(*http.Request) (*http.Response, error)
+
func (f rtFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) }
func TestOllamaChat_HTTPError(t *testing.T) {
- c := newOllama("http://127.0.0.1:0", "m", nil).(ollamaClient)
- c.httpClient = &http.Client{Transport: rtFunc(func(*http.Request)(*http.Response,error){ return nil, fmt.Errorf("boom") })}
- if _, err := c.Chat(context.Background(), []Message{{Role:"user", Content:"x"}}); err == nil {
- t.Fatalf("expected http error path")
- }
+ c := newOllama("http://127.0.0.1:0", "m", nil).(ollamaClient)
+ c.httpClient = &http.Client{Transport: rtFunc(func(*http.Request) (*http.Response, error) { return nil, fmt.Errorf("boom") })}
+ if _, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "x"}}); err == nil {
+ t.Fatalf("expected http error path")
+ }
}
func TestOllamaChat_DecodeError(t *testing.T) {
- if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") }
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- _, _ = w.Write([]byte("{bad json}"))
- }))
- defer ts.Close()
- c := newOllama(ts.URL, "m", nil).(ollamaClient)
- c.httpClient = ts.Client()
- if _, err := c.Chat(context.Background(), []Message{{Role:"user", Content:"x"}}); err == nil {
- t.Fatalf("expected decode error")
- }
+ if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" {
+ t.Skip("skip network-bound tests in restricted environments")
+ }
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ _, _ = w.Write([]byte("{bad json}"))
+ }))
+ defer ts.Close()
+ c := newOllama(ts.URL, "m", nil).(ollamaClient)
+ c.httpClient = ts.Client()
+ if _, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "x"}}); err == nil {
+ t.Fatalf("expected decode error")
+ }
}
func TestHandleOllamaNon2xx_OK(t *testing.T) {
- resp := &http.Response{StatusCode: 200, Body: ioNopCloser(strings.NewReader(""))}
- if err := handleOllamaNon2xx(resp, time.Now()); err != nil { t.Fatalf("unexpected: %v", err) }
+ resp := &http.Response{StatusCode: 200, Body: ioNopCloser(strings.NewReader(""))}
+ if err := handleOllamaNon2xx(resp, time.Now()); err != nil {
+ t.Fatalf("unexpected: %v", err)
+ }
}
func TestOllamaChatStream_Success(t *testing.T) {
- if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") }
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.Header().Set("Content-Type", "application/json")
- // two JSON objects back-to-back
- _, _ = w.Write([]byte(`{"message":{"role":"assistant","content":"Hi"},"done":false}`))
- _, _ = w.Write([]byte(`{"message":{"role":"assistant","content":"!"},"done":true}`))
- }))
- defer ts.Close()
- c := newOllama(ts.URL, "m", nil).(ollamaClient)
- c.httpClient = ts.Client()
- var got strings.Builder
- if err := c.ChatStream(context.Background(), []Message{{Role:"user", Content:"x"}}, func(s string){ got.WriteString(s) }); err != nil {
- t.Fatalf("unexpected: %v", err)
- }
- if got.String() != "Hi!" { t.Fatalf("got %q", got.String()) }
+ if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" {
+ t.Skip("skip network-bound tests in restricted environments")
+ }
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ // two JSON objects back-to-back
+ _, _ = w.Write([]byte(`{"message":{"role":"assistant","content":"Hi"},"done":false}`))
+ _, _ = w.Write([]byte(`{"message":{"role":"assistant","content":"!"},"done":true}`))
+ }))
+ defer ts.Close()
+ c := newOllama(ts.URL, "m", nil).(ollamaClient)
+ c.httpClient = ts.Client()
+ var got strings.Builder
+ if err := c.ChatStream(context.Background(), []Message{{Role: "user", Content: "x"}}, func(s string) { got.WriteString(s) }); err != nil {
+ t.Fatalf("unexpected: %v", err)
+ }
+ if got.String() != "Hi!" {
+ t.Fatalf("got %q", got.String())
+ }
}
func TestOllamaChatStream_ErrorEvent(t *testing.T) {
- if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") }
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- _ = json.NewEncoder(w).Encode(map[string]any{"error":"oops"})
- }))
- defer ts.Close()
- c := newOllama(ts.URL, "m", nil).(ollamaClient)
- c.httpClient = ts.Client()
- if err := c.ChatStream(context.Background(), []Message{{Role:"user", Content:"x"}}, func(string){}); err == nil {
- t.Fatalf("expected stream error")
- }
+ if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" {
+ t.Skip("skip network-bound tests in restricted environments")
+ }
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ _ = json.NewEncoder(w).Encode(map[string]any{"error": "oops"})
+ }))
+ defer ts.Close()
+ c := newOllama(ts.URL, "m", nil).(ollamaClient)
+ c.httpClient = ts.Client()
+ if err := c.ChatStream(context.Background(), []Message{{Role: "user", Content: "x"}}, func(string) {}); err == nil {
+ t.Fatalf("expected stream error")
+ }
}
func TestOllamaChatStream_DecodeError(t *testing.T) {
- if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") }
- ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- _, _ = w.Write([]byte("{not json}"))
- }))
- defer ts.Close()
- c := newOllama(ts.URL, "m", nil).(ollamaClient)
- c.httpClient = ts.Client()
- if err := c.ChatStream(context.Background(), []Message{{Role:"user", Content:"x"}}, func(string){}); err == nil {
- t.Fatalf("expected decode error")
- }
+ if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" {
+ t.Skip("skip network-bound tests in restricted environments")
+ }
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ _, _ = w.Write([]byte("{not json}"))
+ }))
+ defer ts.Close()
+ c := newOllama(ts.URL, "m", nil).(ollamaClient)
+ c.httpClient = ts.Client()
+ if err := c.ChatStream(context.Background(), []Message{{Role: "user", Content: "x"}}, func(string) {}); err == nil {
+ t.Fatalf("expected decode error")
+ }
}
// small helper to construct an io.ReadCloser without importing extra packages
type readCloser struct{ *strings.Reader }
-func (readCloser) Close() error { return nil }
+
+func (readCloser) Close() error { return nil }
func ioNopCloser(r *strings.Reader) *readCloser { return &readCloser{r} }
diff --git a/internal/llm/openai_http_test.go b/internal/llm/openai_http_test.go
index ac7b897..cb4bfcb 100644
--- a/internal/llm/openai_http_test.go
+++ b/internal/llm/openai_http_test.go
@@ -1,143 +1,171 @@
package llm
import (
- "context"
- "encoding/json"
- "io"
- "net/http"
- "net/http/httptest"
- "testing"
- "strings"
- "time"
- "os"
+ "context"
+ "encoding/json"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "strings"
+ "testing"
+ "time"
)
func TestOpenAI_Chat_Success(t *testing.T) {
- if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") }
- srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if r.URL.Path != "/chat/completions" { t.Fatalf("unexpected path: %s", r.URL.Path) }
- _ = json.NewEncoder(w).Encode(map[string]any{"choices": []map[string]any{{"index":0, "message": map[string]string{"role":"assistant","content":"OK"}}}})
- }))
- defer srv.Close()
- c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient)
- c.httpClient = srv.Client()
- out, err := c.Chat(context.Background(), []Message{{Role:"user", Content:"hi"}})
- if err != nil || out != "OK" { t.Fatalf("openai chat: %v %q", err, out) }
+ if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" {
+ t.Skip("skip network-bound tests in restricted environments")
+ }
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path != "/chat/completions" {
+ t.Fatalf("unexpected path: %s", r.URL.Path)
+ }
+ _ = json.NewEncoder(w).Encode(map[string]any{"choices": []map[string]any{{"index": 0, "message": map[string]string{"role": "assistant", "content": "OK"}}}})
+ }))
+ defer srv.Close()
+ c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient)
+ c.httpClient = srv.Client()
+ out, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}})
+ if err != nil || out != "OK" {
+ t.Fatalf("openai chat: %v %q", err, out)
+ }
}
func TestOpenAI_Chat_MissingKey(t *testing.T) {
- c := newOpenAI("http://x", "g", "", f64p(0.2)).(openAIClient)
- if _, err := c.Chat(context.Background(), []Message{{Role:"user", Content:"hi"}}); err == nil { t.Fatalf("expected error for missing key") }
+ c := newOpenAI("http://x", "g", "", f64p(0.2)).(openAIClient)
+ if _, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}); err == nil {
+ t.Fatalf("expected error for missing key")
+ }
}
func TestOpenAI_ChatStream_SSE(t *testing.T) {
- if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") }
- srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- // Return SSE-like stream
- w.Header().Set("Content-Type", "text/event-stream")
- io.WriteString(w, "data: {\"choices\":[{\"delta\":{\"content\":\"Hi\"}}]}\n\n")
- io.WriteString(w, "data: [DONE]\n")
- }))
- defer srv.Close()
- c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient)
- c.httpClient = srv.Client()
- var got string
- err := c.ChatStream(context.Background(), []Message{{Role:"user", Content:"hi"}}, func(s string){ got += s })
- if err != nil || got != "Hi" { t.Fatalf("chat stream: %v %q", err, got) }
+ if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" {
+ t.Skip("skip network-bound tests in restricted environments")
+ }
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ // Return SSE-like stream
+ w.Header().Set("Content-Type", "text/event-stream")
+ io.WriteString(w, "data: {\"choices\":[{\"delta\":{\"content\":\"Hi\"}}]}\n\n")
+ io.WriteString(w, "data: [DONE]\n")
+ }))
+ defer srv.Close()
+ c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient)
+ c.httpClient = srv.Client()
+ var got string
+ err := c.ChatStream(context.Background(), []Message{{Role: "user", Content: "hi"}}, func(s string) { got += s })
+ if err != nil || got != "Hi" {
+ t.Fatalf("chat stream: %v %q", err, got)
+ }
}
func TestHandleOpenAINon2xx_NoErrorBody(t *testing.T) {
- resp := &http.Response{StatusCode: 500, Body: io.NopCloser(strings.NewReader("{}"))}
- if err := handleOpenAINon2xx(resp, time.Now()); err == nil { t.Fatalf("expected http error") }
+ resp := &http.Response{StatusCode: 500, Body: io.NopCloser(strings.NewReader("{}"))}
+ if err := handleOpenAINon2xx(resp, time.Now()); err == nil {
+ t.Fatalf("expected http error")
+ }
}
func TestOpenAI_ChatStream_SSE_ErrorChunk(t *testing.T) {
- if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") }
- srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.Header().Set("Content-Type", "text/event-stream")
- io.WriteString(w, "data: {\"error\":{\"message\":\"oops\"}}\n\n")
- io.WriteString(w, "data: [DONE]\n")
- }))
- defer srv.Close()
- c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient)
- c.httpClient = srv.Client()
- var got string
- if err := c.ChatStream(context.Background(), []Message{{Role:"user", Content:"hi"}}, func(s string){ got += s }); err == nil {
- t.Fatalf("expected error due to error chunk")
- }
+ if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" {
+ t.Skip("skip network-bound tests in restricted environments")
+ }
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "text/event-stream")
+ io.WriteString(w, "data: {\"error\":{\"message\":\"oops\"}}\n\n")
+ io.WriteString(w, "data: [DONE]\n")
+ }))
+ defer srv.Close()
+ c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient)
+ c.httpClient = srv.Client()
+ var got string
+ if err := c.ChatStream(context.Background(), []Message{{Role: "user", Content: "hi"}}, func(s string) { got += s }); err == nil {
+ t.Fatalf("expected error due to error chunk")
+ }
}
func TestOpenAI_Chat_NoChoices_Error(t *testing.T) {
- if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") }
- srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- _ = json.NewEncoder(w).Encode(map[string]any{"choices": []any{}})
- }))
- defer srv.Close()
- c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient)
- c.httpClient = srv.Client()
- if _, err := c.Chat(context.Background(), []Message{{Role:"user", Content:"hi"}}); err == nil {
- t.Fatalf("expected error when choices empty")
- }
+ if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" {
+ t.Skip("skip network-bound tests in restricted environments")
+ }
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ _ = json.NewEncoder(w).Encode(map[string]any{"choices": []any{}})
+ }))
+ defer srv.Close()
+ c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient)
+ c.httpClient = srv.Client()
+ if _, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}); err == nil {
+ t.Fatalf("expected error when choices empty")
+ }
}
func TestOpenAI_ChatStream_SSE_EmptyDelta_NoError(t *testing.T) {
- if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") }
- srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.Header().Set("Content-Type", "text/event-stream")
- io.WriteString(w, "data: {\\\"choices\\\":[{\\\"delta\\\":{\\\"content\\\":\\\"\\\"}}]}\\n\\n")
- io.WriteString(w, "data: [DONE]\\n")
- }))
- defer srv.Close()
- c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient)
- c.httpClient = srv.Client()
- var got string
- if err := c.ChatStream(context.Background(), []Message{{Role:"user", Content:"hi"}}, func(s string){ got += s }); err != nil {
- t.Fatalf("unexpected error for empty delta: %v", err)
- }
- if got != "" { t.Fatalf("expected no output for empty delta, got %q", got) }
+ if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" {
+ t.Skip("skip network-bound tests in restricted environments")
+ }
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "text/event-stream")
+ io.WriteString(w, "data: {\\\"choices\\\":[{\\\"delta\\\":{\\\"content\\\":\\\"\\\"}}]}\\n\\n")
+ io.WriteString(w, "data: [DONE]\\n")
+ }))
+ defer srv.Close()
+ c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient)
+ c.httpClient = srv.Client()
+ var got string
+ if err := c.ChatStream(context.Background(), []Message{{Role: "user", Content: "hi"}}, func(s string) { got += s }); err != nil {
+ t.Fatalf("unexpected error for empty delta: %v", err)
+ }
+ if got != "" {
+ t.Fatalf("expected no output for empty delta, got %q", got)
+ }
}
func TestOpenAI_Chat_DecodeError_StatusOK(t *testing.T) {
- if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") }
- // Return status 200 but invalid JSON body; Chat should return an error
- srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(200)
- io.WriteString(w, "{invalid")
- }))
- defer srv.Close()
- c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient)
- c.httpClient = srv.Client()
- if _, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}); err == nil {
- t.Fatalf("expected decode error for invalid JSON body")
- }
+ if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" {
+ t.Skip("skip network-bound tests in restricted environments")
+ }
+ // Return status 200 but invalid JSON body; Chat should return an error
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(200)
+ io.WriteString(w, "{invalid")
+ }))
+ defer srv.Close()
+ c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient)
+ c.httpClient = srv.Client()
+ if _, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}); err == nil {
+ t.Fatalf("expected decode error for invalid JSON body")
+ }
}
func TestOpenAI_Chat_MultiChoiceAndErrorBody(t *testing.T) {
- if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") }
- // Multi-choice success: return two choices with different finish reasons
- srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- _ = json.NewEncoder(w).Encode(map[string]any{
- "choices": []map[string]any{
- {"index": 0, "finish_reason": "stop", "message": map[string]string{"role": "assistant", "content": "FIRST"}},
- {"index": 1, "finish_reason": "length", "message": map[string]string{"role": "assistant", "content": "SECOND"}},
- },
- })
- }))
- defer srv.Close()
- c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient)
- c.httpClient = srv.Client()
- out, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}})
- if err != nil || out != "FIRST" { t.Fatalf("openai multi-choice: %v %q", err, out) }
+ if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" {
+ t.Skip("skip network-bound tests in restricted environments")
+ }
+ // Multi-choice success: return two choices with different finish reasons
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ _ = json.NewEncoder(w).Encode(map[string]any{
+ "choices": []map[string]any{
+ {"index": 0, "finish_reason": "stop", "message": map[string]string{"role": "assistant", "content": "FIRST"}},
+ {"index": 1, "finish_reason": "length", "message": map[string]string{"role": "assistant", "content": "SECOND"}},
+ },
+ })
+ }))
+ defer srv.Close()
+ c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient)
+ c.httpClient = srv.Client()
+ out, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}})
+ if err != nil || out != "FIRST" {
+ t.Fatalf("openai multi-choice: %v %q", err, out)
+ }
- // Error body case: non-2xx with error message
- srv2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(400)
- _ = json.NewEncoder(w).Encode(map[string]any{"error": map[string]any{"message": "bad", "type": "invalid"}})
- }))
- defer srv2.Close()
- c2 := newOpenAI(srv2.URL, "g", "KEY", f64p(0.2)).(openAIClient)
- c2.httpClient = srv2.Client()
- if _, err := c2.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}); err == nil {
- t.Fatalf("expected error from non-2xx with error body")
- }
+ // Error body case: non-2xx with error message
+ srv2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(400)
+ _ = json.NewEncoder(w).Encode(map[string]any{"error": map[string]any{"message": "bad", "type": "invalid"}})
+ }))
+ defer srv2.Close()
+ c2 := newOpenAI(srv2.URL, "g", "KEY", f64p(0.2)).(openAIClient)
+ c2.httpClient = srv2.Client()
+ if _, err := c2.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}); err == nil {
+ t.Fatalf("expected error from non-2xx with error body")
+ }
}
diff --git a/internal/llm/openai_sse_negative_test.go b/internal/llm/openai_sse_negative_test.go
index 8da5526..de2ff71 100644
--- a/internal/llm/openai_sse_negative_test.go
+++ b/internal/llm/openai_sse_negative_test.go
@@ -1,28 +1,32 @@
package llm
import (
- "context"
- "io"
- "net/http"
- "net/http/httptest"
- "testing"
- "os"
+ "context"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "testing"
)
func TestOpenAI_ChatStream_SSE_MalformedChunk(t *testing.T) {
- if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") }
- // Malformed JSON chunk should be skipped; no onDelta calls; no error.
- srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- w.Header().Set("Content-Type", "text/event-stream")
- io.WriteString(w, "data: {not json}\n\n")
- io.WriteString(w, "data: [DONE]\n")
- }))
- defer srv.Close()
- c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient)
- c.httpClient = srv.Client()
- var got string
- if err := c.ChatStream(context.Background(), []Message{{Role: "user", Content: "hi"}}, func(s string){ got += s }); err != nil {
- t.Fatalf("unexpected error for malformed chunk: %v", err)
- }
- if got != "" { t.Fatalf("expected no deltas for malformed chunk, got %q", got) }
+ if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" {
+ t.Skip("skip network-bound tests in restricted environments")
+ }
+ // Malformed JSON chunk should be skipped; no onDelta calls; no error.
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "text/event-stream")
+ io.WriteString(w, "data: {not json}\n\n")
+ io.WriteString(w, "data: [DONE]\n")
+ }))
+ defer srv.Close()
+ c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient)
+ c.httpClient = srv.Client()
+ var got string
+ if err := c.ChatStream(context.Background(), []Message{{Role: "user", Content: "hi"}}, func(s string) { got += s }); err != nil {
+ t.Fatalf("unexpected error for malformed chunk: %v", err)
+ }
+ if got != "" {
+ t.Fatalf("expected no deltas for malformed chunk, got %q", got)
+ }
}
diff --git a/internal/llm/openai_test.go b/internal/llm/openai_test.go
index f50b171..f7ce080 100644
--- a/internal/llm/openai_test.go
+++ b/internal/llm/openai_test.go
@@ -1,44 +1,67 @@
package llm
import (
- "bytes"
- "encoding/json"
- "io"
- "net/http"
- "strings"
- "testing"
- "time"
+ "bytes"
+ "encoding/json"
+ "io"
+ "net/http"
+ "strings"
+ "testing"
+ "time"
)
func f64p(v float64) *float64 { return &v }
func TestBuildOAChatRequest_TempFallbackAndFields(t *testing.T) {
- o := Options{Model: "m1", Temperature: 0, MaxTokens: 42, Stop: []string{"END"}}
- msgs := []Message{{Role: "user", Content: "hi"}}
- req := buildOAChatRequest(o, msgs, f64p(0.3), false)
- if req.Model != "m1" || req.Stream { t.Fatalf("model/stream mismatch: %+v", req) }
- if req.Temperature == nil || *req.Temperature != 0.3 { t.Fatalf("expected default temp 0.3, got %#v", req.Temperature) }
- if req.MaxTokens == nil || *req.MaxTokens != 42 { t.Fatalf("expected max tokens 42") }
- if len(req.Stop) != 1 || req.Stop[0] != "END" { t.Fatalf("stop not propagated: %#v", req.Stop) }
- if len(req.Messages) != 1 || req.Messages[0].Content != "hi" { t.Fatalf("messages not copied") }
+ o := Options{Model: "m1", Temperature: 0, MaxTokens: 42, Stop: []string{"END"}}
+ msgs := []Message{{Role: "user", Content: "hi"}}
+ req := buildOAChatRequest(o, msgs, f64p(0.3), false)
+ if req.Model != "m1" || req.Stream {
+ t.Fatalf("model/stream mismatch: %+v", req)
+ }
+ if req.Temperature == nil || *req.Temperature != 0.3 {
+ t.Fatalf("expected default temp 0.3, got %#v", req.Temperature)
+ }
+ if req.MaxTokens == nil || *req.MaxTokens != 42 {
+ t.Fatalf("expected max tokens 42")
+ }
+ if len(req.Stop) != 1 || req.Stop[0] != "END" {
+ t.Fatalf("stop not propagated: %#v", req.Stop)
+ }
+ if len(req.Messages) != 1 || req.Messages[0].Content != "hi" {
+ t.Fatalf("messages not copied")
+ }
- // stream on
- req2 := buildOAChatRequest(o, msgs, f64p(0.3), true)
- if !req2.Stream { t.Fatalf("expected stream=true") }
+ // stream on
+ req2 := buildOAChatRequest(o, msgs, f64p(0.3), true)
+ if !req2.Stream {
+ t.Fatalf("expected stream=true")
+ }
}
func TestHandleOpenAINon2xx_WithAPIError(t *testing.T) {
- api := oaChatResponse{Error: &struct{ Message string `json:"message"`; Type string `json:"type"`; Param any `json:"param"`; Code any `json:"code"` }{Message: "bad", Type: "invalid"}}
- b, _ := json.Marshal(api)
- resp := &http.Response{StatusCode: 400, Body: io.NopCloser(bytes.NewReader(b))}
- if err := handleOpenAINon2xx(resp, time.Now()); err == nil { t.Fatalf("expected error for non-2xx with body") }
+ api := oaChatResponse{Error: &struct {
+ Message string `json:"message"`
+ Type string `json:"type"`
+ Param any `json:"param"`
+ Code any `json:"code"`
+ }{Message: "bad", Type: "invalid"}}
+ b, _ := json.Marshal(api)
+ resp := &http.Response{StatusCode: 400, Body: io.NopCloser(bytes.NewReader(b))}
+ if err := handleOpenAINon2xx(resp, time.Now()); err == nil {
+ t.Fatalf("expected error for non-2xx with body")
+ }
}
func TestParseOpenAIStream_DeliversChunks(t *testing.T) {
- stream := "data: {\"choices\":[{\"delta\":{\"content\":\"Hi\"}}]}\n\n" +
- "data: [DONE]\n"
- resp := &http.Response{Body: io.NopCloser(strings.NewReader(stream))}
- var got strings.Builder
- if err := parseOpenAIStream(resp, time.Now(), func(s string){ got.WriteString(s) }); err != nil { t.Fatalf("unexpected error: %v", err) }
- if got.String() != "Hi" { t.Fatalf("got %q want %q", got.String(), "Hi") }
+ stream := "data: {\"choices\":[{\"delta\":{\"content\":\"Hi\"}}]}\n\n" +
+ "data: [DONE]\n"
+ resp := &http.Response{Body: io.NopCloser(strings.NewReader(stream))}
+ var got strings.Builder
+ if err := parseOpenAIStream(resp, time.Now(), func(s string) { got.WriteString(s) }); err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if got.String() != "Hi" {
+ t.Fatalf("got %q want %q", got.String(), "Hi")
+ }
}
diff --git a/internal/llm/provider.go b/internal/llm/provider.go
index 7ab58c6..88c280c 100644
--- a/internal/llm/provider.go
+++ b/internal/llm/provider.go
@@ -28,20 +28,20 @@ type Client interface {
// token-by-token streaming responses. Callers can type-assert to Streamer and
// fall back to Client.Chat when not implemented.
type Streamer interface {
- // ChatStream sends chat messages and invokes onDelta with incremental text
- // chunks as they are produced by the model. Implementations should call
- // onDelta with empty strings sparingly (prefer only non-empty chunks).
- ChatStream(ctx context.Context, messages []Message, onDelta func(string), opts ...RequestOption) error
+ // ChatStream sends chat messages and invokes onDelta with incremental text
+ // chunks as they are produced by the model. Implementations should call
+ // onDelta with empty strings sparingly (prefer only non-empty chunks).
+ ChatStream(ctx context.Context, messages []Message, onDelta func(string), opts ...RequestOption) error
}
// CodeCompleter is an optional interface for providers that support a
// prompt/suffix code-completion API (e.g., Copilot Codex endpoint). Clients
// can type-assert to this and prefer it over chat when available.
type CodeCompleter interface {
- // CodeCompletion requests up to n suggestions given a left-hand prompt and
- // right-hand suffix around the cursor. Language is advisory and may be
- // ignored. Temperature applies when provider supports it.
- CodeCompletion(ctx context.Context, prompt string, suffix string, n int, language string, temperature float64) ([]string, error)
+ // CodeCompletion requests up to n suggestions given a left-hand prompt and
+ // right-hand suffix around the cursor. Language is advisory and may be
+ // ignored. Temperature applies when provider supports it.
+ CodeCompletion(ctx context.Context, prompt string, suffix string, n int, language string, temperature float64) ([]string, error)
}
// Options for a request. Providers may ignore unsupported fields.
@@ -64,56 +64,56 @@ func WithStop(stop ...string) RequestOption {
// Config defines provider configuration read from the Hexai config file.
type Config struct {
- Provider string
- // OpenAI options
- OpenAIBaseURL string
- OpenAIModel string
- OpenAITemperature *float64
- // Ollama options
- OllamaBaseURL string
- OllamaModel string
- OllamaTemperature *float64
- // Copilot options
- CopilotBaseURL string
- CopilotModel string
- CopilotTemperature *float64
+ Provider string
+ // OpenAI options
+ OpenAIBaseURL string
+ OpenAIModel string
+ OpenAITemperature *float64
+ // Ollama options
+ OllamaBaseURL string
+ OllamaModel string
+ OllamaTemperature *float64
+ // Copilot options
+ CopilotBaseURL string
+ CopilotModel string
+ CopilotTemperature *float64
}
// NewFromConfig creates an LLM client using only the supplied configuration.
// The OpenAI API key is supplied separately and may be read from the environment
// by the caller; other environment-based configuration is not used.
func NewFromConfig(cfg Config, openAIAPIKey, copilotAPIKey string) (Client, error) {
- p := strings.ToLower(strings.TrimSpace(cfg.Provider))
- if p == "" {
- p = "openai"
- }
- switch p {
- case "openai":
- if strings.TrimSpace(openAIAPIKey) == "" {
- return nil, errors.New("missing OPENAI_API_KEY for provider openai")
- }
- // Set coding-friendly default temperature if none provided
- if cfg.OpenAITemperature == nil {
- t := 0.2
- cfg.OpenAITemperature = &t
- }
- return newOpenAI(cfg.OpenAIBaseURL, cfg.OpenAIModel, openAIAPIKey, cfg.OpenAITemperature), nil
- case "ollama":
- if cfg.OllamaTemperature == nil {
- t := 0.2
- cfg.OllamaTemperature = &t
- }
- return newOllama(cfg.OllamaBaseURL, cfg.OllamaModel, cfg.OllamaTemperature), nil
- case "copilot":
- if strings.TrimSpace(copilotAPIKey) == "" {
- return nil, errors.New("missing COPILOT_API_KEY for provider copilot")
- }
- if cfg.CopilotTemperature == nil {
- t := 0.2
- cfg.CopilotTemperature = &t
- }
- return newCopilot(cfg.CopilotBaseURL, cfg.CopilotModel, copilotAPIKey, cfg.CopilotTemperature), nil
- default:
- return nil, errors.New("unknown LLM provider: " + p)
- }
+ p := strings.ToLower(strings.TrimSpace(cfg.Provider))
+ if p == "" {
+ p = "openai"
+ }
+ switch p {
+ case "openai":
+ if strings.TrimSpace(openAIAPIKey) == "" {
+ return nil, errors.New("missing OPENAI_API_KEY for provider openai")
+ }
+ // Set coding-friendly default temperature if none provided
+ if cfg.OpenAITemperature == nil {
+ t := 0.2
+ cfg.OpenAITemperature = &t
+ }
+ return newOpenAI(cfg.OpenAIBaseURL, cfg.OpenAIModel, openAIAPIKey, cfg.OpenAITemperature), nil
+ case "ollama":
+ if cfg.OllamaTemperature == nil {
+ t := 0.2
+ cfg.OllamaTemperature = &t
+ }
+ return newOllama(cfg.OllamaBaseURL, cfg.OllamaModel, cfg.OllamaTemperature), nil
+ case "copilot":
+ if strings.TrimSpace(copilotAPIKey) == "" {
+ return nil, errors.New("missing COPILOT_API_KEY for provider copilot")
+ }
+ if cfg.CopilotTemperature == nil {
+ t := 0.2
+ cfg.CopilotTemperature = &t
+ }
+ return newCopilot(cfg.CopilotBaseURL, cfg.CopilotModel, copilotAPIKey, cfg.CopilotTemperature), nil
+ default:
+ return nil, errors.New("unknown LLM provider: " + p)
+ }
}
diff --git a/internal/llm/provider_more_test.go b/internal/llm/provider_more_test.go
index bd08552..d7469af 100644
--- a/internal/llm/provider_more_test.go
+++ b/internal/llm/provider_more_test.go
@@ -3,24 +3,27 @@ package llm
import "testing"
func TestWithOptions_Apply(t *testing.T) {
- o := Options{}
- WithModel("m")(&o)
- WithTemperature(0.7)(&o)
- WithMaxTokens(123)(&o)
- WithStop("END")(&o)
- if o.Model != "m" || o.Temperature != 0.7 || o.MaxTokens != 123 || len(o.Stop) != 1 || o.Stop[0] != "END" {
- t.Fatalf("options not applied correctly: %+v", o)
- }
+ o := Options{}
+ WithModel("m")(&o)
+ WithTemperature(0.7)(&o)
+ WithMaxTokens(123)(&o)
+ WithStop("END")(&o)
+ if o.Model != "m" || o.Temperature != 0.7 || o.MaxTokens != 123 || len(o.Stop) != 1 || o.Stop[0] != "END" {
+ t.Fatalf("options not applied correctly: %+v", o)
+ }
}
func TestNewFromConfig_Success_OpenAI_And_Copilot(t *testing.T) {
- // OpenAI success
- oc := Config{Provider: "openai", OpenAIBaseURL: "http://x", OpenAIModel: "gpt"}
- c, err := NewFromConfig(oc, "KEY", "")
- if err != nil || c == nil || c.Name() != "openai" || c.DefaultModel() == "" { t.Fatalf("openai new: %v %v", c, err) }
- // Copilot success
- cc := Config{Provider: "copilot", CopilotBaseURL: "http://x", CopilotModel: "gpt-4o-mini"}
- c2, err := NewFromConfig(cc, "", "KEY")
- if err != nil || c2 == nil || c2.Name() != "copilot" || c2.DefaultModel() == "" { t.Fatalf("copilot new: %v %v", c2, err) }
+ // OpenAI success
+ oc := Config{Provider: "openai", OpenAIBaseURL: "http://x", OpenAIModel: "gpt"}
+ c, err := NewFromConfig(oc, "KEY", "")
+ if err != nil || c == nil || c.Name() != "openai" || c.DefaultModel() == "" {
+ t.Fatalf("openai new: %v %v", c, err)
+ }
+ // Copilot success
+ cc := Config{Provider: "copilot", CopilotBaseURL: "http://x", CopilotModel: "gpt-4o-mini"}
+ c2, err := NewFromConfig(cc, "", "KEY")
+ if err != nil || c2 == nil || c2.Name() != "copilot" || c2.DefaultModel() == "" {
+ t.Fatalf("copilot new: %v %v", c2, err)
+ }
}
-
diff --git a/internal/llm/provider_test.go b/internal/llm/provider_test.go
index 1412b3c..29e2514 100644
--- a/internal/llm/provider_test.go
+++ b/internal/llm/provider_test.go
@@ -1,21 +1,29 @@
package llm
import (
- "context"
- "testing"
+ "context"
+ "testing"
)
func TestNewFromConfig_DefaultsAndErrors(t *testing.T) {
- // Unknown provider
- if _, err := NewFromConfig(Config{Provider:"bogus"}, "", ""); err == nil { t.Fatalf("expected error for unknown provider") }
- // OpenAI missing key
- if _, err := NewFromConfig(Config{Provider:"openai", OpenAIModel:"g"}, "", ""); err == nil { t.Fatalf("expected key error") }
- // Copilot missing key
- if _, err := NewFromConfig(Config{Provider:"copilot", CopilotModel:"m"}, "", ""); err == nil { t.Fatalf("expected key error") }
+ // Unknown provider
+ if _, err := NewFromConfig(Config{Provider: "bogus"}, "", ""); err == nil {
+ t.Fatalf("expected error for unknown provider")
+ }
+ // OpenAI missing key
+ if _, err := NewFromConfig(Config{Provider: "openai", OpenAIModel: "g"}, "", ""); err == nil {
+ t.Fatalf("expected key error")
+ }
+ // Copilot missing key
+ if _, err := NewFromConfig(Config{Provider: "copilot", CopilotModel: "m"}, "", ""); err == nil {
+ t.Fatalf("expected key error")
+ }
}
type fakeClientMin struct{}
-func (fakeClientMin) Chat(context.Context, []Message, ...RequestOption) (string, error) { return "", nil }
-func (fakeClientMin) Name() string { return "x" }
-func (fakeClientMin) DefaultModel() string { return "m" }
+func (fakeClientMin) Chat(context.Context, []Message, ...RequestOption) (string, error) {
+ return "", nil
+}
+func (fakeClientMin) Name() string { return "x" }
+func (fakeClientMin) DefaultModel() string { return "m" }
diff --git a/internal/llm/util_test.go b/internal/llm/util_test.go
index acffe5a..137e149 100644
--- a/internal/llm/util_test.go
+++ b/internal/llm/util_test.go
@@ -3,7 +3,8 @@ package llm
import "testing"
func TestNilStringErr(t *testing.T) {
- s, err := nilStringErr("boom")
- if s != "" || err == nil { t.Fatalf("expected empty string and error") }
+ s, err := nilStringErr("boom")
+ if s != "" || err == nil {
+ t.Fatalf("expected empty string and error")
+ }
}
-