1use super::Key;
2#[cfg(feature = "python")]
3use pyo3::prelude::*;
4use serde::{Deserialize, Serialize};
5#[cfg(feature = "utoipa")]
6use utoipa::ToSchema;
7
8#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
9#[cfg_attr(feature = "python", pyo3::pyclass)]
10#[cfg_attr(feature = "utoipa", derive(ToSchema))]
11pub struct JoinById {
12 pub keys: Vec<Key>,
13}
14
15impl JoinById {
16 pub fn new(keys: Vec<Key>) -> Self {
17 Self { keys }
18 }
19}
20
21#[cfg(feature = "python")]
22#[pymethods]
23impl JoinById {
24 #[new]
25 pub fn init(keys: Vec<Key>) -> Self {
26 Self { keys }
27 }
28}
29
30#[derive(Serialize, Deserialize, Debug, Clone, Default, PartialEq)]
31#[cfg_attr(feature = "utoipa", derive(ToSchema))]
32pub enum JoinBy {
34 #[default]
37 AddColumns,
38
39 Replace,
41
42 Extend,
44
45 Broadcast,
48
49 CartesianProduct,
52
53 JoinById(JoinById),
57}
58
59#[cfg(feature = "python")]
60pub mod python {
61 use super::*;
62 use serde::{Deserialize, Serialize};
63
64 #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
65 #[pyclass(eq, eq_int)]
66 pub enum PythonJoinBy {
67 AddColumns,
70
71 Replace,
73
74 Extend,
76
77 Broadcast,
80
81 CartesianProduct,
84
85 JoinById,
89 }
90
91 #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
98 #[pyclass]
99 pub struct PythonJoin {
100 pub join_type: PythonJoinBy,
101 pub join_by_id: Option<JoinById>,
102 }
103
104 impl TryFrom<PythonJoin> for JoinBy {
105 type Error = crate::error::Error;
106 fn try_from(py_join: PythonJoin) -> Result<Self, Self::Error> {
107 Ok(match py_join.join_type {
108 PythonJoinBy::AddColumns => JoinBy::AddColumns,
109 PythonJoinBy::Replace => JoinBy::Replace,
110 PythonJoinBy::Extend => JoinBy::Extend,
111 PythonJoinBy::Broadcast => JoinBy::Broadcast,
112 PythonJoinBy::CartesianProduct => JoinBy::CartesianProduct,
113 PythonJoinBy::JoinById => {
114 let join_by_id = py_join
115 .join_by_id
116 .ok_or_else(|| crate::error::Error::MissingField("join_by_id".into()))?;
117 JoinBy::JoinById(join_by_id)
118 }
119 })
120 }
121 }
122
123 impl TryFrom<JoinBy> for PythonJoin {
124 type Error = crate::error::Error;
125 fn try_from(py_join: JoinBy) -> Result<Self, Self::Error> {
126 Ok(match py_join {
127 JoinBy::AddColumns => PythonJoin {
128 join_type: PythonJoinBy::AddColumns,
129 join_by_id: None,
130 },
131 JoinBy::Replace => PythonJoin {
132 join_type: PythonJoinBy::Replace,
133 join_by_id: None,
134 },
135 JoinBy::Extend => PythonJoin {
136 join_type: PythonJoinBy::Extend,
137 join_by_id: None,
138 },
139 JoinBy::Broadcast => PythonJoin {
140 join_type: PythonJoinBy::Broadcast,
141 join_by_id: None,
142 },
143 JoinBy::CartesianProduct => PythonJoin {
144 join_type: PythonJoinBy::CartesianProduct,
145 join_by_id: None,
146 },
147 JoinBy::JoinById(join_by_id) => PythonJoin {
148 join_type: PythonJoinBy::JoinById,
149 join_by_id: Some(join_by_id),
150 },
151 })
152 }
153 }
154
155 impl FromPyObject<'_> for JoinBy {
156 fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
157 let py_join: PythonJoin = ob.extract()?;
158 Self::try_from(py_join).map_err(|e: crate::error::Error| {
159 pyo3::exceptions::PyValueError::new_err(format!("{e}"))
160 })
161 }
162 }
163
164 impl<'py> IntoPyObject<'py> for JoinBy {
165 type Error = PyErr;
166 type Target = PythonJoin;
167 type Output = Bound<'py, Self::Target>;
168 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
169 let py_join: PythonJoin = self.try_into().map_err(|e: crate::error::Error| {
170 pyo3::exceptions::PyValueError::new_err(format!("Error converting: {e}"))
171 })?;
172 py_join.into_pyobject(py)
173 }
174 }
175
176 #[cfg(test)]
177 mod test {
178 use super::*;
179 use rstest::*;
180
181 #[rstest]
182 #[case(JoinBy::AddColumns)]
183 #[case(JoinBy::Replace)]
184 #[case(JoinBy::Extend)]
185 #[case(JoinBy::Broadcast)]
186 #[case(JoinBy::CartesianProduct)]
187 #[case(JoinBy::JoinById(JoinById::new(vec!["a".into()])))]
188 fn test_join_by(#[case] join_by: JoinBy) {
189 let py_join = PythonJoin::try_from(join_by.clone()).unwrap();
190 let join_by2 = JoinBy::try_from(py_join).unwrap();
191 assert_eq!(join_by, join_by2);
192 }
193
194 #[rstest]
195 #[case(JoinBy::AddColumns)]
196 #[case(JoinBy::Replace)]
197 #[case(JoinBy::Extend)]
198 #[case(JoinBy::Broadcast)]
199 #[case(JoinBy::CartesianProduct)]
200 #[case(JoinBy::JoinById(JoinById::new(vec!["a".into()])))]
201 fn test_into_py(#[case] join_by: JoinBy) {
202 pyo3::Python::with_gil(|py| {
203 let py_join = join_by.clone().into_pyobject(py);
204 assert!(py_join.is_ok());
205 let py_join = py_join.unwrap();
206 let from_py = JoinBy::extract_bound(&py_join);
207 assert!(from_py.is_ok());
208 let join_by2 = from_py.unwrap();
209 assert_eq!(join_by, join_by2);
210 });
211 }
212 }
213}
214
215#[derive(Default, Serialize, Deserialize, Debug, Clone, PartialEq)]
216#[cfg_attr(feature = "python", pyclass)]
217#[cfg_attr(feature = "utoipa", derive(ToSchema))]
218pub struct JoinRelation {
219 pub join_type: JoinBy,
220}
221
222#[cfg(feature = "python")]
223#[pymethods]
224impl JoinRelation {
225 #[new]
226 pub fn init(join_type: JoinBy) -> Self {
227 Self::new(join_type)
228 }
229 #[pyo3(name = "broadcast")]
230 #[staticmethod]
231 pub fn py_broadcast() -> Self {
234 Self {
235 join_type: JoinBy::Broadcast,
236 }
237 }
238
239 #[pyo3(name = "add_columns")]
240 #[staticmethod]
241 pub fn py_add_columns() -> Self {
244 Self {
245 join_type: JoinBy::AddColumns,
246 }
247 }
248
249 #[pyo3(name = "replace")]
250 #[staticmethod]
251 pub fn py_replace() -> Self {
253 Self {
254 join_type: JoinBy::Replace,
255 }
256 }
257
258 #[pyo3(name = "extend")]
259 #[staticmethod]
260 pub fn py_extend() -> Self {
262 Self {
263 join_type: JoinBy::Extend,
264 }
265 }
266
267 #[pyo3(name = "cartesian_product")]
268 #[staticmethod]
269 pub fn py_cartesian_product() -> Self {
272 Self {
273 join_type: JoinBy::CartesianProduct,
274 }
275 }
276
277 #[pyo3(name = "join_by_id")]
278 #[staticmethod]
279 pub fn py_join_by_id(keys: Vec<Key>) -> Self {
283 Self {
284 join_type: JoinBy::JoinById(JoinById::new(keys)),
285 }
286 }
287}
288
289impl JoinRelation {
290 pub fn new(join_type: JoinBy) -> Self {
291 Self { join_type }
292 }
293
294 pub fn broadcast() -> Self {
295 Self {
296 join_type: JoinBy::Broadcast,
297 }
298 }
299
300 pub fn add_columns() -> Self {
301 Self {
302 join_type: JoinBy::AddColumns,
303 }
304 }
305
306 pub fn replace() -> Self {
307 Self {
308 join_type: JoinBy::Replace,
309 }
310 }
311
312 pub fn extend() -> Self {
313 Self {
314 join_type: JoinBy::Extend,
315 }
316 }
317
318 pub fn cartesian_product() -> Self {
319 Self {
320 join_type: JoinBy::CartesianProduct,
321 }
322 }
323
324 pub fn join_by_id(keys: Vec<Key>) -> Self {
325 Self {
326 join_type: JoinBy::JoinById(JoinById::new(keys)),
327 }
328 }
329}
330
331#[cfg(test)]
332mod test {
333 use super::*;
334 use rstest::*;
335
336 #[cfg(feature = "utoipa")]
337 #[rstest]
338 fn test_join_relation_to_schema() {
339 let _name = JoinRelation::name();
340 let mut schemas = vec![];
341
342 JoinRelation::schemas(&mut schemas);
343
344 assert!(!schemas.is_empty());
345 }
346
347 #[rstest]
348 #[case(JoinBy::AddColumns)]
349 #[case(JoinBy::Replace)]
350 #[case(JoinBy::Extend)]
351 #[case(JoinBy::Broadcast)]
352 #[case(JoinBy::CartesianProduct)]
353 fn test_join_relation_new(#[case] join_type: JoinBy) {
354 let join_relation = JoinRelation::new(join_type.clone());
355 assert_eq!(join_relation.join_type, join_type);
356 let serde = serde_json::to_string(&join_relation).expect("BUG: Cannot serialize");
357 let deserialized: JoinRelation =
358 serde_json::from_str(&serde).expect("BUG: cannot deserialize");
359 assert_eq!(deserialized, join_relation);
360 }
361
362 #[rstest]
363 #[case(JoinBy::AddColumns, JoinRelation::add_columns())]
364 #[case(JoinBy::Replace, JoinRelation::replace())]
365 #[case(JoinBy::Extend, JoinRelation::extend())]
366 #[case(JoinBy::Broadcast, JoinRelation::broadcast())]
367 #[case(JoinBy::CartesianProduct, JoinRelation::cartesian_product())]
368 #[case(JoinBy::JoinById(JoinById::new(vec!["a".into()])), JoinRelation::join_by_id(vec!["a".into()]))]
369 fn test_join_releation(#[case] join_type: JoinBy, #[case] jt: JoinRelation) {
370 let join_relation = JoinRelation::new(join_type.clone());
371 assert_eq!(join_relation.join_type, join_type);
372 assert_eq!(join_relation, jt);
373 let serde = serde_json::to_string(&join_relation).expect("BUG: Cannot serialize");
374 let deserialized: JoinRelation =
375 serde_json::from_str(&serde).expect("BUG: cannot deserialize");
376 assert_eq!(deserialized, join_relation);
377 }
378
379 #[cfg(feature = "python")]
380 #[rstest]
381 #[case(JoinBy::AddColumns)]
382 #[case(JoinBy::Replace)]
383 #[case(JoinBy::Extend)]
384 #[case(JoinBy::Broadcast)]
385 #[case(JoinBy::CartesianProduct)]
386 #[case(JoinBy::JoinById(JoinById::new(vec!["a".into()])))]
387 fn test_join_relation_py(#[case] join_type: JoinBy) {
388 pyo3::Python::with_gil(|_py| {
389 let join_relation = JoinRelation::new(join_type.clone());
390 let py_join_relation = JoinRelation::init(join_type.clone());
391 assert_eq!(join_relation.join_type, join_type);
392 assert_eq!(join_relation, py_join_relation);
393 });
394 }
395
396 #[cfg(feature = "python")]
397 #[rstest]
398 #[case(JoinBy::AddColumns, JoinRelation::py_add_columns())]
399 #[case(JoinBy::Replace, JoinRelation::py_replace())]
400 #[case(JoinBy::Extend, JoinRelation::py_extend())]
401 #[case(JoinBy::Broadcast, JoinRelation::py_broadcast())]
402 #[case(JoinBy::CartesianProduct, JoinRelation::py_cartesian_product())]
403 #[case(JoinBy::JoinById(JoinById::new(vec!["a".into()])), JoinRelation::py_join_by_id(vec!["a".into()]))]
404 fn test_py_join_releation(#[case] join_type: JoinBy, #[case] jt: JoinRelation) {
405 let join_relation = JoinRelation::new(join_type.clone());
406 assert_eq!(join_relation.join_type, join_type);
407 assert_eq!(join_relation, jt);
408 let serde = serde_json::to_string(&join_relation).expect("BUG: Cannot serialize");
409 let deserialized: JoinRelation =
410 serde_json::from_str(&serde).expect("BUG: cannot deserialize");
411 assert_eq!(deserialized, join_relation);
412 }
413}