Skip to main content

cubecl_core/frontend/element/
base.rs

1use super::{CubePrimitive, Numeric};
2use crate::{
3    ir::{ConstantValue, Scope, Variable, VariableKind},
4    prelude::{DynamicSize, KernelBuilder, KernelLauncher, assign},
5    unexpanded,
6};
7use alloc::{boxed::Box, vec::Vec};
8use core::marker::PhantomData;
9use cubecl_common::{e2m1, e2m1x2, e2m3, e3m2, e4m3, e5m2, flex32, tf32, ue8m0};
10use cubecl_ir::{ManagedVariable, VectorSize};
11use cubecl_runtime::runtime::Runtime;
12use half::{bf16, f16};
13use variadics_please::{all_tuples, all_tuples_enumerated};
14
15/// Types used in a cube function must implement this trait
16///
17/// Variables whose values will be known at runtime must
18/// have `ManagedVariable` as associated type
19/// Variables whose values will be known at compile time
20/// must have the primitive type as associated type
21///
22/// Note: Cube functions should be written using `CubeTypes`,
23/// so that the code generated uses the associated `ExpandType`.
24/// This allows Cube code to not necessitate cloning, which is cumbersome
25/// in algorithmic code. The necessary cloning will automatically appear in
26/// the generated code.
27#[diagnostic::on_unimplemented(note = "Consider using `#[derive(CubeType)]` on `{Self}`")]
28pub trait CubeType {
29    type ExpandType: Clone + IntoMut + CubeDebug;
30}
31
32pub trait CubeEnum: Sized {
33    type RuntimeValue: Clone + CubeDebug;
34
35    fn discriminant(&self) -> NativeExpand<i32>;
36
37    /// Return the runtime value of this enum, if only one variant has a value.
38    /// Should return () for all other cases.
39    fn runtime_value(self) -> Self::RuntimeValue;
40
41    fn discriminant_of_value(&self, variant_name: &'static str) -> i32 {
42        Self::discriminant_of(variant_name)
43    }
44
45    fn discriminant_of(variant_name: &'static str) -> i32;
46}
47
48pub trait Assign {
49    /// Assign `value` to `self` in `scope`.
50    fn expand_assign(&mut self, scope: &mut Scope, value: Self);
51    /// Create a new mutable variable of this type in `scope`.
52    fn init_mut(&self, scope: &mut Scope) -> Self;
53}
54
55impl<T: CubePrimitive> Assign for T {
56    fn expand_assign(&mut self, _scope: &mut Scope, value: Self) {
57        *self = value;
58    }
59    fn init_mut(&self, _scope: &mut Scope) -> Self {
60        *self
61    }
62}
63
64impl<T: NativeAssign> Assign for NativeExpand<T> {
65    fn expand_assign(&mut self, scope: &mut Scope, value: Self) {
66        assign::expand(scope, value, self.clone());
67    }
68    fn init_mut(&self, scope: &mut Scope) -> Self {
69        T::elem_init_mut(scope, self.expand.clone()).into()
70    }
71}
72
73impl<T: Assign> Assign for Option<T> {
74    fn expand_assign(&mut self, scope: &mut Scope, value: Self) {
75        match (self, value) {
76            (Some(this), Some(other)) => this.expand_assign(scope, other),
77            (None, None) => {}
78            _ => panic!("Can't assign mismatched enum variants"),
79        }
80    }
81    fn init_mut(&self, scope: &mut Scope) -> Self {
82        self.as_ref().map(|value| value.init_mut(scope))
83    }
84}
85
86impl<T: Assign> Assign for Vec<T> {
87    fn expand_assign(&mut self, scope: &mut Scope, value: Self) {
88        assert!(
89            self.len() == value.len(),
90            "Can't assign mismatched vector lengths"
91        );
92        for (this, other) in self.iter_mut().zip(value) {
93            this.expand_assign(scope, other);
94        }
95    }
96    fn init_mut(&self, scope: &mut Scope) -> Self {
97        self.iter().map(|it| it.init_mut(scope)).collect()
98    }
99}
100
101pub trait CloneExpand {
102    fn __expand_clone_method(&self, scope: &mut Scope) -> Self;
103}
104
105impl<C: Clone> CloneExpand for C {
106    fn __expand_clone_method(&self, _scope: &mut Scope) -> Self {
107        self.clone()
108    }
109}
110
111/// Trait useful to convert a comptime value into runtime value.
112pub trait IntoRuntime: CubeType + Sized {
113    fn runtime(self) -> Self {
114        self
115    }
116
117    fn __expand_runtime_method(self, scope: &mut Scope) -> Self::ExpandType;
118}
119
120/// Trait for marking a function return value as comptime when the compiler can't infer it.
121pub trait IntoComptime: Sized {
122    #[allow(clippy::wrong_self_convention)]
123    fn comptime(self) -> Self {
124        self
125    }
126}
127
128impl<T: Sized> IntoComptime for T {}
129
130/// Convert an expand type to a version with mutable registers when necessary.
131pub trait IntoMut: Sized {
132    /// Convert the variable into a potentially new mutable variable in `scope`, copying if needed.
133    fn into_mut(self, scope: &mut Scope) -> Self;
134}
135
136pub fn into_mut_assign<T: Assign>(value: T, scope: &mut Scope) -> T {
137    let mut out = value.init_mut(scope);
138    out.expand_assign(scope, value);
139    out
140}
141
142pub trait CubeDebug: Sized {
143    /// Set the debug name of this type's expansion. Should do nothing for types that don't appear
144    /// at runtime
145    #[allow(unused)]
146    fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {}
147}
148
149/// A type that can be used as a kernel comptime argument.
150/// Note that a type doesn't need to implement `CubeComptime` to be used as
151/// a comptime argument. However, this facilitate the declaration of generic cube types.
152///
153/// # Example
154///
155/// ```ignore
156/// #[derive(CubeType)]
157/// pub struct Example<A: CubeType, B: CubeComptime> {
158///     a: A,
159///     #[cube(comptime)]
160///     b: B
161/// }
162/// ```
163pub trait CubeComptime: core::fmt::Debug + core::hash::Hash + Eq + Clone + Copy {}
164impl<T> CubeComptime for T where T: core::fmt::Debug + core::hash::Hash + Eq + Clone + Copy {}
165
166/// Argument used during the compilation of kernels.
167pub trait CompilationArg:
168    Clone + PartialEq + Eq + core::hash::Hash + core::fmt::Debug + Send + Sync + 'static
169{
170    /// Compilation args should be the same even with different element types. However, it isn't
171    /// possible to enforce it with the type system. So, we make the compilation args serializable
172    /// and dynamically cast them.
173    ///
174    /// Without this, the compilation time is unreasonable. The performance drop isn't a concern
175    /// since this is only done once when compiling a kernel for the first time.
176    fn dynamic_cast<Arg: CompilationArg>(&self) -> Arg {
177        // Dynamic cast, unlike transmute it does not require statically proving the types are the
178        // same size. We assert at runtime to avoid undefined behaviour and help the compiler optimize.
179        assert!(size_of::<Arg>() == size_of::<Self>());
180        let this = Box::new(self.clone());
181        unsafe { *Box::from_raw(Box::into_raw(this) as *mut Arg) }
182    }
183}
184
185impl<T: Clone + PartialEq + Eq + core::hash::Hash + core::fmt::Debug + Send + Sync + 'static>
186    CompilationArg for T
187{
188}
189
190/// Defines how a [launch argument](LaunchArg) can be expanded.
191///
192/// TODO Verify the accuracy of the next comment.
193///
194/// Normally this type should be implemented two times for an argument.
195/// Once for the reference and the other for the mutable reference. Often time, the reference
196/// should expand the argument as an input while the mutable reference should expand the argument
197/// as an output.
198#[diagnostic::on_unimplemented(note = "Consider using `#[derive(CubeLaunch)]` on `{Self}`")]
199pub trait LaunchArg: CubeType + Send + Sync + 'static {
200    /// The runtime argument for the kernel.
201    type RuntimeArg<R: Runtime>: Send + Sync;
202    /// Compilation argument.
203    type CompilationArg: CompilationArg;
204
205    fn register<R: Runtime>(
206        arg: Self::RuntimeArg<R>,
207        launcher: &mut KernelLauncher<R>,
208    ) -> Self::CompilationArg;
209
210    /// Register an input variable during compilation that fill the [`KernelBuilder`].
211    fn expand(
212        arg: &Self::CompilationArg,
213        builder: &mut KernelBuilder,
214    ) -> <Self as CubeType>::ExpandType;
215
216    /// Register an output variable during compilation that fill the [`KernelBuilder`].
217    fn expand_output(
218        arg: &Self::CompilationArg,
219        builder: &mut KernelBuilder,
220    ) -> <Self as CubeType>::ExpandType {
221        Self::expand(arg, builder)
222    }
223}
224
225macro_rules! launch_tuple {
226    ($(($T:ident, $t:ident)),*) => {
227        impl<$($T: LaunchArg),*> LaunchArg for ($($T),*) {
228            type RuntimeArg<R: Runtime> = ($($T::RuntimeArg<R>),*);
229            type CompilationArg = ($($T::CompilationArg),*);
230
231            fn register<R: Runtime>(runtime_arg: Self::RuntimeArg<R>, launcher: &mut KernelLauncher<R>) -> Self::CompilationArg {
232                let ($($t),*) = runtime_arg;
233                ($($T::register($t, launcher)),*)
234            }
235
236            fn expand(arg: &Self::CompilationArg, builder: &mut KernelBuilder) -> ($(<$T as CubeType>::ExpandType),*) {
237                let ($($t),*) = arg;
238                ($($T::expand($t, builder)),*)
239            }
240
241            fn expand_output(arg: &Self::CompilationArg, builder: &mut KernelBuilder) -> ($(<$T as CubeType>::ExpandType),*) {
242                let ($($t),*) = arg;
243                ($($T::expand_output($t, builder)),*)
244            }
245        }
246    };
247}
248
249all_tuples!(launch_tuple, 2, 12, T, t);
250
251/// Expand type of a native GPU type, i.e. scalar primitives, arrays, shared memory.
252#[derive(new)]
253pub struct NativeExpand<T: CubeType> {
254    pub expand: ManagedVariable,
255    pub(crate) _type: PhantomData<T>,
256}
257
258impl<T: CubeType> NativeExpand<T> {
259    /// Casts a reference of this expand element to a different type.
260    /// # Safety
261    /// There's no guarantee the new type is valid for the `ManagedVariable`
262    pub unsafe fn as_type_ref_unchecked<E: CubeType>(&self) -> &NativeExpand<E> {
263        unsafe { core::mem::transmute::<&NativeExpand<T>, &NativeExpand<E>>(self) }
264    }
265
266    /// Casts a mutable reference of this expand element to a different type.
267    /// # Safety
268    /// There's no guarantee the new type is valid for the `ManagedVariable`
269    pub unsafe fn as_type_mut_unchecked<E: CubeType>(&mut self) -> &mut NativeExpand<E> {
270        unsafe { core::mem::transmute::<&mut NativeExpand<T>, &mut NativeExpand<E>>(self) }
271    }
272}
273
274impl<T: CubeType> From<&NativeExpand<T>> for NativeExpand<T> {
275    fn from(value: &NativeExpand<T>) -> Self {
276        value.clone()
277    }
278}
279
280impl<T: CubeType> From<NativeExpand<T>> for Variable {
281    fn from(value: NativeExpand<T>) -> Self {
282        value.expand.into()
283    }
284}
285
286impl<T: CubeType> From<&mut NativeExpand<T>> for NativeExpand<T> {
287    fn from(value: &mut NativeExpand<T>) -> Self {
288        value.clone()
289    }
290}
291
292macro_rules! from_const {
293    ($lit:ty) => {
294        impl From<$lit> for NativeExpand<$lit> {
295            fn from(value: $lit) -> Self {
296                let variable: Variable = value.into();
297
298                ManagedVariable::Plain(variable).into()
299            }
300        }
301    };
302}
303
304from_const!(u8);
305from_const!(u16);
306from_const!(u32);
307from_const!(u64);
308from_const!(usize);
309from_const!(isize);
310from_const!(i64);
311from_const!(i8);
312from_const!(i16);
313from_const!(i32);
314from_const!(f64);
315from_const!(f16);
316from_const!(bf16);
317from_const!(flex32);
318from_const!(tf32);
319from_const!(f32);
320from_const!(e2m1);
321from_const!(e2m1x2);
322from_const!(e2m3);
323from_const!(e3m2);
324from_const!(e4m3);
325from_const!(e5m2);
326from_const!(ue8m0);
327from_const!(bool);
328from_const!(num_complex::Complex<f32>);
329from_const!(num_complex::Complex<f64>);
330
331macro_rules! tuple_cube_type {
332    ($($P:ident),*) => {
333        impl<$($P: CubeType),*> CubeType for ($($P,)*) {
334            type ExpandType = ($($P::ExpandType,)*);
335        }
336    }
337}
338macro_rules! tuple_init {
339    ($($P:ident),*) => {
340        impl<$($P: IntoMut),*> IntoMut for ($($P,)*) {
341            #[allow(non_snake_case, unused, clippy::unused_unit)]
342            fn into_mut(self, scope: &mut Scope) -> Self {
343                let ($($P,)*) = self;
344                ($(
345                    $P.into_mut(scope),
346                )*)
347            }
348        }
349    }
350}
351macro_rules! tuple_debug {
352    ($($P:ident),*) => {
353        impl<$($P: CubeDebug),*> CubeDebug for ($($P,)*) {}
354    }
355}
356macro_rules! tuple_runtime {
357    ($($P:ident),*) => {
358        impl<$($P: IntoRuntime),*> IntoRuntime for ($($P,)*) {
359            #[allow(non_snake_case, unused, clippy::unused_unit)]
360            fn __expand_runtime_method(self, scope: &mut Scope) -> Self::ExpandType {
361                let ($($P,)*) = self;
362                ($(
363                    $P.__expand_runtime_method(scope),
364                )*)
365            }
366        }
367    }
368}
369macro_rules! tuple_assign {
370    ($(($n: tt, $P:ident)),*) => {
371        impl<$($P: Assign),*> Assign for ($($P,)*) {
372            #[allow(non_snake_case, unused, clippy::unused_unit)]
373            fn expand_assign(&mut self, scope: &mut Scope, value: Self) {
374                let ($($P,)*) = self;
375                $(
376                    $P.expand_assign(scope, value.$n);
377                )*
378            }
379            #[allow(non_snake_case, unused, clippy::unused_unit)]
380            fn init_mut(&self, scope: &mut Scope) -> Self {
381                let ($($P,)*) = self;
382                ($(
383                    $P.init_mut(scope),
384                )*)
385            }
386        }
387    }
388}
389
390all_tuples!(tuple_cube_type, 0, 12, P);
391all_tuples!(tuple_debug, 0, 12, P);
392all_tuples!(tuple_init, 0, 12, P);
393all_tuples!(tuple_runtime, 0, 12, P);
394all_tuples_enumerated!(tuple_assign, 0, 12, P);
395
396impl<P: CubePrimitive> CubeDebug for P {}
397
398/// Trait for native types that can be assigned. For non-native composites, use the normal [`Assign`].
399pub trait NativeAssign: CubeType {
400    fn elem_init_mut(scope: &mut Scope, elem: ManagedVariable) -> ManagedVariable {
401        init_mut_expand_element(scope, &elem)
402    }
403}
404
405impl<T: NativeAssign> IntoMut for NativeExpand<T> {
406    fn into_mut(self, scope: &mut Scope) -> Self {
407        into_mut_assign(self, scope)
408    }
409}
410
411impl<T: CubeType> CubeDebug for NativeExpand<T> {
412    fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
413        scope.update_variable_name(*self.expand, name);
414    }
415}
416
417impl<T: CubeType> CubeDebug for &NativeExpand<T> {
418    fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
419        scope.update_variable_name(*self.expand, name);
420    }
421}
422
423impl<T: CubeType> CubeDebug for &mut NativeExpand<T> {
424    fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
425        scope.update_variable_name(*self.expand, name);
426    }
427}
428
429impl<T: CubeType> NativeExpand<T> {
430    /// Comptime version of [`crate::frontend::Array::vector_size`].
431    pub fn vector_size(&self) -> VectorSize {
432        self.expand.ty.vector_size()
433    }
434
435    // Expanded version of vectorization factor.
436    pub fn __expand_vector_size_method(self, _scope: &mut Scope) -> VectorSize {
437        self.expand.ty.vector_size()
438    }
439
440    pub fn into_variable(self) -> Variable {
441        self.expand.consume()
442    }
443}
444
445impl<T: CubeType> Clone for NativeExpand<T> {
446    fn clone(&self) -> Self {
447        Self {
448            expand: self.expand.clone(),
449            _type: PhantomData,
450        }
451    }
452}
453
454impl<T: CubeType> From<ManagedVariable> for NativeExpand<T> {
455    fn from(expand: ManagedVariable) -> Self {
456        Self {
457            expand,
458            _type: PhantomData,
459        }
460    }
461}
462
463impl<T: CubeType> From<NativeExpand<T>> for ManagedVariable {
464    fn from(value: NativeExpand<T>) -> Self {
465        value.expand
466    }
467}
468
469impl<T: CubePrimitive> NativeExpand<T> {
470    /// Create an [`NativeExpand`] from a value that is normally a literal.
471    pub fn from_lit<L: Into<ConstantValue>>(scope: &Scope, lit: L) -> Self {
472        let variable: ConstantValue = lit.into();
473        let variable = T::as_type(scope).constant(variable);
474
475        NativeExpand::new(ManagedVariable::Plain(variable))
476    }
477
478    /// Get the [`ConstantValue`] from the variable.
479    pub fn constant(&self) -> Option<ConstantValue> {
480        match self.expand.kind {
481            VariableKind::Constant(val) => Some(val),
482            _ => None,
483        }
484    }
485
486    pub fn __expand_into_lit_unchecked_method(self, _scope: &mut Scope) -> T {
487        let value = self.constant().unwrap();
488        T::from_const_value(value)
489    }
490}
491
492pub(crate) fn init_mut_expand_element(
493    scope: &mut Scope,
494    element: &ManagedVariable,
495) -> ManagedVariable {
496    scope.create_local_mut(element.ty)
497}
498
499impl<T: IntoMut> IntoMut for Option<T> {
500    fn into_mut(self, scope: &mut Scope) -> Self {
501        self.map(|o| IntoMut::into_mut(o, scope))
502    }
503}
504
505impl<T: CubeType> CubeType for Vec<T> {
506    type ExpandType = Vec<T::ExpandType>;
507}
508
509impl<T: CubeType> CubeType for &mut Vec<T> {
510    type ExpandType = Vec<T::ExpandType>;
511}
512
513impl<T: IntoMut> IntoMut for Vec<T> {
514    fn into_mut(self, scope: &mut Scope) -> Self {
515        self.into_iter().map(|e| e.into_mut(scope)).collect()
516    }
517}
518impl<T: CubeDebug> CubeDebug for Vec<T> {}
519
520/// Create a constant element of the correct type during expansion.
521pub(crate) fn __expand_new<C: Numeric, Out: Numeric>(
522    scope: &mut Scope,
523    val: C,
524) -> NativeExpand<Out> {
525    let input: ConstantValue = val.into();
526    let var = Out::as_type(scope).constant(input);
527    ManagedVariable::Plain(var).into()
528}
529
530impl LaunchArg for () {
531    type RuntimeArg<R: Runtime> = ();
532    type CompilationArg = ();
533
534    fn register<R: Runtime>(_runtime_arg: Self::RuntimeArg<R>, _launcher: &mut KernelLauncher<R>) {
535        // nothing to do
536    }
537
538    fn expand(
539        _: &Self::CompilationArg,
540        _builder: &mut KernelBuilder,
541    ) -> <Self as CubeType>::ExpandType {
542    }
543}
544
545pub trait DefaultExpand: CubeType {
546    fn __expand_default(scope: &mut Scope) -> Self::ExpandType;
547}
548
549impl<T: CubeType + Default + IntoRuntime> DefaultExpand for T {
550    fn __expand_default(scope: &mut Scope) -> T::ExpandType {
551        T::default().__expand_runtime_method(scope)
552    }
553}
554
555#[derive(Clone, Copy, Debug)]
556pub struct Const<const N: usize>;
557
558pub trait Size: core::fmt::Debug + Clone + Copy + Send + Sync + 'static {
559    fn __expand_value(scope: &Scope) -> usize;
560    fn value() -> usize {
561        unexpanded!()
562    }
563    fn try_value_const() -> Option<usize> {
564        None
565    }
566}
567
568impl<const VALUE: usize> Size for Const<VALUE> {
569    fn __expand_value(_scope: &Scope) -> usize {
570        VALUE
571    }
572    fn value() -> usize {
573        VALUE
574    }
575    fn try_value_const() -> Option<usize> {
576        Some(VALUE)
577    }
578}
579
580impl<Marker: 'static> Size for DynamicSize<Marker> {
581    fn __expand_value(scope: &Scope) -> usize {
582        scope.resolve_size::<Self>().expect("Size to be registered")
583    }
584    fn value() -> usize {
585        unexpanded!()
586    }
587}
588
589/// Define a custom type to be used for a comptime scalar type.
590/// Useful for cases where generics can't work.
591#[macro_export]
592macro_rules! define_scalar {
593    ($vis: vis $name: ident) => {
594        $crate::__private::paste! {
595            $vis struct [<__ $name>];
596            $vis type $name = $crate::prelude::DynamicScalar<[<__ $name>]>;
597        }
598    };
599}
600
601/// Define a custom type to be used for a comptime size. Useful for cases where generics can't work.
602#[macro_export]
603macro_rules! define_size {
604    ($vis: vis $name: ident) => {
605        $crate::__private::paste! {
606            $vis struct [<__ $name>];
607            $vis type $name = $crate::prelude::DynamicSize<[<__ $name>]>;
608        }
609    };
610}