scirs2_datasets/generators/
heterogeneous.rs1#[derive(Debug, Clone)]
10#[non_exhaustive]
11pub enum FeatureType {
12 Continuous(f64, f64),
14 Categorical(usize),
16 Ordinal(usize),
18 Binary,
20}
21
22#[derive(Debug, Clone)]
24pub struct HeteroConfig {
25 pub n_samples: usize,
27 pub feature_types: Vec<FeatureType>,
29 pub n_features: usize,
31 pub n_classes: usize,
33 pub seed: u64,
35}
36
37impl Default for HeteroConfig {
38 fn default() -> Self {
39 Self {
40 n_samples: 500,
41 feature_types: Vec::new(),
42 n_features: 10,
43 n_classes: 2,
44 seed: 42,
45 }
46 }
47}
48
49#[derive(Debug, Clone, PartialEq)]
51#[non_exhaustive]
52pub enum HeteroFeatureValue {
53 Float(f64),
55 Int(usize),
57 Bool(bool),
59}
60
61#[derive(Debug, Clone)]
63pub struct HeteroDataset {
64 pub features: Vec<Vec<HeteroFeatureValue>>,
66 pub labels: Vec<usize>,
68 pub feature_types: Vec<FeatureType>,
70 pub feature_names: Vec<String>,
72}
73
74struct Lcg {
76 state: u64,
77}
78
79impl Lcg {
80 fn new(seed: u64) -> Self {
81 Self { state: seed }
82 }
83
84 fn next_u64(&mut self) -> u64 {
85 self.state = self
86 .state
87 .wrapping_mul(6_364_136_223_846_793_005)
88 .wrapping_add(1_442_695_040_888_963_407);
89 self.state
90 }
91
92 fn next_f64(&mut self) -> f64 {
93 (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
94 }
95
96 fn next_normal(&mut self) -> f64 {
97 let u1 = self.next_f64().max(1e-10);
98 let u2 = self.next_f64();
99 (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
100 }
101
102 fn next_usize_below(&mut self, n: usize) -> usize {
103 (self.next_u64() % n as u64) as usize
104 }
105}
106
107fn auto_feature_types(n: usize, rng: &mut Lcg) -> Vec<FeatureType> {
109 (0..n)
110 .map(|i| match i % 4 {
111 0 => {
112 let mean = rng.next_normal();
113 let std = 0.5 + rng.next_f64();
114 FeatureType::Continuous(mean, std)
115 }
116 1 => {
117 let n_cats = 2 + rng.next_usize_below(5); FeatureType::Categorical(n_cats)
119 }
120 2 => {
121 let n_levels = 3 + rng.next_usize_below(5); FeatureType::Ordinal(n_levels)
123 }
124 _ => FeatureType::Binary,
125 })
126 .collect()
127}
128
129pub fn make_heterogeneous(config: &HeteroConfig) -> HeteroDataset {
148 let mut rng = Lcg::new(config.seed);
149
150 let feature_types: Vec<FeatureType> = if config.feature_types.is_empty() {
152 auto_feature_types(config.n_features, &mut rng)
153 } else {
154 config.feature_types.clone()
155 };
156 let n_features = feature_types.len();
157
158 let feature_names: Vec<String> = feature_types
160 .iter()
161 .enumerate()
162 .map(|(i, ft)| match ft {
163 FeatureType::Continuous(_, _) => format!("cont_{i}"),
164 FeatureType::Categorical(_) => format!("cat_{i}"),
165 FeatureType::Ordinal(_) => format!("ord_{i}"),
166 FeatureType::Binary => format!("bin_{i}"),
167 })
168 .collect();
169
170 let class_cont_offsets: Vec<Vec<f64>> = (0..config.n_classes)
173 .map(|_| (0..n_features).map(|_| rng.next_normal() * 1.5).collect())
174 .collect();
175
176 let class_cat_probs: Vec<Vec<Vec<f64>>> = (0..config.n_classes)
178 .map(|_| {
179 feature_types
180 .iter()
181 .map(|ft| match ft {
182 FeatureType::Categorical(k) | FeatureType::Ordinal(k) => {
183 let mut weights: Vec<f64> = (0..*k).map(|_| rng.next_f64() + 0.1).collect();
184 let s: f64 = weights.iter().sum();
185 for w in &mut weights {
186 *w /= s;
187 }
188 weights
189 }
190 _ => vec![0.5, 0.5], })
192 .collect()
193 })
194 .collect();
195
196 let class_bin_probs: Vec<Vec<f64>> = (0..config.n_classes)
198 .map(|_| {
199 (0..n_features)
200 .map(|_| 0.1 + rng.next_f64() * 0.8)
201 .collect()
202 })
203 .collect();
204
205 let n_per_class = config.n_samples / config.n_classes;
207 let mut features: Vec<Vec<HeteroFeatureValue>> = Vec::with_capacity(config.n_samples);
208 let mut labels: Vec<usize> = Vec::with_capacity(config.n_samples);
209
210 for class_idx in 0..config.n_classes {
211 let count = if class_idx == config.n_classes - 1 {
212 config.n_samples - n_per_class * (config.n_classes - 1)
213 } else {
214 n_per_class
215 };
216 for _ in 0..count {
217 let row: Vec<HeteroFeatureValue> = feature_types
218 .iter()
219 .enumerate()
220 .map(|(j, ft)| match ft {
221 FeatureType::Continuous(mean, std) => {
222 let offset = class_cont_offsets[class_idx][j];
223 let val = (mean + offset) + rng.next_normal() * std;
224 HeteroFeatureValue::Float(val)
225 }
226 FeatureType::Categorical(k) => {
227 let probs = &class_cat_probs[class_idx][j];
228 let u = rng.next_f64();
229 let mut cumsum = 0.0;
230 let mut cat = 0;
231 for (idx, &p) in probs.iter().enumerate() {
232 cumsum += p;
233 if u < cumsum {
234 cat = idx;
235 break;
236 }
237 cat = k - 1; }
239 HeteroFeatureValue::Int(cat)
240 }
241 FeatureType::Ordinal(k) => {
242 let probs = &class_cat_probs[class_idx][j];
243 let u = rng.next_f64();
244 let mut cumsum = 0.0;
245 let mut level = 0;
246 for (idx, &p) in probs.iter().enumerate() {
247 cumsum += p;
248 if u < cumsum {
249 level = idx;
250 break;
251 }
252 level = k - 1;
253 }
254 HeteroFeatureValue::Int(level)
255 }
256 FeatureType::Binary => {
257 let p = class_bin_probs[class_idx][j];
258 HeteroFeatureValue::Bool(rng.next_f64() < p)
259 }
260 })
261 .collect();
262 features.push(row);
263 labels.push(class_idx);
264 }
265 }
266
267 let n = features.len();
269 for i in (1..n).rev() {
270 let j = rng.next_usize_below(i + 1);
271 features.swap(i, j);
272 labels.swap(i, j);
273 }
274
275 HeteroDataset {
276 features,
277 labels,
278 feature_types,
279 feature_names,
280 }
281}
282
283pub fn encode_one_hot(dataset: &HeteroDataset) -> Vec<Vec<f64>> {
302 let widths: Vec<usize> = dataset
304 .feature_types
305 .iter()
306 .map(|ft| match ft {
307 FeatureType::Continuous(_, _) => 1,
308 FeatureType::Categorical(k) => *k,
309 FeatureType::Ordinal(k) => *k,
310 FeatureType::Binary => 1,
311 })
312 .collect();
313
314 let total_width: usize = widths.iter().sum();
315
316 dataset
317 .features
318 .iter()
319 .map(|row| {
320 let mut encoded = Vec::with_capacity(total_width);
321 for (j, val) in row.iter().enumerate() {
322 match (&dataset.feature_types[j], val) {
323 (FeatureType::Continuous(_, _), HeteroFeatureValue::Float(v)) => {
324 encoded.push(*v);
325 }
326 (FeatureType::Categorical(k), HeteroFeatureValue::Int(cat)) => {
327 for c in 0..*k {
328 encoded.push(if c == *cat { 1.0 } else { 0.0 });
329 }
330 }
331 (FeatureType::Ordinal(k), HeteroFeatureValue::Int(level)) => {
332 for l in 0..*k {
333 encoded.push(if l == *level { 1.0 } else { 0.0 });
334 }
335 }
336 (FeatureType::Binary, HeteroFeatureValue::Bool(b)) => {
337 encoded.push(if *b { 1.0 } else { 0.0 });
338 }
339 (_, HeteroFeatureValue::Float(v)) => encoded.push(*v),
341 (_, HeteroFeatureValue::Int(k)) => encoded.push(*k as f64),
342 (_, HeteroFeatureValue::Bool(b)) => encoded.push(if *b { 1.0 } else { 0.0 }),
343 }
344 }
345 encoded
346 })
347 .collect()
348}
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353
354 #[test]
355 fn test_heterogeneous_basic() {
356 let config = HeteroConfig {
357 n_samples: 50,
358 feature_types: Vec::new(),
359 n_features: 8,
360 n_classes: 2,
361 seed: 42,
362 };
363 let ds = make_heterogeneous(&config);
364 assert_eq!(ds.features.len(), 50);
365 assert_eq!(ds.labels.len(), 50);
366 assert_eq!(ds.feature_types.len(), 8);
367 assert_eq!(ds.feature_names.len(), 8);
368 }
369
370 #[test]
371 fn test_encode_one_hot_wider() {
372 let config = HeteroConfig {
373 n_samples: 20,
374 feature_types: vec![
375 FeatureType::Continuous(0.0, 1.0),
376 FeatureType::Categorical(4),
377 FeatureType::Ordinal(3),
378 FeatureType::Binary,
379 ],
380 n_features: 4,
381 n_classes: 2,
382 seed: 77,
383 };
384 let ds = make_heterogeneous(&config);
385 let encoded = encode_one_hot(&ds);
386 assert_eq!(
388 encoded[0].len(),
389 9,
390 "Expected 9 columns after one-hot encoding"
391 );
392 assert!(encoded[0].len() > config.n_features);
394 }
395
396 #[test]
397 fn test_explicit_feature_types() {
398 let config = HeteroConfig {
399 n_samples: 30,
400 feature_types: vec![
401 FeatureType::Continuous(2.0, 0.5),
402 FeatureType::Categorical(3),
403 FeatureType::Binary,
404 ],
405 n_features: 3,
406 n_classes: 2,
407 seed: 1,
408 };
409 let ds = make_heterogeneous(&config);
410 for row in &ds.features {
412 assert!(matches!(row[0], HeteroFeatureValue::Float(_)));
413 assert!(matches!(row[1], HeteroFeatureValue::Int(_)));
414 assert!(matches!(row[2], HeteroFeatureValue::Bool(_)));
415 }
416 }
417
418 #[test]
419 fn test_label_range() {
420 let config = HeteroConfig {
421 n_samples: 60,
422 feature_types: Vec::new(),
423 n_features: 6,
424 n_classes: 3,
425 seed: 5,
426 };
427 let ds = make_heterogeneous(&config);
428 for &label in &ds.labels {
429 assert!(label < 3, "Label {label} out of range [0, 3)");
430 }
431 }
432}