zero_postgres/protocol/frontend/
auth.rs1use crate::protocol::codec::MessageBuilder;
4
5pub fn write_password(buf: &mut Vec<u8>, password: &str) {
7 let mut msg = MessageBuilder::new(buf, super::msg_type::PASSWORD);
8 msg.write_cstr(password);
9 msg.finish();
10}
11
12pub fn md5_password(username: &str, password: &str, salt: &[u8; 4]) -> String {
16 use md5::{Digest, Md5};
17
18 let first_hash = {
20 let mut hasher = Md5::new();
21 hasher.update(password.as_bytes());
22 hasher.update(username.as_bytes());
23 hasher.finalize()
24 };
25 let first_hash_hex = format!("{:x}", first_hash);
26
27 let second_hash = {
29 let mut hasher = Md5::new();
30 hasher.update(first_hash_hex.as_bytes());
31 hasher.update(salt);
32 hasher.finalize()
33 };
34
35 format!("md5{:x}", second_hash)
36}
37
38pub fn write_sasl_initial_response(buf: &mut Vec<u8>, mechanism: &str, initial_response: &[u8]) {
43 let mut msg = MessageBuilder::new(buf, super::msg_type::PASSWORD);
44 msg.write_cstr(mechanism);
45 msg.write_i32(initial_response.len() as i32);
46 msg.write_bytes(initial_response);
47 msg.finish();
48}
49
50pub fn write_sasl_response(buf: &mut Vec<u8>, response: &[u8]) {
54 let mut msg = MessageBuilder::new(buf, super::msg_type::PASSWORD);
55 msg.write_bytes(response);
56 msg.finish();
57}
58
59pub struct ScramClient {
61 nonce: String,
63 channel_binding: String,
65 password: String,
67 server_first: Option<String>,
69 auth_message: Option<String>,
71 salted_password: Option<Vec<u8>>,
73}
74
75impl ScramClient {
76 pub fn new(password: &str) -> Self {
78 use rand::Rng;
79
80 let mut nonce_bytes = [0u8; 24];
82 rand::rng().fill(&mut nonce_bytes);
83 let nonce = base64::Engine::encode(&base64::engine::general_purpose::STANDARD, nonce_bytes);
84
85 Self {
86 nonce,
87 channel_binding: "n,,".to_string(), password: password.to_string(),
89 server_first: None,
90 auth_message: None,
91 salted_password: None,
92 }
93 }
94
95 pub fn new_with_channel_binding(password: &str, channel_binding_data: &[u8]) -> Self {
97 use rand::Rng;
98
99 let mut nonce_bytes = [0u8; 24];
100 rand::rng().fill(&mut nonce_bytes);
101 let nonce = base64::Engine::encode(&base64::engine::general_purpose::STANDARD, nonce_bytes);
102
103 let cb_data = base64::Engine::encode(
105 &base64::engine::general_purpose::STANDARD,
106 channel_binding_data,
107 );
108
109 Self {
110 nonce,
111 channel_binding: format!("p=tls-server-end-point,,{}", cb_data),
112 password: password.to_string(),
113 server_first: None,
114 auth_message: None,
115 salted_password: None,
116 }
117 }
118
119 pub fn client_first_message(&self) -> String {
121 format!("{}n=,r={}", self.channel_binding, self.nonce)
124 }
125
126 fn client_first_message_bare(&self) -> String {
128 format!("n=,r={}", self.nonce)
129 }
130
131 pub fn process_server_first(&mut self, server_first: &str) -> Result<String, String> {
133 use base64::Engine;
134 use hmac::{Hmac, Mac};
135 use pbkdf2::pbkdf2_hmac;
136 use sha2::{Digest, Sha256};
137
138 self.server_first = Some(server_first.to_string());
139
140 let mut combined_nonce = None;
142 let mut salt_b64 = None;
143 let mut iterations = None;
144
145 for part in server_first.split(',') {
146 if let Some(value) = part.strip_prefix("r=") {
147 combined_nonce = Some(value);
148 } else if let Some(value) = part.strip_prefix("s=") {
149 salt_b64 = Some(value);
150 } else if let Some(value) = part.strip_prefix("i=") {
151 iterations = value.parse().ok();
152 }
153 }
154
155 let combined_nonce = combined_nonce.ok_or("Missing nonce in server-first-message")?;
156 let salt_b64 = salt_b64.ok_or("Missing salt in server-first-message")?;
157 let iterations: u32 = iterations.ok_or("Missing iterations in server-first-message")?;
158
159 if !combined_nonce.starts_with(&self.nonce) {
161 return Err("Server nonce doesn't start with client nonce".to_string());
162 }
163
164 let salt = base64::engine::general_purpose::STANDARD
166 .decode(salt_b64)
167 .map_err(|e| format!("Invalid salt: {}", e))?;
168
169 let mut salted_password = vec![0u8; 32];
171 pbkdf2_hmac::<Sha256>(
172 self.password.as_bytes(),
173 &salt,
174 iterations,
175 &mut salted_password,
176 );
177
178 self.salted_password = Some(salted_password.clone());
179
180 let client_key = {
182 let mut mac = <Hmac<Sha256> as Mac>::new_from_slice(&salted_password)
183 .map_err(|e| format!("HMAC error: {}", e))?;
184 mac.update(b"Client Key");
185 mac.finalize().into_bytes()
186 };
187
188 let stored_key = Sha256::digest(client_key);
190
191 let channel_binding_b64 =
193 base64::engine::general_purpose::STANDARD.encode(self.channel_binding.as_bytes());
194
195 let client_final_without_proof = format!("c={},r={}", channel_binding_b64, combined_nonce);
197
198 let auth_message = format!(
200 "{},{},{}",
201 self.client_first_message_bare(),
202 server_first,
203 client_final_without_proof
204 );
205 self.auth_message = Some(auth_message.clone());
206
207 let client_signature = {
209 let mut mac = <Hmac<Sha256> as Mac>::new_from_slice(&stored_key)
210 .map_err(|e| format!("HMAC error: {}", e))?;
211 mac.update(auth_message.as_bytes());
212 mac.finalize().into_bytes()
213 };
214
215 let mut client_proof = [0u8; 32];
217 for i in 0..32 {
218 client_proof[i] = client_key[i] ^ client_signature[i];
219 }
220
221 let proof_b64 = base64::engine::general_purpose::STANDARD.encode(client_proof);
222
223 Ok(format!("{},p={}", client_final_without_proof, proof_b64))
225 }
226
227 pub fn verify_server_final(&self, server_final: &str) -> Result<(), String> {
229 use base64::Engine;
230 use hmac::{Hmac, Mac};
231
232 let server_signature_b64 = server_final
234 .strip_prefix("v=")
235 .ok_or("Invalid server-final-message format")?;
236
237 let server_signature = base64::engine::general_purpose::STANDARD
238 .decode(server_signature_b64)
239 .map_err(|e| format!("Invalid server signature: {}", e))?;
240
241 let salted_password = self
243 .salted_password
244 .as_ref()
245 .ok_or("Missing salted password")?;
246 let auth_message = self.auth_message.as_ref().ok_or("Missing auth message")?;
247
248 let server_key = {
250 let mut mac = <Hmac<sha2::Sha256> as Mac>::new_from_slice(salted_password)
251 .map_err(|e| format!("HMAC error: {}", e))?;
252 mac.update(b"Server Key");
253 mac.finalize().into_bytes()
254 };
255
256 let expected_signature = {
258 let mut mac = <Hmac<sha2::Sha256> as Mac>::new_from_slice(&server_key)
259 .map_err(|e| format!("HMAC error: {}", e))?;
260 mac.update(auth_message.as_bytes());
261 mac.finalize().into_bytes()
262 };
263
264 if server_signature.as_slice() != expected_signature.as_slice() {
265 return Err("Server signature verification failed".to_string());
266 }
267
268 Ok(())
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275
276 #[test]
277 fn test_md5_password() {
278 let result = md5_password("postgres", "password", &[0x01, 0x02, 0x03, 0x04]);
280 assert!(result.starts_with("md5"));
281 assert_eq!(result.len(), 35); }
283
284 #[test]
285 fn test_password_message() {
286 let mut buf = Vec::new();
287 write_password(&mut buf, "secret");
288
289 assert_eq!(buf[0], b'p');
290 assert!(buf.ends_with(&[0]));
292 }
293}