subx_cli/services/vad/
audio_loader.rs

1//! Direct audio loader: Uses Symphonia to directly decode various audio formats and obtain i16 sample data.
2//!
3//! Supports MP4, MKV, OGG, WAV and other formats, returning sample data and audio information.
4use crate::services::vad::detector::AudioInfo;
5use crate::{Result, error::SubXError};
6use log::{debug, trace, warn};
7use std::fs::File;
8use std::path::Path;
9use symphonia::core::audio::SampleBuffer;
10use symphonia::core::codecs::CodecRegistry;
11use symphonia::core::codecs::DecoderOptions;
12use symphonia::core::formats::FormatOptions;
13use symphonia::core::io::MediaSourceStream;
14use symphonia::core::probe::Hint;
15use symphonia::core::probe::Probe;
16use symphonia::default::{get_codecs, get_probe};
17
18/// Direct audio loader using Symphonia to decode and obtain raw sample data.
19pub struct DirectAudioLoader {
20    probe: &'static Probe,
21    codecs: &'static CodecRegistry,
22}
23
24#[cfg(test)]
25mod tests {
26    use super::*;
27
28    #[tokio::test]
29    async fn test_direct_mp4_loading() {
30        // Test direct audio loading using assets/SubX - The Subtitle Revolution.mp4
31        let loader = DirectAudioLoader::new().expect("Failed to initialize DirectAudioLoader");
32        let (samples, info) = loader
33            .load_audio_samples("assets/SubX - The Subtitle Revolution.mp4")
34            .expect("load_audio_samples failed");
35        assert!(!samples.is_empty(), "Sample data should not be empty");
36        assert!(info.sample_rate > 0, "sample_rate should be greater than 0");
37        assert!(
38            info.total_samples > 0,
39            "total_samples should be greater than 0"
40        );
41    }
42}
43
44impl DirectAudioLoader {
45    /// Creates a new audio loader instance.
46    pub fn new() -> Result<Self> {
47        Ok(Self {
48            probe: get_probe(),
49            codecs: get_codecs(),
50        })
51    }
52
53    /// Loads i16 samples and audio information from an audio file path.
54    pub fn load_audio_samples<P: AsRef<Path>>(&self, path: P) -> Result<(Vec<i16>, AudioInfo)> {
55        let path_ref = path.as_ref();
56        debug!(
57            "[DirectAudioLoader] Start loading audio file: {:?}",
58            path_ref
59        );
60        // Open the media source.
61        let file = File::open(path_ref).map_err(|e| {
62            warn!(
63                "[DirectAudioLoader] Failed to open audio file: {:?}, error: {}",
64                path_ref, e
65            );
66            SubXError::audio_processing(format!("Failed to open audio file: {}", e))
67        })?;
68        debug!(
69            "[DirectAudioLoader] Successfully opened audio file: {:?}",
70            path_ref
71        );
72
73        // Create the media source stream.
74        let mss = MediaSourceStream::new(Box::new(file), Default::default());
75        debug!("[DirectAudioLoader] MediaSourceStream created");
76
77        // Create a hint to help format probing based on file extension.
78        let mut hint = Hint::new();
79        if let Some(ext) = path_ref.extension().and_then(|e| e.to_str()) {
80            debug!(
81                "[DirectAudioLoader] Detected extension: {} (used for format probing)",
82                ext
83            );
84            hint.with_extension(ext);
85        } else {
86            debug!("[DirectAudioLoader] No extension detected, using default format probing");
87        }
88
89        // Probe the media format.
90        let probed = self
91            .probe
92            .format(&hint, mss, &FormatOptions::default(), &Default::default())
93            .map_err(|e| {
94                warn!("[DirectAudioLoader] Format probing failed: {}", e);
95                SubXError::audio_processing(format!("Failed to probe format: {}", e))
96            })?;
97        debug!("[DirectAudioLoader] Format probing succeeded");
98        let mut format = probed.format;
99
100        // List all tracks and their channel info before selecting
101        for (idx, t) in format.tracks().iter().enumerate() {
102            let sr = t
103                .codec_params
104                .sample_rate
105                .map(|v| v.to_string())
106                .unwrap_or("None".to_string());
107            let ch = t
108                .codec_params
109                .channels
110                .map(|c| c.count().to_string())
111                .unwrap_or("None".to_string());
112            debug!(
113                "[DirectAudioLoader] Track[{}]: id={}, sample_rate={}, channels={}",
114                idx, t.id, sr, ch
115            );
116        }
117
118        // Select the first audio track that contains sample_rate as the audio source.
119        let track = format
120            .tracks()
121            .iter()
122            .find(|t| t.codec_params.sample_rate.is_some())
123            .ok_or_else(|| {
124                warn!("[DirectAudioLoader] No audio track with sample_rate found");
125                SubXError::audio_processing("No audio track found".to_string())
126            })?;
127        let track_id = track.id;
128        let sample_rate = track.codec_params.sample_rate.ok_or_else(|| {
129            warn!("[DirectAudioLoader] Audio track sample_rate is unknown");
130            SubXError::audio_processing("Sample rate unknown".to_string())
131        })?;
132        let channels = track
133            .codec_params
134            .channels
135            .map(|c| c.count() as u16)
136            .unwrap_or(1);
137        debug!(
138            "[DirectAudioLoader] Selected track: id={}, sample_rate={}, channels={}",
139            track_id, sample_rate, channels
140        );
141
142        // Create decoder for the track.
143        let dec_opts = DecoderOptions::default();
144        let mut decoder = self
145            .codecs
146            .make(&track.codec_params, &dec_opts)
147            .map_err(|e| {
148                warn!("[DirectAudioLoader] Failed to create decoder: {}", e);
149                SubXError::audio_processing(format!("Failed to create decoder: {}", e))
150            })?;
151        debug!("[DirectAudioLoader] Decoder created successfully");
152
153        // Decode packets and collect samples.
154        let mut samples = Vec::new();
155        let mut packet_count = 0;
156        while let Ok(packet) = format.next_packet() {
157            if packet.track_id() != track_id {
158                continue;
159            }
160            packet_count += 1;
161            trace!(
162                "[DirectAudioLoader] Decoding packet {} (track_id={})",
163                packet_count, track_id
164            );
165            let decoded = decoder.decode(&packet).map_err(|e| {
166                warn!("[DirectAudioLoader] Failed to decode packet: {}", e);
167                SubXError::audio_processing(format!("Decode error: {}", e))
168            })?;
169            // Create a sample buffer for this packet using its signal spec and capacity.
170            let spec = *decoded.spec();
171            let mut sample_buf = SampleBuffer::<i16>::new(decoded.capacity() as u64, spec);
172            sample_buf.copy_interleaved_ref(decoded);
173            let sample_len = sample_buf.samples().len();
174            trace!(
175                "[DirectAudioLoader] Packet decoded successfully, got {} samples",
176                sample_len
177            );
178            samples.extend_from_slice(sample_buf.samples());
179        }
180        debug!(
181            "[DirectAudioLoader] Packet decoding finished, total {} packets, {} samples accumulated",
182            packet_count,
183            samples.len()
184        );
185
186        // Calculate total samples and audio duration
187        let total_samples = samples.len();
188        let duration_seconds = total_samples as f64 / (sample_rate as f64 * channels as f64);
189        debug!(
190            "[DirectAudioLoader] Audio info: sample_rate={}, channels={}, duration_seconds={:.3}, total_samples={}",
191            sample_rate, channels, duration_seconds, total_samples
192        );
193
194        Ok((
195            samples,
196            AudioInfo {
197                sample_rate,
198                channels,
199                duration_seconds,
200                total_samples,
201            },
202        ))
203    }
204}