sklears_core/
validation_examples.rs

1/// Example implementations showing how to use the validation framework
2use crate::error::Result;
3use crate::types::Float;
4#[cfg(test)]
5use crate::validation::ValidationContext;
6use crate::validation::{ConfigValidation, Validate};
7
8/// Example configuration with manual validation implementation
9#[derive(Debug, Clone)]
10pub struct LinearRegressionConfig {
11    /// Learning rate for gradient descent
12    pub learning_rate: Float,
13
14    /// L2 regularization parameter
15    pub alpha: Float,
16
17    /// Maximum number of iterations
18    pub max_iter: usize,
19
20    /// Convergence tolerance
21    pub tol: Float,
22
23    /// Whether to fit intercept
24    pub fit_intercept: bool,
25
26    /// Solver method
27    pub solver: String,
28}
29
30impl Validate for LinearRegressionConfig {
31    fn validate(&self) -> Result<()> {
32        // Validate learning rate
33        crate::validation::ml::validate_learning_rate(self.learning_rate)?;
34
35        // Validate regularization
36        crate::validation::ml::validate_regularization(self.alpha)?;
37
38        // Validate max_iter
39        crate::validation::ml::validate_max_iter(self.max_iter)?;
40
41        // Validate tolerance
42        crate::validation::ValidationRules::new("tol")
43            .add_rule(crate::validation::ValidationRule::Positive)
44            .add_rule(crate::validation::ValidationRule::Finite)
45            .validate_numeric(&self.tol)?;
46
47        // Validate solver
48        crate::validation::ValidationRules::new("solver")
49            .add_rule(crate::validation::ValidationRule::OneOf(vec![
50                "auto".to_string(),
51                "svd".to_string(),
52                "cholesky".to_string(),
53                "lsqr".to_string(),
54                "sparse_cg".to_string(),
55                "sag".to_string(),
56                "saga".to_string(),
57            ]))
58            .validate_string(&self.solver)?;
59
60        Ok(())
61    }
62}
63
64impl Default for LinearRegressionConfig {
65    fn default() -> Self {
66        Self {
67            learning_rate: 0.01,
68            alpha: 1.0,
69            max_iter: 1000,
70            tol: 1e-4,
71            fit_intercept: true,
72            solver: "auto".to_string(),
73        }
74    }
75}
76
77impl ConfigValidation for LinearRegressionConfig {
78    fn validate_config(&self) -> Result<()> {
79        // First run basic validation
80        self.validate()?;
81
82        // Add algorithm-specific validation
83        if self.solver == "cholesky" && !self.fit_intercept {
84            return Err(crate::error::SklearsError::InvalidParameter {
85                name: "solver".to_string(),
86                reason: "cholesky solver requires fit_intercept=true".to_string(),
87            });
88        }
89
90        Ok(())
91    }
92
93    fn get_warnings(&self) -> Vec<String> {
94        let mut warnings = Vec::new();
95
96        if self.learning_rate > 0.1 {
97            warnings
98                .push("Learning rate is quite high, consider using a smaller value".to_string());
99        }
100
101        if self.max_iter < 100 {
102            warnings.push("Maximum iterations is quite low, model may not converge".to_string());
103        }
104
105        warnings
106    }
107}
108
109/// Example clustering configuration
110#[derive(Debug, Clone)]
111pub struct KMeansConfig {
112    /// Number of clusters
113    pub n_clusters: usize,
114
115    /// Maximum number of iterations  
116    pub max_iter: usize,
117
118    /// Convergence tolerance
119    pub tol: Float,
120
121    /// Initialization method
122    pub init: String,
123
124    /// Number of random initializations
125    pub n_init: usize,
126
127    /// Random seed
128    pub random_state: Option<u64>,
129}
130
131impl Validate for KMeansConfig {
132    fn validate(&self) -> Result<()> {
133        // Validate n_clusters
134        crate::validation::ml::validate_n_clusters(self.n_clusters, 100)?;
135
136        // Validate max_iter
137        crate::validation::ml::validate_max_iter(self.max_iter)?;
138
139        // Validate tolerance
140        crate::validation::ValidationRules::new("tol")
141            .add_rule(crate::validation::ValidationRule::Positive)
142            .add_rule(crate::validation::ValidationRule::Finite)
143            .validate_numeric(&self.tol)?;
144
145        // Validate initialization method
146        crate::validation::ValidationRules::new("init")
147            .add_rule(crate::validation::ValidationRule::OneOf(vec![
148                "k-means++".to_string(),
149                "random".to_string(),
150                "custom".to_string(),
151            ]))
152            .validate_string(&self.init)?;
153
154        // Validate n_init
155        if self.n_init == 0 {
156            return Err(crate::error::SklearsError::InvalidParameter {
157                name: "n_init".to_string(),
158                reason: "must be positive".to_string(),
159            });
160        }
161
162        Ok(())
163    }
164}
165
166impl Default for KMeansConfig {
167    fn default() -> Self {
168        Self {
169            n_clusters: 8,
170            max_iter: 300,
171            tol: 1e-4,
172            init: "k-means++".to_string(),
173            n_init: 10,
174            random_state: None,
175        }
176    }
177}
178
179impl ConfigValidation for KMeansConfig {
180    fn validate_config(&self) -> Result<()> {
181        self.validate()?;
182
183        // Additional validation for clustering
184        if self.n_clusters == 1 {
185            log::warn!("Using only 1 cluster - consider if clustering is necessary");
186        }
187
188        Ok(())
189    }
190}
191
192/// Example neural network configuration with complex validation
193#[derive(Debug, Clone)]
194pub struct MLPConfig {
195    /// Hidden layer sizes
196    pub hidden_layer_sizes: Vec<usize>,
197
198    /// Learning rate
199    pub learning_rate: Float,
200
201    /// Maximum number of iterations
202    pub max_iter: usize,
203
204    /// Dropout probability
205    pub dropout: Float,
206
207    /// Batch size
208    pub batch_size: usize,
209
210    /// L2 regularization
211    pub alpha: Float,
212
213    /// Activation function
214    pub activation: String,
215
216    /// Solver
217    pub solver: String,
218}
219
220impl Validate for MLPConfig {
221    fn validate(&self) -> Result<()> {
222        // Validate hidden layer sizes
223        crate::validation::ValidationRules::new("hidden_layer_sizes")
224            .add_rule(crate::validation::ValidationRule::MinLength(1))
225            .validate_array(&self.hidden_layer_sizes)?;
226
227        // Validate learning rate
228        crate::validation::ml::validate_learning_rate(self.learning_rate)?;
229
230        // Validate max_iter
231        crate::validation::ml::validate_max_iter(self.max_iter)?;
232
233        // Validate dropout probability
234        crate::validation::ml::validate_probability(self.dropout)?;
235
236        // Validate batch size
237        if self.batch_size == 0 {
238            return Err(crate::error::SklearsError::InvalidParameter {
239                name: "batch_size".to_string(),
240                reason: "must be positive".to_string(),
241            });
242        }
243
244        // Validate regularization
245        crate::validation::ml::validate_regularization(self.alpha)?;
246
247        // Validate activation function
248        crate::validation::ValidationRules::new("activation")
249            .add_rule(crate::validation::ValidationRule::OneOf(vec![
250                "relu".to_string(),
251                "tanh".to_string(),
252                "sigmoid".to_string(),
253                "identity".to_string(),
254            ]))
255            .validate_string(&self.activation)?;
256
257        // Validate solver
258        crate::validation::ValidationRules::new("solver")
259            .add_rule(crate::validation::ValidationRule::OneOf(vec![
260                "adam".to_string(),
261                "sgd".to_string(),
262                "lbfgs".to_string(),
263            ]))
264            .validate_string(&self.solver)?;
265
266        Ok(())
267    }
268}
269
270impl Default for MLPConfig {
271    fn default() -> Self {
272        Self {
273            hidden_layer_sizes: vec![100],
274            learning_rate: 0.001,
275            max_iter: 200,
276            dropout: 0.0,
277            batch_size: 32,
278            alpha: 1e-4,
279            activation: "relu".to_string(),
280            solver: "adam".to_string(),
281        }
282    }
283}
284
285impl ConfigValidation for MLPConfig {
286    fn validate_config(&self) -> Result<()> {
287        self.validate()?;
288
289        // Complex validation logic
290        if self.solver == "lbfgs" && self.hidden_layer_sizes.len() > 3 {
291            return Err(crate::error::SklearsError::InvalidParameter {
292                name: "solver".to_string(),
293                reason: "lbfgs solver may be inefficient for deep networks".to_string(),
294            });
295        }
296
297        if self.batch_size > 1000 {
298            log::warn!("Large batch size may lead to poor generalization");
299        }
300
301        Ok(())
302    }
303
304    fn get_warnings(&self) -> Vec<String> {
305        let mut warnings = Vec::new();
306
307        if self.hidden_layer_sizes.iter().any(|&size| size > 1000) {
308            warnings.push("Very large hidden layers may cause overfitting".to_string());
309        }
310
311        if self.dropout > 0.5 {
312            warnings.push("High dropout rate may prevent learning".to_string());
313        }
314
315        warnings
316    }
317}
318
319/// Example of manual validation implementation for complex cases
320pub struct CustomValidationExample {
321    pub param1: Float,
322    pub param2: usize,
323    pub dependent_param: Float,
324}
325
326impl Validate for CustomValidationExample {
327    fn validate(&self) -> Result<()> {
328        // Basic validations
329        if self.param1 <= 0.0 {
330            return Err(crate::error::SklearsError::InvalidParameter {
331                name: "param1".to_string(),
332                reason: "must be positive".to_string(),
333            });
334        }
335
336        // Complex cross-parameter validation
337        if self.param2 > 0 && self.dependent_param > self.param1 * 2.0 {
338            return Err(crate::error::SklearsError::InvalidParameter {
339                name: "dependent_param".to_string(),
340                reason: "cannot be more than twice param1 when param2 > 0".to_string(),
341            });
342        }
343
344        Ok(())
345    }
346}
347
348#[allow(non_snake_case)]
349#[cfg(test)]
350mod tests {
351    use super::*;
352
353    #[test]
354    fn test_linear_regression_config_validation() {
355        let mut config = LinearRegressionConfig::default();
356
357        // Valid configuration
358        assert!(config.validate().is_ok());
359
360        // Invalid learning rate
361        config.learning_rate = -0.1;
362        assert!(config.validate().is_err());
363
364        // Reset and test invalid solver
365        config = LinearRegressionConfig::default();
366        config.solver = "invalid_solver".to_string();
367        assert!(config.validate().is_err());
368    }
369
370    #[test]
371    fn test_kmeans_config_validation() {
372        let mut config = KMeansConfig::default();
373
374        // Valid configuration
375        assert!(config.validate().is_ok());
376
377        // Invalid n_clusters
378        config.n_clusters = 0;
379        assert!(config.validate().is_err());
380
381        // Reset and test invalid tolerance
382        config = KMeansConfig::default();
383        config.tol = -1.0;
384        assert!(config.validate().is_err());
385    }
386
387    #[test]
388    fn test_mlp_config_validation() {
389        let mut config = MLPConfig::default();
390
391        // Valid configuration
392        assert!(config.validate().is_ok());
393
394        // Invalid hidden layer sizes (empty)
395        config.hidden_layer_sizes = vec![];
396        assert!(config.validate().is_err());
397
398        // Reset and test invalid dropout
399        config = MLPConfig::default();
400        config.dropout = 1.5;
401        assert!(config.validate().is_err());
402    }
403
404    #[test]
405    fn test_config_validation_trait() {
406        let config = LinearRegressionConfig::default();
407
408        // Test config validation
409        assert!(config.validate_config().is_ok());
410
411        // Test warnings
412        let warnings = config.get_warnings();
413        // Should be empty for default config
414        assert!(warnings.is_empty());
415    }
416
417    #[test]
418    fn test_validation_context() {
419        let context =
420            ValidationContext::new("LinearRegression", "fit").with_data_info(100, 5, "float64");
421
422        let error = crate::error::SklearsError::InvalidParameter {
423            name: "learning_rate".to_string(),
424            reason: "must be positive".to_string(),
425        };
426
427        let formatted = context.format_error(&error);
428        assert!(formatted.contains("LinearRegression"));
429        assert!(formatted.contains("fit"));
430        assert!(formatted.contains("100 samples"));
431        assert!(formatted.contains("5 features"));
432    }
433
434    #[test]
435    fn test_custom_validation() {
436        let example = CustomValidationExample {
437            param1: 1.0,
438            param2: 0,
439            dependent_param: 1.5,
440        };
441
442        // Should be valid
443        assert!(example.validate().is_ok());
444
445        let example2 = CustomValidationExample {
446            param1: 1.0,
447            param2: 1,
448            dependent_param: 3.0, // > 2 * param1
449        };
450
451        // Should be invalid due to cross-parameter constraint
452        assert!(example2.validate().is_err());
453    }
454}