1use futures::{future::Either, stream, Stream, StreamExt};
2use std::{convert, future, hash::Hash, iter};
3
4use crate::SnapshotWithUpdates;
5
6impl<Key, Value, Snapshot, Updates> SnapshotWithUpdates<Snapshot, Updates>
7where
8 Key: Clone + Hash + Eq,
9 Value: Clone,
10 Snapshot: IntoIterator<Item = (Key, Value)>,
11 Updates: Stream,
12 Updates::Item: IntoIterator<Item = (Key, Option<Value>)>,
13{
14 pub fn join_with_optional_parent<ParentKey, ParentValue, ParentSnapshot, ParentUpdates>(
15 self,
16 other: SnapshotWithUpdates<ParentSnapshot, ParentUpdates>,
17 f: impl FnMut(&Key) -> ParentKey,
18 ) -> SnapshotWithUpdates<
19 impl IntoIterator<Item = ((ParentKey, Key), (Option<ParentValue>, Value))>,
20 impl Stream<
21 Item = impl IntoIterator<Item = ((ParentKey, Key), Option<(Option<ParentValue>, Value)>)>,
22 >,
23 >
24 where
25 ParentKey: Clone + Hash + Eq,
26 Snapshot::Item: Clone,
27 ParentSnapshot: IntoIterator<Item = (ParentKey, ParentValue)>,
28 ParentValue: Clone,
29 ParentUpdates: Stream,
30 ParentUpdates::Item: IntoIterator<Item = (ParentKey, Option<ParentValue>)>,
31 {
32 let (snapshot, updates) = self.into_inner();
33 let (parent_snapshot, parent_updates) = other.into_inner();
34
35 let initial_state = JoinOneToManyState::new(parent_snapshot, snapshot, f);
36
37 let initial_snapshot = initial_state.snapshot.clone().into_iter().flat_map(
38 |(parent_key, (parent, inner_key_to_inner))| {
39 inner_key_to_inner
40 .into_iter()
41 .map(move |(key, value)| ((parent_key.clone(), key), (parent.clone(), value)))
42 },
43 );
44
45 let parent_updates = parent_updates
46 .map(Either::Left)
47 .map(Some)
48 .chain(stream::once(future::ready(None)));
49
50 let updates = updates
51 .map(Either::Right)
52 .map(Some)
53 .chain(stream::once(future::ready(None)));
54
55 let merged = tokio_stream::StreamExt::merge(parent_updates, updates);
56 let combined_updates = merged
57 .scan(initial_state, |state, next| {
58 let Some(next) = next else {
59 return future::ready(None);
60 };
61
62 let result = state.ingest_update_optional_parent(next);
63
64 let result = if result.len() == 0 {
65 None
66 } else {
67 Some(result)
68 };
69
70 future::ready(Some(result))
71 })
72 .flat_map(stream::iter);
73
74 SnapshotWithUpdates::new(initial_snapshot, combined_updates)
75 }
76
77 pub fn join_with_parent<ParentKey, ParentValue, ParentSnapshot, ParentUpdates>(
78 self,
79 other: SnapshotWithUpdates<ParentSnapshot, ParentUpdates>,
80 f: impl FnMut(&Key) -> ParentKey,
81 ) -> SnapshotWithUpdates<
82 impl IntoIterator<Item = ((ParentKey, Key), (ParentValue, Value))>,
83 impl Stream<Item = impl IntoIterator<Item = ((ParentKey, Key), Option<(ParentValue, Value)>)>>,
84 >
85 where
86 ParentKey: Clone + Hash + Eq,
87 Snapshot::Item: Clone,
88 ParentSnapshot: IntoIterator<Item = (ParentKey, ParentValue)>,
89 ParentValue: Clone,
90 ParentUpdates: Stream,
91 ParentUpdates::Item: IntoIterator<Item = (ParentKey, Option<ParentValue>)>,
92 {
93 let (snapshot, updates) = self.into_inner();
94 let (parent_snapshot, parent_updates) = other.into_inner();
95
96 let initial_state = JoinOneToManyState::new(parent_snapshot, snapshot, f);
97
98 let initial_snapshot: Vec<_> = initial_state
99 .snapshot
100 .clone()
101 .into_iter()
102 .flat_map(|(parent_key, (parent, inner_key_to_inner))| {
103 parent
104 .map(move |parent| {
105 inner_key_to_inner.into_iter().map(move |(key, value)| {
106 ((parent_key.clone(), key), (parent.clone(), value))
107 })
108 })
109 .into_iter()
110 .flat_map(convert::identity)
111 })
112 .collect();
113
114 let parent_updates = parent_updates
115 .map(Either::Left)
116 .map(Some)
117 .chain(stream::once(future::ready(None)));
118
119 let updates = updates
120 .map(Either::Right)
121 .map(Some)
122 .chain(stream::once(future::ready(None)));
123
124 let merged = tokio_stream::StreamExt::merge(parent_updates, updates);
125 let combined_updates = merged
126 .scan(initial_state, |state, next| {
127 let Some(next) = next else {
128 return future::ready(None);
129 };
130
131 let result = state.ingest_update(next);
132
133 let result = if result.len() == 0 {
134 None
135 } else {
136 Some(result)
137 };
138
139 future::ready(Some(result))
140 })
141 .flat_map(stream::iter);
142
143 SnapshotWithUpdates::new(initial_snapshot, combined_updates)
144 }
145}
146
147struct JoinOneToManyState<ParentKey, ParentValue, Key, Value, F> {
148 snapshot: im::HashMap<ParentKey, (Option<ParentValue>, im::HashMap<Key, Value>)>,
149 f: F,
150}
151
152impl<ParentKey, ParentValue, ChildKey, ChildValue, F>
153 JoinOneToManyState<ParentKey, ParentValue, ChildKey, ChildValue, F>
154where
155 ParentKey: Clone + Hash + Eq,
156 ParentValue: Clone,
157 ChildKey: Clone + Hash + Eq,
158 ChildValue: Clone,
159 F: FnMut(&ChildKey) -> ParentKey,
160{
161 fn new(
162 parents: impl IntoIterator<Item = (ParentKey, ParentValue)>,
163 children: impl IntoIterator<Item = (ChildKey, ChildValue)>,
164 mut f: F,
165 ) -> Self {
166 use im::HashMap;
167
168 let mut snapshot = parents
169 .into_iter()
170 .map(|(key, value)| (key, (Some(value), HashMap::new())))
171 .collect::<HashMap<_, _>>();
172
173 for (child_key, child_value) in children {
174 let parent_key = f(&child_key);
175 let (_, children_state) = snapshot
176 .entry(parent_key)
177 .or_insert_with(|| (None, HashMap::new()));
178
179 children_state.insert(child_key, child_value);
180 }
181
182 Self { snapshot, f }
183 }
184
185 fn ingest_update<'a>(
186 &mut self,
187 update: Either<
188 impl IntoIterator<Item = (ParentKey, Option<ParentValue>)>,
189 impl IntoIterator<Item = (ChildKey, Option<ChildValue>)>,
190 >,
191 ) -> Vec<((ParentKey, ChildKey), Option<(ParentValue, ChildValue)>)> {
192 use im::HashMap;
193
194 match update {
195 Either::Left(parents) => parents
196 .into_iter()
197 .flat_map(|(parent_key, parent)| {
198 let (parent_state, children_state) = self
199 .snapshot
200 .entry(parent_key.clone())
201 .or_insert_with(|| (None, HashMap::new()));
202
203 let changed = match &parent {
204 Some(left) => {
205 _ = parent_state.insert(left.to_owned());
206 true
207 }
208 None => parent_state.take().is_some(),
209 };
210
211 if changed {
212 itertools::Either::Left(children_state.clone().into_iter().map(
213 move |(child_key, child)| {
214 let result = parent.as_ref().map(|parent| (parent.clone(), child));
215
216 ((parent_key.clone(), child_key), result)
217 },
218 ))
219 } else {
220 itertools::Either::Right(iter::empty())
221 }
222 })
223 .collect(),
224 Either::Right(children) => children
225 .into_iter()
226 .flat_map(|(child_key, child)| {
227 let parent_key = (self.f)(&child_key);
228
229 let (parent_state, children_state) = self
230 .snapshot
231 .entry(parent_key.clone())
232 .or_insert_with(|| (None, HashMap::new()));
233
234 let changed = match &child {
235 Some(child) => {
236 children_state.insert(child_key.clone(), child.to_owned());
237 true
238 }
239 None => children_state.remove(&child_key).is_some(),
240 };
241
242 return changed.then(|| {
243 let result = match (parent_state, child) {
244 (Some(parent), Some(child)) => Some((parent.to_owned(), child)),
245 _ => None,
246 };
247
248 ((parent_key, child_key), result)
249 });
250 })
251 .collect(),
252 }
253 }
254
255 fn ingest_update_optional_parent<'a>(
256 &mut self,
257 update: Either<
258 impl IntoIterator<Item = (ParentKey, Option<ParentValue>)>,
259 impl IntoIterator<Item = (ChildKey, Option<ChildValue>)>,
260 >,
261 ) -> Vec<(
262 (ParentKey, ChildKey),
263 Option<(Option<ParentValue>, ChildValue)>,
264 )> {
265 use im::HashMap;
266
267 match update {
268 Either::Left(parents) => parents
269 .into_iter()
270 .flat_map(|(key, parent)| {
271 let (parent_state, children_state) = self
272 .snapshot
273 .entry(key.clone())
274 .or_insert_with(|| (None, HashMap::new()));
275
276 match &parent {
277 Some(parent) => {
278 _ = parent_state.insert(parent.to_owned());
279 }
280 None => {
281 _ = parent_state.take();
282 }
283 };
284
285 children_state
286 .clone()
287 .into_iter()
288 .map(move |(child_key, child)| {
289 let result = (parent.clone(), child);
290 ((key.clone(), child_key), Some(result))
291 })
292 })
293 .collect(),
294 Either::Right(children) => children
295 .into_iter()
296 .map(|(child_key, child)| {
297 let parent_key = (self.f)(&child_key);
298
299 let (parent_state, children_state) = self
300 .snapshot
301 .entry(parent_key.clone())
302 .or_insert_with(|| (None, HashMap::new()));
303
304 match &child {
305 Some(right) => {
306 _ = children_state.insert(child_key.clone(), right.to_owned());
307 }
308 None => {
309 _ = children_state.remove(&child_key);
310 }
311 };
312
313 let result = child.map(|child| (parent_state.clone(), child));
314
315 ((parent_key, child_key), result)
316 })
317 .collect(),
318 }
319 }
320}
321
322#[cfg(test)]
323mod tests {
324 use std::collections::HashMap;
325
326 use futures::{channel::mpsc, SinkExt, StreamExt};
327 use tokio_test::{assert_pending, assert_ready};
328
329 use super::*;
330
331 #[tokio::test(flavor = "current_thread", start_paused = true)]
332 async fn test_join_with_parent() {
333 type ParentKey = i32;
334 type ChildKey = (ParentKey, i64);
335
336 struct ParentEntity {
337 key: ParentKey,
338 value: &'static str,
339 }
340
341 impl ParentEntity {
342 fn into_inner(self) -> (ParentKey, &'static str) {
343 (self.key, self.value)
344 }
345 }
346
347 struct ChildEntity {
348 key: ChildKey,
349 value: &'static str,
350 }
351
352 impl ChildEntity {
353 fn into_inner(self) -> (ChildKey, &'static str) {
354 (self.key, self.value)
355 }
356 }
357
358 let parent_snapshot = vec![
359 ParentEntity {
360 key: 1,
361 value: "PARENT_1",
362 },
363 ParentEntity {
364 key: 2,
365 value: "PARENT_2",
366 },
367 ]
368 .into_iter()
369 .map(ParentEntity::into_inner);
370
371 let children_snapshot = vec![
372 ChildEntity {
373 key: (1, 1011),
374 value: "CHILD_1011_PARENT_1",
375 },
376 ChildEntity {
377 key: (2, 1021),
378 value: "CHILD_1021_PARENT_2",
379 },
380 ChildEntity {
381 key: (2, 1022),
382 value: "CHILD_1022_PARENT_2",
383 },
384 ChildEntity {
385 key: (3, 1031),
386 value: "CHILD_1031_PARENT_3",
387 },
388 ]
389 .into_iter()
390 .map(ChildEntity::into_inner);
391
392 enum Update {
393 Parents(Vec<(ParentKey, Option<&'static str>)>),
394 Children(Vec<(ChildKey, Option<&'static str>)>),
395 }
396
397 let expected_initial_snapshot = vec![
398 ((1, (1, 1011_i64)), ("PARENT_1", "CHILD_1011_PARENT_1")),
399 ((2, (2, 1021_i64)), ("PARENT_2", "CHILD_1021_PARENT_2")),
400 ((2, (2, 1022_i64)), ("PARENT_2", "CHILD_1022_PARENT_2")),
401 ]
402 .into_iter()
403 .collect::<HashMap<_, _>>();
404
405 let (mut parent_tx, parent_rx) = mpsc::unbounded();
406 let parent_snapshot_with_updates = SnapshotWithUpdates::new(parent_snapshot, parent_rx);
407
408 let (mut child_tx, child_rx) = mpsc::unbounded();
409 let children_snapshot_with_updates = SnapshotWithUpdates::new(children_snapshot, child_rx);
410
411 let result = children_snapshot_with_updates
412 .join_with_parent(parent_snapshot_with_updates, |child_key| child_key.0);
413
414 let (snapshot, updates) = result.into_inner();
415 let mut updates = updates
416 .map(|updates| updates.into_iter().collect::<Vec<_>>())
417 .boxed();
418 let snapshot = snapshot.into_iter().collect::<HashMap<_, _>>();
419
420 assert_eq!(expected_initial_snapshot, snapshot);
421
422 struct UpdateTestCase {
423 update: Either<
424 Vec<(ParentKey, Option<&'static str>)>,
425 Vec<(ChildKey, Option<&'static str>)>,
426 >,
427 expected_result: Option<Vec<((i32, (i32, i64)), Option<(&'static str, &'static str)>)>>,
428 }
429
430 let test_cases = vec![
431 UpdateTestCase {
432 update: Either::Left(vec![(1, Some("PARENT_1_UPDATE_1"))]),
433 expected_result: Some(vec![(
434 (1, (1, 1011)),
435 Some(("PARENT_1_UPDATE_1", "CHILD_1011_PARENT_1")),
436 )]),
437 },
438 UpdateTestCase {
439 update: Either::Left(vec![(2, Some("PARENT_2_UPDATE_1"))]),
440 expected_result: Some(vec![
441 (
442 (2, (2, 1021)),
443 Some(("PARENT_2_UPDATE_1", "CHILD_1021_PARENT_2")),
444 ),
445 (
446 (2, (2, 1022)),
447 Some(("PARENT_2_UPDATE_1", "CHILD_1022_PARENT_2")),
448 ),
449 ]),
450 },
451 UpdateTestCase {
452 update: Either::Left(vec![(3, Some("PARENT_3"))]),
453 expected_result: Some(vec![(
454 (3, (3, 1031)),
455 Some(("PARENT_3", "CHILD_1031_PARENT_3")),
456 )]),
457 },
458 UpdateTestCase {
459 update: Either::Left(vec![(2, None)]),
460 expected_result: Some(vec![((2, (2, 1021)), None), ((2, (2, 1022)), None)]),
461 },
462 UpdateTestCase {
464 update: Either::Left(vec![(2, None)]),
465 expected_result: None,
466 },
467 UpdateTestCase {
468 update: Either::Right(vec![((1, 1012), Some("CHILD_1012_PARENT_1"))]),
469 expected_result: Some(vec![(
470 (1, (1, 1012)),
471 Some(("PARENT_1_UPDATE_1", "CHILD_1012_PARENT_1")),
472 )]),
473 },
474 UpdateTestCase {
475 update: Either::Right(vec![((1, 1011), None)]),
476 expected_result: Some(vec![((1, (1, 1011)), None)]),
477 },
478 UpdateTestCase {
480 update: Either::Right(vec![((1, 1011), None)]),
481 expected_result: None,
482 },
483 UpdateTestCase {
485 update: Either::Left(vec![(4, Some("PARENT_4"))]),
486 expected_result: None,
487 },
488 UpdateTestCase {
490 update: Either::Right(vec![
491 ((4, 1041), Some("CHILD_1041_PARENT_4")),
492 ((4, 1042), Some("CHILD_1042_PARENT_4")),
493 ((3, 1031), Some("CHILD_1031_PARENT_3_UPDATE_1")),
494 ]),
495 expected_result: Some(vec![
496 ((4, (4, 1041)), Some(("PARENT_4", "CHILD_1041_PARENT_4"))),
497 ((4, (4, 1042)), Some(("PARENT_4", "CHILD_1042_PARENT_4"))),
498 (
499 (3, (3, 1031)),
500 Some(("PARENT_3", "CHILD_1031_PARENT_3_UPDATE_1")),
501 ),
502 ]),
503 },
504 ];
505
506 let waker = futures::task::noop_waker_ref();
507 let mut cx = std::task::Context::from_waker(&waker);
508
509 for (index, tc) in test_cases.into_iter().enumerate() {
510 let UpdateTestCase {
511 update,
512 expected_result,
513 } = tc;
514
515 match update {
516 Either::Left(parents) => parent_tx.send(parents).await.unwrap(),
517 Either::Right(children) => child_tx.send(children).await.unwrap(),
518 }
519
520 let Some(mut expected_result) = expected_result else {
521 assert_pending!(updates.poll_next_unpin(&mut cx));
522 continue;
523 };
524
525 let update = assert_ready!(updates.poll_next_unpin(&mut cx));
526
527 let Some(update) = update else {
528 panic!("Expected Some for update");
529 };
530
531 expected_result.sort_by_key(|r| r.0);
532
533 let mut actual_result = update.into_iter().collect::<Vec<_>>();
534 actual_result.sort_by_key(|r| r.0);
535
536 assert_eq!(actual_result, expected_result, "TestCase {index} failed")
537 }
538 }
539}