1use std::any::Any;
9use std::borrow::Cow;
10use std::cmp::Ordering;
11use std::fmt;
12
13use futures::StreamExt;
14
15use super::hints::Flags;
16use super::id_static::IdStaticSet;
17use super::AsyncSetQuery;
18use super::BoxVertexStream;
19use super::Hints;
20use super::Set;
21use crate::fmt::write_debug;
22use crate::Id;
23use crate::Result;
24use crate::Vertex;
25
26pub struct IntersectionSet {
30 lhs: Set,
31 rhs: Set,
32 hints: Hints,
33}
34
35struct Iter {
36 iter: BoxVertexStream,
37 rhs: Set,
38 ended: bool,
39
40 stop_condition: Option<StopCondition>,
42}
43
44impl Iter {
45 async fn next(&mut self) -> Option<Result<Vertex>> {
46 if self.ended {
47 return None;
48 }
49 loop {
50 let result = self.iter.as_mut().next().await;
51 if let Some(Ok(ref name)) = result {
52 match self.rhs.contains(name).await {
53 Err(err) => break Some(Err(err)),
54 Ok(false) => {
55 if let Some(ref cond) = self.stop_condition {
57 if let Some(id_convert) = self.rhs.id_convert() {
58 if let Ok(Some(id)) = id_convert.vertex_id_optional(name).await {
59 if cond.should_stop_with_id(id) {
60 self.ended = true;
61 return None;
62 }
63 }
64 }
65 }
66 continue;
67 }
68 Ok(true) => {}
69 }
70 }
71 break result;
72 }
73 }
74
75 fn into_stream(self) -> BoxVertexStream {
76 Box::pin(futures::stream::unfold(self, |mut state| async move {
77 let result = state.next().await;
78 result.map(|r| (r, state))
79 }))
80 }
81}
82
83struct StopCondition {
84 order: Ordering,
85 id: Id,
86}
87
88impl StopCondition {
89 fn should_stop_with_id(&self, id: Id) -> bool {
90 id.cmp(&self.id) == self.order
91 }
92}
93
94impl IntersectionSet {
95 pub fn new(lhs: Set, rhs: Set) -> Self {
96 let (lhs, rhs) = if lhs.hints().contains(Flags::FULL)
99 && !rhs.hints().contains(Flags::FULL)
100 && !rhs.hints().contains(Flags::FILTER)
101 && lhs.hints().dag_version() >= rhs.hints().dag_version()
102 {
103 (rhs, lhs)
104 } else {
105 (lhs, rhs)
106 };
107
108 let hints = Hints::new_inherit_idmap_dag(lhs.hints());
109 hints.add_flags(
110 lhs.hints().flags()
111 & (Flags::EMPTY
112 | Flags::ID_DESC
113 | Flags::ID_ASC
114 | Flags::TOPO_DESC
115 | Flags::FILTER),
116 );
117 if lhs.hints().dag_version() >= rhs.hints().dag_version() {
119 hints.add_flags(lhs.hints().flags() & rhs.hints().flags() & Flags::ANCESTORS);
120 }
121 let (rhs_min_id, rhs_max_id) = if hints.id_map_version() >= rhs.hints().id_map_version() {
122 (rhs.hints().min_id(), rhs.hints().max_id())
124 } else {
125 (None, None)
126 };
127 match (lhs.hints().min_id(), rhs_min_id) {
128 (Some(id), None) | (None, Some(id)) => {
129 hints.set_min_id(id);
130 }
131 (Some(id1), Some(id2)) => {
132 hints.set_min_id(id1.max(id2));
133 }
134 (None, None) => {}
135 }
136 match (lhs.hints().max_id(), rhs_max_id) {
137 (Some(id), None) | (None, Some(id)) => {
138 hints.set_max_id(id);
139 }
140 (Some(id1), Some(id2)) => {
141 hints.set_max_id(id1.min(id2));
142 }
143 (None, None) => {}
144 }
145 Self { lhs, rhs, hints }
146 }
147
148 fn is_rhs_id_map_comapatible(&self) -> bool {
149 let lhs_version = self.lhs.hints().id_map_version();
150 let rhs_version = self.rhs.hints().id_map_version();
151 lhs_version == rhs_version || (lhs_version > rhs_version && rhs_version > None)
152 }
153}
154
155#[async_trait::async_trait]
156impl AsyncSetQuery for IntersectionSet {
157 async fn iter(&self) -> Result<BoxVertexStream> {
158 let stop_condition = if !self.is_rhs_id_map_comapatible() {
159 None
160 } else if self.lhs.hints().contains(Flags::ID_ASC) {
161 self.rhs.hints().max_id().map(|id| StopCondition {
162 id,
163 order: Ordering::Greater,
164 })
165 } else if self.lhs.hints().contains(Flags::ID_DESC) {
166 self.rhs.hints().min_id().map(|id| StopCondition {
167 id,
168 order: Ordering::Less,
169 })
170 } else {
171 None
172 };
173
174 let iter = Iter {
175 iter: self.lhs.iter().await?,
176 rhs: self.rhs.clone(),
177 ended: false,
178 stop_condition,
179 };
180 Ok(iter.into_stream())
181 }
182
183 async fn iter_rev(&self) -> Result<BoxVertexStream> {
184 let stop_condition = if !self.is_rhs_id_map_comapatible() {
185 None
186 } else if self.lhs.hints().contains(Flags::ID_DESC) {
187 self.rhs.hints().max_id().map(|id| StopCondition {
188 id,
189 order: Ordering::Greater,
190 })
191 } else if self.lhs.hints().contains(Flags::ID_ASC) {
192 self.rhs.hints().min_id().map(|id| StopCondition {
193 id,
194 order: Ordering::Less,
195 })
196 } else {
197 None
198 };
199
200 let iter = Iter {
201 iter: self.lhs.iter_rev().await?,
202 rhs: self.rhs.clone(),
203 ended: false,
204 stop_condition,
205 };
206 Ok(iter.into_stream())
207 }
208
209 async fn size_hint(&self) -> (u64, Option<u64>) {
210 let lhs_max = self.lhs.size_hint().await.1;
211 let rhs_max = self.rhs.size_hint().await.1;
212 let max = match (lhs_max, rhs_max) {
213 (Some(l), Some(r)) => Some(l.min(r)),
214 _ => None,
215 };
216 (0, max)
217 }
218
219 async fn contains(&self, name: &Vertex) -> Result<bool> {
220 Ok(self.lhs.contains(name).await? && self.rhs.contains(name).await?)
221 }
222
223 async fn contains_fast(&self, name: &Vertex) -> Result<Option<bool>> {
224 for set in &[&self.lhs, &self.rhs] {
225 let contains = set.contains_fast(name).await?;
226 match contains {
227 Some(false) | None => return Ok(contains),
228 Some(true) => {}
229 }
230 }
231 Ok(Some(true))
232 }
233
234 fn as_any(&self) -> &dyn Any {
235 self
236 }
237
238 fn hints(&self) -> &Hints {
239 &self.hints
240 }
241
242 fn specialized_flatten_id(&self) -> Option<Cow<IdStaticSet>> {
243 let lhs = self.lhs.specialized_flatten_id()?;
244 let rhs = self.rhs.specialized_flatten_id()?;
245 let result = IdStaticSet::from_edit_spans(&lhs, &rhs, |a, b| a.intersection(b))?;
246 Some(Cow::Owned(result))
247 }
248}
249
250impl fmt::Debug for IntersectionSet {
251 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
252 write!(f, "<and")?;
253 write_debug(f, &self.lhs)?;
254 write_debug(f, &self.rhs)?;
255 write!(f, ">")
256 }
257}
258
259#[cfg(test)]
260#[allow(clippy::redundant_clone)]
261mod tests {
262 use std::collections::HashSet;
263
264 use super::super::id_lazy::test_utils::lazy_set;
265 use super::super::id_lazy::test_utils::lazy_set_inherit;
266 use super::super::tests::*;
267 use super::*;
268 use crate::Id;
269
270 fn intersection(a: &[u8], b: &[u8]) -> IntersectionSet {
271 let a = Set::from_query(VecQuery::from_bytes(a));
272 let b = Set::from_query(VecQuery::from_bytes(b));
273 IntersectionSet::new(a, b)
274 }
275
276 #[test]
277 fn test_intersection_basic() -> Result<()> {
278 let set = intersection(b"\x11\x33\x55\x22\x44", b"\x44\x33\x66");
279 check_invariants(&set)?;
280 assert_eq!(shorten_iter(ni(set.iter())), ["33", "44"]);
281 assert_eq!(shorten_iter(ni(set.iter_rev())), ["44", "33"]);
282 assert!(!nb(set.is_empty())?);
283 assert_eq!(nb(set.count_slow())?, 2);
284 assert_eq!(shorten_name(nb(set.first())?.unwrap()), "33");
285 assert_eq!(shorten_name(nb(set.last())?.unwrap()), "44");
286 for &b in b"\x11\x22\x55\x66".iter() {
287 assert!(!nb(set.contains(&to_name(b)))?);
288 }
289 Ok(())
290 }
291
292 #[test]
293 fn test_intersection_min_max_id_fast_path() {
294 let a = lazy_set(&[0x70, 0x60, 0x50, 0x40, 0x30, 0x20]);
296 let b = lazy_set_inherit(&[0x70, 0x65, 0x50, 0x40, 0x35, 0x20], &a);
297 let a = Set::from_query(a);
298 let b = Set::from_query(b);
299 a.hints().add_flags(Flags::ID_DESC);
300 b.hints().set_min_id(Id(0x40));
301 b.hints().set_max_id(Id(0x50));
302
303 let set = IntersectionSet::new(a, b.clone());
304 assert_eq!(shorten_iter(ni(set.iter())), ["70", "50", "40"]);
306 assert_eq!(shorten_iter(ni(set.iter_rev())), ["20", "40", "50"]);
308
309 let a = lazy_set(&[0x20, 0x30, 0x40, 0x50, 0x60, 0x70]);
311 let b = lazy_set_inherit(&[0x70, 0x65, 0x50, 0x40, 0x35, 0x20], &a);
312 let a = Set::from_query(a);
313 let b = Set::from_query(b);
314 a.hints().add_flags(Flags::ID_ASC);
315 b.hints().set_min_id(Id(0x40));
316 b.hints().set_max_id(Id(0x50));
317 let set = IntersectionSet::new(a, b.clone());
318 assert_eq!(shorten_iter(ni(set.iter())), ["20", "40", "50"]);
320 assert_eq!(shorten_iter(ni(set.iter_rev())), ["70", "50", "40"]);
322
323 let a = Set::from_query(lazy_set(&[0x20, 0x30, 0x40, 0x50, 0x60, 0x70]));
325 a.hints().add_flags(Flags::ID_ASC);
326 let set = IntersectionSet::new(a, b.clone());
327 assert_eq!(shorten_iter(ni(set.iter())), ["20", "40", "50", "70"]);
329 assert_eq!(shorten_iter(ni(set.iter_rev())), ["70", "50", "40", "20"]);
330 }
331
332 #[test]
333 fn test_size_hint_sets() {
334 check_size_hint_sets(|a, b| IntersectionSet::new(a, b));
335 }
336
337 quickcheck::quickcheck! {
338 fn test_intersection_quickcheck(a: Vec<u8>, b: Vec<u8>) -> bool {
339 let set = intersection(&a, &b);
340 check_invariants(&set).unwrap();
341
342 let count = nb(set.count_slow()).unwrap() as usize;
343 assert!(count <= a.len(), "len({:?}) = {} should <= len({:?})" , &set, count, &a);
344 assert!(count <= b.len(), "len({:?}) = {} should <= len({:?})" , &set, count, &b);
345
346 let contains_a: HashSet<u8> = a.into_iter().filter(|&b| nb(set.contains(&to_name(b))).ok() == Some(true)).collect();
347 let contains_b: HashSet<u8> = b.into_iter().filter(|&b| nb(set.contains(&to_name(b))).ok() == Some(true)).collect();
348 assert_eq!(contains_a, contains_b);
349
350 true
351 }
352 }
353}