1use crate::error::{Error, Result};
4use crate::opts::{Opts, SslMode};
5use crate::protocol::backend::{
6 AuthenticationMessage, BackendKeyData, ErrorResponse, NegotiateProtocolVersion,
7 ParameterStatus, RawMessage, ReadyForQuery, msg_type,
8};
9use crate::protocol::frontend::auth::{ScramClient, md5_password};
10use crate::protocol::frontend::{
11 startup::write_ssl_request, write_password, write_sasl_initial_response, write_sasl_response,
12 write_startup,
13};
14use crate::protocol::types::TransactionStatus;
15
16use super::StateMachine;
17use super::action::{Action, AsyncMessage};
18use crate::buffer_set::BufferSet;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22enum State {
23 Initial,
24 WaitingSslResponse,
25 WaitingTlsHandshake,
26 WaitingAuthRead,
27 WaitingAuth,
28 SaslInProgressRead,
29 SaslInProgress,
30 WaitingAuthResultRead,
31 WaitingAuthResult,
32 WaitingReady,
33 Finished,
34}
35
36pub struct ConnectionStateMachine {
38 state: State,
39 options: Opts,
40 backend_key: Option<BackendKeyData>,
41 server_params: Vec<(String, String)>,
42 transaction_status: TransactionStatus,
43 scram_client: Option<ScramClient>,
44 ssl_response: u8,
46}
47
48impl ConnectionStateMachine {
49 pub fn new(options: Opts) -> Self {
51 Self {
52 state: State::Initial,
53 options,
54 backend_key: None,
55 server_params: Vec::new(),
56 transaction_status: TransactionStatus::Idle,
57 scram_client: None,
58 ssl_response: 0,
59 }
60 }
61
62 pub fn backend_key(&self) -> Option<&BackendKeyData> {
64 self.backend_key.as_ref()
65 }
66
67 pub fn take_server_params(&mut self) -> Vec<(String, String)> {
69 std::mem::take(&mut self.server_params)
70 }
71
72 pub fn set_ssl_response(&mut self, response: u8) {
74 self.ssl_response = response;
75 }
76
77 fn handle_initial(&mut self, buffer_set: &mut BufferSet) -> Result<Action> {
78 buffer_set.write_buffer.clear();
79
80 let client_supports_tls = cfg!(any(feature = "sync-tls", feature = "tokio-tls"));
81
82 let send_ssl_request = match self.options.ssl_mode {
83 SslMode::Disable => false,
84 SslMode::Prefer => client_supports_tls,
85 SslMode::Require if !client_supports_tls => {
86 return Err(Error::Unsupported(
87 "SSL required but TLS feature not enabled".into(),
88 ));
89 }
90 SslMode::Require => true,
91 };
92
93 if send_ssl_request {
94 write_ssl_request(&mut buffer_set.write_buffer);
95 self.state = State::WaitingSslResponse;
96 Ok(Action::WriteAndReadByte)
97 } else {
98 self.write_startup_message(&mut buffer_set.write_buffer);
99 self.state = State::WaitingAuthRead;
100 Ok(Action::Write)
101 }
102 }
103
104 fn handle_ssl_response(&mut self, buffer_set: &mut BufferSet) -> Result<Action> {
105 match self.ssl_response {
106 b'S' => {
107 self.state = State::WaitingTlsHandshake;
108 Ok(Action::TlsHandshake)
109 }
110 b'N' => {
111 if self.options.ssl_mode == SslMode::Require {
112 return Err(Error::Auth(
113 "SSL required but not supported by server".into(),
114 ));
115 }
116 buffer_set.write_buffer.clear();
118 self.write_startup_message(&mut buffer_set.write_buffer);
119 self.state = State::WaitingAuthRead;
120 Ok(Action::Write)
121 }
122 _ => Err(Error::Protocol(format!(
123 "Unexpected SSL response: {}",
124 self.ssl_response
125 ))),
126 }
127 }
128
129 fn handle_tls_handshake_complete(&mut self, buffer_set: &mut BufferSet) -> Result<Action> {
130 buffer_set.write_buffer.clear();
131 self.write_startup_message(&mut buffer_set.write_buffer);
132 self.state = State::WaitingAuthRead;
133 Ok(Action::Write)
134 }
135
136 fn write_startup_message(&self, write_buffer: &mut Vec<u8>) {
137 let mut params: Vec<(&str, &str)> =
138 vec![("user", &self.options.user), ("client_encoding", "UTF8")];
139
140 if let Some(db) = &self.options.database {
141 params.push(("database", db));
142 }
143
144 if let Some(app) = &self.options.application_name {
145 params.push(("application_name", app));
146 }
147
148 for (name, value) in &self.options.params {
149 params.push((name, value));
150 }
151
152 write_startup(write_buffer, ¶ms);
153 }
154
155 fn handle_auth_message(&mut self, buffer_set: &mut BufferSet) -> Result<Action> {
156 let type_byte = buffer_set.type_byte;
157
158 if type_byte == msg_type::NEGOTIATE_PROTOCOL_VERSION {
160 let negotiate = NegotiateProtocolVersion::parse(&buffer_set.read_buffer)?;
161 return Err(Error::Protocol(format!(
163 "Server does not support protocol 3.2 (requires PostgreSQL 17+). \
164 Server supports protocol 3.{}. Unrecognized options: {:?}",
165 negotiate.newest_minor_version, negotiate.unrecognized_options
166 )));
167 }
168
169 if type_byte != msg_type::AUTHENTICATION {
170 return Err(Error::Protocol(format!(
171 "Expected Authentication message, got '{}'",
172 type_byte as char
173 )));
174 }
175
176 let auth = AuthenticationMessage::parse(&buffer_set.read_buffer)?;
177
178 match auth {
179 AuthenticationMessage::Ok => {
180 self.state = State::WaitingReady;
181 Ok(Action::ReadMessage)
182 }
183 AuthenticationMessage::CleartextPassword => {
184 let password = self
185 .options
186 .password
187 .as_ref()
188 .ok_or_else(|| Error::Auth("Password required but not provided".into()))?;
189
190 buffer_set.write_buffer.clear();
191 write_password(&mut buffer_set.write_buffer, password);
192 self.state = State::WaitingAuthResultRead;
193 Ok(Action::Write)
194 }
195 AuthenticationMessage::Md5Password { salt } => {
196 let password = self
197 .options
198 .password
199 .as_ref()
200 .ok_or_else(|| Error::Auth("Password required but not provided".into()))?;
201
202 let hashed = md5_password(&self.options.user, password, &salt);
203 buffer_set.write_buffer.clear();
204 write_password(&mut buffer_set.write_buffer, &hashed);
205 self.state = State::WaitingAuthResultRead;
206 Ok(Action::Write)
207 }
208 AuthenticationMessage::Sasl { mechanisms } => {
209 if !mechanisms.contains(&"SCRAM-SHA-256") {
211 return Err(Error::Auth(format!(
212 "No supported SASL mechanism. Server offers: {:?}",
213 mechanisms
214 )));
215 }
216
217 let password = self
218 .options
219 .password
220 .as_ref()
221 .ok_or_else(|| Error::Auth("Password required but not provided".into()))?;
222
223 let scram = ScramClient::new(password);
224 let client_first = scram.client_first_message();
225
226 buffer_set.write_buffer.clear();
227 write_sasl_initial_response(
228 &mut buffer_set.write_buffer,
229 "SCRAM-SHA-256",
230 client_first.as_bytes(),
231 );
232
233 self.scram_client = Some(scram);
234 self.state = State::SaslInProgressRead;
235 Ok(Action::Write)
236 }
237 _ => Err(Error::Unsupported(format!(
238 "Unsupported authentication method: {:?}",
239 auth
240 ))),
241 }
242 }
243
244 fn handle_sasl_message(&mut self, buffer_set: &mut BufferSet) -> Result<Action> {
245 let type_byte = buffer_set.type_byte;
246 if type_byte != msg_type::AUTHENTICATION {
247 return Err(Error::Protocol(format!(
248 "Expected Authentication message, got '{}'",
249 type_byte as char
250 )));
251 }
252
253 let auth = AuthenticationMessage::parse(&buffer_set.read_buffer)?;
254
255 match auth {
256 AuthenticationMessage::SaslContinue { data } => {
257 let scram = self
258 .scram_client
259 .as_mut()
260 .ok_or_else(|| Error::Protocol("SCRAM client not initialized".into()))?;
261
262 let server_first = simdutf8::compat::from_utf8(data)
263 .map_err(|e| Error::Auth(format!("Invalid server-first-message: {}", e)))?;
264
265 let client_final = scram
266 .process_server_first(server_first)
267 .map_err(Error::Auth)?;
268
269 buffer_set.write_buffer.clear();
270 write_sasl_response(&mut buffer_set.write_buffer, client_final.as_bytes());
271 self.state = State::SaslInProgressRead;
272 Ok(Action::Write)
273 }
274 AuthenticationMessage::SaslFinal { data } => {
275 let scram = self
276 .scram_client
277 .as_ref()
278 .ok_or_else(|| Error::Protocol("SCRAM client not initialized".into()))?;
279
280 let server_final = simdutf8::compat::from_utf8(data)
281 .map_err(|e| Error::Auth(format!("Invalid server-final-message: {}", e)))?;
282
283 scram
284 .verify_server_final(server_final)
285 .map_err(Error::Auth)?;
286
287 self.state = State::WaitingAuthResult;
288 Ok(Action::ReadMessage)
289 }
290 _ => Err(Error::Protocol(format!(
291 "Unexpected SASL message: {:?}",
292 auth
293 ))),
294 }
295 }
296
297 fn handle_auth_result(&mut self, buffer_set: &BufferSet) -> Result<Action> {
298 let type_byte = buffer_set.type_byte;
299 if type_byte != msg_type::AUTHENTICATION {
300 return Err(Error::Protocol(format!(
301 "Expected AuthenticationOk, got '{}'",
302 type_byte as char
303 )));
304 }
305
306 let auth = AuthenticationMessage::parse(&buffer_set.read_buffer)?;
307
308 match auth {
309 AuthenticationMessage::Ok => {
310 self.state = State::WaitingReady;
311 Ok(Action::ReadMessage)
312 }
313 _ => Err(Error::Auth(format!("Unexpected auth result: {:?}", auth))),
314 }
315 }
316
317 fn handle_ready_message(&mut self, buffer_set: &BufferSet) -> Result<Action> {
318 let type_byte = buffer_set.type_byte;
319 let payload = &buffer_set.read_buffer;
320
321 match type_byte {
322 msg_type::BACKEND_KEY_DATA => {
323 let key = BackendKeyData::parse(payload)?;
324 self.backend_key = Some(key);
325 Ok(Action::ReadMessage)
326 }
327 msg_type::PARAMETER_STATUS => {
328 let param = ParameterStatus::parse(payload)?;
329 self.server_params
330 .push((param.name.to_string(), param.value.to_string()));
331 Ok(Action::ReadMessage)
332 }
333 msg_type::READY_FOR_QUERY => {
334 let ready = ReadyForQuery::parse(payload)?;
335 self.transaction_status = ready.transaction_status().unwrap_or_default();
336 self.state = State::Finished;
337 Ok(Action::Finished)
338 }
339 _ => Err(Error::Protocol(format!(
340 "Unexpected message during startup: '{}'",
341 type_byte as char
342 ))),
343 }
344 }
345
346 fn handle_async_message(&self, msg: &RawMessage<'_>) -> Result<Action> {
347 match msg.type_byte {
348 msg_type::NOTICE_RESPONSE => {
349 let notice = crate::protocol::backend::NoticeResponse::parse(msg.payload)?;
350 Ok(Action::HandleAsyncMessageAndReadMessage(
351 AsyncMessage::Notice(notice.0),
352 ))
353 }
354 msg_type::PARAMETER_STATUS => {
355 let param = ParameterStatus::parse(msg.payload)?;
356 Ok(Action::HandleAsyncMessageAndReadMessage(
357 AsyncMessage::ParameterChanged {
358 name: param.name.to_string(),
359 value: param.value.to_string(),
360 },
361 ))
362 }
363 msg_type::NOTIFICATION_RESPONSE => {
364 let notification =
365 crate::protocol::backend::auth::NotificationResponse::parse(msg.payload)?;
366 Ok(Action::HandleAsyncMessageAndReadMessage(
367 AsyncMessage::Notification {
368 pid: notification.pid,
369 channel: notification.channel.to_string(),
370 payload: notification.payload.to_string(),
371 },
372 ))
373 }
374 _ => Err(Error::Protocol(format!(
375 "Unknown async message type: '{}'",
376 msg.type_byte as char
377 ))),
378 }
379 }
380}
381
382impl StateMachine for ConnectionStateMachine {
383 fn step(&mut self, buffer_set: &mut BufferSet) -> Result<Action> {
384 match self.state {
386 State::Initial => return self.handle_initial(buffer_set),
387 State::WaitingSslResponse => return self.handle_ssl_response(buffer_set),
388 State::WaitingTlsHandshake => return self.handle_tls_handshake_complete(buffer_set),
389 State::WaitingAuthRead => {
390 self.state = State::WaitingAuth;
391 return Ok(Action::ReadMessage);
392 }
393 State::SaslInProgressRead => {
394 self.state = State::SaslInProgress;
395 return Ok(Action::ReadMessage);
396 }
397 State::WaitingAuthResultRead => {
398 self.state = State::WaitingAuthResult;
399 return Ok(Action::ReadMessage);
400 }
401 _ => {}
402 }
403
404 let type_byte = buffer_set.type_byte;
405
406 if RawMessage::is_async_type(type_byte)
409 && !(self.state == State::WaitingReady && type_byte == msg_type::PARAMETER_STATUS)
410 {
411 let msg = RawMessage::new(type_byte, &buffer_set.read_buffer);
412 return self.handle_async_message(&msg);
413 }
414
415 if type_byte == msg_type::ERROR_RESPONSE {
417 let error = ErrorResponse::parse(&buffer_set.read_buffer)?;
418 return Err(error.into_error());
419 }
420
421 match self.state {
422 State::WaitingAuth => self.handle_auth_message(buffer_set),
423 State::SaslInProgress => self.handle_sasl_message(buffer_set),
424 State::WaitingAuthResult => self.handle_auth_result(buffer_set),
425 State::WaitingReady => self.handle_ready_message(buffer_set),
426 _ => Err(Error::Protocol(format!(
427 "Unexpected state {:?}",
428 self.state
429 ))),
430 }
431 }
432
433 fn transaction_status(&self) -> TransactionStatus {
434 self.transaction_status
435 }
436}