sensorlm/data/dataset.rs
1//! Burn [`Dataset`] implementations for SensorLM.
2//!
3//! Two dataset types are provided:
4//!
5//! * [`SyntheticSensorDataset`] – Uses the built-in synthetic data generator.
6//! Good for unit tests, profiling, and quick experiments without real data.
7//!
8//! * [`CsvSensorDataset`] – Loads pre-processed sensor data from a CSV file.
9//! The file is expected to have one row per sample, with columns named after
10//! [`crate::constants::FEATURE_NAMES`] (normalised f32 values) plus a
11//! `"caption"` column containing the text.
12//!
13//! Both implement [`burn::data::dataset::Dataset`] so they can be wrapped by
14//! [`burn::data::dataloader::DataLoaderBuilder`].
15
16use std::path::Path;
17
18use burn::data::dataset::Dataset;
19
20use crate::constants::{NUM_CHANNELS, TIME_STEPS};
21use crate::data::download::{generate_synthetic_dataset, SyntheticDataConfig};
22use crate::error::{Result, SensorLMError};
23
24// ---------------------------------------------------------------------------
25// Shared item type
26// ---------------------------------------------------------------------------
27
28/// A single (sensor, caption) training pair.
29#[derive(Debug, Clone)]
30pub struct SensorTextItem {
31 /// Normalised sensor values, flat `f32` slice of length `T × C`
32 /// (row-major: index `t * C + c`).
33 pub sensor: Vec<f32>,
34 /// Tokenised caption as a `Vec<i32>` (token IDs, padded to `max_seq_len`).
35 pub token_ids: Vec<i32>,
36 /// Padding mask: `1` for real tokens, `0` for padding.
37 pub attention_mask: Vec<i32>,
38 /// Raw caption text (kept for debugging / evaluation).
39 pub caption_text: String,
40}
41
42// ---------------------------------------------------------------------------
43// Synthetic dataset
44// ---------------------------------------------------------------------------
45
46/// An in-memory dataset of synthetically generated sensor-text pairs.
47///
48/// Useful for smoke-testing the training pipeline without a real dataset.
49///
50/// # Example
51///
52/// ```rust,no_run
53/// use sensorlm::data::dataset::SyntheticSensorDataset;
54/// use burn::data::dataset::Dataset;
55/// let ds = SyntheticSensorDataset::new(512, 42, 256);
56/// println!("Dataset size: {}", ds.len());
57/// ```
58pub struct SyntheticSensorDataset {
59 items: Vec<SensorTextItem>,
60}
61
62impl SyntheticSensorDataset {
63 /// Create a new synthetic dataset.
64 ///
65 /// # Arguments
66 ///
67 /// * `num_samples` – Number of (sensor, caption) pairs to generate.
68 /// * `seed` – Random seed.
69 /// * `max_seq_len` – Token sequence length (captions are padded / truncated).
70 pub fn new(num_samples: usize, seed: u64, max_seq_len: usize) -> Self {
71 let cfg = SyntheticDataConfig {
72 num_samples,
73 seed,
74 add_circadian: true,
75 add_missingness: true,
76 missingness_rate: 0.05,
77 };
78 let raw = generate_synthetic_dataset(&cfg);
79
80 let items = raw
81 .into_iter()
82 .map(|s| {
83 let sensor: Vec<f32> = s.sensor.iter().copied().collect();
84
85 // Very simple character-level "tokenisation" for the synthetic
86 // dataset (a real run uses a SentencePiece tokeniser).
87 let raw_ids: Vec<i32> = s
88 .caption
89 .chars()
90 .take(max_seq_len)
91 .map(|c| c as i32 % 32_000)
92 .collect();
93
94 let len = raw_ids.len();
95 let mut token_ids = raw_ids;
96 token_ids.resize(max_seq_len, 1); // pad with id=1
97
98 let mut attention_mask = vec![1i32; len];
99 attention_mask.resize(max_seq_len, 0);
100
101 SensorTextItem {
102 sensor,
103 token_ids,
104 attention_mask,
105 caption_text: s.caption,
106 }
107 })
108 .collect();
109
110 Self { items }
111 }
112}
113
114impl Dataset<SensorTextItem> for SyntheticSensorDataset {
115 fn get(&self, index: usize) -> Option<SensorTextItem> {
116 self.items.get(index).cloned()
117 }
118
119 fn len(&self) -> usize {
120 self.items.len()
121 }
122}
123
124// ---------------------------------------------------------------------------
125// CSV-backed dataset
126// ---------------------------------------------------------------------------
127
128/// A dataset loaded from a CSV file.
129///
130/// Expected CSV schema:
131///
132/// ```text
133/// col_0, col_1, ..., col_33, caption
134/// <f32>, <f32>, ..., <f32>, "The person is walking..."
135/// ```
136///
137/// There must be exactly `T × C` numeric columns followed by a `caption`
138/// string column. If your CSV has one row per time-step, pre-aggregate to
139/// one row per sample before loading.
140pub struct CsvSensorDataset {
141 items: Vec<SensorTextItem>,
142}
143
144impl CsvSensorDataset {
145 /// Load all rows from a CSV file.
146 ///
147 /// # Arguments
148 ///
149 /// * `path` – Path to the `.csv` file.
150 /// * `max_seq_len` – Target token sequence length.
151 /// * `tokenize` – A closure that converts a caption string into token IDs.
152 pub fn from_csv<F>(path: &Path, max_seq_len: usize, tokenize: F) -> Result<Self>
153 where
154 F: Fn(&str) -> Vec<i32>,
155 {
156 let expected_sensor_len = TIME_STEPS * NUM_CHANNELS;
157 let mut items = Vec::new();
158
159 let mut rdr = csv::Reader::from_path(path)
160 .map_err(|e| SensorLMError::DatasetError(e.to_string()))?;
161
162 for result in rdr.records() {
163 let record = result.map_err(|e| SensorLMError::DatasetError(e.to_string()))?;
164
165 if record.len() < expected_sensor_len + 1 {
166 return Err(SensorLMError::DatasetError(format!(
167 "Expected at least {} columns, got {}",
168 expected_sensor_len + 1,
169 record.len()
170 )));
171 }
172
173 let sensor: Vec<f32> = (0..expected_sensor_len)
174 .map(|i| {
175 record[i]
176 .trim()
177 .parse::<f32>()
178 .unwrap_or(0.0)
179 })
180 .collect();
181
182 let caption = record[expected_sensor_len].trim().to_string();
183
184 let mut token_ids = tokenize(&caption);
185 let real_len = token_ids.len().min(max_seq_len);
186 token_ids.truncate(real_len);
187 let mut attn = vec![1i32; real_len];
188 token_ids.resize(max_seq_len, 1);
189 attn.resize(max_seq_len, 0);
190
191 items.push(SensorTextItem {
192 sensor,
193 token_ids,
194 attention_mask: attn,
195 caption_text: caption,
196 });
197 }
198
199 Ok(Self { items })
200 }
201}
202
203impl Dataset<SensorTextItem> for CsvSensorDataset {
204 fn get(&self, index: usize) -> Option<SensorTextItem> {
205 self.items.get(index).cloned()
206 }
207
208 fn len(&self) -> usize {
209 self.items.len()
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216
217 #[test]
218 fn test_synthetic_dataset() {
219 let ds = SyntheticSensorDataset::new(16, 99, 256);
220 assert_eq!(ds.len(), 16);
221 let item = ds.get(0).expect("first item");
222 assert_eq!(item.sensor.len(), TIME_STEPS * NUM_CHANNELS);
223 assert_eq!(item.token_ids.len(), 256);
224 assert_eq!(item.attention_mask.len(), 256);
225 }
226}