1use std::collections::HashMap;
52use std::sync::RwLock;
53use std::time::{SystemTime, UNIX_EPOCH};
54
55#[derive(Debug, Clone)]
61pub struct SemanticTrigger {
62 pub id: String,
64
65 pub name: String,
67
68 pub description: String,
70
71 pub query: String,
73
74 pub embedding: Option<Vec<f32>>,
76
77 pub threshold: f32,
79
80 pub action: TriggerAction,
82
83 pub enabled: bool,
85
86 pub priority: i32,
88
89 pub max_fires_per_window: Option<usize>,
91
92 pub rate_limit_window_secs: Option<u64>,
94
95 pub tags: Vec<String>,
97
98 pub metadata: HashMap<String, String>,
100
101 pub created_at: f64,
103}
104
105#[derive(Debug, Clone)]
107pub enum TriggerAction {
108 Notify {
110 channel: String,
111 template: Option<String>,
112 },
113
114 Route {
116 target: String,
117 context: Option<String>,
118 },
119
120 Escalate {
122 level: EscalationLevel,
123 reason: Option<String>,
124 },
125
126 SpawnAgent {
128 agent_type: String,
129 config: HashMap<String, String>,
130 },
131
132 Log {
134 level: LogLevel,
135 message: Option<String>,
136 },
137
138 Webhook {
140 url: String,
141 method: String,
142 headers: HashMap<String, String>,
143 },
144
145 Callback {
147 function: String,
148 args: HashMap<String, String>,
149 },
150
151 Chain(Vec<TriggerAction>),
153}
154
155#[derive(Debug, Clone, Copy, PartialEq, Eq)]
157pub enum EscalationLevel {
158 Low,
159 Medium,
160 High,
161 Critical,
162}
163
164#[derive(Debug, Clone, Copy, PartialEq, Eq)]
166pub enum LogLevel {
167 Debug,
168 Info,
169 Warn,
170 Error,
171}
172
173#[derive(Debug, Clone)]
179pub struct TriggerEvent {
180 pub id: String,
182
183 pub content: String,
185
186 pub embedding: Option<Vec<f32>>,
188
189 pub source: EventSource,
191
192 pub metadata: HashMap<String, String>,
194
195 pub timestamp: f64,
197}
198
199#[derive(Debug, Clone, PartialEq, Eq)]
201pub enum EventSource {
202 UserMessage,
204 SystemEvent,
206 DataInsert,
208 MemoryCompaction,
210 ExternalApi,
212 AgentAction,
214 Custom(String),
216}
217
218#[derive(Debug, Clone)]
220pub struct TriggerMatch {
221 pub trigger_id: String,
223
224 pub score: f32,
226
227 pub event_id: String,
229
230 pub timestamp: f64,
232
233 pub action_executed: bool,
235
236 pub execution_result: Option<String>,
238}
239
240#[derive(Debug, Clone, Default)]
242pub struct TriggerStats {
243 pub events_processed: usize,
245
246 pub triggers_matched: usize,
248
249 pub actions_executed: usize,
251
252 pub matches_by_trigger: HashMap<String, usize>,
254
255 pub rate_limited: usize,
257}
258
259pub struct TriggerIndex {
265 triggers: RwLock<HashMap<String, SemanticTrigger>>,
267
268 trigger_embeddings: RwLock<Vec<(String, Vec<f32>)>>,
270
271 rate_limits: RwLock<HashMap<String, (usize, f64)>>,
273
274 recent_matches: RwLock<Vec<TriggerMatch>>,
276
277 stats: RwLock<TriggerStats>,
279
280 max_recent_matches: usize,
282}
283
284impl TriggerIndex {
285 pub fn new() -> Self {
287 Self {
288 triggers: RwLock::new(HashMap::new()),
289 trigger_embeddings: RwLock::new(Vec::new()),
290 rate_limits: RwLock::new(HashMap::new()),
291 recent_matches: RwLock::new(Vec::new()),
292 stats: RwLock::new(TriggerStats::default()),
293 max_recent_matches: 1000,
294 }
295 }
296
297 pub fn register_trigger(&self, mut trigger: SemanticTrigger) -> Result<(), TriggerError> {
299 if trigger.id.is_empty() {
300 return Err(TriggerError::InvalidTrigger("ID cannot be empty".to_string()));
301 }
302
303 if trigger.created_at == 0.0 {
305 trigger.created_at = SystemTime::now()
306 .duration_since(UNIX_EPOCH)
307 .unwrap_or_default()
308 .as_secs_f64();
309 }
310
311 {
313 let mut triggers = self.triggers.write().unwrap();
314 triggers.insert(trigger.id.clone(), trigger.clone());
315 }
316
317 if let Some(embedding) = &trigger.embedding {
319 let mut embeddings = self.trigger_embeddings.write().unwrap();
320 embeddings.push((trigger.id.clone(), embedding.clone()));
321 }
322
323 Ok(())
324 }
325
326 pub fn remove_trigger(&self, trigger_id: &str) -> Option<SemanticTrigger> {
328 let removed = {
329 let mut triggers = self.triggers.write().unwrap();
330 triggers.remove(trigger_id)
331 };
332
333 if removed.is_some() {
334 let mut embeddings = self.trigger_embeddings.write().unwrap();
335 embeddings.retain(|(id, _)| id != trigger_id);
336 }
337
338 removed
339 }
340
341 pub fn set_enabled(&self, trigger_id: &str, enabled: bool) -> bool {
343 let mut triggers = self.triggers.write().unwrap();
344 if let Some(trigger) = triggers.get_mut(trigger_id) {
345 trigger.enabled = enabled;
346 true
347 } else {
348 false
349 }
350 }
351
352 pub fn set_threshold(&self, trigger_id: &str, threshold: f32) -> bool {
354 let mut triggers = self.triggers.write().unwrap();
355 if let Some(trigger) = triggers.get_mut(trigger_id) {
356 trigger.threshold = threshold.clamp(0.0, 1.0);
357 true
358 } else {
359 false
360 }
361 }
362
363 pub fn process_event(&self, event: &TriggerEvent) -> Vec<TriggerMatch> {
365 let mut matches = Vec::new();
366 let now = SystemTime::now()
367 .duration_since(UNIX_EPOCH)
368 .unwrap_or_default()
369 .as_secs_f64();
370
371 {
373 let mut stats = self.stats.write().unwrap();
374 stats.events_processed += 1;
375 }
376
377 let event_embedding = match &event.embedding {
379 Some(emb) => emb.clone(),
380 None => {
381 return matches;
383 }
384 };
385
386 let candidates = self.find_candidates(&event_embedding, 10);
388
389 let triggers = self.triggers.read().unwrap();
390
391 for (trigger_id, score) in candidates {
392 if let Some(trigger) = triggers.get(&trigger_id) {
393 if !trigger.enabled {
395 continue;
396 }
397
398 if score < trigger.threshold {
400 continue;
401 }
402
403 if !self.check_rate_limit(&trigger_id, trigger, now) {
405 let mut stats = self.stats.write().unwrap();
406 stats.rate_limited += 1;
407 continue;
408 }
409
410 let trigger_match = TriggerMatch {
412 trigger_id: trigger_id.clone(),
413 score,
414 event_id: event.id.clone(),
415 timestamp: now,
416 action_executed: false,
417 execution_result: None,
418 };
419
420 matches.push(trigger_match);
421
422 {
424 let mut stats = self.stats.write().unwrap();
425 stats.triggers_matched += 1;
426 *stats.matches_by_trigger.entry(trigger_id.clone()).or_insert(0) += 1;
427 }
428 }
429 }
430
431 matches.sort_by(|a, b| {
433 let trigger_a = triggers.get(&a.trigger_id);
434 let trigger_b = triggers.get(&b.trigger_id);
435
436 match (trigger_a, trigger_b) {
437 (Some(ta), Some(tb)) => {
438 ta.priority.cmp(&tb.priority)
439 .then_with(|| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal))
440 }
441 _ => std::cmp::Ordering::Equal,
442 }
443 });
444
445 {
447 let mut recent = self.recent_matches.write().unwrap();
448 for m in &matches {
449 recent.push(m.clone());
450 }
451 while recent.len() > self.max_recent_matches {
453 recent.remove(0);
454 }
455 }
456
457 matches
458 }
459
460 fn find_candidates(&self, query: &[f32], k: usize) -> Vec<(String, f32)> {
462 let embeddings = self.trigger_embeddings.read().unwrap();
463
464 if embeddings.is_empty() {
465 return Vec::new();
466 }
467
468 let mut candidates: Vec<(String, f32)> = embeddings
470 .iter()
471 .map(|(id, emb)| {
472 let score = cosine_similarity(query, emb);
473 (id.clone(), score)
474 })
475 .collect();
476
477 candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
479
480 candidates.truncate(k);
481 candidates
482 }
483
484 fn check_rate_limit(&self, trigger_id: &str, trigger: &SemanticTrigger, now: f64) -> bool {
486 let max_fires = match trigger.max_fires_per_window {
487 Some(max) => max,
488 None => return true, };
490
491 let window_secs = trigger.rate_limit_window_secs.unwrap_or(60);
492
493 let mut rate_limits = self.rate_limits.write().unwrap();
494 let entry = rate_limits.entry(trigger_id.to_string()).or_insert((0, now));
495
496 if now - entry.1 > window_secs as f64 {
498 entry.0 = 1;
499 entry.1 = now;
500 return true;
501 }
502
503 if entry.0 < max_fires {
505 entry.0 += 1;
506 return true;
507 }
508
509 false
510 }
511
512 pub fn execute_action(&self, trigger_match: &mut TriggerMatch) -> Result<(), TriggerError> {
514 let triggers = self.triggers.read().unwrap();
515 let trigger = triggers.get(&trigger_match.trigger_id)
516 .ok_or_else(|| TriggerError::TriggerNotFound(trigger_match.trigger_id.clone()))?;
517
518 let result = self.execute_action_impl(&trigger.action, trigger_match)?;
520
521 trigger_match.action_executed = true;
522 trigger_match.execution_result = Some(result);
523
524 {
526 let mut stats = self.stats.write().unwrap();
527 stats.actions_executed += 1;
528 }
529
530 Ok(())
531 }
532
533 fn execute_action_impl(&self, action: &TriggerAction, trigger_match: &TriggerMatch) -> Result<String, TriggerError> {
535 match action {
536 TriggerAction::Notify { channel, template } => {
537 Ok(format!("Notified channel '{}' (template: {:?})", channel, template))
539 }
540
541 TriggerAction::Route { target, context } => {
542 Ok(format!("Routed to '{}' (context: {:?})", target, context))
543 }
544
545 TriggerAction::Escalate { level, reason } => {
546 Ok(format!("Escalated at level {:?} (reason: {:?})", level, reason))
547 }
548
549 TriggerAction::SpawnAgent { agent_type, config: _ } => {
550 Ok(format!("Spawned agent of type '{}'", agent_type))
551 }
552
553 TriggerAction::Log { level, message } => {
554 let msg = message.as_deref().unwrap_or(&trigger_match.trigger_id);
555 Ok(format!("Logged at {:?}: {}", level, msg))
556 }
557
558 TriggerAction::Webhook { url, method, headers: _ } => {
559 Ok(format!("Called webhook {} {}", method, url))
561 }
562
563 TriggerAction::Callback { function, args: _ } => {
564 Ok(format!("Called callback function '{}'", function))
565 }
566
567 TriggerAction::Chain(actions) => {
568 let mut results = Vec::new();
569 for sub_action in actions {
570 let result = self.execute_action_impl(sub_action, trigger_match)?;
571 results.push(result);
572 }
573 Ok(format!("Chain executed: [{}]", results.join(", ")))
574 }
575 }
576 }
577
578 pub fn list_triggers(&self) -> Vec<SemanticTrigger> {
580 self.triggers.read().unwrap().values().cloned().collect()
581 }
582
583 pub fn get_trigger(&self, trigger_id: &str) -> Option<SemanticTrigger> {
585 self.triggers.read().unwrap().get(trigger_id).cloned()
586 }
587
588 pub fn recent_matches(&self, limit: usize) -> Vec<TriggerMatch> {
590 let matches = self.recent_matches.read().unwrap();
591 matches.iter().rev().take(limit).cloned().collect()
592 }
593
594 pub fn stats(&self) -> TriggerStats {
596 self.stats.read().unwrap().clone()
597 }
598
599 pub fn clear_stats(&self) {
601 let mut stats = self.stats.write().unwrap();
602 *stats = TriggerStats::default();
603 }
604}
605
606impl Default for TriggerIndex {
607 fn default() -> Self {
608 Self::new()
609 }
610}
611
612pub struct TriggerBuilder {
618 trigger: SemanticTrigger,
619}
620
621impl TriggerBuilder {
622 pub fn new(id: &str, query: &str) -> Self {
624 Self {
625 trigger: SemanticTrigger {
626 id: id.to_string(),
627 name: id.to_string(),
628 description: String::new(),
629 query: query.to_string(),
630 embedding: None,
631 threshold: 0.8,
632 action: TriggerAction::Log {
633 level: LogLevel::Info,
634 message: None,
635 },
636 enabled: true,
637 priority: 0,
638 max_fires_per_window: None,
639 rate_limit_window_secs: None,
640 tags: Vec::new(),
641 metadata: HashMap::new(),
642 created_at: 0.0,
643 },
644 }
645 }
646
647 pub fn name(mut self, name: &str) -> Self {
649 self.trigger.name = name.to_string();
650 self
651 }
652
653 pub fn description(mut self, description: &str) -> Self {
655 self.trigger.description = description.to_string();
656 self
657 }
658
659 pub fn embedding(mut self, embedding: Vec<f32>) -> Self {
661 self.trigger.embedding = Some(embedding);
662 self
663 }
664
665 pub fn threshold(mut self, threshold: f32) -> Self {
667 self.trigger.threshold = threshold.clamp(0.0, 1.0);
668 self
669 }
670
671 pub fn action(mut self, action: TriggerAction) -> Self {
673 self.trigger.action = action;
674 self
675 }
676
677 pub fn notify(mut self, channel: &str) -> Self {
679 self.trigger.action = TriggerAction::Notify {
680 channel: channel.to_string(),
681 template: None,
682 };
683 self
684 }
685
686 pub fn route(mut self, target: &str) -> Self {
688 self.trigger.action = TriggerAction::Route {
689 target: target.to_string(),
690 context: None,
691 };
692 self
693 }
694
695 pub fn escalate(mut self, level: EscalationLevel) -> Self {
697 self.trigger.action = TriggerAction::Escalate {
698 level,
699 reason: None,
700 };
701 self
702 }
703
704 pub fn priority(mut self, priority: i32) -> Self {
706 self.trigger.priority = priority;
707 self
708 }
709
710 pub fn rate_limit(mut self, max_fires: usize, window_secs: u64) -> Self {
712 self.trigger.max_fires_per_window = Some(max_fires);
713 self.trigger.rate_limit_window_secs = Some(window_secs);
714 self
715 }
716
717 pub fn tag(mut self, tag: &str) -> Self {
719 self.trigger.tags.push(tag.to_string());
720 self
721 }
722
723 pub fn enabled(mut self, enabled: bool) -> Self {
725 self.trigger.enabled = enabled;
726 self
727 }
728
729 pub fn build(self) -> SemanticTrigger {
731 self.trigger
732 }
733}
734
735#[derive(Debug, Clone)]
741pub enum TriggerError {
742 InvalidTrigger(String),
744 TriggerNotFound(String),
746 ActionFailed(String),
748 RateLimitExceeded(String),
750 EmbeddingError(String),
752}
753
754impl std::fmt::Display for TriggerError {
755 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
756 match self {
757 Self::InvalidTrigger(msg) => write!(f, "Invalid trigger: {}", msg),
758 Self::TriggerNotFound(id) => write!(f, "Trigger not found: {}", id),
759 Self::ActionFailed(msg) => write!(f, "Action failed: {}", msg),
760 Self::RateLimitExceeded(id) => write!(f, "Rate limit exceeded for trigger: {}", id),
761 Self::EmbeddingError(msg) => write!(f, "Embedding error: {}", msg),
762 }
763 }
764}
765
766impl std::error::Error for TriggerError {}
767
768fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
774 if a.len() != b.len() || a.is_empty() {
775 return 0.0;
776 }
777
778 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
779 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
780 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
781
782 if norm_a < 1e-10 || norm_b < 1e-10 {
783 return 0.0;
784 }
785
786 dot / (norm_a * norm_b)
787}
788
789pub fn create_notify_trigger(
795 id: &str,
796 query: &str,
797 channel: &str,
798 embedding: Vec<f32>,
799) -> SemanticTrigger {
800 TriggerBuilder::new(id, query)
801 .embedding(embedding)
802 .notify(channel)
803 .build()
804}
805
806pub fn create_escalation_trigger(
808 id: &str,
809 query: &str,
810 level: EscalationLevel,
811 embedding: Vec<f32>,
812) -> SemanticTrigger {
813 TriggerBuilder::new(id, query)
814 .embedding(embedding)
815 .escalate(level)
816 .priority(-1) .build()
818}
819
820#[cfg(test)]
825mod tests {
826 use super::*;
827
828 fn mock_embedding(seed: u64) -> Vec<f32> {
829 (0..128)
830 .map(|i| ((i as u64 + seed) % 100) as f32 / 100.0 - 0.5)
831 .collect()
832 }
833
834 #[test]
835 fn test_trigger_registration() {
836 let index = TriggerIndex::new();
837
838 let trigger = TriggerBuilder::new("privacy_concern", "user mentions privacy concerns")
839 .embedding(mock_embedding(1))
840 .threshold(0.75)
841 .escalate(EscalationLevel::High)
842 .build();
843
844 index.register_trigger(trigger).unwrap();
845
846 let triggers = index.list_triggers();
847 assert_eq!(triggers.len(), 1);
848 assert_eq!(triggers[0].id, "privacy_concern");
849 }
850
851 #[test]
852 fn test_trigger_matching() {
853 let index = TriggerIndex::new();
854
855 let trigger = TriggerBuilder::new("security_alert", "security vulnerability")
856 .embedding(mock_embedding(1))
857 .threshold(0.5) .notify("security-team")
859 .build();
860
861 index.register_trigger(trigger).unwrap();
862
863 let event = TriggerEvent {
865 id: "event_1".to_string(),
866 content: "possible security issue detected".to_string(),
867 embedding: Some(mock_embedding(1)), source: EventSource::SystemEvent,
869 metadata: HashMap::new(),
870 timestamp: 0.0,
871 };
872
873 let matches = index.process_event(&event);
874
875 assert!(!matches.is_empty());
876 assert_eq!(matches[0].trigger_id, "security_alert");
877 assert!(matches[0].score > 0.5);
878 }
879
880 #[test]
881 fn test_trigger_disable() {
882 let index = TriggerIndex::new();
883
884 let trigger = TriggerBuilder::new("test_trigger", "test")
885 .embedding(mock_embedding(1))
886 .threshold(0.5)
887 .build();
888
889 index.register_trigger(trigger).unwrap();
890
891 index.set_enabled("test_trigger", false);
893
894 let event = TriggerEvent {
895 id: "event_1".to_string(),
896 content: "test".to_string(),
897 embedding: Some(mock_embedding(1)),
898 source: EventSource::UserMessage,
899 metadata: HashMap::new(),
900 timestamp: 0.0,
901 };
902
903 let matches = index.process_event(&event);
904
905 assert!(matches.is_empty());
907 }
908
909 #[test]
910 fn test_rate_limiting() {
911 let index = TriggerIndex::new();
912
913 let trigger = TriggerBuilder::new("rate_limited", "test")
914 .embedding(mock_embedding(1))
915 .threshold(0.5)
916 .rate_limit(2, 60) .build();
918
919 index.register_trigger(trigger).unwrap();
920
921 let event = TriggerEvent {
922 id: "event_1".to_string(),
923 content: "test".to_string(),
924 embedding: Some(mock_embedding(1)),
925 source: EventSource::UserMessage,
926 metadata: HashMap::new(),
927 timestamp: 0.0,
928 };
929
930 let m1 = index.process_event(&event);
932 let m2 = index.process_event(&event);
933
934 let m3 = index.process_event(&event);
936
937 assert!(!m1.is_empty());
938 assert!(!m2.is_empty());
939 assert!(m3.is_empty());
940
941 let stats = index.stats();
943 assert!(stats.rate_limited >= 1);
944 }
945
946 #[test]
947 fn test_action_execution() {
948 let index = TriggerIndex::new();
949
950 let trigger = TriggerBuilder::new("log_trigger", "test")
951 .embedding(mock_embedding(1))
952 .threshold(0.5)
953 .action(TriggerAction::Log {
954 level: LogLevel::Info,
955 message: Some("Test message".to_string()),
956 })
957 .build();
958
959 index.register_trigger(trigger).unwrap();
960
961 let event = TriggerEvent {
962 id: "event_1".to_string(),
963 content: "test".to_string(),
964 embedding: Some(mock_embedding(1)),
965 source: EventSource::UserMessage,
966 metadata: HashMap::new(),
967 timestamp: 0.0,
968 };
969
970 let mut matches = index.process_event(&event);
971
972 assert!(!matches.is_empty());
973
974 index.execute_action(&mut matches[0]).unwrap();
976
977 assert!(matches[0].action_executed);
978 assert!(matches[0].execution_result.is_some());
979 }
980
981 #[test]
982 fn test_cosine_similarity() {
983 let a = vec![1.0, 0.0, 0.0];
984 let b = vec![1.0, 0.0, 0.0];
985
986 let sim = cosine_similarity(&a, &b);
987 assert!((sim - 1.0).abs() < 0.01);
988
989 let c = vec![0.0, 1.0, 0.0];
990 let sim2 = cosine_similarity(&a, &c);
991 assert!(sim2.abs() < 0.01);
992 }
993
994 #[test]
995 fn test_trigger_builder() {
996 let trigger = TriggerBuilder::new("test", "test query")
997 .name("Test Trigger")
998 .description("A test trigger")
999 .threshold(0.85)
1000 .priority(5)
1001 .tag("test")
1002 .tag("example")
1003 .notify("test-channel")
1004 .rate_limit(10, 300)
1005 .build();
1006
1007 assert_eq!(trigger.id, "test");
1008 assert_eq!(trigger.name, "Test Trigger");
1009 assert_eq!(trigger.threshold, 0.85);
1010 assert_eq!(trigger.priority, 5);
1011 assert_eq!(trigger.tags.len(), 2);
1012 assert_eq!(trigger.max_fires_per_window, Some(10));
1013 }
1014}