Skip to main content

torsh_nn/parameter/
parameter_ext.rs

1//! Parameter management extensions and utilities
2//!
3//! This module provides enhanced parameter management capabilities including:
4//! - Parameter groups for differential learning rates
5//! - Parameter constraints and regularization
6//! - Advanced parameter inspection utilities
7//! - Parameter transformation utilities
8
9use super::Parameter;
10use torsh_core::error::{Result, TorshError};
11
12#[cfg(feature = "std")]
13use std::collections::HashMap;
14
15#[cfg(not(feature = "std"))]
16use hashbrown::HashMap;
17
18/// Parameter group for organizing parameters with shared hyperparameters
19///
20/// This is useful for implementing techniques like:
21/// - Differential learning rates
22/// - Layer-wise learning rate decay
23/// - Parameter-specific weight decay
24/// - Grouped parameter optimization
25#[derive(Debug, Clone)]
26pub struct ParameterGroup {
27    /// Name of the parameter group
28    pub name: String,
29    /// Parameters in this group
30    pub parameters: Vec<Parameter>,
31    /// Learning rate multiplier for this group
32    pub lr_multiplier: f32,
33    /// Weight decay for this group
34    pub weight_decay: f32,
35    /// Whether to apply gradient clipping
36    pub clip_gradients: bool,
37    /// Maximum gradient norm for clipping
38    pub max_grad_norm: f32,
39}
40
41impl ParameterGroup {
42    /// Create a new parameter group
43    ///
44    /// # Arguments
45    /// * `name` - Group name
46    /// * `parameters` - Parameters in this group
47    ///
48    /// # Returns
49    /// * `ParameterGroup` - New parameter group with default settings
50    pub fn new(name: String, parameters: Vec<Parameter>) -> Self {
51        Self {
52            name,
53            parameters,
54            lr_multiplier: 1.0,
55            weight_decay: 0.0,
56            clip_gradients: false,
57            max_grad_norm: 1.0,
58        }
59    }
60
61    /// Set learning rate multiplier (builder pattern)
62    pub fn with_lr_multiplier(mut self, multiplier: f32) -> Self {
63        self.lr_multiplier = multiplier;
64        self
65    }
66
67    /// Set weight decay (builder pattern)
68    pub fn with_weight_decay(mut self, decay: f32) -> Self {
69        self.weight_decay = decay;
70        self
71    }
72
73    /// Enable gradient clipping (builder pattern)
74    pub fn with_gradient_clipping(mut self, max_norm: f32) -> Self {
75        self.clip_gradients = true;
76        self.max_grad_norm = max_norm;
77        self
78    }
79
80    /// Get total number of parameters
81    pub fn num_parameters(&self) -> usize {
82        self.parameters.iter().map(|p| p.numel().unwrap_or(0)).sum()
83    }
84
85    /// Get parameter count
86    pub fn parameter_count(&self) -> usize {
87        self.parameters.len()
88    }
89}
90
91/// Parameter constraint for enforcing parameter properties
92///
93/// Constraints can be applied after parameter updates to ensure
94/// parameters stay within valid ranges or satisfy certain properties.
95#[derive(Debug, Clone)]
96pub enum ParameterConstraint {
97    /// Clamp parameters to a range
98    ClampRange { min: f32, max: f32 },
99    /// Ensure parameters are non-negative
100    NonNegative,
101    /// Normalize parameters (L2 norm = 1)
102    UnitNorm,
103    /// Ensure parameters are in [0, 1]
104    Probability,
105    /// Custom constraint function
106    Custom { name: String },
107}
108
109impl ParameterConstraint {
110    /// Apply constraint to a parameter
111    ///
112    /// # Arguments
113    /// * `parameter` - Parameter to constrain
114    ///
115    /// # Returns
116    /// * `Result<()>` - Success or error
117    pub fn apply(&self, parameter: &Parameter) -> Result<()> {
118        let tensor = parameter.tensor();
119        let _data = tensor.write();
120
121        match self {
122            ParameterConstraint::ClampRange { min, max } => {
123                // Would clamp values to [min, max]
124                let _ = (min, max);
125                // TODO: Implement when tensor supports clamp
126                Ok(())
127            }
128            ParameterConstraint::NonNegative => {
129                // Would set negative values to 0
130                // TODO: Implement when tensor supports element-wise operations
131                Ok(())
132            }
133            ParameterConstraint::UnitNorm => {
134                // Would normalize to unit norm
135                // TODO: Implement when tensor supports normalization
136                Ok(())
137            }
138            ParameterConstraint::Probability => {
139                // Would clamp to [0, 1] and normalize
140                // TODO: Implement when tensor supports operations
141                Ok(())
142            }
143            ParameterConstraint::Custom { name: _ } => {
144                // Custom constraints would be implemented by users
145                Ok(())
146            }
147        }
148    }
149
150    /// Get constraint name
151    pub fn name(&self) -> &str {
152        match self {
153            ParameterConstraint::ClampRange { .. } => "ClampRange",
154            ParameterConstraint::NonNegative => "NonNegative",
155            ParameterConstraint::UnitNorm => "UnitNorm",
156            ParameterConstraint::Probability => "Probability",
157            ParameterConstraint::Custom { name } => name,
158        }
159    }
160}
161
162/// Parameter statistics and analysis
163#[derive(Debug, Clone)]
164pub struct ParameterAnalysis {
165    /// Mean of parameter values
166    pub mean: f32,
167    /// Standard deviation of parameter values
168    pub std: f32,
169    /// Minimum value
170    pub min: f32,
171    /// Maximum value
172    pub max: f32,
173    /// Number of elements
174    pub numel: usize,
175    /// Percentage of zero values
176    pub sparsity: f32,
177    /// Has NaN values
178    pub has_nan: bool,
179    /// Has Inf values
180    pub has_inf: bool,
181}
182
183/// Extension trait for Parameter with advanced utilities
184pub trait ParameterExt {
185    /// Analyze parameter statistics
186    ///
187    /// # Returns
188    /// * `Result<ParameterAnalysis>` - Statistical analysis of parameter
189    fn analyze(&self) -> Result<ParameterAnalysis>;
190
191    /// Check if parameter values are finite
192    ///
193    /// # Returns
194    /// * `Result<bool>` - true if all values are finite (not NaN or Inf)
195    fn is_finite(&self) -> Result<bool>;
196
197    /// Get parameter L2 norm
198    ///
199    /// # Returns
200    /// * `Result<f32>` - L2 norm of parameter
201    fn norm(&self) -> Result<f32>;
202
203    /// Get parameter L1 norm
204    ///
205    /// # Returns
206    /// * `Result<f32>` - L1 norm of parameter
207    fn l1_norm(&self) -> Result<f32>;
208
209    /// Compute parameter gradient norm (when available)
210    ///
211    /// # Returns
212    /// * `Result<f32>` - Gradient norm
213    fn grad_norm(&self) -> Result<f32>;
214
215    /// Check if parameter has gradient (when available)
216    ///
217    /// # Returns
218    /// * `bool` - true if gradient is available
219    fn has_grad(&self) -> bool;
220
221    /// Get parameter as read-only data vector
222    ///
223    /// # Returns
224    /// * `Result<Vec<f32>>` - Parameter data
225    fn to_vec(&self) -> Result<Vec<f32>>;
226
227    /// Get parameter dtype name
228    ///
229    /// # Returns
230    /// * `&str` - Data type name
231    fn dtype_name(&self) -> &str;
232
233    /// Get memory usage in bytes
234    ///
235    /// # Returns
236    /// * `usize` - Memory usage in bytes
237    fn memory_bytes(&self) -> usize;
238
239    /// Clone parameter with new requires_grad setting
240    ///
241    /// # Arguments
242    /// * `requires_grad` - New requires_grad setting
243    ///
244    /// # Returns
245    /// * `Parameter` - Cloned parameter
246    fn clone_with_grad(&self, requires_grad: bool) -> Parameter;
247}
248
249impl ParameterExt for Parameter {
250    fn analyze(&self) -> Result<ParameterAnalysis> {
251        let tensor = self.tensor();
252        let data_guard = tensor.read();
253        let data = data_guard.to_vec()?;
254
255        let numel = data.len();
256        if numel == 0 {
257            return Err(TorshError::InvalidArgument(
258                "Cannot analyze empty parameter".to_string(),
259            ));
260        }
261
262        let sum: f32 = data.iter().sum();
263        let mean = sum / numel as f32;
264
265        let variance: f32 = data.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / numel as f32;
266        let std = variance.sqrt();
267
268        let min = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
269        let max = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
270
271        let zero_count = data.iter().filter(|&&x| x == 0.0).count();
272        let sparsity = (zero_count as f32 / numel as f32) * 100.0;
273
274        let has_nan = data.iter().any(|&x| x.is_nan());
275        let has_inf = data.iter().any(|&x| x.is_infinite());
276
277        Ok(ParameterAnalysis {
278            mean,
279            std,
280            min,
281            max,
282            numel,
283            sparsity,
284            has_nan,
285            has_inf,
286        })
287    }
288
289    fn is_finite(&self) -> Result<bool> {
290        let tensor = self.tensor();
291        let data = tensor.read().to_vec()?;
292        Ok(data.iter().all(|&x| x.is_finite()))
293    }
294
295    fn norm(&self) -> Result<f32> {
296        let tensor = self.tensor();
297        let data = tensor.read().to_vec()?;
298        let sum_sq: f32 = data.iter().map(|&x| x * x).sum();
299        Ok(sum_sq.sqrt())
300    }
301
302    fn l1_norm(&self) -> Result<f32> {
303        let tensor = self.tensor();
304        let data = tensor.read().to_vec()?;
305        Ok(data.iter().map(|&x| x.abs()).sum())
306    }
307
308    fn grad_norm(&self) -> Result<f32> {
309        // TODO: Implement when gradient support is available
310        Ok(0.0)
311    }
312
313    fn has_grad(&self) -> bool {
314        // TODO: Implement when gradient support is available
315        false
316    }
317
318    fn to_vec(&self) -> Result<Vec<f32>> {
319        let tensor = self.tensor();
320        let data_guard = tensor.read();
321        data_guard.to_vec()
322    }
323
324    fn dtype_name(&self) -> &str {
325        "f32" // Currently all parameters are f32
326    }
327
328    fn memory_bytes(&self) -> usize {
329        self.numel().unwrap_or(0) * 4 // f32 = 4 bytes
330    }
331
332    fn clone_with_grad(&self, requires_grad: bool) -> Parameter {
333        let tensor = self.clone_data();
334        if requires_grad {
335            Parameter::new(tensor)
336        } else {
337            Parameter::new_no_grad(tensor)
338        }
339    }
340}
341
342/// Extension trait for ParameterCollection with additional utilities
343///
344/// This trait is implemented for the existing ParameterCollection type
345/// to add advanced functionality without modifying the core implementation.
346pub trait ParameterCollectionExt {
347    /// Get total parameter count
348    fn total_numel(&self) -> usize;
349
350    /// Group parameters by name pattern
351    fn group_by_patterns(
352        &self,
353        groups: &HashMap<String, Vec<String>>,
354    ) -> HashMap<String, ParameterGroup>;
355
356    /// Filter parameters by property
357    fn filter<F>(&self, predicate: F) -> HashMap<String, Parameter>
358    where
359        F: Fn(&str, &Parameter) -> bool;
360
361    /// Get trainable parameters only
362    fn trainable(&self) -> HashMap<String, Parameter>;
363
364    /// Get frozen parameters only
365    fn frozen(&self) -> HashMap<String, Parameter>;
366}
367
368impl ParameterCollectionExt for super::ParameterCollection {
369    fn total_numel(&self) -> usize {
370        // Access through public methods
371        self.names()
372            .iter()
373            .filter_map(|name| self.get(name))
374            .map(|p| p.numel().unwrap_or(0))
375            .sum()
376    }
377
378    fn group_by_patterns(
379        &self,
380        groups: &HashMap<String, Vec<String>>,
381    ) -> HashMap<String, ParameterGroup> {
382        let mut result = HashMap::new();
383
384        for (group_name, patterns) in groups {
385            let mut group_params = Vec::new();
386
387            for param_name in self.names() {
388                if patterns.iter().any(|pattern| param_name.contains(pattern)) {
389                    if let Some(param) = self.get(param_name) {
390                        group_params.push(param.clone());
391                    }
392                }
393            }
394
395            if !group_params.is_empty() {
396                result.insert(
397                    group_name.clone(),
398                    ParameterGroup::new(group_name.clone(), group_params),
399                );
400            }
401        }
402
403        result
404    }
405
406    fn filter<F>(&self, predicate: F) -> HashMap<String, Parameter>
407    where
408        F: Fn(&str, &Parameter) -> bool,
409    {
410        let mut result = HashMap::new();
411
412        for name in self.names() {
413            if let Some(param) = self.get(name) {
414                if predicate(name, param) {
415                    result.insert(name.clone(), param.clone());
416                }
417            }
418        }
419
420        result
421    }
422
423    fn trainable(&self) -> HashMap<String, Parameter> {
424        self.filter(|_, param| param.requires_grad())
425    }
426
427    fn frozen(&self) -> HashMap<String, Parameter> {
428        self.filter(|_, param| !param.requires_grad())
429    }
430}
431
432#[cfg(test)]
433mod tests {
434
435    // Tests would go here
436}