vortex_compute/arithmetic/
datum.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_vector::PrimitiveDatum;
5use vortex_vector::ScalarOps;
6use vortex_vector::VectorMutOps;
7use vortex_vector::VectorOps;
8use vortex_vector::primitive::PrimitiveScalar;
9use vortex_vector::primitive::PrimitiveVector;
10
11use crate::arithmetic::Arithmetic;
12use crate::arithmetic::CheckedArithmetic;
13
14impl<Op> CheckedArithmetic<Op> for PrimitiveDatum
15where
16    for<'a> &'a PrimitiveScalar: CheckedArithmetic<Op, Output = PrimitiveScalar>,
17    for<'a> PrimitiveVector: CheckedArithmetic<Op, &'a PrimitiveVector, Output = PrimitiveVector>,
18{
19    type Output = PrimitiveDatum;
20
21    fn checked_eval(self, rhs: PrimitiveDatum) -> Option<Self::Output> {
22        match (self, rhs) {
23            (PrimitiveDatum::Scalar(sc1), PrimitiveDatum::Scalar(sc2)) => {
24                (&sc1).checked_eval(&sc2).map(PrimitiveDatum::Scalar)
25            }
26            (PrimitiveDatum::Vector(vec1), PrimitiveDatum::Vector(vec2)) => {
27                vec1.checked_eval(&vec2).map(PrimitiveDatum::Vector)
28            }
29            (PrimitiveDatum::Vector(vec1), PrimitiveDatum::Scalar(sc2)) => {
30                let len = vec1.len();
31                vec1.checked_eval(&sc2.repeat(len).freeze().into_primitive())
32                    .map(PrimitiveDatum::Vector)
33            }
34            (PrimitiveDatum::Scalar(sc1), PrimitiveDatum::Vector(vec2)) => {
35                let len = vec2.len();
36                sc1.repeat(len)
37                    .freeze()
38                    .into_primitive()
39                    .checked_eval(&vec2)
40                    .map(PrimitiveDatum::Vector)
41            }
42        }
43    }
44}
45
46impl<Op> Arithmetic<Op> for PrimitiveDatum
47where
48    for<'a> &'a PrimitiveScalar: Arithmetic<Op, &'a PrimitiveScalar, Output = PrimitiveScalar>,
49    for<'a> &'a PrimitiveScalar: Arithmetic<Op, PrimitiveVector, Output = PrimitiveDatum>,
50    for<'a> PrimitiveVector: Arithmetic<Op, &'a PrimitiveVector, Output = PrimitiveVector>,
51    for<'a> PrimitiveVector: Arithmetic<Op, &'a PrimitiveScalar, Output = PrimitiveDatum>,
52{
53    type Output = PrimitiveDatum;
54
55    fn eval(self, rhs: PrimitiveDatum) -> Self::Output {
56        match (self, rhs) {
57            (PrimitiveDatum::Scalar(sc1), PrimitiveDatum::Scalar(sc2)) => {
58                PrimitiveDatum::Scalar((&sc1).eval(&sc2))
59            }
60            (PrimitiveDatum::Vector(vec1), PrimitiveDatum::Vector(vec2)) => {
61                PrimitiveDatum::Vector(vec1.eval(&vec2))
62            }
63            (PrimitiveDatum::Vector(vec1), PrimitiveDatum::Scalar(sc2)) => vec1.eval(&sc2),
64            (PrimitiveDatum::Scalar(sc1), PrimitiveDatum::Vector(vec2)) => (&sc1).eval(vec2),
65        }
66    }
67}
68
69#[cfg(test)]
70mod tests {
71    use vortex_buffer::buffer;
72    use vortex_dtype::PTypeDowncast;
73    use vortex_mask::Mask;
74    use vortex_vector::Datum;
75    use vortex_vector::PrimitiveDatum;
76    use vortex_vector::Vector;
77    use vortex_vector::primitive::PVector;
78
79    use crate::arithmetic::Add;
80    use crate::arithmetic::Arithmetic;
81
82    #[test]
83    fn test_datum_arithmetic_in_place() {
84        let left = PVector::new(buffer![1f32, 2.0, 3.0, 4.0], Mask::new_true(4));
85        let right = PVector::new(buffer![10f32, 20.0, 30.0, 40.0], Mask::new_true(4));
86        let left_ptr = left.elements().as_ptr();
87
88        let left_datum = Datum::Vector(Vector::from(left));
89        let right_datum = Datum::Vector(Vector::from(right));
90
91        let result =
92            Arithmetic::<Add, _>::eval(left_datum.into_primitive(), right_datum.into_primitive());
93
94        let result_vec = match result {
95            PrimitiveDatum::Vector(v) => v,
96            _ => panic!("Expected primitive vector result"),
97        };
98
99        let result_pvec: &PVector<f32> = PTypeDowncast::into_f32(&result_vec);
100        let result_ptr = result_pvec.elements().as_ptr();
101
102        assert_eq!(
103            left_ptr, result_ptr,
104            "Buffer should be modified in place when input has unique ownership"
105        );
106        assert_eq!(result_pvec.elements(), &buffer![11f32, 22.0, 33.0, 44.0]);
107    }
108
109    #[test]
110    #[should_panic(expected = "Buffer should be modified in place")]
111    fn test_datum_arithmetic_in_place_fail() {
112        let left = PVector::new(buffer![1f32, 2.0, 3.0, 4.0], Mask::new_true(4));
113        let right = PVector::new(buffer![10f32, 20.0, 30.0, 40.0], Mask::new_true(4));
114        let left_ptr = left.elements().as_ptr();
115
116        let left_datum = Datum::Vector(Vector::from(left));
117        let _left_datum2 = left_datum.clone();
118        let right_datum = Datum::Vector(Vector::from(right));
119
120        let result =
121            Arithmetic::<Add, _>::eval(left_datum.into_primitive(), right_datum.into_primitive());
122
123        let result_vec = match result {
124            PrimitiveDatum::Vector(v) => v,
125            _ => panic!("Expected primitive vector result"),
126        };
127
128        let result_pvec: &PVector<f32> = PTypeDowncast::into_f32(&result_vec);
129        let result_ptr = result_pvec.elements().as_ptr();
130
131        assert_eq!(
132            left_ptr, result_ptr,
133            "Buffer should be modified in place when input has unique ownership"
134        );
135    }
136
137    #[test]
138    fn test_datum_vector_scalar_in_place() {
139        let left = PVector::new(buffer![1f32, 2.0, 3.0, 4.0], Mask::new_true(4));
140        let left_ptr = left.elements().as_ptr();
141
142        let left_datum = PrimitiveDatum::Vector(left.into());
143        let right_datum =
144            PrimitiveDatum::Scalar(vortex_vector::primitive::PScalar::new(Some(10f32)).into());
145
146        let result = Arithmetic::<Add, _>::eval(left_datum, right_datum);
147
148        let result_vec = match result {
149            PrimitiveDatum::Vector(v) => v,
150            _ => panic!("Expected primitive vector result"),
151        };
152
153        let result_pvec: &PVector<f32> = PTypeDowncast::into_f32(&result_vec);
154        let result_ptr = result_pvec.elements().as_ptr();
155
156        assert_eq!(
157            left_ptr, result_ptr,
158            "Buffer should be modified in place for vector-scalar arithmetic"
159        );
160        assert_eq!(result_pvec.elements(), &buffer![11f32, 12.0, 13.0, 14.0]);
161    }
162
163    #[test]
164    fn test_datum_scalar_vector_in_place() {
165        let right = PVector::new(buffer![1f32, 2.0, 3.0, 4.0], Mask::new_true(4));
166        let right_ptr = right.elements().as_ptr();
167
168        let left_datum =
169            PrimitiveDatum::Scalar(vortex_vector::primitive::PScalar::new(Some(10f32)).into());
170        let right_datum = PrimitiveDatum::Vector(right.into());
171
172        let result = Arithmetic::<Add, _>::eval(left_datum, right_datum);
173
174        let result_vec = match result {
175            PrimitiveDatum::Vector(v) => v,
176            _ => panic!("Expected primitive vector result"),
177        };
178
179        let result_pvec = result_vec.into_f32();
180        let result_ptr = result_pvec.elements().as_ptr();
181
182        assert_eq!(
183            right_ptr, result_ptr,
184            "Buffer should be modified in place for scalar-vector arithmetic"
185        );
186        assert_eq!(result_pvec.elements(), &buffer![11f32, 12.0, 13.0, 14.0]);
187    }
188}