Skip to main content

cubecl_core/frontend/container/vector/
ops.rs

1use core::{marker::PhantomData, ops::Not};
2use cubecl_ir::{Bitwise, ConstantValue, ElemType, Instruction, Type, UIntKind, UnaryOperator};
3use cubecl_macros::{cube, intrinsic};
4use num_traits::{NumCast, One, ToPrimitive, Zero};
5
6use crate::{
7    self as cubecl,
8    prelude::{
9        ArcTan2, InverseSqrt, IsInf, IsNan, Powf, Powi, SaturatingAdd, SaturatingSub, Trunc,
10    },
11};
12use crate::{prelude::*, unexpanded};
13
14use super::Vector;
15type VectorExpand<E, N> = NativeExpand<Vector<E, N>>;
16
17impl<P, N: Size> core::ops::Add<Self> for Vector<P, N>
18where
19    P: Scalar,
20    P: core::ops::Add<P, Output = P>,
21{
22    type Output = Self;
23
24    fn add(self, rhs: Self) -> Self::Output {
25        Self::new(self.val + rhs.val)
26    }
27}
28
29impl<P, N: Size> core::ops::Sub<Self> for Vector<P, N>
30where
31    P: Scalar,
32    P: core::ops::Sub<P, Output = P>,
33{
34    type Output = Self;
35
36    fn sub(self, rhs: Self) -> Self::Output {
37        Self::new(self.val - rhs.val)
38    }
39}
40
41impl<P, N: Size> core::ops::Mul<Self> for Vector<P, N>
42where
43    P: Scalar,
44    P: core::ops::Mul<P, Output = P>,
45{
46    type Output = Self;
47
48    fn mul(self, rhs: Self) -> Self::Output {
49        Self::new(self.val * rhs.val)
50    }
51}
52
53impl<P, N: Size> core::ops::Div<Self> for Vector<P, N>
54where
55    P: Scalar,
56    P: core::ops::Div<P, Output = P>,
57{
58    type Output = Self;
59
60    fn div(self, rhs: Self) -> Self::Output {
61        Self::new(self.val / rhs.val)
62    }
63}
64
65impl<P, N: Size> core::ops::AddAssign<Self> for Vector<P, N>
66where
67    P: Scalar,
68    P: core::ops::AddAssign,
69{
70    fn add_assign(&mut self, rhs: Self) {
71        self.val += rhs.val;
72    }
73}
74
75impl<P, N: Size> core::ops::SubAssign<Self> for Vector<P, N>
76where
77    P: Scalar,
78    P: core::ops::SubAssign,
79{
80    fn sub_assign(&mut self, rhs: Self) {
81        self.val -= rhs.val;
82    }
83}
84
85impl<P, N: Size> core::ops::DivAssign<Self> for Vector<P, N>
86where
87    P: Scalar,
88    P: core::ops::DivAssign,
89{
90    fn div_assign(&mut self, rhs: Self) {
91        self.val /= rhs.val;
92    }
93}
94
95impl<P, N: Size> core::ops::MulAssign<Self> for Vector<P, N>
96where
97    P: Scalar,
98    P: core::ops::MulAssign,
99{
100    fn mul_assign(&mut self, rhs: Self) {
101        self.val *= rhs.val;
102    }
103}
104
105impl<P, N: Size> core::cmp::PartialEq for Vector<P, N>
106where
107    P: Scalar,
108    P: core::cmp::PartialEq,
109{
110    fn eq(&self, other: &Self) -> bool {
111        self.val.eq(&other.val)
112    }
113}
114
115impl<P, N: Size> core::cmp::PartialOrd for Vector<P, N>
116where
117    P: Scalar,
118    P: core::cmp::PartialOrd,
119{
120    fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
121        self.val.partial_cmp(&other.val)
122    }
123}
124
125impl<P, N: Size> core::ops::BitAnd<Self> for Vector<P, N>
126where
127    P: Scalar,
128    P: core::ops::BitAnd<P, Output = P>,
129{
130    type Output = Self;
131
132    fn bitand(self, rhs: Self) -> Self::Output {
133        Self::new(self.val & rhs.val)
134    }
135}
136
137impl<P, N: Size> core::ops::BitOr<Self> for Vector<P, N>
138where
139    P: Scalar,
140    P: core::ops::BitOr<P, Output = P>,
141{
142    type Output = Self;
143
144    fn bitor(self, rhs: Self) -> Self::Output {
145        Self::new(self.val | rhs.val)
146    }
147}
148
149impl<P, N: Size> core::ops::BitXor<Self> for Vector<P, N>
150where
151    P: Scalar,
152    P: core::ops::BitXor<P, Output = P>,
153{
154    type Output = Self;
155
156    fn bitxor(self, rhs: Self) -> Self::Output {
157        Self::new(self.val ^ rhs.val)
158    }
159}
160
161impl<P, N: Size> core::ops::Shl<Self> for Vector<P, N>
162where
163    P: Scalar,
164    P: core::ops::Shl<P, Output = P>,
165{
166    type Output = Self;
167
168    fn shl(self, rhs: Self) -> Self::Output {
169        Self::new(self.val << rhs.val)
170    }
171}
172
173impl<P, N: Size> core::ops::Shr<Self> for Vector<P, N>
174where
175    P: Scalar,
176    P: core::ops::Shr<P, Output = P>,
177{
178    type Output = Self;
179
180    fn shr(self, rhs: Self) -> Self::Output {
181        Self::new(self.val >> rhs.val)
182    }
183}
184
185impl<P, N: Size> core::ops::BitAndAssign<Self> for Vector<P, N>
186where
187    P: Scalar,
188    P: core::ops::BitAndAssign,
189{
190    fn bitand_assign(&mut self, rhs: Self) {
191        self.val &= rhs.val;
192    }
193}
194
195impl<P, N: Size> core::ops::BitOrAssign<Self> for Vector<P, N>
196where
197    P: Scalar,
198    P: core::ops::BitOrAssign,
199{
200    fn bitor_assign(&mut self, rhs: Self) {
201        self.val |= rhs.val;
202    }
203}
204
205impl<P, N: Size> core::ops::BitXorAssign<Self> for Vector<P, N>
206where
207    P: Scalar,
208    P: core::ops::BitXorAssign,
209{
210    fn bitxor_assign(&mut self, rhs: Self) {
211        self.val ^= rhs.val;
212    }
213}
214
215impl<P, N: Size> core::ops::ShlAssign<Self> for Vector<P, N>
216where
217    P: Scalar,
218    P: core::ops::ShlAssign,
219{
220    fn shl_assign(&mut self, rhs: Self) {
221        self.val <<= rhs.val;
222    }
223}
224
225impl<P, N: Size> core::ops::ShrAssign<Self> for Vector<P, N>
226where
227    P: Scalar,
228    P: core::ops::ShrAssign,
229{
230    fn shr_assign(&mut self, rhs: Self) {
231        self.val >>= rhs.val;
232    }
233}
234
235impl<P: Scalar + Abs, N: Size> Abs for Vector<P, N> {
236    type AbsElem = P::AbsElem;
237}
238impl<P: Scalar + Log, N: Size> Log for Vector<P, N> {}
239impl<P: Scalar + Log1p, N: Size> Log1p for Vector<P, N> {}
240impl<P: Scalar + Expm1, N: Size> Expm1 for Vector<P, N> {}
241impl<P: Scalar + Erf, N: Size> Erf for Vector<P, N> {}
242impl<P: Scalar + Exp, N: Size> Exp for Vector<P, N> {}
243impl<P: Scalar + Powf, N: Size> Powf for Vector<P, N> {}
244impl<P: Scalar + Powi<I>, I: Scalar, N: Size> Powi<Vector<I, N>> for Vector<P, N> {}
245impl<P: Scalar + Sqrt, N: Size> Sqrt for Vector<P, N> {}
246impl<P: Scalar + InverseSqrt, N: Size> InverseSqrt for Vector<P, N> {}
247impl<P: Scalar + Cos, N: Size> Cos for Vector<P, N> {}
248impl<P: Scalar + Sin, N: Size> Sin for Vector<P, N> {}
249impl<P: Scalar + Tan, N: Size> Tan for Vector<P, N> {}
250impl<P: Scalar + Tanh, N: Size> Tanh for Vector<P, N> {}
251impl<P: Scalar + Sinh, N: Size> Sinh for Vector<P, N> {}
252impl<P: Scalar + Cosh, N: Size> Cosh for Vector<P, N> {}
253impl<P: Scalar + ArcSin, N: Size> ArcSin for Vector<P, N> {}
254impl<P: Scalar + ArcCos, N: Size> ArcCos for Vector<P, N> {}
255impl<P: Scalar + ArcTan, N: Size> ArcTan for Vector<P, N> {}
256impl<P: Scalar + ArcSinh, N: Size> ArcSinh for Vector<P, N> {}
257impl<P: Scalar + ArcCosh, N: Size> ArcCosh for Vector<P, N> {}
258impl<P: Scalar + ArcTanh, N: Size> ArcTanh for Vector<P, N> {}
259impl<P: Scalar + ArcTan2, N: Size> ArcTan2 for Vector<P, N> {}
260impl<P: Scalar + Recip, N: Size> Recip for Vector<P, N> {}
261impl<P: Scalar + Remainder, N: Size> Remainder for Vector<P, N> {}
262impl<P: Scalar + Round, N: Size> Round for Vector<P, N> {}
263impl<P: Scalar + Floor, N: Size> Floor for Vector<P, N> {}
264impl<P: Scalar + Ceil, N: Size> Ceil for Vector<P, N> {}
265impl<P: Scalar + Trunc, N: Size> Trunc for Vector<P, N> {}
266impl<P: Scalar + ReverseBits, N: Size> ReverseBits for Vector<P, N> {}
267impl<P: Scalar + CubeNot, N: Size> CubeNot for Vector<P, N> {}
268impl<P: Scalar + SaturatingAdd, N: Size> SaturatingAdd for Vector<P, N> {}
269impl<P: Scalar + SaturatingSub, N: Size> SaturatingSub for Vector<P, N> {}
270impl<P: Scalar + IsNan, N: Size> IsNan for Vector<P, N> {}
271impl<P: Scalar + IsInf, N: Size> IsInf for Vector<P, N> {}
272impl<P: Scalar + Normalize, N: Size> Normalize for Vector<P, N> {}
273impl<P: Scalar + Magnitude, N: Size> Magnitude for Vector<P, N> {}
274impl<P: Scalar + VectorSum, N: Size> VectorSum for Vector<P, N> {}
275impl<P: Scalar + Degrees, N: Size> Degrees for Vector<P, N> {}
276impl<P: Scalar + Radians, N: Size> Radians for Vector<P, N> {}
277
278impl<P: Scalar + Ord, N: Size> Ord for Vector<P, N> {
279    fn cmp(&self, other: &Self) -> core::cmp::Ordering {
280        self.val.cmp(&other.val)
281    }
282}
283
284#[cube]
285impl<P: CountOnes + Scalar, N: Size> Vector<P, N> {
286    pub fn count_ones(self) -> Vector<u32, N> {
287        intrinsic!(|scope| {
288            let out_item = Type::scalar(ElemType::UInt(UIntKind::U32))
289                .with_vector_size(self.expand.ty.vector_size());
290            let out = scope.create_local(out_item);
291            scope.register(Instruction::new(
292                Bitwise::CountOnes(UnaryOperator {
293                    input: *self.expand,
294                }),
295                *out,
296            ));
297            out.into()
298        })
299    }
300}
301
302impl<P: LeadingZeros + Scalar, N: Size> LeadingZeros for Vector<P, N> {}
303impl<P: FindFirstSet + Scalar, N: Size> FindFirstSet for Vector<P, N> {}
304impl<P: TrailingZeros + Scalar, N: Size> TrailingZeros for Vector<P, N> {}
305
306impl<P: Scalar + NumCast, N: Size> NumCast for Vector<P, N> {
307    fn from<T: num_traits::ToPrimitive>(n: T) -> Option<Self> {
308        let val: P = NumCast::from(n)?;
309        Some(Self {
310            val,
311            _size: PhantomData,
312        })
313    }
314}
315impl<P: Scalar + NumCast, N: Size> ToPrimitive for Vector<P, N> {
316    fn to_i64(&self) -> Option<i64> {
317        self.val.to_i64()
318    }
319
320    fn to_u64(&self) -> Option<u64> {
321        self.val.to_u64()
322    }
323}
324
325impl<P: Not<Output = P> + Scalar, N: Size> Not for Vector<P, N> {
326    type Output = Self;
327
328    fn not(self) -> Self::Output {
329        Vector::new(self.val.not())
330    }
331}
332
333#[allow(clippy::from_over_into)]
334impl<P: Scalar + Into<NativeExpand<P>>, N: Size> Into<NativeExpand<Self>> for Vector<P, N> {
335    fn into(self) -> NativeExpand<Self> {
336        let elem: NativeExpand<P> = self.val.into();
337        elem.expand.into()
338    }
339}
340
341impl<T: Scalar + Default, N: Size> Default for Vector<T, N> {
342    fn default() -> Self {
343        Self::new(T::default())
344    }
345}
346
347impl<T: Scalar + IntoRuntime, N: Size> IntoRuntime for Vector<T, N> {
348    fn __expand_runtime_method(self, scope: &mut Scope) -> Self::ExpandType {
349        let val = self.val.__expand_runtime_method(scope);
350        Self::__expand_new(scope, val)
351    }
352}
353
354impl<T: Scalar + Into<ConstantValue>, N: Size> From<Vector<T, N>> for ConstantValue {
355    fn from(value: Vector<T, N>) -> Self {
356        value.val.into()
357    }
358}
359
360impl<T: Scalar + Zero, N: Size> Zero for Vector<T, N> {
361    fn zero() -> Self {
362        Self::new(T::zero())
363    }
364
365    fn is_zero(&self) -> bool {
366        self.val.is_zero()
367    }
368}
369
370impl<T: Scalar + One, N: Size> One for Vector<T, N> {
371    fn one() -> Self {
372        Self::new(T::one())
373    }
374}
375
376macro_rules! operation_literal {
377    ($lit:ty) => {
378        impl<P, N: Size> core::ops::Add<$lit> for Vector<P, N>
379        where
380            P: Scalar,
381        {
382            type Output = Self;
383
384            fn add(self, _rhs: $lit) -> Self::Output {
385                unexpanded!();
386            }
387        }
388
389        impl<P, N: Size> core::ops::Sub<$lit> for Vector<P, N>
390        where
391            P: Scalar,
392        {
393            type Output = Self;
394
395            fn sub(self, _rhs: $lit) -> Self::Output {
396                unexpanded!();
397            }
398        }
399
400        impl<P, N: Size> core::ops::Mul<$lit> for Vector<P, N>
401        where
402            P: Scalar,
403        {
404            type Output = Self;
405
406            fn mul(self, _rhs: $lit) -> Self::Output {
407                unexpanded!();
408            }
409        }
410
411        impl<P, N: Size> core::ops::Div<$lit> for Vector<P, N>
412        where
413            P: Scalar,
414        {
415            type Output = Self;
416
417            fn div(self, _rhs: $lit) -> Self::Output {
418                unexpanded!();
419            }
420        }
421    };
422}
423
424operation_literal!(f32);
425operation_literal!(f64);
426operation_literal!(usize);
427operation_literal!(i32);
428operation_literal!(i64);