zero_postgres/protocol/frontend/
auth.rs

1//! Authentication messages.
2
3use crate::protocol::codec::MessageBuilder;
4
5/// Write a PasswordMessage (cleartext or MD5 hashed password).
6pub fn write_password(buf: &mut Vec<u8>, password: &str) {
7    let mut msg = MessageBuilder::new(buf, super::msg_type::PASSWORD);
8    msg.write_cstr(password);
9    msg.finish();
10}
11
12/// Compute MD5 password hash.
13///
14/// PostgreSQL MD5 password format: "md5" + md5(md5(password + username) + salt)
15pub fn md5_password(username: &str, password: &str, salt: &[u8; 4]) -> String {
16    use md5::{Digest, Md5};
17
18    // First hash: md5(password + username)
19    let first_hash = {
20        let mut hasher = Md5::new();
21        hasher.update(password.as_bytes());
22        hasher.update(username.as_bytes());
23        hasher.finalize()
24    };
25    let first_hash_hex = format!("{:x}", first_hash);
26
27    // Second hash: md5(first_hash_hex + salt)
28    let second_hash = {
29        let mut hasher = Md5::new();
30        hasher.update(first_hash_hex.as_bytes());
31        hasher.update(salt);
32        hasher.finalize()
33    };
34
35    format!("md5{:x}", second_hash)
36}
37
38/// Write a SASLInitialResponse message.
39///
40/// mechanism: SASL mechanism name (e.g., "SCRAM-SHA-256")
41/// initial_response: Client-first-message for SCRAM
42pub fn write_sasl_initial_response(buf: &mut Vec<u8>, mechanism: &str, initial_response: &[u8]) {
43    let mut msg = MessageBuilder::new(buf, super::msg_type::PASSWORD);
44    msg.write_cstr(mechanism);
45    msg.write_i32(initial_response.len() as i32);
46    msg.write_bytes(initial_response);
47    msg.finish();
48}
49
50/// Write a SASLResponse message.
51///
52/// response: Client-final-message for SCRAM
53pub fn write_sasl_response(buf: &mut Vec<u8>, response: &[u8]) {
54    let mut msg = MessageBuilder::new(buf, super::msg_type::PASSWORD);
55    msg.write_bytes(response);
56    msg.finish();
57}
58
59/// SCRAM-SHA-256 client implementation.
60pub struct ScramClient {
61    /// Client nonce
62    nonce: String,
63    /// Channel binding flag
64    channel_binding: String,
65    /// Password
66    password: String,
67    /// Server-first-message (stored for later)
68    server_first: Option<String>,
69    /// Auth message for signature verification
70    auth_message: Option<String>,
71    /// Salted password for server signature verification
72    salted_password: Option<Vec<u8>>,
73}
74
75impl ScramClient {
76    /// Create a new SCRAM client.
77    pub fn new(password: &str) -> Self {
78        use rand::Rng;
79
80        // Generate 24-byte random nonce, base64 encoded
81        let mut nonce_bytes = [0u8; 24];
82        rand::rng().fill(&mut nonce_bytes);
83        let nonce = base64::Engine::encode(&base64::engine::general_purpose::STANDARD, nonce_bytes);
84
85        Self {
86            nonce,
87            channel_binding: "n,,".to_string(), // No channel binding
88            password: password.to_string(),
89            server_first: None,
90            auth_message: None,
91            salted_password: None,
92        }
93    }
94
95    /// Create a new SCRAM client with channel binding.
96    pub fn new_with_channel_binding(password: &str, channel_binding_data: &[u8]) -> Self {
97        use rand::Rng;
98
99        let mut nonce_bytes = [0u8; 24];
100        rand::rng().fill(&mut nonce_bytes);
101        let nonce = base64::Engine::encode(&base64::engine::general_purpose::STANDARD, nonce_bytes);
102
103        // p=tls-server-end-point,,
104        let cb_data = base64::Engine::encode(
105            &base64::engine::general_purpose::STANDARD,
106            channel_binding_data,
107        );
108
109        Self {
110            nonce,
111            channel_binding: format!("p=tls-server-end-point,,{}", cb_data),
112            password: password.to_string(),
113            server_first: None,
114            auth_message: None,
115            salted_password: None,
116        }
117    }
118
119    /// Generate the client-first-message.
120    pub fn client_first_message(&self) -> String {
121        // n,,n=,r=<nonce>
122        // Note: username is empty because PostgreSQL ignores it in SCRAM
123        format!("{}n=,r={}", self.channel_binding, self.nonce)
124    }
125
126    /// Get the bare client-first-message (without channel binding prefix).
127    fn client_first_message_bare(&self) -> String {
128        format!("n=,r={}", self.nonce)
129    }
130
131    /// Process server-first-message and generate client-final-message.
132    pub fn process_server_first(&mut self, server_first: &str) -> Result<String, String> {
133        use base64::Engine;
134        use hmac::{Hmac, Mac};
135        use pbkdf2::pbkdf2_hmac;
136        use sha2::{Digest, Sha256};
137
138        self.server_first = Some(server_first.to_string());
139
140        // Parse server-first-message: r=<nonce>,s=<salt>,i=<iterations>
141        let mut combined_nonce = None;
142        let mut salt_b64 = None;
143        let mut iterations = None;
144
145        for part in server_first.split(',') {
146            if let Some(value) = part.strip_prefix("r=") {
147                combined_nonce = Some(value);
148            } else if let Some(value) = part.strip_prefix("s=") {
149                salt_b64 = Some(value);
150            } else if let Some(value) = part.strip_prefix("i=") {
151                iterations = value.parse().ok();
152            }
153        }
154
155        let combined_nonce = combined_nonce.ok_or("Missing nonce in server-first-message")?;
156        let salt_b64 = salt_b64.ok_or("Missing salt in server-first-message")?;
157        let iterations: u32 = iterations.ok_or("Missing iterations in server-first-message")?;
158
159        // Verify nonce starts with our client nonce
160        if !combined_nonce.starts_with(&self.nonce) {
161            return Err("Server nonce doesn't start with client nonce".to_string());
162        }
163
164        // Decode salt
165        let salt = base64::engine::general_purpose::STANDARD
166            .decode(salt_b64)
167            .map_err(|e| format!("Invalid salt: {}", e))?;
168
169        // Compute SaltedPassword = Hi(Normalize(password), salt, iterations)
170        let mut salted_password = vec![0u8; 32];
171        pbkdf2_hmac::<Sha256>(
172            self.password.as_bytes(),
173            &salt,
174            iterations,
175            &mut salted_password,
176        );
177
178        self.salted_password = Some(salted_password.clone());
179
180        // ClientKey = HMAC(SaltedPassword, "Client Key")
181        let client_key = {
182            let mut mac = <Hmac<Sha256> as Mac>::new_from_slice(&salted_password)
183                .map_err(|e| format!("HMAC error: {}", e))?;
184            mac.update(b"Client Key");
185            mac.finalize().into_bytes()
186        };
187
188        // StoredKey = H(ClientKey)
189        let stored_key = Sha256::digest(client_key);
190
191        // channel-binding = base64(channel-binding-flag)
192        let channel_binding_b64 =
193            base64::engine::general_purpose::STANDARD.encode(self.channel_binding.as_bytes());
194
195        // client-final-message-without-proof = c=<channel-binding>,r=<nonce>
196        let client_final_without_proof = format!("c={},r={}", channel_binding_b64, combined_nonce);
197
198        // AuthMessage = client-first-message-bare + "," + server-first-message + "," + client-final-message-without-proof
199        let auth_message = format!(
200            "{},{},{}",
201            self.client_first_message_bare(),
202            server_first,
203            client_final_without_proof
204        );
205        self.auth_message = Some(auth_message.clone());
206
207        // ClientSignature = HMAC(StoredKey, AuthMessage)
208        let client_signature = {
209            let mut mac = <Hmac<Sha256> as Mac>::new_from_slice(&stored_key)
210                .map_err(|e| format!("HMAC error: {}", e))?;
211            mac.update(auth_message.as_bytes());
212            mac.finalize().into_bytes()
213        };
214
215        // ClientProof = ClientKey XOR ClientSignature
216        let mut client_proof = [0u8; 32];
217        for i in 0..32 {
218            client_proof[i] = client_key[i] ^ client_signature[i];
219        }
220
221        let proof_b64 = base64::engine::general_purpose::STANDARD.encode(client_proof);
222
223        // client-final-message = client-final-message-without-proof + ",p=" + base64(ClientProof)
224        Ok(format!("{},p={}", client_final_without_proof, proof_b64))
225    }
226
227    /// Verify server-final-message.
228    pub fn verify_server_final(&self, server_final: &str) -> Result<(), String> {
229        use base64::Engine;
230        use hmac::{Hmac, Mac};
231
232        // Parse server-final-message: v=<server-signature>
233        let server_signature_b64 = server_final
234            .strip_prefix("v=")
235            .ok_or("Invalid server-final-message format")?;
236
237        let server_signature = base64::engine::general_purpose::STANDARD
238            .decode(server_signature_b64)
239            .map_err(|e| format!("Invalid server signature: {}", e))?;
240
241        // Compute expected ServerSignature
242        let salted_password = self
243            .salted_password
244            .as_ref()
245            .ok_or("Missing salted password")?;
246        let auth_message = self.auth_message.as_ref().ok_or("Missing auth message")?;
247
248        // ServerKey = HMAC(SaltedPassword, "Server Key")
249        let server_key = {
250            let mut mac = <Hmac<sha2::Sha256> as Mac>::new_from_slice(salted_password)
251                .map_err(|e| format!("HMAC error: {}", e))?;
252            mac.update(b"Server Key");
253            mac.finalize().into_bytes()
254        };
255
256        // ServerSignature = HMAC(ServerKey, AuthMessage)
257        let expected_signature = {
258            let mut mac = <Hmac<sha2::Sha256> as Mac>::new_from_slice(&server_key)
259                .map_err(|e| format!("HMAC error: {}", e))?;
260            mac.update(auth_message.as_bytes());
261            mac.finalize().into_bytes()
262        };
263
264        if server_signature.as_slice() != expected_signature.as_slice() {
265            return Err("Server signature verification failed".to_string());
266        }
267
268        Ok(())
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275
276    #[test]
277    fn test_md5_password() {
278        // Test vector from PostgreSQL
279        let result = md5_password("postgres", "password", &[0x01, 0x02, 0x03, 0x04]);
280        assert!(result.starts_with("md5"));
281        assert_eq!(result.len(), 35); // "md5" + 32 hex chars
282    }
283
284    #[test]
285    fn test_password_message() {
286        let mut buf = Vec::new();
287        write_password(&mut buf, "secret");
288
289        assert_eq!(buf[0], b'p');
290        // Check that password is null-terminated in the message
291        assert!(buf.ends_with(&[0]));
292    }
293}