zero_postgres/state/
connection.rs

1//! Connection startup and authentication state machine.
2
3use crate::error::{Error, Result};
4use crate::opts::{Opts, SslMode};
5use crate::protocol::backend::{
6    AuthenticationMessage, BackendKeyData, ErrorResponse, NegotiateProtocolVersion,
7    ParameterStatus, RawMessage, ReadyForQuery, msg_type,
8};
9use crate::protocol::frontend::auth::{ScramClient, md5_password};
10use crate::protocol::frontend::{
11    startup::write_ssl_request, write_password, write_sasl_initial_response, write_sasl_response,
12    write_startup,
13};
14use crate::protocol::types::TransactionStatus;
15
16use super::StateMachine;
17use super::action::{Action, AsyncMessage};
18use crate::buffer_set::BufferSet;
19
20/// Connection state during startup.
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22enum State {
23    Initial,
24    WaitingSslResponse,
25    WaitingTlsHandshake,
26    WaitingAuthRead,
27    WaitingAuth,
28    SaslInProgressRead,
29    SaslInProgress,
30    WaitingAuthResultRead,
31    WaitingAuthResult,
32    WaitingReady,
33    Finished,
34}
35
36/// Connection startup state machine.
37pub struct ConnectionStateMachine {
38    state: State,
39    options: Opts,
40    backend_key: Option<BackendKeyData>,
41    server_params: Vec<(String, String)>,
42    transaction_status: TransactionStatus,
43    scram_client: Option<ScramClient>,
44    /// SSL response byte, set by driver after ReadByte
45    ssl_response: u8,
46}
47
48impl ConnectionStateMachine {
49    /// Create a new connection state machine.
50    pub fn new(options: Opts) -> Self {
51        Self {
52            state: State::Initial,
53            options,
54            backend_key: None,
55            server_params: Vec::new(),
56            transaction_status: TransactionStatus::Idle,
57            scram_client: None,
58            ssl_response: 0,
59        }
60    }
61
62    /// Get the backend key data (for cancellation).
63    pub fn backend_key(&self) -> Option<&BackendKeyData> {
64        self.backend_key.as_ref()
65    }
66
67    /// Take server parameters.
68    pub fn take_server_params(&mut self) -> Vec<(String, String)> {
69        std::mem::take(&mut self.server_params)
70    }
71
72    /// Set the SSL response byte (called by driver after ReadByte).
73    pub fn set_ssl_response(&mut self, response: u8) {
74        self.ssl_response = response;
75    }
76
77    fn handle_initial(&mut self, buffer_set: &mut BufferSet) -> Result<Action> {
78        buffer_set.write_buffer.clear();
79
80        let client_supports_tls = cfg!(any(feature = "sync-tls", feature = "tokio-tls"));
81
82        let send_ssl_request = match self.options.ssl_mode {
83            SslMode::Disable => false,
84            SslMode::Prefer => client_supports_tls,
85            SslMode::Require if !client_supports_tls => {
86                return Err(Error::Unsupported(
87                    "SSL required but TLS feature not enabled".into(),
88                ));
89            }
90            SslMode::Require => true,
91        };
92
93        if send_ssl_request {
94            write_ssl_request(&mut buffer_set.write_buffer);
95            self.state = State::WaitingSslResponse;
96            Ok(Action::WriteAndReadByte)
97        } else {
98            self.write_startup_message(&mut buffer_set.write_buffer);
99            self.state = State::WaitingAuthRead;
100            Ok(Action::Write)
101        }
102    }
103
104    fn handle_ssl_response(&mut self, buffer_set: &mut BufferSet) -> Result<Action> {
105        match self.ssl_response {
106            b'S' => {
107                self.state = State::WaitingTlsHandshake;
108                Ok(Action::TlsHandshake)
109            }
110            b'N' => {
111                if self.options.ssl_mode == SslMode::Require {
112                    return Err(Error::Auth(
113                        "SSL required but not supported by server".into(),
114                    ));
115                }
116                // SSL not supported, continue with plain connection
117                buffer_set.write_buffer.clear();
118                self.write_startup_message(&mut buffer_set.write_buffer);
119                self.state = State::WaitingAuthRead;
120                Ok(Action::Write)
121            }
122            _ => Err(Error::Protocol(format!(
123                "Unexpected SSL response: {}",
124                self.ssl_response
125            ))),
126        }
127    }
128
129    fn handle_tls_handshake_complete(&mut self, buffer_set: &mut BufferSet) -> Result<Action> {
130        buffer_set.write_buffer.clear();
131        self.write_startup_message(&mut buffer_set.write_buffer);
132        self.state = State::WaitingAuthRead;
133        Ok(Action::Write)
134    }
135
136    fn write_startup_message(&self, write_buffer: &mut Vec<u8>) {
137        let mut params: Vec<(&str, &str)> =
138            vec![("user", &self.options.user), ("client_encoding", "UTF8")];
139
140        if let Some(db) = &self.options.database {
141            params.push(("database", db));
142        }
143
144        if let Some(app) = &self.options.application_name {
145            params.push(("application_name", app));
146        }
147
148        for (name, value) in &self.options.params {
149            params.push((name, value));
150        }
151
152        write_startup(write_buffer, &params);
153    }
154
155    fn handle_auth_message(&mut self, buffer_set: &mut BufferSet) -> Result<Action> {
156        let type_byte = buffer_set.type_byte;
157
158        // Handle NegotiateProtocolVersion - server doesn't support our protocol version
159        if type_byte == msg_type::NEGOTIATE_PROTOCOL_VERSION {
160            let negotiate = NegotiateProtocolVersion::parse(&buffer_set.read_buffer)?;
161            // Server sends the newest minor version it supports (0 for 3.0, 1 for 3.1, etc.)
162            return Err(Error::Protocol(format!(
163                "Server does not support protocol 3.2 (requires PostgreSQL 17+). \
164                 Server supports protocol 3.{}. Unrecognized options: {:?}",
165                negotiate.newest_minor_version, negotiate.unrecognized_options
166            )));
167        }
168
169        if type_byte != msg_type::AUTHENTICATION {
170            return Err(Error::Protocol(format!(
171                "Expected Authentication message, got '{}'",
172                type_byte as char
173            )));
174        }
175
176        let auth = AuthenticationMessage::parse(&buffer_set.read_buffer)?;
177
178        match auth {
179            AuthenticationMessage::Ok => {
180                self.state = State::WaitingReady;
181                Ok(Action::ReadMessage)
182            }
183            AuthenticationMessage::CleartextPassword => {
184                let password = self
185                    .options
186                    .password
187                    .as_ref()
188                    .ok_or_else(|| Error::Auth("Password required but not provided".into()))?;
189
190                buffer_set.write_buffer.clear();
191                write_password(&mut buffer_set.write_buffer, password);
192                self.state = State::WaitingAuthResultRead;
193                Ok(Action::Write)
194            }
195            AuthenticationMessage::Md5Password { salt } => {
196                let password = self
197                    .options
198                    .password
199                    .as_ref()
200                    .ok_or_else(|| Error::Auth("Password required but not provided".into()))?;
201
202                let hashed = md5_password(&self.options.user, password, &salt);
203                buffer_set.write_buffer.clear();
204                write_password(&mut buffer_set.write_buffer, &hashed);
205                self.state = State::WaitingAuthResultRead;
206                Ok(Action::Write)
207            }
208            AuthenticationMessage::Sasl { mechanisms } => {
209                // Check if SCRAM-SHA-256 is supported
210                if !mechanisms.contains(&"SCRAM-SHA-256") {
211                    return Err(Error::Auth(format!(
212                        "No supported SASL mechanism. Server offers: {:?}",
213                        mechanisms
214                    )));
215                }
216
217                let password = self
218                    .options
219                    .password
220                    .as_ref()
221                    .ok_or_else(|| Error::Auth("Password required but not provided".into()))?;
222
223                let scram = ScramClient::new(password);
224                let client_first = scram.client_first_message();
225
226                buffer_set.write_buffer.clear();
227                write_sasl_initial_response(
228                    &mut buffer_set.write_buffer,
229                    "SCRAM-SHA-256",
230                    client_first.as_bytes(),
231                );
232
233                self.scram_client = Some(scram);
234                self.state = State::SaslInProgressRead;
235                Ok(Action::Write)
236            }
237            _ => Err(Error::Unsupported(format!(
238                "Unsupported authentication method: {:?}",
239                auth
240            ))),
241        }
242    }
243
244    fn handle_sasl_message(&mut self, buffer_set: &mut BufferSet) -> Result<Action> {
245        let type_byte = buffer_set.type_byte;
246        if type_byte != msg_type::AUTHENTICATION {
247            return Err(Error::Protocol(format!(
248                "Expected Authentication message, got '{}'",
249                type_byte as char
250            )));
251        }
252
253        let auth = AuthenticationMessage::parse(&buffer_set.read_buffer)?;
254
255        match auth {
256            AuthenticationMessage::SaslContinue { data } => {
257                let scram = self
258                    .scram_client
259                    .as_mut()
260                    .ok_or_else(|| Error::Protocol("SCRAM client not initialized".into()))?;
261
262                let server_first = simdutf8::compat::from_utf8(data)
263                    .map_err(|e| Error::Auth(format!("Invalid server-first-message: {}", e)))?;
264
265                let client_final = scram
266                    .process_server_first(server_first)
267                    .map_err(Error::Auth)?;
268
269                buffer_set.write_buffer.clear();
270                write_sasl_response(&mut buffer_set.write_buffer, client_final.as_bytes());
271                self.state = State::SaslInProgressRead;
272                Ok(Action::Write)
273            }
274            AuthenticationMessage::SaslFinal { data } => {
275                let scram = self
276                    .scram_client
277                    .as_ref()
278                    .ok_or_else(|| Error::Protocol("SCRAM client not initialized".into()))?;
279
280                let server_final = simdutf8::compat::from_utf8(data)
281                    .map_err(|e| Error::Auth(format!("Invalid server-final-message: {}", e)))?;
282
283                scram
284                    .verify_server_final(server_final)
285                    .map_err(Error::Auth)?;
286
287                self.state = State::WaitingAuthResult;
288                Ok(Action::ReadMessage)
289            }
290            _ => Err(Error::Protocol(format!(
291                "Unexpected SASL message: {:?}",
292                auth
293            ))),
294        }
295    }
296
297    fn handle_auth_result(&mut self, buffer_set: &BufferSet) -> Result<Action> {
298        let type_byte = buffer_set.type_byte;
299        if type_byte != msg_type::AUTHENTICATION {
300            return Err(Error::Protocol(format!(
301                "Expected AuthenticationOk, got '{}'",
302                type_byte as char
303            )));
304        }
305
306        let auth = AuthenticationMessage::parse(&buffer_set.read_buffer)?;
307
308        match auth {
309            AuthenticationMessage::Ok => {
310                self.state = State::WaitingReady;
311                Ok(Action::ReadMessage)
312            }
313            _ => Err(Error::Auth(format!("Unexpected auth result: {:?}", auth))),
314        }
315    }
316
317    fn handle_ready_message(&mut self, buffer_set: &BufferSet) -> Result<Action> {
318        let type_byte = buffer_set.type_byte;
319        let payload = &buffer_set.read_buffer;
320
321        match type_byte {
322            msg_type::BACKEND_KEY_DATA => {
323                let key = BackendKeyData::parse(payload)?;
324                self.backend_key = Some(key);
325                Ok(Action::ReadMessage)
326            }
327            msg_type::PARAMETER_STATUS => {
328                let param = ParameterStatus::parse(payload)?;
329                self.server_params
330                    .push((param.name.to_string(), param.value.to_string()));
331                Ok(Action::ReadMessage)
332            }
333            msg_type::READY_FOR_QUERY => {
334                let ready = ReadyForQuery::parse(payload)?;
335                self.transaction_status = ready.transaction_status().unwrap_or_default();
336                self.state = State::Finished;
337                Ok(Action::Finished)
338            }
339            _ => Err(Error::Protocol(format!(
340                "Unexpected message during startup: '{}'",
341                type_byte as char
342            ))),
343        }
344    }
345
346    fn handle_async_message(&self, msg: &RawMessage<'_>) -> Result<Action> {
347        match msg.type_byte {
348            msg_type::NOTICE_RESPONSE => {
349                let notice = crate::protocol::backend::NoticeResponse::parse(msg.payload)?;
350                Ok(Action::HandleAsyncMessageAndReadMessage(
351                    AsyncMessage::Notice(notice.0),
352                ))
353            }
354            msg_type::PARAMETER_STATUS => {
355                let param = ParameterStatus::parse(msg.payload)?;
356                Ok(Action::HandleAsyncMessageAndReadMessage(
357                    AsyncMessage::ParameterChanged {
358                        name: param.name.to_string(),
359                        value: param.value.to_string(),
360                    },
361                ))
362            }
363            msg_type::NOTIFICATION_RESPONSE => {
364                let notification =
365                    crate::protocol::backend::auth::NotificationResponse::parse(msg.payload)?;
366                Ok(Action::HandleAsyncMessageAndReadMessage(
367                    AsyncMessage::Notification {
368                        pid: notification.pid,
369                        channel: notification.channel.to_string(),
370                        payload: notification.payload.to_string(),
371                    },
372                ))
373            }
374            _ => Err(Error::Protocol(format!(
375                "Unknown async message type: '{}'",
376                msg.type_byte as char
377            ))),
378        }
379    }
380}
381
382impl StateMachine for ConnectionStateMachine {
383    fn step(&mut self, buffer_set: &mut BufferSet) -> Result<Action> {
384        // Handle states that don't need to read buffer_set
385        match self.state {
386            State::Initial => return self.handle_initial(buffer_set),
387            State::WaitingSslResponse => return self.handle_ssl_response(buffer_set),
388            State::WaitingTlsHandshake => return self.handle_tls_handshake_complete(buffer_set),
389            State::WaitingAuthRead => {
390                self.state = State::WaitingAuth;
391                return Ok(Action::ReadMessage);
392            }
393            State::SaslInProgressRead => {
394                self.state = State::SaslInProgress;
395                return Ok(Action::ReadMessage);
396            }
397            State::WaitingAuthResultRead => {
398                self.state = State::WaitingAuthResult;
399                return Ok(Action::ReadMessage);
400            }
401            _ => {}
402        }
403
404        let type_byte = buffer_set.type_byte;
405
406        // Handle async messages that can arrive at any time
407        // Note: PARAMETER_STATUS during WaitingReady is part of normal startup, not async
408        if RawMessage::is_async_type(type_byte)
409            && !(self.state == State::WaitingReady && type_byte == msg_type::PARAMETER_STATUS)
410        {
411            let msg = RawMessage::new(type_byte, &buffer_set.read_buffer);
412            return self.handle_async_message(&msg);
413        }
414
415        // Handle error response
416        if type_byte == msg_type::ERROR_RESPONSE {
417            let error = ErrorResponse::parse(&buffer_set.read_buffer)?;
418            return Err(error.into_error());
419        }
420
421        match self.state {
422            State::WaitingAuth => self.handle_auth_message(buffer_set),
423            State::SaslInProgress => self.handle_sasl_message(buffer_set),
424            State::WaitingAuthResult => self.handle_auth_result(buffer_set),
425            State::WaitingReady => self.handle_ready_message(buffer_set),
426            _ => Err(Error::Protocol(format!(
427                "Unexpected state {:?}",
428                self.state
429            ))),
430        }
431    }
432
433    fn transaction_status(&self) -> TransactionStatus {
434        self.transaction_status
435    }
436}