rustkernel_core/
test_kernels.rs

1//! Test kernels for validation and benchmarking.
2//!
3//! This module provides simple kernels for testing the kernel framework:
4//! - `VectorAdd`: Batch kernel for vector addition
5//! - `EchoKernel`: Ring kernel for message echo (latency testing)
6
7use crate::domain::Domain;
8use crate::error::Result;
9use crate::kernel::KernelMetadata;
10use crate::traits::{BatchKernel, GpuKernel};
11use async_trait::async_trait;
12use serde::{Deserialize, Serialize};
13
14// ============================================================================
15// VectorAdd Batch Kernel
16// ============================================================================
17
18/// Input for vector addition.
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct VectorAddInput {
21    /// First vector.
22    pub a: Vec<f32>,
23    /// Second vector.
24    pub b: Vec<f32>,
25}
26
27/// Output from vector addition.
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct VectorAddOutput {
30    /// Result vector (a + b).
31    pub result: Vec<f32>,
32}
33
34/// Simple vector addition kernel.
35///
36/// This is a batch kernel that adds two vectors element-wise.
37/// Used for testing and validation of the kernel framework.
38#[derive(Debug, Clone)]
39pub struct VectorAdd {
40    metadata: KernelMetadata,
41}
42
43impl VectorAdd {
44    /// Create a new VectorAdd kernel.
45    #[must_use]
46    pub fn new() -> Self {
47        Self {
48            metadata: KernelMetadata::batch("core/vector-add", Domain::Core)
49                .with_description("Element-wise vector addition")
50                .with_throughput(10_000_000)
51                .with_latency_us(10.0),
52        }
53    }
54
55    /// Perform vector addition (CPU implementation).
56    fn add_vectors(a: &[f32], b: &[f32]) -> Vec<f32> {
57        a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
58    }
59}
60
61impl Default for VectorAdd {
62    fn default() -> Self {
63        Self::new()
64    }
65}
66
67impl GpuKernel for VectorAdd {
68    fn metadata(&self) -> &KernelMetadata {
69        &self.metadata
70    }
71
72    fn validate(&self) -> Result<()> {
73        Ok(())
74    }
75}
76
77#[async_trait]
78impl BatchKernel<VectorAddInput, VectorAddOutput> for VectorAdd {
79    async fn execute(&self, input: VectorAddInput) -> Result<VectorAddOutput> {
80        self.validate_input(&input)?;
81        let result = Self::add_vectors(&input.a, &input.b);
82        Ok(VectorAddOutput { result })
83    }
84
85    fn validate_input(&self, input: &VectorAddInput) -> Result<()> {
86        if input.a.len() != input.b.len() {
87            return Err(crate::error::KernelError::validation(format!(
88                "Vector lengths must match: a.len()={}, b.len()={}",
89                input.a.len(),
90                input.b.len()
91            )));
92        }
93        Ok(())
94    }
95}
96
97// ============================================================================
98// Echo Ring Kernel
99// ============================================================================
100
101/// Echo request message.
102#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct EchoRequest {
104    /// Message to echo back.
105    pub message: String,
106    /// Sequence number for ordering.
107    pub sequence: u64,
108}
109
110/// Echo response message.
111#[derive(Debug, Clone, Serialize, Deserialize)]
112pub struct EchoResponse {
113    /// Echoed message.
114    pub message: String,
115    /// Original sequence number.
116    pub sequence: u64,
117    /// Timestamp of processing (nanoseconds since kernel start).
118    pub processed_ns: u64,
119}
120
121/// Echo kernel state.
122#[derive(Debug, Clone, Default)]
123pub struct EchoState {
124    /// Number of messages processed.
125    pub messages_processed: u64,
126    /// Start time for latency measurement.
127    pub start_ns: u64,
128}
129
130/// Simple echo kernel for latency testing.
131///
132/// This is a ring kernel that echoes back messages with timing information.
133/// Used for testing message round-trip latency.
134#[derive(Debug, Clone)]
135pub struct EchoKernel {
136    metadata: KernelMetadata,
137}
138
139impl EchoKernel {
140    /// Create a new EchoKernel.
141    #[must_use]
142    pub fn new() -> Self {
143        Self {
144            metadata: KernelMetadata::ring("core/echo", Domain::Core)
145                .with_description("Message echo for latency testing")
146                .with_throughput(1_000_000)
147                .with_latency_us(0.5),
148        }
149    }
150
151    /// Process an echo request.
152    pub fn process(state: &mut EchoState, request: EchoRequest) -> EchoResponse {
153        state.messages_processed += 1;
154
155        // Simple timestamp (would use HLC in real implementation)
156        let processed_ns = state.messages_processed * 100; // Placeholder
157
158        EchoResponse {
159            message: request.message,
160            sequence: request.sequence,
161            processed_ns,
162        }
163    }
164
165    /// Initialize state.
166    pub fn initialize() -> EchoState {
167        EchoState {
168            messages_processed: 0,
169            start_ns: 0,
170        }
171    }
172}
173
174impl Default for EchoKernel {
175    fn default() -> Self {
176        Self::new()
177    }
178}
179
180impl GpuKernel for EchoKernel {
181    fn metadata(&self) -> &KernelMetadata {
182        &self.metadata
183    }
184}
185
186// ============================================================================
187// Matrix Multiply Batch Kernel (for benchmarking)
188// ============================================================================
189
190/// Input for matrix multiplication.
191#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct MatMulInput {
193    /// First matrix (row-major, dimensions rows_a x cols_a).
194    pub a: Vec<f32>,
195    /// Second matrix (row-major, dimensions cols_a x cols_b).
196    pub b: Vec<f32>,
197    /// Rows in matrix A.
198    pub rows_a: usize,
199    /// Columns in matrix A (= rows in matrix B).
200    pub cols_a: usize,
201    /// Columns in matrix B.
202    pub cols_b: usize,
203}
204
205/// Output from matrix multiplication.
206#[derive(Debug, Clone, Serialize, Deserialize)]
207pub struct MatMulOutput {
208    /// Result matrix (row-major, dimensions rows_a x cols_b).
209    pub result: Vec<f32>,
210}
211
212/// Matrix multiplication kernel.
213///
214/// This is a batch kernel that multiplies two matrices.
215/// Used for benchmarking compute throughput.
216#[derive(Debug, Clone)]
217pub struct MatMul {
218    metadata: KernelMetadata,
219}
220
221impl MatMul {
222    /// Create a new MatMul kernel.
223    #[must_use]
224    pub fn new() -> Self {
225        Self {
226            metadata: KernelMetadata::batch("core/matmul", Domain::Core)
227                .with_description("Matrix multiplication (GEMM)")
228                .with_throughput(1_000_000)
229                .with_latency_us(50.0)
230                .with_gpu_native(true),
231        }
232    }
233
234    /// Perform matrix multiplication (naive CPU implementation).
235    fn matmul(a: &[f32], b: &[f32], rows_a: usize, cols_a: usize, cols_b: usize) -> Vec<f32> {
236        let mut result = vec![0.0f32; rows_a * cols_b];
237
238        for i in 0..rows_a {
239            for j in 0..cols_b {
240                let mut sum = 0.0f32;
241                for k in 0..cols_a {
242                    sum += a[i * cols_a + k] * b[k * cols_b + j];
243                }
244                result[i * cols_b + j] = sum;
245            }
246        }
247
248        result
249    }
250}
251
252impl Default for MatMul {
253    fn default() -> Self {
254        Self::new()
255    }
256}
257
258impl GpuKernel for MatMul {
259    fn metadata(&self) -> &KernelMetadata {
260        &self.metadata
261    }
262}
263
264#[async_trait]
265impl BatchKernel<MatMulInput, MatMulOutput> for MatMul {
266    async fn execute(&self, input: MatMulInput) -> Result<MatMulOutput> {
267        self.validate_input(&input)?;
268        let result = Self::matmul(&input.a, &input.b, input.rows_a, input.cols_a, input.cols_b);
269        Ok(MatMulOutput { result })
270    }
271
272    fn validate_input(&self, input: &MatMulInput) -> Result<()> {
273        let expected_a = input.rows_a * input.cols_a;
274        let expected_b = input.cols_a * input.cols_b;
275
276        if input.a.len() != expected_a {
277            return Err(crate::error::KernelError::validation(format!(
278                "Matrix A size mismatch: expected {}, got {}",
279                expected_a,
280                input.a.len()
281            )));
282        }
283
284        if input.b.len() != expected_b {
285            return Err(crate::error::KernelError::validation(format!(
286                "Matrix B size mismatch: expected {}, got {}",
287                expected_b,
288                input.b.len()
289            )));
290        }
291
292        Ok(())
293    }
294}
295
296// ============================================================================
297// Reduce Sum Kernel
298// ============================================================================
299
300/// Input for sum reduction.
301#[derive(Debug, Clone, Serialize, Deserialize)]
302pub struct ReduceSumInput {
303    /// Data to sum.
304    pub data: Vec<f32>,
305}
306
307/// Output from sum reduction.
308#[derive(Debug, Clone, Serialize, Deserialize)]
309pub struct ReduceSumOutput {
310    /// Sum of all elements.
311    pub sum: f64,
312    /// Count of elements.
313    pub count: usize,
314}
315
316/// Sum reduction kernel.
317///
318/// Reduces a vector to its sum. Used for testing parallel reduction patterns.
319#[derive(Debug, Clone)]
320pub struct ReduceSum {
321    metadata: KernelMetadata,
322}
323
324impl ReduceSum {
325    /// Create a new ReduceSum kernel.
326    #[must_use]
327    pub fn new() -> Self {
328        Self {
329            metadata: KernelMetadata::batch("core/reduce-sum", Domain::Core)
330                .with_description("Parallel sum reduction")
331                .with_throughput(100_000_000)
332                .with_latency_us(5.0),
333        }
334    }
335}
336
337impl Default for ReduceSum {
338    fn default() -> Self {
339        Self::new()
340    }
341}
342
343impl GpuKernel for ReduceSum {
344    fn metadata(&self) -> &KernelMetadata {
345        &self.metadata
346    }
347}
348
349#[async_trait]
350impl BatchKernel<ReduceSumInput, ReduceSumOutput> for ReduceSum {
351    async fn execute(&self, input: ReduceSumInput) -> Result<ReduceSumOutput> {
352        let sum: f64 = input.data.iter().map(|&x| f64::from(x)).sum();
353        Ok(ReduceSumOutput {
354            sum,
355            count: input.data.len(),
356        })
357    }
358}
359
360#[cfg(test)]
361mod tests {
362    use super::*;
363    use crate::kernel::KernelMode;
364
365    #[tokio::test]
366    async fn test_vector_add() {
367        let kernel = VectorAdd::new();
368        assert_eq!(kernel.metadata().id, "core/vector-add");
369        assert_eq!(kernel.metadata().mode, KernelMode::Batch);
370
371        let input = VectorAddInput {
372            a: vec![1.0, 2.0, 3.0],
373            b: vec![4.0, 5.0, 6.0],
374        };
375
376        let output = kernel.execute(input).await.unwrap();
377        assert_eq!(output.result, vec![5.0, 7.0, 9.0]);
378    }
379
380    #[tokio::test]
381    async fn test_vector_add_validation() {
382        let kernel = VectorAdd::new();
383
384        let input = VectorAddInput {
385            a: vec![1.0, 2.0],
386            b: vec![1.0, 2.0, 3.0],
387        };
388
389        let result = kernel.execute(input).await;
390        assert!(result.is_err());
391    }
392
393    #[test]
394    fn test_echo_kernel() {
395        let kernel = EchoKernel::new();
396        assert_eq!(kernel.metadata().id, "core/echo");
397        assert_eq!(kernel.metadata().mode, KernelMode::Ring);
398
399        let mut state = EchoKernel::initialize();
400        let request = EchoRequest {
401            message: "Hello".to_string(),
402            sequence: 1,
403        };
404
405        let response = EchoKernel::process(&mut state, request);
406        assert_eq!(response.message, "Hello");
407        assert_eq!(response.sequence, 1);
408        assert_eq!(state.messages_processed, 1);
409    }
410
411    #[tokio::test]
412    async fn test_matmul() {
413        let kernel = MatMul::new();
414
415        // 2x2 * 2x2 matrix multiplication
416        let input = MatMulInput {
417            a: vec![1.0, 2.0, 3.0, 4.0],
418            b: vec![5.0, 6.0, 7.0, 8.0],
419            rows_a: 2,
420            cols_a: 2,
421            cols_b: 2,
422        };
423
424        let output = kernel.execute(input).await.unwrap();
425        // [1,2] * [5,6] = [1*5+2*7, 1*6+2*8] = [19, 22]
426        // [3,4]   [7,8]   [3*5+4*7, 3*6+4*8]   [43, 50]
427        assert_eq!(output.result, vec![19.0, 22.0, 43.0, 50.0]);
428    }
429
430    #[tokio::test]
431    async fn test_reduce_sum() {
432        let kernel = ReduceSum::new();
433
434        let input = ReduceSumInput {
435            data: vec![1.0, 2.0, 3.0, 4.0, 5.0],
436        };
437
438        let output = kernel.execute(input).await.unwrap();
439        assert!((output.sum - 15.0).abs() < 1e-6);
440        assert_eq!(output.count, 5);
441    }
442}