tch_plus/nn/
func.rs

1//! Layers defined by closures.
2use crate::Tensor;
3
4/// A layer defined by a simple closure.
5pub struct Func<'a> {
6    f: Box<dyn 'a + Fn(&Tensor) -> Tensor + Send>,
7}
8
9impl std::fmt::Debug for Func<'_> {
10    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
11        write!(f, "func")
12    }
13}
14
15pub fn func<'a, F>(f: F) -> Func<'a>
16where
17    F: 'a + Fn(&Tensor) -> Tensor + Send,
18{
19    Func { f: Box::new(f) }
20}
21
22impl super::module::Module for Func<'_> {
23    fn forward(&self, xs: &Tensor) -> Tensor {
24        (*self.f)(xs)
25    }
26}
27
28/// A layer defined by a closure with an additional training parameter.
29#[allow(clippy::type_complexity)]
30pub struct FuncT<'a> {
31    f: Box<dyn 'a + Fn(&Tensor, bool) -> Tensor + Send>,
32}
33
34impl std::fmt::Debug for FuncT<'_> {
35    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
36        write!(f, "funcT")
37    }
38}
39
40pub fn func_t<'a, F>(f: F) -> FuncT<'a>
41where
42    F: 'a + Fn(&Tensor, bool) -> Tensor + Send,
43{
44    FuncT { f: Box::new(f) }
45}
46
47impl super::module::ModuleT for FuncT<'_> {
48    fn forward_t(&self, xs: &Tensor, train: bool) -> Tensor {
49        (*self.f)(xs, train)
50    }
51}