make_multitask_dataset

Function make_multitask_dataset 

Source
pub fn make_multitask_dataset(
    n_samples: usize,
    config: MultiTaskConfig,
) -> Result<MultiTaskDataset>
Expand description

Generate multi-task learning dataset

Examples found in repository?
examples/advanced_generators_demo.rs (line 238)
217fn demonstrate_multitask_learning() -> Result<(), Box<dyn std::error::Error>> {
218    println!("🎯 MULTI-TASK LEARNING DATASETS");
219    println!("{}", "-".repeat(35));
220
221    // Basic multi-task scenario
222    println!("Multi-task scenario: Healthcare prediction");
223    let config = MultiTaskConfig {
224        n_tasks: 4,
225        task_types: vec![
226            TaskType::Classification(3), // Disease classification
227            TaskType::Regression,        // Risk score prediction
228            TaskType::Classification(2), // Treatment response
229            TaskType::Ordinal(5),        // Severity rating
230        ],
231        shared_features: 20,        // Common patient features
232        task_specific_features: 10, // Task-specific biomarkers
233        task_correlation: 0.7,      // High correlation between tasks
234        task_noise: vec![0.05, 0.1, 0.08, 0.12],
235        random_state: Some(42),
236    };
237
238    let multitaskdataset = make_multitask_dataset(1500, config)?;
239
240    println!("  📊 Multi-task dataset structure:");
241    println!("    Number of tasks: {}", multitaskdataset.tasks.len());
242    println!("    Shared features: {}", multitaskdataset.shared_features);
243    println!(
244        "    Task correlation: {:.1}",
245        multitaskdataset.task_correlation
246    );
247
248    for (i, task) in multitaskdataset.tasks.iter().enumerate() {
249        println!(
250            "    Task {}: {} samples, {} features ({})",
251            i + 1,
252            task.n_samples(),
253            task.n_features(),
254            task.metadata
255                .get("task_type")
256                .unwrap_or(&"unknown".to_string())
257        );
258
259        // Analyze task characteristics
260        if let Some(target) = &task.target {
261            match task
262                .metadata
263                .get("task_type")
264                .map(|s| s.as_str())
265                .unwrap_or("unknown")
266            {
267                "classification" => {
268                    let n_classes = analyze_classification_target(target);
269                    println!("      Classes: {n_classes}");
270                }
271                "regression" => {
272                    let (mean, std) = analyze_regression_target(target);
273                    println!("      Target range: {mean:.2} ± {std:.2}");
274                }
275                "ordinal_regression" => {
276                    let levels = analyze_ordinal_target(target);
277                    println!("      Ordinal levels: {levels}");
278                }
279                _ => {}
280            }
281        }
282    }
283
284    // Transfer learning scenario
285    println!("\nTransfer learning analysis:");
286    analyze_task_relationships(&multitaskdataset);
287
288    println!();
289    Ok(())
290}