segment_tree/
fenwick.rs

1use crate::ops::{Commutative, Invertible};
2use crate::maybe_owned::MaybeOwned;
3
4use std::default::Default;
5use std::hash::{Hash, Hasher};
6
7/// This data structure allows prefix queries and single element modification.
8///
9/// This tree allocates `n * sizeof(N)` bytes of memory, and can be resized.
10///
11/// This data structure is implemented using a [Fenwick tree][1], which is also known as a
12/// binary indexed tree.
13///
14/// The similar crate [`prefix-sum`] might also be of interest.
15///
16/// # Examples
17///
18/// Showcase of functionality:
19///
20/// ```rust
21/// use segment_tree::ops::Add;
22/// use segment_tree::PrefixPoint;
23///
24/// let buf = vec![10, 5, 30, 40];
25///
26/// let mut pp = PrefixPoint::build(buf, Add);
27///
28/// // If we query, we get the sum up until the specified value.
29/// assert_eq!(pp.query(0), 10);
30/// assert_eq!(pp.query(1), 15);
31/// assert_eq!(pp.query(2), 45);
32/// assert_eq!(pp.query(3), 85);
33///
34/// // Add five to the second value.
35/// pp.modify(1, 5);
36/// assert_eq!(pp.query(0), 10);
37/// assert_eq!(pp.query(1), 20);
38/// assert_eq!(pp.query(2), 50);
39/// assert_eq!(pp.query(3), 90);
40///
41/// // Multiply every value with 2.
42/// pp.map(|v| *v *= 2);
43/// assert_eq!(pp.query(0), 20);
44/// assert_eq!(pp.query(1), 40);
45/// assert_eq!(pp.query(2), 100);
46/// assert_eq!(pp.query(3), 180);
47///
48/// // Divide with two to undo.
49/// pp.map(|v| *v /= 2);
50/// // Add some more values.
51/// pp.extend(vec![0, 10].into_iter());
52/// assert_eq!(pp.query(0), 10);
53/// assert_eq!(pp.query(1), 20);
54/// assert_eq!(pp.query(2), 50);
55/// assert_eq!(pp.query(3), 90);
56/// assert_eq!(pp.query(4), 90);
57/// assert_eq!(pp.query(5), 100);
58///
59/// // Get the underlying values.
60/// assert_eq!(pp.get(0), 10);
61/// assert_eq!(pp.get(1), 10);
62/// assert_eq!(pp.get(2), 30);
63/// assert_eq!(pp.get(3), 40);
64/// assert_eq!(pp.get(4), 0);
65/// assert_eq!(pp.get(5), 10);
66///
67/// // Remove the last value
68/// pp.truncate(5);
69/// assert_eq!(pp.get(0), 10);
70/// assert_eq!(pp.get(1), 10);
71/// assert_eq!(pp.get(2), 30);
72/// assert_eq!(pp.get(3), 40);
73/// assert_eq!(pp.get(4), 0);
74///
75/// // Get back the original values.
76/// assert_eq!(pp.unwrap(), vec![10, 10, 30, 40, 0]);
77/// ```
78///
79/// You can also use other operators:
80///
81/// ```rust
82/// use segment_tree::ops::Mul;
83/// use segment_tree::PrefixPoint;
84///
85/// let buf = vec![10, 5, 30, 40];
86///
87/// let mut pp = PrefixPoint::build(buf, Mul);
88///
89/// assert_eq!(pp.query(0), 10);
90/// assert_eq!(pp.query(1), 50);
91/// assert_eq!(pp.query(2), 1500);
92/// assert_eq!(pp.query(3), 60000);
93/// ```
94///
95/// [1]: https://en.wikipedia.org/wiki/Fenwick_tree
96/// [`prefix-sum`]: https://crates.io/crates/prefix-sum
97pub struct PrefixPoint<N, O> where O: Commutative<N> {
98    buf: Vec<N>,
99    op: O
100}
101
102/// Returns the least significant bit that is one.
103#[inline(always)]
104fn lsb(i: usize) -> usize {
105    i & (1 + !i)
106}
107
108/// Could also be done with slice_at_mut, but that's a giant pain
109#[inline(always)]
110unsafe fn combine_mut<N, O: Commutative<N>>(buf: &mut Vec<N>, i: usize, j: usize, op: &O) {
111    debug_assert!(i != j);
112    let ptr1 = &mut buf[i] as *mut N;
113    let ptr2 = &buf[j] as *const N;
114    op.combine_mut(&mut *ptr1, &*ptr2);
115}
116/// Could also be done with slice_at_mut, but that's a giant pain
117#[inline(always)]
118unsafe fn uncombine_mut<N, O: Invertible<N>>(buf: &mut Vec<N>, i: usize, j: usize, op: &O) {
119    debug_assert!(i != j);
120    let ptr1 = &mut buf[i] as *mut N;
121    let ptr2 = &buf[j] as *const N;
122    op.uncombine(&mut *ptr1, &*ptr2);
123}
124
125impl<N, O: Commutative<N>> PrefixPoint<N, O> {
126    /// Creates a `PrefixPoint` containing the given values.
127    /// Uses `O(len)` time.
128    pub fn build(mut buf: Vec<N>, op: O) -> PrefixPoint<N, O> {
129        let len = buf.len();
130        for i in 0..len {
131            let j = i + lsb(i+1);
132            if j < len {
133                unsafe {
134                    combine_mut::<N, O>(&mut buf, j, i, &op);
135                }
136            }
137        }
138        PrefixPoint { buf: buf, op: op }
139    }
140    /// Returns the number of values in this tree.
141    /// Uses `O(1)` time.
142    pub fn len(&self) -> usize {
143        self.buf.len()
144    }
145    /// Computes `a[0] * a[1] * ... * a[i]`.  Note that `i` is inclusive.
146    /// Uses `O(log(i))` time.
147    #[inline]
148    pub fn query(&self, mut i: usize) -> N where N: Clone {
149        let mut sum = self.buf[i].clone();
150        i -= lsb(1+i) - 1;
151        while i > 0 {
152            sum = self.op.combine_left(sum, &self.buf[i-1]);
153            i -= lsb(i);
154        }
155        sum
156    }
157    /// Computes `a[0] * a[1] * ... * a[i]`.  Note that `i` is inclusive.
158    /// Uses `O(log(i))` time.
159    #[inline]
160    pub fn query_noclone(&self, mut i: usize) -> MaybeOwned<N> {
161        let mut sum = MaybeOwned::Borrowed(&self.buf[i]);
162        i -= lsb(1+i) - 1;
163        while i > 0 {
164            sum = MaybeOwned::Owned(match sum {
165                MaybeOwned::Borrowed(ref v) => self.op.combine(v, &self.buf[i-1]),
166                MaybeOwned::Owned(v) => self.op.combine_left(v, &self.buf[i-1]),
167            });
168            i -= lsb(i);
169        }
170        sum
171    }
172    /// Combine the value at `i` with `delta`.
173    /// Uses `O(log(len))` time.
174    #[inline]
175    pub fn modify(&mut self, mut i: usize, delta: N) {
176        let len = self.len();
177        while i < len {
178            self.op.combine_mut(&mut self.buf[i], &delta);
179            i += lsb(i+1);
180        }
181    }
182    /// Truncates the `PrefixPoint` to the given size.  If `size >= len`, this method does
183    /// nothing.  Uses `O(1)` time.
184    #[inline(always)]
185    pub fn truncate(&mut self, size: usize) {
186        self.buf.truncate(size);
187    }
188    /// Calls `shrink_to_fit` on the interval vector.
189    #[inline(always)]
190    pub fn shrink_to_fit(&mut self) {
191        self.buf.shrink_to_fit();
192    }
193    /// Replace every value in the type with `f(value)`.
194    /// This function assumes that `combine(f(a), f(b)) = f(combine(a, b))`.
195    ///
196    /// The function is applied `len` times.
197    #[inline]
198    pub fn map<F: FnMut(&mut N)>(&mut self, mut f: F) {
199        for val in &mut self.buf {
200            f(val);
201        }
202    }
203}
204impl<N, O: Commutative<N>> Extend<N> for PrefixPoint<N, O> {
205    /// Adds the given values to the `PrefixPoint`, increasing its size.
206    /// Uses `O(len)` time.
207    fn extend<I: IntoIterator<Item=N>>(&mut self, values: I) {
208        let oldlen = self.len();
209        self.buf.extend(values);
210        let len = self.len();
211        for i in 0..len {
212            let j = i + lsb(i+1);
213            if oldlen <= j && j < len {
214                unsafe {
215                    combine_mut::<N, O>(&mut self.buf, j, i, &self.op);
216                }
217            }
218        }
219    }
220}
221impl<N, O: Commutative<N> + Invertible<N>> PrefixPoint<N, O> {
222    /// Returns the value at `i`.
223    /// Uses `O(log(i))` time.
224    /// Store your own copy of the array if you want constant time.
225    pub fn get(&self, mut i: usize) -> N where N: Clone {
226        let mut sum = self.buf[i].clone();
227        let z = 1 + i - lsb(i+1);
228        while i != z {
229            self.op.uncombine(&mut sum, &self.buf[i-1]);
230            i -= lsb(i);
231        }
232        sum
233    }
234    /// Change the value at the index to be the specified value.
235    /// Uses `O(log(i))` time.
236    pub fn set(&mut self, i: usize, mut value: N) where N: Clone {
237        let current = self.get(i);
238        self.op.uncombine(&mut value, &current);
239        self.modify(i, value);
240    }
241    /// Compute the underlying array of values.
242    /// Uses `O(len)` time.
243    pub fn unwrap(self) -> Vec<N> {
244        let mut buf = self.buf;
245        let len = buf.len();
246        for i in (0..len).rev() {
247            let j = i + lsb(i+1);
248            if j < len {
249                unsafe {
250                    uncombine_mut::<N, O>(&mut buf, j, i, &self.op);
251                }
252            }
253        }
254        buf
255    }
256    /// Compute the underlying array of values.
257    /// Uses `O(len)` time.
258    pub fn unwrap_clone(&self) -> Vec<N> where N: Clone {
259        let len = self.buf.len();
260        let mut buf = self.buf.clone();
261        for i in (0..len).rev() {
262            let j = i + lsb(i+1);
263            if j < len {
264                unsafe {
265                    uncombine_mut::<N, O>(&mut buf, j, i, &self.op);
266                }
267            }
268        }
269        buf
270    }
271}
272
273impl<N: Clone, O: Commutative<N> + Clone> Clone for PrefixPoint<N, O> {
274    fn clone(&self) -> PrefixPoint<N, O> {
275        PrefixPoint {
276            buf: self.buf.clone(), op: self.op.clone()
277        }
278    }
279}
280impl<N, O: Commutative<N> + Default> Default for PrefixPoint<N, O> {
281    #[inline]
282    fn default() -> PrefixPoint<N, O> {
283        PrefixPoint { buf: Vec::new(), op: Default::default() }
284    }
285}
286impl<'a, N: 'a + Hash, O: Commutative<N>> Hash for PrefixPoint<N, O> {
287    #[inline]
288    fn hash<H: Hasher>(&self, state: &mut H) {
289        self.buf.hash(state);
290    }
291}
292
293#[cfg(test)]
294mod tests {
295
296    /// Modifies the given slice such that the n'th element becomes the sum of the first n
297    /// elements.
298    pub fn compute_prefix_sum<N: ::std::ops::Add<Output=N> + Copy>(buf: &mut[N]) {
299        let mut iter = buf.iter_mut();
300        match iter.next() {
301            None => {},
302            Some(s) => {
303                let mut sum = *s;
304                for item in iter {
305                    sum = sum + *item;
306                    *item = sum;
307                }
308            }
309        }
310    }
311
312    use super::*;
313    use rand::distributions::Standard;
314    use rand::prelude::*;
315    use std::num::Wrapping;
316    use crate::ops::Add;
317
318    fn random_vec(rng: &mut ThreadRng, len: usize) -> Vec<Wrapping<i32>> {
319        rng.sample_iter(&Standard).map(|i| Wrapping(i)).take(len).collect()
320    }
321
322    #[test]
323    fn fenwick_query() {
324        let mut rng = thread_rng();
325        for n in 0..130 {
326            let mut vec = random_vec(&mut rng, n);
327            let fenwick = PrefixPoint::build(vec.clone(), Add);
328            compute_prefix_sum(&mut vec);
329            for i in 0..vec.len() {
330                assert_eq!(vec[i], fenwick.query(i));
331                assert_eq!(&vec[i], fenwick.query_noclone(i).borrow());
332            }
333        }
334    }
335    #[test]
336    fn fenwick_map() {
337        let mut rng = thread_rng();
338        for n in 0..130 {
339            let mut vec = random_vec(&mut rng, n);
340            let mut fenwick = PrefixPoint::build(vec.clone(), Add);
341            assert_eq!(fenwick.clone().unwrap(), vec);
342            assert_eq!(fenwick.clone().unwrap_clone(), vec);
343            compute_prefix_sum(&mut vec);
344            fenwick.map(|n| *n = Wrapping(12) * *n);
345            for i in 0..vec.len() {
346                assert_eq!(vec[i]*Wrapping(12), fenwick.query(i));
347            }
348        }
349    }
350    #[test]
351    fn fenwick_modify() {
352        let mut rng = thread_rng();
353        for n in 0..130 {
354            let mut vec = random_vec(&mut rng, n);
355            let diff = random_vec(&mut rng, n);
356            let mut fenwick = PrefixPoint::build(vec.clone(), Add);
357            for i in 0..diff.len() {
358                let mut ps: Vec<Wrapping<i32>> = vec.clone();
359                compute_prefix_sum(&mut ps);
360                assert_eq!(fenwick.clone().unwrap(), vec);
361                assert_eq!(fenwick.clone().unwrap_clone(), vec);
362                for j in 0..vec.len() {
363                    assert_eq!(ps[j], fenwick.query(j));
364                    assert_eq!(vec[j], fenwick.get(j));
365                }
366                vec[i] += diff[i];
367                fenwick.modify(i, diff[i]);
368            }
369        }
370    }
371    #[test]
372    fn fenwick_set() {
373        let mut rng = thread_rng();
374        for n in 0..130 {
375            let mut vec = random_vec(&mut rng, n);
376            let diff = random_vec(&mut rng, n);
377            let mut fenwick = PrefixPoint::build(vec.clone(), Add);
378            for i in 0..diff.len() {
379                let mut ps: Vec<Wrapping<i32>> = vec.clone();
380                compute_prefix_sum(&mut ps);
381                assert_eq!(fenwick.clone().unwrap(), vec);
382                assert_eq!(fenwick.clone().unwrap_clone(), vec);
383                for j in 0..vec.len() {
384                    assert_eq!(ps[j], fenwick.query(j));
385                    assert_eq!(vec[j], fenwick.get(j));
386                }
387                vec[i] = diff[i];
388                fenwick.set(i, diff[i]);
389            }
390        }
391    }
392    #[test]
393    fn fenwick_extend() {
394        let mut rng = thread_rng();
395        for n in 0..130 {
396            let vec = random_vec(&mut rng, n);
397            let mut sum = vec.clone();
398            compute_prefix_sum(&mut sum);
399            for i in 0..sum.len() {
400                let mut fenwick = PrefixPoint::build(vec.iter().take(i/2).map(|&i| i).collect(), Add);
401                fenwick.extend(vec.iter().skip(i/2).take(i - i/2).map(|&i| i));
402                for j in 0..i {
403                    assert_eq!(sum[j], fenwick.query(j));
404                }
405            }
406        }
407    }
408}