pub trait MetaLearningModel: Send + Sync {
Show 26 methods
// Required methods
fn forward(
&mut self,
examples: &ExampleSet,
) -> Result<f64, TrustformersError>;
fn compute_accuracy(
&self,
examples: &ExampleSet,
) -> Result<f64, TrustformersError>;
fn compute_gradients(
&self,
loss: f64,
) -> Result<ModelGradients, TrustformersError>;
fn apply_gradients(
&mut self,
gradients: &ModelGradients,
lr: f64,
) -> Result<(), TrustformersError>;
fn get_parameters(&self) -> Result<ModelParameters, TrustformersError>;
fn set_parameters(
&mut self,
params: ModelParameters,
) -> Result<(), TrustformersError>;
fn embed(&self, example: &Example) -> Result<Tensor, TrustformersError>;
// Provided methods
fn compute_second_order_gradients(
&self,
_initial_params: &ModelParameters,
_loss: f64,
) -> Result<ModelGradients, TrustformersError> { ... }
fn compute_first_order_gradients(
&self,
_loss: f64,
) -> Result<ModelGradients, TrustformersError> { ... }
fn compute_relation(
&self,
_emb1: &Tensor,
_emb2: &Tensor,
) -> Result<f64, TrustformersError> { ... }
fn write_to_memory(
&mut self,
_example: &Example,
) -> Result<(), TrustformersError> { ... }
fn read_from_memory(
&self,
_example: &Example,
) -> Result<MemoryOutput, TrustformersError> { ... }
fn predict_from_memory(
&self,
_memory_output: &MemoryOutput,
) -> Result<MemoryPrediction, TrustformersError> { ... }
fn clear_memory(&mut self) -> Result<(), TrustformersError> { ... }
fn get_learning_rates(&self) -> Result<Vec<f64>, TrustformersError> { ... }
fn apply_gradients_with_lr(
&mut self,
_gradients: &ModelGradients,
_learning_rates: &[f64],
) -> Result<(), TrustformersError> { ... }
fn compute_lr_gradients(
&self,
_loss: f64,
) -> Result<Vec<f64>, TrustformersError> { ... }
fn get_meta_learner_state(
&self,
) -> Result<MetaLearnerState, TrustformersError> { ... }
fn apply_learned_algorithm(
&self,
_support_set: &ExampleSet,
_state: &MetaLearnerState,
) -> Result<ModelParameters, TrustformersError> { ... }
fn evaluate_with_params(
&self,
_examples: &ExampleSet,
_params: &ModelParameters,
) -> Result<f64, TrustformersError> { ... }
fn compute_accuracy_with_params(
&self,
_examples: &ExampleSet,
_params: &ModelParameters,
) -> Result<f64, TrustformersError> { ... }
fn compute_meta_learner_gradients(
&self,
_loss: f64,
) -> Result<ModelGradients, TrustformersError> { ... }
fn get_lstm_state(&self) -> Result<LSTMState, TrustformersError> { ... }
fn lstm_update(
&self,
_gradients: &ModelGradients,
_state: &LSTMState,
_step: usize,
) -> Result<(ModelUpdates, LSTMState), TrustformersError> { ... }
fn apply_lstm_updates(
&mut self,
_updates: &ModelUpdates,
) -> Result<(), TrustformersError> { ... }
fn compute_lstm_gradients(
&self,
_loss: f64,
) -> Result<ModelGradients, TrustformersError> { ... }
}Expand description
Trait definitions for model components