tracing_throttle/domain/
policy.rs

1//! Rate limiting policies for event suppression.
2//!
3//! This module defines the core trait for rate limiting policies and provides
4//! several built-in implementations.
5
6use std::collections::VecDeque;
7use std::time::{Duration, Instant};
8
9#[cfg(feature = "redis-storage")]
10use serde::{Deserialize, Serialize};
11
12/// Error returned when policy validation fails.
13///
14/// This error type represents domain-level validation rules for rate limiting
15/// policies. The domain defines what constitutes a valid policy configuration.
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub enum PolicyError {
18    /// Maximum count must be greater than zero
19    ZeroMaxCount,
20    /// Maximum events must be greater than zero
21    ZeroMaxEvents,
22    /// Time window duration must be greater than zero
23    ZeroWindowDuration,
24    /// Bucket capacity must be greater than zero
25    ZeroCapacity,
26    /// Refill rate must be greater than zero
27    ZeroRefillRate,
28}
29
30impl std::fmt::Display for PolicyError {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        match self {
33            PolicyError::ZeroMaxCount => write!(f, "max_count must be greater than 0"),
34            PolicyError::ZeroMaxEvents => write!(f, "max_events must be greater than 0"),
35            PolicyError::ZeroWindowDuration => write!(f, "window duration must be greater than 0"),
36            PolicyError::ZeroCapacity => write!(f, "capacity must be greater than 0"),
37            PolicyError::ZeroRefillRate => write!(f, "refill_rate must be greater than 0"),
38        }
39    }
40}
41
42impl std::error::Error for PolicyError {}
43
44/// Decision made by a rate limiting policy.
45#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46pub enum PolicyDecision {
47    /// Allow the event to be emitted
48    Allow,
49    /// Suppress the event (don't emit it)
50    Suppress,
51}
52
53/// Trait for implementing rate limiting policies.
54///
55/// Policies determine whether an event should be allowed or suppressed based
56/// on historical event patterns.
57pub trait RateLimitPolicy: Send + Sync {
58    /// Register a new event occurrence and decide whether to allow or suppress it.
59    ///
60    /// # Arguments
61    /// * `timestamp` - When the event occurred
62    ///
63    /// # Returns
64    /// A `PolicyDecision` indicating whether to allow or suppress the event.
65    fn register_event(&mut self, timestamp: Instant) -> PolicyDecision;
66
67    /// Reset the policy state.
68    ///
69    /// Called when starting a new tracking period or when clearing history.
70    fn reset(&mut self);
71}
72
73/// Count-based rate limiting policy.
74///
75/// Allows up to N events, then suppresses all subsequent events.
76///
77/// # Example
78/// ```
79/// use tracing_throttle::{CountBasedPolicy, RateLimitPolicy};
80/// use std::time::Instant;
81///
82/// let mut policy = CountBasedPolicy::new(3).unwrap();
83/// let now = Instant::now();
84///
85/// // First 3 events allowed
86/// assert!(policy.register_event(now).is_allow());
87/// assert!(policy.register_event(now).is_allow());
88/// assert!(policy.register_event(now).is_allow());
89///
90/// // 4th and beyond suppressed
91/// assert!(policy.register_event(now).is_suppress());
92/// assert!(policy.register_event(now).is_suppress());
93/// ```
94#[derive(Debug, Clone, PartialEq)]
95#[cfg_attr(feature = "redis-storage", derive(Serialize, Deserialize))]
96pub struct CountBasedPolicy {
97    max_count: usize,
98    current_count: usize,
99}
100
101impl CountBasedPolicy {
102    /// Create a new count-based policy.
103    ///
104    /// # Arguments
105    /// * `max_count` - Maximum number of events to allow before suppressing (must be > 0)
106    ///
107    /// # Errors
108    /// Returns `PolicyError::ZeroMaxCount` if `max_count` is 0.
109    pub fn new(max_count: usize) -> Result<Self, PolicyError> {
110        if max_count == 0 {
111            return Err(PolicyError::ZeroMaxCount);
112        }
113        Ok(Self {
114            max_count,
115            current_count: 0,
116        })
117    }
118}
119
120impl RateLimitPolicy for CountBasedPolicy {
121    fn register_event(&mut self, _timestamp: Instant) -> PolicyDecision {
122        self.current_count += 1;
123        if self.current_count <= self.max_count {
124            PolicyDecision::Allow
125        } else {
126            PolicyDecision::Suppress
127        }
128    }
129
130    fn reset(&mut self) {
131        self.current_count = 0;
132    }
133}
134
135/// Time-window rate limiting policy.
136///
137/// Allows up to K events within a sliding time window. Events outside the
138/// window are automatically expired.
139///
140/// # Example
141/// ```
142/// use tracing_throttle::{TimeWindowPolicy, RateLimitPolicy};
143/// use std::time::{Duration, Instant};
144///
145/// let mut policy = TimeWindowPolicy::new(2, Duration::from_secs(60)).unwrap();
146/// let now = Instant::now();
147///
148/// // First 2 events allowed
149/// assert!(policy.register_event(now).is_allow());
150/// assert!(policy.register_event(now).is_allow());
151///
152/// // 3rd event suppressed (within window)
153/// assert!(policy.register_event(now).is_suppress());
154///
155/// // After window expires, events are allowed again
156/// let after_window = now + Duration::from_secs(61);
157/// assert!(policy.register_event(after_window).is_allow());
158/// assert!(policy.register_event(after_window).is_allow());
159/// ```
160#[derive(Debug, Clone, PartialEq)]
161pub struct TimeWindowPolicy {
162    max_events: usize,
163    window_duration: Duration,
164    event_timestamps: VecDeque<Instant>,
165}
166
167#[cfg(feature = "redis-storage")]
168impl Serialize for TimeWindowPolicy {
169    /// Serialize TimeWindowPolicy for Redis storage.
170    ///
171    /// # Serialization Strategy
172    ///
173    /// Event timestamps (Instant) are serialized as relative offsets from the first
174    /// timestamp in nanoseconds. This approach is chosen because:
175    ///
176    /// 1. Instant is not serializable (system-dependent, no epoch)
177    /// 2. We only care about relative timing between events, not absolute times
178    /// 3. Reduces serialized size (offsets vs full timestamps)
179    ///
180    /// # Important Note on Deserialization
181    ///
182    /// When deserializing, timestamps are reconstructed relative to the current time
183    /// (Instant::now()). This means the time window effectively "resets" when loaded
184    /// from Redis. This is acceptable because:
185    ///
186    /// - The relative spacing between events is preserved
187    /// - Old events will naturally expire based on window_duration
188    /// - This prevents issues with long-running processes where Instant could overflow
189    ///
190    /// **Trade-off**: Events near window expiration may get extra lifetime after reload,
191    /// but this is bounded by the window duration and considered acceptable for the
192    /// distributed rate limiting use case.
193    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
194    where
195        S: serde::Serializer,
196    {
197        use serde::ser::SerializeStruct;
198
199        // Convert Instants to nanoseconds relative to the first timestamp
200        let base = self.event_timestamps.front().copied();
201        let timestamps_nanos: Vec<u64> = if let Some(base_instant) = base {
202            self.event_timestamps
203                .iter()
204                .map(|instant| {
205                    instant
206                        .saturating_duration_since(base_instant)
207                        .as_nanos()
208                        .min(u64::MAX as u128) as u64
209                })
210                .collect()
211        } else {
212            Vec::new()
213        };
214
215        let mut state = serializer.serialize_struct("TimeWindowPolicy", 4)?;
216        state.serialize_field("max_events", &self.max_events)?;
217        state.serialize_field("window_duration_nanos", &self.window_duration.as_nanos())?;
218        state.serialize_field("timestamps_nanos", &timestamps_nanos)?;
219        state.serialize_field("base_timestamp_nanos", &base.map(|_| 0u64))?;
220        state.end()
221    }
222}
223
224#[cfg(feature = "redis-storage")]
225impl<'de> Deserialize<'de> for TimeWindowPolicy {
226    /// Deserialize TimeWindowPolicy from Redis storage.
227    ///
228    /// See `Serialize` implementation docs for important notes about timestamp handling.
229    /// Timestamps are reconstructed relative to the current time, effectively "resetting"
230    /// the time window while preserving relative event spacing.
231    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
232    where
233        D: serde::Deserializer<'de>,
234    {
235        use serde::de::{self, MapAccess, Visitor};
236
237        #[derive(Deserialize)]
238        #[serde(field_identifier, rename_all = "snake_case")]
239        enum Field {
240            MaxEvents,
241            WindowDurationNanos,
242            TimestampsNanos,
243            BaseTimestampNanos,
244        }
245
246        struct TimeWindowPolicyVisitor;
247
248        impl<'de> Visitor<'de> for TimeWindowPolicyVisitor {
249            type Value = TimeWindowPolicy;
250
251            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
252                formatter.write_str("struct TimeWindowPolicy")
253            }
254
255            fn visit_map<V>(self, mut map: V) -> Result<TimeWindowPolicy, V::Error>
256            where
257                V: MapAccess<'de>,
258            {
259                let mut max_events = None;
260                let mut window_duration_nanos = None;
261                let mut timestamps_nanos = None;
262                let mut _base_timestamp_nanos = None;
263
264                while let Some(key) = map.next_key()? {
265                    match key {
266                        Field::MaxEvents => {
267                            if max_events.is_some() {
268                                return Err(de::Error::duplicate_field("max_events"));
269                            }
270                            max_events = Some(map.next_value()?);
271                        }
272                        Field::WindowDurationNanos => {
273                            if window_duration_nanos.is_some() {
274                                return Err(de::Error::duplicate_field("window_duration_nanos"));
275                            }
276                            window_duration_nanos = Some(map.next_value()?);
277                        }
278                        Field::TimestampsNanos => {
279                            if timestamps_nanos.is_some() {
280                                return Err(de::Error::duplicate_field("timestamps_nanos"));
281                            }
282                            timestamps_nanos = Some(map.next_value()?);
283                        }
284                        Field::BaseTimestampNanos => {
285                            _base_timestamp_nanos = Some(map.next_value::<Option<u64>>()?);
286                        }
287                    }
288                }
289
290                let max_events =
291                    max_events.ok_or_else(|| de::Error::missing_field("max_events"))?;
292                let window_duration_nanos: u128 = window_duration_nanos
293                    .ok_or_else(|| de::Error::missing_field("window_duration_nanos"))?;
294                let timestamps_nanos: Vec<u64> =
295                    timestamps_nanos.ok_or_else(|| de::Error::missing_field("timestamps_nanos"))?;
296
297                // Reconstruct Instants relative to current time
298                let now = Instant::now();
299                let event_timestamps: VecDeque<Instant> = timestamps_nanos
300                    .into_iter()
301                    .map(|nanos| now.checked_add(Duration::from_nanos(nanos)).unwrap_or(now))
302                    .collect();
303
304                Ok(TimeWindowPolicy {
305                    max_events,
306                    window_duration: Duration::from_nanos(window_duration_nanos as u64),
307                    event_timestamps,
308                })
309            }
310        }
311
312        const FIELDS: &[&str] = &[
313            "max_events",
314            "window_duration_nanos",
315            "timestamps_nanos",
316            "base_timestamp_nanos",
317        ];
318        deserializer.deserialize_struct("TimeWindowPolicy", FIELDS, TimeWindowPolicyVisitor)
319    }
320}
321
322impl TimeWindowPolicy {
323    /// Create a new time-window policy.
324    ///
325    /// # Arguments
326    /// * `max_events` - Maximum events allowed in the window (must be > 0)
327    /// * `window_duration` - Length of the sliding time window (must be > 0)
328    ///
329    /// # Errors
330    /// Returns `PolicyError::ZeroMaxEvents` if `max_events` is 0.
331    /// Returns `PolicyError::ZeroWindowDuration` if `window_duration` is 0.
332    pub fn new(max_events: usize, window_duration: Duration) -> Result<Self, PolicyError> {
333        if max_events == 0 {
334            return Err(PolicyError::ZeroMaxEvents);
335        }
336        if window_duration.is_zero() {
337            return Err(PolicyError::ZeroWindowDuration);
338        }
339        Ok(Self {
340            max_events,
341            window_duration,
342            event_timestamps: VecDeque::new(),
343        })
344    }
345
346    /// Remove expired events from the window.
347    fn expire_old_events(&mut self, current_time: Instant) {
348        while let Some(&oldest) = self.event_timestamps.front() {
349            if current_time.saturating_duration_since(oldest) > self.window_duration {
350                self.event_timestamps.pop_front();
351            } else {
352                break;
353            }
354        }
355    }
356}
357
358impl RateLimitPolicy for TimeWindowPolicy {
359    fn register_event(&mut self, timestamp: Instant) -> PolicyDecision {
360        self.expire_old_events(timestamp);
361
362        if self.event_timestamps.len() < self.max_events {
363            self.event_timestamps.push_back(timestamp);
364            PolicyDecision::Allow
365        } else {
366            PolicyDecision::Suppress
367        }
368    }
369
370    fn reset(&mut self) {
371        self.event_timestamps.clear();
372    }
373}
374
375/// Exponential backoff policy.
376///
377/// Allows events at exponentially increasing intervals: 1st, 2nd, 4th, 8th, 16th, etc.
378/// Useful for extremely noisy logs.
379///
380/// # Example
381/// ```
382/// use tracing_throttle::{ExponentialBackoffPolicy, RateLimitPolicy};
383/// use std::time::Instant;
384///
385/// let mut policy = ExponentialBackoffPolicy::new();
386/// let now = Instant::now();
387///
388/// assert!(policy.register_event(now).is_allow());  // 1st
389/// assert!(policy.register_event(now).is_allow());  // 2nd
390/// assert!(policy.register_event(now).is_suppress()); // 3rd - suppressed
391/// assert!(policy.register_event(now).is_allow());  // 4th
392/// ```
393#[derive(Debug, Clone, PartialEq)]
394#[cfg_attr(feature = "redis-storage", derive(Serialize, Deserialize))]
395pub struct ExponentialBackoffPolicy {
396    event_count: u64,
397    next_allowed: u64,
398}
399
400impl ExponentialBackoffPolicy {
401    /// Create a new exponential backoff policy.
402    pub fn new() -> Self {
403        Self {
404            event_count: 0,
405            next_allowed: 1,
406        }
407    }
408}
409
410impl Default for ExponentialBackoffPolicy {
411    fn default() -> Self {
412        Self::new()
413    }
414}
415
416impl RateLimitPolicy for ExponentialBackoffPolicy {
417    fn register_event(&mut self, _timestamp: Instant) -> PolicyDecision {
418        self.event_count += 1;
419
420        if self.event_count == self.next_allowed {
421            self.next_allowed = self.next_allowed.saturating_mul(2);
422            PolicyDecision::Allow
423        } else {
424            PolicyDecision::Suppress
425        }
426    }
427
428    fn reset(&mut self) {
429        self.event_count = 0;
430        self.next_allowed = 1;
431    }
432}
433
434/// Token bucket rate limiting policy.
435///
436/// Implements a token bucket algorithm where:
437/// - The bucket holds up to `capacity` tokens
438/// - Tokens refill at a constant rate (`refill_rate` tokens per second)
439/// - Each event consumes 1 token
440/// - Events are suppressed when no tokens are available
441///
442/// This policy provides:
443/// - **Burst tolerance**: Can handle bursts up to `capacity` events
444/// - **Natural recovery**: Tokens automatically refill over time
445/// - **Smooth rate limiting**: Sustained load is limited to `refill_rate`
446/// - **Forgiveness**: After quiet periods, full capacity is restored
447///
448/// # Example
449/// ```
450/// use tracing_throttle::{TokenBucketPolicy, RateLimitPolicy};
451/// use std::time::{Duration, Instant};
452///
453/// // Bucket with capacity 100, refills at 10 tokens/sec
454/// let mut policy = TokenBucketPolicy::new(100.0, 10.0).unwrap();
455/// let start = Instant::now();
456///
457/// // Can burst up to 100 events immediately
458/// for _ in 0..100 {
459///     assert!(policy.register_event(start).is_allow());
460/// }
461///
462/// // 101st event is suppressed (no tokens left)
463/// assert!(policy.register_event(start).is_suppress());
464///
465/// // After 1 second, 10 more tokens available
466/// let later = start + Duration::from_secs(1);
467/// for _ in 0..10 {
468///     assert!(policy.register_event(later).is_allow());
469/// }
470/// ```
471#[derive(Debug, Clone, PartialEq)]
472pub struct TokenBucketPolicy {
473    /// Maximum number of tokens the bucket can hold
474    capacity: f64,
475    /// Rate at which tokens are added (tokens per second)
476    refill_rate: f64,
477    /// Current number of tokens in the bucket
478    tokens: f64,
479    /// Last time the bucket was refilled
480    last_refill: Option<Instant>,
481}
482
483#[cfg(feature = "redis-storage")]
484impl Serialize for TokenBucketPolicy {
485    /// Serialize TokenBucketPolicy for Redis storage.
486    ///
487    /// # Important: last_refill is NOT serialized
488    ///
489    /// The `last_refill` field (`Option<Instant>`) is intentionally not serialized because:
490    ///
491    /// 1. Instant cannot be serialized (system-dependent, no epoch)
492    /// 2. After deserialization, the first event will trigger a refill calculation
493    /// 3. The token count is preserved, so the bucket state is mostly maintained
494    ///
495    /// # Implications
496    ///
497    /// When a TokenBucketPolicy is loaded from Redis:
498    /// - Current token count is restored accurately
499    /// - `last_refill` is set to `None`
500    /// - On first event after reload, tokens will refill based on time since "now"
501    /// - This may allow a small burst beyond the intended rate immediately after reload
502    ///
503    /// **Trade-off**: This is acceptable because the impact is bounded by the bucket
504    /// capacity and only affects the first event after reload.
505    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
506    where
507        S: serde::Serializer,
508    {
509        use serde::ser::SerializeStruct;
510
511        let mut state = serializer.serialize_struct("TokenBucketPolicy", 4)?;
512        state.serialize_field("capacity", &self.capacity)?;
513        state.serialize_field("refill_rate", &self.refill_rate)?;
514        state.serialize_field("tokens", &self.tokens)?;
515        // We intentionally don't serialize last_refill - it will be set on first use after deserialization
516        // This is acceptable because the policy will refill based on the new timestamp
517        state.serialize_field("has_last_refill", &self.last_refill.is_some())?;
518        state.end()
519    }
520}
521
522#[cfg(feature = "redis-storage")]
523impl<'de> Deserialize<'de> for TokenBucketPolicy {
524    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
525    where
526        D: serde::Deserializer<'de>,
527    {
528        use serde::de::{self, MapAccess, Visitor};
529
530        #[derive(Deserialize)]
531        #[serde(field_identifier, rename_all = "snake_case")]
532        enum Field {
533            Capacity,
534            RefillRate,
535            Tokens,
536            HasLastRefill,
537        }
538
539        struct TokenBucketPolicyVisitor;
540
541        impl<'de> Visitor<'de> for TokenBucketPolicyVisitor {
542            type Value = TokenBucketPolicy;
543
544            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
545                formatter.write_str("struct TokenBucketPolicy")
546            }
547
548            fn visit_map<V>(self, mut map: V) -> Result<TokenBucketPolicy, V::Error>
549            where
550                V: MapAccess<'de>,
551            {
552                let mut capacity = None;
553                let mut refill_rate = None;
554                let mut tokens = None;
555                let mut has_last_refill = None;
556
557                while let Some(key) = map.next_key()? {
558                    match key {
559                        Field::Capacity => {
560                            if capacity.is_some() {
561                                return Err(de::Error::duplicate_field("capacity"));
562                            }
563                            capacity = Some(map.next_value()?);
564                        }
565                        Field::RefillRate => {
566                            if refill_rate.is_some() {
567                                return Err(de::Error::duplicate_field("refill_rate"));
568                            }
569                            refill_rate = Some(map.next_value()?);
570                        }
571                        Field::Tokens => {
572                            if tokens.is_some() {
573                                return Err(de::Error::duplicate_field("tokens"));
574                            }
575                            tokens = Some(map.next_value()?);
576                        }
577                        Field::HasLastRefill => {
578                            has_last_refill = Some(map.next_value()?);
579                        }
580                    }
581                }
582
583                let capacity = capacity.ok_or_else(|| de::Error::missing_field("capacity"))?;
584                let refill_rate =
585                    refill_rate.ok_or_else(|| de::Error::missing_field("refill_rate"))?;
586                let tokens = tokens.ok_or_else(|| de::Error::missing_field("tokens"))?;
587                let _has_last_refill = has_last_refill.unwrap_or(false);
588
589                // Set last_refill to None - it will be set on first use after deserialization
590                // This is a safe approach: the bucket will refill based on the next timestamp
591                Ok(TokenBucketPolicy {
592                    capacity,
593                    refill_rate,
594                    tokens,
595                    last_refill: None,
596                })
597            }
598        }
599
600        const FIELDS: &[&str] = &["capacity", "refill_rate", "tokens", "has_last_refill"];
601        deserializer.deserialize_struct("TokenBucketPolicy", FIELDS, TokenBucketPolicyVisitor)
602    }
603}
604
605impl TokenBucketPolicy {
606    /// Create a new token bucket policy.
607    ///
608    /// # Arguments
609    /// * `capacity` - Maximum tokens in the bucket (burst size, must be > 0)
610    /// * `refill_rate` - Tokens added per second (sustained rate, must be > 0)
611    ///
612    /// # Errors
613    /// Returns `PolicyError::ZeroCapacity` if `capacity` is 0 or negative.
614    /// Returns `PolicyError::ZeroRefillRate` if `refill_rate` is 0 or negative.
615    ///
616    /// # Example
617    /// ```
618    /// use tracing_throttle::TokenBucketPolicy;
619    ///
620    /// // 100 token burst, refills at 10/sec
621    /// let policy = TokenBucketPolicy::new(100.0, 10.0).unwrap();
622    /// ```
623    pub fn new(capacity: f64, refill_rate: f64) -> Result<Self, PolicyError> {
624        if capacity <= 0.0 {
625            return Err(PolicyError::ZeroCapacity);
626        }
627        if refill_rate <= 0.0 {
628            return Err(PolicyError::ZeroRefillRate);
629        }
630
631        Ok(Self {
632            capacity,
633            refill_rate,
634            tokens: capacity,
635            last_refill: None,
636        })
637    }
638
639    /// Refill tokens based on elapsed time since last refill.
640    ///
641    /// Handles clock adjustments gracefully - if time goes backwards,
642    /// we simply reset the refill timestamp without adding tokens.
643    fn refill(&mut self, now: Instant) {
644        if let Some(last) = self.last_refill {
645            // Handle time going backwards (NTP adjustments, VM migrations, etc.)
646            if now < last {
647                self.last_refill = Some(now);
648                return;
649            }
650
651            let elapsed = now.duration_since(last).as_secs_f64();
652            let new_tokens = elapsed * self.refill_rate;
653            self.tokens = (self.tokens + new_tokens).min(self.capacity);
654        }
655        self.last_refill = Some(now);
656    }
657}
658
659impl RateLimitPolicy for TokenBucketPolicy {
660    fn register_event(&mut self, timestamp: Instant) -> PolicyDecision {
661        // Refill tokens based on time elapsed
662        self.refill(timestamp);
663
664        // Check if we have a token available
665        if self.tokens >= 1.0 {
666            self.tokens -= 1.0;
667            PolicyDecision::Allow
668        } else {
669            PolicyDecision::Suppress
670        }
671    }
672
673    fn reset(&mut self) {
674        self.tokens = self.capacity;
675        self.last_refill = None;
676    }
677}
678
679/// Convenience enum for common policy types.
680#[derive(Debug, Clone)]
681#[cfg_attr(feature = "redis-storage", derive(Serialize, Deserialize))]
682pub enum Policy {
683    /// Count-based policy
684    CountBased(CountBasedPolicy),
685    /// Time-window policy
686    TimeWindow(TimeWindowPolicy),
687    /// Exponential backoff policy
688    ExponentialBackoff(ExponentialBackoffPolicy),
689    /// Token bucket policy
690    TokenBucket(TokenBucketPolicy),
691}
692
693impl Policy {
694    /// Create a count-based policy.
695    ///
696    /// # Errors
697    /// Returns `PolicyError::ZeroMaxCount` if `max_count` is 0.
698    pub fn count_based(max_count: usize) -> Result<Self, PolicyError> {
699        Ok(Policy::CountBased(CountBasedPolicy::new(max_count)?))
700    }
701
702    /// Create a time-window policy.
703    ///
704    /// # Errors
705    /// Returns `PolicyError::ZeroMaxEvents` if `max_events` is 0.
706    /// Returns `PolicyError::ZeroWindowDuration` if `window` is 0.
707    pub fn time_window(max_events: usize, window: Duration) -> Result<Self, PolicyError> {
708        Ok(Policy::TimeWindow(TimeWindowPolicy::new(
709            max_events, window,
710        )?))
711    }
712
713    /// Create an exponential backoff policy.
714    ///
715    /// This policy has no configurable parameters and cannot fail.
716    pub fn exponential_backoff() -> Self {
717        Policy::ExponentialBackoff(ExponentialBackoffPolicy::new())
718    }
719
720    /// Create a token bucket policy.
721    ///
722    /// # Arguments
723    /// * `capacity` - Maximum tokens (burst size, must be > 0)
724    /// * `refill_rate` - Tokens per second (sustained rate, must be > 0)
725    ///
726    /// # Errors
727    /// Returns `PolicyError::ZeroCapacity` if `capacity` is 0 or negative.
728    /// Returns `PolicyError::ZeroRefillRate` if `refill_rate` is 0 or negative.
729    ///
730    /// # Example
731    /// ```
732    /// use tracing_throttle::Policy;
733    ///
734    /// // Allow bursts of 100, refill at 10/sec
735    /// let policy = Policy::token_bucket(100.0, 10.0).unwrap();
736    /// ```
737    pub fn token_bucket(capacity: f64, refill_rate: f64) -> Result<Self, PolicyError> {
738        Ok(Policy::TokenBucket(TokenBucketPolicy::new(
739            capacity,
740            refill_rate,
741        )?))
742    }
743}
744
745impl RateLimitPolicy for Policy {
746    fn register_event(&mut self, timestamp: Instant) -> PolicyDecision {
747        match self {
748            Policy::CountBased(p) => p.register_event(timestamp),
749            Policy::TimeWindow(p) => p.register_event(timestamp),
750            Policy::ExponentialBackoff(p) => p.register_event(timestamp),
751            Policy::TokenBucket(p) => p.register_event(timestamp),
752        }
753    }
754
755    fn reset(&mut self) {
756        match self {
757            Policy::CountBased(p) => p.reset(),
758            Policy::TimeWindow(p) => p.reset(),
759            Policy::ExponentialBackoff(p) => p.reset(),
760            Policy::TokenBucket(p) => p.reset(),
761        }
762    }
763}
764
765impl PolicyDecision {
766    /// Check if this decision is Allow.
767    pub fn is_allow(&self) -> bool {
768        matches!(self, PolicyDecision::Allow)
769    }
770
771    /// Check if this decision is Suppress.
772    pub fn is_suppress(&self) -> bool {
773        matches!(self, PolicyDecision::Suppress)
774    }
775}
776
777#[cfg(test)]
778mod tests {
779    use super::*;
780
781    #[test]
782    fn test_count_based_policy() {
783        let mut policy = CountBasedPolicy::new(3).unwrap();
784        let now = Instant::now();
785
786        assert_eq!(policy.register_event(now), PolicyDecision::Allow);
787        assert_eq!(policy.register_event(now), PolicyDecision::Allow);
788        assert_eq!(policy.register_event(now), PolicyDecision::Allow);
789        assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
790        assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
791
792        policy.reset();
793        assert_eq!(policy.register_event(now), PolicyDecision::Allow);
794    }
795
796    #[test]
797    fn test_time_window_policy() {
798        let mut policy = TimeWindowPolicy::new(2, Duration::from_secs(1)).unwrap();
799        let now = Instant::now();
800
801        assert_eq!(policy.register_event(now), PolicyDecision::Allow);
802        assert_eq!(policy.register_event(now), PolicyDecision::Allow);
803        assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
804
805        // After window expires, should allow again
806        let later = now + Duration::from_secs(2);
807        assert_eq!(policy.register_event(later), PolicyDecision::Allow);
808    }
809
810    #[test]
811    fn test_exponential_backoff_policy() {
812        let mut policy = ExponentialBackoffPolicy::new();
813        let now = Instant::now();
814
815        // 1st allowed
816        assert_eq!(policy.register_event(now), PolicyDecision::Allow);
817        // 2nd allowed
818        assert_eq!(policy.register_event(now), PolicyDecision::Allow);
819        // 3rd suppressed
820        assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
821        // 4th allowed
822        assert_eq!(policy.register_event(now), PolicyDecision::Allow);
823        // 5th, 6th, 7th suppressed
824        assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
825        assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
826        assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
827        // 8th allowed
828        assert_eq!(policy.register_event(now), PolicyDecision::Allow);
829    }
830
831    #[test]
832    fn test_policy_enum() {
833        let mut policy = Policy::count_based(2).unwrap();
834        let now = Instant::now();
835
836        assert!(policy.register_event(now).is_allow());
837        assert!(policy.register_event(now).is_allow());
838        assert!(policy.register_event(now).is_suppress());
839    }
840
841    // Edge case tests
842    #[test]
843    fn test_count_based_policy_zero_limit() {
844        // Zero limit should be rejected
845        let result = CountBasedPolicy::new(0);
846        assert_eq!(result, Err(PolicyError::ZeroMaxCount));
847    }
848
849    #[test]
850    fn test_count_based_policy_one_limit() {
851        let mut policy = CountBasedPolicy::new(1).unwrap();
852        let now = Instant::now();
853
854        // Only first event allowed
855        assert_eq!(policy.register_event(now), PolicyDecision::Allow);
856        assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
857        assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
858    }
859
860    #[test]
861    fn test_count_based_policy_reset() {
862        let mut policy = CountBasedPolicy::new(2).unwrap();
863        let now = Instant::now();
864
865        // Use up the limit
866        assert_eq!(policy.register_event(now), PolicyDecision::Allow);
867        assert_eq!(policy.register_event(now), PolicyDecision::Allow);
868        assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
869
870        // Reset should restore the limit
871        policy.reset();
872        assert_eq!(policy.register_event(now), PolicyDecision::Allow);
873        assert_eq!(policy.register_event(now), PolicyDecision::Allow);
874        assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
875    }
876
877    #[test]
878    fn test_time_window_policy_zero_duration() {
879        // Zero duration should be rejected
880        let result = TimeWindowPolicy::new(2, Duration::from_secs(0));
881        assert_eq!(result, Err(PolicyError::ZeroWindowDuration));
882    }
883
884    #[test]
885    fn test_time_window_policy_rapid_events() {
886        let mut policy = TimeWindowPolicy::new(3, Duration::from_millis(100)).unwrap();
887        let now = Instant::now();
888
889        // Rapid fire events
890        for i in 0..10 {
891            let decision = policy.register_event(now);
892            if i < 3 {
893                assert_eq!(
894                    decision,
895                    PolicyDecision::Allow,
896                    "Event {} should be allowed",
897                    i
898                );
899            } else {
900                assert_eq!(
901                    decision,
902                    PolicyDecision::Suppress,
903                    "Event {} should be suppressed",
904                    i
905                );
906            }
907        }
908    }
909
910    #[test]
911    fn test_time_window_policy_reset() {
912        let mut policy = TimeWindowPolicy::new(2, Duration::from_secs(60)).unwrap();
913        let now = Instant::now();
914
915        // Use up limit
916        assert_eq!(policy.register_event(now), PolicyDecision::Allow);
917        assert_eq!(policy.register_event(now), PolicyDecision::Allow);
918        assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
919
920        // Reset should clear the window
921        policy.reset();
922        assert_eq!(policy.register_event(now), PolicyDecision::Allow);
923    }
924
925    #[test]
926    fn test_exponential_backoff_large_count() {
927        let mut policy = ExponentialBackoffPolicy::new();
928        let now = Instant::now();
929
930        let expected_allowed = [0, 1, 3, 7, 15, 31, 63]; // 0-indexed: 1st, 2nd, 4th, 8th, 16th, 32nd, 64th
931
932        for i in 0..100 {
933            let decision = policy.register_event(now);
934            if expected_allowed.contains(&i) {
935                assert_eq!(
936                    decision,
937                    PolicyDecision::Allow,
938                    "Event {} should be allowed",
939                    i + 1
940                );
941            } else {
942                assert_eq!(
943                    decision,
944                    PolicyDecision::Suppress,
945                    "Event {} should be suppressed",
946                    i + 1
947                );
948            }
949        }
950    }
951
952    #[test]
953    fn test_exponential_backoff_reset() {
954        let mut policy = ExponentialBackoffPolicy::new();
955        let now = Instant::now();
956
957        // Progress through first few events
958        assert_eq!(policy.register_event(now), PolicyDecision::Allow); // 1st
959        assert_eq!(policy.register_event(now), PolicyDecision::Allow); // 2nd
960        assert_eq!(policy.register_event(now), PolicyDecision::Suppress); // 3rd
961
962        // Reset should start over
963        policy.reset();
964        assert_eq!(policy.register_event(now), PolicyDecision::Allow); // 1st again
965    }
966
967    // Token Bucket Policy Tests
968    #[test]
969    fn test_token_bucket_basic_consumption() {
970        let mut policy = TokenBucketPolicy::new(3.0, 1.0).unwrap();
971        let now = Instant::now();
972
973        // Should allow up to capacity
974        assert_eq!(policy.register_event(now), PolicyDecision::Allow);
975        assert_eq!(policy.register_event(now), PolicyDecision::Allow);
976        assert_eq!(policy.register_event(now), PolicyDecision::Allow);
977        // Bucket empty, should suppress
978        assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
979        assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
980    }
981
982    #[test]
983    fn test_token_bucket_refill_over_time() {
984        let mut policy = TokenBucketPolicy::new(10.0, 10.0).unwrap(); // 10 tokens/sec
985        let now = Instant::now();
986
987        // Use all tokens
988        for _ in 0..10 {
989            assert_eq!(policy.register_event(now), PolicyDecision::Allow);
990        }
991        assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
992
993        // Wait 0.5 seconds - should get 5 tokens back
994        let later = now + Duration::from_millis(500);
995        for i in 0..5 {
996            assert_eq!(
997                policy.register_event(later),
998                PolicyDecision::Allow,
999                "Event {} should be allowed after refill",
1000                i
1001            );
1002        }
1003        assert_eq!(policy.register_event(later), PolicyDecision::Suppress);
1004    }
1005
1006    #[test]
1007    fn test_token_bucket_burst_tolerance() {
1008        let mut policy = TokenBucketPolicy::new(100.0, 1.0).unwrap();
1009        let now = Instant::now();
1010
1011        // Can burst up to full capacity immediately
1012        for i in 0..100 {
1013            assert_eq!(
1014                policy.register_event(now),
1015                PolicyDecision::Allow,
1016                "Event {} in burst should be allowed",
1017                i
1018            );
1019        }
1020        // Then rate limited
1021        assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
1022    }
1023
1024    #[test]
1025    fn test_token_bucket_sustained_rate() {
1026        let mut policy = TokenBucketPolicy::new(10.0, 10.0).unwrap(); // 10/sec sustained, 10 capacity
1027        let now = Instant::now();
1028
1029        // Use all tokens
1030        for _ in 0..10 {
1031            assert_eq!(policy.register_event(now), PolicyDecision::Allow);
1032        }
1033        assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
1034
1035        // Wait 1 second - should get 10 tokens back (capped at capacity)
1036        let later = now + Duration::from_secs(1);
1037        for i in 0..10 {
1038            assert_eq!(
1039                policy.register_event(later),
1040                PolicyDecision::Allow,
1041                "Event {} after 1s should be allowed",
1042                i
1043            );
1044        }
1045        assert_eq!(policy.register_event(later), PolicyDecision::Suppress);
1046
1047        // Wait 0.5 seconds - should get 5 tokens
1048        let even_later = later + Duration::from_millis(500);
1049        for i in 0..5 {
1050            assert_eq!(
1051                policy.register_event(even_later),
1052                PolicyDecision::Allow,
1053                "Event {} after 0.5s should be allowed",
1054                i
1055            );
1056        }
1057        assert_eq!(policy.register_event(even_later), PolicyDecision::Suppress);
1058    }
1059
1060    #[test]
1061    fn test_token_bucket_recovery_after_quiet() {
1062        let mut policy = TokenBucketPolicy::new(5.0, 2.0).unwrap();
1063        let now = Instant::now();
1064
1065        // Use all tokens
1066        for _ in 0..5 {
1067            policy.register_event(now);
1068        }
1069        assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
1070
1071        // Wait long enough to fully recover
1072        let much_later = now + Duration::from_secs(10);
1073        // Should be back to full capacity (5 tokens)
1074        for i in 0..5 {
1075            assert_eq!(
1076                policy.register_event(much_later),
1077                PolicyDecision::Allow,
1078                "Event {} after recovery should be allowed",
1079                i
1080            );
1081        }
1082        assert_eq!(policy.register_event(much_later), PolicyDecision::Suppress);
1083    }
1084
1085    #[test]
1086    fn test_token_bucket_fractional_refill() {
1087        let mut policy = TokenBucketPolicy::new(10.0, 0.5).unwrap(); // 0.5 tokens/sec
1088        let now = Instant::now();
1089
1090        // Use all tokens
1091        for _ in 0..10 {
1092            policy.register_event(now);
1093        }
1094        assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
1095
1096        // Wait 3 seconds - should get 1.5 tokens (only 1 usable)
1097        let later = now + Duration::from_secs(3);
1098        assert_eq!(policy.register_event(later), PolicyDecision::Allow);
1099        assert_eq!(policy.register_event(later), PolicyDecision::Suppress); // 0.5 tokens left, not enough
1100
1101        // Wait 1 more second - should have 1.5 tokens now
1102        let even_later = later + Duration::from_secs(1);
1103        assert_eq!(policy.register_event(even_later), PolicyDecision::Allow);
1104        assert_eq!(policy.register_event(even_later), PolicyDecision::Suppress);
1105    }
1106
1107    #[test]
1108    fn test_token_bucket_reset() {
1109        let mut policy = TokenBucketPolicy::new(5.0, 1.0).unwrap();
1110        let now = Instant::now();
1111
1112        // Use all tokens
1113        for _ in 0..5 {
1114            policy.register_event(now);
1115        }
1116        assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
1117
1118        // Reset should restore full capacity
1119        policy.reset();
1120        for i in 0..5 {
1121            assert_eq!(
1122                policy.register_event(now),
1123                PolicyDecision::Allow,
1124                "Event {} after reset should be allowed",
1125                i
1126            );
1127        }
1128        assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
1129    }
1130
1131    #[test]
1132    fn test_token_bucket_capacity_cap() {
1133        let mut policy = TokenBucketPolicy::new(5.0, 10.0).unwrap();
1134        let now = Instant::now();
1135
1136        // Use some tokens
1137        for _ in 0..3 {
1138            policy.register_event(now);
1139        }
1140
1141        // Wait long time - tokens should cap at capacity (5), not grow unbounded
1142        let much_later = now + Duration::from_secs(100);
1143        for i in 0..5 {
1144            assert_eq!(
1145                policy.register_event(much_later),
1146                PolicyDecision::Allow,
1147                "Event {} should be allowed (capped at capacity)",
1148                i
1149            );
1150        }
1151        assert_eq!(policy.register_event(much_later), PolicyDecision::Suppress);
1152    }
1153
1154    #[test]
1155    fn test_token_bucket_zero_capacity() {
1156        let result = TokenBucketPolicy::new(0.0, 1.0);
1157        assert_eq!(result, Err(PolicyError::ZeroCapacity));
1158    }
1159
1160    #[test]
1161    fn test_token_bucket_negative_capacity() {
1162        let result = TokenBucketPolicy::new(-5.0, 1.0);
1163        assert_eq!(result, Err(PolicyError::ZeroCapacity));
1164    }
1165
1166    #[test]
1167    fn test_token_bucket_zero_refill_rate() {
1168        let result = TokenBucketPolicy::new(10.0, 0.0);
1169        assert_eq!(result, Err(PolicyError::ZeroRefillRate));
1170    }
1171
1172    #[test]
1173    fn test_token_bucket_negative_refill_rate() {
1174        let result = TokenBucketPolicy::new(10.0, -2.0);
1175        assert_eq!(result, Err(PolicyError::ZeroRefillRate));
1176    }
1177
1178    #[test]
1179    fn test_token_bucket_policy_enum() {
1180        let mut policy = Policy::token_bucket(5.0, 2.0).unwrap();
1181        let now = Instant::now();
1182
1183        // Test via Policy enum
1184        for i in 0..5 {
1185            assert!(
1186                policy.register_event(now).is_allow(),
1187                "Event {} should be allowed",
1188                i
1189            );
1190        }
1191        assert!(policy.register_event(now).is_suppress());
1192
1193        // Test reset via enum
1194        policy.reset();
1195        assert!(policy.register_event(now).is_allow());
1196    }
1197
1198    #[test]
1199    fn test_token_bucket_incremental_refill() {
1200        let mut policy = TokenBucketPolicy::new(1.0, 10.0).unwrap(); // 10 tokens/sec, 1 max
1201        let now = Instant::now();
1202
1203        // Use initial token
1204        assert_eq!(policy.register_event(now), PolicyDecision::Allow);
1205        assert_eq!(policy.register_event(now), PolicyDecision::Suppress);
1206
1207        // Incremental refills - 100ms = 1 token
1208        let t1 = now + Duration::from_millis(100);
1209        assert_eq!(policy.register_event(t1), PolicyDecision::Allow);
1210        assert_eq!(policy.register_event(t1), PolicyDecision::Suppress);
1211
1212        let t2 = t1 + Duration::from_millis(100);
1213        assert_eq!(policy.register_event(t2), PolicyDecision::Allow);
1214        assert_eq!(policy.register_event(t2), PolicyDecision::Suppress);
1215    }
1216
1217    #[test]
1218    fn test_token_bucket_same_timestamp_multiple_events() {
1219        // Regression test: multiple events at the same timestamp should not refill
1220        let mut policy = TokenBucketPolicy::new(5.0, 2.0).unwrap();
1221        let start = Instant::now();
1222
1223        // First burst at t=0: use all 5 tokens
1224        for i in 0..5 {
1225            assert_eq!(
1226                policy.register_event(start),
1227                PolicyDecision::Allow,
1228                "Event {} should be allowed",
1229                i
1230            );
1231        }
1232
1233        // Events 6,7,8 at t=0 should be suppressed (no tokens left)
1234        for i in 5..8 {
1235            assert_eq!(
1236                policy.register_event(start),
1237                PolicyDecision::Suppress,
1238                "Event {} should be suppressed (no tokens)",
1239                i
1240            );
1241        }
1242
1243        // After 1 second, should have refilled 2 tokens
1244        let t1 = start + Duration::from_secs(1);
1245
1246        // Events at t=1s: should allow exactly 2
1247        assert_eq!(
1248            policy.register_event(t1),
1249            PolicyDecision::Allow,
1250            "First event after 1s should be allowed"
1251        );
1252        assert_eq!(
1253            policy.register_event(t1),
1254            PolicyDecision::Allow,
1255            "Second event after 1s should be allowed"
1256        );
1257
1258        // Third event at t=1s should be suppressed (only refilled 2 tokens)
1259        assert_eq!(
1260            policy.register_event(t1),
1261            PolicyDecision::Suppress,
1262            "Third event after 1s should be suppressed (only 2 tokens refilled)"
1263        );
1264
1265        // Fourth and fifth should also be suppressed
1266        assert_eq!(policy.register_event(t1), PolicyDecision::Suppress);
1267        assert_eq!(policy.register_event(t1), PolicyDecision::Suppress);
1268    }
1269
1270    #[test]
1271    fn test_token_bucket_time_goes_backwards() {
1272        let mut policy = TokenBucketPolicy::new(10.0, 5.0).unwrap();
1273        let now = Instant::now();
1274
1275        // Use 5 tokens
1276        for _ in 0..5 {
1277            assert_eq!(policy.register_event(now), PolicyDecision::Allow);
1278        }
1279        // 5 tokens remaining
1280
1281        // Time goes forward - should refill 5 tokens (now have 10)
1282        let future = now + Duration::from_secs(1);
1283        for _ in 0..10 {
1284            assert_eq!(policy.register_event(future), PolicyDecision::Allow);
1285        }
1286        // 0 tokens remaining after using all 10
1287
1288        // Time goes backwards (NTP correction, VM migration, etc.)
1289        // Should NOT panic and should NOT add/remove tokens
1290        let past = now + Duration::from_millis(500);
1291        assert!(past < future, "Test setup: past must be before future");
1292
1293        // Should still have 0 tokens (time went backwards, no refill)
1294        assert_eq!(
1295            policy.register_event(past),
1296            PolicyDecision::Suppress,
1297            "Should suppress when no tokens available after time went backwards"
1298        );
1299
1300        // Time moves forward again normally (1 second after 'past')
1301        let future2 = past + Duration::from_secs(1);
1302        // Should refill 5 tokens based on elapsed time from 'past'
1303        for i in 0..5 {
1304            assert_eq!(
1305                policy.register_event(future2),
1306                PolicyDecision::Allow,
1307                "Token {} should be available after normal time progression",
1308                i
1309            );
1310        }
1311
1312        // 6th should be suppressed
1313        assert_eq!(policy.register_event(future2), PolicyDecision::Suppress);
1314    }
1315
1316    #[test]
1317    fn test_time_window_with_many_events() {
1318        // Fill time window with maximum events
1319        let mut policy = TimeWindowPolicy::new(100, Duration::from_secs(60)).unwrap();
1320        let now = Instant::now();
1321
1322        // Add 100 events
1323        for i in 0..100 {
1324            let timestamp = now + Duration::from_millis(i * 10);
1325            policy.register_event(timestamp);
1326        }
1327
1328        // Verify window is full
1329        assert_eq!(
1330            policy.register_event(now + Duration::from_millis(1000)),
1331            PolicyDecision::Suppress
1332        );
1333
1334        // After window expires, should allow again
1335        let later = now + Duration::from_secs(70);
1336        assert_eq!(policy.register_event(later), PolicyDecision::Allow);
1337    }
1338}