vortex_array/optimizer/kernels.rs
1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Session-scoped registry for optimizer kernels.
5//!
6//! [`ArrayKernels`] stores function pointers that participate in array optimization and execution
7//! without adding rules or kernels to an encoding vtable. The optimizer consults it for
8//! parent-reduce rewrites before the child encoding's static `PARENT_RULES`, and the executor
9//! consults it for parent execution before the child encoding's static parent kernels. A
10//! registered function can therefore add support for an extension encoding or take precedence over
11//! a built-in rule or kernel. When several functions are registered for the same key and kind,
12//! they are tried in registration order until one applies.
13//!
14//! Kernel entries are addressed by `(outer_id, child_id)`. For parent-reduce and execute-parent
15//! kernels, `outer_id` is the id returned by the parent array's `encoding_id()` and `child_id` is
16//! the child array's `encoding_id()`. For [`ScalarFn`](crate::arrays::ScalarFn) parents, the
17//! parent id is the scalar function id.
18//!
19//! Because registered functions have different signatures for each kernel kind, the registry
20//! maintains one storage map per function type rather than a single type-erased map.
21//!
22//! Sessions created by the top-level `vortex` crate install the default registry. Other sessions
23//! can add it with [`VortexSession::with`](vortex_session::VortexSession::with) or rely on
24//! [`ArrayKernelsExt::kernels`] to insert the default value.
25
26use std::any::Any;
27use std::borrow::Borrow;
28use std::hash::BuildHasher;
29use std::sync::Arc;
30use std::sync::LazyLock;
31
32use vortex_error::VortexResult;
33use vortex_session::Ref;
34use vortex_session::SessionExt;
35use vortex_session::SessionVar;
36use vortex_session::registry::Id;
37use vortex_utils::aliases::DefaultHashBuilder;
38
39use crate::ArrayRef;
40use crate::ExecutionCtx;
41use crate::arc_swap_map::ArcSwapMap;
42use crate::array::VTable;
43use crate::arrays::Struct;
44use crate::arrays::struct_::compute::cast::struct_cast_execute_parent;
45use crate::arrays::struct_::compute::rules::struct_cast_reduce_parent;
46use crate::scalar_fn::ScalarFnVTable;
47use crate::scalar_fn::fns::cast::Cast;
48
49/// Shared hasher used to combine `(outer, child)` tuples into registry keys.
50static FN_HASHER: LazyLock<DefaultHashBuilder> = LazyLock::new(DefaultHashBuilder::default);
51
52/// Function pointer for a plugin-provided parent-reduce rewrite.
53///
54/// The optimizer calls this with the matched `child`, its `parent`, and the slot index where the
55/// child appears. Return `Ok(Some(new_parent))` to replace the parent, or `Ok(None)` when the
56/// rewrite does not apply.
57///
58/// Implementations must preserve the parent's logical length and dtype, matching the invariant
59/// required of static parent-reduce rules.
60pub type ReduceParentFn =
61 fn(child: &ArrayRef, parent: &ArrayRef, child_idx: usize) -> VortexResult<Option<ArrayRef>>;
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)]
64#[repr(transparent)]
65struct ReduceParentFnId(u64);
66
67impl From<u64> for ReduceParentFnId {
68 fn from(id: u64) -> Self {
69 Self(id)
70 }
71}
72
73impl Borrow<u64> for ReduceParentFnId {
74 fn borrow(&self) -> &u64 {
75 &self.0
76 }
77}
78
79/// Function pointer for a plugin-provided parent execution.
80///
81/// The executor calls this with the matched `child`, its `parent`, the slot index where the child
82/// appears, and the current [`ExecutionCtx`]. Return `Ok(Some(new_parent))` to replace the parent
83/// with an executed result, or `Ok(None)` when the kernel does not apply.
84///
85/// Implementations must preserve the parent's logical length and dtype, matching the invariant
86/// required of static `execute_parent` kernels.
87pub type ExecuteParentFn = fn(
88 child: &ArrayRef,
89 parent: &ArrayRef,
90 child_idx: usize,
91 ctx: &mut ExecutionCtx,
92) -> VortexResult<Option<ArrayRef>>;
93
94#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)]
95#[repr(transparent)]
96struct ExecuteParentFnId(u64);
97
98impl From<u64> for ExecuteParentFnId {
99 fn from(id: u64) -> Self {
100 Self(id)
101 }
102}
103
104impl Borrow<u64> for ExecuteParentFnId {
105 fn borrow(&self) -> &u64 {
106 &self.0
107 }
108}
109
110/// Session-scoped registry of optimizer kernel functions.
111///
112/// Each kernel kind has its own storage map, keyed by `(outer_id, child_id)`. Registering
113/// functions for an existing key appends them to that key's ordered list.
114#[derive(Debug)]
115pub struct ArrayKernels {
116 reduce_parent: ArcSwapMap<ReduceParentFnId, Arc<[ReduceParentFn]>>,
117 execute_parent: ArcSwapMap<ExecuteParentFnId, Arc<[ExecuteParentFn]>>,
118}
119
120impl Default for ArrayKernels {
121 fn default() -> ArrayKernels {
122 let this = Self::empty();
123 this.register_builtin_reduce_parent();
124 this.register_builtin_execute_parent();
125 this
126 }
127}
128
129impl ArrayKernels {
130 /// Create an empty [`ArrayKernels`] with no kernels registered.
131 pub fn empty() -> Self {
132 Self {
133 reduce_parent: ArcSwapMap::default(),
134 execute_parent: ArcSwapMap::default(),
135 }
136 }
137
138 fn register_builtin_reduce_parent(&self) {
139 self.register_reduce_parent(
140 Cast.id(),
141 Struct.id(),
142 &[struct_cast_reduce_parent as ReduceParentFn],
143 );
144 }
145
146 fn register_builtin_execute_parent(&self) {
147 self.register_execute_parent(
148 Cast.id(),
149 Struct.id(),
150 &[struct_cast_execute_parent as ExecuteParentFn],
151 );
152 }
153
154 /// Register [`ReduceParentFn`]s for `(parent, child)`.
155 ///
156 /// The optimizer invokes these functions in registration order when it sees a parent with
157 /// encoding id `parent` holding a child with encoding id `child` during a `reduce_parent`
158 /// step, before trying the child encoding's static `PARENT_RULES`. `parent` is usually the
159 /// parent array's encoding id. For `ScalarFnArray`, it is the scalar function id, for example
160 /// `Cast.id()`.
161 ///
162 /// If functions have already been registered for the same pair, these functions are appended
163 /// after them.
164 pub fn register_reduce_parent(&self, parent: Id, child: Id, fns: &[ReduceParentFn]) {
165 self.reduce_parent
166 .extend(hash_fn_id(parent, child).into(), fns);
167 }
168
169 /// Look up the [`ReduceParentFn`]s registered for `(parent, child)`.
170 ///
171 /// Returns an owned [`Arc`] so the session-variable borrow can be dropped before invoking the
172 /// functions.
173 pub fn find_reduce_parent(&self, parent: Id, child: Id) -> Option<Arc<[ReduceParentFn]>> {
174 self.reduce_parent.get(&hash_fn_id(parent, child))
175 }
176
177 /// Register [`ExecuteParentFn`]s for `(parent, child)`.
178 ///
179 /// The executor invokes these functions in registration order when it sees a parent with
180 /// encoding id `parent` holding a child with encoding id `child` during a parent execution
181 /// step, before trying the child encoding's static parent kernels.
182 ///
183 /// If functions have already been registered for the same pair, these functions are appended
184 /// after them.
185 pub fn register_execute_parent(&self, parent: Id, child: Id, fns: &[ExecuteParentFn]) {
186 self.execute_parent
187 .extend(hash_fn_id(parent, child).into(), fns);
188 }
189
190 /// Look up the [`ExecuteParentFn`]s registered for `(parent, child)`.
191 ///
192 /// Returns an owned [`Arc`] so the session-variable borrow can be dropped before invoking the
193 /// functions.
194 pub fn find_execute_parent(&self, parent: Id, child: Id) -> Option<Arc<[ExecuteParentFn]>> {
195 self.execute_parent.get(&hash_fn_id(parent, child))
196 }
197}
198
199fn hash_fn_id(parent: Id, child: Id) -> u64 {
200 FN_HASHER.hash_one((parent, child))
201}
202
203impl SessionVar for ArrayKernels {
204 fn as_any(&self) -> &dyn Any {
205 self
206 }
207
208 fn as_any_mut(&mut self) -> &mut dyn Any {
209 self
210 }
211}
212
213/// Extension trait for accessing optimizer kernels from a
214/// [`VortexSession`](vortex_session::VortexSession).
215pub trait ArrayKernelsExt: SessionExt {
216 /// Returns the [`ArrayKernels`] session variable, inserting a default-constructed one if
217 /// none has been registered on the session yet.
218 fn kernels(&self) -> Ref<'_, ArrayKernels> {
219 self.get::<ArrayKernels>()
220 }
221}
222
223impl<S: SessionExt> ArrayKernelsExt for S {}