1use crate::domain::Domain;
8use crate::error::Result;
9use crate::kernel::KernelMetadata;
10use crate::traits::{BatchKernel, GpuKernel};
11use async_trait::async_trait;
12use serde::{Deserialize, Serialize};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct VectorAddInput {
21 pub a: Vec<f32>,
23 pub b: Vec<f32>,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct VectorAddOutput {
30 pub result: Vec<f32>,
32}
33
34#[derive(Debug, Clone)]
39pub struct VectorAdd {
40 metadata: KernelMetadata,
41}
42
43impl VectorAdd {
44 #[must_use]
46 pub fn new() -> Self {
47 Self {
48 metadata: KernelMetadata::batch("core/vector-add", Domain::Core)
49 .with_description("Element-wise vector addition")
50 .with_throughput(10_000_000)
51 .with_latency_us(10.0),
52 }
53 }
54
55 fn add_vectors(a: &[f32], b: &[f32]) -> Vec<f32> {
57 a.iter().zip(b.iter()).map(|(x, y)| x + y).collect()
58 }
59}
60
61impl Default for VectorAdd {
62 fn default() -> Self {
63 Self::new()
64 }
65}
66
67impl GpuKernel for VectorAdd {
68 fn metadata(&self) -> &KernelMetadata {
69 &self.metadata
70 }
71
72 fn validate(&self) -> Result<()> {
73 Ok(())
74 }
75}
76
77#[async_trait]
78impl BatchKernel<VectorAddInput, VectorAddOutput> for VectorAdd {
79 async fn execute(&self, input: VectorAddInput) -> Result<VectorAddOutput> {
80 self.validate_input(&input)?;
81 let result = Self::add_vectors(&input.a, &input.b);
82 Ok(VectorAddOutput { result })
83 }
84
85 fn validate_input(&self, input: &VectorAddInput) -> Result<()> {
86 if input.a.len() != input.b.len() {
87 return Err(crate::error::KernelError::validation(format!(
88 "Vector lengths must match: a.len()={}, b.len()={}",
89 input.a.len(),
90 input.b.len()
91 )));
92 }
93 Ok(())
94 }
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct EchoRequest {
104 pub message: String,
106 pub sequence: u64,
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize)]
112pub struct EchoResponse {
113 pub message: String,
115 pub sequence: u64,
117 pub processed_ns: u64,
119}
120
121#[derive(Debug, Clone, Default)]
123pub struct EchoState {
124 pub messages_processed: u64,
126 pub start_ns: u64,
128}
129
130#[derive(Debug, Clone)]
135pub struct EchoKernel {
136 metadata: KernelMetadata,
137}
138
139impl EchoKernel {
140 #[must_use]
142 pub fn new() -> Self {
143 Self {
144 metadata: KernelMetadata::ring("core/echo", Domain::Core)
145 .with_description("Message echo for latency testing")
146 .with_throughput(1_000_000)
147 .with_latency_us(0.5),
148 }
149 }
150
151 pub fn process(state: &mut EchoState, request: EchoRequest) -> EchoResponse {
153 state.messages_processed += 1;
154
155 let processed_ns = state.messages_processed * 100; EchoResponse {
159 message: request.message,
160 sequence: request.sequence,
161 processed_ns,
162 }
163 }
164
165 pub fn initialize() -> EchoState {
167 EchoState {
168 messages_processed: 0,
169 start_ns: 0,
170 }
171 }
172}
173
174impl Default for EchoKernel {
175 fn default() -> Self {
176 Self::new()
177 }
178}
179
180impl GpuKernel for EchoKernel {
181 fn metadata(&self) -> &KernelMetadata {
182 &self.metadata
183 }
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct MatMulInput {
193 pub a: Vec<f32>,
195 pub b: Vec<f32>,
197 pub rows_a: usize,
199 pub cols_a: usize,
201 pub cols_b: usize,
203}
204
205#[derive(Debug, Clone, Serialize, Deserialize)]
207pub struct MatMulOutput {
208 pub result: Vec<f32>,
210}
211
212#[derive(Debug, Clone)]
217pub struct MatMul {
218 metadata: KernelMetadata,
219}
220
221impl MatMul {
222 #[must_use]
224 pub fn new() -> Self {
225 Self {
226 metadata: KernelMetadata::batch("core/matmul", Domain::Core)
227 .with_description("Matrix multiplication (GEMM)")
228 .with_throughput(1_000_000)
229 .with_latency_us(50.0)
230 .with_gpu_native(true),
231 }
232 }
233
234 fn matmul(a: &[f32], b: &[f32], rows_a: usize, cols_a: usize, cols_b: usize) -> Vec<f32> {
236 let mut result = vec![0.0f32; rows_a * cols_b];
237
238 for i in 0..rows_a {
239 for j in 0..cols_b {
240 let mut sum = 0.0f32;
241 for k in 0..cols_a {
242 sum += a[i * cols_a + k] * b[k * cols_b + j];
243 }
244 result[i * cols_b + j] = sum;
245 }
246 }
247
248 result
249 }
250}
251
252impl Default for MatMul {
253 fn default() -> Self {
254 Self::new()
255 }
256}
257
258impl GpuKernel for MatMul {
259 fn metadata(&self) -> &KernelMetadata {
260 &self.metadata
261 }
262}
263
264#[async_trait]
265impl BatchKernel<MatMulInput, MatMulOutput> for MatMul {
266 async fn execute(&self, input: MatMulInput) -> Result<MatMulOutput> {
267 self.validate_input(&input)?;
268 let result = Self::matmul(&input.a, &input.b, input.rows_a, input.cols_a, input.cols_b);
269 Ok(MatMulOutput { result })
270 }
271
272 fn validate_input(&self, input: &MatMulInput) -> Result<()> {
273 let expected_a = input.rows_a * input.cols_a;
274 let expected_b = input.cols_a * input.cols_b;
275
276 if input.a.len() != expected_a {
277 return Err(crate::error::KernelError::validation(format!(
278 "Matrix A size mismatch: expected {}, got {}",
279 expected_a,
280 input.a.len()
281 )));
282 }
283
284 if input.b.len() != expected_b {
285 return Err(crate::error::KernelError::validation(format!(
286 "Matrix B size mismatch: expected {}, got {}",
287 expected_b,
288 input.b.len()
289 )));
290 }
291
292 Ok(())
293 }
294}
295
296#[derive(Debug, Clone, Serialize, Deserialize)]
302pub struct ReduceSumInput {
303 pub data: Vec<f32>,
305}
306
307#[derive(Debug, Clone, Serialize, Deserialize)]
309pub struct ReduceSumOutput {
310 pub sum: f64,
312 pub count: usize,
314}
315
316#[derive(Debug, Clone)]
320pub struct ReduceSum {
321 metadata: KernelMetadata,
322}
323
324impl ReduceSum {
325 #[must_use]
327 pub fn new() -> Self {
328 Self {
329 metadata: KernelMetadata::batch("core/reduce-sum", Domain::Core)
330 .with_description("Parallel sum reduction")
331 .with_throughput(100_000_000)
332 .with_latency_us(5.0),
333 }
334 }
335}
336
337impl Default for ReduceSum {
338 fn default() -> Self {
339 Self::new()
340 }
341}
342
343impl GpuKernel for ReduceSum {
344 fn metadata(&self) -> &KernelMetadata {
345 &self.metadata
346 }
347}
348
349#[async_trait]
350impl BatchKernel<ReduceSumInput, ReduceSumOutput> for ReduceSum {
351 async fn execute(&self, input: ReduceSumInput) -> Result<ReduceSumOutput> {
352 let sum: f64 = input.data.iter().map(|&x| f64::from(x)).sum();
353 Ok(ReduceSumOutput {
354 sum,
355 count: input.data.len(),
356 })
357 }
358}
359
360#[cfg(test)]
361mod tests {
362 use super::*;
363 use crate::kernel::KernelMode;
364
365 #[tokio::test]
366 async fn test_vector_add() {
367 let kernel = VectorAdd::new();
368 assert_eq!(kernel.metadata().id, "core/vector-add");
369 assert_eq!(kernel.metadata().mode, KernelMode::Batch);
370
371 let input = VectorAddInput {
372 a: vec![1.0, 2.0, 3.0],
373 b: vec![4.0, 5.0, 6.0],
374 };
375
376 let output = kernel.execute(input).await.unwrap();
377 assert_eq!(output.result, vec![5.0, 7.0, 9.0]);
378 }
379
380 #[tokio::test]
381 async fn test_vector_add_validation() {
382 let kernel = VectorAdd::new();
383
384 let input = VectorAddInput {
385 a: vec![1.0, 2.0],
386 b: vec![1.0, 2.0, 3.0],
387 };
388
389 let result = kernel.execute(input).await;
390 assert!(result.is_err());
391 }
392
393 #[test]
394 fn test_echo_kernel() {
395 let kernel = EchoKernel::new();
396 assert_eq!(kernel.metadata().id, "core/echo");
397 assert_eq!(kernel.metadata().mode, KernelMode::Ring);
398
399 let mut state = EchoKernel::initialize();
400 let request = EchoRequest {
401 message: "Hello".to_string(),
402 sequence: 1,
403 };
404
405 let response = EchoKernel::process(&mut state, request);
406 assert_eq!(response.message, "Hello");
407 assert_eq!(response.sequence, 1);
408 assert_eq!(state.messages_processed, 1);
409 }
410
411 #[tokio::test]
412 async fn test_matmul() {
413 let kernel = MatMul::new();
414
415 let input = MatMulInput {
417 a: vec![1.0, 2.0, 3.0, 4.0],
418 b: vec![5.0, 6.0, 7.0, 8.0],
419 rows_a: 2,
420 cols_a: 2,
421 cols_b: 2,
422 };
423
424 let output = kernel.execute(input).await.unwrap();
425 assert_eq!(output.result, vec![19.0, 22.0, 43.0, 50.0]);
428 }
429
430 #[tokio::test]
431 async fn test_reduce_sum() {
432 let kernel = ReduceSum::new();
433
434 let input = ReduceSumInput {
435 data: vec![1.0, 2.0, 3.0, 4.0, 5.0],
436 };
437
438 let output = kernel.execute(input).await.unwrap();
439 assert!((output.sum - 15.0).abs() < 1e-6);
440 assert_eq!(output.count, 5);
441 }
442}