1use std::cmp::Ordering;
6use std::sync::Arc;
7
8use vortex_dtype::{DType, ExtDType, FieldNames, PType};
9use vortex_error::{VortexExpect, VortexResult, vortex_panic};
10use vortex_scalar::PValue;
11
12use crate::Array;
13use crate::compute::sum;
14use crate::search_sorted::IndexOrd;
15
16impl dyn Array + '_ {
17 pub fn as_null_typed(&self) -> NullTyped<'_> {
19 matches!(self.dtype(), DType::Null)
20 .then(|| NullTyped(self))
21 .vortex_expect("Array does not have DType::Null")
22 }
23
24 pub fn as_bool_typed(&self) -> BoolTyped<'_> {
26 matches!(self.dtype(), DType::Bool(..))
27 .then(|| BoolTyped(self))
28 .vortex_expect("Array does not have DType::Bool")
29 }
30
31 pub fn as_primitive_typed(&self) -> PrimitiveTyped<'_> {
33 matches!(self.dtype(), DType::Primitive(..))
34 .then(|| PrimitiveTyped(self))
35 .vortex_expect("Array does not have DType::Primitive")
36 }
37
38 pub fn as_decimal_typed(&self) -> DecimalTyped<'_> {
40 matches!(self.dtype(), DType::Decimal(..))
41 .then(|| DecimalTyped(self))
42 .vortex_expect("Array does not have DType::Decimal")
43 }
44
45 pub fn as_utf8_typed(&self) -> Utf8Typed<'_> {
47 matches!(self.dtype(), DType::Utf8(..))
48 .then(|| Utf8Typed(self))
49 .vortex_expect("Array does not have DType::Utf8")
50 }
51
52 pub fn as_binary_typed(&self) -> BinaryTyped<'_> {
54 matches!(self.dtype(), DType::Binary(..))
55 .then(|| BinaryTyped(self))
56 .vortex_expect("Array does not have DType::Binary")
57 }
58
59 pub fn as_struct_typed(&self) -> StructTyped<'_> {
61 matches!(self.dtype(), DType::Struct(..))
62 .then(|| StructTyped(self))
63 .vortex_expect("Array does not have DType::Struct")
64 }
65
66 pub fn as_list_typed(&self) -> ListTyped<'_> {
68 matches!(self.dtype(), DType::List(..))
69 .then(|| ListTyped(self))
70 .vortex_expect("Array does not have DType::List")
71 }
72
73 pub fn as_extension_typed(&self) -> ExtensionTyped<'_> {
75 matches!(self.dtype(), DType::Extension(..))
76 .then(|| ExtensionTyped(self))
77 .vortex_expect("Array does not have DType::Extension")
78 }
79}
80
81#[allow(dead_code)]
82pub struct NullTyped<'a>(&'a dyn Array);
83
84pub struct BoolTyped<'a>(&'a dyn Array);
85
86impl BoolTyped<'_> {
87 pub fn true_count(&self) -> VortexResult<usize> {
88 let true_count = sum(self.0)?;
89 Ok(true_count
90 .as_primitive()
91 .as_::<usize>()
92 .vortex_expect("true count should never overflow usize")
93 .vortex_expect("true count should never be null"))
94 }
95}
96
97pub struct PrimitiveTyped<'a>(&'a dyn Array);
98
99impl PrimitiveTyped<'_> {
100 pub fn ptype(&self) -> PType {
101 let DType::Primitive(ptype, _) = self.0.dtype() else {
102 vortex_panic!("Expected Primitive DType")
103 };
104 *ptype
105 }
106
107 pub fn value(&self, idx: usize) -> Option<PValue> {
109 self.0
110 .is_valid(idx)
111 .vortex_expect("is valid")
112 .then(|| self.value_unchecked(idx))
113 }
114
115 pub fn value_unchecked(&self, idx: usize) -> PValue {
117 self.0
118 .scalar_at(idx)
119 .vortex_expect("scalar at index")
120 .as_primitive()
121 .pvalue()
122 .unwrap_or_else(|| PValue::zero(self.ptype()))
123 }
124}
125
126impl IndexOrd<Option<PValue>> for PrimitiveTyped<'_> {
127 fn index_cmp(&self, idx: usize, elem: &Option<PValue>) -> Option<Ordering> {
128 self.value(idx).partial_cmp(elem)
129 }
130
131 fn index_len(&self) -> usize {
132 self.0.len()
133 }
134}
135
136impl IndexOrd<PValue> for PrimitiveTyped<'_> {
138 fn index_cmp(&self, idx: usize, elem: &PValue) -> Option<Ordering> {
139 assert!(self.0.all_valid().vortex_expect("all valid"));
140 self.value_unchecked(idx).partial_cmp(elem)
141 }
142
143 fn index_len(&self) -> usize {
144 self.0.len()
145 }
146}
147
148#[allow(dead_code)]
149pub struct Utf8Typed<'a>(&'a dyn Array);
150
151#[allow(dead_code)]
152pub struct BinaryTyped<'a>(&'a dyn Array);
153
154#[allow(dead_code)]
155pub struct DecimalTyped<'a>(&'a dyn Array);
156
157pub struct StructTyped<'a>(&'a dyn Array);
158
159impl StructTyped<'_> {
160 pub fn names(&self) -> &FieldNames {
161 let DType::Struct(st, _) = self.0.dtype() else {
162 unreachable!()
163 };
164 st.names()
165 }
166
167 pub fn dtypes(&self) -> Vec<DType> {
168 let DType::Struct(st, _) = self.0.dtype() else {
169 unreachable!()
170 };
171 st.fields().collect()
172 }
173
174 pub fn nfields(&self) -> usize {
175 self.names().len()
176 }
177}
178
179#[allow(dead_code)]
180pub struct ListTyped<'a>(&'a dyn Array);
181
182pub struct ExtensionTyped<'a>(&'a dyn Array);
183
184impl ExtensionTyped<'_> {
185 pub fn ext_dtype(&self) -> &Arc<ExtDType> {
187 let DType::Extension(ext_dtype) = self.0.dtype() else {
188 vortex_panic!("Expected ExtDType")
189 };
190 ext_dtype
191 }
192}