1use base64::Engine;
12use base64::engine::general_purpose::STANDARD as BASE64;
13use hmac::{Hmac, Mac};
14use pbkdf2::pbkdf2_hmac;
15use rand::Rng;
16use sha2::{Digest, Sha256};
17use subtle::ConstantTimeEq;
18
19type HmacSha256 = Hmac<Sha256>;
20
21pub const DEFAULT_ITERATIONS: u32 = 100_000;
23
24#[derive(Debug, thiserror::Error)]
30pub enum AuthError {
31 #[error("invalid credentials")]
32 InvalidCredentials,
33 #[error("invalid auth header: {0}")]
34 InvalidHeader(String),
35 #[error("handshake failed: {0}")]
36 HandshakeFailed(String),
37 #[error("base64 decode error: {0}")]
38 Base64Error(String),
39}
40
41#[derive(Debug, Clone)]
47pub struct ScramCredentials {
48 pub salt: Vec<u8>,
49 pub iterations: u32,
50 pub stored_key: Vec<u8>,
51 pub server_key: Vec<u8>,
52}
53
54#[derive(Debug, Clone)]
57pub struct ScramHandshake {
58 pub username: String,
59 pub client_nonce: String,
60 pub server_nonce: String,
61 pub salt: Vec<u8>,
62 pub iterations: u32,
63 pub auth_message: String,
64 pub server_signature: Vec<u8>,
65 stored_key: Vec<u8>,
67}
68
69#[derive(Debug, Clone, PartialEq, Eq)]
71pub enum AuthHeader {
72 Hello {
73 username: String,
74 },
75 Scram {
76 handshake_token: String,
77 data: String,
78 },
79 Bearer {
80 auth_token: String,
81 },
82}
83
84fn hmac_sha256(key: &[u8], msg: &[u8]) -> Vec<u8> {
90 let mut mac = HmacSha256::new_from_slice(key).expect("HMAC accepts keys of any size");
91 mac.update(msg);
92 mac.finalize().into_bytes().to_vec()
93}
94
95fn sha256(data: &[u8]) -> Vec<u8> {
97 let mut hasher = Sha256::new();
98 hasher.update(data);
99 hasher.finalize().to_vec()
100}
101
102fn xor_bytes(a: &[u8], b: &[u8]) -> Vec<u8> {
104 assert_eq!(a.len(), b.len(), "XOR operands must be the same length");
105 a.iter().zip(b.iter()).map(|(x, y)| x ^ y).collect()
106}
107
108fn pbkdf2_sha256(password: &[u8], salt: &[u8], iterations: u32) -> Vec<u8> {
110 let mut salted_password = vec![0u8; 32];
111 pbkdf2_hmac::<Sha256>(password, salt, iterations, &mut salted_password);
112 salted_password
113}
114
115fn derive_keys(salted_password: &[u8]) -> (Vec<u8>, Vec<u8>, Vec<u8>) {
117 let client_key = hmac_sha256(salted_password, b"Client Key");
118 let stored_key = sha256(&client_key);
119 let server_key = hmac_sha256(salted_password, b"Server Key");
120 (client_key, stored_key, server_key)
121}
122
123fn parse_scram_param<'a>(segment: &'a str, prefix: &str) -> Result<&'a str, AuthError> {
125 let trimmed = segment.trim();
126 trimmed.strip_prefix(prefix).ok_or_else(|| {
127 AuthError::HandshakeFailed(format!(
128 "expected prefix '{}' but got '{}'",
129 prefix, trimmed
130 ))
131 })
132}
133
134fn make_client_first_bare(username: &str, client_nonce: &str) -> String {
136 format!("n={},r={}", username, client_nonce)
137}
138
139pub fn derive_credentials(password: &str, salt: &[u8], iterations: u32) -> ScramCredentials {
147 let salted_password = pbkdf2_sha256(password.as_bytes(), salt, iterations);
148 let (_client_key, stored_key, server_key) = derive_keys(&salted_password);
149 ScramCredentials {
150 salt: salt.to_vec(),
151 iterations,
152 stored_key,
153 server_key,
154 }
155}
156
157pub fn generate_nonce() -> String {
159 let mut bytes = [0u8; 18];
160 rand::rng().fill(&mut bytes);
161 BASE64.encode(bytes)
162}
163
164pub fn client_first_message(username: &str) -> (String, String) {
171 let client_nonce = generate_nonce();
172 let bare = make_client_first_bare(username, &client_nonce);
173 let full = format!("n,,{}", bare);
174 let encoded = BASE64.encode(full.as_bytes());
175 (client_nonce, encoded)
176}
177
178pub fn server_first_message(
186 username: &str,
187 client_nonce_b64: &str,
188 credentials: &ScramCredentials,
189) -> (ScramHandshake, String) {
190 let server_nonce = generate_nonce();
191 let combined_nonce = format!("{}{}", client_nonce_b64, server_nonce);
192 let salt_b64 = BASE64.encode(&credentials.salt);
193
194 let server_first_msg = format!(
196 "r={},s={},i={}",
197 combined_nonce, salt_b64, credentials.iterations
198 );
199
200 let cfmb = make_client_first_bare(username, client_nonce_b64);
202
203 let client_final_without_proof = format!("c=biws,r={}", combined_nonce);
205
206 let auth_message = format!(
208 "{},{},{}",
209 cfmb, server_first_msg, client_final_without_proof
210 );
211
212 let server_signature = hmac_sha256(&credentials.server_key, auth_message.as_bytes());
214
215 let server_first_b64 = BASE64.encode(server_first_msg.as_bytes());
216
217 let handshake = ScramHandshake {
218 username: username.to_string(),
219 client_nonce: client_nonce_b64.to_string(),
220 server_nonce,
221 salt: credentials.salt.clone(),
222 iterations: credentials.iterations,
223 auth_message,
224 server_signature,
225 stored_key: credentials.stored_key.clone(),
226 };
227
228 (handshake, server_first_b64)
229}
230
231pub fn client_final_message(
240 password: &str,
241 client_nonce: &str,
242 server_first_b64: &str,
243 username: &str,
244) -> Result<(String, Vec<u8>), AuthError> {
245 let server_first_bytes = BASE64
247 .decode(server_first_b64)
248 .map_err(|e| AuthError::Base64Error(e.to_string()))?;
249 let server_first_msg = String::from_utf8(server_first_bytes)
250 .map_err(|e| AuthError::HandshakeFailed(e.to_string()))?;
251
252 let parts: Vec<&str> = server_first_msg.splitn(3, ',').collect();
254 if parts.len() != 3 {
255 return Err(AuthError::HandshakeFailed(
256 "invalid server-first-message format".to_string(),
257 ));
258 }
259
260 let combined_nonce = parse_scram_param(parts[0], "r=")?;
261 let salt_b64 = parse_scram_param(parts[1], "s=")?;
262 let iterations_str = parse_scram_param(parts[2], "i=")?;
263
264 if !combined_nonce.starts_with(client_nonce) {
266 return Err(AuthError::HandshakeFailed(
267 "combined nonce does not start with client nonce".to_string(),
268 ));
269 }
270
271 let salt = BASE64
272 .decode(salt_b64)
273 .map_err(|e| AuthError::Base64Error(e.to_string()))?;
274 let iterations: u32 = iterations_str
275 .parse()
276 .map_err(|e: std::num::ParseIntError| AuthError::HandshakeFailed(e.to_string()))?;
277
278 let salted_password = pbkdf2_sha256(password.as_bytes(), &salt, iterations);
280 let (client_key, stored_key, server_key) = derive_keys(&salted_password);
281
282 let cfmb = make_client_first_bare(username, client_nonce);
284 let client_final_without_proof = format!("c=biws,r={}", combined_nonce);
285 let auth_message = format!(
286 "{},{},{}",
287 cfmb, server_first_msg, client_final_without_proof
288 );
289
290 let client_signature = hmac_sha256(&stored_key, auth_message.as_bytes());
292 let client_proof = xor_bytes(&client_key, &client_signature);
294 let server_signature = hmac_sha256(&server_key, auth_message.as_bytes());
296
297 let proof_b64 = BASE64.encode(&client_proof);
299 let client_final_msg = format!("{},p={}", client_final_without_proof, proof_b64);
300 let client_final_b64 = BASE64.encode(client_final_msg.as_bytes());
301
302 Ok((client_final_b64, server_signature))
303}
304
305pub fn server_verify_final(
311 handshake: &ScramHandshake,
312 client_final_b64: &str,
313) -> Result<Vec<u8>, AuthError> {
314 let client_final_bytes = BASE64
316 .decode(client_final_b64)
317 .map_err(|e| AuthError::Base64Error(e.to_string()))?;
318 let client_final_msg = String::from_utf8(client_final_bytes)
319 .map_err(|e| AuthError::HandshakeFailed(e.to_string()))?;
320
321 let parts: Vec<&str> = client_final_msg.splitn(3, ',').collect();
323 if parts.len() != 3 {
324 return Err(AuthError::HandshakeFailed(
325 "invalid client-final-message format".to_string(),
326 ));
327 }
328
329 let channel_binding = parse_scram_param(parts[0], "c=")?;
331 if channel_binding != "biws" {
332 return Err(AuthError::HandshakeFailed(
333 "unexpected channel binding".to_string(),
334 ));
335 }
336
337 let combined_nonce = parse_scram_param(parts[1], "r=")?;
339 let expected_combined = format!("{}{}", handshake.client_nonce, handshake.server_nonce);
340 if combined_nonce != expected_combined {
341 return Err(AuthError::HandshakeFailed("nonce mismatch".to_string()));
342 }
343
344 let proof_b64 = parse_scram_param(parts[2], "p=")?;
346 let client_proof = BASE64
347 .decode(proof_b64)
348 .map_err(|e| AuthError::Base64Error(e.to_string()))?;
349
350 let client_signature = hmac_sha256(&handshake.stored_key, handshake.auth_message.as_bytes());
355 let recovered_client_key = xor_bytes(&client_proof, &client_signature);
356 let recovered_stored_key = sha256(&recovered_client_key);
357
358 if recovered_stored_key
359 .ct_eq(&handshake.stored_key)
360 .unwrap_u8()
361 == 0
362 {
363 return Err(AuthError::InvalidCredentials);
364 }
365
366 Ok(handshake.server_signature.clone())
368}
369
370pub fn parse_auth_header(header: &str) -> Result<AuthHeader, AuthError> {
377 let header = header.trim();
378
379 if let Some(rest) = header.strip_prefix("HELLO ") {
380 let username_b64 = rest
381 .trim()
382 .strip_prefix("username=")
383 .ok_or_else(|| AuthError::InvalidHeader("missing username= in HELLO".into()))?;
384 let username_bytes = BASE64
385 .decode(username_b64.trim())
386 .map_err(|e| AuthError::Base64Error(e.to_string()))?;
387 let username = String::from_utf8(username_bytes)
388 .map_err(|e| AuthError::InvalidHeader(e.to_string()))?;
389 Ok(AuthHeader::Hello { username })
390 } else if let Some(rest) = header.strip_prefix("SCRAM ") {
391 let mut handshake_token = None;
392 let mut data = None;
393 for part in rest.split(',') {
394 let part = part.trim();
395 if let Some(val) = part.strip_prefix("handshakeToken=") {
396 handshake_token = Some(val.trim().to_string());
397 } else if let Some(val) = part.strip_prefix("data=") {
398 data = Some(val.trim().to_string());
399 }
400 }
401 let handshake_token = handshake_token
402 .ok_or_else(|| AuthError::InvalidHeader("missing handshakeToken= in SCRAM".into()))?;
403 let data = data.ok_or_else(|| AuthError::InvalidHeader("missing data= in SCRAM".into()))?;
404 Ok(AuthHeader::Scram {
405 handshake_token,
406 data,
407 })
408 } else if let Some(rest) = header.strip_prefix("BEARER ") {
409 let token = rest
410 .trim()
411 .strip_prefix("authToken=")
412 .ok_or_else(|| AuthError::InvalidHeader("missing authToken= in BEARER".into()))?;
413 Ok(AuthHeader::Bearer {
414 auth_token: token.trim().to_string(),
415 })
416 } else {
417 Err(AuthError::InvalidHeader(format!(
418 "unrecognized auth scheme: {}",
419 header
420 )))
421 }
422}
423
424pub fn format_www_authenticate(handshake_token: &str, hash: &str, data_b64: &str) -> String {
428 format!(
429 "SCRAM handshakeToken={}, hash={}, data={}",
430 handshake_token, hash, data_b64
431 )
432}
433
434pub fn format_auth_info(auth_token: &str, data_b64: &str) -> String {
438 format!("authToken={}, data={}", auth_token, data_b64)
439}
440
441#[cfg(test)]
446mod tests {
447 use super::*;
448
449 #[test]
450 fn test_derive_credentials() {
451 let password = "pencil";
452 let salt = b"random-salt-value";
453 let iterations = 4096;
454
455 let creds = derive_credentials(password, salt, iterations);
456
457 assert_eq!(creds.salt, salt.to_vec());
459 assert_eq!(creds.iterations, iterations);
460 assert_eq!(creds.stored_key.len(), 32); assert_eq!(creds.server_key.len(), 32);
462
463 let creds2 = derive_credentials(password, salt, iterations);
465 assert_eq!(creds.stored_key, creds2.stored_key);
466 assert_eq!(creds.server_key, creds2.server_key);
467
468 let creds3 = derive_credentials("other", salt, iterations);
470 assert_ne!(creds.stored_key, creds3.stored_key);
471 assert_ne!(creds.server_key, creds3.server_key);
472 }
473
474 #[test]
475 fn test_generate_nonce() {
476 let n1 = generate_nonce();
477 let n2 = generate_nonce();
478
479 assert_ne!(n1, n2);
481
482 let decoded1 = BASE64.decode(&n1).expect("nonce must be valid base64");
484 assert_eq!(decoded1.len(), 18);
485
486 let decoded2 = BASE64.decode(&n2).expect("nonce must be valid base64");
487 assert_eq!(decoded2.len(), 18);
488 }
489
490 #[test]
491 fn test_parse_auth_header_hello() {
492 let username = "user";
493 let username_b64 = BASE64.encode(username.as_bytes());
494 let header = format!("HELLO username={}", username_b64);
495
496 let parsed = parse_auth_header(&header).unwrap();
497 assert_eq!(
498 parsed,
499 AuthHeader::Hello {
500 username: "user".to_string(),
501 }
502 );
503 }
504
505 #[test]
506 fn test_parse_auth_header_scram() {
507 let header = "SCRAM handshakeToken=abc123, data=c29tZWRhdGE=";
508 let parsed = parse_auth_header(header).unwrap();
509 assert_eq!(
510 parsed,
511 AuthHeader::Scram {
512 handshake_token: "abc123".to_string(),
513 data: "c29tZWRhdGE=".to_string(),
514 }
515 );
516 }
517
518 #[test]
519 fn test_parse_auth_header_bearer() {
520 let header = "BEARER authToken=mytoken123";
521 let parsed = parse_auth_header(header).unwrap();
522 assert_eq!(
523 parsed,
524 AuthHeader::Bearer {
525 auth_token: "mytoken123".to_string(),
526 }
527 );
528 }
529
530 #[test]
531 fn test_parse_auth_header_invalid() {
532 assert!(parse_auth_header("UNKNOWN foo=bar").is_err());
534 assert!(parse_auth_header("HELLO foo=bar").is_err());
536 assert!(parse_auth_header("SCRAM handshakeToken=abc").is_err());
538 assert!(parse_auth_header("BEARER token=abc").is_err());
540 assert!(parse_auth_header("").is_err());
542 }
543
544 #[test]
545 fn test_full_handshake() {
546 let username = "testuser";
548 let password = "s3cret";
549 let salt = b"test-salt-12345";
550 let iterations = 4096;
551
552 let credentials = derive_credentials(password, salt, iterations);
554
555 let username_b64 = BASE64.encode(username.as_bytes());
557 let hello_header = format!("HELLO username={}", username_b64);
558 let parsed = parse_auth_header(&hello_header).unwrap();
559 match &parsed {
560 AuthHeader::Hello { username: u } => assert_eq!(u, username),
561 _ => panic!("expected Hello variant"),
562 }
563
564 let (client_nonce, _client_first_b64) = client_first_message(username);
566
567 let (handshake, server_first_b64) =
569 server_first_message(username, &client_nonce, &credentials);
570
571 let www_auth = format_www_authenticate("handshake-token-xyz", "SHA-256", &server_first_b64);
573 assert!(www_auth.contains("SCRAM"));
574 assert!(www_auth.contains("SHA-256"));
575 assert!(www_auth.contains("handshake-token-xyz"));
576
577 let (client_final_b64, expected_server_sig) =
579 client_final_message(password, &client_nonce, &server_first_b64, username).unwrap();
580
581 let server_sig = server_verify_final(&handshake, &client_final_b64).unwrap();
583
584 assert_eq!(server_sig, expected_server_sig);
586
587 let server_final_msg = format!("v={}", BASE64.encode(&server_sig));
589 let server_final_b64 = BASE64.encode(server_final_msg.as_bytes());
590 let auth_info = format_auth_info("auth-token-abc", &server_final_b64);
591 assert!(auth_info.contains("authToken=auth-token-abc"));
592
593 let server_final_decoded = BASE64.decode(&server_final_b64).unwrap();
595 let server_final_str = String::from_utf8(server_final_decoded).unwrap();
596 let sig_b64 = server_final_str.strip_prefix("v=").unwrap();
597 let received_server_sig = BASE64.decode(sig_b64).unwrap();
598 assert_eq!(received_server_sig, expected_server_sig);
599 }
600
601 #[test]
602 fn test_client_server_roundtrip() {
603 let username = "admin";
605 let password = "correcthorsebatterystaple";
606 let salt = b"unique-salt-value";
607 let iterations = DEFAULT_ITERATIONS;
608
609 let credentials = derive_credentials(password, salt, iterations);
611
612 let (client_nonce, client_first_b64) = client_first_message(username);
614
615 let client_first_decoded = BASE64.decode(&client_first_b64).unwrap();
617 let client_first_str = String::from_utf8(client_first_decoded).unwrap();
618 assert!(client_first_str.starts_with("n,,"));
619 assert!(client_first_str.contains(&format!("r={}", client_nonce)));
620
621 let (handshake, server_first_b64) =
623 server_first_message(username, &client_nonce, &credentials);
624
625 let server_first_decoded = BASE64.decode(&server_first_b64).unwrap();
627 let server_first_str = String::from_utf8(server_first_decoded).unwrap();
628 assert!(server_first_str.starts_with("r="));
629 assert!(server_first_str.contains(",s="));
630 assert!(server_first_str.contains(",i="));
631 assert!(server_first_str.contains(&client_nonce));
632
633 let (client_final_b64, expected_server_sig) =
635 client_final_message(password, &client_nonce, &server_first_b64, username).unwrap();
636
637 let client_final_decoded = BASE64.decode(&client_final_b64).unwrap();
639 let client_final_str = String::from_utf8(client_final_decoded).unwrap();
640 assert!(client_final_str.starts_with("c=biws,"));
641 assert!(client_final_str.contains(",p="));
642
643 let server_sig = server_verify_final(&handshake, &client_final_b64).unwrap();
645 assert_eq!(server_sig, expected_server_sig);
646
647 let (wrong_final_b64, _) =
649 client_final_message("wrongpassword", &client_nonce, &server_first_b64, username)
650 .unwrap();
651 let result = server_verify_final(&handshake, &wrong_final_b64);
652 assert!(result.is_err());
653 match result {
654 Err(AuthError::InvalidCredentials) => {} other => panic!("expected InvalidCredentials, got {:?}", other),
656 }
657 }
658
659 #[test]
660 fn test_format_www_authenticate() {
661 let result = format_www_authenticate("tok123", "SHA-256", "c29tZQ==");
662 assert_eq!(
663 result,
664 "SCRAM handshakeToken=tok123, hash=SHA-256, data=c29tZQ=="
665 );
666 }
667
668 #[test]
669 fn test_format_auth_info() {
670 let result = format_auth_info("auth-tok", "ZGF0YQ==");
671 assert_eq!(result, "authToken=auth-tok, data=ZGF0YQ==");
672 }
673}