1use std::ops::{Add, Mul};
2
3use crate::nodes::{LazyNode, Node};
4
5#[derive(Clone, Debug)]
7pub struct Sum<T>
8where
9 T: Add<Output = T>,
10{
11 value: T,
12 lazy_value: Option<T>,
13}
14
15impl<T> Node for Sum<T>
16where
17 T: Add<Output = T> + Clone,
18{
19 type Value = T;
20 #[inline]
22 fn initialize(v: &Self::Value) -> Self {
23 Self {
24 value: v.clone(),
25 lazy_value: None,
26 }
27 }
28 #[inline]
30 fn combine(a: &Self, b: &Self) -> Self {
31 Self {
32 value: a.value.clone() + b.value.clone(),
33 lazy_value: None,
34 }
35 }
36 #[inline]
37 fn value(&self) -> &Self::Value {
38 &self.value
39 }
40}
41
42impl<T> LazyNode for Sum<T>
46where
47 T: Add<Output = T> + Mul<usize, Output = T> + Clone,
48{
49 fn lazy_update(&mut self, i: usize, j: usize) {
50 if let Some(value) = self.lazy_value.take() {
51 let temp = self.value.clone() + value * (j - i + 1);
52 self.value = temp;
53 }
54 }
55
56 fn update_lazy_value(&mut self, new_value: &<Self as Node>::Value) {
57 if let Some(value) = self.lazy_value.take() {
58 self.lazy_value = Some(value + new_value.clone());
59 } else {
60 self.lazy_value = Some(new_value.clone());
61 }
62 }
63 #[inline]
64 fn lazy_value(&self) -> Option<&<Self as Node>::Value> {
65 self.lazy_value.as_ref()
66 }
67}
68
69#[cfg(test)]
70mod tests {
71 use std::ops::{Add, Mul};
72
73 use crate::{
74 nodes::{LazyNode, Node},
75 utils::Sum,
76 };
77
78 #[derive(Clone, Copy, Debug, PartialEq)]
79 struct NonCommutativeTest(usize);
80 impl Add for NonCommutativeTest {
82 type Output = Self;
83
84 fn add(self, rhs: Self) -> Self::Output {
85 rhs
86 }
87 }
88
89 impl Mul<usize> for NonCommutativeTest {
90 type Output = Self;
91
92 fn mul(self, _rhs: usize) -> Self::Output {
93 self
94 }
95 }
96
97 const N: usize = 1_000;
98
99 #[test]
100 fn sum_works() {
101 let nodes: Vec<Sum<usize>> = (0..=N).map(|x| Sum::initialize(&x)).collect();
102 let result = nodes
103 .iter()
104 .fold(Sum::initialize(&0), |acc, new| Sum::combine(&acc, new));
105 assert_eq!(result.value(), &((N+1)*N/2));
106 }
107
108 #[test]
109 fn non_commutative_sum_works() {
110 let nodes: Vec<Sum<NonCommutativeTest>> = (0..=N)
111 .map(|x| Sum::initialize(&NonCommutativeTest(x)))
112 .collect();
113 let result = nodes
114 .iter()
115 .fold(Sum::initialize(&NonCommutativeTest(0)), |acc, new| {
116 Sum::combine(&acc, new)
117 });
118 assert_eq!(result.value(), &NonCommutativeTest(N));
119 }
120
121 #[test]
122 fn update_lazy_value_works() {
123 let mut node = Sum::initialize(&1);
124 node.update_lazy_value(&2);
125 assert_eq!(node.lazy_value(), Some(&2));
126 }
127
128 #[test]
129 fn lazy_update_works() {
130 let mut node = Sum::initialize(&1);
132 node.update_lazy_value(&2);
133 node.lazy_update(0, 10);
134 assert_eq!(node.value(), &23);
135 }
136
137 #[test]
138 fn non_commutative_update_lazy_value_works() {
139 let mut node = Sum::initialize(&NonCommutativeTest(1));
140 node.update_lazy_value(&NonCommutativeTest(2));
141 assert_eq!(node.lazy_value(), Some(&NonCommutativeTest(2)));
142 }
143 #[test]
144 fn non_commutative_lazy_update_works() {
145 let mut node = Sum::initialize(&NonCommutativeTest(1));
147 node.update_lazy_value(&NonCommutativeTest(2));
148 node.lazy_update(0, 10);
149 assert_eq!(node.value(), &NonCommutativeTest(2));
150 }
151}