1use std::fmt::{Debug, Display, Formatter};
2use std::hash::Hash;
3use std::sync::Arc;
4
5use DType::*;
6use itertools::Itertools;
7use static_assertions::const_assert_eq;
8use vortex_error::vortex_panic;
9
10use crate::decimal::DecimalDType;
11use crate::nullability::Nullability;
12use crate::{ExtDType, PType, StructFields};
13
14pub type FieldName = Arc<str>;
16pub type FieldNames = Arc<[FieldName]>;
18
19#[derive(Debug, Clone, PartialEq, Eq, Hash)]
24#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
25pub enum DType {
26 Null,
28 Bool(Nullability),
30 Primitive(PType, Nullability),
32 Decimal(DecimalDType, Nullability),
34 Utf8(Nullability),
36 Binary(Nullability),
38 Struct(Arc<StructFields>, Nullability),
40 List(Arc<DType>, Nullability),
42 Extension(Arc<ExtDType>),
44}
45
46#[cfg(not(target_arch = "wasm32"))]
47const_assert_eq!(size_of::<DType>(), 16);
48
49#[cfg(target_arch = "wasm32")]
50const_assert_eq!(size_of::<DType>(), 8);
51
52impl DType {
53 pub const BYTES: Self = Primitive(PType::U8, Nullability::NonNullable);
55
56 pub fn nullability(&self) -> Nullability {
58 self.is_nullable().into()
59 }
60
61 pub fn is_nullable(&self) -> bool {
63 use crate::nullability::Nullability::*;
64
65 match self {
66 Null => true,
67 Extension(ext_dtype) => ext_dtype.storage_dtype().is_nullable(),
68 Bool(n)
69 | Primitive(_, n)
70 | Decimal(_, n)
71 | Utf8(n)
72 | Binary(n)
73 | Struct(_, n)
74 | List(_, n) => matches!(n, Nullable),
75 }
76 }
77
78 pub fn as_nonnullable(&self) -> Self {
80 self.with_nullability(Nullability::NonNullable)
81 }
82
83 pub fn as_nullable(&self) -> Self {
85 self.with_nullability(Nullability::Nullable)
86 }
87
88 pub fn with_nullability(&self, nullability: Nullability) -> Self {
90 match self {
91 Null => Null,
92 Bool(_) => Bool(nullability),
93 Primitive(p, _) => Primitive(*p, nullability),
94 Decimal(d, _) => Decimal(*d, nullability),
95 Utf8(_) => Utf8(nullability),
96 Binary(_) => Binary(nullability),
97 Struct(st, _) => Struct(st.clone(), nullability),
98 List(c, _) => List(c.clone(), nullability),
99 Extension(ext) => Extension(Arc::new(ext.with_nullability(nullability))),
100 }
101 }
102
103 pub fn union_nullability(&self, other: Nullability) -> Self {
105 let nullability = self.nullability() | other;
106 self.with_nullability(nullability)
107 }
108
109 pub fn eq_ignore_nullability(&self, other: &Self) -> bool {
111 match (self, other) {
112 (Null, Null) => true,
113 (Bool(_), Bool(_)) => true,
114 (Primitive(lhs_ptype, _), Primitive(rhs_ptype, _)) => lhs_ptype == rhs_ptype,
115 (Decimal(lhs, _), Decimal(rhs, _)) => lhs == rhs,
116 (Utf8(_), Utf8(_)) => true,
117 (Binary(_), Binary(_)) => true,
118 (List(lhs_dtype, _), List(rhs_dtype, _)) => lhs_dtype.eq_ignore_nullability(rhs_dtype),
119 (Struct(lhs_dtype, _), Struct(rhs_dtype, _)) => {
120 (lhs_dtype.names() == rhs_dtype.names())
121 && (lhs_dtype
122 .fields()
123 .zip_eq(rhs_dtype.fields())
124 .all(|(l, r)| l.eq_ignore_nullability(&r)))
125 }
126 (Extension(lhs_extdtype), Extension(rhs_extdtype)) => {
127 lhs_extdtype.as_ref().eq_ignore_nullability(rhs_extdtype)
128 }
129 _ => false,
130 }
131 }
132
133 pub fn is_struct(&self) -> bool {
135 matches!(self, Struct(_, _))
136 }
137
138 pub fn is_primitive(&self) -> bool {
140 matches!(self, Primitive(_, _))
141 }
142
143 pub fn as_ptype(&self) -> PType {
145 match self {
146 Primitive(ptype, _) => *ptype,
147 _ => vortex_panic!("DType is not a primitive type"),
148 }
149 }
150
151 pub fn is_unsigned_int(&self) -> bool {
153 if let Primitive(ptype, _) = self {
154 return ptype.is_unsigned_int();
155 }
156 false
157 }
158
159 pub fn is_signed_int(&self) -> bool {
161 if let Primitive(ptype, _) = self {
162 return ptype.is_signed_int();
163 }
164 false
165 }
166
167 pub fn is_int(&self) -> bool {
169 if let Primitive(ptype, _) = self {
170 return ptype.is_int();
171 }
172 false
173 }
174
175 pub fn is_float(&self) -> bool {
177 if let Primitive(ptype, _) = self {
178 return ptype.is_float();
179 }
180 false
181 }
182
183 pub fn is_boolean(&self) -> bool {
185 matches!(self, Bool(_))
186 }
187
188 pub fn is_binary(&self) -> bool {
190 matches!(self, Binary(_))
191 }
192
193 pub fn is_utf8(&self) -> bool {
195 matches!(self, Utf8(_))
196 }
197
198 pub fn is_extension(&self) -> bool {
200 matches!(self, Extension(_))
201 }
202
203 pub fn is_decimal(&self) -> bool {
205 matches!(self, Decimal(..))
206 }
207
208 pub fn as_decimal(&self) -> Option<&DecimalDType> {
210 match self {
211 Decimal(decimal, _) => Some(decimal),
212 _ => None,
213 }
214 }
215
216 pub fn as_struct(&self) -> Option<&Arc<StructFields>> {
218 match self {
219 Struct(s, _) => Some(s),
220 _ => None,
221 }
222 }
223
224 pub fn as_list_element(&self) -> Option<&Arc<DType>> {
226 match self {
227 List(s, _) => Some(s),
228 _ => None,
229 }
230 }
231}
232
233impl Display for DType {
234 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
235 match self {
236 Null => write!(f, "null"),
237 Bool(n) => write!(f, "bool{n}"),
238 Primitive(pt, n) => write!(f, "{pt}{n}"),
239 Decimal(dt, n) => write!(f, "{dt}{n}"),
240 Utf8(n) => write!(f, "utf8{n}"),
241 Binary(n) => write!(f, "binary{n}"),
242 Struct(sdt, n) => write!(
243 f,
244 "{{{}}}{}",
245 sdt.names()
246 .iter()
247 .zip(sdt.fields())
248 .map(|(n, dt)| format!("{n}={dt}"))
249 .join(", "),
250 n
251 ),
252 List(edt, n) => write!(f, "list({edt}){n}"),
253 Extension(ext) => write!(
254 f,
255 "ext({}, {}{}){}",
256 ext.id(),
257 ext.storage_dtype()
258 .with_nullability(Nullability::NonNullable),
259 ext.metadata()
260 .map(|m| format!(", {m:?}"))
261 .unwrap_or_else(|| "".to_string()),
262 ext.storage_dtype().nullability(),
263 ),
264 }
265 }
266}