ts_webapi/middleware/
api_key.rs1use alloc::sync::Arc;
4use core::task::{Context, Poll};
5
6use http::{HeaderName, HeaderValue, Request, Response, StatusCode};
7use http_body::Body;
8use tower_layer::Layer;
9use tower_service::Service;
10
11use crate::middleware::futures::DefinedFuture;
12
13#[derive(Debug, Clone)]
15pub struct ApiKeyAuth {
16 pub header: Arc<HeaderName>,
18 pub allowed_keys: Arc<Vec<String>>,
20}
21impl ApiKeyAuth {
22 pub fn new(header: HeaderName, allowed_keys: Vec<String>) -> Self {
24 Self {
25 allowed_keys: Arc::new(allowed_keys),
26 header: Arc::new(header),
27 }
28 }
29
30 pub fn get_header<'a, T>(&self, request: &'a Request<T>) -> Option<&'a str> {
32 request
33 .headers()
34 .get(self.header.as_ref())
35 .map(HeaderValue::to_str)
36 .transpose()
37 .ok()
38 .flatten()
39 }
40
41 pub fn is_allowed_key(&self, key: &str) -> bool {
43 self.allowed_keys
44 .iter()
45 .map(String::as_str)
46 .any(|allowed| allowed.eq(key))
47 }
48}
49
50impl<S> Layer<S> for ApiKeyAuth {
51 type Service = ApiKeyAuthService<S>;
52
53 fn layer(&self, inner: S) -> Self::Service {
54 ApiKeyAuthService {
55 inner,
56 auth: self.clone(),
57 }
58 }
59}
60
61#[derive(Debug, Clone)]
63pub struct ApiKeyAuthService<S> {
64 inner: S,
66 auth: ApiKeyAuth,
68}
69
70impl<S> ApiKeyAuthService<S> {
71 pub fn new(inner: S, auth: ApiKeyAuth) -> Self {
73 Self { inner, auth }
74 }
75}
76
77impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for ApiKeyAuthService<S>
78where
79 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
80 ResBody: Body + Default,
81{
82 type Response = S::Response;
83 type Error = S::Error;
84 type Future = DefinedFuture<S::Future>;
85
86 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
87 self.inner.poll_ready(cx)
88 }
89
90 fn call(&mut self, request: Request<ReqBody>) -> Self::Future {
91 let Some(key) = self.auth.get_header(&request) else {
92 return DefinedFuture::return_status(StatusCode::UNAUTHORIZED);
93 };
94
95 if self.auth.is_allowed_key(key) {
96 DefinedFuture::proceed(self.inner.call(request))
97 } else {
98 DefinedFuture::return_status(StatusCode::FORBIDDEN)
99 }
100 }
101}
102
103#[cfg(test)]
104mod test {
105 use axum::{Router, routing::get};
106 use bytes::Bytes;
107 use http::{HeaderName, HeaderValue, Request, Response, StatusCode};
108 use http_body_util::Full;
109 use tower::{ServiceBuilder, ServiceExt};
110 use tower_http::BoxError;
111
112 use crate::{middleware::api_key::ApiKeyAuth, test::ResponseTestExt};
113
114 async fn echo(req: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, BoxError> {
115 Ok(Response::new(req.into_body()))
116 }
117
118 #[tokio::test]
119 async fn axum_compat() {
120 let api_key_auth = ApiKeyAuth::new(
121 HeaderName::from_static("x-api-key"),
122 vec!["api-key-1".to_string()],
123 );
124
125 let router = Router::new()
126 .route("/", get(|| async { StatusCode::OK }))
127 .layer(api_key_auth);
128
129 router
130 .oneshot(
131 Request::builder()
132 .uri("/")
133 .body(axum::body::Body::empty())
134 .unwrap(),
135 )
136 .await
137 .unwrap()
138 .expect_status(StatusCode::UNAUTHORIZED);
139 }
140
141 #[tokio::test]
142 async fn blocks_no_header() {
143 let api_key_auth = ApiKeyAuth::new(
144 HeaderName::from_static("x-api-key"),
145 vec!["api-key-1".to_string()],
146 );
147
148 let mut service = ServiceBuilder::new().layer(api_key_auth).service_fn(echo);
149
150 let body: Bytes = Bytes::new();
151 let request = Request::new(Full::new(body));
152
153 service
154 .ready()
155 .await
156 .unwrap()
157 .oneshot(request)
158 .await
159 .unwrap()
160 .expect_status(StatusCode::UNAUTHORIZED);
161 }
162
163 #[tokio::test]
164 async fn blocks_not_allowed() {
165 let api_key_auth = ApiKeyAuth::new(
166 HeaderName::from_static("x-api-key"),
167 vec!["api-key-1".to_string()],
168 );
169
170 let mut service = ServiceBuilder::new().layer(api_key_auth).service_fn(echo);
171
172 let body: Bytes = Bytes::new();
173 let mut request = Request::new(Full::new(body));
174 request
175 .headers_mut()
176 .insert("x-api-key", HeaderValue::from_static("not-allowed-key"));
177
178 service
179 .ready()
180 .await
181 .unwrap()
182 .oneshot(request)
183 .await
184 .unwrap()
185 .expect_status(StatusCode::FORBIDDEN);
186 }
187
188 #[tokio::test]
189 async fn allows_allowed() {
190 let api_key_auth = ApiKeyAuth::new(
191 HeaderName::from_static("x-api-key"),
192 vec!["api-key-1".to_string()],
193 );
194
195 let mut service = ServiceBuilder::new().layer(api_key_auth).service_fn(echo);
196
197 let body: Bytes = Bytes::new();
198 let mut request = Request::new(Full::new(body));
199 request
200 .headers_mut()
201 .insert("x-api-key", HeaderValue::from_static("api-key-1"));
202
203 service
204 .ready()
205 .await
206 .unwrap()
207 .oneshot(request)
208 .await
209 .unwrap()
210 .expect_status(StatusCode::OK);
211 }
212}