Skip to main content

scirs2/
signal.rs

1//! Python bindings for scirs2-signal
2//!
3//! Provides signal processing functions similar to scipy.signal
4
5use pyo3::prelude::*;
6use pyo3::types::PyDict;
7use scirs2_core::python::numpy_compat::{scirs_to_numpy_array1, Array1};
8use scirs2_numpy::{PyArray1, PyReadonlyArray1};
9
10// Import signal functions
11use scirs2_signal::hilbert::hilbert;
12
13// Import filter design functions
14use scirs2_signal::filter::fir::firwin;
15use scirs2_signal::filter::iir::{butter, cheby1};
16use scirs2_signal::filter::FilterType;
17
18// =============================================================================
19// Convolution and Correlation
20// =============================================================================
21
22/// Convolve two 1-D arrays - optimized direct implementation
23///
24/// Parameters:
25/// - a: First input array
26/// - v: Second input array (kernel)
27/// - mode: 'full', 'same', or 'valid'
28#[pyfunction]
29#[pyo3(signature = (a, v, mode="full"))]
30fn convolve_py(
31    py: Python,
32    a: PyReadonlyArray1<f64>,
33    v: PyReadonlyArray1<f64>,
34    mode: &str,
35) -> PyResult<Py<PyArray1<f64>>> {
36    let a_arr = a.as_array();
37    let v_arr = v.as_array();
38    let a_slice = a_arr.as_slice().expect("Operation failed");
39    let v_slice = v_arr.as_slice().expect("Operation failed");
40    let n = a_slice.len();
41    let m = v_slice.len();
42
43    if n == 0 || m == 0 {
44        return Err(pyo3::exceptions::PyValueError::new_err(
45            "Arrays must not be empty",
46        ));
47    }
48
49    // Calculate output size based on mode
50    let (out_len, offset) = match mode {
51        "full" => (n + m - 1, 0),
52        "same" => (n, (m - 1) / 2),
53        "valid" => {
54            if n < m {
55                return Err(pyo3::exceptions::PyValueError::new_err(
56                    "For 'valid' mode, first array must be at least as long as second",
57                ));
58            }
59            (n - m + 1, m - 1)
60        }
61        _ => {
62            return Err(pyo3::exceptions::PyValueError::new_err(
63                "mode must be 'full', 'same', or 'valid'",
64            ))
65        }
66    };
67
68    // Direct convolution (optimized for small kernels)
69    let mut result = vec![0.0f64; out_len];
70
71    for (i, res) in result.iter_mut().enumerate() {
72        let full_idx = i + offset;
73        let mut sum = 0.0f64;
74        for (j, &vj) in v_slice.iter().enumerate() {
75            let ai = full_idx as isize - j as isize;
76            if ai >= 0 && (ai as usize) < n {
77                sum += a_slice[ai as usize] * vj;
78            }
79        }
80        *res = sum;
81    }
82
83    scirs_to_numpy_array1(Array1::from_vec(result), py)
84}
85
86/// Cross-correlation of two 1-D arrays - optimized direct implementation
87///
88/// Parameters:
89/// - a: First input array
90/// - v: Second input array
91/// - mode: 'full', 'same', or 'valid'
92#[pyfunction]
93#[pyo3(signature = (a, v, mode="full"))]
94fn correlate_py(
95    py: Python,
96    a: PyReadonlyArray1<f64>,
97    v: PyReadonlyArray1<f64>,
98    mode: &str,
99) -> PyResult<Py<PyArray1<f64>>> {
100    let a_arr = a.as_array();
101    let v_arr = v.as_array();
102    let a_slice = a_arr.as_slice().expect("Operation failed");
103    let v_slice = v_arr.as_slice().expect("Operation failed");
104    let n = a_slice.len();
105    let m = v_slice.len();
106
107    if n == 0 || m == 0 {
108        return Err(pyo3::exceptions::PyValueError::new_err(
109            "Arrays must not be empty",
110        ));
111    }
112
113    // Reverse kernel for correlation (correlation = convolution with reversed kernel)
114    // Calculate output size based on mode
115    let (out_len, offset) = match mode {
116        "full" => (n + m - 1, 0),
117        "same" => (n, (m - 1) / 2),
118        "valid" => {
119            if n < m {
120                return Err(pyo3::exceptions::PyValueError::new_err(
121                    "For 'valid' mode, first array must be at least as long as second",
122                ));
123            }
124            (n - m + 1, m - 1)
125        }
126        _ => {
127            return Err(pyo3::exceptions::PyValueError::new_err(
128                "mode must be 'full', 'same', or 'valid'",
129            ))
130        }
131    };
132
133    // Direct correlation
134    let mut result = vec![0.0f64; out_len];
135
136    for (i, res) in result.iter_mut().enumerate() {
137        let full_idx = i + offset;
138        let mut sum = 0.0f64;
139        for (j, &vj) in v_slice.iter().rev().enumerate() {
140            let ai = full_idx as isize - j as isize;
141            if ai >= 0 && (ai as usize) < n {
142                sum += a_slice[ai as usize] * vj;
143            }
144        }
145        *res = sum;
146    }
147
148    scirs_to_numpy_array1(Array1::from_vec(result), py)
149}
150
151// =============================================================================
152// Hilbert Transform
153// =============================================================================
154
155/// Compute the analytic signal using Hilbert transform
156///
157/// Returns the analytic signal (real and imaginary parts separately)
158#[pyfunction]
159fn hilbert_py(py: Python, x: PyReadonlyArray1<f64>) -> PyResult<Py<PyAny>> {
160    let x_slice = x.as_array().to_vec();
161
162    let result =
163        hilbert(&x_slice).map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
164
165    // Extract real and imaginary parts
166    let real: Vec<f64> = result.iter().map(|c| c.re).collect();
167    let imag: Vec<f64> = result.iter().map(|c| c.im).collect();
168
169    let dict = PyDict::new(py);
170    dict.set_item("real", scirs_to_numpy_array1(Array1::from_vec(real), py)?)?;
171    dict.set_item("imag", scirs_to_numpy_array1(Array1::from_vec(imag), py)?)?;
172
173    Ok(dict.into())
174}
175
176// =============================================================================
177// Window Functions
178// =============================================================================
179
180/// Hann window
181#[pyfunction]
182fn hann_py(py: Python, n: usize) -> PyResult<Py<PyArray1<f64>>> {
183    let mut window = Vec::with_capacity(n);
184    for i in 0..n {
185        let val = 0.5 * (1.0 - (2.0 * std::f64::consts::PI * i as f64 / (n - 1) as f64).cos());
186        window.push(val);
187    }
188    scirs_to_numpy_array1(Array1::from_vec(window), py)
189}
190
191/// Hamming window
192#[pyfunction]
193fn hamming_py(py: Python, n: usize) -> PyResult<Py<PyArray1<f64>>> {
194    let mut window = Vec::with_capacity(n);
195    for i in 0..n {
196        let val = 0.54 - 0.46 * (2.0 * std::f64::consts::PI * i as f64 / (n - 1) as f64).cos();
197        window.push(val);
198    }
199    scirs_to_numpy_array1(Array1::from_vec(window), py)
200}
201
202/// Blackman window
203#[pyfunction]
204fn blackman_py(py: Python, n: usize) -> PyResult<Py<PyArray1<f64>>> {
205    let mut window = Vec::with_capacity(n);
206    for i in 0..n {
207        let t = 2.0 * std::f64::consts::PI * i as f64 / (n - 1) as f64;
208        let val = 0.42 - 0.5 * t.cos() + 0.08 * (2.0 * t).cos();
209        window.push(val);
210    }
211    scirs_to_numpy_array1(Array1::from_vec(window), py)
212}
213
214/// Bartlett (triangular) window
215#[pyfunction]
216fn bartlett_py(py: Python, n: usize) -> PyResult<Py<PyArray1<f64>>> {
217    let mut window = Vec::with_capacity(n);
218    let half = (n - 1) as f64 / 2.0;
219    for i in 0..n {
220        let val = 1.0 - ((i as f64 - half) / half).abs();
221        window.push(val);
222    }
223    scirs_to_numpy_array1(Array1::from_vec(window), py)
224}
225
226/// Kaiser window
227#[pyfunction]
228fn kaiser_py(py: Python, n: usize, beta: f64) -> PyResult<Py<PyArray1<f64>>> {
229    let mut window = Vec::with_capacity(n);
230
231    // Simple approximation of I0 (modified Bessel function)
232    fn bessel_i0(x: f64) -> f64 {
233        let mut sum = 1.0;
234        let mut term = 1.0;
235        for k in 1..50 {
236            term *= (x / 2.0).powi(2) / (k as f64).powi(2);
237            sum += term;
238            if term < 1e-12 {
239                break;
240            }
241        }
242        sum
243    }
244
245    let denom = bessel_i0(beta);
246    for i in 0..n {
247        let t = 2.0 * i as f64 / (n - 1) as f64 - 1.0;
248        let arg = beta * (1.0 - t * t).sqrt();
249        let val = bessel_i0(arg) / denom;
250        window.push(val);
251    }
252
253    scirs_to_numpy_array1(Array1::from_vec(window), py)
254}
255
256// =============================================================================
257// Filter Design
258// =============================================================================
259
260/// Design a Butterworth digital filter
261///
262/// Parameters:
263/// - order: Filter order
264/// - cutoff: Cutoff frequency (normalized 0-1, where 1 is Nyquist)
265/// - filter_type: 'lowpass', 'highpass'
266///
267/// Returns:
268/// - Dict with 'b' (numerator) and 'a' (denominator) coefficients
269#[pyfunction]
270#[pyo3(signature = (order, cutoff, filter_type="lowpass"))]
271fn butter_py(py: Python, order: usize, cutoff: f64, filter_type: &str) -> PyResult<Py<PyAny>> {
272    let ftype = match filter_type.to_lowercase().as_str() {
273        "lowpass" | "low" => FilterType::Lowpass,
274        "highpass" | "high" => FilterType::Highpass,
275        "bandpass" | "band" => FilterType::Bandpass,
276        "bandstop" | "stop" => FilterType::Bandstop,
277        _ => {
278            return Err(pyo3::exceptions::PyValueError::new_err(
279                "Invalid filter type. Use 'lowpass', 'highpass', 'bandpass', or 'bandstop'",
280            ));
281        }
282    };
283
284    let (b, a) = butter(order, cutoff, ftype)
285        .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
286
287    let dict = PyDict::new(py);
288    dict.set_item("b", scirs_to_numpy_array1(Array1::from_vec(b), py)?)?;
289    dict.set_item("a", scirs_to_numpy_array1(Array1::from_vec(a), py)?)?;
290
291    Ok(dict.into())
292}
293
294/// Design a Chebyshev Type I digital filter
295///
296/// Parameters:
297/// - order: Filter order
298/// - ripple: Passband ripple in dB
299/// - cutoff: Cutoff frequency (normalized 0-1, where 1 is Nyquist)
300/// - filter_type: 'lowpass', 'highpass'
301///
302/// Returns:
303/// - Dict with 'b' (numerator) and 'a' (denominator) coefficients
304#[pyfunction]
305#[pyo3(signature = (order, ripple, cutoff, filter_type="lowpass"))]
306fn cheby1_py(
307    py: Python,
308    order: usize,
309    ripple: f64,
310    cutoff: f64,
311    filter_type: &str,
312) -> PyResult<Py<PyAny>> {
313    let ftype = match filter_type.to_lowercase().as_str() {
314        "lowpass" | "low" => FilterType::Lowpass,
315        "highpass" | "high" => FilterType::Highpass,
316        _ => {
317            return Err(pyo3::exceptions::PyValueError::new_err(
318                "Invalid filter type for cheby1. Use 'lowpass' or 'highpass'",
319            ));
320        }
321    };
322
323    let (b, a) = cheby1(order, ripple, cutoff, ftype)
324        .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
325
326    let dict = PyDict::new(py);
327    dict.set_item("b", scirs_to_numpy_array1(Array1::from_vec(b), py)?)?;
328    dict.set_item("a", scirs_to_numpy_array1(Array1::from_vec(a), py)?)?;
329
330    Ok(dict.into())
331}
332
333/// Design a FIR filter using window method
334///
335/// Parameters:
336/// - numtaps: Number of filter taps (filter order + 1)
337/// - cutoff: Cutoff frequency (normalized 0-1, where 1 is Nyquist)
338/// - window: Window function ('hamming', 'hann', 'blackman', 'kaiser')
339/// - pass_zero: If true, lowpass; if false, highpass
340///
341/// Returns:
342/// - Filter coefficients as numpy array
343#[pyfunction]
344#[pyo3(signature = (numtaps, cutoff, window="hamming", pass_zero=true))]
345fn firwin_py(
346    py: Python,
347    numtaps: usize,
348    cutoff: f64,
349    window: &str,
350    pass_zero: bool,
351) -> PyResult<Py<PyArray1<f64>>> {
352    let coeffs = firwin(numtaps, cutoff, window, pass_zero)
353        .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
354
355    scirs_to_numpy_array1(Array1::from_vec(coeffs), py)
356}
357
358// =============================================================================
359// Peak Finding
360// =============================================================================
361
362/// Find peaks in a 1-D array
363///
364/// Returns indices of peaks
365#[pyfunction]
366#[pyo3(signature = (x, height=None, distance=None))]
367fn find_peaks_py(
368    py: Python,
369    x: PyReadonlyArray1<f64>,
370    height: Option<f64>,
371    distance: Option<usize>,
372) -> PyResult<Py<PyArray1<i64>>> {
373    let x_arr = x.as_array();
374    let n = x_arr.len();
375
376    if n < 3 {
377        return scirs_to_numpy_array1(Array1::from_vec(vec![]), py);
378    }
379
380    let mut peaks: Vec<i64> = Vec::new();
381
382    // Find local maxima
383    for i in 1..n - 1 {
384        if x_arr[i] > x_arr[i - 1] && x_arr[i] > x_arr[i + 1] {
385            // Check height threshold
386            if let Some(h) = height {
387                if x_arr[i] < h {
388                    continue;
389                }
390            }
391            peaks.push(i as i64);
392        }
393    }
394
395    // Apply distance filter
396    if let Some(dist) = distance {
397        let mut filtered = Vec::new();
398        for &peak in &peaks {
399            let keep = filtered
400                .iter()
401                .all(|&p: &i64| (peak - p).unsigned_abs() >= dist as u64);
402            if keep {
403                filtered.push(peak);
404            }
405        }
406        peaks = filtered;
407    }
408
409    scirs_to_numpy_array1(Array1::from_vec(peaks), py)
410}
411
412/// Python module registration
413pub fn register_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
414    // Convolution/correlation
415    m.add_function(wrap_pyfunction!(convolve_py, m)?)?;
416    m.add_function(wrap_pyfunction!(correlate_py, m)?)?;
417
418    // Hilbert transform
419    m.add_function(wrap_pyfunction!(hilbert_py, m)?)?;
420
421    // Window functions
422    m.add_function(wrap_pyfunction!(hann_py, m)?)?;
423    m.add_function(wrap_pyfunction!(hamming_py, m)?)?;
424    m.add_function(wrap_pyfunction!(blackman_py, m)?)?;
425    m.add_function(wrap_pyfunction!(bartlett_py, m)?)?;
426    m.add_function(wrap_pyfunction!(kaiser_py, m)?)?;
427
428    // Filter design
429    m.add_function(wrap_pyfunction!(butter_py, m)?)?;
430    m.add_function(wrap_pyfunction!(cheby1_py, m)?)?;
431    m.add_function(wrap_pyfunction!(firwin_py, m)?)?;
432
433    // Peak finding
434    m.add_function(wrap_pyfunction!(find_peaks_py, m)?)?;
435
436    Ok(())
437}