1use scirs2_core::ndarray::{Array1, Array2, Axis};
11use sklears_core::{
12 error::{Result, SklearsError},
13 types::Float,
14};
15use std::fmt::Debug;
16
17pub trait Objective: Debug + Send + Sync {
19 fn value(&self, coefficients: &Array1<Float>, data: &ObjectiveData) -> Result<Float>;
21
22 fn gradient(&self, coefficients: &Array1<Float>, data: &ObjectiveData)
24 -> Result<Array1<Float>>;
25
26 fn value_and_gradient(
28 &self,
29 coefficients: &Array1<Float>,
30 data: &ObjectiveData,
31 ) -> Result<(Float, Array1<Float>)> {
32 let value = self.value(coefficients, data)?;
33 let gradient = self.gradient(coefficients, data)?;
34 Ok((value, gradient))
35 }
36
37 fn supports_hessian(&self) -> bool {
39 false
40 }
41
42 fn hessian(
44 &self,
45 _coefficients: &Array1<Float>,
46 _data: &ObjectiveData,
47 ) -> Result<Array2<Float>> {
48 Err(SklearsError::InvalidOperation(
49 "Hessian computation not supported for this objective".to_string(),
50 ))
51 }
52}
53
54#[derive(Debug, Clone)]
56pub struct ObjectiveData {
57 pub features: Array2<Float>,
59 pub targets: Array1<Float>,
61 pub sample_weights: Option<Array1<Float>>,
63 pub metadata: ObjectiveMetadata,
65}
66
67#[derive(Debug, Clone, Default)]
69pub struct ObjectiveMetadata {
70 pub fit_intercept: bool,
72 pub feature_scale: Option<Array1<Float>>,
74 pub target_scale: Option<Float>,
76}
77
78pub trait LossFunction: Debug + Send + Sync {
80 fn loss(&self, y_true: &Array1<Float>, y_pred: &Array1<Float>) -> Result<Float>;
82
83 fn loss_derivative(
85 &self,
86 y_true: &Array1<Float>,
87 y_pred: &Array1<Float>,
88 ) -> Result<Array1<Float>>;
89
90 fn loss_and_derivative(
92 &self,
93 y_true: &Array1<Float>,
94 y_pred: &Array1<Float>,
95 ) -> Result<(Float, Array1<Float>)> {
96 let loss = self.loss(y_true, y_pred)?;
97 let derivative = self.loss_derivative(y_true, y_pred)?;
98 Ok((loss, derivative))
99 }
100
101 fn is_classification(&self) -> bool {
103 false
104 }
105
106 fn name(&self) -> &'static str;
108}
109
110pub trait Regularization: Debug + Send + Sync {
112 fn penalty(&self, coefficients: &Array1<Float>) -> Result<Float>;
114
115 fn penalty_gradient(&self, coefficients: &Array1<Float>) -> Result<Array1<Float>>;
117
118 fn proximal_operator(
120 &self,
121 coefficients: &Array1<Float>,
122 _step_size: Float,
123 ) -> Result<Array1<Float>> {
124 Ok(coefficients.clone())
126 }
127
128 fn is_non_smooth(&self) -> bool {
130 false
131 }
132
133 fn strength(&self) -> Float;
135
136 fn name(&self) -> &'static str;
138}
139
140pub trait OptimizationSolver: Debug + Send + Sync {
142 type Config: Debug + Clone + Send + Sync;
144
145 type Result: Debug + Clone + Send + Sync;
147
148 fn solve(
150 &self,
151 objective: &dyn Objective,
152 initial_guess: &Array1<Float>,
153 config: &Self::Config,
154 ) -> Result<Self::Result>;
155
156 fn supports_objective(&self, objective: &dyn Objective) -> bool;
158
159 fn name(&self) -> &'static str;
161
162 fn get_recommendations(&self, _data: &ObjectiveData) -> SolverRecommendations {
164 SolverRecommendations::default()
165 }
166}
167
168#[derive(Debug, Clone, Default)]
170pub struct SolverRecommendations {
171 pub max_iterations: Option<usize>,
173 pub tolerance: Option<Float>,
175 pub step_size: Option<Float>,
177 pub use_line_search: Option<bool>,
179 pub notes: Vec<String>,
181}
182
183pub trait PredictionProvider: Debug + Send + Sync {
185 fn predict(
186 &self,
187 features: &Array2<Float>,
188 coefficients: &Array1<Float>,
189 intercept: Option<Float>,
190 ) -> Result<Array1<Float>>;
191
192 fn predict_with_confidence(
193 &self,
194 features: &Array2<Float>,
195 coefficients: &Array1<Float>,
196 intercept: Option<Float>,
197 confidence_level: Float,
198 ) -> Result<PredictionWithConfidence> {
199 let predictions = self.predict(features, coefficients, intercept)?;
200 Ok(PredictionWithConfidence {
201 predictions,
202 lower_bounds: None,
203 upper_bounds: None,
204 confidence_level,
205 })
206 }
207
208 fn predict_with_uncertainty(
210 &self,
211 features: &Array2<Float>,
212 coefficients: &Array1<Float>,
213 intercept: Option<Float>,
214 ) -> Result<PredictionWithUncertainty> {
215 let predictions = self.predict(features, coefficients, intercept)?;
216 Ok(PredictionWithUncertainty {
217 predictions,
218 uncertainties: None,
219 prediction_intervals: None,
220 })
221 }
222
223 fn supports_confidence_intervals(&self) -> bool {
225 false
226 }
227
228 fn supports_uncertainty_quantification(&self) -> bool {
230 false
231 }
232
233 fn name(&self) -> &'static str;
235}
236
237#[derive(Debug, Clone)]
239pub struct PredictionWithConfidence {
240 pub predictions: Array1<Float>,
242 pub lower_bounds: Option<Array1<Float>>,
244 pub upper_bounds: Option<Array1<Float>>,
246 pub confidence_level: Float,
248}
249
250#[derive(Debug, Clone)]
252pub struct PredictionWithUncertainty {
253 pub predictions: Array1<Float>,
255 pub uncertainties: Option<Array1<Float>>,
257 pub prediction_intervals: Option<Array2<Float>>, }
260
261#[derive(Debug)]
263pub struct LinearPredictionProvider;
264
265impl PredictionProvider for LinearPredictionProvider {
266 fn predict(
267 &self,
268 features: &Array2<Float>,
269 coefficients: &Array1<Float>,
270 intercept: Option<Float>,
271 ) -> Result<Array1<Float>> {
272 let mut predictions = features.dot(coefficients);
273 if let Some(intercept_val) = intercept {
274 predictions += intercept_val;
275 }
276 Ok(predictions)
277 }
278
279 fn name(&self) -> &'static str {
280 "LinearPrediction"
281 }
282}
283
284#[derive(Debug)]
286pub struct ProbabilisticPredictionProvider;
287
288impl PredictionProvider for ProbabilisticPredictionProvider {
289 fn predict(
290 &self,
291 features: &Array2<Float>,
292 coefficients: &Array1<Float>,
293 intercept: Option<Float>,
294 ) -> Result<Array1<Float>> {
295 let linear_predictions =
296 LinearPredictionProvider.predict(features, coefficients, intercept)?;
297 let probabilities = linear_predictions.mapv(|x| 1.0 / (1.0 + (-x).exp()));
299 Ok(probabilities)
300 }
301
302 fn supports_confidence_intervals(&self) -> bool {
303 true
304 }
305
306 fn predict_with_confidence(
307 &self,
308 features: &Array2<Float>,
309 coefficients: &Array1<Float>,
310 intercept: Option<Float>,
311 confidence_level: Float,
312 ) -> Result<PredictionWithConfidence> {
313 let predictions = self.predict(features, coefficients, intercept)?;
314
315 let _linear_predictions =
318 LinearPredictionProvider.predict(features, coefficients, intercept)?;
319 let variances = Array1::ones(features.nrows()); let z_score = match confidence_level {
322 0.90 => 1.645,
323 0.95 => 1.96,
324 0.99 => 2.576,
325 _ => 1.96, };
327
328 let margins = variances.mapv(|v: Float| z_score * v.sqrt());
329 let lower_bounds = &predictions - &margins;
330 let upper_bounds = &predictions + &margins;
331
332 Ok(PredictionWithConfidence {
333 predictions,
334 lower_bounds: Some(lower_bounds),
335 upper_bounds: Some(upper_bounds),
336 confidence_level,
337 })
338 }
339
340 fn name(&self) -> &'static str {
341 "ProbabilisticPrediction"
342 }
343}
344
345#[derive(Debug)]
347pub struct BayesianPredictionProvider {
348 pub posterior_covariance: Option<Array2<Float>>,
350}
351
352impl BayesianPredictionProvider {
353 pub fn new(posterior_covariance: Option<Array2<Float>>) -> Self {
354 Self {
355 posterior_covariance,
356 }
357 }
358}
359
360impl PredictionProvider for BayesianPredictionProvider {
361 fn predict(
362 &self,
363 features: &Array2<Float>,
364 coefficients: &Array1<Float>,
365 intercept: Option<Float>,
366 ) -> Result<Array1<Float>> {
367 LinearPredictionProvider.predict(features, coefficients, intercept)
368 }
369
370 fn supports_uncertainty_quantification(&self) -> bool {
371 self.posterior_covariance.is_some()
372 }
373
374 fn predict_with_uncertainty(
375 &self,
376 features: &Array2<Float>,
377 coefficients: &Array1<Float>,
378 intercept: Option<Float>,
379 ) -> Result<PredictionWithUncertainty> {
380 let predictions = self.predict(features, coefficients, intercept)?;
381
382 if let Some(ref cov) = self.posterior_covariance {
383 let mut uncertainties = Array1::zeros(features.nrows());
385
386 for (i, row) in features.axis_iter(Axis(0)).enumerate() {
387 let variance = row.dot(&cov.dot(&row));
388 uncertainties[i] = variance.sqrt();
389 }
390
391 let z_score = 1.96;
393 let lower_bounds = &predictions - z_score * &uncertainties;
394 let upper_bounds = &predictions + z_score * &uncertainties;
395
396 let mut prediction_intervals = Array2::zeros((features.nrows(), 2));
397 prediction_intervals.column_mut(0).assign(&lower_bounds);
398 prediction_intervals.column_mut(1).assign(&upper_bounds);
399
400 Ok(PredictionWithUncertainty {
401 predictions,
402 uncertainties: Some(uncertainties),
403 prediction_intervals: Some(prediction_intervals),
404 })
405 } else {
406 Ok(PredictionWithUncertainty {
407 predictions,
408 uncertainties: None,
409 prediction_intervals: None,
410 })
411 }
412 }
413
414 fn name(&self) -> &'static str {
415 "BayesianPrediction"
416 }
417}
418
419#[derive(Debug, Clone)]
421pub struct ModularConfig {
422 pub max_iterations: usize,
424 pub tolerance: Float,
426 pub verbose: bool,
428 pub random_seed: Option<u64>,
430}
431
432impl Default for ModularConfig {
433 fn default() -> Self {
434 Self {
435 max_iterations: 1000,
436 tolerance: 1e-6,
437 verbose: false,
438 random_seed: None,
439 }
440 }
441}
442
443#[derive(Debug, Clone)]
445pub struct OptimizationResult {
446 pub coefficients: Array1<Float>,
448 pub intercept: Option<Float>,
450 pub objective_value: Float,
452 pub n_iterations: usize,
454 pub converged: bool,
456 pub solver_info: SolverInfo,
458}
459
460#[derive(Debug, Clone)]
462pub struct SolverInfo {
463 pub solver_name: String,
465 pub metrics: std::collections::HashMap<String, Float>,
467 pub warnings: Vec<String>,
469 pub convergence_history: Option<Array1<Float>>,
471}
472
473#[derive(Debug)]
475pub struct ModularFramework {
476 config: ModularConfig,
477}
478
479impl ModularFramework {
480 pub fn new() -> Self {
482 Self {
483 config: ModularConfig::default(),
484 }
485 }
486
487 pub fn with_config(config: ModularConfig) -> Self {
489 Self { config }
490 }
491
492 pub fn solve<S: OptimizationSolver + ?Sized>(
494 &self,
495 loss: &dyn LossFunction,
496 regularization: Option<&dyn Regularization>,
497 solver: &S,
498 data: &ObjectiveData,
499 initial_guess: Option<&Array1<Float>>,
500 ) -> Result<OptimizationResult> {
501 let objective = CompositeObjective::new(loss, regularization);
503
504 let n_features = data.features.ncols();
506 let init = initial_guess
507 .cloned()
508 .unwrap_or_else(|| Array1::zeros(n_features));
509
510 let solver_config = self.create_solver_config::<S>(&objective, data)?;
512
513 let solver_result = solver.solve(&objective, &init, &solver_config)?;
515
516 self.convert_result::<S>(solver_result, &objective, data)
518 }
519
520 fn create_solver_config<S: OptimizationSolver + ?Sized>(
522 &self,
523 _objective: &dyn Objective,
524 _data: &ObjectiveData,
525 ) -> Result<S::Config> {
526 let solver_name = std::any::type_name::<S>();
528
529 Err(SklearsError::InvalidOperation(format!(
532 "Config conversion not implemented for solver: {}",
533 solver_name
534 )))
535 }
536
537 fn convert_result<S: OptimizationSolver + ?Sized>(
539 &self,
540 _solver_result: S::Result,
541 _objective: &dyn Objective,
542 _data: &ObjectiveData,
543 ) -> Result<OptimizationResult> {
544 let result_type = std::any::type_name::<S::Result>();
547
548 Err(SklearsError::InvalidOperation(format!(
549 "Result conversion not implemented for result type: {}",
550 result_type
551 )))
552 }
553}
554
555impl Default for ModularFramework {
556 fn default() -> Self {
557 Self::new()
558 }
559}
560
561#[derive(Debug)]
563pub struct CompositeObjective<'a> {
564 loss: &'a dyn LossFunction,
565 regularization: Option<&'a dyn Regularization>,
566}
567
568impl<'a> CompositeObjective<'a> {
569 pub fn new(loss: &'a dyn LossFunction, regularization: Option<&'a dyn Regularization>) -> Self {
571 Self {
572 loss,
573 regularization,
574 }
575 }
576}
577
578impl<'a> Objective for CompositeObjective<'a> {
579 fn value(&self, coefficients: &Array1<Float>, data: &ObjectiveData) -> Result<Float> {
580 let predictions = data.features.dot(coefficients);
582
583 let loss_value = self.loss.loss(&data.targets, &predictions)?;
585
586 let regularization_value = if let Some(reg) = self.regularization {
588 reg.penalty(coefficients)?
589 } else {
590 0.0
591 };
592
593 Ok(loss_value + regularization_value)
594 }
595
596 fn gradient(
597 &self,
598 coefficients: &Array1<Float>,
599 data: &ObjectiveData,
600 ) -> Result<Array1<Float>> {
601 let predictions = data.features.dot(coefficients);
603
604 let loss_grad_pred = self.loss.loss_derivative(&data.targets, &predictions)?;
606
607 let mut gradient = data.features.t().dot(&loss_grad_pred);
609
610 if let Some(reg) = self.regularization {
612 let reg_grad = reg.penalty_gradient(coefficients)?;
613 gradient = gradient + reg_grad;
614 }
615
616 Ok(gradient)
617 }
618
619 fn supports_hessian(&self) -> bool {
620 false
622 }
623}
624
625pub fn create_modular_linear_regression(
627 loss: Box<dyn LossFunction>,
628 regularization: Option<Box<dyn Regularization>>,
629 solver: Box<dyn OptimizationSolver<Config = ModularConfig, Result = OptimizationResult>>,
630) -> ModularLinearModel {
631 ModularLinearModel {
632 loss,
633 regularization,
634 solver,
635 framework: ModularFramework::new(),
636 }
637}
638
639#[derive(Debug)]
641pub struct ModularLinearModel {
642 loss: Box<dyn LossFunction>,
643 regularization: Option<Box<dyn Regularization>>,
644 solver: Box<dyn OptimizationSolver<Config = ModularConfig, Result = OptimizationResult>>,
645 framework: ModularFramework,
646}
647
648impl ModularLinearModel {
649 pub fn fit(&self, X: &Array2<Float>, y: &Array1<Float>) -> Result<OptimizationResult> {
651 let data = ObjectiveData {
652 features: X.clone(),
653 targets: y.clone(),
654 sample_weights: None,
655 metadata: ObjectiveMetadata::default(),
656 };
657
658 self.framework.solve(
659 self.loss.as_ref(),
660 self.regularization.as_deref(),
661 self.solver.as_ref(),
662 &data,
663 None,
664 )
665 }
666
667 pub fn predict(&self, X: &Array2<Float>, result: &OptimizationResult) -> Result<Array1<Float>> {
669 let predictions = X.dot(&result.coefficients);
670
671 if let Some(intercept) = result.intercept {
673 Ok(predictions + intercept)
674 } else {
675 Ok(predictions)
676 }
677 }
678}
679
680#[allow(non_snake_case)]
681#[cfg(test)]
682mod tests {
683 use super::*;
684 use scirs2_core::ndarray::Array;
685
686 #[derive(Debug)]
688 struct DummyLoss;
689
690 impl LossFunction for DummyLoss {
691 fn loss(&self, y_true: &Array1<Float>, y_pred: &Array1<Float>) -> Result<Float> {
692 Ok(((y_true - y_pred).mapv(|x| x * x)).sum() / (2.0 * y_true.len() as Float))
693 }
694
695 fn loss_derivative(
696 &self,
697 y_true: &Array1<Float>,
698 y_pred: &Array1<Float>,
699 ) -> Result<Array1<Float>> {
700 Ok((y_pred - y_true) / (y_true.len() as Float))
701 }
702
703 fn name(&self) -> &'static str {
704 "SquaredLoss"
705 }
706 }
707
708 #[derive(Debug)]
710 struct DummyRegularization {
711 alpha: Float,
712 }
713
714 impl Regularization for DummyRegularization {
715 fn penalty(&self, coefficients: &Array1<Float>) -> Result<Float> {
716 Ok(0.5 * self.alpha * coefficients.mapv(|x| x * x).sum())
717 }
718
719 fn penalty_gradient(&self, coefficients: &Array1<Float>) -> Result<Array1<Float>> {
720 Ok(self.alpha * coefficients)
721 }
722
723 fn strength(&self) -> Float {
724 self.alpha
725 }
726
727 fn name(&self) -> &'static str {
728 "L2Regularization"
729 }
730 }
731
732 #[test]
733 fn test_composite_objective() {
734 let loss = DummyLoss;
735 let regularization = DummyRegularization { alpha: 0.1 };
736 let objective = CompositeObjective::new(&loss, Some(®ularization));
737
738 let data = ObjectiveData {
739 features: Array::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(),
740 targets: Array::from_vec(vec![1.0, 2.0, 3.0]),
741 sample_weights: None,
742 metadata: ObjectiveMetadata::default(),
743 };
744
745 let coefficients = Array::from_vec(vec![0.5, 0.5]);
746
747 let value = objective.value(&coefficients, &data);
749 assert!(value.is_ok());
750
751 let gradient = objective.gradient(&coefficients, &data);
753 assert!(gradient.is_ok());
754 }
755
756 #[test]
757 fn test_modular_config() {
758 let config = ModularConfig::default();
759 assert_eq!(config.max_iterations, 1000);
760 assert_eq!(config.tolerance, 1e-6);
761 assert!(!config.verbose);
762 assert!(config.random_seed.is_none());
763 }
764
765 #[test]
766 fn test_loss_function_interface() {
767 let loss = DummyLoss;
768 let y_true = Array::from_vec(vec![1.0, 2.0, 3.0]);
769 let y_pred = Array::from_vec(vec![1.1, 1.9, 3.1]);
770
771 let loss_value = loss.loss(&y_true, &y_pred).unwrap();
772 assert!(loss_value >= 0.0);
773
774 let derivative = loss.loss_derivative(&y_true, &y_pred).unwrap();
775 assert_eq!(derivative.len(), y_true.len());
776
777 assert_eq!(loss.name(), "SquaredLoss");
778 assert!(!loss.is_classification());
779 }
780
781 #[test]
782 fn test_regularization_interface() {
783 let reg = DummyRegularization { alpha: 0.1 };
784 let coefficients = Array::from_vec(vec![1.0, -1.0, 2.0]);
785
786 let penalty = reg.penalty(&coefficients).unwrap();
787 assert!(penalty >= 0.0);
788
789 let gradient = reg.penalty_gradient(&coefficients).unwrap();
790 assert_eq!(gradient.len(), coefficients.len());
791
792 assert_eq!(reg.strength(), 0.1);
793 assert_eq!(reg.name(), "L2Regularization");
794 assert!(!reg.is_non_smooth());
795 }
796}