sensorlm/data/
download.rs1use std::{
16 fs,
17 io::Write,
18 path::{Path, PathBuf},
19};
20
21use indicatif::{ProgressBar, ProgressStyle};
22use sha2::{Digest, Sha256};
23use tracing::{info, warn};
24
25use crate::error::{Result, SensorLMError};
26
27#[derive(Debug, Clone)]
33pub struct DatasetEntry {
34 pub name: &'static str,
36 pub url: &'static str,
38 pub sha256: &'static str,
40 pub size_bytes: u64,
42}
43
44pub const KNOWN_DATASETS: &[DatasetEntry] = &[
46 DatasetEntry {
47 name: "PAMAP2",
48 url: "https://archive.ics.uci.edu/static/public/231/pamap2+physical+activity+monitoring.zip",
49 sha256: "", size_bytes: 680_000_000,
51 },
52 DatasetEntry {
53 name: "WESAD",
54 url: "https://uni-siegen.sciebo.de/s/HGdUkoNlW1Ub0Gx/download",
55 sha256: "",
56 size_bytes: 1_800_000_000,
57 },
58];
59
60pub fn find_dataset(name: &str) -> Option<&'static DatasetEntry> {
62 KNOWN_DATASETS
63 .iter()
64 .find(|d| d.name.to_ascii_lowercase() == name.to_ascii_lowercase())
65}
66
67pub fn download_file(url: &str, dest_path: &Path, expected_sha256: &str) -> Result<()> {
83 if dest_path.exists() && !expected_sha256.is_empty() {
85 let existing_hash = sha256_of_file(dest_path)?;
86 if existing_hash.eq_ignore_ascii_case(expected_sha256) {
87 info!("✓ {} already downloaded and verified.", dest_path.display());
88 return Ok(());
89 } else {
90 warn!(
91 "Checksum mismatch for {}: expected {} got {}. Re-downloading.",
92 dest_path.display(),
93 expected_sha256,
94 existing_hash
95 );
96 }
97 }
98
99 info!("Downloading {} → {}", url, dest_path.display());
100
101 if let Some(parent) = dest_path.parent() {
103 fs::create_dir_all(parent)?;
104 }
105
106 let client = reqwest::blocking::Client::builder()
107 .timeout(std::time::Duration::from_secs(3600))
108 .build()
109 .map_err(|e| SensorLMError::DownloadError { url: url.to_string(), source: e })?;
110
111 let mut response = client
112 .get(url)
113 .send()
114 .map_err(|e| SensorLMError::DownloadError { url: url.to_string(), source: e })?;
115
116 let total_bytes = response.content_length().unwrap_or(0);
117 let pb = ProgressBar::new(total_bytes);
118 pb.set_style(
119 ProgressStyle::with_template(
120 "{spinner:.green} [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({eta})",
121 )
122 .unwrap()
123 .progress_chars("=>-"),
124 );
125
126 let mut file = fs::File::create(dest_path)?;
127 let mut downloaded = 0u64;
128 let mut buf = vec![0u8; 8192];
129
130 loop {
131 use std::io::Read;
132 let n = response
133 .read(&mut buf)
134 .map_err(|e| SensorLMError::Io(e))?;
135 if n == 0 {
136 break;
137 }
138 file.write_all(&buf[..n])?;
139 downloaded += n as u64;
140 pb.set_position(downloaded);
141 }
142 pb.finish_with_message("Download complete");
143
144 if !expected_sha256.is_empty() {
146 let actual_hash = sha256_of_file(dest_path)?;
147 if !actual_hash.eq_ignore_ascii_case(expected_sha256) {
148 fs::remove_file(dest_path)?;
149 return Err(SensorLMError::DatasetError(format!(
150 "SHA-256 mismatch: expected {expected_sha256}, got {actual_hash}"
151 )));
152 }
153 info!("✓ Checksum verified.");
154 }
155
156 Ok(())
157}
158
159fn sha256_of_file(path: &Path) -> Result<String> {
161 let bytes = fs::read(path)?;
162 let mut hasher = Sha256::new();
163 hasher.update(&bytes);
164 Ok(hex::encode(hasher.finalize()))
165}
166
167pub fn default_data_dir() -> PathBuf {
176 dirs::data_local_dir()
177 .unwrap_or_else(|| PathBuf::from("."))
178 .join("sensorlm")
179}
180
181use ndarray::Array2;
186use rand::{Rng as _, SeedableRng};
187use rand_distr::{Distribution, Normal};
188
189#[derive(Debug, Clone)]
191pub struct SyntheticDataConfig {
192 pub num_samples: usize,
194 pub seed: u64,
196 pub add_circadian: bool,
198 pub add_missingness: bool,
200 pub missingness_rate: f64,
202}
203
204impl Default for SyntheticDataConfig {
205 fn default() -> Self {
206 Self {
207 num_samples: 1000,
208 seed: 42,
209 add_circadian: true,
210 add_missingness: true,
211 missingness_rate: 0.1,
212 }
213 }
214}
215
216#[derive(Debug, Clone)]
218pub struct SyntheticSample {
219 pub sensor: Array2<f32>,
221 pub mask: Array2<u8>,
223 pub caption: String,
225 pub id: usize,
227}
228
229pub fn generate_synthetic_dataset(cfg: &SyntheticDataConfig) -> Vec<SyntheticSample> {
238 use crate::constants::{NUM_CHANNELS, NORM_PARAMS, TIME_STEPS};
239
240 let mut rng = rand::rngs::StdRng::seed_from_u64(cfg.seed);
241 let mut samples = Vec::with_capacity(cfg.num_samples);
242
243 for id in 0..cfg.num_samples {
244 let mut sensor = Array2::<f32>::zeros((TIME_STEPS, NUM_CHANNELS));
245 let mut mask = Array2::<u8>::zeros((TIME_STEPS, NUM_CHANNELS));
246
247 for ch in 0..NUM_CHANNELS {
248 let (_mean, _std) = NORM_PARAMS[ch]; let noise = Normal::new(0.0f64, 0.3).unwrap();
250
251 for t in 0..TIME_STEPS {
252 let base: f64 = noise.sample(&mut rng);
253
254 let circadian = if cfg.add_circadian && (ch == 0 || ch == 3) {
257 0.5 * (2.0 * std::f64::consts::PI * t as f64 / TIME_STEPS as f64).sin()
258 } else {
259 0.0
260 };
261
262 sensor[[t, ch]] = (base + circadian) as f32;
264 }
265
266 if cfg.add_missingness {
268 let n_missing = ((TIME_STEPS as f64 * cfg.missingness_rate) as usize).max(1);
269 for _ in 0..n_missing {
270 let t: usize = rng.gen_range(0..TIME_STEPS);
271 mask[[t, ch]] = 1;
272 }
273 }
274 }
275
276 let caption = format!(
277 "Synthetic 24-hour recording for individual {id}. \
278 Heart rate shows typical circadian variation. \
279 Activity patterns reflect normal daily movement."
280 );
281
282 samples.push(SyntheticSample { sensor, mask, caption, id });
283 }
284
285 samples
286}