sqlx_postgres/connection/
sasl.rs1use crate::connection::stream::PgStream;
2use crate::error::Error;
3use crate::message::{Authentication, AuthenticationSasl, SaslInitialResponse, SaslResponse};
4use crate::PgConnectOptions;
5use hmac::{Hmac, Mac};
6use rand::Rng;
7use sha2::{Digest, Sha256};
8use stringprep::saslprep;
9
10use base64::prelude::{Engine as _, BASE64_STANDARD};
11
12const GS2_HEADER: &str = "n,,";
13const CHANNEL_ATTR: &str = "c";
14const USERNAME_ATTR: &str = "n";
15const CLIENT_PROOF_ATTR: &str = "p";
16const NONCE_ATTR: &str = "r";
17
18pub(crate) async fn authenticate(
19 stream: &mut PgStream,
20 options: &PgConnectOptions,
21 data: AuthenticationSasl,
22) -> Result<(), Error> {
23 let mut has_sasl = false;
24 let mut has_sasl_plus = false;
25 let mut unknown = Vec::new();
26
27 for mechanism in data.mechanisms() {
28 match mechanism {
29 "SCRAM-SHA-256" => {
30 has_sasl = true;
31 }
32
33 "SCRAM-SHA-256-PLUS" => {
34 has_sasl_plus = true;
35 }
36
37 _ => {
38 unknown.push(mechanism.to_owned());
39 }
40 }
41 }
42
43 if !has_sasl_plus && !has_sasl {
44 return Err(err_protocol!(
45 "unsupported SASL authentication mechanisms: {}",
46 unknown.join(", ")
47 ));
48 }
49
50 let mut channel_binding = format!("{CHANNEL_ATTR}=");
52 BASE64_STANDARD.encode_string(GS2_HEADER, &mut channel_binding);
53
54 let username = format!("{}={}", USERNAME_ATTR, options.username);
56 let username = match saslprep(&username) {
57 Ok(v) => v,
58 Err(_) => panic!("Failed to saslprep username"),
60 };
61
62 let nonce = gen_nonce();
64
65 let client_first_message_bare = format!("{username},{nonce}");
67
68 let client_first_message = format!("{GS2_HEADER}{client_first_message_bare}");
69
70 stream
71 .send(SaslInitialResponse {
72 response: &client_first_message,
73 plus: false,
74 })
75 .await?;
76
77 let cont = match stream.recv_expect().await? {
78 Authentication::SaslContinue(data) => data,
79
80 auth => {
81 return Err(err_protocol!(
82 "expected SASLContinue but received {:?}",
83 auth
84 ));
85 }
86 };
87
88 let salted_password = hi(
90 options.password.as_deref().unwrap_or_default(),
91 &cont.salt,
92 cont.iterations,
93 )?;
94
95 let mut mac = Hmac::<Sha256>::new_from_slice(&salted_password).map_err(Error::protocol)?;
97 mac.update(b"Client Key");
98
99 let client_key = mac.finalize().into_bytes();
100
101 let stored_key = Sha256::digest(client_key);
103
104 let client_final_message_wo_proof = format!(
106 "{channel_binding},r={nonce}",
107 channel_binding = channel_binding,
108 nonce = &cont.nonce
109 );
110
111 let auth_message = format!(
113 "{client_first_message_bare},{server_first_message},{client_final_message_wo_proof}",
114 client_first_message_bare = client_first_message_bare,
115 server_first_message = cont.message,
116 client_final_message_wo_proof = client_final_message_wo_proof
117 );
118
119 let mut mac = Hmac::<Sha256>::new_from_slice(&stored_key).map_err(Error::protocol)?;
121 mac.update(auth_message.as_bytes());
122
123 let client_signature = mac.finalize().into_bytes();
124
125 let client_proof: Vec<u8> = client_key
127 .iter()
128 .zip(client_signature.iter())
129 .map(|(&a, &b)| a ^ b)
130 .collect();
131
132 let mut mac = Hmac::<Sha256>::new_from_slice(&salted_password).map_err(Error::protocol)?;
134 mac.update(b"Server Key");
135
136 let server_key = mac.finalize().into_bytes();
137
138 let mut mac = Hmac::<Sha256>::new_from_slice(&server_key).map_err(Error::protocol)?;
140 mac.update(auth_message.as_bytes());
141
142 let mut client_final_message = format!("{client_final_message_wo_proof},{CLIENT_PROOF_ATTR}=");
144 BASE64_STANDARD.encode_string(client_proof, &mut client_final_message);
145
146 stream.send(SaslResponse(&client_final_message)).await?;
147
148 let data = match stream.recv_expect().await? {
149 Authentication::SaslFinal(data) => data,
150
151 auth => {
152 return Err(err_protocol!("expected SASLFinal but received {:?}", auth));
153 }
154 };
155
156 mac.verify_slice(&data.verifier).map_err(Error::protocol)?;
158
159 Ok(())
160}
161
162fn gen_nonce() -> String {
164 let mut rng = rand::thread_rng();
165 let count = rng.gen_range(64..128);
166
167 let nonce: String = std::iter::repeat(())
172 .map(|()| {
173 let mut c = rng.gen_range(0x21u8..0x7F);
174
175 while c == 0x2C {
176 c = rng.gen_range(0x21u8..0x7F);
177 }
178
179 c
180 })
181 .take(count)
182 .map(|c| c as char)
183 .collect();
184
185 rng.gen_range(32..128);
186 format!("{NONCE_ATTR}={nonce}")
187}
188
189fn hi<'a>(s: &'a str, salt: &'a [u8], iter_count: u32) -> Result<[u8; 32], Error> {
191 let mut mac = Hmac::<Sha256>::new_from_slice(s.as_bytes()).map_err(Error::protocol)?;
192
193 mac.update(salt);
194 mac.update(&1u32.to_be_bytes());
195
196 let mut u = mac.finalize_reset().into_bytes();
197 let mut hi = u;
198
199 for _ in 1..iter_count {
200 mac.update(u.as_slice());
201 u = mac.finalize_reset().into_bytes();
202 hi = hi.iter().zip(u.iter()).map(|(&a, &b)| a ^ b).collect();
203 }
204
205 Ok(hi.into())
206}
207
208#[cfg(all(test, not(debug_assertions)))]
209#[bench]
210fn bench_sasl_hi(b: &mut test::Bencher) {
211 use test::black_box;
212
213 let mut rng = rand::thread_rng();
214 let nonce: Vec<u8> = std::iter::repeat(())
215 .map(|()| rng.sample(rand::distributions::Alphanumeric))
216 .take(64)
217 .collect();
218 b.iter(|| {
219 let _ = hi(
220 test::black_box("secret_password"),
221 test::black_box(&nonce),
222 test::black_box(4096),
223 );
224 });
225}