vortex_compute/arithmetic/
buffer_checked.rs1use vortex_buffer::Buffer;
5use vortex_buffer::BufferMut;
6
7use crate::arithmetic::CheckedArithmetic;
8use crate::arithmetic::CheckedOperator;
9
10impl<Op, T> CheckedArithmetic<Op, &Buffer<T>> for Buffer<T>
12where
13 T: Copy + num_traits::Zero,
14 BufferMut<T>: for<'a> CheckedArithmetic<Op, &'a Buffer<T>, Output = Buffer<T>>,
15 for<'a> &'a Buffer<T>: CheckedArithmetic<Op, &'a Buffer<T>, Output = Buffer<T>>,
16{
17 type Output = Buffer<T>;
18
19 fn checked_eval(self, rhs: &Buffer<T>) -> Option<Self::Output> {
20 match self.try_into_mut() {
21 Ok(lhs) => lhs.checked_eval(rhs),
22 Err(lhs) => (&lhs).checked_eval(rhs), }
24 }
25}
26
27impl<Op, T> CheckedArithmetic<Op, &Buffer<T>> for BufferMut<T>
29where
30 T: Copy + num_traits::Zero,
31 Op: CheckedOperator<T>,
32{
33 type Output = Buffer<T>;
34
35 fn checked_eval(self, rhs: &Buffer<T>) -> Option<Self::Output> {
36 assert_eq!(self.len(), rhs.len());
37
38 let mut i = 0;
39 let mut overflow = false;
40 let buffer = self
41 .map_each_in_place(|a| {
42 let b = unsafe { *rhs.get_unchecked(i) };
44 i += 1;
45
46 Op::apply(&a, &b).unwrap_or_else(|| {
50 overflow = true;
51 T::zero()
52 })
53 })
54 .freeze();
55
56 (!overflow).then_some(buffer)
57 }
58}
59
60impl<Op, T> CheckedArithmetic<Op> for &Buffer<T>
62where
63 T: Copy + num_traits::Zero,
64 Op: CheckedOperator<T>,
65{
66 type Output = Buffer<T>;
67
68 fn checked_eval(self, rhs: &Buffer<T>) -> Option<Self::Output> {
69 assert_eq!(self.len(), rhs.len());
70
71 let mut overflow = false;
72 let buffer =
73 Buffer::<T>::from_trusted_len_iter(self.iter().zip(rhs.iter()).map(|(a, b)| {
74 Op::apply(a, b).unwrap_or_else(|| {
78 overflow = true;
79 T::zero()
80 })
81 }));
82 (!overflow).then_some(buffer)
83 }
84}
85
86impl<Op, T> CheckedArithmetic<Op, &T> for Buffer<T>
89where
90 T: Copy + num_traits::Zero,
91 BufferMut<T>: for<'a> CheckedArithmetic<Op, &'a T, Output = Buffer<T>>,
92 for<'a> &'a Buffer<T>: CheckedArithmetic<Op, &'a T, Output = Buffer<T>>,
93{
94 type Output = Buffer<T>;
95
96 fn checked_eval(self, rhs: &T) -> Option<Self::Output> {
97 match self.try_into_mut() {
98 Ok(lhs) => lhs.checked_eval(rhs),
99 Err(lhs) => (&lhs).checked_eval(rhs),
100 }
101 }
102}
103
104impl<Op, T> CheckedArithmetic<Op, &T> for BufferMut<T>
106where
107 T: Copy + num_traits::Zero,
108 Op: CheckedOperator<T>,
109{
110 type Output = Buffer<T>;
111
112 fn checked_eval(self, rhs: &T) -> Option<Self::Output> {
113 let mut overflow = false;
114 let buffer = self
115 .map_each_in_place(|a| {
116 Op::apply(&a, rhs).unwrap_or_else(|| {
117 overflow = true;
118 T::zero()
119 })
120 })
121 .freeze();
122
123 (!overflow).then_some(buffer)
124 }
125}
126
127impl<Op, T> CheckedArithmetic<Op, &T> for &Buffer<T>
129where
130 T: Copy + num_traits::Zero,
131 Op: CheckedOperator<T>,
132{
133 type Output = Buffer<T>;
134
135 fn checked_eval(self, rhs: &T) -> Option<Self::Output> {
136 let mut overflow = false;
137 let buffer = Buffer::<T>::from_trusted_len_iter(self.iter().map(|a| {
138 Op::apply(a, rhs).unwrap_or_else(|| {
139 overflow = true;
140 T::zero()
141 })
142 }));
143
144 (!overflow).then_some(buffer)
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use vortex_buffer::buffer;
151
152 use crate::arithmetic::Add;
153 use crate::arithmetic::CheckedArithmetic;
154 use crate::arithmetic::Div;
155 use crate::arithmetic::Mul;
156 use crate::arithmetic::Sub;
157
158 #[test]
159 fn test_add_buffers() {
160 let left = buffer![1u32, 2, 3, 4];
161 let right = buffer![10u32, 20, 30, 40];
162
163 let result = CheckedArithmetic::<Add, _>::checked_eval(left, &right).unwrap();
164 assert_eq!(result, buffer![11u32, 22, 33, 44]);
165 }
166
167 #[test]
168 fn test_add_scalar() {
169 let buf = buffer![1u32, 2, 3, 4];
170 let result = CheckedArithmetic::<Add, _>::checked_eval(buf, &10).unwrap();
171 assert_eq!(result, buffer![11u32, 12, 13, 14]);
172 }
173
174 #[test]
175 fn test_add_overflow() {
176 let left = buffer![u8::MAX, 100];
177 let right = buffer![1u8, 50];
178
179 let result = CheckedArithmetic::<Add, _>::checked_eval(left, &right);
180 assert!(result.is_none());
181 }
182
183 #[test]
184 fn test_sub_buffers() {
185 let left = buffer![10u32, 20, 30, 40];
186 let right = buffer![1u32, 2, 3, 4];
187
188 let result = CheckedArithmetic::<Sub, _>::checked_eval(left, &right).unwrap();
189 assert_eq!(result, buffer![9u32, 18, 27, 36]);
190 }
191
192 #[test]
193 fn test_sub_scalar() {
194 let buf = buffer![10u32, 20, 30, 40];
195 let result = CheckedArithmetic::<Sub, _>::checked_eval(buf, &5).unwrap();
196 assert_eq!(result, buffer![5u32, 15, 25, 35]);
197 }
198
199 #[test]
200 fn test_sub_underflow() {
201 let left = buffer![5u32, 10];
202 let right = buffer![10u32, 5];
203
204 let result = CheckedArithmetic::<Sub, _>::checked_eval(left, &right);
205 assert!(result.is_none());
206 }
207
208 #[test]
209 fn test_mul_buffers() {
210 let left = buffer![2u32, 3, 4, 5];
211 let right = buffer![10u32, 20, 30, 40];
212
213 let result = CheckedArithmetic::<Mul, _>::checked_eval(left, &right).unwrap();
214 assert_eq!(result, buffer![20u32, 60, 120, 200]);
215 }
216
217 #[test]
218 fn test_mul_scalar() {
219 let buf = buffer![1u32, 2, 3, 4];
220 let result = CheckedArithmetic::<Mul, _>::checked_eval(buf, &10).unwrap();
221 assert_eq!(result, buffer![10u32, 20, 30, 40]);
222 }
223
224 #[test]
225 fn test_mul_overflow() {
226 let left = buffer![u8::MAX, 100];
227 let right = buffer![2u8, 3];
228
229 let result = CheckedArithmetic::<Mul, _>::checked_eval(left, &right);
230 assert!(result.is_none());
231 }
232
233 #[test]
234 fn test_div_buffers() {
235 let left = buffer![100u32, 200, 300, 400];
236 let right = buffer![10u32, 20, 30, 40];
237
238 let result = CheckedArithmetic::<Div, _>::checked_eval(left, &right).unwrap();
239 assert_eq!(result, buffer![10u32, 10, 10, 10]);
240 }
241
242 #[test]
243 fn test_div_scalar() {
244 let buf = buffer![100u32, 200, 300, 400];
245 let result = CheckedArithmetic::<Div, _>::checked_eval(buf, &10).unwrap();
246 assert_eq!(result, buffer![10u32, 20, 30, 40]);
247 }
248
249 #[test]
250 fn test_div_by_zero() {
251 let left = buffer![10u32, 20, 30];
252 let right = buffer![2u32, 0, 3];
253
254 let result = CheckedArithmetic::<Div, _>::checked_eval(left, &right);
255 assert!(result.is_none());
256 }
257
258 #[test]
259 fn test_div_scalar_by_zero() {
260 let buf = buffer![10u32, 20, 30];
261 let result = CheckedArithmetic::<Div, _>::checked_eval(buf, &0);
262 assert!(result.is_none());
263 }
264}