1#[cfg(not(feature = "no-std"))]
7use std::collections::HashMap;
8#[cfg(not(feature = "no-std"))]
9use std::fmt;
10#[cfg(not(feature = "no-std"))]
11use std::string::ToString;
12#[cfg(not(feature = "no-std"))]
13use std::sync::{Arc, Mutex, RwLock};
14
15#[cfg(feature = "no-std")]
16use alloc::collections::BTreeMap as HashMap;
17#[cfg(feature = "no-std")]
18use alloc::format;
19#[cfg(feature = "no-std")]
20use alloc::string::{String, ToString};
21#[cfg(feature = "no-std")]
22use alloc::sync::Arc;
23#[cfg(feature = "no-std")]
24use alloc::vec::Vec;
25#[cfg(feature = "no-std")]
26use core::fmt;
27#[cfg(feature = "no-std")]
28use spin::{Mutex, RwLock};
29
30pub trait SimdOperation: Send + Sync {
32 fn name(&self) -> &str;
34
35 fn version(&self) -> &str;
37
38 fn description(&self) -> &str;
40
41 fn execute_f32(&self, input: &[f32], output: &mut [f32]) -> Result<(), PluginError>;
43
44 fn execute_f64(&self, input: &[f64], output: &mut [f64]) -> Result<(), PluginError>;
46
47 fn required_input_size(&self, output_size: usize) -> usize {
49 output_size }
51
52 fn supports_inplace(&self) -> bool {
54 false
55 }
56
57 fn simd_requirements(&self) -> SimdRequirements {
59 SimdRequirements::default()
60 }
61}
62
63#[derive(Debug, Clone)]
65pub struct SimdRequirements {
66 pub min_width: usize,
67 pub preferred_width: usize,
68 pub requires_aligned_memory: bool,
69 pub requires_specific_features: Vec<String>,
70}
71
72impl Default for SimdRequirements {
73 fn default() -> Self {
74 Self {
75 min_width: 1,
76 preferred_width: 4,
77 requires_aligned_memory: false,
78 requires_specific_features: Vec::new(),
79 }
80 }
81}
82
83#[derive(Debug, Clone)]
85pub enum PluginError {
86 InvalidInput(String),
87 InvalidOutput(String),
88 IncompatibleSizes(usize, usize),
89 UnsupportedOperation(String),
90 ExecutionFailed(String),
91 RegistrationFailed(String),
92 NotFound(String),
93}
94
95impl fmt::Display for PluginError {
96 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
97 match self {
98 PluginError::InvalidInput(msg) => write!(f, "Invalid input: {}", msg),
99 PluginError::InvalidOutput(msg) => write!(f, "Invalid output: {}", msg),
100 PluginError::IncompatibleSizes(input, output) => {
101 write!(
102 f,
103 "Incompatible sizes: input {} vs output {}",
104 input, output
105 )
106 }
107 PluginError::UnsupportedOperation(op) => {
108 write!(f, "Unsupported operation: {}", op)
109 }
110 PluginError::ExecutionFailed(msg) => write!(f, "Execution failed: {}", msg),
111 PluginError::RegistrationFailed(msg) => write!(f, "Registration failed: {}", msg),
112 PluginError::NotFound(name) => write!(f, "Plugin not found: {}", name),
113 }
114 }
115}
116
117#[cfg(not(feature = "no-std"))]
118impl std::error::Error for PluginError {}
119
120#[cfg(feature = "no-std")]
121impl core::error::Error for PluginError {}
122
123#[derive(Debug, Clone)]
125pub struct PluginMetadata {
126 pub name: String,
127 pub version: String,
128 pub description: String,
129 pub author: String,
130 pub license: String,
131 pub dependencies: Vec<String>,
132 pub simd_requirements: SimdRequirements,
133}
134
135impl Default for PluginMetadata {
136 fn default() -> Self {
137 Self {
138 name: "Unknown".to_string(),
139 version: "0.1.0".to_string(),
140 description: "Custom SIMD operation".to_string(),
141 author: "Unknown".to_string(),
142 license: "MIT".to_string(),
143 dependencies: Vec::new(),
144 simd_requirements: SimdRequirements::default(),
145 }
146 }
147}
148
149pub struct Plugin {
151 pub metadata: PluginMetadata,
152 pub operation: Arc<dyn SimdOperation>,
153}
154
155impl Plugin {
156 pub fn new(operation: Arc<dyn SimdOperation>) -> Self {
157 let metadata = PluginMetadata {
158 name: operation.name().to_string(),
159 version: operation.version().to_string(),
160 description: operation.description().to_string(),
161 ..Default::default()
162 };
163
164 Self {
165 metadata,
166 operation,
167 }
168 }
169
170 pub fn with_metadata(operation: Arc<dyn SimdOperation>, metadata: PluginMetadata) -> Self {
171 Self {
172 metadata,
173 operation,
174 }
175 }
176}
177
178pub struct PluginRegistry {
180 plugins: RwLock<HashMap<String, Arc<Plugin>>>,
181 execution_stats: Mutex<HashMap<String, ExecutionStats>>,
182}
183
184#[derive(Debug, Clone, Default)]
185pub struct ExecutionStats {
186 pub total_calls: u64,
187 pub total_elements_processed: u64,
188 #[cfg(not(feature = "no-std"))]
189 pub total_execution_time: std::time::Duration,
190 pub last_error: Option<String>,
191}
192
193impl Default for PluginRegistry {
194 fn default() -> Self {
195 Self::new()
196 }
197}
198
199impl PluginRegistry {
200 pub fn new() -> Self {
202 Self {
203 plugins: RwLock::new(HashMap::new()),
204 execution_stats: Mutex::new(HashMap::new()),
205 }
206 }
207
208 #[cfg(not(feature = "no-std"))]
210 fn read_plugins(&self) -> std::sync::RwLockReadGuard<'_, HashMap<String, Arc<Plugin>>> {
211 self.plugins.read().expect("operation should succeed")
212 }
213
214 #[cfg(feature = "no-std")]
215 fn read_plugins(&self) -> spin::RwLockReadGuard<'_, HashMap<String, Arc<Plugin>>> {
216 self.plugins.read()
217 }
218
219 #[cfg(not(feature = "no-std"))]
221 fn write_plugins(&self) -> std::sync::RwLockWriteGuard<'_, HashMap<String, Arc<Plugin>>> {
222 self.plugins.write().expect("operation should succeed")
223 }
224
225 #[cfg(feature = "no-std")]
226 fn write_plugins(&self) -> spin::RwLockWriteGuard<'_, HashMap<String, Arc<Plugin>>> {
227 self.plugins.write()
228 }
229
230 #[cfg(not(feature = "no-std"))]
232 fn lock_stats(&self) -> std::sync::MutexGuard<'_, HashMap<String, ExecutionStats>> {
233 self.execution_stats
234 .lock()
235 .expect("lock should not be poisoned")
236 }
237
238 #[cfg(feature = "no-std")]
239 fn lock_stats(&self) -> spin::MutexGuard<'_, HashMap<String, ExecutionStats>, spin::Spin> {
240 self.execution_stats.lock()
241 }
242
243 pub fn register(&self, plugin: Plugin) -> Result<(), PluginError> {
245 let name = plugin.metadata.name.clone();
246
247 self.validate_plugin(&plugin)?;
249
250 let mut plugins = self.write_plugins();
252 plugins.insert(name.clone(), Arc::new(plugin));
253
254 let mut stats = self.lock_stats();
256 stats.insert(name, ExecutionStats::default());
257
258 Ok(())
259 }
260
261 pub fn unregister(&self, name: &str) -> Result<(), PluginError> {
263 let mut plugins = self.write_plugins();
264 plugins.remove(name);
265
266 let mut stats = self.lock_stats();
267 stats.remove(name);
268
269 Ok(())
270 }
271
272 pub fn get(&self, name: &str) -> Result<Arc<Plugin>, PluginError> {
274 let plugins = self.read_plugins();
275 plugins
276 .get(name)
277 .cloned()
278 .ok_or_else(|| PluginError::NotFound(name.to_string()))
279 }
280
281 pub fn list(&self) -> Vec<String> {
283 let plugins = self.read_plugins();
284 plugins.keys().cloned().collect()
285 }
286
287 pub fn execute_f32(
289 &self,
290 name: &str,
291 input: &[f32],
292 output: &mut [f32],
293 ) -> Result<(), PluginError> {
294 let plugin = self.get(name)?;
295
296 #[cfg(not(feature = "no-std"))]
297 let start_time = std::time::Instant::now();
298 let result = plugin.operation.execute_f32(input, output);
299 #[cfg(not(feature = "no-std"))]
300 let execution_time = start_time.elapsed();
301
302 #[cfg(not(feature = "no-std"))]
304 self.update_stats(name, input.len(), execution_time, result.as_ref().err());
305 #[cfg(feature = "no-std")]
306 self.update_stats(name, input.len(), result.as_ref().err());
307
308 result
309 }
310
311 pub fn execute_f64(
313 &self,
314 name: &str,
315 input: &[f64],
316 output: &mut [f64],
317 ) -> Result<(), PluginError> {
318 let plugin = self.get(name)?;
319
320 #[cfg(not(feature = "no-std"))]
321 let start_time = std::time::Instant::now();
322 let result = plugin.operation.execute_f64(input, output);
323 #[cfg(not(feature = "no-std"))]
324 let execution_time = start_time.elapsed();
325
326 #[cfg(not(feature = "no-std"))]
328 self.update_stats(name, input.len(), execution_time, result.as_ref().err());
329 #[cfg(feature = "no-std")]
330 self.update_stats(name, input.len(), result.as_ref().err());
331
332 result
333 }
334
335 pub fn get_stats(&self, name: &str) -> Option<ExecutionStats> {
337 let stats = self.lock_stats();
338 stats.get(name).cloned()
339 }
340
341 pub fn clear_stats(&self) {
343 let mut stats = self.lock_stats();
344 for stat in stats.values_mut() {
345 *stat = ExecutionStats::default();
346 }
347 }
348
349 pub fn find_by_capability(&self, requires_inplace: bool, min_width: usize) -> Vec<String> {
351 let plugins = self.read_plugins();
352 plugins
353 .iter()
354 .filter(|(_, plugin)| {
355 let op = &plugin.operation;
356 (!requires_inplace || op.supports_inplace())
357 && op.simd_requirements().min_width <= min_width
358 })
359 .map(|(name, _)| name.clone())
360 .collect()
361 }
362
363 fn validate_plugin(&self, plugin: &Plugin) -> Result<(), PluginError> {
364 let name = &plugin.metadata.name;
365
366 let plugins = self.read_plugins();
368 if plugins.contains_key(name) {
369 return Err(PluginError::RegistrationFailed(format!(
370 "Plugin '{}' is already registered",
371 name
372 )));
373 }
374
375 if name.is_empty() {
377 return Err(PluginError::RegistrationFailed(
378 "Plugin name cannot be empty".to_string(),
379 ));
380 }
381
382 Ok(())
383 }
384
385 #[cfg(not(feature = "no-std"))]
386 fn update_stats(
387 &self,
388 name: &str,
389 elements: usize,
390 time: std::time::Duration,
391 error: Option<&PluginError>,
392 ) {
393 let mut stats = self.lock_stats();
394 if let Some(stat) = stats.get_mut(name) {
395 stat.total_calls += 1;
396 stat.total_elements_processed += elements as u64;
397 stat.total_execution_time += time;
398 if let Some(err) = error {
399 stat.last_error = Some(err.to_string());
400 }
401 }
402 }
403
404 #[cfg(feature = "no-std")]
405 fn update_stats(&self, name: &str, elements: usize, error: Option<&PluginError>) {
406 let mut stats = self.lock_stats();
407 if let Some(stat) = stats.get_mut(name) {
408 stat.total_calls += 1;
409 stat.total_elements_processed += elements as u64;
410 if let Some(err) = error {
411 stat.last_error = Some(err.to_string());
412 }
413 }
414 }
415}
416
417pub static GLOBAL_REGISTRY: once_cell::sync::Lazy<PluginRegistry> =
419 once_cell::sync::Lazy::new(PluginRegistry::new);
420
421pub mod global {
423 use super::*;
424
425 pub fn register(plugin: Plugin) -> Result<(), PluginError> {
427 GLOBAL_REGISTRY.register(plugin)
428 }
429
430 pub fn execute_f32(name: &str, input: &[f32], output: &mut [f32]) -> Result<(), PluginError> {
432 GLOBAL_REGISTRY.execute_f32(name, input, output)
433 }
434
435 pub fn execute_f64(name: &str, input: &[f64], output: &mut [f64]) -> Result<(), PluginError> {
437 GLOBAL_REGISTRY.execute_f64(name, input, output)
438 }
439
440 pub fn list() -> Vec<String> {
442 GLOBAL_REGISTRY.list()
443 }
444
445 pub fn get_stats(name: &str) -> Option<ExecutionStats> {
447 GLOBAL_REGISTRY.get_stats(name)
448 }
449}
450
451pub mod examples {
453 use super::*;
454
455 pub struct SquareOperation;
457
458 impl SimdOperation for SquareOperation {
459 fn name(&self) -> &str {
460 "square"
461 }
462 fn version(&self) -> &str {
463 "1.0.0"
464 }
465 fn description(&self) -> &str {
466 "Square each element"
467 }
468
469 fn execute_f32(&self, input: &[f32], output: &mut [f32]) -> Result<(), PluginError> {
470 if input.len() != output.len() {
471 return Err(PluginError::IncompatibleSizes(input.len(), output.len()));
472 }
473
474 for (i, &val) in input.iter().enumerate() {
475 output[i] = val * val;
476 }
477 Ok(())
478 }
479
480 fn execute_f64(&self, input: &[f64], output: &mut [f64]) -> Result<(), PluginError> {
481 if input.len() != output.len() {
482 return Err(PluginError::IncompatibleSizes(input.len(), output.len()));
483 }
484
485 for (i, &val) in input.iter().enumerate() {
486 output[i] = val * val;
487 }
488 Ok(())
489 }
490
491 fn supports_inplace(&self) -> bool {
492 true
493 }
494 }
495
496 pub struct MovingAverageOperation {
498 window_size: usize,
499 }
500
501 impl MovingAverageOperation {
502 pub fn new(window_size: usize) -> Self {
503 Self { window_size }
504 }
505 }
506
507 impl SimdOperation for MovingAverageOperation {
508 fn name(&self) -> &str {
509 "moving_average"
510 }
511 fn version(&self) -> &str {
512 "1.0.0"
513 }
514 fn description(&self) -> &str {
515 "Compute moving average with configurable window"
516 }
517
518 fn execute_f32(&self, input: &[f32], output: &mut [f32]) -> Result<(), PluginError> {
519 if input.len() < self.window_size {
520 return Err(PluginError::InvalidInput(
521 "Input too small for window size".to_string(),
522 ));
523 }
524
525 let expected_output_size = input.len() - self.window_size + 1;
526 if output.len() != expected_output_size {
527 return Err(PluginError::IncompatibleSizes(
528 expected_output_size,
529 output.len(),
530 ));
531 }
532
533 for i in 0..output.len() {
534 let sum: f32 = input[i..i + self.window_size].iter().sum();
535 output[i] = sum / self.window_size as f32;
536 }
537 Ok(())
538 }
539
540 fn execute_f64(&self, input: &[f64], output: &mut [f64]) -> Result<(), PluginError> {
541 if input.len() < self.window_size {
542 return Err(PluginError::InvalidInput(
543 "Input too small for window size".to_string(),
544 ));
545 }
546
547 let expected_output_size = input.len() - self.window_size + 1;
548 if output.len() != expected_output_size {
549 return Err(PluginError::IncompatibleSizes(
550 expected_output_size,
551 output.len(),
552 ));
553 }
554
555 for i in 0..output.len() {
556 let sum: f64 = input[i..i + self.window_size].iter().sum();
557 output[i] = sum / self.window_size as f64;
558 }
559 Ok(())
560 }
561
562 fn required_input_size(&self, output_size: usize) -> usize {
563 output_size + self.window_size - 1
564 }
565 }
566}
567
568#[allow(non_snake_case)]
569#[cfg(all(test, not(feature = "no-std")))]
570mod tests {
571 use super::examples::*;
572 use super::*;
573
574 #[cfg(feature = "no-std")]
575 use alloc::vec;
576
577 #[test]
578 fn test_plugin_registration() {
579 let registry = PluginRegistry::new();
580 let operation = Arc::new(SquareOperation);
581 let plugin = Plugin::new(operation);
582
583 assert!(registry.register(plugin).is_ok());
584 assert!(registry.list().contains(&"square".to_string()));
585 }
586
587 #[test]
588 fn test_plugin_execution() {
589 let registry = PluginRegistry::new();
590 let operation = Arc::new(SquareOperation);
591 let plugin = Plugin::new(operation);
592
593 registry.register(plugin).expect("operation should succeed");
594
595 let input = vec![1.0, 2.0, 3.0, 4.0];
596 let mut output = vec![0.0; 4];
597
598 registry
599 .execute_f32("square", &input, &mut output)
600 .expect("operation should succeed");
601 assert_eq!(output, vec![1.0, 4.0, 9.0, 16.0]);
602 }
603
604 #[test]
605 fn test_moving_average_plugin() {
606 let registry = PluginRegistry::new();
607 let operation = Arc::new(MovingAverageOperation::new(3));
608 let plugin = Plugin::new(operation);
609
610 registry.register(plugin).expect("operation should succeed");
611
612 let input = vec![1.0, 2.0, 3.0, 4.0, 5.0];
613 let mut output = vec![0.0; 3]; registry
616 .execute_f32("moving_average", &input, &mut output)
617 .expect("operation should succeed");
618
619 assert_eq!(output, vec![2.0, 3.0, 4.0]);
621 }
622
623 #[test]
624 fn test_plugin_stats() {
625 let registry = PluginRegistry::new();
626 let operation = Arc::new(SquareOperation);
627 let plugin = Plugin::new(operation);
628
629 registry.register(plugin).expect("operation should succeed");
630
631 let input = vec![1.0, 2.0];
632 let mut output = vec![0.0; 2];
633
634 registry
635 .execute_f32("square", &input, &mut output)
636 .expect("operation should succeed");
637
638 let stats = registry
639 .get_stats("square")
640 .expect("operation should succeed");
641 assert_eq!(stats.total_calls, 1);
642 assert_eq!(stats.total_elements_processed, 2);
643 }
644
645 #[test]
646 fn test_global_registry() {
647 let operation = Arc::new(SquareOperation);
648 let plugin = Plugin::new(operation);
649
650 global::register(plugin).expect("operation should succeed");
651
652 let input = vec![2.0, 3.0];
653 let mut output = vec![0.0; 2];
654
655 global::execute_f32("square", &input, &mut output).expect("operation should succeed");
656 assert_eq!(output, vec![4.0, 9.0]);
657
658 let plugins = global::list();
659 assert!(plugins.contains(&"square".to_string()));
660 }
661
662 #[test]
663 fn test_error_handling() {
664 let registry = PluginRegistry::new();
665
666 let input = vec![1.0];
668 let mut output = vec![0.0];
669 let result = registry.execute_f32("nonexistent", &input, &mut output);
670 assert!(matches!(result, Err(PluginError::NotFound(_))));
671
672 let operation = Arc::new(SquareOperation);
674 let plugin = Plugin::new(operation);
675 registry.register(plugin).expect("operation should succeed");
676
677 let input = vec![1.0, 2.0];
678 let mut output = vec![0.0]; let result = registry.execute_f32("square", &input, &mut output);
680 assert!(matches!(result, Err(PluginError::IncompatibleSizes(_, _))));
681 }
682}