Skip to main content

torsh_nn/utils/
mod.rs

1//! Module utilities and helper functions
2//!
3//! This module provides utility traits and helper functions for working with
4//! neural network modules, including module application patterns, parameter analysis,
5//! and model introspection utilities.
6//!
7//! # Examples
8//!
9//! ## Parameter Analysis
10//! ```rust,ignore
11//! use torsh_nn::layers::linear::Linear;
12//! use torsh_nn::utils::analysis;
13//!
14//! let model = Linear::new(784, 10, true);
15//! let param_count = analysis::count_parameters(&model);
16//! let trainable_count = analysis::count_trainable_parameters(&model);
17//! println!("Total parameters: {}, Trainable: {}", param_count, trainable_count);
18//! ```
19//!
20//! ## Model Introspection
21//! ```rust,ignore
22//! use torsh_nn::layers::linear::Linear;
23//! use torsh_nn::utils::introspection;
24//!
25//! let model = Linear::new(784, 128, true);
26//! introspection::print_parameter_summary(&model, "MyModel");
27//! let issues = introspection::health_check(&model);
28//! if !issues.is_empty() {
29//!     println!("Model issues detected: {:?}", issues);
30//! }
31//! ```
32
33use std::collections::HashMap;
34use torsh_core::error::Result;
35
36use crate::{Module, Parameter};
37
38/// Extension trait for applying functions to modules (separate to maintain dyn compatibility)
39pub trait ModuleApply {
40    /// Apply a function to all submodules recursively
41    fn apply<F>(&mut self, f: &F) -> Result<()>
42    where
43        F: Fn(&mut dyn Module) -> Result<()>;
44
45    /// Apply function to all parameters recursively
46    fn apply_to_parameters<F>(&mut self, f: &F) -> Result<()>
47    where
48        F: Fn(&mut Parameter) -> Result<()>;
49
50    /// Apply function to all modules recursively
51    fn apply_to_modules<F>(&mut self, f: &F) -> Result<()>
52    where
53        F: Fn(&mut dyn Module) -> Result<()>;
54}
55
56/// Blanket implementation for all modules
57impl<T: Module> ModuleApply for T {
58    fn apply<F>(&mut self, f: &F) -> Result<()>
59    where
60        F: Fn(&mut dyn Module) -> Result<()>,
61    {
62        f(self)
63    }
64
65    fn apply_to_parameters<F>(&mut self, _f: &F) -> Result<()>
66    where
67        F: Fn(&mut Parameter) -> Result<()>,
68    {
69        // Default implementation does nothing - override in implementing types
70        Ok(())
71    }
72
73    fn apply_to_modules<F>(&mut self, _f: &F) -> Result<()>
74    where
75        F: Fn(&mut dyn Module) -> Result<()>,
76    {
77        // Default implementation does nothing - override in implementing types
78        Ok(())
79    }
80}
81
82/// Utility functions for neural network analysis and debugging
83pub mod analysis {
84    use super::*;
85
86    /// Count the total number of parameters in a module
87    pub fn count_parameters(module: &dyn Module) -> usize {
88        module
89            .parameters()
90            .values()
91            .map(|param| param.tensor().read().shape().numel())
92            .sum()
93    }
94
95    /// Count only trainable parameters (parameters with requires_grad = true)
96    pub fn count_trainable_parameters(module: &dyn Module) -> usize {
97        module
98            .parameters()
99            .values()
100            .filter(|param| param.requires_grad())
101            .map(|param| param.tensor().read().shape().numel())
102            .sum()
103    }
104
105    /// Get detailed parameter statistics for a module
106    pub fn parameter_statistics(module: &dyn Module) -> ModuleParameterStats {
107        let parameters = module.parameters();
108        let total_params = parameters
109            .values()
110            .map(|param| param.tensor().read().shape().numel())
111            .sum();
112        let trainable_params = parameters
113            .values()
114            .filter(|param| param.requires_grad())
115            .map(|param| param.tensor().read().shape().numel())
116            .sum();
117
118        ModuleParameterStats {
119            total_parameters: total_params,
120            trainable_parameters: trainable_params,
121            frozen_parameters: total_params - trainable_params,
122            parameter_count_by_layer: parameters
123                .iter()
124                .map(|(name, param)| (name.clone(), param.tensor().read().shape().numel()))
125                .collect(),
126        }
127    }
128
129    /// Check if a module is in training mode
130    pub fn is_training(module: &dyn Module) -> bool {
131        module.training()
132    }
133
134    /// Get the names of all parameters in a module
135    pub fn parameter_names(module: &dyn Module) -> Vec<String> {
136        module.parameters().keys().cloned().collect()
137    }
138
139    /// Find parameters by name pattern (simple substring matching)
140    pub fn find_parameters_by_pattern(
141        module: &dyn Module,
142        pattern: &str,
143    ) -> HashMap<String, Parameter> {
144        module
145            .parameters()
146            .into_iter()
147            .filter(|(name, _)| name.contains(pattern))
148            .collect()
149    }
150}
151
152/// Parameter statistics structure for module analysis
153#[derive(Debug, Clone)]
154pub struct ModuleParameterStats {
155    /// Total number of parameters (trainable + frozen)
156    pub total_parameters: usize,
157    /// Number of trainable parameters
158    pub trainable_parameters: usize,
159    /// Number of frozen parameters
160    pub frozen_parameters: usize,
161    /// Parameter count by layer name
162    pub parameter_count_by_layer: HashMap<String, usize>,
163}
164
165impl ModuleParameterStats {
166    /// Get the percentage of parameters that are trainable
167    pub fn trainable_percentage(&self) -> f32 {
168        if self.total_parameters == 0 {
169            0.0
170        } else {
171            (self.trainable_parameters as f32 / self.total_parameters as f32) * 100.0
172        }
173    }
174
175    /// Get memory usage estimate in bytes (assuming f32 parameters)
176    pub fn memory_usage_bytes(&self) -> usize {
177        self.total_parameters * 4 // 4 bytes per f32
178    }
179
180    /// Get memory usage in MB
181    pub fn memory_usage_mb(&self) -> f32 {
182        self.memory_usage_bytes() as f32 / (1024.0 * 1024.0)
183    }
184}
185
186/// Utility functions for model introspection and debugging
187pub mod introspection {
188    use super::*;
189
190    /// Print a summary of module parameters
191    pub fn print_parameter_summary(module: &dyn Module, module_name: &str) {
192        let stats = analysis::parameter_statistics(module);
193        println!("=== {} Parameter Summary ===", module_name);
194        println!("Total parameters: {}", stats.total_parameters);
195        println!(
196            "Trainable parameters: {} ({:.1}%)",
197            stats.trainable_parameters,
198            stats.trainable_percentage()
199        );
200        println!("Frozen parameters: {}", stats.frozen_parameters);
201        println!("Memory usage: {:.2} MB", stats.memory_usage_mb());
202        println!("Training mode: {}", analysis::is_training(module));
203
204        if !stats.parameter_count_by_layer.is_empty() {
205            println!("\nParameters by layer:");
206            let mut layers: Vec<_> = stats.parameter_count_by_layer.iter().collect();
207            layers.sort_by_key(|(name, _)| name.as_str());
208            for (layer, count) in layers {
209                println!("  {}: {}", layer, count);
210            }
211        }
212        println!();
213    }
214
215    /// Check for common issues in module configuration
216    pub fn health_check(module: &dyn Module) -> Vec<String> {
217        let mut issues = Vec::new();
218        let stats = analysis::parameter_statistics(module);
219
220        // Check for modules with no parameters
221        if stats.total_parameters == 0 {
222            issues.push("Module has no parameters".to_string());
223        }
224
225        // Check for modules with all frozen parameters
226        if stats.total_parameters > 0 && stats.trainable_parameters == 0 {
227            issues.push("All parameters are frozen - module won't train".to_string());
228        }
229
230        // Check for very large models (>1GB)
231        if stats.memory_usage_bytes() > 1024 * 1024 * 1024 {
232            issues.push(format!(
233                "Large model detected: {:.1} GB",
234                stats.memory_usage_bytes() as f32 / (1024.0 * 1024.0 * 1024.0)
235            ));
236        }
237
238        issues
239    }
240}
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245    use crate::layers::linear::Linear;
246
247    #[test]
248    fn test_parameter_counting() {
249        let linear = Linear::new(10, 5, true); // 10*5 + 5 = 55 parameters
250        let count = analysis::count_parameters(&linear);
251        assert_eq!(count, 55);
252    }
253
254    #[test]
255    fn test_parameter_stats() {
256        let linear = Linear::new(4, 2, true); // 4*2 + 2 = 10 parameters
257        let stats = analysis::parameter_statistics(&linear);
258        assert_eq!(stats.total_parameters, 10);
259        assert_eq!(stats.trainable_parameters, 10);
260        assert_eq!(stats.frozen_parameters, 0);
261        assert_eq!(stats.trainable_percentage(), 100.0);
262    }
263
264    #[test]
265    fn test_memory_calculation() {
266        let stats = ModuleParameterStats {
267            total_parameters: 1000,
268            trainable_parameters: 800,
269            frozen_parameters: 200,
270            parameter_count_by_layer: HashMap::new(),
271        };
272        assert_eq!(stats.memory_usage_bytes(), 4000); // 1000 * 4 bytes
273        assert_eq!(stats.memory_usage_mb(), 4000.0 / (1024.0 * 1024.0));
274    }
275}