1use parking_lot::RwLock;
8use std::sync::Arc;
9use torsh_core::device::DeviceType;
10use torsh_core::error::Result;
11use torsh_tensor::Tensor;
12
13#[cfg(feature = "std")]
15use std::collections::HashMap;
16
17#[cfg(not(feature = "std"))]
18use hashbrown::HashMap;
19
20use crate::{HookCallback, HookHandle, HookRegistry, HookType, Parameter};
21
22pub struct ModuleBase {
24 training: bool,
25 device: DeviceType,
26 pub parameters: HashMap<String, Parameter>,
27 buffers: HashMap<String, Arc<RwLock<Tensor>>>,
28 modules: HashMap<String, Box<dyn crate::Module>>,
29 hook_registry: HookRegistry,
30}
31
32impl core::fmt::Debug for ModuleBase {
33 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
34 f.debug_struct("ModuleBase")
35 .field("training", &self.training)
36 .field("device", &self.device)
37 .field("parameters_count", &self.parameters.len())
38 .field("buffers_count", &self.buffers.len())
39 .field("modules_count", &self.modules.len())
40 .field("hook_registry", &self.hook_registry)
41 .finish()
42 }
43}
44
45impl Default for ModuleBase {
46 fn default() -> Self {
47 Self::new()
48 }
49}
50
51impl ModuleBase {
52 pub fn new() -> Self {
53 Self {
54 training: true,
55 device: DeviceType::Cpu,
56 parameters: HashMap::new(),
57 buffers: HashMap::new(),
58 modules: HashMap::new(),
59 hook_registry: HookRegistry::new(),
60 }
61 }
62
63 pub fn training(&self) -> bool {
65 self.training
66 }
67
68 pub fn set_training(&mut self, training: bool) {
70 self.training = training;
71 for module in self.modules.values_mut() {
72 module.set_training(training);
73 }
74 }
75
76 pub fn apply_to_parameters<F>(&mut self, f: F) -> Result<()>
78 where
79 F: Fn(&mut Parameter) -> Result<()>,
80 {
81 use crate::ModuleApply;
82 for param in self.parameters.values_mut() {
83 f(param)?;
84 }
85 for module in self.modules.values_mut() {
86 module.apply_to_parameters(&f)?;
87 }
88 Ok(())
89 }
90
91 pub fn apply_to_modules<F>(&mut self, f: F) -> Result<()>
93 where
94 F: Fn(&mut dyn crate::Module) -> Result<()>,
95 {
96 use crate::ModuleApply;
97 for module in self.modules.values_mut() {
98 f(module.as_mut())?;
99 module.apply_to_modules(&f)?;
100 }
101 Ok(())
102 }
103
104 pub fn children(&self) -> Vec<&dyn crate::Module> {
106 self.modules.values().map(|m| m.as_ref()).collect()
107 }
108
109 pub fn named_children(&self) -> Vec<(String, &dyn crate::Module)> {
111 self.modules
112 .iter()
113 .map(|(name, module)| (name.clone(), module.as_ref()))
114 .collect()
115 }
116
117 pub fn named_parameters(&self) -> HashMap<String, Parameter> {
119 self.parameters.clone()
120 }
121
122 pub fn to_device(&mut self, device: DeviceType) -> Result<()> {
124 self.device = device;
125 for module in self.modules.values_mut() {
127 module.to_device(device)?;
128 }
129 Ok(())
130 }
131
132 pub fn register_parameter(&mut self, name: String, param: Parameter) {
134 self.parameters.insert(name, param);
135 }
136
137 pub fn register_buffer(&mut self, name: String, tensor: Tensor) {
139 self.buffers.insert(name, Arc::new(RwLock::new(tensor)));
140 }
141
142 pub fn register_module(&mut self, name: String, module: Box<dyn crate::Module>) {
144 self.modules.insert(name, module);
145 }
146
147 pub fn all_parameter_tensors(&self) -> Vec<Arc<RwLock<Tensor>>> {
149 let mut params: Vec<_> = self.parameters.values().map(|p| p.tensor()).collect();
150
151 for module in self.modules.values() {
152 let module_params = module.parameters();
153 for param in module_params.values() {
154 params.push(param.tensor());
155 }
156 }
157
158 params
159 }
160
161 pub fn get_all_named_parameters(&self) -> HashMap<String, Parameter> {
163 let mut all_params = HashMap::new();
164
165 for (name, param) in &self.parameters {
167 all_params.insert(name.clone(), param.clone());
168 }
169
170 for (module_name, module) in &self.modules {
172 for (param_name, param) in module.all_named_parameters() {
173 let full_name = if param_name.is_empty() {
174 module_name.clone()
175 } else {
176 format!("{module_name}.{param_name}")
177 };
178 all_params.insert(full_name, param);
179 }
180 }
181
182 all_params
183 }
184
185 pub fn all_named_parameters(&self) -> HashMap<String, Arc<RwLock<Tensor>>> {
187 let mut params = HashMap::new();
188
189 for (name, param) in &self.parameters {
190 params.insert(name.clone(), param.tensor());
191 }
192
193 for (module_name, module) in &self.modules {
194 for (param_name, param) in module.named_parameters() {
195 params.insert(format!("{module_name}.{param_name}"), param.tensor());
196 }
197 }
198
199 params
200 }
201
202 pub fn register_hook(&mut self, hook_type: HookType, callback: HookCallback) -> HookHandle {
204 self.hook_registry.register_hook(hook_type, callback)
205 }
206
207 pub fn remove_hook(&mut self, hook_type: HookType, handle: HookHandle) -> bool {
209 self.hook_registry.remove_hook(hook_type, handle)
210 }
211
212 pub fn execute_hooks(
214 &self,
215 hook_type: HookType,
216 module: &dyn crate::Module,
217 input: &Tensor,
218 output: Option<&Tensor>,
219 ) -> Result<()> {
220 self.hook_registry
221 .execute_hooks(hook_type, module, input, output)
222 }
223
224 pub fn has_hooks(&self, hook_type: HookType) -> bool {
226 self.hook_registry.has_hooks(hook_type)
227 }
228
229 pub fn hook_count(&self, hook_type: HookType) -> usize {
231 self.hook_registry.hook_count(hook_type)
232 }
233
234 pub fn clear_hooks(&mut self, hook_type: HookType) {
236 self.hook_registry.clear_hooks(hook_type)
237 }
238
239 pub fn clear_all_hooks(&mut self) {
241 self.hook_registry.clear_all_hooks()
242 }
243}
244
245#[cfg(test)]
250mod tests {
251 use super::*;
252 use torsh_tensor::creation::zeros;
253
254 #[test]
255 fn test_module_base_creation() {
256 let base = ModuleBase::new();
257 assert!(base.training());
258 assert_eq!(base.device, DeviceType::Cpu);
259 assert_eq!(base.parameters.len(), 0);
260 assert_eq!(base.buffers.len(), 0);
261 assert_eq!(base.modules.len(), 0);
262 }
263
264 #[test]
265 fn test_module_base_default() {
266 let base = ModuleBase::default();
267 assert!(base.training());
268 assert_eq!(base.device, DeviceType::Cpu);
269 }
270
271 #[test]
272 fn test_training_mode() {
273 let mut base = ModuleBase::new();
274 assert!(base.training());
275
276 base.set_training(false);
277 assert!(!base.training());
278
279 base.set_training(true);
280 assert!(base.training());
281 }
282
283 #[test]
284 fn test_register_parameter() {
285 let mut base = ModuleBase::new();
286 let tensor = zeros(&[3, 4]).unwrap();
287 let param = Parameter::new(tensor);
288
289 base.register_parameter("weight".to_string(), param);
290 assert_eq!(base.parameters.len(), 1);
291 assert!(base.parameters.contains_key("weight"));
292 }
293
294 #[test]
295 fn test_register_multiple_parameters() {
296 let mut base = ModuleBase::new();
297
298 let weight = Parameter::new(zeros(&[10, 5]).unwrap());
299 let bias = Parameter::new(zeros(&[5]).unwrap());
300
301 base.register_parameter("weight".to_string(), weight);
302 base.register_parameter("bias".to_string(), bias);
303
304 assert_eq!(base.parameters.len(), 2);
305 assert!(base.parameters.contains_key("weight"));
306 assert!(base.parameters.contains_key("bias"));
307 }
308
309 #[test]
310 fn test_register_buffer() {
311 let mut base = ModuleBase::new();
312 let tensor = zeros(&[10]).unwrap();
313
314 base.register_buffer("running_mean".to_string(), tensor);
315 assert_eq!(base.buffers.len(), 1);
316 assert!(base.buffers.contains_key("running_mean"));
317 }
318
319 #[test]
320 fn test_named_parameters() {
321 let mut base = ModuleBase::new();
322 let param = Parameter::new(zeros(&[3, 4]).unwrap());
323 base.register_parameter("weight".to_string(), param);
324
325 let named_params = base.named_parameters();
326 assert_eq!(named_params.len(), 1);
327 assert!(named_params.contains_key("weight"));
328 }
329
330 #[test]
331 fn test_children_empty() {
332 let base = ModuleBase::new();
333 let children = base.children();
334 assert_eq!(children.len(), 0);
335 }
336
337 #[test]
338 fn test_named_children_empty() {
339 let base = ModuleBase::new();
340 let named_children = base.named_children();
341 assert_eq!(named_children.len(), 0);
342 }
343
344 #[test]
345 fn test_to_device_cpu() -> Result<()> {
346 let mut base = ModuleBase::new();
347 base.to_device(DeviceType::Cpu)?;
348 assert_eq!(base.device, DeviceType::Cpu);
349 Ok(())
350 }
351
352 #[test]
353 fn test_all_parameter_tensors() {
354 let mut base = ModuleBase::new();
355 let param1 = Parameter::new(zeros(&[2, 3]).unwrap());
356 let param2 = Parameter::new(zeros(&[4]).unwrap());
357
358 base.register_parameter("weight".to_string(), param1);
359 base.register_parameter("bias".to_string(), param2);
360
361 let all_params = base.all_parameter_tensors();
362 assert_eq!(all_params.len(), 2);
363 }
364
365 #[test]
366 fn test_all_named_parameters() {
367 let mut base = ModuleBase::new();
368 let param = Parameter::new(zeros(&[3, 4]).unwrap());
369 base.register_parameter("weight".to_string(), param);
370
371 let all_named = base.all_named_parameters();
372 assert_eq!(all_named.len(), 1);
373 }
374
375 #[test]
376 fn test_hook_registration() {
377 use crate::HookType;
378
379 let mut base = ModuleBase::new();
380 let callback: HookCallback = Box::new(|_module, _input, _output| Ok(()));
381
382 let handle = base.register_hook(HookType::PreForward, callback);
383 assert!(base.has_hooks(HookType::PreForward));
384 assert_eq!(base.hook_count(HookType::PreForward), 1);
385
386 let removed = base.remove_hook(HookType::PreForward, handle);
387 assert!(removed);
388 assert!(!base.has_hooks(HookType::PreForward));
389 }
390
391 #[test]
392 fn test_hook_multiple_registration() {
393 use crate::HookType;
394
395 let mut base = ModuleBase::new();
396 let callback1: HookCallback = Box::new(|_m, _i, _o| Ok(()));
397 let callback2: HookCallback = Box::new(|_m, _i, _o| Ok(()));
398
399 base.register_hook(HookType::PreForward, callback1);
400 base.register_hook(HookType::PreForward, callback2);
401
402 assert_eq!(base.hook_count(HookType::PreForward), 2);
403 }
404
405 #[test]
406 fn test_clear_hooks() {
407 use crate::HookType;
408
409 let mut base = ModuleBase::new();
410 let callback1: HookCallback = Box::new(|_m, _i, _o| Ok(()));
411 let callback2: HookCallback = Box::new(|_m, _i, _o| Ok(()));
412
413 base.register_hook(HookType::PreForward, callback1);
414 base.register_hook(HookType::PreBackward, callback2);
415
416 assert!(base.has_hooks(HookType::PreForward));
417 assert!(base.has_hooks(HookType::PreBackward));
418
419 base.clear_hooks(HookType::PreForward);
420 assert!(!base.has_hooks(HookType::PreForward));
421 assert!(base.has_hooks(HookType::PreBackward));
422 }
423
424 #[test]
425 fn test_clear_all_hooks() {
426 use crate::HookType;
427
428 let mut base = ModuleBase::new();
429 let callback1: HookCallback = Box::new(|_m, _i, _o| Ok(()));
430 let callback2: HookCallback = Box::new(|_m, _i, _o| Ok(()));
431
432 base.register_hook(HookType::PreForward, callback1);
433 base.register_hook(HookType::PreBackward, callback2);
434
435 assert!(base.has_hooks(HookType::PreForward));
436 assert!(base.has_hooks(HookType::PreBackward));
437
438 base.clear_all_hooks();
439 assert!(!base.has_hooks(HookType::PreForward));
440 assert!(!base.has_hooks(HookType::PreBackward));
441 }
442
443 #[test]
444 fn test_hook_count_zero() {
445 use crate::HookType;
446
447 let base = ModuleBase::new();
448 assert_eq!(base.hook_count(HookType::PreForward), 0);
449 assert_eq!(base.hook_count(HookType::PreBackward), 0);
450 }
451
452 #[test]
453 fn test_debug_format() {
454 let mut base = ModuleBase::new();
455 base.register_parameter(
456 "weight".to_string(),
457 Parameter::new(zeros(&[2, 3]).unwrap()),
458 );
459
460 let debug_str = format!("{:?}", base);
461 assert!(debug_str.contains("ModuleBase"));
462 assert!(debug_str.contains("training"));
463 assert!(debug_str.contains("parameters_count"));
464 }
465
466 #[test]
467 fn test_parameter_replacement() {
468 let mut base = ModuleBase::new();
469
470 let param1 = Parameter::new(zeros(&[2, 3]).unwrap());
472 base.register_parameter("weight".to_string(), param1);
473 assert_eq!(base.parameters.len(), 1);
474
475 let param2 = Parameter::new(zeros(&[4, 5]).unwrap());
477 base.register_parameter("weight".to_string(), param2);
478 assert_eq!(base.parameters.len(), 1); let weight_arc = base.parameters["weight"].tensor();
482 let weight = weight_arc.read();
483 assert_eq!(weight.shape().dims(), &[4, 5]);
484 }
485
486 #[test]
487 fn test_buffer_replacement() {
488 let mut base = ModuleBase::new();
489
490 base.register_buffer("running_mean".to_string(), zeros(&[10]).unwrap());
492 assert_eq!(base.buffers.len(), 1);
493
494 base.register_buffer("running_mean".to_string(), zeros(&[20]).unwrap());
496 assert_eq!(base.buffers.len(), 1); let buffer = base.buffers["running_mean"].read();
500 assert_eq!(buffer.shape().dims(), &[20]);
501 }
502
503 #[test]
504 fn test_empty_base_all_named_parameters() {
505 let base = ModuleBase::new();
506 let all_named = base.all_named_parameters();
507 assert_eq!(all_named.len(), 0);
508 }
509
510 #[test]
511 fn test_empty_base_get_all_named_parameters() {
512 let base = ModuleBase::new();
513 let all_named = base.get_all_named_parameters();
514 assert_eq!(all_named.len(), 0);
515 }
516}