1use crate::application::{
7 circuit_breaker::CircuitBreaker,
8 emitter::EmitterConfig,
9 limiter::{LimitDecision, RateLimiter},
10 metrics::Metrics,
11 ports::{Clock, Storage},
12 registry::{EventState, SuppressionRegistry},
13};
14use crate::domain::{policy::Policy, signature::EventSignature};
15use crate::infrastructure::clock::SystemClock;
16use crate::infrastructure::storage::ShardedStorage;
17
18use std::collections::BTreeMap;
19use std::sync::Arc;
20use std::time::Duration;
21use tracing::{Metadata, Subscriber};
22use tracing_subscriber::layer::Filter;
23use tracing_subscriber::{layer::Context, Layer};
24
25#[cfg(feature = "async")]
26use crate::application::emitter::{EmitterHandle, SummaryEmitter};
27
28#[cfg(feature = "async")]
29use crate::domain::summary::SuppressionSummary;
30
31#[cfg(feature = "async")]
32use std::sync::Mutex;
33
34#[cfg(feature = "async")]
39pub type SummaryFormatter = Arc<dyn Fn(&SuppressionSummary) + Send + Sync + 'static>;
40
41#[derive(Debug, Clone, PartialEq, Eq)]
43pub enum BuildError {
44 ZeroMaxSignatures,
46 EmitterConfig(crate::application::emitter::EmitterConfigError),
48}
49
50impl std::fmt::Display for BuildError {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 match self {
53 BuildError::ZeroMaxSignatures => {
54 write!(f, "max_signatures must be greater than 0")
55 }
56 BuildError::EmitterConfig(e) => {
57 write!(f, "emitter configuration error: {}", e)
58 }
59 }
60 }
61}
62
63impl std::error::Error for BuildError {}
64
65impl From<crate::application::emitter::EmitterConfigError> for BuildError {
66 fn from(e: crate::application::emitter::EmitterConfigError) -> Self {
67 BuildError::EmitterConfig(e)
68 }
69}
70
71pub struct TracingRateLimitLayerBuilder {
73 policy: Policy,
74 summary_interval: Duration,
75 clock: Option<Arc<dyn Clock>>,
76 max_signatures: Option<usize>,
77 enable_active_emission: bool,
78 #[cfg(feature = "async")]
79 summary_formatter: Option<SummaryFormatter>,
80}
81
82impl TracingRateLimitLayerBuilder {
83 pub fn with_policy(mut self, policy: Policy) -> Self {
85 self.policy = policy;
86 self
87 }
88
89 pub fn with_summary_interval(mut self, interval: Duration) -> Self {
93 self.summary_interval = interval;
94 self
95 }
96
97 pub fn with_clock(mut self, clock: Arc<dyn Clock>) -> Self {
99 self.clock = Some(clock);
100 self
101 }
102
103 pub fn with_max_signatures(mut self, max_signatures: usize) -> Self {
112 self.max_signatures = Some(max_signatures);
113 self
114 }
115
116 pub fn with_unlimited_signatures(mut self) -> Self {
122 self.max_signatures = None;
123 self
124 }
125
126 pub fn with_active_emission(mut self, enabled: bool) -> Self {
147 self.enable_active_emission = enabled;
148 self
149 }
150
151 #[cfg(feature = "async")]
181 pub fn with_summary_formatter(mut self, formatter: SummaryFormatter) -> Self {
182 self.summary_formatter = Some(formatter);
183 self
184 }
185
186 pub fn build(self) -> Result<TracingRateLimitLayer, BuildError> {
191 if let Some(max) = self.max_signatures {
193 if max == 0 {
194 return Err(BuildError::ZeroMaxSignatures);
195 }
196 }
197
198 let metrics = Metrics::new();
200 let circuit_breaker = Arc::new(CircuitBreaker::new());
201
202 let clock = self.clock.unwrap_or_else(|| Arc::new(SystemClock::new()));
203 let storage = if let Some(max) = self.max_signatures {
204 Arc::new(ShardedStorage::with_max_entries(max).with_metrics(metrics.clone()))
205 } else {
206 Arc::new(ShardedStorage::new().with_metrics(metrics.clone()))
207 };
208 let registry = SuppressionRegistry::new(storage, clock, self.policy);
209 let limiter = RateLimiter::new(registry.clone(), metrics.clone(), circuit_breaker);
210
211 let emitter_config = EmitterConfig::new(self.summary_interval)?;
213
214 #[cfg(feature = "async")]
215 let emitter_handle = if self.enable_active_emission {
216 let emitter = SummaryEmitter::new(registry, emitter_config);
217
218 let formatter = self.summary_formatter.unwrap_or_else(|| {
220 Arc::new(|summary: &SuppressionSummary| {
221 tracing::warn!(
222 signature = %summary.signature,
223 count = summary.count,
224 "{}",
225 summary.format_message()
226 );
227 })
228 });
229
230 let handle = emitter.start(
231 move |summaries| {
232 for summary in summaries {
233 formatter(&summary);
234 }
235 },
236 false, );
238 Arc::new(Mutex::new(Some(handle)))
239 } else {
240 Arc::new(Mutex::new(None))
241 };
242
243 Ok(TracingRateLimitLayer {
244 limiter,
245 #[cfg(feature = "async")]
246 emitter_handle,
247 #[cfg(not(feature = "async"))]
248 _emitter_config: emitter_config,
249 })
250 }
251}
252
253#[derive(Clone)]
261pub struct TracingRateLimitLayer<S = Arc<ShardedStorage<EventSignature, EventState>>>
262where
263 S: Storage<EventSignature, EventState> + Clone,
264{
265 limiter: RateLimiter<S>,
266 #[cfg(feature = "async")]
267 emitter_handle: Arc<Mutex<Option<EmitterHandle>>>,
268 #[cfg(not(feature = "async"))]
269 _emitter_config: EmitterConfig,
270}
271
272impl<S> TracingRateLimitLayer<S>
273where
274 S: Storage<EventSignature, EventState> + Clone,
275{
276 fn compute_signature(
278 &self,
279 metadata: &Metadata,
280 _fields: &BTreeMap<String, String>,
281 ) -> EventSignature {
282 let level = metadata.level().as_str();
285 let message = metadata.name();
286 let target = Some(metadata.target());
287
288 let fields = BTreeMap::new();
291
292 EventSignature::new(level, message, &fields, target)
293 }
294
295 pub fn should_allow(&self, signature: EventSignature) -> bool {
297 matches!(self.limiter.check_event(signature), LimitDecision::Allow)
298 }
299
300 pub fn limiter(&self) -> &RateLimiter<S> {
302 &self.limiter
303 }
304
305 pub fn metrics(&self) -> &Metrics {
312 self.limiter.metrics()
313 }
314
315 pub fn signature_count(&self) -> usize {
317 self.limiter.registry().len()
318 }
319
320 pub fn circuit_breaker(&self) -> &Arc<CircuitBreaker> {
326 self.limiter.circuit_breaker()
327 }
328
329 #[cfg(feature = "async")]
357 pub async fn shutdown(&self) -> Result<(), crate::application::emitter::ShutdownError> {
358 let handle = {
360 let mut handle_guard = self.emitter_handle.lock().unwrap();
361 handle_guard.take()
362 };
363
364 if let Some(handle) = handle {
365 handle.shutdown().await?;
366 }
367 Ok(())
368 }
369}
370
371impl TracingRateLimitLayer<Arc<ShardedStorage<EventSignature, EventState>>> {
372 pub fn builder() -> TracingRateLimitLayerBuilder {
381 TracingRateLimitLayerBuilder {
382 policy: Policy::token_bucket(50.0, 1.0)
383 .expect("default policy with 50 capacity and 1/sec refill is always valid"),
384 summary_interval: Duration::from_secs(30),
385 clock: None,
386 max_signatures: Some(10_000),
387 enable_active_emission: false,
388 #[cfg(feature = "async")]
389 summary_formatter: None,
390 }
391 }
392
393 pub fn new() -> Self {
405 Self::builder()
406 .build()
407 .expect("default configuration is always valid")
408 }
409}
410
411impl Default for TracingRateLimitLayer<Arc<ShardedStorage<EventSignature, EventState>>> {
412 fn default() -> Self {
413 Self::new()
414 }
415}
416
417impl<S, Sub> Filter<Sub> for TracingRateLimitLayer<S>
419where
420 S: Storage<EventSignature, EventState> + Clone,
421 Sub: Subscriber,
422{
423 fn enabled(&self, meta: &Metadata<'_>, _cx: &Context<'_, Sub>) -> bool {
424 let fields = BTreeMap::new();
426 let signature = self.compute_signature(meta, &fields);
427 self.should_allow(signature)
428 }
429}
430
431impl<S, Sub> Layer<Sub> for TracingRateLimitLayer<S>
432where
433 S: Storage<EventSignature, EventState> + Clone + 'static,
434 Sub: Subscriber,
435{
436 }
438
439#[cfg(test)]
440mod tests {
441 use super::*;
442 use tracing::info;
443 use tracing_subscriber::layer::SubscriberExt;
444
445 #[test]
446 fn test_layer_builder() {
447 let layer = TracingRateLimitLayer::builder()
448 .with_policy(Policy::count_based(50).unwrap())
449 .with_summary_interval(Duration::from_secs(60))
450 .build()
451 .unwrap();
452
453 assert!(layer.limiter().registry().is_empty());
454 }
455
456 #[test]
457 fn test_layer_default() {
458 let layer = TracingRateLimitLayer::default();
459 assert!(layer.limiter().registry().is_empty());
460 }
461
462 #[test]
463 fn test_signature_computation() {
464 let _layer = TracingRateLimitLayer::new();
465
466 let sig1 = EventSignature::simple("INFO", "test_event");
468 let sig2 = EventSignature::simple("INFO", "test_event");
469
470 assert_eq!(sig1, sig2);
472 }
473
474 #[test]
475 fn test_basic_rate_limiting() {
476 let layer = TracingRateLimitLayer::builder()
477 .with_policy(Policy::count_based(2).unwrap())
478 .build()
479 .unwrap();
480
481 let sig = EventSignature::simple("INFO", "test_message");
482
483 assert!(layer.should_allow(sig));
485 assert!(layer.should_allow(sig));
486
487 assert!(!layer.should_allow(sig));
489 }
490
491 #[test]
492 fn test_layer_integration() {
493 let layer = TracingRateLimitLayer::builder()
494 .with_policy(Policy::count_based(3).unwrap())
495 .build()
496 .unwrap();
497
498 let layer_for_check = layer.clone();
500
501 let subscriber = tracing_subscriber::registry()
502 .with(tracing_subscriber::fmt::layer().with_filter(layer));
503
504 tracing::subscriber::with_default(subscriber, || {
506 for _ in 0..10 {
508 info!("test event");
509 }
510 });
511
512 assert_eq!(layer_for_check.limiter().registry().len(), 1);
516 }
517
518 #[test]
519 fn test_layer_suppression_logic() {
520 let layer = TracingRateLimitLayer::builder()
521 .with_policy(Policy::count_based(3).unwrap())
522 .build()
523 .unwrap();
524
525 let sig = EventSignature::simple("INFO", "test");
526
527 let mut allowed_count = 0;
529 for _ in 0..10 {
530 if layer.should_allow(sig) {
531 allowed_count += 1;
532 }
533 }
534
535 assert_eq!(allowed_count, 3);
536 }
537
538 #[test]
539 fn test_builder_zero_summary_interval() {
540 let result = TracingRateLimitLayer::builder()
541 .with_summary_interval(Duration::from_secs(0))
542 .build();
543
544 assert!(matches!(
545 result,
546 Err(BuildError::EmitterConfig(
547 crate::application::emitter::EmitterConfigError::ZeroSummaryInterval
548 ))
549 ));
550 }
551
552 #[test]
553 fn test_builder_zero_max_signatures() {
554 let result = TracingRateLimitLayer::builder()
555 .with_max_signatures(0)
556 .build();
557
558 assert!(matches!(result, Err(BuildError::ZeroMaxSignatures)));
559 }
560
561 #[test]
562 fn test_builder_valid_max_signatures() {
563 let layer = TracingRateLimitLayer::builder()
564 .with_max_signatures(100)
565 .build()
566 .unwrap();
567
568 assert!(layer.limiter().registry().is_empty());
569 }
570
571 #[test]
572 fn test_metrics_tracking() {
573 let layer = TracingRateLimitLayer::builder()
574 .with_policy(Policy::count_based(2).unwrap())
575 .build()
576 .unwrap();
577
578 let sig = EventSignature::simple("INFO", "test");
579
580 assert_eq!(layer.metrics().events_allowed(), 0);
582 assert_eq!(layer.metrics().events_suppressed(), 0);
583
584 assert!(layer.should_allow(sig));
586 assert!(layer.should_allow(sig));
587
588 assert_eq!(layer.metrics().events_allowed(), 2);
590 assert_eq!(layer.metrics().events_suppressed(), 0);
591
592 assert!(!layer.should_allow(sig));
594
595 assert_eq!(layer.metrics().events_allowed(), 2);
597 assert_eq!(layer.metrics().events_suppressed(), 1);
598 }
599
600 #[test]
601 fn test_metrics_snapshot() {
602 let layer = TracingRateLimitLayer::builder()
603 .with_policy(Policy::count_based(3).unwrap())
604 .build()
605 .unwrap();
606
607 let sig = EventSignature::simple("INFO", "test");
608
609 for _ in 0..5 {
611 layer.should_allow(sig);
612 }
613
614 let snapshot = layer.metrics().snapshot();
616 assert_eq!(snapshot.events_allowed, 3);
617 assert_eq!(snapshot.events_suppressed, 2);
618 assert_eq!(snapshot.total_events(), 5);
619 assert!((snapshot.suppression_rate() - 0.4).abs() < f64::EPSILON);
620 }
621
622 #[test]
623 fn test_signature_count() {
624 let layer = TracingRateLimitLayer::builder()
625 .with_policy(Policy::count_based(2).unwrap())
626 .build()
627 .unwrap();
628
629 assert_eq!(layer.signature_count(), 0);
630
631 let sig1 = EventSignature::simple("INFO", "test1");
632 let sig2 = EventSignature::simple("INFO", "test2");
633
634 layer.should_allow(sig1);
635 assert_eq!(layer.signature_count(), 1);
636
637 layer.should_allow(sig2);
638 assert_eq!(layer.signature_count(), 2);
639
640 layer.should_allow(sig1);
642 assert_eq!(layer.signature_count(), 2);
643 }
644
645 #[test]
646 fn test_metrics_with_eviction() {
647 let layer = TracingRateLimitLayer::builder()
648 .with_policy(Policy::count_based(1).unwrap())
649 .with_max_signatures(3)
650 .build()
651 .unwrap();
652
653 for i in 0..3 {
655 let sig = EventSignature::simple("INFO", &format!("test{}", i));
656 layer.should_allow(sig);
657 }
658
659 assert_eq!(layer.signature_count(), 3);
660 assert_eq!(layer.metrics().signatures_evicted(), 0);
661
662 let sig = EventSignature::simple("INFO", "test3");
664 layer.should_allow(sig);
665
666 assert_eq!(layer.signature_count(), 3);
667 assert_eq!(layer.metrics().signatures_evicted(), 1);
668 }
669
670 #[test]
671 fn test_circuit_breaker_observability() {
672 use crate::application::circuit_breaker::CircuitState;
673
674 let layer = TracingRateLimitLayer::builder()
675 .with_policy(Policy::count_based(2).unwrap())
676 .build()
677 .unwrap();
678
679 let cb = layer.circuit_breaker();
681 assert_eq!(cb.state(), CircuitState::Closed);
682 assert_eq!(cb.consecutive_failures(), 0);
683
684 let sig = EventSignature::simple("INFO", "test");
686 layer.should_allow(sig);
687 layer.should_allow(sig);
688 layer.should_allow(sig);
689
690 assert_eq!(cb.state(), CircuitState::Closed);
691 }
692
693 #[test]
694 fn test_circuit_breaker_fail_open_integration() {
695 use crate::application::circuit_breaker::{
696 CircuitBreaker, CircuitBreakerConfig, CircuitState,
697 };
698 use std::time::Duration;
699
700 let cb_config = CircuitBreakerConfig {
702 failure_threshold: 2,
703 recovery_timeout: Duration::from_secs(1),
704 };
705 let circuit_breaker = Arc::new(CircuitBreaker::with_config(cb_config));
706
707 let storage = Arc::new(ShardedStorage::new());
709 let clock = Arc::new(SystemClock::new());
710 let policy = Policy::count_based(2).unwrap();
711 let registry = SuppressionRegistry::new(storage, clock, policy);
712 let metrics = Metrics::new();
713 let limiter = RateLimiter::new(registry, metrics, circuit_breaker.clone());
714
715 let layer = TracingRateLimitLayer {
716 limiter,
717 #[cfg(feature = "async")]
718 emitter_handle: Arc::new(Mutex::new(None)),
719 #[cfg(not(feature = "async"))]
720 _emitter_config: crate::application::emitter::EmitterConfig::new(Duration::from_secs(
721 30,
722 ))
723 .unwrap(),
724 };
725
726 let sig = EventSignature::simple("INFO", "test");
727
728 assert!(layer.should_allow(sig));
730 assert!(layer.should_allow(sig));
731 assert!(!layer.should_allow(sig));
732
733 assert_eq!(circuit_breaker.state(), CircuitState::Closed);
735
736 circuit_breaker.record_failure();
738 circuit_breaker.record_failure();
739
740 assert_eq!(circuit_breaker.state(), CircuitState::Open);
742
743 assert!(layer.should_allow(sig));
746 assert!(layer.should_allow(sig));
747 assert!(layer.should_allow(sig));
748
749 let snapshot = layer.metrics().snapshot();
751 assert!(snapshot.events_allowed >= 5); }
753
754 #[cfg(feature = "async")]
755 #[tokio::test]
756 async fn test_active_emission_integration() {
757 use std::sync::atomic::{AtomicUsize, Ordering};
758 use std::time::Duration;
759
760 let emission_count = Arc::new(AtomicUsize::new(0));
762 let count_clone = Arc::clone(&emission_count);
763
764 let storage = Arc::new(ShardedStorage::new());
766 let clock = Arc::new(SystemClock::new());
767 let policy = Policy::count_based(2).unwrap();
768 let registry = SuppressionRegistry::new(storage, clock, policy);
769
770 let emitter_config = EmitterConfig::new(Duration::from_millis(100)).unwrap();
771 let emitter = SummaryEmitter::new(registry.clone(), emitter_config);
772
773 let handle = emitter.start(
775 move |summaries| {
776 count_clone.fetch_add(summaries.len(), Ordering::SeqCst);
777 },
778 false,
779 );
780
781 let sig = EventSignature::simple("INFO", "test_message");
783 for _ in 0..10 {
784 registry.with_event_state(sig, |state, now| {
785 state.counter.record_suppression(now);
786 });
787 }
788
789 tokio::time::sleep(Duration::from_millis(250)).await;
791
792 let count = emission_count.load(Ordering::SeqCst);
794 assert!(
795 count > 0,
796 "Expected at least one suppression summary to be emitted, got {}",
797 count
798 );
799
800 handle.shutdown().await.expect("shutdown failed");
802 }
803
804 #[cfg(feature = "async")]
805 #[tokio::test]
806 async fn test_active_emission_disabled() {
807 use crate::infrastructure::mocks::layer::MockCaptureLayer;
808 use std::time::Duration;
809
810 let layer = TracingRateLimitLayer::builder()
812 .with_policy(Policy::count_based(2).unwrap())
813 .with_summary_interval(Duration::from_millis(100))
814 .build()
815 .unwrap();
816
817 let mock = MockCaptureLayer::new();
818 let mock_clone = mock.clone();
819
820 let subscriber = tracing_subscriber::registry()
821 .with(mock)
822 .with(tracing_subscriber::fmt::layer().with_filter(layer.clone()));
823
824 tracing::subscriber::with_default(subscriber, || {
825 let sig = EventSignature::simple("INFO", "test_message");
826 for _ in 0..10 {
827 layer.should_allow(sig);
828 }
829 });
830
831 tokio::time::sleep(Duration::from_millis(250)).await;
833
834 let events = mock_clone.get_captured();
836 let summary_count = events
837 .iter()
838 .filter(|e| e.message.contains("suppressed"))
839 .count();
840
841 assert_eq!(
842 summary_count, 0,
843 "Should not emit summaries when active emission is disabled"
844 );
845
846 layer.shutdown().await.expect("shutdown failed");
848 }
849
850 #[cfg(feature = "async")]
851 #[tokio::test]
852 async fn test_shutdown_without_emission() {
853 let layer = TracingRateLimitLayer::new();
855
856 layer
858 .shutdown()
859 .await
860 .expect("shutdown should succeed when emitter not running");
861 }
862
863 #[cfg(feature = "async")]
864 #[tokio::test]
865 async fn test_custom_summary_formatter() {
866 use std::sync::atomic::{AtomicUsize, Ordering};
867 use std::time::Duration;
868
869 let call_count = Arc::new(AtomicUsize::new(0));
871 let count_clone = Arc::clone(&call_count);
872
873 let last_count = Arc::new(AtomicUsize::new(0));
875 let last_count_clone = Arc::clone(&last_count);
876
877 let layer = TracingRateLimitLayer::builder()
879 .with_policy(Policy::count_based(2).unwrap())
880 .with_active_emission(true)
881 .with_summary_interval(Duration::from_millis(100))
882 .with_summary_formatter(Arc::new(move |summary| {
883 count_clone.fetch_add(1, Ordering::SeqCst);
884 last_count_clone.store(summary.count, Ordering::SeqCst);
885 tracing::info!(
887 sig = %summary.signature,
888 suppressed = summary.count,
889 "Custom format"
890 );
891 }))
892 .build()
893 .unwrap();
894
895 let sig = EventSignature::simple("INFO", "test_message");
897 for _ in 0..10 {
898 layer.should_allow(sig);
899 }
900
901 tokio::time::sleep(Duration::from_millis(250)).await;
903
904 let calls = call_count.load(Ordering::SeqCst);
906 assert!(calls > 0, "Custom formatter should have been called");
907
908 let count = last_count.load(Ordering::SeqCst);
910 assert!(
911 count >= 8,
912 "Expected at least 8 suppressions, got {}",
913 count
914 );
915
916 layer.shutdown().await.expect("shutdown failed");
917 }
918
919 #[cfg(feature = "async")]
920 #[tokio::test]
921 async fn test_default_formatter_used() {
922 use std::sync::atomic::{AtomicUsize, Ordering};
923 use std::time::Duration;
924
925 let emission_count = Arc::new(AtomicUsize::new(0));
926 let count_clone = Arc::clone(&emission_count);
927
928 let storage = Arc::new(ShardedStorage::new());
929 let clock = Arc::new(SystemClock::new());
930 let policy = Policy::count_based(2).unwrap();
931 let registry = SuppressionRegistry::new(storage, clock, policy);
932
933 let emitter_config = EmitterConfig::new(Duration::from_millis(100)).unwrap();
934 let emitter = SummaryEmitter::new(registry.clone(), emitter_config);
935
936 let handle = emitter.start(
938 move |summaries| {
939 count_clone.fetch_add(summaries.len(), Ordering::SeqCst);
940 },
941 false,
942 );
943
944 let sig = EventSignature::simple("INFO", "test_message");
945 for _ in 0..10 {
946 registry.with_event_state(sig, |state, now| {
947 state.counter.record_suppression(now);
948 });
949 }
950
951 tokio::time::sleep(Duration::from_millis(250)).await;
952
953 let count = emission_count.load(Ordering::SeqCst);
954 assert!(count > 0, "Default formatter should have emitted summaries");
955
956 handle.shutdown().await.expect("shutdown failed");
957 }
958}