sklears_utils/
type_safety.rs

1//! Type safety utilities for compile-time validation and zero-cost abstractions
2//!
3//! This module provides phantom types, zero-cost wrappers, and compile-time
4//! validation utilities to ensure type safety in machine learning operations.
5
6use crate::{UtilsError, UtilsResult};
7use scirs2_core::ndarray::{Array1, Array2};
8use std::marker::PhantomData;
9
10// ===== PHANTOM TYPES FOR STATE VALIDATION =====
11
12/// Phantom type for untrained state
13pub struct Untrained;
14
15/// Phantom type for trained state
16pub struct Trained;
17
18/// Phantom type for validated data
19pub struct Validated;
20
21/// Phantom type for unvalidated data
22pub struct Unvalidated;
23
24/// State-based wrapper for ML models
25#[derive(Debug, Clone)]
26pub struct ModelState<T, State> {
27    pub inner: T,
28    _state: PhantomData<State>,
29}
30
31impl<T> ModelState<T, Untrained> {
32    /// Create a new untrained model
33    pub fn new(inner: T) -> Self {
34        Self {
35            inner,
36            _state: PhantomData,
37        }
38    }
39
40    /// Transition to trained state (only available for untrained models)
41    pub fn train(self) -> ModelState<T, Trained> {
42        ModelState {
43            inner: self.inner,
44            _state: PhantomData,
45        }
46    }
47}
48
49impl<T> ModelState<T, Trained> {
50    /// Predict (only available for trained models)
51    pub fn predict<F, Input, Output>(&self, predict_fn: F, input: Input) -> Output
52    where
53        F: Fn(&T, Input) -> Output,
54    {
55        predict_fn(&self.inner, input)
56    }
57
58    /// Reset to untrained state
59    pub fn reset(self) -> ModelState<T, Untrained> {
60        ModelState {
61            inner: self.inner,
62            _state: PhantomData,
63        }
64    }
65}
66
67// ===== VALIDATED DATA TYPES =====
68
69/// Data wrapper with validation state
70#[derive(Debug, Clone)]
71pub struct DataState<T, State> {
72    pub data: T,
73    _state: PhantomData<State>,
74}
75
76impl<T> DataState<T, Unvalidated> {
77    /// Create new unvalidated data
78    pub fn new(data: T) -> Self {
79        Self {
80            data,
81            _state: PhantomData,
82        }
83    }
84
85    /// Validate data and transition to validated state
86    pub fn validate<F>(self, validator: F) -> UtilsResult<DataState<T, Validated>>
87    where
88        F: FnOnce(&T) -> UtilsResult<()>,
89    {
90        validator(&self.data)?;
91        Ok(DataState {
92            data: self.data,
93            _state: PhantomData,
94        })
95    }
96}
97
98impl<T> DataState<T, Validated> {
99    /// Access validated data (only available after validation)
100    pub fn as_validated(&self) -> &T {
101        &self.data
102    }
103
104    /// Transform validated data while preserving validation state
105    pub fn map<U, F>(self, transform: F) -> DataState<U, Validated>
106    where
107        F: FnOnce(T) -> U,
108    {
109        DataState {
110            data: transform(self.data),
111            _state: PhantomData,
112        }
113    }
114}
115
116// ===== DIMENSIONAL TYPE SAFETY =====
117
118/// Phantom types for dimensions
119pub struct D1;
120pub struct D2;
121pub struct D3;
122
123/// Dimensionally-typed array wrapper
124#[derive(Debug, Clone)]
125pub struct TypedArray<T, D> {
126    data: T,
127    _dimension: PhantomData<D>,
128}
129
130impl<T> TypedArray<Array1<T>, D1> {
131    /// Create a 1D typed array
132    pub fn new_1d(array: Array1<T>) -> Self {
133        Self {
134            data: array,
135            _dimension: PhantomData,
136        }
137    }
138
139    /// Get the underlying 1D array
140    pub fn as_array1(&self) -> &Array1<T> {
141        &self.data
142    }
143
144    /// Convert to 2D array (single row)
145    pub fn to_2d(self) -> TypedArray<Array2<T>, D2>
146    where
147        T: Clone,
148    {
149        let shape = (1, self.data.len());
150        let data = Array2::from_shape_vec(shape, self.data.to_vec()).unwrap();
151        TypedArray {
152            data,
153            _dimension: PhantomData,
154        }
155    }
156}
157
158impl<T> TypedArray<Array2<T>, D2> {
159    /// Create a 2D typed array
160    pub fn new_2d(array: Array2<T>) -> Self {
161        Self {
162            data: array,
163            _dimension: PhantomData,
164        }
165    }
166
167    /// Get the underlying 2D array
168    pub fn as_array2(&self) -> &Array2<T> {
169        &self.data
170    }
171
172    /// Get shape information
173    pub fn shape(&self) -> (usize, usize) {
174        let shape = self.data.shape();
175        (shape[0], shape[1])
176    }
177
178    /// Flatten to 1D array
179    pub fn flatten(self) -> TypedArray<Array1<T>, D1>
180    where
181        T: Clone,
182    {
183        let (vec, offset) = self.data.into_raw_vec_and_offset();
184        assert_eq!(offset, Some(0), "Array offset must be zero for conversion");
185        let data = Array1::from_vec(vec);
186        TypedArray {
187            data,
188            _dimension: PhantomData,
189        }
190    }
191}
192
193// ===== UNITS AND MEASUREMENTS =====
194
195/// Unit types for type-safe measurements
196pub trait Unit: 'static {
197    const NAME: &'static str;
198}
199
200pub struct Meters;
201pub struct Seconds;
202pub struct Kilograms;
203pub struct Pixels;
204
205impl Unit for Meters {
206    const NAME: &'static str = "meters";
207}
208
209impl Unit for Seconds {
210    const NAME: &'static str = "seconds";
211}
212
213impl Unit for Kilograms {
214    const NAME: &'static str = "kilograms";
215}
216
217impl Unit for Pixels {
218    const NAME: &'static str = "pixels";
219}
220
221/// Type-safe measurement with units
222#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
223pub struct Measurement<T, U: Unit> {
224    value: T,
225    _unit: PhantomData<U>,
226}
227
228impl<T, U: Unit> Measurement<T, U> {
229    /// Create a new measurement
230    pub fn new(value: T) -> Self {
231        Self {
232            value,
233            _unit: PhantomData,
234        }
235    }
236
237    /// Get the value
238    pub fn value(&self) -> &T {
239        &self.value
240    }
241
242    /// Convert to different unit (unsafe, requires manual verification)
243    ///
244    /// # Safety
245    ///
246    /// The caller must ensure that the conversion between units is mathematically valid
247    /// and that the value makes sense in the target unit system.
248    pub unsafe fn convert_unit<V: Unit>(self) -> Measurement<T, V> {
249        Measurement {
250            value: self.value,
251            _unit: PhantomData,
252        }
253    }
254}
255
256impl<T, U: Unit> std::ops::Add for Measurement<T, U>
257where
258    T: std::ops::Add<Output = T>,
259{
260    type Output = Self;
261
262    fn add(self, other: Self) -> Self::Output {
263        Self {
264            value: self.value + other.value,
265            _unit: PhantomData,
266        }
267    }
268}
269
270impl<T, U: Unit> std::ops::Sub for Measurement<T, U>
271where
272    T: std::ops::Sub<Output = T>,
273{
274    type Output = Self;
275
276    fn sub(self, other: Self) -> Self::Output {
277        Self {
278            value: self.value - other.value,
279            _unit: PhantomData,
280        }
281    }
282}
283
284// ===== COMPILE-TIME VALIDATION =====
285
286/// Trait for compile-time shape validation
287pub trait ShapeValidation {
288    type Shape;
289    fn validate_shape(shape: Self::Shape) -> bool;
290}
291
292/// Shape constraint: exactly N elements
293pub struct ExactSize<const N: usize>;
294
295impl<const N: usize> ShapeValidation for ExactSize<N> {
296    type Shape = usize;
297
298    fn validate_shape(shape: Self::Shape) -> bool {
299        shape == N
300    }
301}
302
303/// Shape constraint: minimum N elements
304pub struct MinSize<const N: usize>;
305
306impl<const N: usize> ShapeValidation for MinSize<N> {
307    type Shape = usize;
308
309    fn validate_shape(shape: Self::Shape) -> bool {
310        shape >= N
311    }
312}
313
314/// Shape-validated array
315pub struct ValidatedArray<T, V: ShapeValidation> {
316    data: Array1<T>,
317    _validator: PhantomData<V>,
318}
319
320impl<T, V: ShapeValidation<Shape = usize>> ValidatedArray<T, V> {
321    /// Create a validated array (compile-time check)
322    pub fn new(data: Array1<T>) -> Option<Self> {
323        if V::validate_shape(data.len()) {
324            Some(Self {
325                data,
326                _validator: PhantomData,
327            })
328        } else {
329            None
330        }
331    }
332
333    /// Access the validated data
334    pub fn data(&self) -> &Array1<T> {
335        &self.data
336    }
337}
338
339// ===== ZERO-COST ABSTRACTIONS =====
340
341/// Zero-cost wrapper for normalized values [0, 1]
342#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
343pub struct Normalized<T>(T);
344
345impl<T> Normalized<T> {
346    /// Create a normalized value (unsafe - assumes value is in [0, 1])
347    ///
348    /// # Safety
349    ///
350    /// The caller must ensure that the value is within the range [0, 1].
351    /// Using values outside this range will result in undefined behavior.
352    pub unsafe fn new_unchecked(value: T) -> Self {
353        Self(value)
354    }
355
356    /// Get the inner value
357    pub fn get(self) -> T {
358        self.0
359    }
360}
361
362impl Normalized<f64> {
363    /// Create a normalized value with validation
364    pub fn new(value: f64) -> UtilsResult<Self> {
365        if (0.0..=1.0).contains(&value) {
366            Ok(Self(value))
367        } else {
368            Err(UtilsError::InvalidParameter(format!(
369                "Value {value} is not in range [0, 1]"
370            )))
371        }
372    }
373
374    /// Clamp value to [0, 1] range
375    pub fn clamp(value: f64) -> Self {
376        Self(value.clamp(0.0, 1.0))
377    }
378}
379
380/// Zero-cost wrapper for positive values
381#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
382pub struct Positive<T>(T);
383
384impl<T> Positive<T> {
385    /// Get the inner value
386    pub fn get(self) -> T {
387        self.0
388    }
389}
390
391impl Positive<f64> {
392    /// Create a positive value with validation
393    pub fn new(value: f64) -> UtilsResult<Self> {
394        if value > 0.0 {
395            Ok(Self(value))
396        } else {
397            Err(UtilsError::InvalidParameter(format!(
398                "Value {value} is not positive"
399            )))
400        }
401    }
402}
403
404/// Zero-cost wrapper for non-negative values
405#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
406pub struct NonNegative<T>(T);
407
408impl<T> NonNegative<T> {
409    /// Get the inner value
410    pub fn get(self) -> T {
411        self.0
412    }
413}
414
415impl NonNegative<f64> {
416    /// Create a non-negative value with validation
417    pub fn new(value: f64) -> UtilsResult<Self> {
418        if value >= 0.0 {
419            Ok(Self(value))
420        } else {
421            Err(UtilsError::InvalidParameter(format!(
422                "Value {value} is negative"
423            )))
424        }
425    }
426}
427
428// ===== COMPILE-TIME ASSERTIONS =====
429
430/// Compile-time assertion macro
431#[macro_export]
432macro_rules! const_assert {
433    ($condition:expr) => {
434        const _: () = if !$condition {
435            panic!("Compile-time assertion failed");
436        } else {
437            ()
438        };
439    };
440}
441
442/// Compile-time shape assertion
443#[macro_export]
444macro_rules! assert_shape {
445    ($array:expr, $expected:expr) => {
446        if $array.shape() != $expected {
447            return Err(UtilsError::ShapeMismatch {
448                expected: $expected.to_vec(),
449                actual: $array.shape().to_vec(),
450            });
451        }
452    };
453}
454
455// ===== TYPE-LEVEL COMPUTATION =====
456
457/// Type-level arithmetic for compile-time computation
458pub trait TypeNum {
459    const VALUE: usize;
460}
461
462pub struct Zero;
463pub struct One;
464pub struct Two;
465pub struct Three;
466
467impl TypeNum for Zero {
468    const VALUE: usize = 0;
469}
470impl TypeNum for One {
471    const VALUE: usize = 1;
472}
473impl TypeNum for Two {
474    const VALUE: usize = 2;
475}
476impl TypeNum for Three {
477    const VALUE: usize = 3;
478}
479
480/// Add two type-level numbers
481pub trait Add<Rhs> {
482    type Output: TypeNum;
483}
484
485impl Add<Zero> for Zero {
486    type Output = Zero;
487}
488impl Add<One> for Zero {
489    type Output = One;
490}
491impl Add<Two> for Zero {
492    type Output = Two;
493}
494impl Add<Zero> for One {
495    type Output = One;
496}
497impl Add<One> for One {
498    type Output = Two;
499}
500impl Add<Two> for One {
501    type Output = Three;
502}
503
504/// Compile-time validated matrix multiplication
505pub struct MatrixMul<L: TypeNum, M: TypeNum, N: TypeNum> {
506    _phantom: PhantomData<(L, M, N)>,
507}
508
509impl<L: TypeNum, M: TypeNum, N: TypeNum> MatrixMul<L, M, N> {
510    /// Validate matrix multiplication at compile time
511    pub fn multiply(left: &Array2<f64>, right: &Array2<f64>) -> UtilsResult<Array2<f64>> {
512        // Runtime validation (would be compile-time in a full implementation)
513        let left_shape = left.shape();
514        let right_shape = right.shape();
515
516        if left_shape[1] != right_shape[0] {
517            return Err(UtilsError::ShapeMismatch {
518                expected: vec![left_shape[0], right_shape[1]],
519                actual: vec![left_shape[0], left_shape[1], right_shape[0], right_shape[1]],
520            });
521        }
522
523        Ok(left.dot(right))
524    }
525}
526
527#[allow(non_snake_case)]
528#[cfg(test)]
529mod tests {
530    use super::*;
531
532    #[test]
533    fn test_model_state_transitions() {
534        #[derive(Debug, Clone)]
535        struct MockModel {
536            value: i32,
537        }
538
539        let model = MockModel { value: 42 };
540        let untrained = ModelState::new(model);
541
542        // Can only train untrained models
543        let trained = untrained.train();
544
545        // Can only predict with trained models
546        let result = trained.predict(|model, input: i32| model.value + input, 10);
547        assert_eq!(result, 52);
548
549        // Can reset trained model to untrained
550        let _reset = trained.reset();
551    }
552
553    #[test]
554    fn test_data_validation() {
555        let data = vec![1, 2, 3, 4, 5];
556        let unvalidated = DataState::new(data);
557
558        // Validate that all elements are positive
559        let validated = unvalidated
560            .validate(|data| {
561                if data.iter().all(|&x| x > 0) {
562                    Ok(())
563                } else {
564                    Err(UtilsError::InvalidParameter(
565                        "Negative values found".to_string(),
566                    ))
567                }
568            })
569            .unwrap();
570
571        // Can access validated data
572        let validated_data = validated.as_validated();
573        assert_eq!(validated_data.len(), 5);
574
575        // Transform while preserving validation
576        let transformed = validated.map(|data| data.len());
577        assert_eq!(*transformed.as_validated(), 5);
578    }
579
580    #[test]
581    fn test_typed_arrays() {
582        let array1d = Array1::from_vec(vec![1.0, 2.0, 3.0]);
583        let typed1d = TypedArray::new_1d(array1d);
584
585        // Convert to 2D
586        let typed2d = typed1d.to_2d();
587        assert_eq!(typed2d.shape(), (1, 3));
588
589        // Flatten back to 1D
590        let flattened = typed2d.flatten();
591        assert_eq!(flattened.as_array1().len(), 3);
592    }
593
594    #[test]
595    fn test_measurements() {
596        let distance1 = Measurement::<f64, Meters>::new(10.0);
597        let distance2 = Measurement::<f64, Meters>::new(5.0);
598
599        let total_distance = distance1 + distance2;
600        assert_eq!(*total_distance.value(), 15.0);
601
602        let _time = Measurement::<f64, Seconds>::new(2.0);
603        // This would not compile: distance1 + time (different units)
604    }
605
606    #[test]
607    fn test_normalized_values() {
608        // Valid normalized value
609        let norm1 = Normalized::new(0.5).unwrap();
610        assert_eq!(norm1.get(), 0.5);
611
612        // Invalid normalized value
613        assert!(Normalized::new(1.5).is_err());
614
615        // Clamped value
616        let norm2 = Normalized::clamp(1.5);
617        assert_eq!(norm2.get(), 1.0);
618    }
619
620    #[test]
621    fn test_positive_values() {
622        let pos = Positive::new(5.0).unwrap();
623        assert_eq!(pos.get(), 5.0);
624
625        assert!(Positive::new(-1.0).is_err());
626        assert!(Positive::new(0.0).is_err());
627    }
628
629    #[test]
630    fn test_validated_arrays() {
631        let data = Array1::from_vec(vec![1, 2, 3]);
632
633        // Should succeed for ExactSize<3>
634        let validated = ValidatedArray::<i32, ExactSize<3>>::new(data.clone());
635        assert!(validated.is_some());
636
637        // Should fail for ExactSize<5>
638        let validated = ValidatedArray::<i32, ExactSize<5>>::new(data.clone());
639        assert!(validated.is_none());
640
641        // Should succeed for MinSize<2>
642        let validated = ValidatedArray::<i32, MinSize<2>>::new(data);
643        assert!(validated.is_some());
644    }
645
646    #[test]
647    fn test_matrix_multiplication_validation() {
648        let left = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
649        let right = Array2::from_shape_vec((3, 2), vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0]).unwrap();
650
651        let result = MatrixMul::<Two, Three, Two>::multiply(&left, &right).unwrap();
652        assert_eq!(result.shape(), &[2, 2]);
653
654        // Should fail with incompatible shapes
655        let wrong_right = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
656        assert!(MatrixMul::<Two, Three, Two>::multiply(&left, &wrong_right).is_err());
657    }
658}