diff --git a/config/config.go b/config/config.go index 4b5d09f..a733de4 100644 --- a/config/config.go +++ b/config/config.go @@ -8,16 +8,29 @@ import ( "gopkg.in/yaml.v3" ) +// ChunkConfig stores configuration for chunked translation +type ChunkConfig struct { + Enabled bool `yaml:"enabled"` + Size int `yaml:"size"` + Prompt string `yaml:"prompt"` + Context int `yaml:"context"` +} + // Config stores the application configuration type Config struct { - APIBase string `yaml:"api_base"` - APIKey string `yaml:"api_key"` - Model string `yaml:"model"` - SystemPrompt string `yaml:"system_prompt"` + APIBase string `yaml:"api_base"` + APIKey string `yaml:"api_key"` + Model string `yaml:"model"` + SystemPrompt string `yaml:"system_prompt"` + // Concurrency related settings + Concurrency int `yaml:"concurrency"` + // Chunk related settings + Chunk ChunkConfig `yaml:"chunk"` } // Default system prompt as a placeholder const DefaultSystemPrompt = "Placeholder" +const DefaultChunkPrompt = "Please continue translation" // LoadConfig loads configuration from ~/.config/d10n.yaml func LoadConfig() (*Config, error) { @@ -34,6 +47,13 @@ func LoadConfig() (*Config, error) { if _, err := os.Stat(configPath); os.IsNotExist(err) { return &Config{ SystemPrompt: DefaultSystemPrompt, + Concurrency: 3, // Default concurrency + Chunk: ChunkConfig{ + Enabled: false, // Chunking disabled by default + Size: 10240, // Default chunk size in tokens + Prompt: DefaultChunkPrompt, // Default chunk prompt + Context: 2, // Default context size + }, }, nil } @@ -49,10 +69,28 @@ func LoadConfig() (*Config, error) { return nil, fmt.Errorf("could not parse config file: %w", err) } - // Set default system prompt if not specified + // Set default values if not specified if config.SystemPrompt == "" { config.SystemPrompt = DefaultSystemPrompt } + + // Set default for concurrency + if config.Concurrency <= 0 { + config.Concurrency = 3 + } + + // Set defaults for chunk settings + if config.Chunk.Size <= 0 { + config.Chunk.Size = 10240 + } + + if config.Chunk.Prompt == "" { + config.Chunk.Prompt = DefaultChunkPrompt + } + + if config.Chunk.Context <= 0 { + config.Chunk.Context = 2 + } return &config, nil } diff --git a/go.mod b/go.mod index f38e0d3..0034844 100644 --- a/go.mod +++ b/go.mod @@ -3,3 +3,11 @@ module github.com/wholetrans/d10n go 1.24.1 require gopkg.in/yaml.v3 v3.0.1 + +require ( + github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect + github.com/rivo/uniseg v0.4.7 // indirect + github.com/schollz/progressbar/v3 v3.18.0 // indirect + golang.org/x/sys v0.29.0 // indirect + golang.org/x/term v0.28.0 // indirect +) diff --git a/go.sum b/go.sum index a62c313..b07dda9 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,13 @@ +github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db h1:62I3jR2EmQ4l5rM/4FEfDWcRD+abF5XlKShorW5LRoQ= +github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db/go.mod h1:l0dey0ia/Uv7NcFFVbCLtqEBQbrT4OCwCSKTEv6enCw= +github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= +github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= +github.com/schollz/progressbar/v3 v3.18.0 h1:uXdoHABRFmNIjUfte/Ex7WtuyVslrw2wVPQmCN62HpA= +github.com/schollz/progressbar/v3 v3.18.0/go.mod h1:IsO3lpbaGuzh8zIMzgY3+J8l4C8GjO0Y9S69eFvNsec= +golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= +golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg= +golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/main.go b/main.go index b4a0dcb..160185d 100644 --- a/main.go +++ b/main.go @@ -6,7 +6,9 @@ import ( "os" "path/filepath" "strings" + "sync" + "github.com/schollz/progressbar/v3" "github.com/wholetrans/d10n/config" "github.com/wholetrans/d10n/translator" ) @@ -32,6 +34,15 @@ func main() { systemPromptPtr := flagSet.String("system-prompt", "", "System prompt for the model") formatPtr := flagSet.String("format", "", "File format to process (e.g., md, txt)") + // Concurrency flag + concurrencyPtr := flagSet.Int("concurrency", 0, "Number of concurrent translation tasks (default is 3)") + + // Chunking flags + chunkEnabledPtr := flagSet.Bool("chunk", false, "Enable chunked translation") + chunkSizePtr := flagSet.Int("chunk-size", 0, "Size of each chunk in tokens (default is 10240)") + chunkPromptPtr := flagSet.String("chunk-prompt", "", "Prompt to use for continuing translation (default is 'Please continue translation')") + chunkContextPtr := flagSet.Int("chunk-context", 0, "Number of chunks to include as context (default is 2)") + // Parse flags if err := flagSet.Parse(os.Args[2:]); err != nil { fmt.Println("Error parsing arguments:", err) @@ -75,6 +86,28 @@ func main() { cfg.SystemPrompt = *systemPromptPtr } + // Set concurrency options + if *concurrencyPtr > 0 { + cfg.Concurrency = *concurrencyPtr + } + + // Set chunking options from command line arguments + if *chunkEnabledPtr { + cfg.Chunk.Enabled = true + } + + if *chunkSizePtr > 0 { + cfg.Chunk.Size = *chunkSizePtr + } + + if *chunkPromptPtr != "" { + cfg.Chunk.Prompt = *chunkPromptPtr + } + + if *chunkContextPtr > 0 { + cfg.Chunk.Context = *chunkContextPtr + } + // Set target path if not provided targetPath := *targetPathPtr if targetPath == "" { @@ -118,6 +151,11 @@ func printUsage() { fmt.Println(" -api-base API base URL for OpenAI compatible service") fmt.Println(" -system-prompt System prompt for the model") fmt.Println(" -format File format to process (e.g., md, txt)") + fmt.Println(" -concurrency Number of concurrent translation tasks (default: 3)") + fmt.Println(" -chunk Enable chunked translation") + fmt.Println(" -chunk-size Size of each chunk in tokens (default: 10240)") + fmt.Println(" -chunk-prompt Prompt for continuing translation (default: 'Please continue translation')") + fmt.Println(" -chunk-context Number of chunks to include as context (default: 2)") } func processFile(sourcePath, targetPath, sourceLanguage, targetLanguage string, trans *translator.Translator) error { @@ -126,7 +164,7 @@ func processFile(sourcePath, targetPath, sourceLanguage, targetLanguage string, return fmt.Errorf("error reading file %s: %w", sourcePath, err) } - translatedContent, err := trans.Translate(string(content), sourceLanguage, targetLanguage) + translatedContent, err := trans.Translate(string(content), sourceLanguage, targetLanguage, sourcePath) if err != nil { return fmt.Errorf("error translating file %s: %w", sourcePath, err) } @@ -142,7 +180,9 @@ func processFile(sourcePath, targetPath, sourceLanguage, targetLanguage string, return fmt.Errorf("error writing to file %s: %w", targetPath, err) } - fmt.Printf("Translated %s to %s\n", sourcePath, targetPath) + // Remove the progress bar for this file + trans.RemoveProgressBar(sourcePath) + return nil } @@ -185,7 +225,29 @@ func processDirectory(sourcePath, targetPath, sourceLanguage, targetLanguage, fo } fmt.Println() - // Process each matched file with progress updates + // Create overall progress bar + overallBar := progressbar.NewOptions(len(matchedFiles), + progressbar.OptionSetDescription("[Overall Progress]"), + progressbar.OptionShowCount(), + progressbar.OptionSetTheme(progressbar.Theme{ + Saucer: "#", + SaucerHead: ">", + SaucerPadding: "-", + BarStart: "[", + BarEnd: "]", + }), + ) + + // Set up a worker pool for concurrent processing + var ( + wg sync.WaitGroup + mutex sync.Mutex + errors []error + concurrency = trans.GetConcurrency() + semaphore = make(chan struct{}, concurrency) + ) + + // Process files concurrently for i, path := range matchedFiles { // Compute relative path from source relPath, err := filepath.Rel(sourcePath, path) @@ -202,14 +264,39 @@ func processDirectory(sourcePath, targetPath, sourceLanguage, targetLanguage, fo return fmt.Errorf("error creating target directory %s: %w", targetFileDir, err) } - // Display progress - fmt.Printf("[%d/%d] Translating: %s\n", i+1, len(matchedFiles), relPath) - - // Process individual file - err = processFile(path, targetFilePath, sourceLanguage, targetLanguage, trans) - if err != nil { - return err - } + // Process individual file concurrently + wg.Add(1) + go func(idx int, sourcePath, targetPath, relPath string) { + defer wg.Done() + + // Acquire semaphore (limit concurrency) + semaphore <- struct{}{} + defer func() { <-semaphore }() + + // Process the file + err := processFile(sourcePath, targetPath, sourceLanguage, targetLanguage, trans) + + // Update overall progress bar + mutex.Lock() + overallBar.Add(1) + if err != nil { + errors = append(errors, err) + fmt.Printf("Error translating %s: %v\n", relPath, err) + } + mutex.Unlock() + }(i+1, path, targetFilePath, relPath) + } + + // Wait for all translations to complete + wg.Wait() + + // Complete the overall progress bar + overallBar.Finish() + fmt.Println() // Add some spacing after progress bars + + // Check if any errors occurred + if len(errors) > 0 { + return fmt.Errorf("encountered %d errors during translation", len(errors)) } return nil diff --git a/sample_config.yaml b/sample_config.yaml index 948681e..b162abb 100644 --- a/sample_config.yaml +++ b/sample_config.yaml @@ -5,3 +5,13 @@ api_base: "https://api.openai.com" # OpenAI-compatible API base URL api_key: "your-api-key-here" # API key for the service model: "gpt-4o" # Model to use for translation system_prompt: "You are a professional translator. You are translating from $SOURCE_LANG to $TARGET_LANG. Maintain the original formatting and structure of the text while translating it accurately." # Custom system prompt with variables + +# Concurrency settings +concurrency: 3 # Number of concurrent translation tasks (default: 3) + +# Chunked translation settings +chunk: + enabled: false # Whether to enable chunked translation (default: false) + size: 10240 # Size of each chunk in tokens (default: 10240) + prompt: "Please continue translation" # Prompt to use for continuing translation + context: 2 # Number of chunks to include as context (default: 2) diff --git a/translator/translator.go b/translator/translator.go index 948abde..9a51174 100644 --- a/translator/translator.go +++ b/translator/translator.go @@ -4,15 +4,21 @@ import ( "bytes" "encoding/json" "fmt" + "math" "net/http" "strings" + "sync" + "unicode/utf8" + "github.com/schollz/progressbar/v3" "github.com/wholetrans/d10n/config" ) // Translator handles communication with the OpenAI API type Translator struct { config *config.Config + mutex sync.Mutex + progressBars map[string]*progressbar.ProgressBar } // Message represents a message in the OpenAI chat format @@ -43,11 +49,38 @@ type ChatResponse struct { func NewTranslator(cfg *config.Config) *Translator { return &Translator{ config: cfg, + progressBars: make(map[string]*progressbar.ProgressBar), } } +// GetConcurrency returns the configured concurrency level +func (t *Translator) GetConcurrency() int { + return t.config.Concurrency +} + // Translate translates content from sourceLanguage to targetLanguage -func (t *Translator) Translate(content, sourceLanguage, targetLanguage string) (string, error) { +func (t *Translator) Translate(content, sourceLanguage, targetLanguage string, filePath string) (string, error) { + // If chunking is not enabled, translate the whole content at once + if !t.config.Chunk.Enabled { + return t.translateSingle(content, sourceLanguage, targetLanguage) + } + + // Split content into chunks + chunks, err := t.splitIntoChunks(content) + if err != nil { + return "", fmt.Errorf("error splitting content into chunks: %w", err) + } + + // If there's only one chunk, translate it directly + if len(chunks) == 1 { + return t.translateSingle(chunks[0], sourceLanguage, targetLanguage) + } + + return t.translateChunks(chunks, sourceLanguage, targetLanguage, filePath) +} + +// translateSingle translates a single content block +func (t *Translator) translateSingle(content, sourceLanguage, targetLanguage string) (string, error) { messages := []Message{ { Role: "system", @@ -59,6 +92,108 @@ func (t *Translator) Translate(content, sourceLanguage, targetLanguage string) ( }, } + return t.sendChatRequest(messages) +} + +// translateChunks translates content in chunks with context +func (t *Translator) translateChunks(chunks []string, sourceLanguage, targetLanguage, filePath string) (string, error) { + systemPrompt := t.getSystemPrompt(sourceLanguage, targetLanguage) + var translations []string + contextStart := 0 + + // Create a progress bar for this file if it doesn't exist + t.mutex.Lock() + progressBarID := filePath + if _, exists := t.progressBars[progressBarID]; !exists { + t.progressBars[progressBarID] = progressbar.NewOptions(len(chunks), + progressbar.OptionSetDescription(fmt.Sprintf("[Chunks: %s]", filePath)), + progressbar.OptionShowCount(), + progressbar.OptionSetTheme(progressbar.Theme{ + Saucer: "=", + SaucerHead: ">", + SaucerPadding: " ", + BarStart: "[", + BarEnd: "]", + }), + ) + } + progressBar := t.progressBars[progressBarID] + t.mutex.Unlock() + + // Process each chunk + for i, chunk := range chunks { + // Build messages with context + messages := []Message{ + { + Role: "system", + Content: systemPrompt, + }, + } + + // Add context chunks and their translations + for j := contextStart; j < i; j++ { + if j >= 0 && j < len(chunks) { + messages = append(messages, Message{ + Role: "user", + Content: chunks[j], + }) + + if j-contextStart < len(translations) { + messages = append(messages, Message{ + Role: "assistant", + Content: translations[j-contextStart], + }) + } + } + } + + // Add current chunk to translate + messages = append(messages, Message{ + Role: "user", + Content: chunk, + }) + + // For chunks after the first one, add the continuation prompt + if i > 0 { + messages = append(messages, Message{ + Role: "user", + Content: t.config.Chunk.Prompt, + }) + } + + // Send request + translation, err := t.sendChatRequest(messages) + if err != nil { + return "", fmt.Errorf("error translating chunk %d: %w", i+1, err) + } + + translations = append(translations, translation) + + // Update progress bar + progressBar.Add(1) + + // Slide context window if needed + if i >= t.config.Chunk.Context { + contextStart++ + } + } + + // Finish and clear the progress bar + progressBar.Finish() + + // Combine translations + return strings.Join(translations, "\n"), nil +} + +// RemoveProgressBar removes a progress bar by its ID +func (t *Translator) RemoveProgressBar(id string) { + t.mutex.Lock() + defer t.mutex.Unlock() + delete(t.progressBars, id) +} + +// sendChatRequest sends a chat completion request to the OpenAI API +func (t *Translator) sendChatRequest(messages []Message) (string, error) { // Create the API request requestBody := ChatRequest{ Model: t.config.Model, @@ -106,6 +241,40 @@ func (t *Translator) Translate(content, sourceLanguage, targetLanguage string) ( return responseBody.Choices[0].Message.Content, nil } +// splitIntoChunks splits content into chunks based on token size +func (t *Translator) splitIntoChunks(content string) ([]string, error) { + lines := strings.Split(content, "\n") + var chunks []string + var currentChunk strings.Builder + currentTokens := 0 + + for _, line := range lines { + // Estimate tokens in the line (approximation: ~4 characters per token) + lineTokens := int(math.Ceil(float64(utf8.RuneCountInString(line)) / 4.0)) + + // If adding this line would exceed chunk size, start a new chunk + if currentTokens > 0 && currentTokens+lineTokens > t.config.Chunk.Size { + chunks = append(chunks, currentChunk.String()) + currentChunk.Reset() + currentTokens = 0 + } + + // Add line to current chunk + if currentChunk.Len() > 0 { + currentChunk.WriteString("\n") + } + currentChunk.WriteString(line) + currentTokens += lineTokens + } + + // Add the last chunk if it's not empty + if currentChunk.Len() > 0 { + chunks = append(chunks, currentChunk.String()) + } + + return chunks, nil +} + // getSystemPrompt constructs the system prompt for translation func (t *Translator) getSystemPrompt(sourceLanguage, targetLanguage string) string { basePrompt := t.config.SystemPrompt