1use std::fmt::Debug;
5use std::fmt::Display;
6use std::fmt::Formatter;
7use std::hash::Hash;
8use std::sync::Arc;
9
10use DType::*;
11use itertools::Itertools;
12use vortex_error::VortexExpect;
13use vortex_error::vortex_panic;
14
15use crate::FieldDType;
16use crate::FieldName;
17use crate::PType;
18use crate::StructFields;
19use crate::decimal::DecimalDType;
20use crate::decimal::DecimalType;
21use crate::extension::ExtDTypeRef;
22use crate::nullability::Nullability;
23
24#[derive(Debug, Clone, PartialEq, Eq, Hash)]
48pub enum DType {
49 Null,
53
54 Bool(Nullability),
59
60 Primitive(PType, Nullability),
64
65 Decimal(DecimalDType, Nullability),
69
70 Utf8(Nullability),
72
73 Binary(Nullability),
75
76 List(Arc<DType>, Nullability),
81
82 FixedSizeList(Arc<DType>, u32, Nullability),
87
88 Struct(StructFields, Nullability),
93
94 Extension(ExtDTypeRef),
98}
99
100pub trait NativeDType {
108 fn dtype() -> DType;
110}
111
112#[cfg(not(target_arch = "wasm32"))]
114const _: [(); size_of::<DType>()] = [(); 24]; #[cfg(target_arch = "wasm32")]
118const _: [(); size_of::<DType>()] = [(); 12];
119
120impl DType {
121 pub const BYTES: Self = Primitive(PType::U8, Nullability::NonNullable);
123
124 #[inline]
126 pub fn nullability(&self) -> Nullability {
127 self.is_nullable().into()
128 }
129
130 #[inline]
132 pub fn is_nullable(&self) -> bool {
133 match self {
134 Null => true,
135 Extension(ext_dtype) => ext_dtype.storage_dtype().is_nullable(),
136 Bool(null)
137 | Primitive(_, null)
138 | Decimal(_, null)
139 | Utf8(null)
140 | Binary(null)
141 | Struct(_, null)
142 | List(_, null)
143 | FixedSizeList(_, _, null) => matches!(null, Nullability::Nullable),
144 }
145 }
146
147 pub fn as_nonnullable(&self) -> Self {
149 self.with_nullability(Nullability::NonNullable)
150 }
151
152 pub fn as_nullable(&self) -> Self {
154 self.with_nullability(Nullability::Nullable)
155 }
156
157 pub fn with_nullability(&self, nullability: Nullability) -> Self {
159 match self {
160 Null => Null,
161 Bool(_) => Bool(nullability),
162 Primitive(pdt, _) => Primitive(*pdt, nullability),
163 Decimal(ddt, _) => Decimal(*ddt, nullability),
164 Utf8(_) => Utf8(nullability),
165 Binary(_) => Binary(nullability),
166 Struct(sf, _) => Struct(sf.clone(), nullability),
167 List(edt, _) => List(edt.clone(), nullability),
168 FixedSizeList(edt, size, _) => FixedSizeList(edt.clone(), *size, nullability),
169 Extension(ext) => Extension(ext.with_nullability(nullability)),
170 }
171 }
172
173 pub fn union_nullability(&self, other: Nullability) -> Self {
175 let nullability = self.nullability() | other;
176 self.with_nullability(nullability)
177 }
178
179 pub fn eq_ignore_nullability(&self, other: &Self) -> bool {
181 match (self, other) {
182 (Null, Null) => true,
183 (Bool(_), Bool(_)) => true,
184 (Primitive(lhs_ptype, _), Primitive(rhs_ptype, _)) => lhs_ptype == rhs_ptype,
185 (Decimal(lhs, _), Decimal(rhs, _)) => lhs == rhs,
186 (Utf8(_), Utf8(_)) => true,
187 (Binary(_), Binary(_)) => true,
188 (List(lhs_dtype, _), List(rhs_dtype, _)) => lhs_dtype.eq_ignore_nullability(rhs_dtype),
189 (FixedSizeList(lhs_dtype, lhs_size, _), FixedSizeList(rhs_dtype, rhs_size, _)) => {
190 lhs_size == rhs_size && lhs_dtype.eq_ignore_nullability(rhs_dtype)
191 }
192 (Struct(lhs_dtype, _), Struct(rhs_dtype, _)) => {
193 (lhs_dtype.names() == rhs_dtype.names())
194 && (lhs_dtype
195 .fields()
196 .zip_eq(rhs_dtype.fields())
197 .all(|(l, r)| l.eq_ignore_nullability(&r)))
198 }
199 (Extension(lhs_extdtype), Extension(rhs_extdtype)) => {
200 lhs_extdtype.eq_ignore_nullability(rhs_extdtype)
201 }
202 _ => false,
203 }
204 }
205
206 pub fn eq_with_nullability_subset(&self, other: &Self) -> bool {
215 if self.is_nullable() {
216 self == other
217 } else {
218 self.eq_ignore_nullability(other)
219 }
220 }
221
222 pub fn eq_with_nullability_superset(&self, other: &Self) -> bool {
233 if self.is_nullable() {
234 self.eq_ignore_nullability(other)
235 } else {
236 self == other
237 }
238 }
239
240 pub fn is_boolean(&self) -> bool {
242 matches!(self, Bool(_))
243 }
244
245 pub fn is_primitive(&self) -> bool {
247 matches!(self, Primitive(_, _))
248 }
249
250 pub fn as_ptype(&self) -> PType {
252 if let Primitive(ptype, _) = self {
253 *ptype
254 } else {
255 vortex_panic!("DType {self} is not a primitive type")
256 }
257 }
258
259 pub fn is_unsigned_int(&self) -> bool {
261 if let Primitive(ptype, _) = self {
262 return ptype.is_unsigned_int();
263 }
264 false
265 }
266
267 pub fn is_signed_int(&self) -> bool {
269 if let Primitive(ptype, _) = self {
270 return ptype.is_signed_int();
271 }
272 false
273 }
274
275 pub fn is_int(&self) -> bool {
277 if let Primitive(ptype, _) = self {
278 return ptype.is_int();
279 }
280 false
281 }
282
283 pub fn is_float(&self) -> bool {
285 if let Primitive(ptype, _) = self {
286 return ptype.is_float();
287 }
288 false
289 }
290
291 pub fn is_decimal(&self) -> bool {
293 matches!(self, Decimal(..))
294 }
295
296 pub fn is_utf8(&self) -> bool {
298 matches!(self, Utf8(_))
299 }
300
301 pub fn is_binary(&self) -> bool {
303 matches!(self, Binary(_))
304 }
305
306 pub fn is_list(&self) -> bool {
308 matches!(self, List(_, _))
309 }
310
311 pub fn is_fixed_size_list(&self) -> bool {
313 matches!(self, FixedSizeList(..))
314 }
315
316 pub fn is_struct(&self) -> bool {
318 matches!(self, Struct(_, _))
319 }
320
321 pub fn is_extension(&self) -> bool {
323 matches!(self, Extension(_))
324 }
325
326 pub fn is_nested(&self) -> bool {
329 match self {
330 List(..) | FixedSizeList(..) | Struct(..) => true,
331 Extension(ext) => ext.storage_dtype().is_nested(),
332 _ => false,
333 }
334 }
335
336 pub fn element_size(&self) -> Option<usize> {
342 match self {
343 Null => Some(0),
344 Bool(_) => Some(1),
345 Primitive(ptype, _) => Some(ptype.byte_width()),
346 Decimal(decimal, _) => {
347 Some(DecimalType::smallest_decimal_value_type(decimal).byte_width())
348 }
349 Utf8(_) | Binary(_) | List(..) => None,
350 FixedSizeList(elem_dtype, list_size, _) => {
351 elem_dtype.element_size().map(|s| s * *list_size as usize)
352 }
353 Struct(fields, ..) => {
354 let mut sum = 0_usize;
355 for f in fields.fields() {
356 let element_size = f.element_size()?;
357 sum = sum
358 .checked_add(element_size)
359 .vortex_expect("sum of field sizes is bigger than usize");
360 }
361 Some(sum)
362 }
363 Extension(ext) => ext.storage_dtype().element_size(),
364 }
365 }
366
367 pub fn as_decimal_opt(&self) -> Option<&DecimalDType> {
369 if let Decimal(decimal, _) = self {
370 Some(decimal)
371 } else {
372 None
373 }
374 }
375
376 pub fn into_decimal_opt(self) -> Option<DecimalDType> {
378 if let Decimal(decimal, _) = self {
379 Some(decimal)
380 } else {
381 None
382 }
383 }
384
385 pub fn as_list_element_opt(&self) -> Option<&Arc<DType>> {
389 if let List(edt, _) = self {
390 Some(edt)
391 } else {
392 None
393 }
394 }
395
396 pub fn into_list_element_opt(self) -> Option<Arc<DType>> {
398 if let List(edt, _) = self {
399 Some(edt)
400 } else {
401 None
402 }
403 }
404
405 pub fn as_fixed_size_list_element_opt(&self) -> Option<&Arc<DType>> {
410 if let FixedSizeList(edt, ..) = self {
411 Some(edt)
412 } else {
413 None
414 }
415 }
416
417 pub fn into_fixed_size_list_element_opt(self) -> Option<Arc<DType>> {
419 if let FixedSizeList(edt, ..) = self {
420 Some(edt)
421 } else {
422 None
423 }
424 }
425
426 pub fn as_any_size_list_element_opt(&self) -> Option<&Arc<DType>> {
429 if let FixedSizeList(edt, ..) = self {
430 Some(edt)
431 } else if let List(edt, ..) = self {
432 Some(edt)
433 } else {
434 None
435 }
436 }
437
438 pub fn into_any_size_list_element_opt(self) -> Option<Arc<DType>> {
440 if let FixedSizeList(edt, ..) = self {
441 Some(edt)
442 } else if let List(edt, ..) = self {
443 Some(edt)
444 } else {
445 None
446 }
447 }
448
449 pub fn as_struct_fields(&self) -> &StructFields {
455 if let Struct(f, _) = self {
456 return f;
457 }
458 vortex_panic!("DType is not a Struct")
459 }
460
461 pub fn into_struct_fields(self) -> StructFields {
463 if let Struct(f, _) = self {
464 return f;
465 }
466 vortex_panic!("DType is not a Struct")
467 }
468
469 pub fn as_struct_fields_opt(&self) -> Option<&StructFields> {
471 if let Struct(f, _) = self {
472 Some(f)
473 } else {
474 None
475 }
476 }
477
478 pub fn into_struct_fields_opt(self) -> Option<StructFields> {
480 if let Struct(f, _) = self {
481 Some(f)
482 } else {
483 None
484 }
485 }
486
487 pub fn as_extension(&self) -> &ExtDTypeRef {
489 let Extension(ext) = self else {
490 vortex_panic!("DType is not an Extension")
491 };
492 ext
493 }
494
495 pub fn as_extension_opt(&self) -> Option<&ExtDTypeRef> {
497 if let Extension(ext) = self {
498 Some(ext)
499 } else {
500 None
501 }
502 }
503
504 pub fn list(dtype: impl Into<DType>, nullability: Nullability) -> Self {
506 List(Arc::new(dtype.into()), nullability)
507 }
508
509 pub fn struct_<I: IntoIterator<Item = (impl Into<FieldName>, impl Into<FieldDType>)>>(
511 iter: I,
512 nullability: Nullability,
513 ) -> Self {
514 Struct(StructFields::from_iter(iter), nullability)
515 }
516}
517
518impl Display for DType {
519 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
520 match self {
521 Null => write!(f, "null"),
522 Bool(null) => write!(f, "bool{null}"),
523 Primitive(pdt, null) => write!(f, "{pdt}{null}"),
524 Decimal(ddt, null) => write!(f, "{ddt}{null}"),
525 Utf8(null) => write!(f, "utf8{null}"),
526 Binary(null) => write!(f, "binary{null}"),
527 Struct(sf, null) => write!(
528 f,
529 "{{{}}}{null}",
530 sf.names()
531 .iter()
532 .zip(sf.fields())
533 .map(|(field_null, dt)| format!("{field_null}={dt}"))
534 .join(", "),
535 ),
536 List(edt, null) => write!(f, "list({edt}){null}"),
537 FixedSizeList(edt, size, null) => write!(f, "fixed_size_list({edt})[{size}]{null}"),
538 Extension(ext) => write!(f, "{}", ext),
539 }
540 }
541}
542
543#[cfg(test)]
544mod tests {
545 use std::sync::Arc;
546
547 use crate::DType;
548 use crate::Nullability::NonNullable;
549 use crate::Nullability::Nullable;
550 use crate::PType;
551 use crate::datetime::Date;
552 use crate::datetime::Time;
553 use crate::datetime::TimeUnit;
554 use crate::datetime::Timestamp;
555 use crate::decimal::DecimalDType;
556
557 #[test]
558 fn test_ext_dtype_eq_ignore_nullability() {
559 let d1 = DType::Extension(Time::new(TimeUnit::Seconds, Nullable).erased());
560 let d2 = DType::Extension(Time::new(TimeUnit::Seconds, NonNullable).erased());
561 assert!(d1.eq_ignore_nullability(&d2));
562
563 let t1 = DType::Extension(
564 Timestamp::new_with_tz(TimeUnit::Seconds, Some("UTC".into()), Nullable).erased(),
565 );
566 let t2 = DType::Extension(
567 Timestamp::new_with_tz(TimeUnit::Seconds, Some("ET".into()), Nullable).erased(),
568 );
569 assert!(!t1.eq_ignore_nullability(&t2));
570 }
571
572 #[test]
573 fn element_size_null() {
574 assert_eq!(DType::Null.element_size(), Some(0));
575 }
576
577 #[test]
578 fn element_size_bool() {
579 assert_eq!(DType::Bool(NonNullable).element_size(), Some(1));
580 }
581
582 #[test]
583 fn element_size_primitives() {
584 assert_eq!(
585 DType::Primitive(PType::U8, NonNullable).element_size(),
586 Some(1)
587 );
588 assert_eq!(
589 DType::Primitive(PType::I32, NonNullable).element_size(),
590 Some(4)
591 );
592 assert_eq!(
593 DType::Primitive(PType::F64, NonNullable).element_size(),
594 Some(8)
595 );
596 }
597
598 #[test]
599 fn element_size_decimal() {
600 let decimal = DecimalDType::new(10, 2);
601 assert_eq!(DType::Decimal(decimal, NonNullable).element_size(), Some(8));
603 }
604
605 #[test]
606 fn element_size_fixed_size_list() {
607 let elem = Arc::new(DType::Primitive(PType::F64, NonNullable));
608 assert_eq!(
609 DType::FixedSizeList(elem.clone(), 1000, NonNullable).element_size(),
610 Some(8000)
611 );
612
613 assert_eq!(
614 DType::FixedSizeList(
615 Arc::new(DType::FixedSizeList(elem, 20, NonNullable)),
616 1000,
617 NonNullable
618 )
619 .element_size(),
620 Some(160_000)
621 );
622 }
623
624 #[test]
625 fn element_size_nested_fixed_size_list() {
626 let inner = Arc::new(DType::FixedSizeList(
627 Arc::new(DType::Primitive(PType::F64, NonNullable)),
628 10,
629 NonNullable,
630 ));
631 assert_eq!(
632 DType::FixedSizeList(inner, 100, NonNullable).element_size(),
633 Some(8000)
634 );
635 }
636
637 #[test]
638 fn element_size_extension() {
639 assert_eq!(
640 DType::Extension(Date::new(TimeUnit::Days, NonNullable).erased()).element_size(),
641 Some(4)
642 );
643 }
644
645 #[test]
646 fn element_size_variable_width() {
647 assert_eq!(DType::Utf8(NonNullable).element_size(), None);
648 assert_eq!(DType::Binary(NonNullable).element_size(), None);
649 assert_eq!(
650 DType::List(
651 Arc::new(DType::Primitive(PType::I32, NonNullable)),
652 NonNullable
653 )
654 .element_size(),
655 None
656 );
657 }
658}