Skip to main content

sensorlm/data/
preprocessing.rs

1//! Sensor signal normalisation and denormalisation.
2//!
3//! The raw sensor values captured by wearable devices span very different
4//! physical scales (e.g. heart rate ≈ 60–180 bpm vs. LF power ≈ 0–10 000).
5//! Before feeding data into the ViT encoder every channel is z-score
6//! normalised using population-level statistics.
7//!
8//! # Normalisation formula
9//!
10//! ```text
11//! z[t, c] = (x[t, c] - mean[c]) / std[c]
12//! ```
13//!
14//! # Denormalisation (used in caption generation)
15//!
16//! ```text
17//! x[t, c] = z[t, c] * std[c] + mean[c]
18//! ```
19//!
20//! After denormalisation certain channels are clamped to be non-negative
21//! (e.g. `steps`, `sleep_coefficient`).
22
23use ndarray::{Array2, ArrayView2};
24
25use crate::constants::{FEATURE_NAMES, NORM_PARAMS, NON_NEGATIVE_CHANNELS, NUM_CHANNELS};
26use crate::error::{Result, SensorLMError};
27
28// ---------------------------------------------------------------------------
29// Channel statistics
30// ---------------------------------------------------------------------------
31
32/// Per-channel normalisation statistics resolved from [`NORM_PARAMS`].
33#[derive(Debug, Clone)]
34pub struct ChannelStats {
35    /// Population mean for each of the 34 channels.
36    pub mean: Vec<f64>,
37    /// Population standard deviation for each channel.
38    pub std: Vec<f64>,
39}
40
41impl ChannelStats {
42    /// Build [`ChannelStats`] from the compile-time [`NORM_PARAMS`] table.
43    pub fn from_constants() -> Self {
44        let mean: Vec<f64> = NORM_PARAMS.iter().map(|(m, _)| *m).collect();
45        let std: Vec<f64> = NORM_PARAMS.iter().map(|(_, s)| *s).collect();
46        Self { mean, std }
47    }
48}
49
50// ---------------------------------------------------------------------------
51// Normalise
52// ---------------------------------------------------------------------------
53
54/// Z-score normalise a `(T, C)` raw sensor array in-place.
55///
56/// `data` must have shape `[T, NUM_CHANNELS]`.  The function normalises along
57/// the channel axis using population statistics from [`NORM_PARAMS`].
58///
59/// # Errors
60///
61/// Returns [`SensorLMError::ShapeMismatch`] if the number of columns ≠
62/// `NUM_CHANNELS`.
63pub fn normalize(data: &mut Array2<f64>) -> Result<()> {
64    let (t, c) = (data.nrows(), data.ncols());
65    if c != NUM_CHANNELS {
66        return Err(SensorLMError::ShapeMismatch {
67            expected: vec![t, NUM_CHANNELS],
68            actual: vec![t, c],
69        });
70    }
71    for ch in 0..NUM_CHANNELS {
72        let (mean, std) = NORM_PARAMS[ch];
73        if std == 0.0 {
74            continue;
75        }
76        let mut col = data.column_mut(ch);
77        col.mapv_inplace(|x| (x - mean) / std);
78    }
79    Ok(())
80}
81
82/// Z-score normalise a raw sensor array and return a new array.
83///
84/// See [`normalize`] for details.
85pub fn normalized(data: &ArrayView2<f64>) -> Result<Array2<f64>> {
86    let mut out = data.to_owned();
87    normalize(&mut out)?;
88    Ok(out)
89}
90
91// ---------------------------------------------------------------------------
92// Denormalise
93// ---------------------------------------------------------------------------
94
95/// Reverse a previous call to [`normalize`].
96///
97/// Converts normalised values back to physical units and clamps channels
98/// listed in [`NON_NEGATIVE_CHANNELS`] to `≥ 0`.
99pub fn denormalize(data: &mut Array2<f64>) -> Result<()> {
100    let (t, c) = (data.nrows(), data.ncols());
101    if c != NUM_CHANNELS {
102        return Err(SensorLMError::ShapeMismatch {
103            expected: vec![t, NUM_CHANNELS],
104            actual: vec![t, c],
105        });
106    }
107    for ch in 0..NUM_CHANNELS {
108        let (mean, std) = NORM_PARAMS[ch];
109        let mut col = data.column_mut(ch);
110        col.mapv_inplace(|z| z * std + mean);
111    }
112    // Clamp non-negative channels.
113    for &ch in NON_NEGATIVE_CHANNELS {
114        let mut col = data.column_mut(ch);
115        col.mapv_inplace(|x| x.max(0.0));
116    }
117    Ok(())
118}
119
120/// Denormalise without mutating: returns a new owned array.
121pub fn denormalized(data: &ArrayView2<f64>) -> Result<Array2<f64>> {
122    let mut out = data.to_owned();
123    denormalize(&mut out)?;
124    Ok(out)
125}
126
127// ---------------------------------------------------------------------------
128// Missingness handling
129// ---------------------------------------------------------------------------
130
131/// Apply a missingness mask, replacing imputed values with `NaN`.
132///
133/// `mask[t, c] == 1` signals that the value at `(t, c)` was imputed (not
134/// observed).  This function sets those positions to `NaN` so they are
135/// excluded from mean / statistics computations in the captioning pipeline.
136///
137/// Both arrays must have the same shape.
138pub fn apply_mask(data: &mut Array2<f64>, mask: &Array2<u8>) -> Result<()> {
139    if data.shape() != mask.shape() {
140        return Err(SensorLMError::ShapeMismatch {
141            expected: data.shape().to_vec(),
142            actual: mask.shape().to_vec(),
143        });
144    }
145    for (d, m) in data.iter_mut().zip(mask.iter()) {
146        if *m == 1 {
147            *d = f64::NAN;
148        }
149    }
150    Ok(())
151}
152
153// ---------------------------------------------------------------------------
154// Downsample
155// ---------------------------------------------------------------------------
156
157/// Average-pool a `(C, T)` array down to `(C, target_t)` time-steps.
158///
159/// Used by the structural caption generator to reduce the 1440 time-steps
160/// to a manageable 36 points (factor 40).
161///
162/// # Panics
163///
164/// Panics if `T` is not divisible by `target_t`.
165pub fn average_downsample_ct(data: &Array2<f64>, target_t: usize) -> Array2<f64> {
166    let (channels, t) = (data.nrows(), data.ncols());
167    assert_eq!(t % target_t, 0, "T must be divisible by target_t");
168    let factor = t / target_t;
169    let mut out = Array2::<f64>::zeros((channels, target_t));
170    for c in 0..channels {
171        for i in 0..target_t {
172            let slice = data.slice(ndarray::s![c, i * factor..(i + 1) * factor]);
173            out[[c, i]] = slice.mean().unwrap_or(0.0);
174        }
175    }
176    out
177}
178
179// ---------------------------------------------------------------------------
180// Compute per-channel stats (used by captioning)
181// ---------------------------------------------------------------------------
182
183/// Compute `(mean, max, min, std)` for every channel in a `(T, C)` array.
184///
185/// NaN values are ignored in all statistics.  Returns a vector of length `C`
186/// where each entry is `(mean, max, min, std)`.  If a channel is entirely NaN
187/// the tuple fields will be NaN.
188pub fn channel_stats(data: &Array2<f64>) -> Vec<(f64, f64, f64, f64)> {
189    let c = data.ncols();
190    (0..c)
191        .map(|ch| {
192            let col: Vec<f64> = data
193                .column(ch)
194                .iter()
195                .copied()
196                .filter(|v| !v.is_nan())
197                .collect();
198            if col.is_empty() {
199                return (f64::NAN, f64::NAN, f64::NAN, f64::NAN);
200            }
201            let n = col.len() as f64;
202            let mean = col.iter().sum::<f64>() / n;
203            let max = col.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
204            let min = col.iter().cloned().fold(f64::INFINITY, f64::min);
205            let var = col.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n;
206            let std = var.sqrt();
207            (mean, max, min, std)
208        })
209        .collect()
210}
211
212// ---------------------------------------------------------------------------
213// Flat f32 ↔ ndarray helpers (for burn tensor creation)
214// ---------------------------------------------------------------------------
215
216/// Convert a flat `Vec<f32>` (row-major, shape `[T, C]`) into an ndarray
217/// after normalising.
218///
219/// Returns a `(T, C)` [`Array2<f64>`] with z-score normalised values.
220pub fn f32_slice_to_normalised(raw: &[f32], t: usize, c: usize) -> Result<Array2<f64>> {
221    if raw.len() != t * c {
222        return Err(SensorLMError::ShapeMismatch {
223            expected: vec![t * c],
224            actual: vec![raw.len()],
225        });
226    }
227    let data_f64: Vec<f64> = raw.iter().map(|&x| x as f64).collect();
228    let mut arr = Array2::from_shape_vec((t, c), data_f64)
229        .map_err(|e| SensorLMError::DatasetError(e.to_string()))?;
230    normalize(&mut arr)?;
231    Ok(arr)
232}
233
234/// Return the human-readable name for a channel index.
235///
236/// Returns `"unknown"` for out-of-range indices.
237pub fn channel_name(idx: usize) -> &'static str {
238    FEATURE_NAMES.get(idx).copied().unwrap_or("unknown")
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244    use ndarray::Array2;
245
246    #[test]
247    fn test_roundtrip_normalise() {
248        let original = Array2::<f64>::from_elem((10, NUM_CHANNELS), 1.0);
249        let mut data = original.clone();
250        normalize(&mut data).unwrap();
251        denormalize(&mut data).unwrap();
252        // After round-trip every value (except clamped channels that might have
253        // changed sign) should be very close to 1.0.
254        for (orig, norm) in original.iter().zip(data.iter()) {
255            assert!((orig - norm).abs() < 1e-9 || *norm >= 0.0);
256        }
257    }
258
259    #[test]
260    fn test_non_negative_clamp() {
261        let mut data = Array2::<f64>::from_elem((5, NUM_CHANNELS), -100.0);
262        denormalize(&mut data).unwrap();
263        for &ch in NON_NEGATIVE_CHANNELS {
264            for t in 0..5 {
265                assert!(data[[t, ch]] >= 0.0, "channel {ch} should be >= 0");
266            }
267        }
268    }
269
270    #[test]
271    fn test_downsample() {
272        let data = Array2::<f64>::ones((NUM_CHANNELS, 1440));
273        let ds = average_downsample_ct(&data, 36);
274        assert_eq!(ds.shape(), &[NUM_CHANNELS, 36]);
275    }
276}