1use std::fmt::{Debug, Display, Formatter};
5use std::hash::Hash;
6use std::sync::Arc;
7
8use DType::*;
9use itertools::Itertools;
10use static_assertions::const_assert_eq;
11use vortex_error::vortex_panic;
12
13use crate::decimal::DecimalDType;
14use crate::nullability::Nullability;
15use crate::{ExtDType, FieldDType, FieldName, PType, StructFields};
16
17#[derive(Debug, Clone, PartialEq, Eq, Hash)]
41#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
42pub enum DType {
43 Null,
47
48 Bool(Nullability),
53
54 Primitive(PType, Nullability),
58
59 Decimal(DecimalDType, Nullability),
63
64 Utf8(Nullability),
66
67 Binary(Nullability),
69
70 List(Arc<DType>, Nullability),
75
76 Struct(StructFields, Nullability),
81
82 Extension(Arc<ExtDType>),
86}
87
88#[cfg(not(target_arch = "wasm32"))]
89const_assert_eq!(size_of::<DType>(), 16);
90
91#[cfg(target_arch = "wasm32")]
92const_assert_eq!(size_of::<DType>(), 8);
93
94impl DType {
95 pub const BYTES: Self = Primitive(PType::U8, Nullability::NonNullable);
97
98 pub fn nullability(&self) -> Nullability {
100 self.is_nullable().into()
101 }
102
103 pub fn is_nullable(&self) -> bool {
105 match self {
106 Null => true,
107 Extension(ext_dtype) => ext_dtype.storage_dtype().is_nullable(),
108 Bool(null)
109 | Primitive(_, null)
110 | Decimal(_, null)
111 | Utf8(null)
112 | Binary(null)
113 | Struct(_, null)
114 | List(_, null) => matches!(null, Nullability::Nullable),
115 }
116 }
117
118 pub fn as_nonnullable(&self) -> Self {
120 self.with_nullability(Nullability::NonNullable)
121 }
122
123 pub fn as_nullable(&self) -> Self {
125 self.with_nullability(Nullability::Nullable)
126 }
127
128 pub fn with_nullability(&self, nullability: Nullability) -> Self {
130 match self {
131 Null => Null,
132 Bool(_) => Bool(nullability),
133 Primitive(pdt, _) => Primitive(*pdt, nullability),
134 Decimal(ddt, _) => Decimal(*ddt, nullability),
135 Utf8(_) => Utf8(nullability),
136 Binary(_) => Binary(nullability),
137 Struct(sf, _) => Struct(sf.clone(), nullability),
138 List(edt, _) => List(edt.clone(), nullability),
139 Extension(ext) => Extension(Arc::new(ext.with_nullability(nullability))),
140 }
141 }
142
143 pub fn union_nullability(&self, other: Nullability) -> Self {
145 let nullability = self.nullability() | other;
146 self.with_nullability(nullability)
147 }
148
149 pub fn eq_ignore_nullability(&self, other: &Self) -> bool {
151 match (self, other) {
152 (Null, Null) => true,
153 (Bool(_), Bool(_)) => true,
154 (Primitive(lhs_ptype, _), Primitive(rhs_ptype, _)) => lhs_ptype == rhs_ptype,
155 (Decimal(lhs, _), Decimal(rhs, _)) => lhs == rhs,
156 (Utf8(_), Utf8(_)) => true,
157 (Binary(_), Binary(_)) => true,
158 (List(lhs_dtype, _), List(rhs_dtype, _)) => lhs_dtype.eq_ignore_nullability(rhs_dtype),
159 (Struct(lhs_dtype, _), Struct(rhs_dtype, _)) => {
160 (lhs_dtype.names() == rhs_dtype.names())
161 && (lhs_dtype
162 .fields()
163 .zip_eq(rhs_dtype.fields())
164 .all(|(l, r)| l.eq_ignore_nullability(&r)))
165 }
166 (Extension(lhs_extdtype), Extension(rhs_extdtype)) => {
167 lhs_extdtype.as_ref().eq_ignore_nullability(rhs_extdtype)
168 }
169 _ => false,
170 }
171 }
172
173 pub fn is_struct(&self) -> bool {
175 matches!(self, Struct(_, _))
176 }
177
178 pub fn is_list(&self) -> bool {
180 matches!(self, List(_, _))
181 }
182
183 pub fn is_primitive(&self) -> bool {
185 matches!(self, Primitive(_, _))
186 }
187
188 pub fn as_ptype(&self) -> PType {
190 if let Primitive(ptype, _) = self {
191 *ptype
192 } else {
193 vortex_panic!("DType is not a primitive type")
194 }
195 }
196
197 pub fn is_unsigned_int(&self) -> bool {
199 if let Primitive(ptype, _) = self {
200 return ptype.is_unsigned_int();
201 }
202 false
203 }
204
205 pub fn is_signed_int(&self) -> bool {
207 if let Primitive(ptype, _) = self {
208 return ptype.is_signed_int();
209 }
210 false
211 }
212
213 pub fn is_int(&self) -> bool {
215 if let Primitive(ptype, _) = self {
216 return ptype.is_int();
217 }
218 false
219 }
220
221 pub fn is_float(&self) -> bool {
223 if let Primitive(ptype, _) = self {
224 return ptype.is_float();
225 }
226 false
227 }
228
229 pub fn is_boolean(&self) -> bool {
231 matches!(self, Bool(_))
232 }
233
234 pub fn is_binary(&self) -> bool {
236 matches!(self, Binary(_))
237 }
238
239 pub fn is_utf8(&self) -> bool {
241 matches!(self, Utf8(_))
242 }
243
244 pub fn is_extension(&self) -> bool {
246 matches!(self, Extension(_))
247 }
248
249 pub fn is_decimal(&self) -> bool {
251 matches!(self, Decimal(..))
252 }
253
254 pub fn as_decimal_opt(&self) -> Option<&DecimalDType> {
256 if let Decimal(decimal, _) = self {
257 Some(decimal)
258 } else {
259 None
260 }
261 }
262
263 pub fn as_struct_opt(&self) -> Option<&StructFields> {
265 if let Struct(f, _) = self {
266 Some(f)
267 } else {
268 None
269 }
270 }
271
272 pub fn as_list_element_opt(&self) -> Option<&Arc<DType>> {
274 if let List(s, _) = self { Some(s) } else { None }
275 }
276
277 pub fn struct_<I: IntoIterator<Item = (impl Into<FieldName>, impl Into<FieldDType>)>>(
279 iter: I,
280 nullability: Nullability,
281 ) -> Self {
282 Struct(StructFields::from_iter(iter), nullability)
283 }
284
285 pub fn list(dtype: impl Into<DType>, nullability: Nullability) -> Self {
287 List(Arc::new(dtype.into()), nullability)
288 }
289}
290
291impl Display for DType {
292 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
293 match self {
294 Null => write!(f, "null"),
295 Bool(null) => write!(f, "bool{null}"),
296 Primitive(pdt, null) => write!(f, "{pdt}{null}"),
297 Decimal(ddt, null) => write!(f, "{ddt}{null}"),
298 Utf8(null) => write!(f, "utf8{null}"),
299 Binary(null) => write!(f, "binary{null}"),
300 Struct(sf, null) => write!(
301 f,
302 "{{{}}}{null}",
303 sf.names()
304 .iter()
305 .zip(sf.fields())
306 .map(|(field_null, dt)| format!("{field_null}={dt}"))
307 .join(", "),
308 ),
309 List(edt, null) => write!(f, "list({edt}){null}"),
310 Extension(ext) => write!(
311 f,
312 "ext({}, {}{}){}",
313 ext.id(),
314 ext.storage_dtype()
315 .with_nullability(Nullability::NonNullable),
316 ext.metadata()
317 .map(|m| format!(", {m:?}"))
318 .unwrap_or_else(|| "".to_string()),
319 ext.storage_dtype().nullability(),
320 ),
321 }
322 }
323}