1pub mod parameter_ext;
10
11pub use parameter_ext::{
12 ParameterAnalysis, ParameterCollectionExt, ParameterConstraint, ParameterExt, ParameterGroup,
13};
14
15use crate::init::Initializer;
16use parking_lot::RwLock;
17use std::sync::Arc;
18use torsh_core::device::DeviceType;
19use torsh_core::error::Result;
20use torsh_tensor::Tensor;
21
22#[cfg(feature = "std")]
24use std::collections::HashMap;
25
26#[cfg(not(feature = "std"))]
27use hashbrown::HashMap;
28
29#[derive(Clone, Debug)]
31pub struct Parameter {
32 data: Arc<RwLock<Tensor>>,
33 requires_grad: bool,
34}
35
36impl Parameter {
37 pub fn new(tensor: Tensor) -> Self {
39 Self {
40 data: Arc::new(RwLock::new(tensor)),
41 requires_grad: true,
42 }
43 }
44
45 pub fn new_no_grad(tensor: Tensor) -> Self {
47 Self {
48 data: Arc::new(RwLock::new(tensor)),
49 requires_grad: false,
50 }
51 }
52
53 pub fn tensor(&self) -> Arc<RwLock<Tensor>> {
55 self.data.clone()
56 }
57
58 pub fn from_tensor(tensor: Arc<RwLock<Tensor>>) -> Self {
60 Self {
61 data: tensor,
62 requires_grad: true,
63 }
64 }
65
66 pub fn requires_grad_(mut self, requires_grad: bool) -> Self {
68 self.requires_grad = requires_grad;
69 self
72 }
73
74 pub fn requires_grad(&self) -> bool {
76 self.requires_grad
77 }
78
79 pub fn shape(&self) -> Result<Vec<usize>> {
81 Ok(self.data.read().shape().dims().to_vec())
82 }
83
84 pub fn numel(&self) -> Result<usize> {
86 Ok(self.data.read().shape().numel())
87 }
88
89 pub fn to_device(&mut self, device: DeviceType) -> Result<()> {
91 let _ = device; Ok(())
95 }
96
97 pub fn zero_grad(&mut self) {
99 }
102
103 pub fn clone_data(&self) -> Tensor {
105 self.data.read().clone()
106 }
107}
108
109impl Parameter {
111 pub fn with_init<F>(shape: Vec<usize>, _device: DeviceType, init_fn: F) -> Result<Self>
116 where
117 F: FnOnce(Vec<usize>) -> Result<Tensor>,
118 {
119 let tensor = init_fn(shape)?;
120 Ok(Self::new(tensor))
121 }
122
123 pub fn from_data(data: Vec<f32>, shape: Vec<usize>) -> Result<Self> {
127 let tensor = torsh_tensor::Tensor::from_vec(data, &shape)?;
128 Ok(Self::new(tensor))
129 }
130
131 pub fn from_data_auto_shape(data: Vec<f32>) -> Result<Self> {
135 let shape = vec![data.len()];
136 Self::from_data(data, shape)
137 }
138
139 pub fn auto_init(shape: Vec<usize>, device: DeviceType, layer_type: LayerType) -> Result<Self> {
143 use crate::init::InitMethod;
144
145 let init_method = match layer_type {
146 LayerType::Linear | LayerType::Dense => InitMethod::KaimingUniform {
147 mode: crate::init::FanMode::FanIn,
148 nonlinearity: crate::init::Nonlinearity::Linear,
149 },
150 LayerType::Conv => InitMethod::KaimingUniform {
151 mode: crate::init::FanMode::FanOut,
152 nonlinearity: crate::init::Nonlinearity::ReLU,
153 },
154 LayerType::RNN | LayerType::LSTM | LayerType::GRU => {
155 InitMethod::XavierUniform { gain: 1.0 }
156 }
157 LayerType::Attention => InitMethod::XavierNormal { gain: 1.0 },
158 LayerType::Embedding => InitMethod::Normal {
159 mean: 0.0,
160 std: 1.0,
161 },
162 LayerType::Bias => InitMethod::Constant { value: 0.0 },
163 LayerType::BatchNorm => InitMethod::Constant { value: 1.0 },
164 LayerType::Custom => InitMethod::KaimingUniform {
165 mode: crate::init::FanMode::FanIn,
166 nonlinearity: crate::init::Nonlinearity::ReLU,
167 },
168 };
169
170 Self::with_init_method(shape, device, init_method)
171 }
172
173 pub fn zeros(shape: Vec<usize>, _device: DeviceType) -> Result<Self> {
175 use torsh_tensor::creation::zeros;
176 let tensor = zeros(&shape)?;
177 Ok(Self::new(tensor))
178 }
179
180 pub fn ones(shape: Vec<usize>, _device: DeviceType) -> Result<Self> {
182 use torsh_tensor::creation::ones;
183 let tensor = ones(&shape)?;
184 Ok(Self::new(tensor))
185 }
186
187 pub fn with_init_method(
189 shape: Vec<usize>,
190 _device: DeviceType,
191 method: crate::init::InitMethod,
192 ) -> Result<Self> {
193 let tensor = method.initialize(&shape)?;
194 Ok(Self::new(tensor))
195 }
196
197 pub fn uniform(shape: Vec<usize>, device: DeviceType, low: f32, high: f32) -> Result<Self> {
199 use crate::init::InitMethod;
200 Self::with_init_method(shape, device, InitMethod::Uniform { low, high })
201 }
202
203 pub fn normal(shape: Vec<usize>, device: DeviceType, mean: f32, std: f32) -> Result<Self> {
205 use crate::init::InitMethod;
206 Self::with_init_method(shape, device, InitMethod::Normal { mean, std })
207 }
208
209 pub fn xavier_uniform(shape: Vec<usize>, device: DeviceType, gain: f32) -> Result<Self> {
211 use crate::init::InitMethod;
212 Self::with_init_method(shape, device, InitMethod::XavierUniform { gain })
213 }
214
215 pub fn xavier_normal(shape: Vec<usize>, device: DeviceType, gain: f32) -> Result<Self> {
217 use crate::init::InitMethod;
218 Self::with_init_method(shape, device, InitMethod::XavierNormal { gain })
219 }
220
221 pub fn kaiming_uniform(
223 shape: Vec<usize>,
224 device: DeviceType,
225 nonlinearity: &str,
226 ) -> Result<Self> {
227 use crate::init::{FanMode, InitMethod, Nonlinearity};
228 let nl = match nonlinearity {
229 "relu" => Nonlinearity::ReLU,
230 "leaky_relu" => Nonlinearity::LeakyReLU {
231 negative_slope: 0.01,
232 },
233 "tanh" => Nonlinearity::Tanh,
234 "sigmoid" => Nonlinearity::Sigmoid,
235 "selu" => Nonlinearity::SELU,
236 "elu" => Nonlinearity::ELU,
237 "swish" => Nonlinearity::Swish,
238 "linear" => Nonlinearity::Linear,
239 _ => Nonlinearity::Linear,
240 };
241 Self::with_init_method(
242 shape,
243 device,
244 InitMethod::KaimingUniform {
245 mode: FanMode::FanIn,
246 nonlinearity: nl,
247 },
248 )
249 }
250
251 pub fn kaiming_normal(
253 shape: Vec<usize>,
254 device: DeviceType,
255 nonlinearity: &str,
256 ) -> Result<Self> {
257 use crate::init::{FanMode, InitMethod, Nonlinearity};
258 let nl = match nonlinearity {
259 "relu" => Nonlinearity::ReLU,
260 "leaky_relu" => Nonlinearity::LeakyReLU {
261 negative_slope: 0.01,
262 },
263 "tanh" => Nonlinearity::Tanh,
264 "sigmoid" => Nonlinearity::Sigmoid,
265 "selu" => Nonlinearity::SELU,
266 "elu" => Nonlinearity::ELU,
267 "swish" => Nonlinearity::Swish,
268 "linear" => Nonlinearity::Linear,
269 _ => Nonlinearity::Linear,
270 };
271 Self::with_init_method(
272 shape,
273 device,
274 InitMethod::KaimingNormal {
275 mode: FanMode::FanIn,
276 nonlinearity: nl,
277 },
278 )
279 }
280
281 pub fn constant(shape: Vec<usize>, device: DeviceType, value: f32) -> Result<Self> {
283 use crate::init::InitMethod;
284 Self::with_init_method(shape, device, InitMethod::Constant { value })
285 }
286
287 pub fn orthogonal(shape: Vec<usize>, device: DeviceType, gain: f32) -> Result<Self> {
289 use crate::init::InitMethod;
290 Self::with_init_method(shape, device, InitMethod::Orthogonal { gain })
291 }
292
293 pub fn sparse(shape: Vec<usize>, device: DeviceType, sparsity: f32, std: f32) -> Result<Self> {
295 use crate::init::InitMethod;
296 Self::with_init_method(shape, device, InitMethod::Sparse { sparsity, std })
297 }
298
299 pub fn lecun_uniform(shape: Vec<usize>, device: DeviceType) -> Result<Self> {
301 use crate::init::InitMethod;
302 Self::with_init_method(shape, device, InitMethod::LecunUniform)
303 }
304
305 pub fn lecun_normal(shape: Vec<usize>, device: DeviceType) -> Result<Self> {
307 use crate::init::InitMethod;
308 Self::with_init_method(shape, device, InitMethod::LecunNormal)
309 }
310
311 pub fn truncated_normal(
313 shape: Vec<usize>,
314 device: DeviceType,
315 mean: f32,
316 std: f32,
317 a: f32,
318 b: f32,
319 ) -> Result<Self> {
320 use crate::init::InitMethod;
321 Self::with_init_method(
322 shape,
323 device,
324 InitMethod::TruncatedNormal { mean, std, a, b },
325 )
326 }
327
328 pub fn eye(shape: Vec<usize>, device: DeviceType) -> Result<Self> {
330 use crate::init::InitMethod;
331 Self::with_init_method(shape, device, InitMethod::Eye)
332 }
333
334 pub fn stats(&self) -> Result<ParameterStats> {
336 let tensor = self.data.read();
337 let data = tensor.to_vec()?;
338
339 if data.is_empty() {
340 return Ok(ParameterStats {
341 mean: 0.0,
342 std: 0.0,
343 variance: 0.0,
344 min: 0.0,
345 max: 0.0,
346 numel: 0,
347 median: 0.0,
348 q25: 0.0,
349 q75: 0.0,
350 skewness: 0.0,
351 kurtosis: 0.0,
352 });
353 }
354
355 let mean = data.iter().sum::<f32>() / data.len() as f32;
356 let variance = data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / data.len() as f32;
357 let std = variance.sqrt();
358 let min = data.iter().copied().fold(f32::INFINITY, f32::min);
359 let max = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
360
361 let mut sorted_data = data.clone();
363 sorted_data.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
364
365 let median = if sorted_data.len() % 2 == 0 {
366 (sorted_data[sorted_data.len() / 2 - 1] + sorted_data[sorted_data.len() / 2]) / 2.0
367 } else {
368 sorted_data[sorted_data.len() / 2]
369 };
370
371 let q25_idx = sorted_data.len() / 4;
372 let q75_idx = 3 * sorted_data.len() / 4;
373 let q25 = sorted_data.get(q25_idx).copied().unwrap_or(min);
374 let q75 = sorted_data.get(q75_idx).copied().unwrap_or(max);
375
376 let n = data.len() as f32;
378 let skewness = if std > 0.0 {
379 data.iter().map(|x| ((x - mean) / std).powi(3)).sum::<f32>() / n
380 } else {
381 0.0
382 };
383
384 let kurtosis = if std > 0.0 {
385 data.iter().map(|x| ((x - mean) / std).powi(4)).sum::<f32>() / n - 3.0
386 } else {
387 0.0
388 };
389
390 Ok(ParameterStats {
391 mean,
392 std,
393 variance,
394 min,
395 max,
396 numel: data.len(),
397 median,
398 q25,
399 q75,
400 skewness,
401 kurtosis,
402 })
403 }
404
405 pub fn is_finite(&self) -> Result<bool> {
407 let tensor = self.data.read();
408 let data = tensor.to_vec()?;
409 Ok(data.iter().all(|x| x.is_finite()))
410 }
411
412 pub fn reinitialize(&mut self, method: crate::init::InitMethod) -> Result<()> {
414 let current_shape = self.shape()?;
415 let new_tensor = method.initialize(¤t_shape)?;
416 *self.data.write() = new_tensor;
417 Ok(())
418 }
419
420 pub fn norm(&self) -> Result<f32> {
422 let tensor = self.data.read();
423 let data = tensor.to_vec()?;
424 let norm = data.iter().map(|x| x * x).sum::<f32>().sqrt();
425 Ok(norm)
426 }
427
428 pub fn l1_norm(&self) -> Result<f32> {
430 let tensor = self.data.read();
431 let data = tensor.to_vec()?;
432 let norm = data.iter().map(|x| x.abs()).sum::<f32>();
433 Ok(norm)
434 }
435
436 pub fn linf_norm(&self) -> Result<f32> {
438 let tensor = self.data.read();
439 let data = tensor.to_vec()?;
440 let norm = data.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
441 Ok(norm)
442 }
443
444 pub fn clamp(&mut self, min: f32, max: f32) -> Result<()> {
446 let mut tensor = self.data.write();
447 let data = tensor.to_vec()?;
448 let clamped_data: Vec<f32> = data.iter().map(|&x| x.clamp(min, max)).collect();
449 let shape = tensor.shape().dims().to_vec();
450 *tensor = torsh_tensor::Tensor::from_vec(clamped_data, &shape)?;
451 Ok(())
452 }
453
454 pub fn apply_fn<F>(&mut self, f: F) -> Result<()>
456 where
457 F: Fn(f32) -> f32,
458 {
459 let mut tensor = self.data.write();
460 let data = tensor.to_vec()?;
461 let transformed_data: Vec<f32> = data.iter().map(|&x| f(x)).collect();
462 let shape = tensor.shape().dims().to_vec();
463 *tensor = torsh_tensor::Tensor::from_vec(transformed_data, &shape)?;
464 Ok(())
465 }
466
467 pub fn scale(&mut self, factor: f32) -> Result<()> {
469 self.apply_fn(|x| x * factor)
470 }
471
472 pub fn add_noise(&mut self, std: f32) -> Result<()> {
474 use scirs2_core::random::thread_rng;
475 let mut rng = thread_rng();
476 let mut tensor = self.data.write();
477 let data = tensor.to_vec()?;
478 let noisy_data: Vec<f32> = data
479 .iter()
480 .map(|&x| x + rng.random::<f32>() * std)
481 .collect();
482 let shape = tensor.shape().dims().to_vec();
483 *tensor = torsh_tensor::Tensor::from_vec(noisy_data, &shape)?;
484 Ok(())
485 }
486
487 pub fn histogram(&self, bins: usize) -> Result<Vec<(f32, usize)>> {
489 let tensor = self.data.read();
490 let data = tensor.to_vec()?;
491
492 if data.is_empty() {
493 return Ok(Vec::new());
494 }
495
496 let min_val = data.iter().copied().fold(f32::INFINITY, f32::min);
497 let max_val = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
498
499 if min_val == max_val {
500 return Ok(vec![(min_val, data.len())]);
501 }
502
503 let bin_width = (max_val - min_val) / bins as f32;
504 let mut histogram = vec![0; bins];
505
506 for &value in &data {
507 let bin_index = ((value - min_val) / bin_width).floor() as usize;
508 let bin_index = bin_index.min(bins - 1);
509 histogram[bin_index] += 1;
510 }
511
512 let result: Vec<(f32, usize)> = histogram
513 .into_iter()
514 .enumerate()
515 .map(|(i, count)| (min_val + (i as f32 + 0.5) * bin_width, count))
516 .collect();
517
518 Ok(result)
519 }
520
521 pub fn diagnose(&self) -> Result<ParameterDiagnostics> {
523 let stats = self.stats()?;
524 let mut issues = Vec::new();
525 let mut warnings = Vec::new();
526
527 if !self.is_finite()? {
529 issues.push("Parameter contains NaN or infinite values".to_string());
530 }
531
532 if stats.std < 1e-6 {
534 warnings
535 .push("Very low standard deviation - parameters may be too uniform".to_string());
536 }
537
538 if stats.std > 10.0 {
539 warnings.push("Very high standard deviation - parameters may be unstable".to_string());
540 }
541
542 if stats.mean.abs() > 5.0 {
543 warnings
544 .push("High mean absolute value - parameters may be poorly centered".to_string());
545 }
546
547 let norm = self.norm()?;
549 if norm < 1e-8 {
550 warnings
551 .push("Very small parameter norm - may indicate vanishing gradients".to_string());
552 } else if norm > 100.0 {
553 warnings
554 .push("Very large parameter norm - may indicate exploding gradients".to_string());
555 }
556
557 Ok(ParameterDiagnostics {
558 stats,
559 issues,
560 warnings,
561 norm,
562 is_finite: self.is_finite()?,
563 })
564 }
565}
566
567#[derive(Debug, Clone, Copy, PartialEq, Eq)]
569pub enum LayerType {
570 Linear,
571 Dense,
572 Conv,
573 RNN,
574 LSTM,
575 GRU,
576 Attention,
577 Embedding,
578 Bias,
579 BatchNorm,
580 Custom,
581}
582
583#[derive(Debug, Clone)]
585pub struct ParameterStats {
586 pub mean: f32,
587 pub std: f32,
588 pub variance: f32,
589 pub min: f32,
590 pub max: f32,
591 pub numel: usize,
592 pub median: f32,
593 pub q25: f32,
594 pub q75: f32,
595 pub skewness: f32,
596 pub kurtosis: f32,
597}
598
599impl ParameterStats {
600 pub fn from_data(data: &[f32]) -> Self {
602 if data.is_empty() {
603 return Self::empty();
604 }
605
606 let n = data.len() as f32;
607 let mean = data.iter().sum::<f32>() / n;
608 let variance = data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / n;
609 let std = variance.sqrt();
610
611 let mut sorted_data = data.to_vec();
612 sorted_data.sort_by(|a, b| {
613 a.partial_cmp(b)
614 .expect("data comparison should not involve NaN")
615 });
616
617 let min = sorted_data[0];
618 let max = sorted_data[sorted_data.len() - 1];
619 let median = Self::percentile(&sorted_data, 0.5);
620 let q25 = Self::percentile(&sorted_data, 0.25);
621 let q75 = Self::percentile(&sorted_data, 0.75);
622
623 let std_cubed = std.powi(3);
625 let std_fourth = std.powi(4);
626
627 let skewness = if std_cubed > 0.0 {
628 data.iter().map(|x| ((x - mean) / std).powi(3)).sum::<f32>() / n
629 } else {
630 0.0
631 };
632
633 let kurtosis = if std_fourth > 0.0 {
634 data.iter().map(|x| ((x - mean) / std).powi(4)).sum::<f32>() / n - 3.0
635 } else {
636 0.0
637 };
638
639 Self {
640 mean,
641 std,
642 variance,
643 min,
644 max,
645 numel: data.len(),
646 median,
647 q25,
648 q75,
649 skewness,
650 kurtosis,
651 }
652 }
653
654 pub fn empty() -> Self {
656 Self {
657 mean: 0.0,
658 std: 0.0,
659 variance: 0.0,
660 min: 0.0,
661 max: 0.0,
662 numel: 0,
663 median: 0.0,
664 q25: 0.0,
665 q75: 0.0,
666 skewness: 0.0,
667 kurtosis: 0.0,
668 }
669 }
670
671 fn percentile(sorted_data: &[f32], p: f32) -> f32 {
673 if sorted_data.is_empty() {
674 return 0.0;
675 }
676
677 let index = p * (sorted_data.len() - 1) as f32;
678 let lower = index.floor() as usize;
679 let upper = index.ceil() as usize;
680
681 if lower == upper {
682 sorted_data[lower]
683 } else {
684 let weight = index - lower as f32;
685 sorted_data[lower] * (1.0 - weight) + sorted_data[upper] * weight
686 }
687 }
688
689 pub fn iqr(&self) -> f32 {
691 self.q75 - self.q25
692 }
693
694 pub fn is_approximately_normal(&self) -> bool {
696 self.skewness.abs() < 1.0 && self.kurtosis.abs() < 3.0
698 }
699}
700
701impl core::fmt::Display for ParameterStats {
702 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
703 write!(
704 f,
705 "ParameterStats(mean={:.4}, std={:.4}, min={:.4}, max={:.4}, numel={})",
706 self.mean, self.std, self.min, self.max, self.numel
707 )
708 }
709}
710
711#[derive(Debug, Clone)]
713pub struct ParameterDiagnostics {
714 pub stats: ParameterStats,
715 pub issues: Vec<String>,
716 pub warnings: Vec<String>,
717 pub norm: f32,
718 pub is_finite: bool,
719}
720
721impl core::fmt::Display for ParameterDiagnostics {
722 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
723 writeln!(f, "Parameter Diagnostics:")?;
724 writeln!(f, " {}", self.stats)?;
725 writeln!(f, " Norm: {:.6}", self.norm)?;
726 writeln!(f, " Finite: {}", self.is_finite)?;
727
728 if !self.issues.is_empty() {
729 writeln!(f, " Issues:")?;
730 for issue in &self.issues {
731 writeln!(f, " - {issue}")?;
732 }
733 }
734
735 if !self.warnings.is_empty() {
736 writeln!(f, " Warnings:")?;
737 for warning in &self.warnings {
738 writeln!(f, " - {warning}")?;
739 }
740 }
741
742 Ok(())
743 }
744}
745
746#[derive(Debug, Clone)]
751pub struct ParameterCollection {
752 parameters: HashMap<String, Parameter>,
753}
754
755impl ParameterCollection {
756 pub fn new() -> Self {
758 Self {
759 parameters: HashMap::new(),
760 }
761 }
762
763 pub fn from_map(parameters: HashMap<String, Parameter>) -> Self {
765 Self { parameters }
766 }
767
768 pub fn add(&mut self, name: String, parameter: Parameter) {
770 self.parameters.insert(name, parameter);
771 }
772
773 pub fn get(&self, name: &str) -> Option<&Parameter> {
775 self.parameters.get(name)
776 }
777
778 pub fn get_mut(&mut self, name: &str) -> Option<&mut Parameter> {
780 self.parameters.get_mut(name)
781 }
782
783 pub fn names(&self) -> Vec<&String> {
785 self.parameters.keys().collect()
786 }
787
788 pub fn len(&self) -> usize {
790 self.parameters.len()
791 }
792
793 pub fn is_empty(&self) -> bool {
795 self.parameters.is_empty()
796 }
797
798 pub fn apply_to_all<F>(&mut self, mut f: F) -> Result<()>
800 where
801 F: FnMut(&mut Parameter) -> Result<()>,
802 {
803 for param in self.parameters.values_mut() {
804 f(param)?;
805 }
806 Ok(())
807 }
808
809 pub fn stats(&self) -> Result<HashMap<String, ParameterStats>> {
811 let mut stats = HashMap::new();
812 for (name, param) in &self.parameters {
813 stats.insert(name.clone(), param.stats()?);
814 }
815 Ok(stats)
816 }
817
818 pub fn diagnose(&self) -> Result<HashMap<String, ParameterDiagnostics>> {
820 let mut diagnostics = HashMap::new();
821 for (name, param) in &self.parameters {
822 diagnostics.insert(name.clone(), param.diagnose()?);
823 }
824 Ok(diagnostics)
825 }
826
827 pub fn total_parameters(&self) -> usize {
829 self.parameters
830 .values()
831 .map(|p| p.numel().unwrap_or(0))
832 .sum()
833 }
834
835 pub fn total_memory_usage(&self) -> usize {
837 self.parameters
838 .values()
839 .map(|p| p.numel().unwrap_or(0) * 4) .sum()
841 }
842
843 pub fn freeze_all(&mut self) {
845 for param in self.parameters.values_mut() {
846 param.requires_grad = false;
847 }
848 }
849
850 pub fn unfreeze_all(&mut self) {
852 for param in self.parameters.values_mut() {
853 param.requires_grad = true;
854 }
855 }
856
857 pub fn scale_all(&mut self, factor: f32) -> Result<()> {
859 self.apply_to_all(|param| param.scale(factor))
860 }
861
862 pub fn clamp_all(&mut self, min: f32, max: f32) -> Result<()> {
864 self.apply_to_all(|param| param.clamp(min, max))
865 }
866
867 pub fn add_noise_all(&mut self, std: f32) -> Result<()> {
869 self.apply_to_all(|param| param.add_noise(std))
870 }
871
872 pub fn filter_by_name(&self, pattern: &str) -> ParameterCollection {
874 let filtered: HashMap<String, Parameter> = self
875 .parameters
876 .iter()
877 .filter(|(name, _)| name.contains(pattern))
878 .map(|(name, param)| (name.clone(), param.clone()))
879 .collect();
880 ParameterCollection::from_map(filtered)
881 }
882
883 pub fn filter_by<F>(&self, predicate: F) -> ParameterCollection
885 where
886 F: Fn(&String, &Parameter) -> bool,
887 {
888 let filtered: HashMap<String, Parameter> = self
889 .parameters
890 .iter()
891 .filter(|(name, param)| predicate(name, param))
892 .map(|(name, param)| (name.clone(), param.clone()))
893 .collect();
894 ParameterCollection::from_map(filtered)
895 }
896
897 pub fn summary_report(&self) -> Result<String> {
899 let mut report = String::new();
900 report.push_str("Parameter Collection Summary\n");
901 report.push_str(&format!("Total parameters: {}\n", self.len()));
902 report.push_str(&format!("Total elements: {}\n", self.total_parameters()));
903 report.push_str(&format!(
904 "Memory usage: {:.2} MB\n",
905 self.total_memory_usage() as f64 / (1024.0 * 1024.0)
906 ));
907 report.push_str("\nParameter Details:\n");
908
909 for (name, param) in &self.parameters {
910 let stats = param.stats()?;
911 report.push_str(&format!(
912 " {}: {} elements, mean={:.4}, std={:.4}\n",
913 name, stats.numel, stats.mean, stats.std
914 ));
915 }
916
917 Ok(report)
918 }
919}
920
921impl Default for ParameterCollection {
922 fn default() -> Self {
923 Self::new()
924 }
925}
926
927impl From<HashMap<String, Parameter>> for ParameterCollection {
928 fn from(parameters: HashMap<String, Parameter>) -> Self {
929 Self::from_map(parameters)
930 }
931}
932
933impl From<ParameterCollection> for HashMap<String, Parameter> {
934 fn from(val: ParameterCollection) -> Self {
935 val.parameters
936 }
937}
938
939#[cfg(test)]
940mod tests {
941 use super::*;
942 use approx::assert_relative_eq;
943 use torsh_core::device::DeviceType;
944 use torsh_tensor::creation::zeros;
945
946 #[test]
951 fn test_parameter_new() -> Result<()> {
952 let tensor = zeros(&[3, 4])?;
953 let param = Parameter::new(tensor);
954
955 assert!(param.requires_grad());
956 assert_eq!(param.shape()?, vec![3, 4]);
957 assert_eq!(param.numel()?, 12);
958 Ok(())
959 }
960
961 #[test]
962 fn test_parameter_new_no_grad() -> Result<()> {
963 let tensor = zeros(&[2, 3])?;
964 let param = Parameter::new_no_grad(tensor);
965
966 assert!(!param.requires_grad());
967 assert_eq!(param.shape()?, vec![2, 3]);
968 Ok(())
969 }
970
971 #[test]
972 fn test_parameter_from_tensor() -> Result<()> {
973 let tensor = zeros(&[5])?;
974 let arc_tensor = Arc::new(RwLock::new(tensor));
975 let param = Parameter::from_tensor(arc_tensor);
976
977 assert!(param.requires_grad());
978 assert_eq!(param.shape()?, vec![5]);
979 Ok(())
980 }
981
982 #[test]
983 fn test_parameter_requires_grad_setter() -> Result<()> {
984 let tensor = zeros(&[2, 2])?;
985 let param = Parameter::new(tensor).requires_grad_(false);
986
987 assert!(!param.requires_grad());
988 Ok(())
989 }
990
991 #[test]
992 fn test_parameter_from_data() -> Result<()> {
993 let data = vec![1.0, 2.0, 3.0, 4.0];
994 let param = Parameter::from_data(data.clone(), vec![2, 2])?;
995
996 assert_eq!(param.shape()?, vec![2, 2]);
997 assert_eq!(param.numel()?, 4);
998
999 let tensor_data = param.clone_data().to_vec()?;
1000 assert_eq!(tensor_data, data);
1001 Ok(())
1002 }
1003
1004 #[test]
1005 fn test_parameter_from_data_auto_shape() -> Result<()> {
1006 let data = vec![1.0, 2.0, 3.0];
1007 let param = Parameter::from_data_auto_shape(data.clone())?;
1008
1009 assert_eq!(param.shape()?, vec![3]);
1010 assert_eq!(param.numel()?, 3);
1011 Ok(())
1012 }
1013
1014 #[test]
1019 fn test_parameter_zeros() -> Result<()> {
1020 let param = Parameter::zeros(vec![2, 3], DeviceType::Cpu)?;
1021 let data = param.clone_data().to_vec()?;
1022
1023 assert_eq!(param.numel()?, 6);
1024 assert!(data.iter().all(|&x| x == 0.0));
1025 Ok(())
1026 }
1027
1028 #[test]
1029 fn test_parameter_ones() -> Result<()> {
1030 let param = Parameter::ones(vec![3, 2], DeviceType::Cpu)?;
1031 let data = param.clone_data().to_vec()?;
1032
1033 assert_eq!(param.numel()?, 6);
1034 assert!(data.iter().all(|&x| x == 1.0));
1035 Ok(())
1036 }
1037
1038 #[test]
1039 fn test_parameter_constant() -> Result<()> {
1040 let param = Parameter::constant(vec![2, 2], DeviceType::Cpu, 5.0)?;
1041 let data = param.clone_data().to_vec()?;
1042
1043 assert!(data.iter().all(|&x| (x - 5.0).abs() < 1e-6));
1044 Ok(())
1045 }
1046
1047 #[test]
1048 fn test_parameter_uniform() -> Result<()> {
1049 let param = Parameter::uniform(vec![100], DeviceType::Cpu, -1.0, 1.0)?;
1050 let data = param.clone_data().to_vec()?;
1051
1052 assert!(data.iter().all(|&x| x >= -1.0 && x <= 1.0));
1054
1055 let mean = data.iter().sum::<f32>() / data.len() as f32;
1057 assert!(mean.abs() < 0.2); Ok(())
1059 }
1060
1061 #[test]
1062 fn test_parameter_normal() -> Result<()> {
1063 let param = Parameter::normal(vec![1000], DeviceType::Cpu, 0.0, 1.0)?;
1064 let data = param.clone_data().to_vec()?;
1065
1066 let mean = data.iter().sum::<f32>() / data.len() as f32;
1067 let variance = data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / data.len() as f32;
1068
1069 assert!(mean.abs() < 0.1);
1071 assert!((variance - 1.0).abs() < 0.2);
1072 Ok(())
1073 }
1074
1075 #[test]
1076 fn test_parameter_xavier_uniform() -> Result<()> {
1077 let shape = vec![100, 50];
1078 let param = Parameter::xavier_uniform(shape.clone(), DeviceType::Cpu, 1.0)?;
1079
1080 assert_eq!(param.shape()?, shape);
1081 assert!(param.is_finite()?);
1082 Ok(())
1083 }
1084
1085 #[test]
1086 fn test_parameter_xavier_normal() -> Result<()> {
1087 let shape = vec![50, 100];
1088 let param = Parameter::xavier_normal(shape.clone(), DeviceType::Cpu, 1.0)?;
1089
1090 assert_eq!(param.shape()?, shape);
1091 assert!(param.is_finite()?);
1092 Ok(())
1093 }
1094
1095 #[test]
1096 fn test_parameter_kaiming_uniform() -> Result<()> {
1097 let shape = vec![64, 32];
1098 let param = Parameter::kaiming_uniform(shape.clone(), DeviceType::Cpu, "relu")?;
1099
1100 assert_eq!(param.shape()?, shape);
1101 assert!(param.is_finite()?);
1102 Ok(())
1103 }
1104
1105 #[test]
1106 fn test_parameter_kaiming_normal() -> Result<()> {
1107 let shape = vec![32, 64];
1108 let param = Parameter::kaiming_normal(shape.clone(), DeviceType::Cpu, "relu")?;
1109
1110 assert_eq!(param.shape()?, shape);
1111 assert!(param.is_finite()?);
1112 Ok(())
1113 }
1114
1115 #[test]
1116 fn test_parameter_lecun_uniform() -> Result<()> {
1117 let shape = vec![50, 50];
1118 let param = Parameter::lecun_uniform(shape.clone(), DeviceType::Cpu)?;
1119
1120 assert_eq!(param.shape()?, shape);
1121 assert!(param.is_finite()?);
1122 Ok(())
1123 }
1124
1125 #[test]
1126 fn test_parameter_lecun_normal() -> Result<()> {
1127 let shape = vec![50, 50];
1128 let param = Parameter::lecun_normal(shape.clone(), DeviceType::Cpu)?;
1129
1130 assert_eq!(param.shape()?, shape);
1131 assert!(param.is_finite()?);
1132 Ok(())
1133 }
1134
1135 #[test]
1136 fn test_parameter_truncated_normal() -> Result<()> {
1137 let param = Parameter::truncated_normal(vec![100], DeviceType::Cpu, 0.0, 1.0, -2.0, 2.0)?;
1138
1139 let data = param.clone_data().to_vec()?;
1140 assert!(data.iter().all(|&x| x >= -2.0 && x <= 2.0));
1142 Ok(())
1143 }
1144
1145 #[test]
1146 fn test_parameter_eye() -> Result<()> {
1147 let param = Parameter::eye(vec![3, 3], DeviceType::Cpu)?;
1148 let data = param.clone_data().to_vec()?;
1149
1150 assert_relative_eq!(data[0], 1.0, epsilon = 1e-6); assert_relative_eq!(data[4], 1.0, epsilon = 1e-6); assert_relative_eq!(data[8], 1.0, epsilon = 1e-6); assert_relative_eq!(data[1], 0.0, epsilon = 1e-6);
1157 assert_relative_eq!(data[2], 0.0, epsilon = 1e-6);
1158 Ok(())
1159 }
1160
1161 #[test]
1162 fn test_parameter_auto_init_linear() -> Result<()> {
1163 let param = Parameter::auto_init(vec![10, 5], DeviceType::Cpu, LayerType::Linear)?;
1164
1165 assert_eq!(param.shape()?, vec![10, 5]);
1166 assert!(param.is_finite()?);
1167 Ok(())
1168 }
1169
1170 #[test]
1171 fn test_parameter_auto_init_conv() -> Result<()> {
1172 let param = Parameter::auto_init(vec![3, 3, 32, 64], DeviceType::Cpu, LayerType::Conv)?;
1173
1174 assert_eq!(param.shape()?, vec![3, 3, 32, 64]);
1175 assert!(param.is_finite()?);
1176 Ok(())
1177 }
1178
1179 #[test]
1180 fn test_parameter_auto_init_embedding() -> Result<()> {
1181 let param = Parameter::auto_init(vec![1000, 128], DeviceType::Cpu, LayerType::Embedding)?;
1182
1183 assert_eq!(param.shape()?, vec![1000, 128]);
1184 assert!(param.is_finite()?);
1185 Ok(())
1186 }
1187
1188 #[test]
1193 fn test_parameter_stats() -> Result<()> {
1194 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1195 let param = Parameter::from_data(data, vec![5])?;
1196 let stats = param.stats()?;
1197
1198 assert_relative_eq!(stats.mean, 3.0, epsilon = 1e-5);
1199 assert_relative_eq!(stats.min, 1.0, epsilon = 1e-5);
1200 assert_relative_eq!(stats.max, 5.0, epsilon = 1e-5);
1201 assert_eq!(stats.numel, 5);
1202 assert_relative_eq!(stats.median, 3.0, epsilon = 1e-5);
1203 Ok(())
1204 }
1205
1206 #[test]
1207 fn test_parameter_stats_empty() -> Result<()> {
1208 let param: Parameter = Parameter::from_data(vec![], vec![0])?;
1209 let stats = param.stats()?;
1210
1211 assert_eq!(stats.numel, 0);
1212 assert_eq!(stats.mean, 0.0);
1213 assert_eq!(stats.std, 0.0);
1214 Ok(())
1215 }
1216
1217 #[test]
1218 fn test_parameter_norm_l2() -> Result<()> {
1219 let data = vec![3.0, 4.0]; let param = Parameter::from_data(data, vec![2])?;
1221 let norm = param.norm()?;
1222
1223 assert_relative_eq!(norm, 5.0, epsilon = 1e-5);
1224 Ok(())
1225 }
1226
1227 #[test]
1228 fn test_parameter_norm_l1() -> Result<()> {
1229 let data = vec![3.0, -4.0, 5.0]; let param = Parameter::from_data(data, vec![3])?;
1231 let norm = param.l1_norm()?;
1232
1233 assert_relative_eq!(norm, 12.0, epsilon = 1e-5);
1234 Ok(())
1235 }
1236
1237 #[test]
1238 fn test_parameter_norm_linf() -> Result<()> {
1239 let data = vec![1.0, -5.0, 3.0]; let param = Parameter::from_data(data, vec![3])?;
1241 let norm = param.linf_norm()?;
1242
1243 assert_relative_eq!(norm, 5.0, epsilon = 1e-5);
1244 Ok(())
1245 }
1246
1247 #[test]
1252 fn test_parameter_clamp() -> Result<()> {
1253 let data = vec![-5.0, 0.0, 5.0, 10.0];
1254 let mut param = Parameter::from_data(data, vec![4])?;
1255
1256 param.clamp(0.0, 5.0)?;
1257
1258 let clamped = param.clone_data().to_vec()?;
1259 assert_relative_eq!(clamped[0], 0.0, epsilon = 1e-5); assert_relative_eq!(clamped[1], 0.0, epsilon = 1e-5);
1261 assert_relative_eq!(clamped[2], 5.0, epsilon = 1e-5);
1262 assert_relative_eq!(clamped[3], 5.0, epsilon = 1e-5); Ok(())
1264 }
1265
1266 #[test]
1267 fn test_parameter_scale() -> Result<()> {
1268 let data = vec![1.0, 2.0, 3.0];
1269 let mut param = Parameter::from_data(data, vec![3])?;
1270
1271 param.scale(2.0)?;
1272
1273 let scaled = param.clone_data().to_vec()?;
1274 assert_relative_eq!(scaled[0], 2.0, epsilon = 1e-5);
1275 assert_relative_eq!(scaled[1], 4.0, epsilon = 1e-5);
1276 assert_relative_eq!(scaled[2], 6.0, epsilon = 1e-5);
1277 Ok(())
1278 }
1279
1280 #[test]
1281 fn test_parameter_apply_fn() -> Result<()> {
1282 let data = vec![1.0, 2.0, 3.0];
1283 let mut param = Parameter::from_data(data, vec![3])?;
1284
1285 param.apply_fn(|x| x * x)?; let result = param.clone_data().to_vec()?;
1288 assert_relative_eq!(result[0], 1.0, epsilon = 1e-5);
1289 assert_relative_eq!(result[1], 4.0, epsilon = 1e-5);
1290 assert_relative_eq!(result[2], 9.0, epsilon = 1e-5);
1291 Ok(())
1292 }
1293
1294 #[test]
1295 fn test_parameter_add_noise() -> Result<()> {
1296 let data = vec![0.0; 100];
1297 let mut param = Parameter::from_data(data, vec![100])?;
1298
1299 param.add_noise(0.1)?;
1300
1301 let noisy = param.clone_data().to_vec()?;
1302 let all_zero = noisy.iter().all(|&x| x == 0.0);
1304 assert!(!all_zero);
1305 Ok(())
1306 }
1307
1308 #[test]
1309 fn test_parameter_is_finite() -> Result<()> {
1310 let data = vec![1.0, 2.0, 3.0];
1311 let param = Parameter::from_data(data, vec![3])?;
1312
1313 assert!(param.is_finite()?);
1314 Ok(())
1315 }
1316
1317 #[test]
1318 fn test_parameter_reinitialize() -> Result<()> {
1319 let mut param = Parameter::zeros(vec![5], DeviceType::Cpu)?;
1320
1321 use crate::init::InitMethod;
1322 param.reinitialize(InitMethod::Constant { value: 7.0 })?;
1323
1324 let data = param.clone_data().to_vec()?;
1325 assert!(data.iter().all(|&x| (x - 7.0).abs() < 1e-6));
1326 Ok(())
1327 }
1328
1329 #[test]
1330 fn test_parameter_histogram() -> Result<()> {
1331 let data: Vec<f32> = (0..100).map(|i| i as f32).collect();
1332 let param = Parameter::from_data(data, vec![100])?;
1333
1334 let histogram = param.histogram(10)?;
1335
1336 assert_eq!(histogram.len(), 10);
1337 for (_, count) in histogram {
1339 assert!(count >= 9 && count <= 11);
1340 }
1341 Ok(())
1342 }
1343
1344 #[test]
1345 fn test_parameter_histogram_constant() -> Result<()> {
1346 let data = vec![5.0; 10];
1347 let param = Parameter::from_data(data, vec![10])?;
1348
1349 let histogram = param.histogram(5)?;
1350
1351 assert_eq!(histogram.len(), 1);
1353 assert_eq!(histogram[0].1, 10);
1354 Ok(())
1355 }
1356
1357 #[test]
1358 fn test_parameter_diagnose_normal() -> Result<()> {
1359 let data = vec![1.0, 2.0, 3.0, 4.0];
1360 let param = Parameter::from_data(data, vec![4])?;
1361
1362 let diagnostics = param.diagnose()?;
1363
1364 assert!(diagnostics.is_finite);
1365 assert!(diagnostics.issues.is_empty());
1366 assert_eq!(diagnostics.stats.numel, 4);
1367 Ok(())
1368 }
1369
1370 #[test]
1375 fn test_parameter_stats_from_data() {
1376 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1377 let stats = ParameterStats::from_data(&data);
1378
1379 assert_relative_eq!(stats.mean, 3.0, epsilon = 1e-5);
1380 assert_relative_eq!(stats.median, 3.0, epsilon = 1e-5);
1381 assert_relative_eq!(stats.min, 1.0, epsilon = 1e-5);
1382 assert_relative_eq!(stats.max, 5.0, epsilon = 1e-5);
1383 assert_eq!(stats.numel, 5);
1384 }
1385
1386 #[test]
1387 fn test_parameter_stats_empty_constructor() {
1388 let stats = ParameterStats::empty();
1389
1390 assert_eq!(stats.mean, 0.0);
1391 assert_eq!(stats.std, 0.0);
1392 assert_eq!(stats.numel, 0);
1393 }
1394
1395 #[test]
1396 fn test_parameter_stats_iqr() {
1397 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1398 let stats = ParameterStats::from_data(&data);
1399
1400 let iqr = stats.iqr();
1401 assert!(iqr > 0.0);
1402 }
1403
1404 #[test]
1405 fn test_parameter_stats_is_approximately_normal() {
1406 let data: Vec<f32> = vec![-1.0, -0.5, 0.0, 0.5, 1.0, -0.3, 0.3, -0.8, 0.8, -0.2, 0.2];
1408 let stats = ParameterStats::from_data(&data);
1409
1410 assert!(stats.skewness.abs() < 2.0); }
1413
1414 #[test]
1419 fn test_parameter_collection_new() {
1420 let collection = ParameterCollection::new();
1421
1422 assert_eq!(collection.len(), 0);
1423 assert!(collection.is_empty());
1424 }
1425
1426 #[test]
1427 fn test_parameter_collection_add_get() -> Result<()> {
1428 let mut collection = ParameterCollection::new();
1429
1430 let param = Parameter::zeros(vec![3, 4], DeviceType::Cpu)?;
1431 collection.add("weight".to_string(), param);
1432
1433 assert_eq!(collection.len(), 1);
1434 assert!(collection.get("weight").is_some());
1435 assert!(collection.get("bias").is_none());
1436 Ok(())
1437 }
1438
1439 #[test]
1440 fn test_parameter_collection_names() -> Result<()> {
1441 let mut collection = ParameterCollection::new();
1442
1443 collection.add(
1444 "weight".to_string(),
1445 Parameter::zeros(vec![2, 2], DeviceType::Cpu)?,
1446 );
1447 collection.add(
1448 "bias".to_string(),
1449 Parameter::zeros(vec![2], DeviceType::Cpu)?,
1450 );
1451
1452 let names = collection.names();
1453 assert_eq!(names.len(), 2);
1454 assert!(names.contains(&&"weight".to_string()));
1455 assert!(names.contains(&&"bias".to_string()));
1456 Ok(())
1457 }
1458
1459 #[test]
1460 fn test_parameter_collection_total_parameters() -> Result<()> {
1461 let mut collection = ParameterCollection::new();
1462
1463 collection.add(
1464 "weight".to_string(),
1465 Parameter::zeros(vec![2, 3], DeviceType::Cpu)?,
1466 ); collection.add(
1468 "bias".to_string(),
1469 Parameter::zeros(vec![3], DeviceType::Cpu)?,
1470 ); assert_eq!(collection.total_parameters(), 9);
1473 Ok(())
1474 }
1475
1476 #[test]
1477 fn test_parameter_collection_total_memory_usage() -> Result<()> {
1478 let mut collection = ParameterCollection::new();
1479
1480 collection.add(
1481 "weight".to_string(),
1482 Parameter::zeros(vec![10], DeviceType::Cpu)?,
1483 );
1484
1485 let memory = collection.total_memory_usage();
1486 assert_eq!(memory, 10 * 4); Ok(())
1488 }
1489
1490 #[test]
1491 fn test_parameter_collection_freeze_unfreeze() -> Result<()> {
1492 let mut collection = ParameterCollection::new();
1493
1494 let param = Parameter::zeros(vec![2], DeviceType::Cpu)?;
1495 collection.add("weight".to_string(), param);
1496
1497 collection.freeze_all();
1498 assert!(!collection.get("weight").unwrap().requires_grad());
1499
1500 collection.unfreeze_all();
1501 assert!(collection.get("weight").unwrap().requires_grad());
1502 Ok(())
1503 }
1504
1505 #[test]
1506 fn test_parameter_collection_scale_all() -> Result<()> {
1507 let mut collection = ParameterCollection::new();
1508
1509 let param = Parameter::ones(vec![3], DeviceType::Cpu)?;
1510 collection.add("weight".to_string(), param);
1511
1512 collection.scale_all(2.0)?;
1513
1514 let weight = collection.get("weight").unwrap();
1515 let data = weight.clone_data().to_vec()?;
1516 assert!(data.iter().all(|&x| (x - 2.0).abs() < 1e-5));
1517 Ok(())
1518 }
1519
1520 #[test]
1521 fn test_parameter_collection_clamp_all() -> Result<()> {
1522 let mut collection = ParameterCollection::new();
1523
1524 let data = vec![-5.0, 0.0, 5.0];
1525 let param = Parameter::from_data(data, vec![3])?;
1526 collection.add("weight".to_string(), param);
1527
1528 collection.clamp_all(-1.0, 1.0)?;
1529
1530 let weight = collection.get("weight").unwrap();
1531 let clamped = weight.clone_data().to_vec()?;
1532 assert!(clamped.iter().all(|&x| x >= -1.0 && x <= 1.0));
1533 Ok(())
1534 }
1535
1536 #[test]
1537 fn test_parameter_collection_filter_by_name() -> Result<()> {
1538 let mut collection = ParameterCollection::new();
1539
1540 collection.add(
1541 "layer1.weight".to_string(),
1542 Parameter::zeros(vec![2], DeviceType::Cpu)?,
1543 );
1544 collection.add(
1545 "layer1.bias".to_string(),
1546 Parameter::zeros(vec![2], DeviceType::Cpu)?,
1547 );
1548 collection.add(
1549 "layer2.weight".to_string(),
1550 Parameter::zeros(vec![2], DeviceType::Cpu)?,
1551 );
1552
1553 let filtered = collection.filter_by_name("layer1");
1554 assert_eq!(filtered.len(), 2);
1555
1556 let filtered_weight = collection.filter_by_name("weight");
1557 assert_eq!(filtered_weight.len(), 2);
1558 Ok(())
1559 }
1560
1561 #[test]
1562 fn test_parameter_collection_filter_by_predicate() -> Result<()> {
1563 let mut collection = ParameterCollection::new();
1564
1565 collection.add(
1566 "weight".to_string(),
1567 Parameter::zeros(vec![10], DeviceType::Cpu)?,
1568 );
1569 collection.add(
1570 "bias".to_string(),
1571 Parameter::zeros(vec![5], DeviceType::Cpu)?,
1572 );
1573
1574 let filtered = collection.filter_by(|_, param| param.numel().unwrap_or(0) > 5);
1576 assert_eq!(filtered.len(), 1);
1577 assert!(filtered.get("weight").is_some());
1578 Ok(())
1579 }
1580
1581 #[test]
1582 fn test_parameter_collection_stats() -> Result<()> {
1583 let mut collection = ParameterCollection::new();
1584
1585 collection.add(
1586 "weight".to_string(),
1587 Parameter::ones(vec![5], DeviceType::Cpu)?,
1588 );
1589
1590 let stats = collection.stats()?;
1591 assert_eq!(stats.len(), 1);
1592
1593 let weight_stats = stats.get("weight").unwrap();
1594 assert_relative_eq!(weight_stats.mean, 1.0, epsilon = 1e-5);
1595 Ok(())
1596 }
1597
1598 #[test]
1599 fn test_parameter_collection_diagnose() -> Result<()> {
1600 let mut collection = ParameterCollection::new();
1601
1602 collection.add(
1603 "weight".to_string(),
1604 Parameter::ones(vec![3], DeviceType::Cpu)?,
1605 );
1606
1607 let diagnostics = collection.diagnose()?;
1608 assert_eq!(diagnostics.len(), 1);
1609
1610 let weight_diag = diagnostics.get("weight").unwrap();
1611 assert!(weight_diag.is_finite);
1612 Ok(())
1613 }
1614
1615 #[test]
1616 fn test_parameter_collection_summary_report() -> Result<()> {
1617 let mut collection = ParameterCollection::new();
1618
1619 collection.add(
1620 "weight".to_string(),
1621 Parameter::ones(vec![10], DeviceType::Cpu)?,
1622 );
1623 collection.add(
1624 "bias".to_string(),
1625 Parameter::zeros(vec![5], DeviceType::Cpu)?,
1626 );
1627
1628 let report = collection.summary_report()?;
1629
1630 assert!(report.contains("Total parameters: 2"));
1631 assert!(report.contains("Total elements: 15"));
1632 assert!(report.contains("weight"));
1633 assert!(report.contains("bias"));
1634 Ok(())
1635 }
1636
1637 #[test]
1638 fn test_parameter_collection_from_hashmap() -> Result<()> {
1639 let mut map = HashMap::new();
1640 map.insert(
1641 "weight".to_string(),
1642 Parameter::zeros(vec![3], DeviceType::Cpu)?,
1643 );
1644
1645 let collection = ParameterCollection::from_map(map);
1646 assert_eq!(collection.len(), 1);
1647 Ok(())
1648 }
1649
1650 #[test]
1651 fn test_parameter_collection_into_hashmap() -> Result<()> {
1652 let mut collection = ParameterCollection::new();
1653 collection.add(
1654 "weight".to_string(),
1655 Parameter::zeros(vec![3], DeviceType::Cpu)?,
1656 );
1657
1658 let map: HashMap<String, Parameter> = collection.into();
1659 assert_eq!(map.len(), 1);
1660 assert!(map.contains_key("weight"));
1661 Ok(())
1662 }
1663}