visqol_rs/
vad_patch_creator.rs1use crate::patch_creator::PatchCreator;
2use crate::visqol_error::VisqolError;
3use crate::{analysis_window::AnalysisWindow, audio_signal::AudioSignal, math_utils, rms_vad};
4use itertools::Itertools;
5use ndarray::{s, Array2};
6pub struct VadPatchCreator {
8 patch_size: usize,
9 frames_with_va_threshold: f64,
10}
11
12impl PatchCreator for VadPatchCreator {
13 fn create_ref_patch_indices(
14 &self,
15 spectrogram: &Array2<f64>,
16 ref_signal: &AudioSignal,
17 window: &AnalysisWindow,
18 ) -> Result<std::vec::Vec<usize>, VisqolError> {
19 let norm_mat = math_utils::normalize_signal(&ref_signal.data_matrix);
20 let norm_sig = AudioSignal::new(
21 norm_mat
22 .as_slice()
23 .expect("Failed to create AudioSignal from slice!"),
24 ref_signal.sample_rate,
25 );
26
27 let frame_size = (window.size as f64 * window.overlap) as usize;
28 let patch_sample_length = self.patch_size * frame_size;
29 let spectrum_length = spectrogram.ncols();
30 let first_patch_idx = self.patch_size / 2 - 1;
31 let patch_count = (spectrum_length - first_patch_idx) / self.patch_size;
32 let total_sample_count = patch_count * patch_sample_length;
33
34 let mut ref_patch_indices = Vec::<usize>::with_capacity(patch_count);
35
36 let vad_result = self.get_voice_activity(
39 norm_sig
40 .data_matrix
41 .as_slice()
42 .ok_or(VisqolError::FailedToComputeVad)?,
43 first_patch_idx,
44 total_sample_count,
45 frame_size,
46 );
47
48 let mut patch_idx = first_patch_idx;
49
50 for patch in &vad_result.iter().chunks(self.patch_size) {
51 let frames_with_va = patch.sum::<f64>();
52
53 if frames_with_va >= self.frames_with_va_threshold {
54 ref_patch_indices.push(patch_idx);
55 }
56 patch_idx += self.patch_size;
57 }
58
59 Ok(ref_patch_indices)
60 }
61
62 fn create_patches_from_indices(
63 &self,
64 spectrogram: &Array2<f64>,
65 patch_indices: &[usize],
66 ) -> Vec<Array2<f64>> {
67 let mut patches = Vec::<Array2<f64>>::with_capacity(patch_indices.len());
68
69 let mut patch: Array2<f64>;
70
71 let mut end_col: usize;
72 for start_col in patch_indices {
73 end_col = start_col + self.patch_size;
74 patch = spectrogram.slice(s![.., *start_col..end_col]).to_owned();
75 patches.push(patch);
76 }
77 patches
78 }
79}
80
81impl VadPatchCreator {
82 pub fn new(patch_size: usize) -> Self {
84 Self {
85 patch_size,
86 frames_with_va_threshold: 1.0,
87 }
88 }
89
90 pub fn get_voice_activity(
92 &self,
93 signal: &[f64],
94 start_sample: usize,
95 total_samples: usize,
96 frame_length: usize,
97 ) -> Vec<f64> {
98 let mut vad = rms_vad::RmsVad::default();
99
100 let patch = &signal[start_sample..start_sample + total_samples];
101
102 let mut frame = Vec::<i16>::with_capacity(frame_length);
103 for patch_element in patch {
104 let mut scaled_val = ((*patch_element * ((1 << 15) as f64)) as i16) as f64;
105 scaled_val = (-(1 << 15) as f64)
106 .max(1.0 * ((1 << 15) - 1) as f64)
107 .min(scaled_val);
108 frame.push(scaled_val as i16);
109
110 if frame.len() == frame_length {
111 vad.process_chunk(&frame);
112 frame.clear();
113 }
114 }
115 vad.get_vad_results()
116 }
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122 use crate::analysis_window::AnalysisWindow;
123 use crate::audio_utils::load_as_mono;
124 use crate::constants::NUM_BANDS_SPEECH;
125 use crate::gammatone_filterbank::GammatoneFilterbank;
126 use crate::gammatone_spectrogram_builder::GammatoneSpectrogramBuilder;
127 use crate::patch_creator::PatchCreator;
128 use crate::spectrogram_builder::SpectrogramBuilder;
129
130 #[test]
131 fn clean_speech_vad() {
132 const K_START_SAMPLE: usize = 14;
133 const K_TOTAL_SAMPLE: usize = 115200;
134 const K_FRAME_LEN: usize = 480;
135 const K_CA01_01_VAD_RES_COUNT: usize = 240;
136
137 let ref_signal = load_as_mono("test_data/clean_speech/CA01_01.wav").unwrap();
138
139 let vad = VadPatchCreator::new(20);
140 let res = vad.get_voice_activity(
141 ref_signal.data_matrix.as_slice().unwrap(),
142 K_START_SAMPLE,
143 K_TOTAL_SAMPLE,
144 K_FRAME_LEN,
145 );
146 assert_eq!(K_CA01_01_VAD_RES_COUNT, res.len());
147 }
148
149 #[test]
150 fn patch_indices() {
151 const _K_MINIMUM_FREQ: f64 = 50.0;
152 const K_PATCH_SIZE: usize = 20;
153
154 let expected_patches = vec![9, 29, 49, 69, 89];
155 let ref_signal = load_as_mono("test_data/clean_speech/CA01_01.wav").unwrap();
156
157 let mut spectrogram_builder: GammatoneSpectrogramBuilder<NUM_BANDS_SPEECH> =
158 GammatoneSpectrogramBuilder::new(GammatoneFilterbank::<NUM_BANDS_SPEECH>::new(50.0));
159 let window = AnalysisWindow::new(ref_signal.sample_rate, 0.25, 0.08);
160
161 let spectrogram = spectrogram_builder.build(&ref_signal, &window).unwrap();
162
163 let vad = VadPatchCreator::new(K_PATCH_SIZE);
164 let patches = vad
165 .create_ref_patch_indices(&spectrogram.data, &ref_signal, &window)
166 .unwrap();
167
168 assert_eq!(patches.len(), expected_patches.len());
169 for (&a, b) in patches.iter().zip(expected_patches) {
170 assert_eq!(a, b);
171 }
172 }
173}