1use std::f64::consts::PI;
10
11use scirs2_core::numeric::Complex64;
12
13use crate::error::{FFTError, FFTResult};
14
15#[derive(Debug, Clone)]
17pub struct FilterBankConfig {
18 pub j_max: usize,
20 pub quality_factors: Vec<usize>,
22 pub signal_length: usize,
24 pub xi0: f64,
26 pub sigma: Option<f64>,
28}
29
30impl FilterBankConfig {
31 pub fn new(j_max: usize, quality_factors: Vec<usize>, signal_length: usize) -> Self {
38 Self {
39 j_max,
40 quality_factors,
41 signal_length,
42 xi0: PI,
43 sigma: None,
44 }
45 }
46
47 #[must_use]
49 pub fn with_xi0(mut self, xi0: f64) -> Self {
50 self.xi0 = xi0;
51 self
52 }
53
54 #[must_use]
56 pub fn with_sigma(mut self, sigma: f64) -> Self {
57 self.sigma = Some(sigma);
58 self
59 }
60}
61
62#[derive(Debug, Clone)]
66pub struct MorletWavelet {
67 pub xi: f64,
69 pub sigma: f64,
71 pub j: usize,
73 pub q_index: usize,
75 pub linear_index: usize,
77 pub freq_response: Vec<Complex64>,
79}
80
81#[derive(Debug, Clone)]
83pub struct FilterBank {
84 pub config: FilterBankConfig,
86 pub fft_size: usize,
88 pub wavelets: Vec<Vec<MorletWavelet>>,
90 pub phi: Vec<Complex64>,
92}
93
94impl FilterBank {
95 pub fn new(config: FilterBankConfig) -> FFTResult<Self> {
100 if config.j_max == 0 {
101 return Err(FFTError::ValueError("j_max must be at least 1".to_string()));
102 }
103 if config.quality_factors.is_empty() {
104 return Err(FFTError::ValueError(
105 "quality_factors must have at least one entry".to_string(),
106 ));
107 }
108 for (i, &q) in config.quality_factors.iter().enumerate() {
109 if q == 0 {
110 return Err(FFTError::ValueError(format!(
111 "quality_factors[{i}] must be at least 1"
112 )));
113 }
114 }
115 if config.signal_length == 0 {
116 return Err(FFTError::ValueError(
117 "signal_length must be positive".to_string(),
118 ));
119 }
120
121 let fft_size = config.signal_length.next_power_of_two();
122
123 let mut all_wavelets = Vec::new();
125 for (order, &q) in config.quality_factors.iter().enumerate() {
126 let sigma_base = compute_sigma_from_q(q, config.xi0, config.sigma);
127 let wavelets =
128 build_morlet_wavelets(config.j_max, q, config.xi0, sigma_base, fft_size, order)?;
129 all_wavelets.push(wavelets);
130 }
131
132 let sigma_phi = compute_sigma_from_q(config.quality_factors[0], config.xi0, config.sigma);
134 let phi = build_scaling_function(config.j_max, sigma_phi, fft_size)?;
135
136 Ok(Self {
137 config,
138 fft_size,
139 wavelets: all_wavelets,
140 phi,
141 })
142 }
143
144 pub fn num_first_order(&self) -> usize {
146 self.wavelets.first().map_or(0, |w| w.len())
147 }
148
149 pub fn num_second_order(&self) -> usize {
151 self.wavelets.get(1).map_or(0, |w| w.len())
152 }
153
154 pub fn total_wavelets(&self) -> usize {
156 self.wavelets.iter().map(|w| w.len()).sum()
157 }
158}
159
160fn compute_sigma_from_q(q: usize, xi0: f64, custom_sigma: Option<f64>) -> f64 {
165 if let Some(s) = custom_sigma {
166 return s;
167 }
168 let ln2_sqrt = (2.0_f64 * 2.0_f64.ln()).sqrt();
172 xi0 / (q as f64 * ln2_sqrt)
173}
174
175fn build_morlet_wavelets(
177 j_max: usize,
178 q: usize,
179 xi0: f64,
180 sigma_base: f64,
181 fft_size: usize,
182 _order: usize,
183) -> FFTResult<Vec<MorletWavelet>> {
184 let total = j_max * q;
185 let mut wavelets = Vec::with_capacity(total);
186 let n = fft_size;
187
188 for idx in 0..total {
189 let j = idx / q;
190 let q_index = idx % q;
191
192 let scale = 2.0_f64.powf(idx as f64 / q as f64);
194
195 let xi = xi0 / scale;
197
198 let sigma = sigma_base * scale;
200
201 let mut freq_response = vec![Complex64::new(0.0, 0.0); n];
205 let n_f64 = n as f64;
206
207 for k in 0..n {
208 let omega = 2.0 * PI * k as f64 / n_f64;
210
211 let diff_pos = omega - xi;
213 let gauss_pos = (-0.5 * diff_pos * diff_pos * sigma * sigma).exp();
214
215 let gauss_correction = (-0.5 * xi * xi * sigma * sigma).exp();
217 let gauss_zero = (-0.5 * omega * omega * sigma * sigma).exp();
218
219 let value = gauss_pos - gauss_correction * gauss_zero;
220 freq_response[k] = Complex64::new(value, 0.0);
221 }
222
223 let energy: f64 = freq_response.iter().map(|c| c.norm_sqr()).sum();
225 if energy > 1e-15 {
226 let norm_factor = 1.0 / energy.sqrt();
227 for c in &mut freq_response {
228 *c = Complex64::new(c.re * norm_factor, c.im * norm_factor);
229 }
230 }
231
232 wavelets.push(MorletWavelet {
233 xi,
234 sigma,
235 j,
236 q_index,
237 linear_index: idx,
238 freq_response,
239 });
240 }
241
242 Ok(wavelets)
243}
244
245fn build_scaling_function(
249 j_max: usize,
250 sigma_base: f64,
251 fft_size: usize,
252) -> FFTResult<Vec<Complex64>> {
253 let n = fft_size;
254 let n_f64 = n as f64;
255 let sigma_j = sigma_base * 2.0_f64.powi(j_max as i32);
256
257 let mut phi = vec![Complex64::new(0.0, 0.0); n];
258
259 for k in 0..n {
260 let omega = 2.0 * PI * k as f64 / n_f64;
261 let omega_wrapped = if omega > PI { omega - 2.0 * PI } else { omega };
263 let value = (-0.5 * omega_wrapped * omega_wrapped * sigma_j * sigma_j).exp();
264 phi[k] = Complex64::new(value, 0.0);
265 }
266
267 let energy: f64 = phi.iter().map(|c| c.norm_sqr()).sum();
269 if energy > 1e-15 {
270 let norm_factor = 1.0 / energy.sqrt();
271 for c in &mut phi {
272 *c = Complex64::new(c.re * norm_factor, c.im * norm_factor);
273 }
274 }
275
276 Ok(phi)
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282
283 #[test]
284 fn test_filter_bank_creation() {
285 let config = FilterBankConfig::new(4, vec![8, 1], 1024);
286 let fb = FilterBank::new(config).expect("filter bank creation should succeed");
287
288 assert_eq!(fb.num_first_order(), 32); assert_eq!(fb.num_second_order(), 4); assert_eq!(fb.fft_size, 1024);
291 assert_eq!(fb.phi.len(), 1024);
292 }
293
294 #[test]
295 fn test_wavelet_frequency_peaks() {
296 let config = FilterBankConfig::new(3, vec![4], 512);
298 let fb = FilterBank::new(config).expect("filter bank creation should succeed");
299
300 let first_order = &fb.wavelets[0];
301 for w in first_order {
302 let peak_bin = w
304 .freq_response
305 .iter()
306 .enumerate()
307 .max_by(|(_, a), (_, b)| {
308 a.norm_sqr()
309 .partial_cmp(&b.norm_sqr())
310 .unwrap_or(std::cmp::Ordering::Equal)
311 })
312 .map(|(idx, _)| idx)
313 .expect("should find peak");
314
315 let peak_omega = 2.0 * PI * peak_bin as f64 / fb.fft_size as f64;
316
317 let rel_error = if w.xi > 1e-6 {
320 (peak_omega - w.xi).abs() / w.xi
321 } else {
322 peak_omega.abs()
323 };
324 assert!(
325 rel_error < 0.5,
326 "wavelet j={} q={}: peak_omega={:.4} vs xi={:.4}, rel_error={:.4}",
327 w.j,
328 w.q_index,
329 peak_omega,
330 w.xi,
331 rel_error
332 );
333 }
334 }
335
336 #[test]
337 fn test_dyadic_scaling() {
338 let config = FilterBankConfig::new(4, vec![1], 1024);
340 let fb = FilterBank::new(config).expect("filter bank creation should succeed");
341
342 let wavelets = &fb.wavelets[0];
343 for i in 0..wavelets.len() - 1 {
345 let ratio = wavelets[i].xi / wavelets[i + 1].xi;
346 assert!(
348 (ratio - 2.0).abs() < 0.1,
349 "octave {i} to {}: ratio={:.4}, expected ~2.0",
350 i + 1,
351 ratio
352 );
353 }
354 }
355
356 #[test]
357 fn test_filter_bank_invalid_config() {
358 let config = FilterBankConfig::new(0, vec![8], 1024);
360 assert!(FilterBank::new(config).is_err());
361
362 let config = FilterBankConfig::new(4, vec![], 1024);
364 assert!(FilterBank::new(config).is_err());
365
366 let config = FilterBankConfig::new(4, vec![0], 1024);
368 assert!(FilterBank::new(config).is_err());
369
370 let config = FilterBankConfig::new(4, vec![8], 0);
372 assert!(FilterBank::new(config).is_err());
373 }
374
375 #[test]
376 fn test_wavelet_l2_normalization() {
377 let config = FilterBankConfig::new(3, vec![4], 256);
378 let fb = FilterBank::new(config).expect("filter bank creation should succeed");
379
380 for w in &fb.wavelets[0] {
381 let energy: f64 = w.freq_response.iter().map(|c| c.norm_sqr()).sum();
382 assert!(
383 (energy - 1.0).abs() < 1e-10,
384 "wavelet j={} q={} has energy {:.6}, expected 1.0",
385 w.j,
386 w.q_index,
387 energy
388 );
389 }
390 }
391
392 #[test]
393 fn test_scaling_function_is_lowpass() {
394 let config = FilterBankConfig::new(3, vec![4], 512);
395 let fb = FilterBank::new(config).expect("filter bank creation should succeed");
396
397 let dc_mag = fb.phi[0].norm_sqr();
399 let nyquist_bin = fb.fft_size / 2;
400 let nyquist_mag = fb.phi[nyquist_bin].norm_sqr();
401
402 assert!(
403 dc_mag > nyquist_mag,
404 "scaling function should peak at DC: dc={:.6} vs nyquist={:.6}",
405 dc_mag,
406 nyquist_mag
407 );
408 }
409}