1use std::ops::{ControlFlow, Range, RangeInclusive};
2
3#[cfg(feature = "rayon")]
4use rayon::join;
5
6use crate::{ITree, Item, Node};
7
8impl<K, V, S> ITree<K, V, S>
9where
10 S: AsRef<[Node<K, V>]>,
11{
12 pub fn query<'a, I, H, R>(&'a self, interval: I, handler: H) -> ControlFlow<R>
18 where
19 K: Ord,
20 I: Interval<K>,
21 H: FnMut(&'a Item<K, V>) -> ControlFlow<R>,
22 {
23 let nodes = self.nodes.as_ref();
24
25 if !nodes.is_empty() {
26 query(&mut QueryArgs { interval, handler }, nodes)?;
27 }
28
29 ControlFlow::Continue(())
30 }
31
32 #[cfg(feature = "rayon")]
33 pub fn par_query<'a, I, H, R>(&'a self, interval: I, handler: H) -> ControlFlow<R>
39 where
40 K: Ord + Send + Sync,
41 V: Sync,
42 I: Interval<K> + Sync,
43 H: Fn(&'a Item<K, V>) -> ControlFlow<R> + Sync,
44 R: Send,
45 {
46 let nodes = self.nodes.as_ref();
47
48 if !nodes.is_empty() {
49 par_query(&QueryArgs { interval, handler }, nodes)?;
50 }
51
52 ControlFlow::Continue(())
53 }
54}
55
56pub trait Interval<K> {
57 fn go_left(&self, max: &K) -> bool;
58 fn go_right(&self, start: &K) -> bool;
59 fn overlaps(&self, end: &K) -> bool;
60}
61
62impl<K> Interval<K> for Range<K>
63where
64 K: Ord,
65{
66 fn go_left(&self, max: &K) -> bool {
67 &self.start < max
68 }
69
70 fn go_right(&self, start: &K) -> bool {
71 &self.end > start
72 }
73
74 fn overlaps(&self, end: &K) -> bool {
75 &self.start < end
76 }
77}
78
79impl<K> Interval<K> for RangeInclusive<K>
80where
81 K: Ord,
82{
83 fn go_left(&self, max: &K) -> bool {
84 self.start() <= max
85 }
86
87 fn go_right(&self, start: &K) -> bool {
88 self.end() >= start
89 }
90
91 fn overlaps(&self, end: &K) -> bool {
92 self.start() <= end
93 }
94}
95
96struct QueryArgs<I, H> {
97 interval: I,
98 handler: H,
99}
100
101fn query<'a, I, H, K, V, R>(
102 args: &mut QueryArgs<I, H>,
103 mut nodes: &'a [Node<K, V>],
104) -> ControlFlow<R>
105where
106 K: Ord,
107 I: Interval<K>,
108 H: FnMut(&'a (Range<K>, V)) -> ControlFlow<R>,
109{
110 loop {
111 let (left, [mid, right @ ..]) = nodes.split_at(nodes.len() / 2) else {
112 unreachable!()
113 };
114
115 let mut go_left = false;
116 let mut go_right = false;
117
118 if args.interval.go_left(&mid.1) {
119 if !left.is_empty() {
120 go_left = true;
121 }
122
123 if args.interval.go_right(&(mid.0).0.start) {
124 if !right.is_empty() {
125 go_right = true;
126 }
127
128 if args.interval.overlaps(&(mid.0).0.end) {
129 (args.handler)(&mid.0)?;
130 }
131 }
132 }
133
134 match (go_left, go_right) {
135 (true, true) => {
136 query(args, left)?;
137
138 nodes = right;
139 }
140 (true, false) => nodes = left,
141 (false, true) => nodes = right,
142 (false, false) => return ControlFlow::Continue(()),
143 }
144 }
145}
146
147#[cfg(feature = "rayon")]
148fn par_query<'a, I, H, K, V, R>(
149 args: &QueryArgs<I, H>,
150 mut nodes: &'a [Node<K, V>],
151) -> ControlFlow<R>
152where
153 K: Ord + Send + Sync,
154 V: Sync,
155 I: Interval<K> + Sync,
156 H: Fn(&'a (Range<K>, V)) -> ControlFlow<R> + Sync,
157 R: Send,
158{
159 loop {
160 let (left, [mid, right @ ..]) = nodes.split_at(nodes.len() / 2) else {
161 unreachable!()
162 };
163
164 let mut go_left = false;
165 let mut go_right = false;
166
167 if args.interval.go_left(&mid.1) {
168 if !left.is_empty() {
169 go_left = true;
170 }
171
172 if args.interval.go_right(&(mid.0).0.start) {
173 if !right.is_empty() {
174 go_right = true;
175 }
176
177 if args.interval.overlaps(&(mid.0).0.end) {
178 (args.handler)(&mid.0)?;
179 }
180 }
181 }
182
183 match (go_left, go_right) {
184 (true, true) => {
185 let (left, right) = join(|| par_query(args, left), || par_query(args, right));
186
187 left?;
188 right?;
189
190 return ControlFlow::Continue(());
191 }
192 (true, false) => nodes = left,
193 (false, true) => nodes = right,
194 (false, false) => return ControlFlow::Continue(()),
195 }
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202
203 #[cfg(feature = "rayon")]
204 use std::sync::Mutex;
205
206 use proptest::{collection::vec, test_runner::TestRunner};
207
208 #[test]
209 fn query_random() {
210 const DOM: Range<i32> = -1000..1000;
211 const LEN: usize = 1000_usize;
212
213 TestRunner::default()
214 .run(
215 &(vec(DOM, LEN), vec(DOM, LEN), DOM, DOM),
216 |(start, end, query_start, query_end)| {
217 let tree = ITree::<_, _>::new(
218 start
219 .iter()
220 .zip(&end)
221 .map(|(&start, &end)| (start..end, ())),
222 );
223
224 let mut result1 = Vec::new();
225 tree.query(query_start..query_end, |(range, ())| {
226 result1.push(range);
227 ControlFlow::<()>::Continue(())
228 })
229 .continue_value()
230 .unwrap();
231
232 let mut result2 = tree
233 .iter()
234 .filter(|(range, ())| query_end > range.start && query_start < range.end)
235 .map(|(range, ())| range)
236 .collect::<Vec<_>>();
237
238 result1.sort_unstable_by_key(|range| (range.start, range.end));
239 result2.sort_unstable_by_key(|range| (range.start, range.end));
240 assert_eq!(result1, result2);
241
242 Ok(())
243 },
244 )
245 .unwrap()
246 }
247
248 #[test]
249 fn query_random_inclusive() {
250 const DOM: Range<i32> = -1000..1000;
251 const LEN: usize = 1000_usize;
252
253 TestRunner::default()
254 .run(
255 &(vec(DOM, LEN), vec(DOM, LEN), DOM, DOM),
256 |(start, end, query_start, query_end)| {
257 let tree = ITree::<_, _>::new(
258 start
259 .iter()
260 .zip(&end)
261 .map(|(&start, &end)| (start..end, ())),
262 );
263
264 let mut result1 = Vec::new();
265 tree.query(query_start..=query_end, |(range, ())| {
266 result1.push(range);
267 ControlFlow::<()>::Continue(())
268 })
269 .continue_value()
270 .unwrap();
271
272 let mut result2 = tree
273 .iter()
274 .filter(|(range, ())| query_end >= range.start && query_start <= range.end)
275 .map(|(range, ())| range)
276 .collect::<Vec<_>>();
277
278 result1.sort_unstable_by_key(|range| (range.start, range.end));
279 result2.sort_unstable_by_key(|range| (range.start, range.end));
280 assert_eq!(result1, result2);
281
282 Ok(())
283 },
284 )
285 .unwrap()
286 }
287
288 #[cfg(feature = "rayon")]
289 #[test]
290 fn par_query_random() {
291 const DOM: Range<i32> = -1000..1000;
292 const LEN: usize = 1000_usize;
293
294 TestRunner::default()
295 .run(
296 &(vec(DOM, LEN), vec(DOM, LEN), DOM, DOM),
297 |(start, end, query_start, query_end)| {
298 let tree = ITree::<_, _>::par_new(
299 start
300 .iter()
301 .zip(&end)
302 .map(|(&start, &end)| (start..end, ())),
303 );
304
305 let result1 = Mutex::new(Vec::new());
306 tree.par_query(query_start..query_end, |(range, ())| {
307 result1.lock().unwrap().push(range);
308 ControlFlow::<()>::Continue(())
309 })
310 .continue_value()
311 .unwrap();
312 let mut result1 = result1.into_inner().unwrap();
313
314 let mut result2 = tree
315 .iter()
316 .filter(|(range, ())| query_end > range.start && query_start < range.end)
317 .map(|(range, ())| range)
318 .collect::<Vec<_>>();
319
320 result1.sort_unstable_by_key(|range| (range.start, range.end));
321 result2.sort_unstable_by_key(|range| (range.start, range.end));
322 assert_eq!(result1, result2);
323
324 Ok(())
325 },
326 )
327 .unwrap()
328 }
329
330 #[cfg(feature = "rayon")]
331 #[test]
332 fn par_query_random_inclusive() {
333 const DOM: Range<i32> = -1000..1000;
334 const LEN: usize = 1000_usize;
335
336 TestRunner::default()
337 .run(
338 &(vec(DOM, LEN), vec(DOM, LEN), DOM, DOM),
339 |(start, end, query_start, query_end)| {
340 let tree = ITree::<_, _>::par_new(
341 start
342 .iter()
343 .zip(&end)
344 .map(|(&start, &end)| (start..end, ())),
345 );
346
347 let result1 = Mutex::new(Vec::new());
348 tree.par_query(query_start..=query_end, |(range, ())| {
349 result1.lock().unwrap().push(range);
350 ControlFlow::<()>::Continue(())
351 })
352 .continue_value()
353 .unwrap();
354 let mut result1 = result1.into_inner().unwrap();
355
356 let mut result2 = tree
357 .iter()
358 .filter(|(range, ())| query_end >= range.start && query_start <= range.end)
359 .map(|(range, ())| range)
360 .collect::<Vec<_>>();
361
362 result1.sort_unstable_by_key(|range| (range.start, range.end));
363 result2.sort_unstable_by_key(|range| (range.start, range.end));
364 assert_eq!(result1, result2);
365
366 Ok(())
367 },
368 )
369 .unwrap()
370 }
371}