vortex_array/compute/
is_constant.rs1use std::any::Any;
2use std::sync::LazyLock;
3
4use arcref::ArcRef;
5use vortex_dtype::{DType, Nullability};
6use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
7use vortex_scalar::Scalar;
8
9use crate::Array;
10use crate::arrays::{ConstantVTable, NullVTable};
11use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Options, Output};
12use crate::stats::{Precision, Stat, StatsProviderExt};
13use crate::vtable::VTable;
14
15pub fn is_constant(array: &dyn Array) -> VortexResult<Option<bool>> {
30 let opts = IsConstantOpts::default();
31 is_constant_opts(array, &opts)
32}
33
34pub fn is_constant_opts(array: &dyn Array, options: &IsConstantOpts) -> VortexResult<Option<bool>> {
38 let result = IS_CONSTANT_FN
39 .invoke(&InvocationArgs {
40 inputs: &[array.into()],
41 options,
42 })?
43 .unwrap_scalar()?
44 .as_bool()
45 .value();
46
47 Ok(result)
48}
49
50pub static IS_CONSTANT_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
51 let compute = ComputeFn::new("is_constant".into(), ArcRef::new_ref(&IsConstant));
52 for kernel in inventory::iter::<IsConstantKernelRef> {
53 compute.register_kernel(kernel.0.clone());
54 }
55 compute
56});
57
58struct IsConstant;
59
60impl ComputeFnVTable for IsConstant {
61 fn invoke(
62 &self,
63 args: &InvocationArgs,
64 kernels: &[ArcRef<dyn Kernel>],
65 ) -> VortexResult<Output> {
66 let IsConstantArgs { array, options } = IsConstantArgs::try_from(args)?;
67
68 if let Some(Precision::Exact(value)) = array.statistics().get_as::<bool>(Stat::IsConstant) {
70 return Ok(Scalar::from(Some(value)).into());
71 }
72
73 let value = is_constant_impl(array, options, kernels)?;
74
75 if options.cost == Cost::Canonicalize {
76 assert!(
78 value.is_some(),
79 "is constant in array {array} canonicalize returned None"
80 );
81 }
82
83 if let Some(value) = value {
85 array
86 .statistics()
87 .set(Stat::IsConstant, Precision::Exact(value.into()));
88 }
89
90 Ok(Scalar::from(value).into())
91 }
92
93 fn return_dtype(&self, _args: &InvocationArgs) -> VortexResult<DType> {
94 Ok(DType::Bool(Nullability::Nullable))
97 }
98
99 fn return_len(&self, _args: &InvocationArgs) -> VortexResult<usize> {
100 Ok(1)
101 }
102
103 fn is_elementwise(&self) -> bool {
104 false
105 }
106}
107
108fn is_constant_impl(
109 array: &dyn Array,
110 options: &IsConstantOpts,
111 kernels: &[ArcRef<dyn Kernel>],
112) -> VortexResult<Option<bool>> {
113 match array.len() {
114 0 => return Ok(Some(false)),
116 1 => return Ok(Some(true)),
118 _ => {}
119 }
120
121 if array.as_opt::<ConstantVTable>().is_some() || array.as_opt::<NullVTable>().is_some() {
123 return Ok(Some(true));
124 }
125
126 let all_invalid = array.all_invalid()?;
127 if all_invalid {
128 return Ok(Some(true));
129 }
130
131 let all_valid = array.all_valid()?;
132
133 if !all_valid && !all_invalid {
135 return Ok(Some(false));
136 }
137
138 let min = array.statistics().get_scalar(Stat::Min, array.dtype());
140 let max = array.statistics().get_scalar(Stat::Max, array.dtype());
141
142 if let Some((min, max)) = min.zip(max) {
143 if min.is_exact()
145 && min == max
146 && (Stat::NaNCount.dtype(array.dtype()).is_none()
147 || array.statistics().get_as::<u64>(Stat::NaNCount) == Some(Precision::exact(0u64)))
148 {
149 return Ok(Some(true));
150 }
151 }
152
153 assert!(
154 all_valid,
155 "All values must be valid as an invariant of the VTable."
156 );
157 let args = InvocationArgs {
158 inputs: &[array.into()],
159 options,
160 };
161 for kernel in kernels {
162 if let Some(output) = kernel.invoke(&args)? {
163 return Ok(output.unwrap_scalar()?.as_bool().value());
164 }
165 }
166 if let Some(output) = array.invoke(&IS_CONSTANT_FN, &args)? {
167 return Ok(output.unwrap_scalar()?.as_bool().value());
168 }
169
170 log::debug!(
171 "No is_constant implementation found for {}",
172 array.encoding_id()
173 );
174
175 if options.cost == Cost::Canonicalize && !array.is_canonical() {
176 let array = array.to_canonical()?;
177 let is_constant = is_constant_opts(array.as_ref(), options)?;
178 return Ok(is_constant);
179 }
180
181 Ok(None)
183}
184
185pub struct IsConstantKernelRef(ArcRef<dyn Kernel>);
186inventory::collect!(IsConstantKernelRef);
187
188pub trait IsConstantKernel: VTable {
189 fn is_constant(&self, array: &Self::Array, opts: &IsConstantOpts)
196 -> VortexResult<Option<bool>>;
197}
198
199#[derive(Debug)]
200pub struct IsConstantKernelAdapter<V: VTable>(pub V);
201
202impl<V: VTable + IsConstantKernel> IsConstantKernelAdapter<V> {
203 pub const fn lift(&'static self) -> IsConstantKernelRef {
204 IsConstantKernelRef(ArcRef::new_ref(self))
205 }
206}
207
208impl<V: VTable + IsConstantKernel> Kernel for IsConstantKernelAdapter<V> {
209 fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
210 let args = IsConstantArgs::try_from(args)?;
211 let Some(array) = args.array.as_opt::<V>() else {
212 return Ok(None);
213 };
214 let is_constant = V::is_constant(&self.0, array, args.options)?;
215 Ok(Some(Scalar::from(is_constant).into()))
216 }
217}
218
219struct IsConstantArgs<'a> {
220 array: &'a dyn Array,
221 options: &'a IsConstantOpts,
222}
223
224impl<'a> TryFrom<&InvocationArgs<'a>> for IsConstantArgs<'a> {
225 type Error = VortexError;
226
227 fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
228 if value.inputs.len() != 1 {
229 vortex_bail!("Expected 1 input, found {}", value.inputs.len());
230 }
231 let array = value.inputs[0]
232 .array()
233 .ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
234 let options = value
235 .options
236 .as_any()
237 .downcast_ref::<IsConstantOpts>()
238 .ok_or_else(|| vortex_err!("Expected options to be of type IsConstantOpts"))?;
239 Ok(Self { array, options })
240 }
241}
242
243#[derive(Clone, Copy, Debug, Eq, PartialEq)]
247pub enum Cost {
248 Negligible,
250 Specialized,
254 Canonicalize,
257}
258
259#[derive(Clone, Debug)]
261pub struct IsConstantOpts {
262 pub cost: Cost,
264}
265
266impl Default for IsConstantOpts {
267 fn default() -> Self {
268 Self {
269 cost: Cost::Canonicalize,
270 }
271 }
272}
273
274impl Options for IsConstantOpts {
275 fn as_any(&self) -> &dyn Any {
276 self
277 }
278}
279
280impl IsConstantOpts {
281 pub fn is_negligible_cost(&self) -> bool {
282 self.cost == Cost::Negligible
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use crate::arrays::PrimitiveArray;
289 use crate::stats::Stat;
290
291 #[test]
292 fn is_constant_min_max_no_nan() {
293 let arr = PrimitiveArray::from_iter([0, 1]);
294 arr.statistics()
295 .compute_all(&[Stat::Min, Stat::Max])
296 .unwrap();
297 assert!(!arr.is_constant());
298
299 let arr = PrimitiveArray::from_iter([0, 0]);
300 arr.statistics()
301 .compute_all(&[Stat::Min, Stat::Max])
302 .unwrap();
303 assert!(arr.is_constant());
304
305 let arr = PrimitiveArray::from_option_iter([Some(0), Some(0)]);
306 assert!(arr.is_constant());
307 }
308
309 #[test]
310 fn is_constant_min_max_with_nan() {
311 let arr = PrimitiveArray::from_iter([0.0, 0.0, f32::NAN]);
312 arr.statistics()
313 .compute_all(&[Stat::Min, Stat::Max])
314 .unwrap();
315 assert!(!arr.is_constant());
316
317 let arr =
318 PrimitiveArray::from_option_iter([Some(f32::NEG_INFINITY), Some(f32::NEG_INFINITY)]);
319 arr.statistics()
320 .compute_all(&[Stat::Min, Stat::Max])
321 .unwrap();
322 assert!(arr.is_constant());
323 }
324}