Skip to main content

cubecl_ir/
variable.rs

1use core::{fmt::Display, hash::Hash};
2
3use crate::{BarrierLevel, FloatKind, IntKind, StorageType, TypeHash};
4
5use super::{ComplexKind, ElemType, Matrix, Type, UIntKind};
6use cubecl_common::{e2m1, e4m3, e5m2, ue8m0};
7use derive_more::From;
8use float_ord::FloatOrd;
9
10#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
11#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash)]
12#[allow(missing_docs)]
13pub struct Variable {
14    pub kind: VariableKind,
15    pub ty: Type,
16}
17
18impl Variable {
19    pub fn new(kind: VariableKind, item: Type) -> Self {
20        Self { kind, ty: item }
21    }
22
23    pub fn builtin(builtin: Builtin, ty: StorageType) -> Self {
24        Self::new(VariableKind::Builtin(builtin), Type::new(ty))
25    }
26
27    pub fn constant(value: ConstantValue, ty: impl Into<Type>) -> Self {
28        let ty = ty.into();
29        let value = value.cast_to(ty);
30        Self::new(VariableKind::Constant(value), ty)
31    }
32
33    pub fn elem_type(&self) -> ElemType {
34        self.ty.elem_type()
35    }
36
37    pub fn storage_type(&self) -> StorageType {
38        self.ty.storage_type()
39    }
40}
41
42pub type Id = u32;
43
44#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
45#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash)]
46pub enum VariableKind {
47    GlobalInputArray(Id),
48    GlobalOutputArray(Id),
49    GlobalScalar(Id),
50    TensorMapInput(Id),
51    TensorMapOutput(Id),
52    LocalArray {
53        id: Id,
54        length: usize,
55        unroll_factor: usize,
56    },
57    LocalMut {
58        id: Id,
59    },
60    LocalConst {
61        id: Id,
62    },
63    Versioned {
64        id: Id,
65        version: u16,
66    },
67    Constant(ConstantValue),
68    ConstantArray {
69        id: Id,
70        length: usize,
71        unroll_factor: usize,
72    },
73    SharedArray {
74        id: Id,
75        length: usize,
76        unroll_factor: usize,
77        alignment: Option<usize>,
78    },
79    Shared {
80        id: Id,
81    },
82    Matrix {
83        id: Id,
84        mat: Matrix,
85    },
86    Builtin(Builtin),
87    Pipeline {
88        id: Id,
89        num_stages: u8,
90    },
91    BarrierToken {
92        id: Id,
93        level: BarrierLevel,
94    },
95}
96
97#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
98#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TypeHash, PartialOrd, Ord)]
99#[repr(u8)]
100pub enum Builtin {
101    UnitPos,
102    UnitPosX,
103    UnitPosY,
104    UnitPosZ,
105    CubePosCluster,
106    CubePosClusterX,
107    CubePosClusterY,
108    CubePosClusterZ,
109    CubePos,
110    CubePosX,
111    CubePosY,
112    CubePosZ,
113    CubeDim,
114    CubeDimX,
115    CubeDimY,
116    CubeDimZ,
117    CubeClusterDim,
118    CubeClusterDimX,
119    CubeClusterDimY,
120    CubeClusterDimZ,
121    CubeCount,
122    CubeCountX,
123    CubeCountY,
124    CubeCountZ,
125    PlaneDim,
126    PlanePos,
127    UnitPosPlane,
128    AbsolutePos,
129    AbsolutePosX,
130    AbsolutePosY,
131    AbsolutePosZ,
132}
133
134impl Variable {
135    /// Whether a variable is always immutable. Used for optimizations to determine whether it's
136    /// safe to inline/merge
137    pub fn is_immutable(&self) -> bool {
138        match self.kind {
139            VariableKind::GlobalOutputArray { .. } => false,
140            VariableKind::TensorMapInput(_) => true,
141            VariableKind::TensorMapOutput(_) => false,
142            VariableKind::LocalMut { .. } => false,
143            VariableKind::SharedArray { .. } => false,
144            VariableKind::Shared { .. } => false,
145            VariableKind::Matrix { .. } => false,
146            VariableKind::LocalArray { .. } => false,
147            VariableKind::GlobalInputArray { .. } => false,
148            VariableKind::GlobalScalar { .. } => true,
149            VariableKind::Versioned { .. } => true,
150            VariableKind::LocalConst { .. } => true,
151            VariableKind::Constant(_) => true,
152            VariableKind::ConstantArray { .. } => true,
153            VariableKind::Builtin(_) => true,
154            VariableKind::Pipeline { .. } => false,
155            VariableKind::BarrierToken { .. } => false,
156        }
157    }
158
159    /// Is this an array type that yields items when indexed,
160    /// or a scalar/vector that yields elems/slices when indexed?
161    pub fn is_array(&self) -> bool {
162        matches!(
163            self.kind,
164            VariableKind::GlobalInputArray { .. }
165                | VariableKind::GlobalOutputArray { .. }
166                | VariableKind::ConstantArray { .. }
167                | VariableKind::SharedArray { .. }
168                | VariableKind::LocalArray { .. }
169                | VariableKind::Matrix { .. }
170        )
171    }
172
173    /// Is this an array type that is contained in concrete memory,
174    /// or a local array/scalar/vector?
175    pub fn is_memory(&self) -> bool {
176        matches!(
177            self.kind,
178            VariableKind::GlobalInputArray { .. }
179                | VariableKind::GlobalOutputArray { .. }
180                | VariableKind::SharedArray { .. }
181        )
182    }
183
184    pub fn has_length(&self) -> bool {
185        matches!(
186            self.kind,
187            VariableKind::GlobalInputArray { .. } | VariableKind::GlobalOutputArray { .. }
188        )
189    }
190
191    pub fn has_buffer_length(&self) -> bool {
192        matches!(
193            self.kind,
194            VariableKind::GlobalInputArray { .. } | VariableKind::GlobalOutputArray { .. }
195        )
196    }
197
198    /// Determines if the value is a constant with the specified value (converted if necessary)
199    pub fn is_constant(&self, value: i64) -> bool {
200        match self.kind {
201            VariableKind::Constant(ConstantValue::Int(val)) => val == value,
202            VariableKind::Constant(ConstantValue::UInt(val)) => val as i64 == value,
203            VariableKind::Constant(ConstantValue::Float(val)) => val == value as f64,
204            _ => false,
205        }
206    }
207
208    /// Determines if the value is a boolean constant with the `true` value
209    pub fn is_true(&self) -> bool {
210        match self.kind {
211            VariableKind::Constant(ConstantValue::Bool(val)) => val,
212            _ => false,
213        }
214    }
215
216    /// Determines if the value is a boolean constant with the `false` value
217    pub fn is_false(&self) -> bool {
218        match self.kind {
219            VariableKind::Constant(ConstantValue::Bool(val)) => !val,
220            _ => false,
221        }
222    }
223}
224
225/// The scalars are stored with the highest precision possible, but they might get reduced during
226/// compilation. For constant propagation, casts are always executed before converting back to the
227/// larger type to ensure deterministic output.
228#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
229#[derive(Debug, Clone, Copy, TypeHash, PartialEq, PartialOrd, From)]
230#[allow(missing_docs, clippy::derive_ord_xor_partial_ord)]
231pub enum ConstantValue {
232    Int(i64),
233    Float(f64),
234    UInt(u64),
235    Bool(bool),
236    Complex(f64, f64),
237}
238
239impl Ord for ConstantValue {
240    fn cmp(&self, other: &Self) -> core::cmp::Ordering {
241        // Override float-float comparison with `FloatOrd` since `f64` isn't `Ord`. All other
242        // comparisons are safe to unwrap since they're either `Ord` or only compare discriminants.
243        match (self, other) {
244            (ConstantValue::Float(this), ConstantValue::Float(other)) => {
245                FloatOrd(*this).cmp(&FloatOrd(*other))
246            }
247            (
248                ConstantValue::Complex(this_re, this_im),
249                ConstantValue::Complex(other_re, other_im),
250            ) => FloatOrd(*this_re)
251                .cmp(&FloatOrd(*other_re))
252                .then_with(|| FloatOrd(*this_im).cmp(&FloatOrd(*other_im))),
253            _ => self.partial_cmp(other).unwrap(),
254        }
255    }
256}
257
258impl Eq for ConstantValue {}
259impl Hash for ConstantValue {
260    fn hash<H: core::hash::Hasher>(&self, ra_expand_state: &mut H) {
261        core::mem::discriminant(self).hash(ra_expand_state);
262        match self {
263            ConstantValue::Int(f0) => {
264                f0.hash(ra_expand_state);
265            }
266            ConstantValue::Float(f0) => {
267                FloatOrd(*f0).hash(ra_expand_state);
268            }
269            ConstantValue::UInt(f0) => {
270                f0.hash(ra_expand_state);
271            }
272            ConstantValue::Bool(f0) => {
273                f0.hash(ra_expand_state);
274            }
275            ConstantValue::Complex(f0, f1) => {
276                FloatOrd(*f0).hash(ra_expand_state);
277                FloatOrd(*f1).hash(ra_expand_state);
278            }
279        }
280    }
281}
282
283impl ConstantValue {
284    /// Returns the value of the constant as a usize.
285    ///
286    /// It will return [None] if the constant type is a float or a bool.
287    pub fn try_as_usize(&self) -> Option<usize> {
288        match self {
289            ConstantValue::UInt(val) => Some(*val as usize),
290            ConstantValue::Int(val) => Some(*val as usize),
291            ConstantValue::Float(_) => None,
292            ConstantValue::Bool(_) => None,
293            ConstantValue::Complex(_, _) => None,
294        }
295    }
296
297    /// Returns the value of the constant as a usize.
298    pub fn as_usize(&self) -> usize {
299        match self {
300            ConstantValue::UInt(val) => *val as usize,
301            ConstantValue::Int(val) => *val as usize,
302            ConstantValue::Float(val) => *val as usize,
303            ConstantValue::Bool(val) => *val as usize,
304            ConstantValue::Complex(_, _) => {
305                panic!("Complex constants can't be converted to usize")
306            }
307        }
308    }
309
310    /// Returns the value of the scalar as a u32.
311    ///
312    /// It will return [None] if the scalar type is a float or a bool.
313    pub fn try_as_u32(&self) -> Option<u32> {
314        self.try_as_u64().map(|it| it as u32)
315    }
316
317    /// Returns the value of the scalar as a u32.
318    ///
319    /// It will panic if the scalar type is a float or a bool.
320    pub fn as_u32(&self) -> u32 {
321        self.as_u64() as u32
322    }
323
324    /// Returns the value of the scalar as a u64.
325    ///
326    /// It will return [None] if the scalar type is a float or a bool.
327    pub fn try_as_u64(&self) -> Option<u64> {
328        match self {
329            ConstantValue::UInt(val) => Some(*val),
330            ConstantValue::Int(val) => Some(*val as u64),
331            ConstantValue::Float(_) => None,
332            ConstantValue::Bool(_) => None,
333            ConstantValue::Complex(_, _) => None,
334        }
335    }
336
337    /// Returns the value of the scalar as a u64.
338    pub fn as_u64(&self) -> u64 {
339        match self {
340            ConstantValue::UInt(val) => *val,
341            ConstantValue::Int(val) => *val as u64,
342            ConstantValue::Float(val) => *val as u64,
343            ConstantValue::Bool(val) => *val as u64,
344            ConstantValue::Complex(_, _) => panic!("Complex constants can't be converted to u64"),
345        }
346    }
347
348    /// Returns the value of the scalar as a i64.
349    ///
350    /// It will return [None] if the scalar type is a float or a bool.
351    pub fn try_as_i64(&self) -> Option<i64> {
352        match self {
353            ConstantValue::UInt(val) => Some(*val as i64),
354            ConstantValue::Int(val) => Some(*val),
355            ConstantValue::Float(_) => None,
356            ConstantValue::Bool(_) => None,
357            ConstantValue::Complex(_, _) => None,
358        }
359    }
360
361    /// Returns the value of the scalar as a i128.
362    pub fn as_i128(&self) -> i128 {
363        match self {
364            ConstantValue::UInt(val) => *val as i128,
365            ConstantValue::Int(val) => *val as i128,
366            ConstantValue::Float(val) => *val as i128,
367            ConstantValue::Bool(val) => *val as i128,
368            ConstantValue::Complex(_, _) => {
369                panic!("Complex constants can't be converted to i128")
370            }
371        }
372    }
373
374    /// Returns the value of the scalar as a i64.
375    pub fn as_i64(&self) -> i64 {
376        match self {
377            ConstantValue::UInt(val) => *val as i64,
378            ConstantValue::Int(val) => *val,
379            ConstantValue::Float(val) => *val as i64,
380            ConstantValue::Bool(val) => *val as i64,
381            ConstantValue::Complex(_, _) => panic!("Complex constants can't be converted to i64"),
382        }
383    }
384
385    /// Returns the value of the scalar as a i64.
386    pub fn as_i32(&self) -> i32 {
387        match self {
388            ConstantValue::UInt(val) => *val as i32,
389            ConstantValue::Int(val) => *val as i32,
390            ConstantValue::Float(val) => *val as i32,
391            ConstantValue::Bool(val) => *val as i32,
392            ConstantValue::Complex(_, _) => panic!("Complex constants can't be converted to i32"),
393        }
394    }
395
396    /// Returns the value of the scalar as a f64.
397    ///
398    /// It will return [None] if the scalar type is an int or a bool.
399    pub fn try_as_f64(&self) -> Option<f64> {
400        match self {
401            ConstantValue::Float(val) => Some(*val),
402            ConstantValue::Complex(re, _) => Some(*re),
403            _ => None,
404        }
405    }
406
407    /// Returns the value of the scalar as a f64.
408    pub fn as_f64(&self) -> f64 {
409        match self {
410            ConstantValue::UInt(val) => *val as f64,
411            ConstantValue::Int(val) => *val as f64,
412            ConstantValue::Float(val) => *val,
413            ConstantValue::Bool(val) => *val as u8 as f64,
414            ConstantValue::Complex(re, _) => *re,
415        }
416    }
417
418    /// Returns the value of the variable as a bool if it actually is a bool.
419    pub fn try_as_bool(&self) -> Option<bool> {
420        match self {
421            ConstantValue::Bool(val) => Some(*val),
422            _ => None,
423        }
424    }
425
426    /// Returns the value of the variable as a bool.
427    ///
428    /// It will panic if the scalar isn't a bool.
429    pub fn as_bool(&self) -> bool {
430        match self {
431            ConstantValue::UInt(val) => *val != 0,
432            ConstantValue::Int(val) => *val != 0,
433            ConstantValue::Float(val) => *val != 0.,
434            ConstantValue::Bool(val) => *val,
435            ConstantValue::Complex(_, _) => {
436                panic!("Complex constants can't be converted to bool")
437            }
438        }
439    }
440
441    pub fn is_zero(&self) -> bool {
442        match self {
443            ConstantValue::Int(val) => *val == 0,
444            ConstantValue::Float(val) => *val == 0.0,
445            ConstantValue::UInt(val) => *val == 0,
446            ConstantValue::Bool(val) => !*val,
447            ConstantValue::Complex(re, im) => *re == 0.0 && *im == 0.0,
448        }
449    }
450
451    pub fn is_one(&self) -> bool {
452        match self {
453            ConstantValue::Int(val) => *val == 1,
454            ConstantValue::Float(val) => *val == 1.0,
455            ConstantValue::UInt(val) => *val == 1,
456            ConstantValue::Bool(val) => *val,
457            ConstantValue::Complex(re, im) => *re == 1.0 && *im == 0.0,
458        }
459    }
460
461    pub fn cast_to(&self, other: impl Into<Type>) -> ConstantValue {
462        match other.into().storage_type() {
463            StorageType::Scalar(elem_type) => match elem_type {
464                ElemType::Float(kind) => match kind {
465                    FloatKind::E2M1 => e2m1::from_f64(self.as_f64()).to_f64(),
466                    FloatKind::E2M3 | FloatKind::E3M2 => {
467                        unimplemented!("FP6 constants not yet supported")
468                    }
469                    FloatKind::E4M3 => e4m3::from_f64(self.as_f64()).to_f64(),
470                    FloatKind::E5M2 => e5m2::from_f64(self.as_f64()).to_f64(),
471                    FloatKind::UE8M0 => ue8m0::from_f64(self.as_f64()).to_f64(),
472                    FloatKind::F16 => half::f16::from_f64(self.as_f64()).to_f64(),
473                    FloatKind::BF16 => half::bf16::from_f64(self.as_f64()).to_f64(),
474                    FloatKind::Flex32 | FloatKind::TF32 | FloatKind::F32 => {
475                        self.as_f64() as f32 as f64
476                    }
477                    FloatKind::F64 => self.as_f64(),
478                }
479                .into(),
480                ElemType::Int(kind) => {
481                    let value = match self {
482                        ConstantValue::Complex(re, _) => *re as i64,
483                        _ => self.as_i64(),
484                    };
485
486                    match kind {
487                        IntKind::I8 => value as i8 as i64,
488                        IntKind::I16 => value as i16 as i64,
489                        IntKind::I32 => value as i32 as i64,
490                        IntKind::I64 => value,
491                    }
492                }
493                .into(),
494                ElemType::UInt(kind) => {
495                    let value = match self {
496                        ConstantValue::Complex(re, _) => *re as u64,
497                        _ => self.as_u64(),
498                    };
499
500                    match kind {
501                        UIntKind::U8 => value as u8 as u64,
502                        UIntKind::U16 => value as u16 as u64,
503                        UIntKind::U32 => value as u32 as u64,
504                        UIntKind::U64 => value,
505                    }
506                }
507                .into(),
508                ElemType::Bool => self.as_bool().into(),
509                ElemType::Complex(kind) => match (self, kind) {
510                    (ConstantValue::Complex(re, im), ComplexKind::C32) => {
511                        ConstantValue::Complex(*re as f32 as f64, *im as f32 as f64)
512                    }
513                    (ConstantValue::Complex(re, im), ComplexKind::C64) => {
514                        ConstantValue::Complex(*re, *im)
515                    }
516                    (_, ComplexKind::C32) => {
517                        let re = self.as_f64() as f32 as f64;
518                        ConstantValue::Complex(re, 0.0)
519                    }
520                    (_, ComplexKind::C64) => ConstantValue::Complex(self.as_f64(), 0.0),
521                },
522            },
523            StorageType::Packed(ElemType::Float(FloatKind::E2M1), 2) => {
524                e2m1::from_f64(self.as_f64()).to_f64().into()
525            }
526            StorageType::Packed(..) => unimplemented!("Unsupported packed type"),
527            StorageType::Atomic(_) => unimplemented!("Atomic constants aren't supported"),
528            StorageType::Opaque(_) => unimplemented!("Opaque constants aren't supported"),
529        }
530    }
531}
532
533impl Display for ConstantValue {
534    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
535        match self {
536            ConstantValue::Int(val) => write!(f, "{val}"),
537            ConstantValue::Float(val) => write!(f, "{val:?}"),
538            ConstantValue::UInt(val) => write!(f, "{val}"),
539            ConstantValue::Bool(val) => write!(f, "{val}"),
540            ConstantValue::Complex(re, im) => write!(f, "({re:?}, {im:?})"),
541        }
542    }
543}
544
545impl Variable {
546    pub fn vector_size(&self) -> usize {
547        self.ty.vector_size()
548    }
549
550    pub fn index(&self) -> Option<Id> {
551        match self.kind {
552            VariableKind::GlobalInputArray(id)
553            | VariableKind::GlobalOutputArray(id)
554            | VariableKind::TensorMapInput(id)
555            | VariableKind::TensorMapOutput(id)
556            | VariableKind::GlobalScalar(id)
557            | VariableKind::LocalMut { id, .. }
558            | VariableKind::Versioned { id, .. }
559            | VariableKind::LocalConst { id, .. }
560            | VariableKind::ConstantArray { id, .. }
561            | VariableKind::SharedArray { id, .. }
562            | VariableKind::Shared { id, .. }
563            | VariableKind::LocalArray { id, .. }
564            | VariableKind::Matrix { id, .. } => Some(id),
565            _ => None,
566        }
567    }
568
569    pub fn as_const(&self) -> Option<ConstantValue> {
570        match self.kind {
571            VariableKind::Constant(constant) => Some(constant),
572            _ => None,
573        }
574    }
575}
576
577impl Display for Variable {
578    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
579        match self.kind {
580            VariableKind::GlobalInputArray(id) => write!(f, "input({id})"),
581            VariableKind::GlobalOutputArray(id) => write!(f, "output({id})"),
582            VariableKind::GlobalScalar(id) => write!(f, "scalar<{}>({id})", self.ty),
583            VariableKind::TensorMapInput(id) => write!(f, "tensor_map({id})"),
584            VariableKind::TensorMapOutput(id) => write!(f, "tensor_map({id})"),
585            VariableKind::Constant(constant) => write!(f, "{}({constant})", self.ty),
586            VariableKind::LocalMut { id } => write!(f, "local({id})"),
587            VariableKind::Versioned { id, version } => {
588                write!(f, "local({id}).v{version}")
589            }
590            VariableKind::LocalConst { id } => write!(f, "binding({id})"),
591            VariableKind::ConstantArray { id, .. } => write!(f, "const_array({id})"),
592            VariableKind::SharedArray { id, .. } => write!(f, "shared_array({id})"),
593            VariableKind::Shared { id } => write!(f, "shared({id})"),
594            VariableKind::LocalArray { id, .. } => write!(f, "array({id})"),
595            VariableKind::Matrix { id, .. } => write!(f, "matrix({id})"),
596            VariableKind::Builtin(builtin) => write!(f, "{builtin:?}"),
597            VariableKind::Pipeline { id, .. } => write!(f, "pipeline({id})"),
598            VariableKind::BarrierToken { id, .. } => write!(f, "barrier_token({id})"),
599        }
600    }
601}
602
603// Useful with the cube_inline macro.
604impl From<&Variable> for Variable {
605    fn from(value: &Variable) -> Self {
606        *value
607    }
608}