1use std::fmt::{Debug, Display, Formatter};
5use std::hash::Hash;
6use std::ops::Index;
7use std::sync::Arc;
8
9use DType::*;
10use itertools::Itertools;
11use static_assertions::const_assert_eq;
12use vortex_error::vortex_panic;
13
14use crate::decimal::DecimalDType;
15use crate::nullability::Nullability;
16use crate::{ExtDType, FieldDType, PType, StructFields};
17
18pub type FieldName = Arc<str>;
20
21#[derive(Clone, PartialEq, Eq, Debug, Default, Hash)]
23#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
24pub struct FieldNames(Arc<[FieldName]>);
25
26impl Display for FieldNames {
27 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
28 write!(
29 f,
30 "[{}]",
31 itertools::join(self.0.iter().map(|n| format!("\"{n}\"")), ", ")
32 )
33 }
34}
35
36impl PartialEq<&FieldNames> for FieldNames {
37 fn eq(&self, other: &&FieldNames) -> bool {
38 self == *other
39 }
40}
41
42impl PartialEq<&[&str]> for FieldNames {
43 fn eq(&self, other: &&[&str]) -> bool {
44 self.len() == other.len() && self.iter().zip_eq(other.iter()).all(|(l, r)| &**l == *r)
45 }
46}
47
48impl PartialEq<&[&str]> for &FieldNames {
49 fn eq(&self, other: &&[&str]) -> bool {
50 *self == other
51 }
52}
53
54impl<const N: usize> PartialEq<[&str; N]> for FieldNames {
55 fn eq(&self, other: &[&str; N]) -> bool {
56 self == other.as_slice()
57 }
58}
59
60impl<const N: usize> PartialEq<[&str; N]> for &FieldNames {
61 fn eq(&self, other: &[&str; N]) -> bool {
62 *self == other.as_slice()
63 }
64}
65
66impl PartialEq<&[FieldName]> for FieldNames {
67 fn eq(&self, other: &&[FieldName]) -> bool {
68 self.0.as_ref() == *other
69 }
70}
71
72impl PartialEq<&[FieldName]> for &FieldNames {
73 fn eq(&self, other: &&[FieldName]) -> bool {
74 self.0.as_ref() == *other
75 }
76}
77
78impl FieldNames {
79 pub fn len(&self) -> usize {
81 self.0.len()
82 }
83
84 pub fn is_empty(&self) -> bool {
86 self.len() == 0
87 }
88
89 pub fn iter(&self) -> impl ExactSizeIterator<Item = &FieldName> {
91 FieldNamesIter {
92 inner: self,
93 idx: 0,
94 }
95 }
96
97 pub fn get(&self, index: usize) -> Option<&FieldName> {
99 self.0.get(index)
100 }
101}
102
103impl AsRef<[FieldName]> for FieldNames {
104 fn as_ref(&self) -> &[FieldName] {
105 &self.0
106 }
107}
108
109impl Index<usize> for FieldNames {
110 type Output = FieldName;
111
112 fn index(&self, index: usize) -> &Self::Output {
113 &self.0[index]
114 }
115}
116
117pub struct FieldNamesIter<'a> {
119 inner: &'a FieldNames,
120 idx: usize,
121}
122
123impl<'a> Iterator for FieldNamesIter<'a> {
124 type Item = &'a FieldName;
125
126 fn next(&mut self) -> Option<Self::Item> {
127 if self.idx >= self.inner.len() {
128 return None;
129 }
130
131 let i = &self.inner.0[self.idx];
132 self.idx += 1;
133 Some(i)
134 }
135
136 fn size_hint(&self) -> (usize, Option<usize>) {
137 let len = self.inner.len() - self.idx;
138 (len, Some(len))
139 }
140}
141
142impl ExactSizeIterator for FieldNamesIter<'_> {}
143
144pub struct FieldNamesIntoIter {
146 inner: FieldNames,
147 idx: usize,
148}
149
150impl Iterator for FieldNamesIntoIter {
151 type Item = FieldName;
152
153 fn next(&mut self) -> Option<Self::Item> {
154 if self.idx >= self.inner.len() {
155 return None;
156 }
157
158 let i = self.inner.0[self.idx].clone();
159 self.idx += 1;
160 Some(i)
161 }
162
163 fn size_hint(&self) -> (usize, Option<usize>) {
164 let len = self.inner.len() - self.idx;
165 (len, Some(len))
166 }
167}
168
169impl ExactSizeIterator for FieldNamesIntoIter {}
170
171impl IntoIterator for FieldNames {
172 type Item = FieldName;
173
174 type IntoIter = FieldNamesIntoIter;
175
176 fn into_iter(self) -> Self::IntoIter {
177 FieldNamesIntoIter {
178 inner: self,
179 idx: 0,
180 }
181 }
182}
183
184impl From<Vec<FieldName>> for FieldNames {
185 fn from(value: Vec<FieldName>) -> Self {
186 Self(value.into())
187 }
188}
189
190impl From<&[&'static str]> for FieldNames {
191 fn from(value: &[&'static str]) -> Self {
192 Self(value.iter().cloned().map(Arc::from).collect())
193 }
194}
195
196impl From<&[FieldName]> for FieldNames {
197 fn from(value: &[FieldName]) -> Self {
198 Self(Arc::from(value))
199 }
200}
201
202impl<const N: usize> From<[&'static str; N]> for FieldNames {
203 fn from(value: [&'static str; N]) -> Self {
204 Self(value.into_iter().map(Arc::from).collect())
205 }
206}
207
208impl<const N: usize> From<[FieldName; N]> for FieldNames {
209 fn from(value: [FieldName; N]) -> Self {
210 Self(value.into())
211 }
212}
213
214impl<F: Into<FieldName>> FromIterator<F> for FieldNames {
215 fn from_iter<T: IntoIterator<Item = F>>(iter: T) -> Self {
216 Self(iter.into_iter().map(|v| v.into()).collect())
217 }
218}
219
220#[derive(Debug, Clone, PartialEq, Eq, Hash)]
225#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
226pub enum DType {
227 Null,
229 Bool(Nullability),
231 Primitive(PType, Nullability),
233 Decimal(DecimalDType, Nullability),
235 Utf8(Nullability),
237 Binary(Nullability),
239 Struct(StructFields, Nullability),
241 List(Arc<DType>, Nullability),
243 Extension(Arc<ExtDType>),
245}
246
247#[cfg(not(target_arch = "wasm32"))]
248const_assert_eq!(size_of::<DType>(), 16);
249
250#[cfg(target_arch = "wasm32")]
251const_assert_eq!(size_of::<DType>(), 8);
252
253impl DType {
254 pub const BYTES: Self = Primitive(PType::U8, Nullability::NonNullable);
256
257 pub fn nullability(&self) -> Nullability {
259 self.is_nullable().into()
260 }
261
262 pub fn is_nullable(&self) -> bool {
264 use crate::nullability::Nullability::*;
265
266 match self {
267 Null => true,
268 Extension(ext_dtype) => ext_dtype.storage_dtype().is_nullable(),
269 Bool(n)
270 | Primitive(_, n)
271 | Decimal(_, n)
272 | Utf8(n)
273 | Binary(n)
274 | Struct(_, n)
275 | List(_, n) => matches!(n, Nullable),
276 }
277 }
278
279 pub fn as_nonnullable(&self) -> Self {
281 self.with_nullability(Nullability::NonNullable)
282 }
283
284 pub fn as_nullable(&self) -> Self {
286 self.with_nullability(Nullability::Nullable)
287 }
288
289 pub fn with_nullability(&self, nullability: Nullability) -> Self {
291 match self {
292 Null => Null,
293 Bool(_) => Bool(nullability),
294 Primitive(p, _) => Primitive(*p, nullability),
295 Decimal(d, _) => Decimal(*d, nullability),
296 Utf8(_) => Utf8(nullability),
297 Binary(_) => Binary(nullability),
298 Struct(st, _) => Struct(st.clone(), nullability),
299 List(c, _) => List(c.clone(), nullability),
300 Extension(ext) => Extension(Arc::new(ext.with_nullability(nullability))),
301 }
302 }
303
304 pub fn union_nullability(&self, other: Nullability) -> Self {
306 let nullability = self.nullability() | other;
307 self.with_nullability(nullability)
308 }
309
310 pub fn eq_ignore_nullability(&self, other: &Self) -> bool {
312 match (self, other) {
313 (Null, Null) => true,
314 (Bool(_), Bool(_)) => true,
315 (Primitive(lhs_ptype, _), Primitive(rhs_ptype, _)) => lhs_ptype == rhs_ptype,
316 (Decimal(lhs, _), Decimal(rhs, _)) => lhs == rhs,
317 (Utf8(_), Utf8(_)) => true,
318 (Binary(_), Binary(_)) => true,
319 (List(lhs_dtype, _), List(rhs_dtype, _)) => lhs_dtype.eq_ignore_nullability(rhs_dtype),
320 (Struct(lhs_dtype, _), Struct(rhs_dtype, _)) => {
321 (lhs_dtype.names() == rhs_dtype.names())
322 && (lhs_dtype
323 .fields()
324 .zip_eq(rhs_dtype.fields())
325 .all(|(l, r)| l.eq_ignore_nullability(&r)))
326 }
327 (Extension(lhs_extdtype), Extension(rhs_extdtype)) => {
328 lhs_extdtype.as_ref().eq_ignore_nullability(rhs_extdtype)
329 }
330 _ => false,
331 }
332 }
333
334 pub fn is_struct(&self) -> bool {
336 matches!(self, Struct(_, _))
337 }
338
339 pub fn is_list(&self) -> bool {
341 matches!(self, List(_, _))
342 }
343
344 pub fn is_primitive(&self) -> bool {
346 matches!(self, Primitive(_, _))
347 }
348
349 pub fn as_ptype(&self) -> PType {
351 match self {
352 Primitive(ptype, _) => *ptype,
353 _ => vortex_panic!("DType is not a primitive type"),
354 }
355 }
356
357 pub fn is_unsigned_int(&self) -> bool {
359 if let Primitive(ptype, _) = self {
360 return ptype.is_unsigned_int();
361 }
362 false
363 }
364
365 pub fn is_signed_int(&self) -> bool {
367 if let Primitive(ptype, _) = self {
368 return ptype.is_signed_int();
369 }
370 false
371 }
372
373 pub fn is_int(&self) -> bool {
375 if let Primitive(ptype, _) = self {
376 return ptype.is_int();
377 }
378 false
379 }
380
381 pub fn is_float(&self) -> bool {
383 if let Primitive(ptype, _) = self {
384 return ptype.is_float();
385 }
386 false
387 }
388
389 pub fn is_boolean(&self) -> bool {
391 matches!(self, Bool(_))
392 }
393
394 pub fn is_binary(&self) -> bool {
396 matches!(self, Binary(_))
397 }
398
399 pub fn is_utf8(&self) -> bool {
401 matches!(self, Utf8(_))
402 }
403
404 pub fn is_extension(&self) -> bool {
406 matches!(self, Extension(_))
407 }
408
409 pub fn is_decimal(&self) -> bool {
411 matches!(self, Decimal(..))
412 }
413
414 pub fn as_decimal_opt(&self) -> Option<&DecimalDType> {
416 match self {
417 Decimal(decimal, _) => Some(decimal),
418 _ => None,
419 }
420 }
421
422 pub fn as_struct_opt(&self) -> Option<&StructFields> {
424 match self {
425 Struct(s, _) => Some(s),
426 _ => None,
427 }
428 }
429
430 pub fn as_list_element_opt(&self) -> Option<&Arc<DType>> {
432 match self {
433 List(s, _) => Some(s),
434 _ => None,
435 }
436 }
437
438 pub fn struct_<I: IntoIterator<Item = (impl Into<FieldName>, impl Into<FieldDType>)>>(
440 iter: I,
441 nullability: Nullability,
442 ) -> Self {
443 Struct(StructFields::from_iter(iter), nullability)
444 }
445
446 pub fn list(dtype: impl Into<DType>, nullability: Nullability) -> Self {
448 List(Arc::new(dtype.into()), nullability)
449 }
450}
451
452impl Display for DType {
453 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
454 match self {
455 Null => write!(f, "null"),
456 Bool(n) => write!(f, "bool{n}"),
457 Primitive(pt, n) => write!(f, "{pt}{n}"),
458 Decimal(dt, n) => write!(f, "{dt}{n}"),
459 Utf8(n) => write!(f, "utf8{n}"),
460 Binary(n) => write!(f, "binary{n}"),
461 Struct(sdt, n) => write!(
462 f,
463 "{{{}}}{}",
464 sdt.names()
465 .iter()
466 .zip(sdt.fields())
467 .map(|(n, dt)| format!("{n}={dt}"))
468 .join(", "),
469 n
470 ),
471 List(edt, n) => write!(f, "list({edt}){n}"),
472 Extension(ext) => write!(
473 f,
474 "ext({}, {}{}){}",
475 ext.id(),
476 ext.storage_dtype()
477 .with_nullability(Nullability::NonNullable),
478 ext.metadata()
479 .map(|m| format!(", {m:?}"))
480 .unwrap_or_else(|| "".to_string()),
481 ext.storage_dtype().nullability(),
482 ),
483 }
484 }
485}
486
487#[cfg(test)]
488mod tests {
489 use super::*;
490
491 #[test]
492 fn test_field_names_iter() {
493 let names = ["a", "b"];
494 let field_names = FieldNames::from(names);
495 assert_eq!(field_names.iter().len(), names.len());
496 let mut iter = field_names.iter();
497 assert_eq!(iter.next(), Some(&"a".into()));
498 assert_eq!(iter.next(), Some(&"b".into()));
499 assert_eq!(iter.next(), None);
500 }
501
502 #[test]
503 fn test_field_names_owned_iter() {
504 let names = ["a", "b"];
505 let field_names = FieldNames::from(names);
506 assert_eq!(field_names.clone().into_iter().len(), names.len());
507 let mut iter = field_names.into_iter();
508 assert_eq!(iter.next(), Some("a".into()));
509 assert_eq!(iter.next(), Some("b".into()));
510 assert_eq!(iter.next(), None);
511 }
512
513 #[test]
514 fn test_field_names_equality() {
515 let field_names = FieldNames::from(["field1", "field2", "field3"]);
516
517 let field_names_ref = &field_names;
519 assert_eq!(field_names, field_names_ref);
520
521 let str_slice = &["field1", "field2", "field3"][..];
523 assert_eq!(field_names, str_slice);
524
525 assert_eq!(&field_names, str_slice);
527
528 assert_eq!(field_names, ["field1", "field2", "field3"]);
530
531 assert_eq!(&field_names, ["field1", "field2", "field3"]);
533
534 let field_name_vec: Vec<FieldName> =
536 vec!["field1".into(), "field2".into(), "field3".into()];
537 let field_name_slice = field_name_vec.as_slice();
538 assert_eq!(field_names, field_name_slice);
539
540 assert_eq!(&field_names, field_name_slice);
542
543 assert_ne!(field_names, &["field1", "field2"][..]);
545 assert_ne!(field_names, ["different", "fields", "here"]);
546 assert_ne!(field_names, &["field1", "field2", "field3", "extra"][..]);
547 }
548
549 #[test]
550 fn test_field_names_display() {
551 let names = FieldNames::from(["a", "b", "c"]);
552 let f = format!("{names}");
553
554 assert_eq!(f, r#"["a", "b", "c"]"#);
555 }
556}