solverforge_core/constraints/
joiners.rs

1use crate::wasm::Expression;
2use serde::{Deserialize, Deserializer, Serialize, Serializer};
3
4#[derive(Debug, Clone)]
5pub struct WasmFunction {
6    name: String,
7    relation_function: Option<String>,
8    hash_function: Option<String>,
9    comparator_function: Option<String>,
10    /// The expression to compile into WASM. Set when created from a NamedExpression.
11    /// Skipped during serialization as only the name is sent to the solver service.
12    expression: Option<Expression>,
13}
14
15// Implement PartialEq manually to compare by name only (for serialization round-trips)
16impl PartialEq for WasmFunction {
17    fn eq(&self, other: &Self) -> bool {
18        self.name == other.name
19            && self.relation_function == other.relation_function
20            && self.hash_function == other.hash_function
21            && self.comparator_function == other.comparator_function
22        // expression is intentionally not compared - it's runtime metadata
23    }
24}
25
26impl Eq for WasmFunction {}
27
28impl Serialize for WasmFunction {
29    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
30    where
31        S: Serializer,
32    {
33        serializer.serialize_str(&self.name)
34    }
35}
36
37impl<'de> Deserialize<'de> for WasmFunction {
38    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
39    where
40        D: Deserializer<'de>,
41    {
42        let name = String::deserialize(deserializer)?;
43        Ok(WasmFunction::new(name))
44    }
45}
46
47impl WasmFunction {
48    pub fn new(name: impl Into<String>) -> Self {
49        Self {
50            name: name.into(),
51            relation_function: None,
52            hash_function: None,
53            comparator_function: None,
54            expression: None,
55        }
56    }
57
58    /// Creates a WasmFunction with an associated expression.
59    pub fn with_expression(name: impl Into<String>, expression: Expression) -> Self {
60        Self {
61            name: name.into(),
62            relation_function: None,
63            hash_function: None,
64            comparator_function: None,
65            expression: Some(expression),
66        }
67    }
68
69    pub fn with_relation(mut self, relation: impl Into<String>) -> Self {
70        self.relation_function = Some(relation.into());
71        self
72    }
73
74    pub fn with_hash(mut self, hash: impl Into<String>) -> Self {
75        self.hash_function = Some(hash.into());
76        self
77    }
78
79    pub fn with_comparator(mut self, comparator: impl Into<String>) -> Self {
80        self.comparator_function = Some(comparator.into());
81        self
82    }
83
84    pub fn name(&self) -> &str {
85        &self.name
86    }
87
88    pub fn relation_function(&self) -> Option<&str> {
89        self.relation_function.as_deref()
90    }
91
92    pub fn hash_function(&self) -> Option<&str> {
93        self.hash_function.as_deref()
94    }
95
96    pub fn comparator_function(&self) -> Option<&str> {
97        self.comparator_function.as_deref()
98    }
99
100    /// Returns the expression if one was set (from a NamedExpression).
101    pub fn expression(&self) -> Option<&Expression> {
102        self.expression.as_ref()
103    }
104}
105
106#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
107#[serde(tag = "relation")]
108#[allow(clippy::large_enum_variant)]
109pub enum Joiner {
110    #[serde(rename = "equal")]
111    Equal {
112        #[serde(skip_serializing_if = "Option::is_none")]
113        map: Option<WasmFunction>,
114        #[serde(rename = "leftMap", skip_serializing_if = "Option::is_none")]
115        left_map: Option<WasmFunction>,
116        #[serde(rename = "rightMap", skip_serializing_if = "Option::is_none")]
117        right_map: Option<WasmFunction>,
118        #[serde(rename = "equal", skip_serializing_if = "Option::is_none")]
119        relation_predicate: Option<WasmFunction>,
120        #[serde(skip_serializing_if = "Option::is_none")]
121        hasher: Option<WasmFunction>,
122    },
123    #[serde(rename = "lessThan")]
124    LessThan {
125        #[serde(skip_serializing_if = "Option::is_none")]
126        map: Option<WasmFunction>,
127        #[serde(rename = "leftMap", skip_serializing_if = "Option::is_none")]
128        left_map: Option<WasmFunction>,
129        #[serde(rename = "rightMap", skip_serializing_if = "Option::is_none")]
130        right_map: Option<WasmFunction>,
131        comparator: WasmFunction,
132    },
133    #[serde(rename = "lessThanOrEqual")]
134    LessThanOrEqual {
135        #[serde(skip_serializing_if = "Option::is_none")]
136        map: Option<WasmFunction>,
137        #[serde(rename = "leftMap", skip_serializing_if = "Option::is_none")]
138        left_map: Option<WasmFunction>,
139        #[serde(rename = "rightMap", skip_serializing_if = "Option::is_none")]
140        right_map: Option<WasmFunction>,
141        comparator: WasmFunction,
142    },
143    #[serde(rename = "greaterThan")]
144    GreaterThan {
145        #[serde(skip_serializing_if = "Option::is_none")]
146        map: Option<WasmFunction>,
147        #[serde(rename = "leftMap", skip_serializing_if = "Option::is_none")]
148        left_map: Option<WasmFunction>,
149        #[serde(rename = "rightMap", skip_serializing_if = "Option::is_none")]
150        right_map: Option<WasmFunction>,
151        comparator: WasmFunction,
152    },
153    #[serde(rename = "greaterThanOrEqual")]
154    GreaterThanOrEqual {
155        #[serde(skip_serializing_if = "Option::is_none")]
156        map: Option<WasmFunction>,
157        #[serde(rename = "leftMap", skip_serializing_if = "Option::is_none")]
158        left_map: Option<WasmFunction>,
159        #[serde(rename = "rightMap", skip_serializing_if = "Option::is_none")]
160        right_map: Option<WasmFunction>,
161        comparator: WasmFunction,
162    },
163    #[serde(rename = "overlapping")]
164    Overlapping {
165        #[serde(rename = "startMap", skip_serializing_if = "Option::is_none")]
166        start_map: Option<WasmFunction>,
167        #[serde(rename = "endMap", skip_serializing_if = "Option::is_none")]
168        end_map: Option<WasmFunction>,
169        #[serde(rename = "leftStartMap", skip_serializing_if = "Option::is_none")]
170        left_start_map: Option<WasmFunction>,
171        #[serde(rename = "leftEndMap", skip_serializing_if = "Option::is_none")]
172        left_end_map: Option<WasmFunction>,
173        #[serde(rename = "rightStartMap", skip_serializing_if = "Option::is_none")]
174        right_start_map: Option<WasmFunction>,
175        #[serde(rename = "rightEndMap", skip_serializing_if = "Option::is_none")]
176        right_end_map: Option<WasmFunction>,
177        #[serde(skip_serializing_if = "Option::is_none")]
178        comparator: Option<WasmFunction>,
179    },
180    #[serde(rename = "filtering")]
181    Filtering { filter: WasmFunction },
182}
183
184impl Joiner {
185    pub fn equal(map: WasmFunction) -> Self {
186        Joiner::Equal {
187            map: Some(map),
188            left_map: None,
189            right_map: None,
190            relation_predicate: None,
191            hasher: None,
192        }
193    }
194
195    pub fn equal_with_mappings(left_map: WasmFunction, right_map: WasmFunction) -> Self {
196        Joiner::Equal {
197            map: None,
198            left_map: Some(left_map),
199            right_map: Some(right_map),
200            relation_predicate: None,
201            hasher: None,
202        }
203    }
204
205    pub fn equal_with_custom_equals(
206        map: WasmFunction,
207        relation_predicate: WasmFunction,
208        hasher: WasmFunction,
209    ) -> Self {
210        Joiner::Equal {
211            map: Some(map),
212            left_map: None,
213            right_map: None,
214            relation_predicate: Some(relation_predicate),
215            hasher: Some(hasher),
216        }
217    }
218
219    pub fn less_than(map: WasmFunction, comparator: WasmFunction) -> Self {
220        Joiner::LessThan {
221            map: Some(map),
222            left_map: None,
223            right_map: None,
224            comparator,
225        }
226    }
227
228    pub fn less_than_with_mappings(
229        left_map: WasmFunction,
230        right_map: WasmFunction,
231        comparator: WasmFunction,
232    ) -> Self {
233        Joiner::LessThan {
234            map: None,
235            left_map: Some(left_map),
236            right_map: Some(right_map),
237            comparator,
238        }
239    }
240
241    pub fn less_than_or_equal(map: WasmFunction, comparator: WasmFunction) -> Self {
242        Joiner::LessThanOrEqual {
243            map: Some(map),
244            left_map: None,
245            right_map: None,
246            comparator,
247        }
248    }
249
250    pub fn less_than_or_equal_with_mappings(
251        left_map: WasmFunction,
252        right_map: WasmFunction,
253        comparator: WasmFunction,
254    ) -> Self {
255        Joiner::LessThanOrEqual {
256            map: None,
257            left_map: Some(left_map),
258            right_map: Some(right_map),
259            comparator,
260        }
261    }
262
263    pub fn greater_than(map: WasmFunction, comparator: WasmFunction) -> Self {
264        Joiner::GreaterThan {
265            map: Some(map),
266            left_map: None,
267            right_map: None,
268            comparator,
269        }
270    }
271
272    pub fn greater_than_with_mappings(
273        left_map: WasmFunction,
274        right_map: WasmFunction,
275        comparator: WasmFunction,
276    ) -> Self {
277        Joiner::GreaterThan {
278            map: None,
279            left_map: Some(left_map),
280            right_map: Some(right_map),
281            comparator,
282        }
283    }
284
285    pub fn greater_than_or_equal(map: WasmFunction, comparator: WasmFunction) -> Self {
286        Joiner::GreaterThanOrEqual {
287            map: Some(map),
288            left_map: None,
289            right_map: None,
290            comparator,
291        }
292    }
293
294    pub fn greater_than_or_equal_with_mappings(
295        left_map: WasmFunction,
296        right_map: WasmFunction,
297        comparator: WasmFunction,
298    ) -> Self {
299        Joiner::GreaterThanOrEqual {
300            map: None,
301            left_map: Some(left_map),
302            right_map: Some(right_map),
303            comparator,
304        }
305    }
306
307    pub fn overlapping(start_map: WasmFunction, end_map: WasmFunction) -> Self {
308        Joiner::Overlapping {
309            start_map: Some(start_map),
310            end_map: Some(end_map),
311            left_start_map: None,
312            left_end_map: None,
313            right_start_map: None,
314            right_end_map: None,
315            comparator: None,
316        }
317    }
318
319    pub fn overlapping_with_mappings(
320        left_start_map: WasmFunction,
321        left_end_map: WasmFunction,
322        right_start_map: WasmFunction,
323        right_end_map: WasmFunction,
324    ) -> Self {
325        Joiner::Overlapping {
326            start_map: None,
327            end_map: None,
328            left_start_map: Some(left_start_map),
329            left_end_map: Some(left_end_map),
330            right_start_map: Some(right_start_map),
331            right_end_map: Some(right_end_map),
332            comparator: None,
333        }
334    }
335
336    pub fn overlapping_with_comparator(
337        start_map: WasmFunction,
338        end_map: WasmFunction,
339        comparator: WasmFunction,
340    ) -> Self {
341        Joiner::Overlapping {
342            start_map: Some(start_map),
343            end_map: Some(end_map),
344            left_start_map: None,
345            left_end_map: None,
346            right_start_map: None,
347            right_end_map: None,
348            comparator: Some(comparator),
349        }
350    }
351
352    pub fn filtering(filter: WasmFunction) -> Self {
353        Joiner::Filtering { filter }
354    }
355}
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360
361    #[test]
362    fn test_wasm_function_new() {
363        let func = WasmFunction::new("get_timeslot");
364        assert_eq!(func.name(), "get_timeslot");
365        assert!(func.relation_function().is_none());
366        assert!(func.hash_function().is_none());
367        assert!(func.comparator_function().is_none());
368    }
369
370    #[test]
371    fn test_wasm_function_with_relation() {
372        let func = WasmFunction::new("get_value")
373            .with_relation("equals_fn")
374            .with_hash("hash_fn");
375        assert_eq!(func.name(), "get_value");
376        assert_eq!(func.relation_function(), Some("equals_fn"));
377        assert_eq!(func.hash_function(), Some("hash_fn"));
378    }
379
380    #[test]
381    fn test_wasm_function_with_comparator() {
382        let func = WasmFunction::new("get_time").with_comparator("compare_times");
383        assert_eq!(func.comparator_function(), Some("compare_times"));
384    }
385
386    #[test]
387    fn test_equal_joiner() {
388        let joiner = Joiner::equal(WasmFunction::new("get_timeslot"));
389        match joiner {
390            Joiner::Equal {
391                map,
392                left_map,
393                right_map,
394                ..
395            } => {
396                assert!(map.is_some());
397                assert_eq!(map.unwrap().name(), "get_timeslot");
398                assert!(left_map.is_none());
399                assert!(right_map.is_none());
400            }
401            _ => panic!("Expected Equal joiner"),
402        }
403    }
404
405    #[test]
406    fn test_equal_with_mappings() {
407        let joiner = Joiner::equal_with_mappings(
408            WasmFunction::new("get_left_timeslot"),
409            WasmFunction::new("get_right_timeslot"),
410        );
411        match joiner {
412            Joiner::Equal {
413                map,
414                left_map,
415                right_map,
416                ..
417            } => {
418                assert!(map.is_none());
419                assert_eq!(left_map.unwrap().name(), "get_left_timeslot");
420                assert_eq!(right_map.unwrap().name(), "get_right_timeslot");
421            }
422            _ => panic!("Expected Equal joiner"),
423        }
424    }
425
426    #[test]
427    fn test_equal_with_custom_equals() {
428        let joiner = Joiner::equal_with_custom_equals(
429            WasmFunction::new("get_id"),
430            WasmFunction::new("id_equals"),
431            WasmFunction::new("id_hash"),
432        );
433        match joiner {
434            Joiner::Equal {
435                map,
436                relation_predicate,
437                hasher,
438                ..
439            } => {
440                assert!(map.is_some());
441                assert!(relation_predicate.is_some());
442                assert!(hasher.is_some());
443            }
444            _ => panic!("Expected Equal joiner"),
445        }
446    }
447
448    #[test]
449    fn test_less_than_joiner() {
450        let joiner = Joiner::less_than(
451            WasmFunction::new("get_start_time"),
452            WasmFunction::new("compare_time"),
453        );
454        match joiner {
455            Joiner::LessThan {
456                map, comparator, ..
457            } => {
458                assert!(map.is_some());
459                assert_eq!(comparator.name(), "compare_time");
460            }
461            _ => panic!("Expected LessThan joiner"),
462        }
463    }
464
465    #[test]
466    fn test_less_than_with_mappings() {
467        let joiner = Joiner::less_than_with_mappings(
468            WasmFunction::new("get_left_time"),
469            WasmFunction::new("get_right_time"),
470            WasmFunction::new("compare_time"),
471        );
472        match joiner {
473            Joiner::LessThan {
474                map,
475                left_map,
476                right_map,
477                ..
478            } => {
479                assert!(map.is_none());
480                assert!(left_map.is_some());
481                assert!(right_map.is_some());
482            }
483            _ => panic!("Expected LessThan joiner"),
484        }
485    }
486
487    #[test]
488    fn test_greater_than_joiner() {
489        let joiner = Joiner::greater_than(
490            WasmFunction::new("get_priority"),
491            WasmFunction::new("compare_priority"),
492        );
493        match joiner {
494            Joiner::GreaterThan { map, .. } => {
495                assert!(map.is_some());
496            }
497            _ => panic!("Expected GreaterThan joiner"),
498        }
499    }
500
501    #[test]
502    fn test_overlapping_joiner() {
503        let joiner =
504            Joiner::overlapping(WasmFunction::new("get_start"), WasmFunction::new("get_end"));
505        match joiner {
506            Joiner::Overlapping {
507                start_map,
508                end_map,
509                left_start_map,
510                ..
511            } => {
512                assert!(start_map.is_some());
513                assert!(end_map.is_some());
514                assert!(left_start_map.is_none());
515            }
516            _ => panic!("Expected Overlapping joiner"),
517        }
518    }
519
520    #[test]
521    fn test_overlapping_with_mappings() {
522        let joiner = Joiner::overlapping_with_mappings(
523            WasmFunction::new("left_start"),
524            WasmFunction::new("left_end"),
525            WasmFunction::new("right_start"),
526            WasmFunction::new("right_end"),
527        );
528        match joiner {
529            Joiner::Overlapping {
530                start_map,
531                left_start_map,
532                left_end_map,
533                right_start_map,
534                right_end_map,
535                ..
536            } => {
537                assert!(start_map.is_none());
538                assert!(left_start_map.is_some());
539                assert!(left_end_map.is_some());
540                assert!(right_start_map.is_some());
541                assert!(right_end_map.is_some());
542            }
543            _ => panic!("Expected Overlapping joiner"),
544        }
545    }
546
547    #[test]
548    fn test_filtering_joiner() {
549        let joiner = Joiner::filtering(WasmFunction::new("is_compatible"));
550        match joiner {
551            Joiner::Filtering { filter } => {
552                assert_eq!(filter.name(), "is_compatible");
553            }
554            _ => panic!("Expected Filtering joiner"),
555        }
556    }
557
558    #[test]
559    fn test_equal_joiner_json_serialization() {
560        let joiner = Joiner::equal(WasmFunction::new("get_timeslot"));
561        let json = serde_json::to_string(&joiner).unwrap();
562        assert!(json.contains("\"relation\":\"equal\""));
563        assert!(json.contains("\"map\":\"get_timeslot\""));
564
565        let parsed: Joiner = serde_json::from_str(&json).unwrap();
566        assert_eq!(parsed, joiner);
567    }
568
569    #[test]
570    fn test_less_than_joiner_json_serialization() {
571        let joiner = Joiner::less_than(
572            WasmFunction::new("get_time"),
573            WasmFunction::new("compare_time"),
574        );
575        let json = serde_json::to_string(&joiner).unwrap();
576        assert!(json.contains("\"relation\":\"lessThan\""));
577        assert!(json.contains("\"comparator\":\"compare_time\""));
578
579        let parsed: Joiner = serde_json::from_str(&json).unwrap();
580        assert_eq!(parsed, joiner);
581    }
582
583    #[test]
584    fn test_overlapping_joiner_json_serialization() {
585        let joiner = Joiner::overlapping(WasmFunction::new("start"), WasmFunction::new("end"));
586        let json = serde_json::to_string(&joiner).unwrap();
587        assert!(json.contains("\"relation\":\"overlapping\""));
588        assert!(json.contains("\"startMap\":\"start\""));
589        assert!(json.contains("\"endMap\":\"end\""));
590
591        let parsed: Joiner = serde_json::from_str(&json).unwrap();
592        assert_eq!(parsed, joiner);
593    }
594
595    #[test]
596    fn test_filtering_joiner_json_serialization() {
597        let joiner = Joiner::filtering(WasmFunction::new("is_valid"));
598        let json = serde_json::to_string(&joiner).unwrap();
599        assert!(json.contains("\"relation\":\"filtering\""));
600        assert!(json.contains("\"filter\":\"is_valid\""));
601
602        let parsed: Joiner = serde_json::from_str(&json).unwrap();
603        assert_eq!(parsed, joiner);
604    }
605
606    #[test]
607    fn test_equal_with_left_right_json() {
608        let joiner = Joiner::equal_with_mappings(
609            WasmFunction::new("left_fn"),
610            WasmFunction::new("right_fn"),
611        );
612        let json = serde_json::to_string(&joiner).unwrap();
613        assert!(json.contains("\"leftMap\":\"left_fn\""));
614        assert!(json.contains("\"rightMap\":\"right_fn\""));
615        assert!(!json.contains("\"map\""));
616    }
617
618    #[test]
619    fn test_joiner_clone() {
620        let joiner = Joiner::equal(WasmFunction::new("get_value"));
621        let cloned = joiner.clone();
622        assert_eq!(joiner, cloned);
623    }
624
625    #[test]
626    fn test_joiner_debug() {
627        let joiner = Joiner::filtering(WasmFunction::new("test"));
628        let debug = format!("{:?}", joiner);
629        assert!(debug.contains("Filtering"));
630        assert!(debug.contains("test"));
631    }
632}