ts_webapi/middleware/
api_key.rs1use alloc::sync::Arc;
4use core::{
5 pin::Pin,
6 task::{Context, Poll, ready},
7};
8
9use http::{HeaderName, HeaderValue, Request, Response};
10use http_body::Body;
11use pin_project_lite::pin_project;
12use tower_layer::Layer;
13use tower_service::Service;
14
15use crate::{ErrorResponse, middleware::response_body::ResponseBody};
16
17#[derive(Debug, Clone)]
19pub struct ApiKeyAuth {
20 pub header: Arc<HeaderName>,
22 pub allowed_keys: Arc<Vec<String>>,
24}
25impl ApiKeyAuth {
26 pub fn new(header: HeaderName, allowed_keys: Vec<String>) -> Self {
28 Self {
29 allowed_keys: Arc::new(allowed_keys),
30 header: Arc::new(header),
31 }
32 }
33
34 pub fn authenticate<T>(&self, request: &Request<T>) -> bool {
36 let Some(Ok(api_key)) = request
37 .headers()
38 .get(self.header.as_ref())
39 .map(HeaderValue::to_str)
40 else {
41 return false;
42 };
43
44 self.allowed_keys
45 .iter()
46 .map(String::as_str)
47 .any(|key| key.eq(api_key))
48 }
49}
50
51impl<S> Layer<S> for ApiKeyAuth {
52 type Service = ApiKeyAuthService<S>;
53
54 fn layer(&self, inner: S) -> Self::Service {
55 ApiKeyAuthService {
56 inner,
57 auth: self.clone(),
58 }
59 }
60}
61
62#[derive(Debug, Clone)]
64pub struct ApiKeyAuthService<S> {
65 inner: S,
67 auth: ApiKeyAuth,
69}
70
71impl<S> ApiKeyAuthService<S> {
72 pub fn new(inner: S, auth: ApiKeyAuth) -> Self {
74 Self { inner, auth }
75 }
76}
77
78impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for ApiKeyAuthService<S>
79where
80 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
81 ResBody: Body,
82{
83 type Response = Response<ResponseBody<ResBody>>;
84 type Error = S::Error;
85 type Future = ResponseFuture<S::Future>;
86
87 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
88 self.inner.poll_ready(cx)
89 }
90
91 fn call(&mut self, request: Request<ReqBody>) -> Self::Future {
92 if !self.auth.authenticate(&request) {
93 ResponseFuture {
94 inner: Kind::Unauthorized,
95 }
96 } else {
97 ResponseFuture {
98 inner: Kind::Ok {
99 future: self.inner.call(request),
100 },
101 }
102 }
103 }
104}
105
106pin_project! {
107 pub struct ResponseFuture<F> {
109 #[pin]
110 inner: Kind<F>,
111 }
112}
113
114pin_project! {
115 #[project = KindProj]
116 enum Kind<F> {
117 Ok {
118 #[pin]
119 future: F,
120 },
121 Unauthorized,
122 }
123}
124
125impl<ResBody, F, E> Future for ResponseFuture<F>
126where
127 ResBody: Body,
128 F: Future<Output = Result<Response<ResBody>, E>>,
129{
130 type Output = Result<Response<ResponseBody<ResBody>>, E>;
131
132 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
133 let response = match self.project().inner.project() {
134 KindProj::Ok { future } => ready!(future.poll(cx))?.map(ResponseBody::new),
135 KindProj::Unauthorized => ErrorResponse::unauthenticated().as_response(),
136 };
137 Poll::Ready(Ok(response))
138 }
139}
140
141#[cfg(test)]
142mod test {
143 use axum::{Router, routing::get};
144 use bytes::Bytes;
145 use http::{HeaderName, HeaderValue, Request, Response, StatusCode};
146 use http_body_util::Full;
147 use tower::{ServiceBuilder, ServiceExt};
148 use tower_http::BoxError;
149
150 use crate::middleware::api_key::ApiKeyAuth;
151
152 async fn echo(req: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, BoxError> {
153 Ok(Response::new(req.into_body()))
154 }
155
156 #[tokio::test]
157 async fn axum_compat() {
158 let api_key_auth = ApiKeyAuth::new(
159 HeaderName::from_static("x-api-key"),
160 vec!["api-key-1".to_string()],
161 );
162
163 let router = Router::new()
164 .route("/", get(|| async { StatusCode::OK }))
165 .layer(api_key_auth);
166
167 let response = router
168 .oneshot(
169 Request::builder()
170 .uri("/")
171 .body(axum::body::Body::empty())
172 .unwrap(),
173 )
174 .await
175 .unwrap();
176 assert_eq!(StatusCode::UNAUTHORIZED, response.status());
177 }
178
179 #[tokio::test]
180 async fn blocks_no_header() {
181 let api_key_auth = ApiKeyAuth::new(
182 HeaderName::from_static("x-api-key"),
183 vec!["api-key-1".to_string()],
184 );
185
186 let mut service = ServiceBuilder::new().layer(api_key_auth).service_fn(echo);
187
188 let body: Bytes = Bytes::new();
189 let request = Request::new(Full::new(body));
190
191 let response = service
192 .ready()
193 .await
194 .unwrap()
195 .oneshot(request)
196 .await
197 .unwrap();
198 assert_eq!(StatusCode::UNAUTHORIZED, response.status());
199 }
200
201 #[tokio::test]
202 async fn blocks_not_allowed() {
203 let api_key_auth = ApiKeyAuth::new(
204 HeaderName::from_static("x-api-key"),
205 vec!["api-key-1".to_string()],
206 );
207
208 let mut service = ServiceBuilder::new().layer(api_key_auth).service_fn(echo);
209
210 let body: Bytes = Bytes::new();
211 let mut request = Request::new(Full::new(body));
212 request
213 .headers_mut()
214 .insert("x-api-key", HeaderValue::from_static("not-allowed-key"));
215
216 let response = service
217 .ready()
218 .await
219 .unwrap()
220 .oneshot(request)
221 .await
222 .unwrap();
223 assert_eq!(StatusCode::UNAUTHORIZED, response.status());
224 }
225
226 #[tokio::test]
227 async fn allows_allowed() {
228 let api_key_auth = ApiKeyAuth::new(
229 HeaderName::from_static("x-api-key"),
230 vec!["api-key-1".to_string()],
231 );
232
233 let mut service = ServiceBuilder::new().layer(api_key_auth).service_fn(echo);
234
235 let body: Bytes = Bytes::new();
236 let mut request = Request::new(Full::new(body));
237 request
238 .headers_mut()
239 .insert("x-api-key", HeaderValue::from_static("api-key-1"));
240
241 let response = service
242 .ready()
243 .await
244 .unwrap()
245 .oneshot(request)
246 .await
247 .unwrap();
248 assert_eq!(StatusCode::OK, response.status());
249 }
250}