visqol_rs/
vad_patch_creator.rs

1use crate::patch_creator::PatchCreator;
2use crate::visqol_error::VisqolError;
3use crate::{analysis_window::AnalysisWindow, audio_signal::AudioSignal, math_utils, rms_vad};
4use itertools::Itertools;
5use ndarray::{s, Array2};
6/// Computes patch indices from a spectrogram by analyzing voice acitivity in the time domain and rejecting patches which are considered silent.
7pub struct VadPatchCreator {
8    patch_size: usize,
9    frames_with_va_threshold: f64,
10}
11
12impl PatchCreator for VadPatchCreator {
13    fn create_ref_patch_indices(
14        &self,
15        spectrogram: &Array2<f64>,
16        ref_signal: &AudioSignal,
17        window: &AnalysisWindow,
18    ) -> Result<std::vec::Vec<usize>, VisqolError> {
19        let norm_mat = math_utils::normalize_signal(&ref_signal.data_matrix);
20        let norm_sig = AudioSignal::new(
21            norm_mat
22                .as_slice()
23                .expect("Failed to create AudioSignal from slice!"),
24            ref_signal.sample_rate,
25        );
26
27        let frame_size = (window.size as f64 * window.overlap) as usize;
28        let patch_sample_length = self.patch_size * frame_size;
29        let spectrum_length = spectrogram.ncols();
30        let first_patch_idx = self.patch_size / 2 - 1;
31        let patch_count = (spectrum_length - first_patch_idx) / self.patch_size;
32        let total_sample_count = patch_count * patch_sample_length;
33
34        let mut ref_patch_indices = Vec::<usize>::with_capacity(patch_count);
35
36        // Pass the reference signal to the VAD to determine which frames have voice
37        // activity.
38        let vad_result = self.get_voice_activity(
39            norm_sig
40                .data_matrix
41                .as_slice()
42                .ok_or(VisqolError::FailedToComputeVad)?,
43            first_patch_idx,
44            total_sample_count,
45            frame_size,
46        );
47
48        let mut patch_idx = first_patch_idx;
49
50        for patch in &vad_result.iter().chunks(self.patch_size) {
51            let frames_with_va = patch.sum::<f64>();
52
53            if frames_with_va >= self.frames_with_va_threshold {
54                ref_patch_indices.push(patch_idx);
55            }
56            patch_idx += self.patch_size;
57        }
58
59        Ok(ref_patch_indices)
60    }
61
62    fn create_patches_from_indices(
63        &self,
64        spectrogram: &Array2<f64>,
65        patch_indices: &[usize],
66    ) -> Vec<Array2<f64>> {
67        let mut patches = Vec::<Array2<f64>>::with_capacity(patch_indices.len());
68
69        let mut patch: Array2<f64>;
70
71        let mut end_col: usize;
72        for start_col in patch_indices {
73            end_col = start_col + self.patch_size;
74            patch = spectrogram.slice(s![.., *start_col..end_col]).to_owned();
75            patches.push(patch);
76        }
77        patches
78    }
79}
80
81impl VadPatchCreator {
82    /// Creates a new `VadPatchCreator` with the desired patch size.
83    pub fn new(patch_size: usize) -> Self {
84        Self {
85            patch_size,
86            frames_with_va_threshold: 1.0,
87        }
88    }
89
90    /// Given a time domain signal, this function returns a vector with 1s indicating voice acitivity and 0s indicating the absence of acitivity.
91    pub fn get_voice_activity(
92        &self,
93        signal: &[f64],
94        start_sample: usize,
95        total_samples: usize,
96        frame_length: usize,
97    ) -> Vec<f64> {
98        let mut vad = rms_vad::RmsVad::default();
99
100        let patch = &signal[start_sample..start_sample + total_samples];
101
102        let mut frame = Vec::<i16>::with_capacity(frame_length);
103        for patch_element in patch {
104            let mut scaled_val = ((*patch_element * ((1 << 15) as f64)) as i16) as f64;
105            scaled_val = (-(1 << 15) as f64)
106                .max(1.0 * ((1 << 15) - 1) as f64)
107                .min(scaled_val);
108            frame.push(scaled_val as i16);
109
110            if frame.len() == frame_length {
111                vad.process_chunk(&frame);
112                frame.clear();
113            }
114        }
115        vad.get_vad_results()
116    }
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122    use crate::analysis_window::AnalysisWindow;
123    use crate::audio_utils::load_as_mono;
124    use crate::constants::NUM_BANDS_SPEECH;
125    use crate::gammatone_filterbank::GammatoneFilterbank;
126    use crate::gammatone_spectrogram_builder::GammatoneSpectrogramBuilder;
127    use crate::patch_creator::PatchCreator;
128    use crate::spectrogram_builder::SpectrogramBuilder;
129
130    #[test]
131    fn clean_speech_vad() {
132        const K_START_SAMPLE: usize = 14;
133        const K_TOTAL_SAMPLE: usize = 115200;
134        const K_FRAME_LEN: usize = 480;
135        const K_CA01_01_VAD_RES_COUNT: usize = 240;
136
137        let ref_signal = load_as_mono("test_data/clean_speech/CA01_01.wav").unwrap();
138
139        let vad = VadPatchCreator::new(20);
140        let res = vad.get_voice_activity(
141            ref_signal.data_matrix.as_slice().unwrap(),
142            K_START_SAMPLE,
143            K_TOTAL_SAMPLE,
144            K_FRAME_LEN,
145        );
146        assert_eq!(K_CA01_01_VAD_RES_COUNT, res.len());
147    }
148
149    #[test]
150    fn patch_indices() {
151        const _K_MINIMUM_FREQ: f64 = 50.0;
152        const K_PATCH_SIZE: usize = 20;
153
154        let expected_patches = vec![9, 29, 49, 69, 89];
155        let ref_signal = load_as_mono("test_data/clean_speech/CA01_01.wav").unwrap();
156
157        let mut spectrogram_builder: GammatoneSpectrogramBuilder<NUM_BANDS_SPEECH> =
158            GammatoneSpectrogramBuilder::new(GammatoneFilterbank::<NUM_BANDS_SPEECH>::new(50.0));
159        let window = AnalysisWindow::new(ref_signal.sample_rate, 0.25, 0.08);
160
161        let spectrogram = spectrogram_builder.build(&ref_signal, &window).unwrap();
162
163        let vad = VadPatchCreator::new(K_PATCH_SIZE);
164        let patches = vad
165            .create_ref_patch_indices(&spectrogram.data, &ref_signal, &window)
166            .unwrap();
167
168        assert_eq!(patches.len(), expected_patches.len());
169        for (&a, b) in patches.iter().zip(expected_patches) {
170            assert_eq!(a, b);
171        }
172    }
173}