rustkernel_core/
traits.rs

1//! Core kernel traits.
2//!
3//! This module defines the fundamental traits that all kernels implement:
4//! - `GpuKernel`: Base trait for all GPU kernels
5//! - `BatchKernel`: Trait for batch (CPU-orchestrated) kernels
6//! - `RingKernelHandler`: Trait for ring (persistent actor) kernels
7
8use crate::error::Result;
9use crate::kernel::KernelMetadata;
10use async_trait::async_trait;
11use ringkernel_core::{RingContext, RingMessage};
12use std::fmt::Debug;
13
14/// Base trait for all GPU kernels.
15///
16/// Provides access to kernel metadata and input validation.
17pub trait GpuKernel: Send + Sync + Debug {
18    /// Returns the kernel metadata.
19    fn metadata(&self) -> &KernelMetadata;
20
21    /// Validate kernel configuration.
22    ///
23    /// Called before kernel launch to ensure configuration is valid.
24    fn validate(&self) -> Result<()> {
25        Ok(())
26    }
27
28    /// Returns the kernel ID.
29    fn id(&self) -> &str {
30        &self.metadata().id
31    }
32
33    /// Returns true if this kernel requires GPU-native execution.
34    fn requires_gpu_native(&self) -> bool {
35        self.metadata().requires_gpu_native
36    }
37}
38
39/// Trait for batch (CPU-orchestrated) kernels.
40///
41/// Batch kernels are launched on-demand with CPU orchestration.
42/// They have 10-50μs launch overhead and state resides in CPU memory.
43///
44/// # Type Parameters
45///
46/// - `I`: Input type
47/// - `O`: Output type
48#[async_trait]
49pub trait BatchKernel<I, O>: GpuKernel
50where
51    I: Send + Sync,
52    O: Send + Sync,
53{
54    /// Execute the kernel with the given input.
55    ///
56    /// # Arguments
57    ///
58    /// * `input` - The input data for the kernel
59    ///
60    /// # Returns
61    ///
62    /// The kernel output or an error.
63    async fn execute(&self, input: I) -> Result<O>;
64
65    /// Validate the input before execution.
66    ///
67    /// Override to provide custom input validation.
68    fn validate_input(&self, _input: &I) -> Result<()> {
69        Ok(())
70    }
71}
72
73/// Trait for ring (persistent actor) kernels.
74///
75/// Ring kernels are persistent GPU actors with 100-500ns message latency.
76/// State resides permanently in GPU memory.
77///
78/// # Type Parameters
79///
80/// - `M`: Request message type
81/// - `R`: Response message type
82#[async_trait]
83pub trait RingKernelHandler<M, R>: GpuKernel
84where
85    M: RingMessage + Send + Sync,
86    R: RingMessage + Send + Sync,
87{
88    /// Handle an incoming message.
89    ///
90    /// # Arguments
91    ///
92    /// * `ctx` - The ring kernel context with GPU intrinsics
93    /// * `msg` - The incoming message
94    ///
95    /// # Returns
96    ///
97    /// The response message or an error.
98    async fn handle(&self, ctx: &mut RingContext, msg: M) -> Result<R>;
99
100    /// Initialize the kernel state.
101    ///
102    /// Called once when the kernel is first activated.
103    async fn initialize(&self, _ctx: &mut RingContext) -> Result<()> {
104        Ok(())
105    }
106
107    /// Called when the kernel is being shut down.
108    ///
109    /// Use this to clean up resources.
110    async fn shutdown(&self, _ctx: &mut RingContext) -> Result<()> {
111        Ok(())
112    }
113}
114
115/// Trait for iterative (multi-pass) kernels.
116///
117/// Provides support for algorithms that require multiple iterations
118/// to converge (e.g., PageRank, K-Means).
119///
120/// # Type Parameters
121///
122/// - `S`: State type
123/// - `I`: Input type
124/// - `O`: Output type
125#[async_trait]
126pub trait IterativeKernel<S, I, O>: GpuKernel
127where
128    S: Send + Sync + 'static,
129    I: Send + Sync + 'static,
130    O: Send + Sync + 'static,
131{
132    /// Create the initial state.
133    fn initial_state(&self, input: &I) -> S;
134
135    /// Perform one iteration.
136    ///
137    /// # Arguments
138    ///
139    /// * `state` - The current state (mutable)
140    /// * `input` - The input data
141    ///
142    /// # Returns
143    ///
144    /// The iteration result.
145    async fn iterate(&self, state: &mut S, input: &I) -> Result<IterationResult<O>>;
146
147    /// Check if the algorithm has converged.
148    ///
149    /// # Arguments
150    ///
151    /// * `state` - The current state
152    /// * `threshold` - The convergence threshold
153    ///
154    /// # Returns
155    ///
156    /// `true` if converged, `false` otherwise.
157    fn converged(&self, state: &S, threshold: f64) -> bool;
158
159    /// Maximum number of iterations.
160    fn max_iterations(&self) -> usize {
161        100
162    }
163
164    /// Default convergence threshold.
165    fn default_threshold(&self) -> f64 {
166        1e-6
167    }
168
169    /// Run the iterative algorithm to convergence.
170    async fn run_to_convergence(&self, input: I) -> Result<O> {
171        self.run_to_convergence_with_threshold(input, self.default_threshold())
172            .await
173    }
174
175    /// Run the iterative algorithm with a custom threshold.
176    async fn run_to_convergence_with_threshold(&self, input: I, threshold: f64) -> Result<O> {
177        let mut state = self.initial_state(&input);
178        let max_iter = self.max_iterations();
179
180        for _ in 0..max_iter {
181            let result = self.iterate(&mut state, &input).await?;
182
183            if let IterationResult::Converged(output) = result {
184                return Ok(output);
185            }
186
187            if self.converged(&state, threshold) {
188                if let IterationResult::Continue(output) = result {
189                    return Ok(output);
190                }
191            }
192        }
193
194        // Return final state even if not converged
195        match self.iterate(&mut state, &input).await? {
196            IterationResult::Converged(output) | IterationResult::Continue(output) => Ok(output),
197        }
198    }
199}
200
201/// Result of a single iteration.
202#[derive(Debug, Clone)]
203pub enum IterationResult<O> {
204    /// Algorithm has converged with final output.
205    Converged(O),
206    /// Algorithm should continue; current intermediate output.
207    Continue(O),
208}
209
210impl<O> IterationResult<O> {
211    /// Returns true if converged.
212    #[must_use]
213    pub fn is_converged(&self) -> bool {
214        matches!(self, IterationResult::Converged(_))
215    }
216
217    /// Extract the output.
218    #[must_use]
219    pub fn into_output(self) -> O {
220        match self {
221            IterationResult::Converged(o) | IterationResult::Continue(o) => o,
222        }
223    }
224}
225
226/// Type-erased batch kernel for registry storage.
227#[async_trait]
228pub trait BatchKernelDyn: GpuKernel {
229    /// Execute with type-erased input/output.
230    async fn execute_dyn(&self, input: &[u8]) -> Result<Vec<u8>>;
231}
232
233/// Type-erased ring kernel for registry storage.
234#[async_trait]
235pub trait RingKernelDyn: GpuKernel {
236    /// Handle with type-erased messages.
237    async fn handle_dyn(&self, ctx: &mut RingContext, msg: &[u8]) -> Result<Vec<u8>>;
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243
244    #[test]
245    fn test_iteration_result() {
246        let converged: IterationResult<i32> = IterationResult::Converged(42);
247        assert!(converged.is_converged());
248        assert_eq!(converged.into_output(), 42);
249
250        let continuing: IterationResult<i32> = IterationResult::Continue(0);
251        assert!(!continuing.is_converged());
252    }
253}