sif_itree/
query.rs

1use std::ops::{ControlFlow, Range};
2
3#[cfg(feature = "rayon")]
4use rayon::join;
5
6use crate::{ITree, Item, Node};
7
8impl<K, V, S> ITree<K, V, S>
9where
10    S: AsRef<[Node<K, V>]>,
11{
12    /// Query for all intervals overlapping the given interval
13    pub fn query<'a, H, R>(&'a self, interval: Range<K>, handler: H) -> ControlFlow<R>
14    where
15        K: Ord,
16        H: FnMut(&'a Item<K, V>) -> ControlFlow<R>,
17    {
18        let nodes = self.nodes.as_ref();
19
20        if !nodes.is_empty() {
21            query(&mut QueryArgs { interval, handler }, nodes)?;
22        }
23
24        ControlFlow::Continue(())
25    }
26
27    #[cfg(feature = "rayon")]
28    /// Query for all intervals overlapping the given interval, in parallel
29    pub fn par_query<'a, H, R>(&'a self, interval: Range<K>, handler: H) -> ControlFlow<R>
30    where
31        K: Ord + Send + Sync,
32        V: Sync,
33        H: Fn(&'a Item<K, V>) -> ControlFlow<R> + Sync,
34        R: Send,
35    {
36        let nodes = self.nodes.as_ref();
37
38        if !nodes.is_empty() {
39            par_query(&QueryArgs { interval, handler }, nodes)?;
40        }
41
42        ControlFlow::Continue(())
43    }
44}
45
46struct QueryArgs<K, H> {
47    interval: Range<K>,
48    handler: H,
49}
50
51fn query<'a, K, V, H, R>(args: &mut QueryArgs<K, H>, mut nodes: &'a [Node<K, V>]) -> ControlFlow<R>
52where
53    K: Ord,
54    H: FnMut(&'a (Range<K>, V)) -> ControlFlow<R>,
55{
56    loop {
57        let (left, [mid, right @ ..]) = nodes.split_at(nodes.len() / 2) else {
58            unreachable!()
59        };
60
61        let mut go_left = false;
62        let mut go_right = false;
63
64        if args.interval.start < mid.1 {
65            if !left.is_empty() {
66                go_left = true;
67            }
68
69            if args.interval.end > (mid.0).0.start {
70                if !right.is_empty() {
71                    go_right = true;
72                }
73
74                if args.interval.start < (mid.0).0.end {
75                    (args.handler)(&mid.0)?;
76                }
77            }
78        }
79
80        match (go_left, go_right) {
81            (true, true) => {
82                query(args, left)?;
83
84                nodes = right;
85            }
86            (true, false) => nodes = left,
87            (false, true) => nodes = right,
88            (false, false) => return ControlFlow::Continue(()),
89        }
90    }
91}
92
93#[cfg(feature = "rayon")]
94fn par_query<'a, K, V, H, R>(args: &QueryArgs<K, H>, mut nodes: &'a [Node<K, V>]) -> ControlFlow<R>
95where
96    K: Ord + Send + Sync,
97    V: Sync,
98    H: Fn(&'a (Range<K>, V)) -> ControlFlow<R> + Sync,
99    R: Send,
100{
101    loop {
102        let (left, [mid, right @ ..]) = nodes.split_at(nodes.len() / 2) else {
103            unreachable!()
104        };
105
106        let mut go_left = false;
107        let mut go_right = false;
108
109        if args.interval.start < mid.1 {
110            if !left.is_empty() {
111                go_left = true;
112            }
113
114            if args.interval.end > (mid.0).0.start {
115                if !right.is_empty() {
116                    go_right = true;
117                }
118
119                if args.interval.start < (mid.0).0.end {
120                    (args.handler)(&mid.0)?;
121                }
122            }
123        }
124
125        match (go_left, go_right) {
126            (true, true) => {
127                let (left, right) = join(|| par_query(args, left), || par_query(args, right));
128
129                left?;
130                right?;
131
132                return ControlFlow::Continue(());
133            }
134            (true, false) => nodes = left,
135            (false, true) => nodes = right,
136            (false, false) => return ControlFlow::Continue(()),
137        }
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144
145    #[cfg(feature = "rayon")]
146    use std::sync::Mutex;
147
148    use proptest::{collection::vec, test_runner::TestRunner};
149
150    #[test]
151    fn query_random() {
152        const DOM: Range<i32> = -1000..1000;
153        const LEN: usize = 1000_usize;
154
155        TestRunner::default()
156            .run(
157                &(vec(DOM, LEN), vec(DOM, LEN), DOM, DOM),
158                |(start, end, query_start, query_end)| {
159                    let tree = ITree::<_, _>::new(
160                        start
161                            .iter()
162                            .zip(&end)
163                            .map(|(&start, &end)| (start..end, ())),
164                    );
165
166                    let mut result1 = Vec::new();
167                    tree.query(query_start..query_end, |(range, ())| {
168                        result1.push(range);
169                        ControlFlow::<()>::Continue(())
170                    })
171                    .continue_value()
172                    .unwrap();
173
174                    let mut result2 = tree
175                        .iter()
176                        .filter(|(range, ())| query_end > range.start && query_start < range.end)
177                        .map(|(range, ())| range)
178                        .collect::<Vec<_>>();
179
180                    result1.sort_unstable_by_key(|range| (range.start, range.end));
181                    result2.sort_unstable_by_key(|range| (range.start, range.end));
182                    assert_eq!(result1, result2);
183
184                    Ok(())
185                },
186            )
187            .unwrap()
188    }
189
190    #[cfg(feature = "rayon")]
191    #[test]
192    fn par_query_random() {
193        const DOM: Range<i32> = -1000..1000;
194        const LEN: usize = 1000_usize;
195
196        TestRunner::default()
197            .run(
198                &(vec(DOM, LEN), vec(DOM, LEN), DOM, DOM),
199                |(start, end, query_start, query_end)| {
200                    let tree = ITree::<_, _>::par_new(
201                        start
202                            .iter()
203                            .zip(&end)
204                            .map(|(&start, &end)| (start..end, ())),
205                    );
206
207                    let result1 = Mutex::new(Vec::new());
208                    tree.par_query(query_start..query_end, |(range, ())| {
209                        result1.lock().unwrap().push(range);
210                        ControlFlow::<()>::Continue(())
211                    })
212                    .continue_value()
213                    .unwrap();
214                    let mut result1 = result1.into_inner().unwrap();
215
216                    let mut result2 = tree
217                        .iter()
218                        .filter(|(range, ())| query_end > range.start && query_start < range.end)
219                        .map(|(range, ())| range)
220                        .collect::<Vec<_>>();
221
222                    result1.sort_unstable_by_key(|range| (range.start, range.end));
223                    result2.sort_unstable_by_key(|range| (range.start, range.end));
224                    assert_eq!(result1, result2);
225
226                    Ok(())
227                },
228            )
229            .unwrap()
230    }
231}