pub struct MultiTaskNeuralNetwork<S = Untrained> { /* private fields */ }Expand description
Multi-Task Neural Network with Shared Representation Learning
This neural network implements multi-task learning where multiple related tasks share common representations in lower layers while having task-specific layers for final predictions. This approach allows for better generalization and improved performance when tasks are related.
§Architecture
The network consists of:
- Shared layers: Learn common representations across all tasks
- Task-specific layers: Learn task-specific transformations
- Multiple outputs: One output per task
§Examples
use sklears_multioutput::multitask::{MultiTaskNeuralNetwork, TaskBalancing};
use sklears_multioutput::activation::ActivationFunction;
use sklears_core::traits::{Predict, Fit};
// Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
use scirs2_core::ndarray::array;
use std::collections::HashMap;
let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0]];
let mut tasks = HashMap::new();
tasks.insert("task1".to_string(), array![[0.5], [1.0], [1.5], [2.0]]); // Regression task
tasks.insert("task2".to_string(), array![[1.0], [0.0], [1.0], [0.0]]); // Classification task
let mt_net = MultiTaskNeuralNetwork::new()
.shared_layers(vec![20, 10])
.task_specific_layers(vec![5])
.task_outputs(&[("task1", 1), ("task2", 1)])
.shared_activation(ActivationFunction::ReLU)
.learning_rate(0.01)
.max_iter(1000)
.task_weights(&[("task1", 1.0), ("task2", 0.8)])
.random_state(Some(42));Implementations§
Source§impl MultiTaskNeuralNetwork<Untrained>
impl MultiTaskNeuralNetwork<Untrained>
Set the sizes of shared representation layers
Sourcepub fn task_specific_layers(self, sizes: Vec<usize>) -> Self
pub fn task_specific_layers(self, sizes: Vec<usize>) -> Self
Set the sizes of task-specific layers
Sourcepub fn task_outputs(self, tasks: &[(&str, usize)]) -> Self
pub fn task_outputs(self, tasks: &[(&str, usize)]) -> Self
Configure task outputs
Sourcepub fn task_loss_functions(
self,
loss_functions: &[(&str, LossFunction)],
) -> Self
pub fn task_loss_functions( self, loss_functions: &[(&str, LossFunction)], ) -> Self
Set loss functions for specific tasks
Sourcepub fn task_weights(self, weights: &[(&str, Float)]) -> Self
pub fn task_weights(self, weights: &[(&str, Float)]) -> Self
Set task weights for multi-task loss computation
Set activation function for shared layers
Sourcepub fn task_activation(self, activation: ActivationFunction) -> Self
pub fn task_activation(self, activation: ActivationFunction) -> Self
Set activation function for task-specific layers
Sourcepub fn output_activations(
self,
activations: &[(&str, ActivationFunction)],
) -> Self
pub fn output_activations( self, activations: &[(&str, ActivationFunction)], ) -> Self
Set output activation functions for specific tasks
Sourcepub fn learning_rate(self, lr: Float) -> Self
pub fn learning_rate(self, lr: Float) -> Self
Set learning rate
Sourcepub fn random_state(self, seed: Option<u64>) -> Self
pub fn random_state(self, seed: Option<u64>) -> Self
Set random state for reproducibility
Sourcepub fn batch_size(self, batch_size: Option<usize>) -> Self
pub fn batch_size(self, batch_size: Option<usize>) -> Self
Set batch size for training
Sourcepub fn early_stopping(self, early_stopping: bool) -> Self
pub fn early_stopping(self, early_stopping: bool) -> Self
Enable/disable early stopping
Sourcepub fn validation_fraction(self, fraction: Float) -> Self
pub fn validation_fraction(self, fraction: Float) -> Self
Set validation fraction for early stopping
Sourcepub fn task_balancing(self, strategy: TaskBalancing) -> Self
pub fn task_balancing(self, strategy: TaskBalancing) -> Self
Set task balancing strategy
Source§impl MultiTaskNeuralNetwork<MultiTaskNeuralNetworkTrained>
impl MultiTaskNeuralNetwork<MultiTaskNeuralNetworkTrained>
Sourcepub fn task_loss_curves(&self) -> &HashMap<String, Vec<Float>>
pub fn task_loss_curves(&self) -> &HashMap<String, Vec<Float>>
Get the loss curves for all tasks
Sourcepub fn combined_loss_curve(&self) -> &[Float] ⓘ
pub fn combined_loss_curve(&self) -> &[Float] ⓘ
Get the combined loss curve
Sourcepub fn task_outputs(&self) -> &HashMap<String, usize>
pub fn task_outputs(&self) -> &HashMap<String, usize>
Get task configurations
Trait Implementations§
Source§impl<S: Clone> Clone for MultiTaskNeuralNetwork<S>
impl<S: Clone> Clone for MultiTaskNeuralNetwork<S>
Source§fn clone(&self) -> MultiTaskNeuralNetwork<S>
fn clone(&self) -> MultiTaskNeuralNetwork<S>
1.0.0 · Source§fn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
source. Read moreSource§impl<S: Debug> Debug for MultiTaskNeuralNetwork<S>
impl<S: Debug> Debug for MultiTaskNeuralNetwork<S>
Source§impl Default for MultiTaskNeuralNetwork<Untrained>
impl Default for MultiTaskNeuralNetwork<Untrained>
Source§impl Estimator for MultiTaskNeuralNetwork<Untrained>
impl Estimator for MultiTaskNeuralNetwork<Untrained>
Source§type Error = SklearsError
type Error = SklearsError
Source§fn validate_config(&self) -> Result<(), SklearsError>
fn validate_config(&self) -> Result<(), SklearsError>
Source§fn check_compatibility(
&self,
n_samples: usize,
n_features: usize,
) -> Result<(), SklearsError>
fn check_compatibility( &self, n_samples: usize, n_features: usize, ) -> Result<(), SklearsError>
Source§fn metadata(&self) -> EstimatorMetadata
fn metadata(&self) -> EstimatorMetadata
Source§impl Fit<ArrayBase<ViewRepr<&f64>, Dim<[usize; 2]>>, HashMap<String, ArrayBase<OwnedRepr<f64>, Dim<[usize; 2]>>>> for MultiTaskNeuralNetwork<Untrained>
impl Fit<ArrayBase<ViewRepr<&f64>, Dim<[usize; 2]>>, HashMap<String, ArrayBase<OwnedRepr<f64>, Dim<[usize; 2]>>>> for MultiTaskNeuralNetwork<Untrained>
Source§type Fitted = MultiTaskNeuralNetwork<MultiTaskNeuralNetworkTrained>
type Fitted = MultiTaskNeuralNetwork<MultiTaskNeuralNetworkTrained>
Source§fn fit(
self,
x: &ArrayView2<'_, Float>,
y: &HashMap<String, Array2<Float>>,
) -> SklResult<Self::Fitted>
fn fit( self, x: &ArrayView2<'_, Float>, y: &HashMap<String, Array2<Float>>, ) -> SklResult<Self::Fitted>
Source§fn fit_with_validation(
self,
x: &X,
y: &Y,
_x_val: Option<&X>,
_y_val: Option<&Y>,
) -> Result<(Self::Fitted, FitMetrics), SklearsError>where
Self: Sized,
fn fit_with_validation(
self,
x: &X,
y: &Y,
_x_val: Option<&X>,
_y_val: Option<&Y>,
) -> Result<(Self::Fitted, FitMetrics), SklearsError>where
Self: Sized,
Source§impl Predict<ArrayBase<ViewRepr<&f64>, Dim<[usize; 2]>>, HashMap<String, ArrayBase<OwnedRepr<f64>, Dim<[usize; 2]>>>> for MultiTaskNeuralNetwork<MultiTaskNeuralNetworkTrained>
impl Predict<ArrayBase<ViewRepr<&f64>, Dim<[usize; 2]>>, HashMap<String, ArrayBase<OwnedRepr<f64>, Dim<[usize; 2]>>>> for MultiTaskNeuralNetwork<MultiTaskNeuralNetworkTrained>
Source§fn predict(
&self,
X: &ArrayView2<'_, Float>,
) -> SklResult<HashMap<String, Array2<Float>>>
fn predict( &self, X: &ArrayView2<'_, Float>, ) -> SklResult<HashMap<String, Array2<Float>>>
Source§fn predict_with_uncertainty(
&self,
x: &X,
) -> Result<(Output, UncertaintyMeasure), SklearsError>
fn predict_with_uncertainty( &self, x: &X, ) -> Result<(Output, UncertaintyMeasure), SklearsError>
Auto Trait Implementations§
impl<S> Freeze for MultiTaskNeuralNetwork<S>where
S: Freeze,
impl<S> RefUnwindSafe for MultiTaskNeuralNetwork<S>where
S: RefUnwindSafe,
impl<S> Send for MultiTaskNeuralNetwork<S>where
S: Send,
impl<S> Sync for MultiTaskNeuralNetwork<S>where
S: Sync,
impl<S> Unpin for MultiTaskNeuralNetwork<S>where
S: Unpin,
impl<S> UnwindSafe for MultiTaskNeuralNetwork<S>where
S: UnwindSafe,
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more