1use std::collections::HashMap;
23use std::io::{Read, Write};
24use std::net::TcpStream;
25#[cfg(feature = "console")]
26use std::sync::Arc;
27
28use sqlmodel_core::Error;
29use sqlmodel_core::error::{
30 ConnectionError, ConnectionErrorKind, ProtocolError, QueryError, QueryErrorKind,
31};
32
33#[cfg(feature = "console")]
34use sqlmodel_console::{ConsoleAware, SqlModelConsole};
35
36use crate::auth::ScramClient;
37use crate::config::PgConfig;
38#[cfg(not(feature = "tls"))]
39use crate::config::SslMode;
40use crate::protocol::{
41 BackendMessage, ErrorFields, FrontendMessage, MessageReader, MessageWriter, PROTOCOL_VERSION,
42 TransactionStatus,
43};
44
45#[cfg(feature = "tls")]
46use crate::tls;
47
48enum PgStream {
49 Plain(TcpStream),
50 #[cfg(feature = "tls")]
51 Tls(rustls::StreamOwned<rustls::ClientConnection, TcpStream>),
52 #[cfg(feature = "tls")]
53 Closed,
54}
55
56impl PgStream {
57 #[cfg(feature = "tls")]
58 fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
59 match self {
60 PgStream::Plain(s) => s.read_exact(buf),
61 #[cfg(feature = "tls")]
62 PgStream::Tls(s) => s.read_exact(buf),
63 #[cfg(feature = "tls")]
64 PgStream::Closed => Err(std::io::Error::new(
65 std::io::ErrorKind::NotConnected,
66 "connection closed",
67 )),
68 }
69 }
70
71 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
72 match self {
73 PgStream::Plain(s) => s.read(buf),
74 #[cfg(feature = "tls")]
75 PgStream::Tls(s) => s.read(buf),
76 #[cfg(feature = "tls")]
77 PgStream::Closed => Err(std::io::Error::new(
78 std::io::ErrorKind::NotConnected,
79 "connection closed",
80 )),
81 }
82 }
83
84 fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
85 match self {
86 PgStream::Plain(s) => s.write_all(buf),
87 #[cfg(feature = "tls")]
88 PgStream::Tls(s) => s.write_all(buf),
89 #[cfg(feature = "tls")]
90 PgStream::Closed => Err(std::io::Error::new(
91 std::io::ErrorKind::NotConnected,
92 "connection closed",
93 )),
94 }
95 }
96
97 fn flush(&mut self) -> std::io::Result<()> {
98 match self {
99 PgStream::Plain(s) => s.flush(),
100 #[cfg(feature = "tls")]
101 PgStream::Tls(s) => s.flush(),
102 #[cfg(feature = "tls")]
103 PgStream::Closed => Err(std::io::Error::new(
104 std::io::ErrorKind::NotConnected,
105 "connection closed",
106 )),
107 }
108 }
109}
110
111#[derive(Debug, Clone, Copy, PartialEq, Eq)]
113pub enum ConnectionState {
114 Disconnected,
116 Connecting,
118 Authenticating,
120 Ready(TransactionStatusState),
122 InQuery,
124 InTransaction(TransactionStatusState),
126 Error,
128 Closed,
130}
131
132#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
134pub enum TransactionStatusState {
135 #[default]
137 Idle,
138 InTransaction,
140 InFailed,
142}
143
144impl From<TransactionStatus> for TransactionStatusState {
145 fn from(status: TransactionStatus) -> Self {
146 match status {
147 TransactionStatus::Idle => TransactionStatusState::Idle,
148 TransactionStatus::Transaction => TransactionStatusState::InTransaction,
149 TransactionStatus::Error => TransactionStatusState::InFailed,
150 }
151 }
152}
153
154pub struct PgConnection {
165 stream: PgStream,
167 state: ConnectionState,
169 process_id: i32,
171 secret_key: i32,
173 parameters: HashMap<String, String>,
175 config: PgConfig,
177 reader: MessageReader,
179 writer: MessageWriter,
181 read_buf: Vec<u8>,
183 #[cfg(feature = "console")]
185 console: Option<Arc<SqlModelConsole>>,
186}
187
188impl std::fmt::Debug for PgConnection {
189 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190 f.debug_struct("PgConnection")
191 .field("state", &self.state)
192 .field("process_id", &self.process_id)
193 .field("host", &self.config.host)
194 .field("port", &self.config.port)
195 .field("database", &self.config.database)
196 .finish_non_exhaustive()
197 }
198}
199
200impl PgConnection {
201 #[allow(clippy::result_large_err)]
210 pub fn connect(config: PgConfig) -> Result<Self, Error> {
211 let stream = TcpStream::connect_timeout(
213 &config.socket_addr().parse().map_err(|e| {
214 Error::Connection(ConnectionError {
215 kind: ConnectionErrorKind::Connect,
216 message: format!("Invalid socket address: {}", e),
217 source: None,
218 })
219 })?,
220 config.connect_timeout,
221 )
222 .map_err(|e| {
223 let kind = if e.kind() == std::io::ErrorKind::ConnectionRefused {
224 ConnectionErrorKind::Refused
225 } else {
226 ConnectionErrorKind::Connect
227 };
228 Error::Connection(ConnectionError {
229 kind,
230 message: format!("Failed to connect to {}: {}", config.socket_addr(), e),
231 source: Some(Box::new(e)),
232 })
233 })?;
234
235 stream.set_nodelay(true).ok();
237 stream.set_read_timeout(Some(config.connect_timeout)).ok();
238 stream.set_write_timeout(Some(config.connect_timeout)).ok();
239
240 let mut conn = Self {
241 stream: PgStream::Plain(stream),
242 state: ConnectionState::Connecting,
243 process_id: 0,
244 secret_key: 0,
245 parameters: HashMap::new(),
246 config,
247 reader: MessageReader::new(),
248 writer: MessageWriter::new(),
249 read_buf: vec![0u8; 8192],
250 #[cfg(feature = "console")]
251 console: None,
252 };
253
254 if conn.config.ssl_mode.should_try_ssl() {
256 #[cfg(feature = "tls")]
257 conn.negotiate_ssl()?;
258
259 #[cfg(not(feature = "tls"))]
260 if conn.config.ssl_mode != SslMode::Prefer {
261 return Err(Error::Connection(ConnectionError {
262 kind: ConnectionErrorKind::Ssl,
263 message:
264 "TLS requested but 'sqlmodel-postgres' was built without feature 'tls'"
265 .to_string(),
266 source: None,
267 }));
268 }
269 }
270
271 conn.send_startup()?;
273 conn.state = ConnectionState::Authenticating;
274
275 conn.handle_auth()?;
277
278 conn.read_startup_messages()?;
280
281 Ok(conn)
282 }
283
284 pub fn state(&self) -> ConnectionState {
286 self.state
287 }
288
289 pub fn is_ready(&self) -> bool {
291 matches!(self.state, ConnectionState::Ready(_))
292 }
293
294 pub fn process_id(&self) -> i32 {
296 self.process_id
297 }
298
299 pub fn secret_key(&self) -> i32 {
301 self.secret_key
302 }
303
304 pub fn parameter(&self, name: &str) -> Option<&str> {
306 self.parameters.get(name).map(|s| s.as_str())
307 }
308
309 pub fn parameters(&self) -> &HashMap<String, String> {
311 &self.parameters
312 }
313
314 #[allow(clippy::result_large_err)]
316 pub fn close(&mut self) -> Result<(), Error> {
317 if matches!(
318 self.state,
319 ConnectionState::Closed | ConnectionState::Disconnected
320 ) {
321 return Ok(());
322 }
323
324 self.send_message(&FrontendMessage::Terminate)?;
326 self.state = ConnectionState::Closed;
327 Ok(())
328 }
329
330 #[allow(clippy::result_large_err)]
333 #[cfg(feature = "tls")]
334 fn negotiate_ssl(&mut self) -> Result<(), Error> {
335 self.send_message(&FrontendMessage::SSLRequest)?;
337
338 let mut buf = [0u8; 1];
340 self.stream.read_exact(&mut buf).map_err(|e| {
341 Error::Connection(ConnectionError {
342 kind: ConnectionErrorKind::Ssl,
343 message: format!("Failed to read SSL response: {}", e),
344 source: Some(Box::new(e)),
345 })
346 })?;
347
348 match buf[0] {
349 b'S' => {
350 #[cfg(feature = "tls")]
352 {
353 let plain = match std::mem::replace(&mut self.stream, PgStream::Closed) {
354 PgStream::Plain(s) => s,
355 other => {
356 self.stream = other;
357 return Err(Error::Connection(ConnectionError {
358 kind: ConnectionErrorKind::Ssl,
359 message: "TLS upgrade requires a plain TCP stream".to_string(),
360 source: None,
361 }));
362 }
363 };
364
365 let config = tls::build_client_config(self.config.ssl_mode)?;
366 let server_name = tls::server_name(&self.config.host)?;
367 let conn =
368 rustls::ClientConnection::new(std::sync::Arc::new(config), server_name)
369 .map_err(|e| {
370 Error::Connection(ConnectionError {
371 kind: ConnectionErrorKind::Ssl,
372 message: format!("Failed to create TLS connection: {e}"),
373 source: None,
374 })
375 })?;
376
377 let mut tls_stream = rustls::StreamOwned::new(conn, plain);
378 while tls_stream.conn.is_handshaking() {
379 tls_stream
380 .conn
381 .complete_io(&mut tls_stream.sock)
382 .map_err(|e| {
383 Error::Connection(ConnectionError {
384 kind: ConnectionErrorKind::Ssl,
385 message: format!("TLS handshake failed: {e}"),
386 source: Some(Box::new(e)),
387 })
388 })?;
389 }
390
391 self.stream = PgStream::Tls(tls_stream);
392 Ok(())
393 }
394
395 #[cfg(not(feature = "tls"))]
396 {
397 Err(Error::Connection(ConnectionError {
398 kind: ConnectionErrorKind::Ssl,
399 message:
400 "TLS requested but 'sqlmodel-postgres' was built without feature 'tls'"
401 .to_string(),
402 source: None,
403 }))
404 }
405 }
406 b'N' => {
407 if self.config.ssl_mode.is_required() {
409 return Err(Error::Connection(ConnectionError {
410 kind: ConnectionErrorKind::Ssl,
411 message: "Server does not support SSL".to_string(),
412 source: None,
413 }));
414 }
415 Ok(())
417 }
418 _ => Err(Error::Connection(ConnectionError {
419 kind: ConnectionErrorKind::Ssl,
420 message: format!("Unexpected SSL response: 0x{:02x}", buf[0]),
421 source: None,
422 })),
423 }
424 }
425
426 #[allow(clippy::result_large_err)]
429 fn send_startup(&mut self) -> Result<(), Error> {
430 let params = self.config.startup_params();
431 let msg = FrontendMessage::Startup {
432 version: PROTOCOL_VERSION,
433 params,
434 };
435 self.send_message(&msg)
436 }
437
438 #[allow(clippy::result_large_err)]
441 fn require_auth_value(&self, message: &'static str) -> Result<&str, Error> {
442 self.config
444 .password
445 .as_deref()
446 .ok_or_else(|| auth_error(message))
447 }
448
449 #[allow(clippy::result_large_err)]
450 fn handle_auth(&mut self) -> Result<(), Error> {
451 loop {
452 let msg = self.receive_message()?;
453
454 match msg {
455 BackendMessage::AuthenticationOk => {
456 return Ok(());
457 }
458 BackendMessage::AuthenticationCleartextPassword => {
459 let auth_value =
460 self.require_auth_value("Authentication value required but not provided")?;
461 self.send_message(&FrontendMessage::PasswordMessage(auth_value.to_string()))?;
462 }
463 BackendMessage::AuthenticationMD5Password(salt) => {
464 let auth_value =
465 self.require_auth_value("Authentication value required but not provided")?;
466 let hash = md5_password(&self.config.user, auth_value, salt);
467 self.send_message(&FrontendMessage::PasswordMessage(hash))?;
468 }
469 BackendMessage::AuthenticationSASL(mechanisms) => {
470 if mechanisms.contains(&"SCRAM-SHA-256".to_string()) {
471 self.scram_auth()?;
472 } else {
473 return Err(auth_error(format!(
474 "Unsupported SASL mechanisms: {:?}",
475 mechanisms
476 )));
477 }
478 }
479 BackendMessage::ErrorResponse(e) => {
480 self.state = ConnectionState::Error;
481 return Err(error_from_fields(&e));
482 }
483 _ => {
484 return Err(Error::Protocol(ProtocolError {
485 message: format!("Unexpected message during auth: {:?}", msg),
486 raw_data: None,
487 source: None,
488 }));
489 }
490 }
491 }
492 }
493
494 #[allow(clippy::result_large_err)]
495 fn scram_auth(&mut self) -> Result<(), Error> {
496 let auth_value =
497 self.require_auth_value("Authentication value required for SCRAM-SHA-256")?;
498
499 let mut client = ScramClient::new(&self.config.user, auth_value);
500
501 let client_first = client.client_first();
503 self.send_message(&FrontendMessage::SASLInitialResponse {
504 mechanism: "SCRAM-SHA-256".to_string(),
505 data: client_first,
506 })?;
507
508 let msg = self.receive_message()?;
510 let server_first_data = match msg {
511 BackendMessage::AuthenticationSASLContinue(data) => data,
512 BackendMessage::ErrorResponse(e) => {
513 self.state = ConnectionState::Error;
514 return Err(error_from_fields(&e));
515 }
516 _ => {
517 return Err(Error::Protocol(ProtocolError {
518 message: format!("Expected SASL continue, got: {:?}", msg),
519 raw_data: None,
520 source: None,
521 }));
522 }
523 };
524
525 let client_final = client.process_server_first(&server_first_data)?;
527 self.send_message(&FrontendMessage::SASLResponse(client_final))?;
528
529 let msg = self.receive_message()?;
531 let server_final_data = match msg {
532 BackendMessage::AuthenticationSASLFinal(data) => data,
533 BackendMessage::ErrorResponse(e) => {
534 self.state = ConnectionState::Error;
535 return Err(error_from_fields(&e));
536 }
537 _ => {
538 return Err(Error::Protocol(ProtocolError {
539 message: format!("Expected SASL final, got: {:?}", msg),
540 raw_data: None,
541 source: None,
542 }));
543 }
544 };
545
546 client.verify_server_final(&server_final_data)?;
548
549 let msg = self.receive_message()?;
551 match msg {
552 BackendMessage::AuthenticationOk => Ok(()),
553 BackendMessage::ErrorResponse(e) => {
554 self.state = ConnectionState::Error;
555 Err(error_from_fields(&e))
556 }
557 _ => Err(Error::Protocol(ProtocolError {
558 message: format!("Expected AuthenticationOk, got: {:?}", msg),
559 raw_data: None,
560 source: None,
561 })),
562 }
563 }
564
565 #[allow(clippy::result_large_err)]
568 fn read_startup_messages(&mut self) -> Result<(), Error> {
569 loop {
570 let msg = self.receive_message()?;
571
572 match msg {
573 BackendMessage::BackendKeyData {
574 process_id,
575 secret_key,
576 } => {
577 self.process_id = process_id;
578 self.secret_key = secret_key;
579 }
580 BackendMessage::ParameterStatus { name, value } => {
581 self.parameters.insert(name, value);
582 }
583 BackendMessage::ReadyForQuery(status) => {
584 self.state = ConnectionState::Ready(status.into());
585 return Ok(());
586 }
587 BackendMessage::ErrorResponse(e) => {
588 self.state = ConnectionState::Error;
589 return Err(error_from_fields(&e));
590 }
591 BackendMessage::NoticeResponse(_notice) => {
592 }
594 _ => {
595 return Err(Error::Protocol(ProtocolError {
596 message: format!("Unexpected startup message: {:?}", msg),
597 raw_data: None,
598 source: None,
599 }));
600 }
601 }
602 }
603 }
604
605 #[allow(clippy::result_large_err)]
608 fn send_message(&mut self, msg: &FrontendMessage) -> Result<(), Error> {
609 let data = self.writer.write(msg);
610 self.stream.write_all(data).map_err(|e| {
611 self.state = ConnectionState::Error;
612 Error::Io(e)
613 })?;
614 self.stream.flush().map_err(|e| {
615 self.state = ConnectionState::Error;
616 Error::Io(e)
617 })?;
618 Ok(())
619 }
620
621 #[allow(clippy::result_large_err)]
622 fn receive_message(&mut self) -> Result<BackendMessage, Error> {
623 loop {
625 match self.reader.next_message() {
626 Ok(Some(msg)) => return Ok(msg),
627 Ok(None) => {
628 let n = self.stream.read(&mut self.read_buf).map_err(|e| {
630 if e.kind() == std::io::ErrorKind::TimedOut
631 || e.kind() == std::io::ErrorKind::WouldBlock
632 {
633 Error::Timeout
634 } else {
635 self.state = ConnectionState::Error;
636 Error::Connection(ConnectionError {
637 kind: ConnectionErrorKind::Disconnected,
638 message: format!("Failed to read from server: {}", e),
639 source: Some(Box::new(e)),
640 })
641 }
642 })?;
643
644 if n == 0 {
645 self.state = ConnectionState::Disconnected;
646 return Err(Error::Connection(ConnectionError {
647 kind: ConnectionErrorKind::Disconnected,
648 message: "Connection closed by server".to_string(),
649 source: None,
650 }));
651 }
652
653 self.reader.feed(&self.read_buf[..n]).map_err(|e| {
655 Error::Protocol(ProtocolError {
656 message: format!("Protocol error: {}", e),
657 raw_data: None,
658 source: None,
659 })
660 })?;
661 }
662 Err(e) => {
663 self.state = ConnectionState::Error;
664 return Err(Error::Protocol(ProtocolError {
665 message: format!("Protocol error: {}", e),
666 raw_data: None,
667 source: None,
668 }));
669 }
670 }
671 }
672 }
673}
674
675impl Drop for PgConnection {
676 fn drop(&mut self) {
677 let _ = self.close();
679 }
680}
681
682#[cfg(feature = "console")]
685impl ConsoleAware for PgConnection {
686 fn set_console(&mut self, console: Option<Arc<SqlModelConsole>>) {
687 self.console = console;
688 }
689
690 fn console(&self) -> Option<&Arc<SqlModelConsole>> {
691 self.console.as_ref()
692 }
693
694 fn has_console(&self) -> bool {
695 self.console.is_some()
696 }
697}
698
699#[cfg(feature = "console")]
701#[derive(Debug, Clone, Copy, PartialEq, Eq)]
702pub enum ConnectionStage {
703 DnsResolve,
705 TcpConnect,
707 SslNegotiate,
709 SslEstablished,
711 Startup,
713 Authenticating,
715 Authenticated,
717 Ready,
719}
720
721#[cfg(feature = "console")]
722impl ConnectionStage {
723 #[must_use]
725 pub fn description(&self) -> &'static str {
726 match self {
727 Self::DnsResolve => "Resolving DNS",
728 Self::TcpConnect => "Connecting (TCP)",
729 Self::SslNegotiate => "Negotiating SSL",
730 Self::SslEstablished => "SSL established",
731 Self::Startup => "Sending startup",
732 Self::Authenticating => "Authenticating",
733 Self::Authenticated => "Authenticated",
734 Self::Ready => "Ready",
735 }
736 }
737}
738
739#[cfg(feature = "console")]
740impl PgConnection {
741 pub fn emit_progress(&self, stage: ConnectionStage, success: bool) {
745 if let Some(console) = &self.console {
746 let status = if success { "[OK]" } else { "[..] " };
747 let message = format!("{} {}", status, stage.description());
748 console.info(&message);
749 }
750 }
751
752 pub fn emit_connected(&self) {
754 if let Some(console) = &self.console {
755 let server_version = self
756 .parameters
757 .get("server_version")
758 .map_or("unknown", |s| s.as_str());
759 let message = format!(
760 "Connected to PostgreSQL {} at {}:{}",
761 server_version, self.config.host, self.config.port
762 );
763 console.success(&message);
764 }
765 }
766
767 pub fn emit_connected_plain(&self) -> String {
769 let server_version = self
770 .parameters
771 .get("server_version")
772 .map_or("unknown", |s| s.as_str());
773 format!(
774 "Connected to PostgreSQL {} at {}:{}",
775 server_version, self.config.host, self.config.port
776 )
777 }
778}
779
780fn md5_password(user: &str, password: &str, salt: [u8; 4]) -> String {
784 use std::fmt::Write;
785
786 let inner = format!("{}{}", password, user);
788 let inner_hash = md5::compute(inner.as_bytes());
789
790 let mut outer_input = format!("{:x}", inner_hash).into_bytes();
791 outer_input.extend_from_slice(&salt);
792 let outer_hash = md5::compute(&outer_input);
793
794 let mut result = String::with_capacity(35);
795 result.push_str("md5");
796 write!(&mut result, "{:x}", outer_hash).unwrap();
797 result
798}
799
800fn auth_error(msg: impl Into<String>) -> Error {
801 Error::Connection(ConnectionError {
802 kind: ConnectionErrorKind::Authentication,
803 message: msg.into(),
804 source: None,
805 })
806}
807
808fn error_from_fields(fields: &ErrorFields) -> Error {
809 let kind = match fields.code.get(..2) {
811 Some("08") => {
812 return Error::Connection(ConnectionError {
814 kind: ConnectionErrorKind::Connect,
815 message: fields.message.clone(),
816 source: None,
817 });
818 }
819 Some("28") => {
820 return Error::Connection(ConnectionError {
822 kind: ConnectionErrorKind::Authentication,
823 message: fields.message.clone(),
824 source: None,
825 });
826 }
827 Some("42") => QueryErrorKind::Syntax, Some("23") => QueryErrorKind::Constraint, Some("40") => {
830 if fields.code == "40001" {
831 QueryErrorKind::Serialization
832 } else {
833 QueryErrorKind::Deadlock
834 }
835 }
836 Some("57") => {
837 if fields.code == "57014" {
838 QueryErrorKind::Cancelled
839 } else {
840 QueryErrorKind::Timeout
841 }
842 }
843 _ => QueryErrorKind::Database,
844 };
845
846 Error::Query(QueryError {
847 kind,
848 sql: None,
849 sqlstate: Some(fields.code.clone()),
850 message: fields.message.clone(),
851 detail: fields.detail.clone(),
852 hint: fields.hint.clone(),
853 position: fields.position.map(|p| p as usize),
854 source: None,
855 })
856}
857
858#[cfg(test)]
859mod tests {
860 use super::*;
861
862 #[test]
863 fn test_md5_password() {
864 let hash = md5_password("postgres", "mysecretpassword", *b"abcd");
866 assert!(hash.starts_with("md5"));
867 assert_eq!(hash.len(), 35); }
869
870 #[test]
871 fn test_transaction_status_conversion() {
872 assert_eq!(
873 TransactionStatusState::from(TransactionStatus::Idle),
874 TransactionStatusState::Idle
875 );
876 assert_eq!(
877 TransactionStatusState::from(TransactionStatus::Transaction),
878 TransactionStatusState::InTransaction
879 );
880 assert_eq!(
881 TransactionStatusState::from(TransactionStatus::Error),
882 TransactionStatusState::InFailed
883 );
884 }
885
886 #[test]
887 fn test_error_classification() {
888 let fields = ErrorFields {
889 severity: "ERROR".to_string(),
890 code: "23505".to_string(),
891 message: "unique violation".to_string(),
892 ..Default::default()
893 };
894 let err = error_from_fields(&fields);
895 assert!(matches!(err, Error::Query(q) if q.kind == QueryErrorKind::Constraint));
896
897 let fields = ErrorFields {
898 severity: "FATAL".to_string(),
899 code: "28P01".to_string(),
900 message: "password authentication failed".to_string(),
901 ..Default::default()
902 };
903 let err = error_from_fields(&fields);
904 assert!(matches!(
905 err,
906 Error::Connection(c) if c.kind == ConnectionErrorKind::Authentication
907 ));
908 }
909}