1use std::marker::PhantomData;
12
13use sklears_core::{
14 error::{Result, SklearsError},
15 traits::Estimator,
16 types::Float,
17};
18
19use crate::{LinearRegression, LinearRegressionConfig, Penalty, Solver};
20
21#[cfg(feature = "logistic-regression")]
22use crate::{LogisticRegression, LogisticRegressionConfig};
23
24#[derive(Debug, Clone, Copy, PartialEq)]
26pub enum ModelPreset {
27 Quick,
28 Balanced,
29 HighAccuracy,
30 Robust,
31 MemoryEfficient,
32 Production,
33}
34
35pub mod validation {
37 pub trait Configured {}
39
40 pub trait WithRegularization {}
42
43 pub trait WithSolver {}
45}
46
47#[derive(Debug, Clone)]
49pub struct EnhancedLinearRegressionBuilder<State = Unconfigured> {
50 config: LinearRegressionConfig,
51 validation_config: ValidationConfig,
52 _state: PhantomData<State>,
53}
54
55#[cfg(feature = "logistic-regression")]
57#[derive(Debug, Clone)]
58pub struct EnhancedLogisticRegressionBuilder<State = Unconfigured> {
59 config: LogisticRegressionConfig,
60 validation_config: ValidationConfig,
61 _state: PhantomData<State>,
62}
63
64#[derive(Debug, Clone, Copy)]
66pub struct Unconfigured;
67
68#[derive(Debug, Clone, Copy)]
70pub struct Configured;
71
72#[derive(Debug, Clone, Copy)]
74pub struct WithRegularization;
75
76#[derive(Debug, Clone, Copy)]
78pub struct WithSolver;
79
80#[derive(Debug, Clone, Default)]
82pub struct ValidationConfig {
83 pub cross_validation_folds: Option<usize>,
85 pub validation_split: Option<Float>,
87 pub early_stopping: bool,
89 pub random_state: Option<u64>,
91}
92
93impl Default for EnhancedLinearRegressionBuilder<Unconfigured> {
95 fn default() -> Self {
96 Self {
97 config: LinearRegressionConfig::default(),
98 validation_config: ValidationConfig::default(),
99 _state: PhantomData,
100 }
101 }
102}
103
104impl EnhancedLinearRegressionBuilder<Unconfigured> {
105 pub fn new() -> Self {
107 Self::default()
108 }
109
110 pub fn with_preset(preset: ModelPreset) -> EnhancedLinearRegressionBuilder<Configured> {
112 let builder = Self::new();
113 builder.apply_preset(preset)
114 }
115
116 pub fn apply_preset(
118 mut self,
119 preset: ModelPreset,
120 ) -> EnhancedLinearRegressionBuilder<Configured> {
121 match preset {
122 ModelPreset::Quick => {
123 self.config.solver = Solver::Normal;
124 self.config.fit_intercept = true;
125 self.config.max_iter = 100;
126 }
127 ModelPreset::Balanced => {
128 self.config.solver = Solver::Auto;
129 self.config.fit_intercept = true;
130 self.config.max_iter = 1000;
131 self.config.penalty = Penalty::L2(0.1);
132 }
133 ModelPreset::HighAccuracy => {
134 self.config.solver = Solver::Normal;
135 self.config.fit_intercept = true;
136 self.config.max_iter = 5000;
137 self.config.penalty = Penalty::L2(0.01);
138 self.validation_config.cross_validation_folds = Some(5);
139 }
140 ModelPreset::Robust => {
141 self.config.solver = Solver::Auto;
142 self.config.fit_intercept = true;
143 self.config.penalty = Penalty::L1(0.1);
144 self.config.max_iter = 2000;
145 }
146 ModelPreset::MemoryEfficient => {
147 self.config.solver = Solver::Normal;
148 self.config.fit_intercept = true;
149 self.config.max_iter = 500;
150 }
151 ModelPreset::Production => {
152 self.config.solver = Solver::Auto;
153 self.config.fit_intercept = true;
154 self.config.penalty = Penalty::ElasticNet {
155 l1_ratio: 0.5,
156 alpha: 0.1,
157 };
158 self.config.max_iter = 3000;
159 self.validation_config.cross_validation_folds = Some(10);
160 self.validation_config.early_stopping = true;
161 }
162 }
163
164 EnhancedLinearRegressionBuilder {
165 config: self.config,
166 validation_config: self.validation_config,
167 _state: PhantomData,
168 }
169 }
170}
171
172impl<State> EnhancedLinearRegressionBuilder<State> {
173 pub fn solver(mut self, solver: Solver) -> EnhancedLinearRegressionBuilder<WithSolver> {
175 self.config.solver = solver;
176 EnhancedLinearRegressionBuilder {
177 config: self.config,
178 validation_config: self.validation_config,
179 _state: PhantomData,
180 }
181 }
182
183 pub fn fit_intercept(mut self, fit_intercept: bool) -> Self {
185 self.config.fit_intercept = fit_intercept;
186 self
187 }
188
189 pub fn penalty(
191 mut self,
192 penalty: Penalty,
193 ) -> EnhancedLinearRegressionBuilder<WithRegularization> {
194 self.config.penalty = penalty;
195 EnhancedLinearRegressionBuilder {
196 config: self.config,
197 validation_config: self.validation_config,
198 _state: PhantomData,
199 }
200 }
201
202 pub fn max_iter(mut self, max_iter: usize) -> Self {
204 self.config.max_iter = max_iter;
205 self
206 }
207
208 pub fn tolerance(mut self, tol: f64) -> Self {
210 self.config.tol = tol;
211 self
212 }
213
214 pub fn warm_start(mut self, warm_start: bool) -> Self {
216 self.config.warm_start = warm_start;
217 self
218 }
219
220 pub fn with_cross_validation(mut self, folds: usize) -> Self {
222 self.validation_config.cross_validation_folds = Some(folds);
223 self
224 }
225
226 pub fn with_validation_split(mut self, split: Float) -> Self {
228 self.validation_config.validation_split = Some(split);
229 self
230 }
231
232 pub fn with_early_stopping(mut self) -> Self {
234 self.validation_config.early_stopping = true;
235 self
236 }
237
238 pub fn random_state(mut self, seed: u64) -> Self {
240 self.validation_config.random_state = Some(seed);
241 self
242 }
243
244 pub fn build(self) -> Result<LinearRegression> {
246 LinearRegression::new()
247 .penalty(self.config.penalty)
248 .solver(self.config.solver)
249 .fit_intercept(self.config.fit_intercept)
250 .max_iter(self.config.max_iter)
251 .warm_start(self.config.warm_start)
252 .validate_config()
253 }
254
255 pub fn config(&self) -> &LinearRegressionConfig {
257 &self.config
258 }
259
260 pub fn validation_config(&self) -> &ValidationConfig {
262 &self.validation_config
263 }
264}
265
266#[cfg(feature = "logistic-regression")]
268impl Default for EnhancedLogisticRegressionBuilder<Unconfigured> {
269 fn default() -> Self {
270 Self {
271 config: LogisticRegressionConfig::default(),
272 validation_config: ValidationConfig::default(),
273 _state: PhantomData,
274 }
275 }
276}
277
278#[cfg(feature = "logistic-regression")]
279impl EnhancedLogisticRegressionBuilder<Unconfigured> {
280 pub fn new() -> Self {
282 Self::default()
283 }
284
285 pub fn with_preset(preset: ModelPreset) -> EnhancedLogisticRegressionBuilder<Configured> {
287 let builder = Self::new();
288 builder.apply_preset(preset)
289 }
290
291 pub fn apply_preset(
293 mut self,
294 preset: ModelPreset,
295 ) -> EnhancedLogisticRegressionBuilder<Configured> {
296 match preset {
297 ModelPreset::Quick => {
298 self.config.solver = Solver::Lbfgs;
299 self.config.max_iter = 100;
300 self.config.penalty = Penalty::L2(1.0);
301 self.config.tol = 1e-3;
302 }
303 ModelPreset::Balanced => {
304 self.config.solver = Solver::Auto;
305 self.config.max_iter = 1000;
306 self.config.penalty = Penalty::L2(1.0);
307 self.config.tol = 1e-4;
308 }
309 ModelPreset::HighAccuracy => {
310 self.config.solver = Solver::Lbfgs;
311 self.config.max_iter = 10000;
312 self.config.penalty = Penalty::ElasticNet {
313 l1_ratio: 0.5,
314 alpha: 1.0,
315 };
316 self.config.tol = 1e-6;
317 self.validation_config.cross_validation_folds = Some(5);
318 }
319 ModelPreset::Robust => {
320 self.config.solver = Solver::Saga;
321 self.config.penalty = Penalty::L1(1.0);
322 self.config.max_iter = 2000;
323 self.config.tol = 1e-4;
324 }
325 ModelPreset::MemoryEfficient => {
326 self.config.solver = Solver::Sag;
327 self.config.max_iter = 1000;
328 self.config.penalty = Penalty::L2(1.0);
329 self.config.tol = 1e-3;
330 }
331 ModelPreset::Production => {
332 self.config.solver = Solver::Lbfgs;
333 self.config.max_iter = 5000;
334 self.config.penalty = Penalty::ElasticNet {
335 l1_ratio: 0.1,
336 alpha: 1.0,
337 };
338 self.config.tol = 1e-5;
339 self.validation_config.cross_validation_folds = Some(5);
340 self.validation_config.early_stopping = true;
341 }
342 }
343
344 EnhancedLogisticRegressionBuilder {
345 config: self.config,
346 validation_config: self.validation_config,
347 _state: PhantomData,
348 }
349 }
350}
351
352#[cfg(feature = "logistic-regression")]
353impl<State> EnhancedLogisticRegressionBuilder<State> {
354 pub fn penalty(
356 mut self,
357 penalty: Penalty,
358 ) -> EnhancedLogisticRegressionBuilder<WithRegularization> {
359 self.config.penalty = penalty;
360 EnhancedLogisticRegressionBuilder {
361 config: self.config,
362 validation_config: self.validation_config,
363 _state: PhantomData,
364 }
365 }
366
367 pub fn solver(mut self, solver: Solver) -> EnhancedLogisticRegressionBuilder<WithSolver> {
369 self.config.solver = solver;
370 EnhancedLogisticRegressionBuilder {
371 config: self.config,
372 validation_config: self.validation_config,
373 _state: PhantomData,
374 }
375 }
376
377 pub fn max_iter(mut self, max_iter: usize) -> Self {
379 self.config.max_iter = max_iter;
380 self
381 }
382
383 pub fn tolerance(mut self, tol: f64) -> Self {
385 self.config.tol = tol;
386 self
387 }
388
389 pub fn fit_intercept(mut self, fit_intercept: bool) -> Self {
391 self.config.fit_intercept = fit_intercept;
392 self
393 }
394
395 pub fn with_cross_validation(mut self, folds: usize) -> Self {
397 self.validation_config.cross_validation_folds = Some(folds);
398 self
399 }
400
401 pub fn with_validation_split(mut self, split: Float) -> Self {
403 self.validation_config.validation_split = Some(split);
404 self
405 }
406
407 pub fn with_early_stopping(mut self) -> Self {
409 self.validation_config.early_stopping = true;
410 self
411 }
412
413 pub fn random_state(mut self, seed: u64) -> Self {
415 self.config.random_state = Some(seed);
416 self.validation_config.random_state = Some(seed);
417 self
418 }
419
420 pub fn build(self) -> Result<LogisticRegression> {
422 Ok(LogisticRegression::new()
423 .penalty(self.config.penalty)
424 .solver(self.config.solver)
425 .max_iter(self.config.max_iter)
426 .fit_intercept(self.config.fit_intercept))
427 }
428
429 pub fn config(&self) -> &LogisticRegressionConfig {
431 &self.config
432 }
433
434 pub fn validation_config(&self) -> &ValidationConfig {
436 &self.validation_config
437 }
438}
439
440impl validation::Configured for EnhancedLinearRegressionBuilder<Configured> {}
442impl validation::WithRegularization for EnhancedLinearRegressionBuilder<WithRegularization> {}
443impl validation::WithSolver for EnhancedLinearRegressionBuilder<WithSolver> {}
444
445#[cfg(feature = "logistic-regression")]
446impl validation::Configured for EnhancedLogisticRegressionBuilder<Configured> {}
447#[cfg(feature = "logistic-regression")]
448impl validation::WithRegularization for EnhancedLogisticRegressionBuilder<WithRegularization> {}
449#[cfg(feature = "logistic-regression")]
450impl validation::WithSolver for EnhancedLogisticRegressionBuilder<WithSolver> {}
451
452pub trait ModelValidation {
454 type Error;
455
456 fn validate_config(self) -> std::result::Result<Self, Self::Error>
458 where
459 Self: Sized;
460}
461
462impl ModelValidation for LinearRegression {
463 type Error = SklearsError;
464
465 fn validate_config(self) -> std::result::Result<Self, Self::Error> {
466 match self.config().penalty {
468 Penalty::L1(_) | Penalty::ElasticNet { .. } => {
469 if matches!(self.config().solver, Solver::Normal) {
470 return Err(SklearsError::InvalidInput(
471 "Normal equations solver does not support L1 regularization. Use CoordinateDescent or other iterative solver.".to_string()
472 ));
473 }
474 }
475 _ => {}
476 }
477
478 if self.config().max_iter == 0 {
479 return Err(SklearsError::InvalidInput(
480 "max_iter must be greater than 0".to_string(),
481 ));
482 }
483
484 Ok(self)
485 }
486}
487
488#[allow(non_snake_case)]
489#[cfg(test)]
490mod tests {
491 use super::*;
492
493 #[test]
494 fn test_enhanced_linear_regression_builder_presets() {
495 let quick_model = EnhancedLinearRegressionBuilder::with_preset(ModelPreset::Quick)
496 .build()
497 .unwrap();
498 assert_eq!(quick_model.config().solver, Solver::Normal);
499
500 let balanced_model = EnhancedLinearRegressionBuilder::with_preset(ModelPreset::Balanced)
501 .build()
502 .unwrap();
503 assert_eq!(balanced_model.config().solver, Solver::Auto);
504
505 let production_model =
506 EnhancedLinearRegressionBuilder::with_preset(ModelPreset::Production)
507 .build()
508 .unwrap();
509 assert!(matches!(
510 production_model.config().penalty,
511 Penalty::ElasticNet { .. }
512 ));
513 }
514
515 #[test]
516 #[cfg(feature = "logistic-regression")]
517 fn test_enhanced_logistic_regression_builder_presets() {
518 let quick_model = EnhancedLogisticRegressionBuilder::with_preset(ModelPreset::Quick)
519 .build()
520 .unwrap();
521 assert_eq!(quick_model.config().solver, Solver::Lbfgs);
522
523 let robust_model = EnhancedLogisticRegressionBuilder::with_preset(ModelPreset::Robust)
524 .build()
525 .unwrap();
526 assert_eq!(robust_model.config().solver, Solver::Saga);
527 }
528
529 #[test]
530 fn test_builder_method_chaining() {
531 let model = EnhancedLinearRegressionBuilder::new()
532 .solver(Solver::CoordinateDescent)
533 .penalty(Penalty::L1(0.5))
534 .max_iter(2000)
535 .fit_intercept(false)
536 .with_cross_validation(5)
537 .with_early_stopping()
538 .build()
539 .unwrap();
540
541 assert_eq!(model.config().solver, Solver::CoordinateDescent);
542 assert!(matches!(model.config().penalty, Penalty::L1(_)));
543 assert_eq!(model.config().max_iter, 2000);
544 assert!(!model.config().fit_intercept);
545 }
546
547 #[test]
548 #[cfg(feature = "logistic-regression")]
549 fn test_fluent_api() {
550 let builder = EnhancedLogisticRegressionBuilder::new()
551 .penalty(Penalty::L2(2.0))
552 .solver(Solver::Saga)
553 .max_iter(1500)
554 .tolerance(1e-5)
555 .random_state(42);
556
557 assert!(matches!(builder.config().penalty, Penalty::L2(_)));
558 assert_eq!(builder.config().solver, Solver::Saga);
559 assert_eq!(builder.config().max_iter, 1500);
560 assert_eq!(builder.config().random_state, Some(42));
561 }
562
563 #[test]
564 fn test_configuration_validation() {
565 let result = EnhancedLinearRegressionBuilder::new()
567 .solver(Solver::Normal)
568 .penalty(Penalty::L1(1.0))
569 .build();
570
571 assert!(result.is_err());
572 }
573
574 #[test]
575 fn test_preset_configurations_differ() {
576 let quick = EnhancedLinearRegressionBuilder::with_preset(ModelPreset::Quick);
577 let production = EnhancedLinearRegressionBuilder::with_preset(ModelPreset::Production);
578
579 assert_ne!(quick.config().max_iter, production.config().max_iter);
580 assert_ne!(
581 quick.validation_config().cross_validation_folds,
582 production.validation_config().cross_validation_folds
583 );
584 }
585}