diff options
Diffstat (limited to 'internal/lsp')
| -rw-r--r-- | internal/lsp/handlers_codeaction.go | 311 |
1 files changed, 219 insertions, 92 deletions
diff --git a/internal/lsp/handlers_codeaction.go b/internal/lsp/handlers_codeaction.go index 4562954..2a1d7fa 100644 --- a/internal/lsp/handlers_codeaction.go +++ b/internal/lsp/handlers_codeaction.go @@ -14,6 +14,43 @@ import ( "codeberg.org/snonux/hexai/internal/logging" ) +type codeActionPayload struct { + Type string `json:"type"` + ID string `json:"id"` + URI string `json:"uri"` + Range Range `json:"range"` + Instruction string `json:"instruction,omitempty"` + Selection string `json:"selection"` + Diagnostics []Diagnostic `json:"diagnostics,omitempty"` +} + +// CodeActionHandler builds and resolves code actions for a specific action type. +type CodeActionHandler interface { + Build(s *Server, p CodeActionParams, selection string) []CodeAction + Resolve(s *Server, action CodeAction, payload codeActionPayload) (CodeAction, bool) +} + +type codeActionHandler struct { + build func(s *Server, p CodeActionParams, selection string) []CodeAction + resolve func(s *Server, action CodeAction, payload codeActionPayload) (CodeAction, bool) +} + +func (h codeActionHandler) Build(s *Server, p CodeActionParams, selection string) []CodeAction { + if h.build == nil { + return nil + } + return h.build(s, p, selection) +} + +func (h codeActionHandler) Resolve(s *Server, action CodeAction, payload codeActionPayload) (CodeAction, bool) { + if h.resolve == nil { + return action, false + } + return h.resolve(s, action, payload) +} + +var codeActionBuildOrder = []string{"rewrite", "diagnostics", "document", "go_test", "simplify", "custom"} + func (s *Server) handleCodeAction(req Request) { var p CodeActionParams if err := json.Unmarshal(req.Params, &p); err != nil { @@ -39,29 +76,25 @@ func (s *Server) handleCodeAction(req Request) { } sel := extractRangeText(d, p.Range) - actions := make([]CodeAction, 0, 8) - if a := s.buildRewriteCodeAction(p, sel); a != nil { - actions = append(actions, *a) - } - if a := s.buildDiagnosticsCodeAction(p, sel); a != nil { - actions = append(actions, *a) - } - if a := s.buildDocumentCodeAction(p, sel); a != nil { - actions = append(actions, *a) - } - if a := s.buildGoUnitTestCodeAction(p); a != nil { - actions = append(actions, *a) - } - if a := s.buildSimplifyCodeAction(p, sel); a != nil { - actions = append(actions, *a) - } - // Custom actions from config - s.appendCustomActions(&actions, p, sel) + actions := s.buildCodeActions(p, sel) if len(req.ID) != 0 { s.reply(req.ID, actions, nil) } } +func (s *Server) buildCodeActions(p CodeActionParams, selection string) []CodeAction { + actions := make([]CodeAction, 0, 8) + handlers := s.codeActionHandlers() + for _, key := range codeActionBuildOrder { + handler, ok := handlers[key] + if !ok { + continue + } + actions = append(actions, handler.Build(s, p, selection)...) + } + return actions +} + // appendCustomActions adds user-defined actions depending on scope and availability. func (s *Server) appendCustomActions(actions *[]CodeAction, p CodeActionParams, sel string) { customs := s.customActions() @@ -115,6 +148,133 @@ func (s *Server) appendCustomActions(actions *[]CodeAction, p CodeActionParams, } } +func (s *Server) codeActionHandlers() map[string]CodeActionHandler { + return map[string]CodeActionHandler{ + "rewrite": codeActionHandler{build: buildRewriteActions, resolve: resolveRewriteCodeAction}, + "diagnostics": codeActionHandler{build: buildDiagnosticsActions, resolve: resolveDiagnosticsCodeAction}, + "document": codeActionHandler{build: buildDocumentActions, resolve: resolveDocumentCodeAction}, + "go_test": codeActionHandler{build: buildGoTestActions, resolve: resolveGoTestCodeAction}, + "simplify": codeActionHandler{build: buildSimplifyActions, resolve: resolveSimplifyCodeAction}, + "custom": codeActionHandler{build: buildCustomActions, resolve: resolveCustomCodeAction}, + } +} + +func buildRewriteActions(s *Server, p CodeActionParams, selection string) []CodeAction { + if action := s.buildRewriteCodeAction(p, selection); action != nil { + return []CodeAction{*action} + } + return nil +} + +func buildDiagnosticsActions(s *Server, p CodeActionParams, selection string) []CodeAction { + if action := s.buildDiagnosticsCodeAction(p, selection); action != nil { + return []CodeAction{*action} + } + return nil +} + +func buildDocumentActions(s *Server, p CodeActionParams, selection string) []CodeAction { + if action := s.buildDocumentCodeAction(p, selection); action != nil { + return []CodeAction{*action} + } + return nil +} + +func buildGoTestActions(s *Server, p CodeActionParams, _ string) []CodeAction { + if action := s.buildGoUnitTestCodeAction(p); action != nil { + return []CodeAction{*action} + } + return nil +} + +func buildSimplifyActions(s *Server, p CodeActionParams, selection string) []CodeAction { + if action := s.buildSimplifyCodeAction(p, selection); action != nil { + return []CodeAction{*action} + } + return nil +} + +func buildCustomActions(s *Server, p CodeActionParams, selection string) []CodeAction { + actions := make([]CodeAction, 0, len(s.customActions())) + s.appendCustomActions(&actions, p, selection) + return actions +} + +func resolveRewriteCodeAction(s *Server, action CodeAction, payload codeActionPayload) (CodeAction, bool) { + cfg := s.currentConfig() + sys := cfg.PromptCodeActionRewriteSystem + user := renderTemplate(cfg.PromptCodeActionRewriteUser, map[string]string{ + "instruction": payload.Instruction, + "selection": payload.Selection, + }) + return s.completeCodeAction(action, payload.URI, payload.Range, sys, user, 20*time.Second) +} + +func resolveDiagnosticsCodeAction(s *Server, action CodeAction, payload codeActionPayload) (CodeAction, bool) { + cfg := s.currentConfig() + sys := cfg.PromptCodeActionDiagnosticsSystem + user := renderTemplate(cfg.PromptCodeActionDiagnosticsUser, map[string]string{ + "diagnostics": formatDiagnostics(payload.Diagnostics), + "selection": payload.Selection, + }) + return s.completeCodeAction(action, payload.URI, payload.Range, sys, user, 22*time.Second) +} + +func resolveDocumentCodeAction(s *Server, action CodeAction, payload codeActionPayload) (CodeAction, bool) { + cfg := s.currentConfig() + sys := cfg.PromptCodeActionDocumentSystem + user := renderTemplate(cfg.PromptCodeActionDocumentUser, map[string]string{"selection": payload.Selection}) + return s.completeCodeAction(action, payload.URI, payload.Range, sys, user, 20*time.Second) +} + +func resolveGoTestCodeAction(s *Server, action CodeAction, payload codeActionPayload) (CodeAction, bool) { + edit, jumpURI, jumpRange, ok := s.resolveGoTest(payload.URI, payload.Range.Start) + if !ok { + return action, false + } + action.Edit = &edit + action.Command = &Command{ + Title: "Jump to generated test", + Command: "hexai.showDocument", + Arguments: []any{jumpURI, jumpRange}, + } + s.deferShowDocument(jumpURI, jumpRange) + return action, true +} + +func resolveSimplifyCodeAction(s *Server, action CodeAction, payload codeActionPayload) (CodeAction, bool) { + cfg := s.currentConfig() + sys := cfg.PromptCodeActionRewriteSystem + user := renderTemplate(cfg.PromptCodeActionRewriteUser, map[string]string{ + "instruction": "Simplify and improve the code while preserving behavior. Return only the improved code.", + "selection": payload.Selection, + }) + return s.completeCodeAction(action, payload.URI, payload.Range, sys, user, 20*time.Second) +} + +func resolveCustomCodeAction(s *Server, action CodeAction, payload codeActionPayload) (CodeAction, bool) { + custom := s.customActionByID(payload.ID) + if custom == nil { + return action, false + } + cfg := s.currentConfig() + sys := cfg.PromptCodeActionRewriteSystem + user := renderTemplate(cfg.PromptCodeActionRewriteUser, map[string]string{ + "instruction": payload.Instruction, + "selection": payload.Selection, + }) + if strings.TrimSpace(custom.User) != "" { + if strings.TrimSpace(custom.System) != "" { + sys = custom.System + } + user = renderTemplate(custom.User, map[string]string{ + "selection": payload.Selection, + "diagnostics": strings.TrimSpace(formatCustomDiagnostics(payload.Diagnostics)), + }) + } + return s.completeCodeAction(action, payload.URI, payload.Range, sys, user, 20*time.Second) +} + func (s *Server) buildSimplifyCodeAction(p CodeActionParams, sel string) *CodeAction { if strings.TrimSpace(sel) == "" { return nil @@ -167,86 +327,53 @@ func (s *Server) resolveCodeAction(ca CodeAction) (CodeAction, bool) { if s.currentLLMClient() == nil || len(ca.Data) == 0 { return ca, false } - var payload struct { - Type string `json:"type"` - ID string `json:"id"` - URI string `json:"uri"` - Range Range `json:"range"` - Instruction string `json:"instruction,omitempty"` - Selection string `json:"selection"` - Diagnostics []Diagnostic `json:"diagnostics,omitempty"` + payload, ok := decodeCodeActionPayload(ca.Data) + if !ok { + return ca, false } - if err := json.Unmarshal(ca.Data, &payload); err != nil { + handler, found := s.codeActionHandlers()[payload.Type] + if !found { return ca, false } - cfg := s.currentConfig() - switch payload.Type { - case "rewrite": - sys := cfg.PromptCodeActionRewriteSystem - user := renderTemplate(cfg.PromptCodeActionRewriteUser, map[string]string{"instruction": payload.Instruction, "selection": payload.Selection}) - return s.completeCodeAction(ca, payload.URI, payload.Range, sys, user, 20*time.Second) - case "diagnostics": - sys := cfg.PromptCodeActionDiagnosticsSystem - var b strings.Builder - for i, dgn := range payload.Diagnostics { - if dgn.Source != "" { - fmt.Fprintf(&b, "%d. [%s] %s\n", i+1, dgn.Source, dgn.Message) - } else { - fmt.Fprintf(&b, "%d. %s\n", i+1, dgn.Message) - } - } - diagList := b.String() - user := renderTemplate(cfg.PromptCodeActionDiagnosticsUser, map[string]string{"diagnostics": diagList, "selection": payload.Selection}) - return s.completeCodeAction(ca, payload.URI, payload.Range, sys, user, 22*time.Second) - case "document": - sys := cfg.PromptCodeActionDocumentSystem - user := renderTemplate(cfg.PromptCodeActionDocumentUser, map[string]string{"selection": payload.Selection}) - return s.completeCodeAction(ca, payload.URI, payload.Range, sys, user, 20*time.Second) - case "go_test": - if edit, jumpURI, jumpRange, ok := s.resolveGoTest(payload.URI, payload.Range.Start); ok { - ca.Edit = &edit - ca.Command = &Command{Title: "Jump to generated test", Command: "hexai.showDocument", Arguments: []any{jumpURI, jumpRange}} - s.deferShowDocument(jumpURI, jumpRange) - return ca, true - } - case "simplify": - sys := cfg.PromptCodeActionRewriteSystem - user := renderTemplate(cfg.PromptCodeActionRewriteUser, map[string]string{"instruction": "Simplify and improve the code while preserving behavior. Return only the improved code.", "selection": payload.Selection}) - return s.completeCodeAction(ca, payload.URI, payload.Range, sys, user, 20*time.Second) - case "custom": - var action *CustomAction - for _, caDef := range s.customActions() { - if caDef.ID == payload.ID { - action = &caDef - break - } - } - if action == nil { - return ca, false + return handler.Resolve(s, ca, payload) +} + +func decodeCodeActionPayload(raw json.RawMessage) (codeActionPayload, bool) { + var payload codeActionPayload + if err := json.Unmarshal(raw, &payload); err != nil { + return codeActionPayload{}, false + } + return payload, true +} + +func formatDiagnostics(diagnostics []Diagnostic) string { + var b strings.Builder + for i, dgn := range diagnostics { + if dgn.Source != "" { + fmt.Fprintf(&b, "%d. [%s] %s\n", i+1, dgn.Source, dgn.Message) + continue } - var sys, user string - if strings.TrimSpace(action.User) != "" { - if strings.TrimSpace(action.System) != "" { - sys = action.System - } else { - sys = cfg.PromptCodeActionRewriteSystem - } - var diagList string - if len(payload.Diagnostics) > 0 { - var b strings.Builder - for _, d := range payload.Diagnostics { - fmt.Fprintf(&b, "%s\n", d.Message) - } - diagList = b.String() - } - user = renderTemplate(action.User, map[string]string{"selection": payload.Selection, "diagnostics": strings.TrimSpace(diagList)}) - } else { - sys = cfg.PromptCodeActionRewriteSystem - user = renderTemplate(cfg.PromptCodeActionRewriteUser, map[string]string{"instruction": payload.Instruction, "selection": payload.Selection}) + fmt.Fprintf(&b, "%d. %s\n", i+1, dgn.Message) + } + return b.String() +} + +func formatCustomDiagnostics(diagnostics []Diagnostic) string { + var b strings.Builder + for _, d := range diagnostics { + fmt.Fprintf(&b, "%s\n", d.Message) + } + return b.String() +} + +func (s *Server) customActionByID(id string) *CustomAction { + for _, item := range s.customActions() { + if item.ID == id { + action := item + return &action } - return s.completeCodeAction(ca, payload.URI, payload.Range, sys, user, 20*time.Second) } - return ca, false + return nil } func (s *Server) completeCodeAction(ca CodeAction, uri string, rng Range, sys, user string, timeout time.Duration) (CodeAction, bool) { |
