1use pyo3::prelude::*;
8use pyo3::types::PyDict;
9
10use scirs2_numpy::IntoPyArray;
12
13use scirs2_core::{ndarray::ArrayView1, Array1};
15
16use scirs2_optimize::global::{differential_evolution, DifferentialEvolutionOptions};
18use scirs2_optimize::scalar::{minimize_scalar, Method as ScalarMethod, Options as ScalarOptions};
19use scirs2_optimize::unconstrained::{minimize, Bounds, Method, Options};
20
21#[pyfunction]
29#[pyo3(signature = (fun, bracket, method="brent", options=None))]
30fn minimize_scalar_py(
31 py: Python,
32 fun: &Bound<'_, PyAny>,
33 bracket: (f64, f64),
34 method: &str,
35 options: Option<&Bound<'_, PyDict>>,
36) -> PyResult<Py<PyAny>> {
37 let maxiter = options
38 .and_then(|o| o.get_item("maxiter").ok().flatten())
39 .and_then(|v| v.extract().ok());
40 let tol = options
41 .and_then(|o| o.get_item("tol").ok().flatten())
42 .and_then(|v| v.extract().ok());
43
44 let fun_clone = fun.clone().unbind();
45 let f = move |x: f64| -> f64 {
46 #[allow(deprecated)]
47 Python::with_gil(|py| {
48 let result = fun_clone
49 .bind(py)
50 .call1((x,))
51 .expect("Failed to call objective function");
52 result.extract().expect("Failed to extract result")
53 })
54 };
55
56 let scalar_method = match method {
58 "brent" => ScalarMethod::Brent,
59 "golden" => ScalarMethod::Golden,
60 "bounded" => ScalarMethod::Bounded,
61 _ => ScalarMethod::Brent,
62 };
63
64 let mut options = ScalarOptions::default();
66 if let Some(mi) = maxiter {
67 options.max_iter = mi;
68 }
69 if let Some(t) = tol {
70 options.xatol = t;
71 }
72
73 let result = minimize_scalar(f, Some(bracket), scalar_method, Some(options))
74 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
75
76 let dict = PyDict::new(py);
77 dict.set_item("x", result.x)?;
78 dict.set_item("fun", result.fun)?;
79 dict.set_item("success", result.success)?;
80 dict.set_item("nit", result.nit)?;
81 dict.set_item("nfev", result.function_evals)?;
82
83 Ok(dict.into())
84}
85
86#[pyfunction]
99#[pyo3(signature = (fun, a, b, xtol=1e-12, maxiter=100))]
100fn brentq_py(
101 py: Python,
102 fun: &Bound<'_, PyAny>,
103 a: f64,
104 b: f64,
105 xtol: f64,
106 maxiter: usize,
107) -> PyResult<Py<PyAny>> {
108 let fun_clone = fun.clone().unbind();
109 let f = |x: f64| -> f64 {
110 #[allow(deprecated)]
111 Python::with_gil(|py| {
112 let result = fun_clone
113 .bind(py)
114 .call1((x,))
115 .expect("Failed to call objective function");
116 result.extract().expect("Failed to extract result")
117 })
118 };
119
120 let mut a = a;
122 let mut b = b;
123 let mut fa = f(a);
124 let mut fb = f(b);
125
126 if fa * fb > 0.0 {
127 return Err(pyo3::exceptions::PyValueError::new_err(
128 "f(a) and f(b) must have opposite signs",
129 ));
130 }
131
132 if fa.abs() < fb.abs() {
134 std::mem::swap(&mut a, &mut b);
135 std::mem::swap(&mut fa, &mut fb);
136 }
137
138 let mut c = a;
139 let mut fc = fa;
140 let mut d = b - a;
141 let mut e = d;
142 let mut iter = 0;
143
144 while iter < maxiter {
145 if fb.abs() < xtol {
146 let dict = PyDict::new(py);
147 dict.set_item("x", b)?;
148 dict.set_item("fun", fb)?;
149 dict.set_item("iterations", iter)?;
150 dict.set_item("success", true)?;
151 return Ok(dict.into());
152 }
153
154 if fa.abs() < fb.abs() {
155 std::mem::swap(&mut a, &mut b);
156 std::mem::swap(&mut fa, &mut fb);
157 c = a;
158 fc = fa;
159 }
160
161 let tol = 2.0 * f64::EPSILON * b.abs() + xtol;
162 let m = (c - b) / 2.0;
163
164 if m.abs() <= tol {
165 let dict = PyDict::new(py);
166 dict.set_item("x", b)?;
167 dict.set_item("fun", fb)?;
168 dict.set_item("iterations", iter)?;
169 dict.set_item("success", true)?;
170 return Ok(dict.into());
171 }
172
173 let mut use_bisection = true;
175
176 if e.abs() >= tol && fa.abs() > fb.abs() {
177 let s = fb / fa;
178 let (p, q) = if (a - c).abs() < 1e-14 {
179 (2.0 * m * s, 1.0 - s)
181 } else {
182 let q = fa / fc;
184 let r = fb / fc;
185 (
186 s * (2.0 * m * q * (q - r) - (b - a) * (r - 1.0)),
187 (q - 1.0) * (r - 1.0) * (s - 1.0),
188 )
189 };
190
191 let (p, q) = if p > 0.0 { (p, -q) } else { (-p, q) };
192
193 if 2.0 * p < 3.0 * m * q - (tol * q).abs() && p < (e * q / 2.0).abs() {
194 e = d;
195 d = p / q;
196 use_bisection = false;
197 }
198 }
199
200 if use_bisection {
201 d = m;
202 e = m;
203 }
204
205 a = b;
206 fa = fb;
207
208 if d.abs() > tol {
209 b += d;
210 } else {
211 b += if m > 0.0 { tol } else { -tol };
212 }
213
214 fb = f(b);
215
216 if (fb > 0.0) == (fc > 0.0) {
217 c = a;
218 fc = fa;
219 d = b - a;
220 e = d;
221 }
222
223 iter += 1;
224 }
225
226 let dict = PyDict::new(py);
227 dict.set_item("x", b)?;
228 dict.set_item("fun", fb)?;
229 dict.set_item("iterations", iter)?;
230 dict.set_item("success", false)?;
231 dict.set_item("message", "Maximum iterations reached")?;
232 Ok(dict.into())
233}
234
235#[pyfunction]
247#[pyo3(signature = (fun, x0, method="bfgs", options=None, bounds=None))]
248fn minimize_py(
249 py: Python,
250 fun: &Bound<'_, PyAny>,
251 x0: Vec<f64>,
252 method: &str,
253 options: Option<&Bound<'_, PyDict>>,
254 bounds: Option<Vec<(f64, f64)>>,
255) -> PyResult<Py<PyAny>> {
256 let opt_method = match method.to_lowercase().as_str() {
258 "nelder-mead" | "neldermead" => Method::NelderMead,
259 "powell" => Method::Powell,
260 "cg" | "conjugate-gradient" => Method::CG,
261 "bfgs" => Method::BFGS,
262 "lbfgs" | "l-bfgs" => Method::LBFGS,
263 "lbfgsb" | "l-bfgs-b" => Method::LBFGSB,
264 "newton-cg" => Method::NewtonCG,
265 "trust-ncg" => Method::TrustNCG,
266 "sr1" => Method::SR1,
267 "dfp" => Method::DFP,
268 _ => Method::BFGS, };
270
271 let maxiter = options
273 .and_then(|o| o.get_item("maxiter").ok().flatten())
274 .and_then(|v| v.extract().ok());
275 let ftol = options
276 .and_then(|o| o.get_item("ftol").ok().flatten())
277 .and_then(|v| v.extract().ok());
278 let gtol = options
279 .and_then(|o| o.get_item("gtol").ok().flatten())
280 .and_then(|v| v.extract().ok());
281
282 let mut opt_options = Options::default();
283 if let Some(mi) = maxiter {
284 opt_options.max_iter = mi;
285 }
286 if let Some(ft) = ftol {
287 opt_options.ftol = ft;
288 }
289 if let Some(gt) = gtol {
290 opt_options.gtol = gt;
291 }
292
293 if let Some(b) = bounds {
295 let n = x0.len();
296 let mut lower = vec![None; n];
297 let mut upper = vec![None; n];
298 for (i, (l, u)) in b.iter().enumerate() {
299 if i < n {
300 lower[i] = Some(*l);
301 upper[i] = Some(*u);
302 }
303 }
304 opt_options.bounds = Some(Bounds { lower, upper });
305 }
306
307 let fun_arc = std::sync::Arc::new(fun.clone().unbind());
309 let f = move |x: &ArrayView1<f64>| -> f64 {
310 let fun_clone = fun_arc.clone();
311 #[allow(deprecated)]
312 Python::with_gil(|py| {
313 let x_vec: Vec<f64> = x.to_vec();
314 let result = fun_clone
315 .bind(py)
316 .call1((x_vec,))
317 .expect("Failed to call objective function");
318 result.extract().expect("Failed to extract result")
319 })
320 };
321
322 let result = minimize(f, &x0, opt_method, Some(opt_options))
324 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
325
326 let dict = PyDict::new(py);
328 dict.set_item("x", result.x.into_pyarray(py).unbind())?;
329 dict.set_item("fun", result.fun)?;
330 dict.set_item("success", result.success)?;
331 dict.set_item("message", result.message)?;
332 dict.set_item("nit", result.nit)?;
333 dict.set_item("nfev", result.func_evals)?;
334
335 Ok(dict.into())
336}
337
338#[pyfunction]
345#[pyo3(signature = (fun, bounds, options=None))]
346fn differential_evolution_py(
347 py: Python,
348 fun: &Bound<'_, PyAny>,
349 bounds: Vec<(f64, f64)>,
350 options: Option<&Bound<'_, PyDict>>,
351) -> PyResult<Py<PyAny>> {
352 let maxiter = options
353 .and_then(|o| o.get_item("maxiter").ok().flatten())
354 .and_then(|v| v.extract().ok());
355 let popsize = options
356 .and_then(|o| o.get_item("popsize").ok().flatten())
357 .and_then(|v| v.extract().ok());
358 let tol = options
359 .and_then(|o| o.get_item("tol").ok().flatten())
360 .and_then(|v| v.extract().ok());
361 let seed = options
362 .and_then(|o| o.get_item("seed").ok().flatten())
363 .and_then(|v| v.extract().ok());
364
365 let fun_arc = std::sync::Arc::new(fun.clone().unbind());
366 let f = move |x: &ArrayView1<f64>| -> f64 {
367 let fun_clone = fun_arc.clone();
368 #[allow(deprecated)]
369 Python::with_gil(|py| {
370 let x_vec: Vec<f64> = x.to_vec();
371 let result = fun_clone
372 .bind(py)
373 .call1((x_vec,))
374 .expect("Failed to call objective function");
375 result.extract().expect("Failed to extract result")
376 })
377 };
378
379 let mut de_options = DifferentialEvolutionOptions::default();
381 if let Some(mi) = maxiter {
382 de_options.maxiter = mi;
383 }
384 if let Some(ps) = popsize {
385 de_options.popsize = ps;
386 }
387 if let Some(t) = tol {
388 de_options.tol = t;
389 }
390 if let Some(s) = seed {
391 de_options.seed = Some(s);
392 }
393
394 let result = differential_evolution(f, bounds.to_vec(), Some(de_options), None)
396 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
397
398 let dict = PyDict::new(py);
399 dict.set_item("x", result.x.into_pyarray(py).unbind())?;
400 dict.set_item("fun", result.fun)?;
401 dict.set_item("success", result.success)?;
402 dict.set_item("message", result.message)?;
403 dict.set_item("nit", result.nit)?;
404 dict.set_item("nfev", result.func_evals)?;
405
406 Ok(dict.into())
407}
408
409#[pyfunction]
443#[pyo3(signature = (f, xdata, ydata, p0=None, method="lm", maxfev=1000))]
444fn curve_fit_py(
445 py: Python,
446 f: &Bound<'_, PyAny>,
447 xdata: Vec<f64>,
448 ydata: Vec<f64>,
449 p0: Option<Vec<f64>>,
450 method: &str,
451 maxfev: usize,
452) -> PyResult<Py<PyAny>> {
453 use scirs2_optimize::least_squares::{least_squares, Method as LSMethod, Options as LSOptions};
454
455 if xdata.len() != ydata.len() {
456 return Err(pyo3::exceptions::PyValueError::new_err(
457 "xdata and ydata must have the same length",
458 ));
459 }
460
461 let n_data = xdata.len();
462
463 let params_init = p0.unwrap_or_else(|| vec![1.0; 2]);
465
466 let ls_method = match method.to_lowercase().as_str() {
468 "lm" => LSMethod::LevenbergMarquardt,
469 "trf" => LSMethod::TrustRegionReflective,
470 "dogbox" => LSMethod::Dogbox,
471 _ => LSMethod::LevenbergMarquardt,
472 };
473
474 let xdata_clone = xdata.clone();
476 let ydata_clone = ydata.clone();
477 let f_arc = std::sync::Arc::new(f.clone().unbind());
478
479 let residual_fn = move |params: &[f64], _data: &[f64]| -> Array1<f64> {
481 let f_clone = f_arc.clone();
482 let xdata_ref = &xdata_clone;
483 let ydata_ref = &ydata_clone;
484
485 #[allow(deprecated)]
486 Python::with_gil(|py| {
487 let mut residuals = Vec::with_capacity(n_data);
488
489 for i in 0..n_data {
490 let mut args = vec![xdata_ref[i]];
492 args.extend_from_slice(params);
493
494 let f_val: f64 = f_clone
495 .bind(py)
496 .call1(pyo3::types::PyTuple::new(py, &args).expect("Operation failed"))
497 .expect("Failed to call model function")
498 .extract()
499 .expect("Failed to extract model result");
500
501 residuals.push(ydata_ref[i] - f_val);
502 }
503
504 Array1::from_vec(residuals)
505 })
506 };
507
508 let options = LSOptions {
510 max_nfev: Some(maxfev),
511 ..Default::default()
512 };
513
514 let empty_data = Array1::from_vec(vec![]);
516
517 let result = least_squares(
518 residual_fn,
519 &Array1::from_vec(params_init),
520 ls_method,
521 None::<fn(&[f64], &[f64]) -> scirs2_core::ndarray::Array2<f64>>, &empty_data, Some(options),
524 )
525 .map_err(|e| {
526 pyo3::exceptions::PyRuntimeError::new_err(format!("Curve fitting failed: {}", e))
527 })?;
528
529 let dict = PyDict::new(py);
531 dict.set_item("popt", result.x.into_pyarray(py).unbind())?;
532 dict.set_item("success", result.success)?;
533 dict.set_item("nfev", result.nfev)?;
534 dict.set_item("message", result.message)?;
535
536 Ok(dict.into())
537}
538
539pub fn register_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
541 m.add_function(wrap_pyfunction!(minimize_py, m)?)?;
542 m.add_function(wrap_pyfunction!(minimize_scalar_py, m)?)?;
543 m.add_function(wrap_pyfunction!(brentq_py, m)?)?;
544 m.add_function(wrap_pyfunction!(differential_evolution_py, m)?)?;
545 m.add_function(wrap_pyfunction!(curve_fit_py, m)?)?;
546
547 Ok(())
548}