seg_tree/utils/
sum.rs

1use std::ops::{Add, Mul};
2
3use crate::nodes::{LazyNode, Node};
4
5/// Implementation of range sum for generic type T, it implements [`Node`] and [`LazyNode`], as such it can be used as a node in every segment tree type.
6#[derive(Clone, Debug)]
7pub struct Sum<T>
8where
9    T: Add<Output = T>,
10{
11    value: T,
12    lazy_value: Option<T>,
13}
14
15impl<T> Node for Sum<T>
16where
17    T: Add<Output = T> + Clone,
18{
19    type Value = T;
20    /// The node is initialized with the value given.
21    #[inline]
22    fn initialize(v: &Self::Value) -> Self {
23        Self {
24            value: v.clone(),
25            lazy_value: None,
26        }
27    }
28    /// As this is a range sum node, the operation which is used to 'merge' two nodes is `+`.
29    #[inline]
30    fn combine(a: &Self, b: &Self) -> Self {
31        Self {
32            value: a.value.clone() + b.value.clone(),
33            lazy_value: None,
34        }
35    }
36    #[inline]
37    fn value(&self) -> &Self::Value {
38        &self.value
39    }
40}
41
42/// Implementation for sum range query node, the update adds the value to each item in the range.
43/// It assumes that `a*n`, where a: T and n: usize is well defined and `a*n = a+...+a` with 'n' a.
44/// For non-commutative operations, two things will be true `lazy_value = lazy_value + new_value`.
45impl<T> LazyNode for Sum<T>
46where
47    T: Add<Output = T> + Mul<usize, Output = T> + Clone,
48{
49    fn lazy_update(&mut self, i: usize, j: usize) {
50        if let Some(value) = self.lazy_value.take() {
51            let temp = self.value.clone() + value * (j - i + 1);
52            self.value = temp;
53        }
54    }
55
56    fn update_lazy_value(&mut self, new_value: &<Self as Node>::Value) {
57        if let Some(value) = self.lazy_value.take() {
58            self.lazy_value = Some(value + new_value.clone());
59        } else {
60            self.lazy_value = Some(new_value.clone());
61        }
62    }
63    #[inline]
64    fn lazy_value(&self) -> Option<&<Self as Node>::Value> {
65        self.lazy_value.as_ref()
66    }
67}
68
69#[cfg(test)]
70mod tests {
71    use std::ops::{Add, Mul};
72
73    use crate::{
74        nodes::{LazyNode, Node},
75        utils::Sum,
76    };
77
78    #[derive(Clone, Copy, Debug, PartialEq)]
79    struct NonCommutativeTest(usize);
80    /// It satisfies a+b==b
81    impl Add for NonCommutativeTest {
82        type Output = Self;
83
84        fn add(self, rhs: Self) -> Self::Output {
85            rhs
86        }
87    }
88
89    impl Mul<usize> for NonCommutativeTest {
90        type Output = Self;
91
92        fn mul(self, _rhs: usize) -> Self::Output {
93            self
94        }
95    }
96
97    const N: usize = 1_000;
98
99    #[test]
100    fn sum_works() {
101        let nodes: Vec<Sum<usize>> = (0..=N).map(|x| Sum::initialize(&x)).collect();
102        let result = nodes
103            .iter()
104            .fold(Sum::initialize(&0), |acc, new| Sum::combine(&acc, new));
105        assert_eq!(result.value(), &((N+1)*N/2));
106    }
107
108    #[test]
109    fn non_commutative_sum_works() {
110        let nodes: Vec<Sum<NonCommutativeTest>> = (0..=N)
111            .map(|x| Sum::initialize(&NonCommutativeTest(x)))
112            .collect();
113        let result = nodes
114            .iter()
115            .fold(Sum::initialize(&NonCommutativeTest(0)), |acc, new| {
116                Sum::combine(&acc, new)
117            });
118        assert_eq!(result.value(), &NonCommutativeTest(N));
119    }
120
121    #[test]
122    fn update_lazy_value_works() {
123        let mut node = Sum::initialize(&1);
124        node.update_lazy_value(&2);
125        assert_eq!(node.lazy_value(), Some(&2));
126    }
127
128    #[test]
129    fn lazy_update_works() {
130        // Node represents the range [0,10] with sum 1.
131        let mut node = Sum::initialize(&1);
132        node.update_lazy_value(&2);
133        node.lazy_update(0, 10);
134        assert_eq!(node.value(), &23);
135    }
136
137    #[test]
138    fn non_commutative_update_lazy_value_works() {
139        let mut node = Sum::initialize(&NonCommutativeTest(1));
140        node.update_lazy_value(&NonCommutativeTest(2));
141        assert_eq!(node.lazy_value(), Some(&NonCommutativeTest(2)));
142    }
143    #[test]
144    fn non_commutative_lazy_update_works() {
145        // Node represents the range [0,10] with sum 1.
146        let mut node = Sum::initialize(&NonCommutativeTest(1));
147        node.update_lazy_value(&NonCommutativeTest(2));
148        node.lazy_update(0, 10);
149        assert_eq!(node.value(), &NonCommutativeTest(2));
150    }
151}