sklears_kernel_approximation/
middleware.rs

1//! Middleware system for kernel approximation pipelines
2//!
3//! This module provides a flexible middleware architecture for composing
4//! kernel approximation transformations with hooks, callbacks, and monitoring.
5//!
6//! # Examples
7//!
8//! ```rust
9//! use sklears_kernel_approximation::middleware::{Pipeline, PipelineBuilder};
10//! // Create a pipeline with multiple transformations
11//! // let pipeline = PipelineBuilder::new()
12//! //     .add_transform(rbf_sampler)
13//! //     .add_hook(logging_hook)
14//! //     .add_middleware(normalization_middleware)
15//! //     .build();
16//! ```
17
18use scirs2_core::ndarray::Array2;
19use sklears_core::error::SklearsError;
20use std::any::Any;
21use std::sync::Arc;
22use std::time::Instant;
23
24/// Hook that can be called at various stages of the pipeline
25pub trait Hook: Send + Sync {
26    /// Called before fit
27    fn before_fit(
28        &mut self,
29        x: &Array2<f64>,
30        context: &mut HookContext,
31    ) -> Result<(), SklearsError> {
32        let _ = (x, context);
33        Ok(())
34    }
35
36    /// Called after fit
37    fn after_fit(
38        &mut self,
39        x: &Array2<f64>,
40        context: &mut HookContext,
41    ) -> Result<(), SklearsError> {
42        let _ = (x, context);
43        Ok(())
44    }
45
46    /// Called before transform
47    fn before_transform(
48        &mut self,
49        x: &Array2<f64>,
50        context: &mut HookContext,
51    ) -> Result<(), SklearsError> {
52        let _ = (x, context);
53        Ok(())
54    }
55
56    /// Called after transform
57    fn after_transform(
58        &mut self,
59        x: &Array2<f64>,
60        output: &Array2<f64>,
61        context: &mut HookContext,
62    ) -> Result<(), SklearsError> {
63        let _ = (x, output, context);
64        Ok(())
65    }
66
67    /// Called on error
68    fn on_error(&mut self, error: &SklearsError, context: &mut HookContext) {
69        let _ = (error, context);
70    }
71
72    /// Get hook name
73    fn name(&self) -> &str {
74        "Hook"
75    }
76}
77
78/// Context passed to hooks containing metadata
79#[derive(Debug, Clone, Default)]
80pub struct HookContext {
81    /// Stage name (e.g., "fit", "transform")
82    pub stage: String,
83    /// Transform index in pipeline
84    pub transform_index: usize,
85    /// Transform name
86    pub transform_name: String,
87    /// Elapsed time in milliseconds
88    pub elapsed_ms: f64,
89    /// Additional metadata
90    pub metadata: std::collections::HashMap<String, String>,
91}
92
93impl HookContext {
94    /// Create a new hook context
95    pub fn new(stage: &str, transform_index: usize, transform_name: &str) -> Self {
96        Self {
97            stage: stage.to_string(),
98            transform_index,
99            transform_name: transform_name.to_string(),
100            elapsed_ms: 0.0,
101            metadata: std::collections::HashMap::new(),
102        }
103    }
104
105    /// Add metadata entry
106    pub fn add_metadata(&mut self, key: String, value: String) {
107        self.metadata.insert(key, value);
108    }
109}
110
111/// Middleware that wraps transformations
112pub trait Middleware: Send + Sync {
113    /// Process before fit
114    fn process_before_fit(&self, x: &Array2<f64>) -> Result<Array2<f64>, SklearsError> {
115        Ok(x.clone())
116    }
117
118    /// Process after fit
119    fn process_after_fit(&self, x: &Array2<f64>) -> Result<Array2<f64>, SklearsError> {
120        Ok(x.clone())
121    }
122
123    /// Process before transform
124    fn process_before_transform(&self, x: &Array2<f64>) -> Result<Array2<f64>, SklearsError> {
125        Ok(x.clone())
126    }
127
128    /// Process after transform
129    fn process_after_transform(&self, x: &Array2<f64>) -> Result<Array2<f64>, SklearsError> {
130        Ok(x.clone())
131    }
132
133    /// Get middleware name
134    fn name(&self) -> &str {
135        "Middleware"
136    }
137}
138
139/// Transform stage in the pipeline
140pub trait PipelineStage: Send + Sync {
141    /// Fit the stage
142    fn fit(&mut self, x: &Array2<f64>) -> Result<(), SklearsError>;
143
144    /// Transform using the fitted stage
145    fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>, SklearsError>;
146
147    /// Check if stage is fitted
148    fn is_fitted(&self) -> bool;
149
150    /// Get stage name
151    fn name(&self) -> &str;
152
153    /// Clone the stage
154    fn clone_stage(&self) -> Box<dyn PipelineStage>;
155
156    /// Get as Any for downcasting
157    fn as_any(&self) -> &dyn Any;
158}
159
160/// Logging hook that records timing and shapes
161pub struct LoggingHook {
162    logs: Vec<String>,
163}
164
165impl LoggingHook {
166    /// Create a new logging hook
167    pub fn new() -> Self {
168        Self { logs: Vec::new() }
169    }
170
171    /// Get all logs
172    pub fn logs(&self) -> &[String] {
173        &self.logs
174    }
175}
176
177impl Default for LoggingHook {
178    fn default() -> Self {
179        Self::new()
180    }
181}
182
183impl Hook for LoggingHook {
184    fn before_fit(
185        &mut self,
186        x: &Array2<f64>,
187        context: &mut HookContext,
188    ) -> Result<(), SklearsError> {
189        let log = format!(
190            "[{}] Before fit - transform: {}, shape: {:?}",
191            context.stage,
192            context.transform_name,
193            x.dim()
194        );
195        self.logs.push(log);
196        Ok(())
197    }
198
199    fn after_fit(
200        &mut self,
201        _x: &Array2<f64>,
202        context: &mut HookContext,
203    ) -> Result<(), SklearsError> {
204        let log = format!(
205            "[{}] After fit - transform: {}, time: {:.2}ms",
206            context.stage, context.transform_name, context.elapsed_ms
207        );
208        self.logs.push(log);
209        Ok(())
210    }
211
212    fn before_transform(
213        &mut self,
214        x: &Array2<f64>,
215        context: &mut HookContext,
216    ) -> Result<(), SklearsError> {
217        let log = format!(
218            "[{}] Before transform - transform: {}, shape: {:?}",
219            context.stage,
220            context.transform_name,
221            x.dim()
222        );
223        self.logs.push(log);
224        Ok(())
225    }
226
227    fn after_transform(
228        &mut self,
229        _x: &Array2<f64>,
230        output: &Array2<f64>,
231        context: &mut HookContext,
232    ) -> Result<(), SklearsError> {
233        let log = format!(
234            "[{}] After transform - transform: {}, output shape: {:?}, time: {:.2}ms",
235            context.stage,
236            context.transform_name,
237            output.dim(),
238            context.elapsed_ms
239        );
240        self.logs.push(log);
241        Ok(())
242    }
243
244    fn name(&self) -> &str {
245        "LoggingHook"
246    }
247}
248
249/// Normalization middleware
250pub struct NormalizationMiddleware {
251    mean: Option<Array2<f64>>,
252    std: Option<Array2<f64>>,
253}
254
255impl NormalizationMiddleware {
256    /// Create a new normalization middleware
257    pub fn new() -> Self {
258        Self {
259            mean: None,
260            std: None,
261        }
262    }
263}
264
265impl Default for NormalizationMiddleware {
266    fn default() -> Self {
267        Self::new()
268    }
269}
270
271impl Middleware for NormalizationMiddleware {
272    fn process_before_fit(&self, x: &Array2<f64>) -> Result<Array2<f64>, SklearsError> {
273        // Calculate mean and std
274        let mean = x
275            .mean_axis(scirs2_core::ndarray::Axis(0))
276            .ok_or_else(|| SklearsError::InvalidInput("Cannot compute mean".to_string()))?;
277        let std = x.std_axis(scirs2_core::ndarray::Axis(0), 0.0);
278
279        // Normalize
280        let mut normalized = x.clone();
281        for (i, mut col) in normalized
282            .axis_iter_mut(scirs2_core::ndarray::Axis(1))
283            .enumerate()
284        {
285            let std_val = std[i].max(1e-8); // Avoid division by zero
286            for elem in col.iter_mut() {
287                *elem = (*elem - mean[i]) / std_val;
288            }
289        }
290
291        Ok(normalized)
292    }
293
294    fn process_before_transform(&self, x: &Array2<f64>) -> Result<Array2<f64>, SklearsError> {
295        self.process_before_fit(x)
296    }
297
298    fn name(&self) -> &str {
299        "NormalizationMiddleware"
300    }
301}
302
303/// Pipeline for composing multiple kernel approximations
304pub struct Pipeline {
305    stages: Vec<Box<dyn PipelineStage>>,
306    hooks: Vec<Box<dyn Hook>>,
307    middleware: Vec<Arc<dyn Middleware>>,
308    name: String,
309    is_fitted: bool,
310}
311
312impl Pipeline {
313    /// Create a new pipeline
314    pub fn new(name: String) -> Self {
315        Self {
316            stages: Vec::new(),
317            hooks: Vec::new(),
318            middleware: Vec::new(),
319            name,
320            is_fitted: false,
321        }
322    }
323
324    /// Add a stage to the pipeline
325    pub fn add_stage(&mut self, stage: Box<dyn PipelineStage>) {
326        self.stages.push(stage);
327    }
328
329    /// Add a hook
330    pub fn add_hook(&mut self, hook: Box<dyn Hook>) {
331        self.hooks.push(hook);
332    }
333
334    /// Add middleware
335    pub fn add_middleware(&mut self, middleware: Arc<dyn Middleware>) {
336        self.middleware.push(middleware);
337    }
338
339    /// Fit the pipeline
340    pub fn fit(&mut self, x: &Array2<f64>) -> Result<(), SklearsError> {
341        let mut current_data = x.clone();
342
343        // Apply middleware before fit
344        for mw in &self.middleware {
345            current_data = mw.process_before_fit(&current_data)?;
346        }
347
348        // Fit each stage
349        for (idx, stage) in self.stages.iter_mut().enumerate() {
350            let start = Instant::now();
351            let mut context = HookContext::new("fit", idx, stage.name());
352
353            // Call before fit hooks
354            for hook in &mut self.hooks {
355                hook.before_fit(&current_data, &mut context)?;
356            }
357
358            // Fit the stage
359            match stage.fit(&current_data) {
360                Ok(_) => {
361                    context.elapsed_ms = start.elapsed().as_secs_f64() * 1000.0;
362
363                    // Call after fit hooks
364                    for hook in &mut self.hooks {
365                        hook.after_fit(&current_data, &mut context)?;
366                    }
367
368                    // Transform for next stage
369                    current_data = stage.transform(&current_data)?;
370                }
371                Err(e) => {
372                    for hook in &mut self.hooks {
373                        hook.on_error(&e, &mut context);
374                    }
375                    return Err(e);
376                }
377            }
378        }
379
380        // Apply middleware after fit
381        for mw in &self.middleware {
382            current_data = mw.process_after_fit(&current_data)?;
383        }
384
385        self.is_fitted = true;
386        Ok(())
387    }
388
389    /// Transform using the fitted pipeline
390    pub fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>, SklearsError> {
391        if !self.is_fitted {
392            return Err(SklearsError::NotFitted {
393                operation: "Pipeline must be fitted before transform".to_string(),
394            });
395        }
396
397        let mut current_data = x.clone();
398
399        // Apply middleware before transform
400        for mw in &self.middleware {
401            current_data = mw.process_before_transform(&current_data)?;
402        }
403
404        // Transform through each stage
405        for stage in self.stages.iter() {
406            // Note: Hook calls during transform would require interior mutability
407            // since transform takes &self. For now, hooks are only called during fit.
408            // To enable transform hooks, consider using Arc<Mutex<dyn Hook>> instead.
409
410            // Transform
411            match stage.transform(&current_data) {
412                Ok(output) => {
413                    current_data = output;
414                }
415                Err(e) => {
416                    return Err(e);
417                }
418            }
419        }
420
421        // Apply middleware after transform
422        for mw in &self.middleware {
423            current_data = mw.process_after_transform(&current_data)?;
424        }
425
426        Ok(current_data)
427    }
428
429    /// Get pipeline name
430    pub fn name(&self) -> &str {
431        &self.name
432    }
433
434    /// Check if pipeline is fitted
435    pub fn is_fitted(&self) -> bool {
436        self.is_fitted
437    }
438
439    /// Get number of stages
440    pub fn len(&self) -> usize {
441        self.stages.len()
442    }
443
444    /// Check if pipeline is empty
445    pub fn is_empty(&self) -> bool {
446        self.stages.is_empty()
447    }
448}
449
450/// Builder for creating pipelines
451pub struct PipelineBuilder {
452    pipeline: Pipeline,
453}
454
455impl PipelineBuilder {
456    /// Create a new pipeline builder
457    pub fn new(name: &str) -> Self {
458        Self {
459            pipeline: Pipeline::new(name.to_string()),
460        }
461    }
462
463    /// Add a stage
464    pub fn add_stage(mut self, stage: Box<dyn PipelineStage>) -> Self {
465        self.pipeline.add_stage(stage);
466        self
467    }
468
469    /// Add a hook
470    pub fn add_hook(mut self, hook: Box<dyn Hook>) -> Self {
471        self.pipeline.add_hook(hook);
472        self
473    }
474
475    /// Add middleware
476    pub fn add_middleware(mut self, middleware: Arc<dyn Middleware>) -> Self {
477        self.pipeline.add_middleware(middleware);
478        self
479    }
480
481    /// Build the pipeline
482    pub fn build(self) -> Pipeline {
483        self.pipeline
484    }
485}
486
487/// Validation hook that checks for NaN and Inf
488pub struct ValidationHook;
489
490impl Hook for ValidationHook {
491    fn after_transform(
492        &mut self,
493        _x: &Array2<f64>,
494        output: &Array2<f64>,
495        _context: &mut HookContext,
496    ) -> Result<(), SklearsError> {
497        for &val in output.iter() {
498            if val.is_nan() || val.is_infinite() {
499                return Err(SklearsError::InvalidInput(
500                    "Output contains NaN or Inf values".to_string(),
501                ));
502            }
503        }
504        Ok(())
505    }
506
507    fn name(&self) -> &str {
508        "ValidationHook"
509    }
510}
511
512/// Performance monitoring hook
513pub struct PerformanceHook {
514    timings: Vec<(String, f64)>,
515}
516
517impl PerformanceHook {
518    /// Create a new performance hook
519    pub fn new() -> Self {
520        Self {
521            timings: Vec::new(),
522        }
523    }
524
525    /// Get all timings
526    pub fn timings(&self) -> &[(String, f64)] {
527        &self.timings
528    }
529
530    /// Get total time
531    pub fn total_time(&self) -> f64 {
532        self.timings.iter().map(|(_, t)| t).sum()
533    }
534}
535
536impl Default for PerformanceHook {
537    fn default() -> Self {
538        Self::new()
539    }
540}
541
542impl Hook for PerformanceHook {
543    fn after_transform(
544        &mut self,
545        _x: &Array2<f64>,
546        _output: &Array2<f64>,
547        context: &mut HookContext,
548    ) -> Result<(), SklearsError> {
549        self.timings
550            .push((context.transform_name.clone(), context.elapsed_ms));
551        Ok(())
552    }
553
554    fn name(&self) -> &str {
555        "PerformanceHook"
556    }
557}
558
559#[cfg(test)]
560mod tests {
561    use super::*;
562    use scirs2_core::ndarray::array;
563
564    struct DummyStage {
565        name: String,
566        fitted: bool,
567    }
568
569    impl DummyStage {
570        fn new(name: &str) -> Self {
571            Self {
572                name: name.to_string(),
573                fitted: false,
574            }
575        }
576    }
577
578    impl PipelineStage for DummyStage {
579        fn fit(&mut self, _x: &Array2<f64>) -> Result<(), SklearsError> {
580            self.fitted = true;
581            Ok(())
582        }
583
584        fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>, SklearsError> {
585            if !self.fitted {
586                return Err(SklearsError::NotFitted {
587                    operation: "Stage not fitted".to_string(),
588                });
589            }
590            Ok(x.mapv(|v| v * 2.0))
591        }
592
593        fn is_fitted(&self) -> bool {
594            self.fitted
595        }
596
597        fn name(&self) -> &str {
598            &self.name
599        }
600
601        fn clone_stage(&self) -> Box<dyn PipelineStage> {
602            Box::new(DummyStage {
603                name: self.name.clone(),
604                fitted: self.fitted,
605            })
606        }
607
608        fn as_any(&self) -> &dyn Any {
609            self
610        }
611    }
612
613    #[test]
614    fn test_pipeline_basic() {
615        let mut pipeline = Pipeline::new("test".to_string());
616        pipeline.add_stage(Box::new(DummyStage::new("stage1")));
617
618        let x = array![[1.0, 2.0], [3.0, 4.0]];
619        pipeline.fit(&x).unwrap();
620
621        let result = pipeline.transform(&x).unwrap();
622        assert_eq!(result[[0, 0]], 2.0);
623        assert_eq!(result[[1, 1]], 8.0);
624    }
625
626    #[test]
627    fn test_pipeline_builder() {
628        let pipeline = PipelineBuilder::new("test")
629            .add_stage(Box::new(DummyStage::new("stage1")))
630            .add_hook(Box::new(LoggingHook::new()))
631            .build();
632
633        assert_eq!(pipeline.len(), 1);
634        assert!(!pipeline.is_empty());
635    }
636
637    #[test]
638    fn test_logging_hook() {
639        let mut hook = LoggingHook::new();
640        let x = array![[1.0, 2.0]];
641        let mut context = HookContext::new("fit", 0, "test_stage");
642
643        hook.before_fit(&x, &mut context).unwrap();
644        assert_eq!(hook.logs().len(), 1);
645    }
646
647    #[test]
648    fn test_validation_hook() {
649        let mut hook = ValidationHook;
650        let x = array![[1.0, 2.0]];
651        let output = array![[1.0, 2.0]];
652        let mut context = HookContext::new("transform", 0, "test");
653
654        assert!(hook.after_transform(&x, &output, &mut context).is_ok());
655
656        let invalid_output = array![[f64::NAN, 2.0]];
657        assert!(hook
658            .after_transform(&x, &invalid_output, &mut context)
659            .is_err());
660    }
661
662    #[test]
663    fn test_performance_hook() {
664        let mut hook = PerformanceHook::new();
665        let x = array![[1.0, 2.0]];
666        let output = array![[2.0, 4.0]];
667        let mut context = HookContext::new("transform", 0, "test");
668        context.elapsed_ms = 10.0;
669
670        hook.after_transform(&x, &output, &mut context).unwrap();
671        assert_eq!(hook.timings().len(), 1);
672        assert_eq!(hook.total_time(), 10.0);
673    }
674
675    #[test]
676    fn test_hook_context() {
677        let mut context = HookContext::new("fit", 0, "test_stage");
678        context.add_metadata("key".to_string(), "value".to_string());
679
680        assert_eq!(context.stage, "fit");
681        assert_eq!(context.transform_index, 0);
682        assert!(context.metadata.contains_key("key"));
683    }
684}