sampling_tree/
sampling.rs

1use num::Zero;
2use rand::distributions::uniform::SampleUniform;
3use std::mem::MaybeUninit;
4// MARK: Error def
5#[allow(clippy::enum_variant_names)]
6#[allow(dead_code)]
7#[derive(Debug)]
8pub enum Error {
9    NodeNotFound(Index),
10    NodeHasNoParent(ShiftedIndex),
11    NodeAlreadyInserted(Index),
12    CannotDirectlyUpdateInternalNode(ShiftedIndex),
13    EmptyTree,
14    NumericalError,
15}
16impl std::fmt::Display for Error {
17    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
18        match self {
19            Error::NodeNotFound(index) => write!(f, "Node with index {} not found", index.0),
20            Error::NodeHasNoParent(index) => {
21                write!(f, "Node with shifted_index {} has no parent", index.0)
22            }
23            Error::NodeAlreadyInserted(index) => {
24                write!(f, "Node with index {} is already inserted", index.0)
25            }
26            Error::CannotDirectlyUpdateInternalNode(index) => write!(
27                f,
28                "Cannot directly update internal node with shifted_index {}",
29                index.0
30            ),
31            Error::EmptyTree => write!(f, "Tree is empty"),
32            Error::NumericalError => write!(f, "Numerical error"),
33        }
34    }
35}
36
37impl std::error::Error for Error {}
38
39// MARK: Newtypes
40#[derive(Eq, PartialEq, Hash, Copy, Clone, Debug)]
41pub struct Index(pub usize);
42
43impl From<usize> for Index {
44    fn from(index: usize) -> Self {
45        Index(index)
46    }
47}
48
49#[derive(Eq, PartialEq, Hash, Copy, Clone, Debug)]
50pub struct ShiftedIndex(pub usize);
51
52impl From<usize> for ShiftedIndex {
53    fn from(index: usize) -> Self {
54        ShiftedIndex(index)
55    }
56}
57
58// MARK: NodeState enum
59#[derive(PartialEq)]
60pub enum NodeState {
61    Internal,
62    Leaf,
63}
64
65// MARK: Node trait
66pub trait Node
67where
68    Self: Sized,
69{
70    type C;
71
72    fn node_state(shifted_index: ShiftedIndex, storage_size: usize) -> NodeState {
73        let num_leaves = (storage_size + 1) / 2;
74        match shifted_index.0 < num_leaves - 1 {
75            true => NodeState::Internal,
76            false => NodeState::Leaf,
77        }
78    }
79
80    fn left_child(shifted_index: ShiftedIndex) -> ShiftedIndex {
81        ShiftedIndex(2 * shifted_index.0 + 1)
82    }
83    fn right_child(shifted_index: ShiftedIndex) -> ShiftedIndex {
84        ShiftedIndex(2 * shifted_index.0 + 2)
85    }
86    fn parent(shifted_index: ShiftedIndex) -> Result<ShiftedIndex, Error> {
87        if shifted_index.0 == 0 {
88            return Err(Error::NodeHasNoParent(shifted_index));
89        }
90        Ok(ShiftedIndex((shifted_index.0 - 1) / 2))
91    }
92    fn contribution(&self) -> Self::C;
93    fn new(contribution: Self::C) -> Self;
94    fn from_children(left: &Self, right: &Self) -> Self;
95    fn update(
96        storage: &mut Vec<Self>,
97        shifted_index: ShiftedIndex,
98        value: Self::C,
99    ) -> Result<(), Error>;
100    fn sample(storage: &[Self], rng: &mut impl rand::Rng) -> Result<ShiftedIndex, Error>;
101}
102
103// MARK: Tree struct
104pub struct Tree<N> {
105    storage: Vec<N>,
106    num_leaves: usize,
107    num_nodes: usize,
108}
109
110impl<N> std::fmt::Debug for Tree<N>
111where
112    N: std::fmt::Debug,
113{
114    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
115        let internal_nodes: &[N] = &self.storage[..self.num_leaves - 1];
116        let leaves: &[N] = &self.storage[self.num_leaves - 1..];
117
118        write!(
119            f,
120            "Internal nodes: {:?}\nLeaves: {:?}",
121            internal_nodes, leaves
122        )
123    }
124}
125
126impl<N> Tree<N>
127where
128    N: Node,
129{
130    pub fn get_shifted_node_index(&self, node_index: Index) -> Result<ShiftedIndex, Error> {
131        let shifted_node_index = node_index.0 + self.num_leaves - 1;
132        match shifted_node_index < self.num_nodes {
133            false => Err(Error::NodeNotFound(node_index)),
134            true => Ok(ShiftedIndex(shifted_node_index)),
135        }
136    }
137    pub fn get_node_index(&self, shifted_node_index: ShiftedIndex) -> Result<Index, Error> {
138        let node_index = shifted_node_index.0 - (self.num_leaves - 1);
139        match node_index < self.num_leaves {
140            false => Err(Error::NodeNotFound(Index(node_index))),
141            true => Ok(Index(node_index)),
142        }
143    }
144
145    pub fn get_contribution(&self, node_index: Index) -> Result<N::C, Error> {
146        let shifted_index = self.get_shifted_node_index(node_index)?;
147        Ok(self.storage[shifted_index.0].contribution())
148    }
149
150    pub fn from_iterable<I>(mut iterator: I) -> Result<Self, Error>
151    where
152        I: Iterator<Item = N::C> + ExactSizeIterator,
153    {
154        let num_leaves = iterator.len();
155        if num_leaves == 0 {
156            return Err(Error::EmptyTree);
157        }
158        let num_nodes = 2 * num_leaves - 1;
159        let mut storage: Vec<MaybeUninit<N>> = Vec::with_capacity(num_nodes);
160        // SAFTEY: We have reserved enough space for the elements and
161        // we now initialize them
162        unsafe {
163            storage.set_len(num_nodes);
164            // stick the leaves at the end of the storage:
165            storage[num_leaves - 1..].iter_mut().for_each(|uninit| {
166                let uninit_ptr = uninit.as_mut_ptr();
167                let leaf = iterator.next().unwrap();
168                std::ptr::write(uninit_ptr, N::new(leaf));
169            });
170            // Now we fill up the rest of the tree backwards:
171            (0..num_leaves - 1).rev().for_each(|i| {
172                let left = N::left_child(i.into());
173                let right = N::right_child(i.into());
174                let parent = N::from_children(
175                    storage[left.0].as_ptr().as_ref().unwrap(),
176                    storage[right.0].as_ptr().as_ref().unwrap(),
177                );
178                storage[i] = MaybeUninit::new(parent);
179            });
180        }
181        unsafe {
182            // Now transmute the storage to the final form and return:
183            let storage: Vec<N> = std::mem::transmute(storage);
184            Ok(Self {
185                storage,
186                num_leaves,
187                num_nodes,
188            })
189        }
190    }
191    pub fn sample(&self, rng: &mut impl rand::Rng) -> Result<Index, Error> {
192        N::sample(&self.storage, rng).and_then(|shifted_index| self.get_node_index(shifted_index))
193    }
194    pub fn update(&mut self, node_index: Index, value: N::C) -> Result<(), Error> {
195        let shifted_index = self.get_shifted_node_index(node_index)?;
196        N::update(&mut self.storage, shifted_index, value)
197    }
198
199    pub fn contribution(&self, node_index: Index) -> Result<N::C, Error> {
200        let shifted_index = self.get_shifted_node_index(node_index)?;
201        Ok(self.storage[shifted_index.0].contribution())
202    }
203}
204
205// MARK: UnstableNode
206#[derive(Debug)]
207pub struct UnstableNode<C> {
208    contribution: C,
209}
210
211impl<C> Node for UnstableNode<C>
212where
213    C: Copy
214        + Clone
215        + std::ops::Add<Output = C>
216        + std::ops::Sub<Output = C>
217        + Zero
218        + std::ops::AddAssign
219        + std::ops::SubAssign
220        + SampleUniform
221        + std::cmp::PartialOrd,
222{
223    type C = C;
224    fn contribution(&self) -> Self::C {
225        self.contribution
226    }
227
228    fn new(contribution: Self::C) -> Self {
229        UnstableNode { contribution }
230    }
231
232    fn from_children(left: &Self, right: &Self) -> Self {
233        UnstableNode {
234            contribution: left.contribution + right.contribution,
235        }
236    }
237    fn update(
238        storage: &mut Vec<Self>,
239        shifted_index: ShiftedIndex,
240        value: Self::C,
241    ) -> Result<(), Error> {
242        let storage_size = storage.len();
243        let leaf = storage.get_mut(shifted_index.0).unwrap();
244        match Self::node_state(shifted_index, storage_size) {
245            NodeState::Internal => Err(Error::CannotDirectlyUpdateInternalNode(shifted_index)),
246            NodeState::Leaf => {
247                let old_value = &mut leaf.contribution;
248                let (abs_diff, sign): (C, bool) = match *old_value <= value {
249                    true => (value - *old_value, true),
250                    false => (*old_value - value, false),
251                };
252                if abs_diff.is_zero() {
253                    Ok(())
254                } else {
255                    match sign {
256                        true => *old_value += abs_diff,
257                        false => *old_value -= abs_diff,
258                    }
259                    let mut node_shifted_index = shifted_index;
260                    while let Ok(parent_shifted_index) = Self::parent(node_shifted_index) {
261                        let parent = storage.get_mut(parent_shifted_index.0).unwrap();
262
263                        match Self::node_state(parent_shifted_index, storage_size) {
264                            NodeState::Internal => match sign {
265                                true => parent.contribution += abs_diff,
266                                false => parent.contribution -= abs_diff,
267                            },
268                            NodeState::Leaf => unreachable!("Internal node has leaf parent"),
269                        }
270                        node_shifted_index = parent_shifted_index;
271                    }
272                    Ok(())
273                }
274            }
275        }
276    }
277
278    fn sample(storage: &[Self], rng: &mut impl rand::Rng) -> Result<ShiftedIndex, Error> {
279        if storage.is_empty() {
280            return Err(Error::EmptyTree);
281        }
282        let storage_size = storage.len();
283        let mut shifted_index = ShiftedIndex(0);
284        while Self::node_state(shifted_index, storage_size) == NodeState::Internal {
285            let left = Self::left_child(shifted_index);
286            let right = Self::right_child(shifted_index);
287            let left_contribution = unsafe { storage.get_unchecked(left.0).contribution() };
288            let right_contribution = unsafe { storage.get_unchecked(right.0).contribution() };
289            let total_contribution = left_contribution + right_contribution;
290
291            let sample: C = rng.gen_range(C::zero()..total_contribution);
292            if sample < left_contribution {
293                shifted_index = left;
294            } else {
295                shifted_index = right;
296            }
297        }
298        Ok(shifted_index)
299    }
300}