vortex_array/aggregate_fn/
session.rs1use std::any::Any;
5use std::sync::Arc;
6
7use vortex_session::Ref;
8use vortex_session::SessionExt;
9use vortex_session::SessionVar;
10
11use crate::aggregate_fn::AggregateFnId;
12use crate::aggregate_fn::AggregateFnPluginRef;
13use crate::aggregate_fn::AggregateFnVTable;
14use crate::aggregate_fn::fns::all_nan::AllNan;
15use crate::aggregate_fn::fns::all_non_distinct::AllNonDistinct;
16use crate::aggregate_fn::fns::all_non_nan::AllNonNan;
17use crate::aggregate_fn::fns::all_non_null::AllNonNull;
18use crate::aggregate_fn::fns::all_null::AllNull;
19use crate::aggregate_fn::fns::bounded_max::BoundedMax;
20use crate::aggregate_fn::fns::bounded_min::BoundedMin;
21use crate::aggregate_fn::fns::count::Count;
22use crate::aggregate_fn::fns::count::CountGroupedKernel;
23use crate::aggregate_fn::fns::first::First;
24use crate::aggregate_fn::fns::is_constant::IsConstant;
25use crate::aggregate_fn::fns::is_sorted::IsSorted;
26use crate::aggregate_fn::fns::last::Last;
27use crate::aggregate_fn::fns::max::Max;
28use crate::aggregate_fn::fns::min::Min;
29use crate::aggregate_fn::fns::min_max::MinMax;
30use crate::aggregate_fn::fns::nan_count::NanCount;
31use crate::aggregate_fn::fns::null_count::NullCount;
32use crate::aggregate_fn::fns::sum::PrimitiveGroupedSumEncodingKernel;
33use crate::aggregate_fn::fns::sum::Sum;
34use crate::aggregate_fn::fns::uncompressed_size_in_bytes::UncompressedSizeInBytes;
35use crate::aggregate_fn::kernels::DynAggregateKernel;
36use crate::aggregate_fn::kernels::DynGroupedAggregateKernel;
37use crate::arc_swap_map::ArcSwapMap;
38use crate::array::ArrayId;
39use crate::array::VTable;
40use crate::arrays::Chunked;
41use crate::arrays::Dict;
42use crate::arrays::Primitive;
43use crate::arrays::chunked::compute::aggregate::ChunkedArrayAggregate;
44use crate::arrays::dict::compute::is_constant::DictIsConstantKernel;
45use crate::arrays::dict::compute::is_sorted::DictIsSortedKernel;
46use crate::arrays::dict::compute::min_max::DictMinMaxKernel;
47
48#[derive(Debug)]
54pub struct AggregateFnSession {
55 registry: ArcSwapMap<AggregateFnId, AggregateFnPluginRef>,
56
57 kernels: ArcSwapMap<AggregateKernelKey, &'static dyn DynAggregateKernel>,
58 grouped_kernels: ArcSwapMap<AggregateFnId, &'static dyn DynGroupedAggregateKernel>,
59 grouped_encoding_kernels:
60 ArcSwapMap<GroupedEncodingKernelKey, &'static dyn DynGroupedAggregateKernel>,
61}
62
63impl SessionVar for AggregateFnSession {
64 fn as_any(&self) -> &dyn Any {
65 self
66 }
67
68 fn as_any_mut(&mut self) -> &mut dyn Any {
69 self
70 }
71}
72
73type AggregateKernelKey = (ArrayId, Option<AggregateFnId>);
74type GroupedEncodingKernelKey = (ArrayId, AggregateFnId);
75
76impl Default for AggregateFnSession {
77 fn default() -> Self {
78 let this = Self {
79 registry: ArcSwapMap::default(),
80 kernels: ArcSwapMap::default(),
81 grouped_kernels: ArcSwapMap::default(),
82 grouped_encoding_kernels: ArcSwapMap::default(),
83 };
84
85 this.register(AllNonDistinct);
87 this.register(AllNonNan);
88 this.register(AllNonNull);
89 this.register(AllNan);
90 this.register(AllNull);
91 this.register(BoundedMax);
92 this.register(BoundedMin);
93 this.register(First);
94 this.register(IsConstant);
95 this.register(IsSorted);
96 this.register(Last);
97 this.register(Max);
98 this.register(Min);
99 this.register(MinMax);
100 this.register(NanCount);
101 this.register(NullCount);
102 this.register(Sum);
103 this.register(UncompressedSizeInBytes);
104
105 this.register_aggregate_kernel(Chunked.id(), None::<AggregateFnId>, &ChunkedArrayAggregate);
107 this.register_aggregate_kernel(Dict.id(), Some(MinMax.id()), &DictMinMaxKernel);
108 this.register_aggregate_kernel(Dict.id(), Some(IsConstant.id()), &DictIsConstantKernel);
109 this.register_aggregate_kernel(Dict.id(), Some(IsSorted.id()), &DictIsSortedKernel);
110
111 this.register_grouped_kernel(Count.id(), &CountGroupedKernel);
113 this.register_grouped_encoding_kernel(
114 Primitive.id(),
115 Sum.id(),
116 &PrimitiveGroupedSumEncodingKernel,
117 );
118
119 this
120 }
121}
122
123impl AggregateFnSession {
124 pub fn find_plugin(&self, id: &AggregateFnId) -> Option<AggregateFnPluginRef> {
126 self.registry.get(id)
127 }
128
129 pub fn register<V: AggregateFnVTable>(&self, vtable: V) {
132 let id = vtable.id();
133 let pluginref = Arc::new(vtable) as AggregateFnPluginRef;
134 self.registry.insert(id, pluginref);
135 }
136
137 pub fn find_aggregate_kernel(
142 &self,
143 array_id: impl Into<ArrayId>,
144 agg_fn_id: impl Into<AggregateFnId>,
145 ) -> Option<&'static dyn DynAggregateKernel> {
146 let id = array_id.into();
147 let fn_id = agg_fn_id.into();
148 self.kernels.read(|kernels| {
149 kernels
150 .get(&(id, Some(fn_id)))
151 .or_else(|| kernels.get(&(id, None)))
152 .copied()
153 })
154 }
155
156 pub fn register_aggregate_kernel(
162 &self,
163 array_id: impl Into<ArrayId>,
164 agg_fn_id: Option<impl Into<AggregateFnId>>,
165 kernel: &'static dyn DynAggregateKernel,
166 ) {
167 let id = (array_id.into(), agg_fn_id.map(|id| id.into()));
168 self.kernels.insert(id, kernel);
169 }
170
171 pub fn find_grouped_kernel(
176 &self,
177 agg_fn_id: impl Into<AggregateFnId>,
178 ) -> Option<&'static dyn DynGroupedAggregateKernel> {
179 let fn_id = agg_fn_id.into();
180 self.grouped_kernels
181 .read(|kernels| kernels.get(&fn_id).copied())
182 }
183
184 pub fn register_grouped_kernel(
186 &self,
187 agg_fn_id: impl Into<AggregateFnId>,
188 kernel: &'static dyn DynGroupedAggregateKernel,
189 ) {
190 let fn_id = agg_fn_id.into();
191 self.grouped_kernels.insert(fn_id, kernel)
192 }
193
194 pub fn find_grouped_encoding_kernel(
199 &self,
200 array_id: impl Into<ArrayId>,
201 agg_fn_id: impl Into<AggregateFnId>,
202 ) -> Option<&'static dyn DynGroupedAggregateKernel> {
203 let id = array_id.into();
204 let fn_id = agg_fn_id.into();
205 self.grouped_encoding_kernels
206 .read(|kernels| kernels.get(&(id, fn_id)).copied())
207 }
208
209 pub fn register_grouped_encoding_kernel(
211 &self,
212 array_id: impl Into<ArrayId>,
213 agg_fn_id: impl Into<AggregateFnId>,
214 kernel: &'static dyn DynGroupedAggregateKernel,
215 ) {
216 let id = array_id.into();
217 let fn_id = agg_fn_id.into();
218 self.grouped_encoding_kernels.insert((id, fn_id), kernel)
219 }
220}
221
222pub trait AggregateFnSessionExt: SessionExt {
224 fn aggregate_fns(&self) -> Ref<'_, AggregateFnSession> {
226 self.get::<AggregateFnSession>()
227 }
228}
229impl<S: SessionExt> AggregateFnSessionExt for S {}