shgo_rs/
constraints.rs

1use numpy::{PyArray1, PyArray2};
2use pyo3::prelude::*;
3use pyo3::types::{PyAny, PyDict, PyList};
4use std::sync::Arc;
5
6use crate::pybridge::{Objective, ConstraintFunction};
7
8#[derive(Debug, Clone)]
9pub enum ConstraintType { Eq, Ineq }
10
11impl ConstraintType {
12    pub fn as_str(&self) -> &'static str { match self { ConstraintType::Eq => "eq", ConstraintType::Ineq => "ineq" } }
13}
14
15#[derive(Clone)]
16pub enum ConstraintSpec {
17    Linear { a: Vec<Vec<f64>>, lb: Vec<f64>, ub: Vec<f64>, keep_feasible: Option<bool> },
18    NonlinearBounds { fun: Arc<dyn Fn(&[f64]) -> Vec<f64> + Send + Sync + 'static>, lb: Vec<f64>, ub: Vec<f64> },
19    Dict { ctype: ConstraintType, fun: Arc<dyn Fn(&[f64]) -> f64 + Send + Sync + 'static> },
20}
21
22#[derive(Clone)]
23struct LinearConstr { a: Vec<Vec<f64>>, lb: Vec<f64>, ub: Vec<f64>, keep_feasible: Option<bool> }
24#[derive(Clone)]
25struct NonlinearConstrBounds { fun: Arc<dyn Fn(&[f64]) -> Vec<f64> + Send + Sync + 'static>, lb: Vec<f64>, ub: Vec<f64> }
26#[derive(Clone)]
27struct DictConstr { ctype: ConstraintType, fun: Arc<dyn Fn(&[f64]) -> f64 + Send + Sync + 'static> }
28
29#[derive(Clone, Default)]
30pub struct CombinedConstraints { linear: Vec<LinearConstr>, nonlinear: Vec<NonlinearConstrBounds>, dict: Vec<DictConstr> }
31
32impl From<Vec<ConstraintSpec>> for CombinedConstraints {
33    fn from(specs: Vec<ConstraintSpec>) -> Self {
34        let mut cc = CombinedConstraints::default();
35        for s in specs {
36            match s {
37                ConstraintSpec::Linear { a, lb, ub, keep_feasible } => cc.linear.push(LinearConstr { a, lb, ub, keep_feasible }),
38                ConstraintSpec::NonlinearBounds { fun, lb, ub } => cc.nonlinear.push(NonlinearConstrBounds { fun, lb, ub }),
39                ConstraintSpec::Dict { ctype, fun } => cc.dict.push(DictConstr { ctype, fun }),
40            }
41        }
42        cc
43    }
44}
45
46impl CombinedConstraints {
47    pub fn to_py_list<'py>(&self, py: Python<'py>) -> pyo3::Bound<'py, PyList> {
48        let mut items: Vec<pyo3::Bound<'py, PyAny>> = Vec::with_capacity(self.linear.len() + self.nonlinear.len() + self.dict.len());
49        let optimize = PyModule::import(py, "scipy.optimize").unwrap();
50
51        if !self.linear.is_empty() {
52            let constr_cls = optimize.getattr("LinearConstraint").unwrap();
53            for lc in &self.linear {
54                let tol = 0.0_f64;
55                let mut a_eq: Vec<Vec<f64>> = Vec::new();
56                let mut lb_eq: Vec<f64> = Vec::new();
57                let mut ub_eq: Vec<f64> = Vec::new();
58                let mut a_in: Vec<Vec<f64>> = Vec::new();
59                let mut lb_in: Vec<f64> = Vec::new();
60                let mut ub_in: Vec<f64> = Vec::new();
61                for (i, row) in lc.a.iter().enumerate() {
62                    if i < lc.lb.len() && i < lc.ub.len() && (lc.lb[i] - lc.ub[i]).abs() <= tol {
63                        a_eq.push(row.clone()); lb_eq.push(lc.lb[i]); ub_eq.push(lc.ub[i]);
64                    } else {
65                        a_in.push(row.clone());
66                        lb_in.push(lc.lb.get(i).copied().unwrap_or(f64::NEG_INFINITY));
67                        ub_in.push(lc.ub.get(i).copied().unwrap_or(f64::INFINITY));
68                    }
69                }
70                let kwargs_c = PyDict::new(py);
71                if let Some(kf) = lc.keep_feasible { kwargs_c.set_item("keep_feasible", kf).unwrap(); }
72                if !a_in.is_empty() {
73                    let a_arr = PyArray2::<f64>::from_vec2(py, &a_in).unwrap();
74                    let lb_arr = PyArray1::<f64>::from_vec(py, lb_in);
75                    let ub_arr = PyArray1::<f64>::from_vec(py, ub_in);
76                    let obj_any = constr_cls.call((a_arr, lb_arr, ub_arr), Some(&kwargs_c)).unwrap();
77                    items.push(obj_any);
78                }
79                if !a_eq.is_empty() {
80                    let a_arr = PyArray2::<f64>::from_vec2(py, &a_eq).unwrap();
81                    let lb_arr = PyArray1::<f64>::from_vec(py, lb_eq.clone());
82                    let ub_arr = PyArray1::<f64>::from_vec(py, ub_eq.clone());
83                    let obj_any = constr_cls.call((a_arr, lb_arr, ub_arr), Some(&kwargs_c)).unwrap();
84                    items.push(obj_any);
85                }
86            }
87        }
88
89        if !self.nonlinear.is_empty() {
90            let constr_cls = optimize.getattr("NonlinearConstraint").unwrap();
91            for nc in &self.nonlinear {
92                let fun_obj = Py::new(py, ConstraintFunction { func: nc.fun.clone() }).unwrap();
93                let lb_arr = PyArray1::<f64>::from_vec(py, nc.lb.clone());
94                let ub_arr = PyArray1::<f64>::from_vec(py, nc.ub.clone());
95                let obj_any = constr_cls.call((fun_obj, lb_arr, ub_arr), None).unwrap();
96                items.push(obj_any);
97            }
98        }
99
100        for dc in &self.dict {
101            let dict = PyDict::new(py);
102            dict.set_item("type", dc.ctype.as_str()).unwrap();
103            let fun_obj = Py::new(py, Objective { func: dc.fun.clone() }).unwrap();
104            dict.set_item("fun", fun_obj).unwrap();
105            items.push(dict.into_any());
106        }
107
108        PyList::new(py, items).unwrap()
109    }
110}