From 61137206eb7dd6a3df865591d710923838f59f18 Mon Sep 17 00:00:00 2001 From: Paul Buetow Date: Fri, 5 Sep 2025 21:17:25 +0300 Subject: over 80% coverage now --- docs/coverage.html | 1697 ++++++++++++++++++++++++++-------------------------- 1 file changed, 864 insertions(+), 833 deletions(-) (limited to 'docs/coverage.html') diff --git a/docs/coverage.html b/docs/coverage.html index 4976a0c..df02a90 100644 --- a/docs/coverage.html +++ b/docs/coverage.html @@ -61,45 +61,47 @@ - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + - + + + @@ -230,7 +232,7 @@ type App struct { } // Constructor: defaults for App (kept first among functions) -func newDefaultConfig() App { +func newDefaultConfig() App { // Coding-friendly default temperature across providers // Users can override per provider in config.json (including 0.0). t := 0.2 @@ -252,18 +254,18 @@ func newDefaultConfig() App { // Load reads configuration from a file and merges with defaults. // It respects the XDG Base Directory Specification. -func Load(logger *log.Logger) App { +func Load(logger *log.Logger) App { cfg := newDefaultConfig() if logger == nil { return cfg // Return defaults if no logger is provided (e.g. in tests) } - configPath, err := getConfigPath() + configPath, err := getConfigPath() if err != nil { logger.Printf("%v", err) // Even if config path cannot be resolved, still allow env overrides below. - } else { - if fileCfg, err := loadFromFile(configPath, logger); err == nil && fileCfg != nil { + } else { + if fileCfg, err := loadFromFile(configPath, logger); err == nil && fileCfg != nil { cfg.mergeWith(fileCfg) } // When the config file is missing or invalid, we keep defaults and still @@ -271,14 +273,14 @@ func Load(logger *log.Logger) App { } // Environment overrides (take precedence over file) - if envCfg := loadFromEnv(logger); envCfg != nil { + if envCfg := loadFromEnv(logger); envCfg != nil { cfg.mergeWith(envCfg) } - return cfg + return cfg } // Private helpers -func loadFromFile(path string, logger *log.Logger) (*App, error) { +func loadFromFile(path string, logger *log.Logger) (*App, error) { f, err := os.Open(path) if err != nil { if !os.IsNotExist(err) && logger != nil { @@ -286,7 +288,7 @@ func loadFromFile(path string, logger *log.Logger) (*App, error) return nil, err } - defer f.Close() + defer f.Close() dec := json.NewDecoder(f) var fileCfg App @@ -296,81 +298,81 @@ func loadFromFile(path string, logger *log.Logger) (*App, error) return nil, err } - return &fileCfg, nil + return &fileCfg, nil } -func (a *App) mergeWith(other *App) { +func (a *App) mergeWith(other *App) { a.mergeBasics(other) a.mergeProviderFields(other) } // mergeBasics merges general (non-provider) fields. -func (a *App) mergeBasics(other *App) { +func (a *App) mergeBasics(other *App) { if other.MaxTokens > 0 { a.MaxTokens = other.MaxTokens } - if s := strings.TrimSpace(other.ContextMode); s != "" { + if s := strings.TrimSpace(other.ContextMode); s != "" { a.ContextMode = s } - if other.ContextWindowLines > 0 { + if other.ContextWindowLines > 0 { a.ContextWindowLines = other.ContextWindowLines } - if other.MaxContextTokens > 0 { + if other.MaxContextTokens > 0 { a.MaxContextTokens = other.MaxContextTokens } - if other.LogPreviewLimit >= 0 { + if other.LogPreviewLimit >= 0 { a.LogPreviewLimit = other.LogPreviewLimit } - if other.CodingTemperature != nil { // allow explicit 0.0 + if other.CodingTemperature != nil { // allow explicit 0.0 a.CodingTemperature = other.CodingTemperature } - if other.ManualInvokeMinPrefix >= 0 { + if other.ManualInvokeMinPrefix >= 0 { a.ManualInvokeMinPrefix = other.ManualInvokeMinPrefix } - if other.CompletionDebounceMs > 0 { a.CompletionDebounceMs = other.CompletionDebounceMs } - if other.CompletionThrottleMs > 0 { a.CompletionThrottleMs = other.CompletionThrottleMs } - if len(other.TriggerCharacters) > 0 { + if other.CompletionDebounceMs > 0 { a.CompletionDebounceMs = other.CompletionDebounceMs } + if other.CompletionThrottleMs > 0 { a.CompletionThrottleMs = other.CompletionThrottleMs } + if len(other.TriggerCharacters) > 0 { a.TriggerCharacters = slices.Clone(other.TriggerCharacters) } - if s := strings.TrimSpace(other.Provider); s != "" { + if s := strings.TrimSpace(other.Provider); s != "" { a.Provider = s } } // mergeProviderFields merges per-provider configuration. -func (a *App) mergeProviderFields(other *App) { +func (a *App) mergeProviderFields(other *App) { if s := strings.TrimSpace(other.OpenAIBaseURL); s != "" { a.OpenAIBaseURL = s } - if s := strings.TrimSpace(other.OpenAIModel); s != "" { + if s := strings.TrimSpace(other.OpenAIModel); s != "" { a.OpenAIModel = s } - if other.OpenAITemperature != nil { // allow explicit 0.0 + if other.OpenAITemperature != nil { // allow explicit 0.0 a.OpenAITemperature = other.OpenAITemperature } - if s := strings.TrimSpace(other.OllamaBaseURL); s != "" { + if s := strings.TrimSpace(other.OllamaBaseURL); s != "" { a.OllamaBaseURL = s } - if s := strings.TrimSpace(other.OllamaModel); s != "" { + if s := strings.TrimSpace(other.OllamaModel); s != "" { a.OllamaModel = s } - if other.OllamaTemperature != nil { // allow explicit 0.0 + if other.OllamaTemperature != nil { // allow explicit 0.0 a.OllamaTemperature = other.OllamaTemperature } - if s := strings.TrimSpace(other.CopilotBaseURL); s != "" { + if s := strings.TrimSpace(other.CopilotBaseURL); s != "" { a.CopilotBaseURL = s } - if s := strings.TrimSpace(other.CopilotModel); s != "" { + if s := strings.TrimSpace(other.CopilotModel); s != "" { a.CopilotModel = s } - if other.CopilotTemperature != nil { // allow explicit 0.0 + if other.CopilotTemperature != nil { // allow explicit 0.0 a.CopilotTemperature = other.CopilotTemperature } } -func getConfigPath() (string, error) { +func getConfigPath() (string, error) { var configPath string - if xdgConfigHome := os.Getenv("XDG_CONFIG_HOME"); xdgConfigHome != "" { + if xdgConfigHome := os.Getenv("XDG_CONFIG_HOME"); xdgConfigHome != "" { configPath = filepath.Join(xdgConfigHome, "hexai", "config.json") } else { home, err := os.UserHomeDir() @@ -379,29 +381,29 @@ func getConfigPath() (string, error) { } configPath = filepath.Join(home, ".config", "hexai", "config.json") } - return configPath, nil + return configPath, nil } // --- Environment overrides --- // loadFromEnv constructs an App containing only fields set via HEXAI_* env vars. // These values should take precedence over file config when merged. -func loadFromEnv(logger *log.Logger) *App { +func loadFromEnv(logger *log.Logger) *App { var out App var any bool // helpers - getenv := func(k string) string { return strings.TrimSpace(os.Getenv(k)) } - parseInt := func(k string) (int, bool) { + getenv := func(k string) string { return strings.TrimSpace(os.Getenv(k)) } + parseInt := func(k string) (int, bool) { v := getenv(k) - if v == "" { return 0, false } + if v == "" { return 0, false } n, err := strconv.Atoi(v) if err != nil { if logger != nil { logger.Printf("invalid %s: %v", k, err) } ; return 0, false } return n, true } - parseFloatPtr := func(k string) (*float64, bool) { + parseFloatPtr := func(k string) (*float64, bool) { v := getenv(k) - if v == "" { return nil, false } + if v == "" { return nil, false } f, err := strconv.ParseFloat(v, 64) if err != nil { if logger != nil { logger.Printf("invalid %s: %v", k, err) } @@ -410,34 +412,34 @@ func loadFromEnv(logger *log.Logger) *App { return &f, true } - if n, ok := parseInt("HEXAI_MAX_TOKENS"); ok { + if n, ok := parseInt("HEXAI_MAX_TOKENS"); ok { out.MaxTokens = n; any = true } - if s := getenv("HEXAI_CONTEXT_MODE"); s != "" { + if s := getenv("HEXAI_CONTEXT_MODE"); s != "" { out.ContextMode = s; any = true } - if n, ok := parseInt("HEXAI_CONTEXT_WINDOW_LINES"); ok { + if n, ok := parseInt("HEXAI_CONTEXT_WINDOW_LINES"); ok { out.ContextWindowLines = n; any = true } - if n, ok := parseInt("HEXAI_MAX_CONTEXT_TOKENS"); ok { + if n, ok := parseInt("HEXAI_MAX_CONTEXT_TOKENS"); ok { out.MaxContextTokens = n; any = true } - if n, ok := parseInt("HEXAI_LOG_PREVIEW_LIMIT"); ok { + if n, ok := parseInt("HEXAI_LOG_PREVIEW_LIMIT"); ok { out.LogPreviewLimit = n; any = true } - if n, ok := parseInt("HEXAI_MANUAL_INVOKE_MIN_PREFIX"); ok { + if n, ok := parseInt("HEXAI_MANUAL_INVOKE_MIN_PREFIX"); ok { out.ManualInvokeMinPrefix = n; any = true } - if n, ok := parseInt("HEXAI_COMPLETION_DEBOUNCE_MS"); ok { + if n, ok := parseInt("HEXAI_COMPLETION_DEBOUNCE_MS"); ok { out.CompletionDebounceMs = n; any = true } - if n, ok := parseInt("HEXAI_COMPLETION_THROTTLE_MS"); ok { + if n, ok := parseInt("HEXAI_COMPLETION_THROTTLE_MS"); ok { out.CompletionThrottleMs = n; any = true } - if f, ok := parseFloatPtr("HEXAI_CODING_TEMPERATURE"); ok { + if f, ok := parseFloatPtr("HEXAI_CODING_TEMPERATURE"); ok { out.CodingTemperature = f; any = true } - if s := getenv("HEXAI_TRIGGER_CHARACTERS"); s != "" { + if s := getenv("HEXAI_TRIGGER_CHARACTERS"); s != "" { parts := strings.Split(s, ",") out.TriggerCharacters = nil for _, p := range parts { @@ -447,24 +449,24 @@ func loadFromEnv(logger *log.Logger) *App { } any = true } - if s := getenv("HEXAI_PROVIDER"); s != "" { + if s := getenv("HEXAI_PROVIDER"); s != "" { out.Provider = s; any = true } // Provider-specific - if s := getenv("HEXAI_OPENAI_BASE_URL"); s != "" { out.OpenAIBaseURL = s; any = true } - if s := getenv("HEXAI_OPENAI_MODEL"); s != "" { out.OpenAIModel = s; any = true } - if f, ok := parseFloatPtr("HEXAI_OPENAI_TEMPERATURE"); ok { out.OpenAITemperature = f; any = true } + if s := getenv("HEXAI_OPENAI_BASE_URL"); s != "" { out.OpenAIBaseURL = s; any = true } + if s := getenv("HEXAI_OPENAI_MODEL"); s != "" { out.OpenAIModel = s; any = true } + if f, ok := parseFloatPtr("HEXAI_OPENAI_TEMPERATURE"); ok { out.OpenAITemperature = f; any = true } - if s := getenv("HEXAI_OLLAMA_BASE_URL"); s != "" { out.OllamaBaseURL = s; any = true } - if s := getenv("HEXAI_OLLAMA_MODEL"); s != "" { out.OllamaModel = s; any = true } - if f, ok := parseFloatPtr("HEXAI_OLLAMA_TEMPERATURE"); ok { out.OllamaTemperature = f; any = true } + if s := getenv("HEXAI_OLLAMA_BASE_URL"); s != "" { out.OllamaBaseURL = s; any = true } + if s := getenv("HEXAI_OLLAMA_MODEL"); s != "" { out.OllamaModel = s; any = true } + if f, ok := parseFloatPtr("HEXAI_OLLAMA_TEMPERATURE"); ok { out.OllamaTemperature = f; any = true } - if s := getenv("HEXAI_COPILOT_BASE_URL"); s != "" { out.CopilotBaseURL = s; any = true } - if s := getenv("HEXAI_COPILOT_MODEL"); s != "" { out.CopilotModel = s; any = true } - if f, ok := parseFloatPtr("HEXAI_COPILOT_TEMPERATURE"); ok { out.CopilotTemperature = f; any = true } + if s := getenv("HEXAI_COPILOT_BASE_URL"); s != "" { out.CopilotBaseURL = s; any = true } + if s := getenv("HEXAI_COPILOT_MODEL"); s != "" { out.CopilotModel = s; any = true } + if f, ok := parseFloatPtr("HEXAI_COPILOT_TEMPERATURE"); ok { out.CopilotTemperature = f; any = true } - if !any { + if !any { return nil } return &out @@ -492,12 +494,12 @@ import ( // Run executes the Hexai CLI behavior given arguments and I/O streams. // It assumes flags have already been parsed by the caller. -func Run(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io.Writer) error { +func Run(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io.Writer) error { // Load configuration with a logger so file-based config is respected. logger := log.New(stderr, "hexai ", log.LstdFlags|log.Lmsgprefix) cfg := appconfig.Load(logger) client, err := newClientFromConfig(cfg) - if err != nil { + if err != nil { fmt.Fprintf(stderr, logging.AnsiBase+"hexai: LLM disabled: %v"+logging.AnsiReset+"\n", err) return err } @@ -507,35 +509,35 @@ func Run(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io. // RunWithClient executes the CLI flow using an already-constructed client. // Useful for testing and embedding. -func RunWithClient(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io.Writer, client llm.Client) error { +func RunWithClient(ctx context.Context, args []string, stdin io.Reader, stdout, stderr io.Writer, client llm.Client) error { input, err := readInput(stdin, args) if err != nil { fmt.Fprintln(stderr, logging.AnsiBase+err.Error()+logging.AnsiReset) return err } - printProviderInfo(stderr, client) + printProviderInfo(stderr, client) msgs := buildMessages(input) - if err := runChat(ctx, client, msgs, input, stdout, stderr); err != nil { + if err := runChat(ctx, client, msgs, input, stdout, stderr); err != nil { fmt.Fprintf(stderr, logging.AnsiBase+"hexai: error: %v"+logging.AnsiReset+"\n", err) return err } - return nil + return nil } // readInput reads from stdin and args, then combines them per CLI rules. -func readInput(stdin io.Reader, args []string) (string, error) { +func readInput(stdin io.Reader, args []string) (string, error) { var stdinData string - if fi, err := os.Stdin.Stat(); err == nil && (fi.Mode()&os.ModeCharDevice) == 0 { + if fi, err := os.Stdin.Stat(); err == nil && (fi.Mode()&os.ModeCharDevice) == 0 { b, _ := io.ReadAll(bufio.NewReader(stdin)) stdinData = strings.TrimSpace(string(b)) } - argData := strings.TrimSpace(strings.Join(args, " ")) + argData := strings.TrimSpace(strings.Join(args, " ")) switch { - case stdinData != "" && argData != "": + case stdinData != "" && argData != "": return fmt.Sprintf("%s:\n\n%s", argData, stdinData), nil case stdinData != "": return stdinData, nil - case argData != "": + case argData != "": return argData, nil default: return "", fmt.Errorf("hexai: no input provided; pass text as an argument or via stdin") @@ -543,7 +545,7 @@ func readInput(stdin io.Reader, args []string) (string, error) { +func newClientFromConfig(cfg appconfig.App) (llm.Client, error) { llmCfg := llm.Config{ Provider: cfg.Provider, OpenAIBaseURL: cfg.OpenAIBaseURL, @@ -558,59 +560,59 @@ func newClientFromConfig(cfg appconfig.App) (llm.Client, error) { + if strings.TrimSpace(oaKey) == "" { oaKey = os.Getenv("OPENAI_API_KEY") } // Prefer HEXAI_COPILOT_API_KEY; fall back to COPILOT_API_KEY - cpKey := os.Getenv("HEXAI_COPILOT_API_KEY") - if strings.TrimSpace(cpKey) == "" { + cpKey := os.Getenv("HEXAI_COPILOT_API_KEY") + if strings.TrimSpace(cpKey) == "" { cpKey = os.Getenv("COPILOT_API_KEY") } - return llm.NewFromConfig(llmCfg, oaKey, cpKey) + return llm.NewFromConfig(llmCfg, oaKey, cpKey) } // buildMessages creates system and user messages based on input content. -func buildMessages(input string) []llm.Message { +func buildMessages(input string) []llm.Message { lower := strings.ToLower(input) system := "You are Hexai CLI. Default to very short, concise answers. If the user asks for commands, output only the commands (one per line) with no commentary or explanation. Only when the word 'explain' appears in the prompt, produce a verbose explanation." if strings.Contains(lower, "explain") { system = "You are Hexai CLI. The user requested an explanation. Provide a clear, verbose explanation with reasoning and details. If commands are needed, include them with brief context." } - return []llm.Message{ + return []llm.Message{ {Role: "system", Content: system}, {Role: "user", Content: input}, } } // runChat executes the chat request, handling streaming and summary output. -func runChat(ctx context.Context, client llm.Client, msgs []llm.Message, input string, out io.Writer, errw io.Writer) error { +func runChat(ctx context.Context, client llm.Client, msgs []llm.Message, input string, out io.Writer, errw io.Writer) error { start := time.Now() var output string - if s, ok := client.(llm.Streamer); ok { + if s, ok := client.(llm.Streamer); ok { var b strings.Builder - if err := s.ChatStream(ctx, msgs, func(chunk string) { + if err := s.ChatStream(ctx, msgs, func(chunk string) { b.WriteString(chunk) fmt.Fprint(out, chunk) }); err != nil { return err } - output = b.String() + output = b.String() } else { txt, err := client.Chat(ctx, msgs) - if err != nil { + if err != nil { return err } - output = txt + output = txt fmt.Fprint(out, output) } - dur := time.Since(start) + dur := time.Since(start) fmt.Fprintf(errw, "\n"+logging.AnsiBase+"done provider=%s model=%s time=%s in_bytes=%d out_bytes=%d"+logging.AnsiReset+"\n", client.Name(), client.DefaultModel(), dur.Round(time.Millisecond), len(input), len(output)) return nil } // printProviderInfo writes the provider/model line to stderr. -func printProviderInfo(errw io.Writer, client llm.Client) { +func printProviderInfo(errw io.Writer, client llm.Client) { fmt.Fprintf(errw, logging.AnsiBase+"provider=%s model=%s"+logging.AnsiReset+"\n", client.Name(), client.DefaultModel()) } @@ -804,16 +806,16 @@ type copilotChatResponse struct { } // Constructor (kept among the first functions by convention) -func newCopilot(baseURL, model, apiKey string, defaultTemp *float64) Client { +func newCopilot(baseURL, model, apiKey string, defaultTemp *float64) Client { if strings.TrimSpace(baseURL) == "" { baseURL = "https://api.githubcopilot.com" } - if strings.TrimSpace(model) == "" { + if strings.TrimSpace(model) == "" { // GitHub Models (Copilot API) commonly supports gpt-4o/gpt-4o-mini. // Default to a broadly available, cost-effective option. model = "gpt-4o-mini" } - return copilotClient{ + return copilotClient{ httpClient: &http.Client{Timeout: 30 * time.Second}, apiKey: apiKey, baseURL: strings.TrimRight(baseURL, "/"), @@ -823,27 +825,27 @@ func newCopilot(baseURL, model, apiKey string, defaultTemp *float64) Client } -func (c copilotClient) Chat(ctx context.Context, messages []Message, opts ...RequestOption) (string, error) { +func (c copilotClient) Chat(ctx context.Context, messages []Message, opts ...RequestOption) (string, error) { if strings.TrimSpace(c.apiKey) == "" { return nilStringErr("missing Copilot API key") } // Ensure we have a fresh session token - if err := c.ensureSession(ctx); err != nil { + if err := c.ensureSession(ctx); err != nil { return "", err } - o := Options{Model: c.defaultModel} + o := Options{Model: c.defaultModel} for _, opt := range opts { opt(&o) } - if o.Model == "" { + if o.Model == "" { o.Model = c.defaultModel } - start := time.Now() + start := time.Now() logMessages := make([]struct{ Role, Content string }, len(messages)) - for i, m := range messages { + for i, m := range messages { logMessages[i] = struct{ Role, Content string }{m.Role, m.Content} } - c.chatLogger.LogStart(false, o.Model, o.Temperature, o.MaxTokens, o.Stop, logMessages) + c.chatLogger.LogStart(false, o.Model, o.Temperature, o.MaxTokens, o.Stop, logMessages) req := buildCopilotChatRequest(o, messages, c.defaultTemperature) body, err := json.Marshal(req) @@ -852,70 +854,70 @@ func (c copilotClient) Chat(ctx context.Context, messages []Message, opts ...Req return "", err } - endpoint := c.baseURL + "/chat/completions" + endpoint := c.baseURL + "/chat/completions" logging.Logf("llm/copilot ", "POST %s", endpoint) resp, err := c.postJSON(ctx, endpoint, body, c.headersChat()) if err != nil { logging.Logf("llm/copilot ", "%shttp error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase) return "", err } - defer resp.Body.Close() - if err := handleCopilotNon2xx(resp, start); err != nil { + defer resp.Body.Close() + if err := handleCopilotNon2xx(resp, start); err != nil { return "", err } - out, err := decodeCopilotChat(resp, start) - if err != nil { + out, err := decodeCopilotChat(resp, start) + if err != nil { return "", err } - if len(out.Choices) == 0 { + if len(out.Choices) == 0 { logging.Logf("llm/copilot ", "%sno choices returned duration=%s%s", logging.AnsiRed, time.Since(start), logging.AnsiBase) return "", errors.New("copilot: no choices returned") } - content := out.Choices[0].Message.Content + content := out.Choices[0].Message.Content logging.Logf("llm/copilot ", "success choice=0 finish=%s size=%d preview=%s%s%s duration=%s", out.Choices[0].FinishReason, len(content), logging.AnsiGreen, logging.PreviewForLog(content), logging.AnsiBase, time.Since(start)) return content, nil } // Provider metadata -func (c copilotClient) Name() string { return "copilot" } -func (c copilotClient) DefaultModel() string { return c.defaultModel } +func (c copilotClient) Name() string { return "copilot" } +func (c copilotClient) DefaultModel() string { return c.defaultModel } // helpers -func buildCopilotChatRequest(o Options, messages []Message, defaultTemp *float64) copilotChatRequest { +func buildCopilotChatRequest(o Options, messages []Message, defaultTemp *float64) copilotChatRequest { req := copilotChatRequest{Model: o.Model} req.Messages = make([]copilotMessage, len(messages)) - for i, m := range messages { + for i, m := range messages { req.Messages[i] = copilotMessage{Role: m.Role, Content: m.Content} } - if o.Temperature != 0 { + if o.Temperature != 0 { req.Temperature = &o.Temperature - } else if defaultTemp != nil { + } else if defaultTemp != nil { t := *defaultTemp req.Temperature = &t } - if o.MaxTokens > 0 { + if o.MaxTokens > 0 { req.MaxTokens = &o.MaxTokens } - if len(o.Stop) > 0 { + if len(o.Stop) > 0 { req.Stop = o.Stop } - return req + return req } -func (c copilotClient) postJSON(ctx context.Context, url string, body []byte, headers map[string]string) (*http.Response, error) { +func (c copilotClient) postJSON(ctx context.Context, url string, body []byte, headers map[string]string) (*http.Response, error) { req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) if err != nil { return nil, err } - for k, v := range headers { req.Header.Set(k, v) } - return c.httpClient.Do(req) + for k, v := range headers { req.Header.Set(k, v) } + return c.httpClient.Do(req) } -func handleCopilotNon2xx(resp *http.Response, start time.Time) error { - if resp.StatusCode >= 200 && resp.StatusCode < 300 { +func handleCopilotNon2xx(resp *http.Response, start time.Time) error { + if resp.StatusCode >= 200 && resp.StatusCode < 300 { return nil } - var apiErr copilotChatResponse + var apiErr copilotChatResponse _ = json.NewDecoder(resp.Body).Decode(&apiErr) - if apiErr.Error != nil && strings.TrimSpace(apiErr.Error.Message) != "" { + if apiErr.Error != nil && strings.TrimSpace(apiErr.Error.Message) != "" { logging.Logf("llm/copilot ", "%sapi error status=%d type=%s msg=%s duration=%s%s", logging.AnsiRed, resp.StatusCode, apiErr.Error.Type, apiErr.Error.Message, time.Since(start), logging.AnsiBase) return fmt.Errorf("copilot error: %s (status %d)", apiErr.Error.Message, resp.StatusCode) } @@ -923,13 +925,13 @@ func handleCopilotNon2xx(resp *http.Response, start time.Time) error } -func decodeCopilotChat(resp *http.Response, start time.Time) (copilotChatResponse, error) { +func decodeCopilotChat(resp *http.Response, start time.Time) (copilotChatResponse, error) { var out copilotChatResponse - if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { + if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { logging.Logf("llm/copilot ", "%sdecode error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase) return copilotChatResponse{}, err } - return out, nil + return out, nil } // --- Copilot session token management --- @@ -938,59 +940,59 @@ type ghCopilotTokenResp struct { Token string `json:"token"` } -func (c *copilotClient) ensureSession(ctx context.Context) error { +func (c *copilotClient) ensureSession(ctx context.Context) error { // If token valid for >60s, reuse - if c.sessionToken != "" && time.Now().Add(60*time.Second).Before(c.tokenExpiry) { + if c.sessionToken != "" && time.Now().Add(60*time.Second).Before(c.tokenExpiry) { return nil } - if strings.TrimSpace(c.apiKey) == "" { + if strings.TrimSpace(c.apiKey) == "" { return errors.New("missing Copilot API key") } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://api.github.com/copilot_internal/v2/token", nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://api.github.com/copilot_internal/v2/token", nil) if err != nil { return err } - req.Header.Set("Authorization", "Bearer "+c.apiKey) + req.Header.Set("Authorization", "Bearer "+c.apiKey) req.Header.Set("Accept", "application/json") req.Header.Set("User-Agent", "hexai/"+appver.Version) resp, err := c.httpClient.Do(req) if err != nil { return err } - defer resp.Body.Close() + defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 300 { return fmt.Errorf("copilot token http error: %d", resp.StatusCode) } - var out ghCopilotTokenResp + var out ghCopilotTokenResp if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { return err } - if strings.TrimSpace(out.Token) == "" { return errors.New("empty copilot session token") } + if strings.TrimSpace(out.Token) == "" { return errors.New("empty copilot session token") } // Parse JWT exp - exp := parseJWTExp(out.Token) - if exp.IsZero() { exp = time.Now().Add(10 * time.Minute) } - c.sessionToken = out.Token + exp := parseJWTExp(out.Token) + if exp.IsZero() { exp = time.Now().Add(10 * time.Minute) } + c.sessionToken = out.Token c.tokenExpiry = exp return nil } var jwtExpRe = regexp.MustCompile(`"exp"\s*:\s*([0-9]+)`) // fallback if we can't base64 decode -func parseJWTExp(token string) time.Time { +func parseJWTExp(token string) time.Time { parts := strings.Split(token, ".") - if len(parts) < 2 { return time.Time{} } - b, err := base64.RawURLEncoding.DecodeString(parts[1]) + if len(parts) < 2 { return time.Time{} } + b, err := base64.RawURLEncoding.DecodeString(parts[1]) if err != nil { if m := jwtExpRe.FindStringSubmatch(token); len(m) == 2 { if n, err2 := parseInt64(m[1]); err2 == nil { return time.Unix(n, 0) } } return time.Time{} } - var payload struct{ Exp int64 `json:"exp"` } + var payload struct{ Exp int64 `json:"exp"` } _ = json.Unmarshal(b, &payload) if payload.Exp == 0 { return time.Time{} } - return time.Unix(payload.Exp, 0) + return time.Unix(payload.Exp, 0) } -func parseInt64(s string) (int64, error) { var n int64; _, err := fmt.Sscan(s, &n); return n, err } +func parseInt64(s string) (int64, error) { var n int64; _, err := fmt.Sscan(s, &n); return n, err } // --- Copilot headers --- -func (c *copilotClient) headersChat() map[string]string { +func (c *copilotClient) headersChat() map[string]string { _ = c.ensureSession(context.Background()) h := map[string]string{ "Content-Type": "application/json; charset=utf-8", @@ -1008,7 +1010,7 @@ func (c *copilotClient) headersChat() map[string]string -func (c *copilotClient) headersGhost() map[string]string { +func (c *copilotClient) headersGhost() map[string]string { _ = c.ensureSession(context.Background()) h := map[string]string{ "Content-Type": "application/json; charset=utf-8", @@ -1026,23 +1028,23 @@ func (c *copilotClient) headersGhost() map[string]string -func randHex(n int) string { +func randHex(n int) string { const hex = "0123456789abcdef" b := make([]byte, n) - for i := range b { + for i := range b { b[i] = hex[int(time.Now().UnixNano()+int64(i))%len(hex)] } - return string(b) + return string(b) } // --- Codex-style code completion --- // CodeCompletion implements CodeCompleter; returns up to n suggestions. -func (c copilotClient) CodeCompletion(ctx context.Context, prompt string, suffix string, n int, language string, temperature float64) ([]string, error) { +func (c copilotClient) CodeCompletion(ctx context.Context, prompt string, suffix string, n int, language string, temperature float64) ([]string, error) { if strings.TrimSpace(c.apiKey) == "" { return nil, errors.New("missing Copilot API key") } - if err := c.ensureSession(ctx); err != nil { return nil, err } - if n <= 0 { n = 1 } - maxTokens := 500 + if err := c.ensureSession(ctx); err != nil { return nil, err } + if n <= 0 { n = 1 } + maxTokens := 500 body := map[string]any{ "extra": map[string]any{ "language": language, @@ -1065,25 +1067,25 @@ func (c copilotClient) CodeCompletion(ctx context.Context, prompt string, suffix url := "https://copilot-proxy.githubusercontent.com/v1/engines/copilot-codex/completions" resp, err := c.postJSON(ctx, url, buf, c.headersGhost()) if err != nil { return nil, err } - defer resp.Body.Close() + defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode >= 300 { return nil, fmt.Errorf("copilot codex http error: %d", resp.StatusCode) } // Read all and parse lines that start with "data: " accumulating by index - raw, _ := io.ReadAll(resp.Body) + raw, _ := io.ReadAll(resp.Body) byIndex := make(map[int]string) lines := strings.Split(string(raw), "\n") - for _, ln := range lines { - if !strings.HasPrefix(ln, "data: ") { continue } - var evt struct{ Choices []struct{ Index int `json:"index"`; Text string `json:"text"` } `json:"choices"` } - if err := json.Unmarshal([]byte(strings.TrimPrefix(ln, "data: ")), &evt); err != nil { continue } - for _, ch := range evt.Choices { byIndex[ch.Index] += ch.Text } + for _, ln := range lines { + if !strings.HasPrefix(ln, "data: ") { continue } + var evt struct{ Choices []struct{ Index int `json:"index"`; Text string `json:"text"` } `json:"choices"` } + if err := json.Unmarshal([]byte(strings.TrimPrefix(ln, "data: ")), &evt); err != nil { continue } + for _, ch := range evt.Choices { byIndex[ch.Index] += ch.Text } } - out := make([]string, 0, len(byIndex)) - for i := 0; i < n; i++ { - if s, ok := byIndex[i]; ok && strings.TrimSpace(s) != "" { out = append(out, s) } + out := make([]string, 0, len(byIndex)) + for i := 0; i < n; i++ { + if s, ok := byIndex[i]; ok && strings.TrimSpace(s) != "" { out = append(out, s) } } - return out, nil + return out, nil } // newLineDataReader wraps a streaming body and exposes a JSON decoder that @@ -1134,14 +1136,14 @@ type ollamaChatResponse struct { } // Constructor (kept among the first functions by convention) -func newOllama(baseURL, model string, defaultTemp *float64) Client { +func newOllama(baseURL, model string, defaultTemp *float64) Client { if strings.TrimSpace(baseURL) == "" { baseURL = "http://localhost:11434" } - if strings.TrimSpace(model) == "" { + if strings.TrimSpace(model) == "" { model = "qwen3-coder:30b-a3b-q4_K_M`" } - return ollamaClient{ + return ollamaClient{ httpClient: &http.Client{Timeout: 30 * time.Second}, baseURL: strings.TrimRight(baseURL, "/"), defaultModel: model, @@ -1150,16 +1152,16 @@ func newOllama(baseURL, model string, defaultTemp *float64) Client { +func (c ollamaClient) Chat(ctx context.Context, messages []Message, opts ...RequestOption) (string, error) { o := Options{Model: c.defaultModel} for _, opt := range opts { opt(&o) } - if o.Model == "" { + if o.Model == "" { o.Model = c.defaultModel } - start := time.Now() + start := time.Now() c.logStart(false, o, messages) req := buildOllamaRequest(o, messages, c.defaultTemperature, false) body, err := json.Marshal(req) @@ -1167,47 +1169,47 @@ func (c ollamaClient) Chat(ctx context.Context, messages []Message, opts ...Requ return "", err } - endpoint := c.baseURL + "/api/chat" + endpoint := c.baseURL + "/api/chat" logging.Logf("llm/ollama ", "POST %s", endpoint) resp, err := c.doJSON(ctx, endpoint, body) - if err != nil { + if err != nil { logging.Logf("llm/ollama ", "%shttp error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase) return "", err } - defer resp.Body.Close() - if err := handleOllamaNon2xx(resp, start); err != nil { + defer resp.Body.Close() + if err := handleOllamaNon2xx(resp, start); err != nil { return "", err } - var out ollamaChatResponse - if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { + var out ollamaChatResponse + if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { logging.Logf("llm/ollama ", "%sdecode error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase) return "", err } - if strings.TrimSpace(out.Message.Content) == "" { + if strings.TrimSpace(out.Message.Content) == "" { logging.Logf("llm/ollama ", "%sempty content returned duration=%s%s", logging.AnsiRed, time.Since(start), logging.AnsiBase) return "", errors.New("ollama: empty content") } - content := out.Message.Content + content := out.Message.Content logging.Logf("llm/ollama ", "success size=%d preview=%s%s%s duration=%s", len(content), logging.AnsiGreen, logging.PreviewForLog(content), logging.AnsiBase, time.Since(start)) return content, nil } // Provider metadata -func (c ollamaClient) Name() string { return "ollama" } -func (c ollamaClient) DefaultModel() string { return c.defaultModel } +func (c ollamaClient) Name() string { return "ollama" } +func (c ollamaClient) DefaultModel() string { return c.defaultModel } // Streaming support (optional) -func (c ollamaClient) ChatStream(ctx context.Context, messages []Message, onDelta func(string), opts ...RequestOption) error { +func (c ollamaClient) ChatStream(ctx context.Context, messages []Message, onDelta func(string), opts ...RequestOption) error { o := Options{Model: c.defaultModel} for _, opt := range opts { opt(&o) } - if o.Model == "" { + if o.Model == "" { o.Model = c.defaultModel } - start := time.Now() + start := time.Now() c.logStart(true, o, messages) req := buildOllamaRequest(o, messages, c.defaultTemperature, true) body, err := json.Marshal(req) @@ -1215,96 +1217,96 @@ func (c ollamaClient) ChatStream(ctx context.Context, messages []Message, onDelt return err } - endpoint := c.baseURL + "/api/chat" + endpoint := c.baseURL + "/api/chat" logging.Logf("llm/ollama ", "POST %s (stream)", endpoint) resp, err := c.doJSON(ctx, endpoint, body) if err != nil { logging.Logf("llm/ollama ", "%shttp error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase) return err } - defer resp.Body.Close() + defer resp.Body.Close() if err := handleOllamaNon2xx(resp, start); err != nil { return err } - dec := json.NewDecoder(resp.Body) - for { + dec := json.NewDecoder(resp.Body) + for { var ev ollamaChatResponse - if err := dec.Decode(&ev); err != nil { + if err := dec.Decode(&ev); err != nil { if errors.Is(err, io.EOF) { break } - logging.Logf("llm/ollama ", "%sdecode stream error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase) + logging.Logf("llm/ollama ", "%sdecode stream error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase) return err } - if strings.TrimSpace(ev.Error) != "" { + if strings.TrimSpace(ev.Error) != "" { logging.Logf("llm/ollama ", "%sstream event error: %s%s", logging.AnsiRed, ev.Error, logging.AnsiBase) return fmt.Errorf("ollama stream error: %s", ev.Error) } - if s := ev.Message.Content; strings.TrimSpace(s) != "" { + if s := ev.Message.Content; strings.TrimSpace(s) != "" { onDelta(s) } - if ev.Done { + if ev.Done { break } } - logging.Logf("llm/ollama ", "stream end duration=%s", time.Since(start)) + logging.Logf("llm/ollama ", "stream end duration=%s", time.Since(start)) return nil } // helpers to keep methods small -func (c ollamaClient) logStart(stream bool, o Options, messages []Message) { +func (c ollamaClient) logStart(stream bool, o Options, messages []Message) { logMessages := make([]struct{ Role, Content string }, len(messages)) - for i, m := range messages { + for i, m := range messages { logMessages[i] = struct{ Role, Content string }{m.Role, m.Content} } - c.chatLogger.LogStart(stream, o.Model, o.Temperature, o.MaxTokens, o.Stop, logMessages) + c.chatLogger.LogStart(stream, o.Model, o.Temperature, o.MaxTokens, o.Stop, logMessages) } -func buildOllamaRequest(o Options, messages []Message, defaultTemp *float64, stream bool) ollamaChatRequest { +func buildOllamaRequest(o Options, messages []Message, defaultTemp *float64, stream bool) ollamaChatRequest { req := ollamaChatRequest{Model: o.Model, Stream: stream} req.Messages = make([]oaMessage, len(messages)) - for i, m := range messages { + for i, m := range messages { req.Messages[i] = oaMessage{Role: m.Role, Content: m.Content} } - optsMap := map[string]any{} - if o.Temperature != 0 { + optsMap := map[string]any{} + if o.Temperature != 0 { optsMap["temperature"] = o.Temperature - } else if defaultTemp != nil { + } else if defaultTemp != nil { optsMap["temperature"] = *defaultTemp } - if o.MaxTokens > 0 { + if o.MaxTokens > 0 { optsMap["num_predict"] = o.MaxTokens } - if len(o.Stop) > 0 { + if len(o.Stop) > 0 { optsMap["stop"] = o.Stop } - if len(optsMap) > 0 { + if len(optsMap) > 0 { req.Options = optsMap } - return req + return req } -func (c ollamaClient) doJSON(ctx context.Context, url string, body []byte) (*http.Response, error) { +func (c ollamaClient) doJSON(ctx context.Context, url string, body []byte) (*http.Response, error) { req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) if err != nil { return nil, err } - req.Header.Set("Content-Type", "application/json") + req.Header.Set("Content-Type", "application/json") return c.httpClient.Do(req) } -func handleOllamaNon2xx(resp *http.Response, start time.Time) error { - if resp.StatusCode >= 200 && resp.StatusCode < 300 { +func handleOllamaNon2xx(resp *http.Response, start time.Time) error { + if resp.StatusCode >= 200 && resp.StatusCode < 300 { return nil } - var apiErr ollamaChatResponse + var apiErr ollamaChatResponse _ = json.NewDecoder(resp.Body).Decode(&apiErr) - if strings.TrimSpace(apiErr.Error) != "" { + if strings.TrimSpace(apiErr.Error) != "" { logging.Logf("llm/ollama ", "%sapi error status=%d msg=%s duration=%s%s", logging.AnsiRed, resp.StatusCode, apiErr.Error, time.Since(start), logging.AnsiBase) return fmt.Errorf("ollama error: %s (status %d)", apiErr.Error, resp.StatusCode) } - logging.Logf("llm/ollama ", "%shttp non-2xx status=%d duration=%s%s", logging.AnsiRed, resp.StatusCode, time.Since(start), logging.AnsiBase) + logging.Logf("llm/ollama ", "%shttp non-2xx status=%d duration=%s%s", logging.AnsiRed, resp.StatusCode, time.Since(start), logging.AnsiBase) return fmt.Errorf("ollama http error: status %d", resp.StatusCode) } @@ -1386,14 +1388,14 @@ type oaStreamChunk struct { // Constructor (kept among the first functions by convention) // newOpenAI constructs an OpenAI client using explicit configuration values. // The apiKey may be empty; calls will fail until a valid key is supplied. -func newOpenAI(baseURL, model, apiKey string, defaultTemp *float64) Client { - if strings.TrimSpace(baseURL) == "" { +func newOpenAI(baseURL, model, apiKey string, defaultTemp *float64) Client { + if strings.TrimSpace(baseURL) == "" { baseURL = "https://api.openai.com/v1" } - if strings.TrimSpace(model) == "" { + if strings.TrimSpace(model) == "" { model = "gpt-4.1" } - return openAIClient{ + return openAIClient{ httpClient: &http.Client{Timeout: 30 * time.Second}, apiKey: apiKey, baseURL: baseURL, @@ -1403,18 +1405,18 @@ func newOpenAI(baseURL, model, apiKey string, defaultTemp *float64) Client } -func (c openAIClient) Chat(ctx context.Context, messages []Message, opts ...RequestOption) (string, error) { - if c.apiKey == "" { +func (c openAIClient) Chat(ctx context.Context, messages []Message, opts ...RequestOption) (string, error) { + if c.apiKey == "" { return nilStringErr("missing OpenAI API key") } - o := Options{Model: c.defaultModel} + o := Options{Model: c.defaultModel} for _, opt := range opts { opt(&o) } - if o.Model == "" { + if o.Model == "" { o.Model = c.defaultModel } - start := time.Now() + start := time.Now() c.logStart(false, o, messages) req := buildOAChatRequest(o, messages, c.defaultTemperature, false) body, err := json.Marshal(req) @@ -1422,7 +1424,7 @@ func (c openAIClient) Chat(ctx context.Context, messages []Message, opts ...Requ c.logf("marshal error: %v", err) return "", err } - endpoint := c.baseURL + "/chat/completions" + endpoint := c.baseURL + "/chat/completions" logging.Logf("llm/openai ", "POST %s", endpoint) resp, err := c.doJSON(ctx, endpoint, body, map[string]string{ "Authorization": "Bearer " + c.apiKey, @@ -1431,41 +1433,41 @@ func (c openAIClient) Chat(ctx context.Context, messages []Message, opts ...Requ logging.Logf("llm/openai ", "%shttp error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase) return "", err } - defer resp.Body.Close() - if err := handleOpenAINon2xx(resp, start); err != nil { + defer resp.Body.Close() + if err := handleOpenAINon2xx(resp, start); err != nil { return "", err } - out, err := decodeOpenAIChat(resp, start) - if err != nil { + out, err := decodeOpenAIChat(resp, start) + if err != nil { return "", err } - if len(out.Choices) == 0 { + if len(out.Choices) == 0 { logging.Logf("llm/openai ", "%sno choices returned duration=%s%s", logging.AnsiRed, time.Since(start), logging.AnsiBase) return "", errors.New("openai: no choices returned") } - content := out.Choices[0].Message.Content + content := out.Choices[0].Message.Content logging.Logf("llm/openai ", "success choice=0 finish=%s size=%d preview=%s%s%s duration=%s", out.Choices[0].FinishReason, len(content), logging.AnsiGreen, logging.PreviewForLog(content), logging.AnsiBase, time.Since(start)) return content, nil } // Provider metadata -func (c openAIClient) Name() string { return "openai" } -func (c openAIClient) DefaultModel() string { return c.defaultModel } +func (c openAIClient) Name() string { return "openai" } +func (c openAIClient) DefaultModel() string { return c.defaultModel } // Streaming support (optional) -func (c openAIClient) ChatStream(ctx context.Context, messages []Message, onDelta func(string), opts ...RequestOption) error { +func (c openAIClient) ChatStream(ctx context.Context, messages []Message, onDelta func(string), opts ...RequestOption) error { if c.apiKey == "" { return errors.New("missing OpenAI API key") } - o := Options{Model: c.defaultModel} + o := Options{Model: c.defaultModel} for _, opt := range opts { opt(&o) } - if o.Model == "" { + if o.Model == "" { o.Model = c.defaultModel } - start := time.Now() + start := time.Now() c.logStart(true, o, messages) req := buildOAChatRequest(o, messages, c.defaultTemperature, true) body, err := json.Marshal(req) @@ -1473,7 +1475,7 @@ func (c openAIClient) ChatStream(ctx context.Context, messages []Message, onDelt c.logf("marshal error: %v", err) return err } - endpoint := c.baseURL + "/chat/completions" + endpoint := c.baseURL + "/chat/completions" logging.Logf("llm/openai ", "POST %s (stream)", endpoint) resp, err := c.doJSONWithAccept(ctx, endpoint, body, map[string]string{ "Authorization": "Bearer " + c.apiKey, @@ -1482,15 +1484,15 @@ func (c openAIClient) ChatStream(ctx context.Context, messages []Message, onDelt logging.Logf("llm/openai ", "%shttp error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase) return err } - defer resp.Body.Close() + defer resp.Body.Close() if err := handleOpenAINon2xx(resp, start); err != nil { return err } - if err := parseOpenAIStream(resp, start, onDelta); err != nil { + if err := parseOpenAIStream(resp, start, onDelta); err != nil { return err } - logging.Logf("llm/openai ", "stream end duration=%s", time.Since(start)) + logging.Logf("llm/openai ", "stream end duration=%s", time.Since(start)) return nil } @@ -1498,117 +1500,117 @@ func (c openAIClient) ChatStream(ctx context.Context, messages []Message, onDelt func (c openAIClient) logf(format string, args ...any) { logging.Logf("llm/openai ", format, args...) } // helpers extracted to keep methods small -func (c openAIClient) logStart(stream bool, o Options, messages []Message) { +func (c openAIClient) logStart(stream bool, o Options, messages []Message) { logMessages := make([]struct{ Role, Content string }, len(messages)) - for i, m := range messages { + for i, m := range messages { logMessages[i] = struct{ Role, Content string }{m.Role, m.Content} } - c.chatLogger.LogStart(stream, o.Model, o.Temperature, o.MaxTokens, o.Stop, logMessages) + c.chatLogger.LogStart(stream, o.Model, o.Temperature, o.MaxTokens, o.Stop, logMessages) } -func buildOAChatRequest(o Options, messages []Message, defaultTemp *float64, stream bool) oaChatRequest { +func buildOAChatRequest(o Options, messages []Message, defaultTemp *float64, stream bool) oaChatRequest { req := oaChatRequest{Model: o.Model, Stream: stream} req.Messages = make([]oaMessage, len(messages)) - for i, m := range messages { + for i, m := range messages { req.Messages[i] = oaMessage{Role: m.Role, Content: m.Content} } - if o.Temperature != 0 { + if o.Temperature != 0 { req.Temperature = &o.Temperature - } else if defaultTemp != nil { + } else if defaultTemp != nil { t := *defaultTemp req.Temperature = &t } - if o.MaxTokens > 0 { + if o.MaxTokens > 0 { req.MaxTokens = &o.MaxTokens } - if len(o.Stop) > 0 { + if len(o.Stop) > 0 { req.Stop = o.Stop } - return req + return req } -func (c openAIClient) doJSON(ctx context.Context, url string, body []byte, headers map[string]string) (*http.Response, error) { +func (c openAIClient) doJSON(ctx context.Context, url string, body []byte, headers map[string]string) (*http.Response, error) { req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) if err != nil { return nil, err } - req.Header.Set("Content-Type", "application/json") - for k, v := range headers { + req.Header.Set("Content-Type", "application/json") + for k, v := range headers { req.Header.Set(k, v) } - return c.httpClient.Do(req) + return c.httpClient.Do(req) } -func (c openAIClient) doJSONWithAccept(ctx context.Context, url string, body []byte, headers map[string]string, accept string) (*http.Response, error) { +func (c openAIClient) doJSONWithAccept(ctx context.Context, url string, body []byte, headers map[string]string, accept string) (*http.Response, error) { req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) if err != nil { return nil, err } - req.Header.Set("Content-Type", "application/json") + req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", accept) - for k, v := range headers { + for k, v := range headers { req.Header.Set(k, v) } - return c.httpClient.Do(req) + return c.httpClient.Do(req) } -func handleOpenAINon2xx(resp *http.Response, start time.Time) error { - if resp.StatusCode >= 200 && resp.StatusCode < 300 { +func handleOpenAINon2xx(resp *http.Response, start time.Time) error { + if resp.StatusCode >= 200 && resp.StatusCode < 300 { return nil } - var apiErr oaChatResponse + var apiErr oaChatResponse _ = json.NewDecoder(resp.Body).Decode(&apiErr) - if apiErr.Error != nil && apiErr.Error.Message != "" { + if apiErr.Error != nil && apiErr.Error.Message != "" { logging.Logf("llm/openai ", "%sapi error status=%d type=%s msg=%s duration=%s%s", logging.AnsiRed, resp.StatusCode, apiErr.Error.Type, apiErr.Error.Message, time.Since(start), logging.AnsiBase) return fmt.Errorf("openai error: %s (status %d)", apiErr.Error.Message, resp.StatusCode) } - logging.Logf("llm/openai ", "%shttp non-2xx status=%d duration=%s%s", logging.AnsiRed, resp.StatusCode, time.Since(start), logging.AnsiBase) + logging.Logf("llm/openai ", "%shttp non-2xx status=%d duration=%s%s", logging.AnsiRed, resp.StatusCode, time.Since(start), logging.AnsiBase) return fmt.Errorf("openai http error: status %d", resp.StatusCode) } -func decodeOpenAIChat(resp *http.Response, start time.Time) (oaChatResponse, error) { +func decodeOpenAIChat(resp *http.Response, start time.Time) (oaChatResponse, error) { var out oaChatResponse - if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { + if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { logging.Logf("llm/openai ", "%sdecode error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase) return oaChatResponse{}, err } - return out, nil + return out, nil } -func parseOpenAIStream(resp *http.Response, start time.Time, onDelta func(string)) error { +func parseOpenAIStream(resp *http.Response, start time.Time, onDelta func(string)) error { // Parse SSE: lines starting with "data: " containing JSON or [DONE] scanner := bufio.NewScanner(resp.Body) const maxBuf = 1024 * 1024 buf := make([]byte, 0, 64*1024) scanner.Buffer(buf, maxBuf) - for scanner.Scan() { + for scanner.Scan() { line := scanner.Text() - if !strings.HasPrefix(line, "data: ") { + if !strings.HasPrefix(line, "data: ") { continue } - payload := strings.TrimPrefix(line, "data: ") - if strings.TrimSpace(payload) == "[DONE]" { + payload := strings.TrimPrefix(line, "data: ") + if strings.TrimSpace(payload) == "[DONE]" { break } - var chunk oaStreamChunk - if err := json.Unmarshal([]byte(payload), &chunk); err != nil { + var chunk oaStreamChunk + if err := json.Unmarshal([]byte(payload), &chunk); err != nil { continue } - if chunk.Error != nil && chunk.Error.Message != "" { + if chunk.Error != nil && chunk.Error.Message != "" { logging.Logf("llm/openai ", "%sstream error: %s%s", logging.AnsiRed, chunk.Error.Message, logging.AnsiBase) return fmt.Errorf("openai stream error: %s", chunk.Error.Message) } - for _, ch := range chunk.Choices { - if ch.Delta.Content != "" { + for _, ch := range chunk.Choices { + if ch.Delta.Content != "" { onDelta(ch.Delta.Content) } } } - if err := scanner.Err(); err != nil { + if err := scanner.Err(); err != nil { logging.Logf("llm/openai ", "%sstream read error after %s: %v%s", logging.AnsiRed, time.Since(start), err, logging.AnsiBase) return err } - return nil + return nil } @@ -1669,11 +1671,11 @@ type Options struct { // RequestOption mutates Options. type RequestOption func(*Options) -func WithModel(model string) RequestOption { return func(o *Options) { o.Model = model } } -func WithTemperature(t float64) RequestOption { return func(o *Options) { o.Temperature = t } } -func WithMaxTokens(n int) RequestOption { return func(o *Options) { o.MaxTokens = n } } -func WithStop(stop ...string) RequestOption { - return func(o *Options) { o.Stop = append([]string{}, stop...) } +func WithModel(model string) RequestOption { return func(o *Options) { o.Model = model } } +func WithTemperature(t float64) RequestOption { return func(o *Options) { o.Temperature = t } } +func WithMaxTokens(n int) RequestOption { return func(o *Options) { o.MaxTokens = n } } +func WithStop(stop ...string) RequestOption { + return func(o *Options) { o.Stop = append([]string{}, stop...) } } // Config defines provider configuration read from the Hexai config file. @@ -1696,38 +1698,38 @@ type Config struct { // NewFromConfig creates an LLM client using only the supplied configuration. // The OpenAI API key is supplied separately and may be read from the environment // by the caller; other environment-based configuration is not used. -func NewFromConfig(cfg Config, openAIAPIKey, copilotAPIKey string) (Client, error) { +func NewFromConfig(cfg Config, openAIAPIKey, copilotAPIKey string) (Client, error) { p := strings.ToLower(strings.TrimSpace(cfg.Provider)) - if p == "" { + if p == "" { p = "openai" } - switch p { - case "openai": - if strings.TrimSpace(openAIAPIKey) == "" { + switch p { + case "openai": + if strings.TrimSpace(openAIAPIKey) == "" { return nil, errors.New("missing OPENAI_API_KEY for provider openai") } // Set coding-friendly default temperature if none provided - if cfg.OpenAITemperature == nil { + if cfg.OpenAITemperature == nil { t := 0.2 cfg.OpenAITemperature = &t } - return newOpenAI(cfg.OpenAIBaseURL, cfg.OpenAIModel, openAIAPIKey, cfg.OpenAITemperature), nil - case "ollama": - if cfg.OllamaTemperature == nil { + return newOpenAI(cfg.OpenAIBaseURL, cfg.OpenAIModel, openAIAPIKey, cfg.OpenAITemperature), nil + case "ollama": + if cfg.OllamaTemperature == nil { t := 0.2 cfg.OllamaTemperature = &t } - return newOllama(cfg.OllamaBaseURL, cfg.OllamaModel, cfg.OllamaTemperature), nil - case "copilot": - if strings.TrimSpace(copilotAPIKey) == "" { + return newOllama(cfg.OllamaBaseURL, cfg.OllamaModel, cfg.OllamaTemperature), nil + case "copilot": + if strings.TrimSpace(copilotAPIKey) == "" { return nil, errors.New("missing COPILOT_API_KEY for provider copilot") } - if cfg.CopilotTemperature == nil { + if cfg.CopilotTemperature == nil { t := 0.2 cfg.CopilotTemperature = &t } - return newCopilot(cfg.CopilotBaseURL, cfg.CopilotModel, copilotAPIKey, cfg.CopilotTemperature), nil - default: + return newCopilot(cfg.CopilotBaseURL, cfg.CopilotModel, copilotAPIKey, cfg.CopilotTemperature), nil + default: return nil, errors.New("unknown LLM provider: " + p) } } @@ -1738,7 +1740,7 @@ func NewFromConfig(cfg Config, openAIAPIKey, copilotAPIKey string) (Client, erro import "errors" // small helper to keep return type consistent -func nilStringErr(msg string) (string, error) { return "", errors.New(msg) } +func nilStringErr(msg string) (string, error) { return "", errors.New(msg) } @@ -1841,21 +1843,21 @@ import ( // - window: include a window of lines around the cursor // - file-on-new-func: include full file only when defining a new function // - always-full: always include the full file -func (s *Server) buildAdditionalContext(newFunc bool, uri string, pos Position) (string, bool) { +func (s *Server) buildAdditionalContext(newFunc bool, uri string, pos Position) (string, bool) { mode := s.contextMode switch mode { case "minimal": return "", false case "window": return s.windowContext(uri, pos), true - case "file-on-new-func": + case "file-on-new-func": if newFunc { return s.fullFileContext(uri), true } return "", false case "always-full": return s.fullFileContext(uri), true - default: + default: // fallback to minimal if unknown return "", false } @@ -1881,23 +1883,23 @@ func (s *Server) windowContext(uri string, pos Position) string { +func (s *Server) fullFileContext(uri string) string { d := s.getDocument(uri) if d == nil { logging.Logf("lsp ", "context: full-file requested but document not open; skipping uri=%s", uri) return "" } - return truncateToApproxTokens(d.text, s.maxContextTokens) + return truncateToApproxTokens(d.text, s.maxContextTokens) } // truncateToApproxTokens naively truncates the input to fit approx N tokens. // Uses 4 chars/token heuristic for speed and determinism. -func truncateToApproxTokens(text string, maxTokens int) string { +func truncateToApproxTokens(text string, maxTokens int) string { if maxTokens <= 0 { return "" } - maxChars := maxTokens * 4 - if len(text) <= maxChars { + maxChars := maxTokens * 4 + if len(text) <= maxChars { return text } // try to cut on a line boundary near maxChars @@ -1926,136 +1928,136 @@ type document struct { lines []string } -func (s *Server) setDocument(uri, text string) { +func (s *Server) setDocument(uri, text string) { s.mu.Lock() defer s.mu.Unlock() s.docs[uri] = &document{uri: uri, text: text, lines: splitLines(text)} } -func (s *Server) deleteDocument(uri string) { +func (s *Server) deleteDocument(uri string) { s.mu.Lock() defer s.mu.Unlock() delete(s.docs, uri) } -func (s *Server) markActivity() { +func (s *Server) markActivity() { s.mu.Lock() s.lastInput = time.Now() s.mu.Unlock() } -func (s *Server) getDocument(uri string) *document { +func (s *Server) getDocument(uri string) *document { s.mu.RLock() defer s.mu.RUnlock() return s.docs[uri] } // splitLines splits the input string into lines, normalizing line endings to '\n'. -func splitLines(sx string) []string { +func splitLines(sx string) []string { sx = strings.ReplaceAll(sx, "\r\n", "\n") return strings.Split(sx, "\n") } -func (s *Server) lineContext(uri string, pos Position) (above, current, below, funcCtx string) { +func (s *Server) lineContext(uri string, pos Position) (above, current, below, funcCtx string) { d := s.getDocument(uri) if d == nil || len(d.lines) == 0 { return "", "", "", "" } - idx := pos.Line + idx := pos.Line if idx < 0 { idx = 0 } - if idx >= len(d.lines) { + if idx >= len(d.lines) { idx = len(d.lines) - 1 } - current = d.lines[idx] - if idx-1 >= 0 { + current = d.lines[idx] + if idx-1 >= 0 { above = d.lines[idx-1] } - if idx+1 < len(d.lines) { + if idx+1 < len(d.lines) { below = d.lines[idx+1] } - for i := idx; i >= 0; i-- { + for i := idx; i >= 0; i-- { line := strings.TrimSpace(d.lines[i]) - if hasAny(line, []string{"func ", "def ", "class ", "fn ", "procedure ", "sub "}) { + if hasAny(line, []string{"func ", "def ", "class ", "fn ", "procedure ", "sub "}) { funcCtx = line break } } - return + return } // isDefiningNewFunction returns true when the cursor appears to be within // a function declaration/signature and before the opening '{' of the body. // Heuristic: find nearest preceding line containing "func "; ensure no '{' // appears before the cursor across those lines. -func (s *Server) isDefiningNewFunction(uri string, pos Position) bool { +func (s *Server) isDefiningNewFunction(uri string, pos Position) bool { d := s.getDocument(uri) if d == nil || len(d.lines) == 0 { return false } - idx := pos.Line + idx := pos.Line if idx < 0 { idx = 0 } - if idx >= len(d.lines) { + if idx >= len(d.lines) { idx = len(d.lines) - 1 } // Find signature start - sigStart := -1 - for i := idx; i >= 0; i-- { - if strings.Contains(d.lines[i], "func ") { + sigStart := -1 + for i := idx; i >= 0; i-- { + if strings.Contains(d.lines[i], "func ") { sigStart = i break } // stop if we hit a closing brace which likely ends a previous block - if strings.Contains(d.lines[i], "}") { + if strings.Contains(d.lines[i], "}") { break } } - if sigStart == -1 { + if sigStart == -1 { return false } // Scan for '{' from sigStart up to cursor position; if found before or at cursor, we're in body - for i := sigStart; i <= idx; i++ { + for i := sigStart; i <= idx; i++ { line := d.lines[i] brace := strings.Index(line, "{") - if brace >= 0 { - if i < idx { + if brace >= 0 { + if i < idx { return false // body started on a previous line } // same line as cursor: if brace position < cursor character, then already in body - if pos.Character > brace { + if pos.Character > brace { return false } } } - return true + return true } -func hasAny(s string, needles []string) bool { - for _, n := range needles { - if strings.Contains(s, n) { +func hasAny(s string, needles []string) bool { + for _, n := range needles { + if strings.Contains(s, n) { return true } } - return false + return false } -func trimLen(s string) string { +func trimLen(s string) string { s = strings.TrimSpace(s) if len(s) > 200 { return s[:200] + "…" } - return s + return s } -func firstLine(s string) string { +func firstLine(s string) string { s = strings.ReplaceAll(s, "\r\n", "\n") - if idx := strings.IndexByte(s, '\n'); idx >= 0 { + if idx := strings.IndexByte(s, '\n'); idx >= 0 { return s[:idx] } - return s + return s } @@ -2068,12 +2070,12 @@ import ( "strings" ) -func (s *Server) handle(req Request) { - if h, ok := s.handlers[req.Method]; ok { +func (s *Server) handle(req Request) { + if h, ok := s.handlers[req.Method]; ok { h(req) return } - if len(req.ID) != 0 { + if len(req.ID) != 0 { s.reply(req.ID, nil, &RespError{Code: -32601, Message: fmt.Sprintf("method not found: %s", req.Method)}) } } @@ -2086,15 +2088,15 @@ func (s *Server) handle(req Request) { // Preference order on each line: strict ;text; marker (no inner spaces), then // a line comment (//, #, --). Returns the instruction string and the selection // text cleaned of the matched instruction marker or comment. -func instructionFromSelection(sel string) (string, string) { +func instructionFromSelection(sel string) (string, string) { lines := splitLines(sel) - for idx, line := range lines { + for idx, line := range lines { if instr, cleaned, ok := findFirstInstructionInLine(line); ok && strings.TrimSpace(instr) != "" { lines[idx] = cleaned return instr, strings.Join(lines, "\n") } } - return "", sel + return "", sel } // findFirstInstructionInLine returns the earliest instruction marker on the @@ -2106,51 +2108,51 @@ func instructionFromSelection(sel string) (string, string) { +func findFirstInstructionInLine(line string) (instr string, cleaned string, ok bool) { type cand struct { start, end int text string } cands := []cand{} - if t, l, r, ok := findStrictSemicolonTag(line); ok { + if t, l, r, ok := findStrictSemicolonTag(line); ok { cands = append(cands, cand{start: l, end: r, text: t}) } - if i := strings.Index(line, "/*"); i >= 0 { - if j := strings.Index(line[i+2:], "*/"); j >= 0 { + if i := strings.Index(line, "/*"); i >= 0 { + if j := strings.Index(line[i+2:], "*/"); j >= 0 { start := i end := i + 2 + j + 2 text := strings.TrimSpace(line[i+2 : i+2+j]) cands = append(cands, cand{start: start, end: end, text: text}) } } - if i := strings.Index(line, "<!--"); i >= 0 { - if j := strings.Index(line[i+4:], "-->"); j >= 0 { + if i := strings.Index(line, "<!--"); i >= 0 { + if j := strings.Index(line[i+4:], "-->"); j >= 0 { start := i end := i + 4 + j + 3 text := strings.TrimSpace(line[i+4 : i+4+j]) cands = append(cands, cand{start: start, end: end, text: text}) } } - if i := strings.Index(line, "//"); i >= 0 { + if i := strings.Index(line, "//"); i >= 0 { cands = append(cands, cand{start: i, end: len(line), text: strings.TrimSpace(line[i+2:])}) } - if i := strings.Index(line, "#"); i >= 0 { + if i := strings.Index(line, "#"); i >= 0 { cands = append(cands, cand{start: i, end: len(line), text: strings.TrimSpace(line[i+1:])}) } - if i := strings.Index(line, "--"); i >= 0 { + if i := strings.Index(line, "--"); i >= 0 { cands = append(cands, cand{start: i, end: len(line), text: strings.TrimSpace(line[i+2:])}) } - if len(cands) == 0 { + if len(cands) == 0 { return "", line, false } // pick earliest start index - best := cands[0] - for _, c := range cands[1:] { + best := cands[0] + for _, c := range cands[1:] { if c.start >= 0 && (best.start < 0 || c.start < best.start) { best = c } } - cleaned = strings.TrimRight(line[:best.start]+line[best.end:], " \t") + cleaned = strings.TrimRight(line[:best.start]+line[best.end:], " \t") return best.text, cleaned, true } @@ -2175,7 +2177,7 @@ func findFirstInstructionInLine(line string) (instr string, cleaned string, ok b // handleCompletion moved to handlers_completion.go -func (s *Server) reply(id json.RawMessage, result any, err *RespError) { +func (s *Server) reply(id json.RawMessage, result any, err *RespError) { resp := Response{JSONRPC: "2.0", ID: id, Result: result, Error: err} s.writeMessage(resp) } @@ -2249,33 +2251,33 @@ func (s *Server) reply(id json.RawMessage, result any, err *RespError) { +func (s *Server) completionCacheKey(p CompletionParams, above, current, below, funcCtx string, inParams bool, hasExtra bool, extraText string) string { // Normalize left-of-cursor by trimming trailing spaces/tabs idx := p.Position.Character if idx > len(current) { idx = len(current) } - left := strings.TrimRight(current[:idx], " \t") + left := strings.TrimRight(current[:idx], " \t") right := "" if idx < len(current) { right = current[idx:] } - prov := "" + prov := "" model := "" - if s.llmClient != nil { + if s.llmClient != nil { prov = s.llmClient.Name() model = s.llmClient.DefaultModel() } - temp := "" + temp := "" if s.codingTemperature != nil { temp = fmt.Sprintf("%.3f", *s.codingTemperature) } - extra := "" + extra := "" if hasExtra { extra = strings.TrimSpace(extraText) } // Compose a key from essential context parts - return strings.Join([]string{ + return strings.Join([]string{ "v1", // version for future-proofing prov, model, @@ -2304,13 +2306,13 @@ func (s *Server) completionCacheGet(key string) (string, bool) { +func (s *Server) completionCachePut(key, value string) { s.mu.Lock() defer s.mu.Unlock() - if s.compCache == nil { + if s.compCache == nil { s.compCache = make(map[string]string) } - if _, exists := s.compCache[key]; !exists { + if _, exists := s.compCache[key]; !exists { s.compCacheOrder = append(s.compCacheOrder, key) s.compCache[key] = value if len(s.compCacheOrder) > 10 { @@ -2319,7 +2321,7 @@ func (s *Server) completionCachePut(key, value string) - return + return } // update existing and mark most-recent s.compCache[key] = value @@ -2346,36 +2348,36 @@ func (s *Server) compCacheTouchLocked(key string) { // by typing one of our configured trigger characters. It checks the LSP // CompletionContext if provided and also falls back to inspecting the character // immediately to the left of the cursor. -func (s *Server) isTriggerEvent(p CompletionParams, current string) bool { +func (s *Server) isTriggerEvent(p CompletionParams, current string) bool { // 1) Inspect LSP completion context if present - if p.Context != nil { + if p.Context != nil { var ctx struct { TriggerKind int `json:"triggerKind"` TriggerCharacter string `json:"triggerCharacter,omitempty"` } - if raw, ok := p.Context.(json.RawMessage); ok { + if raw, ok := p.Context.(json.RawMessage); ok { _ = json.Unmarshal(raw, &ctx) - } else { + } else { b, _ := json.Marshal(p.Context) _ = json.Unmarshal(b, &ctx) } // If the line contains a bare ';;' (no ';;text;'), do not treat as a trigger source. - if strings.Contains(current, ";;") && !hasDoubleSemicolonTrigger(current) { + if strings.Contains(current, ";;") && !hasDoubleSemicolonTrigger(current) { return false } // TriggerKind 1 = Invoked (manual). Always allow manual invoke. - if ctx.TriggerKind == 1 { + if ctx.TriggerKind == 1 { return true } // TriggerKind 2 is TriggerCharacter per LSP spec - if ctx.TriggerKind == 2 { - if ctx.TriggerCharacter != "" { - for _, c := range s.triggerChars { - if c == ctx.TriggerCharacter { + if ctx.TriggerKind == 2 { + if ctx.TriggerCharacter != "" { + for _, c := range s.triggerChars { + if c == ctx.TriggerCharacter { return true } } - return false + return false } // No character provided but reported as TriggerCharacter; be conservative return false @@ -2383,32 +2385,32 @@ func (s *Server) isTriggerEvent(p CompletionParams, current string) bool idx := p.Position.Character + idx := p.Position.Character if idx <= 0 || idx > len(current) { return false } // Bare ';;' should not trigger via fallback char either - if strings.Contains(current, ";;") && !hasDoubleSemicolonTrigger(current) { + if strings.Contains(current, ";;") && !hasDoubleSemicolonTrigger(current) { return false } - ch := string(current[idx-1]) - for _, c := range s.triggerChars { - if c == ch { + ch := string(current[idx-1]) + for _, c := range s.triggerChars { + if c == ch { return true } } return false } -func (s *Server) makeCompletionItems(cleaned string, inParams bool, current string, p CompletionParams, docStr string) []CompletionItem { +func (s *Server) makeCompletionItems(cleaned string, inParams bool, current string, p CompletionParams, docStr string) []CompletionItem { te, filter := computeTextEditAndFilter(cleaned, inParams, current, p) rm := s.collectPromptRemovalEdits(p.TextDocument.URI) label := labelForCompletion(cleaned, filter) detail := "Hexai LLM completion" - if s.llmClient != nil { + if s.llmClient != nil { detail = "Hexai " + s.llmClient.Name() + ":" + s.llmClient.DefaultModel() } - return []CompletionItem{{ + return []CompletionItem{{ Label: label, Kind: 1, Detail: detail, @@ -2493,7 +2495,7 @@ func (s *Server) makeCompletionItems(cleaned string, inParams bool, current stri // labelForCompletion moved to handlers_utils.go -func (s *Server) fallbackCompletionItems(docStr string) []CompletionItem { +func (s *Server) fallbackCompletionItems(docStr string) []CompletionItem { return []CompletionItem{{ Label: "hexai-complete", Kind: 1, @@ -2520,7 +2522,7 @@ import ( "path/filepath" ) -func (s *Server) handleCodeAction(req Request) { +func (s *Server) handleCodeAction(req Request) { var p CodeActionParams if err := json.Unmarshal(req.Params, &p); err != nil { if len(req.ID) != 0 { @@ -2528,34 +2530,34 @@ func (s *Server) handleCodeAction(req Request) { } return } - d := s.getDocument(p.TextDocument.URI) - if d == nil || len(d.lines) == 0 || s.llmClient == nil { - if len(req.ID) != 0 { + d := s.getDocument(p.TextDocument.URI) + if d == nil || len(d.lines) == 0 || s.llmClient == nil { + if len(req.ID) != 0 { s.reply(req.ID, []CodeAction{}, nil) } - return + return } - sel := extractRangeText(d, p.Range) + sel := extractRangeText(d, p.Range) actions := make([]CodeAction, 0, 4) if a := s.buildRewriteCodeAction(p, sel); a != nil { actions = append(actions, *a) } - if a := s.buildDiagnosticsCodeAction(p, sel); a != nil { + if a := s.buildDiagnosticsCodeAction(p, sel); a != nil { actions = append(actions, *a) } - if a := s.buildDocumentCodeAction(p, sel); a != nil { + if a := s.buildDocumentCodeAction(p, sel); a != nil { actions = append(actions, *a) } - if a := s.buildGoUnitTestCodeAction(p); a != nil { + if a := s.buildGoUnitTestCodeAction(p); a != nil { actions = append(actions, *a) } - if len(req.ID) != 0 { + if len(req.ID) != 0 { s.reply(req.ID, actions, nil) } } -func (s *Server) buildRewriteCodeAction(p CodeActionParams, sel string) *CodeAction { +func (s *Server) buildRewriteCodeAction(p CodeActionParams, sel string) *CodeAction { if instr, cleaned := instructionFromSelection(sel); strings.TrimSpace(instr) != "" { payload := struct { Type string `json:"type"` @@ -2568,15 +2570,15 @@ func (s *Server) buildRewriteCodeAction(p CodeActionParams, sel string) *CodeAct ca := CodeAction{Title: "Hexai: rewrite selection", Kind: "refactor.rewrite", Data: raw} return &ca } - return nil + return nil } -func (s *Server) buildDiagnosticsCodeAction(p CodeActionParams, sel string) *CodeAction { +func (s *Server) buildDiagnosticsCodeAction(p CodeActionParams, sel string) *CodeAction { diags := s.diagnosticsInRange(p.Context, p.Range) - if len(diags) == 0 { + if len(diags) == 0 { return nil } - payload := struct { + payload := struct { Type string `json:"type"` URI string `json:"uri"` Range Range `json:"range"` @@ -2588,11 +2590,11 @@ func (s *Server) buildDiagnosticsCodeAction(p CodeActionParams, sel string) *Cod return &ca } -func (s *Server) resolveCodeAction(ca CodeAction) (CodeAction, bool) { +func (s *Server) resolveCodeAction(ca CodeAction) (CodeAction, bool) { if s.llmClient == nil || len(ca.Data) == 0 { return ca, false } - var payload struct { + var payload struct { Type string `json:"type"` URI string `json:"uri"` Range Range `json:"range"` @@ -2603,16 +2605,16 @@ func (s *Server) resolveCodeAction(ca CodeAction) (CodeAction, bool) { return ca, false } - switch payload.Type { - case "rewrite": + switch payload.Type { + case "rewrite": sys := "You are a precise code refactoring engine. Rewrite the given code strictly according to the instruction. Return only the updated code with no prose or backticks. Preserve formatting where reasonable." user := fmt.Sprintf("Instruction: %s\n\nSelected code to transform:\n%s", payload.Instruction, payload.Selection) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() messages := []llm.Message{{Role: "system", Content: sys}, {Role: "user", Content: user}} opts := s.llmRequestOpts() - if text, err := s.llmClient.Chat(ctx, messages, opts...); err == nil { - if out := stripCodeFences(strings.TrimSpace(text)); out != "" { + if text, err := s.llmClient.Chat(ctx, messages, opts...); err == nil { + if out := stripCodeFences(strings.TrimSpace(text)); out != "" { edit := WorkspaceEdit{Changes: map[string][]TextEdit{payload.URI: {{Range: payload.Range, NewText: out}}}} ca.Edit = &edit return ca, true @@ -2620,25 +2622,25 @@ func (s *Server) resolveCodeAction(ca CodeAction) (CodeAction, bool) { logging.Logf("lsp ", "codeAction rewrite llm error: %v", err) } - case "diagnostics": + case "diagnostics": sys := "You are a precise code fixer. Resolve the given diagnostics by editing only the selected code. Return only the corrected code with no prose or backticks. Keep behavior and style, and avoid unrelated changes." var b strings.Builder b.WriteString("Diagnostics to resolve (selection only):\n") - for i, dgn := range payload.Diagnostics { + for i, dgn := range payload.Diagnostics { if dgn.Source != "" { fmt.Fprintf(&b, "%d. [%s] %s\n", i+1, dgn.Source, dgn.Message) - } else { + } else { fmt.Fprintf(&b, "%d. %s\n", i+1, dgn.Message) } } - b.WriteString("\nSelected code:\n") + b.WriteString("\nSelected code:\n") b.WriteString(payload.Selection) ctx, cancel := context.WithTimeout(context.Background(), 12*time.Second) defer cancel() messages := []llm.Message{{Role: "system", Content: sys}, {Role: "user", Content: b.String()}} opts := s.llmRequestOpts() - if text, err := s.llmClient.Chat(ctx, messages, opts...); err == nil { - if out := stripCodeFences(strings.TrimSpace(text)); out != "" { + if text, err := s.llmClient.Chat(ctx, messages, opts...); err == nil { + if out := stripCodeFences(strings.TrimSpace(text)); out != "" { edit := WorkspaceEdit{Changes: map[string][]TextEdit{payload.URI: {{Range: payload.Range, NewText: out}}}} ca.Edit = &edit return ca, true @@ -2646,15 +2648,15 @@ func (s *Server) resolveCodeAction(ca CodeAction) (CodeAction, bool) { logging.Logf("lsp ", "codeAction diagnostics llm error: %v", err) } - case "document": + case "document": sys := "You are a precise code documentation engine. Add idiomatic documentation comments to the given code. Preserve exact behavior and formatting as much as possible. Return only the updated code with comments, no prose or backticks." user := "Add documentation comments to this code:\n" + payload.Selection ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() messages := []llm.Message{{Role: "system", Content: sys}, {Role: "user", Content: user}} opts := s.llmRequestOpts() - if text, err := s.llmClient.Chat(ctx, messages, opts...); err == nil { - if out := stripCodeFences(strings.TrimSpace(text)); out != "" { + if text, err := s.llmClient.Chat(ctx, messages, opts...); err == nil { + if out := stripCodeFences(strings.TrimSpace(text)); out != "" { edit := WorkspaceEdit{Changes: map[string][]TextEdit{payload.URI: {{Range: payload.Range, NewText: out}}}} ca.Edit = &edit return ca, true @@ -2676,7 +2678,7 @@ func (s *Server) resolveCodeAction(ca CodeAction) (CodeAction, bool) return ca, false } -func (s *Server) handleCodeActionResolve(req Request) { +func (s *Server) handleCodeActionResolve(req Request) { var ca CodeAction if err := json.Unmarshal(req.Params, &ca); err != nil { if len(req.ID) != 0 { @@ -2684,7 +2686,7 @@ func (s *Server) handleCodeActionResolve(req Request) return } - if resolved, ok := s.resolveCodeAction(ca); ok { + if resolved, ok := s.resolveCodeAction(ca); ok { s.reply(req.ID, resolved, nil) return } @@ -2694,77 +2696,77 @@ func (s *Server) handleCodeActionResolve(req Request) { - if len(ctxRaw) == 0 { +func (s *Server) diagnosticsInRange(ctxRaw json.RawMessage, sel Range) []Diagnostic { + if len(ctxRaw) == 0 { return nil } - var ctx CodeActionContext + var ctx CodeActionContext if err := json.Unmarshal(ctxRaw, &ctx); err != nil { return nil } - if len(ctx.Diagnostics) == 0 { + if len(ctx.Diagnostics) == 0 { return nil } - out := make([]Diagnostic, 0, len(ctx.Diagnostics)) - for _, d := range ctx.Diagnostics { - if rangesOverlap(d.Range, sel) { + out := make([]Diagnostic, 0, len(ctx.Diagnostics)) + for _, d := range ctx.Diagnostics { + if rangesOverlap(d.Range, sel) { out = append(out, d) } } - return out + return out } // rangesOverlap reports whether two LSP ranges overlap at all. -func rangesOverlap(a, b Range) bool { +func rangesOverlap(a, b Range) bool { // Normalize ordering if greaterPos(a.Start, a.End) { a.Start, a.End = a.End, a.Start } - if greaterPos(b.Start, b.End) { + if greaterPos(b.Start, b.End) { b.Start, b.End = b.End, b.Start } // a ends before b starts - if lessPos(a.End, b.Start) { + if lessPos(a.End, b.Start) { return false } // b ends before a starts - if lessPos(b.End, a.Start) { + if lessPos(b.End, a.Start) { return false } - return true + return true } -func lessPos(p, q Position) bool { - if p.Line != q.Line { +func lessPos(p, q Position) bool { + if p.Line != q.Line { return p.Line < q.Line } - return p.Character < q.Character + return p.Character < q.Character } -func greaterPos(p, q Position) bool { - if p.Line != q.Line { +func greaterPos(p, q Position) bool { + if p.Line != q.Line { return p.Line > q.Line } - return p.Character > q.Character + return p.Character > q.Character } // --- Go unit test code action --- -func (s *Server) buildGoUnitTestCodeAction(p CodeActionParams) *CodeAction { +func (s *Server) buildGoUnitTestCodeAction(p CodeActionParams) *CodeAction { uri := p.TextDocument.URI if uri == "" || !strings.HasSuffix(strings.TrimPrefix(uri, "file://"), ".go") { return nil } // Skip if already a _test.go file - if strings.HasSuffix(strings.TrimPrefix(uri, "file://"), "_test.go") { + if strings.HasSuffix(strings.TrimPrefix(uri, "file://"), "_test.go") { return nil } // Heuristic: only offer when a function context is found above the cursor - _, _, _, funcCtx := s.lineContext(uri, p.Range.Start) + _, _, _, funcCtx := s.lineContext(uri, p.Range.Start) if !strings.Contains(funcCtx, "func ") { return nil } - payload := struct { + payload := struct { Type string `json:"type"` URI string `json:"uri"` Range Range `json:"range"` @@ -2775,14 +2777,14 @@ func (s *Server) buildGoUnitTestCodeAction(p CodeActionParams) *CodeAction { +func (s *Server) buildDocumentCodeAction(p CodeActionParams, sel string) *CodeAction { if s.llmClient == nil { return nil } - if strings.TrimSpace(sel) == "" { + if strings.TrimSpace(sel) == "" { return nil } - payload := struct { + payload := struct { Type string `json:"type"` URI string `json:"uri"` Range Range `json:"range"` @@ -2793,47 +2795,47 @@ func (s *Server) buildDocumentCodeAction(p CodeActionParams, sel string) *CodeAc return &ca } -func (s *Server) resolveGoTest(uri string, pos Position) (WorkspaceEdit, string, Range, bool) { +func (s *Server) resolveGoTest(uri string, pos Position) (WorkspaceEdit, string, Range, bool) { path := strings.TrimPrefix(uri, "file://") if !strings.HasSuffix(path, ".go") || strings.HasSuffix(path, "_test.go") { return WorkspaceEdit{}, "", Range{}, false } // Load source text - _, lines := s.loadFileText(uri) + _, lines := s.loadFileText(uri) if len(lines) == 0 { return WorkspaceEdit{}, "", Range{}, false } - pkg := parseGoPackageName(lines) + pkg := parseGoPackageName(lines) fnStart, fnEnd := findGoFunctionAtLine(lines, pos.Line) if fnStart < 0 || fnEnd < fnStart { return WorkspaceEdit{}, "", Range{}, false } - funcCode := strings.Join(lines[fnStart:fnEnd+1], "\n") + funcCode := strings.Join(lines[fnStart:fnEnd+1], "\n") testFunc := s.generateGoTestFunction(funcCode) if strings.TrimSpace(testFunc) == "" { return WorkspaceEdit{}, "", Range{}, false } // Determine test file target - testPath := strings.TrimSuffix(path, ".go") + "_test.go" + testPath := strings.TrimSuffix(path, ".go") + "_test.go" testURI := "file://" + testPath // If test file exists, append test at EOF; otherwise, create a new file with package+import - if fileExists(testPath) { + if fileExists(testPath) { // Build an insertion at end of file _, tLines := s.loadFileText(testURI) // Fallback when not open and cannot read: still insert at line 0 lineIdx := 0 col := 0 - if len(tLines) > 0 { + if len(tLines) > 0 { lineIdx = len(tLines) - 1 col = len(tLines[lineIdx]) } - var b strings.Builder + var b strings.Builder // Ensure at least two newlines before the new test if len(tLines) == 0 || (len(tLines) > 0 && !strings.HasSuffix(strings.Join(tLines, "\n"), "\n\n")) { b.WriteString("\n\n") } - b.WriteString(testFunc) + b.WriteString(testFunc) insert := b.String() edit := TextEdit{Range: Range{Start: Position{Line: lineIdx, Character: col}, End: Position{Line: lineIdx, Character: col}}, NewText: insert} we := WorkspaceEdit{Changes: map[string][]TextEdit{testURI: {edit}}} @@ -2841,16 +2843,16 @@ func (s *Server) resolveGoTest(uri string, pos Position) (WorkspaceEdit, string, // Count how many prefix newlines added before the test function prefixNL := 0 if strings.HasPrefix(insert, "\n\n") { prefixNL = 2 } - startLine := lineIdx + prefixNL + startLine := lineIdx + prefixNL // If we inserted with two newlines and last line wasn't blank, first newline moves to next line if prefixNL > 0 { startLine = lineIdx + prefixNL } - jump := Range{Start: Position{Line: startLine, Character: 0}, End: Position{Line: startLine, Character: 0}} + jump := Range{Start: Position{Line: startLine, Character: 0}, End: Position{Line: startLine, Character: 0}} return we, testURI, jump, true } // Create new file content - var content strings.Builder + var content strings.Builder if pkg == "" { pkg = filepath.Base(filepath.Dir(path)) } - content.WriteString("package ") + content.WriteString("package ") content.WriteString(pkg) content.WriteString("\n\n") content.WriteString("import (\n\t\"testing\"\n)\n\n") @@ -2865,59 +2867,59 @@ func (s *Server) resolveGoTest(uri string, pos Position) (WorkspaceEdit, string, pre := content.String() idx := strings.Index(pre, "func Test") startLine := 0 - if idx > 0 { + if idx > 0 { before := pre[:idx] startLine = strings.Count(before, "\n") } - jump := Range{Start: Position{Line: startLine, Character: 0}, End: Position{Line: startLine, Character: 0}} + jump := Range{Start: Position{Line: startLine, Character: 0}, End: Position{Line: startLine, Character: 0}} return we, testURI, jump, true } // loadFileText returns the file content and lines. It prefers the open document; otherwise reads from disk. -func (s *Server) loadFileText(uri string) (string, []string) { - if d := s.getDocument(uri); d != nil { +func (s *Server) loadFileText(uri string) (string, []string) { + if d := s.getDocument(uri); d != nil { return d.text, append([]string{}, d.lines...) } - path := strings.TrimPrefix(uri, "file://") + path := strings.TrimPrefix(uri, "file://") b, err := os.ReadFile(path) if err != nil { return "", nil } - txt := string(b) + txt := string(b) return txt, splitLines(txt) } -func fileExists(path string) bool { - if _, err := os.Stat(path); err == nil { +func fileExists(path string) bool { + if _, err := os.Stat(path); err == nil { return true } - return false + return false } // parseGoPackageName returns the package name from file lines, or empty if not found. -func parseGoPackageName(lines []string) string { - for _, ln := range lines { +func parseGoPackageName(lines []string) string { + for _, ln := range lines { t := strings.TrimSpace(ln) - if strings.HasPrefix(t, "package ") { + if strings.HasPrefix(t, "package ") { name := strings.TrimSpace(strings.TrimPrefix(t, "package ")) // strip inline comments - if i := strings.Index(name, " "); i >= 0 { name = name[:i] } - if i := strings.Index(name, "\t"); i >= 0 { name = name[:i] } - if i := strings.Index(name, "//"); i >= 0 { name = strings.TrimSpace(name[:i]) } - return name + if i := strings.Index(name, " "); i >= 0 { name = name[:i] } + if i := strings.Index(name, "\t"); i >= 0 { name = name[:i] } + if i := strings.Index(name, "//"); i >= 0 { name = strings.TrimSpace(name[:i]) } + return name } } - return "" + return "" } // findGoFunctionAtLine finds the function enclosing or preceding line idx. Returns start and end line indexes. -func findGoFunctionAtLine(lines []string, idx int) (int, int) { +func findGoFunctionAtLine(lines []string, idx int) (int, int) { if idx < 0 { idx = 0 } - if idx >= len(lines) { idx = len(lines)-1 } + if idx >= len(lines) { idx = len(lines)-1 } // find signature start - start := -1 - for i := idx; i >= 0; i-- { - if strings.Contains(lines[i], "func ") { + start := -1 + for i := idx; i >= 0; i-- { + if strings.Contains(lines[i], "func ") { start = i break } @@ -2925,20 +2927,20 @@ func findGoFunctionAtLine(lines []string, idx int) (int, int) } } - if start == -1 { return -1, -1 } + if start == -1 { return -1, -1 } // find first '{' - depth := 0 + depth := 0 seenOpen := false - for i := start; i < len(lines); i++ { + for i := start; i < len(lines); i++ { ln := lines[i] - for j := 0; j < len(ln); j++ { + for j := 0; j < len(ln); j++ { switch ln[j] { - case '{': + case '{': depth++ seenOpen = true - case '}': - if depth > 0 { depth-- } - if seenOpen && depth == 0 { + case '}': + if depth > 0 { depth-- } + if seenOpen && depth == 0 { return start, i } } @@ -2952,55 +2954,55 @@ func findGoFunctionAtLine(lines []string, idx int) (int, int) { - if s.llmClient != nil { +func (s *Server) generateGoTestFunction(funcCode string) string { + if s.llmClient != nil { sys := "You are a precise Go unit test generator. Given a Go function, write one or more Test* functions using the testing package. Do NOT include package or imports, only the test function(s). Prefer table-driven tests. Keep it minimal and idiomatic." user := "Function under test:\n" + funcCode ctx, cancel := context.WithTimeout(context.Background(), 8*time.Second) defer cancel() messages := []llm.Message{{Role: "system", Content: sys}, {Role: "user", Content: user}} opts := s.llmRequestOpts() - if out, err := s.llmClient.Chat(ctx, messages, opts...); err == nil { + if out, err := s.llmClient.Chat(ctx, messages, opts...); err == nil { cleaned := strings.TrimSpace(stripCodeFences(out)) - if cleaned != "" { return cleaned } + if cleaned != "" { return cleaned } } else { logging.Logf("lsp ", "codeAction go_test llm error: %v", err) } } // Fallback stub - name := deriveGoFuncName(funcCode) + name := deriveGoFuncName(funcCode) if name == "" { name = "Function" } - return fmt.Sprintf("func Test%s(t *testing.T) {\n\t// TODO: implement tests for %s\n}\n", exportName(name), name) + return fmt.Sprintf("func Test%s(t *testing.T) {\n\t// TODO: implement tests for %s\n}\n", exportName(name), name) } // deriveGoFuncName extracts function or method name from code. -func deriveGoFuncName(code string) string { +func deriveGoFuncName(code string) string { // look for line starting with func line := firstLine(code) line = strings.TrimSpace(line) if !strings.HasPrefix(line, "func ") { return "" } - rest := strings.TrimSpace(strings.TrimPrefix(line, "func ")) + rest := strings.TrimSpace(strings.TrimPrefix(line, "func ")) // method receiver - if strings.HasPrefix(rest, "(") { + if strings.HasPrefix(rest, "(") { // find ")" - if i := strings.Index(rest, ")"); i >= 0 && i+1 < len(rest) { + if i := strings.Index(rest, ")"); i >= 0 && i+1 < len(rest) { rest = strings.TrimSpace(rest[i+1:]) } } // now rest should start with Name( - if i := strings.Index(rest, "("); i > 0 { + if i := strings.Index(rest, "("); i > 0 { return strings.TrimSpace(rest[:i]) } return "" } -func exportName(name string) string { +func exportName(name string) string { if name == "" { return name } - r := []rune(name) + r := []rune(name) if r[0] >= 'a' && r[0] <= 'z' { r[0] = r[0] - ('a' - 'A') } - return string(r) + return string(r) } @@ -3017,10 +3019,10 @@ import ( "time" ) -func (s *Server) handleCompletion(req Request) { +func (s *Server) handleCompletion(req Request) { var p CompletionParams var docStr string - if err := json.Unmarshal(req.Params, &p); err == nil { + if err := json.Unmarshal(req.Params, &p); err == nil { // Log trigger information for every completion request from client tk, tch := extractTriggerInfo(p) logging.Logf("lsp ", "completion trigger kind=%d char=%q uri=%s line=%d char=%d", @@ -3030,11 +3032,11 @@ func (s *Server) handleCompletion(req Request) { if s.logContext { s.logCompletionContext(p, above, current, below, funcCtx) } - if s.llmClient != nil { + if s.llmClient != nil { newFunc := s.isDefiningNewFunction(p.TextDocument.URI, p.Position) extra, has := s.buildAdditionalContext(newFunc, p.TextDocument.URI, p.Position) items, ok := s.tryLLMCompletion(p, above, current, below, funcCtx, docStr, has, extra) - if ok { + if ok { s.reply(req.ID, CompletionList{IsIncomplete: false, Items: items}, nil) return } @@ -3046,41 +3048,41 @@ func (s *Server) handleCompletion(req Request) { // extractTriggerInfo returns the LSP completion TriggerKind and TriggerCharacter // if provided by the client; when absent it returns zeros. -func extractTriggerInfo(p CompletionParams) (kind int, ch string) { +func extractTriggerInfo(p CompletionParams) (kind int, ch string) { if p.Context == nil { return 0, "" } - var ctx struct { + var ctx struct { TriggerKind int `json:"triggerKind"` TriggerCharacter string `json:"triggerCharacter,omitempty"` } - if raw, ok := p.Context.(json.RawMessage); ok { + if raw, ok := p.Context.(json.RawMessage); ok { _ = json.Unmarshal(raw, &ctx) - } else { + } else { b, _ := json.Marshal(p.Context) _ = json.Unmarshal(b, &ctx) } - return ctx.TriggerKind, ctx.TriggerCharacter + return ctx.TriggerKind, ctx.TriggerCharacter } // --- completion helpers --- -func (s *Server) buildDocString(p CompletionParams, above, current, below, funcCtx string) string { +func (s *Server) buildDocString(p CompletionParams, above, current, below, funcCtx string) string { return fmt.Sprintf("file: %s\nline: %d\nabove: %s\ncurrent: %s\nbelow: %s\nfunction: %s", p.TextDocument.URI, p.Position.Line, trimLen(above), trimLen(current), trimLen(below), trimLen(funcCtx)) } -func (s *Server) logCompletionContext(p CompletionParams, above, current, below, funcCtx string) { +func (s *Server) logCompletionContext(p CompletionParams, above, current, below, funcCtx string) { logging.Logf("lsp ", "completion ctx uri=%s line=%d char=%d above=%q current=%q below=%q function=%q", p.TextDocument.URI, p.Position.Line, p.Position.Character, trimLen(above), trimLen(current), trimLen(below), trimLen(funcCtx)) } -func (s *Server) tryLLMCompletion(p CompletionParams, above, current, below, funcCtx, docStr string, hasExtra bool, extraText string) ([]CompletionItem, bool) { +func (s *Server) tryLLMCompletion(p CompletionParams, above, current, below, funcCtx, docStr string, hasExtra bool, extraText string) ([]CompletionItem, bool) { ctx, cancel := context.WithTimeout(context.Background(), 6*time.Second) defer cancel() inlinePrompt := lineHasInlinePrompt(current) - if !inlinePrompt && !s.isTriggerEvent(p, current) { + if !inlinePrompt && !s.isTriggerEvent(p, current) { logging.Logf("lsp ", "%scompletion skip=no-trigger line=%d char=%d current=%q%s", logging.AnsiYellow, p.Position.Line, p.Position.Character, trimLen(current), logging.AnsiBase) return []CompletionItem{}, true } @@ -3151,77 +3153,77 @@ func (s *Server) tryLLMCompletion(p CompletionParams, above, current, below, fun } // parseManualInvoke inspects the LSP completion context and reports whether the user manually invoked completion. -func parseManualInvoke(ctx any) bool { +func parseManualInvoke(ctx any) bool { if ctx == nil { return false } - var c struct { + var c struct { TriggerKind int `json:"triggerKind"` } - if raw, ok := ctx.(json.RawMessage); ok { + if raw, ok := ctx.(json.RawMessage); ok { _ = json.Unmarshal(raw, &c) } else { b, _ := json.Marshal(ctx) _ = json.Unmarshal(b, &c) } - return c.TriggerKind == 1 + return c.TriggerKind == 1 } // shouldSuppressForChatTriggerEOL returns true when a chat trigger like ">" follows ?, !, :, or ; at EOL. -func (s *Server) shouldSuppressForChatTriggerEOL(current string, p CompletionParams) bool { - if t := strings.TrimRight(current, " \t"); len(t) >= 2 && t[len(t)-1] == '>' { +func (s *Server) shouldSuppressForChatTriggerEOL(current string, p CompletionParams) bool { + if t := strings.TrimRight(current, " \t"); len(t) >= 2 && t[len(t)-1] == '>' { prev := t[len(t)-2] - if prev == '?' || prev == '!' || prev == ':' || prev == ';' { + if prev == '?' || prev == '!' || prev == ':' || prev == ';' { logging.Logf("lsp ", "completion skip=chat-trigger-eol uri=%s line=%d", p.TextDocument.URI, p.Position.Line) return true } } - return false + return false } // prefixHeuristicAllows applies minimal prefix rules unless inlinePrompt or structural triggers apply. -func (s *Server) prefixHeuristicAllows(inlinePrompt bool, current string, p CompletionParams, manualInvoke bool) bool { +func (s *Server) prefixHeuristicAllows(inlinePrompt bool, current string, p CompletionParams, manualInvoke bool) bool { // Determine the effective cursor index within current line, clamped, and // skip over trailing spaces/tabs to support cases like "type Matrix| ". idx := p.Position.Character if idx > len(current) { idx = len(current) } - allowNoPrefix := inlinePrompt - if idx > 0 { + allowNoPrefix := inlinePrompt + if idx > 0 { ch := current[idx-1] - if ch == '.' || ch == ':' || ch == '/' || ch == '_' || ch == ')' { + if ch == '.' || ch == ':' || ch == '/' || ch == '_' || ch == ')' { allowNoPrefix = true } } - if allowNoPrefix { + if allowNoPrefix { return true } // Walk left over whitespace - j := idx - for j > 0 { + j := idx + for j > 0 { c := current[j-1] if c == ' ' || c == '\t' { j-- continue } - break + break } - start := computeWordStart(current, j) + start := computeWordStart(current, j) min := 1 - if manualInvoke && s.manualInvokeMinPrefix >= 0 { + if manualInvoke && s.manualInvokeMinPrefix >= 0 { min = s.manualInvokeMinPrefix } - return j-start >= min + return j-start >= min } // tryProviderNativeCompletion attempts provider-native completion and returns items when successful. -func (s *Server) tryProviderNativeCompletion(current string, p CompletionParams, above, below, funcCtx, docStr string, hasExtra bool, extraText string, inParams bool) ([]CompletionItem, bool) { +func (s *Server) tryProviderNativeCompletion(current string, p CompletionParams, above, below, funcCtx, docStr string, hasExtra bool, extraText string, inParams bool) ([]CompletionItem, bool) { cc, ok := s.llmClient.(llm.CodeCompleter) if !ok { return nil, false } - before, after := s.docBeforeAfter(p.TextDocument.URI, p.Position) + before, after := s.docBeforeAfter(p.TextDocument.URI, p.Position) path := strings.TrimPrefix(p.TextDocument.URI, "file://") prompt := "// Path: " + path + "\n" + before lang := "" @@ -3229,11 +3231,11 @@ func (s *Server) tryProviderNativeCompletion(current string, p CompletionParams, if s.codingTemperature != nil { temp = *s.codingTemperature } - prov := "" - if s.llmClient != nil { + prov := "" + if s.llmClient != nil { prov = s.llmClient.Name() } - logging.Logf("lsp ", "completion path=codex provider=%s uri=%s", prov, path) + logging.Logf("lsp ", "completion path=codex provider=%s uri=%s", prov, path) ctx2, cancel2 := context.WithTimeout(context.Background(), 8*time.Second) defer cancel2() @@ -3242,21 +3244,21 @@ func (s *Server) tryProviderNativeCompletion(current string, p CompletionParams, if !s.waitForThrottle(ctx2) { return nil, false } - suggestions, err := cc.CodeCompletion(ctx2, prompt, after, 1, lang, temp) - if err == nil && len(suggestions) > 0 { + suggestions, err := cc.CodeCompletion(ctx2, prompt, after, 1, lang, temp) + if err == nil && len(suggestions) > 0 { cleaned := strings.TrimSpace(suggestions[0]) - if cleaned != "" { + if cleaned != "" { cleaned = stripDuplicateAssignmentPrefix(current[:p.Position.Character], cleaned) - if cleaned != "" { + if cleaned != "" { cleaned = stripDuplicateGeneralPrefix(current[:p.Position.Character], cleaned) } - if cleaned != "" && hasDoubleSemicolonTrigger(current) { + if cleaned != "" && hasDoubleSemicolonTrigger(current) { indent := leadingIndent(current) if indent != "" { cleaned = applyIndent(indent, cleaned) } } - if strings.TrimSpace(cleaned) != "" { + if strings.TrimSpace(cleaned) != "" { key := s.completionCacheKey(p, above, current, below, funcCtx, inParams, hasExtra, extraText) s.completionCachePut(key, cleaned) return s.makeCompletionItems(cleaned, inParams, current, p, docStr), true @@ -3270,9 +3272,9 @@ func (s *Server) tryProviderNativeCompletion(current string, p CompletionParams, // waitForDebounce sleeps until there has been no input activity for at least // completionDebounce. If debounce is zero or ctx is done, it returns promptly. -func (s *Server) waitForDebounce(ctx context.Context) { +func (s *Server) waitForDebounce(ctx context.Context) { d := s.completionDebounce - if d <= 0 { + if d <= 0 { return } for { @@ -3300,9 +3302,9 @@ func (s *Server) waitForDebounce(ctx context.Context) { +func (s *Server) waitForThrottle(ctx context.Context) bool { interval := s.throttleInterval - if interval <= 0 { + if interval <= 0 { return true } var wait time.Duration @@ -3331,41 +3333,41 @@ func (s *Server) waitForThrottle(ctx context.Context) bool { +func (s *Server) buildCompletionMessages(inlinePrompt, hasExtra bool, extraText string, inParams bool, p CompletionParams, above, current, below, funcCtx string) []llm.Message { sysPrompt, userPrompt := buildPrompts(inParams, p, above, current, below, funcCtx) messages := []llm.Message{ {Role: "system", Content: sysPrompt}, {Role: "user", Content: userPrompt}, } - if hasExtra && extraText != "" { + if hasExtra && extraText != "" { messages = append(messages, llm.Message{Role: "user", Content: "Additional context:\n" + extraText}) } - if inlinePrompt { + if inlinePrompt { messages[0].Content = "You are a precise code completion/refactoring engine. Output only the code to insert with no prose, no comments, and no backticks. Return raw code only." } - return messages + return messages } // postProcessCompletion normalizes and deduplicates completion text and applies indentation rules. -func (s *Server) postProcessCompletion(text string, leftOfCursor string, currentLine string) string { +func (s *Server) postProcessCompletion(text string, leftOfCursor string, currentLine string) string { cleaned := stripCodeFences(text) if cleaned != "" && strings.ContainsRune(cleaned, '`') { if inline := stripInlineCodeSpan(cleaned); strings.TrimSpace(inline) != "" { cleaned = inline } } - if cleaned != "" { + if cleaned != "" { cleaned = stripDuplicateAssignmentPrefix(leftOfCursor, cleaned) } - if cleaned != "" { + if cleaned != "" { cleaned = stripDuplicateGeneralPrefix(leftOfCursor, cleaned) } - if cleaned != "" && hasDoubleSemicolonTrigger(currentLine) { - if indent := leadingIndent(currentLine); indent != "" { + if cleaned != "" && hasDoubleSemicolonTrigger(currentLine) { + if indent := leadingIndent(currentLine); indent != "" { cleaned = applyIndent(indent, cleaned) } } - return cleaned + return cleaned } @@ -3381,29 +3383,29 @@ import ( "time" ) -func (s *Server) handleDidOpen(req Request) { +func (s *Server) handleDidOpen(req Request) { var p DidOpenTextDocumentParams - if err := json.Unmarshal(req.Params, &p); err == nil { + if err := json.Unmarshal(req.Params, &p); err == nil { s.setDocument(p.TextDocument.URI, p.TextDocument.Text) s.markActivity() } } -func (s *Server) handleDidChange(req Request) { +func (s *Server) handleDidChange(req Request) { var p DidChangeTextDocumentParams - if err := json.Unmarshal(req.Params, &p); err == nil { - if len(p.ContentChanges) > 0 { + if err := json.Unmarshal(req.Params, &p); err == nil { + if len(p.ContentChanges) > 0 { s.setDocument(p.TextDocument.URI, p.ContentChanges[len(p.ContentChanges)-1].Text) } - s.markActivity() + s.markActivity() // Detect in-editor chat trigger lines and respond inline. s.detectAndHandleChat(p.TextDocument.URI) } } -func (s *Server) handleDidClose(req Request) { +func (s *Server) handleDidClose(req Request) { var p DidCloseTextDocumentParams - if err := json.Unmarshal(req.Params, &p); err == nil { + if err := json.Unmarshal(req.Params, &p); err == nil { s.deleteDocument(p.TextDocument.URI) s.markActivity() } @@ -3412,42 +3414,42 @@ func (s *Server) handleDidClose(req Request) { // docBeforeAfter returns the full document text split at the given position. // The returned strings are the text before the cursor (inclusive of anything // left of the position) and the text after the cursor. -func (s *Server) docBeforeAfter(uri string, pos Position) (string, string) { +func (s *Server) docBeforeAfter(uri string, pos Position) (string, string) { d := s.getDocument(uri) - if d == nil { + if d == nil { return "", "" } // Clamp indices - line := pos.Line + line := pos.Line if line < 0 { line = 0 } - if line >= len(d.lines) { + if line >= len(d.lines) { line = len(d.lines) - 1 } - col := pos.Character + col := pos.Character if col < 0 { col = 0 } - if col > len(d.lines[line]) { + if col > len(d.lines[line]) { col = len(d.lines[line]) } // Build before - var b strings.Builder - for i := 0; i < line; i++ { + var b strings.Builder + for i := 0; i < line; i++ { b.WriteString(d.lines[i]) b.WriteByte('\n') } - b.WriteString(d.lines[line][:col]) + b.WriteString(d.lines[line][:col]) before := b.String() // Build after var a strings.Builder a.WriteString(d.lines[line][col:]) - for i := line + 1; i < len(d.lines); i++ { + for i := line + 1; i < len(d.lines); i++ { a.WriteByte('\n') a.WriteString(d.lines[i]) } - return before, a.String() + return before, a.String() } // --- in-editor chat (";C ...") --- @@ -3455,50 +3457,50 @@ func (s *Server) docBeforeAfter(uri string, pos Position) (string, string) { - if s.llmClient == nil { +func (s *Server) detectAndHandleChat(uri string) { + if s.llmClient == nil { return } - d := s.getDocument(uri) + d := s.getDocument(uri) if d == nil || len(d.lines) == 0 { return } - for i, raw := range d.lines { + for i, raw := range d.lines { // Find last non-space character index j := len(raw) - 1 - for j >= 0 { + for j >= 0 { if raw[j] == ' ' || raw[j] == '\t' { j-- continue } - break + break } - if j < 1 { + if j < 1 { continue } // need at least two chars - pair := raw[j-1 : j+1] + pair := raw[j-1 : j+1] isTrigger := pair == "?>" || pair == "!>" || pair == ":>" || pair == ";>" - if !isTrigger { + if !isTrigger { continue } // Avoid double-answering: if the next non-empty line starts with '>' we skip. - k := i + 1 - for k < len(d.lines) && strings.TrimSpace(d.lines[k]) == "" { + k := i + 1 + for k < len(d.lines) && strings.TrimSpace(d.lines[k]) == "" { k++ } - if k < len(d.lines) && strings.HasPrefix(strings.TrimSpace(d.lines[k]), ">") { + if k < len(d.lines) && strings.HasPrefix(strings.TrimSpace(d.lines[k]), ">") { continue } // Derive prompt by removing only the trailing '>' - removeCount := 1 + removeCount := 1 base := raw[:j+1-removeCount] prompt := strings.TrimSpace(base) if prompt == "" { continue } - lineIdx := i + lineIdx := i lastIdx := j - go func(prompt string, remove int) { + go func(prompt string, remove int) { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() sys := "You are a helpful coding assistant. Answer concisely and clearly." @@ -3512,26 +3514,26 @@ func (s *Server) detectAndHandleChat(uri string) { logging.Logf("lsp ", "chat llm error: %v", err) return } - out := strings.TrimSpace(stripCodeFences(text)) + out := strings.TrimSpace(stripCodeFences(text)) if out == "" { return } - s.applyChatEdits(uri, lineIdx, lastIdx, remove, "> "+out) + s.applyChatEdits(uri, lineIdx, lastIdx, remove, "> "+out) }(prompt, removeCount) // Only handle one per change tick to avoid flooding - break + break } } // applyChatEdits removes the triggering punctuation at end of the line and // inserts two newlines followed by a new line with the response prefixed. -func (s *Server) applyChatEdits(uri string, lineIdx int, lastNonSpace int, removeCount int, response string) { +func (s *Server) applyChatEdits(uri string, lineIdx int, lastNonSpace int, removeCount int, response string) { d := s.getDocument(uri) if d == nil { return } // 1) Delete the trailing punctuation (1 or 2 chars) - delStart := Position{Line: lineIdx, Character: lastNonSpace + 1 - removeCount} + delStart := Position{Line: lineIdx, Character: lastNonSpace + 1 - removeCount} delEnd := Position{Line: lineIdx, Character: lastNonSpace + 1} // 2) Insert two newlines and the response at end-of-line, then one extra blank line insPos := Position{Line: lineIdx, Character: len(d.lines[lineIdx])} @@ -3547,84 +3549,84 @@ func (s *Server) applyChatEdits(uri string, lineIdx int, lastNonSpace int, remov // buildChatHistory walks upwards from the current line to collect the most recent // Q/A pairs in the in-editor transcript. Returns messages ending with current prompt. -func (s *Server) buildChatHistory(uri string, lineIdx int, currentPrompt string) []llm.Message { +func (s *Server) buildChatHistory(uri string, lineIdx int, currentPrompt string) []llm.Message { d := s.getDocument(uri) if d == nil { return []llm.Message{{Role: "user", Content: currentPrompt}} } - type pair struct{ q, a string } + type pair struct{ q, a string } pairs := []pair{} i := lineIdx - 1 - for i >= 0 && len(pairs) < 3 { - for i >= 0 && strings.TrimSpace(d.lines[i]) == "" { + for i >= 0 && len(pairs) < 3 { + for i >= 0 && strings.TrimSpace(d.lines[i]) == "" { i-- } - if i < 0 { + if i < 0 { break } - if !strings.HasPrefix(strings.TrimSpace(d.lines[i]), ">") { + if !strings.HasPrefix(strings.TrimSpace(d.lines[i]), ">") { break } - var replyLines []string - for i >= 0 { + var replyLines []string + for i >= 0 { line := strings.TrimSpace(d.lines[i]) - if strings.HasPrefix(line, ">") { + if strings.HasPrefix(line, ">") { replyLines = append([]string{strings.TrimSpace(strings.TrimPrefix(line, ">"))}, replyLines...) i-- continue } - break + break } - for i >= 0 && strings.TrimSpace(d.lines[i]) == "" { + for i >= 0 && strings.TrimSpace(d.lines[i]) == "" { i-- } - if i < 0 { + if i < 0 { break } - q := strings.TrimSpace(d.lines[i]) + q := strings.TrimSpace(d.lines[i]) q = stripTrailingTrigger(q) pairs = append([]pair{{q: q, a: strings.Join(replyLines, "\n")}}, pairs...) i-- } - msgs := make([]llm.Message, 0, len(pairs)*2+1) - for _, p := range pairs { - if strings.TrimSpace(p.q) != "" { + msgs := make([]llm.Message, 0, len(pairs)*2+1) + for _, p := range pairs { + if strings.TrimSpace(p.q) != "" { msgs = append(msgs, llm.Message{Role: "user", Content: p.q}) } - if strings.TrimSpace(p.a) != "" { + if strings.TrimSpace(p.a) != "" { msgs = append(msgs, llm.Message{Role: "assistant", Content: p.a}) } } - msgs = append(msgs, llm.Message{Role: "user", Content: currentPrompt}) + msgs = append(msgs, llm.Message{Role: "user", Content: currentPrompt}) return msgs } // stripTrailingTrigger removes the trailing chat trigger punctuation from a line if present. -func stripTrailingTrigger(sx string) string { +func stripTrailingTrigger(sx string) string { s := strings.TrimRight(sx, " \t") - if len(s) >= 2 && s[len(s)-1] == '>' { // new triggers + if len(s) >= 2 && s[len(s)-1] == '>' { // new triggers prev := s[len(s)-2] - if prev == '?' || prev == '!' || prev == ':' || prev == ';' { + if prev == '?' || prev == '!' || prev == ':' || prev == ';' { return strings.TrimRight(s[:len(s)-1], " \t") } } - if strings.HasSuffix(s, ";;") { // legacy inline cleanup used in history building + if strings.HasSuffix(s, ";;") { // legacy inline cleanup used in history building return strings.TrimRight(strings.TrimSuffix(s, ";;"), " \t") } - if len(s) == 0 { + if len(s) == 0 { return sx } - last := s[len(s)-1] + last := s[len(s)-1] switch last { // legacy: remove one trailing punctuation - case '?', '!', ':': + case '?', '!', ':': return strings.TrimRight(s[:len(s)-1], " \t") - default: + default: return sx } } // clientApplyEdit sends a workspace/applyEdit request to the client. -func (s *Server) clientApplyEdit(label string, edit WorkspaceEdit) { +func (s *Server) clientApplyEdit(label string, edit WorkspaceEdit) { params := ApplyWorkspaceEditParams{Label: label, Edit: edit} id := s.nextReqID() req := Request{JSONRPC: "2.0", ID: id, Method: "workspace/applyEdit"} @@ -3634,7 +3636,7 @@ func (s *Server) clientApplyEdit(label string, edit WorkspaceEdit) { +func (s *Server) nextReqID() json.RawMessage { s.mu.Lock() s.nextID++ idNum := s.nextID @@ -3644,7 +3646,7 @@ func (s *Server) nextReqID() json.RawMessage { } // clientShowDocument asks the client to open/focus a document and select a range. -func (s *Server) clientShowDocument(uri string, sel *Range) { +func (s *Server) clientShowDocument(uri string, sel *Range) { var params struct { URI string `json:"uri"` External bool `json:"external,omitempty"` @@ -3663,8 +3665,8 @@ func (s *Server) clientShowDocument(uri string, sel *Range) { - go func() { +func (s *Server) deferShowDocument(uri string, sel Range) { + go func() { time.Sleep(120 * time.Millisecond) s.clientShowDocument(uri, &sel) }() @@ -3678,26 +3680,26 @@ import ( "encoding/json" ) -func (s *Server) handleExecuteCommand(req Request) { +func (s *Server) handleExecuteCommand(req Request) { var p ExecuteCommandParams if err := json.Unmarshal(req.Params, &p); err != nil { s.reply(req.ID, nil, nil) return } - switch p.Command { - case "hexai.showDocument": - if len(p.Arguments) >= 2 { + switch p.Command { + case "hexai.showDocument": + if len(p.Arguments) >= 2 { uri, _ := p.Arguments[0].(string) var r Range // Convert second arg to Range via re-marshal to be robust across clients - if b, err := json.Marshal(p.Arguments[1]); err == nil { + if b, err := json.Marshal(p.Arguments[1]); err == nil { _ = json.Unmarshal(b, &r) } - if uri != "" { + if uri != "" { s.clientShowDocument(uri, &r) } } - s.reply(req.ID, nil, nil) + s.reply(req.ID, nil, nil) return default: // Unknown command; no-op @@ -3717,12 +3719,12 @@ import ( "os" ) -func (s *Server) handleInitialize(req Request) { +func (s *Server) handleInitialize(req Request) { version := internal.Version if s.llmClient != nil { version = version + " [" + s.llmClient.Name() + ":" + s.llmClient.DefaultModel() + "]" } - res := InitializeResult{ + res := InitializeResult{ Capabilities: ServerCapabilities{ TextDocumentSync: 1, // 1 = TextDocumentSyncKindFull CompletionProvider: &CompletionOptions{ @@ -3740,7 +3742,7 @@ func (s *Server) handleInitialized() { logging.Logf("lsp ", "client initialized") } -func (s *Server) handleShutdown(req Request) { +func (s *Server) handleShutdown(req Request) { s.reply(req.ID, nil, nil) } @@ -3762,12 +3764,12 @@ import ( ) // llmRequestOpts builds request options from server settings. -func (s *Server) llmRequestOpts() []llm.RequestOption { +func (s *Server) llmRequestOpts() []llm.RequestOption { opts := []llm.RequestOption{llm.WithMaxTokens(s.maxTokens)} if s.codingTemperature != nil { opts = append(opts, llm.WithTemperature(*s.codingTemperature)) } - return opts + return opts } // small helpers for LLM traffic stats @@ -3808,110 +3810,110 @@ func (s *Server) logLLMStats() { } // Completion prompt builders and filters -func inParamList(current string, cursor int) bool { +func inParamList(current string, cursor int) bool { if !strings.Contains(current, "func ") { return false } - open := strings.Index(current, "(") + open := strings.Index(current, "(") close := strings.Index(current, ")") return open >= 0 && cursor > open && (close == -1 || cursor <= close) } -func buildPrompts(inParams bool, p CompletionParams, above, current, below, funcCtx string) (string, string) { - if inParams { +func buildPrompts(inParams bool, p CompletionParams, above, current, below, funcCtx string) (string, string) { + if inParams { sys := "You are a code completion engine for function signatures. Return only the parameter list contents (without parentheses), no braces, no prose. Prefer idiomatic names and types." user := fmt.Sprintf("Cursor is inside the function parameter list. Suggest only the parameter list (no parentheses).\nFunction line: %s\nCurrent line (cursor at %d): %s", funcCtx, p.Position.Character, current) return sys, user } - sys := "You are a terse code completion engine. Return only the code to insert, no surrounding prose or backticks. Only continue from the cursor; never repeat characters already present to the left of the cursor on the current line (e.g., if 'name :=' is already typed, only return the right-hand side expression)." + sys := "You are a terse code completion engine. Return only the code to insert, no surrounding prose or backticks. Only continue from the cursor; never repeat characters already present to the left of the cursor on the current line (e.g., if 'name :=' is already typed, only return the right-hand side expression)." user := fmt.Sprintf("Provide the next likely code to insert at the cursor.\nFile: %s\nFunction/context: %s\nAbove line: %s\nCurrent line (cursor at character %d): %s\nBelow line: %s\nOnly return the completion snippet.", p.TextDocument.URI, funcCtx, above, p.Position.Character, current, below) return sys, user } -func computeTextEditAndFilter(cleaned string, inParams bool, current string, p CompletionParams) (*TextEdit, string) { - if inParams { +func computeTextEditAndFilter(cleaned string, inParams bool, current string, p CompletionParams) (*TextEdit, string) { + if inParams { open := strings.Index(current, "(") close := strings.Index(current, ")") - if open >= 0 { + if open >= 0 { left := open + 1 right := len(current) - if close >= 0 && close >= left { + if close >= 0 && close >= left { right = close } - if p.Position.Character < right { + if p.Position.Character < right { right = p.Position.Character } - te := &TextEdit{Range: Range{Start: Position{Line: p.Position.Line, Character: left}, End: Position{Line: p.Position.Line, Character: right}}, NewText: cleaned} + te := &TextEdit{Range: Range{Start: Position{Line: p.Position.Line, Character: left}, End: Position{Line: p.Position.Line, Character: right}}, NewText: cleaned} var filter string - if left >= 0 && right >= left && right <= len(current) { + if left >= 0 && right >= left && right <= len(current) { filter = strings.TrimLeft(current[left:right], " \t") } - return te, filter + return te, filter } } - startChar := computeWordStart(current, p.Position.Character) + startChar := computeWordStart(current, p.Position.Character) te := &TextEdit{Range: Range{Start: Position{Line: p.Position.Line, Character: startChar}, End: Position{Line: p.Position.Line, Character: p.Position.Character}}, NewText: cleaned} filter := strings.TrimLeft(current[startChar:p.Position.Character], " \t") return te, filter } -func computeWordStart(current string, at int) int { +func computeWordStart(current string, at int) int { if at > len(current) { at = len(current) } - for at > 0 { + for at > 0 { ch := current[at-1] - if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '_' { + if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '_' { at-- continue } - break + break } - return at + return at } -func isIdentChar(ch byte) bool { +func isIdentChar(ch byte) bool { return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '_' } // Inline prompt utilities -func lineHasInlinePrompt(line string) bool { +func lineHasInlinePrompt(line string) bool { if _, _, _, ok := findStrictSemicolonTag(line); ok { return true } - return hasDoubleSemicolonTrigger(line) + return hasDoubleSemicolonTrigger(line) } -func leadingIndent(line string) string { +func leadingIndent(line string) string { i := 0 - for i < len(line) { - if line[i] == ' ' || line[i] == '\t' { + for i < len(line) { + if line[i] == ' ' || line[i] == '\t' { i++ continue } - break + break } - if i == 0 { + if i == 0 { return "" } - return line[:i] + return line[:i] } -func applyIndent(indent, suggestion string) string { +func applyIndent(indent, suggestion string) string { if indent == "" || suggestion == "" { return suggestion } - lines := splitLines(suggestion) - for i, ln := range lines { - if strings.TrimSpace(ln) == "" { + lines := splitLines(suggestion) + for i, ln := range lines { + if strings.TrimSpace(ln) == "" { continue } - if strings.HasPrefix(ln, indent) { + if strings.HasPrefix(ln, indent) { continue } - lines[i] = indent + ln + lines[i] = indent + ln } - return strings.Join(lines, "\n") + return strings.Join(lines, "\n") } // --- Inline marker parsing and general string utilities --- @@ -3920,53 +3922,53 @@ func applyIndent(indent, suggestion string) string // before the last ';' on the given line. Returns the text between semicolons, // the start index of the opening ';', the end index just after the closing ';', // and whether it was found. -func findStrictSemicolonTag(line string) (string, int, int, bool) { +func findStrictSemicolonTag(line string) (string, int, int, bool) { pos := 0 - for pos < len(line) { + for pos < len(line) { j := strings.Index(line[pos:], ";") - if j < 0 { + if j < 0 { return "", 0, 0, false } - j += pos + j += pos // ensure single ';' (not ';;') and non-space after - if j+1 >= len(line) || line[j+1] == ';' || line[j+1] == ' ' { + if j+1 >= len(line) || line[j+1] == ';' || line[j+1] == ' ' { pos = j + 1 continue } - k := strings.Index(line[j+1:], ";") + k := strings.Index(line[j+1:], ";") if k < 0 { return "", 0, 0, false } - closeIdx := j + 1 + k + closeIdx := j + 1 + k if closeIdx-1 < 0 || line[closeIdx-1] == ' ' { pos = closeIdx + 1 continue } - inner := strings.TrimSpace(line[j+1 : closeIdx]) + inner := strings.TrimSpace(line[j+1 : closeIdx]) if inner == "" { pos = closeIdx + 1 continue } - end := closeIdx + 1 + end := closeIdx + 1 return inner, j, end, true } - return "", 0, 0, false + return "", 0, 0, false } // isBareDoubleSemicolon reports whether the line contains a standalone // double-semicolon marker with no inline content (";;" possibly with only // whitespace after it). It explicitly excludes the valid form ";;text;". -func isBareDoubleSemicolon(line string) bool { +func isBareDoubleSemicolon(line string) bool { t := strings.TrimSpace(line) if !strings.Contains(t, ";;") { return false } - if hasDoubleSemicolonTrigger(t) { + if hasDoubleSemicolonTrigger(t) { return false } - if strings.HasPrefix(t, ";;") { + if strings.HasPrefix(t, ";;") { rest := strings.TrimSpace(t[2:]) - if rest == "" || rest == ";" { + if rest == "" || rest == ";" { return true } } @@ -3974,260 +3976,260 @@ func isBareDoubleSemicolon(line string) bool { } // stripDuplicateAssignmentPrefix removes a duplicated assignment prefix from the suggestion. -func stripDuplicateAssignmentPrefix(prefixBeforeCursor, suggestion string) string { +func stripDuplicateAssignmentPrefix(prefixBeforeCursor, suggestion string) string { s2 := strings.TrimLeft(suggestion, " \t") // Prefer := if present at end of prefix - if idx := strings.LastIndex(prefixBeforeCursor, ":="); idx >= 0 && idx+2 <= len(prefixBeforeCursor) { + if idx := strings.LastIndex(prefixBeforeCursor, ":="); idx >= 0 && idx+2 <= len(prefixBeforeCursor) { tail := prefixBeforeCursor[idx+2:] - if strings.TrimSpace(tail) == "" { + if strings.TrimSpace(tail) == "" { start := idx - 1 - for start >= 0 && (isIdentChar(prefixBeforeCursor[start]) || prefixBeforeCursor[start] == ' ' || prefixBeforeCursor[start] == '\t') { + for start >= 0 && (isIdentChar(prefixBeforeCursor[start]) || prefixBeforeCursor[start] == ' ' || prefixBeforeCursor[start] == '\t') { start-- } - start++ + start++ seg := strings.TrimRight(prefixBeforeCursor[start:idx+2], " \t") - if strings.HasPrefix(s2, seg) { + if strings.HasPrefix(s2, seg) { return strings.TrimLeft(s2[len(seg):], " \t") } } } // Fallback to plain '=' if present - if idx := strings.LastIndex(prefixBeforeCursor, "="); idx >= 0 { - if !(idx > 0 && prefixBeforeCursor[idx-1] == ':') { // not := + if idx := strings.LastIndex(prefixBeforeCursor, "="); idx >= 0 { + if !(idx > 0 && prefixBeforeCursor[idx-1] == ':') { // not := tail := prefixBeforeCursor[idx+1:] - if strings.TrimSpace(tail) == "" { + if strings.TrimSpace(tail) == "" { start := idx - 1 - for start >= 0 && (isIdentChar(prefixBeforeCursor[start]) || prefixBeforeCursor[start] == ' ' || prefixBeforeCursor[start] == '\t') { + for start >= 0 && (isIdentChar(prefixBeforeCursor[start]) || prefixBeforeCursor[start] == ' ' || prefixBeforeCursor[start] == '\t') { start-- } - start++ + start++ seg := strings.TrimRight(prefixBeforeCursor[start:idx+1], " \t") - if strings.HasPrefix(s2, seg) { + if strings.HasPrefix(s2, seg) { return strings.TrimLeft(s2[len(seg):], " \t") } } } } - return suggestion + return suggestion } // stripDuplicateGeneralPrefix removes any already-typed prefix that the model repeated. -func stripDuplicateGeneralPrefix(prefixBeforeCursor, suggestion string) string { +func stripDuplicateGeneralPrefix(prefixBeforeCursor, suggestion string) string { if suggestion == "" { return suggestion } - s := strings.TrimLeft(suggestion, " \t") + s := strings.TrimLeft(suggestion, " \t") p := strings.TrimRight(prefixBeforeCursor, " \t") - if p != "" && strings.HasPrefix(s, p) { + if p != "" && strings.HasPrefix(s, p) { return strings.TrimLeft(s[len(p):], " \t") } - for k := len(p) - 1; k > 0; k-- { - if !isIdentBoundary(p[k-1]) { + for k := len(p) - 1; k > 0; k-- { + if !isIdentBoundary(p[k-1]) { continue } - suf := strings.TrimLeft(p[k:], " \t") + suf := strings.TrimLeft(p[k:], " \t") if suf == "" { continue } - if strings.HasPrefix(s, suf) { + if strings.HasPrefix(s, suf) { return strings.TrimLeft(s[len(suf):], " \t") } } - return suggestion + return suggestion } -func isIdentBoundary(ch byte) bool { +func isIdentBoundary(ch byte) bool { return !((ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '_') } // stripCodeFences removes surrounding Markdown code fences from a model response. -func stripCodeFences(s string) string { +func stripCodeFences(s string) string { t := strings.TrimSpace(s) if t == "" { return t } - lines := splitLines(t) + lines := splitLines(t) start := 0 for start < len(lines) && strings.TrimSpace(lines[start]) == "" { start++ } - end := len(lines) - 1 + end := len(lines) - 1 for end >= 0 && strings.TrimSpace(lines[end]) == "" { end-- } - if start >= len(lines) || end < 0 || start > end { + if start >= len(lines) || end < 0 || start > end { return t } - first := strings.TrimSpace(lines[start]) + first := strings.TrimSpace(lines[start]) last := strings.TrimSpace(lines[end]) - if strings.HasPrefix(first, "```") && last == "```" && end > start { + if strings.HasPrefix(first, "```") && last == "```" && end > start { inner := strings.Join(lines[start+1:end], "\n") return inner } - return t + return t } // stripInlineCodeSpan returns the contents of the first inline backtick code span if present. -func stripInlineCodeSpan(s string) string { +func stripInlineCodeSpan(s string) string { t := strings.TrimSpace(s) if t == "" { return t } - i := strings.IndexByte(t, '`') - if i < 0 { + i := strings.IndexByte(t, '`') + if i < 0 { return t } - jrel := strings.IndexByte(t[i+1:], '`') - if jrel < 0 { + jrel := strings.IndexByte(t[i+1:], '`') + if jrel < 0 { return t } - j := i + 1 + jrel + j := i + 1 + jrel return t[i+1 : j] } // labelForCompletion picks a short, readable label for the completion list. -func labelForCompletion(cleaned, filter string) string { +func labelForCompletion(cleaned, filter string) string { label := trimLen(firstLine(cleaned)) - if filter != "" && !strings.HasPrefix(strings.ToLower(label), strings.ToLower(filter)) { + if filter != "" && !strings.HasPrefix(strings.ToLower(label), strings.ToLower(filter)) { return filter } - return label + return label } // extractRangeText returns the exact text within the given document range. -func extractRangeText(d *document, r Range) string { - if r.Start.Line == r.End.Line { +func extractRangeText(d *document, r Range) string { + if r.Start.Line == r.End.Line { line := d.lines[r.Start.Line] if r.Start.Character < 0 { r.Start.Character = 0 } - if r.End.Character > len(line) { + if r.End.Character > len(line) { r.End.Character = len(line) } - if r.Start.Character > r.End.Character { + if r.Start.Character > r.End.Character { return "" } - return line[r.Start.Character:r.End.Character] + return line[r.Start.Character:r.End.Character] } - var b strings.Builder + var b strings.Builder // first line first := d.lines[r.Start.Line] if r.Start.Character < 0 { r.Start.Character = 0 } - if r.Start.Character > len(first) { + if r.Start.Character > len(first) { r.Start.Character = len(first) } - b.WriteString(first[r.Start.Character:]) + b.WriteString(first[r.Start.Character:]) b.WriteString("\n") // middle lines - for i := r.Start.Line + 1; i < r.End.Line; i++ { + for i := r.Start.Line + 1; i < r.End.Line; i++ { b.WriteString(d.lines[i]) - if i+1 <= r.End.Line { + if i+1 <= r.End.Line { b.WriteString("\n") } } // last line - last := d.lines[r.End.Line] + last := d.lines[r.End.Line] if r.End.Character < 0 { r.End.Character = 0 } - if r.End.Character > len(last) { + if r.End.Character > len(last) { r.End.Character = len(last) } - b.WriteString(last[:r.End.Character]) + b.WriteString(last[:r.End.Character]) return b.String() } // collectPromptRemovalEdits returns edits to remove all inline prompt markers. -func (s *Server) collectPromptRemovalEdits(uri string) []TextEdit { +func (s *Server) collectPromptRemovalEdits(uri string) []TextEdit { d := s.getDocument(uri) - if d == nil || len(d.lines) == 0 { + if d == nil || len(d.lines) == 0 { return nil } - var edits []TextEdit - for i, line := range d.lines { + var edits []TextEdit + for i, line := range d.lines { edits = append(edits, promptRemovalEditsForLine(line, i)...) } - return edits + return edits } -func promptRemovalEditsForLine(line string, lineNum int) []TextEdit { - if hasDoubleSemicolonTrigger(line) { +func promptRemovalEditsForLine(line string, lineNum int) []TextEdit { + if hasDoubleSemicolonTrigger(line) { return []TextEdit{{Range: Range{Start: Position{Line: lineNum, Character: 0}, End: Position{Line: lineNum, Character: len(line)}}, NewText: ""}} } - return collectSemicolonMarkers(line, lineNum) + return collectSemicolonMarkers(line, lineNum) } -func hasDoubleSemicolonTrigger(line string) bool { +func hasDoubleSemicolonTrigger(line string) bool { pos := 0 - for pos < len(line) { + for pos < len(line) { j := strings.Index(line[pos:], ";;") - if j < 0 { + if j < 0 { return false } - j += pos + j += pos contentStart := j + 2 - if contentStart >= len(line) { + if contentStart >= len(line) { return false } - first := line[contentStart] + first := line[contentStart] if first == ' ' || first == ';' { pos = contentStart + 1 continue } - k := strings.Index(line[contentStart+1:], ";") + k := strings.Index(line[contentStart+1:], ";") if k < 0 { return false } - closeIdx := contentStart + 1 + k + closeIdx := contentStart + 1 + k if closeIdx-1 >= 0 && line[closeIdx-1] == ' ' { pos = closeIdx + 1 continue } - return true + return true } - return false + return false } -func collectSemicolonMarkers(line string, lineNum int) []TextEdit { +func collectSemicolonMarkers(line string, lineNum int) []TextEdit { var edits []TextEdit startSemi := 0 - for startSemi < len(line) { + for startSemi < len(line) { j := strings.Index(line[startSemi:], ";") - if j < 0 { + if j < 0 { break } - j += startSemi + j += startSemi k := strings.Index(line[j+1:], ";") if k < 0 { break } - if j+1 >= len(line) || line[j+1] == ' ' { + if j+1 >= len(line) || line[j+1] == ' ' { startSemi = j + 1 continue } - if line[j+1] == ';' { + if line[j+1] == ';' { startSemi = j + 2 continue } - closeIdx := j + 1 + k + closeIdx := j + 1 + k if closeIdx-1 < 0 || line[closeIdx-1] == ' ' { startSemi = closeIdx + 1 continue } - if closeIdx-(j+1) < 1 { + if closeIdx-(j+1) < 1 { startSemi = closeIdx + 1 continue } - endChar := closeIdx + 1 - if endChar < len(line) && line[endChar] == ' ' { + endChar := closeIdx + 1 + if endChar < len(line) && line[endChar] == ' ' { endChar++ } - edits = append(edits, TextEdit{Range: Range{Start: Position{Line: lineNum, Character: j}, End: Position{Line: lineNum, Character: endChar}}, NewText: ""}) + edits = append(edits, TextEdit{Range: Range{Start: Position{Line: lineNum, Character: j}, End: Position{Line: lineNum, Character: endChar}}, NewText: ""}) startSemi = endChar } - return edits + return edits } @@ -4302,48 +4304,48 @@ type ServerOptions struct { CompletionThrottleMs int } -func NewServer(r io.Reader, w io.Writer, logger *log.Logger, opts ServerOptions) *Server { +func NewServer(r io.Reader, w io.Writer, logger *log.Logger, opts ServerOptions) *Server { s := &Server{in: bufio.NewReader(r), out: w, logger: logger, docs: make(map[string]*document), logContext: opts.LogContext} maxTokens := opts.MaxTokens - if maxTokens <= 0 { + if maxTokens <= 0 { maxTokens = 500 } - s.maxTokens = maxTokens + s.maxTokens = maxTokens contextMode := opts.ContextMode - if contextMode == "" { + if contextMode == "" { contextMode = "file-on-new-func" } - windowLines := opts.WindowLines - if windowLines <= 0 { + windowLines := opts.WindowLines + if windowLines <= 0 { windowLines = 120 } - maxContextTokens := opts.MaxContextTokens - if maxContextTokens <= 0 { + maxContextTokens := opts.MaxContextTokens + if maxContextTokens <= 0 { maxContextTokens = 2000 } - s.contextMode = contextMode + s.contextMode = contextMode s.windowLines = windowLines s.maxContextTokens = maxContextTokens s.startTime = time.Now() s.llmClient = opts.Client - if len(opts.TriggerCharacters) == 0 { + if len(opts.TriggerCharacters) == 0 { // Defaults (no space to avoid auto-trigger after whitespace) s.triggerChars = []string{".", ":", "/", "_", ")", "{"} } else { s.triggerChars = append([]string{}, opts.TriggerCharacters...) } - s.codingTemperature = opts.CodingTemperature + s.codingTemperature = opts.CodingTemperature s.compCache = make(map[string]string) s.manualInvokeMinPrefix = opts.ManualInvokeMinPrefix if opts.CompletionDebounceMs > 0 { s.completionDebounce = time.Duration(opts.CompletionDebounceMs) * time.Millisecond } - if opts.CompletionThrottleMs > 0 { + if opts.CompletionThrottleMs > 0 { s.throttleInterval = time.Duration(opts.CompletionThrottleMs) * time.Millisecond } // Initialize dispatch table - s.handlers = map[string]func(Request){ + s.handlers = map[string]func(Request){ "initialize": s.handleInitialize, "initialized": func(_ Request) { s.handleInitialized() }, "shutdown": s.handleShutdown, @@ -4356,7 +4358,7 @@ func NewServer(r io.Reader, w io.Writer, logger *log.Logger, opts ServerOptions) "codeAction/resolve": s.handleCodeActionResolve, "workspace/executeCommand": s.handleExecuteCommand, } - return s + return s } func (s *Server) Run() error { @@ -4398,58 +4400,87 @@ import ( "strings" ) -func (s *Server) readMessage() ([]byte, error) { +func (s *Server) readMessage() ([]byte, error) { tp := textproto.NewReader(s.in) var contentLength int - for { + for { line, err := tp.ReadLine() - if err != nil { + if err != nil { return nil, err } - if line == "" { // end of headers + if line == "" { // end of headers break } - parts := strings.SplitN(line, ":", 2) + parts := strings.SplitN(line, ":", 2) if len(parts) != 2 { continue } - key := strings.TrimSpace(strings.ToLower(parts[0])) + key := strings.TrimSpace(strings.ToLower(parts[0])) val := strings.TrimSpace(parts[1]) switch key { - case "content-length": + case "content-length": n, err := strconv.Atoi(val) if err != nil { return nil, fmt.Errorf("invalid Content-Length: %v", err) } - contentLength = n + contentLength = n } } - if contentLength <= 0 { + if contentLength <= 0 { return nil, fmt.Errorf("missing or invalid Content-Length") } - buf := make([]byte, contentLength) + buf := make([]byte, contentLength) if _, err := io.ReadFull(s.in, buf); err != nil { return nil, err } - return buf, nil + return buf, nil } -func (s *Server) writeMessage(v any) { +func (s *Server) writeMessage(v any) { data, err := json.Marshal(v) if err != nil { logging.Logf("lsp ", "marshal error: %v", err) return } - header := fmt.Sprintf("Content-Length: %d\r\n\r\n", len(data)) + header := fmt.Sprintf("Content-Length: %d\r\n\r\n", len(data)) if _, err := io.WriteString(s.out, header); err != nil { logging.Logf("lsp ", "write header error: %v", err) return } - if _, err := s.out.Write(data); err != nil { + if _, err := s.out.Write(data); err != nil { logging.Logf("lsp ", "write body error: %v", err) return } } + + + -- cgit v1.2.3