Skip to main content

trueno_rag/loader/
transcription.rs

1//! Feature-gated transcription loader using whisper-apr for speech-to-text.
2//!
3//! When a media file has a sidecar subtitle (`.srt` or `.vtt`) adjacent to it,
4//! the subtitle is loaded directly without transcription. For media files without
5//! sidecars, the audio is decoded (WAV natively, MP4/MP3/etc via symphonia)
6//! and transcribed using whisper-apr's Whisper ASR engine.
7//!
8//! Full ASR inference requires a `.apr` model file (e.g. `base.apr`,
9//! `large-v3-turbo.apr`). When no model is configured, the loader computes
10//! the mel spectrogram and reports what would be needed for transcription.
11
12use crate::loader::subtitle::SubtitleLoader;
13use crate::loader::{DocumentLoader, LoaderRegistry};
14use crate::media::{SubtitleCue, SubtitleFormat, SubtitleTrack};
15use crate::{Document, Error, Result};
16use std::collections::HashMap;
17use std::path::{Path, PathBuf};
18use whisper_apr::{Segment, TranscribeOptions, WhisperApr};
19
20/// Media file extensions supported by the transcription loader.
21const MEDIA_EXTENSIONS: &[&str] = &["mp4", "mp3", "wav", "m4a", "ogg", "flac", "webm"];
22
23/// Compute backend for transcription inference.
24#[derive(Debug, Clone, Copy, Default)]
25pub enum TranscriptionBackend {
26    /// CPU with SIMD acceleration via trueno
27    #[default]
28    Cpu,
29    /// GPU via wgpu (cross-platform)
30    Gpu,
31    /// NVIDIA CUDA (Linux/Windows)
32    Cuda,
33}
34
35/// Configuration for the transcription pipeline.
36#[derive(Debug, Clone)]
37pub struct TranscriptionConfig {
38    /// Language hint (ISO 639-1, e.g., "en"). `None` for auto-detect.
39    pub language: Option<String>,
40    /// Beam size for decoding (1 = greedy, 5 = default).
41    pub beam_size: usize,
42    /// Enable word-level timestamps (more precise but slower).
43    pub word_timestamps: bool,
44    /// Write `.srt` sidecar files after transcription for caching.
45    pub write_sidecar: bool,
46    /// Compute backend for inference.
47    pub backend: TranscriptionBackend,
48    /// Path to the `.apr` model file (e.g. `base.apr`).
49    /// When `None`, transcription of files without sidecars will fail with
50    /// a helpful error message.
51    pub model_path: Option<PathBuf>,
52    /// Initial prompt to condition the decoder on domain vocabulary.
53    /// Example: "This lecture covers AWS, Kubernetes, and YAML configurations."
54    pub prompt: Option<String>,
55    /// Hotwords to boost during decoding for domain-specific terms.
56    /// Each string is a word or phrase to bias positively in the logit space.
57    pub hotwords: Vec<String>,
58}
59
60impl Default for TranscriptionConfig {
61    fn default() -> Self {
62        Self {
63            language: Some("en".into()),
64            beam_size: 5,
65            word_timestamps: false,
66            write_sidecar: true,
67            backend: TranscriptionBackend::default(),
68            model_path: None,
69            prompt: None,
70            hotwords: Vec::new(),
71        }
72    }
73}
74
75/// Loader that handles media files via sidecar subtitle detection
76/// and whisper-apr-based speech-to-text transcription.
77///
78/// When a media file has a sidecar subtitle (`.srt` or `.vtt`) adjacent to it,
79/// the subtitle is loaded directly. Otherwise, the audio is decoded and
80/// transcribed using the whisper-apr Whisper ASR engine.
81///
82/// # Example
83///
84/// ```rust,no_run
85/// use trueno_rag::loader::transcription::{TranscriptionLoader, TranscriptionConfig};
86/// use trueno_rag::loader::LoaderRegistry;
87///
88/// let mut registry = LoaderRegistry::new();
89/// registry.register(Box::new(TranscriptionLoader::with_defaults()));
90/// // Now the registry handles .mp4, .wav, etc. via sidecar detection
91/// ```
92pub struct TranscriptionLoader {
93    config: TranscriptionConfig,
94    whisper: Option<WhisperApr>,
95}
96
97impl TranscriptionLoader {
98    /// Create a new transcription loader with the given configuration.
99    ///
100    /// If `config.model_path` is set, loads the whisper-apr model eagerly.
101    /// Otherwise, transcription of files without sidecars will fail gracefully.
102    pub fn new(config: TranscriptionConfig) -> Self {
103        let whisper = config.model_path.as_ref().and_then(|path| match std::fs::read(path) {
104            Ok(data) => match WhisperApr::load_from_apr(&data) {
105                Ok(w) => Some(w),
106                Err(e) => {
107                    eprintln!("Warning: failed to load whisper model from {}: {e}", path.display());
108                    None
109                }
110            },
111            Err(e) => {
112                eprintln!("Warning: failed to read model file {}: {e}", path.display());
113                None
114            }
115        });
116        Self { config, whisper }
117    }
118
119    /// Create a loader with default configuration (no model loaded).
120    #[must_use]
121    pub fn with_defaults() -> Self {
122        Self::new(TranscriptionConfig::default())
123    }
124
125    /// Transcribe audio samples using the loaded whisper-apr model.
126    fn transcribe_audio(&self, samples: &[f32]) -> Result<TranscriptionResult> {
127        let whisper = self.whisper.as_ref().ok_or_else(|| {
128            Error::InvalidInput(
129                "No Whisper model loaded. Set model_path in TranscriptionConfig \
130                 or provide a .srt sidecar file alongside the media."
131                    .into(),
132            )
133        })?;
134
135        let mut options = TranscribeOptions::default();
136        if let Some(ref lang) = self.config.language {
137            options.language = Some(lang.clone());
138        }
139        options.word_timestamps = self.config.word_timestamps;
140        if self.config.beam_size <= 1 {
141            options.strategy = whisper_apr::DecodingStrategy::Greedy;
142        }
143        options.prompt = self.config.prompt.clone();
144        options.hotwords = self.config.hotwords.clone();
145
146        let result = whisper
147            .transcribe(samples, options)
148            .map_err(|e| Error::InvalidInput(format!("Transcription failed: {e}")))?;
149
150        Ok(TranscriptionResult {
151            text: result.text,
152            segments: result.segments,
153            language: result.language,
154        })
155    }
156
157    /// Access the transcription configuration.
158    #[must_use]
159    pub fn config(&self) -> &TranscriptionConfig {
160        &self.config
161    }
162
163    /// Check if a Whisper model is loaded and ready for transcription.
164    #[must_use]
165    pub fn has_model(&self) -> bool {
166        self.whisper.is_some()
167    }
168}
169
170/// Internal transcription result (simplified from whisper-apr types).
171#[derive(Debug)]
172struct TranscriptionResult {
173    /// Full transcribed text (available for direct use if needed).
174    #[allow(dead_code)]
175    text: String,
176    segments: Vec<Segment>,
177    language: String,
178}
179
180impl DocumentLoader for TranscriptionLoader {
181    fn supported_extensions(&self) -> Vec<&str> {
182        MEDIA_EXTENSIONS.to_vec()
183    }
184
185    fn load(&self, path: &Path) -> Result<Document> {
186        // 1. Check for sidecar subtitle file
187        if let Some(sidecar) = LoaderRegistry::find_sidecar(path) {
188            return SubtitleLoader.load(&sidecar);
189        }
190
191        // 2. Load and decode audio (WAV native, MP4/MP3/etc via symphonia)
192        let samples_16k = whisper_apr::audio::load_audio_file(path).map_err(|e| {
193            Error::InvalidInput(format!("Audio decode failed for {}: {e}", path.display()))
194        })?;
195
196        // 4. Transcribe
197        let result = self.transcribe_audio(&samples_16k)?;
198
199        // 5. Convert to subtitle track
200        let track = segments_to_track(&result.segments);
201
202        // 6. Optionally write sidecar for caching
203        if self.config.write_sidecar {
204            let _ = write_sidecar(path, &track);
205        }
206
207        // 7. Build document
208        let mut doc = build_transcription_document(path, &track)?;
209        doc.metadata.insert("language".into(), serde_json::json!(result.language));
210        Ok(doc)
211    }
212}
213
214impl std::fmt::Debug for TranscriptionLoader {
215    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
216        f.debug_struct("TranscriptionLoader")
217            .field("config", &self.config)
218            .field("model_loaded", &self.whisper.is_some())
219            .finish_non_exhaustive()
220    }
221}
222
223/// Convert whisper-apr segments to a [`SubtitleTrack`].
224///
225/// Maps the `start`/`end` fields (seconds as f32) from whisper-apr's `Segment`
226/// type to the f64 representation used by `SubtitleCue`.
227pub fn segments_to_track(segments: &[Segment]) -> SubtitleTrack {
228    let cues = segments
229        .iter()
230        .enumerate()
231        .map(|(i, seg)| SubtitleCue {
232            index: i,
233            start_secs: f64::from(seg.start),
234            end_secs: f64::from(seg.end),
235            text: seg.text.trim().to_string(),
236        })
237        .collect();
238    SubtitleTrack { format: SubtitleFormat::Srt, cues }
239}
240
241/// Build a [`Document`] from a transcription result.
242pub fn build_transcription_document(path: &Path, track: &SubtitleTrack) -> Result<Document> {
243    let title = path.file_stem().and_then(|s| s.to_str()).unwrap_or("Untitled").to_string();
244
245    let mut metadata = HashMap::new();
246    metadata.insert("duration_secs".into(), serde_json::json!(track.duration_secs()));
247    metadata.insert("format".into(), serde_json::json!("transcription"));
248    metadata.insert("cue_count".into(), serde_json::json!(track.cues.len()));
249    metadata.insert(
250        "subtitle_cues".into(),
251        serde_json::to_value(&track.cues).map_err(Error::Serialization)?,
252    );
253
254    let mut doc =
255        Document::new(track.to_plain_text()).with_title(title).with_source(path.to_string_lossy());
256    doc.metadata = metadata;
257    Ok(doc)
258}
259
260/// Write a [`SubtitleTrack`] as an SRT sidecar file adjacent to a media file.
261///
262/// Returns the path of the written sidecar.
263pub fn write_sidecar(media_path: &Path, track: &SubtitleTrack) -> Result<PathBuf> {
264    let sidecar_path = media_path.with_extension("srt");
265    let srt_content = track.to_srt_string();
266    std::fs::write(&sidecar_path, srt_content).map_err(Error::Io)?;
267    Ok(sidecar_path)
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273
274    #[test]
275    fn test_transcription_config_default() {
276        let config = TranscriptionConfig::default();
277        assert_eq!(config.language, Some("en".into()));
278        assert_eq!(config.beam_size, 5);
279        assert!(!config.word_timestamps);
280        assert!(config.write_sidecar);
281        assert!(config.model_path.is_none());
282        assert!(config.prompt.is_none());
283        assert!(config.hotwords.is_empty());
284    }
285
286    #[test]
287    fn test_transcription_config_with_prompt() {
288        let config = TranscriptionConfig {
289            prompt: Some("This is a lecture about AWS and Kubernetes.".into()),
290            ..TranscriptionConfig::default()
291        };
292        assert_eq!(config.prompt.as_deref(), Some("This is a lecture about AWS and Kubernetes."));
293    }
294
295    #[test]
296    fn test_transcription_config_with_hotwords() {
297        let config = TranscriptionConfig {
298            hotwords: vec!["AWS".into(), "Kubernetes".into(), "YAML".into()],
299            ..TranscriptionConfig::default()
300        };
301        assert_eq!(config.hotwords.len(), 3);
302        assert_eq!(config.hotwords[0], "AWS");
303    }
304
305    #[test]
306    fn test_transcription_backend_default() {
307        let backend = TranscriptionBackend::default();
308        assert!(matches!(backend, TranscriptionBackend::Cpu));
309    }
310
311    #[test]
312    fn test_media_extensions() {
313        let loader = TranscriptionLoader::with_defaults();
314        let exts = loader.supported_extensions();
315        assert!(exts.contains(&"mp4"));
316        assert!(exts.contains(&"wav"));
317        assert!(exts.contains(&"mp3"));
318        assert!(exts.contains(&"flac"));
319        assert!(exts.contains(&"webm"));
320    }
321
322    #[test]
323    fn test_has_model_default_false() {
324        let loader = TranscriptionLoader::with_defaults();
325        assert!(!loader.has_model());
326    }
327
328    #[test]
329    fn test_segments_to_track() {
330        let segments = vec![
331            Segment { start: 0.0, end: 3.0, text: "Hello world.".into(), tokens: vec![] },
332            Segment { start: 3.5, end: 6.0, text: "How are you?".into(), tokens: vec![] },
333        ];
334        let track = segments_to_track(&segments);
335        assert_eq!(track.cues.len(), 2);
336        assert_eq!(track.cues[0].text, "Hello world.");
337        assert!((track.cues[0].start_secs).abs() < 0.001);
338        assert!((track.cues[0].end_secs - 3.0).abs() < 0.001);
339        assert!((track.cues[1].start_secs - 3.5).abs() < 0.001);
340        assert!((track.cues[1].end_secs - 6.0).abs() < 0.001);
341    }
342
343    #[test]
344    fn test_segments_to_track_empty() {
345        let track = segments_to_track(&[]);
346        assert!(track.cues.is_empty());
347        assert!((track.duration_secs()).abs() < 0.001);
348    }
349
350    #[test]
351    fn test_load_non_wav_media_errors_helpful() {
352        let loader = TranscriptionLoader::with_defaults();
353        let result = loader.load(Path::new("/tmp/nonexistent_video.mp4"));
354        assert!(result.is_err());
355        let err = result.unwrap_err().to_string();
356        assert!(
357            err.contains("Audio decode") || err.contains("sidecar") || err.contains("not found")
358        );
359    }
360
361    #[test]
362    fn test_sidecar_fallback() {
363        let dir = std::env::temp_dir().join("trueno_rag_test_transcription_sidecar");
364        let _ = std::fs::create_dir_all(&dir);
365        let media = dir.join("lecture.wav");
366        let srt = dir.join("lecture.srt");
367        std::fs::write(&media, b"fake wav data").unwrap();
368        std::fs::write(&srt, "1\n00:00:01,000 --> 00:00:04,500\nSidecar text.\n").unwrap();
369
370        let loader = TranscriptionLoader::with_defaults();
371        let doc = loader.load(&media).unwrap();
372        assert!(doc.content.contains("Sidecar text"));
373
374        let _ = std::fs::remove_dir_all(&dir);
375    }
376
377    #[test]
378    fn test_build_transcription_document() {
379        let track = SubtitleTrack {
380            format: SubtitleFormat::Srt,
381            cues: vec![
382                SubtitleCue { index: 0, start_secs: 0.0, end_secs: 3.0, text: "Hello".into() },
383                SubtitleCue { index: 1, start_secs: 3.0, end_secs: 6.0, text: "World".into() },
384            ],
385        };
386        let doc = build_transcription_document(Path::new("/tmp/test.wav"), &track).unwrap();
387        assert_eq!(doc.content, "Hello World");
388        assert_eq!(doc.title, Some("test".into()));
389        assert!(doc.metadata.contains_key("duration_secs"));
390        assert!(doc.metadata.contains_key("subtitle_cues"));
391        assert!(doc.metadata.contains_key("cue_count"));
392    }
393
394    #[test]
395    fn test_write_sidecar() {
396        let dir = std::env::temp_dir().join("trueno_rag_test_write_sidecar");
397        let _ = std::fs::create_dir_all(&dir);
398        let media = dir.join("output.mp4");
399
400        let track = SubtitleTrack {
401            format: SubtitleFormat::Srt,
402            cues: vec![SubtitleCue {
403                index: 0,
404                start_secs: 1.0,
405                end_secs: 4.5,
406                text: "Hello.".into(),
407            }],
408        };
409
410        let sidecar = write_sidecar(&media, &track).unwrap();
411        assert_eq!(sidecar.extension().unwrap(), "srt");
412        let content = std::fs::read_to_string(&sidecar).unwrap();
413        assert!(content.contains("Hello."));
414        assert!(content.contains("00:00:01,000"));
415
416        let _ = std::fs::remove_dir_all(&dir);
417    }
418
419    #[test]
420    fn test_loader_debug() {
421        let loader = TranscriptionLoader::with_defaults();
422        let debug = format!("{loader:?}");
423        assert!(debug.contains("TranscriptionLoader"));
424        assert!(debug.contains("model_loaded"));
425    }
426
427    #[test]
428    #[ignore = "stereo_to_mono not yet implemented"]
429    fn test_stereo_to_mono() {
430        let stereo = vec![0.5_f32, -0.5, 1.0, 0.0, -1.0, 1.0];
431        let mono: Vec<f32> = stereo.chunks(2).map(|c| (c[0] + c[1]) / 2.0).collect();
432        assert_eq!(mono.len(), 3);
433        assert!((mono[0]).abs() < 0.001); // (0.5 + -0.5) / 2
434        assert!((mono[1] - 0.5).abs() < 0.001); // (1.0 + 0.0) / 2
435        assert!((mono[2]).abs() < 0.001); // (-1.0 + 1.0) / 2
436    }
437
438    #[test]
439    #[ignore = "stereo_to_mono not yet implemented"]
440    fn test_stereo_to_mono_passthrough() {
441        let mono_input = vec![0.1_f32, 0.2, 0.3];
442        assert_eq!(mono_input.len(), 3);
443    }
444
445    #[test]
446    fn test_transcribe_without_model_errors() {
447        let loader = TranscriptionLoader::with_defaults();
448        let result = loader.transcribe_audio(&[0.0; 16000]);
449        assert!(result.is_err());
450        let err = result.unwrap_err().to_string();
451        assert!(err.contains("model") || err.contains("sidecar"));
452    }
453
454    #[test]
455    fn test_config_with_model_path() {
456        let config = TranscriptionConfig {
457            model_path: Some(PathBuf::from("/tmp/nonexistent.apr")),
458            ..TranscriptionConfig::default()
459        };
460        let loader = TranscriptionLoader::new(config);
461        // Model file doesn't exist, so model won't be loaded
462        assert!(!loader.has_model());
463    }
464}