Skip to main content

zer_compute/
kernel.rs

1//! Core trait definitions for the kernel dispatch system.
2//!
3//! [`Kernel`] is a marker trait that associates typed input/output with a
4//! compute operation.  [`KernelDispatch<K>`] is implemented by each concrete
5//! backend device (`CudaDevice`, `VulkanDevice`, `CpuDevice`) **and** by
6//! [`DeviceBackend`] itself, so callers never need to name the backend:
7//!
8//! ```rust,ignore
9//! let output = backend.run::<CompareScore>(input)?;
10//! ```
11//!
12//! # Adding a new kernel
13//!
14//! 1. Create `src/kernels/my_kernel.rs`, define a marker struct + typed
15//!    `Input`/`Output` and `impl Kernel for MyKernel`.
16//! 2. For CUDA add `src/backend/cuda/launch/my_kernel.rs`
17//!    with `impl KernelDispatch<MyKernel> for CudaDevice`.
18//! 3. Add the CPU fallback `impl KernelDispatch<MyKernel> for CpuDevice` in
19//!    `src/backend/cpu/launch/my_kernel.rs`.
20//! 4. Add `impl KernelDispatch<MyKernel> for DeviceBackend` in
21//!    `src/backend/mod.rs` (a match that delegates to the above).
22//! 5. Register the new kernel in `build.rs` so it gets compiled to PTX.
23//!
24//! [`DeviceBackend`]: crate::backend::DeviceBackend
25
26use crate::error::GpuError;
27
28/// Marker trait that binds typed `Input` and `Output` to a compute operation.
29///
30/// Implement this on a zero-sized marker struct, the struct itself carries no
31/// data; it just names the operation so Rust can resolve the right dispatch.
32pub trait Kernel: Sized + 'static {
33    /// Input type for this kernel.  The lifetime parameter allows borrowing
34    /// host data (records, schema) without copying.
35    type Input<'a>;
36    /// Output type produced after the kernel completes and results are
37    /// downloaded back to host memory.
38    type Output;
39}
40
41/// Execute kernel `K` on `self`.
42///
43/// Implemented by:
44/// - Backend devices (`CudaDevice`, `CpuDevice`), the actual
45///   upload / launch / download logic lives here.
46/// - [`DeviceBackend`], a thin match that delegates to the active variant.
47///
48/// Callers should go through [`DeviceBackend::run`] rather than calling
49/// `dispatch` directly.
50///
51/// [`DeviceBackend`]: crate::backend::DeviceBackend
52/// [`DeviceBackend::run`]: crate::backend::DeviceBackend::run
53pub trait KernelDispatch<K: Kernel> {
54    fn dispatch(&self, input: K::Input<'_>) -> Result<K::Output, GpuError>;
55}