summaryrefslogtreecommitdiff
path: root/internal/lsp/handlers_codeaction.go
blob: 4407ac04ae0b51483391a28a72c7f45e67923099 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
// Summary: Code Action handlers and helpers split from handlers.go for clarity.
package lsp

import (
    "context"
    "encoding/json"
    "fmt"
    "codeberg.org/snonux/hexai/internal/llm"
    "codeberg.org/snonux/hexai/internal/logging"
    "strings"
    "time"
)

func (s *Server) handleCodeAction(req Request) {
	var p CodeActionParams
	if err := json.Unmarshal(req.Params, &p); err != nil {
		if len(req.ID) != 0 {
			s.reply(req.ID, []CodeAction{}, nil)
		}
		return
	}
	d := s.getDocument(p.TextDocument.URI)
	if d == nil || len(d.lines) == 0 || s.llmClient == nil {
		if len(req.ID) != 0 {
			s.reply(req.ID, []CodeAction{}, nil)
		}
		return
	}
	sel := extractRangeText(d, p.Range)
	if strings.TrimSpace(sel) == "" {
		if len(req.ID) != 0 {
			s.reply(req.ID, []CodeAction{}, nil)
		}
		return
	}

	actions := make([]CodeAction, 0, 2)
	if a := s.buildRewriteCodeAction(p, sel); a != nil {
		actions = append(actions, *a)
	}
	if a := s.buildDiagnosticsCodeAction(p, sel); a != nil {
		actions = append(actions, *a)
	}
	if len(req.ID) != 0 {
		s.reply(req.ID, actions, nil)
	}
}

func (s *Server) buildRewriteCodeAction(p CodeActionParams, sel string) *CodeAction {
	if instr, cleaned := instructionFromSelection(sel); strings.TrimSpace(instr) != "" {
		payload := struct {
			Type        string `json:"type"`
			URI         string `json:"uri"`
			Range       Range  `json:"range"`
			Instruction string `json:"instruction"`
			Selection   string `json:"selection"`
		}{Type: "rewrite", URI: p.TextDocument.URI, Range: p.Range, Instruction: instr, Selection: cleaned}
		raw, _ := json.Marshal(payload)
		ca := CodeAction{Title: "Hexai: rewrite selection", Kind: "refactor.rewrite", Data: raw}
		return &ca
	}
	return nil
}

func (s *Server) buildDiagnosticsCodeAction(p CodeActionParams, sel string) *CodeAction {
	diags := s.diagnosticsInRange(p.Context, p.Range)
	if len(diags) == 0 {
		return nil
	}
	payload := struct {
		Type        string       `json:"type"`
		URI         string       `json:"uri"`
		Range       Range        `json:"range"`
		Selection   string       `json:"selection"`
		Diagnostics []Diagnostic `json:"diagnostics"`
	}{Type: "diagnostics", URI: p.TextDocument.URI, Range: p.Range, Selection: sel, Diagnostics: diags}
	raw, _ := json.Marshal(payload)
	ca := CodeAction{Title: "Hexai: resolve diagnostics", Kind: "quickfix", Data: raw}
	return &ca
}

