1use chrono::{DateTime, Duration, Utc};
12use std::collections::HashMap;
13
14use shape_ast::error::Result;
15use shape_value::{NanTag, ValueWord};
16
17#[derive(Debug, Clone)]
19pub struct PatternCondition {
20 pub name: String,
22 pub field: String,
24 pub operator: ComparisonOp,
26 pub value: ValueWord,
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum ComparisonOp {
33 Eq,
34 Ne,
35 Gt,
36 Ge,
37 Lt,
38 Le,
39 Contains,
40 StartsWith,
41 EndsWith,
42}
43
44impl PatternCondition {
45 pub fn new(name: &str, field: &str, operator: ComparisonOp, value: ValueWord) -> Self {
47 Self {
48 name: name.to_string(),
49 field: field.to_string(),
50 operator,
51 value,
52 }
53 }
54
55 pub fn evaluate(&self, fields: &HashMap<String, ValueWord>) -> bool {
57 let Some(field_value) = fields.get(&self.field) else {
58 return false;
59 };
60
61 match self.operator {
62 ComparisonOp::Eq
64 | ComparisonOp::Ne
65 | ComparisonOp::Gt
66 | ComparisonOp::Ge
67 | ComparisonOp::Lt
68 | ComparisonOp::Le => match (field_value.tag(), self.value.tag()) {
69 (NanTag::F64, NanTag::F64)
70 | (NanTag::I48, NanTag::I48)
71 | (NanTag::F64, NanTag::I48)
72 | (NanTag::I48, NanTag::F64) => {
73 if let (Some(a), Some(b)) = (field_value.as_f64(), self.value.as_f64()) {
74 match self.operator {
75 ComparisonOp::Eq => (a - b).abs() < f64::EPSILON,
76 ComparisonOp::Ne => (a - b).abs() >= f64::EPSILON,
77 ComparisonOp::Gt => a > b,
78 ComparisonOp::Ge => a >= b,
79 ComparisonOp::Lt => a < b,
80 ComparisonOp::Le => a <= b,
81 _ => false,
82 }
83 } else {
84 false
85 }
86 }
87 (NanTag::Heap, NanTag::Heap) => {
88 if let (Some(a), Some(b)) = (field_value.as_str(), self.value.as_str()) {
89 match self.operator {
90 ComparisonOp::Eq => a == b,
91 ComparisonOp::Ne => a != b,
92 _ => false,
93 }
94 } else {
95 false
96 }
97 }
98 (NanTag::Bool, NanTag::Bool) => match self.operator {
99 ComparisonOp::Eq => field_value.as_bool() == self.value.as_bool(),
100 ComparisonOp::Ne => field_value.as_bool() != self.value.as_bool(),
101 _ => false,
102 },
103 _ => false,
104 },
105 ComparisonOp::Contains => {
107 if let (Some(a), Some(b)) = (field_value.as_str(), self.value.as_str()) {
108 a.contains(b)
109 } else {
110 false
111 }
112 }
113 ComparisonOp::StartsWith => {
114 if let (Some(a), Some(b)) = (field_value.as_str(), self.value.as_str()) {
115 a.starts_with(b)
116 } else {
117 false
118 }
119 }
120 ComparisonOp::EndsWith => {
121 if let (Some(a), Some(b)) = (field_value.as_str(), self.value.as_str()) {
122 a.ends_with(b)
123 } else {
124 false
125 }
126 }
127 }
128 }
129}
130
131#[derive(Debug, Clone)]
133pub enum PatternSequence {
134 Condition(PatternCondition),
136 Seq(Vec<PatternSequence>),
138 Within(Box<PatternSequence>, Duration),
140 FollowedBy(Box<PatternSequence>, Box<PatternSequence>),
142 Not(Box<PatternSequence>),
144 Or(Vec<PatternSequence>),
146 And(Vec<PatternSequence>),
148 Repeat(Box<PatternSequence>, usize),
150}
151
152impl PatternSequence {
153 pub fn condition(name: &str, field: &str, op: ComparisonOp, value: ValueWord) -> Self {
155 PatternSequence::Condition(PatternCondition::new(name, field, op, value))
156 }
157
158 pub fn seq(patterns: Vec<PatternSequence>) -> Self {
160 PatternSequence::Seq(patterns)
161 }
162
163 pub fn within(self, duration: Duration) -> Self {
165 PatternSequence::Within(Box::new(self), duration)
166 }
167
168 pub fn followed_by(self, next: PatternSequence) -> Self {
170 PatternSequence::FollowedBy(Box::new(self), Box::new(next))
171 }
172
173 pub fn not(self) -> Self {
175 PatternSequence::Not(Box::new(self))
176 }
177
178 pub fn or(patterns: Vec<PatternSequence>) -> Self {
180 PatternSequence::Or(patterns)
181 }
182
183 pub fn and(patterns: Vec<PatternSequence>) -> Self {
185 PatternSequence::And(patterns)
186 }
187
188 pub fn repeat(self, times: usize) -> Self {
190 PatternSequence::Repeat(Box::new(self), times)
191 }
192}
193
194#[derive(Debug, Clone)]
196struct MatchState {
197 pattern_id: usize,
199 position: usize,
201 start_time: DateTime<Utc>,
203 deadline: Option<DateTime<Utc>>,
205 matched_events: Vec<MatchedEvent>,
207}
208
209#[derive(Debug, Clone)]
211pub struct MatchedEvent {
212 pub timestamp: DateTime<Utc>,
213 pub condition_name: String,
214 pub fields: HashMap<String, ValueWord>,
215}
216
217#[derive(Debug, Clone)]
219pub struct PatternMatch {
220 pub pattern_name: String,
222 pub start_time: DateTime<Utc>,
224 pub end_time: DateTime<Utc>,
226 pub events: Vec<MatchedEvent>,
228}
229
230#[derive(Debug, Clone)]
232pub struct PatternDef {
233 pub name: String,
234 pub sequence: PatternSequence,
235}
236
237pub struct PatternStateMachine {
239 patterns: Vec<PatternDef>,
241 active_states: Vec<MatchState>,
243 completed_matches: Vec<PatternMatch>,
245}
246
247impl Default for PatternStateMachine {
248 fn default() -> Self {
249 Self::new()
250 }
251}
252
253impl PatternStateMachine {
254 pub fn new() -> Self {
256 Self {
257 patterns: Vec::new(),
258 active_states: Vec::new(),
259 completed_matches: Vec::new(),
260 }
261 }
262
263 pub fn register(&mut self, name: &str, sequence: PatternSequence) -> &mut Self {
265 self.patterns.push(PatternDef {
266 name: name.to_string(),
267 sequence,
268 });
269 self
270 }
271
272 pub fn process(
274 &mut self,
275 timestamp: DateTime<Utc>,
276 fields: HashMap<String, ValueWord>,
277 ) -> Result<()> {
278 self.active_states
280 .retain(|state| state.deadline.map(|d| timestamp <= d).unwrap_or(true));
281
282 let mut new_states = Vec::new();
284 let mut completed = Vec::new();
285
286 for state in &self.active_states {
287 if let Some((new_state, is_complete)) = self.advance_state(state, timestamp, &fields)? {
288 if is_complete {
289 let pattern = &self.patterns[state.pattern_id];
291 completed.push(PatternMatch {
292 pattern_name: pattern.name.clone(),
293 start_time: state.start_time,
294 end_time: timestamp,
295 events: new_state.matched_events,
296 });
297 } else {
298 new_states.push(new_state);
299 }
300 }
301 }
302
303 for (pattern_id, pattern) in self.patterns.iter().enumerate() {
305 if let Some(state) =
306 self.try_start_match(pattern_id, &pattern.sequence, timestamp, &fields)?
307 {
308 if self.is_pattern_complete(&pattern.sequence, &state) {
310 completed.push(PatternMatch {
311 pattern_name: pattern.name.clone(),
312 start_time: timestamp,
313 end_time: timestamp,
314 events: state.matched_events,
315 });
316 } else {
317 new_states.push(state);
318 }
319 }
320 }
321
322 self.active_states = new_states;
324 self.completed_matches.extend(completed);
325
326 Ok(())
327 }
328
329 fn try_start_match(
331 &self,
332 pattern_id: usize,
333 sequence: &PatternSequence,
334 timestamp: DateTime<Utc>,
335 fields: &HashMap<String, ValueWord>,
336 ) -> Result<Option<MatchState>> {
337 match sequence {
338 PatternSequence::Condition(cond) => {
339 if cond.evaluate(fields) {
340 Ok(Some(MatchState {
341 pattern_id,
342 position: 1, start_time: timestamp,
344 deadline: None,
345 matched_events: vec![MatchedEvent {
346 timestamp,
347 condition_name: cond.name.clone(),
348 fields: fields.clone(),
349 }],
350 }))
351 } else {
352 Ok(None)
353 }
354 }
355 PatternSequence::Seq(patterns) if !patterns.is_empty() => {
356 self.try_start_match(pattern_id, &patterns[0], timestamp, fields)
358 }
359 PatternSequence::Within(inner, duration) => {
360 if let Some(mut state) =
361 self.try_start_match(pattern_id, inner, timestamp, fields)?
362 {
363 state.deadline = Some(timestamp + *duration);
364 Ok(Some(state))
365 } else {
366 Ok(None)
367 }
368 }
369 PatternSequence::Or(patterns) => {
370 for pattern in patterns {
371 if let Some(state) =
372 self.try_start_match(pattern_id, pattern, timestamp, fields)?
373 {
374 return Ok(Some(state));
375 }
376 }
377 Ok(None)
378 }
379 PatternSequence::And(patterns) => {
380 let mut all_matched = true;
382 let mut matched_events = Vec::new();
383
384 for pattern in patterns {
385 if let Some(state) =
386 self.try_start_match(pattern_id, pattern, timestamp, fields)?
387 {
388 matched_events.extend(state.matched_events);
389 } else {
390 all_matched = false;
391 break;
392 }
393 }
394
395 if all_matched && !matched_events.is_empty() {
396 Ok(Some(MatchState {
397 pattern_id,
398 position: 1,
399 start_time: timestamp,
400 deadline: None,
401 matched_events,
402 }))
403 } else {
404 Ok(None)
405 }
406 }
407 _ => Ok(None),
408 }
409 }
410
411 fn advance_state(
413 &self,
414 state: &MatchState,
415 timestamp: DateTime<Utc>,
416 fields: &HashMap<String, ValueWord>,
417 ) -> Result<Option<(MatchState, bool)>> {
418 let pattern = &self.patterns[state.pattern_id];
419
420 match &pattern.sequence {
421 PatternSequence::Seq(patterns) => {
422 if state.position < patterns.len() {
423 if let PatternSequence::Condition(cond) = &patterns[state.position] {
425 if cond.evaluate(fields) {
426 let mut new_state = state.clone();
427 new_state.position += 1;
428 new_state.matched_events.push(MatchedEvent {
429 timestamp,
430 condition_name: cond.name.clone(),
431 fields: fields.clone(),
432 });
433
434 let is_complete = new_state.position >= patterns.len();
435 return Ok(Some((new_state, is_complete)));
436 }
437 }
438 }
439 }
440 PatternSequence::FollowedBy(_, second) => {
441 if state.position == 1 {
443 if let PatternSequence::Condition(cond) = second.as_ref() {
444 if cond.evaluate(fields) {
445 let mut new_state = state.clone();
446 new_state.position = 2;
447 new_state.matched_events.push(MatchedEvent {
448 timestamp,
449 condition_name: cond.name.clone(),
450 fields: fields.clone(),
451 });
452 return Ok(Some((new_state, true)));
453 }
454 }
455 }
456 }
457 PatternSequence::Repeat(inner, times) => {
458 if state.position < *times {
459 if let Some(new_inner_state) =
460 self.try_start_match(state.pattern_id, inner, timestamp, fields)?
461 {
462 let mut new_state = state.clone();
463 new_state.position += 1;
464 new_state
465 .matched_events
466 .extend(new_inner_state.matched_events);
467
468 let is_complete = new_state.position >= *times;
469 return Ok(Some((new_state, is_complete)));
470 }
471 }
472 }
473 _ => {}
474 }
475
476 Ok(Some((state.clone(), false)))
478 }
479
480 fn is_pattern_complete(&self, sequence: &PatternSequence, state: &MatchState) -> bool {
482 match sequence {
483 PatternSequence::Condition(_) => state.position >= 1,
484 PatternSequence::Seq(patterns) => state.position >= patterns.len(),
485 PatternSequence::Within(inner, _) => self.is_pattern_complete(inner, state),
486 PatternSequence::Repeat(_, times) => state.position >= *times,
487 PatternSequence::And(_) | PatternSequence::Or(_) => state.position >= 1,
488 _ => false,
489 }
490 }
491
492 pub fn take_matches(&mut self) -> Vec<PatternMatch> {
494 std::mem::take(&mut self.completed_matches)
495 }
496
497 pub fn active_count(&self) -> usize {
499 self.active_states.len()
500 }
501
502 pub fn reset(&mut self) {
504 self.active_states.clear();
505 self.completed_matches.clear();
506 }
507}
508
509#[cfg(test)]
510mod tests {
511 use super::*;
512 use std::sync::Arc;
513
514 fn make_event(value: f64, status: &str) -> HashMap<String, ValueWord> {
515 let mut fields = HashMap::new();
516 fields.insert("value".to_string(), ValueWord::from_f64(value));
517 fields.insert(
518 "status".to_string(),
519 ValueWord::from_string(Arc::new(status.to_string())),
520 );
521 fields
522 }
523
524 #[test]
525 fn test_single_condition() {
526 let mut psm = PatternStateMachine::new();
527
528 psm.register(
529 "high_value",
530 PatternSequence::condition(
531 "high",
532 "value",
533 ComparisonOp::Gt,
534 ValueWord::from_f64(100.0),
535 ),
536 );
537
538 let base = DateTime::from_timestamp(1000000000, 0).unwrap();
539
540 psm.process(base, make_event(50.0, "ok")).unwrap();
542 assert!(psm.take_matches().is_empty());
543
544 psm.process(base + Duration::seconds(1), make_event(150.0, "ok"))
546 .unwrap();
547 let matches = psm.take_matches();
548 assert_eq!(matches.len(), 1);
549 assert_eq!(matches[0].pattern_name, "high_value");
550 }
551
552 #[test]
553 fn test_sequence_pattern() {
554 let mut psm = PatternStateMachine::new();
555
556 psm.register(
558 "spike",
559 PatternSequence::seq(vec![
560 PatternSequence::condition(
561 "low",
562 "value",
563 ComparisonOp::Lt,
564 ValueWord::from_f64(50.0),
565 ),
566 PatternSequence::condition(
567 "high",
568 "value",
569 ComparisonOp::Gt,
570 ValueWord::from_f64(150.0),
571 ),
572 ]),
573 );
574
575 let base = DateTime::from_timestamp(1000000000, 0).unwrap();
576
577 psm.process(base, make_event(30.0, "ok")).unwrap();
579 assert!(psm.take_matches().is_empty());
580 assert_eq!(psm.active_count(), 1); psm.process(base + Duration::seconds(1), make_event(200.0, "ok"))
584 .unwrap();
585 let matches = psm.take_matches();
586 assert_eq!(matches.len(), 1);
587 assert_eq!(matches[0].events.len(), 2);
588 }
589
590 #[test]
591 fn test_within_constraint() {
592 let mut psm = PatternStateMachine::new();
593
594 psm.register(
596 "fast_spike",
597 PatternSequence::seq(vec![
598 PatternSequence::condition(
599 "low",
600 "value",
601 ComparisonOp::Lt,
602 ValueWord::from_f64(50.0),
603 ),
604 PatternSequence::condition(
605 "high",
606 "value",
607 ComparisonOp::Gt,
608 ValueWord::from_f64(150.0),
609 ),
610 ])
611 .within(Duration::seconds(5)),
612 );
613
614 let base = DateTime::from_timestamp(1000000000, 0).unwrap();
615
616 psm.process(base, make_event(30.0, "ok")).unwrap();
618 assert_eq!(psm.active_count(), 1);
619
620 psm.process(base + Duration::seconds(10), make_event(200.0, "ok"))
622 .unwrap();
623
624 let matches = psm.take_matches();
626 assert!(matches.is_empty());
627 }
628
629 #[test]
630 fn test_or_pattern() {
631 let mut psm = PatternStateMachine::new();
632
633 psm.register(
635 "alert_condition",
636 PatternSequence::or(vec![
637 PatternSequence::condition(
638 "high_val",
639 "value",
640 ComparisonOp::Gt,
641 ValueWord::from_f64(100.0),
642 ),
643 PatternSequence::condition(
644 "alert_status",
645 "status",
646 ComparisonOp::Eq,
647 ValueWord::from_string(Arc::new("alert".to_string())),
648 ),
649 ]),
650 );
651
652 let base = DateTime::from_timestamp(1000000000, 0).unwrap();
653
654 psm.process(base, make_event(150.0, "ok")).unwrap();
656 assert_eq!(psm.take_matches().len(), 1);
657
658 psm.process(base + Duration::seconds(1), make_event(50.0, "alert"))
660 .unwrap();
661 assert_eq!(psm.take_matches().len(), 1);
662 }
663
664 #[test]
665 fn test_string_conditions() {
666 let mut psm = PatternStateMachine::new();
667
668 psm.register(
669 "status_check",
670 PatternSequence::condition(
671 "starts_err",
672 "status",
673 ComparisonOp::StartsWith,
674 ValueWord::from_string(Arc::new("err".to_string())),
675 ),
676 );
677
678 let base = DateTime::from_timestamp(1000000000, 0).unwrap();
679
680 psm.process(base, make_event(0.0, "ok")).unwrap();
682 assert!(psm.take_matches().is_empty());
683
684 let mut fields = HashMap::new();
686 fields.insert("value".to_string(), ValueWord::from_f64(0.0));
687 fields.insert(
688 "status".to_string(),
689 ValueWord::from_string(Arc::new("error: connection failed".to_string())),
690 );
691 psm.process(base + Duration::seconds(1), fields).unwrap();
692 assert_eq!(psm.take_matches().len(), 1);
693 }
694}