Skip to main content

torsh_nn/base/
mod.rs

1//! Module base infrastructure for neural network implementations
2//!
3//! This module provides the ModuleBase helper struct that serves as a foundation
4//! for implementing neural network modules with integrated parameter management,
5//! hook system, and training state.
6
7use parking_lot::RwLock;
8use std::sync::Arc;
9use torsh_core::device::DeviceType;
10use torsh_core::error::Result;
11use torsh_tensor::Tensor;
12
13// Conditional imports for std/no_std compatibility
14#[cfg(feature = "std")]
15use std::collections::HashMap;
16
17#[cfg(not(feature = "std"))]
18use hashbrown::HashMap;
19
20use crate::{HookCallback, HookHandle, HookRegistry, HookType, Parameter};
21
22/// Base module implementation helper
23pub struct ModuleBase {
24    training: bool,
25    device: DeviceType,
26    pub parameters: HashMap<String, Parameter>,
27    buffers: HashMap<String, Arc<RwLock<Tensor>>>,
28    modules: HashMap<String, Box<dyn crate::Module>>,
29    hook_registry: HookRegistry,
30}
31
32impl core::fmt::Debug for ModuleBase {
33    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
34        f.debug_struct("ModuleBase")
35            .field("training", &self.training)
36            .field("device", &self.device)
37            .field("parameters_count", &self.parameters.len())
38            .field("buffers_count", &self.buffers.len())
39            .field("modules_count", &self.modules.len())
40            .field("hook_registry", &self.hook_registry)
41            .finish()
42    }
43}
44
45impl Default for ModuleBase {
46    fn default() -> Self {
47        Self::new()
48    }
49}
50
51impl ModuleBase {
52    pub fn new() -> Self {
53        Self {
54            training: true,
55            device: DeviceType::Cpu,
56            parameters: HashMap::new(),
57            buffers: HashMap::new(),
58            modules: HashMap::new(),
59            hook_registry: HookRegistry::new(),
60        }
61    }
62
63    /// Check if in training mode
64    pub fn training(&self) -> bool {
65        self.training
66    }
67
68    /// Set training mode
69    pub fn set_training(&mut self, training: bool) {
70        self.training = training;
71        for module in self.modules.values_mut() {
72            module.set_training(training);
73        }
74    }
75
76    /// Apply function to all parameters in this module
77    pub fn apply_to_parameters<F>(&mut self, f: F) -> Result<()>
78    where
79        F: Fn(&mut Parameter) -> Result<()>,
80    {
81        use crate::ModuleApply;
82        for param in self.parameters.values_mut() {
83            f(param)?;
84        }
85        for module in self.modules.values_mut() {
86            module.apply_to_parameters(&f)?;
87        }
88        Ok(())
89    }
90
91    /// Apply function to all modules recursively
92    pub fn apply_to_modules<F>(&mut self, f: F) -> Result<()>
93    where
94        F: Fn(&mut dyn crate::Module) -> Result<()>,
95    {
96        use crate::ModuleApply;
97        for module in self.modules.values_mut() {
98            f(module.as_mut())?;
99            module.apply_to_modules(&f)?;
100        }
101        Ok(())
102    }
103
104    /// Get children modules as references
105    pub fn children(&self) -> Vec<&dyn crate::Module> {
106        self.modules.values().map(|m| m.as_ref()).collect()
107    }
108
109    /// Get named children modules
110    pub fn named_children(&self) -> Vec<(String, &dyn crate::Module)> {
111        self.modules
112            .iter()
113            .map(|(name, module)| (name.clone(), module.as_ref()))
114            .collect()
115    }
116
117    /// Get named parameters
118    pub fn named_parameters(&self) -> HashMap<String, Parameter> {
119        self.parameters.clone()
120    }
121
122    /// Move to device
123    pub fn to_device(&mut self, device: DeviceType) -> Result<()> {
124        self.device = device;
125        // In a full implementation, would move all parameters to device
126        for module in self.modules.values_mut() {
127            module.to_device(device)?;
128        }
129        Ok(())
130    }
131
132    /// Register a parameter
133    pub fn register_parameter(&mut self, name: String, param: Parameter) {
134        self.parameters.insert(name, param);
135    }
136
137    /// Register a buffer
138    pub fn register_buffer(&mut self, name: String, tensor: Tensor) {
139        self.buffers.insert(name, Arc::new(RwLock::new(tensor)));
140    }
141
142    /// Register a submodule
143    pub fn register_module(&mut self, name: String, module: Box<dyn crate::Module>) {
144        self.modules.insert(name, module);
145    }
146
147    /// Get all parameters including submodules (legacy method)
148    pub fn all_parameter_tensors(&self) -> Vec<Arc<RwLock<Tensor>>> {
149        let mut params: Vec<_> = self.parameters.values().map(|p| p.tensor()).collect();
150
151        for module in self.modules.values() {
152            let module_params = module.parameters();
153            for param in module_params.values() {
154                params.push(param.tensor());
155            }
156        }
157
158        params
159    }
160
161    /// Get all parameters with module hierarchical names
162    pub fn get_all_named_parameters(&self) -> HashMap<String, Parameter> {
163        let mut all_params = HashMap::new();
164
165        // Add own parameters
166        for (name, param) in &self.parameters {
167            all_params.insert(name.clone(), param.clone());
168        }
169
170        // Add child parameters with prefixes
171        for (module_name, module) in &self.modules {
172            for (param_name, param) in module.all_named_parameters() {
173                let full_name = if param_name.is_empty() {
174                    module_name.clone()
175                } else {
176                    format!("{module_name}.{param_name}")
177                };
178                all_params.insert(full_name, param);
179            }
180        }
181
182        all_params
183    }
184
185    /// Get all named parameters including submodules
186    pub fn all_named_parameters(&self) -> HashMap<String, Arc<RwLock<Tensor>>> {
187        let mut params = HashMap::new();
188
189        for (name, param) in &self.parameters {
190            params.insert(name.clone(), param.tensor());
191        }
192
193        for (module_name, module) in &self.modules {
194            for (param_name, param) in module.named_parameters() {
195                params.insert(format!("{module_name}.{param_name}"), param.tensor());
196            }
197        }
198
199        params
200    }
201
202    /// Register a hook for this module
203    pub fn register_hook(&mut self, hook_type: HookType, callback: HookCallback) -> HookHandle {
204        self.hook_registry.register_hook(hook_type, callback)
205    }
206
207    /// Remove a hook by handle
208    pub fn remove_hook(&mut self, hook_type: HookType, handle: HookHandle) -> bool {
209        self.hook_registry.remove_hook(hook_type, handle)
210    }
211
212    /// Execute hooks of a specific type
213    pub fn execute_hooks(
214        &self,
215        hook_type: HookType,
216        module: &dyn crate::Module,
217        input: &Tensor,
218        output: Option<&Tensor>,
219    ) -> Result<()> {
220        self.hook_registry
221            .execute_hooks(hook_type, module, input, output)
222    }
223
224    /// Check if any hooks are registered
225    pub fn has_hooks(&self, hook_type: HookType) -> bool {
226        self.hook_registry.has_hooks(hook_type)
227    }
228
229    /// Get hook count for a specific type
230    pub fn hook_count(&self, hook_type: HookType) -> usize {
231        self.hook_registry.hook_count(hook_type)
232    }
233
234    /// Clear all hooks of a specific type
235    pub fn clear_hooks(&mut self, hook_type: HookType) {
236        self.hook_registry.clear_hooks(hook_type)
237    }
238
239    /// Clear all hooks
240    pub fn clear_all_hooks(&mut self) {
241        self.hook_registry.clear_all_hooks()
242    }
243}
244
245// =============================================================================
246// TESTS
247// =============================================================================
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252    use torsh_tensor::creation::zeros;
253
254    #[test]
255    fn test_module_base_creation() {
256        let base = ModuleBase::new();
257        assert!(base.training());
258        assert_eq!(base.device, DeviceType::Cpu);
259        assert_eq!(base.parameters.len(), 0);
260        assert_eq!(base.buffers.len(), 0);
261        assert_eq!(base.modules.len(), 0);
262    }
263
264    #[test]
265    fn test_module_base_default() {
266        let base = ModuleBase::default();
267        assert!(base.training());
268        assert_eq!(base.device, DeviceType::Cpu);
269    }
270
271    #[test]
272    fn test_training_mode() {
273        let mut base = ModuleBase::new();
274        assert!(base.training());
275
276        base.set_training(false);
277        assert!(!base.training());
278
279        base.set_training(true);
280        assert!(base.training());
281    }
282
283    #[test]
284    fn test_register_parameter() {
285        let mut base = ModuleBase::new();
286        let tensor = zeros(&[3, 4]).unwrap();
287        let param = Parameter::new(tensor);
288
289        base.register_parameter("weight".to_string(), param);
290        assert_eq!(base.parameters.len(), 1);
291        assert!(base.parameters.contains_key("weight"));
292    }
293
294    #[test]
295    fn test_register_multiple_parameters() {
296        let mut base = ModuleBase::new();
297
298        let weight = Parameter::new(zeros(&[10, 5]).unwrap());
299        let bias = Parameter::new(zeros(&[5]).unwrap());
300
301        base.register_parameter("weight".to_string(), weight);
302        base.register_parameter("bias".to_string(), bias);
303
304        assert_eq!(base.parameters.len(), 2);
305        assert!(base.parameters.contains_key("weight"));
306        assert!(base.parameters.contains_key("bias"));
307    }
308
309    #[test]
310    fn test_register_buffer() {
311        let mut base = ModuleBase::new();
312        let tensor = zeros(&[10]).unwrap();
313
314        base.register_buffer("running_mean".to_string(), tensor);
315        assert_eq!(base.buffers.len(), 1);
316        assert!(base.buffers.contains_key("running_mean"));
317    }
318
319    #[test]
320    fn test_named_parameters() {
321        let mut base = ModuleBase::new();
322        let param = Parameter::new(zeros(&[3, 4]).unwrap());
323        base.register_parameter("weight".to_string(), param);
324
325        let named_params = base.named_parameters();
326        assert_eq!(named_params.len(), 1);
327        assert!(named_params.contains_key("weight"));
328    }
329
330    #[test]
331    fn test_children_empty() {
332        let base = ModuleBase::new();
333        let children = base.children();
334        assert_eq!(children.len(), 0);
335    }
336
337    #[test]
338    fn test_named_children_empty() {
339        let base = ModuleBase::new();
340        let named_children = base.named_children();
341        assert_eq!(named_children.len(), 0);
342    }
343
344    #[test]
345    fn test_to_device_cpu() -> Result<()> {
346        let mut base = ModuleBase::new();
347        base.to_device(DeviceType::Cpu)?;
348        assert_eq!(base.device, DeviceType::Cpu);
349        Ok(())
350    }
351
352    #[test]
353    fn test_all_parameter_tensors() {
354        let mut base = ModuleBase::new();
355        let param1 = Parameter::new(zeros(&[2, 3]).unwrap());
356        let param2 = Parameter::new(zeros(&[4]).unwrap());
357
358        base.register_parameter("weight".to_string(), param1);
359        base.register_parameter("bias".to_string(), param2);
360
361        let all_params = base.all_parameter_tensors();
362        assert_eq!(all_params.len(), 2);
363    }
364
365    #[test]
366    fn test_all_named_parameters() {
367        let mut base = ModuleBase::new();
368        let param = Parameter::new(zeros(&[3, 4]).unwrap());
369        base.register_parameter("weight".to_string(), param);
370
371        let all_named = base.all_named_parameters();
372        assert_eq!(all_named.len(), 1);
373    }
374
375    #[test]
376    fn test_hook_registration() {
377        use crate::HookType;
378
379        let mut base = ModuleBase::new();
380        let callback: HookCallback = Box::new(|_module, _input, _output| Ok(()));
381
382        let handle = base.register_hook(HookType::PreForward, callback);
383        assert!(base.has_hooks(HookType::PreForward));
384        assert_eq!(base.hook_count(HookType::PreForward), 1);
385
386        let removed = base.remove_hook(HookType::PreForward, handle);
387        assert!(removed);
388        assert!(!base.has_hooks(HookType::PreForward));
389    }
390
391    #[test]
392    fn test_hook_multiple_registration() {
393        use crate::HookType;
394
395        let mut base = ModuleBase::new();
396        let callback1: HookCallback = Box::new(|_m, _i, _o| Ok(()));
397        let callback2: HookCallback = Box::new(|_m, _i, _o| Ok(()));
398
399        base.register_hook(HookType::PreForward, callback1);
400        base.register_hook(HookType::PreForward, callback2);
401
402        assert_eq!(base.hook_count(HookType::PreForward), 2);
403    }
404
405    #[test]
406    fn test_clear_hooks() {
407        use crate::HookType;
408
409        let mut base = ModuleBase::new();
410        let callback1: HookCallback = Box::new(|_m, _i, _o| Ok(()));
411        let callback2: HookCallback = Box::new(|_m, _i, _o| Ok(()));
412
413        base.register_hook(HookType::PreForward, callback1);
414        base.register_hook(HookType::PreBackward, callback2);
415
416        assert!(base.has_hooks(HookType::PreForward));
417        assert!(base.has_hooks(HookType::PreBackward));
418
419        base.clear_hooks(HookType::PreForward);
420        assert!(!base.has_hooks(HookType::PreForward));
421        assert!(base.has_hooks(HookType::PreBackward));
422    }
423
424    #[test]
425    fn test_clear_all_hooks() {
426        use crate::HookType;
427
428        let mut base = ModuleBase::new();
429        let callback1: HookCallback = Box::new(|_m, _i, _o| Ok(()));
430        let callback2: HookCallback = Box::new(|_m, _i, _o| Ok(()));
431
432        base.register_hook(HookType::PreForward, callback1);
433        base.register_hook(HookType::PreBackward, callback2);
434
435        assert!(base.has_hooks(HookType::PreForward));
436        assert!(base.has_hooks(HookType::PreBackward));
437
438        base.clear_all_hooks();
439        assert!(!base.has_hooks(HookType::PreForward));
440        assert!(!base.has_hooks(HookType::PreBackward));
441    }
442
443    #[test]
444    fn test_hook_count_zero() {
445        use crate::HookType;
446
447        let base = ModuleBase::new();
448        assert_eq!(base.hook_count(HookType::PreForward), 0);
449        assert_eq!(base.hook_count(HookType::PreBackward), 0);
450    }
451
452    #[test]
453    fn test_debug_format() {
454        let mut base = ModuleBase::new();
455        base.register_parameter(
456            "weight".to_string(),
457            Parameter::new(zeros(&[2, 3]).unwrap()),
458        );
459
460        let debug_str = format!("{:?}", base);
461        assert!(debug_str.contains("ModuleBase"));
462        assert!(debug_str.contains("training"));
463        assert!(debug_str.contains("parameters_count"));
464    }
465
466    #[test]
467    fn test_parameter_replacement() {
468        let mut base = ModuleBase::new();
469
470        // Register initial parameter
471        let param1 = Parameter::new(zeros(&[2, 3]).unwrap());
472        base.register_parameter("weight".to_string(), param1);
473        assert_eq!(base.parameters.len(), 1);
474
475        // Replace with new parameter (same name)
476        let param2 = Parameter::new(zeros(&[4, 5]).unwrap());
477        base.register_parameter("weight".to_string(), param2);
478        assert_eq!(base.parameters.len(), 1); // Still just one parameter
479
480        // Verify new shape
481        let weight_arc = base.parameters["weight"].tensor();
482        let weight = weight_arc.read();
483        assert_eq!(weight.shape().dims(), &[4, 5]);
484    }
485
486    #[test]
487    fn test_buffer_replacement() {
488        let mut base = ModuleBase::new();
489
490        // Register initial buffer
491        base.register_buffer("running_mean".to_string(), zeros(&[10]).unwrap());
492        assert_eq!(base.buffers.len(), 1);
493
494        // Replace with new buffer
495        base.register_buffer("running_mean".to_string(), zeros(&[20]).unwrap());
496        assert_eq!(base.buffers.len(), 1); // Still just one buffer
497
498        // Verify new shape
499        let buffer = base.buffers["running_mean"].read();
500        assert_eq!(buffer.shape().dims(), &[20]);
501    }
502
503    #[test]
504    fn test_empty_base_all_named_parameters() {
505        let base = ModuleBase::new();
506        let all_named = base.all_named_parameters();
507        assert_eq!(all_named.len(), 0);
508    }
509
510    #[test]
511    fn test_empty_base_get_all_named_parameters() {
512        let base = ModuleBase::new();
513        let all_named = base.get_all_named_parameters();
514        assert_eq!(all_named.len(), 0);
515    }
516}