1use crate::error::Result;
2use crate::types::FloatBounds;
3use std::fmt::Debug;
4
5#[derive(Debug, Clone, Copy)]
7pub struct Untrained;
8
9#[derive(Debug, Clone, Copy)]
11pub struct Trained;
12
13pub trait Estimator<State = Untrained> {
15 type Config: Clone + Debug + Send + Sync;
17
18 type Error: std::error::Error + Send + Sync + 'static;
20
21 type Float: FloatBounds + Send + Sync;
23
24 fn config(&self) -> &Self::Config;
26
27 fn validate_config(&self) -> Result<()> {
29 Ok(())
30 }
31
32 fn check_compatibility(&self, n_samples: usize, n_features: usize) -> Result<()> {
34 if n_samples == 0 {
35 return Err(crate::error::SklearsError::InvalidInput(
36 "Number of samples cannot be zero".to_string(),
37 ));
38 }
39 if n_features == 0 {
40 return Err(crate::error::SklearsError::InvalidInput(
41 "Number of features cannot be zero".to_string(),
42 ));
43 }
44 Ok(())
45 }
46
47 fn metadata(&self) -> EstimatorMetadata {
49 EstimatorMetadata::default()
50 }
51}
52
53#[derive(Debug, Clone, Default)]
55pub struct EstimatorMetadata {
56 pub name: String,
57 pub version: String,
58 pub description: String,
59 pub supports_sparse: bool,
60 pub supports_multiclass: bool,
61 pub supports_multilabel: bool,
62 pub requires_positive_input: bool,
63 pub supports_online_learning: bool,
64 pub supports_feature_importance: bool,
65 pub memory_complexity: MemoryComplexity,
66 pub time_complexity: TimeComplexity,
67}
68
69#[derive(Debug, Clone, Default)]
71pub enum MemoryComplexity {
72 #[default]
73 Linear, Quadratic, Constant, Logarithmic, }
78
79#[derive(Debug, Clone, Default)]
81pub enum TimeComplexity {
82 #[default]
83 Linear, Quadratic, LogLinear, Polynomial, Exponential, }
89
90pub trait Fit<X, Y, State = Untrained> {
92 type Fitted: Send + Sync;
94
95 fn fit(self, x: &X, y: &Y) -> Result<Self::Fitted>;
97
98 fn fit_with_validation(
100 self,
101 x: &X,
102 y: &Y,
103 _x_val: Option<&X>,
104 _y_val: Option<&Y>,
105 ) -> Result<(Self::Fitted, FitMetrics)>
106 where
107 Self: Sized,
108 {
109 let fitted = self.fit(x, y)?;
110 Ok((fitted, FitMetrics::default()))
111 }
112}
113
114#[derive(Debug, Clone, Default)]
116pub struct FitMetrics {
117 pub training_score: Option<f64>,
118 pub validation_score: Option<f64>,
119 pub iterations: usize,
120 pub convergence_achieved: bool,
121 pub early_stopping_triggered: bool,
122}
123
124pub trait Predict<X, Output> {
126 fn predict(&self, x: &X) -> Result<Output>;
128
129 fn predict_with_uncertainty(&self, x: &X) -> Result<(Output, UncertaintyMeasure)> {
131 let predictions = self.predict(x)?;
132 Ok((predictions, UncertaintyMeasure::default()))
133 }
134}
135
136#[derive(Debug, Clone, Default)]
138pub struct UncertaintyMeasure {
139 pub confidence_intervals: Option<Vec<(f64, f64)>>,
140 pub prediction_variance: Option<Vec<f64>>,
141 pub epistemic_uncertainty: Option<Vec<f64>>,
142 pub aleatoric_uncertainty: Option<Vec<f64>>,
143}
144
145pub trait Transform<X, Output = X> {
147 fn transform(&self, x: &X) -> Result<Output>;
149}
150
151pub trait TransformInplace<X> {
153 fn transform_inplace(&mut self, x: &mut X) -> Result<()>;
155}
156
157pub trait FitPredict<X, Y, Output> {
159 fn fit_predict(self, x_train: &X, y_train: &Y, x_test: &X) -> Result<Output>;
161}
162
163pub trait FitTransform<X, Y = (), Output = X> {
165 fn fit_transform(self, x: &X, y: Option<&Y>) -> Result<Output>;
167}
168
169pub trait PartialFit<X, Y> {
171 fn partial_fit(&mut self, x: &X, y: &Y) -> Result<()>;
173}
174
175pub trait Score<X, Y> {
177 type Float: FloatBounds;
179
180 fn score(&self, x: &X, y: &Y) -> Result<Self::Float>;
182}
183
184pub trait PredictProba<X, Output> {
186 fn predict_proba(&self, x: &X) -> Result<Output>;
188}
189
190pub trait DecisionFunction<X, Output> {
192 fn decision_function(&self, x: &X) -> Result<Output>;
194}
195
196pub trait GetParams {
198 fn get_params(&self) -> std::collections::HashMap<String, String>;
200}
201
202pub trait SetParams {
204 fn set_params(&mut self, params: std::collections::HashMap<String, String>) -> Result<()>;
206}
207
208pub trait Cluster<X> {
210 type Labels;
212
213 fn fit_predict(self, x: &X) -> Result<Self::Labels>;
215}
216
217pub trait FeatureImportance {
221 fn feature_importances(&self) -> Result<Vec<f64>>;
223
224 fn feature_names(&self) -> Option<Vec<String>> {
226 None
227 }
228}
229
230pub trait ModelIntrospection {
232 fn get_model_structure(&self) -> Result<ModelStructure>;
234
235 fn decision_path(&self, x: &[f64]) -> Result<Vec<DecisionNode>>;
237}
238
239#[derive(Debug, Clone)]
241pub enum ModelStructure {
242 Linear {
243 weights: Vec<f64>,
244 bias: f64,
245 },
246 Tree {
247 root: DecisionNode,
248 },
249 Neural {
250 layers: Vec<LayerInfo>,
251 },
252 Ensemble {
253 base_models: Vec<Box<ModelStructure>>,
254 },
255}
256
257#[derive(Debug, Clone)]
259pub struct DecisionNode {
260 pub feature_index: Option<usize>,
261 pub threshold: Option<f64>,
262 pub impurity: Option<f64>,
263 pub samples: usize,
264 pub value: Vec<f64>,
265 pub is_leaf: bool,
266}
267
268#[derive(Debug, Clone)]
270pub struct LayerInfo {
271 pub layer_type: String,
272 pub input_size: usize,
273 pub output_size: usize,
274 pub activation: String,
275}
276
277pub trait OnlineLearning<X, Y> {
279 fn partial_fit(&mut self, x: &X, y: &Y) -> Result<()>;
281
282 fn needs_more_data(&self) -> bool {
284 false
285 }
286
287 fn reset(&mut self) -> Result<()>;
289}
290
291pub trait HyperparameterOptimization {
293 type HyperparameterSpace;
294
295 fn hyperparameter_space(&self) -> Self::HyperparameterSpace;
297
298 fn validate_hyperparameters(
300 &self,
301 params: &std::collections::HashMap<String, f64>,
302 ) -> Result<()>;
303}
304
305pub trait RobustEstimation {
307 fn set_robustness_params(&mut self, outlier_fraction: f64) -> Result<()>;
309
310 fn identify_outliers(&self, x: &[&[f64]]) -> Result<Vec<bool>>;
312}
313
314pub trait SupervisedLearner<X, Y, Output>: Fit<X, Y> + Predict<X, Output>
318where
319 Self::Fitted: Predict<X, Output>,
320 Self: Sized,
321{
322 fn fit_predict(self, x_train: &X, y_train: &Y, x_test: &X) -> Result<Output> {
324 let fitted = self.fit(x_train, y_train)?;
325 fitted.predict(x_test)
326 }
327}
328
329pub trait InterpretableModel<X, Y, Output>:
331 SupervisedLearner<X, Y, Output> + FeatureImportance + ModelIntrospection
332where
333 Self::Fitted: Predict<X, Output> + FeatureImportance + ModelIntrospection,
334{
335 fn explain_prediction(&self, x: &[f64]) -> Result<PredictionExplanation> {
337 let importance = self.feature_importances()?;
338 let path = self.decision_path(x)?;
339 Ok(PredictionExplanation {
340 feature_contributions: importance,
341 decision_path: path,
342 confidence: None,
343 })
344 }
345}
346
347#[derive(Debug, Clone)]
349pub struct PredictionExplanation {
350 pub feature_contributions: Vec<f64>,
351 pub decision_path: Vec<DecisionNode>,
352 pub confidence: Option<f64>,
353}
354
355impl<T, X, Y, Output> SupervisedLearner<X, Y, Output> for T
357where
358 T: Fit<X, Y> + Predict<X, Output> + Sized,
359 T::Fitted: Predict<X, Output>,
360{
361}
362
363pub trait Classifier<X, Labels, Probabilities>:
365 SupervisedLearner<X, Labels, Labels> + PredictProba<X, Probabilities>
366where
367 Self::Fitted: Predict<X, Labels> + PredictProba<X, Probabilities>,
368 Self: Sized,
369{
370 fn classify_with_proba(
372 self,
373 x_train: &X,
374 y_train: &Labels,
375 x_test: &X,
376 ) -> Result<(Labels, Probabilities)> {
377 let fitted = self.fit(x_train, y_train)?;
378 let predictions = fitted.predict(x_test)?;
379 let probabilities = fitted.predict_proba(x_test)?;
380 Ok((predictions, probabilities))
381 }
382}
383
384impl<T, X, Labels, Probabilities> Classifier<X, Labels, Probabilities> for T
386where
387 T: SupervisedLearner<X, Labels, Labels> + PredictProba<X, Probabilities> + Sized,
388 T::Fitted: Predict<X, Labels> + PredictProba<X, Probabilities>,
389{
390}
391
392pub trait Regressor<X, Y>: Fit<X, Y> + Predict<X, Y> + Score<X, Y>
394where
395 Self::Fitted: Predict<X, Y> + Score<X, Y>,
396 Self: Sized,
397{
398 #[allow(clippy::type_complexity)]
400 fn regress_and_score(
401 self,
402 x_train: &X,
403 y_train: &Y,
404 x_test: &X,
405 y_test: &Y,
406 ) -> Result<(Y, <Self::Fitted as Score<X, Y>>::Float)> {
407 let fitted = self.fit(x_train, y_train)?;
408 let predictions = fitted.predict(x_test)?;
409 let score = fitted.score(x_test, y_test)?;
410 Ok((predictions, score))
411 }
412}
413
414impl<T, X, Y> Regressor<X, Y> for T
416where
417 T: Fit<X, Y> + Predict<X, Y> + Score<X, Y> + Sized,
418 T::Fitted: Predict<X, Y> + Score<X, Y>,
419{
420}
421
422pub trait Transformer<X, Y = (), Output = X>: FitTransform<X, Y, Output>
424where
425 Self: Sized,
426{
427 fn fit_then_transform(self, x: &X, y: Option<&Y>) -> Result<Output> {
429 self.fit_transform(x, y)
430 }
431}
432
433impl<T, X, Y, Output> Transformer<X, Y, Output> for T where T: FitTransform<X, Y, Output> + Sized {}
435
436pub trait MLPipeline<X, Y, Output>:
438 Fit<X, Y> + Predict<X, Output> + Transform<X, X> + Score<X, Y>
439where
440 Self::Fitted: Predict<X, Output> + Transform<X, X> + Score<X, Y, Float = Self::Float>,
441 Self: Sized,
442{
443 fn execute_pipeline(
445 self,
446 x_train: &X,
447 y_train: &Y,
448 x_test: &X,
449 y_test: &Y,
450 ) -> Result<PipelineResult<Output, X, Self::Float>> {
451 let fitted = self.fit(x_train, y_train)?;
452 let transformed_test = fitted.transform(x_test)?;
453 let predictions = fitted.predict(&transformed_test)?;
454 let score = fitted.score(x_test, y_test)?;
455
456 Ok(PipelineResult {
457 predictions,
458 score,
459 transformed_features: transformed_test,
460 })
461 }
462}
463
464#[derive(Debug, Clone)]
466pub struct PipelineResult<Predictions, Features, Score> {
467 pub predictions: Predictions,
468 pub score: Score,
469 pub transformed_features: Features,
470}
471
472impl<T, X, Y, Output> MLPipeline<X, Y, Output> for T
474where
475 T: Fit<X, Y> + Predict<X, Output> + Transform<X, X> + Score<X, Y> + Sized,
476 T::Fitted: Predict<X, Output> + Transform<X, X> + Score<X, Y, Float = T::Float>,
477{
478}
479
480pub trait OnlineLearner<X, Y, Output>: PartialFit<X, Y> + Predict<X, Output> + Score<X, Y> {
482 fn train_incrementally(
484 &mut self,
485 batches: &[(X, Y)],
486 x_test: &X,
487 y_test: &Y,
488 ) -> Result<Vec<Self::Float>> {
489 let mut scores = Vec::with_capacity(batches.len());
490
491 for (x_batch, y_batch) in batches {
492 self.partial_fit(x_batch, y_batch)?;
493 let score = self.score(x_test, y_test)?;
494 scores.push(score);
495 }
496
497 Ok(scores)
498 }
499}
500
501impl<T, X, Y, Output> OnlineLearner<X, Y, Output> for T where
503 T: PartialFit<X, Y> + Predict<X, Output> + Score<X, Y>
504{
505}
506
507pub trait ModelEvaluator<X, Y, Output> {
509 type Score: FloatBounds;
510
511 fn cross_validate(
513 &self,
514 model: impl Fit<X, Y> + Clone,
515 x: &X,
516 y: &Y,
517 cv_folds: usize,
518 ) -> Result<Vec<Self::Score>>;
519
520 fn model_selection(&self, models: Vec<impl Fit<X, Y> + Clone>, x: &X, y: &Y) -> Result<usize>; }
523
524pub mod async_traits {
526 use super::*;
527 use std::future::Future;
528 use std::pin::Pin;
529
530 pub trait AsyncFit<X, Y, State = Untrained> {
532 type Fitted;
533 type Error: std::error::Error + Send + Sync;
534
535 fn fit_async<'a>(
537 self,
538 x: &'a X,
539 y: &'a Y,
540 ) -> Pin<Box<dyn Future<Output = Result<Self::Fitted>> + Send + 'a>>
541 where
542 Self: Sized + 'a;
543 }
544
545 pub trait AsyncPredict<X, Output> {
547 type Error: std::error::Error + Send + Sync;
548
549 fn predict_async<'a>(
551 &'a self,
552 x: &'a X,
553 ) -> Pin<Box<dyn Future<Output = Result<Output>> + Send + 'a>>;
554 }
555
556 pub trait AsyncTransform<X, Output = X> {
558 type Error: std::error::Error + Send + Sync;
559
560 fn transform_async<'a>(
562 &'a self,
563 x: &'a X,
564 ) -> Pin<Box<dyn Future<Output = Result<Output>> + Send + 'a>>;
565 }
566}
567
568pub mod streaming {
570 use super::*;
571 use futures_core::Stream;
572 use std::pin::Pin;
573
574 pub trait StreamingFit<S, Y> {
576 type Fitted;
577 type Error: std::error::Error + Send + Sync;
578
579 fn fit_stream(
581 self,
582 stream: S,
583 targets: Y,
584 ) -> Pin<Box<dyn futures_core::Future<Output = Result<Self::Fitted>> + Send>>
585 where
586 S: Stream + Send,
587 Y: Send;
588 }
589
590 pub trait StreamingPredict<S, Output> {
592 type Error: std::error::Error + Send + Sync;
593
594 fn predict_stream<'a>(
596 &'a self,
597 stream: S,
598 ) -> Pin<Box<dyn Stream<Item = Result<Output>> + Send + 'a>>
599 where
600 S: Stream + Send + 'a;
601 }
602
603 pub trait StreamingTransform<S, Output> {
605 type Error: std::error::Error + Send + Sync;
606
607 fn transform_stream<'a>(
609 &'a self,
610 stream: S,
611 ) -> Pin<Box<dyn Stream<Item = Result<Output>> + Send + 'a>>
612 where
613 S: Stream + Send + 'a;
614 }
615
616 pub trait StreamingPartialFit<S, Y> {
618 type Error: std::error::Error + Send + Sync;
619
620 fn partial_fit_stream<'a, Item>(
626 &'a mut self,
627 stream: S,
628 ) -> Pin<Box<dyn futures_core::Future<Output = Result<()>> + Send + 'a>>
629 where
630 S: Stream<Item = (Item, Y)> + Send + 'a,
631 Item: Send + 'a,
632 Y: Send + 'a;
633 }
634}
635
636pub mod gat_traits {
638 use super::*;
639
640 pub trait EstimatorGAT<State = Untrained> {
642 type Config;
644
645 type Error: std::error::Error;
647
648 type Float: FloatBounds;
650
651 type Input<'a>
653 where
654 Self: 'a;
655
656 type Output<'a>
658 where
659 Self: 'a;
660
661 type Parameters;
663 }
664
665 pub trait FitGAT<State = Untrained> {
667 type Input<'a>
669 where
670 Self: 'a;
671 type Target<'a>
672 where
673 Self: 'a;
674 type Fitted;
675 type Error: std::error::Error;
676
677 fn fit_gat<'a>(
688 self,
689 input: Self::Input<'a>,
690 target: Self::Target<'a>,
691 ) -> Result<Self::Fitted>
692 where
693 Self: 'a; }
695
696 pub trait TransformGAT {
698 type Input<'a>
700 where
701 Self: 'a;
702
703 type Output<'a>
705 where
706 Self: 'a;
707
708 type Error: std::error::Error;
710
711 fn transform_gat<'a>(&self, input: Self::Input<'a>) -> Result<Self::Output<'a>>;
723 }
724
725 pub trait IteratorProcessor {
727 type Item<'a>
729 where
730 Self: 'a;
731
732 type ProcessedItem<'a>
734 where
735 Self: 'a;
736
737 type Error: std::error::Error;
739
740 fn process_iter<'input, 'output, I>(
750 &self,
751 iter: I,
752 ) -> impl Iterator<Item = Result<Self::ProcessedItem<'output>>> + 'output
753 where
754 I: Iterator<Item = Self::Item<'input>> + 'input,
755 'input: 'output, Self: 'input + 'output;
757 }
758}
759
760pub mod trait_families {
762 use super::*;
763
764 pub trait CoreMLFamily<State = Untrained>: Estimator<State> + GetParams + SetParams {
766 fn algorithm_family(&self) -> &'static str;
768
769 fn algorithm_category(&self) -> &'static str;
771
772 fn supports_capability(&self, capability: &str) -> bool;
774 }
775
776 pub trait SupervisedLearningFamily<X, Y, Output>:
778 CoreMLFamily + Fit<X, Y> + Predict<X, Output> + Score<X, Y>
779 where
780 Self::Fitted: Predict<X, Output> + Score<X, Y>,
781 {
782 fn learning_type(&self) -> SupervisedType;
784
785 fn supports_feature_importance(&self) -> bool {
787 false
788 }
789
790 fn supports_incremental_learning(&self) -> bool {
792 false
793 }
794 }
795
796 pub trait ClassificationFamily<X, Labels, Probabilities>:
798 SupervisedLearningFamily<X, Labels, Labels> + PredictProba<X, Probabilities>
799 where
800 Self::Fitted: Predict<X, Labels> + PredictProba<X, Probabilities> + Score<X, Labels>,
801 {
802 fn classification_type(&self) -> ClassificationType;
804
805 fn supports_calibration(&self) -> bool {
807 false
808 }
809
810 fn supports_multilabel(&self) -> bool {
812 false
813 }
814 }
815
816 pub trait RegressionFamily<X, Y>: SupervisedLearningFamily<X, Y, Y> + Score<X, Y>
818 where
819 Self::Fitted: Predict<X, Y> + Score<X, Y>,
820 {
821 fn regression_type(&self) -> RegressionType;
823
824 fn supports_prediction_intervals(&self) -> bool {
826 false
827 }
828
829 fn supports_robust_fitting(&self) -> bool {
831 false
832 }
833 }
834
835 pub trait UnsupervisedLearningFamily<X>: CoreMLFamily + Transform<X> {
837 fn unsupervised_type(&self) -> UnsupervisedType;
839
840 fn supports_inverse_transform(&self) -> bool {
842 false
843 }
844
845 fn is_deterministic(&self) -> bool {
847 true
848 }
849 }
850
851 pub trait ClusteringFamily<X>: UnsupervisedLearningFamily<X> + Cluster<X> {
853 fn clustering_type(&self) -> ClusteringType;
855
856 fn supports_hierarchical(&self) -> bool {
858 false
859 }
860
861 fn supports_variable_clusters(&self) -> bool {
863 false
864 }
865
866 fn supports_cluster_centers(&self) -> bool {
868 false
869 }
870 }
871
872 pub trait DimensionalityReductionFamily<X>:
874 UnsupervisedLearningFamily<X> + FitTransform<X, (), X>
875 {
876 fn reduction_type(&self) -> DimensionalityReductionType;
878
879 fn target_dimensions(&self) -> Option<usize>;
881
882 fn preserves_distances(&self) -> bool {
884 false
885 }
886 }
887
888 pub trait EnsembleFamily<X, Y, Output>: SupervisedLearningFamily<X, Y, Output>
890 where
891 Self::Fitted: Predict<X, Output> + Score<X, Y>,
892 {
893 fn ensemble_type(&self) -> EnsembleType;
895
896 fn n_estimators(&self) -> usize;
898
899 fn supports_oob_score(&self) -> bool {
901 false
902 }
903 }
904
905 pub trait NeuralNetworkFamily<X, Y, Output>:
907 SupervisedLearningFamily<X, Y, Output> + PartialFit<X, Y>
908 where
909 Self::Fitted: Predict<X, Output> + Score<X, Y>,
910 {
911 fn network_type(&self) -> NetworkType;
913
914 fn n_layers(&self) -> usize;
916
917 fn supports_dropout(&self) -> bool {
919 false
920 }
921
922 fn supports_batch_norm(&self) -> bool {
924 false
925 }
926 }
927
928 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
931 pub enum SupervisedType {
932 Classification,
933 Regression,
934 }
935
936 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
937 pub enum ClassificationType {
938 Binary,
939 Multiclass,
940 Multilabel,
941 }
942
943 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
944 pub enum RegressionType {
945 Linear,
946 Nonlinear,
947 Robust,
948 }
949
950 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
951 pub enum UnsupervisedType {
952 Clustering,
953 DimensionalityReduction,
954 DensityEstimation,
955 OutlierDetection,
956 }
957
958 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
959 pub enum ClusteringType {
960 Partitional,
961 Hierarchical,
962 DensityBased,
963 GridBased,
964 }
965
966 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
967 pub enum DimensionalityReductionType {
968 Linear,
969 Nonlinear,
970 Manifold,
971 }
972
973 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
974 pub enum EnsembleType {
975 Bagging,
976 Boosting,
977 Voting,
978 Stacking,
979 }
980
981 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
982 pub enum NetworkType {
983 Feedforward,
984 Convolutional,
985 Recurrent,
986 Transformer,
987 }
988
989 impl<T, X, Y, Output> SupervisedLearningFamily<X, Y, Output> for T
991 where
992 T: CoreMLFamily + Fit<X, Y> + Predict<X, Output> + Score<X, Y>,
993 T::Fitted: Predict<X, Output> + Score<X, Y>,
994 {
995 fn learning_type(&self) -> SupervisedType {
996 SupervisedType::Classification
998 }
999 }
1000
1001 impl<T, X, Labels, Probabilities> ClassificationFamily<X, Labels, Probabilities> for T
1002 where
1003 T: SupervisedLearningFamily<X, Labels, Labels> + PredictProba<X, Probabilities>,
1004 T::Fitted: Predict<X, Labels> + PredictProba<X, Probabilities> + Score<X, Labels>,
1005 {
1006 fn classification_type(&self) -> ClassificationType {
1007 ClassificationType::Binary
1009 }
1010 }
1011
1012 impl<T, X, Y> RegressionFamily<X, Y> for T
1013 where
1014 T: SupervisedLearningFamily<X, Y, Y> + Score<X, Y>,
1015 T::Fitted: Predict<X, Y> + Score<X, Y>,
1016 {
1017 fn regression_type(&self) -> RegressionType {
1018 RegressionType::Linear
1020 }
1021 }
1022
1023 impl<T, X> UnsupervisedLearningFamily<X> for T
1024 where
1025 T: CoreMLFamily + Transform<X>,
1026 {
1027 fn unsupervised_type(&self) -> UnsupervisedType {
1028 UnsupervisedType::Clustering
1030 }
1031 }
1032
1033 impl<T, X> ClusteringFamily<X> for T
1034 where
1035 T: UnsupervisedLearningFamily<X> + Cluster<X>,
1036 {
1037 fn clustering_type(&self) -> ClusteringType {
1038 ClusteringType::Partitional
1040 }
1041 }
1042
1043 impl<T, X> DimensionalityReductionFamily<X> for T
1044 where
1045 T: UnsupervisedLearningFamily<X> + FitTransform<X, (), X>,
1046 {
1047 fn reduction_type(&self) -> DimensionalityReductionType {
1048 DimensionalityReductionType::Linear
1050 }
1051
1052 fn target_dimensions(&self) -> Option<usize> {
1053 None
1054 }
1055 }
1056}
1057
1058pub mod specialized {
1060 use super::*;
1061
1062 pub trait HybridLearner<X, Y, Output>:
1063 Fit<X, Y> + PartialFit<X, Y> + Predict<X, Output>
1064 where
1065 Self::Fitted: Predict<X, Output> + PartialFit<X, Y>,
1066 {
1067 fn set_learning_mode(&mut self, online: bool);
1068
1069 fn is_online_mode(&self) -> bool;
1070 }
1071
1072 pub trait InterpretableModel<X, Y, Output> {
1074 type Importance;
1076
1077 fn feature_importance(&self) -> Result<Self::Importance>;
1079
1080 fn explain_prediction(&self, input: &X) -> Result<String>;
1082
1083 fn explain_model(&self) -> Result<String>;
1085 }
1086
1087 pub trait ConfidenceModel<X, Output> {
1089 type Confidence: FloatBounds;
1091
1092 fn predict_with_confidence(&self, x: &X) -> Result<(Output, Vec<Self::Confidence>)>;
1094
1095 fn prediction_uncertainty(&self, x: &X) -> Result<Self::Confidence>;
1097 }
1098
1099 pub trait PrivacyPreservingModel<X, Y> {
1101 type PrivacyBudget: FloatBounds;
1103
1104 fn set_privacy_budget(&mut self, budget: Self::PrivacyBudget);
1106
1107 fn remaining_privacy_budget(&self) -> Self::PrivacyBudget;
1109
1110 fn is_privacy_safe(&self, operation_cost: Self::PrivacyBudget) -> bool;
1112 }
1113}