vortex_compute/arithmetic/
buffer.rs1use vortex_buffer::{Buffer, BufferMut};
5
6use crate::arithmetic::{Arithmetic, Operator};
7
8impl<Op, T> Arithmetic<Op, &Buffer<T>> for Buffer<T>
10where
11 T: Copy,
12 BufferMut<T>: for<'a> Arithmetic<Op, &'a Buffer<T>, Output = Buffer<T>>,
13 for<'a> &'a Buffer<T>: Arithmetic<Op, &'a Buffer<T>, Output = Buffer<T>>,
14{
15 type Output = Buffer<T>;
16
17 fn eval(self, rhs: &Buffer<T>) -> Self::Output {
18 match self.try_into_mut() {
19 Ok(lhs) => lhs.eval(rhs),
20 Err(lhs) => (&lhs).eval(rhs), }
22 }
23}
24
25impl<Op, T> Arithmetic<Op, &Buffer<T>> for BufferMut<T>
27where
28 T: Copy + num_traits::Zero,
29 Op: Operator<T>,
30{
31 type Output = Buffer<T>;
32
33 fn eval(self, rhs: &Buffer<T>) -> Self::Output {
34 assert_eq!(self.len(), rhs.len());
35
36 let mut i = 0;
37 self.map_each_in_place(|a| {
38 let b = unsafe { *rhs.get_unchecked(i) };
40 i += 1;
41
42 Op::apply(&a, &b)
43 })
44 .freeze()
45 }
46}
47
48impl<Op, T> Arithmetic<Op> for &Buffer<T>
50where
51 Op: Operator<T>,
52{
53 type Output = Buffer<T>;
54
55 fn eval(self, rhs: &Buffer<T>) -> Self::Output {
56 assert_eq!(self.len(), rhs.len());
57 Buffer::<T>::from_trusted_len_iter(
58 self.iter().zip(rhs.iter()).map(|(a, b)| Op::apply(a, b)),
59 )
60 }
61}
62
63impl<Op, T> Arithmetic<Op, &T> for Buffer<T>
66where
67 BufferMut<T>: for<'a> Arithmetic<Op, &'a T, Output = Buffer<T>>,
68 for<'a> &'a Buffer<T>: Arithmetic<Op, &'a T, Output = Buffer<T>>,
69{
70 type Output = Buffer<T>;
71
72 fn eval(self, rhs: &T) -> Self::Output {
73 match self.try_into_mut() {
74 Ok(lhs) => lhs.eval(rhs),
75 Err(lhs) => (&lhs).eval(rhs),
76 }
77 }
78}
79
80impl<Op, T> Arithmetic<Op, &T> for BufferMut<T>
82where
83 T: Copy,
84 Op: Operator<T>,
85{
86 type Output = Buffer<T>;
87
88 fn eval(self, rhs: &T) -> Self::Output {
89 self.map_each_in_place(|a| Op::apply(&a, rhs)).freeze()
90 }
91}
92
93impl<Op, T> Arithmetic<Op, &T> for &Buffer<T>
95where
96 Op: Operator<T>,
97{
98 type Output = Buffer<T>;
99
100 fn eval(self, rhs: &T) -> Self::Output {
101 Buffer::<T>::from_trusted_len_iter(self.iter().map(|a| Op::apply(a, rhs)))
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use vortex_buffer::buffer;
108
109 use crate::arithmetic::{Arithmetic, WrappingAdd, WrappingMul, WrappingSub};
110
111 #[test]
112 fn test_add_buffers() {
113 let left = buffer![1u32, 2, 3, 4];
114 let right = buffer![10u32, 20, 30, 40];
115
116 let result = Arithmetic::<WrappingAdd, _>::eval(left, &right);
117 assert_eq!(result, buffer![11u32, 22, 33, 44]);
118 }
119
120 #[test]
121 fn test_add_scalar() {
122 let buf = buffer![1u32, 2, 3, 4];
123 let result = Arithmetic::<WrappingAdd, _>::eval(buf, &10);
124 assert_eq!(result, buffer![11u32, 12, 13, 14]);
125 }
126
127 #[test]
128 fn test_sub_buffers() {
129 let left = buffer![10u32, 20, 30, 40];
130 let right = buffer![1u32, 2, 3, 4];
131
132 let result = Arithmetic::<WrappingSub, _>::eval(left, &right);
133 assert_eq!(result, buffer![9u32, 18, 27, 36]);
134 }
135
136 #[test]
137 fn test_sub_scalar() {
138 let buf = buffer![10u32, 20, 30, 40];
139 let result = Arithmetic::<WrappingSub, _>::eval(buf, &5);
140 assert_eq!(result, buffer![5u32, 15, 25, 35]);
141 }
142
143 #[test]
144 fn test_mul_buffers() {
145 let left = buffer![2u32, 3, 4, 5];
146 let right = buffer![10u32, 20, 30, 40];
147
148 let result = Arithmetic::<WrappingMul, _>::eval(left, &right);
149 assert_eq!(result, buffer![20u32, 60, 120, 200]);
150 }
151
152 #[test]
153 fn test_mul_scalar() {
154 let buf = buffer![1u32, 2, 3, 4];
155 let result = Arithmetic::<WrappingMul, _>::eval(buf, &10);
156 assert_eq!(result, buffer![10u32, 20, 30, 40]);
157 }
158}