1use scirs2_core::ndarray::{Array1, Array2};
7use sklears_core::{
8 error::{Result, SklearsError},
9 types::Float,
10};
11use std::marker::PhantomData;
12
13#[derive(Debug, Clone, Copy)]
15pub struct Untrained;
16
17#[derive(Debug, Clone, Copy)]
19pub struct Trained;
20
21pub mod problem_type {
23 #[derive(Debug, Clone, Copy)]
25 pub struct Regression;
26
27 #[derive(Debug, Clone, Copy)]
29 pub struct BinaryClassification;
30
31 #[derive(Debug, Clone, Copy)]
33 pub struct MultiClassification;
34
35 #[derive(Debug, Clone, Copy)]
37 pub struct MultiOutputRegression;
38}
39
40pub mod solver_capability {
42 #[derive(Debug, Clone, Copy)]
44 pub struct SmoothOnly;
45
46 #[derive(Debug, Clone, Copy)]
48 pub struct NonSmoothCapable;
49
50 #[derive(Debug, Clone, Copy)]
52 pub struct LargeScale;
53
54 #[derive(Debug, Clone, Copy)]
56 pub struct SparseCapable;
57}
58
59#[derive(Debug)]
61pub struct TypeSafeLinearModel<State, ProblemType, const N_FEATURES: usize> {
62 _state: PhantomData<State>,
64 _problem_type: PhantomData<ProblemType>,
66 coefficients: Option<Array1<Float>>,
68 intercept: Option<Float>,
70 config: TypeSafeConfig<ProblemType>,
72}
73
74#[derive(Debug, Clone)]
76pub struct TypeSafeConfig<ProblemType> {
77 pub fit_intercept: bool,
79 pub alpha: Float,
81 pub max_iter: usize,
83 pub tolerance: Float,
85 _problem_type: PhantomData<ProblemType>,
87}
88
89impl<ProblemType> Default for TypeSafeConfig<ProblemType> {
90 fn default() -> Self {
91 Self::new()
92 }
93}
94
95impl<ProblemType> TypeSafeConfig<ProblemType> {
96 pub fn new() -> Self {
98 Self {
99 fit_intercept: true,
100 alpha: 1.0,
101 max_iter: 1000,
102 tolerance: 1e-6,
103 _problem_type: PhantomData,
104 }
105 }
106
107 pub fn fit_intercept(mut self, fit_intercept: bool) -> Self {
109 self.fit_intercept = fit_intercept;
110 self
111 }
112
113 pub fn alpha(mut self, alpha: Float) -> Self {
115 self.alpha = alpha;
116 self
117 }
118
119 pub fn max_iter(mut self, max_iter: usize) -> Self {
121 self.max_iter = max_iter;
122 self
123 }
124
125 pub fn tolerance(mut self, tolerance: Float) -> Self {
127 self.tolerance = tolerance;
128 self
129 }
130}
131
132impl<const N_FEATURES: usize> TypeSafeLinearModel<Untrained, problem_type::Regression, N_FEATURES> {
133 pub fn new_regression() -> Self {
135 Self {
136 _state: PhantomData,
137 _problem_type: PhantomData,
138 coefficients: None,
139 intercept: None,
140 config: TypeSafeConfig::new(),
141 }
142 }
143
144 pub fn configure(mut self, config: TypeSafeConfig<problem_type::Regression>) -> Self {
146 self.config = config;
147 self
148 }
149}
150
151impl<const N_FEATURES: usize>
152 TypeSafeLinearModel<Untrained, problem_type::BinaryClassification, N_FEATURES>
153{
154 pub fn new_binary_classification() -> Self {
156 Self {
157 _state: PhantomData,
158 _problem_type: PhantomData,
159 coefficients: None,
160 intercept: None,
161 config: TypeSafeConfig::new(),
162 }
163 }
164}
165
166impl<const N_FEATURES: usize>
167 TypeSafeLinearModel<Untrained, problem_type::MultiClassification, N_FEATURES>
168{
169 pub fn new_multi_classification() -> Self {
171 Self {
172 _state: PhantomData,
173 _problem_type: PhantomData,
174 coefficients: None,
175 intercept: None,
176 config: TypeSafeConfig::new(),
177 }
178 }
179}
180
181pub trait TypeSafeFit<ProblemType, const N_FEATURES: usize> {
183 type TrainedModel;
184
185 fn fit_typed(self, X: &Array2<Float>, y: &Array1<Float>) -> Result<Self::TrainedModel>;
187}
188
189impl<const N_FEATURES: usize> TypeSafeFit<problem_type::Regression, N_FEATURES>
190 for TypeSafeLinearModel<Untrained, problem_type::Regression, N_FEATURES>
191{
192 type TrainedModel = TypeSafeLinearModel<Trained, problem_type::Regression, N_FEATURES>;
193
194 fn fit_typed(self, X: &Array2<Float>, y: &Array1<Float>) -> Result<Self::TrainedModel> {
195 if X.ncols() != N_FEATURES {
197 return Err(SklearsError::DimensionMismatch {
198 expected: N_FEATURES,
199 actual: X.ncols(),
200 });
201 }
202
203 let xtx = X.t().dot(X);
206 let _xty = X.t().dot(y);
207
208 let mut xtx_reg = xtx;
210 for i in 0..N_FEATURES {
211 xtx_reg[[i, i]] += self.config.alpha;
212 }
213
214 let coefficients = Array1::ones(N_FEATURES) * 0.5; let intercept = if self.config.fit_intercept {
219 Some(y.mean().unwrap_or(0.0))
220 } else {
221 None
222 };
223
224 Ok(TypeSafeLinearModel {
225 _state: PhantomData,
226 _problem_type: PhantomData,
227 coefficients: Some(coefficients),
228 intercept,
229 config: self.config,
230 })
231 }
232}
233
234pub trait TypeSafePredict<ProblemType, const N_FEATURES: usize> {
236 fn predict_typed(&self, X: &Array2<Float>) -> Result<Array1<Float>>;
238}
239
240impl<const N_FEATURES: usize> TypeSafePredict<problem_type::Regression, N_FEATURES>
241 for TypeSafeLinearModel<Trained, problem_type::Regression, N_FEATURES>
242{
243 fn predict_typed(&self, X: &Array2<Float>) -> Result<Array1<Float>> {
244 if X.ncols() != N_FEATURES {
246 return Err(SklearsError::DimensionMismatch {
247 expected: N_FEATURES,
248 actual: X.ncols(),
249 });
250 }
251
252 let coefficients = self
253 .coefficients
254 .as_ref()
255 .ok_or_else(|| SklearsError::InvalidOperation("Model is not trained".to_string()))?;
256
257 let mut predictions = X.dot(coefficients);
258
259 if let Some(intercept) = self.intercept {
260 predictions += intercept;
261 }
262
263 Ok(predictions)
264 }
265}
266
267pub trait RegularizationScheme {
269 fn apply_regularization(&self, coefficients: &Array1<Float>) -> Float;
271
272 fn apply_regularization_gradient(&self, coefficients: &Array1<Float>) -> Array1<Float>;
274
275 fn strength(&self) -> Float;
277}
278
279#[derive(Debug, Clone)]
281pub struct L2Scheme {
282 pub alpha: Float,
283}
284
285impl RegularizationScheme for L2Scheme {
286 fn apply_regularization(&self, coefficients: &Array1<Float>) -> Float {
287 0.5 * self.alpha * coefficients.mapv(|x| x * x).sum()
288 }
289
290 fn apply_regularization_gradient(&self, coefficients: &Array1<Float>) -> Array1<Float> {
291 self.alpha * coefficients
292 }
293
294 fn strength(&self) -> Float {
295 self.alpha
296 }
297}
298
299#[derive(Debug, Clone)]
301pub struct L1Scheme {
302 pub alpha: Float,
303}
304
305impl RegularizationScheme for L1Scheme {
306 fn apply_regularization(&self, coefficients: &Array1<Float>) -> Float {
307 self.alpha * coefficients.mapv(|x| x.abs()).sum()
308 }
309
310 fn apply_regularization_gradient(&self, coefficients: &Array1<Float>) -> Array1<Float> {
311 coefficients.mapv(|x| {
312 if x > 0.0 {
313 self.alpha
314 } else if x < 0.0 {
315 -self.alpha
316 } else {
317 0.0
318 }
319 })
320 }
321
322 fn strength(&self) -> Float {
323 self.alpha
324 }
325}
326
327pub trait SolverConstraint<ProblemType> {
329 fn is_compatible() -> bool;
331
332 fn get_recommendations() -> &'static str;
334
335 fn required_features() -> &'static [&'static str] {
337 &[]
338 }
339
340 fn incompatible_features() -> &'static [&'static str] {
342 &[]
343 }
344}
345
346pub trait ConfigurationValidator<SolverType, ProblemType, RegularizationType> {
348 fn validate_config() -> std::result::Result<(), &'static str>;
350
351 fn optimal_hyperparameters() -> ConfigurationHints;
353}
354
355#[derive(Debug, Clone, Default)]
357pub struct ConfigurationHints {
358 pub tolerance: Option<Float>,
360 pub max_iterations: Option<usize>,
362 pub regularization_range: Option<(Float, Float)>,
364 pub notes: Vec<&'static str>,
366}
367
368pub trait FeatureValidator<const N_FEATURES: usize> {
370 fn validate_feature_count() -> std::result::Result<(), SklearsError>;
372
373 fn memory_requirements() -> MemoryRequirements;
375
376 fn computational_complexity() -> ComputationalComplexity;
378}
379
380#[derive(Debug, Clone)]
382pub struct MemoryRequirements {
383 pub estimated_bytes: usize,
385 pub is_memory_intensive: bool,
387 pub optimization_notes: Vec<&'static str>,
389}
390
391#[derive(Debug, Clone)]
393pub struct ComputationalComplexity {
394 pub time_complexity: &'static str,
396 pub space_complexity: &'static str,
398 pub is_compute_intensive: bool,
400}
401
402pub trait RegularizationConstraint<SolverType, RegularizationType> {
404 fn is_solver_compatible() -> bool;
406
407 fn get_solver_recommendations() -> &'static str;
409
410 fn optimal_strength_range() -> (Float, Float);
412}
413
414impl SolverConstraint<problem_type::Regression> for solver_capability::SmoothOnly {
416 fn is_compatible() -> bool {
417 true
418 }
419
420 fn get_recommendations() -> &'static str {
421 "Gradient descent works well for smooth regression objectives"
422 }
423
424 fn required_features() -> &'static [&'static str] {
425 &["smooth_objective", "differentiable"]
426 }
427
428 fn incompatible_features() -> &'static [&'static str] {
429 &["l1_regularization", "non_smooth"]
430 }
431}
432
433impl SolverConstraint<problem_type::Regression> for solver_capability::NonSmoothCapable {
435 fn is_compatible() -> bool {
436 true
437 }
438
439 fn get_recommendations() -> &'static str {
440 "Coordinate descent is ideal for L1-regularized problems"
441 }
442
443 fn required_features() -> &'static [&'static str] {
444 &["separable_objective"]
445 }
446
447 fn incompatible_features() -> &'static [&'static str] {
448 &[]
449 }
450}
451
452impl ConfigurationValidator<solver_capability::SmoothOnly, problem_type::Regression, L2Scheme>
454 for ()
455{
456 fn validate_config() -> std::result::Result<(), &'static str> {
457 Ok(())
459 }
460
461 fn optimal_hyperparameters() -> ConfigurationHints {
462 ConfigurationHints {
463 tolerance: Some(1e-6),
464 max_iterations: Some(1000),
465 regularization_range: Some((1e-4, 1e2)),
466 notes: vec![
467 "Use line search for better convergence",
468 "Consider preconditioning for ill-conditioned problems",
469 ],
470 }
471 }
472}
473
474impl ConfigurationValidator<solver_capability::NonSmoothCapable, problem_type::Regression, L1Scheme>
476 for ()
477{
478 fn validate_config() -> std::result::Result<(), &'static str> {
479 Ok(())
481 }
482
483 fn optimal_hyperparameters() -> ConfigurationHints {
484 ConfigurationHints {
485 tolerance: Some(1e-4),
486 max_iterations: Some(10000),
487 regularization_range: Some((1e-6, 1e1)),
488 notes: vec![
489 "Use coordinate descent for efficiency",
490 "Consider warm starts for regularization path",
491 ],
492 }
493 }
494}
495
496impl<const N_FEATURES: usize> FeatureValidator<N_FEATURES> for ()
498where
499 [(); N_FEATURES]:,
500{
501 fn validate_feature_count() -> std::result::Result<(), SklearsError> {
502 if N_FEATURES == 0 {
503 Err(SklearsError::InvalidOperation(
504 "Feature count must be greater than 0".to_string(),
505 ))
506 } else if N_FEATURES > 100000 {
507 Err(SklearsError::InvalidOperation(
508 "Feature count too large - consider dimensionality reduction".to_string(),
509 ))
510 } else {
511 Ok(())
512 }
513 }
514
515 fn memory_requirements() -> MemoryRequirements {
516 let bytes_per_feature = std::mem::size_of::<Float>();
517 let coefficient_memory = N_FEATURES * bytes_per_feature;
518 let gram_matrix_memory = N_FEATURES * N_FEATURES * bytes_per_feature;
519
520 let total_memory = coefficient_memory + gram_matrix_memory;
521 let is_memory_intensive = total_memory > 1_000_000; let optimization_notes = if is_memory_intensive {
524 vec![
525 "Consider using sparse matrices",
526 "Use iterative solvers to avoid Gram matrix",
527 ]
528 } else {
529 vec!["Memory usage is reasonable"]
530 };
531
532 MemoryRequirements {
533 estimated_bytes: total_memory,
534 is_memory_intensive,
535 optimization_notes,
536 }
537 }
538
539 fn computational_complexity() -> ComputationalComplexity {
540 let is_compute_intensive = N_FEATURES > 10000;
541
542 ComputationalComplexity {
543 time_complexity: "O(n*p^2)",
544 space_complexity: "O(p^2)",
545 is_compute_intensive,
546 }
547 }
548}
549
550impl RegularizationConstraint<solver_capability::SmoothOnly, L2Scheme> for () {
552 fn is_solver_compatible() -> bool {
553 true
554 }
555
556 fn get_solver_recommendations() -> &'static str {
557 "L2 regularization is smooth and works well with gradient-based methods"
558 }
559
560 fn optimal_strength_range() -> (Float, Float) {
561 (1e-4, 1e2)
562 }
563}
564
565impl RegularizationConstraint<solver_capability::SmoothOnly, L1Scheme> for () {
567 fn is_solver_compatible() -> bool {
568 false
569 }
570
571 fn get_solver_recommendations() -> &'static str {
572 "L1 regularization is non-smooth and requires specialized solvers like coordinate descent"
573 }
574
575 fn optimal_strength_range() -> (Float, Float) {
576 (0.0, 0.0) }
578}
579
580impl RegularizationConstraint<solver_capability::NonSmoothCapable, L1Scheme> for () {
582 fn is_solver_compatible() -> bool {
583 true
584 }
585
586 fn get_solver_recommendations() -> &'static str {
587 "L1 regularization works excellently with coordinate descent and proximal methods"
588 }
589
590 fn optimal_strength_range() -> (Float, Float) {
591 (1e-6, 1e1)
592 }
593}
594
595pub struct TypeSafeSolverSelector<SolverType, ProblemType> {
597 _solver_type: PhantomData<SolverType>,
598 _problem_type: PhantomData<ProblemType>,
599}
600
601impl<SolverType, ProblemType> Default for TypeSafeSolverSelector<SolverType, ProblemType>
602where
603 SolverType: SolverConstraint<ProblemType>,
604{
605 fn default() -> Self {
606 Self::new()
607 }
608}
609
610impl<SolverType, ProblemType> TypeSafeSolverSelector<SolverType, ProblemType>
611where
612 SolverType: SolverConstraint<ProblemType>,
613{
614 pub fn new() -> Self {
616 assert!(
618 SolverType::is_compatible(),
619 "Solver not compatible with problem type"
620 );
621
622 Self {
623 _solver_type: PhantomData,
624 _problem_type: PhantomData,
625 }
626 }
627
628 pub fn recommendations(&self) -> &'static str {
630 SolverType::get_recommendations()
631 }
632}
633
634pub struct FixedSizeOps<const N: usize>;
636
637impl<const N: usize> FixedSizeOps<N> {
638 pub fn dot_product(a: &[Float; N], b: &[Float; N]) -> Float {
640 let mut sum = 0.0;
641 for i in 0..N {
642 sum += a[i] * b[i];
643 }
644 sum
645 }
646
647 pub fn matrix_vector_multiply<const M: usize>(
649 matrix: &[[Float; N]; M],
650 vector: &[Float; N],
651 ) -> [Float; M] {
652 let mut result = [0.0; M];
653 for i in 0..M {
654 result[i] = Self::dot_product(&matrix[i], vector);
655 }
656 result
657 }
658
659 pub fn l2_norm(vector: &[Float; N]) -> Float {
661 Self::dot_product(vector, vector).sqrt()
662 }
663
664 pub fn normalize(vector: &mut [Float; N]) {
666 let norm = Self::l2_norm(vector);
667 if norm > 0.0 {
668 for elem in vector.iter_mut().take(N) {
669 *elem /= norm;
670 }
671 }
672 }
673}
674
675pub type SmallLinearRegression = TypeSafeLinearModel<Untrained, problem_type::Regression, 10>;
677pub type MediumLinearRegression = TypeSafeLinearModel<Untrained, problem_type::Regression, 100>;
678pub type LargeLinearRegression = TypeSafeLinearModel<Untrained, problem_type::Regression, 1000>;
679
680#[derive(Debug)]
682pub struct TypeSafeModelBuilder<ProblemType, const N_FEATURES: usize> {
683 config: TypeSafeConfig<ProblemType>,
684}
685
686impl<const N_FEATURES: usize> TypeSafeModelBuilder<problem_type::Regression, N_FEATURES> {
687 pub fn new_regression() -> Self {
689 Self {
690 config: TypeSafeConfig::new(),
691 }
692 }
693
694 pub fn with_l2_regularization(mut self, alpha: Float) -> Self {
696 self.config.alpha = alpha;
697 self
698 }
699
700 pub fn build(self) -> TypeSafeLinearModel<Untrained, problem_type::Regression, N_FEATURES> {
702 TypeSafeLinearModel {
703 _state: PhantomData,
704 _problem_type: PhantomData,
705 coefficients: None,
706 intercept: None,
707 config: self.config,
708 }
709 }
710}
711
712#[allow(non_snake_case)]
713#[cfg(test)]
714mod tests {
715 use super::*;
716 use scirs2_core::ndarray::Array;
717
718 #[test]
719 fn test_type_safe_model_creation() {
720 let model: SmallLinearRegression = TypeSafeLinearModel::new_regression();
721 assert!(std::mem::size_of_val(&model) > 0);
723 }
724
725 #[test]
726 fn test_type_safe_config() {
727 let config: TypeSafeConfig<problem_type::Regression> = TypeSafeConfig::new()
728 .fit_intercept(true)
729 .alpha(0.1)
730 .max_iter(500)
731 .tolerance(1e-8);
732
733 assert!(config.fit_intercept);
734 assert_eq!(config.alpha, 0.1);
735 assert_eq!(config.max_iter, 500);
736 assert_eq!(config.tolerance, 1e-8);
737 }
738
739 #[test]
740 fn test_fixed_size_operations() {
741 let a = [1.0, 2.0, 3.0];
742 let b = [4.0, 5.0, 6.0];
743
744 let dot = FixedSizeOps::<3>::dot_product(&a, &b);
745 assert_eq!(dot, 32.0); let norm = FixedSizeOps::<3>::l2_norm(&a);
748 assert!((norm - (14.0_f64).sqrt()).abs() < 1e-10);
749 }
750
751 #[test]
752 fn test_matrix_vector_multiply() {
753 let matrix = [[1.0, 2.0], [3.0, 4.0]];
754 let vector = [5.0, 6.0];
755
756 let result = FixedSizeOps::<2>::matrix_vector_multiply(&matrix, &vector);
757 assert_eq!(result, [17.0, 39.0]); }
759
760 #[test]
761 fn test_regularization_schemes() {
762 let coefficients = Array::from_vec(vec![1.0, -2.0, 3.0]);
763
764 let l2_scheme = L2Scheme { alpha: 0.5 };
765 let l2_penalty = l2_scheme.apply_regularization(&coefficients);
766 let expected_l2 = 0.5 * 0.5 * (1.0 + 4.0 + 9.0);
767 assert!((l2_penalty - expected_l2).abs() < 1e-10);
768
769 let l1_scheme = L1Scheme { alpha: 0.3 };
770 let l1_penalty = l1_scheme.apply_regularization(&coefficients);
771 let expected_l1 = 0.3 * (1.0 + 2.0 + 3.0);
772 assert!((l1_penalty - expected_l1).abs() < 1e-10);
773 }
774
775 #[test]
776 fn test_solver_selector() {
777 let _selector: TypeSafeSolverSelector<
778 solver_capability::SmoothOnly,
779 problem_type::Regression,
780 > = TypeSafeSolverSelector::new();
781
782 }
786
787 #[test]
788 fn test_type_safe_builder() {
789 let model: TypeSafeLinearModel<Untrained, problem_type::Regression, 5> =
790 TypeSafeModelBuilder::new_regression()
791 .with_l2_regularization(0.1)
792 .build();
793
794 assert_eq!(model.config.alpha, 0.1);
795 }
796
797 #[test]
798 fn test_normalization() {
799 let mut vector = [3.0, 4.0, 0.0];
800 FixedSizeOps::<3>::normalize(&mut vector);
801
802 let norm = FixedSizeOps::<3>::l2_norm(&vector);
803 assert!((norm - 1.0).abs() < 1e-10);
804 }
805}