1use std::mem::MaybeUninit;
2
3use crate::nodes::Node;
4
5pub struct Recursive<T> {
9 nodes: Vec<T>,
10 n: usize,
11}
12
13impl<T> Recursive<T>
14where
15 T: Node + Clone,
16{
17 #[must_use]
20 pub fn build(values: &[T]) -> Self {
21 let n = values.len();
22 let mut nodes = Vec::with_capacity(4 * n);
23 unsafe { nodes.set_len(4 * n) };
24 if n == 0 {
25 return Self {
26 nodes: Vec::new(),
27 n: 0,
28 };
29 }
30 Self::build_helper(0, 0, n - 1, values, &mut nodes);
31 let ptr = nodes.as_mut_ptr();
32 core::mem::forget(nodes);
33 let nodes = unsafe { Vec::from_raw_parts(ptr.cast::<T>(), 4 * n, 4 * n) }; Self { nodes, n }
36 }
37
38 #[inline]
39 fn build_helper(
40 curr_node: usize,
41 i: usize,
42 j: usize,
43 values: &[T],
44 nodes: &mut [MaybeUninit<T>],
45 ) {
46 if i == j {
47 nodes[curr_node].write(values[i].clone());
48 return;
49 }
50 let mid = (i + j) / 2;
51 let left_node = 2 * curr_node + 1;
52 let right_node = 2 * curr_node + 2;
53 Self::build_helper(left_node, i, mid, values, nodes);
54 Self::build_helper(right_node, mid + 1, j, values, nodes);
55 let (top_nodes, bottom_nodes) = nodes.split_at_mut(curr_node + 1);
56 top_nodes[curr_node].write(Node::combine(
57 unsafe { bottom_nodes[left_node - curr_node - 1].assume_init_ref() },
58 unsafe { bottom_nodes[right_node - curr_node - 1].assume_init_ref() },
59 ));
60 }
61
62 pub fn update(&mut self, p: usize, value: &<T as Node>::Value) {
66 self.update_helper(p, value, 0, 0, self.n - 1);
67 }
68
69 #[inline]
70 fn update_helper(
71 &mut self,
72 p: usize,
73 value: &<T as Node>::Value,
74 curr_node: usize,
75 i: usize,
76 j: usize,
77 ) {
78 if j < p || p < i {
79 return;
80 }
81 if i == j {
82 self.nodes[curr_node] = Node::initialize(value);
83 return;
84 }
85 let mid = (i + j) / 2;
86 let left_node = 2 * curr_node + 1;
87 let right_node = 2 * curr_node + 2;
88 self.update_helper(p, value, left_node, i, mid);
89 self.update_helper(p, value, right_node, mid + 1, j);
90 self.nodes[curr_node] = Node::combine(&self.nodes[left_node], &self.nodes[right_node]);
91 }
92
93 #[allow(clippy::must_use_candidate)]
98 pub fn query(&self, left: usize, right: usize) -> Option<T> {
99 self.query_helper(left, right, 0, 0, self.n - 1)
100 }
101
102 #[inline]
103 fn query_helper(
104 &self,
105 left: usize,
106 right: usize,
107 curr_node: usize,
108 i: usize,
109 j: usize,
110 ) -> Option<T> {
111 if j < left || right < i {
112 return None;
113 }
114 let mid = (i + j) / 2;
115 let left_node = 2 * curr_node + 1;
116 let right_node = 2 * curr_node + 2;
117 if left <= i && j <= right {
118 return Some(self.nodes[curr_node].clone());
119 }
120 match (
121 self.query_helper(left, right, left_node, i, mid),
122 self.query_helper(left, right, right_node, mid + 1, j),
123 ) {
124 (Some(ans_left), Some(ans_right)) => Some(Node::combine(&ans_left, &ans_right)),
125 (Some(ans_left), None) => Some(ans_left),
126 (None, Some(ans_right)) => Some(ans_right),
127 (None, None) => None,
128 }
129 }
130
131 pub fn lower_bound<F, G>(&self, predicate: F, g: G, value: <T as Node>::Value) -> usize
165 where
166 F: Fn(&<T as Node>::Value, &<T as Node>::Value) -> bool,
167 G: Fn(&<T as Node>::Value, <T as Node>::Value) -> <T as Node>::Value,
168 {
169 self.lower_bound_helper(0, 0, self.n - 1, predicate, g, value)
170 }
171 fn lower_bound_helper<F, G>(
172 &self,
173 curr_node: usize,
174 i: usize,
175 j: usize,
176 predicate: F,
177 g: G,
178 value: <T as Node>::Value,
179 ) -> usize
180 where
181 F: Fn(&<T as Node>::Value, &<T as Node>::Value) -> bool,
182 G: Fn(&<T as Node>::Value, <T as Node>::Value) -> <T as Node>::Value,
183 {
184 if i == j {
185 return i;
186 }
187 let mid = (i + j) / 2;
188 let left_node = 2 * curr_node + 1;
189 let right_node = 2 * curr_node + 2;
190 let left_value = self.nodes[left_node].value();
191 if predicate(left_value, &value) {
192 self.lower_bound_helper(left_node, i, mid, predicate, g, value)
193 } else {
194 let value = g(left_value, value);
195 self.lower_bound_helper(right_node, mid + 1, j, predicate, g, value)
196 }
197 }
198}
199
200#[cfg(test)]
201mod tests {
202 use crate::{nodes::Node, utils::Min};
203
204 use super::Recursive;
205
206 #[test]
207 fn non_empty_query_returns_some() {
208 let nodes: Vec<Min<usize>> = (0..=10).map(|x| Min::initialize(&x)).collect();
209 let segment_tree = Recursive::build(&nodes);
210 assert!(segment_tree.query(0, 10).is_some());
211 }
212 #[test]
213 fn empty_query_returns_none() {
214 let nodes: Vec<Min<usize>> = (0..=10).map(|x| Min::initialize(&x)).collect();
215 let segment_tree = Recursive::build(&nodes);
216 assert!(segment_tree.query(10, 0).is_none());
217 }
218 #[test]
219 fn update_works() {
220 let nodes: Vec<Min<usize>> = (0..=10).map(|x| Min::initialize(&x)).collect();
221 let mut segment_tree = Recursive::build(&nodes);
222 let value = 20;
223 segment_tree.update(0, &value);
224 assert_eq!(segment_tree.query(0, 0).unwrap().value(), &value);
225 }
226 #[test]
227 fn query_works() {
228 let nodes: Vec<Min<usize>> = (0..=10).map(|x| Min::initialize(&x)).collect();
229 let segment_tree = Recursive::build(&nodes);
230 assert_eq!(segment_tree.query(1, 10).unwrap().value(), &1);
231 }
232}