summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
Diffstat (limited to 'internal')
-rw-r--r--internal/audio/provider.go2
-rw-r--r--internal/image/openai.go329
-rw-r--r--internal/image/openai_test.go280
-rw-r--r--internal/version.go2
4 files changed, 611 insertions, 2 deletions
diff --git a/internal/audio/provider.go b/internal/audio/provider.go
index 5b8c336..c803b61 100644
--- a/internal/audio/provider.go
+++ b/internal/audio/provider.go
@@ -44,7 +44,7 @@ type Config struct {
// DefaultConfig returns default configuration
func DefaultProviderConfig() *Config {
return &Config{
- Provider: "espeak",
+ Provider: "openai",
OutputDir: "./",
OutputFormat: "mp3",
ESpeakVoice: "bg",
diff --git a/internal/image/openai.go b/internal/image/openai.go
new file mode 100644
index 0000000..a5a3e31
--- /dev/null
+++ b/internal/image/openai.go
@@ -0,0 +1,329 @@
+package image
+
+import (
+ "context"
+ "crypto/md5"
+ "encoding/hex"
+ "fmt"
+ "io"
+ "net/http"
+ "os"
+ "path/filepath"
+ "strings"
+
+ "github.com/sashabaranov/go-openai"
+)
+
+// OpenAIClient implements ImageSearcher for OpenAI DALL-E image generation
+type OpenAIClient struct {
+ client *openai.Client
+ apiKey string
+ model string // dall-e-2 or dall-e-3
+ size string // 256x256, 512x512, 1024x1024
+ quality string // standard or hd (dall-e-3 only)
+ style string // natural or vivid (dall-e-3 only)
+ cacheDir string
+ enableCache bool
+}
+
+// OpenAIConfig holds configuration for the OpenAI image provider
+type OpenAIConfig struct {
+ APIKey string
+ Model string
+ Size string
+ Quality string
+ Style string
+ CacheDir string
+ EnableCache bool
+}
+
+// NewOpenAIClient creates a new OpenAI DALL-E client
+func NewOpenAIClient(config *OpenAIConfig) *OpenAIClient {
+ if config.APIKey == "" {
+ // Return nil client that will fail on operations
+ return &OpenAIClient{}
+ }
+
+ client := openai.NewClient(config.APIKey)
+
+ // Set defaults
+ if config.Model == "" {
+ config.Model = "dall-e-2"
+ }
+ if config.Size == "" {
+ config.Size = "512x512"
+ }
+ if config.Quality == "" {
+ config.Quality = "standard"
+ }
+ if config.Style == "" {
+ config.Style = "natural"
+ }
+ if config.CacheDir == "" {
+ config.CacheDir = "./.image_cache"
+ }
+
+ oc := &OpenAIClient{
+ client: client,
+ apiKey: config.APIKey,
+ model: config.Model,
+ size: config.Size,
+ quality: config.Quality,
+ style: config.Style,
+ cacheDir: config.CacheDir,
+ enableCache: config.EnableCache,
+ }
+
+ // Create cache directory if caching is enabled
+ if oc.enableCache && oc.cacheDir != "" {
+ os.MkdirAll(oc.cacheDir, 0755)
+ }
+
+ return oc
+}
+
+// Search generates an image for the Bulgarian word using DALL-E
+func (c *OpenAIClient) Search(ctx context.Context, opts *SearchOptions) ([]SearchResult, error) {
+ if c.client == nil {
+ return nil, &SearchError{
+ Provider: "openai",
+ Code: "NO_API_KEY",
+ Message: "OpenAI API key not configured",
+ }
+ }
+
+ // Check cache first
+ if c.enableCache {
+ cacheFile := c.getCacheFilePath(opts.Query)
+ if info, err := os.Stat(cacheFile); err == nil && info.Size() > 0 {
+ // Return cached result
+ result := SearchResult{
+ ID: c.generateImageID(opts.Query),
+ URL: cacheFile,
+ ThumbnailURL: cacheFile,
+ Width: c.getSizeWidth(),
+ Height: c.getSizeHeight(),
+ Description: fmt.Sprintf("Generated image for %s", opts.Query),
+ Attribution: "Generated by OpenAI DALL-E",
+ Source: "openai",
+ }
+ return []SearchResult{result}, nil
+ }
+ }
+
+ // Translate Bulgarian word to English for better results
+ translatedWord := translateBulgarianToEnglish(opts.Query)
+
+ // Create educational prompt
+ prompt := c.createEducationalPrompt(opts.Query, translatedWord)
+
+ // Create the image generation request
+ req := openai.ImageRequest{
+ Prompt: prompt,
+ Model: c.model,
+ Size: c.size,
+ ResponseFormat: openai.CreateImageResponseFormatURL,
+ N: 1,
+ }
+
+ // Add model-specific parameters
+ if c.model == "dall-e-3" {
+ req.Quality = c.quality
+ req.Style = c.style
+ }
+
+ // Generate the image
+ resp, err := c.client.CreateImage(ctx, req)
+ if err != nil {
+ return nil, &SearchError{
+ Provider: "openai",
+ Code: "API_ERROR",
+ Message: fmt.Sprintf("Failed to generate image: %v", err),
+ }
+ }
+
+ if len(resp.Data) == 0 {
+ return nil, &SearchError{
+ Provider: "openai",
+ Code: "NO_RESULTS",
+ Message: "No image generated",
+ }
+ }
+
+ // Get the generated image URL
+ imageURL := resp.Data[0].URL
+
+ // Download and cache the image if caching is enabled
+ if c.enableCache {
+ cacheFile := c.getCacheFilePath(opts.Query)
+ if err := c.downloadAndCache(ctx, imageURL, cacheFile); err == nil {
+ // Update URL to point to cached file
+ imageURL = cacheFile
+ }
+ // Continue even if caching fails
+ }
+
+ // Create result
+ result := SearchResult{
+ ID: c.generateImageID(opts.Query),
+ URL: imageURL,
+ ThumbnailURL: imageURL,
+ Width: c.getSizeWidth(),
+ Height: c.getSizeHeight(),
+ Description: fmt.Sprintf("Generated educational image for %s (%s)", opts.Query, translatedWord),
+ Attribution: "Generated by OpenAI DALL-E",
+ Source: "openai",
+ }
+
+ return []SearchResult{result}, nil
+}
+
+// Download downloads an image from the given URL
+func (c *OpenAIClient) Download(ctx context.Context, url string) (io.ReadCloser, error) {
+ // If it's a local cached file (not an HTTP/HTTPS URL), open it directly
+ if !strings.HasPrefix(url, "http://") && !strings.HasPrefix(url, "https://") {
+ file, err := os.Open(url)
+ if err != nil {
+ return nil, fmt.Errorf("failed to open cached file: %w", err)
+ }
+ return file, nil
+ }
+
+ // Otherwise download from URL
+ req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
+ if err != nil {
+ return nil, err
+ }
+
+ resp, err := http.DefaultClient.Do(req)
+ if err != nil {
+ return nil, err
+ }
+
+ if resp.StatusCode != http.StatusOK {
+ resp.Body.Close()
+ return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, resp.Status)
+ }
+
+ return resp.Body, nil
+}
+
+// GetAttribution returns the required attribution text
+func (c *OpenAIClient) GetAttribution(result *SearchResult) string {
+ return "Image generated by OpenAI DALL-E"
+}
+
+// Name returns the name of the provider
+func (c *OpenAIClient) Name() string {
+ return "openai"
+}
+
+// createEducationalPrompt generates a prompt optimized for language learning
+func (c *OpenAIClient) createEducationalPrompt(bulgarianWord, englishTranslation string) string {
+ // Create a prompt that generates clear, educational images
+ // suitable for language learning flashcards
+ return fmt.Sprintf(
+ "A simple, clear, photorealistic educational image showing %s, "+
+ "suitable for language learning flashcards. "+
+ "The image should be easily recognizable, with good lighting, "+
+ "plain background, and focused on a single clear subject. "+
+ "No text, labels, or writing in the image.",
+ englishTranslation,
+ )
+}
+
+// translateBulgarianToEnglish translates a Bulgarian word to English
+func translateBulgarianToEnglish(word string) string {
+ // Use the existing translation function from translate.go
+ return translateBulgarianQuery(word)
+}
+
+// getCacheFilePath generates a cache file path for the given word
+func (c *OpenAIClient) getCacheFilePath(word string) string {
+ // Create a hash of the word and settings
+ h := md5.New()
+ h.Write([]byte(word))
+ h.Write([]byte(c.model))
+ h.Write([]byte(c.size))
+ h.Write([]byte(c.quality))
+ h.Write([]byte(c.style))
+ hash := hex.EncodeToString(h.Sum(nil))
+
+ // Use first 2 chars as subdirectory for better file system performance
+ subdir := hash[:2]
+ filename := hash[2:] + ".png"
+
+ return filepath.Join(c.cacheDir, subdir, filename)
+}
+
+// downloadAndCache downloads an image and saves it to the cache
+func (c *OpenAIClient) downloadAndCache(ctx context.Context, url, cacheFile string) error {
+ // Ensure directory exists
+ dir := filepath.Dir(cacheFile)
+ if err := os.MkdirAll(dir, 0755); err != nil {
+ return err
+ }
+
+ // Download the image
+ resp, err := c.Download(ctx, url)
+ if err != nil {
+ return err
+ }
+ defer resp.Close()
+
+ // Create the cache file
+ out, err := os.Create(cacheFile)
+ if err != nil {
+ return err
+ }
+ defer out.Close()
+
+ // Copy the data
+ _, err = io.Copy(out, resp)
+ return err
+}
+
+// generateImageID creates a unique ID for the image
+func (c *OpenAIClient) generateImageID(word string) string {
+ h := md5.New()
+ h.Write([]byte(word))
+ h.Write([]byte(c.model))
+ return "openai_" + hex.EncodeToString(h.Sum(nil))[:8]
+}
+
+// getSizeWidth returns the width based on the size setting
+func (c *OpenAIClient) getSizeWidth() int {
+ switch c.size {
+ case "256x256":
+ return 256
+ case "512x512":
+ return 512
+ case "1024x1024":
+ return 1024
+ case "1024x1792", "1792x1024": // DALL-E 3 sizes
+ if strings.HasPrefix(c.size, "1024") {
+ return 1024
+ }
+ return 1792
+ default:
+ return 512
+ }
+}
+
+// getSizeHeight returns the height based on the size setting
+func (c *OpenAIClient) getSizeHeight() int {
+ switch c.size {
+ case "256x256":
+ return 256
+ case "512x512":
+ return 512
+ case "1024x1024":
+ return 1024
+ case "1024x1792":
+ return 1792
+ case "1792x1024":
+ return 1024
+ default:
+ return 512
+ }
+} \ No newline at end of file
diff --git a/internal/image/openai_test.go b/internal/image/openai_test.go
new file mode 100644
index 0000000..8f42aeb
--- /dev/null
+++ b/internal/image/openai_test.go
@@ -0,0 +1,280 @@
+package image
+
+import (
+ "context"
+ "os"
+ "testing"
+)
+
+func TestOpenAIClient_NewClient(t *testing.T) {
+ tests := []struct {
+ name string
+ config *OpenAIConfig
+ wantNil bool
+ }{
+ {
+ name: "with API key",
+ config: &OpenAIConfig{
+ APIKey: "test-key",
+ Model: "dall-e-2",
+ Size: "512x512",
+ },
+ wantNil: false,
+ },
+ {
+ name: "without API key",
+ config: &OpenAIConfig{
+ APIKey: "",
+ },
+ wantNil: false, // Client is created but will fail on operations
+ },
+ {
+ name: "with defaults",
+ config: &OpenAIConfig{
+ APIKey: "test-key",
+ },
+ wantNil: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ client := NewOpenAIClient(tt.config)
+ if (client == nil) != tt.wantNil {
+ t.Errorf("NewOpenAIClient() returned nil = %v, want %v", client == nil, tt.wantNil)
+ }
+
+ if client != nil && tt.config.APIKey != "" {
+ // Check defaults were set
+ if tt.config.Model == "" && client.model != "dall-e-2" {
+ t.Errorf("Expected default model dall-e-2, got %s", client.model)
+ }
+ if tt.config.Size == "" && client.size != "512x512" {
+ t.Errorf("Expected default size 512x512, got %s", client.size)
+ }
+ }
+ })
+ }
+}
+
+func TestOpenAIClient_createEducationalPrompt(t *testing.T) {
+ client := &OpenAIClient{}
+
+ tests := []struct {
+ bulgarian string
+ english string
+ wantContains []string
+ }{
+ {
+ bulgarian: "ябълка",
+ english: "apple",
+ wantContains: []string{"apple", "educational", "flashcard"},
+ },
+ {
+ bulgarian: "котка",
+ english: "cat",
+ wantContains: []string{"cat", "simple", "clear"},
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.bulgarian, func(t *testing.T) {
+ prompt := client.createEducationalPrompt(tt.bulgarian, tt.english)
+
+ for _, want := range tt.wantContains {
+ if !contains(prompt, want) {
+ t.Errorf("Prompt missing expected word '%s': %s", want, prompt)
+ }
+ }
+ })
+ }
+}
+
+func TestOpenAIClient_getCacheFilePath(t *testing.T) {
+ client := &OpenAIClient{
+ model: "dall-e-2",
+ size: "512x512",
+ quality: "standard",
+ style: "natural",
+ cacheDir: "./.test_cache",
+ }
+
+ // Test that same input produces same cache path
+ path1 := client.getCacheFilePath("ябълка")
+ path2 := client.getCacheFilePath("ябълка")
+
+ if path1 != path2 {
+ t.Errorf("Cache paths differ for same input: %s vs %s", path1, path2)
+ }
+
+ // Test that different inputs produce different paths
+ path3 := client.getCacheFilePath("котка")
+ if path1 == path3 {
+ t.Errorf("Cache paths same for different inputs")
+ }
+
+ // Test path structure
+ if !contains(path1, ".test_cache") {
+ t.Errorf("Cache path doesn't contain cache dir: %s", path1)
+ }
+
+ if !contains(path1, ".png") {
+ t.Errorf("Cache path doesn't have .png extension: %s", path1)
+ }
+}
+
+func TestOpenAIClient_translateBulgarianToEnglish(t *testing.T) {
+ tests := []struct {
+ input string
+ expected string
+ }{
+ {"ябълка", "apple"},
+ {"котка", "cat"},
+ {"куче", "dog"},
+ {"хляб", "bread"},
+ {"unknown", "unknown"}, // Should return original if not in dictionary
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.input, func(t *testing.T) {
+ result := translateBulgarianToEnglish(tt.input)
+ if result != tt.expected {
+ t.Errorf("translateBulgarianToEnglish(%s) = %s, want %s",
+ tt.input, result, tt.expected)
+ }
+ })
+ }
+}
+
+func TestOpenAIClient_getSizeWidthHeight(t *testing.T) {
+ tests := []struct {
+ size string
+ width int
+ height int
+ }{
+ {"256x256", 256, 256},
+ {"512x512", 512, 512},
+ {"1024x1024", 1024, 1024},
+ {"1024x1792", 1024, 1792},
+ {"1792x1024", 1792, 1024},
+ {"unknown", 512, 512}, // Default
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.size, func(t *testing.T) {
+ client := &OpenAIClient{size: tt.size}
+
+ if w := client.getSizeWidth(); w != tt.width {
+ t.Errorf("getSizeWidth() = %d, want %d", w, tt.width)
+ }
+
+ if h := client.getSizeHeight(); h != tt.height {
+ t.Errorf("getSizeHeight() = %d, want %d", h, tt.height)
+ }
+ })
+ }
+}
+
+func TestOpenAIClient_Search_NoAPIKey(t *testing.T) {
+ client := NewOpenAIClient(&OpenAIConfig{})
+
+ opts := DefaultSearchOptions("ябълка")
+ _, err := client.Search(context.Background(), opts)
+
+ if err == nil {
+ t.Error("Expected error for missing API key")
+ }
+
+ if searchErr, ok := err.(*SearchError); ok {
+ if searchErr.Code != "NO_API_KEY" {
+ t.Errorf("Expected NO_API_KEY error, got %s", searchErr.Code)
+ }
+ } else {
+ t.Error("Expected SearchError type")
+ }
+}
+
+func TestOpenAIClient_Name(t *testing.T) {
+ client := &OpenAIClient{}
+ if name := client.Name(); name != "openai" {
+ t.Errorf("Name() = %s, want 'openai'", name)
+ }
+}
+
+func TestOpenAIClient_GetAttribution(t *testing.T) {
+ client := &OpenAIClient{}
+ result := &SearchResult{}
+
+ attr := client.GetAttribution(result)
+ if !contains(attr, "OpenAI DALL-E") {
+ t.Errorf("Attribution doesn't mention OpenAI DALL-E: %s", attr)
+ }
+}
+
+// Helper function
+func contains(s, substr string) bool {
+ return len(s) >= len(substr) &&
+ (s == substr || len(s) > 0 && containsHelper(s, substr))
+}
+
+func containsHelper(s, substr string) bool {
+ for i := 0; i <= len(s)-len(substr); i++ {
+ if s[i:i+len(substr)] == substr {
+ return true
+ }
+ }
+ return false
+}
+
+// Integration test (skipped by default)
+func TestOpenAIClient_Search_Integration(t *testing.T) {
+ apiKey := os.Getenv("OPENAI_API_KEY")
+ if apiKey == "" {
+ t.Skip("OPENAI_API_KEY not set, skipping integration test")
+ }
+
+ client := NewOpenAIClient(&OpenAIConfig{
+ APIKey: apiKey,
+ Model: "dall-e-2",
+ Size: "256x256", // Smallest size to minimize cost
+ EnableCache: true,
+ CacheDir: t.TempDir(),
+ })
+
+ opts := DefaultSearchOptions("ябълка")
+ results, err := client.Search(context.Background(), opts)
+
+ if err != nil {
+ t.Fatalf("Search failed: %v", err)
+ }
+
+ if len(results) != 1 {
+ t.Fatalf("Expected 1 result, got %d", len(results))
+ }
+
+ result := results[0]
+
+ // Check result fields
+ if result.ID == "" {
+ t.Error("Result ID is empty")
+ }
+ if result.URL == "" {
+ t.Error("Result URL is empty")
+ }
+ if result.Width != 256 || result.Height != 256 {
+ t.Errorf("Expected 256x256, got %dx%d", result.Width, result.Height)
+ }
+ if result.Source != "openai" {
+ t.Errorf("Expected source 'openai', got '%s'", result.Source)
+ }
+
+ // Test caching - second request should use cache
+ results2, err := client.Search(context.Background(), opts)
+ if err != nil {
+ t.Fatalf("Second search failed: %v", err)
+ }
+
+ if results2[0].URL != results[0].URL {
+ t.Log("Note: URLs differ, cache might not be working as expected")
+ }
+} \ No newline at end of file
diff --git a/internal/version.go b/internal/version.go
index 93a42a8..0894830 100644
--- a/internal/version.go
+++ b/internal/version.go
@@ -1,3 +1,3 @@
package internal
-const Version = "0.0.0"
+const Version = "0.1.0"