Skip to main content

svod_dtype/
lib.rs

1pub mod cast;
2pub mod ext;
3
4#[cfg(any(test, feature = "proptest"))]
5pub mod test;
6
7use std::path::PathBuf;
8
9/// Device specification parsed from a device string.
10///
11/// This enum represents different compute devices that can execute kernels.
12/// It's used throughout the compilation pipeline for device selection and
13/// kernel caching.
14#[derive(Debug, Clone, PartialEq, Eq, Hash)]
15#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
16pub enum DeviceSpec {
17    /// CPU device (single-threaded or multi-threaded execution)
18    Cpu,
19    /// CUDA GPU device with specific device ID
20    Cuda { device_id: usize },
21    /// Metal GPU device (Apple Silicon) with specific device ID
22    Metal { device_id: usize },
23    /// WebGPU device (browser or native WebGPU)
24    WebGpu,
25    /// File-backed device (memory-mapped, read-only). Matches Tinygrad's DISK device.
26    /// Cannot execute kernels — data is transferred to compute devices via COPY.
27    Disk { path: PathBuf },
28}
29
30impl DeviceSpec {
31    /// Canonicalize the device spec to a standard string representation.
32    ///
33    /// # Examples
34    ///
35    /// ```
36    /// use svod_dtype::DeviceSpec;
37    ///
38    /// assert_eq!(DeviceSpec::Cpu.canonicalize(), "CPU");
39    /// assert_eq!(DeviceSpec::Cuda { device_id: 0 }.canonicalize(), "CUDA:0");
40    /// assert_eq!(DeviceSpec::Cuda { device_id: 1 }.canonicalize(), "CUDA:1");
41    /// ```
42    pub fn canonicalize(&self) -> String {
43        match self {
44            DeviceSpec::Cpu => "CPU".to_string(),
45            DeviceSpec::Cuda { device_id } => format!("CUDA:{device_id}"),
46            DeviceSpec::Metal { device_id } => format!("Metal:{device_id}"),
47            DeviceSpec::WebGpu => "WebGPU".to_string(),
48            DeviceSpec::Disk { path } => format!("DISK:{}", path.display()),
49        }
50    }
51
52    /// Get maximum buffer count for this device.
53    ///
54    /// Returns None if the device has no buffer limit (effectively unlimited).
55    ///
56    /// Known limits:
57    /// - Metal: 31 buffers (Apple Silicon hardware limit)
58    /// - WebGPU: 8 buffers (WebGPU specification limit)
59    /// - CPU/CUDA: None (no practical limit)
60    /// - Disk: None (file-backed, no kernel execution)
61    pub fn max_buffers(&self) -> Option<usize> {
62        match self {
63            DeviceSpec::Cpu | DeviceSpec::Disk { .. } => None,
64            DeviceSpec::Cuda { .. } => Some(128),
65            DeviceSpec::Metal { .. } => Some(31),
66            DeviceSpec::WebGpu => Some(8),
67        }
68    }
69
70    /// Get the base device type string (strips device ID / path).
71    ///
72    /// Used for device factory lookup and cache key construction.
73    /// Unlike `canonicalize()`, this returns a static string without device ID.
74    ///
75    /// # Examples
76    ///
77    /// ```
78    /// use svod_dtype::DeviceSpec;
79    ///
80    /// assert_eq!(DeviceSpec::Cpu.base_type(), "CPU");
81    /// assert_eq!(DeviceSpec::Cuda { device_id: 0 }.base_type(), "CUDA");
82    /// assert_eq!(DeviceSpec::Cuda { device_id: 1 }.base_type(), "CUDA");
83    /// ```
84    pub fn base_type(&self) -> &'static str {
85        match self {
86            DeviceSpec::Cpu => "CPU",
87            DeviceSpec::Cuda { .. } => "CUDA",
88            DeviceSpec::Metal { .. } => "METAL",
89            DeviceSpec::WebGpu => "WEBGPU",
90            DeviceSpec::Disk { .. } => "DISK",
91        }
92    }
93
94    /// Check if this is a DISK (file-backed) device.
95    pub fn is_disk(&self) -> bool {
96        matches!(self, DeviceSpec::Disk { .. })
97    }
98}
99
100/// Address space for pointer types.
101#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
102#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
103pub enum AddrSpace {
104    /// Global/device memory.
105    Global,
106    /// Local/shared memory.
107    Local,
108    /// Register memory.
109    Reg,
110}
111
112/// Image type kind.
113#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
114#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
115pub enum ImageKind {
116    /// Half precision image.
117    Half,
118    /// Float precision image.
119    Float,
120}
121
122/// Scalar data types (base numeric types).
123#[derive(Debug, Hash, PartialOrd, Ord)]
124#[derive(strum::EnumCount, strum::EnumIter, strum::VariantArray, strum::FromRepr)]
125#[derive(enumset::EnumSetType)]
126#[cfg_attr(feature = "proptest", derive(proptest_derive::Arbitrary))]
127#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
128#[enumset(repr = "u32")]
129pub enum ScalarDType {
130    Bool = 0,
131
132    // Interleaved signed/unsigned for correct LUB priority (lower = more specific)
133    Int8 = 1,
134    UInt8 = 2,
135    Int16 = 3,
136    UInt16 = 4,
137    Int32 = 5,
138    UInt32 = 6,
139    Int64 = 7,
140    UInt64 = 8,
141
142    FP8E4M3 = 9,
143    FP8E5M2 = 10,
144    Float16 = 11,
145    BFloat16 = 12,
146    Float32 = 13,
147    Float64 = 14,
148
149    /// Void type for metadata operations (no data).
150    Void = 15,
151
152    /// Index type for array indexing and loop iteration.
153    Index = 16,
154}
155
156/// Data type including scalars, vectors, pointers, and images.
157#[derive(Debug, Clone, PartialEq, Eq, Hash)]
158#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
159pub enum DType {
160    /// Scalar type (single value).
161    Scalar(ScalarDType),
162
163    /// Vector type (SIMD).
164    Vector { scalar: ScalarDType, count: usize },
165
166    /// Pointer type.
167    /// `vcount` is the vector count of the pointer itself (1 = scalar pointer, >1 = vector of pointers).
168    /// This matches Tinygrad's PtrDType.v field.
169    Ptr { base: Box<DType>, addrspace: AddrSpace, size: Option<usize>, vcount: usize },
170
171    /// Image type (for texture operations).
172    Image { kind: ImageKind, shape: Vec<usize> },
173}
174
175impl ScalarDType {
176    pub const fn bytes(&self) -> usize {
177        match self {
178            Self::Bool => 1,
179            Self::Int8 => 1,
180            Self::Int16 => 2,
181            Self::Int32 => 4,
182            Self::Int64 => 8,
183            Self::UInt8 => 1,
184            Self::UInt16 => 2,
185            Self::UInt32 => 4,
186            Self::UInt64 => 8,
187            Self::FP8E4M3 => 1,
188            Self::FP8E5M2 => 1,
189            Self::Float16 => 2,
190            Self::BFloat16 => 2,
191            Self::Float32 => 4,
192            Self::Float64 => 8,
193            Self::Void => 0,
194            Self::Index => 8, // Treat as 64-bit index
195        }
196    }
197
198    pub const fn is_bool(&self) -> bool {
199        matches!(self, Self::Bool)
200    }
201
202    pub const fn is_signed(&self) -> bool {
203        matches!(self, Self::Int8 | Self::Int16 | Self::Int32 | Self::Int64)
204    }
205
206    pub const fn is_unsigned(&self) -> bool {
207        matches!(self, Self::UInt8 | Self::UInt16 | Self::UInt32 | Self::UInt64)
208    }
209
210    pub const fn is_int(&self) -> bool {
211        self.is_signed() || self.is_unsigned() || matches!(self, Self::Index)
212    }
213
214    pub const fn is_float(&self) -> bool {
215        matches!(self, Self::FP8E4M3 | Self::FP8E5M2 | Self::Float16 | Self::BFloat16 | Self::Float32 | Self::Float64)
216    }
217
218    pub const fn is_fp8(&self) -> bool {
219        matches!(self, Self::FP8E4M3 | Self::FP8E5M2)
220    }
221
222    pub const fn min_value(&self) -> f64 {
223        match self {
224            Self::Bool => 0.0,
225            Self::Int8 => i8::MIN as f64,
226            Self::Int16 => i16::MIN as f64,
227            Self::Int32 => i32::MIN as f64,
228            Self::Int64 => i64::MIN as f64,
229            Self::UInt8 | Self::UInt16 | Self::UInt32 | Self::UInt64 => 0.0,
230            Self::Float16 => -65504.0,
231            Self::BFloat16 => -3.3895313892515355e38,
232            Self::Float32 => f32::MIN as f64,
233            Self::Float64 => f64::MIN,
234            Self::FP8E4M3 => -448.0,
235            Self::FP8E5M2 => -57344.0,
236            Self::Void | Self::Index => 0.0,
237        }
238    }
239
240    pub const fn max_value(&self) -> f64 {
241        match self {
242            Self::Bool => 1.0,
243            Self::Int8 => i8::MAX as f64,
244            Self::Int16 => i16::MAX as f64,
245            Self::Int32 => i32::MAX as f64,
246            Self::Int64 => i64::MAX as f64,
247            Self::UInt8 => u8::MAX as f64,
248            Self::UInt16 => u16::MAX as f64,
249            Self::UInt32 => u32::MAX as f64,
250            Self::UInt64 => u64::MAX as f64,
251            Self::Float16 => 65504.0,
252            Self::BFloat16 => 3.3895313892515355e38,
253            Self::Float32 => f32::MAX as f64,
254            Self::Float64 => f64::MAX,
255            Self::FP8E4M3 => 448.0,
256            Self::FP8E5M2 => 57344.0,
257            Self::Void | Self::Index => 0.0,
258        }
259    }
260
261    pub const fn c_style(&self) -> &'static str {
262        match self {
263            Self::Bool => "bool",
264            Self::Int8 => "signed char",
265            Self::Int16 => "short",
266            Self::Int32 => "int",
267            Self::Int64 => "long",
268            Self::UInt8 => "unsigned char",
269            Self::UInt16 => "unsigned short",
270            Self::UInt32 => "unsigned int",
271            Self::UInt64 => "unsigned long",
272            Self::FP8E4M3 => "float8_e4m3",
273            Self::FP8E5M2 => "float8_e5m2",
274            Self::Float16 => "half",
275            Self::Float32 => "float",
276            Self::Float64 => "double",
277            Self::BFloat16 => "__bf16",
278            Self::Void => "void",
279            Self::Index => "size_t",
280        }
281    }
282
283    pub const fn min_positive(&self) -> f64 {
284        match self {
285            Self::Float16 => 6.103515625e-05,         // 2^-14
286            Self::BFloat16 => 1.175494350822288e-38,  // 2^-126 (same exponent range as f32)
287            Self::Float32 => 1.1754944e-38,           // f32::MIN_POSITIVE
288            Self::Float64 => 2.2250738585072014e-308, // f64::MIN_POSITIVE
289            _ => 1.1754944e-38,                       // default to f32 range
290        }
291    }
292
293    /// (exponent_bits, mantissa_bits) for float types.
294    /// Matches Tinygrad's `dtypes.finfo()`.
295    pub const fn finfo(&self) -> (u32, u32) {
296        match self {
297            Self::FP8E4M3 => (4, 3),
298            Self::FP8E5M2 => (5, 2),
299            Self::Float16 => (5, 10),
300            Self::BFloat16 => (8, 7),
301            Self::Float32 => (8, 23),
302            Self::Float64 => (11, 52),
303            _ => panic!("finfo: not a float type"),
304        }
305    }
306
307    /// Exponent bias: `(1 << (exp_bits - 1)) - 1`.
308    pub const fn exponent_bias(&self) -> i32 {
309        let (e, _) = self.finfo();
310        (1 << (e - 1)) - 1
311    }
312
313    /// Map float dtype to uint storage equivalent of the same bit width.
314    pub const fn float_to_uint(&self) -> ScalarDType {
315        match self {
316            Self::FP8E4M3 | Self::FP8E5M2 => Self::UInt8,
317            Self::Float16 | Self::BFloat16 => Self::UInt16,
318            Self::Float32 => Self::UInt32,
319            Self::Float64 => Self::UInt64,
320            _ => panic!("float_to_uint: not a float type"),
321        }
322    }
323
324    /// Bit size of this scalar type.
325    pub const fn bitsize(&self) -> u32 {
326        (self.bytes() * 8) as u32
327    }
328
329    /// Create a vector DType from this scalar type.
330    pub const fn vec(self, count: usize) -> DType {
331        DType::Vector { scalar: self, count }
332    }
333}
334
335impl From<ScalarDType> for DType {
336    fn from(scalar: ScalarDType) -> Self {
337        Self::Scalar(scalar)
338    }
339}
340
341impl DType {
342    // =========================================================================
343    // Type Constructors
344    // =========================================================================
345
346    /// Create a vector type from this dtype.
347    pub fn vec(&self, count: usize) -> Self {
348        if count == 1 {
349            return self.clone();
350        }
351
352        match self {
353            Self::Scalar(s) if !matches!(s, ScalarDType::Void) => Self::Vector { scalar: *s, count },
354            Self::Vector { .. } => panic!("Cannot vectorize an already vectorized type"),
355            Self::Ptr { vcount: 1, base, addrspace, size } => {
356                Self::Ptr { base: base.clone(), addrspace: *addrspace, size: *size, vcount: count }
357            }
358            // Already vectorized to target count — idempotent (transient state during
359            // graph rewrite when VECTORIZE(CAST(buf)) is reconstructed before the
360            // INDEX(VECTORIZE(CAST(...))) pattern consumes it).
361            Self::Ptr { vcount, .. } if *vcount == count => self.clone(),
362            Self::Ptr { vcount, .. } => {
363                panic!("Cannot vectorize an already vectorized pointer (vcount={vcount}) to different count ({count})")
364            }
365            _ => self.clone(),
366        }
367    }
368
369    /// Create a pointer type from this dtype.
370    pub fn ptr(self, size: Option<usize>, addrspace: AddrSpace) -> Self {
371        match self {
372            Self::Ptr { .. } => panic!("Cannot make a pointer from a pointer"),
373            _ => Self::Ptr { base: Box::new(self), addrspace, size, vcount: 1 },
374        }
375    }
376
377    pub fn scalar(&self) -> Option<ScalarDType> {
378        match self {
379            Self::Scalar(s) => Some(*s),
380            _ => None,
381        }
382    }
383
384    /// Check if this is a vector type.
385    pub fn is_vector(&self) -> bool {
386        matches!(self, Self::Vector { .. })
387    }
388
389    /// Check if this is an image (texture) type.
390    pub fn is_image(&self) -> bool {
391        matches!(self, Self::Image { .. })
392    }
393
394    /// Get the base scalar type (works for both scalars and vectors).
395    pub fn base(&self) -> ScalarDType {
396        match self {
397            Self::Scalar(s) => *s,
398            Self::Vector { scalar, .. } => *scalar,
399            Self::Ptr { base, .. } => base.base(),
400            Self::Image { .. } => ScalarDType::Float32, // Images use float32 by default
401        }
402    }
403
404    /// Get scalar DType (works on both Scalar and Vector).
405    ///
406    /// Unlike `base()` which returns `ScalarDType`, this returns `DType`.
407    /// This enables chaining with `.vec()`.
408    ///
409    /// # Examples
410    ///
411    /// ```
412    /// use svod_dtype::DType;
413    ///
414    /// let vec_dtype = DType::Float32.vec(4);
415    /// assert_eq!(vec_dtype.scalar_dtype(), DType::Float32);
416    ///
417    /// // Enable chaining: dtype.scalar_dtype().vec(new_count)
418    /// let new_vec = vec_dtype.scalar_dtype().vec(8);
419    /// assert_eq!(new_vec, DType::Float32.vec(8));
420    /// ```
421    pub fn scalar_dtype(&self) -> DType {
422        DType::Scalar(self.base())
423    }
424
425    /// Create a new dtype with a different base scalar type, preserving vector count.
426    ///
427    /// Useful for type conversions like bool→uint8 where the structure is preserved.
428    pub fn with_base(&self, new_base: ScalarDType) -> Self {
429        let count = self.vcount();
430        if count > 1 { Self::Scalar(new_base).vec(count) } else { Self::Scalar(new_base) }
431    }
432
433    /// For Ptr types: replace the base dtype while preserving addrspace, size, and vcount.
434    /// Returns None if not a Ptr.
435    pub fn with_ptr_base(&self, new_base: DType) -> Option<Self> {
436        match self {
437            Self::Ptr { addrspace, size, vcount, .. } => {
438                Some(Self::Ptr { base: Box::new(new_base), addrspace: *addrspace, size: *size, vcount: *vcount })
439            }
440            _ => None,
441        }
442    }
443
444    /// Get the vector count (1 for scalars).
445    pub fn count(&self) -> usize {
446        match self {
447            Self::Vector { count, .. } => *count,
448            _ => 1,
449        }
450    }
451
452    /// Get effective vectorization count (for pointers to vectors).
453    pub fn vcount(&self) -> usize {
454        match self {
455            Self::Vector { count, .. } => *count,
456            Self::Ptr { vcount, .. } => *vcount,
457            _ => 1,
458        }
459    }
460
461    // =========================================================================
462    // Type Properties
463    // =========================================================================
464
465    pub fn bytes(&self) -> usize {
466        match self {
467            Self::Scalar(s) => s.bytes(),
468            Self::Vector { scalar, count } => scalar.bytes() * count,
469            Self::Ptr { .. } => 8,   // Pointers are 64-bit
470            Self::Image { .. } => 8, // Image handles are pointers
471        }
472    }
473
474    pub fn is_bool(&self) -> bool {
475        // Use base() to handle both Scalar and Vector types
476        self.base() == ScalarDType::Bool
477    }
478
479    pub fn is_signed(&self) -> bool {
480        // Use base() to handle both Scalar and Vector types
481        self.base().is_signed()
482    }
483
484    pub fn is_unsigned(&self) -> bool {
485        // Use base() to handle both Scalar and Vector types
486        self.base().is_unsigned()
487    }
488
489    pub fn is_int(&self) -> bool {
490        // Use base() to handle both Scalar and Vector types
491        self.base().is_int()
492    }
493
494    pub fn is_float(&self) -> bool {
495        self.base().is_float()
496    }
497
498    pub fn is_fp8(&self) -> bool {
499        self.base().is_fp8()
500    }
501
502    pub fn min_value(&self) -> f64 {
503        self.base().min_value()
504    }
505
506    pub fn max_value(&self) -> f64 {
507        self.base().max_value()
508    }
509
510    pub fn c_style(&self) -> String {
511        match self {
512            Self::Scalar(s) => s.c_style().to_string(),
513            Self::Vector { scalar, count } => format!("{}[{}]", scalar.c_style(), count),
514            Self::Ptr { base, addrspace, .. } => {
515                let addr_str = match addrspace {
516                    AddrSpace::Global => "__global",
517                    AddrSpace::Local => "__local",
518                    AddrSpace::Reg => "__register",
519                };
520                format!("{} {}*", addr_str, base.c_style())
521            }
522            Self::Image { kind, .. } => match kind {
523                ImageKind::Half => "image2d_t".to_string(),
524                ImageKind::Float => "image2d_t".to_string(),
525            },
526        }
527    }
528}
529
530// Convenient constructors for common scalar types
531impl DType {
532    pub const fn bool_() -> Self {
533        Self::Scalar(ScalarDType::Bool)
534    }
535    pub const fn int8() -> Self {
536        Self::Scalar(ScalarDType::Int8)
537    }
538    pub const fn int16() -> Self {
539        Self::Scalar(ScalarDType::Int16)
540    }
541    pub const fn int32() -> Self {
542        Self::Scalar(ScalarDType::Int32)
543    }
544    pub const fn int64() -> Self {
545        Self::Scalar(ScalarDType::Int64)
546    }
547    pub const fn uint8() -> Self {
548        Self::Scalar(ScalarDType::UInt8)
549    }
550    pub const fn uint16() -> Self {
551        Self::Scalar(ScalarDType::UInt16)
552    }
553    pub const fn uint32() -> Self {
554        Self::Scalar(ScalarDType::UInt32)
555    }
556    pub const fn uint64() -> Self {
557        Self::Scalar(ScalarDType::UInt64)
558    }
559    pub const fn float16() -> Self {
560        Self::Scalar(ScalarDType::Float16)
561    }
562    pub const fn bfloat16() -> Self {
563        Self::Scalar(ScalarDType::BFloat16)
564    }
565    pub const fn float32() -> Self {
566        Self::Scalar(ScalarDType::Float32)
567    }
568    pub const fn float64() -> Self {
569        Self::Scalar(ScalarDType::Float64)
570    }
571    pub const fn void_() -> Self {
572        Self::Scalar(ScalarDType::Void)
573    }
574    pub const fn index() -> Self {
575        Self::Scalar(ScalarDType::Index)
576    }
577}
578
579// Legacy aliases for compatibility
580#[allow(non_upper_case_globals)]
581impl DType {
582    pub const Bool: Self = Self::Scalar(ScalarDType::Bool);
583    pub const Int8: Self = Self::Scalar(ScalarDType::Int8);
584    pub const Int16: Self = Self::Scalar(ScalarDType::Int16);
585    pub const Int32: Self = Self::Scalar(ScalarDType::Int32);
586    pub const Int64: Self = Self::Scalar(ScalarDType::Int64);
587    pub const UInt8: Self = Self::Scalar(ScalarDType::UInt8);
588    pub const UInt16: Self = Self::Scalar(ScalarDType::UInt16);
589    pub const UInt32: Self = Self::Scalar(ScalarDType::UInt32);
590    pub const UInt64: Self = Self::Scalar(ScalarDType::UInt64);
591    pub const FP8E4M3: Self = Self::Scalar(ScalarDType::FP8E4M3);
592    pub const FP8E5M2: Self = Self::Scalar(ScalarDType::FP8E5M2);
593    pub const Float16: Self = Self::Scalar(ScalarDType::Float16);
594    pub const BFloat16: Self = Self::Scalar(ScalarDType::BFloat16);
595    pub const Float32: Self = Self::Scalar(ScalarDType::Float32);
596    pub const Float64: Self = Self::Scalar(ScalarDType::Float64);
597    pub const Void: Self = Self::Scalar(ScalarDType::Void);
598    pub const Index: Self = Self::Scalar(ScalarDType::Index);
599}
600
601/// Trait for types that have an associated DType.
602///
603/// This trait is used for type-safe tensor data extraction (e.g., `to_ndarray::<T>()`).
604pub trait HasDType: Clone + Default {
605    const DTYPE: DType;
606}
607
608impl HasDType for f32 {
609    const DTYPE: DType = DType::Float32;
610}
611
612impl HasDType for f64 {
613    const DTYPE: DType = DType::Float64;
614}
615
616impl HasDType for i8 {
617    const DTYPE: DType = DType::Int8;
618}
619
620impl HasDType for i16 {
621    const DTYPE: DType = DType::Int16;
622}
623
624impl HasDType for i32 {
625    const DTYPE: DType = DType::Int32;
626}
627
628impl HasDType for i64 {
629    const DTYPE: DType = DType::Int64;
630}
631
632impl HasDType for u8 {
633    const DTYPE: DType = DType::UInt8;
634}
635
636impl HasDType for u16 {
637    const DTYPE: DType = DType::UInt16;
638}
639
640impl HasDType for u32 {
641    const DTYPE: DType = DType::UInt32;
642}
643
644impl HasDType for u64 {
645    const DTYPE: DType = DType::UInt64;
646}
647
648impl HasDType for bool {
649    const DTYPE: DType = DType::Bool;
650}