sklears_core/
traits.rs

1use crate::error::Result;
2use crate::types::FloatBounds;
3use std::fmt::Debug;
4
5/// Marker trait for untrained models
6#[derive(Debug, Clone, Copy)]
7pub struct Untrained;
8
9/// Marker trait for trained models  
10#[derive(Debug, Clone, Copy)]
11pub struct Trained;
12
13/// Base trait for all estimators with enhanced type safety
14pub trait Estimator<State = Untrained> {
15    /// Configuration type for the estimator
16    type Config: Clone + Debug + Send + Sync;
17
18    /// Error type for the estimator
19    type Error: std::error::Error + Send + Sync + 'static;
20
21    /// The numeric type used by this estimator
22    type Float: FloatBounds + Send + Sync;
23
24    /// Get estimator configuration
25    fn config(&self) -> &Self::Config;
26
27    /// Validate estimator configuration with detailed error context
28    fn validate_config(&self) -> Result<()> {
29        Ok(())
30    }
31
32    /// Check if estimator is compatible with given data dimensions
33    fn check_compatibility(&self, n_samples: usize, n_features: usize) -> Result<()> {
34        if n_samples == 0 {
35            return Err(crate::error::SklearsError::InvalidInput(
36                "Number of samples cannot be zero".to_string(),
37            ));
38        }
39        if n_features == 0 {
40            return Err(crate::error::SklearsError::InvalidInput(
41                "Number of features cannot be zero".to_string(),
42            ));
43        }
44        Ok(())
45    }
46
47    /// Get estimator metadata
48    fn metadata(&self) -> EstimatorMetadata {
49        EstimatorMetadata::default()
50    }
51}
52
53/// Metadata for estimators with enhanced capabilities
54#[derive(Debug, Clone, Default)]
55pub struct EstimatorMetadata {
56    pub name: String,
57    pub version: String,
58    pub description: String,
59    pub supports_sparse: bool,
60    pub supports_multiclass: bool,
61    pub supports_multilabel: bool,
62    pub requires_positive_input: bool,
63    pub supports_online_learning: bool,
64    pub supports_feature_importance: bool,
65    pub memory_complexity: MemoryComplexity,
66    pub time_complexity: TimeComplexity,
67}
68
69/// Memory complexity characteristics
70#[derive(Debug, Clone, Default)]
71pub enum MemoryComplexity {
72    #[default]
73    Linear, // O(n)
74    Quadratic,   // O(n²)
75    Constant,    // O(1)
76    Logarithmic, // O(log n)
77}
78
79/// Time complexity characteristics for training
80#[derive(Debug, Clone, Default)]
81pub enum TimeComplexity {
82    #[default]
83    Linear, // O(n)
84    Quadratic,   // O(n²)
85    LogLinear,   // O(n log n)
86    Polynomial,  // O(n^k)
87    Exponential, // O(2^n)
88}
89
90/// Enhanced trait for models that can be fitted to data
91pub trait Fit<X, Y, State = Untrained> {
92    /// The fitted model type
93    type Fitted: Send + Sync;
94
95    /// Fit the model to the provided data with validation
96    fn fit(self, x: &X, y: &Y) -> Result<Self::Fitted>;
97
98    /// Fit with custom validation and early stopping
99    fn fit_with_validation(
100        self,
101        x: &X,
102        y: &Y,
103        _x_val: Option<&X>,
104        _y_val: Option<&Y>,
105    ) -> Result<(Self::Fitted, FitMetrics)>
106    where
107        Self: Sized,
108    {
109        let fitted = self.fit(x, y)?;
110        Ok((fitted, FitMetrics::default()))
111    }
112}
113
114/// Metrics collected during model fitting
115#[derive(Debug, Clone, Default)]
116pub struct FitMetrics {
117    pub training_score: Option<f64>,
118    pub validation_score: Option<f64>,
119    pub iterations: usize,
120    pub convergence_achieved: bool,
121    pub early_stopping_triggered: bool,
122}
123
124/// Enhanced trait for models that can make predictions
125pub trait Predict<X, Output> {
126    /// Make predictions on the provided data
127    fn predict(&self, x: &X) -> Result<Output>;
128
129    /// Make predictions with confidence intervals
130    fn predict_with_uncertainty(&self, x: &X) -> Result<(Output, UncertaintyMeasure)> {
131        let predictions = self.predict(x)?;
132        Ok((predictions, UncertaintyMeasure::default()))
133    }
134}
135
136/// Uncertainty measures for predictions
137#[derive(Debug, Clone, Default)]
138pub struct UncertaintyMeasure {
139    pub confidence_intervals: Option<Vec<(f64, f64)>>,
140    pub prediction_variance: Option<Vec<f64>>,
141    pub epistemic_uncertainty: Option<Vec<f64>>,
142    pub aleatoric_uncertainty: Option<Vec<f64>>,
143}
144
145/// Trait for models that can transform data
146pub trait Transform<X, Output = X> {
147    /// Transform the input data
148    fn transform(&self, x: &X) -> Result<Output>;
149}
150
151/// Trait for models that can transform data in-place
152pub trait TransformInplace<X> {
153    /// Transform the input data in-place
154    fn transform_inplace(&mut self, x: &mut X) -> Result<()>;
155}
156
157/// Trait for models that can be fitted and used for prediction in one step
158pub trait FitPredict<X, Y, Output> {
159    /// Fit the model and make predictions
160    fn fit_predict(self, x_train: &X, y_train: &Y, x_test: &X) -> Result<Output>;
161}
162
163/// Trait for transformers that can be fitted and transform in one step
164pub trait FitTransform<X, Y = (), Output = X> {
165    /// Fit the transformer and transform the data
166    fn fit_transform(self, x: &X, y: Option<&Y>) -> Result<Output>;
167}
168
169/// Trait for models that support incremental/online learning
170pub trait PartialFit<X, Y> {
171    /// Incrementally fit on a batch of samples
172    fn partial_fit(&mut self, x: &X, y: &Y) -> Result<()>;
173}
174
175/// Trait for models that can calculate a score
176pub trait Score<X, Y> {
177    /// The numeric type for score calculation
178    type Float: FloatBounds;
179
180    /// Calculate the score of the model on the provided data
181    fn score(&self, x: &X, y: &Y) -> Result<Self::Float>;
182}
183
184/// Trait for models that support probability predictions
185pub trait PredictProba<X, Output> {
186    /// Predict class probabilities
187    fn predict_proba(&self, x: &X) -> Result<Output>;
188}
189
190/// Trait for models that support confidence scores
191pub trait DecisionFunction<X, Output> {
192    /// Compute the decision function
193    fn decision_function(&self, x: &X) -> Result<Output>;
194}
195
196/// Trait for models that support getting parameters
197pub trait GetParams {
198    /// Get parameters as a key-value mapping
199    fn get_params(&self) -> std::collections::HashMap<String, String>;
200}
201
202/// Trait for models that support setting parameters
203pub trait SetParams {
204    /// Set parameters from a key-value mapping
205    fn set_params(&mut self, params: std::collections::HashMap<String, String>) -> Result<()>;
206}
207
208/// Trait for clustering algorithms
209pub trait Cluster<X> {
210    /// The output type for cluster assignments
211    type Labels;
212
213    /// Fit the clustering model and return cluster assignments
214    fn fit_predict(self, x: &X) -> Result<Self::Labels>;
215}
216
217// Advanced capability traits for specific ML algorithm types
218
219/// Trait for algorithms that support feature importance
220pub trait FeatureImportance {
221    /// Get feature importance scores
222    fn feature_importances(&self) -> Result<Vec<f64>>;
223
224    /// Get feature names if available
225    fn feature_names(&self) -> Option<Vec<String>> {
226        None
227    }
228}
229
230/// Trait for algorithms that support model introspection
231pub trait ModelIntrospection {
232    /// Get model parameters as interpretable structure
233    fn get_model_structure(&self) -> Result<ModelStructure>;
234
235    /// Get decision path information for a prediction
236    fn decision_path(&self, x: &[f64]) -> Result<Vec<DecisionNode>>;
237}
238
239/// Structured representation of model internals
240#[derive(Debug, Clone)]
241pub enum ModelStructure {
242    Linear {
243        weights: Vec<f64>,
244        bias: f64,
245    },
246    Tree {
247        root: DecisionNode,
248    },
249    Neural {
250        layers: Vec<LayerInfo>,
251    },
252    Ensemble {
253        base_models: Vec<Box<ModelStructure>>,
254    },
255}
256
257/// Decision node information for model interpretability
258#[derive(Debug, Clone)]
259pub struct DecisionNode {
260    pub feature_index: Option<usize>,
261    pub threshold: Option<f64>,
262    pub impurity: Option<f64>,
263    pub samples: usize,
264    pub value: Vec<f64>,
265    pub is_leaf: bool,
266}
267
268/// Neural network layer information
269#[derive(Debug, Clone)]
270pub struct LayerInfo {
271    pub layer_type: String,
272    pub input_size: usize,
273    pub output_size: usize,
274    pub activation: String,
275}
276
277/// Trait for online/incremental learning algorithms
278pub trait OnlineLearning<X, Y> {
279    /// Update model with new data batch
280    fn partial_fit(&mut self, x: &X, y: &Y) -> Result<()>;
281
282    /// Check if the model needs more data
283    fn needs_more_data(&self) -> bool {
284        false
285    }
286
287    /// Reset the model to initial state
288    fn reset(&mut self) -> Result<()>;
289}
290
291/// Trait for algorithms with hyperparameter optimization
292pub trait HyperparameterOptimization {
293    type HyperparameterSpace;
294
295    /// Get recommended hyperparameter search space
296    fn hyperparameter_space(&self) -> Self::HyperparameterSpace;
297
298    /// Validate hyperparameter combination
299    fn validate_hyperparameters(
300        &self,
301        params: &std::collections::HashMap<String, f64>,
302    ) -> Result<()>;
303}
304
305/// Trait for robust algorithms that handle outliers
306pub trait RobustEstimation {
307    /// Set robustness parameters
308    fn set_robustness_params(&mut self, outlier_fraction: f64) -> Result<()>;
309
310    /// Identify potential outliers in training data
311    fn identify_outliers(&self, x: &[&[f64]]) -> Result<Vec<bool>>;
312}
313
314// Enhanced composite traits for common ML patterns
315
316/// Composite trait for supervised learning algorithms that can fit and predict
317pub trait SupervisedLearner<X, Y, Output>: Fit<X, Y> + Predict<X, Output>
318where
319    Self::Fitted: Predict<X, Output>,
320    Self: Sized,
321{
322    /// Default implementation for fit and predict in one step
323    fn fit_predict(self, x_train: &X, y_train: &Y, x_test: &X) -> Result<Output> {
324        let fitted = self.fit(x_train, y_train)?;
325        fitted.predict(x_test)
326    }
327}
328
329/// Composite trait for interpretable models
330pub trait InterpretableModel<X, Y, Output>:
331    SupervisedLearner<X, Y, Output> + FeatureImportance + ModelIntrospection
332where
333    Self::Fitted: Predict<X, Output> + FeatureImportance + ModelIntrospection,
334{
335    /// Generate model explanation for specific prediction
336    fn explain_prediction(&self, x: &[f64]) -> Result<PredictionExplanation> {
337        let importance = self.feature_importances()?;
338        let path = self.decision_path(x)?;
339        Ok(PredictionExplanation {
340            feature_contributions: importance,
341            decision_path: path,
342            confidence: None,
343        })
344    }
345}
346
347/// Explanation for a specific prediction
348#[derive(Debug, Clone)]
349pub struct PredictionExplanation {
350    pub feature_contributions: Vec<f64>,
351    pub decision_path: Vec<DecisionNode>,
352    pub confidence: Option<f64>,
353}
354
355/// Blanket implementation for any type that implements both Fit and Predict
356impl<T, X, Y, Output> SupervisedLearner<X, Y, Output> for T
357where
358    T: Fit<X, Y> + Predict<X, Output> + Sized,
359    T::Fitted: Predict<X, Output>,
360{
361}
362
363/// Composite trait for classifiers that provide both predictions and probabilities
364pub trait Classifier<X, Labels, Probabilities>:
365    SupervisedLearner<X, Labels, Labels> + PredictProba<X, Probabilities>
366where
367    Self::Fitted: Predict<X, Labels> + PredictProba<X, Probabilities>,
368    Self: Sized,
369{
370    /// Default implementation for classification with probability scores
371    fn classify_with_proba(
372        self,
373        x_train: &X,
374        y_train: &Labels,
375        x_test: &X,
376    ) -> Result<(Labels, Probabilities)> {
377        let fitted = self.fit(x_train, y_train)?;
378        let predictions = fitted.predict(x_test)?;
379        let probabilities = fitted.predict_proba(x_test)?;
380        Ok((predictions, probabilities))
381    }
382}
383
384/// Blanket implementation for classifier types
385impl<T, X, Labels, Probabilities> Classifier<X, Labels, Probabilities> for T
386where
387    T: SupervisedLearner<X, Labels, Labels> + PredictProba<X, Probabilities> + Sized,
388    T::Fitted: Predict<X, Labels> + PredictProba<X, Probabilities>,
389{
390}
391
392/// Composite trait for regressors with scoring capability
393pub trait Regressor<X, Y>: Fit<X, Y> + Predict<X, Y> + Score<X, Y>
394where
395    Self::Fitted: Predict<X, Y> + Score<X, Y>,
396    Self: Sized,
397{
398    /// Default implementation for regression with scoring
399    #[allow(clippy::type_complexity)]
400    fn regress_and_score(
401        self,
402        x_train: &X,
403        y_train: &Y,
404        x_test: &X,
405        y_test: &Y,
406    ) -> Result<(Y, <Self::Fitted as Score<X, Y>>::Float)> {
407        let fitted = self.fit(x_train, y_train)?;
408        let predictions = fitted.predict(x_test)?;
409        let score = fitted.score(x_test, y_test)?;
410        Ok((predictions, score))
411    }
412}
413
414/// Blanket implementation for regressor types
415impl<T, X, Y> Regressor<X, Y> for T
416where
417    T: Fit<X, Y> + Predict<X, Y> + Score<X, Y> + Sized,
418    T::Fitted: Predict<X, Y> + Score<X, Y>,
419{
420}
421
422/// Composite trait for transformers that can fit and transform  
423pub trait Transformer<X, Y = (), Output = X>: FitTransform<X, Y, Output>
424where
425    Self: Sized,
426{
427    /// Default implementation that leverages fit_transform
428    fn fit_then_transform(self, x: &X, y: Option<&Y>) -> Result<Output> {
429        self.fit_transform(x, y)
430    }
431}
432
433/// Blanket implementation for transformer types  
434impl<T, X, Y, Output> Transformer<X, Y, Output> for T where T: FitTransform<X, Y, Output> + Sized {}
435
436/// Composite trait for complete ML pipelines
437pub trait MLPipeline<X, Y, Output>:
438    Fit<X, Y> + Predict<X, Output> + Transform<X, X> + Score<X, Y>
439where
440    Self::Fitted: Predict<X, Output> + Transform<X, X> + Score<X, Y, Float = Self::Float>,
441    Self: Sized,
442{
443    /// Execute a complete ML pipeline: fit, transform, predict, and score
444    fn execute_pipeline(
445        self,
446        x_train: &X,
447        y_train: &Y,
448        x_test: &X,
449        y_test: &Y,
450    ) -> Result<PipelineResult<Output, X, Self::Float>> {
451        let fitted = self.fit(x_train, y_train)?;
452        let transformed_test = fitted.transform(x_test)?;
453        let predictions = fitted.predict(&transformed_test)?;
454        let score = fitted.score(x_test, y_test)?;
455
456        Ok(PipelineResult {
457            predictions,
458            score,
459            transformed_features: transformed_test,
460        })
461    }
462}
463
464/// Result type for ML pipeline execution
465#[derive(Debug, Clone)]
466pub struct PipelineResult<Predictions, Features, Score> {
467    pub predictions: Predictions,
468    pub score: Score,
469    pub transformed_features: Features,
470}
471
472/// Blanket implementation for complete pipeline types
473impl<T, X, Y, Output> MLPipeline<X, Y, Output> for T
474where
475    T: Fit<X, Y> + Predict<X, Output> + Transform<X, X> + Score<X, Y> + Sized,
476    T::Fitted: Predict<X, Output> + Transform<X, X> + Score<X, Y, Float = T::Float>,
477{
478}
479
480/// Composite trait for online learners that support incremental learning
481pub trait OnlineLearner<X, Y, Output>: PartialFit<X, Y> + Predict<X, Output> + Score<X, Y> {
482    /// Train incrementally and evaluate performance
483    fn train_incrementally(
484        &mut self,
485        batches: &[(X, Y)],
486        x_test: &X,
487        y_test: &Y,
488    ) -> Result<Vec<Self::Float>> {
489        let mut scores = Vec::with_capacity(batches.len());
490
491        for (x_batch, y_batch) in batches {
492            self.partial_fit(x_batch, y_batch)?;
493            let score = self.score(x_test, y_test)?;
494            scores.push(score);
495        }
496
497        Ok(scores)
498    }
499}
500
501/// Blanket implementation for online learner types
502impl<T, X, Y, Output> OnlineLearner<X, Y, Output> for T where
503    T: PartialFit<X, Y> + Predict<X, Output> + Score<X, Y>
504{
505}
506
507/// Trait for model evaluation and comparison
508pub trait ModelEvaluator<X, Y, Output> {
509    type Score: FloatBounds;
510
511    /// Evaluate model performance using cross-validation
512    fn cross_validate(
513        &self,
514        model: impl Fit<X, Y> + Clone,
515        x: &X,
516        y: &Y,
517        cv_folds: usize,
518    ) -> Result<Vec<Self::Score>>;
519
520    /// Compare multiple models and return the best one
521    fn model_selection(&self, models: Vec<impl Fit<X, Y> + Clone>, x: &X, y: &Y) -> Result<usize>; // Returns index of best model
522}
523
524/// Async versions of core traits for streaming and non-blocking operations
525pub mod async_traits {
526    use super::*;
527    use std::future::Future;
528    use std::pin::Pin;
529
530    /// Async version of Fit trait for non-blocking training
531    pub trait AsyncFit<X, Y, State = Untrained> {
532        type Fitted;
533        type Error: std::error::Error + Send + Sync;
534
535        /// Fit the model asynchronously
536        fn fit_async<'a>(
537            self,
538            x: &'a X,
539            y: &'a Y,
540        ) -> Pin<Box<dyn Future<Output = Result<Self::Fitted>> + Send + 'a>>
541        where
542            Self: Sized + 'a;
543    }
544
545    /// Async version of Predict trait for non-blocking prediction
546    pub trait AsyncPredict<X, Output> {
547        type Error: std::error::Error + Send + Sync;
548
549        /// Make predictions asynchronously
550        fn predict_async<'a>(
551            &'a self,
552            x: &'a X,
553        ) -> Pin<Box<dyn Future<Output = Result<Output>> + Send + 'a>>;
554    }
555
556    /// Async version of Transform trait for non-blocking transformation
557    pub trait AsyncTransform<X, Output = X> {
558        type Error: std::error::Error + Send + Sync;
559
560        /// Transform data asynchronously
561        fn transform_async<'a>(
562            &'a self,
563            x: &'a X,
564        ) -> Pin<Box<dyn Future<Output = Result<Output>> + Send + 'a>>;
565    }
566}
567
568/// Streaming data processing traits for large datasets
569pub mod streaming {
570    use super::*;
571    use futures_core::Stream;
572    use std::pin::Pin;
573
574    /// Trait for processing streaming data
575    pub trait StreamingFit<S, Y> {
576        type Fitted;
577        type Error: std::error::Error + Send + Sync;
578
579        /// Fit model on streaming data
580        fn fit_stream(
581            self,
582            stream: S,
583            targets: Y,
584        ) -> Pin<Box<dyn futures_core::Future<Output = Result<Self::Fitted>> + Send>>
585        where
586            S: Stream + Send,
587            Y: Send;
588    }
589
590    /// Trait for streaming predictions
591    pub trait StreamingPredict<S, Output> {
592        type Error: std::error::Error + Send + Sync;
593
594        /// Make predictions on streaming data
595        fn predict_stream<'a>(
596            &'a self,
597            stream: S,
598        ) -> Pin<Box<dyn Stream<Item = Result<Output>> + Send + 'a>>
599        where
600            S: Stream + Send + 'a;
601    }
602
603    /// Trait for streaming transformations
604    pub trait StreamingTransform<S, Output> {
605        type Error: std::error::Error + Send + Sync;
606
607        /// Transform streaming data
608        fn transform_stream<'a>(
609            &'a self,
610            stream: S,
611        ) -> Pin<Box<dyn Stream<Item = Result<Output>> + Send + 'a>>
612        where
613            S: Stream + Send + 'a;
614    }
615
616    /// Trait for incremental learning on streaming data
617    pub trait StreamingPartialFit<S, Y> {
618        type Error: std::error::Error + Send + Sync;
619
620        /// Incrementally fit on streaming batches
621        ///
622        /// # Lifetime Parameters
623        ///
624        /// The returned future must not outlive the mutable reference to self
625        fn partial_fit_stream<'a, Item>(
626            &'a mut self,
627            stream: S,
628        ) -> Pin<Box<dyn futures_core::Future<Output = Result<()>> + Send + 'a>>
629        where
630            S: Stream<Item = (Item, Y)> + Send + 'a,
631            Item: Send + 'a,
632            Y: Send + 'a;
633    }
634}
635
636/// Generic Associated Types (GATs) enhanced traits
637pub mod gat_traits {
638    use super::*;
639
640    /// Enhanced Estimator trait with GATs for better generic flexibility
641    pub trait EstimatorGAT<State = Untrained> {
642        /// Configuration type
643        type Config;
644
645        /// Error type
646        type Error: std::error::Error;
647
648        /// Numeric type for computations
649        type Float: FloatBounds;
650
651        /// Input data type
652        type Input<'a>
653        where
654            Self: 'a;
655
656        /// Output type
657        type Output<'a>
658        where
659            Self: 'a;
660
661        /// Parameters type
662        type Parameters;
663    }
664
665    /// GAT-enhanced Fit trait for better lifetime management
666    pub trait FitGAT<State = Untrained> {
667        /// Associated types
668        type Input<'a>
669        where
670            Self: 'a;
671        type Target<'a>
672        where
673            Self: 'a;
674        type Fitted;
675        type Error: std::error::Error;
676
677        /// Fit with GATs for flexible lifetime management
678        ///
679        /// # Lifetime Parameters
680        ///
681        /// * `'a` - Lifetime of input and target data, must be valid for the duration of fitting
682        ///
683        /// # Safety
684        ///
685        /// The implementer must ensure that the input and target data remain valid
686        /// for the entire duration of the fitting process.
687        fn fit_gat<'a>(
688            self,
689            input: Self::Input<'a>,
690            target: Self::Target<'a>,
691        ) -> Result<Self::Fitted>
692        where
693            Self: 'a; // Ensure self lives at least as long as the input data
694    }
695
696    /// GAT-enhanced Transform trait for zero-copy operations
697    pub trait TransformGAT {
698        /// Input type with lifetime
699        type Input<'a>
700        where
701            Self: 'a;
702
703        /// Output type with lifetime
704        type Output<'a>
705        where
706            Self: 'a;
707
708        /// Error type
709        type Error: std::error::Error;
710
711        /// Transform with zero-copy when possible
712        ///
713        /// # Lifetime Parameters
714        ///
715        /// * `'a` - Lifetime of input data, the output may borrow from the input
716        ///
717        /// # Zero-Copy Semantics
718        ///
719        /// This method is designed to enable zero-copy operations where the output
720        /// can borrow from the input data without requiring additional allocations.
721        /// The lifetime parameter ensures memory safety for borrowed data.
722        fn transform_gat<'a>(&self, input: Self::Input<'a>) -> Result<Self::Output<'a>>;
723    }
724
725    /// Iterator-based data processing with GATs
726    pub trait IteratorProcessor {
727        /// Item type
728        type Item<'a>
729        where
730            Self: 'a;
731
732        /// Processed item type
733        type ProcessedItem<'a>
734        where
735            Self: 'a;
736
737        /// Error type
738        type Error: std::error::Error;
739
740        /// Process iterator items
741        ///
742        /// # Lifetime Parameters
743        ///
744        /// * `'input` - Lifetime of the input iterator and its items
745        /// * `'output` - Lifetime of the processed output items
746        ///
747        /// The input lifetime must outlive the output lifetime to ensure
748        /// that any borrowed data remains valid.
749        fn process_iter<'input, 'output, I>(
750            &self,
751            iter: I,
752        ) -> impl Iterator<Item = Result<Self::ProcessedItem<'output>>> + 'output
753        where
754            I: Iterator<Item = Self::Item<'input>> + 'input,
755            'input: 'output, // Input must outlive output
756            Self: 'input + 'output;
757    }
758}
759
760/// Trait families for organizing related functionality hierarchically  
761pub mod trait_families {
762    use super::*;
763
764    /// Core ML trait family - base functionality for all ML algorithms
765    pub trait CoreMLFamily<State = Untrained>: Estimator<State> + GetParams + SetParams {
766        /// Get algorithm family name (e.g., "supervised", "unsupervised", "reinforcement")
767        fn algorithm_family(&self) -> &'static str;
768
769        /// Get algorithm category (e.g., "classification", "regression", "clustering")
770        fn algorithm_category(&self) -> &'static str;
771
772        /// Check if the algorithm supports a specific capability
773        fn supports_capability(&self, capability: &str) -> bool;
774    }
775
776    /// Supervised learning trait family with hierarchical relationships
777    pub trait SupervisedLearningFamily<X, Y, Output>:
778        CoreMLFamily + Fit<X, Y> + Predict<X, Output> + Score<X, Y>
779    where
780        Self::Fitted: Predict<X, Output> + Score<X, Y>,
781    {
782        /// Type of supervised learning (classification or regression)
783        fn learning_type(&self) -> SupervisedType;
784
785        /// Whether the algorithm supports feature importance
786        fn supports_feature_importance(&self) -> bool {
787            false
788        }
789
790        /// Whether the algorithm supports incremental learning
791        fn supports_incremental_learning(&self) -> bool {
792            false
793        }
794    }
795
796    /// Classification trait family with specialized classification capabilities
797    pub trait ClassificationFamily<X, Labels, Probabilities>:
798        SupervisedLearningFamily<X, Labels, Labels> + PredictProba<X, Probabilities>
799    where
800        Self::Fitted: Predict<X, Labels> + PredictProba<X, Probabilities> + Score<X, Labels>,
801    {
802        /// Type of classification problem
803        fn classification_type(&self) -> ClassificationType;
804
805        /// Whether the classifier supports probability calibration
806        fn supports_calibration(&self) -> bool {
807            false
808        }
809
810        /// Whether the classifier supports multi-label classification
811        fn supports_multilabel(&self) -> bool {
812            false
813        }
814    }
815
816    /// Regression trait family with specialized regression capabilities
817    pub trait RegressionFamily<X, Y>: SupervisedLearningFamily<X, Y, Y> + Score<X, Y>
818    where
819        Self::Fitted: Predict<X, Y> + Score<X, Y>,
820    {
821        /// Type of regression problem
822        fn regression_type(&self) -> RegressionType;
823
824        /// Whether the regressor supports prediction intervals
825        fn supports_prediction_intervals(&self) -> bool {
826            false
827        }
828
829        /// Whether the regressor supports robust fitting
830        fn supports_robust_fitting(&self) -> bool {
831            false
832        }
833    }
834
835    /// Unsupervised learning trait family
836    pub trait UnsupervisedLearningFamily<X>: CoreMLFamily + Transform<X> {
837        /// Type of unsupervised learning
838        fn unsupervised_type(&self) -> UnsupervisedType;
839
840        /// Whether the algorithm supports inverse transform
841        fn supports_inverse_transform(&self) -> bool {
842            false
843        }
844
845        /// Whether the algorithm is deterministic
846        fn is_deterministic(&self) -> bool {
847            true
848        }
849    }
850
851    /// Clustering trait family with specialized clustering capabilities
852    pub trait ClusteringFamily<X>: UnsupervisedLearningFamily<X> + Cluster<X> {
853        /// Type of clustering algorithm
854        fn clustering_type(&self) -> ClusteringType;
855
856        /// Whether the algorithm supports hierarchical clustering
857        fn supports_hierarchical(&self) -> bool {
858            false
859        }
860
861        /// Whether the algorithm can handle varying cluster numbers
862        fn supports_variable_clusters(&self) -> bool {
863            false
864        }
865
866        /// Whether the algorithm supports cluster centers
867        fn supports_cluster_centers(&self) -> bool {
868            false
869        }
870    }
871
872    /// Dimensionality reduction trait family
873    pub trait DimensionalityReductionFamily<X>:
874        UnsupervisedLearningFamily<X> + FitTransform<X, (), X>
875    {
876        /// Type of dimensionality reduction
877        fn reduction_type(&self) -> DimensionalityReductionType;
878
879        /// Target number of dimensions (if applicable)
880        fn target_dimensions(&self) -> Option<usize>;
881
882        /// Whether the transformation preserves distances
883        fn preserves_distances(&self) -> bool {
884            false
885        }
886    }
887
888    /// Ensemble trait family for meta-algorithms
889    pub trait EnsembleFamily<X, Y, Output>: SupervisedLearningFamily<X, Y, Output>
890    where
891        Self::Fitted: Predict<X, Output> + Score<X, Y>,
892    {
893        /// Type of ensemble method
894        fn ensemble_type(&self) -> EnsembleType;
895
896        /// Number of base estimators
897        fn n_estimators(&self) -> usize;
898
899        /// Whether the ensemble supports out-of-bag scoring
900        fn supports_oob_score(&self) -> bool {
901            false
902        }
903    }
904
905    /// Neural network trait family
906    pub trait NeuralNetworkFamily<X, Y, Output>:
907        SupervisedLearningFamily<X, Y, Output> + PartialFit<X, Y>
908    where
909        Self::Fitted: Predict<X, Output> + Score<X, Y>,
910    {
911        /// Type of neural network architecture
912        fn network_type(&self) -> NetworkType;
913
914        /// Number of layers in the network
915        fn n_layers(&self) -> usize;
916
917        /// Whether the network supports dropout
918        fn supports_dropout(&self) -> bool {
919            false
920        }
921
922        /// Whether the network supports batch normalization
923        fn supports_batch_norm(&self) -> bool {
924            false
925        }
926    }
927
928    /// Enums for categorizing algorithms
929
930    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
931    pub enum SupervisedType {
932        Classification,
933        Regression,
934    }
935
936    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
937    pub enum ClassificationType {
938        Binary,
939        Multiclass,
940        Multilabel,
941    }
942
943    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
944    pub enum RegressionType {
945        Linear,
946        Nonlinear,
947        Robust,
948    }
949
950    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
951    pub enum UnsupervisedType {
952        Clustering,
953        DimensionalityReduction,
954        DensityEstimation,
955        OutlierDetection,
956    }
957
958    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
959    pub enum ClusteringType {
960        Partitional,
961        Hierarchical,
962        DensityBased,
963        GridBased,
964    }
965
966    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
967    pub enum DimensionalityReductionType {
968        Linear,
969        Nonlinear,
970        Manifold,
971    }
972
973    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
974    pub enum EnsembleType {
975        Bagging,
976        Boosting,
977        Voting,
978        Stacking,
979    }
980
981    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
982    pub enum NetworkType {
983        Feedforward,
984        Convolutional,
985        Recurrent,
986        Transformer,
987    }
988
989    /// Blanket implementations for automatic trait family membership
990    impl<T, X, Y, Output> SupervisedLearningFamily<X, Y, Output> for T
991    where
992        T: CoreMLFamily + Fit<X, Y> + Predict<X, Output> + Score<X, Y>,
993        T::Fitted: Predict<X, Output> + Score<X, Y>,
994    {
995        fn learning_type(&self) -> SupervisedType {
996            // Default implementation - should be overridden
997            SupervisedType::Classification
998        }
999    }
1000
1001    impl<T, X, Labels, Probabilities> ClassificationFamily<X, Labels, Probabilities> for T
1002    where
1003        T: SupervisedLearningFamily<X, Labels, Labels> + PredictProba<X, Probabilities>,
1004        T::Fitted: Predict<X, Labels> + PredictProba<X, Probabilities> + Score<X, Labels>,
1005    {
1006        fn classification_type(&self) -> ClassificationType {
1007            // Default implementation - should be overridden
1008            ClassificationType::Binary
1009        }
1010    }
1011
1012    impl<T, X, Y> RegressionFamily<X, Y> for T
1013    where
1014        T: SupervisedLearningFamily<X, Y, Y> + Score<X, Y>,
1015        T::Fitted: Predict<X, Y> + Score<X, Y>,
1016    {
1017        fn regression_type(&self) -> RegressionType {
1018            // Default implementation - should be overridden
1019            RegressionType::Linear
1020        }
1021    }
1022
1023    impl<T, X> UnsupervisedLearningFamily<X> for T
1024    where
1025        T: CoreMLFamily + Transform<X>,
1026    {
1027        fn unsupervised_type(&self) -> UnsupervisedType {
1028            // Default implementation - should be overridden
1029            UnsupervisedType::Clustering
1030        }
1031    }
1032
1033    impl<T, X> ClusteringFamily<X> for T
1034    where
1035        T: UnsupervisedLearningFamily<X> + Cluster<X>,
1036    {
1037        fn clustering_type(&self) -> ClusteringType {
1038            // Default implementation - should be overridden
1039            ClusteringType::Partitional
1040        }
1041    }
1042
1043    impl<T, X> DimensionalityReductionFamily<X> for T
1044    where
1045        T: UnsupervisedLearningFamily<X> + FitTransform<X, (), X>,
1046    {
1047        fn reduction_type(&self) -> DimensionalityReductionType {
1048            // Default implementation - should be overridden
1049            DimensionalityReductionType::Linear
1050        }
1051
1052        fn target_dimensions(&self) -> Option<usize> {
1053            None
1054        }
1055    }
1056}
1057
1058/// Advanced trait combinations for specialized use cases
1059pub mod specialized {
1060    use super::*;
1061
1062    pub trait HybridLearner<X, Y, Output>:
1063        Fit<X, Y> + PartialFit<X, Y> + Predict<X, Output>
1064    where
1065        Self::Fitted: Predict<X, Output> + PartialFit<X, Y>,
1066    {
1067        fn set_learning_mode(&mut self, online: bool);
1068
1069        fn is_online_mode(&self) -> bool;
1070    }
1071
1072    /// Trait for interpretable models
1073    pub trait InterpretableModel<X, Y, Output> {
1074        /// Feature importance type
1075        type Importance;
1076
1077        /// Get feature importance scores
1078        fn feature_importance(&self) -> Result<Self::Importance>;
1079
1080        /// Get model explanation for a prediction
1081        fn explain_prediction(&self, input: &X) -> Result<String>;
1082
1083        /// Get global model explanation
1084        fn explain_model(&self) -> Result<String>;
1085    }
1086
1087    /// Trait for models with confidence estimation
1088    pub trait ConfidenceModel<X, Output> {
1089        /// Confidence score type
1090        type Confidence: FloatBounds;
1091
1092        /// Predict with confidence scores
1093        fn predict_with_confidence(&self, x: &X) -> Result<(Output, Vec<Self::Confidence>)>;
1094
1095        /// Get prediction uncertainty
1096        fn prediction_uncertainty(&self, x: &X) -> Result<Self::Confidence>;
1097    }
1098
1099    /// Trait for models that support differential privacy
1100    pub trait PrivacyPreservingModel<X, Y> {
1101        /// Privacy budget type
1102        type PrivacyBudget: FloatBounds;
1103
1104        /// Set privacy parameters
1105        fn set_privacy_budget(&mut self, budget: Self::PrivacyBudget);
1106
1107        /// Get remaining privacy budget
1108        fn remaining_privacy_budget(&self) -> Self::PrivacyBudget;
1109
1110        /// Check if operation is within privacy budget
1111        fn is_privacy_safe(&self, operation_cost: Self::PrivacyBudget) -> bool;
1112    }
1113}