syntaxdot_transformers/
module.rs

1use std::fmt::Debug;
2
3use tch::Tensor;
4
5/// Module for which a computation can fail.
6pub trait FallibleModule: Debug + Send {
7    /// The error type.
8    type Error;
9
10    /// Apply the module.
11    fn forward(&self, input: &Tensor) -> Result<Tensor, Self::Error>;
12}
13
14/// Module for which a computation can fail.
15pub trait FallibleModuleT: Debug + Send {
16    /// The error type.
17    type Error;
18
19    /// Apply the module.
20    fn forward_t(&self, input: &Tensor, train: bool) -> Result<Tensor, Self::Error>;
21}
22
23impl<M> FallibleModuleT for M
24where
25    M: FallibleModule,
26{
27    type Error = M::Error;
28
29    fn forward_t(&self, input: &Tensor, _train: bool) -> Result<Tensor, Self::Error> {
30        self.forward(input)
31    }
32}