Skip to main content

vortex_array/optimizer/
kernels.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Session-scoped registry for optimizer kernels.
5//!
6//! [`ArrayKernels`] stores function pointers that participate in array optimization and execution
7//! without adding rules or kernels to an encoding vtable. The optimizer consults it for
8//! parent-reduce rewrites before the child encoding's static `PARENT_RULES`, and the executor
9//! consults it for parent execution before the child encoding's static parent kernels. A
10//! registered function can therefore add support for an extension encoding or take precedence over
11//! a built-in rule or kernel. When several functions are registered for the same key and kind,
12//! they are tried in registration order until one applies.
13//!
14//! Kernel entries are addressed by `(outer_id, child_id)`. For parent-reduce and execute-parent
15//! kernels, `outer_id` is the id returned by the parent array's `encoding_id()` and `child_id` is
16//! the child array's `encoding_id()`. For [`ScalarFn`](crate::arrays::ScalarFn) parents, the
17//! parent id is the scalar function id.
18//!
19//! Because registered functions have different signatures for each kernel kind, the registry
20//! maintains one storage map per function type rather than a single type-erased map.
21//!
22//! Sessions created by the top-level `vortex` crate install the default registry. Other sessions
23//! can add it with [`VortexSession::with`](vortex_session::VortexSession::with) or rely on
24//! [`ArrayKernelsExt::kernels`] to insert the default value.
25
26use std::any::Any;
27use std::borrow::Borrow;
28use std::hash::BuildHasher;
29use std::sync::Arc;
30use std::sync::LazyLock;
31
32use vortex_error::VortexResult;
33use vortex_session::Ref;
34use vortex_session::SessionExt;
35use vortex_session::SessionVar;
36use vortex_session::registry::Id;
37use vortex_utils::aliases::DefaultHashBuilder;
38
39use crate::ArrayRef;
40use crate::ExecutionCtx;
41use crate::arc_swap_map::ArcSwapMap;
42use crate::array::VTable;
43use crate::arrays::Struct;
44use crate::arrays::struct_::compute::cast::struct_cast_execute_parent;
45use crate::arrays::struct_::compute::rules::struct_cast_reduce_parent;
46use crate::scalar_fn::ScalarFnVTable;
47use crate::scalar_fn::fns::cast::Cast;
48
49/// Shared hasher used to combine `(outer, child)` tuples into registry keys.
50static FN_HASHER: LazyLock<DefaultHashBuilder> = LazyLock::new(DefaultHashBuilder::default);
51
52/// Function pointer for a plugin-provided parent-reduce rewrite.
53///
54/// The optimizer calls this with the matched `child`, its `parent`, and the slot index where the
55/// child appears. Return `Ok(Some(new_parent))` to replace the parent, or `Ok(None)` when the
56/// rewrite does not apply.
57///
58/// Implementations must preserve the parent's logical length and dtype, matching the invariant
59/// required of static parent-reduce rules.
60pub type ReduceParentFn =
61    fn(child: &ArrayRef, parent: &ArrayRef, child_idx: usize) -> VortexResult<Option<ArrayRef>>;
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)]
64#[repr(transparent)]
65struct ReduceParentFnId(u64);
66
67impl From<u64> for ReduceParentFnId {
68    fn from(id: u64) -> Self {
69        Self(id)
70    }
71}
72
73impl Borrow<u64> for ReduceParentFnId {
74    fn borrow(&self) -> &u64 {
75        &self.0
76    }
77}
78
79/// Function pointer for a plugin-provided parent execution.
80///
81/// The executor calls this with the matched `child`, its `parent`, the slot index where the child
82/// appears, and the current [`ExecutionCtx`]. Return `Ok(Some(new_parent))` to replace the parent
83/// with an executed result, or `Ok(None)` when the kernel does not apply.
84///
85/// Implementations must preserve the parent's logical length and dtype, matching the invariant
86/// required of static `execute_parent` kernels.
87pub type ExecuteParentFn = fn(
88    child: &ArrayRef,
89    parent: &ArrayRef,
90    child_idx: usize,
91    ctx: &mut ExecutionCtx,
92) -> VortexResult<Option<ArrayRef>>;
93
94#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)]
95#[repr(transparent)]
96struct ExecuteParentFnId(u64);
97
98impl From<u64> for ExecuteParentFnId {
99    fn from(id: u64) -> Self {
100        Self(id)
101    }
102}
103
104impl Borrow<u64> for ExecuteParentFnId {
105    fn borrow(&self) -> &u64 {
106        &self.0
107    }
108}
109
110/// Session-scoped registry of optimizer kernel functions.
111///
112/// Each kernel kind has its own storage map, keyed by `(outer_id, child_id)`. Registering
113/// functions for an existing key appends them to that key's ordered list.
114#[derive(Debug)]
115pub struct ArrayKernels {
116    reduce_parent: ArcSwapMap<ReduceParentFnId, Arc<[ReduceParentFn]>>,
117    execute_parent: ArcSwapMap<ExecuteParentFnId, Arc<[ExecuteParentFn]>>,
118}
119
120impl Default for ArrayKernels {
121    fn default() -> ArrayKernels {
122        let this = Self::empty();
123        this.register_builtin_reduce_parent();
124        this.register_builtin_execute_parent();
125        this
126    }
127}
128
129impl ArrayKernels {
130    /// Create an empty [`ArrayKernels`] with no kernels registered.
131    pub fn empty() -> Self {
132        Self {
133            reduce_parent: ArcSwapMap::default(),
134            execute_parent: ArcSwapMap::default(),
135        }
136    }
137
138    fn register_builtin_reduce_parent(&self) {
139        self.register_reduce_parent(
140            Cast.id(),
141            Struct.id(),
142            &[struct_cast_reduce_parent as ReduceParentFn],
143        );
144    }
145
146    fn register_builtin_execute_parent(&self) {
147        self.register_execute_parent(
148            Cast.id(),
149            Struct.id(),
150            &[struct_cast_execute_parent as ExecuteParentFn],
151        );
152    }
153
154    /// Register [`ReduceParentFn`]s for `(parent, child)`.
155    ///
156    /// The optimizer invokes these functions in registration order when it sees a parent with
157    /// encoding id `parent` holding a child with encoding id `child` during a `reduce_parent`
158    /// step, before trying the child encoding's static `PARENT_RULES`. `parent` is usually the
159    /// parent array's encoding id. For `ScalarFnArray`, it is the scalar function id, for example
160    /// `Cast.id()`.
161    ///
162    /// If functions have already been registered for the same pair, these functions are appended
163    /// after them.
164    pub fn register_reduce_parent(&self, parent: Id, child: Id, fns: &[ReduceParentFn]) {
165        self.reduce_parent
166            .extend(hash_fn_id(parent, child).into(), fns);
167    }
168
169    /// Look up the [`ReduceParentFn`]s registered for `(parent, child)`.
170    ///
171    /// Returns an owned [`Arc`] so the session-variable borrow can be dropped before invoking the
172    /// functions.
173    pub fn find_reduce_parent(&self, parent: Id, child: Id) -> Option<Arc<[ReduceParentFn]>> {
174        self.reduce_parent.get(&hash_fn_id(parent, child))
175    }
176
177    /// Register [`ExecuteParentFn`]s for `(parent, child)`.
178    ///
179    /// The executor invokes these functions in registration order when it sees a parent with
180    /// encoding id `parent` holding a child with encoding id `child` during a parent execution
181    /// step, before trying the child encoding's static parent kernels.
182    ///
183    /// If functions have already been registered for the same pair, these functions are appended
184    /// after them.
185    pub fn register_execute_parent(&self, parent: Id, child: Id, fns: &[ExecuteParentFn]) {
186        self.execute_parent
187            .extend(hash_fn_id(parent, child).into(), fns);
188    }
189
190    /// Look up the [`ExecuteParentFn`]s registered for `(parent, child)`.
191    ///
192    /// Returns an owned [`Arc`] so the session-variable borrow can be dropped before invoking the
193    /// functions.
194    pub fn find_execute_parent(&self, parent: Id, child: Id) -> Option<Arc<[ExecuteParentFn]>> {
195        self.execute_parent.get(&hash_fn_id(parent, child))
196    }
197}
198
199fn hash_fn_id(parent: Id, child: Id) -> u64 {
200    FN_HASHER.hash_one((parent, child))
201}
202
203impl SessionVar for ArrayKernels {
204    fn as_any(&self) -> &dyn Any {
205        self
206    }
207
208    fn as_any_mut(&mut self) -> &mut dyn Any {
209        self
210    }
211}
212
213/// Extension trait for accessing optimizer kernels from a
214/// [`VortexSession`](vortex_session::VortexSession).
215pub trait ArrayKernelsExt: SessionExt {
216    /// Returns the [`ArrayKernels`] session variable, inserting a default-constructed one if
217    /// none has been registered on the session yet.
218    fn kernels(&self) -> Ref<'_, ArrayKernels> {
219        self.get::<ArrayKernels>()
220    }
221}
222
223impl<S: SessionExt> ArrayKernelsExt for S {}