1use std::cmp::Ordering;
6
7use vortex_error::VortexExpect;
8use vortex_error::VortexResult;
9use vortex_error::vortex_bail;
10use vortex_error::vortex_panic;
11use vortex_mask::Mask;
12
13use crate::ArrayRef;
14use crate::ExecutionCtx;
15use crate::LEGACY_SESSION;
16use crate::VortexSessionExecute;
17use crate::aggregate_fn::fns::sum::sum;
18use crate::arrays::BoolArray;
19use crate::arrays::bool::BoolArrayExt;
20use crate::builtins::ArrayBuiltins;
21use crate::dtype::DType;
22use crate::dtype::FieldNames;
23use crate::dtype::PType;
24use crate::dtype::extension::ExtDTypeRef;
25use crate::scalar::PValue;
26use crate::scalar::Scalar;
27use crate::search_sorted::IndexOrd;
28
29impl ArrayRef {
30 pub fn as_null_typed(&self) -> NullTyped<'_> {
32 matches!(self.dtype(), DType::Null)
33 .then(|| NullTyped(self))
34 .vortex_expect("Array does not have DType::Null")
35 }
36
37 pub fn as_bool_typed(&self) -> BoolTyped<'_> {
39 matches!(self.dtype(), DType::Bool(..))
40 .then(|| BoolTyped(self))
41 .vortex_expect("Array does not have DType::Bool")
42 }
43
44 pub fn as_primitive_typed(&self) -> PrimitiveTyped<'_> {
46 matches!(self.dtype(), DType::Primitive(..))
47 .then(|| PrimitiveTyped(self))
48 .vortex_expect("Array does not have DType::Primitive")
49 }
50
51 pub fn as_decimal_typed(&self) -> DecimalTyped<'_> {
53 matches!(self.dtype(), DType::Decimal(..))
54 .then(|| DecimalTyped(self))
55 .vortex_expect("Array does not have DType::Decimal")
56 }
57
58 pub fn as_utf8_typed(&self) -> Utf8Typed<'_> {
60 matches!(self.dtype(), DType::Utf8(..))
61 .then(|| Utf8Typed(self))
62 .vortex_expect("Array does not have DType::Utf8")
63 }
64
65 pub fn as_binary_typed(&self) -> BinaryTyped<'_> {
67 matches!(self.dtype(), DType::Binary(..))
68 .then(|| BinaryTyped(self))
69 .vortex_expect("Array does not have DType::Binary")
70 }
71
72 pub fn as_struct_typed(&self) -> StructTyped<'_> {
74 matches!(self.dtype(), DType::Struct(..))
75 .then(|| StructTyped(self))
76 .vortex_expect("Array does not have DType::Struct")
77 }
78
79 pub fn as_list_typed(&self) -> ListTyped<'_> {
81 matches!(self.dtype(), DType::List(..))
82 .then(|| ListTyped(self))
83 .vortex_expect("Array does not have DType::List")
84 }
85
86 pub fn as_extension_typed(&self) -> ExtensionTyped<'_> {
88 matches!(self.dtype(), DType::Extension(..))
89 .then(|| ExtensionTyped(self))
90 .vortex_expect("Array does not have DType::Extension")
91 }
92
93 pub fn try_to_mask_fill_null_false(&self, ctx: &mut ExecutionCtx) -> VortexResult<Mask> {
94 if !matches!(self.dtype(), DType::Bool(_)) {
95 vortex_bail!("mask must be bool array, has dtype {}", self.dtype());
96 }
97
98 let array = self
100 .clone()
101 .fill_null(Scalar::bool(false, self.dtype().nullability()))?;
102
103 Ok(array
104 .execute::<BoolArray>(ctx)?
105 .to_mask_fill_null_false(ctx))
106 }
107}
108
109#[expect(dead_code)]
110pub struct NullTyped<'a>(&'a ArrayRef);
111
112pub struct BoolTyped<'a>(&'a ArrayRef);
113
114impl BoolTyped<'_> {
115 pub fn true_count(&self) -> VortexResult<usize> {
116 let mut ctx = LEGACY_SESSION.create_execution_ctx();
117 let true_count = sum(self.0, &mut ctx)?;
118 Ok(true_count
119 .as_primitive()
120 .as_::<usize>()
121 .vortex_expect("true count should never be null"))
122 }
123}
124
125pub struct PrimitiveTyped<'a>(&'a ArrayRef);
126
127impl PrimitiveTyped<'_> {
128 pub fn ptype(&self) -> PType {
129 let DType::Primitive(ptype, _) = self.0.dtype() else {
130 vortex_panic!("Expected Primitive DType")
131 };
132 *ptype
133 }
134
135 pub fn value(&self, idx: usize) -> VortexResult<Option<PValue>> {
137 self.0
138 .is_valid(idx, &mut LEGACY_SESSION.create_execution_ctx())?
139 .then(|| self.value_unchecked(idx))
140 .transpose()
141 }
142
143 pub fn value_unchecked(&self, idx: usize) -> VortexResult<PValue> {
145 Ok(self
146 .0
147 .execute_scalar(idx, &mut LEGACY_SESSION.create_execution_ctx())?
148 .as_primitive()
149 .pvalue()
150 .unwrap_or_else(|| PValue::zero(&self.ptype())))
151 }
152}
153
154impl IndexOrd<Option<PValue>> for PrimitiveTyped<'_> {
155 fn index_cmp(&self, idx: usize, elem: &Option<PValue>) -> VortexResult<Option<Ordering>> {
156 let value = self.value(idx)?;
157 Ok(value.partial_cmp(elem))
158 }
159
160 fn index_len(&self) -> usize {
161 self.0.len()
162 }
163}
164
165impl IndexOrd<PValue> for PrimitiveTyped<'_> {
167 fn index_cmp(&self, idx: usize, elem: &PValue) -> VortexResult<Option<Ordering>> {
168 assert!(
169 self.0
170 .all_valid(&mut LEGACY_SESSION.create_execution_ctx())?
171 );
172 let value = self.value_unchecked(idx)?;
173 Ok(value.partial_cmp(elem))
174 }
175
176 fn index_len(&self) -> usize {
177 self.0.len()
178 }
179}
180
181#[expect(dead_code)]
182pub struct Utf8Typed<'a>(&'a ArrayRef);
183
184#[expect(dead_code)]
185pub struct BinaryTyped<'a>(&'a ArrayRef);
186
187#[expect(dead_code)]
188pub struct DecimalTyped<'a>(&'a ArrayRef);
189
190pub struct StructTyped<'a>(&'a ArrayRef);
191
192impl StructTyped<'_> {
193 pub fn names(&self) -> &FieldNames {
194 let DType::Struct(st, _) = self.0.dtype() else {
195 unreachable!()
196 };
197 st.names()
198 }
199
200 pub fn dtypes(&self) -> Vec<DType> {
201 let DType::Struct(st, _) = self.0.dtype() else {
202 unreachable!()
203 };
204 st.fields().collect()
205 }
206
207 pub fn nfields(&self) -> usize {
208 self.names().len()
209 }
210}
211
212#[expect(dead_code)]
213pub struct ListTyped<'a>(&'a ArrayRef);
214
215pub struct ExtensionTyped<'a>(&'a ArrayRef);
216
217impl ExtensionTyped<'_> {
218 pub fn ext_dtype(&self) -> &ExtDTypeRef {
220 let DType::Extension(ext_dtype) = self.0.dtype() else {
221 vortex_panic!("Expected ExtDType")
222 };
223 ext_dtype
224 }
225}