summaryrefslogtreecommitdiff
path: root/internal/llm
diff options
context:
space:
mode:
Diffstat (limited to 'internal/llm')
-rw-r--r--internal/llm/copilot.go194
-rw-r--r--internal/llm/provider.go18
2 files changed, 190 insertions, 22 deletions
diff --git a/internal/llm/copilot.go b/internal/llm/copilot.go
index 6ab3a0d..7b3574c 100644
--- a/internal/llm/copilot.go
+++ b/internal/llm/copilot.go
@@ -1,4 +1,4 @@
-// Summary: GitHub Copilot client implementation for chat completions using the Copilot API.
+// Summary: GitHub Copilot client for chat and Codex-style code completion.
package llm
import (
@@ -7,10 +7,13 @@ import (
"encoding/json"
"errors"
"fmt"
+ "io"
"net/http"
+ "regexp"
"strings"
"time"
+ "encoding/base64"
appver "hexai/internal"
"hexai/internal/logging"
)
@@ -23,6 +26,10 @@ type copilotClient struct {
defaultModel string
chatLogger logging.ChatLogger
defaultTemperature *float64
+
+ // cached Copilot session token retrieved from GitHub API using apiKey
+ sessionToken string
+ tokenExpiry time.Time
}
type copilotChatRequest struct {
@@ -79,6 +86,10 @@ func (c copilotClient) Chat(ctx context.Context, messages []Message, opts ...Req
if strings.TrimSpace(c.apiKey) == "" {
return nilStringErr("missing Copilot API key")
}
+ // Ensure we have a fresh session token
+ if err := c.ensureSession(ctx); err != nil {
+ return "", err
+ }
o := Options{Model: c.defaultModel}
for _, opt := range opts {
opt(&o)
@@ -102,9 +113,7 @@ func (c copilotClient) Chat(ctx context.Context, messages []Message, opts ...Req
endpoint := c.baseURL + "/chat/completions"
logging.Logf("llm/copilot ", "POST %s", endpoint)
- resp, err := c.doJSON(ctx, endpoint, body, map[string]string{
- "Authorization": "Bearer " + c.apiKey,
- })
+ resp, err := c.postJSON(ctx, endpoint, body, c.headersChat())
if err != nil {
logging.Logf("llm/copilot ", "%shttp error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase)
return "", err
@@ -152,20 +161,11 @@ func buildCopilotChatRequest(o Options, messages []Message, defaultTemp *float64
return req
}
-func (c copilotClient) doJSON(ctx context.Context, url string, body []byte, headers map[string]string) (*http.Response, error) {
- req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
- if err != nil {
- return nil, err
- }
- req.Header.Set("Content-Type", "application/json")
- // GitHub Copilot (GitHub Models) requires an API version header and a UA.
- req.Header.Set("Accept", "application/json")
- req.Header.Set("X-GitHub-Api-Version", "2023-07-07")
- req.Header.Set("User-Agent", "hexai/"+appver.Version)
- for k, v := range headers {
- req.Header.Set(k, v)
- }
- return c.httpClient.Do(req)
+func (c copilotClient) postJSON(ctx context.Context, url string, body []byte, headers map[string]string) (*http.Response, error) {
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
+ if err != nil { return nil, err }
+ for k, v := range headers { req.Header.Set(k, v) }
+ return c.httpClient.Do(req)
}
func handleCopilotNon2xx(resp *http.Response, start time.Time) error {
@@ -190,3 +190,161 @@ func decodeCopilotChat(resp *http.Response, start time.Time) (copilotChatRespons
}
return out, nil
}
+
+// --- Copilot session token management ---
+
+type ghCopilotTokenResp struct {
+ Token string `json:"token"`
+}
+
+func (c *copilotClient) ensureSession(ctx context.Context) error {
+ // If token valid for >60s, reuse
+ if c.sessionToken != "" && time.Now().Add(60*time.Second).Before(c.tokenExpiry) {
+ return nil
+ }
+ if strings.TrimSpace(c.apiKey) == "" {
+ return errors.New("missing Copilot API key")
+ }
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://api.github.com/copilot_internal/v2/token", nil)
+ if err != nil { return err }
+ req.Header.Set("Authorization", "Bearer "+c.apiKey)
+ req.Header.Set("Accept", "application/json")
+ req.Header.Set("User-Agent", "hexai/"+appver.Version)
+ resp, err := c.httpClient.Do(req)
+ if err != nil { return err }
+ defer resp.Body.Close()
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
+ return fmt.Errorf("copilot token http error: %d", resp.StatusCode)
+ }
+ var out ghCopilotTokenResp
+ if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { return err }
+ if strings.TrimSpace(out.Token) == "" { return errors.New("empty copilot session token") }
+ // Parse JWT exp
+ exp := parseJWTExp(out.Token)
+ if exp.IsZero() { exp = time.Now().Add(10 * time.Minute) }
+ c.sessionToken = out.Token
+ c.tokenExpiry = exp
+ return nil
+}
+
+var jwtExpRe = regexp.MustCompile(`"exp"\s*:\s*([0-9]+)`) // fallback if we can't base64 decode
+
+func parseJWTExp(token string) time.Time {
+ parts := strings.Split(token, ".")
+ if len(parts) < 2 { return time.Time{} }
+ b, err := base64.RawURLEncoding.DecodeString(parts[1])
+ if err != nil {
+ if m := jwtExpRe.FindStringSubmatch(token); len(m) == 2 {
+ if n, err2 := parseInt64(m[1]); err2 == nil { return time.Unix(n, 0) }
+ }
+ return time.Time{}
+ }
+ var payload struct{ Exp int64 `json:"exp"` }
+ _ = json.Unmarshal(b, &payload)
+ if payload.Exp == 0 { return time.Time{} }
+ return time.Unix(payload.Exp, 0)
+}
+
+func parseInt64(s string) (int64, error) { var n int64; _, err := fmt.Sscan(s, &n); return n, err }
+
+// --- Copilot headers ---
+
+func (c *copilotClient) headersChat() map[string]string {
+ _ = c.ensureSession(context.Background())
+ h := map[string]string{
+ "Content-Type": "application/json; charset=utf-8",
+ "Accept": "application/json",
+ "Authorization": "Bearer " + c.sessionToken,
+ "User-Agent": "GitHubCopilotChat/0.8.0",
+ "Editor-Plugin-Version": "copilot-chat/0.8.0",
+ "Editor-Version": "vscode/1.85.1",
+ "Openai-Intent": "conversation-panel",
+ "Openai-Organization": "github-copilot",
+ "VScode-MachineId": randHex(64),
+ "VScode-SessionId": randHex(8) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(12),
+ "X-Request-Id": randHex(8) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(12),
+ }
+ return h
+}
+
+func (c *copilotClient) headersGhost() map[string]string {
+ _ = c.ensureSession(context.Background())
+ h := map[string]string{
+ "Content-Type": "application/json; charset=utf-8",
+ "Accept": "*/*",
+ "Authorization": "Bearer " + c.sessionToken,
+ "User-Agent": "GithubCopilot/1.155.0",
+ "Editor-Plugin-Version": "copilot/1.155.0",
+ "Editor-Version": "vscode/1.85.1",
+ "Openai-Intent": "copilot-ghost",
+ "Openai-Organization": "github-copilot",
+ "VScode-MachineId": randHex(64),
+ "VScode-SessionId": randHex(8) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(12),
+ "X-Request-Id": randHex(8) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(4) + "-" + randHex(12),
+ }
+ return h
+}
+
+func randHex(n int) string {
+ const hex = "0123456789abcdef"
+ b := make([]byte, n)
+ for i := range b {
+ b[i] = hex[int(time.Now().UnixNano()+int64(i))%len(hex)]
+ }
+ return string(b)
+}
+
+// --- Codex-style code completion ---
+
+// CodeCompletion implements CodeCompleter; returns up to n suggestions.
+func (c copilotClient) CodeCompletion(ctx context.Context, prompt string, suffix string, n int, language string, temperature float64) ([]string, error) {
+ if strings.TrimSpace(c.apiKey) == "" { return nil, errors.New("missing Copilot API key") }
+ if err := c.ensureSession(ctx); err != nil { return nil, err }
+ if n <= 0 { n = 1 }
+ maxTokens := 500
+ body := map[string]any{
+ "extra": map[string]any{
+ "language": language,
+ "next_indent": 0,
+ "prompt_tokens": 500,
+ "suffix_tokens": 400,
+ "trim_by_indentation": true,
+ },
+ "max_tokens": maxTokens,
+ "n": n,
+ "nwo": "hexai",
+ "prompt": prompt,
+ "stop": []string{"\n\n"},
+ "stream": true,
+ "suffix": suffix,
+ "temperature": temperature,
+ "top_p": 1,
+ }
+ buf, _ := json.Marshal(body)
+ url := "https://copilot-proxy.githubusercontent.com/v1/engines/copilot-codex/completions"
+ resp, err := c.postJSON(ctx, url, buf, c.headersGhost())
+ if err != nil { return nil, err }
+ defer resp.Body.Close()
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
+ return nil, fmt.Errorf("copilot codex http error: %d", resp.StatusCode)
+ }
+ // Read all and parse lines that start with "data: " accumulating by index
+ raw, _ := io.ReadAll(resp.Body)
+ byIndex := make(map[int]string)
+ lines := strings.Split(string(raw), "\n")
+ for _, ln := range lines {
+ if !strings.HasPrefix(ln, "data: ") { continue }
+ var evt struct{ Choices []struct{ Index int `json:"index"`; Text string `json:"text"` } `json:"choices"` }
+ if err := json.Unmarshal([]byte(strings.TrimPrefix(ln, "data: ")), &evt); err != nil { continue }
+ for _, ch := range evt.Choices { byIndex[ch.Index] += ch.Text }
+ }
+ out := make([]string, 0, len(byIndex))
+ for i := 0; i < n; i++ {
+ if s, ok := byIndex[i]; ok && strings.TrimSpace(s) != "" { out = append(out, s) }
+ }
+ return out, nil
+}
+
+// newLineDataReader wraps a streaming body and exposes a JSON decoder that
+// decodes successive objects from lines prefixed by "data: ".
+// (no streaming decoder needed; we parse whole body lines)
diff --git a/internal/llm/provider.go b/internal/llm/provider.go
index ed9ca59..7ab58c6 100644
--- a/internal/llm/provider.go
+++ b/internal/llm/provider.go
@@ -28,10 +28,20 @@ type Client interface {
// token-by-token streaming responses. Callers can type-assert to Streamer and
// fall back to Client.Chat when not implemented.
type Streamer interface {
- // ChatStream sends chat messages and invokes onDelta with incremental text
- // chunks as they are produced by the model. Implementations should call
- // onDelta with empty strings sparingly (prefer only non-empty chunks).
- ChatStream(ctx context.Context, messages []Message, onDelta func(string), opts ...RequestOption) error
+ // ChatStream sends chat messages and invokes onDelta with incremental text
+ // chunks as they are produced by the model. Implementations should call
+ // onDelta with empty strings sparingly (prefer only non-empty chunks).
+ ChatStream(ctx context.Context, messages []Message, onDelta func(string), opts ...RequestOption) error
+}
+
+// CodeCompleter is an optional interface for providers that support a
+// prompt/suffix code-completion API (e.g., Copilot Codex endpoint). Clients
+// can type-assert to this and prefer it over chat when available.
+type CodeCompleter interface {
+ // CodeCompletion requests up to n suggestions given a left-hand prompt and
+ // right-hand suffix around the cursor. Language is advisory and may be
+ // ignored. Temperature applies when provider supports it.
+ CodeCompletion(ctx context.Context, prompt string, suffix string, n int, language string, temperature float64) ([]string, error)
}
// Options for a request. Providers may ignore unsupported fields.