Skip to main content

sklears_simd/
middleware.rs

1//! Middleware system for operation pipelines
2//!
3//! This module provides a flexible middleware system for composing and chaining
4//! SIMD operations through configurable pipelines.
5
6use crate::traits::SimdError;
7
8#[cfg(feature = "no-std")]
9extern crate alloc;
10
11#[cfg(feature = "no-std")]
12use alloc::{
13    boxed::Box,
14    collections::BTreeMap as HashMap,
15    format,
16    string::{String, ToString},
17    vec::Vec,
18};
19#[cfg(feature = "no-std")]
20use core::any::Any;
21#[cfg(not(feature = "no-std"))]
22use std::{any::Any, collections::HashMap, string::ToString};
23
24// Note: format is already imported above for no-std pattern
25
26#[cfg(feature = "no-std")]
27use alloc::sync::Arc;
28#[cfg(not(feature = "no-std"))]
29use std::sync::Arc;
30
31/// Result type for middleware operations
32pub type MiddlewareResult<T> = Result<T, SimdError>;
33
34/// Context object that passes through the pipeline
35pub struct PipelineContext {
36    /// Input data for the pipeline
37    pub data: Vec<f32>,
38    /// Metadata that can be passed between middleware
39    pub metadata: HashMap<String, String>,
40    /// Arbitrary context data
41    pub context: HashMap<String, Box<dyn Any + Send + Sync>>,
42}
43
44impl PipelineContext {
45    /// Create a new pipeline context with input data
46    pub fn new(data: Vec<f32>) -> Self {
47        Self {
48            data,
49            metadata: HashMap::new(),
50            context: HashMap::new(),
51        }
52    }
53
54    /// Set metadata value
55    pub fn set_metadata(&mut self, key: String, value: String) {
56        self.metadata.insert(key, value);
57    }
58
59    /// Get metadata value
60    pub fn get_metadata(&self, key: &str) -> Option<&String> {
61        self.metadata.get(key)
62    }
63
64    /// Set context data
65    pub fn set_context<T: Any + Send + Sync>(&mut self, key: String, value: T) {
66        self.context.insert(key, Box::new(value));
67    }
68
69    /// Get context data
70    pub fn get_context<T: Any + Send + Sync>(&self, key: &str) -> Option<&T> {
71        self.context.get(key).and_then(|v| v.downcast_ref::<T>())
72    }
73}
74
75/// Trait for middleware components
76pub trait Middleware: Send + Sync {
77    /// Process the pipeline context
78    fn process(&self, context: &mut PipelineContext) -> MiddlewareResult<()>;
79
80    /// Get the middleware name
81    fn name(&self) -> &str;
82
83    /// Check if this middleware should be executed based on context
84    fn should_execute(&self, context: &PipelineContext) -> bool {
85        let _ = context; // Suppress unused parameter warning
86        true
87    }
88}
89
90/// Pipeline executor that runs middleware in sequence
91pub struct Pipeline {
92    /// List of middleware in execution order
93    middleware: Vec<Arc<dyn Middleware>>,
94    /// Pipeline name
95    name: String,
96    /// Whether to stop on first error
97    fail_fast: bool,
98}
99
100impl Pipeline {
101    /// Create a new pipeline
102    pub fn new(name: String) -> Self {
103        Self {
104            middleware: Vec::new(),
105            name,
106            fail_fast: true,
107        }
108    }
109
110    /// Add middleware to the pipeline
111    pub fn add_middleware<M: Middleware + 'static>(mut self, middleware: M) -> Self {
112        self.middleware.push(Arc::new(middleware));
113        self
114    }
115
116    /// Set fail-fast behavior
117    pub fn fail_fast(mut self, fail_fast: bool) -> Self {
118        self.fail_fast = fail_fast;
119        self
120    }
121
122    /// Execute the pipeline
123    pub fn execute(&self, mut context: PipelineContext) -> MiddlewareResult<PipelineContext> {
124        context.set_metadata("pipeline_name".to_string(), self.name.clone());
125
126        for middleware in &self.middleware {
127            if middleware.should_execute(&context) {
128                if let Err(e) = middleware.process(&mut context) {
129                    if self.fail_fast {
130                        return Err(e);
131                    }
132                    // Log error but continue if not fail-fast
133                    context.set_metadata(
134                        format!("error_{}", middleware.name()),
135                        format!("Error: {}", e),
136                    );
137                }
138            }
139        }
140
141        Ok(context)
142    }
143
144    /// Get pipeline name
145    pub fn name(&self) -> &str {
146        &self.name
147    }
148
149    /// Get middleware count
150    pub fn middleware_count(&self) -> usize {
151        self.middleware.len()
152    }
153}
154
155/// Builder for creating pipelines
156pub struct PipelineBuilder {
157    name: String,
158    middleware: Vec<Arc<dyn Middleware>>,
159    fail_fast: bool,
160}
161
162impl PipelineBuilder {
163    /// Create a new pipeline builder
164    pub fn new(name: String) -> Self {
165        Self {
166            name,
167            middleware: Vec::new(),
168            fail_fast: true,
169        }
170    }
171
172    /// Add middleware to the pipeline
173    pub fn with_middleware<M: Middleware + 'static>(mut self, middleware: M) -> Self {
174        self.middleware.push(Arc::new(middleware));
175        self
176    }
177
178    /// Set fail-fast behavior
179    pub fn fail_fast(mut self, fail_fast: bool) -> Self {
180        self.fail_fast = fail_fast;
181        self
182    }
183
184    /// Build the pipeline
185    pub fn build(self) -> Pipeline {
186        Pipeline {
187            middleware: self.middleware,
188            name: self.name,
189            fail_fast: self.fail_fast,
190        }
191    }
192}
193
194/// Common middleware implementations
195/// Normalization middleware
196#[derive(Debug, Clone)]
197pub struct NormalizationMiddleware {
198    /// Normalization type
199    norm_type: NormType,
200}
201
202#[derive(Debug, Clone)]
203pub enum NormType {
204    L1,
205    L2,
206    MinMax,
207}
208
209impl NormalizationMiddleware {
210    pub fn new(norm_type: NormType) -> Self {
211        Self { norm_type }
212    }
213}
214
215impl Middleware for NormalizationMiddleware {
216    fn process(&self, context: &mut PipelineContext) -> MiddlewareResult<()> {
217        let data = &mut context.data;
218
219        match self.norm_type {
220            NormType::L1 => {
221                let sum: f32 = data.iter().map(|x| x.abs()).sum();
222                if sum != 0.0 {
223                    data.iter_mut().for_each(|x| *x /= sum);
224                }
225            }
226            NormType::L2 => {
227                let norm: f32 = data.iter().map(|x| x * x).sum::<f32>().sqrt();
228                if norm != 0.0 {
229                    data.iter_mut().for_each(|x| *x /= norm);
230                }
231            }
232            NormType::MinMax => {
233                let min_val = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
234                let max_val = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
235                let range = max_val - min_val;
236                if range != 0.0 {
237                    data.iter_mut().for_each(|x| *x = (*x - min_val) / range);
238                }
239            }
240        }
241
242        context.set_metadata("normalized".to_string(), format!("{:?}", self.norm_type));
243        Ok(())
244    }
245
246    fn name(&self) -> &str {
247        "normalization"
248    }
249}
250
251/// Filtering middleware
252#[derive(Debug, Clone)]
253pub struct FilteringMiddleware {
254    /// Minimum value threshold
255    min_threshold: f32,
256    /// Maximum value threshold
257    max_threshold: f32,
258}
259
260impl FilteringMiddleware {
261    pub fn new(min_threshold: f32, max_threshold: f32) -> Self {
262        Self {
263            min_threshold,
264            max_threshold,
265        }
266    }
267}
268
269impl Middleware for FilteringMiddleware {
270    fn process(&self, context: &mut PipelineContext) -> MiddlewareResult<()> {
271        let original_len = context.data.len();
272        context
273            .data
274            .retain(|&x| x >= self.min_threshold && x <= self.max_threshold);
275
276        let filtered_count = original_len - context.data.len();
277        context.set_metadata("filtered_count".to_string(), filtered_count.to_string());
278
279        Ok(())
280    }
281
282    fn name(&self) -> &str {
283        "filtering"
284    }
285}
286
287/// Transformation middleware
288#[derive(Debug, Clone)]
289pub struct TransformationMiddleware {
290    /// Transformation function
291    transform_type: TransformType,
292}
293
294#[derive(Debug, Clone)]
295pub enum TransformType {
296    Log,
297    Exp,
298    Sqrt,
299    Square,
300    Abs,
301}
302
303impl TransformationMiddleware {
304    pub fn new(transform_type: TransformType) -> Self {
305        Self { transform_type }
306    }
307}
308
309impl Middleware for TransformationMiddleware {
310    fn process(&self, context: &mut PipelineContext) -> MiddlewareResult<()> {
311        let data = &mut context.data;
312
313        match self.transform_type {
314            TransformType::Log => {
315                data.iter_mut().for_each(|x| *x = x.max(f32::EPSILON).ln());
316            }
317            TransformType::Exp => {
318                data.iter_mut().for_each(|x| *x = x.exp());
319            }
320            TransformType::Sqrt => {
321                data.iter_mut().for_each(|x| *x = x.max(0.0).sqrt());
322            }
323            TransformType::Square => {
324                data.iter_mut().for_each(|x| *x = *x * *x);
325            }
326            TransformType::Abs => {
327                data.iter_mut().for_each(|x| *x = x.abs());
328            }
329        }
330
331        context.set_metadata(
332            "transformed".to_string(),
333            format!("{:?}", self.transform_type),
334        );
335        Ok(())
336    }
337
338    fn name(&self) -> &str {
339        "transformation"
340    }
341}
342
343/// Aggregation middleware
344#[derive(Debug, Clone)]
345pub struct AggregationMiddleware {
346    /// Aggregation function
347    agg_type: AggregationType,
348}
349
350#[derive(Debug, Clone)]
351pub enum AggregationType {
352    Sum,
353    Mean,
354    Max,
355    Min,
356    StdDev,
357}
358
359impl AggregationMiddleware {
360    pub fn new(agg_type: AggregationType) -> Self {
361        Self { agg_type }
362    }
363}
364
365impl Middleware for AggregationMiddleware {
366    fn process(&self, context: &mut PipelineContext) -> MiddlewareResult<()> {
367        let data = &context.data;
368
369        if data.is_empty() {
370            return Err(SimdError::InvalidInput(
371                "Empty data for aggregation".to_string(),
372            ));
373        }
374
375        let result = match self.agg_type {
376            AggregationType::Sum => data.iter().sum::<f32>(),
377            AggregationType::Mean => data.iter().sum::<f32>() / data.len() as f32,
378            AggregationType::Max => data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b)),
379            AggregationType::Min => data.iter().fold(f32::INFINITY, |a, &b| a.min(b)),
380            AggregationType::StdDev => {
381                let mean = data.iter().sum::<f32>() / data.len() as f32;
382                let variance =
383                    data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / data.len() as f32;
384                variance.sqrt()
385            }
386        };
387
388        context.set_metadata("aggregation_result".to_string(), result.to_string());
389        context.set_metadata(
390            "aggregation_type".to_string(),
391            format!("{:?}", self.agg_type),
392        );
393
394        Ok(())
395    }
396
397    fn name(&self) -> &str {
398        "aggregation"
399    }
400}
401
402/// Conditional middleware that executes based on context
403pub struct ConditionalMiddleware {
404    /// Condition function
405    condition: Box<dyn Fn(&PipelineContext) -> bool + Send + Sync>,
406    /// Wrapped middleware
407    middleware: Arc<dyn Middleware>,
408}
409
410impl ConditionalMiddleware {
411    pub fn new<F, M>(condition: F, middleware: M) -> Self
412    where
413        F: Fn(&PipelineContext) -> bool + Send + Sync + 'static,
414        M: Middleware + 'static,
415    {
416        Self {
417            condition: Box::new(condition),
418            middleware: Arc::new(middleware),
419        }
420    }
421}
422
423impl Middleware for ConditionalMiddleware {
424    fn process(&self, context: &mut PipelineContext) -> MiddlewareResult<()> {
425        if (self.condition)(context) {
426            self.middleware.process(context)
427        } else {
428            Ok(())
429        }
430    }
431
432    fn name(&self) -> &str {
433        "conditional"
434    }
435
436    fn should_execute(&self, context: &PipelineContext) -> bool {
437        (self.condition)(context)
438    }
439}
440
441#[allow(non_snake_case)]
442#[cfg(all(test, not(feature = "no-std")))]
443mod tests {
444    use super::*;
445
446    #[cfg(feature = "no-std")]
447    use alloc::{
448        string::{String, ToString},
449        vec,
450        vec::Vec,
451    };
452
453    #[test]
454    fn test_pipeline_context_creation() {
455        let data = vec![1.0, 2.0, 3.0, 4.0];
456        let context = PipelineContext::new(data.clone());
457
458        assert_eq!(context.data, data);
459        assert!(context.metadata.is_empty());
460        assert!(context.context.is_empty());
461    }
462
463    #[test]
464    fn test_pipeline_context_metadata() {
465        let mut context = PipelineContext::new(vec![1.0, 2.0, 3.0]);
466
467        context.set_metadata("test_key".to_string(), "test_value".to_string());
468        assert_eq!(
469            context.get_metadata("test_key"),
470            Some(&"test_value".to_string())
471        );
472        assert_eq!(context.get_metadata("nonexistent"), None);
473    }
474
475    #[test]
476    fn test_pipeline_context_context_data() {
477        let mut context = PipelineContext::new(vec![1.0, 2.0, 3.0]);
478
479        context.set_context("test_int".to_string(), 42i32);
480        assert_eq!(context.get_context::<i32>("test_int"), Some(&42i32));
481        assert_eq!(context.get_context::<f32>("test_int"), None);
482    }
483
484    #[test]
485    fn test_normalization_middleware_l2() {
486        let mut context = PipelineContext::new(vec![3.0, 4.0, 0.0]);
487        let middleware = NormalizationMiddleware::new(NormType::L2);
488
489        middleware
490            .process(&mut context)
491            .expect("operation should succeed");
492
493        // L2 norm of [3, 4, 0] is 5, so normalized should be [0.6, 0.8, 0.0]
494        assert!((context.data[0] - 0.6).abs() < 1e-6);
495        assert!((context.data[1] - 0.8).abs() < 1e-6);
496        assert!((context.data[2] - 0.0).abs() < 1e-6);
497    }
498
499    #[test]
500    fn test_filtering_middleware() {
501        let mut context = PipelineContext::new(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
502        let middleware = FilteringMiddleware::new(2.0, 4.0);
503
504        middleware
505            .process(&mut context)
506            .expect("operation should succeed");
507
508        assert_eq!(context.data, vec![2.0, 3.0, 4.0]);
509        assert_eq!(
510            context.get_metadata("filtered_count"),
511            Some(&"2".to_string())
512        );
513    }
514
515    #[test]
516    fn test_transformation_middleware_sqrt() {
517        let mut context = PipelineContext::new(vec![1.0, 4.0, 9.0, 16.0]);
518        let middleware = TransformationMiddleware::new(TransformType::Sqrt);
519
520        middleware
521            .process(&mut context)
522            .expect("operation should succeed");
523
524        assert_eq!(context.data, vec![1.0, 2.0, 3.0, 4.0]);
525    }
526
527    #[test]
528    fn test_aggregation_middleware_mean() {
529        let mut context = PipelineContext::new(vec![1.0, 2.0, 3.0, 4.0]);
530        let middleware = AggregationMiddleware::new(AggregationType::Mean);
531
532        middleware
533            .process(&mut context)
534            .expect("operation should succeed");
535
536        assert_eq!(
537            context.get_metadata("aggregation_result"),
538            Some(&"2.5".to_string())
539        );
540        assert_eq!(
541            context.get_metadata("aggregation_type"),
542            Some(&"Mean".to_string())
543        );
544    }
545
546    #[test]
547    fn test_pipeline_builder() {
548        let pipeline = PipelineBuilder::new("test_pipeline".to_string())
549            .with_middleware(NormalizationMiddleware::new(NormType::L2))
550            .with_middleware(FilteringMiddleware::new(0.1, 0.9))
551            .fail_fast(false)
552            .build();
553
554        assert_eq!(pipeline.name(), "test_pipeline");
555        assert_eq!(pipeline.middleware_count(), 2);
556    }
557
558    #[test]
559    fn test_pipeline_execution() {
560        let pipeline = Pipeline::new("test_pipeline".to_string())
561            .add_middleware(NormalizationMiddleware::new(NormType::L2))
562            .add_middleware(TransformationMiddleware::new(TransformType::Square));
563
564        let context = PipelineContext::new(vec![3.0, 4.0, 0.0]);
565        let result = pipeline.execute(context).expect("operation should succeed");
566
567        // After L2 normalization: [0.6, 0.8, 0.0]
568        // After squaring: [0.36, 0.64, 0.0]
569        assert!((result.data[0] - 0.36).abs() < 1e-6);
570        assert!((result.data[1] - 0.64).abs() < 1e-6);
571        assert!((result.data[2] - 0.0).abs() < 1e-6);
572    }
573
574    #[test]
575    fn test_conditional_middleware() {
576        let condition = |context: &PipelineContext| context.data.len() > 2;
577        let middleware =
578            ConditionalMiddleware::new(condition, NormalizationMiddleware::new(NormType::L2));
579
580        // Test with data length > 2 (should execute)
581        let mut context = PipelineContext::new(vec![3.0, 4.0, 0.0]);
582        middleware
583            .process(&mut context)
584            .expect("operation should succeed");
585        assert!((context.data[0] - 0.6).abs() < 1e-6);
586
587        // Test with data length <= 2 (should not execute)
588        let mut context = PipelineContext::new(vec![3.0, 4.0]);
589        let original_data = context.data.clone();
590        middleware
591            .process(&mut context)
592            .expect("operation should succeed");
593        assert_eq!(context.data, original_data); // Should be unchanged
594    }
595
596    #[test]
597    fn test_empty_data_handling() {
598        let mut context = PipelineContext::new(vec![]);
599        let middleware = AggregationMiddleware::new(AggregationType::Mean);
600
601        let result = middleware.process(&mut context);
602        assert!(result.is_err());
603    }
604
605    #[test]
606    fn test_pipeline_metadata() {
607        let pipeline = Pipeline::new("test_pipeline".to_string())
608            .add_middleware(NormalizationMiddleware::new(NormType::L2));
609
610        let context = PipelineContext::new(vec![1.0, 2.0, 3.0]);
611        let result = pipeline.execute(context).expect("operation should succeed");
612
613        assert_eq!(
614            result.get_metadata("pipeline_name"),
615            Some(&"test_pipeline".to_string())
616        );
617        assert_eq!(result.get_metadata("normalized"), Some(&"L2".to_string()));
618    }
619}