sklears_multioutput/regularization/
task_relationship.rs

1//! Task Relationship Learning for Multi-Task Learning
2//!
3//! This method learns explicit relationships between tasks and uses this information
4//! to regularize the learning process. Tasks that are determined to be related
5//! are encouraged to have similar parameters.
6
7// Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
8use scirs2_core::ndarray::{Array1, Array2, ArrayView2, Axis};
9use scirs2_core::random::thread_rng;
10use scirs2_core::random::RandNormal;
11use sklears_core::{
12    error::{Result as SklResult, SklearsError},
13    traits::{Estimator, Fit, Predict, Untrained},
14    types::Float,
15};
16use std::collections::HashMap;
17
18/// Methods for computing task similarity
19#[derive(Debug, Clone, PartialEq)]
20pub enum TaskSimilarityMethod {
21    /// Correlation-based similarity
22    Correlation,
23    /// Cosine similarity of task parameters
24    Cosine,
25    /// Euclidean distance-based similarity
26    Euclidean,
27    /// Mutual information-based similarity
28    MutualInformation,
29}
30
31/// Task Relationship Learning for Multi-Task Learning
32///
33/// This method learns explicit relationships between tasks and uses this information
34/// to regularize the learning process. Tasks that are determined to be related
35/// are encouraged to have similar parameters.
36///
37/// # Examples
38///
39/// ```
40/// use sklears_multioutput::regularization::TaskRelationshipLearning;
41/// use sklears_core::traits::{Predict, Fit};
42/// // Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
43/// use scirs2_core::ndarray::array;
44/// use std::collections::HashMap;
45///
46/// let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0]];
47/// let mut y_tasks = HashMap::new();
48/// y_tasks.insert("task1".to_string(), array![[1.0], [2.0], [1.5], [2.5]]);
49/// y_tasks.insert("task2".to_string(), array![[0.5], [1.0], [0.8], [1.2]]);
50///
51/// let task_relationship = TaskRelationshipLearning::new()
52///     .relationship_strength(0.1)
53///     .similarity_threshold(0.5)
54///     .max_iter(1000);
55/// ```
56#[derive(Debug, Clone)]
57pub struct TaskRelationshipLearning<S = Untrained> {
58    pub(crate) state: S,
59    /// Strength of relationship regularization
60    pub(crate) relationship_strength: Float,
61    /// Threshold for task similarity to be considered related
62    pub(crate) similarity_threshold: Float,
63    /// Base regularization strength
64    pub(crate) base_alpha: Float,
65    /// Maximum iterations
66    pub(crate) max_iter: usize,
67    /// Convergence tolerance
68    pub(crate) tolerance: Float,
69    /// Learning rate
70    pub(crate) learning_rate: Float,
71    /// Task configurations
72    pub(crate) task_outputs: HashMap<String, usize>,
73    /// Include intercept term
74    pub(crate) fit_intercept: bool,
75    /// Method for computing task similarity
76    pub(crate) similarity_method: TaskSimilarityMethod,
77}
78
79/// Trained state for TaskRelationshipLearning
80#[derive(Debug, Clone)]
81pub struct TaskRelationshipLearningTrained {
82    /// Coefficients for each task
83    pub(crate) coefficients: HashMap<String, Array2<Float>>,
84    /// Intercepts for each task
85    pub(crate) intercepts: HashMap<String, Array1<Float>>,
86    /// Task relationship matrix (similarity scores)
87    pub(crate) relationship_matrix: Array2<Float>,
88    /// Task names in order
89    pub(crate) task_names: Vec<String>,
90    /// Number of input features
91    pub(crate) n_features: usize,
92    /// Task configurations
93    pub(crate) task_outputs: HashMap<String, usize>,
94    /// Training parameters
95    pub(crate) relationship_strength: Float,
96    pub(crate) similarity_threshold: Float,
97    pub(crate) similarity_method: TaskSimilarityMethod,
98    /// Training iterations performed
99    pub(crate) n_iter: usize,
100}
101
102impl TaskRelationshipLearning<Untrained> {
103    /// Create a new TaskRelationshipLearning instance
104    pub fn new() -> Self {
105        Self {
106            state: Untrained,
107            relationship_strength: 1.0,
108            similarity_threshold: 0.5,
109            base_alpha: 1.0,
110            max_iter: 1000,
111            tolerance: 1e-4,
112            learning_rate: 0.01,
113            task_outputs: HashMap::new(),
114            fit_intercept: true,
115            similarity_method: TaskSimilarityMethod::Correlation,
116        }
117    }
118
119    /// Set relationship regularization strength
120    pub fn relationship_strength(mut self, strength: Float) -> Self {
121        self.relationship_strength = strength;
122        self
123    }
124
125    /// Set similarity threshold for relationships
126    pub fn similarity_threshold(mut self, threshold: Float) -> Self {
127        self.similarity_threshold = threshold;
128        self
129    }
130
131    /// Set base regularization strength
132    pub fn base_alpha(mut self, alpha: Float) -> Self {
133        self.base_alpha = alpha;
134        self
135    }
136
137    /// Set task similarity method
138    pub fn similarity_method(mut self, method: TaskSimilarityMethod) -> Self {
139        self.similarity_method = method;
140        self
141    }
142
143    /// Set maximum iterations
144    pub fn max_iter(mut self, max_iter: usize) -> Self {
145        self.max_iter = max_iter;
146        self
147    }
148
149    /// Set tolerance
150    pub fn tolerance(mut self, tolerance: Float) -> Self {
151        self.tolerance = tolerance;
152        self
153    }
154
155    /// Set learning rate
156    pub fn learning_rate(mut self, lr: Float) -> Self {
157        self.learning_rate = lr;
158        self
159    }
160
161    /// Set task outputs
162    pub fn task_outputs(mut self, outputs: &[(&str, usize)]) -> Self {
163        self.task_outputs = outputs
164            .iter()
165            .map(|(name, size)| (name.to_string(), *size))
166            .collect();
167        self
168    }
169}
170
171impl Default for TaskRelationshipLearning<Untrained> {
172    fn default() -> Self {
173        Self::new()
174    }
175}
176
177impl Estimator for TaskRelationshipLearning<Untrained> {
178    type Config = ();
179    type Error = SklearsError;
180    type Float = Float;
181
182    fn config(&self) -> &Self::Config {
183        &()
184    }
185}
186
187impl Fit<ArrayView2<'_, Float>, HashMap<String, Array2<Float>>>
188    for TaskRelationshipLearning<Untrained>
189{
190    type Fitted = TaskRelationshipLearning<TaskRelationshipLearningTrained>;
191
192    fn fit(
193        self,
194        X: &ArrayView2<'_, Float>,
195        y: &HashMap<String, Array2<Float>>,
196    ) -> SklResult<Self::Fitted> {
197        let x = X.to_owned();
198        let (n_samples, n_features) = x.dim();
199
200        if n_samples == 0 || n_features == 0 {
201            return Err(SklearsError::InvalidInput("Empty input data".to_string()));
202        }
203
204        let task_names: Vec<String> = y.keys().cloned().collect();
205        let n_tasks = task_names.len();
206
207        // Initialize task coefficients
208        let mut task_coefficients: HashMap<String, Array2<Float>> = HashMap::new();
209        let mut task_intercepts: HashMap<String, Array1<Float>> = HashMap::new();
210
211        let mut rng_gen = thread_rng();
212
213        for (task_name, y_task) in y {
214            let n_outputs = y_task.ncols();
215            let mut coef = Array2::<Float>::zeros((n_features, n_outputs));
216            let normal_dist = RandNormal::new(0.0, 0.1).unwrap();
217            for i in 0..n_features {
218                for j in 0..n_outputs {
219                    coef[[i, j]] = rng_gen.sample(normal_dist);
220                }
221            }
222            let intercept = Array1::<Float>::zeros(n_outputs);
223            task_coefficients.insert(task_name.clone(), coef);
224            task_intercepts.insert(task_name.clone(), intercept);
225        }
226
227        // Compute task similarity matrix
228        let mut relationship_matrix = Array2::<Float>::zeros((n_tasks, n_tasks));
229
230        for (i, task_i) in task_names.iter().enumerate() {
231            for (j, task_j) in task_names.iter().enumerate() {
232                if i != j {
233                    let similarity = self.compute_task_similarity(
234                        &y[task_i],
235                        &y[task_j],
236                        &self.similarity_method,
237                    );
238                    relationship_matrix[[i, j]] = similarity;
239                } else {
240                    relationship_matrix[[i, j]] = 1.0;
241                }
242            }
243        }
244
245        // Training loop
246        let mut prev_loss = Float::INFINITY;
247        let mut n_iter = 0;
248
249        for iteration in 0..self.max_iter {
250            let mut total_loss = 0.0;
251
252            // Update coefficients for each task
253            for (task_name, y_task) in y {
254                let current_coef = &task_coefficients[task_name];
255                let current_intercept = &task_intercepts[task_name];
256
257                // Compute predictions
258                let predictions = x.dot(current_coef);
259                let predictions_with_intercept = &predictions + current_intercept;
260
261                // Compute residuals
262                let residuals = &predictions_with_intercept - y_task;
263
264                // Compute gradients
265                let grad_coef = x.t().dot(&residuals) / (n_samples as Float);
266                let grad_intercept = residuals.sum_axis(Axis(0)) / (n_samples as Float);
267
268                // Add relationship regularization
269                let mut reg_grad_coef = grad_coef.clone();
270
271                // Find current task index
272                let task_idx = task_names.iter().position(|t| t == task_name).unwrap();
273
274                // Add relationship penalties
275                for (other_idx, other_task) in task_names.iter().enumerate() {
276                    if other_task != task_name {
277                        let similarity = relationship_matrix[[task_idx, other_idx]];
278                        if similarity > self.similarity_threshold {
279                            let relationship_penalty = &(current_coef
280                                - &task_coefficients[other_task])
281                                * self.relationship_strength
282                                * similarity;
283                            reg_grad_coef = reg_grad_coef + relationship_penalty;
284                        }
285                    }
286                }
287
288                // Update parameters
289                let new_coef = current_coef - &(&reg_grad_coef * self.learning_rate);
290                let new_intercept = current_intercept - &(&grad_intercept * self.learning_rate);
291
292                task_coefficients.insert(task_name.clone(), new_coef);
293                task_intercepts.insert(task_name.clone(), new_intercept);
294
295                // Add to loss
296                total_loss += residuals.mapv(|x| x * x).sum();
297            }
298
299            // Check convergence
300            if (prev_loss - total_loss).abs() < self.tolerance {
301                n_iter = iteration + 1;
302                break;
303            }
304            prev_loss = total_loss;
305            n_iter = iteration + 1;
306        }
307
308        Ok(TaskRelationshipLearning {
309            state: TaskRelationshipLearningTrained {
310                coefficients: task_coefficients,
311                intercepts: task_intercepts,
312                relationship_matrix,
313                task_names,
314                n_features,
315                task_outputs: self.task_outputs.clone(),
316                relationship_strength: self.relationship_strength,
317                similarity_threshold: self.similarity_threshold,
318                similarity_method: self.similarity_method.clone(),
319                n_iter,
320            },
321            relationship_strength: self.relationship_strength,
322            similarity_threshold: self.similarity_threshold,
323            base_alpha: self.base_alpha,
324            max_iter: self.max_iter,
325            tolerance: self.tolerance,
326            learning_rate: self.learning_rate,
327            task_outputs: self.task_outputs,
328            fit_intercept: self.fit_intercept,
329            similarity_method: self.similarity_method,
330        })
331    }
332}
333
334impl TaskRelationshipLearning<Untrained> {
335    fn compute_task_similarity(
336        &self,
337        y1: &Array2<Float>,
338        y2: &Array2<Float>,
339        method: &TaskSimilarityMethod,
340    ) -> Float {
341        match method {
342            TaskSimilarityMethod::Correlation => {
343                // Compute correlation between task outputs
344                let y1_flat: Vec<Float> = y1.iter().copied().collect();
345                let y2_flat: Vec<Float> = y2.iter().copied().collect();
346
347                if y1_flat.len() != y2_flat.len() {
348                    return 0.0;
349                }
350
351                let mean1: Float = y1_flat.iter().sum::<Float>() / y1_flat.len() as Float;
352                let mean2: Float = y2_flat.iter().sum::<Float>() / y2_flat.len() as Float;
353
354                let mut num = 0.0;
355                let mut den1 = 0.0;
356                let mut den2 = 0.0;
357
358                for (v1, v2) in y1_flat.iter().zip(y2_flat.iter()) {
359                    let d1 = v1 - mean1;
360                    let d2 = v2 - mean2;
361                    num += d1 * d2;
362                    den1 += d1 * d1;
363                    den2 += d2 * d2;
364                }
365
366                if den1 > 0.0 && den2 > 0.0 {
367                    (num / (den1.sqrt() * den2.sqrt())).abs()
368                } else {
369                    0.0
370                }
371            }
372            TaskSimilarityMethod::Cosine => {
373                // Compute cosine similarity
374                let y1_flat: Vec<Float> = y1.iter().copied().collect();
375                let y2_flat: Vec<Float> = y2.iter().copied().collect();
376
377                let dot_product: Float =
378                    y1_flat.iter().zip(y2_flat.iter()).map(|(a, b)| a * b).sum();
379                let norm1: Float = y1_flat.iter().map(|x| x * x).sum::<Float>().sqrt();
380                let norm2: Float = y2_flat.iter().map(|x| x * x).sum::<Float>().sqrt();
381
382                if norm1 > 0.0 && norm2 > 0.0 {
383                    (dot_product / (norm1 * norm2)).abs()
384                } else {
385                    0.0
386                }
387            }
388            TaskSimilarityMethod::Euclidean => {
389                // Compute inverse euclidean distance as similarity
390                let y1_flat: Vec<Float> = y1.iter().copied().collect();
391                let y2_flat: Vec<Float> = y2.iter().copied().collect();
392
393                let distance: Float = y1_flat
394                    .iter()
395                    .zip(y2_flat.iter())
396                    .map(|(a, b)| (a - b) * (a - b))
397                    .sum::<Float>()
398                    .sqrt();
399
400                1.0 / (1.0 + distance)
401            }
402            TaskSimilarityMethod::MutualInformation => {
403                // Simple approximation of mutual information using correlation
404                self.compute_task_similarity(y1, y2, &TaskSimilarityMethod::Correlation)
405            }
406        }
407    }
408}
409
410impl Predict<ArrayView2<'_, Float>, HashMap<String, Array2<Float>>>
411    for TaskRelationshipLearning<TaskRelationshipLearningTrained>
412{
413    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<HashMap<String, Array2<Float>>> {
414        let x = X.to_owned();
415        let (n_samples, n_features) = x.dim();
416
417        if n_features != self.state.n_features {
418            return Err(SklearsError::InvalidInput(
419                "Number of features doesn't match training data".to_string(),
420            ));
421        }
422
423        let mut predictions = HashMap::new();
424
425        for (task_name, coef) in &self.state.coefficients {
426            let task_predictions = x.dot(coef);
427            let intercept = &self.state.intercepts[task_name];
428            let final_predictions = &task_predictions + intercept;
429            predictions.insert(task_name.clone(), final_predictions);
430        }
431
432        Ok(predictions)
433    }
434}
435
436impl TaskRelationshipLearningTrained {
437    /// Get coefficients for a specific task
438    pub fn task_coefficients(&self, task_name: &str) -> Option<&Array2<Float>> {
439        self.coefficients.get(task_name)
440    }
441
442    /// Get intercepts for a specific task
443    pub fn task_intercepts(&self, task_name: &str) -> Option<&Array1<Float>> {
444        self.intercepts.get(task_name)
445    }
446
447    /// Get the relationship matrix (task similarity scores)
448    pub fn relationship_matrix(&self) -> &Array2<Float> {
449        &self.relationship_matrix
450    }
451
452    /// Get task names in order
453    pub fn task_names(&self) -> &Vec<String> {
454        &self.task_names
455    }
456
457    /// Get similarity score between two tasks
458    pub fn task_similarity(&self, task1: &str, task2: &str) -> Option<Float> {
459        let idx1 = self.task_names.iter().position(|t| t == task1)?;
460        let idx2 = self.task_names.iter().position(|t| t == task2)?;
461        Some(self.relationship_matrix[[idx1, idx2]])
462    }
463
464    /// Get related tasks for a given task (similarity above threshold)
465    pub fn related_tasks(&self, task_name: &str) -> Vec<(&String, Float)> {
466        if let Some(task_idx) = self.task_names.iter().position(|t| t == task_name) {
467            self.task_names
468                .iter()
469                .enumerate()
470                .filter_map(|(other_idx, other_task)| {
471                    if other_idx != task_idx {
472                        let similarity = self.relationship_matrix[[task_idx, other_idx]];
473                        if similarity > self.similarity_threshold {
474                            Some((other_task, similarity))
475                        } else {
476                            None
477                        }
478                    } else {
479                        None
480                    }
481                })
482                .collect()
483        } else {
484            Vec::new()
485        }
486    }
487
488    /// Get number of iterations performed
489    pub fn n_iter(&self) -> usize {
490        self.n_iter
491    }
492}