sklears_dummy/
type_safe.rs

1//! Type-Safe Dummy Estimators
2//!
3//! This module provides compile-time guarantees for dummy estimators using Rust's type system.
4//! It includes phantom types, state validation, and compile-time configuration verification.
5
6use scirs2_core::ndarray::{Array1, Array2};
7use sklears_core::error::SklearsError;
8use sklears_core::traits::{Fit, Predict};
9use std::marker::PhantomData;
10
11/// Marker trait for estimator states
12pub trait EstimatorState {}
13
14/// Untrained state marker
15#[derive(Debug, Clone, Copy)]
16pub struct Untrained;
17impl EstimatorState for Untrained {}
18
19/// Trained state marker  
20#[derive(Debug, Clone, Copy)]
21pub struct Trained;
22impl EstimatorState for Trained {}
23
24/// Marker trait for task types
25pub trait TaskType {}
26
27/// Classification task marker
28#[derive(Debug, Clone, Copy)]
29pub struct Classification;
30impl TaskType for Classification {}
31
32/// Regression task marker
33#[derive(Debug, Clone, Copy)]
34pub struct Regression;
35impl TaskType for Regression {}
36
37/// Marker trait for strategy validation
38pub trait StrategyValid<T: TaskType> {}
39
40/// Validated classification strategies
41impl StrategyValid<Classification> for crate::ClassifierStrategy {}
42
43/// Validated regression strategies
44impl StrategyValid<Regression> for crate::RegressorStrategy {}
45
46/// Type-safe dummy estimator with compile-time guarantees
47#[derive(Debug, Clone)]
48pub struct TypeSafeDummyEstimator<State, Task, Strategy>
49where
50    State: EstimatorState,
51    Task: TaskType,
52    Strategy: StrategyValid<Task>,
53{
54    strategy: Strategy,
55    random_state: Option<u64>,
56    _state: PhantomData<State>,
57    _task: PhantomData<Task>,
58}
59
60/// Fitted classifier with type-safe state
61#[derive(Debug, Clone)]
62pub struct TypeSafeFittedClassifier<Strategy>
63where
64    Strategy: StrategyValid<Classification>,
65{
66    strategy: Strategy,
67    fitted_data: ClassificationFittedData,
68    random_state: Option<u64>,
69}
70
71/// Fitted regressor with type-safe state
72#[derive(Debug, Clone)]
73pub struct TypeSafeFittedRegressor<Strategy>
74where
75    Strategy: StrategyValid<Regression>,
76{
77    strategy: Strategy,
78    fitted_data: RegressionFittedData,
79    random_state: Option<u64>,
80}
81
82/// Classification fitted data
83#[derive(Debug, Clone)]
84pub struct ClassificationFittedData {
85    /// class_counts
86    pub class_counts: std::collections::HashMap<i32, usize>,
87    /// class_priors
88    pub class_priors: std::collections::HashMap<i32, f64>,
89    /// most_frequent_class
90    pub most_frequent_class: i32,
91    /// n_samples
92    pub n_samples: usize,
93    /// n_features
94    pub n_features: usize,
95}
96
97/// Regression fitted data
98#[derive(Debug, Clone)]
99pub struct RegressionFittedData {
100    /// target_mean
101    pub target_mean: f64,
102    /// target_median
103    pub target_median: f64,
104    /// target_std
105    pub target_std: f64,
106    /// target_min
107    pub target_min: f64,
108    /// target_max
109    pub target_max: f64,
110    /// n_samples
111    pub n_samples: usize,
112    /// n_features
113    pub n_features: usize,
114}
115
116/// Compile-time configuration for estimators
117pub trait EstimatorConfig {
118    type TaskType: TaskType;
119    type Strategy: StrategyValid<Self::TaskType>;
120
121    /// Validate configuration at compile time
122    fn validate() -> Result<(), &'static str>;
123
124    /// Create strategy instance
125    fn create_strategy() -> Self::Strategy;
126}
127
128/// Classification configuration
129#[derive(Debug)]
130pub struct ClassificationConfig<S: StrategyValid<Classification>> {
131    _strategy: PhantomData<S>,
132}
133
134/// Regression configuration
135#[derive(Debug)]
136pub struct RegressionConfig<S: StrategyValid<Regression>> {
137    _strategy: PhantomData<S>,
138}
139
140impl<S: StrategyValid<Classification>> EstimatorConfig for ClassificationConfig<S> {
141    type TaskType = Classification;
142    type Strategy = S;
143
144    fn validate() -> Result<(), &'static str> {
145        // Compile-time validation can be extended here
146        Ok(())
147    }
148
149    fn create_strategy() -> Self::Strategy {
150        // This would need to be implemented per strategy type
151        // For now, we'll use a placeholder approach
152        panic!("Strategy creation must be implemented per type")
153    }
154}
155
156impl<S: StrategyValid<Regression>> EstimatorConfig for RegressionConfig<S> {
157    type TaskType = Regression;
158    type Strategy = S;
159
160    fn validate() -> Result<(), &'static str> {
161        // Compile-time validation can be extended here
162        Ok(())
163    }
164
165    fn create_strategy() -> Self::Strategy {
166        // This would need to be implemented per strategy type
167        panic!("Strategy creation must be implemented per type")
168    }
169}
170
171/// Builder for type-safe dummy estimators with compile-time validation
172#[derive(Debug)]
173pub struct TypeSafeEstimatorBuilder<Task, Strategy>
174where
175    Task: TaskType,
176    Strategy: StrategyValid<Task>,
177{
178    strategy: Strategy,
179    random_state: Option<u64>,
180    _task: PhantomData<Task>,
181}
182
183impl<Strategy> TypeSafeDummyEstimator<Untrained, Classification, Strategy>
184where
185    Strategy: StrategyValid<Classification> + Clone,
186{
187    /// Create a new untrained type-safe classifier
188    pub fn new(strategy: Strategy) -> Self {
189        Self {
190            strategy,
191            random_state: None,
192            _state: PhantomData,
193            _task: PhantomData,
194        }
195    }
196
197    /// Set random state (builder pattern)
198    pub fn with_random_state(mut self, seed: u64) -> Self {
199        self.random_state = Some(seed);
200        self
201    }
202
203    /// Build the estimator after validation
204    pub fn build(self) -> Result<Self, &'static str> {
205        // Compile-time and runtime validation
206        self.validate_configuration()?;
207        Ok(self)
208    }
209
210    /// Validate configuration
211    fn validate_configuration(&self) -> Result<(), &'static str> {
212        // Additional runtime validation can be added here
213        Ok(())
214    }
215}
216
217impl<Strategy> TypeSafeDummyEstimator<Untrained, Regression, Strategy>
218where
219    Strategy: StrategyValid<Regression> + Clone,
220{
221    /// Create a new untrained type-safe regressor
222    pub fn new(strategy: Strategy) -> Self {
223        Self {
224            strategy,
225            random_state: None,
226            _state: PhantomData,
227            _task: PhantomData,
228        }
229    }
230
231    /// Set random state (builder pattern)
232    pub fn with_random_state(mut self, seed: u64) -> Self {
233        self.random_state = Some(seed);
234        self
235    }
236
237    /// Build the estimator after validation
238    pub fn build(self) -> Result<Self, &'static str> {
239        // Compile-time and runtime validation
240        self.validate_configuration()?;
241        Ok(self)
242    }
243
244    /// Validate configuration
245    fn validate_configuration(&self) -> Result<(), &'static str> {
246        // Additional runtime validation can be added here
247        Ok(())
248    }
249}
250
251/// Fit implementation for type-safe classifiers
252impl<Strategy> Fit<Array2<f64>, Array1<i32>, TypeSafeFittedClassifier<Strategy>>
253    for TypeSafeDummyEstimator<Untrained, Classification, Strategy>
254where
255    Strategy: StrategyValid<Classification> + Clone + Into<crate::ClassifierStrategy> + Send + Sync,
256{
257    type Fitted = TypeSafeFittedClassifier<Strategy>;
258    fn fit(
259        self,
260        x: &Array2<f64>,
261        y: &Array1<i32>,
262    ) -> Result<TypeSafeFittedClassifier<Strategy>, SklearsError> {
263        if x.nrows() != y.len() {
264            return Err(SklearsError::ShapeMismatch {
265                expected: format!("{} samples", x.nrows()),
266                actual: format!("{} labels", y.len()),
267            });
268        }
269
270        // Calculate fitted data
271        let mut class_counts = std::collections::HashMap::new();
272        for &class in y.iter() {
273            *class_counts.entry(class).or_insert(0) += 1;
274        }
275
276        let mut class_priors = std::collections::HashMap::new();
277        let total_samples = y.len() as f64;
278        for (&class, &count) in &class_counts {
279            class_priors.insert(class, count as f64 / total_samples);
280        }
281
282        let most_frequent_class = class_counts
283            .iter()
284            .max_by_key(|(_, &count)| count)
285            .map(|(&class, _)| class)
286            .unwrap_or(0);
287
288        let fitted_data = ClassificationFittedData {
289            class_counts,
290            class_priors,
291            most_frequent_class,
292            n_samples: x.nrows(),
293            n_features: x.ncols(),
294        };
295
296        Ok(TypeSafeFittedClassifier {
297            strategy: self.strategy.clone(),
298            fitted_data,
299            random_state: self.random_state,
300        })
301    }
302}
303
304/// Fit implementation for type-safe regressors
305impl<Strategy> Fit<Array2<f64>, Array1<f64>, TypeSafeFittedRegressor<Strategy>>
306    for TypeSafeDummyEstimator<Untrained, Regression, Strategy>
307where
308    Strategy: StrategyValid<Regression> + Clone + Into<crate::RegressorStrategy> + Send + Sync,
309{
310    type Fitted = TypeSafeFittedRegressor<Strategy>;
311    fn fit(
312        self,
313        x: &Array2<f64>,
314        y: &Array1<f64>,
315    ) -> Result<TypeSafeFittedRegressor<Strategy>, SklearsError> {
316        if x.nrows() != y.len() {
317            return Err(SklearsError::ShapeMismatch {
318                expected: format!("{} samples", x.nrows()),
319                actual: format!("{} labels", y.len()),
320            });
321        }
322
323        // Calculate fitted data
324        let target_mean = y.mean().unwrap_or(0.0);
325        let target_std = y.std(0.0);
326        let target_min = y.iter().fold(f64::INFINITY, |a, &b| a.min(b));
327        let target_max = y.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
328
329        // Calculate median
330        let mut sorted_targets = y.to_vec();
331        sorted_targets.sort_by(|a, b| a.partial_cmp(b).unwrap());
332        let target_median = if sorted_targets.len() % 2 == 0 {
333            let mid = sorted_targets.len() / 2;
334            (sorted_targets[mid - 1] + sorted_targets[mid]) / 2.0
335        } else {
336            sorted_targets[sorted_targets.len() / 2]
337        };
338
339        let fitted_data = RegressionFittedData {
340            target_mean,
341            target_median,
342            target_std,
343            target_min,
344            target_max,
345            n_samples: x.nrows(),
346            n_features: x.ncols(),
347        };
348
349        Ok(TypeSafeFittedRegressor {
350            strategy: self.strategy.clone(),
351            fitted_data,
352            random_state: self.random_state,
353        })
354    }
355}
356
357/// Predict implementation for type-safe fitted classifiers
358impl<Strategy> Predict<Array2<f64>, Array1<i32>> for TypeSafeFittedClassifier<Strategy>
359where
360    Strategy: StrategyValid<Classification> + Clone + Into<crate::ClassifierStrategy>,
361{
362    fn predict(&self, x: &Array2<f64>) -> Result<Array1<i32>, SklearsError> {
363        // Convert to legacy strategy and use existing implementation
364        let legacy_strategy: crate::ClassifierStrategy = self.strategy.clone().into();
365        let legacy_classifier = crate::DummyClassifier::new(legacy_strategy);
366
367        // Create fake training data for legacy classifier
368        let fake_x = Array2::zeros((self.fitted_data.n_samples, self.fitted_data.n_features));
369        let fake_y: Array1<i32> = Array1::from_iter(
370            self.fitted_data
371                .class_counts
372                .iter()
373                .flat_map(|(&class, &count)| std::iter::repeat(class).take(count)),
374        );
375
376        let fitted_legacy = legacy_classifier.fit(&fake_x, &fake_y)?;
377        fitted_legacy.predict(x)
378    }
379}
380
381/// Predict implementation for type-safe fitted regressors
382impl<Strategy> Predict<Array2<f64>, Array1<f64>> for TypeSafeFittedRegressor<Strategy>
383where
384    Strategy: StrategyValid<Regression> + Clone + Into<crate::RegressorStrategy>,
385{
386    fn predict(&self, x: &Array2<f64>) -> Result<Array1<f64>, SklearsError> {
387        // Convert to legacy strategy and use existing implementation
388        let legacy_strategy: crate::RegressorStrategy = self.strategy.clone().into();
389        let legacy_regressor = crate::DummyRegressor::new(legacy_strategy);
390
391        // Create fake training data for legacy regressor
392        let fake_x = Array2::zeros((self.fitted_data.n_samples, self.fitted_data.n_features));
393        let fake_y = Array1::from_elem(self.fitted_data.n_samples, self.fitted_data.target_mean);
394
395        let fitted_legacy = legacy_regressor.fit(&fake_x, &fake_y)?;
396        fitted_legacy.predict(x)
397    }
398}
399
400/// Compile-time strategy validation macro
401#[macro_export]
402macro_rules! validate_strategy_at_compile_time {
403    ($strategy:ty, $task:ty) => {
404        const _: fn() = || {
405            fn assert_strategy_valid<S: StrategyValid<T>, T: TaskType>() {}
406            assert_strategy_valid::<$strategy, $task>();
407        };
408    };
409}
410
411/// Compile-time estimator configuration macro
412#[macro_export]
413macro_rules! type_safe_estimator {
414    (classification, $strategy:expr) => {{
415        TypeSafeDummyEstimator::<Untrained, Classification, _>::new($strategy)
416    }};
417    (regression, $strategy:expr) => {{
418        TypeSafeDummyEstimator::<Untrained, Regression, _>::new($strategy)
419    }};
420}
421
422/// Zero-cost abstraction for strategy validation
423pub struct ValidatedStrategy<S, T>
424where
425    S: StrategyValid<T>,
426    T: TaskType,
427{
428    strategy: S,
429    _task: PhantomData<T>,
430}
431
432impl<S, T> ValidatedStrategy<S, T>
433where
434    S: StrategyValid<T>,
435    T: TaskType,
436{
437    /// Create a validated strategy (zero-cost)
438    pub fn new(strategy: S) -> Self {
439        Self {
440            strategy,
441            _task: PhantomData,
442        }
443    }
444
445    /// Extract the strategy (zero-cost)
446    pub fn into_strategy(self) -> S {
447        self.strategy
448    }
449
450    /// Get a reference to the strategy (zero-cost)
451    pub fn strategy(&self) -> &S {
452        &self.strategy
453    }
454}
455
456/// Trait for type-safe parameter validation
457pub trait ParameterValidation {
458    type Error;
459
460    /// Validate parameters at compile time where possible
461    fn validate(&self) -> Result<(), Self::Error>;
462}
463
464/// Type-safe parameter holder
465#[derive(Debug, Clone)]
466pub struct TypeSafeParameters<T> {
467    value: T,
468}
469
470impl<T> TypeSafeParameters<T> {
471    /// Create new type-safe parameters
472    pub fn new(value: T) -> Self {
473        Self { value }
474    }
475
476    /// Get the parameter value
477    pub fn get(&self) -> &T {
478        &self.value
479    }
480
481    /// Consume and return the parameter value
482    pub fn into_inner(self) -> T {
483        self.value
484    }
485}
486
487/// Bounded parameter type for compile-time validation
488#[derive(Debug, Clone, Copy)]
489pub struct BoundedParameter<T, const MIN: i64, const MAX: i64> {
490    value: T,
491}
492
493impl<T, const MIN: i64, const MAX: i64> BoundedParameter<T, MIN, MAX>
494where
495    T: PartialOrd + Copy + TryFrom<i64>,
496{
497    /// Create a bounded parameter with compile-time bounds checking
498    pub fn new(value: T) -> Result<Self, &'static str> {
499        let min_val = T::try_from(MIN).map_err(|_| "Invalid minimum bound")?;
500        let max_val = T::try_from(MAX).map_err(|_| "Invalid maximum bound")?;
501
502        if value >= min_val && value <= max_val {
503            Ok(Self { value })
504        } else {
505            Err("Parameter value out of bounds")
506        }
507    }
508
509    /// Get the parameter value
510    pub fn get(&self) -> T {
511        self.value
512    }
513}
514
515// Specific implementation for i32
516impl<const MIN: i64, const MAX: i64> BoundedParameter<i32, MIN, MAX> {
517    /// Create a bounded i32 parameter
518    pub fn new_i32(value: i32) -> Result<Self, &'static str> {
519        let value_i64 = value as i64;
520        if value_i64 >= MIN && value_i64 <= MAX {
521            Ok(Self { value })
522        } else {
523            Err("Parameter value out of bounds")
524        }
525    }
526}
527
528// Specific implementation for u64
529impl<const MIN: i64, const MAX: i64> BoundedParameter<u64, MIN, MAX> {
530    /// Create a bounded u64 parameter
531    pub fn new_u64(value: u64) -> Result<Self, &'static str> {
532        let value_i64 = value as i64;
533        if value_i64 >= MIN && value_i64 <= MAX && value <= i64::MAX as u64 {
534            Ok(Self { value })
535        } else {
536            Err("Parameter value out of bounds")
537        }
538    }
539}
540
541// Specific implementation for f64
542impl<const MIN: i64, const MAX: i64> BoundedParameter<f64, MIN, MAX> {
543    /// Create a bounded f64 parameter
544    pub fn new_f64(value: f64) -> Result<Self, &'static str> {
545        let min_f64 = MIN as f64;
546        let max_f64 = MAX as f64;
547        if value >= min_f64 && value <= max_f64 {
548            Ok(Self { value })
549        } else {
550            Err("Parameter value out of bounds")
551        }
552    }
553
554    /// Get the f64 parameter value
555    pub fn get_f64(&self) -> f64 {
556        self.value
557    }
558}
559
560/// Type-safe probability parameter (0.0 to 1.0)
561pub type Probability = BoundedParameter<f64, 0, 1>;
562
563/// Type-safe positive integer parameter
564pub type PositiveInt = BoundedParameter<i32, 1, { i32::MAX as i64 }>;
565
566/// Type-safe random seed parameter
567pub type RandomSeed = BoundedParameter<u64, 0, { i64::MAX }>;
568
569/// Type-safe dimension parameter (1 to maximum dimensions)
570pub type Dimension = BoundedParameter<usize, 1, 10000>;
571
572/// Type-safe batch size parameter
573pub type BatchSize = BoundedParameter<usize, 1, 1000000>;
574
575/// Type-safe confidence level parameter (0.0 to 1.0)
576pub type ConfidenceLevel = BoundedParameter<f64, 0, 1>;
577
578/// Type-safe tolerance parameter (positive values only)
579pub type Tolerance = BoundedParameter<f64, 0, { i64::MAX }>;
580
581/// Type-safe iteration count parameter
582pub type IterationCount = BoundedParameter<u32, 1, { i32::MAX as i64 }>;
583
584/// Implementation for usize bounded parameters
585impl<const MIN: i64, const MAX: i64> BoundedParameter<usize, MIN, MAX> {
586    /// Create a bounded usize parameter
587    pub fn new_usize(value: usize) -> Result<Self, &'static str> {
588        let value_i64 = value as i64;
589        if value_i64 >= MIN && value_i64 <= MAX && value <= i64::MAX as usize {
590            Ok(Self { value })
591        } else {
592            Err("Parameter value out of bounds")
593        }
594    }
595
596    /// Get the usize parameter value
597    pub fn get_usize(&self) -> usize {
598        self.value
599    }
600}
601
602/// Implementation for u32 bounded parameters
603impl<const MIN: i64, const MAX: i64> BoundedParameter<u32, MIN, MAX> {
604    /// Create a bounded u32 parameter
605    pub fn new_u32(value: u32) -> Result<Self, &'static str> {
606        let value_i64 = value as i64;
607        if value_i64 >= MIN && value_i64 <= MAX {
608            Ok(Self { value })
609        } else {
610            Err("Parameter value out of bounds")
611        }
612    }
613
614    /// Get the u32 parameter value
615    pub fn get_u32(&self) -> u32 {
616        self.value
617    }
618}
619
620/// Const generic strategy selector for compile-time strategy configuration
621#[derive(Debug, Clone, Copy, PartialEq, Eq)]
622pub struct ConstStrategySelector<const STRATEGY_ID: usize>;
623
624/// Trait for mapping strategy IDs to strategy types
625pub trait StrategyFromId<const ID: usize> {
626    type Strategy;
627
628    fn create_strategy() -> Self::Strategy;
629}
630
631/// Classification strategy mapping (const generic approach)
632impl StrategyFromId<0> for ConstStrategySelector<0> {
633    type Strategy = crate::ClassifierStrategy;
634
635    fn create_strategy() -> Self::Strategy {
636        crate::ClassifierStrategy::MostFrequent
637    }
638}
639
640impl StrategyFromId<1> for ConstStrategySelector<1> {
641    type Strategy = crate::ClassifierStrategy;
642
643    fn create_strategy() -> Self::Strategy {
644        crate::ClassifierStrategy::Stratified
645    }
646}
647
648impl StrategyFromId<2> for ConstStrategySelector<2> {
649    type Strategy = crate::ClassifierStrategy;
650
651    fn create_strategy() -> Self::Strategy {
652        crate::ClassifierStrategy::Uniform
653    }
654}
655
656/// Regression strategy mapping (const generic approach)
657impl StrategyFromId<10> for ConstStrategySelector<10> {
658    type Strategy = crate::RegressorStrategy;
659
660    fn create_strategy() -> Self::Strategy {
661        crate::RegressorStrategy::Mean
662    }
663}
664
665impl StrategyFromId<11> for ConstStrategySelector<11> {
666    type Strategy = crate::RegressorStrategy;
667
668    fn create_strategy() -> Self::Strategy {
669        crate::RegressorStrategy::Median
670    }
671}
672
673impl StrategyFromId<12> for ConstStrategySelector<12> {
674    type Strategy = crate::RegressorStrategy;
675
676    fn create_strategy() -> Self::Strategy {
677        crate::RegressorStrategy::Normal {
678            mean: None,
679            std: None,
680        }
681    }
682}
683
684/// Compile-time estimator with const generic strategy selection
685#[derive(Debug, Clone)]
686pub struct ConstGenericEstimator<State, Task, const STRATEGY_ID: usize>
687where
688    State: EstimatorState,
689    Task: TaskType,
690    ConstStrategySelector<STRATEGY_ID>: StrategyFromId<STRATEGY_ID>,
691{
692    random_state: Option<u64>,
693    _state: PhantomData<State>,
694    _task: PhantomData<Task>,
695}
696
697impl<Task, const STRATEGY_ID: usize> Default for ConstGenericEstimator<Untrained, Task, STRATEGY_ID>
698where
699    Task: TaskType,
700    ConstStrategySelector<STRATEGY_ID>: StrategyFromId<STRATEGY_ID>,
701{
702    fn default() -> Self {
703        Self::new()
704    }
705}
706
707impl<Task, const STRATEGY_ID: usize> ConstGenericEstimator<Untrained, Task, STRATEGY_ID>
708where
709    Task: TaskType,
710    ConstStrategySelector<STRATEGY_ID>: StrategyFromId<STRATEGY_ID>,
711{
712    /// Create new const generic estimator
713    pub fn new() -> Self {
714        Self {
715            random_state: None,
716            _state: PhantomData,
717            _task: PhantomData,
718        }
719    }
720
721    /// Set random state
722    pub fn with_random_state(mut self, seed: u64) -> Self {
723        self.random_state = Some(seed);
724        self
725    }
726
727    /// Get the strategy at compile time
728    pub fn strategy(
729    ) -> <ConstStrategySelector<STRATEGY_ID> as StrategyFromId<STRATEGY_ID>>::Strategy {
730        ConstStrategySelector::<STRATEGY_ID>::create_strategy()
731    }
732}
733
734/// Zero-cost wrapper for type-safe operations
735#[derive(Debug, Clone, Copy)]
736pub struct ZeroCostWrapper<T> {
737    inner: T,
738}
739
740impl<T> ZeroCostWrapper<T> {
741    /// Create a zero-cost wrapper
742    #[inline(always)]
743    pub const fn new(inner: T) -> Self {
744        Self { inner }
745    }
746
747    /// Extract the inner value (zero-cost)
748    #[inline(always)]
749    pub fn into_inner(self) -> T {
750        self.inner
751    }
752
753    /// Get a reference to the inner value (zero-cost)
754    #[inline(always)]
755    pub const fn get(&self) -> &T {
756        &self.inner
757    }
758
759    /// Map the inner value (zero-cost at compile time)
760    #[inline(always)]
761    pub fn map<U, F>(self, f: F) -> ZeroCostWrapper<U>
762    where
763        F: FnOnce(T) -> U,
764    {
765        ZeroCostWrapper::new(f(self.inner))
766    }
767}
768
769/// Phantom type for tracking statistical properties at compile time
770#[derive(Debug, Clone, Copy)]
771pub struct StatisticalPhantom<
772    const IS_DETERMINISTIC: bool,
773    const IS_STATELESS: bool,
774    const REQUIRES_FITTING: bool,
775> {
776    _marker: PhantomData<()>,
777}
778
779impl<const IS_DETERMINISTIC: bool, const IS_STATELESS: bool, const REQUIRES_FITTING: bool> Default
780    for StatisticalPhantom<IS_DETERMINISTIC, IS_STATELESS, REQUIRES_FITTING>
781{
782    fn default() -> Self {
783        Self::new()
784    }
785}
786
787impl<const IS_DETERMINISTIC: bool, const IS_STATELESS: bool, const REQUIRES_FITTING: bool>
788    StatisticalPhantom<IS_DETERMINISTIC, IS_STATELESS, REQUIRES_FITTING>
789{
790    /// Create new statistical phantom type
791    pub const fn new() -> Self {
792        Self {
793            _marker: PhantomData,
794        }
795    }
796
797    /// Check if the strategy is deterministic at compile time
798    pub const fn is_deterministic() -> bool {
799        IS_DETERMINISTIC
800    }
801
802    /// Check if the strategy is stateless at compile time
803    pub const fn is_stateless() -> bool {
804        IS_STATELESS
805    }
806
807    /// Check if the strategy requires fitting at compile time
808    pub const fn requires_fitting() -> bool {
809        REQUIRES_FITTING
810    }
811}
812
813/// Type alias for deterministic, stateless strategies that don't require fitting
814pub type SimpleDeterministicStrategy = StatisticalPhantom<true, true, false>;
815
816/// Type alias for stochastic strategies that require fitting
817pub type StochasticFittedStrategy = StatisticalPhantom<false, false, true>;
818
819/// Compile-time validation trait for strategy properties
820pub trait StrategyProperties {
821    const IS_DETERMINISTIC: bool;
822    const IS_STATELESS: bool;
823    const REQUIRES_FITTING: bool;
824
825    type Phantom: 'static;
826
827    /// Get the phantom type for this strategy
828    fn phantom() -> Self::Phantom;
829}
830
831/// Enhanced type-safe estimator with statistical properties
832#[derive(Debug, Clone)]
833pub struct StatisticallyTypedEstimator<State, Task, Strategy, Properties>
834where
835    State: EstimatorState,
836    Task: TaskType,
837    Strategy: StrategyValid<Task> + StrategyProperties<Phantom = Properties>,
838    Properties: 'static,
839{
840    strategy: Strategy,
841    random_state: Option<u64>,
842    _state: PhantomData<State>,
843    _task: PhantomData<Task>,
844    _properties: PhantomData<Properties>,
845}
846
847impl<Task, Strategy, Properties> StatisticallyTypedEstimator<Untrained, Task, Strategy, Properties>
848where
849    Task: TaskType,
850    Strategy: StrategyValid<Task> + StrategyProperties<Phantom = Properties> + Clone,
851    Properties: 'static,
852{
853    /// Create new statistically typed estimator
854    pub fn new(strategy: Strategy) -> Self {
855        Self {
856            strategy,
857            random_state: None,
858            _state: PhantomData,
859            _task: PhantomData,
860            _properties: PhantomData,
861        }
862    }
863
864    /// Check if this estimator's strategy is deterministic (compile-time)
865    pub const fn is_deterministic() -> bool {
866        Strategy::IS_DETERMINISTIC
867    }
868
869    /// Check if this estimator's strategy is stateless (compile-time)
870    pub const fn is_stateless() -> bool {
871        Strategy::IS_STATELESS
872    }
873
874    /// Check if this estimator requires fitting (compile-time)
875    pub const fn requires_fitting() -> bool {
876        Strategy::REQUIRES_FITTING
877    }
878}
879
880/// Compile-time configuration validation macro
881#[macro_export]
882macro_rules! validate_estimator_config {
883    ($estimator_type:ty, $state:ty, $task:ty) => {
884        const _: () = {
885            // Ensure the estimator implements the required traits
886            fn _assert_estimator_state<S: EstimatorState>() {}
887            fn _assert_task_type<T: TaskType>() {}
888            fn _assert_type_safe_estimator<
889                E: TypeSafeEstimator<S, T>,
890                S: EstimatorState,
891                T: TaskType,
892            >() {
893            }
894
895            _assert_estimator_state::<$state>();
896            _assert_task_type::<$task>();
897        };
898    };
899}
900
901/// Compile-time strategy compatibility validation macro
902#[macro_export]
903macro_rules! assert_strategy_compatible {
904    ($strategy:ty, $task:ty) => {
905        const _: () = {
906            fn _assert_compatible<S: StrategyValid<T>, T: TaskType>() {}
907            _assert_compatible::<$strategy, $task>();
908        };
909    };
910}
911
912/// Zero-cost compile-time feature flag
913#[derive(Debug, Clone, Copy)]
914pub struct CompileTimeFeature<const ENABLED: bool>;
915
916impl<const ENABLED: bool> CompileTimeFeature<ENABLED> {
917    /// Check if feature is enabled at compile time
918    pub const fn is_enabled() -> bool {
919        ENABLED
920    }
921
922    /// Execute code only if feature is enabled (zero-cost)
923    #[inline(always)]
924    pub fn when_enabled<F, R>(f: F) -> Option<R>
925    where
926        F: FnOnce() -> R,
927    {
928        if ENABLED {
929            Some(f())
930        } else {
931            None
932        }
933    }
934}
935
936/// Type-safe statistical operations with compile-time guarantees
937pub trait TypeSafeStatisticalOps<T> {
938    /// Compute mean with type safety
939    fn safe_mean(&self) -> Option<T>;
940
941    /// Compute variance with type safety
942    fn safe_variance(&self) -> Option<T>;
943
944    /// Compute standard deviation with type safety
945    fn safe_std(&self) -> Option<T>;
946}
947
948impl TypeSafeStatisticalOps<f64> for Array1<f64> {
949    fn safe_mean(&self) -> Option<f64> {
950        if self.is_empty() {
951            None
952        } else {
953            Some(self.mean().unwrap_or(0.0))
954        }
955    }
956
957    fn safe_variance(&self) -> Option<f64> {
958        if self.len() < 2 {
959            None
960        } else {
961            Some(self.var(1.0)) // Bessel's correction
962        }
963    }
964
965    fn safe_std(&self) -> Option<f64> {
966        self.safe_variance().map(|v| v.sqrt())
967    }
968}
969
970/// Compile-time memory layout optimization
971#[repr(C)]
972#[derive(Debug, Clone, Copy)]
973pub struct OptimizedLayout<T> {
974    data: T,
975    _pad: [u8; 0], // Zero-sized padding for potential alignment
976}
977
978impl<T> OptimizedLayout<T> {
979    /// Create optimized layout (zero-cost)
980    #[inline(always)]
981    pub const fn new(data: T) -> Self {
982        Self { data, _pad: [] }
983    }
984
985    /// Extract data (zero-cost)
986    #[inline(always)]
987    pub fn into_data(self) -> T {
988        self.data
989    }
990
991    /// Get reference to data (zero-cost)
992    #[inline(always)]
993    pub const fn data(&self) -> &T {
994        &self.data
995    }
996}
997
998/// Trait for estimators with compile-time guarantees
999pub trait TypeSafeEstimator<State: EstimatorState, Task: TaskType> {
1000    type Strategy: StrategyValid<Task>;
1001
1002    /// Get the current state (compile-time known)
1003    fn state(&self) -> State;
1004
1005    /// Get the task type (compile-time known)
1006    fn task_type(&self) -> Task;
1007
1008    /// Validate the estimator configuration
1009    fn validate(&self) -> Result<(), &'static str>;
1010}
1011
1012impl<State, Task, Strategy> TypeSafeEstimator<State, Task>
1013    for TypeSafeDummyEstimator<State, Task, Strategy>
1014where
1015    State: EstimatorState + Default,
1016    Task: TaskType + Default,
1017    Strategy: StrategyValid<Task>,
1018{
1019    type Strategy = Strategy;
1020
1021    fn state(&self) -> State {
1022        State::default()
1023    }
1024
1025    fn task_type(&self) -> Task {
1026        Task::default()
1027    }
1028
1029    fn validate(&self) -> Result<(), &'static str> {
1030        // Validation logic here
1031        Ok(())
1032    }
1033}
1034
1035/// Implement Default for state markers to enable compile-time checks
1036impl Default for Untrained {
1037    fn default() -> Self {
1038        Untrained
1039    }
1040}
1041
1042impl Default for Trained {
1043    fn default() -> Self {
1044        Trained
1045    }
1046}
1047
1048impl Default for Classification {
1049    fn default() -> Self {
1050        Classification
1051    }
1052}
1053
1054impl Default for Regression {
1055    fn default() -> Self {
1056        Regression
1057    }
1058}
1059
1060#[allow(non_snake_case)]
1061#[cfg(test)]
1062mod tests {
1063    use super::*;
1064    use crate::{ClassifierStrategy, RegressorStrategy};
1065    use scirs2_core::ndarray::{array, Array2};
1066
1067    #[test]
1068    fn test_type_safe_classifier_creation() {
1069        let strategy = ClassifierStrategy::MostFrequent;
1070        let classifier = TypeSafeDummyEstimator::<Untrained, Classification, _>::new(strategy);
1071
1072        // This should compile - correct types
1073        assert!(classifier.validate().is_ok());
1074    }
1075
1076    #[test]
1077    fn test_type_safe_regressor_creation() {
1078        let strategy = RegressorStrategy::Mean;
1079        let regressor = TypeSafeDummyEstimator::<Untrained, Regression, _>::new(strategy);
1080
1081        // This should compile - correct types
1082        assert!(regressor.validate().is_ok());
1083    }
1084
1085    #[test]
1086    fn test_bounded_parameters() {
1087        // Valid probability
1088        let prob = Probability::new_f64(0.5);
1089        assert!(prob.is_ok());
1090        assert_eq!(prob.unwrap().get_f64(), 0.5);
1091
1092        // Invalid probability
1093        let invalid_prob = Probability::new_f64(1.5);
1094        assert!(invalid_prob.is_err());
1095
1096        // Valid positive integer
1097        let pos_int = PositiveInt::new_i32(42);
1098        assert!(pos_int.is_ok());
1099        assert_eq!(pos_int.unwrap().get(), 42);
1100
1101        // Invalid positive integer
1102        let invalid_int = PositiveInt::new_i32(-1);
1103        assert!(invalid_int.is_err());
1104    }
1105
1106    #[test]
1107    fn test_validated_strategy() {
1108        let strategy = ClassifierStrategy::MostFrequent;
1109        let validated = ValidatedStrategy::<_, Classification>::new(strategy.clone());
1110
1111        assert_eq!(
1112            format!("{:?}", validated.strategy()),
1113            format!("{:?}", strategy)
1114        );
1115
1116        let extracted = validated.into_strategy();
1117        assert_eq!(format!("{:?}", extracted), format!("{:?}", strategy));
1118    }
1119
1120    #[test]
1121    fn test_type_safe_parameters() {
1122        let params = TypeSafeParameters::new(42.0);
1123        assert_eq!(*params.get(), 42.0);
1124        assert_eq!(params.into_inner(), 42.0);
1125    }
1126
1127    #[test]
1128    fn test_estimator_state_transitions() {
1129        let strategy = ClassifierStrategy::MostFrequent;
1130        let untrained = TypeSafeDummyEstimator::<Untrained, Classification, _>::new(strategy);
1131
1132        // Can't predict without fitting - this is ensured by the type system
1133        // fitted.predict() would require a TypeSafeFittedClassifier
1134
1135        let x =
1136            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1137        let y = array![0, 0, 1, 1];
1138
1139        let fitted = untrained.fit(&x, &y).unwrap();
1140        let predictions = fitted.predict(&x);
1141
1142        assert!(predictions.is_ok());
1143        assert_eq!(predictions.unwrap().len(), 4);
1144    }
1145
1146    #[test]
1147    fn test_const_generic_estimator() {
1148        // Most frequent classifier (ID = 0)
1149        let classifier = ConstGenericEstimator::<Untrained, Classification, 0>::new();
1150        let strategy = ConstGenericEstimator::<Untrained, Classification, 0>::strategy();
1151
1152        assert_eq!(format!("{:?}", strategy), "MostFrequent");
1153
1154        // Mean regressor (ID = 10)
1155        let regressor = ConstGenericEstimator::<Untrained, Regression, 10>::new();
1156        let reg_strategy = ConstGenericEstimator::<Untrained, Regression, 10>::strategy();
1157
1158        assert_eq!(format!("{:?}", reg_strategy), "Mean");
1159    }
1160
1161    #[test]
1162    fn test_zero_cost_wrapper() {
1163        let wrapper = ZeroCostWrapper::new(42);
1164        assert_eq!(*wrapper.get(), 42);
1165        assert_eq!(wrapper.into_inner(), 42);
1166
1167        let mapped = ZeroCostWrapper::new(10).map(|x| x * 2);
1168        assert_eq!(mapped.into_inner(), 20);
1169    }
1170
1171    #[test]
1172    fn test_statistical_phantom() {
1173        // Test deterministic, stateless strategy
1174        let phantom = SimpleDeterministicStrategy::new();
1175        assert!(SimpleDeterministicStrategy::is_deterministic());
1176        assert!(SimpleDeterministicStrategy::is_stateless());
1177        assert!(!SimpleDeterministicStrategy::requires_fitting());
1178
1179        // Test stochastic, fitted strategy
1180        let stochastic = StochasticFittedStrategy::new();
1181        assert!(!StochasticFittedStrategy::is_deterministic());
1182        assert!(!StochasticFittedStrategy::is_stateless());
1183        assert!(StochasticFittedStrategy::requires_fitting());
1184    }
1185
1186    #[test]
1187    fn test_compile_time_feature() {
1188        // Enabled feature
1189        type EnabledFeature = CompileTimeFeature<true>;
1190        assert!(EnabledFeature::is_enabled());
1191
1192        let result = EnabledFeature::when_enabled(|| 42);
1193        assert_eq!(result, Some(42));
1194
1195        // Disabled feature
1196        type DisabledFeature = CompileTimeFeature<false>;
1197        assert!(!DisabledFeature::is_enabled());
1198
1199        let no_result = DisabledFeature::when_enabled(|| 42);
1200        assert_eq!(no_result, None);
1201    }
1202
1203    #[test]
1204    fn test_type_safe_statistical_ops() {
1205        use scirs2_core::ndarray::array;
1206
1207        let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
1208
1209        assert_eq!(data.safe_mean(), Some(3.0));
1210        assert!(data.safe_variance().is_some());
1211        assert!(data.safe_std().is_some());
1212
1213        // Empty array
1214        let empty: Array1<f64> = array![];
1215        assert_eq!(empty.safe_mean(), None);
1216        assert_eq!(empty.safe_variance(), None);
1217        assert_eq!(empty.safe_std(), None);
1218
1219        // Single element
1220        let single = array![42.0];
1221        assert_eq!(single.safe_mean(), Some(42.0));
1222        assert_eq!(single.safe_variance(), None); // Need at least 2 elements
1223    }
1224
1225    #[test]
1226    fn test_optimized_layout() {
1227        let layout = OptimizedLayout::new(42);
1228        assert_eq!(*layout.data(), 42);
1229        assert_eq!(layout.into_data(), 42);
1230    }
1231
1232    // Compile-time test - uncomment to verify compile-time checking
1233    // #[test]
1234    // fn test_compile_time_validation() {
1235    //     // This should not compile - wrong task type for strategy
1236    //     // let invalid = TypeSafeDummyEstimator::<Untrained, Regression, ClassifierStrategy>::new(
1237    //     //     ClassifierStrategy::MostFrequent
1238    //     // );
1239    //
1240    //     // Test compile-time macros
1241    //     validate_estimator_config!(TypeSafeDummyEstimator<Untrained, Classification, ClassifierStrategy>, Untrained, Classification);
1242    //     assert_strategy_compatible!(ClassifierStrategy, Classification);
1243    // }
1244}