Skip to main content

Module multi_task_learning

Module multi_task_learning 

Source
Expand description

§Multi-Task Learning Framework

This module provides a comprehensive framework for multi-task learning, enabling models to learn multiple related tasks simultaneously to improve generalization and efficiency.

§Features

  • Multiple MTL Architectures: Hard parameter sharing, soft parameter sharing, task-specific layers
  • Loss Balancing: Various strategies for balancing losses across tasks
  • Task Weighting: Dynamic and static task weight adjustment
  • Auxiliary Tasks: Support for auxiliary tasks to improve main task performance
  • Task Clustering: Grouping related tasks for better sharing
  • Evaluation Metrics: Specialized metrics for multi-task scenarios

§Usage

use trustformers_models::multi_task_learning::{
    MultiTaskLearningTrainer, MTLConfig, MTLArchitecture
};

let config = MTLConfig {
    architecture: MTLArchitecture::HardParameterSharing {
        shared_layers: 8,
        task_specific_layers: 2,
    },
    loss_balancing: LossBalancingStrategy::DynamicWeightAverage,
    tasks: vec![
        TaskConfig::new("classification", TaskType::Classification { num_classes: 10 }),
        TaskConfig::new("regression", TaskType::Regression { output_dim: 1 }),
    ],
    ..Default::default()
};

let mut trainer = MultiTaskLearningTrainer::new(config)?;
trainer.train_multi_task(task_data)?;

Modules§

utils
Utilities for multi-task learning

Structs§

AuxiliaryTaskConfig
Auxiliary task configuration
GradientStats
Gradient statistics for task balancing
MTLAnalysis
Analysis of multi-task learning effectiveness
MTLConfig
Configuration for multi-task learning
MTLStats
Multi-task learning statistics
MultiTaskEvaluation
Multi-task evaluation results
MultiTaskLearningTrainer
Multi-task learning trainer
MultiTaskOutput
Output from multi-task training step
TaskBatch
Training data batch for a specific task
TaskClusteringConfig
Task clustering configuration
TaskConfig
Task configuration
TaskEvaluation
Task evaluation results
TaskHead
Task-specific neural network head
TaskSchedulerState
Task scheduler state

Enums§

AuxiliaryTaskFrequency
Frequency of auxiliary task training
AuxiliaryType
Auxiliary task types
ClusteringMethod
Clustering methods for tasks
LossBalancingStrategy
Strategies for balancing losses across tasks
MTLArchitecture
Multi-task learning architectures
RankingType
Ranking task types
RegressionLossType
Regression loss types
RegularizationType
Regularization types for soft parameter sharing
TaskPriority
Task priorities
TaskSchedulingStrategy
Task scheduling strategies
TaskType
Task types and their specific parameters