Skip to main content

scirs2/
signal.rs

1//! Python bindings for scirs2-signal
2//!
3//! Provides signal processing functions similar to scipy.signal
4
5use pyo3::prelude::*;
6use pyo3::types::PyDict;
7use scirs2_core::python::numpy_compat::{scirs_to_numpy_array1, Array1};
8use scirs2_numpy::{PyArray1, PyReadonlyArray1};
9
10// Import signal functions
11use scirs2_signal::hilbert::hilbert;
12
13// Import filter design functions
14use scirs2_signal::filter::fir::firwin;
15use scirs2_signal::filter::iir::{butter, cheby1};
16use scirs2_signal::filter::FilterType;
17
18// =============================================================================
19// Convolution and Correlation
20// =============================================================================
21
22/// Convolve two 1-D arrays - optimized direct implementation
23///
24/// Parameters:
25/// - a: First input array
26/// - v: Second input array (kernel)
27/// - mode: 'full', 'same', or 'valid'
28#[pyfunction]
29#[pyo3(signature = (a, v, mode="full"))]
30fn convolve_py(
31    py: Python,
32    a: PyReadonlyArray1<f64>,
33    v: PyReadonlyArray1<f64>,
34    mode: &str,
35) -> PyResult<Py<PyArray1<f64>>> {
36    let a_arr = a.as_array();
37    let v_arr = v.as_array();
38    let a_slice = a_arr.as_slice().ok_or_else(|| {
39        pyo3::exceptions::PyValueError::new_err("Array 'a' is not contiguous in memory")
40    })?;
41    let v_slice = v_arr.as_slice().ok_or_else(|| {
42        pyo3::exceptions::PyValueError::new_err("Array 'v' is not contiguous in memory")
43    })?;
44    let n = a_slice.len();
45    let m = v_slice.len();
46
47    if n == 0 || m == 0 {
48        return Err(pyo3::exceptions::PyValueError::new_err(
49            "Arrays must not be empty",
50        ));
51    }
52
53    // Calculate output size based on mode
54    let (out_len, offset) = match mode {
55        "full" => (n + m - 1, 0),
56        "same" => (n, (m - 1) / 2),
57        "valid" => {
58            if n < m {
59                return Err(pyo3::exceptions::PyValueError::new_err(
60                    "For 'valid' mode, first array must be at least as long as second",
61                ));
62            }
63            (n - m + 1, m - 1)
64        }
65        _ => {
66            return Err(pyo3::exceptions::PyValueError::new_err(
67                "mode must be 'full', 'same', or 'valid'",
68            ))
69        }
70    };
71
72    // Direct convolution (optimized for small kernels)
73    let mut result = vec![0.0f64; out_len];
74
75    for (i, res) in result.iter_mut().enumerate() {
76        let full_idx = i + offset;
77        let mut sum = 0.0f64;
78        for (j, &vj) in v_slice.iter().enumerate() {
79            let ai = full_idx as isize - j as isize;
80            if ai >= 0 && (ai as usize) < n {
81                sum += a_slice[ai as usize] * vj;
82            }
83        }
84        *res = sum;
85    }
86
87    scirs_to_numpy_array1(Array1::from_vec(result), py)
88}
89
90/// Cross-correlation of two 1-D arrays - optimized direct implementation
91///
92/// Parameters:
93/// - a: First input array
94/// - v: Second input array
95/// - mode: 'full', 'same', or 'valid'
96#[pyfunction]
97#[pyo3(signature = (a, v, mode="full"))]
98fn correlate_py(
99    py: Python,
100    a: PyReadonlyArray1<f64>,
101    v: PyReadonlyArray1<f64>,
102    mode: &str,
103) -> PyResult<Py<PyArray1<f64>>> {
104    let a_arr = a.as_array();
105    let v_arr = v.as_array();
106    let a_slice = a_arr.as_slice().ok_or_else(|| {
107        pyo3::exceptions::PyValueError::new_err("Array 'a' is not contiguous in memory")
108    })?;
109    let v_slice = v_arr.as_slice().ok_or_else(|| {
110        pyo3::exceptions::PyValueError::new_err("Array 'v' is not contiguous in memory")
111    })?;
112    let n = a_slice.len();
113    let m = v_slice.len();
114
115    if n == 0 || m == 0 {
116        return Err(pyo3::exceptions::PyValueError::new_err(
117            "Arrays must not be empty",
118        ));
119    }
120
121    // Reverse kernel for correlation (correlation = convolution with reversed kernel)
122    // Calculate output size based on mode
123    let (out_len, offset) = match mode {
124        "full" => (n + m - 1, 0),
125        "same" => (n, (m - 1) / 2),
126        "valid" => {
127            if n < m {
128                return Err(pyo3::exceptions::PyValueError::new_err(
129                    "For 'valid' mode, first array must be at least as long as second",
130                ));
131            }
132            (n - m + 1, m - 1)
133        }
134        _ => {
135            return Err(pyo3::exceptions::PyValueError::new_err(
136                "mode must be 'full', 'same', or 'valid'",
137            ))
138        }
139    };
140
141    // Direct correlation
142    let mut result = vec![0.0f64; out_len];
143
144    for (i, res) in result.iter_mut().enumerate() {
145        let full_idx = i + offset;
146        let mut sum = 0.0f64;
147        for (j, &vj) in v_slice.iter().rev().enumerate() {
148            let ai = full_idx as isize - j as isize;
149            if ai >= 0 && (ai as usize) < n {
150                sum += a_slice[ai as usize] * vj;
151            }
152        }
153        *res = sum;
154    }
155
156    scirs_to_numpy_array1(Array1::from_vec(result), py)
157}
158
159// =============================================================================
160// Hilbert Transform
161// =============================================================================
162
163/// Compute the analytic signal using Hilbert transform
164///
165/// Returns the analytic signal (real and imaginary parts separately)
166#[pyfunction]
167fn hilbert_py(py: Python, x: PyReadonlyArray1<f64>) -> PyResult<Py<PyAny>> {
168    let x_slice = x.as_array().to_vec();
169
170    let result =
171        hilbert(&x_slice).map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
172
173    // Extract real and imaginary parts
174    let real: Vec<f64> = result.iter().map(|c| c.re).collect();
175    let imag: Vec<f64> = result.iter().map(|c| c.im).collect();
176
177    let dict = PyDict::new(py);
178    dict.set_item("real", scirs_to_numpy_array1(Array1::from_vec(real), py)?)?;
179    dict.set_item("imag", scirs_to_numpy_array1(Array1::from_vec(imag), py)?)?;
180
181    Ok(dict.into())
182}
183
184// =============================================================================
185// Window Functions
186// =============================================================================
187
188/// Hann window
189#[pyfunction]
190fn hann_py(py: Python, n: usize) -> PyResult<Py<PyArray1<f64>>> {
191    let mut window = Vec::with_capacity(n);
192    for i in 0..n {
193        let val = 0.5 * (1.0 - (2.0 * std::f64::consts::PI * i as f64 / (n - 1) as f64).cos());
194        window.push(val);
195    }
196    scirs_to_numpy_array1(Array1::from_vec(window), py)
197}
198
199/// Hamming window
200#[pyfunction]
201fn hamming_py(py: Python, n: usize) -> PyResult<Py<PyArray1<f64>>> {
202    let mut window = Vec::with_capacity(n);
203    for i in 0..n {
204        let val = 0.54 - 0.46 * (2.0 * std::f64::consts::PI * i as f64 / (n - 1) as f64).cos();
205        window.push(val);
206    }
207    scirs_to_numpy_array1(Array1::from_vec(window), py)
208}
209
210/// Blackman window
211#[pyfunction]
212fn blackman_py(py: Python, n: usize) -> PyResult<Py<PyArray1<f64>>> {
213    let mut window = Vec::with_capacity(n);
214    for i in 0..n {
215        let t = 2.0 * std::f64::consts::PI * i as f64 / (n - 1) as f64;
216        let val = 0.42 - 0.5 * t.cos() + 0.08 * (2.0 * t).cos();
217        window.push(val);
218    }
219    scirs_to_numpy_array1(Array1::from_vec(window), py)
220}
221
222/// Bartlett (triangular) window
223#[pyfunction]
224fn bartlett_py(py: Python, n: usize) -> PyResult<Py<PyArray1<f64>>> {
225    let mut window = Vec::with_capacity(n);
226    let half = (n - 1) as f64 / 2.0;
227    for i in 0..n {
228        let val = 1.0 - ((i as f64 - half) / half).abs();
229        window.push(val);
230    }
231    scirs_to_numpy_array1(Array1::from_vec(window), py)
232}
233
234/// Kaiser window
235#[pyfunction]
236fn kaiser_py(py: Python, n: usize, beta: f64) -> PyResult<Py<PyArray1<f64>>> {
237    let mut window = Vec::with_capacity(n);
238
239    // Simple approximation of I0 (modified Bessel function)
240    fn bessel_i0(x: f64) -> f64 {
241        let mut sum = 1.0;
242        let mut term = 1.0;
243        for k in 1..50 {
244            term *= (x / 2.0).powi(2) / (k as f64).powi(2);
245            sum += term;
246            if term < 1e-12 {
247                break;
248            }
249        }
250        sum
251    }
252
253    let denom = bessel_i0(beta);
254    for i in 0..n {
255        let t = 2.0 * i as f64 / (n - 1) as f64 - 1.0;
256        let arg = beta * (1.0 - t * t).sqrt();
257        let val = bessel_i0(arg) / denom;
258        window.push(val);
259    }
260
261    scirs_to_numpy_array1(Array1::from_vec(window), py)
262}
263
264// =============================================================================
265// Filter Design
266// =============================================================================
267
268/// Design a Butterworth digital filter
269///
270/// Parameters:
271/// - order: Filter order
272/// - cutoff: Cutoff frequency (normalized 0-1, where 1 is Nyquist)
273/// - filter_type: 'lowpass', 'highpass'
274///
275/// Returns:
276/// - Dict with 'b' (numerator) and 'a' (denominator) coefficients
277#[pyfunction]
278#[pyo3(signature = (order, cutoff, filter_type="lowpass"))]
279fn butter_py(py: Python, order: usize, cutoff: f64, filter_type: &str) -> PyResult<Py<PyAny>> {
280    let ftype = match filter_type.to_lowercase().as_str() {
281        "lowpass" | "low" => FilterType::Lowpass,
282        "highpass" | "high" => FilterType::Highpass,
283        "bandpass" | "band" => FilterType::Bandpass,
284        "bandstop" | "stop" => FilterType::Bandstop,
285        _ => {
286            return Err(pyo3::exceptions::PyValueError::new_err(
287                "Invalid filter type. Use 'lowpass', 'highpass', 'bandpass', or 'bandstop'",
288            ));
289        }
290    };
291
292    let (b, a) = butter(order, cutoff, ftype)
293        .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
294
295    let dict = PyDict::new(py);
296    dict.set_item("b", scirs_to_numpy_array1(Array1::from_vec(b), py)?)?;
297    dict.set_item("a", scirs_to_numpy_array1(Array1::from_vec(a), py)?)?;
298
299    Ok(dict.into())
300}
301
302/// Design a Chebyshev Type I digital filter
303///
304/// Parameters:
305/// - order: Filter order
306/// - ripple: Passband ripple in dB
307/// - cutoff: Cutoff frequency (normalized 0-1, where 1 is Nyquist)
308/// - filter_type: 'lowpass', 'highpass'
309///
310/// Returns:
311/// - Dict with 'b' (numerator) and 'a' (denominator) coefficients
312#[pyfunction]
313#[pyo3(signature = (order, ripple, cutoff, filter_type="lowpass"))]
314fn cheby1_py(
315    py: Python,
316    order: usize,
317    ripple: f64,
318    cutoff: f64,
319    filter_type: &str,
320) -> PyResult<Py<PyAny>> {
321    let ftype = match filter_type.to_lowercase().as_str() {
322        "lowpass" | "low" => FilterType::Lowpass,
323        "highpass" | "high" => FilterType::Highpass,
324        _ => {
325            return Err(pyo3::exceptions::PyValueError::new_err(
326                "Invalid filter type for cheby1. Use 'lowpass' or 'highpass'",
327            ));
328        }
329    };
330
331    let (b, a) = cheby1(order, ripple, cutoff, ftype)
332        .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
333
334    let dict = PyDict::new(py);
335    dict.set_item("b", scirs_to_numpy_array1(Array1::from_vec(b), py)?)?;
336    dict.set_item("a", scirs_to_numpy_array1(Array1::from_vec(a), py)?)?;
337
338    Ok(dict.into())
339}
340
341/// Design a FIR filter using window method
342///
343/// Parameters:
344/// - numtaps: Number of filter taps (filter order + 1)
345/// - cutoff: Cutoff frequency (normalized 0-1, where 1 is Nyquist)
346/// - window: Window function ('hamming', 'hann', 'blackman', 'kaiser')
347/// - pass_zero: If true, lowpass; if false, highpass
348///
349/// Returns:
350/// - Filter coefficients as numpy array
351#[pyfunction]
352#[pyo3(signature = (numtaps, cutoff, window="hamming", pass_zero=true))]
353fn firwin_py(
354    py: Python,
355    numtaps: usize,
356    cutoff: f64,
357    window: &str,
358    pass_zero: bool,
359) -> PyResult<Py<PyArray1<f64>>> {
360    let coeffs = firwin(numtaps, cutoff, window, pass_zero)
361        .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
362
363    scirs_to_numpy_array1(Array1::from_vec(coeffs), py)
364}
365
366// =============================================================================
367// Peak Finding
368// =============================================================================
369
370/// Find peaks in a 1-D array
371///
372/// Returns indices of peaks
373#[pyfunction]
374#[pyo3(signature = (x, height=None, distance=None))]
375fn find_peaks_py(
376    py: Python,
377    x: PyReadonlyArray1<f64>,
378    height: Option<f64>,
379    distance: Option<usize>,
380) -> PyResult<Py<PyArray1<i64>>> {
381    let x_arr = x.as_array();
382    let n = x_arr.len();
383
384    if n < 3 {
385        return scirs_to_numpy_array1(Array1::from_vec(vec![]), py);
386    }
387
388    let mut peaks: Vec<i64> = Vec::new();
389
390    // Find local maxima
391    for i in 1..n - 1 {
392        if x_arr[i] > x_arr[i - 1] && x_arr[i] > x_arr[i + 1] {
393            // Check height threshold
394            if let Some(h) = height {
395                if x_arr[i] < h {
396                    continue;
397                }
398            }
399            peaks.push(i as i64);
400        }
401    }
402
403    // Apply distance filter
404    if let Some(dist) = distance {
405        let mut filtered = Vec::new();
406        for &peak in &peaks {
407            let keep = filtered
408                .iter()
409                .all(|&p: &i64| (peak - p).unsigned_abs() >= dist as u64);
410            if keep {
411                filtered.push(peak);
412            }
413        }
414        peaks = filtered;
415    }
416
417    scirs_to_numpy_array1(Array1::from_vec(peaks), py)
418}
419
420/// Python module registration
421pub fn register_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
422    // Convolution/correlation
423    m.add_function(wrap_pyfunction!(convolve_py, m)?)?;
424    m.add_function(wrap_pyfunction!(correlate_py, m)?)?;
425
426    // Hilbert transform
427    m.add_function(wrap_pyfunction!(hilbert_py, m)?)?;
428
429    // Window functions
430    m.add_function(wrap_pyfunction!(hann_py, m)?)?;
431    m.add_function(wrap_pyfunction!(hamming_py, m)?)?;
432    m.add_function(wrap_pyfunction!(blackman_py, m)?)?;
433    m.add_function(wrap_pyfunction!(bartlett_py, m)?)?;
434    m.add_function(wrap_pyfunction!(kaiser_py, m)?)?;
435
436    // Filter design
437    m.add_function(wrap_pyfunction!(butter_py, m)?)?;
438    m.add_function(wrap_pyfunction!(cheby1_py, m)?)?;
439    m.add_function(wrap_pyfunction!(firwin_py, m)?)?;
440
441    // Peak finding
442    m.add_function(wrap_pyfunction!(find_peaks_py, m)?)?;
443
444    Ok(())
445}