sqlx_postgres/connection/
sasl.rs

1use crate::connection::stream::PgStream;
2use crate::error::Error;
3use crate::message::{Authentication, AuthenticationSasl, SaslInitialResponse, SaslResponse};
4use crate::PgConnectOptions;
5use hmac::{Hmac, Mac};
6use rand::Rng;
7use sha2::{Digest, Sha256};
8use stringprep::saslprep;
9
10use base64::prelude::{Engine as _, BASE64_STANDARD};
11
12const GS2_HEADER: &str = "n,,";
13const CHANNEL_ATTR: &str = "c";
14const USERNAME_ATTR: &str = "n";
15const CLIENT_PROOF_ATTR: &str = "p";
16const NONCE_ATTR: &str = "r";
17
18pub(crate) async fn authenticate(
19    stream: &mut PgStream,
20    options: &PgConnectOptions,
21    data: AuthenticationSasl,
22) -> Result<(), Error> {
23    let mut has_sasl = false;
24    let mut has_sasl_plus = false;
25    let mut unknown = Vec::new();
26
27    for mechanism in data.mechanisms() {
28        match mechanism {
29            "SCRAM-SHA-256" => {
30                has_sasl = true;
31            }
32
33            "SCRAM-SHA-256-PLUS" => {
34                has_sasl_plus = true;
35            }
36
37            _ => {
38                unknown.push(mechanism.to_owned());
39            }
40        }
41    }
42
43    if !has_sasl_plus && !has_sasl {
44        return Err(err_protocol!(
45            "unsupported SASL authentication mechanisms: {}",
46            unknown.join(", ")
47        ));
48    }
49
50    // channel-binding = "c=" base64
51    let mut channel_binding = format!("{CHANNEL_ATTR}=");
52    BASE64_STANDARD.encode_string(GS2_HEADER, &mut channel_binding);
53
54    // "n=" saslname ;; Usernames are prepared using SASLprep.
55    let username = format!("{}={}", USERNAME_ATTR, options.username);
56    let username = match saslprep(&username) {
57        Ok(v) => v,
58        // TODO(danielakhterov): Remove panic when we have proper support for configuration errors
59        Err(_) => panic!("Failed to saslprep username"),
60    };
61
62    // nonce = "r=" c-nonce [s-nonce] ;; Second part provided by server.
63    let nonce = gen_nonce();
64
65    // client-first-message-bare = [reserved-mext ","] username "," nonce ["," extensions]
66    let client_first_message_bare = format!("{username},{nonce}");
67
68    let client_first_message = format!("{GS2_HEADER}{client_first_message_bare}");
69
70    stream
71        .send(SaslInitialResponse {
72            response: &client_first_message,
73            plus: false,
74        })
75        .await?;
76
77    let cont = match stream.recv_expect().await? {
78        Authentication::SaslContinue(data) => data,
79
80        auth => {
81            return Err(err_protocol!(
82                "expected SASLContinue but received {:?}",
83                auth
84            ));
85        }
86    };
87
88    // SaltedPassword := Hi(Normalize(password), salt, i)
89    let salted_password = hi(
90        options.password.as_deref().unwrap_or_default(),
91        &cont.salt,
92        cont.iterations,
93    )?;
94
95    // ClientKey := HMAC(SaltedPassword, "Client Key")
96    let mut mac = Hmac::<Sha256>::new_from_slice(&salted_password).map_err(Error::protocol)?;
97    mac.update(b"Client Key");
98
99    let client_key = mac.finalize().into_bytes();
100
101    // StoredKey := H(ClientKey)
102    let stored_key = Sha256::digest(client_key);
103
104    // client-final-message-without-proof
105    let client_final_message_wo_proof = format!(
106        "{channel_binding},r={nonce}",
107        channel_binding = channel_binding,
108        nonce = &cont.nonce
109    );
110
111    // AuthMessage := client-first-message-bare + "," + server-first-message + "," + client-final-message-without-proof
112    let auth_message = format!(
113        "{client_first_message_bare},{server_first_message},{client_final_message_wo_proof}",
114        client_first_message_bare = client_first_message_bare,
115        server_first_message = cont.message,
116        client_final_message_wo_proof = client_final_message_wo_proof
117    );
118
119    // ClientSignature := HMAC(StoredKey, AuthMessage)
120    let mut mac = Hmac::<Sha256>::new_from_slice(&stored_key).map_err(Error::protocol)?;
121    mac.update(auth_message.as_bytes());
122
123    let client_signature = mac.finalize().into_bytes();
124
125    // ClientProof := ClientKey XOR ClientSignature
126    let client_proof: Vec<u8> = client_key
127        .iter()
128        .zip(client_signature.iter())
129        .map(|(&a, &b)| a ^ b)
130        .collect();
131
132    // ServerKey := HMAC(SaltedPassword, "Server Key")
133    let mut mac = Hmac::<Sha256>::new_from_slice(&salted_password).map_err(Error::protocol)?;
134    mac.update(b"Server Key");
135
136    let server_key = mac.finalize().into_bytes();
137
138    // ServerSignature := HMAC(ServerKey, AuthMessage)
139    let mut mac = Hmac::<Sha256>::new_from_slice(&server_key).map_err(Error::protocol)?;
140    mac.update(auth_message.as_bytes());
141
142    // client-final-message = client-final-message-without-proof "," proof
143    let mut client_final_message = format!("{client_final_message_wo_proof},{CLIENT_PROOF_ATTR}=");
144    BASE64_STANDARD.encode_string(client_proof, &mut client_final_message);
145
146    stream.send(SaslResponse(&client_final_message)).await?;
147
148    let data = match stream.recv_expect().await? {
149        Authentication::SaslFinal(data) => data,
150
151        auth => {
152            return Err(err_protocol!("expected SASLFinal but received {:?}", auth));
153        }
154    };
155
156    // authentication is only considered valid if this verification passes
157    mac.verify_slice(&data.verifier).map_err(Error::protocol)?;
158
159    Ok(())
160}
161
162// nonce is a sequence of random printable bytes
163fn gen_nonce() -> String {
164    let mut rng = rand::thread_rng();
165    let count = rng.gen_range(64..128);
166
167    // printable = %x21-2B / %x2D-7E
168    // ;; Printable ASCII except ",".
169    // ;; Note that any "printable" is also
170    // ;; a valid "value".
171    let nonce: String = std::iter::repeat(())
172        .map(|()| {
173            let mut c = rng.gen_range(0x21u8..0x7F);
174
175            while c == 0x2C {
176                c = rng.gen_range(0x21u8..0x7F);
177            }
178
179            c
180        })
181        .take(count)
182        .map(|c| c as char)
183        .collect();
184
185    rng.gen_range(32..128);
186    format!("{NONCE_ATTR}={nonce}")
187}
188
189// Hi(str, salt, i):
190fn hi<'a>(s: &'a str, salt: &'a [u8], iter_count: u32) -> Result<[u8; 32], Error> {
191    let mut mac = Hmac::<Sha256>::new_from_slice(s.as_bytes()).map_err(Error::protocol)?;
192
193    mac.update(salt);
194    mac.update(&1u32.to_be_bytes());
195
196    let mut u = mac.finalize_reset().into_bytes();
197    let mut hi = u;
198
199    for _ in 1..iter_count {
200        mac.update(u.as_slice());
201        u = mac.finalize_reset().into_bytes();
202        hi = hi.iter().zip(u.iter()).map(|(&a, &b)| a ^ b).collect();
203    }
204
205    Ok(hi.into())
206}
207
208#[cfg(all(test, not(debug_assertions)))]
209#[bench]
210fn bench_sasl_hi(b: &mut test::Bencher) {
211    use test::black_box;
212
213    let mut rng = rand::thread_rng();
214    let nonce: Vec<u8> = std::iter::repeat(())
215        .map(|()| rng.sample(rand::distributions::Alphanumeric))
216        .take(64)
217        .collect();
218    b.iter(|| {
219        let _ = hi(
220            test::black_box("secret_password"),
221            test::black_box(&nonce),
222            test::black_box(4096),
223        );
224    });
225}