vortex_array/
variants.rs

1//! This module defines extension functionality specific to each Vortex DType.
2use std::cmp::Ordering;
3use std::sync::Arc;
4
5use vortex_dtype::{DType, ExtDType, FieldNames, PType};
6use vortex_error::{VortexExpect, VortexResult, vortex_panic};
7use vortex_scalar::PValue;
8
9use crate::Array;
10use crate::compute::sum;
11use crate::search_sorted::IndexOrd;
12
13impl dyn Array + '_ {
14    /// Downcasts the array for null-specific behavior.
15    pub fn as_null_typed(&self) -> NullTyped {
16        matches!(self.dtype(), DType::Null)
17            .then(|| NullTyped(self))
18            .vortex_expect("Array does not have DType::Null")
19    }
20
21    /// Downcasts the array for bool-specific behavior.
22    pub fn as_bool_typed(&self) -> BoolTyped {
23        matches!(self.dtype(), DType::Bool(..))
24            .then(|| BoolTyped(self))
25            .vortex_expect("Array does not have DType::Bool")
26    }
27
28    /// Downcasts the array for primitive-specific behavior.
29    pub fn as_primitive_typed(&self) -> PrimitiveTyped {
30        matches!(self.dtype(), DType::Primitive(..))
31            .then(|| PrimitiveTyped(self))
32            .vortex_expect("Array does not have DType::Primitive")
33    }
34
35    /// Downcasts the array for decimal-specific behavior.
36    pub fn as_decimal_typed(&self) -> DecimalTyped {
37        matches!(self.dtype(), DType::Decimal(..))
38            .then(|| DecimalTyped(self))
39            .vortex_expect("Array does not have DType::Decimal")
40    }
41
42    /// Downcasts the array for utf8-specific behavior.
43    pub fn as_utf8_typed(&self) -> Utf8Typed {
44        matches!(self.dtype(), DType::Utf8(..))
45            .then(|| Utf8Typed(self))
46            .vortex_expect("Array does not have DType::Utf8")
47    }
48
49    /// Downcasts the array for binary-specific behavior.
50    pub fn as_binary_typed(&self) -> BinaryTyped {
51        matches!(self.dtype(), DType::Binary(..))
52            .then(|| BinaryTyped(self))
53            .vortex_expect("Array does not have DType::Binary")
54    }
55
56    /// Downcasts the array for struct-specific behavior.
57    pub fn as_struct_typed(&self) -> StructTyped {
58        matches!(self.dtype(), DType::Struct(..))
59            .then(|| StructTyped(self))
60            .vortex_expect("Array does not have DType::Struct")
61    }
62
63    /// Downcasts the array for list-specific behavior.
64    pub fn as_list_typed(&self) -> ListTyped {
65        matches!(self.dtype(), DType::List(..))
66            .then(|| ListTyped(self))
67            .vortex_expect("Array does not have DType::List")
68    }
69
70    /// Downcasts the array for extension-specific behavior.
71    pub fn as_extension_typed(&self) -> ExtensionTyped {
72        matches!(self.dtype(), DType::Extension(..))
73            .then(|| ExtensionTyped(self))
74            .vortex_expect("Array does not have DType::Extension")
75    }
76}
77
78#[allow(dead_code)]
79pub struct NullTyped<'a>(&'a dyn Array);
80
81pub struct BoolTyped<'a>(&'a dyn Array);
82
83impl BoolTyped<'_> {
84    pub fn true_count(&self) -> VortexResult<usize> {
85        let true_count = sum(self.0)?;
86        Ok(true_count
87            .as_primitive()
88            .as_::<usize>()
89            .vortex_expect("true count should never overflow usize")
90            .vortex_expect("true count should never be null"))
91    }
92}
93
94pub struct PrimitiveTyped<'a>(&'a dyn Array);
95
96impl PrimitiveTyped<'_> {
97    pub fn ptype(&self) -> PType {
98        let DType::Primitive(ptype, _) = self.0.dtype() else {
99            vortex_panic!("Expected Primitive DType")
100        };
101        *ptype
102    }
103
104    /// Return the primitive value at the given index.
105    pub fn value(&self, idx: usize) -> Option<PValue> {
106        self.0
107            .is_valid(idx)
108            .vortex_expect("is valid")
109            .then(|| self.value_unchecked(idx))
110    }
111
112    /// Return the primitive value at the given index, ignoring nullability.
113    pub fn value_unchecked(&self, idx: usize) -> PValue {
114        self.0
115            .scalar_at(idx)
116            .vortex_expect("scalar at index")
117            .as_primitive()
118            .pvalue()
119            .unwrap_or_else(|| PValue::zero(self.ptype()))
120    }
121}
122
123impl IndexOrd<Option<PValue>> for PrimitiveTyped<'_> {
124    fn index_cmp(&self, idx: usize, elem: &Option<PValue>) -> Option<Ordering> {
125        self.value(idx).partial_cmp(elem)
126    }
127
128    fn index_len(&self) -> usize {
129        self.0.len()
130    }
131}
132
133// TODO(ngates): add generics to the `value` function and implement this over T.
134impl IndexOrd<PValue> for PrimitiveTyped<'_> {
135    fn index_cmp(&self, idx: usize, elem: &PValue) -> Option<Ordering> {
136        assert!(self.0.all_valid().vortex_expect("all valid"));
137        self.value_unchecked(idx).partial_cmp(elem)
138    }
139
140    fn index_len(&self) -> usize {
141        self.0.len()
142    }
143}
144
145#[allow(dead_code)]
146pub struct Utf8Typed<'a>(&'a dyn Array);
147
148#[allow(dead_code)]
149pub struct BinaryTyped<'a>(&'a dyn Array);
150
151#[allow(dead_code)]
152pub struct DecimalTyped<'a>(&'a dyn Array);
153
154pub struct StructTyped<'a>(&'a dyn Array);
155
156impl StructTyped<'_> {
157    pub fn names(&self) -> &FieldNames {
158        let DType::Struct(st, _) = self.0.dtype() else {
159            unreachable!()
160        };
161        st.names()
162    }
163
164    pub fn dtypes(&self) -> Vec<DType> {
165        let DType::Struct(st, _) = self.0.dtype() else {
166            unreachable!()
167        };
168        st.fields().collect()
169    }
170
171    pub fn nfields(&self) -> usize {
172        self.names().len()
173    }
174}
175
176#[allow(dead_code)]
177pub struct ListTyped<'a>(&'a dyn Array);
178
179pub struct ExtensionTyped<'a>(&'a dyn Array);
180
181impl ExtensionTyped<'_> {
182    /// Returns the extension logical [`DType`].
183    pub fn ext_dtype(&self) -> &Arc<ExtDType> {
184        let DType::Extension(ext_dtype) = self.0.dtype() else {
185            vortex_panic!("Expected ExtDType")
186        };
187        ext_dtype
188    }
189}