vortex_array/optimizer/kernels.rs
1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Session-scoped registry for optimizer kernels.
5//!
6//! [`ArrayKernels`] stores function pointers that participate in array optimization without
7//! adding rules to an encoding vtable. The optimizer currently consults it for parent-reduce
8//! rewrites before the child encoding's static `PARENT_RULES`. A registered function can
9//! therefore add a rule for an extension encoding or take precedence over a built-in rule.
10//!
11//! Kernel entries are addressed by `(outer_id, child_id, kind)`. For parent-reduce kernels,
12//! `outer_id` is the id returned by the parent array's `encoding_id()` and `child_id` is the
13//! child array's `encoding_id()`. For [`ScalarFn`](crate::arrays::ScalarFn) parents, the parent
14//! id is the scalar function id.
15//!
16//! Sessions created by the top-level `vortex` crate install an empty registry by default. Other
17//! sessions can add it with [`VortexSession::with`](vortex_session::VortexSession::with) or rely
18//! on [`ArrayKernelsExt::kernels`] to insert the default value.
19
20use std::any::Any;
21use std::hash::BuildHasher;
22use std::sync::Arc;
23use std::sync::LazyLock;
24
25use arc_swap::ArcSwap;
26use vortex_error::VortexResult;
27use vortex_session::Ref;
28use vortex_session::SessionExt;
29use vortex_session::SessionVar;
30use vortex_session::registry::Id;
31use vortex_utils::aliases::DefaultHashBuilder;
32use vortex_utils::aliases::hash_map::HashMap;
33
34use crate::ArrayRef;
35
36/// Shared hasher used to combine `(outer, child, FnKind)` tuples into [`FnRegistry`] keys.
37static FN_HASHER: LazyLock<DefaultHashBuilder> = LazyLock::new(DefaultHashBuilder::default);
38
39/// Function pointer for a plugin-provided parent-reduce rewrite.
40///
41/// The optimizer calls this with the matched `child`, its `parent`, and the slot index where the
42/// child appears. Return `Ok(Some(new_parent))` to replace the parent, or `Ok(None)` when the
43/// rewrite does not apply.
44///
45/// Implementations must preserve the parent's logical length and dtype, matching the invariant
46/// required of static parent-reduce rules.
47pub type ReduceParentFn =
48 fn(child: &ArrayRef, parent: &ArrayRef, child_idx: usize) -> VortexResult<Option<ArrayRef>>;
49
50/// Session-scoped registry of optimizer kernel functions.
51#[derive(Debug, Default)]
52pub struct ArrayKernels {
53 reduce_parent: ArcSwap<HashMap<u64, Arc<[ReduceParentFn]>>>,
54}
55
56impl ArrayKernels {
57 /// Create an empty [`ArrayKernels`] with no kernels registered.
58 pub fn empty() -> Self {
59 Self::default()
60 }
61
62 /// Register a [`ReduceParentFn`] for `(outer, child)`.
63 ///
64 /// The optimizer will invoke `f` when it sees a parent with encoding id `outer` holding a
65 /// child with encoding id `child` during a `reduce_parent` step, before trying the child
66 /// encoding's static `PARENT_RULES`. `outer` is usually the parent array's encoding id. For
67 /// `ScalarFnArray`, it is the scalar function id, for example `Cast.id()`.
68 ///
69 /// Replaces any function already registered for the same pair.
70 pub fn register_reduce_parent<I: IntoIterator<Item = ReduceParentFn>>(
71 &self,
72 parent: Id,
73 child: Id,
74 fns: I,
75 ) {
76 let registry = self.reduce_parent.load();
77 let id = self.hash_fn_ids(parent, child);
78 let mut owned_registry = registry.as_ref().clone();
79 if let Some(existing) = owned_registry.remove(&id) {
80 owned_registry.insert(id, existing.as_ref().iter().cloned().chain(fns).collect());
81 } else {
82 owned_registry.insert(id, fns.into_iter().collect());
83 }
84 self.reduce_parent.store(Arc::new(owned_registry));
85 }
86
87 /// Look up the [`ReduceParentFn`] registered for `(outer, child)`.
88 ///
89 /// Returns an owned [`Arc`] so the session-variable borrow can be dropped before invoking the
90 /// function.
91 pub fn find_reduce_parent(&self, parent: Id, child: Id) -> Option<Arc<[ReduceParentFn]>> {
92 let id = self.hash_fn_ids(parent, child);
93 let map = self.reduce_parent.load();
94 let entry = map.get(&id)?;
95 Some(Arc::clone(entry))
96 }
97
98 /// Combine a typed kernel id tuple into the `u64` key expected by the underlying
99 /// [`FnRegistry`]. All typed helpers use this path so registration and lookup agree.
100 fn hash_fn_ids(&self, parent: Id, child: Id) -> u64 {
101 FN_HASHER.hash_one((parent, child))
102 }
103}
104
105impl SessionVar for ArrayKernels {
106 fn as_any(&self) -> &dyn Any {
107 self
108 }
109
110 fn as_any_mut(&mut self) -> &mut dyn Any {
111 self
112 }
113}
114
115/// Extension trait for accessing optimizer kernels from a
116/// [`VortexSession`](vortex_session::VortexSession).
117pub trait ArrayKernelsExt: SessionExt {
118 /// Returns the [`ArrayKernels`] session variable, inserting a default-constructed one if
119 /// none has been registered on the session yet.
120 fn kernels(&self) -> Ref<'_, ArrayKernels> {
121 self.get::<ArrayKernels>()
122 }
123}
124
125impl<S: SessionExt> ArrayKernelsExt for S {}