1use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
2use hmac::{Hmac, Mac};
3use sha2::{Digest, Sha256};
4
5use crate::config::ChannelBinding;
6use crate::connection::stream::PgConnection;
7use crate::error::{Error, Result};
8use crate::protocol::backend::BackendMessage;
9use crate::protocol::frontend;
10
11type HmacSha256 = Hmac<Sha256>;
12
13pub(crate) async fn authenticate(
21 conn: &mut PgConnection,
22 password: &str,
23 mechanisms: &[String],
24 channel_binding: ChannelBinding,
25 server_cert_der: Option<&[u8]>,
26) -> Result<()> {
27 let has_plus = mechanisms.iter().any(|m| m == "SCRAM-SHA-256-PLUS");
29 let has_plain = mechanisms.iter().any(|m| m == "SCRAM-SHA-256");
30 let is_tls = server_cert_der.is_some();
31
32 let (mechanism, gs2_header) = select_mechanism(channel_binding, is_tls, has_plus, has_plain)?;
33
34 let prepped_password = saslprep(password)?;
36
37 let client_nonce = generate_nonce();
39
40 let client_first_bare = format!("n=,r={client_nonce}");
43 let client_first_message = format!("{gs2_header}{client_first_bare}");
44
45 frontend::sasl_initial_response(conn.write_buf(), mechanism, client_first_message.as_bytes());
47 conn.send().await?;
48
49 let server_first = match conn.recv().await? {
51 BackendMessage::AuthenticationSaslContinue { data } => String::from_utf8(data)
52 .map_err(|e| Error::Auth(format!("invalid server-first-message: {e}")))?,
53 BackendMessage::ErrorResponse { fields } => {
54 return Err(Error::server(
55 fields.severity,
56 fields.code,
57 fields.message,
58 fields.detail,
59 fields.hint,
60 fields.position,
61 ));
62 }
63 other => {
64 return Err(Error::protocol(format!(
65 "expected SaslContinue, got {other:?}"
66 )));
67 }
68 };
69
70 let parsed = parse_server_first(&server_first)?;
72
73 if !parsed.nonce.starts_with(&client_nonce) {
75 return Err(Error::Auth(
76 "server nonce doesn't match client nonce".into(),
77 ));
78 }
79
80 let salt = BASE64
81 .decode(&parsed.salt)
82 .map_err(|e| Error::Auth(format!("invalid salt base64: {e}")))?;
83
84 let salted_password = hi(prepped_password.as_bytes(), &salt, parsed.iterations);
86 let client_key = hmac_sha256(&salted_password, b"Client Key");
87 let stored_key = sha256(&client_key);
88 let server_key = hmac_sha256(&salted_password, b"Server Key");
89
90 let cbind_input = build_channel_binding_data(gs2_header, server_cert_der);
92 let channel_binding_b64 = BASE64.encode(&cbind_input);
93
94 let client_final_without_proof = format!("c={channel_binding_b64},r={}", parsed.nonce);
96
97 let auth_message = format!("{client_first_bare},{server_first},{client_final_without_proof}");
99
100 let client_signature = hmac_sha256(&stored_key, auth_message.as_bytes());
101 let client_proof: Vec<u8> = client_key
102 .iter()
103 .zip(client_signature.iter())
104 .map(|(a, b)| a ^ b)
105 .collect();
106
107 let server_signature = hmac_sha256(&server_key, auth_message.as_bytes());
108
109 let client_final = format!(
111 "{client_final_without_proof},p={}",
112 BASE64.encode(&client_proof)
113 );
114
115 frontend::sasl_response(conn.write_buf(), client_final.as_bytes());
116 conn.send().await?;
117
118 match conn.recv().await? {
120 BackendMessage::AuthenticationSaslFinal { data } => {
121 let server_final = String::from_utf8(data)
122 .map_err(|e| Error::Auth(format!("invalid server-final-message: {e}")))?;
123
124 let expected_verifier = format!("v={}", BASE64.encode(&server_signature));
126 if server_final != expected_verifier {
127 return Err(Error::Auth("server signature verification failed".into()));
128 }
129 }
130 BackendMessage::ErrorResponse { fields } => {
131 return Err(Error::server(
132 fields.severity,
133 fields.code,
134 fields.message,
135 fields.detail,
136 fields.hint,
137 fields.position,
138 ));
139 }
140 other => {
141 return Err(Error::protocol(format!(
142 "expected SaslFinal, got {other:?}"
143 )));
144 }
145 }
146
147 match conn.recv().await? {
149 BackendMessage::AuthenticationOk => Ok(()),
150 BackendMessage::ErrorResponse { fields } => Err(Error::server(
151 fields.severity,
152 fields.code,
153 fields.message,
154 fields.detail,
155 fields.hint,
156 fields.position,
157 )),
158 other => Err(Error::protocol(format!(
159 "expected AuthenticationOk, got {other:?}"
160 ))),
161 }
162}
163
164fn select_mechanism(
166 channel_binding: ChannelBinding,
167 is_tls: bool,
168 has_plus: bool,
169 has_plain: bool,
170) -> Result<(&'static str, &'static str)> {
171 match channel_binding {
172 ChannelBinding::Require => {
173 if !is_tls {
174 return Err(Error::Auth("channel binding requires TLS".into()));
175 }
176 if !has_plus {
177 return Err(Error::Auth(
178 "server does not support SCRAM-SHA-256-PLUS".into(),
179 ));
180 }
181 Ok(("SCRAM-SHA-256-PLUS", "p=tls-server-end-point,,"))
182 }
183 ChannelBinding::Prefer => {
184 if is_tls && has_plus {
185 Ok(("SCRAM-SHA-256-PLUS", "p=tls-server-end-point,,"))
186 } else if has_plain {
187 let gs2 = if is_tls { "y,," } else { "n,," };
189 Ok(("SCRAM-SHA-256", gs2))
190 } else {
191 Err(Error::Auth(
192 "server offered no supported SASL mechanisms".into(),
193 ))
194 }
195 }
196 ChannelBinding::Disable => {
197 if has_plain {
198 Ok(("SCRAM-SHA-256", "n,,"))
199 } else {
200 Err(Error::Auth(
201 "server offered no supported SASL mechanisms".into(),
202 ))
203 }
204 }
205 }
206}
207
208fn build_channel_binding_data(gs2_header: &str, server_cert_der: Option<&[u8]>) -> Vec<u8> {
213 let mut data = gs2_header.as_bytes().to_vec();
214 if gs2_header.starts_with("p=tls-server-end-point") {
215 if let Some(cert_der) = server_cert_der {
216 let hash = sha256(cert_der);
217 data.extend_from_slice(&hash);
218 }
219 }
220 data
221}
222
223pub struct ServerFirst {
224 pub nonce: String,
225 pub salt: String,
226 pub iterations: u32,
227}
228
229pub fn parse_server_first(msg: &str) -> Result<ServerFirst> {
230 let mut nonce = None;
231 let mut salt = None;
232 let mut iterations = None;
233
234 for part in msg.split(',') {
235 if let Some(val) = part.strip_prefix("r=") {
236 nonce = Some(val.to_string());
237 } else if let Some(val) = part.strip_prefix("s=") {
238 salt = Some(val.to_string());
239 } else if let Some(val) = part.strip_prefix("i=") {
240 iterations = Some(
241 val.parse::<u32>()
242 .map_err(|_| Error::Auth(format!("invalid iteration count: {val}")))?,
243 );
244 }
245 }
246
247 Ok(ServerFirst {
248 nonce: nonce.ok_or_else(|| Error::Auth("missing nonce in server-first".into()))?,
249 salt: salt.ok_or_else(|| Error::Auth("missing salt in server-first".into()))?,
250 iterations: iterations
251 .ok_or_else(|| Error::Auth("missing iterations in server-first".into()))?,
252 })
253}
254
255pub fn hi(password: &[u8], salt: &[u8], iterations: u32) -> Vec<u8> {
257 let mut salt_with_one = salt.to_vec();
259 salt_with_one.extend_from_slice(&1u32.to_be_bytes());
260
261 let mut u_prev = hmac_sha256(password, &salt_with_one);
262 let mut result = u_prev.clone();
263
264 for _ in 1..iterations {
265 let u_current = hmac_sha256(password, &u_prev);
266 for (r, u) in result.iter_mut().zip(u_current.iter()) {
267 *r ^= u;
268 }
269 u_prev = u_current;
270 }
271
272 result
273}
274
275pub fn hmac_sha256(key: &[u8], data: &[u8]) -> Vec<u8> {
276 #[allow(clippy::expect_used)]
277 let mut mac = HmacSha256::new_from_slice(key).expect("HMAC accepts any key length");
278 mac.update(data);
279 mac.finalize().into_bytes().to_vec()
280}
281
282fn sha256(data: &[u8]) -> Vec<u8> {
283 let mut hasher = Sha256::new();
284 hasher.update(data);
285 hasher.finalize().to_vec()
286}
287
288pub fn saslprep(input: &str) -> Result<String> {
293 stringprep::saslprep(input)
294 .map(std::borrow::Cow::into_owned)
295 .map_err(|e| Error::Auth(format!("SASLprep failed: {e}")))
296}
297
298pub fn generate_nonce() -> String {
300 use rand::Rng;
301 let mut rng = rand::thread_rng();
302 let bytes: Vec<u8> = (0..24).map(|_| rng.gen()).collect();
303 BASE64.encode(&bytes)
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309
310 #[test]
311 fn test_select_mechanism_require_with_tls_and_plus() {
312 let (mech, gs2) = select_mechanism(ChannelBinding::Require, true, true, true).unwrap();
313 assert_eq!(mech, "SCRAM-SHA-256-PLUS");
314 assert_eq!(gs2, "p=tls-server-end-point,,");
315 }
316
317 #[test]
318 fn test_select_mechanism_require_without_tls() {
319 let err = select_mechanism(ChannelBinding::Require, false, false, true).unwrap_err();
320 assert!(err.to_string().contains("channel binding requires TLS"));
321 }
322
323 #[test]
324 fn test_select_mechanism_require_no_plus() {
325 let err = select_mechanism(ChannelBinding::Require, true, false, true).unwrap_err();
326 assert!(err
327 .to_string()
328 .contains("does not support SCRAM-SHA-256-PLUS"));
329 }
330
331 #[test]
332 fn test_select_mechanism_prefer_with_tls_and_plus() {
333 let (mech, gs2) = select_mechanism(ChannelBinding::Prefer, true, true, true).unwrap();
334 assert_eq!(mech, "SCRAM-SHA-256-PLUS");
335 assert_eq!(gs2, "p=tls-server-end-point,,");
336 }
337
338 #[test]
339 fn test_select_mechanism_prefer_tls_no_plus() {
340 let (mech, gs2) = select_mechanism(ChannelBinding::Prefer, true, false, true).unwrap();
341 assert_eq!(mech, "SCRAM-SHA-256");
342 assert_eq!(gs2, "y,,");
343 }
344
345 #[test]
346 fn test_select_mechanism_prefer_no_tls() {
347 let (mech, gs2) = select_mechanism(ChannelBinding::Prefer, false, false, true).unwrap();
348 assert_eq!(mech, "SCRAM-SHA-256");
349 assert_eq!(gs2, "n,,");
350 }
351
352 #[test]
353 fn test_select_mechanism_disable() {
354 let (mech, gs2) = select_mechanism(ChannelBinding::Disable, true, true, true).unwrap();
355 assert_eq!(mech, "SCRAM-SHA-256");
356 assert_eq!(gs2, "n,,");
357 }
358
359 #[test]
360 fn test_build_channel_binding_no_plus() {
361 let data = build_channel_binding_data("n,,", None);
362 assert_eq!(data, b"n,,");
363 assert_eq!(BASE64.encode(&data), "biws");
364 }
365
366 #[test]
367 fn test_build_channel_binding_with_plus() {
368 let fake_cert = b"fake-server-certificate-der";
369 let data = build_channel_binding_data("p=tls-server-end-point,,", Some(fake_cert));
370 let expected_hash = sha256(fake_cert);
372 let mut expected = b"p=tls-server-end-point,,".to_vec();
373 expected.extend_from_slice(&expected_hash);
374 assert_eq!(data, expected);
375 }
376
377 #[test]
378 fn test_gs2_header_y_flag() {
379 let (mech, gs2) = select_mechanism(ChannelBinding::Prefer, true, false, true).unwrap();
381 assert_eq!(mech, "SCRAM-SHA-256");
382 assert_eq!(gs2, "y,,");
383 let data = build_channel_binding_data(gs2, Some(b"cert"));
384 assert_eq!(data, b"y,,");
386 }
387
388 #[test]
389 fn test_select_mechanism_prefer_no_mechanisms() {
390 let err = select_mechanism(ChannelBinding::Prefer, false, false, false).unwrap_err();
391 assert!(err.to_string().contains("no supported SASL mechanisms"));
392 }
393
394 #[test]
395 fn test_select_mechanism_prefer_tls_no_mechanisms() {
396 let err = select_mechanism(ChannelBinding::Prefer, true, false, false).unwrap_err();
397 assert!(err.to_string().contains("no supported SASL mechanisms"));
398 }
399
400 #[test]
401 fn test_select_mechanism_disable_no_plain() {
402 let err = select_mechanism(ChannelBinding::Disable, true, true, false).unwrap_err();
403 assert!(err.to_string().contains("no supported SASL mechanisms"));
404 }
405
406 #[test]
407 fn test_select_mechanism_prefer_no_tls_plus_only() {
408 let err = select_mechanism(ChannelBinding::Prefer, false, true, false).unwrap_err();
410 assert!(err.to_string().contains("no supported SASL mechanisms"));
411 }
412}