sklears_multioutput/regularization/
task_clustering.rs1use scirs2_core::ndarray::{Array1, Array2, ArrayView2, Axis};
10use scirs2_core::random::thread_rng;
11use scirs2_core::random::RandNormal;
12use sklears_core::{
13 error::{Result as SklResult, SklearsError},
14 traits::{Estimator, Fit, Predict, Untrained},
15 types::Float,
16};
17use std::collections::HashMap;
18
19#[derive(Debug, Clone)]
48pub struct TaskClusteringRegularization<S = Untrained> {
49 pub(crate) state: S,
50 pub(crate) n_clusters: usize,
52 pub(crate) intra_cluster_alpha: Float,
54 pub(crate) inter_cluster_alpha: Float,
56 pub(crate) max_iter: usize,
58 pub(crate) tolerance: Float,
60 pub(crate) learning_rate: Float,
62 pub(crate) task_outputs: HashMap<String, usize>,
64 pub(crate) fit_intercept: bool,
66 pub(crate) random_state: Option<u64>,
68}
69
70#[derive(Debug, Clone)]
72pub struct TaskClusteringRegressionTrained {
73 pub(crate) coefficients: HashMap<String, Array2<Float>>,
75 pub(crate) intercepts: HashMap<String, Array1<Float>>,
77 pub(crate) task_clusters: HashMap<String, usize>,
79 pub(crate) cluster_centroids: Array2<Float>,
81 pub(crate) n_features: usize,
83 pub(crate) task_outputs: HashMap<String, usize>,
85 pub(crate) n_clusters: usize,
87 pub(crate) intra_cluster_alpha: Float,
88 pub(crate) inter_cluster_alpha: Float,
89 pub(crate) n_iter: usize,
91}
92
93impl TaskClusteringRegularization<Untrained> {
94 pub fn new() -> Self {
96 Self {
97 state: Untrained,
98 n_clusters: 2,
99 intra_cluster_alpha: 1.0,
100 inter_cluster_alpha: 0.1,
101 max_iter: 1000,
102 tolerance: 1e-4,
103 learning_rate: 0.01,
104 task_outputs: HashMap::new(),
105 fit_intercept: true,
106 random_state: None,
107 }
108 }
109
110 pub fn n_clusters(mut self, n_clusters: usize) -> Self {
112 self.n_clusters = n_clusters;
113 self
114 }
115
116 pub fn intra_cluster_alpha(mut self, alpha: Float) -> Self {
118 self.intra_cluster_alpha = alpha;
119 self
120 }
121
122 pub fn inter_cluster_alpha(mut self, alpha: Float) -> Self {
124 self.inter_cluster_alpha = alpha;
125 self
126 }
127
128 pub fn max_iter(mut self, max_iter: usize) -> Self {
130 self.max_iter = max_iter;
131 self
132 }
133
134 pub fn tolerance(mut self, tolerance: Float) -> Self {
136 self.tolerance = tolerance;
137 self
138 }
139
140 pub fn learning_rate(mut self, lr: Float) -> Self {
142 self.learning_rate = lr;
143 self
144 }
145
146 pub fn random_state(mut self, seed: u64) -> Self {
148 self.random_state = Some(seed);
149 self
150 }
151
152 pub fn task_outputs(mut self, outputs: &[(&str, usize)]) -> Self {
154 self.task_outputs = outputs
155 .iter()
156 .map(|(name, size)| (name.to_string(), *size))
157 .collect();
158 self
159 }
160}
161
162impl Default for TaskClusteringRegularization<Untrained> {
163 fn default() -> Self {
164 Self::new()
165 }
166}
167
168impl Estimator for TaskClusteringRegularization<Untrained> {
169 type Config = ();
170 type Error = SklearsError;
171 type Float = Float;
172
173 fn config(&self) -> &Self::Config {
174 &()
175 }
176}
177
178impl Fit<ArrayView2<'_, Float>, HashMap<String, Array2<Float>>>
179 for TaskClusteringRegularization<Untrained>
180{
181 type Fitted = TaskClusteringRegularization<TaskClusteringRegressionTrained>;
182
183 fn fit(
184 self,
185 X: &ArrayView2<'_, Float>,
186 y: &HashMap<String, Array2<Float>>,
187 ) -> SklResult<Self::Fitted> {
188 let x = X.to_owned();
189 let (n_samples, n_features) = x.dim();
190
191 if n_samples == 0 || n_features == 0 {
192 return Err(SklearsError::InvalidInput("Empty input data".to_string()));
193 }
194
195 if self.n_clusters == 0 {
196 return Err(SklearsError::InvalidInput(
197 "Number of clusters must be > 0".to_string(),
198 ));
199 }
200
201 let mut task_coefficients: HashMap<String, Array2<Float>> = HashMap::new();
203 let mut task_intercepts: HashMap<String, Array1<Float>> = HashMap::new();
204
205 let mut rng_gen = thread_rng();
206
207 for (task_name, y_task) in y {
208 let n_outputs = y_task.ncols();
209 let mut coef = Array2::<Float>::zeros((n_features, n_outputs));
210 let normal_dist = RandNormal::new(0.0, 0.1).unwrap();
211 for i in 0..n_features {
212 for j in 0..n_outputs {
213 coef[[i, j]] = rng_gen.sample(normal_dist);
214 }
215 }
216 let intercept = Array1::<Float>::zeros(n_outputs);
217 task_coefficients.insert(task_name.clone(), coef);
218 task_intercepts.insert(task_name.clone(), intercept);
219 }
220
221 let task_names: Vec<String> = y.keys().cloned().collect();
223 let n_tasks = task_names.len();
224
225 let mut task_vectors = Vec::new();
227 for task_name in &task_names {
228 let coef = &task_coefficients[task_name];
229 let flattened: Vec<Float> = coef.iter().copied().collect();
230 task_vectors.push(flattened);
231 }
232
233 let mut task_clusters: HashMap<String, usize> = HashMap::new();
235 let cluster_centroids =
236 Array2::<Float>::zeros((self.n_clusters, n_features * y[&task_names[0]].ncols()));
237
238 for (i, task_name) in task_names.iter().enumerate() {
240 task_clusters.insert(task_name.clone(), i % self.n_clusters);
241 }
242
243 let mut prev_loss = Float::INFINITY;
245 let mut n_iter = 0;
246
247 for iteration in 0..self.max_iter {
248 let mut total_loss = 0.0;
249
250 for (task_name, y_task) in y {
252 let task_cluster = task_clusters[task_name];
253 let current_coef = &task_coefficients[task_name];
254 let current_intercept = &task_intercepts[task_name];
255
256 let predictions = x.dot(current_coef);
258 let predictions_with_intercept = &predictions + current_intercept;
259
260 let residuals = &predictions_with_intercept - y_task;
262
263 let grad_coef = x.t().dot(&residuals) / (n_samples as Float);
265 let grad_intercept = residuals.sum_axis(Axis(0)) / (n_samples as Float);
266
267 let mut reg_grad_coef = grad_coef.clone();
269
270 let mut cluster_center: Array2<Float> = Array2::<Float>::zeros(current_coef.dim());
272 let mut cluster_count = 0;
273
274 for (other_task, other_cluster) in &task_clusters {
275 if *other_cluster == task_cluster && other_task != task_name {
276 cluster_center = &cluster_center + &task_coefficients[other_task];
277 cluster_count += 1;
278 }
279 }
280
281 if cluster_count > 0 {
282 cluster_center /= cluster_count as Float;
283 let intra_penalty =
284 &(current_coef - &cluster_center) * self.intra_cluster_alpha;
285 reg_grad_coef = reg_grad_coef + intra_penalty;
286 }
287
288 for (other_task, other_cluster) in &task_clusters {
290 if *other_cluster != task_cluster {
291 let inter_penalty = &(current_coef - &task_coefficients[other_task])
292 * self.inter_cluster_alpha
293 * 0.1;
294 reg_grad_coef = reg_grad_coef + inter_penalty;
295 }
296 }
297
298 let new_coef = current_coef - &(®_grad_coef * self.learning_rate);
300 let new_intercept = current_intercept - &(&grad_intercept * self.learning_rate);
301
302 task_coefficients.insert(task_name.clone(), new_coef);
303 task_intercepts.insert(task_name.clone(), new_intercept);
304
305 total_loss += residuals.mapv(|x| x * x).sum();
307 }
308
309 if (prev_loss - total_loss).abs() < self.tolerance {
311 n_iter = iteration + 1;
312 break;
313 }
314 prev_loss = total_loss;
315 n_iter = iteration + 1;
316 }
317
318 Ok(TaskClusteringRegularization {
319 state: TaskClusteringRegressionTrained {
320 coefficients: task_coefficients,
321 intercepts: task_intercepts,
322 task_clusters,
323 cluster_centroids,
324 n_features,
325 task_outputs: self.task_outputs.clone(),
326 n_clusters: self.n_clusters,
327 intra_cluster_alpha: self.intra_cluster_alpha,
328 inter_cluster_alpha: self.inter_cluster_alpha,
329 n_iter,
330 },
331 n_clusters: self.n_clusters,
332 intra_cluster_alpha: self.intra_cluster_alpha,
333 inter_cluster_alpha: self.inter_cluster_alpha,
334 max_iter: self.max_iter,
335 tolerance: self.tolerance,
336 learning_rate: self.learning_rate,
337 task_outputs: self.task_outputs,
338 fit_intercept: self.fit_intercept,
339 random_state: self.random_state,
340 })
341 }
342}
343
344impl Predict<ArrayView2<'_, Float>, HashMap<String, Array2<Float>>>
345 for TaskClusteringRegularization<TaskClusteringRegressionTrained>
346{
347 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<HashMap<String, Array2<Float>>> {
348 let x = X.to_owned();
349 let (n_samples, n_features) = x.dim();
350
351 if n_features != self.state.n_features {
352 return Err(SklearsError::InvalidInput(
353 "Number of features doesn't match training data".to_string(),
354 ));
355 }
356
357 let mut predictions = HashMap::new();
358
359 for (task_name, coef) in &self.state.coefficients {
360 let task_predictions = x.dot(coef);
361 let intercept = &self.state.intercepts[task_name];
362 let final_predictions = &task_predictions + intercept;
363 predictions.insert(task_name.clone(), final_predictions);
364 }
365
366 Ok(predictions)
367 }
368}
369
370impl TaskClusteringRegressionTrained {
371 pub fn task_coefficients(&self, task_name: &str) -> Option<&Array2<Float>> {
373 self.coefficients.get(task_name)
374 }
375
376 pub fn task_intercepts(&self, task_name: &str) -> Option<&Array1<Float>> {
378 self.intercepts.get(task_name)
379 }
380
381 pub fn task_cluster(&self, task_name: &str) -> Option<usize> {
383 self.task_clusters.get(task_name).copied()
384 }
385
386 pub fn task_clusters(&self) -> &HashMap<String, usize> {
388 &self.task_clusters
389 }
390
391 pub fn cluster_centroids(&self) -> &Array2<Float> {
393 &self.cluster_centroids
394 }
395
396 pub fn n_iter(&self) -> usize {
398 self.n_iter
399 }
400
401 pub fn cluster_tasks(&self, cluster_id: usize) -> Vec<&String> {
403 self.task_clusters
404 .iter()
405 .filter_map(|(task_name, &cluster)| {
406 if cluster == cluster_id {
407 Some(task_name)
408 } else {
409 None
410 }
411 })
412 .collect()
413 }
414}