vortex_dtype/
dtype.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::{Debug, Display, Formatter};
5use std::hash::Hash;
6use std::ops::Index;
7use std::sync::Arc;
8
9use DType::*;
10use itertools::Itertools;
11use static_assertions::const_assert_eq;
12use vortex_error::vortex_panic;
13
14use crate::decimal::DecimalDType;
15use crate::nullability::Nullability;
16use crate::{ExtDType, FieldDType, PType, StructFields};
17
18/// A name for a field in a struct
19pub type FieldName = Arc<str>;
20
21/// An ordered list of field names in a struct
22#[derive(Clone, PartialEq, Eq, Debug, Default, Hash)]
23#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
24pub struct FieldNames(Arc<[FieldName]>);
25
26impl Display for FieldNames {
27    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
28        write!(
29            f,
30            "[{}]",
31            itertools::join(self.0.iter().map(|n| format!("\"{n}\"")), ", ")
32        )
33    }
34}
35
36impl PartialEq<&FieldNames> for FieldNames {
37    fn eq(&self, other: &&FieldNames) -> bool {
38        self == *other
39    }
40}
41
42impl PartialEq<&[&str]> for FieldNames {
43    fn eq(&self, other: &&[&str]) -> bool {
44        self.len() == other.len() && self.iter().zip_eq(other.iter()).all(|(l, r)| &**l == *r)
45    }
46}
47
48impl PartialEq<&[&str]> for &FieldNames {
49    fn eq(&self, other: &&[&str]) -> bool {
50        *self == other
51    }
52}
53
54impl<const N: usize> PartialEq<[&str; N]> for FieldNames {
55    fn eq(&self, other: &[&str; N]) -> bool {
56        self == other.as_slice()
57    }
58}
59
60impl<const N: usize> PartialEq<[&str; N]> for &FieldNames {
61    fn eq(&self, other: &[&str; N]) -> bool {
62        *self == other.as_slice()
63    }
64}
65
66impl PartialEq<&[FieldName]> for FieldNames {
67    fn eq(&self, other: &&[FieldName]) -> bool {
68        self.0.as_ref() == *other
69    }
70}
71
72impl PartialEq<&[FieldName]> for &FieldNames {
73    fn eq(&self, other: &&[FieldName]) -> bool {
74        self.0.as_ref() == *other
75    }
76}
77
78impl FieldNames {
79    /// Returns the number of elements.
80    pub fn len(&self) -> usize {
81        self.0.len()
82    }
83
84    /// Returns true if the number of elements is 0.
85    pub fn is_empty(&self) -> bool {
86        self.len() == 0
87    }
88
89    /// Returns a borrowed iterator over the field names.
90    pub fn iter(&self) -> impl ExactSizeIterator<Item = &FieldName> {
91        FieldNamesIter {
92            inner: self,
93            idx: 0,
94        }
95    }
96
97    /// Returns a reference to a field name, or None if `index` is out of bounds.
98    pub fn get(&self, index: usize) -> Option<&FieldName> {
99        self.0.get(index)
100    }
101}
102
103impl AsRef<[FieldName]> for FieldNames {
104    fn as_ref(&self) -> &[FieldName] {
105        &self.0
106    }
107}
108
109impl Index<usize> for FieldNames {
110    type Output = FieldName;
111
112    fn index(&self, index: usize) -> &Self::Output {
113        &self.0[index]
114    }
115}
116
117/// Iterator of references to field names
118pub struct FieldNamesIter<'a> {
119    inner: &'a FieldNames,
120    idx: usize,
121}
122
123impl<'a> Iterator for FieldNamesIter<'a> {
124    type Item = &'a FieldName;
125
126    fn next(&mut self) -> Option<Self::Item> {
127        if self.idx >= self.inner.len() {
128            return None;
129        }
130
131        let i = &self.inner.0[self.idx];
132        self.idx += 1;
133        Some(i)
134    }
135
136    fn size_hint(&self) -> (usize, Option<usize>) {
137        let len = self.inner.len() - self.idx;
138        (len, Some(len))
139    }
140}
141
142impl ExactSizeIterator for FieldNamesIter<'_> {}
143
144/// Owned iterator of field names.
145pub struct FieldNamesIntoIter {
146    inner: FieldNames,
147    idx: usize,
148}
149
150impl Iterator for FieldNamesIntoIter {
151    type Item = FieldName;
152
153    fn next(&mut self) -> Option<Self::Item> {
154        if self.idx >= self.inner.len() {
155            return None;
156        }
157
158        let i = self.inner.0[self.idx].clone();
159        self.idx += 1;
160        Some(i)
161    }
162
163    fn size_hint(&self) -> (usize, Option<usize>) {
164        let len = self.inner.len() - self.idx;
165        (len, Some(len))
166    }
167}
168
169impl ExactSizeIterator for FieldNamesIntoIter {}
170
171impl IntoIterator for FieldNames {
172    type Item = FieldName;
173
174    type IntoIter = FieldNamesIntoIter;
175
176    fn into_iter(self) -> Self::IntoIter {
177        FieldNamesIntoIter {
178            inner: self,
179            idx: 0,
180        }
181    }
182}
183
184impl From<Vec<FieldName>> for FieldNames {
185    fn from(value: Vec<FieldName>) -> Self {
186        Self(value.into())
187    }
188}
189
190impl From<&[&'static str]> for FieldNames {
191    fn from(value: &[&'static str]) -> Self {
192        Self(value.iter().cloned().map(Arc::from).collect())
193    }
194}
195
196impl From<&[FieldName]> for FieldNames {
197    fn from(value: &[FieldName]) -> Self {
198        Self(Arc::from(value))
199    }
200}
201
202impl<const N: usize> From<[&'static str; N]> for FieldNames {
203    fn from(value: [&'static str; N]) -> Self {
204        Self(value.into_iter().map(Arc::from).collect())
205    }
206}
207
208impl<const N: usize> From<[FieldName; N]> for FieldNames {
209    fn from(value: [FieldName; N]) -> Self {
210        Self(value.into())
211    }
212}
213
214impl<F: Into<FieldName>> FromIterator<F> for FieldNames {
215    fn from_iter<T: IntoIterator<Item = F>>(iter: T) -> Self {
216        Self(iter.into_iter().map(|v| v.into()).collect())
217    }
218}
219
220/// The logical types of elements in Vortex arrays.
221///
222/// Vortex arrays preserve a single logical type, while the encodings allow for multiple
223/// physical ways to encode that type.
224#[derive(Debug, Clone, PartialEq, Eq, Hash)]
225#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
226pub enum DType {
227    /// The logical null type (only has a single value, `null`)
228    Null,
229    /// The logical boolean type (`true` or `false` if non-nullable; `true`, `false`, or `null` if nullable)
230    Bool(Nullability),
231    /// Primitive, fixed-width numeric types (e.g., `u8`, `i8`, `u16`, `i16`, `u32`, `i32`, `u64`, `i64`, `f32`, `f64`)
232    Primitive(PType, Nullability),
233    /// Real numbers with fixed exact precision and scale.
234    Decimal(DecimalDType, Nullability),
235    /// UTF-8 strings
236    Utf8(Nullability),
237    /// Binary data
238    Binary(Nullability),
239    /// A struct is composed of an ordered list of fields, each with a corresponding name and DType
240    Struct(StructFields, Nullability),
241    /// A variable-length list type, parameterized by a single element DType
242    List(Arc<DType>, Nullability),
243    /// User-defined extension types
244    Extension(Arc<ExtDType>),
245}
246
247#[cfg(not(target_arch = "wasm32"))]
248const_assert_eq!(size_of::<DType>(), 16);
249
250#[cfg(target_arch = "wasm32")]
251const_assert_eq!(size_of::<DType>(), 8);
252
253impl DType {
254    /// The default DType for bytes
255    pub const BYTES: Self = Primitive(PType::U8, Nullability::NonNullable);
256
257    /// Get the nullability of the DType
258    pub fn nullability(&self) -> Nullability {
259        self.is_nullable().into()
260    }
261
262    /// Check if the DType is nullable
263    pub fn is_nullable(&self) -> bool {
264        use crate::nullability::Nullability::*;
265
266        match self {
267            Null => true,
268            Extension(ext_dtype) => ext_dtype.storage_dtype().is_nullable(),
269            Bool(n)
270            | Primitive(_, n)
271            | Decimal(_, n)
272            | Utf8(n)
273            | Binary(n)
274            | Struct(_, n)
275            | List(_, n) => matches!(n, Nullable),
276        }
277    }
278
279    /// Get a new DType with `Nullability::NonNullable` (but otherwise the same as `self`)
280    pub fn as_nonnullable(&self) -> Self {
281        self.with_nullability(Nullability::NonNullable)
282    }
283
284    /// Get a new DType with `Nullability::Nullable` (but otherwise the same as `self`)
285    pub fn as_nullable(&self) -> Self {
286        self.with_nullability(Nullability::Nullable)
287    }
288
289    /// Get a new DType with the given nullability (but otherwise the same as `self`)
290    pub fn with_nullability(&self, nullability: Nullability) -> Self {
291        match self {
292            Null => Null,
293            Bool(_) => Bool(nullability),
294            Primitive(p, _) => Primitive(*p, nullability),
295            Decimal(d, _) => Decimal(*d, nullability),
296            Utf8(_) => Utf8(nullability),
297            Binary(_) => Binary(nullability),
298            Struct(st, _) => Struct(st.clone(), nullability),
299            List(c, _) => List(c.clone(), nullability),
300            Extension(ext) => Extension(Arc::new(ext.with_nullability(nullability))),
301        }
302    }
303
304    /// Union the nullability of this dtype with the other nullability, returning a new dtype.
305    pub fn union_nullability(&self, other: Nullability) -> Self {
306        let nullability = self.nullability() | other;
307        self.with_nullability(nullability)
308    }
309
310    /// Check if `self` and `other` are equal, ignoring nullability
311    pub fn eq_ignore_nullability(&self, other: &Self) -> bool {
312        match (self, other) {
313            (Null, Null) => true,
314            (Bool(_), Bool(_)) => true,
315            (Primitive(lhs_ptype, _), Primitive(rhs_ptype, _)) => lhs_ptype == rhs_ptype,
316            (Decimal(lhs, _), Decimal(rhs, _)) => lhs == rhs,
317            (Utf8(_), Utf8(_)) => true,
318            (Binary(_), Binary(_)) => true,
319            (List(lhs_dtype, _), List(rhs_dtype, _)) => lhs_dtype.eq_ignore_nullability(rhs_dtype),
320            (Struct(lhs_dtype, _), Struct(rhs_dtype, _)) => {
321                (lhs_dtype.names() == rhs_dtype.names())
322                    && (lhs_dtype
323                        .fields()
324                        .zip_eq(rhs_dtype.fields())
325                        .all(|(l, r)| l.eq_ignore_nullability(&r)))
326            }
327            (Extension(lhs_extdtype), Extension(rhs_extdtype)) => {
328                lhs_extdtype.as_ref().eq_ignore_nullability(rhs_extdtype)
329            }
330            _ => false,
331        }
332    }
333
334    /// Check if `self` is a `StructDType`
335    pub fn is_struct(&self) -> bool {
336        matches!(self, Struct(_, _))
337    }
338
339    /// Check if `self` is a `ListDType`
340    pub fn is_list(&self) -> bool {
341        matches!(self, List(_, _))
342    }
343
344    /// Check if `self` is a primitive tpye
345    pub fn is_primitive(&self) -> bool {
346        matches!(self, Primitive(_, _))
347    }
348
349    /// Returns this DType's `PType` if it is a primitive type, otherwise panics.
350    pub fn as_ptype(&self) -> PType {
351        match self {
352            Primitive(ptype, _) => *ptype,
353            _ => vortex_panic!("DType is not a primitive type"),
354        }
355    }
356
357    /// Check if `self` is an unsigned integer
358    pub fn is_unsigned_int(&self) -> bool {
359        if let Primitive(ptype, _) = self {
360            return ptype.is_unsigned_int();
361        }
362        false
363    }
364
365    /// Check if `self` is a signed integer
366    pub fn is_signed_int(&self) -> bool {
367        if let Primitive(ptype, _) = self {
368            return ptype.is_signed_int();
369        }
370        false
371    }
372
373    /// Check if `self` is an integer (signed or unsigned)
374    pub fn is_int(&self) -> bool {
375        if let Primitive(ptype, _) = self {
376            return ptype.is_int();
377        }
378        false
379    }
380
381    /// Check if `self` is a floating point number
382    pub fn is_float(&self) -> bool {
383        if let Primitive(ptype, _) = self {
384            return ptype.is_float();
385        }
386        false
387    }
388
389    /// Check if `self` is a boolean
390    pub fn is_boolean(&self) -> bool {
391        matches!(self, Bool(_))
392    }
393
394    /// Check if `self` is a binary
395    pub fn is_binary(&self) -> bool {
396        matches!(self, Binary(_))
397    }
398
399    /// Check if `self` is a utf8
400    pub fn is_utf8(&self) -> bool {
401        matches!(self, Utf8(_))
402    }
403
404    /// Check if `self` is an extension type
405    pub fn is_extension(&self) -> bool {
406        matches!(self, Extension(_))
407    }
408
409    /// Check if `self` is a decimal type
410    pub fn is_decimal(&self) -> bool {
411        matches!(self, Decimal(..))
412    }
413
414    /// Check returns the inner decimal type if the dtype is a decimal
415    pub fn as_decimal_opt(&self) -> Option<&DecimalDType> {
416        match self {
417            Decimal(decimal, _) => Some(decimal),
418            _ => None,
419        }
420    }
421
422    /// Get the `StructDType` if `self` is a `StructDType`, otherwise `None`
423    pub fn as_struct_opt(&self) -> Option<&StructFields> {
424        match self {
425            Struct(s, _) => Some(s),
426            _ => None,
427        }
428    }
429
430    /// Get the inner dtype if `self` is a `ListDType`, otherwise `None`
431    pub fn as_list_element_opt(&self) -> Option<&Arc<DType>> {
432        match self {
433            List(s, _) => Some(s),
434            _ => None,
435        }
436    }
437
438    /// Convenience method for creating a struct dtype
439    pub fn struct_<I: IntoIterator<Item = (impl Into<FieldName>, impl Into<FieldDType>)>>(
440        iter: I,
441        nullability: Nullability,
442    ) -> Self {
443        Struct(StructFields::from_iter(iter), nullability)
444    }
445
446    /// Convenience method for creating a list dtype
447    pub fn list(dtype: impl Into<DType>, nullability: Nullability) -> Self {
448        List(Arc::new(dtype.into()), nullability)
449    }
450}
451
452impl Display for DType {
453    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
454        match self {
455            Null => write!(f, "null"),
456            Bool(n) => write!(f, "bool{n}"),
457            Primitive(pt, n) => write!(f, "{pt}{n}"),
458            Decimal(dt, n) => write!(f, "{dt}{n}"),
459            Utf8(n) => write!(f, "utf8{n}"),
460            Binary(n) => write!(f, "binary{n}"),
461            Struct(sdt, n) => write!(
462                f,
463                "{{{}}}{}",
464                sdt.names()
465                    .iter()
466                    .zip(sdt.fields())
467                    .map(|(n, dt)| format!("{n}={dt}"))
468                    .join(", "),
469                n
470            ),
471            List(edt, n) => write!(f, "list({edt}){n}"),
472            Extension(ext) => write!(
473                f,
474                "ext({}, {}{}){}",
475                ext.id(),
476                ext.storage_dtype()
477                    .with_nullability(Nullability::NonNullable),
478                ext.metadata()
479                    .map(|m| format!(", {m:?}"))
480                    .unwrap_or_else(|| "".to_string()),
481                ext.storage_dtype().nullability(),
482            ),
483        }
484    }
485}
486
487#[cfg(test)]
488mod tests {
489    use super::*;
490
491    #[test]
492    fn test_field_names_iter() {
493        let names = ["a", "b"];
494        let field_names = FieldNames::from(names);
495        assert_eq!(field_names.iter().len(), names.len());
496        let mut iter = field_names.iter();
497        assert_eq!(iter.next(), Some(&"a".into()));
498        assert_eq!(iter.next(), Some(&"b".into()));
499        assert_eq!(iter.next(), None);
500    }
501
502    #[test]
503    fn test_field_names_owned_iter() {
504        let names = ["a", "b"];
505        let field_names = FieldNames::from(names);
506        assert_eq!(field_names.clone().into_iter().len(), names.len());
507        let mut iter = field_names.into_iter();
508        assert_eq!(iter.next(), Some("a".into()));
509        assert_eq!(iter.next(), Some("b".into()));
510        assert_eq!(iter.next(), None);
511    }
512
513    #[test]
514    fn test_field_names_equality() {
515        let field_names = FieldNames::from(["field1", "field2", "field3"]);
516
517        // FieldNames == &FieldNames
518        let field_names_ref = &field_names;
519        assert_eq!(field_names, field_names_ref);
520
521        // FieldNames == &[&str]
522        let str_slice = &["field1", "field2", "field3"][..];
523        assert_eq!(field_names, str_slice);
524
525        // &FieldNames == &[&str]
526        assert_eq!(&field_names, str_slice);
527
528        // FieldNames == [&str; N] (array)
529        assert_eq!(field_names, ["field1", "field2", "field3"]);
530
531        // &FieldNames == [&str; N] (array)
532        assert_eq!(&field_names, ["field1", "field2", "field3"]);
533
534        // FieldNames == &[FieldName]
535        let field_name_vec: Vec<FieldName> =
536            vec!["field1".into(), "field2".into(), "field3".into()];
537        let field_name_slice = field_name_vec.as_slice();
538        assert_eq!(field_names, field_name_slice);
539
540        // &FieldNames == &[FieldName]
541        assert_eq!(&field_names, field_name_slice);
542
543        // Test inequality cases
544        assert_ne!(field_names, &["field1", "field2"][..]);
545        assert_ne!(field_names, ["different", "fields", "here"]);
546        assert_ne!(field_names, &["field1", "field2", "field3", "extra"][..]);
547    }
548
549    #[test]
550    fn test_field_names_display() {
551        let names = FieldNames::from(["a", "b", "c"]);
552        let f = format!("{names}");
553
554        assert_eq!(f, r#"["a", "b", "c"]"#);
555    }
556}