Skip to main content

scirs2/
async_ops.rs

1//! Async operations for Python
2//!
3//! This module provides async versions of long-running operations that can be awaited in Python.
4//!
5//! # Example (Python)
6//! ```python
7//! import asyncio
8//! import scirs2
9//! import numpy as np
10//!
11//! async def main():
12//!     # Async FFT for large arrays
13//!     data = np.random.randn(1_000_000)
14//!     result = await scirs2.fft_async(data)
15//!
16//!     # Async matrix decomposition
17//!     matrix = np.random.randn(1000, 1000)
18//!     svd = await scirs2.svd_async(matrix)
19//!
20//! asyncio.run(main())
21//! ```
22
23use crate::error::SciRS2Error;
24use pyo3::prelude::*;
25use pyo3_async_runtimes;
26use scirs2_numpy::{IntoPyArray, PyArray1, PyArray2, PyArrayMethods, PyUntypedArrayMethods};
27
28/// Async FFT operation for large arrays
29///
30/// This function runs FFT in a background thread and returns a Python awaitable.
31/// Useful for large arrays (>100k elements) to avoid blocking the event loop.
32#[pyfunction]
33pub fn fft_async<'py>(
34    py: Python<'py>,
35    data: &Bound<'_, PyArray1<f64>>,
36) -> PyResult<Bound<'py, PyAny>> {
37    let data_vec: Vec<f64> = {
38        let binding = data.readonly();
39        let arr = binding.as_array();
40        arr.iter().cloned().collect()
41    };
42
43    pyo3_async_runtimes::tokio::future_into_py(py, async move {
44        // Run FFT in blocking task
45        // FFT returns Vec<Complex64> - collect real and imag parts
46        let (real_part, imag_part): (Vec<f64>, Vec<f64>) = tokio::task::spawn_blocking(move || {
47            use scirs2_core::Complex64;
48            use scirs2_fft::fft;
49
50            let result: Vec<Complex64> = fft(data_vec.as_slice(), None)
51                .map_err(|e| SciRS2Error::ComputationError(format!("FFT failed: {}", e)))?;
52
53            let real: Vec<f64> = result.iter().map(|c| c.re).collect();
54            let imag: Vec<f64> = result.iter().map(|c| c.im).collect();
55            Ok::<(Vec<f64>, Vec<f64>), SciRS2Error>((real, imag))
56        })
57        .await
58        .map_err(|e| SciRS2Error::RuntimeError(format!("Task join error: {}", e)))??;
59
60        // Return as Py<PyAny> which is Send
61        let py_result: Py<PyAny> = Python::attach(|py| {
62            use pyo3::types::PyDict;
63            use scirs2_core::Array1;
64            // Return dict with real and imag arrays
65            let dict = PyDict::new(py);
66            let real_arr: Array1<f64> = Array1::from_vec(real_part);
67            let imag_arr: Array1<f64> = Array1::from_vec(imag_part);
68            dict.set_item("real", real_arr.into_pyarray(py))?;
69            dict.set_item("imag", imag_arr.into_pyarray(py))?;
70            Ok::<Py<PyAny>, PyErr>(dict.into_any().unbind())
71        })?;
72
73        Ok(py_result)
74    })
75}
76
77/// Async SVD operation for large matrices
78///
79/// This function runs SVD in a background thread and returns a Python awaitable.
80/// Useful for large matrices (>500x500) to avoid blocking the event loop.
81#[pyfunction]
82pub fn svd_async<'py>(
83    py: Python<'py>,
84    matrix: &Bound<'_, PyArray2<f64>>,
85    full_matrices: Option<bool>,
86) -> PyResult<Bound<'py, PyAny>> {
87    let matrix_shape = matrix.shape().to_vec();
88    let matrix_vec: Vec<f64> = {
89        let binding = matrix.readonly();
90        let arr = binding.as_array();
91        arr.iter().cloned().collect()
92    };
93    let full_matrices = full_matrices.unwrap_or(true);
94
95    pyo3_async_runtimes::tokio::future_into_py(py, async move {
96        // Run SVD in blocking task
97        let result = tokio::task::spawn_blocking(move || {
98            use scirs2_core::Array2;
99            use scirs2_linalg::svd_f64_lapack;
100
101            let arr = Array2::from_shape_vec((matrix_shape[0], matrix_shape[1]), matrix_vec)
102                .map_err(|e| SciRS2Error::ArrayError(format!("Array reshape failed: {}", e)))?;
103
104            svd_f64_lapack(&arr.view(), full_matrices)
105                .map_err(|e| SciRS2Error::ComputationError(format!("SVD failed: {}", e)))
106        })
107        .await
108        .map_err(|e| SciRS2Error::RuntimeError(format!("Task join error: {}", e)))??;
109
110        // Convert result to Python dict; return Py<PyAny> which is Send
111        let py_result: Py<PyAny> = Python::attach(|py| {
112            use pyo3::types::PyDict;
113            let dict = PyDict::new(py);
114            dict.set_item("U", result.0.into_pyarray(py))?;
115            dict.set_item("S", result.1.into_pyarray(py))?;
116            dict.set_item("Vt", result.2.into_pyarray(py))?;
117            Ok::<Py<PyAny>, PyErr>(dict.into_any().unbind())
118        })?;
119
120        Ok(py_result)
121    })
122}
123
124/// Async QR decomposition for large matrices
125#[pyfunction]
126pub fn qr_async<'py>(
127    py: Python<'py>,
128    matrix: &Bound<'_, PyArray2<f64>>,
129) -> PyResult<Bound<'py, PyAny>> {
130    let matrix_shape = matrix.shape().to_vec();
131    let matrix_vec: Vec<f64> = {
132        let binding = matrix.readonly();
133        let arr = binding.as_array();
134        arr.iter().cloned().collect()
135    };
136
137    pyo3_async_runtimes::tokio::future_into_py(py, async move {
138        let result = tokio::task::spawn_blocking(move || {
139            use scirs2_core::Array2;
140            use scirs2_linalg::qr_f64_lapack;
141
142            let arr = Array2::from_shape_vec((matrix_shape[0], matrix_shape[1]), matrix_vec)
143                .map_err(|e| SciRS2Error::ArrayError(format!("Array reshape failed: {}", e)))?;
144
145            qr_f64_lapack(&arr.view())
146                .map_err(|e| SciRS2Error::ComputationError(format!("QR failed: {}", e)))
147        })
148        .await
149        .map_err(|e| SciRS2Error::RuntimeError(format!("Task join error: {}", e)))??;
150
151        let py_result: Py<PyAny> = Python::attach(|py| {
152            use pyo3::types::PyDict;
153            let dict = PyDict::new(py);
154            dict.set_item("Q", result.0.into_pyarray(py))?;
155            dict.set_item("R", result.1.into_pyarray(py))?;
156            Ok::<Py<PyAny>, PyErr>(dict.into_any().unbind())
157        })?;
158
159        Ok(py_result)
160    })
161}
162
163/// Async numerical integration for expensive integrands
164#[pyfunction]
165pub fn quad_async<'py>(
166    py: Python<'py>,
167    func: Py<PyAny>,
168    a: f64,
169    b: f64,
170    epsabs: Option<f64>,
171    epsrel: Option<f64>,
172) -> PyResult<Bound<'py, PyAny>> {
173    pyo3_async_runtimes::tokio::future_into_py(py, async move {
174        let result: (f64, f64) = tokio::task::spawn_blocking(move || {
175            Python::attach(|py| {
176                use scirs2_integrate::quad::{quad, QuadOptions};
177
178                let abs_tol = epsabs.unwrap_or(1e-8);
179                let rel_tol = epsrel.unwrap_or(1e-8);
180
181                // Create Rust closure that calls Python function
182                let integrand = |x: f64| -> f64 {
183                    func.call1(py, (x,))
184                        .and_then(|result| result.extract::<f64>(py))
185                        .unwrap_or(f64::NAN)
186                };
187
188                let options = QuadOptions {
189                    abs_tol,
190                    rel_tol,
191                    ..Default::default()
192                };
193
194                let result = quad(integrand, a, b, Some(options)).map_err(|e| {
195                    PyErr::from(SciRS2Error::ComputationError(format!(
196                        "Integration failed: {}",
197                        e
198                    )))
199                })?;
200
201                Ok::<(f64, f64), PyErr>((result.value, result.abs_error))
202            })
203        })
204        .await
205        .map_err(|e| SciRS2Error::RuntimeError(format!("Task join error: {}", e)))??;
206
207        let py_result: Py<PyAny> = Python::attach(|py| {
208            use pyo3::types::PyDict;
209            let dict = PyDict::new(py);
210            dict.set_item("value", result.0)?;
211            dict.set_item("error", result.1)?;
212            Ok::<Py<PyAny>, PyErr>(dict.into_any().unbind())
213        })?;
214
215        Ok(py_result)
216    })
217}
218
219/// Async optimization for expensive objective functions
220#[pyfunction]
221pub fn minimize_async<'py>(
222    py: Python<'py>,
223    func: Py<PyAny>,
224    x0: &Bound<'_, PyArray1<f64>>,
225    method: Option<String>,
226    maxiter: Option<usize>,
227) -> PyResult<Bound<'py, PyAny>> {
228    let x0_vec: Vec<f64> = {
229        let binding = x0.readonly();
230        let arr = binding.as_array();
231        arr.iter().cloned().collect()
232    };
233
234    pyo3_async_runtimes::tokio::future_into_py(py, async move {
235        let result: (Vec<f64>, f64, usize) = tokio::task::spawn_blocking(move || {
236            Python::attach(|py| {
237                use scirs2_core::ndarray::ArrayView1;
238                use scirs2_optimize::unconstrained::{minimize, Method};
239
240                // Create Rust closure that calls Python function
241                let objective = |x: &ArrayView1<f64>| -> f64 {
242                    let x_slice = x.as_slice().unwrap_or(&[]);
243                    let x_py = match pyo3::types::PyList::new(py, x_slice) {
244                        Ok(list) => list,
245                        Err(_) => return f64::NAN,
246                    };
247                    func.call1(py, (x_py,))
248                        .and_then(|r| r.extract::<f64>(py))
249                        .unwrap_or(f64::NAN)
250                };
251
252                let opt_method = match method.as_deref() {
253                    Some("BFGS") => Method::BFGS,
254                    Some("Newton") | Some("NewtonCG") => Method::NewtonCG,
255                    Some("GradientDescent") | Some("CG") => Method::CG,
256                    Some("NelderMead") => Method::NelderMead,
257                    Some("LBFGS") => Method::LBFGS,
258                    _ => Method::BFGS,
259                };
260
261                use scirs2_optimize::unconstrained::Options;
262                let options = Options {
263                    max_iter: maxiter.unwrap_or(1000),
264                    ..Default::default()
265                };
266
267                let result =
268                    minimize(objective, &x0_vec, opt_method, Some(options)).map_err(|e| {
269                        PyErr::from(SciRS2Error::ComputationError(format!(
270                            "Optimization failed: {}",
271                            e
272                        )))
273                    })?;
274
275                let x_vec = result.x.to_vec();
276                let fun_val: f64 = result.fun;
277                let nit = result.nit;
278                Ok::<(Vec<f64>, f64, usize), PyErr>((x_vec, fun_val, nit))
279            })
280        })
281        .await
282        .map_err(|e| SciRS2Error::RuntimeError(format!("Task join error: {}", e)))??;
283
284        let py_result: Py<PyAny> = Python::attach(|py| {
285            use pyo3::types::PyDict;
286            use scirs2_core::Array1;
287
288            let dict = PyDict::new(py);
289            let x = Array1::from_vec(result.0);
290            dict.set_item("x", x.into_pyarray(py))?;
291            dict.set_item("fun", result.1)?;
292            dict.set_item("nit", result.2)?;
293            Ok::<Py<PyAny>, PyErr>(dict.into_any().unbind())
294        })?;
295
296        Ok(py_result)
297    })
298}
299
300/// Register async operations with Python module
301pub fn register_async_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
302    m.add_function(wrap_pyfunction!(fft_async, m)?)?;
303    m.add_function(wrap_pyfunction!(svd_async, m)?)?;
304    m.add_function(wrap_pyfunction!(qr_async, m)?)?;
305    m.add_function(wrap_pyfunction!(quad_async, m)?)?;
306    m.add_function(wrap_pyfunction!(minimize_async, m)?)?;
307    Ok(())
308}