Skip to main content

scirs2/
async_ops.rs

1//! Async operations for Python
2//!
3//! This module provides async versions of long-running operations that can be awaited in Python.
4//!
5//! # Example (Python)
6//! ```python
7//! import asyncio
8//! import scirs2
9//! import numpy as np
10//!
11//! async def main():
12//!     # Async FFT for large arrays
13//!     data = np.random.randn(1_000_000)
14//!     result = await scirs2.fft_async(data)
15//!
16//!     # Async matrix decomposition
17//!     matrix = np.random.randn(1000, 1000)
18//!     svd = await scirs2.svd_async(matrix)
19//!
20//! asyncio.run(main())
21//! ```
22
23use crate::error::SciRS2Error;
24use pyo3::prelude::*;
25use pyo3_async_runtimes;
26use scirs2_numpy::{IntoPyArray, PyArray1, PyArray2, PyArrayMethods, PyUntypedArrayMethods};
27
28/// Async FFT operation for large arrays
29///
30/// This function runs FFT in a background thread and returns a Python awaitable.
31/// Useful for large arrays (>100k elements) to avoid blocking the event loop.
32#[pyfunction]
33pub fn fft_async<'py>(
34    py: Python<'py>,
35    data: &Bound<'_, PyArray1<f64>>,
36) -> PyResult<Bound<'py, PyAny>> {
37    let data_vec: Vec<f64> = {
38        let binding = data.readonly();
39        let arr = binding.as_array();
40        arr.iter().cloned().collect()
41    };
42
43    pyo3_async_runtimes::tokio::future_into_py(py, async move {
44        // Run FFT in blocking task
45        // FFT returns Vec<Complex64> - collect real and imag parts
46        let (real_part, imag_part): (Vec<f64>, Vec<f64>) = tokio::task::spawn_blocking(move || {
47            use scirs2_core::Complex64;
48            use scirs2_fft::fft;
49
50            let result: Vec<Complex64> = fft(data_vec.as_slice(), None)
51                .map_err(|e| SciRS2Error::ComputationError(format!("FFT failed: {}", e)))?;
52
53            let real: Vec<f64> = result.iter().map(|c| c.re).collect();
54            let imag: Vec<f64> = result.iter().map(|c| c.im).collect();
55            Ok::<(Vec<f64>, Vec<f64>), SciRS2Error>((real, imag))
56        })
57        .await
58        .map_err(|e| SciRS2Error::RuntimeError(format!("Task join error: {}", e)))??;
59
60        // Return as Py<PyAny> which is Send
61        let py_result: Py<PyAny> = Python::attach(|py| {
62            use pyo3::types::PyDict;
63            use scirs2_core::Array1;
64            // Return dict with real and imag arrays
65            let dict = PyDict::new(py);
66            let real_arr: Array1<f64> = Array1::from_vec(real_part);
67            let imag_arr: Array1<f64> = Array1::from_vec(imag_part);
68            dict.set_item("real", real_arr.into_pyarray(py))?;
69            dict.set_item("imag", imag_arr.into_pyarray(py))?;
70            Ok::<Py<PyAny>, PyErr>(dict.into_any().unbind())
71        })?;
72
73        Ok(py_result)
74    })
75}
76
77/// Async SVD operation for large matrices
78///
79/// This function runs SVD in a background thread and returns a Python awaitable.
80/// Useful for large matrices (>500x500) to avoid blocking the event loop.
81#[pyfunction]
82pub fn svd_async<'py>(
83    py: Python<'py>,
84    matrix: &Bound<'_, PyArray2<f64>>,
85    full_matrices: Option<bool>,
86) -> PyResult<Bound<'py, PyAny>> {
87    let matrix_shape = matrix.shape().to_vec();
88    let matrix_vec: Vec<f64> = {
89        let binding = matrix.readonly();
90        let arr = binding.as_array();
91        arr.iter().cloned().collect()
92    };
93    let full_matrices = full_matrices.unwrap_or(true);
94
95    pyo3_async_runtimes::tokio::future_into_py(py, async move {
96        // Run SVD in blocking task
97        let result = tokio::task::spawn_blocking(move || {
98            use scirs2_core::Array2;
99            use scirs2_linalg::svd_f64_lapack;
100
101            let arr = Array2::from_shape_vec((matrix_shape[0], matrix_shape[1]), matrix_vec)
102                .map_err(|e| SciRS2Error::ArrayError(format!("Array reshape failed: {}", e)))?;
103
104            svd_f64_lapack(&arr.view(), full_matrices)
105                .map_err(|e| SciRS2Error::ComputationError(format!("SVD failed: {}", e)))
106        })
107        .await
108        .map_err(|e| SciRS2Error::RuntimeError(format!("Task join error: {}", e)))??;
109
110        // Convert result to Python dict; return Py<PyAny> which is Send
111        let py_result: Py<PyAny> = Python::attach(|py| {
112            use pyo3::types::PyDict;
113            let dict = PyDict::new(py);
114            dict.set_item("U", result.0.into_pyarray(py))?;
115            dict.set_item("S", result.1.into_pyarray(py))?;
116            dict.set_item("Vt", result.2.into_pyarray(py))?;
117            Ok::<Py<PyAny>, PyErr>(dict.into_any().unbind())
118        })?;
119
120        Ok(py_result)
121    })
122}
123
124/// Async QR decomposition for large matrices
125#[pyfunction]
126pub fn qr_async<'py>(
127    py: Python<'py>,
128    matrix: &Bound<'_, PyArray2<f64>>,
129) -> PyResult<Bound<'py, PyAny>> {
130    let matrix_shape = matrix.shape().to_vec();
131    let matrix_vec: Vec<f64> = {
132        let binding = matrix.readonly();
133        let arr = binding.as_array();
134        arr.iter().cloned().collect()
135    };
136
137    pyo3_async_runtimes::tokio::future_into_py(py, async move {
138        let result = tokio::task::spawn_blocking(move || {
139            use scirs2_core::Array2;
140            use scirs2_linalg::qr_f64_lapack;
141
142            let arr = Array2::from_shape_vec((matrix_shape[0], matrix_shape[1]), matrix_vec)
143                .map_err(|e| SciRS2Error::ArrayError(format!("Array reshape failed: {}", e)))?;
144
145            qr_f64_lapack(&arr.view())
146                .map_err(|e| SciRS2Error::ComputationError(format!("QR failed: {}", e)))
147        })
148        .await
149        .map_err(|e| SciRS2Error::RuntimeError(format!("Task join error: {}", e)))??;
150
151        let py_result: Py<PyAny> = Python::attach(|py| {
152            use pyo3::types::PyDict;
153            let dict = PyDict::new(py);
154            dict.set_item("Q", result.0.into_pyarray(py))?;
155            dict.set_item("R", result.1.into_pyarray(py))?;
156            Ok::<Py<PyAny>, PyErr>(dict.into_any().unbind())
157        })?;
158
159        Ok(py_result)
160    })
161}
162
163/// Async numerical integration for expensive integrands
164#[pyfunction]
165pub fn quad_async<'py>(
166    py: Python<'py>,
167    func: Py<PyAny>,
168    a: f64,
169    b: f64,
170    epsabs: Option<f64>,
171    epsrel: Option<f64>,
172) -> PyResult<Bound<'py, PyAny>> {
173    pyo3_async_runtimes::tokio::future_into_py(py, async move {
174        let result: (f64, f64) = tokio::task::spawn_blocking(move || {
175            Python::attach(|py| {
176                use scirs2_integrate::quad::{quad, QuadOptions};
177
178                let abs_tol = epsabs.unwrap_or(1e-8);
179                let rel_tol = epsrel.unwrap_or(1e-8);
180
181                // Create Rust closure that calls Python function
182                let integrand = |x: f64| -> f64 {
183                    func.call1(py, (x,))
184                        .and_then(|result| result.extract::<f64>(py))
185                        .unwrap_or(f64::NAN)
186                };
187
188                let options = QuadOptions {
189                    abs_tol,
190                    rel_tol,
191                    ..Default::default()
192                };
193
194                let result = quad(integrand, a, b, Some(options)).map_err(|e| {
195                    PyErr::from(SciRS2Error::ComputationError(format!(
196                        "Integration failed: {}",
197                        e
198                    )))
199                })?;
200
201                Ok::<(f64, f64), PyErr>((result.value, result.abs_error))
202            })
203        })
204        .await
205        .map_err(|e| SciRS2Error::RuntimeError(format!("Task join error: {}", e)))??;
206
207        let py_result: Py<PyAny> = Python::attach(|py| {
208            use pyo3::types::PyDict;
209            let dict = PyDict::new(py);
210            dict.set_item("value", result.0)?;
211            dict.set_item("error", result.1)?;
212            Ok::<Py<PyAny>, PyErr>(dict.into_any().unbind())
213        })?;
214
215        Ok(py_result)
216    })
217}
218
219/// Async optimization for expensive objective functions
220#[pyfunction]
221pub fn minimize_async<'py>(
222    py: Python<'py>,
223    func: Py<PyAny>,
224    x0: &Bound<'_, PyArray1<f64>>,
225    method: Option<String>,
226    maxiter: Option<usize>,
227) -> PyResult<Bound<'py, PyAny>> {
228    let x0_vec: Vec<f64> = {
229        let binding = x0.readonly();
230        let arr = binding.as_array();
231        arr.iter().cloned().collect()
232    };
233
234    pyo3_async_runtimes::tokio::future_into_py(py, async move {
235        let result: (Vec<f64>, f64, usize) = tokio::task::spawn_blocking(move || {
236            Python::attach(|py| {
237                use scirs2_core::ndarray::ArrayView1;
238                use scirs2_optimize::unconstrained::{minimize, Method};
239
240                // Create Rust closure that calls Python function
241                let objective = |x: &ArrayView1<f64>| -> f64 {
242                    let x_slice = x.as_slice().unwrap_or(&[]);
243                    let x_py = match pyo3::types::PyList::new(py, x_slice) {
244                        Ok(list) => list,
245                        Err(_) => return f64::NAN,
246                    };
247                    func.call1(py, (x_py,))
248                        .and_then(|r| r.extract::<f64>(py))
249                        .unwrap_or(f64::NAN)
250                };
251
252                let opt_method = match method.as_deref() {
253                    Some("BFGS") => Method::BFGS,
254                    Some("Newton") | Some("NewtonCG") => Method::NewtonCG,
255                    Some("GradientDescent") | Some("CG") => Method::CG,
256                    Some("NelderMead") => Method::NelderMead,
257                    Some("LBFGS") => Method::LBFGS,
258                    _ => Method::BFGS,
259                };
260
261                use scirs2_optimize::unconstrained::Options;
262                let options = Options {
263                    max_iter: maxiter.unwrap_or(1000),
264                    ..Default::default()
265                };
266
267                let result =
268                    minimize(objective, &x0_vec, opt_method, Some(options)).map_err(|e| {
269                        PyErr::from(SciRS2Error::ComputationError(format!(
270                            "Optimization failed: {}",
271                            e
272                        )))
273                    })?;
274
275                let x_vec = result.x.to_vec();
276                let fun_val: f64 = result.fun;
277                let nit = result.nit;
278                Ok::<(Vec<f64>, f64, usize), PyErr>((x_vec, fun_val, nit))
279            })
280        })
281        .await
282        .map_err(|e| SciRS2Error::RuntimeError(format!("Task join error: {}", e)))??;
283
284        let py_result: Py<PyAny> = Python::attach(|py| {
285            use pyo3::types::PyDict;
286            use scirs2_core::Array1;
287
288            let dict = PyDict::new(py);
289            let x = Array1::from_vec(result.0);
290            dict.set_item("x", x.into_pyarray(py))?;
291            dict.set_item("fun", result.1)?;
292            dict.set_item("nit", result.2)?;
293            Ok::<Py<PyAny>, PyErr>(dict.into_any().unbind())
294        })?;
295
296        Ok(py_result)
297    })
298}
299
300/// Register async operations with Python module
301pub fn register_async_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
302    m.add_function(wrap_pyfunction!(fft_async, m)?)?;
303    m.add_function(wrap_pyfunction!(svd_async, m)?)?;
304    m.add_function(wrap_pyfunction!(qr_async, m)?)?;
305    m.add_function(wrap_pyfunction!(quad_async, m)?)?;
306    m.add_function(wrap_pyfunction!(minimize_async, m)?)?;
307    Ok(())
308}
309
310// ─────────────────────────────────────────────────────────────────────────────
311// Tests
312// ─────────────────────────────────────────────────────────────────────────────
313
314#[cfg(test)]
315mod tests {
316    use pyo3::prelude::*;
317    use scirs2_core::Array1;
318    use scirs2_core::Array2;
319    use scirs2_numpy::{IntoPyArray, PyArray1, PyArray2};
320
321    /// Boot the embedded interpreter and register a fresh asyncio event loop as the
322    /// current thread's "running" loop, so that `pyo3_async_runtimes::tokio::future_into_py`
323    /// (which internally calls `asyncio.get_running_loop()`) succeeds in plain Rust tests.
324    fn install_running_loop(py: Python<'_>) {
325        let asyncio = py.import("asyncio").expect("import asyncio");
326        let loop_ = asyncio
327            .call_method0("new_event_loop")
328            .expect("new_event_loop");
329        let events = py.import("asyncio.events").expect("import asyncio.events");
330        events
331            .call_method1("_set_running_loop", (loop_,))
332            .expect("_set_running_loop");
333    }
334
335    /// Verify `fft_async` produces an awaitable (Bound<PyAny>) without panicking.
336    ///
337    /// We cannot `.await` the coroutine from plain Rust tests without a running
338    /// asyncio event loop, but we can confirm that:
339    ///   1. The function succeeds and returns a Python object.
340    ///   2. The returned object claims to be a coroutine / awaitable.
341    #[test]
342    fn fft_async_returns_awaitable() {
343        pyo3::Python::initialize();
344        Python::attach(|py| {
345            install_running_loop(py);
346            let data: Array1<f64> = Array1::from_vec(vec![1.0, 0.0, -1.0, 0.0]);
347            let py_arr: Bound<'_, PyArray1<f64>> = data.into_pyarray(py);
348
349            let result = super::fft_async(py, &py_arr);
350            assert!(result.is_ok(), "fft_async returned Err: {:?}", result.err());
351            // The returned object should have a `__await__` method (i.e. be a coroutine).
352            let obj = result.expect("fft_async should succeed");
353            assert!(
354                obj.hasattr("__await__").unwrap_or(false)
355                    || obj.hasattr("send").unwrap_or(false)
356                    || obj.hasattr("__next__").unwrap_or(false),
357                "returned object is not awaitable"
358            );
359        });
360    }
361
362    /// Verify `svd_async` produces an awaitable for a small square matrix.
363    #[test]
364    fn svd_async_returns_awaitable() {
365        pyo3::Python::initialize();
366        Python::attach(|py| {
367            install_running_loop(py);
368            // 2×2 identity matrix
369            let data: Array2<f64> =
370                Array2::from_shape_vec((2, 2), vec![1.0, 0.0, 0.0, 1.0]).expect("shape ok");
371            let py_arr: Bound<'_, PyArray2<f64>> = data.into_pyarray(py);
372
373            let result = super::svd_async(py, &py_arr, Some(false));
374            assert!(result.is_ok(), "svd_async returned Err: {:?}", result.err());
375            let obj = result.expect("svd_async should succeed");
376            assert!(
377                obj.hasattr("__await__").unwrap_or(false)
378                    || obj.hasattr("send").unwrap_or(false)
379                    || obj.hasattr("__next__").unwrap_or(false),
380                "returned object is not awaitable"
381            );
382        });
383    }
384
385    /// Verify `qr_async` produces an awaitable for a small matrix.
386    #[test]
387    fn qr_async_returns_awaitable() {
388        pyo3::Python::initialize();
389        Python::attach(|py| {
390            install_running_loop(py);
391            let data: Array2<f64> =
392                Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).expect("shape ok");
393            let py_arr: Bound<'_, PyArray2<f64>> = data.into_pyarray(py);
394
395            let result = super::qr_async(py, &py_arr);
396            assert!(result.is_ok(), "qr_async returned Err: {:?}", result.err());
397            let obj = result.expect("qr_async should succeed");
398            assert!(
399                obj.hasattr("__await__").unwrap_or(false)
400                    || obj.hasattr("send").unwrap_or(false)
401                    || obj.hasattr("__next__").unwrap_or(false),
402                "returned object is not awaitable"
403            );
404        });
405    }
406}