sensorlm/data/captioning/
statistical.rs1use ndarray::{Array2, ArrayView2};
26use rand::{seq::SliceRandom, Rng};
27
28use crate::constants::CHANNEL_GROUPS;
29use crate::data::preprocessing::{channel_stats, denormalized};
30use crate::data::captioning::templates::LOW_LEVEL_TEMPLATES;
31use crate::error::Result;
32
33pub fn generate_statistical_caption<R: Rng>(
48 x_norm: &ArrayView2<f64>,
49 mask: Option<&Array2<u8>>,
50 rng: &mut R,
51) -> Result<String> {
52 use crate::data::preprocessing::apply_mask;
53
54 let mut x_phys = denormalized(x_norm)?;
56
57 if let Some(m) = mask {
59 apply_mask(&mut x_phys, m)?;
60 }
61
62 let stats = channel_stats(&x_phys); let mut parts = Vec::new();
67
68 for group in CHANNEL_GROUPS {
69 let mut group_parts = Vec::new();
70
71 for &(display_name, ch_idx) in group.primary {
73 let (mean, max, min, std) = stats[ch_idx];
74 if [mean, max, min, std].iter().any(|v| v.is_nan()) {
75 continue;
76 }
77 group_parts.push(describe_low_level(display_name, mean, max, min, std, rng));
78 }
79
80 if group.random_k > 0 && !group.random.is_empty() {
82 let sample: Vec<_> = group
83 .random
84 .choose_multiple(rng, group.random_k)
85 .collect();
86 for &&(display_name, ch_idx) in &sample {
87 let (mean, max, min, std) = stats[ch_idx];
88 if [mean, max, min, std].iter().any(|v| v.is_nan()) {
89 continue;
90 }
91 group_parts.push(describe_low_level(display_name, mean, max, min, std, rng));
92 }
93 }
94
95 if !group_parts.is_empty() {
96 parts.push(format!("For {}, {}\n", group.category, group_parts.join(" ")));
97 }
98 }
99
100 Ok(parts.concat())
101}
102
103fn describe_low_level<R: Rng>(
109 name: &str,
110 mean_val: f64,
111 max_val: f64,
112 min_val: f64,
113 std_val: f64,
114 rng: &mut R,
115) -> String {
116 let tmpl = LOW_LEVEL_TEMPLATES.choose(rng).copied().unwrap_or(LOW_LEVEL_TEMPLATES[0]);
117 tmpl.replace("{name}", name)
118 .replace("{mean_val:.1}", &format!("{mean_val:.1}"))
119 .replace("{max_val:.1}", &format!("{max_val:.1}"))
120 .replace("{min_val:.1}", &format!("{min_val:.1}"))
121 .replace("{std_val:.1}", &format!("{std_val:.1}"))
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127 use ndarray::Array2;
128 use rand::SeedableRng;
129 use rand::rngs::StdRng;
130 use crate::constants::NUM_CHANNELS;
131
132 #[test]
133 fn test_statistical_caption_runs() {
134 let x = Array2::<f64>::zeros((1440, NUM_CHANNELS));
135 let mut rng = StdRng::seed_from_u64(42);
136 let cap = generate_statistical_caption(&x.view(), None, &mut rng).unwrap();
137 assert!(!cap.is_empty(), "Caption must be non-empty");
138 assert!(cap.contains("Heart"), "Caption must mention Heart group");
139 }
140}