func (s *Server) resolveCodeAction(ca CodeAction) (CodeAction, bool) {
	if s.llmClient == nil || len(ca.Data) == 0 {
		return ca, false
	}
	var payload struct {
		Type        string       `json:"type"`
		URI         string       `json:"uri"`
		Range       Range        `json:"range"`
		Instruction string       `json:"instruction,omitempty"`
		Selection   string       `json:"selection"`
		Diagnostics []Diagnostic `json:"diagnostics,omitempty"`
	}
	if err := json.Unmarshal(ca.Data, &payload); err != nil {
		return ca, false
	}
	switch payload.Type {
	case "rewrite":
		sys := "You are a precise code refactoring engine. Rewrite the given code strictly according to the instruction. Return only the updated code with no prose or backticks. Preserve formatting where reasonable."
		user := fmt.Sprintf("Instruction: %s\n\nSelected code to transform:\n%s", payload.Instruction, payload.Selection)
		ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
		defer cancel()
		messages := []llm.Message{{Role: "system", Content: sys}, {Role: "user", Content: user}}
		opts := s.llmRequestOpts()
		if text, err := s.llmClient.Chat(ctx, messages, opts...); err == nil {
			if out := stripCodeFences(strings.TrimSpace(text)); out != "" {
				edit := WorkspaceEdit{Changes: map[string][]TextEdit{payload.URI: {{Range: payload.Range, NewText: out}}}}
				ca.Edit = &edit
				return ca, true
			}
		} else {
			logging.Logf("lsp ", "codeAction rewrite llm error: %v", err)
		}
	case "diagnostics":
		sys := "You are a precise code fixer. Resolve the given diagnostics by editing only the selected code. Return only the corrected code with no prose or backticks. Keep behavior and style, and avoid unrelated changes."
		var b strings.Builder
		b.WriteString("Diagnostics to resolve (selection only):\n")
		for i, dgn := range payload.Diagnostics {
			if dgn.Source != "" {
				fmt.Fprintf(&b, "%d. [%s] %s\n", i+1, dgn.Source, dgn.Message)
			} else {
				fmt.Fprintf(&b, "%d. %s\n", i+1, dgn.Message)
			}
		}
		b.WriteString("\nSelected code:\n")
		b.WriteString(payload.Selection)
		ctx, cancel := context.WithTimeout(context.Background(), 12*time.Second)
		defer cancel()
		messages := []llm.Message{{Role: "system", Content: sys}, {Role: "user", Content: b.String()}}
		opts := s.llmRequestOpts()
		if text, err := s.llmClient.Chat(ctx, messages, opts...); err == nil {
			if out := stripCodeFences(strings.TrimSpace(text)); out != "" {
				edit := WorkspaceEdit{Changes: map[string][]TextEdit{payload.URI: {{Range: payload.Range, NewText: out}}}}
				ca.Edit = &edit
				return ca, true
			}
		} else {
			logging.Logf("lsp ", "codeAction diagnostics llm error: %v", err)
		}
	}
	return ca, false
}

func (s *Server) handleCodeActionResolve(req Request) {
	var ca CodeAction
	if err := json.Unmarshal(req.Params, &ca); err != nil {
		if len(req.ID) != 0 {
			s.reply(req.ID, ca, nil)
		}
		return
	}
	if resolved, ok := s.resolveCodeAction(ca); ok {
		s.reply(req.ID, resolved, nil)
		return
	}
	s.reply(req.ID, ca, nil)
}

// diagnosticsInRange parses the CodeAction context and returns diagnostics
// that overlap the given selection range. If the context is missing or does
// not contain diagnostics, returns an empty slice.
func (s *Server) diagnosticsInRange(ctxRaw json.RawMessage, sel Range) []Diagnostic {
	if len(ctxRaw) == 0 {
		return nil
	}
	var ctx CodeActionContext
	if err := json.Unmarshal(ctxRaw, &ctx); err != nil {
		return nil
	}
	if len(ctx.Diagnostics) == 0 {
		return nil
	}
	out := make([]Diagnostic, 0, len(ctx.Diagnostics))
	for _, d := range ctx.Diagnostics {
		if rangesOverlap(d.Range, sel) {
			out = append(out, d)
		}
	}
	return out
}

// rangesOverlap reports whether two LSP ranges overlap at all.
func rangesOverlap(a, b Range) bool {
	// Normalize ordering
	if greaterPos(a.Start, a.End) {
		a.Start, a.End = a.End, a.Start
	}
	if greaterPos(b.Start, b.End) {
		b.Start, b.End = b.End, b.Start
	}
	// a ends before b starts
	if lessPos(a.End, b.Start) {
		return false
	}
	// b ends before a starts
	if lessPos(b.End, a.Start) {
		return false
	}
	return true
}

func lessPos(p, q Position) bool {
	if p.Line != q.Line {
		return p.Line < q.Line
	}
	return p.Character < q.Character
}

func greaterPos(p, q Position) bool {
	if p.Line != q.Line {
		return p.Line > q.Line
	}
	return p.Character > q.Character
}