torsh_nn/core/module_ext.rs
1//! Module trait ergonomic extensions
2//!
3//! This module provides additional ergonomic helpers and utilities for the Module trait,
4//! following Rust best practices for trait extension patterns.
5
6use crate::Module;
7use torsh_core::device::DeviceType;
8use torsh_core::error::Result;
9use torsh_tensor::Tensor;
10
11#[cfg(feature = "std")]
12use std::collections::HashMap;
13
14#[cfg(not(feature = "std"))]
15use hashbrown::HashMap;
16
17/// Extension trait providing additional ergonomic methods for Module
18///
19/// This trait is automatically implemented for all types that implement Module,
20/// providing additional convenience methods without requiring changes to existing code.
21///
22/// # Design Philosophy
23///
24/// This extension follows Rust's extension trait pattern to:
25/// - Add functionality without breaking backward compatibility
26/// - Keep the core Module trait focused on essential methods
27/// - Provide advanced features for users who need them
28/// - Enable fluent/builder-style APIs
29pub trait ModuleExt: Module {
30 // === Fluent API / Builder Pattern Methods ===
31
32 /// Chain forward pass with a transformation function
33 ///
34 /// This enables functional-style chaining of operations.
35 ///
36 /// # Arguments
37 /// * `input` - Input tensor
38 /// * `f` - Transformation function to apply to output
39 ///
40 /// # Returns
41 /// * `Result<Tensor>` - Transformed output
42 ///
43 /// # Example
44 /// ```ignore
45 /// let output = layer.and_then(&input, |x| x.relu())?;
46 /// ```
47 fn and_then<F>(&self, input: &Tensor, f: F) -> Result<Tensor>
48 where
49 F: FnOnce(Tensor) -> Result<Tensor>,
50 {
51 let output = self.forward(input)?;
52 f(output)
53 }
54
55 /// Apply module and map the output with a function
56 ///
57 /// Similar to `and_then` but for non-failable transformations.
58 ///
59 /// # Arguments
60 /// * `input` - Input tensor
61 /// * `f` - Mapping function
62 ///
63 /// # Returns
64 /// * `Result<Tensor>` - Mapped output
65 fn map<F>(&self, input: &Tensor, f: F) -> Result<Tensor>
66 where
67 F: FnOnce(Tensor) -> Tensor,
68 {
69 let output = self.forward(input)?;
70 Ok(f(output))
71 }
72
73 /// Forward pass with input transformation
74 ///
75 /// Apply a transformation to the input before forwarding.
76 ///
77 /// # Arguments
78 /// * `input` - Input tensor
79 /// * `f` - Input transformation function
80 ///
81 /// # Returns
82 /// * `Result<Tensor>` - Module output
83 fn with_input<F>(&self, input: &Tensor, f: F) -> Result<Tensor>
84 where
85 F: FnOnce(&Tensor) -> Result<Tensor>,
86 {
87 let transformed = f(input)?;
88 self.forward(&transformed)
89 }
90
91 // === Inspection and Debugging Methods ===
92
93 /// Get human-readable summary of the module
94 ///
95 /// # Returns
96 /// * `String` - Formatted module summary
97 fn summary(&self) -> String {
98 let info = self.module_info();
99 format!(
100 "Module: {}\n\
101 Training: {}\n\
102 Parameters: {} ({} trainable)\n\
103 Memory: {:.2} MB\n\
104 Children: {}",
105 info.name,
106 info.training,
107 info.parameter_count,
108 info.trainable_parameter_count,
109 info.memory_usage_bytes as f64 / (1024.0 * 1024.0),
110 info.children_count
111 )
112 }
113
114 /// Print module summary to stdout
115 fn print_summary(&self) {
116 println!("{}", self.summary());
117 }
118
119 /// Get parameter statistics
120 ///
121 /// # Returns
122 /// * `ParameterStats` - Statistical information about parameters
123 fn parameter_stats(&self) -> ParameterStats {
124 let params = self.all_parameters();
125 let mut total_params = 0;
126 let mut trainable_params = 0;
127 let mut frozen_params = 0;
128 let mut total_memory = 0;
129
130 for param in params.values() {
131 let numel = param.numel().unwrap_or(0);
132 total_params += numel;
133 total_memory += numel * 4; // Assume f32
134
135 if param.requires_grad() {
136 trainable_params += numel;
137 } else {
138 frozen_params += numel;
139 }
140 }
141
142 ParameterStats {
143 total_parameters: total_params,
144 trainable_parameters: trainable_params,
145 frozen_parameters: frozen_params,
146 total_memory_bytes: total_memory,
147 parameter_count: params.len(),
148 }
149 }
150
151 /// Check if module has NaN or Inf in parameters
152 ///
153 /// # Returns
154 /// * `bool` - true if all parameters are finite
155 fn has_finite_parameters(&self) -> bool {
156 self.all_parameters()
157 .values()
158 .all(|p| p.is_finite().unwrap_or(false))
159 }
160
161 /// Get list of parameter names
162 ///
163 /// # Returns
164 /// * `Vec<String>` - Sorted list of parameter names
165 fn parameter_names(&self) -> Vec<String> {
166 let mut names: Vec<String> = self.all_named_parameters().keys().cloned().collect();
167 names.sort();
168 names
169 }
170
171 /// Get parameter by name
172 ///
173 /// # Arguments
174 /// * `name` - Parameter name
175 ///
176 /// # Returns
177 /// * `Option<Parameter>` - Parameter if found
178 fn get_parameter(&self, name: &str) -> Option<crate::Parameter> {
179 self.all_named_parameters().get(name).cloned()
180 }
181
182 // === Training Utilities ===
183
184 /// Freeze specific parameters by name pattern
185 ///
186 /// # Arguments
187 /// * `pattern` - String pattern to match parameter names
188 ///
189 /// # Returns
190 /// * `usize` - Number of parameters frozen
191 ///
192 /// # Note
193 /// This method currently returns a count but doesn't actually freeze parameters
194 /// because Parameter's requires_grad is immutable. This is a placeholder for
195 /// future implementation when mutable parameter access is available.
196 fn freeze_matching(&mut self, pattern: &str) -> usize {
197 let mut count = 0;
198 for (name, _param) in self.all_named_parameters() {
199 if name.contains(pattern) {
200 // TODO: Implement actual freezing when Parameter supports it
201 count += 1;
202 }
203 }
204 count
205 }
206
207 /// Unfreeze specific parameters by name pattern
208 ///
209 /// # Arguments
210 /// * `pattern` - String pattern to match parameter names
211 ///
212 /// # Returns
213 /// * `usize` - Number of parameters unfrozen
214 ///
215 /// # Note
216 /// This method currently returns a count but doesn't actually unfreeze parameters
217 /// because Parameter's requires_grad is immutable. This is a placeholder for
218 /// future implementation when mutable parameter access is available.
219 fn unfreeze_matching(&mut self, pattern: &str) -> usize {
220 let mut count = 0;
221 for (name, _param) in self.all_named_parameters() {
222 if name.contains(pattern) {
223 // TODO: Implement actual unfreezing when Parameter supports it
224 count += 1;
225 }
226 }
227 count
228 }
229
230 /// Get list of frozen parameters
231 ///
232 /// # Returns
233 /// * `Vec<String>` - Names of frozen parameters
234 fn frozen_parameters(&self) -> Vec<String> {
235 self.all_named_parameters()
236 .into_iter()
237 .filter(|(_, p)| !p.requires_grad())
238 .map(|(name, _)| name)
239 .collect()
240 }
241
242 /// Get list of trainable parameters
243 ///
244 /// # Returns
245 /// * `Vec<String>` - Names of trainable parameters
246 fn trainable_parameters(&self) -> Vec<String> {
247 self.all_named_parameters()
248 .into_iter()
249 .filter(|(_, p)| p.requires_grad())
250 .map(|(name, _)| name)
251 .collect()
252 }
253
254 // === Advanced Operations ===
255
256 /// Clone module parameters into a new state dict
257 ///
258 /// # Returns
259 /// * `HashMap<String, Tensor>` - Cloned state dictionary
260 fn clone_state_dict(&self) -> HashMap<String, Tensor> {
261 self.state_dict()
262 }
263
264 /// Apply a function to all parameters
265 ///
266 /// # Arguments
267 /// * `f` - Function to apply to each parameter
268 fn apply_to_parameters<F>(&self, mut f: F)
269 where
270 F: FnMut(&str, &crate::Parameter),
271 {
272 for (name, param) in self.all_named_parameters() {
273 f(&name, ¶m);
274 }
275 }
276
277 /// Count parameters by layer type
278 ///
279 /// # Returns
280 /// * `HashMap<String, usize>` - Parameter count per layer type
281 fn parameters_by_type(&self) -> HashMap<String, usize> {
282 let mut counts = HashMap::new();
283
284 for (name, param) in self.all_named_parameters() {
285 // Extract layer type from name (first component)
286 let layer_type = name.split('.').next().unwrap_or("unknown").to_string();
287
288 let numel = param.numel().unwrap_or(0);
289 *counts.entry(layer_type).or_insert(0) += numel;
290 }
291
292 counts
293 }
294
295 /// Validate module configuration
296 ///
297 /// Performs comprehensive validation of module state.
298 ///
299 /// # Returns
300 /// * `Result<ValidationReport>` - Validation results
301 fn validate(&self) -> Result<ValidationReport> {
302 let mut report = ValidationReport::default();
303
304 // Check for parameters
305 if !self.has_parameters() {
306 report.warnings.push("Module has no parameters".to_string());
307 }
308
309 // Check for finite parameters
310 if !self.has_finite_parameters() {
311 report
312 .errors
313 .push("Module has non-finite parameters (NaN or Inf)".to_string());
314 }
315
316 // Check memory usage
317 let memory_mb = self.memory_usage_mb();
318 if memory_mb > 1024.0 {
319 report
320 .warnings
321 .push(format!("Large memory usage: {:.2} GB", memory_mb / 1024.0));
322 }
323
324 // Check parameter count
325 let param_count = self.num_parameters();
326 if param_count > 100_000_000 {
327 report
328 .warnings
329 .push(format!("Very large model: {} parameters", param_count));
330 }
331
332 report.is_valid = report.errors.is_empty();
333 Ok(report)
334 }
335
336 /// Get device of parameters (if consistent)
337 ///
338 /// # Returns
339 /// * `Option<DeviceType>` - Device if all parameters are on same device
340 ///
341 /// # Note
342 /// Currently returns None as Parameter doesn't expose device information.
343 /// This is a placeholder for future implementation.
344 fn device(&self) -> Option<DeviceType> {
345 // TODO: Implement when Parameter exposes device information
346 // For now, assume CPU as default
347 if self.has_parameters() {
348 Some(DeviceType::Cpu)
349 } else {
350 None
351 }
352 }
353
354 /// Check if all parameters are on CPU
355 ///
356 /// # Returns
357 /// * `bool` - true if all parameters on CPU
358 fn is_cpu(&self) -> bool {
359 self.device() == Some(DeviceType::Cpu)
360 }
361
362 /// Check if all parameters are on CUDA device
363 ///
364 /// # Returns
365 /// * `bool` - true if all parameters on CUDA
366 fn is_cuda(&self) -> bool {
367 matches!(self.device(), Some(DeviceType::Cuda(_)))
368 }
369}
370
371// Automatically implement ModuleExt for all types that implement Module
372impl<T: Module + ?Sized> ModuleExt for T {}
373
374// === Supporting Types ===
375
376/// Parameter statistics for a module
377#[derive(Debug, Clone)]
378pub struct ParameterStats {
379 /// Total number of parameter elements
380 pub total_parameters: usize,
381 /// Number of trainable parameter elements
382 pub trainable_parameters: usize,
383 /// Number of frozen parameter elements
384 pub frozen_parameters: usize,
385 /// Total memory usage in bytes
386 pub total_memory_bytes: usize,
387 /// Number of distinct parameters
388 pub parameter_count: usize,
389}
390
391impl ParameterStats {
392 /// Get memory usage in megabytes
393 pub fn memory_mb(&self) -> f64 {
394 self.total_memory_bytes as f64 / (1024.0 * 1024.0)
395 }
396
397 /// Get memory usage in gigabytes
398 pub fn memory_gb(&self) -> f64 {
399 self.memory_mb() / 1024.0
400 }
401
402 /// Get percentage of parameters that are trainable
403 pub fn trainable_percentage(&self) -> f64 {
404 if self.total_parameters == 0 {
405 0.0
406 } else {
407 (self.trainable_parameters as f64 / self.total_parameters as f64) * 100.0
408 }
409 }
410}
411
412/// Validation report for a module
413#[derive(Debug, Clone, Default)]
414pub struct ValidationReport {
415 /// Whether the module is valid
416 pub is_valid: bool,
417 /// List of errors found
418 pub errors: Vec<String>,
419 /// List of warnings
420 pub warnings: Vec<String>,
421}
422
423impl ValidationReport {
424 /// Check if validation passed without errors
425 pub fn passed(&self) -> bool {
426 self.is_valid && self.errors.is_empty()
427 }
428
429 /// Get total number of issues (errors + warnings)
430 pub fn issue_count(&self) -> usize {
431 self.errors.len() + self.warnings.len()
432 }
433
434 /// Format as human-readable string
435 pub fn format(&self) -> String {
436 let mut result = String::new();
437
438 result.push_str(&format!(
439 "Validation: {}\n",
440 if self.is_valid { "PASSED" } else { "FAILED" }
441 ));
442
443 if !self.errors.is_empty() {
444 result.push_str("\nErrors:\n");
445 for error in &self.errors {
446 result.push_str(&format!(" - {}\n", error));
447 }
448 }
449
450 if !self.warnings.is_empty() {
451 result.push_str("\nWarnings:\n");
452 for warning in &self.warnings {
453 result.push_str(&format!(" - {}\n", warning));
454 }
455 }
456
457 result
458 }
459}
460
461#[cfg(test)]
462mod tests {
463
464 // Tests would go here - skipped for brevity
465}