subx_cli/services/vad/
audio_loader.rs

1//! Direct audio loader: Uses Symphonia to directly decode various audio formats and obtain i16 sample data.
2//!
3//! Supports MP4, MKV, OGG, WAV and other formats, returning sample data and audio information.
4use crate::services::vad::detector::AudioInfo;
5use crate::{Result, error::SubXError};
6use log::{debug, trace, warn};
7use std::fs::File;
8use std::path::Path;
9use symphonia::core::audio::SampleBuffer;
10use symphonia::core::codecs::CodecRegistry;
11use symphonia::core::codecs::DecoderOptions;
12use symphonia::core::formats::FormatOptions;
13use symphonia::core::io::MediaSourceStream;
14use symphonia::core::probe::Hint;
15use symphonia::core::probe::Probe;
16use symphonia::default::{get_codecs, get_probe};
17
18/// Direct audio loader using Symphonia to decode and obtain raw sample data.
19pub struct DirectAudioLoader {
20    probe: &'static Probe,
21    codecs: &'static CodecRegistry,
22}
23
24#[cfg(test)]
25mod tests {
26    use super::*;
27
28    #[tokio::test]
29    async fn test_direct_mp4_loading() {
30        // Test direct audio loading using assets/SubX - The Subtitle Revolution.mp4
31        let loader = DirectAudioLoader::new().expect("Failed to initialize DirectAudioLoader");
32        let (samples, info) = loader
33            .load_audio_samples("assets/SubX - The Subtitle Revolution.mp4")
34            .expect("load_audio_samples failed");
35        assert!(!samples.is_empty(), "Sample data should not be empty");
36        assert!(info.sample_rate > 0, "sample_rate should be greater than 0");
37        assert!(
38            info.total_samples > 0,
39            "total_samples should be greater than 0"
40        );
41    }
42}
43
44impl DirectAudioLoader {
45    /// Creates a new audio loader instance.
46    pub fn new() -> Result<Self> {
47        Ok(Self {
48            probe: get_probe(),
49            codecs: get_codecs(),
50        })
51    }
52
53    /// Loads i16 samples and audio information from an audio file path.
54    pub fn load_audio_samples<P: AsRef<Path>>(&self, path: P) -> Result<(Vec<i16>, AudioInfo)> {
55        let path_ref = path.as_ref();
56        debug!(
57            "[DirectAudioLoader] Start loading audio file: {:?}",
58            path_ref
59        );
60        // Open the media source.
61        let file = File::open(path_ref).map_err(|e| {
62            warn!(
63                "[DirectAudioLoader] Failed to open audio file: {:?}, error: {}",
64                path_ref, e
65            );
66            SubXError::audio_processing(format!("Failed to open audio file: {}", e))
67        })?;
68        debug!(
69            "[DirectAudioLoader] Successfully opened audio file: {:?}",
70            path_ref
71        );
72
73        // Create the media source stream.
74        let mss = MediaSourceStream::new(Box::new(file), Default::default());
75        debug!("[DirectAudioLoader] MediaSourceStream created");
76
77        // Create a hint to help format probing based on file extension.
78        let mut hint = Hint::new();
79        if let Some(ext) = path_ref.extension().and_then(|e| e.to_str()) {
80            debug!(
81                "[DirectAudioLoader] Detected extension: {} (used for format probing)",
82                ext
83            );
84            hint.with_extension(ext);
85        } else {
86            debug!("[DirectAudioLoader] No extension detected, using default format probing");
87        }
88
89        // Probe the media format.
90        let probed = self
91            .probe
92            .format(&hint, mss, &FormatOptions::default(), &Default::default())
93            .map_err(|e| {
94                warn!("[DirectAudioLoader] Format probing failed: {}", e);
95                SubXError::audio_processing(format!("Failed to probe format: {}", e))
96            })?;
97        debug!("[DirectAudioLoader] Format probing succeeded");
98        let mut format = probed.format;
99
100        // List all tracks and their channel info before selecting
101        for (idx, t) in format.tracks().iter().enumerate() {
102            let sr = t
103                .codec_params
104                .sample_rate
105                .map(|v| v.to_string())
106                .unwrap_or("None".to_string());
107            let ch = t
108                .codec_params
109                .channels
110                .map(|c| c.count().to_string())
111                .unwrap_or("None".to_string());
112            debug!(
113                "[DirectAudioLoader] Track[{}]: id={}, sample_rate={}, channels={}",
114                idx, t.id, sr, ch
115            );
116        }
117
118        // Select the first audio track that contains sample_rate as the audio source.
119        let track = format
120            .tracks()
121            .iter()
122            .find(|t| t.codec_params.sample_rate.is_some())
123            .ok_or_else(|| {
124                warn!("[DirectAudioLoader] No audio track with sample_rate found");
125                SubXError::audio_processing("No audio track found".to_string())
126            })?;
127        // Clone necessary info first to avoid borrow conflicts
128        let track_id = track.id;
129        let sample_rate = track.codec_params.sample_rate.ok_or_else(|| {
130            warn!("[DirectAudioLoader] Audio track sample_rate is unknown");
131            SubXError::audio_processing("Sample rate unknown".to_string())
132        })?;
133        let channels = track.codec_params.channels.map(|c| c.count() as u16);
134        let time_base = track.codec_params.time_base;
135        debug!(
136            "[DirectAudioLoader] Selected track: id={}, sample_rate={}, channels={:?}",
137            track_id, sample_rate, channels
138        );
139
140        // Create decoder for the track.
141        let dec_opts = DecoderOptions::default();
142        let mut decoder = self
143            .codecs
144            .make(&track.codec_params, &dec_opts)
145            .map_err(|e| {
146                warn!("[DirectAudioLoader] Failed to create decoder: {}", e);
147                SubXError::audio_processing(format!("Failed to create decoder: {}", e))
148            })?;
149        debug!("[DirectAudioLoader] Decoder created successfully");
150
151        // Decode packets and collect samples.
152        let mut samples = Vec::new();
153        let mut packet_count = 0;
154        let mut last_pts: u64 = 0;
155        while let Ok(packet) = format.next_packet() {
156            if packet.track_id() != track_id {
157                continue;
158            }
159            packet_count += 1;
160            trace!(
161                "[DirectAudioLoader] Decoding packet {} (track_id={})",
162                packet_count, track_id
163            );
164            let decoded = decoder.decode(&packet).map_err(|e| {
165                warn!("[DirectAudioLoader] Failed to decode packet: {}", e);
166                SubXError::audio_processing(format!("Decode error: {}", e))
167            })?;
168            // Create a sample buffer for this packet using its signal spec and capacity.
169            let spec = *decoded.spec();
170            let mut sample_buf = SampleBuffer::<i16>::new(decoded.capacity() as u64, spec);
171            sample_buf.copy_interleaved_ref(decoded);
172            let sample_len = sample_buf.samples().len();
173            trace!(
174                "[DirectAudioLoader] Packet decoded successfully, got {} samples",
175                sample_len
176            );
177            samples.extend_from_slice(sample_buf.samples());
178            // Directly record the timestamp of the last packet
179            last_pts = packet.ts;
180        }
181        debug!(
182            "[DirectAudioLoader] Packet decoding finished, total {} packets, {} samples accumulated",
183            packet_count,
184            samples.len()
185        );
186
187        // Calculate total samples and audio duration
188        let total_samples = samples.len();
189        // Use Timebase to calculate duration_seconds
190        let duration_seconds = if let Some(tb) = time_base {
191            if last_pts > 0 {
192                let (num, den) = (tb.numer, tb.denom);
193                last_pts as f64 * num as f64 / den as f64
194            } else {
195                total_samples as f64 / (sample_rate as f64 * channels.unwrap_or(1) as f64)
196            }
197        } else {
198            total_samples as f64 / (sample_rate as f64 * channels.unwrap_or(1) as f64)
199        };
200        // If channels is None, try to infer channel count from duration_seconds
201        let channels = channels.unwrap_or_else(|| {
202            let ch = if duration_seconds > 0.0 {
203                (total_samples as f64 / (sample_rate as f64 * duration_seconds)).round() as u16
204            } else {
205                1
206            };
207            debug!("[DirectAudioLoader] Inferred channel count: {}", ch);
208            ch
209        });
210        debug!(
211            "[DirectAudioLoader] Audio info: sample_rate={}, channels={}, duration_seconds={:.3}, total_samples={}",
212            sample_rate, channels, duration_seconds, total_samples
213        );
214
215        Ok((
216            samples,
217            AudioInfo {
218                sample_rate,
219                channels,
220                duration_seconds,
221                total_samples,
222            },
223        ))
224    }
225}