Skip to main content

haystack_core/
auth.rs

1//! SCRAM SHA-256 authentication primitives for the Haystack auth protocol.
2//!
3//! This module implements the cryptographic operations needed for SCRAM
4//! (Salted Challenge Response Authentication Mechanism) with SHA-256 as
5//! specified by the [Project Haystack auth spec](https://project-haystack.org/doc/docHaystack/Auth).
6//!
7//! It provides functions shared by both server and client implementations
8//! for the three-phase handshake: HELLO, SCRAM challenge/response, and
9//! BEARER token issuance.
10
11use base64::Engine;
12use base64::engine::general_purpose::STANDARD as BASE64;
13use hmac::{Hmac, Mac};
14use pbkdf2::pbkdf2_hmac;
15use rand::RngExt;
16use sha2::{Digest, Sha256};
17use subtle::ConstantTimeEq;
18
19type HmacSha256 = Hmac<Sha256>;
20
21/// Default PBKDF2 iteration count for SCRAM SHA-256.
22pub const DEFAULT_ITERATIONS: u32 = 100_000;
23
24// ---------------------------------------------------------------------------
25// Error type
26// ---------------------------------------------------------------------------
27
28/// Errors that can occur during SCRAM authentication.
29#[derive(Debug, thiserror::Error)]
30pub enum AuthError {
31    #[error("invalid credentials")]
32    InvalidCredentials,
33    #[error("invalid auth header: {0}")]
34    InvalidHeader(String),
35    #[error("handshake failed: {0}")]
36    HandshakeFailed(String),
37    #[error("base64 decode error: {0}")]
38    Base64Error(String),
39}
40
41// ---------------------------------------------------------------------------
42// Types
43// ---------------------------------------------------------------------------
44
45/// Pre-computed SCRAM credentials for a user (stored server-side).
46#[derive(Debug, Clone)]
47pub struct ScramCredentials {
48    pub salt: Vec<u8>,
49    pub iterations: u32,
50    pub stored_key: Vec<u8>,
51    pub server_key: Vec<u8>,
52}
53
54/// In-flight SCRAM handshake state held by the server between the
55/// server-first-message and client-final-message exchanges.
56#[derive(Debug, Clone)]
57pub struct ScramHandshake {
58    pub username: String,
59    pub client_nonce: String,
60    pub server_nonce: String,
61    pub salt: Vec<u8>,
62    pub iterations: u32,
63    pub auth_message: String,
64    pub server_signature: Vec<u8>,
65    /// Stored key from credentials, needed to verify the client proof.
66    stored_key: Vec<u8>,
67}
68
69/// Parsed Haystack `Authorization` header.
70#[derive(Debug, Clone, PartialEq, Eq)]
71pub enum AuthHeader {
72    Hello {
73        username: String,
74        /// Base64-encoded client-first-message (contains the client nonce).
75        data: Option<String>,
76    },
77    Scram {
78        handshake_token: String,
79        data: String,
80    },
81    Bearer {
82        auth_token: String,
83    },
84}
85
86// ---------------------------------------------------------------------------
87// Internal helpers
88// ---------------------------------------------------------------------------
89
90/// Compute HMAC-SHA-256(key, msg).
91fn hmac_sha256(key: &[u8], msg: &[u8]) -> Vec<u8> {
92    let mut mac = HmacSha256::new_from_slice(key).expect("HMAC accepts keys of any size");
93    mac.update(msg);
94    mac.finalize().into_bytes().to_vec()
95}
96
97/// Compute SHA-256(data).
98fn sha256(data: &[u8]) -> Vec<u8> {
99    let mut hasher = Sha256::new();
100    hasher.update(data);
101    hasher.finalize().to_vec()
102}
103
104/// XOR two equal-length byte slices.
105fn xor_bytes(a: &[u8], b: &[u8]) -> Vec<u8> {
106    assert_eq!(a.len(), b.len(), "XOR operands must be the same length");
107    a.iter().zip(b.iter()).map(|(x, y)| x ^ y).collect()
108}
109
110/// PBKDF2-HMAC-SHA-256 key derivation, producing a 32-byte salted password.
111fn pbkdf2_sha256(password: &[u8], salt: &[u8], iterations: u32) -> Vec<u8> {
112    let mut salted_password = vec![0u8; 32];
113    pbkdf2_hmac::<Sha256>(password, salt, iterations, &mut salted_password);
114    salted_password
115}
116
117/// Derive (ClientKey, StoredKey, ServerKey) from a salted password.
118fn derive_keys(salted_password: &[u8]) -> (Vec<u8>, Vec<u8>, Vec<u8>) {
119    let client_key = hmac_sha256(salted_password, b"Client Key");
120    let stored_key = sha256(&client_key);
121    let server_key = hmac_sha256(salted_password, b"Server Key");
122    (client_key, stored_key, server_key)
123}
124
125/// Parse a `key=value` parameter from a SCRAM message segment.
126fn parse_scram_param<'a>(segment: &'a str, prefix: &str) -> Result<&'a str, AuthError> {
127    let trimmed = segment.trim();
128    trimmed.strip_prefix(prefix).ok_or_else(|| {
129        AuthError::HandshakeFailed(format!(
130            "expected prefix '{}' but got '{}'",
131            prefix, trimmed
132        ))
133    })
134}
135
136/// Build the client-first-message-bare: `n=<username>,r=<client_nonce>`.
137fn make_client_first_bare(username: &str, client_nonce: &str) -> String {
138    format!("n={},r={}", username, client_nonce)
139}
140
141// ---------------------------------------------------------------------------
142// Public API
143// ---------------------------------------------------------------------------
144
145/// Derive SCRAM credentials from a password (for user creation/storage).
146///
147/// Uses PBKDF2-HMAC-SHA-256 with the given salt and iteration count.
148pub fn derive_credentials(password: &str, salt: &[u8], iterations: u32) -> ScramCredentials {
149    let salted_password = pbkdf2_sha256(password.as_bytes(), salt, iterations);
150    let (_client_key, stored_key, server_key) = derive_keys(&salted_password);
151    ScramCredentials {
152        salt: salt.to_vec(),
153        iterations,
154        stored_key,
155        server_key,
156    }
157}
158
159/// Generate a random nonce string (base64-encoded 18 random bytes).
160pub fn generate_nonce() -> String {
161    let mut bytes = [0u8; 18];
162    rand::rng().fill(&mut bytes);
163    BASE64.encode(bytes)
164}
165
166/// Client-side: Create the client-first-message data (base64-encoded).
167///
168/// Returns `(client_nonce, client_first_data_base64)`.
169///
170/// The client-first-message-bare is `n=<username>,r=<client_nonce>`.
171/// The full message prepends the GS2 header `n,,` (no channel binding).
172pub fn client_first_message(username: &str) -> (String, String) {
173    let client_nonce = generate_nonce();
174    let bare = make_client_first_bare(username, &client_nonce);
175    let full = format!("n,,{}", bare);
176    let encoded = BASE64.encode(full.as_bytes());
177    (client_nonce, encoded)
178}
179
180/// Server-side: Create the server-first-message data and handshake state.
181///
182/// `username` is taken from the HELLO phase. `client_nonce_b64` is the raw
183/// client nonce (as returned by [`client_first_message`]). `credentials` are
184/// the pre-computed SCRAM credentials for this user.
185///
186/// Returns `(handshake_state, server_first_data_base64)`.
187pub fn server_first_message(
188    username: &str,
189    client_nonce_b64: &str,
190    credentials: &ScramCredentials,
191) -> (ScramHandshake, String) {
192    let server_nonce = generate_nonce();
193    let combined_nonce = format!("{}{}", client_nonce_b64, server_nonce);
194    let salt_b64 = BASE64.encode(&credentials.salt);
195
196    // server-first-message: r=<combined>,s=<salt_b64>,i=<iterations>
197    let server_first_msg = format!(
198        "r={},s={},i={}",
199        combined_nonce, salt_b64, credentials.iterations
200    );
201
202    // client-first-message-bare (includes username per SCRAM spec)
203    let cfmb = make_client_first_bare(username, client_nonce_b64);
204
205    // client-final-message-without-proof (anticipated)
206    let client_final_without_proof = format!("c=biws,r={}", combined_nonce);
207
208    // AuthMessage = client-first-bare "," server-first-msg "," client-final-without-proof
209    let auth_message = format!(
210        "{},{},{}",
211        cfmb, server_first_msg, client_final_without_proof
212    );
213
214    // Pre-compute server signature
215    let server_signature = hmac_sha256(&credentials.server_key, auth_message.as_bytes());
216
217    let server_first_b64 = BASE64.encode(server_first_msg.as_bytes());
218
219    let handshake = ScramHandshake {
220        username: username.to_string(),
221        client_nonce: client_nonce_b64.to_string(),
222        server_nonce,
223        salt: credentials.salt.clone(),
224        iterations: credentials.iterations,
225        auth_message,
226        server_signature,
227        stored_key: credentials.stored_key.clone(),
228    };
229
230    (handshake, server_first_b64)
231}
232
233/// Client-side: Process server-first-message, produce client-final-message.
234///
235/// `username` is the same value originally passed to [`client_first_message`].
236/// `password` is the user's plaintext password. `client_nonce` is the nonce
237/// returned by [`client_first_message`]. `server_first_b64` is the base64
238/// server-first-message data received from the server.
239///
240/// Returns `(client_final_data_base64, expected_server_signature)`.
241pub fn client_final_message(
242    password: &str,
243    client_nonce: &str,
244    server_first_b64: &str,
245    username: &str,
246) -> Result<(String, Vec<u8>), AuthError> {
247    // Decode and parse server-first-message
248    let server_first_bytes = BASE64
249        .decode(server_first_b64)
250        .map_err(|e| AuthError::Base64Error(e.to_string()))?;
251    let server_first_msg = String::from_utf8(server_first_bytes)
252        .map_err(|e| AuthError::HandshakeFailed(e.to_string()))?;
253
254    // Expected format: r=<combined_nonce>,s=<salt_b64>,i=<iterations>
255    let parts: Vec<&str> = server_first_msg.splitn(3, ',').collect();
256    if parts.len() != 3 {
257        return Err(AuthError::HandshakeFailed(
258            "invalid server-first-message format".to_string(),
259        ));
260    }
261
262    let combined_nonce = parse_scram_param(parts[0], "r=")?;
263    let salt_b64 = parse_scram_param(parts[1], "s=")?;
264    let iterations_str = parse_scram_param(parts[2], "i=")?;
265
266    // The combined nonce must start with our client nonce
267    if !combined_nonce.starts_with(client_nonce) {
268        return Err(AuthError::HandshakeFailed(
269            "combined nonce does not start with client nonce".to_string(),
270        ));
271    }
272
273    let salt = BASE64
274        .decode(salt_b64)
275        .map_err(|e| AuthError::Base64Error(e.to_string()))?;
276    let iterations: u32 = iterations_str
277        .parse()
278        .map_err(|e: std::num::ParseIntError| AuthError::HandshakeFailed(e.to_string()))?;
279
280    // Key derivation
281    let salted_password = pbkdf2_sha256(password.as_bytes(), &salt, iterations);
282    let (client_key, stored_key, server_key) = derive_keys(&salted_password);
283
284    // Build AuthMessage
285    let cfmb = make_client_first_bare(username, client_nonce);
286    let client_final_without_proof = format!("c=biws,r={}", combined_nonce);
287    let auth_message = format!(
288        "{},{},{}",
289        cfmb, server_first_msg, client_final_without_proof
290    );
291
292    // ClientSignature = HMAC(StoredKey, AuthMessage)
293    let client_signature = hmac_sha256(&stored_key, auth_message.as_bytes());
294    // ClientProof = ClientKey XOR ClientSignature
295    let client_proof = xor_bytes(&client_key, &client_signature);
296    // ServerSignature = HMAC(ServerKey, AuthMessage)
297    let server_signature = hmac_sha256(&server_key, auth_message.as_bytes());
298
299    // client-final-message: c=biws,r=<combined>,p=<proof_b64>
300    let proof_b64 = BASE64.encode(&client_proof);
301    let client_final_msg = format!("{},p={}", client_final_without_proof, proof_b64);
302    let client_final_b64 = BASE64.encode(client_final_msg.as_bytes());
303
304    Ok((client_final_b64, server_signature))
305}
306
307/// Server-side: Verify client-final-message and produce server signature.
308///
309/// Decodes the client-final-message, verifies the client proof against the
310/// stored key in the handshake state, and returns the server signature for
311/// the client to verify (sent as the `v=` field in server-final-message).
312pub fn server_verify_final(
313    handshake: &ScramHandshake,
314    client_final_b64: &str,
315) -> Result<Vec<u8>, AuthError> {
316    // Decode client-final-message
317    let client_final_bytes = BASE64
318        .decode(client_final_b64)
319        .map_err(|e| AuthError::Base64Error(e.to_string()))?;
320    let client_final_msg = String::from_utf8(client_final_bytes)
321        .map_err(|e| AuthError::HandshakeFailed(e.to_string()))?;
322
323    // Expected format: c=biws,r=<combined_nonce>,p=<proof_b64>
324    let parts: Vec<&str> = client_final_msg.splitn(3, ',').collect();
325    if parts.len() != 3 {
326        return Err(AuthError::HandshakeFailed(
327            "invalid client-final-message format".to_string(),
328        ));
329    }
330
331    // Validate channel binding
332    let channel_binding = parse_scram_param(parts[0], "c=")?;
333    if channel_binding != "biws" {
334        return Err(AuthError::HandshakeFailed(
335            "unexpected channel binding".to_string(),
336        ));
337    }
338
339    // Validate combined nonce
340    let combined_nonce = parse_scram_param(parts[1], "r=")?;
341    let expected_combined = format!("{}{}", handshake.client_nonce, handshake.server_nonce);
342    if !bool::from(
343        combined_nonce
344            .as_bytes()
345            .ct_eq(expected_combined.as_bytes()),
346    ) {
347        return Err(AuthError::HandshakeFailed("nonce mismatch".to_string()));
348    }
349
350    // Extract and decode client proof
351    let proof_b64 = parse_scram_param(parts[2], "p=")?;
352    let client_proof = BASE64
353        .decode(proof_b64)
354        .map_err(|e| AuthError::Base64Error(e.to_string()))?;
355
356    // Verify the proof per RFC 5802:
357    //   ClientSignature = HMAC(StoredKey, AuthMessage)
358    //   RecoveredClientKey = ClientProof XOR ClientSignature
359    //   Check: SHA-256(RecoveredClientKey) == StoredKey
360    let client_signature = hmac_sha256(&handshake.stored_key, handshake.auth_message.as_bytes());
361    let recovered_client_key = xor_bytes(&client_proof, &client_signature);
362    let recovered_stored_key = sha256(&recovered_client_key);
363
364    if recovered_stored_key
365        .ct_eq(&handshake.stored_key)
366        .unwrap_u8()
367        == 0
368    {
369        return Err(AuthError::InvalidCredentials);
370    }
371
372    // Proof verified -- return server signature for the client to verify
373    Ok(handshake.server_signature.clone())
374}
375
376/// Extract the client nonce from a base64-encoded client-first-message.
377///
378/// The client-first-message format is `n,,n=<username>,r=<client_nonce>`.
379/// Returns the raw nonce string.
380pub fn extract_client_nonce(client_first_b64: &str) -> Result<String, AuthError> {
381    let bytes = BASE64
382        .decode(client_first_b64)
383        .map_err(|e| AuthError::Base64Error(e.to_string()))?;
384    let msg = String::from_utf8(bytes).map_err(|e| AuthError::HandshakeFailed(e.to_string()))?;
385    // Strip GS2 header "n,," prefix
386    let bare = msg
387        .strip_prefix("n,,")
388        .ok_or_else(|| AuthError::HandshakeFailed("missing GS2 header in client-first".into()))?;
389    // Parse n=<user>,r=<nonce>
390    for part in bare.split(',') {
391        if let Some(nonce) = part.strip_prefix("r=") {
392            return Ok(nonce.to_string());
393        }
394    }
395    Err(AuthError::HandshakeFailed(
396        "missing r= nonce in client-first-message".into(),
397    ))
398}
399
400/// Parse a Haystack `Authorization` header value.
401///
402/// Supported formats:
403/// - `HELLO username=<base64(username)>`
404/// - `SCRAM handshakeToken=<token>, data=<data>`
405/// - `BEARER authToken=<token>`
406pub fn parse_auth_header(header: &str) -> Result<AuthHeader, AuthError> {
407    let header = header.trim();
408
409    if let Some(rest) = header.strip_prefix("HELLO ") {
410        let mut username_b64_val = None;
411        let mut data_val = None;
412        for part in rest.split(',') {
413            let part = part.trim();
414            if let Some(val) = part.strip_prefix("username=") {
415                username_b64_val = Some(val.trim().to_string());
416            } else if let Some(val) = part.strip_prefix("data=") {
417                data_val = Some(val.trim().to_string());
418            }
419        }
420        let username_b64 = username_b64_val
421            .ok_or_else(|| AuthError::InvalidHeader("missing username= in HELLO".into()))?;
422        let username_bytes = BASE64
423            .decode(&username_b64)
424            .map_err(|e| AuthError::Base64Error(e.to_string()))?;
425        let username = String::from_utf8(username_bytes)
426            .map_err(|e| AuthError::InvalidHeader(e.to_string()))?;
427        Ok(AuthHeader::Hello {
428            username,
429            data: data_val,
430        })
431    } else if let Some(rest) = header.strip_prefix("SCRAM ") {
432        let mut handshake_token = None;
433        let mut data = None;
434        for part in rest.split(',') {
435            let part = part.trim();
436            if let Some(val) = part.strip_prefix("handshakeToken=") {
437                handshake_token = Some(val.trim().to_string());
438            } else if let Some(val) = part.strip_prefix("data=") {
439                data = Some(val.trim().to_string());
440            }
441        }
442        let handshake_token = handshake_token
443            .ok_or_else(|| AuthError::InvalidHeader("missing handshakeToken= in SCRAM".into()))?;
444        let data = data.ok_or_else(|| AuthError::InvalidHeader("missing data= in SCRAM".into()))?;
445        Ok(AuthHeader::Scram {
446            handshake_token,
447            data,
448        })
449    } else if let Some(rest) = header.strip_prefix("BEARER ") {
450        let token = rest
451            .trim()
452            .strip_prefix("authToken=")
453            .ok_or_else(|| AuthError::InvalidHeader("missing authToken= in BEARER".into()))?;
454        Ok(AuthHeader::Bearer {
455            auth_token: token.trim().to_string(),
456        })
457    } else {
458        Err(AuthError::InvalidHeader(format!(
459            "unrecognized auth scheme: {}",
460            header
461        )))
462    }
463}
464
465/// Format a Haystack `WWW-Authenticate` header for a SCRAM challenge.
466///
467/// Produces: `SCRAM handshakeToken=<token>, hash=<hash>, data=<data_b64>`
468pub fn format_www_authenticate(handshake_token: &str, hash: &str, data_b64: &str) -> String {
469    format!(
470        "SCRAM handshakeToken={}, hash={}, data={}",
471        handshake_token, hash, data_b64
472    )
473}
474
475/// Format a Haystack `Authentication-Info` header with the auth token.
476///
477/// Produces: `authToken=<token>, data=<data_b64>`
478pub fn format_auth_info(auth_token: &str, data_b64: &str) -> String {
479    format!("authToken={}, data={}", auth_token, data_b64)
480}
481
482// ---------------------------------------------------------------------------
483// Tests
484// ---------------------------------------------------------------------------
485
486#[cfg(test)]
487mod tests {
488    use super::*;
489
490    #[test]
491    fn test_derive_credentials() {
492        let password = "pencil";
493        let salt = b"random-salt-value";
494        let iterations = 4096;
495
496        let creds = derive_credentials(password, salt, iterations);
497
498        // Fields are populated correctly
499        assert_eq!(creds.salt, salt.to_vec());
500        assert_eq!(creds.iterations, iterations);
501        assert_eq!(creds.stored_key.len(), 32); // SHA-256 output length
502        assert_eq!(creds.server_key.len(), 32);
503
504        // Deterministic: same inputs produce same outputs
505        let creds2 = derive_credentials(password, salt, iterations);
506        assert_eq!(creds.stored_key, creds2.stored_key);
507        assert_eq!(creds.server_key, creds2.server_key);
508
509        // Different password yields different credentials
510        let creds3 = derive_credentials("other", salt, iterations);
511        assert_ne!(creds.stored_key, creds3.stored_key);
512        assert_ne!(creds.server_key, creds3.server_key);
513    }
514
515    #[test]
516    fn test_generate_nonce() {
517        let n1 = generate_nonce();
518        let n2 = generate_nonce();
519
520        // Each call produces a unique nonce
521        assert_ne!(n1, n2);
522
523        // Valid base64 encoding of 18 bytes
524        let decoded1 = BASE64.decode(&n1).expect("nonce must be valid base64");
525        assert_eq!(decoded1.len(), 18);
526
527        let decoded2 = BASE64.decode(&n2).expect("nonce must be valid base64");
528        assert_eq!(decoded2.len(), 18);
529    }
530
531    #[test]
532    fn test_parse_auth_header_hello() {
533        let username = "user";
534        let username_b64 = BASE64.encode(username.as_bytes());
535        let header = format!("HELLO username={}", username_b64);
536
537        let parsed = parse_auth_header(&header).unwrap();
538        assert_eq!(
539            parsed,
540            AuthHeader::Hello {
541                username: "user".to_string(),
542                data: None,
543            }
544        );
545    }
546
547    #[test]
548    fn test_parse_auth_header_scram() {
549        let header = "SCRAM handshakeToken=abc123, data=c29tZWRhdGE=";
550        let parsed = parse_auth_header(header).unwrap();
551        assert_eq!(
552            parsed,
553            AuthHeader::Scram {
554                handshake_token: "abc123".to_string(),
555                data: "c29tZWRhdGE=".to_string(),
556            }
557        );
558    }
559
560    #[test]
561    fn test_parse_auth_header_bearer() {
562        let header = "BEARER authToken=mytoken123";
563        let parsed = parse_auth_header(header).unwrap();
564        assert_eq!(
565            parsed,
566            AuthHeader::Bearer {
567                auth_token: "mytoken123".to_string(),
568            }
569        );
570    }
571
572    #[test]
573    fn test_parse_auth_header_invalid() {
574        // Unknown scheme
575        assert!(parse_auth_header("UNKNOWN foo=bar").is_err());
576        // HELLO missing username=
577        assert!(parse_auth_header("HELLO foo=bar").is_err());
578        // SCRAM missing data=
579        assert!(parse_auth_header("SCRAM handshakeToken=abc").is_err());
580        // BEARER missing authToken=
581        assert!(parse_auth_header("BEARER token=abc").is_err());
582        // Empty
583        assert!(parse_auth_header("").is_err());
584    }
585
586    #[test]
587    fn test_full_handshake() {
588        // Simulate the complete HELLO -> SCRAM -> BEARER flow.
589        let username = "testuser";
590        let password = "s3cret";
591        let salt = b"test-salt-12345";
592        let iterations = 4096;
593
594        // --- Server: pre-compute credentials (user registration) ---
595        let credentials = derive_credentials(password, salt, iterations);
596
597        // --- Client: HELLO phase ---
598        let username_b64 = BASE64.encode(username.as_bytes());
599        let hello_header = format!("HELLO username={}", username_b64);
600        let parsed = parse_auth_header(&hello_header).unwrap();
601        match &parsed {
602            AuthHeader::Hello { username: u, .. } => assert_eq!(u, username),
603            _ => panic!("expected Hello variant"),
604        }
605
606        // --- Client: generate client-first-message ---
607        let (client_nonce, _client_first_b64) = client_first_message(username);
608
609        // --- Server: generate server-first-message ---
610        let (handshake, server_first_b64) =
611            server_first_message(username, &client_nonce, &credentials);
612
613        // --- Server: format WWW-Authenticate header ---
614        let www_auth = format_www_authenticate("handshake-token-xyz", "SHA-256", &server_first_b64);
615        assert!(www_auth.contains("SCRAM"));
616        assert!(www_auth.contains("SHA-256"));
617        assert!(www_auth.contains("handshake-token-xyz"));
618
619        // --- Client: process server-first, produce client-final ---
620        let (client_final_b64, expected_server_sig) =
621            client_final_message(password, &client_nonce, &server_first_b64, username).unwrap();
622
623        // --- Server: verify client-final ---
624        let server_sig = server_verify_final(&handshake, &client_final_b64).unwrap();
625
626        // Server signature should match what the client expects
627        assert_eq!(server_sig, expected_server_sig);
628
629        // --- Server: format Authentication-Info header ---
630        let server_final_msg = format!("v={}", BASE64.encode(&server_sig));
631        let server_final_b64 = BASE64.encode(server_final_msg.as_bytes());
632        let auth_info = format_auth_info("auth-token-abc", &server_final_b64);
633        assert!(auth_info.contains("authToken=auth-token-abc"));
634
635        // --- Client: verify server signature from server-final ---
636        let server_final_decoded = BASE64.decode(&server_final_b64).unwrap();
637        let server_final_str = String::from_utf8(server_final_decoded).unwrap();
638        let sig_b64 = server_final_str.strip_prefix("v=").unwrap();
639        let received_server_sig = BASE64.decode(sig_b64).unwrap();
640        assert_eq!(received_server_sig, expected_server_sig);
641    }
642
643    #[test]
644    fn test_client_server_roundtrip() {
645        // Full roundtrip using the public API functions.
646        let username = "admin";
647        let password = "correcthorsebatterystaple";
648        let salt = b"unique-salt-value";
649        let iterations = DEFAULT_ITERATIONS;
650
651        // 1. Server: create credentials during user registration
652        let credentials = derive_credentials(password, salt, iterations);
653
654        // 2. Client: create client-first-message
655        let (client_nonce, client_first_b64) = client_first_message(username);
656
657        // Verify client-first is valid base64 and well-formed
658        let client_first_decoded = BASE64.decode(&client_first_b64).unwrap();
659        let client_first_str = String::from_utf8(client_first_decoded).unwrap();
660        assert!(client_first_str.starts_with("n,,"));
661        assert!(client_first_str.contains(&format!("r={}", client_nonce)));
662
663        // 3. Server: create server-first-message
664        let (handshake, server_first_b64) =
665            server_first_message(username, &client_nonce, &credentials);
666
667        // Verify server-first contains expected SCRAM fields
668        let server_first_decoded = BASE64.decode(&server_first_b64).unwrap();
669        let server_first_str = String::from_utf8(server_first_decoded).unwrap();
670        assert!(server_first_str.starts_with("r="));
671        assert!(server_first_str.contains(",s="));
672        assert!(server_first_str.contains(",i="));
673        assert!(server_first_str.contains(&client_nonce));
674
675        // 4. Client: create client-final-message
676        let (client_final_b64, expected_server_sig) =
677            client_final_message(password, &client_nonce, &server_first_b64, username).unwrap();
678
679        // Verify client-final structure
680        let client_final_decoded = BASE64.decode(&client_final_b64).unwrap();
681        let client_final_str = String::from_utf8(client_final_decoded).unwrap();
682        assert!(client_final_str.starts_with("c=biws,"));
683        assert!(client_final_str.contains(",p="));
684
685        // 5. Server: verify and get server signature
686        let server_sig = server_verify_final(&handshake, &client_final_b64).unwrap();
687        assert_eq!(server_sig, expected_server_sig);
688
689        // 6. Wrong password: server rejects the proof
690        let (wrong_final_b64, _) =
691            client_final_message("wrongpassword", &client_nonce, &server_first_b64, username)
692                .unwrap();
693        let result = server_verify_final(&handshake, &wrong_final_b64);
694        assert!(result.is_err());
695        match result {
696            Err(AuthError::InvalidCredentials) => {} // expected
697            other => panic!("expected InvalidCredentials, got {:?}", other),
698        }
699    }
700
701    #[test]
702    fn test_format_www_authenticate() {
703        let result = format_www_authenticate("tok123", "SHA-256", "c29tZQ==");
704        assert_eq!(
705            result,
706            "SCRAM handshakeToken=tok123, hash=SHA-256, data=c29tZQ=="
707        );
708    }
709
710    #[test]
711    fn test_format_auth_info() {
712        let result = format_auth_info("auth-tok", "ZGF0YQ==");
713        assert_eq!(result, "authToken=auth-tok, data=ZGF0YQ==");
714    }
715}