1use crate::{TrainError, TrainResult};
9use scirs2_core::ndarray::{Array, Ix2};
10use std::collections::HashMap;
11
12pub trait Regularizer {
14 fn compute_penalty(&self, parameters: &HashMap<String, Array<f64, Ix2>>) -> TrainResult<f64>;
22
23 fn compute_gradient(
31 &self,
32 parameters: &HashMap<String, Array<f64, Ix2>>,
33 ) -> TrainResult<HashMap<String, Array<f64, Ix2>>>;
34}
35
36#[derive(Debug, Clone)]
41pub struct L1Regularization {
42 pub lambda: f64,
44}
45
46impl L1Regularization {
47 pub fn new(lambda: f64) -> Self {
52 Self { lambda }
53 }
54}
55
56impl Default for L1Regularization {
57 fn default() -> Self {
58 Self { lambda: 0.01 }
59 }
60}
61
62impl Regularizer for L1Regularization {
63 fn compute_penalty(&self, parameters: &HashMap<String, Array<f64, Ix2>>) -> TrainResult<f64> {
64 let mut penalty = 0.0;
65
66 for param in parameters.values() {
67 for &value in param.iter() {
68 penalty += value.abs();
69 }
70 }
71
72 Ok(self.lambda * penalty)
73 }
74
75 fn compute_gradient(
76 &self,
77 parameters: &HashMap<String, Array<f64, Ix2>>,
78 ) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
79 let mut gradients = HashMap::new();
80
81 for (name, param) in parameters {
82 let grad = param.mapv(|w| self.lambda * w.signum());
84 gradients.insert(name.clone(), grad);
85 }
86
87 Ok(gradients)
88 }
89}
90
91#[derive(Debug, Clone)]
96pub struct L2Regularization {
97 pub lambda: f64,
99}
100
101impl L2Regularization {
102 pub fn new(lambda: f64) -> Self {
107 Self { lambda }
108 }
109}
110
111impl Default for L2Regularization {
112 fn default() -> Self {
113 Self { lambda: 0.01 }
114 }
115}
116
117impl Regularizer for L2Regularization {
118 fn compute_penalty(&self, parameters: &HashMap<String, Array<f64, Ix2>>) -> TrainResult<f64> {
119 let mut penalty = 0.0;
120
121 for param in parameters.values() {
122 for &value in param.iter() {
123 penalty += value * value;
124 }
125 }
126
127 Ok(0.5 * self.lambda * penalty)
128 }
129
130 fn compute_gradient(
131 &self,
132 parameters: &HashMap<String, Array<f64, Ix2>>,
133 ) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
134 let mut gradients = HashMap::new();
135
136 for (name, param) in parameters {
137 let grad = param.mapv(|w| self.lambda * w);
139 gradients.insert(name.clone(), grad);
140 }
141
142 Ok(gradients)
143 }
144}
145
146#[derive(Debug, Clone)]
150pub struct ElasticNetRegularization {
151 pub lambda: f64,
153 pub l1_ratio: f64,
155}
156
157impl ElasticNetRegularization {
158 pub fn new(lambda: f64, l1_ratio: f64) -> TrainResult<Self> {
164 if !(0.0..=1.0).contains(&l1_ratio) {
165 return Err(TrainError::InvalidParameter(
166 "l1_ratio must be between 0.0 and 1.0".to_string(),
167 ));
168 }
169 Ok(Self { lambda, l1_ratio })
170 }
171}
172
173impl Default for ElasticNetRegularization {
174 fn default() -> Self {
175 Self {
176 lambda: 0.01,
177 l1_ratio: 0.5,
178 }
179 }
180}
181
182impl Regularizer for ElasticNetRegularization {
183 fn compute_penalty(&self, parameters: &HashMap<String, Array<f64, Ix2>>) -> TrainResult<f64> {
184 let mut l1_penalty = 0.0;
185 let mut l2_penalty = 0.0;
186
187 for param in parameters.values() {
188 for &value in param.iter() {
189 l1_penalty += value.abs();
190 l2_penalty += value * value;
191 }
192 }
193
194 let penalty =
195 self.lambda * (self.l1_ratio * l1_penalty + (1.0 - self.l1_ratio) * 0.5 * l2_penalty);
196
197 Ok(penalty)
198 }
199
200 fn compute_gradient(
201 &self,
202 parameters: &HashMap<String, Array<f64, Ix2>>,
203 ) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
204 let mut gradients = HashMap::new();
205
206 for (name, param) in parameters {
207 let grad = param
209 .mapv(|w| self.lambda * (self.l1_ratio * w.signum() + (1.0 - self.l1_ratio) * w));
210 gradients.insert(name.clone(), grad);
211 }
212
213 Ok(gradients)
214 }
215}
216
217#[derive(Clone)]
221pub struct CompositeRegularization {
222 regularizers: Vec<Box<dyn RegularizerClone>>,
223}
224
225impl std::fmt::Debug for CompositeRegularization {
226 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
227 f.debug_struct("CompositeRegularization")
228 .field("num_regularizers", &self.regularizers.len())
229 .finish()
230 }
231}
232
233trait RegularizerClone: Regularizer {
235 fn clone_box(&self) -> Box<dyn RegularizerClone>;
236}
237
238impl<T: Regularizer + Clone + 'static> RegularizerClone for T {
239 fn clone_box(&self) -> Box<dyn RegularizerClone> {
240 Box::new(self.clone())
241 }
242}
243
244impl Clone for Box<dyn RegularizerClone> {
245 fn clone(&self) -> Self {
246 self.clone_box()
247 }
248}
249
250impl Regularizer for Box<dyn RegularizerClone> {
251 fn compute_penalty(&self, parameters: &HashMap<String, Array<f64, Ix2>>) -> TrainResult<f64> {
252 (**self).compute_penalty(parameters)
253 }
254
255 fn compute_gradient(
256 &self,
257 parameters: &HashMap<String, Array<f64, Ix2>>,
258 ) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
259 (**self).compute_gradient(parameters)
260 }
261}
262
263impl CompositeRegularization {
264 pub fn new() -> Self {
266 Self {
267 regularizers: Vec::new(),
268 }
269 }
270
271 pub fn add<R: Regularizer + Clone + 'static>(&mut self, regularizer: R) {
276 self.regularizers.push(Box::new(regularizer));
277 }
278
279 pub fn len(&self) -> usize {
281 self.regularizers.len()
282 }
283
284 pub fn is_empty(&self) -> bool {
286 self.regularizers.is_empty()
287 }
288}
289
290impl Default for CompositeRegularization {
291 fn default() -> Self {
292 Self::new()
293 }
294}
295
296impl Regularizer for CompositeRegularization {
297 fn compute_penalty(&self, parameters: &HashMap<String, Array<f64, Ix2>>) -> TrainResult<f64> {
298 let mut total_penalty = 0.0;
299
300 for regularizer in &self.regularizers {
301 total_penalty += regularizer.compute_penalty(parameters)?;
302 }
303
304 Ok(total_penalty)
305 }
306
307 fn compute_gradient(
308 &self,
309 parameters: &HashMap<String, Array<f64, Ix2>>,
310 ) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
311 let mut total_gradients: HashMap<String, Array<f64, Ix2>> = HashMap::new();
312
313 for (name, param) in parameters {
315 total_gradients.insert(name.clone(), Array::zeros(param.raw_dim()));
316 }
317
318 for regularizer in &self.regularizers {
320 let grads = regularizer.compute_gradient(parameters)?;
321
322 for (name, grad) in grads {
323 if let Some(total_grad) = total_gradients.get_mut(&name) {
324 *total_grad = &*total_grad + &grad;
325 }
326 }
327 }
328
329 Ok(total_gradients)
330 }
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336 use scirs2_core::ndarray::array;
337
338 #[test]
339 fn test_l1_regularization() {
340 let regularizer = L1Regularization::new(0.1);
341
342 let mut params = HashMap::new();
343 params.insert("w".to_string(), array![[1.0, -2.0], [3.0, -4.0]]);
344
345 let penalty = regularizer.compute_penalty(¶ms).unwrap();
346 assert!((penalty - 1.0).abs() < 1e-6);
348
349 let gradients = regularizer.compute_gradient(¶ms).unwrap();
350 let grad_w = gradients.get("w").unwrap();
351
352 assert_eq!(grad_w[[0, 0]], 0.1); assert_eq!(grad_w[[0, 1]], -0.1); assert_eq!(grad_w[[1, 0]], 0.1); assert_eq!(grad_w[[1, 1]], -0.1); }
358
359 #[test]
360 fn test_l2_regularization() {
361 let regularizer = L2Regularization::new(0.1);
362
363 let mut params = HashMap::new();
364 params.insert("w".to_string(), array![[1.0, 2.0], [3.0, 4.0]]);
365
366 let penalty = regularizer.compute_penalty(¶ms).unwrap();
367 assert!((penalty - 1.5).abs() < 1e-6);
369
370 let gradients = regularizer.compute_gradient(¶ms).unwrap();
371 let grad_w = gradients.get("w").unwrap();
372
373 assert!((grad_w[[0, 0]] - 0.1).abs() < 1e-10); assert!((grad_w[[0, 1]] - 0.2).abs() < 1e-10); assert!((grad_w[[1, 0]] - 0.3).abs() < 1e-10); assert!((grad_w[[1, 1]] - 0.4).abs() < 1e-10); }
379
380 #[test]
381 fn test_elastic_net_regularization() {
382 let regularizer = ElasticNetRegularization::new(0.1, 0.5).unwrap();
383
384 let mut params = HashMap::new();
385 params.insert("w".to_string(), array![[1.0, 2.0]]);
386
387 let penalty = regularizer.compute_penalty(¶ms).unwrap();
388 assert!(penalty > 0.0);
389
390 let gradients = regularizer.compute_gradient(¶ms).unwrap();
391 let grad_w = gradients.get("w").unwrap();
392 assert_eq!(grad_w.shape(), &[1, 2]);
393 }
394
395 #[test]
396 fn test_elastic_net_invalid_ratio() {
397 let result = ElasticNetRegularization::new(0.1, 1.5);
398 assert!(result.is_err());
399
400 let result = ElasticNetRegularization::new(0.1, -0.1);
401 assert!(result.is_err());
402 }
403
404 #[test]
405 fn test_composite_regularization() {
406 let mut composite = CompositeRegularization::new();
407 composite.add(L1Regularization::new(0.1));
408 composite.add(L2Regularization::new(0.1));
409
410 let mut params = HashMap::new();
411 params.insert("w".to_string(), array![[1.0, 2.0]]);
412
413 let penalty = composite.compute_penalty(¶ms).unwrap();
414 assert!((penalty - 0.55).abs() < 1e-6);
418
419 let gradients = composite.compute_gradient(¶ms).unwrap();
420 let grad_w = gradients.get("w").unwrap();
421 assert_eq!(grad_w.shape(), &[1, 2]);
422
423 assert!((grad_w[[0, 0]] - 0.2).abs() < 1e-6);
426 }
427
428 #[test]
429 fn test_composite_empty() {
430 let composite = CompositeRegularization::new();
431 assert!(composite.is_empty());
432 assert_eq!(composite.len(), 0);
433
434 let mut params = HashMap::new();
435 params.insert("w".to_string(), array![[1.0]]);
436
437 let penalty = composite.compute_penalty(¶ms).unwrap();
438 assert_eq!(penalty, 0.0);
439 }
440
441 #[test]
442 fn test_multiple_parameters() {
443 let regularizer = L2Regularization::new(0.1);
444
445 let mut params = HashMap::new();
446 params.insert("w1".to_string(), array![[1.0, 2.0]]);
447 params.insert("w2".to_string(), array![[3.0]]);
448
449 let penalty = regularizer.compute_penalty(¶ms).unwrap();
450 assert!((penalty - 0.7).abs() < 1e-6);
452
453 let gradients = regularizer.compute_gradient(¶ms).unwrap();
454 assert_eq!(gradients.len(), 2);
455 assert!(gradients.contains_key("w1"));
456 assert!(gradients.contains_key("w2"));
457 }
458
459 #[test]
460 fn test_zero_lambda() {
461 let regularizer = L1Regularization::new(0.0);
462
463 let mut params = HashMap::new();
464 params.insert("w".to_string(), array![[100.0, 200.0]]);
465
466 let penalty = regularizer.compute_penalty(¶ms).unwrap();
467 assert_eq!(penalty, 0.0);
468
469 let gradients = regularizer.compute_gradient(¶ms).unwrap();
470 let grad_w = gradients.get("w").unwrap();
471 assert_eq!(grad_w[[0, 0]], 0.0);
472 assert_eq!(grad_w[[0, 1]], 0.0);
473 }
474}