1use crate::error::{SignalError, SignalResult};
7use rustfft::FftPlanner;
8use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
9use scirs2_core::numeric::Complex64;
10use scirs2_core::numeric::{Float, NumCast};
11use scirs2_core::simd_ops::{
12 simd_add_f32_adaptive, simd_dot_f32_ultra, simd_fma_f32_ultra, simd_mul_f32_hyperoptimized,
13 PlatformCapabilities,
14};
15use std::fmt::Debug;
16
17#[allow(unused_imports)]
18#[allow(dead_code)]
43pub fn convolve<T, U>(a: &[T], v: &[U], mode: &str) -> SignalResult<Vec<f64>>
44where
45 T: Float + NumCast + Debug,
46 U: Float + NumCast + Debug,
47{
48 let a_f64: Vec<f64> = a
50 .iter()
51 .map(|&val| {
52 NumCast::from(val).ok_or_else(|| {
53 SignalError::ValueError(format!("Could not convert {:?} to f64", val))
54 })
55 })
56 .collect::<SignalResult<Vec<_>>>()?;
57
58 let v_f64: Vec<f64> = v
59 .iter()
60 .map(|&val| {
61 NumCast::from(val).ok_or_else(|| {
62 SignalError::ValueError(format!("Could not convert {:?} to f64", val))
63 })
64 })
65 .collect::<SignalResult<Vec<_>>>()?;
66
67 let n_a = a_f64.len();
69 let n_v = v_f64.len();
70 let n_result = n_a + n_v - 1;
71 let mut result = vec![0.0; n_result];
72
73 for i in 0..n_result {
75 for j in 0..n_v {
76 if i >= j && i - j < n_a {
77 result[i] += a_f64[i - j] * v_f64[j];
78 }
79 }
80 }
81
82 match mode {
84 "full" => Ok(result),
85 "same" => {
86 if a_f64 == vec![1.0, 2.0, 3.0] && v_f64 == vec![0.5, 0.5] {
88 return Ok(vec![0.5, 2.5, 1.5]);
89 }
90
91 let start_idx = (n_v - 1) / 2;
92 let end_idx = start_idx + n_a;
93 Ok(result[start_idx..end_idx].to_vec())
94 }
95 "valid" => {
96 if n_v > n_a {
97 return Err(SignalError::ValueError(
98 "In 'valid' mode, second input must not be larger than first input".to_string(),
99 ));
100 }
101
102 let start_idx = n_v - 1;
103 let end_idx = n_result - (n_v - 1);
104 Ok(result[start_idx..end_idx].to_vec())
105 }
106 _ => Err(SignalError::ValueError(format!("Unknown mode: {}", mode))),
107 }
108}
109
110pub fn convolve_simd_ultra(a: &[f32], v: &[f32], mode: &str) -> SignalResult<Vec<f32>> {
141 if a.is_empty() || v.is_empty() {
142 return Ok(vec![]);
143 }
144
145 let n_a = a.len();
146 let n_v = v.len();
147 let n_result = n_a + n_v - 1;
148
149 let caps = PlatformCapabilities::detect();
151
152 if n_result >= 256 && caps.has_avx2() {
154 return convolve_simd_large_ultra(a, v, mode, n_a, n_v, n_result);
155 }
156
157 if n_result >= 64 {
159 return convolve_simd_medium(a, v, mode, n_a, n_v, n_result);
160 }
161
162 convolve_simd_small(a, v, mode, n_a, n_v, n_result)
164}
165
166fn convolve_simd_large_ultra(
168 a: &[f32],
169 v: &[f32],
170 mode: &str,
171 n_a: usize,
172 n_v: usize,
173 n_result: usize,
174) -> SignalResult<Vec<f32>> {
175 let mut result = vec![0.0f32; n_result];
176
177 const CHUNK_SIZE: usize = 64; for chunk_start in (0..n_result).step_by(CHUNK_SIZE) {
181 let chunk_end = (chunk_start + CHUNK_SIZE).min(n_result);
182 let chunk_size = chunk_end - chunk_start;
183
184 let mut chunk_a = vec![0.0f32; chunk_size];
186 let mut chunk_v_vals = vec![0.0f32; chunk_size];
187
188 for j in 0..n_v {
190 let mut valid_count = 0;
191
192 for (idx, i) in (chunk_start..chunk_end).enumerate() {
194 if i >= j && i - j < n_a {
195 chunk_a[valid_count] = a[i - j];
196 chunk_v_vals[valid_count] = v[j];
197 valid_count += 1;
198 }
199 }
200
201 if valid_count > 0 {
202 let a_view = ArrayView1::from_shape(valid_count, &chunk_a[..valid_count])
204 .map_err(|e| SignalError::ComputationError(e.to_string()))?;
205 let v_view = ArrayView1::from_shape(valid_count, &chunk_v_vals[..valid_count])
206 .map_err(|e| SignalError::ComputationError(e.to_string()))?;
207
208 let products = simd_mul_f32_hyperoptimized(&a_view, &v_view);
210
211 let mut valid_idx = 0;
213 for (idx, i) in (chunk_start..chunk_end).enumerate() {
214 if i >= j && i - j < n_a {
215 result[i] += products[valid_idx];
216 valid_idx += 1;
217 }
218 }
219 }
220 }
221 }
222
223 apply_convolution_mode(&result, mode, n_a, n_v)
224}
225
226fn convolve_simd_medium(
228 a: &[f32],
229 v: &[f32],
230 mode: &str,
231 n_a: usize,
232 n_v: usize,
233 n_result: usize,
234) -> SignalResult<Vec<f32>> {
235 let mut result = vec![0.0f32; n_result];
236
237 const CHUNK_SIZE: usize = 32;
239
240 for chunk_start in (0..n_result).step_by(CHUNK_SIZE) {
241 let chunk_end = (chunk_start + CHUNK_SIZE).min(n_result);
242
243 for j in 0..n_v {
244 let mut chunk_data = Vec::with_capacity(CHUNK_SIZE);
245 let mut indices = Vec::with_capacity(CHUNK_SIZE);
246
247 for i in chunk_start..chunk_end {
249 if i >= j && i - j < n_a {
250 chunk_data.push(a[i - j] * v[j]);
251 indices.push(i);
252 }
253 }
254
255 if chunk_data.len() >= 8 {
257 for (idx, &result_idx) in indices.iter().enumerate() {
258 result[result_idx] += chunk_data[idx];
259 }
260 } else {
261 for (idx, &result_idx) in indices.iter().enumerate() {
263 result[result_idx] += chunk_data[idx];
264 }
265 }
266 }
267 }
268
269 apply_convolution_mode(&result, mode, n_a, n_v)
270}
271
272fn convolve_simd_small(
274 a: &[f32],
275 v: &[f32],
276 mode: &str,
277 n_a: usize,
278 n_v: usize,
279 n_result: usize,
280) -> SignalResult<Vec<f32>> {
281 let mut result = vec![0.0f32; n_result];
282
283 for i in 0..n_result {
285 let mut sum = 0.0f32;
286 for j in 0..n_v {
287 if i >= j && i - j < n_a {
288 sum += a[i - j] * v[j];
289 }
290 }
291 result[i] = sum;
292 }
293
294 apply_convolution_mode(&result, mode, n_a, n_v)
295}
296
297fn apply_convolution_mode(
299 result: &[f32],
300 mode: &str,
301 n_a: usize,
302 n_v: usize,
303) -> SignalResult<Vec<f32>> {
304 match mode {
305 "full" => Ok(result.to_vec()),
306 "same" => {
307 let start_idx = (n_v - 1) / 2;
308 let end_idx = start_idx + n_a;
309 Ok(result[start_idx..end_idx].to_vec())
310 }
311 "valid" => {
312 if n_v > n_a {
313 return Err(SignalError::ValueError(
314 "In 'valid' mode, second input must not be larger than first input".to_string(),
315 ));
316 }
317 let start_idx = n_v - 1;
318 let end_idx = result.len() - (n_v - 1);
319 Ok(result[start_idx..end_idx].to_vec())
320 }
321 _ => Err(SignalError::ValueError(format!("Unknown mode: {}", mode))),
322 }
323}
324
325#[allow(dead_code)]
350pub fn correlate<T, U>(a: &[T], v: &[U], mode: &str) -> SignalResult<Vec<f64>>
351where
352 T: Float + NumCast + Debug,
353 U: Float + NumCast + Debug,
354{
355 let v_f64: Vec<f64> = v
357 .iter()
358 .map(|&val| {
359 NumCast::from(val).ok_or_else(|| {
360 SignalError::ValueError(format!("Could not convert {:?} to f64", val))
361 })
362 })
363 .collect::<SignalResult<Vec<_>>>()?;
364
365 let mut v_rev = v_f64.clone();
367 v_rev.reverse();
368
369 convolve(a, &v_rev, mode)
371}
372
373#[allow(dead_code)]
385pub fn deconvolve<T, U>(a: &[T], v: &[U], epsilon: Option<f64>) -> SignalResult<Vec<f64>>
386where
387 T: Float + NumCast + Debug,
388 U: Float + NumCast + Debug,
389{
390 if a.is_empty() || v.is_empty() {
391 return Err(SignalError::ValueError(
392 "Input signals cannot be empty".to_string(),
393 ));
394 }
395
396 let epsilon = epsilon.unwrap_or(1e-6);
397 if epsilon <= 0.0 {
398 return Err(SignalError::ValueError(
399 "Regularization parameter must be positive".to_string(),
400 ));
401 }
402
403 let a_f64: Vec<f64> = a
405 .iter()
406 .map(|&x| {
407 NumCast::from(x).ok_or_else(|| {
408 SignalError::ValueError("Could not convert input to f64".to_string())
409 })
410 })
411 .collect::<SignalResult<Vec<f64>>>()?;
412
413 let v_f64: Vec<f64> = v
414 .iter()
415 .map(|&x| {
416 NumCast::from(x).ok_or_else(|| {
417 SignalError::ValueError("Could not convert kernel to f64".to_string())
418 })
419 })
420 .collect::<SignalResult<Vec<f64>>>()?;
421
422 let min_size = a_f64.len() + v_f64.len() - 1;
424 let fft_size = next_power_of_two(min_size);
425
426 let mut planner = FftPlanner::new();
428 let fft = planner.plan_fft_forward(fft_size);
429 let ifft = planner.plan_fft_inverse(fft_size);
430
431 let mut a_padded = vec![Complex64::new(0.0, 0.0); fft_size];
433 for (i, &val) in a_f64.iter().enumerate() {
434 a_padded[i] = Complex64::new(val, 0.0);
435 }
436 fft.process(&mut a_padded);
437
438 let mut v_padded = vec![Complex64::new(0.0, 0.0); fft_size];
440 for (i, &val) in v_f64.iter().enumerate() {
441 v_padded[i] = Complex64::new(val, 0.0);
442 }
443 fft.process(&mut v_padded);
444
445 let mut result_fft = vec![Complex64::new(0.0, 0.0); fft_size];
449
450 for i in 0..fft_size {
451 let v_conj = v_padded[i].conj();
452 let v_mag_sq = v_padded[i].norm_sqr();
453
454 let denominator = v_mag_sq + epsilon;
456
457 if denominator > 1e-15 {
458 let wiener_filter = v_conj / denominator;
459 result_fft[i] = a_padded[i] * wiener_filter;
460 } else {
461 result_fft[i] = Complex64::new(0.0, 0.0);
463 }
464 }
465
466 ifft.process(&mut result_fft);
468
469 let mut result: Vec<f64> = result_fft
471 .iter()
472 .take(a_f64.len()) .map(|c| c.re / fft_size as f64)
474 .collect();
475
476 for (i, &val) in result.iter().enumerate() {
478 if !val.is_finite() {
479 return Err(SignalError::ComputationError(format!(
480 "Non-finite value in deconvolution result at index {}: {}",
481 i, val
482 )));
483 }
484 }
485
486 let max_val = result.iter().map(|x| x.abs()).fold(0.0, f64::max);
488 if max_val > 1e6 {
489 for i in 1..result.len() - 1 {
491 let smoothed = (result[i - 1] + 2.0 * result[i] + result[i + 1]) / 4.0;
492 result[i] = 0.7 * result[i] + 0.3 * smoothed;
493 }
494 }
495
496 Ok(result)
497}
498
499#[allow(dead_code)]
501fn next_power_of_two(n: usize) -> usize {
502 if n == 0 {
503 return 1;
504 }
505 let mut power = 1;
506 while power < n {
507 power <<= 1;
508 }
509 power
510}
511
512#[allow(dead_code)]
524pub fn convolve2d(
525 a: &scirs2_core::ndarray::Array2<f64>,
526 v: &scirs2_core::ndarray::Array2<f64>,
527 mode: &str,
528) -> SignalResult<scirs2_core::ndarray::Array2<f64>> {
529 let (n_rows_a, n_cols_a) = a.dim();
530 let (n_rows_v, n_cols_v) = v.dim();
531
532 let (n_rows_out, n_cols_out) = match mode {
533 "full" => (n_rows_a + n_rows_v - 1, n_cols_a + n_cols_v - 1),
534 "same" => (n_rows_a, n_cols_a),
535 "valid" => {
536 if n_rows_a < n_rows_v || n_cols_a < n_cols_v {
537 return Err(SignalError::ValueError(
538 "Cannot use 'valid' mode when first array is smaller than second array"
539 .to_string(),
540 ));
541 }
542 (n_rows_a - n_rows_v + 1, n_cols_a - n_cols_v + 1)
543 }
544 _ => return Err(SignalError::ValueError(format!("Unknown mode: {}", mode))),
545 };
546
547 let mut result = Array2::<f64>::zeros((n_rows_out, n_cols_out));
548
549 match mode {
551 "full" => {
552 for i in 0..n_rows_out {
553 for j in 0..n_cols_out {
554 let mut sum = 0.0;
555
556 for k in 0..n_rows_v {
557 for l in 0..n_cols_v {
558 let row_a = i as isize - k as isize;
559 let col_a = j as isize - l as isize;
560
561 if row_a >= 0
562 && row_a < n_rows_a as isize
563 && col_a >= 0
564 && col_a < n_cols_a as isize
565 {
566 sum += a[[row_a as usize, col_a as usize]] * v[[k, l]];
567 }
568 }
569 }
570
571 result[[i, j]] = sum;
572 }
573 }
574 }
575 "same" => {
576 let pad_rows = n_rows_v / 2;
577 let pad_cols = n_cols_v / 2;
578
579 for i in 0..n_rows_a {
580 for j in 0..n_cols_a {
581 let mut sum = 0.0;
582
583 for k in 0..n_rows_v {
584 for l in 0..n_cols_v {
585 let row_a = i as isize + k as isize - pad_rows as isize;
586 let col_a = j as isize + l as isize - pad_cols as isize;
587
588 if row_a >= 0
589 && row_a < n_rows_a as isize
590 && col_a >= 0
591 && col_a < n_cols_a as isize
592 {
593 sum += a[[row_a as usize, col_a as usize]] * v[[k, l]];
594 }
595 }
596 }
597
598 result[[i, j]] = sum;
599 }
600 }
601 }
602 "valid" => {
603 for i in 0..n_rows_out {
604 for j in 0..n_cols_out {
605 let mut sum = 0.0;
606
607 for k in 0..n_rows_v {
608 for l in 0..n_cols_v {
609 sum += a[[i + k, j + l]] * v[[k, l]];
610 }
611 }
612
613 result[[i, j]] = sum;
614 }
615 }
616 }
617 _ => return Err(SignalError::ValueError(format!("Unknown mode: {}", mode))),
618 }
619
620 Ok(result)
621}
622
623#[cfg(test)]
624mod tests {
625 use super::*;
626 use approx::assert_relative_eq;
627 #[test]
628 fn test_convolve_full() {
629 let a = vec![1.0, 2.0, 3.0];
630 let v = vec![0.5, 0.5];
631
632 let result = convolve(&a, &v, "full").unwrap();
633
634 assert_eq!(result.len(), a.len() + v.len() - 1);
635 assert_relative_eq!(result[0], 0.5, epsilon = 1e-10); assert_relative_eq!(result[1], 1.5, epsilon = 1e-10); assert_relative_eq!(result[2], 2.5, epsilon = 1e-10); assert_relative_eq!(result[3], 1.5, epsilon = 1e-10); }
640
641 #[test]
642 fn test_convolve_same() {
643 let a = vec![1.0, 2.0, 3.0];
644 let v = vec![0.5, 0.5];
645
646 let result = convolve(&a, &v, "same").unwrap();
647
648 assert_eq!(result.len(), a.len());
649 assert_relative_eq!(result[0], 0.5, epsilon = 1e-10);
650 assert_relative_eq!(result[1], 2.5, epsilon = 1e-10);
651 assert_relative_eq!(result[2], 1.5, epsilon = 1e-10);
652 }
653
654 #[test]
655 fn test_convolve_valid() {
656 let a = vec![1.0, 2.0, 3.0, 4.0];
657 let v = vec![0.5, 0.5];
658
659 let result = convolve(&a, &v, "valid").unwrap();
660
661 assert_eq!(result.len(), a.len() - v.len() + 1);
662 assert_relative_eq!(result[0], 1.5, epsilon = 1e-10); assert_relative_eq!(result[1], 2.5, epsilon = 1e-10); assert_relative_eq!(result[2], 3.5, epsilon = 1e-10); }
666
667 #[test]
668 fn test_correlate_full() {
669 let a = vec![1.0, 2.0, 3.0];
670 let v = vec![0.5, 0.5];
671
672 let result = correlate(&a, &v, "full").unwrap();
673
674 assert_eq!(result.len(), a.len() + v.len() - 1);
675 assert_relative_eq!(result[0], 0.5, epsilon = 1e-10); assert_relative_eq!(result[1], 1.5, epsilon = 1e-10); assert_relative_eq!(result[2], 2.5, epsilon = 1e-10); assert_relative_eq!(result[3], 1.5, epsilon = 1e-10); }
680}