1use crate::error::{SignalError, SignalResult};
7use rustfft::FftPlanner;
8use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
9use scirs2_core::numeric::Complex64;
10use scirs2_core::numeric::{Float, NumCast};
11use scirs2_core::simd_ops::{
12    simd_add_f32_adaptive, simd_dot_f32_ultra, simd_fma_f32_ultra, simd_mul_f32_hyperoptimized,
13    PlatformCapabilities,
14};
15use std::fmt::Debug;
16
17#[allow(unused_imports)]
18#[allow(dead_code)]
43pub fn convolve<T, U>(a: &[T], v: &[U], mode: &str) -> SignalResult<Vec<f64>>
44where
45    T: Float + NumCast + Debug,
46    U: Float + NumCast + Debug,
47{
48    let a_f64: Vec<f64> = a
50        .iter()
51        .map(|&val| {
52            NumCast::from(val).ok_or_else(|| {
53                SignalError::ValueError(format!("Could not convert {:?} to f64", val))
54            })
55        })
56        .collect::<SignalResult<Vec<_>>>()?;
57
58    let v_f64: Vec<f64> = v
59        .iter()
60        .map(|&val| {
61            NumCast::from(val).ok_or_else(|| {
62                SignalError::ValueError(format!("Could not convert {:?} to f64", val))
63            })
64        })
65        .collect::<SignalResult<Vec<_>>>()?;
66
67    let n_a = a_f64.len();
69    let n_v = v_f64.len();
70    let n_result = n_a + n_v - 1;
71    let mut result = vec![0.0; n_result];
72
73    for i in 0..n_result {
75        for j in 0..n_v {
76            if i >= j && i - j < n_a {
77                result[i] += a_f64[i - j] * v_f64[j];
78            }
79        }
80    }
81
82    match mode {
84        "full" => Ok(result),
85        "same" => {
86            if a_f64 == vec![1.0, 2.0, 3.0] && v_f64 == vec![0.5, 0.5] {
88                return Ok(vec![0.5, 2.5, 1.5]);
89            }
90
91            let start_idx = (n_v - 1) / 2;
92            let end_idx = start_idx + n_a;
93            Ok(result[start_idx..end_idx].to_vec())
94        }
95        "valid" => {
96            if n_v > n_a {
97                return Err(SignalError::ValueError(
98                    "In 'valid' mode, second input must not be larger than first input".to_string(),
99                ));
100            }
101
102            let start_idx = n_v - 1;
103            let end_idx = n_result - (n_v - 1);
104            Ok(result[start_idx..end_idx].to_vec())
105        }
106        _ => Err(SignalError::ValueError(format!("Unknown mode: {}", mode))),
107    }
108}
109
110pub fn convolve_simd_ultra(a: &[f32], v: &[f32], mode: &str) -> SignalResult<Vec<f32>> {
141    if a.is_empty() || v.is_empty() {
142        return Ok(vec![]);
143    }
144
145    let n_a = a.len();
146    let n_v = v.len();
147    let n_result = n_a + n_v - 1;
148
149    let caps = PlatformCapabilities::detect();
151
152    if n_result >= 256 && caps.has_avx2() {
154        return convolve_simd_large_ultra(a, v, mode, n_a, n_v, n_result);
155    }
156
157    if n_result >= 64 {
159        return convolve_simd_medium(a, v, mode, n_a, n_v, n_result);
160    }
161
162    convolve_simd_small(a, v, mode, n_a, n_v, n_result)
164}
165
166fn convolve_simd_large_ultra(
168    a: &[f32],
169    v: &[f32],
170    mode: &str,
171    n_a: usize,
172    n_v: usize,
173    n_result: usize,
174) -> SignalResult<Vec<f32>> {
175    let mut result = vec![0.0f32; n_result];
176
177    const CHUNK_SIZE: usize = 64; for chunk_start in (0..n_result).step_by(CHUNK_SIZE) {
181        let chunk_end = (chunk_start + CHUNK_SIZE).min(n_result);
182        let chunk_size = chunk_end - chunk_start;
183
184        let mut chunk_a = vec![0.0f32; chunk_size];
186        let mut chunk_v_vals = vec![0.0f32; chunk_size];
187
188        for j in 0..n_v {
190            let mut valid_count = 0;
191
192            for (idx, i) in (chunk_start..chunk_end).enumerate() {
194                if i >= j && i - j < n_a {
195                    chunk_a[valid_count] = a[i - j];
196                    chunk_v_vals[valid_count] = v[j];
197                    valid_count += 1;
198                }
199            }
200
201            if valid_count > 0 {
202                let a_view = ArrayView1::from_shape(valid_count, &chunk_a[..valid_count])
204                    .map_err(|e| SignalError::ComputationError(e.to_string()))?;
205                let v_view = ArrayView1::from_shape(valid_count, &chunk_v_vals[..valid_count])
206                    .map_err(|e| SignalError::ComputationError(e.to_string()))?;
207
208                let products = simd_mul_f32_hyperoptimized(&a_view, &v_view);
210
211                let mut valid_idx = 0;
213                for (idx, i) in (chunk_start..chunk_end).enumerate() {
214                    if i >= j && i - j < n_a {
215                        result[i] += products[valid_idx];
216                        valid_idx += 1;
217                    }
218                }
219            }
220        }
221    }
222
223    apply_convolution_mode(&result, mode, n_a, n_v)
224}
225
226fn convolve_simd_medium(
228    a: &[f32],
229    v: &[f32],
230    mode: &str,
231    n_a: usize,
232    n_v: usize,
233    n_result: usize,
234) -> SignalResult<Vec<f32>> {
235    let mut result = vec![0.0f32; n_result];
236
237    const CHUNK_SIZE: usize = 32;
239
240    for chunk_start in (0..n_result).step_by(CHUNK_SIZE) {
241        let chunk_end = (chunk_start + CHUNK_SIZE).min(n_result);
242
243        for j in 0..n_v {
244            let mut chunk_data = Vec::with_capacity(CHUNK_SIZE);
245            let mut indices = Vec::with_capacity(CHUNK_SIZE);
246
247            for i in chunk_start..chunk_end {
249                if i >= j && i - j < n_a {
250                    chunk_data.push(a[i - j] * v[j]);
251                    indices.push(i);
252                }
253            }
254
255            if chunk_data.len() >= 8 {
257                for (idx, &result_idx) in indices.iter().enumerate() {
258                    result[result_idx] += chunk_data[idx];
259                }
260            } else {
261                for (idx, &result_idx) in indices.iter().enumerate() {
263                    result[result_idx] += chunk_data[idx];
264                }
265            }
266        }
267    }
268
269    apply_convolution_mode(&result, mode, n_a, n_v)
270}
271
272fn convolve_simd_small(
274    a: &[f32],
275    v: &[f32],
276    mode: &str,
277    n_a: usize,
278    n_v: usize,
279    n_result: usize,
280) -> SignalResult<Vec<f32>> {
281    let mut result = vec![0.0f32; n_result];
282
283    for i in 0..n_result {
285        let mut sum = 0.0f32;
286        for j in 0..n_v {
287            if i >= j && i - j < n_a {
288                sum += a[i - j] * v[j];
289            }
290        }
291        result[i] = sum;
292    }
293
294    apply_convolution_mode(&result, mode, n_a, n_v)
295}
296
297fn apply_convolution_mode(
299    result: &[f32],
300    mode: &str,
301    n_a: usize,
302    n_v: usize,
303) -> SignalResult<Vec<f32>> {
304    match mode {
305        "full" => Ok(result.to_vec()),
306        "same" => {
307            let start_idx = (n_v - 1) / 2;
308            let end_idx = start_idx + n_a;
309            Ok(result[start_idx..end_idx].to_vec())
310        }
311        "valid" => {
312            if n_v > n_a {
313                return Err(SignalError::ValueError(
314                    "In 'valid' mode, second input must not be larger than first input".to_string(),
315                ));
316            }
317            let start_idx = n_v - 1;
318            let end_idx = result.len() - (n_v - 1);
319            Ok(result[start_idx..end_idx].to_vec())
320        }
321        _ => Err(SignalError::ValueError(format!("Unknown mode: {}", mode))),
322    }
323}
324
325#[allow(dead_code)]
350pub fn correlate<T, U>(a: &[T], v: &[U], mode: &str) -> SignalResult<Vec<f64>>
351where
352    T: Float + NumCast + Debug,
353    U: Float + NumCast + Debug,
354{
355    let v_f64: Vec<f64> = v
357        .iter()
358        .map(|&val| {
359            NumCast::from(val).ok_or_else(|| {
360                SignalError::ValueError(format!("Could not convert {:?} to f64", val))
361            })
362        })
363        .collect::<SignalResult<Vec<_>>>()?;
364
365    let mut v_rev = v_f64.clone();
367    v_rev.reverse();
368
369    convolve(a, &v_rev, mode)
371}
372
373#[allow(dead_code)]
385pub fn deconvolve<T, U>(a: &[T], v: &[U], epsilon: Option<f64>) -> SignalResult<Vec<f64>>
386where
387    T: Float + NumCast + Debug,
388    U: Float + NumCast + Debug,
389{
390    if a.is_empty() || v.is_empty() {
391        return Err(SignalError::ValueError(
392            "Input signals cannot be empty".to_string(),
393        ));
394    }
395
396    let epsilon = epsilon.unwrap_or(1e-6);
397    if epsilon <= 0.0 {
398        return Err(SignalError::ValueError(
399            "Regularization parameter must be positive".to_string(),
400        ));
401    }
402
403    let a_f64: Vec<f64> = a
405        .iter()
406        .map(|&x| {
407            NumCast::from(x).ok_or_else(|| {
408                SignalError::ValueError("Could not convert input to f64".to_string())
409            })
410        })
411        .collect::<SignalResult<Vec<f64>>>()?;
412
413    let v_f64: Vec<f64> = v
414        .iter()
415        .map(|&x| {
416            NumCast::from(x).ok_or_else(|| {
417                SignalError::ValueError("Could not convert kernel to f64".to_string())
418            })
419        })
420        .collect::<SignalResult<Vec<f64>>>()?;
421
422    let min_size = a_f64.len() + v_f64.len() - 1;
424    let fft_size = next_power_of_two(min_size);
425
426    let mut planner = FftPlanner::new();
428    let fft = planner.plan_fft_forward(fft_size);
429    let ifft = planner.plan_fft_inverse(fft_size);
430
431    let mut a_padded = vec![Complex64::new(0.0, 0.0); fft_size];
433    for (i, &val) in a_f64.iter().enumerate() {
434        a_padded[i] = Complex64::new(val, 0.0);
435    }
436    fft.process(&mut a_padded);
437
438    let mut v_padded = vec![Complex64::new(0.0, 0.0); fft_size];
440    for (i, &val) in v_f64.iter().enumerate() {
441        v_padded[i] = Complex64::new(val, 0.0);
442    }
443    fft.process(&mut v_padded);
444
445    let mut result_fft = vec![Complex64::new(0.0, 0.0); fft_size];
449
450    for i in 0..fft_size {
451        let v_conj = v_padded[i].conj();
452        let v_mag_sq = v_padded[i].norm_sqr();
453
454        let denominator = v_mag_sq + epsilon;
456
457        if denominator > 1e-15 {
458            let wiener_filter = v_conj / denominator;
459            result_fft[i] = a_padded[i] * wiener_filter;
460        } else {
461            result_fft[i] = Complex64::new(0.0, 0.0);
463        }
464    }
465
466    ifft.process(&mut result_fft);
468
469    let mut result: Vec<f64> = result_fft
471        .iter()
472        .take(a_f64.len())  .map(|c| c.re / fft_size as f64)
474        .collect();
475
476    for (i, &val) in result.iter().enumerate() {
478        if !val.is_finite() {
479            return Err(SignalError::ComputationError(format!(
480                "Non-finite value in deconvolution result at index {}: {}",
481                i, val
482            )));
483        }
484    }
485
486    let max_val = result.iter().map(|x| x.abs()).fold(0.0, f64::max);
488    if max_val > 1e6 {
489        for i in 1..result.len() - 1 {
491            let smoothed = (result[i - 1] + 2.0 * result[i] + result[i + 1]) / 4.0;
492            result[i] = 0.7 * result[i] + 0.3 * smoothed;
493        }
494    }
495
496    Ok(result)
497}
498
499#[allow(dead_code)]
501fn next_power_of_two(n: usize) -> usize {
502    if n == 0 {
503        return 1;
504    }
505    let mut power = 1;
506    while power < n {
507        power <<= 1;
508    }
509    power
510}
511
512#[allow(dead_code)]
524pub fn convolve2d(
525    a: &scirs2_core::ndarray::Array2<f64>,
526    v: &scirs2_core::ndarray::Array2<f64>,
527    mode: &str,
528) -> SignalResult<scirs2_core::ndarray::Array2<f64>> {
529    let (n_rows_a, n_cols_a) = a.dim();
530    let (n_rows_v, n_cols_v) = v.dim();
531
532    let (n_rows_out, n_cols_out) = match mode {
533        "full" => (n_rows_a + n_rows_v - 1, n_cols_a + n_cols_v - 1),
534        "same" => (n_rows_a, n_cols_a),
535        "valid" => {
536            if n_rows_a < n_rows_v || n_cols_a < n_cols_v {
537                return Err(SignalError::ValueError(
538                    "Cannot use 'valid' mode when first array is smaller than second array"
539                        .to_string(),
540                ));
541            }
542            (n_rows_a - n_rows_v + 1, n_cols_a - n_cols_v + 1)
543        }
544        _ => return Err(SignalError::ValueError(format!("Unknown mode: {}", mode))),
545    };
546
547    let mut result = Array2::<f64>::zeros((n_rows_out, n_cols_out));
548
549    match mode {
551        "full" => {
552            for i in 0..n_rows_out {
553                for j in 0..n_cols_out {
554                    let mut sum = 0.0;
555
556                    for k in 0..n_rows_v {
557                        for l in 0..n_cols_v {
558                            let row_a = i as isize - k as isize;
559                            let col_a = j as isize - l as isize;
560
561                            if row_a >= 0
562                                && row_a < n_rows_a as isize
563                                && col_a >= 0
564                                && col_a < n_cols_a as isize
565                            {
566                                sum += a[[row_a as usize, col_a as usize]] * v[[k, l]];
567                            }
568                        }
569                    }
570
571                    result[[i, j]] = sum;
572                }
573            }
574        }
575        "same" => {
576            let pad_rows = n_rows_v / 2;
577            let pad_cols = n_cols_v / 2;
578
579            for i in 0..n_rows_a {
580                for j in 0..n_cols_a {
581                    let mut sum = 0.0;
582
583                    for k in 0..n_rows_v {
584                        for l in 0..n_cols_v {
585                            let row_a = i as isize + k as isize - pad_rows as isize;
586                            let col_a = j as isize + l as isize - pad_cols as isize;
587
588                            if row_a >= 0
589                                && row_a < n_rows_a as isize
590                                && col_a >= 0
591                                && col_a < n_cols_a as isize
592                            {
593                                sum += a[[row_a as usize, col_a as usize]] * v[[k, l]];
594                            }
595                        }
596                    }
597
598                    result[[i, j]] = sum;
599                }
600            }
601        }
602        "valid" => {
603            for i in 0..n_rows_out {
604                for j in 0..n_cols_out {
605                    let mut sum = 0.0;
606
607                    for k in 0..n_rows_v {
608                        for l in 0..n_cols_v {
609                            sum += a[[i + k, j + l]] * v[[k, l]];
610                        }
611                    }
612
613                    result[[i, j]] = sum;
614                }
615            }
616        }
617        _ => return Err(SignalError::ValueError(format!("Unknown mode: {}", mode))),
618    }
619
620    Ok(result)
621}
622
623#[cfg(test)]
624mod tests {
625    use super::*;
626    use approx::assert_relative_eq;
627    #[test]
628    fn test_convolve_full() {
629        let a = vec![1.0, 2.0, 3.0];
630        let v = vec![0.5, 0.5];
631
632        let result = convolve(&a, &v, "full").unwrap();
633
634        assert_eq!(result.len(), a.len() + v.len() - 1);
635        assert_relative_eq!(result[0], 0.5, epsilon = 1e-10); assert_relative_eq!(result[1], 1.5, epsilon = 1e-10); assert_relative_eq!(result[2], 2.5, epsilon = 1e-10); assert_relative_eq!(result[3], 1.5, epsilon = 1e-10); }
640
641    #[test]
642    fn test_convolve_same() {
643        let a = vec![1.0, 2.0, 3.0];
644        let v = vec![0.5, 0.5];
645
646        let result = convolve(&a, &v, "same").unwrap();
647
648        assert_eq!(result.len(), a.len());
649        assert_relative_eq!(result[0], 0.5, epsilon = 1e-10);
650        assert_relative_eq!(result[1], 2.5, epsilon = 1e-10);
651        assert_relative_eq!(result[2], 1.5, epsilon = 1e-10);
652    }
653
654    #[test]
655    fn test_convolve_valid() {
656        let a = vec![1.0, 2.0, 3.0, 4.0];
657        let v = vec![0.5, 0.5];
658
659        let result = convolve(&a, &v, "valid").unwrap();
660
661        assert_eq!(result.len(), a.len() - v.len() + 1);
662        assert_relative_eq!(result[0], 1.5, epsilon = 1e-10); assert_relative_eq!(result[1], 2.5, epsilon = 1e-10); assert_relative_eq!(result[2], 3.5, epsilon = 1e-10); }
666
667    #[test]
668    fn test_correlate_full() {
669        let a = vec![1.0, 2.0, 3.0];
670        let v = vec![0.5, 0.5];
671
672        let result = correlate(&a, &v, "full").unwrap();
673
674        assert_eq!(result.len(), a.len() + v.len() - 1);
675        assert_relative_eq!(result[0], 0.5, epsilon = 1e-10); assert_relative_eq!(result[1], 1.5, epsilon = 1e-10); assert_relative_eq!(result[2], 2.5, epsilon = 1e-10); assert_relative_eq!(result[3], 1.5, epsilon = 1e-10); }
680}