Skip to main content

dag/set/
union.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 *
4 * This source code is licensed under the MIT license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7
8use std::any::Any;
9use std::borrow::Cow;
10use std::fmt;
11use std::task::Poll;
12
13use futures::Stream;
14use futures::StreamExt;
15use serde::Deserialize;
16
17use super::hints::Flags;
18use super::id_static::IdStaticSet;
19use super::AsyncSetQuery;
20use super::BoxVertexStream;
21use super::Hints;
22use super::Set;
23use crate::fmt::write_debug;
24use crate::Result;
25use crate::Vertex;
26
27#[derive(Copy, Clone, Debug, Eq, PartialEq, Deserialize)]
28pub enum UnionOrder {
29    /// The first set is iterated first using its own order.
30    /// Then the second set is iterated, with duplications skipped.
31    FirstSecond,
32
33    /// Take one item from the first set, then one item from the second set
34    /// (if not exist in the first set), and repeat. Note this is slightly
35    /// different from "zip" as the second set is treated as not having
36    /// items duplicated with the first set.
37    Zip,
38}
39
40/// Union of 2 sets.
41///
42/// See [`UnionOrder`] for iteration order.
43pub struct UnionSet {
44    sets: [Set; 2],
45    hints: Hints,
46    order: UnionOrder,
47    // Count of the "count_slow" calls.
48    #[cfg(test)]
49    pub(crate) test_slow_count: std::sync::atomic::AtomicU64,
50}
51
52impl UnionSet {
53    pub fn new(lhs: Set, rhs: Set) -> Self {
54        let hints = Hints::union(&[lhs.hints(), rhs.hints()]);
55        if hints.id_map().is_some() {
56            if let (Some(id1), Some(id2)) = (lhs.hints().min_id(), rhs.hints().min_id()) {
57                hints.set_min_id(id1.min(id2));
58            }
59            if let (Some(id1), Some(id2)) = (lhs.hints().max_id(), rhs.hints().max_id()) {
60                hints.set_max_id(id1.max(id2));
61            }
62        };
63        hints.add_flags(lhs.hints().flags() & rhs.hints().flags() & Flags::ANCESTORS);
64        if lhs.hints().contains(Flags::FILTER) || rhs.hints().contains(Flags::FILTER) {
65            hints.add_flags(Flags::FILTER);
66        }
67        Self {
68            sets: [lhs, rhs],
69            hints,
70            order: UnionOrder::FirstSecond,
71            #[cfg(test)]
72            test_slow_count: std::sync::atomic::AtomicU64::new(0),
73        }
74    }
75
76    pub fn with_order(mut self, order: UnionOrder) -> Self {
77        self.order = order;
78        self
79    }
80}
81
82#[async_trait::async_trait]
83impl AsyncSetQuery for UnionSet {
84    async fn iter(&self) -> Result<BoxVertexStream> {
85        debug_assert_eq!(self.sets.len(), 2);
86        let diff = self.sets[1].clone() - self.sets[0].clone();
87        let diff_iter = diff.iter().await?;
88        let set0_iter = self.sets[0].iter().await?;
89        let iter: BoxVertexStream = match self.order {
90            UnionOrder::FirstSecond => Box::pin(set0_iter.chain(diff_iter)),
91            UnionOrder::Zip => Box::pin(ZipStream::new(set0_iter, diff_iter)),
92        };
93        Ok(iter)
94    }
95
96    async fn iter_rev(&self) -> Result<BoxVertexStream> {
97        debug_assert_eq!(self.sets.len(), 2);
98        let diff = self.sets[1].clone() - self.sets[0].clone();
99        let diff_iter = diff.iter_rev().await?;
100        let set0_iter = self.sets[0].iter_rev().await?;
101        let iter: BoxVertexStream = match self.order {
102            UnionOrder::FirstSecond => Box::pin(diff_iter.chain(set0_iter)),
103            UnionOrder::Zip => {
104                // note: cannot use ZipStream::new(diff_iter, set_iter) when two iters have
105                // different lengths.
106                let mut iter = self.iter().await?;
107                let mut items = Vec::new();
108                while let Some(item) = iter.next().await {
109                    items.push(item);
110                }
111                Box::pin(futures::stream::iter(items.into_iter().rev()))
112            }
113        };
114        Ok(iter)
115    }
116
117    async fn size_hint(&self) -> (u64, Option<u64>) {
118        let mut min_size = 0;
119        let mut max_size = Some(0u64);
120        for set in &self.sets {
121            let (min, max) = set.size_hint().await;
122            min_size = min.min(min_size);
123            max_size = match (max_size, max) {
124                (Some(max_size), Some(max)) => max_size.checked_add(max),
125                _ => None,
126            };
127        }
128        (min_size, max_size)
129    }
130
131    async fn count_slow(&self) -> Result<u64> {
132        #[cfg(test)]
133        self.test_slow_count
134            .fetch_add(1, std::sync::atomic::Ordering::AcqRel);
135        debug_assert_eq!(self.sets.len(), 2);
136        // This is more efficient if sets[0] is a large set that has a fast path
137        // for "count()".
138        let mut count = self.sets[0].count().await?;
139        let mut iter = self.sets[1].iter().await?;
140        while let Some(item) = iter.next().await {
141            let name = item?;
142            if !self.sets[0].contains(&name).await? {
143                count += 1;
144            }
145        }
146        Ok(count)
147    }
148
149    async fn is_empty(&self) -> Result<bool> {
150        for set in &self.sets {
151            if !set.is_empty().await? {
152                return Ok(false);
153            }
154        }
155        Ok(true)
156    }
157
158    async fn contains(&self, name: &Vertex) -> Result<bool> {
159        for set in &self.sets {
160            if set.contains(name).await? {
161                return Ok(true);
162            }
163        }
164        Ok(false)
165    }
166
167    async fn contains_fast(&self, name: &Vertex) -> Result<Option<bool>> {
168        for set in &self.sets {
169            if let Some(result) = set.contains_fast(name).await? {
170                return Ok(Some(result));
171            }
172        }
173        Ok(None)
174    }
175
176    fn as_any(&self) -> &dyn Any {
177        self
178    }
179
180    fn hints(&self) -> &Hints {
181        &self.hints
182    }
183
184    fn specialized_flatten_id(&self) -> Option<Cow<IdStaticSet>> {
185        let mut result = self.sets[0].specialized_flatten_id()?;
186        for set in &self.sets[1..] {
187            let other = set.specialized_flatten_id()?;
188            result = Cow::Owned(IdStaticSet::from_edit_spans(&result, &other, |a, b| {
189                a.union(b)
190            })?);
191        }
192        Some(result)
193    }
194}
195
196/// Iterate through iter1 and iter2 in turn until both iters end.
197/// For example, ZipStream([1,2], [3,4,5,6]) produces: [1,3,2,4,5,6].
198struct ZipStream {
199    // note: iters[1] should not overlap with iter[0]
200    iters: [BoxVertexStream; 2],
201    // Whether the stream has ended.
202    iter_ended: [bool; 2],
203    // Which to pull next, 0 or 1.
204    next_iter: usize,
205}
206
207impl ZipStream {
208    fn new(iter1: BoxVertexStream, iter2: BoxVertexStream) -> Self {
209        Self {
210            iters: [iter1, iter2],
211            iter_ended: [false, false],
212            next_iter: 0,
213        }
214    }
215}
216
217impl Stream for ZipStream {
218    type Item = Result<Vertex>;
219
220    fn poll_next(
221        mut self: std::pin::Pin<&mut Self>,
222        cx: &mut std::task::Context<'_>,
223    ) -> Poll<Option<Self::Item>> {
224        'again: loop {
225            let index = self.next_iter;
226            if self.iter_ended[index] {
227                return Poll::Ready(None);
228            }
229            match self.iters[index].as_mut().poll_next(cx) {
230                Poll::Ready(v) => {
231                    if v.is_none() {
232                        // Mark the current iterator as ended.
233                        self.iter_ended[index] = true;
234                    }
235                    if !self.iter_ended[index ^ 1] {
236                        // Switch to the other iterator if it hasn't ended.
237                        self.next_iter = index ^ 1;
238                    }
239                    if v.is_none() {
240                        // Try the other iterator.
241                        continue 'again;
242                    }
243                    return Poll::Ready(v);
244                }
245                Poll::Pending => return Poll::Pending,
246            }
247        }
248    }
249}
250
251impl fmt::Debug for UnionSet {
252    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
253        write!(f, "<or")?;
254        write_debug(f, &self.sets[0])?;
255        write_debug(f, &self.sets[1])?;
256        match self.order {
257            UnionOrder::FirstSecond => {}
258            order => write!(f, " (order={:?})", order)?,
259        }
260        write!(f, ">")
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use std::collections::HashSet;
267
268    use super::super::tests::*;
269    use super::*;
270
271    fn union(a: &[u8], b: &[u8]) -> UnionSet {
272        let a = Set::from_query(VecQuery::from_bytes(a));
273        let b = Set::from_query(VecQuery::from_bytes(b));
274        UnionSet::new(a, b)
275    }
276
277    #[test]
278    fn test_union_basic() -> Result<()> {
279        // 'a' overlaps with 'b'. UnionSet should de-duplicate items.
280        let set = union(b"\x11\x33\x22", b"\x44\x11\x55\x33");
281        check_invariants(&set)?;
282        assert_eq!(shorten_iter(ni(set.iter())), ["11", "33", "22", "44", "55"]);
283        assert_eq!(
284            shorten_iter(ni(set.iter_rev())),
285            ["55", "44", "22", "33", "11"]
286        );
287        assert!(!nb(set.is_empty())?);
288        assert_eq!(nb(set.count())?, 5);
289        assert_eq!(shorten_name(nb(set.first())?.unwrap()), "11");
290        assert_eq!(shorten_name(nb(set.last())?.unwrap()), "55");
291        for &b in b"\x11\x22\x33\x44\x55".iter() {
292            assert!(nb(set.contains(&to_name(b)))?);
293        }
294        for &b in b"\x66\x77\x88".iter() {
295            assert!(!nb(set.contains(&to_name(b)))?);
296        }
297        Ok(())
298    }
299
300    #[test]
301    fn test_union_zip_order() -> Result<()> {
302        let set = union(b"\x33\x44\x55", b"").with_order(UnionOrder::Zip);
303        check_invariants(&set)?;
304        assert_eq!(shorten_iter(ni(set.iter())), ["33", "44", "55"]);
305
306        let set = union(b"", b"\x33\x44\x55").with_order(UnionOrder::Zip);
307        check_invariants(&set)?;
308        assert_eq!(shorten_iter(ni(set.iter())), ["33", "44", "55"]);
309
310        let set = union(b"\x33\x44\x55", b"\x55\x33\x22\x11").with_order(UnionOrder::Zip);
311        assert_eq!(shorten_iter(ni(set.iter())), ["33", "22", "44", "11", "55"]);
312        check_invariants(&set)?;
313
314        Ok(())
315    }
316
317    #[test]
318    fn test_size_hint_sets() {
319        check_size_hint_sets(|a, b| UnionSet::new(a, b));
320        check_size_hint_sets(|a, b| UnionSet::new(a, b).with_order(UnionOrder::Zip));
321    }
322
323    quickcheck::quickcheck! {
324        fn test_union_quickcheck(a: Vec<u8>, b: Vec<u8>) -> bool {
325            let set = union(&a, &b);
326            check_invariants(&set).unwrap();
327
328            let count = nb(set.count()).unwrap() as usize;
329            assert!(count <= a.len() + b.len());
330
331            let set2: HashSet<_> = a.iter().chain(b.iter()).cloned().collect();
332            assert_eq!(count, set2.len());
333
334            assert!(a.iter().all(|&b| nb(set.contains(&to_name(b))).ok() == Some(true)));
335            assert!(b.iter().all(|&b| nb(set.contains(&to_name(b))).ok() == Some(true)));
336
337            true
338        }
339    }
340}