diff options
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/audio/provider.go | 2 | ||||
| -rw-r--r-- | internal/image/openai.go | 329 | ||||
| -rw-r--r-- | internal/image/openai_test.go | 280 | ||||
| -rw-r--r-- | internal/version.go | 2 |
4 files changed, 611 insertions, 2 deletions
diff --git a/internal/audio/provider.go b/internal/audio/provider.go index 5b8c336..c803b61 100644 --- a/internal/audio/provider.go +++ b/internal/audio/provider.go @@ -44,7 +44,7 @@ type Config struct { // DefaultConfig returns default configuration func DefaultProviderConfig() *Config { return &Config{ - Provider: "espeak", + Provider: "openai", OutputDir: "./", OutputFormat: "mp3", ESpeakVoice: "bg", diff --git a/internal/image/openai.go b/internal/image/openai.go new file mode 100644 index 0000000..a5a3e31 --- /dev/null +++ b/internal/image/openai.go @@ -0,0 +1,329 @@ +package image + +import ( + "context" + "crypto/md5" + "encoding/hex" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strings" + + "github.com/sashabaranov/go-openai" +) + +// OpenAIClient implements ImageSearcher for OpenAI DALL-E image generation +type OpenAIClient struct { + client *openai.Client + apiKey string + model string // dall-e-2 or dall-e-3 + size string // 256x256, 512x512, 1024x1024 + quality string // standard or hd (dall-e-3 only) + style string // natural or vivid (dall-e-3 only) + cacheDir string + enableCache bool +} + +// OpenAIConfig holds configuration for the OpenAI image provider +type OpenAIConfig struct { + APIKey string + Model string + Size string + Quality string + Style string + CacheDir string + EnableCache bool +} + +// NewOpenAIClient creates a new OpenAI DALL-E client +func NewOpenAIClient(config *OpenAIConfig) *OpenAIClient { + if config.APIKey == "" { + // Return nil client that will fail on operations + return &OpenAIClient{} + } + + client := openai.NewClient(config.APIKey) + + // Set defaults + if config.Model == "" { + config.Model = "dall-e-2" + } + if config.Size == "" { + config.Size = "512x512" + } + if config.Quality == "" { + config.Quality = "standard" + } + if config.Style == "" { + config.Style = "natural" + } + if config.CacheDir == "" { + config.CacheDir = "./.image_cache" + } + + oc := &OpenAIClient{ + client: client, + apiKey: config.APIKey, + model: config.Model, + size: config.Size, + quality: config.Quality, + style: config.Style, + cacheDir: config.CacheDir, + enableCache: config.EnableCache, + } + + // Create cache directory if caching is enabled + if oc.enableCache && oc.cacheDir != "" { + os.MkdirAll(oc.cacheDir, 0755) + } + + return oc +} + +// Search generates an image for the Bulgarian word using DALL-E +func (c *OpenAIClient) Search(ctx context.Context, opts *SearchOptions) ([]SearchResult, error) { + if c.client == nil { + return nil, &SearchError{ + Provider: "openai", + Code: "NO_API_KEY", + Message: "OpenAI API key not configured", + } + } + + // Check cache first + if c.enableCache { + cacheFile := c.getCacheFilePath(opts.Query) + if info, err := os.Stat(cacheFile); err == nil && info.Size() > 0 { + // Return cached result + result := SearchResult{ + ID: c.generateImageID(opts.Query), + URL: cacheFile, + ThumbnailURL: cacheFile, + Width: c.getSizeWidth(), + Height: c.getSizeHeight(), + Description: fmt.Sprintf("Generated image for %s", opts.Query), + Attribution: "Generated by OpenAI DALL-E", + Source: "openai", + } + return []SearchResult{result}, nil + } + } + + // Translate Bulgarian word to English for better results + translatedWord := translateBulgarianToEnglish(opts.Query) + + // Create educational prompt + prompt := c.createEducationalPrompt(opts.Query, translatedWord) + + // Create the image generation request + req := openai.ImageRequest{ + Prompt: prompt, + Model: c.model, + Size: c.size, + ResponseFormat: openai.CreateImageResponseFormatURL, + N: 1, + } + + // Add model-specific parameters + if c.model == "dall-e-3" { + req.Quality = c.quality + req.Style = c.style + } + + // Generate the image + resp, err := c.client.CreateImage(ctx, req) + if err != nil { + return nil, &SearchError{ + Provider: "openai", + Code: "API_ERROR", + Message: fmt.Sprintf("Failed to generate image: %v", err), + } + } + + if len(resp.Data) == 0 { + return nil, &SearchError{ + Provider: "openai", + Code: "NO_RESULTS", + Message: "No image generated", + } + } + + // Get the generated image URL + imageURL := resp.Data[0].URL + + // Download and cache the image if caching is enabled + if c.enableCache { + cacheFile := c.getCacheFilePath(opts.Query) + if err := c.downloadAndCache(ctx, imageURL, cacheFile); err == nil { + // Update URL to point to cached file + imageURL = cacheFile + } + // Continue even if caching fails + } + + // Create result + result := SearchResult{ + ID: c.generateImageID(opts.Query), + URL: imageURL, + ThumbnailURL: imageURL, + Width: c.getSizeWidth(), + Height: c.getSizeHeight(), + Description: fmt.Sprintf("Generated educational image for %s (%s)", opts.Query, translatedWord), + Attribution: "Generated by OpenAI DALL-E", + Source: "openai", + } + + return []SearchResult{result}, nil +} + +// Download downloads an image from the given URL +func (c *OpenAIClient) Download(ctx context.Context, url string) (io.ReadCloser, error) { + // If it's a local cached file (not an HTTP/HTTPS URL), open it directly + if !strings.HasPrefix(url, "http://") && !strings.HasPrefix(url, "https://") { + file, err := os.Open(url) + if err != nil { + return nil, fmt.Errorf("failed to open cached file: %w", err) + } + return file, nil + } + + // Otherwise download from URL + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, err + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + resp.Body.Close() + return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, resp.Status) + } + + return resp.Body, nil +} + +// GetAttribution returns the required attribution text +func (c *OpenAIClient) GetAttribution(result *SearchResult) string { + return "Image generated by OpenAI DALL-E" +} + +// Name returns the name of the provider +func (c *OpenAIClient) Name() string { + return "openai" +} + +// createEducationalPrompt generates a prompt optimized for language learning +func (c *OpenAIClient) createEducationalPrompt(bulgarianWord, englishTranslation string) string { + // Create a prompt that generates clear, educational images + // suitable for language learning flashcards + return fmt.Sprintf( + "A simple, clear, photorealistic educational image showing %s, "+ + "suitable for language learning flashcards. "+ + "The image should be easily recognizable, with good lighting, "+ + "plain background, and focused on a single clear subject. "+ + "No text, labels, or writing in the image.", + englishTranslation, + ) +} + +// translateBulgarianToEnglish translates a Bulgarian word to English +func translateBulgarianToEnglish(word string) string { + // Use the existing translation function from translate.go + return translateBulgarianQuery(word) +} + +// getCacheFilePath generates a cache file path for the given word +func (c *OpenAIClient) getCacheFilePath(word string) string { + // Create a hash of the word and settings + h := md5.New() + h.Write([]byte(word)) + h.Write([]byte(c.model)) + h.Write([]byte(c.size)) + h.Write([]byte(c.quality)) + h.Write([]byte(c.style)) + hash := hex.EncodeToString(h.Sum(nil)) + + // Use first 2 chars as subdirectory for better file system performance + subdir := hash[:2] + filename := hash[2:] + ".png" + + return filepath.Join(c.cacheDir, subdir, filename) +} + +// downloadAndCache downloads an image and saves it to the cache +func (c *OpenAIClient) downloadAndCache(ctx context.Context, url, cacheFile string) error { + // Ensure directory exists + dir := filepath.Dir(cacheFile) + if err := os.MkdirAll(dir, 0755); err != nil { + return err + } + + // Download the image + resp, err := c.Download(ctx, url) + if err != nil { + return err + } + defer resp.Close() + + // Create the cache file + out, err := os.Create(cacheFile) + if err != nil { + return err + } + defer out.Close() + + // Copy the data + _, err = io.Copy(out, resp) + return err +} + +// generateImageID creates a unique ID for the image +func (c *OpenAIClient) generateImageID(word string) string { + h := md5.New() + h.Write([]byte(word)) + h.Write([]byte(c.model)) + return "openai_" + hex.EncodeToString(h.Sum(nil))[:8] +} + +// getSizeWidth returns the width based on the size setting +func (c *OpenAIClient) getSizeWidth() int { + switch c.size { + case "256x256": + return 256 + case "512x512": + return 512 + case "1024x1024": + return 1024 + case "1024x1792", "1792x1024": // DALL-E 3 sizes + if strings.HasPrefix(c.size, "1024") { + return 1024 + } + return 1792 + default: + return 512 + } +} + +// getSizeHeight returns the height based on the size setting +func (c *OpenAIClient) getSizeHeight() int { + switch c.size { + case "256x256": + return 256 + case "512x512": + return 512 + case "1024x1024": + return 1024 + case "1024x1792": + return 1792 + case "1792x1024": + return 1024 + default: + return 512 + } +}
\ No newline at end of file diff --git a/internal/image/openai_test.go b/internal/image/openai_test.go new file mode 100644 index 0000000..8f42aeb --- /dev/null +++ b/internal/image/openai_test.go @@ -0,0 +1,280 @@ +package image + +import ( + "context" + "os" + "testing" +) + +func TestOpenAIClient_NewClient(t *testing.T) { + tests := []struct { + name string + config *OpenAIConfig + wantNil bool + }{ + { + name: "with API key", + config: &OpenAIConfig{ + APIKey: "test-key", + Model: "dall-e-2", + Size: "512x512", + }, + wantNil: false, + }, + { + name: "without API key", + config: &OpenAIConfig{ + APIKey: "", + }, + wantNil: false, // Client is created but will fail on operations + }, + { + name: "with defaults", + config: &OpenAIConfig{ + APIKey: "test-key", + }, + wantNil: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := NewOpenAIClient(tt.config) + if (client == nil) != tt.wantNil { + t.Errorf("NewOpenAIClient() returned nil = %v, want %v", client == nil, tt.wantNil) + } + + if client != nil && tt.config.APIKey != "" { + // Check defaults were set + if tt.config.Model == "" && client.model != "dall-e-2" { + t.Errorf("Expected default model dall-e-2, got %s", client.model) + } + if tt.config.Size == "" && client.size != "512x512" { + t.Errorf("Expected default size 512x512, got %s", client.size) + } + } + }) + } +} + +func TestOpenAIClient_createEducationalPrompt(t *testing.T) { + client := &OpenAIClient{} + + tests := []struct { + bulgarian string + english string + wantContains []string + }{ + { + bulgarian: "ябълка", + english: "apple", + wantContains: []string{"apple", "educational", "flashcard"}, + }, + { + bulgarian: "котка", + english: "cat", + wantContains: []string{"cat", "simple", "clear"}, + }, + } + + for _, tt := range tests { + t.Run(tt.bulgarian, func(t *testing.T) { + prompt := client.createEducationalPrompt(tt.bulgarian, tt.english) + + for _, want := range tt.wantContains { + if !contains(prompt, want) { + t.Errorf("Prompt missing expected word '%s': %s", want, prompt) + } + } + }) + } +} + +func TestOpenAIClient_getCacheFilePath(t *testing.T) { + client := &OpenAIClient{ + model: "dall-e-2", + size: "512x512", + quality: "standard", + style: "natural", + cacheDir: "./.test_cache", + } + + // Test that same input produces same cache path + path1 := client.getCacheFilePath("ябълка") + path2 := client.getCacheFilePath("ябълка") + + if path1 != path2 { + t.Errorf("Cache paths differ for same input: %s vs %s", path1, path2) + } + + // Test that different inputs produce different paths + path3 := client.getCacheFilePath("котка") + if path1 == path3 { + t.Errorf("Cache paths same for different inputs") + } + + // Test path structure + if !contains(path1, ".test_cache") { + t.Errorf("Cache path doesn't contain cache dir: %s", path1) + } + + if !contains(path1, ".png") { + t.Errorf("Cache path doesn't have .png extension: %s", path1) + } +} + +func TestOpenAIClient_translateBulgarianToEnglish(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"ябълка", "apple"}, + {"котка", "cat"}, + {"куче", "dog"}, + {"хляб", "bread"}, + {"unknown", "unknown"}, // Should return original if not in dictionary + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + result := translateBulgarianToEnglish(tt.input) + if result != tt.expected { + t.Errorf("translateBulgarianToEnglish(%s) = %s, want %s", + tt.input, result, tt.expected) + } + }) + } +} + +func TestOpenAIClient_getSizeWidthHeight(t *testing.T) { + tests := []struct { + size string + width int + height int + }{ + {"256x256", 256, 256}, + {"512x512", 512, 512}, + {"1024x1024", 1024, 1024}, + {"1024x1792", 1024, 1792}, + {"1792x1024", 1792, 1024}, + {"unknown", 512, 512}, // Default + } + + for _, tt := range tests { + t.Run(tt.size, func(t *testing.T) { + client := &OpenAIClient{size: tt.size} + + if w := client.getSizeWidth(); w != tt.width { + t.Errorf("getSizeWidth() = %d, want %d", w, tt.width) + } + + if h := client.getSizeHeight(); h != tt.height { + t.Errorf("getSizeHeight() = %d, want %d", h, tt.height) + } + }) + } +} + +func TestOpenAIClient_Search_NoAPIKey(t *testing.T) { + client := NewOpenAIClient(&OpenAIConfig{}) + + opts := DefaultSearchOptions("ябълка") + _, err := client.Search(context.Background(), opts) + + if err == nil { + t.Error("Expected error for missing API key") + } + + if searchErr, ok := err.(*SearchError); ok { + if searchErr.Code != "NO_API_KEY" { + t.Errorf("Expected NO_API_KEY error, got %s", searchErr.Code) + } + } else { + t.Error("Expected SearchError type") + } +} + +func TestOpenAIClient_Name(t *testing.T) { + client := &OpenAIClient{} + if name := client.Name(); name != "openai" { + t.Errorf("Name() = %s, want 'openai'", name) + } +} + +func TestOpenAIClient_GetAttribution(t *testing.T) { + client := &OpenAIClient{} + result := &SearchResult{} + + attr := client.GetAttribution(result) + if !contains(attr, "OpenAI DALL-E") { + t.Errorf("Attribution doesn't mention OpenAI DALL-E: %s", attr) + } +} + +// Helper function +func contains(s, substr string) bool { + return len(s) >= len(substr) && + (s == substr || len(s) > 0 && containsHelper(s, substr)) +} + +func containsHelper(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} + +// Integration test (skipped by default) +func TestOpenAIClient_Search_Integration(t *testing.T) { + apiKey := os.Getenv("OPENAI_API_KEY") + if apiKey == "" { + t.Skip("OPENAI_API_KEY not set, skipping integration test") + } + + client := NewOpenAIClient(&OpenAIConfig{ + APIKey: apiKey, + Model: "dall-e-2", + Size: "256x256", // Smallest size to minimize cost + EnableCache: true, + CacheDir: t.TempDir(), + }) + + opts := DefaultSearchOptions("ябълка") + results, err := client.Search(context.Background(), opts) + + if err != nil { + t.Fatalf("Search failed: %v", err) + } + + if len(results) != 1 { + t.Fatalf("Expected 1 result, got %d", len(results)) + } + + result := results[0] + + // Check result fields + if result.ID == "" { + t.Error("Result ID is empty") + } + if result.URL == "" { + t.Error("Result URL is empty") + } + if result.Width != 256 || result.Height != 256 { + t.Errorf("Expected 256x256, got %dx%d", result.Width, result.Height) + } + if result.Source != "openai" { + t.Errorf("Expected source 'openai', got '%s'", result.Source) + } + + // Test caching - second request should use cache + results2, err := client.Search(context.Background(), opts) + if err != nil { + t.Fatalf("Second search failed: %v", err) + } + + if results2[0].URL != results[0].URL { + t.Log("Note: URLs differ, cache might not be working as expected") + } +}
\ No newline at end of file diff --git a/internal/version.go b/internal/version.go index 93a42a8..0894830 100644 --- a/internal/version.go +++ b/internal/version.go @@ -1,3 +1,3 @@ package internal -const Version = "0.0.0" +const Version = "0.1.0" |
