summaryrefslogtreecommitdiff
path: root/internal/llm
diff options
context:
space:
mode:
Diffstat (limited to 'internal/llm')
-rw-r--r--internal/llm/copilot.go13
-rw-r--r--internal/llm/ollama.go45
-rw-r--r--internal/llm/openai.go23
-rw-r--r--internal/llm/provider.go76
-rw-r--r--internal/llm/util.go6
5 files changed, 103 insertions, 60 deletions
diff --git a/internal/llm/copilot.go b/internal/llm/copilot.go
index 680e7ec..47ce11e 100644
--- a/internal/llm/copilot.go
+++ b/internal/llm/copilot.go
@@ -22,6 +22,7 @@ type copilotClient struct {
baseURL string
defaultModel string
chatLogger logging.ChatLogger
+ defaultTemperature *float64
}
type copilotChatRequest struct {
@@ -55,7 +56,7 @@ type copilotChatResponse struct {
}
// Constructor (kept among the first functions by convention)
-func newCopilot(baseURL, model, apiKey string) Client {
+func newCopilot(baseURL, model, apiKey string, defaultTemp *float64) Client {
if strings.TrimSpace(baseURL) == "" {
baseURL = "https://api.githubcopilot.com"
}
@@ -68,6 +69,7 @@ func newCopilot(baseURL, model, apiKey string) Client {
baseURL: strings.TrimRight(baseURL, "/"),
defaultModel: model,
chatLogger: logging.NewChatLogger("copilot"),
+ defaultTemperature: defaultTemp,
}
}
@@ -101,9 +103,12 @@ func (c copilotClient) Chat(ctx context.Context, messages []Message, opts ...Req
for i, m := range messages {
req.Messages[i] = copilotMessage{Role: m.Role, Content: m.Content}
}
- if o.Temperature != 0 {
- req.Temperature = &o.Temperature
- }
+ if o.Temperature != 0 {
+ req.Temperature = &o.Temperature
+ } else if c.defaultTemperature != nil {
+ t := *c.defaultTemperature
+ req.Temperature = &t
+ }
if o.MaxTokens > 0 {
req.MaxTokens = &o.MaxTokens
}
diff --git a/internal/llm/ollama.go b/internal/llm/ollama.go
index 14aa558..20dfe2a 100644
--- a/internal/llm/ollama.go
+++ b/internal/llm/ollama.go
@@ -22,13 +22,14 @@ type ollamaClient struct {
baseURL string
defaultModel string
chatLogger logging.ChatLogger
+ defaultTemperature *float64
}
type ollamaChatRequest struct {
- Model string `json:"model"`
- Messages []oaMessage `json:"messages"`
- Stream bool `json:"stream"`
- Options any `json:"options,omitempty"`
+ Model string `json:"model"`
+ Messages []oaMessage `json:"messages"`
+ Stream bool `json:"stream"`
+ Options any `json:"options,omitempty"`
}
type ollamaChatResponse struct {
@@ -41,21 +42,23 @@ type ollamaChatResponse struct {
}
// Constructor (kept among the first functions by convention)
-func newOllama(baseURL, model string) Client {
- if strings.TrimSpace(baseURL) == "" {
- baseURL = "http://localhost:11434"
- }
- if strings.TrimSpace(model) == "" {
- model = "qwen2.5-coder:latest"
- }
+func newOllama(baseURL, model string, defaultTemp *float64) Client {
+ if strings.TrimSpace(baseURL) == "" {
+ baseURL = "http://localhost:11434"
+ }
+ if strings.TrimSpace(model) == "" {
+ model = "qwen3-coder:30b-a3b-q4_K_M`"
+ }
return ollamaClient{
httpClient: &http.Client{Timeout: 30 * time.Second},
baseURL: strings.TrimRight(baseURL, "/"),
defaultModel: model,
chatLogger: logging.NewChatLogger("ollama"),
+ defaultTemperature: defaultTemp,
}
}
+// TODO: This function is too long and should be refactored for readability and maintainability.
func (c ollamaClient) Chat(ctx context.Context, messages []Message, opts ...RequestOption) (string, error) {
o := Options{Model: c.defaultModel}
for _, opt := range opts {
@@ -86,9 +89,11 @@ func (c ollamaClient) Chat(ctx context.Context, messages []Message, opts ...Requ
// Build options map only if any option is set
optsMap := map[string]any{}
- if o.Temperature != 0 {
- optsMap["temperature"] = o.Temperature
- }
+ if o.Temperature != 0 {
+ optsMap["temperature"] = o.Temperature
+ } else if c.defaultTemperature != nil {
+ optsMap["temperature"] = *c.defaultTemperature
+ }
if o.MaxTokens > 0 {
optsMap["num_predict"] = o.MaxTokens
}
@@ -177,9 +182,11 @@ func (c ollamaClient) ChatStream(ctx context.Context, messages []Message, onDelt
}
// Build options map
optsMap := map[string]any{}
- if o.Temperature != 0 {
- optsMap["temperature"] = o.Temperature
- }
+ if o.Temperature != 0 {
+ optsMap["temperature"] = o.Temperature
+ } else if c.defaultTemperature != nil {
+ optsMap["temperature"] = *c.defaultTemperature
+ }
if o.MaxTokens > 0 {
optsMap["num_predict"] = o.MaxTokens
}
@@ -241,6 +248,6 @@ func (c ollamaClient) ChatStream(ctx context.Context, messages []Message, onDelt
break
}
}
- logging.Logf("llm/ollama ", "stream end duration=%s", time.Since(start))
- return nil
+ logging.Logf("llm/ollama ", "stream end duration=%s", time.Since(start))
+ return nil
}
diff --git a/internal/llm/openai.go b/internal/llm/openai.go
index 8dc2907..5348def 100644
--- a/internal/llm/openai.go
+++ b/internal/llm/openai.go
@@ -23,6 +23,7 @@ type openAIClient struct {
baseURL string
defaultModel string
chatLogger logging.ChatLogger
+ defaultTemperature *float64
}
type oaChatRequest struct {
@@ -75,7 +76,7 @@ type oaStreamChunk struct {
// Constructor (kept among the first functions by convention)
// newOpenAI constructs an OpenAI client using explicit configuration values.
// The apiKey may be empty; calls will fail until a valid key is supplied.
-func newOpenAI(baseURL, model, apiKey string) Client {
+func newOpenAI(baseURL, model, apiKey string, defaultTemp *float64) Client {
if strings.TrimSpace(baseURL) == "" {
baseURL = "https://api.openai.com/v1"
}
@@ -88,6 +89,7 @@ func newOpenAI(baseURL, model, apiKey string) Client {
baseURL: baseURL,
defaultModel: model,
chatLogger: logging.NewChatLogger("openai"),
+ defaultTemperature: defaultTemp,
}
}
@@ -120,9 +122,13 @@ func (c openAIClient) Chat(ctx context.Context, messages []Message, opts ...Requ
for i, m := range messages {
req.Messages[i] = oaMessage{Role: m.Role, Content: m.Content}
}
- if o.Temperature != 0 {
- req.Temperature = &o.Temperature
- }
+ // Decide temperature: request option overrides config default.
+ if o.Temperature != 0 {
+ req.Temperature = &o.Temperature
+ } else if c.defaultTemperature != nil {
+ t := *c.defaultTemperature
+ req.Temperature = &t
+ }
if o.MaxTokens > 0 {
req.MaxTokens = &o.MaxTokens
}
@@ -212,9 +218,12 @@ func (c openAIClient) ChatStream(ctx context.Context, messages []Message, onDelt
for i, m := range messages {
req.Messages[i] = oaMessage{Role: m.Role, Content: m.Content}
}
- if o.Temperature != 0 {
- req.Temperature = &o.Temperature
- }
+ if o.Temperature != 0 {
+ req.Temperature = &o.Temperature
+ } else if c.defaultTemperature != nil {
+ t := *c.defaultTemperature
+ req.Temperature = &t
+ }
if o.MaxTokens > 0 {
req.MaxTokens = &o.MaxTokens
}
diff --git a/internal/llm/provider.go b/internal/llm/provider.go
index c605081..ed9ca59 100644
--- a/internal/llm/provider.go
+++ b/internal/llm/provider.go
@@ -54,40 +54,56 @@ func WithStop(stop ...string) RequestOption {
// Config defines provider configuration read from the Hexai config file.
type Config struct {
- Provider string
- // OpenAI options
- OpenAIBaseURL string
- OpenAIModel string
- // Ollama options
- OllamaBaseURL string
- OllamaModel string
- // Copilot options
- CopilotBaseURL string
- CopilotModel string
+ Provider string
+ // OpenAI options
+ OpenAIBaseURL string
+ OpenAIModel string
+ OpenAITemperature *float64
+ // Ollama options
+ OllamaBaseURL string
+ OllamaModel string
+ OllamaTemperature *float64
+ // Copilot options
+ CopilotBaseURL string
+ CopilotModel string
+ CopilotTemperature *float64
}
// NewFromConfig creates an LLM client using only the supplied configuration.
// The OpenAI API key is supplied separately and may be read from the environment
// by the caller; other environment-based configuration is not used.
func NewFromConfig(cfg Config, openAIAPIKey, copilotAPIKey string) (Client, error) {
- p := strings.ToLower(strings.TrimSpace(cfg.Provider))
- if p == "" {
- p = "openai"
- }
- switch p {
- case "openai":
- if strings.TrimSpace(openAIAPIKey) == "" {
- return nil, errors.New("missing OPENAI_API_KEY for provider openai")
- }
- return newOpenAI(cfg.OpenAIBaseURL, cfg.OpenAIModel, openAIAPIKey), nil
- case "ollama":
- return newOllama(cfg.OllamaBaseURL, cfg.OllamaModel), nil
- case "copilot":
- if strings.TrimSpace(copilotAPIKey) == "" {
- return nil, errors.New("missing COPILOT_API_KEY for provider copilot")
- }
- return newCopilot(cfg.CopilotBaseURL, cfg.CopilotModel, copilotAPIKey), nil
- default:
- return nil, errors.New("unknown LLM provider: " + p)
- }
+ p := strings.ToLower(strings.TrimSpace(cfg.Provider))
+ if p == "" {
+ p = "openai"
+ }
+ switch p {
+ case "openai":
+ if strings.TrimSpace(openAIAPIKey) == "" {
+ return nil, errors.New("missing OPENAI_API_KEY for provider openai")
+ }
+ // Set coding-friendly default temperature if none provided
+ if cfg.OpenAITemperature == nil {
+ t := 0.2
+ cfg.OpenAITemperature = &t
+ }
+ return newOpenAI(cfg.OpenAIBaseURL, cfg.OpenAIModel, openAIAPIKey, cfg.OpenAITemperature), nil
+ case "ollama":
+ if cfg.OllamaTemperature == nil {
+ t := 0.2
+ cfg.OllamaTemperature = &t
+ }
+ return newOllama(cfg.OllamaBaseURL, cfg.OllamaModel, cfg.OllamaTemperature), nil
+ case "copilot":
+ if strings.TrimSpace(copilotAPIKey) == "" {
+ return nil, errors.New("missing COPILOT_API_KEY for provider copilot")
+ }
+ if cfg.CopilotTemperature == nil {
+ t := 0.2
+ cfg.CopilotTemperature = &t
+ }
+ return newCopilot(cfg.CopilotBaseURL, cfg.CopilotModel, copilotAPIKey, cfg.CopilotTemperature), nil
+ default:
+ return nil, errors.New("unknown LLM provider: " + p)
+ }
}
diff --git a/internal/llm/util.go b/internal/llm/util.go
new file mode 100644
index 0000000..b99d7c8
--- /dev/null
+++ b/internal/llm/util.go
@@ -0,0 +1,6 @@
+package llm
+
+import "errors"
+
+// small helper to keep return type consistent
+func nilStringErr(msg string) (string, error) { return "", errors.New(msg) }