vortex_compute/arithmetic/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Arithmetic operations on buffers and vectors.
5
6use vortex_dtype::half::f16;
7
8mod buffer;
9mod buffer_checked;
10mod datum;
11mod primitive_scalar;
12mod primitive_vector;
13mod pscalar;
14mod pvector;
15mod pvector_checked;
16
17/// Trait for arithmetic operations.
18pub trait Arithmetic<Op, Rhs = Self> {
19    /// The result type after performing the operation.
20    type Output;
21
22    /// Perform the operation.
23    fn eval(self, rhs: Rhs) -> Self::Output;
24}
25
26/// Trait for checked arithmetic operators.
27pub trait Operator<T> {
28    /// Apply the operator to the two operands.
29    fn apply(a: &T, b: &T) -> T;
30}
31
32/// Trait for checked arithmetic operations.
33pub trait CheckedArithmetic<Op, Rhs = Self> {
34    /// The result type after performing the operation.
35    type Output;
36
37    /// Perform the operation, returning None on overflow/underflow or division by zero.
38    /// See the `Op` marker detailed semantics on the checked behavior.
39    fn checked_eval(self, rhs: Rhs) -> Option<Self::Output>;
40}
41
42/// Trait for checked arithmetic operators.
43pub trait CheckedOperator<T> {
44    /// Apply the operator to the two operands, returning None on overflow/underflow.
45    fn apply(a: &T, b: &T) -> Option<T>;
46}
47
48/// Marker type for arithmetic addition.
49pub struct Add;
50/// Marker type for arithmetic subtraction.
51pub struct Sub;
52/// Marker type for arithmetic multiplication.
53pub struct Mul;
54/// Marker type for arithmetic division.
55pub struct Div;
56
57/// Marker type for arithmetic addition that wraps on overflow.
58pub struct WrappingAdd;
59/// Marker type for arithmetic subtraction that wraps on overflow.
60pub struct WrappingSub;
61/// Marker type for arithmetic multiplication that wraps on overflow.
62pub struct WrappingMul;
63
64/// Marker type for arithmetic addition that saturates on overflow.
65pub struct SaturatingAdd;
66/// Marker type for arithmetic subtraction that saturates on overflow.
67pub struct SaturatingSub;
68/// Marker type for arithmetic multiplication that saturates on overflow.
69pub struct SaturatingMul;
70
71impl<T: num_traits::CheckedAdd> CheckedOperator<T> for Add {
72    #[inline(always)]
73    fn apply(a: &T, b: &T) -> Option<T> {
74        num_traits::CheckedAdd::checked_add(a, b)
75    }
76}
77impl<T: num_traits::CheckedSub> CheckedOperator<T> for Sub {
78    #[inline(always)]
79    fn apply(a: &T, b: &T) -> Option<T> {
80        num_traits::CheckedSub::checked_sub(a, b)
81    }
82}
83impl<T: num_traits::CheckedMul> CheckedOperator<T> for Mul {
84    #[inline(always)]
85    fn apply(a: &T, b: &T) -> Option<T> {
86        num_traits::CheckedMul::checked_mul(a, b)
87    }
88}
89impl<T: num_traits::CheckedDiv> CheckedOperator<T> for Div {
90    #[inline(always)]
91    fn apply(a: &T, b: &T) -> Option<T> {
92        num_traits::CheckedDiv::checked_div(a, b)
93    }
94}
95
96impl<T: num_traits::WrappingAdd> Operator<T> for WrappingAdd {
97    #[inline(always)]
98    fn apply(a: &T, b: &T) -> T {
99        num_traits::WrappingAdd::wrapping_add(a, b)
100    }
101}
102impl<T: num_traits::WrappingSub> Operator<T> for WrappingSub {
103    #[inline(always)]
104    fn apply(a: &T, b: &T) -> T {
105        num_traits::WrappingSub::wrapping_sub(a, b)
106    }
107}
108impl<T: num_traits::WrappingMul> Operator<T> for WrappingMul {
109    #[inline(always)]
110    fn apply(a: &T, b: &T) -> T {
111        num_traits::WrappingMul::wrapping_mul(a, b)
112    }
113}
114
115impl<T: num_traits::SaturatingAdd> Operator<T> for SaturatingAdd {
116    #[inline(always)]
117    fn apply(a: &T, b: &T) -> T {
118        num_traits::SaturatingAdd::saturating_add(a, b)
119    }
120}
121impl<T: num_traits::SaturatingSub> Operator<T> for SaturatingSub {
122    #[inline(always)]
123    fn apply(a: &T, b: &T) -> T {
124        num_traits::SaturatingSub::saturating_sub(a, b)
125    }
126}
127impl<T: num_traits::SaturatingMul> Operator<T> for SaturatingMul {
128    #[inline(always)]
129    fn apply(a: &T, b: &T) -> T {
130        num_traits::SaturatingMul::saturating_mul(a, b)
131    }
132}
133
134/// Macro to implement arithmetic operators for floating-point types.
135///
136/// These are not deferred to the `std::ops::Add` since those implementations will panic on
137/// overflow in some cases (e.g., debug builds).
138macro_rules! impl_float {
139    ($T:ty) => {
140        impl Operator<$T> for Add {
141            #[inline(always)]
142            fn apply(a: &$T, b: &$T) -> $T {
143                a + b
144            }
145        }
146        impl Operator<$T> for Sub {
147            #[inline(always)]
148            fn apply(a: &$T, b: &$T) -> $T {
149                a - b
150            }
151        }
152        impl Operator<$T> for Mul {
153            #[inline(always)]
154            fn apply(a: &$T, b: &$T) -> $T {
155                a * b
156            }
157        }
158        impl Operator<$T> for Div {
159            #[inline(always)]
160            fn apply(a: &$T, b: &$T) -> $T {
161                a / b
162            }
163        }
164    };
165}
166
167impl_float!(f16);
168impl_float!(f32);
169impl_float!(f64);