Skip to main content

svod_dtype/
cast.rs

1use super::*;
2use enumset::EnumSet;
3
4impl ScalarDType {
5    const fn promotion_lattice(self) -> &'static [Self] {
6        use ScalarDType::*;
7        match self {
8            Bool => &[Int8, UInt8],
9            Int8 => &[Int16],
10            Int16 => &[Int32],
11            Int32 => &[Int64],
12            Int64 => &[FP8E4M3, FP8E5M2],
13            UInt8 => &[Int16, UInt16],
14            UInt16 => &[Int32, UInt32],
15            UInt32 => &[Int64, UInt64],
16            UInt64 => &[FP8E4M3, FP8E5M2],
17            FP8E5M2 => &[Float16, BFloat16],
18            FP8E4M3 => &[Float16, BFloat16],
19            Float16 => &[Float32],
20            BFloat16 => &[Float32],
21            Float32 => &[Float64],
22            Float64 | Void | Index => &[],
23        }
24    }
25
26    fn get_recursive_parents(self) -> EnumSet<Self> {
27        self.promotion_lattice()
28            .iter()
29            .fold(EnumSet::only(self), |dtypes, &parent| dtypes.union(parent.get_recursive_parents()))
30    }
31
32    /// Check if casting from `from` to `to` is safe (preserves value).
33    pub fn can_safe_cast(self, to: Self) -> bool {
34        // Same type (compare discriminants) or from Bool (Bool can cast to anything)
35        if self == to || matches!(self, Self::Bool) {
36            return true;
37        }
38
39        // Index type: can cast from any integer to Index
40        if matches!(to, Self::Index) {
41            return self.is_int();
42        }
43
44        let from_bytes = self.bytes();
45        let to_bytes = to.bytes();
46        match (self.is_unsigned(), self.is_signed(), self.is_float(), to.is_unsigned(), to.is_signed(), to.is_float()) {
47            // Unsigned -> Unsigned: only if target is larger
48            (true, _, _, true, _, _) => from_bytes < to_bytes,
49            // Signed -> Signed: only if target is same size or larger
50            (_, true, _, _, true, _) => from_bytes <= to_bytes,
51            // Unsigned -> Signed: only if target is strictly larger
52            (true, _, _, _, true, _) => from_bytes < to_bytes,
53            // Integer -> Float: safe if integer is Int32 or smaller
54            (_, _, false, _, _, true) => from_bytes <= Self::Int32.bytes(),
55            // Float -> Float: only if target is larger
56            (_, _, true, _, _, true) => from_bytes < to_bytes,
57            _ => false,
58        }
59    }
60}
61
62impl DType {
63    /// Check if casting from `from` to `to` is safe (preserves value).
64    pub fn can_safe_cast(from: Self, to: Self) -> bool {
65        // Extract scalars
66        let (Some(from_scalar), Some(to_scalar)) = (from.scalar(), to.scalar()) else {
67            return false;
68        };
69
70        // Check scalar cast is safe
71        if !from_scalar.can_safe_cast(to_scalar) {
72            return false;
73        }
74
75        // Vector counts must match (or broadcast from scalar)
76        from.count() == to.count() || from.count() == 1 || to.count() == 1
77    }
78
79    /// Find the least upper bound type for a set of dtypes.
80    ///
81    /// Returns the smallest type that all input types can be safely cast to.
82    ///
83    /// Type promotion rules:
84    /// - Scalar + Scalar → promoted Scalar
85    /// - `Ptr<T>` + `Ptr<T>` → `Ptr<T>` (same Ptr types)
86    /// - `Ptr<T>` + `Scalar(T)` → `Scalar(T)` (Ptr will be auto-loaded in codegen)
87    /// - `Ptr<T>` + `Scalar(U)` → promoted Scalar (if T and U are compatible)
88    pub fn least_upper_dtype(dtypes: &[Self]) -> Option<Self> {
89        if dtypes.is_empty() {
90            return None;
91        }
92
93        // Check for ImageDType first (they always win in promotion)
94        if let Some(img) = dtypes.iter().find(|d| matches!(d, DType::Image { .. })) {
95            return Some(img.clone());
96        }
97
98        // Check if all types are identical Ptr types
99        let first = &dtypes[0];
100        if matches!(first, DType::Ptr { .. }) && dtypes.iter().all(|d| d == first) {
101            return Some(first.clone());
102        }
103
104        // Find common scalar type via promotion lattice intersection
105        // Use base() to extract scalar from Ptr types for promotion
106        // This allows Ptr<Float32> + Float32 → Float32
107        let scalar_result = dtypes
108            .iter()
109            .map(|d| d.base())
110            .map(|s| s.get_recursive_parents())
111            .reduce(|lhs, rhs| lhs.intersection(rhs))?
112            .iter()
113            .min()?; // min by discriminant (= priority: lower = more specific)
114
115        // Svod extension: preserve vector count if all inputs have the same vcount > 1.
116        // Tinygrad's least_upper_dtype always returns scalar; we extend it to preserve
117        // vector width when all operands agree, avoiding unnecessary devectorize/revectorize.
118        let vcount = dtypes[0].vcount();
119        if vcount > 1 && dtypes.iter().all(|d| d.vcount() == vcount) {
120            Some(DType::Vector { scalar: scalar_result, count: vcount })
121        } else {
122            Some(DType::Scalar(scalar_result))
123        }
124    }
125}