296 lines
8 KiB
Go
296 lines
8 KiB
Go
package translator
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"fmt"
|
|
"math"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"unicode/utf8"
|
|
|
|
"github.com/schollz/progressbar/v3"
|
|
"github.com/wholetrans/d10n/config"
|
|
)
|
|
|
|
// Translator handles communication with the OpenAI API
|
|
type Translator struct {
|
|
config *config.Config
|
|
mutex sync.Mutex
|
|
progressBars map[string]*progressbar.ProgressBar
|
|
}
|
|
|
|
// Message represents a message in the OpenAI chat format
|
|
type Message struct {
|
|
Role string `json:"role"`
|
|
Content string `json:"content"`
|
|
}
|
|
|
|
// ChatRequest represents a request to the OpenAI chat completion API
|
|
type ChatRequest struct {
|
|
Model string `json:"model"`
|
|
Messages []Message `json:"messages"`
|
|
}
|
|
|
|
// ChatResponse represents a response from the OpenAI chat completion API
|
|
type ChatResponse struct {
|
|
Choices []struct {
|
|
Message struct {
|
|
Content string `json:"content"`
|
|
} `json:"message"`
|
|
} `json:"choices"`
|
|
Error *struct {
|
|
Message string `json:"message"`
|
|
} `json:"error,omitempty"`
|
|
}
|
|
|
|
// NewTranslator creates a new translator
|
|
func NewTranslator(cfg *config.Config) *Translator {
|
|
return &Translator{
|
|
config: cfg,
|
|
progressBars: make(map[string]*progressbar.ProgressBar),
|
|
}
|
|
}
|
|
|
|
// GetConcurrency returns the configured concurrency level
|
|
func (t *Translator) GetConcurrency() int {
|
|
return t.config.Concurrency
|
|
}
|
|
|
|
// Translate translates content from sourceLanguage to targetLanguage
|
|
func (t *Translator) Translate(content, sourceLanguage, targetLanguage string, filePath string) (string, error) {
|
|
// If chunking is not enabled, translate the whole content at once
|
|
if !t.config.Chunk.Enabled {
|
|
return t.translateSingle(content, sourceLanguage, targetLanguage)
|
|
}
|
|
|
|
// Split content into chunks
|
|
chunks, err := t.splitIntoChunks(content)
|
|
if err != nil {
|
|
return "", fmt.Errorf("error splitting content into chunks: %w", err)
|
|
}
|
|
|
|
// If there's only one chunk, translate it directly
|
|
if len(chunks) == 1 {
|
|
return t.translateSingle(chunks[0], sourceLanguage, targetLanguage)
|
|
}
|
|
|
|
return t.translateChunks(chunks, sourceLanguage, targetLanguage, filePath)
|
|
}
|
|
|
|
// translateSingle translates a single content block
|
|
func (t *Translator) translateSingle(content, sourceLanguage, targetLanguage string) (string, error) {
|
|
messages := []Message{
|
|
{
|
|
Role: "system",
|
|
Content: t.getSystemPrompt(sourceLanguage, targetLanguage),
|
|
},
|
|
{
|
|
Role: "user",
|
|
Content: content,
|
|
},
|
|
}
|
|
|
|
return t.sendChatRequest(messages)
|
|
}
|
|
|
|
// translateChunks translates content in chunks with context
|
|
func (t *Translator) translateChunks(chunks []string, sourceLanguage, targetLanguage, filePath string) (string, error) {
|
|
systemPrompt := t.getSystemPrompt(sourceLanguage, targetLanguage)
|
|
var translations []string
|
|
contextStart := 0
|
|
|
|
// Create a progress bar for this file if it doesn't exist
|
|
t.mutex.Lock()
|
|
progressBarID := filePath
|
|
if _, exists := t.progressBars[progressBarID]; !exists {
|
|
t.progressBars[progressBarID] = progressbar.NewOptions(len(chunks),
|
|
progressbar.OptionSetDescription(fmt.Sprintf("[Chunks: %s]", filePath)),
|
|
progressbar.OptionShowCount(),
|
|
progressbar.OptionSetTheme(progressbar.Theme{
|
|
Saucer: "=",
|
|
SaucerHead: ">",
|
|
SaucerPadding: " ",
|
|
BarStart: "[",
|
|
BarEnd: "]",
|
|
}),
|
|
)
|
|
}
|
|
progressBar := t.progressBars[progressBarID]
|
|
t.mutex.Unlock()
|
|
|
|
// Process each chunk
|
|
for i, chunk := range chunks {
|
|
// Build messages with context
|
|
messages := []Message{
|
|
{
|
|
Role: "system",
|
|
Content: systemPrompt,
|
|
},
|
|
}
|
|
|
|
// Add context chunks and their translations
|
|
for j := contextStart; j < i; j++ {
|
|
if j >= 0 && j < len(chunks) {
|
|
messages = append(messages, Message{
|
|
Role: "user",
|
|
Content: chunks[j],
|
|
})
|
|
|
|
if j-contextStart < len(translations) {
|
|
messages = append(messages, Message{
|
|
Role: "assistant",
|
|
Content: translations[j-contextStart],
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
// Add current chunk to translate
|
|
messages = append(messages, Message{
|
|
Role: "user",
|
|
Content: chunk,
|
|
})
|
|
|
|
// For chunks after the first one, add the continuation prompt
|
|
if i > 0 {
|
|
messages = append(messages, Message{
|
|
Role: "user",
|
|
Content: t.config.Chunk.Prompt,
|
|
})
|
|
}
|
|
|
|
// Send request
|
|
translation, err := t.sendChatRequest(messages)
|
|
if err != nil {
|
|
return "", fmt.Errorf("error translating chunk %d: %w", i+1, err)
|
|
}
|
|
|
|
translations = append(translations, translation)
|
|
|
|
// Update progress bar
|
|
progressBar.Add(1)
|
|
|
|
// Slide context window if needed
|
|
if i >= t.config.Chunk.Context {
|
|
contextStart++
|
|
}
|
|
}
|
|
|
|
// Finish and clear the progress bar
|
|
progressBar.Finish()
|
|
|
|
// Combine translations
|
|
return strings.Join(translations, "\n"), nil
|
|
}
|
|
|
|
// RemoveProgressBar removes a progress bar by its ID
|
|
func (t *Translator) RemoveProgressBar(id string) {
|
|
t.mutex.Lock()
|
|
defer t.mutex.Unlock()
|
|
delete(t.progressBars, id)
|
|
}
|
|
|
|
// sendChatRequest sends a chat completion request to the OpenAI API
|
|
func (t *Translator) sendChatRequest(messages []Message) (string, error) {
|
|
// Create the API request
|
|
requestBody := ChatRequest{
|
|
Model: t.config.Model,
|
|
Messages: messages,
|
|
}
|
|
|
|
jsonData, err := json.Marshal(requestBody)
|
|
if err != nil {
|
|
return "", fmt.Errorf("error marshaling request: %w", err)
|
|
}
|
|
|
|
// Create HTTP request
|
|
req, err := http.NewRequest("POST", t.config.APIBase+"/v1/chat/completions", bytes.NewBuffer(jsonData))
|
|
if err != nil {
|
|
return "", fmt.Errorf("error creating HTTP request: %w", err)
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", "Bearer "+t.config.APIKey)
|
|
|
|
// Send the request
|
|
client := &http.Client{}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return "", fmt.Errorf("error sending request to API: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
// Parse the response
|
|
var responseBody ChatResponse
|
|
if err := json.NewDecoder(resp.Body).Decode(&responseBody); err != nil {
|
|
return "", fmt.Errorf("error parsing API response: %w", err)
|
|
}
|
|
|
|
// Check for API errors
|
|
if responseBody.Error != nil {
|
|
return "", fmt.Errorf("API error: %s", responseBody.Error.Message)
|
|
}
|
|
|
|
// Check if we have any choices
|
|
if len(responseBody.Choices) == 0 {
|
|
return "", fmt.Errorf("API returned no translation")
|
|
}
|
|
|
|
return responseBody.Choices[0].Message.Content, nil
|
|
}
|
|
|
|
// splitIntoChunks splits content into chunks based on token size
|
|
func (t *Translator) splitIntoChunks(content string) ([]string, error) {
|
|
lines := strings.Split(content, "\n")
|
|
var chunks []string
|
|
var currentChunk strings.Builder
|
|
currentTokens := 0
|
|
|
|
for _, line := range lines {
|
|
// Estimate tokens in the line (approximation: ~4 characters per token)
|
|
lineTokens := int(math.Ceil(float64(utf8.RuneCountInString(line)) / 4.0))
|
|
|
|
// If adding this line would exceed chunk size, start a new chunk
|
|
if currentTokens > 0 && currentTokens+lineTokens > t.config.Chunk.Size {
|
|
chunks = append(chunks, currentChunk.String())
|
|
currentChunk.Reset()
|
|
currentTokens = 0
|
|
}
|
|
|
|
// Add line to current chunk
|
|
if currentChunk.Len() > 0 {
|
|
currentChunk.WriteString("\n")
|
|
}
|
|
currentChunk.WriteString(line)
|
|
currentTokens += lineTokens
|
|
}
|
|
|
|
// Add the last chunk if it's not empty
|
|
if currentChunk.Len() > 0 {
|
|
chunks = append(chunks, currentChunk.String())
|
|
}
|
|
|
|
return chunks, nil
|
|
}
|
|
|
|
// getSystemPrompt constructs the system prompt for translation
|
|
func (t *Translator) getSystemPrompt(sourceLanguage, targetLanguage string) string {
|
|
basePrompt := t.config.SystemPrompt
|
|
|
|
// Replace variables in the system prompt
|
|
basePrompt = strings.ReplaceAll(basePrompt, "$SOURCE_LANG", sourceLanguage)
|
|
basePrompt = strings.ReplaceAll(basePrompt, "$TARGET_LANG", targetLanguage)
|
|
|
|
// If the source language is specified, include it in the prompt
|
|
sourceLangStr := ""
|
|
if sourceLanguage != "" {
|
|
sourceLangStr = fmt.Sprintf(" from %s", sourceLanguage)
|
|
}
|
|
|
|
// Append the translation instruction to the base prompt
|
|
translationInstruction := fmt.Sprintf("\nTranslate the following text%s to %s. Only output the translated text, without any explanations or additional content.", sourceLangStr, targetLanguage)
|
|
|
|
return basePrompt + translationInstruction
|
|
}
|