vortex_array/
variants.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! This module defines extension functionality specific to each Vortex DType.
5use std::cmp::Ordering;
6use std::sync::Arc;
7
8use vortex_dtype::DType;
9use vortex_dtype::ExtDType;
10use vortex_dtype::FieldNames;
11use vortex_dtype::PType;
12use vortex_error::VortexExpect;
13use vortex_error::VortexResult;
14use vortex_error::vortex_panic;
15use vortex_scalar::PValue;
16
17use crate::Array;
18use crate::compute::sum;
19use crate::search_sorted::IndexOrd;
20
21impl dyn Array + '_ {
22    /// Downcasts the array for null-specific behavior.
23    pub fn as_null_typed(&self) -> NullTyped<'_> {
24        matches!(self.dtype(), DType::Null)
25            .then(|| NullTyped(self))
26            .vortex_expect("Array does not have DType::Null")
27    }
28
29    /// Downcasts the array for bool-specific behavior.
30    pub fn as_bool_typed(&self) -> BoolTyped<'_> {
31        matches!(self.dtype(), DType::Bool(..))
32            .then(|| BoolTyped(self))
33            .vortex_expect("Array does not have DType::Bool")
34    }
35
36    /// Downcasts the array for primitive-specific behavior.
37    pub fn as_primitive_typed(&self) -> PrimitiveTyped<'_> {
38        matches!(self.dtype(), DType::Primitive(..))
39            .then(|| PrimitiveTyped(self))
40            .vortex_expect("Array does not have DType::Primitive")
41    }
42
43    /// Downcasts the array for decimal-specific behavior.
44    pub fn as_decimal_typed(&self) -> DecimalTyped<'_> {
45        matches!(self.dtype(), DType::Decimal(..))
46            .then(|| DecimalTyped(self))
47            .vortex_expect("Array does not have DType::Decimal")
48    }
49
50    /// Downcasts the array for utf8-specific behavior.
51    pub fn as_utf8_typed(&self) -> Utf8Typed<'_> {
52        matches!(self.dtype(), DType::Utf8(..))
53            .then(|| Utf8Typed(self))
54            .vortex_expect("Array does not have DType::Utf8")
55    }
56
57    /// Downcasts the array for binary-specific behavior.
58    pub fn as_binary_typed(&self) -> BinaryTyped<'_> {
59        matches!(self.dtype(), DType::Binary(..))
60            .then(|| BinaryTyped(self))
61            .vortex_expect("Array does not have DType::Binary")
62    }
63
64    /// Downcasts the array for struct-specific behavior.
65    pub fn as_struct_typed(&self) -> StructTyped<'_> {
66        matches!(self.dtype(), DType::Struct(..))
67            .then(|| StructTyped(self))
68            .vortex_expect("Array does not have DType::Struct")
69    }
70
71    /// Downcasts the array for list-specific behavior.
72    pub fn as_list_typed(&self) -> ListTyped<'_> {
73        matches!(self.dtype(), DType::List(..))
74            .then(|| ListTyped(self))
75            .vortex_expect("Array does not have DType::List")
76    }
77
78    /// Downcasts the array for extension-specific behavior.
79    pub fn as_extension_typed(&self) -> ExtensionTyped<'_> {
80        matches!(self.dtype(), DType::Extension(..))
81            .then(|| ExtensionTyped(self))
82            .vortex_expect("Array does not have DType::Extension")
83    }
84}
85
86#[allow(dead_code)]
87pub struct NullTyped<'a>(&'a dyn Array);
88
89pub struct BoolTyped<'a>(&'a dyn Array);
90
91impl BoolTyped<'_> {
92    pub fn true_count(&self) -> VortexResult<usize> {
93        let true_count = sum(self.0)?;
94        Ok(true_count
95            .as_primitive()
96            .as_::<usize>()
97            .vortex_expect("true count should never be null"))
98    }
99}
100
101pub struct PrimitiveTyped<'a>(&'a dyn Array);
102
103impl PrimitiveTyped<'_> {
104    pub fn ptype(&self) -> PType {
105        let DType::Primitive(ptype, _) = self.0.dtype() else {
106            vortex_panic!("Expected Primitive DType")
107        };
108        *ptype
109    }
110
111    /// Return the primitive value at the given index.
112    pub fn value(&self, idx: usize) -> Option<PValue> {
113        self.0.is_valid(idx).then(|| self.value_unchecked(idx))
114    }
115
116    /// Return the primitive value at the given index, ignoring nullability.
117    pub fn value_unchecked(&self, idx: usize) -> PValue {
118        self.0
119            .scalar_at(idx)
120            .as_primitive()
121            .pvalue()
122            .unwrap_or_else(|| PValue::zero(self.ptype()))
123    }
124}
125
126impl IndexOrd<Option<PValue>> for PrimitiveTyped<'_> {
127    fn index_cmp(&self, idx: usize, elem: &Option<PValue>) -> Option<Ordering> {
128        self.value(idx).partial_cmp(elem)
129    }
130
131    fn index_len(&self) -> usize {
132        self.0.len()
133    }
134}
135
136// TODO(ngates): add generics to the `value` function and implement this over T.
137impl IndexOrd<PValue> for PrimitiveTyped<'_> {
138    fn index_cmp(&self, idx: usize, elem: &PValue) -> Option<Ordering> {
139        assert!(self.0.all_valid());
140        self.value_unchecked(idx).partial_cmp(elem)
141    }
142
143    fn index_len(&self) -> usize {
144        self.0.len()
145    }
146}
147
148#[allow(dead_code)]
149pub struct Utf8Typed<'a>(&'a dyn Array);
150
151#[allow(dead_code)]
152pub struct BinaryTyped<'a>(&'a dyn Array);
153
154#[allow(dead_code)]
155pub struct DecimalTyped<'a>(&'a dyn Array);
156
157pub struct StructTyped<'a>(&'a dyn Array);
158
159impl StructTyped<'_> {
160    pub fn names(&self) -> &FieldNames {
161        let DType::Struct(st, _) = self.0.dtype() else {
162            unreachable!()
163        };
164        st.names()
165    }
166
167    pub fn dtypes(&self) -> Vec<DType> {
168        let DType::Struct(st, _) = self.0.dtype() else {
169            unreachable!()
170        };
171        st.fields().collect()
172    }
173
174    pub fn nfields(&self) -> usize {
175        self.names().len()
176    }
177}
178
179#[allow(dead_code)]
180pub struct ListTyped<'a>(&'a dyn Array);
181
182pub struct ExtensionTyped<'a>(&'a dyn Array);
183
184impl ExtensionTyped<'_> {
185    /// Returns the extension logical [`DType`].
186    pub fn ext_dtype(&self) -> &Arc<ExtDType> {
187        let DType::Extension(ext_dtype) = self.0.dtype() else {
188            vortex_panic!("Expected ExtDType")
189        };
190        ext_dtype
191    }
192}