ts_webapi/middleware/
api_key.rs

1//! API key middleware, validates an API key is present and in the allow list.
2use alloc::sync::Arc;
3
4use bytes::Bytes;
5use futures_core::future::BoxFuture;
6use http::{HeaderName, Method, Request, Response};
7use http_body_util::Full;
8use tower_http::auth::{AsyncAuthorizeRequest, AsyncRequireAuthorizationLayer};
9
10use crate::ErrorResponse;
11
12/// API Key request extensions
13#[derive(Debug, Clone)]
14pub struct ApiKey(pub String);
15
16/// API key middleware, validates an API key is present and in the allow list.
17#[derive(Debug, Clone)]
18pub struct ApiKeyAuth {
19    /// The allow list of API keys.
20    allowed_keys: Arc<Vec<String>>,
21    /// The header to find the API key.
22    header: Arc<HeaderName>,
23}
24impl ApiKeyAuth {
25    /// Create new API key authorization layer
26    pub fn new(header: HeaderName, allowed_keys: Vec<String>) -> Self {
27        Self {
28            allowed_keys: Arc::new(allowed_keys),
29            header: Arc::new(header),
30        }
31    }
32
33    /// Create a new [`AsyncRequireAuthorizationLayer`] containing self.
34    pub fn new_layer(
35        header: HeaderName,
36        allowed_keys: Vec<String>,
37    ) -> AsyncRequireAuthorizationLayer<Self> {
38        let api_key_auth = Self::new(header, allowed_keys);
39        AsyncRequireAuthorizationLayer::new(api_key_auth)
40    }
41}
42
43impl<B> AsyncAuthorizeRequest<B> for ApiKeyAuth
44where
45    B: Send + Sync + 'static,
46{
47    type RequestBody = B;
48    type ResponseBody = Full<Bytes>;
49    type Future = BoxFuture<'static, Result<Request<B>, Response<Self::ResponseBody>>>;
50
51    fn authorize(&mut self, request: Request<B>) -> Self::Future {
52        let header = self.header.clone();
53        let allowed_keys = self.allowed_keys.clone();
54        Box::pin(async move {
55            let (parts, body) = request.into_parts();
56
57            if parts.method == Method::OPTIONS {
58                let request = Request::from_parts(parts, body);
59                return Ok(request);
60            }
61
62            let Some(api_key) = parts.headers.get(header.as_ref()) else {
63                return Err(ErrorResponse::unauthenticated().into());
64            };
65
66            let Ok(api_key) = api_key.to_str() else {
67                return Err(ErrorResponse::unauthenticated().into());
68            };
69
70            if !allowed_keys
71                .iter()
72                .map(String::as_str)
73                .any(|key| key.eq(api_key))
74            {
75                return Err(ErrorResponse::unauthenticated().into());
76            }
77
78            let api_key = api_key.to_string();
79            let mut request = Request::from_parts(parts, body);
80            request.extensions_mut().insert(ApiKey(api_key));
81
82            Ok(request)
83        })
84    }
85}
86
87#[cfg(test)]
88mod test {
89    use bytes::Bytes;
90    use http::{HeaderName, HeaderValue, Request, Response, StatusCode};
91    use http_body_util::Full;
92    use tower::{ServiceBuilder, ServiceExt};
93    use tower_http::BoxError;
94
95    use crate::middleware::ApiKeyAuth;
96
97    async fn echo(req: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, BoxError> {
98        Ok(Response::new(req.into_body()))
99    }
100
101    #[tokio::test]
102    async fn blocks_no_header() {
103        let api_key_auth = ApiKeyAuth::new_layer(
104            HeaderName::from_static("x-api-key"),
105            vec!["api-key-1".to_string()],
106        );
107
108        let mut service = ServiceBuilder::new().layer(api_key_auth).service_fn(echo);
109
110        let body: Bytes = Bytes::new();
111        let request = Request::new(Full::new(body));
112
113        let response = service
114            .ready()
115            .await
116            .unwrap()
117            .oneshot(request)
118            .await
119            .unwrap();
120        assert_eq!(StatusCode::UNAUTHORIZED, response.status());
121    }
122
123    #[tokio::test]
124    async fn blocks_not_allowed() {
125        let api_key_auth = ApiKeyAuth::new_layer(
126            HeaderName::from_static("x-api-key"),
127            vec!["api-key-1".to_string()],
128        );
129
130        let mut service = ServiceBuilder::new().layer(api_key_auth).service_fn(echo);
131
132        let body: Bytes = Bytes::new();
133        let mut request = Request::new(Full::new(body));
134        request
135            .headers_mut()
136            .insert("x-api-key", HeaderValue::from_static("not-allowed-key"));
137
138        let response = service
139            .ready()
140            .await
141            .unwrap()
142            .oneshot(request)
143            .await
144            .unwrap();
145        assert_eq!(StatusCode::UNAUTHORIZED, response.status());
146    }
147
148    #[tokio::test]
149    async fn allows_allowed() {
150        let api_key_auth = ApiKeyAuth::new_layer(
151            HeaderName::from_static("x-api-key"),
152            vec!["api-key-1".to_string()],
153        );
154
155        let mut service = ServiceBuilder::new().layer(api_key_auth).service_fn(echo);
156
157        let body: Bytes = Bytes::new();
158        let mut request = Request::new(Full::new(body));
159        request
160            .headers_mut()
161            .insert("x-api-key", HeaderValue::from_static("api-key-1"));
162
163        let response = service
164            .ready()
165            .await
166            .unwrap()
167            .oneshot(request)
168            .await
169            .unwrap();
170        assert_eq!(StatusCode::OK, response.status());
171    }
172}