vortex_compute/arithmetic/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Arithmetic operations on buffers and vectors.
5
6use vortex_dtype::half::f16;
7
8mod buffer;
9mod buffer_checked;
10mod pvector;
11mod pvector_checked;
12
13/// Trait for arithmetic operations.
14pub trait Arithmetic<Op, Rhs = Self> {
15    /// The result type after performing the operation.
16    type Output;
17
18    /// Perform the operation.
19    fn eval(self, rhs: Rhs) -> Self::Output;
20}
21
22/// Trait for checked arithmetic operators.
23pub trait Operator<T> {
24    /// Apply the operator to the two operands.
25    fn apply(a: &T, b: &T) -> T;
26}
27
28/// Trait for checked arithmetic operations.
29pub trait CheckedArithmetic<Op, Rhs = Self> {
30    /// The result type after performing the operation.
31    type Output;
32
33    /// Perform the operation, returning None on overflow/underflow or division by zero.
34    /// See the `Op` marker detailed semantics on the checked behavior.
35    fn checked_eval(self, rhs: Rhs) -> Option<Self::Output>;
36}
37
38/// Trait for checked arithmetic operators.
39pub trait CheckedOperator<T> {
40    /// Apply the operator to the two operands, returning None on overflow/underflow.
41    fn apply(a: &T, b: &T) -> Option<T>;
42}
43
44/// Marker type for arithmetic addition.
45pub struct Add;
46/// Marker type for arithmetic subtraction.
47pub struct Sub;
48/// Marker type for arithmetic multiplication.
49pub struct Mul;
50/// Marker type for arithmetic division.
51pub struct Div;
52
53/// Marker type for arithmetic addition that wraps on overflow.
54pub struct WrappingAdd;
55/// Marker type for arithmetic subtraction that wraps on overflow.
56pub struct WrappingSub;
57/// Marker type for arithmetic multiplication that wraps on overflow.
58pub struct WrappingMul;
59
60/// Marker type for arithmetic addition that saturates on overflow.
61pub struct SaturatingAdd;
62/// Marker type for arithmetic subtraction that saturates on overflow.
63pub struct SaturatingSub;
64/// Marker type for arithmetic multiplication that saturates on overflow.
65pub struct SaturatingMul;
66
67impl<T: num_traits::CheckedAdd> CheckedOperator<T> for Add {
68    #[inline(always)]
69    fn apply(a: &T, b: &T) -> Option<T> {
70        num_traits::CheckedAdd::checked_add(a, b)
71    }
72}
73impl<T: num_traits::CheckedSub> CheckedOperator<T> for Sub {
74    #[inline(always)]
75    fn apply(a: &T, b: &T) -> Option<T> {
76        num_traits::CheckedSub::checked_sub(a, b)
77    }
78}
79impl<T: num_traits::CheckedMul> CheckedOperator<T> for Mul {
80    #[inline(always)]
81    fn apply(a: &T, b: &T) -> Option<T> {
82        num_traits::CheckedMul::checked_mul(a, b)
83    }
84}
85impl<T: num_traits::CheckedDiv> CheckedOperator<T> for Div {
86    #[inline(always)]
87    fn apply(a: &T, b: &T) -> Option<T> {
88        num_traits::CheckedDiv::checked_div(a, b)
89    }
90}
91
92impl<T: num_traits::WrappingAdd> Operator<T> for WrappingAdd {
93    #[inline(always)]
94    fn apply(a: &T, b: &T) -> T {
95        num_traits::WrappingAdd::wrapping_add(a, b)
96    }
97}
98impl<T: num_traits::WrappingSub> Operator<T> for WrappingSub {
99    #[inline(always)]
100    fn apply(a: &T, b: &T) -> T {
101        num_traits::WrappingSub::wrapping_sub(a, b)
102    }
103}
104impl<T: num_traits::WrappingMul> Operator<T> for WrappingMul {
105    #[inline(always)]
106    fn apply(a: &T, b: &T) -> T {
107        num_traits::WrappingMul::wrapping_mul(a, b)
108    }
109}
110
111impl<T: num_traits::SaturatingAdd> Operator<T> for SaturatingAdd {
112    #[inline(always)]
113    fn apply(a: &T, b: &T) -> T {
114        num_traits::SaturatingAdd::saturating_add(a, b)
115    }
116}
117impl<T: num_traits::SaturatingSub> Operator<T> for SaturatingSub {
118    #[inline(always)]
119    fn apply(a: &T, b: &T) -> T {
120        num_traits::SaturatingSub::saturating_sub(a, b)
121    }
122}
123impl<T: num_traits::SaturatingMul> Operator<T> for SaturatingMul {
124    #[inline(always)]
125    fn apply(a: &T, b: &T) -> T {
126        num_traits::SaturatingMul::saturating_mul(a, b)
127    }
128}
129
130/// Macro to implement arithmetic operators for floating-point types.
131///
132/// These are not deferred to the `std::ops::Add` since those implementations will panic on
133/// overflow in some cases (e.g., debug builds).
134macro_rules! impl_float {
135    ($T:ty) => {
136        impl Operator<$T> for Add {
137            #[inline(always)]
138            fn apply(a: &$T, b: &$T) -> $T {
139                a + b
140            }
141        }
142        impl Operator<$T> for Sub {
143            #[inline(always)]
144            fn apply(a: &$T, b: &$T) -> $T {
145                a - b
146            }
147        }
148        impl Operator<$T> for Mul {
149            #[inline(always)]
150            fn apply(a: &$T, b: &$T) -> $T {
151                a * b
152            }
153        }
154        impl Operator<$T> for Div {
155            #[inline(always)]
156            fn apply(a: &$T, b: &$T) -> $T {
157                a / b
158            }
159        }
160    };
161}
162
163impl_float!(f16);
164impl_float!(f32);
165impl_float!(f64);