Skip to main content

scirs2_datasets/generators/
multilabel_advanced.rs

1//! Advanced multi-label classification data generator with label dependencies
2//!
3//! Generates synthetic multi-label datasets using a latent factor model.
4//! Label correlations arise naturally from shared latent structure, producing
5//! more realistic co-occurrence patterns than independently sampling labels.
6
7/// Configuration for the advanced multi-label classification generator
8#[derive(Debug, Clone)]
9pub struct AdvancedMultilabelConfig {
10    /// Number of samples
11    pub n_samples: usize,
12    /// Number of features
13    pub n_features: usize,
14    /// Number of output labels
15    pub n_labels: usize,
16    /// Expected fraction of positive labels per sample
17    pub label_density: f64,
18    /// Dimensionality of the latent space driving label correlations
19    pub n_latent: usize,
20    /// Whether to allow samples with no positive labels
21    pub allow_unlabeled: bool,
22    /// Random seed for reproducibility
23    pub seed: u64,
24}
25
26impl Default for AdvancedMultilabelConfig {
27    fn default() -> Self {
28        Self {
29            n_samples: 500,
30            n_features: 20,
31            n_labels: 5,
32            label_density: 0.3,
33            n_latent: 10,
34            allow_unlabeled: true,
35            seed: 42,
36        }
37    }
38}
39
40/// Advanced multi-label classification dataset
41#[derive(Debug, Clone)]
42pub struct AdvancedMultilabelDataset {
43    /// Feature matrix (n_samples × n_features)
44    pub x: Vec<Vec<f64>>,
45    /// Label matrix (n_samples × n_labels); true = label is active
46    pub y: Vec<Vec<bool>>,
47    /// Label co-occurrence frequency matrix (n_labels × n_labels)
48    pub label_cooccurrence: Vec<Vec<f64>>,
49    /// Mean number of active labels per sample
50    pub cardinality: f64,
51}
52
53/// Simple seeded LCG PRNG for deterministic generation
54struct Lcg {
55    state: u64,
56}
57
58impl Lcg {
59    fn new(seed: u64) -> Self {
60        Self { state: seed }
61    }
62
63    fn next_u64(&mut self) -> u64 {
64        self.state = self
65            .state
66            .wrapping_mul(6_364_136_223_846_793_005)
67            .wrapping_add(1_442_695_040_888_963_407);
68        self.state
69    }
70
71    fn next_f64(&mut self) -> f64 {
72        (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
73    }
74
75    fn next_normal(&mut self) -> f64 {
76        let u1 = self.next_f64().max(1e-10);
77        let u2 = self.next_f64();
78        (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
79    }
80}
81
82/// Sigmoid activation function
83fn sigmoid(x: f64) -> f64 {
84    1.0 / (1.0 + (-x).exp())
85}
86
87/// Generate an advanced multi-label classification dataset with label correlations.
88///
89/// The generative model:
90/// - W ∈ R^{n_latent × n_features}: feature generation matrix
91/// - L ∈ R^{n_labels × n_latent}: label-latent mapping
92/// - bias ∈ R^{n_labels}: per-label bias controlling label density
93/// - For each sample: z ~ N(0,I) in latent space
94///   - x = W^T z + noise
95///   - p_k = sigmoid(`L[k,:]` · z + bias_k)
96///   - y_k ~ Bernoulli(p_k)
97///
98/// # Arguments
99///
100/// * `config` - Generator configuration
101///
102/// # Returns
103///
104/// An [`AdvancedMultilabelDataset`] with correlated labels.
105pub fn make_advanced_multilabel_classification(
106    config: &AdvancedMultilabelConfig,
107) -> AdvancedMultilabelDataset {
108    let mut rng = Lcg::new(config.seed);
109    let n_lat = config.n_latent.max(1);
110
111    // --- Feature generation matrix W (n_latent × n_features) ---
112    let w_mat: Vec<Vec<f64>> = (0..n_lat)
113        .map(|_| (0..config.n_features).map(|_| rng.next_normal()).collect())
114        .collect();
115
116    // --- Label-latent mapping L (n_labels × n_latent) ---
117    let l_mat: Vec<Vec<f64>> = (0..config.n_labels)
118        .map(|_| (0..n_lat).map(|_| rng.next_normal()).collect())
119        .collect();
120
121    // --- Per-label biases calibrated to hit target label_density ---
122    // logit(p) = L·z + b; for z=0 bias controls marginal probability
123    // logit(label_density) drives the base probability
124    let target_logit = {
125        let d = config.label_density.clamp(1e-6, 1.0 - 1e-6);
126        (d / (1.0 - d)).ln()
127    };
128    let biases: Vec<f64> = (0..config.n_labels)
129        .map(|_| target_logit + rng.next_normal() * 0.2)
130        .collect();
131
132    // --- Generate samples ---
133    let mut x_all: Vec<Vec<f64>> = Vec::with_capacity(config.n_samples);
134    let mut y_all: Vec<Vec<bool>> = Vec::with_capacity(config.n_samples);
135
136    let mut generated = 0;
137    let mut attempts = 0;
138    let max_attempts = config.n_samples * 10 + 100;
139
140    while generated < config.n_samples && attempts < max_attempts {
141        attempts += 1;
142
143        // Latent vector z ~ N(0, I)
144        let z: Vec<f64> = (0..n_lat).map(|_| rng.next_normal()).collect();
145
146        // Feature vector x = W^T z + small noise
147        let sample_x: Vec<f64> = (0..config.n_features)
148            .map(|j| {
149                let val: f64 = (0..n_lat).map(|k| w_mat[k][j] * z[k]).sum();
150                val + rng.next_normal() * 0.1
151            })
152            .collect();
153
154        // Label probabilities
155        let probs: Vec<f64> = (0..config.n_labels)
156            .map(|k| {
157                let logit: f64 = (0..n_lat).map(|d| l_mat[k][d] * z[d]).sum::<f64>() + biases[k];
158                sigmoid(logit)
159            })
160            .collect();
161
162        // Sample labels
163        let labels: Vec<bool> = probs.iter().map(|&p| rng.next_f64() < p).collect();
164
165        // If allow_unlabeled=false, skip samples with no active labels
166        let any_active = labels.iter().any(|&b| b);
167        if !config.allow_unlabeled && !any_active {
168            continue;
169        }
170
171        x_all.push(sample_x);
172        y_all.push(labels);
173        generated += 1;
174    }
175
176    // --- Label co-occurrence matrix ---
177    let n_actual = y_all.len();
178    let mut cooccur = vec![vec![0.0f64; config.n_labels]; config.n_labels];
179    for labels in &y_all {
180        for k1 in 0..config.n_labels {
181            for k2 in 0..config.n_labels {
182                if labels[k1] && labels[k2] {
183                    cooccur[k1][k2] += 1.0;
184                }
185            }
186        }
187    }
188    if n_actual > 0 {
189        for row in &mut cooccur {
190            for val in row.iter_mut() {
191                *val /= n_actual as f64;
192            }
193        }
194    }
195
196    let cardinality = label_cardinality_impl(&y_all);
197
198    AdvancedMultilabelDataset {
199        x: x_all,
200        y: y_all,
201        label_cooccurrence: cooccur,
202        cardinality,
203    }
204}
205
206fn label_cardinality_impl(y: &[Vec<bool>]) -> f64 {
207    if y.is_empty() {
208        return 0.0;
209    }
210    let total: usize = y.iter().map(|row| row.iter().filter(|&&b| b).count()).sum();
211    total as f64 / y.len() as f64
212}
213
214/// Compute the mean number of active labels per sample.
215///
216/// # Arguments
217///
218/// * `y` - Label matrix (n_samples × n_labels)
219///
220/// # Returns
221///
222/// Mean active labels per sample (label cardinality).
223pub fn label_cardinality(y: &[Vec<bool>]) -> f64 {
224    label_cardinality_impl(y)
225}
226
227/// Compute the fraction of positive labels across all samples and labels.
228///
229/// # Arguments
230///
231/// * `y` - Label matrix (n_samples × n_labels)
232///
233/// # Returns
234///
235/// Global label density in [0, 1].
236pub fn label_density_score(y: &[Vec<bool>]) -> f64 {
237    if y.is_empty() {
238        return 0.0;
239    }
240    let n_labels = y[0].len();
241    if n_labels == 0 {
242        return 0.0;
243    }
244    let total = (y.len() * n_labels) as f64;
245    let positives: usize = y.iter().flat_map(|row| row.iter()).filter(|&&b| b).count();
246    positives as f64 / total
247}
248
249/// Compute the Hamming loss between two label matrices.
250///
251/// Hamming loss is the fraction of labels that are incorrectly predicted,
252/// averaged over all samples and labels.
253///
254/// # Arguments
255///
256/// * `y_true` - Ground truth label matrix
257/// * `y_pred` - Predicted label matrix
258///
259/// # Returns
260///
261/// Hamming loss in [0, 1].
262pub fn hamming_loss(y_true: &[Vec<bool>], y_pred: &[Vec<bool>]) -> f64 {
263    if y_true.is_empty() {
264        return 0.0;
265    }
266    let n_labels = y_true[0].len();
267    if n_labels == 0 {
268        return 0.0;
269    }
270    let total = (y_true.len() * n_labels) as f64;
271    let wrong: usize = y_true
272        .iter()
273        .zip(y_pred.iter())
274        .flat_map(|(r_t, r_p)| r_t.iter().zip(r_p.iter()))
275        .filter(|(&t, &p)| t != p)
276        .count();
277    wrong as f64 / total
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283
284    #[test]
285    fn test_cardinality_near_density() {
286        let config = AdvancedMultilabelConfig {
287            n_samples: 1000,
288            n_features: 15,
289            n_labels: 5,
290            label_density: 0.3,
291            n_latent: 8,
292            allow_unlabeled: true,
293            seed: 42,
294        };
295        let ds = make_advanced_multilabel_classification(&config);
296        // Expected cardinality ≈ label_density * n_labels = 0.3 * 5 = 1.5
297        // Allow generous tolerance due to latent-space correlations
298        assert!(
299            ds.cardinality > 0.5,
300            "Cardinality too low: {}",
301            ds.cardinality
302        );
303        assert!(
304            ds.cardinality < 4.5,
305            "Cardinality too high: {}",
306            ds.cardinality
307        );
308    }
309
310    #[test]
311    fn test_output_shapes() {
312        let config = AdvancedMultilabelConfig {
313            n_samples: 100,
314            n_features: 10,
315            n_labels: 4,
316            label_density: 0.4,
317            n_latent: 5,
318            allow_unlabeled: true,
319            seed: 7,
320        };
321        let ds = make_advanced_multilabel_classification(&config);
322        assert_eq!(ds.x.len(), 100);
323        assert_eq!(ds.x[0].len(), 10);
324        assert_eq!(ds.y.len(), 100);
325        assert_eq!(ds.y[0].len(), 4);
326        assert_eq!(ds.label_cooccurrence.len(), 4);
327        assert_eq!(ds.label_cooccurrence[0].len(), 4);
328    }
329
330    #[test]
331    fn test_hamming_loss_self_zero() {
332        let y = vec![
333            vec![true, false, true],
334            vec![false, false, true],
335            vec![true, true, false],
336        ];
337        assert!((hamming_loss(&y, &y) - 0.0).abs() < 1e-12);
338    }
339
340    #[test]
341    fn test_hamming_loss_all_wrong() {
342        let y_true = vec![vec![true, true], vec![false, false]];
343        let y_pred = vec![vec![false, false], vec![true, true]];
344        let loss = hamming_loss(&y_true, &y_pred);
345        assert!((loss - 1.0).abs() < 1e-12, "Expected 1.0, got {loss}");
346    }
347
348    #[test]
349    fn test_label_density_score() {
350        let y = vec![
351            vec![true, false, false, false],
352            vec![false, false, false, false],
353        ];
354        let d = label_density_score(&y);
355        // 1 out of 8 entries are positive
356        assert!((d - 0.125).abs() < 1e-12, "Expected 0.125, got {d}");
357    }
358
359    #[test]
360    fn test_no_unlabeled_when_disabled() {
361        let config = AdvancedMultilabelConfig {
362            n_samples: 200,
363            n_features: 10,
364            n_labels: 3,
365            label_density: 0.5,
366            n_latent: 5,
367            allow_unlabeled: false,
368            seed: 55,
369        };
370        let ds = make_advanced_multilabel_classification(&config);
371        for (i, labels) in ds.y.iter().enumerate() {
372            let any = labels.iter().any(|&b| b);
373            assert!(
374                any,
375                "Sample {i} has no active labels but allow_unlabeled=false"
376            );
377        }
378    }
379}