1use pyo3::prelude::*;
9use pyo3::types::PyDict;
10
11use scirs2_numpy::IntoPyArray;
13
14use scirs2_core::{ndarray::ArrayView1, Array1};
16
17use scirs2_optimize::global::{differential_evolution, DifferentialEvolutionOptions};
19use scirs2_optimize::scalar::{minimize_scalar, Method as ScalarMethod, Options as ScalarOptions};
20use scirs2_optimize::unconstrained::{minimize, Bounds, Method, Options};
21
22#[pyfunction]
30#[pyo3(signature = (fun, bracket, method="brent", options=None))]
31fn minimize_scalar_py(
32 py: Python,
33 fun: &Bound<'_, PyAny>,
34 bracket: (f64, f64),
35 method: &str,
36 options: Option<&Bound<'_, PyDict>>,
37) -> PyResult<Py<PyAny>> {
38 let maxiter = options
39 .and_then(|o| o.get_item("maxiter").ok().flatten())
40 .and_then(|v| v.extract().ok());
41 let tol = options
42 .and_then(|o| o.get_item("tol").ok().flatten())
43 .and_then(|v| v.extract().ok());
44
45 let fun_clone = fun.clone().unbind();
46 let f = move |x: f64| -> f64 {
47 Python::attach(|py| {
48 let result = fun_clone
49 .bind(py)
50 .call1((x,))
51 .expect("Failed to call objective function");
52 result.extract().expect("Failed to extract result")
53 })
54 };
55
56 let scalar_method = match method {
58 "brent" => ScalarMethod::Brent,
59 "golden" => ScalarMethod::Golden,
60 "bounded" => ScalarMethod::Bounded,
61 _ => ScalarMethod::Brent,
62 };
63
64 let mut options = ScalarOptions::default();
66 if let Some(mi) = maxiter {
67 options.max_iter = mi;
68 }
69 if let Some(t) = tol {
70 options.xatol = t;
71 }
72
73 let result = minimize_scalar(f, Some(bracket), scalar_method, Some(options))
74 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
75
76 let dict = PyDict::new(py);
77 dict.set_item("x", result.x)?;
78 dict.set_item("fun", result.fun)?;
79 dict.set_item("success", result.success)?;
80 dict.set_item("nit", result.nit)?;
81 dict.set_item("nfev", result.function_evals)?;
82
83 Ok(dict.into())
84}
85
86#[pyfunction]
99#[pyo3(signature = (fun, a, b, xtol=1e-12, maxiter=100))]
100fn brentq_py(
101 py: Python,
102 fun: &Bound<'_, PyAny>,
103 a: f64,
104 b: f64,
105 xtol: f64,
106 maxiter: usize,
107) -> PyResult<Py<PyAny>> {
108 let fun_clone = fun.clone().unbind();
109 let f = |x: f64| -> f64 {
110 Python::attach(|py| {
111 let result = fun_clone
112 .bind(py)
113 .call1((x,))
114 .expect("Failed to call objective function");
115 result.extract().expect("Failed to extract result")
116 })
117 };
118
119 let mut a = a;
121 let mut b = b;
122 let mut fa = f(a);
123 let mut fb = f(b);
124
125 if fa * fb > 0.0 {
126 return Err(pyo3::exceptions::PyValueError::new_err(
127 "f(a) and f(b) must have opposite signs",
128 ));
129 }
130
131 if fa.abs() < fb.abs() {
133 std::mem::swap(&mut a, &mut b);
134 std::mem::swap(&mut fa, &mut fb);
135 }
136
137 let mut c = a;
138 let mut fc = fa;
139 let mut d = b - a;
140 let mut e = d;
141 let mut iter = 0;
142
143 while iter < maxiter {
144 if fb.abs() < xtol {
145 let dict = PyDict::new(py);
146 dict.set_item("x", b)?;
147 dict.set_item("fun", fb)?;
148 dict.set_item("iterations", iter)?;
149 dict.set_item("success", true)?;
150 return Ok(dict.into());
151 }
152
153 if fa.abs() < fb.abs() {
154 std::mem::swap(&mut a, &mut b);
155 std::mem::swap(&mut fa, &mut fb);
156 c = a;
157 fc = fa;
158 }
159
160 let tol = 2.0 * f64::EPSILON * b.abs() + xtol;
161 let m = (c - b) / 2.0;
162
163 if m.abs() <= tol {
164 let dict = PyDict::new(py);
165 dict.set_item("x", b)?;
166 dict.set_item("fun", fb)?;
167 dict.set_item("iterations", iter)?;
168 dict.set_item("success", true)?;
169 return Ok(dict.into());
170 }
171
172 let mut use_bisection = true;
174
175 if e.abs() >= tol && fa.abs() > fb.abs() {
176 let s = fb / fa;
177 let (p, q) = if (a - c).abs() < 1e-14 {
178 (2.0 * m * s, 1.0 - s)
180 } else {
181 let q = fa / fc;
183 let r = fb / fc;
184 (
185 s * (2.0 * m * q * (q - r) - (b - a) * (r - 1.0)),
186 (q - 1.0) * (r - 1.0) * (s - 1.0),
187 )
188 };
189
190 let (p, q) = if p > 0.0 { (p, -q) } else { (-p, q) };
191
192 if 2.0 * p < 3.0 * m * q - (tol * q).abs() && p < (e * q / 2.0).abs() {
193 e = d;
194 d = p / q;
195 use_bisection = false;
196 }
197 }
198
199 if use_bisection {
200 d = m;
201 e = m;
202 }
203
204 a = b;
205 fa = fb;
206
207 if d.abs() > tol {
208 b += d;
209 } else {
210 b += if m > 0.0 { tol } else { -tol };
211 }
212
213 fb = f(b);
214
215 if (fb > 0.0) == (fc > 0.0) {
216 c = a;
217 fc = fa;
218 d = b - a;
219 e = d;
220 }
221
222 iter += 1;
223 }
224
225 let dict = PyDict::new(py);
226 dict.set_item("x", b)?;
227 dict.set_item("fun", fb)?;
228 dict.set_item("iterations", iter)?;
229 dict.set_item("success", false)?;
230 dict.set_item("message", "Maximum iterations reached")?;
231 Ok(dict.into())
232}
233
234#[pyfunction]
246#[pyo3(signature = (fun, x0, method="bfgs", options=None, bounds=None))]
247fn minimize_py(
248 py: Python,
249 fun: &Bound<'_, PyAny>,
250 x0: Vec<f64>,
251 method: &str,
252 options: Option<&Bound<'_, PyDict>>,
253 bounds: Option<Vec<(f64, f64)>>,
254) -> PyResult<Py<PyAny>> {
255 let opt_method = match method.to_lowercase().as_str() {
257 "nelder-mead" | "neldermead" => Method::NelderMead,
258 "powell" => Method::Powell,
259 "cg" | "conjugate-gradient" => Method::CG,
260 "bfgs" => Method::BFGS,
261 "lbfgs" | "l-bfgs" => Method::LBFGS,
262 "lbfgsb" | "l-bfgs-b" => Method::LBFGSB,
263 "newton-cg" => Method::NewtonCG,
264 "trust-ncg" => Method::TrustNCG,
265 "sr1" => Method::SR1,
266 "dfp" => Method::DFP,
267 _ => Method::BFGS, };
269
270 let maxiter = options
272 .and_then(|o| o.get_item("maxiter").ok().flatten())
273 .and_then(|v| v.extract().ok());
274 let ftol = options
275 .and_then(|o| o.get_item("ftol").ok().flatten())
276 .and_then(|v| v.extract().ok());
277 let gtol = options
278 .and_then(|o| o.get_item("gtol").ok().flatten())
279 .and_then(|v| v.extract().ok());
280
281 let mut opt_options = Options::default();
282 if let Some(mi) = maxiter {
283 opt_options.max_iter = mi;
284 }
285 if let Some(ft) = ftol {
286 opt_options.ftol = ft;
287 }
288 if let Some(gt) = gtol {
289 opt_options.gtol = gt;
290 }
291
292 if let Some(b) = bounds {
294 let n = x0.len();
295 let mut lower = vec![None; n];
296 let mut upper = vec![None; n];
297 for (i, (l, u)) in b.iter().enumerate() {
298 if i < n {
299 lower[i] = Some(*l);
300 upper[i] = Some(*u);
301 }
302 }
303 opt_options.bounds = Some(Bounds { lower, upper });
304 }
305
306 let fun_arc = std::sync::Arc::new(fun.clone().unbind());
308 let f = move |x: &ArrayView1<f64>| -> f64 {
309 let fun_clone = fun_arc.clone();
310 Python::attach(|py| {
311 let x_vec: Vec<f64> = x.to_vec();
312 let result = fun_clone
313 .bind(py)
314 .call1((x_vec,))
315 .expect("Failed to call objective function");
316 result.extract().expect("Failed to extract result")
317 })
318 };
319
320 let result = minimize(f, &x0, opt_method, Some(opt_options))
322 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
323
324 let dict = PyDict::new(py);
326 dict.set_item("x", result.x.into_pyarray(py).unbind())?;
327 dict.set_item("fun", result.fun)?;
328 dict.set_item("success", result.success)?;
329 dict.set_item("message", result.message)?;
330 dict.set_item("nit", result.nit)?;
331 dict.set_item("nfev", result.func_evals)?;
332
333 Ok(dict.into())
334}
335
336#[pyfunction]
343#[pyo3(signature = (fun, bounds, options=None))]
344fn differential_evolution_py(
345 py: Python,
346 fun: &Bound<'_, PyAny>,
347 bounds: Vec<(f64, f64)>,
348 options: Option<&Bound<'_, PyDict>>,
349) -> PyResult<Py<PyAny>> {
350 let maxiter = options
351 .and_then(|o| o.get_item("maxiter").ok().flatten())
352 .and_then(|v| v.extract().ok());
353 let popsize = options
354 .and_then(|o| o.get_item("popsize").ok().flatten())
355 .and_then(|v| v.extract().ok());
356 let tol = options
357 .and_then(|o| o.get_item("tol").ok().flatten())
358 .and_then(|v| v.extract().ok());
359 let seed = options
360 .and_then(|o| o.get_item("seed").ok().flatten())
361 .and_then(|v| v.extract().ok());
362
363 let fun_arc = std::sync::Arc::new(fun.clone().unbind());
364 let f = move |x: &ArrayView1<f64>| -> f64 {
365 let fun_clone = fun_arc.clone();
366 Python::attach(|py| {
367 let x_vec: Vec<f64> = x.to_vec();
368 let result = fun_clone
369 .bind(py)
370 .call1((x_vec,))
371 .expect("Failed to call objective function");
372 result.extract().expect("Failed to extract result")
373 })
374 };
375
376 let mut de_options = DifferentialEvolutionOptions::default();
378 if let Some(mi) = maxiter {
379 de_options.maxiter = mi;
380 }
381 if let Some(ps) = popsize {
382 de_options.popsize = ps;
383 }
384 if let Some(t) = tol {
385 de_options.tol = t;
386 }
387 if let Some(s) = seed {
388 de_options.seed = Some(s);
389 }
390
391 let result = differential_evolution(f, bounds.to_vec(), Some(de_options), None)
393 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
394
395 let dict = PyDict::new(py);
396 dict.set_item("x", result.x.into_pyarray(py).unbind())?;
397 dict.set_item("fun", result.fun)?;
398 dict.set_item("success", result.success)?;
399 dict.set_item("message", result.message)?;
400 dict.set_item("nit", result.nit)?;
401 dict.set_item("nfev", result.func_evals)?;
402
403 Ok(dict.into())
404}
405
406#[pyfunction]
440#[pyo3(signature = (f, xdata, ydata, p0=None, method="lm", maxfev=1000))]
441fn curve_fit_py(
442 py: Python,
443 f: &Bound<'_, PyAny>,
444 xdata: Vec<f64>,
445 ydata: Vec<f64>,
446 p0: Option<Vec<f64>>,
447 method: &str,
448 maxfev: usize,
449) -> PyResult<Py<PyAny>> {
450 use scirs2_optimize::least_squares::{least_squares, Method as LSMethod, Options as LSOptions};
451
452 if xdata.len() != ydata.len() {
453 return Err(pyo3::exceptions::PyValueError::new_err(
454 "xdata and ydata must have the same length",
455 ));
456 }
457
458 let n_data = xdata.len();
459
460 let params_init = p0.unwrap_or_else(|| vec![1.0; 2]);
462
463 let ls_method = match method.to_lowercase().as_str() {
465 "lm" => LSMethod::LevenbergMarquardt,
466 "trf" => LSMethod::TrustRegionReflective,
467 "dogbox" => LSMethod::Dogbox,
468 _ => LSMethod::LevenbergMarquardt,
469 };
470
471 let xdata_clone = xdata.clone();
473 let ydata_clone = ydata.clone();
474 let f_arc = std::sync::Arc::new(f.clone().unbind());
475
476 let residual_fn = move |params: &[f64], _data: &[f64]| -> Array1<f64> {
478 let f_clone = f_arc.clone();
479 let xdata_ref = &xdata_clone;
480 let ydata_ref = &ydata_clone;
481
482 Python::attach(|py| {
483 let mut residuals = Vec::with_capacity(n_data);
484
485 for i in 0..n_data {
486 let mut args = vec![xdata_ref[i]];
488 args.extend_from_slice(params);
489
490 let f_val: f64 = f_clone
491 .bind(py)
492 .call1(pyo3::types::PyTuple::new(py, &args).expect("Operation failed"))
493 .expect("Failed to call model function")
494 .extract()
495 .expect("Failed to extract model result");
496
497 residuals.push(ydata_ref[i] - f_val);
498 }
499
500 Array1::from_vec(residuals)
501 })
502 };
503
504 let options = LSOptions {
506 max_nfev: Some(maxfev),
507 ..Default::default()
508 };
509
510 let empty_data = Array1::from_vec(vec![]);
512
513 let result = least_squares(
514 residual_fn,
515 &Array1::from_vec(params_init),
516 ls_method,
517 None::<fn(&[f64], &[f64]) -> scirs2_core::ndarray::Array2<f64>>, &empty_data, Some(options),
520 )
521 .map_err(|e| {
522 pyo3::exceptions::PyRuntimeError::new_err(format!("Curve fitting failed: {}", e))
523 })?;
524
525 let dict = PyDict::new(py);
527 dict.set_item("popt", result.x.into_pyarray(py).unbind())?;
528 dict.set_item("success", result.success)?;
529 dict.set_item("nfev", result.nfev)?;
530 dict.set_item("message", result.message)?;
531
532 Ok(dict.into())
533}
534
535pub fn register_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
537 m.add_function(wrap_pyfunction!(minimize_py, m)?)?;
538 m.add_function(wrap_pyfunction!(minimize_scalar_py, m)?)?;
539 m.add_function(wrap_pyfunction!(brentq_py, m)?)?;
540 m.add_function(wrap_pyfunction!(differential_evolution_py, m)?)?;
541 m.add_function(wrap_pyfunction!(curve_fit_py, m)?)?;
542
543 Ok(())
544}