1pub mod cast;
2pub mod ext;
3
4#[cfg(any(test, feature = "proptest"))]
5pub mod test;
6
7use std::path::PathBuf;
8
9#[derive(Debug, Clone, PartialEq, Eq, Hash)]
15#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
16pub enum DeviceSpec {
17 Cpu,
19 Cuda { device_id: usize },
21 Metal { device_id: usize },
23 WebGpu,
25 Disk { path: PathBuf },
28}
29
30impl DeviceSpec {
31 pub fn canonicalize(&self) -> String {
43 match self {
44 DeviceSpec::Cpu => "CPU".to_string(),
45 DeviceSpec::Cuda { device_id } => format!("CUDA:{device_id}"),
46 DeviceSpec::Metal { device_id } => format!("Metal:{device_id}"),
47 DeviceSpec::WebGpu => "WebGPU".to_string(),
48 DeviceSpec::Disk { path } => format!("DISK:{}", path.display()),
49 }
50 }
51
52 pub fn max_buffers(&self) -> Option<usize> {
62 match self {
63 DeviceSpec::Cpu | DeviceSpec::Disk { .. } => None,
64 DeviceSpec::Cuda { .. } => Some(128),
65 DeviceSpec::Metal { .. } => Some(31),
66 DeviceSpec::WebGpu => Some(8),
67 }
68 }
69
70 pub fn base_type(&self) -> &'static str {
85 match self {
86 DeviceSpec::Cpu => "CPU",
87 DeviceSpec::Cuda { .. } => "CUDA",
88 DeviceSpec::Metal { .. } => "METAL",
89 DeviceSpec::WebGpu => "WEBGPU",
90 DeviceSpec::Disk { .. } => "DISK",
91 }
92 }
93
94 pub fn is_disk(&self) -> bool {
96 matches!(self, DeviceSpec::Disk { .. })
97 }
98}
99
100#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
102#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
103pub enum AddrSpace {
104 Global,
106 Local,
108 Reg,
110}
111
112#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
114#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
115pub enum ImageKind {
116 Half,
118 Float,
120}
121
122#[derive(Debug, Hash, PartialOrd, Ord)]
124#[derive(strum::EnumCount, strum::EnumIter, strum::VariantArray, strum::FromRepr)]
125#[derive(enumset::EnumSetType)]
126#[cfg_attr(feature = "proptest", derive(proptest_derive::Arbitrary))]
127#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
128#[enumset(repr = "u32")]
129pub enum ScalarDType {
130 Bool = 0,
131
132 Int8 = 1,
134 UInt8 = 2,
135 Int16 = 3,
136 UInt16 = 4,
137 Int32 = 5,
138 UInt32 = 6,
139 Int64 = 7,
140 UInt64 = 8,
141
142 FP8E4M3 = 9,
143 FP8E5M2 = 10,
144 Float16 = 11,
145 BFloat16 = 12,
146 Float32 = 13,
147 Float64 = 14,
148
149 Void = 15,
151
152 Index = 16,
154}
155
156#[derive(Debug, Clone, PartialEq, Eq, Hash)]
158#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
159pub enum DType {
160 Scalar(ScalarDType),
162
163 Vector { scalar: ScalarDType, count: usize },
165
166 Ptr { base: Box<DType>, addrspace: AddrSpace, size: Option<usize>, vcount: usize },
170
171 Image { kind: ImageKind, shape: Vec<usize> },
173}
174
175impl ScalarDType {
176 pub const fn bytes(&self) -> usize {
177 match self {
178 Self::Bool => 1,
179 Self::Int8 => 1,
180 Self::Int16 => 2,
181 Self::Int32 => 4,
182 Self::Int64 => 8,
183 Self::UInt8 => 1,
184 Self::UInt16 => 2,
185 Self::UInt32 => 4,
186 Self::UInt64 => 8,
187 Self::FP8E4M3 => 1,
188 Self::FP8E5M2 => 1,
189 Self::Float16 => 2,
190 Self::BFloat16 => 2,
191 Self::Float32 => 4,
192 Self::Float64 => 8,
193 Self::Void => 0,
194 Self::Index => 8, }
196 }
197
198 pub const fn is_bool(&self) -> bool {
199 matches!(self, Self::Bool)
200 }
201
202 pub const fn is_signed(&self) -> bool {
203 matches!(self, Self::Int8 | Self::Int16 | Self::Int32 | Self::Int64)
204 }
205
206 pub const fn is_unsigned(&self) -> bool {
207 matches!(self, Self::UInt8 | Self::UInt16 | Self::UInt32 | Self::UInt64)
208 }
209
210 pub const fn is_int(&self) -> bool {
211 self.is_signed() || self.is_unsigned() || matches!(self, Self::Index)
212 }
213
214 pub const fn is_float(&self) -> bool {
215 matches!(self, Self::FP8E4M3 | Self::FP8E5M2 | Self::Float16 | Self::BFloat16 | Self::Float32 | Self::Float64)
216 }
217
218 pub const fn is_fp8(&self) -> bool {
219 matches!(self, Self::FP8E4M3 | Self::FP8E5M2)
220 }
221
222 pub const fn min_value(&self) -> f64 {
223 match self {
224 Self::Bool => 0.0,
225 Self::Int8 => i8::MIN as f64,
226 Self::Int16 => i16::MIN as f64,
227 Self::Int32 => i32::MIN as f64,
228 Self::Int64 => i64::MIN as f64,
229 Self::UInt8 | Self::UInt16 | Self::UInt32 | Self::UInt64 => 0.0,
230 Self::Float16 => -65504.0,
231 Self::BFloat16 => -3.3895313892515355e38,
232 Self::Float32 => f32::MIN as f64,
233 Self::Float64 => f64::MIN,
234 Self::FP8E4M3 => -448.0,
235 Self::FP8E5M2 => -57344.0,
236 Self::Void | Self::Index => 0.0,
237 }
238 }
239
240 pub const fn max_value(&self) -> f64 {
241 match self {
242 Self::Bool => 1.0,
243 Self::Int8 => i8::MAX as f64,
244 Self::Int16 => i16::MAX as f64,
245 Self::Int32 => i32::MAX as f64,
246 Self::Int64 => i64::MAX as f64,
247 Self::UInt8 => u8::MAX as f64,
248 Self::UInt16 => u16::MAX as f64,
249 Self::UInt32 => u32::MAX as f64,
250 Self::UInt64 => u64::MAX as f64,
251 Self::Float16 => 65504.0,
252 Self::BFloat16 => 3.3895313892515355e38,
253 Self::Float32 => f32::MAX as f64,
254 Self::Float64 => f64::MAX,
255 Self::FP8E4M3 => 448.0,
256 Self::FP8E5M2 => 57344.0,
257 Self::Void | Self::Index => 0.0,
258 }
259 }
260
261 pub const fn c_style(&self) -> &'static str {
262 match self {
263 Self::Bool => "bool",
264 Self::Int8 => "signed char",
265 Self::Int16 => "short",
266 Self::Int32 => "int",
267 Self::Int64 => "long",
268 Self::UInt8 => "unsigned char",
269 Self::UInt16 => "unsigned short",
270 Self::UInt32 => "unsigned int",
271 Self::UInt64 => "unsigned long",
272 Self::FP8E4M3 => "float8_e4m3",
273 Self::FP8E5M2 => "float8_e5m2",
274 Self::Float16 => "half",
275 Self::Float32 => "float",
276 Self::Float64 => "double",
277 Self::BFloat16 => "__bf16",
278 Self::Void => "void",
279 Self::Index => "size_t",
280 }
281 }
282
283 pub const fn min_positive(&self) -> f64 {
284 match self {
285 Self::Float16 => 6.103515625e-05, Self::BFloat16 => 1.175494350822288e-38, Self::Float32 => 1.1754944e-38, Self::Float64 => 2.2250738585072014e-308, _ => 1.1754944e-38, }
291 }
292
293 pub const fn finfo(&self) -> (u32, u32) {
296 match self {
297 Self::FP8E4M3 => (4, 3),
298 Self::FP8E5M2 => (5, 2),
299 Self::Float16 => (5, 10),
300 Self::BFloat16 => (8, 7),
301 Self::Float32 => (8, 23),
302 Self::Float64 => (11, 52),
303 _ => panic!("finfo: not a float type"),
304 }
305 }
306
307 pub const fn exponent_bias(&self) -> i32 {
309 let (e, _) = self.finfo();
310 (1 << (e - 1)) - 1
311 }
312
313 pub const fn float_to_uint(&self) -> ScalarDType {
315 match self {
316 Self::FP8E4M3 | Self::FP8E5M2 => Self::UInt8,
317 Self::Float16 | Self::BFloat16 => Self::UInt16,
318 Self::Float32 => Self::UInt32,
319 Self::Float64 => Self::UInt64,
320 _ => panic!("float_to_uint: not a float type"),
321 }
322 }
323
324 pub const fn bitsize(&self) -> u32 {
326 (self.bytes() * 8) as u32
327 }
328
329 pub const fn vec(self, count: usize) -> DType {
331 DType::Vector { scalar: self, count }
332 }
333}
334
335impl From<ScalarDType> for DType {
336 fn from(scalar: ScalarDType) -> Self {
337 Self::Scalar(scalar)
338 }
339}
340
341impl DType {
342 pub fn vec(&self, count: usize) -> Self {
348 if count == 1 {
349 return self.clone();
350 }
351
352 match self {
353 Self::Scalar(s) if !matches!(s, ScalarDType::Void) => Self::Vector { scalar: *s, count },
354 Self::Vector { .. } => panic!("Cannot vectorize an already vectorized type"),
355 Self::Ptr { vcount: 1, base, addrspace, size } => {
356 Self::Ptr { base: base.clone(), addrspace: *addrspace, size: *size, vcount: count }
357 }
358 Self::Ptr { vcount, .. } if *vcount == count => self.clone(),
362 Self::Ptr { vcount, .. } => {
363 panic!("Cannot vectorize an already vectorized pointer (vcount={vcount}) to different count ({count})")
364 }
365 _ => self.clone(),
366 }
367 }
368
369 pub fn ptr(self, size: Option<usize>, addrspace: AddrSpace) -> Self {
371 match self {
372 Self::Ptr { .. } => panic!("Cannot make a pointer from a pointer"),
373 _ => Self::Ptr { base: Box::new(self), addrspace, size, vcount: 1 },
374 }
375 }
376
377 pub fn scalar(&self) -> Option<ScalarDType> {
378 match self {
379 Self::Scalar(s) => Some(*s),
380 _ => None,
381 }
382 }
383
384 pub fn is_vector(&self) -> bool {
386 matches!(self, Self::Vector { .. })
387 }
388
389 pub fn is_image(&self) -> bool {
391 matches!(self, Self::Image { .. })
392 }
393
394 pub fn base(&self) -> ScalarDType {
396 match self {
397 Self::Scalar(s) => *s,
398 Self::Vector { scalar, .. } => *scalar,
399 Self::Ptr { base, .. } => base.base(),
400 Self::Image { .. } => ScalarDType::Float32, }
402 }
403
404 pub fn scalar_dtype(&self) -> DType {
422 DType::Scalar(self.base())
423 }
424
425 pub fn with_base(&self, new_base: ScalarDType) -> Self {
429 let count = self.vcount();
430 if count > 1 { Self::Scalar(new_base).vec(count) } else { Self::Scalar(new_base) }
431 }
432
433 pub fn with_ptr_base(&self, new_base: DType) -> Option<Self> {
436 match self {
437 Self::Ptr { addrspace, size, vcount, .. } => {
438 Some(Self::Ptr { base: Box::new(new_base), addrspace: *addrspace, size: *size, vcount: *vcount })
439 }
440 _ => None,
441 }
442 }
443
444 pub fn count(&self) -> usize {
446 match self {
447 Self::Vector { count, .. } => *count,
448 _ => 1,
449 }
450 }
451
452 pub fn vcount(&self) -> usize {
454 match self {
455 Self::Vector { count, .. } => *count,
456 Self::Ptr { vcount, .. } => *vcount,
457 _ => 1,
458 }
459 }
460
461 pub fn bytes(&self) -> usize {
466 match self {
467 Self::Scalar(s) => s.bytes(),
468 Self::Vector { scalar, count } => scalar.bytes() * count,
469 Self::Ptr { .. } => 8, Self::Image { .. } => 8, }
472 }
473
474 pub fn is_bool(&self) -> bool {
475 self.base() == ScalarDType::Bool
477 }
478
479 pub fn is_signed(&self) -> bool {
480 self.base().is_signed()
482 }
483
484 pub fn is_unsigned(&self) -> bool {
485 self.base().is_unsigned()
487 }
488
489 pub fn is_int(&self) -> bool {
490 self.base().is_int()
492 }
493
494 pub fn is_float(&self) -> bool {
495 self.base().is_float()
496 }
497
498 pub fn is_fp8(&self) -> bool {
499 self.base().is_fp8()
500 }
501
502 pub fn min_value(&self) -> f64 {
503 self.base().min_value()
504 }
505
506 pub fn max_value(&self) -> f64 {
507 self.base().max_value()
508 }
509
510 pub fn c_style(&self) -> String {
511 match self {
512 Self::Scalar(s) => s.c_style().to_string(),
513 Self::Vector { scalar, count } => format!("{}[{}]", scalar.c_style(), count),
514 Self::Ptr { base, addrspace, .. } => {
515 let addr_str = match addrspace {
516 AddrSpace::Global => "__global",
517 AddrSpace::Local => "__local",
518 AddrSpace::Reg => "__register",
519 };
520 format!("{} {}*", addr_str, base.c_style())
521 }
522 Self::Image { kind, .. } => match kind {
523 ImageKind::Half => "image2d_t".to_string(),
524 ImageKind::Float => "image2d_t".to_string(),
525 },
526 }
527 }
528}
529
530impl DType {
532 pub const fn bool_() -> Self {
533 Self::Scalar(ScalarDType::Bool)
534 }
535 pub const fn int8() -> Self {
536 Self::Scalar(ScalarDType::Int8)
537 }
538 pub const fn int16() -> Self {
539 Self::Scalar(ScalarDType::Int16)
540 }
541 pub const fn int32() -> Self {
542 Self::Scalar(ScalarDType::Int32)
543 }
544 pub const fn int64() -> Self {
545 Self::Scalar(ScalarDType::Int64)
546 }
547 pub const fn uint8() -> Self {
548 Self::Scalar(ScalarDType::UInt8)
549 }
550 pub const fn uint16() -> Self {
551 Self::Scalar(ScalarDType::UInt16)
552 }
553 pub const fn uint32() -> Self {
554 Self::Scalar(ScalarDType::UInt32)
555 }
556 pub const fn uint64() -> Self {
557 Self::Scalar(ScalarDType::UInt64)
558 }
559 pub const fn float16() -> Self {
560 Self::Scalar(ScalarDType::Float16)
561 }
562 pub const fn bfloat16() -> Self {
563 Self::Scalar(ScalarDType::BFloat16)
564 }
565 pub const fn float32() -> Self {
566 Self::Scalar(ScalarDType::Float32)
567 }
568 pub const fn float64() -> Self {
569 Self::Scalar(ScalarDType::Float64)
570 }
571 pub const fn void_() -> Self {
572 Self::Scalar(ScalarDType::Void)
573 }
574 pub const fn index() -> Self {
575 Self::Scalar(ScalarDType::Index)
576 }
577}
578
579#[allow(non_upper_case_globals)]
581impl DType {
582 pub const Bool: Self = Self::Scalar(ScalarDType::Bool);
583 pub const Int8: Self = Self::Scalar(ScalarDType::Int8);
584 pub const Int16: Self = Self::Scalar(ScalarDType::Int16);
585 pub const Int32: Self = Self::Scalar(ScalarDType::Int32);
586 pub const Int64: Self = Self::Scalar(ScalarDType::Int64);
587 pub const UInt8: Self = Self::Scalar(ScalarDType::UInt8);
588 pub const UInt16: Self = Self::Scalar(ScalarDType::UInt16);
589 pub const UInt32: Self = Self::Scalar(ScalarDType::UInt32);
590 pub const UInt64: Self = Self::Scalar(ScalarDType::UInt64);
591 pub const FP8E4M3: Self = Self::Scalar(ScalarDType::FP8E4M3);
592 pub const FP8E5M2: Self = Self::Scalar(ScalarDType::FP8E5M2);
593 pub const Float16: Self = Self::Scalar(ScalarDType::Float16);
594 pub const BFloat16: Self = Self::Scalar(ScalarDType::BFloat16);
595 pub const Float32: Self = Self::Scalar(ScalarDType::Float32);
596 pub const Float64: Self = Self::Scalar(ScalarDType::Float64);
597 pub const Void: Self = Self::Scalar(ScalarDType::Void);
598 pub const Index: Self = Self::Scalar(ScalarDType::Index);
599}
600
601pub trait HasDType: Clone + Default {
605 const DTYPE: DType;
606}
607
608impl HasDType for f32 {
609 const DTYPE: DType = DType::Float32;
610}
611
612impl HasDType for f64 {
613 const DTYPE: DType = DType::Float64;
614}
615
616impl HasDType for i8 {
617 const DTYPE: DType = DType::Int8;
618}
619
620impl HasDType for i16 {
621 const DTYPE: DType = DType::Int16;
622}
623
624impl HasDType for i32 {
625 const DTYPE: DType = DType::Int32;
626}
627
628impl HasDType for i64 {
629 const DTYPE: DType = DType::Int64;
630}
631
632impl HasDType for u8 {
633 const DTYPE: DType = DType::UInt8;
634}
635
636impl HasDType for u16 {
637 const DTYPE: DType = DType::UInt16;
638}
639
640impl HasDType for u32 {
641 const DTYPE: DType = DType::UInt32;
642}
643
644impl HasDType for u64 {
645 const DTYPE: DType = DType::UInt64;
646}
647
648impl HasDType for bool {
649 const DTYPE: DType = DType::Bool;
650}