scivex_core/promote.rs
1//! Type promotion rules and casting utilities for tensor element types.
2//!
3//! Provides numpy-style type promotion via [`promote()`], compile-time type
4//! identification via [`DTypeOf`], and element-wise casting via [`CastFrom`].
5
6use crate::dtype::Scalar;
7use crate::tensor::Tensor;
8
9// ---------------------------------------------------------------------------
10// DType — runtime type tag
11// ---------------------------------------------------------------------------
12
13/// Runtime type tag for tensor element types.
14///
15/// # Examples
16///
17/// ```
18/// use scivex_core::promote::DType;
19/// assert!(DType::F64.is_float());
20/// assert!(!DType::I32.is_float());
21/// assert_eq!(DType::F32.size_bytes(), 4);
22/// ```
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
24pub enum DType {
25 U8,
26 U16,
27 U32,
28 U64,
29 I8,
30 I16,
31 I32,
32 I64,
33 F32,
34 F64,
35}
36
37impl DType {
38 /// Size in bytes of a single element of this type.
39 ///
40 /// # Examples
41 ///
42 /// ```
43 /// use scivex_core::promote::DType;
44 /// assert_eq!(DType::U8.size_bytes(), 1);
45 /// assert_eq!(DType::F64.size_bytes(), 8);
46 /// ```
47 #[inline]
48 pub fn size_bytes(self) -> usize {
49 match self {
50 DType::U8 | DType::I8 => 1,
51 DType::U16 | DType::I16 => 2,
52 DType::U32 | DType::I32 | DType::F32 => 4,
53 DType::U64 | DType::I64 | DType::F64 => 8,
54 }
55 }
56
57 /// Returns `true` if this is a floating-point type.
58 ///
59 /// # Examples
60 ///
61 /// ```
62 /// use scivex_core::promote::DType;
63 /// assert!(DType::F32.is_float());
64 /// assert!(!DType::I64.is_float());
65 /// ```
66 #[inline]
67 pub fn is_float(self) -> bool {
68 matches!(self, DType::F32 | DType::F64)
69 }
70
71 /// Returns `true` if this is a signed type (signed integers or floats).
72 ///
73 /// # Examples
74 ///
75 /// ```
76 /// use scivex_core::promote::DType;
77 /// assert!(DType::I32.is_signed());
78 /// assert!(DType::F64.is_signed());
79 /// assert!(!DType::U8.is_signed());
80 /// ```
81 #[inline]
82 pub fn is_signed(self) -> bool {
83 matches!(
84 self,
85 DType::I8 | DType::I16 | DType::I32 | DType::I64 | DType::F32 | DType::F64
86 )
87 }
88
89 /// Returns `true` if this is an integer type (signed or unsigned).
90 ///
91 /// # Examples
92 ///
93 /// ```
94 /// use scivex_core::promote::DType;
95 /// assert!(DType::U32.is_integer());
96 /// assert!(!DType::F32.is_integer());
97 /// ```
98 #[inline]
99 pub fn is_integer(self) -> bool {
100 !self.is_float()
101 }
102}
103
104// ---------------------------------------------------------------------------
105// promote — numpy-style type promotion
106// ---------------------------------------------------------------------------
107
108/// Determine the result type when combining two dtypes (numpy-style promotion).
109///
110/// Rules:
111/// - Same type -> same type
112/// - Integer + Integer -> wider integer (preserving signedness if possible)
113/// - Signed + Unsigned -> signed type wide enough for both
114/// - Any Integer + Any Float -> the float type
115/// - F32 + F64 -> F64
116///
117/// # Examples
118///
119/// ```
120/// use scivex_core::promote::{DType, promote};
121/// assert_eq!(promote(DType::I32, DType::F64), DType::F64);
122/// assert_eq!(promote(DType::F32, DType::F64), DType::F64);
123/// assert_eq!(promote(DType::I8, DType::I32), DType::I32);
124/// ```
125pub fn promote(a: DType, b: DType) -> DType {
126 if a == b {
127 return a;
128 }
129
130 // If either is float, the result is float.
131 match (a.is_float(), b.is_float()) {
132 (true, true) => {
133 // F32 + F64 -> F64
134 if a == DType::F64 || b == DType::F64 {
135 return DType::F64;
136 }
137 return DType::F32;
138 }
139 (true, false) => return promote_int_float(b, a),
140 (false, true) => return promote_int_float(a, b),
141 (false, false) => {}
142 }
143
144 // Both are integers.
145 let a_signed = a.is_signed();
146 let b_signed = b.is_signed();
147 let a_bytes = a.size_bytes();
148 let b_bytes = b.size_bytes();
149
150 match (a_signed, b_signed) {
151 // Both same signedness: pick the wider one.
152 (true, true) | (false, false) => {
153 if a_bytes >= b_bytes {
154 a
155 } else {
156 b
157 }
158 }
159 // Mixed signedness: need a signed type wide enough for both.
160 _ => {
161 let (signed_dt, unsigned_dt) = if a_signed { (a, b) } else { (b, a) };
162 let s_bytes = signed_dt.size_bytes();
163 let u_bytes = unsigned_dt.size_bytes();
164 // The signed type must be strictly wider than the unsigned type
165 // to represent all values of both.
166 if s_bytes > u_bytes {
167 // Signed type is already wide enough.
168 signed_dt
169 } else {
170 // Need to widen: pick the next signed type larger than the unsigned type.
171 match u_bytes {
172 1 => DType::I16,
173 2 => DType::I32,
174 4 => DType::I64,
175 // u64 cannot be fully represented by i64; promote to f64.
176 _ => DType::F64,
177 }
178 }
179 }
180 }
181}
182
183/// Promote an integer dtype combined with a float dtype.
184fn promote_int_float(_int_dt: DType, float_dt: DType) -> DType {
185 // Any integer + any float -> the float type.
186 // If the integer is very wide (64-bit) and float is F32, we still return
187 // the float type (matching numpy behavior).
188 float_dt
189}
190
191// ---------------------------------------------------------------------------
192// DTypeOf — compile-time type -> DType mapping
193// ---------------------------------------------------------------------------
194
195/// Trait to get the [`DType`] tag for a [`Scalar`] type at compile time.
196///
197/// # Examples
198///
199/// ```
200/// use scivex_core::promote::{DType, DTypeOf};
201/// assert_eq!(<f64 as DTypeOf>::dtype(), DType::F64);
202/// assert_eq!(<i32 as DTypeOf>::dtype(), DType::I32);
203/// ```
204pub trait DTypeOf: Scalar {
205 /// The runtime [`DType`] tag for this type.
206 ///
207 /// # Examples
208 ///
209 /// ```
210 /// use scivex_core::promote::{DType, DTypeOf};
211 /// assert_eq!(f32::dtype(), DType::F32);
212 /// ```
213 fn dtype() -> DType;
214}
215
216macro_rules! impl_dtype_of {
217 ($($ty:ty => $variant:ident),* $(,)?) => {
218 $(
219 impl DTypeOf for $ty {
220 #[inline]
221 fn dtype() -> DType {
222 DType::$variant
223 }
224 }
225 )*
226 };
227}
228
229impl_dtype_of!(
230 u8 => U8,
231 u16 => U16,
232 u32 => U32,
233 u64 => U64,
234 i8 => I8,
235 i16 => I16,
236 i32 => I32,
237 i64 => I64,
238 f32 => F32,
239 f64 => F64,
240);
241
242// ---------------------------------------------------------------------------
243// CastFrom — numeric type casting
244// ---------------------------------------------------------------------------
245
246/// Trait for numeric type casting between scalar types.
247///
248/// Implementations use Rust `as` casts, which means:
249/// - Integer widening is lossless.
250/// - Float-to-integer truncates toward zero.
251/// - Integer-to-float may lose precision for large values.
252/// - f64-to-f32 may lose precision or become infinity.
253///
254/// # Examples
255///
256/// ```
257/// use scivex_core::promote::CastFrom;
258/// let x: f64 = f64::cast_from(42_u8);
259/// assert_eq!(x, 42.0_f64);
260/// let y: i32 = i32::cast_from(3.9_f64);
261/// assert_eq!(y, 3);
262/// ```
263pub trait CastFrom<T> {
264 /// Cast a value of type `T` into `Self`.
265 ///
266 /// # Examples
267 ///
268 /// ```
269 /// use scivex_core::promote::CastFrom;
270 /// assert_eq!(f32::cast_from(255_u8), 255.0_f32);
271 /// ```
272 fn cast_from(val: T) -> Self;
273}
274
275// Macro to generate CastFrom impls for all reasonable pairs.
276macro_rules! impl_cast_from {
277 ($src:ty => $($dst:ty),* $(,)?) => {
278 $(
279 impl CastFrom<$src> for $dst {
280 #[inline]
281 #[allow(clippy::cast_possible_truncation)]
282 #[allow(clippy::cast_possible_wrap)]
283 #[allow(clippy::cast_sign_loss)]
284 #[allow(clippy::cast_lossless)]
285 #[allow(clippy::cast_precision_loss)]
286 fn cast_from(val: $src) -> Self {
287 val as Self
288 }
289 }
290 )*
291 };
292}
293
294// From every integer/float type to every integer/float type.
295impl_cast_from!(u8 => u8, u16, u32, u64, i8, i16, i32, i64, f32, f64);
296impl_cast_from!(u16 => u8, u16, u32, u64, i8, i16, i32, i64, f32, f64);
297impl_cast_from!(u32 => u8, u16, u32, u64, i8, i16, i32, i64, f32, f64);
298impl_cast_from!(u64 => u8, u16, u32, u64, i8, i16, i32, i64, f32, f64);
299impl_cast_from!(i8 => u8, u16, u32, u64, i8, i16, i32, i64, f32, f64);
300impl_cast_from!(i16 => u8, u16, u32, u64, i8, i16, i32, i64, f32, f64);
301impl_cast_from!(i32 => u8, u16, u32, u64, i8, i16, i32, i64, f32, f64);
302impl_cast_from!(i64 => u8, u16, u32, u64, i8, i16, i32, i64, f32, f64);
303impl_cast_from!(f32 => u8, u16, u32, u64, i8, i16, i32, i64, f32, f64);
304impl_cast_from!(f64 => u8, u16, u32, u64, i8, i16, i32, i64, f32, f64);
305
306// Also support usize / isize as sources and destinations.
307impl_cast_from!(usize => u8, u16, u32, u64, i8, i16, i32, i64, f32, f64, usize, isize);
308impl_cast_from!(isize => u8, u16, u32, u64, i8, i16, i32, i64, f32, f64, usize, isize);
309impl_cast_from!(u8 => usize, isize);
310impl_cast_from!(u16 => usize, isize);
311impl_cast_from!(u32 => usize, isize);
312impl_cast_from!(u64 => usize, isize);
313impl_cast_from!(i8 => usize, isize);
314impl_cast_from!(i16 => usize, isize);
315impl_cast_from!(i32 => usize, isize);
316impl_cast_from!(i64 => usize, isize);
317impl_cast_from!(f32 => usize, isize);
318impl_cast_from!(f64 => usize, isize);
319
320// ---------------------------------------------------------------------------
321// Tensor::cast_to
322// ---------------------------------------------------------------------------
323
324impl<T: Scalar> Tensor<T> {
325 /// Cast every element of this tensor to a different scalar type.
326 ///
327 /// This allocates a new tensor with the same shape and copies each element
328 /// through [`CastFrom`].
329 ///
330 /// # Examples
331 ///
332 /// ```
333 /// # use scivex_core::Tensor;
334 /// # use scivex_core::promote::CastFrom;
335 /// let t = Tensor::from_vec(vec![1_u8, 2, 3, 4], vec![2, 2]).unwrap();
336 /// let f: Tensor<f64> = t.cast_to();
337 /// assert_eq!(f.as_slice(), &[1.0, 2.0, 3.0, 4.0]);
338 /// ```
339 pub fn cast_to<U: Scalar + CastFrom<T>>(&self) -> Tensor<U> {
340 let data: Vec<U> = self.as_slice().iter().map(|&v| U::cast_from(v)).collect();
341 // Shape is unchanged so the product of dimensions equals data.len();
342 // from_vec will never fail here.
343 Tensor::from_vec(data, self.shape().to_vec())
344 .expect("cast_to: shape unchanged, from_vec cannot fail")
345 }
346}
347
348// ---------------------------------------------------------------------------
349// Tests
350// ---------------------------------------------------------------------------
351
352#[cfg(test)]
353mod tests {
354 use super::*;
355
356 #[test]
357 fn test_promote_same_type() {
358 assert_eq!(promote(DType::F32, DType::F32), DType::F32);
359 assert_eq!(promote(DType::I32, DType::I32), DType::I32);
360 assert_eq!(promote(DType::U8, DType::U8), DType::U8);
361 }
362
363 #[test]
364 fn test_promote_int_float() {
365 assert_eq!(promote(DType::I32, DType::F32), DType::F32);
366 assert_eq!(promote(DType::U8, DType::F64), DType::F64);
367 assert_eq!(promote(DType::U64, DType::F32), DType::F32);
368 assert_eq!(promote(DType::I64, DType::F64), DType::F64);
369 }
370
371 #[test]
372 fn test_promote_int_widening() {
373 assert_eq!(promote(DType::I8, DType::I32), DType::I32);
374 assert_eq!(promote(DType::U8, DType::U16), DType::U16);
375 assert_eq!(promote(DType::U16, DType::U32), DType::U32);
376 assert_eq!(promote(DType::I16, DType::I64), DType::I64);
377 }
378
379 #[test]
380 fn test_promote_signed_unsigned() {
381 // I8 (1 byte signed) + U8 (1 byte unsigned) -> I16 (need wider signed)
382 assert_eq!(promote(DType::I8, DType::U8), DType::I16);
383 // I16 (2 bytes signed) + U16 (2 bytes unsigned) -> I32
384 assert_eq!(promote(DType::I16, DType::U16), DType::I32);
385 // I32 + U32 -> I64
386 assert_eq!(promote(DType::I32, DType::U32), DType::I64);
387 // I64 + U64 -> F64 (no wider signed integer)
388 assert_eq!(promote(DType::I64, DType::U64), DType::F64);
389 // I32 (4 bytes) + U8 (1 byte) -> I32 (signed already wider)
390 assert_eq!(promote(DType::I32, DType::U8), DType::I32);
391 }
392
393 #[test]
394 fn test_dtype_of() {
395 assert_eq!(<u8 as DTypeOf>::dtype(), DType::U8);
396 assert_eq!(<u16 as DTypeOf>::dtype(), DType::U16);
397 assert_eq!(<u32 as DTypeOf>::dtype(), DType::U32);
398 assert_eq!(<u64 as DTypeOf>::dtype(), DType::U64);
399 assert_eq!(<i8 as DTypeOf>::dtype(), DType::I8);
400 assert_eq!(<i16 as DTypeOf>::dtype(), DType::I16);
401 assert_eq!(<i32 as DTypeOf>::dtype(), DType::I32);
402 assert_eq!(<i64 as DTypeOf>::dtype(), DType::I64);
403 assert_eq!(<f32 as DTypeOf>::dtype(), DType::F32);
404 assert_eq!(<f64 as DTypeOf>::dtype(), DType::F64);
405 }
406
407 #[test]
408 fn test_dtype_properties() {
409 assert_eq!(DType::U8.size_bytes(), 1);
410 assert_eq!(DType::U16.size_bytes(), 2);
411 assert_eq!(DType::U32.size_bytes(), 4);
412 assert_eq!(DType::U64.size_bytes(), 8);
413 assert_eq!(DType::I8.size_bytes(), 1);
414 assert_eq!(DType::I16.size_bytes(), 2);
415 assert_eq!(DType::I32.size_bytes(), 4);
416 assert_eq!(DType::I64.size_bytes(), 8);
417 assert_eq!(DType::F32.size_bytes(), 4);
418 assert_eq!(DType::F64.size_bytes(), 8);
419
420 assert!(DType::F32.is_float());
421 assert!(DType::F64.is_float());
422 assert!(!DType::I32.is_float());
423 assert!(!DType::U8.is_float());
424
425 assert!(DType::I8.is_signed());
426 assert!(DType::F64.is_signed());
427 assert!(!DType::U8.is_signed());
428 assert!(!DType::U64.is_signed());
429
430 assert!(DType::I32.is_integer());
431 assert!(DType::U16.is_integer());
432 assert!(!DType::F32.is_integer());
433 }
434
435 #[test]
436 fn test_cast_from_u8_to_f64() {
437 let t = Tensor::from_vec(vec![1_u8, 2, 3, 4, 5, 6], vec![2, 3]).unwrap();
438 let f: Tensor<f64> = t.cast_to();
439 assert_eq!(f.shape(), &[2, 3]);
440 assert_eq!(f.as_slice(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
441 }
442
443 #[test]
444 fn test_cast_from_f64_to_f32() {
445 let t = Tensor::from_vec(vec![1.5_f64, 2.25, -3.0, 1e30], vec![4]).unwrap();
446 let f: Tensor<f32> = t.cast_to();
447 assert_eq!(f.shape(), &[4]);
448 assert!((f.as_slice()[0] - 1.5_f32).abs() < f32::EPSILON);
449 assert!((f.as_slice()[1] - 2.25_f32).abs() < f32::EPSILON);
450 assert!((f.as_slice()[2] - (-3.0_f32)).abs() < f32::EPSILON);
451 }
452}