1use std::fmt::{Debug, Display, Formatter};
2use std::hash::Hash;
3use std::sync::Arc;
4
5use DType::*;
6use itertools::Itertools;
7use static_assertions::const_assert_eq;
8
9use crate::nullability::Nullability;
10use crate::{ExtDType, PType, StructDType};
11
12pub type FieldName = Arc<str>;
14pub type FieldNames = Arc<[FieldName]>;
16
17#[derive(Debug, Clone, PartialEq, Eq, Hash)]
22#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
23pub enum DType {
24 Null,
26 Bool(Nullability),
28 Primitive(PType, Nullability),
30 Utf8(Nullability),
32 Binary(Nullability),
34 Struct(Arc<StructDType>, Nullability),
36 List(Arc<DType>, Nullability),
38 Extension(Arc<ExtDType>),
40}
41
42#[cfg(not(target_arch = "wasm32"))]
43const_assert_eq!(size_of::<DType>(), 16);
44
45#[cfg(target_arch = "wasm32")]
46const_assert_eq!(size_of::<DType>(), 8);
47
48impl DType {
49 pub const BYTES: Self = Primitive(PType::U8, Nullability::NonNullable);
51
52 pub fn nullability(&self) -> Nullability {
54 self.is_nullable().into()
55 }
56
57 pub fn is_nullable(&self) -> bool {
59 use crate::nullability::Nullability::*;
60
61 match self {
62 Null => true,
63 Bool(n) => matches!(n, Nullable),
64 Primitive(_, n) => matches!(n, Nullable),
65 Utf8(n) => matches!(n, Nullable),
66 Binary(n) => matches!(n, Nullable),
67 Struct(_, n) => matches!(n, Nullable),
68 List(_, n) => matches!(n, Nullable),
69 Extension(ext_dtype) => ext_dtype.storage_dtype().is_nullable(),
70 }
71 }
72
73 pub fn as_nonnullable(&self) -> Self {
75 self.with_nullability(Nullability::NonNullable)
76 }
77
78 pub fn as_nullable(&self) -> Self {
80 self.with_nullability(Nullability::Nullable)
81 }
82
83 pub fn with_nullability(&self, nullability: Nullability) -> Self {
85 match self {
86 Null => Null,
87 Bool(_) => Bool(nullability),
88 Primitive(p, _) => Primitive(*p, nullability),
89 Utf8(_) => Utf8(nullability),
90 Binary(_) => Binary(nullability),
91 Struct(st, _) => Struct(st.clone(), nullability),
92 List(c, _) => List(c.clone(), nullability),
93 Extension(ext) => Extension(Arc::new(ext.with_nullability(nullability))),
94 }
95 }
96
97 pub fn eq_ignore_nullability(&self, other: &Self) -> bool {
99 match (self, other) {
100 (Null, Null) => true,
101 (Null, _) => false,
102 (Bool(_), Bool(_)) => true,
103 (Bool(_), _) => false,
104 (Primitive(lhs_ptype, _), Primitive(rhs_ptype, _)) => lhs_ptype == rhs_ptype,
105 (Primitive(..), _) => false,
106 (Utf8(_), Utf8(_)) => true,
107 (Utf8(_), _) => false,
108 (Binary(_), Binary(_)) => true,
109 (Binary(_), _) => false,
110 (List(lhs_dtype, _), List(rhs_dtype, _)) => lhs_dtype.eq_ignore_nullability(rhs_dtype),
111 (List(..), _) => false,
112 (Struct(lhs_dtype, _), Struct(rhs_dtype, _)) => {
113 (lhs_dtype.names() == rhs_dtype.names())
114 && (lhs_dtype
115 .fields()
116 .zip_eq(rhs_dtype.fields())
117 .all(|(l, r)| l.eq_ignore_nullability(&r)))
118 }
119 (Struct(..), _) => false,
120 (Extension(lhs_extdtype), Extension(rhs_extdtype)) => {
121 lhs_extdtype.as_ref().eq_ignore_nullability(rhs_extdtype)
122 }
123 (Extension(_), _) => false,
124 }
125 }
126
127 pub fn is_struct(&self) -> bool {
129 matches!(self, Struct(_, _))
130 }
131
132 pub fn is_unsigned_int(&self) -> bool {
134 PType::try_from(self).is_ok_and(PType::is_unsigned_int)
135 }
136
137 pub fn is_signed_int(&self) -> bool {
139 PType::try_from(self).is_ok_and(PType::is_signed_int)
140 }
141
142 pub fn is_int(&self) -> bool {
144 PType::try_from(self).is_ok_and(PType::is_int)
145 }
146
147 pub fn is_float(&self) -> bool {
149 PType::try_from(self).is_ok_and(PType::is_float)
150 }
151
152 pub fn is_boolean(&self) -> bool {
154 matches!(self, Bool(_))
155 }
156
157 pub fn is_extension(&self) -> bool {
159 matches!(self, Extension(_))
160 }
161
162 pub fn as_struct(&self) -> Option<&Arc<StructDType>> {
164 match self {
165 Struct(s, _) => Some(s),
166 _ => None,
167 }
168 }
169
170 pub fn as_list_element(&self) -> Option<&DType> {
172 match self {
173 List(s, _) => Some(s.as_ref()),
174 _ => None,
175 }
176 }
177}
178
179impl Display for DType {
180 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
181 match self {
182 Null => write!(f, "null"),
183 Bool(n) => write!(f, "bool{}", n),
184 Primitive(pt, n) => write!(f, "{}{}", pt, n),
185 Utf8(n) => write!(f, "utf8{}", n),
186 Binary(n) => write!(f, "binary{}", n),
187 Struct(sdt, n) => write!(
188 f,
189 "{{{}}}{}",
190 sdt.names()
191 .iter()
192 .zip(sdt.fields())
193 .map(|(n, dt)| format!("{}={}", n, dt))
194 .join(", "),
195 n
196 ),
197 List(edt, n) => write!(f, "list({}){}", edt, n),
198 Extension(ext) => write!(
199 f,
200 "ext({}, {}{}){}",
201 ext.id(),
202 ext.storage_dtype()
203 .with_nullability(Nullability::NonNullable),
204 ext.metadata()
205 .map(|m| format!(", {:?}", m))
206 .unwrap_or_else(|| "".to_string()),
207 ext.storage_dtype().nullability(),
208 ),
209 }
210 }
211}