vortex_array/
variants.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! This module defines extension functionality specific to each Vortex DType.
5use std::cmp::Ordering;
6use std::sync::Arc;
7
8use vortex_dtype::{DType, ExtDType, FieldNames, PType};
9use vortex_error::{VortexExpect, VortexResult, vortex_panic};
10use vortex_scalar::PValue;
11
12use crate::Array;
13use crate::compute::sum;
14use crate::search_sorted::IndexOrd;
15
16impl dyn Array + '_ {
17    /// Downcasts the array for null-specific behavior.
18    pub fn as_null_typed(&self) -> NullTyped<'_> {
19        matches!(self.dtype(), DType::Null)
20            .then(|| NullTyped(self))
21            .vortex_expect("Array does not have DType::Null")
22    }
23
24    /// Downcasts the array for bool-specific behavior.
25    pub fn as_bool_typed(&self) -> BoolTyped<'_> {
26        matches!(self.dtype(), DType::Bool(..))
27            .then(|| BoolTyped(self))
28            .vortex_expect("Array does not have DType::Bool")
29    }
30
31    /// Downcasts the array for primitive-specific behavior.
32    pub fn as_primitive_typed(&self) -> PrimitiveTyped<'_> {
33        matches!(self.dtype(), DType::Primitive(..))
34            .then(|| PrimitiveTyped(self))
35            .vortex_expect("Array does not have DType::Primitive")
36    }
37
38    /// Downcasts the array for decimal-specific behavior.
39    pub fn as_decimal_typed(&self) -> DecimalTyped<'_> {
40        matches!(self.dtype(), DType::Decimal(..))
41            .then(|| DecimalTyped(self))
42            .vortex_expect("Array does not have DType::Decimal")
43    }
44
45    /// Downcasts the array for utf8-specific behavior.
46    pub fn as_utf8_typed(&self) -> Utf8Typed<'_> {
47        matches!(self.dtype(), DType::Utf8(..))
48            .then(|| Utf8Typed(self))
49            .vortex_expect("Array does not have DType::Utf8")
50    }
51
52    /// Downcasts the array for binary-specific behavior.
53    pub fn as_binary_typed(&self) -> BinaryTyped<'_> {
54        matches!(self.dtype(), DType::Binary(..))
55            .then(|| BinaryTyped(self))
56            .vortex_expect("Array does not have DType::Binary")
57    }
58
59    /// Downcasts the array for struct-specific behavior.
60    pub fn as_struct_typed(&self) -> StructTyped<'_> {
61        matches!(self.dtype(), DType::Struct(..))
62            .then(|| StructTyped(self))
63            .vortex_expect("Array does not have DType::Struct")
64    }
65
66    /// Downcasts the array for list-specific behavior.
67    pub fn as_list_typed(&self) -> ListTyped<'_> {
68        matches!(self.dtype(), DType::List(..))
69            .then(|| ListTyped(self))
70            .vortex_expect("Array does not have DType::List")
71    }
72
73    /// Downcasts the array for extension-specific behavior.
74    pub fn as_extension_typed(&self) -> ExtensionTyped<'_> {
75        matches!(self.dtype(), DType::Extension(..))
76            .then(|| ExtensionTyped(self))
77            .vortex_expect("Array does not have DType::Extension")
78    }
79}
80
81#[allow(dead_code)]
82pub struct NullTyped<'a>(&'a dyn Array);
83
84pub struct BoolTyped<'a>(&'a dyn Array);
85
86impl BoolTyped<'_> {
87    pub fn true_count(&self) -> VortexResult<usize> {
88        let true_count = sum(self.0)?;
89        Ok(true_count
90            .as_primitive()
91            .as_::<usize>()
92            .vortex_expect("true count should never be null"))
93    }
94}
95
96pub struct PrimitiveTyped<'a>(&'a dyn Array);
97
98impl PrimitiveTyped<'_> {
99    pub fn ptype(&self) -> PType {
100        let DType::Primitive(ptype, _) = self.0.dtype() else {
101            vortex_panic!("Expected Primitive DType")
102        };
103        *ptype
104    }
105
106    /// Return the primitive value at the given index.
107    pub fn value(&self, idx: usize) -> Option<PValue> {
108        self.0
109            .is_valid(idx)
110            .vortex_expect("is valid")
111            .then(|| self.value_unchecked(idx))
112    }
113
114    /// Return the primitive value at the given index, ignoring nullability.
115    pub fn value_unchecked(&self, idx: usize) -> PValue {
116        self.0
117            .scalar_at(idx)
118            .as_primitive()
119            .pvalue()
120            .unwrap_or_else(|| PValue::zero(self.ptype()))
121    }
122}
123
124impl IndexOrd<Option<PValue>> for PrimitiveTyped<'_> {
125    fn index_cmp(&self, idx: usize, elem: &Option<PValue>) -> Option<Ordering> {
126        self.value(idx).partial_cmp(elem)
127    }
128
129    fn index_len(&self) -> usize {
130        self.0.len()
131    }
132}
133
134// TODO(ngates): add generics to the `value` function and implement this over T.
135impl IndexOrd<PValue> for PrimitiveTyped<'_> {
136    fn index_cmp(&self, idx: usize, elem: &PValue) -> Option<Ordering> {
137        assert!(self.0.all_valid().vortex_expect("all valid"));
138        self.value_unchecked(idx).partial_cmp(elem)
139    }
140
141    fn index_len(&self) -> usize {
142        self.0.len()
143    }
144}
145
146#[allow(dead_code)]
147pub struct Utf8Typed<'a>(&'a dyn Array);
148
149#[allow(dead_code)]
150pub struct BinaryTyped<'a>(&'a dyn Array);
151
152#[allow(dead_code)]
153pub struct DecimalTyped<'a>(&'a dyn Array);
154
155pub struct StructTyped<'a>(&'a dyn Array);
156
157impl StructTyped<'_> {
158    pub fn names(&self) -> &FieldNames {
159        let DType::Struct(st, _) = self.0.dtype() else {
160            unreachable!()
161        };
162        st.names()
163    }
164
165    pub fn dtypes(&self) -> Vec<DType> {
166        let DType::Struct(st, _) = self.0.dtype() else {
167            unreachable!()
168        };
169        st.fields().collect()
170    }
171
172    pub fn nfields(&self) -> usize {
173        self.names().len()
174    }
175}
176
177#[allow(dead_code)]
178pub struct ListTyped<'a>(&'a dyn Array);
179
180pub struct ExtensionTyped<'a>(&'a dyn Array);
181
182impl ExtensionTyped<'_> {
183    /// Returns the extension logical [`DType`].
184    pub fn ext_dtype(&self) -> &Arc<ExtDType> {
185        let DType::Extension(ext_dtype) = self.0.dtype() else {
186            vortex_panic!("Expected ExtDType")
187        };
188        ext_dtype
189    }
190}