1use crate::fast_fourier_transform;
2use crate::fft_manager::FftManager;
3use ndarray::Array1;
4use num::complex::Complex64;
5
6pub fn calculate_upper_env(signal: &Array1<f64>) -> Option<ndarray::Array1<f64>> {
8 let mean = signal.mean()?;
9 let mut signal_centered = signal - mean;
10 let hilbert = calculate_hilbert(signal_centered.as_slice_mut()?)?;
11
12 let mut hilbert_amplitude = Array1::<f64>::zeros(hilbert.len());
13
14 for (amplitude, h) in hilbert_amplitude.iter_mut().zip(&hilbert) {
15 *amplitude = h.norm();
16 }
17 hilbert_amplitude += mean;
18 Some(hilbert_amplitude)
19}
20
21pub fn calculate_hilbert(signal: &mut [f64]) -> Option<Array1<Complex64>> {
23 let mut fft_manager = FftManager::new(signal.len());
24 let freq_domain_signal =
25 fast_fourier_transform::forward_1d_from_matrix(&mut fft_manager, signal);
26
27 let is_odd = signal.len() % 2 == 1;
28 let is_non_empty = !signal.is_empty();
29
30 let mut hilbert_scaling = vec![0.0f64; freq_domain_signal.len()];
32 hilbert_scaling[0] = 1.0;
33
34 if !is_odd && is_non_empty {
35 hilbert_scaling[signal.len() / 2] = 1.0;
36 } else if is_odd && is_non_empty {
37 hilbert_scaling[signal.len() / 2] = 2.0;
38 }
39
40 let n = if is_odd {
41 freq_domain_signal.len().div_ceil(2)
42 } else {
43 freq_domain_signal.len() / 2
44 };
45
46 hilbert_scaling[1..n].fill(2.0);
47
48 let mut element_wise_product = Array1::<Complex64>::zeros(freq_domain_signal.len());
49
50 for i in 0..freq_domain_signal.len() {
51 element_wise_product[i] = freq_domain_signal[i] * hilbert_scaling[i];
52 }
53
54 let mut hilbert =
55 fast_fourier_transform::inverse_1d(&mut fft_manager, element_wise_product.as_slice()?);
56 hilbert
57 .iter_mut()
58 .for_each(|element| *element = *element * 2.0 - 0.000001);
59 Some(Array1::<Complex64>::from_vec(hilbert))
60}
61
62#[cfg(test)]
63mod tests {
64 use super::*;
65 use crate::{
66 audio_signal::AudioSignal,
67 audio_utils::load_as_mono,
68 fft_manager,
69 xcorr::{
70 calculate_best_lag, calculate_fft_pointwise_product,
71 calculate_inverse_fft_pointwise_product, frexp,
72 },
73 };
74 use approx::assert_abs_diff_eq;
75
76 #[test]
77 fn hilbert_transform_on_audio_signal() {
78 let (mut signal, _) = load_audio_files();
79 let result = calculate_hilbert(signal.data_matrix.as_slice_mut().unwrap()).unwrap();
80
81 assert_abs_diff_eq!(result[0].re, 0.000_303_661_691_188_833, epsilon = 0.0001);
82 }
83
84 #[test]
85 fn envelope_on_audio_signal() {
86 let (signal, _) = load_audio_files();
87 let result = calculate_upper_env(&signal.data_matrix).unwrap();
88
89 assert_abs_diff_eq!(result[0], 0.00030159861338215923, epsilon = 0.0001);
90 }
91
92 #[test]
93 fn xcorr_pointwise_prod_on_audio_signal() {
94 let (ref_signal, deg_signal) = load_audio_files();
95 let ref_signal_vec = ref_signal.data_matrix.to_vec();
96
97 let (_, exponent) = frexp((ref_signal_vec.len() * 2 - 1) as f64);
98 let fft_points = 2i32.pow(exponent as u32) as usize;
99 let mut manager = fft_manager::FftManager::new(fft_points);
100
101 let result = calculate_fft_pointwise_product(
102 &ref_signal.data_matrix.to_vec(),
103 °_signal.data_matrix.to_vec(),
104 &mut manager,
105 fft_points,
106 );
107
108 assert_abs_diff_eq!(result[0].re, 0.012231532484292984, epsilon = 0.001);
109 }
110
111 #[test]
112 fn calculate_inverse_fft_pointwise_product_on_audio_pair() {
113 let (ref_signal, deg_signal) = load_audio_files();
114
115 let result = calculate_inverse_fft_pointwise_product(
116 &mut ref_signal.data_matrix.to_vec(),
117 &mut deg_signal.data_matrix.to_vec(),
118 );
119
120 assert_abs_diff_eq!(result[0], 79.66060597338944, epsilon = 0.0001);
121 }
122
123 #[test]
124 fn calculate_best_lag_on_audio_signal() {
125 let (ref_signal, deg_signal) = load_audio_files();
126
127 let result = calculate_best_lag(
128 ref_signal.data_matrix.as_slice().unwrap(),
129 deg_signal.data_matrix.as_slice().unwrap(),
130 )
131 .unwrap();
132
133 assert_abs_diff_eq!(result, 0);
134 }
135
136 fn load_audio_files() -> (AudioSignal, AudioSignal) {
137 let ref_signal_path = "test_data/clean_speech/CA01_01.wav";
138 let deg_signal_path = "test_data/clean_speech/transcoded_CA01_01.wav";
139 (
140 load_as_mono(ref_signal_path).unwrap(),
141 load_as_mono(deg_signal_path).unwrap(),
142 )
143 }
144}