1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
use crate::{data::Iter2, Device, Tensor};
pub trait Module: std::fmt::Debug + Send {
fn forward(&self, xs: &Tensor) -> Tensor;
}
pub trait ModuleT: std::fmt::Debug + Send {
fn forward_t(&self, xs: &Tensor, train: bool) -> Tensor;
fn batch_accuracy_for_logits(
&self,
xs: &Tensor,
ys: &Tensor,
d: Device,
batch_size: i64,
) -> f64 {
let _no_grad = crate::no_grad_guard();
let mut sum_accuracy = 0f64;
let mut sample_count = 0f64;
for (xs, ys) in Iter2::new(xs, ys, batch_size).return_smaller_last_batch() {
let acc = self
.forward_t(&xs.to_device(d), false)
.accuracy_for_logits(&ys.to_device(d));
let size = xs.size()[0] as f64;
sum_accuracy += f64::from(&acc) * size;
sample_count += size;
}
sum_accuracy / sample_count
}
}
impl<T> ModuleT for T
where
T: Module,
{
fn forward_t(&self, xs: &Tensor, _train: bool) -> Tensor {
self.forward(&xs)
}
}
impl Tensor {
pub fn apply<M: Module>(&self, m: &M) -> Tensor {
m.forward(&self)
}
pub fn apply_t<M: ModuleT>(&self, m: &M, train: bool) -> Tensor {
m.forward_t(&self, train)
}
pub fn apply_opt<M: Module>(&self, m: &Option<M>) -> Tensor {
match m {
Some(m) => m.forward(&self),
None => self.shallow_clone(),
}
}
pub fn apply_opt_t<M: ModuleT>(&self, m: &Option<M>, train: bool) -> Tensor {
match m {
Some(m) => m.forward_t(&self, train),
None => self.shallow_clone(),
}
}
}