Skip to main content

sensorlm/data/captioning/
statistical.rs

1//! Level-1 (statistical) caption generation.
2//!
3//! Produces a text description of the **mean, maximum, minimum, and standard
4//! deviation** of every sensor channel, grouped by physiological category.
5//!
6//! # Pipeline
7//!
8//! 1. Denormalise the input `(T × C)` array back to physical units.
9//! 2. Optionally apply the missingness mask (set imputed values to NaN).
10//! 3. For each physiological group in [`crate::constants::CHANNEL_GROUPS`]:
11//!    a. Compute stats for the primary channels.
12//!    b. Sample `random_k` additional channels from the random pool.
13//!    c. Format each channel description using a randomly selected template.
14//! 4. Concatenate all group descriptions into a single string.
15//!
16//! # Example output
17//!
18//! ```text
19//! For Heart, heart rate mean, max, min, std are 72.1, 95.3, 58.4, 8.2.
20//! hrv rr exhibits a mean of 820.4, with range 960.0 to 680.0 and a
21//! standard deviation of 45.1.
22//! For Activity, steps mean, max, min, std are 3.2, 12.0, 0.0, 2.1. ...
23//! ```
24
25use ndarray::{Array2, ArrayView2};
26use rand::{seq::SliceRandom, Rng};
27
28use crate::constants::CHANNEL_GROUPS;
29use crate::data::preprocessing::{channel_stats, denormalized};
30use crate::data::captioning::templates::LOW_LEVEL_TEMPLATES;
31use crate::error::Result;
32
33/// Generate a level-1 statistical caption.
34///
35/// # Arguments
36///
37/// * `x_norm`  – Normalised sensor array, shape `(T, C)`.
38/// * `mask`    – Optional missingness mask, shape `(T, C)`.
39///   `mask[t, c] == 1` means the value was imputed; set to `None` to include
40///   all values.
41/// * `rng`     – Random number generator used to pick templates and random
42///   channel subsets.
43///
44/// # Returns
45///
46/// A multi-line string suitable for use as the `low_level_caption` text pair.
47pub fn generate_statistical_caption<R: Rng>(
48    x_norm: &ArrayView2<f64>,
49    mask: Option<&Array2<u8>>,
50    rng: &mut R,
51) -> Result<String> {
52    use crate::data::preprocessing::apply_mask;
53
54    // Step 1 – Denormalise.
55    let mut x_phys = denormalized(x_norm)?;
56
57    // Step 2 – Optionally blank out imputed values.
58    if let Some(m) = mask {
59        apply_mask(&mut x_phys, m)?;
60    }
61
62    // Step 3 – Compute per-channel stats on the denormalised data.
63    let stats = channel_stats(&x_phys); // Vec<(mean, max, min, std)>
64
65    // Step 4 – Build caption.
66    let mut parts = Vec::new();
67
68    for group in CHANNEL_GROUPS {
69        let mut group_parts = Vec::new();
70
71        // Primary channels (always included).
72        for &(display_name, ch_idx) in group.primary {
73            let (mean, max, min, std) = stats[ch_idx];
74            if [mean, max, min, std].iter().any(|v| v.is_nan()) {
75                continue;
76            }
77            group_parts.push(describe_low_level(display_name, mean, max, min, std, rng));
78        }
79
80        // Random channels (sampled).
81        if group.random_k > 0 && !group.random.is_empty() {
82            let sample: Vec<_> = group
83                .random
84                .choose_multiple(rng, group.random_k)
85                .collect();
86            for &&(display_name, ch_idx) in &sample {
87                let (mean, max, min, std) = stats[ch_idx];
88                if [mean, max, min, std].iter().any(|v| v.is_nan()) {
89                    continue;
90                }
91                group_parts.push(describe_low_level(display_name, mean, max, min, std, rng));
92            }
93        }
94
95        if !group_parts.is_empty() {
96            parts.push(format!("For {}, {}\n", group.category, group_parts.join(" ")));
97        }
98    }
99
100    Ok(parts.concat())
101}
102
103// ---------------------------------------------------------------------------
104// Internal helpers
105// ---------------------------------------------------------------------------
106
107/// Pick a random low-level template and fill in the placeholders.
108fn describe_low_level<R: Rng>(
109    name: &str,
110    mean_val: f64,
111    max_val: f64,
112    min_val: f64,
113    std_val: f64,
114    rng: &mut R,
115) -> String {
116    let tmpl = LOW_LEVEL_TEMPLATES.choose(rng).copied().unwrap_or(LOW_LEVEL_TEMPLATES[0]);
117    tmpl.replace("{name}", name)
118        .replace("{mean_val:.1}", &format!("{mean_val:.1}"))
119        .replace("{max_val:.1}", &format!("{max_val:.1}"))
120        .replace("{min_val:.1}", &format!("{min_val:.1}"))
121        .replace("{std_val:.1}", &format!("{std_val:.1}"))
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127    use ndarray::Array2;
128    use rand::SeedableRng;
129    use rand::rngs::StdRng;
130    use crate::constants::NUM_CHANNELS;
131
132    #[test]
133    fn test_statistical_caption_runs() {
134        let x = Array2::<f64>::zeros((1440, NUM_CHANNELS));
135        let mut rng = StdRng::seed_from_u64(42);
136        let cap = generate_statistical_caption(&x.view(), None, &mut rng).unwrap();
137        assert!(!cap.is_empty(), "Caption must be non-empty");
138        assert!(cap.contains("Heart"), "Caption must mention Heart group");
139    }
140}