1use std::cmp::Ordering;
6use std::sync::Arc;
7
8use vortex_dtype::DType;
9use vortex_dtype::ExtDType;
10use vortex_dtype::FieldNames;
11use vortex_dtype::PType;
12use vortex_error::VortexExpect;
13use vortex_error::VortexResult;
14use vortex_error::vortex_panic;
15use vortex_scalar::PValue;
16
17use crate::Array;
18use crate::compute::sum;
19use crate::search_sorted::IndexOrd;
20
21impl dyn Array + '_ {
22 pub fn as_null_typed(&self) -> NullTyped<'_> {
24 matches!(self.dtype(), DType::Null)
25 .then(|| NullTyped(self))
26 .vortex_expect("Array does not have DType::Null")
27 }
28
29 pub fn as_bool_typed(&self) -> BoolTyped<'_> {
31 matches!(self.dtype(), DType::Bool(..))
32 .then(|| BoolTyped(self))
33 .vortex_expect("Array does not have DType::Bool")
34 }
35
36 pub fn as_primitive_typed(&self) -> PrimitiveTyped<'_> {
38 matches!(self.dtype(), DType::Primitive(..))
39 .then(|| PrimitiveTyped(self))
40 .vortex_expect("Array does not have DType::Primitive")
41 }
42
43 pub fn as_decimal_typed(&self) -> DecimalTyped<'_> {
45 matches!(self.dtype(), DType::Decimal(..))
46 .then(|| DecimalTyped(self))
47 .vortex_expect("Array does not have DType::Decimal")
48 }
49
50 pub fn as_utf8_typed(&self) -> Utf8Typed<'_> {
52 matches!(self.dtype(), DType::Utf8(..))
53 .then(|| Utf8Typed(self))
54 .vortex_expect("Array does not have DType::Utf8")
55 }
56
57 pub fn as_binary_typed(&self) -> BinaryTyped<'_> {
59 matches!(self.dtype(), DType::Binary(..))
60 .then(|| BinaryTyped(self))
61 .vortex_expect("Array does not have DType::Binary")
62 }
63
64 pub fn as_struct_typed(&self) -> StructTyped<'_> {
66 matches!(self.dtype(), DType::Struct(..))
67 .then(|| StructTyped(self))
68 .vortex_expect("Array does not have DType::Struct")
69 }
70
71 pub fn as_list_typed(&self) -> ListTyped<'_> {
73 matches!(self.dtype(), DType::List(..))
74 .then(|| ListTyped(self))
75 .vortex_expect("Array does not have DType::List")
76 }
77
78 pub fn as_extension_typed(&self) -> ExtensionTyped<'_> {
80 matches!(self.dtype(), DType::Extension(..))
81 .then(|| ExtensionTyped(self))
82 .vortex_expect("Array does not have DType::Extension")
83 }
84}
85
86#[allow(dead_code)]
87pub struct NullTyped<'a>(&'a dyn Array);
88
89pub struct BoolTyped<'a>(&'a dyn Array);
90
91impl BoolTyped<'_> {
92 pub fn true_count(&self) -> VortexResult<usize> {
93 let true_count = sum(self.0)?;
94 Ok(true_count
95 .as_primitive()
96 .as_::<usize>()
97 .vortex_expect("true count should never be null"))
98 }
99}
100
101pub struct PrimitiveTyped<'a>(&'a dyn Array);
102
103impl PrimitiveTyped<'_> {
104 pub fn ptype(&self) -> PType {
105 let DType::Primitive(ptype, _) = self.0.dtype() else {
106 vortex_panic!("Expected Primitive DType")
107 };
108 *ptype
109 }
110
111 pub fn value(&self, idx: usize) -> Option<PValue> {
113 self.0.is_valid(idx).then(|| self.value_unchecked(idx))
114 }
115
116 pub fn value_unchecked(&self, idx: usize) -> PValue {
118 self.0
119 .scalar_at(idx)
120 .as_primitive()
121 .pvalue()
122 .unwrap_or_else(|| PValue::zero(self.ptype()))
123 }
124}
125
126impl IndexOrd<Option<PValue>> for PrimitiveTyped<'_> {
127 fn index_cmp(&self, idx: usize, elem: &Option<PValue>) -> Option<Ordering> {
128 self.value(idx).partial_cmp(elem)
129 }
130
131 fn index_len(&self) -> usize {
132 self.0.len()
133 }
134}
135
136impl IndexOrd<PValue> for PrimitiveTyped<'_> {
138 fn index_cmp(&self, idx: usize, elem: &PValue) -> Option<Ordering> {
139 assert!(self.0.all_valid());
140 self.value_unchecked(idx).partial_cmp(elem)
141 }
142
143 fn index_len(&self) -> usize {
144 self.0.len()
145 }
146}
147
148#[allow(dead_code)]
149pub struct Utf8Typed<'a>(&'a dyn Array);
150
151#[allow(dead_code)]
152pub struct BinaryTyped<'a>(&'a dyn Array);
153
154#[allow(dead_code)]
155pub struct DecimalTyped<'a>(&'a dyn Array);
156
157pub struct StructTyped<'a>(&'a dyn Array);
158
159impl StructTyped<'_> {
160 pub fn names(&self) -> &FieldNames {
161 let DType::Struct(st, _) = self.0.dtype() else {
162 unreachable!()
163 };
164 st.names()
165 }
166
167 pub fn dtypes(&self) -> Vec<DType> {
168 let DType::Struct(st, _) = self.0.dtype() else {
169 unreachable!()
170 };
171 st.fields().collect()
172 }
173
174 pub fn nfields(&self) -> usize {
175 self.names().len()
176 }
177}
178
179#[allow(dead_code)]
180pub struct ListTyped<'a>(&'a dyn Array);
181
182pub struct ExtensionTyped<'a>(&'a dyn Array);
183
184impl ExtensionTyped<'_> {
185 pub fn ext_dtype(&self) -> &Arc<ExtDType> {
187 let DType::Extension(ext_dtype) = self.0.dtype() else {
188 vortex_panic!("Expected ExtDType")
189 };
190 ext_dtype
191 }
192}