use super::Key;
#[cfg(feature = "python")]
use pyo3::prelude::*;
use serde::{Deserialize, Serialize};
#[cfg(feature = "utoipa")]
use utoipa::ToSchema;
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[cfg_attr(feature = "python", pyo3::pyclass)]
#[cfg_attr(feature = "utoipa", derive(ToSchema))]
pub struct JoinById {
pub keys: Vec<Key>,
}
impl JoinById {
pub fn new(keys: Vec<Key>) -> Self {
Self { keys }
}
}
#[cfg(feature = "python")]
#[pymethods]
impl JoinById {
#[new]
pub fn init(keys: Vec<Key>) -> Self {
Self { keys }
}
}
#[derive(Serialize, Deserialize, Debug, Clone, Default, PartialEq)]
#[cfg_attr(feature = "utoipa", derive(ToSchema))]
pub enum JoinBy {
#[default]
AddColumns,
Replace,
Extend,
Broadcast,
CartesianProduct,
JoinById(JoinById),
}
#[cfg(feature = "python")]
pub mod python {
use super::*;
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[pyclass(eq, eq_int)]
pub enum PythonJoinBy {
AddColumns,
Replace,
Extend,
Broadcast,
CartesianProduct,
JoinById,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[pyclass]
pub struct PythonJoin {
pub join_type: PythonJoinBy,
pub join_by_id: Option<JoinById>,
}
impl TryFrom<PythonJoin> for JoinBy {
type Error = crate::error::Error;
fn try_from(py_join: PythonJoin) -> Result<Self, Self::Error> {
Ok(match py_join.join_type {
PythonJoinBy::AddColumns => JoinBy::AddColumns,
PythonJoinBy::Replace => JoinBy::Replace,
PythonJoinBy::Extend => JoinBy::Extend,
PythonJoinBy::Broadcast => JoinBy::Broadcast,
PythonJoinBy::CartesianProduct => JoinBy::CartesianProduct,
PythonJoinBy::JoinById => {
let join_by_id = py_join
.join_by_id
.ok_or_else(|| crate::error::Error::MissingField("join_by_id".into()))?;
JoinBy::JoinById(join_by_id)
}
})
}
}
impl TryFrom<JoinBy> for PythonJoin {
type Error = crate::error::Error;
fn try_from(py_join: JoinBy) -> Result<Self, Self::Error> {
Ok(match py_join {
JoinBy::AddColumns => PythonJoin {
join_type: PythonJoinBy::AddColumns,
join_by_id: None,
},
JoinBy::Replace => PythonJoin {
join_type: PythonJoinBy::Replace,
join_by_id: None,
},
JoinBy::Extend => PythonJoin {
join_type: PythonJoinBy::Extend,
join_by_id: None,
},
JoinBy::Broadcast => PythonJoin {
join_type: PythonJoinBy::Broadcast,
join_by_id: None,
},
JoinBy::CartesianProduct => PythonJoin {
join_type: PythonJoinBy::CartesianProduct,
join_by_id: None,
},
JoinBy::JoinById(join_by_id) => PythonJoin {
join_type: PythonJoinBy::JoinById,
join_by_id: Some(join_by_id),
},
})
}
}
impl FromPyObject<'_> for JoinBy {
fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
let py_join: PythonJoin = ob.extract()?;
Self::try_from(py_join).map_err(|e: crate::error::Error| {
pyo3::exceptions::PyValueError::new_err(format!("{}", e))
})
}
}
impl<'py> IntoPyObject<'py> for JoinBy {
type Error = PyErr;
type Target = PythonJoin;
type Output = Bound<'py, Self::Target>;
fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
let py_join: PythonJoin = self.try_into().map_err(|e: crate::error::Error| {
pyo3::exceptions::PyValueError::new_err(format!("Error converting: {}", e))
})?;
py_join.into_pyobject(py)
}
}
#[cfg(test)]
mod test {
use super::*;
use rstest::*;
#[rstest]
#[case(JoinBy::AddColumns)]
#[case(JoinBy::Replace)]
#[case(JoinBy::Extend)]
#[case(JoinBy::Broadcast)]
#[case(JoinBy::CartesianProduct)]
#[case(JoinBy::JoinById(JoinById::new(vec!["a".into()])))]
fn test_join_by(#[case] join_by: JoinBy) {
let py_join = PythonJoin::try_from(join_by.clone()).unwrap();
let join_by2 = JoinBy::try_from(py_join).unwrap();
assert_eq!(join_by, join_by2);
}
#[rstest]
#[case(JoinBy::AddColumns)]
#[case(JoinBy::Replace)]
#[case(JoinBy::Extend)]
#[case(JoinBy::Broadcast)]
#[case(JoinBy::CartesianProduct)]
#[case(JoinBy::JoinById(JoinById::new(vec!["a".into()])))]
fn test_into_py(#[case] join_by: JoinBy) {
pyo3::Python::with_gil(|py| {
let py_join = join_by.clone().into_pyobject(py);
assert!(py_join.is_ok());
let py_join = py_join.unwrap();
let from_py = JoinBy::extract_bound(&py_join);
assert!(from_py.is_ok());
let join_by2 = from_py.unwrap();
assert_eq!(join_by, join_by2);
});
}
}
}
#[derive(Default, Serialize, Deserialize, Debug, Clone, PartialEq)]
#[cfg_attr(feature = "python", pyclass)]
#[cfg_attr(feature = "utoipa", derive(ToSchema))]
pub struct JoinRelation {
pub join_type: JoinBy,
}
#[cfg(feature = "python")]
#[pymethods]
impl JoinRelation {
#[new]
pub fn init(join_type: JoinBy) -> Self {
Self::new(join_type)
}
}
impl JoinRelation {
pub fn new(join_type: JoinBy) -> Self {
Self { join_type }
}
pub fn broadcast() -> Self {
Self {
join_type: JoinBy::Broadcast,
}
}
pub fn add_columns() -> Self {
Self {
join_type: JoinBy::AddColumns,
}
}
pub fn replace() -> Self {
Self {
join_type: JoinBy::Replace,
}
}
pub fn extend() -> Self {
Self {
join_type: JoinBy::Extend,
}
}
pub fn cartesian_product() -> Self {
Self {
join_type: JoinBy::CartesianProduct,
}
}
pub fn join_by_id(keys: Vec<Key>) -> Self {
Self {
join_type: JoinBy::JoinById(JoinById::new(keys)),
}
}
}
#[cfg(test)]
mod test {
use super::*;
use rstest::*;
#[cfg(feature = "utoipa")]
#[rstest]
fn test_join_relation_to_schema() {
let _name = JoinRelation::name();
let mut schemas = vec![];
JoinRelation::schemas(&mut schemas);
assert!(!schemas.is_empty());
}
#[rstest]
#[case(JoinBy::AddColumns)]
#[case(JoinBy::Replace)]
#[case(JoinBy::Extend)]
#[case(JoinBy::Broadcast)]
#[case(JoinBy::CartesianProduct)]
fn test_join_relation_new(#[case] join_type: JoinBy) {
let join_relation = JoinRelation::new(join_type.clone());
assert_eq!(join_relation.join_type, join_type);
let serde = serde_json::to_string(&join_relation).expect("BUG: Cannot serialize");
let deserialized: JoinRelation =
serde_json::from_str(&serde).expect("BUG: cannot deserialize");
assert_eq!(deserialized, join_relation);
}
#[rstest]
#[case(JoinBy::AddColumns, JoinRelation::add_columns())]
#[case(JoinBy::Replace, JoinRelation::replace())]
#[case(JoinBy::Extend, JoinRelation::extend())]
#[case(JoinBy::Broadcast, JoinRelation::broadcast())]
#[case(JoinBy::CartesianProduct, JoinRelation::cartesian_product())]
#[case(JoinBy::JoinById(JoinById::new(vec!["a".into()])), JoinRelation::join_by_id(vec!["a".into()]))]
fn test_join_releation(#[case] join_type: JoinBy, #[case] jt: JoinRelation) {
let join_relation = JoinRelation::new(join_type.clone());
assert_eq!(join_relation.join_type, join_type);
assert_eq!(join_relation, jt);
let serde = serde_json::to_string(&join_relation).expect("BUG: Cannot serialize");
let deserialized: JoinRelation =
serde_json::from_str(&serde).expect("BUG: cannot deserialize");
assert_eq!(deserialized, join_relation);
}
#[cfg(feature = "python")]
#[rstest]
#[case(JoinBy::AddColumns)]
#[case(JoinBy::Replace)]
#[case(JoinBy::Extend)]
#[case(JoinBy::Broadcast)]
#[case(JoinBy::CartesianProduct)]
#[case(JoinBy::JoinById(JoinById::new(vec!["a".into()])))]
fn test_join_relation_py(#[case] join_type: JoinBy) {
pyo3::Python::with_gil(|_py| {
let join_relation = JoinRelation::new(join_type.clone());
let py_join_relation = JoinRelation::init(join_type.clone());
assert_eq!(join_relation.join_type, join_type);
assert_eq!(join_relation, py_join_relation);
});
}
}