1use crate::{data::Iter2, Device, Tensor};
3
4pub trait Module: std::fmt::Debug + Send {
6 fn forward(&self, xs: &Tensor) -> Tensor;
7}
8
9pub trait ModuleT: std::fmt::Debug + Send {
14 fn forward_t(&self, xs: &Tensor, train: bool) -> Tensor;
15
16 fn batch_accuracy_for_logits(
17 &self,
18 xs: &Tensor,
19 ys: &Tensor,
20 d: Device,
21 batch_size: i64,
22 ) -> f64 {
23 let _no_grad = crate::no_grad_guard();
24 let mut sum_accuracy = 0f64;
25 let mut sample_count = 0f64;
26 for (xs, ys) in Iter2::new(xs, ys, batch_size).return_smaller_last_batch() {
27 let acc = self.forward_t(&xs.to_device(d), false).accuracy_for_logits(&ys.to_device(d));
28 let size = xs.size()[0] as f64;
29 sum_accuracy += f64::try_from(&acc).unwrap() * size;
30 sample_count += size;
31 }
32 sum_accuracy / sample_count
33 }
34}
35
36impl<T> ModuleT for T
37where
38 T: Module,
39{
40 fn forward_t(&self, xs: &Tensor, _train: bool) -> Tensor {
41 self.forward(xs)
42 }
43}
44
45impl Tensor {
46 pub fn apply<M: Module>(&self, m: &M) -> Tensor {
47 m.forward(self)
48 }
49
50 pub fn apply_t<M: ModuleT>(&self, m: &M, train: bool) -> Tensor {
51 m.forward_t(self, train)
52 }
53
54 pub fn apply_opt<M: Module>(&self, m: &Option<M>) -> Tensor {
55 match m {
56 Some(m) => m.forward(self),
57 None => self.shallow_clone(),
58 }
59 }
60
61 pub fn apply_opt_t<M: ModuleT>(&self, m: &Option<M>, train: bool) -> Tensor {
62 match m {
63 Some(m) => m.forward_t(self, train),
64 None => self.shallow_clone(),
65 }
66 }
67}