trs_dataframe/dataframe/
join.rs

1use super::Key;
2#[cfg(feature = "python")]
3use pyo3::prelude::*;
4use serde::{Deserialize, Serialize};
5#[cfg(feature = "utoipa")]
6use utoipa::ToSchema;
7
8#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
9#[cfg_attr(feature = "python", pyo3::pyclass)]
10#[cfg_attr(feature = "utoipa", derive(ToSchema))]
11pub struct JoinById {
12    pub keys: Vec<Key>,
13}
14
15impl JoinById {
16    pub fn new(keys: Vec<Key>) -> Self {
17        Self { keys }
18    }
19}
20
21#[cfg(feature = "python")]
22#[pymethods]
23impl JoinById {
24    #[new]
25    pub fn init(keys: Vec<Key>) -> Self {
26        Self { keys }
27    }
28}
29
30#[derive(Serialize, Deserialize, Debug, Clone, Default, PartialEq)]
31#[cfg_attr(feature = "utoipa", derive(ToSchema))]
32/// Enum representing different strategies for combining or joining data structures.
33pub enum JoinBy {
34    /// Adds only non-existing columns to the existing structure.
35    /// This is the default behavior.
36    #[default]
37    AddColumns,
38
39    /// Replaces existing data with the new data.
40    Replace,
41
42    /// Extends the existing data by appending new elements.
43    Extend,
44
45    /// Performs a broadcast operation, replicating smaller data structures
46    /// to match the size of larger ones.
47    Broadcast,
48
49    /// Computes the Cartesian product of the input structures,
50    /// resulting in all possible combinations of elements.
51    CartesianProduct,
52
53    /// Joins two structures using a specific identifier or key.
54    ///
55    /// The behavior is determined by the provided `JoinById` variant.
56    JoinById(JoinById),
57}
58
59#[cfg(feature = "python")]
60pub mod python {
61    use super::*;
62    use serde::{Deserialize, Serialize};
63
64    #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
65    #[pyclass(eq, eq_int)]
66    pub enum PythonJoinBy {
67        /// Adds only non-existing columns to the existing structure.
68        /// This is the default behavior.
69        AddColumns,
70
71        /// Replaces existing data with the new data.
72        Replace,
73
74        /// Extends the existing data by appending new elements.
75        Extend,
76
77        /// Performs a broadcast operation, replicating smaller data structures
78        /// to match the size of larger ones.
79        Broadcast,
80
81        /// Computes the Cartesian product of the input structures,
82        /// resulting in all possible combinations of elements.
83        CartesianProduct,
84
85        /// Joins two structures using a specific identifier or key.
86        ///
87        /// The behavior is determined by the provided `JoinById` variant.
88        JoinById,
89    }
90
91    /// Python representation of the `JoinBy` enum,
92    /// which includes the join type and an optional `JoinById`.
93    /// This struct is used to facilitate conversions between Rust and Python representations.
94    /// It allows for the serialization and deserialization of join operations in a Python-friendly format.
95    /// This struct is particularly useful when integrating with Python code,
96    ///
97    #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
98    #[pyclass]
99    pub struct PythonJoin {
100        pub join_type: PythonJoinBy,
101        pub join_by_id: Option<JoinById>,
102    }
103
104    impl TryFrom<PythonJoin> for JoinBy {
105        type Error = crate::error::Error;
106        fn try_from(py_join: PythonJoin) -> Result<Self, Self::Error> {
107            Ok(match py_join.join_type {
108                PythonJoinBy::AddColumns => JoinBy::AddColumns,
109                PythonJoinBy::Replace => JoinBy::Replace,
110                PythonJoinBy::Extend => JoinBy::Extend,
111                PythonJoinBy::Broadcast => JoinBy::Broadcast,
112                PythonJoinBy::CartesianProduct => JoinBy::CartesianProduct,
113                PythonJoinBy::JoinById => {
114                    let join_by_id = py_join
115                        .join_by_id
116                        .ok_or_else(|| crate::error::Error::MissingField("join_by_id".into()))?;
117                    JoinBy::JoinById(join_by_id)
118                }
119            })
120        }
121    }
122
123    impl TryFrom<JoinBy> for PythonJoin {
124        type Error = crate::error::Error;
125        fn try_from(py_join: JoinBy) -> Result<Self, Self::Error> {
126            Ok(match py_join {
127                JoinBy::AddColumns => PythonJoin {
128                    join_type: PythonJoinBy::AddColumns,
129                    join_by_id: None,
130                },
131                JoinBy::Replace => PythonJoin {
132                    join_type: PythonJoinBy::Replace,
133                    join_by_id: None,
134                },
135                JoinBy::Extend => PythonJoin {
136                    join_type: PythonJoinBy::Extend,
137                    join_by_id: None,
138                },
139                JoinBy::Broadcast => PythonJoin {
140                    join_type: PythonJoinBy::Broadcast,
141                    join_by_id: None,
142                },
143                JoinBy::CartesianProduct => PythonJoin {
144                    join_type: PythonJoinBy::CartesianProduct,
145                    join_by_id: None,
146                },
147                JoinBy::JoinById(join_by_id) => PythonJoin {
148                    join_type: PythonJoinBy::JoinById,
149                    join_by_id: Some(join_by_id),
150                },
151            })
152        }
153    }
154
155    impl FromPyObject<'_> for JoinBy {
156        fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
157            let py_join: PythonJoin = ob.extract()?;
158            Self::try_from(py_join).map_err(|e: crate::error::Error| {
159                pyo3::exceptions::PyValueError::new_err(format!("{e}"))
160            })
161        }
162    }
163
164    impl<'py> IntoPyObject<'py> for JoinBy {
165        type Error = PyErr;
166        type Target = PythonJoin;
167        type Output = Bound<'py, Self::Target>;
168        fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
169            let py_join: PythonJoin = self.try_into().map_err(|e: crate::error::Error| {
170                pyo3::exceptions::PyValueError::new_err(format!("Error converting: {e}"))
171            })?;
172            py_join.into_pyobject(py)
173        }
174    }
175
176    #[cfg(test)]
177    mod test {
178        use super::*;
179        use rstest::*;
180
181        #[rstest]
182        #[case(JoinBy::AddColumns)]
183        #[case(JoinBy::Replace)]
184        #[case(JoinBy::Extend)]
185        #[case(JoinBy::Broadcast)]
186        #[case(JoinBy::CartesianProduct)]
187        #[case(JoinBy::JoinById(JoinById::new(vec!["a".into()])))]
188        fn test_join_by(#[case] join_by: JoinBy) {
189            let py_join = PythonJoin::try_from(join_by.clone()).unwrap();
190            let join_by2 = JoinBy::try_from(py_join).unwrap();
191            assert_eq!(join_by, join_by2);
192        }
193
194        #[rstest]
195        #[case(JoinBy::AddColumns)]
196        #[case(JoinBy::Replace)]
197        #[case(JoinBy::Extend)]
198        #[case(JoinBy::Broadcast)]
199        #[case(JoinBy::CartesianProduct)]
200        #[case(JoinBy::JoinById(JoinById::new(vec!["a".into()])))]
201        fn test_into_py(#[case] join_by: JoinBy) {
202            pyo3::Python::with_gil(|py| {
203                let py_join = join_by.clone().into_pyobject(py);
204                assert!(py_join.is_ok());
205                let py_join = py_join.unwrap();
206                let from_py = JoinBy::extract_bound(&py_join);
207                assert!(from_py.is_ok());
208                let join_by2 = from_py.unwrap();
209                assert_eq!(join_by, join_by2);
210            });
211        }
212    }
213}
214
215#[derive(Default, Serialize, Deserialize, Debug, Clone, PartialEq)]
216#[cfg_attr(feature = "python", pyclass)]
217#[cfg_attr(feature = "utoipa", derive(ToSchema))]
218pub struct JoinRelation {
219    pub join_type: JoinBy,
220}
221
222#[cfg(feature = "python")]
223#[pymethods]
224impl JoinRelation {
225    #[new]
226    pub fn init(join_type: JoinBy) -> Self {
227        Self::new(join_type)
228    }
229    #[pyo3(name = "broadcast")]
230    #[staticmethod]
231    /// Performs a broadcast operation, replicating smaller data structures
232    /// to match the size of larger ones.
233    pub fn py_broadcast() -> Self {
234        Self {
235            join_type: JoinBy::Broadcast,
236        }
237    }
238
239    #[pyo3(name = "add_columns")]
240    #[staticmethod]
241    /// Adds only non-existing columns to the existing structure.
242    /// This is the default behavior.
243    pub fn py_add_columns() -> Self {
244        Self {
245            join_type: JoinBy::AddColumns,
246        }
247    }
248
249    #[pyo3(name = "replace")]
250    #[staticmethod]
251    /// Replaces existing data with the new data.
252    pub fn py_replace() -> Self {
253        Self {
254            join_type: JoinBy::Replace,
255        }
256    }
257
258    #[pyo3(name = "extend")]
259    #[staticmethod]
260    /// Extends the existing data by appending new elements.
261    pub fn py_extend() -> Self {
262        Self {
263            join_type: JoinBy::Extend,
264        }
265    }
266
267    #[pyo3(name = "cartesian_product")]
268    #[staticmethod]
269    /// Computes the Cartesian product of the input structures,
270    /// resulting in all possible combinations of elements.
271    pub fn py_cartesian_product() -> Self {
272        Self {
273            join_type: JoinBy::CartesianProduct,
274        }
275    }
276
277    #[pyo3(name = "join_by_id")]
278    #[staticmethod]
279    /// Joins two structures using a specific identifier or key.
280    ///
281    /// The behavior is determined by the provided key variant.
282    pub fn py_join_by_id(keys: Vec<Key>) -> Self {
283        Self {
284            join_type: JoinBy::JoinById(JoinById::new(keys)),
285        }
286    }
287}
288
289impl JoinRelation {
290    pub fn new(join_type: JoinBy) -> Self {
291        Self { join_type }
292    }
293
294    pub fn broadcast() -> Self {
295        Self {
296            join_type: JoinBy::Broadcast,
297        }
298    }
299
300    pub fn add_columns() -> Self {
301        Self {
302            join_type: JoinBy::AddColumns,
303        }
304    }
305
306    pub fn replace() -> Self {
307        Self {
308            join_type: JoinBy::Replace,
309        }
310    }
311
312    pub fn extend() -> Self {
313        Self {
314            join_type: JoinBy::Extend,
315        }
316    }
317
318    pub fn cartesian_product() -> Self {
319        Self {
320            join_type: JoinBy::CartesianProduct,
321        }
322    }
323
324    pub fn join_by_id(keys: Vec<Key>) -> Self {
325        Self {
326            join_type: JoinBy::JoinById(JoinById::new(keys)),
327        }
328    }
329}
330
331#[cfg(test)]
332mod test {
333    use super::*;
334    use rstest::*;
335
336    #[cfg(feature = "utoipa")]
337    #[rstest]
338    fn test_join_relation_to_schema() {
339        let _name = JoinRelation::name();
340        let mut schemas = vec![];
341
342        JoinRelation::schemas(&mut schemas);
343
344        assert!(!schemas.is_empty());
345    }
346
347    #[rstest]
348    #[case(JoinBy::AddColumns)]
349    #[case(JoinBy::Replace)]
350    #[case(JoinBy::Extend)]
351    #[case(JoinBy::Broadcast)]
352    #[case(JoinBy::CartesianProduct)]
353    fn test_join_relation_new(#[case] join_type: JoinBy) {
354        let join_relation = JoinRelation::new(join_type.clone());
355        assert_eq!(join_relation.join_type, join_type);
356        let serde = serde_json::to_string(&join_relation).expect("BUG: Cannot serialize");
357        let deserialized: JoinRelation =
358            serde_json::from_str(&serde).expect("BUG: cannot deserialize");
359        assert_eq!(deserialized, join_relation);
360    }
361
362    #[rstest]
363    #[case(JoinBy::AddColumns, JoinRelation::add_columns())]
364    #[case(JoinBy::Replace, JoinRelation::replace())]
365    #[case(JoinBy::Extend, JoinRelation::extend())]
366    #[case(JoinBy::Broadcast, JoinRelation::broadcast())]
367    #[case(JoinBy::CartesianProduct, JoinRelation::cartesian_product())]
368    #[case(JoinBy::JoinById(JoinById::new(vec!["a".into()])), JoinRelation::join_by_id(vec!["a".into()]))]
369    fn test_join_releation(#[case] join_type: JoinBy, #[case] jt: JoinRelation) {
370        let join_relation = JoinRelation::new(join_type.clone());
371        assert_eq!(join_relation.join_type, join_type);
372        assert_eq!(join_relation, jt);
373        let serde = serde_json::to_string(&join_relation).expect("BUG: Cannot serialize");
374        let deserialized: JoinRelation =
375            serde_json::from_str(&serde).expect("BUG: cannot deserialize");
376        assert_eq!(deserialized, join_relation);
377    }
378
379    #[cfg(feature = "python")]
380    #[rstest]
381    #[case(JoinBy::AddColumns)]
382    #[case(JoinBy::Replace)]
383    #[case(JoinBy::Extend)]
384    #[case(JoinBy::Broadcast)]
385    #[case(JoinBy::CartesianProduct)]
386    #[case(JoinBy::JoinById(JoinById::new(vec!["a".into()])))]
387    fn test_join_relation_py(#[case] join_type: JoinBy) {
388        pyo3::Python::with_gil(|_py| {
389            let join_relation = JoinRelation::new(join_type.clone());
390            let py_join_relation = JoinRelation::init(join_type.clone());
391            assert_eq!(join_relation.join_type, join_type);
392            assert_eq!(join_relation, py_join_relation);
393        });
394    }
395
396    #[cfg(feature = "python")]
397    #[rstest]
398    #[case(JoinBy::AddColumns, JoinRelation::py_add_columns())]
399    #[case(JoinBy::Replace, JoinRelation::py_replace())]
400    #[case(JoinBy::Extend, JoinRelation::py_extend())]
401    #[case(JoinBy::Broadcast, JoinRelation::py_broadcast())]
402    #[case(JoinBy::CartesianProduct, JoinRelation::py_cartesian_product())]
403    #[case(JoinBy::JoinById(JoinById::new(vec!["a".into()])), JoinRelation::py_join_by_id(vec!["a".into()]))]
404    fn test_py_join_releation(#[case] join_type: JoinBy, #[case] jt: JoinRelation) {
405        let join_relation = JoinRelation::new(join_type.clone());
406        assert_eq!(join_relation.join_type, join_type);
407        assert_eq!(join_relation, jt);
408        let serde = serde_json::to_string(&join_relation).expect("BUG: Cannot serialize");
409        let deserialized: JoinRelation =
410            serde_json::from_str(&serde).expect("BUG: cannot deserialize");
411        assert_eq!(deserialized, join_relation);
412    }
413}