sklears_multioutput/regularization/
task_relationship.rs1use scirs2_core::ndarray::{Array1, Array2, ArrayView2, Axis};
9use scirs2_core::random::thread_rng;
10use scirs2_core::random::RandNormal;
11use sklears_core::{
12 error::{Result as SklResult, SklearsError},
13 traits::{Estimator, Fit, Predict, Untrained},
14 types::Float,
15};
16use std::collections::HashMap;
17
18#[derive(Debug, Clone, PartialEq)]
20pub enum TaskSimilarityMethod {
21 Correlation,
23 Cosine,
25 Euclidean,
27 MutualInformation,
29}
30
31#[derive(Debug, Clone)]
57pub struct TaskRelationshipLearning<S = Untrained> {
58 pub(crate) state: S,
59 pub(crate) relationship_strength: Float,
61 pub(crate) similarity_threshold: Float,
63 pub(crate) base_alpha: Float,
65 pub(crate) max_iter: usize,
67 pub(crate) tolerance: Float,
69 pub(crate) learning_rate: Float,
71 pub(crate) task_outputs: HashMap<String, usize>,
73 pub(crate) fit_intercept: bool,
75 pub(crate) similarity_method: TaskSimilarityMethod,
77}
78
79#[derive(Debug, Clone)]
81pub struct TaskRelationshipLearningTrained {
82 pub(crate) coefficients: HashMap<String, Array2<Float>>,
84 pub(crate) intercepts: HashMap<String, Array1<Float>>,
86 pub(crate) relationship_matrix: Array2<Float>,
88 pub(crate) task_names: Vec<String>,
90 pub(crate) n_features: usize,
92 pub(crate) task_outputs: HashMap<String, usize>,
94 pub(crate) relationship_strength: Float,
96 pub(crate) similarity_threshold: Float,
97 pub(crate) similarity_method: TaskSimilarityMethod,
98 pub(crate) n_iter: usize,
100}
101
102impl TaskRelationshipLearning<Untrained> {
103 pub fn new() -> Self {
105 Self {
106 state: Untrained,
107 relationship_strength: 1.0,
108 similarity_threshold: 0.5,
109 base_alpha: 1.0,
110 max_iter: 1000,
111 tolerance: 1e-4,
112 learning_rate: 0.01,
113 task_outputs: HashMap::new(),
114 fit_intercept: true,
115 similarity_method: TaskSimilarityMethod::Correlation,
116 }
117 }
118
119 pub fn relationship_strength(mut self, strength: Float) -> Self {
121 self.relationship_strength = strength;
122 self
123 }
124
125 pub fn similarity_threshold(mut self, threshold: Float) -> Self {
127 self.similarity_threshold = threshold;
128 self
129 }
130
131 pub fn base_alpha(mut self, alpha: Float) -> Self {
133 self.base_alpha = alpha;
134 self
135 }
136
137 pub fn similarity_method(mut self, method: TaskSimilarityMethod) -> Self {
139 self.similarity_method = method;
140 self
141 }
142
143 pub fn max_iter(mut self, max_iter: usize) -> Self {
145 self.max_iter = max_iter;
146 self
147 }
148
149 pub fn tolerance(mut self, tolerance: Float) -> Self {
151 self.tolerance = tolerance;
152 self
153 }
154
155 pub fn learning_rate(mut self, lr: Float) -> Self {
157 self.learning_rate = lr;
158 self
159 }
160
161 pub fn task_outputs(mut self, outputs: &[(&str, usize)]) -> Self {
163 self.task_outputs = outputs
164 .iter()
165 .map(|(name, size)| (name.to_string(), *size))
166 .collect();
167 self
168 }
169}
170
171impl Default for TaskRelationshipLearning<Untrained> {
172 fn default() -> Self {
173 Self::new()
174 }
175}
176
177impl Estimator for TaskRelationshipLearning<Untrained> {
178 type Config = ();
179 type Error = SklearsError;
180 type Float = Float;
181
182 fn config(&self) -> &Self::Config {
183 &()
184 }
185}
186
187impl Fit<ArrayView2<'_, Float>, HashMap<String, Array2<Float>>>
188 for TaskRelationshipLearning<Untrained>
189{
190 type Fitted = TaskRelationshipLearning<TaskRelationshipLearningTrained>;
191
192 fn fit(
193 self,
194 X: &ArrayView2<'_, Float>,
195 y: &HashMap<String, Array2<Float>>,
196 ) -> SklResult<Self::Fitted> {
197 let x = X.to_owned();
198 let (n_samples, n_features) = x.dim();
199
200 if n_samples == 0 || n_features == 0 {
201 return Err(SklearsError::InvalidInput("Empty input data".to_string()));
202 }
203
204 let task_names: Vec<String> = y.keys().cloned().collect();
205 let n_tasks = task_names.len();
206
207 let mut task_coefficients: HashMap<String, Array2<Float>> = HashMap::new();
209 let mut task_intercepts: HashMap<String, Array1<Float>> = HashMap::new();
210
211 let mut rng_gen = thread_rng();
212
213 for (task_name, y_task) in y {
214 let n_outputs = y_task.ncols();
215 let mut coef = Array2::<Float>::zeros((n_features, n_outputs));
216 let normal_dist = RandNormal::new(0.0, 0.1).unwrap();
217 for i in 0..n_features {
218 for j in 0..n_outputs {
219 coef[[i, j]] = rng_gen.sample(normal_dist);
220 }
221 }
222 let intercept = Array1::<Float>::zeros(n_outputs);
223 task_coefficients.insert(task_name.clone(), coef);
224 task_intercepts.insert(task_name.clone(), intercept);
225 }
226
227 let mut relationship_matrix = Array2::<Float>::zeros((n_tasks, n_tasks));
229
230 for (i, task_i) in task_names.iter().enumerate() {
231 for (j, task_j) in task_names.iter().enumerate() {
232 if i != j {
233 let similarity = self.compute_task_similarity(
234 &y[task_i],
235 &y[task_j],
236 &self.similarity_method,
237 );
238 relationship_matrix[[i, j]] = similarity;
239 } else {
240 relationship_matrix[[i, j]] = 1.0;
241 }
242 }
243 }
244
245 let mut prev_loss = Float::INFINITY;
247 let mut n_iter = 0;
248
249 for iteration in 0..self.max_iter {
250 let mut total_loss = 0.0;
251
252 for (task_name, y_task) in y {
254 let current_coef = &task_coefficients[task_name];
255 let current_intercept = &task_intercepts[task_name];
256
257 let predictions = x.dot(current_coef);
259 let predictions_with_intercept = &predictions + current_intercept;
260
261 let residuals = &predictions_with_intercept - y_task;
263
264 let grad_coef = x.t().dot(&residuals) / (n_samples as Float);
266 let grad_intercept = residuals.sum_axis(Axis(0)) / (n_samples as Float);
267
268 let mut reg_grad_coef = grad_coef.clone();
270
271 let task_idx = task_names.iter().position(|t| t == task_name).unwrap();
273
274 for (other_idx, other_task) in task_names.iter().enumerate() {
276 if other_task != task_name {
277 let similarity = relationship_matrix[[task_idx, other_idx]];
278 if similarity > self.similarity_threshold {
279 let relationship_penalty = &(current_coef
280 - &task_coefficients[other_task])
281 * self.relationship_strength
282 * similarity;
283 reg_grad_coef = reg_grad_coef + relationship_penalty;
284 }
285 }
286 }
287
288 let new_coef = current_coef - &(®_grad_coef * self.learning_rate);
290 let new_intercept = current_intercept - &(&grad_intercept * self.learning_rate);
291
292 task_coefficients.insert(task_name.clone(), new_coef);
293 task_intercepts.insert(task_name.clone(), new_intercept);
294
295 total_loss += residuals.mapv(|x| x * x).sum();
297 }
298
299 if (prev_loss - total_loss).abs() < self.tolerance {
301 n_iter = iteration + 1;
302 break;
303 }
304 prev_loss = total_loss;
305 n_iter = iteration + 1;
306 }
307
308 Ok(TaskRelationshipLearning {
309 state: TaskRelationshipLearningTrained {
310 coefficients: task_coefficients,
311 intercepts: task_intercepts,
312 relationship_matrix,
313 task_names,
314 n_features,
315 task_outputs: self.task_outputs.clone(),
316 relationship_strength: self.relationship_strength,
317 similarity_threshold: self.similarity_threshold,
318 similarity_method: self.similarity_method.clone(),
319 n_iter,
320 },
321 relationship_strength: self.relationship_strength,
322 similarity_threshold: self.similarity_threshold,
323 base_alpha: self.base_alpha,
324 max_iter: self.max_iter,
325 tolerance: self.tolerance,
326 learning_rate: self.learning_rate,
327 task_outputs: self.task_outputs,
328 fit_intercept: self.fit_intercept,
329 similarity_method: self.similarity_method,
330 })
331 }
332}
333
334impl TaskRelationshipLearning<Untrained> {
335 fn compute_task_similarity(
336 &self,
337 y1: &Array2<Float>,
338 y2: &Array2<Float>,
339 method: &TaskSimilarityMethod,
340 ) -> Float {
341 match method {
342 TaskSimilarityMethod::Correlation => {
343 let y1_flat: Vec<Float> = y1.iter().copied().collect();
345 let y2_flat: Vec<Float> = y2.iter().copied().collect();
346
347 if y1_flat.len() != y2_flat.len() {
348 return 0.0;
349 }
350
351 let mean1: Float = y1_flat.iter().sum::<Float>() / y1_flat.len() as Float;
352 let mean2: Float = y2_flat.iter().sum::<Float>() / y2_flat.len() as Float;
353
354 let mut num = 0.0;
355 let mut den1 = 0.0;
356 let mut den2 = 0.0;
357
358 for (v1, v2) in y1_flat.iter().zip(y2_flat.iter()) {
359 let d1 = v1 - mean1;
360 let d2 = v2 - mean2;
361 num += d1 * d2;
362 den1 += d1 * d1;
363 den2 += d2 * d2;
364 }
365
366 if den1 > 0.0 && den2 > 0.0 {
367 (num / (den1.sqrt() * den2.sqrt())).abs()
368 } else {
369 0.0
370 }
371 }
372 TaskSimilarityMethod::Cosine => {
373 let y1_flat: Vec<Float> = y1.iter().copied().collect();
375 let y2_flat: Vec<Float> = y2.iter().copied().collect();
376
377 let dot_product: Float =
378 y1_flat.iter().zip(y2_flat.iter()).map(|(a, b)| a * b).sum();
379 let norm1: Float = y1_flat.iter().map(|x| x * x).sum::<Float>().sqrt();
380 let norm2: Float = y2_flat.iter().map(|x| x * x).sum::<Float>().sqrt();
381
382 if norm1 > 0.0 && norm2 > 0.0 {
383 (dot_product / (norm1 * norm2)).abs()
384 } else {
385 0.0
386 }
387 }
388 TaskSimilarityMethod::Euclidean => {
389 let y1_flat: Vec<Float> = y1.iter().copied().collect();
391 let y2_flat: Vec<Float> = y2.iter().copied().collect();
392
393 let distance: Float = y1_flat
394 .iter()
395 .zip(y2_flat.iter())
396 .map(|(a, b)| (a - b) * (a - b))
397 .sum::<Float>()
398 .sqrt();
399
400 1.0 / (1.0 + distance)
401 }
402 TaskSimilarityMethod::MutualInformation => {
403 self.compute_task_similarity(y1, y2, &TaskSimilarityMethod::Correlation)
405 }
406 }
407 }
408}
409
410impl Predict<ArrayView2<'_, Float>, HashMap<String, Array2<Float>>>
411 for TaskRelationshipLearning<TaskRelationshipLearningTrained>
412{
413 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<HashMap<String, Array2<Float>>> {
414 let x = X.to_owned();
415 let (n_samples, n_features) = x.dim();
416
417 if n_features != self.state.n_features {
418 return Err(SklearsError::InvalidInput(
419 "Number of features doesn't match training data".to_string(),
420 ));
421 }
422
423 let mut predictions = HashMap::new();
424
425 for (task_name, coef) in &self.state.coefficients {
426 let task_predictions = x.dot(coef);
427 let intercept = &self.state.intercepts[task_name];
428 let final_predictions = &task_predictions + intercept;
429 predictions.insert(task_name.clone(), final_predictions);
430 }
431
432 Ok(predictions)
433 }
434}
435
436impl TaskRelationshipLearningTrained {
437 pub fn task_coefficients(&self, task_name: &str) -> Option<&Array2<Float>> {
439 self.coefficients.get(task_name)
440 }
441
442 pub fn task_intercepts(&self, task_name: &str) -> Option<&Array1<Float>> {
444 self.intercepts.get(task_name)
445 }
446
447 pub fn relationship_matrix(&self) -> &Array2<Float> {
449 &self.relationship_matrix
450 }
451
452 pub fn task_names(&self) -> &Vec<String> {
454 &self.task_names
455 }
456
457 pub fn task_similarity(&self, task1: &str, task2: &str) -> Option<Float> {
459 let idx1 = self.task_names.iter().position(|t| t == task1)?;
460 let idx2 = self.task_names.iter().position(|t| t == task2)?;
461 Some(self.relationship_matrix[[idx1, idx2]])
462 }
463
464 pub fn related_tasks(&self, task_name: &str) -> Vec<(&String, Float)> {
466 if let Some(task_idx) = self.task_names.iter().position(|t| t == task_name) {
467 self.task_names
468 .iter()
469 .enumerate()
470 .filter_map(|(other_idx, other_task)| {
471 if other_idx != task_idx {
472 let similarity = self.relationship_matrix[[task_idx, other_idx]];
473 if similarity > self.similarity_threshold {
474 Some((other_task, similarity))
475 } else {
476 None
477 }
478 } else {
479 None
480 }
481 })
482 .collect()
483 } else {
484 Vec::new()
485 }
486 }
487
488 pub fn n_iter(&self) -> usize {
490 self.n_iter
491 }
492}