tch_plus/nn/
sequential.rs

1//! A sequential layer used to chain multiple layers and closures.
2use super::{Module, ModuleT};
3use crate::Tensor;
4
5/// A sequential layer combining multiple other layers.
6#[derive(Debug)]
7pub struct Sequential {
8    layers: Vec<Box<dyn Module>>,
9}
10
11/// Creates a new empty sequential layer.
12pub fn seq() -> Sequential {
13    Sequential { layers: vec![] }
14}
15
16impl Sequential {
17    /// The number of sub-layers embedded in this layer.
18    pub fn len(&self) -> i64 {
19        self.layers.len() as i64
20    }
21
22    /// Returns true if this layer does not have any sub-layer.
23    pub fn is_empty(&self) -> bool {
24        self.layers.is_empty()
25    }
26}
27
28impl Module for Sequential {
29    fn forward(&self, xs: &Tensor) -> Tensor {
30        if self.layers.is_empty() {
31            xs.shallow_clone()
32        } else {
33            let xs = self.layers[0].forward(xs);
34            self.layers.iter().skip(1).fold(xs, |xs, layer| layer.forward(&xs))
35        }
36    }
37}
38
39impl Sequential {
40    /// Appends a layer after all the current layers.
41    #[allow(clippy::should_implement_trait)]
42    pub fn add<M: Module + 'static>(mut self, layer: M) -> Self {
43        self.layers.push(Box::new(layer));
44        self
45    }
46
47    /// Appends a closure after all the current layers.
48    pub fn add_fn<F>(self, f: F) -> Self
49    where
50        F: 'static + Fn(&Tensor) -> Tensor + Send,
51    {
52        self.add(super::func(f))
53    }
54
55    /// Applies the forward pass and returns the output for each layer.
56    pub fn forward_all(&self, xs: &Tensor, n: Option<usize>) -> Vec<Tensor> {
57        if self.layers.is_empty() {
58            vec![xs.shallow_clone()]
59        } else {
60            let n = n.unwrap_or(self.layers.len());
61            let xs = self.layers[0].forward(xs);
62            let mut vec = vec![];
63            let out = self.layers.iter().take(n).skip(1).fold(xs, |xs, layer| {
64                let out = layer.forward(&xs);
65                vec.push(xs);
66                out
67            });
68            vec.push(out);
69            vec
70        }
71    }
72}
73
74/// A sequential layer combining new layers with support for a training mode.
75#[derive(Debug)]
76pub struct SequentialT {
77    layers: Vec<Box<dyn ModuleT>>,
78}
79
80/// Creates a new empty sequential layer.
81pub fn seq_t() -> SequentialT {
82    SequentialT { layers: vec![] }
83}
84
85impl SequentialT {
86    /// The number of sub-layers embedded in this layer.
87    pub fn len(&self) -> i64 {
88        self.layers.len() as i64
89    }
90
91    /// Returns true if this layer does not have any sub-layer.
92    pub fn is_empty(&self) -> bool {
93        self.layers.is_empty()
94    }
95}
96
97impl ModuleT for SequentialT {
98    fn forward_t(&self, xs: &Tensor, train: bool) -> Tensor {
99        if self.layers.is_empty() {
100            xs.shallow_clone()
101        } else {
102            let xs = self.layers[0].forward_t(xs, train);
103            self.layers.iter().skip(1).fold(xs, |xs, layer| layer.forward_t(&xs, train))
104        }
105    }
106}
107
108impl SequentialT {
109    /// Appends a layer after all the current layers.
110    #[allow(clippy::should_implement_trait)]
111    pub fn add<M: ModuleT + 'static>(mut self, layer: M) -> Self {
112        self.layers.push(Box::new(layer));
113        self
114    }
115
116    /// Appends a closure after all the current layers.
117    pub fn add_fn<F>(self, f: F) -> Self
118    where
119        F: 'static + Fn(&Tensor) -> Tensor + Send,
120    {
121        self.add(super::func(f))
122    }
123
124    /// Appends a closure after all the current layers.
125    pub fn add_fn_t<F>(self, f: F) -> Self
126    where
127        F: 'static + Fn(&Tensor, bool) -> Tensor + Send,
128    {
129        self.add(super::func_t(f))
130    }
131
132    /// Applies the forward pass and returns the output for each layer.
133    pub fn forward_all_t(&self, xs: &Tensor, train: bool, n: Option<usize>) -> Vec<Tensor> {
134        if self.layers.is_empty() {
135            vec![xs.shallow_clone()]
136        } else {
137            let n = n.unwrap_or(self.layers.len());
138            let xs = self.layers[0].forward_t(xs, train);
139            let mut vec = vec![];
140            let out = self.layers.iter().take(n).skip(1).fold(xs, |xs, layer| {
141                let out = layer.forward_t(&xs, train);
142                vec.push(xs);
143                out
144            });
145            vec.push(out);
146            vec
147        }
148    }
149}