diff options
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/llm/openai.go | 162 | ||||
| -rw-r--r-- | internal/llm/provider.go | 49 | ||||
| -rw-r--r-- | internal/lsp/server.go | 201 | ||||
| -rw-r--r-- | internal/test.go | 18 | ||||
| -rw-r--r-- | internal/version.go | 1 |
5 files changed, 396 insertions, 35 deletions
diff --git a/internal/llm/openai.go b/internal/llm/openai.go new file mode 100644 index 0000000..860c80e --- /dev/null +++ b/internal/llm/openai.go @@ -0,0 +1,162 @@ +package llm + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "log" + "net/http" + "os" + "time" +) + +// openAIClient implements Client against OpenAI's Chat Completions API. +type openAIClient struct { + httpClient *http.Client + apiKey string + baseURL string + defaultModel string + logger *log.Logger +} + +func newOpenAIFromEnv(apiKey string, logger *log.Logger) Client { + base := os.Getenv("OPENAI_BASE_URL") + if base == "" { + base = "https://api.openai.com/v1" + } + model := os.Getenv("OPENAI_MODEL") + if model == "" { + model = "gpt-4o-mini" + } + return &openAIClient{ + httpClient: &http.Client{Timeout: 30 * time.Second}, + apiKey: apiKey, + baseURL: base, + defaultModel: model, + logger: logger, + } +} + +type oaChatRequest struct { + Model string `json:"model"` + Messages []oaMessage `json:"messages"` + Temperature *float64 `json:"temperature,omitempty"` + MaxTokens *int `json:"max_tokens,omitempty"` + Stop []string `json:"stop,omitempty"` +} + +type oaMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type oaChatResponse struct { + Choices []struct { + Index int `json:"index"` + Message struct { + Role string `json:"role"` + Content string `json:"content"` + } `json:"message"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` + Error *struct { + Message string `json:"message"` + Type string `json:"type"` + Param any `json:"param"` + Code any `json:"code"` + } `json:"error,omitempty"` +} + +func (c *openAIClient) Chat(ctx context.Context, messages []Message, opts ...RequestOption) (string, error) { + if c.apiKey == "" { + return nilStringErr("missing OpenAI API key") + } + o := Options{Model: c.defaultModel} + for _, opt := range opts { + opt(&o) + } + if o.Model == "" { + o.Model = c.defaultModel + } + start := time.Now() + c.logf("chat start model=%s temp=%.2f max_tokens=%d stop=%d messages=%d", o.Model, o.Temperature, o.MaxTokens, len(o.Stop), len(messages)) + for i, m := range messages { + c.logf("msg[%d] role=%s size=%d preview=%q", i, m.Role, len(m.Content), trimPreview(m.Content, 200)) + } + req := oaChatRequest{Model: o.Model} + req.Messages = make([]oaMessage, len(messages)) + for i, m := range messages { + req.Messages[i] = oaMessage{Role: m.Role, Content: m.Content} + } + if o.Temperature != 0 { + req.Temperature = &o.Temperature + } + if o.MaxTokens > 0 { + req.MaxTokens = &o.MaxTokens + } + if len(o.Stop) > 0 { + req.Stop = o.Stop + } + + body, err := json.Marshal(req) + if err != nil { + c.logf("marshal error: %v", err) + return "", err + } + endpoint := c.baseURL + "/chat/completions" + c.logf("POST %s", endpoint) + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body)) + if err != nil { + c.logf("new request error: %v", err) + return "", err + } + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Authorization", "Bearer "+c.apiKey) + + resp, err := c.httpClient.Do(httpReq) + if err != nil { + c.logf("http error after %s: %v", time.Since(start), err) + return "", err + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + var apiErr oaChatResponse + _ = json.NewDecoder(resp.Body).Decode(&apiErr) + if apiErr.Error != nil && apiErr.Error.Message != "" { + c.logf("api error status=%d type=%s msg=%s duration=%s", resp.StatusCode, apiErr.Error.Type, apiErr.Error.Message, time.Since(start)) + return "", fmt.Errorf("openai error: %s (status %d)", apiErr.Error.Message, resp.StatusCode) + } + c.logf("http non-2xx status=%d duration=%s", resp.StatusCode, time.Since(start)) + return "", fmt.Errorf("openai http error: status %d", resp.StatusCode) + } + var out oaChatResponse + if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { + c.logf("decode error after %s: %v", time.Since(start), err) + return "", err + } + if len(out.Choices) == 0 { + c.logf("no choices returned duration=%s", time.Since(start)) + return "", errors.New("openai: no choices returned") + } + content := out.Choices[0].Message.Content + c.logf("success choice=0 finish=%s size=%d preview=%q duration=%s", out.Choices[0].FinishReason, len(content), trimPreview(content, 200), time.Since(start)) + return content, nil +} + +// small helper to keep return type consistent +func nilStringErr(msg string) (string, error) { return "", errors.New(msg) } + +func (c *openAIClient) logf(format string, args ...any) { + if c.logger != nil { + c.logger.Printf("llm/openai "+format, args...) + } +} + +func trimPreview(s string, n int) string { + if n <= 0 || len(s) <= n { + return s + } + return s[:n] + "…" +} diff --git a/internal/llm/provider.go b/internal/llm/provider.go new file mode 100644 index 0000000..fd9d4d3 --- /dev/null +++ b/internal/llm/provider.go @@ -0,0 +1,49 @@ +package llm + +import ( + "context" + "errors" + "log" + "os" +) + +// Message represents a chat-style prompt message. +type Message struct { + Role string + Content string +} + +// Client is a minimal LLM provider interface. +// Future providers (Ollama, etc.) should implement this. +type Client interface { + // Chat sends chat messages and returns the assistant text. + Chat(ctx context.Context, messages []Message, opts ...RequestOption) (string, error) +} + +// Options for a request. Providers may ignore unsupported fields. +type Options struct { + Model string + Temperature float64 + MaxTokens int + Stop []string +} + +// RequestOption mutates Options. +type RequestOption func(*Options) + +func WithModel(model string) RequestOption { return func(o *Options) { o.Model = model } } +func WithTemperature(t float64) RequestOption { return func(o *Options) { o.Temperature = t } } +func WithMaxTokens(n int) RequestOption { return func(o *Options) { o.MaxTokens = n } } +func WithStop(stop ...string) RequestOption { + return func(o *Options) { o.Stop = append([]string{}, stop...) } +} + +// NewDefault returns the default provider using environment configuration. +// Currently this is the OpenAI provider using OPENAI_API_KEY. +func NewDefault(logger *log.Logger) (Client, error) { + apiKey := os.Getenv("OPENAI_API_KEY") + if apiKey == "" { + return nil, errors.New("OPENAI_API_KEY is not set") + } + return newOpenAIFromEnv(apiKey, logger), nil +} diff --git a/internal/lsp/server.go b/internal/lsp/server.go index 3949680..ec1a113 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -2,9 +2,11 @@ package lsp import ( "bufio" + "context" "encoding/json" "fmt" "hexai/internal" + "hexai/internal/llm" "io" "log" "net/textproto" @@ -12,6 +14,7 @@ import ( "strconv" "strings" "sync" + "time" ) // JSON-RPC 2.0 structures (minimal) @@ -61,27 +64,39 @@ type CompletionList struct { } type CompletionItem struct { - Label string `json:"label"` - Kind int `json:"kind,omitempty"` - Detail string `json:"detail,omitempty"` - InsertText string `json:"insertText,omitempty"` - SortText string `json:"sortText,omitempty"` - Documentation string `json:"documentation,omitempty"` + Label string `json:"label"` + Kind int `json:"kind,omitempty"` + Detail string `json:"detail,omitempty"` + InsertText string `json:"insertText,omitempty"` + InsertTextFormat int `json:"insertTextFormat,omitempty"` + FilterText string `json:"filterText,omitempty"` + TextEdit *TextEdit `json:"textEdit,omitempty"` + SortText string `json:"sortText,omitempty"` + Documentation string `json:"documentation,omitempty"` } // Server implements a minimal LSP over stdio. type Server struct { - in *bufio.Reader - out io.Writer - logger *log.Logger - exited bool - mu sync.RWMutex - docs map[string]*document - logContext bool + in *bufio.Reader + out io.Writer + logger *log.Logger + exited bool + mu sync.RWMutex + docs map[string]*document + logContext bool + llmClient llm.Client + lastInput time.Time } func NewServer(r io.Reader, w io.Writer, logger *log.Logger, logContext bool) *Server { - return &Server{in: bufio.NewReader(r), out: w, logger: logger, docs: make(map[string]*document), logContext: logContext} + s := &Server{in: bufio.NewReader(r), out: w, logger: logger, docs: make(map[string]*document), logContext: logContext} + if c, err := llm.NewDefault(logger); err != nil { + // Keep running without LLM; completions will be basic. + s.logger.Printf("llm disabled: %v", err) + } else { + s.llmClient = c + } + return s } func (s *Server) Run() error { @@ -137,6 +152,7 @@ func (s *Server) handle(req Request) { var p DidOpenTextDocumentParams if err := json.Unmarshal(req.Params, &p); err == nil { s.setDocument(p.TextDocument.URI, p.TextDocument.Text) + s.markActivity() } case "textDocument/didChange": var p DidChangeTextDocumentParams @@ -144,32 +160,123 @@ func (s *Server) handle(req Request) { if len(p.ContentChanges) > 0 { s.setDocument(p.TextDocument.URI, p.ContentChanges[len(p.ContentChanges)-1].Text) } + s.markActivity() } case "textDocument/didClose": var p DidCloseTextDocumentParams if err := json.Unmarshal(req.Params, &p); err == nil { s.deleteDocument(p.TextDocument.URI) + s.markActivity() } - case "textDocument/completion": - var p CompletionParams - var docStr string - if err := json.Unmarshal(req.Params, &p); err == nil { - above, current, below, funcCtx := s.lineContext(p.TextDocument.URI, p.Position) - docStr = fmt.Sprintf("file: %s\nline: %d\nabove: %s\ncurrent: %s\nbelow: %s\nfunction: %s", p.TextDocument.URI, p.Position.Line, trimLen(above), trimLen(current), trimLen(below), trimLen(funcCtx)) - if s.logContext { - s.logger.Printf("completion ctx uri=%s line=%d char=%d above=%q current=%q below=%q function=%q", - p.TextDocument.URI, p.Position.Line, p.Position.Character, trimLen(above), trimLen(current), trimLen(below), trimLen(funcCtx)) - } - } - items := []CompletionItem{{ - Label: "hexai-complete", - Kind: 14, - Detail: "dummy completion", - InsertText: "hexai", - SortText: "0000", - Documentation: docStr, - }} - s.reply(req.ID, CompletionList{IsIncomplete: false, Items: items}, nil) + case "textDocument/completion": + var p CompletionParams + var docStr string + if err := json.Unmarshal(req.Params, &p); err == nil { + above, current, below, funcCtx := s.lineContext(p.TextDocument.URI, p.Position) + docStr = fmt.Sprintf("file: %s\nline: %d\nabove: %s\ncurrent: %s\nbelow: %s\nfunction: %s", p.TextDocument.URI, p.Position.Line, trimLen(above), trimLen(current), trimLen(below), trimLen(funcCtx)) + if s.logContext { + s.logger.Printf("completion ctx uri=%s line=%d char=%d above=%q current=%q below=%q function=%q", + p.TextDocument.URI, p.Position.Line, p.Position.Character, trimLen(above), trimLen(current), trimLen(below), trimLen(funcCtx)) + } + // Previously: gated LLM calls until 1s idle. Removed to complete as you type. + // Try LLM-backed suggestion if available (always, no idle gating) + if s.llmClient != nil { + ctx, cancel := context.WithTimeout(context.Background(), 6*time.Second) + defer cancel() + // Tailor prompt if inside a Go function parameter list + inParams := false + if strings.Contains(current, "func ") { + open := strings.Index(current, "(") + close := strings.Index(current, ")") + if open >= 0 && p.Position.Character > open && (close == -1 || p.Position.Character <= close) { + inParams = true + } + } + sysPrompt := "You are a terse code completion engine. Return only the code to insert, no surrounding prose or backticks." + userPrompt := fmt.Sprintf("Provide the next likely code to insert at the cursor.\nFile: %s\nFunction/context: %s\nAbove line: %s\nCurrent line (cursor at character %d): %s\nBelow line: %s\nOnly return the completion snippet.", p.TextDocument.URI, funcCtx, above, p.Position.Character, current, below) + if inParams { + sysPrompt = "You are a terse Go code completion engine for function signatures. Return only the parameter list contents (without parentheses), no braces, no prose. Prefer idiomatic names and types." + userPrompt = fmt.Sprintf("Cursor is inside the function parameter list. Suggest only the parameter list (no parentheses).\nFunction line: %s\nCurrent line (cursor at %d): %s", funcCtx, p.Position.Character, current) + } + messages := []llm.Message{ + {Role: "system", Content: sysPrompt}, + {Role: "user", Content: userPrompt}, + } + // keep completions small by default + text, err := s.llmClient.Chat(ctx, messages, llm.WithMaxTokens(96), llm.WithTemperature(0.2)) + if err == nil && strings.TrimSpace(text) != "" { + cleaned := strings.TrimSpace(text) + var te *TextEdit + var filter string + if inParams { + // Replace inside the parentheses + open := strings.Index(current, "(") + close := strings.Index(current, ")") + if open >= 0 { + left := open + 1 + right := len(current) + if close >= 0 && close >= left { + right = close + } + if p.Position.Character < right { + right = p.Position.Character + } + te = &TextEdit{Range: Range{Start: Position{Line: p.Position.Line, Character: left}, End: Position{Line: p.Position.Line, Character: right}}, NewText: cleaned} + if left >= 0 && right >= left && right <= len(current) { + filter = strings.TrimLeft(current[left:right], " \t") + } + } + } + if te == nil { + // compute word start for replacement + startChar := p.Position.Character + if startChar > len(current) { + startChar = len(current) + } + for startChar > 0 { + ch := current[startChar-1] + if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '_' { + startChar-- + continue + } + break + } + te = &TextEdit{Range: Range{Start: Position{Line: p.Position.Line, Character: startChar}, End: Position{Line: p.Position.Line, Character: p.Position.Character}}, NewText: cleaned} + filter = strings.TrimLeft(current[startChar:p.Position.Character], " \t") + } + // Choose a label that starts with the current prefix when possible so the client doesn't filter it out. + label := trimLen(firstLine(cleaned)) + if filter != "" && !strings.HasPrefix(strings.ToLower(label), strings.ToLower(filter)) { + label = filter + } + items := []CompletionItem{{ + Label: label, + Kind: 1, + Detail: "OpenAI completion", + InsertTextFormat: 1, + FilterText: strings.TrimLeft(filter, " \t"), + TextEdit: te, + SortText: "0000", + Documentation: docStr, + }} + s.reply(req.ID, CompletionList{IsIncomplete: false, Items: items}, nil) + return + } + if err != nil { + s.logger.Printf("llm completion error: %v", err) + } + } + } + // Fallback basic/dummy completion + items := []CompletionItem{{ + Label: "hexai-complete", + Kind: 1, + Detail: "dummy completion", + InsertText: "hexai", + SortText: "9999", + Documentation: docStr, + }} + s.reply(req.ID, CompletionList{IsIncomplete: false, Items: items}, nil) default: // Unknown method; reply with Method Not Found for requests that have an ID. if len(req.ID) != 0 { @@ -256,6 +363,12 @@ func (s *Server) deleteDocument(uri string) { delete(s.docs, uri) } +func (s *Server) markActivity() { + s.mu.Lock() + s.lastInput = time.Now() + s.mu.Unlock() +} + func (s *Server) getDocument(uri string) *document { s.mu.RLock() defer s.mu.RUnlock() @@ -314,6 +427,18 @@ type CompletionParams struct { Context any `json:"context,omitempty"` } +// Range defines a text range in a document. +type Range struct { + Start Position `json:"start"` + End Position `json:"end"` +} + +// TextEdit represents a textual edit applicable to a document. +type TextEdit struct { + Range Range `json:"range"` + NewText string `json:"newText"` +} + func (s *Server) lineContext(uri string, pos Position) (above, current, below, funcCtx string) { d := s.getDocument(uri) if d == nil || len(d.lines) == 0 { @@ -359,3 +484,11 @@ func trimLen(s string) string { } return s } + +func firstLine(s string) string { + s = strings.ReplaceAll(s, "\r\n", "\n") + if idx := strings.IndexByte(s, '\n'); idx >= 0 { + return s[:idx] + } + return s +} diff --git a/internal/test.go b/internal/test.go new file mode 100644 index 0000000..586a1bc --- /dev/null +++ b/internal/test.go @@ -0,0 +1,18 @@ +package internal + +import "os" + +func fib(i int) int { + if i <= 1 { + return i + } + return fib(i-1) + fib(i-2) +} + +func countFilesInDir(dirPath string) int { + files, err := os.ReadDir(dirPath) + if err != nil { + return 0 + } + return len(files) +} diff --git a/internal/version.go b/internal/version.go index 525ff73..673db85 100644 --- a/internal/version.go +++ b/internal/version.go @@ -1,4 +1,3 @@ package internal const Version = "0.0.1" - |
