vortex_array/stats/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Stats as they are stored on arrays.
5
6use std::sync::Arc;
7
8use parking_lot::RwLock;
9use vortex_error::{VortexError, VortexResult, vortex_panic};
10use vortex_scalar::{Scalar, ScalarValue};
11
12use super::{Precision, Stat, StatType, StatsProvider, StatsSet, StatsSetIntoIter};
13use crate::Array;
14use crate::compute::{
15    MinMaxResult, is_constant, is_sorted, is_strict_sorted, min_max, nan_count, sum,
16};
17
18/// A shared [`StatsSet`] stored in an array. Can be shared by copies of the array and can also be mutated in place.
19// TODO(adamg): This is a very bad name.
20#[derive(Clone, Default, Debug)]
21pub struct ArrayStats {
22    inner: Arc<RwLock<StatsSet>>,
23}
24
25/// Reference to an array's [`StatsSet`]. Can be used to get and mutate the underlying stats.
26///
27/// Constructed by calling [`ArrayStats::to_ref`].
28pub struct StatsSetRef<'a> {
29    // We need to reference back to the array
30    dyn_array_ref: &'a dyn Array,
31    array_stats: &'a ArrayStats,
32}
33
34impl ArrayStats {
35    pub fn to_ref<'a>(&'a self, array: &'a dyn Array) -> StatsSetRef<'a> {
36        StatsSetRef {
37            dyn_array_ref: array,
38            array_stats: self,
39        }
40    }
41
42    pub fn set(&self, stat: Stat, value: Precision<ScalarValue>) {
43        self.inner.write().set(stat, value);
44    }
45
46    pub fn clear(&self, stat: Stat) {
47        self.inner.write().clear(stat);
48    }
49
50    pub fn retain(&self, stats: &[Stat]) {
51        self.inner.write().retain_only(stats);
52    }
53}
54
55impl From<StatsSet> for ArrayStats {
56    fn from(value: StatsSet) -> Self {
57        Self {
58            inner: Arc::new(RwLock::new(value)),
59        }
60    }
61}
62
63impl From<ArrayStats> for StatsSet {
64    fn from(value: ArrayStats) -> Self {
65        value.inner.read().clone()
66    }
67}
68
69impl StatsSetRef<'_> {
70    pub fn set_iter(&self, iter: StatsSetIntoIter) {
71        let mut guard = self.array_stats.inner.write();
72        for (stat, value) in iter {
73            guard.set(stat, value);
74        }
75    }
76
77    pub fn inherit_from(&self, stats: StatsSetRef<'_>) {
78        stats.with_iter(|iter| self.inherit(iter));
79    }
80
81    pub fn inherit<'a>(&self, iter: impl Iterator<Item = &'a (Stat, Precision<ScalarValue>)>) {
82        // TODO(ngates): depending on statistic, this should choose the more precise one
83        let mut guard = self.array_stats.inner.write();
84        for (stat, value) in iter {
85            guard.set(*stat, value.clone());
86        }
87    }
88
89    pub fn replace(&self, stats: StatsSet) {
90        *self.array_stats.inner.write() = stats;
91    }
92
93    pub fn to_owned(&self) -> StatsSet {
94        self.array_stats.inner.read().clone()
95    }
96
97    pub fn with_iter<
98        F: for<'a> FnOnce(&mut dyn Iterator<Item = &'a (Stat, Precision<ScalarValue>)>) -> R,
99        R,
100    >(
101        &self,
102        f: F,
103    ) -> R {
104        let lock = self.array_stats.inner.read();
105        f(&mut lock.iter())
106    }
107
108    pub fn compute_stat(&self, stat: Stat) -> VortexResult<Option<Scalar>> {
109        // If it's already computed and exact, we can return it.
110        if let Some(Precision::Exact(s)) = self.get(stat) {
111            return Ok(Some(s));
112        }
113
114        Ok(match stat {
115            Stat::Min => min_max(self.dyn_array_ref)?.map(|MinMaxResult { min, max: _ }| min),
116            Stat::Max => min_max(self.dyn_array_ref)?.map(|MinMaxResult { min: _, max }| max),
117            Stat::Sum => {
118                Stat::Sum
119                    .dtype(self.dyn_array_ref.dtype())
120                    .is_some()
121                    .then(|| {
122                        // Sum is supported for this dtype.
123                        sum(self.dyn_array_ref)
124                    })
125                    .transpose()?
126            }
127            Stat::NullCount => Some(self.dyn_array_ref.invalid_count()?.into()),
128            Stat::IsConstant => {
129                if self.dyn_array_ref.is_empty() {
130                    None
131                } else {
132                    is_constant(self.dyn_array_ref)?.map(|v| v.into())
133                }
134            }
135            Stat::IsSorted => Some(is_sorted(self.dyn_array_ref)?.into()),
136            Stat::IsStrictSorted => Some(is_strict_sorted(self.dyn_array_ref)?.into()),
137            Stat::UncompressedSizeInBytes => {
138                let nbytes = self.dyn_array_ref.to_canonical()?.as_ref().nbytes();
139                self.set(stat, Precision::exact(nbytes));
140                Some(nbytes.into())
141            }
142            Stat::NaNCount => {
143                Stat::NaNCount
144                    .dtype(self.dyn_array_ref.dtype())
145                    .is_some()
146                    .then(|| {
147                        // NaNCount is supported for this dtype.
148                        nan_count(self.dyn_array_ref)
149                    })
150                    .transpose()?
151                    .map(|s| s.into())
152            }
153        })
154    }
155
156    pub fn compute_all(&self, stats: &[Stat]) -> VortexResult<StatsSet> {
157        let mut stats_set = StatsSet::default();
158        for &stat in stats {
159            if let Some(s) = self.compute_stat(stat)? {
160                stats_set.set(stat, Precision::exact(s.into_value()))
161            }
162        }
163        Ok(stats_set)
164    }
165}
166
167impl StatsSetRef<'_> {
168    pub fn get_as<U: for<'a> TryFrom<&'a Scalar, Error = VortexError>>(
169        &self,
170        stat: Stat,
171    ) -> Option<Precision<U>> {
172        self.get(stat).map(|v| {
173            v.map(|v| {
174                U::try_from(&v).unwrap_or_else(|err| {
175                    vortex_panic!(
176                        err,
177                        "Failed to get stat {} as {}",
178                        stat,
179                        std::any::type_name::<U>()
180                    )
181                })
182            })
183        })
184    }
185
186    pub fn get_as_bound<S, U>(&self) -> Option<S::Bound>
187    where
188        S: StatType<U>,
189        U: for<'a> TryFrom<&'a Scalar, Error = VortexError>,
190    {
191        self.get_as::<U>(S::STAT).map(|v| v.bound::<S>())
192    }
193
194    pub fn compute_as<U: for<'a> TryFrom<&'a Scalar, Error = VortexError>>(
195        &self,
196        stat: Stat,
197    ) -> Option<U> {
198        self.compute_stat(stat)
199            .inspect_err(|e| log::warn!("Failed to compute stat {stat}: {e}"))
200            .ok()
201            .flatten()
202            .map(|s| U::try_from(&s))
203            .transpose()
204            .unwrap_or_else(|err| {
205                vortex_panic!(
206                    err,
207                    "Failed to compute stat {} as {}",
208                    stat,
209                    std::any::type_name::<U>()
210                )
211            })
212    }
213
214    pub fn set(&self, stat: Stat, value: Precision<ScalarValue>) {
215        self.array_stats.set(stat, value);
216    }
217
218    pub fn clear(&self, stat: Stat) {
219        self.array_stats.clear(stat);
220    }
221
222    pub fn retain(&self, stats: &[Stat]) {
223        self.array_stats.retain(stats);
224    }
225
226    pub fn compute_min<U: for<'a> TryFrom<&'a Scalar, Error = VortexError>>(&self) -> Option<U> {
227        self.compute_as(Stat::Min)
228    }
229
230    pub fn compute_max<U: for<'a> TryFrom<&'a Scalar, Error = VortexError>>(&self) -> Option<U> {
231        self.compute_as(Stat::Max)
232    }
233
234    pub fn compute_is_sorted(&self) -> Option<bool> {
235        self.compute_as(Stat::IsSorted)
236    }
237
238    pub fn compute_is_strict_sorted(&self) -> Option<bool> {
239        self.compute_as(Stat::IsStrictSorted)
240    }
241
242    pub fn compute_is_constant(&self) -> Option<bool> {
243        self.compute_as(Stat::IsConstant)
244    }
245
246    pub fn compute_null_count(&self) -> Option<usize> {
247        self.compute_as(Stat::NullCount)
248    }
249
250    pub fn compute_uncompressed_size_in_bytes(&self) -> Option<usize> {
251        self.compute_as(Stat::UncompressedSizeInBytes)
252    }
253}
254
255impl StatsProvider for StatsSetRef<'_> {
256    fn get(&self, stat: Stat) -> Option<Precision<Scalar>> {
257        self.array_stats
258            .inner
259            .read()
260            .as_typed_ref(self.dyn_array_ref.dtype())
261            .get(stat)
262    }
263
264    fn len(&self) -> usize {
265        self.array_stats.inner.read().len()
266    }
267}