1use std::cmp::Ordering;
6
7use vortex_error::VortexExpect;
8use vortex_error::VortexResult;
9use vortex_error::vortex_bail;
10use vortex_error::vortex_panic;
11use vortex_mask::Mask;
12
13use crate::Array;
14use crate::ExecutionCtx;
15use crate::arrays::BoolArray;
16use crate::builtins::ArrayBuiltins;
17use crate::compute::sum;
18use crate::dtype::DType;
19use crate::dtype::FieldNames;
20use crate::dtype::PType;
21use crate::dtype::extension::ExtDTypeRef;
22use crate::scalar::PValue;
23use crate::scalar::Scalar;
24use crate::search_sorted::IndexOrd;
25
26impl dyn Array + '_ {
27 pub fn as_null_typed(&self) -> NullTyped<'_> {
29 matches!(self.dtype(), DType::Null)
30 .then(|| NullTyped(self))
31 .vortex_expect("Array does not have DType::Null")
32 }
33
34 pub fn as_bool_typed(&self) -> BoolTyped<'_> {
36 matches!(self.dtype(), DType::Bool(..))
37 .then(|| BoolTyped(self))
38 .vortex_expect("Array does not have DType::Bool")
39 }
40
41 pub fn as_primitive_typed(&self) -> PrimitiveTyped<'_> {
43 matches!(self.dtype(), DType::Primitive(..))
44 .then(|| PrimitiveTyped(self))
45 .vortex_expect("Array does not have DType::Primitive")
46 }
47
48 pub fn as_decimal_typed(&self) -> DecimalTyped<'_> {
50 matches!(self.dtype(), DType::Decimal(..))
51 .then(|| DecimalTyped(self))
52 .vortex_expect("Array does not have DType::Decimal")
53 }
54
55 pub fn as_utf8_typed(&self) -> Utf8Typed<'_> {
57 matches!(self.dtype(), DType::Utf8(..))
58 .then(|| Utf8Typed(self))
59 .vortex_expect("Array does not have DType::Utf8")
60 }
61
62 pub fn as_binary_typed(&self) -> BinaryTyped<'_> {
64 matches!(self.dtype(), DType::Binary(..))
65 .then(|| BinaryTyped(self))
66 .vortex_expect("Array does not have DType::Binary")
67 }
68
69 pub fn as_struct_typed(&self) -> StructTyped<'_> {
71 matches!(self.dtype(), DType::Struct(..))
72 .then(|| StructTyped(self))
73 .vortex_expect("Array does not have DType::Struct")
74 }
75
76 pub fn as_list_typed(&self) -> ListTyped<'_> {
78 matches!(self.dtype(), DType::List(..))
79 .then(|| ListTyped(self))
80 .vortex_expect("Array does not have DType::List")
81 }
82
83 pub fn as_extension_typed(&self) -> ExtensionTyped<'_> {
85 matches!(self.dtype(), DType::Extension(..))
86 .then(|| ExtensionTyped(self))
87 .vortex_expect("Array does not have DType::Extension")
88 }
89
90 pub fn try_to_mask_fill_null_false(&self, ctx: &mut ExecutionCtx) -> VortexResult<Mask> {
91 if !matches!(self.dtype(), DType::Bool(_)) {
92 vortex_bail!("mask must be bool array, has dtype {}", self.dtype());
93 }
94
95 let array = self
97 .to_array()
98 .fill_null(Scalar::bool(false, self.dtype().nullability()))?;
99
100 Ok(array.execute::<BoolArray>(ctx)?.to_mask_fill_null_false())
101 }
102}
103
104#[expect(dead_code)]
105pub struct NullTyped<'a>(&'a dyn Array);
106
107pub struct BoolTyped<'a>(&'a dyn Array);
108
109impl BoolTyped<'_> {
110 pub fn true_count(&self) -> VortexResult<usize> {
111 let true_count = sum(&self.0.to_array())?;
112 Ok(true_count
113 .as_primitive()
114 .as_::<usize>()
115 .vortex_expect("true count should never be null"))
116 }
117}
118
119pub struct PrimitiveTyped<'a>(&'a dyn Array);
120
121impl PrimitiveTyped<'_> {
122 pub fn ptype(&self) -> PType {
123 let DType::Primitive(ptype, _) = self.0.dtype() else {
124 vortex_panic!("Expected Primitive DType")
125 };
126 *ptype
127 }
128
129 pub fn value(&self, idx: usize) -> VortexResult<Option<PValue>> {
131 self.0
132 .is_valid(idx)?
133 .then(|| self.value_unchecked(idx))
134 .transpose()
135 }
136
137 pub fn value_unchecked(&self, idx: usize) -> VortexResult<PValue> {
139 Ok(self
140 .0
141 .scalar_at(idx)?
142 .as_primitive()
143 .pvalue()
144 .unwrap_or_else(|| PValue::zero(&self.ptype())))
145 }
146}
147
148impl IndexOrd<Option<PValue>> for PrimitiveTyped<'_> {
149 fn index_cmp(&self, idx: usize, elem: &Option<PValue>) -> VortexResult<Option<Ordering>> {
150 let value = self.value(idx)?;
151 Ok(value.partial_cmp(elem))
152 }
153
154 fn index_len(&self) -> usize {
155 self.0.len()
156 }
157}
158
159impl IndexOrd<PValue> for PrimitiveTyped<'_> {
161 fn index_cmp(&self, idx: usize, elem: &PValue) -> VortexResult<Option<Ordering>> {
162 assert!(self.0.all_valid()?);
163 let value = self.value_unchecked(idx)?;
164 Ok(value.partial_cmp(elem))
165 }
166
167 fn index_len(&self) -> usize {
168 self.0.len()
169 }
170}
171
172#[expect(dead_code)]
173pub struct Utf8Typed<'a>(&'a dyn Array);
174
175#[expect(dead_code)]
176pub struct BinaryTyped<'a>(&'a dyn Array);
177
178#[expect(dead_code)]
179pub struct DecimalTyped<'a>(&'a dyn Array);
180
181pub struct StructTyped<'a>(&'a dyn Array);
182
183impl StructTyped<'_> {
184 pub fn names(&self) -> &FieldNames {
185 let DType::Struct(st, _) = self.0.dtype() else {
186 unreachable!()
187 };
188 st.names()
189 }
190
191 pub fn dtypes(&self) -> Vec<DType> {
192 let DType::Struct(st, _) = self.0.dtype() else {
193 unreachable!()
194 };
195 st.fields().collect()
196 }
197
198 pub fn nfields(&self) -> usize {
199 self.names().len()
200 }
201}
202
203#[expect(dead_code)]
204pub struct ListTyped<'a>(&'a dyn Array);
205
206pub struct ExtensionTyped<'a>(&'a dyn Array);
207
208impl ExtensionTyped<'_> {
209 pub fn ext_dtype(&self) -> &ExtDTypeRef {
211 let DType::Extension(ext_dtype) = self.0.dtype() else {
212 vortex_panic!("Expected ExtDType")
213 };
214 ext_dtype
215 }
216}