cubecl_core/frontend/
scalar.rs1use alloc::vec::Vec;
2use cubecl::prelude::*;
3use cubecl_common::{e4m3, e5m2, ue8m0};
4use serde::{Deserialize, Serialize};
5
6use crate::{
7 self as cubecl, ScalarArgType, intrinsic,
8 ir::{ElemType, FloatKind, IntKind, ManagedVariable, UIntKind},
9};
10
11#[derive(Clone, Copy, Debug)]
12pub struct InputScalar {
16 data: [u8; 8],
17 dtype: StorageType,
18}
19
20#[derive(Clone)]
21pub struct InputScalarExpand {
22 pub expand: ManagedVariable,
23}
24
25impl CubeType for InputScalar {
26 type ExpandType = InputScalarExpand;
27}
28
29impl IntoMut for InputScalarExpand {
30 fn into_mut(self, _scope: &mut Scope) -> Self {
31 self
32 }
33}
34
35impl CubeDebug for InputScalarExpand {}
36
37impl InputScalar {
38 pub fn new<E: num_traits::ToPrimitive>(val: E, dtype: impl Into<StorageType>) -> Self {
44 let dtype: StorageType = dtype.into();
45 let mut out = InputScalar {
46 data: Default::default(),
47 dtype,
48 };
49 fn write<E: ScalarArgType>(val: impl num_traits::ToPrimitive, out: &mut [u8]) {
50 let val = [E::from(val).unwrap()];
51 let bytes = E::as_bytes(&val);
52 out[..bytes.len()].copy_from_slice(bytes);
53 }
54 match dtype {
55 StorageType::Scalar(elem) => match elem {
56 ElemType::Float(float_kind) => match float_kind {
57 FloatKind::F16 => write::<half::f16>(val, &mut out.data),
58 FloatKind::BF16 => write::<half::bf16>(val, &mut out.data),
59 FloatKind::Flex32 | FloatKind::F32 | FloatKind::TF32 => {
60 write::<f32>(val, &mut out.data)
61 }
62 FloatKind::F64 => write::<f64>(val, &mut out.data),
63 FloatKind::E2M1 | FloatKind::E2M3 | FloatKind::E3M2 => {
64 unimplemented!("fp6 CPU conversion not yet implemented")
65 }
66 FloatKind::E4M3 => write::<e4m3>(val, &mut out.data),
67 FloatKind::E5M2 => write::<e5m2>(val, &mut out.data),
68 FloatKind::UE8M0 => write::<ue8m0>(val, &mut out.data),
69 },
70 ElemType::Int(int_kind) => match int_kind {
71 IntKind::I8 => write::<i8>(val, &mut out.data),
72 IntKind::I16 => write::<i16>(val, &mut out.data),
73 IntKind::I32 => write::<i32>(val, &mut out.data),
74 IntKind::I64 => write::<i64>(val, &mut out.data),
75 },
76 ElemType::UInt(uint_kind) => match uint_kind {
77 UIntKind::U8 => write::<u8>(val, &mut out.data),
78 UIntKind::U16 => write::<u16>(val, &mut out.data),
79 UIntKind::U32 => write::<u32>(val, &mut out.data),
80 UIntKind::U64 => write::<u64>(val, &mut out.data),
81 },
82 ElemType::Bool => panic!("Bool isn't a scalar"),
83 ElemType::Complex(_) => unimplemented!("Complex not supported for scalar input"),
84 },
85 other => unimplemented!("{other} not supported for scalars"),
86 };
87 out
88 }
89}
90
91#[cube]
92impl InputScalar {
93 pub fn get<C: Scalar>(&self) -> C {
97 intrinsic!(|scope| {
98 let dtype = C::as_type(scope);
99 if self.expand.ty == dtype {
100 return self.expand.into();
101 }
102 let new_var = scope.create_local(dtype);
103 cast::expand::<C, C>(scope, self.expand.into(), new_var.clone().into());
104 new_var.into()
105 })
106 }
107}
108
109impl InputScalar {
110 pub fn as_bytes(&self) -> Vec<u8> {
111 self.data[..self.dtype.size()].to_vec()
112 }
113}
114
115impl LaunchArg for InputScalar {
116 type RuntimeArg<R: Runtime> = InputScalar;
117 type CompilationArg = InputScalarCompilationArg;
118
119 fn register<R: Runtime>(
120 arg: Self::RuntimeArg<R>,
121 launcher: &mut KernelLauncher<R>,
122 ) -> Self::CompilationArg {
123 let dtype = arg.dtype;
124 launcher.register_scalar_raw(&arg.data[..dtype.size()], dtype);
125 InputScalarCompilationArg::new(arg.dtype)
126 }
127
128 fn expand(
129 arg: &Self::CompilationArg,
130 builder: &mut KernelBuilder,
131 ) -> <Self as CubeType>::ExpandType {
132 let expand = builder.scalar(arg.ty);
133 InputScalarExpand { expand }
134 }
135}
136
137#[derive(Serialize, Deserialize, Clone, PartialEq, Eq, Hash, Debug)]
138pub struct InputScalarCompilationArg {
139 ty: StorageType,
140}
141
142impl InputScalarCompilationArg {
143 pub fn new(ty: StorageType) -> Self {
144 Self { ty }
145 }
146}