1use crate::Tensor;
3
4pub struct Func<'a> {
6 f: Box<dyn 'a + Fn(&Tensor) -> Tensor + Send>,
7}
8
9impl std::fmt::Debug for Func<'_> {
10 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
11 write!(f, "func")
12 }
13}
14
15pub fn func<'a, F>(f: F) -> Func<'a>
16where
17 F: 'a + Fn(&Tensor) -> Tensor + Send,
18{
19 Func { f: Box::new(f) }
20}
21
22impl super::module::Module for Func<'_> {
23 fn forward(&self, xs: &Tensor) -> Tensor {
24 (*self.f)(xs)
25 }
26}
27
28#[allow(clippy::type_complexity)]
30pub struct FuncT<'a> {
31 f: Box<dyn 'a + Fn(&Tensor, bool) -> Tensor + Send>,
32}
33
34impl std::fmt::Debug for FuncT<'_> {
35 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
36 write!(f, "funcT")
37 }
38}
39
40pub fn func_t<'a, F>(f: F) -> FuncT<'a>
41where
42 F: 'a + Fn(&Tensor, bool) -> Tensor + Send,
43{
44 FuncT { f: Box::new(f) }
45}
46
47impl super::module::ModuleT for FuncT<'_> {
48 fn forward_t(&self, xs: &Tensor, train: bool) -> Tensor {
49 (*self.f)(xs, train)
50 }
51}