1use super::Key;
2#[cfg(feature = "python")]
3use pyo3::prelude::*;
4use serde::{Deserialize, Serialize};
5#[cfg(feature = "utoipa")]
6use utoipa::ToSchema;
7
8#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
9#[cfg_attr(feature = "python", pyo3::pyclass)]
10#[cfg_attr(feature = "utoipa", derive(ToSchema))]
11pub struct JoinById {
13 pub keys: Vec<Key>,
14}
15
16impl JoinById {
17 pub fn new(keys: Vec<Key>) -> Self {
19 Self { keys }
20 }
21}
22
23#[cfg(feature = "python")]
24#[pymethods]
25impl JoinById {
26 #[new]
27 pub fn init(keys: Vec<Key>) -> Self {
28 Self { keys }
29 }
30}
31
32#[derive(Serialize, Deserialize, Debug, Clone, Default, PartialEq)]
33#[cfg_attr(feature = "utoipa", derive(ToSchema))]
34pub enum JoinBy {
36 #[default]
39 AddColumns,
40
41 Replace,
43
44 Extend,
46
47 Broadcast,
50
51 CartesianProduct,
54
55 JoinById(JoinById),
59}
60
61#[cfg(feature = "python")]
62pub mod python {
63 use super::*;
64 use serde::{Deserialize, Serialize};
65
66 #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
67 #[pyclass(eq, eq_int)]
68 pub enum PythonJoinBy {
69 AddColumns,
72
73 Replace,
75
76 Extend,
78
79 Broadcast,
82
83 CartesianProduct,
86
87 JoinById,
91 }
92
93 #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
100 #[pyclass]
101 pub struct PythonJoin {
102 pub join_type: PythonJoinBy,
103 pub join_by_id: Option<JoinById>,
104 }
105
106 impl TryFrom<PythonJoin> for JoinBy {
107 type Error = crate::error::Error;
108 fn try_from(py_join: PythonJoin) -> Result<Self, Self::Error> {
109 Ok(match py_join.join_type {
110 PythonJoinBy::AddColumns => JoinBy::AddColumns,
111 PythonJoinBy::Replace => JoinBy::Replace,
112 PythonJoinBy::Extend => JoinBy::Extend,
113 PythonJoinBy::Broadcast => JoinBy::Broadcast,
114 PythonJoinBy::CartesianProduct => JoinBy::CartesianProduct,
115 PythonJoinBy::JoinById => {
116 let join_by_id = py_join
117 .join_by_id
118 .ok_or_else(|| crate::error::Error::MissingField("join_by_id".into()))?;
119 JoinBy::JoinById(join_by_id)
120 }
121 })
122 }
123 }
124
125 impl TryFrom<JoinBy> for PythonJoin {
126 type Error = crate::error::Error;
127 fn try_from(py_join: JoinBy) -> Result<Self, Self::Error> {
128 Ok(match py_join {
129 JoinBy::AddColumns => PythonJoin {
130 join_type: PythonJoinBy::AddColumns,
131 join_by_id: None,
132 },
133 JoinBy::Replace => PythonJoin {
134 join_type: PythonJoinBy::Replace,
135 join_by_id: None,
136 },
137 JoinBy::Extend => PythonJoin {
138 join_type: PythonJoinBy::Extend,
139 join_by_id: None,
140 },
141 JoinBy::Broadcast => PythonJoin {
142 join_type: PythonJoinBy::Broadcast,
143 join_by_id: None,
144 },
145 JoinBy::CartesianProduct => PythonJoin {
146 join_type: PythonJoinBy::CartesianProduct,
147 join_by_id: None,
148 },
149 JoinBy::JoinById(join_by_id) => PythonJoin {
150 join_type: PythonJoinBy::JoinById,
151 join_by_id: Some(join_by_id),
152 },
153 })
154 }
155 }
156
157 impl FromPyObject<'_> for JoinBy {
158 fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult<Self> {
159 let py_join: PythonJoin = ob.extract()?;
160 Self::try_from(py_join).map_err(|e: crate::error::Error| {
161 pyo3::exceptions::PyValueError::new_err(format!("{e}"))
162 })
163 }
164 }
165
166 impl<'py> IntoPyObject<'py> for JoinBy {
167 type Error = PyErr;
168 type Target = PythonJoin;
169 type Output = Bound<'py, Self::Target>;
170 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
171 let py_join: PythonJoin = self.try_into().map_err(|e: crate::error::Error| {
172 pyo3::exceptions::PyValueError::new_err(format!("Error converting: {e}"))
173 })?;
174 py_join.into_pyobject(py)
175 }
176 }
177
178 #[cfg(test)]
179 mod test {
180 use super::*;
181 use rstest::*;
182
183 #[rstest]
184 #[case(JoinBy::AddColumns)]
185 #[case(JoinBy::Replace)]
186 #[case(JoinBy::Extend)]
187 #[case(JoinBy::Broadcast)]
188 #[case(JoinBy::CartesianProduct)]
189 #[case(JoinBy::JoinById(JoinById::new(vec!["a".into()])))]
190 fn test_join_by(#[case] join_by: JoinBy) {
191 let py_join = PythonJoin::try_from(join_by.clone()).unwrap();
192 let join_by2 = JoinBy::try_from(py_join).unwrap();
193 assert_eq!(join_by, join_by2);
194 }
195
196 #[rstest]
197 #[case(JoinBy::AddColumns)]
198 #[case(JoinBy::Replace)]
199 #[case(JoinBy::Extend)]
200 #[case(JoinBy::Broadcast)]
201 #[case(JoinBy::CartesianProduct)]
202 #[case(JoinBy::JoinById(JoinById::new(vec!["a".into()])))]
203 fn test_into_py(#[case] join_by: JoinBy) {
204 pyo3::Python::attach(|py| {
205 let py_join = join_by.clone().into_pyobject(py);
206 assert!(py_join.is_ok());
207 let py_join = py_join.unwrap();
208 let from_py = JoinBy::extract_bound(&py_join);
209 assert!(from_py.is_ok());
210 let join_by2 = from_py.unwrap();
211 assert_eq!(join_by, join_by2);
212 });
213 }
214 }
215}
216
217#[derive(Default, Serialize, Deserialize, Debug, Clone, PartialEq)]
218#[cfg_attr(feature = "python", pyclass)]
219#[cfg_attr(feature = "utoipa", derive(ToSchema))]
220pub struct JoinRelation {
221 pub join_type: JoinBy,
222}
223
224#[cfg(feature = "python")]
225#[pymethods]
226impl JoinRelation {
227 #[new]
228 pub fn init(join_type: JoinBy) -> Self {
229 Self::new(join_type)
230 }
231 #[pyo3(name = "broadcast")]
232 #[staticmethod]
233 pub fn py_broadcast() -> Self {
236 Self {
237 join_type: JoinBy::Broadcast,
238 }
239 }
240
241 #[pyo3(name = "add_columns")]
242 #[staticmethod]
243 pub fn py_add_columns() -> Self {
246 Self {
247 join_type: JoinBy::AddColumns,
248 }
249 }
250
251 #[pyo3(name = "replace")]
252 #[staticmethod]
253 pub fn py_replace() -> Self {
255 Self {
256 join_type: JoinBy::Replace,
257 }
258 }
259
260 #[pyo3(name = "extend")]
261 #[staticmethod]
262 pub fn py_extend() -> Self {
264 Self {
265 join_type: JoinBy::Extend,
266 }
267 }
268
269 #[pyo3(name = "cartesian_product")]
270 #[staticmethod]
271 pub fn py_cartesian_product() -> Self {
274 Self {
275 join_type: JoinBy::CartesianProduct,
276 }
277 }
278
279 #[pyo3(name = "join_by_id")]
280 #[staticmethod]
281 pub fn py_join_by_id(keys: Vec<Key>) -> Self {
285 Self {
286 join_type: JoinBy::JoinById(JoinById::new(keys)),
287 }
288 }
289}
290
291impl JoinRelation {
292 pub fn new(join_type: JoinBy) -> Self {
293 Self { join_type }
294 }
295
296 pub fn broadcast() -> Self {
297 Self {
298 join_type: JoinBy::Broadcast,
299 }
300 }
301
302 pub fn add_columns() -> Self {
303 Self {
304 join_type: JoinBy::AddColumns,
305 }
306 }
307
308 pub fn replace() -> Self {
309 Self {
310 join_type: JoinBy::Replace,
311 }
312 }
313
314 pub fn extend() -> Self {
315 Self {
316 join_type: JoinBy::Extend,
317 }
318 }
319
320 pub fn cartesian_product() -> Self {
321 Self {
322 join_type: JoinBy::CartesianProduct,
323 }
324 }
325
326 pub fn join_by_id(keys: Vec<Key>) -> Self {
327 Self {
328 join_type: JoinBy::JoinById(JoinById::new(keys)),
329 }
330 }
331}
332
333#[cfg(test)]
334mod test {
335 use super::*;
336 use rstest::*;
337
338 #[cfg(feature = "utoipa")]
339 #[rstest]
340 fn test_join_relation_to_schema() {
341 let _name = JoinRelation::name();
342 let mut schemas = vec![];
343
344 JoinRelation::schemas(&mut schemas);
345
346 assert!(!schemas.is_empty());
347 }
348
349 #[rstest]
350 #[case(JoinBy::AddColumns)]
351 #[case(JoinBy::Replace)]
352 #[case(JoinBy::Extend)]
353 #[case(JoinBy::Broadcast)]
354 #[case(JoinBy::CartesianProduct)]
355 fn test_join_relation_new(#[case] join_type: JoinBy) {
356 let join_relation = JoinRelation::new(join_type.clone());
357 assert_eq!(join_relation.join_type, join_type);
358 let serde = serde_json::to_string(&join_relation).expect("BUG: Cannot serialize");
359 let deserialized: JoinRelation =
360 serde_json::from_str(&serde).expect("BUG: cannot deserialize");
361 assert_eq!(deserialized, join_relation);
362 }
363
364 #[rstest]
365 #[case(JoinBy::AddColumns, JoinRelation::add_columns())]
366 #[case(JoinBy::Replace, JoinRelation::replace())]
367 #[case(JoinBy::Extend, JoinRelation::extend())]
368 #[case(JoinBy::Broadcast, JoinRelation::broadcast())]
369 #[case(JoinBy::CartesianProduct, JoinRelation::cartesian_product())]
370 #[case(JoinBy::JoinById(JoinById::new(vec!["a".into()])), JoinRelation::join_by_id(vec!["a".into()]))]
371 fn test_join_releation(#[case] join_type: JoinBy, #[case] jt: JoinRelation) {
372 let join_relation = JoinRelation::new(join_type.clone());
373 assert_eq!(join_relation.join_type, join_type);
374 assert_eq!(join_relation, jt);
375 let serde = serde_json::to_string(&join_relation).expect("BUG: Cannot serialize");
376 let deserialized: JoinRelation =
377 serde_json::from_str(&serde).expect("BUG: cannot deserialize");
378 assert_eq!(deserialized, join_relation);
379 }
380
381 #[cfg(feature = "python")]
382 #[rstest]
383 #[case(JoinBy::AddColumns)]
384 #[case(JoinBy::Replace)]
385 #[case(JoinBy::Extend)]
386 #[case(JoinBy::Broadcast)]
387 #[case(JoinBy::CartesianProduct)]
388 #[case(JoinBy::JoinById(JoinById::new(vec!["a".into()])))]
389 fn test_join_relation_py(#[case] join_type: JoinBy) {
390 pyo3::Python::attach(|_py| {
391 let join_relation = JoinRelation::new(join_type.clone());
392 let py_join_relation = JoinRelation::init(join_type.clone());
393 assert_eq!(join_relation.join_type, join_type);
394 assert_eq!(join_relation, py_join_relation);
395 });
396 }
397
398 #[cfg(feature = "python")]
399 #[rstest]
400 #[case(JoinBy::AddColumns, JoinRelation::py_add_columns())]
401 #[case(JoinBy::Replace, JoinRelation::py_replace())]
402 #[case(JoinBy::Extend, JoinRelation::py_extend())]
403 #[case(JoinBy::Broadcast, JoinRelation::py_broadcast())]
404 #[case(JoinBy::CartesianProduct, JoinRelation::py_cartesian_product())]
405 #[case(JoinBy::JoinById(JoinById::new(vec!["a".into()])), JoinRelation::py_join_by_id(vec!["a".into()]))]
406 fn test_py_join_releation(#[case] join_type: JoinBy, #[case] jt: JoinRelation) {
407 let join_relation = JoinRelation::new(join_type.clone());
408 assert_eq!(join_relation.join_type, join_type);
409 assert_eq!(join_relation, jt);
410 let serde = serde_json::to_string(&join_relation).expect("BUG: Cannot serialize");
411 let deserialized: JoinRelation =
412 serde_json::from_str(&serde).expect("BUG: cannot deserialize");
413 assert_eq!(deserialized, join_relation);
414 }
415}