summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
Diffstat (limited to 'internal')
-rw-r--r--internal/hexailsp/run.go22
-rw-r--r--internal/lsp/completion_throttle_test.go75
-rw-r--r--internal/lsp/handlers.go22
-rw-r--r--internal/lsp/server.go12
4 files changed, 119 insertions, 12 deletions
diff --git a/internal/hexailsp/run.go b/internal/hexailsp/run.go
index 64607e3..dd12600 100644
--- a/internal/hexailsp/run.go
+++ b/internal/hexailsp/run.go
@@ -98,14 +98,16 @@ func ensureFactory(factory ServerFactory) ServerFactory {
}
func makeServerOptions(cfg appconfig.App, logContext bool, client llm.Client) lsp.ServerOptions {
- return lsp.ServerOptions{
- LogContext: logContext,
- MaxTokens: cfg.MaxTokens,
- ContextMode: cfg.ContextMode,
- WindowLines: cfg.ContextWindowLines,
- MaxContextTokens: cfg.MaxContextTokens,
- CodingTemperature: cfg.CodingTemperature,
- Client: client,
- TriggerCharacters: cfg.TriggerCharacters,
- }
+ return lsp.ServerOptions{
+ LogContext: logContext,
+ MaxTokens: cfg.MaxTokens,
+ ContextMode: cfg.ContextMode,
+ WindowLines: cfg.ContextWindowLines,
+ MaxContextTokens: cfg.MaxContextTokens,
+ CodingTemperature: cfg.CodingTemperature,
+ Client: client,
+ TriggerCharacters: cfg.TriggerCharacters,
+ // Optional; when zero, server uses a sensible default
+ MinCompletionIntervalMs: 0,
+ }
}
diff --git a/internal/lsp/completion_throttle_test.go b/internal/lsp/completion_throttle_test.go
new file mode 100644
index 0000000..2de8edb
--- /dev/null
+++ b/internal/lsp/completion_throttle_test.go
@@ -0,0 +1,75 @@
+package lsp
+
+import (
+ "bytes"
+ "context"
+ "log"
+ "testing"
+ "time"
+
+ "hexai/internal/llm"
+)
+
+// countingLLM counts Chat calls; minimal implementation for tests.
+type countingLLM struct{ calls int }
+
+func (f *countingLLM) Chat(_ context.Context, _ []llm.Message, _ ...llm.RequestOption) (string, error) {
+ f.calls++
+ return "x := 1", nil
+}
+func (f *countingLLM) Name() string { return "fake" }
+func (f *countingLLM) DefaultModel() string { return "m" }
+
+func TestDefaultTriggerChars_DoesNotIncludeSemicolonOrQuestion(t *testing.T) {
+ var buf bytes.Buffer
+ logger := log.New(&buf, "", 0)
+ s := NewServer(bytes.NewBuffer(nil), &buf, logger, ServerOptions{})
+ has := func(ch string) bool {
+ for _, c := range s.triggerChars {
+ if c == ch { return true }
+ }
+ return false
+ }
+ if has(";") || has("?") {
+ t.Fatalf("default trigger chars should not include ';' or '?' got=%v", s.triggerChars)
+ }
+}
+
+func TestTryLLMCompletion_ThrottleSkipsRapidCalls(t *testing.T) {
+ // Build server with long min interval and set last completion to now
+ s := &Server{ maxTokens: 32 }
+ s.minCompletionInterval = time.Hour
+ s.lastLLMCompletion = time.Now()
+ fake := &countingLLM{}
+ s.llmClient = fake
+ // Position with adequate prefix to avoid prefix heuristic from skipping
+ p := CompletionParams{ Position: Position{ Line: 0, Character: 3 }, TextDocument: TextDocumentIdentifier{URI: "file://x.go"} }
+ items, ok := s.tryLLMCompletion(p, "", "foo", "", "", "", false, "")
+ if !ok {
+ t.Fatalf("expected ok=true even when throttled")
+ }
+ if len(items) != 0 {
+ t.Fatalf("expected zero items when throttled, got %d", len(items))
+ }
+ if fake.calls != 0 {
+ t.Fatalf("LLM Chat should not be called when throttled; calls=%d", fake.calls)
+ }
+}
+
+func TestTryLLMCompletion_MinPrefixSkipsEarly(t *testing.T) {
+ s := &Server{ maxTokens: 32 }
+ fake := &countingLLM{}
+ s.llmClient = fake
+ // Only 1 identifier character before cursor
+ p := CompletionParams{ Position: Position{ Line: 0, Character: 1 }, TextDocument: TextDocumentIdentifier{URI: "file://x.go"} }
+ items, ok := s.tryLLMCompletion(p, "", "a", "", "", "", false, "")
+ if !ok {
+ t.Fatalf("expected ok=true when skipped by min-prefix heuristic")
+ }
+ if len(items) != 0 {
+ t.Fatalf("expected zero items when min-prefix not satisfied")
+ }
+ if fake.calls != 0 {
+ t.Fatalf("LLM Chat should not be called when min-prefix not met; calls=%d", fake.calls)
+ }
+}
diff --git a/internal/lsp/handlers.go b/internal/lsp/handlers.go
index cfd71ea..95656df 100644
--- a/internal/lsp/handlers.go
+++ b/internal/lsp/handlers.go
@@ -451,7 +451,27 @@ func (s *Server) tryLLMCompletion(p CompletionParams, above, current, below, fun
ctx, cancel := context.WithTimeout(context.Background(), 6*time.Second)
defer cancel()
- inParams := inParamList(current, p.Position.Character)
+ inParams := inParamList(current, p.Position.Character)
+ // Heuristic 1: Require a minimal typed identifier prefix to avoid early triggers
+ if !inParams {
+ start := computeWordStart(current, p.Position.Character)
+ if p.Position.Character-start < 2 { // fewer than 2 identifier chars
+ return []CompletionItem{}, true
+ }
+ }
+ // Heuristic 2: Throttle LLM calls to avoid rapid-fire requests
+ if s.minCompletionInterval > 0 {
+ s.mu.Lock()
+ tooSoon := time.Since(s.lastLLMCompletion) < s.minCompletionInterval
+ // Preemptively update timestamp to coalesce bursts
+ if !tooSoon {
+ s.lastLLMCompletion = time.Now()
+ }
+ s.mu.Unlock()
+ if tooSoon {
+ return []CompletionItem{}, true
+ }
+ }
sysPrompt, userPrompt := buildPrompts(inParams, p, above, current, below, funcCtx)
messages := []llm.Message{
{Role: "system", Content: sysPrompt},
diff --git a/internal/lsp/server.go b/internal/lsp/server.go
index f1ca302..474020c 100644
--- a/internal/lsp/server.go
+++ b/internal/lsp/server.go
@@ -32,6 +32,9 @@ type Server struct {
triggerChars []string
// If set, used as the LSP coding temperature for all LLM calls
codingTemperature *float64
+ // Throttling for LLM-powered completion
+ lastLLMCompletion time.Time
+ minCompletionInterval time.Duration
// LLM request stats
llmReqTotal int64
llmSentBytesTotal int64
@@ -51,6 +54,7 @@ type ServerOptions struct {
Client llm.Client
TriggerCharacters []string
CodingTemperature *float64
+ MinCompletionIntervalMs int
}
func NewServer(r io.Reader, w io.Writer, logger *log.Logger, opts ServerOptions) *Server {
@@ -79,11 +83,17 @@ func NewServer(r io.Reader, w io.Writer, logger *log.Logger, opts ServerOptions)
s.startTime = time.Now()
s.llmClient = opts.Client
if len(opts.TriggerCharacters) == 0 {
- s.triggerChars = []string{".", ":", "/", "_", ";", "?"}
+ // Conservative defaults to reduce early triggers and API usage
+ s.triggerChars = []string{".", ":", "/", "_"}
} else {
s.triggerChars = append([]string{}, opts.TriggerCharacters...)
}
s.codingTemperature = opts.CodingTemperature
+ if opts.MinCompletionIntervalMs <= 0 {
+ s.minCompletionInterval = 900 * time.Millisecond
+ } else {
+ s.minCompletionInterval = time.Duration(opts.MinCompletionIntervalMs) * time.Millisecond
+ }
return s
}