1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
use std::fmt::Debug;
use tch::Tensor;
pub trait FallibleModule: Debug + Send {
type Error;
fn forward(&self, input: &Tensor) -> Result<Tensor, Self::Error>;
}
pub trait FallibleModuleT: Debug + Send {
type Error;
fn forward_t(&self, input: &Tensor, train: bool) -> Result<Tensor, Self::Error>;
}
impl<M> FallibleModuleT for M
where
M: FallibleModule,
{
type Error = M::Error;
fn forward_t(&self, input: &Tensor, _train: bool) -> Result<Tensor, Self::Error> {
self.forward(input)
}
}