subx_cli/services/audio/
transcoder.rs

1//! Audio transcoding service: Multi-format to WAV conversion based on Symphonia.
2
3use crate::{Result, error::SubXError};
4use hound::{SampleFormat, WavSpec, WavWriter};
5use log::warn;
6use std::fs::{self, File};
7use std::path::{Path, PathBuf};
8use std::time::Duration;
9use symphonia::core::{
10    audio::{Layout, SampleBuffer},
11    codecs::CODEC_TYPE_NULL,
12    io::MediaSourceStream,
13};
14use symphonia::core::{codecs::CodecRegistry, probe::Probe};
15use symphonia::default::{get_codecs, get_probe};
16use tempfile::TempDir;
17/// Audio transcoder: Detects file format and converts non-WAV files to WAV.
18pub struct AudioTranscoder {
19    /// Temporary directory for storing transcoding results
20    temp_dir: TempDir,
21    probe: &'static Probe,
22    codecs: &'static CodecRegistry,
23}
24
25/// Audio transcoding statistics
26pub struct TranscodeStats {
27    /// Total number of packets processed
28    pub total_packets: u64,
29    /// Number of successfully decoded packets
30    pub decoded_packets: u64,
31    /// Number of packets skipped due to DecodeError
32    pub skipped_decode_errors: u64,
33    /// Number of packets skipped due to IoError
34    pub skipped_io_errors: u64,
35    /// Number of times reset was required
36    pub reset_required_count: u64,
37}
38
39impl TranscodeStats {
40    /// Create new statistics
41    pub fn new() -> Self {
42        Self {
43            total_packets: 0,
44            decoded_packets: 0,
45            skipped_decode_errors: 0,
46            skipped_io_errors: 0,
47            reset_required_count: 0,
48        }
49    }
50
51    /// Calculate decoding success rate
52    pub fn success_rate(&self) -> f64 {
53        if self.total_packets == 0 {
54            0.0
55        } else {
56            self.decoded_packets as f64 / self.total_packets as f64
57        }
58    }
59}
60
61#[cfg(test)]
62mod tests {
63    use super::*;
64
65    use tempfile::TempDir;
66
67    /// Create a minimal WAV file for testing transcoding.
68    fn create_minimal_wav_file(dir: &TempDir) -> PathBuf {
69        let path = dir.path().join("test.wav");
70        let spec = WavSpec {
71            channels: 1,
72            sample_rate: 44100,
73            bits_per_sample: 16,
74            sample_format: SampleFormat::Int,
75        };
76        let mut writer = WavWriter::create(&path, spec).unwrap();
77        writer.write_sample(0i16).unwrap();
78        writer.finalize().unwrap();
79        path
80    }
81
82    #[test]
83    fn test_needs_transcoding() {
84        let transcoder = AudioTranscoder::new().expect("Failed to create transcoder");
85        assert!(transcoder.needs_transcoding("test.mp4").unwrap());
86        assert!(transcoder.needs_transcoding("test.MKV").unwrap());
87        assert!(transcoder.needs_transcoding("test.ogg").unwrap());
88        assert!(!transcoder.needs_transcoding("test.wav").unwrap());
89    }
90
91    #[tokio::test]
92    #[ignore]
93    async fn test_transcode_wav_to_wav() {
94        let transcoder = AudioTranscoder::new().expect("Failed to create transcoder");
95        let temp_dir = TempDir::new().unwrap();
96        let wav_path = create_minimal_wav_file(&temp_dir);
97        let out_path = transcoder
98            .transcode_to_wav(&wav_path)
99            .await
100            .expect("Transcode failed");
101        assert_eq!(out_path.extension().and_then(|e| e.to_str()), Some("wav"));
102        let meta = std::fs::metadata(&out_path).expect("Failed to stat output file");
103        assert!(meta.len() > 0, "Output WAV file should not be empty");
104    }
105
106    #[test]
107    fn test_transcode_stats_success_rate() {
108        let mut stats = TranscodeStats::new();
109        assert_eq!(stats.success_rate(), 0.0);
110        stats.total_packets = 10;
111        stats.decoded_packets = 7;
112        let rate = stats.success_rate();
113        assert!(
114            (rate - 0.7).abs() < f64::EPSILON,
115            "Expected 0.7, got {}",
116            rate
117        );
118    }
119}
120
121impl AudioTranscoder {
122    /// Create a new AudioTranscoder instance and initialize temporary directory.
123    pub fn new() -> Result<Self> {
124        let temp_dir = TempDir::new().map_err(|e| {
125            SubXError::audio_processing(format!("Failed to create temp dir: {}", e))
126        })?;
127        let probe = get_probe();
128        let codecs = get_codecs();
129        Ok(Self {
130            temp_dir,
131            probe,
132            codecs,
133        })
134    }
135
136    /// Check if the audio file at the specified path needs transcoding (based on file extension).
137    pub fn needs_transcoding<P: AsRef<Path>>(&self, audio_path: P) -> Result<bool> {
138        if let Some(ext) = audio_path.as_ref().extension().and_then(|s| s.to_str()) {
139            let ext_lc = ext.to_lowercase();
140            if ext_lc == "wav" { Ok(false) } else { Ok(true) }
141        } else {
142            Err(SubXError::audio_processing(
143                "Missing file extension".to_string(),
144            ))
145        }
146    }
147
148    /// Actively clean up temporary directory
149    pub fn cleanup(self) -> Result<()> {
150        self.temp_dir
151            .close()
152            .map_err(|e| SubXError::audio_processing(format!("Failed to clean temp dir: {}", e)))?;
153        Ok(())
154    }
155}
156
157impl AudioTranscoder {
158    /// Audio transcoding method with configuration, allowing specification of minimum success rate
159    pub async fn transcode_to_wav_with_config<P: AsRef<Path>>(
160        &self,
161        input_path: P,
162        min_success_rate: Option<f64>,
163    ) -> Result<(PathBuf, TranscodeStats)> {
164        use symphonia::core::errors::Error as SymphoniaError;
165
166        let input = input_path.as_ref();
167        let min_success_rate = min_success_rate.unwrap_or(0.5);
168        let mut stats = TranscodeStats::new();
169
170        let file = File::open(input).map_err(|e| {
171            SubXError::audio_processing(format!(
172                "Failed to open input file {}: {}",
173                input.display(),
174                e
175            ))
176        })?;
177        let mss = MediaSourceStream::new(Box::new(file), Default::default());
178
179        let probed = self
180            .probe
181            .format(
182                &Default::default(),
183                mss,
184                &Default::default(),
185                &Default::default(),
186            )
187            .map_err(|e| SubXError::audio_processing(format!("Format probe error: {}", e)))?;
188        let mut format = probed.format;
189
190        let track = format
191            .tracks()
192            .iter()
193            .find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
194            .ok_or_else(|| SubXError::audio_processing("No audio track found".to_string()))?;
195
196        let mut decoder = self
197            .codecs
198            .make(&track.codec_params, &Default::default())
199            .map_err(|e| SubXError::audio_processing(format!("Decoder error: {}", e)))?;
200
201        let sample_rate = track.codec_params.sample_rate.unwrap_or(44100);
202        let layout = track.codec_params.channel_layout.unwrap_or(Layout::Stereo);
203        let channels = layout.into_channels().count() as u16;
204        let spec = WavSpec {
205            channels,
206            sample_rate,
207            bits_per_sample: 16,
208            sample_format: SampleFormat::Int,
209        };
210
211        let wav_path = self
212            .temp_dir
213            .path()
214            .join(input.file_stem().unwrap_or_default())
215            .with_extension("wav");
216        let mut writer = WavWriter::create(&wav_path, spec)
217            .map_err(|e| SubXError::audio_processing(format!("WAV writer error: {}", e)))?;
218
219        loop {
220            stats.total_packets += 1;
221            match format.next_packet() {
222                Ok(packet) => match decoder.decode(&packet) {
223                    Ok(audio_buf) => {
224                        stats.decoded_packets += 1;
225
226                        let mut sample_buf = SampleBuffer::<i16>::new(
227                            audio_buf.capacity() as u64,
228                            *audio_buf.spec(),
229                        );
230                        sample_buf.copy_interleaved_ref(audio_buf);
231                        for sample in sample_buf.samples() {
232                            writer.write_sample(*sample).map_err(|e| {
233                                SubXError::audio_processing(format!("Write sample error: {}", e))
234                            })?;
235                        }
236                    }
237                    Err(SymphoniaError::DecodeError(decode_err)) => {
238                        warn!(
239                            "Decode error (recoverable), skipping packet: {}",
240                            decode_err
241                        );
242                        stats.skipped_decode_errors += 1;
243                        continue;
244                    }
245                    Err(SymphoniaError::IoError(io_err)) => {
246                        warn!("I/O error (recoverable), skipping packet: {}", io_err);
247                        stats.skipped_io_errors += 1;
248                        continue;
249                    }
250                    Err(SymphoniaError::ResetRequired) => {
251                        warn!("Decoder reset required, audio specs may change");
252                        stats.reset_required_count += 1;
253                        continue;
254                    }
255                    Err(other) => {
256                        return Err(SubXError::audio_processing(format!(
257                            "Unrecoverable decode error: {}",
258                            other
259                        )));
260                    }
261                },
262                Err(SymphoniaError::IoError(err))
263                    if err.kind() == std::io::ErrorKind::UnexpectedEof =>
264                {
265                    break;
266                }
267                Err(e) => {
268                    return Err(SubXError::audio_processing(format!(
269                        "Packet read error: {}",
270                        e
271                    )));
272                }
273            }
274        }
275
276        writer
277            .finalize()
278            .map_err(|e| SubXError::audio_processing(format!("Finalize WAV error: {}", e)))?;
279
280        if stats.success_rate() < min_success_rate {
281            warn!(
282                "Final decode success rate ({:.1}%) is below minimum threshold ({:.1}%)",
283                stats.success_rate() * 100.0,
284                min_success_rate * 100.0
285            );
286        }
287
288        if stats.total_packets > 10 && stats.success_rate() < min_success_rate {
289            return Err(SubXError::audio_processing(format!(
290                "Decode success rate ({:.1}%) below minimum threshold ({:.1}%), output quality unacceptable",
291                stats.success_rate() * 100.0,
292                min_success_rate * 100.0
293            )));
294        }
295
296        Ok((wav_path, stats))
297    }
298
299    /// Transcode input audio file to WAV and save to temporary directory (backward compatibility).
300    pub async fn transcode_to_wav<P: AsRef<Path>>(&self, input_path: P) -> Result<PathBuf> {
301        let (path, stats) = self.transcode_to_wav_with_config(input_path, None).await?;
302        if stats.success_rate() < 0.8 {
303            warn!(
304                "Low decode success rate ({:.1}%), output quality may be affected",
305                stats.success_rate() * 100.0
306            );
307        }
308        Ok(path)
309    }
310
311    /// Extract audio segment for specified time range and convert to WAV.
312    ///
313    /// # Arguments
314    ///
315    /// * `input` - Input audio file path
316    /// * `output` - Output WAV file path
317    /// * `_start` - Start time for extraction (currently unused)
318    /// * `_end` - End time for extraction (currently unused)
319    pub async fn extract_segment<P: AsRef<Path>, Q: AsRef<Path>>(
320        &self,
321        input: P,
322        output: Q,
323        _start: Duration,
324        _end: Duration,
325    ) -> Result<()> {
326        let temp = self.transcode_to_wav(input).await?;
327        fs::copy(&temp, output.as_ref()).map_err(|e| {
328            SubXError::audio_processing(format!("Failed to extract segment: {}", e))
329        })?;
330        Ok(())
331    }
332
333    /// Transcode audio to specified format (WAV), ignoring parameters.
334    ///
335    /// # Arguments
336    ///
337    /// * `input` - Input audio file path
338    /// * `output` - Output file path
339    /// * `_sample_rate` - Target sample rate (currently unused)
340    /// * `_channels` - Target channel count (currently unused)
341    pub async fn transcode_to_format<P: AsRef<Path>, Q: AsRef<Path>>(
342        &self,
343        input: P,
344        output: Q,
345        _sample_rate: u32,
346        _channels: u32,
347    ) -> Result<()> {
348        let temp = self.transcode_to_wav(input).await?;
349        fs::copy(&temp, output.as_ref()).map_err(|e| {
350            SubXError::audio_processing(format!("Failed to transcode format: {}", e))
351        })?;
352        Ok(())
353    }
354}