vortex_array/compute/
is_constant.rs1use std::any::Any;
5use std::sync::LazyLock;
6
7use arcref::ArcRef;
8use vortex_dtype::{DType, Nullability};
9use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
10use vortex_scalar::Scalar;
11
12use crate::Array;
13use crate::arrays::{ConstantVTable, NullVTable};
14use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Options, Output};
15use crate::stats::{Precision, Stat, StatsProviderExt};
16use crate::vtable::VTable;
17
18pub fn is_constant(array: &dyn Array) -> VortexResult<Option<bool>> {
33 let opts = IsConstantOpts::default();
34 is_constant_opts(array, &opts)
35}
36
37pub fn is_constant_opts(array: &dyn Array, options: &IsConstantOpts) -> VortexResult<Option<bool>> {
41 let result = IS_CONSTANT_FN
42 .invoke(&InvocationArgs {
43 inputs: &[array.into()],
44 options,
45 })?
46 .unwrap_scalar()?
47 .as_bool()
48 .value();
49
50 Ok(result)
51}
52
53pub static IS_CONSTANT_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
54 let compute = ComputeFn::new("is_constant".into(), ArcRef::new_ref(&IsConstant));
55 for kernel in inventory::iter::<IsConstantKernelRef> {
56 compute.register_kernel(kernel.0.clone());
57 }
58 compute
59});
60
61struct IsConstant;
62
63impl ComputeFnVTable for IsConstant {
64 fn invoke(
65 &self,
66 args: &InvocationArgs,
67 kernels: &[ArcRef<dyn Kernel>],
68 ) -> VortexResult<Output> {
69 let IsConstantArgs { array, options } = IsConstantArgs::try_from(args)?;
70
71 if let Some(Precision::Exact(value)) = array.statistics().get_as::<bool>(Stat::IsConstant) {
73 return Ok(Scalar::from(Some(value)).into());
74 }
75
76 let value = is_constant_impl(array, options, kernels)?;
77
78 if options.cost == Cost::Canonicalize {
79 assert!(
81 value.is_some(),
82 "is constant in array {array} canonicalize returned None"
83 );
84 }
85
86 if let Some(value) = value {
88 array
89 .statistics()
90 .set(Stat::IsConstant, Precision::Exact(value.into()));
91 }
92
93 Ok(Scalar::from(value).into())
94 }
95
96 fn return_dtype(&self, _args: &InvocationArgs) -> VortexResult<DType> {
97 Ok(DType::Bool(Nullability::Nullable))
100 }
101
102 fn return_len(&self, _args: &InvocationArgs) -> VortexResult<usize> {
103 Ok(1)
104 }
105
106 fn is_elementwise(&self) -> bool {
107 false
108 }
109}
110
111fn is_constant_impl(
112 array: &dyn Array,
113 options: &IsConstantOpts,
114 kernels: &[ArcRef<dyn Kernel>],
115) -> VortexResult<Option<bool>> {
116 match array.len() {
117 0 => return Ok(Some(false)),
119 1 => return Ok(Some(true)),
121 _ => {}
122 }
123
124 if array.as_opt::<ConstantVTable>().is_some() || array.as_opt::<NullVTable>().is_some() {
126 return Ok(Some(true));
127 }
128
129 let all_invalid = array.all_invalid()?;
130 if all_invalid {
131 return Ok(Some(true));
132 }
133
134 let all_valid = array.all_valid()?;
135
136 if !all_valid && !all_invalid {
138 return Ok(Some(false));
139 }
140
141 let min = array.statistics().get_scalar(Stat::Min, array.dtype());
143 let max = array.statistics().get_scalar(Stat::Max, array.dtype());
144
145 if let Some((min, max)) = min.zip(max) {
146 if min.is_exact()
148 && min == max
149 && (Stat::NaNCount.dtype(array.dtype()).is_none()
150 || array.statistics().get_as::<u64>(Stat::NaNCount) == Some(Precision::exact(0u64)))
151 {
152 return Ok(Some(true));
153 }
154 }
155
156 assert!(
157 all_valid,
158 "All values must be valid as an invariant of the VTable."
159 );
160 let args = InvocationArgs {
161 inputs: &[array.into()],
162 options,
163 };
164 for kernel in kernels {
165 if let Some(output) = kernel.invoke(&args)? {
166 return Ok(output.unwrap_scalar()?.as_bool().value());
167 }
168 }
169 if let Some(output) = array.invoke(&IS_CONSTANT_FN, &args)? {
170 return Ok(output.unwrap_scalar()?.as_bool().value());
171 }
172
173 log::debug!(
174 "No is_constant implementation found for {}",
175 array.encoding_id()
176 );
177
178 if options.cost == Cost::Canonicalize && !array.is_canonical() {
179 let array = array.to_canonical()?;
180 let is_constant = is_constant_opts(array.as_ref(), options)?;
181 return Ok(is_constant);
182 }
183
184 Ok(None)
186}
187
188pub struct IsConstantKernelRef(ArcRef<dyn Kernel>);
189inventory::collect!(IsConstantKernelRef);
190
191pub trait IsConstantKernel: VTable {
192 fn is_constant(&self, array: &Self::Array, opts: &IsConstantOpts)
199 -> VortexResult<Option<bool>>;
200}
201
202#[derive(Debug)]
203pub struct IsConstantKernelAdapter<V: VTable>(pub V);
204
205impl<V: VTable + IsConstantKernel> IsConstantKernelAdapter<V> {
206 pub const fn lift(&'static self) -> IsConstantKernelRef {
207 IsConstantKernelRef(ArcRef::new_ref(self))
208 }
209}
210
211impl<V: VTable + IsConstantKernel> Kernel for IsConstantKernelAdapter<V> {
212 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
213 let args = IsConstantArgs::try_from(args)?;
214 let Some(array) = args.array.as_opt::<V>() else {
215 return Ok(None);
216 };
217 let is_constant = V::is_constant(&self.0, array, args.options)?;
218 Ok(Some(Scalar::from(is_constant).into()))
219 }
220}
221
222struct IsConstantArgs<'a> {
223 array: &'a dyn Array,
224 options: &'a IsConstantOpts,
225}
226
227impl<'a> TryFrom<&InvocationArgs<'a>> for IsConstantArgs<'a> {
228 type Error = VortexError;
229
230 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
231 if value.inputs.len() != 1 {
232 vortex_bail!("Expected 1 input, found {}", value.inputs.len());
233 }
234 let array = value.inputs[0]
235 .array()
236 .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
237 let options = value
238 .options
239 .as_any()
240 .downcast_ref::<IsConstantOpts>()
241 .ok_or_else(|| vortex_err!("Expected options to be of type IsConstantOpts"))?;
242 Ok(Self { array, options })
243 }
244}
245
246#[derive(Clone, Copy, Debug, Eq, PartialEq)]
250pub enum Cost {
251 Negligible,
253 Specialized,
257 Canonicalize,
260}
261
262#[derive(Clone, Debug)]
264pub struct IsConstantOpts {
265 pub cost: Cost,
267}
268
269impl Default for IsConstantOpts {
270 fn default() -> Self {
271 Self {
272 cost: Cost::Canonicalize,
273 }
274 }
275}
276
277impl Options for IsConstantOpts {
278 fn as_any(&self) -> &dyn Any {
279 self
280 }
281}
282
283impl IsConstantOpts {
284 pub fn is_negligible_cost(&self) -> bool {
285 self.cost == Cost::Negligible
286 }
287}
288
289#[cfg(test)]
290mod tests {
291 use crate::arrays::PrimitiveArray;
292 use crate::stats::Stat;
293
294 #[test]
295 fn is_constant_min_max_no_nan() {
296 let arr = PrimitiveArray::from_iter([0, 1]);
297 arr.statistics()
298 .compute_all(&[Stat::Min, Stat::Max])
299 .unwrap();
300 assert!(!arr.is_constant());
301
302 let arr = PrimitiveArray::from_iter([0, 0]);
303 arr.statistics()
304 .compute_all(&[Stat::Min, Stat::Max])
305 .unwrap();
306 assert!(arr.is_constant());
307
308 let arr = PrimitiveArray::from_option_iter([Some(0), Some(0)]);
309 assert!(arr.is_constant());
310 }
311
312 #[test]
313 fn is_constant_min_max_with_nan() {
314 let arr = PrimitiveArray::from_iter([0.0, 0.0, f32::NAN]);
315 arr.statistics()
316 .compute_all(&[Stat::Min, Stat::Max])
317 .unwrap();
318 assert!(!arr.is_constant());
319
320 let arr =
321 PrimitiveArray::from_option_iter([Some(f32::NEG_INFINITY), Some(f32::NEG_INFINITY)]);
322 arr.statistics()
323 .compute_all(&[Stat::Min, Stat::Max])
324 .unwrap();
325 assert!(arr.is_constant());
326 }
327}