Skip to main content

torsh_nn/core/
mod.rs

1//! Core module trait system for neural network modules
2//!
3//! This module provides the foundational Module trait and essential interfaces
4//! for all neural network components in ToRSh.
5
6pub mod module_ext;
7
8pub use module_ext::{ModuleExt, ParameterStats, ValidationReport};
9
10use torsh_core::device::DeviceType;
11use torsh_core::error::Result;
12use torsh_tensor::Tensor;
13
14// Conditional imports for std/no_std compatibility
15#[cfg(feature = "std")]
16use std::collections::HashMap;
17
18#[cfg(not(feature = "std"))]
19use hashbrown::HashMap;
20
21/// Base trait for all neural network modules
22///
23/// This trait provides the core interface for all neural network components,
24/// following PyTorch-compatible patterns for maximum interoperability.
25///
26/// # Design Philosophy
27///
28/// This trait is designed for maximum ergonomics while maintaining flexibility:
29/// - Most methods have sensible defaults to reduce boilerplate
30/// - Core functionality (forward, training mode) is required
31/// - Parameter management is streamlined with helper methods
32/// - Hook system is optional but well-integrated
33///
34/// # Required Methods
35///
36/// Only [`forward()`](Module::forward) must be implemented. All other methods have sensible defaults.
37///
38/// # Implementing a Custom Module
39///
40/// ## Basic Module (No Parameters)
41///
42/// ```rust
43/// use torsh_nn::{Module, ModuleBase};
44/// use torsh_tensor::Tensor;
45/// use torsh_core::error::Result;
46///
47/// /// Simple ReLU activation function
48/// struct MyReLU {
49///     base: ModuleBase,
50/// }
51///
52/// impl MyReLU {
53///     fn new() -> Self {
54///         Self {
55///             base: ModuleBase::new(),
56///         }
57///     }
58/// }
59///
60/// impl Module for MyReLU {
61///     fn forward(&self, input: &Tensor) -> Result<Tensor> {
62///         // Apply ReLU: max(0, x)
63///         input.relu()
64///     }
65///
66///     fn training(&self) -> bool {
67///         self.base.training()
68///     }
69///
70///     fn set_training(&mut self, training: bool) {
71///         self.base.set_training(training);
72///     }
73/// }
74/// ```
75///
76/// ## Module with Parameters
77///
78/// ```rust
79/// use torsh_nn::{Module, ModuleBase, Parameter};
80/// use torsh_tensor::{Tensor, creation};
81/// use torsh_core::error::Result;
82/// use std::collections::HashMap;
83///
84/// /// Custom linear layer with learnable weight and bias
85/// struct MyLinear {
86///     base: ModuleBase,
87///     in_features: usize,
88///     out_features: usize,
89/// }
90///
91/// impl MyLinear {
92///     fn new(in_features: usize, out_features: usize) -> Result<Self> {
93///         let mut base = ModuleBase::new();
94///
95///         // Initialize weight: [in_features, out_features]
96///         let weight = creation::randn(&[in_features, out_features])?;
97///         base.register_parameter("weight".to_string(), Parameter::new(weight));
98///
99///         // Initialize bias: [out_features]
100///         let bias = creation::zeros(&[out_features])?;
101///         base.register_parameter("bias".to_string(), Parameter::new(bias));
102///
103///         Ok(Self {
104///             base,
105///             in_features,
106///             out_features,
107///         })
108///     }
109/// }
110///
111/// impl Module for MyLinear {
112///     fn forward(&self, input: &Tensor) -> Result<Tensor> {
113///         // Get parameters
114///         let weight = self.base.parameters["weight"].tensor().read().clone();
115///         let bias = self.base.parameters["bias"].tensor().read().clone();
116///
117///         // Compute: input @ weight + bias
118///         let output = input.matmul(&weight)?;
119///         output.add(&bias)
120///     }
121///
122///     fn parameters(&self) -> HashMap<String, Parameter> {
123///         self.base.parameters.clone()
124///     }
125///
126///     fn named_parameters(&self) -> HashMap<String, Parameter> {
127///         self.base.named_parameters()
128///     }
129///
130///     fn training(&self) -> bool {
131///         self.base.training()
132///     }
133///
134///     fn set_training(&mut self, training: bool) {
135///         self.base.set_training(training);
136///     }
137/// }
138/// ```
139///
140/// ## Module with Training/Evaluation Modes
141///
142/// ```rust
143/// use torsh_nn::{Module, ModuleBase};
144/// use torsh_tensor::{Tensor, creation};
145/// use torsh_core::error::Result;
146///
147/// /// Dropout layer with different behavior in train/eval modes
148/// struct MyDropout {
149///     base: ModuleBase,
150///     p: f32,  // Dropout probability
151/// }
152///
153/// impl MyDropout {
154///     fn new(p: f32) -> Self {
155///         Self {
156///             base: ModuleBase::new(),
157///             p,
158///         }
159///     }
160/// }
161///
162/// impl Module for MyDropout {
163///     fn forward(&self, input: &Tensor) -> Result<Tensor> {
164///         if self.training() {
165///             // During training: randomly zero out units
166///             let rand_vals = creation::rand_like(input)?;
167///             let threshold = creation::full_like(input, self.p)?;
168///             // Keep elements where random value >= p
169///             let keep_mask = rand_vals.ge(&threshold)?;
170///             let zeros = creation::zeros_like(input)?;
171///             let masked = input.where_tensor(&keep_mask, &zeros)?;
172///             // Scale by 1/(1-p) to maintain expected value (inverted dropout)
173///             masked.div_scalar(1.0 - self.p)
174///         } else {
175///             // During evaluation: pass through unchanged
176///             Ok(input.clone())
177///         }
178///     }
179///
180///     fn training(&self) -> bool {
181///         self.base.training()
182///     }
183///
184///     fn set_training(&mut self, training: bool) {
185///         self.base.set_training(training);
186///     }
187/// }
188/// ```
189///
190/// # Complete Training Loop Example
191///
192/// ```rust,no_run
193/// use torsh_nn::prelude::{Module, Linear, Sequential};
194/// use torsh_tensor::{Tensor, creation};
195/// use torsh_core::error::Result;
196///
197/// fn train_eval_example() -> Result<()> {
198///     // 1. Create model
199///     let mut model = Sequential::new()
200///         .add(Linear::new(784, 128, true))
201///         .add(Linear::new(128, 10, true));
202///
203///     // 2. Set model to training mode
204///     model.train();
205///
206///     // 3. Forward pass
207///     let inputs = creation::randn(&[32, 784])?;
208///     let outputs = model.forward(&inputs)?;
209///
210///     // 4. Compute loss (simplified MSE)
211///     let targets = creation::randn(&[32, 10])?;
212///     let diff = outputs.sub(&targets)?;
213///     let loss = diff.pow_scalar(2.0)?.mean(None, false)?;
214///
215///     // 5. Backward pass (computes gradients)
216///     loss.backward()?;
217///
218///     // 6. Switch to evaluation mode
219///     model.eval();
220///     let test_input = creation::randn(&[10, 784])?;
221///     let test_output = model.forward(&test_input)?;
222///
223///     Ok(())
224/// }
225/// ```
226///
227/// # Method Categories
228///
229/// ## Core Methods (Required)
230/// - [`forward()`](Module::forward) - Forward pass computation
231///
232/// ## Parameter Management
233/// - [`parameters()`](Module::parameters) - Get all trainable parameters
234/// - [`named_parameters()`](Module::named_parameters) - Get parameters with names
235/// - [`all_parameters()`](Module::all_parameters) - Get parameters recursively
236/// - [`zero_grad()`](Module::zero_grad) - Clear all gradients
237///
238/// ## Training Mode Control
239/// - [`training()`](Module::training) - Check if in training mode
240/// - [`train()`](Module::train) - Set to training mode
241/// - [`eval()`](Module::eval) - Set to evaluation mode
242/// - [`set_training()`](Module::set_training) - Set training mode explicitly
243///
244/// ## Module Hierarchy
245/// - [`children()`](Module::children) - Get direct child modules
246/// - [`named_children()`](Module::named_children) - Get children with names
247/// - [`modules()`](Module::modules) - Get all modules recursively
248///
249/// ## State Management
250/// - [`state_dict()`](Module::state_dict) - Save module state
251/// - [`load_state_dict()`](Module::load_state_dict) - Load module state
252/// - [`to_device()`](Module::to_device) - Move module to device
253///
254/// ## Utilities
255/// - [`freeze()`](Module::freeze) - Freeze all parameters
256/// - [`unfreeze()`](Module::unfreeze) - Unfreeze all parameters
257/// - [`num_parameters()`](Module::num_parameters) - Count total parameters
258/// - [`diagnose()`](Module::diagnose) - Check module health
259///
260/// # PyTorch Compatibility
261///
262/// ToRSh's Module trait closely follows PyTorch's `nn.Module` interface:
263///
264/// | PyTorch | ToRSh | Notes |
265/// |---------|-------|-------|
266/// | `forward(x)` | `forward(&x)` | Returns `Result<Tensor>` |
267/// | `parameters()` | `parameters()` | Returns `HashMap<String, Parameter>` |
268/// | `train()` | `train()` | Sets training mode |
269/// | `eval()` | `eval()` | Sets evaluation mode |
270/// | `state_dict()` | `state_dict()` | Returns parameter tensors |
271/// | `load_state_dict()` | `load_state_dict()` | Loads from HashMap |
272/// | `to(device)` | `to_device(device)` | Moves to device |
273/// | `zero_grad()` | `zero_grad()` | Clears gradients |
274///
275/// # Best Practices
276///
277/// 1. **Always use `ModuleBase`**: Store a `ModuleBase` instance in your module to handle
278///    common functionality like parameters, buffers, and training state.
279///
280/// 2. **Register parameters in constructor**: Use `base.register_parameter()` to register
281///    all trainable parameters during module creation.
282///
283/// 3. **Implement training/eval behavior**: If your module behaves differently during
284///    training vs evaluation (like Dropout, BatchNorm), check `self.training()`.
285///
286/// 4. **Use `Result<Tensor>` for error handling**: Always return `Result<Tensor>` from
287///    `forward()` to properly propagate errors.
288///
289/// 5. **Delegate to `ModuleBase`**: Implement parameter and training mode methods by
290///    delegating to your `ModuleBase` instance.
291///
292/// 6. **Initialize parameters properly**: Use initialization methods from `torsh_nn::init`
293///    for better training convergence.
294pub trait Module: Send + Sync {
295    /// Forward pass through the module
296    ///
297    /// This is the only required method that must be implemented by all modules.
298    ///
299    /// # Arguments
300    /// * `input` - Input tensor
301    ///
302    /// # Returns
303    /// * `Result<Tensor>` - Output tensor or error
304    fn forward(&self, input: &Tensor) -> Result<Tensor>;
305
306    /// Get all parameters in the module (non-recursive)
307    ///
308    /// Override this method if your module has trainable parameters.
309    /// The default implementation returns an empty map.
310    ///
311    /// # Returns
312    /// * `HashMap<String, Parameter>` - Map of parameter names to parameters
313    fn parameters(&self) -> HashMap<String, crate::Parameter> {
314        HashMap::new()
315    }
316
317    /// Get named parameters (non-recursive)
318    ///
319    /// Default implementation delegates to `parameters()`. Override if you need
320    /// different behavior for named vs unnamed parameter access.
321    ///
322    /// # Returns
323    /// * `HashMap<String, Parameter>` - Map of parameter names to parameters
324    fn named_parameters(&self) -> HashMap<String, crate::Parameter> {
325        self.parameters()
326    }
327
328    /// Get all parameters recursively including submodules
329    ///
330    /// # Returns
331    /// * `HashMap<String, Parameter>` - Flattened map of all parameters
332    fn all_parameters(&self) -> HashMap<String, crate::Parameter> {
333        let mut all_params = self.parameters();
334
335        for child in self.children() {
336            let child_params = child.all_parameters();
337            for (name, param) in child_params {
338                all_params.insert(name, param);
339            }
340        }
341
342        all_params
343    }
344
345    /// Get all named parameters recursively with module prefixes
346    ///
347    /// # Returns
348    /// * `HashMap<String, Parameter>` - Hierarchical parameter names
349    fn all_named_parameters(&self) -> HashMap<String, crate::Parameter> {
350        let mut all_params = HashMap::new();
351
352        // Add own parameters
353        for (name, param) in self.named_parameters() {
354            all_params.insert(name, param);
355        }
356
357        // Add child parameters with prefixes
358        let children_named = self.named_children();
359        for (child_name, child) in children_named {
360            for (param_name, param) in child.all_named_parameters() {
361                let full_name = format!("{}.{}", child_name, param_name);
362                all_params.insert(full_name, param);
363            }
364        }
365
366        all_params
367    }
368
369    /// Check if in training mode
370    ///
371    /// Default implementation returns true. Override if your module tracks training state.
372    fn training(&self) -> bool {
373        true
374    }
375
376    /// Set training mode
377    ///
378    /// Convenience method that calls `set_training(true)`.
379    fn train(&mut self) {
380        self.set_training(true);
381    }
382
383    /// Set evaluation mode
384    ///
385    /// Convenience method that calls `set_training(false)`.
386    fn eval(&mut self) {
387        self.set_training(false);
388    }
389
390    /// Set training mode (internal implementation)
391    ///
392    /// Default implementation does nothing. Override if your module needs to track
393    /// training state or propagate it to child modules.
394    fn set_training(&mut self, _training: bool) {
395        // Default: do nothing
396    }
397
398    /// Move module to device
399    ///
400    /// Default implementation does nothing. Override if your module has parameters
401    /// or buffers that need to be moved between devices.
402    fn to_device(&mut self, _device: DeviceType) -> Result<()> {
403        Ok(())
404    }
405
406    /// Load state dictionary into the module
407    ///
408    /// # Arguments
409    /// * `state_dict` - Map of parameter names to tensors
410    /// * `strict` - Whether to require exact parameter name matches
411    ///
412    /// # Returns
413    /// * `Result<()>` - Success or error with details about missing/unexpected keys
414    fn load_state_dict(
415        &mut self,
416        state_dict: &HashMap<String, Tensor>,
417        strict: bool,
418    ) -> Result<()> {
419        let current_params = self.all_named_parameters();
420        let mut missing_keys = Vec::new();
421        let mut unexpected_keys = Vec::new();
422
423        // Check for missing parameters
424        for name in current_params.keys() {
425            if !state_dict.contains_key(name) {
426                missing_keys.push(name.clone());
427            }
428        }
429
430        // Check for unexpected parameters
431        for name in state_dict.keys() {
432            if !current_params.contains_key(name) {
433                unexpected_keys.push(name.clone());
434            }
435        }
436
437        if strict && (!missing_keys.is_empty() || !unexpected_keys.is_empty()) {
438            return Err(torsh_core::error::TorshError::Other(format!(
439                "State dict loading failed. Missing keys: {:?}, Unexpected keys: {:?}",
440                missing_keys, unexpected_keys
441            )));
442        }
443
444        // Load matching parameters
445        for (name, param) in current_params {
446            if let Some(new_tensor) = state_dict.get(&name) {
447                // Validate tensor shapes match
448                let current_shape = param.shape()?;
449                let new_shape = new_tensor.shape().dims().to_vec();
450                if current_shape != new_shape {
451                    return Err(torsh_core::error::TorshError::Other(format!(
452                        "Shape mismatch for parameter '{}': expected {:?}, got {:?}",
453                        name, current_shape, new_shape
454                    )));
455                }
456
457                // Copy tensor data
458                *param.tensor().write() = new_tensor.clone();
459            }
460        }
461
462        Ok(())
463    }
464
465    /// Load state dictionary with default strict=true
466    fn load_state_dict_strict(&mut self, state_dict: &HashMap<String, Tensor>) -> Result<()> {
467        self.load_state_dict(state_dict, true)
468    }
469
470    /// Save state dictionary from the module
471    fn state_dict(&self) -> HashMap<String, Tensor> {
472        let mut state = HashMap::new();
473        for (name, param) in self.all_named_parameters() {
474            state.insert(name, param.clone_data());
475        }
476        state
477    }
478
479    /// Get the module name (optional, for debugging and serialization)
480    fn name(&self) -> Option<&str> {
481        None
482    }
483
484    /// Get all buffers (non-trainable parameters)
485    fn buffers(&self) -> Vec<std::sync::Arc<parking_lot::RwLock<Tensor>>> {
486        Vec::new()
487    }
488
489    /// Get named buffers
490    fn named_buffers(&self) -> HashMap<String, std::sync::Arc<parking_lot::RwLock<Tensor>>> {
491        HashMap::new()
492    }
493
494    /// Get all direct child modules
495    ///
496    /// Default implementation returns an empty vector. Override if your module
497    /// contains child modules.
498    fn children(&self) -> Vec<&dyn Module> {
499        Vec::new()
500    }
501
502    /// Get all direct child modules with names
503    ///
504    /// Default implementation returns an empty vector. Override if your module
505    /// contains named child modules.
506    fn named_children(&self) -> Vec<(String, &dyn Module)> {
507        Vec::new()
508    }
509
510    /// Get all modules recursively (depth-first traversal)
511    fn modules(&self) -> Vec<&dyn Module>
512    where
513        Self: Sized,
514    {
515        let mut modules: Vec<&dyn Module> = vec![self];
516        for child in self.children() {
517            // Since child is &dyn Module, we need to use a different approach
518            // We'll just collect immediate children for now
519            modules.push(child);
520        }
521        modules
522    }
523
524    /// Get all modules recursively with hierarchical names
525    fn named_modules(&self) -> Vec<(String, &dyn Module)>
526    where
527        Self: Sized,
528    {
529        let mut modules: Vec<(String, &dyn Module)> = vec![(String::new(), self)];
530
531        for (child_name, child) in self.named_children() {
532            // Since child is &dyn Module, we need to use a different approach
533            // We'll just collect immediate named children for now
534            modules.push((child_name, child));
535        }
536
537        modules
538    }
539
540    /// Zero all gradients recursively
541    ///
542    /// Default implementation does nothing. Override if your module has parameters
543    /// with gradients that need to be zeroed.
544    fn zero_grad(&mut self) {
545        // Default: do nothing
546    }
547
548    /// Count total number of parameters
549    fn num_parameters(&self) -> usize {
550        self.all_parameters()
551            .values()
552            .map(|p| p.numel().unwrap_or(0))
553            .sum()
554    }
555
556    /// Count trainable parameters
557    fn num_trainable_parameters(&self) -> usize {
558        self.all_parameters()
559            .values()
560            .filter(|p| p.requires_grad())
561            .map(|p| p.numel().unwrap_or(0))
562            .sum()
563    }
564
565    /// Get memory usage estimate in bytes
566    fn memory_usage(&self) -> usize {
567        self.all_parameters()
568            .values()
569            .map(|p| p.numel().unwrap_or(0) * 4) // Assume f32 = 4 bytes
570            .sum()
571    }
572
573    /// Freeze all parameters (set requires_grad = false)
574    ///
575    /// Default implementation does nothing. Override if your module has parameters
576    /// that can be frozen/unfrozen.
577    fn freeze(&mut self) {
578        // Default: do nothing
579    }
580
581    /// Unfreeze all parameters (set requires_grad = true)
582    ///
583    /// Default implementation does nothing. Override if your module has parameters
584    /// that can be frozen/unfrozen.
585    fn unfreeze(&mut self) {
586        // Default: do nothing
587    }
588
589    /// Get string representation
590    fn extra_repr(&self) -> String {
591        String::new()
592    }
593
594    /// Register a hook for this module (default implementation does nothing)
595    fn register_hook(
596        &mut self,
597        _hook_type: crate::HookType,
598        _callback: crate::HookCallback,
599    ) -> Option<crate::HookHandle> {
600        None
601    }
602
603    /// Remove a hook by handle (default implementation does nothing)
604    fn remove_hook(&mut self, _hook_type: crate::HookType, _handle: crate::HookHandle) -> bool {
605        false
606    }
607
608    /// Execute hooks of a specific type (default implementation does nothing)
609    fn execute_hooks(
610        &self,
611        _hook_type: crate::HookType,
612        _input: &Tensor,
613        _output: Option<&Tensor>,
614    ) -> Result<()> {
615        Ok(())
616    }
617
618    /// Forward pass with hooks support
619    fn forward_with_hooks(&self, input: &Tensor) -> Result<Tensor> {
620        // Execute pre-forward hooks
621        self.execute_hooks(crate::HookType::PreForward, input, None)?;
622
623        // Perform forward pass
624        let output = self.forward(input)?;
625
626        // Execute post-forward hooks
627        self.execute_hooks(crate::HookType::PostForward, input, Some(&output))?;
628
629        Ok(output)
630    }
631
632    /// Check if module has hooks registered
633    fn has_hooks(&self, _hook_type: crate::HookType) -> bool {
634        false
635    }
636
637    // === Ergonomic Helper Methods ===
638
639    /// Convenient method to call forward and handle common patterns
640    ///
641    /// This is equivalent to `forward()` but provides a more ergonomic interface
642    /// for chaining operations.
643    fn call(&self, input: &Tensor) -> Result<Tensor> {
644        self.forward(input)
645    }
646
647    /// Apply the module to input (alias for forward)
648    ///
649    /// PyTorch-style method name for compatibility.
650    fn apply(&self, input: &Tensor) -> Result<Tensor> {
651        self.forward(input)
652    }
653
654    /// Check if the module has any parameters
655    fn has_parameters(&self) -> bool {
656        !self.parameters().is_empty()
657    }
658
659    /// Check if the module has any child modules
660    fn has_children(&self) -> bool {
661        !self.children().is_empty()
662    }
663
664    /// Get parameter count (convenience method)
665    fn parameter_count(&self) -> usize {
666        self.num_parameters()
667    }
668
669    /// Get trainable parameter count (convenience method)
670    fn trainable_parameter_count(&self) -> usize {
671        self.num_trainable_parameters()
672    }
673
674    /// Get memory usage in MB (convenience method)
675    fn memory_usage_mb(&self) -> f64 {
676        self.memory_usage() as f64 / (1024.0 * 1024.0)
677    }
678
679    /// Toggle training mode (convenience method)
680    fn toggle_training(&mut self) {
681        self.set_training(!self.training());
682    }
683
684    /// Check if module is in evaluation mode
685    fn eval_mode(&self) -> bool {
686        !self.training()
687    }
688
689    // === Enhanced Ergonomic Methods ===
690
691    /// Sequential forward pass through multiple modules
692    ///
693    /// This provides a convenient way to chain multiple forward passes.
694    ///
695    /// # Arguments
696    /// * `modules` - Slice of modules to apply sequentially
697    /// * `input` - Input tensor
698    ///
699    /// # Returns
700    /// * `Result<Tensor>` - Final output after all modules
701    ///
702    /// # Example
703    /// ```ignore
704    /// let result = Module::sequential_forward(&[&layer1, &layer2, &layer3], &input)?;
705    /// ```
706    fn sequential_forward(modules: &[&dyn Module], mut input: Tensor) -> Result<Tensor>
707    where
708        Self: Sized,
709    {
710        for module in modules {
711            input = module.forward(&input)?;
712        }
713        Ok(input)
714    }
715
716    /// Apply module multiple times with different inputs (batch processing)
717    ///
718    /// This is useful for processing multiple independent inputs through the same module.
719    ///
720    /// # Arguments
721    /// * `inputs` - Vector of input tensors
722    ///
723    /// # Returns
724    /// * `Result<Vec<Tensor>>` - Vector of output tensors
725    fn batch_forward(&self, inputs: &[Tensor]) -> Result<Vec<Tensor>> {
726        inputs.iter().map(|input| self.forward(input)).collect()
727    }
728
729    /// Forward with condition - only apply if condition is true
730    ///
731    /// This provides conditional execution of modules.
732    ///
733    /// # Arguments
734    /// * `input` - Input tensor
735    /// * `condition` - Whether to apply this module
736    ///
737    /// # Returns
738    /// * `Result<Tensor>` - Output tensor (input if condition is false)
739    fn conditional_forward(&self, input: &Tensor, condition: bool) -> Result<Tensor> {
740        if condition {
741            self.forward(input)
742        } else {
743            Ok(input.clone())
744        }
745    }
746
747    /// Forward with residual connection
748    ///
749    /// Applies the module and adds the result to the input (residual/skip connection).
750    ///
751    /// # Arguments
752    /// * `input` - Input tensor
753    ///
754    /// # Returns
755    /// * `Result<Tensor>` - Output tensor (input + forward(input))
756    fn residual_forward(&self, input: &Tensor) -> Result<Tensor> {
757        let output = self.forward(input)?;
758        // This would use tensor addition when available
759        // For now, just return the output
760        Ok(output)
761    }
762
763    /// Get detailed module information for debugging
764    ///
765    /// This provides comprehensive information about the module state.
766    ///
767    /// # Returns
768    /// * `ModuleInfo` - Detailed module information
769    fn module_info(&self) -> crate::ModuleInfo {
770        crate::ModuleInfo {
771            name: self.name().unwrap_or("Unknown").to_string(),
772            training: self.training(),
773            parameter_count: self.num_parameters(),
774            trainable_parameter_count: self.num_trainable_parameters(),
775            memory_usage_bytes: self.memory_usage(),
776            has_children: self.has_children(),
777            children_count: self.children().len(),
778        }
779    }
780
781    /// Check if module is ready for training
782    ///
783    /// Performs various checks to ensure the module is properly configured for training.
784    ///
785    /// # Returns
786    /// * `Result<()>` - Ok if ready, Error with details if not
787    fn check_training_readiness(&self) -> Result<()> {
788        // Check if module has parameters
789        if !self.has_parameters() {
790            return Err(torsh_core::error::TorshError::Other(
791                "Module has no parameters - may not be trainable".to_string(),
792            ));
793        }
794
795        // Check if in training mode
796        if !self.training() {
797            return Err(torsh_core::error::TorshError::Other(
798                "Module is in evaluation mode - switch to training mode first".to_string(),
799            ));
800        }
801
802        // Check for finite parameters
803        for param in self.parameters().values() {
804            if !param.is_finite().unwrap_or(false) {
805                return Err(torsh_core::error::TorshError::Other(
806                    "Module contains non-finite parameters (NaN or infinity)".to_string(),
807                ));
808            }
809        }
810
811        Ok(())
812    }
813
814    /// Get parameter names matching a pattern
815    ///
816    /// This helps with selective parameter access and manipulation.
817    ///
818    /// # Arguments
819    /// * `pattern` - String pattern to match against parameter names
820    ///
821    /// # Returns
822    /// * `Vec<String>` - Vector of parameter names matching the pattern
823    fn parameter_names_matching(&self, pattern: &str) -> Vec<String> {
824        self.all_named_parameters()
825            .keys()
826            .filter(|name| name.contains(pattern))
827            .cloned()
828            .collect()
829    }
830
831    /// Get parameters by layer type (e.g., "weight", "bias")
832    ///
833    /// # Arguments
834    /// * `param_type` - Type of parameters to retrieve
835    ///
836    /// # Returns
837    /// * `HashMap<String, Parameter>` - Filtered parameters
838    fn parameters_by_type(&self, param_type: &str) -> HashMap<String, crate::Parameter> {
839        self.all_named_parameters()
840            .into_iter()
841            .filter(|(name, _)| name.contains(param_type))
842            .collect()
843    }
844
845    /// Clone module parameters (for creating copies or checkpoints)
846    ///
847    /// # Returns
848    /// * `HashMap<String, Tensor>` - Cloned parameter tensors
849    fn clone_parameters(&self) -> HashMap<String, Tensor> {
850        self.all_named_parameters()
851            .into_iter()
852            .map(|(name, param)| (name, param.clone_data()))
853            .collect()
854    }
855
856    /// Quick diagnostic check of module health
857    ///
858    /// # Returns
859    /// * `ModuleDiagnostics` - Diagnostic information
860    fn diagnose(&self) -> crate::ModuleDiagnostics {
861        let mut issues = Vec::new();
862        let mut warnings = Vec::new();
863
864        // Check parameter health
865        for (name, param) in self.all_named_parameters() {
866            if let Ok(diag) = param.diagnose() {
867                if !diag.issues.is_empty() {
868                    issues.extend(
869                        diag.issues
870                            .into_iter()
871                            .map(|issue| format!("{}: {}", name, issue)),
872                    );
873                }
874                if !diag.warnings.is_empty() {
875                    warnings.extend(
876                        diag.warnings
877                            .into_iter()
878                            .map(|warning| format!("{}: {}", name, warning)),
879                    );
880                }
881            }
882        }
883
884        // Check training readiness
885        if let Err(e) = self.check_training_readiness() {
886            warnings.push(format!("Training readiness: {}", e));
887        }
888
889        crate::ModuleDiagnostics {
890            module_info: self.module_info(),
891            issues,
892            warnings,
893            parameter_diagnostics: self
894                .all_named_parameters()
895                .into_iter()
896                .filter_map(|(name, param)| param.diagnose().ok().map(|d| (name, d)))
897                .collect(),
898        }
899    }
900}
901
902/// Implementation for boxed trait objects
903impl Module for Box<dyn Module> {
904    fn forward(&self, x: &Tensor) -> Result<Tensor> {
905        (**self).forward(x)
906    }
907
908    fn parameters(&self) -> HashMap<String, crate::Parameter> {
909        (**self).parameters()
910    }
911
912    fn train(&mut self) {
913        (**self).train()
914    }
915
916    fn eval(&mut self) {
917        (**self).eval()
918    }
919
920    fn training(&self) -> bool {
921        (**self).training()
922    }
923
924    fn children(&self) -> Vec<&dyn Module> {
925        (**self).children()
926    }
927
928    fn named_children(&self) -> Vec<(String, &dyn Module)> {
929        (**self).named_children()
930    }
931
932    fn set_training(&mut self, training: bool) {
933        (**self).set_training(training)
934    }
935
936    fn to_device(&mut self, device: DeviceType) -> Result<()> {
937        (**self).to_device(device)
938    }
939}
940
941/// Implementation for mutable references to boxed trait objects
942impl Module for &mut Box<dyn Module> {
943    fn forward(&self, x: &Tensor) -> Result<Tensor> {
944        (***self).forward(x)
945    }
946
947    fn parameters(&self) -> HashMap<String, crate::Parameter> {
948        (***self).parameters()
949    }
950
951    fn train(&mut self) {
952        (***self).train()
953    }
954
955    fn eval(&mut self) {
956        (***self).eval()
957    }
958
959    fn training(&self) -> bool {
960        (***self).training()
961    }
962
963    fn children(&self) -> Vec<&dyn Module> {
964        (***self).children()
965    }
966
967    fn named_children(&self) -> Vec<(String, &dyn Module)> {
968        (***self).named_children()
969    }
970
971    fn set_training(&mut self, training: bool) {
972        (***self).set_training(training)
973    }
974
975    fn to_device(&mut self, device: DeviceType) -> Result<()> {
976        (***self).to_device(device)
977    }
978}