1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
use std::fmt::Debug;

use tch::Tensor;

/// Module for which a computation can fail.
pub trait FallibleModule: Debug + Send {
    /// The error type.
    type Error;

    /// Apply the module.
    fn forward(&self, input: &Tensor) -> Result<Tensor, Self::Error>;
}

/// Module for which a computation can fail.
pub trait FallibleModuleT: Debug + Send {
    /// The error type.
    type Error;

    /// Apply the module.
    fn forward_t(&self, input: &Tensor, train: bool) -> Result<Tensor, Self::Error>;
}

impl<M> FallibleModuleT for M
where
    M: FallibleModule,
{
    type Error = M::Error;

    fn forward_t(&self, input: &Tensor, _train: bool) -> Result<Tensor, Self::Error> {
        self.forward(input)
    }
}