Skip to main content

vortex_array/optimizer/
kernels.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Session-scoped registry for optimizer kernels.
5//!
6//! [`ArrayKernels`] stores function pointers that participate in array optimization and execution
7//! without adding rules or kernels to an encoding vtable. The optimizer consults it for
8//! parent-reduce rewrites before the child encoding's static `PARENT_RULES`, and the executor
9//! consults it for parent execution before the child encoding's static parent kernels. A
10//! registered function can therefore add support for an extension encoding or take precedence over
11//! a built-in rule or kernel. When several functions are registered for the same key and kind,
12//! they are tried in registration order until one applies.
13//!
14//! Kernel entries are addressed by `(outer_id, child_id)`. For parent-reduce and execute-parent
15//! kernels, `outer_id` is the id returned by the parent array's `encoding_id()` and `child_id` is
16//! the child array's `encoding_id()`. For [`ScalarFn`](crate::arrays::ScalarFn) parents, the
17//! parent id is the scalar function id.
18//!
19//! Because registered functions have different signatures for each kernel kind, the registry
20//! maintains one storage map per function type rather than a single type-erased map.
21//!
22//! Sessions created by the top-level `vortex` crate install the default registry. Other sessions
23//! can add it with [`VortexSession::with`](vortex_session::VortexSession::with) or rely on
24//! [`ArrayKernelsExt::kernels`] to insert the default value.
25
26use std::any::Any;
27use std::borrow::Borrow;
28use std::hash::BuildHasher;
29use std::hash::Hash;
30use std::sync::Arc;
31use std::sync::LazyLock;
32
33use arc_swap::ArcSwap;
34use vortex_error::VortexResult;
35use vortex_session::Ref;
36use vortex_session::SessionExt;
37use vortex_session::SessionVar;
38use vortex_session::registry::Id;
39use vortex_utils::aliases::DefaultHashBuilder;
40use vortex_utils::aliases::hash_map::HashMap;
41
42use crate::ArrayRef;
43use crate::ExecutionCtx;
44use crate::array::VTable;
45use crate::arrays::Struct;
46use crate::arrays::struct_::compute::cast::struct_cast_execute_parent;
47use crate::arrays::struct_::compute::rules::struct_cast_reduce_parent;
48use crate::scalar_fn::ScalarFnVTable;
49use crate::scalar_fn::fns::cast::Cast;
50
51/// Shared hasher used to combine `(outer, child)` tuples into registry keys.
52static FN_HASHER: LazyLock<DefaultHashBuilder> = LazyLock::new(DefaultHashBuilder::default);
53
54/// Function pointer for a plugin-provided parent-reduce rewrite.
55///
56/// The optimizer calls this with the matched `child`, its `parent`, and the slot index where the
57/// child appears. Return `Ok(Some(new_parent))` to replace the parent, or `Ok(None)` when the
58/// rewrite does not apply.
59///
60/// Implementations must preserve the parent's logical length and dtype, matching the invariant
61/// required of static parent-reduce rules.
62pub type ReduceParentFn =
63    fn(child: &ArrayRef, parent: &ArrayRef, child_idx: usize) -> VortexResult<Option<ArrayRef>>;
64
65#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)]
66#[repr(transparent)]
67struct ReduceParentFnId(u64);
68
69impl From<u64> for ReduceParentFnId {
70    fn from(id: u64) -> Self {
71        Self(id)
72    }
73}
74
75impl Borrow<u64> for ReduceParentFnId {
76    fn borrow(&self) -> &u64 {
77        &self.0
78    }
79}
80
81/// Function pointer for a plugin-provided parent execution.
82///
83/// The executor calls this with the matched `child`, its `parent`, the slot index where the child
84/// appears, and the current [`ExecutionCtx`]. Return `Ok(Some(new_parent))` to replace the parent
85/// with an executed result, or `Ok(None)` when the kernel does not apply.
86///
87/// Implementations must preserve the parent's logical length and dtype, matching the invariant
88/// required of static `execute_parent` kernels.
89pub type ExecuteParentFn = fn(
90    child: &ArrayRef,
91    parent: &ArrayRef,
92    child_idx: usize,
93    ctx: &mut ExecutionCtx,
94) -> VortexResult<Option<ArrayRef>>;
95
96#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)]
97#[repr(transparent)]
98struct ExecuteParentFnId(u64);
99
100impl From<u64> for ExecuteParentFnId {
101    fn from(id: u64) -> Self {
102        Self(id)
103    }
104}
105
106impl Borrow<u64> for ExecuteParentFnId {
107    fn borrow(&self) -> &u64 {
108        &self.0
109    }
110}
111
112/// Session-scoped registry of optimizer kernel functions.
113///
114/// Each kernel kind has its own storage map, keyed by `(outer_id, child_id)`. Registering
115/// functions for an existing key appends them to that key's ordered list.
116#[derive(Debug)]
117pub struct ArrayKernels {
118    reduce_parent: ArcSwap<HashMap<ReduceParentFnId, Arc<[ReduceParentFn]>>>,
119    execute_parent: ArcSwap<HashMap<ExecuteParentFnId, Arc<[ExecuteParentFn]>>>,
120}
121
122impl Default for ArrayKernels {
123    fn default() -> ArrayKernels {
124        let this = Self::empty();
125        this.register_builtin_reduce_parent();
126        this.register_builtin_execute_parent();
127        this
128    }
129}
130
131impl ArrayKernels {
132    /// Create an empty [`ArrayKernels`] with no kernels registered.
133    pub fn empty() -> Self {
134        Self {
135            reduce_parent: ArcSwap::from_pointee(HashMap::default()),
136            execute_parent: ArcSwap::from_pointee(HashMap::default()),
137        }
138    }
139
140    fn register_builtin_reduce_parent(&self) {
141        self.register_reduce_parent(
142            Cast.id(),
143            Struct.id(),
144            &[struct_cast_reduce_parent as ReduceParentFn],
145        );
146    }
147
148    fn register_builtin_execute_parent(&self) {
149        self.register_execute_parent(
150            Cast.id(),
151            Struct.id(),
152            &[struct_cast_execute_parent as ExecuteParentFn],
153        );
154    }
155
156    /// Register [`ReduceParentFn`]s for `(parent, child)`.
157    ///
158    /// The optimizer invokes these functions in registration order when it sees a parent with
159    /// encoding id `parent` holding a child with encoding id `child` during a `reduce_parent`
160    /// step, before trying the child encoding's static `PARENT_RULES`. `parent` is usually the
161    /// parent array's encoding id. For `ScalarFnArray`, it is the scalar function id, for example
162    /// `Cast.id()`.
163    ///
164    /// If functions have already been registered for the same pair, these functions are appended
165    /// after them.
166    pub fn register_reduce_parent(&self, parent: Id, child: Id, fns: &[ReduceParentFn]) {
167        self.reduce_parent.rcu(move |registry| {
168            update_fns(registry.as_ref().clone(), hash_fn_id(parent, child), fns)
169        });
170    }
171
172    /// Look up the [`ReduceParentFn`]s registered for `(parent, child)`.
173    ///
174    /// Returns an owned [`Arc`] so the session-variable borrow can be dropped before invoking the
175    /// functions.
176    pub fn find_reduce_parent(&self, parent: Id, child: Id) -> Option<Arc<[ReduceParentFn]>> {
177        let id = hash_fn_id(parent, child);
178        self.reduce_parent.load().get(&id).cloned()
179    }
180
181    /// Register [`ExecuteParentFn`]s for `(parent, child)`.
182    ///
183    /// The executor invokes these functions in registration order when it sees a parent with
184    /// encoding id `parent` holding a child with encoding id `child` during a parent execution
185    /// step, before trying the child encoding's static parent kernels.
186    ///
187    /// If functions have already been registered for the same pair, these functions are appended
188    /// after them.
189    pub fn register_execute_parent(&self, parent: Id, child: Id, fns: &[ExecuteParentFn]) {
190        self.execute_parent.rcu(move |registry| {
191            update_fns(registry.as_ref().clone(), hash_fn_id(parent, child), fns)
192        });
193    }
194
195    /// Look up the [`ExecuteParentFn`]s registered for `(parent, child)`.
196    ///
197    /// Returns an owned [`Arc`] so the session-variable borrow can be dropped before invoking the
198    /// functions.
199    pub fn find_execute_parent(&self, parent: Id, child: Id) -> Option<Arc<[ExecuteParentFn]>> {
200        let id = hash_fn_id(parent, child);
201        self.execute_parent.load().get(&id).cloned()
202    }
203}
204
205fn hash_fn_id(parent: Id, child: Id) -> u64 {
206    FN_HASHER.hash_one((parent, child))
207}
208
209fn update_fns<F: Clone, K: Borrow<u64> + Eq + Hash + From<u64>>(
210    mut existing: HashMap<K, Arc<[F]>>,
211    id: u64,
212    fns: &[F],
213) -> HashMap<K, Arc<[F]>> {
214    if let Some(existing_fns) = existing.remove(&id) {
215        existing.insert(
216            id.into(),
217            existing_fns.as_ref().iter().chain(fns).cloned().collect(),
218        );
219    } else {
220        existing.insert(id.into(), fns.into());
221    }
222    existing
223}
224
225impl SessionVar for ArrayKernels {
226    fn as_any(&self) -> &dyn Any {
227        self
228    }
229
230    fn as_any_mut(&mut self) -> &mut dyn Any {
231        self
232    }
233}
234
235/// Extension trait for accessing optimizer kernels from a
236/// [`VortexSession`](vortex_session::VortexSession).
237pub trait ArrayKernelsExt: SessionExt {
238    /// Returns the [`ArrayKernels`] session variable, inserting a default-constructed one if
239    /// none has been registered on the session yet.
240    fn kernels(&self) -> Ref<'_, ArrayKernels> {
241        self.get::<ArrayKernels>()
242    }
243}
244
245impl<S: SessionExt> ArrayKernelsExt for S {}