1use scirs2_core::ndarray::{Array1, Array2};
7use sklears_core::{error::Result, prelude::SklearsError, types::Float};
8use std::marker::PhantomData;
9
10pub trait DecompositionState {}
12
13#[derive(Debug, Clone, Copy)]
15pub struct Untrained;
16impl DecompositionState for Untrained {}
17
18#[derive(Debug, Clone, Copy)]
20pub struct Fitted;
21impl DecompositionState for Fitted {}
22
23#[derive(Debug, Clone, Copy)]
25pub struct Rank<const R: usize>;
26
27#[derive(Debug, Clone, Copy)]
29pub struct Dimensions<const ROWS: usize, const COLS: usize>;
30
31pub trait TypeSafeDecomposition<State: DecompositionState> {
33 type Output;
34 type ErrorType;
35
36 fn state(&self) -> PhantomData<State>;
38}
39
40#[derive(Debug, Clone)]
42pub struct TypeSafePCA<State: DecompositionState, const RANK: usize> {
43 pub n_components: usize,
45 pub center: bool,
47 pub scale: bool,
49 components: Option<Array2<Float>>,
51 explained_variance: Option<Array1<Float>>,
53 mean: Option<Array1<Float>>,
55 _state: PhantomData<State>,
57}
58
59impl<const RANK: usize> Default for TypeSafePCA<Untrained, RANK> {
60 fn default() -> Self {
61 Self::new()
62 }
63}
64
65impl<const RANK: usize> TypeSafePCA<Untrained, RANK> {
66 pub const fn new() -> Self {
68 Self {
69 n_components: RANK,
70 center: true,
71 scale: false,
72 components: None,
73 explained_variance: None,
74 mean: None,
75 _state: PhantomData,
76 }
77 }
78
79 pub fn center(mut self, center: bool) -> Self {
81 self.center = center;
82 self
83 }
84
85 pub fn scale(mut self, scale: bool) -> Self {
87 self.scale = scale;
88 self
89 }
90
91 pub fn fit(self, data: &Array2<Float>) -> Result<TypeSafePCA<Fitted, RANK>> {
93 let (n_samples, n_features) = data.dim();
94
95 if RANK > n_features {
96 return Err(SklearsError::InvalidParameter {
97 name: "RANK".to_string(),
98 reason: format!("RANK ({RANK}) cannot exceed number of features ({n_features})"),
99 });
100 }
101
102 if RANK > n_samples {
103 return Err(SklearsError::InvalidParameter {
104 name: "RANK".to_string(),
105 reason: format!("RANK ({RANK}) cannot exceed number of samples ({n_samples})"),
106 });
107 }
108
109 let mean = if self.center {
111 data.mean_axis(scirs2_core::ndarray::Axis(0)).unwrap()
112 } else {
113 Array1::zeros(n_features)
114 };
115
116 let mut centered_data = data.clone();
117 if self.center {
118 for mut row in centered_data.axis_iter_mut(scirs2_core::ndarray::Axis(0)) {
119 row -= &mean;
120 }
121 }
122
123 if self.scale {
125 let std = centered_data.std_axis(scirs2_core::ndarray::Axis(0), 0.0);
126 for mut row in centered_data.axis_iter_mut(scirs2_core::ndarray::Axis(0)) {
127 for (i, val) in row.iter_mut().enumerate() {
128 if std[i] != 0.0 {
129 *val /= std[i];
130 }
131 }
132 }
133 }
134
135 let covariance = centered_data.t().dot(¢ered_data) / ((n_samples - 1) as Float);
137
138 let (eigenvalues, eigenvectors) = self.eigendecomposition(&covariance)?;
140
141 let mut eigen_pairs: Vec<(Float, Array1<Float>)> = eigenvalues
143 .iter()
144 .zip(eigenvectors.axis_iter(scirs2_core::ndarray::Axis(1)))
145 .map(|(&val, vec)| (val, vec.to_owned()))
146 .collect();
147
148 eigen_pairs.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
149
150 let mut components = Array2::zeros((n_features, RANK));
151 let mut explained_variance = Array1::zeros(RANK);
152
153 for (i, (eigenval, eigenvec)) in eigen_pairs.iter().take(RANK).enumerate() {
154 components.column_mut(i).assign(eigenvec);
155 explained_variance[i] = *eigenval;
156 }
157
158 Ok(TypeSafePCA {
159 n_components: RANK,
160 center: self.center,
161 scale: self.scale,
162 components: Some(components),
163 explained_variance: Some(explained_variance),
164 mean: Some(mean),
165 _state: PhantomData,
166 })
167 }
168
169 fn eigendecomposition(&self, matrix: &Array2<Float>) -> Result<(Array1<Float>, Array2<Float>)> {
171 let n = matrix.nrows();
172
173 let eigenvalues = Array1::from_iter((0..n).map(|i| (n - i) as Float));
176 let mut eigenvectors: Array2<Float> = Array2::eye(n);
177
178 for mut col in eigenvectors.axis_iter_mut(scirs2_core::ndarray::Axis(1)) {
180 let norm = col.dot(&col).sqrt();
181 if norm > 1e-10 {
182 col /= norm;
183 }
184 }
185
186 Ok((eigenvalues, eigenvectors))
187 }
188}
189
190impl<const RANK: usize> TypeSafePCA<Fitted, RANK> {
191 pub fn transform(&self, data: &Array2<Float>) -> Result<Array2<Float>> {
193 let components = self
194 .components
195 .as_ref()
196 .ok_or_else(|| SklearsError::NotFitted {
197 operation: "transform".to_string(),
198 })?;
199
200 let mut transformed_data = data.clone();
201
202 if self.center {
204 if let Some(ref mean) = self.mean {
205 for mut row in transformed_data.axis_iter_mut(scirs2_core::ndarray::Axis(0)) {
206 row -= mean;
207 }
208 }
209 }
210
211 if self.scale {
213 }
215
216 Ok(transformed_data.dot(components))
218 }
219
220 pub fn components(&self) -> &Array2<Float> {
222 self.components.as_ref().unwrap()
223 }
224
225 pub fn explained_variance(&self) -> &Array1<Float> {
227 self.explained_variance.as_ref().unwrap()
228 }
229
230 pub fn explained_variance_ratio(&self) -> Array1<Float> {
232 let explained_var = self.explained_variance();
233 let total_variance = explained_var.sum();
234 explained_var / total_variance
235 }
236
237 pub fn fit_transform(
239 untrained: TypeSafePCA<Untrained, RANK>,
240 data: &Array2<Float>,
241 ) -> Result<(TypeSafePCA<Fitted, RANK>, Array2<Float>)> {
242 let fitted = untrained.fit(data)?;
243 let transformed = fitted.transform(data)?;
244 Ok((fitted, transformed))
245 }
246}
247
248impl<State: DecompositionState, const RANK: usize> TypeSafeDecomposition<State>
249 for TypeSafePCA<State, RANK>
250{
251 type Output = Array2<Float>;
252 type ErrorType = SklearsError;
253
254 fn state(&self) -> PhantomData<State> {
255 self._state
256 }
257}
258
259#[derive(Debug, Clone)]
261pub struct TypeSafeMatrix<const ROWS: usize, const COLS: usize> {
262 data: Array2<Float>,
263}
264
265impl<const ROWS: usize, const COLS: usize> TypeSafeMatrix<ROWS, COLS> {
266 pub fn new(data: Array2<Float>) -> Result<Self> {
268 let (rows, cols) = data.dim();
269 if rows != ROWS || cols != COLS {
270 return Err(SklearsError::InvalidParameter {
271 name: "matrix_dimensions".to_string(),
272 reason: format!(
273 "Matrix dimensions {rows}x{cols} do not match expected {ROWS}x{COLS}"
274 ),
275 });
276 }
277 Ok(Self { data })
278 }
279
280 pub fn zeros() -> Self {
282 Self {
283 data: Array2::zeros((ROWS, COLS)),
284 }
285 }
286
287 pub fn eye() -> Self {
289 assert_eq!(ROWS, COLS, "Identity matrix requires ROWS == COLS");
290 Self {
291 data: Array2::eye(ROWS),
292 }
293 }
294
295 pub fn data(&self) -> &Array2<Float> {
297 &self.data
298 }
299
300 pub fn data_mut(&mut self) -> &mut Array2<Float> {
302 &mut self.data
303 }
304
305 pub fn dot<const OTHER_COLS: usize>(
307 &self,
308 other: &TypeSafeMatrix<COLS, OTHER_COLS>,
309 ) -> TypeSafeMatrix<ROWS, OTHER_COLS> {
310 let result = self.data.dot(&other.data);
311 TypeSafeMatrix { data: result }
312 }
313
314 pub fn t(&self) -> TypeSafeMatrix<COLS, ROWS> {
316 TypeSafeMatrix {
317 data: self.data.t().to_owned(),
318 }
319 }
320
321 pub fn submatrix<const SUB_ROWS: usize, const SUB_COLS: usize>(
323 &self,
324 start_row: usize,
325 start_col: usize,
326 ) -> Result<TypeSafeMatrix<SUB_ROWS, SUB_COLS>> {
327 if SUB_ROWS > ROWS || SUB_COLS > COLS {
328 return Err(SklearsError::InvalidParameter {
329 name: "submatrix_size".to_string(),
330 reason: "Submatrix size exceeds matrix dimensions".to_string(),
331 });
332 }
333
334 if start_row + SUB_ROWS > ROWS || start_col + SUB_COLS > COLS {
335 return Err(SklearsError::InvalidParameter {
336 name: "submatrix_bounds".to_string(),
337 reason: "Submatrix bounds exceed matrix dimensions".to_string(),
338 });
339 }
340
341 let subarray = self
342 .data
343 .slice(scirs2_core::ndarray::s![
344 start_row..start_row + SUB_ROWS,
345 start_col..start_col + SUB_COLS
346 ])
347 .to_owned();
348
349 Ok(TypeSafeMatrix { data: subarray })
350 }
351}
352
353#[derive(Debug, Clone, Copy)]
355pub struct ComponentIndex<const INDEX: usize>;
356
357impl<const INDEX: usize> Default for ComponentIndex<INDEX> {
358 fn default() -> Self {
359 Self::new()
360 }
361}
362
363impl<const INDEX: usize> ComponentIndex<INDEX> {
364 pub const fn new() -> Self {
366 Self
367 }
368
369 pub const fn index(&self) -> usize {
371 INDEX
372 }
373}
374
375pub trait ComponentAccess<const RANK: usize> {
377 fn component<const INDEX: usize>(&self, _index: ComponentIndex<INDEX>)
379 -> Result<Array1<Float>>;
380}
381
382impl<const RANK: usize> ComponentAccess<RANK> for TypeSafePCA<Fitted, RANK> {
383 fn component<const INDEX: usize>(
384 &self,
385 _index: ComponentIndex<INDEX>,
386 ) -> Result<Array1<Float>> {
387 if INDEX >= RANK {
388 return Err(SklearsError::InvalidParameter {
389 name: "component_index".to_string(),
390 reason: format!("Component index {INDEX} exceeds number of components {RANK}"),
391 });
392 }
393
394 let components = self.components();
395 Ok(components.column(INDEX).to_owned())
396 }
397}
398
399pub struct DecompositionPipeline<State: DecompositionState> {
401 operations: Vec<Box<dyn DecompositionOperation>>,
402 _state: PhantomData<State>,
403}
404
405pub trait DecompositionOperation {
407 fn apply(&self, data: &Array2<Float>) -> Result<Array2<Float>>;
408 fn name(&self) -> &str;
409}
410
411#[derive(Debug, Clone)]
413pub struct CenteringOperation {
414 #[allow(dead_code)]
415 mean: Option<Array1<Float>>,
416}
417
418impl Default for CenteringOperation {
419 fn default() -> Self {
420 Self::new()
421 }
422}
423
424impl CenteringOperation {
425 pub fn new() -> Self {
426 Self { mean: None }
427 }
428}
429
430impl DecompositionOperation for CenteringOperation {
431 fn apply(&self, data: &Array2<Float>) -> Result<Array2<Float>> {
432 let mean = data.mean_axis(scirs2_core::ndarray::Axis(0)).unwrap();
433 let mut centered = data.clone();
434 for mut row in centered.axis_iter_mut(scirs2_core::ndarray::Axis(0)) {
435 row -= &mean;
436 }
437 Ok(centered)
438 }
439
440 fn name(&self) -> &str {
441 "centering"
442 }
443}
444
445#[derive(Debug, Clone)]
447pub struct ScalingOperation {
448 #[allow(dead_code)]
449 scale: Option<Array1<Float>>,
450}
451
452impl Default for ScalingOperation {
453 fn default() -> Self {
454 Self::new()
455 }
456}
457
458impl ScalingOperation {
459 pub fn new() -> Self {
460 Self { scale: None }
461 }
462}
463
464impl DecompositionOperation for ScalingOperation {
465 fn apply(&self, data: &Array2<Float>) -> Result<Array2<Float>> {
466 let std = data.std_axis(scirs2_core::ndarray::Axis(0), 0.0);
467 let mut scaled = data.clone();
468 for mut row in scaled.axis_iter_mut(scirs2_core::ndarray::Axis(0)) {
469 for (i, val) in row.iter_mut().enumerate() {
470 if std[i] != 0.0 {
471 *val /= std[i];
472 }
473 }
474 }
475 Ok(scaled)
476 }
477
478 fn name(&self) -> &str {
479 "scaling"
480 }
481}
482
483impl Default for DecompositionPipeline<Untrained> {
484 fn default() -> Self {
485 Self::new()
486 }
487}
488
489impl DecompositionPipeline<Untrained> {
490 pub fn new() -> Self {
492 Self {
493 operations: Vec::new(),
494 _state: PhantomData,
495 }
496 }
497
498 pub fn center(mut self) -> Self {
500 self.operations.push(Box::new(CenteringOperation::new()));
501 self
502 }
503
504 pub fn scale(mut self) -> Self {
506 self.operations.push(Box::new(ScalingOperation::new()));
507 self
508 }
509
510 pub fn apply(&self, data: &Array2<Float>) -> Result<Array2<Float>> {
512 let mut result = data.clone();
513 for operation in &self.operations {
514 result = operation.apply(&result)?;
515 }
516 Ok(result)
517 }
518
519 pub fn fit(self, _data: &Array2<Float>) -> Result<DecompositionPipeline<Fitted>> {
521 Ok(DecompositionPipeline {
523 operations: self.operations,
524 _state: PhantomData,
525 })
526 }
527}
528
529impl DecompositionPipeline<Fitted> {
530 pub fn transform(&self, data: &Array2<Float>) -> Result<Array2<Float>> {
532 let mut result = data.clone();
533 for operation in &self.operations {
534 result = operation.apply(&result)?;
535 }
536 Ok(result)
537 }
538}
539
540pub fn validate_matrix_multiplication<
542 const A_ROWS: usize,
543 const A_COLS: usize,
544 const B_ROWS: usize,
545 const B_COLS: usize,
546>(
547 _a: &TypeSafeMatrix<A_ROWS, A_COLS>,
548 _b: &TypeSafeMatrix<B_ROWS, B_COLS>,
549) -> Result<()> {
550 if A_COLS != B_ROWS {
551 return Err(SklearsError::InvalidParameter {
552 name: "matrix_multiplication".to_string(),
553 reason: format!(
554 "Cannot multiply {A_ROWS}x{A_COLS} matrix with {B_ROWS}x{B_COLS} matrix"
555 ),
556 });
557 }
558 Ok(())
559}
560
561#[allow(non_snake_case)]
562#[cfg(test)]
563mod tests {
564 use super::*;
565 use scirs2_core::ndarray::array;
566
567 #[test]
568 fn test_type_safe_pca_creation() {
569 let pca: TypeSafePCA<Untrained, 2> = TypeSafePCA::new();
570 assert_eq!(pca.n_components, 2);
571 assert!(pca.center);
572 assert!(!pca.scale);
573 }
574
575 #[test]
576 fn test_type_safe_pca_fit() {
577 let data = array![
578 [1.0, 2.0, 3.0],
579 [4.0, 5.0, 6.0],
580 [7.0, 8.0, 9.0],
581 [10.0, 11.0, 12.0],
582 ];
583
584 let pca: TypeSafePCA<Untrained, 2> = TypeSafePCA::new();
585 let fitted_pca = pca.fit(&data).unwrap();
586
587 assert_eq!(fitted_pca.components().dim(), (3, 2));
588 assert_eq!(fitted_pca.explained_variance().len(), 2);
589 }
590
591 #[test]
592 fn test_type_safe_pca_transform() {
593 let data = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0],];
594
595 let pca: TypeSafePCA<Untrained, 2> = TypeSafePCA::new();
596 let fitted_pca = pca.fit(&data).unwrap();
597 let transformed = fitted_pca.transform(&data).unwrap();
598
599 assert_eq!(transformed.dim(), (3, 2));
600 }
601
602 #[test]
603 fn test_type_safe_pca_rank_validation() {
604 let data = array![
605 [1.0, 2.0], [3.0, 4.0],
607 ];
608
609 let pca: TypeSafePCA<Untrained, 3> = TypeSafePCA::new();
611 let result = pca.fit(&data);
612 assert!(result.is_err());
613 }
614
615 #[test]
616 fn test_type_safe_matrix_creation() {
617 let data = array![[1.0, 2.0], [3.0, 4.0]];
618 let matrix: TypeSafeMatrix<2, 2> = TypeSafeMatrix::new(data).unwrap();
619 assert_eq!(matrix.data().dim(), (2, 2));
620 }
621
622 #[test]
623 fn test_type_safe_matrix_dimension_validation() {
624 let data = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; let result: Result<TypeSafeMatrix<3, 3>> = TypeSafeMatrix::new(data);
626 assert!(result.is_err());
627 }
628
629 #[test]
630 fn test_type_safe_matrix_multiplication() {
631 let a_data = array![[1.0, 2.0], [3.0, 4.0]];
632 let b_data = array![[5.0, 6.0], [7.0, 8.0]];
633
634 let a: TypeSafeMatrix<2, 2> = TypeSafeMatrix::new(a_data).unwrap();
635 let b: TypeSafeMatrix<2, 2> = TypeSafeMatrix::new(b_data).unwrap();
636
637 let result: TypeSafeMatrix<2, 2> = a.dot(&b);
638 assert_eq!(result.data().dim(), (2, 2));
639 }
640
641 #[test]
642 fn test_type_safe_matrix_transpose() {
643 let data = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
644 let matrix: TypeSafeMatrix<2, 3> = TypeSafeMatrix::new(data).unwrap();
645 let transposed: TypeSafeMatrix<3, 2> = matrix.t();
646 assert_eq!(transposed.data().dim(), (3, 2));
647 }
648
649 #[test]
650 fn test_component_index() {
651 let index: ComponentIndex<0> = ComponentIndex::new();
652 assert_eq!(index.index(), 0);
653
654 let index: ComponentIndex<5> = ComponentIndex::new();
655 assert_eq!(index.index(), 5);
656 }
657
658 #[test]
659 fn test_decomposition_pipeline() {
660 let data = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0],];
661
662 let pipeline = DecompositionPipeline::new().center().scale();
663
664 let processed = pipeline.apply(&data).unwrap();
665 assert_eq!(processed.dim(), data.dim());
666
667 let fitted_pipeline = pipeline.fit(&data).unwrap();
668 let transformed = fitted_pipeline.transform(&data).unwrap();
669 assert_eq!(transformed.dim(), data.dim());
670 }
671
672 #[test]
673 fn test_matrix_shape_validation() {
674 let a_data = array![[1.0, 2.0], [3.0, 4.0]];
675 let b_data = array![[5.0, 6.0], [7.0, 8.0]];
676
677 let a: TypeSafeMatrix<2, 2> = TypeSafeMatrix::new(a_data).unwrap();
678 let b: TypeSafeMatrix<2, 2> = TypeSafeMatrix::new(b_data).unwrap();
679
680 let result = validate_matrix_multiplication(&a, &b);
682 assert!(result.is_ok());
683 }
684}