vortex_array/aggregate_fn/fns/all_nan/
mod.rs1use vortex_error::VortexResult;
5use vortex_session::VortexSession;
6use vortex_session::registry::CachedId;
7
8use crate::ArrayRef;
9use crate::Columnar;
10use crate::ExecutionCtx;
11use crate::IntoArray;
12use crate::aggregate_fn::AggregateFnId;
13use crate::aggregate_fn::AggregateFnVTable;
14use crate::aggregate_fn::EmptyOptions;
15use crate::aggregate_fn::fns::nan_count::nan_count;
16use crate::dtype::DType;
17use crate::dtype::Nullability;
18use crate::scalar::Scalar;
19
20#[derive(Clone, Debug)]
31pub struct AllNan;
32
33impl AggregateFnVTable for AllNan {
34 type Options = EmptyOptions;
35 type Partial = bool;
36
37 fn id(&self) -> AggregateFnId {
38 static ID: CachedId = CachedId::new("vortex.all_nan");
39 *ID
40 }
41
42 fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
43 Ok(Some(vec![]))
44 }
45
46 fn deserialize(
47 &self,
48 _metadata: &[u8],
49 _session: &VortexSession,
50 ) -> VortexResult<Self::Options> {
51 Ok(EmptyOptions)
52 }
53
54 fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option<DType> {
55 matches!(input_dtype, DType::Primitive(ptype, _) if ptype.is_float())
56 .then_some(DType::Bool(Nullability::Nullable))
57 }
58
59 fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> {
60 self.return_dtype(options, input_dtype)
61 }
62
63 fn empty_partial(
64 &self,
65 _options: &Self::Options,
66 _input_dtype: &DType,
67 ) -> VortexResult<Self::Partial> {
68 Ok(true)
69 }
70
71 fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> {
72 *partial &= bool::try_from(&other)?;
73 Ok(())
74 }
75
76 fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
77 Ok(Scalar::bool(*partial, Nullability::Nullable))
78 }
79
80 fn reset(&self, partial: &mut Self::Partial) {
81 *partial = true;
82 }
83
84 fn is_saturated(&self, partial: &Self::Partial) -> bool {
85 !*partial
86 }
87
88 fn try_accumulate(
89 &self,
90 state: &mut Self::Partial,
91 batch: &ArrayRef,
92 ctx: &mut ExecutionCtx,
93 ) -> VortexResult<bool> {
94 if !matches!(batch.dtype(), DType::Primitive(ptype, _) if ptype.is_float()) {
95 *state = false;
96 return Ok(true);
97 }
98
99 *state &= nan_count(batch, ctx)? == batch.len();
100 Ok(true)
101 }
102
103 fn accumulate(
104 &self,
105 partial: &mut Self::Partial,
106 batch: &Columnar,
107 ctx: &mut ExecutionCtx,
108 ) -> VortexResult<()> {
109 let array = match batch {
112 Columnar::Constant(c) => c.clone().into_array(),
113 Columnar::Canonical(c) => c.clone().into_array(),
114 };
115 if !matches!(array.dtype(), DType::Primitive(ptype, _) if ptype.is_float()) {
116 *partial = false;
117 return Ok(());
118 }
119
120 *partial &= nan_count(&array, ctx)? == array.len();
121 Ok(())
122 }
123
124 fn finalize(&self, partials: ArrayRef) -> VortexResult<ArrayRef> {
125 Ok(partials)
126 }
127
128 fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> {
129 self.to_scalar(partial)
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use vortex_error::VortexResult;
136
137 use crate::IntoArray;
138 use crate::VortexSessionExecute;
139 use crate::aggregate_fn::Accumulator;
140 use crate::aggregate_fn::DynAccumulator;
141 use crate::aggregate_fn::EmptyOptions;
142 use crate::aggregate_fn::fns::all_nan::AllNan;
143 use crate::array_session;
144 use crate::arrays::PrimitiveArray;
145 use crate::dtype::DType;
146 use crate::dtype::Nullability;
147 use crate::dtype::PType;
148
149 #[test]
150 fn all_nan_aggregate_fn() -> VortexResult<()> {
151 let mut ctx = array_session().create_execution_ctx();
152 let dtype = DType::Primitive(PType::F32, Nullability::Nullable);
153 let mut acc = Accumulator::try_new(AllNan, EmptyOptions, dtype)?;
154
155 let batch = PrimitiveArray::from_option_iter([Some(f32::NAN), Some(f32::NAN)]).into_array();
156 acc.accumulate(&batch, &mut ctx)?;
157
158 assert!(bool::try_from(&acc.finish()?)?);
159 Ok(())
160 }
161
162 #[test]
163 fn all_nan_false_with_non_nan() -> VortexResult<()> {
164 let mut ctx = array_session().create_execution_ctx();
165 let dtype = DType::Primitive(PType::F32, Nullability::Nullable);
166 let mut acc = Accumulator::try_new(AllNan, EmptyOptions, dtype)?;
167
168 let batch = PrimitiveArray::from_option_iter([Some(f32::NAN), Some(1.0f32)]).into_array();
169 acc.accumulate(&batch, &mut ctx)?;
170
171 assert!(!bool::try_from(&acc.finish()?)?);
172 Ok(())
173 }
174
175 #[test]
176 fn all_nan_unsupported_for_non_float_values() -> VortexResult<()> {
177 let dtype = DType::Primitive(PType::I32, Nullability::Nullable);
178 assert!(Accumulator::try_new(AllNan, EmptyOptions, dtype).is_err());
179 Ok(())
180 }
181
182 #[test]
183 fn all_nan_false_with_null() -> VortexResult<()> {
184 let mut ctx = array_session().create_execution_ctx();
185 let dtype = DType::Primitive(PType::F32, Nullability::Nullable);
186 let mut acc = Accumulator::try_new(AllNan, EmptyOptions, dtype)?;
187
188 let batch = PrimitiveArray::from_option_iter([Some(f32::NAN), None]).into_array();
189 acc.accumulate(&batch, &mut ctx)?;
190
191 assert!(!bool::try_from(&acc.finish()?)?);
192 Ok(())
193 }
194
195 #[test]
196 fn all_nan_true_for_empty_float_values() -> VortexResult<()> {
197 let dtype = DType::Primitive(PType::F32, Nullability::Nullable);
198 let mut acc = Accumulator::try_new(AllNan, EmptyOptions, dtype)?;
199
200 assert!(bool::try_from(&acc.finish()?)?);
201 Ok(())
202 }
203}