Skip to main content

scirs2/
optimize.rs

1//! Python bindings for scirs2-optimize
2//!
3//! Provides optimization algorithms similar to scipy.optimize
4
5// Callbacks acquire the GIL via Python::attach (pyo3 0.28+) since they are invoked
6// from Rust optimization routines where the caller's Python<'_> token is not in scope.
7
8use pyo3::prelude::*;
9use pyo3::types::PyDict;
10
11// NumPy types for Python array interface (scirs2-numpy with native ndarray 0.17)
12use scirs2_numpy::IntoPyArray;
13
14// ndarray types from scirs2-core
15use scirs2_core::{ndarray::ArrayView1, Array1};
16
17// Direct imports from scirs2-optimize (native ndarray 0.17 support)
18use scirs2_optimize::global::{differential_evolution, DifferentialEvolutionOptions};
19use scirs2_optimize::scalar::{minimize_scalar, Method as ScalarMethod, Options as ScalarOptions};
20use scirs2_optimize::unconstrained::{minimize, Bounds, Method, Options};
21
22/// Minimize a scalar function of one variable
23///
24/// Parameters:
25/// - fun: The objective function to minimize
26/// - bracket: (a, b) interval to search
27/// - method: 'brent', 'golden', or 'bounded'
28/// - options: Dict with 'maxiter', 'tol'
29#[pyfunction]
30#[pyo3(signature = (fun, bracket, method="brent", options=None))]
31fn minimize_scalar_py(
32    py: Python,
33    fun: &Bound<'_, PyAny>,
34    bracket: (f64, f64),
35    method: &str,
36    options: Option<&Bound<'_, PyDict>>,
37) -> PyResult<Py<PyAny>> {
38    let maxiter = options
39        .and_then(|o| o.get_item("maxiter").ok().flatten())
40        .and_then(|v| v.extract().ok());
41    let tol = options
42        .and_then(|o| o.get_item("tol").ok().flatten())
43        .and_then(|v| v.extract().ok());
44
45    let fun_clone = fun.clone().unbind();
46    let f = move |x: f64| -> f64 {
47        Python::attach(|py| {
48            let result = fun_clone
49                .bind(py)
50                .call1((x,))
51                .expect("Failed to call objective function");
52            result.extract().expect("Failed to extract result")
53        })
54    };
55
56    // Parse method
57    let scalar_method = match method {
58        "brent" => ScalarMethod::Brent,
59        "golden" => ScalarMethod::Golden,
60        "bounded" => ScalarMethod::Bounded,
61        _ => ScalarMethod::Brent,
62    };
63
64    // Set up options
65    let mut options = ScalarOptions::default();
66    if let Some(mi) = maxiter {
67        options.max_iter = mi;
68    }
69    if let Some(t) = tol {
70        options.xatol = t;
71    }
72
73    let result = minimize_scalar(f, Some(bracket), scalar_method, Some(options))
74        .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
75
76    let dict = PyDict::new(py);
77    dict.set_item("x", result.x)?;
78    dict.set_item("fun", result.fun)?;
79    dict.set_item("success", result.success)?;
80    dict.set_item("nit", result.nit)?;
81    dict.set_item("nfev", result.function_evals)?;
82
83    Ok(dict.into())
84}
85
86/// Find root of a scalar function using Brent's method
87///
88/// Parameters:
89/// - fun: The function for which to find the root
90/// - a: Lower bound of the bracket
91/// - b: Upper bound of the bracket
92/// - xtol: Absolute tolerance (default 1e-12)
93/// - maxiter: Maximum iterations (default 100)
94///
95/// Returns:
96/// - Dict with 'x' (root location), 'fun' (function value at root),
97///   'iterations', 'success'
98#[pyfunction]
99#[pyo3(signature = (fun, a, b, xtol=1e-12, maxiter=100))]
100fn brentq_py(
101    py: Python,
102    fun: &Bound<'_, PyAny>,
103    a: f64,
104    b: f64,
105    xtol: f64,
106    maxiter: usize,
107) -> PyResult<Py<PyAny>> {
108    let fun_clone = fun.clone().unbind();
109    let f = |x: f64| -> f64 {
110        Python::attach(|py| {
111            let result = fun_clone
112                .bind(py)
113                .call1((x,))
114                .expect("Failed to call objective function");
115            result.extract().expect("Failed to extract result")
116        })
117    };
118
119    // Brent's method implementation
120    let mut a = a;
121    let mut b = b;
122    let mut fa = f(a);
123    let mut fb = f(b);
124
125    if fa * fb > 0.0 {
126        return Err(pyo3::exceptions::PyValueError::new_err(
127            "f(a) and f(b) must have opposite signs",
128        ));
129    }
130
131    // Ensure |f(a)| >= |f(b)|
132    if fa.abs() < fb.abs() {
133        std::mem::swap(&mut a, &mut b);
134        std::mem::swap(&mut fa, &mut fb);
135    }
136
137    let mut c = a;
138    let mut fc = fa;
139    let mut d = b - a;
140    let mut e = d;
141    let mut iter = 0;
142
143    while iter < maxiter {
144        if fb.abs() < xtol {
145            let dict = PyDict::new(py);
146            dict.set_item("x", b)?;
147            dict.set_item("fun", fb)?;
148            dict.set_item("iterations", iter)?;
149            dict.set_item("success", true)?;
150            return Ok(dict.into());
151        }
152
153        if fa.abs() < fb.abs() {
154            std::mem::swap(&mut a, &mut b);
155            std::mem::swap(&mut fa, &mut fb);
156            c = a;
157            fc = fa;
158        }
159
160        let tol = 2.0 * f64::EPSILON * b.abs() + xtol;
161        let m = (c - b) / 2.0;
162
163        if m.abs() <= tol {
164            let dict = PyDict::new(py);
165            dict.set_item("x", b)?;
166            dict.set_item("fun", fb)?;
167            dict.set_item("iterations", iter)?;
168            dict.set_item("success", true)?;
169            return Ok(dict.into());
170        }
171
172        // Use bisection or interpolation
173        let mut use_bisection = true;
174
175        if e.abs() >= tol && fa.abs() > fb.abs() {
176            let s = fb / fa;
177            let (p, q) = if (a - c).abs() < 1e-14 {
178                // Linear interpolation
179                (2.0 * m * s, 1.0 - s)
180            } else {
181                // Inverse quadratic interpolation
182                let q = fa / fc;
183                let r = fb / fc;
184                (
185                    s * (2.0 * m * q * (q - r) - (b - a) * (r - 1.0)),
186                    (q - 1.0) * (r - 1.0) * (s - 1.0),
187                )
188            };
189
190            let (p, q) = if p > 0.0 { (p, -q) } else { (-p, q) };
191
192            if 2.0 * p < 3.0 * m * q - (tol * q).abs() && p < (e * q / 2.0).abs() {
193                e = d;
194                d = p / q;
195                use_bisection = false;
196            }
197        }
198
199        if use_bisection {
200            d = m;
201            e = m;
202        }
203
204        a = b;
205        fa = fb;
206
207        if d.abs() > tol {
208            b += d;
209        } else {
210            b += if m > 0.0 { tol } else { -tol };
211        }
212
213        fb = f(b);
214
215        if (fb > 0.0) == (fc > 0.0) {
216            c = a;
217            fc = fa;
218            d = b - a;
219            e = d;
220        }
221
222        iter += 1;
223    }
224
225    let dict = PyDict::new(py);
226    dict.set_item("x", b)?;
227    dict.set_item("fun", fb)?;
228    dict.set_item("iterations", iter)?;
229    dict.set_item("success", false)?;
230    dict.set_item("message", "Maximum iterations reached")?;
231    Ok(dict.into())
232}
233
234/// Minimize a function of one or more variables
235///
236/// Parameters:
237/// - fun: The objective function to minimize
238/// - x0: Initial guess as array
239/// - method: Optimization method ('nelder-mead', 'bfgs', 'cg', 'powell', 'lbfgs', etc.)
240/// - options: Dict with 'maxiter', 'ftol', 'gtol'
241/// - bounds: Optional list of (min, max) bounds for each variable
242///
243/// Returns:
244/// - Dict with 'x' (solution), 'fun' (function value), 'success', 'nit', 'nfev', 'message'
245#[pyfunction]
246#[pyo3(signature = (fun, x0, method="bfgs", options=None, bounds=None))]
247fn minimize_py(
248    py: Python,
249    fun: &Bound<'_, PyAny>,
250    x0: Vec<f64>,
251    method: &str,
252    options: Option<&Bound<'_, PyDict>>,
253    bounds: Option<Vec<(f64, f64)>>,
254) -> PyResult<Py<PyAny>> {
255    // Parse method
256    let opt_method = match method.to_lowercase().as_str() {
257        "nelder-mead" | "neldermead" => Method::NelderMead,
258        "powell" => Method::Powell,
259        "cg" | "conjugate-gradient" => Method::CG,
260        "bfgs" => Method::BFGS,
261        "lbfgs" | "l-bfgs" => Method::LBFGS,
262        "lbfgsb" | "l-bfgs-b" => Method::LBFGSB,
263        "newton-cg" => Method::NewtonCG,
264        "trust-ncg" => Method::TrustNCG,
265        "sr1" => Method::SR1,
266        "dfp" => Method::DFP,
267        _ => Method::BFGS, // Default to BFGS
268    };
269
270    // Parse options
271    let maxiter = options
272        .and_then(|o| o.get_item("maxiter").ok().flatten())
273        .and_then(|v| v.extract().ok());
274    let ftol = options
275        .and_then(|o| o.get_item("ftol").ok().flatten())
276        .and_then(|v| v.extract().ok());
277    let gtol = options
278        .and_then(|o| o.get_item("gtol").ok().flatten())
279        .and_then(|v| v.extract().ok());
280
281    let mut opt_options = Options::default();
282    if let Some(mi) = maxiter {
283        opt_options.max_iter = mi;
284    }
285    if let Some(ft) = ftol {
286        opt_options.ftol = ft;
287    }
288    if let Some(gt) = gtol {
289        opt_options.gtol = gt;
290    }
291
292    // Parse bounds
293    if let Some(b) = bounds {
294        let n = x0.len();
295        let mut lower = vec![None; n];
296        let mut upper = vec![None; n];
297        for (i, (l, u)) in b.iter().enumerate() {
298            if i < n {
299                lower[i] = Some(*l);
300                upper[i] = Some(*u);
301            }
302        }
303        opt_options.bounds = Some(Bounds { lower, upper });
304    }
305
306    // Create closure for the objective function
307    let fun_arc = std::sync::Arc::new(fun.clone().unbind());
308    let f = move |x: &ArrayView1<f64>| -> f64 {
309        let fun_clone = fun_arc.clone();
310        Python::attach(|py| {
311            let x_vec: Vec<f64> = x.to_vec();
312            let result = fun_clone
313                .bind(py)
314                .call1((x_vec,))
315                .expect("Failed to call objective function");
316            result.extract().expect("Failed to extract result")
317        })
318    };
319
320    // Run optimization
321    let result = minimize(f, &x0, opt_method, Some(opt_options))
322        .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
323
324    // Return result as dict
325    let dict = PyDict::new(py);
326    dict.set_item("x", result.x.into_pyarray(py).unbind())?;
327    dict.set_item("fun", result.fun)?;
328    dict.set_item("success", result.success)?;
329    dict.set_item("message", result.message)?;
330    dict.set_item("nit", result.nit)?;
331    dict.set_item("nfev", result.func_evals)?;
332
333    Ok(dict.into())
334}
335
336/// Global optimization using differential evolution
337///
338/// Parameters:
339/// - fun: The objective function to minimize
340/// - bounds: List of (min, max) bounds for each variable
341/// - options: Dict with 'maxiter', 'popsize', 'tol', 'seed'
342#[pyfunction]
343#[pyo3(signature = (fun, bounds, options=None))]
344fn differential_evolution_py(
345    py: Python,
346    fun: &Bound<'_, PyAny>,
347    bounds: Vec<(f64, f64)>,
348    options: Option<&Bound<'_, PyDict>>,
349) -> PyResult<Py<PyAny>> {
350    let maxiter = options
351        .and_then(|o| o.get_item("maxiter").ok().flatten())
352        .and_then(|v| v.extract().ok());
353    let popsize = options
354        .and_then(|o| o.get_item("popsize").ok().flatten())
355        .and_then(|v| v.extract().ok());
356    let tol = options
357        .and_then(|o| o.get_item("tol").ok().flatten())
358        .and_then(|v| v.extract().ok());
359    let seed = options
360        .and_then(|o| o.get_item("seed").ok().flatten())
361        .and_then(|v| v.extract().ok());
362
363    let fun_arc = std::sync::Arc::new(fun.clone().unbind());
364    let f = move |x: &ArrayView1<f64>| -> f64 {
365        let fun_clone = fun_arc.clone();
366        Python::attach(|py| {
367            let x_vec: Vec<f64> = x.to_vec();
368            let result = fun_clone
369                .bind(py)
370                .call1((x_vec,))
371                .expect("Failed to call objective function");
372            result.extract().expect("Failed to extract result")
373        })
374    };
375
376    // Set up options
377    let mut de_options = DifferentialEvolutionOptions::default();
378    if let Some(mi) = maxiter {
379        de_options.maxiter = mi;
380    }
381    if let Some(ps) = popsize {
382        de_options.popsize = ps;
383    }
384    if let Some(t) = tol {
385        de_options.tol = t;
386    }
387    if let Some(s) = seed {
388        de_options.seed = Some(s);
389    }
390
391    // Use bounds vector directly (differential_evolution expects Vec<(f64, f64)>)
392    let result = differential_evolution(f, bounds.to_vec(), Some(de_options), None)
393        .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
394
395    let dict = PyDict::new(py);
396    dict.set_item("x", result.x.into_pyarray(py).unbind())?;
397    dict.set_item("fun", result.fun)?;
398    dict.set_item("success", result.success)?;
399    dict.set_item("message", result.message)?;
400    dict.set_item("nit", result.nit)?;
401    dict.set_item("nfev", result.func_evals)?;
402
403    Ok(dict.into())
404}
405
406/// Curve fitting using non-linear least squares
407///
408/// Use non-linear least squares to fit a function to data.
409/// This is similar to scipy.optimize.curve_fit.
410///
411/// Parameters:
412/// - f: Model function f(x, *params) that takes independent variable(s) and parameters
413/// - xdata: Independent variable where data is measured (array or scalar for each point)
414/// - ydata: Dependent data to fit
415/// - p0: Initial guess for parameters (optional, defaults to ones)
416/// - method: Optimization method ('lm', 'trf', or 'dogbox')
417/// - maxfev: Maximum number of function evaluations (default: 1000)
418///
419/// Returns:
420/// - Dict with 'popt' (optimized parameters), 'success', 'nfev', 'message'
421///
422/// Example:
423/// ```python
424/// import numpy as np
425/// import scirs2
426///
427/// # Define exponential model: f(x, a, b) = a * exp(b * x)
428/// def model(x, a, b):
429///     return a * np.exp(b * x)
430///
431/// # Generate noisy data
432/// xdata = np.array([0.0, 1.0, 2.0, 3.0, 4.0])
433/// ydata = np.array([1.0, 2.7, 7.4, 20.1, 54.6])  # ≈ 1.0 * exp(1.0 * x) with noise
434///
435/// # Fit the curve
436/// result = scirs2.curve_fit_py(model, xdata, ydata, p0=[1.0, 1.0])
437/// print(f"Optimized parameters: {result['popt']}")
438/// ```
439#[pyfunction]
440#[pyo3(signature = (f, xdata, ydata, p0=None, method="lm", maxfev=1000))]
441fn curve_fit_py(
442    py: Python,
443    f: &Bound<'_, PyAny>,
444    xdata: Vec<f64>,
445    ydata: Vec<f64>,
446    p0: Option<Vec<f64>>,
447    method: &str,
448    maxfev: usize,
449) -> PyResult<Py<PyAny>> {
450    use scirs2_optimize::least_squares::{least_squares, Method as LSMethod, Options as LSOptions};
451
452    if xdata.len() != ydata.len() {
453        return Err(pyo3::exceptions::PyValueError::new_err(
454            "xdata and ydata must have the same length",
455        ));
456    }
457
458    let n_data = xdata.len();
459
460    // Use p0 or default to ones
461    let params_init = p0.unwrap_or_else(|| vec![1.0; 2]);
462
463    // Parse method
464    let ls_method = match method.to_lowercase().as_str() {
465        "lm" => LSMethod::LevenbergMarquardt,
466        "trf" => LSMethod::TrustRegionReflective,
467        "dogbox" => LSMethod::Dogbox,
468        _ => LSMethod::LevenbergMarquardt,
469    };
470
471    // Create copies for closure
472    let xdata_clone = xdata.clone();
473    let ydata_clone = ydata.clone();
474    let f_arc = std::sync::Arc::new(f.clone().unbind());
475
476    // Define residual function
477    let residual_fn = move |params: &[f64], _data: &[f64]| -> Array1<f64> {
478        let f_clone = f_arc.clone();
479        let xdata_ref = &xdata_clone;
480        let ydata_ref = &ydata_clone;
481
482        Python::attach(|py| {
483            let mut residuals = Vec::with_capacity(n_data);
484
485            for i in 0..n_data {
486                // Call f(x, *params)
487                let mut args = vec![xdata_ref[i]];
488                args.extend_from_slice(params);
489
490                let f_val: f64 = f_clone
491                    .bind(py)
492                    .call1(pyo3::types::PyTuple::new(py, &args).expect("Operation failed"))
493                    .expect("Failed to call model function")
494                    .extract()
495                    .expect("Failed to extract model result");
496
497                residuals.push(ydata_ref[i] - f_val);
498            }
499
500            Array1::from_vec(residuals)
501        })
502    };
503
504    // Set up options
505    let options = LSOptions {
506        max_nfev: Some(maxfev),
507        ..Default::default()
508    };
509
510    // Run least squares optimization
511    let empty_data = Array1::from_vec(vec![]);
512
513    let result = least_squares(
514        residual_fn,
515        &Array1::from_vec(params_init),
516        ls_method,
517        None::<fn(&[f64], &[f64]) -> scirs2_core::ndarray::Array2<f64>>, // No jacobian provided
518        &empty_data,                                                     // No additional data
519        Some(options),
520    )
521    .map_err(|e| {
522        pyo3::exceptions::PyRuntimeError::new_err(format!("Curve fitting failed: {}", e))
523    })?;
524
525    // Return results
526    let dict = PyDict::new(py);
527    dict.set_item("popt", result.x.into_pyarray(py).unbind())?;
528    dict.set_item("success", result.success)?;
529    dict.set_item("nfev", result.nfev)?;
530    dict.set_item("message", result.message)?;
531
532    Ok(dict.into())
533}
534
535/// Python module registration
536pub fn register_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
537    m.add_function(wrap_pyfunction!(minimize_py, m)?)?;
538    m.add_function(wrap_pyfunction!(minimize_scalar_py, m)?)?;
539    m.add_function(wrap_pyfunction!(brentq_py, m)?)?;
540    m.add_function(wrap_pyfunction!(differential_evolution_py, m)?)?;
541    m.add_function(wrap_pyfunction!(curve_fit_py, m)?)?;
542
543    Ok(())
544}