tako_rs_plugins/middleware/basic_auth.rs
1//! Basic HTTP authentication middleware for securing web application endpoints.
2//!
3//! This module provides middleware for implementing RFC 7617 Basic HTTP Authentication.
4//! It supports both static user credentials and dynamic verification functions, allowing
5//! flexible authentication strategies. The middleware validates credentials from the
6//! Authorization header and can inject user objects into request extensions for use
7//! by downstream handlers.
8//!
9//! # Examples
10//!
11//! ```rust,ignore
12//! use tako::middleware::basic_auth::BasicAuth;
13//! use tako::middleware::IntoMiddleware;
14//!
15//! // Single user.
16//! let auth = BasicAuth::single("admin", "password");
17//! let mw = auth.into_middleware();
18//!
19//! // Multiple users with custom realm.
20//! let multi = BasicAuth::multiple([
21//! ("alice", "secret1"),
22//! ("bob", "secret2"),
23//! ]).realm("Admin Area");
24//!
25//! // Dynamic verify callback — returns `bool`, not a user object.
26//! let dynamic = BasicAuth::with_verify(|username, password| {
27//! username == "user" && password == "pass"
28//! });
29//! ```
30
31use std::collections::HashMap;
32use std::future::Future;
33use std::pin::Pin;
34use std::sync::Arc;
35
36use base64::Engine;
37use http::HeaderValue;
38use http::StatusCode;
39use http::header;
40use subtle::ConstantTimeEq;
41use tako_rs_core::body::TakoBody;
42use tako_rs_core::middleware::IntoMiddleware;
43use tako_rs_core::middleware::Next;
44use tako_rs_core::responder::Responder;
45use tako_rs_core::types::BuildHasher;
46use tako_rs_core::types::Request;
47use tako_rs_core::types::Response;
48
49/// Basic HTTP authentication middleware configuration.
50///
51/// `BasicAuth` provides flexible configuration for HTTP Basic authentication using either
52/// static user credentials, dynamic verification functions, or both. The middleware
53/// validates credentials from the Authorization header and can inject authenticated
54/// user objects into request extensions for downstream handlers.
55///
56/// # Type Parameters
57///
58/// * `U` - User object type returned by verification functions
59/// * `F` - Verification function type that takes username/password and returns `Option<U>`
60///
61/// # Examples
62///
63/// ```rust
64/// use tako::middleware::basic_auth::BasicAuth;
65/// use std::collections::HashMap;
66///
67/// // Simple static authentication
68/// let auth = BasicAuth::<(), _>::single("admin", "secret");
69///
70/// // Multiple static users
71/// let multi = BasicAuth::<(), _>::multiple([
72/// ("user1", "pass1"),
73/// ("user2", "pass2"),
74/// ]);
75///
76/// // Custom verification with user data
77/// #[derive(Clone)]
78/// struct UserInfo { id: u32, role: String }
79///
80/// let custom = BasicAuth::with_verify(|user, pass| {
81/// // Verify against database, LDAP, etc.
82/// if user == "admin" && pass == "secret" {
83/// Some(UserInfo { id: 1, role: "admin".to_string() })
84/// } else {
85/// None
86/// }
87/// });
88/// ```
89/// Custom verification closure for [`BasicAuth`].
90pub type BasicAuthVerifyFn = Arc<dyn Fn(&str, &str) -> bool + Send + Sync + 'static>;
91
92pub struct BasicAuth {
93 /// Static user credentials map (username -> password).
94 users: Option<Arc<HashMap<String, String, BuildHasher>>>,
95 /// Custom verification function for dynamic authentication.
96 verify: Option<BasicAuthVerifyFn>,
97 /// Authentication realm for WWW-Authenticate header.
98 realm: &'static str,
99}
100
101impl BasicAuth {
102 /// Creates authentication middleware with a single static user credential.
103 pub fn single(user: impl Into<String>, pass: impl Into<String>) -> Self {
104 Self::multiple(std::iter::once((user, pass)))
105 }
106
107 /// Creates authentication middleware with multiple static user credentials.
108 pub fn multiple<I, T, P>(pairs: I) -> Self
109 where
110 I: IntoIterator<Item = (T, P)>,
111 T: Into<String>,
112 P: Into<String>,
113 {
114 Self {
115 users: Some(Arc::new(
116 pairs
117 .into_iter()
118 .map(|(u, p)| (u.into(), p.into()))
119 .collect(),
120 )),
121 verify: None,
122 realm: "Restricted",
123 }
124 }
125
126 /// Creates authentication middleware with a custom verification function.
127 pub fn with_verify<F>(cb: F) -> Self
128 where
129 F: Fn(&str, &str) -> bool + Send + Sync + 'static,
130 {
131 Self {
132 users: None,
133 verify: Some(Arc::new(cb)),
134 realm: "Restricted",
135 }
136 }
137
138 /// Creates authentication middleware with both static credentials and custom verification.
139 pub fn users_with_verify<I, S, F>(pairs: I, cb: F) -> Self
140 where
141 I: IntoIterator<Item = (S, S)>,
142 S: Into<String>,
143 F: Fn(&str, &str) -> bool + Send + Sync + 'static,
144 {
145 Self {
146 users: Some(Arc::new(
147 pairs
148 .into_iter()
149 .map(|(u, p)| (u.into(), p.into()))
150 .collect(),
151 )),
152 verify: Some(Arc::new(cb)),
153 realm: "Restricted",
154 }
155 }
156
157 /// Sets the authentication realm for the WWW-Authenticate header.
158 pub fn realm(mut self, r: &'static str) -> Self {
159 self.realm = r;
160 self
161 }
162}
163
164impl IntoMiddleware for BasicAuth {
165 /// Converts the authentication configuration into middleware.
166 fn into_middleware(
167 self,
168 ) -> impl Fn(Request, Next) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>>
169 + Clone
170 + Send
171 + Sync
172 + 'static {
173 let users = self.users;
174 let verify = self.verify;
175 let realm = self.realm;
176 // `HeaderValue::from_str` rejects non-visible-ASCII bytes and
177 // embedded `"` characters; a developer who hands us a realm with
178 // such bytes would otherwise panic the middleware setup (cold but
179 // user-controlled value). Strip the realm if it cannot be encoded
180 // and fall back to a bare `Basic` challenge — RFC 7617 §2.1 makes
181 // the realm parameter optional.
182 let www_authenticate = HeaderValue::from_str(&format!("Basic realm=\"{realm}\""))
183 .unwrap_or_else(|_| HeaderValue::from_static("Basic"));
184
185 move |req: Request, next: Next| {
186 let users = users.clone();
187 let verify = verify.clone();
188 let www_authenticate = www_authenticate.clone();
189
190 Box::pin(async move {
191 // Extract Basic credentials from Authorization header. RFC 7235
192 // §2.1 makes the auth-scheme token case-insensitive.
193 let creds = req
194 .headers()
195 .get(header::AUTHORIZATION)
196 .and_then(|h| h.to_str().ok())
197 .and_then(|h| {
198 let (scheme, rest) = h.trim_start().split_once(' ')?;
199 scheme.eq_ignore_ascii_case("Basic").then(|| rest.trim())
200 })
201 .and_then(|b64| base64::engine::general_purpose::STANDARD.decode(b64).ok());
202
203 match creds {
204 Some(raw) => {
205 let Some(decoded) = std::str::from_utf8(&raw).ok() else {
206 let mut res = Response::new(TakoBody::empty());
207 *res.status_mut() = StatusCode::UNAUTHORIZED;
208 res
209 .headers_mut()
210 .append(header::WWW_AUTHENTICATE, www_authenticate.clone());
211 return res;
212 };
213 let Some((u, p)) = decoded.split_once(':') else {
214 let mut res = Response::new(TakoBody::empty());
215 *res.status_mut() = StatusCode::UNAUTHORIZED;
216 res
217 .headers_mut()
218 .append(header::WWW_AUTHENTICATE, www_authenticate.clone());
219 return res;
220 };
221
222 // Check static user credentials first. Scan every entry and
223 // constant-time-compare both the username and the password so
224 // that neither (a) the time-to-401 leaks whether the username
225 // exists, nor (b) the password compare itself short-circuits on
226 // first-byte mismatch.
227 let mut authed = false;
228 if let Some(map) = users.as_ref() {
229 for (known_user, known_pw) in map.iter() {
230 let user_match = constant_time_eq(known_user.as_bytes(), u.as_bytes());
231 let pw_match = constant_time_eq(known_pw.as_bytes(), p.as_bytes());
232 authed |= user_match & pw_match;
233 }
234 }
235 if authed {
236 return next.run(req).await.into_response();
237 }
238
239 // Use custom verification function if available
240 if let Some(cb) = &verify
241 && cb(u, p)
242 {
243 return next.run(req).await.into_response();
244 }
245 }
246 None => {
247 return http::Response::builder()
248 .status(StatusCode::UNAUTHORIZED)
249 .header(header::WWW_AUTHENTICATE, www_authenticate.clone())
250 .body(TakoBody::from("Missing credentials"))
251 .unwrap()
252 .into_response();
253 }
254 }
255
256 // Return 401 Unauthorized for invalid credentials
257 let mut res = Response::new(TakoBody::empty());
258 *res.status_mut() = StatusCode::UNAUTHORIZED;
259 res
260 .headers_mut()
261 .append(header::WWW_AUTHENTICATE, www_authenticate);
262 res
263 })
264 }
265 }
266}
267
268fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
269 // Length mismatch must short-circuit because `ct_eq` requires equal-length
270 // slices. Leaking the length of credentials is mild — actual entropy comes
271 // from value, not byte-count.
272 if a.len() != b.len() {
273 return false;
274 }
275 a.ct_eq(b).into()
276}