svod_dtype/cast.rs
1use super::*;
2use enumset::EnumSet;
3
4impl ScalarDType {
5 const fn promotion_lattice(self) -> &'static [Self] {
6 use ScalarDType::*;
7 match self {
8 Bool => &[Int8, UInt8],
9 Int8 => &[Int16],
10 Int16 => &[Int32],
11 Int32 => &[Int64],
12 Int64 => &[FP8E4M3, FP8E5M2],
13 UInt8 => &[Int16, UInt16],
14 UInt16 => &[Int32, UInt32],
15 UInt32 => &[Int64, UInt64],
16 UInt64 => &[FP8E4M3, FP8E5M2],
17 FP8E5M2 => &[Float16, BFloat16],
18 FP8E4M3 => &[Float16, BFloat16],
19 Float16 => &[Float32],
20 BFloat16 => &[Float32],
21 Float32 => &[Float64],
22 Float64 | Void | Index => &[],
23 }
24 }
25
26 fn get_recursive_parents(self) -> EnumSet<Self> {
27 self.promotion_lattice()
28 .iter()
29 .fold(EnumSet::only(self), |dtypes, &parent| dtypes.union(parent.get_recursive_parents()))
30 }
31
32 /// Check if casting from `from` to `to` is safe (preserves value).
33 pub fn can_safe_cast(self, to: Self) -> bool {
34 // Same type (compare discriminants) or from Bool (Bool can cast to anything)
35 if self == to || matches!(self, Self::Bool) {
36 return true;
37 }
38
39 // Index type: can cast from any integer to Index
40 if matches!(to, Self::Index) {
41 return self.is_int();
42 }
43
44 let from_bytes = self.bytes();
45 let to_bytes = to.bytes();
46 match (self.is_unsigned(), self.is_signed(), self.is_float(), to.is_unsigned(), to.is_signed(), to.is_float()) {
47 // Unsigned -> Unsigned: only if target is larger
48 (true, _, _, true, _, _) => from_bytes < to_bytes,
49 // Signed -> Signed: only if target is same size or larger
50 (_, true, _, _, true, _) => from_bytes <= to_bytes,
51 // Unsigned -> Signed: only if target is strictly larger
52 (true, _, _, _, true, _) => from_bytes < to_bytes,
53 // Integer -> Float: safe if integer is Int32 or smaller
54 (_, _, false, _, _, true) => from_bytes <= Self::Int32.bytes(),
55 // Float -> Float: only if target is larger
56 (_, _, true, _, _, true) => from_bytes < to_bytes,
57 _ => false,
58 }
59 }
60}
61
62impl DType {
63 /// Check if casting from `from` to `to` is safe (preserves value).
64 pub fn can_safe_cast(from: Self, to: Self) -> bool {
65 // Extract scalars
66 let (Some(from_scalar), Some(to_scalar)) = (from.scalar(), to.scalar()) else {
67 return false;
68 };
69
70 // Check scalar cast is safe
71 if !from_scalar.can_safe_cast(to_scalar) {
72 return false;
73 }
74
75 // Vector counts must match (or broadcast from scalar)
76 from.count() == to.count() || from.count() == 1 || to.count() == 1
77 }
78
79 /// Find the least upper bound type for a set of dtypes.
80 ///
81 /// Returns the smallest type that all input types can be safely cast to.
82 ///
83 /// Type promotion rules:
84 /// - Scalar + Scalar → promoted Scalar
85 /// - `Ptr<T>` + `Ptr<T>` → `Ptr<T>` (same Ptr types)
86 /// - `Ptr<T>` + `Scalar(T)` → `Scalar(T)` (Ptr will be auto-loaded in codegen)
87 /// - `Ptr<T>` + `Scalar(U)` → promoted Scalar (if T and U are compatible)
88 pub fn least_upper_dtype(dtypes: &[Self]) -> Option<Self> {
89 if dtypes.is_empty() {
90 return None;
91 }
92
93 // Check for ImageDType first (they always win in promotion)
94 if let Some(img) = dtypes.iter().find(|d| matches!(d, DType::Image { .. })) {
95 return Some(img.clone());
96 }
97
98 // Check if all types are identical Ptr types
99 let first = &dtypes[0];
100 if matches!(first, DType::Ptr { .. }) && dtypes.iter().all(|d| d == first) {
101 return Some(first.clone());
102 }
103
104 // Find common scalar type via promotion lattice intersection
105 // Use base() to extract scalar from Ptr types for promotion
106 // This allows Ptr<Float32> + Float32 → Float32
107 let scalar_result = dtypes
108 .iter()
109 .map(|d| d.base())
110 .map(|s| s.get_recursive_parents())
111 .reduce(|lhs, rhs| lhs.intersection(rhs))?
112 .iter()
113 .min()?; // min by discriminant (= priority: lower = more specific)
114
115 // Svod extension: preserve vector count if all inputs have the same vcount > 1.
116 // Tinygrad's least_upper_dtype always returns scalar; we extend it to preserve
117 // vector width when all operands agree, avoiding unnecessary devectorize/revectorize.
118 let vcount = dtypes[0].vcount();
119 if vcount > 1 && dtypes.iter().all(|d| d.vcount() == vcount) {
120 Some(DType::Vector { scalar: scalar_result, count: vcount })
121 } else {
122 Some(DType::Scalar(scalar_result))
123 }
124 }
125}