seg_tree/segment_tree/
recursive.rs

1use std::mem::MaybeUninit;
2
3use crate::nodes::Node;
4
5/// Segment tree with range queries and point updates.
6/// It uses `O(n)` space, assuming that each node uses `O(1)` space.
7/// Note if you don't need to use `lower_bound`, just use the [`SegmentTree`](crate::segment_tree::SegmentTree) it uses half the memory and it's more performant.
8pub struct Recursive<T> {
9    nodes: Vec<T>,
10    n: usize,
11}
12
13impl<T> Recursive<T>
14where
15    T: Node + Clone,
16{
17    /// Builds segment tree from slice, each element of the slice will correspond to a leaf of the segment tree.
18    /// It has time complexity of `O(n*log(n))`, assuming that [combine](Node::combine) has constant time complexity.
19    #[must_use]
20    pub fn build(values: &[T]) -> Self {
21        let n = values.len();
22        let mut nodes = Vec::with_capacity(4 * n);
23        unsafe { nodes.set_len(4 * n) };
24        if n == 0 {
25            return Self {
26                nodes: Vec::new(),
27                n: 0,
28            };
29        }
30        Self::build_helper(0, 0, n - 1, values, &mut nodes);
31        let ptr = nodes.as_mut_ptr();
32        core::mem::forget(nodes);
33        let nodes = unsafe { Vec::from_raw_parts(ptr.cast::<T>(), 4 * n, 4 * n) }; // Unsafe AF, but if it's coded correctly the only nodes which will ever be accessed are already initialized
34
35        Self { nodes, n }
36    }
37
38    #[inline]
39    fn build_helper(
40        curr_node: usize,
41        i: usize,
42        j: usize,
43        values: &[T],
44        nodes: &mut [MaybeUninit<T>],
45    ) {
46        if i == j {
47            nodes[curr_node].write(values[i].clone());
48            return;
49        }
50        let mid = (i + j) / 2;
51        let left_node = 2 * curr_node + 1;
52        let right_node = 2 * curr_node + 2;
53        Self::build_helper(left_node, i, mid, values, nodes);
54        Self::build_helper(right_node, mid + 1, j, values, nodes);
55        let (top_nodes, bottom_nodes) = nodes.split_at_mut(curr_node + 1);
56        top_nodes[curr_node].write(Node::combine(
57            unsafe { bottom_nodes[left_node - curr_node - 1].assume_init_ref() },
58            unsafe { bottom_nodes[right_node - curr_node - 1].assume_init_ref() },
59        ));
60    }
61
62    /// Sets the p-th element of the segment tree to value T and update the segment tree correspondingly.
63    /// It will panic if p is not in `[0,n)`
64    /// It has time complexity of `O(log(n))`, assuming that [combine](Node::combine) has constant time complexity.
65    pub fn update(&mut self, p: usize, value: &<T as Node>::Value) {
66        self.update_helper(p, value, 0, 0, self.n - 1);
67    }
68
69    #[inline]
70    fn update_helper(
71        &mut self,
72        p: usize,
73        value: &<T as Node>::Value,
74        curr_node: usize,
75        i: usize,
76        j: usize,
77    ) {
78        if j < p || p < i {
79            return;
80        }
81        if i == j {
82            self.nodes[curr_node] = Node::initialize(value);
83            return;
84        }
85        let mid = (i + j) / 2;
86        let left_node = 2 * curr_node + 1;
87        let right_node = 2 * curr_node + 2;
88        self.update_helper(p, value, left_node, i, mid);
89        self.update_helper(p, value, right_node, mid + 1, j);
90        self.nodes[curr_node] = Node::combine(&self.nodes[left_node], &self.nodes[right_node]);
91    }
92
93    /// Returns the result from the range `[left,right]`.
94    /// It returns None if and only if range is empty.
95    /// It will **panic** if `left` or `right` are not in [0,n).
96    /// It has time complexity of `O(log(n))`, assuming that [combine](Node::combine) has constant time complexity.
97    #[allow(clippy::must_use_candidate)]
98    pub fn query(&self, left: usize, right: usize) -> Option<T> {
99        self.query_helper(left, right, 0, 0, self.n - 1)
100    }
101
102    #[inline]
103    fn query_helper(
104        &self,
105        left: usize,
106        right: usize,
107        curr_node: usize,
108        i: usize,
109        j: usize,
110    ) -> Option<T> {
111        if j < left || right < i {
112            return None;
113        }
114        let mid = (i + j) / 2;
115        let left_node = 2 * curr_node + 1;
116        let right_node = 2 * curr_node + 2;
117        if left <= i && j <= right {
118            return Some(self.nodes[curr_node].clone());
119        }
120        match (
121            self.query_helper(left, right, left_node, i, mid),
122            self.query_helper(left, right, right_node, mid + 1, j),
123        ) {
124            (Some(ans_left), Some(ans_right)) => Some(Node::combine(&ans_left, &ans_right)),
125            (Some(ans_left), None) => Some(ans_left),
126            (None, Some(ans_right)) => Some(ans_right),
127            (None, None) => None,
128        }
129    }
130
131    /// A method that finds the smallest prefix[^note] `u` such that `predicate(u.value(), value)` is `true`. The following must be true:
132    /// - `predicate` is monotonic over prefixes[^note2].
133    /// - `g` will satisfy the following, given segments `[i,j]` and `[i,k]` with `j<k` we have that `predicate([i,k].value(),value)` implies `predicate([j+1,k].value(),g([i,j].value(),value))`.
134    ///
135    /// These are two examples, the first is finding the smallest prefix which sums at least some value.
136    /// ```
137    /// # use seg_tree::{Recursive,utils::Sum,nodes::Node};
138    /// let predicate = |left_value: &usize, value: &usize|{*left_value >= *value}; // Is the sum greater or equal to value?
139    /// let g = |left_node: &usize, value: usize|{value - *left_node}; // Subtract the sum of the prefix.
140    /// # let nodes: Vec<Sum<usize>> = (0..10).map(|x| Sum::initialize(&x)).collect();
141    /// let seg_tree = Recursive::build(&nodes); // [0,1,2,3,4,5,6,7,8,9] with Sum<usize> nodes
142    /// let index = seg_tree.lower_bound(predicate, g, 3); // Will return 2 as sum([0,1,2])>=3
143    /// # let sums = vec![0,1,3,6,10,15,21,28,36,45];
144    /// # for i in 0..10{
145    /// #    assert_eq!(seg_tree.lower_bound(predicate, g, sums[i]), i);
146    /// # }
147    /// ```
148    /// The second is finding the position of the smallest value greater or equal to some value.
149    /// ```
150    /// # use seg_tree::{Recursive,utils::Max,nodes::Node};
151    /// let predicate = |left_value:&usize, value:&usize|{*left_value>=*value}; // Is the maximum greater or equal to value?
152    /// let g = |_left_node:&usize,value:usize|{value}; // Do nothing
153    /// # let nodes: Vec<Max<usize>> = (0..10).map(|x| Max::initialize(&x)).collect();
154    /// let seg_tree = Recursive::build(&nodes); // [0,1,2,3,4,5,6,7,8,9] with Max<usize> nodes
155    /// let index = seg_tree.lower_bound(predicate, g, 3); // Will return 3 as 3>=3
156    /// # for i in 0..10{
157    /// #    assert_eq!(seg_tree.lower_bound(predicate, g, i), i);
158    /// # }
159    /// ```
160    ///
161    /// [^note]: A prefix is a segment of the form `[0,i]`.
162    ///
163    /// [^note2]: Given two prefixes `u` and `v` if `u` is contained in `v` then `predicate(u.value(), value)` implies `predicate(v.value(), value)`.
164    pub fn lower_bound<F, G>(&self, predicate: F, g: G, value: <T as Node>::Value) -> usize
165    where
166        F: Fn(&<T as Node>::Value, &<T as Node>::Value) -> bool,
167        G: Fn(&<T as Node>::Value, <T as Node>::Value) -> <T as Node>::Value,
168    {
169        self.lower_bound_helper(0, 0, self.n - 1, predicate, g, value)
170    }
171    fn lower_bound_helper<F, G>(
172        &self,
173        curr_node: usize,
174        i: usize,
175        j: usize,
176        predicate: F,
177        g: G,
178        value: <T as Node>::Value,
179    ) -> usize
180    where
181        F: Fn(&<T as Node>::Value, &<T as Node>::Value) -> bool,
182        G: Fn(&<T as Node>::Value, <T as Node>::Value) -> <T as Node>::Value,
183    {
184        if i == j {
185            return i;
186        }
187        let mid = (i + j) / 2;
188        let left_node = 2 * curr_node + 1;
189        let right_node = 2 * curr_node + 2;
190        let left_value = self.nodes[left_node].value();
191        if predicate(left_value, &value) {
192            self.lower_bound_helper(left_node, i, mid, predicate, g, value)
193        } else {
194            let value = g(left_value, value);
195            self.lower_bound_helper(right_node, mid + 1, j, predicate, g, value)
196        }
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use crate::{nodes::Node, utils::Min};
203
204    use super::Recursive;
205
206    #[test]
207    fn non_empty_query_returns_some() {
208        let nodes: Vec<Min<usize>> = (0..=10).map(|x| Min::initialize(&x)).collect();
209        let segment_tree = Recursive::build(&nodes);
210        assert!(segment_tree.query(0, 10).is_some());
211    }
212    #[test]
213    fn empty_query_returns_none() {
214        let nodes: Vec<Min<usize>> = (0..=10).map(|x| Min::initialize(&x)).collect();
215        let segment_tree = Recursive::build(&nodes);
216        assert!(segment_tree.query(10, 0).is_none());
217    }
218    #[test]
219    fn update_works() {
220        let nodes: Vec<Min<usize>> = (0..=10).map(|x| Min::initialize(&x)).collect();
221        let mut segment_tree = Recursive::build(&nodes);
222        let value = 20;
223        segment_tree.update(0, &value);
224        assert_eq!(segment_tree.query(0, 0).unwrap().value(), &value);
225    }
226    #[test]
227    fn query_works() {
228        let nodes: Vec<Min<usize>> = (0..=10).map(|x| Min::initialize(&x)).collect();
229        let segment_tree = Recursive::build(&nodes);
230        assert_eq!(segment_tree.query(1, 10).unwrap().value(), &1);
231    }
232}