Skip to main content

scivex_core/
promote.rs

1//! Type promotion rules and casting utilities for tensor element types.
2//!
3//! Provides numpy-style type promotion via [`promote()`], compile-time type
4//! identification via [`DTypeOf`], and element-wise casting via [`CastFrom`].
5
6use crate::dtype::Scalar;
7use crate::tensor::Tensor;
8
9// ---------------------------------------------------------------------------
10// DType — runtime type tag
11// ---------------------------------------------------------------------------
12
13/// Runtime type tag for tensor element types.
14///
15/// # Examples
16///
17/// ```
18/// use scivex_core::promote::DType;
19/// assert!(DType::F64.is_float());
20/// assert!(!DType::I32.is_float());
21/// assert_eq!(DType::F32.size_bytes(), 4);
22/// ```
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
24pub enum DType {
25    U8,
26    U16,
27    U32,
28    U64,
29    I8,
30    I16,
31    I32,
32    I64,
33    F32,
34    F64,
35}
36
37impl DType {
38    /// Size in bytes of a single element of this type.
39    ///
40    /// # Examples
41    ///
42    /// ```
43    /// use scivex_core::promote::DType;
44    /// assert_eq!(DType::U8.size_bytes(), 1);
45    /// assert_eq!(DType::F64.size_bytes(), 8);
46    /// ```
47    #[inline]
48    pub fn size_bytes(self) -> usize {
49        match self {
50            DType::U8 | DType::I8 => 1,
51            DType::U16 | DType::I16 => 2,
52            DType::U32 | DType::I32 | DType::F32 => 4,
53            DType::U64 | DType::I64 | DType::F64 => 8,
54        }
55    }
56
57    /// Returns `true` if this is a floating-point type.
58    ///
59    /// # Examples
60    ///
61    /// ```
62    /// use scivex_core::promote::DType;
63    /// assert!(DType::F32.is_float());
64    /// assert!(!DType::I64.is_float());
65    /// ```
66    #[inline]
67    pub fn is_float(self) -> bool {
68        matches!(self, DType::F32 | DType::F64)
69    }
70
71    /// Returns `true` if this is a signed type (signed integers or floats).
72    ///
73    /// # Examples
74    ///
75    /// ```
76    /// use scivex_core::promote::DType;
77    /// assert!(DType::I32.is_signed());
78    /// assert!(DType::F64.is_signed());
79    /// assert!(!DType::U8.is_signed());
80    /// ```
81    #[inline]
82    pub fn is_signed(self) -> bool {
83        matches!(
84            self,
85            DType::I8 | DType::I16 | DType::I32 | DType::I64 | DType::F32 | DType::F64
86        )
87    }
88
89    /// Returns `true` if this is an integer type (signed or unsigned).
90    ///
91    /// # Examples
92    ///
93    /// ```
94    /// use scivex_core::promote::DType;
95    /// assert!(DType::U32.is_integer());
96    /// assert!(!DType::F32.is_integer());
97    /// ```
98    #[inline]
99    pub fn is_integer(self) -> bool {
100        !self.is_float()
101    }
102}
103
104// ---------------------------------------------------------------------------
105// promote — numpy-style type promotion
106// ---------------------------------------------------------------------------
107
108/// Determine the result type when combining two dtypes (numpy-style promotion).
109///
110/// Rules:
111/// - Same type -> same type
112/// - Integer + Integer -> wider integer (preserving signedness if possible)
113/// - Signed + Unsigned -> signed type wide enough for both
114/// - Any Integer + Any Float -> the float type
115/// - F32 + F64 -> F64
116///
117/// # Examples
118///
119/// ```
120/// use scivex_core::promote::{DType, promote};
121/// assert_eq!(promote(DType::I32, DType::F64), DType::F64);
122/// assert_eq!(promote(DType::F32, DType::F64), DType::F64);
123/// assert_eq!(promote(DType::I8, DType::I32), DType::I32);
124/// ```
125pub fn promote(a: DType, b: DType) -> DType {
126    if a == b {
127        return a;
128    }
129
130    // If either is float, the result is float.
131    match (a.is_float(), b.is_float()) {
132        (true, true) => {
133            // F32 + F64 -> F64
134            if a == DType::F64 || b == DType::F64 {
135                return DType::F64;
136            }
137            return DType::F32;
138        }
139        (true, false) => return promote_int_float(b, a),
140        (false, true) => return promote_int_float(a, b),
141        (false, false) => {}
142    }
143
144    // Both are integers.
145    let a_signed = a.is_signed();
146    let b_signed = b.is_signed();
147    let a_bytes = a.size_bytes();
148    let b_bytes = b.size_bytes();
149
150    match (a_signed, b_signed) {
151        // Both same signedness: pick the wider one.
152        (true, true) | (false, false) => {
153            if a_bytes >= b_bytes {
154                a
155            } else {
156                b
157            }
158        }
159        // Mixed signedness: need a signed type wide enough for both.
160        _ => {
161            let (signed_dt, unsigned_dt) = if a_signed { (a, b) } else { (b, a) };
162            let s_bytes = signed_dt.size_bytes();
163            let u_bytes = unsigned_dt.size_bytes();
164            // The signed type must be strictly wider than the unsigned type
165            // to represent all values of both.
166            if s_bytes > u_bytes {
167                // Signed type is already wide enough.
168                signed_dt
169            } else {
170                // Need to widen: pick the next signed type larger than the unsigned type.
171                match u_bytes {
172                    1 => DType::I16,
173                    2 => DType::I32,
174                    4 => DType::I64,
175                    // u64 cannot be fully represented by i64; promote to f64.
176                    _ => DType::F64,
177                }
178            }
179        }
180    }
181}
182
183/// Promote an integer dtype combined with a float dtype.
184fn promote_int_float(_int_dt: DType, float_dt: DType) -> DType {
185    // Any integer + any float -> the float type.
186    // If the integer is very wide (64-bit) and float is F32, we still return
187    // the float type (matching numpy behavior).
188    float_dt
189}
190
191// ---------------------------------------------------------------------------
192// DTypeOf — compile-time type -> DType mapping
193// ---------------------------------------------------------------------------
194
195/// Trait to get the [`DType`] tag for a [`Scalar`] type at compile time.
196///
197/// # Examples
198///
199/// ```
200/// use scivex_core::promote::{DType, DTypeOf};
201/// assert_eq!(<f64 as DTypeOf>::dtype(), DType::F64);
202/// assert_eq!(<i32 as DTypeOf>::dtype(), DType::I32);
203/// ```
204pub trait DTypeOf: Scalar {
205    /// The runtime [`DType`] tag for this type.
206    ///
207    /// # Examples
208    ///
209    /// ```
210    /// use scivex_core::promote::{DType, DTypeOf};
211    /// assert_eq!(f32::dtype(), DType::F32);
212    /// ```
213    fn dtype() -> DType;
214}
215
216macro_rules! impl_dtype_of {
217    ($($ty:ty => $variant:ident),* $(,)?) => {
218        $(
219            impl DTypeOf for $ty {
220                #[inline]
221                fn dtype() -> DType {
222                    DType::$variant
223                }
224            }
225        )*
226    };
227}
228
229impl_dtype_of!(
230    u8 => U8,
231    u16 => U16,
232    u32 => U32,
233    u64 => U64,
234    i8 => I8,
235    i16 => I16,
236    i32 => I32,
237    i64 => I64,
238    f32 => F32,
239    f64 => F64,
240);
241
242// ---------------------------------------------------------------------------
243// CastFrom — numeric type casting
244// ---------------------------------------------------------------------------
245
246/// Trait for numeric type casting between scalar types.
247///
248/// Implementations use Rust `as` casts, which means:
249/// - Integer widening is lossless.
250/// - Float-to-integer truncates toward zero.
251/// - Integer-to-float may lose precision for large values.
252/// - f64-to-f32 may lose precision or become infinity.
253///
254/// # Examples
255///
256/// ```
257/// use scivex_core::promote::CastFrom;
258/// let x: f64 = f64::cast_from(42_u8);
259/// assert_eq!(x, 42.0_f64);
260/// let y: i32 = i32::cast_from(3.9_f64);
261/// assert_eq!(y, 3);
262/// ```
263pub trait CastFrom<T> {
264    /// Cast a value of type `T` into `Self`.
265    ///
266    /// # Examples
267    ///
268    /// ```
269    /// use scivex_core::promote::CastFrom;
270    /// assert_eq!(f32::cast_from(255_u8), 255.0_f32);
271    /// ```
272    fn cast_from(val: T) -> Self;
273}
274
275// Macro to generate CastFrom impls for all reasonable pairs.
276macro_rules! impl_cast_from {
277    ($src:ty => $($dst:ty),* $(,)?) => {
278        $(
279            impl CastFrom<$src> for $dst {
280                #[inline]
281                #[allow(clippy::cast_possible_truncation)]
282                #[allow(clippy::cast_possible_wrap)]
283                #[allow(clippy::cast_sign_loss)]
284                #[allow(clippy::cast_lossless)]
285                #[allow(clippy::cast_precision_loss)]
286                fn cast_from(val: $src) -> Self {
287                    val as Self
288                }
289            }
290        )*
291    };
292}
293
294// From every integer/float type to every integer/float type.
295impl_cast_from!(u8  => u8, u16, u32, u64, i8, i16, i32, i64, f32, f64);
296impl_cast_from!(u16 => u8, u16, u32, u64, i8, i16, i32, i64, f32, f64);
297impl_cast_from!(u32 => u8, u16, u32, u64, i8, i16, i32, i64, f32, f64);
298impl_cast_from!(u64 => u8, u16, u32, u64, i8, i16, i32, i64, f32, f64);
299impl_cast_from!(i8  => u8, u16, u32, u64, i8, i16, i32, i64, f32, f64);
300impl_cast_from!(i16 => u8, u16, u32, u64, i8, i16, i32, i64, f32, f64);
301impl_cast_from!(i32 => u8, u16, u32, u64, i8, i16, i32, i64, f32, f64);
302impl_cast_from!(i64 => u8, u16, u32, u64, i8, i16, i32, i64, f32, f64);
303impl_cast_from!(f32 => u8, u16, u32, u64, i8, i16, i32, i64, f32, f64);
304impl_cast_from!(f64 => u8, u16, u32, u64, i8, i16, i32, i64, f32, f64);
305
306// Also support usize / isize as sources and destinations.
307impl_cast_from!(usize => u8, u16, u32, u64, i8, i16, i32, i64, f32, f64, usize, isize);
308impl_cast_from!(isize => u8, u16, u32, u64, i8, i16, i32, i64, f32, f64, usize, isize);
309impl_cast_from!(u8    => usize, isize);
310impl_cast_from!(u16   => usize, isize);
311impl_cast_from!(u32   => usize, isize);
312impl_cast_from!(u64   => usize, isize);
313impl_cast_from!(i8    => usize, isize);
314impl_cast_from!(i16   => usize, isize);
315impl_cast_from!(i32   => usize, isize);
316impl_cast_from!(i64   => usize, isize);
317impl_cast_from!(f32   => usize, isize);
318impl_cast_from!(f64   => usize, isize);
319
320// ---------------------------------------------------------------------------
321// Tensor::cast_to
322// ---------------------------------------------------------------------------
323
324impl<T: Scalar> Tensor<T> {
325    /// Cast every element of this tensor to a different scalar type.
326    ///
327    /// This allocates a new tensor with the same shape and copies each element
328    /// through [`CastFrom`].
329    ///
330    /// # Examples
331    ///
332    /// ```
333    /// # use scivex_core::Tensor;
334    /// # use scivex_core::promote::CastFrom;
335    /// let t = Tensor::from_vec(vec![1_u8, 2, 3, 4], vec![2, 2]).unwrap();
336    /// let f: Tensor<f64> = t.cast_to();
337    /// assert_eq!(f.as_slice(), &[1.0, 2.0, 3.0, 4.0]);
338    /// ```
339    pub fn cast_to<U: Scalar + CastFrom<T>>(&self) -> Tensor<U> {
340        let data: Vec<U> = self.as_slice().iter().map(|&v| U::cast_from(v)).collect();
341        // Shape is unchanged so the product of dimensions equals data.len();
342        // from_vec will never fail here.
343        Tensor::from_vec(data, self.shape().to_vec())
344            .expect("cast_to: shape unchanged, from_vec cannot fail")
345    }
346}
347
348// ---------------------------------------------------------------------------
349// Tests
350// ---------------------------------------------------------------------------
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355
356    #[test]
357    fn test_promote_same_type() {
358        assert_eq!(promote(DType::F32, DType::F32), DType::F32);
359        assert_eq!(promote(DType::I32, DType::I32), DType::I32);
360        assert_eq!(promote(DType::U8, DType::U8), DType::U8);
361    }
362
363    #[test]
364    fn test_promote_int_float() {
365        assert_eq!(promote(DType::I32, DType::F32), DType::F32);
366        assert_eq!(promote(DType::U8, DType::F64), DType::F64);
367        assert_eq!(promote(DType::U64, DType::F32), DType::F32);
368        assert_eq!(promote(DType::I64, DType::F64), DType::F64);
369    }
370
371    #[test]
372    fn test_promote_int_widening() {
373        assert_eq!(promote(DType::I8, DType::I32), DType::I32);
374        assert_eq!(promote(DType::U8, DType::U16), DType::U16);
375        assert_eq!(promote(DType::U16, DType::U32), DType::U32);
376        assert_eq!(promote(DType::I16, DType::I64), DType::I64);
377    }
378
379    #[test]
380    fn test_promote_signed_unsigned() {
381        // I8 (1 byte signed) + U8 (1 byte unsigned) -> I16 (need wider signed)
382        assert_eq!(promote(DType::I8, DType::U8), DType::I16);
383        // I16 (2 bytes signed) + U16 (2 bytes unsigned) -> I32
384        assert_eq!(promote(DType::I16, DType::U16), DType::I32);
385        // I32 + U32 -> I64
386        assert_eq!(promote(DType::I32, DType::U32), DType::I64);
387        // I64 + U64 -> F64 (no wider signed integer)
388        assert_eq!(promote(DType::I64, DType::U64), DType::F64);
389        // I32 (4 bytes) + U8 (1 byte) -> I32 (signed already wider)
390        assert_eq!(promote(DType::I32, DType::U8), DType::I32);
391    }
392
393    #[test]
394    fn test_dtype_of() {
395        assert_eq!(<u8 as DTypeOf>::dtype(), DType::U8);
396        assert_eq!(<u16 as DTypeOf>::dtype(), DType::U16);
397        assert_eq!(<u32 as DTypeOf>::dtype(), DType::U32);
398        assert_eq!(<u64 as DTypeOf>::dtype(), DType::U64);
399        assert_eq!(<i8 as DTypeOf>::dtype(), DType::I8);
400        assert_eq!(<i16 as DTypeOf>::dtype(), DType::I16);
401        assert_eq!(<i32 as DTypeOf>::dtype(), DType::I32);
402        assert_eq!(<i64 as DTypeOf>::dtype(), DType::I64);
403        assert_eq!(<f32 as DTypeOf>::dtype(), DType::F32);
404        assert_eq!(<f64 as DTypeOf>::dtype(), DType::F64);
405    }
406
407    #[test]
408    fn test_dtype_properties() {
409        assert_eq!(DType::U8.size_bytes(), 1);
410        assert_eq!(DType::U16.size_bytes(), 2);
411        assert_eq!(DType::U32.size_bytes(), 4);
412        assert_eq!(DType::U64.size_bytes(), 8);
413        assert_eq!(DType::I8.size_bytes(), 1);
414        assert_eq!(DType::I16.size_bytes(), 2);
415        assert_eq!(DType::I32.size_bytes(), 4);
416        assert_eq!(DType::I64.size_bytes(), 8);
417        assert_eq!(DType::F32.size_bytes(), 4);
418        assert_eq!(DType::F64.size_bytes(), 8);
419
420        assert!(DType::F32.is_float());
421        assert!(DType::F64.is_float());
422        assert!(!DType::I32.is_float());
423        assert!(!DType::U8.is_float());
424
425        assert!(DType::I8.is_signed());
426        assert!(DType::F64.is_signed());
427        assert!(!DType::U8.is_signed());
428        assert!(!DType::U64.is_signed());
429
430        assert!(DType::I32.is_integer());
431        assert!(DType::U16.is_integer());
432        assert!(!DType::F32.is_integer());
433    }
434
435    #[test]
436    fn test_cast_from_u8_to_f64() {
437        let t = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], vec![2, 3]).unwrap();
438        let f: Tensor<f64> = t.cast_to();
439        assert_eq!(f.shape(), &[2, 3]);
440        assert_eq!(f.as_slice(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
441    }
442
443    #[test]
444    fn test_cast_from_f64_to_f32() {
445        let t = Tensor::from_vec(vec![1.5_f64, 2.25, -3.0, 1e30], vec![4]).unwrap();
446        let f: Tensor<f32> = t.cast_to();
447        assert_eq!(f.shape(), &[4]);
448        assert!((f.as_slice()[0] - 1.5_f32).abs() < f32::EPSILON);
449        assert!((f.as_slice()[1] - 2.25_f32).abs() < f32::EPSILON);
450        assert!((f.as_slice()[2] - (-3.0_f32)).abs() < f32::EPSILON);
451    }
452}