torsh_nn/composition/
mod.rs1use torsh_core::device::DeviceType;
8use torsh_core::error::Result;
9use torsh_tensor::Tensor;
10
11#[cfg(feature = "std")]
13use std::collections::HashMap;
14
15#[cfg(not(feature = "std"))]
16use hashbrown::HashMap;
17
18use crate::{Module, ModuleConfig, Parameter};
19
20pub trait ModuleBuilder<T> {
22 fn build(self) -> Result<T>;
24
25 fn training(self, training: bool) -> Self;
27
28 fn device(self, device: DeviceType) -> Self;
30
31 fn config<F>(self, config_fn: F) -> Self
33 where
34 F: FnOnce(&mut ModuleConfig);
35}
36
37pub trait ModuleComposition {
39 fn then<Other: Module + 'static>(self, other: Other) -> ComposedModule<Self, Other>
43 where
44 Self: Sized + 'static;
45
46 fn parallel<Other: Module + 'static>(self, other: Other) -> ParallelModule<Self, Other>
50 where
51 Self: Sized + 'static;
52
53 fn residual(self) -> ResidualModule<Self>
57 where
58 Self: Sized + 'static;
59
60 fn conditional<F>(self, condition_fn: F) -> ConditionalModule<Self, F>
64 where
65 Self: Sized + 'static,
66 F: Fn() -> bool + Send + Sync;
67}
68
69pub struct ComposedModule<First, Second> {
71 first: First,
72 second: Second,
73}
74
75impl<First: Module, Second: Module> Module for ComposedModule<First, Second> {
76 fn forward(&self, input: &Tensor) -> Result<Tensor> {
77 let intermediate = self.first.forward(input)?;
78 self.second.forward(&intermediate)
79 }
80
81 fn parameters(&self) -> HashMap<String, Parameter> {
82 let mut params = self.first.parameters();
83 let second_params = self.second.parameters();
84 for (name, param) in second_params {
85 params.insert(format!("second.{}", name), param);
86 }
87 params
88 }
89
90 fn training(&self) -> bool {
91 self.first.training() && self.second.training()
92 }
93
94 fn set_training(&mut self, training: bool) {
95 self.first.set_training(training);
96 self.second.set_training(training);
97 }
98}
99
100pub struct ParallelModule<First, Second> {
102 first: First,
103 second: Second,
104}
105
106impl<First: Module, Second: Module> Module for ParallelModule<First, Second> {
107 fn forward(&self, input: &Tensor) -> Result<Tensor> {
108 let first_output = self.first.forward(input)?;
109 let _second_output = self.second.forward(input)?;
110 Ok(first_output)
113 }
114
115 fn parameters(&self) -> HashMap<String, Parameter> {
116 let mut params = HashMap::new();
117 for (name, param) in self.first.parameters() {
118 params.insert(format!("first.{name}"), param);
119 }
120 for (name, param) in self.second.parameters() {
121 params.insert(format!("second.{name}"), param);
122 }
123 params
124 }
125
126 fn training(&self) -> bool {
127 self.first.training() && self.second.training()
128 }
129
130 fn set_training(&mut self, training: bool) {
131 self.first.set_training(training);
132 self.second.set_training(training);
133 }
134}
135
136pub struct ResidualModule<M> {
138 module: M,
139}
140
141impl<M: Module> Module for ResidualModule<M> {
142 fn forward(&self, input: &Tensor) -> Result<Tensor> {
143 let output = self.module.forward(input)?;
144 Ok(output)
147 }
148
149 fn parameters(&self) -> HashMap<String, Parameter> {
150 self.module.parameters()
151 }
152
153 fn training(&self) -> bool {
154 self.module.training()
155 }
156
157 fn set_training(&mut self, training: bool) {
158 self.module.set_training(training);
159 }
160}
161
162pub struct ConditionalModule<M, F> {
164 module: M,
165 condition_fn: F,
166}
167
168impl<M: Module, F: Fn() -> bool + Send + Sync> Module for ConditionalModule<M, F> {
169 fn forward(&self, input: &Tensor) -> Result<Tensor> {
170 if (self.condition_fn)() {
171 self.module.forward(input)
172 } else {
173 Ok(input.clone())
174 }
175 }
176
177 fn parameters(&self) -> HashMap<String, Parameter> {
178 self.module.parameters()
179 }
180
181 fn training(&self) -> bool {
182 self.module.training()
183 }
184
185 fn set_training(&mut self, training: bool) {
186 self.module.set_training(training);
187 }
188}
189
190impl<T: Module + 'static> ModuleComposition for T {
192 fn then<Other: Module + 'static>(self, other: Other) -> ComposedModule<Self, Other> {
193 ComposedModule {
194 first: self,
195 second: other,
196 }
197 }
198
199 fn parallel<Other: Module + 'static>(self, other: Other) -> ParallelModule<Self, Other> {
200 ParallelModule {
201 first: self,
202 second: other,
203 }
204 }
205
206 fn residual(self) -> ResidualModule<Self> {
207 ResidualModule { module: self }
208 }
209
210 fn conditional<F>(self, condition_fn: F) -> ConditionalModule<Self, F>
211 where
212 F: Fn() -> bool + Send + Sync,
213 {
214 ConditionalModule {
215 module: self,
216 condition_fn,
217 }
218 }
219}