tako_rs_plugins/middleware/
api_key_auth.rs1use std::borrow::Cow;
35use std::future::Future;
36use std::pin::Pin;
37use std::sync::Arc;
38
39use http::HeaderValue;
40use http::StatusCode;
41use http::header;
42use subtle::Choice;
43use subtle::ConstantTimeEq;
44use tako_rs_core::body::TakoBody;
45use tako_rs_core::middleware::IntoMiddleware;
46use tako_rs_core::middleware::Next;
47use tako_rs_core::responder::Responder;
48use tako_rs_core::types::Request;
49use tako_rs_core::types::Response;
50
51fn constant_time_contains(input: &[u8], candidates: &[Vec<u8>]) -> bool {
57 let mut found = Choice::from(0u8);
58 for candidate in candidates {
59 found |= input.ct_eq(candidate.as_slice());
60 }
61 bool::from(found)
62}
63
64#[derive(Clone)]
66pub enum ApiKeyLocation {
67 Header(&'static str),
69 Query(&'static str),
71 HeaderOrQuery(&'static str, &'static str),
73}
74
75impl Default for ApiKeyLocation {
76 fn default() -> Self {
77 Self::Header("X-API-Key")
78 }
79}
80
81pub type ApiKeyVerifyFn = Arc<dyn Fn(&str) -> bool + Send + Sync + 'static>;
106
107pub struct ApiKeyAuth {
108 keys: Option<Vec<Vec<u8>>>,
110 verify: Option<ApiKeyVerifyFn>,
112 location: ApiKeyLocation,
114}
115
116impl ApiKeyAuth {
117 pub fn new(key: impl Into<String>) -> Self {
121 let key: String = key.into();
122 Self {
123 keys: Some(vec![key.into_bytes()]),
124 verify: None,
125 location: ApiKeyLocation::default(),
126 }
127 }
128
129 pub fn from_keys<I>(keys: I) -> Self
131 where
132 I: IntoIterator,
133 I::Item: Into<String>,
134 {
135 Self {
136 keys: Some(
137 keys
138 .into_iter()
139 .map(|k| Into::<String>::into(k).into_bytes())
140 .collect(),
141 ),
142 verify: None,
143 location: ApiKeyLocation::default(),
144 }
145 }
146
147 pub fn with_verify<F>(f: F) -> Self
149 where
150 F: Fn(&str) -> bool + Send + Sync + 'static,
151 {
152 Self {
153 keys: None,
154 verify: Some(Arc::new(f)),
155 location: ApiKeyLocation::default(),
156 }
157 }
158
159 pub fn from_keys_with_verify<I, F>(keys: I, f: F) -> Self
161 where
162 I: IntoIterator,
163 I::Item: Into<String>,
164 F: Fn(&str) -> bool + Send + Sync + 'static,
165 {
166 Self {
167 keys: Some(
168 keys
169 .into_iter()
170 .map(|k| Into::<String>::into(k).into_bytes())
171 .collect(),
172 ),
173 verify: Some(Arc::new(f)),
174 location: ApiKeyLocation::default(),
175 }
176 }
177
178 pub fn location(mut self, location: ApiKeyLocation) -> Self {
180 self.location = location;
181 self
182 }
183
184 pub fn header_name(mut self, name: &'static str) -> Self {
189 self.location = ApiKeyLocation::Header(name);
190 self
191 }
192
193 pub fn query_param(mut self, name: &'static str) -> Self {
198 self.location = ApiKeyLocation::Query(name);
199 self
200 }
201}
202
203fn extract_api_key<'a>(req: &'a Request, location: &ApiKeyLocation) -> Option<Cow<'a, str>> {
205 match location {
206 ApiKeyLocation::Header(name) => req
207 .headers()
208 .get(*name)
209 .and_then(|v| v.to_str().ok())
210 .map(|s| Cow::Borrowed(s.trim())),
211
212 ApiKeyLocation::Query(name) => req.uri().query().and_then(|q| {
213 url::form_urlencoded::parse(q.as_bytes())
214 .find(|(k, _)| k == *name)
215 .map(|(_, v)| v)
216 }),
217
218 ApiKeyLocation::HeaderOrQuery(header, query) => {
219 if let Some(key) = req
221 .headers()
222 .get(*header)
223 .and_then(|v| v.to_str().ok())
224 .map(|s| Cow::Borrowed(s.trim()))
225 {
226 return Some(key);
227 }
228 req.uri().query().and_then(|q| {
230 url::form_urlencoded::parse(q.as_bytes())
231 .find(|(k, _)| k == *query)
232 .map(|(_, v)| v)
233 })
234 }
235 }
236}
237
238impl IntoMiddleware for ApiKeyAuth {
239 fn into_middleware(
241 self,
242 ) -> impl Fn(Request, Next) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>>
243 + Clone
244 + Send
245 + Sync
246 + 'static {
247 let keys = self.keys.map(Arc::new);
248 let verify = self.verify;
249 let location = self.location;
250 let api_key_authenticate = HeaderValue::from_static("ApiKey");
251
252 move |req: Request, next: Next| {
253 let keys = keys.clone();
254 let verify = verify.clone();
255 let location = location.clone();
256 let api_key_authenticate = api_key_authenticate.clone();
257
258 Box::pin(async move {
259 let Some(api_key) = extract_api_key(&req, &location) else {
261 return http::Response::builder()
262 .status(StatusCode::UNAUTHORIZED)
263 .header(header::WWW_AUTHENTICATE, api_key_authenticate.clone())
264 .body(TakoBody::from("API key is missing"))
265 .unwrap()
266 .into_response();
267 };
268
269 if let Some(set) = &keys
271 && constant_time_contains(api_key.as_bytes(), set)
272 {
273 return next.run(req).await.into_response();
274 }
275
276 if let Some(v) = verify.as_ref()
278 && v(api_key.as_ref())
279 {
280 return next.run(req).await.into_response();
281 }
282
283 http::Response::builder()
285 .status(StatusCode::UNAUTHORIZED)
286 .header(header::WWW_AUTHENTICATE, api_key_authenticate)
287 .body(TakoBody::from("Invalid API key"))
288 .unwrap()
289 .into_response()
290 })
291 }
292 }
293}