1use scirs2_core::numeric::Complex64;
13
14use crate::error::{FFTError, FFTResult};
15use crate::fft::{fft, ifft};
16
17use super::filter_bank::{FilterBank, FilterBankConfig};
18
19#[derive(Debug, Clone)]
21pub struct ScatteringConfig {
22 pub j_max: usize,
24 pub quality_factors: Vec<usize>,
26 pub max_order: usize,
28 pub average: bool,
30 pub oversampling: usize,
32}
33
34impl ScatteringConfig {
35 pub fn new(j_max: usize, quality_factors: Vec<usize>) -> Self {
41 Self {
42 j_max,
43 quality_factors,
44 max_order: 2,
45 average: true,
46 oversampling: 0,
47 }
48 }
49
50 #[must_use]
52 pub fn with_max_order(mut self, order: usize) -> Self {
53 self.max_order = order.min(2);
54 self
55 }
56
57 #[must_use]
59 pub fn with_average(mut self, average: bool) -> Self {
60 self.average = average;
61 self
62 }
63
64 #[must_use]
66 pub fn with_oversampling(mut self, oversampling: usize) -> Self {
67 self.oversampling = oversampling;
68 self
69 }
70}
71
72#[derive(Debug, Clone)]
74pub enum ScatteringOrder {
75 Zeroth,
77 First { lambda1: usize },
79 Second { lambda1: usize, lambda2: usize },
81}
82
83#[derive(Debug, Clone)]
85pub struct ScatteringCoefficients {
86 pub order: ScatteringOrder,
88 pub values: Vec<f64>,
90}
91
92#[derive(Debug, Clone)]
94pub struct ScatteringResult {
95 pub coefficients: Vec<ScatteringCoefficients>,
97 pub num_zeroth: usize,
99 pub num_first: usize,
101 pub num_second: usize,
103 pub output_length: usize,
105}
106
107impl ScatteringResult {
108 pub fn zeroth_order(&self) -> &[ScatteringCoefficients] {
110 &self.coefficients[..self.num_zeroth]
111 }
112
113 pub fn first_order(&self) -> &[ScatteringCoefficients] {
115 &self.coefficients[self.num_zeroth..self.num_zeroth + self.num_first]
116 }
117
118 pub fn second_order(&self) -> &[ScatteringCoefficients] {
120 &self.coefficients[self.num_zeroth + self.num_first..]
121 }
122
123 pub fn flatten(&self) -> Vec<f64> {
125 let mut result = Vec::new();
126 for coeff in &self.coefficients {
127 result.extend_from_slice(&coeff.values);
128 }
129 result
130 }
131
132 pub fn total_energy(&self) -> f64 {
134 self.coefficients
135 .iter()
136 .flat_map(|c| c.values.iter())
137 .map(|v| v * v)
138 .sum()
139 }
140}
141
142#[derive(Debug, Clone)]
144pub struct ScatteringTransform {
145 config: ScatteringConfig,
147 filter_bank: FilterBank,
149}
150
151impl ScatteringTransform {
152 pub fn new(config: ScatteringConfig, signal_length: usize) -> FFTResult<Self> {
158 if signal_length == 0 {
159 return Err(FFTError::ValueError(
160 "signal_length must be positive".to_string(),
161 ));
162 }
163
164 let fb_config =
165 FilterBankConfig::new(config.j_max, config.quality_factors.clone(), signal_length);
166 let filter_bank = FilterBank::new(fb_config)?;
167
168 Ok(Self {
169 config,
170 filter_bank,
171 })
172 }
173
174 pub fn filter_bank(&self) -> &FilterBank {
176 &self.filter_bank
177 }
178
179 pub fn transform(&self, signal: &[f64]) -> FFTResult<ScatteringResult> {
183 if signal.is_empty() {
184 return Err(FFTError::ValueError(
185 "Input signal must not be empty".to_string(),
186 ));
187 }
188
189 let fft_size = self.filter_bank.fft_size;
190
191 let mut padded = vec![0.0_f64; fft_size];
193 let copy_len = signal.len().min(fft_size);
194 padded[..copy_len].copy_from_slice(&signal[..copy_len]);
195
196 let x_hat = fft(&padded, Some(fft_size))?;
198
199 let subsample = if self.config.average {
201 let base = 2_usize.pow(self.config.j_max as u32);
202 base >> self.config.oversampling.min(self.config.j_max)
203 } else {
204 1
205 };
206 let output_length = fft_size.div_ceil(subsample);
207
208 let mut coefficients = Vec::new();
209
210 let mut num_first = 0;
211 let mut num_second = 0;
212
213 let s0 = convolve_and_subsample(&x_hat, &self.filter_bank.phi, fft_size, subsample)?;
215 coefficients.push(ScatteringCoefficients {
216 order: ScatteringOrder::Zeroth,
217 values: s0,
218 });
219 let num_zeroth = 1;
220
221 if self.config.max_order == 0 {
222 return Ok(ScatteringResult {
223 coefficients,
224 num_zeroth,
225 num_first,
226 num_second,
227 output_length,
228 });
229 }
230
231 let first_order_wavelets = self
233 .filter_bank
234 .wavelets
235 .first()
236 .ok_or_else(|| FFTError::ComputationError("No first-order wavelets".to_string()))?;
237
238 let mut u1_hats: Vec<Vec<Complex64>> = Vec::new();
240
241 for (lambda1, wavelet) in first_order_wavelets.iter().enumerate() {
242 let convolved: Vec<Complex64> = x_hat
244 .iter()
245 .zip(wavelet.freq_response.iter())
246 .map(|(x, w)| x * w)
247 .collect();
248
249 let u1_time = ifft(&convolved, None)?;
251
252 let u1_mod: Vec<f64> = u1_time.iter().map(|c| c.norm()).collect();
254
255 if self.config.max_order >= 2 {
257 let u1_mod_hat = fft(&u1_mod, Some(fft_size))?;
258 u1_hats.push(u1_mod_hat);
259 }
260
261 let u1_mod_hat_for_avg = if self.config.max_order >= 2 {
263 u1_hats.last().ok_or_else(|| {
265 FFTError::ComputationError("u1_hats should not be empty".to_string())
266 })?
267 } else {
268 &fft(&u1_mod, Some(fft_size))?
270 };
271
272 let s1 = convolve_and_subsample(
273 u1_mod_hat_for_avg,
274 &self.filter_bank.phi,
275 fft_size,
276 subsample,
277 )?;
278
279 coefficients.push(ScatteringCoefficients {
280 order: ScatteringOrder::First { lambda1 },
281 values: s1,
282 });
283 num_first += 1;
284 }
285
286 if self.config.max_order < 2 {
287 return Ok(ScatteringResult {
288 coefficients,
289 num_zeroth,
290 num_first,
291 num_second,
292 output_length,
293 });
294 }
295
296 let second_order_wavelets = if self.filter_bank.wavelets.len() > 1 {
299 &self.filter_bank.wavelets[1]
300 } else {
301 &self.filter_bank.wavelets[0]
303 };
304
305 for (lambda1, u1_hat) in u1_hats.iter().enumerate() {
306 for (lambda2, wavelet2) in second_order_wavelets.iter().enumerate() {
307 let first_scale = if !first_order_wavelets.is_empty() {
310 first_order_wavelets[lambda1].j
311 } else {
312 0
313 };
314 let second_scale = wavelet2.j;
315
316 if second_scale <= first_scale {
317 continue;
318 }
319
320 let convolved2: Vec<Complex64> = u1_hat
322 .iter()
323 .zip(wavelet2.freq_response.iter())
324 .map(|(u, w)| u * w)
325 .collect();
326
327 let u2_time = ifft(&convolved2, None)?;
328
329 let u2_mod: Vec<f64> = u2_time.iter().map(|c| c.norm()).collect();
331
332 let u2_mod_hat = fft(&u2_mod, Some(fft_size))?;
334 let s2 = convolve_and_subsample(
335 &u2_mod_hat,
336 &self.filter_bank.phi,
337 fft_size,
338 subsample,
339 )?;
340
341 coefficients.push(ScatteringCoefficients {
342 order: ScatteringOrder::Second { lambda1, lambda2 },
343 values: s2,
344 });
345 num_second += 1;
346 }
347 }
348
349 Ok(ScatteringResult {
350 coefficients,
351 num_zeroth,
352 num_first,
353 num_second,
354 output_length,
355 })
356 }
357
358 pub fn features(&self, signal: &[f64]) -> FFTResult<Vec<f64>> {
360 let result = self.transform(signal)?;
361 Ok(result.flatten())
362 }
363}
364
365fn convolve_and_subsample(
367 x_hat: &[Complex64],
368 filter_hat: &[Complex64],
369 fft_size: usize,
370 subsample: usize,
371) -> FFTResult<Vec<f64>> {
372 let product: Vec<Complex64> = x_hat
374 .iter()
375 .zip(filter_hat.iter())
376 .map(|(x, f)| x * f)
377 .collect();
378
379 let time_domain = ifft(&product, None)?;
381
382 let output_len = fft_size.div_ceil(subsample);
384 let mut result = Vec::with_capacity(output_len);
385 for i in 0..output_len {
386 let idx = i * subsample;
387 if idx < time_domain.len() {
388 result.push(time_domain[idx].re);
389 } else {
390 result.push(0.0);
391 }
392 }
393
394 Ok(result)
395}
396
397#[cfg(test)]
398mod tests {
399 use super::*;
400 use std::f64::consts::PI;
401
402 #[test]
403 fn test_scattering_basic() {
404 let config = ScatteringConfig::new(3, vec![2, 1]);
405 let st = ScatteringTransform::new(config, 256)
406 .expect("scattering transform creation should succeed");
407
408 let signal: Vec<f64> = (0..256)
410 .map(|i| (2.0 * PI * 10.0 * i as f64 / 256.0).sin())
411 .collect();
412
413 let result = st.transform(&signal).expect("transform should succeed");
414
415 assert_eq!(result.num_zeroth, 1);
416 assert!(result.num_first > 0);
417 }
419
420 #[test]
421 fn test_translation_invariance() {
422 let config = ScatteringConfig::new(3, vec![4, 1]).with_max_order(1);
427 let n = 512;
428 let st = ScatteringTransform::new(config, n)
429 .expect("scattering transform creation should succeed");
430
431 let mut signal1 = vec![0.0; n];
433 for i in 0..n {
434 let t = (i as f64 - 128.0) / 20.0;
435 signal1[i] = (-0.5 * t * t).exp();
436 }
437
438 let shift = 64;
440 let mut signal2 = vec![0.0; n];
441 for i in 0..n {
442 let src = (i + n - shift) % n;
443 signal2[i] = signal1[src];
444 }
445
446 let r1 = st.transform(&signal1).expect("transform should succeed");
447 let r2 = st.transform(&signal2).expect("transform should succeed");
448
449 let s1_energies_1: Vec<f64> = r1
453 .first_order()
454 .iter()
455 .map(|c| c.values.iter().map(|v| v * v).sum::<f64>())
456 .collect();
457 let s1_energies_2: Vec<f64> = r2
458 .first_order()
459 .iter()
460 .map(|c| c.values.iter().map(|v| v * v).sum::<f64>())
461 .collect();
462
463 let total_e1: f64 = s1_energies_1.iter().sum();
464 let total_e2: f64 = s1_energies_2.iter().sum();
465
466 if total_e1 > 1e-15 {
467 let rel_error = ((total_e1 - total_e2) / total_e1).abs();
468 assert!(
469 rel_error < 0.3,
470 "First-order total energy should be approximately translation invariant, \
471 rel_error={:.4} (e1={:.4}, e2={:.4})",
472 rel_error,
473 total_e1,
474 total_e2
475 );
476 }
477 }
478
479 #[test]
480 fn test_output_dimensions() {
481 let j = 3;
482 let q1 = 4;
483 let q2 = 1;
484 let config = ScatteringConfig::new(j, vec![q1, q2]);
485 let n = 256;
486 let st = ScatteringTransform::new(config, n)
487 .expect("scattering transform creation should succeed");
488
489 let signal: Vec<f64> = (0..n)
490 .map(|i| (2.0 * PI * 5.0 * i as f64 / n as f64).sin())
491 .collect();
492
493 let result = st.transform(&signal).expect("transform should succeed");
494
495 assert_eq!(result.num_first, j * q1);
497
498 let _ = result.num_second;
503
504 let expected_len = result.output_length;
506 for coeff in &result.coefficients {
507 assert_eq!(
508 coeff.values.len(),
509 expected_len,
510 "coefficient output length mismatch"
511 );
512 }
513 }
514
515 #[test]
516 fn test_energy_approximate_preservation() {
517 let config = ScatteringConfig::new(3, vec![4, 1]);
518 let n = 256;
519 let st = ScatteringTransform::new(config, n)
520 .expect("scattering transform creation should succeed");
521
522 let signal: Vec<f64> = (0..n)
523 .map(|i| {
524 let t = i as f64 / n as f64;
525 (2.0 * PI * 8.0 * t).sin() + 0.5 * (2.0 * PI * 32.0 * t).cos()
526 })
527 .collect();
528
529 let input_energy: f64 = signal.iter().map(|v| v * v).sum();
530 let result = st.transform(&signal).expect("transform should succeed");
531 let scatter_energy = result.total_energy();
532
533 assert!(scatter_energy > 0.0, "scattering energy should be positive");
537 }
538
539 #[test]
540 fn test_sine_wave_first_order() {
541 let config = ScatteringConfig::new(4, vec![8]).with_max_order(1);
542 let n = 1024;
543 let st = ScatteringTransform::new(config, n)
544 .expect("scattering transform creation should succeed");
545
546 let freq = 20.0; let signal: Vec<f64> = (0..n)
549 .map(|i| (2.0 * PI * freq * i as f64 / n as f64).sin())
550 .collect();
551
552 let result = st.transform(&signal).expect("transform should succeed");
553
554 let first = result.first_order();
557 assert!(!first.is_empty(), "should have first-order coefficients");
558
559 let max_path = first
561 .iter()
562 .enumerate()
563 .max_by(|(_, a), (_, b)| {
564 let ea: f64 = a.values.iter().map(|v| v * v).sum();
565 let eb: f64 = b.values.iter().map(|v| v * v).sum();
566 ea.partial_cmp(&eb).unwrap_or(std::cmp::Ordering::Equal)
567 })
568 .map(|(idx, _)| idx);
569
570 assert!(max_path.is_some(), "should find a path with maximum energy");
571 }
572
573 #[test]
574 fn test_zeroth_order_only() {
575 let config = ScatteringConfig::new(3, vec![4]).with_max_order(0);
576 let n = 128;
577 let st = ScatteringTransform::new(config, n)
578 .expect("scattering transform creation should succeed");
579
580 let signal: Vec<f64> = (0..n).map(|i| i as f64 / n as f64).collect();
581 let result = st.transform(&signal).expect("transform should succeed");
582
583 assert_eq!(result.num_zeroth, 1);
584 assert_eq!(result.num_first, 0);
585 assert_eq!(result.num_second, 0);
586 }
587
588 #[test]
589 fn test_empty_signal_error() {
590 let config = ScatteringConfig::new(3, vec![4]);
591 let st = ScatteringTransform::new(config, 128)
592 .expect("scattering transform creation should succeed");
593
594 let result = st.transform(&[]);
595 assert!(result.is_err());
596 }
597}