scirs2_datasets/generators/
multilabel_advanced.rs1#[derive(Debug, Clone)]
9pub struct AdvancedMultilabelConfig {
10 pub n_samples: usize,
12 pub n_features: usize,
14 pub n_labels: usize,
16 pub label_density: f64,
18 pub n_latent: usize,
20 pub allow_unlabeled: bool,
22 pub seed: u64,
24}
25
26impl Default for AdvancedMultilabelConfig {
27 fn default() -> Self {
28 Self {
29 n_samples: 500,
30 n_features: 20,
31 n_labels: 5,
32 label_density: 0.3,
33 n_latent: 10,
34 allow_unlabeled: true,
35 seed: 42,
36 }
37 }
38}
39
40#[derive(Debug, Clone)]
42pub struct AdvancedMultilabelDataset {
43 pub x: Vec<Vec<f64>>,
45 pub y: Vec<Vec<bool>>,
47 pub label_cooccurrence: Vec<Vec<f64>>,
49 pub cardinality: f64,
51}
52
53struct Lcg {
55 state: u64,
56}
57
58impl Lcg {
59 fn new(seed: u64) -> Self {
60 Self { state: seed }
61 }
62
63 fn next_u64(&mut self) -> u64 {
64 self.state = self
65 .state
66 .wrapping_mul(6_364_136_223_846_793_005)
67 .wrapping_add(1_442_695_040_888_963_407);
68 self.state
69 }
70
71 fn next_f64(&mut self) -> f64 {
72 (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
73 }
74
75 fn next_normal(&mut self) -> f64 {
76 let u1 = self.next_f64().max(1e-10);
77 let u2 = self.next_f64();
78 (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
79 }
80}
81
82fn sigmoid(x: f64) -> f64 {
84 1.0 / (1.0 + (-x).exp())
85}
86
87pub fn make_advanced_multilabel_classification(
106 config: &AdvancedMultilabelConfig,
107) -> AdvancedMultilabelDataset {
108 let mut rng = Lcg::new(config.seed);
109 let n_lat = config.n_latent.max(1);
110
111 let w_mat: Vec<Vec<f64>> = (0..n_lat)
113 .map(|_| (0..config.n_features).map(|_| rng.next_normal()).collect())
114 .collect();
115
116 let l_mat: Vec<Vec<f64>> = (0..config.n_labels)
118 .map(|_| (0..n_lat).map(|_| rng.next_normal()).collect())
119 .collect();
120
121 let target_logit = {
125 let d = config.label_density.clamp(1e-6, 1.0 - 1e-6);
126 (d / (1.0 - d)).ln()
127 };
128 let biases: Vec<f64> = (0..config.n_labels)
129 .map(|_| target_logit + rng.next_normal() * 0.2)
130 .collect();
131
132 let mut x_all: Vec<Vec<f64>> = Vec::with_capacity(config.n_samples);
134 let mut y_all: Vec<Vec<bool>> = Vec::with_capacity(config.n_samples);
135
136 let mut generated = 0;
137 let mut attempts = 0;
138 let max_attempts = config.n_samples * 10 + 100;
139
140 while generated < config.n_samples && attempts < max_attempts {
141 attempts += 1;
142
143 let z: Vec<f64> = (0..n_lat).map(|_| rng.next_normal()).collect();
145
146 let sample_x: Vec<f64> = (0..config.n_features)
148 .map(|j| {
149 let val: f64 = (0..n_lat).map(|k| w_mat[k][j] * z[k]).sum();
150 val + rng.next_normal() * 0.1
151 })
152 .collect();
153
154 let probs: Vec<f64> = (0..config.n_labels)
156 .map(|k| {
157 let logit: f64 = (0..n_lat).map(|d| l_mat[k][d] * z[d]).sum::<f64>() + biases[k];
158 sigmoid(logit)
159 })
160 .collect();
161
162 let labels: Vec<bool> = probs.iter().map(|&p| rng.next_f64() < p).collect();
164
165 let any_active = labels.iter().any(|&b| b);
167 if !config.allow_unlabeled && !any_active {
168 continue;
169 }
170
171 x_all.push(sample_x);
172 y_all.push(labels);
173 generated += 1;
174 }
175
176 let n_actual = y_all.len();
178 let mut cooccur = vec![vec![0.0f64; config.n_labels]; config.n_labels];
179 for labels in &y_all {
180 for k1 in 0..config.n_labels {
181 for k2 in 0..config.n_labels {
182 if labels[k1] && labels[k2] {
183 cooccur[k1][k2] += 1.0;
184 }
185 }
186 }
187 }
188 if n_actual > 0 {
189 for row in &mut cooccur {
190 for val in row.iter_mut() {
191 *val /= n_actual as f64;
192 }
193 }
194 }
195
196 let cardinality = label_cardinality_impl(&y_all);
197
198 AdvancedMultilabelDataset {
199 x: x_all,
200 y: y_all,
201 label_cooccurrence: cooccur,
202 cardinality,
203 }
204}
205
206fn label_cardinality_impl(y: &[Vec<bool>]) -> f64 {
207 if y.is_empty() {
208 return 0.0;
209 }
210 let total: usize = y.iter().map(|row| row.iter().filter(|&&b| b).count()).sum();
211 total as f64 / y.len() as f64
212}
213
214pub fn label_cardinality(y: &[Vec<bool>]) -> f64 {
224 label_cardinality_impl(y)
225}
226
227pub fn label_density_score(y: &[Vec<bool>]) -> f64 {
237 if y.is_empty() {
238 return 0.0;
239 }
240 let n_labels = y[0].len();
241 if n_labels == 0 {
242 return 0.0;
243 }
244 let total = (y.len() * n_labels) as f64;
245 let positives: usize = y.iter().flat_map(|row| row.iter()).filter(|&&b| b).count();
246 positives as f64 / total
247}
248
249pub fn hamming_loss(y_true: &[Vec<bool>], y_pred: &[Vec<bool>]) -> f64 {
263 if y_true.is_empty() {
264 return 0.0;
265 }
266 let n_labels = y_true[0].len();
267 if n_labels == 0 {
268 return 0.0;
269 }
270 let total = (y_true.len() * n_labels) as f64;
271 let wrong: usize = y_true
272 .iter()
273 .zip(y_pred.iter())
274 .flat_map(|(r_t, r_p)| r_t.iter().zip(r_p.iter()))
275 .filter(|(&t, &p)| t != p)
276 .count();
277 wrong as f64 / total
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283
284 #[test]
285 fn test_cardinality_near_density() {
286 let config = AdvancedMultilabelConfig {
287 n_samples: 1000,
288 n_features: 15,
289 n_labels: 5,
290 label_density: 0.3,
291 n_latent: 8,
292 allow_unlabeled: true,
293 seed: 42,
294 };
295 let ds = make_advanced_multilabel_classification(&config);
296 assert!(
299 ds.cardinality > 0.5,
300 "Cardinality too low: {}",
301 ds.cardinality
302 );
303 assert!(
304 ds.cardinality < 4.5,
305 "Cardinality too high: {}",
306 ds.cardinality
307 );
308 }
309
310 #[test]
311 fn test_output_shapes() {
312 let config = AdvancedMultilabelConfig {
313 n_samples: 100,
314 n_features: 10,
315 n_labels: 4,
316 label_density: 0.4,
317 n_latent: 5,
318 allow_unlabeled: true,
319 seed: 7,
320 };
321 let ds = make_advanced_multilabel_classification(&config);
322 assert_eq!(ds.x.len(), 100);
323 assert_eq!(ds.x[0].len(), 10);
324 assert_eq!(ds.y.len(), 100);
325 assert_eq!(ds.y[0].len(), 4);
326 assert_eq!(ds.label_cooccurrence.len(), 4);
327 assert_eq!(ds.label_cooccurrence[0].len(), 4);
328 }
329
330 #[test]
331 fn test_hamming_loss_self_zero() {
332 let y = vec![
333 vec![true, false, true],
334 vec![false, false, true],
335 vec![true, true, false],
336 ];
337 assert!((hamming_loss(&y, &y) - 0.0).abs() < 1e-12);
338 }
339
340 #[test]
341 fn test_hamming_loss_all_wrong() {
342 let y_true = vec![vec![true, true], vec![false, false]];
343 let y_pred = vec![vec![false, false], vec![true, true]];
344 let loss = hamming_loss(&y_true, &y_pred);
345 assert!((loss - 1.0).abs() < 1e-12, "Expected 1.0, got {loss}");
346 }
347
348 #[test]
349 fn test_label_density_score() {
350 let y = vec![
351 vec![true, false, false, false],
352 vec![false, false, false, false],
353 ];
354 let d = label_density_score(&y);
355 assert!((d - 0.125).abs() < 1e-12, "Expected 0.125, got {d}");
357 }
358
359 #[test]
360 fn test_no_unlabeled_when_disabled() {
361 let config = AdvancedMultilabelConfig {
362 n_samples: 200,
363 n_features: 10,
364 n_labels: 3,
365 label_density: 0.5,
366 n_latent: 5,
367 allow_unlabeled: false,
368 seed: 55,
369 };
370 let ds = make_advanced_multilabel_classification(&config);
371 for (i, labels) in ds.y.iter().enumerate() {
372 let any = labels.iter().any(|&b| b);
373 assert!(
374 any,
375 "Sample {i} has no active labels but allow_unlabeled=false"
376 );
377 }
378 }
379}