diff options
| author | Paul Buetow <paul@buetow.org> | 2025-09-06 10:56:27 +0300 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2025-09-06 10:56:27 +0300 |
| commit | 320de746f7a2985b60c8564a0e65bdf231e840b7 (patch) | |
| tree | e70bcf50813dba411afa2934e774383124bbc99e /internal/llm | |
| parent | 06247527d5170f329b454b42f59a3e4434ab1f4b (diff) | |
use gofumpt
Diffstat (limited to 'internal/llm')
| -rw-r--r-- | internal/llm/copilot.go | 297 | ||||
| -rw-r--r-- | internal/llm/copilot_http_test.go | 392 | ||||
| -rw-r--r-- | internal/llm/ollama_test.go | 296 | ||||
| -rw-r--r-- | internal/llm/openai_http_test.go | 250 | ||||
| -rw-r--r-- | internal/llm/openai_sse_negative_test.go | 46 | ||||
| -rw-r--r-- | internal/llm/openai_test.go | 79 | ||||
| -rw-r--r-- | internal/llm/provider.go | 108 | ||||
| -rw-r--r-- | internal/llm/provider_more_test.go | 37 | ||||
| -rw-r--r-- | internal/llm/provider_test.go | 30 | ||||
| -rw-r--r-- | internal/llm/util_test.go | 7 |
10 files changed, 876 insertions, 666 deletions
diff --git a/internal/llm/copilot.go b/internal/llm/copilot.go index 16eeda6..d3b1a9d 100644 --- a/internal/llm/copilot.go +++ b/internal/llm/copilot.go @@ -4,6 +4,7 @@ package llm import ( "bytes" "context" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -13,7 +14,6 @@ import ( "strings" "time" - "encoding/base64" appver "codeberg.org/snonux/hexai/internal" "codeberg.org/snonux/hexai/internal/logging" ) @@ -162,10 +162,14 @@ func buildCopilotChatRequest(o Options, messages []Message, defaultTemp *float64 } func (c copilotClient) postJSON(ctx context.Context, url string, body []byte, headers map[string]string) (*http.Response, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) - if err != nil { return nil, err } - for k, v := range headers { req.Header.Set(k, v) } - return c.httpClient.Do(req) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return nil, err + } + for k, v := range headers { + req.Header.Set(k, v) + } + return c.httpClient.Do(req) } func handleCopilotNon2xx(resp *http.Response, start time.Time) error { @@ -194,55 +198,73 @@ func decodeCopilotChat(resp *http.Response, start time.Time) (copilotChatRespons // --- Copilot session token management --- type ghCopilotTokenResp struct { - Token string `json:"token"` + Token string `json:"token"` } func (c *copilotClient) ensureSession(ctx context.Context) error { - // If token valid for >60s, reuse - if c.sessionToken != "" && time.Now().Add(60*time.Second).Before(c.tokenExpiry) { - return nil - } - if strings.TrimSpace(c.apiKey) == "" { - return errors.New("missing Copilot API key") - } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://api.github.com/copilot_internal/v2/token", nil) - if err != nil { return err } - req.Header.Set("Authorization", "Bearer "+c.apiKey) - req.Header.Set("Accept", "application/json") - req.Header.Set("User-Agent", "hexai/"+appver.Version) - resp, err := c.httpClient.Do(req) - if err != nil { return err } - defer resp.Body.Close() - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return fmt.Errorf("copilot token http error: %d", resp.StatusCode) - } - var out ghCopilotTokenResp - if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { return err } - if strings.TrimSpace(out.Token) == "" { return errors.New("empty copilot session token") } - // Parse JWT exp - exp := parseJWTExp(out.Token) - if exp.IsZero() { exp = time.Now().Add(10 * time.Minute) } - c.sessionToken = out.Token - c.tokenExpiry = exp - return nil + // If token valid for >60s, reuse + if c.sessionToken != "" && time.Now().Add(60*time.Second).Before(c.tokenExpiry) { + return nil + } + if strings.TrimSpace(c.apiKey) == "" { + return errors.New("missing Copilot API key") + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://api.github.com/copilot_internal/v2/token", nil) + if err != nil { + return err + } + req.Header.Set("Authorization", "Bearer "+c.apiKey) + req.Header.Set("Accept", "application/json") + req.Header.Set("User-Agent", "hexai/"+appver.Version) + resp, err := c.httpClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("copilot token http error: %d", resp.StatusCode) + } + var out ghCopilotTokenResp + if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { + return err + } + if strings.TrimSpace(out.Token) == "" { + return errors.New("empty copilot session token") + } + // Parse JWT exp + exp := parseJWTExp(out.Token) + if exp.IsZero() { + exp = time.Now().Add(10 * time.Minute) + } + c.sessionToken = out.Token + c.tokenExpiry = exp + return nil } var jwtExpRe = regexp.MustCompile(`"exp"\s*:\s*([0-9]+)`) // fallback if we can't base64 decode func parseJWTExp(token string) time.Time { - parts := strings.Split(token, ".") - if len(parts) < 2 { return time.Time{} } - b, err := base64.RawURLEncoding.DecodeString(parts[1]) - if err != nil { - if m := jwtExpRe.FindStringSubmatch(token); len(m) == 2 { - if n, err2 := parseInt64(m[1]); err2 == nil { return time.Unix(n, 0) } - } - return time.Time{} - } - var payload struct{ Exp int64 `json:"exp"` } - _ = json.Unmarshal(b, &payload) - if payload.Exp == 0 { return time.Time{} } - return time.Unix(payload.Exp, 0) + parts := strings.Split(token, ".") + if len(parts) < 2 { + return time.Time{} + } + b, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + if m := jwtExpRe.FindStringSubmatch(token); len(m) == 2 { + if n, err2 := parseInt64(m[1]); err2 == nil { + return time.Unix(n, 0) + } + } + return time.Time{} + } + var payload struct { + Exp int64 `json:"exp"` + } + _ = json.Unmarshal(b, &payload) + if payload.Exp == 0 { + return time.Time{} + } + return time.Unix(payload.Exp, 0) } func parseInt64(s string) (int64, error) { var n int64; _, err := fmt.Sscan(s, &n); return n, err } @@ -250,99 +272,120 @@ func parseInt64(s string) (int64, error) { var n int64; _, err := fmt.Sscan(s, & // --- Copilot headers --- func (c *copilotClient) headersChat() map[string]string { - _ = c.ensureSession(context.Background()) - h := map[string]string{ - "Content-Type": "application/json; charset=utf-8", - "Accept": "application/json", - "Authorization": "Bearer " + c.sessionToken, - "User-Agent": "GitHubCopilotChat/0.8.0", - "Editor-Plugin-Version": "copilot-chat/0.8.0", - "Editor-Version": "vscode/1.85.1", - "Openai-Intent": "conversation-panel", - "Openai-Organization": "github-copilot", - "VScode-MachineId": randHex(64), - "VScode-SessionId": randHex(8) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(12), - "X-Request-Id": randHex(8) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(12), - } - return h + _ = c.ensureSession(context.Background()) + h := map[string]string{ + "Content-Type": "application/json; charset=utf-8", + "Accept": "application/json", + "Authorization": "Bearer " + c.sessionToken, + "User-Agent": "GitHubCopilotChat/0.8.0", + "Editor-Plugin-Version": "copilot-chat/0.8.0", + "Editor-Version": "vscode/1.85.1", + "Openai-Intent": "conversation-panel", + "Openai-Organization": "github-copilot", + "VScode-MachineId": randHex(64), + "VScode-SessionId": randHex(8) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(12), + "X-Request-Id": randHex(8) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(12), + } + return h } func (c *copilotClient) headersGhost() map[string]string { - _ = c.ensureSession(context.Background()) - h := map[string]string{ - "Content-Type": "application/json; charset=utf-8", - "Accept": "*/*", - "Authorization": "Bearer " + c.sessionToken, - "User-Agent": "GithubCopilot/1.155.0", - "Editor-Plugin-Version": "copilot/1.155.0", - "Editor-Version": "vscode/1.85.1", - "Openai-Intent": "copilot-ghost", - "Openai-Organization": "github-copilot", - "VScode-MachineId": randHex(64), - "VScode-SessionId": randHex(8) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(12), - "X-Request-Id": randHex(8) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(12), - } - return h + _ = c.ensureSession(context.Background()) + h := map[string]string{ + "Content-Type": "application/json; charset=utf-8", + "Accept": "*/*", + "Authorization": "Bearer " + c.sessionToken, + "User-Agent": "GithubCopilot/1.155.0", + "Editor-Plugin-Version": "copilot/1.155.0", + "Editor-Version": "vscode/1.85.1", + "Openai-Intent": "copilot-ghost", + "Openai-Organization": "github-copilot", + "VScode-MachineId": randHex(64), + "VScode-SessionId": randHex(8) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(12), + "X-Request-Id": randHex(8) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(12), + } + return h } func randHex(n int) string { - const hex = "0123456789abcdef" - b := make([]byte, n) - for i := range b { - b[i] = hex[int(time.Now().UnixNano()+int64(i))%len(hex)] - } - return string(b) + const hex = "0123456789abcdef" + b := make([]byte, n) + for i := range b { + b[i] = hex[int(time.Now().UnixNano()+int64(i))%len(hex)] + } + return string(b) } // --- Codex-style code completion --- // CodeCompletion implements CodeCompleter; returns up to n suggestions. func (c copilotClient) CodeCompletion(ctx context.Context, prompt string, suffix string, n int, language string, temperature float64) ([]string, error) { - if strings.TrimSpace(c.apiKey) == "" { return nil, errors.New("missing Copilot API key") } - if err := c.ensureSession(ctx); err != nil { return nil, err } - if n <= 0 { n = 1 } - maxTokens := 500 - body := map[string]any{ - "extra": map[string]any{ - "language": language, - "next_indent": 0, - "prompt_tokens": 500, - "suffix_tokens": 400, - "trim_by_indentation": true, - }, - "max_tokens": maxTokens, - "n": n, - "nwo": "hexai", - "prompt": prompt, - "stop": []string{"\n\n"}, - "stream": true, - "suffix": suffix, - "temperature": temperature, - "top_p": 1, - } - buf, _ := json.Marshal(body) - url := "https://copilot-proxy.githubusercontent.com/v1/engines/copilot-codex/completions" - resp, err := c.postJSON(ctx, url, buf, c.headersGhost()) - if err != nil { return nil, err } - defer resp.Body.Close() - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return nil, fmt.Errorf("copilot codex http error: %d", resp.StatusCode) - } - // Read all and parse lines that start with "data: " accumulating by index - raw, _ := io.ReadAll(resp.Body) - byIndex := make(map[int]string) - lines := strings.Split(string(raw), "\n") - for _, ln := range lines { - if !strings.HasPrefix(ln, "data: ") { continue } - var evt struct{ Choices []struct{ Index int `json:"index"`; Text string `json:"text"` } `json:"choices"` } - if err := json.Unmarshal([]byte(strings.TrimPrefix(ln, "data: ")), &evt); err != nil { continue } - for _, ch := range evt.Choices { byIndex[ch.Index] += ch.Text } - } - out := make([]string, 0, len(byIndex)) - for i := 0; i < n; i++ { - if s, ok := byIndex[i]; ok && strings.TrimSpace(s) != "" { out = append(out, s) } - } - return out, nil + if strings.TrimSpace(c.apiKey) == "" { + return nil, errors.New("missing Copilot API key") + } + if err := c.ensureSession(ctx); err != nil { + return nil, err + } + if n <= 0 { + n = 1 + } + maxTokens := 500 + body := map[string]any{ + "extra": map[string]any{ + "language": language, + "next_indent": 0, + "prompt_tokens": 500, + "suffix_tokens": 400, + "trim_by_indentation": true, + }, + "max_tokens": maxTokens, + "n": n, + "nwo": "hexai", + "prompt": prompt, + "stop": []string{"\n\n"}, + "stream": true, + "suffix": suffix, + "temperature": temperature, + "top_p": 1, + } + buf, _ := json.Marshal(body) + url := "https://copilot-proxy.githubusercontent.com/v1/engines/copilot-codex/completions" + resp, err := c.postJSON(ctx, url, buf, c.headersGhost()) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("copilot codex http error: %d", resp.StatusCode) + } + // Read all and parse lines that start with "data: " accumulating by index + raw, _ := io.ReadAll(resp.Body) + byIndex := make(map[int]string) + lines := strings.Split(string(raw), "\n") + for _, ln := range lines { + if !strings.HasPrefix(ln, "data: ") { + continue + } + var evt struct { + Choices []struct { + Index int `json:"index"` + Text string `json:"text"` + } `json:"choices"` + } + if err := json.Unmarshal([]byte(strings.TrimPrefix(ln, "data: ")), &evt); err != nil { + continue + } + for _, ch := range evt.Choices { + byIndex[ch.Index] += ch.Text + } + } + out := make([]string, 0, len(byIndex)) + for i := 0; i < n; i++ { + if s, ok := byIndex[i]; ok && strings.TrimSpace(s) != "" { + out = append(out, s) + } + } + return out, nil } // newLineDataReader wraps a streaming body and exposes a JSON decoder that diff --git a/internal/llm/copilot_http_test.go b/internal/llm/copilot_http_test.go index 180e43e..d66311c 100644 --- a/internal/llm/copilot_http_test.go +++ b/internal/llm/copilot_http_test.go @@ -1,205 +1,261 @@ package llm import ( - "context" - "encoding/json" - "io" - "net/http" - "net/http/httptest" - "strings" - "testing" - "time" - "encoding/base64" - "os" + "context" + "encoding/base64" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + "time" ) type rtFunc2 func(*http.Request) (*http.Response, error) + func (f rtFunc2) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) } func TestCopilot_EnsureSession_AndChat_Success(t *testing.T) { - if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") } - // Mock chat endpoint - chatSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != "/chat/completions" { t.Fatalf("unexpected path: %s", r.URL.Path) } - _ = json.NewEncoder(w).Encode(map[string]any{"choices": []map[string]any{{"index":0, "message": map[string]string{"role":"assistant","content":"OK"}}}}) - })) - defer chatSrv.Close() - c := newCopilot(chatSrv.URL, "gpt-4o-mini", "APIKEY", f64p(0.1)).(copilotClient) - // Intercept token endpoint to return a session token - tr := rtFunc2(func(r *http.Request) (*http.Response, error) { - if r.URL.Host == "api.github.com" && r.URL.Path == "/copilot_internal/v2/token" { - rw := httptest.NewRecorder() - _ = json.NewEncoder(rw).Encode(map[string]string{"token":"tok"}) - res := rw.Result() - res.StatusCode = 200 - return res, nil - } - // Fallback to default transport for chatSrv - return http.DefaultTransport.RoundTrip(r) - }) - c.httpClient = &http.Client{Transport: tr, Timeout: 5 * time.Second} - out, err := c.Chat(context.Background(), []Message{{Role:"user", Content:"hi"}}) - if err != nil || out != "OK" { t.Fatalf("copilot chat failed: %v %q", err, out) } + if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { + t.Skip("skip network-bound tests in restricted environments") + } + // Mock chat endpoint + chatSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/chat/completions" { + t.Fatalf("unexpected path: %s", r.URL.Path) + } + _ = json.NewEncoder(w).Encode(map[string]any{"choices": []map[string]any{{"index": 0, "message": map[string]string{"role": "assistant", "content": "OK"}}}}) + })) + defer chatSrv.Close() + c := newCopilot(chatSrv.URL, "gpt-4o-mini", "APIKEY", f64p(0.1)).(copilotClient) + // Intercept token endpoint to return a session token + tr := rtFunc2(func(r *http.Request) (*http.Response, error) { + if r.URL.Host == "api.github.com" && r.URL.Path == "/copilot_internal/v2/token" { + rw := httptest.NewRecorder() + _ = json.NewEncoder(rw).Encode(map[string]string{"token": "tok"}) + res := rw.Result() + res.StatusCode = 200 + return res, nil + } + // Fallback to default transport for chatSrv + return http.DefaultTransport.RoundTrip(r) + }) + c.httpClient = &http.Client{Transport: tr, Timeout: 5 * time.Second} + out, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}) + if err != nil || out != "OK" { + t.Fatalf("copilot chat failed: %v %q", err, out) + } } func TestCopilot_HandleNon2xx(t *testing.T) { - b, _ := json.Marshal(map[string]any{"error": map[string]any{"message":"bad","type":"invalid"}}) - resp := &http.Response{StatusCode: 400, Body: io.NopCloser(bytesReader(b))} - if err := handleCopilotNon2xx(resp, time.Now()); err == nil { t.Fatalf("expected error") } + b, _ := json.Marshal(map[string]any{"error": map[string]any{"message": "bad", "type": "invalid"}}) + resp := &http.Response{StatusCode: 400, Body: io.NopCloser(bytesReader(b))} + if err := handleCopilotNon2xx(resp, time.Now()); err == nil { + t.Fatalf("expected error") + } } func TestCopilot_CodeCompletion_Success(t *testing.T) { - c := newCopilot("https://api.githubcopilot.com", "gpt-4o-mini", "API", f64p(0.1)).(copilotClient) - tr := rtFunc2(func(r *http.Request) (*http.Response, error) { - // Token endpoint - if r.URL.Host == "api.github.com" && r.URL.Path == "/copilot_internal/v2/token" { - rw := httptest.NewRecorder() - _ = json.NewEncoder(rw).Encode(map[string]string{"token":"tok"}) - res := rw.Result(); res.StatusCode = 200; return res, nil - } - // Codex completion endpoint - if r.URL.Host == "copilot-proxy.githubusercontent.com" && strings.HasSuffix(r.URL.Path, "/v1/engines/copilot-codex/completions") { - rw := httptest.NewRecorder() - // two choices for index 0 and 1 - rw.WriteString("data: {\"choices\":[{\"index\":0,\"text\":\"A\"}]}\n") - rw.WriteString("data: {\"choices\":[{\"index\":1,\"text\":\"B\"}]}\n") - res := rw.Result(); res.StatusCode = 200; return res, nil - } - return http.DefaultTransport.RoundTrip(r) - }) - c.httpClient = &http.Client{Transport: tr, Timeout: 5 * time.Second} - out, err := c.CodeCompletion(context.Background(), "p", "s", 2, "go", 0.1) - if err != nil || len(out) != 2 || out[0] != "A" || out[1] != "B" { - t.Fatalf("codex: %v %#v", err, out) - } + c := newCopilot("https://api.githubcopilot.com", "gpt-4o-mini", "API", f64p(0.1)).(copilotClient) + tr := rtFunc2(func(r *http.Request) (*http.Response, error) { + // Token endpoint + if r.URL.Host == "api.github.com" && r.URL.Path == "/copilot_internal/v2/token" { + rw := httptest.NewRecorder() + _ = json.NewEncoder(rw).Encode(map[string]string{"token": "tok"}) + res := rw.Result() + res.StatusCode = 200 + return res, nil + } + // Codex completion endpoint + if r.URL.Host == "copilot-proxy.githubusercontent.com" && strings.HasSuffix(r.URL.Path, "/v1/engines/copilot-codex/completions") { + rw := httptest.NewRecorder() + // two choices for index 0 and 1 + rw.WriteString("data: {\"choices\":[{\"index\":0,\"text\":\"A\"}]}\n") + rw.WriteString("data: {\"choices\":[{\"index\":1,\"text\":\"B\"}]}\n") + res := rw.Result() + res.StatusCode = 200 + return res, nil + } + return http.DefaultTransport.RoundTrip(r) + }) + c.httpClient = &http.Client{Transport: tr, Timeout: 5 * time.Second} + out, err := c.CodeCompletion(context.Background(), "p", "s", 2, "go", 0.1) + if err != nil || len(out) != 2 || out[0] != "A" || out[1] != "B" { + t.Fatalf("codex: %v %#v", err, out) + } } func TestCopilot_Chat_MultiChoice_And_ErrorBody(t *testing.T) { - if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") } - // Chat multi-choice: return two choices; client returns first content - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _ = json.NewEncoder(w).Encode(map[string]any{ - "choices": []map[string]any{ - {"index": 0, "finish_reason": "stop", "message": map[string]string{"role": "assistant", "content": "FIRST"}}, - {"index": 1, "finish_reason": "length", "message": map[string]string{"role": "assistant", "content": "SECOND"}}, - }, - }) - })) - defer srv.Close() - c := newCopilot(srv.URL, "gpt-4o-mini", "KEY", f64p(0.1)).(copilotClient) - // Token success - tr := rtFunc2(func(r *http.Request) (*http.Response, error) { - if r.URL.Host == "api.github.com" && r.URL.Path == "/copilot_internal/v2/token" { - rw := httptest.NewRecorder(); _ = json.NewEncoder(rw).Encode(map[string]string{"token":"tok"}); res := rw.Result(); res.StatusCode = 200; return res, nil - } - return http.DefaultTransport.RoundTrip(r) - }) - c.httpClient = &http.Client{Transport: tr, Timeout: 5 * time.Second} - out, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}) - if err != nil || out != "FIRST" { t.Fatalf("copilot multi-choice: %v %q", err, out) } + if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { + t.Skip("skip network-bound tests in restricted environments") + } + // Chat multi-choice: return two choices; client returns first content + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]any{ + "choices": []map[string]any{ + {"index": 0, "finish_reason": "stop", "message": map[string]string{"role": "assistant", "content": "FIRST"}}, + {"index": 1, "finish_reason": "length", "message": map[string]string{"role": "assistant", "content": "SECOND"}}, + }, + }) + })) + defer srv.Close() + c := newCopilot(srv.URL, "gpt-4o-mini", "KEY", f64p(0.1)).(copilotClient) + // Token success + tr := rtFunc2(func(r *http.Request) (*http.Response, error) { + if r.URL.Host == "api.github.com" && r.URL.Path == "/copilot_internal/v2/token" { + rw := httptest.NewRecorder() + _ = json.NewEncoder(rw).Encode(map[string]string{"token": "tok"}) + res := rw.Result() + res.StatusCode = 200 + return res, nil + } + return http.DefaultTransport.RoundTrip(r) + }) + c.httpClient = &http.Client{Transport: tr, Timeout: 5 * time.Second} + out, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}) + if err != nil || out != "FIRST" { + t.Fatalf("copilot multi-choice: %v %q", err, out) + } - // Non-2xx with error body - srv2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(403) - _ = json.NewEncoder(w).Encode(map[string]any{"error": map[string]any{"message":"denied","type":"forbidden"}}) - })) - defer srv2.Close() - c2 := newCopilot(srv2.URL, "gpt-4o-mini", "KEY", f64p(0.1)).(copilotClient) - c2.httpClient = &http.Client{Transport: tr, Timeout: 5 * time.Second} - if _, err := c2.Chat(context.Background(), []Message{{Role:"user", Content:"hi"}}); err == nil { - t.Fatalf("expected error for copilot non-2xx with error body") - } + // Non-2xx with error body + srv2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(403) + _ = json.NewEncoder(w).Encode(map[string]any{"error": map[string]any{"message": "denied", "type": "forbidden"}}) + })) + defer srv2.Close() + c2 := newCopilot(srv2.URL, "gpt-4o-mini", "KEY", f64p(0.1)).(copilotClient) + c2.httpClient = &http.Client{Transport: tr, Timeout: 5 * time.Second} + if _, err := c2.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}); err == nil { + t.Fatalf("expected error for copilot non-2xx with error body") + } } func TestCopilot_Chat_NoChoices_Error(t *testing.T) { - if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") } - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _ = json.NewEncoder(w).Encode(map[string]any{"choices": []any{}}) - })) - defer srv.Close() - c := newCopilot(srv.URL, "gpt-4o-mini", "KEY", f64p(0.1)).(copilotClient) - tr := rtFunc2(func(r *http.Request) (*http.Response, error) { - if r.URL.Host == "api.github.com" && r.URL.Path == "/copilot_internal/v2/token" { - rw := httptest.NewRecorder(); _ = json.NewEncoder(rw).Encode(map[string]string{"token":"tok"}); res := rw.Result(); res.StatusCode = 200; return res, nil - } - return http.DefaultTransport.RoundTrip(r) - }) - c.httpClient = &http.Client{Transport: tr, Timeout: 5 * time.Second} - if _, err := c.Chat(context.Background(), []Message{{Role:"user", Content:"hi"}}); err == nil { - t.Fatalf("expected error when no choices returned") - } + if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { + t.Skip("skip network-bound tests in restricted environments") + } + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]any{"choices": []any{}}) + })) + defer srv.Close() + c := newCopilot(srv.URL, "gpt-4o-mini", "KEY", f64p(0.1)).(copilotClient) + tr := rtFunc2(func(r *http.Request) (*http.Response, error) { + if r.URL.Host == "api.github.com" && r.URL.Path == "/copilot_internal/v2/token" { + rw := httptest.NewRecorder() + _ = json.NewEncoder(rw).Encode(map[string]string{"token": "tok"}) + res := rw.Result() + res.StatusCode = 200 + return res, nil + } + return http.DefaultTransport.RoundTrip(r) + }) + c.httpClient = &http.Client{Transport: tr, Timeout: 5 * time.Second} + if _, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}); err == nil { + t.Fatalf("expected error when no choices returned") + } } func TestCopilot_Chat_DecodeError_StatusOK(t *testing.T) { - if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") } - // Chat returns 200 but invalid JSON; expect decode error - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - io.WriteString(w, "{invalid") - })) - defer srv.Close() - c := newCopilot(srv.URL, "gpt-4o-mini", "KEY", f64p(0.1)).(copilotClient) - tr := rtFunc2(func(r *http.Request) (*http.Response, error) { - if r.URL.Host == "api.github.com" && r.URL.Path == "/copilot_internal/v2/token" { - rw := httptest.NewRecorder(); _ = json.NewEncoder(rw).Encode(map[string]string{"token":"tok"}); res := rw.Result(); res.StatusCode = 200; return res, nil - } - return http.DefaultTransport.RoundTrip(r) - }) - c.httpClient = &http.Client{Transport: tr, Timeout: 5 * time.Second} - if _, err := c.Chat(context.Background(), []Message{{Role:"user", Content:"hi"}}); err == nil { - t.Fatalf("expected decode error for invalid body") - } + if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { + t.Skip("skip network-bound tests in restricted environments") + } + // Chat returns 200 but invalid JSON; expect decode error + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, "{invalid") + })) + defer srv.Close() + c := newCopilot(srv.URL, "gpt-4o-mini", "KEY", f64p(0.1)).(copilotClient) + tr := rtFunc2(func(r *http.Request) (*http.Response, error) { + if r.URL.Host == "api.github.com" && r.URL.Path == "/copilot_internal/v2/token" { + rw := httptest.NewRecorder() + _ = json.NewEncoder(rw).Encode(map[string]string{"token": "tok"}) + res := rw.Result() + res.StatusCode = 200 + return res, nil + } + return http.DefaultTransport.RoundTrip(r) + }) + c.httpClient = &http.Client{Transport: tr, Timeout: 5 * time.Second} + if _, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}); err == nil { + t.Fatalf("expected decode error for invalid body") + } } func TestCopilot_CodeCompletion_MalformedAndEmpty(t *testing.T) { - c := newCopilot("https://api.githubcopilot.com", "gpt-4o-mini", "API", f64p(0.1)).(copilotClient) - tr := rtFunc2(func(r *http.Request) (*http.Response, error) { - if r.URL.Host == "api.github.com" && r.URL.Path == "/copilot_internal/v2/token" { - rw := httptest.NewRecorder(); _ = json.NewEncoder(rw).Encode(map[string]string{"token":"tok"}); res := rw.Result(); res.StatusCode = 200; return res, nil - } - if r.URL.Host == "copilot-proxy.githubusercontent.com" && strings.HasSuffix(r.URL.Path, "/v1/engines/copilot-codex/completions") { - rw := httptest.NewRecorder() - // malformed line - rw.WriteString("data: {bad}\n") - // done; should produce empty suggestions - rw.WriteString("data: [DONE]\n") - res := rw.Result(); res.StatusCode = 200; return res, nil - } - return http.DefaultTransport.RoundTrip(r) - }) - c.httpClient = &http.Client{Transport: tr, Timeout: 5 * time.Second} - out, err := c.CodeCompletion(context.Background(), "p", "s", 1, "go", 0.1) - if err != nil { t.Fatalf("unexpected error: %v", err) } - if len(out) != 0 { t.Fatalf("expected empty suggestions, got %#v", out) } + c := newCopilot("https://api.githubcopilot.com", "gpt-4o-mini", "API", f64p(0.1)).(copilotClient) + tr := rtFunc2(func(r *http.Request) (*http.Response, error) { + if r.URL.Host == "api.github.com" && r.URL.Path == "/copilot_internal/v2/token" { + rw := httptest.NewRecorder() + _ = json.NewEncoder(rw).Encode(map[string]string{"token": "tok"}) + res := rw.Result() + res.StatusCode = 200 + return res, nil + } + if r.URL.Host == "copilot-proxy.githubusercontent.com" && strings.HasSuffix(r.URL.Path, "/v1/engines/copilot-codex/completions") { + rw := httptest.NewRecorder() + // malformed line + rw.WriteString("data: {bad}\n") + // done; should produce empty suggestions + rw.WriteString("data: [DONE]\n") + res := rw.Result() + res.StatusCode = 200 + return res, nil + } + return http.DefaultTransport.RoundTrip(r) + }) + c.httpClient = &http.Client{Transport: tr, Timeout: 5 * time.Second} + out, err := c.CodeCompletion(context.Background(), "p", "s", 1, "go", 0.1) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(out) != 0 { + t.Fatalf("expected empty suggestions, got %#v", out) + } - // Now include one good chunk after malformed - tr2 := rtFunc2(func(r *http.Request) (*http.Response, error) { - if r.URL.Host == "api.github.com" && r.URL.Path == "/copilot_internal/v2/token" { - rw := httptest.NewRecorder(); _ = json.NewEncoder(rw).Encode(map[string]string{"token":"tok"}); res := rw.Result(); res.StatusCode = 200; return res, nil - } - if r.URL.Host == "copilot-proxy.githubusercontent.com" && strings.HasSuffix(r.URL.Path, "/v1/engines/copilot-codex/completions") { - rw := httptest.NewRecorder() - rw.WriteString("data: {bad}\n") - rw.WriteString("data: {\"choices\":[{\"index\":0,\"text\":\"OK\"}]}\n") - rw.WriteString("data: [DONE]\n") - res := rw.Result(); res.StatusCode = 200; return res, nil - } - return http.DefaultTransport.RoundTrip(r) - }) - c.httpClient = &http.Client{Transport: tr2, Timeout: 5 * time.Second} - out2, err := c.CodeCompletion(context.Background(), "p", "s", 1, "go", 0.1) - if err != nil || len(out2) != 1 || out2[0] != "OK" { t.Fatalf("unexpected: %v %#v", err, out2) } + // Now include one good chunk after malformed + tr2 := rtFunc2(func(r *http.Request) (*http.Response, error) { + if r.URL.Host == "api.github.com" && r.URL.Path == "/copilot_internal/v2/token" { + rw := httptest.NewRecorder() + _ = json.NewEncoder(rw).Encode(map[string]string{"token": "tok"}) + res := rw.Result() + res.StatusCode = 200 + return res, nil + } + if r.URL.Host == "copilot-proxy.githubusercontent.com" && strings.HasSuffix(r.URL.Path, "/v1/engines/copilot-codex/completions") { + rw := httptest.NewRecorder() + rw.WriteString("data: {bad}\n") + rw.WriteString("data: {\"choices\":[{\"index\":0,\"text\":\"OK\"}]}\n") + rw.WriteString("data: [DONE]\n") + res := rw.Result() + res.StatusCode = 200 + return res, nil + } + return http.DefaultTransport.RoundTrip(r) + }) + c.httpClient = &http.Client{Transport: tr2, Timeout: 5 * time.Second} + out2, err := c.CodeCompletion(context.Background(), "p", "s", 1, "go", 0.1) + if err != nil || len(out2) != 1 || out2[0] != "OK" { + t.Fatalf("unexpected: %v %#v", err, out2) + } } func TestParseJWTExp_AndParseInt64(t *testing.T) { - // Valid base64 payload - payload := `{"exp": 1700000000}` - b := base64.RawURLEncoding.EncodeToString([]byte(payload)) - tok := "x." + b + ".y" - if tm := parseJWTExp(tok); tm.IsZero() { t.Fatalf("expected non-zero time") } - if n, err := parseInt64("123"); err != nil || n != 123 { t.Fatalf("parseInt64: %v %d", err, n) } + // Valid base64 payload + payload := `{"exp": 1700000000}` + b := base64.RawURLEncoding.EncodeToString([]byte(payload)) + tok := "x." + b + ".y" + if tm := parseJWTExp(tok); tm.IsZero() { + t.Fatalf("expected non-zero time") + } + if n, err := parseInt64("123"); err != nil || n != 123 { + t.Fatalf("parseInt64: %v %d", err, n) + } } // bytesReader wraps a byte slice with an io.ReadCloser without importing extra. type bytesReader []byte + func (b bytesReader) Read(p []byte) (int, error) { n := copy(p, b); return n, io.EOF } -func (b bytesReader) Close() error { return nil } +func (b bytesReader) Close() error { return nil } diff --git a/internal/llm/ollama_test.go b/internal/llm/ollama_test.go index 15f9cff..8bd33ca 100644 --- a/internal/llm/ollama_test.go +++ b/internal/llm/ollama_test.go @@ -1,173 +1,217 @@ package llm import ( - "context" - "encoding/json" - "fmt" - "net/http" - "net/http/httptest" - "strings" - "testing" - "time" - "os" + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + "time" ) func TestBuildOllamaRequest_OptionsAndStream(t *testing.T) { - o := Options{Model: "codemodel", Temperature: 0, MaxTokens: 256, Stop: []string{"STOP"}} - msgs := []Message{{Role: "user", Content: "hello"}} - req := buildOllamaRequest(o, msgs, f64p(0.2), false) - if req.Model != "codemodel" || req.Stream { t.Fatalf("model/stream mismatch: %+v", req) } - if req.Options == nil { t.Fatalf("expected options map") } - if req.Options.(map[string]any)["temperature"].(float64) != 0.2 { t.Fatalf("default temp not applied") } - if req.Options.(map[string]any)["num_predict"].(int) != 256 { t.Fatalf("num_predict not applied") } - if req.Options.(map[string]any)["stop"].([]string)[0] != "STOP" { t.Fatalf("stop not applied") } - - req2 := buildOllamaRequest(o, msgs, f64p(0.2), true) - if !req2.Stream { t.Fatalf("expected stream=true") } + o := Options{Model: "codemodel", Temperature: 0, MaxTokens: 256, Stop: []string{"STOP"}} + msgs := []Message{{Role: "user", Content: "hello"}} + req := buildOllamaRequest(o, msgs, f64p(0.2), false) + if req.Model != "codemodel" || req.Stream { + t.Fatalf("model/stream mismatch: %+v", req) + } + if req.Options == nil { + t.Fatalf("expected options map") + } + if req.Options.(map[string]any)["temperature"].(float64) != 0.2 { + t.Fatalf("default temp not applied") + } + if req.Options.(map[string]any)["num_predict"].(int) != 256 { + t.Fatalf("num_predict not applied") + } + if req.Options.(map[string]any)["stop"].([]string)[0] != "STOP" { + t.Fatalf("stop not applied") + } + + req2 := buildOllamaRequest(o, msgs, f64p(0.2), true) + if !req2.Stream { + t.Fatalf("expected stream=true") + } } func TestBuildOllamaRequest_TempOverride(t *testing.T) { - o := Options{Model: "m", Temperature: 0.9} - msgs := []Message{{Role: "user", Content: "hi"}} - req := buildOllamaRequest(o, msgs, f64p(0.2), false) - m := req.Options.(map[string]any) - if m["temperature"].(float64) != 0.9 { t.Fatalf("explicit temp should override default") } + o := Options{Model: "m", Temperature: 0.9} + msgs := []Message{{Role: "user", Content: "hi"}} + req := buildOllamaRequest(o, msgs, f64p(0.2), false) + m := req.Options.(map[string]any) + if m["temperature"].(float64) != 0.9 { + t.Fatalf("explicit temp should override default") + } } func TestOllama_NameAndModel(t *testing.T) { - c := newOllama("http://x", "model-x", nil).(ollamaClient) - if c.Name() != "ollama" { t.Fatalf("name: %q", c.Name()) } - if c.DefaultModel() != "model-x" { t.Fatalf("default model: %q", c.DefaultModel()) } + c := newOllama("http://x", "model-x", nil).(ollamaClient) + if c.Name() != "ollama" { + t.Fatalf("name: %q", c.Name()) + } + if c.DefaultModel() != "model-x" { + t.Fatalf("default model: %q", c.DefaultModel()) + } } func TestOllamaChat_Success(t *testing.T) { - if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") } - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost || r.URL.Path != "/api/chat" { t.Fatalf("unexpected request: %s %s", r.Method, r.URL.Path) } - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(map[string]any{"message": map[string]string{"role":"assistant","content":"Hello"}, "done": true}) - })) - defer ts.Close() - c := newOllama(ts.URL, "m", f64p(0.1)).(ollamaClient) - c.httpClient = ts.Client() - out, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}) - if err != nil { t.Fatalf("unexpected err: %v", err) } - if out != "Hello" { t.Fatalf("got %q", out) } + if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { + t.Skip("skip network-bound tests in restricted environments") + } + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || r.URL.Path != "/api/chat" { + t.Fatalf("unexpected request: %s %s", r.Method, r.URL.Path) + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{"message": map[string]string{"role": "assistant", "content": "Hello"}, "done": true}) + })) + defer ts.Close() + c := newOllama(ts.URL, "m", f64p(0.1)).(ollamaClient) + c.httpClient = ts.Client() + out, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if out != "Hello" { + t.Fatalf("got %q", out) + } } func TestOllamaChat_EmptyContent(t *testing.T) { - if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") } - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _ = json.NewEncoder(w).Encode(map[string]any{"message": map[string]string{"role":"assistant","content":""}, "done": true}) - })) - defer ts.Close() - c := newOllama(ts.URL, "m", nil).(ollamaClient) - c.httpClient = ts.Client() - if _, err := c.Chat(context.Background(), []Message{{Role:"user", Content:"x"}}); err == nil { - t.Fatalf("expected error for empty content") - } + if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { + t.Skip("skip network-bound tests in restricted environments") + } + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]any{"message": map[string]string{"role": "assistant", "content": ""}, "done": true}) + })) + defer ts.Close() + c := newOllama(ts.URL, "m", nil).(ollamaClient) + c.httpClient = ts.Client() + if _, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "x"}}); err == nil { + t.Fatalf("expected error for empty content") + } } func TestOllamaChat_Non2xx(t *testing.T) { - if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") } - // API error string - ts1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(400) - _ = json.NewEncoder(w).Encode(map[string]any{"error":"bad"}) - })) - defer ts1.Close() - c1 := newOllama(ts1.URL, "m", nil).(ollamaClient) - c1.httpClient = ts1.Client() - if _, err := c1.Chat(context.Background(), []Message{{Role:"user", Content:"x"}}); err == nil { - t.Fatalf("expected error for 400 with api body") - } - // Plain HTTP error without api message - ts2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(500) - _, _ = w.Write([]byte("{}")) - })) - defer ts2.Close() - c2 := newOllama(ts2.URL, "m", nil).(ollamaClient) - c2.httpClient = ts2.Client() - if _, err := c2.Chat(context.Background(), []Message{{Role:"user", Content:"x"}}); err == nil { - t.Fatalf("expected error for 500") - } + if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { + t.Skip("skip network-bound tests in restricted environments") + } + // API error string + ts1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(400) + _ = json.NewEncoder(w).Encode(map[string]any{"error": "bad"}) + })) + defer ts1.Close() + c1 := newOllama(ts1.URL, "m", nil).(ollamaClient) + c1.httpClient = ts1.Client() + if _, err := c1.Chat(context.Background(), []Message{{Role: "user", Content: "x"}}); err == nil { + t.Fatalf("expected error for 400 with api body") + } + // Plain HTTP error without api message + ts2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(500) + _, _ = w.Write([]byte("{}")) + })) + defer ts2.Close() + c2 := newOllama(ts2.URL, "m", nil).(ollamaClient) + c2.httpClient = ts2.Client() + if _, err := c2.Chat(context.Background(), []Message{{Role: "user", Content: "x"}}); err == nil { + t.Fatalf("expected error for 500") + } } type rtFunc func(*http.Request) (*http.Response, error) + func (f rtFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) } func TestOllamaChat_HTTPError(t *testing.T) { - c := newOllama("http://127.0.0.1:0", "m", nil).(ollamaClient) - c.httpClient = &http.Client{Transport: rtFunc(func(*http.Request)(*http.Response,error){ return nil, fmt.Errorf("boom") })} - if _, err := c.Chat(context.Background(), []Message{{Role:"user", Content:"x"}}); err == nil { - t.Fatalf("expected http error path") - } + c := newOllama("http://127.0.0.1:0", "m", nil).(ollamaClient) + c.httpClient = &http.Client{Transport: rtFunc(func(*http.Request) (*http.Response, error) { return nil, fmt.Errorf("boom") })} + if _, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "x"}}); err == nil { + t.Fatalf("expected http error path") + } } func TestOllamaChat_DecodeError(t *testing.T) { - if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") } - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, _ = w.Write([]byte("{bad json}")) - })) - defer ts.Close() - c := newOllama(ts.URL, "m", nil).(ollamaClient) - c.httpClient = ts.Client() - if _, err := c.Chat(context.Background(), []Message{{Role:"user", Content:"x"}}); err == nil { - t.Fatalf("expected decode error") - } + if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { + t.Skip("skip network-bound tests in restricted environments") + } + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("{bad json}")) + })) + defer ts.Close() + c := newOllama(ts.URL, "m", nil).(ollamaClient) + c.httpClient = ts.Client() + if _, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "x"}}); err == nil { + t.Fatalf("expected decode error") + } } func TestHandleOllamaNon2xx_OK(t *testing.T) { - resp := &http.Response{StatusCode: 200, Body: ioNopCloser(strings.NewReader(""))} - if err := handleOllamaNon2xx(resp, time.Now()); err != nil { t.Fatalf("unexpected: %v", err) } + resp := &http.Response{StatusCode: 200, Body: ioNopCloser(strings.NewReader(""))} + if err := handleOllamaNon2xx(resp, time.Now()); err != nil { + t.Fatalf("unexpected: %v", err) + } } func TestOllamaChatStream_Success(t *testing.T) { - if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") } - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - // two JSON objects back-to-back - _, _ = w.Write([]byte(`{"message":{"role":"assistant","content":"Hi"},"done":false}`)) - _, _ = w.Write([]byte(`{"message":{"role":"assistant","content":"!"},"done":true}`)) - })) - defer ts.Close() - c := newOllama(ts.URL, "m", nil).(ollamaClient) - c.httpClient = ts.Client() - var got strings.Builder - if err := c.ChatStream(context.Background(), []Message{{Role:"user", Content:"x"}}, func(s string){ got.WriteString(s) }); err != nil { - t.Fatalf("unexpected: %v", err) - } - if got.String() != "Hi!" { t.Fatalf("got %q", got.String()) } + if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { + t.Skip("skip network-bound tests in restricted environments") + } + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + // two JSON objects back-to-back + _, _ = w.Write([]byte(`{"message":{"role":"assistant","content":"Hi"},"done":false}`)) + _, _ = w.Write([]byte(`{"message":{"role":"assistant","content":"!"},"done":true}`)) + })) + defer ts.Close() + c := newOllama(ts.URL, "m", nil).(ollamaClient) + c.httpClient = ts.Client() + var got strings.Builder + if err := c.ChatStream(context.Background(), []Message{{Role: "user", Content: "x"}}, func(s string) { got.WriteString(s) }); err != nil { + t.Fatalf("unexpected: %v", err) + } + if got.String() != "Hi!" { + t.Fatalf("got %q", got.String()) + } } func TestOllamaChatStream_ErrorEvent(t *testing.T) { - if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") } - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _ = json.NewEncoder(w).Encode(map[string]any{"error":"oops"}) - })) - defer ts.Close() - c := newOllama(ts.URL, "m", nil).(ollamaClient) - c.httpClient = ts.Client() - if err := c.ChatStream(context.Background(), []Message{{Role:"user", Content:"x"}}, func(string){}); err == nil { - t.Fatalf("expected stream error") - } + if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { + t.Skip("skip network-bound tests in restricted environments") + } + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]any{"error": "oops"}) + })) + defer ts.Close() + c := newOllama(ts.URL, "m", nil).(ollamaClient) + c.httpClient = ts.Client() + if err := c.ChatStream(context.Background(), []Message{{Role: "user", Content: "x"}}, func(string) {}); err == nil { + t.Fatalf("expected stream error") + } } func TestOllamaChatStream_DecodeError(t *testing.T) { - if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") } - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, _ = w.Write([]byte("{not json}")) - })) - defer ts.Close() - c := newOllama(ts.URL, "m", nil).(ollamaClient) - c.httpClient = ts.Client() - if err := c.ChatStream(context.Background(), []Message{{Role:"user", Content:"x"}}, func(string){}); err == nil { - t.Fatalf("expected decode error") - } + if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { + t.Skip("skip network-bound tests in restricted environments") + } + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("{not json}")) + })) + defer ts.Close() + c := newOllama(ts.URL, "m", nil).(ollamaClient) + c.httpClient = ts.Client() + if err := c.ChatStream(context.Background(), []Message{{Role: "user", Content: "x"}}, func(string) {}); err == nil { + t.Fatalf("expected decode error") + } } // small helper to construct an io.ReadCloser without importing extra packages type readCloser struct{ *strings.Reader } -func (readCloser) Close() error { return nil } + +func (readCloser) Close() error { return nil } func ioNopCloser(r *strings.Reader) *readCloser { return &readCloser{r} } diff --git a/internal/llm/openai_http_test.go b/internal/llm/openai_http_test.go index ac7b897..cb4bfcb 100644 --- a/internal/llm/openai_http_test.go +++ b/internal/llm/openai_http_test.go @@ -1,143 +1,171 @@ package llm import ( - "context" - "encoding/json" - "io" - "net/http" - "net/http/httptest" - "testing" - "strings" - "time" - "os" + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + "time" ) func TestOpenAI_Chat_Success(t *testing.T) { - if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") } - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != "/chat/completions" { t.Fatalf("unexpected path: %s", r.URL.Path) } - _ = json.NewEncoder(w).Encode(map[string]any{"choices": []map[string]any{{"index":0, "message": map[string]string{"role":"assistant","content":"OK"}}}}) - })) - defer srv.Close() - c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient) - c.httpClient = srv.Client() - out, err := c.Chat(context.Background(), []Message{{Role:"user", Content:"hi"}}) - if err != nil || out != "OK" { t.Fatalf("openai chat: %v %q", err, out) } + if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { + t.Skip("skip network-bound tests in restricted environments") + } + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/chat/completions" { + t.Fatalf("unexpected path: %s", r.URL.Path) + } + _ = json.NewEncoder(w).Encode(map[string]any{"choices": []map[string]any{{"index": 0, "message": map[string]string{"role": "assistant", "content": "OK"}}}}) + })) + defer srv.Close() + c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient) + c.httpClient = srv.Client() + out, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}) + if err != nil || out != "OK" { + t.Fatalf("openai chat: %v %q", err, out) + } } func TestOpenAI_Chat_MissingKey(t *testing.T) { - c := newOpenAI("http://x", "g", "", f64p(0.2)).(openAIClient) - if _, err := c.Chat(context.Background(), []Message{{Role:"user", Content:"hi"}}); err == nil { t.Fatalf("expected error for missing key") } + c := newOpenAI("http://x", "g", "", f64p(0.2)).(openAIClient) + if _, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}); err == nil { + t.Fatalf("expected error for missing key") + } } func TestOpenAI_ChatStream_SSE(t *testing.T) { - if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") } - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Return SSE-like stream - w.Header().Set("Content-Type", "text/event-stream") - io.WriteString(w, "data: {\"choices\":[{\"delta\":{\"content\":\"Hi\"}}]}\n\n") - io.WriteString(w, "data: [DONE]\n") - })) - defer srv.Close() - c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient) - c.httpClient = srv.Client() - var got string - err := c.ChatStream(context.Background(), []Message{{Role:"user", Content:"hi"}}, func(s string){ got += s }) - if err != nil || got != "Hi" { t.Fatalf("chat stream: %v %q", err, got) } + if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { + t.Skip("skip network-bound tests in restricted environments") + } + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Return SSE-like stream + w.Header().Set("Content-Type", "text/event-stream") + io.WriteString(w, "data: {\"choices\":[{\"delta\":{\"content\":\"Hi\"}}]}\n\n") + io.WriteString(w, "data: [DONE]\n") + })) + defer srv.Close() + c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient) + c.httpClient = srv.Client() + var got string + err := c.ChatStream(context.Background(), []Message{{Role: "user", Content: "hi"}}, func(s string) { got += s }) + if err != nil || got != "Hi" { + t.Fatalf("chat stream: %v %q", err, got) + } } func TestHandleOpenAINon2xx_NoErrorBody(t *testing.T) { - resp := &http.Response{StatusCode: 500, Body: io.NopCloser(strings.NewReader("{}"))} - if err := handleOpenAINon2xx(resp, time.Now()); err == nil { t.Fatalf("expected http error") } + resp := &http.Response{StatusCode: 500, Body: io.NopCloser(strings.NewReader("{}"))} + if err := handleOpenAINon2xx(resp, time.Now()); err == nil { + t.Fatalf("expected http error") + } } func TestOpenAI_ChatStream_SSE_ErrorChunk(t *testing.T) { - if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") } - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/event-stream") - io.WriteString(w, "data: {\"error\":{\"message\":\"oops\"}}\n\n") - io.WriteString(w, "data: [DONE]\n") - })) - defer srv.Close() - c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient) - c.httpClient = srv.Client() - var got string - if err := c.ChatStream(context.Background(), []Message{{Role:"user", Content:"hi"}}, func(s string){ got += s }); err == nil { - t.Fatalf("expected error due to error chunk") - } + if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { + t.Skip("skip network-bound tests in restricted environments") + } + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + io.WriteString(w, "data: {\"error\":{\"message\":\"oops\"}}\n\n") + io.WriteString(w, "data: [DONE]\n") + })) + defer srv.Close() + c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient) + c.httpClient = srv.Client() + var got string + if err := c.ChatStream(context.Background(), []Message{{Role: "user", Content: "hi"}}, func(s string) { got += s }); err == nil { + t.Fatalf("expected error due to error chunk") + } } func TestOpenAI_Chat_NoChoices_Error(t *testing.T) { - if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") } - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _ = json.NewEncoder(w).Encode(map[string]any{"choices": []any{}}) - })) - defer srv.Close() - c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient) - c.httpClient = srv.Client() - if _, err := c.Chat(context.Background(), []Message{{Role:"user", Content:"hi"}}); err == nil { - t.Fatalf("expected error when choices empty") - } + if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { + t.Skip("skip network-bound tests in restricted environments") + } + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]any{"choices": []any{}}) + })) + defer srv.Close() + c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient) + c.httpClient = srv.Client() + if _, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}); err == nil { + t.Fatalf("expected error when choices empty") + } } func TestOpenAI_ChatStream_SSE_EmptyDelta_NoError(t *testing.T) { - if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") } - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/event-stream") - io.WriteString(w, "data: {\\\"choices\\\":[{\\\"delta\\\":{\\\"content\\\":\\\"\\\"}}]}\\n\\n") - io.WriteString(w, "data: [DONE]\\n") - })) - defer srv.Close() - c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient) - c.httpClient = srv.Client() - var got string - if err := c.ChatStream(context.Background(), []Message{{Role:"user", Content:"hi"}}, func(s string){ got += s }); err != nil { - t.Fatalf("unexpected error for empty delta: %v", err) - } - if got != "" { t.Fatalf("expected no output for empty delta, got %q", got) } + if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { + t.Skip("skip network-bound tests in restricted environments") + } + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + io.WriteString(w, "data: {\\\"choices\\\":[{\\\"delta\\\":{\\\"content\\\":\\\"\\\"}}]}\\n\\n") + io.WriteString(w, "data: [DONE]\\n") + })) + defer srv.Close() + c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient) + c.httpClient = srv.Client() + var got string + if err := c.ChatStream(context.Background(), []Message{{Role: "user", Content: "hi"}}, func(s string) { got += s }); err != nil { + t.Fatalf("unexpected error for empty delta: %v", err) + } + if got != "" { + t.Fatalf("expected no output for empty delta, got %q", got) + } } func TestOpenAI_Chat_DecodeError_StatusOK(t *testing.T) { - if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") } - // Return status 200 but invalid JSON body; Chat should return an error - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) - io.WriteString(w, "{invalid") - })) - defer srv.Close() - c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient) - c.httpClient = srv.Client() - if _, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}); err == nil { - t.Fatalf("expected decode error for invalid JSON body") - } + if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { + t.Skip("skip network-bound tests in restricted environments") + } + // Return status 200 but invalid JSON body; Chat should return an error + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + io.WriteString(w, "{invalid") + })) + defer srv.Close() + c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient) + c.httpClient = srv.Client() + if _, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}); err == nil { + t.Fatalf("expected decode error for invalid JSON body") + } } func TestOpenAI_Chat_MultiChoiceAndErrorBody(t *testing.T) { - if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") } - // Multi-choice success: return two choices with different finish reasons - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _ = json.NewEncoder(w).Encode(map[string]any{ - "choices": []map[string]any{ - {"index": 0, "finish_reason": "stop", "message": map[string]string{"role": "assistant", "content": "FIRST"}}, - {"index": 1, "finish_reason": "length", "message": map[string]string{"role": "assistant", "content": "SECOND"}}, - }, - }) - })) - defer srv.Close() - c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient) - c.httpClient = srv.Client() - out, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}) - if err != nil || out != "FIRST" { t.Fatalf("openai multi-choice: %v %q", err, out) } + if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { + t.Skip("skip network-bound tests in restricted environments") + } + // Multi-choice success: return two choices with different finish reasons + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]any{ + "choices": []map[string]any{ + {"index": 0, "finish_reason": "stop", "message": map[string]string{"role": "assistant", "content": "FIRST"}}, + {"index": 1, "finish_reason": "length", "message": map[string]string{"role": "assistant", "content": "SECOND"}}, + }, + }) + })) + defer srv.Close() + c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient) + c.httpClient = srv.Client() + out, err := c.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}) + if err != nil || out != "FIRST" { + t.Fatalf("openai multi-choice: %v %q", err, out) + } - // Error body case: non-2xx with error message - srv2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(400) - _ = json.NewEncoder(w).Encode(map[string]any{"error": map[string]any{"message": "bad", "type": "invalid"}}) - })) - defer srv2.Close() - c2 := newOpenAI(srv2.URL, "g", "KEY", f64p(0.2)).(openAIClient) - c2.httpClient = srv2.Client() - if _, err := c2.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}); err == nil { - t.Fatalf("expected error from non-2xx with error body") - } + // Error body case: non-2xx with error message + srv2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(400) + _ = json.NewEncoder(w).Encode(map[string]any{"error": map[string]any{"message": "bad", "type": "invalid"}}) + })) + defer srv2.Close() + c2 := newOpenAI(srv2.URL, "g", "KEY", f64p(0.2)).(openAIClient) + c2.httpClient = srv2.Client() + if _, err := c2.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}); err == nil { + t.Fatalf("expected error from non-2xx with error body") + } } diff --git a/internal/llm/openai_sse_negative_test.go b/internal/llm/openai_sse_negative_test.go index 8da5526..de2ff71 100644 --- a/internal/llm/openai_sse_negative_test.go +++ b/internal/llm/openai_sse_negative_test.go @@ -1,28 +1,32 @@ package llm import ( - "context" - "io" - "net/http" - "net/http/httptest" - "testing" - "os" + "context" + "io" + "net/http" + "net/http/httptest" + "os" + "testing" ) func TestOpenAI_ChatStream_SSE_MalformedChunk(t *testing.T) { - if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { t.Skip("skip network-bound tests in restricted environments") } - // Malformed JSON chunk should be skipped; no onDelta calls; no error. - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/event-stream") - io.WriteString(w, "data: {not json}\n\n") - io.WriteString(w, "data: [DONE]\n") - })) - defer srv.Close() - c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient) - c.httpClient = srv.Client() - var got string - if err := c.ChatStream(context.Background(), []Message{{Role: "user", Content: "hi"}}, func(s string){ got += s }); err != nil { - t.Fatalf("unexpected error for malformed chunk: %v", err) - } - if got != "" { t.Fatalf("expected no deltas for malformed chunk, got %q", got) } + if os.Getenv("HEXAI_TEST_SKIP_NET") == "1" { + t.Skip("skip network-bound tests in restricted environments") + } + // Malformed JSON chunk should be skipped; no onDelta calls; no error. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + io.WriteString(w, "data: {not json}\n\n") + io.WriteString(w, "data: [DONE]\n") + })) + defer srv.Close() + c := newOpenAI(srv.URL, "g", "KEY", f64p(0.2)).(openAIClient) + c.httpClient = srv.Client() + var got string + if err := c.ChatStream(context.Background(), []Message{{Role: "user", Content: "hi"}}, func(s string) { got += s }); err != nil { + t.Fatalf("unexpected error for malformed chunk: %v", err) + } + if got != "" { + t.Fatalf("expected no deltas for malformed chunk, got %q", got) + } } diff --git a/internal/llm/openai_test.go b/internal/llm/openai_test.go index f50b171..f7ce080 100644 --- a/internal/llm/openai_test.go +++ b/internal/llm/openai_test.go @@ -1,44 +1,67 @@ package llm import ( - "bytes" - "encoding/json" - "io" - "net/http" - "strings" - "testing" - "time" + "bytes" + "encoding/json" + "io" + "net/http" + "strings" + "testing" + "time" ) func f64p(v float64) *float64 { return &v } func TestBuildOAChatRequest_TempFallbackAndFields(t *testing.T) { - o := Options{Model: "m1", Temperature: 0, MaxTokens: 42, Stop: []string{"END"}} - msgs := []Message{{Role: "user", Content: "hi"}} - req := buildOAChatRequest(o, msgs, f64p(0.3), false) - if req.Model != "m1" || req.Stream { t.Fatalf("model/stream mismatch: %+v", req) } - if req.Temperature == nil || *req.Temperature != 0.3 { t.Fatalf("expected default temp 0.3, got %#v", req.Temperature) } - if req.MaxTokens == nil || *req.MaxTokens != 42 { t.Fatalf("expected max tokens 42") } - if len(req.Stop) != 1 || req.Stop[0] != "END" { t.Fatalf("stop not propagated: %#v", req.Stop) } - if len(req.Messages) != 1 || req.Messages[0].Content != "hi" { t.Fatalf("messages not copied") } + o := Options{Model: "m1", Temperature: 0, MaxTokens: 42, Stop: []string{"END"}} + msgs := []Message{{Role: "user", Content: "hi"}} + req := buildOAChatRequest(o, msgs, f64p(0.3), false) + if req.Model != "m1" || req.Stream { + t.Fatalf("model/stream mismatch: %+v", req) + } + if req.Temperature == nil || *req.Temperature != 0.3 { + t.Fatalf("expected default temp 0.3, got %#v", req.Temperature) + } + if req.MaxTokens == nil || *req.MaxTokens != 42 { + t.Fatalf("expected max tokens 42") + } + if len(req.Stop) != 1 || req.Stop[0] != "END" { + t.Fatalf("stop not propagated: %#v", req.Stop) + } + if len(req.Messages) != 1 || req.Messages[0].Content != "hi" { + t.Fatalf("messages not copied") + } - // stream on - req2 := buildOAChatRequest(o, msgs, f64p(0.3), true) - if !req2.Stream { t.Fatalf("expected stream=true") } + // stream on + req2 := buildOAChatRequest(o, msgs, f64p(0.3), true) + if !req2.Stream { + t.Fatalf("expected stream=true") + } } func TestHandleOpenAINon2xx_WithAPIError(t *testing.T) { - api := oaChatResponse{Error: &struct{ Message string `json:"message"`; Type string `json:"type"`; Param any `json:"param"`; Code any `json:"code"` }{Message: "bad", Type: "invalid"}} - b, _ := json.Marshal(api) - resp := &http.Response{StatusCode: 400, Body: io.NopCloser(bytes.NewReader(b))} - if err := handleOpenAINon2xx(resp, time.Now()); err == nil { t.Fatalf("expected error for non-2xx with body") } + api := oaChatResponse{Error: &struct { + Message string `json:"message"` + Type string `json:"type"` + Param any `json:"param"` + Code any `json:"code"` + }{Message: "bad", Type: "invalid"}} + b, _ := json.Marshal(api) + resp := &http.Response{StatusCode: 400, Body: io.NopCloser(bytes.NewReader(b))} + if err := handleOpenAINon2xx(resp, time.Now()); err == nil { + t.Fatalf("expected error for non-2xx with body") + } } func TestParseOpenAIStream_DeliversChunks(t *testing.T) { - stream := "data: {\"choices\":[{\"delta\":{\"content\":\"Hi\"}}]}\n\n" + - "data: [DONE]\n" - resp := &http.Response{Body: io.NopCloser(strings.NewReader(stream))} - var got strings.Builder - if err := parseOpenAIStream(resp, time.Now(), func(s string){ got.WriteString(s) }); err != nil { t.Fatalf("unexpected error: %v", err) } - if got.String() != "Hi" { t.Fatalf("got %q want %q", got.String(), "Hi") } + stream := "data: {\"choices\":[{\"delta\":{\"content\":\"Hi\"}}]}\n\n" + + "data: [DONE]\n" + resp := &http.Response{Body: io.NopCloser(strings.NewReader(stream))} + var got strings.Builder + if err := parseOpenAIStream(resp, time.Now(), func(s string) { got.WriteString(s) }); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got.String() != "Hi" { + t.Fatalf("got %q want %q", got.String(), "Hi") + } } diff --git a/internal/llm/provider.go b/internal/llm/provider.go index 7ab58c6..88c280c 100644 --- a/internal/llm/provider.go +++ b/internal/llm/provider.go @@ -28,20 +28,20 @@ type Client interface { // token-by-token streaming responses. Callers can type-assert to Streamer and // fall back to Client.Chat when not implemented. type Streamer interface { - // ChatStream sends chat messages and invokes onDelta with incremental text - // chunks as they are produced by the model. Implementations should call - // onDelta with empty strings sparingly (prefer only non-empty chunks). - ChatStream(ctx context.Context, messages []Message, onDelta func(string), opts ...RequestOption) error + // ChatStream sends chat messages and invokes onDelta with incremental text + // chunks as they are produced by the model. Implementations should call + // onDelta with empty strings sparingly (prefer only non-empty chunks). + ChatStream(ctx context.Context, messages []Message, onDelta func(string), opts ...RequestOption) error } // CodeCompleter is an optional interface for providers that support a // prompt/suffix code-completion API (e.g., Copilot Codex endpoint). Clients // can type-assert to this and prefer it over chat when available. type CodeCompleter interface { - // CodeCompletion requests up to n suggestions given a left-hand prompt and - // right-hand suffix around the cursor. Language is advisory and may be - // ignored. Temperature applies when provider supports it. - CodeCompletion(ctx context.Context, prompt string, suffix string, n int, language string, temperature float64) ([]string, error) + // CodeCompletion requests up to n suggestions given a left-hand prompt and + // right-hand suffix around the cursor. Language is advisory and may be + // ignored. Temperature applies when provider supports it. + CodeCompletion(ctx context.Context, prompt string, suffix string, n int, language string, temperature float64) ([]string, error) } // Options for a request. Providers may ignore unsupported fields. @@ -64,56 +64,56 @@ func WithStop(stop ...string) RequestOption { // Config defines provider configuration read from the Hexai config file. type Config struct { - Provider string - // OpenAI options - OpenAIBaseURL string - OpenAIModel string - OpenAITemperature *float64 - // Ollama options - OllamaBaseURL string - OllamaModel string - OllamaTemperature *float64 - // Copilot options - CopilotBaseURL string - CopilotModel string - CopilotTemperature *float64 + Provider string + // OpenAI options + OpenAIBaseURL string + OpenAIModel string + OpenAITemperature *float64 + // Ollama options + OllamaBaseURL string + OllamaModel string + OllamaTemperature *float64 + // Copilot options + CopilotBaseURL string + CopilotModel string + CopilotTemperature *float64 } // NewFromConfig creates an LLM client using only the supplied configuration. // The OpenAI API key is supplied separately and may be read from the environment // by the caller; other environment-based configuration is not used. func NewFromConfig(cfg Config, openAIAPIKey, copilotAPIKey string) (Client, error) { - p := strings.ToLower(strings.TrimSpace(cfg.Provider)) - if p == "" { - p = "openai" - } - switch p { - case "openai": - if strings.TrimSpace(openAIAPIKey) == "" { - return nil, errors.New("missing OPENAI_API_KEY for provider openai") - } - // Set coding-friendly default temperature if none provided - if cfg.OpenAITemperature == nil { - t := 0.2 - cfg.OpenAITemperature = &t - } - return newOpenAI(cfg.OpenAIBaseURL, cfg.OpenAIModel, openAIAPIKey, cfg.OpenAITemperature), nil - case "ollama": - if cfg.OllamaTemperature == nil { - t := 0.2 - cfg.OllamaTemperature = &t - } - return newOllama(cfg.OllamaBaseURL, cfg.OllamaModel, cfg.OllamaTemperature), nil - case "copilot": - if strings.TrimSpace(copilotAPIKey) == "" { - return nil, errors.New("missing COPILOT_API_KEY for provider copilot") - } - if cfg.CopilotTemperature == nil { - t := 0.2 - cfg.CopilotTemperature = &t - } - return newCopilot(cfg.CopilotBaseURL, cfg.CopilotModel, copilotAPIKey, cfg.CopilotTemperature), nil - default: - return nil, errors.New("unknown LLM provider: " + p) - } + p := strings.ToLower(strings.TrimSpace(cfg.Provider)) + if p == "" { + p = "openai" + } + switch p { + case "openai": + if strings.TrimSpace(openAIAPIKey) == "" { + return nil, errors.New("missing OPENAI_API_KEY for provider openai") + } + // Set coding-friendly default temperature if none provided + if cfg.OpenAITemperature == nil { + t := 0.2 + cfg.OpenAITemperature = &t + } + return newOpenAI(cfg.OpenAIBaseURL, cfg.OpenAIModel, openAIAPIKey, cfg.OpenAITemperature), nil + case "ollama": + if cfg.OllamaTemperature == nil { + t := 0.2 + cfg.OllamaTemperature = &t + } + return newOllama(cfg.OllamaBaseURL, cfg.OllamaModel, cfg.OllamaTemperature), nil + case "copilot": + if strings.TrimSpace(copilotAPIKey) == "" { + return nil, errors.New("missing COPILOT_API_KEY for provider copilot") + } + if cfg.CopilotTemperature == nil { + t := 0.2 + cfg.CopilotTemperature = &t + } + return newCopilot(cfg.CopilotBaseURL, cfg.CopilotModel, copilotAPIKey, cfg.CopilotTemperature), nil + default: + return nil, errors.New("unknown LLM provider: " + p) + } } diff --git a/internal/llm/provider_more_test.go b/internal/llm/provider_more_test.go index bd08552..d7469af 100644 --- a/internal/llm/provider_more_test.go +++ b/internal/llm/provider_more_test.go @@ -3,24 +3,27 @@ package llm import "testing" func TestWithOptions_Apply(t *testing.T) { - o := Options{} - WithModel("m")(&o) - WithTemperature(0.7)(&o) - WithMaxTokens(123)(&o) - WithStop("END")(&o) - if o.Model != "m" || o.Temperature != 0.7 || o.MaxTokens != 123 || len(o.Stop) != 1 || o.Stop[0] != "END" { - t.Fatalf("options not applied correctly: %+v", o) - } + o := Options{} + WithModel("m")(&o) + WithTemperature(0.7)(&o) + WithMaxTokens(123)(&o) + WithStop("END")(&o) + if o.Model != "m" || o.Temperature != 0.7 || o.MaxTokens != 123 || len(o.Stop) != 1 || o.Stop[0] != "END" { + t.Fatalf("options not applied correctly: %+v", o) + } } func TestNewFromConfig_Success_OpenAI_And_Copilot(t *testing.T) { - // OpenAI success - oc := Config{Provider: "openai", OpenAIBaseURL: "http://x", OpenAIModel: "gpt"} - c, err := NewFromConfig(oc, "KEY", "") - if err != nil || c == nil || c.Name() != "openai" || c.DefaultModel() == "" { t.Fatalf("openai new: %v %v", c, err) } - // Copilot success - cc := Config{Provider: "copilot", CopilotBaseURL: "http://x", CopilotModel: "gpt-4o-mini"} - c2, err := NewFromConfig(cc, "", "KEY") - if err != nil || c2 == nil || c2.Name() != "copilot" || c2.DefaultModel() == "" { t.Fatalf("copilot new: %v %v", c2, err) } + // OpenAI success + oc := Config{Provider: "openai", OpenAIBaseURL: "http://x", OpenAIModel: "gpt"} + c, err := NewFromConfig(oc, "KEY", "") + if err != nil || c == nil || c.Name() != "openai" || c.DefaultModel() == "" { + t.Fatalf("openai new: %v %v", c, err) + } + // Copilot success + cc := Config{Provider: "copilot", CopilotBaseURL: "http://x", CopilotModel: "gpt-4o-mini"} + c2, err := NewFromConfig(cc, "", "KEY") + if err != nil || c2 == nil || c2.Name() != "copilot" || c2.DefaultModel() == "" { + t.Fatalf("copilot new: %v %v", c2, err) + } } - diff --git a/internal/llm/provider_test.go b/internal/llm/provider_test.go index 1412b3c..29e2514 100644 --- a/internal/llm/provider_test.go +++ b/internal/llm/provider_test.go @@ -1,21 +1,29 @@ package llm import ( - "context" - "testing" + "context" + "testing" ) func TestNewFromConfig_DefaultsAndErrors(t *testing.T) { - // Unknown provider - if _, err := NewFromConfig(Config{Provider:"bogus"}, "", ""); err == nil { t.Fatalf("expected error for unknown provider") } - // OpenAI missing key - if _, err := NewFromConfig(Config{Provider:"openai", OpenAIModel:"g"}, "", ""); err == nil { t.Fatalf("expected key error") } - // Copilot missing key - if _, err := NewFromConfig(Config{Provider:"copilot", CopilotModel:"m"}, "", ""); err == nil { t.Fatalf("expected key error") } + // Unknown provider + if _, err := NewFromConfig(Config{Provider: "bogus"}, "", ""); err == nil { + t.Fatalf("expected error for unknown provider") + } + // OpenAI missing key + if _, err := NewFromConfig(Config{Provider: "openai", OpenAIModel: "g"}, "", ""); err == nil { + t.Fatalf("expected key error") + } + // Copilot missing key + if _, err := NewFromConfig(Config{Provider: "copilot", CopilotModel: "m"}, "", ""); err == nil { + t.Fatalf("expected key error") + } } type fakeClientMin struct{} -func (fakeClientMin) Chat(context.Context, []Message, ...RequestOption) (string, error) { return "", nil } -func (fakeClientMin) Name() string { return "x" } -func (fakeClientMin) DefaultModel() string { return "m" } +func (fakeClientMin) Chat(context.Context, []Message, ...RequestOption) (string, error) { + return "", nil +} +func (fakeClientMin) Name() string { return "x" } +func (fakeClientMin) DefaultModel() string { return "m" } diff --git a/internal/llm/util_test.go b/internal/llm/util_test.go index acffe5a..137e149 100644 --- a/internal/llm/util_test.go +++ b/internal/llm/util_test.go @@ -3,7 +3,8 @@ package llm import "testing" func TestNilStringErr(t *testing.T) { - s, err := nilStringErr("boom") - if s != "" || err == nil { t.Fatalf("expected empty string and error") } + s, err := nilStringErr("boom") + if s != "" || err == nil { + t.Fatalf("expected empty string and error") + } } - |
