1use crate::error::SklearsError;
6use std::marker::PhantomData;
7
8pub trait ValidConfig {}
10
11pub trait Validated {}
13
14pub trait RequiresValidation {}
16
17pub struct ValidationState<T> {
19 _phantom: PhantomData<T>,
20}
21
22pub struct Unvalidated;
24
25pub struct ValidatedState;
27
28impl<T> ValidationState<T> {
29 pub fn new() -> Self {
30 Self {
31 _phantom: PhantomData,
32 }
33 }
34}
35
36impl<T> Default for ValidationState<T> {
37 fn default() -> Self {
38 Self::new()
39 }
40}
41
42pub struct ValidatedConfig<T, S = Unvalidated> {
44 pub config: T,
45 _state: PhantomData<S>,
46}
47
48impl<T> ValidatedConfig<T, Unvalidated> {
49 pub fn new(config: T) -> Self {
51 Self {
52 config,
53 _state: PhantomData,
54 }
55 }
56
57 pub fn validate(self) -> Result<ValidatedConfig<T, ValidatedState>, SklearsError>
59 where
60 T: ValidConfig,
61 {
62 Ok(ValidatedConfig {
64 config: self.config,
65 _state: PhantomData,
66 })
67 }
68}
69
70impl<T> ValidatedConfig<T, ValidatedState> {
71 pub fn inner(&self) -> &T {
73 &self.config
74 }
75
76 pub fn into_inner(self) -> T {
78 self.config
79 }
80}
81
82pub trait ParameterValidator<T> {
84 type Error;
85
86 fn validate(value: &T) -> Result<(), Self::Error>;
88}
89
90pub struct RangeValidator<const MIN: i64, const MAX: i64>;
92
93impl<const MIN: i64, const MAX: i64> ParameterValidator<i32> for RangeValidator<MIN, MAX> {
94 type Error = SklearsError;
95
96 fn validate(value: &i32) -> Result<(), Self::Error> {
97 if (*value as i64) < MIN || (*value as i64) > MAX {
98 Err(SklearsError::InvalidParameter {
99 name: "value".to_string(),
100 reason: format!("Value {value} not in range [{MIN}, {MAX}]"),
101 })
102 } else {
103 Ok(())
104 }
105 }
106}
107
108impl<const MIN: i64, const MAX: i64> ParameterValidator<f64> for RangeValidator<MIN, MAX> {
109 type Error = SklearsError;
110
111 fn validate(value: &f64) -> Result<(), Self::Error> {
112 if (*value as i64) < MIN || (*value as i64) > MAX {
113 Err(SklearsError::InvalidParameter {
114 name: "value".to_string(),
115 reason: format!("Value {value} not in range [{MIN}, {MAX}]"),
116 })
117 } else {
118 Ok(())
119 }
120 }
121}
122
123pub struct PositiveValidator;
125
126impl ParameterValidator<f64> for PositiveValidator {
127 type Error = SklearsError;
128
129 fn validate(value: &f64) -> Result<(), Self::Error> {
130 if *value <= 0.0 {
131 Err(SklearsError::InvalidParameter {
132 name: "value".to_string(),
133 reason: format!("Value {value} must be positive"),
134 })
135 } else {
136 Ok(())
137 }
138 }
139}
140
141impl ParameterValidator<i32> for PositiveValidator {
142 type Error = SklearsError;
143
144 fn validate(value: &i32) -> Result<(), Self::Error> {
145 if *value <= 0 {
146 Err(SklearsError::InvalidParameter {
147 name: "value".to_string(),
148 reason: format!("Value {value} must be positive"),
149 })
150 } else {
151 Ok(())
152 }
153 }
154}
155
156pub struct ProbabilityValidator;
158
159impl ParameterValidator<f64> for ProbabilityValidator {
160 type Error = SklearsError;
161
162 fn validate(value: &f64) -> Result<(), Self::Error> {
163 if *value < 0.0 || *value > 1.0 {
164 Err(SklearsError::InvalidParameter {
165 name: "probability".to_string(),
166 reason: format!("Probability {value} must be between 0.0 and 1.0"),
167 })
168 } else {
169 Ok(())
170 }
171 }
172}
173
174#[macro_export]
176macro_rules! validated_param {
177 ($name:ident: $type:ty, $validator:ty, $value:expr) => {{
178 <$validator as $crate::compile_time_validation::ParameterValidator<$type>>::validate(
179 &$value,
180 )?;
181 $value
182 }};
183}
184
185pub trait CompileTimeValidated {
187 type Config: ValidConfig;
188 type ValidatedConfig;
189
190 fn validate_config(config: Self::Config) -> Result<Self::ValidatedConfig, SklearsError>;
192}
193
194#[derive(Debug, Clone)]
196pub struct LinearRegressionConfig {
197 pub fit_intercept: bool,
198 pub positive: bool,
199 pub alpha: f64,
200 pub max_iter: i32,
201}
202
203impl ValidConfig for LinearRegressionConfig {}
204
205impl LinearRegressionConfig {
206 pub fn builder() -> LinearRegressionConfigBuilder<Unvalidated> {
208 LinearRegressionConfigBuilder::new()
209 }
210}
211
212pub struct LinearRegressionConfigBuilder<S = Unvalidated> {
214 config: LinearRegressionConfig,
215 _state: PhantomData<S>,
216}
217
218impl LinearRegressionConfigBuilder<Unvalidated> {
219 pub fn new() -> Self {
220 Self {
221 config: LinearRegressionConfig {
222 fit_intercept: true,
223 positive: false,
224 alpha: 1.0,
225 max_iter: 1000,
226 },
227 _state: PhantomData,
228 }
229 }
230
231 pub fn fit_intercept(mut self, fit_intercept: bool) -> Self {
232 self.config.fit_intercept = fit_intercept;
233 self
234 }
235
236 pub fn positive(mut self, positive: bool) -> Self {
237 self.config.positive = positive;
238 self
239 }
240
241 pub fn alpha(mut self, alpha: f64) -> Result<Self, SklearsError> {
243 PositiveValidator::validate(&alpha)?;
244 self.config.alpha = alpha;
245 Ok(self)
246 }
247
248 pub fn max_iter(mut self, max_iter: i32) -> Result<Self, SklearsError> {
250 RangeValidator::<1, 10000>::validate(&max_iter)?;
251 self.config.max_iter = max_iter;
252 Ok(self)
253 }
254
255 pub fn build(self) -> Result<LinearRegressionConfigBuilder<ValidatedState>, SklearsError> {
257 Ok(LinearRegressionConfigBuilder {
259 config: self.config,
260 _state: PhantomData,
261 })
262 }
263}
264
265impl Default for LinearRegressionConfigBuilder<Unvalidated> {
266 fn default() -> Self {
267 Self::new()
268 }
269}
270
271impl LinearRegressionConfigBuilder<ValidatedState> {
272 pub fn config(&self) -> &LinearRegressionConfig {
274 &self.config
275 }
276
277 pub fn into_config(self) -> LinearRegressionConfig {
279 self.config
280 }
281}
282
283pub trait DimensionValidator<const N: usize> {
285 fn validate_dimensions(&self) -> Result<(), SklearsError>;
286}
287
288pub struct FixedArray<T, const N: usize> {
290 data: [T; N],
291}
292
293impl<T, const N: usize> FixedArray<T, N> {
294 pub fn new(data: [T; N]) -> Self {
295 Self { data }
296 }
297
298 pub fn len(&self) -> usize {
299 N
300 }
301
302 pub fn is_empty(&self) -> bool {
303 N == 0
304 }
305
306 pub fn as_slice(&self) -> &[T] {
307 &self.data
308 }
309}
310
311impl<T, const N: usize> DimensionValidator<N> for FixedArray<T, N> {
312 fn validate_dimensions(&self) -> Result<(), SklearsError> {
313 Ok(())
315 }
316}
317
318pub trait SolverCompatibility<S> {
320 fn is_compatible() -> bool;
321}
322
323pub struct SGDSolver;
325pub struct LBFGSSolver;
326pub struct CoordinateDescentSolver;
327
328pub struct L1Regularization;
330pub struct L2Regularization;
331pub struct ElasticNetRegularization;
332
333impl SolverCompatibility<L1Regularization> for CoordinateDescentSolver {
335 fn is_compatible() -> bool {
336 true
337 }
338}
339
340impl SolverCompatibility<L1Regularization> for LBFGSSolver {
341 fn is_compatible() -> bool {
342 false }
344}
345
346impl SolverCompatibility<L2Regularization> for LBFGSSolver {
347 fn is_compatible() -> bool {
348 true
349 }
350}
351
352impl SolverCompatibility<ElasticNetRegularization> for CoordinateDescentSolver {
353 fn is_compatible() -> bool {
354 true
355 }
356}
357
358pub fn validate_solver_regularization<S, R>() -> Result<(), SklearsError>
360where
361 S: SolverCompatibility<R>,
362{
363 if S::is_compatible() {
364 Ok(())
365 } else {
366 Err(SklearsError::InvalidParameter {
367 name: "solver".to_string(),
368 reason: "Solver is not compatible with the specified regularization".to_string(),
369 })
370 }
371}
372
373#[allow(non_snake_case)]
374#[cfg(test)]
375mod tests {
376 use super::*;
377
378 #[test]
379 fn test_validated_config_creation() {
380 let config = LinearRegressionConfig {
381 fit_intercept: true,
382 positive: false,
383 alpha: 1.0,
384 max_iter: 1000,
385 };
386
387 let validated = ValidatedConfig::new(config);
388 assert!(validated.validate().is_ok());
389 }
390
391 #[test]
392 fn test_config_builder_validation() {
393 let result = LinearRegressionConfig::builder()
394 .fit_intercept(true)
395 .alpha(0.5)
396 .unwrap()
397 .max_iter(500)
398 .unwrap()
399 .build();
400
401 assert!(result.is_ok());
402 }
403
404 #[test]
405 fn test_invalid_alpha() {
406 let result = LinearRegressionConfig::builder().alpha(-1.0);
407
408 assert!(result.is_err());
409 }
410
411 #[test]
412 fn test_invalid_max_iter() {
413 let result = LinearRegressionConfig::builder().max_iter(-1);
414
415 assert!(result.is_err());
416 }
417
418 #[test]
419 fn test_fixed_array_dimensions() {
420 let arr = FixedArray::new([1, 2, 3, 4, 5]);
421 assert_eq!(arr.len(), 5);
422 assert!(arr.validate_dimensions().is_ok());
423 }
424
425 #[test]
426 fn test_solver_compatibility() {
427 assert!(
429 validate_solver_regularization::<CoordinateDescentSolver, L1Regularization>().is_ok()
430 );
431
432 assert!(validate_solver_regularization::<LBFGSSolver, L1Regularization>().is_err());
434 }
435
436 #[test]
437 fn test_range_validator() {
438 assert!(RangeValidator::<1, 100>::validate(&50).is_ok());
439 assert!(RangeValidator::<1, 100>::validate(&0).is_err());
440 assert!(RangeValidator::<1, 100>::validate(&101).is_err());
441 }
442
443 #[test]
444 fn test_positive_validator() {
445 assert!(PositiveValidator::validate(&1.0).is_ok());
446 assert!(PositiveValidator::validate(&0.0).is_err());
447 assert!(PositiveValidator::validate(&-1.0).is_err());
448 }
449
450 #[test]
451 fn test_probability_validator() {
452 assert!(ProbabilityValidator::validate(&0.5).is_ok());
453 assert!(ProbabilityValidator::validate(&0.0).is_ok());
454 assert!(ProbabilityValidator::validate(&1.0).is_ok());
455 assert!(ProbabilityValidator::validate(&-0.1).is_err());
456 assert!(ProbabilityValidator::validate(&1.1).is_err());
457 }
458}