summaryrefslogtreecommitdiff
path: root/internal/image/openai.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/image/openai.go')
-rw-r--r--internal/image/openai.go329
1 files changed, 329 insertions, 0 deletions
diff --git a/internal/image/openai.go b/internal/image/openai.go
new file mode 100644
index 0000000..a5a3e31
--- /dev/null
+++ b/internal/image/openai.go
@@ -0,0 +1,329 @@
+package image
+
+import (
+ "context"
+ "crypto/md5"
+ "encoding/hex"
+ "fmt"
+ "io"
+ "net/http"
+ "os"
+ "path/filepath"
+ "strings"
+
+ "github.com/sashabaranov/go-openai"
+)
+
+// OpenAIClient implements ImageSearcher for OpenAI DALL-E image generation
+type OpenAIClient struct {
+ client *openai.Client
+ apiKey string
+ model string // dall-e-2 or dall-e-3
+ size string // 256x256, 512x512, 1024x1024
+ quality string // standard or hd (dall-e-3 only)
+ style string // natural or vivid (dall-e-3 only)
+ cacheDir string
+ enableCache bool
+}
+
+// OpenAIConfig holds configuration for the OpenAI image provider
+type OpenAIConfig struct {
+ APIKey string
+ Model string
+ Size string
+ Quality string
+ Style string
+ CacheDir string
+ EnableCache bool
+}
+
+// NewOpenAIClient creates a new OpenAI DALL-E client
+func NewOpenAIClient(config *OpenAIConfig) *OpenAIClient {
+ if config.APIKey == "" {
+ // Return nil client that will fail on operations
+ return &OpenAIClient{}
+ }
+
+ client := openai.NewClient(config.APIKey)
+
+ // Set defaults
+ if config.Model == "" {
+ config.Model = "dall-e-2"
+ }
+ if config.Size == "" {
+ config.Size = "512x512"
+ }
+ if config.Quality == "" {
+ config.Quality = "standard"
+ }
+ if config.Style == "" {
+ config.Style = "natural"
+ }
+ if config.CacheDir == "" {
+ config.CacheDir = "./.image_cache"
+ }
+
+ oc := &OpenAIClient{
+ client: client,
+ apiKey: config.APIKey,
+ model: config.Model,
+ size: config.Size,
+ quality: config.Quality,
+ style: config.Style,
+ cacheDir: config.CacheDir,
+ enableCache: config.EnableCache,
+ }
+
+ // Create cache directory if caching is enabled
+ if oc.enableCache && oc.cacheDir != "" {
+ os.MkdirAll(oc.cacheDir, 0755)
+ }
+
+ return oc
+}
+
+// Search generates an image for the Bulgarian word using DALL-E
+func (c *OpenAIClient) Search(ctx context.Context, opts *SearchOptions) ([]SearchResult, error) {
+ if c.client == nil {
+ return nil, &SearchError{
+ Provider: "openai",
+ Code: "NO_API_KEY",
+ Message: "OpenAI API key not configured",
+ }
+ }
+
+ // Check cache first
+ if c.enableCache {
+ cacheFile := c.getCacheFilePath(opts.Query)
+ if info, err := os.Stat(cacheFile); err == nil && info.Size() > 0 {
+ // Return cached result
+ result := SearchResult{
+ ID: c.generateImageID(opts.Query),
+ URL: cacheFile,
+ ThumbnailURL: cacheFile,
+ Width: c.getSizeWidth(),
+ Height: c.getSizeHeight(),
+ Description: fmt.Sprintf("Generated image for %s", opts.Query),
+ Attribution: "Generated by OpenAI DALL-E",
+ Source: "openai",
+ }
+ return []SearchResult{result}, nil
+ }
+ }
+
+ // Translate Bulgarian word to English for better results
+ translatedWord := translateBulgarianToEnglish(opts.Query)
+
+ // Create educational prompt
+ prompt := c.createEducationalPrompt(opts.Query, translatedWord)
+
+ // Create the image generation request
+ req := openai.ImageRequest{
+ Prompt: prompt,
+ Model: c.model,
+ Size: c.size,
+ ResponseFormat: openai.CreateImageResponseFormatURL,
+ N: 1,
+ }
+
+ // Add model-specific parameters
+ if c.model == "dall-e-3" {
+ req.Quality = c.quality
+ req.Style = c.style
+ }
+
+ // Generate the image
+ resp, err := c.client.CreateImage(ctx, req)
+ if err != nil {
+ return nil, &SearchError{
+ Provider: "openai",
+ Code: "API_ERROR",
+ Message: fmt.Sprintf("Failed to generate image: %v", err),
+ }
+ }
+
+ if len(resp.Data) == 0 {
+ return nil, &SearchError{
+ Provider: "openai",
+ Code: "NO_RESULTS",
+ Message: "No image generated",
+ }
+ }
+
+ // Get the generated image URL
+ imageURL := resp.Data[0].URL
+
+ // Download and cache the image if caching is enabled
+ if c.enableCache {
+ cacheFile := c.getCacheFilePath(opts.Query)
+ if err := c.downloadAndCache(ctx, imageURL, cacheFile); err == nil {
+ // Update URL to point to cached file
+ imageURL = cacheFile
+ }
+ // Continue even if caching fails
+ }
+
+ // Create result
+ result := SearchResult{
+ ID: c.generateImageID(opts.Query),
+ URL: imageURL,
+ ThumbnailURL: imageURL,
+ Width: c.getSizeWidth(),
+ Height: c.getSizeHeight(),
+ Description: fmt.Sprintf("Generated educational image for %s (%s)", opts.Query, translatedWord),
+ Attribution: "Generated by OpenAI DALL-E",
+ Source: "openai",
+ }
+
+ return []SearchResult{result}, nil
+}
+
+// Download downloads an image from the given URL
+func (c *OpenAIClient) Download(ctx context.Context, url string) (io.ReadCloser, error) {
+ // If it's a local cached file (not an HTTP/HTTPS URL), open it directly
+ if !strings.HasPrefix(url, "http://") && !strings.HasPrefix(url, "https://") {
+ file, err := os.Open(url)
+ if err != nil {
+ return nil, fmt.Errorf("failed to open cached file: %w", err)
+ }
+ return file, nil
+ }
+
+ // Otherwise download from URL
+ req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ resp, err := http.DefaultClient.Do(req)
+ if err != nil {
+ return nil, err
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ resp.Body.Close()
+ return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, resp.Status)
+ }
+
+ return resp.Body, nil
+}
+
+// GetAttribution returns the required attribution text
+func (c *OpenAIClient) GetAttribution(result *SearchResult) string {
+ return "Image generated by OpenAI DALL-E"
+}
+
+// Name returns the name of the provider
+func (c *OpenAIClient) Name() string {
+ return "openai"
+}
+
+// createEducationalPrompt generates a prompt optimized for language learning
+func (c *OpenAIClient) createEducationalPrompt(bulgarianWord, englishTranslation string) string {
+ // Create a prompt that generates clear, educational images
+ // suitable for language learning flashcards
+ return fmt.Sprintf(
+ "A simple, clear, photorealistic educational image showing %s, "+
+ "suitable for language learning flashcards. "+
+ "The image should be easily recognizable, with good lighting, "+
+ "plain background, and focused on a single clear subject. "+
+ "No text, labels, or writing in the image.",
+ englishTranslation,
+ )
+}
+
+// translateBulgarianToEnglish translates a Bulgarian word to English
+func translateBulgarianToEnglish(word string) string {
+ // Use the existing translation function from translate.go
+ return translateBulgarianQuery(word)
+}
+
+// getCacheFilePath generates a cache file path for the given word
+func (c *OpenAIClient) getCacheFilePath(word string) string {
+ // Create a hash of the word and settings
+ h := md5.New()
+ h.Write([]byte(word))
+ h.Write([]byte(c.model))
+ h.Write([]byte(c.size))
+ h.Write([]byte(c.quality))
+ h.Write([]byte(c.style))
+ hash := hex.EncodeToString(h.Sum(nil))
+
+ // Use first 2 chars as subdirectory for better file system performance
+ subdir := hash[:2]
+ filename := hash[2:] + ".png"
+
+ return filepath.Join(c.cacheDir, subdir, filename)
+}
+
+// downloadAndCache downloads an image and saves it to the cache
+func (c *OpenAIClient) downloadAndCache(ctx context.Context, url, cacheFile string) error {
+ // Ensure directory exists
+ dir := filepath.Dir(cacheFile)
+ if err := os.MkdirAll(dir, 0755); err != nil {
+ return err
+ }
+
+ // Download the image
+ resp, err := c.Download(ctx, url)
+ if err != nil {
+ return err
+ }
+ defer resp.Close()
+
+ // Create the cache file
+ out, err := os.Create(cacheFile)
+ if err != nil {
+ return err
+ }
+ defer out.Close()
+
+ // Copy the data
+ _, err = io.Copy(out, resp)
+ return err
+}
+
+// generateImageID creates a unique ID for the image
+func (c *OpenAIClient) generateImageID(word string) string {
+ h := md5.New()
+ h.Write([]byte(word))
+ h.Write([]byte(c.model))
+ return "openai_" + hex.EncodeToString(h.Sum(nil))[:8]
+}
+
+// getSizeWidth returns the width based on the size setting
+func (c *OpenAIClient) getSizeWidth() int {
+ switch c.size {
+ case "256x256":
+ return 256
+ case "512x512":
+ return 512
+ case "1024x1024":
+ return 1024
+ case "1024x1792", "1792x1024": // DALL-E 3 sizes
+ if strings.HasPrefix(c.size, "1024") {
+ return 1024
+ }
+ return 1792
+ default:
+ return 512
+ }
+}
+
+// getSizeHeight returns the height based on the size setting
+func (c *OpenAIClient) getSizeHeight() int {
+ switch c.size {
+ case "256x256":
+ return 256
+ case "512x512":
+ return 512
+ case "1024x1024":
+ return 1024
+ case "1024x1792":
+ return 1792
+ case "1792x1024":
+ return 1024
+ default:
+ return 512
+ }
+} \ No newline at end of file