1#[cfg(not(feature = "no-std"))]
7use std::any;
8#[cfg(not(feature = "no-std"))]
9use std::boxed::Box;
10#[cfg(not(feature = "no-std"))]
11use std::collections::HashMap;
12#[cfg(not(feature = "no-std"))]
13use std::fmt::Debug;
14#[cfg(not(feature = "no-std"))]
15use std::string::{String, ToString};
16#[cfg(not(feature = "no-std"))]
17use std::vec::Vec;
18
19#[cfg(feature = "no-std")]
20use alloc::boxed::Box;
21#[cfg(feature = "no-std")]
22use alloc::collections::BTreeMap as HashMap;
23#[cfg(feature = "no-std")]
24use alloc::format;
25#[cfg(feature = "no-std")]
26use alloc::string::{String, ToString};
27#[cfg(feature = "no-std")]
28use alloc::vec::Vec;
29#[cfg(feature = "no-std")]
30use core::any;
31#[cfg(feature = "no-std")]
32use core::fmt::Debug;
33
34pub trait SimdOperation<T> {
36 type Output;
38
39 type Error;
41
42 fn execute(&self, input: &[T]) -> Result<Self::Output, Self::Error>;
44
45 fn optimal_width(&self) -> usize;
47
48 fn is_supported(&self) -> bool;
50
51 fn name(&self) -> &'static str;
53}
54
55pub trait VectorArithmetic<T> {
57 fn add(&self, a: &[T], b: &[T]) -> Result<Vec<T>, SimdError>;
59
60 fn sub(&self, a: &[T], b: &[T]) -> Result<Vec<T>, SimdError>;
62
63 fn mul(&self, a: &[T], b: &[T]) -> Result<Vec<T>, SimdError>;
65
66 fn div(&self, a: &[T], b: &[T]) -> Result<Vec<T>, SimdError>;
68
69 fn fma(&self, a: &[T], b: &[T], c: &[T]) -> Result<Vec<T>, SimdError>;
71
72 fn scale(&self, vector: &[T], scalar: T) -> Result<Vec<T>, SimdError>;
74}
75
76pub trait VectorReduction<T> {
78 fn sum(&self, vector: &[T]) -> Result<T, SimdError>;
80
81 fn min(&self, vector: &[T]) -> Result<T, SimdError>;
83
84 fn max(&self, vector: &[T]) -> Result<T, SimdError>;
86
87 fn dot_product(&self, a: &[T], b: &[T]) -> Result<T, SimdError>;
89
90 fn norm(&self, vector: &[T]) -> Result<T, SimdError>;
92
93 fn mean(&self, vector: &[T]) -> Result<T, SimdError>;
95}
96
97pub trait DistanceMetric<T> {
99 fn euclidean_distance(&self, a: &[T], b: &[T]) -> Result<T, SimdError>;
101
102 fn manhattan_distance(&self, a: &[T], b: &[T]) -> Result<T, SimdError>;
104
105 fn cosine_distance(&self, a: &[T], b: &[T]) -> Result<T, SimdError>;
107
108 fn squared_euclidean_distance(&self, a: &[T], b: &[T]) -> Result<T, SimdError>;
110}
111
112pub trait ActivationFunction<T: Copy> {
114 fn apply(&self, input: &[T]) -> Result<Vec<T>, SimdError>;
116
117 fn derivative(&self, input: &[T]) -> Result<Vec<T>, SimdError>;
119
120 fn name(&self) -> &'static str;
122
123 fn supports_inplace(&self) -> bool;
125
126 fn apply_inplace(&self, input: &mut [T]) -> Result<(), SimdError> {
128 if !self.supports_inplace() {
129 return Err(SimdError::UnsupportedOperation(
130 "In-place operation not supported".to_string(),
131 ));
132 }
133 let result = self.apply(input)?;
134 input.copy_from_slice(&result);
135 Ok(())
136 }
137}
138
139pub trait KernelFunction<T> {
141 fn compute(&self, a: &[T], b: &[T]) -> Result<T, SimdError>;
143
144 fn kernel_matrix(&self, vectors: &[&[T]]) -> Result<Vec<Vec<T>>, SimdError>;
146
147 fn name(&self) -> &'static str;
149
150 fn has_parameters(&self) -> bool;
152}
153
154pub trait MatrixOperations<T> {
156 fn matrix_vector_multiply(&self, matrix: &[Vec<T>], vector: &[T]) -> Result<Vec<T>, SimdError>;
158
159 fn matrix_multiply(&self, a: &[Vec<T>], b: &[Vec<T>]) -> Result<Vec<Vec<T>>, SimdError>;
161
162 fn transpose(&self, matrix: &[Vec<T>]) -> Result<Vec<Vec<T>>, SimdError>;
164
165 fn elementwise_add(&self, a: &[Vec<T>], b: &[Vec<T>]) -> Result<Vec<Vec<T>>, SimdError>;
167}
168
169pub trait ClusteringOperations<T> {
171 fn point_to_centroid_distances(
173 &self,
174 points: &[&[T]],
175 centroids: &[&[T]],
176 ) -> Result<Vec<Vec<T>>, SimdError>;
177
178 fn update_centroids(
180 &self,
181 points: &[&[T]],
182 assignments: &[usize],
183 k: usize,
184 ) -> Result<Vec<Vec<T>>, SimdError>;
185
186 fn wcss(
188 &self,
189 points: &[&[T]],
190 centroids: &[&[T]],
191 assignments: &[usize],
192 ) -> Result<T, SimdError>;
193}
194
195#[derive(Debug, Clone)]
197pub enum SimdError {
198 DimensionMismatch { expected: usize, actual: usize },
200
201 EmptyInput,
203
204 UnsupportedPlatform,
206
207 UnsupportedOperation(String),
209
210 NumericalError(String),
212
213 InvalidParameter { name: String, value: String },
215
216 AllocationError,
218
219 ExternalLibraryError(String),
221
222 InvalidInput(String),
224
225 InvalidArgument(String),
227
228 NotImplemented(String),
230
231 Other(String),
233}
234
235impl core::fmt::Display for SimdError {
236 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
237 match self {
238 SimdError::DimensionMismatch { expected, actual } => {
239 write!(
240 f,
241 "Dimension mismatch: expected {}, got {}",
242 expected, actual
243 )
244 }
245 SimdError::EmptyInput => write!(f, "Input data is empty"),
246 SimdError::UnsupportedPlatform => {
247 write!(f, "SIMD operation not supported on this platform")
248 }
249 SimdError::UnsupportedOperation(op) => write!(f, "Unsupported operation: {}", op),
250 SimdError::NumericalError(msg) => write!(f, "Numerical error: {}", msg),
251 SimdError::InvalidParameter { name, value } => {
252 write!(f, "Invalid parameter {}: {}", name, value)
253 }
254 SimdError::AllocationError => write!(f, "Memory allocation failed"),
255 SimdError::ExternalLibraryError(msg) => write!(f, "External library error: {}", msg),
256 SimdError::InvalidInput(msg) => write!(f, "Invalid input: {}", msg),
257 SimdError::InvalidArgument(msg) => write!(f, "Invalid argument: {}", msg),
258 SimdError::NotImplemented(msg) => write!(f, "Not implemented: {}", msg),
259 SimdError::Other(msg) => write!(f, "Error: {}", msg),
260 }
261 }
262}
263
264#[cfg(not(feature = "no-std"))]
265impl std::error::Error for SimdError {}
266
267#[cfg(feature = "no-std")]
268impl core::error::Error for SimdError {}
269
270pub trait SimdDispatcher<T> {
272 type Operation;
274
275 fn select_implementation(
277 &self,
278 ) -> Box<dyn SimdOperation<T, Output = Self::Operation, Error = SimdError>>;
279
280 fn available_implementations(&self) -> Vec<&'static str>;
282
283 fn force_implementation(
285 &self,
286 name: &str,
287 ) -> Option<Box<dyn SimdOperation<T, Output = Self::Operation, Error = SimdError>>>;
288}
289
290pub trait SimdConfig {
292 fn set_simd_width(&mut self, width: usize);
294
295 fn simd_width(&self) -> usize;
297
298 fn set_scalar_fallback(&mut self, enabled: bool);
300
301 fn scalar_fallback_enabled(&self) -> bool;
303
304 fn set_precision_tolerance(&mut self, tolerance: f64);
306
307 fn precision_tolerance(&self) -> f64;
309}
310
311#[derive(Debug, Clone)]
313pub struct DefaultSimdConfig {
314 pub simd_width: usize,
315 pub scalar_fallback: bool,
316 pub precision_tolerance: f64,
317}
318
319impl Default for DefaultSimdConfig {
320 fn default() -> Self {
321 Self {
322 simd_width: crate::SIMD_CAPS.best_f32_width(),
323 scalar_fallback: true,
324 precision_tolerance: 1e-6,
325 }
326 }
327}
328
329impl SimdConfig for DefaultSimdConfig {
330 fn set_simd_width(&mut self, width: usize) {
331 self.simd_width = width;
332 }
333
334 fn simd_width(&self) -> usize {
335 self.simd_width
336 }
337
338 fn set_scalar_fallback(&mut self, enabled: bool) {
339 self.scalar_fallback = enabled;
340 }
341
342 fn scalar_fallback_enabled(&self) -> bool {
343 self.scalar_fallback
344 }
345
346 fn set_precision_tolerance(&mut self, tolerance: f64) {
347 self.precision_tolerance = tolerance;
348 }
349
350 fn precision_tolerance(&self) -> f64 {
351 self.precision_tolerance
352 }
353}
354
355pub trait ComposableOperation<T>: SimdOperation<T> {
357 fn compose<Other>(self, other: Other) -> ComposedOperation<Self, Other>
359 where
360 Self: Sized,
361 Other: SimdOperation<T>;
362
363 fn map<F, U>(self, f: F) -> MappedOperation<Self, F>
365 where
366 Self: Sized,
367 F: Fn(Self::Output) -> U;
368}
369
370pub struct ComposedOperation<First, Second> {
372 #[allow(dead_code)] first: First,
374 #[allow(dead_code)] second: Second,
376}
377
378impl<First, Second> ComposedOperation<First, Second> {
379 pub fn new(first: First, second: Second) -> Self {
380 Self { first, second }
381 }
382}
383
384pub struct MappedOperation<Op, F> {
386 #[allow(dead_code)] operation: Op,
388 #[allow(dead_code)] mapper: F,
390}
391
392impl<Op, F> MappedOperation<Op, F> {
393 pub fn new(operation: Op, mapper: F) -> Self {
394 Self { operation, mapper }
395 }
396}
397
398pub trait ParallelSimdOperation<T>: SimdOperation<T> {
400 fn execute_parallel(&self, input: &[T], chunk_size: usize)
402 -> Result<Self::Output, Self::Error>;
403
404 fn optimal_chunk_size(&self, input_size: usize) -> usize;
406
407 fn should_parallelize(&self, input_size: usize) -> bool;
409}
410
411pub struct SimdRegistry {
413 #[cfg(not(feature = "no-std"))]
414 operations: HashMap<String, Box<dyn any::Any + Send + Sync>>,
415 #[cfg(feature = "no-std")]
416 operations: HashMap<String, Box<dyn any::Any + Send + Sync>>,
417}
418
419impl Default for SimdRegistry {
420 fn default() -> Self {
421 Self::new()
422 }
423}
424
425impl SimdRegistry {
426 pub fn new() -> Self {
428 Self {
429 operations: HashMap::new(),
430 }
431 }
432
433 pub fn register<T: 'static + Send + Sync>(&mut self, name: String, operation: T) {
435 self.operations.insert(name, Box::new(operation));
436 }
437
438 pub fn get<T: 'static>(&self, name: &str) -> Option<&T> {
440 self.operations
441 .get(name)
442 .and_then(|op| op.downcast_ref::<T>())
443 }
444
445 pub fn list_operations(&self) -> Vec<&String> {
447 self.operations.keys().collect()
448 }
449}
450
451#[macro_export]
453macro_rules! impl_simd_operation {
454 ($type:ty, $output:ty, $name:literal) => {
455 impl SimdOperation<f32> for $type {
456 type Output = $output;
457 type Error = SimdError;
458
459 fn execute(&self, input: &[f32]) -> Result<Self::Output, Self::Error> {
460 if input.is_empty() {
461 return Err(SimdError::EmptyInput);
462 }
463 self.compute(input)
464 }
465
466 fn optimal_width(&self) -> usize {
467 $crate::SIMD_CAPS.best_f32_width()
468 }
469
470 fn is_supported(&self) -> bool {
471 self.optimal_width() > 1
472 }
473
474 fn name(&self) -> &'static str {
475 $name
476 }
477 }
478 };
479}
480
481pub mod utils {
483 use super::*;
484
485 pub fn validate_same_length<T>(a: &[T], b: &[T]) -> Result<(), SimdError> {
487 if a.len() != b.len() {
488 Err(SimdError::DimensionMismatch {
489 expected: a.len(),
490 actual: b.len(),
491 })
492 } else {
493 Ok(())
494 }
495 }
496
497 pub fn validate_not_empty<T>(slice: &[T]) -> Result<(), SimdError> {
499 if slice.is_empty() {
500 Err(SimdError::EmptyInput)
501 } else {
502 Ok(())
503 }
504 }
505
506 pub fn validate_finite(slice: &[f32]) -> Result<(), SimdError> {
508 for &value in slice {
509 if !value.is_finite() {
510 return Err(SimdError::NumericalError(format!(
511 "Non-finite value encountered: {}",
512 value
513 )));
514 }
515 }
516 Ok(())
517 }
518
519 pub fn create_chunks<T>(slice: &[T], chunk_size: usize) -> impl Iterator<Item = &[T]> {
521 slice.chunks(chunk_size)
522 }
523
524 pub fn optimal_chunk_size(input_size: usize, simd_width: usize) -> usize {
526 let base_chunk = simd_width * 64; let max_chunk = input_size / 4; if max_chunk < base_chunk {
530 max_chunk.max(simd_width)
531 } else {
532 base_chunk
533 }
534 }
535}
536
537#[allow(non_snake_case)]
538#[cfg(all(test, not(feature = "no-std")))]
539mod tests {
540 use super::*;
541
542 #[cfg(feature = "no-std")]
543 use alloc::{vec, vec::Vec};
544
545 struct MockVectorAdd;
547
548 impl MockVectorAdd {
549 fn compute(&self, input: &[f32]) -> Result<Vec<f32>, SimdError> {
550 Ok(input.iter().map(|&x| x + 1.0).collect())
551 }
552 }
553
554 impl_simd_operation!(MockVectorAdd, Vec<f32>, "mock_vector_add");
555
556 #[test]
557 fn test_simd_operation_trait() {
558 let op = MockVectorAdd;
559 let input = vec![1.0, 2.0, 3.0, 4.0];
560
561 let result = op.execute(&input).expect("operation should succeed");
562 assert_eq!(result, vec![2.0, 3.0, 4.0, 5.0]);
563
564 assert_eq!(op.name(), "mock_vector_add");
565 assert!(op.optimal_width() >= 1);
566 }
567
568 #[test]
569 fn test_simd_error_display() {
570 let error = SimdError::DimensionMismatch {
571 expected: 4,
572 actual: 3,
573 };
574 assert!(error.to_string().contains("Dimension mismatch"));
575
576 let error = SimdError::EmptyInput;
577 assert!(error.to_string().contains("empty"));
578 }
579
580 #[test]
581 fn test_default_simd_config() {
582 let mut config = DefaultSimdConfig::default();
583
584 assert!(config.simd_width() >= 1);
585 assert!(config.scalar_fallback_enabled());
586 assert_eq!(config.precision_tolerance(), 1e-6);
587
588 config.set_simd_width(8);
589 assert_eq!(config.simd_width(), 8);
590
591 config.set_scalar_fallback(false);
592 assert!(!config.scalar_fallback_enabled());
593
594 config.set_precision_tolerance(1e-8);
595 assert_eq!(config.precision_tolerance(), 1e-8);
596 }
597
598 #[test]
599 fn test_simd_registry() {
600 let mut registry = SimdRegistry::new();
601
602 registry.register("test_op".to_string(), MockVectorAdd);
603
604 let operations = registry.list_operations();
605 assert_eq!(operations.len(), 1);
606 assert_eq!(operations[0], "test_op");
607
608 let op = registry.get::<MockVectorAdd>("test_op");
609 assert!(op.is_some());
610
611 let nonexistent = registry.get::<MockVectorAdd>("nonexistent");
612 assert!(nonexistent.is_none());
613 }
614
615 #[test]
616 fn test_validation_utils() {
617 use utils::*;
618
619 let a = vec![1.0, 2.0, 3.0];
621 let b = vec![4.0, 5.0, 6.0];
622 let c = vec![7.0, 8.0];
623
624 assert!(validate_same_length(&a, &b).is_ok());
625 assert!(validate_same_length(&a, &c).is_err());
626
627 assert!(validate_not_empty(&a).is_ok());
629 assert!(validate_not_empty(&Vec::<f32>::new()).is_err());
630
631 let finite = vec![1.0, 2.0, 3.0];
633 let infinite = vec![1.0, f32::INFINITY, 3.0];
634 let nan = vec![1.0, f32::NAN, 3.0];
635
636 assert!(validate_finite(&finite).is_ok());
637 assert!(validate_finite(&infinite).is_err());
638 assert!(validate_finite(&nan).is_err());
639 }
640
641 #[test]
642 fn test_chunk_utilities() {
643 use utils::*;
644
645 let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
646 let chunks: Vec<&[i32]> = create_chunks(&data, 3).collect();
647
648 assert_eq!(chunks.len(), 4);
649 assert_eq!(chunks[0], &[1, 2, 3]);
650 assert_eq!(chunks[1], &[4, 5, 6]);
651 assert_eq!(chunks[2], &[7, 8, 9]);
652 assert_eq!(chunks[3], &[10]);
653
654 let chunk_size = optimal_chunk_size(1000, 8);
655 assert!(chunk_size >= 8);
656 assert!(chunk_size <= 1000);
657 }
658}