Skip to main content

tako_rs_plugins/middleware/
api_key_auth.rs

1//! API Key authentication middleware for simple token-based access control.
2//!
3//! This module provides middleware for validating API keys from HTTP headers or query
4//! parameters. It supports multiple key sources, custom header names, and dynamic
5//! key verification functions for flexible authentication strategies.
6//!
7//! # Examples
8//!
9//! ```rust,ignore
10//! use tako::middleware::api_key_auth::{ApiKeyAuth, ApiKeyLocation};
11//! use tako::middleware::IntoMiddleware;
12//!
13//! // Single API key from header
14//! let auth = ApiKeyAuth::new("secret-api-key");
15//! let middleware = auth.into_middleware();
16//!
17//! // Multiple valid keys
18//! let multi_auth = ApiKeyAuth::from_keys(["key1", "key2", "admin-key"]);
19//!
20//! // Custom header name
21//! let custom_auth = ApiKeyAuth::new("secret")
22//!     .header_name("X-Custom-Key");
23//!
24//! // From query parameter
25//! let query_auth = ApiKeyAuth::new("secret")
26//!     .location(ApiKeyLocation::Query("api_key"));
27//!
28//! // Dynamic verification
29//! let dynamic_auth = ApiKeyAuth::with_verify(|key| {
30//!     key.starts_with("valid_")
31//! });
32//! ```
33
34use std::borrow::Cow;
35use std::future::Future;
36use std::pin::Pin;
37use std::sync::Arc;
38
39use http::HeaderValue;
40use http::StatusCode;
41use http::header;
42use subtle::Choice;
43use subtle::ConstantTimeEq;
44use tako_rs_core::body::TakoBody;
45use tako_rs_core::middleware::IntoMiddleware;
46use tako_rs_core::middleware::Next;
47use tako_rs_core::responder::Responder;
48use tako_rs_core::types::Request;
49use tako_rs_core::types::Response;
50
51/// Constant-time match against a list of candidate keys.
52///
53/// Iterates the full list every call; per-byte comparison uses `subtle::ConstantTimeEq`
54/// so equal-length matches do not leak via wall-clock. Length mismatches still return
55/// faster than equal-length compares — clients learn key length but not contents.
56fn constant_time_contains(input: &[u8], candidates: &[Vec<u8>]) -> bool {
57  let mut found = Choice::from(0u8);
58  for candidate in candidates {
59    found |= input.ct_eq(candidate.as_slice());
60  }
61  bool::from(found)
62}
63
64/// Location where the API key should be extracted from.
65#[derive(Clone)]
66pub enum ApiKeyLocation {
67  /// Extract from HTTP header with the given name.
68  Header(&'static str),
69  /// Extract from query parameter with the given name.
70  Query(&'static str),
71  /// Try header first, then query parameter.
72  HeaderOrQuery(&'static str, &'static str),
73}
74
75impl Default for ApiKeyLocation {
76  fn default() -> Self {
77    Self::Header("X-API-Key")
78  }
79}
80
81/// API Key authentication middleware configuration.
82///
83/// `ApiKeyAuth` provides flexible configuration for API key authentication,
84/// supporting static keys, dynamic verification, and multiple extraction locations.
85///
86/// # Examples
87///
88/// ```rust
89/// use tako::middleware::api_key_auth::{ApiKeyAuth, ApiKeyLocation};
90///
91/// // Simple static key
92/// let auth = ApiKeyAuth::new("my-secret-key");
93///
94/// // Multiple keys with custom location
95/// let auth = ApiKeyAuth::from_keys(["key1", "key2"])
96///     .location(ApiKeyLocation::Query("apikey"));
97///
98/// // Dynamic verification
99/// let auth = ApiKeyAuth::with_verify(|key| {
100///     // Lookup in database, validate format, etc.
101///     key.len() == 32 && key.chars().all(|c| c.is_ascii_hexdigit())
102/// });
103/// ```
104/// Custom verification closure for [`ApiKeyAuth`].
105pub type ApiKeyVerifyFn = Arc<dyn Fn(&str) -> bool + Send + Sync + 'static>;
106
107pub struct ApiKeyAuth {
108  /// Static API keys (raw bytes, scanned in constant time).
109  keys: Option<Vec<Vec<u8>>>,
110  /// Custom verification function for dynamic key validation.
111  verify: Option<ApiKeyVerifyFn>,
112  /// Location to extract the API key from.
113  location: ApiKeyLocation,
114}
115
116impl ApiKeyAuth {
117  /// Creates authentication middleware with a single static API key.
118  ///
119  /// By default, the key is extracted from the `X-API-Key` header.
120  pub fn new(key: impl Into<String>) -> Self {
121    let key: String = key.into();
122    Self {
123      keys: Some(vec![key.into_bytes()]),
124      verify: None,
125      location: ApiKeyLocation::default(),
126    }
127  }
128
129  /// Creates authentication middleware with multiple static API keys.
130  pub fn from_keys<I>(keys: I) -> Self
131  where
132    I: IntoIterator,
133    I::Item: Into<String>,
134  {
135    Self {
136      keys: Some(
137        keys
138          .into_iter()
139          .map(|k| Into::<String>::into(k).into_bytes())
140          .collect(),
141      ),
142      verify: None,
143      location: ApiKeyLocation::default(),
144    }
145  }
146
147  /// Creates authentication middleware with a custom verification function.
148  pub fn with_verify<F>(f: F) -> Self
149  where
150    F: Fn(&str) -> bool + Send + Sync + 'static,
151  {
152    Self {
153      keys: None,
154      verify: Some(Arc::new(f)),
155      location: ApiKeyLocation::default(),
156    }
157  }
158
159  /// Creates authentication with both static keys and custom verification.
160  pub fn from_keys_with_verify<I, F>(keys: I, f: F) -> Self
161  where
162    I: IntoIterator,
163    I::Item: Into<String>,
164    F: Fn(&str) -> bool + Send + Sync + 'static,
165  {
166    Self {
167      keys: Some(
168        keys
169          .into_iter()
170          .map(|k| Into::<String>::into(k).into_bytes())
171          .collect(),
172      ),
173      verify: Some(Arc::new(f)),
174      location: ApiKeyLocation::default(),
175    }
176  }
177
178  /// Sets the location where the API key should be extracted from.
179  pub fn location(mut self, location: ApiKeyLocation) -> Self {
180    self.location = location;
181    self
182  }
183
184  /// Sets a custom header name for API key extraction.
185  ///
186  /// This is a convenience method equivalent to
187  /// `.location(ApiKeyLocation::Header(name))`.
188  pub fn header_name(mut self, name: &'static str) -> Self {
189    self.location = ApiKeyLocation::Header(name);
190    self
191  }
192
193  /// Sets a query parameter name for API key extraction.
194  ///
195  /// This is a convenience method equivalent to
196  /// `.location(ApiKeyLocation::Query(name))`.
197  pub fn query_param(mut self, name: &'static str) -> Self {
198    self.location = ApiKeyLocation::Query(name);
199    self
200  }
201}
202
203/// Extracts API key from request based on location configuration.
204fn extract_api_key<'a>(req: &'a Request, location: &ApiKeyLocation) -> Option<Cow<'a, str>> {
205  match location {
206    ApiKeyLocation::Header(name) => req
207      .headers()
208      .get(*name)
209      .and_then(|v| v.to_str().ok())
210      .map(|s| Cow::Borrowed(s.trim())),
211
212    ApiKeyLocation::Query(name) => req.uri().query().and_then(|q| {
213      url::form_urlencoded::parse(q.as_bytes())
214        .find(|(k, _)| k == *name)
215        .map(|(_, v)| v)
216    }),
217
218    ApiKeyLocation::HeaderOrQuery(header, query) => {
219      // Try header first
220      if let Some(key) = req
221        .headers()
222        .get(*header)
223        .and_then(|v| v.to_str().ok())
224        .map(|s| Cow::Borrowed(s.trim()))
225      {
226        return Some(key);
227      }
228      // Fall back to query parameter
229      req.uri().query().and_then(|q| {
230        url::form_urlencoded::parse(q.as_bytes())
231          .find(|(k, _)| k == *query)
232          .map(|(_, v)| v)
233      })
234    }
235  }
236}
237
238impl IntoMiddleware for ApiKeyAuth {
239  /// Converts the API key authentication configuration into middleware.
240  fn into_middleware(
241    self,
242  ) -> impl Fn(Request, Next) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>>
243  + Clone
244  + Send
245  + Sync
246  + 'static {
247    let keys = self.keys.map(Arc::new);
248    let verify = self.verify;
249    let location = self.location;
250    let api_key_authenticate = HeaderValue::from_static("ApiKey");
251
252    move |req: Request, next: Next| {
253      let keys = keys.clone();
254      let verify = verify.clone();
255      let location = location.clone();
256      let api_key_authenticate = api_key_authenticate.clone();
257
258      Box::pin(async move {
259        // Extract API key from configured location
260        let Some(api_key) = extract_api_key(&req, &location) else {
261          return http::Response::builder()
262            .status(StatusCode::UNAUTHORIZED)
263            .header(header::WWW_AUTHENTICATE, api_key_authenticate.clone())
264            .body(TakoBody::from("API key is missing"))
265            .unwrap()
266            .into_response();
267        };
268
269        // Validate against static keys (constant-time scan)
270        if let Some(set) = &keys
271          && constant_time_contains(api_key.as_bytes(), set)
272        {
273          return next.run(req).await.into_response();
274        }
275
276        // Validate using custom verification function
277        if let Some(v) = verify.as_ref()
278          && v(api_key.as_ref())
279        {
280          return next.run(req).await.into_response();
281        }
282
283        // Return 401 Unauthorized for invalid keys
284        http::Response::builder()
285          .status(StatusCode::UNAUTHORIZED)
286          .header(header::WWW_AUTHENTICATE, api_key_authenticate)
287          .body(TakoBody::from("Invalid API key"))
288          .unwrap()
289          .into_response()
290      })
291    }
292  }
293}