tensorlogic_train/
regularization.rs

1//! Regularization techniques for training.
2//!
3//! This module provides various regularization strategies to prevent overfitting:
4//! - L1 regularization (Lasso): Encourages sparsity
5//! - L2 regularization (Ridge): Prevents large weights
6//! - Composite regularization: Combines multiple regularizers
7
8use crate::{TrainError, TrainResult};
9use scirs2_core::ndarray::{Array, Ix2};
10use std::collections::HashMap;
11
12/// Trait for regularization strategies.
13pub trait Regularizer {
14    /// Compute the regularization penalty for given parameters.
15    ///
16    /// # Arguments
17    /// * `parameters` - Model parameters to regularize
18    ///
19    /// # Returns
20    /// The regularization penalty value
21    fn compute_penalty(&self, parameters: &HashMap<String, Array<f64, Ix2>>) -> TrainResult<f64>;
22
23    /// Compute the gradient of the regularization penalty.
24    ///
25    /// # Arguments
26    /// * `parameters` - Model parameters
27    ///
28    /// # Returns
29    /// Gradients of the regularization penalty for each parameter
30    fn compute_gradient(
31        &self,
32        parameters: &HashMap<String, Array<f64, Ix2>>,
33    ) -> TrainResult<HashMap<String, Array<f64, Ix2>>>;
34}
35
36/// L1 regularization (Lasso).
37///
38/// Adds penalty proportional to the absolute value of weights: λ * Σ|w|
39/// Encourages sparsity by driving some weights to exactly zero.
40#[derive(Debug, Clone)]
41pub struct L1Regularization {
42    /// Regularization strength (lambda).
43    pub lambda: f64,
44}
45
46impl L1Regularization {
47    /// Create a new L1 regularizer.
48    ///
49    /// # Arguments
50    /// * `lambda` - Regularization strength
51    pub fn new(lambda: f64) -> Self {
52        Self { lambda }
53    }
54}
55
56impl Default for L1Regularization {
57    fn default() -> Self {
58        Self { lambda: 0.01 }
59    }
60}
61
62impl Regularizer for L1Regularization {
63    fn compute_penalty(&self, parameters: &HashMap<String, Array<f64, Ix2>>) -> TrainResult<f64> {
64        let mut penalty = 0.0;
65
66        for param in parameters.values() {
67            for &value in param.iter() {
68                penalty += value.abs();
69            }
70        }
71
72        Ok(self.lambda * penalty)
73    }
74
75    fn compute_gradient(
76        &self,
77        parameters: &HashMap<String, Array<f64, Ix2>>,
78    ) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
79        let mut gradients = HashMap::new();
80
81        for (name, param) in parameters {
82            // Gradient of L1: λ * sign(w)
83            let grad = param.mapv(|w| self.lambda * w.signum());
84            gradients.insert(name.clone(), grad);
85        }
86
87        Ok(gradients)
88    }
89}
90
91/// L2 regularization (Ridge / Weight Decay).
92///
93/// Adds penalty proportional to the square of weights: λ * Σw²
94/// Prevents weights from becoming too large.
95#[derive(Debug, Clone)]
96pub struct L2Regularization {
97    /// Regularization strength (lambda).
98    pub lambda: f64,
99}
100
101impl L2Regularization {
102    /// Create a new L2 regularizer.
103    ///
104    /// # Arguments
105    /// * `lambda` - Regularization strength
106    pub fn new(lambda: f64) -> Self {
107        Self { lambda }
108    }
109}
110
111impl Default for L2Regularization {
112    fn default() -> Self {
113        Self { lambda: 0.01 }
114    }
115}
116
117impl Regularizer for L2Regularization {
118    fn compute_penalty(&self, parameters: &HashMap<String, Array<f64, Ix2>>) -> TrainResult<f64> {
119        let mut penalty = 0.0;
120
121        for param in parameters.values() {
122            for &value in param.iter() {
123                penalty += value * value;
124            }
125        }
126
127        Ok(0.5 * self.lambda * penalty)
128    }
129
130    fn compute_gradient(
131        &self,
132        parameters: &HashMap<String, Array<f64, Ix2>>,
133    ) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
134        let mut gradients = HashMap::new();
135
136        for (name, param) in parameters {
137            // Gradient of L2: λ * w
138            let grad = param.mapv(|w| self.lambda * w);
139            gradients.insert(name.clone(), grad);
140        }
141
142        Ok(gradients)
143    }
144}
145
146/// Elastic Net regularization (combination of L1 and L2).
147///
148/// Combines L1 and L2 penalties: l1_ratio * L1 + (1 - l1_ratio) * L2
149#[derive(Debug, Clone)]
150pub struct ElasticNetRegularization {
151    /// Overall regularization strength.
152    pub lambda: f64,
153    /// Balance between L1 and L2 (0.0 = pure L2, 1.0 = pure L1).
154    pub l1_ratio: f64,
155}
156
157impl ElasticNetRegularization {
158    /// Create a new Elastic Net regularizer.
159    ///
160    /// # Arguments
161    /// * `lambda` - Overall regularization strength
162    /// * `l1_ratio` - Balance between L1 and L2 (should be in [0, 1])
163    pub fn new(lambda: f64, l1_ratio: f64) -> TrainResult<Self> {
164        if !(0.0..=1.0).contains(&l1_ratio) {
165            return Err(TrainError::InvalidParameter(
166                "l1_ratio must be between 0.0 and 1.0".to_string(),
167            ));
168        }
169        Ok(Self { lambda, l1_ratio })
170    }
171}
172
173impl Default for ElasticNetRegularization {
174    fn default() -> Self {
175        Self {
176            lambda: 0.01,
177            l1_ratio: 0.5,
178        }
179    }
180}
181
182impl Regularizer for ElasticNetRegularization {
183    fn compute_penalty(&self, parameters: &HashMap<String, Array<f64, Ix2>>) -> TrainResult<f64> {
184        let mut l1_penalty = 0.0;
185        let mut l2_penalty = 0.0;
186
187        for param in parameters.values() {
188            for &value in param.iter() {
189                l1_penalty += value.abs();
190                l2_penalty += value * value;
191            }
192        }
193
194        let penalty =
195            self.lambda * (self.l1_ratio * l1_penalty + (1.0 - self.l1_ratio) * 0.5 * l2_penalty);
196
197        Ok(penalty)
198    }
199
200    fn compute_gradient(
201        &self,
202        parameters: &HashMap<String, Array<f64, Ix2>>,
203    ) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
204        let mut gradients = HashMap::new();
205
206        for (name, param) in parameters {
207            // Gradient: λ * (l1_ratio * sign(w) + (1 - l1_ratio) * w)
208            let grad = param
209                .mapv(|w| self.lambda * (self.l1_ratio * w.signum() + (1.0 - self.l1_ratio) * w));
210            gradients.insert(name.clone(), grad);
211        }
212
213        Ok(gradients)
214    }
215}
216
217/// Composite regularization that combines multiple regularizers.
218///
219/// Useful for applying different regularization strategies simultaneously.
220#[derive(Clone)]
221pub struct CompositeRegularization {
222    regularizers: Vec<Box<dyn RegularizerClone>>,
223}
224
225impl std::fmt::Debug for CompositeRegularization {
226    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
227        f.debug_struct("CompositeRegularization")
228            .field("num_regularizers", &self.regularizers.len())
229            .finish()
230    }
231}
232
233/// Helper trait for cloning boxed regularizers.
234trait RegularizerClone: Regularizer {
235    fn clone_box(&self) -> Box<dyn RegularizerClone>;
236}
237
238impl<T: Regularizer + Clone + 'static> RegularizerClone for T {
239    fn clone_box(&self) -> Box<dyn RegularizerClone> {
240        Box::new(self.clone())
241    }
242}
243
244impl Clone for Box<dyn RegularizerClone> {
245    fn clone(&self) -> Self {
246        self.clone_box()
247    }
248}
249
250impl Regularizer for Box<dyn RegularizerClone> {
251    fn compute_penalty(&self, parameters: &HashMap<String, Array<f64, Ix2>>) -> TrainResult<f64> {
252        (**self).compute_penalty(parameters)
253    }
254
255    fn compute_gradient(
256        &self,
257        parameters: &HashMap<String, Array<f64, Ix2>>,
258    ) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
259        (**self).compute_gradient(parameters)
260    }
261}
262
263impl CompositeRegularization {
264    /// Create a new composite regularizer.
265    pub fn new() -> Self {
266        Self {
267            regularizers: Vec::new(),
268        }
269    }
270
271    /// Add a regularizer to the composite.
272    ///
273    /// # Arguments
274    /// * `regularizer` - Regularizer to add
275    pub fn add<R: Regularizer + Clone + 'static>(&mut self, regularizer: R) {
276        self.regularizers.push(Box::new(regularizer));
277    }
278
279    /// Get the number of regularizers in the composite.
280    pub fn len(&self) -> usize {
281        self.regularizers.len()
282    }
283
284    /// Check if the composite is empty.
285    pub fn is_empty(&self) -> bool {
286        self.regularizers.is_empty()
287    }
288}
289
290impl Default for CompositeRegularization {
291    fn default() -> Self {
292        Self::new()
293    }
294}
295
296impl Regularizer for CompositeRegularization {
297    fn compute_penalty(&self, parameters: &HashMap<String, Array<f64, Ix2>>) -> TrainResult<f64> {
298        let mut total_penalty = 0.0;
299
300        for regularizer in &self.regularizers {
301            total_penalty += regularizer.compute_penalty(parameters)?;
302        }
303
304        Ok(total_penalty)
305    }
306
307    fn compute_gradient(
308        &self,
309        parameters: &HashMap<String, Array<f64, Ix2>>,
310    ) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
311        let mut total_gradients: HashMap<String, Array<f64, Ix2>> = HashMap::new();
312
313        // Initialize with zeros
314        for (name, param) in parameters {
315            total_gradients.insert(name.clone(), Array::zeros(param.raw_dim()));
316        }
317
318        // Accumulate gradients from all regularizers
319        for regularizer in &self.regularizers {
320            let grads = regularizer.compute_gradient(parameters)?;
321
322            for (name, grad) in grads {
323                if let Some(total_grad) = total_gradients.get_mut(&name) {
324                    *total_grad = &*total_grad + &grad;
325                }
326            }
327        }
328
329        Ok(total_gradients)
330    }
331}
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336    use scirs2_core::ndarray::array;
337
338    #[test]
339    fn test_l1_regularization() {
340        let regularizer = L1Regularization::new(0.1);
341
342        let mut params = HashMap::new();
343        params.insert("w".to_string(), array![[1.0, -2.0], [3.0, -4.0]]);
344
345        let penalty = regularizer.compute_penalty(&params).unwrap();
346        // Expected: 0.1 * (1 + 2 + 3 + 4) = 1.0
347        assert!((penalty - 1.0).abs() < 1e-6);
348
349        let gradients = regularizer.compute_gradient(&params).unwrap();
350        let grad_w = gradients.get("w").unwrap();
351
352        // Gradient should be λ * sign(w)
353        assert_eq!(grad_w[[0, 0]], 0.1); // sign(1.0) = 1.0
354        assert_eq!(grad_w[[0, 1]], -0.1); // sign(-2.0) = -1.0
355        assert_eq!(grad_w[[1, 0]], 0.1); // sign(3.0) = 1.0
356        assert_eq!(grad_w[[1, 1]], -0.1); // sign(-4.0) = -1.0
357    }
358
359    #[test]
360    fn test_l2_regularization() {
361        let regularizer = L2Regularization::new(0.1);
362
363        let mut params = HashMap::new();
364        params.insert("w".to_string(), array![[1.0, 2.0], [3.0, 4.0]]);
365
366        let penalty = regularizer.compute_penalty(&params).unwrap();
367        // Expected: 0.5 * 0.1 * (1 + 4 + 9 + 16) = 1.5
368        assert!((penalty - 1.5).abs() < 1e-6);
369
370        let gradients = regularizer.compute_gradient(&params).unwrap();
371        let grad_w = gradients.get("w").unwrap();
372
373        // Gradient should be λ * w
374        assert!((grad_w[[0, 0]] - 0.1).abs() < 1e-10); // 0.1 * 1.0
375        assert!((grad_w[[0, 1]] - 0.2).abs() < 1e-10); // 0.1 * 2.0
376        assert!((grad_w[[1, 0]] - 0.3).abs() < 1e-10); // 0.1 * 3.0
377        assert!((grad_w[[1, 1]] - 0.4).abs() < 1e-10); // 0.1 * 4.0
378    }
379
380    #[test]
381    fn test_elastic_net_regularization() {
382        let regularizer = ElasticNetRegularization::new(0.1, 0.5).unwrap();
383
384        let mut params = HashMap::new();
385        params.insert("w".to_string(), array![[1.0, 2.0]]);
386
387        let penalty = regularizer.compute_penalty(&params).unwrap();
388        assert!(penalty > 0.0);
389
390        let gradients = regularizer.compute_gradient(&params).unwrap();
391        let grad_w = gradients.get("w").unwrap();
392        assert_eq!(grad_w.shape(), &[1, 2]);
393    }
394
395    #[test]
396    fn test_elastic_net_invalid_ratio() {
397        let result = ElasticNetRegularization::new(0.1, 1.5);
398        assert!(result.is_err());
399
400        let result = ElasticNetRegularization::new(0.1, -0.1);
401        assert!(result.is_err());
402    }
403
404    #[test]
405    fn test_composite_regularization() {
406        let mut composite = CompositeRegularization::new();
407        composite.add(L1Regularization::new(0.1));
408        composite.add(L2Regularization::new(0.1));
409
410        let mut params = HashMap::new();
411        params.insert("w".to_string(), array![[1.0, 2.0]]);
412
413        let penalty = composite.compute_penalty(&params).unwrap();
414        // L1: 0.1 * (1 + 2) = 0.3
415        // L2: 0.5 * 0.1 * (1 + 4) = 0.25
416        // Total: 0.55
417        assert!((penalty - 0.55).abs() < 1e-6);
418
419        let gradients = composite.compute_gradient(&params).unwrap();
420        let grad_w = gradients.get("w").unwrap();
421        assert_eq!(grad_w.shape(), &[1, 2]);
422
423        // Gradient should combine both L1 and L2
424        // For w[0,0] = 1.0: L1 grad = 0.1, L2 grad = 0.1, total = 0.2
425        assert!((grad_w[[0, 0]] - 0.2).abs() < 1e-6);
426    }
427
428    #[test]
429    fn test_composite_empty() {
430        let composite = CompositeRegularization::new();
431        assert!(composite.is_empty());
432        assert_eq!(composite.len(), 0);
433
434        let mut params = HashMap::new();
435        params.insert("w".to_string(), array![[1.0]]);
436
437        let penalty = composite.compute_penalty(&params).unwrap();
438        assert_eq!(penalty, 0.0);
439    }
440
441    #[test]
442    fn test_multiple_parameters() {
443        let regularizer = L2Regularization::new(0.1);
444
445        let mut params = HashMap::new();
446        params.insert("w1".to_string(), array![[1.0, 2.0]]);
447        params.insert("w2".to_string(), array![[3.0]]);
448
449        let penalty = regularizer.compute_penalty(&params).unwrap();
450        // Expected: 0.5 * 0.1 * (1 + 4 + 9) = 0.7
451        assert!((penalty - 0.7).abs() < 1e-6);
452
453        let gradients = regularizer.compute_gradient(&params).unwrap();
454        assert_eq!(gradients.len(), 2);
455        assert!(gradients.contains_key("w1"));
456        assert!(gradients.contains_key("w2"));
457    }
458
459    #[test]
460    fn test_zero_lambda() {
461        let regularizer = L1Regularization::new(0.0);
462
463        let mut params = HashMap::new();
464        params.insert("w".to_string(), array![[100.0, 200.0]]);
465
466        let penalty = regularizer.compute_penalty(&params).unwrap();
467        assert_eq!(penalty, 0.0);
468
469        let gradients = regularizer.compute_gradient(&params).unwrap();
470        let grad_w = gradients.get("w").unwrap();
471        assert_eq!(grad_w[[0, 0]], 0.0);
472        assert_eq!(grad_w[[0, 1]], 0.0);
473    }
474}