vortex_dtype/
dtype.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::{Debug, Display, Formatter};
5use std::hash::Hash;
6use std::ops::Index;
7use std::sync::Arc;
8
9use DType::*;
10use itertools::Itertools;
11use static_assertions::const_assert_eq;
12use vortex_error::vortex_panic;
13
14use crate::decimal::DecimalDType;
15use crate::nullability::Nullability;
16use crate::{ExtDType, FieldDType, PType, StructFields};
17
18/// A name for a field in a struct
19pub type FieldName = Arc<str>;
20
21/// An ordered list of field names in a struct
22#[derive(Clone, PartialEq, Eq, Debug, Default, Hash)]
23#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
24pub struct FieldNames(Arc<[FieldName]>);
25
26impl FieldNames {
27    /// Returns the number of elements.
28    pub fn len(&self) -> usize {
29        self.0.len()
30    }
31
32    /// Returns true if the number of elements is 0.
33    pub fn is_empty(&self) -> bool {
34        self.len() == 0
35    }
36
37    /// Returns a borrowed iterator over the field names.
38    pub fn iter(&self) -> impl ExactSizeIterator<Item = &FieldName> {
39        FieldNamesIter {
40            inner: self,
41            idx: 0,
42        }
43    }
44
45    /// Returns a reference to a field name, or None if `index` is out of bounds.
46    pub fn get(&self, index: usize) -> Option<&FieldName> {
47        self.0.get(index)
48    }
49}
50
51impl AsRef<[FieldName]> for FieldNames {
52    fn as_ref(&self) -> &[FieldName] {
53        &self.0
54    }
55}
56
57impl Index<usize> for FieldNames {
58    type Output = FieldName;
59
60    fn index(&self, index: usize) -> &Self::Output {
61        &self.0[index]
62    }
63}
64
65/// Iterator of references to field names
66pub struct FieldNamesIter<'a> {
67    inner: &'a FieldNames,
68    idx: usize,
69}
70
71impl<'a> Iterator for FieldNamesIter<'a> {
72    type Item = &'a FieldName;
73
74    fn next(&mut self) -> Option<Self::Item> {
75        if self.idx >= self.inner.len() {
76            return None;
77        }
78
79        let i = &self.inner.0[self.idx];
80        self.idx += 1;
81        Some(i)
82    }
83
84    fn size_hint(&self) -> (usize, Option<usize>) {
85        let len = self.inner.len() - self.idx;
86        (len, Some(len))
87    }
88}
89
90impl ExactSizeIterator for FieldNamesIter<'_> {}
91
92/// Owned iterator of field names.
93pub struct FieldNamesIntoIter {
94    inner: FieldNames,
95    idx: usize,
96}
97
98impl Iterator for FieldNamesIntoIter {
99    type Item = FieldName;
100
101    fn next(&mut self) -> Option<Self::Item> {
102        if self.idx >= self.inner.len() {
103            return None;
104        }
105
106        let i = self.inner.0[self.idx].clone();
107        self.idx += 1;
108        Some(i)
109    }
110
111    fn size_hint(&self) -> (usize, Option<usize>) {
112        let len = self.inner.len() - self.idx;
113        (len, Some(len))
114    }
115}
116
117impl ExactSizeIterator for FieldNamesIntoIter {}
118
119impl IntoIterator for FieldNames {
120    type Item = FieldName;
121
122    type IntoIter = FieldNamesIntoIter;
123
124    fn into_iter(self) -> Self::IntoIter {
125        FieldNamesIntoIter {
126            inner: self,
127            idx: 0,
128        }
129    }
130}
131
132impl From<Vec<FieldName>> for FieldNames {
133    fn from(value: Vec<FieldName>) -> Self {
134        Self(value.into())
135    }
136}
137
138impl From<&[&'static str]> for FieldNames {
139    fn from(value: &[&'static str]) -> Self {
140        Self(value.iter().cloned().map(Arc::from).collect())
141    }
142}
143
144impl From<&[FieldName]> for FieldNames {
145    fn from(value: &[FieldName]) -> Self {
146        Self(Arc::from(value))
147    }
148}
149
150impl<const N: usize> From<[&'static str; N]> for FieldNames {
151    fn from(value: [&'static str; N]) -> Self {
152        Self(value.into_iter().map(Arc::from).collect())
153    }
154}
155
156impl<const N: usize> From<[FieldName; N]> for FieldNames {
157    fn from(value: [FieldName; N]) -> Self {
158        Self(value.into())
159    }
160}
161
162impl<F: Into<FieldName>> FromIterator<F> for FieldNames {
163    fn from_iter<T: IntoIterator<Item = F>>(iter: T) -> Self {
164        Self(iter.into_iter().map(|v| v.into()).collect())
165    }
166}
167
168/// The logical types of elements in Vortex arrays.
169///
170/// Vortex arrays preserve a single logical type, while the encodings allow for multiple
171/// physical ways to encode that type.
172#[derive(Debug, Clone, PartialEq, Eq, Hash)]
173#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
174pub enum DType {
175    /// The logical null type (only has a single value, `null`)
176    Null,
177    /// The logical boolean type (`true` or `false` if non-nullable; `true`, `false`, or `null` if nullable)
178    Bool(Nullability),
179    /// Primitive, fixed-width numeric types (e.g., `u8`, `i8`, `u16`, `i16`, `u32`, `i32`, `u64`, `i64`, `f32`, `f64`)
180    Primitive(PType, Nullability),
181    /// Real numbers with fixed exact precision and scale.
182    Decimal(DecimalDType, Nullability),
183    /// UTF-8 strings
184    Utf8(Nullability),
185    /// Binary data
186    Binary(Nullability),
187    /// A struct is composed of an ordered list of fields, each with a corresponding name and DType
188    Struct(StructFields, Nullability),
189    /// A variable-length list type, parameterized by a single element DType
190    List(Arc<DType>, Nullability),
191    /// User-defined extension types
192    Extension(Arc<ExtDType>),
193}
194
195#[cfg(not(target_arch = "wasm32"))]
196const_assert_eq!(size_of::<DType>(), 16);
197
198#[cfg(target_arch = "wasm32")]
199const_assert_eq!(size_of::<DType>(), 8);
200
201impl DType {
202    /// The default DType for bytes
203    pub const BYTES: Self = Primitive(PType::U8, Nullability::NonNullable);
204
205    /// Get the nullability of the DType
206    pub fn nullability(&self) -> Nullability {
207        self.is_nullable().into()
208    }
209
210    /// Check if the DType is nullable
211    pub fn is_nullable(&self) -> bool {
212        use crate::nullability::Nullability::*;
213
214        match self {
215            Null => true,
216            Extension(ext_dtype) => ext_dtype.storage_dtype().is_nullable(),
217            Bool(n)
218            | Primitive(_, n)
219            | Decimal(_, n)
220            | Utf8(n)
221            | Binary(n)
222            | Struct(_, n)
223            | List(_, n) => matches!(n, Nullable),
224        }
225    }
226
227    /// Get a new DType with `Nullability::NonNullable` (but otherwise the same as `self`)
228    pub fn as_nonnullable(&self) -> Self {
229        self.with_nullability(Nullability::NonNullable)
230    }
231
232    /// Get a new DType with `Nullability::Nullable` (but otherwise the same as `self`)
233    pub fn as_nullable(&self) -> Self {
234        self.with_nullability(Nullability::Nullable)
235    }
236
237    /// Get a new DType with the given nullability (but otherwise the same as `self`)
238    pub fn with_nullability(&self, nullability: Nullability) -> Self {
239        match self {
240            Null => Null,
241            Bool(_) => Bool(nullability),
242            Primitive(p, _) => Primitive(*p, nullability),
243            Decimal(d, _) => Decimal(*d, nullability),
244            Utf8(_) => Utf8(nullability),
245            Binary(_) => Binary(nullability),
246            Struct(st, _) => Struct(st.clone(), nullability),
247            List(c, _) => List(c.clone(), nullability),
248            Extension(ext) => Extension(Arc::new(ext.with_nullability(nullability))),
249        }
250    }
251
252    /// Union the nullability of this dtype with the other nullability, returning a new dtype.
253    pub fn union_nullability(&self, other: Nullability) -> Self {
254        let nullability = self.nullability() | other;
255        self.with_nullability(nullability)
256    }
257
258    /// Check if `self` and `other` are equal, ignoring nullability
259    pub fn eq_ignore_nullability(&self, other: &Self) -> bool {
260        match (self, other) {
261            (Null, Null) => true,
262            (Bool(_), Bool(_)) => true,
263            (Primitive(lhs_ptype, _), Primitive(rhs_ptype, _)) => lhs_ptype == rhs_ptype,
264            (Decimal(lhs, _), Decimal(rhs, _)) => lhs == rhs,
265            (Utf8(_), Utf8(_)) => true,
266            (Binary(_), Binary(_)) => true,
267            (List(lhs_dtype, _), List(rhs_dtype, _)) => lhs_dtype.eq_ignore_nullability(rhs_dtype),
268            (Struct(lhs_dtype, _), Struct(rhs_dtype, _)) => {
269                (lhs_dtype.names() == rhs_dtype.names())
270                    && (lhs_dtype
271                        .fields()
272                        .zip_eq(rhs_dtype.fields())
273                        .all(|(l, r)| l.eq_ignore_nullability(&r)))
274            }
275            (Extension(lhs_extdtype), Extension(rhs_extdtype)) => {
276                lhs_extdtype.as_ref().eq_ignore_nullability(rhs_extdtype)
277            }
278            _ => false,
279        }
280    }
281
282    /// Check if `self` is a `StructDType`
283    pub fn is_struct(&self) -> bool {
284        matches!(self, Struct(_, _))
285    }
286
287    /// Check if `self` is a `ListDType`
288    pub fn is_list(&self) -> bool {
289        matches!(self, List(_, _))
290    }
291
292    /// Check if `self` is a primitive tpye
293    pub fn is_primitive(&self) -> bool {
294        matches!(self, Primitive(_, _))
295    }
296
297    /// Returns this DType's `PType` if it is a primitive type, otherwise panics.
298    pub fn as_ptype(&self) -> PType {
299        match self {
300            Primitive(ptype, _) => *ptype,
301            _ => vortex_panic!("DType is not a primitive type"),
302        }
303    }
304
305    /// Check if `self` is an unsigned integer
306    pub fn is_unsigned_int(&self) -> bool {
307        if let Primitive(ptype, _) = self {
308            return ptype.is_unsigned_int();
309        }
310        false
311    }
312
313    /// Check if `self` is a signed integer
314    pub fn is_signed_int(&self) -> bool {
315        if let Primitive(ptype, _) = self {
316            return ptype.is_signed_int();
317        }
318        false
319    }
320
321    /// Check if `self` is an integer (signed or unsigned)
322    pub fn is_int(&self) -> bool {
323        if let Primitive(ptype, _) = self {
324            return ptype.is_int();
325        }
326        false
327    }
328
329    /// Check if `self` is a floating point number
330    pub fn is_float(&self) -> bool {
331        if let Primitive(ptype, _) = self {
332            return ptype.is_float();
333        }
334        false
335    }
336
337    /// Check if `self` is a boolean
338    pub fn is_boolean(&self) -> bool {
339        matches!(self, Bool(_))
340    }
341
342    /// Check if `self` is a binary
343    pub fn is_binary(&self) -> bool {
344        matches!(self, Binary(_))
345    }
346
347    /// Check if `self` is a utf8
348    pub fn is_utf8(&self) -> bool {
349        matches!(self, Utf8(_))
350    }
351
352    /// Check if `self` is an extension type
353    pub fn is_extension(&self) -> bool {
354        matches!(self, Extension(_))
355    }
356
357    /// Check if `self` is a decimal type
358    pub fn is_decimal(&self) -> bool {
359        matches!(self, Decimal(..))
360    }
361
362    /// Check returns the inner decimal type if the dtype is a decimal
363    pub fn as_decimal(&self) -> Option<&DecimalDType> {
364        match self {
365            Decimal(decimal, _) => Some(decimal),
366            _ => None,
367        }
368    }
369
370    /// Get the `StructDType` if `self` is a `StructDType`, otherwise `None`
371    pub fn as_struct(&self) -> Option<&StructFields> {
372        match self {
373            Struct(s, _) => Some(s),
374            _ => None,
375        }
376    }
377
378    /// Get the inner dtype if `self` is a `ListDType`, otherwise `None`
379    pub fn as_list_element(&self) -> Option<&Arc<DType>> {
380        match self {
381            List(s, _) => Some(s),
382            _ => None,
383        }
384    }
385
386    /// Convenience method for creating a struct dtype
387    pub fn struct_<I: IntoIterator<Item = (impl Into<FieldName>, impl Into<FieldDType>)>>(
388        iter: I,
389        nullability: Nullability,
390    ) -> Self {
391        Struct(StructFields::from_iter(iter), nullability)
392    }
393
394    /// Convenience method for creating a list dtype
395    pub fn list(dtype: impl Into<DType>, nullability: Nullability) -> Self {
396        List(Arc::new(dtype.into()), nullability)
397    }
398}
399
400impl Display for DType {
401    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
402        match self {
403            Null => write!(f, "null"),
404            Bool(n) => write!(f, "bool{n}"),
405            Primitive(pt, n) => write!(f, "{pt}{n}"),
406            Decimal(dt, n) => write!(f, "{dt}{n}"),
407            Utf8(n) => write!(f, "utf8{n}"),
408            Binary(n) => write!(f, "binary{n}"),
409            Struct(sdt, n) => write!(
410                f,
411                "{{{}}}{}",
412                sdt.names()
413                    .iter()
414                    .zip(sdt.fields())
415                    .map(|(n, dt)| format!("{n}={dt}"))
416                    .join(", "),
417                n
418            ),
419            List(edt, n) => write!(f, "list({edt}){n}"),
420            Extension(ext) => write!(
421                f,
422                "ext({}, {}{}){}",
423                ext.id(),
424                ext.storage_dtype()
425                    .with_nullability(Nullability::NonNullable),
426                ext.metadata()
427                    .map(|m| format!(", {m:?}"))
428                    .unwrap_or_else(|| "".to_string()),
429                ext.storage_dtype().nullability(),
430            ),
431        }
432    }
433}
434
435#[cfg(test)]
436mod tests {
437    use super::*;
438
439    #[test]
440    fn test_field_names_iter() {
441        let names = ["a", "b"];
442        let field_names = FieldNames::from(names);
443        assert_eq!(field_names.iter().len(), names.len());
444        let mut iter = field_names.iter();
445        assert_eq!(iter.next(), Some(&"a".into()));
446        assert_eq!(iter.next(), Some(&"b".into()));
447        assert_eq!(iter.next(), None);
448    }
449
450    #[test]
451    fn test_field_names_owned_iter() {
452        let names = ["a", "b"];
453        let field_names = FieldNames::from(names);
454        assert_eq!(field_names.clone().into_iter().len(), names.len());
455        let mut iter = field_names.into_iter();
456        assert_eq!(iter.next(), Some("a".into()));
457        assert_eq!(iter.next(), Some("b".into()));
458        assert_eq!(iter.next(), None);
459    }
460}