1use std::any::Any;
9use std::borrow::Cow;
10use std::fmt;
11
12use futures::StreamExt;
13
14use super::hints::Flags;
15use super::id_static::IdStaticSet;
16use super::AsyncSetQuery;
17use super::BoxVertexStream;
18use super::Hints;
19use super::Set;
20use crate::fmt::write_debug;
21use crate::Result;
22use crate::Vertex;
23
24pub struct DifferenceSet {
28 lhs: Set,
29 rhs: Set,
30 hints: Hints,
31}
32
33struct Iter {
34 iter: BoxVertexStream,
35 rhs: Set,
36}
37
38impl DifferenceSet {
39 pub fn new(lhs: Set, rhs: Set) -> Self {
40 let hints = Hints::new_inherit_idmap_dag(lhs.hints());
41 hints.add_flags(
43 lhs.hints().flags()
44 & (Flags::EMPTY
45 | Flags::ID_DESC
46 | Flags::ID_ASC
47 | Flags::TOPO_DESC
48 | Flags::FILTER),
49 );
50 if let Some(id) = lhs.hints().min_id() {
51 hints.set_min_id(id);
52 }
53 if let Some(id) = lhs.hints().max_id() {
54 hints.set_max_id(id);
55 }
56 Self { lhs, rhs, hints }
57 }
58}
59
60#[async_trait::async_trait]
61impl AsyncSetQuery for DifferenceSet {
62 async fn iter(&self) -> Result<BoxVertexStream> {
63 let iter = Iter {
64 iter: self.lhs.iter().await?,
65 rhs: self.rhs.clone(),
66 };
67 Ok(iter.into_stream())
68 }
69
70 async fn iter_rev(&self) -> Result<BoxVertexStream> {
71 let iter = Iter {
72 iter: self.lhs.iter_rev().await?,
73 rhs: self.rhs.clone(),
74 };
75 Ok(iter.into_stream())
76 }
77
78 async fn contains(&self, name: &Vertex) -> Result<bool> {
79 Ok(self.lhs.contains(name).await? && !self.rhs.contains(name).await?)
80 }
81
82 async fn contains_fast(&self, name: &Vertex) -> Result<Option<bool>> {
83 let lhs_contains = self.lhs.contains_fast(name).await?;
84 if lhs_contains == Some(false) {
85 return Ok(Some(false));
86 }
87 let rhs_contains = self.rhs.contains_fast(name).await?;
88 let result = match (lhs_contains, rhs_contains) {
89 (Some(true), Some(false)) => Some(true),
90 (_, Some(true)) | (Some(false), _) => Some(false),
91 (Some(true), None) | (None, _) => None,
92 };
93 Ok(result)
94 }
95
96 async fn size_hint(&self) -> (u64, Option<u64>) {
97 let (lhs_min, lhs_max) = self.lhs.size_hint().await;
98 let (_rhs_min, rhs_max) = self.rhs.size_hint().await;
99 let min = match rhs_max {
100 None => 0,
101 Some(rhs_max) => lhs_min.saturating_sub(rhs_max),
102 };
103 (min, lhs_max)
104 }
105
106 fn as_any(&self) -> &dyn Any {
107 self
108 }
109
110 fn hints(&self) -> &Hints {
111 &self.hints
112 }
113
114 fn specialized_flatten_id(&self) -> Option<Cow<IdStaticSet>> {
115 let lhs = self.lhs.specialized_flatten_id()?;
116 let rhs = self.rhs.specialized_flatten_id()?;
117 let result = IdStaticSet::from_edit_spans(&lhs, &rhs, |a, b| a.difference(b))?;
118 Some(Cow::Owned(result))
119 }
120}
121
122impl fmt::Debug for DifferenceSet {
123 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
124 write!(f, "<diff")?;
125 write_debug(f, &self.lhs)?;
126 write_debug(f, &self.rhs)?;
127 write!(f, ">")
128 }
129}
130
131impl Iter {
132 async fn next(&mut self) -> Option<Result<Vertex>> {
133 loop {
134 let result = self.iter.as_mut().next().await;
135 if let Some(Ok(ref name)) = result {
136 match self.rhs.contains(name).await {
137 Err(err) => break Some(Err(err)),
138 Ok(true) => continue,
139 _ => {}
140 }
141 }
142 break result;
143 }
144 }
145
146 fn into_stream(self) -> BoxVertexStream {
147 Box::pin(futures::stream::unfold(self, |mut state| async move {
148 let result = state.next().await;
149 result.map(|r| (r, state))
150 }))
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use nonblocking::non_blocking as nb;
157
158 use super::super::tests::*;
159 use super::*;
160
161 fn difference(a: &[u8], b: &[u8]) -> DifferenceSet {
162 let a = Set::from_query(VecQuery::from_bytes(a));
163 let b = Set::from_query(VecQuery::from_bytes(b));
164 DifferenceSet::new(a, b)
165 }
166
167 #[test]
168 fn test_difference_basic() -> Result<()> {
169 let set = difference(b"\x11\x33\x55\x22\x44", b"\x44\x33\x66");
170 check_invariants(&set)?;
171 assert_eq!(shorten_iter(ni(set.iter())), ["11", "55", "22"]);
172 assert_eq!(shorten_iter(ni(set.iter_rev())), ["22", "55", "11"]);
173 assert!(!nb(set.is_empty())??);
174 assert_eq!(nb(set.count_slow())??, 3);
175 assert_eq!(shorten_name(nb(set.first())??.unwrap()), "11");
176 assert_eq!(shorten_name(nb(set.last())??.unwrap()), "22");
177 for &b in b"\x11\x22\x55".iter() {
178 assert!(nb(set.contains(&to_name(b)))??);
179 }
180 for &b in b"\x33\x44\x66".iter() {
181 assert!(!nb(set.contains(&to_name(b)))??);
182 }
183 Ok(())
184 }
185
186 #[test]
187 fn test_size_hint_sets() {
188 check_size_hint_sets(|a, b| DifferenceSet::new(a, b));
189 }
190
191 quickcheck::quickcheck! {
192 fn test_difference_quickcheck(a: Vec<u8>, b: Vec<u8>) -> bool {
193 let set = difference(&a, &b);
194 check_invariants(&set).unwrap();
195
196 let count = nb(set.count_slow()).unwrap().unwrap() as usize;
197 assert!(count <= a.len());
198
199 assert!(b.iter().all(|&b| nb(set.contains(&to_name(b))).unwrap().ok() == Some(false)));
200
201 true
202 }
203 }
204}