1use scirs2_core::ndarray::{Array1, Array2};
11use sklears_core::prelude::*;
12use std::marker::PhantomData;
13
14pub trait TransformState: sealed::Sealed {}
20
21#[derive(Debug, Clone, Copy)]
23pub struct Unfitted;
24
25#[derive(Debug, Clone, Copy)]
27pub struct Fitted;
28
29mod sealed {
30 pub trait Sealed {}
31 impl Sealed for super::Unfitted {}
32 impl Sealed for super::Fitted {}
33
34 pub trait DimensionSealed {}
35 impl DimensionSealed for super::Dynamic {}
36 impl<const N: usize> DimensionSealed for super::Known<N> {}
37}
38
39impl TransformState for Unfitted {}
40impl TransformState for Fitted {}
41
42pub struct Dynamic;
48
49pub struct Known<const N: usize>;
51
52pub trait Dimension: sealed::DimensionSealed {
54 fn value() -> Option<usize>;
56}
57
58impl Dimension for Dynamic {
59 fn value() -> Option<usize> {
60 None
61 }
62}
63
64impl<const N: usize> Dimension for Known<N> {
65 fn value() -> Option<usize> {
66 Some(N)
67 }
68}
69
70#[derive(Debug, Clone)]
81pub struct TypeSafeTransformer<S: TransformState, InDim: Dimension, OutDim: Dimension> {
82 config: TypeSafeConfig,
84 input_dim: Option<usize>,
86 output_dim: Option<usize>,
88 parameters: Option<TransformParameters>,
90 _state: PhantomData<S>,
92 _in_dim: PhantomData<InDim>,
94 _out_dim: PhantomData<OutDim>,
96}
97
98#[derive(Debug, Clone)]
100pub struct TypeSafeConfig {
101 pub validate_dimensions: bool,
103 pub normalize: bool,
105}
106
107impl Default for TypeSafeConfig {
108 fn default() -> Self {
109 Self {
110 validate_dimensions: true,
111 normalize: false,
112 }
113 }
114}
115
116#[derive(Debug, Clone)]
118struct TransformParameters {
119 mean: Array1<f64>,
121 std: Array1<f64>,
123}
124
125impl<InDim: Dimension, OutDim: Dimension> TypeSafeTransformer<Unfitted, InDim, OutDim> {
130 pub fn new(config: TypeSafeConfig) -> TypeSafeTransformer<Unfitted, Dynamic, Dynamic> {
132 TypeSafeTransformer {
133 config,
134 input_dim: None,
135 output_dim: None,
136 parameters: None,
137 _state: PhantomData,
138 _in_dim: PhantomData,
139 _out_dim: PhantomData,
140 }
141 }
142
143 pub fn with_input_dim<const N: usize>(
145 config: TypeSafeConfig,
146 ) -> TypeSafeTransformer<Unfitted, Known<N>, Dynamic> {
147 TypeSafeTransformer {
148 config,
149 input_dim: Some(N),
150 output_dim: None,
151 parameters: None,
152 _state: PhantomData,
153 _in_dim: PhantomData,
154 _out_dim: PhantomData,
155 }
156 }
157
158 pub fn with_dimensions<const IN: usize, const OUT: usize>(
160 config: TypeSafeConfig,
161 ) -> TypeSafeTransformer<Unfitted, Known<IN>, Known<OUT>> {
162 TypeSafeTransformer {
163 config,
164 input_dim: Some(IN),
165 output_dim: Some(OUT),
166 parameters: None,
167 _state: PhantomData,
168 _in_dim: PhantomData,
169 _out_dim: PhantomData,
170 }
171 }
172}
173
174impl TypeSafeTransformer<Unfitted, Dynamic, Dynamic> {
176 pub fn fit(self, X: &Array2<f64>) -> Result<TypeSafeTransformer<Fitted, Dynamic, Dynamic>> {
178 let input_dim = X.ncols();
179 let output_dim = X.ncols(); let parameters = if self.config.normalize {
182 let mean = X
183 .mean_axis(scirs2_core::ndarray::Axis(0))
184 .ok_or_else(|| SklearsError::InvalidInput("Failed to compute mean".to_string()))?;
185 let std = X.std_axis(scirs2_core::ndarray::Axis(0), 0.0);
186 Some(TransformParameters { mean, std })
187 } else {
188 None
189 };
190
191 Ok(TypeSafeTransformer {
192 config: self.config,
193 input_dim: Some(input_dim),
194 output_dim: Some(output_dim),
195 parameters,
196 _state: PhantomData,
197 _in_dim: PhantomData,
198 _out_dim: PhantomData,
199 })
200 }
201}
202
203impl<const N: usize> TypeSafeTransformer<Unfitted, Known<N>, Dynamic> {
205 pub fn fit(self, X: &Array2<f64>) -> Result<TypeSafeTransformer<Fitted, Known<N>, Dynamic>> {
207 if X.ncols() != N {
208 return Err(SklearsError::InvalidInput(format!(
209 "Expected {} input features, got {}",
210 N,
211 X.ncols()
212 )));
213 }
214
215 let output_dim = X.ncols();
216
217 let parameters = if self.config.normalize {
218 let mean = X
219 .mean_axis(scirs2_core::ndarray::Axis(0))
220 .ok_or_else(|| SklearsError::InvalidInput("Failed to compute mean".to_string()))?;
221 let std = X.std_axis(scirs2_core::ndarray::Axis(0), 0.0);
222 Some(TransformParameters { mean, std })
223 } else {
224 None
225 };
226
227 Ok(TypeSafeTransformer {
228 config: self.config,
229 input_dim: Some(N),
230 output_dim: Some(output_dim),
231 parameters,
232 _state: PhantomData,
233 _in_dim: PhantomData,
234 _out_dim: PhantomData,
235 })
236 }
237}
238
239impl<const IN: usize, const OUT: usize> TypeSafeTransformer<Unfitted, Known<IN>, Known<OUT>> {
241 pub fn fit(
243 self,
244 X: &Array2<f64>,
245 ) -> Result<TypeSafeTransformer<Fitted, Known<IN>, Known<OUT>>> {
246 if X.ncols() != IN {
247 return Err(SklearsError::InvalidInput(format!(
248 "Expected {} input features, got {}",
249 IN,
250 X.ncols()
251 )));
252 }
253
254 let parameters = if self.config.normalize {
255 let mean = X
256 .mean_axis(scirs2_core::ndarray::Axis(0))
257 .ok_or_else(|| SklearsError::InvalidInput("Failed to compute mean".to_string()))?;
258 let std = X.std_axis(scirs2_core::ndarray::Axis(0), 0.0);
259 Some(TransformParameters { mean, std })
260 } else {
261 None
262 };
263
264 Ok(TypeSafeTransformer {
265 config: self.config,
266 input_dim: Some(IN),
267 output_dim: Some(OUT),
268 parameters,
269 _state: PhantomData,
270 _in_dim: PhantomData,
271 _out_dim: PhantomData,
272 })
273 }
274}
275
276impl TypeSafeTransformer<Fitted, Dynamic, Dynamic> {
282 pub fn transform(&self, X: &Array2<f64>) -> Result<Array2<f64>> {
284 if let Some(input_dim) = self.input_dim {
285 if X.ncols() != input_dim {
286 return Err(SklearsError::InvalidInput(format!(
287 "Expected {} input features, got {}",
288 input_dim,
289 X.ncols()
290 )));
291 }
292 }
293
294 let mut result = X.clone();
295
296 if let Some(ref params) = self.parameters {
297 for i in 0..result.nrows() {
298 for j in 0..result.ncols() {
299 result[[i, j]] = (result[[i, j]] - params.mean[j]) / params.std[j].max(1e-10);
300 }
301 }
302 }
303
304 Ok(result)
305 }
306}
307
308impl<const N: usize> TypeSafeTransformer<Fitted, Known<N>, Dynamic> {
310 pub fn transform(&self, X: &Array2<f64>) -> Result<Array2<f64>> {
312 if X.ncols() != N {
313 return Err(SklearsError::InvalidInput(format!(
314 "Expected {} input features, got {}",
315 N,
316 X.ncols()
317 )));
318 }
319
320 let mut result = X.clone();
321
322 if let Some(ref params) = self.parameters {
323 for i in 0..result.nrows() {
324 for j in 0..result.ncols() {
325 result[[i, j]] = (result[[i, j]] - params.mean[j]) / params.std[j].max(1e-10);
326 }
327 }
328 }
329
330 Ok(result)
331 }
332}
333
334impl<const IN: usize, const OUT: usize> TypeSafeTransformer<Fitted, Known<IN>, Known<OUT>> {
336 pub fn transform(&self, X: &Array2<f64>) -> Result<Array2<f64>> {
338 if X.ncols() != IN {
339 return Err(SklearsError::InvalidInput(format!(
340 "Expected {} input features, got {}",
341 IN,
342 X.ncols()
343 )));
344 }
345
346 let mut result = X.clone();
347
348 if let Some(ref params) = self.parameters {
349 for i in 0..result.nrows() {
350 for j in 0..result.ncols() {
351 result[[i, j]] = (result[[i, j]] - params.mean[j]) / params.std[j].max(1e-10);
352 }
353 }
354 }
355
356 if result.ncols() != OUT {
358 return Err(SklearsError::InvalidInput(format!(
359 "Expected {} output features, got {}",
360 OUT,
361 result.ncols()
362 )));
363 }
364
365 Ok(result)
366 }
367}
368
369pub struct TypeSafePipeline<S1, S2, D1, D2, D3>
375where
376 S1: TransformState,
377 S2: TransformState,
378 D1: Dimension,
379 D2: Dimension,
380 D3: Dimension,
381{
382 first: TypeSafeTransformer<S1, D1, D2>,
384 second: TypeSafeTransformer<S2, D2, D3>,
386}
387
388impl<D1: Dimension, D2: Dimension, D3: Dimension> TypeSafePipeline<Unfitted, Unfitted, D1, D2, D3> {
390 pub fn new(
392 first: TypeSafeTransformer<Unfitted, D1, D2>,
393 second: TypeSafeTransformer<Unfitted, D2, D3>,
394 ) -> Self {
395 Self { first, second }
396 }
397}
398
399impl TypeSafePipeline<Unfitted, Unfitted, Dynamic, Dynamic, Dynamic> {
401 pub fn fit(
403 self,
404 X: &Array2<f64>,
405 ) -> Result<TypeSafePipeline<Fitted, Fitted, Dynamic, Dynamic, Dynamic>> {
406 let first_fitted = self.first.fit(X)?;
407 let X_transformed = first_fitted.transform(X)?;
408 let second_fitted = self.second.fit(&X_transformed)?;
409
410 Ok(TypeSafePipeline {
411 first: first_fitted,
412 second: second_fitted,
413 })
414 }
415}
416
417impl<const D1: usize, const D2: usize, const D3: usize>
419 TypeSafePipeline<Unfitted, Unfitted, Known<D1>, Known<D2>, Known<D3>>
420{
421 pub fn fit(
423 self,
424 X: &Array2<f64>,
425 ) -> Result<TypeSafePipeline<Fitted, Fitted, Known<D1>, Known<D2>, Known<D3>>> {
426 let first_fitted = self.first.fit(X)?;
427 let X_transformed = first_fitted.transform(X)?;
428 let second_fitted = self.second.fit(&X_transformed)?;
429
430 Ok(TypeSafePipeline {
431 first: first_fitted,
432 second: second_fitted,
433 })
434 }
435}
436
437impl TypeSafePipeline<Fitted, Fitted, Dynamic, Dynamic, Dynamic> {
439 pub fn transform(&self, X: &Array2<f64>) -> Result<Array2<f64>> {
441 let X_intermediate = self.first.transform(X)?;
442 self.second.transform(&X_intermediate)
443 }
444}
445
446impl<const D1: usize, const D2: usize, const D3: usize>
448 TypeSafePipeline<Fitted, Fitted, Known<D1>, Known<D2>, Known<D3>>
449{
450 pub fn transform(&self, X: &Array2<f64>) -> Result<Array2<f64>> {
452 let X_intermediate = self.first.transform(X)?;
453 self.second.transform(&X_intermediate)
454 }
455}
456
457#[cfg(test)]
462mod tests {
463 use super::*;
464 use scirs2_core::ndarray::array;
465
466 #[test]
467 fn test_dynamic_dimensions() {
468 let X = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
469
470 let config = TypeSafeConfig::default();
471 let transformer: TypeSafeTransformer<Unfitted, Dynamic, Dynamic> =
472 TypeSafeTransformer::<Unfitted, Dynamic, Dynamic>::new(config);
473 let fitted = transformer.fit(&X).unwrap();
474 let result = fitted.transform(&X).unwrap();
475
476 assert_eq!(result.nrows(), 3);
477 assert_eq!(result.ncols(), 2);
478 }
479
480 #[test]
481 fn test_known_input_dimension() {
482 let X = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
483
484 let config = TypeSafeConfig::default();
485 let transformer: TypeSafeTransformer<Unfitted, Known<2>, Dynamic> =
486 TypeSafeTransformer::<Unfitted, Known<2>, Dynamic>::with_input_dim(config);
487 let fitted = transformer.fit(&X).unwrap();
488 let result = fitted.transform(&X).unwrap();
489
490 assert_eq!(result.ncols(), 2);
491 }
492
493 #[test]
494 fn test_known_input_dimension_mismatch() {
495 let X = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
496
497 let config = TypeSafeConfig::default();
498 let transformer: TypeSafeTransformer<Unfitted, Known<2>, Dynamic> =
499 TypeSafeTransformer::<Unfitted, Known<2>, Dynamic>::with_input_dim(config);
500
501 assert!(transformer.fit(&X).is_err());
503 }
504
505 #[test]
506 fn test_known_dimensions() {
507 let X = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
508
509 let config = TypeSafeConfig::default();
510 let transformer: TypeSafeTransformer<Unfitted, Known<2>, Known<2>> =
511 TypeSafeTransformer::<Unfitted, Known<2>, Known<2>>::with_dimensions(config);
512 let fitted = transformer.fit(&X).unwrap();
513 let result = fitted.transform(&X).unwrap();
514
515 assert_eq!(result.nrows(), 3);
516 assert_eq!(result.ncols(), 2);
517 }
518
519 #[test]
520 fn test_normalization() {
521 let X = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
522
523 let config = TypeSafeConfig {
524 validate_dimensions: true,
525 normalize: true,
526 };
527 let transformer: TypeSafeTransformer<Unfitted, Dynamic, Dynamic> =
528 TypeSafeTransformer::<Unfitted, Dynamic, Dynamic>::new(config);
529 let fitted = transformer.fit(&X).unwrap();
530 let result = fitted.transform(&X).unwrap();
531
532 let mean = result.mean_axis(scirs2_core::ndarray::Axis(0)).unwrap();
534 for &val in mean.iter() {
535 assert!((val.abs()) < 1e-10);
536 }
537 }
538
539 #[test]
540 fn test_pipeline_dynamic() {
541 let X = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
542
543 let config1 = TypeSafeConfig {
544 validate_dimensions: true,
545 normalize: true,
546 };
547 let config2 = TypeSafeConfig::default();
548
549 let transformer1: TypeSafeTransformer<Unfitted, Dynamic, Dynamic> =
550 TypeSafeTransformer::<Unfitted, Dynamic, Dynamic>::new(config1);
551 let transformer2: TypeSafeTransformer<Unfitted, Dynamic, Dynamic> =
552 TypeSafeTransformer::<Unfitted, Dynamic, Dynamic>::new(config2);
553
554 let pipeline = TypeSafePipeline::new(transformer1, transformer2);
555 let fitted_pipeline = pipeline.fit(&X).unwrap();
556 let result = fitted_pipeline.transform(&X).unwrap();
557
558 assert_eq!(result.nrows(), 3);
559 assert_eq!(result.ncols(), 2);
560 }
561
562 #[test]
563 fn test_pipeline_known_dimensions() {
564 let X = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
565
566 let config1 = TypeSafeConfig::default();
567 let config2 = TypeSafeConfig::default();
568
569 let transformer1: TypeSafeTransformer<Unfitted, Known<2>, Known<2>> =
570 TypeSafeTransformer::<Unfitted, Known<2>, Known<2>>::with_dimensions(config1);
571 let transformer2: TypeSafeTransformer<Unfitted, Known<2>, Known<2>> =
572 TypeSafeTransformer::<Unfitted, Known<2>, Known<2>>::with_dimensions(config2);
573
574 let pipeline = TypeSafePipeline::new(transformer1, transformer2);
575 let fitted_pipeline = pipeline.fit(&X).unwrap();
576 let result = fitted_pipeline.transform(&X).unwrap();
577
578 assert_eq!(result.nrows(), 3);
579 assert_eq!(result.ncols(), 2);
580 }
581
582 #[test]
583 fn test_state_transitions() {
584 let X = array![[1.0, 2.0], [3.0, 4.0]];
585
586 let unfitted: TypeSafeTransformer<Unfitted, Dynamic, Dynamic> =
588 TypeSafeTransformer::<Unfitted, Dynamic, Dynamic>::new(TypeSafeConfig::default());
589
590 let fitted = unfitted.fit(&X).unwrap();
592
593 let _result = fitted.transform(&X).unwrap();
595
596 }
599
600 #[test]
601 fn test_dimension_markers() {
602 assert_eq!(Dynamic::value(), None);
603 assert_eq!(Known::<5>::value(), Some(5));
604 assert_eq!(Known::<10>::value(), Some(10));
605 }
606}