Skip to main content

tenflowers_dataset/statistics/
computation.rs

1//! Statistics computation implementation
2//!
3//! This module contains the core logic for computing dataset statistics
4//! including mean, standard deviation, histograms, and min/max values.
5
6use crate::Dataset;
7use std::collections::HashMap;
8use tenflowers_core::{Result, Tensor, TensorError};
9
10use super::core::{DatasetStats, Histogram, StatisticsConfig};
11
12/// Statistics computer for datasets
13pub struct DatasetStatisticsComputer;
14
15impl DatasetStatisticsComputer {
16    /// Compute statistics for a dataset
17    pub fn compute<T, D>(dataset: &D, config: StatisticsConfig) -> Result<DatasetStats<T>>
18    where
19        T: Clone
20            + Default
21            + scirs2_core::numeric::Zero
22            + scirs2_core::numeric::Float
23            + std::fmt::Debug
24            + Send
25            + Sync
26            + 'static,
27        D: Dataset<T>,
28    {
29        if dataset.is_empty() {
30            return Err(TensorError::invalid_argument(
31                "Cannot compute statistics on empty dataset".to_string(),
32            ));
33        }
34
35        let sample_count = dataset.len();
36        let first_sample = dataset.get(0)?;
37        let feature_count = first_sample.0.shape().dims().iter().product::<usize>();
38
39        let mut stats = DatasetStats::new(feature_count, sample_count);
40
41        // Collect all feature vectors
42        let mut all_features = Vec::new();
43        for i in 0..sample_count {
44            let (features, _) = dataset.get(i)?;
45            let feature_vec = Self::tensor_to_vec(&features)?;
46            all_features.push(feature_vec);
47        }
48
49        // Compute mean
50        if config.compute_mean {
51            stats.mean = Some(Self::compute_mean(&all_features)?);
52        }
53
54        // Compute standard deviation
55        if config.compute_std {
56            let mean = if let Some(ref mean) = stats.mean {
57                mean.clone()
58            } else {
59                Self::compute_mean(&all_features)?
60            };
61            stats.std = Some(Self::compute_std(&all_features, &mean)?);
62        }
63
64        // Compute min/max
65        if config.compute_min_max {
66            let (min, max) = Self::compute_min_max(&all_features)?;
67            stats.min = Some(min);
68            stats.max = Some(max);
69        }
70
71        // Compute histogram
72        if config.compute_histogram {
73            let min = if let Some(ref min) = stats.min {
74                min.clone()
75            } else {
76                Self::compute_min_max(&all_features)?.0
77            };
78            let max = if let Some(ref max) = stats.max {
79                max.clone()
80            } else {
81                Self::compute_min_max(&all_features)?.1
82            };
83            stats.histogram = Some(Self::compute_histogram(
84                &all_features,
85                &min,
86                &max,
87                config.histogram_bins,
88            )?);
89        }
90
91        // Compute class distribution
92        if config.compute_class_distribution {
93            let mut class_counts = HashMap::new();
94            for i in 0..sample_count {
95                let (_, label) = dataset.get(i)?;
96                let label_str = format!("{label:?}");
97                *class_counts.entry(label_str).or_insert(0) += 1;
98            }
99            stats.class_distribution = Some(class_counts);
100        }
101
102        Ok(stats)
103    }
104
105    /// Convert tensor to vector
106    pub fn tensor_to_vec<T>(tensor: &Tensor<T>) -> Result<Vec<T>>
107    where
108        T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
109    {
110        // Get the raw data from tensor
111        let data = tensor.as_slice().ok_or_else(|| {
112            TensorError::invalid_argument(
113                "Cannot access tensor data (GPU tensor not supported)".to_string(),
114            )
115        })?;
116        Ok(data.to_vec())
117    }
118
119    /// Compute mean of feature vectors
120    pub fn compute_mean<T>(features: &[Vec<T>]) -> Result<Vec<T>>
121    where
122        T: Clone + Default + scirs2_core::numeric::Zero + scirs2_core::numeric::Float,
123    {
124        if features.is_empty() {
125            return Err(TensorError::invalid_argument(
126                "Cannot compute mean of empty features".to_string(),
127            ));
128        }
129
130        let feature_count = features[0].len();
131        let mut mean = vec![T::zero(); feature_count];
132
133        for feature_vec in features {
134            for (i, &value) in feature_vec.iter().enumerate() {
135                mean[i] = mean[i] + value;
136            }
137        }
138
139        let n = T::from(features.len()).expect("feature count should convert to float");
140        for mean_val in &mut mean {
141            *mean_val = *mean_val / n;
142        }
143
144        Ok(mean)
145    }
146
147    /// Compute standard deviation of feature vectors
148    fn compute_std<T>(features: &[Vec<T>], mean: &[T]) -> Result<Vec<T>>
149    where
150        T: Clone + Default + scirs2_core::numeric::Zero + scirs2_core::numeric::Float,
151    {
152        if features.is_empty() {
153            return Err(TensorError::invalid_argument(
154                "Cannot compute std of empty features".to_string(),
155            ));
156        }
157
158        let feature_count = features[0].len();
159        let mut variance = vec![T::zero(); feature_count];
160
161        for feature_vec in features {
162            for (i, &value) in feature_vec.iter().enumerate() {
163                let diff = value - mean[i];
164                variance[i] = variance[i] + diff * diff;
165            }
166        }
167
168        let n = T::from(features.len()).expect("feature count should convert to float");
169        let mut std = Vec::new();
170        for var_val in variance {
171            let std_val = (var_val / n).sqrt();
172            std.push(std_val);
173        }
174
175        Ok(std)
176    }
177
178    /// Compute min and max of feature vectors
179    fn compute_min_max<T>(features: &[Vec<T>]) -> Result<(Vec<T>, Vec<T>)>
180    where
181        T: Clone + Default + scirs2_core::numeric::Zero + scirs2_core::numeric::Float,
182    {
183        if features.is_empty() {
184            return Err(TensorError::invalid_argument(
185                "Cannot compute min/max of empty features".to_string(),
186            ));
187        }
188
189        let _feature_count = features[0].len();
190        let mut min_vals = features[0].clone();
191        let mut max_vals = features[0].clone();
192
193        for feature_vec in features.iter().skip(1) {
194            for (i, &value) in feature_vec.iter().enumerate() {
195                if value < min_vals[i] {
196                    min_vals[i] = value;
197                }
198                if value > max_vals[i] {
199                    max_vals[i] = value;
200                }
201            }
202        }
203
204        Ok((min_vals, max_vals))
205    }
206
207    /// Compute histogram of feature vectors
208    fn compute_histogram<T>(
209        features: &[Vec<T>],
210        min_vals: &[T],
211        max_vals: &[T],
212        bins: usize,
213    ) -> Result<Histogram<T>>
214    where
215        T: Clone + Default + scirs2_core::numeric::Zero + scirs2_core::numeric::Float,
216    {
217        if features.is_empty() {
218            return Err(TensorError::invalid_argument(
219                "Cannot compute histogram of empty features".to_string(),
220            ));
221        }
222
223        // For simplicity, compute histogram for the first feature only
224        let feature_idx = 0;
225        let min_val = min_vals[feature_idx];
226        let max_val = max_vals[feature_idx];
227
228        // Create bin edges
229        let mut bin_edges = Vec::new();
230        let step = (max_val - min_val) / T::from(bins).expect("bin count should convert to float");
231        for i in 0..=bins {
232            bin_edges.push(min_val + T::from(i).expect("bin index should convert to float") * step);
233        }
234
235        // Count values in each bin
236        let mut counts = vec![0usize; bins];
237        for feature_vec in features {
238            let value = feature_vec[feature_idx];
239            let bin_idx = if value >= max_val {
240                bins - 1
241            } else {
242                let normalized = (value - min_val) / (max_val - min_val);
243                let idx = (normalized * T::from(bins).expect("bin count should convert to float"))
244                    .to_usize()
245                    .unwrap_or(0);
246                idx.min(bins - 1)
247            };
248            counts[bin_idx] += 1;
249        }
250
251        // Create bin centers
252        let mut bin_centers = Vec::new();
253        for i in 0..bins {
254            let center = (bin_edges[i] + bin_edges[i + 1])
255                / T::from(2).expect("constant 2 should convert to float");
256            bin_centers.push(center);
257        }
258
259        Ok(Histogram {
260            bins: bin_centers,
261            counts,
262            bin_edges,
263        })
264    }
265}