sklears_dummy/
fluent_api.rs1use crate::dummy_classifier::{DummyClassifier, Strategy as ClassifierStrategy};
7use crate::dummy_regressor::{DummyRegressor, Strategy as RegressorStrategy};
8use scirs2_core::ndarray::Array1;
9use sklears_core::types::Float;
10
11#[derive(Debug, Clone)]
13pub struct ConfigPresets;
14
15impl ConfigPresets {
16 pub fn imbalanced_classification() -> ClassifierConfig {
18 ClassifierConfig::new()
19 .strategy(ClassifierStrategy::MostFrequent)
20 .with_description("Optimized for imbalanced datasets")
21 }
22
23 pub fn balanced_multiclass() -> ClassifierConfig {
25 ClassifierConfig::new()
26 .strategy(ClassifierStrategy::Stratified)
27 .with_description("Balanced multiclass classification")
28 }
29
30 pub fn uncertainty_aware_classification() -> ClassifierConfig {
32 ClassifierConfig::new()
33 .strategy(ClassifierStrategy::Bayesian)
34 .with_description("Provides uncertainty estimates")
35 }
36
37 pub fn time_series_forecasting() -> RegressorConfig {
39 RegressorConfig::new()
40 .strategy(RegressorStrategy::SeasonalNaive(12))
41 .with_description("Time series forecasting baseline")
42 }
43
44 pub fn high_variance_regression() -> RegressorConfig {
46 RegressorConfig::new()
47 .strategy(RegressorStrategy::Median)
48 .with_description("Robust to high variance and outliers")
49 }
50
51 pub fn probabilistic_regression() -> RegressorConfig {
53 RegressorConfig::new()
54 .strategy(RegressorStrategy::Normal {
55 mean: None,
56 std: None,
57 })
58 .with_description("Provides probabilistic predictions")
59 }
60
61 pub fn competition_baseline() -> RegressorConfig {
63 RegressorConfig::new()
64 .strategy(RegressorStrategy::Auto)
65 .with_description("Adaptive baseline for competitions")
66 }
67
68 pub fn streaming_baseline() -> RegressorConfig {
70 RegressorConfig::new()
71 .strategy(RegressorStrategy::Mean)
72 .with_description("Suitable for streaming scenarios")
73 }
74}
75
76#[derive(Debug, Clone)]
78pub struct ClassifierConfig {
79 strategy: ClassifierStrategy,
80 random_state: Option<u64>,
81 constant: Option<i32>,
82 bayesian_alpha: Option<Array1<Float>>,
83 description: Option<String>,
84}
85
86impl ClassifierConfig {
87 pub fn new() -> Self {
89 Self {
90 strategy: ClassifierStrategy::Auto,
91 random_state: None,
92 constant: None,
93 bayesian_alpha: None,
94 description: None,
95 }
96 }
97
98 pub fn strategy(mut self, strategy: ClassifierStrategy) -> Self {
100 self.strategy = strategy;
101 self
102 }
103
104 pub fn random_state(mut self, seed: u64) -> Self {
106 self.random_state = Some(seed);
107 self
108 }
109
110 pub fn constant(mut self, value: i32) -> Self {
112 self.constant = Some(value);
113 self
114 }
115
116 pub fn bayesian_prior(mut self, alpha: Array1<Float>) -> Self {
118 self.bayesian_alpha = Some(alpha);
119 self
120 }
121
122 pub fn with_description<S: Into<String>>(mut self, description: S) -> Self {
124 self.description = Some(description.into());
125 self
126 }
127
128 pub fn reproducible(self) -> Self {
130 self.random_state(42)
131 }
132
133 pub fn fast_mode(self) -> Self {
135 self.strategy(ClassifierStrategy::MostFrequent)
136 }
137
138 pub fn balanced_mode(self) -> Self {
140 self.strategy(ClassifierStrategy::Stratified)
141 }
142
143 pub fn uncertainty_mode(self) -> Self {
145 self.strategy(ClassifierStrategy::Bayesian)
146 }
147
148 pub fn build(self) -> DummyClassifier {
150 let mut classifier = DummyClassifier::new(self.strategy);
151
152 if let Some(seed) = self.random_state {
153 classifier = classifier.with_random_state(seed);
154 }
155
156 if let Some(constant) = self.constant {
157 classifier = classifier.with_constant(constant);
158 }
159
160 if let Some(alpha) = self.bayesian_alpha {
161 classifier = classifier.with_bayesian_prior(alpha);
162 }
163
164 classifier
165 }
166
167 pub fn description(&self) -> Option<&str> {
169 self.description.as_deref()
170 }
171}
172
173impl Default for ClassifierConfig {
174 fn default() -> Self {
175 Self::new()
176 }
177}
178
179#[derive(Debug, Clone)]
181pub struct RegressorConfig {
182 strategy: RegressorStrategy,
183 random_state: Option<u64>,
184 constant: Option<Float>,
185 description: Option<String>,
186}
187
188impl RegressorConfig {
189 pub fn new() -> Self {
191 Self {
192 strategy: RegressorStrategy::Auto,
193 random_state: None,
194 constant: None,
195 description: None,
196 }
197 }
198
199 pub fn strategy(mut self, strategy: RegressorStrategy) -> Self {
201 self.strategy = strategy;
202 self
203 }
204
205 pub fn random_state(mut self, seed: u64) -> Self {
207 self.random_state = Some(seed);
208 self
209 }
210
211 pub fn constant(mut self, value: Float) -> Self {
213 self.constant = Some(value);
214 self
215 }
216
217 pub fn with_description<S: Into<String>>(mut self, description: S) -> Self {
219 self.description = Some(description.into());
220 self
221 }
222
223 pub fn reproducible(self) -> Self {
225 self.random_state(42)
226 }
227
228 pub fn fast_mode(self) -> Self {
230 self.strategy(RegressorStrategy::Mean)
231 }
232
233 pub fn robust_mode(self) -> Self {
235 self.strategy(RegressorStrategy::Median)
236 }
237
238 pub fn probabilistic_mode(self) -> Self {
240 self.strategy(RegressorStrategy::Normal {
241 mean: None,
242 std: None,
243 })
244 }
245
246 pub fn time_series_mode(self) -> Self {
248 self.strategy(RegressorStrategy::SeasonalNaive(12))
249 }
250
251 pub fn build(self) -> DummyRegressor {
253 let mut regressor = DummyRegressor::new(self.strategy);
254
255 if let Some(seed) = self.random_state {
256 regressor = regressor.with_random_state(seed);
257 }
258
259 if let Some(constant) = self.constant {
260 regressor = regressor.with_constant(constant);
261 }
262
263 regressor
264 }
265
266 pub fn description(&self) -> Option<&str> {
268 self.description.as_deref()
269 }
270}
271
272impl Default for RegressorConfig {
273 fn default() -> Self {
274 Self::new()
275 }
276}
277
278pub trait PreprocessingChain<T> {
280 fn with_preprocessing<F>(self, preprocessor: F) -> Self
282 where
283 F: Fn(T) -> T;
284}
285
286pub trait ClassifierFluentExt {
288 fn configure() -> ClassifierConfig;
290
291 fn for_imbalanced_data() -> DummyClassifier;
293 fn for_balanced_multiclass() -> DummyClassifier;
294 fn for_uncertainty_estimation() -> DummyClassifier;
295 fn for_fast_baseline() -> DummyClassifier;
296}
297
298impl ClassifierFluentExt for DummyClassifier {
299 fn configure() -> ClassifierConfig {
300 ClassifierConfig::new()
301 }
302
303 fn for_imbalanced_data() -> DummyClassifier {
304 ConfigPresets::imbalanced_classification().build()
305 }
306
307 fn for_balanced_multiclass() -> DummyClassifier {
308 ConfigPresets::balanced_multiclass().build()
309 }
310
311 fn for_uncertainty_estimation() -> DummyClassifier {
312 ConfigPresets::uncertainty_aware_classification().build()
313 }
314
315 fn for_fast_baseline() -> DummyClassifier {
316 ClassifierConfig::new().fast_mode().build()
317 }
318}
319
320pub trait RegressorFluentExt {
322 fn configure() -> RegressorConfig;
324
325 fn for_time_series() -> DummyRegressor;
327 fn for_high_variance() -> DummyRegressor;
328 fn for_probabilistic() -> DummyRegressor;
329 fn for_competition() -> DummyRegressor;
330 fn for_streaming() -> DummyRegressor;
331}
332
333impl RegressorFluentExt for DummyRegressor {
334 fn configure() -> RegressorConfig {
335 RegressorConfig::new()
336 }
337
338 fn for_time_series() -> DummyRegressor {
339 ConfigPresets::time_series_forecasting().build()
340 }
341
342 fn for_high_variance() -> DummyRegressor {
343 ConfigPresets::high_variance_regression().build()
344 }
345
346 fn for_probabilistic() -> DummyRegressor {
347 ConfigPresets::probabilistic_regression().build()
348 }
349
350 fn for_competition() -> DummyRegressor {
351 ConfigPresets::competition_baseline().build()
352 }
353
354 fn for_streaming() -> DummyRegressor {
355 ConfigPresets::streaming_baseline().build()
356 }
357}
358
359#[allow(non_snake_case)]
360#[cfg(test)]
361mod tests {
362 use super::*;
363 use scirs2_core::ndarray::arr1;
364
365 #[test]
366 fn test_classifier_config_builder() {
367 let config = ClassifierConfig::new()
368 .strategy(ClassifierStrategy::MostFrequent)
369 .random_state(42)
370 .with_description("Test configuration");
371
372 assert_eq!(config.description(), Some("Test configuration"));
373 let classifier = config.build();
374 assert_eq!(classifier.strategy, ClassifierStrategy::MostFrequent);
375 assert_eq!(classifier.random_state, Some(42));
376 }
377
378 #[test]
379 fn test_regressor_config_builder() {
380 let config = RegressorConfig::new()
381 .strategy(RegressorStrategy::Mean)
382 .random_state(123)
383 .constant(5.0)
384 .with_description("Test regressor");
385
386 assert_eq!(config.description(), Some("Test regressor"));
387 let regressor = config.build();
388 assert_eq!(regressor.strategy, RegressorStrategy::Constant(5.0));
389 assert_eq!(regressor.random_state, Some(123));
390 }
391
392 #[test]
393 fn test_fluent_extensions() {
394 let classifier = DummyClassifier::for_imbalanced_data();
395 assert_eq!(classifier.strategy, ClassifierStrategy::MostFrequent);
396
397 let regressor = DummyRegressor::for_time_series();
398 assert!(matches!(
399 regressor.strategy,
400 RegressorStrategy::SeasonalNaive(_)
401 ));
402 }
403
404 #[test]
405 fn test_config_presets() {
406 let config = ConfigPresets::imbalanced_classification();
407 assert_eq!(
408 config.description(),
409 Some("Optimized for imbalanced datasets")
410 );
411
412 let config = ConfigPresets::probabilistic_regression();
413 assert_eq!(
414 config.description(),
415 Some("Provides probabilistic predictions")
416 );
417 }
418
419 #[test]
420 fn test_method_chaining() {
421 let classifier = ClassifierConfig::new()
422 .strategy(ClassifierStrategy::Bayesian)
423 .reproducible()
424 .bayesian_prior(arr1(&[1.0, 1.0, 1.0]))
425 .with_description("Chained configuration")
426 .build();
427
428 assert_eq!(classifier.strategy, ClassifierStrategy::Bayesian);
429 assert_eq!(classifier.random_state, Some(42));
430 assert!(classifier.bayesian_alpha_.is_some());
431 }
432
433 #[test]
434 fn test_mode_configurations() {
435 let fast_config = ClassifierConfig::new().fast_mode();
436 assert_eq!(fast_config.strategy, ClassifierStrategy::MostFrequent);
437
438 let balanced_config = ClassifierConfig::new().balanced_mode();
439 assert_eq!(balanced_config.strategy, ClassifierStrategy::Stratified);
440
441 let uncertainty_config = ClassifierConfig::new().uncertainty_mode();
442 assert_eq!(uncertainty_config.strategy, ClassifierStrategy::Bayesian);
443 }
444}