Skip to main content

uvb_storage_api/
ratelimit.rs

1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use std::fmt;
4use thiserror::Error;
5use uvb_core::TenantId;
6
7#[derive(Debug, Error)]
8pub enum RateLimitError {
9    #[error("storage error: {0}")]
10    Storage(String),
11
12    #[error("serialization error: {0}")]
13    Serialization(String),
14}
15
16/// Rate limit scope for different types of operations
17#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
18pub enum RateLimitScope {
19    /// Rate limit per subject (user)
20    Subject {
21        user_id: String,
22        tenant_id: TenantId,
23    },
24
25    /// Rate limit per IP address
26    IpAddress { ip: String },
27
28    /// Rate limit per factor and subject
29    FactorAttempt {
30        user_id: String,
31        tenant_id: TenantId,
32        factor_id: String,
33    },
34
35    /// Rate limit per API endpoint
36    Endpoint { path: String, method: String },
37
38    /// Custom rate limit key
39    Custom { key: String },
40}
41
42impl fmt::Display for RateLimitScope {
43    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44        match self {
45            RateLimitScope::Subject { user_id, tenant_id } => {
46                write!(f, "subject:{}:{}", tenant_id, user_id)
47            }
48            RateLimitScope::IpAddress { ip } => write!(f, "ip:{}", ip),
49            RateLimitScope::FactorAttempt {
50                user_id,
51                tenant_id,
52                factor_id,
53            } => write!(f, "factor:{}:{}:{}", tenant_id, user_id, factor_id),
54            RateLimitScope::Endpoint { path, method } => write!(f, "endpoint:{}:{}", method, path),
55            RateLimitScope::Custom { key } => write!(f, "custom:{}", key),
56        }
57    }
58}
59
60/// Rate limit configuration
61#[derive(Clone, Debug, Serialize, Deserialize)]
62pub struct RateLimitConfig {
63    /// Maximum number of attempts allowed
64    pub max_attempts: u32,
65
66    /// Time window in seconds
67    pub window_secs: u64,
68
69    /// Optional penalty duration in seconds after exceeding limit
70    pub penalty_secs: Option<u64>,
71}
72
73impl RateLimitConfig {
74    pub fn new(max_attempts: u32, window_secs: u64) -> Self {
75        Self {
76            max_attempts,
77            window_secs,
78            penalty_secs: None,
79        }
80    }
81
82    pub fn with_penalty(mut self, penalty_secs: u64) -> Self {
83        self.penalty_secs = Some(penalty_secs);
84        self
85    }
86}
87
88/// Rate limit check result
89#[derive(Clone, Debug, Serialize, Deserialize)]
90pub struct RateLimitResult {
91    /// Whether the request is allowed
92    pub allowed: bool,
93
94    /// Current attempt count in the window
95    pub current_attempts: u32,
96
97    /// Maximum attempts allowed
98    pub max_attempts: u32,
99
100    /// Remaining attempts before hitting the limit
101    pub remaining_attempts: u32,
102
103    /// Unix timestamp when the rate limit window resets
104    pub reset_at: i64,
105
106    /// Optional: Unix timestamp when penalty expires (if in penalty period)
107    pub penalty_expires_at: Option<i64>,
108}
109
110impl RateLimitResult {
111    pub fn allowed(current_attempts: u32, max_attempts: u32, reset_at: i64) -> Self {
112        Self {
113            allowed: true,
114            current_attempts,
115            max_attempts,
116            remaining_attempts: max_attempts.saturating_sub(current_attempts),
117            reset_at,
118            penalty_expires_at: None,
119        }
120    }
121
122    pub fn denied(current_attempts: u32, max_attempts: u32, reset_at: i64) -> Self {
123        Self {
124            allowed: false,
125            current_attempts,
126            max_attempts,
127            remaining_attempts: 0,
128            reset_at,
129            penalty_expires_at: None,
130        }
131    }
132
133    pub fn with_penalty(mut self, penalty_expires_at: i64) -> Self {
134        self.penalty_expires_at = Some(penalty_expires_at);
135        self
136    }
137}
138
139/// Trait for rate limiting storage
140///
141/// Implementations should use efficient counters with TTL where possible.
142/// Redis is ideal for this, but in-memory and SQL implementations are also provided.
143#[async_trait]
144pub trait RateLimitStore: Send + Sync {
145    /// Check if an operation is allowed under the rate limit and increment counter
146    ///
147    /// This is an atomic "check and increment" operation:
148    /// - If under the limit: increment and return allowed=true
149    /// - If at or over the limit: return allowed=false
150    /// - Automatically handles window expiration
151    async fn check_and_increment(
152        &self,
153        scope: &RateLimitScope,
154        config: &RateLimitConfig,
155    ) -> Result<RateLimitResult, RateLimitError>;
156
157    /// Get current rate limit status without incrementing
158    async fn check(
159        &self,
160        scope: &RateLimitScope,
161        config: &RateLimitConfig,
162    ) -> Result<RateLimitResult, RateLimitError>;
163
164    /// Reset rate limit for a scope (useful for admin overrides)
165    async fn reset(&self, scope: &RateLimitScope) -> Result<(), RateLimitError>;
166
167    /// Apply a penalty (temporary ban) for a scope
168    async fn apply_penalty(
169        &self,
170        scope: &RateLimitScope,
171        penalty_secs: u64,
172    ) -> Result<(), RateLimitError>;
173
174    /// Check if a scope is currently in penalty period
175    async fn is_penalized(&self, scope: &RateLimitScope) -> Result<bool, RateLimitError>;
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181
182    #[test]
183    fn test_rate_limit_scope_display() {
184        let scope = RateLimitScope::Subject {
185            user_id: "user_1".to_string(),
186            tenant_id: TenantId::new("tenant_a"),
187        };
188        assert_eq!(scope.to_string(), "subject:tenant_a:user_1");
189
190        let scope = RateLimitScope::IpAddress {
191            ip: "203.0.113.1".to_string(),
192        };
193        assert_eq!(scope.to_string(), "ip:203.0.113.1");
194
195        let scope = RateLimitScope::FactorAttempt {
196            user_id: "user_1".to_string(),
197            tenant_id: TenantId::new("tenant_a"),
198            factor_id: "totp".to_string(),
199        };
200        assert_eq!(scope.to_string(), "factor:tenant_a:user_1:totp");
201    }
202
203    #[test]
204    fn test_rate_limit_result() {
205        let result = RateLimitResult::allowed(3, 10, 1234567890);
206        assert!(result.allowed);
207        assert_eq!(result.current_attempts, 3);
208        assert_eq!(result.remaining_attempts, 7);
209
210        let result = RateLimitResult::denied(10, 10, 1234567890);
211        assert!(!result.allowed);
212        assert_eq!(result.remaining_attempts, 0);
213    }
214}