rustkernel_core/traits.rs
1//! Core kernel traits.
2//!
3//! This module defines the fundamental traits that all kernels implement:
4//! - `GpuKernel`: Base trait for all GPU kernels
5//! - `BatchKernel`: Trait for batch (CPU-orchestrated) kernels
6//! - `RingKernelHandler`: Trait for ring (persistent actor) kernels
7
8use crate::error::Result;
9use crate::kernel::KernelMetadata;
10use async_trait::async_trait;
11use ringkernel_core::{RingContext, RingMessage};
12use std::fmt::Debug;
13
14/// Base trait for all GPU kernels.
15///
16/// Provides access to kernel metadata and input validation.
17pub trait GpuKernel: Send + Sync + Debug {
18 /// Returns the kernel metadata.
19 fn metadata(&self) -> &KernelMetadata;
20
21 /// Validate kernel configuration.
22 ///
23 /// Called before kernel launch to ensure configuration is valid.
24 fn validate(&self) -> Result<()> {
25 Ok(())
26 }
27
28 /// Returns the kernel ID.
29 fn id(&self) -> &str {
30 &self.metadata().id
31 }
32
33 /// Returns true if this kernel requires GPU-native execution.
34 fn requires_gpu_native(&self) -> bool {
35 self.metadata().requires_gpu_native
36 }
37}
38
39/// Trait for batch (CPU-orchestrated) kernels.
40///
41/// Batch kernels are launched on-demand with CPU orchestration.
42/// They have 10-50μs launch overhead and state resides in CPU memory.
43///
44/// # Type Parameters
45///
46/// - `I`: Input type
47/// - `O`: Output type
48#[async_trait]
49pub trait BatchKernel<I, O>: GpuKernel
50where
51 I: Send + Sync,
52 O: Send + Sync,
53{
54 /// Execute the kernel with the given input.
55 ///
56 /// # Arguments
57 ///
58 /// * `input` - The input data for the kernel
59 ///
60 /// # Returns
61 ///
62 /// The kernel output or an error.
63 async fn execute(&self, input: I) -> Result<O>;
64
65 /// Validate the input before execution.
66 ///
67 /// Override to provide custom input validation.
68 fn validate_input(&self, _input: &I) -> Result<()> {
69 Ok(())
70 }
71}
72
73/// Trait for ring (persistent actor) kernels.
74///
75/// Ring kernels are persistent GPU actors with 100-500ns message latency.
76/// State resides permanently in GPU memory.
77///
78/// # Type Parameters
79///
80/// - `M`: Request message type
81/// - `R`: Response message type
82#[async_trait]
83pub trait RingKernelHandler<M, R>: GpuKernel
84where
85 M: RingMessage + Send + Sync,
86 R: RingMessage + Send + Sync,
87{
88 /// Handle an incoming message.
89 ///
90 /// # Arguments
91 ///
92 /// * `ctx` - The ring kernel context with GPU intrinsics
93 /// * `msg` - The incoming message
94 ///
95 /// # Returns
96 ///
97 /// The response message or an error.
98 async fn handle(&self, ctx: &mut RingContext, msg: M) -> Result<R>;
99
100 /// Initialize the kernel state.
101 ///
102 /// Called once when the kernel is first activated.
103 async fn initialize(&self, _ctx: &mut RingContext) -> Result<()> {
104 Ok(())
105 }
106
107 /// Called when the kernel is being shut down.
108 ///
109 /// Use this to clean up resources.
110 async fn shutdown(&self, _ctx: &mut RingContext) -> Result<()> {
111 Ok(())
112 }
113}
114
115/// Trait for iterative (multi-pass) kernels.
116///
117/// Provides support for algorithms that require multiple iterations
118/// to converge (e.g., PageRank, K-Means).
119///
120/// # Type Parameters
121///
122/// - `S`: State type
123/// - `I`: Input type
124/// - `O`: Output type
125#[async_trait]
126pub trait IterativeKernel<S, I, O>: GpuKernel
127where
128 S: Send + Sync + 'static,
129 I: Send + Sync + 'static,
130 O: Send + Sync + 'static,
131{
132 /// Create the initial state.
133 fn initial_state(&self, input: &I) -> S;
134
135 /// Perform one iteration.
136 ///
137 /// # Arguments
138 ///
139 /// * `state` - The current state (mutable)
140 /// * `input` - The input data
141 ///
142 /// # Returns
143 ///
144 /// The iteration result.
145 async fn iterate(&self, state: &mut S, input: &I) -> Result<IterationResult<O>>;
146
147 /// Check if the algorithm has converged.
148 ///
149 /// # Arguments
150 ///
151 /// * `state` - The current state
152 /// * `threshold` - The convergence threshold
153 ///
154 /// # Returns
155 ///
156 /// `true` if converged, `false` otherwise.
157 fn converged(&self, state: &S, threshold: f64) -> bool;
158
159 /// Maximum number of iterations.
160 fn max_iterations(&self) -> usize {
161 100
162 }
163
164 /// Default convergence threshold.
165 fn default_threshold(&self) -> f64 {
166 1e-6
167 }
168
169 /// Run the iterative algorithm to convergence.
170 async fn run_to_convergence(&self, input: I) -> Result<O> {
171 self.run_to_convergence_with_threshold(input, self.default_threshold())
172 .await
173 }
174
175 /// Run the iterative algorithm with a custom threshold.
176 async fn run_to_convergence_with_threshold(&self, input: I, threshold: f64) -> Result<O> {
177 let mut state = self.initial_state(&input);
178 let max_iter = self.max_iterations();
179
180 for _ in 0..max_iter {
181 let result = self.iterate(&mut state, &input).await?;
182
183 if let IterationResult::Converged(output) = result {
184 return Ok(output);
185 }
186
187 if self.converged(&state, threshold) {
188 if let IterationResult::Continue(output) = result {
189 return Ok(output);
190 }
191 }
192 }
193
194 // Return final state even if not converged
195 match self.iterate(&mut state, &input).await? {
196 IterationResult::Converged(output) | IterationResult::Continue(output) => Ok(output),
197 }
198 }
199}
200
201/// Result of a single iteration.
202#[derive(Debug, Clone)]
203pub enum IterationResult<O> {
204 /// Algorithm has converged with final output.
205 Converged(O),
206 /// Algorithm should continue; current intermediate output.
207 Continue(O),
208}
209
210impl<O> IterationResult<O> {
211 /// Returns true if converged.
212 #[must_use]
213 pub fn is_converged(&self) -> bool {
214 matches!(self, IterationResult::Converged(_))
215 }
216
217 /// Extract the output.
218 #[must_use]
219 pub fn into_output(self) -> O {
220 match self {
221 IterationResult::Converged(o) | IterationResult::Continue(o) => o,
222 }
223 }
224}
225
226/// Type-erased batch kernel for registry storage.
227#[async_trait]
228pub trait BatchKernelDyn: GpuKernel {
229 /// Execute with type-erased input/output.
230 async fn execute_dyn(&self, input: &[u8]) -> Result<Vec<u8>>;
231}
232
233/// Type-erased ring kernel for registry storage.
234#[async_trait]
235pub trait RingKernelDyn: GpuKernel {
236 /// Handle with type-erased messages.
237 async fn handle_dyn(&self, ctx: &mut RingContext, msg: &[u8]) -> Result<Vec<u8>>;
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243
244 #[test]
245 fn test_iteration_result() {
246 let converged: IterationResult<i32> = IterationResult::Converged(42);
247 assert!(converged.is_converged());
248 assert_eq!(converged.into_output(), 42);
249
250 let continuing: IterationResult<i32> = IterationResult::Continue(0);
251 assert!(!continuing.is_converged());
252 }
253}