sqlx_core/postgres/connection/
sasl.rs1use crate::error::Error;
2use crate::postgres::connection::stream::PgStream;
3use crate::postgres::message::{
4 Authentication, AuthenticationSasl, MessageFormat, SaslInitialResponse, SaslResponse,
5};
6use crate::postgres::PgConnectOptions;
7use hmac::{Hmac, Mac};
8use rand::Rng;
9use sha2::{Digest, Sha256};
10use stringprep::saslprep;
11
12const GS2_HEADER: &str = "n,,";
13const CHANNEL_ATTR: &str = "c";
14const USERNAME_ATTR: &str = "n";
15const CLIENT_PROOF_ATTR: &str = "p";
16const NONCE_ATTR: &str = "r";
17
18pub(crate) async fn authenticate(
19 stream: &mut PgStream,
20 options: &PgConnectOptions,
21 data: AuthenticationSasl,
22) -> Result<(), Error> {
23 let mut has_sasl = false;
24 let mut has_sasl_plus = false;
25 let mut unknown = Vec::new();
26
27 for mechanism in data.mechanisms() {
28 match mechanism {
29 "SCRAM-SHA-256" => {
30 has_sasl = true;
31 }
32
33 "SCRAM-SHA-256-PLUS" => {
34 has_sasl_plus = true;
35 }
36
37 _ => {
38 unknown.push(mechanism.to_owned());
39 }
40 }
41 }
42
43 if !has_sasl_plus && !has_sasl {
44 return Err(err_protocol!(
45 "unsupported SASL authentication mechanisms: {}",
46 unknown.join(", ")
47 ));
48 }
49
50 let channel_binding = format!("{}={}", CHANNEL_ATTR, base64::encode(GS2_HEADER));
52
53 let username = format!("{}={}", USERNAME_ATTR, options.username);
55 let username = match saslprep(&username) {
56 Ok(v) => v,
57 Err(_) => panic!("Failed to saslprep username"),
59 };
60
61 let nonce = gen_nonce();
63
64 let client_first_message_bare =
66 format!("{username},{nonce}", username = username, nonce = nonce);
67
68 let client_first_message = format!(
69 "{gs2_header}{client_first_message_bare}",
70 gs2_header = GS2_HEADER,
71 client_first_message_bare = client_first_message_bare
72 );
73
74 stream
75 .send(SaslInitialResponse {
76 response: &client_first_message,
77 plus: false,
78 })
79 .await?;
80
81 let cont = match stream.recv_expect(MessageFormat::Authentication).await? {
82 Authentication::SaslContinue(data) => data,
83
84 auth => {
85 return Err(err_protocol!(
86 "expected SASLContinue but received {:?}",
87 auth
88 ));
89 }
90 };
91
92 let salted_password = hi(
94 options.password.as_deref().unwrap_or_default(),
95 &cont.salt,
96 cont.iterations,
97 )?;
98
99 let mut mac = Hmac::<Sha256>::new_from_slice(&salted_password).map_err(Error::protocol)?;
101 mac.update(b"Client Key");
102
103 let client_key = mac.finalize().into_bytes();
104
105 let stored_key = Sha256::digest(&client_key);
107
108 let client_final_message_wo_proof = format!(
110 "{channel_binding},r={nonce}",
111 channel_binding = channel_binding,
112 nonce = &cont.nonce
113 );
114
115 let auth_message = format!(
117 "{client_first_message_bare},{server_first_message},{client_final_message_wo_proof}",
118 client_first_message_bare = client_first_message_bare,
119 server_first_message = cont.message,
120 client_final_message_wo_proof = client_final_message_wo_proof
121 );
122
123 let mut mac = Hmac::<Sha256>::new_from_slice(&stored_key).map_err(Error::protocol)?;
125 mac.update(&auth_message.as_bytes());
126
127 let client_signature = mac.finalize().into_bytes();
128
129 let client_proof: Vec<u8> = client_key
131 .iter()
132 .zip(client_signature.iter())
133 .map(|(&a, &b)| a ^ b)
134 .collect();
135
136 let mut mac = Hmac::<Sha256>::new_from_slice(&salted_password).map_err(Error::protocol)?;
138 mac.update(b"Server Key");
139
140 let server_key = mac.finalize().into_bytes();
141
142 let mut mac = Hmac::<Sha256>::new_from_slice(&server_key).map_err(Error::protocol)?;
144 mac.update(&auth_message.as_bytes());
145
146 let client_final_message = format!(
148 "{client_final_message_wo_proof},{client_proof_attr}={client_proof}",
149 client_final_message_wo_proof = client_final_message_wo_proof,
150 client_proof_attr = CLIENT_PROOF_ATTR,
151 client_proof = base64::encode(&client_proof)
152 );
153
154 stream.send(SaslResponse(&client_final_message)).await?;
155
156 let data = match stream.recv_expect(MessageFormat::Authentication).await? {
157 Authentication::SaslFinal(data) => data,
158
159 auth => {
160 return Err(err_protocol!("expected SASLFinal but received {:?}", auth));
161 }
162 };
163
164 mac.verify_slice(&data.verifier).map_err(Error::protocol)?;
166
167 Ok(())
168}
169
170fn gen_nonce() -> String {
172 let mut rng = rand::thread_rng();
173 let count = rng.gen_range(64..128);
174
175 let nonce: String = std::iter::repeat(())
180 .map(|()| {
181 let mut c = rng.gen_range(0x21..0x7F) as u8;
182
183 while c == 0x2C {
184 c = rng.gen_range(0x21..0x7F) as u8;
185 }
186
187 c
188 })
189 .take(count)
190 .map(|c| c as char)
191 .collect();
192
193 rng.gen_range(32..128);
194 format!("{}={}", NONCE_ATTR, nonce)
195}
196
197fn hi<'a>(s: &'a str, salt: &'a [u8], iter_count: u32) -> Result<[u8; 32], Error> {
199 let mut mac = Hmac::<Sha256>::new_from_slice(s.as_bytes()).map_err(Error::protocol)?;
200
201 mac.update(&salt);
202 mac.update(&1u32.to_be_bytes());
203
204 let mut u = mac.finalize().into_bytes();
205 let mut hi = u;
206
207 for _ in 1..iter_count {
208 let mut mac = Hmac::<Sha256>::new_from_slice(s.as_bytes()).map_err(Error::protocol)?;
209 mac.update(u.as_slice());
210 u = mac.finalize().into_bytes();
211 hi = hi.iter().zip(u.iter()).map(|(&a, &b)| a ^ b).collect();
212 }
213
214 Ok(hi.into())
215}