1use num::Zero;
2use rand::distributions::uniform::SampleUniform;
3use std::mem::MaybeUninit;
4#[allow(clippy::enum_variant_names)]
6#[allow(dead_code)]
7#[derive(Debug)]
8pub enum Error {
9 NodeNotFound(Index),
10 NodeHasNoParent(ShiftedIndex),
11 NodeAlreadyInserted(Index),
12 CannotDirectlyUpdateInternalNode(ShiftedIndex),
13 EmptyTree,
14 NumericalError,
15}
16impl std::fmt::Display for Error {
17 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
18 match self {
19 Error::NodeNotFound(index) => write!(f, "Node with index {} not found", index.0),
20 Error::NodeHasNoParent(index) => {
21 write!(f, "Node with shifted_index {} has no parent", index.0)
22 }
23 Error::NodeAlreadyInserted(index) => {
24 write!(f, "Node with index {} is already inserted", index.0)
25 }
26 Error::CannotDirectlyUpdateInternalNode(index) => write!(
27 f,
28 "Cannot directly update internal node with shifted_index {}",
29 index.0
30 ),
31 Error::EmptyTree => write!(f, "Tree is empty"),
32 Error::NumericalError => write!(f, "Numerical error"),
33 }
34 }
35}
36
37impl std::error::Error for Error {}
38
39#[derive(Eq, PartialEq, Hash, Copy, Clone, Debug)]
41pub struct Index(pub usize);
42
43impl From<usize> for Index {
44 fn from(index: usize) -> Self {
45 Index(index)
46 }
47}
48
49#[derive(Eq, PartialEq, Hash, Copy, Clone, Debug)]
50pub struct ShiftedIndex(pub usize);
51
52impl From<usize> for ShiftedIndex {
53 fn from(index: usize) -> Self {
54 ShiftedIndex(index)
55 }
56}
57
58#[derive(PartialEq)]
60pub enum NodeState {
61 Internal,
62 Leaf,
63}
64
65pub trait Node
67where
68 Self: Sized,
69{
70 type C;
71
72 fn node_state(shifted_index: ShiftedIndex, storage_size: usize) -> NodeState {
73 let num_leaves = (storage_size + 1) / 2;
74 match shifted_index.0 < num_leaves - 1 {
75 true => NodeState::Internal,
76 false => NodeState::Leaf,
77 }
78 }
79
80 fn left_child(shifted_index: ShiftedIndex) -> ShiftedIndex {
81 ShiftedIndex(2 * shifted_index.0 + 1)
82 }
83 fn right_child(shifted_index: ShiftedIndex) -> ShiftedIndex {
84 ShiftedIndex(2 * shifted_index.0 + 2)
85 }
86 fn parent(shifted_index: ShiftedIndex) -> Result<ShiftedIndex, Error> {
87 if shifted_index.0 == 0 {
88 return Err(Error::NodeHasNoParent(shifted_index));
89 }
90 Ok(ShiftedIndex((shifted_index.0 - 1) / 2))
91 }
92 fn contribution(&self) -> Self::C;
93 fn new(contribution: Self::C) -> Self;
94 fn from_children(left: &Self, right: &Self) -> Self;
95 fn update(
96 storage: &mut Vec<Self>,
97 shifted_index: ShiftedIndex,
98 value: Self::C,
99 ) -> Result<(), Error>;
100 fn sample(storage: &[Self], rng: &mut impl rand::Rng) -> Result<ShiftedIndex, Error>;
101}
102
103pub struct Tree<N> {
105 storage: Vec<N>,
106 num_leaves: usize,
107 num_nodes: usize,
108}
109
110impl<N> std::fmt::Debug for Tree<N>
111where
112 N: std::fmt::Debug,
113{
114 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
115 let internal_nodes: &[N] = &self.storage[..self.num_leaves - 1];
116 let leaves: &[N] = &self.storage[self.num_leaves - 1..];
117
118 write!(
119 f,
120 "Internal nodes: {:?}\nLeaves: {:?}",
121 internal_nodes, leaves
122 )
123 }
124}
125
126impl<N> Tree<N>
127where
128 N: Node,
129{
130 pub fn get_shifted_node_index(&self, node_index: Index) -> Result<ShiftedIndex, Error> {
131 let shifted_node_index = node_index.0 + self.num_leaves - 1;
132 match shifted_node_index < self.num_nodes {
133 false => Err(Error::NodeNotFound(node_index)),
134 true => Ok(ShiftedIndex(shifted_node_index)),
135 }
136 }
137 pub fn get_node_index(&self, shifted_node_index: ShiftedIndex) -> Result<Index, Error> {
138 let node_index = shifted_node_index.0 - (self.num_leaves - 1);
139 match node_index < self.num_leaves {
140 false => Err(Error::NodeNotFound(Index(node_index))),
141 true => Ok(Index(node_index)),
142 }
143 }
144
145 pub fn get_contribution(&self, node_index: Index) -> Result<N::C, Error> {
146 let shifted_index = self.get_shifted_node_index(node_index)?;
147 Ok(self.storage[shifted_index.0].contribution())
148 }
149
150 pub fn from_iterable<I>(mut iterator: I) -> Result<Self, Error>
151 where
152 I: Iterator<Item = N::C> + ExactSizeIterator,
153 {
154 let num_leaves = iterator.len();
155 if num_leaves == 0 {
156 return Err(Error::EmptyTree);
157 }
158 let num_nodes = 2 * num_leaves - 1;
159 let mut storage: Vec<MaybeUninit<N>> = Vec::with_capacity(num_nodes);
160 unsafe {
163 storage.set_len(num_nodes);
164 storage[num_leaves - 1..].iter_mut().for_each(|uninit| {
166 let uninit_ptr = uninit.as_mut_ptr();
167 let leaf = iterator.next().unwrap();
168 std::ptr::write(uninit_ptr, N::new(leaf));
169 });
170 (0..num_leaves - 1).rev().for_each(|i| {
172 let left = N::left_child(i.into());
173 let right = N::right_child(i.into());
174 let parent = N::from_children(
175 storage[left.0].as_ptr().as_ref().unwrap(),
176 storage[right.0].as_ptr().as_ref().unwrap(),
177 );
178 storage[i] = MaybeUninit::new(parent);
179 });
180 }
181 unsafe {
182 let storage: Vec<N> = std::mem::transmute(storage);
184 Ok(Self {
185 storage,
186 num_leaves,
187 num_nodes,
188 })
189 }
190 }
191 pub fn sample(&self, rng: &mut impl rand::Rng) -> Result<Index, Error> {
192 N::sample(&self.storage, rng).and_then(|shifted_index| self.get_node_index(shifted_index))
193 }
194 pub fn update(&mut self, node_index: Index, value: N::C) -> Result<(), Error> {
195 let shifted_index = self.get_shifted_node_index(node_index)?;
196 N::update(&mut self.storage, shifted_index, value)
197 }
198
199 pub fn contribution(&self, node_index: Index) -> Result<N::C, Error> {
200 let shifted_index = self.get_shifted_node_index(node_index)?;
201 Ok(self.storage[shifted_index.0].contribution())
202 }
203}
204
205#[derive(Debug)]
207pub struct UnstableNode<C> {
208 contribution: C,
209}
210
211impl<C> Node for UnstableNode<C>
212where
213 C: Copy
214 + Clone
215 + std::ops::Add<Output = C>
216 + std::ops::Sub<Output = C>
217 + Zero
218 + std::ops::AddAssign
219 + std::ops::SubAssign
220 + SampleUniform
221 + std::cmp::PartialOrd,
222{
223 type C = C;
224 fn contribution(&self) -> Self::C {
225 self.contribution
226 }
227
228 fn new(contribution: Self::C) -> Self {
229 UnstableNode { contribution }
230 }
231
232 fn from_children(left: &Self, right: &Self) -> Self {
233 UnstableNode {
234 contribution: left.contribution + right.contribution,
235 }
236 }
237 fn update(
238 storage: &mut Vec<Self>,
239 shifted_index: ShiftedIndex,
240 value: Self::C,
241 ) -> Result<(), Error> {
242 let storage_size = storage.len();
243 let leaf = storage.get_mut(shifted_index.0).unwrap();
244 match Self::node_state(shifted_index, storage_size) {
245 NodeState::Internal => Err(Error::CannotDirectlyUpdateInternalNode(shifted_index)),
246 NodeState::Leaf => {
247 let old_value = &mut leaf.contribution;
248 let (abs_diff, sign): (C, bool) = match *old_value <= value {
249 true => (value - *old_value, true),
250 false => (*old_value - value, false),
251 };
252 if abs_diff.is_zero() {
253 Ok(())
254 } else {
255 match sign {
256 true => *old_value += abs_diff,
257 false => *old_value -= abs_diff,
258 }
259 let mut node_shifted_index = shifted_index;
260 while let Ok(parent_shifted_index) = Self::parent(node_shifted_index) {
261 let parent = storage.get_mut(parent_shifted_index.0).unwrap();
262
263 match Self::node_state(parent_shifted_index, storage_size) {
264 NodeState::Internal => match sign {
265 true => parent.contribution += abs_diff,
266 false => parent.contribution -= abs_diff,
267 },
268 NodeState::Leaf => unreachable!("Internal node has leaf parent"),
269 }
270 node_shifted_index = parent_shifted_index;
271 }
272 Ok(())
273 }
274 }
275 }
276 }
277
278 fn sample(storage: &[Self], rng: &mut impl rand::Rng) -> Result<ShiftedIndex, Error> {
279 if storage.is_empty() {
280 return Err(Error::EmptyTree);
281 }
282 let storage_size = storage.len();
283 let mut shifted_index = ShiftedIndex(0);
284 while Self::node_state(shifted_index, storage_size) == NodeState::Internal {
285 let left = Self::left_child(shifted_index);
286 let right = Self::right_child(shifted_index);
287 let left_contribution = unsafe { storage.get_unchecked(left.0).contribution() };
288 let right_contribution = unsafe { storage.get_unchecked(right.0).contribution() };
289 let total_contribution = left_contribution + right_contribution;
290
291 let sample: C = rng.gen_range(C::zero()..total_contribution);
292 if sample < left_contribution {
293 shifted_index = left;
294 } else {
295 shifted_index = right;
296 }
297 }
298 Ok(shifted_index)
299 }
300}