1use std::any::Any;
9use std::borrow::Cow;
10use std::fmt;
11use std::task::Poll;
12
13use futures::Stream;
14use futures::StreamExt;
15use serde::Deserialize;
16
17use super::hints::Flags;
18use super::id_static::IdStaticSet;
19use super::AsyncSetQuery;
20use super::BoxVertexStream;
21use super::Hints;
22use super::Set;
23use crate::fmt::write_debug;
24use crate::Result;
25use crate::Vertex;
26
27#[derive(Copy, Clone, Debug, Eq, PartialEq, Deserialize)]
28pub enum UnionOrder {
29 FirstSecond,
32
33 Zip,
38}
39
40pub struct UnionSet {
44 sets: [Set; 2],
45 hints: Hints,
46 order: UnionOrder,
47 #[cfg(test)]
49 pub(crate) test_slow_count: std::sync::atomic::AtomicU64,
50}
51
52impl UnionSet {
53 pub fn new(lhs: Set, rhs: Set) -> Self {
54 let hints = Hints::union(&[lhs.hints(), rhs.hints()]);
55 if hints.id_map().is_some() {
56 if let (Some(id1), Some(id2)) = (lhs.hints().min_id(), rhs.hints().min_id()) {
57 hints.set_min_id(id1.min(id2));
58 }
59 if let (Some(id1), Some(id2)) = (lhs.hints().max_id(), rhs.hints().max_id()) {
60 hints.set_max_id(id1.max(id2));
61 }
62 };
63 hints.add_flags(lhs.hints().flags() & rhs.hints().flags() & Flags::ANCESTORS);
64 if lhs.hints().contains(Flags::FILTER) || rhs.hints().contains(Flags::FILTER) {
65 hints.add_flags(Flags::FILTER);
66 }
67 Self {
68 sets: [lhs, rhs],
69 hints,
70 order: UnionOrder::FirstSecond,
71 #[cfg(test)]
72 test_slow_count: std::sync::atomic::AtomicU64::new(0),
73 }
74 }
75
76 pub fn with_order(mut self, order: UnionOrder) -> Self {
77 self.order = order;
78 self
79 }
80}
81
82#[async_trait::async_trait]
83impl AsyncSetQuery for UnionSet {
84 async fn iter(&self) -> Result<BoxVertexStream> {
85 debug_assert_eq!(self.sets.len(), 2);
86 let diff = self.sets[1].clone() - self.sets[0].clone();
87 let diff_iter = diff.iter().await?;
88 let set0_iter = self.sets[0].iter().await?;
89 let iter: BoxVertexStream = match self.order {
90 UnionOrder::FirstSecond => Box::pin(set0_iter.chain(diff_iter)),
91 UnionOrder::Zip => Box::pin(ZipStream::new(set0_iter, diff_iter)),
92 };
93 Ok(iter)
94 }
95
96 async fn iter_rev(&self) -> Result<BoxVertexStream> {
97 debug_assert_eq!(self.sets.len(), 2);
98 let diff = self.sets[1].clone() - self.sets[0].clone();
99 let diff_iter = diff.iter_rev().await?;
100 let set0_iter = self.sets[0].iter_rev().await?;
101 let iter: BoxVertexStream = match self.order {
102 UnionOrder::FirstSecond => Box::pin(diff_iter.chain(set0_iter)),
103 UnionOrder::Zip => {
104 let mut iter = self.iter().await?;
107 let mut items = Vec::new();
108 while let Some(item) = iter.next().await {
109 items.push(item);
110 }
111 Box::pin(futures::stream::iter(items.into_iter().rev()))
112 }
113 };
114 Ok(iter)
115 }
116
117 async fn size_hint(&self) -> (u64, Option<u64>) {
118 let mut min_size = 0;
119 let mut max_size = Some(0u64);
120 for set in &self.sets {
121 let (min, max) = set.size_hint().await;
122 min_size = min.min(min_size);
123 max_size = match (max_size, max) {
124 (Some(max_size), Some(max)) => max_size.checked_add(max),
125 _ => None,
126 };
127 }
128 (min_size, max_size)
129 }
130
131 async fn count_slow(&self) -> Result<u64> {
132 #[cfg(test)]
133 self.test_slow_count
134 .fetch_add(1, std::sync::atomic::Ordering::AcqRel);
135 debug_assert_eq!(self.sets.len(), 2);
136 let mut count = self.sets[0].count().await?;
139 let mut iter = self.sets[1].iter().await?;
140 while let Some(item) = iter.next().await {
141 let name = item?;
142 if !self.sets[0].contains(&name).await? {
143 count += 1;
144 }
145 }
146 Ok(count)
147 }
148
149 async fn is_empty(&self) -> Result<bool> {
150 for set in &self.sets {
151 if !set.is_empty().await? {
152 return Ok(false);
153 }
154 }
155 Ok(true)
156 }
157
158 async fn contains(&self, name: &Vertex) -> Result<bool> {
159 for set in &self.sets {
160 if set.contains(name).await? {
161 return Ok(true);
162 }
163 }
164 Ok(false)
165 }
166
167 async fn contains_fast(&self, name: &Vertex) -> Result<Option<bool>> {
168 for set in &self.sets {
169 if let Some(result) = set.contains_fast(name).await? {
170 return Ok(Some(result));
171 }
172 }
173 Ok(None)
174 }
175
176 fn as_any(&self) -> &dyn Any {
177 self
178 }
179
180 fn hints(&self) -> &Hints {
181 &self.hints
182 }
183
184 fn specialized_flatten_id(&self) -> Option<Cow<IdStaticSet>> {
185 let mut result = self.sets[0].specialized_flatten_id()?;
186 for set in &self.sets[1..] {
187 let other = set.specialized_flatten_id()?;
188 result = Cow::Owned(IdStaticSet::from_edit_spans(&result, &other, |a, b| {
189 a.union(b)
190 })?);
191 }
192 Some(result)
193 }
194}
195
196struct ZipStream {
199 iters: [BoxVertexStream; 2],
201 iter_ended: [bool; 2],
203 next_iter: usize,
205}
206
207impl ZipStream {
208 fn new(iter1: BoxVertexStream, iter2: BoxVertexStream) -> Self {
209 Self {
210 iters: [iter1, iter2],
211 iter_ended: [false, false],
212 next_iter: 0,
213 }
214 }
215}
216
217impl Stream for ZipStream {
218 type Item = Result<Vertex>;
219
220 fn poll_next(
221 mut self: std::pin::Pin<&mut Self>,
222 cx: &mut std::task::Context<'_>,
223 ) -> Poll<Option<Self::Item>> {
224 'again: loop {
225 let index = self.next_iter;
226 if self.iter_ended[index] {
227 return Poll::Ready(None);
228 }
229 match self.iters[index].as_mut().poll_next(cx) {
230 Poll::Ready(v) => {
231 if v.is_none() {
232 self.iter_ended[index] = true;
234 }
235 if !self.iter_ended[index ^ 1] {
236 self.next_iter = index ^ 1;
238 }
239 if v.is_none() {
240 continue 'again;
242 }
243 return Poll::Ready(v);
244 }
245 Poll::Pending => return Poll::Pending,
246 }
247 }
248 }
249}
250
251impl fmt::Debug for UnionSet {
252 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
253 write!(f, "<or")?;
254 write_debug(f, &self.sets[0])?;
255 write_debug(f, &self.sets[1])?;
256 match self.order {
257 UnionOrder::FirstSecond => {}
258 order => write!(f, " (order={:?})", order)?,
259 }
260 write!(f, ">")
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use std::collections::HashSet;
267
268 use super::super::tests::*;
269 use super::*;
270
271 fn union(a: &[u8], b: &[u8]) -> UnionSet {
272 let a = Set::from_query(VecQuery::from_bytes(a));
273 let b = Set::from_query(VecQuery::from_bytes(b));
274 UnionSet::new(a, b)
275 }
276
277 #[test]
278 fn test_union_basic() -> Result<()> {
279 let set = union(b"\x11\x33\x22", b"\x44\x11\x55\x33");
281 check_invariants(&set)?;
282 assert_eq!(shorten_iter(ni(set.iter())), ["11", "33", "22", "44", "55"]);
283 assert_eq!(
284 shorten_iter(ni(set.iter_rev())),
285 ["55", "44", "22", "33", "11"]
286 );
287 assert!(!nb(set.is_empty())?);
288 assert_eq!(nb(set.count())?, 5);
289 assert_eq!(shorten_name(nb(set.first())?.unwrap()), "11");
290 assert_eq!(shorten_name(nb(set.last())?.unwrap()), "55");
291 for &b in b"\x11\x22\x33\x44\x55".iter() {
292 assert!(nb(set.contains(&to_name(b)))?);
293 }
294 for &b in b"\x66\x77\x88".iter() {
295 assert!(!nb(set.contains(&to_name(b)))?);
296 }
297 Ok(())
298 }
299
300 #[test]
301 fn test_union_zip_order() -> Result<()> {
302 let set = union(b"\x33\x44\x55", b"").with_order(UnionOrder::Zip);
303 check_invariants(&set)?;
304 assert_eq!(shorten_iter(ni(set.iter())), ["33", "44", "55"]);
305
306 let set = union(b"", b"\x33\x44\x55").with_order(UnionOrder::Zip);
307 check_invariants(&set)?;
308 assert_eq!(shorten_iter(ni(set.iter())), ["33", "44", "55"]);
309
310 let set = union(b"\x33\x44\x55", b"\x55\x33\x22\x11").with_order(UnionOrder::Zip);
311 assert_eq!(shorten_iter(ni(set.iter())), ["33", "22", "44", "11", "55"]);
312 check_invariants(&set)?;
313
314 Ok(())
315 }
316
317 #[test]
318 fn test_size_hint_sets() {
319 check_size_hint_sets(|a, b| UnionSet::new(a, b));
320 check_size_hint_sets(|a, b| UnionSet::new(a, b).with_order(UnionOrder::Zip));
321 }
322
323 quickcheck::quickcheck! {
324 fn test_union_quickcheck(a: Vec<u8>, b: Vec<u8>) -> bool {
325 let set = union(&a, &b);
326 check_invariants(&set).unwrap();
327
328 let count = nb(set.count()).unwrap() as usize;
329 assert!(count <= a.len() + b.len());
330
331 let set2: HashSet<_> = a.iter().chain(b.iter()).cloned().collect();
332 assert_eq!(count, set2.len());
333
334 assert!(a.iter().all(|&b| nb(set.contains(&to_name(b))).ok() == Some(true)));
335 assert!(b.iter().all(|&b| nb(set.contains(&to_name(b))).ok() == Some(true)));
336
337 true
338 }
339 }
340}