pub trait ModuleT: Debug + Send {
// Required method
fn forward_t(&self, xs: &Tensor, train: bool) -> Tensor;
// Provided method
fn batch_accuracy_for_logits(
&self,
xs: &Tensor,
ys: &Tensor,
d: Device,
batch_size: i64
) -> f64 { ... }
}
Expand description
Module trait with an additional train parameter.
The train parameter is commonly used to have different behavior between training and evaluation, e.g. when using dropout or batch-normalization.