1use actix_service::{Service, Transform};
2use actix_web::{
3 dev::{ServiceRequest, ServiceResponse},
4 error::Error as ActixError,
5 http::{header, StatusCode},
6 HttpResponse,
7};
8use futures_util::future::{ready, LocalBoxFuture, Ready};
9use std::{
10 collections::{HashMap, HashSet},
11 rc::Rc,
12 sync::{Arc, Mutex},
13 time::{Duration, Instant},
14};
15use thiserror::Error;
16
17pub type ConditionFunction = Rc<Box<dyn for<'a> Fn(&'a ServiceRequest) -> bool + 'static>>;
18pub type InterceptFunction = Rc<dyn Fn(&ServiceRequest) -> bool>;
19pub type RateLimiters = Arc<Mutex<HashMap<String, (u64, Instant)>>>;
20pub type AllowedOrigins = HashSet<String>;
21
22impl std::fmt::Debug for UnifiedMiddleware {
23 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24 f.debug_struct("UnifiedMiddleware")
25 .field("allowed_origins", &self.allowed_origins)
26 .field("max_requests", &self.max_requests)
27 .field("window_duration", &self.window_duration)
28 .finish()
29 }
30}
31
32pub struct UnifiedMiddleware {
33 pub allowed_origins: AllowedOrigins,
34 pub rate_limiters: RateLimiters,
35 pub max_requests: usize,
36 pub window_duration: Duration,
37 pub intercept_dependencies: InterceptFunction,
38 pub condition: ConditionFunction,
39}
40
41#[derive(Debug, Error)]
42pub enum UnifiedError {
43 #[error("An internal error occurred in the middleware.")]
44 InternalMiddlewareError,
45 #[error("Invalid request.")]
46 InvalidRequest,
47 #[error("Unauthorized access.")]
48 Unauthorized,
49}
50
51impl actix_web::ResponseError for UnifiedError {
52 fn status_code(&self) -> StatusCode {
53 match self {
54 UnifiedError::InternalMiddlewareError => StatusCode::INTERNAL_SERVER_ERROR,
55 UnifiedError::InvalidRequest => StatusCode::BAD_REQUEST,
56 UnifiedError::Unauthorized => StatusCode::UNAUTHORIZED,
57 }
58 }
59
60 fn error_response(&self) -> HttpResponse {
61 HttpResponse::build(self.status_code())
62 .content_type("application/json")
63 .body(format!("{{\"error\": \"{}\"}}", self))
64 }
65}
66
67pub type OptionalConditionFunction =
68 Option<Box<dyn for<'a> Fn(&'a ServiceRequest) -> bool + 'static>>;
69
70impl UnifiedMiddleware {
71 pub fn new(
83 allowed_origins: String,
84 rate_limiters: RateLimiters,
85 max_requests: usize,
86 window_duration: Duration,
87 intercept_dependencies: InterceptFunction,
88 condition: OptionalConditionFunction,
89 ) -> Self {
90 let origins: AllowedOrigins =
91 allowed_origins.split(',').map(|s| s.trim().to_string()).collect();
92
93 let default_condition: Box<dyn for<'a> Fn(&'a ServiceRequest) -> bool + 'static> =
94 Box::new(|_| true);
95
96 Self {
97 allowed_origins: origins,
98 rate_limiters,
99 max_requests,
100 window_duration,
101 intercept_dependencies,
102 condition: Rc::new(condition.unwrap_or(default_condition)),
103 }
104 }
105
106 pub fn simple(
118 allowed_origins: Vec<String>,
119 max_requests: usize,
120 window_duration: Duration,
121 ) -> Self {
122 Self::new(
123 allowed_origins.join(","),
124 Arc::new(Mutex::new(HashMap::new())),
125 max_requests,
126 window_duration,
127 Rc::new(|_| true),
128 Some(Box::new(|_| true)),
129 )
130 }
131}
132
133impl<S, B> Transform<S, ServiceRequest> for UnifiedMiddleware
134where
135 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = ActixError> + 'static,
136 B: 'static,
137{
138 type Response = ServiceResponse<B>;
139 type Error = ActixError;
140 type Transform = UnifiedMiddlewareService<S>;
141 type InitError = ();
142 type Future = Ready<Result<Self::Transform, Self::InitError>>;
143
144 fn new_transform(&self, service: S) -> Self::Future {
145 ready(Ok(UnifiedMiddlewareService {
146 service: Rc::new(service),
147 allowed_origins: self.allowed_origins.clone(),
148 rate_limiters: self.rate_limiters.clone(),
149 max_requests: self.max_requests,
150 window_duration: self.window_duration,
151 intercept_dependencies: self.intercept_dependencies.clone(),
152 condition: self.condition.clone(),
153 }))
154 }
155}
156
157pub struct UnifiedMiddlewareService<S> {
158 service: Rc<S>,
159 allowed_origins: AllowedOrigins,
160 rate_limiters: RateLimiters,
161 max_requests: usize,
162 window_duration: Duration,
163 intercept_dependencies: InterceptFunction,
164 condition: ConditionFunction,
165}
166
167impl<S, B> Service<ServiceRequest> for UnifiedMiddlewareService<S>
168where
169 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = ActixError> + 'static,
170 B: 'static,
171{
172 type Response = ServiceResponse<B>;
173 type Error = ActixError;
174 type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
175
176 fn poll_ready(
177 &self,
178 cx: &mut std::task::Context<'_>,
179 ) -> std::task::Poll<Result<(), Self::Error>> {
180 self.service.poll_ready(cx)
181 }
182
183 fn call(&self, req: ServiceRequest) -> Self::Future {
184 let service = self.service.clone();
185 let condition = self.condition.clone();
186 let intercept = self.intercept_dependencies.clone();
187 let allowed_origins = self.allowed_origins.clone();
188 let rate_limiters = self.rate_limiters.clone();
189 let max_requests = self.max_requests;
190 let window_duration = self.window_duration;
191
192 Box::pin(async move {
193 if !(*condition)(&req) {
195 return service.call(req).await;
196 }
197
198 if (*intercept)(&req) {
199 check_origin(&req, &allowed_origins)?;
201
202 check_rate_limit(&req, rate_limiters, max_requests, window_duration)?;
204 }
205
206 service.call(req).await
207 })
208 }
209}
210
211fn check_origin(req: &ServiceRequest, allowed_origins: &AllowedOrigins) -> Result<(), ActixError> {
213 if let Some(origin) = req.headers().get(header::ORIGIN) {
214 if let Ok(origin_str) = origin.to_str() {
215 if !allowed_origins.contains(origin_str) && !allowed_origins.contains("*") {
216 return Err(UnifiedError::Unauthorized.into());
217 }
218 }
219 }
220 Ok(())
221}
222
223fn check_rate_limit(
225 req: &ServiceRequest,
226 rate_limiters: RateLimiters,
227 max_requests: usize,
228 window_duration: Duration,
229) -> Result<(), ActixError> {
230 let client_ip = match req.connection_info().realip_remote_addr() {
231 Some(ip) => ip.to_string(),
232 None => "unknown".to_string(),
233 };
234
235 let should_limit =
236 update_rate_limiter(&client_ip, rate_limiters, max_requests, window_duration)?;
237
238 if should_limit {
239 return Err(ActixError::from(UnifiedError::InvalidRequest));
240 }
241
242 Ok(())
243}
244
245fn update_rate_limiter(
247 client_ip: &str,
248 rate_limiters: RateLimiters,
249 max_requests: usize,
250 window_duration: Duration,
251) -> Result<bool, ActixError> {
252 let mut limiters = rate_limiters.lock().map_err(|_| UnifiedError::InternalMiddlewareError)?;
253
254 let now = Instant::now();
255 let entry = limiters.entry(client_ip.to_string()).or_insert_with(|| (0, now));
256
257 if now.duration_since(entry.1) > window_duration {
258 *entry = (1, now);
260 Ok(false)
261 } else {
262 entry.0 += 1;
264 Ok(entry.0 > max_requests as u64)
265 }
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271 use actix_web::{
272 dev::Service,
273 http::StatusCode,
274 test::{init_service, TestRequest},
275 web, App, HttpResponse,
276 };
277 use std::{sync::Arc, time::Duration};
278
279 async fn test_handler() -> HttpResponse {
280 HttpResponse::Ok().body("test success")
281 }
282
283 #[actix_web::test]
284 async fn test_rate_limiting() {
285 let rate_limiters = Arc::new(Mutex::new(HashMap::new()));
286 let max_requests = 2;
287 let window_duration = Duration::from_secs(1);
288
289 let middleware = UnifiedMiddleware::new(
290 "*".to_string(),
291 rate_limiters,
292 max_requests,
293 window_duration,
294 Rc::new(|_| true),
295 None,
296 );
297
298 let app =
299 init_service(App::new().wrap(middleware).route("/test", web::get().to(test_handler)))
300 .await;
301
302 for _ in 0..max_requests {
304 let req = TestRequest::get().uri("/test").to_request();
305 let resp = app.call(req).await.unwrap();
306 assert_eq!(resp.status(), StatusCode::OK);
307 }
308
309 let req = TestRequest::get().uri("/test").to_request();
311 let resp = app.call(req).await;
312 assert!(resp.is_err());
313 }
314
315 #[actix_web::test]
316 async fn test_allowed_origins() {
317 let rate_limiters = Arc::new(Mutex::new(HashMap::new()));
318 let allowed_origins = "https://example.com,https://test.com".to_string();
319
320 let middleware = UnifiedMiddleware::new(
321 allowed_origins,
322 rate_limiters,
323 100,
324 Duration::from_secs(60),
325 Rc::new(|_| true),
326 None,
327 );
328
329 let app =
330 init_service(App::new().wrap(middleware).route("/test", web::get().to(test_handler)))
331 .await;
332
333 let mut req = TestRequest::get().uri("/test");
335 req = req.insert_header((header::ORIGIN, "https://example.com"));
336 let resp = app.call(req.to_request()).await.unwrap();
337 assert_eq!(resp.status(), StatusCode::OK);
338
339 let mut req = TestRequest::get().uri("/test");
341 req = req.insert_header((header::ORIGIN, "https://unauthorized.com"));
342 let resp = app.call(req.to_request()).await;
343 assert!(resp.is_err());
344 }
345
346 #[actix_web::test]
347 async fn test_reset_rate_limiting_window() {
348 let rate_limiters = Arc::new(Mutex::new(HashMap::new()));
349 let max_requests = 1;
350 let window_duration = Duration::from_millis(10); let middleware = UnifiedMiddleware::new(
353 "*".to_string(),
354 rate_limiters.clone(),
355 max_requests,
356 window_duration,
357 Rc::new(|_| true),
358 None,
359 );
360
361 let app =
362 init_service(App::new().wrap(middleware).route("/test", web::get().to(test_handler)))
363 .await;
364
365 let req = TestRequest::get().uri("/test").to_request();
367 let resp = app.call(req).await.unwrap();
368 assert_eq!(resp.status(), StatusCode::OK);
369
370 let req = TestRequest::get().uri("/test").to_request();
372 let resp = app.call(req).await;
373 assert!(resp.is_err());
374
375 tokio::time::sleep(window_duration * 2).await;
377
378 let req = TestRequest::get().uri("/test").to_request();
380 let resp = app.call(req).await.unwrap();
381 assert_eq!(resp.status(), StatusCode::OK);
382 }
383}