torsh_nn/parameter/
parameter_ext.rs1use super::Parameter;
10use torsh_core::error::{Result, TorshError};
11
12#[cfg(feature = "std")]
13use std::collections::HashMap;
14
15#[cfg(not(feature = "std"))]
16use hashbrown::HashMap;
17
18#[derive(Debug, Clone)]
26pub struct ParameterGroup {
27 pub name: String,
29 pub parameters: Vec<Parameter>,
31 pub lr_multiplier: f32,
33 pub weight_decay: f32,
35 pub clip_gradients: bool,
37 pub max_grad_norm: f32,
39}
40
41impl ParameterGroup {
42 pub fn new(name: String, parameters: Vec<Parameter>) -> Self {
51 Self {
52 name,
53 parameters,
54 lr_multiplier: 1.0,
55 weight_decay: 0.0,
56 clip_gradients: false,
57 max_grad_norm: 1.0,
58 }
59 }
60
61 pub fn with_lr_multiplier(mut self, multiplier: f32) -> Self {
63 self.lr_multiplier = multiplier;
64 self
65 }
66
67 pub fn with_weight_decay(mut self, decay: f32) -> Self {
69 self.weight_decay = decay;
70 self
71 }
72
73 pub fn with_gradient_clipping(mut self, max_norm: f32) -> Self {
75 self.clip_gradients = true;
76 self.max_grad_norm = max_norm;
77 self
78 }
79
80 pub fn num_parameters(&self) -> usize {
82 self.parameters.iter().map(|p| p.numel().unwrap_or(0)).sum()
83 }
84
85 pub fn parameter_count(&self) -> usize {
87 self.parameters.len()
88 }
89}
90
91#[derive(Debug, Clone)]
96pub enum ParameterConstraint {
97 ClampRange { min: f32, max: f32 },
99 NonNegative,
101 UnitNorm,
103 Probability,
105 Custom { name: String },
107}
108
109impl ParameterConstraint {
110 pub fn apply(&self, parameter: &Parameter) -> Result<()> {
118 let tensor = parameter.tensor();
119 let _data = tensor.write();
120
121 match self {
122 ParameterConstraint::ClampRange { min, max } => {
123 let _ = (min, max);
125 Ok(())
127 }
128 ParameterConstraint::NonNegative => {
129 Ok(())
132 }
133 ParameterConstraint::UnitNorm => {
134 Ok(())
137 }
138 ParameterConstraint::Probability => {
139 Ok(())
142 }
143 ParameterConstraint::Custom { name: _ } => {
144 Ok(())
146 }
147 }
148 }
149
150 pub fn name(&self) -> &str {
152 match self {
153 ParameterConstraint::ClampRange { .. } => "ClampRange",
154 ParameterConstraint::NonNegative => "NonNegative",
155 ParameterConstraint::UnitNorm => "UnitNorm",
156 ParameterConstraint::Probability => "Probability",
157 ParameterConstraint::Custom { name } => name,
158 }
159 }
160}
161
162#[derive(Debug, Clone)]
164pub struct ParameterAnalysis {
165 pub mean: f32,
167 pub std: f32,
169 pub min: f32,
171 pub max: f32,
173 pub numel: usize,
175 pub sparsity: f32,
177 pub has_nan: bool,
179 pub has_inf: bool,
181}
182
183pub trait ParameterExt {
185 fn analyze(&self) -> Result<ParameterAnalysis>;
190
191 fn is_finite(&self) -> Result<bool>;
196
197 fn norm(&self) -> Result<f32>;
202
203 fn l1_norm(&self) -> Result<f32>;
208
209 fn grad_norm(&self) -> Result<f32>;
214
215 fn has_grad(&self) -> bool;
220
221 fn to_vec(&self) -> Result<Vec<f32>>;
226
227 fn dtype_name(&self) -> &str;
232
233 fn memory_bytes(&self) -> usize;
238
239 fn clone_with_grad(&self, requires_grad: bool) -> Parameter;
247}
248
249impl ParameterExt for Parameter {
250 fn analyze(&self) -> Result<ParameterAnalysis> {
251 let tensor = self.tensor();
252 let data_guard = tensor.read();
253 let data = data_guard.to_vec()?;
254
255 let numel = data.len();
256 if numel == 0 {
257 return Err(TorshError::InvalidArgument(
258 "Cannot analyze empty parameter".to_string(),
259 ));
260 }
261
262 let sum: f32 = data.iter().sum();
263 let mean = sum / numel as f32;
264
265 let variance: f32 = data.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / numel as f32;
266 let std = variance.sqrt();
267
268 let min = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
269 let max = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
270
271 let zero_count = data.iter().filter(|&&x| x == 0.0).count();
272 let sparsity = (zero_count as f32 / numel as f32) * 100.0;
273
274 let has_nan = data.iter().any(|&x| x.is_nan());
275 let has_inf = data.iter().any(|&x| x.is_infinite());
276
277 Ok(ParameterAnalysis {
278 mean,
279 std,
280 min,
281 max,
282 numel,
283 sparsity,
284 has_nan,
285 has_inf,
286 })
287 }
288
289 fn is_finite(&self) -> Result<bool> {
290 let tensor = self.tensor();
291 let data = tensor.read().to_vec()?;
292 Ok(data.iter().all(|&x| x.is_finite()))
293 }
294
295 fn norm(&self) -> Result<f32> {
296 let tensor = self.tensor();
297 let data = tensor.read().to_vec()?;
298 let sum_sq: f32 = data.iter().map(|&x| x * x).sum();
299 Ok(sum_sq.sqrt())
300 }
301
302 fn l1_norm(&self) -> Result<f32> {
303 let tensor = self.tensor();
304 let data = tensor.read().to_vec()?;
305 Ok(data.iter().map(|&x| x.abs()).sum())
306 }
307
308 fn grad_norm(&self) -> Result<f32> {
309 Ok(0.0)
311 }
312
313 fn has_grad(&self) -> bool {
314 false
316 }
317
318 fn to_vec(&self) -> Result<Vec<f32>> {
319 let tensor = self.tensor();
320 let data_guard = tensor.read();
321 data_guard.to_vec()
322 }
323
324 fn dtype_name(&self) -> &str {
325 "f32" }
327
328 fn memory_bytes(&self) -> usize {
329 self.numel().unwrap_or(0) * 4 }
331
332 fn clone_with_grad(&self, requires_grad: bool) -> Parameter {
333 let tensor = self.clone_data();
334 if requires_grad {
335 Parameter::new(tensor)
336 } else {
337 Parameter::new_no_grad(tensor)
338 }
339 }
340}
341
342pub trait ParameterCollectionExt {
347 fn total_numel(&self) -> usize;
349
350 fn group_by_patterns(
352 &self,
353 groups: &HashMap<String, Vec<String>>,
354 ) -> HashMap<String, ParameterGroup>;
355
356 fn filter<F>(&self, predicate: F) -> HashMap<String, Parameter>
358 where
359 F: Fn(&str, &Parameter) -> bool;
360
361 fn trainable(&self) -> HashMap<String, Parameter>;
363
364 fn frozen(&self) -> HashMap<String, Parameter>;
366}
367
368impl ParameterCollectionExt for super::ParameterCollection {
369 fn total_numel(&self) -> usize {
370 self.names()
372 .iter()
373 .filter_map(|name| self.get(name))
374 .map(|p| p.numel().unwrap_or(0))
375 .sum()
376 }
377
378 fn group_by_patterns(
379 &self,
380 groups: &HashMap<String, Vec<String>>,
381 ) -> HashMap<String, ParameterGroup> {
382 let mut result = HashMap::new();
383
384 for (group_name, patterns) in groups {
385 let mut group_params = Vec::new();
386
387 for param_name in self.names() {
388 if patterns.iter().any(|pattern| param_name.contains(pattern)) {
389 if let Some(param) = self.get(param_name) {
390 group_params.push(param.clone());
391 }
392 }
393 }
394
395 if !group_params.is_empty() {
396 result.insert(
397 group_name.clone(),
398 ParameterGroup::new(group_name.clone(), group_params),
399 );
400 }
401 }
402
403 result
404 }
405
406 fn filter<F>(&self, predicate: F) -> HashMap<String, Parameter>
407 where
408 F: Fn(&str, &Parameter) -> bool,
409 {
410 let mut result = HashMap::new();
411
412 for name in self.names() {
413 if let Some(param) = self.get(name) {
414 if predicate(name, param) {
415 result.insert(name.clone(), param.clone());
416 }
417 }
418 }
419
420 result
421 }
422
423 fn trainable(&self) -> HashMap<String, Parameter> {
424 self.filter(|_, param| param.requires_grad())
425 }
426
427 fn frozen(&self) -> HashMap<String, Parameter> {
428 self.filter(|_, param| !param.requires_grad())
429 }
430}
431
432#[cfg(test)]
433mod tests {
434
435 }