Skip to main content

scirs2_datasets/generators/
heterogeneous.rs

1//! Mixed feature type dataset generator
2//!
3//! Generates synthetic datasets combining continuous, categorical, ordinal,
4//! and binary features. Supports class-conditional generation so each class
5//! has different feature distributions, enabling classification benchmarks
6//! on heterogeneous tabular data.
7
8/// Describes the type and parameters of a single feature column.
9#[derive(Debug, Clone)]
10#[non_exhaustive]
11pub enum FeatureType {
12    /// Continuous Gaussian feature with given (mean, std)
13    Continuous(f64, f64),
14    /// Categorical feature with the given number of categories
15    Categorical(usize),
16    /// Ordinal feature with the given number of levels (0 ..= n_levels-1)
17    Ordinal(usize),
18    /// Binary Bernoulli feature
19    Binary,
20}
21
22/// Configuration for the heterogeneous dataset generator
23#[derive(Debug, Clone)]
24pub struct HeteroConfig {
25    /// Number of samples
26    pub n_samples: usize,
27    /// Explicit feature types; if empty, types are generated automatically
28    pub feature_types: Vec<FeatureType>,
29    /// Number of features to auto-generate when `feature_types` is empty
30    pub n_features: usize,
31    /// Number of output classes
32    pub n_classes: usize,
33    /// Random seed for reproducibility
34    pub seed: u64,
35}
36
37impl Default for HeteroConfig {
38    fn default() -> Self {
39        Self {
40            n_samples: 500,
41            feature_types: Vec::new(),
42            n_features: 10,
43            n_classes: 2,
44            seed: 42,
45        }
46    }
47}
48
49/// A single feature value that may be continuous, integer, or boolean
50#[derive(Debug, Clone, PartialEq)]
51#[non_exhaustive]
52pub enum HeteroFeatureValue {
53    /// Floating-point (continuous) value
54    Float(f64),
55    /// Non-negative integer (categorical or ordinal level)
56    Int(usize),
57    /// Boolean (binary feature)
58    Bool(bool),
59}
60
61/// Heterogeneous (mixed feature type) classification dataset
62#[derive(Debug, Clone)]
63pub struct HeteroDataset {
64    /// Feature matrix: each row is a sample, each column is a feature
65    pub features: Vec<Vec<HeteroFeatureValue>>,
66    /// Class label for each sample
67    pub labels: Vec<usize>,
68    /// Type of each feature column
69    pub feature_types: Vec<FeatureType>,
70    /// Human-readable name for each feature
71    pub feature_names: Vec<String>,
72}
73
74/// Simple seeded LCG PRNG for deterministic generation
75struct Lcg {
76    state: u64,
77}
78
79impl Lcg {
80    fn new(seed: u64) -> Self {
81        Self { state: seed }
82    }
83
84    fn next_u64(&mut self) -> u64 {
85        self.state = self
86            .state
87            .wrapping_mul(6_364_136_223_846_793_005)
88            .wrapping_add(1_442_695_040_888_963_407);
89        self.state
90    }
91
92    fn next_f64(&mut self) -> f64 {
93        (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
94    }
95
96    fn next_normal(&mut self) -> f64 {
97        let u1 = self.next_f64().max(1e-10);
98        let u2 = self.next_f64();
99        (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
100    }
101
102    fn next_usize_below(&mut self, n: usize) -> usize {
103        (self.next_u64() % n as u64) as usize
104    }
105}
106
107/// Auto-generate a list of feature types by cycling through Continuous/Categorical/Ordinal/Binary
108fn auto_feature_types(n: usize, rng: &mut Lcg) -> Vec<FeatureType> {
109    (0..n)
110        .map(|i| match i % 4 {
111            0 => {
112                let mean = rng.next_normal();
113                let std = 0.5 + rng.next_f64();
114                FeatureType::Continuous(mean, std)
115            }
116            1 => {
117                let n_cats = 2 + rng.next_usize_below(5); // 2..6
118                FeatureType::Categorical(n_cats)
119            }
120            2 => {
121                let n_levels = 3 + rng.next_usize_below(5); // 3..7
122                FeatureType::Ordinal(n_levels)
123            }
124            _ => FeatureType::Binary,
125        })
126        .collect()
127}
128
129/// Generate a heterogeneous (mixed feature type) classification dataset.
130///
131/// Each class receives different feature distributions:
132/// - **Continuous**: each class shifts the mean by a class-specific offset
133/// - **Categorical**: each class uses a different Dirichlet-like distribution
134/// - **Ordinal**: each class shifts the modal level
135/// - **Binary**: each class uses a different Bernoulli probability
136///
137/// If `config.feature_types` is empty, feature types are auto-assigned in a
138/// Continuous/Categorical/Ordinal/Binary pattern.
139///
140/// # Arguments
141///
142/// * `config` - Generator configuration
143///
144/// # Returns
145///
146/// A [`HeteroDataset`] with mixed feature types and class labels.
147pub fn make_heterogeneous(config: &HeteroConfig) -> HeteroDataset {
148    let mut rng = Lcg::new(config.seed);
149
150    // Resolve feature types
151    let feature_types: Vec<FeatureType> = if config.feature_types.is_empty() {
152        auto_feature_types(config.n_features, &mut rng)
153    } else {
154        config.feature_types.clone()
155    };
156    let n_features = feature_types.len();
157
158    // Feature names
159    let feature_names: Vec<String> = feature_types
160        .iter()
161        .enumerate()
162        .map(|(i, ft)| match ft {
163            FeatureType::Continuous(_, _) => format!("cont_{i}"),
164            FeatureType::Categorical(_) => format!("cat_{i}"),
165            FeatureType::Ordinal(_) => format!("ord_{i}"),
166            FeatureType::Binary => format!("bin_{i}"),
167        })
168        .collect();
169
170    // Per-class, per-feature parameters
171    // continuous: class mean offset
172    let class_cont_offsets: Vec<Vec<f64>> = (0..config.n_classes)
173        .map(|_| (0..n_features).map(|_| rng.next_normal() * 1.5).collect())
174        .collect();
175
176    // categorical: per-class probability vector over categories
177    let class_cat_probs: Vec<Vec<Vec<f64>>> = (0..config.n_classes)
178        .map(|_| {
179            feature_types
180                .iter()
181                .map(|ft| match ft {
182                    FeatureType::Categorical(k) | FeatureType::Ordinal(k) => {
183                        let mut weights: Vec<f64> = (0..*k).map(|_| rng.next_f64() + 0.1).collect();
184                        let s: f64 = weights.iter().sum();
185                        for w in &mut weights {
186                            *w /= s;
187                        }
188                        weights
189                    }
190                    _ => vec![0.5, 0.5], // placeholder for non-categorical
191                })
192                .collect()
193        })
194        .collect();
195
196    // binary: per-class Bernoulli probability
197    let class_bin_probs: Vec<Vec<f64>> = (0..config.n_classes)
198        .map(|_| {
199            (0..n_features)
200                .map(|_| 0.1 + rng.next_f64() * 0.8)
201                .collect()
202        })
203        .collect();
204
205    // Generate balanced samples
206    let n_per_class = config.n_samples / config.n_classes;
207    let mut features: Vec<Vec<HeteroFeatureValue>> = Vec::with_capacity(config.n_samples);
208    let mut labels: Vec<usize> = Vec::with_capacity(config.n_samples);
209
210    for class_idx in 0..config.n_classes {
211        let count = if class_idx == config.n_classes - 1 {
212            config.n_samples - n_per_class * (config.n_classes - 1)
213        } else {
214            n_per_class
215        };
216        for _ in 0..count {
217            let row: Vec<HeteroFeatureValue> = feature_types
218                .iter()
219                .enumerate()
220                .map(|(j, ft)| match ft {
221                    FeatureType::Continuous(mean, std) => {
222                        let offset = class_cont_offsets[class_idx][j];
223                        let val = (mean + offset) + rng.next_normal() * std;
224                        HeteroFeatureValue::Float(val)
225                    }
226                    FeatureType::Categorical(k) => {
227                        let probs = &class_cat_probs[class_idx][j];
228                        let u = rng.next_f64();
229                        let mut cumsum = 0.0;
230                        let mut cat = 0;
231                        for (idx, &p) in probs.iter().enumerate() {
232                            cumsum += p;
233                            if u < cumsum {
234                                cat = idx;
235                                break;
236                            }
237                            cat = k - 1; // fallback
238                        }
239                        HeteroFeatureValue::Int(cat)
240                    }
241                    FeatureType::Ordinal(k) => {
242                        let probs = &class_cat_probs[class_idx][j];
243                        let u = rng.next_f64();
244                        let mut cumsum = 0.0;
245                        let mut level = 0;
246                        for (idx, &p) in probs.iter().enumerate() {
247                            cumsum += p;
248                            if u < cumsum {
249                                level = idx;
250                                break;
251                            }
252                            level = k - 1;
253                        }
254                        HeteroFeatureValue::Int(level)
255                    }
256                    FeatureType::Binary => {
257                        let p = class_bin_probs[class_idx][j];
258                        HeteroFeatureValue::Bool(rng.next_f64() < p)
259                    }
260                })
261                .collect();
262            features.push(row);
263            labels.push(class_idx);
264        }
265    }
266
267    // Shuffle samples
268    let n = features.len();
269    for i in (1..n).rev() {
270        let j = rng.next_usize_below(i + 1);
271        features.swap(i, j);
272        labels.swap(i, j);
273    }
274
275    HeteroDataset {
276        features,
277        labels,
278        feature_types,
279        feature_names,
280    }
281}
282
283/// One-hot encode all features in a heterogeneous dataset to a flat f64 vector.
284///
285/// Encoding rules:
286/// - `Float(v)` → `[v]` (pass through as single value)
287/// - `Int(k)` with categorical having `n` categories → one-hot vector of length `n`
288/// - `Int(l)` with ordinal having `n_levels` → one-hot vector of length `n_levels`
289/// - `Bool(b)` → `[0.0]` or `[1.0]`
290///
291/// The output vector per sample is wider than `n_features` whenever categorical
292/// or ordinal features have more than one category/level.
293///
294/// # Arguments
295///
296/// * `dataset` - The heterogeneous dataset to encode
297///
298/// # Returns
299///
300/// Dense f64 feature matrix with one-hot encoded categorical/ordinal features.
301pub fn encode_one_hot(dataset: &HeteroDataset) -> Vec<Vec<f64>> {
302    // Precompute the encoded widths for each feature
303    let widths: Vec<usize> = dataset
304        .feature_types
305        .iter()
306        .map(|ft| match ft {
307            FeatureType::Continuous(_, _) => 1,
308            FeatureType::Categorical(k) => *k,
309            FeatureType::Ordinal(k) => *k,
310            FeatureType::Binary => 1,
311        })
312        .collect();
313
314    let total_width: usize = widths.iter().sum();
315
316    dataset
317        .features
318        .iter()
319        .map(|row| {
320            let mut encoded = Vec::with_capacity(total_width);
321            for (j, val) in row.iter().enumerate() {
322                match (&dataset.feature_types[j], val) {
323                    (FeatureType::Continuous(_, _), HeteroFeatureValue::Float(v)) => {
324                        encoded.push(*v);
325                    }
326                    (FeatureType::Categorical(k), HeteroFeatureValue::Int(cat)) => {
327                        for c in 0..*k {
328                            encoded.push(if c == *cat { 1.0 } else { 0.0 });
329                        }
330                    }
331                    (FeatureType::Ordinal(k), HeteroFeatureValue::Int(level)) => {
332                        for l in 0..*k {
333                            encoded.push(if l == *level { 1.0 } else { 0.0 });
334                        }
335                    }
336                    (FeatureType::Binary, HeteroFeatureValue::Bool(b)) => {
337                        encoded.push(if *b { 1.0 } else { 0.0 });
338                    }
339                    // Fallback for unexpected type combinations
340                    (_, HeteroFeatureValue::Float(v)) => encoded.push(*v),
341                    (_, HeteroFeatureValue::Int(k)) => encoded.push(*k as f64),
342                    (_, HeteroFeatureValue::Bool(b)) => encoded.push(if *b { 1.0 } else { 0.0 }),
343                }
344            }
345            encoded
346        })
347        .collect()
348}
349
350#[cfg(test)]
351mod tests {
352    use super::*;
353
354    #[test]
355    fn test_heterogeneous_basic() {
356        let config = HeteroConfig {
357            n_samples: 50,
358            feature_types: Vec::new(),
359            n_features: 8,
360            n_classes: 2,
361            seed: 42,
362        };
363        let ds = make_heterogeneous(&config);
364        assert_eq!(ds.features.len(), 50);
365        assert_eq!(ds.labels.len(), 50);
366        assert_eq!(ds.feature_types.len(), 8);
367        assert_eq!(ds.feature_names.len(), 8);
368    }
369
370    #[test]
371    fn test_encode_one_hot_wider() {
372        let config = HeteroConfig {
373            n_samples: 20,
374            feature_types: vec![
375                FeatureType::Continuous(0.0, 1.0),
376                FeatureType::Categorical(4),
377                FeatureType::Ordinal(3),
378                FeatureType::Binary,
379            ],
380            n_features: 4,
381            n_classes: 2,
382            seed: 77,
383        };
384        let ds = make_heterogeneous(&config);
385        let encoded = encode_one_hot(&ds);
386        // 1 + 4 + 3 + 1 = 9 columns after one-hot
387        assert_eq!(
388            encoded[0].len(),
389            9,
390            "Expected 9 columns after one-hot encoding"
391        );
392        // n_features = 4 < 9 (one-hot is wider)
393        assert!(encoded[0].len() > config.n_features);
394    }
395
396    #[test]
397    fn test_explicit_feature_types() {
398        let config = HeteroConfig {
399            n_samples: 30,
400            feature_types: vec![
401                FeatureType::Continuous(2.0, 0.5),
402                FeatureType::Categorical(3),
403                FeatureType::Binary,
404            ],
405            n_features: 3,
406            n_classes: 2,
407            seed: 1,
408        };
409        let ds = make_heterogeneous(&config);
410        // Check continuous feature produces floats
411        for row in &ds.features {
412            assert!(matches!(row[0], HeteroFeatureValue::Float(_)));
413            assert!(matches!(row[1], HeteroFeatureValue::Int(_)));
414            assert!(matches!(row[2], HeteroFeatureValue::Bool(_)));
415        }
416    }
417
418    #[test]
419    fn test_label_range() {
420        let config = HeteroConfig {
421            n_samples: 60,
422            feature_types: Vec::new(),
423            n_features: 6,
424            n_classes: 3,
425            seed: 5,
426        };
427        let ds = make_heterogeneous(&config);
428        for &label in &ds.labels {
429            assert!(label < 3, "Label {label} out of range [0, 3)");
430        }
431    }
432}