tch_plus/nn/
mod.rs

1//! A small neural-network library based on Torch.
2//!
3//! This library tries to stay as close as possible to the original
4//! Python and C++ implementations.
5pub mod init;
6pub use init::{f_init, init, Init};
7
8mod var_store;
9pub use var_store::{Path, VarStore, Variables};
10
11mod module;
12pub use module::{Module, ModuleT};
13
14mod linear;
15pub use linear::*;
16
17mod conv;
18pub use conv::*;
19
20mod conv_transpose;
21pub use conv_transpose::*;
22
23mod batch_norm;
24pub use batch_norm::*;
25
26mod group_norm;
27pub use group_norm::*;
28
29mod layer_norm;
30pub use layer_norm::*;
31
32mod sparse;
33pub use sparse::*;
34
35mod rnn;
36pub use rnn::*;
37
38mod func;
39pub use func::*;
40
41mod sequential;
42pub use sequential::*;
43
44mod optimizer;
45pub use optimizer::{
46    adam, adamw, rms_prop, sgd, Adam, AdamW, Optimizer, OptimizerConfig, RmsProp, Sgd,
47};
48
49/// An identity layer. This just propagates its tensor input as output.
50#[derive(Debug)]
51pub struct Id();
52
53impl ModuleT for Id {
54    fn forward_t(&self, xs: &crate::Tensor, _train: bool) -> crate::Tensor {
55        xs.shallow_clone()
56    }
57}
58
59impl Module for crate::CModule {
60    fn forward(&self, xs: &crate::Tensor) -> crate::Tensor {
61        self.forward_ts(&[xs]).unwrap()
62    }
63}
64
65impl ModuleT for crate::TrainableCModule {
66    fn forward_t(&self, xs: &crate::Tensor, _train: bool) -> crate::Tensor {
67        self.inner.forward_ts(&[xs]).unwrap()
68    }
69}