tch_plus/nn/
module.rs

1//! Basic module traits defining the forward pass.
2use crate::{data::Iter2, Device, Tensor};
3
4/// The simplest module trait, defining a forward function.
5pub trait Module: std::fmt::Debug + Send {
6    fn forward(&self, xs: &Tensor) -> Tensor;
7}
8
9/// Module trait with an additional train parameter.
10///
11/// The train parameter is commonly used to have different behavior between training
12/// and evaluation, e.g. when using dropout or batch-normalization.
13pub trait ModuleT: std::fmt::Debug + Send {
14    fn forward_t(&self, xs: &Tensor, train: bool) -> Tensor;
15
16    fn batch_accuracy_for_logits(
17        &self,
18        xs: &Tensor,
19        ys: &Tensor,
20        d: Device,
21        batch_size: i64,
22    ) -> f64 {
23        let _no_grad = crate::no_grad_guard();
24        let mut sum_accuracy = 0f64;
25        let mut sample_count = 0f64;
26        for (xs, ys) in Iter2::new(xs, ys, batch_size).return_smaller_last_batch() {
27            let acc = self.forward_t(&xs.to_device(d), false).accuracy_for_logits(&ys.to_device(d));
28            let size = xs.size()[0] as f64;
29            sum_accuracy += f64::try_from(&acc).unwrap() * size;
30            sample_count += size;
31        }
32        sum_accuracy / sample_count
33    }
34}
35
36impl<T> ModuleT for T
37where
38    T: Module,
39{
40    fn forward_t(&self, xs: &Tensor, _train: bool) -> Tensor {
41        self.forward(xs)
42    }
43}
44
45impl Tensor {
46    pub fn apply<M: Module>(&self, m: &M) -> Tensor {
47        m.forward(self)
48    }
49
50    pub fn apply_t<M: ModuleT>(&self, m: &M, train: bool) -> Tensor {
51        m.forward_t(self, train)
52    }
53
54    pub fn apply_opt<M: Module>(&self, m: &Option<M>) -> Tensor {
55        match m {
56            Some(m) => m.forward(self),
57            None => self.shallow_clone(),
58        }
59    }
60
61    pub fn apply_opt_t<M: ModuleT>(&self, m: &Option<M>, train: bool) -> Tensor {
62        match m {
63            Some(m) => m.forward_t(self, train),
64            None => self.shallow_clone(),
65        }
66    }
67}