feat: chunking + progress bar
This commit is contained in:
parent
0277157919
commit
68391cf532
6 changed files with 339 additions and 17 deletions
109
main.go
109
main.go
|
@ -6,7 +6,9 @@ import (
|
|||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/schollz/progressbar/v3"
|
||||
"github.com/wholetrans/d10n/config"
|
||||
"github.com/wholetrans/d10n/translator"
|
||||
)
|
||||
|
@ -32,6 +34,15 @@ func main() {
|
|||
systemPromptPtr := flagSet.String("system-prompt", "", "System prompt for the model")
|
||||
formatPtr := flagSet.String("format", "", "File format to process (e.g., md, txt)")
|
||||
|
||||
// Concurrency flag
|
||||
concurrencyPtr := flagSet.Int("concurrency", 0, "Number of concurrent translation tasks (default is 3)")
|
||||
|
||||
// Chunking flags
|
||||
chunkEnabledPtr := flagSet.Bool("chunk", false, "Enable chunked translation")
|
||||
chunkSizePtr := flagSet.Int("chunk-size", 0, "Size of each chunk in tokens (default is 10240)")
|
||||
chunkPromptPtr := flagSet.String("chunk-prompt", "", "Prompt to use for continuing translation (default is 'Please continue translation')")
|
||||
chunkContextPtr := flagSet.Int("chunk-context", 0, "Number of chunks to include as context (default is 2)")
|
||||
|
||||
// Parse flags
|
||||
if err := flagSet.Parse(os.Args[2:]); err != nil {
|
||||
fmt.Println("Error parsing arguments:", err)
|
||||
|
@ -75,6 +86,28 @@ func main() {
|
|||
cfg.SystemPrompt = *systemPromptPtr
|
||||
}
|
||||
|
||||
// Set concurrency options
|
||||
if *concurrencyPtr > 0 {
|
||||
cfg.Concurrency = *concurrencyPtr
|
||||
}
|
||||
|
||||
// Set chunking options from command line arguments
|
||||
if *chunkEnabledPtr {
|
||||
cfg.Chunk.Enabled = true
|
||||
}
|
||||
|
||||
if *chunkSizePtr > 0 {
|
||||
cfg.Chunk.Size = *chunkSizePtr
|
||||
}
|
||||
|
||||
if *chunkPromptPtr != "" {
|
||||
cfg.Chunk.Prompt = *chunkPromptPtr
|
||||
}
|
||||
|
||||
if *chunkContextPtr > 0 {
|
||||
cfg.Chunk.Context = *chunkContextPtr
|
||||
}
|
||||
|
||||
// Set target path if not provided
|
||||
targetPath := *targetPathPtr
|
||||
if targetPath == "" {
|
||||
|
@ -118,6 +151,11 @@ func printUsage() {
|
|||
fmt.Println(" -api-base <url> API base URL for OpenAI compatible service")
|
||||
fmt.Println(" -system-prompt <prompt> System prompt for the model")
|
||||
fmt.Println(" -format <ext> File format to process (e.g., md, txt)")
|
||||
fmt.Println(" -concurrency <num> Number of concurrent translation tasks (default: 3)")
|
||||
fmt.Println(" -chunk Enable chunked translation")
|
||||
fmt.Println(" -chunk-size <tokens> Size of each chunk in tokens (default: 10240)")
|
||||
fmt.Println(" -chunk-prompt <prompt> Prompt for continuing translation (default: 'Please continue translation')")
|
||||
fmt.Println(" -chunk-context <num> Number of chunks to include as context (default: 2)")
|
||||
}
|
||||
|
||||
func processFile(sourcePath, targetPath, sourceLanguage, targetLanguage string, trans *translator.Translator) error {
|
||||
|
@ -126,7 +164,7 @@ func processFile(sourcePath, targetPath, sourceLanguage, targetLanguage string,
|
|||
return fmt.Errorf("error reading file %s: %w", sourcePath, err)
|
||||
}
|
||||
|
||||
translatedContent, err := trans.Translate(string(content), sourceLanguage, targetLanguage)
|
||||
translatedContent, err := trans.Translate(string(content), sourceLanguage, targetLanguage, sourcePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error translating file %s: %w", sourcePath, err)
|
||||
}
|
||||
|
@ -142,7 +180,9 @@ func processFile(sourcePath, targetPath, sourceLanguage, targetLanguage string,
|
|||
return fmt.Errorf("error writing to file %s: %w", targetPath, err)
|
||||
}
|
||||
|
||||
fmt.Printf("Translated %s to %s\n", sourcePath, targetPath)
|
||||
// Remove the progress bar for this file
|
||||
trans.RemoveProgressBar(sourcePath)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -185,7 +225,29 @@ func processDirectory(sourcePath, targetPath, sourceLanguage, targetLanguage, fo
|
|||
}
|
||||
fmt.Println()
|
||||
|
||||
// Process each matched file with progress updates
|
||||
// Create overall progress bar
|
||||
overallBar := progressbar.NewOptions(len(matchedFiles),
|
||||
progressbar.OptionSetDescription("[Overall Progress]"),
|
||||
progressbar.OptionShowCount(),
|
||||
progressbar.OptionSetTheme(progressbar.Theme{
|
||||
Saucer: "#",
|
||||
SaucerHead: ">",
|
||||
SaucerPadding: "-",
|
||||
BarStart: "[",
|
||||
BarEnd: "]",
|
||||
}),
|
||||
)
|
||||
|
||||
// Set up a worker pool for concurrent processing
|
||||
var (
|
||||
wg sync.WaitGroup
|
||||
mutex sync.Mutex
|
||||
errors []error
|
||||
concurrency = trans.GetConcurrency()
|
||||
semaphore = make(chan struct{}, concurrency)
|
||||
)
|
||||
|
||||
// Process files concurrently
|
||||
for i, path := range matchedFiles {
|
||||
// Compute relative path from source
|
||||
relPath, err := filepath.Rel(sourcePath, path)
|
||||
|
@ -202,14 +264,39 @@ func processDirectory(sourcePath, targetPath, sourceLanguage, targetLanguage, fo
|
|||
return fmt.Errorf("error creating target directory %s: %w", targetFileDir, err)
|
||||
}
|
||||
|
||||
// Display progress
|
||||
fmt.Printf("[%d/%d] Translating: %s\n", i+1, len(matchedFiles), relPath)
|
||||
|
||||
// Process individual file
|
||||
err = processFile(path, targetFilePath, sourceLanguage, targetLanguage, trans)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Process individual file concurrently
|
||||
wg.Add(1)
|
||||
go func(idx int, sourcePath, targetPath, relPath string) {
|
||||
defer wg.Done()
|
||||
|
||||
// Acquire semaphore (limit concurrency)
|
||||
semaphore <- struct{}{}
|
||||
defer func() { <-semaphore }()
|
||||
|
||||
// Process the file
|
||||
err := processFile(sourcePath, targetPath, sourceLanguage, targetLanguage, trans)
|
||||
|
||||
// Update overall progress bar
|
||||
mutex.Lock()
|
||||
overallBar.Add(1)
|
||||
if err != nil {
|
||||
errors = append(errors, err)
|
||||
fmt.Printf("Error translating %s: %v\n", relPath, err)
|
||||
}
|
||||
mutex.Unlock()
|
||||
}(i+1, path, targetFilePath, relPath)
|
||||
}
|
||||
|
||||
// Wait for all translations to complete
|
||||
wg.Wait()
|
||||
|
||||
// Complete the overall progress bar
|
||||
overallBar.Finish()
|
||||
fmt.Println() // Add some spacing after progress bars
|
||||
|
||||
// Check if any errors occurred
|
||||
if len(errors) > 0 {
|
||||
return fmt.Errorf("encountered %d errors during translation", len(errors))
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue