Trait tch::nn::ModuleT

source ·
pub trait ModuleT: Debug + Send {
    // Required method
    fn forward_t(&self, xs: &Tensor, train: bool) -> Tensor;

    // Provided method
    fn batch_accuracy_for_logits(
        &self,
        xs: &Tensor,
        ys: &Tensor,
        d: Device,
        batch_size: i64
    ) -> f64 { ... }
}
Expand description

Module trait with an additional train parameter.

The train parameter is commonly used to have different behavior between training and evaluation, e.g. when using dropout or batch-normalization.

Required Methods§

source

fn forward_t(&self, xs: &Tensor, train: bool) -> Tensor

Provided Methods§

source

fn batch_accuracy_for_logits( &self, xs: &Tensor, ys: &Tensor, d: Device, batch_size: i64 ) -> f64

Implementors§