1use std::cmp::Ordering;
3use std::sync::Arc;
4
5use vortex_dtype::{DType, ExtDType, FieldNames, PType};
6use vortex_error::{VortexExpect, VortexResult, vortex_panic};
7use vortex_scalar::PValue;
8
9use crate::Array;
10use crate::compute::sum;
11use crate::search_sorted::IndexOrd;
12
13impl dyn Array + '_ {
14 pub fn as_null_typed(&self) -> NullTyped {
16 matches!(self.dtype(), DType::Null)
17 .then(|| NullTyped(self))
18 .vortex_expect("Array does not have DType::Null")
19 }
20
21 pub fn as_bool_typed(&self) -> BoolTyped {
23 matches!(self.dtype(), DType::Bool(..))
24 .then(|| BoolTyped(self))
25 .vortex_expect("Array does not have DType::Bool")
26 }
27
28 pub fn as_primitive_typed(&self) -> PrimitiveTyped {
30 matches!(self.dtype(), DType::Primitive(..))
31 .then(|| PrimitiveTyped(self))
32 .vortex_expect("Array does not have DType::Primitive")
33 }
34
35 pub fn as_decimal_typed(&self) -> DecimalTyped {
37 matches!(self.dtype(), DType::Decimal(..))
38 .then(|| DecimalTyped(self))
39 .vortex_expect("Array does not have DType::Decimal")
40 }
41
42 pub fn as_utf8_typed(&self) -> Utf8Typed {
44 matches!(self.dtype(), DType::Utf8(..))
45 .then(|| Utf8Typed(self))
46 .vortex_expect("Array does not have DType::Utf8")
47 }
48
49 pub fn as_binary_typed(&self) -> BinaryTyped {
51 matches!(self.dtype(), DType::Binary(..))
52 .then(|| BinaryTyped(self))
53 .vortex_expect("Array does not have DType::Binary")
54 }
55
56 pub fn as_struct_typed(&self) -> StructTyped {
58 matches!(self.dtype(), DType::Struct(..))
59 .then(|| StructTyped(self))
60 .vortex_expect("Array does not have DType::Struct")
61 }
62
63 pub fn as_list_typed(&self) -> ListTyped {
65 matches!(self.dtype(), DType::List(..))
66 .then(|| ListTyped(self))
67 .vortex_expect("Array does not have DType::List")
68 }
69
70 pub fn as_extension_typed(&self) -> ExtensionTyped {
72 matches!(self.dtype(), DType::Extension(..))
73 .then(|| ExtensionTyped(self))
74 .vortex_expect("Array does not have DType::Extension")
75 }
76}
77
78#[allow(dead_code)]
79pub struct NullTyped<'a>(&'a dyn Array);
80
81pub struct BoolTyped<'a>(&'a dyn Array);
82
83impl BoolTyped<'_> {
84 pub fn true_count(&self) -> VortexResult<usize> {
85 let true_count = sum(self.0)?;
86 Ok(true_count
87 .as_primitive()
88 .as_::<usize>()
89 .vortex_expect("true count should never overflow usize")
90 .vortex_expect("true count should never be null"))
91 }
92}
93
94pub struct PrimitiveTyped<'a>(&'a dyn Array);
95
96impl PrimitiveTyped<'_> {
97 pub fn ptype(&self) -> PType {
98 let DType::Primitive(ptype, _) = self.0.dtype() else {
99 vortex_panic!("Expected Primitive DType")
100 };
101 *ptype
102 }
103
104 pub fn value(&self, idx: usize) -> Option<PValue> {
106 self.0
107 .is_valid(idx)
108 .vortex_expect("is valid")
109 .then(|| self.value_unchecked(idx))
110 }
111
112 pub fn value_unchecked(&self, idx: usize) -> PValue {
114 self.0
115 .scalar_at(idx)
116 .vortex_expect("scalar at index")
117 .as_primitive()
118 .pvalue()
119 .unwrap_or_else(|| PValue::zero(self.ptype()))
120 }
121}
122
123impl IndexOrd<Option<PValue>> for PrimitiveTyped<'_> {
124 fn index_cmp(&self, idx: usize, elem: &Option<PValue>) -> Option<Ordering> {
125 self.value(idx).partial_cmp(elem)
126 }
127
128 fn index_len(&self) -> usize {
129 self.0.len()
130 }
131}
132
133impl IndexOrd<PValue> for PrimitiveTyped<'_> {
135 fn index_cmp(&self, idx: usize, elem: &PValue) -> Option<Ordering> {
136 assert!(self.0.all_valid().vortex_expect("all valid"));
137 self.value_unchecked(idx).partial_cmp(elem)
138 }
139
140 fn index_len(&self) -> usize {
141 self.0.len()
142 }
143}
144
145#[allow(dead_code)]
146pub struct Utf8Typed<'a>(&'a dyn Array);
147
148#[allow(dead_code)]
149pub struct BinaryTyped<'a>(&'a dyn Array);
150
151#[allow(dead_code)]
152pub struct DecimalTyped<'a>(&'a dyn Array);
153
154pub struct StructTyped<'a>(&'a dyn Array);
155
156impl StructTyped<'_> {
157 pub fn names(&self) -> &FieldNames {
158 let DType::Struct(st, _) = self.0.dtype() else {
159 unreachable!()
160 };
161 st.names()
162 }
163
164 pub fn dtypes(&self) -> Vec<DType> {
165 let DType::Struct(st, _) = self.0.dtype() else {
166 unreachable!()
167 };
168 st.fields().collect()
169 }
170
171 pub fn nfields(&self) -> usize {
172 self.names().len()
173 }
174}
175
176#[allow(dead_code)]
177pub struct ListTyped<'a>(&'a dyn Array);
178
179pub struct ExtensionTyped<'a>(&'a dyn Array);
180
181impl ExtensionTyped<'_> {
182 pub fn ext_dtype(&self) -> &Arc<ExtDType> {
184 let DType::Extension(ext_dtype) = self.0.dtype() else {
185 vortex_panic!("Expected ExtDType")
186 };
187 ext_dtype
188 }
189}