vortex_compute/arithmetic/
buffer.rs1use vortex_buffer::Buffer;
5use vortex_buffer::BufferMut;
6
7use crate::arithmetic::Arithmetic;
8use crate::arithmetic::Operator;
9
10impl<Op, T> Arithmetic<Op, &Buffer<T>> for Buffer<T>
12where
13 T: Copy,
14 BufferMut<T>: for<'a> Arithmetic<Op, &'a Buffer<T>, Output = Buffer<T>>,
15 for<'a> &'a Buffer<T>: Arithmetic<Op, &'a Buffer<T>, Output = Buffer<T>>,
16{
17 type Output = Buffer<T>;
18
19 fn eval(self, rhs: &Buffer<T>) -> Self::Output {
20 match self.try_into_mut() {
21 Ok(lhs) => lhs.eval(rhs),
22 Err(lhs) => (&lhs).eval(rhs), }
24 }
25}
26
27impl<Op, T> Arithmetic<Op, &Buffer<T>> for BufferMut<T>
29where
30 T: Copy + num_traits::Zero,
31 Op: Operator<T>,
32{
33 type Output = Buffer<T>;
34
35 fn eval(self, rhs: &Buffer<T>) -> Self::Output {
36 assert_eq!(self.len(), rhs.len());
37
38 let mut i = 0;
39 self.map_each_in_place(|a| {
40 let b = unsafe { *rhs.get_unchecked(i) };
42 i += 1;
43
44 Op::apply(&a, &b)
45 })
46 .freeze()
47 }
48}
49
50impl<Op, T> Arithmetic<Op> for &Buffer<T>
52where
53 Op: Operator<T>,
54{
55 type Output = Buffer<T>;
56
57 fn eval(self, rhs: &Buffer<T>) -> Self::Output {
58 assert_eq!(self.len(), rhs.len());
59 Buffer::<T>::from_trusted_len_iter(
60 self.iter().zip(rhs.iter()).map(|(a, b)| Op::apply(a, b)),
61 )
62 }
63}
64
65impl<Op, T> Arithmetic<Op, &T> for Buffer<T>
68where
69 BufferMut<T>: for<'a> Arithmetic<Op, &'a T, Output = Buffer<T>>,
70 for<'a> &'a Buffer<T>: Arithmetic<Op, &'a T, Output = Buffer<T>>,
71{
72 type Output = Buffer<T>;
73
74 fn eval(self, rhs: &T) -> Self::Output {
75 match self.try_into_mut() {
76 Ok(lhs) => lhs.eval(rhs),
77 Err(lhs) => (&lhs).eval(rhs),
78 }
79 }
80}
81
82impl<Op, T> Arithmetic<Op, &T> for BufferMut<T>
84where
85 T: Copy,
86 Op: Operator<T>,
87{
88 type Output = Buffer<T>;
89
90 fn eval(self, rhs: &T) -> Self::Output {
91 self.map_each_in_place(|a| Op::apply(&a, rhs)).freeze()
92 }
93}
94
95impl<Op, T> Arithmetic<Op, &T> for &Buffer<T>
97where
98 Op: Operator<T>,
99{
100 type Output = Buffer<T>;
101
102 fn eval(self, rhs: &T) -> Self::Output {
103 Buffer::<T>::from_trusted_len_iter(self.iter().map(|a| Op::apply(a, rhs)))
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use vortex_buffer::buffer;
110
111 use crate::arithmetic::Arithmetic;
112 use crate::arithmetic::WrappingAdd;
113 use crate::arithmetic::WrappingMul;
114 use crate::arithmetic::WrappingSub;
115
116 #[test]
117 fn test_add_buffers() {
118 let left = buffer![1u32, 2, 3, 4];
119 let right = buffer![10u32, 20, 30, 40];
120
121 let result = Arithmetic::<WrappingAdd, _>::eval(left, &right);
122 assert_eq!(result, buffer![11u32, 22, 33, 44]);
123 }
124
125 #[test]
126 fn test_add_scalar() {
127 let buf = buffer![1u32, 2, 3, 4];
128 let result = Arithmetic::<WrappingAdd, _>::eval(buf, &10);
129 assert_eq!(result, buffer![11u32, 12, 13, 14]);
130 }
131
132 #[test]
133 fn test_sub_buffers() {
134 let left = buffer![10u32, 20, 30, 40];
135 let right = buffer![1u32, 2, 3, 4];
136
137 let result = Arithmetic::<WrappingSub, _>::eval(left, &right);
138 assert_eq!(result, buffer![9u32, 18, 27, 36]);
139 }
140
141 #[test]
142 fn test_sub_scalar() {
143 let buf = buffer![10u32, 20, 30, 40];
144 let result = Arithmetic::<WrappingSub, _>::eval(buf, &5);
145 assert_eq!(result, buffer![5u32, 15, 25, 35]);
146 }
147
148 #[test]
149 fn test_mul_buffers() {
150 let left = buffer![2u32, 3, 4, 5];
151 let right = buffer![10u32, 20, 30, 40];
152
153 let result = Arithmetic::<WrappingMul, _>::eval(left, &right);
154 assert_eq!(result, buffer![20u32, 60, 120, 200]);
155 }
156
157 #[test]
158 fn test_mul_scalar() {
159 let buf = buffer![1u32, 2, 3, 4];
160 let result = Arithmetic::<WrappingMul, _>::eval(buf, &10);
161 assert_eq!(result, buffer![10u32, 20, 30, 40]);
162 }
163}