sklears_linear/
modular_framework.rs

1//! Modular Framework for Linear Models
2//!
3//! This module implements a trait-based system for pluggable solvers, loss functions,
4//! and regularization schemes. This addresses TODO items for architectural improvements:
5//! - Separate solver implementations into trait-based system
6//! - Create pluggable loss function framework  
7//! - Implement composable regularization schemes
8//! - Add extensible prediction interface
9
10use scirs2_core::ndarray::{Array1, Array2, Axis};
11use sklears_core::{
12    error::{Result, SklearsError},
13    types::Float,
14};
15use std::fmt::Debug;
16
17/// Trait for optimization objectives that can be minimized
18pub trait Objective: Debug + Send + Sync {
19    /// Compute the objective value
20    fn value(&self, coefficients: &Array1<Float>, data: &ObjectiveData) -> Result<Float>;
21
22    /// Compute the gradient of the objective
23    fn gradient(&self, coefficients: &Array1<Float>, data: &ObjectiveData)
24        -> Result<Array1<Float>>;
25
26    /// Compute both value and gradient (often more efficient than separate calls)
27    fn value_and_gradient(
28        &self,
29        coefficients: &Array1<Float>,
30        data: &ObjectiveData,
31    ) -> Result<(Float, Array1<Float>)> {
32        let value = self.value(coefficients, data)?;
33        let gradient = self.gradient(coefficients, data)?;
34        Ok((value, gradient))
35    }
36
37    /// Check if the objective supports Hessian computation
38    fn supports_hessian(&self) -> bool {
39        false
40    }
41
42    /// Compute the Hessian matrix (for second-order methods)
43    fn hessian(
44        &self,
45        _coefficients: &Array1<Float>,
46        _data: &ObjectiveData,
47    ) -> Result<Array2<Float>> {
48        Err(SklearsError::InvalidOperation(
49            "Hessian computation not supported for this objective".to_string(),
50        ))
51    }
52}
53
54/// Data structure containing all information needed for objective computation
55#[derive(Debug, Clone)]
56pub struct ObjectiveData {
57    /// Feature matrix (n_samples, n_features)
58    pub features: Array2<Float>,
59    /// Target values (n_samples,)
60    pub targets: Array1<Float>,
61    /// Sample weights (optional)
62    pub sample_weights: Option<Array1<Float>>,
63    /// Additional metadata
64    pub metadata: ObjectiveMetadata,
65}
66
67/// Metadata for objective computation
68#[derive(Debug, Clone, Default)]
69pub struct ObjectiveMetadata {
70    /// Whether to fit intercept
71    pub fit_intercept: bool,
72    /// Feature scaling factors (for numerical stability)
73    pub feature_scale: Option<Array1<Float>>,
74    /// Target scaling factors
75    pub target_scale: Option<Float>,
76}
77
78/// Trait for loss functions that measure prediction error
79pub trait LossFunction: Debug + Send + Sync {
80    /// Compute the loss value for predictions
81    fn loss(&self, y_true: &Array1<Float>, y_pred: &Array1<Float>) -> Result<Float>;
82
83    /// Compute the derivative of loss with respect to predictions
84    fn loss_derivative(
85        &self,
86        y_true: &Array1<Float>,
87        y_pred: &Array1<Float>,
88    ) -> Result<Array1<Float>>;
89
90    /// Compute both loss and derivative (often more efficient)
91    fn loss_and_derivative(
92        &self,
93        y_true: &Array1<Float>,
94        y_pred: &Array1<Float>,
95    ) -> Result<(Float, Array1<Float>)> {
96        let loss = self.loss(y_true, y_pred)?;
97        let derivative = self.loss_derivative(y_true, y_pred)?;
98        Ok((loss, derivative))
99    }
100
101    /// Check if this is a classification loss (vs regression)
102    fn is_classification(&self) -> bool {
103        false
104    }
105
106    /// Get the name of this loss function
107    fn name(&self) -> &'static str;
108}
109
110/// Trait for regularization penalties
111pub trait Regularization: Debug + Send + Sync {
112    /// Compute the regularization penalty value
113    fn penalty(&self, coefficients: &Array1<Float>) -> Result<Float>;
114
115    /// Compute the regularization gradient (subgradient for non-smooth penalties)
116    fn penalty_gradient(&self, coefficients: &Array1<Float>) -> Result<Array1<Float>>;
117
118    /// Apply the proximal operator (for proximal gradient methods)
119    fn proximal_operator(
120        &self,
121        coefficients: &Array1<Float>,
122        _step_size: Float,
123    ) -> Result<Array1<Float>> {
124        // Default implementation: no proximal operator (smooth penalties)
125        Ok(coefficients.clone())
126    }
127
128    /// Check if this regularization is non-smooth (requires proximal methods)
129    fn is_non_smooth(&self) -> bool {
130        false
131    }
132
133    /// Get the regularization strength
134    fn strength(&self) -> Float;
135
136    /// Get the name of this regularization
137    fn name(&self) -> &'static str;
138}
139
140/// Trait for optimization solvers
141pub trait OptimizationSolver: Debug + Send + Sync {
142    /// Configuration type for this solver
143    type Config: Debug + Clone + Send + Sync;
144
145    /// Result type returned by this solver
146    type Result: Debug + Clone + Send + Sync;
147
148    /// Solve the optimization problem
149    fn solve(
150        &self,
151        objective: &dyn Objective,
152        initial_guess: &Array1<Float>,
153        config: &Self::Config,
154    ) -> Result<Self::Result>;
155
156    /// Check if this solver supports the given objective type
157    fn supports_objective(&self, objective: &dyn Objective) -> bool;
158
159    /// Get the name of this solver
160    fn name(&self) -> &'static str;
161
162    /// Get solver-specific recommendations for the given problem
163    fn get_recommendations(&self, _data: &ObjectiveData) -> SolverRecommendations {
164        SolverRecommendations::default()
165    }
166}
167
168/// Recommendations for solver configuration
169#[derive(Debug, Clone, Default)]
170pub struct SolverRecommendations {
171    /// Recommended maximum iterations
172    pub max_iterations: Option<usize>,
173    /// Recommended convergence tolerance
174    pub tolerance: Option<Float>,
175    /// Recommended step size or learning rate
176    pub step_size: Option<Float>,
177    /// Whether to use line search
178    pub use_line_search: Option<bool>,
179    /// Additional solver-specific advice
180    pub notes: Vec<String>,
181}
182
183/// Extensible prediction interface supporting different prediction types
184pub trait PredictionProvider: Debug + Send + Sync {
185    fn predict(
186        &self,
187        features: &Array2<Float>,
188        coefficients: &Array1<Float>,
189        intercept: Option<Float>,
190    ) -> Result<Array1<Float>>;
191
192    fn predict_with_confidence(
193        &self,
194        features: &Array2<Float>,
195        coefficients: &Array1<Float>,
196        intercept: Option<Float>,
197        confidence_level: Float,
198    ) -> Result<PredictionWithConfidence> {
199        let predictions = self.predict(features, coefficients, intercept)?;
200        Ok(PredictionWithConfidence {
201            predictions,
202            lower_bounds: None,
203            upper_bounds: None,
204            confidence_level,
205        })
206    }
207
208    /// Prediction with uncertainty quantification (if supported)
209    fn predict_with_uncertainty(
210        &self,
211        features: &Array2<Float>,
212        coefficients: &Array1<Float>,
213        intercept: Option<Float>,
214    ) -> Result<PredictionWithUncertainty> {
215        let predictions = self.predict(features, coefficients, intercept)?;
216        Ok(PredictionWithUncertainty {
217            predictions,
218            uncertainties: None,
219            prediction_intervals: None,
220        })
221    }
222
223    /// Check if this provider supports confidence intervals
224    fn supports_confidence_intervals(&self) -> bool {
225        false
226    }
227
228    /// Check if this provider supports uncertainty quantification
229    fn supports_uncertainty_quantification(&self) -> bool {
230        false
231    }
232
233    /// Get the name of this prediction provider
234    fn name(&self) -> &'static str;
235}
236
237/// Prediction result with confidence intervals
238#[derive(Debug, Clone)]
239pub struct PredictionWithConfidence {
240    /// Point predictions
241    pub predictions: Array1<Float>,
242    /// Lower confidence bounds (optional)
243    pub lower_bounds: Option<Array1<Float>>,
244    /// Upper confidence bounds (optional)
245    pub upper_bounds: Option<Array1<Float>>,
246    /// Confidence level used (e.g., 0.95 for 95% confidence)
247    pub confidence_level: Float,
248}
249
250/// Prediction result with uncertainty quantification
251#[derive(Debug, Clone)]
252pub struct PredictionWithUncertainty {
253    /// Point predictions
254    pub predictions: Array1<Float>,
255    /// Prediction uncertainties (standard errors)
256    pub uncertainties: Option<Array1<Float>>,
257    /// Prediction intervals
258    pub prediction_intervals: Option<Array2<Float>>, // (n_samples, 2) for [lower, upper]
259}
260
261/// Standard linear prediction provider
262#[derive(Debug)]
263pub struct LinearPredictionProvider;
264
265impl PredictionProvider for LinearPredictionProvider {
266    fn predict(
267        &self,
268        features: &Array2<Float>,
269        coefficients: &Array1<Float>,
270        intercept: Option<Float>,
271    ) -> Result<Array1<Float>> {
272        let mut predictions = features.dot(coefficients);
273        if let Some(intercept_val) = intercept {
274            predictions += intercept_val;
275        }
276        Ok(predictions)
277    }
278
279    fn name(&self) -> &'static str {
280        "LinearPrediction"
281    }
282}
283
284/// Probabilistic prediction provider for classification
285#[derive(Debug)]
286pub struct ProbabilisticPredictionProvider;
287
288impl PredictionProvider for ProbabilisticPredictionProvider {
289    fn predict(
290        &self,
291        features: &Array2<Float>,
292        coefficients: &Array1<Float>,
293        intercept: Option<Float>,
294    ) -> Result<Array1<Float>> {
295        let linear_predictions =
296            LinearPredictionProvider.predict(features, coefficients, intercept)?;
297        // Apply sigmoid transformation for binary classification
298        let probabilities = linear_predictions.mapv(|x| 1.0 / (1.0 + (-x).exp()));
299        Ok(probabilities)
300    }
301
302    fn supports_confidence_intervals(&self) -> bool {
303        true
304    }
305
306    fn predict_with_confidence(
307        &self,
308        features: &Array2<Float>,
309        coefficients: &Array1<Float>,
310        intercept: Option<Float>,
311        confidence_level: Float,
312    ) -> Result<PredictionWithConfidence> {
313        let predictions = self.predict(features, coefficients, intercept)?;
314
315        // For logistic regression, we can compute confidence intervals
316        // based on the variance of the linear predictor
317        let _linear_predictions =
318            LinearPredictionProvider.predict(features, coefficients, intercept)?;
319        let variances = Array1::ones(features.nrows()); // Simplified - would need proper variance calculation
320
321        let z_score = match confidence_level {
322            0.90 => 1.645,
323            0.95 => 1.96,
324            0.99 => 2.576,
325            _ => 1.96, // Default to 95%
326        };
327
328        let margins = variances.mapv(|v: Float| z_score * v.sqrt());
329        let lower_bounds = &predictions - &margins;
330        let upper_bounds = &predictions + &margins;
331
332        Ok(PredictionWithConfidence {
333            predictions,
334            lower_bounds: Some(lower_bounds),
335            upper_bounds: Some(upper_bounds),
336            confidence_level,
337        })
338    }
339
340    fn name(&self) -> &'static str {
341        "ProbabilisticPrediction"
342    }
343}
344
345/// Bayesian prediction provider with uncertainty quantification
346#[derive(Debug)]
347pub struct BayesianPredictionProvider {
348    /// Posterior covariance matrix
349    pub posterior_covariance: Option<Array2<Float>>,
350}
351
352impl BayesianPredictionProvider {
353    pub fn new(posterior_covariance: Option<Array2<Float>>) -> Self {
354        Self {
355            posterior_covariance,
356        }
357    }
358}
359
360impl PredictionProvider for BayesianPredictionProvider {
361    fn predict(
362        &self,
363        features: &Array2<Float>,
364        coefficients: &Array1<Float>,
365        intercept: Option<Float>,
366    ) -> Result<Array1<Float>> {
367        LinearPredictionProvider.predict(features, coefficients, intercept)
368    }
369
370    fn supports_uncertainty_quantification(&self) -> bool {
371        self.posterior_covariance.is_some()
372    }
373
374    fn predict_with_uncertainty(
375        &self,
376        features: &Array2<Float>,
377        coefficients: &Array1<Float>,
378        intercept: Option<Float>,
379    ) -> Result<PredictionWithUncertainty> {
380        let predictions = self.predict(features, coefficients, intercept)?;
381
382        if let Some(ref cov) = self.posterior_covariance {
383            // Compute prediction uncertainties using the posterior covariance
384            let mut uncertainties = Array1::zeros(features.nrows());
385
386            for (i, row) in features.axis_iter(Axis(0)).enumerate() {
387                let variance = row.dot(&cov.dot(&row));
388                uncertainties[i] = variance.sqrt();
389            }
390
391            // Compute 95% prediction intervals
392            let z_score = 1.96;
393            let lower_bounds = &predictions - z_score * &uncertainties;
394            let upper_bounds = &predictions + z_score * &uncertainties;
395
396            let mut prediction_intervals = Array2::zeros((features.nrows(), 2));
397            prediction_intervals.column_mut(0).assign(&lower_bounds);
398            prediction_intervals.column_mut(1).assign(&upper_bounds);
399
400            Ok(PredictionWithUncertainty {
401                predictions,
402                uncertainties: Some(uncertainties),
403                prediction_intervals: Some(prediction_intervals),
404            })
405        } else {
406            Ok(PredictionWithUncertainty {
407                predictions,
408                uncertainties: None,
409                prediction_intervals: None,
410            })
411        }
412    }
413
414    fn name(&self) -> &'static str {
415        "BayesianPrediction"
416    }
417}
418
419/// Configuration for the modular framework
420#[derive(Debug, Clone)]
421pub struct ModularConfig {
422    /// Maximum iterations for optimization
423    pub max_iterations: usize,
424    /// Convergence tolerance
425    pub tolerance: Float,
426    /// Whether to enable verbose output
427    pub verbose: bool,
428    /// Random seed for reproducibility
429    pub random_seed: Option<u64>,
430}
431
432impl Default for ModularConfig {
433    fn default() -> Self {
434        Self {
435            max_iterations: 1000,
436            tolerance: 1e-6,
437            verbose: false,
438            random_seed: None,
439        }
440    }
441}
442
443/// Result of optimization through the modular framework
444#[derive(Debug, Clone)]
445pub struct OptimizationResult {
446    /// Final coefficient values
447    pub coefficients: Array1<Float>,
448    /// Final intercept value (if fitted)
449    pub intercept: Option<Float>,
450    /// Final objective value
451    pub objective_value: Float,
452    /// Number of iterations performed
453    pub n_iterations: usize,
454    /// Whether optimization converged
455    pub converged: bool,
456    /// Solver-specific information
457    pub solver_info: SolverInfo,
458}
459
460/// Information about the solver execution
461#[derive(Debug, Clone)]
462pub struct SolverInfo {
463    /// Name of the solver used
464    pub solver_name: String,
465    /// Solver-specific metrics
466    pub metrics: std::collections::HashMap<String, Float>,
467    /// Warning messages
468    pub warnings: Vec<String>,
469    /// Convergence history (if available)
470    pub convergence_history: Option<Array1<Float>>,
471}
472
473/// The main modular framework that coordinates components
474#[derive(Debug)]
475pub struct ModularFramework {
476    config: ModularConfig,
477}
478
479impl ModularFramework {
480    /// Create a new modular framework with default configuration
481    pub fn new() -> Self {
482        Self {
483            config: ModularConfig::default(),
484        }
485    }
486
487    /// Create a new modular framework with custom configuration
488    pub fn with_config(config: ModularConfig) -> Self {
489        Self { config }
490    }
491
492    /// Solve an optimization problem using the modular components
493    pub fn solve<S: OptimizationSolver + ?Sized>(
494        &self,
495        loss: &dyn LossFunction,
496        regularization: Option<&dyn Regularization>,
497        solver: &S,
498        data: &ObjectiveData,
499        initial_guess: Option<&Array1<Float>>,
500    ) -> Result<OptimizationResult> {
501        // Create the composite objective combining loss and regularization
502        let objective = CompositeObjective::new(loss, regularization);
503
504        // Create initial guess if not provided
505        let n_features = data.features.ncols();
506        let init = initial_guess
507            .cloned()
508            .unwrap_or_else(|| Array1::zeros(n_features));
509
510        // Create solver config from framework config
511        let solver_config = self.create_solver_config::<S>(&objective, data)?;
512
513        // Solve the problem
514        let solver_result = solver.solve(&objective, &init, &solver_config)?;
515
516        // Convert solver result to framework result
517        self.convert_result::<S>(solver_result, &objective, data)
518    }
519
520    /// Create solver-specific configuration from framework configuration
521    fn create_solver_config<S: OptimizationSolver + ?Sized>(
522        &self,
523        _objective: &dyn Objective,
524        _data: &ObjectiveData,
525    ) -> Result<S::Config> {
526        // Get solver recommendations and use them to create config
527        let solver_name = std::any::type_name::<S>();
528
529        // For now, we'll return an error indicating the specific solver type
530        // In practice, each solver would register a config converter
531        Err(SklearsError::InvalidOperation(format!(
532            "Config conversion not implemented for solver: {}",
533            solver_name
534        )))
535    }
536
537    /// Convert solver-specific result to framework result
538    fn convert_result<S: OptimizationSolver + ?Sized>(
539        &self,
540        _solver_result: S::Result,
541        _objective: &dyn Objective,
542        _data: &ObjectiveData,
543    ) -> Result<OptimizationResult> {
544        // For now, we'll return an error indicating the specific result type
545        // In practice, each solver result would implement a conversion trait
546        let result_type = std::any::type_name::<S::Result>();
547
548        Err(SklearsError::InvalidOperation(format!(
549            "Result conversion not implemented for result type: {}",
550            result_type
551        )))
552    }
553}
554
555impl Default for ModularFramework {
556    fn default() -> Self {
557        Self::new()
558    }
559}
560
561/// A composite objective that combines a loss function with optional regularization
562#[derive(Debug)]
563pub struct CompositeObjective<'a> {
564    loss: &'a dyn LossFunction,
565    regularization: Option<&'a dyn Regularization>,
566}
567
568impl<'a> CompositeObjective<'a> {
569    /// Create a new composite objective
570    pub fn new(loss: &'a dyn LossFunction, regularization: Option<&'a dyn Regularization>) -> Self {
571        Self {
572            loss,
573            regularization,
574        }
575    }
576}
577
578impl<'a> Objective for CompositeObjective<'a> {
579    fn value(&self, coefficients: &Array1<Float>, data: &ObjectiveData) -> Result<Float> {
580        // Compute predictions
581        let predictions = data.features.dot(coefficients);
582
583        // Compute loss
584        let loss_value = self.loss.loss(&data.targets, &predictions)?;
585
586        // Add regularization if present
587        let regularization_value = if let Some(reg) = self.regularization {
588            reg.penalty(coefficients)?
589        } else {
590            0.0
591        };
592
593        Ok(loss_value + regularization_value)
594    }
595
596    fn gradient(
597        &self,
598        coefficients: &Array1<Float>,
599        data: &ObjectiveData,
600    ) -> Result<Array1<Float>> {
601        // Compute predictions
602        let predictions = data.features.dot(coefficients);
603
604        // Compute loss derivative with respect to predictions
605        let loss_grad_pred = self.loss.loss_derivative(&data.targets, &predictions)?;
606
607        // Compute gradient with respect to coefficients using chain rule
608        let mut gradient = data.features.t().dot(&loss_grad_pred);
609
610        // Add regularization gradient if present
611        if let Some(reg) = self.regularization {
612            let reg_grad = reg.penalty_gradient(coefficients)?;
613            gradient = gradient + reg_grad;
614        }
615
616        Ok(gradient)
617    }
618
619    fn supports_hessian(&self) -> bool {
620        // For simplicity, we don't support Hessian computation in the composite objective
621        false
622    }
623}
624
625/// Utility function to create a modular linear regression solver
626pub fn create_modular_linear_regression(
627    loss: Box<dyn LossFunction>,
628    regularization: Option<Box<dyn Regularization>>,
629    solver: Box<dyn OptimizationSolver<Config = ModularConfig, Result = OptimizationResult>>,
630) -> ModularLinearModel {
631    ModularLinearModel {
632        loss,
633        regularization,
634        solver,
635        framework: ModularFramework::new(),
636    }
637}
638
639/// A linear model built using the modular framework
640#[derive(Debug)]
641pub struct ModularLinearModel {
642    loss: Box<dyn LossFunction>,
643    regularization: Option<Box<dyn Regularization>>,
644    solver: Box<dyn OptimizationSolver<Config = ModularConfig, Result = OptimizationResult>>,
645    framework: ModularFramework,
646}
647
648impl ModularLinearModel {
649    /// Fit the model to training data
650    pub fn fit(&self, X: &Array2<Float>, y: &Array1<Float>) -> Result<OptimizationResult> {
651        let data = ObjectiveData {
652            features: X.clone(),
653            targets: y.clone(),
654            sample_weights: None,
655            metadata: ObjectiveMetadata::default(),
656        };
657
658        self.framework.solve(
659            self.loss.as_ref(),
660            self.regularization.as_deref(),
661            self.solver.as_ref(),
662            &data,
663            None,
664        )
665    }
666
667    /// Make predictions using the fitted model
668    pub fn predict(&self, X: &Array2<Float>, result: &OptimizationResult) -> Result<Array1<Float>> {
669        let predictions = X.dot(&result.coefficients);
670
671        // Add intercept if present
672        if let Some(intercept) = result.intercept {
673            Ok(predictions + intercept)
674        } else {
675            Ok(predictions)
676        }
677    }
678}
679
680#[allow(non_snake_case)]
681#[cfg(test)]
682mod tests {
683    use super::*;
684    use scirs2_core::ndarray::Array;
685
686    // Test helper: dummy loss function
687    #[derive(Debug)]
688    struct DummyLoss;
689
690    impl LossFunction for DummyLoss {
691        fn loss(&self, y_true: &Array1<Float>, y_pred: &Array1<Float>) -> Result<Float> {
692            Ok(((y_true - y_pred).mapv(|x| x * x)).sum() / (2.0 * y_true.len() as Float))
693        }
694
695        fn loss_derivative(
696            &self,
697            y_true: &Array1<Float>,
698            y_pred: &Array1<Float>,
699        ) -> Result<Array1<Float>> {
700            Ok((y_pred - y_true) / (y_true.len() as Float))
701        }
702
703        fn name(&self) -> &'static str {
704            "SquaredLoss"
705        }
706    }
707
708    // Test helper: dummy regularization
709    #[derive(Debug)]
710    struct DummyRegularization {
711        alpha: Float,
712    }
713
714    impl Regularization for DummyRegularization {
715        fn penalty(&self, coefficients: &Array1<Float>) -> Result<Float> {
716            Ok(0.5 * self.alpha * coefficients.mapv(|x| x * x).sum())
717        }
718
719        fn penalty_gradient(&self, coefficients: &Array1<Float>) -> Result<Array1<Float>> {
720            Ok(self.alpha * coefficients)
721        }
722
723        fn strength(&self) -> Float {
724            self.alpha
725        }
726
727        fn name(&self) -> &'static str {
728            "L2Regularization"
729        }
730    }
731
732    #[test]
733    fn test_composite_objective() {
734        let loss = DummyLoss;
735        let regularization = DummyRegularization { alpha: 0.1 };
736        let objective = CompositeObjective::new(&loss, Some(&regularization));
737
738        let data = ObjectiveData {
739            features: Array::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(),
740            targets: Array::from_vec(vec![1.0, 2.0, 3.0]),
741            sample_weights: None,
742            metadata: ObjectiveMetadata::default(),
743        };
744
745        let coefficients = Array::from_vec(vec![0.5, 0.5]);
746
747        // Test that value computation doesn't panic
748        let value = objective.value(&coefficients, &data);
749        assert!(value.is_ok());
750
751        // Test that gradient computation doesn't panic
752        let gradient = objective.gradient(&coefficients, &data);
753        assert!(gradient.is_ok());
754    }
755
756    #[test]
757    fn test_modular_config() {
758        let config = ModularConfig::default();
759        assert_eq!(config.max_iterations, 1000);
760        assert_eq!(config.tolerance, 1e-6);
761        assert!(!config.verbose);
762        assert!(config.random_seed.is_none());
763    }
764
765    #[test]
766    fn test_loss_function_interface() {
767        let loss = DummyLoss;
768        let y_true = Array::from_vec(vec![1.0, 2.0, 3.0]);
769        let y_pred = Array::from_vec(vec![1.1, 1.9, 3.1]);
770
771        let loss_value = loss.loss(&y_true, &y_pred).unwrap();
772        assert!(loss_value >= 0.0);
773
774        let derivative = loss.loss_derivative(&y_true, &y_pred).unwrap();
775        assert_eq!(derivative.len(), y_true.len());
776
777        assert_eq!(loss.name(), "SquaredLoss");
778        assert!(!loss.is_classification());
779    }
780
781    #[test]
782    fn test_regularization_interface() {
783        let reg = DummyRegularization { alpha: 0.1 };
784        let coefficients = Array::from_vec(vec![1.0, -1.0, 2.0]);
785
786        let penalty = reg.penalty(&coefficients).unwrap();
787        assert!(penalty >= 0.0);
788
789        let gradient = reg.penalty_gradient(&coefficients).unwrap();
790        assert_eq!(gradient.len(), coefficients.len());
791
792        assert_eq!(reg.strength(), 0.1);
793        assert_eq!(reg.name(), "L2Regularization");
794        assert!(!reg.is_non_smooth());
795    }
796}