Skip to main content

cubecl_ir/
type.rs

1use super::{ConstantValue, Variable, VariableKind};
2use crate::{BarrierLevel, TypeHash};
3use core::fmt::Display;
4use cubecl_common::{
5    e2m1, e2m1x2, e2m3, e3m2, e4m3, e5m2, flex32,
6    quant::scheme::{QuantParam, QuantValue},
7    tf32, ue8m0,
8};
9use derive_more::From;
10use half::{bf16, f16};
11
12#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
13#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
14#[allow(missing_docs)]
15pub enum FloatKind {
16    /// FP4, 2 bit exponent, 1 bit mantissa
17    E2M1,
18    /// FP6, 2 bit exponent, 3 bit mantissa
19    /// Note: represented by an 8-bit value, with the upper two bits being insignificant
20    E2M3,
21    /// FP6, 3 bit exponent, 2 bit mantissa
22    /// Note: represented by an 8-bit value, with the upper two bits being insignificant
23    E3M2,
24    /// FP8, 4 bit exponent, 3 bit mantissa
25    E4M3,
26    /// FP8, 5 bit exponent, 2 bit mantissa
27    E5M2,
28    /// FP8, unsigned, 8 bit exponent, 0 bit mantissa
29    UE8M0,
30    F16,
31    BF16,
32    Flex32,
33    F32,
34    TF32,
35    F64,
36}
37
38#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
39#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
40#[allow(missing_docs)]
41pub enum IntKind {
42    I8,
43    I16,
44    I32,
45    I64,
46}
47
48#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
49#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
50#[allow(missing_docs)]
51pub enum UIntKind {
52    U8,
53    U16,
54    U32,
55    U64,
56}
57
58#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
59#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
60#[allow(missing_docs)]
61pub enum ComplexKind {
62    C32,
63    C64,
64}
65
66/// Conceptual element type, not necessarily the physical type used in the code
67#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
68#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord, From)]
69#[allow(missing_docs)]
70pub enum ElemType {
71    Float(FloatKind),
72    Int(IntKind),
73    UInt(UIntKind),
74    Complex(ComplexKind),
75    Bool,
76}
77
78#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
79#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
80pub enum OpaqueType {
81    Barrier(BarrierLevel),
82}
83
84#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
85#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
86pub enum SemanticType {
87    BarrierToken,
88    Pipeline,
89    TensorMap,
90}
91
92/// Physical type containing one or more elements
93#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
94#[derive(Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
95pub enum StorageType {
96    /// `ElemType` is the same as the physical type
97    Scalar(ElemType),
98    /// Packed values of type `ElemType`
99    Packed(ElemType, usize),
100    /// Atomically accessed version of `ElemType`
101    Atomic(ElemType),
102    /// Opaque types that can be stored but not interacted with normally. Currently only barrier,
103    /// but may be used for arrival tokens and tensor map descriptors, for example.
104    Opaque(OpaqueType),
105}
106
107impl core::fmt::Debug for StorageType {
108    fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
109        // Ensure debug is not spread into multiple lines because it makes kernel ids very hard
110        // to read.
111        struct Dummy<'a>(&'a StorageType);
112
113        impl<'a> core::fmt::Debug for Dummy<'a> {
114            fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
115                match self.0 {
116                    StorageType::Scalar(f0) => f.debug_tuple("Scalar").field(&f0).finish(),
117                    StorageType::Packed(f0, f1) => {
118                        f.debug_tuple("Packed").field(&f0).field(&f1).finish()
119                    }
120                    StorageType::Atomic(f0) => f.debug_tuple("Atomic").field(&f0).finish(),
121                    StorageType::Opaque(f0) => f.debug_tuple("Opaque").field(&f0).finish(),
122                }
123            }
124        }
125
126        write!(f, "{:?}", Dummy(self))
127    }
128}
129
130impl ElemType {
131    /// Creates an elem type that correspond to the given [`QuantParam`].
132    pub fn from_quant_param(quant_param: QuantParam) -> Self {
133        match quant_param {
134            QuantParam::F32 => Self::Float(FloatKind::F32),
135            QuantParam::F16 => Self::Float(FloatKind::F16),
136            QuantParam::BF16 => Self::Float(FloatKind::BF16),
137            QuantParam::UE8M0 => Self::Float(FloatKind::UE8M0),
138            QuantParam::UE4M3 => Self::Float(FloatKind::UE8M0),
139        }
140    }
141
142    /// Creates an elem type that correspond to the given [`QuantValue`].
143    pub fn from_quant_value(quant_value: QuantValue) -> Self {
144        match quant_value {
145            QuantValue::E5M2 => Self::Float(FloatKind::E5M2),
146            QuantValue::E4M3 => Self::Float(FloatKind::E4M3),
147            QuantValue::E2M1 => Self::Float(FloatKind::E2M1),
148            QuantValue::Q8F | QuantValue::Q8S => Self::Int(IntKind::I8),
149            other => panic!("Unsupported quant value {other:?}"),
150        }
151    }
152
153    /// Create a constant from a constant value.
154    ///
155    /// The output will have the same type as the element.
156    pub fn constant(&self, val: ConstantValue) -> Variable {
157        Variable::constant(val, Type::scalar(*self))
158    }
159
160    /// Get the size in bytes.
161    pub const fn size(&self) -> usize {
162        match self {
163            ElemType::Float(kind) => match kind {
164                FloatKind::E2M1
165                | FloatKind::E2M3
166                | FloatKind::E3M2
167                | FloatKind::E4M3
168                | FloatKind::E5M2
169                | FloatKind::UE8M0 => core::mem::size_of::<u8>(),
170                FloatKind::F16 => core::mem::size_of::<half::f16>(),
171                FloatKind::BF16 => core::mem::size_of::<half::bf16>(),
172                FloatKind::F32 => core::mem::size_of::<f32>(),
173                FloatKind::F64 => core::mem::size_of::<f64>(),
174                FloatKind::Flex32 => core::mem::size_of::<f32>(),
175                FloatKind::TF32 => core::mem::size_of::<f32>(),
176            },
177            ElemType::Int(kind) => match kind {
178                IntKind::I8 => core::mem::size_of::<i8>(),
179                IntKind::I16 => core::mem::size_of::<i16>(),
180                IntKind::I32 => core::mem::size_of::<i32>(),
181                IntKind::I64 => core::mem::size_of::<i64>(),
182            },
183            ElemType::UInt(kind) => match kind {
184                UIntKind::U8 => core::mem::size_of::<u8>(),
185                UIntKind::U16 => core::mem::size_of::<u16>(),
186                UIntKind::U32 => core::mem::size_of::<u32>(),
187                UIntKind::U64 => core::mem::size_of::<u64>(),
188            },
189            ElemType::Complex(kind) => match kind {
190                ComplexKind::C32 => core::mem::size_of::<f32>() * 2,
191                ComplexKind::C64 => core::mem::size_of::<f64>() * 2,
192            },
193            ElemType::Bool => core::mem::size_of::<bool>(),
194        }
195    }
196
197    /// Get the size in bits.
198    pub const fn size_bits(&self) -> usize {
199        match self {
200            ElemType::Float(kind) => match kind {
201                FloatKind::E2M3
202                | FloatKind::E3M2
203                | FloatKind::E4M3
204                | FloatKind::E5M2
205                | FloatKind::UE8M0
206                | FloatKind::F16
207                | FloatKind::BF16
208                | FloatKind::F32
209                | FloatKind::F64
210                | FloatKind::Flex32
211                | FloatKind::TF32 => self.size() * 8,
212                FloatKind::E2M1 => 4,
213            },
214            ElemType::Int(_) | ElemType::UInt(_) | ElemType::Bool | ElemType::Complex(_) => {
215                self.size() * 8
216            }
217        }
218    }
219
220    pub const fn min_vector_size(&self) -> u8 {
221        match self {
222            ElemType::Float(FloatKind::E2M1) => 2,
223            _ => 1,
224        }
225    }
226
227    pub fn is_int(&self) -> bool {
228        matches!(self, ElemType::Int(_) | ElemType::UInt(_) | ElemType::Bool)
229    }
230
231    pub fn is_signed_int(&self) -> bool {
232        matches!(self, ElemType::Int(_))
233    }
234
235    pub fn is_unsigned_int(&self) -> bool {
236        matches!(self, ElemType::UInt(_) | ElemType::Bool)
237    }
238
239    pub fn is_float(&self) -> bool {
240        matches!(self, ElemType::Float(_))
241    }
242
243    pub fn is_bool(&self) -> bool {
244        matches!(self, ElemType::Bool)
245    }
246
247    pub fn is_complex(&self) -> bool {
248        matches!(self, ElemType::Complex(_))
249    }
250
251    pub fn as_complex(&self) -> Option<ComplexKind> {
252        match self {
253            ElemType::Complex(kind) => Some(*kind),
254            _ => None,
255        }
256    }
257
258    pub fn as_float(&self) -> Option<FloatKind> {
259        match self {
260            ElemType::Float(kind) => Some(*kind),
261            _ => None,
262        }
263    }
264
265    pub fn max_variable(&self) -> Variable {
266        let value = match self {
267            ElemType::Float(kind) => match kind {
268                FloatKind::E2M1 => e2m1::MAX,
269                FloatKind::E2M3 => e2m3::MAX,
270                FloatKind::E3M2 => e3m2::MAX,
271                FloatKind::E4M3 => e4m3::MAX,
272                FloatKind::E5M2 => e5m2::MAX,
273                FloatKind::UE8M0 => ue8m0::MAX,
274                FloatKind::F16 => half::f16::MAX.to_f64(),
275                FloatKind::BF16 => half::bf16::MAX.to_f64(),
276                FloatKind::Flex32 | FloatKind::TF32 | FloatKind::F32 => f32::MAX as f64,
277                FloatKind::F64 => f64::MAX,
278            }
279            .into(),
280            ElemType::Int(kind) => match kind {
281                IntKind::I8 => i8::MAX as i64,
282                IntKind::I16 => i16::MAX as i64,
283                IntKind::I32 => i32::MAX as i64,
284                IntKind::I64 => i64::MAX,
285            }
286            .into(),
287            ElemType::UInt(kind) => match kind {
288                UIntKind::U8 => u8::MAX as u64,
289                UIntKind::U16 => u16::MAX as u64,
290                UIntKind::U32 => u32::MAX as u64,
291                UIntKind::U64 => u64::MAX,
292            }
293            .into(),
294            ElemType::Complex(_) => panic!("Complex numbers have no maximum"),
295            ElemType::Bool => true.into(),
296        };
297
298        Variable::new(VariableKind::Constant(value), Type::scalar(*self))
299    }
300
301    pub fn min_variable(&self) -> Variable {
302        let value = match self {
303            ElemType::Float(kind) => match kind {
304                FloatKind::E2M1 => e2m1::MIN,
305                FloatKind::E2M3 => e2m3::MIN,
306                FloatKind::E3M2 => e3m2::MIN,
307                FloatKind::E4M3 => e4m3::MIN,
308                FloatKind::E5M2 => e5m2::MIN,
309                FloatKind::UE8M0 => ue8m0::MIN,
310                FloatKind::F16 => half::f16::MIN.to_f64(),
311                FloatKind::BF16 => half::bf16::MIN.to_f64(),
312                FloatKind::Flex32 | FloatKind::TF32 | FloatKind::F32 => f32::MIN as f64,
313                FloatKind::F64 => f64::MIN,
314            }
315            .into(),
316            ElemType::Int(kind) => match kind {
317                IntKind::I8 => i8::MIN as i64,
318                IntKind::I16 => i16::MIN as i64,
319                IntKind::I32 => i32::MIN as i64,
320                IntKind::I64 => i64::MIN,
321            }
322            .into(),
323            ElemType::UInt(kind) => match kind {
324                UIntKind::U8 => u8::MIN as u64,
325                UIntKind::U16 => u16::MIN as u64,
326                UIntKind::U32 => u32::MIN as u64,
327                UIntKind::U64 => u64::MIN,
328            }
329            .into(),
330            ElemType::Complex(_) => panic!("Complex numbers have no minimum"),
331            ElemType::Bool => false.into(),
332        };
333
334        Variable::new(VariableKind::Constant(value), Type::scalar(*self))
335    }
336
337    pub fn epsilon(&self) -> f64 {
338        match self {
339            ElemType::Float(kind) => match kind {
340                FloatKind::E2M1 => 0.5 * (e2m1::MAX - e2m1::MIN),
341                FloatKind::E2M3 => 0.5 * (e2m3::MAX - e2m3::MIN),
342                FloatKind::E3M2 => 0.5 * (e3m2::MAX - e3m2::MIN),
343                FloatKind::E4M3 => 0.5 * (e4m3::MAX - e4m3::MIN),
344                FloatKind::E5M2 => 0.5 * (e5m2::MAX - e5m2::MIN),
345                FloatKind::UE8M0 => 0.5 * (ue8m0::MAX - ue8m0::MIN),
346                FloatKind::F16 => half::f16::EPSILON.to_f64(),
347                FloatKind::BF16 => 0.0078125, // bf16 epsilon ≈ 2^-7
348                FloatKind::Flex32 | FloatKind::F32 | FloatKind::TF32 => f32::EPSILON.into(),
349                FloatKind::F64 => f64::EPSILON,
350            },
351            ElemType::Int(_) | ElemType::UInt(_) => 1.0,
352            ElemType::Complex(kind) => match kind {
353                ComplexKind::C32 => f32::EPSILON.into(),
354                ComplexKind::C64 => f64::EPSILON,
355            },
356            ElemType::Bool => 1.0,
357        }
358    }
359}
360
361impl OpaqueType {
362    /// Get the size in bytes.
363    pub const fn size(&self) -> usize {
364        match self {
365            OpaqueType::Barrier(_) => 8,
366        }
367    }
368
369    /// Get the size in bits.
370    pub const fn size_bits(&self) -> usize {
371        match self {
372            OpaqueType::Barrier(_) => 64,
373        }
374    }
375}
376
377impl StorageType {
378    pub fn elem_type(&self) -> ElemType {
379        match self {
380            StorageType::Scalar(ty) | StorageType::Packed(ty, _) | StorageType::Atomic(ty) => *ty,
381            StorageType::Opaque(_) => unimplemented!("Can't get elem type for opaque type"),
382        }
383    }
384
385    pub fn packing_factor(&self) -> usize {
386        match self {
387            StorageType::Packed(_, factor) => *factor,
388            _ => 1,
389        }
390    }
391
392    pub fn is_atomic(&self) -> bool {
393        matches!(self, StorageType::Atomic(_))
394    }
395
396    pub fn size(&self) -> usize {
397        self.size_bits().div_ceil(8)
398    }
399
400    pub fn size_bits(&self) -> usize {
401        match self {
402            StorageType::Packed(ty, factor) => ty.size_bits() * *factor,
403            StorageType::Scalar(ty) | StorageType::Atomic(ty) => ty.size_bits(),
404            StorageType::Opaque(ty) => ty.size_bits(),
405        }
406    }
407
408    pub fn is_int(&self) -> bool {
409        self.elem_type().is_int()
410    }
411
412    pub fn is_signed_int(&self) -> bool {
413        self.elem_type().is_signed_int()
414    }
415
416    pub fn is_unsigned_int(&self) -> bool {
417        self.elem_type().is_unsigned_int()
418    }
419
420    pub fn is_float(&self) -> bool {
421        self.elem_type().is_float()
422    }
423
424    pub fn is_bool(&self) -> bool {
425        self.elem_type().is_bool()
426    }
427
428    /// Returns an empirical epsilon for this storage type, taking quantization into account.
429    pub fn epsilon(&self) -> f64 {
430        match self {
431            StorageType::Scalar(ty) | StorageType::Atomic(ty) => ty.epsilon(),
432            StorageType::Packed(ty, factor) => {
433                // For packed types, we can conservatively scale epsilon by the number of packed elements
434                ty.epsilon() * (*factor as f64)
435            }
436            StorageType::Opaque(_) => panic!("Opaque type does not have an epsilon"),
437        }
438    }
439
440    pub fn constant(&self, value: ConstantValue) -> Variable {
441        Variable::constant(value, Type::new(*self))
442    }
443}
444
445macro_rules! storage_from_elem {
446    ($($ty: ty),*) => {
447        $(impl From<$ty> for StorageType {
448            fn from(value: $ty) -> Self {
449                StorageType::Scalar(value.into())
450            }
451        })*
452    };
453}
454
455storage_from_elem!(FloatKind, IntKind, UIntKind, ComplexKind, ElemType);
456
457impl From<OpaqueType> for StorageType {
458    fn from(val: OpaqueType) -> Self {
459        StorageType::Opaque(val)
460    }
461}
462
463impl<T: Into<StorageType>> From<T> for Type {
464    fn from(val: T) -> Self {
465        Type::new(val.into())
466    }
467}
468
469impl From<SemanticType> for Type {
470    fn from(val: SemanticType) -> Self {
471        Type::semantic(val)
472    }
473}
474
475#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
476#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
477pub enum Type {
478    /// Scalar type containing a single storage element
479    Scalar(StorageType),
480    /// Vector wrapping `n` storage elements
481    Vector(StorageType, VectorSize),
482    /// No defined physical representation, purely semantic. i.e. barrier, pipeline
483    Semantic(SemanticType),
484}
485
486pub type VectorSize = usize;
487
488impl Type {
489    /// Fetch the elem of the item.
490    pub fn elem_type(&self) -> ElemType {
491        self.storage_type().elem_type()
492    }
493
494    /// Create a new item
495    pub fn new(storage: StorageType) -> Self {
496        Type::Scalar(storage)
497    }
498
499    pub fn scalar(elem: ElemType) -> Self {
500        Self::new(StorageType::Scalar(elem))
501    }
502
503    pub fn semantic(ty: SemanticType) -> Self {
504        Self::Semantic(ty)
505    }
506
507    pub fn with_vector_size(self, vector_size: VectorSize) -> Type {
508        match vector_size > 1 {
509            true => Type::Vector(self.storage_type(), vector_size),
510            false => Type::Scalar(self.storage_type()),
511        }
512    }
513
514    pub fn with_storage_type(self, storage: StorageType) -> Type {
515        let vector_size = self.vector_size();
516        Type::new(storage).with_vector_size(vector_size)
517    }
518
519    pub fn vector_size(&self) -> VectorSize {
520        match self {
521            Type::Scalar(_) => 1,
522            Type::Vector(_, vector_size) => *vector_size,
523            Type::Semantic(_) => 0,
524        }
525    }
526
527    pub fn size(&self) -> usize {
528        match self {
529            Type::Scalar(ty) => ty.size(),
530            Type::Vector(ty, vector_size) => ty.size() * *vector_size,
531            Type::Semantic(_) => 0,
532        }
533    }
534
535    pub fn size_bits(&self) -> usize {
536        match self {
537            Type::Scalar(ty) => ty.size_bits(),
538            Type::Vector(ty, vector_size) => ty.size_bits() * *vector_size,
539            Type::Semantic(_) => 0,
540        }
541    }
542
543    pub fn packing_factor(&self) -> usize {
544        match self {
545            Type::Scalar(ty) => ty.packing_factor(),
546            Type::Vector(ty, _) => ty.packing_factor(),
547            Type::Semantic(_) => 1,
548        }
549    }
550
551    pub fn is_atomic(&self) -> bool {
552        !self.is_semantic() && self.storage_type().is_atomic()
553    }
554
555    pub fn is_int(&self) -> bool {
556        !self.is_semantic() && self.storage_type().is_int()
557    }
558
559    pub fn is_signed_int(&self) -> bool {
560        !self.is_semantic() && self.storage_type().is_signed_int()
561    }
562
563    pub fn is_unsigned_int(&self) -> bool {
564        !self.is_semantic() && self.storage_type().is_unsigned_int()
565    }
566
567    pub fn is_float(&self) -> bool {
568        !self.is_semantic() && self.storage_type().is_float()
569    }
570
571    pub fn is_bool(&self) -> bool {
572        !self.is_semantic() && self.storage_type().is_bool()
573    }
574
575    pub fn storage_type(&self) -> StorageType {
576        match self {
577            Type::Scalar(ty) | Type::Vector(ty, _) => *ty,
578            Type::Semantic(_) => unimplemented!("Can't get storage for semantic type"),
579        }
580    }
581
582    pub fn is_semantic(&self) -> bool {
583        match self {
584            Type::Scalar(_) | Type::Vector(_, _) => false,
585            Type::Semantic(_) => true,
586        }
587    }
588
589    pub fn constant(&self, value: ConstantValue) -> Variable {
590        Variable::constant(value, *self)
591    }
592}
593
594impl Display for Type {
595    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
596        match self {
597            Type::Scalar(ty) => write!(f, "{ty}"),
598            Type::Vector(ty, vector_size) => write!(f, "vector<{ty}, {vector_size}>"),
599            Type::Semantic(ty) => write!(f, "{ty}"),
600        }
601    }
602}
603
604impl Display for StorageType {
605    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
606        match self {
607            StorageType::Scalar(ty) => write!(f, "{ty}"),
608            StorageType::Packed(ty, factor) => write!(f, "packed<{ty}, {factor}>"),
609            StorageType::Atomic(ty) => write!(f, "atomic<{ty}>"),
610            StorageType::Opaque(ty) => write!(f, "{ty}"),
611        }
612    }
613}
614
615impl Display for ElemType {
616    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
617        match self {
618            Self::Float(kind) => match kind {
619                FloatKind::E2M1 => f.write_str("e2m1"),
620                FloatKind::E2M3 => f.write_str("e2m3"),
621                FloatKind::E3M2 => f.write_str("e3m2"),
622                FloatKind::E4M3 => f.write_str("e4m3"),
623                FloatKind::E5M2 => f.write_str("e5m2"),
624                FloatKind::UE8M0 => f.write_str("ue8m0"),
625                FloatKind::F16 => f.write_str("f16"),
626                FloatKind::BF16 => f.write_str("bf16"),
627                FloatKind::Flex32 => f.write_str("flex32"),
628                FloatKind::TF32 => f.write_str("tf32"),
629                FloatKind::F32 => f.write_str("f32"),
630                FloatKind::F64 => f.write_str("f64"),
631            },
632            Self::Int(kind) => match kind {
633                IntKind::I8 => f.write_str("i8"),
634                IntKind::I16 => f.write_str("i16"),
635                IntKind::I32 => f.write_str("i32"),
636                IntKind::I64 => f.write_str("i64"),
637            },
638            Self::UInt(kind) => match kind {
639                UIntKind::U8 => f.write_str("u8"),
640                UIntKind::U16 => f.write_str("u16"),
641                UIntKind::U32 => f.write_str("u32"),
642                UIntKind::U64 => f.write_str("u64"),
643            },
644            Self::Complex(kind) => match kind {
645                ComplexKind::C32 => f.write_str("c32"),
646                ComplexKind::C64 => f.write_str("c64"),
647            },
648            Self::Bool => f.write_str("bool"),
649        }
650    }
651}
652
653impl Display for SemanticType {
654    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
655        match self {
656            SemanticType::BarrierToken => f.write_str("barrier_token"),
657            SemanticType::Pipeline => f.write_str("pipeline"),
658            SemanticType::TensorMap => f.write_str("tensor_map"),
659        }
660    }
661}
662
663impl Display for OpaqueType {
664    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
665        match self {
666            OpaqueType::Barrier(level) => write!(f, "barrier<{level}>"),
667        }
668    }
669}
670
671impl From<e2m1x2> for Variable {
672    fn from(_value: e2m1x2) -> Self {
673        unimplemented!("Can't currently construct e2m1x2")
674    }
675}
676
677impl From<e2m3> for Variable {
678    fn from(_value: e2m3) -> Self {
679        unimplemented!("Can't currently construct fp6")
680    }
681}
682
683impl From<e3m2> for Variable {
684    fn from(_value: e3m2) -> Self {
685        unimplemented!("Can't currently construct fp6")
686    }
687}
688
689impl From<i8> for ConstantValue {
690    fn from(value: i8) -> Self {
691        ConstantValue::Int(value as i64)
692    }
693}
694
695impl From<i16> for ConstantValue {
696    fn from(value: i16) -> Self {
697        ConstantValue::Int(value as i64)
698    }
699}
700
701impl From<i32> for ConstantValue {
702    fn from(value: i32) -> Self {
703        ConstantValue::Int(value as i64)
704    }
705}
706
707impl From<isize> for ConstantValue {
708    fn from(value: isize) -> Self {
709        ConstantValue::Int(value as i64)
710    }
711}
712
713impl From<u8> for ConstantValue {
714    fn from(value: u8) -> Self {
715        ConstantValue::UInt(value as u64)
716    }
717}
718
719impl From<u16> for ConstantValue {
720    fn from(value: u16) -> Self {
721        ConstantValue::UInt(value as u64)
722    }
723}
724
725impl From<u32> for ConstantValue {
726    fn from(value: u32) -> Self {
727        ConstantValue::UInt(value as u64)
728    }
729}
730
731impl From<usize> for ConstantValue {
732    fn from(value: usize) -> Self {
733        ConstantValue::UInt(value as u64)
734    }
735}
736
737impl From<e2m1> for ConstantValue {
738    fn from(value: e2m1) -> Self {
739        ConstantValue::Float(value.to_f64())
740    }
741}
742
743impl From<e4m3> for ConstantValue {
744    fn from(value: e4m3) -> Self {
745        ConstantValue::Float(value.to_f64())
746    }
747}
748
749impl From<e5m2> for ConstantValue {
750    fn from(value: e5m2) -> Self {
751        ConstantValue::Float(value.to_f64())
752    }
753}
754
755impl From<ue8m0> for ConstantValue {
756    fn from(value: ue8m0) -> Self {
757        ConstantValue::Float(value.to_f64())
758    }
759}
760
761impl From<half::f16> for ConstantValue {
762    fn from(value: half::f16) -> Self {
763        ConstantValue::Float(value.to_f64())
764    }
765}
766
767impl From<half::bf16> for ConstantValue {
768    fn from(value: half::bf16) -> Self {
769        ConstantValue::Float(value.to_f64())
770    }
771}
772
773impl From<flex32> for ConstantValue {
774    fn from(value: flex32) -> Self {
775        ConstantValue::Float(value.to_f64())
776    }
777}
778
779impl From<tf32> for ConstantValue {
780    fn from(value: tf32) -> Self {
781        ConstantValue::Float(value.to_f64())
782    }
783}
784
785impl From<f32> for ConstantValue {
786    fn from(value: f32) -> Self {
787        ConstantValue::Float(value as f64)
788    }
789}
790
791macro_rules! impl_into_variable {
792    ($($ty: ty => $kind: path,)*) => {
793        $(
794            impl From<$ty> for Variable {
795                fn from(value: $ty) -> Self {
796                    Variable::new(VariableKind::Constant(value.into()), $kind.into())
797                }
798            }
799        )*
800    };
801}
802
803impl_into_variable!(
804    bool => ElemType::Bool,
805
806    i8 => IntKind::I8,
807    i16 => IntKind::I16,
808    i32 => IntKind::I32,
809    i64 => IntKind::I64,
810
811    u8 => UIntKind::U8,
812    u16 => UIntKind::U16,
813    u32 => UIntKind::U32,
814    u64 => UIntKind::U64,
815
816    e2m1 => FloatKind::E2M1,
817    e4m3 => FloatKind::E4M3,
818    e5m2 => FloatKind::E5M2,
819    ue8m0 => FloatKind::UE8M0,
820    f16 => FloatKind::F16,
821    bf16 => FloatKind::BF16,
822    f32 => FloatKind::F32,
823    flex32 => FloatKind::Flex32,
824    tf32 => FloatKind::TF32,
825    f64 => FloatKind::F64,
826
827    usize => UIntKind::U32,
828    isize => IntKind::I32,
829);
830
831impl From<num_complex::Complex<f32>> for Variable {
832    fn from(value: num_complex::Complex<f32>) -> Self {
833        Variable::new(
834            VariableKind::Constant(ConstantValue::Complex(value.re as f64, value.im as f64)),
835            StorageType::Scalar(ElemType::Complex(ComplexKind::C32)).into(),
836        )
837    }
838}
839
840impl From<num_complex::Complex<f64>> for Variable {
841    fn from(value: num_complex::Complex<f64>) -> Self {
842        Variable::new(
843            VariableKind::Constant(ConstantValue::Complex(value.re, value.im)),
844            StorageType::Scalar(ElemType::Complex(ComplexKind::C64)).into(),
845        )
846    }
847}