use pyo3::prelude::*;
use crate::factors::*;
use crate::loss_functions::*;
use crate::problem::Problem;
fn convert_pyany_to_factor(py_any: &PyAny) -> PyResult<Box<dyn Factor + Send>> {
let factor_name: String = py_any.get_type().getattr("__name__")?.extract()?;
match factor_name.as_str() {
"BetweenFactorSE2" => {
let factor: BetweenFactorSE2 = py_any.extract().unwrap();
Ok(Box::new(factor))
}
"PriorFactor" => {
let factor: PriorFactor = py_any.extract().unwrap();
Ok(Box::new(factor))
}
_ => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"Unknown factor type",
)),
}
}
fn convert_pyany_to_loss_function(py_any: &PyAny) -> PyResult<Option<Box<dyn Loss + Send>>> {
let factor_name: String = py_any.get_type().getattr("__name__")?.extract()?;
match factor_name.as_str() {
"HuberLoss" => {
let loss_func: HuberLoss = py_any.extract().unwrap();
Ok(Some(Box::new(loss_func)))
}
"NoneType" => Ok(None),
_ => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"Unknown factor type",
)),
}
}
#[pymethods]
impl Problem {
#[new]
pub fn new_py() -> Self {
Problem::new()
}
#[pyo3(name = "add_residual_block")]
pub fn add_residual_block_py(
&mut self,
dim_residual: usize,
variable_key_size_list: Vec<(String, usize)>,
pyfactor: &PyAny,
pyloss_func: &PyAny,
) -> PyResult<()> {
self.add_residual_block(
dim_residual,
variable_key_size_list,
convert_pyany_to_factor(pyfactor).unwrap(),
convert_pyany_to_loss_function(pyloss_func).unwrap(),
);
Ok(())
}
}