cubecl_core/frontend/element/
float.rs1use cubecl_ir::{ConstantValue, Scope, StorageType, Type};
2use half::{bf16, f16};
3
4use crate::{
5 self as cubecl,
6 ir::{ElemType, FloatKind},
7 prelude::*,
8};
9
10use super::Numeric;
11
12mod fp4;
13mod fp6;
14mod fp8;
15mod relaxed;
16mod tensor_float;
17
18pub trait Float:
20 Numeric
21 + FloatOps
22 + Exp
23 + Log
24 + Log1p
25 + Expm1
26 + Cos
27 + Sin
28 + Tan
29 + Tanh
30 + Sinh
31 + Cosh
32 + ArcCos
33 + ArcSin
34 + ArcTan
35 + ArcSinh
36 + ArcCosh
37 + ArcTanh
38 + Degrees
39 + Radians
40 + ArcTan2
41 + Powf
42 + Powi<i32>
43 + Hypot
44 + Rhypot
45 + Sqrt
46 + InverseSqrt
47 + Round
48 + Floor
49 + Ceil
50 + Trunc
51 + Erf
52 + Recip
53 + Magnitude
54 + Normalize
55 + Dot
56 + IsNan
57 + IsInf
58 + Into<Self::ExpandType>
59 + core::ops::Neg<Output = Self>
60 + core::cmp::PartialOrd
61 + core::cmp::PartialEq
62{
63 const DIGITS: u32;
64 const EPSILON: Self;
65 const INFINITY: Self;
66 const MANTISSA_DIGITS: u32;
67 const MAX_10_EXP: i32;
68 const MAX_EXP: i32;
69 const MIN_10_EXP: i32;
70 const MIN_EXP: i32;
71 const MIN_POSITIVE: Self;
72 const NAN: Self;
73 const NEG_INFINITY: Self;
74 const RADIX: u32;
75
76 fn new(val: f32) -> Self;
77 fn __expand_new(scope: &mut Scope, val: f32) -> <Self as CubeType>::ExpandType {
78 __expand_new(scope, val)
79 }
80}
81
82#[cube]
83pub trait FloatOps: CubePrimitive + PartialOrd + Sized {
84 fn min(self, other: Self) -> Self {
85 cubecl::prelude::min(self, other)
86 }
87
88 fn max(self, other: Self) -> Self {
89 cubecl::prelude::max(self, other)
90 }
91
92 fn clamp(self, min: Self, max: Self) -> Self {
93 clamp(self, min, max)
94 }
95}
96
97impl<T: Float> FloatOps for T {}
98impl<T: FloatOps + CubePrimitive> FloatOpsExpand for NativeExpand<T> {
99 fn __expand_min_method(self, scope: &mut Scope, other: Self) -> Self {
100 min::expand(scope, self, other)
101 }
102
103 fn __expand_max_method(self, scope: &mut Scope, other: Self) -> Self {
104 max::expand(scope, self, other)
105 }
106
107 fn __expand_clamp_method(self, scope: &mut Scope, min: Self, max: Self) -> Self {
108 clamp::expand(scope, self, min, max)
109 }
110}
111
112macro_rules! impl_float {
113 (half $primitive:ident, $kind:ident) => {
114 impl_float!($primitive, $kind, |val| $primitive::from_f64(val));
115 };
116 ($primitive:ident, $kind:ident) => {
117 impl_float!($primitive, $kind, |val| val as $primitive);
118 };
119 ($primitive:ident, $kind:ident, $new:expr) => {
120 impl CubeType for $primitive {
121 type ExpandType = NativeExpand<$primitive>;
122 }
123
124 impl Scalar for $primitive {}
125 impl CubePrimitive for $primitive {
126 type Scalar = Self;
127 type Size = Const<1>;
128 type WithScalar<S: Scalar> = S;
129
130 fn as_type_native() -> Option<Type> {
132 Some(StorageType::Scalar(ElemType::Float(FloatKind::$kind)).into())
133 }
134
135 fn from_const_value(value: ConstantValue) -> Self {
136 let ConstantValue::Float(value) = value else {
137 unreachable!()
138 };
139 $new(value)
140 }
141 }
142
143 impl IntoRuntime for $primitive {
144 fn __expand_runtime_method(self, _scope: &mut Scope) -> NativeExpand<Self> {
145 self.into()
146 }
147 }
148
149 impl Numeric for $primitive {
150 fn min_value() -> Self {
151 <Self as num_traits::Float>::min_value()
152 }
153 fn max_value() -> Self {
154 <Self as num_traits::Float>::max_value()
155 }
156 }
157
158 impl NativeAssign for $primitive {}
159
160 impl IntoMut for $primitive {
161 fn into_mut(self, _scope: &mut Scope) -> Self {
162 self
163 }
164 }
165
166 impl Float for $primitive {
167 const DIGITS: u32 = $primitive::DIGITS;
168 const EPSILON: Self = $primitive::EPSILON;
169 const INFINITY: Self = $primitive::INFINITY;
170 const MANTISSA_DIGITS: u32 = $primitive::MANTISSA_DIGITS;
171 const MAX_10_EXP: i32 = $primitive::MAX_10_EXP;
172 const MAX_EXP: i32 = $primitive::MAX_EXP;
173 const MIN_10_EXP: i32 = $primitive::MIN_10_EXP;
174 const MIN_EXP: i32 = $primitive::MIN_EXP;
175 const MIN_POSITIVE: Self = $primitive::MIN_POSITIVE;
176 const NAN: Self = $primitive::NAN;
177 const NEG_INFINITY: Self = $primitive::NEG_INFINITY;
178 const RADIX: u32 = $primitive::RADIX;
179
180 fn new(val: f32) -> Self {
181 $new(val as f64)
182 }
183 }
184 };
185}
186
187impl_float!(half f16, F16);
188impl_float!(half bf16, BF16);
189impl_float!(f32, F32);
190impl_float!(f64, F64);