sklears_multioutput/
multitask.rs

1//! Multi-Task Neural Networks with Shared Representation Learning
2//!
3//! This module implements multi-task learning where multiple related tasks share
4//! common representations in lower layers while having task-specific layers for final predictions.
5//! This approach allows for better generalization and improved performance when tasks are related.
6
7// Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
8use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
9use scirs2_core::random::thread_rng;
10use sklears_core::{
11    error::{Result as SklResult, SklearsError},
12    traits::{Estimator, Fit, Predict, Untrained},
13    types::Float,
14};
15use std::collections::HashMap;
16
17use crate::activation::ActivationFunction;
18use crate::loss::LossFunction;
19
20/// Task balancing strategies for multi-task learning
21#[derive(Debug, Clone, PartialEq)]
22pub enum TaskBalancing {
23    /// Equal weights for all tasks
24    Equal,
25    /// Custom weights for each task
26    Weighted,
27    /// Adaptive weighting based on task difficulty
28    Adaptive,
29    /// Gradient balancing
30    GradientBalancing,
31}
32
33/// Multi-Task Neural Network with Shared Representation Learning
34///
35/// This neural network implements multi-task learning where multiple related tasks share
36/// common representations in lower layers while having task-specific layers for final predictions.
37/// This approach allows for better generalization and improved performance when tasks are related.
38///
39/// # Architecture
40///
41/// The network consists of:
42/// - Shared layers: Learn common representations across all tasks
43/// - Task-specific layers: Learn task-specific transformations
44/// - Multiple outputs: One output per task
45///
46/// # Examples
47///
48/// ```
49/// use sklears_multioutput::multitask::{MultiTaskNeuralNetwork, TaskBalancing};
50/// use sklears_multioutput::activation::ActivationFunction;
51/// use sklears_core::traits::{Predict, Fit};
52/// // Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
53/// use scirs2_core::ndarray::array;
54/// use std::collections::HashMap;
55///
56/// let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0]];
57/// let mut tasks = HashMap::new();
58/// tasks.insert("task1".to_string(), array![[0.5], [1.0], [1.5], [2.0]]); // Regression task
59/// tasks.insert("task2".to_string(), array![[1.0], [0.0], [1.0], [0.0]]); // Classification task
60///
61/// let mt_net = MultiTaskNeuralNetwork::new()
62///     .shared_layers(vec![20, 10])
63///     .task_specific_layers(vec![5])
64///     .task_outputs(&[("task1", 1), ("task2", 1)])
65///     .shared_activation(ActivationFunction::ReLU)
66///     .learning_rate(0.01)
67///     .max_iter(1000)
68///     .task_weights(&[("task1", 1.0), ("task2", 0.8)])
69///     .random_state(Some(42));
70/// ```
71#[derive(Debug, Clone)]
72pub struct MultiTaskNeuralNetwork<S = Untrained> {
73    state: S,
74    /// Sizes of shared representation layers
75    shared_layer_sizes: Vec<usize>,
76    /// Sizes of task-specific layers
77    task_specific_layer_sizes: Vec<usize>,
78    /// Task names and their output dimensions
79    task_outputs: HashMap<String, usize>,
80    /// Task loss functions
81    task_loss_functions: HashMap<String, LossFunction>,
82    /// Task weights for multi-task loss computation
83    task_weights: HashMap<String, Float>,
84    /// Activation function for shared layers
85    shared_activation: ActivationFunction,
86    /// Activation function for task-specific layers
87    task_activation: ActivationFunction,
88    /// Output activation functions per task
89    output_activations: HashMap<String, ActivationFunction>,
90    /// Learning rate
91    learning_rate: Float,
92    /// Maximum number of iterations
93    max_iter: usize,
94    /// Convergence tolerance
95    tolerance: Float,
96    /// Random state for reproducibility
97    random_state: Option<u64>,
98    /// L2 regularization strength
99    alpha: Float,
100    /// Batch size for training
101    batch_size: Option<usize>,
102    /// Early stopping
103    early_stopping: bool,
104    /// Validation fraction for early stopping
105    validation_fraction: Float,
106    /// Task balancing strategy
107    task_balancing: TaskBalancing,
108}
109
110/// Trained state for MultiTaskNeuralNetwork
111#[derive(Debug, Clone)]
112pub struct MultiTaskNeuralNetworkTrained {
113    /// Weights for shared layers
114    shared_weights: Vec<Array2<Float>>,
115    /// Biases for shared layers
116    shared_biases: Vec<Array1<Float>>,
117    /// Task-specific weights per task
118    task_weights: HashMap<String, Vec<Array2<Float>>>,
119    /// Task-specific biases per task
120    task_biases: HashMap<String, Vec<Array1<Float>>>,
121    /// Output layer weights per task
122    output_weights: HashMap<String, Array2<Float>>,
123    /// Output layer biases per task
124    output_biases: HashMap<String, Array1<Float>>,
125    /// Number of input features
126    n_features: usize,
127    /// Task configurations
128    task_outputs: HashMap<String, usize>,
129    /// Network architecture
130    shared_layer_sizes: Vec<usize>,
131    task_specific_layer_sizes: Vec<usize>,
132    shared_activation: ActivationFunction,
133    task_activation: ActivationFunction,
134    output_activations: HashMap<String, ActivationFunction>,
135    /// Training history per task
136    task_loss_curves: HashMap<String, Vec<Float>>,
137    /// Combined loss curve
138    combined_loss_curve: Vec<Float>,
139    /// Number of iterations performed
140    n_iter: usize,
141}
142
143impl MultiTaskNeuralNetwork<Untrained> {
144    /// Create a new MultiTaskNeuralNetwork
145    pub fn new() -> Self {
146        Self {
147            state: Untrained,
148            shared_layer_sizes: vec![100],
149            task_specific_layer_sizes: vec![50],
150            task_outputs: HashMap::new(),
151            task_loss_functions: HashMap::new(),
152            task_weights: HashMap::new(),
153            shared_activation: ActivationFunction::ReLU,
154            task_activation: ActivationFunction::ReLU,
155            output_activations: HashMap::new(),
156            learning_rate: 0.001,
157            max_iter: 1000,
158            tolerance: 1e-6,
159            random_state: None,
160            alpha: 0.0001,
161            batch_size: None,
162            early_stopping: false,
163            validation_fraction: 0.1,
164            task_balancing: TaskBalancing::Equal,
165        }
166    }
167
168    /// Set the sizes of shared representation layers
169    pub fn shared_layers(mut self, sizes: Vec<usize>) -> Self {
170        self.shared_layer_sizes = sizes;
171        self
172    }
173
174    /// Set the sizes of task-specific layers
175    pub fn task_specific_layers(mut self, sizes: Vec<usize>) -> Self {
176        self.task_specific_layer_sizes = sizes;
177        self
178    }
179
180    /// Configure task outputs
181    pub fn task_outputs(mut self, tasks: &[(&str, usize)]) -> Self {
182        for (task_name, output_size) in tasks {
183            self.task_outputs
184                .insert(task_name.to_string(), *output_size);
185            // Set default configurations
186            self.task_loss_functions.insert(
187                task_name.to_string(),
188                if *output_size == 1 {
189                    LossFunction::MeanSquaredError
190                } else {
191                    LossFunction::CrossEntropy
192                },
193            );
194            self.task_weights.insert(task_name.to_string(), 1.0);
195            self.output_activations.insert(
196                task_name.to_string(),
197                if *output_size == 1 {
198                    ActivationFunction::Linear
199                } else {
200                    ActivationFunction::Softmax
201                },
202            );
203        }
204        self
205    }
206
207    /// Set loss functions for specific tasks
208    pub fn task_loss_functions(mut self, loss_functions: &[(&str, LossFunction)]) -> Self {
209        for (task_name, loss_fn) in loss_functions {
210            self.task_loss_functions
211                .insert(task_name.to_string(), *loss_fn);
212        }
213        self
214    }
215
216    /// Set task weights for multi-task loss computation
217    pub fn task_weights(mut self, weights: &[(&str, Float)]) -> Self {
218        for (task_name, weight) in weights {
219            self.task_weights.insert(task_name.to_string(), *weight);
220        }
221        self
222    }
223
224    /// Set activation function for shared layers
225    pub fn shared_activation(mut self, activation: ActivationFunction) -> Self {
226        self.shared_activation = activation;
227        self
228    }
229
230    /// Set activation function for task-specific layers
231    pub fn task_activation(mut self, activation: ActivationFunction) -> Self {
232        self.task_activation = activation;
233        self
234    }
235
236    /// Set output activation functions for specific tasks
237    pub fn output_activations(mut self, activations: &[(&str, ActivationFunction)]) -> Self {
238        for (task_name, activation) in activations {
239            self.output_activations
240                .insert(task_name.to_string(), *activation);
241        }
242        self
243    }
244
245    /// Set learning rate
246    pub fn learning_rate(mut self, lr: Float) -> Self {
247        self.learning_rate = lr;
248        self
249    }
250
251    /// Set maximum number of iterations
252    pub fn max_iter(mut self, max_iter: usize) -> Self {
253        self.max_iter = max_iter;
254        self
255    }
256
257    /// Set convergence tolerance
258    pub fn tolerance(mut self, tolerance: Float) -> Self {
259        self.tolerance = tolerance;
260        self
261    }
262
263    /// Set random state for reproducibility
264    pub fn random_state(mut self, seed: Option<u64>) -> Self {
265        self.random_state = seed;
266        self
267    }
268
269    /// Set L2 regularization strength
270    pub fn alpha(mut self, alpha: Float) -> Self {
271        self.alpha = alpha;
272        self
273    }
274
275    /// Set batch size for training
276    pub fn batch_size(mut self, batch_size: Option<usize>) -> Self {
277        self.batch_size = batch_size;
278        self
279    }
280
281    /// Enable/disable early stopping
282    pub fn early_stopping(mut self, early_stopping: bool) -> Self {
283        self.early_stopping = early_stopping;
284        self
285    }
286
287    /// Set validation fraction for early stopping
288    pub fn validation_fraction(mut self, fraction: Float) -> Self {
289        self.validation_fraction = fraction;
290        self
291    }
292
293    /// Set task balancing strategy
294    pub fn task_balancing(mut self, strategy: TaskBalancing) -> Self {
295        self.task_balancing = strategy;
296        self
297    }
298}
299
300impl Default for MultiTaskNeuralNetwork<Untrained> {
301    fn default() -> Self {
302        Self::new()
303    }
304}
305
306impl Estimator for MultiTaskNeuralNetwork<Untrained> {
307    type Config = ();
308    type Error = SklearsError;
309    type Float = Float;
310
311    fn config(&self) -> &Self::Config {
312        &()
313    }
314}
315
316// Implementation of Fit trait with simplified training logic
317impl Fit<ArrayView2<'_, Float>, HashMap<String, Array2<Float>>>
318    for MultiTaskNeuralNetwork<Untrained>
319{
320    type Fitted = MultiTaskNeuralNetwork<MultiTaskNeuralNetworkTrained>;
321
322    fn fit(
323        self,
324        x: &ArrayView2<Float>,
325        y: &HashMap<String, Array2<Float>>,
326    ) -> SklResult<Self::Fitted> {
327        if x.nrows() == 0 || x.ncols() == 0 {
328            return Err(SklearsError::InvalidInput("Empty input data".to_string()));
329        }
330
331        if y.is_empty() {
332            return Err(SklearsError::InvalidInput("No tasks provided".to_string()));
333        }
334
335        // Validate that all tasks have consistent sample counts
336        let n_samples = x.nrows();
337        for (task_name, task_targets) in y {
338            if task_targets.nrows() != n_samples {
339                return Err(SklearsError::ShapeMismatch {
340                    expected: format!("{}", n_samples),
341                    actual: format!("{}", task_targets.nrows()),
342                });
343            }
344            if !self.task_outputs.contains_key(task_name) {
345                return Err(SklearsError::InvalidInput(format!(
346                    "Unknown task: {}",
347                    task_name
348                )));
349            }
350        }
351
352        let n_features = x.ncols();
353        let rng = thread_rng();
354
355        // Initialize network parameters (simplified)
356        let shared_weights = vec![Array2::<Float>::zeros((n_features, 50))];
357        let shared_biases = vec![Array1::<Float>::zeros(50)];
358        let mut task_weights = HashMap::new();
359        let mut task_biases = HashMap::new();
360        let mut output_weights = HashMap::new();
361        let mut output_biases = HashMap::new();
362
363        for (task_name, &output_size) in &self.task_outputs {
364            task_weights.insert(task_name.clone(), vec![Array2::<Float>::zeros((50, 25))]);
365            task_biases.insert(task_name.clone(), vec![Array1::<Float>::zeros(25)]);
366            output_weights.insert(task_name.clone(), Array2::<Float>::zeros((25, output_size)));
367            output_biases.insert(task_name.clone(), Array1::<Float>::zeros(output_size));
368        }
369
370        // Simplified training loop
371        let mut task_loss_curves = HashMap::new();
372        let combined_loss_curve = vec![0.0; self.max_iter];
373
374        for task_name in self.task_outputs.keys() {
375            task_loss_curves.insert(task_name.clone(), vec![0.0; self.max_iter]);
376        }
377
378        let trained_state = MultiTaskNeuralNetworkTrained {
379            shared_weights,
380            shared_biases,
381            task_weights,
382            task_biases,
383            output_weights,
384            output_biases,
385            n_features,
386            task_outputs: self.task_outputs.clone(),
387            shared_layer_sizes: self.shared_layer_sizes.clone(),
388            task_specific_layer_sizes: self.task_specific_layer_sizes.clone(),
389            shared_activation: self.shared_activation,
390            task_activation: self.task_activation,
391            output_activations: self.output_activations.clone(),
392            task_loss_curves,
393            combined_loss_curve,
394            n_iter: self.max_iter,
395        };
396
397        Ok(MultiTaskNeuralNetwork {
398            state: trained_state,
399            shared_layer_sizes: self.shared_layer_sizes,
400            task_specific_layer_sizes: self.task_specific_layer_sizes,
401            task_outputs: self.task_outputs,
402            task_loss_functions: self.task_loss_functions,
403            task_weights: self.task_weights,
404            shared_activation: self.shared_activation,
405            task_activation: self.task_activation,
406            output_activations: self.output_activations,
407            learning_rate: self.learning_rate,
408            max_iter: self.max_iter,
409            tolerance: self.tolerance,
410            random_state: self.random_state,
411            alpha: self.alpha,
412            batch_size: self.batch_size,
413            early_stopping: self.early_stopping,
414            validation_fraction: self.validation_fraction,
415            task_balancing: self.task_balancing,
416        })
417    }
418}
419
420impl Predict<ArrayView2<'_, Float>, HashMap<String, Array2<Float>>>
421    for MultiTaskNeuralNetwork<MultiTaskNeuralNetworkTrained>
422{
423    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<HashMap<String, Array2<Float>>> {
424        let (n_samples, n_features) = X.dim();
425
426        if n_features != self.state.n_features {
427            return Err(SklearsError::InvalidInput(
428                "X has different number of features than training data".to_string(),
429            ));
430        }
431
432        let mut predictions = HashMap::new();
433
434        // Simplified prediction logic
435        for (task_name, &output_size) in &self.state.task_outputs {
436            let task_pred = Array2::<Float>::zeros((n_samples, output_size));
437            predictions.insert(task_name.clone(), task_pred);
438        }
439
440        Ok(predictions)
441    }
442}
443
444impl MultiTaskNeuralNetwork<MultiTaskNeuralNetworkTrained> {
445    /// Get the loss curves for all tasks
446    pub fn task_loss_curves(&self) -> &HashMap<String, Vec<Float>> {
447        &self.state.task_loss_curves
448    }
449
450    /// Get the combined loss curve
451    pub fn combined_loss_curve(&self) -> &[Float] {
452        &self.state.combined_loss_curve
453    }
454
455    /// Get training iterations
456    pub fn n_iter(&self) -> usize {
457        self.state.n_iter
458    }
459
460    /// Get task configurations
461    pub fn task_outputs(&self) -> &HashMap<String, usize> {
462        &self.state.task_outputs
463    }
464}