d10n/translator/translator.go

296 lines
8 KiB
Go

package translator
import (
"bytes"
"encoding/json"
"fmt"
"math"
"net/http"
"strings"
"sync"
"unicode/utf8"
"github.com/schollz/progressbar/v3"
"github.com/wholetrans/d10n/config"
)
// Translator handles communication with the OpenAI API
type Translator struct {
config *config.Config
mutex sync.Mutex
progressBars map[string]*progressbar.ProgressBar
}
// Message represents a message in the OpenAI chat format
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
// ChatRequest represents a request to the OpenAI chat completion API
type ChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
}
// ChatResponse represents a response from the OpenAI chat completion API
type ChatResponse struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
Error *struct {
Message string `json:"message"`
} `json:"error,omitempty"`
}
// NewTranslator creates a new translator
func NewTranslator(cfg *config.Config) *Translator {
return &Translator{
config: cfg,
progressBars: make(map[string]*progressbar.ProgressBar),
}
}
// GetConcurrency returns the configured concurrency level
func (t *Translator) GetConcurrency() int {
return t.config.Concurrency
}
// Translate translates content from sourceLanguage to targetLanguage
func (t *Translator) Translate(content, sourceLanguage, targetLanguage string, filePath string) (string, error) {
// If chunking is not enabled, translate the whole content at once
if !t.config.Chunk.Enabled {
return t.translateSingle(content, sourceLanguage, targetLanguage)
}
// Split content into chunks
chunks, err := t.splitIntoChunks(content)
if err != nil {
return "", fmt.Errorf("error splitting content into chunks: %w", err)
}
// If there's only one chunk, translate it directly
if len(chunks) == 1 {
return t.translateSingle(chunks[0], sourceLanguage, targetLanguage)
}
return t.translateChunks(chunks, sourceLanguage, targetLanguage, filePath)
}
// translateSingle translates a single content block
func (t *Translator) translateSingle(content, sourceLanguage, targetLanguage string) (string, error) {
messages := []Message{
{
Role: "system",
Content: t.getSystemPrompt(sourceLanguage, targetLanguage),
},
{
Role: "user",
Content: content,
},
}
return t.sendChatRequest(messages)
}
// translateChunks translates content in chunks with context
func (t *Translator) translateChunks(chunks []string, sourceLanguage, targetLanguage, filePath string) (string, error) {
systemPrompt := t.getSystemPrompt(sourceLanguage, targetLanguage)
var translations []string
contextStart := 0
// Create a progress bar for this file if it doesn't exist
t.mutex.Lock()
progressBarID := filePath
if _, exists := t.progressBars[progressBarID]; !exists {
t.progressBars[progressBarID] = progressbar.NewOptions(len(chunks),
progressbar.OptionSetDescription(fmt.Sprintf("[Chunks: %s]", filePath)),
progressbar.OptionShowCount(),
progressbar.OptionSetTheme(progressbar.Theme{
Saucer: "=",
SaucerHead: ">",
SaucerPadding: " ",
BarStart: "[",
BarEnd: "]",
}),
)
}
progressBar := t.progressBars[progressBarID]
t.mutex.Unlock()
// Process each chunk
for i, chunk := range chunks {
// Build messages with context
messages := []Message{
{
Role: "system",
Content: systemPrompt,
},
}
// Add context chunks and their translations
for j := contextStart; j < i; j++ {
if j >= 0 && j < len(chunks) {
messages = append(messages, Message{
Role: "user",
Content: chunks[j],
})
if j-contextStart < len(translations) {
messages = append(messages, Message{
Role: "assistant",
Content: translations[j-contextStart],
})
}
}
}
// Add current chunk to translate
messages = append(messages, Message{
Role: "user",
Content: chunk,
})
// For chunks after the first one, add the continuation prompt
if i > 0 {
messages = append(messages, Message{
Role: "user",
Content: t.config.Chunk.Prompt,
})
}
// Send request
translation, err := t.sendChatRequest(messages)
if err != nil {
return "", fmt.Errorf("error translating chunk %d: %w", i+1, err)
}
translations = append(translations, translation)
// Update progress bar
progressBar.Add(1)
// Slide context window if needed
if i >= t.config.Chunk.Context {
contextStart++
}
}
// Finish and clear the progress bar
progressBar.Finish()
// Combine translations
return strings.Join(translations, "\n"), nil
}
// RemoveProgressBar removes a progress bar by its ID
func (t *Translator) RemoveProgressBar(id string) {
t.mutex.Lock()
defer t.mutex.Unlock()
delete(t.progressBars, id)
}
// sendChatRequest sends a chat completion request to the OpenAI API
func (t *Translator) sendChatRequest(messages []Message) (string, error) {
// Create the API request
requestBody := ChatRequest{
Model: t.config.Model,
Messages: messages,
}
jsonData, err := json.Marshal(requestBody)
if err != nil {
return "", fmt.Errorf("error marshaling request: %w", err)
}
// Create HTTP request
req, err := http.NewRequest("POST", t.config.APIBase+"/v1/chat/completions", bytes.NewBuffer(jsonData))
if err != nil {
return "", fmt.Errorf("error creating HTTP request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+t.config.APIKey)
// Send the request
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("error sending request to API: %w", err)
}
defer resp.Body.Close()
// Parse the response
var responseBody ChatResponse
if err := json.NewDecoder(resp.Body).Decode(&responseBody); err != nil {
return "", fmt.Errorf("error parsing API response: %w", err)
}
// Check for API errors
if responseBody.Error != nil {
return "", fmt.Errorf("API error: %s", responseBody.Error.Message)
}
// Check if we have any choices
if len(responseBody.Choices) == 0 {
return "", fmt.Errorf("API returned no translation")
}
return responseBody.Choices[0].Message.Content, nil
}
// splitIntoChunks splits content into chunks based on token size
func (t *Translator) splitIntoChunks(content string) ([]string, error) {
lines := strings.Split(content, "\n")
var chunks []string
var currentChunk strings.Builder
currentTokens := 0
for _, line := range lines {
// Estimate tokens in the line (approximation: ~4 characters per token)
lineTokens := int(math.Ceil(float64(utf8.RuneCountInString(line)) / 4.0))
// If adding this line would exceed chunk size, start a new chunk
if currentTokens > 0 && currentTokens+lineTokens > t.config.Chunk.Size {
chunks = append(chunks, currentChunk.String())
currentChunk.Reset()
currentTokens = 0
}
// Add line to current chunk
if currentChunk.Len() > 0 {
currentChunk.WriteString("\n")
}
currentChunk.WriteString(line)
currentTokens += lineTokens
}
// Add the last chunk if it's not empty
if currentChunk.Len() > 0 {
chunks = append(chunks, currentChunk.String())
}
return chunks, nil
}
// getSystemPrompt constructs the system prompt for translation
func (t *Translator) getSystemPrompt(sourceLanguage, targetLanguage string) string {
basePrompt := t.config.SystemPrompt
// Replace variables in the system prompt
basePrompt = strings.ReplaceAll(basePrompt, "$SOURCE_LANG", sourceLanguage)
basePrompt = strings.ReplaceAll(basePrompt, "$TARGET_LANG", targetLanguage)
// If the source language is specified, include it in the prompt
sourceLangStr := ""
if sourceLanguage != "" {
sourceLangStr = fmt.Sprintf(" from %s", sourceLanguage)
}
// Append the translation instruction to the base prompt
translationInstruction := fmt.Sprintf("\nTranslate the following text%s to %s. Only output the translated text, without any explanations or additional content.", sourceLangStr, targetLanguage)
return basePrompt + translationInstruction
}