salesforce_client/
rate_limit.rs1use crate::error::{SfError, SfResult};
6use governor::{Quota, RateLimiter as GovernorRateLimiter};
7use std::num::NonZeroU32;
8use std::sync::Arc;
9use std::time::Duration;
10use tracing::{debug, warn};
11
12#[derive(Debug, Clone)]
14pub struct RateLimitConfig {
15 pub requests_per_second: u32,
17
18 pub burst_size: u32,
20}
21
22impl Default for RateLimitConfig {
23 fn default() -> Self {
24 Self {
27 requests_per_second: 4,
28 burst_size: 10,
29 }
30 }
31}
32
33impl RateLimitConfig {
34 pub fn new() -> Self {
36 Self::default()
37 }
38
39 pub fn requests_per_second(mut self, rps: u32) -> Self {
41 self.requests_per_second = rps;
42 self
43 }
44
45 pub fn burst_size(mut self, size: u32) -> Self {
47 self.burst_size = size;
48 self
49 }
50
51 pub fn unlimited() -> Self {
53 Self {
54 requests_per_second: u32::MAX,
55 burst_size: u32::MAX,
56 }
57 }
58}
59
60pub struct RateLimiter {
62 limiter: Arc<
63 GovernorRateLimiter<
64 governor::state::NotKeyed,
65 governor::state::InMemoryState,
66 governor::clock::DefaultClock,
67 >,
68 >,
69 enabled: bool,
70}
71
72impl RateLimiter {
73 pub fn new(config: RateLimitConfig) -> Self {
75 let enabled = config.requests_per_second < u32::MAX;
76
77 if !enabled {
78 debug!("Rate limiting disabled");
79 return Self {
80 limiter: Arc::new(GovernorRateLimiter::direct(Quota::per_second(
81 NonZeroU32::new(1).unwrap(),
82 ))),
83 enabled: false,
84 };
85 }
86
87 let quota = Quota::per_second(
89 NonZeroU32::new(config.requests_per_second).unwrap_or(NonZeroU32::new(1).unwrap()),
90 )
91 .allow_burst(NonZeroU32::new(config.burst_size).unwrap_or(NonZeroU32::new(1).unwrap()));
92
93 let limiter = GovernorRateLimiter::direct(quota);
94
95 debug!(
96 "Rate limiter initialized: {} req/s, burst {}",
97 config.requests_per_second, config.burst_size
98 );
99
100 Self {
101 limiter: Arc::new(limiter),
102 enabled: true,
103 }
104 }
105
106 pub async fn acquire(&self) -> SfResult<()> {
110 if !self.enabled {
111 return Ok(());
112 }
113
114 self.limiter.until_ready().await;
116 debug!("Rate limit check passed");
117 Ok(())
118 }
119
120 pub fn try_acquire(&self) -> SfResult<()> {
124 if !self.enabled {
125 return Ok(());
126 }
127
128 match self.limiter.check() {
129 Ok(_) => Ok(()),
130 Err(not_until) => {
131 let wait_time = not_until.wait_time_from(governor::clock::Clock::now(
132 &governor::clock::DefaultClock::default(),
133 ));
134
135 warn!("Rate limit exceeded, need to wait {:?}", wait_time);
136
137 Err(SfError::RateLimit {
138 retry_after: Some(wait_time.as_secs()),
139 })
140 }
141 }
142 }
143
144 pub fn status(&self) -> RateLimitStatus {
146 if !self.enabled {
147 return RateLimitStatus {
148 available: true,
149 wait_time: None,
150 };
151 }
152
153 match self.limiter.check() {
154 Ok(_) => RateLimitStatus {
155 available: true,
156 wait_time: None,
157 },
158 Err(not_until) => {
159 let wait_time = not_until.wait_time_from(governor::clock::Clock::now(
160 &governor::clock::DefaultClock::default(),
161 ));
162
163 RateLimitStatus {
164 available: false,
165 wait_time: Some(wait_time),
166 }
167 }
168 }
169 }
170}
171
172#[derive(Debug, Clone)]
174pub struct RateLimitStatus {
175 pub available: bool,
177
178 pub wait_time: Option<Duration>,
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185
186 #[test]
187 fn test_rate_limit_config() {
188 let config = RateLimitConfig::new()
189 .requests_per_second(10)
190 .burst_size(20);
191
192 assert_eq!(config.requests_per_second, 10);
193 assert_eq!(config.burst_size, 20);
194 }
195
196 #[tokio::test]
197 async fn test_rate_limiter_acquire() {
198 let config = RateLimitConfig::new()
199 .requests_per_second(100) .burst_size(10);
201
202 let limiter = RateLimiter::new(config);
203
204 assert!(limiter.acquire().await.is_ok());
206 }
207
208 #[test]
209 fn test_rate_limiter_disabled() {
210 let config = RateLimitConfig::unlimited();
211 let limiter = RateLimiter::new(config);
212
213 assert!(!limiter.enabled);
214 assert!(limiter.try_acquire().is_ok());
215 }
216}