vortex_dtype/
dtype.rs

1use std::fmt::{Debug, Display, Formatter};
2use std::hash::Hash;
3use std::ops::Index;
4use std::sync::Arc;
5
6use DType::*;
7use itertools::Itertools;
8use static_assertions::const_assert_eq;
9use vortex_error::vortex_panic;
10
11use crate::decimal::DecimalDType;
12use crate::nullability::Nullability;
13use crate::{ExtDType, FieldDType, PType, StructFields};
14
15/// A name for a field in a struct
16pub type FieldName = Arc<str>;
17
18/// An ordered list of field names in a struct
19#[derive(Clone, PartialEq, Eq, Debug, Default, Hash)]
20#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
21pub struct FieldNames(Arc<[FieldName]>);
22
23impl FieldNames {
24    /// Returns the number of elements.
25    pub fn len(&self) -> usize {
26        self.0.len()
27    }
28
29    /// Returns true if the number of elements is 0.
30    pub fn is_empty(&self) -> bool {
31        self.len() == 0
32    }
33
34    /// Returns a borrowed iterator over the field names.
35    pub fn iter(&self) -> impl ExactSizeIterator<Item = &FieldName> {
36        FieldNamesIter {
37            inner: self,
38            idx: 0,
39        }
40    }
41
42    /// Returns a reference to a field name, or None if `index` is out of bounds.
43    pub fn get(&self, index: usize) -> Option<&FieldName> {
44        self.0.get(index)
45    }
46}
47
48impl AsRef<[FieldName]> for FieldNames {
49    fn as_ref(&self) -> &[FieldName] {
50        &self.0
51    }
52}
53
54impl Index<usize> for FieldNames {
55    type Output = FieldName;
56
57    fn index(&self, index: usize) -> &Self::Output {
58        &self.0[index]
59    }
60}
61
62/// Iterator of references to field names
63pub struct FieldNamesIter<'a> {
64    inner: &'a FieldNames,
65    idx: usize,
66}
67
68impl<'a> Iterator for FieldNamesIter<'a> {
69    type Item = &'a FieldName;
70
71    fn next(&mut self) -> Option<Self::Item> {
72        if self.idx >= self.inner.len() {
73            return None;
74        }
75
76        let i = &self.inner.0[self.idx];
77        self.idx += 1;
78        Some(i)
79    }
80}
81
82impl ExactSizeIterator for FieldNamesIter<'_> {
83    fn len(&self) -> usize {
84        self.inner.len() - self.idx
85    }
86}
87
88/// Owned iterator of field names.
89pub struct FieldNamesIntoIter {
90    inner: FieldNames,
91    idx: usize,
92}
93
94impl Iterator for FieldNamesIntoIter {
95    type Item = FieldName;
96
97    fn next(&mut self) -> Option<Self::Item> {
98        if self.idx >= self.inner.len() {
99            return None;
100        }
101
102        let i = self.inner.0[self.idx].clone();
103        self.idx += 1;
104        Some(i)
105    }
106}
107
108impl ExactSizeIterator for FieldNamesIntoIter {
109    fn len(&self) -> usize {
110        self.inner.len() - self.idx
111    }
112}
113
114impl IntoIterator for FieldNames {
115    type Item = FieldName;
116
117    type IntoIter = FieldNamesIntoIter;
118
119    fn into_iter(self) -> Self::IntoIter {
120        FieldNamesIntoIter {
121            inner: self,
122            idx: 0,
123        }
124    }
125}
126
127impl From<Vec<FieldName>> for FieldNames {
128    fn from(value: Vec<FieldName>) -> Self {
129        Self(value.into())
130    }
131}
132
133impl From<&[&'static str]> for FieldNames {
134    fn from(value: &[&'static str]) -> Self {
135        Self(value.iter().cloned().map(Arc::from).collect())
136    }
137}
138
139impl From<&[FieldName]> for FieldNames {
140    fn from(value: &[FieldName]) -> Self {
141        Self(Arc::from(value))
142    }
143}
144
145impl<const N: usize> From<[&'static str; N]> for FieldNames {
146    fn from(value: [&'static str; N]) -> Self {
147        Self(value.into_iter().map(Arc::from).collect())
148    }
149}
150
151impl<const N: usize> From<[FieldName; N]> for FieldNames {
152    fn from(value: [FieldName; N]) -> Self {
153        Self(value.into())
154    }
155}
156
157impl<F: Into<FieldName>> FromIterator<F> for FieldNames {
158    fn from_iter<T: IntoIterator<Item = F>>(iter: T) -> Self {
159        Self(iter.into_iter().map(|v| v.into()).collect())
160    }
161}
162
163/// The logical types of elements in Vortex arrays.
164///
165/// Vortex arrays preserve a single logical type, while the encodings allow for multiple
166/// physical ways to encode that type.
167#[derive(Debug, Clone, PartialEq, Eq, Hash)]
168#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
169pub enum DType {
170    /// The logical null type (only has a single value, `null`)
171    Null,
172    /// The logical boolean type (`true` or `false` if non-nullable; `true`, `false`, or `null` if nullable)
173    Bool(Nullability),
174    /// Primitive, fixed-width numeric types (e.g., `u8`, `i8`, `u16`, `i16`, `u32`, `i32`, `u64`, `i64`, `f32`, `f64`)
175    Primitive(PType, Nullability),
176    /// Real numbers with fixed exact precision and scale.
177    Decimal(DecimalDType, Nullability),
178    /// UTF-8 strings
179    Utf8(Nullability),
180    /// Binary data
181    Binary(Nullability),
182    /// A struct is composed of an ordered list of fields, each with a corresponding name and DType
183    Struct(StructFields, Nullability),
184    /// A variable-length list type, parameterized by a single element DType
185    List(Arc<DType>, Nullability),
186    /// User-defined extension types
187    Extension(Arc<ExtDType>),
188}
189
190#[cfg(not(target_arch = "wasm32"))]
191const_assert_eq!(size_of::<DType>(), 16);
192
193#[cfg(target_arch = "wasm32")]
194const_assert_eq!(size_of::<DType>(), 8);
195
196impl DType {
197    /// The default DType for bytes
198    pub const BYTES: Self = Primitive(PType::U8, Nullability::NonNullable);
199
200    /// Get the nullability of the DType
201    pub fn nullability(&self) -> Nullability {
202        self.is_nullable().into()
203    }
204
205    /// Check if the DType is nullable
206    pub fn is_nullable(&self) -> bool {
207        use crate::nullability::Nullability::*;
208
209        match self {
210            Null => true,
211            Extension(ext_dtype) => ext_dtype.storage_dtype().is_nullable(),
212            Bool(n)
213            | Primitive(_, n)
214            | Decimal(_, n)
215            | Utf8(n)
216            | Binary(n)
217            | Struct(_, n)
218            | List(_, n) => matches!(n, Nullable),
219        }
220    }
221
222    /// Get a new DType with `Nullability::NonNullable` (but otherwise the same as `self`)
223    pub fn as_nonnullable(&self) -> Self {
224        self.with_nullability(Nullability::NonNullable)
225    }
226
227    /// Get a new DType with `Nullability::Nullable` (but otherwise the same as `self`)
228    pub fn as_nullable(&self) -> Self {
229        self.with_nullability(Nullability::Nullable)
230    }
231
232    /// Get a new DType with the given nullability (but otherwise the same as `self`)
233    pub fn with_nullability(&self, nullability: Nullability) -> Self {
234        match self {
235            Null => Null,
236            Bool(_) => Bool(nullability),
237            Primitive(p, _) => Primitive(*p, nullability),
238            Decimal(d, _) => Decimal(*d, nullability),
239            Utf8(_) => Utf8(nullability),
240            Binary(_) => Binary(nullability),
241            Struct(st, _) => Struct(st.clone(), nullability),
242            List(c, _) => List(c.clone(), nullability),
243            Extension(ext) => Extension(Arc::new(ext.with_nullability(nullability))),
244        }
245    }
246
247    /// Union the nullability of this dtype with the other nullability, returning a new dtype.
248    pub fn union_nullability(&self, other: Nullability) -> Self {
249        let nullability = self.nullability() | other;
250        self.with_nullability(nullability)
251    }
252
253    /// Check if `self` and `other` are equal, ignoring nullability
254    pub fn eq_ignore_nullability(&self, other: &Self) -> bool {
255        match (self, other) {
256            (Null, Null) => true,
257            (Bool(_), Bool(_)) => true,
258            (Primitive(lhs_ptype, _), Primitive(rhs_ptype, _)) => lhs_ptype == rhs_ptype,
259            (Decimal(lhs, _), Decimal(rhs, _)) => lhs == rhs,
260            (Utf8(_), Utf8(_)) => true,
261            (Binary(_), Binary(_)) => true,
262            (List(lhs_dtype, _), List(rhs_dtype, _)) => lhs_dtype.eq_ignore_nullability(rhs_dtype),
263            (Struct(lhs_dtype, _), Struct(rhs_dtype, _)) => {
264                (lhs_dtype.names() == rhs_dtype.names())
265                    && (lhs_dtype
266                        .fields()
267                        .zip_eq(rhs_dtype.fields())
268                        .all(|(l, r)| l.eq_ignore_nullability(&r)))
269            }
270            (Extension(lhs_extdtype), Extension(rhs_extdtype)) => {
271                lhs_extdtype.as_ref().eq_ignore_nullability(rhs_extdtype)
272            }
273            _ => false,
274        }
275    }
276
277    /// Check if `self` is a `StructDType`
278    pub fn is_struct(&self) -> bool {
279        matches!(self, Struct(_, _))
280    }
281
282    /// Check if `self` is a primitive tpye
283    pub fn is_primitive(&self) -> bool {
284        matches!(self, Primitive(_, _))
285    }
286
287    /// Returns this DType's `PType` if it is a primitive type, otherwise panics.
288    pub fn as_ptype(&self) -> PType {
289        match self {
290            Primitive(ptype, _) => *ptype,
291            _ => vortex_panic!("DType is not a primitive type"),
292        }
293    }
294
295    /// Check if `self` is an unsigned integer
296    pub fn is_unsigned_int(&self) -> bool {
297        if let Primitive(ptype, _) = self {
298            return ptype.is_unsigned_int();
299        }
300        false
301    }
302
303    /// Check if `self` is a signed integer
304    pub fn is_signed_int(&self) -> bool {
305        if let Primitive(ptype, _) = self {
306            return ptype.is_signed_int();
307        }
308        false
309    }
310
311    /// Check if `self` is an integer (signed or unsigned)
312    pub fn is_int(&self) -> bool {
313        if let Primitive(ptype, _) = self {
314            return ptype.is_int();
315        }
316        false
317    }
318
319    /// Check if `self` is a floating point number
320    pub fn is_float(&self) -> bool {
321        if let Primitive(ptype, _) = self {
322            return ptype.is_float();
323        }
324        false
325    }
326
327    /// Check if `self` is a boolean
328    pub fn is_boolean(&self) -> bool {
329        matches!(self, Bool(_))
330    }
331
332    /// Check if `self` is a binary
333    pub fn is_binary(&self) -> bool {
334        matches!(self, Binary(_))
335    }
336
337    /// Check if `self` is a utf8
338    pub fn is_utf8(&self) -> bool {
339        matches!(self, Utf8(_))
340    }
341
342    /// Check if `self` is an extension type
343    pub fn is_extension(&self) -> bool {
344        matches!(self, Extension(_))
345    }
346
347    /// Check if `self` is a decimal type
348    pub fn is_decimal(&self) -> bool {
349        matches!(self, Decimal(..))
350    }
351
352    /// Check returns the inner decimal type if the dtype is a decimal
353    pub fn as_decimal(&self) -> Option<&DecimalDType> {
354        match self {
355            Decimal(decimal, _) => Some(decimal),
356            _ => None,
357        }
358    }
359
360    /// Get the `StructDType` if `self` is a `StructDType`, otherwise `None`
361    pub fn as_struct(&self) -> Option<&StructFields> {
362        match self {
363            Struct(s, _) => Some(s),
364            _ => None,
365        }
366    }
367
368    /// Get the inner dtype if `self` is a `ListDType`, otherwise `None`
369    pub fn as_list_element(&self) -> Option<&Arc<DType>> {
370        match self {
371            List(s, _) => Some(s),
372            _ => None,
373        }
374    }
375
376    /// Convenience method for creating a struct dtype
377    pub fn struct_<I: IntoIterator<Item = (impl Into<FieldName>, impl Into<FieldDType>)>>(
378        iter: I,
379        nullability: Nullability,
380    ) -> Self {
381        Struct(StructFields::from_iter(iter), nullability)
382    }
383
384    /// Convenience method for creating a list dtype
385    pub fn list(dtype: impl Into<DType>, nullability: Nullability) -> Self {
386        List(Arc::new(dtype.into()), nullability)
387    }
388}
389
390impl Display for DType {
391    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
392        match self {
393            Null => write!(f, "null"),
394            Bool(n) => write!(f, "bool{n}"),
395            Primitive(pt, n) => write!(f, "{pt}{n}"),
396            Decimal(dt, n) => write!(f, "{dt}{n}"),
397            Utf8(n) => write!(f, "utf8{n}"),
398            Binary(n) => write!(f, "binary{n}"),
399            Struct(sdt, n) => write!(
400                f,
401                "{{{}}}{}",
402                sdt.names()
403                    .iter()
404                    .zip(sdt.fields())
405                    .map(|(n, dt)| format!("{n}={dt}"))
406                    .join(", "),
407                n
408            ),
409            List(edt, n) => write!(f, "list({edt}){n}"),
410            Extension(ext) => write!(
411                f,
412                "ext({}, {}{}){}",
413                ext.id(),
414                ext.storage_dtype()
415                    .with_nullability(Nullability::NonNullable),
416                ext.metadata()
417                    .map(|m| format!(", {m:?}"))
418                    .unwrap_or_else(|| "".to_string()),
419                ext.storage_dtype().nullability(),
420            ),
421        }
422    }
423}