sklears_core/
compile_time_validation.rs

1/// Compile-time validation framework for ML configurations
2///
3/// This module provides traits and types for compile-time validation of machine learning
4/// model configurations, preventing runtime errors and ensuring type safety.
5use crate::error::SklearsError;
6use std::marker::PhantomData;
7
8/// Marker trait for valid configurations
9pub trait ValidConfig {}
10
11/// Marker trait for configurations that have been validated
12pub trait Validated {}
13
14/// Marker trait for configurations that require validation
15pub trait RequiresValidation {}
16
17/// Phantom type for tracking validation state
18pub struct ValidationState<T> {
19    _phantom: PhantomData<T>,
20}
21
22/// Marker type for unvalidated configurations
23pub struct Unvalidated;
24
25/// Marker type for validated configurations
26pub struct ValidatedState;
27
28impl<T> ValidationState<T> {
29    pub fn new() -> Self {
30        Self {
31            _phantom: PhantomData,
32        }
33    }
34}
35
36impl<T> Default for ValidationState<T> {
37    fn default() -> Self {
38        Self::new()
39    }
40}
41
42/// Configuration wrapper that tracks validation state at compile time
43pub struct ValidatedConfig<T, S = Unvalidated> {
44    pub config: T,
45    _state: PhantomData<S>,
46}
47
48impl<T> ValidatedConfig<T, Unvalidated> {
49    /// Create a new unvalidated configuration
50    pub fn new(config: T) -> Self {
51        Self {
52            config,
53            _state: PhantomData,
54        }
55    }
56
57    /// Validate the configuration at compile time
58    pub fn validate(self) -> Result<ValidatedConfig<T, ValidatedState>, SklearsError>
59    where
60        T: ValidConfig,
61    {
62        // Runtime validation can still be performed here
63        Ok(ValidatedConfig {
64            config: self.config,
65            _state: PhantomData,
66        })
67    }
68}
69
70impl<T> ValidatedConfig<T, ValidatedState> {
71    /// Get the validated configuration
72    pub fn inner(&self) -> &T {
73        &self.config
74    }
75
76    /// Consume the wrapper and return the validated configuration
77    pub fn into_inner(self) -> T {
78        self.config
79    }
80}
81
82/// Trait for compile-time parameter validation
83pub trait ParameterValidator<T> {
84    type Error;
85
86    /// Validate parameter at compile time
87    fn validate(value: &T) -> Result<(), Self::Error>;
88}
89
90/// Compile-time range validator
91pub struct RangeValidator<const MIN: i64, const MAX: i64>;
92
93impl<const MIN: i64, const MAX: i64> ParameterValidator<i32> for RangeValidator<MIN, MAX> {
94    type Error = SklearsError;
95
96    fn validate(value: &i32) -> Result<(), Self::Error> {
97        if (*value as i64) < MIN || (*value as i64) > MAX {
98            Err(SklearsError::InvalidParameter {
99                name: "value".to_string(),
100                reason: format!("Value {value} not in range [{MIN}, {MAX}]"),
101            })
102        } else {
103            Ok(())
104        }
105    }
106}
107
108impl<const MIN: i64, const MAX: i64> ParameterValidator<f64> for RangeValidator<MIN, MAX> {
109    type Error = SklearsError;
110
111    fn validate(value: &f64) -> Result<(), Self::Error> {
112        if (*value as i64) < MIN || (*value as i64) > MAX {
113            Err(SklearsError::InvalidParameter {
114                name: "value".to_string(),
115                reason: format!("Value {value} not in range [{MIN}, {MAX}]"),
116            })
117        } else {
118            Ok(())
119        }
120    }
121}
122
123/// Positive number validator
124pub struct PositiveValidator;
125
126impl ParameterValidator<f64> for PositiveValidator {
127    type Error = SklearsError;
128
129    fn validate(value: &f64) -> Result<(), Self::Error> {
130        if *value <= 0.0 {
131            Err(SklearsError::InvalidParameter {
132                name: "value".to_string(),
133                reason: format!("Value {value} must be positive"),
134            })
135        } else {
136            Ok(())
137        }
138    }
139}
140
141impl ParameterValidator<i32> for PositiveValidator {
142    type Error = SklearsError;
143
144    fn validate(value: &i32) -> Result<(), Self::Error> {
145        if *value <= 0 {
146            Err(SklearsError::InvalidParameter {
147                name: "value".to_string(),
148                reason: format!("Value {value} must be positive"),
149            })
150        } else {
151            Ok(())
152        }
153    }
154}
155
156/// Probability validator (0.0 to 1.0)
157pub struct ProbabilityValidator;
158
159impl ParameterValidator<f64> for ProbabilityValidator {
160    type Error = SklearsError;
161
162    fn validate(value: &f64) -> Result<(), Self::Error> {
163        if *value < 0.0 || *value > 1.0 {
164            Err(SklearsError::InvalidParameter {
165                name: "probability".to_string(),
166                reason: format!("Probability {value} must be between 0.0 and 1.0"),
167            })
168        } else {
169            Ok(())
170        }
171    }
172}
173
174/// Macro for creating compile-time validated parameters
175#[macro_export]
176macro_rules! validated_param {
177    ($name:ident: $type:ty, $validator:ty, $value:expr) => {{
178        <$validator as $crate::compile_time_validation::ParameterValidator<$type>>::validate(
179            &$value,
180        )?;
181        $value
182    }};
183}
184
185/// Trait for algorithms that support compile-time configuration validation
186pub trait CompileTimeValidated {
187    type Config: ValidConfig;
188    type ValidatedConfig;
189
190    /// Create a validated configuration
191    fn validate_config(config: Self::Config) -> Result<Self::ValidatedConfig, SklearsError>;
192}
193
194/// Example validated configuration for linear regression
195#[derive(Debug, Clone)]
196pub struct LinearRegressionConfig {
197    pub fit_intercept: bool,
198    pub positive: bool,
199    pub alpha: f64,
200    pub max_iter: i32,
201}
202
203impl ValidConfig for LinearRegressionConfig {}
204
205impl LinearRegressionConfig {
206    /// Create a new configuration with compile-time validation
207    pub fn builder() -> LinearRegressionConfigBuilder<Unvalidated> {
208        LinearRegressionConfigBuilder::new()
209    }
210}
211
212/// Builder for LinearRegressionConfig with compile-time validation
213pub struct LinearRegressionConfigBuilder<S = Unvalidated> {
214    config: LinearRegressionConfig,
215    _state: PhantomData<S>,
216}
217
218impl LinearRegressionConfigBuilder<Unvalidated> {
219    pub fn new() -> Self {
220        Self {
221            config: LinearRegressionConfig {
222                fit_intercept: true,
223                positive: false,
224                alpha: 1.0,
225                max_iter: 1000,
226            },
227            _state: PhantomData,
228        }
229    }
230
231    pub fn fit_intercept(mut self, fit_intercept: bool) -> Self {
232        self.config.fit_intercept = fit_intercept;
233        self
234    }
235
236    pub fn positive(mut self, positive: bool) -> Self {
237        self.config.positive = positive;
238        self
239    }
240
241    /// Set alpha with compile-time validation
242    pub fn alpha(mut self, alpha: f64) -> Result<Self, SklearsError> {
243        PositiveValidator::validate(&alpha)?;
244        self.config.alpha = alpha;
245        Ok(self)
246    }
247
248    /// Set max_iter with compile-time validation
249    pub fn max_iter(mut self, max_iter: i32) -> Result<Self, SklearsError> {
250        RangeValidator::<1, 10000>::validate(&max_iter)?;
251        self.config.max_iter = max_iter;
252        Ok(self)
253    }
254
255    /// Build the validated configuration
256    pub fn build(self) -> Result<LinearRegressionConfigBuilder<ValidatedState>, SklearsError> {
257        // Additional cross-parameter validation can be done here
258        Ok(LinearRegressionConfigBuilder {
259            config: self.config,
260            _state: PhantomData,
261        })
262    }
263}
264
265impl Default for LinearRegressionConfigBuilder<Unvalidated> {
266    fn default() -> Self {
267        Self::new()
268    }
269}
270
271impl LinearRegressionConfigBuilder<ValidatedState> {
272    /// Get the validated configuration
273    pub fn config(&self) -> &LinearRegressionConfig {
274        &self.config
275    }
276
277    /// Consume the builder and return the validated configuration
278    pub fn into_config(self) -> LinearRegressionConfig {
279        self.config
280    }
281}
282
283/// Trait for dimension validation at compile time
284pub trait DimensionValidator<const N: usize> {
285    fn validate_dimensions(&self) -> Result<(), SklearsError>;
286}
287
288/// Fixed-size array wrapper with compile-time dimension validation
289pub struct FixedArray<T, const N: usize> {
290    data: [T; N],
291}
292
293impl<T, const N: usize> FixedArray<T, N> {
294    pub fn new(data: [T; N]) -> Self {
295        Self { data }
296    }
297
298    pub fn len(&self) -> usize {
299        N
300    }
301
302    pub fn is_empty(&self) -> bool {
303        N == 0
304    }
305
306    pub fn as_slice(&self) -> &[T] {
307        &self.data
308    }
309}
310
311impl<T, const N: usize> DimensionValidator<N> for FixedArray<T, N> {
312    fn validate_dimensions(&self) -> Result<(), SklearsError> {
313        // Compile-time dimension validation is automatic with const generics
314        Ok(())
315    }
316}
317
318/// Trait for solver compatibility validation
319pub trait SolverCompatibility<S> {
320    fn is_compatible() -> bool;
321}
322
323/// Marker types for different solvers
324pub struct SGDSolver;
325pub struct LBFGSSolver;
326pub struct CoordinateDescentSolver;
327
328/// Marker types for different regularization types
329pub struct L1Regularization;
330pub struct L2Regularization;
331pub struct ElasticNetRegularization;
332
333/// Example solver compatibility implementations
334impl SolverCompatibility<L1Regularization> for CoordinateDescentSolver {
335    fn is_compatible() -> bool {
336        true
337    }
338}
339
340impl SolverCompatibility<L1Regularization> for LBFGSSolver {
341    fn is_compatible() -> bool {
342        false // LBFGS doesn't support L1 regularization
343    }
344}
345
346impl SolverCompatibility<L2Regularization> for LBFGSSolver {
347    fn is_compatible() -> bool {
348        true
349    }
350}
351
352impl SolverCompatibility<ElasticNetRegularization> for CoordinateDescentSolver {
353    fn is_compatible() -> bool {
354        true
355    }
356}
357
358/// Compile-time solver validation
359pub fn validate_solver_regularization<S, R>() -> Result<(), SklearsError>
360where
361    S: SolverCompatibility<R>,
362{
363    if S::is_compatible() {
364        Ok(())
365    } else {
366        Err(SklearsError::InvalidParameter {
367            name: "solver".to_string(),
368            reason: "Solver is not compatible with the specified regularization".to_string(),
369        })
370    }
371}
372
373#[allow(non_snake_case)]
374#[cfg(test)]
375mod tests {
376    use super::*;
377
378    #[test]
379    fn test_validated_config_creation() {
380        let config = LinearRegressionConfig {
381            fit_intercept: true,
382            positive: false,
383            alpha: 1.0,
384            max_iter: 1000,
385        };
386
387        let validated = ValidatedConfig::new(config);
388        assert!(validated.validate().is_ok());
389    }
390
391    #[test]
392    fn test_config_builder_validation() {
393        let result = LinearRegressionConfig::builder()
394            .fit_intercept(true)
395            .alpha(0.5)
396            .unwrap()
397            .max_iter(500)
398            .unwrap()
399            .build();
400
401        assert!(result.is_ok());
402    }
403
404    #[test]
405    fn test_invalid_alpha() {
406        let result = LinearRegressionConfig::builder().alpha(-1.0);
407
408        assert!(result.is_err());
409    }
410
411    #[test]
412    fn test_invalid_max_iter() {
413        let result = LinearRegressionConfig::builder().max_iter(-1);
414
415        assert!(result.is_err());
416    }
417
418    #[test]
419    fn test_fixed_array_dimensions() {
420        let arr = FixedArray::new([1, 2, 3, 4, 5]);
421        assert_eq!(arr.len(), 5);
422        assert!(arr.validate_dimensions().is_ok());
423    }
424
425    #[test]
426    fn test_solver_compatibility() {
427        // This should compile and return Ok
428        assert!(
429            validate_solver_regularization::<CoordinateDescentSolver, L1Regularization>().is_ok()
430        );
431
432        // This should compile but return Err
433        assert!(validate_solver_regularization::<LBFGSSolver, L1Regularization>().is_err());
434    }
435
436    #[test]
437    fn test_range_validator() {
438        assert!(RangeValidator::<1, 100>::validate(&50).is_ok());
439        assert!(RangeValidator::<1, 100>::validate(&0).is_err());
440        assert!(RangeValidator::<1, 100>::validate(&101).is_err());
441    }
442
443    #[test]
444    fn test_positive_validator() {
445        assert!(PositiveValidator::validate(&1.0).is_ok());
446        assert!(PositiveValidator::validate(&0.0).is_err());
447        assert!(PositiveValidator::validate(&-1.0).is_err());
448    }
449
450    #[test]
451    fn test_probability_validator() {
452        assert!(ProbabilityValidator::validate(&0.5).is_ok());
453        assert!(ProbabilityValidator::validate(&0.0).is_ok());
454        assert!(ProbabilityValidator::validate(&1.0).is_ok());
455        assert!(ProbabilityValidator::validate(&-0.1).is_err());
456        assert!(ProbabilityValidator::validate(&1.1).is_err());
457    }
458}