summaryrefslogtreecommitdiff
path: root/internal/llm/openai.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/llm/openai.go')
-rw-r--r--internal/llm/openai.go27
1 files changed, 20 insertions, 7 deletions
diff --git a/internal/llm/openai.go b/internal/llm/openai.go
index e9a1fdc..8b00335 100644
--- a/internal/llm/openai.go
+++ b/internal/llm/openai.go
@@ -26,12 +26,13 @@ type openAIClient struct {
}
type oaChatRequest struct {
- Model string `json:"model"`
- Messages []oaMessage `json:"messages"`
- Temperature *float64 `json:"temperature,omitempty"`
- MaxTokens *int `json:"max_tokens,omitempty"`
- Stop []string `json:"stop,omitempty"`
- Stream bool `json:"stream,omitempty"`
+ Model string `json:"model"`
+ Messages []oaMessage `json:"messages"`
+ Temperature *float64 `json:"temperature,omitempty"`
+ MaxTokens *int `json:"max_tokens,omitempty"`
+ MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
+ Stop []string `json:"stop,omitempty"`
+ Stream bool `json:"stream,omitempty"`
}
type oaMessage struct {
@@ -208,7 +209,11 @@ func buildOAChatRequest(o Options, messages []Message, defaultTemp *float64, str
req.Temperature = &t
}
if o.MaxTokens > 0 {
- req.MaxTokens = &o.MaxTokens
+ if requiresMaxCompletionTokens(o.Model) {
+ req.MaxCompletionTokens = &o.MaxTokens
+ } else {
+ req.MaxTokens = &o.MaxTokens
+ }
}
if len(o.Stop) > 0 {
req.Stop = o.Stop
@@ -216,6 +221,14 @@ func buildOAChatRequest(o Options, messages []Message, defaultTemp *float64, str
return req
}
+// requiresMaxCompletionTokens reports whether the given model prefers the
+// new parameter name "max_completion_tokens" instead of "max_tokens". Newer
+// models (e.g., gpt-5 family) expect this per OpenAI's API error guidance.
+func requiresMaxCompletionTokens(model string) bool {
+ m := strings.ToLower(strings.TrimSpace(model))
+ return strings.HasPrefix(m, "gpt-5")
+}
+
func (c openAIClient) doJSON(ctx context.Context, url string, body []byte, headers map[string]string) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {