1use std::fmt::{Debug, Display, Formatter};
5use std::hash::Hash;
6use std::ops::Index;
7use std::sync::Arc;
8
9use DType::*;
10use itertools::Itertools;
11use static_assertions::const_assert_eq;
12use vortex_error::vortex_panic;
13
14use crate::decimal::DecimalDType;
15use crate::nullability::Nullability;
16use crate::{ExtDType, FieldDType, PType, StructFields};
17
18pub type FieldName = Arc<str>;
20
21#[derive(Clone, PartialEq, Eq, Debug, Default, Hash)]
23#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
24pub struct FieldNames(Arc<[FieldName]>);
25
26impl FieldNames {
27 pub fn len(&self) -> usize {
29 self.0.len()
30 }
31
32 pub fn is_empty(&self) -> bool {
34 self.len() == 0
35 }
36
37 pub fn iter(&self) -> impl ExactSizeIterator<Item = &FieldName> {
39 FieldNamesIter {
40 inner: self,
41 idx: 0,
42 }
43 }
44
45 pub fn get(&self, index: usize) -> Option<&FieldName> {
47 self.0.get(index)
48 }
49}
50
51impl AsRef<[FieldName]> for FieldNames {
52 fn as_ref(&self) -> &[FieldName] {
53 &self.0
54 }
55}
56
57impl Index<usize> for FieldNames {
58 type Output = FieldName;
59
60 fn index(&self, index: usize) -> &Self::Output {
61 &self.0[index]
62 }
63}
64
65pub struct FieldNamesIter<'a> {
67 inner: &'a FieldNames,
68 idx: usize,
69}
70
71impl<'a> Iterator for FieldNamesIter<'a> {
72 type Item = &'a FieldName;
73
74 fn next(&mut self) -> Option<Self::Item> {
75 if self.idx >= self.inner.len() {
76 return None;
77 }
78
79 let i = &self.inner.0[self.idx];
80 self.idx += 1;
81 Some(i)
82 }
83
84 fn size_hint(&self) -> (usize, Option<usize>) {
85 let len = self.inner.len() - self.idx;
86 (len, Some(len))
87 }
88}
89
90impl ExactSizeIterator for FieldNamesIter<'_> {}
91
92pub struct FieldNamesIntoIter {
94 inner: FieldNames,
95 idx: usize,
96}
97
98impl Iterator for FieldNamesIntoIter {
99 type Item = FieldName;
100
101 fn next(&mut self) -> Option<Self::Item> {
102 if self.idx >= self.inner.len() {
103 return None;
104 }
105
106 let i = self.inner.0[self.idx].clone();
107 self.idx += 1;
108 Some(i)
109 }
110
111 fn size_hint(&self) -> (usize, Option<usize>) {
112 let len = self.inner.len() - self.idx;
113 (len, Some(len))
114 }
115}
116
117impl ExactSizeIterator for FieldNamesIntoIter {}
118
119impl IntoIterator for FieldNames {
120 type Item = FieldName;
121
122 type IntoIter = FieldNamesIntoIter;
123
124 fn into_iter(self) -> Self::IntoIter {
125 FieldNamesIntoIter {
126 inner: self,
127 idx: 0,
128 }
129 }
130}
131
132impl From<Vec<FieldName>> for FieldNames {
133 fn from(value: Vec<FieldName>) -> Self {
134 Self(value.into())
135 }
136}
137
138impl From<&[&'static str]> for FieldNames {
139 fn from(value: &[&'static str]) -> Self {
140 Self(value.iter().cloned().map(Arc::from).collect())
141 }
142}
143
144impl From<&[FieldName]> for FieldNames {
145 fn from(value: &[FieldName]) -> Self {
146 Self(Arc::from(value))
147 }
148}
149
150impl<const N: usize> From<[&'static str; N]> for FieldNames {
151 fn from(value: [&'static str; N]) -> Self {
152 Self(value.into_iter().map(Arc::from).collect())
153 }
154}
155
156impl<const N: usize> From<[FieldName; N]> for FieldNames {
157 fn from(value: [FieldName; N]) -> Self {
158 Self(value.into())
159 }
160}
161
162impl<F: Into<FieldName>> FromIterator<F> for FieldNames {
163 fn from_iter<T: IntoIterator<Item = F>>(iter: T) -> Self {
164 Self(iter.into_iter().map(|v| v.into()).collect())
165 }
166}
167
168#[derive(Debug, Clone, PartialEq, Eq, Hash)]
173#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
174pub enum DType {
175 Null,
177 Bool(Nullability),
179 Primitive(PType, Nullability),
181 Decimal(DecimalDType, Nullability),
183 Utf8(Nullability),
185 Binary(Nullability),
187 Struct(StructFields, Nullability),
189 List(Arc<DType>, Nullability),
191 Extension(Arc<ExtDType>),
193}
194
195#[cfg(not(target_arch = "wasm32"))]
196const_assert_eq!(size_of::<DType>(), 16);
197
198#[cfg(target_arch = "wasm32")]
199const_assert_eq!(size_of::<DType>(), 8);
200
201impl DType {
202 pub const BYTES: Self = Primitive(PType::U8, Nullability::NonNullable);
204
205 pub fn nullability(&self) -> Nullability {
207 self.is_nullable().into()
208 }
209
210 pub fn is_nullable(&self) -> bool {
212 use crate::nullability::Nullability::*;
213
214 match self {
215 Null => true,
216 Extension(ext_dtype) => ext_dtype.storage_dtype().is_nullable(),
217 Bool(n)
218 | Primitive(_, n)
219 | Decimal(_, n)
220 | Utf8(n)
221 | Binary(n)
222 | Struct(_, n)
223 | List(_, n) => matches!(n, Nullable),
224 }
225 }
226
227 pub fn as_nonnullable(&self) -> Self {
229 self.with_nullability(Nullability::NonNullable)
230 }
231
232 pub fn as_nullable(&self) -> Self {
234 self.with_nullability(Nullability::Nullable)
235 }
236
237 pub fn with_nullability(&self, nullability: Nullability) -> Self {
239 match self {
240 Null => Null,
241 Bool(_) => Bool(nullability),
242 Primitive(p, _) => Primitive(*p, nullability),
243 Decimal(d, _) => Decimal(*d, nullability),
244 Utf8(_) => Utf8(nullability),
245 Binary(_) => Binary(nullability),
246 Struct(st, _) => Struct(st.clone(), nullability),
247 List(c, _) => List(c.clone(), nullability),
248 Extension(ext) => Extension(Arc::new(ext.with_nullability(nullability))),
249 }
250 }
251
252 pub fn union_nullability(&self, other: Nullability) -> Self {
254 let nullability = self.nullability() | other;
255 self.with_nullability(nullability)
256 }
257
258 pub fn eq_ignore_nullability(&self, other: &Self) -> bool {
260 match (self, other) {
261 (Null, Null) => true,
262 (Bool(_), Bool(_)) => true,
263 (Primitive(lhs_ptype, _), Primitive(rhs_ptype, _)) => lhs_ptype == rhs_ptype,
264 (Decimal(lhs, _), Decimal(rhs, _)) => lhs == rhs,
265 (Utf8(_), Utf8(_)) => true,
266 (Binary(_), Binary(_)) => true,
267 (List(lhs_dtype, _), List(rhs_dtype, _)) => lhs_dtype.eq_ignore_nullability(rhs_dtype),
268 (Struct(lhs_dtype, _), Struct(rhs_dtype, _)) => {
269 (lhs_dtype.names() == rhs_dtype.names())
270 && (lhs_dtype
271 .fields()
272 .zip_eq(rhs_dtype.fields())
273 .all(|(l, r)| l.eq_ignore_nullability(&r)))
274 }
275 (Extension(lhs_extdtype), Extension(rhs_extdtype)) => {
276 lhs_extdtype.as_ref().eq_ignore_nullability(rhs_extdtype)
277 }
278 _ => false,
279 }
280 }
281
282 pub fn is_struct(&self) -> bool {
284 matches!(self, Struct(_, _))
285 }
286
287 pub fn is_list(&self) -> bool {
289 matches!(self, List(_, _))
290 }
291
292 pub fn is_primitive(&self) -> bool {
294 matches!(self, Primitive(_, _))
295 }
296
297 pub fn as_ptype(&self) -> PType {
299 match self {
300 Primitive(ptype, _) => *ptype,
301 _ => vortex_panic!("DType is not a primitive type"),
302 }
303 }
304
305 pub fn is_unsigned_int(&self) -> bool {
307 if let Primitive(ptype, _) = self {
308 return ptype.is_unsigned_int();
309 }
310 false
311 }
312
313 pub fn is_signed_int(&self) -> bool {
315 if let Primitive(ptype, _) = self {
316 return ptype.is_signed_int();
317 }
318 false
319 }
320
321 pub fn is_int(&self) -> bool {
323 if let Primitive(ptype, _) = self {
324 return ptype.is_int();
325 }
326 false
327 }
328
329 pub fn is_float(&self) -> bool {
331 if let Primitive(ptype, _) = self {
332 return ptype.is_float();
333 }
334 false
335 }
336
337 pub fn is_boolean(&self) -> bool {
339 matches!(self, Bool(_))
340 }
341
342 pub fn is_binary(&self) -> bool {
344 matches!(self, Binary(_))
345 }
346
347 pub fn is_utf8(&self) -> bool {
349 matches!(self, Utf8(_))
350 }
351
352 pub fn is_extension(&self) -> bool {
354 matches!(self, Extension(_))
355 }
356
357 pub fn is_decimal(&self) -> bool {
359 matches!(self, Decimal(..))
360 }
361
362 pub fn as_decimal(&self) -> Option<&DecimalDType> {
364 match self {
365 Decimal(decimal, _) => Some(decimal),
366 _ => None,
367 }
368 }
369
370 pub fn as_struct(&self) -> Option<&StructFields> {
372 match self {
373 Struct(s, _) => Some(s),
374 _ => None,
375 }
376 }
377
378 pub fn as_list_element(&self) -> Option<&Arc<DType>> {
380 match self {
381 List(s, _) => Some(s),
382 _ => None,
383 }
384 }
385
386 pub fn struct_<I: IntoIterator<Item = (impl Into<FieldName>, impl Into<FieldDType>)>>(
388 iter: I,
389 nullability: Nullability,
390 ) -> Self {
391 Struct(StructFields::from_iter(iter), nullability)
392 }
393
394 pub fn list(dtype: impl Into<DType>, nullability: Nullability) -> Self {
396 List(Arc::new(dtype.into()), nullability)
397 }
398}
399
400impl Display for DType {
401 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
402 match self {
403 Null => write!(f, "null"),
404 Bool(n) => write!(f, "bool{n}"),
405 Primitive(pt, n) => write!(f, "{pt}{n}"),
406 Decimal(dt, n) => write!(f, "{dt}{n}"),
407 Utf8(n) => write!(f, "utf8{n}"),
408 Binary(n) => write!(f, "binary{n}"),
409 Struct(sdt, n) => write!(
410 f,
411 "{{{}}}{}",
412 sdt.names()
413 .iter()
414 .zip(sdt.fields())
415 .map(|(n, dt)| format!("{n}={dt}"))
416 .join(", "),
417 n
418 ),
419 List(edt, n) => write!(f, "list({edt}){n}"),
420 Extension(ext) => write!(
421 f,
422 "ext({}, {}{}){}",
423 ext.id(),
424 ext.storage_dtype()
425 .with_nullability(Nullability::NonNullable),
426 ext.metadata()
427 .map(|m| format!(", {m:?}"))
428 .unwrap_or_else(|| "".to_string()),
429 ext.storage_dtype().nullability(),
430 ),
431 }
432 }
433}
434
435#[cfg(test)]
436mod tests {
437 use super::*;
438
439 #[test]
440 fn test_field_names_iter() {
441 let names = ["a", "b"];
442 let field_names = FieldNames::from(names);
443 assert_eq!(field_names.iter().len(), names.len());
444 let mut iter = field_names.iter();
445 assert_eq!(iter.next(), Some(&"a".into()));
446 assert_eq!(iter.next(), Some(&"b".into()));
447 assert_eq!(iter.next(), None);
448 }
449
450 #[test]
451 fn test_field_names_owned_iter() {
452 let names = ["a", "b"];
453 let field_names = FieldNames::from(names);
454 assert_eq!(field_names.clone().into_iter().len(), names.len());
455 let mut iter = field_names.into_iter();
456 assert_eq!(iter.next(), Some("a".into()));
457 assert_eq!(iter.next(), Some("b".into()));
458 assert_eq!(iter.next(), None);
459 }
460}