Skip to main content

cubecl_core/frontend/operation/
unary.rs

1use core::ops::Not;
2use cubecl_common::{e2m1, e2m1x2, e4m3, e5m2, ue8m0};
3use cubecl_ir::{Bitwise, Comparison, Operator};
4use half::{bf16, f16};
5
6use crate::{
7    flex32,
8    ir::{Arithmetic, ManagedVariable, Scope},
9    prelude::{CubePrimitive, CubePrimitiveExpand, CubeType, NativeExpand, Reinterpret, Scalar},
10    tf32, unexpanded,
11};
12
13use super::base::{unary_expand, unary_expand_fixed_output};
14
15pub mod not {
16    use super::*;
17
18    pub fn expand<T: CubeNot>(scope: &mut Scope, x: NativeExpand<T>) -> NativeExpand<T> {
19        if x.expand.ty.is_bool() {
20            unary_expand(scope, x.into(), Operator::Not).into()
21        } else {
22            unary_expand(scope, x.into(), Bitwise::BitwiseNot).into()
23        }
24    }
25}
26
27pub mod neg {
28    use super::*;
29
30    pub fn expand<E: CubePrimitive>(scope: &mut Scope, x: NativeExpand<E>) -> NativeExpand<E> {
31        unary_expand(scope, x.into(), Arithmetic::Neg).into()
32    }
33}
34
35macro_rules! impl_unary_func {
36    ($trait_name:ident, $method_name:ident, $operator:expr, $($type:ty),*) => {
37        paste::paste! {
38            pub trait $trait_name: CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]> + Sized {
39                #[allow(unused_variables)]
40                fn $method_name(self) -> Self {
41                    unexpanded!()
42                }
43
44                fn [<__expand_ $method_name>](scope: &mut Scope, x: NativeExpand<Self>) -> NativeExpand<Self> {
45                    x.[<__expand_ $method_name _method>](scope)
46                }
47            }
48
49            pub trait [<$trait_name Expand>] {
50                fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self;
51            }
52
53            $(impl $trait_name for $type {})*
54            impl<T: $trait_name + CubePrimitive> [<$trait_name Expand>] for NativeExpand<T> {
55                fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self {
56                    unary_expand(scope, self.into(), $operator).into()
57                }
58            }
59        }
60    }
61}
62
63impl Exp for f32 {
64    fn exp(self) -> Self {
65        self.exp()
66    }
67}
68
69pub trait Abs:
70    CubePrimitive
71    + CubeType<
72        ExpandType: AbsExpand<
73            AbsElem = Self::AbsElem,
74            AbsOut = NativeExpand<Self::WithScalar<Self::AbsElem>>,
75        >,
76    > + Sized
77{
78    type AbsElem: Scalar;
79
80    #[allow(unused_variables)]
81    fn abs(self) -> Self::WithScalar<Self::AbsElem> {
82        unexpanded!()
83    }
84
85    fn __expand_abs(
86        scope: &mut Scope,
87        x: NativeExpand<Self>,
88    ) -> NativeExpand<Self::WithScalar<Self::AbsElem>> {
89        x.__expand_abs_method(scope)
90    }
91}
92
93pub trait AbsExpand: CubePrimitiveExpand {
94    type AbsElem: Scalar;
95    type AbsOut;
96
97    fn __expand_abs_method(self, scope: &mut Scope) -> Self::AbsOut;
98}
99
100impl<T: Abs> AbsExpand for NativeExpand<T> {
101    type AbsElem = T::AbsElem;
102    type AbsOut = NativeExpand<T::WithScalar<T::AbsElem>>;
103
104    fn __expand_abs_method(self, scope: &mut Scope) -> Self::AbsOut {
105        let expand_element: ManagedVariable = self.into();
106        let item = <T::AbsElem as CubePrimitive>::as_type(scope)
107            .with_vector_size(expand_element.ty.vector_size());
108        unary_expand_fixed_output(scope, expand_element, item, Arithmetic::Abs).into()
109    }
110}
111
112macro_rules! impl_abs_same_type {
113    ($($type:ty),*) => {
114        $(impl Abs for $type {
115            type AbsElem = $type;
116        })*
117    };
118}
119
120macro_rules! impl_unary_func_scalar_out {
121    ($trait_name:ident, $method_name:ident, $operator:expr, $($type:ty),*) => {
122        paste::paste! {
123            pub trait $trait_name: CubePrimitive
124                + CubeType<ExpandType: [<$trait_name Expand>]
125                + CubePrimitiveExpand<Scalar = NativeExpand<Self::Scalar>>>
126                + Sized {
127                #[allow(unused_variables)]
128                fn $method_name(self) -> Self {
129                    unexpanded!()
130                }
131
132                fn [<__expand_ $method_name>](scope: &mut Scope, x: NativeExpand<Self>) -> NativeExpand<Self::Scalar> {
133                    x.[<__expand_ $method_name _method>](scope)
134                }
135            }
136
137            pub trait [<$trait_name Expand>]: CubePrimitiveExpand {
138                fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self::Scalar;
139            }
140
141            $(impl $trait_name for $type {})*
142            impl<T: $trait_name + CubePrimitive> [<$trait_name Expand>] for NativeExpand<T> {
143                fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self::Scalar {
144                    let expand_element: ManagedVariable = self.into();
145                    let item = expand_element.ty.with_vector_size(0);
146                    unary_expand_fixed_output(scope, expand_element, item, $operator).into()
147                }
148            }
149        }
150    }
151}
152
153macro_rules! impl_unary_func_fixed_out_ty {
154    ($trait_name:ident, $method_name:ident, $out_ty: ty, $operator:expr, $($type:ty),*) => {
155        paste::paste! {
156            pub trait $trait_name: CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]
157            + CubePrimitiveExpand<WithScalar<$out_ty> = NativeExpand<Self::WithScalar<$out_ty>>>> + Sized {
158                #[allow(unused_variables, clippy::wrong_self_convention)]
159                fn $method_name(self) -> Self::WithScalar<$out_ty> {
160                    unexpanded!()
161                }
162
163                fn [<__expand_ $method_name>](scope: &mut Scope, x: NativeExpand<Self>) -> NativeExpand<Self::WithScalar<$out_ty>> {
164                    x.[<__expand_ $method_name _method>](scope)
165                }
166            }
167
168            pub trait [<$trait_name Expand>]: CubePrimitiveExpand {
169                fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self::WithScalar<$out_ty>;
170            }
171
172            $(impl $trait_name for $type {})*
173            impl<T: $trait_name + CubePrimitive> [<$trait_name Expand>] for NativeExpand<T> {
174                fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self::WithScalar<$out_ty> {
175                    let expand_element: ManagedVariable = self.into();
176                    let item = <$out_ty as CubePrimitive>::as_type(scope).with_vector_size(expand_element.ty.vector_size());
177                    unary_expand_fixed_output(scope, expand_element, item, $operator).into()
178                }
179            }
180        }
181    }
182}
183
184// Needs special handling because Rust combines bitwise and logical or into one trait
185macro_rules! impl_not {
186    ($trait_name:ident, $method_name:ident, $($type:ty),*) => {
187        paste::paste! {
188            pub trait [<Cube $trait_name>]: $trait_name<Output = Self> + CubePrimitive + CubeType<ExpandType: [<$trait_name Expand>]> {
189                fn [<__expand_ $method_name>](scope: &mut Scope, x: NativeExpand<Self>) -> NativeExpand<Self> {
190                    x.[<__expand_ $method_name _method>](scope)
191                }
192            }
193
194            pub trait [<$trait_name Expand>] {
195                fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self;
196            }
197
198            $(impl [<Cube $trait_name>] for $type {})*
199            impl<T: [<Cube $trait_name>] + CubePrimitive> [<$trait_name Expand>] for NativeExpand<T> {
200                fn [<__expand_ $method_name _method>](self, scope: &mut Scope) -> Self {
201                    not::expand(scope, self.into())
202                }
203            }
204        }
205    }
206}
207
208impl_not!(
209    Not, not, bool, u8, u16, u32, u64, i8, i16, i32, i64, isize, usize
210);
211
212impl_abs_same_type!(
213    e2m1, e4m3, e5m2, ue8m0, f16, bf16, flex32, tf32, f32, f64, i8, i16, i32, i64, u8, u16, u32,
214    u64, usize, isize
215);
216impl_unary_func!(
217    Exp,
218    exp,
219    Arithmetic::Exp,
220    f16,
221    bf16,
222    flex32,
223    tf32,
224    // f32,
225    f64,
226    num_complex::Complex<f32>,
227    num_complex::Complex<f64>
228);
229impl_unary_func!(
230    Log,
231    ln,
232    Arithmetic::Log,
233    f16,
234    bf16,
235    flex32,
236    tf32,
237    f32,
238    f64,
239    num_complex::Complex<f32>,
240    num_complex::Complex<f64>
241);
242impl_unary_func!(
243    Log1p,
244    log1p,
245    Arithmetic::Log1p,
246    f16,
247    bf16,
248    flex32,
249    tf32,
250    f32,
251    f64
252);
253impl_unary_func!(
254    Expm1,
255    exp_m1,
256    Arithmetic::Expm1,
257    f16,
258    bf16,
259    flex32,
260    tf32,
261    f32,
262    f64
263);
264impl_unary_func!(
265    Cos,
266    cos,
267    Arithmetic::Cos,
268    f16,
269    bf16,
270    flex32,
271    tf32,
272    f32,
273    f64,
274    num_complex::Complex<f32>,
275    num_complex::Complex<f64>
276);
277impl_unary_func!(
278    Sin,
279    sin,
280    Arithmetic::Sin,
281    f16,
282    bf16,
283    flex32,
284    tf32,
285    f32,
286    f64,
287    num_complex::Complex<f32>,
288    num_complex::Complex<f64>
289);
290impl_unary_func!(Tan, tan, Arithmetic::Tan, f16, bf16, flex32, tf32, f32, f64);
291impl_unary_func!(
292    Tanh,
293    tanh,
294    Arithmetic::Tanh,
295    f16,
296    bf16,
297    flex32,
298    tf32,
299    f32,
300    f64,
301    num_complex::Complex<f32>,
302    num_complex::Complex<f64>
303);
304impl_unary_func!(
305    Sinh,
306    sinh,
307    Arithmetic::Sinh,
308    f16,
309    bf16,
310    flex32,
311    tf32,
312    f32,
313    f64
314);
315impl_unary_func!(
316    Cosh,
317    cosh,
318    Arithmetic::Cosh,
319    f16,
320    bf16,
321    flex32,
322    tf32,
323    f32,
324    f64
325);
326impl_unary_func!(
327    ArcCos,
328    acos,
329    Arithmetic::ArcCos,
330    f16,
331    bf16,
332    flex32,
333    tf32,
334    f32,
335    f64
336);
337impl_unary_func!(
338    ArcSin,
339    asin,
340    Arithmetic::ArcSin,
341    f16,
342    bf16,
343    flex32,
344    tf32,
345    f32,
346    f64
347);
348impl_unary_func!(
349    ArcTan,
350    atan,
351    Arithmetic::ArcTan,
352    f16,
353    bf16,
354    flex32,
355    tf32,
356    f32,
357    f64
358);
359impl_unary_func!(
360    ArcSinh,
361    asinh,
362    Arithmetic::ArcSinh,
363    f16,
364    bf16,
365    flex32,
366    tf32,
367    f32,
368    f64
369);
370impl_unary_func!(
371    ArcCosh,
372    acosh,
373    Arithmetic::ArcCosh,
374    f16,
375    bf16,
376    flex32,
377    tf32,
378    f32,
379    f64
380);
381impl_unary_func!(
382    ArcTanh,
383    atanh,
384    Arithmetic::ArcTanh,
385    f16,
386    bf16,
387    flex32,
388    tf32,
389    f32,
390    f64
391);
392impl_unary_func!(
393    Degrees,
394    to_degrees,
395    Arithmetic::Degrees,
396    f16,
397    bf16,
398    flex32,
399    tf32,
400    f32,
401    f64
402);
403impl_unary_func!(
404    Radians,
405    to_radians,
406    Arithmetic::Radians,
407    f16,
408    bf16,
409    flex32,
410    tf32,
411    f32,
412    f64
413);
414impl_unary_func!(
415    Sqrt,
416    sqrt,
417    Arithmetic::Sqrt,
418    f16,
419    bf16,
420    flex32,
421    tf32,
422    f32,
423    f64,
424    num_complex::Complex<f32>,
425    num_complex::Complex<f64>
426);
427impl_unary_func!(
428    InverseSqrt,
429    inverse_sqrt,
430    Arithmetic::InverseSqrt,
431    f16,
432    bf16,
433    flex32,
434    tf32,
435    f32,
436    f64
437);
438impl_unary_func!(
439    Round,
440    round,
441    Arithmetic::Round,
442    f16,
443    bf16,
444    flex32,
445    tf32,
446    f32,
447    f64
448);
449impl_unary_func!(
450    Floor,
451    floor,
452    Arithmetic::Floor,
453    f16,
454    bf16,
455    flex32,
456    tf32,
457    f32,
458    f64
459);
460impl_unary_func!(
461    Ceil,
462    ceil,
463    Arithmetic::Ceil,
464    f16,
465    bf16,
466    flex32,
467    tf32,
468    f32,
469    f64
470);
471impl_unary_func!(
472    Trunc,
473    trunc,
474    Arithmetic::Trunc,
475    f16,
476    bf16,
477    flex32,
478    tf32,
479    f32,
480    f64
481);
482impl_unary_func!(Erf, erf, Arithmetic::Erf, f16, bf16, flex32, tf32, f32, f64);
483impl_unary_func!(
484    Recip,
485    recip,
486    Arithmetic::Recip,
487    f16,
488    bf16,
489    flex32,
490    tf32,
491    f32,
492    f64
493);
494impl_unary_func_scalar_out!(
495    Magnitude,
496    magnitude,
497    Arithmetic::Magnitude,
498    f16,
499    bf16,
500    flex32,
501    tf32,
502    f32,
503    f64
504);
505impl_unary_func_scalar_out!(
506    VectorSum,
507    vector_sum,
508    Arithmetic::VectorSum,
509    e2m1,
510    e4m3,
511    e5m2,
512    ue8m0,
513    f16,
514    bf16,
515    flex32,
516    tf32,
517    f32,
518    f64,
519    i8,
520    i16,
521    i32,
522    i64,
523    u8,
524    u16,
525    u32,
526    u64,
527    usize,
528    isize
529);
530impl_unary_func!(
531    Normalize,
532    normalize,
533    Arithmetic::Normalize,
534    f16,
535    bf16,
536    flex32,
537    tf32,
538    f32,
539    f64
540);
541impl_unary_func_fixed_out_ty!(
542    CountOnes,
543    count_ones,
544    u32,
545    Bitwise::CountOnes,
546    u8,
547    i8,
548    u16,
549    i16,
550    u32,
551    i32,
552    u64,
553    i64,
554    usize,
555    isize
556);
557impl_unary_func!(
558    ReverseBits,
559    reverse_bits,
560    Bitwise::ReverseBits,
561    u8,
562    i8,
563    u16,
564    i16,
565    u32,
566    i32,
567    u64,
568    i64,
569    usize,
570    isize
571);
572
573impl_unary_func_fixed_out_ty!(
574    LeadingZeros,
575    leading_zeros,
576    u32,
577    Bitwise::LeadingZeros,
578    u8,
579    i8,
580    u16,
581    i16,
582    u32,
583    i32,
584    u64,
585    i64,
586    usize,
587    isize
588);
589impl_unary_func_fixed_out_ty!(
590    TrailingZeros,
591    trailing_zeros,
592    u32,
593    Bitwise::TrailingZeros,
594    u8,
595    i8,
596    u16,
597    i16,
598    u32,
599    i32,
600    u64,
601    i64,
602    usize,
603    isize
604);
605impl_unary_func_fixed_out_ty!(
606    FindFirstSet,
607    find_first_set,
608    u32,
609    Bitwise::FindFirstSet,
610    u8,
611    i8,
612    u16,
613    i16,
614    u32,
615    i32,
616    u64,
617    i64,
618    usize,
619    isize
620);
621impl_unary_func_fixed_out_ty!(
622    IsNan,
623    is_nan,
624    bool,
625    Comparison::IsNan,
626    f16,
627    bf16,
628    flex32,
629    tf32,
630    f32,
631    f64
632);
633impl_unary_func_fixed_out_ty!(
634    IsInf,
635    is_inf,
636    bool,
637    Comparison::IsInf,
638    f16,
639    bf16,
640    flex32,
641    tf32,
642    f32,
643    f64
644);
645
646pub trait FloatBits:
647    CubePrimitive + CubeType<ExpandType: FloatBitsExpand<Bits = Self::Bits>>
648{
649    type Bits: CubePrimitive;
650
651    fn __expand_from_bits(scope: &mut Scope, bits: NativeExpand<Self::Bits>) -> NativeExpand<Self> {
652        Self::__expand_reinterpret(scope, bits)
653    }
654
655    fn __expand_to_bits(scope: &mut Scope, this: NativeExpand<Self>) -> NativeExpand<Self::Bits> {
656        <Self::Bits as Reinterpret>::__expand_reinterpret(scope, this)
657    }
658}
659
660pub trait FloatBitsExpand: Sized {
661    type Bits: CubePrimitive;
662
663    fn __expand_to_bits_method(self, scope: &mut Scope) -> NativeExpand<Self::Bits>;
664}
665
666impl<F: FloatBits> FloatBitsExpand for NativeExpand<F> {
667    type Bits = F::Bits;
668
669    fn __expand_to_bits_method(self, scope: &mut Scope) -> NativeExpand<Self::Bits> {
670        <Self::Bits as Reinterpret>::__expand_reinterpret(scope, self)
671    }
672}
673
674impl FloatBits for e2m1x2 {
675    type Bits = u8;
676}
677
678impl FloatBits for e5m2 {
679    type Bits = u8;
680}
681
682impl FloatBits for e4m3 {
683    type Bits = u8;
684}
685
686impl FloatBits for f16 {
687    type Bits = u16;
688}
689
690impl FloatBits for bf16 {
691    type Bits = u16;
692}
693
694impl FloatBits for f32 {
695    type Bits = u32;
696}
697
698impl FloatBits for f64 {
699    type Bits = u64;
700}