Skip to main content

vortex_array/
variants.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! This module defines extension functionality specific to each Vortex DType.
5use std::cmp::Ordering;
6
7use vortex_dtype::DType;
8use vortex_dtype::FieldNames;
9use vortex_dtype::PType;
10use vortex_dtype::extension::ExtDTypeRef;
11use vortex_error::VortexExpect;
12use vortex_error::VortexResult;
13use vortex_error::vortex_panic;
14
15use crate::Array;
16use crate::compute::sum;
17use crate::scalar::PValue;
18use crate::search_sorted::IndexOrd;
19
20impl dyn Array + '_ {
21    /// Downcasts the array for null-specific behavior.
22    pub fn as_null_typed(&self) -> NullTyped<'_> {
23        matches!(self.dtype(), DType::Null)
24            .then(|| NullTyped(self))
25            .vortex_expect("Array does not have DType::Null")
26    }
27
28    /// Downcasts the array for bool-specific behavior.
29    pub fn as_bool_typed(&self) -> BoolTyped<'_> {
30        matches!(self.dtype(), DType::Bool(..))
31            .then(|| BoolTyped(self))
32            .vortex_expect("Array does not have DType::Bool")
33    }
34
35    /// Downcasts the array for primitive-specific behavior.
36    pub fn as_primitive_typed(&self) -> PrimitiveTyped<'_> {
37        matches!(self.dtype(), DType::Primitive(..))
38            .then(|| PrimitiveTyped(self))
39            .vortex_expect("Array does not have DType::Primitive")
40    }
41
42    /// Downcasts the array for decimal-specific behavior.
43    pub fn as_decimal_typed(&self) -> DecimalTyped<'_> {
44        matches!(self.dtype(), DType::Decimal(..))
45            .then(|| DecimalTyped(self))
46            .vortex_expect("Array does not have DType::Decimal")
47    }
48
49    /// Downcasts the array for utf8-specific behavior.
50    pub fn as_utf8_typed(&self) -> Utf8Typed<'_> {
51        matches!(self.dtype(), DType::Utf8(..))
52            .then(|| Utf8Typed(self))
53            .vortex_expect("Array does not have DType::Utf8")
54    }
55
56    /// Downcasts the array for binary-specific behavior.
57    pub fn as_binary_typed(&self) -> BinaryTyped<'_> {
58        matches!(self.dtype(), DType::Binary(..))
59            .then(|| BinaryTyped(self))
60            .vortex_expect("Array does not have DType::Binary")
61    }
62
63    /// Downcasts the array for struct-specific behavior.
64    pub fn as_struct_typed(&self) -> StructTyped<'_> {
65        matches!(self.dtype(), DType::Struct(..))
66            .then(|| StructTyped(self))
67            .vortex_expect("Array does not have DType::Struct")
68    }
69
70    /// Downcasts the array for list-specific behavior.
71    pub fn as_list_typed(&self) -> ListTyped<'_> {
72        matches!(self.dtype(), DType::List(..))
73            .then(|| ListTyped(self))
74            .vortex_expect("Array does not have DType::List")
75    }
76
77    /// Downcasts the array for extension-specific behavior.
78    pub fn as_extension_typed(&self) -> ExtensionTyped<'_> {
79        matches!(self.dtype(), DType::Extension(..))
80            .then(|| ExtensionTyped(self))
81            .vortex_expect("Array does not have DType::Extension")
82    }
83}
84
85#[expect(dead_code)]
86pub struct NullTyped<'a>(&'a dyn Array);
87
88pub struct BoolTyped<'a>(&'a dyn Array);
89
90impl BoolTyped<'_> {
91    pub fn true_count(&self) -> VortexResult<usize> {
92        let true_count = sum(self.0)?;
93        Ok(true_count
94            .as_primitive()
95            .as_::<usize>()
96            .vortex_expect("true count should never be null"))
97    }
98}
99
100pub struct PrimitiveTyped<'a>(&'a dyn Array);
101
102impl PrimitiveTyped<'_> {
103    pub fn ptype(&self) -> PType {
104        let DType::Primitive(ptype, _) = self.0.dtype() else {
105            vortex_panic!("Expected Primitive DType")
106        };
107        *ptype
108    }
109
110    /// Return the primitive value at the given index.
111    pub fn value(&self, idx: usize) -> VortexResult<Option<PValue>> {
112        self.0
113            .is_valid(idx)?
114            .then(|| self.value_unchecked(idx))
115            .transpose()
116    }
117
118    /// Return the primitive value at the given index, ignoring nullability.
119    pub fn value_unchecked(&self, idx: usize) -> VortexResult<PValue> {
120        Ok(self
121            .0
122            .scalar_at(idx)?
123            .as_primitive()
124            .pvalue()
125            .unwrap_or_else(|| PValue::zero(&self.ptype())))
126    }
127}
128
129impl IndexOrd<Option<PValue>> for PrimitiveTyped<'_> {
130    fn index_cmp(&self, idx: usize, elem: &Option<PValue>) -> VortexResult<Option<Ordering>> {
131        let value = self.value(idx)?;
132        Ok(value.partial_cmp(elem))
133    }
134
135    fn index_len(&self) -> usize {
136        self.0.len()
137    }
138}
139
140// TODO(ngates): add generics to the `value` function and implement this over T.
141impl IndexOrd<PValue> for PrimitiveTyped<'_> {
142    fn index_cmp(&self, idx: usize, elem: &PValue) -> VortexResult<Option<Ordering>> {
143        assert!(self.0.all_valid()?);
144        let value = self.value_unchecked(idx)?;
145        Ok(value.partial_cmp(elem))
146    }
147
148    fn index_len(&self) -> usize {
149        self.0.len()
150    }
151}
152
153#[expect(dead_code)]
154pub struct Utf8Typed<'a>(&'a dyn Array);
155
156#[expect(dead_code)]
157pub struct BinaryTyped<'a>(&'a dyn Array);
158
159#[expect(dead_code)]
160pub struct DecimalTyped<'a>(&'a dyn Array);
161
162pub struct StructTyped<'a>(&'a dyn Array);
163
164impl StructTyped<'_> {
165    pub fn names(&self) -> &FieldNames {
166        let DType::Struct(st, _) = self.0.dtype() else {
167            unreachable!()
168        };
169        st.names()
170    }
171
172    pub fn dtypes(&self) -> Vec<DType> {
173        let DType::Struct(st, _) = self.0.dtype() else {
174            unreachable!()
175        };
176        st.fields().collect()
177    }
178
179    pub fn nfields(&self) -> usize {
180        self.names().len()
181    }
182}
183
184#[expect(dead_code)]
185pub struct ListTyped<'a>(&'a dyn Array);
186
187pub struct ExtensionTyped<'a>(&'a dyn Array);
188
189impl ExtensionTyped<'_> {
190    /// Returns the extension logical [`DType`].
191    pub fn ext_dtype(&self) -> &ExtDTypeRef {
192        let DType::Extension(ext_dtype) = self.0.dtype() else {
193            vortex_panic!("Expected ExtDType")
194        };
195        ext_dtype
196    }
197}