Trait tch::nn::ModuleT[][src]

pub trait ModuleT: Debug + Send {
    fn forward_t(&self, xs: &Tensor, train: bool) -> Tensor;

    fn batch_accuracy_for_logits(
        &self,
        xs: &Tensor,
        ys: &Tensor,
        d: Device,
        batch_size: i64
    ) -> f64 { ... } }
Expand description

Module trait with an additional train parameter.

The train parameter is commonly used to have different behavior between training and evaluation, e.g. when using dropout or batch-normalization.

Required methods

Provided methods

Implementors