diff options
Diffstat (limited to 'internal/llm')
| -rw-r--r-- | internal/llm/copilot.go | 13 | ||||
| -rw-r--r-- | internal/llm/ollama.go | 45 | ||||
| -rw-r--r-- | internal/llm/openai.go | 23 | ||||
| -rw-r--r-- | internal/llm/provider.go | 76 | ||||
| -rw-r--r-- | internal/llm/util.go | 6 |
5 files changed, 103 insertions, 60 deletions
diff --git a/internal/llm/copilot.go b/internal/llm/copilot.go index 680e7ec..47ce11e 100644 --- a/internal/llm/copilot.go +++ b/internal/llm/copilot.go @@ -22,6 +22,7 @@ type copilotClient struct { baseURL string defaultModel string chatLogger logging.ChatLogger + defaultTemperature *float64 } type copilotChatRequest struct { @@ -55,7 +56,7 @@ type copilotChatResponse struct { } // Constructor (kept among the first functions by convention) -func newCopilot(baseURL, model, apiKey string) Client { +func newCopilot(baseURL, model, apiKey string, defaultTemp *float64) Client { if strings.TrimSpace(baseURL) == "" { baseURL = "https://api.githubcopilot.com" } @@ -68,6 +69,7 @@ func newCopilot(baseURL, model, apiKey string) Client { baseURL: strings.TrimRight(baseURL, "/"), defaultModel: model, chatLogger: logging.NewChatLogger("copilot"), + defaultTemperature: defaultTemp, } } @@ -101,9 +103,12 @@ func (c copilotClient) Chat(ctx context.Context, messages []Message, opts ...Req for i, m := range messages { req.Messages[i] = copilotMessage{Role: m.Role, Content: m.Content} } - if o.Temperature != 0 { - req.Temperature = &o.Temperature - } + if o.Temperature != 0 { + req.Temperature = &o.Temperature + } else if c.defaultTemperature != nil { + t := *c.defaultTemperature + req.Temperature = &t + } if o.MaxTokens > 0 { req.MaxTokens = &o.MaxTokens } diff --git a/internal/llm/ollama.go b/internal/llm/ollama.go index 14aa558..20dfe2a 100644 --- a/internal/llm/ollama.go +++ b/internal/llm/ollama.go @@ -22,13 +22,14 @@ type ollamaClient struct { baseURL string defaultModel string chatLogger logging.ChatLogger + defaultTemperature *float64 } type ollamaChatRequest struct { - Model string `json:"model"` - Messages []oaMessage `json:"messages"` - Stream bool `json:"stream"` - Options any `json:"options,omitempty"` + Model string `json:"model"` + Messages []oaMessage `json:"messages"` + Stream bool `json:"stream"` + Options any `json:"options,omitempty"` } type ollamaChatResponse struct { @@ -41,21 +42,23 @@ type ollamaChatResponse struct { } // Constructor (kept among the first functions by convention) -func newOllama(baseURL, model string) Client { - if strings.TrimSpace(baseURL) == "" { - baseURL = "http://localhost:11434" - } - if strings.TrimSpace(model) == "" { - model = "qwen2.5-coder:latest" - } +func newOllama(baseURL, model string, defaultTemp *float64) Client { + if strings.TrimSpace(baseURL) == "" { + baseURL = "http://localhost:11434" + } + if strings.TrimSpace(model) == "" { + model = "qwen3-coder:30b-a3b-q4_K_M`" + } return ollamaClient{ httpClient: &http.Client{Timeout: 30 * time.Second}, baseURL: strings.TrimRight(baseURL, "/"), defaultModel: model, chatLogger: logging.NewChatLogger("ollama"), + defaultTemperature: defaultTemp, } } +// TODO: This function is too long and should be refactored for readability and maintainability. func (c ollamaClient) Chat(ctx context.Context, messages []Message, opts ...RequestOption) (string, error) { o := Options{Model: c.defaultModel} for _, opt := range opts { @@ -86,9 +89,11 @@ func (c ollamaClient) Chat(ctx context.Context, messages []Message, opts ...Requ // Build options map only if any option is set optsMap := map[string]any{} - if o.Temperature != 0 { - optsMap["temperature"] = o.Temperature - } + if o.Temperature != 0 { + optsMap["temperature"] = o.Temperature + } else if c.defaultTemperature != nil { + optsMap["temperature"] = *c.defaultTemperature + } if o.MaxTokens > 0 { optsMap["num_predict"] = o.MaxTokens } @@ -177,9 +182,11 @@ func (c ollamaClient) ChatStream(ctx context.Context, messages []Message, onDelt } // Build options map optsMap := map[string]any{} - if o.Temperature != 0 { - optsMap["temperature"] = o.Temperature - } + if o.Temperature != 0 { + optsMap["temperature"] = o.Temperature + } else if c.defaultTemperature != nil { + optsMap["temperature"] = *c.defaultTemperature + } if o.MaxTokens > 0 { optsMap["num_predict"] = o.MaxTokens } @@ -241,6 +248,6 @@ func (c ollamaClient) ChatStream(ctx context.Context, messages []Message, onDelt break } } - logging.Logf("llm/ollama ", "stream end duration=%s", time.Since(start)) - return nil + logging.Logf("llm/ollama ", "stream end duration=%s", time.Since(start)) + return nil } diff --git a/internal/llm/openai.go b/internal/llm/openai.go index 8dc2907..5348def 100644 --- a/internal/llm/openai.go +++ b/internal/llm/openai.go @@ -23,6 +23,7 @@ type openAIClient struct { baseURL string defaultModel string chatLogger logging.ChatLogger + defaultTemperature *float64 } type oaChatRequest struct { @@ -75,7 +76,7 @@ type oaStreamChunk struct { // Constructor (kept among the first functions by convention) // newOpenAI constructs an OpenAI client using explicit configuration values. // The apiKey may be empty; calls will fail until a valid key is supplied. -func newOpenAI(baseURL, model, apiKey string) Client { +func newOpenAI(baseURL, model, apiKey string, defaultTemp *float64) Client { if strings.TrimSpace(baseURL) == "" { baseURL = "https://api.openai.com/v1" } @@ -88,6 +89,7 @@ func newOpenAI(baseURL, model, apiKey string) Client { baseURL: baseURL, defaultModel: model, chatLogger: logging.NewChatLogger("openai"), + defaultTemperature: defaultTemp, } } @@ -120,9 +122,13 @@ func (c openAIClient) Chat(ctx context.Context, messages []Message, opts ...Requ for i, m := range messages { req.Messages[i] = oaMessage{Role: m.Role, Content: m.Content} } - if o.Temperature != 0 { - req.Temperature = &o.Temperature - } + // Decide temperature: request option overrides config default. + if o.Temperature != 0 { + req.Temperature = &o.Temperature + } else if c.defaultTemperature != nil { + t := *c.defaultTemperature + req.Temperature = &t + } if o.MaxTokens > 0 { req.MaxTokens = &o.MaxTokens } @@ -212,9 +218,12 @@ func (c openAIClient) ChatStream(ctx context.Context, messages []Message, onDelt for i, m := range messages { req.Messages[i] = oaMessage{Role: m.Role, Content: m.Content} } - if o.Temperature != 0 { - req.Temperature = &o.Temperature - } + if o.Temperature != 0 { + req.Temperature = &o.Temperature + } else if c.defaultTemperature != nil { + t := *c.defaultTemperature + req.Temperature = &t + } if o.MaxTokens > 0 { req.MaxTokens = &o.MaxTokens } diff --git a/internal/llm/provider.go b/internal/llm/provider.go index c605081..ed9ca59 100644 --- a/internal/llm/provider.go +++ b/internal/llm/provider.go @@ -54,40 +54,56 @@ func WithStop(stop ...string) RequestOption { // Config defines provider configuration read from the Hexai config file. type Config struct { - Provider string - // OpenAI options - OpenAIBaseURL string - OpenAIModel string - // Ollama options - OllamaBaseURL string - OllamaModel string - // Copilot options - CopilotBaseURL string - CopilotModel string + Provider string + // OpenAI options + OpenAIBaseURL string + OpenAIModel string + OpenAITemperature *float64 + // Ollama options + OllamaBaseURL string + OllamaModel string + OllamaTemperature *float64 + // Copilot options + CopilotBaseURL string + CopilotModel string + CopilotTemperature *float64 } // NewFromConfig creates an LLM client using only the supplied configuration. // The OpenAI API key is supplied separately and may be read from the environment // by the caller; other environment-based configuration is not used. func NewFromConfig(cfg Config, openAIAPIKey, copilotAPIKey string) (Client, error) { - p := strings.ToLower(strings.TrimSpace(cfg.Provider)) - if p == "" { - p = "openai" - } - switch p { - case "openai": - if strings.TrimSpace(openAIAPIKey) == "" { - return nil, errors.New("missing OPENAI_API_KEY for provider openai") - } - return newOpenAI(cfg.OpenAIBaseURL, cfg.OpenAIModel, openAIAPIKey), nil - case "ollama": - return newOllama(cfg.OllamaBaseURL, cfg.OllamaModel), nil - case "copilot": - if strings.TrimSpace(copilotAPIKey) == "" { - return nil, errors.New("missing COPILOT_API_KEY for provider copilot") - } - return newCopilot(cfg.CopilotBaseURL, cfg.CopilotModel, copilotAPIKey), nil - default: - return nil, errors.New("unknown LLM provider: " + p) - } + p := strings.ToLower(strings.TrimSpace(cfg.Provider)) + if p == "" { + p = "openai" + } + switch p { + case "openai": + if strings.TrimSpace(openAIAPIKey) == "" { + return nil, errors.New("missing OPENAI_API_KEY for provider openai") + } + // Set coding-friendly default temperature if none provided + if cfg.OpenAITemperature == nil { + t := 0.2 + cfg.OpenAITemperature = &t + } + return newOpenAI(cfg.OpenAIBaseURL, cfg.OpenAIModel, openAIAPIKey, cfg.OpenAITemperature), nil + case "ollama": + if cfg.OllamaTemperature == nil { + t := 0.2 + cfg.OllamaTemperature = &t + } + return newOllama(cfg.OllamaBaseURL, cfg.OllamaModel, cfg.OllamaTemperature), nil + case "copilot": + if strings.TrimSpace(copilotAPIKey) == "" { + return nil, errors.New("missing COPILOT_API_KEY for provider copilot") + } + if cfg.CopilotTemperature == nil { + t := 0.2 + cfg.CopilotTemperature = &t + } + return newCopilot(cfg.CopilotBaseURL, cfg.CopilotModel, copilotAPIKey, cfg.CopilotTemperature), nil + default: + return nil, errors.New("unknown LLM provider: " + p) + } } diff --git a/internal/llm/util.go b/internal/llm/util.go new file mode 100644 index 0000000..b99d7c8 --- /dev/null +++ b/internal/llm/util.go @@ -0,0 +1,6 @@ +package llm + +import "errors" + +// small helper to keep return type consistent +func nilStringErr(msg string) (string, error) { return "", errors.New(msg) } |
