syntaxdot_transformers/
module.rs1use std::fmt::Debug;
2
3use tch::Tensor;
4
5pub trait FallibleModule: Debug + Send {
7 type Error;
9
10 fn forward(&self, input: &Tensor) -> Result<Tensor, Self::Error>;
12}
13
14pub trait FallibleModuleT: Debug + Send {
16 type Error;
18
19 fn forward_t(&self, input: &Tensor, train: bool) -> Result<Tensor, Self::Error>;
21}
22
23impl<M> FallibleModuleT for M
24where
25 M: FallibleModule,
26{
27 type Error = M::Error;
28
29 fn forward_t(&self, input: &Tensor, _train: bool) -> Result<Tensor, Self::Error> {
30 self.forward(input)
31 }
32}