tch_plus/nn/
sequential.rs1use super::{Module, ModuleT};
3use crate::Tensor;
4
5#[derive(Debug)]
7pub struct Sequential {
8 layers: Vec<Box<dyn Module>>,
9}
10
11pub fn seq() -> Sequential {
13 Sequential { layers: vec![] }
14}
15
16impl Sequential {
17 pub fn len(&self) -> i64 {
19 self.layers.len() as i64
20 }
21
22 pub fn is_empty(&self) -> bool {
24 self.layers.is_empty()
25 }
26}
27
28impl Module for Sequential {
29 fn forward(&self, xs: &Tensor) -> Tensor {
30 if self.layers.is_empty() {
31 xs.shallow_clone()
32 } else {
33 let xs = self.layers[0].forward(xs);
34 self.layers.iter().skip(1).fold(xs, |xs, layer| layer.forward(&xs))
35 }
36 }
37}
38
39impl Sequential {
40 #[allow(clippy::should_implement_trait)]
42 pub fn add<M: Module + 'static>(mut self, layer: M) -> Self {
43 self.layers.push(Box::new(layer));
44 self
45 }
46
47 pub fn add_fn<F>(self, f: F) -> Self
49 where
50 F: 'static + Fn(&Tensor) -> Tensor + Send,
51 {
52 self.add(super::func(f))
53 }
54
55 pub fn forward_all(&self, xs: &Tensor, n: Option<usize>) -> Vec<Tensor> {
57 if self.layers.is_empty() {
58 vec![xs.shallow_clone()]
59 } else {
60 let n = n.unwrap_or(self.layers.len());
61 let xs = self.layers[0].forward(xs);
62 let mut vec = vec![];
63 let out = self.layers.iter().take(n).skip(1).fold(xs, |xs, layer| {
64 let out = layer.forward(&xs);
65 vec.push(xs);
66 out
67 });
68 vec.push(out);
69 vec
70 }
71 }
72}
73
74#[derive(Debug)]
76pub struct SequentialT {
77 layers: Vec<Box<dyn ModuleT>>,
78}
79
80pub fn seq_t() -> SequentialT {
82 SequentialT { layers: vec![] }
83}
84
85impl SequentialT {
86 pub fn len(&self) -> i64 {
88 self.layers.len() as i64
89 }
90
91 pub fn is_empty(&self) -> bool {
93 self.layers.is_empty()
94 }
95}
96
97impl ModuleT for SequentialT {
98 fn forward_t(&self, xs: &Tensor, train: bool) -> Tensor {
99 if self.layers.is_empty() {
100 xs.shallow_clone()
101 } else {
102 let xs = self.layers[0].forward_t(xs, train);
103 self.layers.iter().skip(1).fold(xs, |xs, layer| layer.forward_t(&xs, train))
104 }
105 }
106}
107
108impl SequentialT {
109 #[allow(clippy::should_implement_trait)]
111 pub fn add<M: ModuleT + 'static>(mut self, layer: M) -> Self {
112 self.layers.push(Box::new(layer));
113 self
114 }
115
116 pub fn add_fn<F>(self, f: F) -> Self
118 where
119 F: 'static + Fn(&Tensor) -> Tensor + Send,
120 {
121 self.add(super::func(f))
122 }
123
124 pub fn add_fn_t<F>(self, f: F) -> Self
126 where
127 F: 'static + Fn(&Tensor, bool) -> Tensor + Send,
128 {
129 self.add(super::func_t(f))
130 }
131
132 pub fn forward_all_t(&self, xs: &Tensor, train: bool, n: Option<usize>) -> Vec<Tensor> {
134 if self.layers.is_empty() {
135 vec![xs.shallow_clone()]
136 } else {
137 let n = n.unwrap_or(self.layers.len());
138 let xs = self.layers[0].forward_t(xs, train);
139 let mut vec = vec![];
140 let out = self.layers.iter().take(n).skip(1).fold(xs, |xs, layer| {
141 let out = layer.forward_t(&xs, train);
142 vec.push(xs);
143 out
144 });
145 vec.push(out);
146 vec
147 }
148 }
149}