Skip to main content

sensorlm/data/
dataset.rs

1//! Burn [`Dataset`] implementations for SensorLM.
2//!
3//! Two dataset types are provided:
4//!
5//! * [`SyntheticSensorDataset`] – Uses the built-in synthetic data generator.
6//!   Good for unit tests, profiling, and quick experiments without real data.
7//!
8//! * [`CsvSensorDataset`] – Loads pre-processed sensor data from a CSV file.
9//!   The file is expected to have one row per sample, with columns named after
10//!   [`crate::constants::FEATURE_NAMES`] (normalised f32 values) plus a
11//!   `"caption"` column containing the text.
12//!
13//! Both implement [`burn::data::dataset::Dataset`] so they can be wrapped by
14//! [`burn::data::dataloader::DataLoaderBuilder`].
15
16use std::path::Path;
17
18use burn::data::dataset::Dataset;
19
20use crate::constants::{NUM_CHANNELS, TIME_STEPS};
21use crate::data::download::{generate_synthetic_dataset, SyntheticDataConfig};
22use crate::error::{Result, SensorLMError};
23
24// ---------------------------------------------------------------------------
25// Shared item type
26// ---------------------------------------------------------------------------
27
28/// A single (sensor, caption) training pair.
29#[derive(Debug, Clone)]
30pub struct SensorTextItem {
31    /// Normalised sensor values, flat `f32` slice of length `T × C`
32    /// (row-major: index `t * C + c`).
33    pub sensor: Vec<f32>,
34    /// Tokenised caption as a `Vec<i32>` (token IDs, padded to `max_seq_len`).
35    pub token_ids: Vec<i32>,
36    /// Padding mask: `1` for real tokens, `0` for padding.
37    pub attention_mask: Vec<i32>,
38    /// Raw caption text (kept for debugging / evaluation).
39    pub caption_text: String,
40}
41
42// ---------------------------------------------------------------------------
43// Synthetic dataset
44// ---------------------------------------------------------------------------
45
46/// An in-memory dataset of synthetically generated sensor-text pairs.
47///
48/// Useful for smoke-testing the training pipeline without a real dataset.
49///
50/// # Example
51///
52/// ```rust,no_run
53/// use sensorlm::data::dataset::SyntheticSensorDataset;
54/// use burn::data::dataset::Dataset;
55/// let ds = SyntheticSensorDataset::new(512, 42, 256);
56/// println!("Dataset size: {}", ds.len());
57/// ```
58pub struct SyntheticSensorDataset {
59    items: Vec<SensorTextItem>,
60}
61
62impl SyntheticSensorDataset {
63    /// Create a new synthetic dataset.
64    ///
65    /// # Arguments
66    ///
67    /// * `num_samples` – Number of (sensor, caption) pairs to generate.
68    /// * `seed`        – Random seed.
69    /// * `max_seq_len` – Token sequence length (captions are padded / truncated).
70    pub fn new(num_samples: usize, seed: u64, max_seq_len: usize) -> Self {
71        let cfg = SyntheticDataConfig {
72            num_samples,
73            seed,
74            add_circadian: true,
75            add_missingness: true,
76            missingness_rate: 0.05,
77        };
78        let raw = generate_synthetic_dataset(&cfg);
79
80        let items = raw
81            .into_iter()
82            .map(|s| {
83                let sensor: Vec<f32> = s.sensor.iter().copied().collect();
84
85                // Very simple character-level "tokenisation" for the synthetic
86                // dataset (a real run uses a SentencePiece tokeniser).
87                let raw_ids: Vec<i32> = s
88                    .caption
89                    .chars()
90                    .take(max_seq_len)
91                    .map(|c| c as i32 % 32_000)
92                    .collect();
93
94                let len = raw_ids.len();
95                let mut token_ids = raw_ids;
96                token_ids.resize(max_seq_len, 1); // pad with id=1
97
98                let mut attention_mask = vec![1i32; len];
99                attention_mask.resize(max_seq_len, 0);
100
101                SensorTextItem {
102                    sensor,
103                    token_ids,
104                    attention_mask,
105                    caption_text: s.caption,
106                }
107            })
108            .collect();
109
110        Self { items }
111    }
112}
113
114impl Dataset<SensorTextItem> for SyntheticSensorDataset {
115    fn get(&self, index: usize) -> Option<SensorTextItem> {
116        self.items.get(index).cloned()
117    }
118
119    fn len(&self) -> usize {
120        self.items.len()
121    }
122}
123
124// ---------------------------------------------------------------------------
125// CSV-backed dataset
126// ---------------------------------------------------------------------------
127
128/// A dataset loaded from a CSV file.
129///
130/// Expected CSV schema:
131///
132/// ```text
133/// col_0, col_1, ..., col_33, caption
134/// <f32>,  <f32>, ..., <f32>, "The person is walking..."
135/// ```
136///
137/// There must be exactly `T × C` numeric columns followed by a `caption`
138/// string column.  If your CSV has one row per time-step, pre-aggregate to
139/// one row per sample before loading.
140pub struct CsvSensorDataset {
141    items: Vec<SensorTextItem>,
142}
143
144impl CsvSensorDataset {
145    /// Load all rows from a CSV file.
146    ///
147    /// # Arguments
148    ///
149    /// * `path`        – Path to the `.csv` file.
150    /// * `max_seq_len` – Target token sequence length.
151    /// * `tokenize`    – A closure that converts a caption string into token IDs.
152    pub fn from_csv<F>(path: &Path, max_seq_len: usize, tokenize: F) -> Result<Self>
153    where
154        F: Fn(&str) -> Vec<i32>,
155    {
156        let expected_sensor_len = TIME_STEPS * NUM_CHANNELS;
157        let mut items = Vec::new();
158
159        let mut rdr = csv::Reader::from_path(path)
160            .map_err(|e| SensorLMError::DatasetError(e.to_string()))?;
161
162        for result in rdr.records() {
163            let record = result.map_err(|e| SensorLMError::DatasetError(e.to_string()))?;
164
165            if record.len() < expected_sensor_len + 1 {
166                return Err(SensorLMError::DatasetError(format!(
167                    "Expected at least {} columns, got {}",
168                    expected_sensor_len + 1,
169                    record.len()
170                )));
171            }
172
173            let sensor: Vec<f32> = (0..expected_sensor_len)
174                .map(|i| {
175                    record[i]
176                        .trim()
177                        .parse::<f32>()
178                        .unwrap_or(0.0)
179                })
180                .collect();
181
182            let caption = record[expected_sensor_len].trim().to_string();
183
184            let mut token_ids = tokenize(&caption);
185            let real_len = token_ids.len().min(max_seq_len);
186            token_ids.truncate(real_len);
187            let mut attn = vec![1i32; real_len];
188            token_ids.resize(max_seq_len, 1);
189            attn.resize(max_seq_len, 0);
190
191            items.push(SensorTextItem {
192                sensor,
193                token_ids,
194                attention_mask: attn,
195                caption_text: caption,
196            });
197        }
198
199        Ok(Self { items })
200    }
201}
202
203impl Dataset<SensorTextItem> for CsvSensorDataset {
204    fn get(&self, index: usize) -> Option<SensorTextItem> {
205        self.items.get(index).cloned()
206    }
207
208    fn len(&self) -> usize {
209        self.items.len()
210    }
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216
217    #[test]
218    fn test_synthetic_dataset() {
219        let ds = SyntheticSensorDataset::new(16, 99, 256);
220        assert_eq!(ds.len(), 16);
221        let item = ds.get(0).expect("first item");
222        assert_eq!(item.sensor.len(), TIME_STEPS * NUM_CHANNELS);
223        assert_eq!(item.token_ids.len(), 256);
224        assert_eq!(item.attention_mask.len(), 256);
225    }
226}