Skip to main content

vortex_array/aggregate_fn/
erased.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Type-erased aggregate function ([`AggregateFnRef`]).
5
6use std::fmt::Debug;
7use std::fmt::Display;
8use std::fmt::Formatter;
9use std::hash::Hash;
10use std::hash::Hasher;
11use std::sync::Arc;
12
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15use vortex_session::VortexSession;
16use vortex_utils::debug_with::DebugWith;
17
18use crate::aggregate_fn::AccumulatorRef;
19use crate::aggregate_fn::AggregateFnId;
20use crate::aggregate_fn::AggregateFnVTable;
21use crate::aggregate_fn::GroupedAccumulatorRef;
22use crate::aggregate_fn::options::AggregateFnOptions;
23use crate::aggregate_fn::typed::AggregateFnInner;
24use crate::aggregate_fn::typed::DynAggregateFn;
25use crate::dtype::DType;
26
27/// A type-erased aggregate function, pairing a vtable with bound options behind a trait object.
28///
29/// This stores an [`AggregateFnVTable`] and its options behind an `Arc<dyn DynAggregateFn>`,
30/// allowing heterogeneous storage and dispatch.
31///
32/// Use [`super::AggregateFn::new()`] to construct, and [`super::AggregateFn::erased()`] to
33/// obtain an [`AggregateFnRef`].
34#[derive(Clone)]
35pub struct AggregateFnRef(pub(super) Arc<dyn DynAggregateFn>);
36
37impl AggregateFnRef {
38    /// Returns the ID of this aggregate function.
39    pub fn id(&self) -> AggregateFnId {
40        self.0.id()
41    }
42
43    /// Returns whether the aggregate function is of the given vtable type.
44    pub fn is<V: AggregateFnVTable>(&self) -> bool {
45        self.0.as_any().is::<AggregateFnInner<V>>()
46    }
47
48    /// Returns the typed options for this aggregate function if it matches the given vtable type.
49    pub fn as_opt<V: AggregateFnVTable>(&self) -> Option<&V::Options> {
50        self.downcast_inner::<V>().map(|inner| &inner.options)
51    }
52
53    /// Returns a reference to the typed vtable if it matches the given vtable type.
54    pub fn vtable_ref<V: AggregateFnVTable>(&self) -> Option<&V> {
55        self.downcast_inner::<V>().map(|inner| &inner.vtable)
56    }
57
58    /// Downcast the inner to the concrete `AggregateFnInner<V>`.
59    fn downcast_inner<V: AggregateFnVTable>(&self) -> Option<&AggregateFnInner<V>> {
60        self.0.as_any().downcast_ref::<AggregateFnInner<V>>()
61    }
62
63    /// Returns the typed options for this aggregate function if it matches the given vtable type.
64    ///
65    /// # Panics
66    ///
67    /// Panics if the vtable type does not match.
68    pub fn as_<V: AggregateFnVTable>(&self) -> &V::Options {
69        self.as_opt::<V>()
70            .vortex_expect("Aggregate function options type mismatch")
71    }
72
73    /// The type-erased options for this aggregate function.
74    pub fn options(&self) -> AggregateFnOptions<'_> {
75        AggregateFnOptions { inner: &*self.0 }
76    }
77
78    /// Compute the return [`DType`] per group given the input element type.
79    pub fn return_dtype(&self, input_dtype: &DType) -> VortexResult<DType> {
80        self.0.return_dtype(input_dtype)
81    }
82
83    /// DType of the intermediate accumulator state.
84    pub fn state_dtype(&self, input_dtype: &DType) -> VortexResult<DType> {
85        self.0.state_dtype(input_dtype)
86    }
87
88    /// Create an accumulator for streaming aggregation.
89    pub fn accumulator(
90        &self,
91        input_dtype: &DType,
92        session: &VortexSession,
93    ) -> VortexResult<AccumulatorRef> {
94        self.0.accumulator(input_dtype, session)
95    }
96
97    /// Create a grouped accumulator for grouped streaming aggregation.
98    pub fn accumulator_grouped(
99        &self,
100        input_dtype: &DType,
101        session: &VortexSession,
102    ) -> VortexResult<GroupedAccumulatorRef> {
103        self.0.accumulator_grouped(input_dtype, session)
104    }
105}
106
107impl Debug for AggregateFnRef {
108    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
109        f.debug_struct("AggregateFnRef")
110            .field("vtable", &self.0.id())
111            .field("options", &DebugWith(|fmt| self.0.options_debug(fmt)))
112            .finish()
113    }
114}
115
116impl Display for AggregateFnRef {
117    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
118        write!(f, "{}(", self.0.id())?;
119        self.0.options_display(f)?;
120        write!(f, ")")
121    }
122}
123
124impl PartialEq for AggregateFnRef {
125    fn eq(&self, other: &Self) -> bool {
126        self.0.id() == other.0.id() && self.0.options_eq(other.0.options_any())
127    }
128}
129impl Eq for AggregateFnRef {}
130
131impl Hash for AggregateFnRef {
132    fn hash<H: Hasher>(&self, state: &mut H) {
133        self.0.id().hash(state);
134        self.0.options_hash(state);
135    }
136}