subx_cli/services/vad/
audio_processor.rs1use super::AudioInfo;
2use crate::{Result, error::SubXError};
3use hound::{SampleFormat, WavReader};
4use rubato::{Resampler, SincFixedIn, SincInterpolationParameters, SincInterpolationType};
5use std::fs::File;
6use std::io::BufReader;
7use std::path::Path;
8
9pub struct VadAudioProcessor {
14 target_sample_rate: u32,
15 target_channels: u16,
16}
17
18#[derive(Debug)]
23pub struct ProcessedAudioData {
24 pub samples: Vec<i16>,
26 pub info: AudioInfo,
28}
29
30impl VadAudioProcessor {
31 pub fn new(target_sample_rate: u32, target_channels: u16) -> Result<Self> {
42 Ok(Self {
43 target_sample_rate,
44 target_channels,
45 })
46 }
47
48 pub async fn load_and_prepare_audio(&self, audio_path: &Path) -> Result<ProcessedAudioData> {
70 let raw_audio_data = self.load_wav_file(audio_path)?;
72
73 let resampled_data = if raw_audio_data.info.sample_rate != self.target_sample_rate {
75 self.resample_audio(&raw_audio_data)?
76 } else {
77 raw_audio_data
78 };
79
80 let mono_data = if resampled_data.info.channels > 1 {
82 self.convert_to_mono(&resampled_data)?
83 } else {
84 resampled_data
85 };
86
87 Ok(mono_data)
88 }
89
90 fn load_wav_file(&self, path: &Path) -> Result<ProcessedAudioData> {
91 let file = File::open(path).map_err(|e| {
92 SubXError::audio_processing(format!("Failed to open audio file: {}", e))
93 })?;
94
95 let reader = WavReader::new(BufReader::new(file))
96 .map_err(|e| SubXError::audio_processing(format!("Failed to read WAV file: {}", e)))?;
97
98 let spec = reader.spec();
99 let sample_rate = spec.sample_rate;
100 let channels = spec.channels;
101
102 let samples: Vec<i16> = match spec.sample_format {
104 SampleFormat::Int => match spec.bits_per_sample {
105 16 => {
106 let samples: std::result::Result<Vec<i16>, hound::Error> =
107 reader.into_samples::<i16>().collect();
108 samples.map_err(|e| {
109 SubXError::audio_processing(format!("Failed to read samples: {}", e))
110 })?
111 }
112 32 => {
113 let samples: std::result::Result<Vec<i32>, hound::Error> =
114 reader.into_samples::<i32>().collect();
115 let i32_samples = samples.map_err(|e| {
116 SubXError::audio_processing(format!("Failed to read i32 samples: {}", e))
117 })?;
118 i32_samples.iter().map(|&s| (s >> 16) as i16).collect()
119 }
120 _ => {
121 return Err(SubXError::audio_processing(format!(
122 "Unsupported bit depth: {}",
123 spec.bits_per_sample
124 )));
125 }
126 },
127 SampleFormat::Float => {
128 let samples: std::result::Result<Vec<f32>, hound::Error> =
129 reader.into_samples::<f32>().collect();
130 let f32_samples = samples.map_err(|e| {
131 SubXError::audio_processing(format!("Failed to read f32 samples: {}", e))
132 })?;
133 f32_samples.iter().map(|&s| (s * 32767.0) as i16).collect()
134 }
135 };
136
137 let samples_len = samples.len();
138 let duration_seconds = samples_len as f64 / (sample_rate as f64 * channels as f64);
139
140 Ok(ProcessedAudioData {
141 samples,
142 info: AudioInfo {
143 sample_rate,
144 channels,
145 duration_seconds,
146 total_samples: samples_len,
147 },
148 })
149 }
150
151 fn resample_audio(&self, audio_data: &ProcessedAudioData) -> Result<ProcessedAudioData> {
152 if audio_data.info.sample_rate == self.target_sample_rate {
153 return Ok(ProcessedAudioData {
155 samples: audio_data.samples.clone(),
156 info: audio_data.info.clone(),
157 });
158 }
159
160 let params = SincInterpolationParameters {
162 sinc_len: 256,
163 f_cutoff: 0.95,
164 interpolation: SincInterpolationType::Linear,
165 oversampling_factor: 128,
166 window: rubato::WindowFunction::BlackmanHarris2,
167 };
168
169 let mut resampler = SincFixedIn::<f64>::new(
171 self.target_sample_rate as f64 / audio_data.info.sample_rate as f64,
172 2.0, params,
174 audio_data.samples.len(),
175 audio_data.info.channels as usize,
176 )
177 .map_err(|e| SubXError::audio_processing(format!("Failed to create resampler: {}", e)))?;
178
179 let input_channels = if audio_data.info.channels == 1 {
181 vec![
182 audio_data
183 .samples
184 .iter()
185 .map(|&s| s as f64 / 32768.0)
186 .collect(),
187 ]
188 } else {
189 let mut channels = vec![Vec::new(); audio_data.info.channels as usize];
191 for (i, &sample) in audio_data.samples.iter().enumerate() {
192 channels[i % audio_data.info.channels as usize].push(sample as f64 / 32768.0);
193 }
194 channels
195 };
196
197 let output_channels = resampler
199 .process(&input_channels, None)
200 .map_err(|e| SubXError::audio_processing(format!("Resampling failed: {}", e)))?;
201
202 let mut resampled_samples = Vec::new();
204 if audio_data.info.channels == 1 {
205 resampled_samples = output_channels[0]
206 .iter()
207 .map(|&s| (s * 32767.0) as i16)
208 .collect();
209 } else {
210 let max_len = output_channels.iter().map(|ch| ch.len()).max().unwrap_or(0);
212 for i in 0..max_len {
213 for ch in &output_channels {
214 if i < ch.len() {
215 resampled_samples.push((ch[i] * 32767.0) as i16);
216 }
217 }
218 }
219 }
220
221 let samples_len = resampled_samples.len();
222 let duration_seconds =
223 samples_len as f64 / (self.target_sample_rate as f64 * audio_data.info.channels as f64);
224
225 Ok(ProcessedAudioData {
226 samples: resampled_samples,
227 info: AudioInfo {
228 sample_rate: self.target_sample_rate,
229 channels: audio_data.info.channels,
230 duration_seconds,
231 total_samples: samples_len,
232 },
233 })
234 }
235
236 fn convert_to_mono(&self, audio_data: &ProcessedAudioData) -> Result<ProcessedAudioData> {
237 if audio_data.info.channels == 1 {
238 return Ok(ProcessedAudioData {
239 samples: audio_data.samples.clone(),
240 info: audio_data.info.clone(),
241 });
242 }
243
244 let channels = audio_data.info.channels as usize;
245 let mut mono_samples = Vec::new();
246
247 for chunk in audio_data.samples.chunks_exact(channels) {
249 let sum: i32 = chunk.iter().map(|&s| s as i32).sum();
250 let average = (sum / channels as i32) as i16;
251 mono_samples.push(average);
252 }
253
254 let samples_len = mono_samples.len();
255 let duration_seconds = samples_len as f64 / audio_data.info.sample_rate as f64;
256
257 Ok(ProcessedAudioData {
258 samples: mono_samples,
259 info: AudioInfo {
260 sample_rate: audio_data.info.sample_rate,
261 channels: 1,
262 duration_seconds,
263 total_samples: samples_len,
264 },
265 })
266 }
267}