subx_cli/services/vad/
audio_processor.rs

1use super::AudioInfo;
2use crate::{Result, error::SubXError};
3use hound::{SampleFormat, WavReader};
4use rubato::{Resampler, SincFixedIn, SincInterpolationParameters, SincInterpolationType};
5use std::fs::File;
6use std::io::BufReader;
7use std::path::Path;
8
9/// Audio processor for VAD operations.
10///
11/// Handles loading, resampling, and format conversion of audio files
12/// for voice activity detection processing.
13pub struct VadAudioProcessor {
14    target_sample_rate: u32,
15    target_channels: u16,
16}
17
18/// Processed audio data ready for VAD analysis.
19///
20/// Contains the audio samples and metadata after processing
21/// and format conversion.
22#[derive(Debug)]
23pub struct ProcessedAudioData {
24    /// Audio samples as 16-bit integers
25    pub samples: Vec<i16>,
26    /// Audio metadata and properties
27    pub info: AudioInfo,
28}
29
30impl VadAudioProcessor {
31    /// Create a new VAD audio processor.
32    ///
33    /// # Arguments
34    ///
35    /// * `target_sample_rate` - Desired sample rate for processing
36    /// * `target_channels` - Desired number of audio channels
37    ///
38    /// # Returns
39    ///
40    /// A new `VadAudioProcessor` instance
41    pub fn new(target_sample_rate: u32, target_channels: u16) -> Result<Self> {
42        Ok(Self {
43            target_sample_rate,
44            target_channels,
45        })
46    }
47
48    /// Load and prepare audio file for VAD processing.
49    ///
50    /// Performs all necessary audio processing steps including loading,
51    /// resampling, and format conversion to prepare the audio for
52    /// voice activity detection.
53    ///
54    /// # Arguments
55    ///
56    /// * `audio_path` - Path to the audio file to process
57    ///
58    /// # Returns
59    ///
60    /// Processed audio data ready for VAD analysis
61    ///
62    /// # Errors
63    ///
64    /// Returns an error if:
65    /// - Audio file cannot be loaded
66    /// - Audio format is unsupported
67    /// - Resampling fails
68    /// - Format conversion fails
69    pub async fn load_and_prepare_audio(&self, audio_path: &Path) -> Result<ProcessedAudioData> {
70        // 1. Load audio file
71        let raw_audio_data = self.load_wav_file(audio_path)?;
72
73        // 2. Convert sample rate (if needed)
74        let resampled_data = if raw_audio_data.info.sample_rate != self.target_sample_rate {
75            self.resample_audio(&raw_audio_data)?
76        } else {
77            raw_audio_data
78        };
79
80        // 3. Convert to mono (if needed)
81        let mono_data = if resampled_data.info.channels > 1 {
82            self.convert_to_mono(&resampled_data)?
83        } else {
84            resampled_data
85        };
86
87        Ok(mono_data)
88    }
89
90    fn load_wav_file(&self, path: &Path) -> Result<ProcessedAudioData> {
91        let file = File::open(path).map_err(|e| {
92            SubXError::audio_processing(format!("Failed to open audio file: {}", e))
93        })?;
94
95        let reader = WavReader::new(BufReader::new(file))
96            .map_err(|e| SubXError::audio_processing(format!("Failed to read WAV file: {}", e)))?;
97
98        let spec = reader.spec();
99        let sample_rate = spec.sample_rate;
100        let channels = spec.channels;
101
102        // Read all samples and convert to i16
103        let samples: Vec<i16> = match spec.sample_format {
104            SampleFormat::Int => match spec.bits_per_sample {
105                16 => {
106                    let samples: std::result::Result<Vec<i16>, hound::Error> =
107                        reader.into_samples::<i16>().collect();
108                    samples.map_err(|e| {
109                        SubXError::audio_processing(format!("Failed to read samples: {}", e))
110                    })?
111                }
112                32 => {
113                    let samples: std::result::Result<Vec<i32>, hound::Error> =
114                        reader.into_samples::<i32>().collect();
115                    let i32_samples = samples.map_err(|e| {
116                        SubXError::audio_processing(format!("Failed to read i32 samples: {}", e))
117                    })?;
118                    i32_samples.iter().map(|&s| (s >> 16) as i16).collect()
119                }
120                _ => {
121                    return Err(SubXError::audio_processing(format!(
122                        "Unsupported bit depth: {}",
123                        spec.bits_per_sample
124                    )));
125                }
126            },
127            SampleFormat::Float => {
128                let samples: std::result::Result<Vec<f32>, hound::Error> =
129                    reader.into_samples::<f32>().collect();
130                let f32_samples = samples.map_err(|e| {
131                    SubXError::audio_processing(format!("Failed to read f32 samples: {}", e))
132                })?;
133                f32_samples.iter().map(|&s| (s * 32767.0) as i16).collect()
134            }
135        };
136
137        let samples_len = samples.len();
138        let duration_seconds = samples_len as f64 / (sample_rate as f64 * channels as f64);
139
140        Ok(ProcessedAudioData {
141            samples,
142            info: AudioInfo {
143                sample_rate,
144                channels,
145                duration_seconds,
146                total_samples: samples_len,
147            },
148        })
149    }
150
151    fn resample_audio(&self, audio_data: &ProcessedAudioData) -> Result<ProcessedAudioData> {
152        if audio_data.info.sample_rate == self.target_sample_rate {
153            // Cloning via struct initializer to own data
154            return Ok(ProcessedAudioData {
155                samples: audio_data.samples.clone(),
156                info: audio_data.info.clone(),
157            });
158        }
159
160        // Configure resampling parameters
161        let params = SincInterpolationParameters {
162            sinc_len: 256,
163            f_cutoff: 0.95,
164            interpolation: SincInterpolationType::Linear,
165            oversampling_factor: 128,
166            window: rubato::WindowFunction::BlackmanHarris2,
167        };
168
169        // Create resampler
170        let mut resampler = SincFixedIn::<f64>::new(
171            self.target_sample_rate as f64 / audio_data.info.sample_rate as f64,
172            2.0, // max_resample_ratio_relative
173            params,
174            audio_data.samples.len(),
175            audio_data.info.channels as usize,
176        )
177        .map_err(|e| SubXError::audio_processing(format!("Failed to create resampler: {}", e)))?;
178
179        // Convert sample format to f64
180        let input_channels = if audio_data.info.channels == 1 {
181            vec![
182                audio_data
183                    .samples
184                    .iter()
185                    .map(|&s| s as f64 / 32768.0)
186                    .collect(),
187            ]
188        } else {
189            // Process multi-channel audio
190            let mut channels = vec![Vec::new(); audio_data.info.channels as usize];
191            for (i, &sample) in audio_data.samples.iter().enumerate() {
192                channels[i % audio_data.info.channels as usize].push(sample as f64 / 32768.0);
193            }
194            channels
195        };
196
197        // Perform resampling
198        let output_channels = resampler
199            .process(&input_channels, None)
200            .map_err(|e| SubXError::audio_processing(format!("Resampling failed: {}", e)))?;
201
202        // Convert back to i16 format
203        let mut resampled_samples = Vec::new();
204        if audio_data.info.channels == 1 {
205            resampled_samples = output_channels[0]
206                .iter()
207                .map(|&s| (s * 32767.0) as i16)
208                .collect();
209        } else {
210            // Interleave multi-channel samples
211            let max_len = output_channels.iter().map(|ch| ch.len()).max().unwrap_or(0);
212            for i in 0..max_len {
213                for ch in &output_channels {
214                    if i < ch.len() {
215                        resampled_samples.push((ch[i] * 32767.0) as i16);
216                    }
217                }
218            }
219        }
220
221        let samples_len = resampled_samples.len();
222        let duration_seconds =
223            samples_len as f64 / (self.target_sample_rate as f64 * audio_data.info.channels as f64);
224
225        Ok(ProcessedAudioData {
226            samples: resampled_samples,
227            info: AudioInfo {
228                sample_rate: self.target_sample_rate,
229                channels: audio_data.info.channels,
230                duration_seconds,
231                total_samples: samples_len,
232            },
233        })
234    }
235
236    fn convert_to_mono(&self, audio_data: &ProcessedAudioData) -> Result<ProcessedAudioData> {
237        if audio_data.info.channels == 1 {
238            return Ok(ProcessedAudioData {
239                samples: audio_data.samples.clone(),
240                info: audio_data.info.clone(),
241            });
242        }
243
244        let channels = audio_data.info.channels as usize;
245        let mut mono_samples = Vec::new();
246
247        // Convert to mono (average all channels)
248        for chunk in audio_data.samples.chunks_exact(channels) {
249            let sum: i32 = chunk.iter().map(|&s| s as i32).sum();
250            let average = (sum / channels as i32) as i16;
251            mono_samples.push(average);
252        }
253
254        let samples_len = mono_samples.len();
255        let duration_seconds = samples_len as f64 / audio_data.info.sample_rate as f64;
256
257        Ok(ProcessedAudioData {
258            samples: mono_samples,
259            info: AudioInfo {
260                sample_rate: audio_data.info.sample_rate,
261                channels: 1,
262                duration_seconds,
263                total_samples: samples_len,
264            },
265        })
266    }
267}