Skip to main content

vortex_array/
variants.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! This module defines extension functionality specific to each Vortex DType.
5use std::cmp::Ordering;
6
7use vortex_error::VortexExpect;
8use vortex_error::VortexResult;
9use vortex_error::vortex_bail;
10use vortex_error::vortex_panic;
11use vortex_mask::Mask;
12
13use crate::DynArray;
14use crate::ExecutionCtx;
15use crate::LEGACY_SESSION;
16use crate::VortexSessionExecute;
17use crate::aggregate_fn::fns::sum::sum;
18use crate::arrays::BoolArray;
19use crate::builtins::ArrayBuiltins;
20use crate::dtype::DType;
21use crate::dtype::FieldNames;
22use crate::dtype::PType;
23use crate::dtype::extension::ExtDTypeRef;
24use crate::scalar::PValue;
25use crate::scalar::Scalar;
26use crate::search_sorted::IndexOrd;
27
28impl dyn DynArray + '_ {
29    /// Downcasts the array for null-specific behavior.
30    pub fn as_null_typed(&self) -> NullTyped<'_> {
31        matches!(self.dtype(), DType::Null)
32            .then(|| NullTyped(self))
33            .vortex_expect("Array does not have DType::Null")
34    }
35
36    /// Downcasts the array for bool-specific behavior.
37    pub fn as_bool_typed(&self) -> BoolTyped<'_> {
38        matches!(self.dtype(), DType::Bool(..))
39            .then(|| BoolTyped(self))
40            .vortex_expect("Array does not have DType::Bool")
41    }
42
43    /// Downcasts the array for primitive-specific behavior.
44    pub fn as_primitive_typed(&self) -> PrimitiveTyped<'_> {
45        matches!(self.dtype(), DType::Primitive(..))
46            .then(|| PrimitiveTyped(self))
47            .vortex_expect("Array does not have DType::Primitive")
48    }
49
50    /// Downcasts the array for decimal-specific behavior.
51    pub fn as_decimal_typed(&self) -> DecimalTyped<'_> {
52        matches!(self.dtype(), DType::Decimal(..))
53            .then(|| DecimalTyped(self))
54            .vortex_expect("Array does not have DType::Decimal")
55    }
56
57    /// Downcasts the array for utf8-specific behavior.
58    pub fn as_utf8_typed(&self) -> Utf8Typed<'_> {
59        matches!(self.dtype(), DType::Utf8(..))
60            .then(|| Utf8Typed(self))
61            .vortex_expect("Array does not have DType::Utf8")
62    }
63
64    /// Downcasts the array for binary-specific behavior.
65    pub fn as_binary_typed(&self) -> BinaryTyped<'_> {
66        matches!(self.dtype(), DType::Binary(..))
67            .then(|| BinaryTyped(self))
68            .vortex_expect("Array does not have DType::Binary")
69    }
70
71    /// Downcasts the array for struct-specific behavior.
72    pub fn as_struct_typed(&self) -> StructTyped<'_> {
73        matches!(self.dtype(), DType::Struct(..))
74            .then(|| StructTyped(self))
75            .vortex_expect("Array does not have DType::Struct")
76    }
77
78    /// Downcasts the array for list-specific behavior.
79    pub fn as_list_typed(&self) -> ListTyped<'_> {
80        matches!(self.dtype(), DType::List(..))
81            .then(|| ListTyped(self))
82            .vortex_expect("Array does not have DType::List")
83    }
84
85    /// Downcasts the array for extension-specific behavior.
86    pub fn as_extension_typed(&self) -> ExtensionTyped<'_> {
87        matches!(self.dtype(), DType::Extension(..))
88            .then(|| ExtensionTyped(self))
89            .vortex_expect("Array does not have DType::Extension")
90    }
91
92    pub fn try_to_mask_fill_null_false(&self, ctx: &mut ExecutionCtx) -> VortexResult<Mask> {
93        if !matches!(self.dtype(), DType::Bool(_)) {
94            vortex_bail!("mask must be bool array, has dtype {}", self.dtype());
95        }
96
97        // Convert nulls to false first in case this can be done cheaply by the encoding.
98        let array = self
99            .to_array()
100            .fill_null(Scalar::bool(false, self.dtype().nullability()))?;
101
102        Ok(array.execute::<BoolArray>(ctx)?.to_mask_fill_null_false())
103    }
104}
105
106#[expect(dead_code)]
107pub struct NullTyped<'a>(&'a dyn DynArray);
108
109pub struct BoolTyped<'a>(&'a dyn DynArray);
110
111impl BoolTyped<'_> {
112    pub fn true_count(&self) -> VortexResult<usize> {
113        let mut ctx = LEGACY_SESSION.create_execution_ctx();
114        let true_count = sum(&self.0.to_array(), &mut ctx)?;
115        Ok(true_count
116            .as_primitive()
117            .as_::<usize>()
118            .vortex_expect("true count should never be null"))
119    }
120}
121
122pub struct PrimitiveTyped<'a>(&'a dyn DynArray);
123
124impl PrimitiveTyped<'_> {
125    pub fn ptype(&self) -> PType {
126        let DType::Primitive(ptype, _) = self.0.dtype() else {
127            vortex_panic!("Expected Primitive DType")
128        };
129        *ptype
130    }
131
132    /// Return the primitive value at the given index.
133    pub fn value(&self, idx: usize) -> VortexResult<Option<PValue>> {
134        self.0
135            .is_valid(idx)?
136            .then(|| self.value_unchecked(idx))
137            .transpose()
138    }
139
140    /// Return the primitive value at the given index, ignoring nullability.
141    pub fn value_unchecked(&self, idx: usize) -> VortexResult<PValue> {
142        Ok(self
143            .0
144            .scalar_at(idx)?
145            .as_primitive()
146            .pvalue()
147            .unwrap_or_else(|| PValue::zero(&self.ptype())))
148    }
149}
150
151impl IndexOrd<Option<PValue>> for PrimitiveTyped<'_> {
152    fn index_cmp(&self, idx: usize, elem: &Option<PValue>) -> VortexResult<Option<Ordering>> {
153        let value = self.value(idx)?;
154        Ok(value.partial_cmp(elem))
155    }
156
157    fn index_len(&self) -> usize {
158        self.0.len()
159    }
160}
161
162// TODO(ngates): add generics to the `value` function and implement this over T.
163impl IndexOrd<PValue> for PrimitiveTyped<'_> {
164    fn index_cmp(&self, idx: usize, elem: &PValue) -> VortexResult<Option<Ordering>> {
165        assert!(self.0.all_valid()?);
166        let value = self.value_unchecked(idx)?;
167        Ok(value.partial_cmp(elem))
168    }
169
170    fn index_len(&self) -> usize {
171        self.0.len()
172    }
173}
174
175#[expect(dead_code)]
176pub struct Utf8Typed<'a>(&'a dyn DynArray);
177
178#[expect(dead_code)]
179pub struct BinaryTyped<'a>(&'a dyn DynArray);
180
181#[expect(dead_code)]
182pub struct DecimalTyped<'a>(&'a dyn DynArray);
183
184pub struct StructTyped<'a>(&'a dyn DynArray);
185
186impl StructTyped<'_> {
187    pub fn names(&self) -> &FieldNames {
188        let DType::Struct(st, _) = self.0.dtype() else {
189            unreachable!()
190        };
191        st.names()
192    }
193
194    pub fn dtypes(&self) -> Vec<DType> {
195        let DType::Struct(st, _) = self.0.dtype() else {
196            unreachable!()
197        };
198        st.fields().collect()
199    }
200
201    pub fn nfields(&self) -> usize {
202        self.names().len()
203    }
204}
205
206#[expect(dead_code)]
207pub struct ListTyped<'a>(&'a dyn DynArray);
208
209pub struct ExtensionTyped<'a>(&'a dyn DynArray);
210
211impl ExtensionTyped<'_> {
212    /// Returns the extension logical [`DType`].
213    pub fn ext_dtype(&self) -> &ExtDTypeRef {
214        let DType::Extension(ext_dtype) = self.0.dtype() else {
215            vortex_panic!("Expected ExtDType")
216        };
217        ext_dtype
218    }
219}