ts_webapi/middleware/
api_key.rs

1//! API key middleware, validates an API key is present and in the allow list.
2
3use alloc::sync::Arc;
4use core::task::{Context, Poll};
5
6use http::{HeaderName, HeaderValue, Request, Response, StatusCode};
7use http_body::Body;
8use tower_layer::Layer;
9use tower_service::Service;
10
11use crate::middleware::futures::DefinedFuture;
12
13/// API key validation layer.
14#[derive(Debug, Clone)]
15pub struct ApiKeyAuth {
16    /// The header to find the API key in.
17    pub header: Arc<HeaderName>,
18    /// The list of allowed keys.
19    pub allowed_keys: Arc<Vec<String>>,
20}
21impl ApiKeyAuth {
22    /// Create new API key auth.
23    pub fn new(header: HeaderName, allowed_keys: Vec<String>) -> Self {
24        Self {
25            allowed_keys: Arc::new(allowed_keys),
26            header: Arc::new(header),
27        }
28    }
29
30    /// Try get the API key header from a request.
31    pub fn get_header<'a, T>(&self, request: &'a Request<T>) -> Option<&'a str> {
32        request
33            .headers()
34            .get(self.header.as_ref())
35            .map(HeaderValue::to_str)
36            .transpose()
37            .ok()
38            .flatten()
39    }
40
41    /// Check if a given API key is allowed
42    pub fn is_allowed_key(&self, key: &str) -> bool {
43        self.allowed_keys
44            .iter()
45            .map(String::as_str)
46            .any(|allowed| allowed.eq(key))
47    }
48}
49
50impl<S> Layer<S> for ApiKeyAuth {
51    type Service = ApiKeyAuthService<S>;
52
53    fn layer(&self, inner: S) -> Self::Service {
54        ApiKeyAuthService {
55            inner,
56            auth: self.clone(),
57        }
58    }
59}
60
61/// Tower service behind the API key auth middleware.
62#[derive(Debug, Clone)]
63pub struct ApiKeyAuthService<S> {
64    /// Inner service.
65    inner: S,
66    /// The logic layer.
67    auth: ApiKeyAuth,
68}
69
70impl<S> ApiKeyAuthService<S> {
71    /// Create a new service.
72    pub fn new(inner: S, auth: ApiKeyAuth) -> Self {
73        Self { inner, auth }
74    }
75}
76
77impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for ApiKeyAuthService<S>
78where
79    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
80    ResBody: Body + Default,
81{
82    type Response = S::Response;
83    type Error = S::Error;
84    type Future = DefinedFuture<S::Future>;
85
86    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
87        self.inner.poll_ready(cx)
88    }
89
90    fn call(&mut self, request: Request<ReqBody>) -> Self::Future {
91        let Some(key) = self.auth.get_header(&request) else {
92            return DefinedFuture::return_status(StatusCode::UNAUTHORIZED);
93        };
94
95        if self.auth.is_allowed_key(key) {
96            DefinedFuture::proceed(self.inner.call(request))
97        } else {
98            DefinedFuture::return_status(StatusCode::FORBIDDEN)
99        }
100    }
101}
102
103#[cfg(test)]
104mod test {
105    use axum::{Router, routing::get};
106    use bytes::Bytes;
107    use http::{HeaderName, HeaderValue, Request, Response, StatusCode};
108    use http_body_util::Full;
109    use tower::{ServiceBuilder, ServiceExt};
110    use tower_http::BoxError;
111
112    use crate::{middleware::api_key::ApiKeyAuth, test::ResponseTestExt};
113
114    async fn echo(req: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, BoxError> {
115        Ok(Response::new(req.into_body()))
116    }
117
118    #[tokio::test]
119    async fn axum_compat() {
120        let api_key_auth = ApiKeyAuth::new(
121            HeaderName::from_static("x-api-key"),
122            vec!["api-key-1".to_string()],
123        );
124
125        let router = Router::new()
126            .route("/", get(|| async { StatusCode::OK }))
127            .layer(api_key_auth);
128
129        router
130            .oneshot(
131                Request::builder()
132                    .uri("/")
133                    .body(axum::body::Body::empty())
134                    .unwrap(),
135            )
136            .await
137            .unwrap()
138            .expect_status(StatusCode::UNAUTHORIZED);
139    }
140
141    #[tokio::test]
142    async fn blocks_no_header() {
143        let api_key_auth = ApiKeyAuth::new(
144            HeaderName::from_static("x-api-key"),
145            vec!["api-key-1".to_string()],
146        );
147
148        let mut service = ServiceBuilder::new().layer(api_key_auth).service_fn(echo);
149
150        let body: Bytes = Bytes::new();
151        let request = Request::new(Full::new(body));
152
153        service
154            .ready()
155            .await
156            .unwrap()
157            .oneshot(request)
158            .await
159            .unwrap()
160            .expect_status(StatusCode::UNAUTHORIZED);
161    }
162
163    #[tokio::test]
164    async fn blocks_not_allowed() {
165        let api_key_auth = ApiKeyAuth::new(
166            HeaderName::from_static("x-api-key"),
167            vec!["api-key-1".to_string()],
168        );
169
170        let mut service = ServiceBuilder::new().layer(api_key_auth).service_fn(echo);
171
172        let body: Bytes = Bytes::new();
173        let mut request = Request::new(Full::new(body));
174        request
175            .headers_mut()
176            .insert("x-api-key", HeaderValue::from_static("not-allowed-key"));
177
178        service
179            .ready()
180            .await
181            .unwrap()
182            .oneshot(request)
183            .await
184            .unwrap()
185            .expect_status(StatusCode::FORBIDDEN);
186    }
187
188    #[tokio::test]
189    async fn allows_allowed() {
190        let api_key_auth = ApiKeyAuth::new(
191            HeaderName::from_static("x-api-key"),
192            vec!["api-key-1".to_string()],
193        );
194
195        let mut service = ServiceBuilder::new().layer(api_key_auth).service_fn(echo);
196
197        let body: Bytes = Bytes::new();
198        let mut request = Request::new(Full::new(body));
199        request
200            .headers_mut()
201            .insert("x-api-key", HeaderValue::from_static("api-key-1"));
202
203        service
204            .ready()
205            .await
206            .unwrap()
207            .oneshot(request)
208            .await
209            .unwrap()
210            .expect_status(StatusCode::OK);
211    }
212}