vortex_array/aggregate_fn/
erased.rs1use std::fmt::Debug;
7use std::fmt::Display;
8use std::fmt::Formatter;
9use std::hash::Hash;
10use std::hash::Hasher;
11use std::sync::Arc;
12
13use vortex_error::VortexExpect;
14use vortex_error::VortexResult;
15use vortex_utils::debug_with::DebugWith;
16
17use crate::aggregate_fn::AccumulatorRef;
18use crate::aggregate_fn::AggregateFnId;
19use crate::aggregate_fn::AggregateFnVTable;
20use crate::aggregate_fn::GroupedAccumulatorRef;
21use crate::aggregate_fn::options::AggregateFnOptions;
22use crate::aggregate_fn::typed::AggregateFnInner;
23use crate::aggregate_fn::typed::DynAggregateFn;
24use crate::dtype::DType;
25
26#[derive(Clone)]
34pub struct AggregateFnRef(pub(super) Arc<dyn DynAggregateFn>);
35
36impl AggregateFnRef {
37 pub fn id(&self) -> AggregateFnId {
39 self.0.id()
40 }
41
42 pub fn is<V: AggregateFnVTable>(&self) -> bool {
44 self.0.as_any().is::<AggregateFnInner<V>>()
45 }
46
47 pub fn as_opt<V: AggregateFnVTable>(&self) -> Option<&V::Options> {
49 self.downcast_inner::<V>().map(|inner| &inner.options)
50 }
51
52 pub fn vtable_ref<V: AggregateFnVTable>(&self) -> Option<&V> {
54 self.downcast_inner::<V>().map(|inner| &inner.vtable)
55 }
56
57 fn downcast_inner<V: AggregateFnVTable>(&self) -> Option<&AggregateFnInner<V>> {
59 self.0.as_any().downcast_ref::<AggregateFnInner<V>>()
60 }
61
62 pub fn as_<V: AggregateFnVTable>(&self) -> &V::Options {
68 self.as_opt::<V>()
69 .vortex_expect("Aggregate function options type mismatch")
70 }
71
72 pub fn options(&self) -> AggregateFnOptions<'_> {
74 AggregateFnOptions { inner: &*self.0 }
75 }
76
77 pub fn return_dtype(&self, input_dtype: &DType) -> Option<DType> {
81 self.0.return_dtype(input_dtype)
82 }
83
84 pub fn state_dtype(&self, input_dtype: &DType) -> Option<DType> {
88 self.0.state_dtype(input_dtype)
89 }
90
91 pub fn accumulator(&self, input_dtype: &DType) -> VortexResult<AccumulatorRef> {
93 self.0.accumulator(input_dtype)
94 }
95
96 pub fn accumulator_grouped(&self, input_dtype: &DType) -> VortexResult<GroupedAccumulatorRef> {
98 self.0.accumulator_grouped(input_dtype)
99 }
100}
101
102impl Debug for AggregateFnRef {
103 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
104 f.debug_struct("AggregateFnRef")
105 .field("vtable", &self.0.id())
106 .field("options", &DebugWith(|fmt| self.0.options_debug(fmt)))
107 .finish()
108 }
109}
110
111impl Display for AggregateFnRef {
112 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
113 write!(f, "{}(", self.0.id())?;
114 self.0.options_display(f)?;
115 write!(f, ")")
116 }
117}
118
119impl PartialEq for AggregateFnRef {
120 fn eq(&self, other: &Self) -> bool {
121 self.0.id() == other.0.id() && self.0.options_eq(other.0.options_any())
122 }
123}
124impl Eq for AggregateFnRef {}
125
126impl Hash for AggregateFnRef {
127 fn hash<H: Hasher>(&self, state: &mut H) {
128 self.0.id().hash(state);
129 self.0.options_hash(state);
130 }
131}