1use super::{CubePrimitive, Numeric};
2use crate::{
3 ir::{ConstantValue, Scope, Variable, VariableKind},
4 prelude::{DynamicSize, KernelBuilder, KernelLauncher, assign},
5 unexpanded,
6};
7use alloc::{boxed::Box, vec::Vec};
8use core::marker::PhantomData;
9use cubecl_common::{e2m1, e2m1x2, e2m3, e3m2, e4m3, e5m2, flex32, tf32, ue8m0};
10use cubecl_ir::{ManagedVariable, VectorSize};
11use cubecl_runtime::runtime::Runtime;
12use half::{bf16, f16};
13use variadics_please::{all_tuples, all_tuples_enumerated};
14
15#[diagnostic::on_unimplemented(note = "Consider using `#[derive(CubeType)]` on `{Self}`")]
28pub trait CubeType {
29 type ExpandType: Clone + IntoMut + CubeDebug;
30}
31
32pub trait CubeEnum: Sized {
33 type RuntimeValue: Clone + CubeDebug;
34
35 fn discriminant(&self) -> NativeExpand<i32>;
36
37 fn runtime_value(self) -> Self::RuntimeValue;
40
41 fn discriminant_of_value(&self, variant_name: &'static str) -> i32 {
42 Self::discriminant_of(variant_name)
43 }
44
45 fn discriminant_of(variant_name: &'static str) -> i32;
46}
47
48pub trait Assign {
49 fn expand_assign(&mut self, scope: &mut Scope, value: Self);
51 fn init_mut(&self, scope: &mut Scope) -> Self;
53}
54
55impl<T: CubePrimitive> Assign for T {
56 fn expand_assign(&mut self, _scope: &mut Scope, value: Self) {
57 *self = value;
58 }
59 fn init_mut(&self, _scope: &mut Scope) -> Self {
60 *self
61 }
62}
63
64impl<T: NativeAssign> Assign for NativeExpand<T> {
65 fn expand_assign(&mut self, scope: &mut Scope, value: Self) {
66 assign::expand(scope, value, self.clone());
67 }
68 fn init_mut(&self, scope: &mut Scope) -> Self {
69 T::elem_init_mut(scope, self.expand.clone()).into()
70 }
71}
72
73impl<T: Assign> Assign for Option<T> {
74 fn expand_assign(&mut self, scope: &mut Scope, value: Self) {
75 match (self, value) {
76 (Some(this), Some(other)) => this.expand_assign(scope, other),
77 (None, None) => {}
78 _ => panic!("Can't assign mismatched enum variants"),
79 }
80 }
81 fn init_mut(&self, scope: &mut Scope) -> Self {
82 self.as_ref().map(|value| value.init_mut(scope))
83 }
84}
85
86impl<T: Assign> Assign for Vec<T> {
87 fn expand_assign(&mut self, scope: &mut Scope, value: Self) {
88 assert!(
89 self.len() == value.len(),
90 "Can't assign mismatched vector lengths"
91 );
92 for (this, other) in self.iter_mut().zip(value) {
93 this.expand_assign(scope, other);
94 }
95 }
96 fn init_mut(&self, scope: &mut Scope) -> Self {
97 self.iter().map(|it| it.init_mut(scope)).collect()
98 }
99}
100
101pub trait CloneExpand {
102 fn __expand_clone_method(&self, scope: &mut Scope) -> Self;
103}
104
105impl<C: Clone> CloneExpand for C {
106 fn __expand_clone_method(&self, _scope: &mut Scope) -> Self {
107 self.clone()
108 }
109}
110
111pub trait IntoRuntime: CubeType + Sized {
113 fn runtime(self) -> Self {
114 self
115 }
116
117 fn __expand_runtime_method(self, scope: &mut Scope) -> Self::ExpandType;
118}
119
120pub trait IntoComptime: Sized {
122 #[allow(clippy::wrong_self_convention)]
123 fn comptime(self) -> Self {
124 self
125 }
126}
127
128impl<T: Sized> IntoComptime for T {}
129
130pub trait IntoMut: Sized {
132 fn into_mut(self, scope: &mut Scope) -> Self;
134}
135
136pub fn into_mut_assign<T: Assign>(value: T, scope: &mut Scope) -> T {
137 let mut out = value.init_mut(scope);
138 out.expand_assign(scope, value);
139 out
140}
141
142pub trait CubeDebug: Sized {
143 #[allow(unused)]
146 fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {}
147}
148
149pub trait CubeComptime: core::fmt::Debug + core::hash::Hash + Eq + Clone + Copy {}
164impl<T> CubeComptime for T where T: core::fmt::Debug + core::hash::Hash + Eq + Clone + Copy {}
165
166pub trait CompilationArg:
168 Clone + PartialEq + Eq + core::hash::Hash + core::fmt::Debug + Send + Sync + 'static
169{
170 fn dynamic_cast<Arg: CompilationArg>(&self) -> Arg {
177 assert!(size_of::<Arg>() == size_of::<Self>());
180 let this = Box::new(self.clone());
181 unsafe { *Box::from_raw(Box::into_raw(this) as *mut Arg) }
182 }
183}
184
185impl<T: Clone + PartialEq + Eq + core::hash::Hash + core::fmt::Debug + Send + Sync + 'static>
186 CompilationArg for T
187{
188}
189
190#[diagnostic::on_unimplemented(note = "Consider using `#[derive(CubeLaunch)]` on `{Self}`")]
199pub trait LaunchArg: CubeType + Send + Sync + 'static {
200 type RuntimeArg<R: Runtime>: Send + Sync;
202 type CompilationArg: CompilationArg;
204
205 fn register<R: Runtime>(
206 arg: Self::RuntimeArg<R>,
207 launcher: &mut KernelLauncher<R>,
208 ) -> Self::CompilationArg;
209
210 fn expand(
212 arg: &Self::CompilationArg,
213 builder: &mut KernelBuilder,
214 ) -> <Self as CubeType>::ExpandType;
215
216 fn expand_output(
218 arg: &Self::CompilationArg,
219 builder: &mut KernelBuilder,
220 ) -> <Self as CubeType>::ExpandType {
221 Self::expand(arg, builder)
222 }
223}
224
225macro_rules! launch_tuple {
226 ($(($T:ident, $t:ident)),*) => {
227 impl<$($T: LaunchArg),*> LaunchArg for ($($T),*) {
228 type RuntimeArg<R: Runtime> = ($($T::RuntimeArg<R>),*);
229 type CompilationArg = ($($T::CompilationArg),*);
230
231 fn register<R: Runtime>(runtime_arg: Self::RuntimeArg<R>, launcher: &mut KernelLauncher<R>) -> Self::CompilationArg {
232 let ($($t),*) = runtime_arg;
233 ($($T::register($t, launcher)),*)
234 }
235
236 fn expand(arg: &Self::CompilationArg, builder: &mut KernelBuilder) -> ($(<$T as CubeType>::ExpandType),*) {
237 let ($($t),*) = arg;
238 ($($T::expand($t, builder)),*)
239 }
240
241 fn expand_output(arg: &Self::CompilationArg, builder: &mut KernelBuilder) -> ($(<$T as CubeType>::ExpandType),*) {
242 let ($($t),*) = arg;
243 ($($T::expand_output($t, builder)),*)
244 }
245 }
246 };
247}
248
249all_tuples!(launch_tuple, 2, 12, T, t);
250
251#[derive(new)]
253pub struct NativeExpand<T: CubeType> {
254 pub expand: ManagedVariable,
255 pub(crate) _type: PhantomData<T>,
256}
257
258impl<T: CubeType> NativeExpand<T> {
259 pub unsafe fn as_type_ref_unchecked<E: CubeType>(&self) -> &NativeExpand<E> {
263 unsafe { core::mem::transmute::<&NativeExpand<T>, &NativeExpand<E>>(self) }
264 }
265
266 pub unsafe fn as_type_mut_unchecked<E: CubeType>(&mut self) -> &mut NativeExpand<E> {
270 unsafe { core::mem::transmute::<&mut NativeExpand<T>, &mut NativeExpand<E>>(self) }
271 }
272}
273
274impl<T: CubeType> From<&NativeExpand<T>> for NativeExpand<T> {
275 fn from(value: &NativeExpand<T>) -> Self {
276 value.clone()
277 }
278}
279
280impl<T: CubeType> From<NativeExpand<T>> for Variable {
281 fn from(value: NativeExpand<T>) -> Self {
282 value.expand.into()
283 }
284}
285
286impl<T: CubeType> From<&mut NativeExpand<T>> for NativeExpand<T> {
287 fn from(value: &mut NativeExpand<T>) -> Self {
288 value.clone()
289 }
290}
291
292macro_rules! from_const {
293 ($lit:ty) => {
294 impl From<$lit> for NativeExpand<$lit> {
295 fn from(value: $lit) -> Self {
296 let variable: Variable = value.into();
297
298 ManagedVariable::Plain(variable).into()
299 }
300 }
301 };
302}
303
304from_const!(u8);
305from_const!(u16);
306from_const!(u32);
307from_const!(u64);
308from_const!(usize);
309from_const!(isize);
310from_const!(i64);
311from_const!(i8);
312from_const!(i16);
313from_const!(i32);
314from_const!(f64);
315from_const!(f16);
316from_const!(bf16);
317from_const!(flex32);
318from_const!(tf32);
319from_const!(f32);
320from_const!(e2m1);
321from_const!(e2m1x2);
322from_const!(e2m3);
323from_const!(e3m2);
324from_const!(e4m3);
325from_const!(e5m2);
326from_const!(ue8m0);
327from_const!(bool);
328from_const!(num_complex::Complex<f32>);
329from_const!(num_complex::Complex<f64>);
330
331macro_rules! tuple_cube_type {
332 ($($P:ident),*) => {
333 impl<$($P: CubeType),*> CubeType for ($($P,)*) {
334 type ExpandType = ($($P::ExpandType,)*);
335 }
336 }
337}
338macro_rules! tuple_init {
339 ($($P:ident),*) => {
340 impl<$($P: IntoMut),*> IntoMut for ($($P,)*) {
341 #[allow(non_snake_case, unused, clippy::unused_unit)]
342 fn into_mut(self, scope: &mut Scope) -> Self {
343 let ($($P,)*) = self;
344 ($(
345 $P.into_mut(scope),
346 )*)
347 }
348 }
349 }
350}
351macro_rules! tuple_debug {
352 ($($P:ident),*) => {
353 impl<$($P: CubeDebug),*> CubeDebug for ($($P,)*) {}
354 }
355}
356macro_rules! tuple_runtime {
357 ($($P:ident),*) => {
358 impl<$($P: IntoRuntime),*> IntoRuntime for ($($P,)*) {
359 #[allow(non_snake_case, unused, clippy::unused_unit)]
360 fn __expand_runtime_method(self, scope: &mut Scope) -> Self::ExpandType {
361 let ($($P,)*) = self;
362 ($(
363 $P.__expand_runtime_method(scope),
364 )*)
365 }
366 }
367 }
368}
369macro_rules! tuple_assign {
370 ($(($n: tt, $P:ident)),*) => {
371 impl<$($P: Assign),*> Assign for ($($P,)*) {
372 #[allow(non_snake_case, unused, clippy::unused_unit)]
373 fn expand_assign(&mut self, scope: &mut Scope, value: Self) {
374 let ($($P,)*) = self;
375 $(
376 $P.expand_assign(scope, value.$n);
377 )*
378 }
379 #[allow(non_snake_case, unused, clippy::unused_unit)]
380 fn init_mut(&self, scope: &mut Scope) -> Self {
381 let ($($P,)*) = self;
382 ($(
383 $P.init_mut(scope),
384 )*)
385 }
386 }
387 }
388}
389
390all_tuples!(tuple_cube_type, 0, 12, P);
391all_tuples!(tuple_debug, 0, 12, P);
392all_tuples!(tuple_init, 0, 12, P);
393all_tuples!(tuple_runtime, 0, 12, P);
394all_tuples_enumerated!(tuple_assign, 0, 12, P);
395
396impl<P: CubePrimitive> CubeDebug for P {}
397
398pub trait NativeAssign: CubeType {
400 fn elem_init_mut(scope: &mut Scope, elem: ManagedVariable) -> ManagedVariable {
401 init_mut_expand_element(scope, &elem)
402 }
403}
404
405impl<T: NativeAssign> IntoMut for NativeExpand<T> {
406 fn into_mut(self, scope: &mut Scope) -> Self {
407 into_mut_assign(self, scope)
408 }
409}
410
411impl<T: CubeType> CubeDebug for NativeExpand<T> {
412 fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
413 scope.update_variable_name(*self.expand, name);
414 }
415}
416
417impl<T: CubeType> CubeDebug for &NativeExpand<T> {
418 fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
419 scope.update_variable_name(*self.expand, name);
420 }
421}
422
423impl<T: CubeType> CubeDebug for &mut NativeExpand<T> {
424 fn set_debug_name(&self, scope: &mut Scope, name: &'static str) {
425 scope.update_variable_name(*self.expand, name);
426 }
427}
428
429impl<T: CubeType> NativeExpand<T> {
430 pub fn vector_size(&self) -> VectorSize {
432 self.expand.ty.vector_size()
433 }
434
435 pub fn __expand_vector_size_method(self, _scope: &mut Scope) -> VectorSize {
437 self.expand.ty.vector_size()
438 }
439
440 pub fn into_variable(self) -> Variable {
441 self.expand.consume()
442 }
443}
444
445impl<T: CubeType> Clone for NativeExpand<T> {
446 fn clone(&self) -> Self {
447 Self {
448 expand: self.expand.clone(),
449 _type: PhantomData,
450 }
451 }
452}
453
454impl<T: CubeType> From<ManagedVariable> for NativeExpand<T> {
455 fn from(expand: ManagedVariable) -> Self {
456 Self {
457 expand,
458 _type: PhantomData,
459 }
460 }
461}
462
463impl<T: CubeType> From<NativeExpand<T>> for ManagedVariable {
464 fn from(value: NativeExpand<T>) -> Self {
465 value.expand
466 }
467}
468
469impl<T: CubePrimitive> NativeExpand<T> {
470 pub fn from_lit<L: Into<ConstantValue>>(scope: &Scope, lit: L) -> Self {
472 let variable: ConstantValue = lit.into();
473 let variable = T::as_type(scope).constant(variable);
474
475 NativeExpand::new(ManagedVariable::Plain(variable))
476 }
477
478 pub fn constant(&self) -> Option<ConstantValue> {
480 match self.expand.kind {
481 VariableKind::Constant(val) => Some(val),
482 _ => None,
483 }
484 }
485
486 pub fn __expand_into_lit_unchecked_method(self, _scope: &mut Scope) -> T {
487 let value = self.constant().unwrap();
488 T::from_const_value(value)
489 }
490}
491
492pub(crate) fn init_mut_expand_element(
493 scope: &mut Scope,
494 element: &ManagedVariable,
495) -> ManagedVariable {
496 scope.create_local_mut(element.ty)
497}
498
499impl<T: IntoMut> IntoMut for Option<T> {
500 fn into_mut(self, scope: &mut Scope) -> Self {
501 self.map(|o| IntoMut::into_mut(o, scope))
502 }
503}
504
505impl<T: CubeType> CubeType for Vec<T> {
506 type ExpandType = Vec<T::ExpandType>;
507}
508
509impl<T: CubeType> CubeType for &mut Vec<T> {
510 type ExpandType = Vec<T::ExpandType>;
511}
512
513impl<T: IntoMut> IntoMut for Vec<T> {
514 fn into_mut(self, scope: &mut Scope) -> Self {
515 self.into_iter().map(|e| e.into_mut(scope)).collect()
516 }
517}
518impl<T: CubeDebug> CubeDebug for Vec<T> {}
519
520pub(crate) fn __expand_new<C: Numeric, Out: Numeric>(
522 scope: &mut Scope,
523 val: C,
524) -> NativeExpand<Out> {
525 let input: ConstantValue = val.into();
526 let var = Out::as_type(scope).constant(input);
527 ManagedVariable::Plain(var).into()
528}
529
530impl LaunchArg for () {
531 type RuntimeArg<R: Runtime> = ();
532 type CompilationArg = ();
533
534 fn register<R: Runtime>(_runtime_arg: Self::RuntimeArg<R>, _launcher: &mut KernelLauncher<R>) {
535 }
537
538 fn expand(
539 _: &Self::CompilationArg,
540 _builder: &mut KernelBuilder,
541 ) -> <Self as CubeType>::ExpandType {
542 }
543}
544
545pub trait DefaultExpand: CubeType {
546 fn __expand_default(scope: &mut Scope) -> Self::ExpandType;
547}
548
549impl<T: CubeType + Default + IntoRuntime> DefaultExpand for T {
550 fn __expand_default(scope: &mut Scope) -> T::ExpandType {
551 T::default().__expand_runtime_method(scope)
552 }
553}
554
555#[derive(Clone, Copy, Debug)]
556pub struct Const<const N: usize>;
557
558pub trait Size: core::fmt::Debug + Clone + Copy + Send + Sync + 'static {
559 fn __expand_value(scope: &Scope) -> usize;
560 fn value() -> usize {
561 unexpanded!()
562 }
563 fn try_value_const() -> Option<usize> {
564 None
565 }
566}
567
568impl<const VALUE: usize> Size for Const<VALUE> {
569 fn __expand_value(_scope: &Scope) -> usize {
570 VALUE
571 }
572 fn value() -> usize {
573 VALUE
574 }
575 fn try_value_const() -> Option<usize> {
576 Some(VALUE)
577 }
578}
579
580impl<Marker: 'static> Size for DynamicSize<Marker> {
581 fn __expand_value(scope: &Scope) -> usize {
582 scope.resolve_size::<Self>().expect("Size to be registered")
583 }
584 fn value() -> usize {
585 unexpanded!()
586 }
587}
588
589#[macro_export]
592macro_rules! define_scalar {
593 ($vis: vis $name: ident) => {
594 $crate::__private::paste! {
595 $vis struct [<__ $name>];
596 $vis type $name = $crate::prelude::DynamicScalar<[<__ $name>]>;
597 }
598 };
599}
600
601#[macro_export]
603macro_rules! define_size {
604 ($vis: vis $name: ident) => {
605 $crate::__private::paste! {
606 $vis struct [<__ $name>];
607 $vis type $name = $crate::prelude::DynamicSize<[<__ $name>]>;
608 }
609 };
610}