Skip to main content

vortex_array/optimizer/
kernels.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Session-scoped registry for optimizer kernels.
5//!
6//! [`ArrayKernels`] stores function pointers that participate in array optimization without
7//! adding rules to an encoding vtable. The optimizer currently consults it for parent-reduce
8//! rewrites before the child encoding's static `PARENT_RULES`. A registered function can
9//! therefore add a rule for an extension encoding or take precedence over a built-in rule.
10//!
11//! Kernel entries are addressed by `(outer_id, child_id, kind)`. For parent-reduce kernels,
12//! `outer_id` is the id returned by the parent array's `encoding_id()` and `child_id` is the
13//! child array's `encoding_id()`. For [`ScalarFn`](crate::arrays::ScalarFn) parents, the parent
14//! id is the scalar function id.
15//!
16//! Sessions created by the top-level `vortex` crate install an empty registry by default. Other
17//! sessions can add it with [`VortexSession::with`](vortex_session::VortexSession::with) or rely
18//! on [`ArrayKernelsExt::kernels`] to insert the default value.
19
20use std::any::Any;
21use std::hash::BuildHasher;
22use std::sync::Arc;
23use std::sync::LazyLock;
24
25use arc_swap::ArcSwap;
26use vortex_error::VortexResult;
27use vortex_session::Ref;
28use vortex_session::SessionExt;
29use vortex_session::SessionVar;
30use vortex_session::registry::Id;
31use vortex_utils::aliases::DefaultHashBuilder;
32use vortex_utils::aliases::hash_map::HashMap;
33
34use crate::ArrayRef;
35
36/// Shared hasher used to combine `(outer, child, FnKind)` tuples into [`FnRegistry`] keys.
37static FN_HASHER: LazyLock<DefaultHashBuilder> = LazyLock::new(DefaultHashBuilder::default);
38
39/// Function pointer for a plugin-provided parent-reduce rewrite.
40///
41/// The optimizer calls this with the matched `child`, its `parent`, and the slot index where the
42/// child appears. Return `Ok(Some(new_parent))` to replace the parent, or `Ok(None)` when the
43/// rewrite does not apply.
44///
45/// Implementations must preserve the parent's logical length and dtype, matching the invariant
46/// required of static parent-reduce rules.
47pub type ReduceParentFn =
48    fn(child: &ArrayRef, parent: &ArrayRef, child_idx: usize) -> VortexResult<Option<ArrayRef>>;
49
50/// Session-scoped registry of optimizer kernel functions.
51#[derive(Debug, Default)]
52pub struct ArrayKernels {
53    reduce_parent: ArcSwap<HashMap<u64, Arc<[ReduceParentFn]>>>,
54}
55
56impl ArrayKernels {
57    /// Create an empty [`ArrayKernels`] with no kernels registered.
58    pub fn empty() -> Self {
59        Self::default()
60    }
61
62    /// Register a [`ReduceParentFn`] for `(outer, child)`.
63    ///
64    /// The optimizer will invoke `f` when it sees a parent with encoding id `outer` holding a
65    /// child with encoding id `child` during a `reduce_parent` step, before trying the child
66    /// encoding's static `PARENT_RULES`. `outer` is usually the parent array's encoding id. For
67    /// `ScalarFnArray`, it is the scalar function id, for example `Cast.id()`.
68    ///
69    /// Replaces any function already registered for the same pair.
70    pub fn register_reduce_parent<I: IntoIterator<Item = ReduceParentFn>>(
71        &self,
72        parent: Id,
73        child: Id,
74        fns: I,
75    ) {
76        let registry = self.reduce_parent.load();
77        let id = self.hash_fn_ids(parent, child);
78        let mut owned_registry = registry.as_ref().clone();
79        if let Some(existing) = owned_registry.remove(&id) {
80            owned_registry.insert(id, existing.as_ref().iter().cloned().chain(fns).collect());
81        } else {
82            owned_registry.insert(id, fns.into_iter().collect());
83        }
84        self.reduce_parent.store(Arc::new(owned_registry));
85    }
86
87    /// Look up the [`ReduceParentFn`] registered for `(outer, child)`.
88    ///
89    /// Returns an owned [`Arc`] so the session-variable borrow can be dropped before invoking the
90    /// function.
91    pub fn find_reduce_parent(&self, parent: Id, child: Id) -> Option<Arc<[ReduceParentFn]>> {
92        let id = self.hash_fn_ids(parent, child);
93        let map = self.reduce_parent.load();
94        let entry = map.get(&id)?;
95        Some(Arc::clone(entry))
96    }
97
98    /// Combine a typed kernel id tuple into the `u64` key expected by the underlying
99    /// [`FnRegistry`]. All typed helpers use this path so registration and lookup agree.
100    fn hash_fn_ids(&self, parent: Id, child: Id) -> u64 {
101        FN_HASHER.hash_one((parent, child))
102    }
103}
104
105impl SessionVar for ArrayKernels {
106    fn as_any(&self) -> &dyn Any {
107        self
108    }
109
110    fn as_any_mut(&mut self) -> &mut dyn Any {
111        self
112    }
113}
114
115/// Extension trait for accessing optimizer kernels from a
116/// [`VortexSession`](vortex_session::VortexSession).
117pub trait ArrayKernelsExt: SessionExt {
118    /// Returns the [`ArrayKernels`] session variable, inserting a default-constructed one if
119    /// none has been registered on the session yet.
120    fn kernels(&self) -> Ref<'_, ArrayKernels> {
121        self.get::<ArrayKernels>()
122    }
123}
124
125impl<S: SessionExt> ArrayKernelsExt for S {}