1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
pub mod init;
pub use init::{f_init, init, Init};
mod var_store;
pub use var_store::{Path, VarStore, Variables};
mod module;
pub use module::{Module, ModuleT};
mod linear;
pub use linear::*;
mod conv;
pub use conv::*;
mod conv_transpose;
pub use conv_transpose::*;
mod batch_norm;
pub use batch_norm::*;
mod group_norm;
pub use group_norm::*;
mod layer_norm;
pub use layer_norm::*;
mod sparse;
pub use sparse::*;
mod rnn;
pub use rnn::*;
mod func;
pub use func::*;
mod sequential;
pub use sequential::*;
mod optimizer;
pub use optimizer::{
adam, adamw, rms_prop, sgd, Adam, AdamW, Optimizer, OptimizerConfig, RmsProp, Sgd,
};
#[derive(Debug)]
pub struct Id();
impl ModuleT for Id {
fn forward_t(&self, xs: &crate::Tensor, _train: bool) -> crate::Tensor {
xs.shallow_clone()
}
}
impl Module for crate::CModule {
fn forward(&self, xs: &crate::Tensor) -> crate::Tensor {
self.forward_ts(&[xs]).unwrap()
}
}
impl ModuleT for crate::TrainableCModule {
fn forward_t(&self, xs: &crate::Tensor, _train: bool) -> crate::Tensor {
self.inner.forward_ts(&[xs]).unwrap()
}
}