sklears_dummy/
fluent_api.rs

1//! Fluent API for baseline estimator configuration
2//!
3//! This module provides a comprehensive fluent API for configuring dummy estimators
4//! with method chaining, configuration presets, and streamlined parameter setting.
5
6use crate::dummy_classifier::{DummyClassifier, Strategy as ClassifierStrategy};
7use crate::dummy_regressor::{DummyRegressor, Strategy as RegressorStrategy};
8use scirs2_core::ndarray::Array1;
9use sklears_core::types::Float;
10
11/// Configuration presets for common use cases
12#[derive(Debug, Clone)]
13pub struct ConfigPresets;
14
15impl ConfigPresets {
16    /// Configuration for highly imbalanced datasets
17    pub fn imbalanced_classification() -> ClassifierConfig {
18        ClassifierConfig::new()
19            .strategy(ClassifierStrategy::MostFrequent)
20            .with_description("Optimized for imbalanced datasets")
21    }
22
23    /// Configuration for balanced multiclass classification
24    pub fn balanced_multiclass() -> ClassifierConfig {
25        ClassifierConfig::new()
26            .strategy(ClassifierStrategy::Stratified)
27            .with_description("Balanced multiclass classification")
28    }
29
30    /// Configuration for uncertainty-aware classification
31    pub fn uncertainty_aware_classification() -> ClassifierConfig {
32        ClassifierConfig::new()
33            .strategy(ClassifierStrategy::Bayesian)
34            .with_description("Provides uncertainty estimates")
35    }
36
37    /// Configuration for time series forecasting baselines
38    pub fn time_series_forecasting() -> RegressorConfig {
39        RegressorConfig::new()
40            .strategy(RegressorStrategy::SeasonalNaive(12))
41            .with_description("Time series forecasting baseline")
42    }
43
44    /// Configuration for high-variance regression data
45    pub fn high_variance_regression() -> RegressorConfig {
46        RegressorConfig::new()
47            .strategy(RegressorStrategy::Median)
48            .with_description("Robust to high variance and outliers")
49    }
50
51    /// Configuration for probabilistic regression
52    pub fn probabilistic_regression() -> RegressorConfig {
53        RegressorConfig::new()
54            .strategy(RegressorStrategy::Normal {
55                mean: None,
56                std: None,
57            })
58            .with_description("Provides probabilistic predictions")
59    }
60
61    /// Configuration for competition-grade baselines
62    pub fn competition_baseline() -> RegressorConfig {
63        RegressorConfig::new()
64            .strategy(RegressorStrategy::Auto)
65            .with_description("Adaptive baseline for competitions")
66    }
67
68    /// Configuration for streaming/online learning
69    pub fn streaming_baseline() -> RegressorConfig {
70        RegressorConfig::new()
71            .strategy(RegressorStrategy::Mean)
72            .with_description("Suitable for streaming scenarios")
73    }
74}
75
76/// Fluent configuration builder for DummyClassifier
77#[derive(Debug, Clone)]
78pub struct ClassifierConfig {
79    strategy: ClassifierStrategy,
80    random_state: Option<u64>,
81    constant: Option<i32>,
82    bayesian_alpha: Option<Array1<Float>>,
83    description: Option<String>,
84}
85
86impl ClassifierConfig {
87    /// Create a new configuration builder
88    pub fn new() -> Self {
89        Self {
90            strategy: ClassifierStrategy::Auto,
91            random_state: None,
92            constant: None,
93            bayesian_alpha: None,
94            description: None,
95        }
96    }
97
98    /// Set the prediction strategy
99    pub fn strategy(mut self, strategy: ClassifierStrategy) -> Self {
100        self.strategy = strategy;
101        self
102    }
103
104    /// Set random state for reproducible results
105    pub fn random_state(mut self, seed: u64) -> Self {
106        self.random_state = Some(seed);
107        self
108    }
109
110    /// Set constant value for constant strategy
111    pub fn constant(mut self, value: i32) -> Self {
112        self.constant = Some(value);
113        self
114    }
115
116    /// Set Bayesian prior parameters
117    pub fn bayesian_prior(mut self, alpha: Array1<Float>) -> Self {
118        self.bayesian_alpha = Some(alpha);
119        self
120    }
121
122    /// Add description for documentation
123    pub fn with_description<S: Into<String>>(mut self, description: S) -> Self {
124        self.description = Some(description.into());
125        self
126    }
127
128    /// Enable reproducible mode with fixed seed
129    pub fn reproducible(self) -> Self {
130        self.random_state(42)
131    }
132
133    /// Configure for fast predictions (minimal computation)
134    pub fn fast_mode(self) -> Self {
135        self.strategy(ClassifierStrategy::MostFrequent)
136    }
137
138    /// Configure for balanced predictions
139    pub fn balanced_mode(self) -> Self {
140        self.strategy(ClassifierStrategy::Stratified)
141    }
142
143    /// Configure for uncertainty quantification
144    pub fn uncertainty_mode(self) -> Self {
145        self.strategy(ClassifierStrategy::Bayesian)
146    }
147
148    /// Build the configured DummyClassifier
149    pub fn build(self) -> DummyClassifier {
150        let mut classifier = DummyClassifier::new(self.strategy);
151
152        if let Some(seed) = self.random_state {
153            classifier = classifier.with_random_state(seed);
154        }
155
156        if let Some(constant) = self.constant {
157            classifier = classifier.with_constant(constant);
158        }
159
160        if let Some(alpha) = self.bayesian_alpha {
161            classifier = classifier.with_bayesian_prior(alpha);
162        }
163
164        classifier
165    }
166
167    /// Get the configuration description
168    pub fn description(&self) -> Option<&str> {
169        self.description.as_deref()
170    }
171}
172
173impl Default for ClassifierConfig {
174    fn default() -> Self {
175        Self::new()
176    }
177}
178
179/// Fluent configuration builder for DummyRegressor
180#[derive(Debug, Clone)]
181pub struct RegressorConfig {
182    strategy: RegressorStrategy,
183    random_state: Option<u64>,
184    constant: Option<Float>,
185    description: Option<String>,
186}
187
188impl RegressorConfig {
189    /// Create a new configuration builder
190    pub fn new() -> Self {
191        Self {
192            strategy: RegressorStrategy::Auto,
193            random_state: None,
194            constant: None,
195            description: None,
196        }
197    }
198
199    /// Set the prediction strategy
200    pub fn strategy(mut self, strategy: RegressorStrategy) -> Self {
201        self.strategy = strategy;
202        self
203    }
204
205    /// Set random state for reproducible results
206    pub fn random_state(mut self, seed: u64) -> Self {
207        self.random_state = Some(seed);
208        self
209    }
210
211    /// Set constant value for constant strategy
212    pub fn constant(mut self, value: Float) -> Self {
213        self.constant = Some(value);
214        self
215    }
216
217    /// Add description for documentation
218    pub fn with_description<S: Into<String>>(mut self, description: S) -> Self {
219        self.description = Some(description.into());
220        self
221    }
222
223    /// Enable reproducible mode with fixed seed
224    pub fn reproducible(self) -> Self {
225        self.random_state(42)
226    }
227
228    /// Configure for fast predictions (minimal computation)
229    pub fn fast_mode(self) -> Self {
230        self.strategy(RegressorStrategy::Mean)
231    }
232
233    /// Configure for robust predictions (outlier resistant)
234    pub fn robust_mode(self) -> Self {
235        self.strategy(RegressorStrategy::Median)
236    }
237
238    /// Configure for probabilistic predictions
239    pub fn probabilistic_mode(self) -> Self {
240        self.strategy(RegressorStrategy::Normal {
241            mean: None,
242            std: None,
243        })
244    }
245
246    /// Configure for time series forecasting
247    pub fn time_series_mode(self) -> Self {
248        self.strategy(RegressorStrategy::SeasonalNaive(12))
249    }
250
251    /// Build the configured DummyRegressor
252    pub fn build(self) -> DummyRegressor {
253        let mut regressor = DummyRegressor::new(self.strategy);
254
255        if let Some(seed) = self.random_state {
256            regressor = regressor.with_random_state(seed);
257        }
258
259        if let Some(constant) = self.constant {
260            regressor = regressor.with_constant(constant);
261        }
262
263        regressor
264    }
265
266    /// Get the configuration description
267    pub fn description(&self) -> Option<&str> {
268        self.description.as_deref()
269    }
270}
271
272impl Default for RegressorConfig {
273    fn default() -> Self {
274        Self::new()
275    }
276}
277
278/// Trait for method chaining with preprocessing
279pub trait PreprocessingChain<T> {
280    /// Apply preprocessing step and return self for chaining
281    fn with_preprocessing<F>(self, preprocessor: F) -> Self
282    where
283        F: Fn(T) -> T;
284}
285
286/// Enhanced fluent API extensions for DummyClassifier
287pub trait ClassifierFluentExt {
288    /// Create a fluent configuration builder
289    fn configure() -> ClassifierConfig;
290
291    /// Quick setup for common scenarios
292    fn for_imbalanced_data() -> DummyClassifier;
293    fn for_balanced_multiclass() -> DummyClassifier;
294    fn for_uncertainty_estimation() -> DummyClassifier;
295    fn for_fast_baseline() -> DummyClassifier;
296}
297
298impl ClassifierFluentExt for DummyClassifier {
299    fn configure() -> ClassifierConfig {
300        ClassifierConfig::new()
301    }
302
303    fn for_imbalanced_data() -> DummyClassifier {
304        ConfigPresets::imbalanced_classification().build()
305    }
306
307    fn for_balanced_multiclass() -> DummyClassifier {
308        ConfigPresets::balanced_multiclass().build()
309    }
310
311    fn for_uncertainty_estimation() -> DummyClassifier {
312        ConfigPresets::uncertainty_aware_classification().build()
313    }
314
315    fn for_fast_baseline() -> DummyClassifier {
316        ClassifierConfig::new().fast_mode().build()
317    }
318}
319
320/// Enhanced fluent API extensions for DummyRegressor
321pub trait RegressorFluentExt {
322    /// Create a fluent configuration builder
323    fn configure() -> RegressorConfig;
324
325    /// Quick setup for common scenarios
326    fn for_time_series() -> DummyRegressor;
327    fn for_high_variance() -> DummyRegressor;
328    fn for_probabilistic() -> DummyRegressor;
329    fn for_competition() -> DummyRegressor;
330    fn for_streaming() -> DummyRegressor;
331}
332
333impl RegressorFluentExt for DummyRegressor {
334    fn configure() -> RegressorConfig {
335        RegressorConfig::new()
336    }
337
338    fn for_time_series() -> DummyRegressor {
339        ConfigPresets::time_series_forecasting().build()
340    }
341
342    fn for_high_variance() -> DummyRegressor {
343        ConfigPresets::high_variance_regression().build()
344    }
345
346    fn for_probabilistic() -> DummyRegressor {
347        ConfigPresets::probabilistic_regression().build()
348    }
349
350    fn for_competition() -> DummyRegressor {
351        ConfigPresets::competition_baseline().build()
352    }
353
354    fn for_streaming() -> DummyRegressor {
355        ConfigPresets::streaming_baseline().build()
356    }
357}
358
359#[allow(non_snake_case)]
360#[cfg(test)]
361mod tests {
362    use super::*;
363    use scirs2_core::ndarray::arr1;
364
365    #[test]
366    fn test_classifier_config_builder() {
367        let config = ClassifierConfig::new()
368            .strategy(ClassifierStrategy::MostFrequent)
369            .random_state(42)
370            .with_description("Test configuration");
371
372        assert_eq!(config.description(), Some("Test configuration"));
373        let classifier = config.build();
374        assert_eq!(classifier.strategy, ClassifierStrategy::MostFrequent);
375        assert_eq!(classifier.random_state, Some(42));
376    }
377
378    #[test]
379    fn test_regressor_config_builder() {
380        let config = RegressorConfig::new()
381            .strategy(RegressorStrategy::Mean)
382            .random_state(123)
383            .constant(5.0)
384            .with_description("Test regressor");
385
386        assert_eq!(config.description(), Some("Test regressor"));
387        let regressor = config.build();
388        assert_eq!(regressor.strategy, RegressorStrategy::Constant(5.0));
389        assert_eq!(regressor.random_state, Some(123));
390    }
391
392    #[test]
393    fn test_fluent_extensions() {
394        let classifier = DummyClassifier::for_imbalanced_data();
395        assert_eq!(classifier.strategy, ClassifierStrategy::MostFrequent);
396
397        let regressor = DummyRegressor::for_time_series();
398        assert!(matches!(
399            regressor.strategy,
400            RegressorStrategy::SeasonalNaive(_)
401        ));
402    }
403
404    #[test]
405    fn test_config_presets() {
406        let config = ConfigPresets::imbalanced_classification();
407        assert_eq!(
408            config.description(),
409            Some("Optimized for imbalanced datasets")
410        );
411
412        let config = ConfigPresets::probabilistic_regression();
413        assert_eq!(
414            config.description(),
415            Some("Provides probabilistic predictions")
416        );
417    }
418
419    #[test]
420    fn test_method_chaining() {
421        let classifier = ClassifierConfig::new()
422            .strategy(ClassifierStrategy::Bayesian)
423            .reproducible()
424            .bayesian_prior(arr1(&[1.0, 1.0, 1.0]))
425            .with_description("Chained configuration")
426            .build();
427
428        assert_eq!(classifier.strategy, ClassifierStrategy::Bayesian);
429        assert_eq!(classifier.random_state, Some(42));
430        assert!(classifier.bayesian_alpha_.is_some());
431    }
432
433    #[test]
434    fn test_mode_configurations() {
435        let fast_config = ClassifierConfig::new().fast_mode();
436        assert_eq!(fast_config.strategy, ClassifierStrategy::MostFrequent);
437
438        let balanced_config = ClassifierConfig::new().balanced_mode();
439        assert_eq!(balanced_config.strategy, ClassifierStrategy::Stratified);
440
441        let uncertainty_config = ClassifierConfig::new().uncertainty_mode();
442        assert_eq!(uncertainty_config.strategy, ClassifierStrategy::Bayesian);
443    }
444}