diff options
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/hexailsp/run.go | 22 | ||||
| -rw-r--r-- | internal/lsp/completion_throttle_test.go | 75 | ||||
| -rw-r--r-- | internal/lsp/handlers.go | 22 | ||||
| -rw-r--r-- | internal/lsp/server.go | 12 |
4 files changed, 119 insertions, 12 deletions
diff --git a/internal/hexailsp/run.go b/internal/hexailsp/run.go index 64607e3..dd12600 100644 --- a/internal/hexailsp/run.go +++ b/internal/hexailsp/run.go @@ -98,14 +98,16 @@ func ensureFactory(factory ServerFactory) ServerFactory { } func makeServerOptions(cfg appconfig.App, logContext bool, client llm.Client) lsp.ServerOptions { - return lsp.ServerOptions{ - LogContext: logContext, - MaxTokens: cfg.MaxTokens, - ContextMode: cfg.ContextMode, - WindowLines: cfg.ContextWindowLines, - MaxContextTokens: cfg.MaxContextTokens, - CodingTemperature: cfg.CodingTemperature, - Client: client, - TriggerCharacters: cfg.TriggerCharacters, - } + return lsp.ServerOptions{ + LogContext: logContext, + MaxTokens: cfg.MaxTokens, + ContextMode: cfg.ContextMode, + WindowLines: cfg.ContextWindowLines, + MaxContextTokens: cfg.MaxContextTokens, + CodingTemperature: cfg.CodingTemperature, + Client: client, + TriggerCharacters: cfg.TriggerCharacters, + // Optional; when zero, server uses a sensible default + MinCompletionIntervalMs: 0, + } } diff --git a/internal/lsp/completion_throttle_test.go b/internal/lsp/completion_throttle_test.go new file mode 100644 index 0000000..2de8edb --- /dev/null +++ b/internal/lsp/completion_throttle_test.go @@ -0,0 +1,75 @@ +package lsp + +import ( + "bytes" + "context" + "log" + "testing" + "time" + + "hexai/internal/llm" +) + +// countingLLM counts Chat calls; minimal implementation for tests. +type countingLLM struct{ calls int } + +func (f *countingLLM) Chat(_ context.Context, _ []llm.Message, _ ...llm.RequestOption) (string, error) { + f.calls++ + return "x := 1", nil +} +func (f *countingLLM) Name() string { return "fake" } +func (f *countingLLM) DefaultModel() string { return "m" } + +func TestDefaultTriggerChars_DoesNotIncludeSemicolonOrQuestion(t *testing.T) { + var buf bytes.Buffer + logger := log.New(&buf, "", 0) + s := NewServer(bytes.NewBuffer(nil), &buf, logger, ServerOptions{}) + has := func(ch string) bool { + for _, c := range s.triggerChars { + if c == ch { return true } + } + return false + } + if has(";") || has("?") { + t.Fatalf("default trigger chars should not include ';' or '?' got=%v", s.triggerChars) + } +} + +func TestTryLLMCompletion_ThrottleSkipsRapidCalls(t *testing.T) { + // Build server with long min interval and set last completion to now + s := &Server{ maxTokens: 32 } + s.minCompletionInterval = time.Hour + s.lastLLMCompletion = time.Now() + fake := &countingLLM{} + s.llmClient = fake + // Position with adequate prefix to avoid prefix heuristic from skipping + p := CompletionParams{ Position: Position{ Line: 0, Character: 3 }, TextDocument: TextDocumentIdentifier{URI: "file://x.go"} } + items, ok := s.tryLLMCompletion(p, "", "foo", "", "", "", false, "") + if !ok { + t.Fatalf("expected ok=true even when throttled") + } + if len(items) != 0 { + t.Fatalf("expected zero items when throttled, got %d", len(items)) + } + if fake.calls != 0 { + t.Fatalf("LLM Chat should not be called when throttled; calls=%d", fake.calls) + } +} + +func TestTryLLMCompletion_MinPrefixSkipsEarly(t *testing.T) { + s := &Server{ maxTokens: 32 } + fake := &countingLLM{} + s.llmClient = fake + // Only 1 identifier character before cursor + p := CompletionParams{ Position: Position{ Line: 0, Character: 1 }, TextDocument: TextDocumentIdentifier{URI: "file://x.go"} } + items, ok := s.tryLLMCompletion(p, "", "a", "", "", "", false, "") + if !ok { + t.Fatalf("expected ok=true when skipped by min-prefix heuristic") + } + if len(items) != 0 { + t.Fatalf("expected zero items when min-prefix not satisfied") + } + if fake.calls != 0 { + t.Fatalf("LLM Chat should not be called when min-prefix not met; calls=%d", fake.calls) + } +} diff --git a/internal/lsp/handlers.go b/internal/lsp/handlers.go index cfd71ea..95656df 100644 --- a/internal/lsp/handlers.go +++ b/internal/lsp/handlers.go @@ -451,7 +451,27 @@ func (s *Server) tryLLMCompletion(p CompletionParams, above, current, below, fun ctx, cancel := context.WithTimeout(context.Background(), 6*time.Second) defer cancel() - inParams := inParamList(current, p.Position.Character) + inParams := inParamList(current, p.Position.Character) + // Heuristic 1: Require a minimal typed identifier prefix to avoid early triggers + if !inParams { + start := computeWordStart(current, p.Position.Character) + if p.Position.Character-start < 2 { // fewer than 2 identifier chars + return []CompletionItem{}, true + } + } + // Heuristic 2: Throttle LLM calls to avoid rapid-fire requests + if s.minCompletionInterval > 0 { + s.mu.Lock() + tooSoon := time.Since(s.lastLLMCompletion) < s.minCompletionInterval + // Preemptively update timestamp to coalesce bursts + if !tooSoon { + s.lastLLMCompletion = time.Now() + } + s.mu.Unlock() + if tooSoon { + return []CompletionItem{}, true + } + } sysPrompt, userPrompt := buildPrompts(inParams, p, above, current, below, funcCtx) messages := []llm.Message{ {Role: "system", Content: sysPrompt}, diff --git a/internal/lsp/server.go b/internal/lsp/server.go index f1ca302..474020c 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -32,6 +32,9 @@ type Server struct { triggerChars []string // If set, used as the LSP coding temperature for all LLM calls codingTemperature *float64 + // Throttling for LLM-powered completion + lastLLMCompletion time.Time + minCompletionInterval time.Duration // LLM request stats llmReqTotal int64 llmSentBytesTotal int64 @@ -51,6 +54,7 @@ type ServerOptions struct { Client llm.Client TriggerCharacters []string CodingTemperature *float64 + MinCompletionIntervalMs int } func NewServer(r io.Reader, w io.Writer, logger *log.Logger, opts ServerOptions) *Server { @@ -79,11 +83,17 @@ func NewServer(r io.Reader, w io.Writer, logger *log.Logger, opts ServerOptions) s.startTime = time.Now() s.llmClient = opts.Client if len(opts.TriggerCharacters) == 0 { - s.triggerChars = []string{".", ":", "/", "_", ";", "?"} + // Conservative defaults to reduce early triggers and API usage + s.triggerChars = []string{".", ":", "/", "_"} } else { s.triggerChars = append([]string{}, opts.TriggerCharacters...) } s.codingTemperature = opts.CodingTemperature + if opts.MinCompletionIntervalMs <= 0 { + s.minCompletionInterval = 900 * time.Millisecond + } else { + s.minCompletionInterval = time.Duration(opts.MinCompletionIntervalMs) * time.Millisecond + } return s } |
