1use crate::{UtilsError, UtilsResult};
7use scirs2_core::ndarray::{Array1, Array2};
8use std::marker::PhantomData;
9
10pub struct Untrained;
14
15pub struct Trained;
17
18pub struct Validated;
20
21pub struct Unvalidated;
23
24#[derive(Debug, Clone)]
26pub struct ModelState<T, State> {
27 pub inner: T,
28 _state: PhantomData<State>,
29}
30
31impl<T> ModelState<T, Untrained> {
32 pub fn new(inner: T) -> Self {
34 Self {
35 inner,
36 _state: PhantomData,
37 }
38 }
39
40 pub fn train(self) -> ModelState<T, Trained> {
42 ModelState {
43 inner: self.inner,
44 _state: PhantomData,
45 }
46 }
47}
48
49impl<T> ModelState<T, Trained> {
50 pub fn predict<F, Input, Output>(&self, predict_fn: F, input: Input) -> Output
52 where
53 F: Fn(&T, Input) -> Output,
54 {
55 predict_fn(&self.inner, input)
56 }
57
58 pub fn reset(self) -> ModelState<T, Untrained> {
60 ModelState {
61 inner: self.inner,
62 _state: PhantomData,
63 }
64 }
65}
66
67#[derive(Debug, Clone)]
71pub struct DataState<T, State> {
72 pub data: T,
73 _state: PhantomData<State>,
74}
75
76impl<T> DataState<T, Unvalidated> {
77 pub fn new(data: T) -> Self {
79 Self {
80 data,
81 _state: PhantomData,
82 }
83 }
84
85 pub fn validate<F>(self, validator: F) -> UtilsResult<DataState<T, Validated>>
87 where
88 F: FnOnce(&T) -> UtilsResult<()>,
89 {
90 validator(&self.data)?;
91 Ok(DataState {
92 data: self.data,
93 _state: PhantomData,
94 })
95 }
96}
97
98impl<T> DataState<T, Validated> {
99 pub fn as_validated(&self) -> &T {
101 &self.data
102 }
103
104 pub fn map<U, F>(self, transform: F) -> DataState<U, Validated>
106 where
107 F: FnOnce(T) -> U,
108 {
109 DataState {
110 data: transform(self.data),
111 _state: PhantomData,
112 }
113 }
114}
115
116pub struct D1;
120pub struct D2;
121pub struct D3;
122
123#[derive(Debug, Clone)]
125pub struct TypedArray<T, D> {
126 data: T,
127 _dimension: PhantomData<D>,
128}
129
130impl<T> TypedArray<Array1<T>, D1> {
131 pub fn new_1d(array: Array1<T>) -> Self {
133 Self {
134 data: array,
135 _dimension: PhantomData,
136 }
137 }
138
139 pub fn as_array1(&self) -> &Array1<T> {
141 &self.data
142 }
143
144 pub fn to_2d(self) -> TypedArray<Array2<T>, D2>
146 where
147 T: Clone,
148 {
149 let shape = (1, self.data.len());
150 let data = Array2::from_shape_vec(shape, self.data.to_vec()).unwrap();
151 TypedArray {
152 data,
153 _dimension: PhantomData,
154 }
155 }
156}
157
158impl<T> TypedArray<Array2<T>, D2> {
159 pub fn new_2d(array: Array2<T>) -> Self {
161 Self {
162 data: array,
163 _dimension: PhantomData,
164 }
165 }
166
167 pub fn as_array2(&self) -> &Array2<T> {
169 &self.data
170 }
171
172 pub fn shape(&self) -> (usize, usize) {
174 let shape = self.data.shape();
175 (shape[0], shape[1])
176 }
177
178 pub fn flatten(self) -> TypedArray<Array1<T>, D1>
180 where
181 T: Clone,
182 {
183 let (vec, offset) = self.data.into_raw_vec_and_offset();
184 assert_eq!(offset, Some(0), "Array offset must be zero for conversion");
185 let data = Array1::from_vec(vec);
186 TypedArray {
187 data,
188 _dimension: PhantomData,
189 }
190 }
191}
192
193pub trait Unit: 'static {
197 const NAME: &'static str;
198}
199
200pub struct Meters;
201pub struct Seconds;
202pub struct Kilograms;
203pub struct Pixels;
204
205impl Unit for Meters {
206 const NAME: &'static str = "meters";
207}
208
209impl Unit for Seconds {
210 const NAME: &'static str = "seconds";
211}
212
213impl Unit for Kilograms {
214 const NAME: &'static str = "kilograms";
215}
216
217impl Unit for Pixels {
218 const NAME: &'static str = "pixels";
219}
220
221#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
223pub struct Measurement<T, U: Unit> {
224 value: T,
225 _unit: PhantomData<U>,
226}
227
228impl<T, U: Unit> Measurement<T, U> {
229 pub fn new(value: T) -> Self {
231 Self {
232 value,
233 _unit: PhantomData,
234 }
235 }
236
237 pub fn value(&self) -> &T {
239 &self.value
240 }
241
242 pub unsafe fn convert_unit<V: Unit>(self) -> Measurement<T, V> {
249 Measurement {
250 value: self.value,
251 _unit: PhantomData,
252 }
253 }
254}
255
256impl<T, U: Unit> std::ops::Add for Measurement<T, U>
257where
258 T: std::ops::Add<Output = T>,
259{
260 type Output = Self;
261
262 fn add(self, other: Self) -> Self::Output {
263 Self {
264 value: self.value + other.value,
265 _unit: PhantomData,
266 }
267 }
268}
269
270impl<T, U: Unit> std::ops::Sub for Measurement<T, U>
271where
272 T: std::ops::Sub<Output = T>,
273{
274 type Output = Self;
275
276 fn sub(self, other: Self) -> Self::Output {
277 Self {
278 value: self.value - other.value,
279 _unit: PhantomData,
280 }
281 }
282}
283
284pub trait ShapeValidation {
288 type Shape;
289 fn validate_shape(shape: Self::Shape) -> bool;
290}
291
292pub struct ExactSize<const N: usize>;
294
295impl<const N: usize> ShapeValidation for ExactSize<N> {
296 type Shape = usize;
297
298 fn validate_shape(shape: Self::Shape) -> bool {
299 shape == N
300 }
301}
302
303pub struct MinSize<const N: usize>;
305
306impl<const N: usize> ShapeValidation for MinSize<N> {
307 type Shape = usize;
308
309 fn validate_shape(shape: Self::Shape) -> bool {
310 shape >= N
311 }
312}
313
314pub struct ValidatedArray<T, V: ShapeValidation> {
316 data: Array1<T>,
317 _validator: PhantomData<V>,
318}
319
320impl<T, V: ShapeValidation<Shape = usize>> ValidatedArray<T, V> {
321 pub fn new(data: Array1<T>) -> Option<Self> {
323 if V::validate_shape(data.len()) {
324 Some(Self {
325 data,
326 _validator: PhantomData,
327 })
328 } else {
329 None
330 }
331 }
332
333 pub fn data(&self) -> &Array1<T> {
335 &self.data
336 }
337}
338
339#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
343pub struct Normalized<T>(T);
344
345impl<T> Normalized<T> {
346 pub unsafe fn new_unchecked(value: T) -> Self {
353 Self(value)
354 }
355
356 pub fn get(self) -> T {
358 self.0
359 }
360}
361
362impl Normalized<f64> {
363 pub fn new(value: f64) -> UtilsResult<Self> {
365 if (0.0..=1.0).contains(&value) {
366 Ok(Self(value))
367 } else {
368 Err(UtilsError::InvalidParameter(format!(
369 "Value {value} is not in range [0, 1]"
370 )))
371 }
372 }
373
374 pub fn clamp(value: f64) -> Self {
376 Self(value.clamp(0.0, 1.0))
377 }
378}
379
380#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
382pub struct Positive<T>(T);
383
384impl<T> Positive<T> {
385 pub fn get(self) -> T {
387 self.0
388 }
389}
390
391impl Positive<f64> {
392 pub fn new(value: f64) -> UtilsResult<Self> {
394 if value > 0.0 {
395 Ok(Self(value))
396 } else {
397 Err(UtilsError::InvalidParameter(format!(
398 "Value {value} is not positive"
399 )))
400 }
401 }
402}
403
404#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
406pub struct NonNegative<T>(T);
407
408impl<T> NonNegative<T> {
409 pub fn get(self) -> T {
411 self.0
412 }
413}
414
415impl NonNegative<f64> {
416 pub fn new(value: f64) -> UtilsResult<Self> {
418 if value >= 0.0 {
419 Ok(Self(value))
420 } else {
421 Err(UtilsError::InvalidParameter(format!(
422 "Value {value} is negative"
423 )))
424 }
425 }
426}
427
428#[macro_export]
432macro_rules! const_assert {
433 ($condition:expr) => {
434 const _: () = if !$condition {
435 panic!("Compile-time assertion failed");
436 } else {
437 ()
438 };
439 };
440}
441
442#[macro_export]
444macro_rules! assert_shape {
445 ($array:expr, $expected:expr) => {
446 if $array.shape() != $expected {
447 return Err(UtilsError::ShapeMismatch {
448 expected: $expected.to_vec(),
449 actual: $array.shape().to_vec(),
450 });
451 }
452 };
453}
454
455pub trait TypeNum {
459 const VALUE: usize;
460}
461
462pub struct Zero;
463pub struct One;
464pub struct Two;
465pub struct Three;
466
467impl TypeNum for Zero {
468 const VALUE: usize = 0;
469}
470impl TypeNum for One {
471 const VALUE: usize = 1;
472}
473impl TypeNum for Two {
474 const VALUE: usize = 2;
475}
476impl TypeNum for Three {
477 const VALUE: usize = 3;
478}
479
480pub trait Add<Rhs> {
482 type Output: TypeNum;
483}
484
485impl Add<Zero> for Zero {
486 type Output = Zero;
487}
488impl Add<One> for Zero {
489 type Output = One;
490}
491impl Add<Two> for Zero {
492 type Output = Two;
493}
494impl Add<Zero> for One {
495 type Output = One;
496}
497impl Add<One> for One {
498 type Output = Two;
499}
500impl Add<Two> for One {
501 type Output = Three;
502}
503
504pub struct MatrixMul<L: TypeNum, M: TypeNum, N: TypeNum> {
506 _phantom: PhantomData<(L, M, N)>,
507}
508
509impl<L: TypeNum, M: TypeNum, N: TypeNum> MatrixMul<L, M, N> {
510 pub fn multiply(left: &Array2<f64>, right: &Array2<f64>) -> UtilsResult<Array2<f64>> {
512 let left_shape = left.shape();
514 let right_shape = right.shape();
515
516 if left_shape[1] != right_shape[0] {
517 return Err(UtilsError::ShapeMismatch {
518 expected: vec![left_shape[0], right_shape[1]],
519 actual: vec![left_shape[0], left_shape[1], right_shape[0], right_shape[1]],
520 });
521 }
522
523 Ok(left.dot(right))
524 }
525}
526
527#[allow(non_snake_case)]
528#[cfg(test)]
529mod tests {
530 use super::*;
531
532 #[test]
533 fn test_model_state_transitions() {
534 #[derive(Debug, Clone)]
535 struct MockModel {
536 value: i32,
537 }
538
539 let model = MockModel { value: 42 };
540 let untrained = ModelState::new(model);
541
542 let trained = untrained.train();
544
545 let result = trained.predict(|model, input: i32| model.value + input, 10);
547 assert_eq!(result, 52);
548
549 let _reset = trained.reset();
551 }
552
553 #[test]
554 fn test_data_validation() {
555 let data = vec![1, 2, 3, 4, 5];
556 let unvalidated = DataState::new(data);
557
558 let validated = unvalidated
560 .validate(|data| {
561 if data.iter().all(|&x| x > 0) {
562 Ok(())
563 } else {
564 Err(UtilsError::InvalidParameter(
565 "Negative values found".to_string(),
566 ))
567 }
568 })
569 .unwrap();
570
571 let validated_data = validated.as_validated();
573 assert_eq!(validated_data.len(), 5);
574
575 let transformed = validated.map(|data| data.len());
577 assert_eq!(*transformed.as_validated(), 5);
578 }
579
580 #[test]
581 fn test_typed_arrays() {
582 let array1d = Array1::from_vec(vec![1.0, 2.0, 3.0]);
583 let typed1d = TypedArray::new_1d(array1d);
584
585 let typed2d = typed1d.to_2d();
587 assert_eq!(typed2d.shape(), (1, 3));
588
589 let flattened = typed2d.flatten();
591 assert_eq!(flattened.as_array1().len(), 3);
592 }
593
594 #[test]
595 fn test_measurements() {
596 let distance1 = Measurement::<f64, Meters>::new(10.0);
597 let distance2 = Measurement::<f64, Meters>::new(5.0);
598
599 let total_distance = distance1 + distance2;
600 assert_eq!(*total_distance.value(), 15.0);
601
602 let _time = Measurement::<f64, Seconds>::new(2.0);
603 }
605
606 #[test]
607 fn test_normalized_values() {
608 let norm1 = Normalized::new(0.5).unwrap();
610 assert_eq!(norm1.get(), 0.5);
611
612 assert!(Normalized::new(1.5).is_err());
614
615 let norm2 = Normalized::clamp(1.5);
617 assert_eq!(norm2.get(), 1.0);
618 }
619
620 #[test]
621 fn test_positive_values() {
622 let pos = Positive::new(5.0).unwrap();
623 assert_eq!(pos.get(), 5.0);
624
625 assert!(Positive::new(-1.0).is_err());
626 assert!(Positive::new(0.0).is_err());
627 }
628
629 #[test]
630 fn test_validated_arrays() {
631 let data = Array1::from_vec(vec![1, 2, 3]);
632
633 let validated = ValidatedArray::<i32, ExactSize<3>>::new(data.clone());
635 assert!(validated.is_some());
636
637 let validated = ValidatedArray::<i32, ExactSize<5>>::new(data.clone());
639 assert!(validated.is_none());
640
641 let validated = ValidatedArray::<i32, MinSize<2>>::new(data);
643 assert!(validated.is_some());
644 }
645
646 #[test]
647 fn test_matrix_multiplication_validation() {
648 let left = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
649 let right = Array2::from_shape_vec((3, 2), vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0]).unwrap();
650
651 let result = MatrixMul::<Two, Three, Two>::multiply(&left, &right).unwrap();
652 assert_eq!(result.shape(), &[2, 2]);
653
654 let wrong_right = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
656 assert!(MatrixMul::<Two, Three, Two>::multiply(&left, &wrong_right).is_err());
657 }
658}