visqol_rs/
envelope.rs

1use crate::fast_fourier_transform;
2use crate::fft_manager::FftManager;
3use ndarray::Array1;
4use num::complex::Complex64;
5
6/// Calculates the upper envelope for a given time domain signal.
7pub fn calculate_upper_env(signal: &Array1<f64>) -> Option<ndarray::Array1<f64>> {
8    let mean = signal.mean()?;
9    let mut signal_centered = signal - mean;
10    let hilbert = calculate_hilbert(signal_centered.as_slice_mut()?)?;
11
12    let mut hilbert_amplitude = Array1::<f64>::zeros(hilbert.len());
13
14    for (amplitude, h) in hilbert_amplitude.iter_mut().zip(&hilbert) {
15        *amplitude = h.norm();
16    }
17    hilbert_amplitude += mean;
18    Some(hilbert_amplitude)
19}
20
21/// Calculates the hilbert transform for a given time domain signal.
22pub fn calculate_hilbert(signal: &mut [f64]) -> Option<Array1<Complex64>> {
23    let mut fft_manager = FftManager::new(signal.len());
24    let freq_domain_signal =
25        fast_fourier_transform::forward_1d_from_matrix(&mut fft_manager, signal);
26
27    let is_odd = signal.len() % 2 == 1;
28    let is_non_empty = !signal.is_empty();
29
30    // Set up scaling vector
31    let mut hilbert_scaling = vec![0.0f64; freq_domain_signal.len()];
32    hilbert_scaling[0] = 1.0;
33
34    if !is_odd && is_non_empty {
35        hilbert_scaling[signal.len() / 2] = 1.0;
36    } else if is_odd && is_non_empty {
37        hilbert_scaling[signal.len() / 2] = 2.0;
38    }
39
40    let n = if is_odd {
41        freq_domain_signal.len().div_ceil(2)
42    } else {
43        freq_domain_signal.len() / 2
44    };
45
46    hilbert_scaling[1..n].fill(2.0);
47
48    let mut element_wise_product = Array1::<Complex64>::zeros(freq_domain_signal.len());
49
50    for i in 0..freq_domain_signal.len() {
51        element_wise_product[i] = freq_domain_signal[i] * hilbert_scaling[i];
52    }
53
54    let mut hilbert =
55        fast_fourier_transform::inverse_1d(&mut fft_manager, element_wise_product.as_slice()?);
56    hilbert
57        .iter_mut()
58        .for_each(|element| *element = *element * 2.0 - 0.000001);
59    Some(Array1::<Complex64>::from_vec(hilbert))
60}
61
62#[cfg(test)]
63mod tests {
64    use super::*;
65    use crate::{
66        audio_signal::AudioSignal,
67        audio_utils::load_as_mono,
68        fft_manager,
69        xcorr::{
70            calculate_best_lag, calculate_fft_pointwise_product,
71            calculate_inverse_fft_pointwise_product, frexp,
72        },
73    };
74    use approx::assert_abs_diff_eq;
75
76    #[test]
77    fn hilbert_transform_on_audio_signal() {
78        let (mut signal, _) = load_audio_files();
79        let result = calculate_hilbert(signal.data_matrix.as_slice_mut().unwrap()).unwrap();
80
81        assert_abs_diff_eq!(result[0].re, 0.000_303_661_691_188_833, epsilon = 0.0001);
82    }
83
84    #[test]
85    fn envelope_on_audio_signal() {
86        let (signal, _) = load_audio_files();
87        let result = calculate_upper_env(&signal.data_matrix).unwrap();
88
89        assert_abs_diff_eq!(result[0], 0.00030159861338215923, epsilon = 0.0001);
90    }
91
92    #[test]
93    fn xcorr_pointwise_prod_on_audio_signal() {
94        let (ref_signal, deg_signal) = load_audio_files();
95        let ref_signal_vec = ref_signal.data_matrix.to_vec();
96
97        let (_, exponent) = frexp((ref_signal_vec.len() * 2 - 1) as f64);
98        let fft_points = 2i32.pow(exponent as u32) as usize;
99        let mut manager = fft_manager::FftManager::new(fft_points);
100
101        let result = calculate_fft_pointwise_product(
102            &ref_signal.data_matrix.to_vec(),
103            &deg_signal.data_matrix.to_vec(),
104            &mut manager,
105            fft_points,
106        );
107
108        assert_abs_diff_eq!(result[0].re, 0.012231532484292984, epsilon = 0.001);
109    }
110
111    #[test]
112    fn calculate_inverse_fft_pointwise_product_on_audio_pair() {
113        let (ref_signal, deg_signal) = load_audio_files();
114
115        let result = calculate_inverse_fft_pointwise_product(
116            &mut ref_signal.data_matrix.to_vec(),
117            &mut deg_signal.data_matrix.to_vec(),
118        );
119
120        assert_abs_diff_eq!(result[0], 79.66060597338944, epsilon = 0.0001);
121    }
122
123    #[test]
124    fn calculate_best_lag_on_audio_signal() {
125        let (ref_signal, deg_signal) = load_audio_files();
126
127        let result = calculate_best_lag(
128            ref_signal.data_matrix.as_slice().unwrap(),
129            deg_signal.data_matrix.as_slice().unwrap(),
130        )
131        .unwrap();
132
133        assert_abs_diff_eq!(result, 0);
134    }
135
136    fn load_audio_files() -> (AudioSignal, AudioSignal) {
137        let ref_signal_path = "test_data/clean_speech/CA01_01.wav";
138        let deg_signal_path = "test_data/clean_speech/transcoded_CA01_01.wav";
139        (
140            load_as_mono(ref_signal_path).unwrap(),
141            load_as_mono(deg_signal_path).unwrap(),
142        )
143    }
144}