1use std::collections::HashMap;
23use std::io::{Read, Write};
24use std::net::TcpStream;
25#[cfg(feature = "console")]
26use std::sync::Arc;
27
28use sqlmodel_core::Error;
29use sqlmodel_core::error::{
30 ConnectionError, ConnectionErrorKind, ProtocolError, QueryError, QueryErrorKind,
31};
32
33#[cfg(feature = "console")]
34use sqlmodel_console::{ConsoleAware, SqlModelConsole};
35
36use crate::auth::ScramClient;
37use crate::config::PgConfig;
38use crate::protocol::{
39 BackendMessage, ErrorFields, FrontendMessage, MessageReader, MessageWriter, PROTOCOL_VERSION,
40 TransactionStatus,
41};
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
45pub enum ConnectionState {
46 Disconnected,
48 Connecting,
50 Authenticating,
52 Ready(TransactionStatusState),
54 InQuery,
56 InTransaction(TransactionStatusState),
58 Error,
60 Closed,
62}
63
64#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
66pub enum TransactionStatusState {
67 #[default]
69 Idle,
70 InTransaction,
72 InFailed,
74}
75
76impl From<TransactionStatus> for TransactionStatusState {
77 fn from(status: TransactionStatus) -> Self {
78 match status {
79 TransactionStatus::Idle => TransactionStatusState::Idle,
80 TransactionStatus::Transaction => TransactionStatusState::InTransaction,
81 TransactionStatus::Error => TransactionStatusState::InFailed,
82 }
83 }
84}
85
86pub struct PgConnection {
97 stream: TcpStream,
99 state: ConnectionState,
101 process_id: i32,
103 secret_key: i32,
105 parameters: HashMap<String, String>,
107 config: PgConfig,
109 reader: MessageReader,
111 writer: MessageWriter,
113 read_buf: Vec<u8>,
115 #[cfg(feature = "console")]
117 console: Option<Arc<SqlModelConsole>>,
118}
119
120impl std::fmt::Debug for PgConnection {
121 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122 f.debug_struct("PgConnection")
123 .field("state", &self.state)
124 .field("process_id", &self.process_id)
125 .field("host", &self.config.host)
126 .field("port", &self.config.port)
127 .field("database", &self.config.database)
128 .finish_non_exhaustive()
129 }
130}
131
132impl PgConnection {
133 #[allow(clippy::result_large_err)]
142 pub fn connect(config: PgConfig) -> Result<Self, Error> {
143 let stream = TcpStream::connect_timeout(
145 &config.socket_addr().parse().map_err(|e| {
146 Error::Connection(ConnectionError {
147 kind: ConnectionErrorKind::Connect,
148 message: format!("Invalid socket address: {}", e),
149 source: None,
150 })
151 })?,
152 config.connect_timeout,
153 )
154 .map_err(|e| {
155 let kind = if e.kind() == std::io::ErrorKind::ConnectionRefused {
156 ConnectionErrorKind::Refused
157 } else {
158 ConnectionErrorKind::Connect
159 };
160 Error::Connection(ConnectionError {
161 kind,
162 message: format!("Failed to connect to {}: {}", config.socket_addr(), e),
163 source: Some(Box::new(e)),
164 })
165 })?;
166
167 stream.set_nodelay(true).ok();
169 stream.set_read_timeout(Some(config.connect_timeout)).ok();
170 stream.set_write_timeout(Some(config.connect_timeout)).ok();
171
172 let mut conn = Self {
173 stream,
174 state: ConnectionState::Connecting,
175 process_id: 0,
176 secret_key: 0,
177 parameters: HashMap::new(),
178 config,
179 reader: MessageReader::new(),
180 writer: MessageWriter::new(),
181 read_buf: vec![0u8; 8192],
182 #[cfg(feature = "console")]
183 console: None,
184 };
185
186 if conn.config.ssl_mode.should_try_ssl() {
188 conn.negotiate_ssl()?;
189 }
190
191 conn.send_startup()?;
193 conn.state = ConnectionState::Authenticating;
194
195 conn.handle_auth()?;
197
198 conn.read_startup_messages()?;
200
201 Ok(conn)
202 }
203
204 pub fn state(&self) -> ConnectionState {
206 self.state
207 }
208
209 pub fn is_ready(&self) -> bool {
211 matches!(self.state, ConnectionState::Ready(_))
212 }
213
214 pub fn process_id(&self) -> i32 {
216 self.process_id
217 }
218
219 pub fn secret_key(&self) -> i32 {
221 self.secret_key
222 }
223
224 pub fn parameter(&self, name: &str) -> Option<&str> {
226 self.parameters.get(name).map(|s| s.as_str())
227 }
228
229 pub fn parameters(&self) -> &HashMap<String, String> {
231 &self.parameters
232 }
233
234 #[allow(clippy::result_large_err)]
236 pub fn close(&mut self) -> Result<(), Error> {
237 if matches!(
238 self.state,
239 ConnectionState::Closed | ConnectionState::Disconnected
240 ) {
241 return Ok(());
242 }
243
244 self.send_message(&FrontendMessage::Terminate)?;
246 self.state = ConnectionState::Closed;
247 Ok(())
248 }
249
250 #[allow(clippy::result_large_err)]
253 fn negotiate_ssl(&mut self) -> Result<(), Error> {
254 self.send_message(&FrontendMessage::SSLRequest)?;
256
257 let mut buf = [0u8; 1];
259 self.stream.read_exact(&mut buf).map_err(|e| {
260 Error::Connection(ConnectionError {
261 kind: ConnectionErrorKind::Ssl,
262 message: format!("Failed to read SSL response: {}", e),
263 source: Some(Box::new(e)),
264 })
265 })?;
266
267 match buf[0] {
268 b'S' => {
269 if self.config.ssl_mode.is_required() {
272 return Err(Error::Connection(ConnectionError {
273 kind: ConnectionErrorKind::Ssl,
274 message: "SSL/TLS not yet implemented".to_string(),
275 source: None,
276 }));
277 }
278 Err(Error::Connection(ConnectionError {
281 kind: ConnectionErrorKind::Ssl,
282 message: "SSL/TLS not yet implemented, reconnect with ssl_mode=disable"
283 .to_string(),
284 source: None,
285 }))
286 }
287 b'N' => {
288 if self.config.ssl_mode.is_required() {
290 return Err(Error::Connection(ConnectionError {
291 kind: ConnectionErrorKind::Ssl,
292 message: "Server does not support SSL".to_string(),
293 source: None,
294 }));
295 }
296 Ok(())
298 }
299 _ => Err(Error::Connection(ConnectionError {
300 kind: ConnectionErrorKind::Ssl,
301 message: format!("Unexpected SSL response: 0x{:02x}", buf[0]),
302 source: None,
303 })),
304 }
305 }
306
307 #[allow(clippy::result_large_err)]
310 fn send_startup(&mut self) -> Result<(), Error> {
311 let params = self.config.startup_params();
312 let msg = FrontendMessage::Startup {
313 version: PROTOCOL_VERSION,
314 params,
315 };
316 self.send_message(&msg)
317 }
318
319 #[allow(clippy::result_large_err)]
322 fn handle_auth(&mut self) -> Result<(), Error> {
323 loop {
324 let msg = self.receive_message()?;
325
326 match msg {
327 BackendMessage::AuthenticationOk => {
328 return Ok(());
329 }
330 BackendMessage::AuthenticationCleartextPassword => {
331 let password = self
332 .config
333 .password
334 .as_ref()
335 .ok_or_else(|| auth_error("Password required but not provided"))?;
336 self.send_message(&FrontendMessage::PasswordMessage(password.clone()))?;
337 }
338 BackendMessage::AuthenticationMD5Password(salt) => {
339 let password = self
340 .config
341 .password
342 .as_ref()
343 .ok_or_else(|| auth_error("Password required but not provided"))?;
344 let hash = md5_password(&self.config.user, password, salt);
345 self.send_message(&FrontendMessage::PasswordMessage(hash))?;
346 }
347 BackendMessage::AuthenticationSASL(mechanisms) => {
348 if mechanisms.contains(&"SCRAM-SHA-256".to_string()) {
349 self.scram_auth()?;
350 } else {
351 return Err(auth_error(format!(
352 "Unsupported SASL mechanisms: {:?}",
353 mechanisms
354 )));
355 }
356 }
357 BackendMessage::ErrorResponse(e) => {
358 self.state = ConnectionState::Error;
359 return Err(error_from_fields(&e));
360 }
361 _ => {
362 return Err(Error::Protocol(ProtocolError {
363 message: format!("Unexpected message during auth: {:?}", msg),
364 raw_data: None,
365 source: None,
366 }));
367 }
368 }
369 }
370 }
371
372 #[allow(clippy::result_large_err)]
373 fn scram_auth(&mut self) -> Result<(), Error> {
374 let password = self
375 .config
376 .password
377 .as_ref()
378 .ok_or_else(|| auth_error("Password required for SCRAM-SHA-256"))?;
379
380 let mut client = ScramClient::new(&self.config.user, password);
381
382 let client_first = client.client_first();
384 self.send_message(&FrontendMessage::SASLInitialResponse {
385 mechanism: "SCRAM-SHA-256".to_string(),
386 data: client_first,
387 })?;
388
389 let msg = self.receive_message()?;
391 let server_first_data = match msg {
392 BackendMessage::AuthenticationSASLContinue(data) => data,
393 BackendMessage::ErrorResponse(e) => {
394 self.state = ConnectionState::Error;
395 return Err(error_from_fields(&e));
396 }
397 _ => {
398 return Err(Error::Protocol(ProtocolError {
399 message: format!("Expected SASL continue, got: {:?}", msg),
400 raw_data: None,
401 source: None,
402 }));
403 }
404 };
405
406 let client_final = client.process_server_first(&server_first_data)?;
408 self.send_message(&FrontendMessage::SASLResponse(client_final))?;
409
410 let msg = self.receive_message()?;
412 let server_final_data = match msg {
413 BackendMessage::AuthenticationSASLFinal(data) => data,
414 BackendMessage::ErrorResponse(e) => {
415 self.state = ConnectionState::Error;
416 return Err(error_from_fields(&e));
417 }
418 _ => {
419 return Err(Error::Protocol(ProtocolError {
420 message: format!("Expected SASL final, got: {:?}", msg),
421 raw_data: None,
422 source: None,
423 }));
424 }
425 };
426
427 client.verify_server_final(&server_final_data)?;
429
430 let msg = self.receive_message()?;
432 match msg {
433 BackendMessage::AuthenticationOk => Ok(()),
434 BackendMessage::ErrorResponse(e) => {
435 self.state = ConnectionState::Error;
436 Err(error_from_fields(&e))
437 }
438 _ => Err(Error::Protocol(ProtocolError {
439 message: format!("Expected AuthenticationOk, got: {:?}", msg),
440 raw_data: None,
441 source: None,
442 })),
443 }
444 }
445
446 #[allow(clippy::result_large_err)]
449 fn read_startup_messages(&mut self) -> Result<(), Error> {
450 loop {
451 let msg = self.receive_message()?;
452
453 match msg {
454 BackendMessage::BackendKeyData {
455 process_id,
456 secret_key,
457 } => {
458 self.process_id = process_id;
459 self.secret_key = secret_key;
460 }
461 BackendMessage::ParameterStatus { name, value } => {
462 self.parameters.insert(name, value);
463 }
464 BackendMessage::ReadyForQuery(status) => {
465 self.state = ConnectionState::Ready(status.into());
466 return Ok(());
467 }
468 BackendMessage::ErrorResponse(e) => {
469 self.state = ConnectionState::Error;
470 return Err(error_from_fields(&e));
471 }
472 BackendMessage::NoticeResponse(_notice) => {
473 }
475 _ => {
476 return Err(Error::Protocol(ProtocolError {
477 message: format!("Unexpected startup message: {:?}", msg),
478 raw_data: None,
479 source: None,
480 }));
481 }
482 }
483 }
484 }
485
486 #[allow(clippy::result_large_err)]
489 fn send_message(&mut self, msg: &FrontendMessage) -> Result<(), Error> {
490 let data = self.writer.write(msg);
491 self.stream.write_all(data).map_err(|e| {
492 self.state = ConnectionState::Error;
493 Error::Io(e)
494 })?;
495 self.stream.flush().map_err(|e| {
496 self.state = ConnectionState::Error;
497 Error::Io(e)
498 })?;
499 Ok(())
500 }
501
502 #[allow(clippy::result_large_err)]
503 fn receive_message(&mut self) -> Result<BackendMessage, Error> {
504 loop {
506 match self.reader.next_message() {
507 Ok(Some(msg)) => return Ok(msg),
508 Ok(None) => {
509 let n = self.stream.read(&mut self.read_buf).map_err(|e| {
511 if e.kind() == std::io::ErrorKind::TimedOut
512 || e.kind() == std::io::ErrorKind::WouldBlock
513 {
514 Error::Timeout
515 } else {
516 self.state = ConnectionState::Error;
517 Error::Connection(ConnectionError {
518 kind: ConnectionErrorKind::Disconnected,
519 message: format!("Failed to read from server: {}", e),
520 source: Some(Box::new(e)),
521 })
522 }
523 })?;
524
525 if n == 0 {
526 self.state = ConnectionState::Disconnected;
527 return Err(Error::Connection(ConnectionError {
528 kind: ConnectionErrorKind::Disconnected,
529 message: "Connection closed by server".to_string(),
530 source: None,
531 }));
532 }
533
534 self.reader.feed(&self.read_buf[..n]).map_err(|e| {
536 Error::Protocol(ProtocolError {
537 message: format!("Protocol error: {}", e),
538 raw_data: None,
539 source: None,
540 })
541 })?;
542 }
543 Err(e) => {
544 self.state = ConnectionState::Error;
545 return Err(Error::Protocol(ProtocolError {
546 message: format!("Protocol error: {}", e),
547 raw_data: None,
548 source: None,
549 }));
550 }
551 }
552 }
553 }
554}
555
556impl Drop for PgConnection {
557 fn drop(&mut self) {
558 let _ = self.close();
560 }
561}
562
563#[cfg(feature = "console")]
566impl ConsoleAware for PgConnection {
567 fn set_console(&mut self, console: Option<Arc<SqlModelConsole>>) {
568 self.console = console;
569 }
570
571 fn console(&self) -> Option<&Arc<SqlModelConsole>> {
572 self.console.as_ref()
573 }
574
575 fn has_console(&self) -> bool {
576 self.console.is_some()
577 }
578}
579
580#[cfg(feature = "console")]
582#[derive(Debug, Clone, Copy, PartialEq, Eq)]
583pub enum ConnectionStage {
584 DnsResolve,
586 TcpConnect,
588 SslNegotiate,
590 SslEstablished,
592 Startup,
594 Authenticating,
596 Authenticated,
598 Ready,
600}
601
602#[cfg(feature = "console")]
603impl ConnectionStage {
604 #[must_use]
606 pub fn description(&self) -> &'static str {
607 match self {
608 Self::DnsResolve => "Resolving DNS",
609 Self::TcpConnect => "Connecting (TCP)",
610 Self::SslNegotiate => "Negotiating SSL",
611 Self::SslEstablished => "SSL established",
612 Self::Startup => "Sending startup",
613 Self::Authenticating => "Authenticating",
614 Self::Authenticated => "Authenticated",
615 Self::Ready => "Ready",
616 }
617 }
618}
619
620#[cfg(feature = "console")]
621impl PgConnection {
622 pub fn emit_progress(&self, stage: ConnectionStage, success: bool) {
626 if let Some(console) = &self.console {
627 let status = if success { "[OK]" } else { "[..] " };
628 let message = format!("{} {}", status, stage.description());
629 console.info(&message);
630 }
631 }
632
633 pub fn emit_connected(&self) {
635 if let Some(console) = &self.console {
636 let server_version = self
637 .parameters
638 .get("server_version")
639 .map_or("unknown", |s| s.as_str());
640 let message = format!(
641 "Connected to PostgreSQL {} at {}:{}",
642 server_version, self.config.host, self.config.port
643 );
644 console.success(&message);
645 }
646 }
647
648 pub fn emit_connected_plain(&self) -> String {
650 let server_version = self
651 .parameters
652 .get("server_version")
653 .map_or("unknown", |s| s.as_str());
654 format!(
655 "Connected to PostgreSQL {} at {}:{}",
656 server_version, self.config.host, self.config.port
657 )
658 }
659}
660
661fn md5_password(user: &str, password: &str, salt: [u8; 4]) -> String {
665 use std::fmt::Write;
666
667 let inner = format!("{}{}", password, user);
669 let inner_hash = md5::compute(inner.as_bytes());
670
671 let mut outer_input = format!("{:x}", inner_hash).into_bytes();
672 outer_input.extend_from_slice(&salt);
673 let outer_hash = md5::compute(&outer_input);
674
675 let mut result = String::with_capacity(35);
676 result.push_str("md5");
677 write!(&mut result, "{:x}", outer_hash).unwrap();
678 result
679}
680
681fn auth_error(msg: impl Into<String>) -> Error {
682 Error::Connection(ConnectionError {
683 kind: ConnectionErrorKind::Authentication,
684 message: msg.into(),
685 source: None,
686 })
687}
688
689fn error_from_fields(fields: &ErrorFields) -> Error {
690 let kind = match fields.code.get(..2) {
692 Some("08") => {
693 return Error::Connection(ConnectionError {
695 kind: ConnectionErrorKind::Connect,
696 message: fields.message.clone(),
697 source: None,
698 });
699 }
700 Some("28") => {
701 return Error::Connection(ConnectionError {
703 kind: ConnectionErrorKind::Authentication,
704 message: fields.message.clone(),
705 source: None,
706 });
707 }
708 Some("42") => QueryErrorKind::Syntax, Some("23") => QueryErrorKind::Constraint, Some("40") => {
711 if fields.code == "40001" {
712 QueryErrorKind::Serialization
713 } else {
714 QueryErrorKind::Deadlock
715 }
716 }
717 Some("57") => {
718 if fields.code == "57014" {
719 QueryErrorKind::Cancelled
720 } else {
721 QueryErrorKind::Timeout
722 }
723 }
724 _ => QueryErrorKind::Database,
725 };
726
727 Error::Query(QueryError {
728 kind,
729 sql: None,
730 sqlstate: Some(fields.code.clone()),
731 message: fields.message.clone(),
732 detail: fields.detail.clone(),
733 hint: fields.hint.clone(),
734 position: fields.position.map(|p| p as usize),
735 source: None,
736 })
737}
738
739#[cfg(test)]
740mod tests {
741 use super::*;
742
743 #[test]
744 fn test_md5_password() {
745 let hash = md5_password("postgres", "mysecretpassword", *b"abcd");
747 assert!(hash.starts_with("md5"));
748 assert_eq!(hash.len(), 35); }
750
751 #[test]
752 fn test_transaction_status_conversion() {
753 assert_eq!(
754 TransactionStatusState::from(TransactionStatus::Idle),
755 TransactionStatusState::Idle
756 );
757 assert_eq!(
758 TransactionStatusState::from(TransactionStatus::Transaction),
759 TransactionStatusState::InTransaction
760 );
761 assert_eq!(
762 TransactionStatusState::from(TransactionStatus::Error),
763 TransactionStatusState::InFailed
764 );
765 }
766
767 #[test]
768 fn test_error_classification() {
769 let fields = ErrorFields {
770 severity: "ERROR".to_string(),
771 code: "23505".to_string(),
772 message: "unique violation".to_string(),
773 ..Default::default()
774 };
775 let err = error_from_fields(&fields);
776 assert!(matches!(err, Error::Query(q) if q.kind == QueryErrorKind::Constraint));
777
778 let fields = ErrorFields {
779 severity: "FATAL".to_string(),
780 code: "28P01".to_string(),
781 message: "password authentication failed".to_string(),
782 ..Default::default()
783 };
784 let err = error_from_fields(&fields);
785 assert!(matches!(
786 err,
787 Error::Connection(c) if c.kind == ConnectionErrorKind::Authentication
788 ));
789 }
790}