ts_webapi/middleware/
api_key.rs1use alloc::sync::Arc;
3
4use bytes::Bytes;
5use futures_core::future::BoxFuture;
6use http::{HeaderName, Method, Request, Response};
7use http_body_util::Full;
8use tower_http::auth::{AsyncAuthorizeRequest, AsyncRequireAuthorizationLayer};
9
10use crate::ErrorResponse;
11
12#[derive(Debug, Clone)]
14pub struct ApiKey(pub String);
15
16#[derive(Debug, Clone)]
18pub struct ApiKeyAuth {
19 allowed_keys: Arc<Vec<String>>,
21 header: Arc<HeaderName>,
23}
24impl ApiKeyAuth {
25 pub fn new(header: HeaderName, allowed_keys: Vec<String>) -> Self {
27 Self {
28 allowed_keys: Arc::new(allowed_keys),
29 header: Arc::new(header),
30 }
31 }
32
33 pub fn new_layer(
35 header: HeaderName,
36 allowed_keys: Vec<String>,
37 ) -> AsyncRequireAuthorizationLayer<Self> {
38 let api_key_auth = Self::new(header, allowed_keys);
39 AsyncRequireAuthorizationLayer::new(api_key_auth)
40 }
41}
42
43impl<B> AsyncAuthorizeRequest<B> for ApiKeyAuth
44where
45 B: Send + Sync + 'static,
46{
47 type RequestBody = B;
48 type ResponseBody = Full<Bytes>;
49 type Future = BoxFuture<'static, Result<Request<B>, Response<Self::ResponseBody>>>;
50
51 fn authorize(&mut self, request: Request<B>) -> Self::Future {
52 let header = self.header.clone();
53 let allowed_keys = self.allowed_keys.clone();
54 Box::pin(async move {
55 let (parts, body) = request.into_parts();
56
57 if parts.method == Method::OPTIONS {
58 let request = Request::from_parts(parts, body);
59 return Ok(request);
60 }
61
62 let Some(api_key) = parts.headers.get(header.as_ref()) else {
63 return Err(ErrorResponse::unauthenticated().into());
64 };
65
66 let Ok(api_key) = api_key.to_str() else {
67 return Err(ErrorResponse::unauthenticated().into());
68 };
69
70 if !allowed_keys
71 .iter()
72 .map(String::as_str)
73 .any(|key| key.eq(api_key))
74 {
75 return Err(ErrorResponse::unauthenticated().into());
76 }
77
78 let api_key = api_key.to_string();
79 let mut request = Request::from_parts(parts, body);
80 request.extensions_mut().insert(ApiKey(api_key));
81
82 Ok(request)
83 })
84 }
85}
86
87#[cfg(test)]
88mod test {
89 use bytes::Bytes;
90 use http::{HeaderName, HeaderValue, Request, Response, StatusCode};
91 use http_body_util::Full;
92 use tower::{ServiceBuilder, ServiceExt};
93 use tower_http::BoxError;
94
95 use crate::middleware::ApiKeyAuth;
96
97 async fn echo(req: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, BoxError> {
98 Ok(Response::new(req.into_body()))
99 }
100
101 #[tokio::test]
102 async fn blocks_no_header() {
103 let api_key_auth = ApiKeyAuth::new_layer(
104 HeaderName::from_static("x-api-key"),
105 vec!["api-key-1".to_string()],
106 );
107
108 let mut service = ServiceBuilder::new().layer(api_key_auth).service_fn(echo);
109
110 let body: Bytes = Bytes::new();
111 let request = Request::new(Full::new(body));
112
113 let response = service
114 .ready()
115 .await
116 .unwrap()
117 .oneshot(request)
118 .await
119 .unwrap();
120 assert_eq!(StatusCode::UNAUTHORIZED, response.status());
121 }
122
123 #[tokio::test]
124 async fn blocks_not_allowed() {
125 let api_key_auth = ApiKeyAuth::new_layer(
126 HeaderName::from_static("x-api-key"),
127 vec!["api-key-1".to_string()],
128 );
129
130 let mut service = ServiceBuilder::new().layer(api_key_auth).service_fn(echo);
131
132 let body: Bytes = Bytes::new();
133 let mut request = Request::new(Full::new(body));
134 request
135 .headers_mut()
136 .insert("x-api-key", HeaderValue::from_static("not-allowed-key"));
137
138 let response = service
139 .ready()
140 .await
141 .unwrap()
142 .oneshot(request)
143 .await
144 .unwrap();
145 assert_eq!(StatusCode::UNAUTHORIZED, response.status());
146 }
147
148 #[tokio::test]
149 async fn allows_allowed() {
150 let api_key_auth = ApiKeyAuth::new_layer(
151 HeaderName::from_static("x-api-key"),
152 vec!["api-key-1".to_string()],
153 );
154
155 let mut service = ServiceBuilder::new().layer(api_key_auth).service_fn(echo);
156
157 let body: Bytes = Bytes::new();
158 let mut request = Request::new(Full::new(body));
159 request
160 .headers_mut()
161 .insert("x-api-key", HeaderValue::from_static("api-key-1"));
162
163 let response = service
164 .ready()
165 .await
166 .unwrap()
167 .oneshot(request)
168 .await
169 .unwrap();
170 assert_eq!(StatusCode::OK, response.status());
171 }
172}