Skip to main content

tako_rs_plugins/middleware/
basic_auth.rs

1//! Basic HTTP authentication middleware for securing web application endpoints.
2//!
3//! This module provides middleware for implementing RFC 7617 Basic HTTP Authentication.
4//! It supports both static user credentials and dynamic verification functions, allowing
5//! flexible authentication strategies. The middleware validates credentials from the
6//! Authorization header and can inject user objects into request extensions for use
7//! by downstream handlers.
8//!
9//! # Examples
10//!
11//! ```rust,ignore
12//! use tako::middleware::basic_auth::BasicAuth;
13//! use tako::middleware::IntoMiddleware;
14//!
15//! // Single user.
16//! let auth = BasicAuth::single("admin", "password");
17//! let mw = auth.into_middleware();
18//!
19//! // Multiple users with custom realm.
20//! let multi = BasicAuth::multiple([
21//!     ("alice", "secret1"),
22//!     ("bob", "secret2"),
23//! ]).realm("Admin Area");
24//!
25//! // Dynamic verify callback — returns `bool`, not a user object.
26//! let dynamic = BasicAuth::with_verify(|username, password| {
27//!     username == "user" && password == "pass"
28//! });
29//! ```
30
31use std::collections::HashMap;
32use std::future::Future;
33use std::pin::Pin;
34use std::sync::Arc;
35
36use base64::Engine;
37use http::HeaderValue;
38use http::StatusCode;
39use http::header;
40use subtle::ConstantTimeEq;
41use tako_rs_core::body::TakoBody;
42use tako_rs_core::middleware::IntoMiddleware;
43use tako_rs_core::middleware::Next;
44use tako_rs_core::responder::Responder;
45use tako_rs_core::types::BuildHasher;
46use tako_rs_core::types::Request;
47use tako_rs_core::types::Response;
48
49/// Basic HTTP authentication middleware configuration.
50///
51/// `BasicAuth` provides flexible configuration for HTTP Basic authentication using either
52/// static user credentials, dynamic verification functions, or both. The middleware
53/// validates credentials from the Authorization header and can inject authenticated
54/// user objects into request extensions for downstream handlers.
55///
56/// # Type Parameters
57///
58/// * `U` - User object type returned by verification functions
59/// * `F` - Verification function type that takes username/password and returns `Option<U>`
60///
61/// # Examples
62///
63/// ```rust
64/// use tako::middleware::basic_auth::BasicAuth;
65/// use std::collections::HashMap;
66///
67/// // Simple static authentication
68/// let auth = BasicAuth::<(), _>::single("admin", "secret");
69///
70/// // Multiple static users
71/// let multi = BasicAuth::<(), _>::multiple([
72///     ("user1", "pass1"),
73///     ("user2", "pass2"),
74/// ]);
75///
76/// // Custom verification with user data
77/// #[derive(Clone)]
78/// struct UserInfo { id: u32, role: String }
79///
80/// let custom = BasicAuth::with_verify(|user, pass| {
81///     // Verify against database, LDAP, etc.
82///     if user == "admin" && pass == "secret" {
83///         Some(UserInfo { id: 1, role: "admin".to_string() })
84///     } else {
85///         None
86///     }
87/// });
88/// ```
89/// Custom verification closure for [`BasicAuth`].
90pub type BasicAuthVerifyFn = Arc<dyn Fn(&str, &str) -> bool + Send + Sync + 'static>;
91
92pub struct BasicAuth {
93  /// Static user credentials map (username -> password).
94  users: Option<Arc<HashMap<String, String, BuildHasher>>>,
95  /// Custom verification function for dynamic authentication.
96  verify: Option<BasicAuthVerifyFn>,
97  /// Authentication realm for WWW-Authenticate header.
98  realm: &'static str,
99}
100
101impl BasicAuth {
102  /// Creates authentication middleware with a single static user credential.
103  pub fn single(user: impl Into<String>, pass: impl Into<String>) -> Self {
104    Self::multiple(std::iter::once((user, pass)))
105  }
106
107  /// Creates authentication middleware with multiple static user credentials.
108  pub fn multiple<I, T, P>(pairs: I) -> Self
109  where
110    I: IntoIterator<Item = (T, P)>,
111    T: Into<String>,
112    P: Into<String>,
113  {
114    Self {
115      users: Some(Arc::new(
116        pairs
117          .into_iter()
118          .map(|(u, p)| (u.into(), p.into()))
119          .collect(),
120      )),
121      verify: None,
122      realm: "Restricted",
123    }
124  }
125
126  /// Creates authentication middleware with a custom verification function.
127  pub fn with_verify<F>(cb: F) -> Self
128  where
129    F: Fn(&str, &str) -> bool + Send + Sync + 'static,
130  {
131    Self {
132      users: None,
133      verify: Some(Arc::new(cb)),
134      realm: "Restricted",
135    }
136  }
137
138  /// Creates authentication middleware with both static credentials and custom verification.
139  pub fn users_with_verify<I, S, F>(pairs: I, cb: F) -> Self
140  where
141    I: IntoIterator<Item = (S, S)>,
142    S: Into<String>,
143    F: Fn(&str, &str) -> bool + Send + Sync + 'static,
144  {
145    Self {
146      users: Some(Arc::new(
147        pairs
148          .into_iter()
149          .map(|(u, p)| (u.into(), p.into()))
150          .collect(),
151      )),
152      verify: Some(Arc::new(cb)),
153      realm: "Restricted",
154    }
155  }
156
157  /// Sets the authentication realm for the WWW-Authenticate header.
158  pub fn realm(mut self, r: &'static str) -> Self {
159    self.realm = r;
160    self
161  }
162}
163
164impl IntoMiddleware for BasicAuth {
165  /// Converts the authentication configuration into middleware.
166  fn into_middleware(
167    self,
168  ) -> impl Fn(Request, Next) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>>
169  + Clone
170  + Send
171  + Sync
172  + 'static {
173    let users = self.users;
174    let verify = self.verify;
175    let realm = self.realm;
176    // `HeaderValue::from_str` rejects non-visible-ASCII bytes and
177    // embedded `"` characters; a developer who hands us a realm with
178    // such bytes would otherwise panic the middleware setup (cold but
179    // user-controlled value). Strip the realm if it cannot be encoded
180    // and fall back to a bare `Basic` challenge — RFC 7617 §2.1 makes
181    // the realm parameter optional.
182    let www_authenticate = HeaderValue::from_str(&format!("Basic realm=\"{realm}\""))
183      .unwrap_or_else(|_| HeaderValue::from_static("Basic"));
184
185    move |req: Request, next: Next| {
186      let users = users.clone();
187      let verify = verify.clone();
188      let www_authenticate = www_authenticate.clone();
189
190      Box::pin(async move {
191        // Extract Basic credentials from Authorization header. RFC 7235
192        // §2.1 makes the auth-scheme token case-insensitive.
193        let creds = req
194          .headers()
195          .get(header::AUTHORIZATION)
196          .and_then(|h| h.to_str().ok())
197          .and_then(|h| {
198            let (scheme, rest) = h.trim_start().split_once(' ')?;
199            scheme.eq_ignore_ascii_case("Basic").then(|| rest.trim())
200          })
201          .and_then(|b64| base64::engine::general_purpose::STANDARD.decode(b64).ok());
202
203        match creds {
204          Some(raw) => {
205            let Some(decoded) = std::str::from_utf8(&raw).ok() else {
206              let mut res = Response::new(TakoBody::empty());
207              *res.status_mut() = StatusCode::UNAUTHORIZED;
208              res
209                .headers_mut()
210                .append(header::WWW_AUTHENTICATE, www_authenticate.clone());
211              return res;
212            };
213            let Some((u, p)) = decoded.split_once(':') else {
214              let mut res = Response::new(TakoBody::empty());
215              *res.status_mut() = StatusCode::UNAUTHORIZED;
216              res
217                .headers_mut()
218                .append(header::WWW_AUTHENTICATE, www_authenticate.clone());
219              return res;
220            };
221
222            // Check static user credentials first. Scan every entry and
223            // constant-time-compare both the username and the password so
224            // that neither (a) the time-to-401 leaks whether the username
225            // exists, nor (b) the password compare itself short-circuits on
226            // first-byte mismatch.
227            let mut authed = false;
228            if let Some(map) = users.as_ref() {
229              for (known_user, known_pw) in map.iter() {
230                let user_match = constant_time_eq(known_user.as_bytes(), u.as_bytes());
231                let pw_match = constant_time_eq(known_pw.as_bytes(), p.as_bytes());
232                authed |= user_match & pw_match;
233              }
234            }
235            if authed {
236              return next.run(req).await.into_response();
237            }
238
239            // Use custom verification function if available
240            if let Some(cb) = &verify
241              && cb(u, p)
242            {
243              return next.run(req).await.into_response();
244            }
245          }
246          None => {
247            return http::Response::builder()
248              .status(StatusCode::UNAUTHORIZED)
249              .header(header::WWW_AUTHENTICATE, www_authenticate.clone())
250              .body(TakoBody::from("Missing credentials"))
251              .unwrap()
252              .into_response();
253          }
254        }
255
256        // Return 401 Unauthorized for invalid credentials
257        let mut res = Response::new(TakoBody::empty());
258        *res.status_mut() = StatusCode::UNAUTHORIZED;
259        res
260          .headers_mut()
261          .append(header::WWW_AUTHENTICATE, www_authenticate);
262        res
263      })
264    }
265  }
266}
267
268fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
269  // Length mismatch must short-circuit because `ct_eq` requires equal-length
270  // slices. Leaking the length of credentials is mild — actual entropy comes
271  // from value, not byte-count.
272  if a.len() != b.len() {
273    return false;
274  }
275  a.ct_eq(b).into()
276}