1use numpy::{PyArray1, PyArray2};
2use pyo3::prelude::*;
3use pyo3::types::{PyAny, PyDict, PyList};
4use std::sync::Arc;
5
6use crate::pybridge::{Objective, ConstraintFunction};
7
8#[derive(Debug, Clone)]
9pub enum ConstraintType { Eq, Ineq }
10
11impl ConstraintType {
12 pub fn as_str(&self) -> &'static str { match self { ConstraintType::Eq => "eq", ConstraintType::Ineq => "ineq" } }
13}
14
15#[derive(Clone)]
16pub enum ConstraintSpec {
17 Linear { a: Vec<Vec<f64>>, lb: Vec<f64>, ub: Vec<f64>, keep_feasible: Option<bool> },
18 NonlinearBounds { fun: Arc<dyn Fn(&[f64]) -> Vec<f64> + Send + Sync + 'static>, lb: Vec<f64>, ub: Vec<f64> },
19 Dict { ctype: ConstraintType, fun: Arc<dyn Fn(&[f64]) -> f64 + Send + Sync + 'static> },
20}
21
22#[derive(Clone)]
23struct LinearConstr { a: Vec<Vec<f64>>, lb: Vec<f64>, ub: Vec<f64>, keep_feasible: Option<bool> }
24#[derive(Clone)]
25struct NonlinearConstrBounds { fun: Arc<dyn Fn(&[f64]) -> Vec<f64> + Send + Sync + 'static>, lb: Vec<f64>, ub: Vec<f64> }
26#[derive(Clone)]
27struct DictConstr { ctype: ConstraintType, fun: Arc<dyn Fn(&[f64]) -> f64 + Send + Sync + 'static> }
28
29#[derive(Clone, Default)]
30pub struct CombinedConstraints { linear: Vec<LinearConstr>, nonlinear: Vec<NonlinearConstrBounds>, dict: Vec<DictConstr> }
31
32impl From<Vec<ConstraintSpec>> for CombinedConstraints {
33 fn from(specs: Vec<ConstraintSpec>) -> Self {
34 let mut cc = CombinedConstraints::default();
35 for s in specs {
36 match s {
37 ConstraintSpec::Linear { a, lb, ub, keep_feasible } => cc.linear.push(LinearConstr { a, lb, ub, keep_feasible }),
38 ConstraintSpec::NonlinearBounds { fun, lb, ub } => cc.nonlinear.push(NonlinearConstrBounds { fun, lb, ub }),
39 ConstraintSpec::Dict { ctype, fun } => cc.dict.push(DictConstr { ctype, fun }),
40 }
41 }
42 cc
43 }
44}
45
46impl CombinedConstraints {
47 pub fn to_py_list<'py>(&self, py: Python<'py>) -> pyo3::Bound<'py, PyList> {
48 let mut items: Vec<pyo3::Bound<'py, PyAny>> = Vec::with_capacity(self.linear.len() + self.nonlinear.len() + self.dict.len());
49 let optimize = PyModule::import(py, "scipy.optimize").unwrap();
50
51 if !self.linear.is_empty() {
52 let constr_cls = optimize.getattr("LinearConstraint").unwrap();
53 for lc in &self.linear {
54 let tol = 0.0_f64;
55 let mut a_eq: Vec<Vec<f64>> = Vec::new();
56 let mut lb_eq: Vec<f64> = Vec::new();
57 let mut ub_eq: Vec<f64> = Vec::new();
58 let mut a_in: Vec<Vec<f64>> = Vec::new();
59 let mut lb_in: Vec<f64> = Vec::new();
60 let mut ub_in: Vec<f64> = Vec::new();
61 for (i, row) in lc.a.iter().enumerate() {
62 if i < lc.lb.len() && i < lc.ub.len() && (lc.lb[i] - lc.ub[i]).abs() <= tol {
63 a_eq.push(row.clone()); lb_eq.push(lc.lb[i]); ub_eq.push(lc.ub[i]);
64 } else {
65 a_in.push(row.clone());
66 lb_in.push(lc.lb.get(i).copied().unwrap_or(f64::NEG_INFINITY));
67 ub_in.push(lc.ub.get(i).copied().unwrap_or(f64::INFINITY));
68 }
69 }
70 let kwargs_c = PyDict::new(py);
71 if let Some(kf) = lc.keep_feasible { kwargs_c.set_item("keep_feasible", kf).unwrap(); }
72 if !a_in.is_empty() {
73 let a_arr = PyArray2::<f64>::from_vec2(py, &a_in).unwrap();
74 let lb_arr = PyArray1::<f64>::from_vec(py, lb_in);
75 let ub_arr = PyArray1::<f64>::from_vec(py, ub_in);
76 let obj_any = constr_cls.call((a_arr, lb_arr, ub_arr), Some(&kwargs_c)).unwrap();
77 items.push(obj_any);
78 }
79 if !a_eq.is_empty() {
80 let a_arr = PyArray2::<f64>::from_vec2(py, &a_eq).unwrap();
81 let lb_arr = PyArray1::<f64>::from_vec(py, lb_eq.clone());
82 let ub_arr = PyArray1::<f64>::from_vec(py, ub_eq.clone());
83 let obj_any = constr_cls.call((a_arr, lb_arr, ub_arr), Some(&kwargs_c)).unwrap();
84 items.push(obj_any);
85 }
86 }
87 }
88
89 if !self.nonlinear.is_empty() {
90 let constr_cls = optimize.getattr("NonlinearConstraint").unwrap();
91 for nc in &self.nonlinear {
92 let fun_obj = Py::new(py, ConstraintFunction { func: nc.fun.clone() }).unwrap();
93 let lb_arr = PyArray1::<f64>::from_vec(py, nc.lb.clone());
94 let ub_arr = PyArray1::<f64>::from_vec(py, nc.ub.clone());
95 let obj_any = constr_cls.call((fun_obj, lb_arr, ub_arr), None).unwrap();
96 items.push(obj_any);
97 }
98 }
99
100 for dc in &self.dict {
101 let dict = PyDict::new(py);
102 dict.set_item("type", dc.ctype.as_str()).unwrap();
103 let fun_obj = Py::new(py, Objective { func: dc.fun.clone() }).unwrap();
104 dict.set_item("fun", fun_obj).unwrap();
105 items.push(dict.into_any());
106 }
107
108 PyList::new(py, items).unwrap()
109 }
110}