1use std::fmt::{Debug, Display, Formatter};
2use std::hash::Hash;
3use std::ops::Index;
4use std::sync::Arc;
5
6use DType::*;
7use itertools::Itertools;
8use static_assertions::const_assert_eq;
9use vortex_error::vortex_panic;
10
11use crate::decimal::DecimalDType;
12use crate::nullability::Nullability;
13use crate::{ExtDType, FieldDType, PType, StructFields};
14
15pub type FieldName = Arc<str>;
17
18#[derive(Clone, PartialEq, Eq, Debug, Default, Hash)]
20#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
21pub struct FieldNames(Arc<[FieldName]>);
22
23impl FieldNames {
24 pub fn len(&self) -> usize {
26 self.0.len()
27 }
28
29 pub fn is_empty(&self) -> bool {
31 self.len() == 0
32 }
33
34 pub fn iter(&self) -> impl ExactSizeIterator<Item = &FieldName> {
36 FieldNamesIter {
37 inner: self,
38 idx: 0,
39 }
40 }
41
42 pub fn get(&self, index: usize) -> Option<&FieldName> {
44 self.0.get(index)
45 }
46}
47
48impl AsRef<[FieldName]> for FieldNames {
49 fn as_ref(&self) -> &[FieldName] {
50 &self.0
51 }
52}
53
54impl Index<usize> for FieldNames {
55 type Output = FieldName;
56
57 fn index(&self, index: usize) -> &Self::Output {
58 &self.0[index]
59 }
60}
61
62pub struct FieldNamesIter<'a> {
64 inner: &'a FieldNames,
65 idx: usize,
66}
67
68impl<'a> Iterator for FieldNamesIter<'a> {
69 type Item = &'a FieldName;
70
71 fn next(&mut self) -> Option<Self::Item> {
72 if self.idx >= self.inner.len() {
73 return None;
74 }
75
76 let i = &self.inner.0[self.idx];
77 self.idx += 1;
78 Some(i)
79 }
80}
81
82impl ExactSizeIterator for FieldNamesIter<'_> {
83 fn len(&self) -> usize {
84 self.inner.len() - self.idx
85 }
86}
87
88pub struct FieldNamesIntoIter {
90 inner: FieldNames,
91 idx: usize,
92}
93
94impl Iterator for FieldNamesIntoIter {
95 type Item = FieldName;
96
97 fn next(&mut self) -> Option<Self::Item> {
98 if self.idx >= self.inner.len() {
99 return None;
100 }
101
102 let i = self.inner.0[self.idx].clone();
103 self.idx += 1;
104 Some(i)
105 }
106}
107
108impl ExactSizeIterator for FieldNamesIntoIter {
109 fn len(&self) -> usize {
110 self.inner.len() - self.idx
111 }
112}
113
114impl IntoIterator for FieldNames {
115 type Item = FieldName;
116
117 type IntoIter = FieldNamesIntoIter;
118
119 fn into_iter(self) -> Self::IntoIter {
120 FieldNamesIntoIter {
121 inner: self,
122 idx: 0,
123 }
124 }
125}
126
127impl From<Vec<FieldName>> for FieldNames {
128 fn from(value: Vec<FieldName>) -> Self {
129 Self(value.into())
130 }
131}
132
133impl From<&[&'static str]> for FieldNames {
134 fn from(value: &[&'static str]) -> Self {
135 Self(value.iter().cloned().map(Arc::from).collect())
136 }
137}
138
139impl From<&[FieldName]> for FieldNames {
140 fn from(value: &[FieldName]) -> Self {
141 Self(Arc::from(value))
142 }
143}
144
145impl<const N: usize> From<[&'static str; N]> for FieldNames {
146 fn from(value: [&'static str; N]) -> Self {
147 Self(value.into_iter().map(Arc::from).collect())
148 }
149}
150
151impl<const N: usize> From<[FieldName; N]> for FieldNames {
152 fn from(value: [FieldName; N]) -> Self {
153 Self(value.into())
154 }
155}
156
157impl<F: Into<FieldName>> FromIterator<F> for FieldNames {
158 fn from_iter<T: IntoIterator<Item = F>>(iter: T) -> Self {
159 Self(iter.into_iter().map(|v| v.into()).collect())
160 }
161}
162
163#[derive(Debug, Clone, PartialEq, Eq, Hash)]
168#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
169pub enum DType {
170 Null,
172 Bool(Nullability),
174 Primitive(PType, Nullability),
176 Decimal(DecimalDType, Nullability),
178 Utf8(Nullability),
180 Binary(Nullability),
182 Struct(StructFields, Nullability),
184 List(Arc<DType>, Nullability),
186 Extension(Arc<ExtDType>),
188}
189
190#[cfg(not(target_arch = "wasm32"))]
191const_assert_eq!(size_of::<DType>(), 16);
192
193#[cfg(target_arch = "wasm32")]
194const_assert_eq!(size_of::<DType>(), 8);
195
196impl DType {
197 pub const BYTES: Self = Primitive(PType::U8, Nullability::NonNullable);
199
200 pub fn nullability(&self) -> Nullability {
202 self.is_nullable().into()
203 }
204
205 pub fn is_nullable(&self) -> bool {
207 use crate::nullability::Nullability::*;
208
209 match self {
210 Null => true,
211 Extension(ext_dtype) => ext_dtype.storage_dtype().is_nullable(),
212 Bool(n)
213 | Primitive(_, n)
214 | Decimal(_, n)
215 | Utf8(n)
216 | Binary(n)
217 | Struct(_, n)
218 | List(_, n) => matches!(n, Nullable),
219 }
220 }
221
222 pub fn as_nonnullable(&self) -> Self {
224 self.with_nullability(Nullability::NonNullable)
225 }
226
227 pub fn as_nullable(&self) -> Self {
229 self.with_nullability(Nullability::Nullable)
230 }
231
232 pub fn with_nullability(&self, nullability: Nullability) -> Self {
234 match self {
235 Null => Null,
236 Bool(_) => Bool(nullability),
237 Primitive(p, _) => Primitive(*p, nullability),
238 Decimal(d, _) => Decimal(*d, nullability),
239 Utf8(_) => Utf8(nullability),
240 Binary(_) => Binary(nullability),
241 Struct(st, _) => Struct(st.clone(), nullability),
242 List(c, _) => List(c.clone(), nullability),
243 Extension(ext) => Extension(Arc::new(ext.with_nullability(nullability))),
244 }
245 }
246
247 pub fn union_nullability(&self, other: Nullability) -> Self {
249 let nullability = self.nullability() | other;
250 self.with_nullability(nullability)
251 }
252
253 pub fn eq_ignore_nullability(&self, other: &Self) -> bool {
255 match (self, other) {
256 (Null, Null) => true,
257 (Bool(_), Bool(_)) => true,
258 (Primitive(lhs_ptype, _), Primitive(rhs_ptype, _)) => lhs_ptype == rhs_ptype,
259 (Decimal(lhs, _), Decimal(rhs, _)) => lhs == rhs,
260 (Utf8(_), Utf8(_)) => true,
261 (Binary(_), Binary(_)) => true,
262 (List(lhs_dtype, _), List(rhs_dtype, _)) => lhs_dtype.eq_ignore_nullability(rhs_dtype),
263 (Struct(lhs_dtype, _), Struct(rhs_dtype, _)) => {
264 (lhs_dtype.names() == rhs_dtype.names())
265 && (lhs_dtype
266 .fields()
267 .zip_eq(rhs_dtype.fields())
268 .all(|(l, r)| l.eq_ignore_nullability(&r)))
269 }
270 (Extension(lhs_extdtype), Extension(rhs_extdtype)) => {
271 lhs_extdtype.as_ref().eq_ignore_nullability(rhs_extdtype)
272 }
273 _ => false,
274 }
275 }
276
277 pub fn is_struct(&self) -> bool {
279 matches!(self, Struct(_, _))
280 }
281
282 pub fn is_primitive(&self) -> bool {
284 matches!(self, Primitive(_, _))
285 }
286
287 pub fn as_ptype(&self) -> PType {
289 match self {
290 Primitive(ptype, _) => *ptype,
291 _ => vortex_panic!("DType is not a primitive type"),
292 }
293 }
294
295 pub fn is_unsigned_int(&self) -> bool {
297 if let Primitive(ptype, _) = self {
298 return ptype.is_unsigned_int();
299 }
300 false
301 }
302
303 pub fn is_signed_int(&self) -> bool {
305 if let Primitive(ptype, _) = self {
306 return ptype.is_signed_int();
307 }
308 false
309 }
310
311 pub fn is_int(&self) -> bool {
313 if let Primitive(ptype, _) = self {
314 return ptype.is_int();
315 }
316 false
317 }
318
319 pub fn is_float(&self) -> bool {
321 if let Primitive(ptype, _) = self {
322 return ptype.is_float();
323 }
324 false
325 }
326
327 pub fn is_boolean(&self) -> bool {
329 matches!(self, Bool(_))
330 }
331
332 pub fn is_binary(&self) -> bool {
334 matches!(self, Binary(_))
335 }
336
337 pub fn is_utf8(&self) -> bool {
339 matches!(self, Utf8(_))
340 }
341
342 pub fn is_extension(&self) -> bool {
344 matches!(self, Extension(_))
345 }
346
347 pub fn is_decimal(&self) -> bool {
349 matches!(self, Decimal(..))
350 }
351
352 pub fn as_decimal(&self) -> Option<&DecimalDType> {
354 match self {
355 Decimal(decimal, _) => Some(decimal),
356 _ => None,
357 }
358 }
359
360 pub fn as_struct(&self) -> Option<&StructFields> {
362 match self {
363 Struct(s, _) => Some(s),
364 _ => None,
365 }
366 }
367
368 pub fn as_list_element(&self) -> Option<&Arc<DType>> {
370 match self {
371 List(s, _) => Some(s),
372 _ => None,
373 }
374 }
375
376 pub fn struct_<I: IntoIterator<Item = (impl Into<FieldName>, impl Into<FieldDType>)>>(
378 iter: I,
379 nullability: Nullability,
380 ) -> Self {
381 Struct(StructFields::from_iter(iter), nullability)
382 }
383
384 pub fn list(dtype: impl Into<DType>, nullability: Nullability) -> Self {
386 List(Arc::new(dtype.into()), nullability)
387 }
388}
389
390impl Display for DType {
391 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
392 match self {
393 Null => write!(f, "null"),
394 Bool(n) => write!(f, "bool{n}"),
395 Primitive(pt, n) => write!(f, "{pt}{n}"),
396 Decimal(dt, n) => write!(f, "{dt}{n}"),
397 Utf8(n) => write!(f, "utf8{n}"),
398 Binary(n) => write!(f, "binary{n}"),
399 Struct(sdt, n) => write!(
400 f,
401 "{{{}}}{}",
402 sdt.names()
403 .iter()
404 .zip(sdt.fields())
405 .map(|(n, dt)| format!("{n}={dt}"))
406 .join(", "),
407 n
408 ),
409 List(edt, n) => write!(f, "list({edt}){n}"),
410 Extension(ext) => write!(
411 f,
412 "ext({}, {}{}){}",
413 ext.id(),
414 ext.storage_dtype()
415 .with_nullability(Nullability::NonNullable),
416 ext.metadata()
417 .map(|m| format!(", {m:?}"))
418 .unwrap_or_else(|| "".to_string()),
419 ext.storage_dtype().nullability(),
420 ),
421 }
422 }
423}