1use crate::loader::subtitle::SubtitleLoader;
13use crate::loader::{DocumentLoader, LoaderRegistry};
14use crate::media::{SubtitleCue, SubtitleFormat, SubtitleTrack};
15use crate::{Document, Error, Result};
16use std::collections::HashMap;
17use std::path::{Path, PathBuf};
18use whisper_apr::{Segment, TranscribeOptions, WhisperApr};
19
20const MEDIA_EXTENSIONS: &[&str] = &["mp4", "mp3", "wav", "m4a", "ogg", "flac", "webm"];
22
23#[derive(Debug, Clone, Copy, Default)]
25pub enum TranscriptionBackend {
26 #[default]
28 Cpu,
29 Gpu,
31 Cuda,
33}
34
35#[derive(Debug, Clone)]
37pub struct TranscriptionConfig {
38 pub language: Option<String>,
40 pub beam_size: usize,
42 pub word_timestamps: bool,
44 pub write_sidecar: bool,
46 pub backend: TranscriptionBackend,
48 pub model_path: Option<PathBuf>,
52 pub prompt: Option<String>,
55 pub hotwords: Vec<String>,
58}
59
60impl Default for TranscriptionConfig {
61 fn default() -> Self {
62 Self {
63 language: Some("en".into()),
64 beam_size: 5,
65 word_timestamps: false,
66 write_sidecar: true,
67 backend: TranscriptionBackend::default(),
68 model_path: None,
69 prompt: None,
70 hotwords: Vec::new(),
71 }
72 }
73}
74
75pub struct TranscriptionLoader {
93 config: TranscriptionConfig,
94 whisper: Option<WhisperApr>,
95}
96
97impl TranscriptionLoader {
98 pub fn new(config: TranscriptionConfig) -> Self {
103 let whisper = config.model_path.as_ref().and_then(|path| match std::fs::read(path) {
104 Ok(data) => match WhisperApr::load_from_apr(&data) {
105 Ok(w) => Some(w),
106 Err(e) => {
107 eprintln!("Warning: failed to load whisper model from {}: {e}", path.display());
108 None
109 }
110 },
111 Err(e) => {
112 eprintln!("Warning: failed to read model file {}: {e}", path.display());
113 None
114 }
115 });
116 Self { config, whisper }
117 }
118
119 #[must_use]
121 pub fn with_defaults() -> Self {
122 Self::new(TranscriptionConfig::default())
123 }
124
125 fn transcribe_audio(&self, samples: &[f32]) -> Result<TranscriptionResult> {
127 let whisper = self.whisper.as_ref().ok_or_else(|| {
128 Error::InvalidInput(
129 "No Whisper model loaded. Set model_path in TranscriptionConfig \
130 or provide a .srt sidecar file alongside the media."
131 .into(),
132 )
133 })?;
134
135 let mut options = TranscribeOptions::default();
136 if let Some(ref lang) = self.config.language {
137 options.language = Some(lang.clone());
138 }
139 options.word_timestamps = self.config.word_timestamps;
140 if self.config.beam_size <= 1 {
141 options.strategy = whisper_apr::DecodingStrategy::Greedy;
142 }
143 options.prompt = self.config.prompt.clone();
144 options.hotwords = self.config.hotwords.clone();
145
146 let result = whisper
147 .transcribe(samples, options)
148 .map_err(|e| Error::InvalidInput(format!("Transcription failed: {e}")))?;
149
150 Ok(TranscriptionResult {
151 text: result.text,
152 segments: result.segments,
153 language: result.language,
154 })
155 }
156
157 #[must_use]
159 pub fn config(&self) -> &TranscriptionConfig {
160 &self.config
161 }
162
163 #[must_use]
165 pub fn has_model(&self) -> bool {
166 self.whisper.is_some()
167 }
168}
169
170#[derive(Debug)]
172struct TranscriptionResult {
173 #[allow(dead_code)]
175 text: String,
176 segments: Vec<Segment>,
177 language: String,
178}
179
180impl DocumentLoader for TranscriptionLoader {
181 fn supported_extensions(&self) -> Vec<&str> {
182 MEDIA_EXTENSIONS.to_vec()
183 }
184
185 fn load(&self, path: &Path) -> Result<Document> {
186 if let Some(sidecar) = LoaderRegistry::find_sidecar(path) {
188 return SubtitleLoader.load(&sidecar);
189 }
190
191 let samples_16k = whisper_apr::audio::load_audio_file(path).map_err(|e| {
193 Error::InvalidInput(format!("Audio decode failed for {}: {e}", path.display()))
194 })?;
195
196 let result = self.transcribe_audio(&samples_16k)?;
198
199 let track = segments_to_track(&result.segments);
201
202 if self.config.write_sidecar {
204 let _ = write_sidecar(path, &track);
205 }
206
207 let mut doc = build_transcription_document(path, &track)?;
209 doc.metadata.insert("language".into(), serde_json::json!(result.language));
210 Ok(doc)
211 }
212}
213
214impl std::fmt::Debug for TranscriptionLoader {
215 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
216 f.debug_struct("TranscriptionLoader")
217 .field("config", &self.config)
218 .field("model_loaded", &self.whisper.is_some())
219 .finish_non_exhaustive()
220 }
221}
222
223pub fn segments_to_track(segments: &[Segment]) -> SubtitleTrack {
228 let cues = segments
229 .iter()
230 .enumerate()
231 .map(|(i, seg)| SubtitleCue {
232 index: i,
233 start_secs: f64::from(seg.start),
234 end_secs: f64::from(seg.end),
235 text: seg.text.trim().to_string(),
236 })
237 .collect();
238 SubtitleTrack { format: SubtitleFormat::Srt, cues }
239}
240
241pub fn build_transcription_document(path: &Path, track: &SubtitleTrack) -> Result<Document> {
243 let title = path.file_stem().and_then(|s| s.to_str()).unwrap_or("Untitled").to_string();
244
245 let mut metadata = HashMap::new();
246 metadata.insert("duration_secs".into(), serde_json::json!(track.duration_secs()));
247 metadata.insert("format".into(), serde_json::json!("transcription"));
248 metadata.insert("cue_count".into(), serde_json::json!(track.cues.len()));
249 metadata.insert(
250 "subtitle_cues".into(),
251 serde_json::to_value(&track.cues).map_err(Error::Serialization)?,
252 );
253
254 let mut doc =
255 Document::new(track.to_plain_text()).with_title(title).with_source(path.to_string_lossy());
256 doc.metadata = metadata;
257 Ok(doc)
258}
259
260pub fn write_sidecar(media_path: &Path, track: &SubtitleTrack) -> Result<PathBuf> {
264 let sidecar_path = media_path.with_extension("srt");
265 let srt_content = track.to_srt_string();
266 std::fs::write(&sidecar_path, srt_content).map_err(Error::Io)?;
267 Ok(sidecar_path)
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273
274 #[test]
275 fn test_transcription_config_default() {
276 let config = TranscriptionConfig::default();
277 assert_eq!(config.language, Some("en".into()));
278 assert_eq!(config.beam_size, 5);
279 assert!(!config.word_timestamps);
280 assert!(config.write_sidecar);
281 assert!(config.model_path.is_none());
282 assert!(config.prompt.is_none());
283 assert!(config.hotwords.is_empty());
284 }
285
286 #[test]
287 fn test_transcription_config_with_prompt() {
288 let config = TranscriptionConfig {
289 prompt: Some("This is a lecture about AWS and Kubernetes.".into()),
290 ..TranscriptionConfig::default()
291 };
292 assert_eq!(config.prompt.as_deref(), Some("This is a lecture about AWS and Kubernetes."));
293 }
294
295 #[test]
296 fn test_transcription_config_with_hotwords() {
297 let config = TranscriptionConfig {
298 hotwords: vec!["AWS".into(), "Kubernetes".into(), "YAML".into()],
299 ..TranscriptionConfig::default()
300 };
301 assert_eq!(config.hotwords.len(), 3);
302 assert_eq!(config.hotwords[0], "AWS");
303 }
304
305 #[test]
306 fn test_transcription_backend_default() {
307 let backend = TranscriptionBackend::default();
308 assert!(matches!(backend, TranscriptionBackend::Cpu));
309 }
310
311 #[test]
312 fn test_media_extensions() {
313 let loader = TranscriptionLoader::with_defaults();
314 let exts = loader.supported_extensions();
315 assert!(exts.contains(&"mp4"));
316 assert!(exts.contains(&"wav"));
317 assert!(exts.contains(&"mp3"));
318 assert!(exts.contains(&"flac"));
319 assert!(exts.contains(&"webm"));
320 }
321
322 #[test]
323 fn test_has_model_default_false() {
324 let loader = TranscriptionLoader::with_defaults();
325 assert!(!loader.has_model());
326 }
327
328 #[test]
329 fn test_segments_to_track() {
330 let segments = vec![
331 Segment { start: 0.0, end: 3.0, text: "Hello world.".into(), tokens: vec![] },
332 Segment { start: 3.5, end: 6.0, text: "How are you?".into(), tokens: vec![] },
333 ];
334 let track = segments_to_track(&segments);
335 assert_eq!(track.cues.len(), 2);
336 assert_eq!(track.cues[0].text, "Hello world.");
337 assert!((track.cues[0].start_secs).abs() < 0.001);
338 assert!((track.cues[0].end_secs - 3.0).abs() < 0.001);
339 assert!((track.cues[1].start_secs - 3.5).abs() < 0.001);
340 assert!((track.cues[1].end_secs - 6.0).abs() < 0.001);
341 }
342
343 #[test]
344 fn test_segments_to_track_empty() {
345 let track = segments_to_track(&[]);
346 assert!(track.cues.is_empty());
347 assert!((track.duration_secs()).abs() < 0.001);
348 }
349
350 #[test]
351 fn test_load_non_wav_media_errors_helpful() {
352 let loader = TranscriptionLoader::with_defaults();
353 let result = loader.load(Path::new("/tmp/nonexistent_video.mp4"));
354 assert!(result.is_err());
355 let err = result.unwrap_err().to_string();
356 assert!(
357 err.contains("Audio decode") || err.contains("sidecar") || err.contains("not found")
358 );
359 }
360
361 #[test]
362 fn test_sidecar_fallback() {
363 let dir = std::env::temp_dir().join("trueno_rag_test_transcription_sidecar");
364 let _ = std::fs::create_dir_all(&dir);
365 let media = dir.join("lecture.wav");
366 let srt = dir.join("lecture.srt");
367 std::fs::write(&media, b"fake wav data").unwrap();
368 std::fs::write(&srt, "1\n00:00:01,000 --> 00:00:04,500\nSidecar text.\n").unwrap();
369
370 let loader = TranscriptionLoader::with_defaults();
371 let doc = loader.load(&media).unwrap();
372 assert!(doc.content.contains("Sidecar text"));
373
374 let _ = std::fs::remove_dir_all(&dir);
375 }
376
377 #[test]
378 fn test_build_transcription_document() {
379 let track = SubtitleTrack {
380 format: SubtitleFormat::Srt,
381 cues: vec![
382 SubtitleCue { index: 0, start_secs: 0.0, end_secs: 3.0, text: "Hello".into() },
383 SubtitleCue { index: 1, start_secs: 3.0, end_secs: 6.0, text: "World".into() },
384 ],
385 };
386 let doc = build_transcription_document(Path::new("/tmp/test.wav"), &track).unwrap();
387 assert_eq!(doc.content, "Hello World");
388 assert_eq!(doc.title, Some("test".into()));
389 assert!(doc.metadata.contains_key("duration_secs"));
390 assert!(doc.metadata.contains_key("subtitle_cues"));
391 assert!(doc.metadata.contains_key("cue_count"));
392 }
393
394 #[test]
395 fn test_write_sidecar() {
396 let dir = std::env::temp_dir().join("trueno_rag_test_write_sidecar");
397 let _ = std::fs::create_dir_all(&dir);
398 let media = dir.join("output.mp4");
399
400 let track = SubtitleTrack {
401 format: SubtitleFormat::Srt,
402 cues: vec![SubtitleCue {
403 index: 0,
404 start_secs: 1.0,
405 end_secs: 4.5,
406 text: "Hello.".into(),
407 }],
408 };
409
410 let sidecar = write_sidecar(&media, &track).unwrap();
411 assert_eq!(sidecar.extension().unwrap(), "srt");
412 let content = std::fs::read_to_string(&sidecar).unwrap();
413 assert!(content.contains("Hello."));
414 assert!(content.contains("00:00:01,000"));
415
416 let _ = std::fs::remove_dir_all(&dir);
417 }
418
419 #[test]
420 fn test_loader_debug() {
421 let loader = TranscriptionLoader::with_defaults();
422 let debug = format!("{loader:?}");
423 assert!(debug.contains("TranscriptionLoader"));
424 assert!(debug.contains("model_loaded"));
425 }
426
427 #[test]
428 #[ignore = "stereo_to_mono not yet implemented"]
429 fn test_stereo_to_mono() {
430 let stereo = vec![0.5_f32, -0.5, 1.0, 0.0, -1.0, 1.0];
431 let mono: Vec<f32> = stereo.chunks(2).map(|c| (c[0] + c[1]) / 2.0).collect();
432 assert_eq!(mono.len(), 3);
433 assert!((mono[0]).abs() < 0.001); assert!((mono[1] - 0.5).abs() < 0.001); assert!((mono[2]).abs() < 0.001); }
437
438 #[test]
439 #[ignore = "stereo_to_mono not yet implemented"]
440 fn test_stereo_to_mono_passthrough() {
441 let mono_input = vec![0.1_f32, 0.2, 0.3];
442 assert_eq!(mono_input.len(), 3);
443 }
444
445 #[test]
446 fn test_transcribe_without_model_errors() {
447 let loader = TranscriptionLoader::with_defaults();
448 let result = loader.transcribe_audio(&[0.0; 16000]);
449 assert!(result.is_err());
450 let err = result.unwrap_err().to_string();
451 assert!(err.contains("model") || err.contains("sidecar"));
452 }
453
454 #[test]
455 fn test_config_with_model_path() {
456 let config = TranscriptionConfig {
457 model_path: Some(PathBuf::from("/tmp/nonexistent.apr")),
458 ..TranscriptionConfig::default()
459 };
460 let loader = TranscriptionLoader::new(config);
461 assert!(!loader.has_model());
463 }
464}