[][src]Trait tch::nn::ModuleT

pub trait ModuleT: Debug + Send {
    fn forward_t(&self, xs: &Tensor, train: bool) -> Tensor;

    fn batch_accuracy_for_logits(
        &self,
        xs: &Tensor,
        ys: &Tensor,
        d: Device,
        batch_size: i64
    ) -> f64 { ... } }

Module trait with an additional train parameter.

The train parameter is commonly used to have different behavior between training and evaluation, e.g. when using dropout or batch-normalization.

Required methods

fn forward_t(&self, xs: &Tensor, train: bool) -> Tensor

Loading content...

Provided methods

fn batch_accuracy_for_logits(
    &self,
    xs: &Tensor,
    ys: &Tensor,
    d: Device,
    batch_size: i64
) -> f64

Loading content...

Implementors

impl ModuleT for BatchNorm[src]

impl ModuleT for Id[src]

impl ModuleT for SequentialT[src]

impl<'a> ModuleT for FuncT<'a>[src]

impl<T> ModuleT for T where
    T: Module
[src]

Loading content...