Skip to main content

vortex_array/
variants.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! This module defines extension functionality specific to each Vortex DType.
5use std::cmp::Ordering;
6
7use vortex_error::VortexExpect;
8use vortex_error::VortexResult;
9use vortex_error::vortex_bail;
10use vortex_error::vortex_panic;
11use vortex_mask::Mask;
12
13use crate::Array;
14use crate::ExecutionCtx;
15use crate::arrays::BoolArray;
16use crate::builtins::ArrayBuiltins;
17use crate::compute::sum;
18use crate::dtype::DType;
19use crate::dtype::FieldNames;
20use crate::dtype::PType;
21use crate::dtype::extension::ExtDTypeRef;
22use crate::scalar::PValue;
23use crate::scalar::Scalar;
24use crate::search_sorted::IndexOrd;
25
26impl dyn Array + '_ {
27    /// Downcasts the array for null-specific behavior.
28    pub fn as_null_typed(&self) -> NullTyped<'_> {
29        matches!(self.dtype(), DType::Null)
30            .then(|| NullTyped(self))
31            .vortex_expect("Array does not have DType::Null")
32    }
33
34    /// Downcasts the array for bool-specific behavior.
35    pub fn as_bool_typed(&self) -> BoolTyped<'_> {
36        matches!(self.dtype(), DType::Bool(..))
37            .then(|| BoolTyped(self))
38            .vortex_expect("Array does not have DType::Bool")
39    }
40
41    /// Downcasts the array for primitive-specific behavior.
42    pub fn as_primitive_typed(&self) -> PrimitiveTyped<'_> {
43        matches!(self.dtype(), DType::Primitive(..))
44            .then(|| PrimitiveTyped(self))
45            .vortex_expect("Array does not have DType::Primitive")
46    }
47
48    /// Downcasts the array for decimal-specific behavior.
49    pub fn as_decimal_typed(&self) -> DecimalTyped<'_> {
50        matches!(self.dtype(), DType::Decimal(..))
51            .then(|| DecimalTyped(self))
52            .vortex_expect("Array does not have DType::Decimal")
53    }
54
55    /// Downcasts the array for utf8-specific behavior.
56    pub fn as_utf8_typed(&self) -> Utf8Typed<'_> {
57        matches!(self.dtype(), DType::Utf8(..))
58            .then(|| Utf8Typed(self))
59            .vortex_expect("Array does not have DType::Utf8")
60    }
61
62    /// Downcasts the array for binary-specific behavior.
63    pub fn as_binary_typed(&self) -> BinaryTyped<'_> {
64        matches!(self.dtype(), DType::Binary(..))
65            .then(|| BinaryTyped(self))
66            .vortex_expect("Array does not have DType::Binary")
67    }
68
69    /// Downcasts the array for struct-specific behavior.
70    pub fn as_struct_typed(&self) -> StructTyped<'_> {
71        matches!(self.dtype(), DType::Struct(..))
72            .then(|| StructTyped(self))
73            .vortex_expect("Array does not have DType::Struct")
74    }
75
76    /// Downcasts the array for list-specific behavior.
77    pub fn as_list_typed(&self) -> ListTyped<'_> {
78        matches!(self.dtype(), DType::List(..))
79            .then(|| ListTyped(self))
80            .vortex_expect("Array does not have DType::List")
81    }
82
83    /// Downcasts the array for extension-specific behavior.
84    pub fn as_extension_typed(&self) -> ExtensionTyped<'_> {
85        matches!(self.dtype(), DType::Extension(..))
86            .then(|| ExtensionTyped(self))
87            .vortex_expect("Array does not have DType::Extension")
88    }
89
90    pub fn try_to_mask_fill_null_false(&self, ctx: &mut ExecutionCtx) -> VortexResult<Mask> {
91        if !matches!(self.dtype(), DType::Bool(_)) {
92            vortex_bail!("mask must be bool array, has dtype {}", self.dtype());
93        }
94
95        // Convert nulls to false first in case this can be done cheaply by the encoding.
96        let array = self
97            .to_array()
98            .fill_null(Scalar::bool(false, self.dtype().nullability()))?;
99
100        Ok(array.execute::<BoolArray>(ctx)?.to_mask_fill_null_false())
101    }
102}
103
104#[expect(dead_code)]
105pub struct NullTyped<'a>(&'a dyn Array);
106
107pub struct BoolTyped<'a>(&'a dyn Array);
108
109impl BoolTyped<'_> {
110    pub fn true_count(&self) -> VortexResult<usize> {
111        let true_count = sum(&self.0.to_array())?;
112        Ok(true_count
113            .as_primitive()
114            .as_::<usize>()
115            .vortex_expect("true count should never be null"))
116    }
117}
118
119pub struct PrimitiveTyped<'a>(&'a dyn Array);
120
121impl PrimitiveTyped<'_> {
122    pub fn ptype(&self) -> PType {
123        let DType::Primitive(ptype, _) = self.0.dtype() else {
124            vortex_panic!("Expected Primitive DType")
125        };
126        *ptype
127    }
128
129    /// Return the primitive value at the given index.
130    pub fn value(&self, idx: usize) -> VortexResult<Option<PValue>> {
131        self.0
132            .is_valid(idx)?
133            .then(|| self.value_unchecked(idx))
134            .transpose()
135    }
136
137    /// Return the primitive value at the given index, ignoring nullability.
138    pub fn value_unchecked(&self, idx: usize) -> VortexResult<PValue> {
139        Ok(self
140            .0
141            .scalar_at(idx)?
142            .as_primitive()
143            .pvalue()
144            .unwrap_or_else(|| PValue::zero(&self.ptype())))
145    }
146}
147
148impl IndexOrd<Option<PValue>> for PrimitiveTyped<'_> {
149    fn index_cmp(&self, idx: usize, elem: &Option<PValue>) -> VortexResult<Option<Ordering>> {
150        let value = self.value(idx)?;
151        Ok(value.partial_cmp(elem))
152    }
153
154    fn index_len(&self) -> usize {
155        self.0.len()
156    }
157}
158
159// TODO(ngates): add generics to the `value` function and implement this over T.
160impl IndexOrd<PValue> for PrimitiveTyped<'_> {
161    fn index_cmp(&self, idx: usize, elem: &PValue) -> VortexResult<Option<Ordering>> {
162        assert!(self.0.all_valid()?);
163        let value = self.value_unchecked(idx)?;
164        Ok(value.partial_cmp(elem))
165    }
166
167    fn index_len(&self) -> usize {
168        self.0.len()
169    }
170}
171
172#[expect(dead_code)]
173pub struct Utf8Typed<'a>(&'a dyn Array);
174
175#[expect(dead_code)]
176pub struct BinaryTyped<'a>(&'a dyn Array);
177
178#[expect(dead_code)]
179pub struct DecimalTyped<'a>(&'a dyn Array);
180
181pub struct StructTyped<'a>(&'a dyn Array);
182
183impl StructTyped<'_> {
184    pub fn names(&self) -> &FieldNames {
185        let DType::Struct(st, _) = self.0.dtype() else {
186            unreachable!()
187        };
188        st.names()
189    }
190
191    pub fn dtypes(&self) -> Vec<DType> {
192        let DType::Struct(st, _) = self.0.dtype() else {
193            unreachable!()
194        };
195        st.fields().collect()
196    }
197
198    pub fn nfields(&self) -> usize {
199        self.names().len()
200    }
201}
202
203#[expect(dead_code)]
204pub struct ListTyped<'a>(&'a dyn Array);
205
206pub struct ExtensionTyped<'a>(&'a dyn Array);
207
208impl ExtensionTyped<'_> {
209    /// Returns the extension logical [`DType`].
210    pub fn ext_dtype(&self) -> &ExtDTypeRef {
211        let DType::Extension(ext_dtype) = self.0.dtype() else {
212            vortex_panic!("Expected ExtDType")
213        };
214        ext_dtype
215    }
216}