Skip to main content

vortex_array/stats/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Stats as they are stored on arrays.
5
6use std::sync::Arc;
7
8use parking_lot::RwLock;
9use vortex_array::ExecutionCtx;
10use vortex_error::VortexError;
11use vortex_error::VortexResult;
12use vortex_error::vortex_panic;
13
14use super::MutTypedStatsSetRef;
15use super::StatsSet;
16use super::StatsSetIntoIter;
17use super::TypedStatsSetRef;
18use crate::ArrayRef;
19use crate::aggregate_fn::fns::is_constant::is_constant;
20use crate::aggregate_fn::fns::is_sorted::is_sorted;
21use crate::aggregate_fn::fns::is_sorted::is_strict_sorted;
22use crate::aggregate_fn::fns::min_max::MinMaxResult;
23use crate::aggregate_fn::fns::min_max::min_max;
24use crate::aggregate_fn::fns::nan_count::nan_count;
25use crate::aggregate_fn::fns::sum::sum;
26use crate::aggregate_fn::fns::uncompressed_size_in_bytes::uncompressed_size_in_bytes;
27use crate::expr::stats::Precision;
28use crate::expr::stats::Stat;
29use crate::expr::stats::StatsProvider;
30use crate::scalar::Scalar;
31use crate::scalar::ScalarValue;
32
33/// A shared [`StatsSet`] stored in an array. Can be shared by copies of the array and can also be mutated in place.
34// TODO(adamg): This is a very bad name.
35#[derive(Clone, Default, Debug)]
36pub struct ArrayStats {
37    inner: Arc<RwLock<StatsSet>>,
38}
39
40/// Reference to an array's [`StatsSet`]. Can be used to get and mutate the underlying stats.
41///
42/// Constructed by calling [`ArrayStats::to_ref`].
43pub struct StatsSetRef<'a> {
44    // We need to reference back to the array
45    dyn_array_ref: &'a ArrayRef,
46    array_stats: &'a ArrayStats,
47}
48
49impl ArrayStats {
50    pub fn to_ref<'a>(&'a self, array: &'a ArrayRef) -> StatsSetRef<'a> {
51        StatsSetRef {
52            dyn_array_ref: array,
53            array_stats: self,
54        }
55    }
56
57    pub fn set(&self, stat: Stat, value: Precision<ScalarValue>) {
58        self.inner.write().set(stat, value);
59    }
60
61    pub fn clear(&self, stat: Stat) {
62        self.inner.write().clear(stat);
63    }
64
65    pub fn retain(&self, stats: &[Stat]) {
66        self.inner.write().retain_only(stats);
67    }
68}
69
70impl From<StatsSet> for ArrayStats {
71    fn from(value: StatsSet) -> Self {
72        Self {
73            inner: Arc::new(RwLock::new(value)),
74        }
75    }
76}
77
78impl From<ArrayStats> for StatsSet {
79    fn from(value: ArrayStats) -> Self {
80        value.inner.read().clone()
81    }
82}
83
84impl StatsSetRef<'_> {
85    pub(crate) fn replace(&self, stats: StatsSet) {
86        *self.array_stats.inner.write() = stats;
87    }
88
89    pub fn set_iter(&self, iter: StatsSetIntoIter) {
90        let mut guard = self.array_stats.inner.write();
91        for (stat, value) in iter {
92            guard.set(stat, value);
93        }
94    }
95
96    pub fn inherit_from(&self, stats: StatsSetRef<'_>) {
97        // Only inherit if the underlying stats are different
98        if !Arc::ptr_eq(&self.array_stats.inner, &stats.array_stats.inner) {
99            stats.with_iter(|iter| self.inherit(iter));
100        }
101    }
102
103    pub fn inherit<'a>(&self, iter: impl Iterator<Item = &'a (Stat, Precision<ScalarValue>)>) {
104        let mut guard = self.array_stats.inner.write();
105        for (stat, value) in iter {
106            if !value.is_exact() {
107                if !guard.get(*stat).is_exact() {
108                    guard.set(*stat, value.clone());
109                }
110            } else {
111                guard.set(*stat, value.clone());
112            }
113        }
114    }
115
116    pub fn with_typed_stats_set<U, F: FnOnce(TypedStatsSetRef) -> U>(&self, apply: F) -> U {
117        apply(
118            self.array_stats
119                .inner
120                .read()
121                .as_typed_ref(self.dyn_array_ref.dtype()),
122        )
123    }
124
125    pub fn with_mut_typed_stats_set<U, F: FnOnce(MutTypedStatsSetRef) -> U>(&self, apply: F) -> U {
126        apply(
127            self.array_stats
128                .inner
129                .write()
130                .as_mut_typed_ref(self.dyn_array_ref.dtype()),
131        )
132    }
133
134    pub fn to_owned(&self) -> StatsSet {
135        self.array_stats.inner.read().clone()
136    }
137
138    /// Returns a clone of the underlying [`ArrayStats`].
139    ///
140    /// Since [`ArrayStats`] uses `Arc` internally, this is a cheap reference-count increment.
141    pub fn to_array_stats(&self) -> ArrayStats {
142        self.array_stats.clone()
143    }
144
145    pub fn with_iter<
146        F: for<'a> FnOnce(&mut dyn Iterator<Item = &'a (Stat, Precision<ScalarValue>)>) -> R,
147        R,
148    >(
149        &self,
150        f: F,
151    ) -> R {
152        let lock = self.array_stats.inner.read();
153        f(&mut lock.iter())
154    }
155
156    /// Returns the value of `stat` by either fetching it from cache if it exists and is [`Precision::Exact`], or falling back to
157    /// computation. The underlying compute kernels will cache the computed stat in the latter case.
158    pub fn compute_stat(&self, stat: Stat, ctx: &mut ExecutionCtx) -> VortexResult<Option<Scalar>> {
159        // If it's already computed and exact, we can return it.
160        if let Precision::Exact(s) = self.get(stat) {
161            return Ok(Some(s));
162        }
163
164        Ok(match stat {
165            Stat::Min => min_max(self.dyn_array_ref, ctx)?.map(|MinMaxResult { min, max: _ }| min),
166            Stat::Max => min_max(self.dyn_array_ref, ctx)?.map(|MinMaxResult { min: _, max }| max),
167            Stat::Sum => {
168                Stat::Sum
169                    .dtype(self.dyn_array_ref.dtype())
170                    .is_some()
171                    .then(|| {
172                        // Sum is supported for this dtype.
173                        sum(self.dyn_array_ref, ctx)
174                    })
175                    .transpose()?
176            }
177            Stat::NullCount => self.dyn_array_ref.invalid_count(ctx).ok().map(Into::into),
178            Stat::IsConstant => {
179                if self.dyn_array_ref.is_empty() {
180                    None
181                } else {
182                    Some(is_constant(self.dyn_array_ref, ctx)?.into())
183                }
184            }
185            Stat::IsSorted => Some(is_sorted(self.dyn_array_ref, ctx)?.into()),
186            Stat::IsStrictSorted => Some(is_strict_sorted(self.dyn_array_ref, ctx)?.into()),
187            Stat::UncompressedSizeInBytes => Stat::UncompressedSizeInBytes
188                .dtype(self.dyn_array_ref.dtype())
189                .is_some()
190                .then(|| uncompressed_size_in_bytes(self.dyn_array_ref, ctx))
191                .transpose()?
192                .map(|s| s.into()),
193            Stat::NaNCount => {
194                Stat::NaNCount
195                    .dtype(self.dyn_array_ref.dtype())
196                    .is_some()
197                    .then(|| {
198                        // NaNCount is supported for this dtype.
199                        nan_count(self.dyn_array_ref, ctx)
200                    })
201                    .transpose()?
202                    .map(|s| s.into())
203            }
204        })
205    }
206
207    pub fn compute_all(&self, stats: &[Stat], ctx: &mut ExecutionCtx) -> VortexResult<StatsSet> {
208        let mut stats_set = StatsSet::default();
209        for &stat in stats {
210            if let Some(s) = self.compute_stat(stat, ctx)?
211                && let Some(value) = s.into_value()
212            {
213                stats_set.set(stat, Precision::exact(value));
214            }
215        }
216        Ok(stats_set)
217    }
218}
219
220impl StatsSetRef<'_> {
221    pub fn compute_as<U: for<'a> TryFrom<&'a Scalar, Error = VortexError>>(
222        &self,
223        stat: Stat,
224        ctx: &mut ExecutionCtx,
225    ) -> Option<U> {
226        self.compute_stat(stat, ctx)
227            .inspect_err(|e| tracing::warn!("Failed to compute stat {stat}: {e}"))
228            .ok()
229            .flatten()
230            .map(|s| U::try_from(&s))
231            .transpose()
232            .unwrap_or_else(|err| {
233                vortex_panic!(
234                    err,
235                    "Failed to compute stat {} as {}",
236                    stat,
237                    std::any::type_name::<U>()
238                )
239            })
240    }
241
242    pub fn set(&self, stat: Stat, value: Precision<ScalarValue>) {
243        self.array_stats.set(stat, value);
244    }
245
246    pub fn clear(&self, stat: Stat) {
247        self.array_stats.clear(stat);
248    }
249
250    pub fn compute_min<U: for<'a> TryFrom<&'a Scalar, Error = VortexError>>(
251        &self,
252        ctx: &mut ExecutionCtx,
253    ) -> Option<U> {
254        self.compute_as(Stat::Min, ctx)
255    }
256
257    pub fn compute_max<U: for<'a> TryFrom<&'a Scalar, Error = VortexError>>(
258        &self,
259        ctx: &mut ExecutionCtx,
260    ) -> Option<U> {
261        self.compute_as(Stat::Max, ctx)
262    }
263
264    pub fn compute_is_sorted(&self, ctx: &mut ExecutionCtx) -> Option<bool> {
265        self.compute_as(Stat::IsSorted, ctx)
266    }
267
268    pub fn compute_is_strict_sorted(&self, ctx: &mut ExecutionCtx) -> Option<bool> {
269        self.compute_as(Stat::IsStrictSorted, ctx)
270    }
271
272    pub fn compute_is_constant(&self, ctx: &mut ExecutionCtx) -> Option<bool> {
273        self.compute_as(Stat::IsConstant, ctx)
274    }
275
276    pub fn compute_null_count(&self, ctx: &mut ExecutionCtx) -> Option<usize> {
277        self.compute_as(Stat::NullCount, ctx)
278    }
279
280    pub fn compute_uncompressed_size_in_bytes(&self, ctx: &mut ExecutionCtx) -> Option<usize> {
281        self.compute_as(Stat::UncompressedSizeInBytes, ctx)
282    }
283}
284
285impl StatsProvider for StatsSetRef<'_> {
286    fn get(&self, stat: Stat) -> Precision<Scalar> {
287        self.array_stats
288            .inner
289            .read()
290            .as_typed_ref(self.dyn_array_ref.dtype())
291            .get(stat)
292    }
293
294    fn len(&self) -> usize {
295        self.array_stats.inner.read().len()
296    }
297}