sensorlm/data/preprocessing.rs
1//! Sensor signal normalisation and denormalisation.
2//!
3//! The raw sensor values captured by wearable devices span very different
4//! physical scales (e.g. heart rate ≈ 60–180 bpm vs. LF power ≈ 0–10 000).
5//! Before feeding data into the ViT encoder every channel is z-score
6//! normalised using population-level statistics.
7//!
8//! # Normalisation formula
9//!
10//! ```text
11//! z[t, c] = (x[t, c] - mean[c]) / std[c]
12//! ```
13//!
14//! # Denormalisation (used in caption generation)
15//!
16//! ```text
17//! x[t, c] = z[t, c] * std[c] + mean[c]
18//! ```
19//!
20//! After denormalisation certain channels are clamped to be non-negative
21//! (e.g. `steps`, `sleep_coefficient`).
22
23use ndarray::{Array2, ArrayView2};
24
25use crate::constants::{FEATURE_NAMES, NORM_PARAMS, NON_NEGATIVE_CHANNELS, NUM_CHANNELS};
26use crate::error::{Result, SensorLMError};
27
28// ---------------------------------------------------------------------------
29// Channel statistics
30// ---------------------------------------------------------------------------
31
32/// Per-channel normalisation statistics resolved from [`NORM_PARAMS`].
33#[derive(Debug, Clone)]
34pub struct ChannelStats {
35 /// Population mean for each of the 34 channels.
36 pub mean: Vec<f64>,
37 /// Population standard deviation for each channel.
38 pub std: Vec<f64>,
39}
40
41impl ChannelStats {
42 /// Build [`ChannelStats`] from the compile-time [`NORM_PARAMS`] table.
43 pub fn from_constants() -> Self {
44 let mean: Vec<f64> = NORM_PARAMS.iter().map(|(m, _)| *m).collect();
45 let std: Vec<f64> = NORM_PARAMS.iter().map(|(_, s)| *s).collect();
46 Self { mean, std }
47 }
48}
49
50// ---------------------------------------------------------------------------
51// Normalise
52// ---------------------------------------------------------------------------
53
54/// Z-score normalise a `(T, C)` raw sensor array in-place.
55///
56/// `data` must have shape `[T, NUM_CHANNELS]`. The function normalises along
57/// the channel axis using population statistics from [`NORM_PARAMS`].
58///
59/// # Errors
60///
61/// Returns [`SensorLMError::ShapeMismatch`] if the number of columns ≠
62/// `NUM_CHANNELS`.
63pub fn normalize(data: &mut Array2<f64>) -> Result<()> {
64 let (t, c) = (data.nrows(), data.ncols());
65 if c != NUM_CHANNELS {
66 return Err(SensorLMError::ShapeMismatch {
67 expected: vec![t, NUM_CHANNELS],
68 actual: vec![t, c],
69 });
70 }
71 for ch in 0..NUM_CHANNELS {
72 let (mean, std) = NORM_PARAMS[ch];
73 if std == 0.0 {
74 continue;
75 }
76 let mut col = data.column_mut(ch);
77 col.mapv_inplace(|x| (x - mean) / std);
78 }
79 Ok(())
80}
81
82/// Z-score normalise a raw sensor array and return a new array.
83///
84/// See [`normalize`] for details.
85pub fn normalized(data: &ArrayView2<f64>) -> Result<Array2<f64>> {
86 let mut out = data.to_owned();
87 normalize(&mut out)?;
88 Ok(out)
89}
90
91// ---------------------------------------------------------------------------
92// Denormalise
93// ---------------------------------------------------------------------------
94
95/// Reverse a previous call to [`normalize`].
96///
97/// Converts normalised values back to physical units and clamps channels
98/// listed in [`NON_NEGATIVE_CHANNELS`] to `≥ 0`.
99pub fn denormalize(data: &mut Array2<f64>) -> Result<()> {
100 let (t, c) = (data.nrows(), data.ncols());
101 if c != NUM_CHANNELS {
102 return Err(SensorLMError::ShapeMismatch {
103 expected: vec![t, NUM_CHANNELS],
104 actual: vec![t, c],
105 });
106 }
107 for ch in 0..NUM_CHANNELS {
108 let (mean, std) = NORM_PARAMS[ch];
109 let mut col = data.column_mut(ch);
110 col.mapv_inplace(|z| z * std + mean);
111 }
112 // Clamp non-negative channels.
113 for &ch in NON_NEGATIVE_CHANNELS {
114 let mut col = data.column_mut(ch);
115 col.mapv_inplace(|x| x.max(0.0));
116 }
117 Ok(())
118}
119
120/// Denormalise without mutating: returns a new owned array.
121pub fn denormalized(data: &ArrayView2<f64>) -> Result<Array2<f64>> {
122 let mut out = data.to_owned();
123 denormalize(&mut out)?;
124 Ok(out)
125}
126
127// ---------------------------------------------------------------------------
128// Missingness handling
129// ---------------------------------------------------------------------------
130
131/// Apply a missingness mask, replacing imputed values with `NaN`.
132///
133/// `mask[t, c] == 1` signals that the value at `(t, c)` was imputed (not
134/// observed). This function sets those positions to `NaN` so they are
135/// excluded from mean / statistics computations in the captioning pipeline.
136///
137/// Both arrays must have the same shape.
138pub fn apply_mask(data: &mut Array2<f64>, mask: &Array2<u8>) -> Result<()> {
139 if data.shape() != mask.shape() {
140 return Err(SensorLMError::ShapeMismatch {
141 expected: data.shape().to_vec(),
142 actual: mask.shape().to_vec(),
143 });
144 }
145 for (d, m) in data.iter_mut().zip(mask.iter()) {
146 if *m == 1 {
147 *d = f64::NAN;
148 }
149 }
150 Ok(())
151}
152
153// ---------------------------------------------------------------------------
154// Downsample
155// ---------------------------------------------------------------------------
156
157/// Average-pool a `(C, T)` array down to `(C, target_t)` time-steps.
158///
159/// Used by the structural caption generator to reduce the 1440 time-steps
160/// to a manageable 36 points (factor 40).
161///
162/// # Panics
163///
164/// Panics if `T` is not divisible by `target_t`.
165pub fn average_downsample_ct(data: &Array2<f64>, target_t: usize) -> Array2<f64> {
166 let (channels, t) = (data.nrows(), data.ncols());
167 assert_eq!(t % target_t, 0, "T must be divisible by target_t");
168 let factor = t / target_t;
169 let mut out = Array2::<f64>::zeros((channels, target_t));
170 for c in 0..channels {
171 for i in 0..target_t {
172 let slice = data.slice(ndarray::s![c, i * factor..(i + 1) * factor]);
173 out[[c, i]] = slice.mean().unwrap_or(0.0);
174 }
175 }
176 out
177}
178
179// ---------------------------------------------------------------------------
180// Compute per-channel stats (used by captioning)
181// ---------------------------------------------------------------------------
182
183/// Compute `(mean, max, min, std)` for every channel in a `(T, C)` array.
184///
185/// NaN values are ignored in all statistics. Returns a vector of length `C`
186/// where each entry is `(mean, max, min, std)`. If a channel is entirely NaN
187/// the tuple fields will be NaN.
188pub fn channel_stats(data: &Array2<f64>) -> Vec<(f64, f64, f64, f64)> {
189 let c = data.ncols();
190 (0..c)
191 .map(|ch| {
192 let col: Vec<f64> = data
193 .column(ch)
194 .iter()
195 .copied()
196 .filter(|v| !v.is_nan())
197 .collect();
198 if col.is_empty() {
199 return (f64::NAN, f64::NAN, f64::NAN, f64::NAN);
200 }
201 let n = col.len() as f64;
202 let mean = col.iter().sum::<f64>() / n;
203 let max = col.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
204 let min = col.iter().cloned().fold(f64::INFINITY, f64::min);
205 let var = col.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n;
206 let std = var.sqrt();
207 (mean, max, min, std)
208 })
209 .collect()
210}
211
212// ---------------------------------------------------------------------------
213// Flat f32 ↔ ndarray helpers (for burn tensor creation)
214// ---------------------------------------------------------------------------
215
216/// Convert a flat `Vec<f32>` (row-major, shape `[T, C]`) into an ndarray
217/// after normalising.
218///
219/// Returns a `(T, C)` [`Array2<f64>`] with z-score normalised values.
220pub fn f32_slice_to_normalised(raw: &[f32], t: usize, c: usize) -> Result<Array2<f64>> {
221 if raw.len() != t * c {
222 return Err(SensorLMError::ShapeMismatch {
223 expected: vec![t * c],
224 actual: vec![raw.len()],
225 });
226 }
227 let data_f64: Vec<f64> = raw.iter().map(|&x| x as f64).collect();
228 let mut arr = Array2::from_shape_vec((t, c), data_f64)
229 .map_err(|e| SensorLMError::DatasetError(e.to_string()))?;
230 normalize(&mut arr)?;
231 Ok(arr)
232}
233
234/// Return the human-readable name for a channel index.
235///
236/// Returns `"unknown"` for out-of-range indices.
237pub fn channel_name(idx: usize) -> &'static str {
238 FEATURE_NAMES.get(idx).copied().unwrap_or("unknown")
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244 use ndarray::Array2;
245
246 #[test]
247 fn test_roundtrip_normalise() {
248 let original = Array2::<f64>::from_elem((10, NUM_CHANNELS), 1.0);
249 let mut data = original.clone();
250 normalize(&mut data).unwrap();
251 denormalize(&mut data).unwrap();
252 // After round-trip every value (except clamped channels that might have
253 // changed sign) should be very close to 1.0.
254 for (orig, norm) in original.iter().zip(data.iter()) {
255 assert!((orig - norm).abs() < 1e-9 || *norm >= 0.0);
256 }
257 }
258
259 #[test]
260 fn test_non_negative_clamp() {
261 let mut data = Array2::<f64>::from_elem((5, NUM_CHANNELS), -100.0);
262 denormalize(&mut data).unwrap();
263 for &ch in NON_NEGATIVE_CHANNELS {
264 for t in 0..5 {
265 assert!(data[[t, ch]] >= 0.0, "channel {ch} should be >= 0");
266 }
267 }
268 }
269
270 #[test]
271 fn test_downsample() {
272 let data = Array2::<f64>::ones((NUM_CHANNELS, 1440));
273 let ds = average_downsample_ct(&data, 36);
274 assert_eq!(ds.shape(), &[NUM_CHANNELS, 36]);
275 }
276}