Skip to main content

vortex_array/stats/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Stats as they are stored on arrays.
5
6use std::sync::Arc;
7
8use parking_lot::RwLock;
9use vortex_array::ExecutionCtx;
10use vortex_error::VortexError;
11use vortex_error::VortexResult;
12use vortex_error::vortex_panic;
13
14use super::MutTypedStatsSetRef;
15use super::StatsSet;
16use super::StatsSetIntoIter;
17use super::TypedStatsSetRef;
18use crate::ArrayRef;
19use crate::aggregate_fn::NumericalAggregateOpts;
20use crate::aggregate_fn::fns::is_constant::is_constant;
21use crate::aggregate_fn::fns::is_sorted::is_sorted;
22use crate::aggregate_fn::fns::is_sorted::is_strict_sorted;
23use crate::aggregate_fn::fns::min_max::MinMaxResult;
24use crate::aggregate_fn::fns::min_max::min_max;
25use crate::aggregate_fn::fns::nan_count::nan_count;
26use crate::aggregate_fn::fns::sum::sum;
27use crate::aggregate_fn::fns::uncompressed_size_in_bytes::uncompressed_size_in_bytes;
28use crate::expr::stats::Precision;
29use crate::expr::stats::Stat;
30use crate::expr::stats::StatsProvider;
31use crate::scalar::Scalar;
32use crate::scalar::ScalarValue;
33
34/// A shared [`StatsSet`] stored in an array. Can be shared by copies of the array and can also be mutated in place.
35// TODO(adamg): This is a very bad name.
36#[derive(Clone, Default, Debug)]
37pub struct ArrayStats {
38    inner: Arc<RwLock<StatsSet>>,
39}
40
41/// Reference to an array's [`StatsSet`]. Can be used to get and mutate the underlying stats.
42///
43/// Constructed by calling [`ArrayStats::to_ref`].
44pub struct StatsSetRef<'a> {
45    // We need to reference back to the array
46    dyn_array_ref: &'a ArrayRef,
47    array_stats: &'a ArrayStats,
48}
49
50impl ArrayStats {
51    pub fn to_ref<'a>(&'a self, array: &'a ArrayRef) -> StatsSetRef<'a> {
52        StatsSetRef {
53            dyn_array_ref: array,
54            array_stats: self,
55        }
56    }
57
58    pub fn set(&self, stat: Stat, value: Precision<ScalarValue>) {
59        self.inner.write().set(stat, value);
60    }
61
62    pub fn clear(&self, stat: Stat) {
63        self.inner.write().clear(stat);
64    }
65
66    pub fn retain(&self, stats: &[Stat]) {
67        self.inner.write().retain_only(stats);
68    }
69}
70
71impl From<StatsSet> for ArrayStats {
72    fn from(value: StatsSet) -> Self {
73        Self {
74            inner: Arc::new(RwLock::new(value)),
75        }
76    }
77}
78
79impl From<ArrayStats> for StatsSet {
80    fn from(value: ArrayStats) -> Self {
81        value.inner.read().clone()
82    }
83}
84
85impl StatsSetRef<'_> {
86    pub(crate) fn replace(&self, stats: StatsSet) {
87        *self.array_stats.inner.write() = stats;
88    }
89
90    pub fn set_iter(&self, iter: StatsSetIntoIter) {
91        let mut guard = self.array_stats.inner.write();
92        for (stat, value) in iter {
93            guard.set(stat, value);
94        }
95    }
96
97    pub fn inherit_from(&self, stats: StatsSetRef<'_>) {
98        // Only inherit if the underlying stats are different
99        if !Arc::ptr_eq(&self.array_stats.inner, &stats.array_stats.inner) {
100            stats.with_iter(|iter| self.inherit(iter));
101        }
102    }
103
104    pub fn inherit<'a>(&self, iter: impl Iterator<Item = &'a (Stat, Precision<ScalarValue>)>) {
105        let mut guard = self.array_stats.inner.write();
106        for (stat, value) in iter {
107            if !value.is_exact() {
108                if !guard.get(*stat).is_exact() {
109                    guard.set(*stat, value.clone());
110                }
111            } else {
112                guard.set(*stat, value.clone());
113            }
114        }
115    }
116
117    pub fn with_typed_stats_set<U, F: FnOnce(TypedStatsSetRef) -> U>(&self, apply: F) -> U {
118        apply(
119            self.array_stats
120                .inner
121                .read()
122                .as_typed_ref(self.dyn_array_ref.dtype()),
123        )
124    }
125
126    pub fn with_mut_typed_stats_set<U, F: FnOnce(MutTypedStatsSetRef) -> U>(&self, apply: F) -> U {
127        apply(
128            self.array_stats
129                .inner
130                .write()
131                .as_mut_typed_ref(self.dyn_array_ref.dtype()),
132        )
133    }
134
135    pub fn to_owned(&self) -> StatsSet {
136        self.array_stats.inner.read().clone()
137    }
138
139    /// Returns a clone of the underlying [`ArrayStats`].
140    ///
141    /// Since [`ArrayStats`] uses `Arc` internally, this is a cheap reference-count increment.
142    pub fn to_array_stats(&self) -> ArrayStats {
143        self.array_stats.clone()
144    }
145
146    pub fn with_iter<
147        F: for<'a> FnOnce(&mut dyn Iterator<Item = &'a (Stat, Precision<ScalarValue>)>) -> R,
148        R,
149    >(
150        &self,
151        f: F,
152    ) -> R {
153        let lock = self.array_stats.inner.read();
154        f(&mut lock.iter())
155    }
156
157    /// Returns the value of `stat` by either fetching it from cache if it exists and is [`Precision::Exact`], or falling back to
158    /// computation. The underlying compute kernels will cache the computed stat in the latter case.
159    pub fn compute_stat(&self, stat: Stat, ctx: &mut ExecutionCtx) -> VortexResult<Option<Scalar>> {
160        // If it's already computed and exact, we can return it.
161        if let Precision::Exact(s) = self.get(stat) {
162            return Ok(Some(s));
163        }
164
165        Ok(match stat {
166            Stat::Min => min_max(self.dyn_array_ref, ctx, NumericalAggregateOpts::default())?
167                .map(|MinMaxResult { min, max: _ }| min),
168            Stat::Max => min_max(self.dyn_array_ref, ctx, NumericalAggregateOpts::default())?
169                .map(|MinMaxResult { min: _, max }| max),
170            Stat::Sum => {
171                Stat::Sum
172                    .dtype(self.dyn_array_ref.dtype())
173                    .is_some()
174                    .then(|| {
175                        // Sum is supported for this dtype.
176                        sum(self.dyn_array_ref, ctx)
177                    })
178                    .transpose()?
179            }
180            Stat::NullCount => self.dyn_array_ref.invalid_count(ctx).ok().map(Into::into),
181            Stat::IsConstant => {
182                if self.dyn_array_ref.is_empty() {
183                    None
184                } else {
185                    Some(is_constant(self.dyn_array_ref, ctx)?.into())
186                }
187            }
188            Stat::IsSorted => Some(is_sorted(self.dyn_array_ref, ctx)?.into()),
189            Stat::IsStrictSorted => Some(is_strict_sorted(self.dyn_array_ref, ctx)?.into()),
190            Stat::UncompressedSizeInBytes => Stat::UncompressedSizeInBytes
191                .dtype(self.dyn_array_ref.dtype())
192                .is_some()
193                .then(|| uncompressed_size_in_bytes(self.dyn_array_ref, ctx))
194                .transpose()?
195                .map(|s| s.into()),
196            Stat::NaNCount => {
197                Stat::NaNCount
198                    .dtype(self.dyn_array_ref.dtype())
199                    .is_some()
200                    .then(|| {
201                        // NaNCount is supported for this dtype.
202                        nan_count(self.dyn_array_ref, ctx)
203                    })
204                    .transpose()?
205                    .map(|s| s.into())
206            }
207        })
208    }
209
210    pub fn compute_all(&self, stats: &[Stat], ctx: &mut ExecutionCtx) -> VortexResult<StatsSet> {
211        let mut stats_set = StatsSet::default();
212        for &stat in stats {
213            if let Some(s) = self.compute_stat(stat, ctx)?
214                && let Some(value) = s.into_value()
215            {
216                stats_set.set(stat, Precision::exact(value));
217            }
218        }
219        Ok(stats_set)
220    }
221}
222
223impl StatsSetRef<'_> {
224    pub fn compute_as<U: for<'a> TryFrom<&'a Scalar, Error = VortexError>>(
225        &self,
226        stat: Stat,
227        ctx: &mut ExecutionCtx,
228    ) -> Option<U> {
229        self.compute_stat(stat, ctx)
230            .inspect_err(|e| tracing::warn!("Failed to compute stat {stat}: {e}"))
231            .ok()
232            .flatten()
233            .map(|s| U::try_from(&s))
234            .transpose()
235            .unwrap_or_else(|err| {
236                vortex_panic!(
237                    err,
238                    "Failed to compute stat {} as {}",
239                    stat,
240                    std::any::type_name::<U>()
241                )
242            })
243    }
244
245    pub fn set(&self, stat: Stat, value: Precision<ScalarValue>) {
246        self.array_stats.set(stat, value);
247    }
248
249    pub fn clear(&self, stat: Stat) {
250        self.array_stats.clear(stat);
251    }
252
253    pub fn compute_min<U: for<'a> TryFrom<&'a Scalar, Error = VortexError>>(
254        &self,
255        ctx: &mut ExecutionCtx,
256    ) -> Option<U> {
257        self.compute_as(Stat::Min, ctx)
258    }
259
260    pub fn compute_max<U: for<'a> TryFrom<&'a Scalar, Error = VortexError>>(
261        &self,
262        ctx: &mut ExecutionCtx,
263    ) -> Option<U> {
264        self.compute_as(Stat::Max, ctx)
265    }
266
267    pub fn compute_is_sorted(&self, ctx: &mut ExecutionCtx) -> Option<bool> {
268        self.compute_as(Stat::IsSorted, ctx)
269    }
270
271    pub fn compute_is_strict_sorted(&self, ctx: &mut ExecutionCtx) -> Option<bool> {
272        self.compute_as(Stat::IsStrictSorted, ctx)
273    }
274
275    pub fn compute_is_constant(&self, ctx: &mut ExecutionCtx) -> Option<bool> {
276        self.compute_as(Stat::IsConstant, ctx)
277    }
278
279    pub fn compute_null_count(&self, ctx: &mut ExecutionCtx) -> Option<usize> {
280        self.compute_as(Stat::NullCount, ctx)
281    }
282
283    pub fn compute_uncompressed_size_in_bytes(&self, ctx: &mut ExecutionCtx) -> Option<usize> {
284        self.compute_as(Stat::UncompressedSizeInBytes, ctx)
285    }
286}
287
288impl StatsProvider for StatsSetRef<'_> {
289    fn get(&self, stat: Stat) -> Precision<Scalar> {
290        self.array_stats
291            .inner
292            .read()
293            .as_typed_ref(self.dyn_array_ref.dtype())
294            .get(stat)
295    }
296
297    fn len(&self) -> usize {
298        self.array_stats.inner.read().len()
299    }
300}