segment_tree/
segment_tree.rs

1use std::{mem, fmt};
2use std::cmp::{PartialEq, Eq};
3use std::default::Default;
4use std::hash::{Hash, Hasher};
5
6use crate::ops::{Operation, Commutative, Identity};
7use crate::maybe_owned::MaybeOwned;
8
9/// This data structure allows range queries and single element modification.
10///
11/// This tree allocates `2n * sizeof(N)` bytes of memory.
12///
13/// This tree is implemented using a segment tree.  A segment tree is a binary tree where
14/// each node contains the combination of the children under the operation.
15///
16///# Examples
17///
18/// This example solves the [range minimum query][1] problem.
19///
20/// ```rust
21/// use segment_tree::SegmentPoint;
22/// use segment_tree::ops::Min;
23///
24/// // Let's solve the range minimum query on this array.
25/// let mut tree = SegmentPoint::build(
26///     vec![10i32, 5, 6, 4, 12, 8, 9, 3, 2, 1, 5], Min
27/// ); //        0  1  2  3   4  5  6  7  8  9 10  - indices
28///
29/// // Find the minimum value in a few intervals. Note that the second argument is
30/// // exclusive.
31/// assert_eq!(tree.query(0, 2), 5);
32/// assert_eq!(tree.query(4, 8), 3);
33/// assert_eq!(tree.query(3, 11), 1);
34/// assert_eq!(tree.query(0, 11), 1);
35///
36/// // query returns the identity if given an invalid interval
37/// // The identity of min is the MAX.
38/// assert_eq!(tree.query(4, 2), std::i32::MAX);
39///
40/// // We can change individual values in the array as well.
41/// tree.modify(2, 0);
42/// assert_eq!(tree.query(0, 3), 0);
43/// assert_eq!(tree.query(3, 6), 4);
44///
45/// // We can view the values currently stored at any time.
46/// assert_eq!(tree.view(), &[10, 5, 0, 4, 12, 8, 9, 3, 2, 1, 5]);
47/// ```
48///
49/// We can also use a `SegmentPoint` to find the sum of any interval, by changing the
50/// operator to [`Add`].
51///
52/// ```rust
53/// use segment_tree::SegmentPoint;
54/// use segment_tree::ops::Add;
55///
56/// let mut tree = SegmentPoint::build(
57///     vec![10, 5, 6, 4, 12, 8, 9, 3, 2, 1, 5], Add
58/// ); //     0  1  2  3   4  5  6  7  8  9 10  - indices
59///
60/// assert_eq!(tree.query(4, 8), 12 + 8 + 9 + 3);
61/// assert_eq!(tree.query(1, 3), 5 + 6);
62///
63/// // we can still modify values in the tree
64/// tree.modify(2, 4);
65/// assert_eq!(tree.query(1, 3), 5 + 4);
66/// assert_eq!(tree.query(4, 8), 12 + 8 + 9 + 3);
67///
68/// assert_eq!(tree.view(), &[10, 5, 4, 4, 12, 8, 9, 3, 2, 1, 5]);
69/// ```
70///
71/// [1]: https://en.wikipedia.org/wiki/Range_minimum_query
72/// [`Add`]: ops/struct.Add.html
73pub struct SegmentPoint<N, O> where O: Operation<N> {
74    buf: Vec<N>,
75    n: usize,
76    op: O
77}
78
79impl<N, O: Operation<N>> SegmentPoint<N, O> {
80    /// Builds a tree using the given buffer.  If the given buffer is less than half full,
81    /// this function allocates.  This function clones every value in the input array.
82    /// Uses `O(len)` time.
83    ///
84    /// See also the function [`build_noalloc`].
85    ///
86    /// [`build_noalloc`]: struct.SegmentPoint.html#method.build_noalloc
87    pub fn build(mut buf: Vec<N>, op: O) -> SegmentPoint<N, O> where N: Clone {
88        let n = buf.len();
89        buf.reserve_exact(n);
90        for i in 0..n {
91            debug_assert!(i < buf.len());
92            let clone = unsafe { buf.get_unchecked(i).clone() }; // i < n < buf.len()
93            buf.push(clone);
94        }
95        SegmentPoint::build_noalloc(buf, op)
96    }
97    /// Set the value at the specified index and return the old value.
98    /// Uses `O(log(len))` time.
99    pub fn modify(&mut self, mut p: usize, value: N) -> N {
100        p += self.n;
101        let res = mem::replace(&mut self.buf[p], value);
102        while { p >>= 1; p > 0 } {
103            self.buf[p] = self.op.combine(&self.buf[p<<1], &self.buf[p<<1|1]);
104        }
105        res
106    }
107    /// Computes `a[l] * a[l+1] * ... * a[r-1]`.
108    /// Uses `O(log(len))` time.
109    ///
110    /// If `l >= r`, this method returns the identity.
111    ///
112    /// See [`query_noiden`] or [`query_noclone`] for a version that works with
113    /// non-[commutative operations][1].
114    ///
115    /// [`query_noiden`]: struct.SegmentPoint.html#method.query_noiden
116    /// [`query_noclone`]: struct.SegmentPoint.html#method.query_noclone
117    /// [1]: ops/trait.Commutative.html
118    pub fn query(&self, mut l: usize, mut r: usize) -> N
119    where
120        O: Commutative<N> + Identity<N>
121    {
122        let mut res = self.op.identity();
123        l += self.n; r += self.n;
124        while l < r {
125            if l&1 == 1 {
126                res = self.op.combine_left(res, &self.buf[l]);
127                l += 1;
128            }
129            if r&1 == 1 {
130                r -= 1;
131                res = self.op.combine_left(res, &self.buf[r]);
132            }
133            l >>= 1; r >>= 1;
134        }
135        res
136    }
137    /// Combine the value at `p` with `delta`.
138    /// Uses `O(log(len))` time.
139    #[inline(always)]
140    pub fn compose(&mut self, p: usize, delta: &N)
141    where
142        O: Commutative<N>
143    {
144        self.compose_right(p, delta);
145    }
146    /// Combine the value at `p` with `delta`, such that `delta` is the left argument.
147    /// Uses `O(log(len))` time.
148    pub fn compose_left(&mut self, mut p: usize, delta: &N) {
149        p += self.n;
150        self.op.combine_mut2(delta, &mut self.buf[p]);
151        while { p >>= 1; p > 0 } {
152            self.buf[p] = self.op.combine(&self.buf[p<<1], &self.buf[p<<1|1]);
153        }
154    }
155    /// Combine the value at `p` with `delta`, such that `delta` is the right argument.
156    /// Uses `O(log(len))` time.
157    pub fn compose_right(&mut self, mut p: usize, delta: &N) {
158        p += self.n;
159        self.op.combine_mut(&mut self.buf[p], delta);
160        while { p >>= 1; p > 0 } {
161            self.buf[p] = self.op.combine(&self.buf[p<<1], &self.buf[p<<1|1]);
162        }
163    }
164    /// View the values in this segment tree using a slice.  Uses `O(1)` time.
165    #[inline(always)]
166    pub fn view(&self) -> &[N] {
167        &self.buf[self.n..]
168    }
169    /// The number of elements stored in this segment tree.  Uses `O(1)` time.
170    #[inline(always)]
171    pub fn len(&self) -> usize {
172        self.n
173    }
174    /// Builds a tree using the given buffer.  The buffer must be even in size.
175    /// The first `n` values have no effect on the resulting tree,
176    /// and the remaining `n` values contains the array to build the tree on.
177    /// Uses `O(len)` time.
178    ///
179    /// This function panics if the size of the buffer is odd.
180    ///
181    /// # Example
182    ///
183    /// ```rust
184    /// use segment_tree::SegmentPoint;
185    /// use segment_tree::ops::Min;
186    ///
187    /// // make a segment point using the other build method
188    /// let mut tree = SegmentPoint::build(
189    ///     vec![1, 2, 3, 4], Min
190    /// );
191    /// // make a segment point using the build_noalloc method:
192    /// let mut tree1 = SegmentPoint::build_noalloc(
193    ///     // the first half of the values are ignored
194    ///     vec![3282, 210, 0, 245, 1, 2, 3, 4], Min
195    /// );
196    /// assert_eq!(tree, tree1);
197    ///
198    /// let mut tree2 = SegmentPoint::build_noalloc(
199    ///     // we can also try some other first few values
200    ///     vec![0, 0, 0, 0, 1, 2, 3, 4], Min
201    /// );
202    /// assert_eq!(tree1, tree2);
203    /// assert_eq!(tree, tree2);
204    /// ```
205    pub fn build_noalloc(mut buf: Vec<N>, op: O) -> SegmentPoint<N, O> {
206        let len = buf.len();
207        let n = len >> 1;
208        if len & 1 == 1 {
209            panic!("SegmentPoint::build_noalloc: odd size");
210        }
211        for i in (1..n).rev() {
212            let res = op.combine(&buf[i<<1], &buf[i<<1 | 1]);
213            buf[i] = res;
214        }
215        SegmentPoint {
216            buf: buf, op: op, n: n
217        }
218    }
219}
220impl<N, O: Operation<N>> SegmentPoint<N, O> {
221    /// Like [`query`], except it doesn't require the operation to be commutative, nor to
222    /// have any identity.
223    ///
224    /// Computes `a[l] * a[l+1] * ... * a[r-1]`.
225    ///
226    /// This method panics if `l >= r`.
227    ///
228    /// This method clones at most twice and runs in `O(log(len))` time.
229    /// See [`query_noclone`] for a version that doesn't clone.
230    ///
231    /// [`query_noclone`]: struct.SegmentPoint.html#method.query_noclone
232    /// [`query`]: struct.SegmentPoint.html#method.query
233    pub fn query_noiden(&self, mut l: usize, mut r: usize) -> N where N: Clone {
234        let mut resl = None;
235        let mut resr = None;
236        l += self.n; r += self.n;
237        while l < r {
238            if l&1 == 1 {
239                resl = match resl {
240                    None => Some(self.buf[l].clone()),
241                    Some(v) => Some(self.op.combine_left(v, &self.buf[l]))
242                };
243                l += 1;
244            }
245            if r&1 == 1 {
246                r -= 1;
247                resr = match resr {
248                    None => Some(self.buf[r].clone()),
249                    Some(v) => Some(self.op.combine_right(&self.buf[r], v))
250                }
251            }
252            l >>= 1; r >>= 1;
253        }
254        match resl {
255            None => match resr {
256                None => panic!("Empty interval."),
257                Some(r) => r,
258            },
259            Some(l) => match resr {
260                None => l,
261                Some(r) => self.op.combine_both(l, r)
262            }
263        }
264    }
265    /// Like [`query_noiden`], except it doesn't clone.
266    ///
267    /// Computes `a[l] * a[l+1] * ... * a[r-1]`.
268    ///
269    /// This method panics if `l >= r`.
270    ///
271    /// Uses `O(log(len))` time.
272    /// See also [`query`] and [`query_commut`].
273    ///
274    /// [`query_commut`]: struct.SegmentPoint.html#method.query_commut
275    /// [`query_noiden`]: struct.SegmentPoint.html#method.query_noiden
276    /// [`query`]: struct.SegmentPoint.html#method.query
277    pub fn query_noclone<'a>(&'a self, mut l: usize, mut r: usize) -> MaybeOwned<'a, N> {
278        let mut resl = None;
279        let mut resr = None;
280        l += self.n; r += self.n;
281        while l < r {
282            if l&1 == 1 {
283                resl = match resl {
284                    None => Some(MaybeOwned::Borrowed(&self.buf[l])),
285                    Some(MaybeOwned::Borrowed(ref v)) =>
286                        Some(MaybeOwned::Owned(self.op.combine(v, &self.buf[l]))),
287                    Some(MaybeOwned::Owned(v)) =>
288                        Some(MaybeOwned::Owned(self.op.combine_left(v, &self.buf[l]))),
289                };
290                l += 1;
291            }
292            if r&1 == 1 {
293                r -= 1;
294                resr = match resr {
295                    None => Some(MaybeOwned::Borrowed(&self.buf[r])),
296                    Some(MaybeOwned::Borrowed(ref v)) =>
297                        Some(MaybeOwned::Owned(self.op.combine(&self.buf[r], v))),
298                    Some(MaybeOwned::Owned(v)) =>
299                        Some(MaybeOwned::Owned(self.op.combine_right(&self.buf[r], v))),
300                }
301            }
302            l >>= 1; r >>= 1;
303        }
304        match resl {
305            None => match resr {
306                None => panic!("Empty interval."),
307                Some(v) => v,
308            },
309            Some(MaybeOwned::Borrowed(ref l)) => match resr {
310                None => MaybeOwned::Borrowed(l),
311                Some(MaybeOwned::Borrowed(ref r)) =>
312                    MaybeOwned::Owned(self.op.combine(l, r)),
313                Some(MaybeOwned::Owned(r)) =>
314                    MaybeOwned::Owned(self.op.combine_right(l, r))
315            },
316            Some(MaybeOwned::Owned(l)) => match resr {
317                None => MaybeOwned::Owned(l),
318                Some(MaybeOwned::Borrowed(ref r)) =>
319                    MaybeOwned::Owned(self.op.combine_left(l, r)),
320                Some(MaybeOwned::Owned(r)) =>
321                    MaybeOwned::Owned(self.op.combine_both(l, r))
322            }
323        }
324    }
325}
326
327impl<N: Clone, O: Operation<N> + Clone> Clone for SegmentPoint<N, O> {
328    #[inline]
329    fn clone(&self) -> SegmentPoint<N, O> {
330        SegmentPoint {
331            buf: self.buf.clone(), n: self.n, op: self.op.clone()
332        }
333    }
334}
335impl<N: fmt::Debug, O: Operation<N>> fmt::Debug for SegmentPoint<N, O> {
336    #[inline]
337    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
338        write!(f, "SegmentPoint({:?})", self.view())
339    }
340}
341impl<N: PartialEq, O: Operation<N> + PartialEq> PartialEq for SegmentPoint<N, O> {
342    #[inline]
343    fn eq(&self, other: &SegmentPoint<N, O>) -> bool {
344        self.op.eq(&other.op) && self.view().eq(other.view())
345    }
346    #[inline]
347    fn ne(&self, other: &SegmentPoint<N, O>) -> bool {
348        self.op.ne(&other.op) && self.view().ne(other.view())
349    }
350}
351impl<N: Eq, O: Operation<N> + Eq> Eq for SegmentPoint<N, O> { }
352impl<N, O: Operation<N> + Default> Default for SegmentPoint<N, O> {
353    #[inline]
354    fn default() -> SegmentPoint<N, O> {
355        SegmentPoint { buf: Vec::new(), n: 0, op: Default::default() }
356    }
357}
358impl<'a, N: 'a + Hash, O: Operation<N>> Hash for SegmentPoint<N, O> {
359    #[inline]
360    fn hash<H: Hasher>(&self, state: &mut H) {
361        self.view().hash(state);
362    }
363}
364
365#[cfg(test)]
366mod tests {
367    use crate::SegmentPoint;
368    use crate::ops::*;
369    use crate::maybe_owned::MaybeOwned;
370    use rand::prelude::*;
371    use rand::distributions::{Distribution, Standard};
372    use rand::seq::SliceRandom;
373    use std::num::Wrapping;
374
375    /// Not commutative! Not useful in practice since the root always contains the
376    /// concatenation of every string.
377    #[derive(PartialEq, Eq, Clone, Debug)]
378    struct StrType {
379        value: String
380    }
381    impl StrType {
382        fn cat(list: &[StrType]) -> StrType {
383            let mut res = String::new();
384            for v in list {
385                res.push_str(v.value.as_str());
386            }
387            StrType { value: res }
388        }
389        fn sub(&self, i: usize, j: usize) -> StrType {
390            StrType { value: String::from(&self.value[4*i .. 4*j]) }
391        }
392    }
393    impl Distribution<StrType> for Standard {
394        fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> StrType {
395            let a = rng.gen_range('A' as u8, 'Z' as u8+1);
396            let b = rng.gen_range('A' as u8, 'Z' as u8+1);
397            let c = rng.gen_range('A' as u8, 'Z' as u8+1);
398            let d = rng.gen_range('A' as u8, 'Z' as u8+1);
399            let bytes = [a, b, c, d];
400            let utf8 = std::str::from_utf8(&bytes).unwrap();
401            StrType { value: String::from(utf8) }
402        }
403    }
404    impl Operation<StrType> for Add {
405        fn combine(&self, a: &StrType, b: &StrType) -> StrType {
406            StrType {
407                value: a.value.clone() + b.value.as_str()
408            }
409        }
410        fn combine_mut(&self, a: &mut StrType, b: &StrType) {
411            a.value.push_str(b.value.as_str());
412        }
413    }
414
415    #[test]
416    fn segment_tree_build() {
417        let mut rng = thread_rng();
418        let vals: Vec<Wrapping<i32>> = rng.sample_iter(&Standard)
419            .map(|i| Wrapping(i)).take(130).collect();
420        for i in 0..vals.len() {
421            let buf: Vec<_> = vals[0..i].iter().cloned().collect();
422            println!("{:?}", buf);
423            let mut buf2 = vec![];
424            let n = buf.len();
425            buf2.resize(2*n, Wrapping(0));
426            for i in 0..n {
427                buf2[n+i] = buf[i];
428            }
429            let tree1 = SegmentPoint::build(buf, Add);
430            let tree2 = SegmentPoint::build_noalloc(buf2, Add);
431            let mut buf = tree1.buf;
432            let mut buf2 = tree2.buf;
433            if i > 0 {
434                buf[0] = Wrapping(0);
435                buf2[0] = Wrapping(0);
436            }
437            println!("build");
438            println!("{:?}", buf);
439            println!("build_noalloc");
440            println!("{:?}", buf2);
441            assert_eq!(buf, buf2);
442            assert_eq!(buf.len(), 2*n);
443        }
444    }
445    #[test]
446    fn segment_tree_query() {
447        let mut rng = thread_rng();
448        let vals: Vec<StrType> = rng.sample_iter(&Standard).take(130).collect();
449        for i in 0..vals.len() {
450            let buf: Vec<_> = vals[0..i].iter().cloned().collect();
451            let tree = SegmentPoint::build(buf.clone(), Add);
452            let sum = StrType::cat(&buf);
453            let n = buf.len();
454            println!("n: {} tree.buf.len: {}", n, tree.buf.len());
455            for i in 0..n {
456                for j in i+1..n+1 {
457                    println!("i: {}, j: {}", i, j);
458                    assert_eq!(tree.query_noiden(i, j), sum.sub(i, j));
459                    assert_eq!(tree.query_noclone(i, j), MaybeOwned::Owned(sum.sub(i, j)));
460                }
461            }
462        }
463    }
464    #[test]
465    fn segment_tree_query_commut() {
466        let mut rng = thread_rng();
467        let vals: Vec<Wrapping<i32>> = rng.sample_iter(&Standard)
468            .map(|i| Wrapping(i)).take(130).collect();
469        for i in 0..vals.len() {
470            let mut buf: Vec<_> = vals[0..i].iter().cloned().collect();
471            let tree = SegmentPoint::build(buf.clone(), Add);
472            assert_eq!(tree.view(), &buf[..]);
473
474            for i in 1..buf.len() {
475                let prev = buf[i-1];
476                buf[i] += prev;
477            }
478
479            let n = buf.len();
480            println!("n: {} tree.buf.len: {}", n, tree.buf.len());
481            for i in 0..n {
482                for j in i+1..n+1 {
483                    println!("i: {}, j: {}", i, j);
484                    if i == 0 {
485                        assert_eq!(tree.query(i, j), buf[j-1]);
486                    } else {
487                        assert_eq!(tree.query(i, j), buf[j-1] - buf[i-1]);
488                    }
489                }
490            }
491        }
492    }
493    #[test]
494    fn segment_tree_modify() {
495        let mut rng = thread_rng();
496        let vals1: Vec<Wrapping<i32>> = rng.sample_iter(&Standard)
497            .map(|i| Wrapping(i)).take(130).collect();
498        let vals2: Vec<Wrapping<i32>> = rng.sample_iter(&Standard)
499            .map(|i| Wrapping(i)).take(130).collect();
500        for i in 0..vals1.len() {
501            let mut order: Vec<_> = (0..i).collect();
502            order.shuffle(&mut rng);
503            let mut buf: Vec<_> = vals1[0..i].iter().cloned().collect();
504            let mut tree = SegmentPoint::build(buf.clone(), Add);
505            for next in order {
506                tree.modify(next, vals2[next]);
507                buf[next] = vals2[next];
508                let tree2 = SegmentPoint::build(buf.clone(), Add);
509                assert_eq!(tree.buf[1..], tree2.buf[1..]);
510            }
511        }
512    }
513}