sqlx_etorreborre_postgres/connection/
sasl.rs

1use crate::connection::stream::PgStream;
2use crate::error::Error;
3use crate::message::{
4    Authentication, AuthenticationSasl, MessageFormat, SaslInitialResponse, SaslResponse,
5};
6use crate::PgConnectOptions;
7use hmac::{Hmac, Mac};
8use rand::Rng;
9use sha2::{Digest, Sha256};
10use stringprep::saslprep;
11
12use base64::prelude::{Engine as _, BASE64_STANDARD};
13
14const GS2_HEADER: &str = "n,,";
15const CHANNEL_ATTR: &str = "c";
16const USERNAME_ATTR: &str = "n";
17const CLIENT_PROOF_ATTR: &str = "p";
18const NONCE_ATTR: &str = "r";
19
20pub(crate) async fn authenticate(
21    stream: &mut PgStream,
22    options: &PgConnectOptions,
23    data: AuthenticationSasl,
24) -> Result<(), Error> {
25    let mut has_sasl = false;
26    let mut has_sasl_plus = false;
27    let mut unknown = Vec::new();
28
29    for mechanism in data.mechanisms() {
30        match mechanism {
31            "SCRAM-SHA-256" => {
32                has_sasl = true;
33            }
34
35            "SCRAM-SHA-256-PLUS" => {
36                has_sasl_plus = true;
37            }
38
39            _ => {
40                unknown.push(mechanism.to_owned());
41            }
42        }
43    }
44
45    if !has_sasl_plus && !has_sasl {
46        return Err(err_protocol!(
47            "unsupported SASL authentication mechanisms: {}",
48            unknown.join(", ")
49        ));
50    }
51
52    // channel-binding = "c=" base64
53    let mut channel_binding = format!("{CHANNEL_ATTR}=");
54    BASE64_STANDARD.encode_string(GS2_HEADER, &mut channel_binding);
55
56    // "n=" saslname ;; Usernames are prepared using SASLprep.
57    let username = format!("{}={}", USERNAME_ATTR, options.username);
58    let username = match saslprep(&username) {
59        Ok(v) => v,
60        // TODO(danielakhterov): Remove panic when we have proper support for configuration errors
61        Err(_) => panic!("Failed to saslprep username"),
62    };
63
64    // nonce = "r=" c-nonce [s-nonce] ;; Second part provided by server.
65    let nonce = gen_nonce();
66
67    // client-first-message-bare = [reserved-mext ","] username "," nonce ["," extensions]
68    let client_first_message_bare = format!("{username},{nonce}");
69
70    let client_first_message = format!("{GS2_HEADER}{client_first_message_bare}");
71
72    stream
73        .send(SaslInitialResponse {
74            response: &client_first_message,
75            plus: false,
76        })
77        .await?;
78
79    let cont = match stream.recv_expect(MessageFormat::Authentication).await? {
80        Authentication::SaslContinue(data) => data,
81
82        auth => {
83            return Err(err_protocol!(
84                "expected SASLContinue but received {:?}",
85                auth
86            ));
87        }
88    };
89
90    // SaltedPassword := Hi(Normalize(password), salt, i)
91    let salted_password = hi(
92        options.password.as_deref().unwrap_or_default(),
93        &cont.salt,
94        cont.iterations,
95    )?;
96
97    // ClientKey := HMAC(SaltedPassword, "Client Key")
98    let mut mac = Hmac::<Sha256>::new_from_slice(&salted_password).map_err(Error::protocol)?;
99    mac.update(b"Client Key");
100
101    let client_key = mac.finalize().into_bytes();
102
103    // StoredKey := H(ClientKey)
104    let stored_key = Sha256::digest(&client_key);
105
106    // client-final-message-without-proof
107    let client_final_message_wo_proof = format!(
108        "{channel_binding},r={nonce}",
109        channel_binding = channel_binding,
110        nonce = &cont.nonce
111    );
112
113    // AuthMessage := client-first-message-bare + "," + server-first-message + "," + client-final-message-without-proof
114    let auth_message = format!(
115        "{client_first_message_bare},{server_first_message},{client_final_message_wo_proof}",
116        client_first_message_bare = client_first_message_bare,
117        server_first_message = cont.message,
118        client_final_message_wo_proof = client_final_message_wo_proof
119    );
120
121    // ClientSignature := HMAC(StoredKey, AuthMessage)
122    let mut mac = Hmac::<Sha256>::new_from_slice(&stored_key).map_err(Error::protocol)?;
123    mac.update(&auth_message.as_bytes());
124
125    let client_signature = mac.finalize().into_bytes();
126
127    // ClientProof := ClientKey XOR ClientSignature
128    let client_proof: Vec<u8> = client_key
129        .iter()
130        .zip(client_signature.iter())
131        .map(|(&a, &b)| a ^ b)
132        .collect();
133
134    // ServerKey := HMAC(SaltedPassword, "Server Key")
135    let mut mac = Hmac::<Sha256>::new_from_slice(&salted_password).map_err(Error::protocol)?;
136    mac.update(b"Server Key");
137
138    let server_key = mac.finalize().into_bytes();
139
140    // ServerSignature := HMAC(ServerKey, AuthMessage)
141    let mut mac = Hmac::<Sha256>::new_from_slice(&server_key).map_err(Error::protocol)?;
142    mac.update(&auth_message.as_bytes());
143
144    // client-final-message = client-final-message-without-proof "," proof
145    let mut client_final_message = format!("{client_final_message_wo_proof},{CLIENT_PROOF_ATTR}=");
146    BASE64_STANDARD.encode_string(client_proof, &mut client_final_message);
147
148    stream.send(SaslResponse(&client_final_message)).await?;
149
150    let data = match stream.recv_expect(MessageFormat::Authentication).await? {
151        Authentication::SaslFinal(data) => data,
152
153        auth => {
154            return Err(err_protocol!("expected SASLFinal but received {:?}", auth));
155        }
156    };
157
158    // authentication is only considered valid if this verification passes
159    mac.verify_slice(&data.verifier).map_err(Error::protocol)?;
160
161    Ok(())
162}
163
164// nonce is a sequence of random printable bytes
165fn gen_nonce() -> String {
166    let mut rng = rand::thread_rng();
167    let count = rng.gen_range(64..128);
168
169    // printable = %x21-2B / %x2D-7E
170    // ;; Printable ASCII except ",".
171    // ;; Note that any "printable" is also
172    // ;; a valid "value".
173    let nonce: String = std::iter::repeat(())
174        .map(|()| {
175            let mut c = rng.gen_range(0x21..0x7F) as u8;
176
177            while c == 0x2C {
178                c = rng.gen_range(0x21..0x7F) as u8;
179            }
180
181            c
182        })
183        .take(count)
184        .map(|c| c as char)
185        .collect();
186
187    rng.gen_range(32..128);
188    format!("{NONCE_ATTR}={nonce}")
189}
190
191// Hi(str, salt, i):
192fn hi<'a>(s: &'a str, salt: &'a [u8], iter_count: u32) -> Result<[u8; 32], Error> {
193    let mut mac = Hmac::<Sha256>::new_from_slice(s.as_bytes()).map_err(Error::protocol)?;
194
195    mac.update(&salt);
196    mac.update(&1u32.to_be_bytes());
197
198    let mut u = mac.finalize_reset().into_bytes();
199    let mut hi = u;
200
201    for _ in 1..iter_count {
202        mac.update(u.as_slice());
203        u = mac.finalize_reset().into_bytes();
204        hi = hi.iter().zip(u.iter()).map(|(&a, &b)| a ^ b).collect();
205    }
206
207    Ok(hi.into())
208}
209
210#[cfg(all(test, not(debug_assertions)))]
211#[bench]
212fn bench_sasl_hi(b: &mut test::Bencher) {
213    use test::black_box;
214
215    let mut rng = rand::thread_rng();
216    let nonce: Vec<u8> = std::iter::repeat(())
217        .map(|()| rng.sample(rand::distributions::Alphanumeric))
218        .take(64)
219        .collect();
220    b.iter(|| {
221        let _ = hi(
222            test::black_box("secret_password"),
223            test::black_box(&nonce),
224            test::black_box(4096),
225        );
226    });
227}