pub struct AdversarialMultiTaskNetwork<S = Untrained> { /* private fields */ }Expand description
Adversarial Multi-Task Network with feature disentanglement
This network implements adversarial multi-task learning where a task discriminator is trained to predict which task shared features come from, while the shared feature extractor is trained adversarially to fool the discriminator. This ensures that shared representations contain only task-invariant information.
§Architecture
The network consists of:
- Shared layers: Learn task-invariant representations
- Private layers: Learn task-specific representations per task
- Task discriminator: Tries to predict task from shared features
- Gradient reversal: Adversarial training mechanism
§Examples
use sklears_multioutput::adversarial::{AdversarialMultiTaskNetwork, AdversarialStrategy};
use sklears_core::traits::{Predict, Fit};
// Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
use scirs2_core::ndarray::array;
use std::collections::HashMap;
let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0]];
let mut tasks = HashMap::new();
tasks.insert("task1".to_string(), array![[0.5], [1.0], [1.5], [2.0]]);
tasks.insert("task2".to_string(), array![[1.0], [0.0], [1.0], [0.0]]);
let adv_net = AdversarialMultiTaskNetwork::new()
.shared_layers(vec![20, 10])
.private_layers(vec![8])
.task_outputs(&[("task1", 1), ("task2", 1)])
.adversarial_strategy(AdversarialStrategy::GradientReversal)
.adversarial_weight(0.1)
.orthogonality_weight(0.01)
.random_state(Some(42));Implementations§
Source§impl AdversarialMultiTaskNetwork<Untrained>
impl AdversarialMultiTaskNetwork<Untrained>
Set shared layer sizes
Sourcepub fn private_layers(self, sizes: Vec<usize>) -> Self
pub fn private_layers(self, sizes: Vec<usize>) -> Self
Set private layer sizes
Sourcepub fn task_outputs(self, tasks: &[(&str, usize)]) -> Self
pub fn task_outputs(self, tasks: &[(&str, usize)]) -> Self
Configure task outputs
Sourcepub fn adversarial_strategy(self, strategy: AdversarialStrategy) -> Self
pub fn adversarial_strategy(self, strategy: AdversarialStrategy) -> Self
Set adversarial strategy
Sourcepub fn adversarial_weight(self, weight: Float) -> Self
pub fn adversarial_weight(self, weight: Float) -> Self
Set adversarial weight
Sourcepub fn orthogonality_weight(self, weight: Float) -> Self
pub fn orthogonality_weight(self, weight: Float) -> Self
Set orthogonality weight
Sourcepub fn learning_rate(self, lr: Float) -> Self
pub fn learning_rate(self, lr: Float) -> Self
Set learning rate
Sourcepub fn random_state(self, seed: Option<u64>) -> Self
pub fn random_state(self, seed: Option<u64>) -> Self
Set random state
Source§impl AdversarialMultiTaskNetwork<AdversarialMultiTaskNetworkTrained>
impl AdversarialMultiTaskNetwork<AdversarialMultiTaskNetworkTrained>
Sourcepub fn adversarial_loss_curve(&self) -> &[Float] ⓘ
pub fn adversarial_loss_curve(&self) -> &[Float] ⓘ
Get adversarial loss curve
Sourcepub fn orthogonality_loss_curve(&self) -> &[Float] ⓘ
pub fn orthogonality_loss_curve(&self) -> &[Float] ⓘ
Get orthogonality loss curve
Sourcepub fn combined_loss_curve(&self) -> &[Float] ⓘ
pub fn combined_loss_curve(&self) -> &[Float] ⓘ
Get combined loss curve
Sourcepub fn discriminator_accuracy_curve(&self) -> &[Float] ⓘ
pub fn discriminator_accuracy_curve(&self) -> &[Float] ⓘ
Get discriminator accuracy curve
Trait Implementations§
Source§impl<S: Clone> Clone for AdversarialMultiTaskNetwork<S>
impl<S: Clone> Clone for AdversarialMultiTaskNetwork<S>
Source§fn clone(&self) -> AdversarialMultiTaskNetwork<S>
fn clone(&self) -> AdversarialMultiTaskNetwork<S>
Returns a duplicate of the value. Read more
1.0.0 · Source§fn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
Performs copy-assignment from
source. Read moreSource§impl<S: Debug> Debug for AdversarialMultiTaskNetwork<S>
impl<S: Debug> Debug for AdversarialMultiTaskNetwork<S>
Source§impl Estimator for AdversarialMultiTaskNetwork<Untrained>
impl Estimator for AdversarialMultiTaskNetwork<Untrained>
Source§type Config = AdversarialConfig
type Config = AdversarialConfig
Configuration type for the estimator
Source§type Error = SklearsError
type Error = SklearsError
Error type for the estimator
Source§fn validate_config(&self) -> Result<(), SklearsError>
fn validate_config(&self) -> Result<(), SklearsError>
Validate estimator configuration with detailed error context
Source§fn check_compatibility(
&self,
n_samples: usize,
n_features: usize,
) -> Result<(), SklearsError>
fn check_compatibility( &self, n_samples: usize, n_features: usize, ) -> Result<(), SklearsError>
Check if estimator is compatible with given data dimensions
Source§fn metadata(&self) -> EstimatorMetadata
fn metadata(&self) -> EstimatorMetadata
Get estimator metadata
Source§impl Fit<ArrayBase<ViewRepr<&f64>, Dim<[usize; 2]>>, HashMap<String, ArrayBase<OwnedRepr<f64>, Dim<[usize; 2]>>>> for AdversarialMultiTaskNetwork<Untrained>
impl Fit<ArrayBase<ViewRepr<&f64>, Dim<[usize; 2]>>, HashMap<String, ArrayBase<OwnedRepr<f64>, Dim<[usize; 2]>>>> for AdversarialMultiTaskNetwork<Untrained>
Source§type Fitted = AdversarialMultiTaskNetwork<AdversarialMultiTaskNetworkTrained>
type Fitted = AdversarialMultiTaskNetwork<AdversarialMultiTaskNetworkTrained>
The fitted model type
Source§fn fit(
self,
x: &ArrayView2<'_, Float>,
y: &HashMap<String, Array2<Float>>,
) -> SklResult<Self::Fitted>
fn fit( self, x: &ArrayView2<'_, Float>, y: &HashMap<String, Array2<Float>>, ) -> SklResult<Self::Fitted>
Fit the model to the provided data with validation
Source§fn fit_with_validation(
self,
x: &X,
y: &Y,
_x_val: Option<&X>,
_y_val: Option<&Y>,
) -> Result<(Self::Fitted, FitMetrics), SklearsError>where
Self: Sized,
fn fit_with_validation(
self,
x: &X,
y: &Y,
_x_val: Option<&X>,
_y_val: Option<&Y>,
) -> Result<(Self::Fitted, FitMetrics), SklearsError>where
Self: Sized,
Fit with custom validation and early stopping
Source§impl Predict<ArrayBase<ViewRepr<&f64>, Dim<[usize; 2]>>, HashMap<String, ArrayBase<OwnedRepr<f64>, Dim<[usize; 2]>>>> for AdversarialMultiTaskNetwork<AdversarialMultiTaskNetworkTrained>
impl Predict<ArrayBase<ViewRepr<&f64>, Dim<[usize; 2]>>, HashMap<String, ArrayBase<OwnedRepr<f64>, Dim<[usize; 2]>>>> for AdversarialMultiTaskNetwork<AdversarialMultiTaskNetworkTrained>
Source§fn predict(
&self,
X: &ArrayView2<'_, Float>,
) -> SklResult<HashMap<String, Array2<Float>>>
fn predict( &self, X: &ArrayView2<'_, Float>, ) -> SklResult<HashMap<String, Array2<Float>>>
Make predictions on the provided data
Source§fn predict_with_uncertainty(
&self,
x: &X,
) -> Result<(Output, UncertaintyMeasure), SklearsError>
fn predict_with_uncertainty( &self, x: &X, ) -> Result<(Output, UncertaintyMeasure), SklearsError>
Make predictions with confidence intervals
Auto Trait Implementations§
impl<S> Freeze for AdversarialMultiTaskNetwork<S>where
S: Freeze,
impl<S> RefUnwindSafe for AdversarialMultiTaskNetwork<S>where
S: RefUnwindSafe,
impl<S> Send for AdversarialMultiTaskNetwork<S>where
S: Send,
impl<S> Sync for AdversarialMultiTaskNetwork<S>where
S: Sync,
impl<S> Unpin for AdversarialMultiTaskNetwork<S>where
S: Unpin,
impl<S> UnwindSafe for AdversarialMultiTaskNetwork<S>where
S: UnwindSafe,
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more
Source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§impl<T> Pointable for T
impl<T> Pointable for T
Source§impl<T> StableApi for Twhere
T: Estimator,
impl<T> StableApi for Twhere
T: Estimator,
Source§const STABLE_SINCE: &'static str = "0.1.0"
const STABLE_SINCE: &'static str = "0.1.0"
API version this type was stabilized in
Source§const HAS_EXPERIMENTAL_FEATURES: bool = false
const HAS_EXPERIMENTAL_FEATURES: bool = false
Whether this API has any experimental features