Skip to main content

torsh_nn/parameter/
mod.rs

1//! Parameter management system for neural network modules
2//!
3//! This module provides comprehensive parameter management including:
4//! - Parameter struct with thread-safe tensor access
5//! - Comprehensive initialization methods
6//! - Parameter analysis and diagnostics
7//! - Parameter collections and batch operations
8
9pub mod parameter_ext;
10
11pub use parameter_ext::{
12    ParameterAnalysis, ParameterCollectionExt, ParameterConstraint, ParameterExt, ParameterGroup,
13};
14
15use crate::init::Initializer;
16use parking_lot::RwLock;
17use std::sync::Arc;
18use torsh_core::device::DeviceType;
19use torsh_core::error::Result;
20use torsh_tensor::Tensor;
21
22// Conditional imports for std/no_std compatibility
23#[cfg(feature = "std")]
24use std::collections::HashMap;
25
26#[cfg(not(feature = "std"))]
27use hashbrown::HashMap;
28
29/// Parameter wrapper for trainable tensors
30#[derive(Clone, Debug)]
31pub struct Parameter {
32    data: Arc<RwLock<Tensor>>,
33    requires_grad: bool,
34}
35
36impl Parameter {
37    /// Create a new parameter
38    pub fn new(tensor: Tensor) -> Self {
39        Self {
40            data: Arc::new(RwLock::new(tensor)),
41            requires_grad: true,
42        }
43    }
44
45    /// Create a parameter that doesn't require gradients
46    pub fn new_no_grad(tensor: Tensor) -> Self {
47        Self {
48            data: Arc::new(RwLock::new(tensor)),
49            requires_grad: false,
50        }
51    }
52
53    /// Get the underlying tensor
54    pub fn tensor(&self) -> Arc<RwLock<Tensor>> {
55        self.data.clone()
56    }
57
58    /// Create a parameter from an existing tensor Arc
59    pub fn from_tensor(tensor: Arc<RwLock<Tensor>>) -> Self {
60        Self {
61            data: tensor,
62            requires_grad: true,
63        }
64    }
65
66    /// Set whether this parameter requires gradients
67    pub fn requires_grad_(mut self, requires_grad: bool) -> Self {
68        self.requires_grad = requires_grad;
69        // Note: torsh_tensor doesn't support requires_grad yet
70        // This will be implemented when autograd is available
71        self
72    }
73
74    /// Check if parameter requires gradients
75    pub fn requires_grad(&self) -> bool {
76        self.requires_grad
77    }
78
79    /// Get parameter shape
80    pub fn shape(&self) -> Result<Vec<usize>> {
81        Ok(self.data.read().shape().dims().to_vec())
82    }
83
84    /// Get parameter size (number of elements)
85    pub fn numel(&self) -> Result<usize> {
86        Ok(self.data.read().shape().numel())
87    }
88
89    /// Move parameter to device
90    pub fn to_device(&mut self, device: DeviceType) -> Result<()> {
91        // This would move the tensor to the specified device
92        // For now, just update the device field when tensor supports it
93        let _ = device; // Suppress warning
94        Ok(())
95    }
96
97    /// Zero the parameter gradients
98    pub fn zero_grad(&mut self) {
99        // This would zero gradients when autograd is available
100        // For now, this is a placeholder
101    }
102
103    /// Clone the parameter data
104    pub fn clone_data(&self) -> Tensor {
105        self.data.read().clone()
106    }
107}
108
109/// Enhanced parameter management utilities
110impl Parameter {
111    /// Create parameter with specific initialization function
112    ///
113    /// This is the most flexible parameter creation method, allowing custom
114    /// initialization logic.
115    pub fn with_init<F>(shape: Vec<usize>, _device: DeviceType, init_fn: F) -> Result<Self>
116    where
117        F: FnOnce(Vec<usize>) -> Result<Tensor>,
118    {
119        let tensor = init_fn(shape)?;
120        Ok(Self::new(tensor))
121    }
122
123    /// Create parameter from existing data
124    ///
125    /// Convenient method to create a parameter from a vector of data.
126    pub fn from_data(data: Vec<f32>, shape: Vec<usize>) -> Result<Self> {
127        let tensor = torsh_tensor::Tensor::from_vec(data, &shape)?;
128        Ok(Self::new(tensor))
129    }
130
131    /// Create parameter with automatic shape inference
132    ///
133    /// Creates a parameter where the shape is inferred from the provided data.
134    pub fn from_data_auto_shape(data: Vec<f32>) -> Result<Self> {
135        let shape = vec![data.len()];
136        Self::from_data(data, shape)
137    }
138
139    /// Create parameter with random initialization and automatic fan calculation
140    ///
141    /// This method automatically chooses the best initialization based on the layer type.
142    pub fn auto_init(shape: Vec<usize>, device: DeviceType, layer_type: LayerType) -> Result<Self> {
143        use crate::init::InitMethod;
144
145        let init_method = match layer_type {
146            LayerType::Linear | LayerType::Dense => InitMethod::KaimingUniform {
147                mode: crate::init::FanMode::FanIn,
148                nonlinearity: crate::init::Nonlinearity::Linear,
149            },
150            LayerType::Conv => InitMethod::KaimingUniform {
151                mode: crate::init::FanMode::FanOut,
152                nonlinearity: crate::init::Nonlinearity::ReLU,
153            },
154            LayerType::RNN | LayerType::LSTM | LayerType::GRU => {
155                InitMethod::XavierUniform { gain: 1.0 }
156            }
157            LayerType::Attention => InitMethod::XavierNormal { gain: 1.0 },
158            LayerType::Embedding => InitMethod::Normal {
159                mean: 0.0,
160                std: 1.0,
161            },
162            LayerType::Bias => InitMethod::Constant { value: 0.0 },
163            LayerType::BatchNorm => InitMethod::Constant { value: 1.0 },
164            LayerType::Custom => InitMethod::KaimingUniform {
165                mode: crate::init::FanMode::FanIn,
166                nonlinearity: crate::init::Nonlinearity::ReLU,
167            },
168        };
169
170        Self::with_init_method(shape, device, init_method)
171    }
172
173    /// Create parameter filled with zeros
174    pub fn zeros(shape: Vec<usize>, _device: DeviceType) -> Result<Self> {
175        use torsh_tensor::creation::zeros;
176        let tensor = zeros(&shape)?;
177        Ok(Self::new(tensor))
178    }
179
180    /// Create parameter filled with ones
181    pub fn ones(shape: Vec<usize>, _device: DeviceType) -> Result<Self> {
182        use torsh_tensor::creation::ones;
183        let tensor = ones(&shape)?;
184        Ok(Self::new(tensor))
185    }
186
187    /// Create parameter using InitMethod enum
188    pub fn with_init_method(
189        shape: Vec<usize>,
190        _device: DeviceType,
191        method: crate::init::InitMethod,
192    ) -> Result<Self> {
193        let tensor = method.initialize(&shape)?;
194        Ok(Self::new(tensor))
195    }
196
197    /// Create parameter with uniform random initialization
198    pub fn uniform(shape: Vec<usize>, device: DeviceType, low: f32, high: f32) -> Result<Self> {
199        use crate::init::InitMethod;
200        Self::with_init_method(shape, device, InitMethod::Uniform { low, high })
201    }
202
203    /// Create parameter with normal random initialization
204    pub fn normal(shape: Vec<usize>, device: DeviceType, mean: f32, std: f32) -> Result<Self> {
205        use crate::init::InitMethod;
206        Self::with_init_method(shape, device, InitMethod::Normal { mean, std })
207    }
208
209    /// Create parameter with Xavier/Glorot uniform initialization
210    pub fn xavier_uniform(shape: Vec<usize>, device: DeviceType, gain: f32) -> Result<Self> {
211        use crate::init::InitMethod;
212        Self::with_init_method(shape, device, InitMethod::XavierUniform { gain })
213    }
214
215    /// Create parameter with Xavier/Glorot normal initialization
216    pub fn xavier_normal(shape: Vec<usize>, device: DeviceType, gain: f32) -> Result<Self> {
217        use crate::init::InitMethod;
218        Self::with_init_method(shape, device, InitMethod::XavierNormal { gain })
219    }
220
221    /// Create parameter with Kaiming/He uniform initialization
222    pub fn kaiming_uniform(
223        shape: Vec<usize>,
224        device: DeviceType,
225        nonlinearity: &str,
226    ) -> Result<Self> {
227        use crate::init::{FanMode, InitMethod, Nonlinearity};
228        let nl = match nonlinearity {
229            "relu" => Nonlinearity::ReLU,
230            "leaky_relu" => Nonlinearity::LeakyReLU {
231                negative_slope: 0.01,
232            },
233            "tanh" => Nonlinearity::Tanh,
234            "sigmoid" => Nonlinearity::Sigmoid,
235            "selu" => Nonlinearity::SELU,
236            "elu" => Nonlinearity::ELU,
237            "swish" => Nonlinearity::Swish,
238            "linear" => Nonlinearity::Linear,
239            _ => Nonlinearity::Linear,
240        };
241        Self::with_init_method(
242            shape,
243            device,
244            InitMethod::KaimingUniform {
245                mode: FanMode::FanIn,
246                nonlinearity: nl,
247            },
248        )
249    }
250
251    /// Create parameter with Kaiming/He normal initialization
252    pub fn kaiming_normal(
253        shape: Vec<usize>,
254        device: DeviceType,
255        nonlinearity: &str,
256    ) -> Result<Self> {
257        use crate::init::{FanMode, InitMethod, Nonlinearity};
258        let nl = match nonlinearity {
259            "relu" => Nonlinearity::ReLU,
260            "leaky_relu" => Nonlinearity::LeakyReLU {
261                negative_slope: 0.01,
262            },
263            "tanh" => Nonlinearity::Tanh,
264            "sigmoid" => Nonlinearity::Sigmoid,
265            "selu" => Nonlinearity::SELU,
266            "elu" => Nonlinearity::ELU,
267            "swish" => Nonlinearity::Swish,
268            "linear" => Nonlinearity::Linear,
269            _ => Nonlinearity::Linear,
270        };
271        Self::with_init_method(
272            shape,
273            device,
274            InitMethod::KaimingNormal {
275                mode: FanMode::FanIn,
276                nonlinearity: nl,
277            },
278        )
279    }
280
281    /// Create parameter with constant value
282    pub fn constant(shape: Vec<usize>, device: DeviceType, value: f32) -> Result<Self> {
283        use crate::init::InitMethod;
284        Self::with_init_method(shape, device, InitMethod::Constant { value })
285    }
286
287    /// Create parameter with orthogonal initialization
288    pub fn orthogonal(shape: Vec<usize>, device: DeviceType, gain: f32) -> Result<Self> {
289        use crate::init::InitMethod;
290        Self::with_init_method(shape, device, InitMethod::Orthogonal { gain })
291    }
292
293    /// Create parameter with sparse initialization
294    pub fn sparse(shape: Vec<usize>, device: DeviceType, sparsity: f32, std: f32) -> Result<Self> {
295        use crate::init::InitMethod;
296        Self::with_init_method(shape, device, InitMethod::Sparse { sparsity, std })
297    }
298
299    /// Create parameter with Lecun uniform initialization
300    pub fn lecun_uniform(shape: Vec<usize>, device: DeviceType) -> Result<Self> {
301        use crate::init::InitMethod;
302        Self::with_init_method(shape, device, InitMethod::LecunUniform)
303    }
304
305    /// Create parameter with Lecun normal initialization
306    pub fn lecun_normal(shape: Vec<usize>, device: DeviceType) -> Result<Self> {
307        use crate::init::InitMethod;
308        Self::with_init_method(shape, device, InitMethod::LecunNormal)
309    }
310
311    /// Create parameter with truncated normal initialization
312    pub fn truncated_normal(
313        shape: Vec<usize>,
314        device: DeviceType,
315        mean: f32,
316        std: f32,
317        a: f32,
318        b: f32,
319    ) -> Result<Self> {
320        use crate::init::InitMethod;
321        Self::with_init_method(
322            shape,
323            device,
324            InitMethod::TruncatedNormal { mean, std, a, b },
325        )
326    }
327
328    /// Create parameter with eye/identity initialization
329    pub fn eye(shape: Vec<usize>, device: DeviceType) -> Result<Self> {
330        use crate::init::InitMethod;
331        Self::with_init_method(shape, device, InitMethod::Eye)
332    }
333
334    /// Get parameter statistics
335    pub fn stats(&self) -> Result<ParameterStats> {
336        let tensor = self.data.read();
337        let data = tensor.to_vec()?;
338
339        if data.is_empty() {
340            return Ok(ParameterStats {
341                mean: 0.0,
342                std: 0.0,
343                variance: 0.0,
344                min: 0.0,
345                max: 0.0,
346                numel: 0,
347                median: 0.0,
348                q25: 0.0,
349                q75: 0.0,
350                skewness: 0.0,
351                kurtosis: 0.0,
352            });
353        }
354
355        let mean = data.iter().sum::<f32>() / data.len() as f32;
356        let variance = data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / data.len() as f32;
357        let std = variance.sqrt();
358        let min = data.iter().copied().fold(f32::INFINITY, f32::min);
359        let max = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
360
361        // Calculate additional statistics
362        let mut sorted_data = data.clone();
363        sorted_data.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
364
365        let median = if sorted_data.len() % 2 == 0 {
366            (sorted_data[sorted_data.len() / 2 - 1] + sorted_data[sorted_data.len() / 2]) / 2.0
367        } else {
368            sorted_data[sorted_data.len() / 2]
369        };
370
371        let q25_idx = sorted_data.len() / 4;
372        let q75_idx = 3 * sorted_data.len() / 4;
373        let q25 = sorted_data.get(q25_idx).copied().unwrap_or(min);
374        let q75 = sorted_data.get(q75_idx).copied().unwrap_or(max);
375
376        // Basic skewness and kurtosis calculations
377        let n = data.len() as f32;
378        let skewness = if std > 0.0 {
379            data.iter().map(|x| ((x - mean) / std).powi(3)).sum::<f32>() / n
380        } else {
381            0.0
382        };
383
384        let kurtosis = if std > 0.0 {
385            data.iter().map(|x| ((x - mean) / std).powi(4)).sum::<f32>() / n - 3.0
386        } else {
387            0.0
388        };
389
390        Ok(ParameterStats {
391            mean,
392            std,
393            variance,
394            min,
395            max,
396            numel: data.len(),
397            median,
398            q25,
399            q75,
400            skewness,
401            kurtosis,
402        })
403    }
404
405    /// Check if parameter has finite values (no NaN or infinity)
406    pub fn is_finite(&self) -> Result<bool> {
407        let tensor = self.data.read();
408        let data = tensor.to_vec()?;
409        Ok(data.iter().all(|x| x.is_finite()))
410    }
411
412    /// Reinitialize parameter with a new method
413    pub fn reinitialize(&mut self, method: crate::init::InitMethod) -> Result<()> {
414        let current_shape = self.shape()?;
415        let new_tensor = method.initialize(&current_shape)?;
416        *self.data.write() = new_tensor;
417        Ok(())
418    }
419
420    /// Get parameter norm (L2 norm)
421    pub fn norm(&self) -> Result<f32> {
422        let tensor = self.data.read();
423        let data = tensor.to_vec()?;
424        let norm = data.iter().map(|x| x * x).sum::<f32>().sqrt();
425        Ok(norm)
426    }
427
428    /// Get parameter L1 norm
429    pub fn l1_norm(&self) -> Result<f32> {
430        let tensor = self.data.read();
431        let data = tensor.to_vec()?;
432        let norm = data.iter().map(|x| x.abs()).sum::<f32>();
433        Ok(norm)
434    }
435
436    /// Get parameter L-infinity norm (max absolute value)
437    pub fn linf_norm(&self) -> Result<f32> {
438        let tensor = self.data.read();
439        let data = tensor.to_vec()?;
440        let norm = data.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
441        Ok(norm)
442    }
443
444    /// Clamp parameter values to a range
445    pub fn clamp(&mut self, min: f32, max: f32) -> Result<()> {
446        let mut tensor = self.data.write();
447        let data = tensor.to_vec()?;
448        let clamped_data: Vec<f32> = data.iter().map(|&x| x.clamp(min, max)).collect();
449        let shape = tensor.shape().dims().to_vec();
450        *tensor = torsh_tensor::Tensor::from_vec(clamped_data, &shape)?;
451        Ok(())
452    }
453
454    /// Apply a function to all parameter values
455    pub fn apply_fn<F>(&mut self, f: F) -> Result<()>
456    where
457        F: Fn(f32) -> f32,
458    {
459        let mut tensor = self.data.write();
460        let data = tensor.to_vec()?;
461        let transformed_data: Vec<f32> = data.iter().map(|&x| f(x)).collect();
462        let shape = tensor.shape().dims().to_vec();
463        *tensor = torsh_tensor::Tensor::from_vec(transformed_data, &shape)?;
464        Ok(())
465    }
466
467    /// Scale parameter by a factor
468    pub fn scale(&mut self, factor: f32) -> Result<()> {
469        self.apply_fn(|x| x * factor)
470    }
471
472    /// Add noise to parameter
473    pub fn add_noise(&mut self, std: f32) -> Result<()> {
474        use scirs2_core::random::thread_rng;
475        let mut rng = thread_rng();
476        let mut tensor = self.data.write();
477        let data = tensor.to_vec()?;
478        let noisy_data: Vec<f32> = data
479            .iter()
480            .map(|&x| x + rng.random::<f32>() * std)
481            .collect();
482        let shape = tensor.shape().dims().to_vec();
483        *tensor = torsh_tensor::Tensor::from_vec(noisy_data, &shape)?;
484        Ok(())
485    }
486
487    /// Get parameter histogram for analysis
488    pub fn histogram(&self, bins: usize) -> Result<Vec<(f32, usize)>> {
489        let tensor = self.data.read();
490        let data = tensor.to_vec()?;
491
492        if data.is_empty() {
493            return Ok(Vec::new());
494        }
495
496        let min_val = data.iter().copied().fold(f32::INFINITY, f32::min);
497        let max_val = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
498
499        if min_val == max_val {
500            return Ok(vec![(min_val, data.len())]);
501        }
502
503        let bin_width = (max_val - min_val) / bins as f32;
504        let mut histogram = vec![0; bins];
505
506        for &value in &data {
507            let bin_index = ((value - min_val) / bin_width).floor() as usize;
508            let bin_index = bin_index.min(bins - 1);
509            histogram[bin_index] += 1;
510        }
511
512        let result: Vec<(f32, usize)> = histogram
513            .into_iter()
514            .enumerate()
515            .map(|(i, count)| (min_val + (i as f32 + 0.5) * bin_width, count))
516            .collect();
517
518        Ok(result)
519    }
520
521    /// Check for common parameter issues
522    pub fn diagnose(&self) -> Result<ParameterDiagnostics> {
523        let stats = self.stats()?;
524        let mut issues = Vec::new();
525        let mut warnings = Vec::new();
526
527        // Check for NaN or infinite values
528        if !self.is_finite()? {
529            issues.push("Parameter contains NaN or infinite values".to_string());
530        }
531
532        // Check for suspicious statistics
533        if stats.std < 1e-6 {
534            warnings
535                .push("Very low standard deviation - parameters may be too uniform".to_string());
536        }
537
538        if stats.std > 10.0 {
539            warnings.push("Very high standard deviation - parameters may be unstable".to_string());
540        }
541
542        if stats.mean.abs() > 5.0 {
543            warnings
544                .push("High mean absolute value - parameters may be poorly centered".to_string());
545        }
546
547        // Check gradient-related issues
548        let norm = self.norm()?;
549        if norm < 1e-8 {
550            warnings
551                .push("Very small parameter norm - may indicate vanishing gradients".to_string());
552        } else if norm > 100.0 {
553            warnings
554                .push("Very large parameter norm - may indicate exploding gradients".to_string());
555        }
556
557        Ok(ParameterDiagnostics {
558            stats,
559            issues,
560            warnings,
561            norm,
562            is_finite: self.is_finite()?,
563        })
564    }
565}
566
567/// Layer type enumeration for automatic parameter initialization
568#[derive(Debug, Clone, Copy, PartialEq, Eq)]
569pub enum LayerType {
570    Linear,
571    Dense,
572    Conv,
573    RNN,
574    LSTM,
575    GRU,
576    Attention,
577    Embedding,
578    Bias,
579    BatchNorm,
580    Custom,
581}
582
583/// Parameter statistics for analysis and debugging
584#[derive(Debug, Clone)]
585pub struct ParameterStats {
586    pub mean: f32,
587    pub std: f32,
588    pub variance: f32,
589    pub min: f32,
590    pub max: f32,
591    pub numel: usize,
592    pub median: f32,
593    pub q25: f32,
594    pub q75: f32,
595    pub skewness: f32,
596    pub kurtosis: f32,
597}
598
599impl ParameterStats {
600    /// Create parameter statistics from data
601    pub fn from_data(data: &[f32]) -> Self {
602        if data.is_empty() {
603            return Self::empty();
604        }
605
606        let n = data.len() as f32;
607        let mean = data.iter().sum::<f32>() / n;
608        let variance = data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / n;
609        let std = variance.sqrt();
610
611        let mut sorted_data = data.to_vec();
612        sorted_data.sort_by(|a, b| {
613            a.partial_cmp(b)
614                .expect("data comparison should not involve NaN")
615        });
616
617        let min = sorted_data[0];
618        let max = sorted_data[sorted_data.len() - 1];
619        let median = Self::percentile(&sorted_data, 0.5);
620        let q25 = Self::percentile(&sorted_data, 0.25);
621        let q75 = Self::percentile(&sorted_data, 0.75);
622
623        // Calculate skewness and kurtosis
624        let std_cubed = std.powi(3);
625        let std_fourth = std.powi(4);
626
627        let skewness = if std_cubed > 0.0 {
628            data.iter().map(|x| ((x - mean) / std).powi(3)).sum::<f32>() / n
629        } else {
630            0.0
631        };
632
633        let kurtosis = if std_fourth > 0.0 {
634            data.iter().map(|x| ((x - mean) / std).powi(4)).sum::<f32>() / n - 3.0
635        } else {
636            0.0
637        };
638
639        Self {
640            mean,
641            std,
642            variance,
643            min,
644            max,
645            numel: data.len(),
646            median,
647            q25,
648            q75,
649            skewness,
650            kurtosis,
651        }
652    }
653
654    /// Create empty statistics
655    pub fn empty() -> Self {
656        Self {
657            mean: 0.0,
658            std: 0.0,
659            variance: 0.0,
660            min: 0.0,
661            max: 0.0,
662            numel: 0,
663            median: 0.0,
664            q25: 0.0,
665            q75: 0.0,
666            skewness: 0.0,
667            kurtosis: 0.0,
668        }
669    }
670
671    /// Calculate percentile from sorted data
672    fn percentile(sorted_data: &[f32], p: f32) -> f32 {
673        if sorted_data.is_empty() {
674            return 0.0;
675        }
676
677        let index = p * (sorted_data.len() - 1) as f32;
678        let lower = index.floor() as usize;
679        let upper = index.ceil() as usize;
680
681        if lower == upper {
682            sorted_data[lower]
683        } else {
684            let weight = index - lower as f32;
685            sorted_data[lower] * (1.0 - weight) + sorted_data[upper] * weight
686        }
687    }
688
689    /// Get interquartile range
690    pub fn iqr(&self) -> f32 {
691        self.q75 - self.q25
692    }
693
694    /// Check if distribution appears normal
695    pub fn is_approximately_normal(&self) -> bool {
696        // Simple heuristic: check if skewness and kurtosis are reasonable
697        self.skewness.abs() < 1.0 && self.kurtosis.abs() < 3.0
698    }
699}
700
701impl core::fmt::Display for ParameterStats {
702    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
703        write!(
704            f,
705            "ParameterStats(mean={:.4}, std={:.4}, min={:.4}, max={:.4}, numel={})",
706            self.mean, self.std, self.min, self.max, self.numel
707        )
708    }
709}
710
711/// Parameter diagnostics for debugging and analysis
712#[derive(Debug, Clone)]
713pub struct ParameterDiagnostics {
714    pub stats: ParameterStats,
715    pub issues: Vec<String>,
716    pub warnings: Vec<String>,
717    pub norm: f32,
718    pub is_finite: bool,
719}
720
721impl core::fmt::Display for ParameterDiagnostics {
722    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
723        writeln!(f, "Parameter Diagnostics:")?;
724        writeln!(f, "  {}", self.stats)?;
725        writeln!(f, "  Norm: {:.6}", self.norm)?;
726        writeln!(f, "  Finite: {}", self.is_finite)?;
727
728        if !self.issues.is_empty() {
729            writeln!(f, "  Issues:")?;
730            for issue in &self.issues {
731                writeln!(f, "    - {issue}")?;
732            }
733        }
734
735        if !self.warnings.is_empty() {
736            writeln!(f, "  Warnings:")?;
737            for warning in &self.warnings {
738                writeln!(f, "    - {warning}")?;
739            }
740        }
741
742        Ok(())
743    }
744}
745
746/// Parameter collection utility for managing multiple parameters
747///
748/// This provides convenient methods for working with collections of parameters,
749/// such as applying operations to all parameters in a module.
750#[derive(Debug, Clone)]
751pub struct ParameterCollection {
752    parameters: HashMap<String, Parameter>,
753}
754
755impl ParameterCollection {
756    /// Create a new parameter collection
757    pub fn new() -> Self {
758        Self {
759            parameters: HashMap::new(),
760        }
761    }
762
763    /// Create from a parameter map
764    pub fn from_map(parameters: HashMap<String, Parameter>) -> Self {
765        Self { parameters }
766    }
767
768    /// Add a parameter to the collection
769    pub fn add(&mut self, name: String, parameter: Parameter) {
770        self.parameters.insert(name, parameter);
771    }
772
773    /// Get a parameter by name
774    pub fn get(&self, name: &str) -> Option<&Parameter> {
775        self.parameters.get(name)
776    }
777
778    /// Get a mutable parameter by name
779    pub fn get_mut(&mut self, name: &str) -> Option<&mut Parameter> {
780        self.parameters.get_mut(name)
781    }
782
783    /// Get all parameter names
784    pub fn names(&self) -> Vec<&String> {
785        self.parameters.keys().collect()
786    }
787
788    /// Get the number of parameters
789    pub fn len(&self) -> usize {
790        self.parameters.len()
791    }
792
793    /// Check if the collection is empty
794    pub fn is_empty(&self) -> bool {
795        self.parameters.is_empty()
796    }
797
798    /// Apply a function to all parameters
799    pub fn apply_to_all<F>(&mut self, mut f: F) -> Result<()>
800    where
801        F: FnMut(&mut Parameter) -> Result<()>,
802    {
803        for param in self.parameters.values_mut() {
804            f(param)?;
805        }
806        Ok(())
807    }
808
809    /// Get statistics for all parameters
810    pub fn stats(&self) -> Result<HashMap<String, ParameterStats>> {
811        let mut stats = HashMap::new();
812        for (name, param) in &self.parameters {
813            stats.insert(name.clone(), param.stats()?);
814        }
815        Ok(stats)
816    }
817
818    /// Get diagnostics for all parameters
819    pub fn diagnose(&self) -> Result<HashMap<String, ParameterDiagnostics>> {
820        let mut diagnostics = HashMap::new();
821        for (name, param) in &self.parameters {
822            diagnostics.insert(name.clone(), param.diagnose()?);
823        }
824        Ok(diagnostics)
825    }
826
827    /// Get total parameter count
828    pub fn total_parameters(&self) -> usize {
829        self.parameters
830            .values()
831            .map(|p| p.numel().unwrap_or(0))
832            .sum()
833    }
834
835    /// Get total memory usage
836    pub fn total_memory_usage(&self) -> usize {
837        self.parameters
838            .values()
839            .map(|p| p.numel().unwrap_or(0) * 4) // Assume f32
840            .sum()
841    }
842
843    /// Freeze all parameters
844    pub fn freeze_all(&mut self) {
845        for param in self.parameters.values_mut() {
846            param.requires_grad = false;
847        }
848    }
849
850    /// Unfreeze all parameters
851    pub fn unfreeze_all(&mut self) {
852        for param in self.parameters.values_mut() {
853            param.requires_grad = true;
854        }
855    }
856
857    /// Scale all parameters by a factor
858    pub fn scale_all(&mut self, factor: f32) -> Result<()> {
859        self.apply_to_all(|param| param.scale(factor))
860    }
861
862    /// Clamp all parameters to a range
863    pub fn clamp_all(&mut self, min: f32, max: f32) -> Result<()> {
864        self.apply_to_all(|param| param.clamp(min, max))
865    }
866
867    /// Add noise to all parameters
868    pub fn add_noise_all(&mut self, std: f32) -> Result<()> {
869        self.apply_to_all(|param| param.add_noise(std))
870    }
871
872    /// Filter parameters by name pattern
873    pub fn filter_by_name(&self, pattern: &str) -> ParameterCollection {
874        let filtered: HashMap<String, Parameter> = self
875            .parameters
876            .iter()
877            .filter(|(name, _)| name.contains(pattern))
878            .map(|(name, param)| (name.clone(), param.clone()))
879            .collect();
880        ParameterCollection::from_map(filtered)
881    }
882
883    /// Filter parameters by predicate
884    pub fn filter_by<F>(&self, predicate: F) -> ParameterCollection
885    where
886        F: Fn(&String, &Parameter) -> bool,
887    {
888        let filtered: HashMap<String, Parameter> = self
889            .parameters
890            .iter()
891            .filter(|(name, param)| predicate(name, param))
892            .map(|(name, param)| (name.clone(), param.clone()))
893            .collect();
894        ParameterCollection::from_map(filtered)
895    }
896
897    /// Create a summary report of all parameters
898    pub fn summary_report(&self) -> Result<String> {
899        let mut report = String::new();
900        report.push_str("Parameter Collection Summary\n");
901        report.push_str(&format!("Total parameters: {}\n", self.len()));
902        report.push_str(&format!("Total elements: {}\n", self.total_parameters()));
903        report.push_str(&format!(
904            "Memory usage: {:.2} MB\n",
905            self.total_memory_usage() as f64 / (1024.0 * 1024.0)
906        ));
907        report.push_str("\nParameter Details:\n");
908
909        for (name, param) in &self.parameters {
910            let stats = param.stats()?;
911            report.push_str(&format!(
912                "  {}: {} elements, mean={:.4}, std={:.4}\n",
913                name, stats.numel, stats.mean, stats.std
914            ));
915        }
916
917        Ok(report)
918    }
919}
920
921impl Default for ParameterCollection {
922    fn default() -> Self {
923        Self::new()
924    }
925}
926
927impl From<HashMap<String, Parameter>> for ParameterCollection {
928    fn from(parameters: HashMap<String, Parameter>) -> Self {
929        Self::from_map(parameters)
930    }
931}
932
933impl From<ParameterCollection> for HashMap<String, Parameter> {
934    fn from(val: ParameterCollection) -> Self {
935        val.parameters
936    }
937}
938
939#[cfg(test)]
940mod tests {
941    use super::*;
942    use approx::assert_relative_eq;
943    use torsh_core::device::DeviceType;
944    use torsh_tensor::creation::zeros;
945
946    // ========================================================================
947    // Parameter Creation Tests
948    // ========================================================================
949
950    #[test]
951    fn test_parameter_new() -> Result<()> {
952        let tensor = zeros(&[3, 4])?;
953        let param = Parameter::new(tensor);
954
955        assert!(param.requires_grad());
956        assert_eq!(param.shape()?, vec![3, 4]);
957        assert_eq!(param.numel()?, 12);
958        Ok(())
959    }
960
961    #[test]
962    fn test_parameter_new_no_grad() -> Result<()> {
963        let tensor = zeros(&[2, 3])?;
964        let param = Parameter::new_no_grad(tensor);
965
966        assert!(!param.requires_grad());
967        assert_eq!(param.shape()?, vec![2, 3]);
968        Ok(())
969    }
970
971    #[test]
972    fn test_parameter_from_tensor() -> Result<()> {
973        let tensor = zeros(&[5])?;
974        let arc_tensor = Arc::new(RwLock::new(tensor));
975        let param = Parameter::from_tensor(arc_tensor);
976
977        assert!(param.requires_grad());
978        assert_eq!(param.shape()?, vec![5]);
979        Ok(())
980    }
981
982    #[test]
983    fn test_parameter_requires_grad_setter() -> Result<()> {
984        let tensor = zeros(&[2, 2])?;
985        let param = Parameter::new(tensor).requires_grad_(false);
986
987        assert!(!param.requires_grad());
988        Ok(())
989    }
990
991    #[test]
992    fn test_parameter_from_data() -> Result<()> {
993        let data = vec![1.0, 2.0, 3.0, 4.0];
994        let param = Parameter::from_data(data.clone(), vec![2, 2])?;
995
996        assert_eq!(param.shape()?, vec![2, 2]);
997        assert_eq!(param.numel()?, 4);
998
999        let tensor_data = param.clone_data().to_vec()?;
1000        assert_eq!(tensor_data, data);
1001        Ok(())
1002    }
1003
1004    #[test]
1005    fn test_parameter_from_data_auto_shape() -> Result<()> {
1006        let data = vec![1.0, 2.0, 3.0];
1007        let param = Parameter::from_data_auto_shape(data.clone())?;
1008
1009        assert_eq!(param.shape()?, vec![3]);
1010        assert_eq!(param.numel()?, 3);
1011        Ok(())
1012    }
1013
1014    // ========================================================================
1015    // Parameter Initialization Tests
1016    // ========================================================================
1017
1018    #[test]
1019    fn test_parameter_zeros() -> Result<()> {
1020        let param = Parameter::zeros(vec![2, 3], DeviceType::Cpu)?;
1021        let data = param.clone_data().to_vec()?;
1022
1023        assert_eq!(param.numel()?, 6);
1024        assert!(data.iter().all(|&x| x == 0.0));
1025        Ok(())
1026    }
1027
1028    #[test]
1029    fn test_parameter_ones() -> Result<()> {
1030        let param = Parameter::ones(vec![3, 2], DeviceType::Cpu)?;
1031        let data = param.clone_data().to_vec()?;
1032
1033        assert_eq!(param.numel()?, 6);
1034        assert!(data.iter().all(|&x| x == 1.0));
1035        Ok(())
1036    }
1037
1038    #[test]
1039    fn test_parameter_constant() -> Result<()> {
1040        let param = Parameter::constant(vec![2, 2], DeviceType::Cpu, 5.0)?;
1041        let data = param.clone_data().to_vec()?;
1042
1043        assert!(data.iter().all(|&x| (x - 5.0).abs() < 1e-6));
1044        Ok(())
1045    }
1046
1047    #[test]
1048    fn test_parameter_uniform() -> Result<()> {
1049        let param = Parameter::uniform(vec![100], DeviceType::Cpu, -1.0, 1.0)?;
1050        let data = param.clone_data().to_vec()?;
1051
1052        // Check all values are in range
1053        assert!(data.iter().all(|&x| x >= -1.0 && x <= 1.0));
1054
1055        // Check distribution statistics
1056        let mean = data.iter().sum::<f32>() / data.len() as f32;
1057        assert!(mean.abs() < 0.2); // Should be close to 0
1058        Ok(())
1059    }
1060
1061    #[test]
1062    fn test_parameter_normal() -> Result<()> {
1063        let param = Parameter::normal(vec![1000], DeviceType::Cpu, 0.0, 1.0)?;
1064        let data = param.clone_data().to_vec()?;
1065
1066        let mean = data.iter().sum::<f32>() / data.len() as f32;
1067        let variance = data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / data.len() as f32;
1068
1069        // Normal distribution should have mean ≈ 0 and variance ≈ 1
1070        assert!(mean.abs() < 0.1);
1071        assert!((variance - 1.0).abs() < 0.2);
1072        Ok(())
1073    }
1074
1075    #[test]
1076    fn test_parameter_xavier_uniform() -> Result<()> {
1077        let shape = vec![100, 50];
1078        let param = Parameter::xavier_uniform(shape.clone(), DeviceType::Cpu, 1.0)?;
1079
1080        assert_eq!(param.shape()?, shape);
1081        assert!(param.is_finite()?);
1082        Ok(())
1083    }
1084
1085    #[test]
1086    fn test_parameter_xavier_normal() -> Result<()> {
1087        let shape = vec![50, 100];
1088        let param = Parameter::xavier_normal(shape.clone(), DeviceType::Cpu, 1.0)?;
1089
1090        assert_eq!(param.shape()?, shape);
1091        assert!(param.is_finite()?);
1092        Ok(())
1093    }
1094
1095    #[test]
1096    fn test_parameter_kaiming_uniform() -> Result<()> {
1097        let shape = vec![64, 32];
1098        let param = Parameter::kaiming_uniform(shape.clone(), DeviceType::Cpu, "relu")?;
1099
1100        assert_eq!(param.shape()?, shape);
1101        assert!(param.is_finite()?);
1102        Ok(())
1103    }
1104
1105    #[test]
1106    fn test_parameter_kaiming_normal() -> Result<()> {
1107        let shape = vec![32, 64];
1108        let param = Parameter::kaiming_normal(shape.clone(), DeviceType::Cpu, "relu")?;
1109
1110        assert_eq!(param.shape()?, shape);
1111        assert!(param.is_finite()?);
1112        Ok(())
1113    }
1114
1115    #[test]
1116    fn test_parameter_lecun_uniform() -> Result<()> {
1117        let shape = vec![50, 50];
1118        let param = Parameter::lecun_uniform(shape.clone(), DeviceType::Cpu)?;
1119
1120        assert_eq!(param.shape()?, shape);
1121        assert!(param.is_finite()?);
1122        Ok(())
1123    }
1124
1125    #[test]
1126    fn test_parameter_lecun_normal() -> Result<()> {
1127        let shape = vec![50, 50];
1128        let param = Parameter::lecun_normal(shape.clone(), DeviceType::Cpu)?;
1129
1130        assert_eq!(param.shape()?, shape);
1131        assert!(param.is_finite()?);
1132        Ok(())
1133    }
1134
1135    #[test]
1136    fn test_parameter_truncated_normal() -> Result<()> {
1137        let param = Parameter::truncated_normal(vec![100], DeviceType::Cpu, 0.0, 1.0, -2.0, 2.0)?;
1138
1139        let data = param.clone_data().to_vec()?;
1140        // All values should be within truncation bounds
1141        assert!(data.iter().all(|&x| x >= -2.0 && x <= 2.0));
1142        Ok(())
1143    }
1144
1145    #[test]
1146    fn test_parameter_eye() -> Result<()> {
1147        let param = Parameter::eye(vec![3, 3], DeviceType::Cpu)?;
1148        let data = param.clone_data().to_vec()?;
1149
1150        // Check diagonal elements are 1
1151        assert_relative_eq!(data[0], 1.0, epsilon = 1e-6); // [0,0]
1152        assert_relative_eq!(data[4], 1.0, epsilon = 1e-6); // [1,1]
1153        assert_relative_eq!(data[8], 1.0, epsilon = 1e-6); // [2,2]
1154
1155        // Check off-diagonal elements are 0
1156        assert_relative_eq!(data[1], 0.0, epsilon = 1e-6);
1157        assert_relative_eq!(data[2], 0.0, epsilon = 1e-6);
1158        Ok(())
1159    }
1160
1161    #[test]
1162    fn test_parameter_auto_init_linear() -> Result<()> {
1163        let param = Parameter::auto_init(vec![10, 5], DeviceType::Cpu, LayerType::Linear)?;
1164
1165        assert_eq!(param.shape()?, vec![10, 5]);
1166        assert!(param.is_finite()?);
1167        Ok(())
1168    }
1169
1170    #[test]
1171    fn test_parameter_auto_init_conv() -> Result<()> {
1172        let param = Parameter::auto_init(vec![3, 3, 32, 64], DeviceType::Cpu, LayerType::Conv)?;
1173
1174        assert_eq!(param.shape()?, vec![3, 3, 32, 64]);
1175        assert!(param.is_finite()?);
1176        Ok(())
1177    }
1178
1179    #[test]
1180    fn test_parameter_auto_init_embedding() -> Result<()> {
1181        let param = Parameter::auto_init(vec![1000, 128], DeviceType::Cpu, LayerType::Embedding)?;
1182
1183        assert_eq!(param.shape()?, vec![1000, 128]);
1184        assert!(param.is_finite()?);
1185        Ok(())
1186    }
1187
1188    // ========================================================================
1189    // Parameter Statistics Tests
1190    // ========================================================================
1191
1192    #[test]
1193    fn test_parameter_stats() -> Result<()> {
1194        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1195        let param = Parameter::from_data(data, vec![5])?;
1196        let stats = param.stats()?;
1197
1198        assert_relative_eq!(stats.mean, 3.0, epsilon = 1e-5);
1199        assert_relative_eq!(stats.min, 1.0, epsilon = 1e-5);
1200        assert_relative_eq!(stats.max, 5.0, epsilon = 1e-5);
1201        assert_eq!(stats.numel, 5);
1202        assert_relative_eq!(stats.median, 3.0, epsilon = 1e-5);
1203        Ok(())
1204    }
1205
1206    #[test]
1207    fn test_parameter_stats_empty() -> Result<()> {
1208        let param: Parameter = Parameter::from_data(vec![], vec![0])?;
1209        let stats = param.stats()?;
1210
1211        assert_eq!(stats.numel, 0);
1212        assert_eq!(stats.mean, 0.0);
1213        assert_eq!(stats.std, 0.0);
1214        Ok(())
1215    }
1216
1217    #[test]
1218    fn test_parameter_norm_l2() -> Result<()> {
1219        let data = vec![3.0, 4.0]; // 3^2 + 4^2 = 25, sqrt = 5
1220        let param = Parameter::from_data(data, vec![2])?;
1221        let norm = param.norm()?;
1222
1223        assert_relative_eq!(norm, 5.0, epsilon = 1e-5);
1224        Ok(())
1225    }
1226
1227    #[test]
1228    fn test_parameter_norm_l1() -> Result<()> {
1229        let data = vec![3.0, -4.0, 5.0]; // |3| + |-4| + |5| = 12
1230        let param = Parameter::from_data(data, vec![3])?;
1231        let norm = param.l1_norm()?;
1232
1233        assert_relative_eq!(norm, 12.0, epsilon = 1e-5);
1234        Ok(())
1235    }
1236
1237    #[test]
1238    fn test_parameter_norm_linf() -> Result<()> {
1239        let data = vec![1.0, -5.0, 3.0]; // max(|1|, |-5|, |3|) = 5
1240        let param = Parameter::from_data(data, vec![3])?;
1241        let norm = param.linf_norm()?;
1242
1243        assert_relative_eq!(norm, 5.0, epsilon = 1e-5);
1244        Ok(())
1245    }
1246
1247    // ========================================================================
1248    // Parameter Operations Tests
1249    // ========================================================================
1250
1251    #[test]
1252    fn test_parameter_clamp() -> Result<()> {
1253        let data = vec![-5.0, 0.0, 5.0, 10.0];
1254        let mut param = Parameter::from_data(data, vec![4])?;
1255
1256        param.clamp(0.0, 5.0)?;
1257
1258        let clamped = param.clone_data().to_vec()?;
1259        assert_relative_eq!(clamped[0], 0.0, epsilon = 1e-5); // -5 clamped to 0
1260        assert_relative_eq!(clamped[1], 0.0, epsilon = 1e-5);
1261        assert_relative_eq!(clamped[2], 5.0, epsilon = 1e-5);
1262        assert_relative_eq!(clamped[3], 5.0, epsilon = 1e-5); // 10 clamped to 5
1263        Ok(())
1264    }
1265
1266    #[test]
1267    fn test_parameter_scale() -> Result<()> {
1268        let data = vec![1.0, 2.0, 3.0];
1269        let mut param = Parameter::from_data(data, vec![3])?;
1270
1271        param.scale(2.0)?;
1272
1273        let scaled = param.clone_data().to_vec()?;
1274        assert_relative_eq!(scaled[0], 2.0, epsilon = 1e-5);
1275        assert_relative_eq!(scaled[1], 4.0, epsilon = 1e-5);
1276        assert_relative_eq!(scaled[2], 6.0, epsilon = 1e-5);
1277        Ok(())
1278    }
1279
1280    #[test]
1281    fn test_parameter_apply_fn() -> Result<()> {
1282        let data = vec![1.0, 2.0, 3.0];
1283        let mut param = Parameter::from_data(data, vec![3])?;
1284
1285        param.apply_fn(|x| x * x)?; // Square each element
1286
1287        let result = param.clone_data().to_vec()?;
1288        assert_relative_eq!(result[0], 1.0, epsilon = 1e-5);
1289        assert_relative_eq!(result[1], 4.0, epsilon = 1e-5);
1290        assert_relative_eq!(result[2], 9.0, epsilon = 1e-5);
1291        Ok(())
1292    }
1293
1294    #[test]
1295    fn test_parameter_add_noise() -> Result<()> {
1296        let data = vec![0.0; 100];
1297        let mut param = Parameter::from_data(data, vec![100])?;
1298
1299        param.add_noise(0.1)?;
1300
1301        let noisy = param.clone_data().to_vec()?;
1302        // After adding noise, values should no longer all be 0
1303        let all_zero = noisy.iter().all(|&x| x == 0.0);
1304        assert!(!all_zero);
1305        Ok(())
1306    }
1307
1308    #[test]
1309    fn test_parameter_is_finite() -> Result<()> {
1310        let data = vec![1.0, 2.0, 3.0];
1311        let param = Parameter::from_data(data, vec![3])?;
1312
1313        assert!(param.is_finite()?);
1314        Ok(())
1315    }
1316
1317    #[test]
1318    fn test_parameter_reinitialize() -> Result<()> {
1319        let mut param = Parameter::zeros(vec![5], DeviceType::Cpu)?;
1320
1321        use crate::init::InitMethod;
1322        param.reinitialize(InitMethod::Constant { value: 7.0 })?;
1323
1324        let data = param.clone_data().to_vec()?;
1325        assert!(data.iter().all(|&x| (x - 7.0).abs() < 1e-6));
1326        Ok(())
1327    }
1328
1329    #[test]
1330    fn test_parameter_histogram() -> Result<()> {
1331        let data: Vec<f32> = (0..100).map(|i| i as f32).collect();
1332        let param = Parameter::from_data(data, vec![100])?;
1333
1334        let histogram = param.histogram(10)?;
1335
1336        assert_eq!(histogram.len(), 10);
1337        // Each bin should have roughly 10 elements
1338        for (_, count) in histogram {
1339            assert!(count >= 9 && count <= 11);
1340        }
1341        Ok(())
1342    }
1343
1344    #[test]
1345    fn test_parameter_histogram_constant() -> Result<()> {
1346        let data = vec![5.0; 10];
1347        let param = Parameter::from_data(data, vec![10])?;
1348
1349        let histogram = param.histogram(5)?;
1350
1351        // All values are the same, should return single bin
1352        assert_eq!(histogram.len(), 1);
1353        assert_eq!(histogram[0].1, 10);
1354        Ok(())
1355    }
1356
1357    #[test]
1358    fn test_parameter_diagnose_normal() -> Result<()> {
1359        let data = vec![1.0, 2.0, 3.0, 4.0];
1360        let param = Parameter::from_data(data, vec![4])?;
1361
1362        let diagnostics = param.diagnose()?;
1363
1364        assert!(diagnostics.is_finite);
1365        assert!(diagnostics.issues.is_empty());
1366        assert_eq!(diagnostics.stats.numel, 4);
1367        Ok(())
1368    }
1369
1370    // ========================================================================
1371    // ParameterStats Tests
1372    // ========================================================================
1373
1374    #[test]
1375    fn test_parameter_stats_from_data() {
1376        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1377        let stats = ParameterStats::from_data(&data);
1378
1379        assert_relative_eq!(stats.mean, 3.0, epsilon = 1e-5);
1380        assert_relative_eq!(stats.median, 3.0, epsilon = 1e-5);
1381        assert_relative_eq!(stats.min, 1.0, epsilon = 1e-5);
1382        assert_relative_eq!(stats.max, 5.0, epsilon = 1e-5);
1383        assert_eq!(stats.numel, 5);
1384    }
1385
1386    #[test]
1387    fn test_parameter_stats_empty_constructor() {
1388        let stats = ParameterStats::empty();
1389
1390        assert_eq!(stats.mean, 0.0);
1391        assert_eq!(stats.std, 0.0);
1392        assert_eq!(stats.numel, 0);
1393    }
1394
1395    #[test]
1396    fn test_parameter_stats_iqr() {
1397        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1398        let stats = ParameterStats::from_data(&data);
1399
1400        let iqr = stats.iqr();
1401        assert!(iqr > 0.0);
1402    }
1403
1404    #[test]
1405    fn test_parameter_stats_is_approximately_normal() {
1406        // Create approximately normal data
1407        let data: Vec<f32> = vec![-1.0, -0.5, 0.0, 0.5, 1.0, -0.3, 0.3, -0.8, 0.8, -0.2, 0.2];
1408        let stats = ParameterStats::from_data(&data);
1409
1410        // This is a simple heuristic check
1411        assert!(stats.skewness.abs() < 2.0); // Not too skewed
1412    }
1413
1414    // ========================================================================
1415    // ParameterCollection Tests
1416    // ========================================================================
1417
1418    #[test]
1419    fn test_parameter_collection_new() {
1420        let collection = ParameterCollection::new();
1421
1422        assert_eq!(collection.len(), 0);
1423        assert!(collection.is_empty());
1424    }
1425
1426    #[test]
1427    fn test_parameter_collection_add_get() -> Result<()> {
1428        let mut collection = ParameterCollection::new();
1429
1430        let param = Parameter::zeros(vec![3, 4], DeviceType::Cpu)?;
1431        collection.add("weight".to_string(), param);
1432
1433        assert_eq!(collection.len(), 1);
1434        assert!(collection.get("weight").is_some());
1435        assert!(collection.get("bias").is_none());
1436        Ok(())
1437    }
1438
1439    #[test]
1440    fn test_parameter_collection_names() -> Result<()> {
1441        let mut collection = ParameterCollection::new();
1442
1443        collection.add(
1444            "weight".to_string(),
1445            Parameter::zeros(vec![2, 2], DeviceType::Cpu)?,
1446        );
1447        collection.add(
1448            "bias".to_string(),
1449            Parameter::zeros(vec![2], DeviceType::Cpu)?,
1450        );
1451
1452        let names = collection.names();
1453        assert_eq!(names.len(), 2);
1454        assert!(names.contains(&&"weight".to_string()));
1455        assert!(names.contains(&&"bias".to_string()));
1456        Ok(())
1457    }
1458
1459    #[test]
1460    fn test_parameter_collection_total_parameters() -> Result<()> {
1461        let mut collection = ParameterCollection::new();
1462
1463        collection.add(
1464            "weight".to_string(),
1465            Parameter::zeros(vec![2, 3], DeviceType::Cpu)?,
1466        ); // 6 params
1467        collection.add(
1468            "bias".to_string(),
1469            Parameter::zeros(vec![3], DeviceType::Cpu)?,
1470        ); // 3 params
1471
1472        assert_eq!(collection.total_parameters(), 9);
1473        Ok(())
1474    }
1475
1476    #[test]
1477    fn test_parameter_collection_total_memory_usage() -> Result<()> {
1478        let mut collection = ParameterCollection::new();
1479
1480        collection.add(
1481            "weight".to_string(),
1482            Parameter::zeros(vec![10], DeviceType::Cpu)?,
1483        );
1484
1485        let memory = collection.total_memory_usage();
1486        assert_eq!(memory, 10 * 4); // 10 f32 elements * 4 bytes
1487        Ok(())
1488    }
1489
1490    #[test]
1491    fn test_parameter_collection_freeze_unfreeze() -> Result<()> {
1492        let mut collection = ParameterCollection::new();
1493
1494        let param = Parameter::zeros(vec![2], DeviceType::Cpu)?;
1495        collection.add("weight".to_string(), param);
1496
1497        collection.freeze_all();
1498        assert!(!collection.get("weight").unwrap().requires_grad());
1499
1500        collection.unfreeze_all();
1501        assert!(collection.get("weight").unwrap().requires_grad());
1502        Ok(())
1503    }
1504
1505    #[test]
1506    fn test_parameter_collection_scale_all() -> Result<()> {
1507        let mut collection = ParameterCollection::new();
1508
1509        let param = Parameter::ones(vec![3], DeviceType::Cpu)?;
1510        collection.add("weight".to_string(), param);
1511
1512        collection.scale_all(2.0)?;
1513
1514        let weight = collection.get("weight").unwrap();
1515        let data = weight.clone_data().to_vec()?;
1516        assert!(data.iter().all(|&x| (x - 2.0).abs() < 1e-5));
1517        Ok(())
1518    }
1519
1520    #[test]
1521    fn test_parameter_collection_clamp_all() -> Result<()> {
1522        let mut collection = ParameterCollection::new();
1523
1524        let data = vec![-5.0, 0.0, 5.0];
1525        let param = Parameter::from_data(data, vec![3])?;
1526        collection.add("weight".to_string(), param);
1527
1528        collection.clamp_all(-1.0, 1.0)?;
1529
1530        let weight = collection.get("weight").unwrap();
1531        let clamped = weight.clone_data().to_vec()?;
1532        assert!(clamped.iter().all(|&x| x >= -1.0 && x <= 1.0));
1533        Ok(())
1534    }
1535
1536    #[test]
1537    fn test_parameter_collection_filter_by_name() -> Result<()> {
1538        let mut collection = ParameterCollection::new();
1539
1540        collection.add(
1541            "layer1.weight".to_string(),
1542            Parameter::zeros(vec![2], DeviceType::Cpu)?,
1543        );
1544        collection.add(
1545            "layer1.bias".to_string(),
1546            Parameter::zeros(vec![2], DeviceType::Cpu)?,
1547        );
1548        collection.add(
1549            "layer2.weight".to_string(),
1550            Parameter::zeros(vec![2], DeviceType::Cpu)?,
1551        );
1552
1553        let filtered = collection.filter_by_name("layer1");
1554        assert_eq!(filtered.len(), 2);
1555
1556        let filtered_weight = collection.filter_by_name("weight");
1557        assert_eq!(filtered_weight.len(), 2);
1558        Ok(())
1559    }
1560
1561    #[test]
1562    fn test_parameter_collection_filter_by_predicate() -> Result<()> {
1563        let mut collection = ParameterCollection::new();
1564
1565        collection.add(
1566            "weight".to_string(),
1567            Parameter::zeros(vec![10], DeviceType::Cpu)?,
1568        );
1569        collection.add(
1570            "bias".to_string(),
1571            Parameter::zeros(vec![5], DeviceType::Cpu)?,
1572        );
1573
1574        // Filter parameters with > 5 elements
1575        let filtered = collection.filter_by(|_, param| param.numel().unwrap_or(0) > 5);
1576        assert_eq!(filtered.len(), 1);
1577        assert!(filtered.get("weight").is_some());
1578        Ok(())
1579    }
1580
1581    #[test]
1582    fn test_parameter_collection_stats() -> Result<()> {
1583        let mut collection = ParameterCollection::new();
1584
1585        collection.add(
1586            "weight".to_string(),
1587            Parameter::ones(vec![5], DeviceType::Cpu)?,
1588        );
1589
1590        let stats = collection.stats()?;
1591        assert_eq!(stats.len(), 1);
1592
1593        let weight_stats = stats.get("weight").unwrap();
1594        assert_relative_eq!(weight_stats.mean, 1.0, epsilon = 1e-5);
1595        Ok(())
1596    }
1597
1598    #[test]
1599    fn test_parameter_collection_diagnose() -> Result<()> {
1600        let mut collection = ParameterCollection::new();
1601
1602        collection.add(
1603            "weight".to_string(),
1604            Parameter::ones(vec![3], DeviceType::Cpu)?,
1605        );
1606
1607        let diagnostics = collection.diagnose()?;
1608        assert_eq!(diagnostics.len(), 1);
1609
1610        let weight_diag = diagnostics.get("weight").unwrap();
1611        assert!(weight_diag.is_finite);
1612        Ok(())
1613    }
1614
1615    #[test]
1616    fn test_parameter_collection_summary_report() -> Result<()> {
1617        let mut collection = ParameterCollection::new();
1618
1619        collection.add(
1620            "weight".to_string(),
1621            Parameter::ones(vec![10], DeviceType::Cpu)?,
1622        );
1623        collection.add(
1624            "bias".to_string(),
1625            Parameter::zeros(vec![5], DeviceType::Cpu)?,
1626        );
1627
1628        let report = collection.summary_report()?;
1629
1630        assert!(report.contains("Total parameters: 2"));
1631        assert!(report.contains("Total elements: 15"));
1632        assert!(report.contains("weight"));
1633        assert!(report.contains("bias"));
1634        Ok(())
1635    }
1636
1637    #[test]
1638    fn test_parameter_collection_from_hashmap() -> Result<()> {
1639        let mut map = HashMap::new();
1640        map.insert(
1641            "weight".to_string(),
1642            Parameter::zeros(vec![3], DeviceType::Cpu)?,
1643        );
1644
1645        let collection = ParameterCollection::from_map(map);
1646        assert_eq!(collection.len(), 1);
1647        Ok(())
1648    }
1649
1650    #[test]
1651    fn test_parameter_collection_into_hashmap() -> Result<()> {
1652        let mut collection = ParameterCollection::new();
1653        collection.add(
1654            "weight".to_string(),
1655            Parameter::zeros(vec![3], DeviceType::Cpu)?,
1656        );
1657
1658        let map: HashMap<String, Parameter> = collection.into();
1659        assert_eq!(map.len(), 1);
1660        assert!(map.contains_key("weight"));
1661        Ok(())
1662    }
1663}