1use super::{ConstantValue, Variable, VariableKind};
2use crate::{BarrierLevel, TypeHash};
3use core::fmt::Display;
4use cubecl_common::{
5 e2m1, e2m1x2, e2m3, e3m2, e4m3, e5m2, flex32,
6 quant::scheme::{QuantParam, QuantValue},
7 tf32, ue8m0,
8};
9use derive_more::From;
10use half::{bf16, f16};
11
12#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
13#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
14#[allow(missing_docs)]
15pub enum FloatKind {
16 E2M1,
18 E2M3,
21 E3M2,
24 E4M3,
26 E5M2,
28 UE8M0,
30 F16,
31 BF16,
32 Flex32,
33 F32,
34 TF32,
35 F64,
36}
37
38#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
39#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
40#[allow(missing_docs)]
41pub enum IntKind {
42 I8,
43 I16,
44 I32,
45 I64,
46}
47
48#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
49#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
50#[allow(missing_docs)]
51pub enum UIntKind {
52 U8,
53 U16,
54 U32,
55 U64,
56}
57
58#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
59#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
60#[allow(missing_docs)]
61pub enum ComplexKind {
62 C32,
63 C64,
64}
65
66#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
68#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord, From)]
69#[allow(missing_docs)]
70pub enum ElemType {
71 Float(FloatKind),
72 Int(IntKind),
73 UInt(UIntKind),
74 Complex(ComplexKind),
75 Bool,
76}
77
78#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
79#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
80pub enum OpaqueType {
81 Barrier(BarrierLevel),
82}
83
84#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
85#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
86pub enum SemanticType {
87 BarrierToken,
88 Pipeline,
89 TensorMap,
90}
91
92#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
94#[derive(Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
95pub enum StorageType {
96 Scalar(ElemType),
98 Packed(ElemType, usize),
100 Atomic(ElemType),
102 Opaque(OpaqueType),
105}
106
107impl core::fmt::Debug for StorageType {
108 fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
109 struct Dummy<'a>(&'a StorageType);
112
113 impl<'a> core::fmt::Debug for Dummy<'a> {
114 fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
115 match self.0 {
116 StorageType::Scalar(f0) => f.debug_tuple("Scalar").field(&f0).finish(),
117 StorageType::Packed(f0, f1) => {
118 f.debug_tuple("Packed").field(&f0).field(&f1).finish()
119 }
120 StorageType::Atomic(f0) => f.debug_tuple("Atomic").field(&f0).finish(),
121 StorageType::Opaque(f0) => f.debug_tuple("Opaque").field(&f0).finish(),
122 }
123 }
124 }
125
126 write!(f, "{:?}", Dummy(self))
127 }
128}
129
130impl ElemType {
131 pub fn from_quant_param(quant_param: QuantParam) -> Self {
133 match quant_param {
134 QuantParam::F32 => Self::Float(FloatKind::F32),
135 QuantParam::F16 => Self::Float(FloatKind::F16),
136 QuantParam::BF16 => Self::Float(FloatKind::BF16),
137 QuantParam::UE8M0 => Self::Float(FloatKind::UE8M0),
138 QuantParam::UE4M3 => Self::Float(FloatKind::UE8M0),
139 }
140 }
141
142 pub fn from_quant_value(quant_value: QuantValue) -> Self {
144 match quant_value {
145 QuantValue::E5M2 => Self::Float(FloatKind::E5M2),
146 QuantValue::E4M3 => Self::Float(FloatKind::E4M3),
147 QuantValue::E2M1 => Self::Float(FloatKind::E2M1),
148 QuantValue::Q8F | QuantValue::Q8S => Self::Int(IntKind::I8),
149 other => panic!("Unsupported quant value {other:?}"),
150 }
151 }
152
153 pub fn constant(&self, val: ConstantValue) -> Variable {
157 Variable::constant(val, Type::scalar(*self))
158 }
159
160 pub const fn size(&self) -> usize {
162 match self {
163 ElemType::Float(kind) => match kind {
164 FloatKind::E2M1
165 | FloatKind::E2M3
166 | FloatKind::E3M2
167 | FloatKind::E4M3
168 | FloatKind::E5M2
169 | FloatKind::UE8M0 => core::mem::size_of::<u8>(),
170 FloatKind::F16 => core::mem::size_of::<half::f16>(),
171 FloatKind::BF16 => core::mem::size_of::<half::bf16>(),
172 FloatKind::F32 => core::mem::size_of::<f32>(),
173 FloatKind::F64 => core::mem::size_of::<f64>(),
174 FloatKind::Flex32 => core::mem::size_of::<f32>(),
175 FloatKind::TF32 => core::mem::size_of::<f32>(),
176 },
177 ElemType::Int(kind) => match kind {
178 IntKind::I8 => core::mem::size_of::<i8>(),
179 IntKind::I16 => core::mem::size_of::<i16>(),
180 IntKind::I32 => core::mem::size_of::<i32>(),
181 IntKind::I64 => core::mem::size_of::<i64>(),
182 },
183 ElemType::UInt(kind) => match kind {
184 UIntKind::U8 => core::mem::size_of::<u8>(),
185 UIntKind::U16 => core::mem::size_of::<u16>(),
186 UIntKind::U32 => core::mem::size_of::<u32>(),
187 UIntKind::U64 => core::mem::size_of::<u64>(),
188 },
189 ElemType::Complex(kind) => match kind {
190 ComplexKind::C32 => core::mem::size_of::<f32>() * 2,
191 ComplexKind::C64 => core::mem::size_of::<f64>() * 2,
192 },
193 ElemType::Bool => core::mem::size_of::<bool>(),
194 }
195 }
196
197 pub const fn size_bits(&self) -> usize {
199 match self {
200 ElemType::Float(kind) => match kind {
201 FloatKind::E2M3
202 | FloatKind::E3M2
203 | FloatKind::E4M3
204 | FloatKind::E5M2
205 | FloatKind::UE8M0
206 | FloatKind::F16
207 | FloatKind::BF16
208 | FloatKind::F32
209 | FloatKind::F64
210 | FloatKind::Flex32
211 | FloatKind::TF32 => self.size() * 8,
212 FloatKind::E2M1 => 4,
213 },
214 ElemType::Int(_) | ElemType::UInt(_) | ElemType::Bool | ElemType::Complex(_) => {
215 self.size() * 8
216 }
217 }
218 }
219
220 pub const fn min_vector_size(&self) -> u8 {
221 match self {
222 ElemType::Float(FloatKind::E2M1) => 2,
223 _ => 1,
224 }
225 }
226
227 pub fn is_int(&self) -> bool {
228 matches!(self, ElemType::Int(_) | ElemType::UInt(_) | ElemType::Bool)
229 }
230
231 pub fn is_signed_int(&self) -> bool {
232 matches!(self, ElemType::Int(_))
233 }
234
235 pub fn is_unsigned_int(&self) -> bool {
236 matches!(self, ElemType::UInt(_) | ElemType::Bool)
237 }
238
239 pub fn is_float(&self) -> bool {
240 matches!(self, ElemType::Float(_))
241 }
242
243 pub fn is_bool(&self) -> bool {
244 matches!(self, ElemType::Bool)
245 }
246
247 pub fn is_complex(&self) -> bool {
248 matches!(self, ElemType::Complex(_))
249 }
250
251 pub fn as_complex(&self) -> Option<ComplexKind> {
252 match self {
253 ElemType::Complex(kind) => Some(*kind),
254 _ => None,
255 }
256 }
257
258 pub fn as_float(&self) -> Option<FloatKind> {
259 match self {
260 ElemType::Float(kind) => Some(*kind),
261 _ => None,
262 }
263 }
264
265 pub fn max_variable(&self) -> Variable {
266 let value = match self {
267 ElemType::Float(kind) => match kind {
268 FloatKind::E2M1 => e2m1::MAX,
269 FloatKind::E2M3 => e2m3::MAX,
270 FloatKind::E3M2 => e3m2::MAX,
271 FloatKind::E4M3 => e4m3::MAX,
272 FloatKind::E5M2 => e5m2::MAX,
273 FloatKind::UE8M0 => ue8m0::MAX,
274 FloatKind::F16 => half::f16::MAX.to_f64(),
275 FloatKind::BF16 => half::bf16::MAX.to_f64(),
276 FloatKind::Flex32 | FloatKind::TF32 | FloatKind::F32 => f32::MAX as f64,
277 FloatKind::F64 => f64::MAX,
278 }
279 .into(),
280 ElemType::Int(kind) => match kind {
281 IntKind::I8 => i8::MAX as i64,
282 IntKind::I16 => i16::MAX as i64,
283 IntKind::I32 => i32::MAX as i64,
284 IntKind::I64 => i64::MAX,
285 }
286 .into(),
287 ElemType::UInt(kind) => match kind {
288 UIntKind::U8 => u8::MAX as u64,
289 UIntKind::U16 => u16::MAX as u64,
290 UIntKind::U32 => u32::MAX as u64,
291 UIntKind::U64 => u64::MAX,
292 }
293 .into(),
294 ElemType::Complex(_) => panic!("Complex numbers have no maximum"),
295 ElemType::Bool => true.into(),
296 };
297
298 Variable::new(VariableKind::Constant(value), Type::scalar(*self))
299 }
300
301 pub fn min_variable(&self) -> Variable {
302 let value = match self {
303 ElemType::Float(kind) => match kind {
304 FloatKind::E2M1 => e2m1::MIN,
305 FloatKind::E2M3 => e2m3::MIN,
306 FloatKind::E3M2 => e3m2::MIN,
307 FloatKind::E4M3 => e4m3::MIN,
308 FloatKind::E5M2 => e5m2::MIN,
309 FloatKind::UE8M0 => ue8m0::MIN,
310 FloatKind::F16 => half::f16::MIN.to_f64(),
311 FloatKind::BF16 => half::bf16::MIN.to_f64(),
312 FloatKind::Flex32 | FloatKind::TF32 | FloatKind::F32 => f32::MIN as f64,
313 FloatKind::F64 => f64::MIN,
314 }
315 .into(),
316 ElemType::Int(kind) => match kind {
317 IntKind::I8 => i8::MIN as i64,
318 IntKind::I16 => i16::MIN as i64,
319 IntKind::I32 => i32::MIN as i64,
320 IntKind::I64 => i64::MIN,
321 }
322 .into(),
323 ElemType::UInt(kind) => match kind {
324 UIntKind::U8 => u8::MIN as u64,
325 UIntKind::U16 => u16::MIN as u64,
326 UIntKind::U32 => u32::MIN as u64,
327 UIntKind::U64 => u64::MIN,
328 }
329 .into(),
330 ElemType::Complex(_) => panic!("Complex numbers have no minimum"),
331 ElemType::Bool => false.into(),
332 };
333
334 Variable::new(VariableKind::Constant(value), Type::scalar(*self))
335 }
336
337 pub fn epsilon(&self) -> f64 {
338 match self {
339 ElemType::Float(kind) => match kind {
340 FloatKind::E2M1 => 0.5 * (e2m1::MAX - e2m1::MIN),
341 FloatKind::E2M3 => 0.5 * (e2m3::MAX - e2m3::MIN),
342 FloatKind::E3M2 => 0.5 * (e3m2::MAX - e3m2::MIN),
343 FloatKind::E4M3 => 0.5 * (e4m3::MAX - e4m3::MIN),
344 FloatKind::E5M2 => 0.5 * (e5m2::MAX - e5m2::MIN),
345 FloatKind::UE8M0 => 0.5 * (ue8m0::MAX - ue8m0::MIN),
346 FloatKind::F16 => half::f16::EPSILON.to_f64(),
347 FloatKind::BF16 => 0.0078125, FloatKind::Flex32 | FloatKind::F32 | FloatKind::TF32 => f32::EPSILON.into(),
349 FloatKind::F64 => f64::EPSILON,
350 },
351 ElemType::Int(_) | ElemType::UInt(_) => 1.0,
352 ElemType::Complex(kind) => match kind {
353 ComplexKind::C32 => f32::EPSILON.into(),
354 ComplexKind::C64 => f64::EPSILON,
355 },
356 ElemType::Bool => 1.0,
357 }
358 }
359}
360
361impl OpaqueType {
362 pub const fn size(&self) -> usize {
364 match self {
365 OpaqueType::Barrier(_) => 8,
366 }
367 }
368
369 pub const fn size_bits(&self) -> usize {
371 match self {
372 OpaqueType::Barrier(_) => 64,
373 }
374 }
375}
376
377impl StorageType {
378 pub fn elem_type(&self) -> ElemType {
379 match self {
380 StorageType::Scalar(ty) | StorageType::Packed(ty, _) | StorageType::Atomic(ty) => *ty,
381 StorageType::Opaque(_) => unimplemented!("Can't get elem type for opaque type"),
382 }
383 }
384
385 pub fn packing_factor(&self) -> usize {
386 match self {
387 StorageType::Packed(_, factor) => *factor,
388 _ => 1,
389 }
390 }
391
392 pub fn is_atomic(&self) -> bool {
393 matches!(self, StorageType::Atomic(_))
394 }
395
396 pub fn size(&self) -> usize {
397 self.size_bits().div_ceil(8)
398 }
399
400 pub fn size_bits(&self) -> usize {
401 match self {
402 StorageType::Packed(ty, factor) => ty.size_bits() * *factor,
403 StorageType::Scalar(ty) | StorageType::Atomic(ty) => ty.size_bits(),
404 StorageType::Opaque(ty) => ty.size_bits(),
405 }
406 }
407
408 pub fn is_int(&self) -> bool {
409 self.elem_type().is_int()
410 }
411
412 pub fn is_signed_int(&self) -> bool {
413 self.elem_type().is_signed_int()
414 }
415
416 pub fn is_unsigned_int(&self) -> bool {
417 self.elem_type().is_unsigned_int()
418 }
419
420 pub fn is_float(&self) -> bool {
421 self.elem_type().is_float()
422 }
423
424 pub fn is_bool(&self) -> bool {
425 self.elem_type().is_bool()
426 }
427
428 pub fn epsilon(&self) -> f64 {
430 match self {
431 StorageType::Scalar(ty) | StorageType::Atomic(ty) => ty.epsilon(),
432 StorageType::Packed(ty, factor) => {
433 ty.epsilon() * (*factor as f64)
435 }
436 StorageType::Opaque(_) => panic!("Opaque type does not have an epsilon"),
437 }
438 }
439
440 pub fn constant(&self, value: ConstantValue) -> Variable {
441 Variable::constant(value, Type::new(*self))
442 }
443}
444
445macro_rules! storage_from_elem {
446 ($($ty: ty),*) => {
447 $(impl From<$ty> for StorageType {
448 fn from(value: $ty) -> Self {
449 StorageType::Scalar(value.into())
450 }
451 })*
452 };
453}
454
455storage_from_elem!(FloatKind, IntKind, UIntKind, ComplexKind, ElemType);
456
457impl From<OpaqueType> for StorageType {
458 fn from(val: OpaqueType) -> Self {
459 StorageType::Opaque(val)
460 }
461}
462
463impl<T: Into<StorageType>> From<T> for Type {
464 fn from(val: T) -> Self {
465 Type::new(val.into())
466 }
467}
468
469impl From<SemanticType> for Type {
470 fn from(val: SemanticType) -> Self {
471 Type::semantic(val)
472 }
473}
474
475#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
476#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
477pub enum Type {
478 Scalar(StorageType),
480 Vector(StorageType, VectorSize),
482 Semantic(SemanticType),
484}
485
486pub type VectorSize = usize;
487
488impl Type {
489 pub fn elem_type(&self) -> ElemType {
491 self.storage_type().elem_type()
492 }
493
494 pub fn new(storage: StorageType) -> Self {
496 Type::Scalar(storage)
497 }
498
499 pub fn scalar(elem: ElemType) -> Self {
500 Self::new(StorageType::Scalar(elem))
501 }
502
503 pub fn semantic(ty: SemanticType) -> Self {
504 Self::Semantic(ty)
505 }
506
507 pub fn with_vector_size(self, vector_size: VectorSize) -> Type {
508 match vector_size > 1 {
509 true => Type::Vector(self.storage_type(), vector_size),
510 false => Type::Scalar(self.storage_type()),
511 }
512 }
513
514 pub fn with_storage_type(self, storage: StorageType) -> Type {
515 let vector_size = self.vector_size();
516 Type::new(storage).with_vector_size(vector_size)
517 }
518
519 pub fn vector_size(&self) -> VectorSize {
520 match self {
521 Type::Scalar(_) => 1,
522 Type::Vector(_, vector_size) => *vector_size,
523 Type::Semantic(_) => 0,
524 }
525 }
526
527 pub fn size(&self) -> usize {
528 match self {
529 Type::Scalar(ty) => ty.size(),
530 Type::Vector(ty, vector_size) => ty.size() * *vector_size,
531 Type::Semantic(_) => 0,
532 }
533 }
534
535 pub fn size_bits(&self) -> usize {
536 match self {
537 Type::Scalar(ty) => ty.size_bits(),
538 Type::Vector(ty, vector_size) => ty.size_bits() * *vector_size,
539 Type::Semantic(_) => 0,
540 }
541 }
542
543 pub fn packing_factor(&self) -> usize {
544 match self {
545 Type::Scalar(ty) => ty.packing_factor(),
546 Type::Vector(ty, _) => ty.packing_factor(),
547 Type::Semantic(_) => 1,
548 }
549 }
550
551 pub fn is_atomic(&self) -> bool {
552 !self.is_semantic() && self.storage_type().is_atomic()
553 }
554
555 pub fn is_int(&self) -> bool {
556 !self.is_semantic() && self.storage_type().is_int()
557 }
558
559 pub fn is_signed_int(&self) -> bool {
560 !self.is_semantic() && self.storage_type().is_signed_int()
561 }
562
563 pub fn is_unsigned_int(&self) -> bool {
564 !self.is_semantic() && self.storage_type().is_unsigned_int()
565 }
566
567 pub fn is_float(&self) -> bool {
568 !self.is_semantic() && self.storage_type().is_float()
569 }
570
571 pub fn is_bool(&self) -> bool {
572 !self.is_semantic() && self.storage_type().is_bool()
573 }
574
575 pub fn storage_type(&self) -> StorageType {
576 match self {
577 Type::Scalar(ty) | Type::Vector(ty, _) => *ty,
578 Type::Semantic(_) => unimplemented!("Can't get storage for semantic type"),
579 }
580 }
581
582 pub fn is_semantic(&self) -> bool {
583 match self {
584 Type::Scalar(_) | Type::Vector(_, _) => false,
585 Type::Semantic(_) => true,
586 }
587 }
588
589 pub fn constant(&self, value: ConstantValue) -> Variable {
590 Variable::constant(value, *self)
591 }
592}
593
594impl Display for Type {
595 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
596 match self {
597 Type::Scalar(ty) => write!(f, "{ty}"),
598 Type::Vector(ty, vector_size) => write!(f, "vector<{ty}, {vector_size}>"),
599 Type::Semantic(ty) => write!(f, "{ty}"),
600 }
601 }
602}
603
604impl Display for StorageType {
605 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
606 match self {
607 StorageType::Scalar(ty) => write!(f, "{ty}"),
608 StorageType::Packed(ty, factor) => write!(f, "packed<{ty}, {factor}>"),
609 StorageType::Atomic(ty) => write!(f, "atomic<{ty}>"),
610 StorageType::Opaque(ty) => write!(f, "{ty}"),
611 }
612 }
613}
614
615impl Display for ElemType {
616 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
617 match self {
618 Self::Float(kind) => match kind {
619 FloatKind::E2M1 => f.write_str("e2m1"),
620 FloatKind::E2M3 => f.write_str("e2m3"),
621 FloatKind::E3M2 => f.write_str("e3m2"),
622 FloatKind::E4M3 => f.write_str("e4m3"),
623 FloatKind::E5M2 => f.write_str("e5m2"),
624 FloatKind::UE8M0 => f.write_str("ue8m0"),
625 FloatKind::F16 => f.write_str("f16"),
626 FloatKind::BF16 => f.write_str("bf16"),
627 FloatKind::Flex32 => f.write_str("flex32"),
628 FloatKind::TF32 => f.write_str("tf32"),
629 FloatKind::F32 => f.write_str("f32"),
630 FloatKind::F64 => f.write_str("f64"),
631 },
632 Self::Int(kind) => match kind {
633 IntKind::I8 => f.write_str("i8"),
634 IntKind::I16 => f.write_str("i16"),
635 IntKind::I32 => f.write_str("i32"),
636 IntKind::I64 => f.write_str("i64"),
637 },
638 Self::UInt(kind) => match kind {
639 UIntKind::U8 => f.write_str("u8"),
640 UIntKind::U16 => f.write_str("u16"),
641 UIntKind::U32 => f.write_str("u32"),
642 UIntKind::U64 => f.write_str("u64"),
643 },
644 Self::Complex(kind) => match kind {
645 ComplexKind::C32 => f.write_str("c32"),
646 ComplexKind::C64 => f.write_str("c64"),
647 },
648 Self::Bool => f.write_str("bool"),
649 }
650 }
651}
652
653impl Display for SemanticType {
654 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
655 match self {
656 SemanticType::BarrierToken => f.write_str("barrier_token"),
657 SemanticType::Pipeline => f.write_str("pipeline"),
658 SemanticType::TensorMap => f.write_str("tensor_map"),
659 }
660 }
661}
662
663impl Display for OpaqueType {
664 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
665 match self {
666 OpaqueType::Barrier(level) => write!(f, "barrier<{level}>"),
667 }
668 }
669}
670
671impl From<e2m1x2> for Variable {
672 fn from(_value: e2m1x2) -> Self {
673 unimplemented!("Can't currently construct e2m1x2")
674 }
675}
676
677impl From<e2m3> for Variable {
678 fn from(_value: e2m3) -> Self {
679 unimplemented!("Can't currently construct fp6")
680 }
681}
682
683impl From<e3m2> for Variable {
684 fn from(_value: e3m2) -> Self {
685 unimplemented!("Can't currently construct fp6")
686 }
687}
688
689impl From<i8> for ConstantValue {
690 fn from(value: i8) -> Self {
691 ConstantValue::Int(value as i64)
692 }
693}
694
695impl From<i16> for ConstantValue {
696 fn from(value: i16) -> Self {
697 ConstantValue::Int(value as i64)
698 }
699}
700
701impl From<i32> for ConstantValue {
702 fn from(value: i32) -> Self {
703 ConstantValue::Int(value as i64)
704 }
705}
706
707impl From<isize> for ConstantValue {
708 fn from(value: isize) -> Self {
709 ConstantValue::Int(value as i64)
710 }
711}
712
713impl From<u8> for ConstantValue {
714 fn from(value: u8) -> Self {
715 ConstantValue::UInt(value as u64)
716 }
717}
718
719impl From<u16> for ConstantValue {
720 fn from(value: u16) -> Self {
721 ConstantValue::UInt(value as u64)
722 }
723}
724
725impl From<u32> for ConstantValue {
726 fn from(value: u32) -> Self {
727 ConstantValue::UInt(value as u64)
728 }
729}
730
731impl From<usize> for ConstantValue {
732 fn from(value: usize) -> Self {
733 ConstantValue::UInt(value as u64)
734 }
735}
736
737impl From<e2m1> for ConstantValue {
738 fn from(value: e2m1) -> Self {
739 ConstantValue::Float(value.to_f64())
740 }
741}
742
743impl From<e4m3> for ConstantValue {
744 fn from(value: e4m3) -> Self {
745 ConstantValue::Float(value.to_f64())
746 }
747}
748
749impl From<e5m2> for ConstantValue {
750 fn from(value: e5m2) -> Self {
751 ConstantValue::Float(value.to_f64())
752 }
753}
754
755impl From<ue8m0> for ConstantValue {
756 fn from(value: ue8m0) -> Self {
757 ConstantValue::Float(value.to_f64())
758 }
759}
760
761impl From<half::f16> for ConstantValue {
762 fn from(value: half::f16) -> Self {
763 ConstantValue::Float(value.to_f64())
764 }
765}
766
767impl From<half::bf16> for ConstantValue {
768 fn from(value: half::bf16) -> Self {
769 ConstantValue::Float(value.to_f64())
770 }
771}
772
773impl From<flex32> for ConstantValue {
774 fn from(value: flex32) -> Self {
775 ConstantValue::Float(value.to_f64())
776 }
777}
778
779impl From<tf32> for ConstantValue {
780 fn from(value: tf32) -> Self {
781 ConstantValue::Float(value.to_f64())
782 }
783}
784
785impl From<f32> for ConstantValue {
786 fn from(value: f32) -> Self {
787 ConstantValue::Float(value as f64)
788 }
789}
790
791macro_rules! impl_into_variable {
792 ($($ty: ty => $kind: path,)*) => {
793 $(
794 impl From<$ty> for Variable {
795 fn from(value: $ty) -> Self {
796 Variable::new(VariableKind::Constant(value.into()), $kind.into())
797 }
798 }
799 )*
800 };
801}
802
803impl_into_variable!(
804 bool => ElemType::Bool,
805
806 i8 => IntKind::I8,
807 i16 => IntKind::I16,
808 i32 => IntKind::I32,
809 i64 => IntKind::I64,
810
811 u8 => UIntKind::U8,
812 u16 => UIntKind::U16,
813 u32 => UIntKind::U32,
814 u64 => UIntKind::U64,
815
816 e2m1 => FloatKind::E2M1,
817 e4m3 => FloatKind::E4M3,
818 e5m2 => FloatKind::E5M2,
819 ue8m0 => FloatKind::UE8M0,
820 f16 => FloatKind::F16,
821 bf16 => FloatKind::BF16,
822 f32 => FloatKind::F32,
823 flex32 => FloatKind::Flex32,
824 tf32 => FloatKind::TF32,
825 f64 => FloatKind::F64,
826
827 usize => UIntKind::U32,
828 isize => IntKind::I32,
829);
830
831impl From<num_complex::Complex<f32>> for Variable {
832 fn from(value: num_complex::Complex<f32>) -> Self {
833 Variable::new(
834 VariableKind::Constant(ConstantValue::Complex(value.re as f64, value.im as f64)),
835 StorageType::Scalar(ElemType::Complex(ComplexKind::C32)).into(),
836 )
837 }
838}
839
840impl From<num_complex::Complex<f64>> for Variable {
841 fn from(value: num_complex::Complex<f64>) -> Self {
842 Variable::new(
843 VariableKind::Constant(ConstantValue::Complex(value.re, value.im)),
844 StorageType::Scalar(ElemType::Complex(ComplexKind::C64)).into(),
845 )
846 }
847}