sklears_preprocessing/
type_safety.rs

1//! Advanced type safety for preprocessing transformers
2//!
3//! This module provides compile-time guarantees for transformation states and pipeline composition
4//! using Rust's advanced type system features including:
5//! - Phantom types for tracking fitted/unfitted states
6//! - Const generics for compile-time dimension checking
7//! - Type-level programming for pipeline validation
8//! - Zero-cost abstractions for transformation composition
9
10use scirs2_core::ndarray::{Array1, Array2};
11use sklears_core::prelude::*;
12use std::marker::PhantomData;
13
14// ================================================================================================
15// State Markers
16// ================================================================================================
17
18/// Marker trait for transformation states
19pub trait TransformState: sealed::Sealed {}
20
21/// Unfitted state - transformer has not been fitted to data
22#[derive(Debug, Clone, Copy)]
23pub struct Unfitted;
24
25/// Fitted state - transformer has been fitted to data and can transform
26#[derive(Debug, Clone, Copy)]
27pub struct Fitted;
28
29mod sealed {
30    pub trait Sealed {}
31    impl Sealed for super::Unfitted {}
32    impl Sealed for super::Fitted {}
33
34    pub trait DimensionSealed {}
35    impl DimensionSealed for super::Dynamic {}
36    impl<const N: usize> DimensionSealed for super::Known<N> {}
37}
38
39impl TransformState for Unfitted {}
40impl TransformState for Fitted {}
41
42// ================================================================================================
43// Dimension Markers
44// ================================================================================================
45
46/// Marker for unknown dimensions (determined at runtime)
47pub struct Dynamic;
48
49/// Marker for known dimensions (determined at compile time)
50pub struct Known<const N: usize>;
51
52/// Trait for dimension types
53pub trait Dimension: sealed::DimensionSealed {
54    /// Get the dimension value if known at compile time
55    fn value() -> Option<usize>;
56}
57
58impl Dimension for Dynamic {
59    fn value() -> Option<usize> {
60        None
61    }
62}
63
64impl<const N: usize> Dimension for Known<N> {
65    fn value() -> Option<usize> {
66        Some(N)
67    }
68}
69
70// ================================================================================================
71// Type-Safe Transformer
72// ================================================================================================
73
74/// Type-safe transformer with compile-time state and dimension tracking
75///
76/// # Type Parameters
77/// * `S` - State marker (Unfitted or Fitted)
78/// * `InDim` - Input dimension marker (Dynamic or `Known<N>`)
79/// * `OutDim` - Output dimension marker (Dynamic or `Known<N>`)
80#[derive(Debug, Clone)]
81pub struct TypeSafeTransformer<S: TransformState, InDim: Dimension, OutDim: Dimension> {
82    /// Configuration
83    config: TypeSafeConfig,
84    /// Input dimension (runtime value)
85    input_dim: Option<usize>,
86    /// Output dimension (runtime value)
87    output_dim: Option<usize>,
88    /// Fitted parameters (only available in Fitted state)
89    parameters: Option<TransformParameters>,
90    /// State marker (zero-sized, compile-time only)
91    _state: PhantomData<S>,
92    /// Input dimension marker (zero-sized, compile-time only)
93    _in_dim: PhantomData<InDim>,
94    /// Output dimension marker (zero-sized, compile-time only)
95    _out_dim: PhantomData<OutDim>,
96}
97
98/// Configuration for type-safe transformer
99#[derive(Debug, Clone)]
100pub struct TypeSafeConfig {
101    /// Whether to validate dimensions at runtime
102    pub validate_dimensions: bool,
103    /// Whether to normalize outputs
104    pub normalize: bool,
105}
106
107impl Default for TypeSafeConfig {
108    fn default() -> Self {
109        Self {
110            validate_dimensions: true,
111            normalize: false,
112        }
113    }
114}
115
116/// Fitted parameters for the transformer
117#[derive(Debug, Clone)]
118struct TransformParameters {
119    /// Mean for normalization
120    mean: Array1<f64>,
121    /// Standard deviation for normalization
122    std: Array1<f64>,
123}
124
125// ================================================================================================
126// Implementation for Unfitted State
127// ================================================================================================
128
129impl<InDim: Dimension, OutDim: Dimension> TypeSafeTransformer<Unfitted, InDim, OutDim> {
130    /// Create a new unfitted transformer with dynamic dimensions
131    pub fn new(config: TypeSafeConfig) -> TypeSafeTransformer<Unfitted, Dynamic, Dynamic> {
132        TypeSafeTransformer {
133            config,
134            input_dim: None,
135            output_dim: None,
136            parameters: None,
137            _state: PhantomData,
138            _in_dim: PhantomData,
139            _out_dim: PhantomData,
140        }
141    }
142
143    /// Create a new unfitted transformer with known input dimension
144    pub fn with_input_dim<const N: usize>(
145        config: TypeSafeConfig,
146    ) -> TypeSafeTransformer<Unfitted, Known<N>, Dynamic> {
147        TypeSafeTransformer {
148            config,
149            input_dim: Some(N),
150            output_dim: None,
151            parameters: None,
152            _state: PhantomData,
153            _in_dim: PhantomData,
154            _out_dim: PhantomData,
155        }
156    }
157
158    /// Create a new unfitted transformer with known input and output dimensions
159    pub fn with_dimensions<const IN: usize, const OUT: usize>(
160        config: TypeSafeConfig,
161    ) -> TypeSafeTransformer<Unfitted, Known<IN>, Known<OUT>> {
162        TypeSafeTransformer {
163            config,
164            input_dim: Some(IN),
165            output_dim: Some(OUT),
166            parameters: None,
167            _state: PhantomData,
168            _in_dim: PhantomData,
169            _out_dim: PhantomData,
170        }
171    }
172}
173
174// Fit for dynamic dimensions
175impl TypeSafeTransformer<Unfitted, Dynamic, Dynamic> {
176    /// Fit the transformer to data
177    pub fn fit(self, X: &Array2<f64>) -> Result<TypeSafeTransformer<Fitted, Dynamic, Dynamic>> {
178        let input_dim = X.ncols();
179        let output_dim = X.ncols(); // Identity transform for this example
180
181        let parameters = if self.config.normalize {
182            let mean = X
183                .mean_axis(scirs2_core::ndarray::Axis(0))
184                .ok_or_else(|| SklearsError::InvalidInput("Failed to compute mean".to_string()))?;
185            let std = X.std_axis(scirs2_core::ndarray::Axis(0), 0.0);
186            Some(TransformParameters { mean, std })
187        } else {
188            None
189        };
190
191        Ok(TypeSafeTransformer {
192            config: self.config,
193            input_dim: Some(input_dim),
194            output_dim: Some(output_dim),
195            parameters,
196            _state: PhantomData,
197            _in_dim: PhantomData,
198            _out_dim: PhantomData,
199        })
200    }
201}
202
203// Fit for known input dimension
204impl<const N: usize> TypeSafeTransformer<Unfitted, Known<N>, Dynamic> {
205    /// Fit the transformer to data with compile-time input dimension check
206    pub fn fit(self, X: &Array2<f64>) -> Result<TypeSafeTransformer<Fitted, Known<N>, Dynamic>> {
207        if X.ncols() != N {
208            return Err(SklearsError::InvalidInput(format!(
209                "Expected {} input features, got {}",
210                N,
211                X.ncols()
212            )));
213        }
214
215        let output_dim = X.ncols();
216
217        let parameters = if self.config.normalize {
218            let mean = X
219                .mean_axis(scirs2_core::ndarray::Axis(0))
220                .ok_or_else(|| SklearsError::InvalidInput("Failed to compute mean".to_string()))?;
221            let std = X.std_axis(scirs2_core::ndarray::Axis(0), 0.0);
222            Some(TransformParameters { mean, std })
223        } else {
224            None
225        };
226
227        Ok(TypeSafeTransformer {
228            config: self.config,
229            input_dim: Some(N),
230            output_dim: Some(output_dim),
231            parameters,
232            _state: PhantomData,
233            _in_dim: PhantomData,
234            _out_dim: PhantomData,
235        })
236    }
237}
238
239// Fit for known input and output dimensions
240impl<const IN: usize, const OUT: usize> TypeSafeTransformer<Unfitted, Known<IN>, Known<OUT>> {
241    /// Fit the transformer to data with compile-time dimension checks
242    pub fn fit(
243        self,
244        X: &Array2<f64>,
245    ) -> Result<TypeSafeTransformer<Fitted, Known<IN>, Known<OUT>>> {
246        if X.ncols() != IN {
247            return Err(SklearsError::InvalidInput(format!(
248                "Expected {} input features, got {}",
249                IN,
250                X.ncols()
251            )));
252        }
253
254        let parameters = if self.config.normalize {
255            let mean = X
256                .mean_axis(scirs2_core::ndarray::Axis(0))
257                .ok_or_else(|| SklearsError::InvalidInput("Failed to compute mean".to_string()))?;
258            let std = X.std_axis(scirs2_core::ndarray::Axis(0), 0.0);
259            Some(TransformParameters { mean, std })
260        } else {
261            None
262        };
263
264        Ok(TypeSafeTransformer {
265            config: self.config,
266            input_dim: Some(IN),
267            output_dim: Some(OUT),
268            parameters,
269            _state: PhantomData,
270            _in_dim: PhantomData,
271            _out_dim: PhantomData,
272        })
273    }
274}
275
276// ================================================================================================
277// Implementation for Fitted State
278// ================================================================================================
279
280// Transform for dynamic dimensions
281impl TypeSafeTransformer<Fitted, Dynamic, Dynamic> {
282    /// Transform data using the fitted transformer
283    pub fn transform(&self, X: &Array2<f64>) -> Result<Array2<f64>> {
284        if let Some(input_dim) = self.input_dim {
285            if X.ncols() != input_dim {
286                return Err(SklearsError::InvalidInput(format!(
287                    "Expected {} input features, got {}",
288                    input_dim,
289                    X.ncols()
290                )));
291            }
292        }
293
294        let mut result = X.clone();
295
296        if let Some(ref params) = self.parameters {
297            for i in 0..result.nrows() {
298                for j in 0..result.ncols() {
299                    result[[i, j]] = (result[[i, j]] - params.mean[j]) / params.std[j].max(1e-10);
300                }
301            }
302        }
303
304        Ok(result)
305    }
306}
307
308// Transform for known input dimension
309impl<const N: usize> TypeSafeTransformer<Fitted, Known<N>, Dynamic> {
310    /// Transform data with compile-time input dimension check
311    pub fn transform(&self, X: &Array2<f64>) -> Result<Array2<f64>> {
312        if X.ncols() != N {
313            return Err(SklearsError::InvalidInput(format!(
314                "Expected {} input features, got {}",
315                N,
316                X.ncols()
317            )));
318        }
319
320        let mut result = X.clone();
321
322        if let Some(ref params) = self.parameters {
323            for i in 0..result.nrows() {
324                for j in 0..result.ncols() {
325                    result[[i, j]] = (result[[i, j]] - params.mean[j]) / params.std[j].max(1e-10);
326                }
327            }
328        }
329
330        Ok(result)
331    }
332}
333
334// Transform for known input and output dimensions
335impl<const IN: usize, const OUT: usize> TypeSafeTransformer<Fitted, Known<IN>, Known<OUT>> {
336    /// Transform data with compile-time dimension checks
337    pub fn transform(&self, X: &Array2<f64>) -> Result<Array2<f64>> {
338        if X.ncols() != IN {
339            return Err(SklearsError::InvalidInput(format!(
340                "Expected {} input features, got {}",
341                IN,
342                X.ncols()
343            )));
344        }
345
346        let mut result = X.clone();
347
348        if let Some(ref params) = self.parameters {
349            for i in 0..result.nrows() {
350                for j in 0..result.ncols() {
351                    result[[i, j]] = (result[[i, j]] - params.mean[j]) / params.std[j].max(1e-10);
352                }
353            }
354        }
355
356        // Compile-time check: output must have OUT columns
357        if result.ncols() != OUT {
358            return Err(SklearsError::InvalidInput(format!(
359                "Expected {} output features, got {}",
360                OUT,
361                result.ncols()
362            )));
363        }
364
365        Ok(result)
366    }
367}
368
369// ================================================================================================
370// Type-Safe Pipeline
371// ================================================================================================
372
373/// Type-safe pipeline that chains transformers with compile-time validation
374pub struct TypeSafePipeline<S1, S2, D1, D2, D3>
375where
376    S1: TransformState,
377    S2: TransformState,
378    D1: Dimension,
379    D2: Dimension,
380    D3: Dimension,
381{
382    /// First transformer
383    first: TypeSafeTransformer<S1, D1, D2>,
384    /// Second transformer
385    second: TypeSafeTransformer<S2, D2, D3>,
386}
387
388/// Type-safe pipeline in unfitted state
389impl<D1: Dimension, D2: Dimension, D3: Dimension> TypeSafePipeline<Unfitted, Unfitted, D1, D2, D3> {
390    /// Create a new pipeline by chaining two unfitted transformers
391    pub fn new(
392        first: TypeSafeTransformer<Unfitted, D1, D2>,
393        second: TypeSafeTransformer<Unfitted, D2, D3>,
394    ) -> Self {
395        Self { first, second }
396    }
397}
398
399/// Fit pipeline with dynamic dimensions
400impl TypeSafePipeline<Unfitted, Unfitted, Dynamic, Dynamic, Dynamic> {
401    /// Fit the entire pipeline to data
402    pub fn fit(
403        self,
404        X: &Array2<f64>,
405    ) -> Result<TypeSafePipeline<Fitted, Fitted, Dynamic, Dynamic, Dynamic>> {
406        let first_fitted = self.first.fit(X)?;
407        let X_transformed = first_fitted.transform(X)?;
408        let second_fitted = self.second.fit(&X_transformed)?;
409
410        Ok(TypeSafePipeline {
411            first: first_fitted,
412            second: second_fitted,
413        })
414    }
415}
416
417/// Fit pipeline with known dimensions
418impl<const D1: usize, const D2: usize, const D3: usize>
419    TypeSafePipeline<Unfitted, Unfitted, Known<D1>, Known<D2>, Known<D3>>
420{
421    /// Fit the entire pipeline to data with compile-time dimension validation
422    pub fn fit(
423        self,
424        X: &Array2<f64>,
425    ) -> Result<TypeSafePipeline<Fitted, Fitted, Known<D1>, Known<D2>, Known<D3>>> {
426        let first_fitted = self.first.fit(X)?;
427        let X_transformed = first_fitted.transform(X)?;
428        let second_fitted = self.second.fit(&X_transformed)?;
429
430        Ok(TypeSafePipeline {
431            first: first_fitted,
432            second: second_fitted,
433        })
434    }
435}
436
437/// Transform for fitted pipeline with dynamic dimensions
438impl TypeSafePipeline<Fitted, Fitted, Dynamic, Dynamic, Dynamic> {
439    /// Transform data through the entire pipeline
440    pub fn transform(&self, X: &Array2<f64>) -> Result<Array2<f64>> {
441        let X_intermediate = self.first.transform(X)?;
442        self.second.transform(&X_intermediate)
443    }
444}
445
446/// Transform for fitted pipeline with known dimensions
447impl<const D1: usize, const D2: usize, const D3: usize>
448    TypeSafePipeline<Fitted, Fitted, Known<D1>, Known<D2>, Known<D3>>
449{
450    /// Transform data through the entire pipeline with compile-time dimension validation
451    pub fn transform(&self, X: &Array2<f64>) -> Result<Array2<f64>> {
452        let X_intermediate = self.first.transform(X)?;
453        self.second.transform(&X_intermediate)
454    }
455}
456
457// ================================================================================================
458// Tests
459// ================================================================================================
460
461#[cfg(test)]
462mod tests {
463    use super::*;
464    use scirs2_core::ndarray::array;
465
466    #[test]
467    fn test_dynamic_dimensions() {
468        let X = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
469
470        let config = TypeSafeConfig::default();
471        let transformer: TypeSafeTransformer<Unfitted, Dynamic, Dynamic> =
472            TypeSafeTransformer::<Unfitted, Dynamic, Dynamic>::new(config);
473        let fitted = transformer.fit(&X).unwrap();
474        let result = fitted.transform(&X).unwrap();
475
476        assert_eq!(result.nrows(), 3);
477        assert_eq!(result.ncols(), 2);
478    }
479
480    #[test]
481    fn test_known_input_dimension() {
482        let X = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
483
484        let config = TypeSafeConfig::default();
485        let transformer: TypeSafeTransformer<Unfitted, Known<2>, Dynamic> =
486            TypeSafeTransformer::<Unfitted, Known<2>, Dynamic>::with_input_dim(config);
487        let fitted = transformer.fit(&X).unwrap();
488        let result = fitted.transform(&X).unwrap();
489
490        assert_eq!(result.ncols(), 2);
491    }
492
493    #[test]
494    fn test_known_input_dimension_mismatch() {
495        let X = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
496
497        let config = TypeSafeConfig::default();
498        let transformer: TypeSafeTransformer<Unfitted, Known<2>, Dynamic> =
499            TypeSafeTransformer::<Unfitted, Known<2>, Dynamic>::with_input_dim(config);
500
501        // This should fail because X has 3 columns, but we expect 2
502        assert!(transformer.fit(&X).is_err());
503    }
504
505    #[test]
506    fn test_known_dimensions() {
507        let X = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
508
509        let config = TypeSafeConfig::default();
510        let transformer: TypeSafeTransformer<Unfitted, Known<2>, Known<2>> =
511            TypeSafeTransformer::<Unfitted, Known<2>, Known<2>>::with_dimensions(config);
512        let fitted = transformer.fit(&X).unwrap();
513        let result = fitted.transform(&X).unwrap();
514
515        assert_eq!(result.nrows(), 3);
516        assert_eq!(result.ncols(), 2);
517    }
518
519    #[test]
520    fn test_normalization() {
521        let X = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
522
523        let config = TypeSafeConfig {
524            validate_dimensions: true,
525            normalize: true,
526        };
527        let transformer: TypeSafeTransformer<Unfitted, Dynamic, Dynamic> =
528            TypeSafeTransformer::<Unfitted, Dynamic, Dynamic>::new(config);
529        let fitted = transformer.fit(&X).unwrap();
530        let result = fitted.transform(&X).unwrap();
531
532        // Verify normalization: mean should be approximately 0
533        let mean = result.mean_axis(scirs2_core::ndarray::Axis(0)).unwrap();
534        for &val in mean.iter() {
535            assert!((val.abs()) < 1e-10);
536        }
537    }
538
539    #[test]
540    fn test_pipeline_dynamic() {
541        let X = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
542
543        let config1 = TypeSafeConfig {
544            validate_dimensions: true,
545            normalize: true,
546        };
547        let config2 = TypeSafeConfig::default();
548
549        let transformer1: TypeSafeTransformer<Unfitted, Dynamic, Dynamic> =
550            TypeSafeTransformer::<Unfitted, Dynamic, Dynamic>::new(config1);
551        let transformer2: TypeSafeTransformer<Unfitted, Dynamic, Dynamic> =
552            TypeSafeTransformer::<Unfitted, Dynamic, Dynamic>::new(config2);
553
554        let pipeline = TypeSafePipeline::new(transformer1, transformer2);
555        let fitted_pipeline = pipeline.fit(&X).unwrap();
556        let result = fitted_pipeline.transform(&X).unwrap();
557
558        assert_eq!(result.nrows(), 3);
559        assert_eq!(result.ncols(), 2);
560    }
561
562    #[test]
563    fn test_pipeline_known_dimensions() {
564        let X = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
565
566        let config1 = TypeSafeConfig::default();
567        let config2 = TypeSafeConfig::default();
568
569        let transformer1: TypeSafeTransformer<Unfitted, Known<2>, Known<2>> =
570            TypeSafeTransformer::<Unfitted, Known<2>, Known<2>>::with_dimensions(config1);
571        let transformer2: TypeSafeTransformer<Unfitted, Known<2>, Known<2>> =
572            TypeSafeTransformer::<Unfitted, Known<2>, Known<2>>::with_dimensions(config2);
573
574        let pipeline = TypeSafePipeline::new(transformer1, transformer2);
575        let fitted_pipeline = pipeline.fit(&X).unwrap();
576        let result = fitted_pipeline.transform(&X).unwrap();
577
578        assert_eq!(result.nrows(), 3);
579        assert_eq!(result.ncols(), 2);
580    }
581
582    #[test]
583    fn test_state_transitions() {
584        let X = array![[1.0, 2.0], [3.0, 4.0]];
585
586        // Start in Unfitted state
587        let unfitted: TypeSafeTransformer<Unfitted, Dynamic, Dynamic> =
588            TypeSafeTransformer::<Unfitted, Dynamic, Dynamic>::new(TypeSafeConfig::default());
589
590        // Fit transitions to Fitted state
591        let fitted = unfitted.fit(&X).unwrap();
592
593        // Can transform in Fitted state
594        let _result = fitted.transform(&X).unwrap();
595
596        // Cannot call fit() on fitted transformer (compile error if uncommented)
597        // let _refitted = fitted.fit(&X); // This would not compile
598    }
599
600    #[test]
601    fn test_dimension_markers() {
602        assert_eq!(Dynamic::value(), None);
603        assert_eq!(Known::<5>::value(), Some(5));
604        assert_eq!(Known::<10>::value(), Some(10));
605    }
606}