Skip to main content

vortex_array/
variants.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! This module defines extension functionality specific to each Vortex DType.
5use std::cmp::Ordering;
6
7use vortex_error::VortexExpect;
8use vortex_error::VortexResult;
9use vortex_error::vortex_bail;
10use vortex_error::vortex_panic;
11use vortex_mask::Mask;
12
13use crate::ArrayRef;
14use crate::ExecutionCtx;
15use crate::LEGACY_SESSION;
16use crate::VortexSessionExecute;
17use crate::aggregate_fn::fns::sum::sum;
18use crate::arrays::BoolArray;
19use crate::arrays::bool::BoolArrayExt;
20use crate::builtins::ArrayBuiltins;
21use crate::dtype::DType;
22use crate::dtype::FieldNames;
23use crate::dtype::PType;
24use crate::dtype::extension::ExtDTypeRef;
25use crate::scalar::PValue;
26use crate::scalar::Scalar;
27use crate::search_sorted::IndexOrd;
28
29impl ArrayRef {
30    /// Downcasts the array for null-specific behavior.
31    pub fn as_null_typed(&self) -> NullTyped<'_> {
32        matches!(self.dtype(), DType::Null)
33            .then(|| NullTyped(self))
34            .vortex_expect("Array does not have DType::Null")
35    }
36
37    /// Downcasts the array for bool-specific behavior.
38    pub fn as_bool_typed(&self) -> BoolTyped<'_> {
39        matches!(self.dtype(), DType::Bool(..))
40            .then(|| BoolTyped(self))
41            .vortex_expect("Array does not have DType::Bool")
42    }
43
44    /// Downcasts the array for primitive-specific behavior.
45    pub fn as_primitive_typed(&self) -> PrimitiveTyped<'_> {
46        matches!(self.dtype(), DType::Primitive(..))
47            .then(|| PrimitiveTyped(self))
48            .vortex_expect("Array does not have DType::Primitive")
49    }
50
51    /// Downcasts the array for decimal-specific behavior.
52    pub fn as_decimal_typed(&self) -> DecimalTyped<'_> {
53        matches!(self.dtype(), DType::Decimal(..))
54            .then(|| DecimalTyped(self))
55            .vortex_expect("Array does not have DType::Decimal")
56    }
57
58    /// Downcasts the array for utf8-specific behavior.
59    pub fn as_utf8_typed(&self) -> Utf8Typed<'_> {
60        matches!(self.dtype(), DType::Utf8(..))
61            .then(|| Utf8Typed(self))
62            .vortex_expect("Array does not have DType::Utf8")
63    }
64
65    /// Downcasts the array for binary-specific behavior.
66    pub fn as_binary_typed(&self) -> BinaryTyped<'_> {
67        matches!(self.dtype(), DType::Binary(..))
68            .then(|| BinaryTyped(self))
69            .vortex_expect("Array does not have DType::Binary")
70    }
71
72    /// Downcasts the array for struct-specific behavior.
73    pub fn as_struct_typed(&self) -> StructTyped<'_> {
74        matches!(self.dtype(), DType::Struct(..))
75            .then(|| StructTyped(self))
76            .vortex_expect("Array does not have DType::Struct")
77    }
78
79    /// Downcasts the array for list-specific behavior.
80    pub fn as_list_typed(&self) -> ListTyped<'_> {
81        matches!(self.dtype(), DType::List(..))
82            .then(|| ListTyped(self))
83            .vortex_expect("Array does not have DType::List")
84    }
85
86    /// Downcasts the array for extension-specific behavior.
87    pub fn as_extension_typed(&self) -> ExtensionTyped<'_> {
88        matches!(self.dtype(), DType::Extension(..))
89            .then(|| ExtensionTyped(self))
90            .vortex_expect("Array does not have DType::Extension")
91    }
92
93    pub fn try_to_mask_fill_null_false(&self, ctx: &mut ExecutionCtx) -> VortexResult<Mask> {
94        if !matches!(self.dtype(), DType::Bool(_)) {
95            vortex_bail!("mask must be bool array, has dtype {}", self.dtype());
96        }
97
98        // Convert nulls to false first in case this can be done cheaply by the encoding.
99        let array = self
100            .clone()
101            .fill_null(Scalar::bool(false, self.dtype().nullability()))?;
102
103        Ok(array
104            .execute::<BoolArray>(ctx)?
105            .to_mask_fill_null_false(ctx))
106    }
107}
108
109#[expect(dead_code)]
110pub struct NullTyped<'a>(&'a ArrayRef);
111
112pub struct BoolTyped<'a>(&'a ArrayRef);
113
114impl BoolTyped<'_> {
115    #[deprecated(
116        note = "Relies on the hidden global `LEGACY_SESSION`; use `sum(array, ctx)` with an explicit `ExecutionCtx` instead"
117    )]
118    pub fn true_count(&self) -> VortexResult<usize> {
119        let mut ctx = LEGACY_SESSION.create_execution_ctx();
120        let true_count = sum(self.0, &mut ctx)?;
121        Ok(true_count
122            .as_primitive()
123            .as_::<usize>()
124            .vortex_expect("true count should never be null"))
125    }
126}
127
128pub struct PrimitiveTyped<'a>(&'a ArrayRef);
129
130impl PrimitiveTyped<'_> {
131    pub fn ptype(&self) -> PType {
132        let DType::Primitive(ptype, _) = self.0.dtype() else {
133            vortex_panic!("Expected Primitive DType")
134        };
135        *ptype
136    }
137
138    /// Return the primitive value at the given index.
139    #[deprecated(
140        note = "Relies on the hidden global `LEGACY_SESSION`; use `is_valid`/`execute_scalar` with an explicit `ExecutionCtx` instead"
141    )]
142    #[allow(deprecated)]
143    pub fn value(&self, idx: usize) -> VortexResult<Option<PValue>> {
144        self.0
145            .is_valid(idx, &mut LEGACY_SESSION.create_execution_ctx())?
146            .then(|| self.value_unchecked(idx))
147            .transpose()
148    }
149
150    /// Return the primitive value at the given index, ignoring nullability.
151    #[deprecated(
152        note = "Relies on the hidden global `LEGACY_SESSION`; use `execute_scalar` with an explicit `ExecutionCtx` instead"
153    )]
154    pub fn value_unchecked(&self, idx: usize) -> VortexResult<PValue> {
155        Ok(self
156            .0
157            .execute_scalar(idx, &mut LEGACY_SESSION.create_execution_ctx())?
158            .as_primitive()
159            .pvalue()
160            .unwrap_or_else(|| PValue::zero(&self.ptype())))
161    }
162}
163
164impl IndexOrd<Option<PValue>> for PrimitiveTyped<'_> {
165    #[allow(deprecated)]
166    fn index_cmp(&self, idx: usize, elem: &Option<PValue>) -> VortexResult<Option<Ordering>> {
167        let value = self.value(idx)?;
168        Ok(value.partial_cmp(elem))
169    }
170
171    fn index_len(&self) -> usize {
172        self.0.len()
173    }
174}
175
176// TODO(ngates): add generics to the `value` function and implement this over T.
177impl IndexOrd<PValue> for PrimitiveTyped<'_> {
178    #[allow(deprecated)]
179    fn index_cmp(&self, idx: usize, elem: &PValue) -> VortexResult<Option<Ordering>> {
180        assert!(
181            self.0
182                .all_valid(&mut LEGACY_SESSION.create_execution_ctx())?
183        );
184        let value = self.value_unchecked(idx)?;
185        Ok(value.partial_cmp(elem))
186    }
187
188    fn index_len(&self) -> usize {
189        self.0.len()
190    }
191}
192
193#[expect(dead_code)]
194pub struct Utf8Typed<'a>(&'a ArrayRef);
195
196#[expect(dead_code)]
197pub struct BinaryTyped<'a>(&'a ArrayRef);
198
199#[expect(dead_code)]
200pub struct DecimalTyped<'a>(&'a ArrayRef);
201
202pub struct StructTyped<'a>(&'a ArrayRef);
203
204impl StructTyped<'_> {
205    pub fn names(&self) -> &FieldNames {
206        let DType::Struct(st, _) = self.0.dtype() else {
207            unreachable!()
208        };
209        st.names()
210    }
211
212    pub fn dtypes(&self) -> Vec<DType> {
213        let DType::Struct(st, _) = self.0.dtype() else {
214            unreachable!()
215        };
216        st.fields().collect()
217    }
218
219    pub fn nfields(&self) -> usize {
220        self.names().len()
221    }
222}
223
224#[expect(dead_code)]
225pub struct ListTyped<'a>(&'a ArrayRef);
226
227pub struct ExtensionTyped<'a>(&'a ArrayRef);
228
229impl ExtensionTyped<'_> {
230    /// Returns the extension logical [`DType`].
231    pub fn ext_dtype(&self) -> &ExtDTypeRef {
232        let DType::Extension(ext_dtype) = self.0.dtype() else {
233            vortex_panic!("Expected ExtDType")
234        };
235        ext_dtype
236    }
237}