1use pyo3::prelude::*;
6use pyo3::types::PyDict;
7use scirs2_core::python::numpy_compat::{scirs_to_numpy_array1, Array1};
8use scirs2_numpy::{PyArray1, PyReadonlyArray1};
9
10use scirs2_signal::hilbert::hilbert;
12
13use scirs2_signal::filter::fir::firwin;
15use scirs2_signal::filter::iir::{butter, cheby1};
16use scirs2_signal::filter::FilterType;
17
18#[pyfunction]
29#[pyo3(signature = (a, v, mode="full"))]
30fn convolve_py(
31 py: Python,
32 a: PyReadonlyArray1<f64>,
33 v: PyReadonlyArray1<f64>,
34 mode: &str,
35) -> PyResult<Py<PyArray1<f64>>> {
36 let a_arr = a.as_array();
37 let v_arr = v.as_array();
38 let a_slice = a_arr.as_slice().expect("Operation failed");
39 let v_slice = v_arr.as_slice().expect("Operation failed");
40 let n = a_slice.len();
41 let m = v_slice.len();
42
43 if n == 0 || m == 0 {
44 return Err(pyo3::exceptions::PyValueError::new_err(
45 "Arrays must not be empty",
46 ));
47 }
48
49 let (out_len, offset) = match mode {
51 "full" => (n + m - 1, 0),
52 "same" => (n, (m - 1) / 2),
53 "valid" => {
54 if n < m {
55 return Err(pyo3::exceptions::PyValueError::new_err(
56 "For 'valid' mode, first array must be at least as long as second",
57 ));
58 }
59 (n - m + 1, m - 1)
60 }
61 _ => {
62 return Err(pyo3::exceptions::PyValueError::new_err(
63 "mode must be 'full', 'same', or 'valid'",
64 ))
65 }
66 };
67
68 let mut result = vec![0.0f64; out_len];
70
71 for (i, res) in result.iter_mut().enumerate() {
72 let full_idx = i + offset;
73 let mut sum = 0.0f64;
74 for (j, &vj) in v_slice.iter().enumerate() {
75 let ai = full_idx as isize - j as isize;
76 if ai >= 0 && (ai as usize) < n {
77 sum += a_slice[ai as usize] * vj;
78 }
79 }
80 *res = sum;
81 }
82
83 scirs_to_numpy_array1(Array1::from_vec(result), py)
84}
85
86#[pyfunction]
93#[pyo3(signature = (a, v, mode="full"))]
94fn correlate_py(
95 py: Python,
96 a: PyReadonlyArray1<f64>,
97 v: PyReadonlyArray1<f64>,
98 mode: &str,
99) -> PyResult<Py<PyArray1<f64>>> {
100 let a_arr = a.as_array();
101 let v_arr = v.as_array();
102 let a_slice = a_arr.as_slice().expect("Operation failed");
103 let v_slice = v_arr.as_slice().expect("Operation failed");
104 let n = a_slice.len();
105 let m = v_slice.len();
106
107 if n == 0 || m == 0 {
108 return Err(pyo3::exceptions::PyValueError::new_err(
109 "Arrays must not be empty",
110 ));
111 }
112
113 let (out_len, offset) = match mode {
116 "full" => (n + m - 1, 0),
117 "same" => (n, (m - 1) / 2),
118 "valid" => {
119 if n < m {
120 return Err(pyo3::exceptions::PyValueError::new_err(
121 "For 'valid' mode, first array must be at least as long as second",
122 ));
123 }
124 (n - m + 1, m - 1)
125 }
126 _ => {
127 return Err(pyo3::exceptions::PyValueError::new_err(
128 "mode must be 'full', 'same', or 'valid'",
129 ))
130 }
131 };
132
133 let mut result = vec![0.0f64; out_len];
135
136 for (i, res) in result.iter_mut().enumerate() {
137 let full_idx = i + offset;
138 let mut sum = 0.0f64;
139 for (j, &vj) in v_slice.iter().rev().enumerate() {
140 let ai = full_idx as isize - j as isize;
141 if ai >= 0 && (ai as usize) < n {
142 sum += a_slice[ai as usize] * vj;
143 }
144 }
145 *res = sum;
146 }
147
148 scirs_to_numpy_array1(Array1::from_vec(result), py)
149}
150
151#[pyfunction]
159fn hilbert_py(py: Python, x: PyReadonlyArray1<f64>) -> PyResult<Py<PyAny>> {
160 let x_slice = x.as_array().to_vec();
161
162 let result =
163 hilbert(&x_slice).map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
164
165 let real: Vec<f64> = result.iter().map(|c| c.re).collect();
167 let imag: Vec<f64> = result.iter().map(|c| c.im).collect();
168
169 let dict = PyDict::new(py);
170 dict.set_item("real", scirs_to_numpy_array1(Array1::from_vec(real), py)?)?;
171 dict.set_item("imag", scirs_to_numpy_array1(Array1::from_vec(imag), py)?)?;
172
173 Ok(dict.into())
174}
175
176#[pyfunction]
182fn hann_py(py: Python, n: usize) -> PyResult<Py<PyArray1<f64>>> {
183 let mut window = Vec::with_capacity(n);
184 for i in 0..n {
185 let val = 0.5 * (1.0 - (2.0 * std::f64::consts::PI * i as f64 / (n - 1) as f64).cos());
186 window.push(val);
187 }
188 scirs_to_numpy_array1(Array1::from_vec(window), py)
189}
190
191#[pyfunction]
193fn hamming_py(py: Python, n: usize) -> PyResult<Py<PyArray1<f64>>> {
194 let mut window = Vec::with_capacity(n);
195 for i in 0..n {
196 let val = 0.54 - 0.46 * (2.0 * std::f64::consts::PI * i as f64 / (n - 1) as f64).cos();
197 window.push(val);
198 }
199 scirs_to_numpy_array1(Array1::from_vec(window), py)
200}
201
202#[pyfunction]
204fn blackman_py(py: Python, n: usize) -> PyResult<Py<PyArray1<f64>>> {
205 let mut window = Vec::with_capacity(n);
206 for i in 0..n {
207 let t = 2.0 * std::f64::consts::PI * i as f64 / (n - 1) as f64;
208 let val = 0.42 - 0.5 * t.cos() + 0.08 * (2.0 * t).cos();
209 window.push(val);
210 }
211 scirs_to_numpy_array1(Array1::from_vec(window), py)
212}
213
214#[pyfunction]
216fn bartlett_py(py: Python, n: usize) -> PyResult<Py<PyArray1<f64>>> {
217 let mut window = Vec::with_capacity(n);
218 let half = (n - 1) as f64 / 2.0;
219 for i in 0..n {
220 let val = 1.0 - ((i as f64 - half) / half).abs();
221 window.push(val);
222 }
223 scirs_to_numpy_array1(Array1::from_vec(window), py)
224}
225
226#[pyfunction]
228fn kaiser_py(py: Python, n: usize, beta: f64) -> PyResult<Py<PyArray1<f64>>> {
229 let mut window = Vec::with_capacity(n);
230
231 fn bessel_i0(x: f64) -> f64 {
233 let mut sum = 1.0;
234 let mut term = 1.0;
235 for k in 1..50 {
236 term *= (x / 2.0).powi(2) / (k as f64).powi(2);
237 sum += term;
238 if term < 1e-12 {
239 break;
240 }
241 }
242 sum
243 }
244
245 let denom = bessel_i0(beta);
246 for i in 0..n {
247 let t = 2.0 * i as f64 / (n - 1) as f64 - 1.0;
248 let arg = beta * (1.0 - t * t).sqrt();
249 let val = bessel_i0(arg) / denom;
250 window.push(val);
251 }
252
253 scirs_to_numpy_array1(Array1::from_vec(window), py)
254}
255
256#[pyfunction]
270#[pyo3(signature = (order, cutoff, filter_type="lowpass"))]
271fn butter_py(py: Python, order: usize, cutoff: f64, filter_type: &str) -> PyResult<Py<PyAny>> {
272 let ftype = match filter_type.to_lowercase().as_str() {
273 "lowpass" | "low" => FilterType::Lowpass,
274 "highpass" | "high" => FilterType::Highpass,
275 "bandpass" | "band" => FilterType::Bandpass,
276 "bandstop" | "stop" => FilterType::Bandstop,
277 _ => {
278 return Err(pyo3::exceptions::PyValueError::new_err(
279 "Invalid filter type. Use 'lowpass', 'highpass', 'bandpass', or 'bandstop'",
280 ));
281 }
282 };
283
284 let (b, a) = butter(order, cutoff, ftype)
285 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
286
287 let dict = PyDict::new(py);
288 dict.set_item("b", scirs_to_numpy_array1(Array1::from_vec(b), py)?)?;
289 dict.set_item("a", scirs_to_numpy_array1(Array1::from_vec(a), py)?)?;
290
291 Ok(dict.into())
292}
293
294#[pyfunction]
305#[pyo3(signature = (order, ripple, cutoff, filter_type="lowpass"))]
306fn cheby1_py(
307 py: Python,
308 order: usize,
309 ripple: f64,
310 cutoff: f64,
311 filter_type: &str,
312) -> PyResult<Py<PyAny>> {
313 let ftype = match filter_type.to_lowercase().as_str() {
314 "lowpass" | "low" => FilterType::Lowpass,
315 "highpass" | "high" => FilterType::Highpass,
316 _ => {
317 return Err(pyo3::exceptions::PyValueError::new_err(
318 "Invalid filter type for cheby1. Use 'lowpass' or 'highpass'",
319 ));
320 }
321 };
322
323 let (b, a) = cheby1(order, ripple, cutoff, ftype)
324 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
325
326 let dict = PyDict::new(py);
327 dict.set_item("b", scirs_to_numpy_array1(Array1::from_vec(b), py)?)?;
328 dict.set_item("a", scirs_to_numpy_array1(Array1::from_vec(a), py)?)?;
329
330 Ok(dict.into())
331}
332
333#[pyfunction]
344#[pyo3(signature = (numtaps, cutoff, window="hamming", pass_zero=true))]
345fn firwin_py(
346 py: Python,
347 numtaps: usize,
348 cutoff: f64,
349 window: &str,
350 pass_zero: bool,
351) -> PyResult<Py<PyArray1<f64>>> {
352 let coeffs = firwin(numtaps, cutoff, window, pass_zero)
353 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
354
355 scirs_to_numpy_array1(Array1::from_vec(coeffs), py)
356}
357
358#[pyfunction]
366#[pyo3(signature = (x, height=None, distance=None))]
367fn find_peaks_py(
368 py: Python,
369 x: PyReadonlyArray1<f64>,
370 height: Option<f64>,
371 distance: Option<usize>,
372) -> PyResult<Py<PyArray1<i64>>> {
373 let x_arr = x.as_array();
374 let n = x_arr.len();
375
376 if n < 3 {
377 return scirs_to_numpy_array1(Array1::from_vec(vec![]), py);
378 }
379
380 let mut peaks: Vec<i64> = Vec::new();
381
382 for i in 1..n - 1 {
384 if x_arr[i] > x_arr[i - 1] && x_arr[i] > x_arr[i + 1] {
385 if let Some(h) = height {
387 if x_arr[i] < h {
388 continue;
389 }
390 }
391 peaks.push(i as i64);
392 }
393 }
394
395 if let Some(dist) = distance {
397 let mut filtered = Vec::new();
398 for &peak in &peaks {
399 let keep = filtered
400 .iter()
401 .all(|&p: &i64| (peak - p).unsigned_abs() >= dist as u64);
402 if keep {
403 filtered.push(peak);
404 }
405 }
406 peaks = filtered;
407 }
408
409 scirs_to_numpy_array1(Array1::from_vec(peaks), py)
410}
411
412pub fn register_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
414 m.add_function(wrap_pyfunction!(convolve_py, m)?)?;
416 m.add_function(wrap_pyfunction!(correlate_py, m)?)?;
417
418 m.add_function(wrap_pyfunction!(hilbert_py, m)?)?;
420
421 m.add_function(wrap_pyfunction!(hann_py, m)?)?;
423 m.add_function(wrap_pyfunction!(hamming_py, m)?)?;
424 m.add_function(wrap_pyfunction!(blackman_py, m)?)?;
425 m.add_function(wrap_pyfunction!(bartlett_py, m)?)?;
426 m.add_function(wrap_pyfunction!(kaiser_py, m)?)?;
427
428 m.add_function(wrap_pyfunction!(butter_py, m)?)?;
430 m.add_function(wrap_pyfunction!(cheby1_py, m)?)?;
431 m.add_function(wrap_pyfunction!(firwin_py, m)?)?;
432
433 m.add_function(wrap_pyfunction!(find_peaks_py, m)?)?;
435
436 Ok(())
437}