Skip to main content

sensorlm/data/captioning/
structural.rs

1//! Level-2 (structural) caption generation.
2//!
3//! Detects and describes **temporal patterns** within each sensor channel:
4//!
5//! * **Trends** – monotonically increasing / decreasing / stable segments
6//!   identified by fitting linear regression over overlapping windows.
7//! * **Anomalies** – significant peaks, spikes, and valleys found by a
8//!   prominence-threshold peak detector.
9//!
10//! # Algorithm overview
11//!
12//! ```text
13//! 1. Denormalise (T, C) array.
14//! 2. Downsample time axis: 1440 → 36 points (factor 40, average pooling).
15//! 3. For every channel in every group:
16//!    a. Fit linear regression over windows of size 6, 8, 12 data points
17//!       (with 50% overlap).  Classify each window as increasing / decreasing /
18//!       stable using slope and range thresholds.
19//!    b. Run prominence-based peak / valley detector on the downsampled signal.
20//! 4. Sample up to `max_insights` per category, format with random templates.
21//! 5. Concatenate all group captions.
22//! ```
23
24use ndarray::{Array2, ArrayView2};
25use rand::{seq::SliceRandom, Rng};
26
27use crate::constants::CHANNEL_GROUPS;
28use crate::data::captioning::templates::{ANOMALY_TEMPLATES, TREND_TEMPLATES};
29use crate::data::preprocessing::{average_downsample_ct, denormalized};
30use crate::error::Result;
31
32/// Generate a level-2 structural caption.
33///
34/// # Arguments
35///
36/// * `x_norm`            – Normalised `(T, C)` sensor array.
37/// * `max_per_category`  – Maximum number of insight sentences per category
38///   (default used in reference: 7).
39/// * `rng`               – Random number generator.
40pub fn generate_structural_caption<R: Rng>(
41    x_norm: &ArrayView2<f64>,
42    max_per_category: usize,
43    rng: &mut R,
44) -> Result<String> {
45    // Denormalise to physical units.
46    let x_phys = denormalized(x_norm)?; // (T=1440, C=34)
47
48    // Transpose to (C, T) for downsampling, then back.
49    let ct: Array2<f64> = x_phys.t().to_owned(); // (C, 1440)
50
51    // Downsample: 1440 → 36 points (factor 40, 40-minute averages).
52    const TARGET_T: usize = 36;
53    const DOWNSAMPLE_SCALE: usize = 40;
54    let ct_ds = average_downsample_ct(&ct, TARGET_T); // (C, 36)
55
56    let mut caption = String::new();
57
58    for group in CHANNEL_GROUPS {
59        let mut insights: Vec<(usize, String)> = Vec::new(); // (original_order, sentence)
60
61        for &(display_name, ch_idx) in group.primary {
62            if ch_idx >= ct_ds.nrows() {
63                continue;
64            }
65            let channel_data: Vec<f64> = ct_ds.row(ch_idx).iter().copied().collect();
66
67            // Trend detection.
68            let trends = identify_trends(&channel_data, DOWNSAMPLE_SCALE);
69            for (start, end, trend_type, _slope, _delta, _seg) in &trends {
70                insights.push((
71                    *start,
72                    describe_trend(display_name, trend_type, *start, *end, rng),
73                ));
74            }
75
76            // Anomaly detection.
77            let peaks_valleys = detect_peaks_valleys(&channel_data, DOWNSAMPLE_SCALE);
78            for (minute, anomaly_type) in &peaks_valleys {
79                insights.push((
80                    *minute,
81                    describe_anomaly(display_name, anomaly_type, *minute, rng),
82                ));
83            }
84        }
85
86        // Randomly subsample if over budget.
87        if insights.len() > max_per_category {
88            insights.shuffle(rng);
89            insights.truncate(max_per_category);
90            // Re-sort by time for readability.
91            insights.sort_by_key(|(t, _)| *t);
92        }
93
94        let category_text: Vec<&str> = insights.iter().map(|(_, s)| s.as_str()).collect();
95        caption.push_str(&format!("{}: {}\n", group.category, category_text.join(" ")));
96    }
97
98    Ok(caption)
99}
100
101// ---------------------------------------------------------------------------
102// Trend detection via linear regression
103// ---------------------------------------------------------------------------
104
105/// Result of trend detection for a single window.
106///
107/// Fields: `(start_minute, end_minute, trend_type, slope, delta, segment_size)`.
108type TrendResult = (usize, usize, String, f64, f64, usize);
109
110/// Identify up to 3 non-overlapping trends in `data` using multi-scale
111/// sliding-window linear regression.
112///
113/// Segment sizes mirror the reference: 6, 8, 12 downsampled points.
114/// Each downsampled point represents `downsample_scale` minutes.
115fn identify_trends(data: &[f64], downsample_scale: usize) -> Vec<TrendResult> {
116    if data.is_empty() {
117        return vec![];
118    }
119
120    let max_v = data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
121    let min_v = data.iter().cloned().fold(f64::INFINITY, f64::min);
122    let range = (max_v - min_v).max(1e-9);
123
124    // Slope thresholds for each segment size (scaled by range).
125    let thresholds: &[(usize, f64)] = &[(6, 1.5), (8, 1.3), (12, 1.0)];
126    let stable_threshold = 0.01 * range;
127
128    let mut candidates: Vec<TrendResult> = Vec::new();
129
130    for &(seg, scale) in thresholds {
131        let slope_thresh = scale * range / 40.0;
132        let step = seg / 2; // 50 % overlap
133        let mut start = 0;
134        while start + seg <= data.len() {
135            let slice = &data[start..start + seg];
136            let slope = linear_regression_slope(slice);
137            let delta_val = slice[seg - 1] - slice[0];
138            let start_min = (start + 1) * downsample_scale;
139            let end_min = (start + seg) * downsample_scale;
140
141            if slope > slope_thresh && delta_val > 0.2 * range {
142                candidates.push((start_min, end_min, "increasing".into(), slope, delta_val, seg));
143            } else if slope < -slope_thresh && (-delta_val) > 0.2 * range {
144                candidates.push((start_min, end_min, "decreasing".into(), slope, -delta_val, seg));
145            } else if slope.abs() < stable_threshold {
146                let seg_range = slice.iter().cloned().fold(f64::NEG_INFINITY, f64::max)
147                    - slice.iter().cloned().fold(f64::INFINITY, f64::min);
148                if seg_range < 0.1 * range {
149                    candidates.push((start_min, end_min, "stable".into(), slope, slice[seg - 1], seg));
150                }
151            }
152
153            start += step;
154        }
155    }
156
157    // Sort by |delta| descending, then select up to 3 non-overlapping trends.
158    candidates.sort_by(|a, b| b.4.partial_cmp(&a.4).unwrap_or(std::cmp::Ordering::Equal));
159
160    let mut selected: Vec<TrendResult> = Vec::new();
161    'outer: for cand in candidates {
162        if selected.len() == 3 {
163            break;
164        }
165        let (s1, e1, ..) = cand;
166        for &(s2, e2, ..) in &selected {
167            let overlap = overlap_fraction(s1, e1, s2, e2);
168            if overlap > 0.3 {
169                continue 'outer;
170            }
171        }
172        selected.push(cand);
173    }
174
175    selected
176}
177
178/// Compute the fraction of the shorter segment that overlaps with another.
179fn overlap_fraction(s1: usize, e1: usize, s2: usize, e2: usize) -> f64 {
180    let ov = (e1.min(e2) as isize - s1.max(s2) as isize).max(0) as f64;
181    let shorter = ((e1 - s1).min(e2 - s2)) as f64;
182    if shorter == 0.0 { 0.0 } else { ov / shorter }
183}
184
185/// Ordinary-least-squares slope for an evenly spaced sequence.
186fn linear_regression_slope(y: &[f64]) -> f64 {
187    let n = y.len() as f64;
188    let x_mean = (n - 1.0) / 2.0;
189    let y_mean: f64 = y.iter().sum::<f64>() / n;
190    let num: f64 = y
191        .iter()
192        .enumerate()
193        .map(|(i, &yi)| (i as f64 - x_mean) * (yi - y_mean))
194        .sum();
195    let den: f64 = (0..y.len())
196        .map(|i| (i as f64 - x_mean).powi(2))
197        .sum();
198    if den == 0.0 { 0.0 } else { num / den }
199}
200
201// ---------------------------------------------------------------------------
202// Peak / valley detection
203// ---------------------------------------------------------------------------
204
205/// Detect significant peaks (spikes) and valleys (drops) in `data`.
206///
207/// Returns a vector of `(minute, event_type)` pairs.
208fn detect_peaks_valleys(data: &[f64], downsample_scale: usize) -> Vec<(usize, String)> {
209    if data.len() < 3 {
210        return vec![];
211    }
212
213    let max_v = data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
214    let min_v = data.iter().cloned().fold(f64::INFINITY, f64::min);
215    let mean_v: f64 = data.iter().sum::<f64>() / data.len() as f64;
216    let range = (max_v - min_v).max(1e-9);
217
218    const PROMINENCE_THRESHOLD: f64 = 0.5;
219    const HEIGHT_THRESHOLD: f64 = 0.6;
220    const DISTANCE: usize = 5;
221
222    let prom_thresh = PROMINENCE_THRESHOLD * range;
223    let height_thresh = HEIGHT_THRESHOLD * range + mean_v;
224    let valley_thresh = -(mean_v + (1.0 - HEIGHT_THRESHOLD) * range);
225
226    let mut results = Vec::new();
227
228    // Peaks
229    let peaks = find_peaks(data, prom_thresh, Some(height_thresh), DISTANCE);
230    for p in peaks {
231        results.push(((p + 1) * downsample_scale, "spike".to_string()));
232    }
233
234    // Valleys (invert the signal)
235    let inv: Vec<f64> = data.iter().map(|x| -x).collect();
236    let valleys = find_peaks(&inv, prom_thresh, Some(valley_thresh), DISTANCE);
237    for v in valleys {
238        results.push(((v + 1) * downsample_scale, "drop".to_string()));
239    }
240
241    results
242}
243
244/// Simple local-maximum peak finder with prominence and minimum height filters.
245///
246/// Returns indices of detected peaks.
247fn find_peaks(
248    data: &[f64],
249    prominence_threshold: f64,
250    height_threshold: Option<f64>,
251    min_distance: usize,
252) -> Vec<usize> {
253    let n = data.len();
254    let mut peaks: Vec<(usize, f64)> = Vec::new();
255
256    for i in 1..n - 1 {
257        if data[i] > data[i - 1] && data[i] > data[i + 1] {
258            // Check height.
259            if let Some(ht) = height_threshold {
260                if data[i] < ht {
261                    continue;
262                }
263            }
264            // Approximate prominence: difference to the lowest surrounding base.
265            let left_min = data[..i].iter().cloned().fold(f64::INFINITY, f64::min);
266            let right_min = data[i + 1..].iter().cloned().fold(f64::INFINITY, f64::min);
267            let prominence = data[i] - left_min.max(right_min);
268            if prominence >= prominence_threshold {
269                peaks.push((i, data[i]));
270            }
271        }
272    }
273
274    // Enforce minimum distance: greedily keep highest peaks.
275    peaks.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
276    let mut selected: Vec<usize> = Vec::new();
277    for (idx, _) in peaks {
278        if selected
279            .iter()
280            .all(|&s| (idx as isize - s as isize).unsigned_abs() >= min_distance)
281        {
282            selected.push(idx);
283        }
284    }
285    selected
286}
287
288// ---------------------------------------------------------------------------
289// Formatting helpers
290// ---------------------------------------------------------------------------
291
292fn describe_trend<R: Rng>(
293    sensor_name: &str,
294    trend_type: &str,
295    start: usize,
296    end: usize,
297    rng: &mut R,
298) -> String {
299    let tmpl = TREND_TEMPLATES
300        .choose(rng)
301        .copied()
302        .unwrap_or(TREND_TEMPLATES[0]);
303    tmpl.replace("{sensor_name}", sensor_name)
304        .replace("{trend_type}", trend_type)
305        .replace("{start}", &start.to_string())
306        .replace("{end}", &end.to_string())
307}
308
309fn describe_anomaly<R: Rng>(
310    sensor_name: &str,
311    anomaly: &str,
312    time: usize,
313    rng: &mut R,
314) -> String {
315    let tmpl = ANOMALY_TEMPLATES
316        .choose(rng)
317        .copied()
318        .unwrap_or(ANOMALY_TEMPLATES[0]);
319    tmpl.replace("{sensor_name}", sensor_name)
320        .replace("{anomaly}", anomaly)
321        .replace("{time}", &time.to_string())
322}
323
324#[cfg(test)]
325mod tests {
326    use super::*;
327    use ndarray::Array2;
328    use rand::{rngs::StdRng, SeedableRng};
329    use crate::constants::NUM_CHANNELS;
330
331    #[test]
332    fn test_structural_caption_runs() {
333        let x = Array2::<f64>::zeros((1440, NUM_CHANNELS));
334        let mut rng = StdRng::seed_from_u64(7);
335        let cap = generate_structural_caption(&x.view(), 7, &mut rng).unwrap();
336        assert!(!cap.is_empty());
337    }
338
339    #[test]
340    fn test_linreg_slope() {
341        let y: Vec<f64> = (0..10).map(|i| i as f64).collect();
342        let slope = linear_regression_slope(&y);
343        assert!((slope - 1.0).abs() < 1e-9, "slope should be 1.0, got {slope}");
344    }
345
346    #[test]
347    fn test_find_peaks() {
348        let data = vec![0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0];
349        let peaks = find_peaks(&data, 0.5, None, 2);
350        assert_eq!(peaks.len(), 2);
351    }
352}