Skip to main content

vortex_array/aggregate_fn/
session.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::sync::Arc;
6
7use vortex_session::Ref;
8use vortex_session::SessionExt;
9use vortex_session::SessionVar;
10
11use crate::aggregate_fn::AggregateFnId;
12use crate::aggregate_fn::AggregateFnPluginRef;
13use crate::aggregate_fn::AggregateFnVTable;
14use crate::aggregate_fn::fns::all_nan::AllNan;
15use crate::aggregate_fn::fns::all_non_distinct::AllNonDistinct;
16use crate::aggregate_fn::fns::all_non_nan::AllNonNan;
17use crate::aggregate_fn::fns::all_non_null::AllNonNull;
18use crate::aggregate_fn::fns::all_null::AllNull;
19use crate::aggregate_fn::fns::bounded_max::BoundedMax;
20use crate::aggregate_fn::fns::bounded_min::BoundedMin;
21use crate::aggregate_fn::fns::count::Count;
22use crate::aggregate_fn::fns::count::CountGroupedKernel;
23use crate::aggregate_fn::fns::first::First;
24use crate::aggregate_fn::fns::is_constant::IsConstant;
25use crate::aggregate_fn::fns::is_sorted::IsSorted;
26use crate::aggregate_fn::fns::last::Last;
27use crate::aggregate_fn::fns::max::Max;
28use crate::aggregate_fn::fns::min::Min;
29use crate::aggregate_fn::fns::min_max::MinMax;
30use crate::aggregate_fn::fns::nan_count::NanCount;
31use crate::aggregate_fn::fns::null_count::NullCount;
32use crate::aggregate_fn::fns::sum::PrimitiveGroupedSumEncodingKernel;
33use crate::aggregate_fn::fns::sum::Sum;
34use crate::aggregate_fn::fns::uncompressed_size_in_bytes::UncompressedSizeInBytes;
35use crate::aggregate_fn::kernels::DynAggregateKernel;
36use crate::aggregate_fn::kernels::DynGroupedAggregateKernel;
37use crate::arc_swap_map::ArcSwapMap;
38use crate::array::ArrayId;
39use crate::array::VTable;
40use crate::arrays::Chunked;
41use crate::arrays::Dict;
42use crate::arrays::Primitive;
43use crate::arrays::chunked::compute::aggregate::ChunkedArrayAggregate;
44use crate::arrays::dict::compute::is_constant::DictIsConstantKernel;
45use crate::arrays::dict::compute::is_sorted::DictIsSortedKernel;
46use crate::arrays::dict::compute::min_max::DictMinMaxKernel;
47
48/// Session state for aggregate functions and encoding-specific aggregate kernels.
49///
50/// The default session registers the built-in aggregate functions and kernels. Additional
51/// aggregate functions and kernels may be registered by extensions when they are added to a
52/// [`VortexSession`](vortex_session::VortexSession).
53#[derive(Debug)]
54pub struct AggregateFnSession {
55    registry: ArcSwapMap<AggregateFnId, AggregateFnPluginRef>,
56
57    kernels: ArcSwapMap<AggregateKernelKey, &'static dyn DynAggregateKernel>,
58    grouped_kernels: ArcSwapMap<AggregateFnId, &'static dyn DynGroupedAggregateKernel>,
59    grouped_encoding_kernels:
60        ArcSwapMap<GroupedEncodingKernelKey, &'static dyn DynGroupedAggregateKernel>,
61}
62
63impl SessionVar for AggregateFnSession {
64    fn as_any(&self) -> &dyn Any {
65        self
66    }
67
68    fn as_any_mut(&mut self) -> &mut dyn Any {
69        self
70    }
71}
72
73type AggregateKernelKey = (ArrayId, Option<AggregateFnId>);
74type GroupedEncodingKernelKey = (ArrayId, AggregateFnId);
75
76impl Default for AggregateFnSession {
77    fn default() -> Self {
78        let this = Self {
79            registry: ArcSwapMap::default(),
80            kernels: ArcSwapMap::default(),
81            grouped_kernels: ArcSwapMap::default(),
82            grouped_encoding_kernels: ArcSwapMap::default(),
83        };
84
85        // Register the built-in aggregate functions
86        this.register(AllNonDistinct);
87        this.register(AllNonNan);
88        this.register(AllNonNull);
89        this.register(AllNan);
90        this.register(AllNull);
91        this.register(BoundedMax);
92        this.register(BoundedMin);
93        this.register(First);
94        this.register(IsConstant);
95        this.register(IsSorted);
96        this.register(Last);
97        this.register(Max);
98        this.register(Min);
99        this.register(MinMax);
100        this.register(NanCount);
101        this.register(NullCount);
102        this.register(Sum);
103        this.register(UncompressedSizeInBytes);
104
105        // Register the built-in aggregate kernels.
106        this.register_aggregate_kernel(Chunked.id(), None::<AggregateFnId>, &ChunkedArrayAggregate);
107        this.register_aggregate_kernel(Dict.id(), Some(MinMax.id()), &DictMinMaxKernel);
108        this.register_aggregate_kernel(Dict.id(), Some(IsConstant.id()), &DictIsConstantKernel);
109        this.register_aggregate_kernel(Dict.id(), Some(IsSorted.id()), &DictIsSortedKernel);
110
111        // Register the built-in grouped aggregate kernels.
112        this.register_grouped_kernel(Count.id(), &CountGroupedKernel);
113        this.register_grouped_encoding_kernel(
114            Primitive.id(),
115            Sum.id(),
116            &PrimitiveGroupedSumEncodingKernel,
117        );
118
119        this
120    }
121}
122
123impl AggregateFnSession {
124    /// Returns the aggregate function plugin registered for `id`, if any.
125    pub fn find_plugin(&self, id: &AggregateFnId) -> Option<AggregateFnPluginRef> {
126        self.registry.get(id)
127    }
128
129    /// Register an aggregate function vtable in the session, replacing any existing vtable with
130    /// the same ID.
131    pub fn register<V: AggregateFnVTable>(&self, vtable: V) {
132        let id = vtable.id();
133        let pluginref = Arc::new(vtable) as AggregateFnPluginRef;
134        self.registry.insert(id, pluginref);
135    }
136
137    /// Returns the aggregate kernel registered for `array_id` and `agg_fn_id`, if any.
138    ///
139    /// Lookup first checks for a kernel registered for the exact aggregate function, then falls
140    /// back to a kernel registered for all aggregate functions on the same array encoding.
141    pub fn find_aggregate_kernel(
142        &self,
143        array_id: impl Into<ArrayId>,
144        agg_fn_id: impl Into<AggregateFnId>,
145    ) -> Option<&'static dyn DynAggregateKernel> {
146        let id = array_id.into();
147        let fn_id = agg_fn_id.into();
148        self.kernels.read(|kernels| {
149            kernels
150                .get(&(id, Some(fn_id)))
151                .or_else(|| kernels.get(&(id, None)))
152                .copied()
153        })
154    }
155
156    /// Registers an aggregate kernel for an array encoding.
157    ///
158    /// When `agg_fn_id` is `Some`, the kernel is used only for that aggregate function. When
159    /// `agg_fn_id` is `None`, the kernel is used as the fallback for aggregate functions on the
160    /// array encoding that do not have a more specific kernel.
161    pub fn register_aggregate_kernel(
162        &self,
163        array_id: impl Into<ArrayId>,
164        agg_fn_id: Option<impl Into<AggregateFnId>>,
165        kernel: &'static dyn DynAggregateKernel,
166    ) {
167        let id = (array_id.into(), agg_fn_id.map(|id| id.into()));
168        self.kernels.insert(id, kernel);
169    }
170
171    /// Returns the grouped aggregate kernel registered for `agg_fn_id`, if any.
172    ///
173    /// These kernels are independent of the element encoding and are checked for each element
174    /// representation, after any kernel registered for the current element encoding.
175    pub fn find_grouped_kernel(
176        &self,
177        agg_fn_id: impl Into<AggregateFnId>,
178    ) -> Option<&'static dyn DynGroupedAggregateKernel> {
179        let fn_id = agg_fn_id.into();
180        self.grouped_kernels
181            .read(|kernels| kernels.get(&fn_id).copied())
182    }
183
184    /// Registers a grouped aggregate kernel for an aggregate function.
185    pub fn register_grouped_kernel(
186        &self,
187        agg_fn_id: impl Into<AggregateFnId>,
188        kernel: &'static dyn DynGroupedAggregateKernel,
189    ) {
190        let fn_id = agg_fn_id.into();
191        self.grouped_kernels.insert(fn_id, kernel)
192    }
193
194    /// Returns the grouped aggregate kernel registered for `array_id` and `agg_fn_id`, if any.
195    ///
196    /// These kernels are matched against each intermediate element encoding while the grouped
197    /// accumulator executes the element array.
198    pub fn find_grouped_encoding_kernel(
199        &self,
200        array_id: impl Into<ArrayId>,
201        agg_fn_id: impl Into<AggregateFnId>,
202    ) -> Option<&'static dyn DynGroupedAggregateKernel> {
203        let id = array_id.into();
204        let fn_id = agg_fn_id.into();
205        self.grouped_encoding_kernels
206            .read(|kernels| kernels.get(&(id, fn_id)).copied())
207    }
208
209    /// Registers a grouped aggregate kernel for a specific aggregate function and array encoding.
210    pub fn register_grouped_encoding_kernel(
211        &self,
212        array_id: impl Into<ArrayId>,
213        agg_fn_id: impl Into<AggregateFnId>,
214        kernel: &'static dyn DynGroupedAggregateKernel,
215    ) {
216        let id = array_id.into();
217        let fn_id = agg_fn_id.into();
218        self.grouped_encoding_kernels.insert((id, fn_id), kernel)
219    }
220}
221
222/// Extension trait for accessing aggregate function session data.
223pub trait AggregateFnSessionExt: SessionExt {
224    /// Returns the aggregate function session data.
225    fn aggregate_fns(&self) -> Ref<'_, AggregateFnSession> {
226        self.get::<AggregateFnSession>()
227    }
228}
229impl<S: SessionExt> AggregateFnSessionExt for S {}