scirs2_datasets/generators/
sparse_classification.rs1#[derive(Debug, Clone)]
10pub struct SparseClassConfig {
11 pub n_samples: usize,
13 pub n_features: usize,
15 pub n_informative: usize,
17 pub n_classes: usize,
19 pub class_sep: f64,
21 pub seed: u64,
23}
24
25impl Default for SparseClassConfig {
26 fn default() -> Self {
27 Self {
28 n_samples: 1000,
29 n_features: 10000,
30 n_informative: 20,
31 n_classes: 2,
32 class_sep: 1.0,
33 seed: 42,
34 }
35 }
36}
37
38#[derive(Debug, Clone)]
40pub struct SparseClassDataset {
41 pub x: Vec<Vec<f64>>,
43 pub y: Vec<usize>,
45 pub informative_features: Vec<usize>,
47 pub feature_weights: Vec<f64>,
49}
50
51struct Lcg {
53 state: u64,
54}
55
56impl Lcg {
57 fn new(seed: u64) -> Self {
58 Self { state: seed }
59 }
60
61 fn next_u64(&mut self) -> u64 {
62 self.state = self
63 .state
64 .wrapping_mul(6_364_136_223_846_793_005)
65 .wrapping_add(1_442_695_040_888_963_407);
66 self.state
67 }
68
69 fn next_f64(&mut self) -> f64 {
70 (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
71 }
72
73 fn next_normal(&mut self) -> f64 {
74 let u1 = self.next_f64().max(1e-10);
75 let u2 = self.next_f64();
76 (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
77 }
78
79 fn next_usize_below(&mut self, n: usize) -> usize {
80 (self.next_u64() % n as u64) as usize
81 }
82}
83
84pub fn make_sparse_classification(config: &SparseClassConfig) -> SparseClassDataset {
99 let mut rng = Lcg::new(config.seed);
100
101 let n_inf = config.n_informative.min(config.n_features);
103 let mut informative_features: Vec<usize> = {
104 let mut indices: Vec<usize> = (0..config.n_features).collect();
106 for i in 0..n_inf {
107 let j = i + rng.next_usize_below(config.n_features - i);
108 indices.swap(i, j);
109 }
110 indices[..n_inf].to_vec()
111 };
112 informative_features.sort_unstable();
113
114 let centroids: Vec<Vec<f64>> = (0..config.n_classes)
117 .map(|_| {
118 (0..n_inf)
119 .map(|_| rng.next_normal() * config.class_sep)
120 .collect()
121 })
122 .collect();
123
124 let mut feature_weights = vec![0.0f64; config.n_features];
126 for (idx, &fi) in informative_features.iter().enumerate() {
127 let mean_val: f64 = centroids.iter().map(|c| c[idx]).sum::<f64>() / config.n_classes as f64;
129 feature_weights[fi] = mean_val;
130 }
131
132 let n_per_class = config.n_samples / config.n_classes;
134 let mut x: Vec<Vec<f64>> = Vec::with_capacity(config.n_samples);
135 let mut y: Vec<usize> = Vec::with_capacity(config.n_samples);
136
137 for (class_idx, centroid) in centroids.iter().enumerate() {
138 let count = if class_idx == config.n_classes - 1 {
140 config.n_samples - n_per_class * (config.n_classes - 1)
141 } else {
142 n_per_class
143 };
144 for _ in 0..count {
145 let mut sample = vec![0.0f64; config.n_features];
146 for (inf_idx, &fi) in informative_features.iter().enumerate() {
147 sample[fi] = centroid[inf_idx] + rng.next_normal() * 0.5;
149 }
150 x.push(sample);
151 y.push(class_idx);
152 }
153 }
154
155 let n = x.len();
157 for i in (1..n).rev() {
158 let j = rng.next_usize_below(i + 1);
159 x.swap(i, j);
160 y.swap(i, j);
161 }
162
163 SparseClassDataset {
164 x,
165 y,
166 informative_features,
167 feature_weights,
168 }
169}
170
171pub fn sparsity_ratio(x: &[Vec<f64>]) -> f64 {
184 if x.is_empty() {
185 return 1.0;
186 }
187 let n_cols = x[0].len();
188 if n_cols == 0 {
189 return 1.0;
190 }
191 let total = (x.len() * n_cols) as f64;
192 let zeros = x
193 .iter()
194 .flat_map(|row| row.iter())
195 .filter(|&&v| v == 0.0)
196 .count() as f64;
197 zeros / total
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203
204 #[test]
205 fn test_sparsity_high() {
206 let config = SparseClassConfig {
207 n_samples: 200,
208 n_features: 1000,
209 n_informative: 10,
210 n_classes: 2,
211 class_sep: 1.0,
212 seed: 42,
213 };
214 let ds = make_sparse_classification(&config);
215 let ratio = sparsity_ratio(&ds.x);
216 assert!(ratio > 0.98, "Sparsity ratio should be > 0.98, got {ratio}");
218 }
219
220 #[test]
221 fn test_label_balance() {
222 let config = SparseClassConfig {
223 n_samples: 100,
224 n_features: 500,
225 n_informative: 5,
226 n_classes: 2,
227 class_sep: 1.0,
228 seed: 7,
229 };
230 let ds = make_sparse_classification(&config);
231 assert_eq!(ds.y.len(), 100);
232 let class0 = ds.y.iter().filter(|&&c| c == 0).count();
233 let class1 = ds.y.iter().filter(|&&c| c == 1).count();
234 assert!((40..=60).contains(&class0), "Class 0 count: {class0}");
236 assert!((40..=60).contains(&class1), "Class 1 count: {class1}");
237 }
238
239 #[test]
240 fn test_informative_feature_count() {
241 let config = SparseClassConfig {
242 n_samples: 50,
243 n_features: 200,
244 n_informative: 8,
245 n_classes: 3,
246 class_sep: 1.5,
247 seed: 99,
248 };
249 let ds = make_sparse_classification(&config);
250 assert_eq!(ds.informative_features.len(), 8);
251 for &fi in &ds.informative_features {
253 assert!(fi < 200, "Informative feature index out of range: {fi}");
254 }
255 }
256
257 #[test]
258 fn test_non_informative_are_zero() {
259 let config = SparseClassConfig {
260 n_samples: 20,
261 n_features: 100,
262 n_informative: 5,
263 n_classes: 2,
264 class_sep: 1.0,
265 seed: 13,
266 };
267 let ds = make_sparse_classification(&config);
268 let inf_set: std::collections::HashSet<usize> =
269 ds.informative_features.iter().copied().collect();
270 for row in &ds.x {
271 for (j, &val) in row.iter().enumerate() {
272 if !inf_set.contains(&j) {
273 assert_eq!(val, 0.0, "Non-informative feature {j} should be zero");
274 }
275 }
276 }
277 }
278
279 #[test]
280 fn test_default_config_shape() {
281 let config = SparseClassConfig {
282 n_samples: 50,
283 n_features: 200,
284 n_informative: 10,
285 ..Default::default()
286 };
287 let ds = make_sparse_classification(&config);
288 assert_eq!(ds.x.len(), 50);
289 assert_eq!(ds.x[0].len(), 200);
290 }
291}