sklears_decomposition/
type_safe.rs

1//! Type-safe decomposition abstractions using Rust's type system
2//!
3//! This module provides zero-cost abstractions for matrix decomposition methods
4//! that leverage Rust's type system for improved safety and performance.
5
6use scirs2_core::ndarray::{Array1, Array2};
7use sklears_core::{error::Result, prelude::SklearsError, types::Float};
8use std::marker::PhantomData;
9
10/// Marker trait for decomposition states
11pub trait DecompositionState {}
12
13/// Untrained decomposition state
14#[derive(Debug, Clone, Copy)]
15pub struct Untrained;
16impl DecompositionState for Untrained {}
17
18/// Fitted decomposition state
19#[derive(Debug, Clone, Copy)]
20pub struct Fitted;
21impl DecompositionState for Fitted {}
22
23/// Phantom type for matrix rank
24#[derive(Debug, Clone, Copy)]
25pub struct Rank<const R: usize>;
26
27/// Phantom type for matrix dimensions
28#[derive(Debug, Clone, Copy)]
29pub struct Dimensions<const ROWS: usize, const COLS: usize>;
30
31/// Type-safe decomposition trait
32pub trait TypeSafeDecomposition<State: DecompositionState> {
33    type Output;
34    type ErrorType;
35
36    /// Get the current state of the decomposition
37    fn state(&self) -> PhantomData<State>;
38}
39
40/// Type-safe PCA with compile-time rank checking
41#[derive(Debug, Clone)]
42pub struct TypeSafePCA<State: DecompositionState, const RANK: usize> {
43    /// Number of components (must match RANK)
44    pub n_components: usize,
45    /// Whether to center the data
46    pub center: bool,
47    /// Whether to scale the data
48    pub scale: bool,
49    /// Fitted components (only available in Fitted state)
50    components: Option<Array2<Float>>,
51    /// Explained variance (only available in Fitted state)
52    explained_variance: Option<Array1<Float>>,
53    /// Mean (only available in Fitted state)
54    mean: Option<Array1<Float>>,
55    /// State phantom
56    _state: PhantomData<State>,
57}
58
59impl<const RANK: usize> Default for TypeSafePCA<Untrained, RANK> {
60    fn default() -> Self {
61        Self::new()
62    }
63}
64
65impl<const RANK: usize> TypeSafePCA<Untrained, RANK> {
66    /// Create a new type-safe PCA with compile-time rank checking
67    pub const fn new() -> Self {
68        Self {
69            n_components: RANK,
70            center: true,
71            scale: false,
72            components: None,
73            explained_variance: None,
74            mean: None,
75            _state: PhantomData,
76        }
77    }
78
79    /// Set whether to center the data
80    pub fn center(mut self, center: bool) -> Self {
81        self.center = center;
82        self
83    }
84
85    /// Set whether to scale the data
86    pub fn scale(mut self, scale: bool) -> Self {
87        self.scale = scale;
88        self
89    }
90
91    /// Fit the PCA model and transition to Fitted state
92    pub fn fit(self, data: &Array2<Float>) -> Result<TypeSafePCA<Fitted, RANK>> {
93        let (n_samples, n_features) = data.dim();
94
95        if RANK > n_features {
96            return Err(SklearsError::InvalidParameter {
97                name: "RANK".to_string(),
98                reason: format!("RANK ({RANK}) cannot exceed number of features ({n_features})"),
99            });
100        }
101
102        if RANK > n_samples {
103            return Err(SklearsError::InvalidParameter {
104                name: "RANK".to_string(),
105                reason: format!("RANK ({RANK}) cannot exceed number of samples ({n_samples})"),
106            });
107        }
108
109        // Center the data if requested
110        let mean = if self.center {
111            data.mean_axis(scirs2_core::ndarray::Axis(0)).unwrap()
112        } else {
113            Array1::zeros(n_features)
114        };
115
116        let mut centered_data = data.clone();
117        if self.center {
118            for mut row in centered_data.axis_iter_mut(scirs2_core::ndarray::Axis(0)) {
119                row -= &mean;
120            }
121        }
122
123        // Scale the data if requested
124        if self.scale {
125            let std = centered_data.std_axis(scirs2_core::ndarray::Axis(0), 0.0);
126            for mut row in centered_data.axis_iter_mut(scirs2_core::ndarray::Axis(0)) {
127                for (i, val) in row.iter_mut().enumerate() {
128                    if std[i] != 0.0 {
129                        *val /= std[i];
130                    }
131                }
132            }
133        }
134
135        // Compute covariance matrix
136        let covariance = centered_data.t().dot(&centered_data) / ((n_samples - 1) as Float);
137
138        // Eigendecomposition (simplified - in practice would use LAPACK)
139        let (eigenvalues, eigenvectors) = self.eigendecomposition(&covariance)?;
140
141        // Sort by eigenvalues in descending order and take top RANK
142        let mut eigen_pairs: Vec<(Float, Array1<Float>)> = eigenvalues
143            .iter()
144            .zip(eigenvectors.axis_iter(scirs2_core::ndarray::Axis(1)))
145            .map(|(&val, vec)| (val, vec.to_owned()))
146            .collect();
147
148        eigen_pairs.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
149
150        let mut components = Array2::zeros((n_features, RANK));
151        let mut explained_variance = Array1::zeros(RANK);
152
153        for (i, (eigenval, eigenvec)) in eigen_pairs.iter().take(RANK).enumerate() {
154            components.column_mut(i).assign(eigenvec);
155            explained_variance[i] = *eigenval;
156        }
157
158        Ok(TypeSafePCA {
159            n_components: RANK,
160            center: self.center,
161            scale: self.scale,
162            components: Some(components),
163            explained_variance: Some(explained_variance),
164            mean: Some(mean),
165            _state: PhantomData,
166        })
167    }
168
169    /// Simplified eigendecomposition (placeholder for actual LAPACK call)
170    fn eigendecomposition(&self, matrix: &Array2<Float>) -> Result<(Array1<Float>, Array2<Float>)> {
171        let n = matrix.nrows();
172
173        // This is a placeholder implementation
174        // In practice, you would use LAPACK for eigendecomposition
175        let eigenvalues = Array1::from_iter((0..n).map(|i| (n - i) as Float));
176        let mut eigenvectors: Array2<Float> = Array2::eye(n);
177
178        // Normalize eigenvectors
179        for mut col in eigenvectors.axis_iter_mut(scirs2_core::ndarray::Axis(1)) {
180            let norm = col.dot(&col).sqrt();
181            if norm > 1e-10 {
182                col /= norm;
183            }
184        }
185
186        Ok((eigenvalues, eigenvectors))
187    }
188}
189
190impl<const RANK: usize> TypeSafePCA<Fitted, RANK> {
191    /// Transform data using the fitted components
192    pub fn transform(&self, data: &Array2<Float>) -> Result<Array2<Float>> {
193        let components = self
194            .components
195            .as_ref()
196            .ok_or_else(|| SklearsError::NotFitted {
197                operation: "transform".to_string(),
198            })?;
199
200        let mut transformed_data = data.clone();
201
202        // Apply same centering as during fit
203        if self.center {
204            if let Some(ref mean) = self.mean {
205                for mut row in transformed_data.axis_iter_mut(scirs2_core::ndarray::Axis(0)) {
206                    row -= mean;
207                }
208            }
209        }
210
211        // Apply same scaling as during fit
212        if self.scale {
213            // Would need to store scale factors from fit for this to work properly
214        }
215
216        // Project onto components
217        Ok(transformed_data.dot(components))
218    }
219
220    /// Get the fitted components (guaranteed to be RANK columns)
221    pub fn components(&self) -> &Array2<Float> {
222        self.components.as_ref().unwrap()
223    }
224
225    /// Get explained variance (guaranteed to be RANK elements)
226    pub fn explained_variance(&self) -> &Array1<Float> {
227        self.explained_variance.as_ref().unwrap()
228    }
229
230    /// Get explained variance ratios
231    pub fn explained_variance_ratio(&self) -> Array1<Float> {
232        let explained_var = self.explained_variance();
233        let total_variance = explained_var.sum();
234        explained_var / total_variance
235    }
236
237    /// Fit and transform in one step
238    pub fn fit_transform(
239        untrained: TypeSafePCA<Untrained, RANK>,
240        data: &Array2<Float>,
241    ) -> Result<(TypeSafePCA<Fitted, RANK>, Array2<Float>)> {
242        let fitted = untrained.fit(data)?;
243        let transformed = fitted.transform(data)?;
244        Ok((fitted, transformed))
245    }
246}
247
248impl<State: DecompositionState, const RANK: usize> TypeSafeDecomposition<State>
249    for TypeSafePCA<State, RANK>
250{
251    type Output = Array2<Float>;
252    type ErrorType = SklearsError;
253
254    fn state(&self) -> PhantomData<State> {
255        self._state
256    }
257}
258
259/// Type-safe matrix with compile-time dimension checking
260#[derive(Debug, Clone)]
261pub struct TypeSafeMatrix<const ROWS: usize, const COLS: usize> {
262    data: Array2<Float>,
263}
264
265impl<const ROWS: usize, const COLS: usize> TypeSafeMatrix<ROWS, COLS> {
266    /// Create a new type-safe matrix with compile-time dimension checking
267    pub fn new(data: Array2<Float>) -> Result<Self> {
268        let (rows, cols) = data.dim();
269        if rows != ROWS || cols != COLS {
270            return Err(SklearsError::InvalidParameter {
271                name: "matrix_dimensions".to_string(),
272                reason: format!(
273                    "Matrix dimensions {rows}x{cols} do not match expected {ROWS}x{COLS}"
274                ),
275            });
276        }
277        Ok(Self { data })
278    }
279
280    /// Create a zero matrix
281    pub fn zeros() -> Self {
282        Self {
283            data: Array2::zeros((ROWS, COLS)),
284        }
285    }
286
287    /// Create an identity matrix (only valid for square matrices where ROWS == COLS)
288    pub fn eye() -> Self {
289        assert_eq!(ROWS, COLS, "Identity matrix requires ROWS == COLS");
290        Self {
291            data: Array2::eye(ROWS),
292        }
293    }
294
295    /// Get the underlying data
296    pub fn data(&self) -> &Array2<Float> {
297        &self.data
298    }
299
300    /// Get mutable access to the underlying data
301    pub fn data_mut(&mut self) -> &mut Array2<Float> {
302        &mut self.data
303    }
304
305    /// Matrix multiplication with compile-time dimension checking
306    pub fn dot<const OTHER_COLS: usize>(
307        &self,
308        other: &TypeSafeMatrix<COLS, OTHER_COLS>,
309    ) -> TypeSafeMatrix<ROWS, OTHER_COLS> {
310        let result = self.data.dot(&other.data);
311        TypeSafeMatrix { data: result }
312    }
313
314    /// Transpose the matrix
315    pub fn t(&self) -> TypeSafeMatrix<COLS, ROWS> {
316        TypeSafeMatrix {
317            data: self.data.t().to_owned(),
318        }
319    }
320
321    /// Extract a submatrix with runtime size checking
322    pub fn submatrix<const SUB_ROWS: usize, const SUB_COLS: usize>(
323        &self,
324        start_row: usize,
325        start_col: usize,
326    ) -> Result<TypeSafeMatrix<SUB_ROWS, SUB_COLS>> {
327        if SUB_ROWS > ROWS || SUB_COLS > COLS {
328            return Err(SklearsError::InvalidParameter {
329                name: "submatrix_size".to_string(),
330                reason: "Submatrix size exceeds matrix dimensions".to_string(),
331            });
332        }
333
334        if start_row + SUB_ROWS > ROWS || start_col + SUB_COLS > COLS {
335            return Err(SklearsError::InvalidParameter {
336                name: "submatrix_bounds".to_string(),
337                reason: "Submatrix bounds exceed matrix dimensions".to_string(),
338            });
339        }
340
341        let subarray = self
342            .data
343            .slice(scirs2_core::ndarray::s![
344                start_row..start_row + SUB_ROWS,
345                start_col..start_col + SUB_COLS
346            ])
347            .to_owned();
348
349        Ok(TypeSafeMatrix { data: subarray })
350    }
351}
352
353/// Type-safe component indexing
354#[derive(Debug, Clone, Copy)]
355pub struct ComponentIndex<const INDEX: usize>;
356
357impl<const INDEX: usize> Default for ComponentIndex<INDEX> {
358    fn default() -> Self {
359        Self::new()
360    }
361}
362
363impl<const INDEX: usize> ComponentIndex<INDEX> {
364    /// Create a new component index
365    pub const fn new() -> Self {
366        Self
367    }
368
369    /// Get the index value
370    pub const fn index(&self) -> usize {
371        INDEX
372    }
373}
374
375/// Type-safe component access for fitted decomposition
376pub trait ComponentAccess<const RANK: usize> {
377    /// Get a specific component by index (runtime checked)
378    fn component<const INDEX: usize>(&self, _index: ComponentIndex<INDEX>)
379        -> Result<Array1<Float>>;
380}
381
382impl<const RANK: usize> ComponentAccess<RANK> for TypeSafePCA<Fitted, RANK> {
383    fn component<const INDEX: usize>(
384        &self,
385        _index: ComponentIndex<INDEX>,
386    ) -> Result<Array1<Float>> {
387        if INDEX >= RANK {
388            return Err(SklearsError::InvalidParameter {
389                name: "component_index".to_string(),
390                reason: format!("Component index {INDEX} exceeds number of components {RANK}"),
391            });
392        }
393
394        let components = self.components();
395        Ok(components.column(INDEX).to_owned())
396    }
397}
398
399/// Zero-cost decomposition pipeline builder
400pub struct DecompositionPipeline<State: DecompositionState> {
401    operations: Vec<Box<dyn DecompositionOperation>>,
402    _state: PhantomData<State>,
403}
404
405/// Trait for decomposition operations in the pipeline
406pub trait DecompositionOperation {
407    fn apply(&self, data: &Array2<Float>) -> Result<Array2<Float>>;
408    fn name(&self) -> &str;
409}
410
411/// Centering operation
412#[derive(Debug, Clone)]
413pub struct CenteringOperation {
414    #[allow(dead_code)]
415    mean: Option<Array1<Float>>,
416}
417
418impl Default for CenteringOperation {
419    fn default() -> Self {
420        Self::new()
421    }
422}
423
424impl CenteringOperation {
425    pub fn new() -> Self {
426        Self { mean: None }
427    }
428}
429
430impl DecompositionOperation for CenteringOperation {
431    fn apply(&self, data: &Array2<Float>) -> Result<Array2<Float>> {
432        let mean = data.mean_axis(scirs2_core::ndarray::Axis(0)).unwrap();
433        let mut centered = data.clone();
434        for mut row in centered.axis_iter_mut(scirs2_core::ndarray::Axis(0)) {
435            row -= &mean;
436        }
437        Ok(centered)
438    }
439
440    fn name(&self) -> &str {
441        "centering"
442    }
443}
444
445/// Scaling operation
446#[derive(Debug, Clone)]
447pub struct ScalingOperation {
448    #[allow(dead_code)]
449    scale: Option<Array1<Float>>,
450}
451
452impl Default for ScalingOperation {
453    fn default() -> Self {
454        Self::new()
455    }
456}
457
458impl ScalingOperation {
459    pub fn new() -> Self {
460        Self { scale: None }
461    }
462}
463
464impl DecompositionOperation for ScalingOperation {
465    fn apply(&self, data: &Array2<Float>) -> Result<Array2<Float>> {
466        let std = data.std_axis(scirs2_core::ndarray::Axis(0), 0.0);
467        let mut scaled = data.clone();
468        for mut row in scaled.axis_iter_mut(scirs2_core::ndarray::Axis(0)) {
469            for (i, val) in row.iter_mut().enumerate() {
470                if std[i] != 0.0 {
471                    *val /= std[i];
472                }
473            }
474        }
475        Ok(scaled)
476    }
477
478    fn name(&self) -> &str {
479        "scaling"
480    }
481}
482
483impl Default for DecompositionPipeline<Untrained> {
484    fn default() -> Self {
485        Self::new()
486    }
487}
488
489impl DecompositionPipeline<Untrained> {
490    /// Create a new decomposition pipeline
491    pub fn new() -> Self {
492        Self {
493            operations: Vec::new(),
494            _state: PhantomData,
495        }
496    }
497
498    /// Add a centering operation to the pipeline
499    pub fn center(mut self) -> Self {
500        self.operations.push(Box::new(CenteringOperation::new()));
501        self
502    }
503
504    /// Add a scaling operation to the pipeline
505    pub fn scale(mut self) -> Self {
506        self.operations.push(Box::new(ScalingOperation::new()));
507        self
508    }
509
510    /// Apply the pipeline to data
511    pub fn apply(&self, data: &Array2<Float>) -> Result<Array2<Float>> {
512        let mut result = data.clone();
513        for operation in &self.operations {
514            result = operation.apply(&result)?;
515        }
516        Ok(result)
517    }
518
519    /// Fit the pipeline and transition to fitted state
520    pub fn fit(self, _data: &Array2<Float>) -> Result<DecompositionPipeline<Fitted>> {
521        // In a real implementation, we would store fitted parameters
522        Ok(DecompositionPipeline {
523            operations: self.operations,
524            _state: PhantomData,
525        })
526    }
527}
528
529impl DecompositionPipeline<Fitted> {
530    /// Apply the fitted pipeline to new data
531    pub fn transform(&self, data: &Array2<Float>) -> Result<Array2<Float>> {
532        let mut result = data.clone();
533        for operation in &self.operations {
534            result = operation.apply(&result)?;
535        }
536        Ok(result)
537    }
538}
539
540/// Runtime matrix shape validation for matrix multiplication
541pub fn validate_matrix_multiplication<
542    const A_ROWS: usize,
543    const A_COLS: usize,
544    const B_ROWS: usize,
545    const B_COLS: usize,
546>(
547    _a: &TypeSafeMatrix<A_ROWS, A_COLS>,
548    _b: &TypeSafeMatrix<B_ROWS, B_COLS>,
549) -> Result<()> {
550    if A_COLS != B_ROWS {
551        return Err(SklearsError::InvalidParameter {
552            name: "matrix_multiplication".to_string(),
553            reason: format!(
554                "Cannot multiply {A_ROWS}x{A_COLS} matrix with {B_ROWS}x{B_COLS} matrix"
555            ),
556        });
557    }
558    Ok(())
559}
560
561#[allow(non_snake_case)]
562#[cfg(test)]
563mod tests {
564    use super::*;
565    use scirs2_core::ndarray::array;
566
567    #[test]
568    fn test_type_safe_pca_creation() {
569        let pca: TypeSafePCA<Untrained, 2> = TypeSafePCA::new();
570        assert_eq!(pca.n_components, 2);
571        assert!(pca.center);
572        assert!(!pca.scale);
573    }
574
575    #[test]
576    fn test_type_safe_pca_fit() {
577        let data = array![
578            [1.0, 2.0, 3.0],
579            [4.0, 5.0, 6.0],
580            [7.0, 8.0, 9.0],
581            [10.0, 11.0, 12.0],
582        ];
583
584        let pca: TypeSafePCA<Untrained, 2> = TypeSafePCA::new();
585        let fitted_pca = pca.fit(&data).unwrap();
586
587        assert_eq!(fitted_pca.components().dim(), (3, 2));
588        assert_eq!(fitted_pca.explained_variance().len(), 2);
589    }
590
591    #[test]
592    fn test_type_safe_pca_transform() {
593        let data = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0],];
594
595        let pca: TypeSafePCA<Untrained, 2> = TypeSafePCA::new();
596        let fitted_pca = pca.fit(&data).unwrap();
597        let transformed = fitted_pca.transform(&data).unwrap();
598
599        assert_eq!(transformed.dim(), (3, 2));
600    }
601
602    #[test]
603    fn test_type_safe_pca_rank_validation() {
604        let data = array![
605            [1.0, 2.0], // Only 2 features
606            [3.0, 4.0],
607        ];
608
609        // This should fail because RANK=3 > n_features=2
610        let pca: TypeSafePCA<Untrained, 3> = TypeSafePCA::new();
611        let result = pca.fit(&data);
612        assert!(result.is_err());
613    }
614
615    #[test]
616    fn test_type_safe_matrix_creation() {
617        let data = array![[1.0, 2.0], [3.0, 4.0]];
618        let matrix: TypeSafeMatrix<2, 2> = TypeSafeMatrix::new(data).unwrap();
619        assert_eq!(matrix.data().dim(), (2, 2));
620    }
621
622    #[test]
623    fn test_type_safe_matrix_dimension_validation() {
624        let data = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; // 2x3 matrix
625        let result: Result<TypeSafeMatrix<3, 3>> = TypeSafeMatrix::new(data);
626        assert!(result.is_err());
627    }
628
629    #[test]
630    fn test_type_safe_matrix_multiplication() {
631        let a_data = array![[1.0, 2.0], [3.0, 4.0]];
632        let b_data = array![[5.0, 6.0], [7.0, 8.0]];
633
634        let a: TypeSafeMatrix<2, 2> = TypeSafeMatrix::new(a_data).unwrap();
635        let b: TypeSafeMatrix<2, 2> = TypeSafeMatrix::new(b_data).unwrap();
636
637        let result: TypeSafeMatrix<2, 2> = a.dot(&b);
638        assert_eq!(result.data().dim(), (2, 2));
639    }
640
641    #[test]
642    fn test_type_safe_matrix_transpose() {
643        let data = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
644        let matrix: TypeSafeMatrix<2, 3> = TypeSafeMatrix::new(data).unwrap();
645        let transposed: TypeSafeMatrix<3, 2> = matrix.t();
646        assert_eq!(transposed.data().dim(), (3, 2));
647    }
648
649    #[test]
650    fn test_component_index() {
651        let index: ComponentIndex<0> = ComponentIndex::new();
652        assert_eq!(index.index(), 0);
653
654        let index: ComponentIndex<5> = ComponentIndex::new();
655        assert_eq!(index.index(), 5);
656    }
657
658    #[test]
659    fn test_decomposition_pipeline() {
660        let data = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0],];
661
662        let pipeline = DecompositionPipeline::new().center().scale();
663
664        let processed = pipeline.apply(&data).unwrap();
665        assert_eq!(processed.dim(), data.dim());
666
667        let fitted_pipeline = pipeline.fit(&data).unwrap();
668        let transformed = fitted_pipeline.transform(&data).unwrap();
669        assert_eq!(transformed.dim(), data.dim());
670    }
671
672    #[test]
673    fn test_matrix_shape_validation() {
674        let a_data = array![[1.0, 2.0], [3.0, 4.0]];
675        let b_data = array![[5.0, 6.0], [7.0, 8.0]];
676
677        let a: TypeSafeMatrix<2, 2> = TypeSafeMatrix::new(a_data).unwrap();
678        let b: TypeSafeMatrix<2, 2> = TypeSafeMatrix::new(b_data).unwrap();
679
680        // This should compile and succeed
681        let result = validate_matrix_multiplication(&a, &b);
682        assert!(result.is_ok());
683    }
684}