sqlx_core/postgres/connection/
sasl.rs

1use crate::error::Error;
2use crate::postgres::connection::stream::PgStream;
3use crate::postgres::message::{
4    Authentication, AuthenticationSasl, MessageFormat, SaslInitialResponse, SaslResponse,
5};
6use crate::postgres::PgConnectOptions;
7use hmac::{Hmac, Mac};
8use rand::Rng;
9use sha2::{Digest, Sha256};
10use stringprep::saslprep;
11
12const GS2_HEADER: &str = "n,,";
13const CHANNEL_ATTR: &str = "c";
14const USERNAME_ATTR: &str = "n";
15const CLIENT_PROOF_ATTR: &str = "p";
16const NONCE_ATTR: &str = "r";
17
18pub(crate) async fn authenticate(
19    stream: &mut PgStream,
20    options: &PgConnectOptions,
21    data: AuthenticationSasl,
22) -> Result<(), Error> {
23    let mut has_sasl = false;
24    let mut has_sasl_plus = false;
25    let mut unknown = Vec::new();
26
27    for mechanism in data.mechanisms() {
28        match mechanism {
29            "SCRAM-SHA-256" => {
30                has_sasl = true;
31            }
32
33            "SCRAM-SHA-256-PLUS" => {
34                has_sasl_plus = true;
35            }
36
37            _ => {
38                unknown.push(mechanism.to_owned());
39            }
40        }
41    }
42
43    if !has_sasl_plus && !has_sasl {
44        return Err(err_protocol!(
45            "unsupported SASL authentication mechanisms: {}",
46            unknown.join(", ")
47        ));
48    }
49
50    // channel-binding = "c=" base64
51    let channel_binding = format!("{}={}", CHANNEL_ATTR, base64::encode(GS2_HEADER));
52
53    // "n=" saslname ;; Usernames are prepared using SASLprep.
54    let username = format!("{}={}", USERNAME_ATTR, options.username);
55    let username = match saslprep(&username) {
56        Ok(v) => v,
57        // TODO(danielakhterov): Remove panic when we have proper support for configuration errors
58        Err(_) => panic!("Failed to saslprep username"),
59    };
60
61    // nonce = "r=" c-nonce [s-nonce] ;; Second part provided by server.
62    let nonce = gen_nonce();
63
64    // client-first-message-bare = [reserved-mext ","] username "," nonce ["," extensions]
65    let client_first_message_bare =
66        format!("{username},{nonce}", username = username, nonce = nonce);
67
68    let client_first_message = format!(
69        "{gs2_header}{client_first_message_bare}",
70        gs2_header = GS2_HEADER,
71        client_first_message_bare = client_first_message_bare
72    );
73
74    stream
75        .send(SaslInitialResponse {
76            response: &client_first_message,
77            plus: false,
78        })
79        .await?;
80
81    let cont = match stream.recv_expect(MessageFormat::Authentication).await? {
82        Authentication::SaslContinue(data) => data,
83
84        auth => {
85            return Err(err_protocol!(
86                "expected SASLContinue but received {:?}",
87                auth
88            ));
89        }
90    };
91
92    // SaltedPassword := Hi(Normalize(password), salt, i)
93    let salted_password = hi(
94        options.password.as_deref().unwrap_or_default(),
95        &cont.salt,
96        cont.iterations,
97    )?;
98
99    // ClientKey := HMAC(SaltedPassword, "Client Key")
100    let mut mac = Hmac::<Sha256>::new_from_slice(&salted_password).map_err(Error::protocol)?;
101    mac.update(b"Client Key");
102
103    let client_key = mac.finalize().into_bytes();
104
105    // StoredKey := H(ClientKey)
106    let stored_key = Sha256::digest(&client_key);
107
108    // client-final-message-without-proof
109    let client_final_message_wo_proof = format!(
110        "{channel_binding},r={nonce}",
111        channel_binding = channel_binding,
112        nonce = &cont.nonce
113    );
114
115    // AuthMessage := client-first-message-bare + "," + server-first-message + "," + client-final-message-without-proof
116    let auth_message = format!(
117        "{client_first_message_bare},{server_first_message},{client_final_message_wo_proof}",
118        client_first_message_bare = client_first_message_bare,
119        server_first_message = cont.message,
120        client_final_message_wo_proof = client_final_message_wo_proof
121    );
122
123    // ClientSignature := HMAC(StoredKey, AuthMessage)
124    let mut mac = Hmac::<Sha256>::new_from_slice(&stored_key).map_err(Error::protocol)?;
125    mac.update(&auth_message.as_bytes());
126
127    let client_signature = mac.finalize().into_bytes();
128
129    // ClientProof := ClientKey XOR ClientSignature
130    let client_proof: Vec<u8> = client_key
131        .iter()
132        .zip(client_signature.iter())
133        .map(|(&a, &b)| a ^ b)
134        .collect();
135
136    // ServerKey := HMAC(SaltedPassword, "Server Key")
137    let mut mac = Hmac::<Sha256>::new_from_slice(&salted_password).map_err(Error::protocol)?;
138    mac.update(b"Server Key");
139
140    let server_key = mac.finalize().into_bytes();
141
142    // ServerSignature := HMAC(ServerKey, AuthMessage)
143    let mut mac = Hmac::<Sha256>::new_from_slice(&server_key).map_err(Error::protocol)?;
144    mac.update(&auth_message.as_bytes());
145
146    // client-final-message = client-final-message-without-proof "," proof
147    let client_final_message = format!(
148        "{client_final_message_wo_proof},{client_proof_attr}={client_proof}",
149        client_final_message_wo_proof = client_final_message_wo_proof,
150        client_proof_attr = CLIENT_PROOF_ATTR,
151        client_proof = base64::encode(&client_proof)
152    );
153
154    stream.send(SaslResponse(&client_final_message)).await?;
155
156    let data = match stream.recv_expect(MessageFormat::Authentication).await? {
157        Authentication::SaslFinal(data) => data,
158
159        auth => {
160            return Err(err_protocol!("expected SASLFinal but received {:?}", auth));
161        }
162    };
163
164    // authentication is only considered valid if this verification passes
165    mac.verify_slice(&data.verifier).map_err(Error::protocol)?;
166
167    Ok(())
168}
169
170// nonce is a sequence of random printable bytes
171fn gen_nonce() -> String {
172    let mut rng = rand::thread_rng();
173    let count = rng.gen_range(64..128);
174
175    // printable = %x21-2B / %x2D-7E
176    // ;; Printable ASCII except ",".
177    // ;; Note that any "printable" is also
178    // ;; a valid "value".
179    let nonce: String = std::iter::repeat(())
180        .map(|()| {
181            let mut c = rng.gen_range(0x21..0x7F) as u8;
182
183            while c == 0x2C {
184                c = rng.gen_range(0x21..0x7F) as u8;
185            }
186
187            c
188        })
189        .take(count)
190        .map(|c| c as char)
191        .collect();
192
193    rng.gen_range(32..128);
194    format!("{}={}", NONCE_ATTR, nonce)
195}
196
197// Hi(str, salt, i):
198fn hi<'a>(s: &'a str, salt: &'a [u8], iter_count: u32) -> Result<[u8; 32], Error> {
199    let mut mac = Hmac::<Sha256>::new_from_slice(s.as_bytes()).map_err(Error::protocol)?;
200
201    mac.update(&salt);
202    mac.update(&1u32.to_be_bytes());
203
204    let mut u = mac.finalize().into_bytes();
205    let mut hi = u;
206
207    for _ in 1..iter_count {
208        let mut mac = Hmac::<Sha256>::new_from_slice(s.as_bytes()).map_err(Error::protocol)?;
209        mac.update(u.as_slice());
210        u = mac.finalize().into_bytes();
211        hi = hi.iter().zip(u.iter()).map(|(&a, &b)| a ^ b).collect();
212    }
213
214    Ok(hi.into())
215}