Skip to main content

sensorlm/data/
download.rs

1//! Dataset download utilities.
2//!
3//! The original SensorLM training data (59.7 M hours from 103 000 participants)
4//! is internal to Google and is not publicly available.  This module provides:
5//!
6//! 1. A **synthetic data generator** that creates plausible random wearable
7//!    tensors for local development, unit tests, and integration tests.
8//! 2. **Downloaders** for two publicly available datasets that can substitute
9//!    as proof-of-concept training data:
10//!    * [PAMAP2](https://archive.ics.uci.edu/dataset/231/pamap2+physical+activity+monitoring)
11//!    * [WESAD](https://archive.ics.uci.edu/dataset/465/wesad+wearable+stress+and+affect+detection)
12//! 3. A generic [`download_file`] helper with SHA-256 checksum verification
13//!    and a progress bar.
14
15use std::{
16    fs,
17    io::Write,
18    path::{Path, PathBuf},
19};
20
21use indicatif::{ProgressBar, ProgressStyle};
22use sha2::{Digest, Sha256};
23use tracing::{info, warn};
24
25use crate::error::{Result, SensorLMError};
26
27// ---------------------------------------------------------------------------
28// Known public datasets
29// ---------------------------------------------------------------------------
30
31/// Registry entry for a downloadable dataset.
32#[derive(Debug, Clone)]
33pub struct DatasetEntry {
34    /// Human-readable name.
35    pub name: &'static str,
36    /// Primary download URL.
37    pub url: &'static str,
38    /// Expected SHA-256 hex digest of the downloaded file (empty = skip check).
39    pub sha256: &'static str,
40    /// Total uncompressed size in bytes (used to size the progress bar).
41    pub size_bytes: u64,
42}
43
44/// Publicly downloadable datasets compatible with this pipeline.
45pub const KNOWN_DATASETS: &[DatasetEntry] = &[
46    DatasetEntry {
47        name: "PAMAP2",
48        url: "https://archive.ics.uci.edu/static/public/231/pamap2+physical+activity+monitoring.zip",
49        sha256: "",   // skip verification – checksum not published by UCI
50        size_bytes: 680_000_000,
51    },
52    DatasetEntry {
53        name: "WESAD",
54        url: "https://uni-siegen.sciebo.de/s/HGdUkoNlW1Ub0Gx/download",
55        sha256: "",
56        size_bytes: 1_800_000_000,
57    },
58];
59
60/// Resolve a dataset entry by name (case-insensitive).
61pub fn find_dataset(name: &str) -> Option<&'static DatasetEntry> {
62    KNOWN_DATASETS
63        .iter()
64        .find(|d| d.name.to_ascii_lowercase() == name.to_ascii_lowercase())
65}
66
67// ---------------------------------------------------------------------------
68// Generic HTTP downloader
69// ---------------------------------------------------------------------------
70
71/// Download a file from `url` to `dest_path`.
72///
73/// * Shows a progress bar via `indicatif`.
74/// * Verifies the SHA-256 digest if `expected_sha256` is non-empty.
75/// * Skips the download if the file already exists **and** has the correct
76///   checksum.
77///
78/// # Errors
79///
80/// Returns an error if the HTTP request fails, the write fails, or the
81/// checksum does not match.
82pub fn download_file(url: &str, dest_path: &Path, expected_sha256: &str) -> Result<()> {
83    // Check if already downloaded with correct checksum.
84    if dest_path.exists() && !expected_sha256.is_empty() {
85        let existing_hash = sha256_of_file(dest_path)?;
86        if existing_hash.eq_ignore_ascii_case(expected_sha256) {
87            info!("✓ {} already downloaded and verified.", dest_path.display());
88            return Ok(());
89        } else {
90            warn!(
91                "Checksum mismatch for {}: expected {} got {}. Re-downloading.",
92                dest_path.display(),
93                expected_sha256,
94                existing_hash
95            );
96        }
97    }
98
99    info!("Downloading {} → {}", url, dest_path.display());
100
101    // Create parent directories.
102    if let Some(parent) = dest_path.parent() {
103        fs::create_dir_all(parent)?;
104    }
105
106    let client = reqwest::blocking::Client::builder()
107        .timeout(std::time::Duration::from_secs(3600))
108        .build()
109        .map_err(|e| SensorLMError::DownloadError { url: url.to_string(), source: e })?;
110
111    let mut response = client
112        .get(url)
113        .send()
114        .map_err(|e| SensorLMError::DownloadError { url: url.to_string(), source: e })?;
115
116    let total_bytes = response.content_length().unwrap_or(0);
117    let pb = ProgressBar::new(total_bytes);
118    pb.set_style(
119        ProgressStyle::with_template(
120            "{spinner:.green} [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({eta})",
121        )
122        .unwrap()
123        .progress_chars("=>-"),
124    );
125
126    let mut file = fs::File::create(dest_path)?;
127    let mut downloaded = 0u64;
128    let mut buf = vec![0u8; 8192];
129
130    loop {
131        use std::io::Read;
132        let n = response
133            .read(&mut buf)
134            .map_err(|e| SensorLMError::Io(e))?;
135        if n == 0 {
136            break;
137        }
138        file.write_all(&buf[..n])?;
139        downloaded += n as u64;
140        pb.set_position(downloaded);
141    }
142    pb.finish_with_message("Download complete");
143
144    // Verify checksum.
145    if !expected_sha256.is_empty() {
146        let actual_hash = sha256_of_file(dest_path)?;
147        if !actual_hash.eq_ignore_ascii_case(expected_sha256) {
148            fs::remove_file(dest_path)?;
149            return Err(SensorLMError::DatasetError(format!(
150                "SHA-256 mismatch: expected {expected_sha256}, got {actual_hash}"
151            )));
152        }
153        info!("✓ Checksum verified.");
154    }
155
156    Ok(())
157}
158
159/// Compute the SHA-256 hex digest of a file on disk.
160fn sha256_of_file(path: &Path) -> Result<String> {
161    let bytes = fs::read(path)?;
162    let mut hasher = Sha256::new();
163    hasher.update(&bytes);
164    Ok(hex::encode(hasher.finalize()))
165}
166
167// ---------------------------------------------------------------------------
168// Default data directory
169// ---------------------------------------------------------------------------
170
171/// Return the platform-appropriate data directory for sensorlm-rs.
172///
173/// On Linux / macOS: `~/.local/share/sensorlm`
174/// On Windows: `%APPDATA%\sensorlm`
175pub fn default_data_dir() -> PathBuf {
176    dirs::data_local_dir()
177        .unwrap_or_else(|| PathBuf::from("."))
178        .join("sensorlm")
179}
180
181// ---------------------------------------------------------------------------
182// Synthetic dataset generator
183// ---------------------------------------------------------------------------
184
185use ndarray::Array2;
186use rand::{Rng as _, SeedableRng};
187use rand_distr::{Distribution, Normal};
188
189/// Parameters controlling the synthetic data generator.
190#[derive(Debug, Clone)]
191pub struct SyntheticDataConfig {
192    /// Number of samples (individuals) to generate.
193    pub num_samples: usize,
194    /// Random seed for reproducibility.
195    pub seed: u64,
196    /// Whether to add realistic circadian structure to heart rate.
197    pub add_circadian: bool,
198    /// Whether to simulate missing data (realistic wearable dropout).
199    pub add_missingness: bool,
200    /// Fraction of time-steps to mark as missing [0, 1].
201    pub missingness_rate: f64,
202}
203
204impl Default for SyntheticDataConfig {
205    fn default() -> Self {
206        Self {
207            num_samples: 1000,
208            seed: 42,
209            add_circadian: true,
210            add_missingness: true,
211            missingness_rate: 0.1,
212        }
213    }
214}
215
216/// A single synthetic wearable sample.
217#[derive(Debug, Clone)]
218pub struct SyntheticSample {
219    /// Normalised sensor tensor, shape `(TIME_STEPS, NUM_CHANNELS)`.
220    pub sensor: Array2<f32>,
221    /// Missingness mask, shape `(TIME_STEPS, NUM_CHANNELS)`.  1 = imputed.
222    pub mask: Array2<u8>,
223    /// Pre-generated caption (high-level summary).
224    pub caption: String,
225    /// Sample ID.
226    pub id: usize,
227}
228
229/// Generate a batch of synthetic wearable samples.
230///
231/// Each sample simulates one 24-hour recording window with:
232///
233/// * Normally distributed channel noise scaled by the population parameters
234///   in [`NORM_PARAMS`].
235/// * A sinusoidal circadian rhythm on heart rate and step count (optional).
236/// * Random missingness blocks mimicking sensor dropout (optional).
237pub fn generate_synthetic_dataset(cfg: &SyntheticDataConfig) -> Vec<SyntheticSample> {
238    use crate::constants::{NUM_CHANNELS, NORM_PARAMS, TIME_STEPS};
239
240    let mut rng = rand::rngs::StdRng::seed_from_u64(cfg.seed);
241    let mut samples = Vec::with_capacity(cfg.num_samples);
242
243    for id in 0..cfg.num_samples {
244        let mut sensor = Array2::<f32>::zeros((TIME_STEPS, NUM_CHANNELS));
245        let mut mask = Array2::<u8>::zeros((TIME_STEPS, NUM_CHANNELS));
246
247        for ch in 0..NUM_CHANNELS {
248            let (_mean, _std) = NORM_PARAMS[ch]; // reserved for real normalisation
249            let noise = Normal::new(0.0f64, 0.3).unwrap();
250
251            for t in 0..TIME_STEPS {
252                let base: f64 = noise.sample(&mut rng);
253
254                // Add a gentle circadian sine wave on HR (channel 0) and
255                // steps (channel 3).
256                let circadian = if cfg.add_circadian && (ch == 0 || ch == 3) {
257                    0.5 * (2.0 * std::f64::consts::PI * t as f64 / TIME_STEPS as f64).sin()
258                } else {
259                    0.0
260                };
261
262                // Store as z-score (already normalised by construction).
263                sensor[[t, ch]] = (base + circadian) as f32;
264            }
265
266            // Simulate missingness blocks.
267            if cfg.add_missingness {
268                let n_missing = ((TIME_STEPS as f64 * cfg.missingness_rate) as usize).max(1);
269                for _ in 0..n_missing {
270                    let t: usize = rng.gen_range(0..TIME_STEPS);
271                    mask[[t, ch]] = 1;
272                }
273            }
274        }
275
276        let caption = format!(
277            "Synthetic 24-hour recording for individual {id}. \
278             Heart rate shows typical circadian variation. \
279             Activity patterns reflect normal daily movement."
280        );
281
282        samples.push(SyntheticSample { sensor, mask, caption, id });
283    }
284
285    samples
286}