Skip to main content

scirs2/
optimize.rs

1//! Python bindings for scirs2-optimize
2//!
3//! Provides optimization algorithms similar to scipy.optimize
4
5// Allow deprecated with_gil for callback patterns where GIL must be acquired from Rust
6
7use pyo3::prelude::*;
8use pyo3::types::PyDict;
9
10// NumPy types for Python array interface (scirs2-numpy with native ndarray 0.17)
11use scirs2_numpy::IntoPyArray;
12
13// ndarray types from scirs2-core
14use scirs2_core::{ndarray::ArrayView1, Array1};
15
16// Direct imports from scirs2-optimize (native ndarray 0.17 support)
17use scirs2_optimize::global::{differential_evolution, DifferentialEvolutionOptions};
18use scirs2_optimize::scalar::{minimize_scalar, Method as ScalarMethod, Options as ScalarOptions};
19use scirs2_optimize::unconstrained::{minimize, Bounds, Method, Options};
20
21/// Minimize a scalar function of one variable
22///
23/// Parameters:
24/// - fun: The objective function to minimize
25/// - bracket: (a, b) interval to search
26/// - method: 'brent', 'golden', or 'bounded'
27/// - options: Dict with 'maxiter', 'tol'
28#[pyfunction]
29#[pyo3(signature = (fun, bracket, method="brent", options=None))]
30fn minimize_scalar_py(
31    py: Python,
32    fun: &Bound<'_, PyAny>,
33    bracket: (f64, f64),
34    method: &str,
35    options: Option<&Bound<'_, PyDict>>,
36) -> PyResult<Py<PyAny>> {
37    let maxiter = options
38        .and_then(|o| o.get_item("maxiter").ok().flatten())
39        .and_then(|v| v.extract().ok());
40    let tol = options
41        .and_then(|o| o.get_item("tol").ok().flatten())
42        .and_then(|v| v.extract().ok());
43
44    let fun_clone = fun.clone().unbind();
45    let f = move |x: f64| -> f64 {
46        #[allow(deprecated)]
47        Python::with_gil(|py| {
48            let result = fun_clone
49                .bind(py)
50                .call1((x,))
51                .expect("Failed to call objective function");
52            result.extract().expect("Failed to extract result")
53        })
54    };
55
56    // Parse method
57    let scalar_method = match method {
58        "brent" => ScalarMethod::Brent,
59        "golden" => ScalarMethod::Golden,
60        "bounded" => ScalarMethod::Bounded,
61        _ => ScalarMethod::Brent,
62    };
63
64    // Set up options
65    let mut options = ScalarOptions::default();
66    if let Some(mi) = maxiter {
67        options.max_iter = mi;
68    }
69    if let Some(t) = tol {
70        options.xatol = t;
71    }
72
73    let result = minimize_scalar(f, Some(bracket), scalar_method, Some(options))
74        .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
75
76    let dict = PyDict::new(py);
77    dict.set_item("x", result.x)?;
78    dict.set_item("fun", result.fun)?;
79    dict.set_item("success", result.success)?;
80    dict.set_item("nit", result.nit)?;
81    dict.set_item("nfev", result.function_evals)?;
82
83    Ok(dict.into())
84}
85
86/// Find root of a scalar function using Brent's method
87///
88/// Parameters:
89/// - fun: The function for which to find the root
90/// - a: Lower bound of the bracket
91/// - b: Upper bound of the bracket
92/// - xtol: Absolute tolerance (default 1e-12)
93/// - maxiter: Maximum iterations (default 100)
94///
95/// Returns:
96/// - Dict with 'x' (root location), 'fun' (function value at root),
97///   'iterations', 'success'
98#[pyfunction]
99#[pyo3(signature = (fun, a, b, xtol=1e-12, maxiter=100))]
100fn brentq_py(
101    py: Python,
102    fun: &Bound<'_, PyAny>,
103    a: f64,
104    b: f64,
105    xtol: f64,
106    maxiter: usize,
107) -> PyResult<Py<PyAny>> {
108    let fun_clone = fun.clone().unbind();
109    let f = |x: f64| -> f64 {
110        #[allow(deprecated)]
111        Python::with_gil(|py| {
112            let result = fun_clone
113                .bind(py)
114                .call1((x,))
115                .expect("Failed to call objective function");
116            result.extract().expect("Failed to extract result")
117        })
118    };
119
120    // Brent's method implementation
121    let mut a = a;
122    let mut b = b;
123    let mut fa = f(a);
124    let mut fb = f(b);
125
126    if fa * fb > 0.0 {
127        return Err(pyo3::exceptions::PyValueError::new_err(
128            "f(a) and f(b) must have opposite signs",
129        ));
130    }
131
132    // Ensure |f(a)| >= |f(b)|
133    if fa.abs() < fb.abs() {
134        std::mem::swap(&mut a, &mut b);
135        std::mem::swap(&mut fa, &mut fb);
136    }
137
138    let mut c = a;
139    let mut fc = fa;
140    let mut d = b - a;
141    let mut e = d;
142    let mut iter = 0;
143
144    while iter < maxiter {
145        if fb.abs() < xtol {
146            let dict = PyDict::new(py);
147            dict.set_item("x", b)?;
148            dict.set_item("fun", fb)?;
149            dict.set_item("iterations", iter)?;
150            dict.set_item("success", true)?;
151            return Ok(dict.into());
152        }
153
154        if fa.abs() < fb.abs() {
155            std::mem::swap(&mut a, &mut b);
156            std::mem::swap(&mut fa, &mut fb);
157            c = a;
158            fc = fa;
159        }
160
161        let tol = 2.0 * f64::EPSILON * b.abs() + xtol;
162        let m = (c - b) / 2.0;
163
164        if m.abs() <= tol {
165            let dict = PyDict::new(py);
166            dict.set_item("x", b)?;
167            dict.set_item("fun", fb)?;
168            dict.set_item("iterations", iter)?;
169            dict.set_item("success", true)?;
170            return Ok(dict.into());
171        }
172
173        // Use bisection or interpolation
174        let mut use_bisection = true;
175
176        if e.abs() >= tol && fa.abs() > fb.abs() {
177            let s = fb / fa;
178            let (p, q) = if (a - c).abs() < 1e-14 {
179                // Linear interpolation
180                (2.0 * m * s, 1.0 - s)
181            } else {
182                // Inverse quadratic interpolation
183                let q = fa / fc;
184                let r = fb / fc;
185                (
186                    s * (2.0 * m * q * (q - r) - (b - a) * (r - 1.0)),
187                    (q - 1.0) * (r - 1.0) * (s - 1.0),
188                )
189            };
190
191            let (p, q) = if p > 0.0 { (p, -q) } else { (-p, q) };
192
193            if 2.0 * p < 3.0 * m * q - (tol * q).abs() && p < (e * q / 2.0).abs() {
194                e = d;
195                d = p / q;
196                use_bisection = false;
197            }
198        }
199
200        if use_bisection {
201            d = m;
202            e = m;
203        }
204
205        a = b;
206        fa = fb;
207
208        if d.abs() > tol {
209            b += d;
210        } else {
211            b += if m > 0.0 { tol } else { -tol };
212        }
213
214        fb = f(b);
215
216        if (fb > 0.0) == (fc > 0.0) {
217            c = a;
218            fc = fa;
219            d = b - a;
220            e = d;
221        }
222
223        iter += 1;
224    }
225
226    let dict = PyDict::new(py);
227    dict.set_item("x", b)?;
228    dict.set_item("fun", fb)?;
229    dict.set_item("iterations", iter)?;
230    dict.set_item("success", false)?;
231    dict.set_item("message", "Maximum iterations reached")?;
232    Ok(dict.into())
233}
234
235/// Minimize a function of one or more variables
236///
237/// Parameters:
238/// - fun: The objective function to minimize
239/// - x0: Initial guess as array
240/// - method: Optimization method ('nelder-mead', 'bfgs', 'cg', 'powell', 'lbfgs', etc.)
241/// - options: Dict with 'maxiter', 'ftol', 'gtol'
242/// - bounds: Optional list of (min, max) bounds for each variable
243///
244/// Returns:
245/// - Dict with 'x' (solution), 'fun' (function value), 'success', 'nit', 'nfev', 'message'
246#[pyfunction]
247#[pyo3(signature = (fun, x0, method="bfgs", options=None, bounds=None))]
248fn minimize_py(
249    py: Python,
250    fun: &Bound<'_, PyAny>,
251    x0: Vec<f64>,
252    method: &str,
253    options: Option<&Bound<'_, PyDict>>,
254    bounds: Option<Vec<(f64, f64)>>,
255) -> PyResult<Py<PyAny>> {
256    // Parse method
257    let opt_method = match method.to_lowercase().as_str() {
258        "nelder-mead" | "neldermead" => Method::NelderMead,
259        "powell" => Method::Powell,
260        "cg" | "conjugate-gradient" => Method::CG,
261        "bfgs" => Method::BFGS,
262        "lbfgs" | "l-bfgs" => Method::LBFGS,
263        "lbfgsb" | "l-bfgs-b" => Method::LBFGSB,
264        "newton-cg" => Method::NewtonCG,
265        "trust-ncg" => Method::TrustNCG,
266        "sr1" => Method::SR1,
267        "dfp" => Method::DFP,
268        _ => Method::BFGS, // Default to BFGS
269    };
270
271    // Parse options
272    let maxiter = options
273        .and_then(|o| o.get_item("maxiter").ok().flatten())
274        .and_then(|v| v.extract().ok());
275    let ftol = options
276        .and_then(|o| o.get_item("ftol").ok().flatten())
277        .and_then(|v| v.extract().ok());
278    let gtol = options
279        .and_then(|o| o.get_item("gtol").ok().flatten())
280        .and_then(|v| v.extract().ok());
281
282    let mut opt_options = Options::default();
283    if let Some(mi) = maxiter {
284        opt_options.max_iter = mi;
285    }
286    if let Some(ft) = ftol {
287        opt_options.ftol = ft;
288    }
289    if let Some(gt) = gtol {
290        opt_options.gtol = gt;
291    }
292
293    // Parse bounds
294    if let Some(b) = bounds {
295        let n = x0.len();
296        let mut lower = vec![None; n];
297        let mut upper = vec![None; n];
298        for (i, (l, u)) in b.iter().enumerate() {
299            if i < n {
300                lower[i] = Some(*l);
301                upper[i] = Some(*u);
302            }
303        }
304        opt_options.bounds = Some(Bounds { lower, upper });
305    }
306
307    // Create closure for the objective function
308    let fun_arc = std::sync::Arc::new(fun.clone().unbind());
309    let f = move |x: &ArrayView1<f64>| -> f64 {
310        let fun_clone = fun_arc.clone();
311        #[allow(deprecated)]
312        Python::with_gil(|py| {
313            let x_vec: Vec<f64> = x.to_vec();
314            let result = fun_clone
315                .bind(py)
316                .call1((x_vec,))
317                .expect("Failed to call objective function");
318            result.extract().expect("Failed to extract result")
319        })
320    };
321
322    // Run optimization
323    let result = minimize(f, &x0, opt_method, Some(opt_options))
324        .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
325
326    // Return result as dict
327    let dict = PyDict::new(py);
328    dict.set_item("x", result.x.into_pyarray(py).unbind())?;
329    dict.set_item("fun", result.fun)?;
330    dict.set_item("success", result.success)?;
331    dict.set_item("message", result.message)?;
332    dict.set_item("nit", result.nit)?;
333    dict.set_item("nfev", result.func_evals)?;
334
335    Ok(dict.into())
336}
337
338/// Global optimization using differential evolution
339///
340/// Parameters:
341/// - fun: The objective function to minimize
342/// - bounds: List of (min, max) bounds for each variable
343/// - options: Dict with 'maxiter', 'popsize', 'tol', 'seed'
344#[pyfunction]
345#[pyo3(signature = (fun, bounds, options=None))]
346fn differential_evolution_py(
347    py: Python,
348    fun: &Bound<'_, PyAny>,
349    bounds: Vec<(f64, f64)>,
350    options: Option<&Bound<'_, PyDict>>,
351) -> PyResult<Py<PyAny>> {
352    let maxiter = options
353        .and_then(|o| o.get_item("maxiter").ok().flatten())
354        .and_then(|v| v.extract().ok());
355    let popsize = options
356        .and_then(|o| o.get_item("popsize").ok().flatten())
357        .and_then(|v| v.extract().ok());
358    let tol = options
359        .and_then(|o| o.get_item("tol").ok().flatten())
360        .and_then(|v| v.extract().ok());
361    let seed = options
362        .and_then(|o| o.get_item("seed").ok().flatten())
363        .and_then(|v| v.extract().ok());
364
365    let fun_arc = std::sync::Arc::new(fun.clone().unbind());
366    let f = move |x: &ArrayView1<f64>| -> f64 {
367        let fun_clone = fun_arc.clone();
368        #[allow(deprecated)]
369        Python::with_gil(|py| {
370            let x_vec: Vec<f64> = x.to_vec();
371            let result = fun_clone
372                .bind(py)
373                .call1((x_vec,))
374                .expect("Failed to call objective function");
375            result.extract().expect("Failed to extract result")
376        })
377    };
378
379    // Set up options
380    let mut de_options = DifferentialEvolutionOptions::default();
381    if let Some(mi) = maxiter {
382        de_options.maxiter = mi;
383    }
384    if let Some(ps) = popsize {
385        de_options.popsize = ps;
386    }
387    if let Some(t) = tol {
388        de_options.tol = t;
389    }
390    if let Some(s) = seed {
391        de_options.seed = Some(s);
392    }
393
394    // Use bounds vector directly (differential_evolution expects Vec<(f64, f64)>)
395    let result = differential_evolution(f, bounds.to_vec(), Some(de_options), None)
396        .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
397
398    let dict = PyDict::new(py);
399    dict.set_item("x", result.x.into_pyarray(py).unbind())?;
400    dict.set_item("fun", result.fun)?;
401    dict.set_item("success", result.success)?;
402    dict.set_item("message", result.message)?;
403    dict.set_item("nit", result.nit)?;
404    dict.set_item("nfev", result.func_evals)?;
405
406    Ok(dict.into())
407}
408
409/// Curve fitting using non-linear least squares
410///
411/// Use non-linear least squares to fit a function to data.
412/// This is similar to scipy.optimize.curve_fit.
413///
414/// Parameters:
415/// - f: Model function f(x, *params) that takes independent variable(s) and parameters
416/// - xdata: Independent variable where data is measured (array or scalar for each point)
417/// - ydata: Dependent data to fit
418/// - p0: Initial guess for parameters (optional, defaults to ones)
419/// - method: Optimization method ('lm', 'trf', or 'dogbox')
420/// - maxfev: Maximum number of function evaluations (default: 1000)
421///
422/// Returns:
423/// - Dict with 'popt' (optimized parameters), 'success', 'nfev', 'message'
424///
425/// Example:
426/// ```python
427/// import numpy as np
428/// import scirs2
429///
430/// # Define exponential model: f(x, a, b) = a * exp(b * x)
431/// def model(x, a, b):
432///     return a * np.exp(b * x)
433///
434/// # Generate noisy data
435/// xdata = np.array([0.0, 1.0, 2.0, 3.0, 4.0])
436/// ydata = np.array([1.0, 2.7, 7.4, 20.1, 54.6])  # ≈ 1.0 * exp(1.0 * x) with noise
437///
438/// # Fit the curve
439/// result = scirs2.curve_fit_py(model, xdata, ydata, p0=[1.0, 1.0])
440/// print(f"Optimized parameters: {result['popt']}")
441/// ```
442#[pyfunction]
443#[pyo3(signature = (f, xdata, ydata, p0=None, method="lm", maxfev=1000))]
444fn curve_fit_py(
445    py: Python,
446    f: &Bound<'_, PyAny>,
447    xdata: Vec<f64>,
448    ydata: Vec<f64>,
449    p0: Option<Vec<f64>>,
450    method: &str,
451    maxfev: usize,
452) -> PyResult<Py<PyAny>> {
453    use scirs2_optimize::least_squares::{least_squares, Method as LSMethod, Options as LSOptions};
454
455    if xdata.len() != ydata.len() {
456        return Err(pyo3::exceptions::PyValueError::new_err(
457            "xdata and ydata must have the same length",
458        ));
459    }
460
461    let n_data = xdata.len();
462
463    // Use p0 or default to ones
464    let params_init = p0.unwrap_or_else(|| vec![1.0; 2]);
465
466    // Parse method
467    let ls_method = match method.to_lowercase().as_str() {
468        "lm" => LSMethod::LevenbergMarquardt,
469        "trf" => LSMethod::TrustRegionReflective,
470        "dogbox" => LSMethod::Dogbox,
471        _ => LSMethod::LevenbergMarquardt,
472    };
473
474    // Create copies for closure
475    let xdata_clone = xdata.clone();
476    let ydata_clone = ydata.clone();
477    let f_arc = std::sync::Arc::new(f.clone().unbind());
478
479    // Define residual function
480    let residual_fn = move |params: &[f64], _data: &[f64]| -> Array1<f64> {
481        let f_clone = f_arc.clone();
482        let xdata_ref = &xdata_clone;
483        let ydata_ref = &ydata_clone;
484
485        #[allow(deprecated)]
486        Python::with_gil(|py| {
487            let mut residuals = Vec::with_capacity(n_data);
488
489            for i in 0..n_data {
490                // Call f(x, *params)
491                let mut args = vec![xdata_ref[i]];
492                args.extend_from_slice(params);
493
494                let f_val: f64 = f_clone
495                    .bind(py)
496                    .call1(pyo3::types::PyTuple::new(py, &args).expect("Operation failed"))
497                    .expect("Failed to call model function")
498                    .extract()
499                    .expect("Failed to extract model result");
500
501                residuals.push(ydata_ref[i] - f_val);
502            }
503
504            Array1::from_vec(residuals)
505        })
506    };
507
508    // Set up options
509    let options = LSOptions {
510        max_nfev: Some(maxfev),
511        ..Default::default()
512    };
513
514    // Run least squares optimization
515    let empty_data = Array1::from_vec(vec![]);
516
517    let result = least_squares(
518        residual_fn,
519        &Array1::from_vec(params_init),
520        ls_method,
521        None::<fn(&[f64], &[f64]) -> scirs2_core::ndarray::Array2<f64>>, // No jacobian provided
522        &empty_data,                                                     // No additional data
523        Some(options),
524    )
525    .map_err(|e| {
526        pyo3::exceptions::PyRuntimeError::new_err(format!("Curve fitting failed: {}", e))
527    })?;
528
529    // Return results
530    let dict = PyDict::new(py);
531    dict.set_item("popt", result.x.into_pyarray(py).unbind())?;
532    dict.set_item("success", result.success)?;
533    dict.set_item("nfev", result.nfev)?;
534    dict.set_item("message", result.message)?;
535
536    Ok(dict.into())
537}
538
539/// Python module registration
540pub fn register_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
541    m.add_function(wrap_pyfunction!(minimize_py, m)?)?;
542    m.add_function(wrap_pyfunction!(minimize_scalar_py, m)?)?;
543    m.add_function(wrap_pyfunction!(brentq_py, m)?)?;
544    m.add_function(wrap_pyfunction!(differential_evolution_py, m)?)?;
545    m.add_function(wrap_pyfunction!(curve_fit_py, m)?)?;
546
547    Ok(())
548}