sklears_linear/
type_safety.rs

1//! Type Safety Enhancements for Linear Models
2//!
3//! This module implements phantom types, const generics, and zero-cost abstractions
4//! to provide compile-time guarantees and improve the type safety of linear models.
5
6use scirs2_core::ndarray::{Array1, Array2};
7use sklears_core::{
8    error::{Result, SklearsError},
9    types::Float,
10};
11use std::marker::PhantomData;
12
13/// Phantom type marker for untrained models
14#[derive(Debug, Clone, Copy)]
15pub struct Untrained;
16
17/// Phantom type marker for trained models
18#[derive(Debug, Clone, Copy)]
19pub struct Trained;
20
21/// Phantom type marker for different problem types
22pub mod problem_type {
23    /// Regression problem marker
24    #[derive(Debug, Clone, Copy)]
25    pub struct Regression;
26
27    /// Binary classification problem marker
28    #[derive(Debug, Clone, Copy)]
29    pub struct BinaryClassification;
30
31    /// Multi-class classification problem marker
32    #[derive(Debug, Clone, Copy)]
33    pub struct MultiClassification;
34
35    /// Multi-output regression problem marker
36    #[derive(Debug, Clone, Copy)]
37    pub struct MultiOutputRegression;
38}
39
40/// Phantom type marker for different solver capabilities
41pub mod solver_capability {
42    /// Supports smooth objectives only
43    #[derive(Debug, Clone, Copy)]
44    pub struct SmoothOnly;
45
46    /// Supports non-smooth objectives (with proximal operators)
47    #[derive(Debug, Clone, Copy)]
48    pub struct NonSmoothCapable;
49
50    /// Supports large-scale problems
51    #[derive(Debug, Clone, Copy)]
52    pub struct LargeScale;
53
54    /// Supports sparse problems
55    #[derive(Debug, Clone, Copy)]
56    pub struct SparseCapable;
57}
58
59/// Type-safe linear model with phantom types and const generics
60#[derive(Debug)]
61pub struct TypeSafeLinearModel<State, ProblemType, const N_FEATURES: usize> {
62    /// Model state (Trained or Untrained)
63    _state: PhantomData<State>,
64    /// Problem type marker
65    _problem_type: PhantomData<ProblemType>,
66    /// Model coefficients (only available when trained)
67    coefficients: Option<Array1<Float>>,
68    /// Intercept term (only available when trained)
69    intercept: Option<Float>,
70    /// Model configuration
71    config: TypeSafeConfig<ProblemType>,
72}
73
74/// Type-safe configuration for linear models
75#[derive(Debug, Clone)]
76pub struct TypeSafeConfig<ProblemType> {
77    /// Whether to fit intercept
78    pub fit_intercept: bool,
79    /// Regularization strength
80    pub alpha: Float,
81    /// Maximum iterations
82    pub max_iter: usize,
83    /// Convergence tolerance
84    pub tolerance: Float,
85    /// Problem type marker
86    _problem_type: PhantomData<ProblemType>,
87}
88
89impl<ProblemType> Default for TypeSafeConfig<ProblemType> {
90    fn default() -> Self {
91        Self::new()
92    }
93}
94
95impl<ProblemType> TypeSafeConfig<ProblemType> {
96    /// Create a new configuration
97    pub fn new() -> Self {
98        Self {
99            fit_intercept: true,
100            alpha: 1.0,
101            max_iter: 1000,
102            tolerance: 1e-6,
103            _problem_type: PhantomData,
104        }
105    }
106
107    /// Set whether to fit intercept
108    pub fn fit_intercept(mut self, fit_intercept: bool) -> Self {
109        self.fit_intercept = fit_intercept;
110        self
111    }
112
113    /// Set regularization strength
114    pub fn alpha(mut self, alpha: Float) -> Self {
115        self.alpha = alpha;
116        self
117    }
118
119    /// Set maximum iterations
120    pub fn max_iter(mut self, max_iter: usize) -> Self {
121        self.max_iter = max_iter;
122        self
123    }
124
125    /// Set convergence tolerance
126    pub fn tolerance(mut self, tolerance: Float) -> Self {
127        self.tolerance = tolerance;
128        self
129    }
130}
131
132impl<const N_FEATURES: usize> TypeSafeLinearModel<Untrained, problem_type::Regression, N_FEATURES> {
133    /// Create a new untrained regression model
134    pub fn new_regression() -> Self {
135        Self {
136            _state: PhantomData,
137            _problem_type: PhantomData,
138            coefficients: None,
139            intercept: None,
140            config: TypeSafeConfig::new(),
141        }
142    }
143
144    /// Configure the model
145    pub fn configure(mut self, config: TypeSafeConfig<problem_type::Regression>) -> Self {
146        self.config = config;
147        self
148    }
149}
150
151impl<const N_FEATURES: usize>
152    TypeSafeLinearModel<Untrained, problem_type::BinaryClassification, N_FEATURES>
153{
154    /// Create a new untrained binary classification model
155    pub fn new_binary_classification() -> Self {
156        Self {
157            _state: PhantomData,
158            _problem_type: PhantomData,
159            coefficients: None,
160            intercept: None,
161            config: TypeSafeConfig::new(),
162        }
163    }
164}
165
166impl<const N_FEATURES: usize>
167    TypeSafeLinearModel<Untrained, problem_type::MultiClassification, N_FEATURES>
168{
169    /// Create a new untrained multi-class classification model
170    pub fn new_multi_classification() -> Self {
171        Self {
172            _state: PhantomData,
173            _problem_type: PhantomData,
174            coefficients: None,
175            intercept: None,
176            config: TypeSafeConfig::new(),
177        }
178    }
179}
180
181/// Trait for fitting models with compile-time feature size checking
182pub trait TypeSafeFit<ProblemType, const N_FEATURES: usize> {
183    type TrainedModel;
184
185    /// Fit the model to training data with compile-time feature size verification
186    fn fit_typed(self, X: &Array2<Float>, y: &Array1<Float>) -> Result<Self::TrainedModel>;
187}
188
189impl<const N_FEATURES: usize> TypeSafeFit<problem_type::Regression, N_FEATURES>
190    for TypeSafeLinearModel<Untrained, problem_type::Regression, N_FEATURES>
191{
192    type TrainedModel = TypeSafeLinearModel<Trained, problem_type::Regression, N_FEATURES>;
193
194    fn fit_typed(self, X: &Array2<Float>, y: &Array1<Float>) -> Result<Self::TrainedModel> {
195        // Compile-time feature size check
196        if X.ncols() != N_FEATURES {
197            return Err(SklearsError::DimensionMismatch {
198                expected: N_FEATURES,
199                actual: X.ncols(),
200            });
201        }
202
203        // Simplified fitting logic (in practice, this would use the modular framework)
204        // Simple least squares solution: β = (X'X)^(-1)X'y
205        let xtx = X.t().dot(X);
206        let _xty = X.t().dot(y);
207
208        // Add regularization
209        let mut xtx_reg = xtx;
210        for i in 0..N_FEATURES {
211            xtx_reg[[i, i]] += self.config.alpha;
212        }
213
214        // Solve the system (simplified - in practice use proper linear algebra)
215        // This is just a placeholder for the actual solving logic
216        let coefficients = Array1::ones(N_FEATURES) * 0.5; // Dummy solution
217
218        let intercept = if self.config.fit_intercept {
219            Some(y.mean().unwrap_or(0.0))
220        } else {
221            None
222        };
223
224        Ok(TypeSafeLinearModel {
225            _state: PhantomData,
226            _problem_type: PhantomData,
227            coefficients: Some(coefficients),
228            intercept,
229            config: self.config,
230        })
231    }
232}
233
234/// Trait for making predictions with compile-time verification
235pub trait TypeSafePredict<ProblemType, const N_FEATURES: usize> {
236    /// Make predictions with compile-time feature size verification
237    fn predict_typed(&self, X: &Array2<Float>) -> Result<Array1<Float>>;
238}
239
240impl<const N_FEATURES: usize> TypeSafePredict<problem_type::Regression, N_FEATURES>
241    for TypeSafeLinearModel<Trained, problem_type::Regression, N_FEATURES>
242{
243    fn predict_typed(&self, X: &Array2<Float>) -> Result<Array1<Float>> {
244        // Compile-time feature size check
245        if X.ncols() != N_FEATURES {
246            return Err(SklearsError::DimensionMismatch {
247                expected: N_FEATURES,
248                actual: X.ncols(),
249            });
250        }
251
252        let coefficients = self
253            .coefficients
254            .as_ref()
255            .ok_or_else(|| SklearsError::InvalidOperation("Model is not trained".to_string()))?;
256
257        let mut predictions = X.dot(coefficients);
258
259        if let Some(intercept) = self.intercept {
260            predictions += intercept;
261        }
262
263        Ok(predictions)
264    }
265}
266
267/// Zero-cost abstraction for different regularization schemes
268pub trait RegularizationScheme {
269    /// Apply regularization to the objective
270    fn apply_regularization(&self, coefficients: &Array1<Float>) -> Float;
271
272    /// Apply regularization gradient
273    fn apply_regularization_gradient(&self, coefficients: &Array1<Float>) -> Array1<Float>;
274
275    /// Get the regularization strength
276    fn strength(&self) -> Float;
277}
278
279/// L2 regularization scheme (zero-cost abstraction)
280#[derive(Debug, Clone)]
281pub struct L2Scheme {
282    pub alpha: Float,
283}
284
285impl RegularizationScheme for L2Scheme {
286    fn apply_regularization(&self, coefficients: &Array1<Float>) -> Float {
287        0.5 * self.alpha * coefficients.mapv(|x| x * x).sum()
288    }
289
290    fn apply_regularization_gradient(&self, coefficients: &Array1<Float>) -> Array1<Float> {
291        self.alpha * coefficients
292    }
293
294    fn strength(&self) -> Float {
295        self.alpha
296    }
297}
298
299/// L1 regularization scheme (zero-cost abstraction)
300#[derive(Debug, Clone)]
301pub struct L1Scheme {
302    pub alpha: Float,
303}
304
305impl RegularizationScheme for L1Scheme {
306    fn apply_regularization(&self, coefficients: &Array1<Float>) -> Float {
307        self.alpha * coefficients.mapv(|x| x.abs()).sum()
308    }
309
310    fn apply_regularization_gradient(&self, coefficients: &Array1<Float>) -> Array1<Float> {
311        coefficients.mapv(|x| {
312            if x > 0.0 {
313                self.alpha
314            } else if x < 0.0 {
315                -self.alpha
316            } else {
317                0.0
318            }
319        })
320    }
321
322    fn strength(&self) -> Float {
323        self.alpha
324    }
325}
326
327/// Compile-time constraint checking for solver configurations
328pub trait SolverConstraint<ProblemType> {
329    /// Check if the solver is compatible with the problem type
330    fn is_compatible() -> bool;
331
332    /// Get solver-specific recommendations
333    fn get_recommendations() -> &'static str;
334
335    /// Get required features for this solver-problem combination
336    fn required_features() -> &'static [&'static str] {
337        &[]
338    }
339
340    /// Get incompatible features for this solver-problem combination
341    fn incompatible_features() -> &'static [&'static str] {
342        &[]
343    }
344}
345
346/// Enhanced compile-time configuration validation
347pub trait ConfigurationValidator<SolverType, ProblemType, RegularizationType> {
348    /// Validate configuration at compile time
349    fn validate_config() -> std::result::Result<(), &'static str>;
350
351    /// Get optimal hyperparameters for this configuration
352    fn optimal_hyperparameters() -> ConfigurationHints;
353}
354
355/// Configuration hints for optimal performance
356#[derive(Debug, Clone, Default)]
357pub struct ConfigurationHints {
358    /// Recommended tolerance
359    pub tolerance: Option<Float>,
360    /// Recommended maximum iterations
361    pub max_iterations: Option<usize>,
362    /// Recommended regularization strength range
363    pub regularization_range: Option<(Float, Float)>,
364    /// Performance notes
365    pub notes: Vec<&'static str>,
366}
367
368/// Compile-time feature validation
369pub trait FeatureValidator<const N_FEATURES: usize> {
370    /// Validate that the feature count is appropriate for the algorithm
371    fn validate_feature_count() -> std::result::Result<(), SklearsError>;
372
373    /// Get memory requirements for this feature count
374    fn memory_requirements() -> MemoryRequirements;
375
376    /// Get computational complexity estimate
377    fn computational_complexity() -> ComputationalComplexity;
378}
379
380/// Memory requirements estimate
381#[derive(Debug, Clone)]
382pub struct MemoryRequirements {
383    /// Estimated memory usage in bytes
384    pub estimated_bytes: usize,
385    /// Whether the algorithm is memory-intensive
386    pub is_memory_intensive: bool,
387    /// Recommendations for memory optimization
388    pub optimization_notes: Vec<&'static str>,
389}
390
391/// Computational complexity estimate
392#[derive(Debug, Clone)]
393pub struct ComputationalComplexity {
394    /// Time complexity (e.g., "O(n^2)", "O(n*p)")
395    pub time_complexity: &'static str,
396    /// Space complexity
397    pub space_complexity: &'static str,
398    /// Whether the algorithm is compute-intensive
399    pub is_compute_intensive: bool,
400}
401
402/// Advanced constraint checking for regularization compatibility
403pub trait RegularizationConstraint<SolverType, RegularizationType> {
404    /// Check if solver supports this regularization type
405    fn is_solver_compatible() -> bool;
406
407    /// Get solver-specific regularization recommendations
408    fn get_solver_recommendations() -> &'static str;
409
410    /// Get optimal regularization strength for this solver
411    fn optimal_strength_range() -> (Float, Float);
412}
413
414/// Gradient descent is compatible with smooth problems
415impl SolverConstraint<problem_type::Regression> for solver_capability::SmoothOnly {
416    fn is_compatible() -> bool {
417        true
418    }
419
420    fn get_recommendations() -> &'static str {
421        "Gradient descent works well for smooth regression objectives"
422    }
423
424    fn required_features() -> &'static [&'static str] {
425        &["smooth_objective", "differentiable"]
426    }
427
428    fn incompatible_features() -> &'static [&'static str] {
429        &["l1_regularization", "non_smooth"]
430    }
431}
432
433/// Coordinate descent is compatible with L1-regularized problems
434impl SolverConstraint<problem_type::Regression> for solver_capability::NonSmoothCapable {
435    fn is_compatible() -> bool {
436        true
437    }
438
439    fn get_recommendations() -> &'static str {
440        "Coordinate descent is ideal for L1-regularized problems"
441    }
442
443    fn required_features() -> &'static [&'static str] {
444        &["separable_objective"]
445    }
446
447    fn incompatible_features() -> &'static [&'static str] {
448        &[]
449    }
450}
451
452/// Configuration validation for smooth solvers with L2 regularization
453impl ConfigurationValidator<solver_capability::SmoothOnly, problem_type::Regression, L2Scheme>
454    for ()
455{
456    fn validate_config() -> std::result::Result<(), &'static str> {
457        // L2 regularization is smooth, so it's compatible with smooth solvers
458        Ok(())
459    }
460
461    fn optimal_hyperparameters() -> ConfigurationHints {
462        ConfigurationHints {
463            tolerance: Some(1e-6),
464            max_iterations: Some(1000),
465            regularization_range: Some((1e-4, 1e2)),
466            notes: vec![
467                "Use line search for better convergence",
468                "Consider preconditioning for ill-conditioned problems",
469            ],
470        }
471    }
472}
473
474/// Configuration validation for non-smooth capable solvers with L1 regularization
475impl ConfigurationValidator<solver_capability::NonSmoothCapable, problem_type::Regression, L1Scheme>
476    for ()
477{
478    fn validate_config() -> std::result::Result<(), &'static str> {
479        // L1 regularization requires non-smooth capable solvers
480        Ok(())
481    }
482
483    fn optimal_hyperparameters() -> ConfigurationHints {
484        ConfigurationHints {
485            tolerance: Some(1e-4),
486            max_iterations: Some(10000),
487            regularization_range: Some((1e-6, 1e1)),
488            notes: vec![
489                "Use coordinate descent for efficiency",
490                "Consider warm starts for regularization path",
491            ],
492        }
493    }
494}
495
496/// Feature validation for small problems
497impl<const N_FEATURES: usize> FeatureValidator<N_FEATURES> for ()
498where
499    [(); N_FEATURES]:,
500{
501    fn validate_feature_count() -> std::result::Result<(), SklearsError> {
502        if N_FEATURES == 0 {
503            Err(SklearsError::InvalidOperation(
504                "Feature count must be greater than 0".to_string(),
505            ))
506        } else if N_FEATURES > 100000 {
507            Err(SklearsError::InvalidOperation(
508                "Feature count too large - consider dimensionality reduction".to_string(),
509            ))
510        } else {
511            Ok(())
512        }
513    }
514
515    fn memory_requirements() -> MemoryRequirements {
516        let bytes_per_feature = std::mem::size_of::<Float>();
517        let coefficient_memory = N_FEATURES * bytes_per_feature;
518        let gram_matrix_memory = N_FEATURES * N_FEATURES * bytes_per_feature;
519
520        let total_memory = coefficient_memory + gram_matrix_memory;
521        let is_memory_intensive = total_memory > 1_000_000; // 1MB threshold
522
523        let optimization_notes = if is_memory_intensive {
524            vec![
525                "Consider using sparse matrices",
526                "Use iterative solvers to avoid Gram matrix",
527            ]
528        } else {
529            vec!["Memory usage is reasonable"]
530        };
531
532        MemoryRequirements {
533            estimated_bytes: total_memory,
534            is_memory_intensive,
535            optimization_notes,
536        }
537    }
538
539    fn computational_complexity() -> ComputationalComplexity {
540        let is_compute_intensive = N_FEATURES > 10000;
541
542        ComputationalComplexity {
543            time_complexity: "O(n*p^2)",
544            space_complexity: "O(p^2)",
545            is_compute_intensive,
546        }
547    }
548}
549
550/// Regularization constraint for L2 with smooth solvers
551impl RegularizationConstraint<solver_capability::SmoothOnly, L2Scheme> for () {
552    fn is_solver_compatible() -> bool {
553        true
554    }
555
556    fn get_solver_recommendations() -> &'static str {
557        "L2 regularization is smooth and works well with gradient-based methods"
558    }
559
560    fn optimal_strength_range() -> (Float, Float) {
561        (1e-4, 1e2)
562    }
563}
564
565/// Regularization constraint for L1 with smooth solvers (incompatible)
566impl RegularizationConstraint<solver_capability::SmoothOnly, L1Scheme> for () {
567    fn is_solver_compatible() -> bool {
568        false
569    }
570
571    fn get_solver_recommendations() -> &'static str {
572        "L1 regularization is non-smooth and requires specialized solvers like coordinate descent"
573    }
574
575    fn optimal_strength_range() -> (Float, Float) {
576        (0.0, 0.0) // Not applicable for incompatible combinations
577    }
578}
579
580/// Regularization constraint for L1 with non-smooth capable solvers
581impl RegularizationConstraint<solver_capability::NonSmoothCapable, L1Scheme> for () {
582    fn is_solver_compatible() -> bool {
583        true
584    }
585
586    fn get_solver_recommendations() -> &'static str {
587        "L1 regularization works excellently with coordinate descent and proximal methods"
588    }
589
590    fn optimal_strength_range() -> (Float, Float) {
591        (1e-6, 1e1)
592    }
593}
594
595/// Type-safe solver selector
596pub struct TypeSafeSolverSelector<SolverType, ProblemType> {
597    _solver_type: PhantomData<SolverType>,
598    _problem_type: PhantomData<ProblemType>,
599}
600
601impl<SolverType, ProblemType> Default for TypeSafeSolverSelector<SolverType, ProblemType>
602where
603    SolverType: SolverConstraint<ProblemType>,
604{
605    fn default() -> Self {
606        Self::new()
607    }
608}
609
610impl<SolverType, ProblemType> TypeSafeSolverSelector<SolverType, ProblemType>
611where
612    SolverType: SolverConstraint<ProblemType>,
613{
614    /// Create a new solver selector with compile-time compatibility checking
615    pub fn new() -> Self {
616        // Compile-time check
617        assert!(
618            SolverType::is_compatible(),
619            "Solver not compatible with problem type"
620        );
621
622        Self {
623            _solver_type: PhantomData,
624            _problem_type: PhantomData,
625        }
626    }
627
628    /// Get solver recommendations
629    pub fn recommendations(&self) -> &'static str {
630        SolverType::get_recommendations()
631    }
632}
633
634/// Fixed-size array operations for small problems (const generic optimization)
635pub struct FixedSizeOps<const N: usize>;
636
637impl<const N: usize> FixedSizeOps<N> {
638    /// Dot product for fixed-size vectors (compile-time optimized)
639    pub fn dot_product(a: &[Float; N], b: &[Float; N]) -> Float {
640        let mut sum = 0.0;
641        for i in 0..N {
642            sum += a[i] * b[i];
643        }
644        sum
645    }
646
647    /// Matrix-vector multiplication for fixed-size matrices (compile-time optimized)
648    pub fn matrix_vector_multiply<const M: usize>(
649        matrix: &[[Float; N]; M],
650        vector: &[Float; N],
651    ) -> [Float; M] {
652        let mut result = [0.0; M];
653        for i in 0..M {
654            result[i] = Self::dot_product(&matrix[i], vector);
655        }
656        result
657    }
658
659    /// L2 norm for fixed-size vectors
660    pub fn l2_norm(vector: &[Float; N]) -> Float {
661        Self::dot_product(vector, vector).sqrt()
662    }
663
664    /// Normalize fixed-size vector in-place
665    pub fn normalize(vector: &mut [Float; N]) {
666        let norm = Self::l2_norm(vector);
667        if norm > 0.0 {
668            for elem in vector.iter_mut().take(N) {
669                *elem /= norm;
670            }
671        }
672    }
673}
674
675/// Type alias for common fixed-size models
676pub type SmallLinearRegression = TypeSafeLinearModel<Untrained, problem_type::Regression, 10>;
677pub type MediumLinearRegression = TypeSafeLinearModel<Untrained, problem_type::Regression, 100>;
678pub type LargeLinearRegression = TypeSafeLinearModel<Untrained, problem_type::Regression, 1000>;
679
680/// Builder pattern with compile-time validation
681#[derive(Debug)]
682pub struct TypeSafeModelBuilder<ProblemType, const N_FEATURES: usize> {
683    config: TypeSafeConfig<ProblemType>,
684}
685
686impl<const N_FEATURES: usize> TypeSafeModelBuilder<problem_type::Regression, N_FEATURES> {
687    /// Create a new builder for regression
688    pub fn new_regression() -> Self {
689        Self {
690            config: TypeSafeConfig::new(),
691        }
692    }
693
694    /// Set regularization strength with compile-time validation
695    pub fn with_l2_regularization(mut self, alpha: Float) -> Self {
696        self.config.alpha = alpha;
697        self
698    }
699
700    /// Build the model
701    pub fn build(self) -> TypeSafeLinearModel<Untrained, problem_type::Regression, N_FEATURES> {
702        TypeSafeLinearModel {
703            _state: PhantomData,
704            _problem_type: PhantomData,
705            coefficients: None,
706            intercept: None,
707            config: self.config,
708        }
709    }
710}
711
712#[allow(non_snake_case)]
713#[cfg(test)]
714mod tests {
715    use super::*;
716    use scirs2_core::ndarray::Array;
717
718    #[test]
719    fn test_type_safe_model_creation() {
720        let model: SmallLinearRegression = TypeSafeLinearModel::new_regression();
721        // Verify that the model is created with correct types
722        assert!(std::mem::size_of_val(&model) > 0);
723    }
724
725    #[test]
726    fn test_type_safe_config() {
727        let config: TypeSafeConfig<problem_type::Regression> = TypeSafeConfig::new()
728            .fit_intercept(true)
729            .alpha(0.1)
730            .max_iter(500)
731            .tolerance(1e-8);
732
733        assert!(config.fit_intercept);
734        assert_eq!(config.alpha, 0.1);
735        assert_eq!(config.max_iter, 500);
736        assert_eq!(config.tolerance, 1e-8);
737    }
738
739    #[test]
740    fn test_fixed_size_operations() {
741        let a = [1.0, 2.0, 3.0];
742        let b = [4.0, 5.0, 6.0];
743
744        let dot = FixedSizeOps::<3>::dot_product(&a, &b);
745        assert_eq!(dot, 32.0); // 1*4 + 2*5 + 3*6 = 32
746
747        let norm = FixedSizeOps::<3>::l2_norm(&a);
748        assert!((norm - (14.0_f64).sqrt()).abs() < 1e-10);
749    }
750
751    #[test]
752    fn test_matrix_vector_multiply() {
753        let matrix = [[1.0, 2.0], [3.0, 4.0]];
754        let vector = [5.0, 6.0];
755
756        let result = FixedSizeOps::<2>::matrix_vector_multiply(&matrix, &vector);
757        assert_eq!(result, [17.0, 39.0]); // [1*5+2*6, 3*5+4*6] = [17, 39]
758    }
759
760    #[test]
761    fn test_regularization_schemes() {
762        let coefficients = Array::from_vec(vec![1.0, -2.0, 3.0]);
763
764        let l2_scheme = L2Scheme { alpha: 0.5 };
765        let l2_penalty = l2_scheme.apply_regularization(&coefficients);
766        let expected_l2 = 0.5 * 0.5 * (1.0 + 4.0 + 9.0);
767        assert!((l2_penalty - expected_l2).abs() < 1e-10);
768
769        let l1_scheme = L1Scheme { alpha: 0.3 };
770        let l1_penalty = l1_scheme.apply_regularization(&coefficients);
771        let expected_l1 = 0.3 * (1.0 + 2.0 + 3.0);
772        assert!((l1_penalty - expected_l1).abs() < 1e-10);
773    }
774
775    #[test]
776    fn test_solver_selector() {
777        let _selector: TypeSafeSolverSelector<
778            solver_capability::SmoothOnly,
779            problem_type::Regression,
780        > = TypeSafeSolverSelector::new();
781
782        // This would fail at compile time if solver is not compatible:
783        // let _incompatible: TypeSafeSolverSelector<solver_capability::SparseCapable, problem_type::BinaryClassification>
784        //     = TypeSafeSolverSelector::new();
785    }
786
787    #[test]
788    fn test_type_safe_builder() {
789        let model: TypeSafeLinearModel<Untrained, problem_type::Regression, 5> =
790            TypeSafeModelBuilder::new_regression()
791                .with_l2_regularization(0.1)
792                .build();
793
794        assert_eq!(model.config.alpha, 0.1);
795    }
796
797    #[test]
798    fn test_normalization() {
799        let mut vector = [3.0, 4.0, 0.0];
800        FixedSizeOps::<3>::normalize(&mut vector);
801
802        let norm = FixedSizeOps::<3>::l2_norm(&vector);
803        assert!((norm - 1.0).abs() < 1e-10);
804    }
805}