1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use std::fmt;
4use thiserror::Error;
5use uvb_core::TenantId;
6
7#[derive(Debug, Error)]
8pub enum RateLimitError {
9 #[error("storage error: {0}")]
10 Storage(String),
11
12 #[error("serialization error: {0}")]
13 Serialization(String),
14}
15
16#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
18pub enum RateLimitScope {
19 Subject {
21 user_id: String,
22 tenant_id: TenantId,
23 },
24
25 IpAddress { ip: String },
27
28 FactorAttempt {
30 user_id: String,
31 tenant_id: TenantId,
32 factor_id: String,
33 },
34
35 Endpoint { path: String, method: String },
37
38 Custom { key: String },
40}
41
42impl fmt::Display for RateLimitScope {
43 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44 match self {
45 RateLimitScope::Subject { user_id, tenant_id } => {
46 write!(f, "subject:{}:{}", tenant_id, user_id)
47 }
48 RateLimitScope::IpAddress { ip } => write!(f, "ip:{}", ip),
49 RateLimitScope::FactorAttempt {
50 user_id,
51 tenant_id,
52 factor_id,
53 } => write!(f, "factor:{}:{}:{}", tenant_id, user_id, factor_id),
54 RateLimitScope::Endpoint { path, method } => write!(f, "endpoint:{}:{}", method, path),
55 RateLimitScope::Custom { key } => write!(f, "custom:{}", key),
56 }
57 }
58}
59
60#[derive(Clone, Debug, Serialize, Deserialize)]
62pub struct RateLimitConfig {
63 pub max_attempts: u32,
65
66 pub window_secs: u64,
68
69 pub penalty_secs: Option<u64>,
71}
72
73impl RateLimitConfig {
74 pub fn new(max_attempts: u32, window_secs: u64) -> Self {
75 Self {
76 max_attempts,
77 window_secs,
78 penalty_secs: None,
79 }
80 }
81
82 pub fn with_penalty(mut self, penalty_secs: u64) -> Self {
83 self.penalty_secs = Some(penalty_secs);
84 self
85 }
86}
87
88#[derive(Clone, Debug, Serialize, Deserialize)]
90pub struct RateLimitResult {
91 pub allowed: bool,
93
94 pub current_attempts: u32,
96
97 pub max_attempts: u32,
99
100 pub remaining_attempts: u32,
102
103 pub reset_at: i64,
105
106 pub penalty_expires_at: Option<i64>,
108}
109
110impl RateLimitResult {
111 pub fn allowed(current_attempts: u32, max_attempts: u32, reset_at: i64) -> Self {
112 Self {
113 allowed: true,
114 current_attempts,
115 max_attempts,
116 remaining_attempts: max_attempts.saturating_sub(current_attempts),
117 reset_at,
118 penalty_expires_at: None,
119 }
120 }
121
122 pub fn denied(current_attempts: u32, max_attempts: u32, reset_at: i64) -> Self {
123 Self {
124 allowed: false,
125 current_attempts,
126 max_attempts,
127 remaining_attempts: 0,
128 reset_at,
129 penalty_expires_at: None,
130 }
131 }
132
133 pub fn with_penalty(mut self, penalty_expires_at: i64) -> Self {
134 self.penalty_expires_at = Some(penalty_expires_at);
135 self
136 }
137}
138
139#[async_trait]
144pub trait RateLimitStore: Send + Sync {
145 async fn check_and_increment(
152 &self,
153 scope: &RateLimitScope,
154 config: &RateLimitConfig,
155 ) -> Result<RateLimitResult, RateLimitError>;
156
157 async fn check(
159 &self,
160 scope: &RateLimitScope,
161 config: &RateLimitConfig,
162 ) -> Result<RateLimitResult, RateLimitError>;
163
164 async fn reset(&self, scope: &RateLimitScope) -> Result<(), RateLimitError>;
166
167 async fn apply_penalty(
169 &self,
170 scope: &RateLimitScope,
171 penalty_secs: u64,
172 ) -> Result<(), RateLimitError>;
173
174 async fn is_penalized(&self, scope: &RateLimitScope) -> Result<bool, RateLimitError>;
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181
182 #[test]
183 fn test_rate_limit_scope_display() {
184 let scope = RateLimitScope::Subject {
185 user_id: "user_1".to_string(),
186 tenant_id: TenantId::new("tenant_a"),
187 };
188 assert_eq!(scope.to_string(), "subject:tenant_a:user_1");
189
190 let scope = RateLimitScope::IpAddress {
191 ip: "203.0.113.1".to_string(),
192 };
193 assert_eq!(scope.to_string(), "ip:203.0.113.1");
194
195 let scope = RateLimitScope::FactorAttempt {
196 user_id: "user_1".to_string(),
197 tenant_id: TenantId::new("tenant_a"),
198 factor_id: "totp".to_string(),
199 };
200 assert_eq!(scope.to_string(), "factor:tenant_a:user_1:totp");
201 }
202
203 #[test]
204 fn test_rate_limit_result() {
205 let result = RateLimitResult::allowed(3, 10, 1234567890);
206 assert!(result.allowed);
207 assert_eq!(result.current_attempts, 3);
208 assert_eq!(result.remaining_attempts, 7);
209
210 let result = RateLimitResult::denied(10, 10, 1234567890);
211 assert!(!result.allowed);
212 assert_eq!(result.remaining_attempts, 0);
213 }
214}