Skip to main content

shape_runtime/
pattern_state_machine.rs

1//! Generic pattern state machine for sequence detection
2//!
3//! Provides pattern matching over event streams:
4//! - Sequential pattern matching
5//! - Temporal constraints (WITHIN)
6//! - Logical operators (AND, OR, NOT, FOLLOWED_BY)
7//! - State-based pattern tracking
8//!
9//! This module is industry-agnostic and works with any events.
10
11use chrono::{DateTime, Duration, Utc};
12use std::collections::HashMap;
13
14use shape_ast::error::Result;
15use shape_value::{NanTag, ValueWord};
16
17/// A condition for pattern matching
18#[derive(Debug, Clone)]
19pub struct PatternCondition {
20    /// Unique name for this condition
21    pub name: String,
22    /// Field to evaluate
23    pub field: String,
24    /// Comparison operator
25    pub operator: ComparisonOp,
26    /// Value to compare against
27    pub value: ValueWord,
28}
29
30/// Comparison operators
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum ComparisonOp {
33    Eq,
34    Ne,
35    Gt,
36    Ge,
37    Lt,
38    Le,
39    Contains,
40    StartsWith,
41    EndsWith,
42}
43
44impl PatternCondition {
45    /// Create a new pattern condition
46    pub fn new(name: &str, field: &str, operator: ComparisonOp, value: ValueWord) -> Self {
47        Self {
48            name: name.to_string(),
49            field: field.to_string(),
50            operator,
51            value,
52        }
53    }
54
55    /// Evaluate this condition against event fields
56    pub fn evaluate(&self, fields: &HashMap<String, ValueWord>) -> bool {
57        let Some(field_value) = fields.get(&self.field) else {
58            return false;
59        };
60
61        match self.operator {
62            // Numeric comparisons
63            ComparisonOp::Eq
64            | ComparisonOp::Ne
65            | ComparisonOp::Gt
66            | ComparisonOp::Ge
67            | ComparisonOp::Lt
68            | ComparisonOp::Le => match (field_value.tag(), self.value.tag()) {
69                (NanTag::F64, NanTag::F64)
70                | (NanTag::I48, NanTag::I48)
71                | (NanTag::F64, NanTag::I48)
72                | (NanTag::I48, NanTag::F64) => {
73                    if let (Some(a), Some(b)) = (field_value.as_f64(), self.value.as_f64()) {
74                        match self.operator {
75                            ComparisonOp::Eq => (a - b).abs() < f64::EPSILON,
76                            ComparisonOp::Ne => (a - b).abs() >= f64::EPSILON,
77                            ComparisonOp::Gt => a > b,
78                            ComparisonOp::Ge => a >= b,
79                            ComparisonOp::Lt => a < b,
80                            ComparisonOp::Le => a <= b,
81                            _ => false,
82                        }
83                    } else {
84                        false
85                    }
86                }
87                (NanTag::Heap, NanTag::Heap) => {
88                    if let (Some(a), Some(b)) = (field_value.as_str(), self.value.as_str()) {
89                        match self.operator {
90                            ComparisonOp::Eq => a == b,
91                            ComparisonOp::Ne => a != b,
92                            _ => false,
93                        }
94                    } else {
95                        false
96                    }
97                }
98                (NanTag::Bool, NanTag::Bool) => match self.operator {
99                    ComparisonOp::Eq => field_value.as_bool() == self.value.as_bool(),
100                    ComparisonOp::Ne => field_value.as_bool() != self.value.as_bool(),
101                    _ => false,
102                },
103                _ => false,
104            },
105            // String-specific operations
106            ComparisonOp::Contains => {
107                if let (Some(a), Some(b)) = (field_value.as_str(), self.value.as_str()) {
108                    a.contains(b)
109                } else {
110                    false
111                }
112            }
113            ComparisonOp::StartsWith => {
114                if let (Some(a), Some(b)) = (field_value.as_str(), self.value.as_str()) {
115                    a.starts_with(b)
116                } else {
117                    false
118                }
119            }
120            ComparisonOp::EndsWith => {
121                if let (Some(a), Some(b)) = (field_value.as_str(), self.value.as_str()) {
122                    a.ends_with(b)
123                } else {
124                    false
125                }
126            }
127        }
128    }
129}
130
131/// Pattern sequence operators
132#[derive(Debug, Clone)]
133pub enum PatternSequence {
134    /// Single condition
135    Condition(PatternCondition),
136    /// Sequence of patterns (must occur in order)
137    Seq(Vec<PatternSequence>),
138    /// Pattern must complete within duration
139    Within(Box<PatternSequence>, Duration),
140    /// One pattern followed by another
141    FollowedBy(Box<PatternSequence>, Box<PatternSequence>),
142    /// Pattern must NOT occur
143    Not(Box<PatternSequence>),
144    /// Any of these patterns
145    Or(Vec<PatternSequence>),
146    /// All of these patterns (any order)
147    And(Vec<PatternSequence>),
148    /// Pattern repeated N times
149    Repeat(Box<PatternSequence>, usize),
150}
151
152impl PatternSequence {
153    /// Create a single condition pattern
154    pub fn condition(name: &str, field: &str, op: ComparisonOp, value: ValueWord) -> Self {
155        PatternSequence::Condition(PatternCondition::new(name, field, op, value))
156    }
157
158    /// Create a sequence of patterns
159    pub fn seq(patterns: Vec<PatternSequence>) -> Self {
160        PatternSequence::Seq(patterns)
161    }
162
163    /// Add a time constraint
164    pub fn within(self, duration: Duration) -> Self {
165        PatternSequence::Within(Box::new(self), duration)
166    }
167
168    /// Create a followed-by pattern
169    pub fn followed_by(self, next: PatternSequence) -> Self {
170        PatternSequence::FollowedBy(Box::new(self), Box::new(next))
171    }
172
173    /// Negate a pattern
174    pub fn not(self) -> Self {
175        PatternSequence::Not(Box::new(self))
176    }
177
178    /// Create an OR of patterns
179    pub fn or(patterns: Vec<PatternSequence>) -> Self {
180        PatternSequence::Or(patterns)
181    }
182
183    /// Create an AND of patterns
184    pub fn and(patterns: Vec<PatternSequence>) -> Self {
185        PatternSequence::And(patterns)
186    }
187
188    /// Repeat pattern N times
189    pub fn repeat(self, times: usize) -> Self {
190        PatternSequence::Repeat(Box::new(self), times)
191    }
192}
193
194/// State of a pattern match in progress
195#[derive(Debug, Clone)]
196struct MatchState {
197    /// Pattern being matched
198    pattern_id: usize,
199    /// Current position in sequence
200    position: usize,
201    /// When matching started
202    start_time: DateTime<Utc>,
203    /// Deadline for WITHIN constraints
204    deadline: Option<DateTime<Utc>>,
205    /// Events matched so far
206    matched_events: Vec<MatchedEvent>,
207}
208
209/// A matched event in a pattern
210#[derive(Debug, Clone)]
211pub struct MatchedEvent {
212    pub timestamp: DateTime<Utc>,
213    pub condition_name: String,
214    pub fields: HashMap<String, ValueWord>,
215}
216
217/// A completed pattern match
218#[derive(Debug, Clone)]
219pub struct PatternMatch {
220    /// Pattern name
221    pub pattern_name: String,
222    /// When the match started
223    pub start_time: DateTime<Utc>,
224    /// When the match completed
225    pub end_time: DateTime<Utc>,
226    /// Events that made up the match
227    pub events: Vec<MatchedEvent>,
228}
229
230/// Pattern definition with name
231#[derive(Debug, Clone)]
232pub struct PatternDef {
233    pub name: String,
234    pub sequence: PatternSequence,
235}
236
237/// Generic pattern state machine for event sequence detection
238pub struct PatternStateMachine {
239    /// Registered patterns
240    patterns: Vec<PatternDef>,
241    /// Active match states
242    active_states: Vec<MatchState>,
243    /// Completed matches
244    completed_matches: Vec<PatternMatch>,
245}
246
247impl Default for PatternStateMachine {
248    fn default() -> Self {
249        Self::new()
250    }
251}
252
253impl PatternStateMachine {
254    /// Create a new pattern state machine
255    pub fn new() -> Self {
256        Self {
257            patterns: Vec::new(),
258            active_states: Vec::new(),
259            completed_matches: Vec::new(),
260        }
261    }
262
263    /// Register a pattern
264    pub fn register(&mut self, name: &str, sequence: PatternSequence) -> &mut Self {
265        self.patterns.push(PatternDef {
266            name: name.to_string(),
267            sequence,
268        });
269        self
270    }
271
272    /// Process an event
273    pub fn process(
274        &mut self,
275        timestamp: DateTime<Utc>,
276        fields: HashMap<String, ValueWord>,
277    ) -> Result<()> {
278        // Remove expired states
279        self.active_states
280            .retain(|state| state.deadline.map(|d| timestamp <= d).unwrap_or(true));
281
282        // Try to advance existing states
283        let mut new_states = Vec::new();
284        let mut completed = Vec::new();
285
286        for state in &self.active_states {
287            if let Some((new_state, is_complete)) = self.advance_state(state, timestamp, &fields)? {
288                if is_complete {
289                    // Pattern completed
290                    let pattern = &self.patterns[state.pattern_id];
291                    completed.push(PatternMatch {
292                        pattern_name: pattern.name.clone(),
293                        start_time: state.start_time,
294                        end_time: timestamp,
295                        events: new_state.matched_events,
296                    });
297                } else {
298                    new_states.push(new_state);
299                }
300            }
301        }
302
303        // Try to start new pattern matches
304        for (pattern_id, pattern) in self.patterns.iter().enumerate() {
305            if let Some(state) =
306                self.try_start_match(pattern_id, &pattern.sequence, timestamp, &fields)?
307            {
308                // Check if it's already complete (single condition pattern)
309                if self.is_pattern_complete(&pattern.sequence, &state) {
310                    completed.push(PatternMatch {
311                        pattern_name: pattern.name.clone(),
312                        start_time: timestamp,
313                        end_time: timestamp,
314                        events: state.matched_events,
315                    });
316                } else {
317                    new_states.push(state);
318                }
319            }
320        }
321
322        // Update states
323        self.active_states = new_states;
324        self.completed_matches.extend(completed);
325
326        Ok(())
327    }
328
329    /// Try to start a new pattern match
330    fn try_start_match(
331        &self,
332        pattern_id: usize,
333        sequence: &PatternSequence,
334        timestamp: DateTime<Utc>,
335        fields: &HashMap<String, ValueWord>,
336    ) -> Result<Option<MatchState>> {
337        match sequence {
338            PatternSequence::Condition(cond) => {
339                if cond.evaluate(fields) {
340                    Ok(Some(MatchState {
341                        pattern_id,
342                        position: 1, // Completed first (and only) condition
343                        start_time: timestamp,
344                        deadline: None,
345                        matched_events: vec![MatchedEvent {
346                            timestamp,
347                            condition_name: cond.name.clone(),
348                            fields: fields.clone(),
349                        }],
350                    }))
351                } else {
352                    Ok(None)
353                }
354            }
355            PatternSequence::Seq(patterns) if !patterns.is_empty() => {
356                // Try to match first pattern in sequence
357                self.try_start_match(pattern_id, &patterns[0], timestamp, fields)
358            }
359            PatternSequence::Within(inner, duration) => {
360                if let Some(mut state) =
361                    self.try_start_match(pattern_id, inner, timestamp, fields)?
362                {
363                    state.deadline = Some(timestamp + *duration);
364                    Ok(Some(state))
365                } else {
366                    Ok(None)
367                }
368            }
369            PatternSequence::Or(patterns) => {
370                for pattern in patterns {
371                    if let Some(state) =
372                        self.try_start_match(pattern_id, pattern, timestamp, fields)?
373                    {
374                        return Ok(Some(state));
375                    }
376                }
377                Ok(None)
378            }
379            PatternSequence::And(patterns) => {
380                // For AND, all conditions must match the same event
381                let mut all_matched = true;
382                let mut matched_events = Vec::new();
383
384                for pattern in patterns {
385                    if let Some(state) =
386                        self.try_start_match(pattern_id, pattern, timestamp, fields)?
387                    {
388                        matched_events.extend(state.matched_events);
389                    } else {
390                        all_matched = false;
391                        break;
392                    }
393                }
394
395                if all_matched && !matched_events.is_empty() {
396                    Ok(Some(MatchState {
397                        pattern_id,
398                        position: 1,
399                        start_time: timestamp,
400                        deadline: None,
401                        matched_events,
402                    }))
403                } else {
404                    Ok(None)
405                }
406            }
407            _ => Ok(None),
408        }
409    }
410
411    /// Advance an existing match state
412    fn advance_state(
413        &self,
414        state: &MatchState,
415        timestamp: DateTime<Utc>,
416        fields: &HashMap<String, ValueWord>,
417    ) -> Result<Option<(MatchState, bool)>> {
418        let pattern = &self.patterns[state.pattern_id];
419
420        match &pattern.sequence {
421            PatternSequence::Seq(patterns) => {
422                if state.position < patterns.len() {
423                    // Try to match next pattern in sequence
424                    if let PatternSequence::Condition(cond) = &patterns[state.position] {
425                        if cond.evaluate(fields) {
426                            let mut new_state = state.clone();
427                            new_state.position += 1;
428                            new_state.matched_events.push(MatchedEvent {
429                                timestamp,
430                                condition_name: cond.name.clone(),
431                                fields: fields.clone(),
432                            });
433
434                            let is_complete = new_state.position >= patterns.len();
435                            return Ok(Some((new_state, is_complete)));
436                        }
437                    }
438                }
439            }
440            PatternSequence::FollowedBy(_, second) => {
441                // If we're past the first pattern, try matching the second
442                if state.position == 1 {
443                    if let PatternSequence::Condition(cond) = second.as_ref() {
444                        if cond.evaluate(fields) {
445                            let mut new_state = state.clone();
446                            new_state.position = 2;
447                            new_state.matched_events.push(MatchedEvent {
448                                timestamp,
449                                condition_name: cond.name.clone(),
450                                fields: fields.clone(),
451                            });
452                            return Ok(Some((new_state, true)));
453                        }
454                    }
455                }
456            }
457            PatternSequence::Repeat(inner, times) => {
458                if state.position < *times {
459                    if let Some(new_inner_state) =
460                        self.try_start_match(state.pattern_id, inner, timestamp, fields)?
461                    {
462                        let mut new_state = state.clone();
463                        new_state.position += 1;
464                        new_state
465                            .matched_events
466                            .extend(new_inner_state.matched_events);
467
468                        let is_complete = new_state.position >= *times;
469                        return Ok(Some((new_state, is_complete)));
470                    }
471                }
472            }
473            _ => {}
474        }
475
476        // No advancement, keep current state
477        Ok(Some((state.clone(), false)))
478    }
479
480    /// Check if a pattern is complete
481    fn is_pattern_complete(&self, sequence: &PatternSequence, state: &MatchState) -> bool {
482        match sequence {
483            PatternSequence::Condition(_) => state.position >= 1,
484            PatternSequence::Seq(patterns) => state.position >= patterns.len(),
485            PatternSequence::Within(inner, _) => self.is_pattern_complete(inner, state),
486            PatternSequence::Repeat(_, times) => state.position >= *times,
487            PatternSequence::And(_) | PatternSequence::Or(_) => state.position >= 1,
488            _ => false,
489        }
490    }
491
492    /// Take completed matches
493    pub fn take_matches(&mut self) -> Vec<PatternMatch> {
494        std::mem::take(&mut self.completed_matches)
495    }
496
497    /// Get count of active match states
498    pub fn active_count(&self) -> usize {
499        self.active_states.len()
500    }
501
502    /// Reset all state
503    pub fn reset(&mut self) {
504        self.active_states.clear();
505        self.completed_matches.clear();
506    }
507}
508
509#[cfg(test)]
510mod tests {
511    use super::*;
512    use std::sync::Arc;
513
514    fn make_event(value: f64, status: &str) -> HashMap<String, ValueWord> {
515        let mut fields = HashMap::new();
516        fields.insert("value".to_string(), ValueWord::from_f64(value));
517        fields.insert(
518            "status".to_string(),
519            ValueWord::from_string(Arc::new(status.to_string())),
520        );
521        fields
522    }
523
524    #[test]
525    fn test_single_condition() {
526        let mut psm = PatternStateMachine::new();
527
528        psm.register(
529            "high_value",
530            PatternSequence::condition(
531                "high",
532                "value",
533                ComparisonOp::Gt,
534                ValueWord::from_f64(100.0),
535            ),
536        );
537
538        let base = DateTime::from_timestamp(1000000000, 0).unwrap();
539
540        // Should not match
541        psm.process(base, make_event(50.0, "ok")).unwrap();
542        assert!(psm.take_matches().is_empty());
543
544        // Should match
545        psm.process(base + Duration::seconds(1), make_event(150.0, "ok"))
546            .unwrap();
547        let matches = psm.take_matches();
548        assert_eq!(matches.len(), 1);
549        assert_eq!(matches[0].pattern_name, "high_value");
550    }
551
552    #[test]
553    fn test_sequence_pattern() {
554        let mut psm = PatternStateMachine::new();
555
556        // Pattern: value goes from low to high
557        psm.register(
558            "spike",
559            PatternSequence::seq(vec![
560                PatternSequence::condition(
561                    "low",
562                    "value",
563                    ComparisonOp::Lt,
564                    ValueWord::from_f64(50.0),
565                ),
566                PatternSequence::condition(
567                    "high",
568                    "value",
569                    ComparisonOp::Gt,
570                    ValueWord::from_f64(150.0),
571                ),
572            ]),
573        );
574
575        let base = DateTime::from_timestamp(1000000000, 0).unwrap();
576
577        // Start with low value
578        psm.process(base, make_event(30.0, "ok")).unwrap();
579        assert!(psm.take_matches().is_empty());
580        assert_eq!(psm.active_count(), 1); // Active state waiting for high
581
582        // High value completes the pattern
583        psm.process(base + Duration::seconds(1), make_event(200.0, "ok"))
584            .unwrap();
585        let matches = psm.take_matches();
586        assert_eq!(matches.len(), 1);
587        assert_eq!(matches[0].events.len(), 2);
588    }
589
590    #[test]
591    fn test_within_constraint() {
592        let mut psm = PatternStateMachine::new();
593
594        // Pattern must complete within 5 seconds
595        psm.register(
596            "fast_spike",
597            PatternSequence::seq(vec![
598                PatternSequence::condition(
599                    "low",
600                    "value",
601                    ComparisonOp::Lt,
602                    ValueWord::from_f64(50.0),
603                ),
604                PatternSequence::condition(
605                    "high",
606                    "value",
607                    ComparisonOp::Gt,
608                    ValueWord::from_f64(150.0),
609                ),
610            ])
611            .within(Duration::seconds(5)),
612        );
613
614        let base = DateTime::from_timestamp(1000000000, 0).unwrap();
615
616        // Start with low value
617        psm.process(base, make_event(30.0, "ok")).unwrap();
618        assert_eq!(psm.active_count(), 1);
619
620        // High value comes too late (10 seconds later)
621        psm.process(base + Duration::seconds(10), make_event(200.0, "ok"))
622            .unwrap();
623
624        // State should have expired
625        let matches = psm.take_matches();
626        assert!(matches.is_empty());
627    }
628
629    #[test]
630    fn test_or_pattern() {
631        let mut psm = PatternStateMachine::new();
632
633        // Pattern: either high value OR status is "alert"
634        psm.register(
635            "alert_condition",
636            PatternSequence::or(vec![
637                PatternSequence::condition(
638                    "high_val",
639                    "value",
640                    ComparisonOp::Gt,
641                    ValueWord::from_f64(100.0),
642                ),
643                PatternSequence::condition(
644                    "alert_status",
645                    "status",
646                    ComparisonOp::Eq,
647                    ValueWord::from_string(Arc::new("alert".to_string())),
648                ),
649            ]),
650        );
651
652        let base = DateTime::from_timestamp(1000000000, 0).unwrap();
653
654        // Match via value
655        psm.process(base, make_event(150.0, "ok")).unwrap();
656        assert_eq!(psm.take_matches().len(), 1);
657
658        // Match via status
659        psm.process(base + Duration::seconds(1), make_event(50.0, "alert"))
660            .unwrap();
661        assert_eq!(psm.take_matches().len(), 1);
662    }
663
664    #[test]
665    fn test_string_conditions() {
666        let mut psm = PatternStateMachine::new();
667
668        psm.register(
669            "status_check",
670            PatternSequence::condition(
671                "starts_err",
672                "status",
673                ComparisonOp::StartsWith,
674                ValueWord::from_string(Arc::new("err".to_string())),
675            ),
676        );
677
678        let base = DateTime::from_timestamp(1000000000, 0).unwrap();
679
680        // Should not match
681        psm.process(base, make_event(0.0, "ok")).unwrap();
682        assert!(psm.take_matches().is_empty());
683
684        // Should match
685        let mut fields = HashMap::new();
686        fields.insert("value".to_string(), ValueWord::from_f64(0.0));
687        fields.insert(
688            "status".to_string(),
689            ValueWord::from_string(Arc::new("error: connection failed".to_string())),
690        );
691        psm.process(base + Duration::seconds(1), fields).unwrap();
692        assert_eq!(psm.take_matches().len(), 1);
693    }
694}