solverforge_core/constraints/
collectors.rs

1use crate::constraints::WasmFunction;
2use serde::{Deserialize, Serialize};
3
4#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
5#[serde(tag = "name")]
6pub enum Collector {
7    #[serde(rename = "count")]
8    Count {
9        #[serde(skip_serializing_if = "Option::is_none")]
10        distinct: Option<bool>,
11        #[serde(skip_serializing_if = "Option::is_none")]
12        map: Option<WasmFunction>,
13    },
14    #[serde(rename = "sum")]
15    Sum { map: WasmFunction },
16    #[serde(rename = "average")]
17    Average { map: WasmFunction },
18    #[serde(rename = "min")]
19    Min {
20        map: WasmFunction,
21        comparator: WasmFunction,
22    },
23    #[serde(rename = "max")]
24    Max {
25        map: WasmFunction,
26        comparator: WasmFunction,
27    },
28    #[serde(rename = "toList")]
29    ToList {
30        #[serde(skip_serializing_if = "Option::is_none")]
31        map: Option<WasmFunction>,
32    },
33    #[serde(rename = "toSet")]
34    ToSet {
35        #[serde(skip_serializing_if = "Option::is_none")]
36        map: Option<WasmFunction>,
37    },
38    #[serde(rename = "compose")]
39    Compose {
40        collectors: Vec<Collector>,
41        combiner: WasmFunction,
42    },
43    #[serde(rename = "conditionally")]
44    Conditionally {
45        predicate: WasmFunction,
46        collector: Box<Collector>,
47    },
48    #[serde(rename = "collectAndThen")]
49    CollectAndThen {
50        collector: Box<Collector>,
51        mapper: WasmFunction,
52    },
53    #[serde(rename = "loadBalance")]
54    LoadBalance {
55        map: WasmFunction,
56        #[serde(skip_serializing_if = "Option::is_none")]
57        load: Option<WasmFunction>,
58    },
59}
60
61impl Collector {
62    pub fn count() -> Self {
63        Collector::Count {
64            distinct: None,
65            map: None,
66        }
67    }
68
69    pub fn count_distinct() -> Self {
70        Collector::Count {
71            distinct: Some(true),
72            map: None,
73        }
74    }
75
76    pub fn count_with_map(map: WasmFunction) -> Self {
77        Collector::Count {
78            distinct: None,
79            map: Some(map),
80        }
81    }
82
83    pub fn sum(map: WasmFunction) -> Self {
84        Collector::Sum { map }
85    }
86
87    pub fn average(map: WasmFunction) -> Self {
88        Collector::Average { map }
89    }
90
91    pub fn min(map: WasmFunction, comparator: WasmFunction) -> Self {
92        Collector::Min { map, comparator }
93    }
94
95    pub fn max(map: WasmFunction, comparator: WasmFunction) -> Self {
96        Collector::Max { map, comparator }
97    }
98
99    pub fn to_list() -> Self {
100        Collector::ToList { map: None }
101    }
102
103    pub fn to_list_with_map(map: WasmFunction) -> Self {
104        Collector::ToList { map: Some(map) }
105    }
106
107    pub fn to_set() -> Self {
108        Collector::ToSet { map: None }
109    }
110
111    pub fn to_set_with_map(map: WasmFunction) -> Self {
112        Collector::ToSet { map: Some(map) }
113    }
114
115    pub fn compose(collectors: Vec<Collector>, combiner: WasmFunction) -> Self {
116        Collector::Compose {
117            collectors,
118            combiner,
119        }
120    }
121
122    pub fn conditionally(predicate: WasmFunction, collector: Collector) -> Self {
123        Collector::Conditionally {
124            predicate,
125            collector: Box::new(collector),
126        }
127    }
128
129    pub fn collect_and_then(collector: Collector, mapper: WasmFunction) -> Self {
130        Collector::CollectAndThen {
131            collector: Box::new(collector),
132            mapper,
133        }
134    }
135
136    pub fn load_balance(map: WasmFunction) -> Self {
137        Collector::LoadBalance { map, load: None }
138    }
139
140    pub fn load_balance_with_load(map: WasmFunction, load: WasmFunction) -> Self {
141        Collector::LoadBalance {
142            map,
143            load: Some(load),
144        }
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151
152    #[test]
153    fn test_count() {
154        let collector = Collector::count();
155        match collector {
156            Collector::Count { distinct, map } => {
157                assert!(distinct.is_none());
158                assert!(map.is_none());
159            }
160            _ => panic!("Expected Count collector"),
161        }
162    }
163
164    #[test]
165    fn test_count_distinct() {
166        let collector = Collector::count_distinct();
167        match collector {
168            Collector::Count { distinct, .. } => {
169                assert_eq!(distinct, Some(true));
170            }
171            _ => panic!("Expected Count collector"),
172        }
173    }
174
175    #[test]
176    fn test_count_with_map() {
177        let collector = Collector::count_with_map(WasmFunction::new("get_id"));
178        match collector {
179            Collector::Count { map, .. } => {
180                assert!(map.is_some());
181                assert_eq!(map.unwrap().name(), "get_id");
182            }
183            _ => panic!("Expected Count collector"),
184        }
185    }
186
187    #[test]
188    fn test_sum() {
189        let collector = Collector::sum(WasmFunction::new("get_value"));
190        match collector {
191            Collector::Sum { map } => {
192                assert_eq!(map.name(), "get_value");
193            }
194            _ => panic!("Expected Sum collector"),
195        }
196    }
197
198    #[test]
199    fn test_average() {
200        let collector = Collector::average(WasmFunction::new("get_score"));
201        match collector {
202            Collector::Average { map } => {
203                assert_eq!(map.name(), "get_score");
204            }
205            _ => panic!("Expected Average collector"),
206        }
207    }
208
209    #[test]
210    fn test_min() {
211        let collector = Collector::min(
212            WasmFunction::new("get_time"),
213            WasmFunction::new("compare_time"),
214        );
215        match collector {
216            Collector::Min { map, comparator } => {
217                assert_eq!(map.name(), "get_time");
218                assert_eq!(comparator.name(), "compare_time");
219            }
220            _ => panic!("Expected Min collector"),
221        }
222    }
223
224    #[test]
225    fn test_max() {
226        let collector = Collector::max(
227            WasmFunction::new("get_priority"),
228            WasmFunction::new("compare_priority"),
229        );
230        match collector {
231            Collector::Max { map, comparator } => {
232                assert_eq!(map.name(), "get_priority");
233                assert_eq!(comparator.name(), "compare_priority");
234            }
235            _ => panic!("Expected Max collector"),
236        }
237    }
238
239    #[test]
240    fn test_to_list() {
241        let collector = Collector::to_list();
242        match collector {
243            Collector::ToList { map } => {
244                assert!(map.is_none());
245            }
246            _ => panic!("Expected ToList collector"),
247        }
248    }
249
250    #[test]
251    fn test_to_list_with_map() {
252        let collector = Collector::to_list_with_map(WasmFunction::new("get_name"));
253        match collector {
254            Collector::ToList { map } => {
255                assert!(map.is_some());
256            }
257            _ => panic!("Expected ToList collector"),
258        }
259    }
260
261    #[test]
262    fn test_to_set() {
263        let collector = Collector::to_set();
264        match collector {
265            Collector::ToSet { map } => {
266                assert!(map.is_none());
267            }
268            _ => panic!("Expected ToSet collector"),
269        }
270    }
271
272    #[test]
273    fn test_compose() {
274        let collector = Collector::compose(
275            vec![
276                Collector::count(),
277                Collector::sum(WasmFunction::new("get_value")),
278            ],
279            WasmFunction::new("combine"),
280        );
281        match collector {
282            Collector::Compose {
283                collectors,
284                combiner,
285            } => {
286                assert_eq!(collectors.len(), 2);
287                assert_eq!(combiner.name(), "combine");
288            }
289            _ => panic!("Expected Compose collector"),
290        }
291    }
292
293    #[test]
294    fn test_conditionally() {
295        let collector = Collector::conditionally(WasmFunction::new("is_valid"), Collector::count());
296        match collector {
297            Collector::Conditionally {
298                predicate,
299                collector,
300            } => {
301                assert_eq!(predicate.name(), "is_valid");
302                matches!(*collector, Collector::Count { .. });
303            }
304            _ => panic!("Expected Conditionally collector"),
305        }
306    }
307
308    #[test]
309    fn test_collect_and_then() {
310        let collector =
311            Collector::collect_and_then(Collector::count(), WasmFunction::new("to_string"));
312        match collector {
313            Collector::CollectAndThen { collector, mapper } => {
314                matches!(*collector, Collector::Count { .. });
315                assert_eq!(mapper.name(), "to_string");
316            }
317            _ => panic!("Expected CollectAndThen collector"),
318        }
319    }
320
321    #[test]
322    fn test_load_balance() {
323        let collector = Collector::load_balance(WasmFunction::new("get_employee"));
324        match collector {
325            Collector::LoadBalance { map, load } => {
326                assert_eq!(map.name(), "get_employee");
327                assert!(load.is_none());
328            }
329            _ => panic!("Expected LoadBalance collector"),
330        }
331    }
332
333    #[test]
334    fn test_load_balance_with_load() {
335        let collector = Collector::load_balance_with_load(
336            WasmFunction::new("get_employee"),
337            WasmFunction::new("get_load"),
338        );
339        match collector {
340            Collector::LoadBalance { map, load } => {
341                assert_eq!(map.name(), "get_employee");
342                assert!(load.is_some());
343            }
344            _ => panic!("Expected LoadBalance collector"),
345        }
346    }
347
348    #[test]
349    fn test_count_json_serialization() {
350        let collector = Collector::count();
351        let json = serde_json::to_string(&collector).unwrap();
352        assert!(json.contains("\"name\":\"count\""));
353
354        let parsed: Collector = serde_json::from_str(&json).unwrap();
355        assert_eq!(parsed, collector);
356    }
357
358    #[test]
359    fn test_sum_json_serialization() {
360        let collector = Collector::sum(WasmFunction::new("get_value"));
361        let json = serde_json::to_string(&collector).unwrap();
362        assert!(json.contains("\"name\":\"sum\""));
363        assert!(json.contains("\"map\":\"get_value\""));
364
365        let parsed: Collector = serde_json::from_str(&json).unwrap();
366        assert_eq!(parsed, collector);
367    }
368
369    #[test]
370    fn test_compose_json_serialization() {
371        let collector = Collector::compose(vec![Collector::count()], WasmFunction::new("wrap"));
372        let json = serde_json::to_string(&collector).unwrap();
373        assert!(json.contains("\"name\":\"compose\""));
374        assert!(json.contains("\"collectors\""));
375        assert!(json.contains("\"combiner\":\"wrap\""));
376
377        let parsed: Collector = serde_json::from_str(&json).unwrap();
378        assert_eq!(parsed, collector);
379    }
380
381    #[test]
382    fn test_conditionally_json_serialization() {
383        let collector = Collector::conditionally(WasmFunction::new("pred"), Collector::count());
384        let json = serde_json::to_string(&collector).unwrap();
385        assert!(json.contains("\"name\":\"conditionally\""));
386        assert!(json.contains("\"predicate\":\"pred\""));
387        assert!(json.contains("\"collector\""));
388
389        let parsed: Collector = serde_json::from_str(&json).unwrap();
390        assert_eq!(parsed, collector);
391    }
392
393    #[test]
394    fn test_collect_and_then_json_serialization() {
395        let collector =
396            Collector::collect_and_then(Collector::count(), WasmFunction::new("transform"));
397        let json = serde_json::to_string(&collector).unwrap();
398        assert!(json.contains("\"name\":\"collectAndThen\""));
399        assert!(json.contains("\"mapper\":\"transform\""));
400
401        let parsed: Collector = serde_json::from_str(&json).unwrap();
402        assert_eq!(parsed, collector);
403    }
404
405    #[test]
406    fn test_load_balance_json_serialization() {
407        let collector = Collector::load_balance(WasmFunction::new("get_item"));
408        let json = serde_json::to_string(&collector).unwrap();
409        assert!(json.contains("\"name\":\"loadBalance\""));
410
411        let parsed: Collector = serde_json::from_str(&json).unwrap();
412        assert_eq!(parsed, collector);
413    }
414
415    #[test]
416    fn test_nested_collectors_json() {
417        let collector = Collector::compose(
418            vec![
419                Collector::conditionally(
420                    WasmFunction::new("is_valid"),
421                    Collector::sum(WasmFunction::new("get_value")),
422                ),
423                Collector::collect_and_then(Collector::count(), WasmFunction::new("double")),
424            ],
425            WasmFunction::new("combine_results"),
426        );
427        let json = serde_json::to_string(&collector).unwrap();
428        let parsed: Collector = serde_json::from_str(&json).unwrap();
429        assert_eq!(parsed, collector);
430    }
431
432    #[test]
433    fn test_collector_clone() {
434        let collector = Collector::sum(WasmFunction::new("get_value"));
435        let cloned = collector.clone();
436        assert_eq!(collector, cloned);
437    }
438
439    #[test]
440    fn test_collector_debug() {
441        let collector = Collector::count();
442        let debug = format!("{:?}", collector);
443        assert!(debug.contains("Count"));
444    }
445}