Skip to main content

sqlmodel_postgres/
async_connection.rs

1//! Async PostgreSQL connection implementation.
2//!
3//! This module implements an async PostgreSQL connection using asupersync's TCP
4//! primitives. It provides a shared wrapper that implements `sqlmodel-core`'s
5//! [`Connection`] trait.
6//!
7//! The implementation currently focuses on:
8//! - Async connect + authentication (cleartext, MD5, SCRAM-SHA-256)
9//! - Extended query protocol for parameterized queries
10//! - Row decoding via the postgres type registry (OID + text/binary format)
11//! - Basic transaction support (BEGIN/COMMIT/ROLLBACK + savepoints)
12
13// Allow `impl Future` return types in trait methods - intentional for async trait compat
14#![allow(clippy::manual_async_fn)]
15// The Error type is intentionally large to carry full context
16#![allow(clippy::result_large_err)]
17
18use std::collections::HashMap;
19use std::future::Future;
20#[cfg(feature = "tls")]
21use std::io::{Read, Write};
22use std::sync::Arc;
23
24use asupersync::io::{AsyncRead, AsyncWrite, ReadBuf};
25use asupersync::net::TcpStream;
26use asupersync::sync::Mutex;
27use asupersync::{Cx, Outcome};
28
29use sqlmodel_core::connection::{Connection, IsolationLevel, PreparedStatement, TransactionOps};
30use sqlmodel_core::error::{
31    ConnectionError, ConnectionErrorKind, ProtocolError, QueryError, QueryErrorKind,
32};
33use sqlmodel_core::row::ColumnInfo;
34use sqlmodel_core::{Error, Row, Value};
35
36use crate::auth::ScramClient;
37use crate::config::{PgConfig, SslMode};
38use crate::connection::{ConnectionState, TransactionStatusState};
39use crate::protocol::{
40    BackendMessage, DescribeKind, ErrorFields, FrontendMessage, MessageReader, MessageWriter,
41    PROTOCOL_VERSION,
42};
43use crate::types::{Format, decode_value, encode_value};
44
45#[cfg(feature = "tls")]
46use crate::tls;
47
48enum PgAsyncStream {
49    Plain(TcpStream),
50    #[cfg(feature = "tls")]
51    Tls(AsyncTlsStream),
52    #[cfg(feature = "tls")]
53    Closed,
54}
55
56impl PgAsyncStream {
57    #[cfg(feature = "tls")]
58    async fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
59        match self {
60            PgAsyncStream::Plain(s) => read_exact_plain_async(s, buf).await,
61            #[cfg(feature = "tls")]
62            PgAsyncStream::Tls(s) => s.read_exact(buf).await,
63            #[cfg(feature = "tls")]
64            PgAsyncStream::Closed => Err(std::io::Error::new(
65                std::io::ErrorKind::NotConnected,
66                "connection closed",
67            )),
68        }
69    }
70
71    async fn read_some(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
72        match self {
73            PgAsyncStream::Plain(s) => read_some_plain_async(s, buf).await,
74            #[cfg(feature = "tls")]
75            PgAsyncStream::Tls(s) => s.read_plain(buf).await,
76            #[cfg(feature = "tls")]
77            PgAsyncStream::Closed => Err(std::io::Error::new(
78                std::io::ErrorKind::NotConnected,
79                "connection closed",
80            )),
81        }
82    }
83
84    async fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
85        match self {
86            PgAsyncStream::Plain(s) => write_all_plain_async(s, buf).await,
87            #[cfg(feature = "tls")]
88            PgAsyncStream::Tls(s) => s.write_all(buf).await,
89            #[cfg(feature = "tls")]
90            PgAsyncStream::Closed => Err(std::io::Error::new(
91                std::io::ErrorKind::NotConnected,
92                "connection closed",
93            )),
94        }
95    }
96
97    async fn flush(&mut self) -> std::io::Result<()> {
98        match self {
99            PgAsyncStream::Plain(s) => flush_plain_async(s).await,
100            #[cfg(feature = "tls")]
101            PgAsyncStream::Tls(s) => s.flush().await,
102            #[cfg(feature = "tls")]
103            PgAsyncStream::Closed => Err(std::io::Error::new(
104                std::io::ErrorKind::NotConnected,
105                "connection closed",
106            )),
107        }
108    }
109}
110
111#[cfg(feature = "tls")]
112struct AsyncTlsStream {
113    tcp: TcpStream,
114    tls: rustls::ClientConnection,
115}
116
117#[cfg(feature = "tls")]
118impl AsyncTlsStream {
119    async fn handshake(mut tcp: TcpStream, ssl_mode: SslMode, host: &str) -> Result<Self, Error> {
120        let config = tls::build_client_config(ssl_mode)?;
121        let server_name = tls::server_name(host)?;
122        let mut tls = rustls::ClientConnection::new(std::sync::Arc::new(config), server_name)
123            .map_err(|e| connection_error(format!("Failed to create TLS connection: {e}")))?;
124
125        while tls.is_handshaking() {
126            while tls.wants_write() {
127                let mut out = Vec::new();
128                tls.write_tls(&mut out)
129                    .map_err(|e| connection_error(format!("TLS handshake write_tls error: {e}")))?;
130                if !out.is_empty() {
131                    write_all_plain_async(&mut tcp, &out).await.map_err(|e| {
132                        Error::Connection(ConnectionError {
133                            kind: ConnectionErrorKind::Disconnected,
134                            message: format!("TLS handshake write error: {e}"),
135                            source: Some(Box::new(e)),
136                        })
137                    })?;
138                }
139            }
140
141            if tls.wants_read() {
142                let mut buf = [0u8; 8192];
143                let n = read_some_plain_async(&mut tcp, &mut buf)
144                    .await
145                    .map_err(|e| {
146                        Error::Connection(ConnectionError {
147                            kind: ConnectionErrorKind::Disconnected,
148                            message: format!("TLS handshake read error: {e}"),
149                            source: Some(Box::new(e)),
150                        })
151                    })?;
152                if n == 0 {
153                    return Err(connection_error("Connection closed during TLS handshake"));
154                }
155
156                let mut cursor = std::io::Cursor::new(&buf[..n]);
157                tls.read_tls(&mut cursor)
158                    .map_err(|e| connection_error(format!("TLS handshake read_tls error: {e}")))?;
159                tls.process_new_packets()
160                    .map_err(|e| connection_error(format!("TLS handshake error: {e}")))?;
161            }
162        }
163
164        Ok(Self { tcp, tls })
165    }
166
167    async fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
168        let mut read = 0;
169        while read < buf.len() {
170            let n = self.read_plain(&mut buf[read..]).await?;
171            if n == 0 {
172                return Err(std::io::Error::new(
173                    std::io::ErrorKind::UnexpectedEof,
174                    "connection closed",
175                ));
176            }
177            read += n;
178        }
179        Ok(())
180    }
181
182    async fn read_plain(&mut self, out: &mut [u8]) -> std::io::Result<usize> {
183        loop {
184            match self.tls.reader().read(out) {
185                Ok(n) if n > 0 => return Ok(n),
186                Ok(_) => {}
187                Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {}
188                Err(e) => return Err(e),
189            }
190
191            if !self.tls.wants_read() {
192                return Ok(0);
193            }
194
195            let mut enc = [0u8; 8192];
196            let n = read_some_plain_async(&mut self.tcp, &mut enc).await?;
197            if n == 0 {
198                return Ok(0);
199            }
200
201            let mut cursor = std::io::Cursor::new(&enc[..n]);
202            self.tls.read_tls(&mut cursor)?;
203            self.tls
204                .process_new_packets()
205                .map_err(|e| std::io::Error::other(format!("TLS error: {e}")))?;
206        }
207    }
208
209    async fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
210        let mut written = 0;
211        while written < buf.len() {
212            let n = self.tls.writer().write(&buf[written..])?;
213            if n == 0 {
214                return Err(std::io::Error::new(
215                    std::io::ErrorKind::WriteZero,
216                    "TLS write zero",
217                ));
218            }
219            written += n;
220            self.flush().await?;
221        }
222        Ok(())
223    }
224
225    async fn flush(&mut self) -> std::io::Result<()> {
226        self.tls.writer().flush()?;
227        while self.tls.wants_write() {
228            let mut out = Vec::new();
229            self.tls.write_tls(&mut out)?;
230            if !out.is_empty() {
231                write_all_plain_async(&mut self.tcp, &out).await?;
232            }
233        }
234        flush_plain_async(&mut self.tcp).await
235    }
236}
237
238#[cfg(feature = "tls")]
239async fn read_exact_plain_async(stream: &mut TcpStream, buf: &mut [u8]) -> std::io::Result<()> {
240    let mut read = 0;
241    while read < buf.len() {
242        let n = read_some_plain_async(stream, &mut buf[read..]).await?;
243        if n == 0 {
244            return Err(std::io::Error::new(
245                std::io::ErrorKind::UnexpectedEof,
246                "connection closed",
247            ));
248        }
249        read += n;
250    }
251    Ok(())
252}
253
254async fn read_some_plain_async(stream: &mut TcpStream, buf: &mut [u8]) -> std::io::Result<usize> {
255    let mut read_buf = ReadBuf::new(buf);
256    std::future::poll_fn(|cx| std::pin::Pin::new(&mut *stream).poll_read(cx, &mut read_buf))
257        .await?;
258    Ok(read_buf.filled().len())
259}
260
261async fn write_all_plain_async(stream: &mut TcpStream, buf: &[u8]) -> std::io::Result<()> {
262    let mut written = 0;
263    while written < buf.len() {
264        let n = std::future::poll_fn(|cx| {
265            std::pin::Pin::new(&mut *stream).poll_write(cx, &buf[written..])
266        })
267        .await?;
268        if n == 0 {
269            return Err(std::io::Error::new(
270                std::io::ErrorKind::WriteZero,
271                "connection closed",
272            ));
273        }
274        written += n;
275    }
276    Ok(())
277}
278
279async fn flush_plain_async(stream: &mut TcpStream) -> std::io::Result<()> {
280    std::future::poll_fn(|cx| std::pin::Pin::new(&mut *stream).poll_flush(cx)).await
281}
282
283/// Async PostgreSQL connection.
284///
285/// This connection uses asupersync's TCP stream for non-blocking I/O and
286/// supports the extended query protocol for parameter binding.
287pub struct PgAsyncConnection {
288    stream: PgAsyncStream,
289    state: ConnectionState,
290    process_id: i32,
291    secret_key: i32,
292    parameters: HashMap<String, String>,
293    next_prepared_id: u64,
294    prepared: HashMap<u64, PgPreparedMeta>,
295    config: PgConfig,
296    reader: MessageReader,
297    writer: MessageWriter,
298    read_buf: Vec<u8>,
299}
300
301#[derive(Debug, Clone)]
302struct PgPreparedMeta {
303    name: String,
304    param_type_oids: Vec<u32>,
305}
306
307impl std::fmt::Debug for PgAsyncConnection {
308    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
309        f.debug_struct("PgAsyncConnection")
310            .field("state", &self.state)
311            .field("process_id", &self.process_id)
312            .field("host", &self.config.host)
313            .field("port", &self.config.port)
314            .field("database", &self.config.database)
315            .finish_non_exhaustive()
316    }
317}
318
319impl PgAsyncConnection {
320    /// Establish a new async connection to the PostgreSQL server.
321    pub async fn connect(_cx: &Cx, config: PgConfig) -> Outcome<Self, Error> {
322        let addr = config.socket_addr();
323        let socket_addr = match addr.parse() {
324            Ok(a) => a,
325            Err(e) => {
326                return Outcome::Err(Error::Connection(ConnectionError {
327                    kind: ConnectionErrorKind::Connect,
328                    message: format!("Invalid socket address: {}", e),
329                    source: None,
330                }));
331            }
332        };
333
334        let stream = match TcpStream::connect_timeout(socket_addr, config.connect_timeout).await {
335            Ok(s) => s,
336            Err(e) => {
337                let kind = if e.kind() == std::io::ErrorKind::ConnectionRefused {
338                    ConnectionErrorKind::Refused
339                } else {
340                    ConnectionErrorKind::Connect
341                };
342                return Outcome::Err(Error::Connection(ConnectionError {
343                    kind,
344                    message: format!("Failed to connect to {}: {}", addr, e),
345                    source: Some(Box::new(e)),
346                }));
347            }
348        };
349
350        stream.set_nodelay(true).ok();
351
352        let mut conn = Self {
353            stream: PgAsyncStream::Plain(stream),
354            state: ConnectionState::Connecting,
355            process_id: 0,
356            secret_key: 0,
357            parameters: HashMap::new(),
358            next_prepared_id: 1,
359            prepared: HashMap::new(),
360            config,
361            reader: MessageReader::new(),
362            writer: MessageWriter::new(),
363            read_buf: vec![0u8; 8192],
364        };
365
366        // SSL negotiation (feature-gated TLS)
367        if conn.config.ssl_mode.should_try_ssl() {
368            #[cfg(feature = "tls")]
369            match conn.negotiate_ssl().await {
370                Outcome::Ok(()) => {}
371                Outcome::Err(e) => return Outcome::Err(e),
372                Outcome::Cancelled(r) => return Outcome::Cancelled(r),
373                Outcome::Panicked(p) => return Outcome::Panicked(p),
374            }
375
376            #[cfg(not(feature = "tls"))]
377            if conn.config.ssl_mode != SslMode::Prefer {
378                return Outcome::Err(connection_error(
379                    "TLS requested but 'sqlmodel-postgres' was built without feature 'tls'",
380                ));
381            }
382        }
383
384        // Startup + authentication
385        if let Outcome::Err(e) = conn.send_startup().await {
386            return Outcome::Err(e);
387        }
388        conn.state = ConnectionState::Authenticating;
389
390        match conn.handle_auth().await {
391            Outcome::Ok(()) => {}
392            Outcome::Err(e) => return Outcome::Err(e),
393            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
394            Outcome::Panicked(p) => return Outcome::Panicked(p),
395        }
396
397        match conn.read_startup_messages().await {
398            Outcome::Ok(()) => Outcome::Ok(conn),
399            Outcome::Err(e) => Outcome::Err(e),
400            Outcome::Cancelled(r) => Outcome::Cancelled(r),
401            Outcome::Panicked(p) => Outcome::Panicked(p),
402        }
403    }
404
405    /// Run a parameterized query and return all rows.
406    pub async fn query_async(
407        &mut self,
408        cx: &Cx,
409        sql: &str,
410        params: &[Value],
411    ) -> Outcome<Vec<Row>, Error> {
412        match self.run_extended(cx, sql, params).await {
413            Outcome::Ok(result) => Outcome::Ok(result.rows),
414            Outcome::Err(e) => Outcome::Err(e),
415            Outcome::Cancelled(r) => Outcome::Cancelled(r),
416            Outcome::Panicked(p) => Outcome::Panicked(p),
417        }
418    }
419
420    /// Execute a statement and return rows affected.
421    pub async fn execute_async(
422        &mut self,
423        cx: &Cx,
424        sql: &str,
425        params: &[Value],
426    ) -> Outcome<u64, Error> {
427        match self.run_extended(cx, sql, params).await {
428            Outcome::Ok(result) => {
429                Outcome::Ok(parse_rows_affected(result.command_tag.as_deref()).unwrap_or(0))
430            }
431            Outcome::Err(e) => Outcome::Err(e),
432            Outcome::Cancelled(r) => Outcome::Cancelled(r),
433            Outcome::Panicked(p) => Outcome::Panicked(p),
434        }
435    }
436
437    /// Execute an INSERT and return the inserted id.
438    ///
439    /// PostgreSQL requires `RETURNING` to retrieve generated IDs. This method
440    /// expects the SQL to return a single-row, single-column result set
441    /// containing an integer id.
442    pub async fn insert_async(
443        &mut self,
444        cx: &Cx,
445        sql: &str,
446        params: &[Value],
447    ) -> Outcome<i64, Error> {
448        let result = match self.run_extended(cx, sql, params).await {
449            Outcome::Ok(r) => r,
450            Outcome::Err(e) => return Outcome::Err(e),
451            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
452            Outcome::Panicked(p) => return Outcome::Panicked(p),
453        };
454
455        let Some(row) = result.rows.first() else {
456            return Outcome::Err(query_error_msg(
457                "INSERT did not return an id; add `RETURNING id`",
458                QueryErrorKind::Database,
459            ));
460        };
461        let Some(id_value) = row.get(0) else {
462            return Outcome::Err(query_error_msg(
463                "INSERT result row missing id column",
464                QueryErrorKind::Database,
465            ));
466        };
467        match id_value.as_i64() {
468            Some(v) => Outcome::Ok(v),
469            None => Outcome::Err(query_error_msg(
470                "INSERT returned non-integer id",
471                QueryErrorKind::Database,
472            )),
473        }
474    }
475
476    /// Ping the server.
477    pub async fn ping_async(&mut self, cx: &Cx) -> Outcome<(), Error> {
478        self.execute_async(cx, "SELECT 1", &[]).await.map(|_| ())
479    }
480
481    /// Close the connection.
482    pub async fn close_async(&mut self, cx: &Cx) -> Outcome<(), Error> {
483        // Best-effort terminate. If this fails, the drop will close the socket.
484        //
485        // Note: server-side prepared statements are released when the connection terminates;
486        // explicit Close/DEALLOCATE is not required for correctness here.
487        let _ = self.send_message(cx, &FrontendMessage::Terminate).await;
488        self.state = ConnectionState::Closed;
489        Outcome::Ok(())
490    }
491
492    // ==================== Prepared statements ====================
493
494    /// Prepare a server-side statement and return a reusable handle.
495    pub async fn prepare_async(&mut self, cx: &Cx, sql: &str) -> Outcome<PreparedStatement, Error> {
496        let stmt_id = self.next_prepared_id;
497        self.next_prepared_id = self.next_prepared_id.saturating_add(1);
498        let stmt_name = format!("sqlmodel_stmt_{stmt_id}");
499
500        if let Outcome::Err(e) = self
501            .send_message(
502                cx,
503                &FrontendMessage::Parse {
504                    name: stmt_name.clone(),
505                    query: sql.to_string(),
506                    // Let PostgreSQL infer types where possible; ambiguous queries will error
507                    // and should add explicit casts.
508                    param_types: Vec::new(),
509                },
510            )
511            .await
512        {
513            return Outcome::Err(e);
514        }
515
516        if let Outcome::Err(e) = self
517            .send_message(
518                cx,
519                &FrontendMessage::Describe {
520                    kind: DescribeKind::Statement,
521                    name: stmt_name.clone(),
522                },
523            )
524            .await
525        {
526            return Outcome::Err(e);
527        }
528
529        if let Outcome::Err(e) = self.send_message(cx, &FrontendMessage::Sync).await {
530            return Outcome::Err(e);
531        }
532
533        let mut param_type_oids: Option<Vec<u32>> = None;
534        let mut columns: Option<Vec<String>> = None;
535
536        loop {
537            let msg = match self.receive_message(cx).await {
538                Outcome::Ok(m) => m,
539                Outcome::Err(e) => return Outcome::Err(e),
540                Outcome::Cancelled(r) => return Outcome::Cancelled(r),
541                Outcome::Panicked(p) => return Outcome::Panicked(p),
542            };
543
544            match msg {
545                BackendMessage::ParseComplete
546                | BackendMessage::BindComplete
547                | BackendMessage::CloseComplete
548                | BackendMessage::NoData
549                | BackendMessage::EmptyQueryResponse => {}
550                BackendMessage::ParameterDescription(oids) => {
551                    param_type_oids = Some(oids);
552                }
553                BackendMessage::RowDescription(desc) => {
554                    columns = Some(desc.iter().map(|f| f.name.clone()).collect());
555                }
556                BackendMessage::ReadyForQuery(status) => {
557                    self.state = ConnectionState::Ready(TransactionStatusState::from(status));
558                    break;
559                }
560                BackendMessage::ErrorResponse(e) => {
561                    self.state = ConnectionState::Error;
562                    return Outcome::Err(error_from_fields(&e));
563                }
564                BackendMessage::NoticeResponse(_notice) => {}
565                other => {
566                    return Outcome::Err(protocol_error(format!(
567                        "Unexpected message during prepare: {other:?}"
568                    )));
569                }
570            }
571        }
572
573        let param_type_oids = param_type_oids.unwrap_or_default();
574        self.prepared.insert(
575            stmt_id,
576            PgPreparedMeta {
577                name: stmt_name,
578                param_type_oids: param_type_oids.clone(),
579            },
580        );
581
582        match columns {
583            Some(cols) => Outcome::Ok(PreparedStatement::with_columns(
584                stmt_id,
585                sql.to_string(),
586                param_type_oids.len(),
587                cols,
588            )),
589            None => Outcome::Ok(PreparedStatement::new(
590                stmt_id,
591                sql.to_string(),
592                param_type_oids.len(),
593            )),
594        }
595    }
596
597    pub async fn query_prepared_async(
598        &mut self,
599        cx: &Cx,
600        stmt: &PreparedStatement,
601        params: &[Value],
602    ) -> Outcome<Vec<Row>, Error> {
603        let meta = match self.prepared.get(&stmt.id()) {
604            Some(m) => m.clone(),
605            None => {
606                return Outcome::Err(query_error_msg(
607                    format!("Unknown prepared statement id {}", stmt.id()),
608                    QueryErrorKind::Database,
609                ));
610            }
611        };
612
613        if meta.param_type_oids.len() != params.len() {
614            return Outcome::Err(query_error_msg(
615                format!(
616                    "Prepared statement expects {} params, got {}",
617                    meta.param_type_oids.len(),
618                    params.len()
619                ),
620                QueryErrorKind::Database,
621            ));
622        }
623
624        match self.run_prepared(cx, &meta, params).await {
625            Outcome::Ok(result) => Outcome::Ok(result.rows),
626            Outcome::Err(e) => Outcome::Err(e),
627            Outcome::Cancelled(r) => Outcome::Cancelled(r),
628            Outcome::Panicked(p) => Outcome::Panicked(p),
629        }
630    }
631
632    pub async fn execute_prepared_async(
633        &mut self,
634        cx: &Cx,
635        stmt: &PreparedStatement,
636        params: &[Value],
637    ) -> Outcome<u64, Error> {
638        let meta = match self.prepared.get(&stmt.id()) {
639            Some(m) => m.clone(),
640            None => {
641                return Outcome::Err(query_error_msg(
642                    format!("Unknown prepared statement id {}", stmt.id()),
643                    QueryErrorKind::Database,
644                ));
645            }
646        };
647
648        if meta.param_type_oids.len() != params.len() {
649            return Outcome::Err(query_error_msg(
650                format!(
651                    "Prepared statement expects {} params, got {}",
652                    meta.param_type_oids.len(),
653                    params.len()
654                ),
655                QueryErrorKind::Database,
656            ));
657        }
658
659        match self.run_prepared(cx, &meta, params).await {
660            Outcome::Ok(result) => {
661                Outcome::Ok(parse_rows_affected(result.command_tag.as_deref()).unwrap_or(0))
662            }
663            Outcome::Err(e) => Outcome::Err(e),
664            Outcome::Cancelled(r) => Outcome::Cancelled(r),
665            Outcome::Panicked(p) => Outcome::Panicked(p),
666        }
667    }
668
669    // ==================== Protocol: extended query ====================
670
671    async fn read_extended_result(&mut self, cx: &Cx) -> Outcome<PgQueryResult, Error> {
672        // Read responses until ReadyForQuery
673        let mut field_descs: Option<Vec<crate::protocol::FieldDescription>> = None;
674        let mut columns: Option<Arc<ColumnInfo>> = None;
675        let mut rows: Vec<Row> = Vec::new();
676        let mut command_tag: Option<String> = None;
677
678        loop {
679            let msg = match self.receive_message(cx).await {
680                Outcome::Ok(m) => m,
681                Outcome::Err(e) => return Outcome::Err(e),
682                Outcome::Cancelled(r) => return Outcome::Cancelled(r),
683                Outcome::Panicked(p) => return Outcome::Panicked(p),
684            };
685
686            match msg {
687                BackendMessage::ParseComplete
688                | BackendMessage::BindComplete
689                | BackendMessage::CloseComplete
690                | BackendMessage::ParameterDescription(_)
691                | BackendMessage::NoData
692                | BackendMessage::PortalSuspended
693                | BackendMessage::EmptyQueryResponse => {}
694                BackendMessage::RowDescription(desc) => {
695                    let names: Vec<String> = desc.iter().map(|f| f.name.clone()).collect();
696                    columns = Some(Arc::new(ColumnInfo::new(names)));
697                    field_descs = Some(desc);
698                }
699                BackendMessage::DataRow(raw_values) => {
700                    let Some(ref desc) = field_descs else {
701                        return Outcome::Err(protocol_error(
702                            "DataRow received before RowDescription",
703                        ));
704                    };
705                    let Some(ref cols) = columns else {
706                        return Outcome::Err(protocol_error("Row column metadata missing"));
707                    };
708                    if raw_values.len() != desc.len() {
709                        return Outcome::Err(protocol_error("DataRow field count mismatch"));
710                    }
711
712                    let mut values = Vec::with_capacity(raw_values.len());
713                    for (i, raw) in raw_values.into_iter().enumerate() {
714                        match raw {
715                            None => values.push(Value::Null),
716                            Some(bytes) => {
717                                let field = &desc[i];
718                                let format = Format::from_code(field.format);
719                                let decoded = match decode_value(
720                                    field.type_oid,
721                                    Some(bytes.as_slice()),
722                                    format,
723                                ) {
724                                    Ok(v) => v,
725                                    Err(e) => return Outcome::Err(e),
726                                };
727                                values.push(decoded);
728                            }
729                        }
730                    }
731                    rows.push(Row::with_columns(Arc::clone(cols), values));
732                }
733                BackendMessage::CommandComplete(tag) => {
734                    command_tag = Some(tag);
735                }
736                BackendMessage::ReadyForQuery(status) => {
737                    self.state = ConnectionState::Ready(TransactionStatusState::from(status));
738                    break;
739                }
740                BackendMessage::ErrorResponse(e) => {
741                    self.state = ConnectionState::Error;
742                    return Outcome::Err(error_from_fields(&e));
743                }
744                BackendMessage::NoticeResponse(_notice) => {}
745                _ => {}
746            }
747        }
748
749        Outcome::Ok(PgQueryResult { rows, command_tag })
750    }
751
752    async fn run_extended(
753        &mut self,
754        cx: &Cx,
755        sql: &str,
756        params: &[Value],
757    ) -> Outcome<PgQueryResult, Error> {
758        // Encode parameters
759        let mut param_types = Vec::with_capacity(params.len());
760        let mut param_values = Vec::with_capacity(params.len());
761
762        for v in params {
763            if matches!(v, Value::Null) {
764                param_types.push(0);
765                param_values.push(None);
766                continue;
767            }
768            match encode_value(v, Format::Text) {
769                Ok((bytes, oid)) => {
770                    param_types.push(oid);
771                    param_values.push(Some(bytes));
772                }
773                Err(e) => return Outcome::Err(e),
774            }
775        }
776
777        // Parse + bind unnamed statement/portal
778        if let Outcome::Err(e) = self
779            .send_message(
780                cx,
781                &FrontendMessage::Parse {
782                    name: String::new(),
783                    query: sql.to_string(),
784                    param_types,
785                },
786            )
787            .await
788        {
789            return Outcome::Err(e);
790        }
791
792        let param_formats = if params.is_empty() {
793            Vec::new()
794        } else {
795            vec![Format::Text.code()]
796        };
797        if let Outcome::Err(e) = self
798            .send_message(
799                cx,
800                &FrontendMessage::Bind {
801                    portal: String::new(),
802                    statement: String::new(),
803                    param_formats,
804                    params: param_values,
805                    // Default result formats (text) when empty.
806                    result_formats: Vec::new(),
807                },
808            )
809            .await
810        {
811            return Outcome::Err(e);
812        }
813
814        if let Outcome::Err(e) = self
815            .send_message(
816                cx,
817                &FrontendMessage::Describe {
818                    kind: DescribeKind::Portal,
819                    name: String::new(),
820                },
821            )
822            .await
823        {
824            return Outcome::Err(e);
825        }
826
827        if let Outcome::Err(e) = self
828            .send_message(
829                cx,
830                &FrontendMessage::Execute {
831                    portal: String::new(),
832                    max_rows: 0,
833                },
834            )
835            .await
836        {
837            return Outcome::Err(e);
838        }
839
840        if let Outcome::Err(e) = self.send_message(cx, &FrontendMessage::Sync).await {
841            return Outcome::Err(e);
842        }
843        self.read_extended_result(cx).await
844    }
845
846    async fn run_prepared(
847        &mut self,
848        cx: &Cx,
849        meta: &PgPreparedMeta,
850        params: &[Value],
851    ) -> Outcome<PgQueryResult, Error> {
852        let mut param_values = Vec::with_capacity(params.len());
853
854        for (i, v) in params.iter().enumerate() {
855            if matches!(v, Value::Null) {
856                param_values.push(None);
857                continue;
858            }
859            match encode_value(v, Format::Text) {
860                Ok((bytes, oid)) => {
861                    let expected = meta.param_type_oids.get(i).copied().unwrap_or(0);
862                    if expected != 0 && expected != oid {
863                        return Outcome::Err(query_error_msg(
864                            format!(
865                                "Prepared statement param {} expects type OID {}, got {}",
866                                i + 1,
867                                expected,
868                                oid
869                            ),
870                            QueryErrorKind::Database,
871                        ));
872                    }
873                    param_values.push(Some(bytes));
874                }
875                Err(e) => return Outcome::Err(e),
876            }
877        }
878
879        let param_formats = if params.is_empty() {
880            Vec::new()
881        } else {
882            vec![Format::Text.code()]
883        };
884
885        if let Outcome::Err(e) = self
886            .send_message(
887                cx,
888                &FrontendMessage::Bind {
889                    portal: String::new(),
890                    statement: meta.name.clone(),
891                    param_formats,
892                    params: param_values,
893                    result_formats: Vec::new(),
894                },
895            )
896            .await
897        {
898            return Outcome::Err(e);
899        }
900
901        if let Outcome::Err(e) = self
902            .send_message(
903                cx,
904                &FrontendMessage::Describe {
905                    kind: DescribeKind::Portal,
906                    name: String::new(),
907                },
908            )
909            .await
910        {
911            return Outcome::Err(e);
912        }
913
914        if let Outcome::Err(e) = self
915            .send_message(
916                cx,
917                &FrontendMessage::Execute {
918                    portal: String::new(),
919                    max_rows: 0,
920                },
921            )
922            .await
923        {
924            return Outcome::Err(e);
925        }
926
927        if let Outcome::Err(e) = self.send_message(cx, &FrontendMessage::Sync).await {
928            return Outcome::Err(e);
929        }
930
931        self.read_extended_result(cx).await
932    }
933
934    // ==================== Startup + auth ====================
935
936    #[cfg(feature = "tls")]
937    async fn negotiate_ssl(&mut self) -> Outcome<(), Error> {
938        // Send SSL request
939        if let Outcome::Err(e) = self.send_message_no_cx(&FrontendMessage::SSLRequest).await {
940            return Outcome::Err(e);
941        }
942
943        // Read single-byte response
944        let mut buf = [0u8; 1];
945        if let Err(e) = self.stream.read_exact(&mut buf).await {
946            return Outcome::Err(Error::Connection(ConnectionError {
947                kind: ConnectionErrorKind::Ssl,
948                message: format!("Failed to read SSL response: {}", e),
949                source: Some(Box::new(e)),
950            }));
951        }
952
953        match buf[0] {
954            b'S' => {
955                #[cfg(feature = "tls")]
956                {
957                    let plain = match std::mem::replace(&mut self.stream, PgAsyncStream::Closed) {
958                        PgAsyncStream::Plain(s) => s,
959                        other => {
960                            self.stream = other;
961                            return Outcome::Err(connection_error(
962                                "TLS upgrade requires a plain TCP stream",
963                            ));
964                        }
965                    };
966
967                    let tls_stream = match AsyncTlsStream::handshake(
968                        plain,
969                        self.config.ssl_mode,
970                        &self.config.host,
971                    )
972                    .await
973                    {
974                        Ok(s) => s,
975                        Err(e) => return Outcome::Err(e),
976                    };
977
978                    self.stream = PgAsyncStream::Tls(tls_stream);
979                    Outcome::Ok(())
980                }
981
982                #[cfg(not(feature = "tls"))]
983                {
984                    Outcome::Err(connection_error(
985                        "TLS requested but 'sqlmodel-postgres' was built without feature 'tls'",
986                    ))
987                }
988            }
989            b'N' => {
990                if self.config.ssl_mode.is_required() {
991                    Outcome::Err(Error::Connection(ConnectionError {
992                        kind: ConnectionErrorKind::Ssl,
993                        message: "Server does not support SSL".to_string(),
994                        source: None,
995                    }))
996                } else {
997                    Outcome::Ok(())
998                }
999            }
1000            other => Outcome::Err(Error::Connection(ConnectionError {
1001                kind: ConnectionErrorKind::Ssl,
1002                message: format!("Unexpected SSL response: 0x{other:02x}"),
1003                source: None,
1004            })),
1005        }
1006    }
1007
1008    async fn send_startup(&mut self) -> Outcome<(), Error> {
1009        let params = self.config.startup_params();
1010        self.send_message_no_cx(&FrontendMessage::Startup {
1011            version: PROTOCOL_VERSION,
1012            params,
1013        })
1014        .await
1015    }
1016
1017    fn require_auth_value(&self, message: &'static str) -> Outcome<&str, Error> {
1018        // NOTE: Auth values are sourced from runtime config, not hardcoded.
1019        match self.config.password.as_deref() {
1020            Some(password) => Outcome::Ok(password),
1021            None => Outcome::Err(auth_error(message)),
1022        }
1023    }
1024
1025    async fn handle_auth(&mut self) -> Outcome<(), Error> {
1026        loop {
1027            let msg = match self.receive_message_no_cx().await {
1028                Outcome::Ok(m) => m,
1029                Outcome::Err(e) => return Outcome::Err(e),
1030                Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1031                Outcome::Panicked(p) => return Outcome::Panicked(p),
1032            };
1033
1034            match msg {
1035                BackendMessage::AuthenticationOk => return Outcome::Ok(()),
1036                BackendMessage::AuthenticationCleartextPassword => {
1037                    let auth_value = match self
1038                        .require_auth_value("Authentication value required but not provided")
1039                    {
1040                        Outcome::Ok(password) => password,
1041                        Outcome::Err(e) => return Outcome::Err(e),
1042                        Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1043                        Outcome::Panicked(p) => return Outcome::Panicked(p),
1044                    };
1045                    if let Outcome::Err(e) = self
1046                        .send_message_no_cx(&FrontendMessage::PasswordMessage(
1047                            auth_value.to_string(),
1048                        ))
1049                        .await
1050                    {
1051                        return Outcome::Err(e);
1052                    }
1053                }
1054                BackendMessage::AuthenticationMD5Password(salt) => {
1055                    let auth_value = match self
1056                        .require_auth_value("Authentication value required but not provided")
1057                    {
1058                        Outcome::Ok(password) => password,
1059                        Outcome::Err(e) => return Outcome::Err(e),
1060                        Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1061                        Outcome::Panicked(p) => return Outcome::Panicked(p),
1062                    };
1063                    let hash = md5_password(&self.config.user, auth_value, salt);
1064                    if let Outcome::Err(e) = self
1065                        .send_message_no_cx(&FrontendMessage::PasswordMessage(hash))
1066                        .await
1067                    {
1068                        return Outcome::Err(e);
1069                    }
1070                }
1071                BackendMessage::AuthenticationSASL(mechanisms) => {
1072                    if mechanisms.contains(&"SCRAM-SHA-256".to_string()) {
1073                        match self.scram_auth().await {
1074                            Outcome::Ok(()) => {}
1075                            Outcome::Err(e) => return Outcome::Err(e),
1076                            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1077                            Outcome::Panicked(p) => return Outcome::Panicked(p),
1078                        }
1079                    } else {
1080                        return Outcome::Err(auth_error(format!(
1081                            "Unsupported SASL mechanisms: {:?}",
1082                            mechanisms
1083                        )));
1084                    }
1085                }
1086                BackendMessage::ErrorResponse(e) => {
1087                    self.state = ConnectionState::Error;
1088                    return Outcome::Err(error_from_fields(&e));
1089                }
1090                other => {
1091                    return Outcome::Err(protocol_error(format!(
1092                        "Unexpected message during auth: {other:?}"
1093                    )));
1094                }
1095            }
1096        }
1097    }
1098
1099    async fn scram_auth(&mut self) -> Outcome<(), Error> {
1100        let auth_value =
1101            match self.require_auth_value("Authentication value required for SCRAM-SHA-256") {
1102                Outcome::Ok(password) => password,
1103                Outcome::Err(e) => return Outcome::Err(e),
1104                Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1105                Outcome::Panicked(p) => return Outcome::Panicked(p),
1106            };
1107
1108        let mut client = ScramClient::new(&self.config.user, auth_value);
1109
1110        // Client-first
1111        let client_first = client.client_first();
1112        if let Outcome::Err(e) = self
1113            .send_message_no_cx(&FrontendMessage::SASLInitialResponse {
1114                mechanism: "SCRAM-SHA-256".to_string(),
1115                data: client_first,
1116            })
1117            .await
1118        {
1119            return Outcome::Err(e);
1120        }
1121
1122        // Server-first
1123        let msg = match self.receive_message_no_cx().await {
1124            Outcome::Ok(m) => m,
1125            Outcome::Err(e) => return Outcome::Err(e),
1126            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1127            Outcome::Panicked(p) => return Outcome::Panicked(p),
1128        };
1129        let server_first_data = match msg {
1130            BackendMessage::AuthenticationSASLContinue(data) => data,
1131            BackendMessage::ErrorResponse(e) => {
1132                self.state = ConnectionState::Error;
1133                return Outcome::Err(error_from_fields(&e));
1134            }
1135            other => {
1136                return Outcome::Err(protocol_error(format!(
1137                    "Expected SASL continue, got: {other:?}"
1138                )));
1139            }
1140        };
1141
1142        // Client-final
1143        let client_final = match client.process_server_first(&server_first_data) {
1144            Ok(v) => v,
1145            Err(e) => return Outcome::Err(e),
1146        };
1147        if let Outcome::Err(e) = self
1148            .send_message_no_cx(&FrontendMessage::SASLResponse(client_final))
1149            .await
1150        {
1151            return Outcome::Err(e);
1152        }
1153
1154        // Server-final
1155        let msg = match self.receive_message_no_cx().await {
1156            Outcome::Ok(m) => m,
1157            Outcome::Err(e) => return Outcome::Err(e),
1158            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1159            Outcome::Panicked(p) => return Outcome::Panicked(p),
1160        };
1161        let server_final_data = match msg {
1162            BackendMessage::AuthenticationSASLFinal(data) => data,
1163            BackendMessage::ErrorResponse(e) => {
1164                self.state = ConnectionState::Error;
1165                return Outcome::Err(error_from_fields(&e));
1166            }
1167            other => {
1168                return Outcome::Err(protocol_error(format!(
1169                    "Expected SASL final, got: {other:?}"
1170                )));
1171            }
1172        };
1173
1174        if let Err(e) = client.verify_server_final(&server_final_data) {
1175            return Outcome::Err(e);
1176        }
1177
1178        // Wait for AuthenticationOk
1179        let msg = match self.receive_message_no_cx().await {
1180            Outcome::Ok(m) => m,
1181            Outcome::Err(e) => return Outcome::Err(e),
1182            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1183            Outcome::Panicked(p) => return Outcome::Panicked(p),
1184        };
1185        match msg {
1186            BackendMessage::AuthenticationOk => Outcome::Ok(()),
1187            BackendMessage::ErrorResponse(e) => {
1188                self.state = ConnectionState::Error;
1189                Outcome::Err(error_from_fields(&e))
1190            }
1191            other => Outcome::Err(protocol_error(format!(
1192                "Expected AuthenticationOk, got: {other:?}"
1193            ))),
1194        }
1195    }
1196
1197    async fn read_startup_messages(&mut self) -> Outcome<(), Error> {
1198        loop {
1199            let msg = match self.receive_message_no_cx().await {
1200                Outcome::Ok(m) => m,
1201                Outcome::Err(e) => return Outcome::Err(e),
1202                Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1203                Outcome::Panicked(p) => return Outcome::Panicked(p),
1204            };
1205
1206            match msg {
1207                BackendMessage::BackendKeyData {
1208                    process_id,
1209                    secret_key,
1210                } => {
1211                    self.process_id = process_id;
1212                    self.secret_key = secret_key;
1213                }
1214                BackendMessage::ParameterStatus { name, value } => {
1215                    self.parameters.insert(name, value);
1216                }
1217                BackendMessage::ReadyForQuery(status) => {
1218                    self.state = ConnectionState::Ready(TransactionStatusState::from(status));
1219                    return Outcome::Ok(());
1220                }
1221                BackendMessage::ErrorResponse(e) => {
1222                    self.state = ConnectionState::Error;
1223                    return Outcome::Err(error_from_fields(&e));
1224                }
1225                BackendMessage::NoticeResponse(_notice) => {}
1226                other => {
1227                    return Outcome::Err(protocol_error(format!(
1228                        "Unexpected startup message: {other:?}"
1229                    )));
1230                }
1231            }
1232        }
1233    }
1234
1235    // ==================== I/O ====================
1236
1237    async fn send_message(&mut self, cx: &Cx, msg: &FrontendMessage) -> Outcome<(), Error> {
1238        // If cancelled, propagate early.
1239        if let Some(reason) = cx.cancel_reason() {
1240            return Outcome::Cancelled(reason);
1241        }
1242        self.send_message_no_cx(msg).await
1243    }
1244
1245    async fn receive_message(&mut self, cx: &Cx) -> Outcome<BackendMessage, Error> {
1246        if let Some(reason) = cx.cancel_reason() {
1247            return Outcome::Cancelled(reason);
1248        }
1249        self.receive_message_no_cx().await
1250    }
1251
1252    async fn send_message_no_cx(&mut self, msg: &FrontendMessage) -> Outcome<(), Error> {
1253        let data = self.writer.write(msg).to_vec();
1254
1255        if let Err(e) = self.stream.write_all(&data).await {
1256            self.state = ConnectionState::Error;
1257            return Outcome::Err(Error::Connection(ConnectionError {
1258                kind: ConnectionErrorKind::Disconnected,
1259                message: format!("Failed to write to server: {}", e),
1260                source: Some(Box::new(e)),
1261            }));
1262        }
1263
1264        if let Err(e) = self.stream.flush().await {
1265            self.state = ConnectionState::Error;
1266            return Outcome::Err(Error::Connection(ConnectionError {
1267                kind: ConnectionErrorKind::Disconnected,
1268                message: format!("Failed to flush stream: {}", e),
1269                source: Some(Box::new(e)),
1270            }));
1271        }
1272
1273        Outcome::Ok(())
1274    }
1275
1276    async fn receive_message_no_cx(&mut self) -> Outcome<BackendMessage, Error> {
1277        loop {
1278            match self.reader.next_message() {
1279                Ok(Some(msg)) => return Outcome::Ok(msg),
1280                Ok(None) => {}
1281                Err(e) => {
1282                    self.state = ConnectionState::Error;
1283                    return Outcome::Err(protocol_error(format!("Protocol error: {}", e)));
1284                }
1285            }
1286
1287            let n = match self.stream.read_some(&mut self.read_buf).await {
1288                Ok(n) => n,
1289                Err(e) => {
1290                    self.state = ConnectionState::Error;
1291                    return Outcome::Err(match e.kind() {
1292                        std::io::ErrorKind::TimedOut | std::io::ErrorKind::WouldBlock => {
1293                            Error::Timeout
1294                        }
1295                        _ => Error::Connection(ConnectionError {
1296                            kind: ConnectionErrorKind::Disconnected,
1297                            message: format!("Failed to read from server: {}", e),
1298                            source: Some(Box::new(e)),
1299                        }),
1300                    });
1301                }
1302            };
1303
1304            if n == 0 {
1305                self.state = ConnectionState::Disconnected;
1306                return Outcome::Err(Error::Connection(ConnectionError {
1307                    kind: ConnectionErrorKind::Disconnected,
1308                    message: "Connection closed by server".to_string(),
1309                    source: None,
1310                }));
1311            }
1312
1313            // Only append raw bytes; let next_message() at the top of the
1314            // loop handle parsing.  The old code called feed() here, which
1315            // parsed *and consumed* all complete messages from the buffer and
1316            // returned them in a Vec — but the caller only checked for Err,
1317            // silently discarding the Ok(messages).  On the next iteration
1318            // next_message() would see an empty buffer and block forever on
1319            // the socket read.  (See issue #9.)
1320            self.reader.push(&self.read_buf[..n]);
1321        }
1322    }
1323}
1324
1325/// Shared, cloneable PostgreSQL connection with interior mutability.
1326pub struct SharedPgConnection {
1327    inner: Arc<Mutex<PgAsyncConnection>>,
1328}
1329
1330impl SharedPgConnection {
1331    pub fn new(conn: PgAsyncConnection) -> Self {
1332        Self {
1333            inner: Arc::new(Mutex::new(conn)),
1334        }
1335    }
1336
1337    pub async fn connect(cx: &Cx, config: PgConfig) -> Outcome<Self, Error> {
1338        match PgAsyncConnection::connect(cx, config).await {
1339            Outcome::Ok(conn) => Outcome::Ok(Self::new(conn)),
1340            Outcome::Err(e) => Outcome::Err(e),
1341            Outcome::Cancelled(r) => Outcome::Cancelled(r),
1342            Outcome::Panicked(p) => Outcome::Panicked(p),
1343        }
1344    }
1345
1346    pub fn inner(&self) -> &Arc<Mutex<PgAsyncConnection>> {
1347        &self.inner
1348    }
1349
1350    async fn begin_transaction_impl(
1351        &self,
1352        cx: &Cx,
1353        isolation: Option<IsolationLevel>,
1354    ) -> Outcome<SharedPgTransaction<'_>, Error> {
1355        let inner = Arc::clone(&self.inner);
1356        let Ok(mut guard) = inner.lock(cx).await else {
1357            return Outcome::Err(connection_error("Failed to acquire connection lock"));
1358        };
1359
1360        if let Some(level) = isolation {
1361            let sql = format!("SET TRANSACTION ISOLATION LEVEL {}", level.as_sql());
1362            match guard.execute_async(cx, &sql, &[]).await {
1363                Outcome::Ok(_) => {}
1364                Outcome::Err(e) => return Outcome::Err(e),
1365                Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1366                Outcome::Panicked(p) => return Outcome::Panicked(p),
1367            }
1368        }
1369
1370        match guard.execute_async(cx, "BEGIN", &[]).await {
1371            Outcome::Ok(_) => {}
1372            Outcome::Err(e) => return Outcome::Err(e),
1373            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1374            Outcome::Panicked(p) => return Outcome::Panicked(p),
1375        }
1376
1377        drop(guard);
1378        Outcome::Ok(SharedPgTransaction {
1379            inner,
1380            committed: false,
1381            _marker: std::marker::PhantomData,
1382        })
1383    }
1384}
1385
1386impl Clone for SharedPgConnection {
1387    fn clone(&self) -> Self {
1388        Self {
1389            inner: Arc::clone(&self.inner),
1390        }
1391    }
1392}
1393
1394impl std::fmt::Debug for SharedPgConnection {
1395    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1396        f.debug_struct("SharedPgConnection")
1397            .field("inner", &"Arc<Mutex<PgAsyncConnection>>")
1398            .finish()
1399    }
1400}
1401
1402pub struct SharedPgTransaction<'conn> {
1403    inner: Arc<Mutex<PgAsyncConnection>>,
1404    committed: bool,
1405    _marker: std::marker::PhantomData<&'conn ()>,
1406}
1407
1408impl<'conn> Drop for SharedPgTransaction<'conn> {
1409    fn drop(&mut self) {
1410        if !self.committed {
1411            // WARNING: Transaction was dropped without commit() or rollback()!
1412            // We cannot do async work in Drop, so the PostgreSQL transaction will
1413            // remain open until the connection is closed or a new transaction
1414            // is started.
1415            #[cfg(debug_assertions)]
1416            eprintln!(
1417                "WARNING: SharedPgTransaction dropped without commit/rollback. \
1418                 The PostgreSQL transaction may still be open."
1419            );
1420        }
1421    }
1422}
1423
1424impl Connection for SharedPgConnection {
1425    type Tx<'conn>
1426        = SharedPgTransaction<'conn>
1427    where
1428        Self: 'conn;
1429
1430    fn dialect(&self) -> sqlmodel_core::Dialect {
1431        sqlmodel_core::Dialect::Postgres
1432    }
1433
1434    fn query(
1435        &self,
1436        cx: &Cx,
1437        sql: &str,
1438        params: &[Value],
1439    ) -> impl Future<Output = Outcome<Vec<Row>, Error>> + Send {
1440        let inner = Arc::clone(&self.inner);
1441        let sql = sql.to_string();
1442        let params = params.to_vec();
1443        async move {
1444            let Ok(mut guard) = inner.lock(cx).await else {
1445                return Outcome::Err(connection_error("Failed to acquire connection lock"));
1446            };
1447            guard.query_async(cx, &sql, &params).await
1448        }
1449    }
1450
1451    fn query_one(
1452        &self,
1453        cx: &Cx,
1454        sql: &str,
1455        params: &[Value],
1456    ) -> impl Future<Output = Outcome<Option<Row>, Error>> + Send {
1457        let inner = Arc::clone(&self.inner);
1458        let sql = sql.to_string();
1459        let params = params.to_vec();
1460        async move {
1461            let Ok(mut guard) = inner.lock(cx).await else {
1462                return Outcome::Err(connection_error("Failed to acquire connection lock"));
1463            };
1464            let rows = match guard.query_async(cx, &sql, &params).await {
1465                Outcome::Ok(r) => r,
1466                Outcome::Err(e) => return Outcome::Err(e),
1467                Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1468                Outcome::Panicked(p) => return Outcome::Panicked(p),
1469            };
1470            Outcome::Ok(rows.into_iter().next())
1471        }
1472    }
1473
1474    fn execute(
1475        &self,
1476        cx: &Cx,
1477        sql: &str,
1478        params: &[Value],
1479    ) -> impl Future<Output = Outcome<u64, Error>> + Send {
1480        let inner = Arc::clone(&self.inner);
1481        let sql = sql.to_string();
1482        let params = params.to_vec();
1483        async move {
1484            let Ok(mut guard) = inner.lock(cx).await else {
1485                return Outcome::Err(connection_error("Failed to acquire connection lock"));
1486            };
1487            guard.execute_async(cx, &sql, &params).await
1488        }
1489    }
1490
1491    fn insert(
1492        &self,
1493        cx: &Cx,
1494        sql: &str,
1495        params: &[Value],
1496    ) -> impl Future<Output = Outcome<i64, Error>> + Send {
1497        let inner = Arc::clone(&self.inner);
1498        let sql = sql.to_string();
1499        let params = params.to_vec();
1500        async move {
1501            let Ok(mut guard) = inner.lock(cx).await else {
1502                return Outcome::Err(connection_error("Failed to acquire connection lock"));
1503            };
1504            guard.insert_async(cx, &sql, &params).await
1505        }
1506    }
1507
1508    fn batch(
1509        &self,
1510        cx: &Cx,
1511        statements: &[(String, Vec<Value>)],
1512    ) -> impl Future<Output = Outcome<Vec<u64>, Error>> + Send {
1513        let inner = Arc::clone(&self.inner);
1514        let statements = statements.to_vec();
1515        async move {
1516            let Ok(mut guard) = inner.lock(cx).await else {
1517                return Outcome::Err(connection_error("Failed to acquire connection lock"));
1518            };
1519            let mut results = Vec::with_capacity(statements.len());
1520            for (sql, params) in &statements {
1521                match guard.execute_async(cx, sql, params).await {
1522                    Outcome::Ok(n) => results.push(n),
1523                    Outcome::Err(e) => return Outcome::Err(e),
1524                    Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1525                    Outcome::Panicked(p) => return Outcome::Panicked(p),
1526                }
1527            }
1528            Outcome::Ok(results)
1529        }
1530    }
1531
1532    fn begin(&self, cx: &Cx) -> impl Future<Output = Outcome<Self::Tx<'_>, Error>> + Send {
1533        self.begin_with(cx, IsolationLevel::default())
1534    }
1535
1536    fn begin_with(
1537        &self,
1538        cx: &Cx,
1539        isolation: IsolationLevel,
1540    ) -> impl Future<Output = Outcome<Self::Tx<'_>, Error>> + Send {
1541        self.begin_transaction_impl(cx, Some(isolation))
1542    }
1543
1544    fn prepare(
1545        &self,
1546        cx: &Cx,
1547        sql: &str,
1548    ) -> impl Future<Output = Outcome<PreparedStatement, Error>> + Send {
1549        let inner = Arc::clone(&self.inner);
1550        let sql = sql.to_string();
1551        async move {
1552            let Ok(mut guard) = inner.lock(cx).await else {
1553                return Outcome::Err(connection_error("Failed to acquire connection lock"));
1554            };
1555            guard.prepare_async(cx, &sql).await
1556        }
1557    }
1558
1559    fn query_prepared(
1560        &self,
1561        cx: &Cx,
1562        stmt: &PreparedStatement,
1563        params: &[Value],
1564    ) -> impl Future<Output = Outcome<Vec<Row>, Error>> + Send {
1565        let inner = Arc::clone(&self.inner);
1566        let stmt = stmt.clone();
1567        let params = params.to_vec();
1568        async move {
1569            let Ok(mut guard) = inner.lock(cx).await else {
1570                return Outcome::Err(connection_error("Failed to acquire connection lock"));
1571            };
1572            guard.query_prepared_async(cx, &stmt, &params).await
1573        }
1574    }
1575
1576    fn execute_prepared(
1577        &self,
1578        cx: &Cx,
1579        stmt: &PreparedStatement,
1580        params: &[Value],
1581    ) -> impl Future<Output = Outcome<u64, Error>> + Send {
1582        let inner = Arc::clone(&self.inner);
1583        let stmt = stmt.clone();
1584        let params = params.to_vec();
1585        async move {
1586            let Ok(mut guard) = inner.lock(cx).await else {
1587                return Outcome::Err(connection_error("Failed to acquire connection lock"));
1588            };
1589            guard.execute_prepared_async(cx, &stmt, &params).await
1590        }
1591    }
1592
1593    fn ping(&self, cx: &Cx) -> impl Future<Output = Outcome<(), Error>> + Send {
1594        let inner = Arc::clone(&self.inner);
1595        async move {
1596            let Ok(mut guard) = inner.lock(cx).await else {
1597                return Outcome::Err(connection_error("Failed to acquire connection lock"));
1598            };
1599            guard.ping_async(cx).await
1600        }
1601    }
1602
1603    async fn close(self, cx: &Cx) -> sqlmodel_core::Result<()> {
1604        let Ok(mut guard) = self.inner.lock(cx).await else {
1605            return Err(connection_error("Failed to acquire connection lock"));
1606        };
1607        match guard.close_async(cx).await {
1608            Outcome::Ok(()) => Ok(()),
1609            Outcome::Err(e) => Err(e),
1610            Outcome::Cancelled(r) => Err(Error::Query(QueryError {
1611                kind: QueryErrorKind::Cancelled,
1612                message: format!("Cancelled: {r:?}"),
1613                sqlstate: None,
1614                sql: None,
1615                detail: None,
1616                hint: None,
1617                position: None,
1618                source: None,
1619            })),
1620            Outcome::Panicked(p) => Err(Error::Protocol(ProtocolError {
1621                message: format!("Panicked: {p:?}"),
1622                raw_data: None,
1623                source: None,
1624            })),
1625        }
1626    }
1627}
1628
1629impl<'conn> TransactionOps for SharedPgTransaction<'conn> {
1630    fn query(
1631        &self,
1632        cx: &Cx,
1633        sql: &str,
1634        params: &[Value],
1635    ) -> impl Future<Output = Outcome<Vec<Row>, Error>> + Send {
1636        let inner = Arc::clone(&self.inner);
1637        let sql = sql.to_string();
1638        let params = params.to_vec();
1639        async move {
1640            let Ok(mut guard) = inner.lock(cx).await else {
1641                return Outcome::Err(connection_error("Failed to acquire connection lock"));
1642            };
1643            guard.query_async(cx, &sql, &params).await
1644        }
1645    }
1646
1647    fn query_one(
1648        &self,
1649        cx: &Cx,
1650        sql: &str,
1651        params: &[Value],
1652    ) -> impl Future<Output = Outcome<Option<Row>, Error>> + Send {
1653        let inner = Arc::clone(&self.inner);
1654        let sql = sql.to_string();
1655        let params = params.to_vec();
1656        async move {
1657            let Ok(mut guard) = inner.lock(cx).await else {
1658                return Outcome::Err(connection_error("Failed to acquire connection lock"));
1659            };
1660            let rows = match guard.query_async(cx, &sql, &params).await {
1661                Outcome::Ok(r) => r,
1662                Outcome::Err(e) => return Outcome::Err(e),
1663                Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1664                Outcome::Panicked(p) => return Outcome::Panicked(p),
1665            };
1666            Outcome::Ok(rows.into_iter().next())
1667        }
1668    }
1669
1670    fn execute(
1671        &self,
1672        cx: &Cx,
1673        sql: &str,
1674        params: &[Value],
1675    ) -> impl Future<Output = Outcome<u64, Error>> + Send {
1676        let inner = Arc::clone(&self.inner);
1677        let sql = sql.to_string();
1678        let params = params.to_vec();
1679        async move {
1680            let Ok(mut guard) = inner.lock(cx).await else {
1681                return Outcome::Err(connection_error("Failed to acquire connection lock"));
1682            };
1683            guard.execute_async(cx, &sql, &params).await
1684        }
1685    }
1686
1687    fn savepoint(&self, cx: &Cx, name: &str) -> impl Future<Output = Outcome<(), Error>> + Send {
1688        let inner = Arc::clone(&self.inner);
1689        let name = name.to_string();
1690        async move {
1691            if let Err(e) = validate_savepoint_name(&name) {
1692                return Outcome::Err(e);
1693            }
1694            let sql = format!("SAVEPOINT {}", name);
1695            let Ok(mut guard) = inner.lock(cx).await else {
1696                return Outcome::Err(connection_error("Failed to acquire connection lock"));
1697            };
1698            guard.execute_async(cx, &sql, &[]).await.map(|_| ())
1699        }
1700    }
1701
1702    fn rollback_to(&self, cx: &Cx, name: &str) -> impl Future<Output = Outcome<(), Error>> + Send {
1703        let inner = Arc::clone(&self.inner);
1704        let name = name.to_string();
1705        async move {
1706            if let Err(e) = validate_savepoint_name(&name) {
1707                return Outcome::Err(e);
1708            }
1709            let sql = format!("ROLLBACK TO SAVEPOINT {}", name);
1710            let Ok(mut guard) = inner.lock(cx).await else {
1711                return Outcome::Err(connection_error("Failed to acquire connection lock"));
1712            };
1713            guard.execute_async(cx, &sql, &[]).await.map(|_| ())
1714        }
1715    }
1716
1717    fn release(&self, cx: &Cx, name: &str) -> impl Future<Output = Outcome<(), Error>> + Send {
1718        let inner = Arc::clone(&self.inner);
1719        let name = name.to_string();
1720        async move {
1721            if let Err(e) = validate_savepoint_name(&name) {
1722                return Outcome::Err(e);
1723            }
1724            let sql = format!("RELEASE SAVEPOINT {}", name);
1725            let Ok(mut guard) = inner.lock(cx).await else {
1726                return Outcome::Err(connection_error("Failed to acquire connection lock"));
1727            };
1728            guard.execute_async(cx, &sql, &[]).await.map(|_| ())
1729        }
1730    }
1731
1732    // Note: clippy sometimes flags `self.committed = true` as unused, but Drop reads it.
1733    #[allow(unused_assignments)]
1734    fn commit(mut self, cx: &Cx) -> impl Future<Output = Outcome<(), Error>> + Send {
1735        let inner = Arc::clone(&self.inner);
1736        async move {
1737            let Ok(mut guard) = inner.lock(cx).await else {
1738                return Outcome::Err(connection_error("Failed to acquire connection lock"));
1739            };
1740            let result = guard.execute_async(cx, "COMMIT", &[]).await;
1741            if matches!(result, Outcome::Ok(_)) {
1742                self.committed = true;
1743            }
1744            result.map(|_| ())
1745        }
1746    }
1747
1748    #[allow(unused_assignments)]
1749    fn rollback(mut self, cx: &Cx) -> impl Future<Output = Outcome<(), Error>> + Send {
1750        let inner = Arc::clone(&self.inner);
1751        async move {
1752            let Ok(mut guard) = inner.lock(cx).await else {
1753                return Outcome::Err(connection_error("Failed to acquire connection lock"));
1754            };
1755            let result = guard.execute_async(cx, "ROLLBACK", &[]).await;
1756            if matches!(result, Outcome::Ok(_)) {
1757                self.committed = true;
1758            }
1759            result.map(|_| ())
1760        }
1761    }
1762}
1763
1764// ==================== Helpers ====================
1765
1766struct PgQueryResult {
1767    rows: Vec<Row>,
1768    command_tag: Option<String>,
1769}
1770
1771fn connection_error(msg: impl Into<String>) -> Error {
1772    Error::Connection(ConnectionError {
1773        kind: ConnectionErrorKind::Connect,
1774        message: msg.into(),
1775        source: None,
1776    })
1777}
1778
1779fn auth_error(msg: impl Into<String>) -> Error {
1780    Error::Connection(ConnectionError {
1781        kind: ConnectionErrorKind::Authentication,
1782        message: msg.into(),
1783        source: None,
1784    })
1785}
1786
1787fn protocol_error(msg: impl Into<String>) -> Error {
1788    Error::Protocol(ProtocolError {
1789        message: msg.into(),
1790        raw_data: None,
1791        source: None,
1792    })
1793}
1794
1795fn query_error_msg(msg: impl Into<String>, kind: QueryErrorKind) -> Error {
1796    Error::Query(QueryError {
1797        kind,
1798        message: msg.into(),
1799        sqlstate: None,
1800        sql: None,
1801        detail: None,
1802        hint: None,
1803        position: None,
1804        source: None,
1805    })
1806}
1807
1808fn error_from_fields(fields: &ErrorFields) -> Error {
1809    let kind = match fields.code.get(..2) {
1810        Some("08") => {
1811            return Error::Connection(ConnectionError {
1812                kind: ConnectionErrorKind::Connect,
1813                message: fields.message.clone(),
1814                source: None,
1815            });
1816        }
1817        Some("28") => {
1818            return Error::Connection(ConnectionError {
1819                kind: ConnectionErrorKind::Authentication,
1820                message: fields.message.clone(),
1821                source: None,
1822            });
1823        }
1824        Some("42") => QueryErrorKind::Syntax,
1825        Some("23") => QueryErrorKind::Constraint,
1826        Some("40") => {
1827            if fields.code == "40001" {
1828                QueryErrorKind::Serialization
1829            } else {
1830                QueryErrorKind::Deadlock
1831            }
1832        }
1833        Some("57") => {
1834            if fields.code == "57014" {
1835                QueryErrorKind::Cancelled
1836            } else {
1837                QueryErrorKind::Timeout
1838            }
1839        }
1840        _ => QueryErrorKind::Database,
1841    };
1842
1843    Error::Query(QueryError {
1844        kind,
1845        sql: None,
1846        sqlstate: Some(fields.code.clone()),
1847        message: fields.message.clone(),
1848        detail: fields.detail.clone(),
1849        hint: fields.hint.clone(),
1850        position: fields.position.map(|p| p as usize),
1851        source: None,
1852    })
1853}
1854
1855fn parse_rows_affected(tag: Option<&str>) -> Option<u64> {
1856    let tag = tag?;
1857    let mut parts = tag.split_whitespace().collect::<Vec<_>>();
1858    parts.pop().and_then(|last| last.parse::<u64>().ok())
1859}
1860
1861/// Validate a savepoint name to reduce SQL injection risk.
1862fn validate_savepoint_name(name: &str) -> sqlmodel_core::Result<()> {
1863    if name.is_empty() {
1864        return Err(query_error_msg(
1865            "Savepoint name cannot be empty",
1866            QueryErrorKind::Syntax,
1867        ));
1868    }
1869    if name.len() > 63 {
1870        return Err(query_error_msg(
1871            "Savepoint name exceeds maximum length of 63 characters",
1872            QueryErrorKind::Syntax,
1873        ));
1874    }
1875    let mut chars = name.chars();
1876    let Some(first) = chars.next() else {
1877        return Err(query_error_msg(
1878            "Savepoint name cannot be empty",
1879            QueryErrorKind::Syntax,
1880        ));
1881    };
1882    if !first.is_ascii_alphabetic() && first != '_' {
1883        return Err(query_error_msg(
1884            "Savepoint name must start with a letter or underscore",
1885            QueryErrorKind::Syntax,
1886        ));
1887    }
1888    for c in chars {
1889        if !c.is_ascii_alphanumeric() && c != '_' {
1890            return Err(query_error_msg(
1891                format!("Savepoint name contains invalid character: '{c}'"),
1892                QueryErrorKind::Syntax,
1893            ));
1894        }
1895    }
1896    Ok(())
1897}
1898
1899fn md5_password(user: &str, password: &str, salt: [u8; 4]) -> String {
1900    use std::fmt::Write;
1901
1902    let inner = format!("{password}{user}");
1903    let inner_hash = md5::compute(inner.as_bytes());
1904
1905    let mut outer_input = format!("{inner_hash:x}").into_bytes();
1906    outer_input.extend_from_slice(&salt);
1907    let outer_hash = md5::compute(&outer_input);
1908
1909    let mut result = String::with_capacity(35);
1910    result.push_str("md5");
1911    write!(&mut result, "{outer_hash:x}").unwrap();
1912    result
1913}
1914
1915// Note: read/write helpers are implemented above on PgAsyncStream.