solverforge_core/constraints/
joiners.rs

1use serde::{Deserialize, Deserializer, Serialize, Serializer};
2
3#[derive(Debug, Clone, PartialEq, Eq)]
4pub struct WasmFunction {
5    name: String,
6    relation_function: Option<String>,
7    hash_function: Option<String>,
8    comparator_function: Option<String>,
9}
10
11impl Serialize for WasmFunction {
12    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
13    where
14        S: Serializer,
15    {
16        serializer.serialize_str(&self.name)
17    }
18}
19
20impl<'de> Deserialize<'de> for WasmFunction {
21    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
22    where
23        D: Deserializer<'de>,
24    {
25        let name = String::deserialize(deserializer)?;
26        Ok(WasmFunction::new(name))
27    }
28}
29
30impl WasmFunction {
31    pub fn new(name: impl Into<String>) -> Self {
32        Self {
33            name: name.into(),
34            relation_function: None,
35            hash_function: None,
36            comparator_function: None,
37        }
38    }
39
40    pub fn with_relation(mut self, relation: impl Into<String>) -> Self {
41        self.relation_function = Some(relation.into());
42        self
43    }
44
45    pub fn with_hash(mut self, hash: impl Into<String>) -> Self {
46        self.hash_function = Some(hash.into());
47        self
48    }
49
50    pub fn with_comparator(mut self, comparator: impl Into<String>) -> Self {
51        self.comparator_function = Some(comparator.into());
52        self
53    }
54
55    pub fn name(&self) -> &str {
56        &self.name
57    }
58
59    pub fn relation_function(&self) -> Option<&str> {
60        self.relation_function.as_deref()
61    }
62
63    pub fn hash_function(&self) -> Option<&str> {
64        self.hash_function.as_deref()
65    }
66
67    pub fn comparator_function(&self) -> Option<&str> {
68        self.comparator_function.as_deref()
69    }
70}
71
72#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
73#[serde(tag = "relation")]
74pub enum Joiner {
75    #[serde(rename = "equal")]
76    Equal {
77        #[serde(skip_serializing_if = "Option::is_none")]
78        map: Option<WasmFunction>,
79        #[serde(rename = "leftMap", skip_serializing_if = "Option::is_none")]
80        left_map: Option<WasmFunction>,
81        #[serde(rename = "rightMap", skip_serializing_if = "Option::is_none")]
82        right_map: Option<WasmFunction>,
83        #[serde(rename = "equal", skip_serializing_if = "Option::is_none")]
84        relation_predicate: Option<WasmFunction>,
85        #[serde(skip_serializing_if = "Option::is_none")]
86        hasher: Option<WasmFunction>,
87    },
88    #[serde(rename = "lessThan")]
89    LessThan {
90        #[serde(skip_serializing_if = "Option::is_none")]
91        map: Option<WasmFunction>,
92        #[serde(rename = "leftMap", skip_serializing_if = "Option::is_none")]
93        left_map: Option<WasmFunction>,
94        #[serde(rename = "rightMap", skip_serializing_if = "Option::is_none")]
95        right_map: Option<WasmFunction>,
96        comparator: WasmFunction,
97    },
98    #[serde(rename = "lessThanOrEqual")]
99    LessThanOrEqual {
100        #[serde(skip_serializing_if = "Option::is_none")]
101        map: Option<WasmFunction>,
102        #[serde(rename = "leftMap", skip_serializing_if = "Option::is_none")]
103        left_map: Option<WasmFunction>,
104        #[serde(rename = "rightMap", skip_serializing_if = "Option::is_none")]
105        right_map: Option<WasmFunction>,
106        comparator: WasmFunction,
107    },
108    #[serde(rename = "greaterThan")]
109    GreaterThan {
110        #[serde(skip_serializing_if = "Option::is_none")]
111        map: Option<WasmFunction>,
112        #[serde(rename = "leftMap", skip_serializing_if = "Option::is_none")]
113        left_map: Option<WasmFunction>,
114        #[serde(rename = "rightMap", skip_serializing_if = "Option::is_none")]
115        right_map: Option<WasmFunction>,
116        comparator: WasmFunction,
117    },
118    #[serde(rename = "greaterThanOrEqual")]
119    GreaterThanOrEqual {
120        #[serde(skip_serializing_if = "Option::is_none")]
121        map: Option<WasmFunction>,
122        #[serde(rename = "leftMap", skip_serializing_if = "Option::is_none")]
123        left_map: Option<WasmFunction>,
124        #[serde(rename = "rightMap", skip_serializing_if = "Option::is_none")]
125        right_map: Option<WasmFunction>,
126        comparator: WasmFunction,
127    },
128    #[serde(rename = "overlapping")]
129    Overlapping {
130        #[serde(rename = "startMap", skip_serializing_if = "Option::is_none")]
131        start_map: Option<WasmFunction>,
132        #[serde(rename = "endMap", skip_serializing_if = "Option::is_none")]
133        end_map: Option<WasmFunction>,
134        #[serde(rename = "leftStartMap", skip_serializing_if = "Option::is_none")]
135        left_start_map: Option<WasmFunction>,
136        #[serde(rename = "leftEndMap", skip_serializing_if = "Option::is_none")]
137        left_end_map: Option<WasmFunction>,
138        #[serde(rename = "rightStartMap", skip_serializing_if = "Option::is_none")]
139        right_start_map: Option<WasmFunction>,
140        #[serde(rename = "rightEndMap", skip_serializing_if = "Option::is_none")]
141        right_end_map: Option<WasmFunction>,
142        #[serde(skip_serializing_if = "Option::is_none")]
143        comparator: Option<WasmFunction>,
144    },
145    #[serde(rename = "filtering")]
146    Filtering { filter: WasmFunction },
147}
148
149impl Joiner {
150    pub fn equal(map: WasmFunction) -> Self {
151        Joiner::Equal {
152            map: Some(map),
153            left_map: None,
154            right_map: None,
155            relation_predicate: None,
156            hasher: None,
157        }
158    }
159
160    pub fn equal_with_mappings(left_map: WasmFunction, right_map: WasmFunction) -> Self {
161        Joiner::Equal {
162            map: None,
163            left_map: Some(left_map),
164            right_map: Some(right_map),
165            relation_predicate: None,
166            hasher: None,
167        }
168    }
169
170    pub fn equal_with_custom_equals(
171        map: WasmFunction,
172        relation_predicate: WasmFunction,
173        hasher: WasmFunction,
174    ) -> Self {
175        Joiner::Equal {
176            map: Some(map),
177            left_map: None,
178            right_map: None,
179            relation_predicate: Some(relation_predicate),
180            hasher: Some(hasher),
181        }
182    }
183
184    pub fn less_than(map: WasmFunction, comparator: WasmFunction) -> Self {
185        Joiner::LessThan {
186            map: Some(map),
187            left_map: None,
188            right_map: None,
189            comparator,
190        }
191    }
192
193    pub fn less_than_with_mappings(
194        left_map: WasmFunction,
195        right_map: WasmFunction,
196        comparator: WasmFunction,
197    ) -> Self {
198        Joiner::LessThan {
199            map: None,
200            left_map: Some(left_map),
201            right_map: Some(right_map),
202            comparator,
203        }
204    }
205
206    pub fn less_than_or_equal(map: WasmFunction, comparator: WasmFunction) -> Self {
207        Joiner::LessThanOrEqual {
208            map: Some(map),
209            left_map: None,
210            right_map: None,
211            comparator,
212        }
213    }
214
215    pub fn less_than_or_equal_with_mappings(
216        left_map: WasmFunction,
217        right_map: WasmFunction,
218        comparator: WasmFunction,
219    ) -> Self {
220        Joiner::LessThanOrEqual {
221            map: None,
222            left_map: Some(left_map),
223            right_map: Some(right_map),
224            comparator,
225        }
226    }
227
228    pub fn greater_than(map: WasmFunction, comparator: WasmFunction) -> Self {
229        Joiner::GreaterThan {
230            map: Some(map),
231            left_map: None,
232            right_map: None,
233            comparator,
234        }
235    }
236
237    pub fn greater_than_with_mappings(
238        left_map: WasmFunction,
239        right_map: WasmFunction,
240        comparator: WasmFunction,
241    ) -> Self {
242        Joiner::GreaterThan {
243            map: None,
244            left_map: Some(left_map),
245            right_map: Some(right_map),
246            comparator,
247        }
248    }
249
250    pub fn greater_than_or_equal(map: WasmFunction, comparator: WasmFunction) -> Self {
251        Joiner::GreaterThanOrEqual {
252            map: Some(map),
253            left_map: None,
254            right_map: None,
255            comparator,
256        }
257    }
258
259    pub fn greater_than_or_equal_with_mappings(
260        left_map: WasmFunction,
261        right_map: WasmFunction,
262        comparator: WasmFunction,
263    ) -> Self {
264        Joiner::GreaterThanOrEqual {
265            map: None,
266            left_map: Some(left_map),
267            right_map: Some(right_map),
268            comparator,
269        }
270    }
271
272    pub fn overlapping(start_map: WasmFunction, end_map: WasmFunction) -> Self {
273        Joiner::Overlapping {
274            start_map: Some(start_map),
275            end_map: Some(end_map),
276            left_start_map: None,
277            left_end_map: None,
278            right_start_map: None,
279            right_end_map: None,
280            comparator: None,
281        }
282    }
283
284    pub fn overlapping_with_mappings(
285        left_start_map: WasmFunction,
286        left_end_map: WasmFunction,
287        right_start_map: WasmFunction,
288        right_end_map: WasmFunction,
289    ) -> Self {
290        Joiner::Overlapping {
291            start_map: None,
292            end_map: None,
293            left_start_map: Some(left_start_map),
294            left_end_map: Some(left_end_map),
295            right_start_map: Some(right_start_map),
296            right_end_map: Some(right_end_map),
297            comparator: None,
298        }
299    }
300
301    pub fn overlapping_with_comparator(
302        start_map: WasmFunction,
303        end_map: WasmFunction,
304        comparator: WasmFunction,
305    ) -> Self {
306        Joiner::Overlapping {
307            start_map: Some(start_map),
308            end_map: Some(end_map),
309            left_start_map: None,
310            left_end_map: None,
311            right_start_map: None,
312            right_end_map: None,
313            comparator: Some(comparator),
314        }
315    }
316
317    pub fn filtering(filter: WasmFunction) -> Self {
318        Joiner::Filtering { filter }
319    }
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325
326    #[test]
327    fn test_wasm_function_new() {
328        let func = WasmFunction::new("get_timeslot");
329        assert_eq!(func.name(), "get_timeslot");
330        assert!(func.relation_function().is_none());
331        assert!(func.hash_function().is_none());
332        assert!(func.comparator_function().is_none());
333    }
334
335    #[test]
336    fn test_wasm_function_with_relation() {
337        let func = WasmFunction::new("get_value")
338            .with_relation("equals_fn")
339            .with_hash("hash_fn");
340        assert_eq!(func.name(), "get_value");
341        assert_eq!(func.relation_function(), Some("equals_fn"));
342        assert_eq!(func.hash_function(), Some("hash_fn"));
343    }
344
345    #[test]
346    fn test_wasm_function_with_comparator() {
347        let func = WasmFunction::new("get_time").with_comparator("compare_times");
348        assert_eq!(func.comparator_function(), Some("compare_times"));
349    }
350
351    #[test]
352    fn test_equal_joiner() {
353        let joiner = Joiner::equal(WasmFunction::new("get_timeslot"));
354        match joiner {
355            Joiner::Equal {
356                map,
357                left_map,
358                right_map,
359                ..
360            } => {
361                assert!(map.is_some());
362                assert_eq!(map.unwrap().name(), "get_timeslot");
363                assert!(left_map.is_none());
364                assert!(right_map.is_none());
365            }
366            _ => panic!("Expected Equal joiner"),
367        }
368    }
369
370    #[test]
371    fn test_equal_with_mappings() {
372        let joiner = Joiner::equal_with_mappings(
373            WasmFunction::new("get_left_timeslot"),
374            WasmFunction::new("get_right_timeslot"),
375        );
376        match joiner {
377            Joiner::Equal {
378                map,
379                left_map,
380                right_map,
381                ..
382            } => {
383                assert!(map.is_none());
384                assert_eq!(left_map.unwrap().name(), "get_left_timeslot");
385                assert_eq!(right_map.unwrap().name(), "get_right_timeslot");
386            }
387            _ => panic!("Expected Equal joiner"),
388        }
389    }
390
391    #[test]
392    fn test_equal_with_custom_equals() {
393        let joiner = Joiner::equal_with_custom_equals(
394            WasmFunction::new("get_id"),
395            WasmFunction::new("id_equals"),
396            WasmFunction::new("id_hash"),
397        );
398        match joiner {
399            Joiner::Equal {
400                map,
401                relation_predicate,
402                hasher,
403                ..
404            } => {
405                assert!(map.is_some());
406                assert!(relation_predicate.is_some());
407                assert!(hasher.is_some());
408            }
409            _ => panic!("Expected Equal joiner"),
410        }
411    }
412
413    #[test]
414    fn test_less_than_joiner() {
415        let joiner = Joiner::less_than(
416            WasmFunction::new("get_start_time"),
417            WasmFunction::new("compare_time"),
418        );
419        match joiner {
420            Joiner::LessThan {
421                map, comparator, ..
422            } => {
423                assert!(map.is_some());
424                assert_eq!(comparator.name(), "compare_time");
425            }
426            _ => panic!("Expected LessThan joiner"),
427        }
428    }
429
430    #[test]
431    fn test_less_than_with_mappings() {
432        let joiner = Joiner::less_than_with_mappings(
433            WasmFunction::new("get_left_time"),
434            WasmFunction::new("get_right_time"),
435            WasmFunction::new("compare_time"),
436        );
437        match joiner {
438            Joiner::LessThan {
439                map,
440                left_map,
441                right_map,
442                ..
443            } => {
444                assert!(map.is_none());
445                assert!(left_map.is_some());
446                assert!(right_map.is_some());
447            }
448            _ => panic!("Expected LessThan joiner"),
449        }
450    }
451
452    #[test]
453    fn test_greater_than_joiner() {
454        let joiner = Joiner::greater_than(
455            WasmFunction::new("get_priority"),
456            WasmFunction::new("compare_priority"),
457        );
458        match joiner {
459            Joiner::GreaterThan { map, .. } => {
460                assert!(map.is_some());
461            }
462            _ => panic!("Expected GreaterThan joiner"),
463        }
464    }
465
466    #[test]
467    fn test_overlapping_joiner() {
468        let joiner =
469            Joiner::overlapping(WasmFunction::new("get_start"), WasmFunction::new("get_end"));
470        match joiner {
471            Joiner::Overlapping {
472                start_map,
473                end_map,
474                left_start_map,
475                ..
476            } => {
477                assert!(start_map.is_some());
478                assert!(end_map.is_some());
479                assert!(left_start_map.is_none());
480            }
481            _ => panic!("Expected Overlapping joiner"),
482        }
483    }
484
485    #[test]
486    fn test_overlapping_with_mappings() {
487        let joiner = Joiner::overlapping_with_mappings(
488            WasmFunction::new("left_start"),
489            WasmFunction::new("left_end"),
490            WasmFunction::new("right_start"),
491            WasmFunction::new("right_end"),
492        );
493        match joiner {
494            Joiner::Overlapping {
495                start_map,
496                left_start_map,
497                left_end_map,
498                right_start_map,
499                right_end_map,
500                ..
501            } => {
502                assert!(start_map.is_none());
503                assert!(left_start_map.is_some());
504                assert!(left_end_map.is_some());
505                assert!(right_start_map.is_some());
506                assert!(right_end_map.is_some());
507            }
508            _ => panic!("Expected Overlapping joiner"),
509        }
510    }
511
512    #[test]
513    fn test_filtering_joiner() {
514        let joiner = Joiner::filtering(WasmFunction::new("is_compatible"));
515        match joiner {
516            Joiner::Filtering { filter } => {
517                assert_eq!(filter.name(), "is_compatible");
518            }
519            _ => panic!("Expected Filtering joiner"),
520        }
521    }
522
523    #[test]
524    fn test_equal_joiner_json_serialization() {
525        let joiner = Joiner::equal(WasmFunction::new("get_timeslot"));
526        let json = serde_json::to_string(&joiner).unwrap();
527        assert!(json.contains("\"relation\":\"equal\""));
528        assert!(json.contains("\"map\":\"get_timeslot\""));
529
530        let parsed: Joiner = serde_json::from_str(&json).unwrap();
531        assert_eq!(parsed, joiner);
532    }
533
534    #[test]
535    fn test_less_than_joiner_json_serialization() {
536        let joiner = Joiner::less_than(
537            WasmFunction::new("get_time"),
538            WasmFunction::new("compare_time"),
539        );
540        let json = serde_json::to_string(&joiner).unwrap();
541        assert!(json.contains("\"relation\":\"lessThan\""));
542        assert!(json.contains("\"comparator\":\"compare_time\""));
543
544        let parsed: Joiner = serde_json::from_str(&json).unwrap();
545        assert_eq!(parsed, joiner);
546    }
547
548    #[test]
549    fn test_overlapping_joiner_json_serialization() {
550        let joiner = Joiner::overlapping(WasmFunction::new("start"), WasmFunction::new("end"));
551        let json = serde_json::to_string(&joiner).unwrap();
552        assert!(json.contains("\"relation\":\"overlapping\""));
553        assert!(json.contains("\"startMap\":\"start\""));
554        assert!(json.contains("\"endMap\":\"end\""));
555
556        let parsed: Joiner = serde_json::from_str(&json).unwrap();
557        assert_eq!(parsed, joiner);
558    }
559
560    #[test]
561    fn test_filtering_joiner_json_serialization() {
562        let joiner = Joiner::filtering(WasmFunction::new("is_valid"));
563        let json = serde_json::to_string(&joiner).unwrap();
564        assert!(json.contains("\"relation\":\"filtering\""));
565        assert!(json.contains("\"filter\":\"is_valid\""));
566
567        let parsed: Joiner = serde_json::from_str(&json).unwrap();
568        assert_eq!(parsed, joiner);
569    }
570
571    #[test]
572    fn test_equal_with_left_right_json() {
573        let joiner = Joiner::equal_with_mappings(
574            WasmFunction::new("left_fn"),
575            WasmFunction::new("right_fn"),
576        );
577        let json = serde_json::to_string(&joiner).unwrap();
578        assert!(json.contains("\"leftMap\":\"left_fn\""));
579        assert!(json.contains("\"rightMap\":\"right_fn\""));
580        assert!(!json.contains("\"map\""));
581    }
582
583    #[test]
584    fn test_joiner_clone() {
585        let joiner = Joiner::equal(WasmFunction::new("get_value"));
586        let cloned = joiner.clone();
587        assert_eq!(joiner, cloned);
588    }
589
590    #[test]
591    fn test_joiner_debug() {
592        let joiner = Joiner::filtering(WasmFunction::new("test"));
593        let debug = format!("{:?}", joiner);
594        assert!(debug.contains("Filtering"));
595        assert!(debug.contains("test"));
596    }
597}