sif_itree/
query.rs

1use std::ops::{ControlFlow, Range, RangeInclusive};
2
3#[cfg(feature = "rayon")]
4use rayon::join;
5
6use crate::{ITree, Item, Node};
7
8impl<K, V, S> ITree<K, V, S>
9where
10    S: AsRef<[Node<K, V>]>,
11{
12    /// Query for all intervals overlapping the given interval
13    ///
14    /// The stored intervals are interpeted as half-open or closed,
15    /// depending on the type of the given query interval,
16    /// either [`Range`] or [`RangeInclusive`].
17    pub fn query<'a, I, H, R>(&'a self, interval: I, handler: H) -> ControlFlow<R>
18    where
19        K: Ord,
20        I: Interval<K>,
21        H: FnMut(&'a Item<K, V>) -> ControlFlow<R>,
22    {
23        let nodes = self.nodes.as_ref();
24
25        if !nodes.is_empty() {
26            query(&mut QueryArgs { interval, handler }, nodes)?;
27        }
28
29        ControlFlow::Continue(())
30    }
31
32    #[cfg(feature = "rayon")]
33    /// Query for all intervals overlapping the given interval, in parallel
34    ///
35    /// The stored intervals are interpeted as half-open or closed,
36    /// depending on the type of the given query interval,
37    /// either [`Range`] or [`RangeInclusive`].
38    pub fn par_query<'a, I, H, R>(&'a self, interval: I, handler: H) -> ControlFlow<R>
39    where
40        K: Ord + Send + Sync,
41        V: Sync,
42        I: Interval<K> + Sync,
43        H: Fn(&'a Item<K, V>) -> ControlFlow<R> + Sync,
44        R: Send,
45    {
46        let nodes = self.nodes.as_ref();
47
48        if !nodes.is_empty() {
49            par_query(&QueryArgs { interval, handler }, nodes)?;
50        }
51
52        ControlFlow::Continue(())
53    }
54}
55
56pub trait Interval<K> {
57    fn go_left(&self, max: &K) -> bool;
58    fn go_right(&self, start: &K) -> bool;
59    fn overlaps(&self, end: &K) -> bool;
60}
61
62impl<K> Interval<K> for Range<K>
63where
64    K: Ord,
65{
66    fn go_left(&self, max: &K) -> bool {
67        &self.start < max
68    }
69
70    fn go_right(&self, start: &K) -> bool {
71        &self.end > start
72    }
73
74    fn overlaps(&self, end: &K) -> bool {
75        &self.start < end
76    }
77}
78
79impl<K> Interval<K> for RangeInclusive<K>
80where
81    K: Ord,
82{
83    fn go_left(&self, max: &K) -> bool {
84        self.start() <= max
85    }
86
87    fn go_right(&self, start: &K) -> bool {
88        self.end() >= start
89    }
90
91    fn overlaps(&self, end: &K) -> bool {
92        self.start() <= end
93    }
94}
95
96struct QueryArgs<I, H> {
97    interval: I,
98    handler: H,
99}
100
101fn query<'a, I, H, K, V, R>(
102    args: &mut QueryArgs<I, H>,
103    mut nodes: &'a [Node<K, V>],
104) -> ControlFlow<R>
105where
106    K: Ord,
107    I: Interval<K>,
108    H: FnMut(&'a (Range<K>, V)) -> ControlFlow<R>,
109{
110    loop {
111        let (left, [mid, right @ ..]) = nodes.split_at(nodes.len() / 2) else {
112            unreachable!()
113        };
114
115        let mut go_left = false;
116        let mut go_right = false;
117
118        if args.interval.go_left(&mid.1) {
119            if !left.is_empty() {
120                go_left = true;
121            }
122
123            if args.interval.go_right(&(mid.0).0.start) {
124                if !right.is_empty() {
125                    go_right = true;
126                }
127
128                if args.interval.overlaps(&(mid.0).0.end) {
129                    (args.handler)(&mid.0)?;
130                }
131            }
132        }
133
134        match (go_left, go_right) {
135            (true, true) => {
136                query(args, left)?;
137
138                nodes = right;
139            }
140            (true, false) => nodes = left,
141            (false, true) => nodes = right,
142            (false, false) => return ControlFlow::Continue(()),
143        }
144    }
145}
146
147#[cfg(feature = "rayon")]
148fn par_query<'a, I, H, K, V, R>(
149    args: &QueryArgs<I, H>,
150    mut nodes: &'a [Node<K, V>],
151) -> ControlFlow<R>
152where
153    K: Ord + Send + Sync,
154    V: Sync,
155    I: Interval<K> + Sync,
156    H: Fn(&'a (Range<K>, V)) -> ControlFlow<R> + Sync,
157    R: Send,
158{
159    loop {
160        let (left, [mid, right @ ..]) = nodes.split_at(nodes.len() / 2) else {
161            unreachable!()
162        };
163
164        let mut go_left = false;
165        let mut go_right = false;
166
167        if args.interval.go_left(&mid.1) {
168            if !left.is_empty() {
169                go_left = true;
170            }
171
172            if args.interval.go_right(&(mid.0).0.start) {
173                if !right.is_empty() {
174                    go_right = true;
175                }
176
177                if args.interval.overlaps(&(mid.0).0.end) {
178                    (args.handler)(&mid.0)?;
179                }
180            }
181        }
182
183        match (go_left, go_right) {
184            (true, true) => {
185                let (left, right) = join(|| par_query(args, left), || par_query(args, right));
186
187                left?;
188                right?;
189
190                return ControlFlow::Continue(());
191            }
192            (true, false) => nodes = left,
193            (false, true) => nodes = right,
194            (false, false) => return ControlFlow::Continue(()),
195        }
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202
203    #[cfg(feature = "rayon")]
204    use std::sync::Mutex;
205
206    use proptest::{collection::vec, test_runner::TestRunner};
207
208    #[test]
209    fn query_random() {
210        const DOM: Range<i32> = -1000..1000;
211        const LEN: usize = 1000_usize;
212
213        TestRunner::default()
214            .run(
215                &(vec(DOM, LEN), vec(DOM, LEN), DOM, DOM),
216                |(start, end, query_start, query_end)| {
217                    let tree = ITree::<_, _>::new(
218                        start
219                            .iter()
220                            .zip(&end)
221                            .map(|(&start, &end)| (start..end, ())),
222                    );
223
224                    let mut result1 = Vec::new();
225                    tree.query(query_start..query_end, |(range, ())| {
226                        result1.push(range);
227                        ControlFlow::<()>::Continue(())
228                    })
229                    .continue_value()
230                    .unwrap();
231
232                    let mut result2 = tree
233                        .iter()
234                        .filter(|(range, ())| query_end > range.start && query_start < range.end)
235                        .map(|(range, ())| range)
236                        .collect::<Vec<_>>();
237
238                    result1.sort_unstable_by_key(|range| (range.start, range.end));
239                    result2.sort_unstable_by_key(|range| (range.start, range.end));
240                    assert_eq!(result1, result2);
241
242                    Ok(())
243                },
244            )
245            .unwrap()
246    }
247
248    #[test]
249    fn query_random_inclusive() {
250        const DOM: Range<i32> = -1000..1000;
251        const LEN: usize = 1000_usize;
252
253        TestRunner::default()
254            .run(
255                &(vec(DOM, LEN), vec(DOM, LEN), DOM, DOM),
256                |(start, end, query_start, query_end)| {
257                    let tree = ITree::<_, _>::new(
258                        start
259                            .iter()
260                            .zip(&end)
261                            .map(|(&start, &end)| (start..end, ())),
262                    );
263
264                    let mut result1 = Vec::new();
265                    tree.query(query_start..=query_end, |(range, ())| {
266                        result1.push(range);
267                        ControlFlow::<()>::Continue(())
268                    })
269                    .continue_value()
270                    .unwrap();
271
272                    let mut result2 = tree
273                        .iter()
274                        .filter(|(range, ())| query_end >= range.start && query_start <= range.end)
275                        .map(|(range, ())| range)
276                        .collect::<Vec<_>>();
277
278                    result1.sort_unstable_by_key(|range| (range.start, range.end));
279                    result2.sort_unstable_by_key(|range| (range.start, range.end));
280                    assert_eq!(result1, result2);
281
282                    Ok(())
283                },
284            )
285            .unwrap()
286    }
287
288    #[cfg(feature = "rayon")]
289    #[test]
290    fn par_query_random() {
291        const DOM: Range<i32> = -1000..1000;
292        const LEN: usize = 1000_usize;
293
294        TestRunner::default()
295            .run(
296                &(vec(DOM, LEN), vec(DOM, LEN), DOM, DOM),
297                |(start, end, query_start, query_end)| {
298                    let tree = ITree::<_, _>::par_new(
299                        start
300                            .iter()
301                            .zip(&end)
302                            .map(|(&start, &end)| (start..end, ())),
303                    );
304
305                    let result1 = Mutex::new(Vec::new());
306                    tree.par_query(query_start..query_end, |(range, ())| {
307                        result1.lock().unwrap().push(range);
308                        ControlFlow::<()>::Continue(())
309                    })
310                    .continue_value()
311                    .unwrap();
312                    let mut result1 = result1.into_inner().unwrap();
313
314                    let mut result2 = tree
315                        .iter()
316                        .filter(|(range, ())| query_end > range.start && query_start < range.end)
317                        .map(|(range, ())| range)
318                        .collect::<Vec<_>>();
319
320                    result1.sort_unstable_by_key(|range| (range.start, range.end));
321                    result2.sort_unstable_by_key(|range| (range.start, range.end));
322                    assert_eq!(result1, result2);
323
324                    Ok(())
325                },
326            )
327            .unwrap()
328    }
329
330    #[cfg(feature = "rayon")]
331    #[test]
332    fn par_query_random_inclusive() {
333        const DOM: Range<i32> = -1000..1000;
334        const LEN: usize = 1000_usize;
335
336        TestRunner::default()
337            .run(
338                &(vec(DOM, LEN), vec(DOM, LEN), DOM, DOM),
339                |(start, end, query_start, query_end)| {
340                    let tree = ITree::<_, _>::par_new(
341                        start
342                            .iter()
343                            .zip(&end)
344                            .map(|(&start, &end)| (start..end, ())),
345                    );
346
347                    let result1 = Mutex::new(Vec::new());
348                    tree.par_query(query_start..=query_end, |(range, ())| {
349                        result1.lock().unwrap().push(range);
350                        ControlFlow::<()>::Continue(())
351                    })
352                    .continue_value()
353                    .unwrap();
354                    let mut result1 = result1.into_inner().unwrap();
355
356                    let mut result2 = tree
357                        .iter()
358                        .filter(|(range, ())| query_end >= range.start && query_start <= range.end)
359                        .map(|(range, ())| range)
360                        .collect::<Vec<_>>();
361
362                    result1.sort_unstable_by_key(|range| (range.start, range.end));
363                    result2.sort_unstable_by_key(|range| (range.start, range.end));
364                    assert_eq!(result1, result2);
365
366                    Ok(())
367                },
368            )
369            .unwrap()
370    }
371}