pub struct MultiTaskLearningTrainer<M: Model> {
pub base_model: M,
pub task_heads: HashMap<String, TaskHead>,
pub config: MTLConfig,
pub task_weights: HashMap<String, f32>,
pub task_performance: HashMap<String, Vec<f32>>,
pub step_counter: usize,
pub scheduler_state: TaskSchedulerState,
pub gradient_stats: HashMap<String, GradientStats>,
}Expand description
Multi-task learning trainer
Fields§
§base_model: MBase model (shared layers)
task_heads: HashMap<String, TaskHead>Task-specific heads
config: MTLConfigConfiguration
task_weights: HashMap<String, f32>Task losses and weights
task_performance: HashMap<String, Vec<f32>>Task performance history
step_counter: usizeCurrent training step
scheduler_state: TaskSchedulerStateTask scheduling state
gradient_stats: HashMap<String, GradientStats>Gradient statistics for balancing
Implementations§
Source§impl<M: Model<Input = Tensor, Output = Tensor>> MultiTaskLearningTrainer<M>
impl<M: Model<Input = Tensor, Output = Tensor>> MultiTaskLearningTrainer<M>
Sourcepub fn new(base_model: M, config: MTLConfig) -> Result<Self>
pub fn new(base_model: M, config: MTLConfig) -> Result<Self>
Create a new multi-task learning trainer
Sourcepub fn train_multi_task_step(
&mut self,
task_data: &HashMap<String, TaskBatch>,
) -> Result<MultiTaskOutput>
pub fn train_multi_task_step( &mut self, task_data: &HashMap<String, TaskBatch>, ) -> Result<MultiTaskOutput>
Train on multiple tasks for one step
Sourcepub fn evaluate_all_tasks(
&self,
test_data: &HashMap<String, TaskBatch>,
) -> Result<MultiTaskEvaluation>
pub fn evaluate_all_tasks( &self, test_data: &HashMap<String, TaskBatch>, ) -> Result<MultiTaskEvaluation>
Evaluate all tasks
Sourcepub fn get_mtl_stats(&self) -> MTLStats
pub fn get_mtl_stats(&self) -> MTLStats
Get multi-task learning statistics
Auto Trait Implementations§
impl<M> Freeze for MultiTaskLearningTrainer<M>where
M: Freeze,
impl<M> RefUnwindSafe for MultiTaskLearningTrainer<M>where
M: RefUnwindSafe,
impl<M> Send for MultiTaskLearningTrainer<M>
impl<M> Sync for MultiTaskLearningTrainer<M>
impl<M> Unpin for MultiTaskLearningTrainer<M>where
M: Unpin,
impl<M> UnsafeUnpin for MultiTaskLearningTrainer<M>where
M: UnsafeUnpin,
impl<M> UnwindSafe for MultiTaskLearningTrainer<M>where
M: UnwindSafe,
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more
Source§impl<T> Instrument for T
impl<T> Instrument for T
Source§fn instrument(self, span: Span) -> Instrumented<Self>
fn instrument(self, span: Span) -> Instrumented<Self>
Source§fn in_current_span(self) -> Instrumented<Self>
fn in_current_span(self) -> Instrumented<Self>
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
Converts
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more