1use std::collections::HashMap;
34use torsh_core::error::Result;
35
36use crate::{Module, Parameter};
37
38pub trait ModuleApply {
40 fn apply<F>(&mut self, f: &F) -> Result<()>
42 where
43 F: Fn(&mut dyn Module) -> Result<()>;
44
45 fn apply_to_parameters<F>(&mut self, f: &F) -> Result<()>
47 where
48 F: Fn(&mut Parameter) -> Result<()>;
49
50 fn apply_to_modules<F>(&mut self, f: &F) -> Result<()>
52 where
53 F: Fn(&mut dyn Module) -> Result<()>;
54}
55
56impl<T: Module> ModuleApply for T {
58 fn apply<F>(&mut self, f: &F) -> Result<()>
59 where
60 F: Fn(&mut dyn Module) -> Result<()>,
61 {
62 f(self)
63 }
64
65 fn apply_to_parameters<F>(&mut self, _f: &F) -> Result<()>
66 where
67 F: Fn(&mut Parameter) -> Result<()>,
68 {
69 Ok(())
71 }
72
73 fn apply_to_modules<F>(&mut self, _f: &F) -> Result<()>
74 where
75 F: Fn(&mut dyn Module) -> Result<()>,
76 {
77 Ok(())
79 }
80}
81
82pub mod analysis {
84 use super::*;
85
86 pub fn count_parameters(module: &dyn Module) -> usize {
88 module
89 .parameters()
90 .values()
91 .map(|param| param.tensor().read().shape().numel())
92 .sum()
93 }
94
95 pub fn count_trainable_parameters(module: &dyn Module) -> usize {
97 module
98 .parameters()
99 .values()
100 .filter(|param| param.requires_grad())
101 .map(|param| param.tensor().read().shape().numel())
102 .sum()
103 }
104
105 pub fn parameter_statistics(module: &dyn Module) -> ModuleParameterStats {
107 let parameters = module.parameters();
108 let total_params = parameters
109 .values()
110 .map(|param| param.tensor().read().shape().numel())
111 .sum();
112 let trainable_params = parameters
113 .values()
114 .filter(|param| param.requires_grad())
115 .map(|param| param.tensor().read().shape().numel())
116 .sum();
117
118 ModuleParameterStats {
119 total_parameters: total_params,
120 trainable_parameters: trainable_params,
121 frozen_parameters: total_params - trainable_params,
122 parameter_count_by_layer: parameters
123 .iter()
124 .map(|(name, param)| (name.clone(), param.tensor().read().shape().numel()))
125 .collect(),
126 }
127 }
128
129 pub fn is_training(module: &dyn Module) -> bool {
131 module.training()
132 }
133
134 pub fn parameter_names(module: &dyn Module) -> Vec<String> {
136 module.parameters().keys().cloned().collect()
137 }
138
139 pub fn find_parameters_by_pattern(
141 module: &dyn Module,
142 pattern: &str,
143 ) -> HashMap<String, Parameter> {
144 module
145 .parameters()
146 .into_iter()
147 .filter(|(name, _)| name.contains(pattern))
148 .collect()
149 }
150}
151
152#[derive(Debug, Clone)]
154pub struct ModuleParameterStats {
155 pub total_parameters: usize,
157 pub trainable_parameters: usize,
159 pub frozen_parameters: usize,
161 pub parameter_count_by_layer: HashMap<String, usize>,
163}
164
165impl ModuleParameterStats {
166 pub fn trainable_percentage(&self) -> f32 {
168 if self.total_parameters == 0 {
169 0.0
170 } else {
171 (self.trainable_parameters as f32 / self.total_parameters as f32) * 100.0
172 }
173 }
174
175 pub fn memory_usage_bytes(&self) -> usize {
177 self.total_parameters * 4 }
179
180 pub fn memory_usage_mb(&self) -> f32 {
182 self.memory_usage_bytes() as f32 / (1024.0 * 1024.0)
183 }
184}
185
186pub mod introspection {
188 use super::*;
189
190 pub fn print_parameter_summary(module: &dyn Module, module_name: &str) {
192 let stats = analysis::parameter_statistics(module);
193 println!("=== {} Parameter Summary ===", module_name);
194 println!("Total parameters: {}", stats.total_parameters);
195 println!(
196 "Trainable parameters: {} ({:.1}%)",
197 stats.trainable_parameters,
198 stats.trainable_percentage()
199 );
200 println!("Frozen parameters: {}", stats.frozen_parameters);
201 println!("Memory usage: {:.2} MB", stats.memory_usage_mb());
202 println!("Training mode: {}", analysis::is_training(module));
203
204 if !stats.parameter_count_by_layer.is_empty() {
205 println!("\nParameters by layer:");
206 let mut layers: Vec<_> = stats.parameter_count_by_layer.iter().collect();
207 layers.sort_by_key(|(name, _)| name.as_str());
208 for (layer, count) in layers {
209 println!(" {}: {}", layer, count);
210 }
211 }
212 println!();
213 }
214
215 pub fn health_check(module: &dyn Module) -> Vec<String> {
217 let mut issues = Vec::new();
218 let stats = analysis::parameter_statistics(module);
219
220 if stats.total_parameters == 0 {
222 issues.push("Module has no parameters".to_string());
223 }
224
225 if stats.total_parameters > 0 && stats.trainable_parameters == 0 {
227 issues.push("All parameters are frozen - module won't train".to_string());
228 }
229
230 if stats.memory_usage_bytes() > 1024 * 1024 * 1024 {
232 issues.push(format!(
233 "Large model detected: {:.1} GB",
234 stats.memory_usage_bytes() as f32 / (1024.0 * 1024.0 * 1024.0)
235 ));
236 }
237
238 issues
239 }
240}
241
242#[cfg(test)]
243mod tests {
244 use super::*;
245 use crate::layers::linear::Linear;
246
247 #[test]
248 fn test_parameter_counting() {
249 let linear = Linear::new(10, 5, true); let count = analysis::count_parameters(&linear);
251 assert_eq!(count, 55);
252 }
253
254 #[test]
255 fn test_parameter_stats() {
256 let linear = Linear::new(4, 2, true); let stats = analysis::parameter_statistics(&linear);
258 assert_eq!(stats.total_parameters, 10);
259 assert_eq!(stats.trainable_parameters, 10);
260 assert_eq!(stats.frozen_parameters, 0);
261 assert_eq!(stats.trainable_percentage(), 100.0);
262 }
263
264 #[test]
265 fn test_memory_calculation() {
266 let stats = ModuleParameterStats {
267 total_parameters: 1000,
268 trainable_parameters: 800,
269 frozen_parameters: 200,
270 parameter_count_by_layer: HashMap::new(),
271 };
272 assert_eq!(stats.memory_usage_bytes(), 4000); assert_eq!(stats.memory_usage_mb(), 4000.0 / (1024.0 * 1024.0));
274 }
275}