solverforge_core/constraints/
stream.rs

1use crate::constraints::{Collector, Joiner, WasmFunction};
2use serde::{Deserialize, Serialize};
3
4#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
5#[serde(tag = "kind")]
6pub enum StreamComponent {
7    #[serde(rename = "forEach")]
8    ForEach {
9        #[serde(rename = "className")]
10        class_name: String,
11    },
12    #[serde(rename = "forEachIncludingUnassigned")]
13    ForEachIncludingUnassigned {
14        #[serde(rename = "className")]
15        class_name: String,
16    },
17    #[serde(rename = "forEachUniquePair")]
18    ForEachUniquePair {
19        #[serde(rename = "className")]
20        class_name: String,
21        #[serde(default, skip_serializing_if = "Vec::is_empty")]
22        joiners: Vec<Joiner>,
23    },
24    #[serde(rename = "filter")]
25    Filter { predicate: WasmFunction },
26    #[serde(rename = "join")]
27    Join {
28        #[serde(rename = "className")]
29        class_name: String,
30        #[serde(default, skip_serializing_if = "Vec::is_empty")]
31        joiners: Vec<Joiner>,
32    },
33    #[serde(rename = "ifExists")]
34    IfExists {
35        #[serde(rename = "className")]
36        class_name: String,
37        #[serde(default, skip_serializing_if = "Vec::is_empty")]
38        joiners: Vec<Joiner>,
39    },
40    #[serde(rename = "ifNotExists")]
41    IfNotExists {
42        #[serde(rename = "className")]
43        class_name: String,
44        #[serde(default, skip_serializing_if = "Vec::is_empty")]
45        joiners: Vec<Joiner>,
46    },
47    #[serde(rename = "groupBy")]
48    GroupBy {
49        #[serde(default, skip_serializing_if = "Vec::is_empty")]
50        keys: Vec<WasmFunction>,
51        #[serde(default, skip_serializing_if = "Vec::is_empty")]
52        aggregators: Vec<Collector>,
53    },
54    #[serde(rename = "map")]
55    Map {
56        #[serde(rename = "mapper")]
57        mappers: Vec<WasmFunction>,
58    },
59    #[serde(rename = "flattenLast")]
60    FlattenLast {
61        #[serde(skip_serializing_if = "Option::is_none")]
62        map: Option<WasmFunction>,
63    },
64    #[serde(rename = "expand")]
65    Expand {
66        #[serde(rename = "mapper")]
67        mappers: Vec<WasmFunction>,
68    },
69    #[serde(rename = "complement")]
70    Complement {
71        #[serde(rename = "className")]
72        class_name: String,
73    },
74    #[serde(rename = "penalize")]
75    Penalize {
76        weight: String,
77        #[serde(rename = "scaleBy", skip_serializing_if = "Option::is_none")]
78        scale_by: Option<WasmFunction>,
79    },
80    #[serde(rename = "reward")]
81    Reward {
82        weight: String,
83        #[serde(rename = "scaleBy", skip_serializing_if = "Option::is_none")]
84        scale_by: Option<WasmFunction>,
85    },
86}
87
88impl StreamComponent {
89    pub fn for_each(class_name: impl Into<String>) -> Self {
90        StreamComponent::ForEach {
91            class_name: class_name.into(),
92        }
93    }
94
95    pub fn for_each_including_unassigned(class_name: impl Into<String>) -> Self {
96        StreamComponent::ForEachIncludingUnassigned {
97            class_name: class_name.into(),
98        }
99    }
100
101    pub fn for_each_unique_pair(class_name: impl Into<String>) -> Self {
102        StreamComponent::ForEachUniquePair {
103            class_name: class_name.into(),
104            joiners: Vec::new(),
105        }
106    }
107
108    pub fn for_each_unique_pair_with_joiners(
109        class_name: impl Into<String>,
110        joiners: Vec<Joiner>,
111    ) -> Self {
112        StreamComponent::ForEachUniquePair {
113            class_name: class_name.into(),
114            joiners,
115        }
116    }
117
118    pub fn filter(predicate: WasmFunction) -> Self {
119        StreamComponent::Filter { predicate }
120    }
121
122    pub fn join(class_name: impl Into<String>) -> Self {
123        StreamComponent::Join {
124            class_name: class_name.into(),
125            joiners: Vec::new(),
126        }
127    }
128
129    pub fn join_with_joiners(class_name: impl Into<String>, joiners: Vec<Joiner>) -> Self {
130        StreamComponent::Join {
131            class_name: class_name.into(),
132            joiners,
133        }
134    }
135
136    pub fn if_exists(class_name: impl Into<String>) -> Self {
137        StreamComponent::IfExists {
138            class_name: class_name.into(),
139            joiners: Vec::new(),
140        }
141    }
142
143    pub fn if_exists_with_joiners(class_name: impl Into<String>, joiners: Vec<Joiner>) -> Self {
144        StreamComponent::IfExists {
145            class_name: class_name.into(),
146            joiners,
147        }
148    }
149
150    pub fn if_not_exists(class_name: impl Into<String>) -> Self {
151        StreamComponent::IfNotExists {
152            class_name: class_name.into(),
153            joiners: Vec::new(),
154        }
155    }
156
157    pub fn if_not_exists_with_joiners(class_name: impl Into<String>, joiners: Vec<Joiner>) -> Self {
158        StreamComponent::IfNotExists {
159            class_name: class_name.into(),
160            joiners,
161        }
162    }
163
164    pub fn group_by(keys: Vec<WasmFunction>, aggregators: Vec<Collector>) -> Self {
165        StreamComponent::GroupBy { keys, aggregators }
166    }
167
168    pub fn group_by_key(key: WasmFunction) -> Self {
169        StreamComponent::GroupBy {
170            keys: vec![key],
171            aggregators: Vec::new(),
172        }
173    }
174
175    pub fn group_by_collector(aggregator: Collector) -> Self {
176        StreamComponent::GroupBy {
177            keys: Vec::new(),
178            aggregators: vec![aggregator],
179        }
180    }
181
182    pub fn map(mappers: Vec<WasmFunction>) -> Self {
183        StreamComponent::Map { mappers }
184    }
185
186    pub fn map_single(mapper: WasmFunction) -> Self {
187        StreamComponent::Map {
188            mappers: vec![mapper],
189        }
190    }
191
192    pub fn flatten_last() -> Self {
193        StreamComponent::FlattenLast { map: None }
194    }
195
196    pub fn flatten_last_with_map(map: WasmFunction) -> Self {
197        StreamComponent::FlattenLast { map: Some(map) }
198    }
199
200    pub fn expand(mappers: Vec<WasmFunction>) -> Self {
201        StreamComponent::Expand { mappers }
202    }
203
204    pub fn complement(class_name: impl Into<String>) -> Self {
205        StreamComponent::Complement {
206            class_name: class_name.into(),
207        }
208    }
209
210    pub fn penalize(weight: impl Into<String>) -> Self {
211        StreamComponent::Penalize {
212            weight: weight.into(),
213            scale_by: None,
214        }
215    }
216
217    pub fn penalize_with_weigher(weight: impl Into<String>, scale_by: WasmFunction) -> Self {
218        StreamComponent::Penalize {
219            weight: weight.into(),
220            scale_by: Some(scale_by),
221        }
222    }
223
224    pub fn reward(weight: impl Into<String>) -> Self {
225        StreamComponent::Reward {
226            weight: weight.into(),
227            scale_by: None,
228        }
229    }
230
231    pub fn reward_with_weigher(weight: impl Into<String>, scale_by: WasmFunction) -> Self {
232        StreamComponent::Reward {
233            weight: weight.into(),
234            scale_by: Some(scale_by),
235        }
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242
243    #[test]
244    fn test_for_each() {
245        let component = StreamComponent::for_each("Lesson");
246        match component {
247            StreamComponent::ForEach { class_name } => {
248                assert_eq!(class_name, "Lesson");
249            }
250            _ => panic!("Expected ForEach"),
251        }
252    }
253
254    #[test]
255    fn test_for_each_including_unassigned() {
256        let component = StreamComponent::for_each_including_unassigned("Lesson");
257        match component {
258            StreamComponent::ForEachIncludingUnassigned { class_name } => {
259                assert_eq!(class_name, "Lesson");
260            }
261            _ => panic!("Expected ForEachIncludingUnassigned"),
262        }
263    }
264
265    #[test]
266    fn test_for_each_unique_pair() {
267        let component = StreamComponent::for_each_unique_pair("Lesson");
268        match component {
269            StreamComponent::ForEachUniquePair {
270                class_name,
271                joiners,
272            } => {
273                assert_eq!(class_name, "Lesson");
274                assert!(joiners.is_empty());
275            }
276            _ => panic!("Expected ForEachUniquePair"),
277        }
278    }
279
280    #[test]
281    fn test_for_each_unique_pair_with_joiners() {
282        let component = StreamComponent::for_each_unique_pair_with_joiners(
283            "Lesson",
284            vec![Joiner::equal(WasmFunction::new("get_timeslot"))],
285        );
286        match component {
287            StreamComponent::ForEachUniquePair { joiners, .. } => {
288                assert_eq!(joiners.len(), 1);
289            }
290            _ => panic!("Expected ForEachUniquePair"),
291        }
292    }
293
294    #[test]
295    fn test_filter() {
296        let component = StreamComponent::filter(WasmFunction::new("is_valid"));
297        match component {
298            StreamComponent::Filter { predicate } => {
299                assert_eq!(predicate.name(), "is_valid");
300            }
301            _ => panic!("Expected Filter"),
302        }
303    }
304
305    #[test]
306    fn test_join() {
307        let component = StreamComponent::join("Room");
308        match component {
309            StreamComponent::Join {
310                class_name,
311                joiners,
312            } => {
313                assert_eq!(class_name, "Room");
314                assert!(joiners.is_empty());
315            }
316            _ => panic!("Expected Join"),
317        }
318    }
319
320    #[test]
321    fn test_join_with_joiners() {
322        let component = StreamComponent::join_with_joiners(
323            "Room",
324            vec![Joiner::equal(WasmFunction::new("get_room"))],
325        );
326        match component {
327            StreamComponent::Join { joiners, .. } => {
328                assert_eq!(joiners.len(), 1);
329            }
330            _ => panic!("Expected Join"),
331        }
332    }
333
334    #[test]
335    fn test_if_exists() {
336        let component = StreamComponent::if_exists("Conflict");
337        match component {
338            StreamComponent::IfExists { class_name, .. } => {
339                assert_eq!(class_name, "Conflict");
340            }
341            _ => panic!("Expected IfExists"),
342        }
343    }
344
345    #[test]
346    fn test_if_not_exists() {
347        let component = StreamComponent::if_not_exists("Conflict");
348        match component {
349            StreamComponent::IfNotExists { class_name, .. } => {
350                assert_eq!(class_name, "Conflict");
351            }
352            _ => panic!("Expected IfNotExists"),
353        }
354    }
355
356    #[test]
357    fn test_group_by() {
358        let component = StreamComponent::group_by(
359            vec![WasmFunction::new("get_room")],
360            vec![Collector::count()],
361        );
362        match component {
363            StreamComponent::GroupBy { keys, aggregators } => {
364                assert_eq!(keys.len(), 1);
365                assert_eq!(aggregators.len(), 1);
366            }
367            _ => panic!("Expected GroupBy"),
368        }
369    }
370
371    #[test]
372    fn test_group_by_key() {
373        let component = StreamComponent::group_by_key(WasmFunction::new("get_room"));
374        match component {
375            StreamComponent::GroupBy { keys, aggregators } => {
376                assert_eq!(keys.len(), 1);
377                assert!(aggregators.is_empty());
378            }
379            _ => panic!("Expected GroupBy"),
380        }
381    }
382
383    #[test]
384    fn test_group_by_collector() {
385        let component = StreamComponent::group_by_collector(Collector::count());
386        match component {
387            StreamComponent::GroupBy { keys, aggregators } => {
388                assert!(keys.is_empty());
389                assert_eq!(aggregators.len(), 1);
390            }
391            _ => panic!("Expected GroupBy"),
392        }
393    }
394
395    #[test]
396    fn test_map() {
397        let component =
398            StreamComponent::map(vec![WasmFunction::new("get_a"), WasmFunction::new("get_b")]);
399        match component {
400            StreamComponent::Map { mappers } => {
401                assert_eq!(mappers.len(), 2);
402            }
403            _ => panic!("Expected Map"),
404        }
405    }
406
407    #[test]
408    fn test_map_single() {
409        let component = StreamComponent::map_single(WasmFunction::new("get_value"));
410        match component {
411            StreamComponent::Map { mappers } => {
412                assert_eq!(mappers.len(), 1);
413            }
414            _ => panic!("Expected Map"),
415        }
416    }
417
418    #[test]
419    fn test_flatten_last() {
420        let component = StreamComponent::flatten_last();
421        match component {
422            StreamComponent::FlattenLast { map } => {
423                assert!(map.is_none());
424            }
425            _ => panic!("Expected FlattenLast"),
426        }
427    }
428
429    #[test]
430    fn test_flatten_last_with_map() {
431        let component = StreamComponent::flatten_last_with_map(WasmFunction::new("get_items"));
432        match component {
433            StreamComponent::FlattenLast { map } => {
434                assert!(map.is_some());
435            }
436            _ => panic!("Expected FlattenLast"),
437        }
438    }
439
440    #[test]
441    fn test_expand() {
442        let component = StreamComponent::expand(vec![WasmFunction::new("get_extra")]);
443        match component {
444            StreamComponent::Expand { mappers } => {
445                assert_eq!(mappers.len(), 1);
446            }
447            _ => panic!("Expected Expand"),
448        }
449    }
450
451    #[test]
452    fn test_complement() {
453        let component = StreamComponent::complement("Timeslot");
454        match component {
455            StreamComponent::Complement { class_name } => {
456                assert_eq!(class_name, "Timeslot");
457            }
458            _ => panic!("Expected Complement"),
459        }
460    }
461
462    #[test]
463    fn test_penalize() {
464        let component = StreamComponent::penalize("1hard");
465        match component {
466            StreamComponent::Penalize { weight, scale_by } => {
467                assert_eq!(weight, "1hard");
468                assert!(scale_by.is_none());
469            }
470            _ => panic!("Expected Penalize"),
471        }
472    }
473
474    #[test]
475    fn test_penalize_with_weigher() {
476        let component =
477            StreamComponent::penalize_with_weigher("1hard", WasmFunction::new("get_weight"));
478        match component {
479            StreamComponent::Penalize { weight, scale_by } => {
480                assert_eq!(weight, "1hard");
481                assert!(scale_by.is_some());
482            }
483            _ => panic!("Expected Penalize"),
484        }
485    }
486
487    #[test]
488    fn test_reward() {
489        let component = StreamComponent::reward("1soft");
490        match component {
491            StreamComponent::Reward { weight, scale_by } => {
492                assert_eq!(weight, "1soft");
493                assert!(scale_by.is_none());
494            }
495            _ => panic!("Expected Reward"),
496        }
497    }
498
499    #[test]
500    fn test_reward_with_weigher() {
501        let component =
502            StreamComponent::reward_with_weigher("1soft", WasmFunction::new("get_bonus"));
503        match component {
504            StreamComponent::Reward { scale_by, .. } => {
505                assert!(scale_by.is_some());
506            }
507            _ => panic!("Expected Reward"),
508        }
509    }
510
511    #[test]
512    fn test_for_each_json_serialization() {
513        let component = StreamComponent::for_each("Lesson");
514        let json = serde_json::to_string(&component).unwrap();
515        assert!(json.contains("\"kind\":\"forEach\""));
516        assert!(json.contains("\"className\":\"Lesson\""));
517
518        let parsed: StreamComponent = serde_json::from_str(&json).unwrap();
519        assert_eq!(parsed, component);
520    }
521
522    #[test]
523    fn test_filter_json_serialization() {
524        let component = StreamComponent::filter(WasmFunction::new("is_valid"));
525        let json = serde_json::to_string(&component).unwrap();
526        assert!(json.contains("\"kind\":\"filter\""));
527        assert!(json.contains("\"predicate\":\"is_valid\""));
528
529        let parsed: StreamComponent = serde_json::from_str(&json).unwrap();
530        assert_eq!(parsed, component);
531    }
532
533    #[test]
534    fn test_join_json_serialization() {
535        let component = StreamComponent::join_with_joiners(
536            "Room",
537            vec![Joiner::equal(WasmFunction::new("get_room"))],
538        );
539        let json = serde_json::to_string(&component).unwrap();
540        assert!(json.contains("\"kind\":\"join\""));
541        assert!(json.contains("\"className\":\"Room\""));
542        assert!(json.contains("\"joiners\""));
543
544        let parsed: StreamComponent = serde_json::from_str(&json).unwrap();
545        assert_eq!(parsed, component);
546    }
547
548    #[test]
549    fn test_group_by_json_serialization() {
550        let component = StreamComponent::group_by(
551            vec![WasmFunction::new("get_room")],
552            vec![Collector::count()],
553        );
554        let json = serde_json::to_string(&component).unwrap();
555        assert!(json.contains("\"kind\":\"groupBy\""));
556        assert!(json.contains("\"keys\""));
557        assert!(json.contains("\"aggregators\""));
558
559        let parsed: StreamComponent = serde_json::from_str(&json).unwrap();
560        assert_eq!(parsed, component);
561    }
562
563    #[test]
564    fn test_penalize_json_serialization() {
565        let component = StreamComponent::penalize("1hard");
566        let json = serde_json::to_string(&component).unwrap();
567        assert!(json.contains("\"kind\":\"penalize\""));
568        assert!(json.contains("\"weight\":\"1hard\""));
569
570        let parsed: StreamComponent = serde_json::from_str(&json).unwrap();
571        assert_eq!(parsed, component);
572    }
573
574    #[test]
575    fn test_component_clone() {
576        let component = StreamComponent::for_each("Lesson");
577        let cloned = component.clone();
578        assert_eq!(component, cloned);
579    }
580
581    #[test]
582    fn test_component_debug() {
583        let component = StreamComponent::for_each("Lesson");
584        let debug = format!("{:?}", component);
585        assert!(debug.contains("ForEach"));
586    }
587}