scouter_drift/spc/
alert.rs

1use crate::error::DriftError;
2use ndarray::s;
3use ndarray::{ArrayView1, ArrayView2, Axis};
4use rayon::iter::IntoParallelIterator;
5use rayon::iter::ParallelIterator;
6use scouter_types::spc::{AlertZone, SpcAlert, SpcAlertRule, SpcAlertType, SpcFeatureAlerts};
7use std::collections::HashSet;
8use std::num::ParseIntError;
9
10// Struct for holding stateful Alert information
11#[derive(Clone)]
12pub struct Alerter {
13    pub alerts: HashSet<SpcAlert>,
14    pub alert_rule: SpcAlertRule,
15}
16
17impl Alerter {
18    // Create a new instance of the Alerter struct
19    //
20    // Sets both alerts (hashset) and alert positions (hashmap)
21    // Alerts is a collection of unique alert types
22    // Alert positions is a hashmap of keys (alert zones) and their corresponding alert start and stop indices
23    // Keys:
24    //  1 - Zone 1 alerts
25    //  2 - Zone 2 alerts
26    //  3 - Zone 3 alerts
27    //  4 - Zone 4 alerts (out of bounds)
28    //  5 - Increasing trend alerts
29    //  6 - Decreasing trend alerts
30    pub fn new(alert_rule: SpcAlertRule) -> Self {
31        Alerter {
32            alerts: HashSet::new(),
33            alert_rule,
34        }
35    }
36
37    // Check if the drift array has a consecutive zone alert for negative or positive values
38    //
39    // drift_array - ArrayView1<f64> - The drift array to check
40    // zone_consecutive_rule - usize - The number of consecutive values to check for
41    // threshold - f64 - The threshold value to check against
42    pub fn check_zone_consecutive(
43        &self,
44        drift_array: &ArrayView1<f64>,
45        zone_consecutive_rule: usize,
46        threshold: f64,
47    ) -> Result<bool, DriftError> {
48        let pos_count = drift_array.iter().filter(|&x| *x >= threshold).count();
49
50        let neg_count = drift_array.iter().filter(|&x| *x <= -threshold).count();
51
52        if pos_count >= zone_consecutive_rule || neg_count >= zone_consecutive_rule {
53            return Ok(true);
54        }
55
56        Ok(false)
57    }
58
59    pub fn check_zone_alternating(
60        &self,
61        drift_array: &ArrayView1<f64>,
62        zone_alt_rule: usize,
63        threshold: f64,
64    ) -> Result<bool, DriftError> {
65        // check for consecutive alternating values
66
67        let mut last_val = 0.0;
68        let mut alt_count = 0;
69
70        for i in 0..drift_array.len() {
71            if drift_array[i] == 0.0 {
72                last_val = 0.0;
73                alt_count = 0;
74                continue;
75            } else if drift_array[i] != last_val
76                && (drift_array[i] >= threshold || drift_array[i] <= -threshold)
77            {
78                alt_count += 1;
79                if alt_count >= zone_alt_rule {
80                    return Ok(true);
81                }
82            } else {
83                last_val = 0.0;
84                alt_count = 0;
85                continue;
86            }
87
88            last_val = drift_array[i];
89        }
90
91        Ok(false)
92    }
93
94    pub fn has_overlap(last_entry: &[usize], start: usize, end: usize) -> Result<bool, DriftError> {
95        let last_start = last_entry[0];
96        let last_end = last_entry[1];
97
98        let has_overlap = last_start <= end && start <= last_end;
99
100        Ok(has_overlap)
101    }
102
103    pub fn check_zone(
104        &mut self,
105        value: f64,
106        idx: usize,
107        drift_array: &ArrayView1<f64>,
108        consecutive_rule: usize,
109        alternating_rule: usize,
110        threshold: f64,
111    ) -> Result<(), DriftError> {
112        // test consecutive first
113        if (value == threshold || value == -threshold)
114            && idx + 1 >= consecutive_rule
115            && consecutive_rule > 0
116        {
117            let start = idx + 1 - consecutive_rule;
118            let consecutive_alert = self.check_zone_consecutive(
119                &drift_array.slice(s![start..=idx]),
120                consecutive_rule,
121                threshold,
122            )?;
123
124            // update alerts
125            if consecutive_alert {
126                self.update_alert(threshold as usize, SpcAlertType::Consecutive)?;
127            }
128        }
129
130        // check alternating
131        if (value == threshold || value == -threshold)
132            && idx + 1 >= alternating_rule
133            && alternating_rule > 0
134        {
135            let start = idx + 1 - alternating_rule;
136            let alternating_alert = self.check_zone_alternating(
137                &drift_array.slice(s![start..=idx]),
138                alternating_rule,
139                threshold,
140            )?;
141
142            // update alerts
143            if alternating_alert {
144                self.update_alert(threshold as usize, SpcAlertType::Alternating)?;
145            }
146        }
147
148        Ok(())
149    }
150
151    pub fn convert_rules_to_vec(&self, rule: &str) -> Result<Vec<i32>, DriftError> {
152        let rule_chars = rule.split(' ');
153
154        let rule_vec = rule_chars
155            .collect::<Vec<&str>>()
156            .into_iter()
157            .map(|ele| ele.parse::<i32>())
158            .collect::<Result<Vec<i32>, ParseIntError>>()?;
159
160        // assert rule_vec.len() == 7
161        let rule_vec_len = rule_vec.len();
162        if rule_vec_len != 8 {
163            return Err(DriftError::SpcRuleLengthError);
164        }
165
166        Ok(rule_vec)
167    }
168
169    pub fn check_process_rule_for_alert(
170        &mut self,
171        drift_array: &ArrayView1<f64>,
172    ) -> Result<(), DriftError> {
173        let rule_vec = self.convert_rules_to_vec(&self.alert_rule.rule)?;
174
175        // iterate over each value in drift array
176        for (idx, value) in drift_array.iter().enumerate() {
177            // iterate over rule vec and step by 2 (consecutive and alternating rules for each zone)
178            for i in (0..=6).step_by(2) {
179                let threshold = match i {
180                    0 => 1,
181                    2 => 2,
182                    4 => 3,
183                    6 => 4,
184                    _ => 0,
185                };
186
187                self.check_zone(
188                    *value,
189                    idx,
190                    drift_array,
191                    rule_vec[i] as usize,
192                    rule_vec[i + 1] as usize,
193                    threshold as f64,
194                )?;
195            }
196        }
197
198        Ok(())
199    }
200
201    pub fn update_alert(
202        &mut self,
203        threshold: usize,
204        alert: SpcAlertType,
205    ) -> Result<(), DriftError> {
206        let alert_zone = match threshold {
207            1 => AlertZone::Zone1,
208            2 => AlertZone::Zone2,
209            3 => AlertZone::Zone3,
210            4 => AlertZone::Zone4,
211            _ => AlertZone::NotApplicable,
212        };
213
214        // skip if the zone is not in the process rule
215        if !self.alert_rule.zones_to_monitor.contains(&alert_zone) {
216            return Ok(());
217        }
218
219        if alert_zone == AlertZone::Zone4 {
220            self.alerts.insert(SpcAlert {
221                zone: alert_zone,
222                kind: SpcAlertType::OutOfBounds,
223            });
224        } else {
225            self.alerts.insert(SpcAlert {
226                zone: alert_zone,
227                kind: alert,
228            });
229        }
230
231        Ok(())
232    }
233
234    pub fn check_trend(&mut self, drift_array: &ArrayView1<f64>) -> Result<(), DriftError> {
235        drift_array.windows(7).into_iter().for_each(|window| {
236            // iterate over array and check if each value is increasing or decreasing
237            let mut increasing = 0;
238            let mut decreasing = 0;
239
240            // iterate through
241            for i in 1..window.len() {
242                if window[i] > window[i - 1] {
243                    increasing += 1;
244                } else if window[i] < window[i - 1] {
245                    decreasing += 1;
246                }
247            }
248
249            if increasing >= 6 || decreasing >= 6 {
250                self.alerts.insert(SpcAlert {
251                    zone: AlertZone::NotApplicable,
252                    kind: SpcAlertType::Trend,
253                });
254            }
255        });
256
257        Ok(())
258    }
259}
260
261impl Default for Alerter {
262    fn default() -> Self {
263        let rule = SpcAlertRule::default();
264        Alerter {
265            alerts: HashSet::new(),
266            alert_rule: rule,
267        }
268    }
269}
270
271pub fn generate_alert(
272    drift_array: &ArrayView1<f64>,
273    rule: &SpcAlertRule,
274) -> Result<HashSet<SpcAlert>, DriftError> {
275    let mut alerter = Alerter::new(rule.clone());
276
277    alerter.check_process_rule_for_alert(&drift_array.view())?;
278
279    alerter.check_trend(&drift_array.view())?;
280
281    Ok(alerter.alerts)
282}
283
284/// Generate alerts for each feature in the drift array
285///
286/// # Arguments
287/// drift_array - ArrayView2<f64> - The drift array to check for alerts (column order should match feature order)
288/// features - Vec<String> - The features to check for alerts (feature order should match drift array column order)
289/// alert_rule - AlertRule - The alert rule to check against
290///
291/// Returns a Result<FeatureAlerts, AlertError>
292///
293pub fn generate_alerts(
294    drift_array: &ArrayView2<f64>,
295    features: &[String],
296    rule: &SpcAlertRule,
297) -> Result<SpcFeatureAlerts, DriftError> {
298    let mut has_alerts: bool = false;
299
300    // check for alerts
301    let alerts = drift_array
302        .axis_iter(Axis(1))
303        .into_par_iter()
304        .map(|col| {
305            // check for alerts and errors
306            generate_alert(&col, rule)
307        })
308        .collect::<Vec<Result<HashSet<SpcAlert>, DriftError>>>();
309
310    // Calculate correlation matrix when there are alerts
311    if alerts
312        .iter()
313        .any(|alert| !alert.as_ref().unwrap().is_empty())
314    {
315        // get correlation matrix
316        has_alerts = true;
317    };
318
319    let mut feature_alerts = SpcFeatureAlerts::new(has_alerts);
320
321    //zip the alerts with the features
322    for (feature, alert) in features.iter().zip(alerts.iter()) {
323        // unwrap the alert, should should have already been checked
324        let alerts = alert.as_ref().unwrap();
325        //let alerts: Vec<SpcAlert> = alerts.iter().cloned().collect();
326
327        feature_alerts.insert_feature_alert(feature, alerts.to_owned());
328    }
329
330    Ok(feature_alerts)
331}
332
333#[cfg(test)]
334mod tests {
335
336    use scouter_types::spc::SpcAlertRule;
337
338    use super::*;
339    use ndarray::arr2;
340    use ndarray::Array;
341
342    #[test]
343    fn test_alerting_consecutive() {
344        let alerter = Alerter::default();
345        // write tests for all alerts
346        let values = [0.0, 1.0, 1.0, 1.0, 1.0, 1.0];
347        let drift_array = Array::from_vec(values.to_vec());
348        let threshold = 1.0;
349
350        let result = alerter
351            .check_zone_consecutive(&drift_array.view(), 5, threshold)
352            .unwrap();
353        assert!(result);
354
355        let values = [0.0, 1.0, 1.0, -1.0, 1.0, 1.0];
356        let drift_array = Array::from_vec(values.to_vec());
357        let threshold = 1.0;
358
359        let result = alerter
360            .check_zone_consecutive(&drift_array.view(), 5, threshold)
361            .unwrap();
362        assert!(!result);
363    }
364
365    #[test]
366    fn test_alerting_alternating() {
367        let alerter = Alerter::default();
368        let values = [0.0, 1.0, -1.0, 1.0, -1.0, 1.0];
369        let drift_array = Array::from_vec(values.to_vec());
370        let threshold = 1.0;
371
372        let result = alerter
373            .check_zone_alternating(&drift_array.view(), 5, threshold)
374            .unwrap();
375        assert!(result);
376
377        let values = [0.0, 1.0, -1.0, 1.0, 0.0, 1.0];
378        let drift_array = Array::from_vec(values.to_vec());
379        let threshold = 1.0;
380
381        let result = alerter
382            .check_zone_alternating(&drift_array.view(), 5, threshold)
383            .unwrap();
384        assert!(!result);
385    }
386
387    #[test]
388    fn test_convert_rule() {
389        let alerter = Alerter::default();
390        let vec_of_ints = alerter
391            .convert_rules_to_vec(&SpcAlertRule::default().rule)
392            .unwrap();
393        assert_eq!(vec_of_ints, [8, 16, 4, 8, 2, 4, 1, 1,]);
394    }
395
396    #[test]
397    fn test_check_rule() {
398        let mut alerter = Alerter::default();
399        let values = [
400            0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, -2.0, 2.0, 0.0, 0.0, 3.0, 3.0,
401            3.0, 4.0, 0.0, -4.0, 3.0, -3.0, 3.0, -3.0, 3.0, -3.0,
402        ];
403        let drift_array = Array::from_vec(values.to_vec());
404        alerter
405            .check_process_rule_for_alert(&drift_array.view())
406            .unwrap();
407
408        assert_eq!(alerter.alerts.len(), 4);
409    }
410
411    #[test]
412    fn test_check_rule_zones_to_monitor() {
413        let zones_to_monitor = [AlertZone::Zone1, AlertZone::Zone4].to_vec();
414        let process = SpcAlertRule {
415            zones_to_monitor,
416            ..Default::default()
417        };
418
419        let mut alerter = Alerter::new(process);
420
421        let values = [
422            0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0, -2.0, 2.0, 0.0, 0.0, 3.0, 3.0,
423            3.0, 4.0, 0.0, -4.0, 3.0, -3.0, 3.0, -3.0, 3.0, -3.0,
424        ];
425        let drift_array = Array::from_vec(values.to_vec());
426
427        alerter
428            .check_process_rule_for_alert(&drift_array.view())
429            .unwrap();
430
431        assert_eq!(alerter.alerts.len(), 2);
432    }
433
434    #[test]
435    fn test_check_trend() {
436        let mut alerter = Alerter::default();
437        let values = [
438            0.0, 0.0, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0.2, 0.3, 0.4,
439            0.5, 0.6, 0.7,
440        ];
441        let drift_samples = Array::from_vec(values.to_vec());
442
443        alerter.check_trend(&drift_samples.view()).unwrap();
444
445        // get first alert
446        let alert = alerter.alerts.iter().next().unwrap();
447
448        assert_eq!(alert.zone, AlertZone::NotApplicable);
449        assert_eq!(alert.kind, SpcAlertType::Trend);
450    }
451
452    #[test]
453    fn test_generate_process_alerts() {
454        // has alerts
455        // create 20, 3 vector
456
457        let drift_array = arr2(&[
458            [0.0, 0.0, 4.0, 4.0],
459            [0.0, 1.0, 1.0, 1.0],
460            [1.0, 0.0, -1.0, -1.0],
461            [0.0, 1.1, 2.0, 2.0],
462            [2.0, 0.0, -2.0, -2.0],
463            [0.0, 0.0, 1.0, 1.0],
464            [0.0, 2.1, 1.0, 1.0],
465            [0.0, 0.0, 1.0, 1.0],
466            [2.0, 1.0, 1.0, 1.0],
467            [0.0, 1.0, 1.0, 1.0],
468            [0.0, 0.0, 1.0, 1.0],
469            [0.0, 2.1, 1.0, 1.0],
470            [0.0, 0.0, 1.0, 1.0],
471            [1.0, 0.0, 1.0, 1.0],
472        ]);
473
474        // assert shape is 16,3
475        assert_eq!(drift_array.shape(), &[14, 4]);
476
477        let features = vec![
478            "feature1".to_string(),
479            "feature2".to_string(),
480            "feature3".to_string(),
481            "feature4".to_string(),
482        ];
483
484        let rule = SpcAlertRule::default();
485
486        let alerts = generate_alerts(&drift_array.view(), &features, &rule).unwrap();
487
488        let feature1 = alerts.features.get("feature1").unwrap();
489        let feature2 = alerts.features.get("feature2").unwrap();
490        let feature3 = alerts.features.get("feature3").unwrap();
491        let feature4 = alerts.features.get("feature4").unwrap();
492
493        // assert feature 1 is has an empty hash set
494        assert_eq!(feature1.alerts.len(), 0);
495        assert_eq!(feature1.alerts.len(), 0);
496
497        // assert feature 3 has 2 alerts
498        assert_eq!(feature3.alerts.len(), 2);
499
500        assert_eq!(feature4.alerts.len(), 2);
501
502        // assert feature 2 has 0 alert
503        assert_eq!(feature2.alerts.len(), 0);
504    }
505}