ts_webapi/middleware/
api_key.rs

1//! API key middleware, validates an API key is present and in the allow list.
2
3use alloc::sync::Arc;
4use core::{
5    pin::Pin,
6    task::{Context, Poll, ready},
7};
8
9use http::{HeaderName, HeaderValue, Request, Response};
10use http_body::Body;
11use pin_project_lite::pin_project;
12use tower_layer::Layer;
13use tower_service::Service;
14
15use crate::{ErrorResponse, middleware::response_body::ResponseBody};
16
17/// API key validation layer.
18#[derive(Debug, Clone)]
19pub struct ApiKeyAuth {
20    /// The header to find the API key in.
21    pub header: Arc<HeaderName>,
22    /// The list of allowed keys.
23    pub allowed_keys: Arc<Vec<String>>,
24}
25impl ApiKeyAuth {
26    /// Create new API key auth.
27    pub fn new(header: HeaderName, allowed_keys: Vec<String>) -> Self {
28        Self {
29            allowed_keys: Arc::new(allowed_keys),
30            header: Arc::new(header),
31        }
32    }
33
34    /// Authenticate a request against this auth
35    pub fn authenticate<T>(&self, request: &Request<T>) -> bool {
36        let Some(Ok(api_key)) = request
37            .headers()
38            .get(self.header.as_ref())
39            .map(HeaderValue::to_str)
40        else {
41            return false;
42        };
43
44        self.allowed_keys
45            .iter()
46            .map(String::as_str)
47            .any(|key| key.eq(api_key))
48    }
49}
50
51impl<S> Layer<S> for ApiKeyAuth {
52    type Service = ApiKeyAuthService<S>;
53
54    fn layer(&self, inner: S) -> Self::Service {
55        ApiKeyAuthService {
56            inner,
57            auth: self.clone(),
58        }
59    }
60}
61
62/// Tower service behind the API key auth middleware.
63#[derive(Debug, Clone)]
64pub struct ApiKeyAuthService<S> {
65    /// Inner service.
66    inner: S,
67    /// The logic layer.
68    auth: ApiKeyAuth,
69}
70
71impl<S> ApiKeyAuthService<S> {
72    /// Create a new service.
73    pub fn new(inner: S, auth: ApiKeyAuth) -> Self {
74        Self { inner, auth }
75    }
76}
77
78impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for ApiKeyAuthService<S>
79where
80    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
81    ResBody: Body,
82{
83    type Response = Response<ResponseBody<ResBody>>;
84    type Error = S::Error;
85    type Future = ResponseFuture<S::Future>;
86
87    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
88        self.inner.poll_ready(cx)
89    }
90
91    fn call(&mut self, request: Request<ReqBody>) -> Self::Future {
92        if !self.auth.authenticate(&request) {
93            ResponseFuture {
94                inner: Kind::Unauthorized,
95            }
96        } else {
97            ResponseFuture {
98                inner: Kind::Ok {
99                    future: self.inner.call(request),
100                },
101            }
102        }
103    }
104}
105
106pin_project! {
107    /// Response future for [`ApiKeyAuthService`].
108    pub struct ResponseFuture<F> {
109        #[pin]
110        inner: Kind<F>,
111    }
112}
113
114pin_project! {
115    #[project = KindProj]
116    enum Kind<F> {
117        Ok {
118            #[pin]
119            future: F,
120        },
121        Unauthorized,
122    }
123}
124
125impl<ResBody, F, E> Future for ResponseFuture<F>
126where
127    ResBody: Body,
128    F: Future<Output = Result<Response<ResBody>, E>>,
129{
130    type Output = Result<Response<ResponseBody<ResBody>>, E>;
131
132    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
133        let response = match self.project().inner.project() {
134            KindProj::Ok { future } => ready!(future.poll(cx))?.map(ResponseBody::new),
135            KindProj::Unauthorized => ErrorResponse::unauthenticated().as_response(),
136        };
137        Poll::Ready(Ok(response))
138    }
139}
140
141#[cfg(test)]
142mod test {
143    use axum::{Router, routing::get};
144    use bytes::Bytes;
145    use http::{HeaderName, HeaderValue, Request, Response, StatusCode};
146    use http_body_util::Full;
147    use tower::{ServiceBuilder, ServiceExt};
148    use tower_http::BoxError;
149
150    use crate::middleware::api_key::ApiKeyAuth;
151
152    async fn echo(req: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, BoxError> {
153        Ok(Response::new(req.into_body()))
154    }
155
156    #[tokio::test]
157    async fn axum_compat() {
158        let api_key_auth = ApiKeyAuth::new(
159            HeaderName::from_static("x-api-key"),
160            vec!["api-key-1".to_string()],
161        );
162
163        let router = Router::new()
164            .route("/", get(|| async { StatusCode::OK }))
165            .layer(api_key_auth);
166
167        let response = router
168            .oneshot(
169                Request::builder()
170                    .uri("/")
171                    .body(axum::body::Body::empty())
172                    .unwrap(),
173            )
174            .await
175            .unwrap();
176        assert_eq!(StatusCode::UNAUTHORIZED, response.status());
177    }
178
179    #[tokio::test]
180    async fn blocks_no_header() {
181        let api_key_auth = ApiKeyAuth::new(
182            HeaderName::from_static("x-api-key"),
183            vec!["api-key-1".to_string()],
184        );
185
186        let mut service = ServiceBuilder::new().layer(api_key_auth).service_fn(echo);
187
188        let body: Bytes = Bytes::new();
189        let request = Request::new(Full::new(body));
190
191        let response = service
192            .ready()
193            .await
194            .unwrap()
195            .oneshot(request)
196            .await
197            .unwrap();
198        assert_eq!(StatusCode::UNAUTHORIZED, response.status());
199    }
200
201    #[tokio::test]
202    async fn blocks_not_allowed() {
203        let api_key_auth = ApiKeyAuth::new(
204            HeaderName::from_static("x-api-key"),
205            vec!["api-key-1".to_string()],
206        );
207
208        let mut service = ServiceBuilder::new().layer(api_key_auth).service_fn(echo);
209
210        let body: Bytes = Bytes::new();
211        let mut request = Request::new(Full::new(body));
212        request
213            .headers_mut()
214            .insert("x-api-key", HeaderValue::from_static("not-allowed-key"));
215
216        let response = service
217            .ready()
218            .await
219            .unwrap()
220            .oneshot(request)
221            .await
222            .unwrap();
223        assert_eq!(StatusCode::UNAUTHORIZED, response.status());
224    }
225
226    #[tokio::test]
227    async fn allows_allowed() {
228        let api_key_auth = ApiKeyAuth::new(
229            HeaderName::from_static("x-api-key"),
230            vec!["api-key-1".to_string()],
231        );
232
233        let mut service = ServiceBuilder::new().layer(api_key_auth).service_fn(echo);
234
235        let body: Bytes = Bytes::new();
236        let mut request = Request::new(Full::new(body));
237        request
238            .headers_mut()
239            .insert("x-api-key", HeaderValue::from_static("api-key-1"));
240
241        let response = service
242            .ready()
243            .await
244            .unwrap()
245            .oneshot(request)
246            .await
247            .unwrap();
248        assert_eq!(StatusCode::OK, response.status());
249    }
250}