Skip to main content

dag/set/
intersection.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 *
4 * This source code is licensed under the MIT license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7
8use std::any::Any;
9use std::borrow::Cow;
10use std::cmp::Ordering;
11use std::fmt;
12
13use futures::StreamExt;
14
15use super::hints::Flags;
16use super::id_static::IdStaticSet;
17use super::AsyncSetQuery;
18use super::BoxVertexStream;
19use super::Hints;
20use super::Set;
21use crate::fmt::write_debug;
22use crate::Id;
23use crate::Result;
24use crate::Vertex;
25
26/// Intersection of 2 sets.
27///
28/// The iteration order is defined by the first set.
29pub struct IntersectionSet {
30    lhs: Set,
31    rhs: Set,
32    hints: Hints,
33}
34
35struct Iter {
36    iter: BoxVertexStream,
37    rhs: Set,
38    ended: bool,
39
40    /// Optional fast path for stop.
41    stop_condition: Option<StopCondition>,
42}
43
44impl Iter {
45    async fn next(&mut self) -> Option<Result<Vertex>> {
46        if self.ended {
47            return None;
48        }
49        loop {
50            let result = self.iter.as_mut().next().await;
51            if let Some(Ok(ref name)) = result {
52                match self.rhs.contains(name).await {
53                    Err(err) => break Some(Err(err)),
54                    Ok(false) => {
55                        // Check if we can stop iteration early using hints.
56                        if let Some(ref cond) = self.stop_condition {
57                            if let Some(id_convert) = self.rhs.id_convert() {
58                                if let Ok(Some(id)) = id_convert.vertex_id_optional(name).await {
59                                    if cond.should_stop_with_id(id) {
60                                        self.ended = true;
61                                        return None;
62                                    }
63                                }
64                            }
65                        }
66                        continue;
67                    }
68                    Ok(true) => {}
69                }
70            }
71            break result;
72        }
73    }
74
75    fn into_stream(self) -> BoxVertexStream {
76        Box::pin(futures::stream::unfold(self, |mut state| async move {
77            let result = state.next().await;
78            result.map(|r| (r, state))
79        }))
80    }
81}
82
83struct StopCondition {
84    order: Ordering,
85    id: Id,
86}
87
88impl StopCondition {
89    fn should_stop_with_id(&self, id: Id) -> bool {
90        id.cmp(&self.id) == self.order
91    }
92}
93
94impl IntersectionSet {
95    pub fn new(lhs: Set, rhs: Set) -> Self {
96        // More efficient if `lhs` is smaller. But `lhs` order matters.
97        // Swap `lhs` and `rhs` if `lhs` is `FULL`.
98        let (lhs, rhs) = if lhs.hints().contains(Flags::FULL)
99            && !rhs.hints().contains(Flags::FULL)
100            && !rhs.hints().contains(Flags::FILTER)
101            && lhs.hints().dag_version() >= rhs.hints().dag_version()
102        {
103            (rhs, lhs)
104        } else {
105            (lhs, rhs)
106        };
107
108        let hints = Hints::new_inherit_idmap_dag(lhs.hints());
109        hints.add_flags(
110            lhs.hints().flags()
111                & (Flags::EMPTY
112                    | Flags::ID_DESC
113                    | Flags::ID_ASC
114                    | Flags::TOPO_DESC
115                    | Flags::FILTER),
116        );
117        // Only keep the ANCESTORS flag if lhs and rhs use a compatible Dag.
118        if lhs.hints().dag_version() >= rhs.hints().dag_version() {
119            hints.add_flags(lhs.hints().flags() & rhs.hints().flags() & Flags::ANCESTORS);
120        }
121        let (rhs_min_id, rhs_max_id) = if hints.id_map_version() >= rhs.hints().id_map_version() {
122            // rhs ids are all known by lhs.
123            (rhs.hints().min_id(), rhs.hints().max_id())
124        } else {
125            (None, None)
126        };
127        match (lhs.hints().min_id(), rhs_min_id) {
128            (Some(id), None) | (None, Some(id)) => {
129                hints.set_min_id(id);
130            }
131            (Some(id1), Some(id2)) => {
132                hints.set_min_id(id1.max(id2));
133            }
134            (None, None) => {}
135        }
136        match (lhs.hints().max_id(), rhs_max_id) {
137            (Some(id), None) | (None, Some(id)) => {
138                hints.set_max_id(id);
139            }
140            (Some(id1), Some(id2)) => {
141                hints.set_max_id(id1.min(id2));
142            }
143            (None, None) => {}
144        }
145        Self { lhs, rhs, hints }
146    }
147
148    fn is_rhs_id_map_comapatible(&self) -> bool {
149        let lhs_version = self.lhs.hints().id_map_version();
150        let rhs_version = self.rhs.hints().id_map_version();
151        lhs_version == rhs_version || (lhs_version > rhs_version && rhs_version > None)
152    }
153}
154
155#[async_trait::async_trait]
156impl AsyncSetQuery for IntersectionSet {
157    async fn iter(&self) -> Result<BoxVertexStream> {
158        let stop_condition = if !self.is_rhs_id_map_comapatible() {
159            None
160        } else if self.lhs.hints().contains(Flags::ID_ASC) {
161            self.rhs.hints().max_id().map(|id| StopCondition {
162                id,
163                order: Ordering::Greater,
164            })
165        } else if self.lhs.hints().contains(Flags::ID_DESC) {
166            self.rhs.hints().min_id().map(|id| StopCondition {
167                id,
168                order: Ordering::Less,
169            })
170        } else {
171            None
172        };
173
174        let iter = Iter {
175            iter: self.lhs.iter().await?,
176            rhs: self.rhs.clone(),
177            ended: false,
178            stop_condition,
179        };
180        Ok(iter.into_stream())
181    }
182
183    async fn iter_rev(&self) -> Result<BoxVertexStream> {
184        let stop_condition = if !self.is_rhs_id_map_comapatible() {
185            None
186        } else if self.lhs.hints().contains(Flags::ID_DESC) {
187            self.rhs.hints().max_id().map(|id| StopCondition {
188                id,
189                order: Ordering::Greater,
190            })
191        } else if self.lhs.hints().contains(Flags::ID_ASC) {
192            self.rhs.hints().min_id().map(|id| StopCondition {
193                id,
194                order: Ordering::Less,
195            })
196        } else {
197            None
198        };
199
200        let iter = Iter {
201            iter: self.lhs.iter_rev().await?,
202            rhs: self.rhs.clone(),
203            ended: false,
204            stop_condition,
205        };
206        Ok(iter.into_stream())
207    }
208
209    async fn size_hint(&self) -> (u64, Option<u64>) {
210        let lhs_max = self.lhs.size_hint().await.1;
211        let rhs_max = self.rhs.size_hint().await.1;
212        let max = match (lhs_max, rhs_max) {
213            (Some(l), Some(r)) => Some(l.min(r)),
214            _ => None,
215        };
216        (0, max)
217    }
218
219    async fn contains(&self, name: &Vertex) -> Result<bool> {
220        Ok(self.lhs.contains(name).await? && self.rhs.contains(name).await?)
221    }
222
223    async fn contains_fast(&self, name: &Vertex) -> Result<Option<bool>> {
224        for set in &[&self.lhs, &self.rhs] {
225            let contains = set.contains_fast(name).await?;
226            match contains {
227                Some(false) | None => return Ok(contains),
228                Some(true) => {}
229            }
230        }
231        Ok(Some(true))
232    }
233
234    fn as_any(&self) -> &dyn Any {
235        self
236    }
237
238    fn hints(&self) -> &Hints {
239        &self.hints
240    }
241
242    fn specialized_flatten_id(&self) -> Option<Cow<IdStaticSet>> {
243        let lhs = self.lhs.specialized_flatten_id()?;
244        let rhs = self.rhs.specialized_flatten_id()?;
245        let result = IdStaticSet::from_edit_spans(&lhs, &rhs, |a, b| a.intersection(b))?;
246        Some(Cow::Owned(result))
247    }
248}
249
250impl fmt::Debug for IntersectionSet {
251    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
252        write!(f, "<and")?;
253        write_debug(f, &self.lhs)?;
254        write_debug(f, &self.rhs)?;
255        write!(f, ">")
256    }
257}
258
259#[cfg(test)]
260#[allow(clippy::redundant_clone)]
261mod tests {
262    use std::collections::HashSet;
263
264    use super::super::id_lazy::test_utils::lazy_set;
265    use super::super::id_lazy::test_utils::lazy_set_inherit;
266    use super::super::tests::*;
267    use super::*;
268    use crate::Id;
269
270    fn intersection(a: &[u8], b: &[u8]) -> IntersectionSet {
271        let a = Set::from_query(VecQuery::from_bytes(a));
272        let b = Set::from_query(VecQuery::from_bytes(b));
273        IntersectionSet::new(a, b)
274    }
275
276    #[test]
277    fn test_intersection_basic() -> Result<()> {
278        let set = intersection(b"\x11\x33\x55\x22\x44", b"\x44\x33\x66");
279        check_invariants(&set)?;
280        assert_eq!(shorten_iter(ni(set.iter())), ["33", "44"]);
281        assert_eq!(shorten_iter(ni(set.iter_rev())), ["44", "33"]);
282        assert!(!nb(set.is_empty())?);
283        assert_eq!(nb(set.count_slow())?, 2);
284        assert_eq!(shorten_name(nb(set.first())?.unwrap()), "33");
285        assert_eq!(shorten_name(nb(set.last())?.unwrap()), "44");
286        for &b in b"\x11\x22\x55\x66".iter() {
287            assert!(!nb(set.contains(&to_name(b)))?);
288        }
289        Ok(())
290    }
291
292    #[test]
293    fn test_intersection_min_max_id_fast_path() {
294        // The min_ids are intentionally wrong to test the fast paths.
295        let a = lazy_set(&[0x70, 0x60, 0x50, 0x40, 0x30, 0x20]);
296        let b = lazy_set_inherit(&[0x70, 0x65, 0x50, 0x40, 0x35, 0x20], &a);
297        let a = Set::from_query(a);
298        let b = Set::from_query(b);
299        a.hints().add_flags(Flags::ID_DESC);
300        b.hints().set_min_id(Id(0x40));
301        b.hints().set_max_id(Id(0x50));
302
303        let set = IntersectionSet::new(a, b.clone());
304        // No "20" - filtered out by min id fast path.
305        assert_eq!(shorten_iter(ni(set.iter())), ["70", "50", "40"]);
306        // No "70" - filtered out by max id fast path.
307        assert_eq!(shorten_iter(ni(set.iter_rev())), ["20", "40", "50"]);
308
309        // Test the reversed sort order.
310        let a = lazy_set(&[0x20, 0x30, 0x40, 0x50, 0x60, 0x70]);
311        let b = lazy_set_inherit(&[0x70, 0x65, 0x50, 0x40, 0x35, 0x20], &a);
312        let a = Set::from_query(a);
313        let b = Set::from_query(b);
314        a.hints().add_flags(Flags::ID_ASC);
315        b.hints().set_min_id(Id(0x40));
316        b.hints().set_max_id(Id(0x50));
317        let set = IntersectionSet::new(a, b.clone());
318        // No "70".
319        assert_eq!(shorten_iter(ni(set.iter())), ["20", "40", "50"]);
320        // No "20".
321        assert_eq!(shorten_iter(ni(set.iter_rev())), ["70", "50", "40"]);
322
323        // If two sets have incompatible IdMap, fast paths are not used.
324        let a = Set::from_query(lazy_set(&[0x20, 0x30, 0x40, 0x50, 0x60, 0x70]));
325        a.hints().add_flags(Flags::ID_ASC);
326        let set = IntersectionSet::new(a, b.clone());
327        // Should contain "70" and "20".
328        assert_eq!(shorten_iter(ni(set.iter())), ["20", "40", "50", "70"]);
329        assert_eq!(shorten_iter(ni(set.iter_rev())), ["70", "50", "40", "20"]);
330    }
331
332    #[test]
333    fn test_size_hint_sets() {
334        check_size_hint_sets(|a, b| IntersectionSet::new(a, b));
335    }
336
337    quickcheck::quickcheck! {
338        fn test_intersection_quickcheck(a: Vec<u8>, b: Vec<u8>) -> bool {
339            let set = intersection(&a, &b);
340            check_invariants(&set).unwrap();
341
342            let count = nb(set.count_slow()).unwrap() as usize;
343            assert!(count <= a.len(), "len({:?}) = {} should <= len({:?})" , &set, count, &a);
344            assert!(count <= b.len(), "len({:?}) = {} should <= len({:?})" , &set, count, &b);
345
346            let contains_a: HashSet<u8> = a.into_iter().filter(|&b| nb(set.contains(&to_name(b))).ok() == Some(true)).collect();
347            let contains_b: HashSet<u8> = b.into_iter().filter(|&b| nb(set.contains(&to_name(b))).ok() == Some(true)).collect();
348            assert_eq!(contains_a, contains_b);
349
350            true
351        }
352    }
353}