Skip to main content

torsh_nn/core/
module_ext.rs

1//! Module trait ergonomic extensions
2//!
3//! This module provides additional ergonomic helpers and utilities for the Module trait,
4//! following Rust best practices for trait extension patterns.
5
6use crate::Module;
7use torsh_core::device::DeviceType;
8use torsh_core::error::Result;
9use torsh_tensor::Tensor;
10
11#[cfg(feature = "std")]
12use std::collections::HashMap;
13
14#[cfg(not(feature = "std"))]
15use hashbrown::HashMap;
16
17/// Extension trait providing additional ergonomic methods for Module
18///
19/// This trait is automatically implemented for all types that implement Module,
20/// providing additional convenience methods without requiring changes to existing code.
21///
22/// # Design Philosophy
23///
24/// This extension follows Rust's extension trait pattern to:
25/// - Add functionality without breaking backward compatibility
26/// - Keep the core Module trait focused on essential methods
27/// - Provide advanced features for users who need them
28/// - Enable fluent/builder-style APIs
29pub trait ModuleExt: Module {
30    // === Fluent API / Builder Pattern Methods ===
31
32    /// Chain forward pass with a transformation function
33    ///
34    /// This enables functional-style chaining of operations.
35    ///
36    /// # Arguments
37    /// * `input` - Input tensor
38    /// * `f` - Transformation function to apply to output
39    ///
40    /// # Returns
41    /// * `Result<Tensor>` - Transformed output
42    ///
43    /// # Example
44    /// ```ignore
45    /// let output = layer.and_then(&input, |x| x.relu())?;
46    /// ```
47    fn and_then<F>(&self, input: &Tensor, f: F) -> Result<Tensor>
48    where
49        F: FnOnce(Tensor) -> Result<Tensor>,
50    {
51        let output = self.forward(input)?;
52        f(output)
53    }
54
55    /// Apply module and map the output with a function
56    ///
57    /// Similar to `and_then` but for non-failable transformations.
58    ///
59    /// # Arguments
60    /// * `input` - Input tensor
61    /// * `f` - Mapping function
62    ///
63    /// # Returns
64    /// * `Result<Tensor>` - Mapped output
65    fn map<F>(&self, input: &Tensor, f: F) -> Result<Tensor>
66    where
67        F: FnOnce(Tensor) -> Tensor,
68    {
69        let output = self.forward(input)?;
70        Ok(f(output))
71    }
72
73    /// Forward pass with input transformation
74    ///
75    /// Apply a transformation to the input before forwarding.
76    ///
77    /// # Arguments
78    /// * `input` - Input tensor
79    /// * `f` - Input transformation function
80    ///
81    /// # Returns
82    /// * `Result<Tensor>` - Module output
83    fn with_input<F>(&self, input: &Tensor, f: F) -> Result<Tensor>
84    where
85        F: FnOnce(&Tensor) -> Result<Tensor>,
86    {
87        let transformed = f(input)?;
88        self.forward(&transformed)
89    }
90
91    // === Inspection and Debugging Methods ===
92
93    /// Get human-readable summary of the module
94    ///
95    /// # Returns
96    /// * `String` - Formatted module summary
97    fn summary(&self) -> String {
98        let info = self.module_info();
99        format!(
100            "Module: {}\n\
101             Training: {}\n\
102             Parameters: {} ({} trainable)\n\
103             Memory: {:.2} MB\n\
104             Children: {}",
105            info.name,
106            info.training,
107            info.parameter_count,
108            info.trainable_parameter_count,
109            info.memory_usage_bytes as f64 / (1024.0 * 1024.0),
110            info.children_count
111        )
112    }
113
114    /// Print module summary to stdout
115    fn print_summary(&self) {
116        println!("{}", self.summary());
117    }
118
119    /// Get parameter statistics
120    ///
121    /// # Returns
122    /// * `ParameterStats` - Statistical information about parameters
123    fn parameter_stats(&self) -> ParameterStats {
124        let params = self.all_parameters();
125        let mut total_params = 0;
126        let mut trainable_params = 0;
127        let mut frozen_params = 0;
128        let mut total_memory = 0;
129
130        for param in params.values() {
131            let numel = param.numel().unwrap_or(0);
132            total_params += numel;
133            total_memory += numel * 4; // Assume f32
134
135            if param.requires_grad() {
136                trainable_params += numel;
137            } else {
138                frozen_params += numel;
139            }
140        }
141
142        ParameterStats {
143            total_parameters: total_params,
144            trainable_parameters: trainable_params,
145            frozen_parameters: frozen_params,
146            total_memory_bytes: total_memory,
147            parameter_count: params.len(),
148        }
149    }
150
151    /// Check if module has NaN or Inf in parameters
152    ///
153    /// # Returns
154    /// * `bool` - true if all parameters are finite
155    fn has_finite_parameters(&self) -> bool {
156        self.all_parameters()
157            .values()
158            .all(|p| p.is_finite().unwrap_or(false))
159    }
160
161    /// Get list of parameter names
162    ///
163    /// # Returns
164    /// * `Vec<String>` - Sorted list of parameter names
165    fn parameter_names(&self) -> Vec<String> {
166        let mut names: Vec<String> = self.all_named_parameters().keys().cloned().collect();
167        names.sort();
168        names
169    }
170
171    /// Get parameter by name
172    ///
173    /// # Arguments
174    /// * `name` - Parameter name
175    ///
176    /// # Returns
177    /// * `Option<Parameter>` - Parameter if found
178    fn get_parameter(&self, name: &str) -> Option<crate::Parameter> {
179        self.all_named_parameters().get(name).cloned()
180    }
181
182    // === Training Utilities ===
183
184    /// Freeze specific parameters by name pattern
185    ///
186    /// # Arguments
187    /// * `pattern` - String pattern to match parameter names
188    ///
189    /// # Returns
190    /// * `usize` - Number of parameters frozen
191    ///
192    /// # Note
193    /// This method currently returns a count but doesn't actually freeze parameters
194    /// because Parameter's requires_grad is immutable. This is a placeholder for
195    /// future implementation when mutable parameter access is available.
196    fn freeze_matching(&mut self, pattern: &str) -> usize {
197        let mut count = 0;
198        for (name, _param) in self.all_named_parameters() {
199            if name.contains(pattern) {
200                // TODO: Implement actual freezing when Parameter supports it
201                count += 1;
202            }
203        }
204        count
205    }
206
207    /// Unfreeze specific parameters by name pattern
208    ///
209    /// # Arguments
210    /// * `pattern` - String pattern to match parameter names
211    ///
212    /// # Returns
213    /// * `usize` - Number of parameters unfrozen
214    ///
215    /// # Note
216    /// This method currently returns a count but doesn't actually unfreeze parameters
217    /// because Parameter's requires_grad is immutable. This is a placeholder for
218    /// future implementation when mutable parameter access is available.
219    fn unfreeze_matching(&mut self, pattern: &str) -> usize {
220        let mut count = 0;
221        for (name, _param) in self.all_named_parameters() {
222            if name.contains(pattern) {
223                // TODO: Implement actual unfreezing when Parameter supports it
224                count += 1;
225            }
226        }
227        count
228    }
229
230    /// Get list of frozen parameters
231    ///
232    /// # Returns
233    /// * `Vec<String>` - Names of frozen parameters
234    fn frozen_parameters(&self) -> Vec<String> {
235        self.all_named_parameters()
236            .into_iter()
237            .filter(|(_, p)| !p.requires_grad())
238            .map(|(name, _)| name)
239            .collect()
240    }
241
242    /// Get list of trainable parameters
243    ///
244    /// # Returns
245    /// * `Vec<String>` - Names of trainable parameters
246    fn trainable_parameters(&self) -> Vec<String> {
247        self.all_named_parameters()
248            .into_iter()
249            .filter(|(_, p)| p.requires_grad())
250            .map(|(name, _)| name)
251            .collect()
252    }
253
254    // === Advanced Operations ===
255
256    /// Clone module parameters into a new state dict
257    ///
258    /// # Returns
259    /// * `HashMap<String, Tensor>` - Cloned state dictionary
260    fn clone_state_dict(&self) -> HashMap<String, Tensor> {
261        self.state_dict()
262    }
263
264    /// Apply a function to all parameters
265    ///
266    /// # Arguments
267    /// * `f` - Function to apply to each parameter
268    fn apply_to_parameters<F>(&self, mut f: F)
269    where
270        F: FnMut(&str, &crate::Parameter),
271    {
272        for (name, param) in self.all_named_parameters() {
273            f(&name, &param);
274        }
275    }
276
277    /// Count parameters by layer type
278    ///
279    /// # Returns
280    /// * `HashMap<String, usize>` - Parameter count per layer type
281    fn parameters_by_type(&self) -> HashMap<String, usize> {
282        let mut counts = HashMap::new();
283
284        for (name, param) in self.all_named_parameters() {
285            // Extract layer type from name (first component)
286            let layer_type = name.split('.').next().unwrap_or("unknown").to_string();
287
288            let numel = param.numel().unwrap_or(0);
289            *counts.entry(layer_type).or_insert(0) += numel;
290        }
291
292        counts
293    }
294
295    /// Validate module configuration
296    ///
297    /// Performs comprehensive validation of module state.
298    ///
299    /// # Returns
300    /// * `Result<ValidationReport>` - Validation results
301    fn validate(&self) -> Result<ValidationReport> {
302        let mut report = ValidationReport::default();
303
304        // Check for parameters
305        if !self.has_parameters() {
306            report.warnings.push("Module has no parameters".to_string());
307        }
308
309        // Check for finite parameters
310        if !self.has_finite_parameters() {
311            report
312                .errors
313                .push("Module has non-finite parameters (NaN or Inf)".to_string());
314        }
315
316        // Check memory usage
317        let memory_mb = self.memory_usage_mb();
318        if memory_mb > 1024.0 {
319            report
320                .warnings
321                .push(format!("Large memory usage: {:.2} GB", memory_mb / 1024.0));
322        }
323
324        // Check parameter count
325        let param_count = self.num_parameters();
326        if param_count > 100_000_000 {
327            report
328                .warnings
329                .push(format!("Very large model: {} parameters", param_count));
330        }
331
332        report.is_valid = report.errors.is_empty();
333        Ok(report)
334    }
335
336    /// Get device of parameters (if consistent)
337    ///
338    /// # Returns
339    /// * `Option<DeviceType>` - Device if all parameters are on same device
340    ///
341    /// # Note
342    /// Currently returns None as Parameter doesn't expose device information.
343    /// This is a placeholder for future implementation.
344    fn device(&self) -> Option<DeviceType> {
345        // TODO: Implement when Parameter exposes device information
346        // For now, assume CPU as default
347        if self.has_parameters() {
348            Some(DeviceType::Cpu)
349        } else {
350            None
351        }
352    }
353
354    /// Check if all parameters are on CPU
355    ///
356    /// # Returns
357    /// * `bool` - true if all parameters on CPU
358    fn is_cpu(&self) -> bool {
359        self.device() == Some(DeviceType::Cpu)
360    }
361
362    /// Check if all parameters are on CUDA device
363    ///
364    /// # Returns
365    /// * `bool` - true if all parameters on CUDA
366    fn is_cuda(&self) -> bool {
367        matches!(self.device(), Some(DeviceType::Cuda(_)))
368    }
369}
370
371// Automatically implement ModuleExt for all types that implement Module
372impl<T: Module + ?Sized> ModuleExt for T {}
373
374// === Supporting Types ===
375
376/// Parameter statistics for a module
377#[derive(Debug, Clone)]
378pub struct ParameterStats {
379    /// Total number of parameter elements
380    pub total_parameters: usize,
381    /// Number of trainable parameter elements
382    pub trainable_parameters: usize,
383    /// Number of frozen parameter elements
384    pub frozen_parameters: usize,
385    /// Total memory usage in bytes
386    pub total_memory_bytes: usize,
387    /// Number of distinct parameters
388    pub parameter_count: usize,
389}
390
391impl ParameterStats {
392    /// Get memory usage in megabytes
393    pub fn memory_mb(&self) -> f64 {
394        self.total_memory_bytes as f64 / (1024.0 * 1024.0)
395    }
396
397    /// Get memory usage in gigabytes
398    pub fn memory_gb(&self) -> f64 {
399        self.memory_mb() / 1024.0
400    }
401
402    /// Get percentage of parameters that are trainable
403    pub fn trainable_percentage(&self) -> f64 {
404        if self.total_parameters == 0 {
405            0.0
406        } else {
407            (self.trainable_parameters as f64 / self.total_parameters as f64) * 100.0
408        }
409    }
410}
411
412/// Validation report for a module
413#[derive(Debug, Clone, Default)]
414pub struct ValidationReport {
415    /// Whether the module is valid
416    pub is_valid: bool,
417    /// List of errors found
418    pub errors: Vec<String>,
419    /// List of warnings
420    pub warnings: Vec<String>,
421}
422
423impl ValidationReport {
424    /// Check if validation passed without errors
425    pub fn passed(&self) -> bool {
426        self.is_valid && self.errors.is_empty()
427    }
428
429    /// Get total number of issues (errors + warnings)
430    pub fn issue_count(&self) -> usize {
431        self.errors.len() + self.warnings.len()
432    }
433
434    /// Format as human-readable string
435    pub fn format(&self) -> String {
436        let mut result = String::new();
437
438        result.push_str(&format!(
439            "Validation: {}\n",
440            if self.is_valid { "PASSED" } else { "FAILED" }
441        ));
442
443        if !self.errors.is_empty() {
444            result.push_str("\nErrors:\n");
445            for error in &self.errors {
446                result.push_str(&format!("  - {}\n", error));
447            }
448        }
449
450        if !self.warnings.is_empty() {
451            result.push_str("\nWarnings:\n");
452            for warning in &self.warnings {
453                result.push_str(&format!("  - {}\n", warning));
454            }
455        }
456
457        result
458    }
459}
460
461#[cfg(test)]
462mod tests {
463
464    // Tests would go here - skipped for brevity
465}