1use core::mem::MaybeUninit;
2
3use crate::nodes::{LazyNode, Node};
4
5pub struct LazyRecursive<T: LazyNode> {
8 nodes: Vec<T>,
9 n: usize,
10}
11
12impl<T: LazyNode + Clone> LazyRecursive<T> {
13 pub fn build(values: &[T]) -> Self {
16 let n = values.len();
17 if n == 0 {
18 return Self {
19 nodes: Vec::new(),
20 n,
21 };
22 }
23 let mut nodes = Vec::with_capacity(4 * n);
24 unsafe { nodes.set_len(4 * n) };
25 Self::build_helper(0, 0, n - 1, values, &mut nodes);
26 let ptr = nodes.as_mut_ptr();
27 core::mem::forget(nodes);
28 let nodes = unsafe { Vec::from_raw_parts(ptr.cast::<T>(), 4 * n, 4 * n) };
29 Self { nodes, n }
30 }
31
32 fn build_helper(
33 curr_node: usize,
34 i: usize,
35 j: usize,
36 values: &[T],
37 nodes: &mut [MaybeUninit<T>],
38 ) {
39 if i == j {
40 nodes[curr_node].write(values[i].clone());
41 return;
42 }
43 let mid = (i + j) / 2;
44 let left_node = 2 * curr_node + 1;
45 let right_node = 2 * curr_node + 2;
46 Self::build_helper(left_node, i, mid, values, nodes);
47 Self::build_helper(right_node, mid + 1, j, values, nodes);
48 let (top_nodes, bottom_nodes) = nodes.split_at_mut(curr_node + 1);
49 top_nodes[curr_node].write(Node::combine(
50 unsafe { bottom_nodes[left_node - curr_node - 1].assume_init_ref() },
51 unsafe { bottom_nodes[right_node - curr_node - 1].assume_init_ref() },
52 ));
53 }
54
55 fn push(&mut self, u: usize, i: usize, j: usize) {
56 let (parent_slice, sons_slice) = self.nodes.split_at_mut(u + 1);
58 if let Some(value) = parent_slice[u].lazy_value() {
59 if i != j {
60 sons_slice[u].update_lazy_value(value); sons_slice[u + 1].update_lazy_value(value); }
63 }
64 self.nodes[u].lazy_update(i, j);
65 }
66
67 pub fn update(&mut self, i: usize, j: usize, value: &<T as Node>::Value) {
71 self.update_helper(i, j, value, 0, 0, self.n - 1);
72 }
73
74 fn update_helper(
75 &mut self,
76 left: usize,
77 right: usize,
78 value: &<T as Node>::Value,
79 curr_node: usize,
80 i: usize,
81 j: usize,
82 ) {
83 if self.nodes[curr_node].lazy_value().is_some() {
84 self.push(curr_node, i, j);
85 }
86 if j < left || right < i {
87 return;
88 }
89 if left <= i && j <= right {
90 self.nodes[curr_node].update_lazy_value(value);
91 self.push(curr_node, i, j);
92 return;
93 }
94 let mid = (i + j) / 2;
95 let left_node = 2 * curr_node + 1;
96 let right_node = 2 * curr_node + 2;
97 self.update_helper(left, right, value, left_node, i, mid);
98 self.update_helper(left, right, value, right_node, mid + 1, j);
99 self.nodes[curr_node] = Node::combine(&self.nodes[left_node], &self.nodes[right_node]);
100 }
101
102 pub fn query(&mut self, left: usize, right: usize) -> Option<T> {
107 self.query_helper(left, right, 0, 0, self.n - 1)
108 }
109
110 fn query_helper(
111 &mut self,
112 left: usize,
113 right: usize,
114 curr_node: usize,
115 i: usize,
116 j: usize,
117 ) -> Option<T> {
118 if j < left || right < i {
119 return None;
120 }
121 let mid = (i + j) / 2;
122 let left_node = 2 * curr_node + 1;
123 let right_node = 2 * curr_node + 2;
124 if self.nodes[curr_node].lazy_value().is_some() {
125 self.push(curr_node, i, j);
126 }
127 if left <= i && j <= right {
128 return Some(self.nodes[curr_node].clone());
129 }
130 match (
131 self.query_helper(left, right, left_node, i, mid),
132 self.query_helper(left, right, right_node, mid + 1, j),
133 ) {
134 (Some(ans_left), Some(ans_right)) => Some(Node::combine(&ans_left, &ans_right)),
135 (Some(ans_left), None) => Some(ans_left),
136 (None, Some(ans_right)) => Some(ans_right),
137 (None, None) => None,
138 }
139 }
140
141 pub fn lower_bound<F, G>(&self, predicate: F, g: G, value: <T as Node>::Value) -> usize
176 where
177 F: Fn(&<T as Node>::Value, &<T as Node>::Value) -> bool,
178 G: Fn(&<T as Node>::Value, <T as Node>::Value) -> <T as Node>::Value,
179 {
180 self.lower_bound_helper(0, 0, self.n - 1, predicate, g, value)
181 }
182 fn lower_bound_helper<F, G>(
183 &self,
184 curr_node: usize,
185 i: usize,
186 j: usize,
187 predicate: F,
188 g: G,
189 value: <T as Node>::Value,
190 ) -> usize
191 where
192 F: Fn(&<T as Node>::Value, &<T as Node>::Value) -> bool,
193 G: Fn(&<T as Node>::Value, <T as Node>::Value) -> <T as Node>::Value,
194 {
195 if i == j {
196 return i;
197 }
198 let mid = (i + j) / 2;
199 let left_node = 2 * curr_node + 1;
200 let right_node = 2 * curr_node + 2;
201 let left_value = self.nodes[left_node].value();
202 if predicate(left_value, &value) {
203 self.lower_bound_helper(left_node, i, mid, predicate, g, value)
204 } else {
205 let value = g(left_value, value);
206 self.lower_bound_helper(right_node, mid + 1, j, predicate, g, value)
207 }
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use crate::{
214 nodes::Node,
215 utils::{LazySetWrapper, Min},
216 };
217
218 use super::LazyRecursive;
219
220 type LSMin<T> = LazySetWrapper<Min<T>>;
221
222 #[test]
223 fn build_works() {
224 let n = 16;
225 let nodes: Vec<LSMin<usize>> = (0..n).map(|x| LSMin::initialize(&x)).collect();
226 let mut segment_tree = LazyRecursive::build(&nodes);
227 for i in 0..n {
228 let temp = segment_tree.query(i, i).unwrap();
229 assert_eq!(temp.value(), &i);
230 }
231 }
232 #[test]
233 fn non_empty_query_returns_some() {
234 let nodes: Vec<LSMin<usize>> = (0..10).map(|x| LSMin::initialize(&x)).collect();
235 let mut segment_tree = LazyRecursive::build(&nodes);
236 assert!(segment_tree.query(0, 9).is_some());
237 }
238 #[test]
239 fn empty_query_returns_none() {
240 let nodes: Vec<LSMin<usize>> = (0..10).map(|x| LSMin::initialize(&x)).collect();
241 let mut segment_tree = LazyRecursive::build(&nodes);
242 assert!(segment_tree.query(10, 0).is_none());
243 }
244 #[test]
245 fn update_works() {
246 let nodes: Vec<LSMin<usize>> = (0..10).map(|x| LSMin::initialize(&x)).collect();
247 let mut segment_tree = LazyRecursive::build(&nodes);
248 let value = 20;
249 segment_tree.update(0, 9, &value);
250 assert_eq!(segment_tree.query(0, 1).unwrap().value(), &value);
251 }
252 #[test]
253 fn query_works() {
254 let nodes: Vec<LSMin<usize>> = (0..10).map(|x| LSMin::initialize(&x)).collect();
255 let mut segment_tree = LazyRecursive::build(&nodes);
256 assert_eq!(segment_tree.query(1, 9).unwrap().value(), &1);
257 }
258}