Skip to main content

cubecl_core/frontend/element/
float.rs

1use cubecl_ir::{ConstantValue, Scope, StorageType, Type};
2use half::{bf16, f16};
3
4use crate::{
5    self as cubecl,
6    ir::{ElemType, FloatKind},
7    prelude::*,
8};
9
10use super::Numeric;
11
12mod fp4;
13mod fp6;
14mod fp8;
15mod relaxed;
16mod tensor_float;
17
18/// Floating point numbers. Used as input in float kernels
19pub trait Float:
20    Numeric
21    + FloatOps
22    + Exp
23    + Log
24    + Log1p
25    + Expm1
26    + Cos
27    + Sin
28    + Tan
29    + Tanh
30    + Sinh
31    + Cosh
32    + ArcCos
33    + ArcSin
34    + ArcTan
35    + ArcSinh
36    + ArcCosh
37    + ArcTanh
38    + Degrees
39    + Radians
40    + ArcTan2
41    + Powf
42    + Powi<i32>
43    + Hypot
44    + Rhypot
45    + Sqrt
46    + InverseSqrt
47    + Round
48    + Floor
49    + Ceil
50    + Trunc
51    + Erf
52    + Recip
53    + Magnitude
54    + Normalize
55    + Dot
56    + IsNan
57    + IsInf
58    + Into<Self::ExpandType>
59    + core::ops::Neg<Output = Self>
60    + core::cmp::PartialOrd
61    + core::cmp::PartialEq
62{
63    const DIGITS: u32;
64    const EPSILON: Self;
65    const INFINITY: Self;
66    const MANTISSA_DIGITS: u32;
67    const MAX_10_EXP: i32;
68    const MAX_EXP: i32;
69    const MIN_10_EXP: i32;
70    const MIN_EXP: i32;
71    const MIN_POSITIVE: Self;
72    const NAN: Self;
73    const NEG_INFINITY: Self;
74    const RADIX: u32;
75
76    fn new(val: f32) -> Self;
77    fn __expand_new(scope: &mut Scope, val: f32) -> <Self as CubeType>::ExpandType {
78        __expand_new(scope, val)
79    }
80}
81
82#[cube]
83pub trait FloatOps: CubePrimitive + PartialOrd + Sized {
84    fn min(self, other: Self) -> Self {
85        cubecl::prelude::min(self, other)
86    }
87
88    fn max(self, other: Self) -> Self {
89        cubecl::prelude::max(self, other)
90    }
91
92    fn clamp(self, min: Self, max: Self) -> Self {
93        clamp(self, min, max)
94    }
95}
96
97impl<T: Float> FloatOps for T {}
98impl<T: FloatOps + CubePrimitive> FloatOpsExpand for NativeExpand<T> {
99    fn __expand_min_method(self, scope: &mut Scope, other: Self) -> Self {
100        min::expand(scope, self, other)
101    }
102
103    fn __expand_max_method(self, scope: &mut Scope, other: Self) -> Self {
104        max::expand(scope, self, other)
105    }
106
107    fn __expand_clamp_method(self, scope: &mut Scope, min: Self, max: Self) -> Self {
108        clamp::expand(scope, self, min, max)
109    }
110}
111
112macro_rules! impl_float {
113    (half $primitive:ident, $kind:ident) => {
114        impl_float!($primitive, $kind, |val| $primitive::from_f64(val));
115    };
116    ($primitive:ident, $kind:ident) => {
117        impl_float!($primitive, $kind, |val| val as $primitive);
118    };
119    ($primitive:ident, $kind:ident, $new:expr) => {
120        impl CubeType for $primitive {
121            type ExpandType = NativeExpand<$primitive>;
122        }
123
124        impl Scalar for $primitive {}
125        impl CubePrimitive for $primitive {
126            type Scalar = Self;
127            type Size = Const<1>;
128            type WithScalar<S: Scalar> = S;
129
130            /// Return the element type to use on GPU
131            fn as_type_native() -> Option<Type> {
132                Some(StorageType::Scalar(ElemType::Float(FloatKind::$kind)).into())
133            }
134
135            fn from_const_value(value: ConstantValue) -> Self {
136                let ConstantValue::Float(value) = value else {
137                    unreachable!()
138                };
139                $new(value)
140            }
141        }
142
143        impl IntoRuntime for $primitive {
144            fn __expand_runtime_method(self, _scope: &mut Scope) -> NativeExpand<Self> {
145                self.into()
146            }
147        }
148
149        impl Numeric for $primitive {
150            fn min_value() -> Self {
151                <Self as num_traits::Float>::min_value()
152            }
153            fn max_value() -> Self {
154                <Self as num_traits::Float>::max_value()
155            }
156        }
157
158        impl NativeAssign for $primitive {}
159
160        impl IntoMut for $primitive {
161            fn into_mut(self, _scope: &mut Scope) -> Self {
162                self
163            }
164        }
165
166        impl Float for $primitive {
167            const DIGITS: u32 = $primitive::DIGITS;
168            const EPSILON: Self = $primitive::EPSILON;
169            const INFINITY: Self = $primitive::INFINITY;
170            const MANTISSA_DIGITS: u32 = $primitive::MANTISSA_DIGITS;
171            const MAX_10_EXP: i32 = $primitive::MAX_10_EXP;
172            const MAX_EXP: i32 = $primitive::MAX_EXP;
173            const MIN_10_EXP: i32 = $primitive::MIN_10_EXP;
174            const MIN_EXP: i32 = $primitive::MIN_EXP;
175            const MIN_POSITIVE: Self = $primitive::MIN_POSITIVE;
176            const NAN: Self = $primitive::NAN;
177            const NEG_INFINITY: Self = $primitive::NEG_INFINITY;
178            const RADIX: u32 = $primitive::RADIX;
179
180            fn new(val: f32) -> Self {
181                $new(val as f64)
182            }
183        }
184    };
185}
186
187impl_float!(half f16, F16);
188impl_float!(half bf16, BF16);
189impl_float!(f32, F32);
190impl_float!(f64, F64);