1use crate::constraints::{Collector, Joiner, WasmFunction};
2use serde::{Deserialize, Serialize};
3
4#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
5#[serde(tag = "kind")]
6pub enum StreamComponent {
7 #[serde(rename = "forEach")]
8 ForEach {
9 #[serde(rename = "className")]
10 class_name: String,
11 },
12 #[serde(rename = "forEachIncludingUnassigned")]
13 ForEachIncludingUnassigned {
14 #[serde(rename = "className")]
15 class_name: String,
16 },
17 #[serde(rename = "forEachUniquePair")]
18 ForEachUniquePair {
19 #[serde(rename = "className")]
20 class_name: String,
21 #[serde(default, skip_serializing_if = "Vec::is_empty")]
22 joiners: Vec<Joiner>,
23 },
24 #[serde(rename = "filter")]
25 Filter { predicate: WasmFunction },
26 #[serde(rename = "join")]
27 Join {
28 #[serde(rename = "className")]
29 class_name: String,
30 #[serde(default, skip_serializing_if = "Vec::is_empty")]
31 joiners: Vec<Joiner>,
32 },
33 #[serde(rename = "ifExists")]
34 IfExists {
35 #[serde(rename = "className")]
36 class_name: String,
37 #[serde(default, skip_serializing_if = "Vec::is_empty")]
38 joiners: Vec<Joiner>,
39 },
40 #[serde(rename = "ifNotExists")]
41 IfNotExists {
42 #[serde(rename = "className")]
43 class_name: String,
44 #[serde(default, skip_serializing_if = "Vec::is_empty")]
45 joiners: Vec<Joiner>,
46 },
47 #[serde(rename = "groupBy")]
48 GroupBy {
49 #[serde(default, skip_serializing_if = "Vec::is_empty")]
50 keys: Vec<WasmFunction>,
51 #[serde(default, skip_serializing_if = "Vec::is_empty")]
52 aggregators: Vec<Collector>,
53 },
54 #[serde(rename = "map")]
55 Map {
56 #[serde(rename = "mapper")]
57 mappers: Vec<WasmFunction>,
58 },
59 #[serde(rename = "flattenLast")]
60 FlattenLast {
61 #[serde(skip_serializing_if = "Option::is_none")]
62 map: Option<WasmFunction>,
63 },
64 #[serde(rename = "expand")]
65 Expand {
66 #[serde(rename = "mapper")]
67 mappers: Vec<WasmFunction>,
68 },
69 #[serde(rename = "complement")]
70 Complement {
71 #[serde(rename = "className")]
72 class_name: String,
73 },
74 #[serde(rename = "penalize")]
75 Penalize {
76 weight: String,
77 #[serde(rename = "scaleBy", skip_serializing_if = "Option::is_none")]
78 scale_by: Option<WasmFunction>,
79 },
80 #[serde(rename = "reward")]
81 Reward {
82 weight: String,
83 #[serde(rename = "scaleBy", skip_serializing_if = "Option::is_none")]
84 scale_by: Option<WasmFunction>,
85 },
86}
87
88impl StreamComponent {
89 pub fn for_each(class_name: impl Into<String>) -> Self {
90 StreamComponent::ForEach {
91 class_name: class_name.into(),
92 }
93 }
94
95 pub fn for_each_including_unassigned(class_name: impl Into<String>) -> Self {
96 StreamComponent::ForEachIncludingUnassigned {
97 class_name: class_name.into(),
98 }
99 }
100
101 pub fn for_each_unique_pair(class_name: impl Into<String>) -> Self {
102 StreamComponent::ForEachUniquePair {
103 class_name: class_name.into(),
104 joiners: Vec::new(),
105 }
106 }
107
108 pub fn for_each_unique_pair_with_joiners(
109 class_name: impl Into<String>,
110 joiners: Vec<Joiner>,
111 ) -> Self {
112 StreamComponent::ForEachUniquePair {
113 class_name: class_name.into(),
114 joiners,
115 }
116 }
117
118 pub fn filter(predicate: WasmFunction) -> Self {
119 StreamComponent::Filter { predicate }
120 }
121
122 pub fn join(class_name: impl Into<String>) -> Self {
123 StreamComponent::Join {
124 class_name: class_name.into(),
125 joiners: Vec::new(),
126 }
127 }
128
129 pub fn join_with_joiners(class_name: impl Into<String>, joiners: Vec<Joiner>) -> Self {
130 StreamComponent::Join {
131 class_name: class_name.into(),
132 joiners,
133 }
134 }
135
136 pub fn if_exists(class_name: impl Into<String>) -> Self {
137 StreamComponent::IfExists {
138 class_name: class_name.into(),
139 joiners: Vec::new(),
140 }
141 }
142
143 pub fn if_exists_with_joiners(class_name: impl Into<String>, joiners: Vec<Joiner>) -> Self {
144 StreamComponent::IfExists {
145 class_name: class_name.into(),
146 joiners,
147 }
148 }
149
150 pub fn if_not_exists(class_name: impl Into<String>) -> Self {
151 StreamComponent::IfNotExists {
152 class_name: class_name.into(),
153 joiners: Vec::new(),
154 }
155 }
156
157 pub fn if_not_exists_with_joiners(class_name: impl Into<String>, joiners: Vec<Joiner>) -> Self {
158 StreamComponent::IfNotExists {
159 class_name: class_name.into(),
160 joiners,
161 }
162 }
163
164 pub fn group_by(keys: Vec<WasmFunction>, aggregators: Vec<Collector>) -> Self {
165 StreamComponent::GroupBy { keys, aggregators }
166 }
167
168 pub fn group_by_key(key: WasmFunction) -> Self {
169 StreamComponent::GroupBy {
170 keys: vec![key],
171 aggregators: Vec::new(),
172 }
173 }
174
175 pub fn group_by_collector(aggregator: Collector) -> Self {
176 StreamComponent::GroupBy {
177 keys: Vec::new(),
178 aggregators: vec![aggregator],
179 }
180 }
181
182 pub fn map(mappers: Vec<WasmFunction>) -> Self {
183 StreamComponent::Map { mappers }
184 }
185
186 pub fn map_single(mapper: WasmFunction) -> Self {
187 StreamComponent::Map {
188 mappers: vec![mapper],
189 }
190 }
191
192 pub fn flatten_last() -> Self {
193 StreamComponent::FlattenLast { map: None }
194 }
195
196 pub fn flatten_last_with_map(map: WasmFunction) -> Self {
197 StreamComponent::FlattenLast { map: Some(map) }
198 }
199
200 pub fn expand(mappers: Vec<WasmFunction>) -> Self {
201 StreamComponent::Expand { mappers }
202 }
203
204 pub fn complement(class_name: impl Into<String>) -> Self {
205 StreamComponent::Complement {
206 class_name: class_name.into(),
207 }
208 }
209
210 pub fn penalize(weight: impl Into<String>) -> Self {
211 StreamComponent::Penalize {
212 weight: weight.into(),
213 scale_by: None,
214 }
215 }
216
217 pub fn penalize_with_weigher(weight: impl Into<String>, scale_by: WasmFunction) -> Self {
218 StreamComponent::Penalize {
219 weight: weight.into(),
220 scale_by: Some(scale_by),
221 }
222 }
223
224 pub fn reward(weight: impl Into<String>) -> Self {
225 StreamComponent::Reward {
226 weight: weight.into(),
227 scale_by: None,
228 }
229 }
230
231 pub fn reward_with_weigher(weight: impl Into<String>, scale_by: WasmFunction) -> Self {
232 StreamComponent::Reward {
233 weight: weight.into(),
234 scale_by: Some(scale_by),
235 }
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242
243 #[test]
244 fn test_for_each() {
245 let component = StreamComponent::for_each("Lesson");
246 match component {
247 StreamComponent::ForEach { class_name } => {
248 assert_eq!(class_name, "Lesson");
249 }
250 _ => panic!("Expected ForEach"),
251 }
252 }
253
254 #[test]
255 fn test_for_each_including_unassigned() {
256 let component = StreamComponent::for_each_including_unassigned("Lesson");
257 match component {
258 StreamComponent::ForEachIncludingUnassigned { class_name } => {
259 assert_eq!(class_name, "Lesson");
260 }
261 _ => panic!("Expected ForEachIncludingUnassigned"),
262 }
263 }
264
265 #[test]
266 fn test_for_each_unique_pair() {
267 let component = StreamComponent::for_each_unique_pair("Lesson");
268 match component {
269 StreamComponent::ForEachUniquePair {
270 class_name,
271 joiners,
272 } => {
273 assert_eq!(class_name, "Lesson");
274 assert!(joiners.is_empty());
275 }
276 _ => panic!("Expected ForEachUniquePair"),
277 }
278 }
279
280 #[test]
281 fn test_for_each_unique_pair_with_joiners() {
282 let component = StreamComponent::for_each_unique_pair_with_joiners(
283 "Lesson",
284 vec![Joiner::equal(WasmFunction::new("get_timeslot"))],
285 );
286 match component {
287 StreamComponent::ForEachUniquePair { joiners, .. } => {
288 assert_eq!(joiners.len(), 1);
289 }
290 _ => panic!("Expected ForEachUniquePair"),
291 }
292 }
293
294 #[test]
295 fn test_filter() {
296 let component = StreamComponent::filter(WasmFunction::new("is_valid"));
297 match component {
298 StreamComponent::Filter { predicate } => {
299 assert_eq!(predicate.name(), "is_valid");
300 }
301 _ => panic!("Expected Filter"),
302 }
303 }
304
305 #[test]
306 fn test_join() {
307 let component = StreamComponent::join("Room");
308 match component {
309 StreamComponent::Join {
310 class_name,
311 joiners,
312 } => {
313 assert_eq!(class_name, "Room");
314 assert!(joiners.is_empty());
315 }
316 _ => panic!("Expected Join"),
317 }
318 }
319
320 #[test]
321 fn test_join_with_joiners() {
322 let component = StreamComponent::join_with_joiners(
323 "Room",
324 vec![Joiner::equal(WasmFunction::new("get_room"))],
325 );
326 match component {
327 StreamComponent::Join { joiners, .. } => {
328 assert_eq!(joiners.len(), 1);
329 }
330 _ => panic!("Expected Join"),
331 }
332 }
333
334 #[test]
335 fn test_if_exists() {
336 let component = StreamComponent::if_exists("Conflict");
337 match component {
338 StreamComponent::IfExists { class_name, .. } => {
339 assert_eq!(class_name, "Conflict");
340 }
341 _ => panic!("Expected IfExists"),
342 }
343 }
344
345 #[test]
346 fn test_if_not_exists() {
347 let component = StreamComponent::if_not_exists("Conflict");
348 match component {
349 StreamComponent::IfNotExists { class_name, .. } => {
350 assert_eq!(class_name, "Conflict");
351 }
352 _ => panic!("Expected IfNotExists"),
353 }
354 }
355
356 #[test]
357 fn test_group_by() {
358 let component = StreamComponent::group_by(
359 vec![WasmFunction::new("get_room")],
360 vec![Collector::count()],
361 );
362 match component {
363 StreamComponent::GroupBy { keys, aggregators } => {
364 assert_eq!(keys.len(), 1);
365 assert_eq!(aggregators.len(), 1);
366 }
367 _ => panic!("Expected GroupBy"),
368 }
369 }
370
371 #[test]
372 fn test_group_by_key() {
373 let component = StreamComponent::group_by_key(WasmFunction::new("get_room"));
374 match component {
375 StreamComponent::GroupBy { keys, aggregators } => {
376 assert_eq!(keys.len(), 1);
377 assert!(aggregators.is_empty());
378 }
379 _ => panic!("Expected GroupBy"),
380 }
381 }
382
383 #[test]
384 fn test_group_by_collector() {
385 let component = StreamComponent::group_by_collector(Collector::count());
386 match component {
387 StreamComponent::GroupBy { keys, aggregators } => {
388 assert!(keys.is_empty());
389 assert_eq!(aggregators.len(), 1);
390 }
391 _ => panic!("Expected GroupBy"),
392 }
393 }
394
395 #[test]
396 fn test_map() {
397 let component =
398 StreamComponent::map(vec![WasmFunction::new("get_a"), WasmFunction::new("get_b")]);
399 match component {
400 StreamComponent::Map { mappers } => {
401 assert_eq!(mappers.len(), 2);
402 }
403 _ => panic!("Expected Map"),
404 }
405 }
406
407 #[test]
408 fn test_map_single() {
409 let component = StreamComponent::map_single(WasmFunction::new("get_value"));
410 match component {
411 StreamComponent::Map { mappers } => {
412 assert_eq!(mappers.len(), 1);
413 }
414 _ => panic!("Expected Map"),
415 }
416 }
417
418 #[test]
419 fn test_flatten_last() {
420 let component = StreamComponent::flatten_last();
421 match component {
422 StreamComponent::FlattenLast { map } => {
423 assert!(map.is_none());
424 }
425 _ => panic!("Expected FlattenLast"),
426 }
427 }
428
429 #[test]
430 fn test_flatten_last_with_map() {
431 let component = StreamComponent::flatten_last_with_map(WasmFunction::new("get_items"));
432 match component {
433 StreamComponent::FlattenLast { map } => {
434 assert!(map.is_some());
435 }
436 _ => panic!("Expected FlattenLast"),
437 }
438 }
439
440 #[test]
441 fn test_expand() {
442 let component = StreamComponent::expand(vec![WasmFunction::new("get_extra")]);
443 match component {
444 StreamComponent::Expand { mappers } => {
445 assert_eq!(mappers.len(), 1);
446 }
447 _ => panic!("Expected Expand"),
448 }
449 }
450
451 #[test]
452 fn test_complement() {
453 let component = StreamComponent::complement("Timeslot");
454 match component {
455 StreamComponent::Complement { class_name } => {
456 assert_eq!(class_name, "Timeslot");
457 }
458 _ => panic!("Expected Complement"),
459 }
460 }
461
462 #[test]
463 fn test_penalize() {
464 let component = StreamComponent::penalize("1hard");
465 match component {
466 StreamComponent::Penalize { weight, scale_by } => {
467 assert_eq!(weight, "1hard");
468 assert!(scale_by.is_none());
469 }
470 _ => panic!("Expected Penalize"),
471 }
472 }
473
474 #[test]
475 fn test_penalize_with_weigher() {
476 let component =
477 StreamComponent::penalize_with_weigher("1hard", WasmFunction::new("get_weight"));
478 match component {
479 StreamComponent::Penalize { weight, scale_by } => {
480 assert_eq!(weight, "1hard");
481 assert!(scale_by.is_some());
482 }
483 _ => panic!("Expected Penalize"),
484 }
485 }
486
487 #[test]
488 fn test_reward() {
489 let component = StreamComponent::reward("1soft");
490 match component {
491 StreamComponent::Reward { weight, scale_by } => {
492 assert_eq!(weight, "1soft");
493 assert!(scale_by.is_none());
494 }
495 _ => panic!("Expected Reward"),
496 }
497 }
498
499 #[test]
500 fn test_reward_with_weigher() {
501 let component =
502 StreamComponent::reward_with_weigher("1soft", WasmFunction::new("get_bonus"));
503 match component {
504 StreamComponent::Reward { scale_by, .. } => {
505 assert!(scale_by.is_some());
506 }
507 _ => panic!("Expected Reward"),
508 }
509 }
510
511 #[test]
512 fn test_for_each_json_serialization() {
513 let component = StreamComponent::for_each("Lesson");
514 let json = serde_json::to_string(&component).unwrap();
515 assert!(json.contains("\"kind\":\"forEach\""));
516 assert!(json.contains("\"className\":\"Lesson\""));
517
518 let parsed: StreamComponent = serde_json::from_str(&json).unwrap();
519 assert_eq!(parsed, component);
520 }
521
522 #[test]
523 fn test_filter_json_serialization() {
524 let component = StreamComponent::filter(WasmFunction::new("is_valid"));
525 let json = serde_json::to_string(&component).unwrap();
526 assert!(json.contains("\"kind\":\"filter\""));
527 assert!(json.contains("\"predicate\":\"is_valid\""));
528
529 let parsed: StreamComponent = serde_json::from_str(&json).unwrap();
530 assert_eq!(parsed, component);
531 }
532
533 #[test]
534 fn test_join_json_serialization() {
535 let component = StreamComponent::join_with_joiners(
536 "Room",
537 vec![Joiner::equal(WasmFunction::new("get_room"))],
538 );
539 let json = serde_json::to_string(&component).unwrap();
540 assert!(json.contains("\"kind\":\"join\""));
541 assert!(json.contains("\"className\":\"Room\""));
542 assert!(json.contains("\"joiners\""));
543
544 let parsed: StreamComponent = serde_json::from_str(&json).unwrap();
545 assert_eq!(parsed, component);
546 }
547
548 #[test]
549 fn test_group_by_json_serialization() {
550 let component = StreamComponent::group_by(
551 vec![WasmFunction::new("get_room")],
552 vec![Collector::count()],
553 );
554 let json = serde_json::to_string(&component).unwrap();
555 assert!(json.contains("\"kind\":\"groupBy\""));
556 assert!(json.contains("\"keys\""));
557 assert!(json.contains("\"aggregators\""));
558
559 let parsed: StreamComponent = serde_json::from_str(&json).unwrap();
560 assert_eq!(parsed, component);
561 }
562
563 #[test]
564 fn test_penalize_json_serialization() {
565 let component = StreamComponent::penalize("1hard");
566 let json = serde_json::to_string(&component).unwrap();
567 assert!(json.contains("\"kind\":\"penalize\""));
568 assert!(json.contains("\"weight\":\"1hard\""));
569
570 let parsed: StreamComponent = serde_json::from_str(&json).unwrap();
571 assert_eq!(parsed, component);
572 }
573
574 #[test]
575 fn test_component_clone() {
576 let component = StreamComponent::for_each("Lesson");
577 let cloned = component.clone();
578 assert_eq!(component, cloned);
579 }
580
581 #[test]
582 fn test_component_debug() {
583 let component = StreamComponent::for_each("Lesson");
584 let debug = format!("{:?}", component);
585 assert!(debug.contains("ForEach"));
586 }
587}