snapup/
join_with_parent.rs

1use futures::{future::Either, stream, Stream, StreamExt};
2use std::{convert, future, hash::Hash, iter};
3
4use crate::SnapshotWithUpdates;
5
6impl<Key, Value, Snapshot, Updates> SnapshotWithUpdates<Snapshot, Updates>
7where
8    Key: Clone + Hash + Eq,
9    Value: Clone,
10    Snapshot: IntoIterator<Item = (Key, Value)>,
11    Updates: Stream,
12    Updates::Item: IntoIterator<Item = (Key, Option<Value>)>,
13{
14    pub fn join_with_optional_parent<ParentKey, ParentValue, ParentSnapshot, ParentUpdates>(
15        self,
16        other: SnapshotWithUpdates<ParentSnapshot, ParentUpdates>,
17        f: impl FnMut(&Key) -> ParentKey,
18    ) -> SnapshotWithUpdates<
19        impl IntoIterator<Item = ((ParentKey, Key), (Option<ParentValue>, Value))>,
20        impl Stream<
21            Item = impl IntoIterator<Item = ((ParentKey, Key), Option<(Option<ParentValue>, Value)>)>,
22        >,
23    >
24    where
25        ParentKey: Clone + Hash + Eq,
26        Snapshot::Item: Clone,
27        ParentSnapshot: IntoIterator<Item = (ParentKey, ParentValue)>,
28        ParentValue: Clone,
29        ParentUpdates: Stream,
30        ParentUpdates::Item: IntoIterator<Item = (ParentKey, Option<ParentValue>)>,
31    {
32        let (snapshot, updates) = self.into_inner();
33        let (parent_snapshot, parent_updates) = other.into_inner();
34
35        let initial_state = JoinOneToManyState::new(parent_snapshot, snapshot, f);
36
37        let initial_snapshot = initial_state.snapshot.clone().into_iter().flat_map(
38            |(parent_key, (parent, inner_key_to_inner))| {
39                inner_key_to_inner
40                    .into_iter()
41                    .map(move |(key, value)| ((parent_key.clone(), key), (parent.clone(), value)))
42            },
43        );
44
45        let parent_updates = parent_updates
46            .map(Either::Left)
47            .map(Some)
48            .chain(stream::once(future::ready(None)));
49
50        let updates = updates
51            .map(Either::Right)
52            .map(Some)
53            .chain(stream::once(future::ready(None)));
54
55        let merged = tokio_stream::StreamExt::merge(parent_updates, updates);
56        let combined_updates = merged
57            .scan(initial_state, |state, next| {
58                let Some(next) = next else {
59                    return future::ready(None);
60                };
61
62                let result = state.ingest_update_optional_parent(next);
63
64                let result = if result.len() == 0 {
65                    None
66                } else {
67                    Some(result)
68                };
69
70                future::ready(Some(result))
71            })
72            .flat_map(stream::iter);
73
74        SnapshotWithUpdates::new(initial_snapshot, combined_updates)
75    }
76
77    pub fn join_with_parent<ParentKey, ParentValue, ParentSnapshot, ParentUpdates>(
78        self,
79        other: SnapshotWithUpdates<ParentSnapshot, ParentUpdates>,
80        f: impl FnMut(&Key) -> ParentKey,
81    ) -> SnapshotWithUpdates<
82        impl IntoIterator<Item = ((ParentKey, Key), (ParentValue, Value))>,
83        impl Stream<Item = impl IntoIterator<Item = ((ParentKey, Key), Option<(ParentValue, Value)>)>>,
84    >
85    where
86        ParentKey: Clone + Hash + Eq,
87        Snapshot::Item: Clone,
88        ParentSnapshot: IntoIterator<Item = (ParentKey, ParentValue)>,
89        ParentValue: Clone,
90        ParentUpdates: Stream,
91        ParentUpdates::Item: IntoIterator<Item = (ParentKey, Option<ParentValue>)>,
92    {
93        let (snapshot, updates) = self.into_inner();
94        let (parent_snapshot, parent_updates) = other.into_inner();
95
96        let initial_state = JoinOneToManyState::new(parent_snapshot, snapshot, f);
97
98        let initial_snapshot: Vec<_> = initial_state
99            .snapshot
100            .clone()
101            .into_iter()
102            .flat_map(|(parent_key, (parent, inner_key_to_inner))| {
103                parent
104                    .map(move |parent| {
105                        inner_key_to_inner.into_iter().map(move |(key, value)| {
106                            ((parent_key.clone(), key), (parent.clone(), value))
107                        })
108                    })
109                    .into_iter()
110                    .flat_map(convert::identity)
111            })
112            .collect();
113
114        let parent_updates = parent_updates
115            .map(Either::Left)
116            .map(Some)
117            .chain(stream::once(future::ready(None)));
118
119        let updates = updates
120            .map(Either::Right)
121            .map(Some)
122            .chain(stream::once(future::ready(None)));
123
124        let merged = tokio_stream::StreamExt::merge(parent_updates, updates);
125        let combined_updates = merged
126            .scan(initial_state, |state, next| {
127                let Some(next) = next else {
128                    return future::ready(None);
129                };
130
131                let result = state.ingest_update(next);
132
133                let result = if result.len() == 0 {
134                    None
135                } else {
136                    Some(result)
137                };
138
139                future::ready(Some(result))
140            })
141            .flat_map(stream::iter);
142
143        SnapshotWithUpdates::new(initial_snapshot, combined_updates)
144    }
145}
146
147struct JoinOneToManyState<ParentKey, ParentValue, Key, Value, F> {
148    snapshot: im::HashMap<ParentKey, (Option<ParentValue>, im::HashMap<Key, Value>)>,
149    f: F,
150}
151
152impl<ParentKey, ParentValue, ChildKey, ChildValue, F>
153    JoinOneToManyState<ParentKey, ParentValue, ChildKey, ChildValue, F>
154where
155    ParentKey: Clone + Hash + Eq,
156    ParentValue: Clone,
157    ChildKey: Clone + Hash + Eq,
158    ChildValue: Clone,
159    F: FnMut(&ChildKey) -> ParentKey,
160{
161    fn new(
162        parents: impl IntoIterator<Item = (ParentKey, ParentValue)>,
163        children: impl IntoIterator<Item = (ChildKey, ChildValue)>,
164        mut f: F,
165    ) -> Self {
166        use im::HashMap;
167
168        let mut snapshot = parents
169            .into_iter()
170            .map(|(key, value)| (key, (Some(value), HashMap::new())))
171            .collect::<HashMap<_, _>>();
172
173        for (child_key, child_value) in children {
174            let parent_key = f(&child_key);
175            let (_, children_state) = snapshot
176                .entry(parent_key)
177                .or_insert_with(|| (None, HashMap::new()));
178
179            children_state.insert(child_key, child_value);
180        }
181
182        Self { snapshot, f }
183    }
184
185    fn ingest_update<'a>(
186        &mut self,
187        update: Either<
188            impl IntoIterator<Item = (ParentKey, Option<ParentValue>)>,
189            impl IntoIterator<Item = (ChildKey, Option<ChildValue>)>,
190        >,
191    ) -> Vec<((ParentKey, ChildKey), Option<(ParentValue, ChildValue)>)> {
192        use im::HashMap;
193
194        match update {
195            Either::Left(parents) => parents
196                .into_iter()
197                .flat_map(|(parent_key, parent)| {
198                    let (parent_state, children_state) = self
199                        .snapshot
200                        .entry(parent_key.clone())
201                        .or_insert_with(|| (None, HashMap::new()));
202
203                    let changed = match &parent {
204                        Some(left) => {
205                            _ = parent_state.insert(left.to_owned());
206                            true
207                        }
208                        None => parent_state.take().is_some(),
209                    };
210
211                    if changed {
212                        itertools::Either::Left(children_state.clone().into_iter().map(
213                            move |(child_key, child)| {
214                                let result = parent.as_ref().map(|parent| (parent.clone(), child));
215
216                                ((parent_key.clone(), child_key), result)
217                            },
218                        ))
219                    } else {
220                        itertools::Either::Right(iter::empty())
221                    }
222                })
223                .collect(),
224            Either::Right(children) => children
225                .into_iter()
226                .flat_map(|(child_key, child)| {
227                    let parent_key = (self.f)(&child_key);
228
229                    let (parent_state, children_state) = self
230                        .snapshot
231                        .entry(parent_key.clone())
232                        .or_insert_with(|| (None, HashMap::new()));
233
234                    let changed = match &child {
235                        Some(child) => {
236                            children_state.insert(child_key.clone(), child.to_owned());
237                            true
238                        }
239                        None => children_state.remove(&child_key).is_some(),
240                    };
241
242                    return changed.then(|| {
243                        let result = match (parent_state, child) {
244                            (Some(parent), Some(child)) => Some((parent.to_owned(), child)),
245                            _ => None,
246                        };
247
248                        ((parent_key, child_key), result)
249                    });
250                })
251                .collect(),
252        }
253    }
254
255    fn ingest_update_optional_parent<'a>(
256        &mut self,
257        update: Either<
258            impl IntoIterator<Item = (ParentKey, Option<ParentValue>)>,
259            impl IntoIterator<Item = (ChildKey, Option<ChildValue>)>,
260        >,
261    ) -> Vec<(
262        (ParentKey, ChildKey),
263        Option<(Option<ParentValue>, ChildValue)>,
264    )> {
265        use im::HashMap;
266
267        match update {
268            Either::Left(parents) => parents
269                .into_iter()
270                .flat_map(|(key, parent)| {
271                    let (parent_state, children_state) = self
272                        .snapshot
273                        .entry(key.clone())
274                        .or_insert_with(|| (None, HashMap::new()));
275
276                    match &parent {
277                        Some(parent) => {
278                            _ = parent_state.insert(parent.to_owned());
279                        }
280                        None => {
281                            _ = parent_state.take();
282                        }
283                    };
284
285                    children_state
286                        .clone()
287                        .into_iter()
288                        .map(move |(child_key, child)| {
289                            let result = (parent.clone(), child);
290                            ((key.clone(), child_key), Some(result))
291                        })
292                })
293                .collect(),
294            Either::Right(children) => children
295                .into_iter()
296                .map(|(child_key, child)| {
297                    let parent_key = (self.f)(&child_key);
298
299                    let (parent_state, children_state) = self
300                        .snapshot
301                        .entry(parent_key.clone())
302                        .or_insert_with(|| (None, HashMap::new()));
303
304                    match &child {
305                        Some(right) => {
306                            _ = children_state.insert(child_key.clone(), right.to_owned());
307                        }
308                        None => {
309                            _ = children_state.remove(&child_key);
310                        }
311                    };
312
313                    let result = child.map(|child| (parent_state.clone(), child));
314
315                    ((parent_key, child_key), result)
316                })
317                .collect(),
318        }
319    }
320}
321
322#[cfg(test)]
323mod tests {
324    use std::collections::HashMap;
325
326    use futures::{channel::mpsc, SinkExt, StreamExt};
327    use tokio_test::{assert_pending, assert_ready};
328
329    use super::*;
330
331    #[tokio::test(flavor = "current_thread", start_paused = true)]
332    async fn test_join_with_parent() {
333        type ParentKey = i32;
334        type ChildKey = (ParentKey, i64);
335
336        struct ParentEntity {
337            key: ParentKey,
338            value: &'static str,
339        }
340
341        impl ParentEntity {
342            fn into_inner(self) -> (ParentKey, &'static str) {
343                (self.key, self.value)
344            }
345        }
346
347        struct ChildEntity {
348            key: ChildKey,
349            value: &'static str,
350        }
351
352        impl ChildEntity {
353            fn into_inner(self) -> (ChildKey, &'static str) {
354                (self.key, self.value)
355            }
356        }
357
358        let parent_snapshot = vec![
359            ParentEntity {
360                key: 1,
361                value: "PARENT_1",
362            },
363            ParentEntity {
364                key: 2,
365                value: "PARENT_2",
366            },
367        ]
368        .into_iter()
369        .map(ParentEntity::into_inner);
370
371        let children_snapshot = vec![
372            ChildEntity {
373                key: (1, 1011),
374                value: "CHILD_1011_PARENT_1",
375            },
376            ChildEntity {
377                key: (2, 1021),
378                value: "CHILD_1021_PARENT_2",
379            },
380            ChildEntity {
381                key: (2, 1022),
382                value: "CHILD_1022_PARENT_2",
383            },
384            ChildEntity {
385                key: (3, 1031),
386                value: "CHILD_1031_PARENT_3",
387            },
388        ]
389        .into_iter()
390        .map(ChildEntity::into_inner);
391
392        enum Update {
393            Parents(Vec<(ParentKey, Option<&'static str>)>),
394            Children(Vec<(ChildKey, Option<&'static str>)>),
395        }
396
397        let expected_initial_snapshot = vec![
398            ((1, (1, 1011_i64)), ("PARENT_1", "CHILD_1011_PARENT_1")),
399            ((2, (2, 1021_i64)), ("PARENT_2", "CHILD_1021_PARENT_2")),
400            ((2, (2, 1022_i64)), ("PARENT_2", "CHILD_1022_PARENT_2")),
401        ]
402        .into_iter()
403        .collect::<HashMap<_, _>>();
404
405        let (mut parent_tx, parent_rx) = mpsc::unbounded();
406        let parent_snapshot_with_updates = SnapshotWithUpdates::new(parent_snapshot, parent_rx);
407
408        let (mut child_tx, child_rx) = mpsc::unbounded();
409        let children_snapshot_with_updates = SnapshotWithUpdates::new(children_snapshot, child_rx);
410
411        let result = children_snapshot_with_updates
412            .join_with_parent(parent_snapshot_with_updates, |child_key| child_key.0);
413
414        let (snapshot, updates) = result.into_inner();
415        let mut updates = updates
416            .map(|updates| updates.into_iter().collect::<Vec<_>>())
417            .boxed();
418        let snapshot = snapshot.into_iter().collect::<HashMap<_, _>>();
419
420        assert_eq!(expected_initial_snapshot, snapshot);
421
422        struct UpdateTestCase {
423            update: Either<
424                Vec<(ParentKey, Option<&'static str>)>,
425                Vec<(ChildKey, Option<&'static str>)>,
426            >,
427            expected_result: Option<Vec<((i32, (i32, i64)), Option<(&'static str, &'static str)>)>>,
428        }
429
430        let test_cases = vec![
431            UpdateTestCase {
432                update: Either::Left(vec![(1, Some("PARENT_1_UPDATE_1"))]),
433                expected_result: Some(vec![(
434                    (1, (1, 1011)),
435                    Some(("PARENT_1_UPDATE_1", "CHILD_1011_PARENT_1")),
436                )]),
437            },
438            UpdateTestCase {
439                update: Either::Left(vec![(2, Some("PARENT_2_UPDATE_1"))]),
440                expected_result: Some(vec![
441                    (
442                        (2, (2, 1021)),
443                        Some(("PARENT_2_UPDATE_1", "CHILD_1021_PARENT_2")),
444                    ),
445                    (
446                        (2, (2, 1022)),
447                        Some(("PARENT_2_UPDATE_1", "CHILD_1022_PARENT_2")),
448                    ),
449                ]),
450            },
451            UpdateTestCase {
452                update: Either::Left(vec![(3, Some("PARENT_3"))]),
453                expected_result: Some(vec![(
454                    (3, (3, 1031)),
455                    Some(("PARENT_3", "CHILD_1031_PARENT_3")),
456                )]),
457            },
458            UpdateTestCase {
459                update: Either::Left(vec![(2, None)]),
460                expected_result: Some(vec![((2, (2, 1021)), None), ((2, (2, 1022)), None)]),
461            },
462            // Deleted twice... it's already deleted so nothing should be emitted.
463            UpdateTestCase {
464                update: Either::Left(vec![(2, None)]),
465                expected_result: None,
466            },
467            UpdateTestCase {
468                update: Either::Right(vec![((1, 1012), Some("CHILD_1012_PARENT_1"))]),
469                expected_result: Some(vec![(
470                    (1, (1, 1012)),
471                    Some(("PARENT_1_UPDATE_1", "CHILD_1012_PARENT_1")),
472                )]),
473            },
474            UpdateTestCase {
475                update: Either::Right(vec![((1, 1011), None)]),
476                expected_result: Some(vec![((1, (1, 1011)), None)]),
477            },
478            // Deleted twice... it's already deleted so nothing should be emitted.
479            UpdateTestCase {
480                update: Either::Right(vec![((1, 1011), None)]),
481                expected_result: None,
482            },
483            // PARENT_4 has no children... yet
484            UpdateTestCase {
485                update: Either::Left(vec![(4, Some("PARENT_4"))]),
486                expected_result: None,
487            },
488            // Now PARENT_4 has children, and we also modified CHILD_1031
489            UpdateTestCase {
490                update: Either::Right(vec![
491                    ((4, 1041), Some("CHILD_1041_PARENT_4")),
492                    ((4, 1042), Some("CHILD_1042_PARENT_4")),
493                    ((3, 1031), Some("CHILD_1031_PARENT_3_UPDATE_1")),
494                ]),
495                expected_result: Some(vec![
496                    ((4, (4, 1041)), Some(("PARENT_4", "CHILD_1041_PARENT_4"))),
497                    ((4, (4, 1042)), Some(("PARENT_4", "CHILD_1042_PARENT_4"))),
498                    (
499                        (3, (3, 1031)),
500                        Some(("PARENT_3", "CHILD_1031_PARENT_3_UPDATE_1")),
501                    ),
502                ]),
503            },
504        ];
505
506        let waker = futures::task::noop_waker_ref();
507        let mut cx = std::task::Context::from_waker(&waker);
508
509        for (index, tc) in test_cases.into_iter().enumerate() {
510            let UpdateTestCase {
511                update,
512                expected_result,
513            } = tc;
514
515            match update {
516                Either::Left(parents) => parent_tx.send(parents).await.unwrap(),
517                Either::Right(children) => child_tx.send(children).await.unwrap(),
518            }
519
520            let Some(mut expected_result) = expected_result else {
521                assert_pending!(updates.poll_next_unpin(&mut cx));
522                continue;
523            };
524
525            let update = assert_ready!(updates.poll_next_unpin(&mut cx));
526
527            let Some(update) = update else {
528                panic!("Expected Some for update");
529            };
530
531            expected_result.sort_by_key(|r| r.0);
532
533            let mut actual_result = update.into_iter().collect::<Vec<_>>();
534            actual_result.sort_by_key(|r| r.0);
535
536            assert_eq!(actual_result, expected_result, "TestCase {index} failed")
537        }
538    }
539}