Skip to main content

systemprompt_api/services/middleware/
ip_ban.rs

1use axum::extract::Request;
2use axum::middleware::Next;
3use axum::response::{IntoResponse, Response};
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::Arc;
7use systemprompt_models::api::ApiError;
8use systemprompt_users::BannedIpRepository;
9use tracing::warn;
10
11#[derive(Clone, Copy, Debug)]
12pub struct IpBanMiddleware;
13
14impl IpBanMiddleware {
15    fn extract_ip(request: &Request) -> Option<String> {
16        request
17            .headers()
18            .get("x-forwarded-for")
19            .and_then(|v| v.to_str().ok())
20            .and_then(|s| s.split(',').next())
21            .map(|s| s.trim().to_string())
22            .or_else(|| {
23                request
24                    .headers()
25                    .get("x-real-ip")
26                    .and_then(|v| v.to_str().ok())
27                    .map(ToString::to_string)
28            })
29            .or_else(|| {
30                request
31                    .headers()
32                    .get("cf-connecting-ip")
33                    .and_then(|v| v.to_str().ok())
34                    .map(ToString::to_string)
35            })
36    }
37}
38
39pub async fn ip_ban_middleware(
40    request: Request,
41    next: Next,
42    banned_ip_repo: Arc<BannedIpRepository>,
43) -> Response {
44    let ip_address = IpBanMiddleware::extract_ip(&request);
45
46    if let Some(ip) = &ip_address {
47        match banned_ip_repo.is_banned(ip).await {
48            Ok(true) => {
49                warn!(ip = %ip, path = %request.uri().path(), "Blocked request from banned IP");
50                let api_error = ApiError::forbidden("Access denied");
51                let mut response = api_error.into_response();
52                response.headers_mut().insert(
53                    "X-Blocked-Reason",
54                    http::HeaderValue::from_static("ip-banned"),
55                );
56                return response;
57            },
58            Ok(false) => {},
59            Err(e) => {
60                tracing::error!(error = %e, ip = %ip, "Failed to check IP ban status");
61            },
62        }
63    }
64
65    next.run(request).await
66}
67
68pub fn ip_ban_layer(
69    banned_ip_repo: Arc<BannedIpRepository>,
70) -> axum::middleware::FromFnLayer<
71    impl Fn(Request, Next) -> Pin<Box<dyn Future<Output = Response> + Send>> + Clone + Send,
72    (),
73    Request,
74> {
75    axum::middleware::from_fn(move |req: Request, next: Next| {
76        let repo = banned_ip_repo.clone();
77        let fut: Pin<Box<dyn Future<Output = Response> + Send>> =
78            Box::pin(async move { ip_ban_middleware(req, next, repo).await });
79        fut
80    })
81}