1use std::collections::HashMap;
49use std::sync::RwLock;
50use std::time::{SystemTime, UNIX_EPOCH};
51
52#[derive(Debug, Clone)]
58pub struct SemanticTrigger {
59 pub id: String,
61
62 pub name: String,
64
65 pub description: String,
67
68 pub query: String,
70
71 pub embedding: Option<Vec<f32>>,
73
74 pub threshold: f32,
76
77 pub action: TriggerAction,
79
80 pub enabled: bool,
82
83 pub priority: i32,
85
86 pub max_fires_per_window: Option<usize>,
88
89 pub rate_limit_window_secs: Option<u64>,
91
92 pub tags: Vec<String>,
94
95 pub metadata: HashMap<String, String>,
97
98 pub created_at: f64,
100}
101
102#[derive(Debug, Clone)]
104pub enum TriggerAction {
105 Notify {
107 channel: String,
108 template: Option<String>,
109 },
110
111 Route {
113 target: String,
114 context: Option<String>,
115 },
116
117 Escalate {
119 level: EscalationLevel,
120 reason: Option<String>,
121 },
122
123 SpawnAgent {
125 agent_type: String,
126 config: HashMap<String, String>,
127 },
128
129 Log {
131 level: LogLevel,
132 message: Option<String>,
133 },
134
135 Webhook {
137 url: String,
138 method: String,
139 headers: HashMap<String, String>,
140 },
141
142 Callback {
144 function: String,
145 args: HashMap<String, String>,
146 },
147
148 Chain(Vec<TriggerAction>),
150}
151
152#[derive(Debug, Clone, Copy, PartialEq, Eq)]
154pub enum EscalationLevel {
155 Low,
156 Medium,
157 High,
158 Critical,
159}
160
161#[derive(Debug, Clone, Copy, PartialEq, Eq)]
163pub enum LogLevel {
164 Debug,
165 Info,
166 Warn,
167 Error,
168}
169
170#[derive(Debug, Clone)]
176pub struct TriggerEvent {
177 pub id: String,
179
180 pub content: String,
182
183 pub embedding: Option<Vec<f32>>,
185
186 pub source: EventSource,
188
189 pub metadata: HashMap<String, String>,
191
192 pub timestamp: f64,
194}
195
196#[derive(Debug, Clone, PartialEq, Eq)]
198pub enum EventSource {
199 UserMessage,
201 SystemEvent,
203 DataInsert,
205 MemoryCompaction,
207 ExternalApi,
209 AgentAction,
211 Custom(String),
213}
214
215#[derive(Debug, Clone)]
217pub struct TriggerMatch {
218 pub trigger_id: String,
220
221 pub score: f32,
223
224 pub event_id: String,
226
227 pub timestamp: f64,
229
230 pub action_executed: bool,
232
233 pub execution_result: Option<String>,
235}
236
237#[derive(Debug, Clone, Default)]
239pub struct TriggerStats {
240 pub events_processed: usize,
242
243 pub triggers_matched: usize,
245
246 pub actions_executed: usize,
248
249 pub matches_by_trigger: HashMap<String, usize>,
251
252 pub rate_limited: usize,
254}
255
256pub struct TriggerIndex {
262 triggers: RwLock<HashMap<String, SemanticTrigger>>,
264
265 trigger_embeddings: RwLock<Vec<(String, Vec<f32>)>>,
267
268 rate_limits: RwLock<HashMap<String, (usize, f64)>>,
270
271 recent_matches: RwLock<Vec<TriggerMatch>>,
273
274 stats: RwLock<TriggerStats>,
276
277 max_recent_matches: usize,
279}
280
281impl TriggerIndex {
282 pub fn new() -> Self {
284 Self {
285 triggers: RwLock::new(HashMap::new()),
286 trigger_embeddings: RwLock::new(Vec::new()),
287 rate_limits: RwLock::new(HashMap::new()),
288 recent_matches: RwLock::new(Vec::new()),
289 stats: RwLock::new(TriggerStats::default()),
290 max_recent_matches: 1000,
291 }
292 }
293
294 pub fn register_trigger(&self, mut trigger: SemanticTrigger) -> Result<(), TriggerError> {
296 if trigger.id.is_empty() {
297 return Err(TriggerError::InvalidTrigger("ID cannot be empty".to_string()));
298 }
299
300 if trigger.created_at == 0.0 {
302 trigger.created_at = SystemTime::now()
303 .duration_since(UNIX_EPOCH)
304 .unwrap_or_default()
305 .as_secs_f64();
306 }
307
308 {
310 let mut triggers = self.triggers.write().unwrap();
311 triggers.insert(trigger.id.clone(), trigger.clone());
312 }
313
314 if let Some(embedding) = &trigger.embedding {
316 let mut embeddings = self.trigger_embeddings.write().unwrap();
317 embeddings.push((trigger.id.clone(), embedding.clone()));
318 }
319
320 Ok(())
321 }
322
323 pub fn remove_trigger(&self, trigger_id: &str) -> Option<SemanticTrigger> {
325 let removed = {
326 let mut triggers = self.triggers.write().unwrap();
327 triggers.remove(trigger_id)
328 };
329
330 if removed.is_some() {
331 let mut embeddings = self.trigger_embeddings.write().unwrap();
332 embeddings.retain(|(id, _)| id != trigger_id);
333 }
334
335 removed
336 }
337
338 pub fn set_enabled(&self, trigger_id: &str, enabled: bool) -> bool {
340 let mut triggers = self.triggers.write().unwrap();
341 if let Some(trigger) = triggers.get_mut(trigger_id) {
342 trigger.enabled = enabled;
343 true
344 } else {
345 false
346 }
347 }
348
349 pub fn set_threshold(&self, trigger_id: &str, threshold: f32) -> bool {
351 let mut triggers = self.triggers.write().unwrap();
352 if let Some(trigger) = triggers.get_mut(trigger_id) {
353 trigger.threshold = threshold.clamp(0.0, 1.0);
354 true
355 } else {
356 false
357 }
358 }
359
360 pub fn process_event(&self, event: &TriggerEvent) -> Vec<TriggerMatch> {
362 let mut matches = Vec::new();
363 let now = SystemTime::now()
364 .duration_since(UNIX_EPOCH)
365 .unwrap_or_default()
366 .as_secs_f64();
367
368 {
370 let mut stats = self.stats.write().unwrap();
371 stats.events_processed += 1;
372 }
373
374 let event_embedding = match &event.embedding {
376 Some(emb) => emb.clone(),
377 None => {
378 return matches;
380 }
381 };
382
383 let candidates = self.find_candidates(&event_embedding, 10);
385
386 let triggers = self.triggers.read().unwrap();
387
388 for (trigger_id, score) in candidates {
389 if let Some(trigger) = triggers.get(&trigger_id) {
390 if !trigger.enabled {
392 continue;
393 }
394
395 if score < trigger.threshold {
397 continue;
398 }
399
400 if !self.check_rate_limit(&trigger_id, trigger, now) {
402 let mut stats = self.stats.write().unwrap();
403 stats.rate_limited += 1;
404 continue;
405 }
406
407 let trigger_match = TriggerMatch {
409 trigger_id: trigger_id.clone(),
410 score,
411 event_id: event.id.clone(),
412 timestamp: now,
413 action_executed: false,
414 execution_result: None,
415 };
416
417 matches.push(trigger_match);
418
419 {
421 let mut stats = self.stats.write().unwrap();
422 stats.triggers_matched += 1;
423 *stats.matches_by_trigger.entry(trigger_id.clone()).or_insert(0) += 1;
424 }
425 }
426 }
427
428 matches.sort_by(|a, b| {
430 let trigger_a = triggers.get(&a.trigger_id);
431 let trigger_b = triggers.get(&b.trigger_id);
432
433 match (trigger_a, trigger_b) {
434 (Some(ta), Some(tb)) => {
435 ta.priority.cmp(&tb.priority)
436 .then_with(|| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal))
437 }
438 _ => std::cmp::Ordering::Equal,
439 }
440 });
441
442 {
444 let mut recent = self.recent_matches.write().unwrap();
445 for m in &matches {
446 recent.push(m.clone());
447 }
448 while recent.len() > self.max_recent_matches {
450 recent.remove(0);
451 }
452 }
453
454 matches
455 }
456
457 fn find_candidates(&self, query: &[f32], k: usize) -> Vec<(String, f32)> {
459 let embeddings = self.trigger_embeddings.read().unwrap();
460
461 if embeddings.is_empty() {
462 return Vec::new();
463 }
464
465 let mut candidates: Vec<(String, f32)> = embeddings
467 .iter()
468 .map(|(id, emb)| {
469 let score = cosine_similarity(query, emb);
470 (id.clone(), score)
471 })
472 .collect();
473
474 candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
476
477 candidates.truncate(k);
478 candidates
479 }
480
481 fn check_rate_limit(&self, trigger_id: &str, trigger: &SemanticTrigger, now: f64) -> bool {
483 let max_fires = match trigger.max_fires_per_window {
484 Some(max) => max,
485 None => return true, };
487
488 let window_secs = trigger.rate_limit_window_secs.unwrap_or(60);
489
490 let mut rate_limits = self.rate_limits.write().unwrap();
491 let entry = rate_limits.entry(trigger_id.to_string()).or_insert((0, now));
492
493 if now - entry.1 > window_secs as f64 {
495 entry.0 = 1;
496 entry.1 = now;
497 return true;
498 }
499
500 if entry.0 < max_fires {
502 entry.0 += 1;
503 return true;
504 }
505
506 false
507 }
508
509 pub fn execute_action(&self, trigger_match: &mut TriggerMatch) -> Result<(), TriggerError> {
511 let triggers = self.triggers.read().unwrap();
512 let trigger = triggers.get(&trigger_match.trigger_id)
513 .ok_or_else(|| TriggerError::TriggerNotFound(trigger_match.trigger_id.clone()))?;
514
515 let result = self.execute_action_impl(&trigger.action, trigger_match)?;
517
518 trigger_match.action_executed = true;
519 trigger_match.execution_result = Some(result);
520
521 {
523 let mut stats = self.stats.write().unwrap();
524 stats.actions_executed += 1;
525 }
526
527 Ok(())
528 }
529
530 fn execute_action_impl(&self, action: &TriggerAction, trigger_match: &TriggerMatch) -> Result<String, TriggerError> {
532 match action {
533 TriggerAction::Notify { channel, template } => {
534 Ok(format!("Notified channel '{}' (template: {:?})", channel, template))
536 }
537
538 TriggerAction::Route { target, context } => {
539 Ok(format!("Routed to '{}' (context: {:?})", target, context))
540 }
541
542 TriggerAction::Escalate { level, reason } => {
543 Ok(format!("Escalated at level {:?} (reason: {:?})", level, reason))
544 }
545
546 TriggerAction::SpawnAgent { agent_type, config: _ } => {
547 Ok(format!("Spawned agent of type '{}'", agent_type))
548 }
549
550 TriggerAction::Log { level, message } => {
551 let msg = message.as_deref().unwrap_or(&trigger_match.trigger_id);
552 Ok(format!("Logged at {:?}: {}", level, msg))
553 }
554
555 TriggerAction::Webhook { url, method, headers: _ } => {
556 Ok(format!("Called webhook {} {}", method, url))
558 }
559
560 TriggerAction::Callback { function, args: _ } => {
561 Ok(format!("Called callback function '{}'", function))
562 }
563
564 TriggerAction::Chain(actions) => {
565 let mut results = Vec::new();
566 for sub_action in actions {
567 let result = self.execute_action_impl(sub_action, trigger_match)?;
568 results.push(result);
569 }
570 Ok(format!("Chain executed: [{}]", results.join(", ")))
571 }
572 }
573 }
574
575 pub fn list_triggers(&self) -> Vec<SemanticTrigger> {
577 self.triggers.read().unwrap().values().cloned().collect()
578 }
579
580 pub fn get_trigger(&self, trigger_id: &str) -> Option<SemanticTrigger> {
582 self.triggers.read().unwrap().get(trigger_id).cloned()
583 }
584
585 pub fn recent_matches(&self, limit: usize) -> Vec<TriggerMatch> {
587 let matches = self.recent_matches.read().unwrap();
588 matches.iter().rev().take(limit).cloned().collect()
589 }
590
591 pub fn stats(&self) -> TriggerStats {
593 self.stats.read().unwrap().clone()
594 }
595
596 pub fn clear_stats(&self) {
598 let mut stats = self.stats.write().unwrap();
599 *stats = TriggerStats::default();
600 }
601}
602
603impl Default for TriggerIndex {
604 fn default() -> Self {
605 Self::new()
606 }
607}
608
609pub struct TriggerBuilder {
615 trigger: SemanticTrigger,
616}
617
618impl TriggerBuilder {
619 pub fn new(id: &str, query: &str) -> Self {
621 Self {
622 trigger: SemanticTrigger {
623 id: id.to_string(),
624 name: id.to_string(),
625 description: String::new(),
626 query: query.to_string(),
627 embedding: None,
628 threshold: 0.8,
629 action: TriggerAction::Log {
630 level: LogLevel::Info,
631 message: None,
632 },
633 enabled: true,
634 priority: 0,
635 max_fires_per_window: None,
636 rate_limit_window_secs: None,
637 tags: Vec::new(),
638 metadata: HashMap::new(),
639 created_at: 0.0,
640 },
641 }
642 }
643
644 pub fn name(mut self, name: &str) -> Self {
646 self.trigger.name = name.to_string();
647 self
648 }
649
650 pub fn description(mut self, description: &str) -> Self {
652 self.trigger.description = description.to_string();
653 self
654 }
655
656 pub fn embedding(mut self, embedding: Vec<f32>) -> Self {
658 self.trigger.embedding = Some(embedding);
659 self
660 }
661
662 pub fn threshold(mut self, threshold: f32) -> Self {
664 self.trigger.threshold = threshold.clamp(0.0, 1.0);
665 self
666 }
667
668 pub fn action(mut self, action: TriggerAction) -> Self {
670 self.trigger.action = action;
671 self
672 }
673
674 pub fn notify(mut self, channel: &str) -> Self {
676 self.trigger.action = TriggerAction::Notify {
677 channel: channel.to_string(),
678 template: None,
679 };
680 self
681 }
682
683 pub fn route(mut self, target: &str) -> Self {
685 self.trigger.action = TriggerAction::Route {
686 target: target.to_string(),
687 context: None,
688 };
689 self
690 }
691
692 pub fn escalate(mut self, level: EscalationLevel) -> Self {
694 self.trigger.action = TriggerAction::Escalate {
695 level,
696 reason: None,
697 };
698 self
699 }
700
701 pub fn priority(mut self, priority: i32) -> Self {
703 self.trigger.priority = priority;
704 self
705 }
706
707 pub fn rate_limit(mut self, max_fires: usize, window_secs: u64) -> Self {
709 self.trigger.max_fires_per_window = Some(max_fires);
710 self.trigger.rate_limit_window_secs = Some(window_secs);
711 self
712 }
713
714 pub fn tag(mut self, tag: &str) -> Self {
716 self.trigger.tags.push(tag.to_string());
717 self
718 }
719
720 pub fn enabled(mut self, enabled: bool) -> Self {
722 self.trigger.enabled = enabled;
723 self
724 }
725
726 pub fn build(self) -> SemanticTrigger {
728 self.trigger
729 }
730}
731
732#[derive(Debug, Clone)]
738pub enum TriggerError {
739 InvalidTrigger(String),
741 TriggerNotFound(String),
743 ActionFailed(String),
745 RateLimitExceeded(String),
747 EmbeddingError(String),
749}
750
751impl std::fmt::Display for TriggerError {
752 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
753 match self {
754 Self::InvalidTrigger(msg) => write!(f, "Invalid trigger: {}", msg),
755 Self::TriggerNotFound(id) => write!(f, "Trigger not found: {}", id),
756 Self::ActionFailed(msg) => write!(f, "Action failed: {}", msg),
757 Self::RateLimitExceeded(id) => write!(f, "Rate limit exceeded for trigger: {}", id),
758 Self::EmbeddingError(msg) => write!(f, "Embedding error: {}", msg),
759 }
760 }
761}
762
763impl std::error::Error for TriggerError {}
764
765fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
771 if a.len() != b.len() || a.is_empty() {
772 return 0.0;
773 }
774
775 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
776 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
777 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
778
779 if norm_a < 1e-10 || norm_b < 1e-10 {
780 return 0.0;
781 }
782
783 dot / (norm_a * norm_b)
784}
785
786pub fn create_notify_trigger(
792 id: &str,
793 query: &str,
794 channel: &str,
795 embedding: Vec<f32>,
796) -> SemanticTrigger {
797 TriggerBuilder::new(id, query)
798 .embedding(embedding)
799 .notify(channel)
800 .build()
801}
802
803pub fn create_escalation_trigger(
805 id: &str,
806 query: &str,
807 level: EscalationLevel,
808 embedding: Vec<f32>,
809) -> SemanticTrigger {
810 TriggerBuilder::new(id, query)
811 .embedding(embedding)
812 .escalate(level)
813 .priority(-1) .build()
815}
816
817#[cfg(test)]
822mod tests {
823 use super::*;
824
825 fn mock_embedding(seed: u64) -> Vec<f32> {
826 (0..128)
827 .map(|i| ((i as u64 + seed) % 100) as f32 / 100.0 - 0.5)
828 .collect()
829 }
830
831 #[test]
832 fn test_trigger_registration() {
833 let index = TriggerIndex::new();
834
835 let trigger = TriggerBuilder::new("privacy_concern", "user mentions privacy concerns")
836 .embedding(mock_embedding(1))
837 .threshold(0.75)
838 .escalate(EscalationLevel::High)
839 .build();
840
841 index.register_trigger(trigger).unwrap();
842
843 let triggers = index.list_triggers();
844 assert_eq!(triggers.len(), 1);
845 assert_eq!(triggers[0].id, "privacy_concern");
846 }
847
848 #[test]
849 fn test_trigger_matching() {
850 let index = TriggerIndex::new();
851
852 let trigger = TriggerBuilder::new("security_alert", "security vulnerability")
853 .embedding(mock_embedding(1))
854 .threshold(0.5) .notify("security-team")
856 .build();
857
858 index.register_trigger(trigger).unwrap();
859
860 let event = TriggerEvent {
862 id: "event_1".to_string(),
863 content: "possible security issue detected".to_string(),
864 embedding: Some(mock_embedding(1)), source: EventSource::SystemEvent,
866 metadata: HashMap::new(),
867 timestamp: 0.0,
868 };
869
870 let matches = index.process_event(&event);
871
872 assert!(!matches.is_empty());
873 assert_eq!(matches[0].trigger_id, "security_alert");
874 assert!(matches[0].score > 0.5);
875 }
876
877 #[test]
878 fn test_trigger_disable() {
879 let index = TriggerIndex::new();
880
881 let trigger = TriggerBuilder::new("test_trigger", "test")
882 .embedding(mock_embedding(1))
883 .threshold(0.5)
884 .build();
885
886 index.register_trigger(trigger).unwrap();
887
888 index.set_enabled("test_trigger", false);
890
891 let event = TriggerEvent {
892 id: "event_1".to_string(),
893 content: "test".to_string(),
894 embedding: Some(mock_embedding(1)),
895 source: EventSource::UserMessage,
896 metadata: HashMap::new(),
897 timestamp: 0.0,
898 };
899
900 let matches = index.process_event(&event);
901
902 assert!(matches.is_empty());
904 }
905
906 #[test]
907 fn test_rate_limiting() {
908 let index = TriggerIndex::new();
909
910 let trigger = TriggerBuilder::new("rate_limited", "test")
911 .embedding(mock_embedding(1))
912 .threshold(0.5)
913 .rate_limit(2, 60) .build();
915
916 index.register_trigger(trigger).unwrap();
917
918 let event = TriggerEvent {
919 id: "event_1".to_string(),
920 content: "test".to_string(),
921 embedding: Some(mock_embedding(1)),
922 source: EventSource::UserMessage,
923 metadata: HashMap::new(),
924 timestamp: 0.0,
925 };
926
927 let m1 = index.process_event(&event);
929 let m2 = index.process_event(&event);
930
931 let m3 = index.process_event(&event);
933
934 assert!(!m1.is_empty());
935 assert!(!m2.is_empty());
936 assert!(m3.is_empty());
937
938 let stats = index.stats();
940 assert!(stats.rate_limited >= 1);
941 }
942
943 #[test]
944 fn test_action_execution() {
945 let index = TriggerIndex::new();
946
947 let trigger = TriggerBuilder::new("log_trigger", "test")
948 .embedding(mock_embedding(1))
949 .threshold(0.5)
950 .action(TriggerAction::Log {
951 level: LogLevel::Info,
952 message: Some("Test message".to_string()),
953 })
954 .build();
955
956 index.register_trigger(trigger).unwrap();
957
958 let event = TriggerEvent {
959 id: "event_1".to_string(),
960 content: "test".to_string(),
961 embedding: Some(mock_embedding(1)),
962 source: EventSource::UserMessage,
963 metadata: HashMap::new(),
964 timestamp: 0.0,
965 };
966
967 let mut matches = index.process_event(&event);
968
969 assert!(!matches.is_empty());
970
971 index.execute_action(&mut matches[0]).unwrap();
973
974 assert!(matches[0].action_executed);
975 assert!(matches[0].execution_result.is_some());
976 }
977
978 #[test]
979 fn test_cosine_similarity() {
980 let a = vec![1.0, 0.0, 0.0];
981 let b = vec![1.0, 0.0, 0.0];
982
983 let sim = cosine_similarity(&a, &b);
984 assert!((sim - 1.0).abs() < 0.01);
985
986 let c = vec![0.0, 1.0, 0.0];
987 let sim2 = cosine_similarity(&a, &c);
988 assert!(sim2.abs() < 0.01);
989 }
990
991 #[test]
992 fn test_trigger_builder() {
993 let trigger = TriggerBuilder::new("test", "test query")
994 .name("Test Trigger")
995 .description("A test trigger")
996 .threshold(0.85)
997 .priority(5)
998 .tag("test")
999 .tag("example")
1000 .notify("test-channel")
1001 .rate_limit(10, 300)
1002 .build();
1003
1004 assert_eq!(trigger.id, "test");
1005 assert_eq!(trigger.name, "Test Trigger");
1006 assert_eq!(trigger.threshold, 0.85);
1007 assert_eq!(trigger.priority, 5);
1008 assert_eq!(trigger.tags.len(), 2);
1009 assert_eq!(trigger.max_fires_per_window, Some(10));
1010 }
1011}