1use std::cmp::Ordering;
6
7use vortex_dtype::DType;
8use vortex_dtype::FieldNames;
9use vortex_dtype::PType;
10use vortex_dtype::extension::ExtDTypeRef;
11use vortex_error::VortexExpect;
12use vortex_error::VortexResult;
13use vortex_error::vortex_panic;
14
15use crate::Array;
16use crate::compute::sum;
17use crate::scalar::PValue;
18use crate::search_sorted::IndexOrd;
19
20impl dyn Array + '_ {
21 pub fn as_null_typed(&self) -> NullTyped<'_> {
23 matches!(self.dtype(), DType::Null)
24 .then(|| NullTyped(self))
25 .vortex_expect("Array does not have DType::Null")
26 }
27
28 pub fn as_bool_typed(&self) -> BoolTyped<'_> {
30 matches!(self.dtype(), DType::Bool(..))
31 .then(|| BoolTyped(self))
32 .vortex_expect("Array does not have DType::Bool")
33 }
34
35 pub fn as_primitive_typed(&self) -> PrimitiveTyped<'_> {
37 matches!(self.dtype(), DType::Primitive(..))
38 .then(|| PrimitiveTyped(self))
39 .vortex_expect("Array does not have DType::Primitive")
40 }
41
42 pub fn as_decimal_typed(&self) -> DecimalTyped<'_> {
44 matches!(self.dtype(), DType::Decimal(..))
45 .then(|| DecimalTyped(self))
46 .vortex_expect("Array does not have DType::Decimal")
47 }
48
49 pub fn as_utf8_typed(&self) -> Utf8Typed<'_> {
51 matches!(self.dtype(), DType::Utf8(..))
52 .then(|| Utf8Typed(self))
53 .vortex_expect("Array does not have DType::Utf8")
54 }
55
56 pub fn as_binary_typed(&self) -> BinaryTyped<'_> {
58 matches!(self.dtype(), DType::Binary(..))
59 .then(|| BinaryTyped(self))
60 .vortex_expect("Array does not have DType::Binary")
61 }
62
63 pub fn as_struct_typed(&self) -> StructTyped<'_> {
65 matches!(self.dtype(), DType::Struct(..))
66 .then(|| StructTyped(self))
67 .vortex_expect("Array does not have DType::Struct")
68 }
69
70 pub fn as_list_typed(&self) -> ListTyped<'_> {
72 matches!(self.dtype(), DType::List(..))
73 .then(|| ListTyped(self))
74 .vortex_expect("Array does not have DType::List")
75 }
76
77 pub fn as_extension_typed(&self) -> ExtensionTyped<'_> {
79 matches!(self.dtype(), DType::Extension(..))
80 .then(|| ExtensionTyped(self))
81 .vortex_expect("Array does not have DType::Extension")
82 }
83}
84
85#[expect(dead_code)]
86pub struct NullTyped<'a>(&'a dyn Array);
87
88pub struct BoolTyped<'a>(&'a dyn Array);
89
90impl BoolTyped<'_> {
91 pub fn true_count(&self) -> VortexResult<usize> {
92 let true_count = sum(self.0)?;
93 Ok(true_count
94 .as_primitive()
95 .as_::<usize>()
96 .vortex_expect("true count should never be null"))
97 }
98}
99
100pub struct PrimitiveTyped<'a>(&'a dyn Array);
101
102impl PrimitiveTyped<'_> {
103 pub fn ptype(&self) -> PType {
104 let DType::Primitive(ptype, _) = self.0.dtype() else {
105 vortex_panic!("Expected Primitive DType")
106 };
107 *ptype
108 }
109
110 pub fn value(&self, idx: usize) -> VortexResult<Option<PValue>> {
112 self.0
113 .is_valid(idx)?
114 .then(|| self.value_unchecked(idx))
115 .transpose()
116 }
117
118 pub fn value_unchecked(&self, idx: usize) -> VortexResult<PValue> {
120 Ok(self
121 .0
122 .scalar_at(idx)?
123 .as_primitive()
124 .pvalue()
125 .unwrap_or_else(|| PValue::zero(&self.ptype())))
126 }
127}
128
129impl IndexOrd<Option<PValue>> for PrimitiveTyped<'_> {
130 fn index_cmp(&self, idx: usize, elem: &Option<PValue>) -> VortexResult<Option<Ordering>> {
131 let value = self.value(idx)?;
132 Ok(value.partial_cmp(elem))
133 }
134
135 fn index_len(&self) -> usize {
136 self.0.len()
137 }
138}
139
140impl IndexOrd<PValue> for PrimitiveTyped<'_> {
142 fn index_cmp(&self, idx: usize, elem: &PValue) -> VortexResult<Option<Ordering>> {
143 assert!(self.0.all_valid()?);
144 let value = self.value_unchecked(idx)?;
145 Ok(value.partial_cmp(elem))
146 }
147
148 fn index_len(&self) -> usize {
149 self.0.len()
150 }
151}
152
153#[expect(dead_code)]
154pub struct Utf8Typed<'a>(&'a dyn Array);
155
156#[expect(dead_code)]
157pub struct BinaryTyped<'a>(&'a dyn Array);
158
159#[expect(dead_code)]
160pub struct DecimalTyped<'a>(&'a dyn Array);
161
162pub struct StructTyped<'a>(&'a dyn Array);
163
164impl StructTyped<'_> {
165 pub fn names(&self) -> &FieldNames {
166 let DType::Struct(st, _) = self.0.dtype() else {
167 unreachable!()
168 };
169 st.names()
170 }
171
172 pub fn dtypes(&self) -> Vec<DType> {
173 let DType::Struct(st, _) = self.0.dtype() else {
174 unreachable!()
175 };
176 st.fields().collect()
177 }
178
179 pub fn nfields(&self) -> usize {
180 self.names().len()
181 }
182}
183
184#[expect(dead_code)]
185pub struct ListTyped<'a>(&'a dyn Array);
186
187pub struct ExtensionTyped<'a>(&'a dyn Array);
188
189impl ExtensionTyped<'_> {
190 pub fn ext_dtype(&self) -> &ExtDTypeRef {
192 let DType::Extension(ext_dtype) = self.0.dtype() else {
193 vortex_panic!("Expected ExtDType")
194 };
195 ext_dtype
196 }
197}