1use std::collections::{BTreeMap, BTreeSet};
7use std::sync::Arc;
8
9use serde_json::json;
10
11use crate::Result;
12use crate::core::formats::Subtitle;
13use crate::core::formats::manager::FormatManager;
14use crate::core::translation::CueIdGenerator;
15use crate::core::translation::request::{
16 GlossaryEntry, TerminologyMap, TranslationOutcome, TranslationRequest, TranslationResult,
17 merge_terminology,
18};
19use crate::error::SubXError;
20use crate::services::ai::AIProvider;
21use crate::services::ai::translation_prompts::{
22 TERMINOLOGY_SYSTEM_MESSAGE, TRANSLATION_SYSTEM_MESSAGE, build_terminology_prompt,
23 build_translation_prompt, is_unknown_cue_id_error, parse_terminology_response,
24 parse_translation_response_partial,
25};
26
27pub struct TranslationEngine {
35 ai_provider: Arc<dyn AIProvider>,
36 format_manager: FormatManager,
37 batch_size: usize,
38}
39
40impl std::fmt::Debug for TranslationEngine {
41 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42 f.debug_struct("TranslationEngine")
43 .field("batch_size", &self.batch_size)
44 .finish()
45 }
46}
47
48impl TranslationEngine {
49 pub fn new(ai_provider: Arc<dyn AIProvider>, batch_size: usize) -> Result<Self> {
55 if batch_size == 0 {
56 return Err(SubXError::config(
57 "Translation batch size must be greater than 0",
58 ));
59 }
60 Ok(Self {
61 ai_provider,
62 format_manager: FormatManager::new(),
63 batch_size,
64 })
65 }
66
67 pub fn batch_size(&self) -> usize {
69 self.batch_size
70 }
71
72 pub fn format_manager(&self) -> &FormatManager {
74 &self.format_manager
75 }
76
77 pub async fn translate_subtitle(
93 &self,
94 subtitle: Subtitle,
95 request: &TranslationRequest,
96 ) -> Result<TranslationResult> {
97 if request.target_language.trim().is_empty() {
98 return Err(SubXError::config(
99 "Translation target language must be provided",
100 ));
101 }
102
103 let mut subtitle = subtitle;
104 if subtitle.entries.is_empty() {
105 return Ok(TranslationResult {
106 subtitle,
107 outcome: TranslationOutcome::default(),
108 });
109 }
110
111 let mut id_gen = CueIdGenerator::new();
113 let cue_ids: Vec<String> = subtitle
114 .entries
115 .iter()
116 .map(|_| id_gen.next_id().to_string())
117 .collect();
118 let protected_cues: Vec<ProtectedCueText> = subtitle
119 .entries
120 .iter()
121 .enumerate()
122 .map(|(idx, entry)| protect_inline_formatting(&entry.text, idx))
123 .collect();
124 let terminology_texts: Vec<String> = protected_cues
125 .iter()
126 .map(|cue| cue.visible_text.clone())
127 .collect();
128
129 let generated_terms = self
131 .extract_terminology(&terminology_texts, request)
132 .await?;
133 let effective_terminology = merge_terminology(generated_terms, &request.glossary_entries);
134
135 let mut translations: BTreeMap<String, String> = BTreeMap::new();
138 let mut batch_count = 0usize;
139 for chunk_indices in chunk_ranges(subtitle.entries.len(), self.batch_size) {
140 let mut batch_cues: Vec<(String, String)> = Vec::with_capacity(chunk_indices.len());
141 let mut batch_ids: Vec<String> = Vec::with_capacity(chunk_indices.len());
142 for &i in &chunk_indices {
143 batch_cues.push((cue_ids[i].clone(), protected_cues[i].prompt_text.clone()));
144 batch_ids.push(cue_ids[i].clone());
145 }
146
147 let (map, issued_batches) = self
148 .translate_batch_with_unknown_retry(
149 &batch_cues,
150 &batch_ids,
151 request,
152 &effective_terminology,
153 )
154 .await?;
155 for (id, text) in map {
156 translations.insert(id, text);
157 }
158 batch_count += issued_batches;
159 log_translation_progress(translations.len(), cue_ids.len());
160 }
161
162 let mut empty_fallback_ids = BTreeSet::new();
163 let missing_after_initial = missing_translation_indices(&cue_ids, &translations);
164 if !missing_after_initial.is_empty() {
165 let (retry_map, issued_batches) = self
166 .retry_missing_translations(
167 &cue_ids,
168 &protected_cues,
169 &missing_after_initial,
170 request,
171 &effective_terminology,
172 )
173 .await?;
174 for (id, text) in retry_map {
175 translations.insert(id, text);
176 }
177 batch_count += issued_batches;
178
179 for idx in missing_translation_indices(&cue_ids, &translations) {
180 let id = cue_ids[idx].clone();
181 translations.insert(id.clone(), String::new());
182 empty_fallback_ids.insert(id);
183 }
184 log_translation_progress(translations.len(), cue_ids.len());
185 }
186
187 for ((entry, id), protected) in subtitle
190 .entries
191 .iter_mut()
192 .zip(cue_ids.iter())
193 .zip(protected_cues.iter())
194 {
195 if let Some(translated) = translations.get(id) {
196 if empty_fallback_ids.contains(id) {
197 entry.text = String::new();
198 } else {
199 entry.text = restore_inline_formatting(translated, protected)?;
200 }
201 }
202 }
203
204 let translated_cue_count = subtitle.entries.len();
205 Ok(TranslationResult {
206 subtitle,
207 outcome: TranslationOutcome {
208 effective_terminology,
209 translated_cue_count,
210 batch_count,
211 },
212 })
213 }
214
215 pub async fn translate_content(
221 &self,
222 content: &str,
223 request: &TranslationRequest,
224 ) -> Result<TranslationResult> {
225 let subtitle = self.format_manager.parse_auto(content)?;
226 self.translate_subtitle(subtitle, request).await
227 }
228
229 pub async fn extract_terminology(
234 &self,
235 cue_texts: &[String],
236 request: &TranslationRequest,
237 ) -> Result<TerminologyMap> {
238 let prompt = build_terminology_prompt(
239 &request.target_language,
240 request.source_language.as_deref(),
241 cue_texts,
242 request.glossary_text.as_deref(),
243 request.context.as_deref(),
244 );
245 let messages = vec![
246 json!({"role": "system", "content": TERMINOLOGY_SYSTEM_MESSAGE}),
247 json!({"role": "user", "content": prompt}),
248 ];
249 let response = self.ai_provider.chat_completion(messages).await?;
250 parse_terminology_response(&response)
251 }
252
253 async fn retry_missing_translations(
254 &self,
255 cue_ids: &[String],
256 protected_cues: &[ProtectedCueText],
257 missing_indices: &[usize],
258 request: &TranslationRequest,
259 terminology: &TerminologyMap,
260 ) -> Result<(BTreeMap<String, String>, usize)> {
261 let mut retry_cues = Vec::with_capacity(missing_indices.len());
262 let mut retry_ids = Vec::with_capacity(missing_indices.len());
263 for &idx in missing_indices {
264 retry_cues.push((
265 cue_ids[idx].clone(),
266 protected_cues[idx].prompt_text.clone(),
267 ));
268 retry_ids.push(cue_ids[idx].clone());
269 }
270
271 self.translate_batch_with_unknown_retry(&retry_cues, &retry_ids, request, terminology)
272 .await
273 }
274
275 async fn translate_batch_with_unknown_retry(
276 &self,
277 batch_cues: &[(String, String)],
278 batch_ids: &[String],
279 request: &TranslationRequest,
280 terminology: &TerminologyMap,
281 ) -> Result<(BTreeMap<String, String>, usize)> {
282 match self
283 .translate_batch_once(batch_cues, batch_ids, request, terminology)
284 .await
285 {
286 Ok(map) => Ok((map, 1)),
287 Err(err) if is_unknown_cue_id_error(&err) => {
288 if !crate::cli::output::is_quiet() && !crate::cli::output::active_mode().is_json() {
292 eprintln!(
293 "⚠ Translation response contained an unknown cue ID; discarding the batch response and retrying once."
294 );
295 }
296 match self
297 .translate_batch_once(batch_cues, batch_ids, request, terminology)
298 .await
299 {
300 Ok(map) => Ok((map, 2)),
301 Err(retry_err) if is_unknown_cue_id_error(&retry_err) => {
302 Err(SubXError::ai_service(format!(
303 "Translation response still contained an unknown cue ID after retry; failing this file: {retry_err}"
304 )))
305 }
306 Err(retry_err) => Err(retry_err),
307 }
308 }
309 Err(err) => Err(err),
310 }
311 }
312
313 async fn translate_batch_once(
314 &self,
315 batch_cues: &[(String, String)],
316 batch_ids: &[String],
317 request: &TranslationRequest,
318 terminology: &TerminologyMap,
319 ) -> Result<BTreeMap<String, String>> {
320 let prompt = build_translation_prompt(
321 &request.target_language,
322 request.source_language.as_deref(),
323 terminology,
324 request.glossary_text.as_deref(),
325 request.context.as_deref(),
326 batch_cues,
327 );
328 let messages = vec![
329 json!({"role": "system", "content": TRANSLATION_SYSTEM_MESSAGE}),
330 json!({"role": "user", "content": prompt}),
331 ];
332 let response = self.ai_provider.chat_completion(messages).await?;
333 Ok(parse_translation_response_partial(&response, batch_ids)?
334 .into_iter()
335 .collect())
336 }
337}
338
339pub fn parse_glossary_text(text: &str) -> Vec<GlossaryEntry> {
346 let mut out = Vec::new();
347 for raw_line in text.lines() {
348 let line = raw_line.trim();
349 if line.is_empty() || line.starts_with('#') {
350 continue;
351 }
352 let separator = if line.contains("->") {
353 "->"
354 } else if line.contains('=') {
355 "="
356 } else {
357 continue;
358 };
359 let mut parts = line.splitn(2, separator);
360 let source = parts.next().map(str::trim).unwrap_or("").to_string();
361 let target = parts.next().map(str::trim).unwrap_or("").to_string();
362 if source.is_empty() || target.is_empty() {
363 continue;
364 }
365 out.push(GlossaryEntry { source, target });
366 }
367 out
368}
369
370fn chunk_ranges(total: usize, batch_size: usize) -> Vec<Vec<usize>> {
371 let mut chunks = Vec::new();
372 let mut start = 0;
373 while start < total {
374 let end = (start + batch_size).min(total);
375 chunks.push((start..end).collect());
376 start = end;
377 }
378 chunks
379}
380
381fn missing_translation_indices(
382 cue_ids: &[String],
383 translations: &BTreeMap<String, String>,
384) -> Vec<usize> {
385 cue_ids
386 .iter()
387 .enumerate()
388 .filter_map(|(idx, id)| (!translations.contains_key(id)).then_some(idx))
389 .collect()
390}
391
392fn log_translation_progress(processed_cues: usize, total_cues: usize) {
393 if crate::cli::output::is_quiet() || crate::cli::output::active_mode().is_json() {
396 return;
397 }
398 eprintln!(
399 "{}",
400 format_translation_progress(processed_cues, total_cues)
401 );
402}
403
404fn format_translation_progress(processed_cues: usize, total_cues: usize) -> String {
405 format!("📊 Translation Progress:\n Processed cues: {processed_cues}/{total_cues}")
406}
407
408#[derive(Debug, Clone)]
409struct ProtectedCueText {
410 prompt_text: String,
411 visible_text: String,
412 markers: Vec<(String, String)>,
413}
414
415fn protect_inline_formatting(text: &str, cue_index: usize) -> ProtectedCueText {
416 let mut prompt_text = String::new();
417 let mut visible_text = String::new();
418 let mut markers = Vec::new();
419 let mut offset = 0usize;
420
421 while offset < text.len() {
422 let remaining = &text[offset..];
423
424 if let Some(end_offset) = html_like_tag_end(remaining) {
425 let token = &text[offset..offset + end_offset];
426 push_format_marker(cue_index, token, &mut prompt_text, &mut markers);
427 offset += end_offset;
428 continue;
429 }
430
431 if let Some(end_offset) = ass_override_tag_end(remaining) {
432 let token = &text[offset..offset + end_offset];
433 push_format_marker(cue_index, token, &mut prompt_text, &mut markers);
434 offset += end_offset;
435 continue;
436 }
437
438 let ch = remaining
439 .chars()
440 .next()
441 .expect("offset is always inside a non-empty string slice");
442 prompt_text.push(ch);
443 visible_text.push(ch);
444 offset += ch.len_utf8();
445 }
446
447 ProtectedCueText {
448 prompt_text,
449 visible_text,
450 markers,
451 }
452}
453
454fn html_like_tag_end(text: &str) -> Option<usize> {
455 if !text.starts_with('<') {
456 return None;
457 }
458 let end = text.find('>')? + 1;
459 (end > 2).then_some(end)
460}
461
462fn ass_override_tag_end(text: &str) -> Option<usize> {
463 if !text.starts_with('{') {
464 return None;
465 }
466 let end = text.find('}')? + 1;
467 let token = &text[..end];
468 token.contains('\\').then_some(end)
469}
470
471fn push_format_marker(
472 cue_index: usize,
473 token: &str,
474 prompt_text: &mut String,
475 markers: &mut Vec<(String, String)>,
476) {
477 let placeholder = format!("__SUBX_FMT_{}_{}__", cue_index, markers.len());
478 prompt_text.push_str(&placeholder);
479 markers.push((placeholder, token.to_string()));
480}
481
482fn restore_inline_formatting(translated: &str, protected: &ProtectedCueText) -> Result<String> {
483 let mut restored = translated.to_string();
484 for (placeholder, token) in &protected.markers {
485 let count = restored.matches(placeholder).count();
486 if count != 1 {
487 return Err(SubXError::ai_service(format!(
488 "Translation response must preserve formatting placeholder {placeholder} exactly once"
489 )));
490 }
491 restored = restored.replace(placeholder, token);
492 }
493 Ok(restored)
494}
495
496#[cfg(test)]
497mod tests {
498 use super::*;
499 use async_trait::async_trait;
500 use std::sync::Mutex;
501 use std::time::Duration;
502
503 use crate::core::formats::{Subtitle, SubtitleEntry, SubtitleFormatType, SubtitleMetadata};
504 use crate::services::ai::{
505 AIProvider, AnalysisRequest, ConfidenceScore, MatchResult, VerificationRequest,
506 };
507
508 struct ScriptedAI {
509 responses: Mutex<Vec<String>>,
510 }
511
512 impl ScriptedAI {
513 fn new(responses: Vec<&str>) -> Arc<Self> {
514 Arc::new(Self {
515 responses: Mutex::new(responses.into_iter().map(|s| s.to_string()).collect()),
516 })
517 }
518 }
519
520 #[async_trait]
521 impl AIProvider for ScriptedAI {
522 async fn analyze_content(&self, _r: AnalysisRequest) -> Result<MatchResult> {
523 unreachable!()
524 }
525
526 async fn verify_match(&self, _r: VerificationRequest) -> Result<ConfidenceScore> {
527 unreachable!()
528 }
529
530 async fn chat_completion(&self, _messages: Vec<serde_json::Value>) -> Result<String> {
531 let mut responses = self.responses.lock().unwrap();
532 if responses.is_empty() {
533 return Err(SubXError::ai_service("no scripted response left"));
534 }
535 Ok(responses.remove(0))
536 }
537 }
538
539 fn sample_subtitle() -> Subtitle {
540 let metadata = SubtitleMetadata::new(SubtitleFormatType::Srt);
541 let mut sub = Subtitle::new(SubtitleFormatType::Srt, metadata);
542 sub.entries.push(SubtitleEntry::new(
543 1,
544 Duration::from_secs(1),
545 Duration::from_secs(2),
546 "Hello Alice".to_string(),
547 ));
548 sub.entries.push(SubtitleEntry::new(
549 2,
550 Duration::from_secs(3),
551 Duration::from_secs(4),
552 "Goodbye Alice".to_string(),
553 ));
554 sub
555 }
556
557 #[tokio::test]
558 async fn translation_engine_translates_in_order() {
559 let term_resp = r#"{"terms":[{"source":"Alice","target":"愛麗絲"}]}"#;
560 let cues_resp = r#"{"translations":[{"id":"__ID0__","text":"哈囉 愛麗絲"},{"id":"__ID1__","text":"再見 愛麗絲"}]}"#;
562 let provider = ScriptedAI::new(vec![term_resp, cues_resp]);
563
564 struct PlaceholderAI {
567 inner: Arc<ScriptedAI>,
568 captured_ids: Mutex<Vec<String>>,
569 }
570 #[async_trait]
571 impl AIProvider for PlaceholderAI {
572 async fn analyze_content(&self, _r: AnalysisRequest) -> Result<MatchResult> {
573 unreachable!()
574 }
575 async fn verify_match(&self, _r: VerificationRequest) -> Result<ConfidenceScore> {
576 unreachable!()
577 }
578 async fn chat_completion(&self, messages: Vec<serde_json::Value>) -> Result<String> {
579 let prompt = messages
582 .last()
583 .and_then(|m| m.get("content"))
584 .and_then(|c| c.as_str())
585 .unwrap_or("")
586 .to_string();
587 let mut response = self.inner.chat_completion(messages).await?;
588 if response.contains("__ID0__") {
589 let ids: Vec<String> = prompt
590 .lines()
591 .filter_map(|l| l.trim().strip_prefix("- id: "))
592 .map(|s| s.trim().to_string())
593 .collect();
594 let mut captured = self.captured_ids.lock().unwrap();
595 *captured = ids.clone();
596 for (i, id) in ids.iter().enumerate() {
597 response = response.replace(&format!("__ID{}__", i), id);
598 }
599 }
600 Ok(response)
601 }
602 }
603
604 let provider: Arc<dyn AIProvider> = Arc::new(PlaceholderAI {
605 inner: provider,
606 captured_ids: Mutex::new(Vec::new()),
607 });
608 let engine = TranslationEngine::new(provider, 10).unwrap();
609 let request = TranslationRequest {
610 target_language: "zh-TW".to_string(),
611 source_language: Some("en".to_string()),
612 glossary_text: None,
613 context: None,
614 glossary_entries: vec![],
615 };
616 let result = engine
617 .translate_subtitle(sample_subtitle(), &request)
618 .await
619 .unwrap();
620 assert_eq!(result.subtitle.entries.len(), 2);
621 assert_eq!(result.subtitle.entries[0].text, "哈囉 愛麗絲");
622 assert_eq!(result.subtitle.entries[1].text, "再見 愛麗絲");
623 assert_eq!(result.outcome.translated_cue_count, 2);
624 assert_eq!(result.outcome.batch_count, 1);
625 assert_eq!(
626 result.outcome.effective_terminology.get("Alice").unwrap(),
627 "愛麗絲"
628 );
629 assert_eq!(
631 result.subtitle.entries[0].start_time,
632 Duration::from_secs(1)
633 );
634 assert_eq!(result.subtitle.entries[1].end_time, Duration::from_secs(4));
635 }
636
637 #[tokio::test]
638 async fn empty_subtitle_returns_empty_outcome() {
639 let provider: Arc<dyn AIProvider> = ScriptedAI::new(vec![]);
640 let engine = TranslationEngine::new(provider, 5).unwrap();
641 let metadata = SubtitleMetadata::new(SubtitleFormatType::Srt);
642 let sub = Subtitle::new(SubtitleFormatType::Srt, metadata);
643 let request = TranslationRequest {
644 target_language: "zh-TW".to_string(),
645 source_language: None,
646 glossary_text: None,
647 context: None,
648 glossary_entries: vec![],
649 };
650 let result = engine.translate_subtitle(sub, &request).await.unwrap();
651 assert_eq!(result.outcome.translated_cue_count, 0);
652 assert_eq!(result.outcome.batch_count, 0);
653 }
654
655 #[test]
656 fn batch_size_zero_is_rejected() {
657 let provider: Arc<dyn AIProvider> = ScriptedAI::new(vec![]);
658 let err = TranslationEngine::new(provider, 0).unwrap_err();
659 assert!(err.to_string().contains("batch size"));
660 }
661
662 #[test]
663 fn parse_glossary_text_handles_multiple_separators() {
664 let text = "# comment\nAlice = 艾莉絲\nBob -> 鮑伯\n\n";
665 let entries = parse_glossary_text(text);
666 assert_eq!(entries.len(), 2);
667 assert_eq!(entries[0].source, "Alice");
668 assert_eq!(entries[0].target, "艾莉絲");
669 assert_eq!(entries[1].source, "Bob");
670 assert_eq!(entries[1].target, "鮑伯");
671 }
672
673 #[test]
674 fn protect_and_restore_inline_formatting_tokens() {
675 let protected = protect_inline_formatting(r#"<i>{\b1}Hello{\b0}</i>"#, 3);
676 assert_eq!(
677 protected.prompt_text,
678 "__SUBX_FMT_3_0____SUBX_FMT_3_1__Hello__SUBX_FMT_3_2____SUBX_FMT_3_3__"
679 );
680 assert_eq!(protected.visible_text, "Hello");
681
682 let translated = "__SUBX_FMT_3_0____SUBX_FMT_3_1__你好__SUBX_FMT_3_2____SUBX_FMT_3_3__";
683 let restored = restore_inline_formatting(translated, &protected).unwrap();
684 assert_eq!(restored, r#"<i>{\b1}你好{\b0}</i>"#);
685 }
686
687 #[test]
688 fn translation_progress_message_includes_processed_and_total_cues() {
689 assert_eq!(
690 format_translation_progress(42, 100),
691 "📊 Translation Progress:\n Processed cues: 42/100"
692 );
693 }
694}