Skip to main content

sentinel_driver/auth/
scram.rs

1use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
2use hmac::{Hmac, Mac};
3use sha2::{Digest, Sha256};
4
5use crate::config::ChannelBinding;
6use crate::connection::stream::PgConnection;
7use crate::error::{Error, Result};
8use crate::protocol::backend::BackendMessage;
9use crate::protocol::frontend;
10
11type HmacSha256 = Hmac<Sha256>;
12
13/// Perform SCRAM-SHA-256 (or SCRAM-SHA-256-PLUS) authentication with the server.
14///
15/// This handles the full 3-message exchange:
16/// 1. Client sends SASLInitialResponse with client-first-message
17/// 2. Server replies with AuthenticationSaslContinue (server-first-message)
18/// 3. Client sends SASLResponse with client-final-message
19/// 4. Server replies with AuthenticationSaslFinal (server-final verification)
20pub(crate) async fn authenticate(
21    conn: &mut PgConnection,
22    password: &str,
23    mechanisms: &[String],
24    channel_binding: ChannelBinding,
25    server_cert_der: Option<&[u8]>,
26) -> Result<()> {
27    // Determine mechanism and GS2 header based on channel binding config
28    let has_plus = mechanisms.iter().any(|m| m == "SCRAM-SHA-256-PLUS");
29    let has_plain = mechanisms.iter().any(|m| m == "SCRAM-SHA-256");
30    let is_tls = server_cert_der.is_some();
31
32    let (mechanism, gs2_header) = select_mechanism(channel_binding, is_tls, has_plus, has_plain)?;
33
34    // SASLprep the password (RFC 7613)
35    let prepped_password = saslprep(password)?;
36
37    // Generate client nonce
38    let client_nonce = generate_nonce();
39
40    // Client-first-message-bare: n=,r=<nonce>
41    // We don't send a username in the SCRAM exchange; PG uses the startup user.
42    let client_first_bare = format!("n=,r={client_nonce}");
43    let client_first_message = format!("{gs2_header}{client_first_bare}");
44
45    // Send SASLInitialResponse
46    frontend::sasl_initial_response(conn.write_buf(), mechanism, client_first_message.as_bytes());
47    conn.send().await?;
48
49    // Receive server-first-message
50    let server_first = match conn.recv().await? {
51        BackendMessage::AuthenticationSaslContinue { data } => String::from_utf8(data)
52            .map_err(|e| Error::Auth(format!("invalid server-first-message: {e}")))?,
53        BackendMessage::ErrorResponse { fields } => {
54            return Err(Error::server(
55                fields.severity,
56                fields.code,
57                fields.message,
58                fields.detail,
59                fields.hint,
60                fields.position,
61            ));
62        }
63        other => {
64            return Err(Error::protocol(format!(
65                "expected SaslContinue, got {other:?}"
66            )));
67        }
68    };
69
70    // Parse server-first-message: r=<nonce>,s=<salt>,i=<iterations>
71    let parsed = parse_server_first(&server_first)?;
72
73    // Verify server nonce starts with our client nonce
74    if !parsed.nonce.starts_with(&client_nonce) {
75        return Err(Error::Auth(
76            "server nonce doesn't match client nonce".into(),
77        ));
78    }
79
80    let salt = BASE64
81        .decode(&parsed.salt)
82        .map_err(|e| Error::Auth(format!("invalid salt base64: {e}")))?;
83
84    // Compute SCRAM proof
85    let salted_password = hi(prepped_password.as_bytes(), &salt, parsed.iterations);
86    let client_key = hmac_sha256(&salted_password, b"Client Key");
87    let stored_key = sha256(&client_key);
88    let server_key = hmac_sha256(&salted_password, b"Server Key");
89
90    // Build channel binding data for c= parameter
91    let cbind_input = build_channel_binding_data(gs2_header, server_cert_der);
92    let channel_binding_b64 = BASE64.encode(&cbind_input);
93
94    // client-final-message-without-proof
95    let client_final_without_proof = format!("c={channel_binding_b64},r={}", parsed.nonce);
96
97    // AuthMessage = client-first-bare + "," + server-first + "," + client-final-without-proof
98    let auth_message = format!("{client_first_bare},{server_first},{client_final_without_proof}");
99
100    let client_signature = hmac_sha256(&stored_key, auth_message.as_bytes());
101    let client_proof: Vec<u8> = client_key
102        .iter()
103        .zip(client_signature.iter())
104        .map(|(a, b)| a ^ b)
105        .collect();
106
107    let server_signature = hmac_sha256(&server_key, auth_message.as_bytes());
108
109    // Send client-final-message
110    let client_final = format!(
111        "{client_final_without_proof},p={}",
112        BASE64.encode(&client_proof)
113    );
114
115    frontend::sasl_response(conn.write_buf(), client_final.as_bytes());
116    conn.send().await?;
117
118    // Receive server-final-message
119    match conn.recv().await? {
120        BackendMessage::AuthenticationSaslFinal { data } => {
121            let server_final = String::from_utf8(data)
122                .map_err(|e| Error::Auth(format!("invalid server-final-message: {e}")))?;
123
124            // Verify server signature
125            let expected_verifier = format!("v={}", BASE64.encode(&server_signature));
126            if server_final != expected_verifier {
127                return Err(Error::Auth("server signature verification failed".into()));
128            }
129        }
130        BackendMessage::ErrorResponse { fields } => {
131            return Err(Error::server(
132                fields.severity,
133                fields.code,
134                fields.message,
135                fields.detail,
136                fields.hint,
137                fields.position,
138            ));
139        }
140        other => {
141            return Err(Error::protocol(format!(
142                "expected SaslFinal, got {other:?}"
143            )));
144        }
145    }
146
147    // Wait for AuthenticationOk
148    match conn.recv().await? {
149        BackendMessage::AuthenticationOk => Ok(()),
150        BackendMessage::ErrorResponse { fields } => Err(Error::server(
151            fields.severity,
152            fields.code,
153            fields.message,
154            fields.detail,
155            fields.hint,
156            fields.position,
157        )),
158        other => Err(Error::protocol(format!(
159            "expected AuthenticationOk, got {other:?}"
160        ))),
161    }
162}
163
164/// Select SCRAM mechanism and GS2 header based on channel binding config.
165fn select_mechanism(
166    channel_binding: ChannelBinding,
167    is_tls: bool,
168    has_plus: bool,
169    has_plain: bool,
170) -> Result<(&'static str, &'static str)> {
171    match channel_binding {
172        ChannelBinding::Require => {
173            if !is_tls {
174                return Err(Error::Auth("channel binding requires TLS".into()));
175            }
176            if !has_plus {
177                return Err(Error::Auth(
178                    "server does not support SCRAM-SHA-256-PLUS".into(),
179                ));
180            }
181            Ok(("SCRAM-SHA-256-PLUS", "p=tls-server-end-point,,"))
182        }
183        ChannelBinding::Prefer => {
184            if is_tls && has_plus {
185                Ok(("SCRAM-SHA-256-PLUS", "p=tls-server-end-point,,"))
186            } else if has_plain {
187                // y,, = client supports channel binding but server doesn't advertise it
188                let gs2 = if is_tls { "y,," } else { "n,," };
189                Ok(("SCRAM-SHA-256", gs2))
190            } else {
191                Err(Error::Auth(
192                    "server offered no supported SASL mechanisms".into(),
193                ))
194            }
195        }
196        ChannelBinding::Disable => {
197            if has_plain {
198                Ok(("SCRAM-SHA-256", "n,,"))
199            } else {
200                Err(Error::Auth(
201                    "server offered no supported SASL mechanisms".into(),
202                ))
203            }
204        }
205    }
206}
207
208/// Build the channel binding input bytes: gs2_header + cbind_data.
209///
210/// For `tls-server-end-point`, cbind_data is SHA-256 hash of the server's DER certificate.
211/// For non-PLUS, cbind_data is empty (just the GS2 header).
212fn build_channel_binding_data(gs2_header: &str, server_cert_der: Option<&[u8]>) -> Vec<u8> {
213    let mut data = gs2_header.as_bytes().to_vec();
214    if gs2_header.starts_with("p=tls-server-end-point") {
215        if let Some(cert_der) = server_cert_der {
216            let hash = sha256(cert_der);
217            data.extend_from_slice(&hash);
218        }
219    }
220    data
221}
222
223pub struct ServerFirst {
224    pub nonce: String,
225    pub salt: String,
226    pub iterations: u32,
227}
228
229pub fn parse_server_first(msg: &str) -> Result<ServerFirst> {
230    let mut nonce = None;
231    let mut salt = None;
232    let mut iterations = None;
233
234    for part in msg.split(',') {
235        if let Some(val) = part.strip_prefix("r=") {
236            nonce = Some(val.to_string());
237        } else if let Some(val) = part.strip_prefix("s=") {
238            salt = Some(val.to_string());
239        } else if let Some(val) = part.strip_prefix("i=") {
240            iterations = Some(
241                val.parse::<u32>()
242                    .map_err(|_| Error::Auth(format!("invalid iteration count: {val}")))?,
243            );
244        }
245    }
246
247    Ok(ServerFirst {
248        nonce: nonce.ok_or_else(|| Error::Auth("missing nonce in server-first".into()))?,
249        salt: salt.ok_or_else(|| Error::Auth("missing salt in server-first".into()))?,
250        iterations: iterations
251            .ok_or_else(|| Error::Auth("missing iterations in server-first".into()))?,
252    })
253}
254
255/// Hi(password, salt, iterations) — PBKDF2-HMAC-SHA256.
256pub fn hi(password: &[u8], salt: &[u8], iterations: u32) -> Vec<u8> {
257    // U1 = HMAC(password, salt + INT(1))
258    let mut salt_with_one = salt.to_vec();
259    salt_with_one.extend_from_slice(&1u32.to_be_bytes());
260
261    let mut u_prev = hmac_sha256(password, &salt_with_one);
262    let mut result = u_prev.clone();
263
264    for _ in 1..iterations {
265        let u_current = hmac_sha256(password, &u_prev);
266        for (r, u) in result.iter_mut().zip(u_current.iter()) {
267            *r ^= u;
268        }
269        u_prev = u_current;
270    }
271
272    result
273}
274
275pub fn hmac_sha256(key: &[u8], data: &[u8]) -> Vec<u8> {
276    #[allow(clippy::expect_used)]
277    let mut mac = HmacSha256::new_from_slice(key).expect("HMAC accepts any key length");
278    mac.update(data);
279    mac.finalize().into_bytes().to_vec()
280}
281
282fn sha256(data: &[u8]) -> Vec<u8> {
283    let mut hasher = Sha256::new();
284    hasher.update(data);
285    hasher.finalize().to_vec()
286}
287
288/// SASLprep (RFC 7613) password normalization.
289///
290/// This is what sqlx gets wrong — they skip this step, leading to
291/// authentication failures with non-ASCII passwords.
292pub fn saslprep(input: &str) -> Result<String> {
293    stringprep::saslprep(input)
294        .map(std::borrow::Cow::into_owned)
295        .map_err(|e| Error::Auth(format!("SASLprep failed: {e}")))
296}
297
298/// Generate a random nonce for SCRAM.
299pub fn generate_nonce() -> String {
300    use rand::Rng;
301    let mut rng = rand::thread_rng();
302    let bytes: Vec<u8> = (0..24).map(|_| rng.gen()).collect();
303    BASE64.encode(&bytes)
304}
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309
310    #[test]
311    fn test_select_mechanism_require_with_tls_and_plus() {
312        let (mech, gs2) = select_mechanism(ChannelBinding::Require, true, true, true).unwrap();
313        assert_eq!(mech, "SCRAM-SHA-256-PLUS");
314        assert_eq!(gs2, "p=tls-server-end-point,,");
315    }
316
317    #[test]
318    fn test_select_mechanism_require_without_tls() {
319        let err = select_mechanism(ChannelBinding::Require, false, false, true).unwrap_err();
320        assert!(err.to_string().contains("channel binding requires TLS"));
321    }
322
323    #[test]
324    fn test_select_mechanism_require_no_plus() {
325        let err = select_mechanism(ChannelBinding::Require, true, false, true).unwrap_err();
326        assert!(err
327            .to_string()
328            .contains("does not support SCRAM-SHA-256-PLUS"));
329    }
330
331    #[test]
332    fn test_select_mechanism_prefer_with_tls_and_plus() {
333        let (mech, gs2) = select_mechanism(ChannelBinding::Prefer, true, true, true).unwrap();
334        assert_eq!(mech, "SCRAM-SHA-256-PLUS");
335        assert_eq!(gs2, "p=tls-server-end-point,,");
336    }
337
338    #[test]
339    fn test_select_mechanism_prefer_tls_no_plus() {
340        let (mech, gs2) = select_mechanism(ChannelBinding::Prefer, true, false, true).unwrap();
341        assert_eq!(mech, "SCRAM-SHA-256");
342        assert_eq!(gs2, "y,,");
343    }
344
345    #[test]
346    fn test_select_mechanism_prefer_no_tls() {
347        let (mech, gs2) = select_mechanism(ChannelBinding::Prefer, false, false, true).unwrap();
348        assert_eq!(mech, "SCRAM-SHA-256");
349        assert_eq!(gs2, "n,,");
350    }
351
352    #[test]
353    fn test_select_mechanism_disable() {
354        let (mech, gs2) = select_mechanism(ChannelBinding::Disable, true, true, true).unwrap();
355        assert_eq!(mech, "SCRAM-SHA-256");
356        assert_eq!(gs2, "n,,");
357    }
358
359    #[test]
360    fn test_build_channel_binding_no_plus() {
361        let data = build_channel_binding_data("n,,", None);
362        assert_eq!(data, b"n,,");
363        assert_eq!(BASE64.encode(&data), "biws");
364    }
365
366    #[test]
367    fn test_build_channel_binding_with_plus() {
368        let fake_cert = b"fake-server-certificate-der";
369        let data = build_channel_binding_data("p=tls-server-end-point,,", Some(fake_cert));
370        // Should be: gs2_header bytes + sha256(cert)
371        let expected_hash = sha256(fake_cert);
372        let mut expected = b"p=tls-server-end-point,,".to_vec();
373        expected.extend_from_slice(&expected_hash);
374        assert_eq!(data, expected);
375    }
376
377    #[test]
378    fn test_gs2_header_y_flag() {
379        // y,, means client supports CB but server didn't advertise PLUS
380        let (mech, gs2) = select_mechanism(ChannelBinding::Prefer, true, false, true).unwrap();
381        assert_eq!(mech, "SCRAM-SHA-256");
382        assert_eq!(gs2, "y,,");
383        let data = build_channel_binding_data(gs2, Some(b"cert"));
384        // y,, should NOT include channel binding data
385        assert_eq!(data, b"y,,");
386    }
387
388    #[test]
389    fn test_select_mechanism_prefer_no_mechanisms() {
390        let err = select_mechanism(ChannelBinding::Prefer, false, false, false).unwrap_err();
391        assert!(err.to_string().contains("no supported SASL mechanisms"));
392    }
393
394    #[test]
395    fn test_select_mechanism_prefer_tls_no_mechanisms() {
396        let err = select_mechanism(ChannelBinding::Prefer, true, false, false).unwrap_err();
397        assert!(err.to_string().contains("no supported SASL mechanisms"));
398    }
399
400    #[test]
401    fn test_select_mechanism_disable_no_plain() {
402        let err = select_mechanism(ChannelBinding::Disable, true, true, false).unwrap_err();
403        assert!(err.to_string().contains("no supported SASL mechanisms"));
404    }
405
406    #[test]
407    fn test_select_mechanism_prefer_no_tls_plus_only() {
408        // Server only offers PLUS but client has no TLS — should fail
409        let err = select_mechanism(ChannelBinding::Prefer, false, true, false).unwrap_err();
410        assert!(err.to_string().contains("no supported SASL mechanisms"));
411    }
412}