1use scirs2_core::ndarray::{Array1, Array2};
7use sklears_core::error::SklearsError;
8use sklears_core::traits::{Fit, Predict};
9use std::marker::PhantomData;
10
11pub trait EstimatorState {}
13
14#[derive(Debug, Clone, Copy)]
16pub struct Untrained;
17impl EstimatorState for Untrained {}
18
19#[derive(Debug, Clone, Copy)]
21pub struct Trained;
22impl EstimatorState for Trained {}
23
24pub trait TaskType {}
26
27#[derive(Debug, Clone, Copy)]
29pub struct Classification;
30impl TaskType for Classification {}
31
32#[derive(Debug, Clone, Copy)]
34pub struct Regression;
35impl TaskType for Regression {}
36
37pub trait StrategyValid<T: TaskType> {}
39
40impl StrategyValid<Classification> for crate::ClassifierStrategy {}
42
43impl StrategyValid<Regression> for crate::RegressorStrategy {}
45
46#[derive(Debug, Clone)]
48pub struct TypeSafeDummyEstimator<State, Task, Strategy>
49where
50 State: EstimatorState,
51 Task: TaskType,
52 Strategy: StrategyValid<Task>,
53{
54 strategy: Strategy,
55 random_state: Option<u64>,
56 _state: PhantomData<State>,
57 _task: PhantomData<Task>,
58}
59
60#[derive(Debug, Clone)]
62pub struct TypeSafeFittedClassifier<Strategy>
63where
64 Strategy: StrategyValid<Classification>,
65{
66 strategy: Strategy,
67 fitted_data: ClassificationFittedData,
68 random_state: Option<u64>,
69}
70
71#[derive(Debug, Clone)]
73pub struct TypeSafeFittedRegressor<Strategy>
74where
75 Strategy: StrategyValid<Regression>,
76{
77 strategy: Strategy,
78 fitted_data: RegressionFittedData,
79 random_state: Option<u64>,
80}
81
82#[derive(Debug, Clone)]
84pub struct ClassificationFittedData {
85 pub class_counts: std::collections::HashMap<i32, usize>,
87 pub class_priors: std::collections::HashMap<i32, f64>,
89 pub most_frequent_class: i32,
91 pub n_samples: usize,
93 pub n_features: usize,
95}
96
97#[derive(Debug, Clone)]
99pub struct RegressionFittedData {
100 pub target_mean: f64,
102 pub target_median: f64,
104 pub target_std: f64,
106 pub target_min: f64,
108 pub target_max: f64,
110 pub n_samples: usize,
112 pub n_features: usize,
114}
115
116pub trait EstimatorConfig {
118 type TaskType: TaskType;
119 type Strategy: StrategyValid<Self::TaskType>;
120
121 fn validate() -> Result<(), &'static str>;
123
124 fn create_strategy() -> Self::Strategy;
126}
127
128#[derive(Debug)]
130pub struct ClassificationConfig<S: StrategyValid<Classification>> {
131 _strategy: PhantomData<S>,
132}
133
134#[derive(Debug)]
136pub struct RegressionConfig<S: StrategyValid<Regression>> {
137 _strategy: PhantomData<S>,
138}
139
140impl<S: StrategyValid<Classification>> EstimatorConfig for ClassificationConfig<S> {
141 type TaskType = Classification;
142 type Strategy = S;
143
144 fn validate() -> Result<(), &'static str> {
145 Ok(())
147 }
148
149 fn create_strategy() -> Self::Strategy {
150 panic!("Strategy creation must be implemented per type")
153 }
154}
155
156impl<S: StrategyValid<Regression>> EstimatorConfig for RegressionConfig<S> {
157 type TaskType = Regression;
158 type Strategy = S;
159
160 fn validate() -> Result<(), &'static str> {
161 Ok(())
163 }
164
165 fn create_strategy() -> Self::Strategy {
166 panic!("Strategy creation must be implemented per type")
168 }
169}
170
171#[derive(Debug)]
173pub struct TypeSafeEstimatorBuilder<Task, Strategy>
174where
175 Task: TaskType,
176 Strategy: StrategyValid<Task>,
177{
178 strategy: Strategy,
179 random_state: Option<u64>,
180 _task: PhantomData<Task>,
181}
182
183impl<Strategy> TypeSafeDummyEstimator<Untrained, Classification, Strategy>
184where
185 Strategy: StrategyValid<Classification> + Clone,
186{
187 pub fn new(strategy: Strategy) -> Self {
189 Self {
190 strategy,
191 random_state: None,
192 _state: PhantomData,
193 _task: PhantomData,
194 }
195 }
196
197 pub fn with_random_state(mut self, seed: u64) -> Self {
199 self.random_state = Some(seed);
200 self
201 }
202
203 pub fn build(self) -> Result<Self, &'static str> {
205 self.validate_configuration()?;
207 Ok(self)
208 }
209
210 fn validate_configuration(&self) -> Result<(), &'static str> {
212 Ok(())
214 }
215}
216
217impl<Strategy> TypeSafeDummyEstimator<Untrained, Regression, Strategy>
218where
219 Strategy: StrategyValid<Regression> + Clone,
220{
221 pub fn new(strategy: Strategy) -> Self {
223 Self {
224 strategy,
225 random_state: None,
226 _state: PhantomData,
227 _task: PhantomData,
228 }
229 }
230
231 pub fn with_random_state(mut self, seed: u64) -> Self {
233 self.random_state = Some(seed);
234 self
235 }
236
237 pub fn build(self) -> Result<Self, &'static str> {
239 self.validate_configuration()?;
241 Ok(self)
242 }
243
244 fn validate_configuration(&self) -> Result<(), &'static str> {
246 Ok(())
248 }
249}
250
251impl<Strategy> Fit<Array2<f64>, Array1<i32>, TypeSafeFittedClassifier<Strategy>>
253 for TypeSafeDummyEstimator<Untrained, Classification, Strategy>
254where
255 Strategy: StrategyValid<Classification> + Clone + Into<crate::ClassifierStrategy> + Send + Sync,
256{
257 type Fitted = TypeSafeFittedClassifier<Strategy>;
258 fn fit(
259 self,
260 x: &Array2<f64>,
261 y: &Array1<i32>,
262 ) -> Result<TypeSafeFittedClassifier<Strategy>, SklearsError> {
263 if x.nrows() != y.len() {
264 return Err(SklearsError::ShapeMismatch {
265 expected: format!("{} samples", x.nrows()),
266 actual: format!("{} labels", y.len()),
267 });
268 }
269
270 let mut class_counts = std::collections::HashMap::new();
272 for &class in y.iter() {
273 *class_counts.entry(class).or_insert(0) += 1;
274 }
275
276 let mut class_priors = std::collections::HashMap::new();
277 let total_samples = y.len() as f64;
278 for (&class, &count) in &class_counts {
279 class_priors.insert(class, count as f64 / total_samples);
280 }
281
282 let most_frequent_class = class_counts
283 .iter()
284 .max_by_key(|(_, &count)| count)
285 .map(|(&class, _)| class)
286 .unwrap_or(0);
287
288 let fitted_data = ClassificationFittedData {
289 class_counts,
290 class_priors,
291 most_frequent_class,
292 n_samples: x.nrows(),
293 n_features: x.ncols(),
294 };
295
296 Ok(TypeSafeFittedClassifier {
297 strategy: self.strategy.clone(),
298 fitted_data,
299 random_state: self.random_state,
300 })
301 }
302}
303
304impl<Strategy> Fit<Array2<f64>, Array1<f64>, TypeSafeFittedRegressor<Strategy>>
306 for TypeSafeDummyEstimator<Untrained, Regression, Strategy>
307where
308 Strategy: StrategyValid<Regression> + Clone + Into<crate::RegressorStrategy> + Send + Sync,
309{
310 type Fitted = TypeSafeFittedRegressor<Strategy>;
311 fn fit(
312 self,
313 x: &Array2<f64>,
314 y: &Array1<f64>,
315 ) -> Result<TypeSafeFittedRegressor<Strategy>, SklearsError> {
316 if x.nrows() != y.len() {
317 return Err(SklearsError::ShapeMismatch {
318 expected: format!("{} samples", x.nrows()),
319 actual: format!("{} labels", y.len()),
320 });
321 }
322
323 let target_mean = y.mean().unwrap_or(0.0);
325 let target_std = y.std(0.0);
326 let target_min = y.iter().fold(f64::INFINITY, |a, &b| a.min(b));
327 let target_max = y.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
328
329 let mut sorted_targets = y.to_vec();
331 sorted_targets.sort_by(|a, b| a.partial_cmp(b).unwrap());
332 let target_median = if sorted_targets.len() % 2 == 0 {
333 let mid = sorted_targets.len() / 2;
334 (sorted_targets[mid - 1] + sorted_targets[mid]) / 2.0
335 } else {
336 sorted_targets[sorted_targets.len() / 2]
337 };
338
339 let fitted_data = RegressionFittedData {
340 target_mean,
341 target_median,
342 target_std,
343 target_min,
344 target_max,
345 n_samples: x.nrows(),
346 n_features: x.ncols(),
347 };
348
349 Ok(TypeSafeFittedRegressor {
350 strategy: self.strategy.clone(),
351 fitted_data,
352 random_state: self.random_state,
353 })
354 }
355}
356
357impl<Strategy> Predict<Array2<f64>, Array1<i32>> for TypeSafeFittedClassifier<Strategy>
359where
360 Strategy: StrategyValid<Classification> + Clone + Into<crate::ClassifierStrategy>,
361{
362 fn predict(&self, x: &Array2<f64>) -> Result<Array1<i32>, SklearsError> {
363 let legacy_strategy: crate::ClassifierStrategy = self.strategy.clone().into();
365 let legacy_classifier = crate::DummyClassifier::new(legacy_strategy);
366
367 let fake_x = Array2::zeros((self.fitted_data.n_samples, self.fitted_data.n_features));
369 let fake_y: Array1<i32> = Array1::from_iter(
370 self.fitted_data
371 .class_counts
372 .iter()
373 .flat_map(|(&class, &count)| std::iter::repeat(class).take(count)),
374 );
375
376 let fitted_legacy = legacy_classifier.fit(&fake_x, &fake_y)?;
377 fitted_legacy.predict(x)
378 }
379}
380
381impl<Strategy> Predict<Array2<f64>, Array1<f64>> for TypeSafeFittedRegressor<Strategy>
383where
384 Strategy: StrategyValid<Regression> + Clone + Into<crate::RegressorStrategy>,
385{
386 fn predict(&self, x: &Array2<f64>) -> Result<Array1<f64>, SklearsError> {
387 let legacy_strategy: crate::RegressorStrategy = self.strategy.clone().into();
389 let legacy_regressor = crate::DummyRegressor::new(legacy_strategy);
390
391 let fake_x = Array2::zeros((self.fitted_data.n_samples, self.fitted_data.n_features));
393 let fake_y = Array1::from_elem(self.fitted_data.n_samples, self.fitted_data.target_mean);
394
395 let fitted_legacy = legacy_regressor.fit(&fake_x, &fake_y)?;
396 fitted_legacy.predict(x)
397 }
398}
399
400#[macro_export]
402macro_rules! validate_strategy_at_compile_time {
403 ($strategy:ty, $task:ty) => {
404 const _: fn() = || {
405 fn assert_strategy_valid<S: StrategyValid<T>, T: TaskType>() {}
406 assert_strategy_valid::<$strategy, $task>();
407 };
408 };
409}
410
411#[macro_export]
413macro_rules! type_safe_estimator {
414 (classification, $strategy:expr) => {{
415 TypeSafeDummyEstimator::<Untrained, Classification, _>::new($strategy)
416 }};
417 (regression, $strategy:expr) => {{
418 TypeSafeDummyEstimator::<Untrained, Regression, _>::new($strategy)
419 }};
420}
421
422pub struct ValidatedStrategy<S, T>
424where
425 S: StrategyValid<T>,
426 T: TaskType,
427{
428 strategy: S,
429 _task: PhantomData<T>,
430}
431
432impl<S, T> ValidatedStrategy<S, T>
433where
434 S: StrategyValid<T>,
435 T: TaskType,
436{
437 pub fn new(strategy: S) -> Self {
439 Self {
440 strategy,
441 _task: PhantomData,
442 }
443 }
444
445 pub fn into_strategy(self) -> S {
447 self.strategy
448 }
449
450 pub fn strategy(&self) -> &S {
452 &self.strategy
453 }
454}
455
456pub trait ParameterValidation {
458 type Error;
459
460 fn validate(&self) -> Result<(), Self::Error>;
462}
463
464#[derive(Debug, Clone)]
466pub struct TypeSafeParameters<T> {
467 value: T,
468}
469
470impl<T> TypeSafeParameters<T> {
471 pub fn new(value: T) -> Self {
473 Self { value }
474 }
475
476 pub fn get(&self) -> &T {
478 &self.value
479 }
480
481 pub fn into_inner(self) -> T {
483 self.value
484 }
485}
486
487#[derive(Debug, Clone, Copy)]
489pub struct BoundedParameter<T, const MIN: i64, const MAX: i64> {
490 value: T,
491}
492
493impl<T, const MIN: i64, const MAX: i64> BoundedParameter<T, MIN, MAX>
494where
495 T: PartialOrd + Copy + TryFrom<i64>,
496{
497 pub fn new(value: T) -> Result<Self, &'static str> {
499 let min_val = T::try_from(MIN).map_err(|_| "Invalid minimum bound")?;
500 let max_val = T::try_from(MAX).map_err(|_| "Invalid maximum bound")?;
501
502 if value >= min_val && value <= max_val {
503 Ok(Self { value })
504 } else {
505 Err("Parameter value out of bounds")
506 }
507 }
508
509 pub fn get(&self) -> T {
511 self.value
512 }
513}
514
515impl<const MIN: i64, const MAX: i64> BoundedParameter<i32, MIN, MAX> {
517 pub fn new_i32(value: i32) -> Result<Self, &'static str> {
519 let value_i64 = value as i64;
520 if value_i64 >= MIN && value_i64 <= MAX {
521 Ok(Self { value })
522 } else {
523 Err("Parameter value out of bounds")
524 }
525 }
526}
527
528impl<const MIN: i64, const MAX: i64> BoundedParameter<u64, MIN, MAX> {
530 pub fn new_u64(value: u64) -> Result<Self, &'static str> {
532 let value_i64 = value as i64;
533 if value_i64 >= MIN && value_i64 <= MAX && value <= i64::MAX as u64 {
534 Ok(Self { value })
535 } else {
536 Err("Parameter value out of bounds")
537 }
538 }
539}
540
541impl<const MIN: i64, const MAX: i64> BoundedParameter<f64, MIN, MAX> {
543 pub fn new_f64(value: f64) -> Result<Self, &'static str> {
545 let min_f64 = MIN as f64;
546 let max_f64 = MAX as f64;
547 if value >= min_f64 && value <= max_f64 {
548 Ok(Self { value })
549 } else {
550 Err("Parameter value out of bounds")
551 }
552 }
553
554 pub fn get_f64(&self) -> f64 {
556 self.value
557 }
558}
559
560pub type Probability = BoundedParameter<f64, 0, 1>;
562
563pub type PositiveInt = BoundedParameter<i32, 1, { i32::MAX as i64 }>;
565
566pub type RandomSeed = BoundedParameter<u64, 0, { i64::MAX }>;
568
569pub type Dimension = BoundedParameter<usize, 1, 10000>;
571
572pub type BatchSize = BoundedParameter<usize, 1, 1000000>;
574
575pub type ConfidenceLevel = BoundedParameter<f64, 0, 1>;
577
578pub type Tolerance = BoundedParameter<f64, 0, { i64::MAX }>;
580
581pub type IterationCount = BoundedParameter<u32, 1, { i32::MAX as i64 }>;
583
584impl<const MIN: i64, const MAX: i64> BoundedParameter<usize, MIN, MAX> {
586 pub fn new_usize(value: usize) -> Result<Self, &'static str> {
588 let value_i64 = value as i64;
589 if value_i64 >= MIN && value_i64 <= MAX && value <= i64::MAX as usize {
590 Ok(Self { value })
591 } else {
592 Err("Parameter value out of bounds")
593 }
594 }
595
596 pub fn get_usize(&self) -> usize {
598 self.value
599 }
600}
601
602impl<const MIN: i64, const MAX: i64> BoundedParameter<u32, MIN, MAX> {
604 pub fn new_u32(value: u32) -> Result<Self, &'static str> {
606 let value_i64 = value as i64;
607 if value_i64 >= MIN && value_i64 <= MAX {
608 Ok(Self { value })
609 } else {
610 Err("Parameter value out of bounds")
611 }
612 }
613
614 pub fn get_u32(&self) -> u32 {
616 self.value
617 }
618}
619
620#[derive(Debug, Clone, Copy, PartialEq, Eq)]
622pub struct ConstStrategySelector<const STRATEGY_ID: usize>;
623
624pub trait StrategyFromId<const ID: usize> {
626 type Strategy;
627
628 fn create_strategy() -> Self::Strategy;
629}
630
631impl StrategyFromId<0> for ConstStrategySelector<0> {
633 type Strategy = crate::ClassifierStrategy;
634
635 fn create_strategy() -> Self::Strategy {
636 crate::ClassifierStrategy::MostFrequent
637 }
638}
639
640impl StrategyFromId<1> for ConstStrategySelector<1> {
641 type Strategy = crate::ClassifierStrategy;
642
643 fn create_strategy() -> Self::Strategy {
644 crate::ClassifierStrategy::Stratified
645 }
646}
647
648impl StrategyFromId<2> for ConstStrategySelector<2> {
649 type Strategy = crate::ClassifierStrategy;
650
651 fn create_strategy() -> Self::Strategy {
652 crate::ClassifierStrategy::Uniform
653 }
654}
655
656impl StrategyFromId<10> for ConstStrategySelector<10> {
658 type Strategy = crate::RegressorStrategy;
659
660 fn create_strategy() -> Self::Strategy {
661 crate::RegressorStrategy::Mean
662 }
663}
664
665impl StrategyFromId<11> for ConstStrategySelector<11> {
666 type Strategy = crate::RegressorStrategy;
667
668 fn create_strategy() -> Self::Strategy {
669 crate::RegressorStrategy::Median
670 }
671}
672
673impl StrategyFromId<12> for ConstStrategySelector<12> {
674 type Strategy = crate::RegressorStrategy;
675
676 fn create_strategy() -> Self::Strategy {
677 crate::RegressorStrategy::Normal {
678 mean: None,
679 std: None,
680 }
681 }
682}
683
684#[derive(Debug, Clone)]
686pub struct ConstGenericEstimator<State, Task, const STRATEGY_ID: usize>
687where
688 State: EstimatorState,
689 Task: TaskType,
690 ConstStrategySelector<STRATEGY_ID>: StrategyFromId<STRATEGY_ID>,
691{
692 random_state: Option<u64>,
693 _state: PhantomData<State>,
694 _task: PhantomData<Task>,
695}
696
697impl<Task, const STRATEGY_ID: usize> Default for ConstGenericEstimator<Untrained, Task, STRATEGY_ID>
698where
699 Task: TaskType,
700 ConstStrategySelector<STRATEGY_ID>: StrategyFromId<STRATEGY_ID>,
701{
702 fn default() -> Self {
703 Self::new()
704 }
705}
706
707impl<Task, const STRATEGY_ID: usize> ConstGenericEstimator<Untrained, Task, STRATEGY_ID>
708where
709 Task: TaskType,
710 ConstStrategySelector<STRATEGY_ID>: StrategyFromId<STRATEGY_ID>,
711{
712 pub fn new() -> Self {
714 Self {
715 random_state: None,
716 _state: PhantomData,
717 _task: PhantomData,
718 }
719 }
720
721 pub fn with_random_state(mut self, seed: u64) -> Self {
723 self.random_state = Some(seed);
724 self
725 }
726
727 pub fn strategy(
729 ) -> <ConstStrategySelector<STRATEGY_ID> as StrategyFromId<STRATEGY_ID>>::Strategy {
730 ConstStrategySelector::<STRATEGY_ID>::create_strategy()
731 }
732}
733
734#[derive(Debug, Clone, Copy)]
736pub struct ZeroCostWrapper<T> {
737 inner: T,
738}
739
740impl<T> ZeroCostWrapper<T> {
741 #[inline(always)]
743 pub const fn new(inner: T) -> Self {
744 Self { inner }
745 }
746
747 #[inline(always)]
749 pub fn into_inner(self) -> T {
750 self.inner
751 }
752
753 #[inline(always)]
755 pub const fn get(&self) -> &T {
756 &self.inner
757 }
758
759 #[inline(always)]
761 pub fn map<U, F>(self, f: F) -> ZeroCostWrapper<U>
762 where
763 F: FnOnce(T) -> U,
764 {
765 ZeroCostWrapper::new(f(self.inner))
766 }
767}
768
769#[derive(Debug, Clone, Copy)]
771pub struct StatisticalPhantom<
772 const IS_DETERMINISTIC: bool,
773 const IS_STATELESS: bool,
774 const REQUIRES_FITTING: bool,
775> {
776 _marker: PhantomData<()>,
777}
778
779impl<const IS_DETERMINISTIC: bool, const IS_STATELESS: bool, const REQUIRES_FITTING: bool> Default
780 for StatisticalPhantom<IS_DETERMINISTIC, IS_STATELESS, REQUIRES_FITTING>
781{
782 fn default() -> Self {
783 Self::new()
784 }
785}
786
787impl<const IS_DETERMINISTIC: bool, const IS_STATELESS: bool, const REQUIRES_FITTING: bool>
788 StatisticalPhantom<IS_DETERMINISTIC, IS_STATELESS, REQUIRES_FITTING>
789{
790 pub const fn new() -> Self {
792 Self {
793 _marker: PhantomData,
794 }
795 }
796
797 pub const fn is_deterministic() -> bool {
799 IS_DETERMINISTIC
800 }
801
802 pub const fn is_stateless() -> bool {
804 IS_STATELESS
805 }
806
807 pub const fn requires_fitting() -> bool {
809 REQUIRES_FITTING
810 }
811}
812
813pub type SimpleDeterministicStrategy = StatisticalPhantom<true, true, false>;
815
816pub type StochasticFittedStrategy = StatisticalPhantom<false, false, true>;
818
819pub trait StrategyProperties {
821 const IS_DETERMINISTIC: bool;
822 const IS_STATELESS: bool;
823 const REQUIRES_FITTING: bool;
824
825 type Phantom: 'static;
826
827 fn phantom() -> Self::Phantom;
829}
830
831#[derive(Debug, Clone)]
833pub struct StatisticallyTypedEstimator<State, Task, Strategy, Properties>
834where
835 State: EstimatorState,
836 Task: TaskType,
837 Strategy: StrategyValid<Task> + StrategyProperties<Phantom = Properties>,
838 Properties: 'static,
839{
840 strategy: Strategy,
841 random_state: Option<u64>,
842 _state: PhantomData<State>,
843 _task: PhantomData<Task>,
844 _properties: PhantomData<Properties>,
845}
846
847impl<Task, Strategy, Properties> StatisticallyTypedEstimator<Untrained, Task, Strategy, Properties>
848where
849 Task: TaskType,
850 Strategy: StrategyValid<Task> + StrategyProperties<Phantom = Properties> + Clone,
851 Properties: 'static,
852{
853 pub fn new(strategy: Strategy) -> Self {
855 Self {
856 strategy,
857 random_state: None,
858 _state: PhantomData,
859 _task: PhantomData,
860 _properties: PhantomData,
861 }
862 }
863
864 pub const fn is_deterministic() -> bool {
866 Strategy::IS_DETERMINISTIC
867 }
868
869 pub const fn is_stateless() -> bool {
871 Strategy::IS_STATELESS
872 }
873
874 pub const fn requires_fitting() -> bool {
876 Strategy::REQUIRES_FITTING
877 }
878}
879
880#[macro_export]
882macro_rules! validate_estimator_config {
883 ($estimator_type:ty, $state:ty, $task:ty) => {
884 const _: () = {
885 fn _assert_estimator_state<S: EstimatorState>() {}
887 fn _assert_task_type<T: TaskType>() {}
888 fn _assert_type_safe_estimator<
889 E: TypeSafeEstimator<S, T>,
890 S: EstimatorState,
891 T: TaskType,
892 >() {
893 }
894
895 _assert_estimator_state::<$state>();
896 _assert_task_type::<$task>();
897 };
898 };
899}
900
901#[macro_export]
903macro_rules! assert_strategy_compatible {
904 ($strategy:ty, $task:ty) => {
905 const _: () = {
906 fn _assert_compatible<S: StrategyValid<T>, T: TaskType>() {}
907 _assert_compatible::<$strategy, $task>();
908 };
909 };
910}
911
912#[derive(Debug, Clone, Copy)]
914pub struct CompileTimeFeature<const ENABLED: bool>;
915
916impl<const ENABLED: bool> CompileTimeFeature<ENABLED> {
917 pub const fn is_enabled() -> bool {
919 ENABLED
920 }
921
922 #[inline(always)]
924 pub fn when_enabled<F, R>(f: F) -> Option<R>
925 where
926 F: FnOnce() -> R,
927 {
928 if ENABLED {
929 Some(f())
930 } else {
931 None
932 }
933 }
934}
935
936pub trait TypeSafeStatisticalOps<T> {
938 fn safe_mean(&self) -> Option<T>;
940
941 fn safe_variance(&self) -> Option<T>;
943
944 fn safe_std(&self) -> Option<T>;
946}
947
948impl TypeSafeStatisticalOps<f64> for Array1<f64> {
949 fn safe_mean(&self) -> Option<f64> {
950 if self.is_empty() {
951 None
952 } else {
953 Some(self.mean().unwrap_or(0.0))
954 }
955 }
956
957 fn safe_variance(&self) -> Option<f64> {
958 if self.len() < 2 {
959 None
960 } else {
961 Some(self.var(1.0)) }
963 }
964
965 fn safe_std(&self) -> Option<f64> {
966 self.safe_variance().map(|v| v.sqrt())
967 }
968}
969
970#[repr(C)]
972#[derive(Debug, Clone, Copy)]
973pub struct OptimizedLayout<T> {
974 data: T,
975 _pad: [u8; 0], }
977
978impl<T> OptimizedLayout<T> {
979 #[inline(always)]
981 pub const fn new(data: T) -> Self {
982 Self { data, _pad: [] }
983 }
984
985 #[inline(always)]
987 pub fn into_data(self) -> T {
988 self.data
989 }
990
991 #[inline(always)]
993 pub const fn data(&self) -> &T {
994 &self.data
995 }
996}
997
998pub trait TypeSafeEstimator<State: EstimatorState, Task: TaskType> {
1000 type Strategy: StrategyValid<Task>;
1001
1002 fn state(&self) -> State;
1004
1005 fn task_type(&self) -> Task;
1007
1008 fn validate(&self) -> Result<(), &'static str>;
1010}
1011
1012impl<State, Task, Strategy> TypeSafeEstimator<State, Task>
1013 for TypeSafeDummyEstimator<State, Task, Strategy>
1014where
1015 State: EstimatorState + Default,
1016 Task: TaskType + Default,
1017 Strategy: StrategyValid<Task>,
1018{
1019 type Strategy = Strategy;
1020
1021 fn state(&self) -> State {
1022 State::default()
1023 }
1024
1025 fn task_type(&self) -> Task {
1026 Task::default()
1027 }
1028
1029 fn validate(&self) -> Result<(), &'static str> {
1030 Ok(())
1032 }
1033}
1034
1035impl Default for Untrained {
1037 fn default() -> Self {
1038 Untrained
1039 }
1040}
1041
1042impl Default for Trained {
1043 fn default() -> Self {
1044 Trained
1045 }
1046}
1047
1048impl Default for Classification {
1049 fn default() -> Self {
1050 Classification
1051 }
1052}
1053
1054impl Default for Regression {
1055 fn default() -> Self {
1056 Regression
1057 }
1058}
1059
1060#[allow(non_snake_case)]
1061#[cfg(test)]
1062mod tests {
1063 use super::*;
1064 use crate::{ClassifierStrategy, RegressorStrategy};
1065 use scirs2_core::ndarray::{array, Array2};
1066
1067 #[test]
1068 fn test_type_safe_classifier_creation() {
1069 let strategy = ClassifierStrategy::MostFrequent;
1070 let classifier = TypeSafeDummyEstimator::<Untrained, Classification, _>::new(strategy);
1071
1072 assert!(classifier.validate().is_ok());
1074 }
1075
1076 #[test]
1077 fn test_type_safe_regressor_creation() {
1078 let strategy = RegressorStrategy::Mean;
1079 let regressor = TypeSafeDummyEstimator::<Untrained, Regression, _>::new(strategy);
1080
1081 assert!(regressor.validate().is_ok());
1083 }
1084
1085 #[test]
1086 fn test_bounded_parameters() {
1087 let prob = Probability::new_f64(0.5);
1089 assert!(prob.is_ok());
1090 assert_eq!(prob.unwrap().get_f64(), 0.5);
1091
1092 let invalid_prob = Probability::new_f64(1.5);
1094 assert!(invalid_prob.is_err());
1095
1096 let pos_int = PositiveInt::new_i32(42);
1098 assert!(pos_int.is_ok());
1099 assert_eq!(pos_int.unwrap().get(), 42);
1100
1101 let invalid_int = PositiveInt::new_i32(-1);
1103 assert!(invalid_int.is_err());
1104 }
1105
1106 #[test]
1107 fn test_validated_strategy() {
1108 let strategy = ClassifierStrategy::MostFrequent;
1109 let validated = ValidatedStrategy::<_, Classification>::new(strategy.clone());
1110
1111 assert_eq!(
1112 format!("{:?}", validated.strategy()),
1113 format!("{:?}", strategy)
1114 );
1115
1116 let extracted = validated.into_strategy();
1117 assert_eq!(format!("{:?}", extracted), format!("{:?}", strategy));
1118 }
1119
1120 #[test]
1121 fn test_type_safe_parameters() {
1122 let params = TypeSafeParameters::new(42.0);
1123 assert_eq!(*params.get(), 42.0);
1124 assert_eq!(params.into_inner(), 42.0);
1125 }
1126
1127 #[test]
1128 fn test_estimator_state_transitions() {
1129 let strategy = ClassifierStrategy::MostFrequent;
1130 let untrained = TypeSafeDummyEstimator::<Untrained, Classification, _>::new(strategy);
1131
1132 let x =
1136 Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
1137 let y = array![0, 0, 1, 1];
1138
1139 let fitted = untrained.fit(&x, &y).unwrap();
1140 let predictions = fitted.predict(&x);
1141
1142 assert!(predictions.is_ok());
1143 assert_eq!(predictions.unwrap().len(), 4);
1144 }
1145
1146 #[test]
1147 fn test_const_generic_estimator() {
1148 let classifier = ConstGenericEstimator::<Untrained, Classification, 0>::new();
1150 let strategy = ConstGenericEstimator::<Untrained, Classification, 0>::strategy();
1151
1152 assert_eq!(format!("{:?}", strategy), "MostFrequent");
1153
1154 let regressor = ConstGenericEstimator::<Untrained, Regression, 10>::new();
1156 let reg_strategy = ConstGenericEstimator::<Untrained, Regression, 10>::strategy();
1157
1158 assert_eq!(format!("{:?}", reg_strategy), "Mean");
1159 }
1160
1161 #[test]
1162 fn test_zero_cost_wrapper() {
1163 let wrapper = ZeroCostWrapper::new(42);
1164 assert_eq!(*wrapper.get(), 42);
1165 assert_eq!(wrapper.into_inner(), 42);
1166
1167 let mapped = ZeroCostWrapper::new(10).map(|x| x * 2);
1168 assert_eq!(mapped.into_inner(), 20);
1169 }
1170
1171 #[test]
1172 fn test_statistical_phantom() {
1173 let phantom = SimpleDeterministicStrategy::new();
1175 assert!(SimpleDeterministicStrategy::is_deterministic());
1176 assert!(SimpleDeterministicStrategy::is_stateless());
1177 assert!(!SimpleDeterministicStrategy::requires_fitting());
1178
1179 let stochastic = StochasticFittedStrategy::new();
1181 assert!(!StochasticFittedStrategy::is_deterministic());
1182 assert!(!StochasticFittedStrategy::is_stateless());
1183 assert!(StochasticFittedStrategy::requires_fitting());
1184 }
1185
1186 #[test]
1187 fn test_compile_time_feature() {
1188 type EnabledFeature = CompileTimeFeature<true>;
1190 assert!(EnabledFeature::is_enabled());
1191
1192 let result = EnabledFeature::when_enabled(|| 42);
1193 assert_eq!(result, Some(42));
1194
1195 type DisabledFeature = CompileTimeFeature<false>;
1197 assert!(!DisabledFeature::is_enabled());
1198
1199 let no_result = DisabledFeature::when_enabled(|| 42);
1200 assert_eq!(no_result, None);
1201 }
1202
1203 #[test]
1204 fn test_type_safe_statistical_ops() {
1205 use scirs2_core::ndarray::array;
1206
1207 let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
1208
1209 assert_eq!(data.safe_mean(), Some(3.0));
1210 assert!(data.safe_variance().is_some());
1211 assert!(data.safe_std().is_some());
1212
1213 let empty: Array1<f64> = array![];
1215 assert_eq!(empty.safe_mean(), None);
1216 assert_eq!(empty.safe_variance(), None);
1217 assert_eq!(empty.safe_std(), None);
1218
1219 let single = array![42.0];
1221 assert_eq!(single.safe_mean(), Some(42.0));
1222 assert_eq!(single.safe_variance(), None); }
1224
1225 #[test]
1226 fn test_optimized_layout() {
1227 let layout = OptimizedLayout::new(42);
1228 assert_eq!(*layout.data(), 42);
1229 assert_eq!(layout.into_data(), 42);
1230 }
1231
1232 }