tenflowers_dataset/statistics/
computation.rs1use crate::Dataset;
7use std::collections::HashMap;
8use tenflowers_core::{Result, Tensor, TensorError};
9
10use super::core::{DatasetStats, Histogram, StatisticsConfig};
11
12pub struct DatasetStatisticsComputer;
14
15impl DatasetStatisticsComputer {
16 pub fn compute<T, D>(dataset: &D, config: StatisticsConfig) -> Result<DatasetStats<T>>
18 where
19 T: Clone
20 + Default
21 + scirs2_core::numeric::Zero
22 + scirs2_core::numeric::Float
23 + std::fmt::Debug
24 + Send
25 + Sync
26 + 'static,
27 D: Dataset<T>,
28 {
29 if dataset.is_empty() {
30 return Err(TensorError::invalid_argument(
31 "Cannot compute statistics on empty dataset".to_string(),
32 ));
33 }
34
35 let sample_count = dataset.len();
36 let first_sample = dataset.get(0)?;
37 let feature_count = first_sample.0.shape().dims().iter().product::<usize>();
38
39 let mut stats = DatasetStats::new(feature_count, sample_count);
40
41 let mut all_features = Vec::new();
43 for i in 0..sample_count {
44 let (features, _) = dataset.get(i)?;
45 let feature_vec = Self::tensor_to_vec(&features)?;
46 all_features.push(feature_vec);
47 }
48
49 if config.compute_mean {
51 stats.mean = Some(Self::compute_mean(&all_features)?);
52 }
53
54 if config.compute_std {
56 let mean = if let Some(ref mean) = stats.mean {
57 mean.clone()
58 } else {
59 Self::compute_mean(&all_features)?
60 };
61 stats.std = Some(Self::compute_std(&all_features, &mean)?);
62 }
63
64 if config.compute_min_max {
66 let (min, max) = Self::compute_min_max(&all_features)?;
67 stats.min = Some(min);
68 stats.max = Some(max);
69 }
70
71 if config.compute_histogram {
73 let min = if let Some(ref min) = stats.min {
74 min.clone()
75 } else {
76 Self::compute_min_max(&all_features)?.0
77 };
78 let max = if let Some(ref max) = stats.max {
79 max.clone()
80 } else {
81 Self::compute_min_max(&all_features)?.1
82 };
83 stats.histogram = Some(Self::compute_histogram(
84 &all_features,
85 &min,
86 &max,
87 config.histogram_bins,
88 )?);
89 }
90
91 if config.compute_class_distribution {
93 let mut class_counts = HashMap::new();
94 for i in 0..sample_count {
95 let (_, label) = dataset.get(i)?;
96 let label_str = format!("{label:?}");
97 *class_counts.entry(label_str).or_insert(0) += 1;
98 }
99 stats.class_distribution = Some(class_counts);
100 }
101
102 Ok(stats)
103 }
104
105 pub fn tensor_to_vec<T>(tensor: &Tensor<T>) -> Result<Vec<T>>
107 where
108 T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
109 {
110 let data = tensor.as_slice().ok_or_else(|| {
112 TensorError::invalid_argument(
113 "Cannot access tensor data (GPU tensor not supported)".to_string(),
114 )
115 })?;
116 Ok(data.to_vec())
117 }
118
119 pub fn compute_mean<T>(features: &[Vec<T>]) -> Result<Vec<T>>
121 where
122 T: Clone + Default + scirs2_core::numeric::Zero + scirs2_core::numeric::Float,
123 {
124 if features.is_empty() {
125 return Err(TensorError::invalid_argument(
126 "Cannot compute mean of empty features".to_string(),
127 ));
128 }
129
130 let feature_count = features[0].len();
131 let mut mean = vec![T::zero(); feature_count];
132
133 for feature_vec in features {
134 for (i, &value) in feature_vec.iter().enumerate() {
135 mean[i] = mean[i] + value;
136 }
137 }
138
139 let n = T::from(features.len()).expect("feature count should convert to float");
140 for mean_val in &mut mean {
141 *mean_val = *mean_val / n;
142 }
143
144 Ok(mean)
145 }
146
147 fn compute_std<T>(features: &[Vec<T>], mean: &[T]) -> Result<Vec<T>>
149 where
150 T: Clone + Default + scirs2_core::numeric::Zero + scirs2_core::numeric::Float,
151 {
152 if features.is_empty() {
153 return Err(TensorError::invalid_argument(
154 "Cannot compute std of empty features".to_string(),
155 ));
156 }
157
158 let feature_count = features[0].len();
159 let mut variance = vec![T::zero(); feature_count];
160
161 for feature_vec in features {
162 for (i, &value) in feature_vec.iter().enumerate() {
163 let diff = value - mean[i];
164 variance[i] = variance[i] + diff * diff;
165 }
166 }
167
168 let n = T::from(features.len()).expect("feature count should convert to float");
169 let mut std = Vec::new();
170 for var_val in variance {
171 let std_val = (var_val / n).sqrt();
172 std.push(std_val);
173 }
174
175 Ok(std)
176 }
177
178 fn compute_min_max<T>(features: &[Vec<T>]) -> Result<(Vec<T>, Vec<T>)>
180 where
181 T: Clone + Default + scirs2_core::numeric::Zero + scirs2_core::numeric::Float,
182 {
183 if features.is_empty() {
184 return Err(TensorError::invalid_argument(
185 "Cannot compute min/max of empty features".to_string(),
186 ));
187 }
188
189 let _feature_count = features[0].len();
190 let mut min_vals = features[0].clone();
191 let mut max_vals = features[0].clone();
192
193 for feature_vec in features.iter().skip(1) {
194 for (i, &value) in feature_vec.iter().enumerate() {
195 if value < min_vals[i] {
196 min_vals[i] = value;
197 }
198 if value > max_vals[i] {
199 max_vals[i] = value;
200 }
201 }
202 }
203
204 Ok((min_vals, max_vals))
205 }
206
207 fn compute_histogram<T>(
209 features: &[Vec<T>],
210 min_vals: &[T],
211 max_vals: &[T],
212 bins: usize,
213 ) -> Result<Histogram<T>>
214 where
215 T: Clone + Default + scirs2_core::numeric::Zero + scirs2_core::numeric::Float,
216 {
217 if features.is_empty() {
218 return Err(TensorError::invalid_argument(
219 "Cannot compute histogram of empty features".to_string(),
220 ));
221 }
222
223 let feature_idx = 0;
225 let min_val = min_vals[feature_idx];
226 let max_val = max_vals[feature_idx];
227
228 let mut bin_edges = Vec::new();
230 let step = (max_val - min_val) / T::from(bins).expect("bin count should convert to float");
231 for i in 0..=bins {
232 bin_edges.push(min_val + T::from(i).expect("bin index should convert to float") * step);
233 }
234
235 let mut counts = vec![0usize; bins];
237 for feature_vec in features {
238 let value = feature_vec[feature_idx];
239 let bin_idx = if value >= max_val {
240 bins - 1
241 } else {
242 let normalized = (value - min_val) / (max_val - min_val);
243 let idx = (normalized * T::from(bins).expect("bin count should convert to float"))
244 .to_usize()
245 .unwrap_or(0);
246 idx.min(bins - 1)
247 };
248 counts[bin_idx] += 1;
249 }
250
251 let mut bin_centers = Vec::new();
253 for i in 0..bins {
254 let center = (bin_edges[i] + bin_edges[i + 1])
255 / T::from(2).expect("constant 2 should convert to float");
256 bin_centers.push(center);
257 }
258
259 Ok(Histogram {
260 bins: bin_centers,
261 counts,
262 bin_edges,
263 })
264 }
265}