torsh_nn/core/mod.rs
1//! Core module trait system for neural network modules
2//!
3//! This module provides the foundational Module trait and essential interfaces
4//! for all neural network components in ToRSh.
5
6pub mod module_ext;
7
8pub use module_ext::{ModuleExt, ParameterStats, ValidationReport};
9
10use torsh_core::device::DeviceType;
11use torsh_core::error::Result;
12use torsh_tensor::Tensor;
13
14// Conditional imports for std/no_std compatibility
15#[cfg(feature = "std")]
16use std::collections::HashMap;
17
18#[cfg(not(feature = "std"))]
19use hashbrown::HashMap;
20
21/// Base trait for all neural network modules
22///
23/// This trait provides the core interface for all neural network components,
24/// following PyTorch-compatible patterns for maximum interoperability.
25///
26/// # Design Philosophy
27///
28/// This trait is designed for maximum ergonomics while maintaining flexibility:
29/// - Most methods have sensible defaults to reduce boilerplate
30/// - Core functionality (forward, training mode) is required
31/// - Parameter management is streamlined with helper methods
32/// - Hook system is optional but well-integrated
33///
34/// # Required Methods
35///
36/// Only [`forward()`](Module::forward) must be implemented. All other methods have sensible defaults.
37///
38/// # Implementing a Custom Module
39///
40/// ## Basic Module (No Parameters)
41///
42/// ```rust
43/// use torsh_nn::{Module, ModuleBase};
44/// use torsh_tensor::Tensor;
45/// use torsh_core::error::Result;
46///
47/// /// Simple ReLU activation function
48/// struct MyReLU {
49/// base: ModuleBase,
50/// }
51///
52/// impl MyReLU {
53/// fn new() -> Self {
54/// Self {
55/// base: ModuleBase::new(),
56/// }
57/// }
58/// }
59///
60/// impl Module for MyReLU {
61/// fn forward(&self, input: &Tensor) -> Result<Tensor> {
62/// // Apply ReLU: max(0, x)
63/// input.relu()
64/// }
65///
66/// fn training(&self) -> bool {
67/// self.base.training()
68/// }
69///
70/// fn set_training(&mut self, training: bool) {
71/// self.base.set_training(training);
72/// }
73/// }
74/// ```
75///
76/// ## Module with Parameters
77///
78/// ```rust
79/// use torsh_nn::{Module, ModuleBase, Parameter};
80/// use torsh_tensor::{Tensor, creation};
81/// use torsh_core::error::Result;
82/// use std::collections::HashMap;
83///
84/// /// Custom linear layer with learnable weight and bias
85/// struct MyLinear {
86/// base: ModuleBase,
87/// in_features: usize,
88/// out_features: usize,
89/// }
90///
91/// impl MyLinear {
92/// fn new(in_features: usize, out_features: usize) -> Result<Self> {
93/// let mut base = ModuleBase::new();
94///
95/// // Initialize weight: [in_features, out_features]
96/// let weight = creation::randn(&[in_features, out_features])?;
97/// base.register_parameter("weight".to_string(), Parameter::new(weight));
98///
99/// // Initialize bias: [out_features]
100/// let bias = creation::zeros(&[out_features])?;
101/// base.register_parameter("bias".to_string(), Parameter::new(bias));
102///
103/// Ok(Self {
104/// base,
105/// in_features,
106/// out_features,
107/// })
108/// }
109/// }
110///
111/// impl Module for MyLinear {
112/// fn forward(&self, input: &Tensor) -> Result<Tensor> {
113/// // Get parameters
114/// let weight = self.base.parameters["weight"].tensor().read().clone();
115/// let bias = self.base.parameters["bias"].tensor().read().clone();
116///
117/// // Compute: input @ weight + bias
118/// let output = input.matmul(&weight)?;
119/// output.add(&bias)
120/// }
121///
122/// fn parameters(&self) -> HashMap<String, Parameter> {
123/// self.base.parameters.clone()
124/// }
125///
126/// fn named_parameters(&self) -> HashMap<String, Parameter> {
127/// self.base.named_parameters()
128/// }
129///
130/// fn training(&self) -> bool {
131/// self.base.training()
132/// }
133///
134/// fn set_training(&mut self, training: bool) {
135/// self.base.set_training(training);
136/// }
137/// }
138/// ```
139///
140/// ## Module with Training/Evaluation Modes
141///
142/// ```rust
143/// use torsh_nn::{Module, ModuleBase};
144/// use torsh_tensor::{Tensor, creation};
145/// use torsh_core::error::Result;
146///
147/// /// Dropout layer with different behavior in train/eval modes
148/// struct MyDropout {
149/// base: ModuleBase,
150/// p: f32, // Dropout probability
151/// }
152///
153/// impl MyDropout {
154/// fn new(p: f32) -> Self {
155/// Self {
156/// base: ModuleBase::new(),
157/// p,
158/// }
159/// }
160/// }
161///
162/// impl Module for MyDropout {
163/// fn forward(&self, input: &Tensor) -> Result<Tensor> {
164/// if self.training() {
165/// // During training: randomly zero out units
166/// let rand_vals = creation::rand_like(input)?;
167/// let threshold = creation::full_like(input, self.p)?;
168/// // Keep elements where random value >= p
169/// let keep_mask = rand_vals.ge(&threshold)?;
170/// let zeros = creation::zeros_like(input)?;
171/// let masked = input.where_tensor(&keep_mask, &zeros)?;
172/// // Scale by 1/(1-p) to maintain expected value (inverted dropout)
173/// masked.div_scalar(1.0 - self.p)
174/// } else {
175/// // During evaluation: pass through unchanged
176/// Ok(input.clone())
177/// }
178/// }
179///
180/// fn training(&self) -> bool {
181/// self.base.training()
182/// }
183///
184/// fn set_training(&mut self, training: bool) {
185/// self.base.set_training(training);
186/// }
187/// }
188/// ```
189///
190/// # Complete Training Loop Example
191///
192/// ```rust,no_run
193/// use torsh_nn::prelude::{Module, Linear, Sequential};
194/// use torsh_tensor::{Tensor, creation};
195/// use torsh_core::error::Result;
196///
197/// fn train_eval_example() -> Result<()> {
198/// // 1. Create model
199/// let mut model = Sequential::new()
200/// .add(Linear::new(784, 128, true))
201/// .add(Linear::new(128, 10, true));
202///
203/// // 2. Set model to training mode
204/// model.train();
205///
206/// // 3. Forward pass
207/// let inputs = creation::randn(&[32, 784])?;
208/// let outputs = model.forward(&inputs)?;
209///
210/// // 4. Compute loss (simplified MSE)
211/// let targets = creation::randn(&[32, 10])?;
212/// let diff = outputs.sub(&targets)?;
213/// let loss = diff.pow_scalar(2.0)?.mean(None, false)?;
214///
215/// // 5. Backward pass (computes gradients)
216/// loss.backward()?;
217///
218/// // 6. Switch to evaluation mode
219/// model.eval();
220/// let test_input = creation::randn(&[10, 784])?;
221/// let test_output = model.forward(&test_input)?;
222///
223/// Ok(())
224/// }
225/// ```
226///
227/// # Method Categories
228///
229/// ## Core Methods (Required)
230/// - [`forward()`](Module::forward) - Forward pass computation
231///
232/// ## Parameter Management
233/// - [`parameters()`](Module::parameters) - Get all trainable parameters
234/// - [`named_parameters()`](Module::named_parameters) - Get parameters with names
235/// - [`all_parameters()`](Module::all_parameters) - Get parameters recursively
236/// - [`zero_grad()`](Module::zero_grad) - Clear all gradients
237///
238/// ## Training Mode Control
239/// - [`training()`](Module::training) - Check if in training mode
240/// - [`train()`](Module::train) - Set to training mode
241/// - [`eval()`](Module::eval) - Set to evaluation mode
242/// - [`set_training()`](Module::set_training) - Set training mode explicitly
243///
244/// ## Module Hierarchy
245/// - [`children()`](Module::children) - Get direct child modules
246/// - [`named_children()`](Module::named_children) - Get children with names
247/// - [`modules()`](Module::modules) - Get all modules recursively
248///
249/// ## State Management
250/// - [`state_dict()`](Module::state_dict) - Save module state
251/// - [`load_state_dict()`](Module::load_state_dict) - Load module state
252/// - [`to_device()`](Module::to_device) - Move module to device
253///
254/// ## Utilities
255/// - [`freeze()`](Module::freeze) - Freeze all parameters
256/// - [`unfreeze()`](Module::unfreeze) - Unfreeze all parameters
257/// - [`num_parameters()`](Module::num_parameters) - Count total parameters
258/// - [`diagnose()`](Module::diagnose) - Check module health
259///
260/// # PyTorch Compatibility
261///
262/// ToRSh's Module trait closely follows PyTorch's `nn.Module` interface:
263///
264/// | PyTorch | ToRSh | Notes |
265/// |---------|-------|-------|
266/// | `forward(x)` | `forward(&x)` | Returns `Result<Tensor>` |
267/// | `parameters()` | `parameters()` | Returns `HashMap<String, Parameter>` |
268/// | `train()` | `train()` | Sets training mode |
269/// | `eval()` | `eval()` | Sets evaluation mode |
270/// | `state_dict()` | `state_dict()` | Returns parameter tensors |
271/// | `load_state_dict()` | `load_state_dict()` | Loads from HashMap |
272/// | `to(device)` | `to_device(device)` | Moves to device |
273/// | `zero_grad()` | `zero_grad()` | Clears gradients |
274///
275/// # Best Practices
276///
277/// 1. **Always use `ModuleBase`**: Store a `ModuleBase` instance in your module to handle
278/// common functionality like parameters, buffers, and training state.
279///
280/// 2. **Register parameters in constructor**: Use `base.register_parameter()` to register
281/// all trainable parameters during module creation.
282///
283/// 3. **Implement training/eval behavior**: If your module behaves differently during
284/// training vs evaluation (like Dropout, BatchNorm), check `self.training()`.
285///
286/// 4. **Use `Result<Tensor>` for error handling**: Always return `Result<Tensor>` from
287/// `forward()` to properly propagate errors.
288///
289/// 5. **Delegate to `ModuleBase`**: Implement parameter and training mode methods by
290/// delegating to your `ModuleBase` instance.
291///
292/// 6. **Initialize parameters properly**: Use initialization methods from `torsh_nn::init`
293/// for better training convergence.
294pub trait Module: Send + Sync {
295 /// Forward pass through the module
296 ///
297 /// This is the only required method that must be implemented by all modules.
298 ///
299 /// # Arguments
300 /// * `input` - Input tensor
301 ///
302 /// # Returns
303 /// * `Result<Tensor>` - Output tensor or error
304 fn forward(&self, input: &Tensor) -> Result<Tensor>;
305
306 /// Get all parameters in the module (non-recursive)
307 ///
308 /// Override this method if your module has trainable parameters.
309 /// The default implementation returns an empty map.
310 ///
311 /// # Returns
312 /// * `HashMap<String, Parameter>` - Map of parameter names to parameters
313 fn parameters(&self) -> HashMap<String, crate::Parameter> {
314 HashMap::new()
315 }
316
317 /// Get named parameters (non-recursive)
318 ///
319 /// Default implementation delegates to `parameters()`. Override if you need
320 /// different behavior for named vs unnamed parameter access.
321 ///
322 /// # Returns
323 /// * `HashMap<String, Parameter>` - Map of parameter names to parameters
324 fn named_parameters(&self) -> HashMap<String, crate::Parameter> {
325 self.parameters()
326 }
327
328 /// Get all parameters recursively including submodules
329 ///
330 /// # Returns
331 /// * `HashMap<String, Parameter>` - Flattened map of all parameters
332 fn all_parameters(&self) -> HashMap<String, crate::Parameter> {
333 let mut all_params = self.parameters();
334
335 for child in self.children() {
336 let child_params = child.all_parameters();
337 for (name, param) in child_params {
338 all_params.insert(name, param);
339 }
340 }
341
342 all_params
343 }
344
345 /// Get all named parameters recursively with module prefixes
346 ///
347 /// # Returns
348 /// * `HashMap<String, Parameter>` - Hierarchical parameter names
349 fn all_named_parameters(&self) -> HashMap<String, crate::Parameter> {
350 let mut all_params = HashMap::new();
351
352 // Add own parameters
353 for (name, param) in self.named_parameters() {
354 all_params.insert(name, param);
355 }
356
357 // Add child parameters with prefixes
358 let children_named = self.named_children();
359 for (child_name, child) in children_named {
360 for (param_name, param) in child.all_named_parameters() {
361 let full_name = format!("{}.{}", child_name, param_name);
362 all_params.insert(full_name, param);
363 }
364 }
365
366 all_params
367 }
368
369 /// Check if in training mode
370 ///
371 /// Default implementation returns true. Override if your module tracks training state.
372 fn training(&self) -> bool {
373 true
374 }
375
376 /// Set training mode
377 ///
378 /// Convenience method that calls `set_training(true)`.
379 fn train(&mut self) {
380 self.set_training(true);
381 }
382
383 /// Set evaluation mode
384 ///
385 /// Convenience method that calls `set_training(false)`.
386 fn eval(&mut self) {
387 self.set_training(false);
388 }
389
390 /// Set training mode (internal implementation)
391 ///
392 /// Default implementation does nothing. Override if your module needs to track
393 /// training state or propagate it to child modules.
394 fn set_training(&mut self, _training: bool) {
395 // Default: do nothing
396 }
397
398 /// Move module to device
399 ///
400 /// Default implementation does nothing. Override if your module has parameters
401 /// or buffers that need to be moved between devices.
402 fn to_device(&mut self, _device: DeviceType) -> Result<()> {
403 Ok(())
404 }
405
406 /// Load state dictionary into the module
407 ///
408 /// # Arguments
409 /// * `state_dict` - Map of parameter names to tensors
410 /// * `strict` - Whether to require exact parameter name matches
411 ///
412 /// # Returns
413 /// * `Result<()>` - Success or error with details about missing/unexpected keys
414 fn load_state_dict(
415 &mut self,
416 state_dict: &HashMap<String, Tensor>,
417 strict: bool,
418 ) -> Result<()> {
419 let current_params = self.all_named_parameters();
420 let mut missing_keys = Vec::new();
421 let mut unexpected_keys = Vec::new();
422
423 // Check for missing parameters
424 for name in current_params.keys() {
425 if !state_dict.contains_key(name) {
426 missing_keys.push(name.clone());
427 }
428 }
429
430 // Check for unexpected parameters
431 for name in state_dict.keys() {
432 if !current_params.contains_key(name) {
433 unexpected_keys.push(name.clone());
434 }
435 }
436
437 if strict && (!missing_keys.is_empty() || !unexpected_keys.is_empty()) {
438 return Err(torsh_core::error::TorshError::Other(format!(
439 "State dict loading failed. Missing keys: {:?}, Unexpected keys: {:?}",
440 missing_keys, unexpected_keys
441 )));
442 }
443
444 // Load matching parameters
445 for (name, param) in current_params {
446 if let Some(new_tensor) = state_dict.get(&name) {
447 // Validate tensor shapes match
448 let current_shape = param.shape()?;
449 let new_shape = new_tensor.shape().dims().to_vec();
450 if current_shape != new_shape {
451 return Err(torsh_core::error::TorshError::Other(format!(
452 "Shape mismatch for parameter '{}': expected {:?}, got {:?}",
453 name, current_shape, new_shape
454 )));
455 }
456
457 // Copy tensor data
458 *param.tensor().write() = new_tensor.clone();
459 }
460 }
461
462 Ok(())
463 }
464
465 /// Load state dictionary with default strict=true
466 fn load_state_dict_strict(&mut self, state_dict: &HashMap<String, Tensor>) -> Result<()> {
467 self.load_state_dict(state_dict, true)
468 }
469
470 /// Save state dictionary from the module
471 fn state_dict(&self) -> HashMap<String, Tensor> {
472 let mut state = HashMap::new();
473 for (name, param) in self.all_named_parameters() {
474 state.insert(name, param.clone_data());
475 }
476 state
477 }
478
479 /// Get the module name (optional, for debugging and serialization)
480 fn name(&self) -> Option<&str> {
481 None
482 }
483
484 /// Get all buffers (non-trainable parameters)
485 fn buffers(&self) -> Vec<std::sync::Arc<parking_lot::RwLock<Tensor>>> {
486 Vec::new()
487 }
488
489 /// Get named buffers
490 fn named_buffers(&self) -> HashMap<String, std::sync::Arc<parking_lot::RwLock<Tensor>>> {
491 HashMap::new()
492 }
493
494 /// Get all direct child modules
495 ///
496 /// Default implementation returns an empty vector. Override if your module
497 /// contains child modules.
498 fn children(&self) -> Vec<&dyn Module> {
499 Vec::new()
500 }
501
502 /// Get all direct child modules with names
503 ///
504 /// Default implementation returns an empty vector. Override if your module
505 /// contains named child modules.
506 fn named_children(&self) -> Vec<(String, &dyn Module)> {
507 Vec::new()
508 }
509
510 /// Get all modules recursively (depth-first traversal)
511 fn modules(&self) -> Vec<&dyn Module>
512 where
513 Self: Sized,
514 {
515 let mut modules: Vec<&dyn Module> = vec![self];
516 for child in self.children() {
517 // Since child is &dyn Module, we need to use a different approach
518 // We'll just collect immediate children for now
519 modules.push(child);
520 }
521 modules
522 }
523
524 /// Get all modules recursively with hierarchical names
525 fn named_modules(&self) -> Vec<(String, &dyn Module)>
526 where
527 Self: Sized,
528 {
529 let mut modules: Vec<(String, &dyn Module)> = vec![(String::new(), self)];
530
531 for (child_name, child) in self.named_children() {
532 // Since child is &dyn Module, we need to use a different approach
533 // We'll just collect immediate named children for now
534 modules.push((child_name, child));
535 }
536
537 modules
538 }
539
540 /// Zero all gradients recursively
541 ///
542 /// Default implementation does nothing. Override if your module has parameters
543 /// with gradients that need to be zeroed.
544 fn zero_grad(&mut self) {
545 // Default: do nothing
546 }
547
548 /// Count total number of parameters
549 fn num_parameters(&self) -> usize {
550 self.all_parameters()
551 .values()
552 .map(|p| p.numel().unwrap_or(0))
553 .sum()
554 }
555
556 /// Count trainable parameters
557 fn num_trainable_parameters(&self) -> usize {
558 self.all_parameters()
559 .values()
560 .filter(|p| p.requires_grad())
561 .map(|p| p.numel().unwrap_or(0))
562 .sum()
563 }
564
565 /// Get memory usage estimate in bytes
566 fn memory_usage(&self) -> usize {
567 self.all_parameters()
568 .values()
569 .map(|p| p.numel().unwrap_or(0) * 4) // Assume f32 = 4 bytes
570 .sum()
571 }
572
573 /// Freeze all parameters (set requires_grad = false)
574 ///
575 /// Default implementation does nothing. Override if your module has parameters
576 /// that can be frozen/unfrozen.
577 fn freeze(&mut self) {
578 // Default: do nothing
579 }
580
581 /// Unfreeze all parameters (set requires_grad = true)
582 ///
583 /// Default implementation does nothing. Override if your module has parameters
584 /// that can be frozen/unfrozen.
585 fn unfreeze(&mut self) {
586 // Default: do nothing
587 }
588
589 /// Get string representation
590 fn extra_repr(&self) -> String {
591 String::new()
592 }
593
594 /// Register a hook for this module (default implementation does nothing)
595 fn register_hook(
596 &mut self,
597 _hook_type: crate::HookType,
598 _callback: crate::HookCallback,
599 ) -> Option<crate::HookHandle> {
600 None
601 }
602
603 /// Remove a hook by handle (default implementation does nothing)
604 fn remove_hook(&mut self, _hook_type: crate::HookType, _handle: crate::HookHandle) -> bool {
605 false
606 }
607
608 /// Execute hooks of a specific type (default implementation does nothing)
609 fn execute_hooks(
610 &self,
611 _hook_type: crate::HookType,
612 _input: &Tensor,
613 _output: Option<&Tensor>,
614 ) -> Result<()> {
615 Ok(())
616 }
617
618 /// Forward pass with hooks support
619 fn forward_with_hooks(&self, input: &Tensor) -> Result<Tensor> {
620 // Execute pre-forward hooks
621 self.execute_hooks(crate::HookType::PreForward, input, None)?;
622
623 // Perform forward pass
624 let output = self.forward(input)?;
625
626 // Execute post-forward hooks
627 self.execute_hooks(crate::HookType::PostForward, input, Some(&output))?;
628
629 Ok(output)
630 }
631
632 /// Check if module has hooks registered
633 fn has_hooks(&self, _hook_type: crate::HookType) -> bool {
634 false
635 }
636
637 // === Ergonomic Helper Methods ===
638
639 /// Convenient method to call forward and handle common patterns
640 ///
641 /// This is equivalent to `forward()` but provides a more ergonomic interface
642 /// for chaining operations.
643 fn call(&self, input: &Tensor) -> Result<Tensor> {
644 self.forward(input)
645 }
646
647 /// Apply the module to input (alias for forward)
648 ///
649 /// PyTorch-style method name for compatibility.
650 fn apply(&self, input: &Tensor) -> Result<Tensor> {
651 self.forward(input)
652 }
653
654 /// Check if the module has any parameters
655 fn has_parameters(&self) -> bool {
656 !self.parameters().is_empty()
657 }
658
659 /// Check if the module has any child modules
660 fn has_children(&self) -> bool {
661 !self.children().is_empty()
662 }
663
664 /// Get parameter count (convenience method)
665 fn parameter_count(&self) -> usize {
666 self.num_parameters()
667 }
668
669 /// Get trainable parameter count (convenience method)
670 fn trainable_parameter_count(&self) -> usize {
671 self.num_trainable_parameters()
672 }
673
674 /// Get memory usage in MB (convenience method)
675 fn memory_usage_mb(&self) -> f64 {
676 self.memory_usage() as f64 / (1024.0 * 1024.0)
677 }
678
679 /// Toggle training mode (convenience method)
680 fn toggle_training(&mut self) {
681 self.set_training(!self.training());
682 }
683
684 /// Check if module is in evaluation mode
685 fn eval_mode(&self) -> bool {
686 !self.training()
687 }
688
689 // === Enhanced Ergonomic Methods ===
690
691 /// Sequential forward pass through multiple modules
692 ///
693 /// This provides a convenient way to chain multiple forward passes.
694 ///
695 /// # Arguments
696 /// * `modules` - Slice of modules to apply sequentially
697 /// * `input` - Input tensor
698 ///
699 /// # Returns
700 /// * `Result<Tensor>` - Final output after all modules
701 ///
702 /// # Example
703 /// ```ignore
704 /// let result = Module::sequential_forward(&[&layer1, &layer2, &layer3], &input)?;
705 /// ```
706 fn sequential_forward(modules: &[&dyn Module], mut input: Tensor) -> Result<Tensor>
707 where
708 Self: Sized,
709 {
710 for module in modules {
711 input = module.forward(&input)?;
712 }
713 Ok(input)
714 }
715
716 /// Apply module multiple times with different inputs (batch processing)
717 ///
718 /// This is useful for processing multiple independent inputs through the same module.
719 ///
720 /// # Arguments
721 /// * `inputs` - Vector of input tensors
722 ///
723 /// # Returns
724 /// * `Result<Vec<Tensor>>` - Vector of output tensors
725 fn batch_forward(&self, inputs: &[Tensor]) -> Result<Vec<Tensor>> {
726 inputs.iter().map(|input| self.forward(input)).collect()
727 }
728
729 /// Forward with condition - only apply if condition is true
730 ///
731 /// This provides conditional execution of modules.
732 ///
733 /// # Arguments
734 /// * `input` - Input tensor
735 /// * `condition` - Whether to apply this module
736 ///
737 /// # Returns
738 /// * `Result<Tensor>` - Output tensor (input if condition is false)
739 fn conditional_forward(&self, input: &Tensor, condition: bool) -> Result<Tensor> {
740 if condition {
741 self.forward(input)
742 } else {
743 Ok(input.clone())
744 }
745 }
746
747 /// Forward with residual connection
748 ///
749 /// Applies the module and adds the result to the input (residual/skip connection).
750 ///
751 /// # Arguments
752 /// * `input` - Input tensor
753 ///
754 /// # Returns
755 /// * `Result<Tensor>` - Output tensor (input + forward(input))
756 fn residual_forward(&self, input: &Tensor) -> Result<Tensor> {
757 let output = self.forward(input)?;
758 // This would use tensor addition when available
759 // For now, just return the output
760 Ok(output)
761 }
762
763 /// Get detailed module information for debugging
764 ///
765 /// This provides comprehensive information about the module state.
766 ///
767 /// # Returns
768 /// * `ModuleInfo` - Detailed module information
769 fn module_info(&self) -> crate::ModuleInfo {
770 crate::ModuleInfo {
771 name: self.name().unwrap_or("Unknown").to_string(),
772 training: self.training(),
773 parameter_count: self.num_parameters(),
774 trainable_parameter_count: self.num_trainable_parameters(),
775 memory_usage_bytes: self.memory_usage(),
776 has_children: self.has_children(),
777 children_count: self.children().len(),
778 }
779 }
780
781 /// Check if module is ready for training
782 ///
783 /// Performs various checks to ensure the module is properly configured for training.
784 ///
785 /// # Returns
786 /// * `Result<()>` - Ok if ready, Error with details if not
787 fn check_training_readiness(&self) -> Result<()> {
788 // Check if module has parameters
789 if !self.has_parameters() {
790 return Err(torsh_core::error::TorshError::Other(
791 "Module has no parameters - may not be trainable".to_string(),
792 ));
793 }
794
795 // Check if in training mode
796 if !self.training() {
797 return Err(torsh_core::error::TorshError::Other(
798 "Module is in evaluation mode - switch to training mode first".to_string(),
799 ));
800 }
801
802 // Check for finite parameters
803 for param in self.parameters().values() {
804 if !param.is_finite().unwrap_or(false) {
805 return Err(torsh_core::error::TorshError::Other(
806 "Module contains non-finite parameters (NaN or infinity)".to_string(),
807 ));
808 }
809 }
810
811 Ok(())
812 }
813
814 /// Get parameter names matching a pattern
815 ///
816 /// This helps with selective parameter access and manipulation.
817 ///
818 /// # Arguments
819 /// * `pattern` - String pattern to match against parameter names
820 ///
821 /// # Returns
822 /// * `Vec<String>` - Vector of parameter names matching the pattern
823 fn parameter_names_matching(&self, pattern: &str) -> Vec<String> {
824 self.all_named_parameters()
825 .keys()
826 .filter(|name| name.contains(pattern))
827 .cloned()
828 .collect()
829 }
830
831 /// Get parameters by layer type (e.g., "weight", "bias")
832 ///
833 /// # Arguments
834 /// * `param_type` - Type of parameters to retrieve
835 ///
836 /// # Returns
837 /// * `HashMap<String, Parameter>` - Filtered parameters
838 fn parameters_by_type(&self, param_type: &str) -> HashMap<String, crate::Parameter> {
839 self.all_named_parameters()
840 .into_iter()
841 .filter(|(name, _)| name.contains(param_type))
842 .collect()
843 }
844
845 /// Clone module parameters (for creating copies or checkpoints)
846 ///
847 /// # Returns
848 /// * `HashMap<String, Tensor>` - Cloned parameter tensors
849 fn clone_parameters(&self) -> HashMap<String, Tensor> {
850 self.all_named_parameters()
851 .into_iter()
852 .map(|(name, param)| (name, param.clone_data()))
853 .collect()
854 }
855
856 /// Quick diagnostic check of module health
857 ///
858 /// # Returns
859 /// * `ModuleDiagnostics` - Diagnostic information
860 fn diagnose(&self) -> crate::ModuleDiagnostics {
861 let mut issues = Vec::new();
862 let mut warnings = Vec::new();
863
864 // Check parameter health
865 for (name, param) in self.all_named_parameters() {
866 if let Ok(diag) = param.diagnose() {
867 if !diag.issues.is_empty() {
868 issues.extend(
869 diag.issues
870 .into_iter()
871 .map(|issue| format!("{}: {}", name, issue)),
872 );
873 }
874 if !diag.warnings.is_empty() {
875 warnings.extend(
876 diag.warnings
877 .into_iter()
878 .map(|warning| format!("{}: {}", name, warning)),
879 );
880 }
881 }
882 }
883
884 // Check training readiness
885 if let Err(e) = self.check_training_readiness() {
886 warnings.push(format!("Training readiness: {}", e));
887 }
888
889 crate::ModuleDiagnostics {
890 module_info: self.module_info(),
891 issues,
892 warnings,
893 parameter_diagnostics: self
894 .all_named_parameters()
895 .into_iter()
896 .filter_map(|(name, param)| param.diagnose().ok().map(|d| (name, d)))
897 .collect(),
898 }
899 }
900}
901
902/// Implementation for boxed trait objects
903impl Module for Box<dyn Module> {
904 fn forward(&self, x: &Tensor) -> Result<Tensor> {
905 (**self).forward(x)
906 }
907
908 fn parameters(&self) -> HashMap<String, crate::Parameter> {
909 (**self).parameters()
910 }
911
912 fn train(&mut self) {
913 (**self).train()
914 }
915
916 fn eval(&mut self) {
917 (**self).eval()
918 }
919
920 fn training(&self) -> bool {
921 (**self).training()
922 }
923
924 fn children(&self) -> Vec<&dyn Module> {
925 (**self).children()
926 }
927
928 fn named_children(&self) -> Vec<(String, &dyn Module)> {
929 (**self).named_children()
930 }
931
932 fn set_training(&mut self, training: bool) {
933 (**self).set_training(training)
934 }
935
936 fn to_device(&mut self, device: DeviceType) -> Result<()> {
937 (**self).to_device(device)
938 }
939}
940
941/// Implementation for mutable references to boxed trait objects
942impl Module for &mut Box<dyn Module> {
943 fn forward(&self, x: &Tensor) -> Result<Tensor> {
944 (***self).forward(x)
945 }
946
947 fn parameters(&self) -> HashMap<String, crate::Parameter> {
948 (***self).parameters()
949 }
950
951 fn train(&mut self) {
952 (***self).train()
953 }
954
955 fn eval(&mut self) {
956 (***self).eval()
957 }
958
959 fn training(&self) -> bool {
960 (***self).training()
961 }
962
963 fn children(&self) -> Vec<&dyn Module> {
964 (***self).children()
965 }
966
967 fn named_children(&self) -> Vec<(String, &dyn Module)> {
968 (***self).named_children()
969 }
970
971 fn set_training(&mut self, training: bool) {
972 (***self).set_training(training)
973 }
974
975 fn to_device(&mut self, device: DeviceType) -> Result<()> {
976 (***self).to_device(device)
977 }
978}