1use std::fmt::{Debug, Display, Formatter};
5use std::hash::Hash;
6use std::sync::Arc;
7
8use DType::*;
9use itertools::Itertools;
10use static_assertions::const_assert_eq;
11use vortex_error::vortex_panic;
12
13use crate::decimal::DecimalDType;
14use crate::nullability::Nullability;
15use crate::{ExtDType, FieldDType, FieldName, PType, StructFields};
16
17#[derive(Debug, Clone, PartialEq, Eq, Hash)]
41#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
42pub enum DType {
43 Null,
47
48 Bool(Nullability),
53
54 Primitive(PType, Nullability),
58
59 Decimal(DecimalDType, Nullability),
63
64 Utf8(Nullability),
66
67 Binary(Nullability),
69
70 List(Arc<DType>, Nullability),
75
76 FixedSizeList(Arc<DType>, u32, Nullability),
84
85 Struct(StructFields, Nullability),
90
91 Extension(Arc<ExtDType>),
95}
96
97#[cfg(not(target_arch = "wasm32"))]
98const_assert_eq!(size_of::<DType>(), 16);
99
100#[cfg(target_arch = "wasm32")]
101const_assert_eq!(size_of::<DType>(), 12);
102
103impl DType {
104 pub const BYTES: Self = Primitive(PType::U8, Nullability::NonNullable);
106
107 pub fn nullability(&self) -> Nullability {
109 self.is_nullable().into()
110 }
111
112 pub fn is_nullable(&self) -> bool {
114 match self {
115 Null => true,
116 Extension(ext_dtype) => ext_dtype.storage_dtype().is_nullable(),
117 Bool(null)
118 | Primitive(_, null)
119 | Decimal(_, null)
120 | Utf8(null)
121 | Binary(null)
122 | Struct(_, null)
123 | List(_, null)
124 | FixedSizeList(_, _, null) => matches!(null, Nullability::Nullable),
125 }
126 }
127
128 pub fn as_nonnullable(&self) -> Self {
130 self.with_nullability(Nullability::NonNullable)
131 }
132
133 pub fn as_nullable(&self) -> Self {
135 self.with_nullability(Nullability::Nullable)
136 }
137
138 pub fn with_nullability(&self, nullability: Nullability) -> Self {
140 match self {
141 Null => Null,
142 Bool(_) => Bool(nullability),
143 Primitive(pdt, _) => Primitive(*pdt, nullability),
144 Decimal(ddt, _) => Decimal(*ddt, nullability),
145 Utf8(_) => Utf8(nullability),
146 Binary(_) => Binary(nullability),
147 Struct(sf, _) => Struct(sf.clone(), nullability),
148 List(edt, _) => List(edt.clone(), nullability),
149 FixedSizeList(edt, size, _) => FixedSizeList(edt.clone(), *size, nullability),
150 Extension(ext) => Extension(Arc::new(ext.with_nullability(nullability))),
151 }
152 }
153
154 pub fn union_nullability(&self, other: Nullability) -> Self {
156 let nullability = self.nullability() | other;
157 self.with_nullability(nullability)
158 }
159
160 pub fn eq_ignore_nullability(&self, other: &Self) -> bool {
162 match (self, other) {
163 (Null, Null) => true,
164 (Bool(_), Bool(_)) => true,
165 (Primitive(lhs_ptype, _), Primitive(rhs_ptype, _)) => lhs_ptype == rhs_ptype,
166 (Decimal(lhs, _), Decimal(rhs, _)) => lhs == rhs,
167 (Utf8(_), Utf8(_)) => true,
168 (Binary(_), Binary(_)) => true,
169 (List(lhs_dtype, _), List(rhs_dtype, _)) => lhs_dtype.eq_ignore_nullability(rhs_dtype),
170 (FixedSizeList(lhs_dtype, lhs_size, _), FixedSizeList(rhs_dtype, rhs_size, _)) => {
171 lhs_size == rhs_size && lhs_dtype.eq_ignore_nullability(rhs_dtype)
172 }
173 (Struct(lhs_dtype, _), Struct(rhs_dtype, _)) => {
174 (lhs_dtype.names() == rhs_dtype.names())
175 && (lhs_dtype
176 .fields()
177 .zip_eq(rhs_dtype.fields())
178 .all(|(l, r)| l.eq_ignore_nullability(&r)))
179 }
180 (Extension(lhs_extdtype), Extension(rhs_extdtype)) => {
181 lhs_extdtype.as_ref().eq_ignore_nullability(rhs_extdtype)
182 }
183 _ => false,
184 }
185 }
186
187 pub fn is_boolean(&self) -> bool {
189 matches!(self, Bool(_))
190 }
191
192 pub fn is_primitive(&self) -> bool {
194 matches!(self, Primitive(_, _))
195 }
196
197 pub fn as_ptype(&self) -> PType {
199 if let Primitive(ptype, _) = self {
200 *ptype
201 } else {
202 vortex_panic!("DType is not a primitive type")
203 }
204 }
205
206 pub fn is_unsigned_int(&self) -> bool {
208 if let Primitive(ptype, _) = self {
209 return ptype.is_unsigned_int();
210 }
211 false
212 }
213
214 pub fn is_signed_int(&self) -> bool {
216 if let Primitive(ptype, _) = self {
217 return ptype.is_signed_int();
218 }
219 false
220 }
221
222 pub fn is_int(&self) -> bool {
224 if let Primitive(ptype, _) = self {
225 return ptype.is_int();
226 }
227 false
228 }
229
230 pub fn is_float(&self) -> bool {
232 if let Primitive(ptype, _) = self {
233 return ptype.is_float();
234 }
235 false
236 }
237
238 pub fn is_decimal(&self) -> bool {
240 matches!(self, Decimal(..))
241 }
242
243 pub fn is_utf8(&self) -> bool {
245 matches!(self, Utf8(_))
246 }
247
248 pub fn is_binary(&self) -> bool {
250 matches!(self, Binary(_))
251 }
252
253 pub fn is_list(&self) -> bool {
255 matches!(self, List(_, _))
256 }
257
258 pub fn is_fixed_size_list(&self) -> bool {
260 matches!(self, FixedSizeList(..))
261 }
262
263 pub fn is_struct(&self) -> bool {
265 matches!(self, Struct(_, _))
266 }
267
268 pub fn is_extension(&self) -> bool {
270 matches!(self, Extension(_))
271 }
272
273 pub fn as_decimal_opt(&self) -> Option<&DecimalDType> {
275 if let Decimal(decimal, _) = self {
276 Some(decimal)
277 } else {
278 None
279 }
280 }
281
282 pub fn as_list_element_opt(&self) -> Option<&Arc<DType>> {
285 if let List(edt, _) = self {
286 Some(edt)
287 } else if let FixedSizeList(edt, ..) = self {
288 Some(edt)
289 } else {
290 None
291 }
292 }
293
294 pub fn as_struct_opt(&self) -> Option<&StructFields> {
297 if let Struct(f, _) = self {
298 Some(f)
299 } else {
300 None
301 }
302 }
303
304 pub fn list(dtype: impl Into<DType>, nullability: Nullability) -> Self {
306 List(Arc::new(dtype.into()), nullability)
307 }
308
309 pub fn struct_<I: IntoIterator<Item = (impl Into<FieldName>, impl Into<FieldDType>)>>(
311 iter: I,
312 nullability: Nullability,
313 ) -> Self {
314 Struct(StructFields::from_iter(iter), nullability)
315 }
316}
317
318impl Display for DType {
319 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
320 match self {
321 Null => write!(f, "null"),
322 Bool(null) => write!(f, "bool{null}"),
323 Primitive(pdt, null) => write!(f, "{pdt}{null}"),
324 Decimal(ddt, null) => write!(f, "{ddt}{null}"),
325 Utf8(null) => write!(f, "utf8{null}"),
326 Binary(null) => write!(f, "binary{null}"),
327 Struct(sf, null) => write!(
328 f,
329 "{{{}}}{null}",
330 sf.names()
331 .iter()
332 .zip(sf.fields())
333 .map(|(field_null, dt)| format!("{field_null}={dt}"))
334 .join(", "),
335 ),
336 List(edt, null) => write!(f, "list({edt}){null}"),
337 FixedSizeList(edt, size, null) => write!(f, "fixed_size_list({edt})[{size}]{null}"),
338 Extension(ext) => write!(
339 f,
340 "ext({}, {}{}){}",
341 ext.id(),
342 ext.storage_dtype()
343 .with_nullability(Nullability::NonNullable),
344 ext.metadata()
345 .map(|m| format!(", {m:?}"))
346 .unwrap_or_else(|| "".to_string()),
347 ext.storage_dtype().nullability(),
348 ),
349 }
350 }
351}