1use crate::{Result, error::SubXError};
4use hound::{SampleFormat, WavSpec, WavWriter};
5use log::warn;
6use std::fs::{self, File};
7use std::path::{Path, PathBuf};
8use std::time::Duration;
9use symphonia::core::{
10 audio::{Layout, SampleBuffer},
11 codecs::CODEC_TYPE_NULL,
12 io::MediaSourceStream,
13};
14use symphonia::core::{codecs::CodecRegistry, probe::Probe};
15use symphonia::default::{get_codecs, get_probe};
16use tempfile::TempDir;
17pub struct AudioTranscoder {
19 temp_dir: TempDir,
21 probe: &'static Probe,
22 codecs: &'static CodecRegistry,
23}
24
25pub struct TranscodeStats {
27 pub total_packets: u64,
29 pub decoded_packets: u64,
31 pub skipped_decode_errors: u64,
33 pub skipped_io_errors: u64,
35 pub reset_required_count: u64,
37}
38
39impl TranscodeStats {
40 pub fn new() -> Self {
42 Self {
43 total_packets: 0,
44 decoded_packets: 0,
45 skipped_decode_errors: 0,
46 skipped_io_errors: 0,
47 reset_required_count: 0,
48 }
49 }
50
51 pub fn success_rate(&self) -> f64 {
53 if self.total_packets == 0 {
54 0.0
55 } else {
56 self.decoded_packets as f64 / self.total_packets as f64
57 }
58 }
59}
60
61#[cfg(test)]
62mod tests {
63 use super::*;
64
65 use tempfile::TempDir;
66
67 fn create_minimal_wav_file(dir: &TempDir) -> PathBuf {
69 let path = dir.path().join("test.wav");
70 let spec = WavSpec {
71 channels: 1,
72 sample_rate: 44100,
73 bits_per_sample: 16,
74 sample_format: SampleFormat::Int,
75 };
76 let mut writer = WavWriter::create(&path, spec).unwrap();
77 writer.write_sample(0i16).unwrap();
78 writer.finalize().unwrap();
79 path
80 }
81
82 #[test]
83 fn test_needs_transcoding() {
84 let transcoder = AudioTranscoder::new().expect("Failed to create transcoder");
85 assert!(transcoder.needs_transcoding("test.mp4").unwrap());
86 assert!(transcoder.needs_transcoding("test.MKV").unwrap());
87 assert!(transcoder.needs_transcoding("test.ogg").unwrap());
88 assert!(!transcoder.needs_transcoding("test.wav").unwrap());
89 }
90
91 #[tokio::test]
92 #[ignore]
93 async fn test_transcode_wav_to_wav() {
94 let transcoder = AudioTranscoder::new().expect("Failed to create transcoder");
95 let temp_dir = TempDir::new().unwrap();
96 let wav_path = create_minimal_wav_file(&temp_dir);
97 let out_path = transcoder
98 .transcode_to_wav(&wav_path)
99 .await
100 .expect("Transcode failed");
101 assert_eq!(out_path.extension().and_then(|e| e.to_str()), Some("wav"));
102 let meta = std::fs::metadata(&out_path).expect("Failed to stat output file");
103 assert!(meta.len() > 0, "Output WAV file should not be empty");
104 }
105
106 #[test]
107 fn test_transcode_stats_success_rate() {
108 let mut stats = TranscodeStats::new();
109 assert_eq!(stats.success_rate(), 0.0);
110 stats.total_packets = 10;
111 stats.decoded_packets = 7;
112 let rate = stats.success_rate();
113 assert!(
114 (rate - 0.7).abs() < f64::EPSILON,
115 "Expected 0.7, got {}",
116 rate
117 );
118 }
119}
120
121impl AudioTranscoder {
122 pub fn new() -> Result<Self> {
124 let temp_dir = TempDir::new().map_err(|e| {
125 SubXError::audio_processing(format!("Failed to create temp dir: {}", e))
126 })?;
127 let probe = get_probe();
128 let codecs = get_codecs();
129 Ok(Self {
130 temp_dir,
131 probe,
132 codecs,
133 })
134 }
135
136 pub fn needs_transcoding<P: AsRef<Path>>(&self, audio_path: P) -> Result<bool> {
138 if let Some(ext) = audio_path.as_ref().extension().and_then(|s| s.to_str()) {
139 let ext_lc = ext.to_lowercase();
140 if ext_lc == "wav" { Ok(false) } else { Ok(true) }
141 } else {
142 Err(SubXError::audio_processing(
143 "Missing file extension".to_string(),
144 ))
145 }
146 }
147
148 pub fn cleanup(self) -> Result<()> {
150 self.temp_dir
151 .close()
152 .map_err(|e| SubXError::audio_processing(format!("Failed to clean temp dir: {}", e)))?;
153 Ok(())
154 }
155}
156
157impl AudioTranscoder {
158 pub async fn transcode_to_wav_with_config<P: AsRef<Path>>(
160 &self,
161 input_path: P,
162 min_success_rate: Option<f64>,
163 ) -> Result<(PathBuf, TranscodeStats)> {
164 use symphonia::core::errors::Error as SymphoniaError;
165
166 let input = input_path.as_ref();
167 let min_success_rate = min_success_rate.unwrap_or(0.5);
168 let mut stats = TranscodeStats::new();
169
170 let file = File::open(input).map_err(|e| {
171 SubXError::audio_processing(format!(
172 "Failed to open input file {}: {}",
173 input.display(),
174 e
175 ))
176 })?;
177 let mss = MediaSourceStream::new(Box::new(file), Default::default());
178
179 let probed = self
180 .probe
181 .format(
182 &Default::default(),
183 mss,
184 &Default::default(),
185 &Default::default(),
186 )
187 .map_err(|e| SubXError::audio_processing(format!("Format probe error: {}", e)))?;
188 let mut format = probed.format;
189
190 let track = format
191 .tracks()
192 .iter()
193 .find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
194 .ok_or_else(|| SubXError::audio_processing("No audio track found".to_string()))?;
195
196 let mut decoder = self
197 .codecs
198 .make(&track.codec_params, &Default::default())
199 .map_err(|e| SubXError::audio_processing(format!("Decoder error: {}", e)))?;
200
201 let sample_rate = track.codec_params.sample_rate.unwrap_or(44100);
202 let layout = track.codec_params.channel_layout.unwrap_or(Layout::Stereo);
203 let channels = layout.into_channels().count() as u16;
204 let spec = WavSpec {
205 channels,
206 sample_rate,
207 bits_per_sample: 16,
208 sample_format: SampleFormat::Int,
209 };
210
211 let wav_path = self
212 .temp_dir
213 .path()
214 .join(input.file_stem().unwrap_or_default())
215 .with_extension("wav");
216 let mut writer = WavWriter::create(&wav_path, spec)
217 .map_err(|e| SubXError::audio_processing(format!("WAV writer error: {}", e)))?;
218
219 loop {
220 stats.total_packets += 1;
221 match format.next_packet() {
222 Ok(packet) => match decoder.decode(&packet) {
223 Ok(audio_buf) => {
224 stats.decoded_packets += 1;
225
226 let mut sample_buf = SampleBuffer::<i16>::new(
227 audio_buf.capacity() as u64,
228 *audio_buf.spec(),
229 );
230 sample_buf.copy_interleaved_ref(audio_buf);
231 for sample in sample_buf.samples() {
232 writer.write_sample(*sample).map_err(|e| {
233 SubXError::audio_processing(format!("Write sample error: {}", e))
234 })?;
235 }
236 }
237 Err(SymphoniaError::DecodeError(decode_err)) => {
238 warn!(
239 "Decode error (recoverable), skipping packet: {}",
240 decode_err
241 );
242 stats.skipped_decode_errors += 1;
243 continue;
244 }
245 Err(SymphoniaError::IoError(io_err)) => {
246 warn!("I/O error (recoverable), skipping packet: {}", io_err);
247 stats.skipped_io_errors += 1;
248 continue;
249 }
250 Err(SymphoniaError::ResetRequired) => {
251 warn!("Decoder reset required, audio specs may change");
252 stats.reset_required_count += 1;
253 continue;
254 }
255 Err(other) => {
256 return Err(SubXError::audio_processing(format!(
257 "Unrecoverable decode error: {}",
258 other
259 )));
260 }
261 },
262 Err(SymphoniaError::IoError(err))
263 if err.kind() == std::io::ErrorKind::UnexpectedEof =>
264 {
265 break;
266 }
267 Err(e) => {
268 return Err(SubXError::audio_processing(format!(
269 "Packet read error: {}",
270 e
271 )));
272 }
273 }
274 }
275
276 writer
277 .finalize()
278 .map_err(|e| SubXError::audio_processing(format!("Finalize WAV error: {}", e)))?;
279
280 if stats.success_rate() < min_success_rate {
281 warn!(
282 "Final decode success rate ({:.1}%) is below minimum threshold ({:.1}%)",
283 stats.success_rate() * 100.0,
284 min_success_rate * 100.0
285 );
286 }
287
288 if stats.total_packets > 10 && stats.success_rate() < min_success_rate {
289 return Err(SubXError::audio_processing(format!(
290 "Decode success rate ({:.1}%) below minimum threshold ({:.1}%), output quality unacceptable",
291 stats.success_rate() * 100.0,
292 min_success_rate * 100.0
293 )));
294 }
295
296 Ok((wav_path, stats))
297 }
298
299 pub async fn transcode_to_wav<P: AsRef<Path>>(&self, input_path: P) -> Result<PathBuf> {
301 let (path, stats) = self.transcode_to_wav_with_config(input_path, None).await?;
302 if stats.success_rate() < 0.8 {
303 warn!(
304 "Low decode success rate ({:.1}%), output quality may be affected",
305 stats.success_rate() * 100.0
306 );
307 }
308 Ok(path)
309 }
310
311 pub async fn extract_segment<P: AsRef<Path>, Q: AsRef<Path>>(
320 &self,
321 input: P,
322 output: Q,
323 _start: Duration,
324 _end: Duration,
325 ) -> Result<()> {
326 let temp = self.transcode_to_wav(input).await?;
327 fs::copy(&temp, output.as_ref()).map_err(|e| {
328 SubXError::audio_processing(format!("Failed to extract segment: {}", e))
329 })?;
330 Ok(())
331 }
332
333 pub async fn transcode_to_format<P: AsRef<Path>, Q: AsRef<Path>>(
342 &self,
343 input: P,
344 output: Q,
345 _sample_rate: u32,
346 _channels: u32,
347 ) -> Result<()> {
348 let temp = self.transcode_to_wav(input).await?;
349 fs::copy(&temp, output.as_ref()).map_err(|e| {
350 SubXError::audio_processing(format!("Failed to transcode format: {}", e))
351 })?;
352 Ok(())
353 }
354}