Skip to main content

subx_cli/services/vad/
audio_loader.rs

1//! Direct audio loader: Uses Symphonia to directly decode various audio formats and obtain i16 sample data.
2//!
3//! Supports MP4, MKV, OGG, WAV and other formats, returning sample data and audio information.
4use crate::services::vad::detector::AudioInfo;
5use crate::{Result, error::SubXError};
6use log::{debug, trace, warn};
7use std::fs::File;
8use std::path::Path;
9use symphonia::core::audio::SampleBuffer;
10use symphonia::core::codecs::CodecRegistry;
11use symphonia::core::codecs::DecoderOptions;
12use symphonia::core::formats::FormatOptions;
13use symphonia::core::io::MediaSourceStream;
14use symphonia::core::probe::Hint;
15use symphonia::core::probe::Probe;
16use symphonia::default::{get_codecs, get_probe};
17
18/// Direct audio loader using Symphonia to decode and obtain raw sample data.
19pub struct DirectAudioLoader {
20    probe: &'static Probe,
21    codecs: &'static CodecRegistry,
22}
23
24#[cfg(test)]
25mod tests {
26    use super::*;
27
28    #[tokio::test]
29    async fn test_direct_mp4_loading() {
30        // Test direct audio loading using assets/SubX - The Subtitle Revolution.mp4
31        let loader = DirectAudioLoader::new().expect("Failed to initialize DirectAudioLoader");
32        let (samples, info) = loader
33            .load_audio_samples("assets/SubX - The Subtitle Revolution.mp4", 2_147_483_648)
34            .expect("load_audio_samples failed");
35        assert!(!samples.is_empty(), "Sample data should not be empty");
36        assert!(info.sample_rate > 0, "sample_rate should be greater than 0");
37        assert!(
38            info.total_samples > 0,
39            "total_samples should be greater than 0"
40        );
41    }
42}
43
44impl DirectAudioLoader {
45    /// Creates a new audio loader instance.
46    pub fn new() -> Result<Self> {
47        Ok(Self {
48            probe: get_probe(),
49            codecs: get_codecs(),
50        })
51    }
52
53    /// Loads i16 samples and audio information from an audio file path.
54    ///
55    /// `max_audio_bytes` caps the accepted file size as a defense-in-depth
56    /// guard; callers should thread the configured limit from
57    /// `GeneralConfig.max_audio_bytes`.
58    pub fn load_audio_samples<P: AsRef<Path>>(
59        &self,
60        path: P,
61        max_audio_bytes: u64,
62    ) -> Result<(Vec<i16>, AudioInfo)> {
63        let path_ref = path.as_ref();
64        debug!(
65            "[DirectAudioLoader] Start loading audio file: {:?}",
66            path_ref
67        );
68        crate::core::fs_util::check_file_size(path_ref, max_audio_bytes, "Audio")
69            .map_err(|e| SubXError::audio_processing(e.to_string()))?;
70        // Open the media source.
71        let file = File::open(path_ref).map_err(|e| {
72            warn!(
73                "[DirectAudioLoader] Failed to open audio file: {:?}, error: {}",
74                path_ref, e
75            );
76            SubXError::audio_processing(format!("Failed to open audio file: {}", e))
77        })?;
78        debug!(
79            "[DirectAudioLoader] Successfully opened audio file: {:?}",
80            path_ref
81        );
82
83        // Create the media source stream.
84        let mss = MediaSourceStream::new(Box::new(file), Default::default());
85        debug!("[DirectAudioLoader] MediaSourceStream created");
86
87        // Create a hint to help format probing based on file extension.
88        let mut hint = Hint::new();
89        if let Some(ext) = path_ref.extension().and_then(|e| e.to_str()) {
90            debug!(
91                "[DirectAudioLoader] Detected extension: {} (used for format probing)",
92                ext
93            );
94            hint.with_extension(ext);
95        } else {
96            debug!("[DirectAudioLoader] No extension detected, using default format probing");
97        }
98
99        // Probe the media format.
100        let probed = self
101            .probe
102            .format(&hint, mss, &FormatOptions::default(), &Default::default())
103            .map_err(|e| {
104                warn!("[DirectAudioLoader] Format probing failed: {}", e);
105                SubXError::audio_processing(format!("Failed to probe format: {}", e))
106            })?;
107        debug!("[DirectAudioLoader] Format probing succeeded");
108        let mut format = probed.format;
109
110        // List all tracks and their channel info before selecting
111        for (idx, t) in format.tracks().iter().enumerate() {
112            let sr = t
113                .codec_params
114                .sample_rate
115                .map(|v| v.to_string())
116                .unwrap_or("None".to_string());
117            let ch = t
118                .codec_params
119                .channels
120                .map(|c| c.count().to_string())
121                .unwrap_or("None".to_string());
122            debug!(
123                "[DirectAudioLoader] Track[{}]: id={}, sample_rate={}, channels={}",
124                idx, t.id, sr, ch
125            );
126        }
127
128        // Select the first audio track that contains sample_rate as the audio source.
129        let track = format
130            .tracks()
131            .iter()
132            .find(|t| t.codec_params.sample_rate.is_some())
133            .ok_or_else(|| {
134                warn!("[DirectAudioLoader] No audio track with sample_rate found");
135                SubXError::audio_processing("No audio track found".to_string())
136            })?;
137        // Clone necessary info first to avoid borrow conflicts
138        let track_id = track.id;
139        let sample_rate = track.codec_params.sample_rate.ok_or_else(|| {
140            warn!("[DirectAudioLoader] Audio track sample_rate is unknown");
141            SubXError::audio_processing("Sample rate unknown".to_string())
142        })?;
143        let channels = track.codec_params.channels.map(|c| c.count() as u16);
144        let time_base = track.codec_params.time_base;
145        debug!(
146            "[DirectAudioLoader] Selected track: id={}, sample_rate={}, channels={:?}",
147            track_id, sample_rate, channels
148        );
149
150        // Create decoder for the track.
151        let dec_opts = DecoderOptions::default();
152        let mut decoder = self
153            .codecs
154            .make(&track.codec_params, &dec_opts)
155            .map_err(|e| {
156                warn!("[DirectAudioLoader] Failed to create decoder: {}", e);
157                SubXError::audio_processing(format!("Failed to create decoder: {}", e))
158            })?;
159        debug!("[DirectAudioLoader] Decoder created successfully");
160
161        // Decode packets and collect samples.
162        let mut samples = Vec::new();
163        let mut packet_count = 0;
164        let mut last_pts: u64 = 0;
165        while let Ok(packet) = format.next_packet() {
166            if packet.track_id() != track_id {
167                continue;
168            }
169            packet_count += 1;
170            trace!(
171                "[DirectAudioLoader] Decoding packet {} (track_id={})",
172                packet_count, track_id
173            );
174            let decoded = decoder.decode(&packet).map_err(|e| {
175                warn!("[DirectAudioLoader] Failed to decode packet: {}", e);
176                SubXError::audio_processing(format!("Decode error: {}", e))
177            })?;
178            // Create a sample buffer for this packet using its signal spec and capacity.
179            let spec = *decoded.spec();
180            let mut sample_buf = SampleBuffer::<i16>::new(decoded.capacity() as u64, spec);
181            sample_buf.copy_interleaved_ref(decoded);
182            let sample_len = sample_buf.samples().len();
183            trace!(
184                "[DirectAudioLoader] Packet decoded successfully, got {} samples",
185                sample_len
186            );
187            samples.extend_from_slice(sample_buf.samples());
188            // Directly record the timestamp of the last packet
189            last_pts = packet.ts;
190        }
191        debug!(
192            "[DirectAudioLoader] Packet decoding finished, total {} packets, {} samples accumulated",
193            packet_count,
194            samples.len()
195        );
196
197        // Calculate total samples and audio duration
198        let total_samples = samples.len();
199        // Use Timebase to calculate duration_seconds
200        let duration_seconds = if let Some(tb) = time_base {
201            if last_pts > 0 {
202                let (num, den) = (tb.numer, tb.denom);
203                last_pts as f64 * num as f64 / den as f64
204            } else {
205                total_samples as f64 / (sample_rate as f64 * channels.unwrap_or(1) as f64)
206            }
207        } else {
208            total_samples as f64 / (sample_rate as f64 * channels.unwrap_or(1) as f64)
209        };
210        // If channels is None, try to infer channel count from duration_seconds
211        let channels = channels.unwrap_or_else(|| {
212            let ch = if duration_seconds > 0.0 {
213                (total_samples as f64 / (sample_rate as f64 * duration_seconds)).round() as u16
214            } else {
215                1
216            };
217            debug!("[DirectAudioLoader] Inferred channel count: {}", ch);
218            ch
219        });
220        debug!(
221            "[DirectAudioLoader] Audio info: sample_rate={}, channels={}, duration_seconds={:.3}, total_samples={}",
222            sample_rate, channels, duration_seconds, total_samples
223        );
224
225        Ok((
226            samples,
227            AudioInfo {
228                sample_rate,
229                channels,
230                duration_seconds,
231                total_samples,
232            },
233        ))
234    }
235}