1use anyhow::Result;
4use chrono::{DateTime, Utc};
5use parking_lot::RwLock;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::Duration;
10use tracing::{debug, warn};
11use uuid::Uuid;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct InvalidationEvent {
16 pub event_id: Uuid,
17 pub event_type: InvalidationEventType,
18 pub entity_type: String,
19 pub entity_id: Option<Uuid>,
20 pub operation: String,
21 pub timestamp: DateTime<Utc>,
22 pub affected_caches: Vec<String>,
23 pub metadata: HashMap<String, serde_json::Value>,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
28pub enum InvalidationEventType {
29 Created,
31 Updated,
33 Deleted,
35 Completed,
37 BulkOperation,
39 ManualInvalidation,
41 Expired,
43 CascadeInvalidation,
45}
46
47impl std::fmt::Display for InvalidationEventType {
48 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 match self {
50 InvalidationEventType::Created => write!(f, "Created"),
51 InvalidationEventType::Updated => write!(f, "Updated"),
52 InvalidationEventType::Deleted => write!(f, "Deleted"),
53 InvalidationEventType::Completed => write!(f, "Completed"),
54 InvalidationEventType::BulkOperation => write!(f, "BulkOperation"),
55 InvalidationEventType::ManualInvalidation => write!(f, "ManualInvalidation"),
56 InvalidationEventType::Expired => write!(f, "Expired"),
57 InvalidationEventType::CascadeInvalidation => write!(f, "CascadeInvalidation"),
58 }
59 }
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct InvalidationRule {
65 pub rule_id: Uuid,
66 pub name: String,
67 pub description: String,
68 pub entity_type: String,
69 pub operations: Vec<String>,
70 pub affected_cache_types: Vec<String>,
71 pub invalidation_strategy: InvalidationStrategy,
72 pub enabled: bool,
73 pub created_at: DateTime<Utc>,
74 pub updated_at: DateTime<Utc>,
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
79pub enum InvalidationStrategy {
80 InvalidateAll,
82 InvalidateSpecific(Vec<String>),
84 InvalidateByEntity,
86 InvalidateByPattern(String),
88 CascadeInvalidation,
90}
91
92pub struct CacheInvalidationMiddleware {
94 rules: Arc<RwLock<HashMap<String, InvalidationRule>>>,
96 events: Arc<RwLock<Vec<InvalidationEvent>>>,
98 handlers: Arc<RwLock<HashMap<String, Box<dyn CacheInvalidationHandler + Send + Sync>>>>,
100 config: InvalidationConfig,
102 stats: Arc<RwLock<InvalidationStats>>,
104}
105
106pub trait CacheInvalidationHandler {
108 fn invalidate(&self, event: &InvalidationEvent) -> Result<()>;
114
115 fn cache_type(&self) -> &str;
117
118 fn can_handle(&self, event: &InvalidationEvent) -> bool;
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct InvalidationConfig {
125 pub max_events: usize,
127 pub event_retention: Duration,
129 pub enable_cascade: bool,
131 pub cascade_depth: u32,
133 pub enable_batching: bool,
135 pub batch_size: usize,
137 pub batch_timeout: Duration,
139}
140
141impl Default for InvalidationConfig {
142 fn default() -> Self {
143 Self {
144 max_events: 10000,
145 event_retention: Duration::from_secs(86400), enable_cascade: true,
147 cascade_depth: 3,
148 enable_batching: true,
149 batch_size: 100,
150 batch_timeout: Duration::from_secs(5),
151 }
152 }
153}
154
155#[derive(Debug, Clone, Default, Serialize, Deserialize)]
157pub struct InvalidationStats {
158 pub total_events: u64,
159 pub successful_invalidations: u64,
160 pub failed_invalidations: u64,
161 pub cascade_invalidations: u64,
162 pub manual_invalidations: u64,
163 pub expired_invalidations: u64,
164 pub average_processing_time_ms: f64,
165 pub last_invalidation: Option<DateTime<Utc>>,
166}
167
168impl CacheInvalidationMiddleware {
169 #[must_use]
171 pub fn new(config: InvalidationConfig) -> Self {
172 Self {
173 rules: Arc::new(RwLock::new(HashMap::new())),
174 events: Arc::new(RwLock::new(Vec::new())),
175 handlers: Arc::new(RwLock::new(HashMap::new())),
176 config,
177 stats: Arc::new(RwLock::new(InvalidationStats::default())),
178 }
179 }
180
181 #[must_use]
183 pub fn new_default() -> Self {
184 Self::new(InvalidationConfig::default())
185 }
186
187 pub fn register_handler(&self, handler: Box<dyn CacheInvalidationHandler + Send + Sync>) {
189 let mut handlers = self.handlers.write();
190 handlers.insert(handler.cache_type().to_string(), handler);
191 }
192
193 pub fn add_rule(&self, rule: InvalidationRule) {
195 let mut rules = self.rules.write();
196 rules.insert(rule.name.clone(), rule);
197 }
198
199 pub async fn process_event(&self, event: InvalidationEvent) -> Result<()> {
205 let start_time = std::time::Instant::now();
206
207 self.store_event(&event);
209
210 let applicable_rules = self.find_applicable_rules(&event);
212
213 for rule in applicable_rules {
215 if let Err(e) = self.process_rule(&event, &rule).await {
216 warn!("Failed to process invalidation rule {}: {}", rule.name, e);
217 self.record_failed_invalidation();
218 } else {
219 self.record_successful_invalidation();
220 }
221 }
222
223 if self.config.enable_cascade {
225 self.handle_cascade_invalidation(&event).await?;
226 }
227
228 #[allow(clippy::cast_precision_loss)]
230 let processing_time = start_time.elapsed().as_millis().min(u128::from(u64::MAX)) as f64;
231 {
232 let mut stats = self.stats.write();
233 stats.total_events += 1;
234 }
235 self.update_processing_time(processing_time);
236
237 debug!(
238 "Processed invalidation event: {} for entity: {}:{}",
239 event.event_type,
240 event.entity_type,
241 event
242 .entity_id
243 .map_or_else(|| "none".to_string(), |id| id.to_string())
244 );
245
246 Ok(())
247 }
248
249 pub async fn manual_invalidate(
255 &self,
256 entity_type: &str,
257 entity_id: Option<Uuid>,
258 cache_types: Option<Vec<String>>,
259 ) -> Result<()> {
260 let event = InvalidationEvent {
261 event_id: Uuid::new_v4(),
262 event_type: InvalidationEventType::ManualInvalidation,
263 entity_type: entity_type.to_string(),
264 entity_id,
265 operation: "manual_invalidation".to_string(),
266 timestamp: Utc::now(),
267 affected_caches: cache_types.unwrap_or_default(),
268 metadata: HashMap::new(),
269 };
270
271 self.process_event(event).await?;
272 self.record_manual_invalidation();
273 Ok(())
274 }
275
276 #[must_use]
278 pub fn get_stats(&self) -> InvalidationStats {
279 self.stats.read().clone()
280 }
281
282 #[must_use]
284 pub fn get_recent_events(&self, limit: usize) -> Vec<InvalidationEvent> {
285 let events = self.events.read();
286 events.iter().rev().take(limit).cloned().collect()
287 }
288
289 #[must_use]
291 pub fn get_events_by_entity_type(&self, entity_type: &str) -> Vec<InvalidationEvent> {
292 let events = self.events.read();
293 events
294 .iter()
295 .filter(|event| event.entity_type == entity_type)
296 .cloned()
297 .collect()
298 }
299
300 fn store_event(&self, event: &InvalidationEvent) {
302 let mut events = self.events.write();
303 events.push(event.clone());
304
305 if events.len() > self.config.max_events {
307 let excess = events.len() - self.config.max_events;
308 events.drain(0..excess);
309 }
310
311 let cutoff_time = Utc::now()
313 - chrono::Duration::from_std(self.config.event_retention).unwrap_or_default();
314 events.retain(|event| event.timestamp > cutoff_time);
315 }
316
317 fn find_applicable_rules(&self, event: &InvalidationEvent) -> Vec<InvalidationRule> {
319 let rules = self.rules.read();
320 rules
321 .values()
322 .filter(|rule| {
323 rule.enabled
324 && rule.entity_type == event.entity_type
325 && (rule.operations.is_empty() || rule.operations.contains(&event.operation))
326 })
327 .cloned()
328 .collect()
329 }
330
331 async fn process_rule(&self, event: &InvalidationEvent, rule: &InvalidationRule) -> Result<()> {
333 match &rule.invalidation_strategy {
334 InvalidationStrategy::InvalidateAll => {
335 let handlers_guard = self.handlers.read();
337 for handler in handlers_guard.values() {
338 if handler.can_handle(event) {
339 handler.invalidate(event)?;
340 }
341 }
342 }
343 InvalidationStrategy::InvalidateSpecific(cache_types) => {
344 let handlers_guard = self.handlers.read();
346 for cache_type in cache_types {
347 if let Some(handler) = handlers_guard.get(cache_type) {
348 if handler.can_handle(event) {
349 handler.invalidate(event)?;
350 }
351 }
352 }
353 }
354 InvalidationStrategy::InvalidateByEntity => {
355 if let Some(_entity_id) = event.entity_id {
357 let handlers_guard = self.handlers.read();
358 for handler in handlers_guard.values() {
359 if handler.can_handle(event) {
360 handler.invalidate(event)?;
361 }
362 }
363 }
364 }
365 InvalidationStrategy::InvalidateByPattern(pattern) => {
366 let handlers_guard = self.handlers.read();
368 for handler in handlers_guard.values() {
369 if handler.can_handle(event) && Self::matches_pattern(event, pattern) {
370 handler.invalidate(event)?;
371 }
372 }
373 }
374 InvalidationStrategy::CascadeInvalidation => {
375 self.handle_cascade_invalidation(event).await?;
377 }
378 }
379
380 Ok(())
381 }
382
383 async fn handle_cascade_invalidation(&self, event: &InvalidationEvent) -> Result<()> {
385 let dependent_entities = Self::find_dependent_entities(event);
387
388 for dependent_entity in dependent_entities {
389 let dependent_event = InvalidationEvent {
390 event_id: Uuid::new_v4(),
391 event_type: InvalidationEventType::CascadeInvalidation,
392 entity_type: dependent_entity.entity_type.clone(),
393 entity_id: dependent_entity.entity_id,
394 operation: "cascade_invalidation".to_string(),
395 timestamp: Utc::now(),
396 affected_caches: dependent_entity.affected_caches.clone(),
397 metadata: HashMap::new(),
398 };
399
400 Box::pin(self.process_event(dependent_event)).await?;
401 self.record_cascade_invalidation();
402 }
403
404 Ok(())
405 }
406
407 fn find_dependent_entities(event: &InvalidationEvent) -> Vec<DependentEntity> {
409 let mut dependent_entities = Vec::new();
412
413 match event.entity_type.as_str() {
414 "task" => {
415 if let Some(_task_id) = event.entity_id {
417 dependent_entities.push(DependentEntity {
418 entity_type: "project".to_string(),
419 entity_id: None, affected_caches: vec!["l1".to_string(), "l2".to_string()],
421 });
422 dependent_entities.push(DependentEntity {
423 entity_type: "area".to_string(),
424 entity_id: None, affected_caches: vec!["l1".to_string(), "l2".to_string()],
426 });
427 }
428 }
429 "project" => {
430 if let Some(_project_id) = event.entity_id {
432 dependent_entities.push(DependentEntity {
433 entity_type: "task".to_string(),
434 entity_id: None, affected_caches: vec!["l1".to_string(), "l2".to_string()],
436 });
437 }
438 }
439 "area" => {
440 if let Some(_area_id) = event.entity_id {
442 dependent_entities.push(DependentEntity {
443 entity_type: "project".to_string(),
444 entity_id: None,
445 affected_caches: vec!["l1".to_string(), "l2".to_string()],
446 });
447 dependent_entities.push(DependentEntity {
448 entity_type: "task".to_string(),
449 entity_id: None,
450 affected_caches: vec!["l1".to_string(), "l2".to_string()],
451 });
452 }
453 }
454 _ => {
455 }
457 }
458
459 dependent_entities
460 }
461
462 fn matches_pattern(event: &InvalidationEvent, pattern: &str) -> bool {
464 event.entity_type.contains(pattern) || event.operation.contains(pattern)
466 }
467
468 fn record_successful_invalidation(&self) {
470 let mut stats = self.stats.write();
471 stats.successful_invalidations += 1;
472 stats.last_invalidation = Some(Utc::now());
473 }
474
475 fn record_failed_invalidation(&self) {
477 let mut stats = self.stats.write();
478 stats.failed_invalidations += 1;
479 }
480
481 fn record_cascade_invalidation(&self) {
483 let mut stats = self.stats.write();
484 stats.cascade_invalidations += 1;
485 }
486
487 fn record_manual_invalidation(&self) {
489 let mut stats = self.stats.write();
490 stats.manual_invalidations += 1;
491 }
492
493 fn update_processing_time(&self, processing_time: f64) {
495 let mut stats = self.stats.write();
496
497 #[allow(clippy::cast_precision_loss)]
499 let total_events = stats.total_events as f64;
500 stats.average_processing_time_ms =
501 (stats.average_processing_time_ms * (total_events - 1.0) + processing_time)
502 / total_events;
503 }
504}
505
506#[derive(Debug, Clone)]
508struct DependentEntity {
509 entity_type: String,
510 entity_id: Option<Uuid>,
511 affected_caches: Vec<String>,
512}
513
514#[derive(Debug, Clone, Serialize, Deserialize)]
516pub enum CascadeInvalidationEvent {
517 InvalidateAll,
519 InvalidateSpecific(Vec<String>),
521 InvalidateByLevel(u32),
523}
524
525#[cfg(test)]
526mod tests {
527 use super::*;
528 use std::collections::HashMap;
529
530 struct MockCacheHandler {
532 cache_type: String,
533 invalidated_events: Arc<RwLock<Vec<InvalidationEvent>>>,
534 }
535
536 impl MockCacheHandler {
537 fn new(cache_type: &str) -> Self {
538 Self {
539 cache_type: cache_type.to_string(),
540 invalidated_events: Arc::new(RwLock::new(Vec::new())),
541 }
542 }
543
544 fn _get_invalidated_events(&self) -> Vec<InvalidationEvent> {
545 self.invalidated_events.read().clone()
546 }
547 }
548
549 impl CacheInvalidationHandler for MockCacheHandler {
550 fn invalidate(&self, event: &InvalidationEvent) -> Result<()> {
551 let mut events = self.invalidated_events.write();
552 events.push(event.clone());
553 Ok(())
554 }
555
556 fn cache_type(&self) -> &str {
557 &self.cache_type
558 }
559
560 fn can_handle(&self, event: &InvalidationEvent) -> bool {
561 event.affected_caches.is_empty() || event.affected_caches.contains(&self.cache_type)
562 }
563 }
564
565 #[tokio::test]
566 async fn test_invalidation_middleware_basic() {
567 let middleware = CacheInvalidationMiddleware::new_default();
568
569 let _l1_handler = Arc::new(MockCacheHandler::new("l1"));
571 let _l2_handler = Arc::new(MockCacheHandler::new("l2"));
572
573 middleware.register_handler(Box::new(MockCacheHandler::new("l1")));
574 middleware.register_handler(Box::new(MockCacheHandler::new("l2")));
575
576 let task_rule = InvalidationRule {
578 rule_id: Uuid::new_v4(),
579 name: "task_rule".to_string(),
580 description: "Rule for task invalidation".to_string(),
581 entity_type: "task".to_string(),
582 operations: vec!["updated".to_string()],
583 affected_cache_types: vec!["l1".to_string(), "l2".to_string()],
584 invalidation_strategy: InvalidationStrategy::InvalidateAll,
585 enabled: true,
586 created_at: Utc::now(),
587 updated_at: Utc::now(),
588 };
589 middleware.add_rule(task_rule);
590
591 let project_rule = InvalidationRule {
592 rule_id: Uuid::new_v4(),
593 name: "project_rule".to_string(),
594 description: "Rule for project invalidation".to_string(),
595 entity_type: "project".to_string(),
596 operations: vec!["cascade_invalidation".to_string()],
597 affected_cache_types: vec!["l1".to_string(), "l2".to_string()],
598 invalidation_strategy: InvalidationStrategy::InvalidateAll,
599 enabled: true,
600 created_at: Utc::now(),
601 updated_at: Utc::now(),
602 };
603 middleware.add_rule(project_rule);
604
605 let area_rule = InvalidationRule {
606 rule_id: Uuid::new_v4(),
607 name: "area_rule".to_string(),
608 description: "Rule for area invalidation".to_string(),
609 entity_type: "area".to_string(),
610 operations: vec!["cascade_invalidation".to_string()],
611 affected_cache_types: vec!["l1".to_string(), "l2".to_string()],
612 invalidation_strategy: InvalidationStrategy::InvalidateAll,
613 enabled: true,
614 created_at: Utc::now(),
615 updated_at: Utc::now(),
616 };
617 middleware.add_rule(area_rule);
618
619 let event = InvalidationEvent {
621 event_id: Uuid::new_v4(),
622 event_type: InvalidationEventType::Updated,
623 entity_type: "task".to_string(),
624 entity_id: Some(Uuid::new_v4()),
625 operation: "updated".to_string(),
626 timestamp: Utc::now(),
627 affected_caches: vec!["l1".to_string(), "l2".to_string()],
628 metadata: HashMap::new(),
629 };
630
631 middleware.process_event(event).await.unwrap();
633
634 let stats = middleware.get_stats();
636 assert_eq!(stats.total_events, 3); assert_eq!(stats.successful_invalidations, 3);
638 }
639
640 #[tokio::test]
641 async fn test_manual_invalidation() {
642 let middleware = CacheInvalidationMiddleware::new_default();
643
644 middleware.register_handler(Box::new(MockCacheHandler::new("l1")));
645
646 middleware
648 .manual_invalidate("task", Some(Uuid::new_v4()), None)
649 .await
650 .unwrap();
651
652 let stats = middleware.get_stats();
653 assert_eq!(stats.manual_invalidations, 1);
654 }
655
656 #[tokio::test]
657 async fn test_event_storage() {
658 let middleware = CacheInvalidationMiddleware::new_default();
659
660 let event = InvalidationEvent {
661 event_id: Uuid::new_v4(),
662 event_type: InvalidationEventType::Created,
663 entity_type: "task".to_string(),
664 entity_id: Some(Uuid::new_v4()),
665 operation: "created".to_string(),
666 timestamp: Utc::now(),
667 affected_caches: vec![],
668 metadata: HashMap::new(),
669 };
670
671 middleware.store_event(&event);
672
673 let recent_events = middleware.get_recent_events(1);
674 assert_eq!(recent_events.len(), 1);
675 assert_eq!(recent_events[0].entity_type, "task");
676 }
677
678 #[tokio::test]
679 async fn test_invalidation_middleware_creation() {
680 let middleware = CacheInvalidationMiddleware::new_default();
681 let stats = middleware.get_stats();
682
683 assert_eq!(stats.total_events, 0);
684 assert_eq!(stats.successful_invalidations, 0);
685 assert_eq!(stats.failed_invalidations, 0);
686 assert_eq!(stats.manual_invalidations, 0);
687 }
688
689 #[tokio::test]
690 async fn test_invalidation_middleware_with_config() {
691 let config = InvalidationConfig {
692 enable_cascade: true,
693 max_events: 1000,
694 event_retention: Duration::from_secs(3600),
695 batch_size: 10,
696 batch_timeout: Duration::from_secs(30),
697 cascade_depth: 3,
698 enable_batching: true,
699 };
700
701 let middleware = CacheInvalidationMiddleware::new(config);
702 let stats = middleware.get_stats();
703
704 assert_eq!(stats.total_events, 0);
705 }
706
707 #[tokio::test]
708 async fn test_add_rule() {
709 let middleware = CacheInvalidationMiddleware::new_default();
710
711 let rule = InvalidationRule {
712 rule_id: Uuid::new_v4(),
713 name: "test_rule".to_string(),
714 description: "Test rule".to_string(),
715 entity_type: "task".to_string(),
716 operations: vec!["updated".to_string()],
717 affected_cache_types: vec!["task_cache".to_string()],
718 invalidation_strategy: InvalidationStrategy::InvalidateAll,
719 enabled: true,
720 created_at: Utc::now(),
721 updated_at: Utc::now(),
722 };
723
724 middleware.add_rule(rule);
725
726 }
729
730 #[tokio::test]
731 async fn test_register_handler() {
732 let middleware = CacheInvalidationMiddleware::new_default();
733 let handler = Box::new(MockCacheHandler::new("test_cache"));
734
735 middleware.register_handler(handler);
736
737 }
740
741 #[tokio::test]
742 async fn test_process_event_with_handler() {
743 let middleware = CacheInvalidationMiddleware::new_default();
744 let handler = Box::new(MockCacheHandler::new("test_cache"));
745 let l1_handler = Box::new(MockCacheHandler::new("l1"));
746 let l2_handler = Box::new(MockCacheHandler::new("l2"));
747
748 middleware.register_handler(handler);
749 middleware.register_handler(l1_handler);
750 middleware.register_handler(l2_handler);
751
752 let rule = InvalidationRule {
753 rule_id: Uuid::new_v4(),
754 name: "test_rule".to_string(),
755 description: "Test rule".to_string(),
756 entity_type: "task".to_string(),
757 operations: vec!["created".to_string(), "updated".to_string()],
758 affected_cache_types: vec!["test_cache".to_string()],
759 invalidation_strategy: InvalidationStrategy::InvalidateAll,
760 enabled: true,
761 created_at: Utc::now(),
762 updated_at: Utc::now(),
763 };
764 middleware.add_rule(rule);
765
766 let project_rule = InvalidationRule {
769 rule_id: Uuid::new_v4(),
770 name: "project_rule".to_string(),
771 description: "Rule for project invalidation".to_string(),
772 entity_type: "project".to_string(),
773 operations: vec!["cascade_invalidation".to_string()],
774 affected_cache_types: vec!["l1".to_string(), "l2".to_string()],
775 invalidation_strategy: InvalidationStrategy::InvalidateAll,
776 enabled: true,
777 created_at: Utc::now(),
778 updated_at: Utc::now(),
779 };
780 middleware.add_rule(project_rule);
781
782 let area_rule = InvalidationRule {
783 rule_id: Uuid::new_v4(),
784 name: "area_rule".to_string(),
785 description: "Rule for area invalidation".to_string(),
786 entity_type: "area".to_string(),
787 operations: vec!["cascade_invalidation".to_string()],
788 affected_cache_types: vec!["l1".to_string(), "l2".to_string()],
789 invalidation_strategy: InvalidationStrategy::InvalidateAll,
790 enabled: true,
791 created_at: Utc::now(),
792 updated_at: Utc::now(),
793 };
794 middleware.add_rule(area_rule);
795
796 let event = InvalidationEvent {
797 event_id: Uuid::new_v4(),
798 event_type: InvalidationEventType::Created,
799 entity_type: "task".to_string(),
800 entity_id: Some(Uuid::new_v4()),
801 operation: "created".to_string(),
802 timestamp: Utc::now(),
803 affected_caches: vec!["test_cache".to_string()],
804 metadata: HashMap::new(),
805 };
806
807 let _ = middleware.process_event(event).await;
808
809 let stats = middleware.get_stats();
810 assert_eq!(stats.total_events, 3); assert_eq!(stats.successful_invalidations, 3);
812 }
813
814 #[tokio::test]
815 async fn test_get_recent_events() {
816 let middleware = CacheInvalidationMiddleware::new_default();
817
818 for i in 0..5 {
820 let event = InvalidationEvent {
821 event_id: Uuid::new_v4(),
822 event_type: InvalidationEventType::Created,
823 entity_type: format!("task_{i}"),
824 entity_id: Some(Uuid::new_v4()),
825 operation: "created".to_string(),
826 timestamp: Utc::now(),
827 affected_caches: vec![],
828 metadata: HashMap::new(),
829 };
830 middleware.store_event(&event);
831 }
832
833 let recent_events = middleware.get_recent_events(3);
835 assert_eq!(recent_events.len(), 3);
836
837 let all_events = middleware.get_recent_events(10);
839 assert_eq!(all_events.len(), 5);
840 }
841}