1use crate::error::Result;
3use crate::types::Float;
4#[cfg(test)]
5use crate::validation::ValidationContext;
6use crate::validation::{ConfigValidation, Validate};
7
8#[derive(Debug, Clone)]
10pub struct LinearRegressionConfig {
11 pub learning_rate: Float,
13
14 pub alpha: Float,
16
17 pub max_iter: usize,
19
20 pub tol: Float,
22
23 pub fit_intercept: bool,
25
26 pub solver: String,
28}
29
30impl Validate for LinearRegressionConfig {
31 fn validate(&self) -> Result<()> {
32 crate::validation::ml::validate_learning_rate(self.learning_rate)?;
34
35 crate::validation::ml::validate_regularization(self.alpha)?;
37
38 crate::validation::ml::validate_max_iter(self.max_iter)?;
40
41 crate::validation::ValidationRules::new("tol")
43 .add_rule(crate::validation::ValidationRule::Positive)
44 .add_rule(crate::validation::ValidationRule::Finite)
45 .validate_numeric(&self.tol)?;
46
47 crate::validation::ValidationRules::new("solver")
49 .add_rule(crate::validation::ValidationRule::OneOf(vec![
50 "auto".to_string(),
51 "svd".to_string(),
52 "cholesky".to_string(),
53 "lsqr".to_string(),
54 "sparse_cg".to_string(),
55 "sag".to_string(),
56 "saga".to_string(),
57 ]))
58 .validate_string(&self.solver)?;
59
60 Ok(())
61 }
62}
63
64impl Default for LinearRegressionConfig {
65 fn default() -> Self {
66 Self {
67 learning_rate: 0.01,
68 alpha: 1.0,
69 max_iter: 1000,
70 tol: 1e-4,
71 fit_intercept: true,
72 solver: "auto".to_string(),
73 }
74 }
75}
76
77impl ConfigValidation for LinearRegressionConfig {
78 fn validate_config(&self) -> Result<()> {
79 self.validate()?;
81
82 if self.solver == "cholesky" && !self.fit_intercept {
84 return Err(crate::error::SklearsError::InvalidParameter {
85 name: "solver".to_string(),
86 reason: "cholesky solver requires fit_intercept=true".to_string(),
87 });
88 }
89
90 Ok(())
91 }
92
93 fn get_warnings(&self) -> Vec<String> {
94 let mut warnings = Vec::new();
95
96 if self.learning_rate > 0.1 {
97 warnings
98 .push("Learning rate is quite high, consider using a smaller value".to_string());
99 }
100
101 if self.max_iter < 100 {
102 warnings.push("Maximum iterations is quite low, model may not converge".to_string());
103 }
104
105 warnings
106 }
107}
108
109#[derive(Debug, Clone)]
111pub struct KMeansConfig {
112 pub n_clusters: usize,
114
115 pub max_iter: usize,
117
118 pub tol: Float,
120
121 pub init: String,
123
124 pub n_init: usize,
126
127 pub random_state: Option<u64>,
129}
130
131impl Validate for KMeansConfig {
132 fn validate(&self) -> Result<()> {
133 crate::validation::ml::validate_n_clusters(self.n_clusters, 100)?;
135
136 crate::validation::ml::validate_max_iter(self.max_iter)?;
138
139 crate::validation::ValidationRules::new("tol")
141 .add_rule(crate::validation::ValidationRule::Positive)
142 .add_rule(crate::validation::ValidationRule::Finite)
143 .validate_numeric(&self.tol)?;
144
145 crate::validation::ValidationRules::new("init")
147 .add_rule(crate::validation::ValidationRule::OneOf(vec![
148 "k-means++".to_string(),
149 "random".to_string(),
150 "custom".to_string(),
151 ]))
152 .validate_string(&self.init)?;
153
154 if self.n_init == 0 {
156 return Err(crate::error::SklearsError::InvalidParameter {
157 name: "n_init".to_string(),
158 reason: "must be positive".to_string(),
159 });
160 }
161
162 Ok(())
163 }
164}
165
166impl Default for KMeansConfig {
167 fn default() -> Self {
168 Self {
169 n_clusters: 8,
170 max_iter: 300,
171 tol: 1e-4,
172 init: "k-means++".to_string(),
173 n_init: 10,
174 random_state: None,
175 }
176 }
177}
178
179impl ConfigValidation for KMeansConfig {
180 fn validate_config(&self) -> Result<()> {
181 self.validate()?;
182
183 if self.n_clusters == 1 {
185 log::warn!("Using only 1 cluster - consider if clustering is necessary");
186 }
187
188 Ok(())
189 }
190}
191
192#[derive(Debug, Clone)]
194pub struct MLPConfig {
195 pub hidden_layer_sizes: Vec<usize>,
197
198 pub learning_rate: Float,
200
201 pub max_iter: usize,
203
204 pub dropout: Float,
206
207 pub batch_size: usize,
209
210 pub alpha: Float,
212
213 pub activation: String,
215
216 pub solver: String,
218}
219
220impl Validate for MLPConfig {
221 fn validate(&self) -> Result<()> {
222 crate::validation::ValidationRules::new("hidden_layer_sizes")
224 .add_rule(crate::validation::ValidationRule::MinLength(1))
225 .validate_array(&self.hidden_layer_sizes)?;
226
227 crate::validation::ml::validate_learning_rate(self.learning_rate)?;
229
230 crate::validation::ml::validate_max_iter(self.max_iter)?;
232
233 crate::validation::ml::validate_probability(self.dropout)?;
235
236 if self.batch_size == 0 {
238 return Err(crate::error::SklearsError::InvalidParameter {
239 name: "batch_size".to_string(),
240 reason: "must be positive".to_string(),
241 });
242 }
243
244 crate::validation::ml::validate_regularization(self.alpha)?;
246
247 crate::validation::ValidationRules::new("activation")
249 .add_rule(crate::validation::ValidationRule::OneOf(vec![
250 "relu".to_string(),
251 "tanh".to_string(),
252 "sigmoid".to_string(),
253 "identity".to_string(),
254 ]))
255 .validate_string(&self.activation)?;
256
257 crate::validation::ValidationRules::new("solver")
259 .add_rule(crate::validation::ValidationRule::OneOf(vec![
260 "adam".to_string(),
261 "sgd".to_string(),
262 "lbfgs".to_string(),
263 ]))
264 .validate_string(&self.solver)?;
265
266 Ok(())
267 }
268}
269
270impl Default for MLPConfig {
271 fn default() -> Self {
272 Self {
273 hidden_layer_sizes: vec![100],
274 learning_rate: 0.001,
275 max_iter: 200,
276 dropout: 0.0,
277 batch_size: 32,
278 alpha: 1e-4,
279 activation: "relu".to_string(),
280 solver: "adam".to_string(),
281 }
282 }
283}
284
285impl ConfigValidation for MLPConfig {
286 fn validate_config(&self) -> Result<()> {
287 self.validate()?;
288
289 if self.solver == "lbfgs" && self.hidden_layer_sizes.len() > 3 {
291 return Err(crate::error::SklearsError::InvalidParameter {
292 name: "solver".to_string(),
293 reason: "lbfgs solver may be inefficient for deep networks".to_string(),
294 });
295 }
296
297 if self.batch_size > 1000 {
298 log::warn!("Large batch size may lead to poor generalization");
299 }
300
301 Ok(())
302 }
303
304 fn get_warnings(&self) -> Vec<String> {
305 let mut warnings = Vec::new();
306
307 if self.hidden_layer_sizes.iter().any(|&size| size > 1000) {
308 warnings.push("Very large hidden layers may cause overfitting".to_string());
309 }
310
311 if self.dropout > 0.5 {
312 warnings.push("High dropout rate may prevent learning".to_string());
313 }
314
315 warnings
316 }
317}
318
319pub struct CustomValidationExample {
321 pub param1: Float,
322 pub param2: usize,
323 pub dependent_param: Float,
324}
325
326impl Validate for CustomValidationExample {
327 fn validate(&self) -> Result<()> {
328 if self.param1 <= 0.0 {
330 return Err(crate::error::SklearsError::InvalidParameter {
331 name: "param1".to_string(),
332 reason: "must be positive".to_string(),
333 });
334 }
335
336 if self.param2 > 0 && self.dependent_param > self.param1 * 2.0 {
338 return Err(crate::error::SklearsError::InvalidParameter {
339 name: "dependent_param".to_string(),
340 reason: "cannot be more than twice param1 when param2 > 0".to_string(),
341 });
342 }
343
344 Ok(())
345 }
346}
347
348#[allow(non_snake_case)]
349#[cfg(test)]
350mod tests {
351 use super::*;
352
353 #[test]
354 fn test_linear_regression_config_validation() {
355 let mut config = LinearRegressionConfig::default();
356
357 assert!(config.validate().is_ok());
359
360 config.learning_rate = -0.1;
362 assert!(config.validate().is_err());
363
364 config = LinearRegressionConfig::default();
366 config.solver = "invalid_solver".to_string();
367 assert!(config.validate().is_err());
368 }
369
370 #[test]
371 fn test_kmeans_config_validation() {
372 let mut config = KMeansConfig::default();
373
374 assert!(config.validate().is_ok());
376
377 config.n_clusters = 0;
379 assert!(config.validate().is_err());
380
381 config = KMeansConfig::default();
383 config.tol = -1.0;
384 assert!(config.validate().is_err());
385 }
386
387 #[test]
388 fn test_mlp_config_validation() {
389 let mut config = MLPConfig::default();
390
391 assert!(config.validate().is_ok());
393
394 config.hidden_layer_sizes = vec![];
396 assert!(config.validate().is_err());
397
398 config = MLPConfig::default();
400 config.dropout = 1.5;
401 assert!(config.validate().is_err());
402 }
403
404 #[test]
405 fn test_config_validation_trait() {
406 let config = LinearRegressionConfig::default();
407
408 assert!(config.validate_config().is_ok());
410
411 let warnings = config.get_warnings();
413 assert!(warnings.is_empty());
415 }
416
417 #[test]
418 fn test_validation_context() {
419 let context =
420 ValidationContext::new("LinearRegression", "fit").with_data_info(100, 5, "float64");
421
422 let error = crate::error::SklearsError::InvalidParameter {
423 name: "learning_rate".to_string(),
424 reason: "must be positive".to_string(),
425 };
426
427 let formatted = context.format_error(&error);
428 assert!(formatted.contains("LinearRegression"));
429 assert!(formatted.contains("fit"));
430 assert!(formatted.contains("100 samples"));
431 assert!(formatted.contains("5 features"));
432 }
433
434 #[test]
435 fn test_custom_validation() {
436 let example = CustomValidationExample {
437 param1: 1.0,
438 param2: 0,
439 dependent_param: 1.5,
440 };
441
442 assert!(example.validate().is_ok());
444
445 let example2 = CustomValidationExample {
446 param1: 1.0,
447 param2: 1,
448 dependent_param: 3.0, };
450
451 assert!(example2.validate().is_err());
453 }
454}