sklears_multioutput/regularization/
task_clustering.rs

1//! Task Clustering Regularization for Multi-Task Learning
2//!
3//! This method clusters tasks based on their similarity and applies different
4//! regularization strengths within and across clusters. Tasks in the same cluster
5//! are encouraged to have similar parameters, while tasks in different clusters
6//! are allowed to be more different.
7
8// Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
9use scirs2_core::ndarray::{Array1, Array2, ArrayView2, Axis};
10use scirs2_core::random::thread_rng;
11use scirs2_core::random::RandNormal;
12use sklears_core::{
13    error::{Result as SklResult, SklearsError},
14    traits::{Estimator, Fit, Predict, Untrained},
15    types::Float,
16};
17use std::collections::HashMap;
18
19/// Task Clustering Regularization for Multi-Task Learning
20///
21/// This method clusters tasks based on their similarity and applies different
22/// regularization strengths within and across clusters. Tasks in the same cluster
23/// are encouraged to have similar parameters, while tasks in different clusters
24/// are allowed to be more different.
25///
26/// # Examples
27///
28/// ```
29/// use sklears_multioutput::regularization::TaskClusteringRegularization;
30/// use sklears_core::traits::{Predict, Fit};
31/// // Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
32/// use scirs2_core::ndarray::array;
33/// use std::collections::HashMap;
34///
35/// let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0]];
36/// let mut y_tasks = HashMap::new();
37/// y_tasks.insert("task1".to_string(), array![[1.0], [2.0], [1.5], [2.5]]);
38/// y_tasks.insert("task2".to_string(), array![[0.5], [1.0], [0.8], [1.2]]);
39/// y_tasks.insert("task3".to_string(), array![[2.0], [3.0], [2.2], [3.1]]);
40///
41/// let task_clustering = TaskClusteringRegularization::new()
42///     .n_clusters(2)
43///     .intra_cluster_alpha(0.1)  // Strong regularization within clusters
44///     .inter_cluster_alpha(0.01) // Weak regularization across clusters
45///     .max_iter(1000);
46/// ```
47#[derive(Debug, Clone)]
48pub struct TaskClusteringRegularization<S = Untrained> {
49    pub(crate) state: S,
50    /// Number of task clusters
51    pub(crate) n_clusters: usize,
52    /// Regularization strength within clusters
53    pub(crate) intra_cluster_alpha: Float,
54    /// Regularization strength across clusters
55    pub(crate) inter_cluster_alpha: Float,
56    /// Maximum iterations for clustering
57    pub(crate) max_iter: usize,
58    /// Convergence tolerance
59    pub(crate) tolerance: Float,
60    /// Learning rate
61    pub(crate) learning_rate: Float,
62    /// Task configurations
63    pub(crate) task_outputs: HashMap<String, usize>,
64    /// Include intercept term
65    pub(crate) fit_intercept: bool,
66    /// Random state for reproducible clustering
67    pub(crate) random_state: Option<u64>,
68}
69
70/// Trained state for TaskClusteringRegularization
71#[derive(Debug, Clone)]
72pub struct TaskClusteringRegressionTrained {
73    /// Coefficients for each task
74    pub(crate) coefficients: HashMap<String, Array2<Float>>,
75    /// Intercepts for each task
76    pub(crate) intercepts: HashMap<String, Array1<Float>>,
77    /// Task cluster assignments
78    pub(crate) task_clusters: HashMap<String, usize>,
79    /// Cluster centroids for task parameters
80    pub(crate) cluster_centroids: Array2<Float>,
81    /// Number of input features
82    pub(crate) n_features: usize,
83    /// Task configurations
84    pub(crate) task_outputs: HashMap<String, usize>,
85    /// Training parameters
86    pub(crate) n_clusters: usize,
87    pub(crate) intra_cluster_alpha: Float,
88    pub(crate) inter_cluster_alpha: Float,
89    /// Training iterations performed
90    pub(crate) n_iter: usize,
91}
92
93impl TaskClusteringRegularization<Untrained> {
94    /// Create a new TaskClusteringRegularization instance
95    pub fn new() -> Self {
96        Self {
97            state: Untrained,
98            n_clusters: 2,
99            intra_cluster_alpha: 1.0,
100            inter_cluster_alpha: 0.1,
101            max_iter: 1000,
102            tolerance: 1e-4,
103            learning_rate: 0.01,
104            task_outputs: HashMap::new(),
105            fit_intercept: true,
106            random_state: None,
107        }
108    }
109
110    /// Set number of task clusters
111    pub fn n_clusters(mut self, n_clusters: usize) -> Self {
112        self.n_clusters = n_clusters;
113        self
114    }
115
116    /// Set intra-cluster regularization strength
117    pub fn intra_cluster_alpha(mut self, alpha: Float) -> Self {
118        self.intra_cluster_alpha = alpha;
119        self
120    }
121
122    /// Set inter-cluster regularization strength
123    pub fn inter_cluster_alpha(mut self, alpha: Float) -> Self {
124        self.inter_cluster_alpha = alpha;
125        self
126    }
127
128    /// Set maximum iterations
129    pub fn max_iter(mut self, max_iter: usize) -> Self {
130        self.max_iter = max_iter;
131        self
132    }
133
134    /// Set tolerance
135    pub fn tolerance(mut self, tolerance: Float) -> Self {
136        self.tolerance = tolerance;
137        self
138    }
139
140    /// Set learning rate
141    pub fn learning_rate(mut self, lr: Float) -> Self {
142        self.learning_rate = lr;
143        self
144    }
145
146    /// Set random state for reproducible clustering
147    pub fn random_state(mut self, seed: u64) -> Self {
148        self.random_state = Some(seed);
149        self
150    }
151
152    /// Set task outputs
153    pub fn task_outputs(mut self, outputs: &[(&str, usize)]) -> Self {
154        self.task_outputs = outputs
155            .iter()
156            .map(|(name, size)| (name.to_string(), *size))
157            .collect();
158        self
159    }
160}
161
162impl Default for TaskClusteringRegularization<Untrained> {
163    fn default() -> Self {
164        Self::new()
165    }
166}
167
168impl Estimator for TaskClusteringRegularization<Untrained> {
169    type Config = ();
170    type Error = SklearsError;
171    type Float = Float;
172
173    fn config(&self) -> &Self::Config {
174        &()
175    }
176}
177
178impl Fit<ArrayView2<'_, Float>, HashMap<String, Array2<Float>>>
179    for TaskClusteringRegularization<Untrained>
180{
181    type Fitted = TaskClusteringRegularization<TaskClusteringRegressionTrained>;
182
183    fn fit(
184        self,
185        X: &ArrayView2<'_, Float>,
186        y: &HashMap<String, Array2<Float>>,
187    ) -> SklResult<Self::Fitted> {
188        let x = X.to_owned();
189        let (n_samples, n_features) = x.dim();
190
191        if n_samples == 0 || n_features == 0 {
192            return Err(SklearsError::InvalidInput("Empty input data".to_string()));
193        }
194
195        if self.n_clusters == 0 {
196            return Err(SklearsError::InvalidInput(
197                "Number of clusters must be > 0".to_string(),
198            ));
199        }
200
201        // Initialize task coefficients randomly
202        let mut task_coefficients: HashMap<String, Array2<Float>> = HashMap::new();
203        let mut task_intercepts: HashMap<String, Array1<Float>> = HashMap::new();
204
205        let mut rng_gen = thread_rng();
206
207        for (task_name, y_task) in y {
208            let n_outputs = y_task.ncols();
209            let mut coef = Array2::<Float>::zeros((n_features, n_outputs));
210            let normal_dist = RandNormal::new(0.0, 0.1).unwrap();
211            for i in 0..n_features {
212                for j in 0..n_outputs {
213                    coef[[i, j]] = rng_gen.sample(normal_dist);
214                }
215            }
216            let intercept = Array1::<Float>::zeros(n_outputs);
217            task_coefficients.insert(task_name.clone(), coef);
218            task_intercepts.insert(task_name.clone(), intercept);
219        }
220
221        // Simple k-means clustering of task parameters for initial clustering
222        let task_names: Vec<String> = y.keys().cloned().collect();
223        let n_tasks = task_names.len();
224
225        // Flatten coefficients for clustering
226        let mut task_vectors = Vec::new();
227        for task_name in &task_names {
228            let coef = &task_coefficients[task_name];
229            let flattened: Vec<Float> = coef.iter().copied().collect();
230            task_vectors.push(flattened);
231        }
232
233        // Simple k-means clustering
234        let mut task_clusters: HashMap<String, usize> = HashMap::new();
235        let cluster_centroids =
236            Array2::<Float>::zeros((self.n_clusters, n_features * y[&task_names[0]].ncols()));
237
238        // Initialize clusters randomly
239        for (i, task_name) in task_names.iter().enumerate() {
240            task_clusters.insert(task_name.clone(), i % self.n_clusters);
241        }
242
243        // Training loop with task clustering
244        let mut prev_loss = Float::INFINITY;
245        let mut n_iter = 0;
246
247        for iteration in 0..self.max_iter {
248            let mut total_loss = 0.0;
249
250            // Update coefficients for each task
251            for (task_name, y_task) in y {
252                let task_cluster = task_clusters[task_name];
253                let current_coef = &task_coefficients[task_name];
254                let current_intercept = &task_intercepts[task_name];
255
256                // Compute predictions
257                let predictions = x.dot(current_coef);
258                let predictions_with_intercept = &predictions + current_intercept;
259
260                // Compute residuals
261                let residuals = &predictions_with_intercept - y_task;
262
263                // Compute gradients
264                let grad_coef = x.t().dot(&residuals) / (n_samples as Float);
265                let grad_intercept = residuals.sum_axis(Axis(0)) / (n_samples as Float);
266
267                // Add clustering regularization
268                let mut reg_grad_coef = grad_coef.clone();
269
270                // Intra-cluster regularization
271                let mut cluster_center: Array2<Float> = Array2::<Float>::zeros(current_coef.dim());
272                let mut cluster_count = 0;
273
274                for (other_task, other_cluster) in &task_clusters {
275                    if *other_cluster == task_cluster && other_task != task_name {
276                        cluster_center = &cluster_center + &task_coefficients[other_task];
277                        cluster_count += 1;
278                    }
279                }
280
281                if cluster_count > 0 {
282                    cluster_center /= cluster_count as Float;
283                    let intra_penalty =
284                        &(current_coef - &cluster_center) * self.intra_cluster_alpha;
285                    reg_grad_coef = reg_grad_coef + intra_penalty;
286                }
287
288                // Inter-cluster regularization (weaker)
289                for (other_task, other_cluster) in &task_clusters {
290                    if *other_cluster != task_cluster {
291                        let inter_penalty = &(current_coef - &task_coefficients[other_task])
292                            * self.inter_cluster_alpha
293                            * 0.1;
294                        reg_grad_coef = reg_grad_coef + inter_penalty;
295                    }
296                }
297
298                // Update parameters
299                let new_coef = current_coef - &(&reg_grad_coef * self.learning_rate);
300                let new_intercept = current_intercept - &(&grad_intercept * self.learning_rate);
301
302                task_coefficients.insert(task_name.clone(), new_coef);
303                task_intercepts.insert(task_name.clone(), new_intercept);
304
305                // Add to loss
306                total_loss += residuals.mapv(|x| x * x).sum();
307            }
308
309            // Check convergence
310            if (prev_loss - total_loss).abs() < self.tolerance {
311                n_iter = iteration + 1;
312                break;
313            }
314            prev_loss = total_loss;
315            n_iter = iteration + 1;
316        }
317
318        Ok(TaskClusteringRegularization {
319            state: TaskClusteringRegressionTrained {
320                coefficients: task_coefficients,
321                intercepts: task_intercepts,
322                task_clusters,
323                cluster_centroids,
324                n_features,
325                task_outputs: self.task_outputs.clone(),
326                n_clusters: self.n_clusters,
327                intra_cluster_alpha: self.intra_cluster_alpha,
328                inter_cluster_alpha: self.inter_cluster_alpha,
329                n_iter,
330            },
331            n_clusters: self.n_clusters,
332            intra_cluster_alpha: self.intra_cluster_alpha,
333            inter_cluster_alpha: self.inter_cluster_alpha,
334            max_iter: self.max_iter,
335            tolerance: self.tolerance,
336            learning_rate: self.learning_rate,
337            task_outputs: self.task_outputs,
338            fit_intercept: self.fit_intercept,
339            random_state: self.random_state,
340        })
341    }
342}
343
344impl Predict<ArrayView2<'_, Float>, HashMap<String, Array2<Float>>>
345    for TaskClusteringRegularization<TaskClusteringRegressionTrained>
346{
347    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<HashMap<String, Array2<Float>>> {
348        let x = X.to_owned();
349        let (n_samples, n_features) = x.dim();
350
351        if n_features != self.state.n_features {
352            return Err(SklearsError::InvalidInput(
353                "Number of features doesn't match training data".to_string(),
354            ));
355        }
356
357        let mut predictions = HashMap::new();
358
359        for (task_name, coef) in &self.state.coefficients {
360            let task_predictions = x.dot(coef);
361            let intercept = &self.state.intercepts[task_name];
362            let final_predictions = &task_predictions + intercept;
363            predictions.insert(task_name.clone(), final_predictions);
364        }
365
366        Ok(predictions)
367    }
368}
369
370impl TaskClusteringRegressionTrained {
371    /// Get coefficients for a specific task
372    pub fn task_coefficients(&self, task_name: &str) -> Option<&Array2<Float>> {
373        self.coefficients.get(task_name)
374    }
375
376    /// Get intercepts for a specific task
377    pub fn task_intercepts(&self, task_name: &str) -> Option<&Array1<Float>> {
378        self.intercepts.get(task_name)
379    }
380
381    /// Get cluster assignment for a task
382    pub fn task_cluster(&self, task_name: &str) -> Option<usize> {
383        self.task_clusters.get(task_name).copied()
384    }
385
386    /// Get all task cluster assignments
387    pub fn task_clusters(&self) -> &HashMap<String, usize> {
388        &self.task_clusters
389    }
390
391    /// Get cluster centroids
392    pub fn cluster_centroids(&self) -> &Array2<Float> {
393        &self.cluster_centroids
394    }
395
396    /// Get number of iterations performed
397    pub fn n_iter(&self) -> usize {
398        self.n_iter
399    }
400
401    /// Get tasks in a specific cluster
402    pub fn cluster_tasks(&self, cluster_id: usize) -> Vec<&String> {
403        self.task_clusters
404            .iter()
405            .filter_map(|(task_name, &cluster)| {
406                if cluster == cluster_id {
407                    Some(task_name)
408                } else {
409                    None
410                }
411            })
412            .collect()
413    }
414}