1use scirs2_core::ndarray::{Array1, Array2};
18use sklears_core::{traits::Untrained, types::Float};
19use std::collections::HashMap;
20
21#[path = "regularization/simd_ops.rs"]
23pub mod simd_ops;
24
25#[path = "regularization/group_lasso.rs"]
26pub mod group_lasso;
27
28#[path = "regularization/nuclear_norm.rs"]
29pub mod nuclear_norm;
30
31#[path = "regularization/task_clustering.rs"]
32pub mod task_clustering;
33
34#[path = "regularization/task_relationship.rs"]
35pub mod task_relationship;
36
37#[path = "regularization/meta_learning.rs"]
38pub mod meta_learning;
39
40pub use group_lasso::{GroupLasso, GroupLassoTrained};
42pub use meta_learning::{MetaLearningMultiTask, MetaLearningMultiTaskTrained};
43pub use nuclear_norm::{NuclearNormRegression, NuclearNormRegressionTrained};
44pub use task_clustering::{TaskClusteringRegressionTrained, TaskClusteringRegularization};
45pub use task_relationship::{
46 TaskRelationshipLearning, TaskRelationshipLearningTrained, TaskSimilarityMethod,
47};
48
49#[derive(Debug, Clone)]
56pub struct MultiTaskElasticNet<S = Untrained> {
57 state: S,
58 alpha: Float,
60 l1_ratio: Float,
62 feature_groups: Vec<Vec<usize>>,
64 group_alpha: Float,
66 max_iter: usize,
68 tolerance: Float,
70 learning_rate: Float,
72 task_outputs: HashMap<String, usize>,
74 fit_intercept: bool,
76}
77
78#[derive(Debug, Clone)]
82pub struct MultiTaskElasticNetTrained {
83 coefficients: HashMap<String, Array2<Float>>,
85 intercepts: HashMap<String, Array1<Float>>,
87 n_features: usize,
89 task_outputs: HashMap<String, usize>,
91 alpha: Float,
93 l1_ratio: Float,
94 group_alpha: Float,
95 n_iter: usize,
97}
98
99#[derive(Debug, Clone, PartialEq, Default)]
101pub enum RegularizationStrategy {
102 #[default]
104 None,
105 L1(Float),
107 L2(Float),
109 ElasticNet { alpha: Float, l1_ratio: Float },
111 GroupLasso { alpha: Float },
113 NuclearNorm { alpha: Float },
115 TaskClustering {
117 n_clusters: usize,
118 intra_cluster_alpha: Float,
119 inter_cluster_alpha: Float,
120 },
121 TaskRelationship {
123 relationship_strength: Float,
124 similarity_threshold: Float,
125 },
126 MetaLearning {
128 meta_learning_rate: Float,
129 inner_learning_rate: Float,
130 n_inner_steps: usize,
131 },
132}
133
134#[allow(non_snake_case)]
136#[cfg(test)]
137mod regularization_tests {
138 use super::*;
139 use approx::assert_abs_diff_eq;
140 use scirs2_core::ndarray::array;
142 use sklears_core::traits::{Fit, Predict};
143 use std::collections::HashMap;
144
145 #[test]
146 fn test_group_lasso_creation() {
147 let group_lasso = GroupLasso::new()
148 .alpha(0.1)
149 .feature_groups(vec![vec![0, 1], vec![2, 3]])
150 .max_iter(100)
151 .tolerance(1e-6)
152 .learning_rate(0.01);
153
154 assert_eq!(group_lasso.alpha, 0.1);
155 assert_eq!(group_lasso.feature_groups, vec![vec![0, 1], vec![2, 3]]);
156 assert_eq!(group_lasso.max_iter, 100);
157 assert_abs_diff_eq!(group_lasso.tolerance, 1e-6);
158 assert_abs_diff_eq!(group_lasso.learning_rate, 0.01);
159 }
160
161 #[test]
162 fn test_group_lasso_fit_predict() {
163 let X = array![
164 [1.0, 2.0, 3.0, 4.0],
165 [2.0, 3.0, 4.0, 5.0],
166 [3.0, 1.0, 2.0, 3.0],
167 [4.0, 2.0, 1.0, 2.0]
168 ];
169
170 let mut y_tasks = HashMap::new();
171 y_tasks.insert("task1".to_string(), array![[1.0], [2.0], [1.5], [2.5]]);
172 y_tasks.insert("task2".to_string(), array![[0.5], [1.0], [0.8], [1.2]]);
173
174 let feature_groups = vec![vec![0, 1], vec![2, 3]];
175
176 let group_lasso = GroupLasso::new()
177 .alpha(0.01)
178 .feature_groups(feature_groups)
179 .task_outputs(&[("task1", 1), ("task2", 1)])
180 .max_iter(50)
181 .tolerance(1e-4)
182 .learning_rate(0.01);
183
184 let trained = group_lasso.fit(&X.view(), &y_tasks).unwrap();
185
186 let predictions = trained.predict(&X.view()).unwrap();
188 assert!(predictions.contains_key("task1"));
189 assert!(predictions.contains_key("task2"));
190
191 let task1_pred = &predictions["task1"];
192 let task2_pred = &predictions["task2"];
193
194 assert_eq!(task1_pred.shape(), &[4, 1]);
195 assert_eq!(task2_pred.shape(), &[4, 1]);
196
197 let sparsity = trained.group_sparsity();
199 assert!(sparsity >= 0.0 && sparsity <= 1.0); assert!(trained.task_coefficients("task1").is_some());
203 assert!(trained.task_intercepts("task1").is_some());
204 assert!(trained.n_iter() <= 50);
205 }
206
207 #[test]
208 fn test_nuclear_norm_regression_creation() {
209 let nuclear_norm = NuclearNormRegression::new()
210 .alpha(0.1)
211 .max_iter(100)
212 .tolerance(1e-6)
213 .learning_rate(0.01)
214 .target_rank(Some(5));
215
216 assert_eq!(nuclear_norm.alpha, 0.1);
217 assert_eq!(nuclear_norm.max_iter, 100);
218 assert_abs_diff_eq!(nuclear_norm.tolerance, 1e-6);
219 assert_abs_diff_eq!(nuclear_norm.learning_rate, 0.01);
220 assert_eq!(nuclear_norm.target_rank, Some(5));
221 }
222
223 #[test]
224 fn test_nuclear_norm_regression_fit_predict() {
225 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0]];
226
227 let mut y_tasks = HashMap::new();
228 y_tasks.insert("task1".to_string(), array![[1.0], [2.0], [1.5], [2.5]]);
229 y_tasks.insert("task2".to_string(), array![[0.5], [1.0], [0.8], [1.2]]);
230
231 let nuclear_norm = NuclearNormRegression::new()
232 .alpha(0.01)
233 .task_outputs(&[("task1", 1), ("task2", 1)])
234 .max_iter(50)
235 .tolerance(1e-4)
236 .learning_rate(0.01);
237
238 let trained = nuclear_norm.fit(&X.view(), &y_tasks).unwrap();
239
240 let predictions = trained.predict(&X.view()).unwrap();
242 assert!(predictions.contains_key("task1"));
243 assert!(predictions.contains_key("task2"));
244
245 let task1_pred = &predictions["task1"];
246 let task2_pred = &predictions["task2"];
247
248 assert_eq!(task1_pred.shape(), &[4, 1]);
249 assert_eq!(task2_pred.shape(), &[4, 1]);
250
251 assert!(trained.task_coefficient_matrix("task1").is_some());
253 assert!(trained.effective_rank() >= 0);
254 assert!(!trained.singular_values().is_empty());
255 assert!(trained.n_iter() <= 50);
256 }
257
258 #[test]
259 fn test_regularization_strategies() {
260 let strategies = vec![
261 RegularizationStrategy::None,
262 RegularizationStrategy::L1(0.1),
263 RegularizationStrategy::L2(0.1),
264 RegularizationStrategy::ElasticNet {
265 alpha: 0.1,
266 l1_ratio: 0.5,
267 },
268 RegularizationStrategy::GroupLasso { alpha: 0.1 },
269 RegularizationStrategy::NuclearNorm { alpha: 0.1 },
270 RegularizationStrategy::TaskClustering {
271 n_clusters: 2,
272 intra_cluster_alpha: 0.1,
273 inter_cluster_alpha: 0.01,
274 },
275 RegularizationStrategy::TaskRelationship {
276 relationship_strength: 0.1,
277 similarity_threshold: 0.5,
278 },
279 RegularizationStrategy::MetaLearning {
280 meta_learning_rate: 0.01,
281 inner_learning_rate: 0.1,
282 n_inner_steps: 5,
283 },
284 ];
285
286 assert_eq!(strategies.len(), 9);
287 assert_eq!(strategies[0], RegularizationStrategy::None);
288 assert_eq!(strategies[1], RegularizationStrategy::L1(0.1));
289 }
290
291 #[test]
292 fn test_task_similarity_methods() {
293 let methods = vec![
294 TaskSimilarityMethod::Correlation,
295 TaskSimilarityMethod::Cosine,
296 TaskSimilarityMethod::Euclidean,
297 TaskSimilarityMethod::MutualInformation,
298 ];
299
300 assert_eq!(methods.len(), 4);
301 assert_eq!(methods[0], TaskSimilarityMethod::Correlation);
302 assert_eq!(methods[1], TaskSimilarityMethod::Cosine);
303 }
304}