sqlx_etorreborre_postgres/connection/
sasl.rs1use crate::connection::stream::PgStream;
2use crate::error::Error;
3use crate::message::{
4 Authentication, AuthenticationSasl, MessageFormat, SaslInitialResponse, SaslResponse,
5};
6use crate::PgConnectOptions;
7use hmac::{Hmac, Mac};
8use rand::Rng;
9use sha2::{Digest, Sha256};
10use stringprep::saslprep;
11
12use base64::prelude::{Engine as _, BASE64_STANDARD};
13
14const GS2_HEADER: &str = "n,,";
15const CHANNEL_ATTR: &str = "c";
16const USERNAME_ATTR: &str = "n";
17const CLIENT_PROOF_ATTR: &str = "p";
18const NONCE_ATTR: &str = "r";
19
20pub(crate) async fn authenticate(
21 stream: &mut PgStream,
22 options: &PgConnectOptions,
23 data: AuthenticationSasl,
24) -> Result<(), Error> {
25 let mut has_sasl = false;
26 let mut has_sasl_plus = false;
27 let mut unknown = Vec::new();
28
29 for mechanism in data.mechanisms() {
30 match mechanism {
31 "SCRAM-SHA-256" => {
32 has_sasl = true;
33 }
34
35 "SCRAM-SHA-256-PLUS" => {
36 has_sasl_plus = true;
37 }
38
39 _ => {
40 unknown.push(mechanism.to_owned());
41 }
42 }
43 }
44
45 if !has_sasl_plus && !has_sasl {
46 return Err(err_protocol!(
47 "unsupported SASL authentication mechanisms: {}",
48 unknown.join(", ")
49 ));
50 }
51
52 let mut channel_binding = format!("{CHANNEL_ATTR}=");
54 BASE64_STANDARD.encode_string(GS2_HEADER, &mut channel_binding);
55
56 let username = format!("{}={}", USERNAME_ATTR, options.username);
58 let username = match saslprep(&username) {
59 Ok(v) => v,
60 Err(_) => panic!("Failed to saslprep username"),
62 };
63
64 let nonce = gen_nonce();
66
67 let client_first_message_bare = format!("{username},{nonce}");
69
70 let client_first_message = format!("{GS2_HEADER}{client_first_message_bare}");
71
72 stream
73 .send(SaslInitialResponse {
74 response: &client_first_message,
75 plus: false,
76 })
77 .await?;
78
79 let cont = match stream.recv_expect(MessageFormat::Authentication).await? {
80 Authentication::SaslContinue(data) => data,
81
82 auth => {
83 return Err(err_protocol!(
84 "expected SASLContinue but received {:?}",
85 auth
86 ));
87 }
88 };
89
90 let salted_password = hi(
92 options.password.as_deref().unwrap_or_default(),
93 &cont.salt,
94 cont.iterations,
95 )?;
96
97 let mut mac = Hmac::<Sha256>::new_from_slice(&salted_password).map_err(Error::protocol)?;
99 mac.update(b"Client Key");
100
101 let client_key = mac.finalize().into_bytes();
102
103 let stored_key = Sha256::digest(&client_key);
105
106 let client_final_message_wo_proof = format!(
108 "{channel_binding},r={nonce}",
109 channel_binding = channel_binding,
110 nonce = &cont.nonce
111 );
112
113 let auth_message = format!(
115 "{client_first_message_bare},{server_first_message},{client_final_message_wo_proof}",
116 client_first_message_bare = client_first_message_bare,
117 server_first_message = cont.message,
118 client_final_message_wo_proof = client_final_message_wo_proof
119 );
120
121 let mut mac = Hmac::<Sha256>::new_from_slice(&stored_key).map_err(Error::protocol)?;
123 mac.update(&auth_message.as_bytes());
124
125 let client_signature = mac.finalize().into_bytes();
126
127 let client_proof: Vec<u8> = client_key
129 .iter()
130 .zip(client_signature.iter())
131 .map(|(&a, &b)| a ^ b)
132 .collect();
133
134 let mut mac = Hmac::<Sha256>::new_from_slice(&salted_password).map_err(Error::protocol)?;
136 mac.update(b"Server Key");
137
138 let server_key = mac.finalize().into_bytes();
139
140 let mut mac = Hmac::<Sha256>::new_from_slice(&server_key).map_err(Error::protocol)?;
142 mac.update(&auth_message.as_bytes());
143
144 let mut client_final_message = format!("{client_final_message_wo_proof},{CLIENT_PROOF_ATTR}=");
146 BASE64_STANDARD.encode_string(client_proof, &mut client_final_message);
147
148 stream.send(SaslResponse(&client_final_message)).await?;
149
150 let data = match stream.recv_expect(MessageFormat::Authentication).await? {
151 Authentication::SaslFinal(data) => data,
152
153 auth => {
154 return Err(err_protocol!("expected SASLFinal but received {:?}", auth));
155 }
156 };
157
158 mac.verify_slice(&data.verifier).map_err(Error::protocol)?;
160
161 Ok(())
162}
163
164fn gen_nonce() -> String {
166 let mut rng = rand::thread_rng();
167 let count = rng.gen_range(64..128);
168
169 let nonce: String = std::iter::repeat(())
174 .map(|()| {
175 let mut c = rng.gen_range(0x21..0x7F) as u8;
176
177 while c == 0x2C {
178 c = rng.gen_range(0x21..0x7F) as u8;
179 }
180
181 c
182 })
183 .take(count)
184 .map(|c| c as char)
185 .collect();
186
187 rng.gen_range(32..128);
188 format!("{NONCE_ATTR}={nonce}")
189}
190
191fn hi<'a>(s: &'a str, salt: &'a [u8], iter_count: u32) -> Result<[u8; 32], Error> {
193 let mut mac = Hmac::<Sha256>::new_from_slice(s.as_bytes()).map_err(Error::protocol)?;
194
195 mac.update(&salt);
196 mac.update(&1u32.to_be_bytes());
197
198 let mut u = mac.finalize_reset().into_bytes();
199 let mut hi = u;
200
201 for _ in 1..iter_count {
202 mac.update(u.as_slice());
203 u = mac.finalize_reset().into_bytes();
204 hi = hi.iter().zip(u.iter()).map(|(&a, &b)| a ^ b).collect();
205 }
206
207 Ok(hi.into())
208}
209
210#[cfg(all(test, not(debug_assertions)))]
211#[bench]
212fn bench_sasl_hi(b: &mut test::Bencher) {
213 use test::black_box;
214
215 let mut rng = rand::thread_rng();
216 let nonce: Vec<u8> = std::iter::repeat(())
217 .map(|()| rng.sample(rand::distributions::Alphanumeric))
218 .take(64)
219 .collect();
220 b.iter(|| {
221 let _ = hi(
222 test::black_box("secret_password"),
223 test::black_box(&nonce),
224 test::black_box(4096),
225 );
226 });
227}