Skip to main content

sklears_simd/
traits.rs

1//! Trait-based SIMD framework for modular and composable operations
2//!
3//! This module provides a comprehensive trait system for SIMD operations,
4//! enabling modular design, runtime dispatch, and easy extensibility.
5
6#[cfg(not(feature = "no-std"))]
7use std::any;
8#[cfg(not(feature = "no-std"))]
9use std::boxed::Box;
10#[cfg(not(feature = "no-std"))]
11use std::collections::HashMap;
12#[cfg(not(feature = "no-std"))]
13use std::fmt::Debug;
14#[cfg(not(feature = "no-std"))]
15use std::string::{String, ToString};
16#[cfg(not(feature = "no-std"))]
17use std::vec::Vec;
18
19#[cfg(feature = "no-std")]
20use alloc::boxed::Box;
21#[cfg(feature = "no-std")]
22use alloc::collections::BTreeMap as HashMap;
23#[cfg(feature = "no-std")]
24use alloc::format;
25#[cfg(feature = "no-std")]
26use alloc::string::{String, ToString};
27#[cfg(feature = "no-std")]
28use alloc::vec::Vec;
29#[cfg(feature = "no-std")]
30use core::any;
31#[cfg(feature = "no-std")]
32use core::fmt::Debug;
33
34/// Core trait for all SIMD operations
35pub trait SimdOperation<T> {
36    /// The output type of the operation
37    type Output;
38
39    /// The error type for operation failures
40    type Error;
41
42    /// Execute the SIMD operation
43    fn execute(&self, input: &[T]) -> Result<Self::Output, Self::Error>;
44
45    /// Get the optimal SIMD width for this operation on the current platform
46    fn optimal_width(&self) -> usize;
47
48    /// Check if the operation can be performed with SIMD on the current platform
49    fn is_supported(&self) -> bool;
50
51    /// Get a human-readable name for this operation
52    fn name(&self) -> &'static str;
53}
54
55/// Trait for vectorized arithmetic operations
56pub trait VectorArithmetic<T> {
57    /// Add two vectors element-wise
58    fn add(&self, a: &[T], b: &[T]) -> Result<Vec<T>, SimdError>;
59
60    /// Subtract two vectors element-wise
61    fn sub(&self, a: &[T], b: &[T]) -> Result<Vec<T>, SimdError>;
62
63    /// Multiply two vectors element-wise
64    fn mul(&self, a: &[T], b: &[T]) -> Result<Vec<T>, SimdError>;
65
66    /// Divide two vectors element-wise
67    fn div(&self, a: &[T], b: &[T]) -> Result<Vec<T>, SimdError>;
68
69    /// Compute fused multiply-add: a * b + c
70    fn fma(&self, a: &[T], b: &[T], c: &[T]) -> Result<Vec<T>, SimdError>;
71
72    /// Scale a vector by a scalar
73    fn scale(&self, vector: &[T], scalar: T) -> Result<Vec<T>, SimdError>;
74}
75
76/// Trait for vector reduction operations
77pub trait VectorReduction<T> {
78    /// Sum all elements in the vector
79    fn sum(&self, vector: &[T]) -> Result<T, SimdError>;
80
81    /// Find the minimum element
82    fn min(&self, vector: &[T]) -> Result<T, SimdError>;
83
84    /// Find the maximum element
85    fn max(&self, vector: &[T]) -> Result<T, SimdError>;
86
87    /// Compute the dot product of two vectors
88    fn dot_product(&self, a: &[T], b: &[T]) -> Result<T, SimdError>;
89
90    /// Compute the L2 norm of a vector
91    fn norm(&self, vector: &[T]) -> Result<T, SimdError>;
92
93    /// Compute the mean of all elements
94    fn mean(&self, vector: &[T]) -> Result<T, SimdError>;
95}
96
97/// Trait for distance computations
98pub trait DistanceMetric<T> {
99    /// Compute Euclidean distance between two vectors
100    fn euclidean_distance(&self, a: &[T], b: &[T]) -> Result<T, SimdError>;
101
102    /// Compute Manhattan (L1) distance
103    fn manhattan_distance(&self, a: &[T], b: &[T]) -> Result<T, SimdError>;
104
105    /// Compute Cosine distance
106    fn cosine_distance(&self, a: &[T], b: &[T]) -> Result<T, SimdError>;
107
108    /// Compute squared Euclidean distance (avoiding square root)
109    fn squared_euclidean_distance(&self, a: &[T], b: &[T]) -> Result<T, SimdError>;
110}
111
112/// Trait for activation functions used in neural networks
113pub trait ActivationFunction<T: Copy> {
114    /// Apply the activation function
115    fn apply(&self, input: &[T]) -> Result<Vec<T>, SimdError>;
116
117    /// Apply the derivative of the activation function
118    fn derivative(&self, input: &[T]) -> Result<Vec<T>, SimdError>;
119
120    /// Get the name of the activation function
121    fn name(&self) -> &'static str;
122
123    /// Check if this activation function supports in-place operations
124    fn supports_inplace(&self) -> bool;
125
126    /// Apply the activation function in-place (if supported)
127    fn apply_inplace(&self, input: &mut [T]) -> Result<(), SimdError> {
128        if !self.supports_inplace() {
129            return Err(SimdError::UnsupportedOperation(
130                "In-place operation not supported".to_string(),
131            ));
132        }
133        let result = self.apply(input)?;
134        input.copy_from_slice(&result);
135        Ok(())
136    }
137}
138
139/// Trait for kernel functions used in SVM and other algorithms
140pub trait KernelFunction<T> {
141    /// Compute the kernel function between two vectors
142    fn compute(&self, a: &[T], b: &[T]) -> Result<T, SimdError>;
143
144    /// Compute kernel matrix for a set of vectors
145    fn kernel_matrix(&self, vectors: &[&[T]]) -> Result<Vec<Vec<T>>, SimdError>;
146
147    /// Get the name of the kernel function
148    fn name(&self) -> &'static str;
149
150    /// Check if kernel supports hyperparameters
151    fn has_parameters(&self) -> bool;
152}
153
154/// Trait for matrix operations
155pub trait MatrixOperations<T> {
156    /// Matrix-vector multiplication
157    fn matrix_vector_multiply(&self, matrix: &[Vec<T>], vector: &[T]) -> Result<Vec<T>, SimdError>;
158
159    /// Matrix-matrix multiplication
160    fn matrix_multiply(&self, a: &[Vec<T>], b: &[Vec<T>]) -> Result<Vec<Vec<T>>, SimdError>;
161
162    /// Matrix transpose
163    fn transpose(&self, matrix: &[Vec<T>]) -> Result<Vec<Vec<T>>, SimdError>;
164
165    /// Element-wise matrix operations
166    fn elementwise_add(&self, a: &[Vec<T>], b: &[Vec<T>]) -> Result<Vec<Vec<T>>, SimdError>;
167}
168
169/// Trait for clustering operations
170pub trait ClusteringOperations<T> {
171    /// Compute distances from points to centroids
172    fn point_to_centroid_distances(
173        &self,
174        points: &[&[T]],
175        centroids: &[&[T]],
176    ) -> Result<Vec<Vec<T>>, SimdError>;
177
178    /// Update centroids based on point assignments
179    fn update_centroids(
180        &self,
181        points: &[&[T]],
182        assignments: &[usize],
183        k: usize,
184    ) -> Result<Vec<Vec<T>>, SimdError>;
185
186    /// Compute within-cluster sum of squares
187    fn wcss(
188        &self,
189        points: &[&[T]],
190        centroids: &[&[T]],
191        assignments: &[usize],
192    ) -> Result<T, SimdError>;
193}
194
195/// Common error types for SIMD operations
196#[derive(Debug, Clone)]
197pub enum SimdError {
198    /// Input vectors have mismatched dimensions
199    DimensionMismatch { expected: usize, actual: usize },
200
201    /// Input data is empty
202    EmptyInput,
203
204    /// SIMD operation is not supported on this platform
205    UnsupportedPlatform,
206
207    /// Operation is not implemented for this type
208    UnsupportedOperation(String),
209
210    /// Numerical error (overflow, underflow, NaN)
211    NumericalError(String),
212
213    /// Invalid parameter value
214    InvalidParameter { name: String, value: String },
215
216    /// Memory allocation error
217    AllocationError,
218
219    /// External library integration error
220    ExternalLibraryError(String),
221
222    /// Invalid input data
223    InvalidInput(String),
224
225    /// Invalid argument provided
226    InvalidArgument(String),
227
228    /// Feature not implemented
229    NotImplemented(String),
230
231    /// Other generic errors
232    Other(String),
233}
234
235impl core::fmt::Display for SimdError {
236    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
237        match self {
238            SimdError::DimensionMismatch { expected, actual } => {
239                write!(
240                    f,
241                    "Dimension mismatch: expected {}, got {}",
242                    expected, actual
243                )
244            }
245            SimdError::EmptyInput => write!(f, "Input data is empty"),
246            SimdError::UnsupportedPlatform => {
247                write!(f, "SIMD operation not supported on this platform")
248            }
249            SimdError::UnsupportedOperation(op) => write!(f, "Unsupported operation: {}", op),
250            SimdError::NumericalError(msg) => write!(f, "Numerical error: {}", msg),
251            SimdError::InvalidParameter { name, value } => {
252                write!(f, "Invalid parameter {}: {}", name, value)
253            }
254            SimdError::AllocationError => write!(f, "Memory allocation failed"),
255            SimdError::ExternalLibraryError(msg) => write!(f, "External library error: {}", msg),
256            SimdError::InvalidInput(msg) => write!(f, "Invalid input: {}", msg),
257            SimdError::InvalidArgument(msg) => write!(f, "Invalid argument: {}", msg),
258            SimdError::NotImplemented(msg) => write!(f, "Not implemented: {}", msg),
259            SimdError::Other(msg) => write!(f, "Error: {}", msg),
260        }
261    }
262}
263
264#[cfg(not(feature = "no-std"))]
265impl std::error::Error for SimdError {}
266
267#[cfg(feature = "no-std")]
268impl core::error::Error for SimdError {}
269
270/// Dispatcher trait for runtime SIMD implementation selection
271pub trait SimdDispatcher<T> {
272    /// The operation type this dispatcher handles
273    type Operation;
274
275    /// Select the best implementation for the current platform
276    fn select_implementation(
277        &self,
278    ) -> Box<dyn SimdOperation<T, Output = Self::Operation, Error = SimdError>>;
279
280    /// Get all available implementations
281    fn available_implementations(&self) -> Vec<&'static str>;
282
283    /// Force a specific implementation (for testing/benchmarking)
284    fn force_implementation(
285        &self,
286        name: &str,
287    ) -> Option<Box<dyn SimdOperation<T, Output = Self::Operation, Error = SimdError>>>;
288}
289
290/// Configuration trait for SIMD operations
291pub trait SimdConfig {
292    /// Set the preferred SIMD width
293    fn set_simd_width(&mut self, width: usize);
294
295    /// Get the current SIMD width
296    fn simd_width(&self) -> usize;
297
298    /// Enable/disable automatic fallback to scalar
299    fn set_scalar_fallback(&mut self, enabled: bool);
300
301    /// Check if scalar fallback is enabled
302    fn scalar_fallback_enabled(&self) -> bool;
303
304    /// Set numerical precision requirements
305    fn set_precision_tolerance(&mut self, tolerance: f64);
306
307    /// Get current precision tolerance
308    fn precision_tolerance(&self) -> f64;
309}
310
311/// Default configuration for SIMD operations
312#[derive(Debug, Clone)]
313pub struct DefaultSimdConfig {
314    pub simd_width: usize,
315    pub scalar_fallback: bool,
316    pub precision_tolerance: f64,
317}
318
319impl Default for DefaultSimdConfig {
320    fn default() -> Self {
321        Self {
322            simd_width: crate::SIMD_CAPS.best_f32_width(),
323            scalar_fallback: true,
324            precision_tolerance: 1e-6,
325        }
326    }
327}
328
329impl SimdConfig for DefaultSimdConfig {
330    fn set_simd_width(&mut self, width: usize) {
331        self.simd_width = width;
332    }
333
334    fn simd_width(&self) -> usize {
335        self.simd_width
336    }
337
338    fn set_scalar_fallback(&mut self, enabled: bool) {
339        self.scalar_fallback = enabled;
340    }
341
342    fn scalar_fallback_enabled(&self) -> bool {
343        self.scalar_fallback
344    }
345
346    fn set_precision_tolerance(&mut self, tolerance: f64) {
347        self.precision_tolerance = tolerance;
348    }
349
350    fn precision_tolerance(&self) -> f64 {
351        self.precision_tolerance
352    }
353}
354
355/// Trait for composable SIMD operations
356pub trait ComposableOperation<T>: SimdOperation<T> {
357    /// Compose this operation with another operation
358    fn compose<Other>(self, other: Other) -> ComposedOperation<Self, Other>
359    where
360        Self: Sized,
361        Other: SimdOperation<T>;
362
363    /// Apply a transformation to the output
364    fn map<F, U>(self, f: F) -> MappedOperation<Self, F>
365    where
366        Self: Sized,
367        F: Fn(Self::Output) -> U;
368}
369
370/// A composed operation that applies two operations in sequence
371pub struct ComposedOperation<First, Second> {
372    #[allow(dead_code)] // First stage of pipeline; used when SimdOperation impls are added
373    first: First,
374    #[allow(dead_code)] // Second stage of pipeline; used when SimdOperation impls are added
375    second: Second,
376}
377
378impl<First, Second> ComposedOperation<First, Second> {
379    pub fn new(first: First, second: Second) -> Self {
380        Self { first, second }
381    }
382}
383
384/// An operation with a mapped output transformation
385pub struct MappedOperation<Op, F> {
386    #[allow(dead_code)] // Base operation; used when SimdOperation impls are added
387    operation: Op,
388    #[allow(dead_code)] // Output mapper function; used when SimdOperation impls are added
389    mapper: F,
390}
391
392impl<Op, F> MappedOperation<Op, F> {
393    pub fn new(operation: Op, mapper: F) -> Self {
394        Self { operation, mapper }
395    }
396}
397
398/// Trait for operations that can be parallelized
399pub trait ParallelSimdOperation<T>: SimdOperation<T> {
400    /// Execute the operation in parallel across multiple chunks
401    fn execute_parallel(&self, input: &[T], chunk_size: usize)
402        -> Result<Self::Output, Self::Error>;
403
404    /// Get the optimal chunk size for parallel execution
405    fn optimal_chunk_size(&self, input_size: usize) -> usize;
406
407    /// Check if parallel execution is beneficial for the given input size
408    fn should_parallelize(&self, input_size: usize) -> bool;
409}
410
411/// Registry for SIMD operation implementations
412pub struct SimdRegistry {
413    #[cfg(not(feature = "no-std"))]
414    operations: HashMap<String, Box<dyn any::Any + Send + Sync>>,
415    #[cfg(feature = "no-std")]
416    operations: HashMap<String, Box<dyn any::Any + Send + Sync>>,
417}
418
419impl Default for SimdRegistry {
420    fn default() -> Self {
421        Self::new()
422    }
423}
424
425impl SimdRegistry {
426    /// Create a new registry
427    pub fn new() -> Self {
428        Self {
429            operations: HashMap::new(),
430        }
431    }
432
433    /// Register a new operation implementation
434    pub fn register<T: 'static + Send + Sync>(&mut self, name: String, operation: T) {
435        self.operations.insert(name, Box::new(operation));
436    }
437
438    /// Get a registered operation
439    pub fn get<T: 'static>(&self, name: &str) -> Option<&T> {
440        self.operations
441            .get(name)
442            .and_then(|op| op.downcast_ref::<T>())
443    }
444
445    /// List all registered operations
446    pub fn list_operations(&self) -> Vec<&String> {
447        self.operations.keys().collect()
448    }
449}
450
451/// Macro for implementing basic SIMD operation traits
452#[macro_export]
453macro_rules! impl_simd_operation {
454    ($type:ty, $output:ty, $name:literal) => {
455        impl SimdOperation<f32> for $type {
456            type Output = $output;
457            type Error = SimdError;
458
459            fn execute(&self, input: &[f32]) -> Result<Self::Output, Self::Error> {
460                if input.is_empty() {
461                    return Err(SimdError::EmptyInput);
462                }
463                self.compute(input)
464            }
465
466            fn optimal_width(&self) -> usize {
467                $crate::SIMD_CAPS.best_f32_width()
468            }
469
470            fn is_supported(&self) -> bool {
471                self.optimal_width() > 1
472            }
473
474            fn name(&self) -> &'static str {
475                $name
476            }
477        }
478    };
479}
480
481/// Utility functions for common trait implementations
482pub mod utils {
483    use super::*;
484
485    /// Validate that two slices have the same length
486    pub fn validate_same_length<T>(a: &[T], b: &[T]) -> Result<(), SimdError> {
487        if a.len() != b.len() {
488            Err(SimdError::DimensionMismatch {
489                expected: a.len(),
490                actual: b.len(),
491            })
492        } else {
493            Ok(())
494        }
495    }
496
497    /// Validate that a slice is not empty
498    pub fn validate_not_empty<T>(slice: &[T]) -> Result<(), SimdError> {
499        if slice.is_empty() {
500            Err(SimdError::EmptyInput)
501        } else {
502            Ok(())
503        }
504    }
505
506    /// Check if all values are finite (no NaN or infinity)
507    pub fn validate_finite(slice: &[f32]) -> Result<(), SimdError> {
508        for &value in slice {
509            if !value.is_finite() {
510                return Err(SimdError::NumericalError(format!(
511                    "Non-finite value encountered: {}",
512                    value
513                )));
514            }
515        }
516        Ok(())
517    }
518
519    /// Create a chunked iterator for parallel processing
520    pub fn create_chunks<T>(slice: &[T], chunk_size: usize) -> impl Iterator<Item = &[T]> {
521        slice.chunks(chunk_size)
522    }
523
524    /// Compute optimal chunk size based on input size and hardware
525    pub fn optimal_chunk_size(input_size: usize, simd_width: usize) -> usize {
526        let base_chunk = simd_width * 64; // Process 64 SIMD vectors per chunk
527        let max_chunk = input_size / 4; // Use at most 4 chunks
528
529        if max_chunk < base_chunk {
530            max_chunk.max(simd_width)
531        } else {
532            base_chunk
533        }
534    }
535}
536
537#[allow(non_snake_case)]
538#[cfg(all(test, not(feature = "no-std")))]
539mod tests {
540    use super::*;
541
542    #[cfg(feature = "no-std")]
543    use alloc::{vec, vec::Vec};
544
545    // Mock implementation for testing
546    struct MockVectorAdd;
547
548    impl MockVectorAdd {
549        fn compute(&self, input: &[f32]) -> Result<Vec<f32>, SimdError> {
550            Ok(input.iter().map(|&x| x + 1.0).collect())
551        }
552    }
553
554    impl_simd_operation!(MockVectorAdd, Vec<f32>, "mock_vector_add");
555
556    #[test]
557    fn test_simd_operation_trait() {
558        let op = MockVectorAdd;
559        let input = vec![1.0, 2.0, 3.0, 4.0];
560
561        let result = op.execute(&input).expect("operation should succeed");
562        assert_eq!(result, vec![2.0, 3.0, 4.0, 5.0]);
563
564        assert_eq!(op.name(), "mock_vector_add");
565        assert!(op.optimal_width() >= 1);
566    }
567
568    #[test]
569    fn test_simd_error_display() {
570        let error = SimdError::DimensionMismatch {
571            expected: 4,
572            actual: 3,
573        };
574        assert!(error.to_string().contains("Dimension mismatch"));
575
576        let error = SimdError::EmptyInput;
577        assert!(error.to_string().contains("empty"));
578    }
579
580    #[test]
581    fn test_default_simd_config() {
582        let mut config = DefaultSimdConfig::default();
583
584        assert!(config.simd_width() >= 1);
585        assert!(config.scalar_fallback_enabled());
586        assert_eq!(config.precision_tolerance(), 1e-6);
587
588        config.set_simd_width(8);
589        assert_eq!(config.simd_width(), 8);
590
591        config.set_scalar_fallback(false);
592        assert!(!config.scalar_fallback_enabled());
593
594        config.set_precision_tolerance(1e-8);
595        assert_eq!(config.precision_tolerance(), 1e-8);
596    }
597
598    #[test]
599    fn test_simd_registry() {
600        let mut registry = SimdRegistry::new();
601
602        registry.register("test_op".to_string(), MockVectorAdd);
603
604        let operations = registry.list_operations();
605        assert_eq!(operations.len(), 1);
606        assert_eq!(operations[0], "test_op");
607
608        let op = registry.get::<MockVectorAdd>("test_op");
609        assert!(op.is_some());
610
611        let nonexistent = registry.get::<MockVectorAdd>("nonexistent");
612        assert!(nonexistent.is_none());
613    }
614
615    #[test]
616    fn test_validation_utils() {
617        use utils::*;
618
619        // Test same length validation
620        let a = vec![1.0, 2.0, 3.0];
621        let b = vec![4.0, 5.0, 6.0];
622        let c = vec![7.0, 8.0];
623
624        assert!(validate_same_length(&a, &b).is_ok());
625        assert!(validate_same_length(&a, &c).is_err());
626
627        // Test empty validation
628        assert!(validate_not_empty(&a).is_ok());
629        assert!(validate_not_empty(&Vec::<f32>::new()).is_err());
630
631        // Test finite validation
632        let finite = vec![1.0, 2.0, 3.0];
633        let infinite = vec![1.0, f32::INFINITY, 3.0];
634        let nan = vec![1.0, f32::NAN, 3.0];
635
636        assert!(validate_finite(&finite).is_ok());
637        assert!(validate_finite(&infinite).is_err());
638        assert!(validate_finite(&nan).is_err());
639    }
640
641    #[test]
642    fn test_chunk_utilities() {
643        use utils::*;
644
645        let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
646        let chunks: Vec<&[i32]> = create_chunks(&data, 3).collect();
647
648        assert_eq!(chunks.len(), 4);
649        assert_eq!(chunks[0], &[1, 2, 3]);
650        assert_eq!(chunks[1], &[4, 5, 6]);
651        assert_eq!(chunks[2], &[7, 8, 9]);
652        assert_eq!(chunks[3], &[10]);
653
654        let chunk_size = optimal_chunk_size(1000, 8);
655        assert!(chunk_size >= 8);
656        assert!(chunk_size <= 1000);
657    }
658}