this/events/operators/
rate_limit.rs1use crate::config::events::RateLimitConfig;
15use crate::events::context::FlowContext;
16use crate::events::operators::deduplicate::parse_duration;
17use crate::events::operators::{OpResult, PipelineOperator};
18use anyhow::Result;
19use async_trait::async_trait;
20use std::sync::Arc;
21use std::time::{Duration, Instant};
22use tokio::sync::Mutex;
23
24#[derive(Debug)]
26struct TokenBucket {
27 tokens: f64,
29 max_tokens: f64,
31 refill_rate: f64,
33 last_refill: Instant,
35}
36
37impl TokenBucket {
38 fn new(max_tokens: u32, period: Duration) -> Self {
39 let max = max_tokens as f64;
40 let refill_rate = max / period.as_secs_f64();
41 Self {
42 tokens: max,
43 max_tokens: max,
44 refill_rate,
45 last_refill: Instant::now(),
46 }
47 }
48
49 fn try_consume(&mut self) -> bool {
51 self.refill();
52 if self.tokens >= 1.0 {
53 self.tokens -= 1.0;
54 true
55 } else {
56 false
57 }
58 }
59
60 fn refill(&mut self) {
62 let now = Instant::now();
63 let elapsed = now.duration_since(self.last_refill);
64 let new_tokens = elapsed.as_secs_f64() * self.refill_rate;
65 self.tokens = (self.tokens + new_tokens).min(self.max_tokens);
66 self.last_refill = now;
67 }
68}
69
70#[derive(Debug)]
72pub struct RateLimitOp {
73 strategy: String,
75
76 bucket: Arc<Mutex<TokenBucket>>,
78}
79
80impl RateLimitOp {
81 pub fn from_config(config: &RateLimitConfig) -> Result<Self> {
83 let period = parse_duration(&config.per)?;
84 Ok(Self {
85 strategy: config.strategy.clone(),
86 bucket: Arc::new(Mutex::new(TokenBucket::new(config.max, period))),
87 })
88 }
89
90 #[cfg(test)]
92 fn with_params(max: u32, period: Duration) -> Self {
93 Self {
94 strategy: "drop".to_string(),
95 bucket: Arc::new(Mutex::new(TokenBucket::new(max, period))),
96 }
97 }
98}
99
100#[async_trait]
101impl PipelineOperator for RateLimitOp {
102 async fn execute(&self, _ctx: &mut FlowContext) -> Result<OpResult> {
103 let mut bucket = self.bucket.lock().await;
104 if bucket.try_consume() {
105 Ok(OpResult::Continue)
106 } else {
107 match self.strategy.as_str() {
108 "queue" => {
109 tracing::debug!("rate_limit: event queued (falling back to drop)");
112 Ok(OpResult::Drop)
113 }
114 _ => {
115 Ok(OpResult::Drop)
117 }
118 }
119 }
120 }
121
122 fn name(&self) -> &str {
123 "rate_limit"
124 }
125}
126
127#[cfg(test)]
128mod tests {
129 use super::*;
130 use crate::core::events::{EntityEvent, FrameworkEvent};
131 use crate::core::service::LinkService;
132 use serde_json::json;
133 use std::collections::HashMap;
134 use std::sync::Arc;
135 use uuid::Uuid;
136
137 struct MockLinkService;
138
139 #[async_trait]
140 impl LinkService for MockLinkService {
141 async fn create(
142 &self,
143 _: crate::core::link::LinkEntity,
144 ) -> Result<crate::core::link::LinkEntity> {
145 unimplemented!()
146 }
147 async fn get(&self, _: &Uuid) -> Result<Option<crate::core::link::LinkEntity>> {
148 unimplemented!()
149 }
150 async fn list(&self) -> Result<Vec<crate::core::link::LinkEntity>> {
151 unimplemented!()
152 }
153 async fn find_by_source(
154 &self,
155 _: &Uuid,
156 _: Option<&str>,
157 _: Option<&str>,
158 ) -> Result<Vec<crate::core::link::LinkEntity>> {
159 unimplemented!()
160 }
161 async fn find_by_target(
162 &self,
163 _: &Uuid,
164 _: Option<&str>,
165 _: Option<&str>,
166 ) -> Result<Vec<crate::core::link::LinkEntity>> {
167 unimplemented!()
168 }
169 async fn update(
170 &self,
171 _: &Uuid,
172 _: crate::core::link::LinkEntity,
173 ) -> Result<crate::core::link::LinkEntity> {
174 unimplemented!()
175 }
176 async fn delete(&self, _: &Uuid) -> Result<()> {
177 unimplemented!()
178 }
179 async fn delete_by_entity(&self, _: &Uuid) -> Result<()> {
180 unimplemented!()
181 }
182 }
183
184 fn make_context() -> FlowContext {
185 let event = FrameworkEvent::Entity(EntityEvent::Created {
186 entity_type: "user".to_string(),
187 entity_id: Uuid::new_v4(),
188 data: json!({}),
189 });
190 FlowContext::new(
191 event,
192 Arc::new(MockLinkService) as Arc<dyn LinkService>,
193 HashMap::new(),
194 )
195 }
196
197 #[tokio::test]
198 async fn test_rate_limit_allows_within_limit() {
199 let op = RateLimitOp::with_params(3, Duration::from_secs(1));
200
201 for _ in 0..3 {
202 let mut ctx = make_context();
203 let result = op.execute(&mut ctx).await.unwrap();
204 assert!(matches!(result, OpResult::Continue));
205 }
206 }
207
208 #[tokio::test]
209 async fn test_rate_limit_drops_over_limit() {
210 let op = RateLimitOp::with_params(2, Duration::from_secs(1));
211
212 for _ in 0..2 {
214 let mut ctx = make_context();
215 let result = op.execute(&mut ctx).await.unwrap();
216 assert!(matches!(result, OpResult::Continue));
217 }
218
219 let mut ctx = make_context();
221 let result = op.execute(&mut ctx).await.unwrap();
222 assert!(matches!(result, OpResult::Drop));
223 }
224
225 #[tokio::test]
226 async fn test_rate_limit_refills_after_period() {
227 let op = RateLimitOp::with_params(2, Duration::from_millis(50));
228
229 for _ in 0..2 {
231 let mut ctx = make_context();
232 let _ = op.execute(&mut ctx).await.unwrap();
233 }
234
235 let mut ctx = make_context();
237 let result = op.execute(&mut ctx).await.unwrap();
238 assert!(matches!(result, OpResult::Drop));
239
240 tokio::time::sleep(Duration::from_millis(60)).await;
242
243 let mut ctx = make_context();
245 let result = op.execute(&mut ctx).await.unwrap();
246 assert!(matches!(result, OpResult::Continue));
247 }
248
249 #[tokio::test]
250 async fn test_rate_limit_partial_refill() {
251 let op = RateLimitOp::with_params(2, Duration::from_millis(100));
253
254 for _ in 0..2 {
256 let mut ctx = make_context();
257 let _ = op.execute(&mut ctx).await.unwrap();
258 }
259
260 tokio::time::sleep(Duration::from_millis(55)).await;
262
263 let mut ctx = make_context();
265 let result = op.execute(&mut ctx).await.unwrap();
266 assert!(matches!(result, OpResult::Continue));
267
268 let mut ctx = make_context();
270 let result = op.execute(&mut ctx).await.unwrap();
271 assert!(matches!(result, OpResult::Drop));
272 }
273}