zero_postgres/tokio/
conn.rs

1//! Asynchronous PostgreSQL connection.
2
3use tokio::net::TcpStream;
4use tokio::net::UnixStream;
5
6use crate::buffer_pool::PooledBufferSet;
7use crate::conversion::ToParams;
8use crate::error::{Error, Result};
9use crate::handler::{
10    AsyncMessageHandler, BinaryHandler, DropHandler, FirstRowHandler, TextHandler,
11};
12use crate::opts::Opts;
13use crate::protocol::backend::BackendKeyData;
14use crate::protocol::frontend::write_terminate;
15use crate::protocol::types::TransactionStatus;
16use crate::state::StateMachine;
17use crate::state::action::Action;
18use crate::state::connection::ConnectionStateMachine;
19use crate::state::extended::{BindStateMachine, ExtendedQueryStateMachine, PreparedStatement};
20use crate::state::simple_query::SimpleQueryStateMachine;
21use crate::statement::IntoStatement;
22
23use super::stream::Stream;
24
25/// Asynchronous PostgreSQL connection.
26pub struct Conn {
27    pub(crate) stream: Stream,
28    pub(crate) buffer_set: PooledBufferSet,
29    backend_key: Option<BackendKeyData>,
30    server_params: Vec<(String, String)>,
31    pub(crate) transaction_status: TransactionStatus,
32    pub(crate) is_broken: bool,
33    name_counter: u64,
34    async_message_handler: Option<Box<dyn AsyncMessageHandler>>,
35}
36
37impl Conn {
38    /// Connect to a PostgreSQL server.
39    pub async fn new<O: TryInto<Opts>>(opts: O) -> Result<Self>
40    where
41        Error: From<O::Error>,
42    {
43        let opts = opts.try_into()?;
44
45        let stream = if let Some(socket_path) = &opts.socket {
46            Stream::unix(UnixStream::connect(socket_path).await?)
47        } else {
48            if opts.host.is_empty() {
49                return Err(Error::InvalidUsage("host is empty".into()));
50            }
51            let addr = format!("{}:{}", opts.host, opts.port);
52            let tcp = TcpStream::connect(&addr).await?;
53            tcp.set_nodelay(true)?;
54            Stream::tcp(tcp)
55        };
56
57        Self::new_with_stream(stream, opts).await
58    }
59
60    /// Connect using an existing stream.
61    #[allow(unused_mut)]
62    pub async fn new_with_stream(mut stream: Stream, options: Opts) -> Result<Self> {
63        let mut buffer_set = options.buffer_pool.get_buffer_set();
64        let mut state_machine = ConnectionStateMachine::new(options.clone());
65
66        // Drive the connection state machine
67        loop {
68            match state_machine.step(&mut buffer_set)? {
69                Action::WriteAndReadByte => {
70                    stream.write_all(&buffer_set.write_buffer).await?;
71                    stream.flush().await?;
72                    let byte = stream.read_u8().await?;
73                    state_machine.set_ssl_response(byte);
74                }
75                Action::ReadMessage => {
76                    stream.read_message(&mut buffer_set).await?;
77                }
78                Action::Write => {
79                    stream.write_all(&buffer_set.write_buffer).await?;
80                    stream.flush().await?;
81                }
82                Action::WriteAndReadMessage => {
83                    stream.write_all(&buffer_set.write_buffer).await?;
84                    stream.flush().await?;
85                    stream.read_message(&mut buffer_set).await?;
86                }
87                Action::TlsHandshake => {
88                    #[cfg(feature = "tokio-tls")]
89                    {
90                        stream = stream.upgrade_to_tls(&options.host).await?;
91                    }
92                    #[cfg(not(feature = "tokio-tls"))]
93                    {
94                        return Err(Error::Unsupported(
95                            "TLS requested but tokio-tls feature not enabled".into(),
96                        ));
97                    }
98                }
99                Action::HandleAsyncMessageAndReadMessage(_) => {
100                    // Ignore async messages during startup, read next message
101                    stream.read_message(&mut buffer_set).await?;
102                }
103                Action::Finished => break,
104            }
105        }
106
107        let conn = Self {
108            stream,
109            buffer_set,
110            backend_key: state_machine.backend_key().cloned(),
111            server_params: state_machine.take_server_params(),
112            transaction_status: state_machine.transaction_status(),
113            is_broken: false,
114            name_counter: 0,
115            async_message_handler: None,
116        };
117
118        // Upgrade to Unix socket if connected via TCP to loopback
119        let conn = if options.upgrade_to_unix_socket && conn.stream.is_tcp_loopback() {
120            conn.try_upgrade_to_unix_socket(&options).await
121        } else {
122            conn
123        };
124
125        Ok(conn)
126    }
127
128    /// Try to upgrade to Unix socket connection.
129    /// Returns upgraded conn on success, original conn on failure.
130    fn try_upgrade_to_unix_socket(
131        mut self,
132        opts: &Opts,
133    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Self> + Send + '_>> {
134        let opts = opts.clone();
135        Box::pin(async move {
136            // Query unix_socket_directories from server
137            let mut handler = FirstRowHandler::<(String,)>::new();
138            if self
139                .query("SHOW unix_socket_directories", &mut handler)
140                .await
141                .is_err()
142            {
143                return self;
144            }
145
146            let socket_dir = match handler.into_row() {
147                Some((dirs,)) => {
148                    // May contain multiple directories, use the first one
149                    match dirs.split(',').next() {
150                        Some(d) if !d.trim().is_empty() => d.trim().to_string(),
151                        _ => return self,
152                    }
153                }
154                None => return self,
155            };
156
157            // Build socket path: {directory}/.s.PGSQL.{port}
158            let socket_path = format!("{}/.s.PGSQL.{}", socket_dir, opts.port);
159
160            // Connect via Unix socket
161            let unix_stream = match UnixStream::connect(&socket_path).await {
162                Ok(s) => s,
163                Err(_) => return self,
164            };
165
166            // Create new connection over Unix socket
167            let mut opts_unix = opts.clone();
168            opts_unix.upgrade_to_unix_socket = false;
169
170            match Self::new_with_stream(Stream::unix(unix_stream), opts_unix).await {
171                Ok(new_conn) => new_conn,
172                Err(_) => self,
173            }
174        })
175    }
176
177    /// Get the backend key data for query cancellation.
178    pub fn backend_key(&self) -> Option<&BackendKeyData> {
179        self.backend_key.as_ref()
180    }
181
182    /// Get the connection ID (backend process ID).
183    ///
184    /// Returns 0 if the backend key data is not available.
185    pub fn connection_id(&self) -> u32 {
186        self.backend_key.as_ref().map_or(0, |k| k.process_id())
187    }
188
189    /// Get server parameters.
190    pub fn server_params(&self) -> &[(String, String)] {
191        &self.server_params
192    }
193
194    /// Get the current transaction status.
195    pub fn transaction_status(&self) -> TransactionStatus {
196        self.transaction_status
197    }
198
199    /// Check if currently in a transaction.
200    pub fn in_transaction(&self) -> bool {
201        self.transaction_status.in_transaction()
202    }
203
204    /// Check if the connection is broken.
205    pub fn is_broken(&self) -> bool {
206        self.is_broken
207    }
208
209    /// Generate the next unique portal name.
210    pub(crate) fn next_portal_name(&mut self) -> String {
211        self.name_counter += 1;
212        format!("_zero_p_{}", self.name_counter)
213    }
214
215    /// Create a named portal by binding a statement.
216    ///
217    /// Used internally by Transaction::exec_portal.
218    pub(crate) async fn create_named_portal<S: IntoStatement, P: ToParams>(
219        &mut self,
220        portal_name: &str,
221        statement: &S,
222        params: &P,
223    ) -> Result<()> {
224        // Create bind state machine for named portal
225        let mut state_machine = if let Some(sql) = statement.as_sql() {
226            BindStateMachine::bind_sql(&mut self.buffer_set, portal_name, sql, params)?
227        } else {
228            let stmt = statement.as_prepared().unwrap();
229            BindStateMachine::bind_prepared(
230                &mut self.buffer_set,
231                portal_name,
232                &stmt.wire_name(),
233                &stmt.param_oids,
234                params,
235            )?
236        };
237
238        // Drive the state machine to completion (ParseComplete + BindComplete)
239        loop {
240            match state_machine.step(&mut self.buffer_set)? {
241                Action::ReadMessage => {
242                    self.stream.read_message(&mut self.buffer_set).await?;
243                }
244                Action::Write => {
245                    self.stream.write_all(&self.buffer_set.write_buffer).await?;
246                    self.stream.flush().await?;
247                }
248                Action::WriteAndReadMessage => {
249                    self.stream.write_all(&self.buffer_set.write_buffer).await?;
250                    self.stream.flush().await?;
251                    self.stream.read_message(&mut self.buffer_set).await?;
252                }
253                Action::Finished => break,
254                _ => return Err(Error::Protocol("Unexpected action in bind".into())),
255            }
256        }
257
258        Ok(())
259    }
260
261    /// Set the async message handler.
262    ///
263    /// The handler is called when the server sends asynchronous messages:
264    /// - `Notification` - from LISTEN/NOTIFY
265    /// - `Notice` - warnings and informational messages
266    /// - `ParameterChanged` - server parameter updates
267    pub fn set_async_message_handler<H: AsyncMessageHandler + 'static>(&mut self, handler: H) {
268        self.async_message_handler = Some(Box::new(handler));
269    }
270
271    /// Remove the async message handler.
272    pub fn clear_async_message_handler(&mut self) {
273        self.async_message_handler = None;
274    }
275
276    /// Ping the server with an empty query to check connection aliveness.
277    pub async fn ping(&mut self) -> Result<()> {
278        self.query_drop("").await?;
279        Ok(())
280    }
281
282    /// Drive a state machine to completion.
283    async fn drive<S: StateMachine>(&mut self, state_machine: &mut S) -> Result<()> {
284        loop {
285            match state_machine.step(&mut self.buffer_set)? {
286                Action::WriteAndReadByte => {
287                    return Err(Error::Protocol(
288                        "Unexpected WriteAndReadByte in query state machine".into(),
289                    ));
290                }
291                Action::ReadMessage => {
292                    self.stream.read_message(&mut self.buffer_set).await?;
293                }
294                Action::Write => {
295                    self.stream.write_all(&self.buffer_set.write_buffer).await?;
296                    self.stream.flush().await?;
297                }
298                Action::WriteAndReadMessage => {
299                    self.stream.write_all(&self.buffer_set.write_buffer).await?;
300                    self.stream.flush().await?;
301                    self.stream.read_message(&mut self.buffer_set).await?;
302                }
303                Action::TlsHandshake => {
304                    return Err(Error::Protocol(
305                        "Unexpected TlsHandshake in query state machine".into(),
306                    ));
307                }
308                Action::HandleAsyncMessageAndReadMessage(ref async_msg) => {
309                    if let Some(ref mut h) = self.async_message_handler {
310                        h.handle(async_msg);
311                    }
312                    // Read next message after handling async message
313                    self.stream.read_message(&mut self.buffer_set).await?;
314                }
315                Action::Finished => {
316                    self.transaction_status = state_machine.transaction_status();
317                    break;
318                }
319            }
320        }
321        Ok(())
322    }
323
324    /// Execute a simple query with a handler.
325    pub async fn query<H: TextHandler>(&mut self, sql: &str, handler: &mut H) -> Result<()> {
326        let result = self.query_inner(sql, handler).await;
327        if let Err(e) = &result
328            && e.is_connection_broken()
329        {
330            self.is_broken = true;
331        }
332        result
333    }
334
335    async fn query_inner<H: TextHandler>(&mut self, sql: &str, handler: &mut H) -> Result<()> {
336        let mut state_machine = SimpleQueryStateMachine::new(handler, sql);
337        self.drive(&mut state_machine).await
338    }
339
340    /// Execute a simple query and discard results.
341    pub async fn query_drop(&mut self, sql: &str) -> Result<Option<u64>> {
342        let mut handler = DropHandler::new();
343        self.query(sql, &mut handler).await?;
344        Ok(handler.rows_affected())
345    }
346
347    /// Execute a simple query and collect typed rows.
348    pub async fn query_collect<T: for<'a> crate::conversion::FromRow<'a>>(
349        &mut self,
350        sql: &str,
351    ) -> Result<Vec<T>> {
352        let mut handler = crate::handler::CollectHandler::<T>::new();
353        self.query(sql, &mut handler).await?;
354        Ok(handler.into_rows())
355    }
356
357    /// Execute a simple query and return the first typed row.
358    pub async fn query_first<T: for<'a> crate::conversion::FromRow<'a>>(
359        &mut self,
360        sql: &str,
361    ) -> Result<Option<T>> {
362        let mut handler = crate::handler::FirstRowHandler::<T>::new();
363        self.query(sql, &mut handler).await?;
364        Ok(handler.into_row())
365    }
366
367    /// Close the connection gracefully.
368    pub async fn close(mut self) -> Result<()> {
369        self.buffer_set.write_buffer.clear();
370        write_terminate(&mut self.buffer_set.write_buffer);
371        self.stream.write_all(&self.buffer_set.write_buffer).await?;
372        self.stream.flush().await?;
373        Ok(())
374    }
375
376    // === Extended Query Protocol ===
377
378    /// Prepare a statement using the extended query protocol.
379    pub async fn prepare(&mut self, query: &str) -> Result<PreparedStatement> {
380        self.prepare_typed(query, &[]).await
381    }
382
383    /// Prepare a statement with explicit parameter types.
384    pub async fn prepare_typed(
385        &mut self,
386        query: &str,
387        param_oids: &[u32],
388    ) -> Result<PreparedStatement> {
389        self.name_counter += 1;
390        let idx = self.name_counter;
391        let result = self.prepare_inner(idx, query, param_oids).await;
392        if let Err(e) = &result
393            && e.is_connection_broken()
394        {
395            self.is_broken = true;
396        }
397        result
398    }
399
400    /// Prepare multiple statements in a single round-trip.
401    ///
402    /// This is more efficient than calling `prepare()` multiple times when you
403    /// need to prepare several statements, as it batches the network communication.
404    ///
405    /// # Example
406    ///
407    /// ```ignore
408    /// let stmts = conn.prepare_batch(&[
409    ///     "SELECT id, name FROM users WHERE id = $1",
410    ///     "INSERT INTO users (name) VALUES ($1) RETURNING id",
411    ///     "UPDATE users SET name = $1 WHERE id = $2",
412    /// ]).await?;
413    ///
414    /// // Use stmts[0], stmts[1], stmts[2]...
415    /// ```
416    pub async fn prepare_batch(&mut self, queries: &[&str]) -> Result<Vec<PreparedStatement>> {
417        if queries.is_empty() {
418            return Ok(Vec::new());
419        }
420
421        let start_idx = self.name_counter + 1;
422        self.name_counter += queries.len() as u64;
423
424        let result = self.prepare_batch_inner(queries, start_idx).await;
425        if let Err(e) = &result
426            && e.is_connection_broken()
427        {
428            self.is_broken = true;
429        }
430        result
431    }
432
433    async fn prepare_batch_inner(
434        &mut self,
435        queries: &[&str],
436        start_idx: u64,
437    ) -> Result<Vec<PreparedStatement>> {
438        use crate::state::batch_prepare::BatchPrepareStateMachine;
439
440        let mut state_machine =
441            BatchPrepareStateMachine::new(&mut self.buffer_set, queries, start_idx);
442
443        loop {
444            match state_machine.step(&mut self.buffer_set)? {
445                Action::ReadMessage => {
446                    self.stream.read_message(&mut self.buffer_set).await?;
447                }
448                Action::WriteAndReadMessage => {
449                    self.stream.write_all(&self.buffer_set.write_buffer).await?;
450                    self.stream.flush().await?;
451                    self.stream.read_message(&mut self.buffer_set).await?;
452                }
453                Action::Finished => {
454                    self.transaction_status = state_machine.transaction_status();
455                    break;
456                }
457                _ => return Err(Error::Protocol("Unexpected action in batch prepare".into())),
458            }
459        }
460
461        Ok(state_machine.take_statements())
462    }
463
464    async fn prepare_inner(
465        &mut self,
466        idx: u64,
467        query: &str,
468        param_oids: &[u32],
469    ) -> Result<PreparedStatement> {
470        let mut handler = DropHandler::new();
471        let mut state_machine = ExtendedQueryStateMachine::prepare(
472            &mut handler,
473            &mut self.buffer_set,
474            idx,
475            query,
476            param_oids,
477        );
478        self.drive(&mut state_machine).await?;
479        state_machine
480            .take_prepared_statement()
481            .ok_or_else(|| Error::Protocol("No prepared statement".into()))
482    }
483
484    /// Execute a statement with a handler.
485    ///
486    /// The statement can be either:
487    /// - A `&PreparedStatement` returned from `prepare()`
488    /// - A raw SQL `&str` for one-shot execution
489    pub async fn exec<S: IntoStatement, P: ToParams, H: BinaryHandler>(
490        &mut self,
491        statement: S,
492        params: P,
493        handler: &mut H,
494    ) -> Result<()> {
495        let result = self.exec_inner(&statement, &params, handler).await;
496        if let Err(e) = &result
497            && e.is_connection_broken()
498        {
499            self.is_broken = true;
500        }
501        result
502    }
503
504    async fn exec_inner<S: IntoStatement, P: ToParams, H: BinaryHandler>(
505        &mut self,
506        statement: &S,
507        params: &P,
508        handler: &mut H,
509    ) -> Result<()> {
510        let mut state_machine = if statement.needs_parse() {
511            ExtendedQueryStateMachine::execute_sql(
512                handler,
513                &mut self.buffer_set,
514                statement.as_sql().unwrap(),
515                params,
516            )?
517        } else {
518            let stmt = statement.as_prepared().unwrap();
519            ExtendedQueryStateMachine::execute(
520                handler,
521                &mut self.buffer_set,
522                &stmt.wire_name(),
523                &stmt.param_oids,
524                params,
525            )?
526        };
527
528        self.drive(&mut state_machine).await
529    }
530
531    /// Execute a statement and discard results.
532    ///
533    /// The statement can be either a `&PreparedStatement` or a raw SQL `&str`.
534    pub async fn exec_drop<S: IntoStatement, P: ToParams>(
535        &mut self,
536        statement: S,
537        params: P,
538    ) -> Result<Option<u64>> {
539        let mut handler = DropHandler::new();
540        self.exec(statement, params, &mut handler).await?;
541        Ok(handler.rows_affected())
542    }
543
544    /// Execute a statement and collect typed rows.
545    ///
546    /// The statement can be either a `&PreparedStatement` or a raw SQL `&str`.
547    pub async fn exec_collect<
548        T: for<'a> crate::conversion::FromRow<'a>,
549        S: IntoStatement,
550        P: ToParams,
551    >(
552        &mut self,
553        statement: S,
554        params: P,
555    ) -> Result<Vec<T>> {
556        let mut handler = crate::handler::CollectHandler::<T>::new();
557        self.exec(statement, params, &mut handler).await?;
558        Ok(handler.into_rows())
559    }
560
561    /// Execute a statement with multiple parameter sets in a batch.
562    ///
563    /// This is more efficient than calling `exec_drop` multiple times as it
564    /// batches the network communication. The statement is parsed once (if raw SQL)
565    /// and then bound/executed for each parameter set.
566    ///
567    /// Parameters are processed in chunks (default 1000) to avoid overwhelming
568    /// the server with too many pending operations.
569    ///
570    /// The statement can be either:
571    /// - A `&PreparedStatement` returned from `prepare()`
572    /// - A raw SQL `&str` for one-shot execution
573    ///
574    /// # Example
575    ///
576    /// ```ignore
577    /// // Using prepared statement
578    /// let stmt = conn.prepare("INSERT INTO users (name, age) VALUES ($1, $2)").await?;
579    /// conn.exec_batch(&stmt, &[
580    ///     ("alice", 30),
581    ///     ("bob", 25),
582    ///     ("charlie", 35),
583    /// ]).await?;
584    ///
585    /// // Using raw SQL
586    /// conn.exec_batch("INSERT INTO users (name, age) VALUES ($1, $2)", &[
587    ///     ("alice", 30),
588    ///     ("bob", 25),
589    /// ]).await?;
590    /// ```
591    pub async fn exec_batch<S: IntoStatement, P: ToParams>(
592        &mut self,
593        statement: S,
594        params_list: &[P],
595    ) -> Result<()> {
596        self.exec_batch_chunked(statement, params_list, 1000).await
597    }
598
599    /// Execute a statement with multiple parameter sets in a batch with custom chunk size.
600    ///
601    /// Same as `exec_batch` but allows specifying the chunk size for batching.
602    pub async fn exec_batch_chunked<S: IntoStatement, P: ToParams>(
603        &mut self,
604        statement: S,
605        params_list: &[P],
606        chunk_size: usize,
607    ) -> Result<()> {
608        let result = self
609            .exec_batch_inner(&statement, params_list, chunk_size)
610            .await;
611        if let Err(e) = &result
612            && e.is_connection_broken()
613        {
614            self.is_broken = true;
615        }
616        result
617    }
618
619    async fn exec_batch_inner<S: IntoStatement, P: ToParams>(
620        &mut self,
621        statement: &S,
622        params_list: &[P],
623        chunk_size: usize,
624    ) -> Result<()> {
625        use crate::protocol::frontend::{write_bind, write_execute, write_parse, write_sync};
626        use crate::state::extended::BatchStateMachine;
627
628        if params_list.is_empty() {
629            return Ok(());
630        }
631
632        let chunk_size = chunk_size.max(1);
633        let needs_parse = statement.needs_parse();
634        let sql = statement.as_sql();
635        let prepared = statement.as_prepared();
636
637        // Get param OIDs from first params or prepared statement
638        let param_oids: Vec<u32> = if let Some(stmt) = prepared {
639            stmt.param_oids.clone()
640        } else {
641            params_list[0].natural_oids()
642        };
643
644        // Statement name: empty for raw SQL, actual name for prepared
645        let stmt_name = prepared.map(|s| s.wire_name()).unwrap_or_default();
646
647        for chunk in params_list.chunks(chunk_size) {
648            self.buffer_set.write_buffer.clear();
649
650            // For raw SQL, send Parse each chunk (reuses unnamed statement)
651            let parse_in_chunk = needs_parse;
652            if parse_in_chunk {
653                write_parse(
654                    &mut self.buffer_set.write_buffer,
655                    "",
656                    sql.unwrap(),
657                    &param_oids,
658                );
659            }
660
661            // Write Bind + Execute for each param set
662            for params in chunk {
663                let effective_stmt_name = if needs_parse { "" } else { &stmt_name };
664                write_bind(
665                    &mut self.buffer_set.write_buffer,
666                    "",
667                    effective_stmt_name,
668                    params,
669                    &param_oids,
670                )?;
671                write_execute(&mut self.buffer_set.write_buffer, "", 0);
672            }
673
674            // Send Sync
675            write_sync(&mut self.buffer_set.write_buffer);
676
677            // Drive state machine
678            let mut state_machine = BatchStateMachine::new(parse_in_chunk);
679            self.drive_batch(&mut state_machine).await?;
680            self.transaction_status = state_machine.transaction_status();
681        }
682
683        Ok(())
684    }
685
686    /// Drive a batch state machine to completion.
687    async fn drive_batch(
688        &mut self,
689        state_machine: &mut crate::state::extended::BatchStateMachine,
690    ) -> Result<()> {
691        use crate::protocol::backend::{ReadyForQuery, msg_type};
692        use crate::state::action::Action;
693
694        loop {
695            let step_result = state_machine.step(&mut self.buffer_set);
696            match step_result {
697                Ok(Action::ReadMessage) => {
698                    self.stream.read_message(&mut self.buffer_set).await?;
699                }
700                Ok(Action::WriteAndReadMessage) => {
701                    self.stream.write_all(&self.buffer_set.write_buffer).await?;
702                    self.stream.flush().await?;
703                    self.stream.read_message(&mut self.buffer_set).await?;
704                }
705                Ok(Action::Finished) => {
706                    break;
707                }
708                Ok(_) => return Err(Error::Protocol("Unexpected action in batch".into())),
709                Err(e) => {
710                    // On error, drain to ReadyForQuery to leave connection in clean state
711                    loop {
712                        self.stream.read_message(&mut self.buffer_set).await?;
713                        if self.buffer_set.type_byte == msg_type::READY_FOR_QUERY {
714                            let ready = ReadyForQuery::parse(&self.buffer_set.read_buffer)?;
715                            self.transaction_status =
716                                ready.transaction_status().unwrap_or_default();
717                            break;
718                        }
719                    }
720                    return Err(e);
721                }
722            }
723        }
724        Ok(())
725    }
726
727    /// Close a prepared statement.
728    pub async fn close_statement(&mut self, stmt: &PreparedStatement) -> Result<()> {
729        let result = self.close_statement_inner(&stmt.wire_name()).await;
730        if let Err(e) = &result
731            && e.is_connection_broken()
732        {
733            self.is_broken = true;
734        }
735        result
736    }
737
738    async fn close_statement_inner(&mut self, name: &str) -> Result<()> {
739        let mut handler = DropHandler::new();
740        let mut state_machine =
741            ExtendedQueryStateMachine::close_statement(&mut handler, &mut self.buffer_set, name);
742        self.drive(&mut state_machine).await
743    }
744
745    // === Low-Level Extended Query Protocol ===
746
747    /// Low-level flush: send FLUSH to force server to send pending responses.
748    ///
749    /// Unlike SYNC, FLUSH does not end the transaction or wait for ReadyForQuery.
750    /// It just forces the server to send any pending responses without ending
751    /// the extended query sequence.
752    pub async fn lowlevel_flush(&mut self) -> Result<()> {
753        use crate::protocol::frontend::write_flush;
754
755        self.buffer_set.write_buffer.clear();
756        write_flush(&mut self.buffer_set.write_buffer);
757
758        self.stream.write_all(&self.buffer_set.write_buffer).await?;
759        self.stream.flush().await?;
760        Ok(())
761    }
762
763    /// Low-level sync: send SYNC and receive ReadyForQuery.
764    ///
765    /// This ends an extended query sequence and:
766    /// - Commits implicit transaction if successful
767    /// - Rolls back implicit transaction if failed
768    /// - Updates transaction status
769    pub async fn lowlevel_sync(&mut self) -> Result<()> {
770        let result = self.lowlevel_sync_inner().await;
771        if let Err(e) = &result
772            && e.is_connection_broken()
773        {
774            self.is_broken = true;
775        }
776        result
777    }
778
779    async fn lowlevel_sync_inner(&mut self) -> Result<()> {
780        use crate::protocol::backend::{ErrorResponse, RawMessage, ReadyForQuery, msg_type};
781        use crate::protocol::frontend::write_sync;
782
783        self.buffer_set.write_buffer.clear();
784        write_sync(&mut self.buffer_set.write_buffer);
785
786        self.stream.write_all(&self.buffer_set.write_buffer).await?;
787        self.stream.flush().await?;
788
789        let mut pending_error: Option<Error> = None;
790
791        loop {
792            self.stream.read_message(&mut self.buffer_set).await?;
793            let type_byte = self.buffer_set.type_byte;
794
795            if RawMessage::is_async_type(type_byte) {
796                continue;
797            }
798
799            match type_byte {
800                msg_type::READY_FOR_QUERY => {
801                    let ready = ReadyForQuery::parse(&self.buffer_set.read_buffer)?;
802                    self.transaction_status = ready.transaction_status().unwrap_or_default();
803                    if let Some(e) = pending_error {
804                        return Err(e);
805                    }
806                    return Ok(());
807                }
808                msg_type::ERROR_RESPONSE => {
809                    let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
810                    pending_error = Some(error.into_error());
811                }
812                _ => {
813                    // Ignore other messages before ReadyForQuery
814                }
815            }
816        }
817    }
818
819    /// Low-level bind: send BIND message and receive BindComplete.
820    ///
821    /// This allows creating named portals. Unlike `exec()`, this does NOT
822    /// send EXECUTE or SYNC - the caller controls when to execute and sync.
823    ///
824    /// # Arguments
825    /// - `portal`: Portal name (empty string "" for unnamed portal)
826    /// - `statement_name`: Prepared statement name
827    /// - `params`: Parameter values
828    pub async fn lowlevel_bind<P: ToParams>(
829        &mut self,
830        portal: &str,
831        statement_name: &str,
832        params: P,
833    ) -> Result<()> {
834        let result = self
835            .lowlevel_bind_inner(portal, statement_name, &params)
836            .await;
837        if let Err(e) = &result
838            && e.is_connection_broken()
839        {
840            self.is_broken = true;
841        }
842        result
843    }
844
845    async fn lowlevel_bind_inner<P: ToParams>(
846        &mut self,
847        portal: &str,
848        statement_name: &str,
849        params: &P,
850    ) -> Result<()> {
851        use crate::protocol::backend::{BindComplete, ErrorResponse, RawMessage, msg_type};
852        use crate::protocol::frontend::{write_bind, write_flush};
853
854        let param_oids = params.natural_oids();
855        self.buffer_set.write_buffer.clear();
856        write_bind(
857            &mut self.buffer_set.write_buffer,
858            portal,
859            statement_name,
860            params,
861            &param_oids,
862        )?;
863        write_flush(&mut self.buffer_set.write_buffer);
864
865        self.stream.write_all(&self.buffer_set.write_buffer).await?;
866        self.stream.flush().await?;
867
868        loop {
869            self.stream.read_message(&mut self.buffer_set).await?;
870            let type_byte = self.buffer_set.type_byte;
871
872            if RawMessage::is_async_type(type_byte) {
873                continue;
874            }
875
876            match type_byte {
877                msg_type::BIND_COMPLETE => {
878                    BindComplete::parse(&self.buffer_set.read_buffer)?;
879                    return Ok(());
880                }
881                msg_type::ERROR_RESPONSE => {
882                    let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
883                    return Err(error.into_error());
884                }
885                _ => {
886                    return Err(Error::Protocol(format!(
887                        "Expected BindComplete or ErrorResponse, got '{}'",
888                        type_byte as char
889                    )));
890                }
891            }
892        }
893    }
894
895    /// Low-level execute: send EXECUTE message and receive results.
896    ///
897    /// Executes a previously bound portal. Does NOT send SYNC.
898    ///
899    /// # Arguments
900    /// - `portal`: Portal name (empty string "" for unnamed portal)
901    /// - `max_rows`: Maximum rows to return (0 = unlimited)
902    /// - `handler`: Handler to receive rows
903    ///
904    /// # Returns
905    /// - `Ok(true)` if more rows available (PortalSuspended received)
906    /// - `Ok(false)` if execution completed (CommandComplete received)
907    pub async fn lowlevel_execute<H: BinaryHandler>(
908        &mut self,
909        portal: &str,
910        max_rows: u32,
911        handler: &mut H,
912    ) -> Result<bool> {
913        let result = self.lowlevel_execute_inner(portal, max_rows, handler).await;
914        if let Err(e) = &result
915            && e.is_connection_broken()
916        {
917            self.is_broken = true;
918        }
919        result
920    }
921
922    async fn lowlevel_execute_inner<H: BinaryHandler>(
923        &mut self,
924        portal: &str,
925        max_rows: u32,
926        handler: &mut H,
927    ) -> Result<bool> {
928        use crate::protocol::backend::{
929            CommandComplete, DataRow, ErrorResponse, NoData, PortalSuspended, RawMessage,
930            RowDescription, msg_type,
931        };
932        use crate::protocol::frontend::{write_describe_portal, write_execute, write_flush};
933
934        self.buffer_set.write_buffer.clear();
935        write_describe_portal(&mut self.buffer_set.write_buffer, portal);
936        write_execute(&mut self.buffer_set.write_buffer, portal, max_rows);
937        write_flush(&mut self.buffer_set.write_buffer);
938
939        self.stream.write_all(&self.buffer_set.write_buffer).await?;
940        self.stream.flush().await?;
941
942        let mut column_buffer: Vec<u8> = Vec::new();
943
944        loop {
945            self.stream.read_message(&mut self.buffer_set).await?;
946            let type_byte = self.buffer_set.type_byte;
947
948            if RawMessage::is_async_type(type_byte) {
949                continue;
950            }
951
952            match type_byte {
953                msg_type::ROW_DESCRIPTION => {
954                    column_buffer.clear();
955                    column_buffer.extend_from_slice(&self.buffer_set.read_buffer);
956                    let cols = RowDescription::parse(&column_buffer)?;
957                    handler.result_start(cols)?;
958                }
959                msg_type::NO_DATA => {
960                    NoData::parse(&self.buffer_set.read_buffer)?;
961                }
962                msg_type::DATA_ROW => {
963                    let cols = RowDescription::parse(&column_buffer)?;
964                    let row = DataRow::parse(&self.buffer_set.read_buffer)?;
965                    handler.row(cols, row)?;
966                }
967                msg_type::COMMAND_COMPLETE => {
968                    let complete = CommandComplete::parse(&self.buffer_set.read_buffer)?;
969                    handler.result_end(complete)?;
970                    return Ok(false); // No more rows
971                }
972                msg_type::PORTAL_SUSPENDED => {
973                    PortalSuspended::parse(&self.buffer_set.read_buffer)?;
974                    return Ok(true); // More rows available
975                }
976                msg_type::ERROR_RESPONSE => {
977                    let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
978                    return Err(error.into_error());
979                }
980                _ => {
981                    return Err(Error::Protocol(format!(
982                        "Unexpected message in execute: '{}'",
983                        type_byte as char
984                    )));
985                }
986            }
987        }
988    }
989
990    /// Execute a statement with iterative row fetching.
991    ///
992    /// Creates an unnamed portal and passes it to the closure. The closure can
993    /// call `portal.fetch(n, handler)` multiple times to retrieve rows in batches.
994    /// Sync is called after the closure returns to end the implicit transaction.
995    ///
996    /// The statement can be either:
997    /// - A `&PreparedStatement` returned from `prepare()`
998    /// - A raw SQL `&str` for one-shot execution
999    ///
1000    /// # Example
1001    /// ```ignore
1002    /// // Using prepared statement
1003    /// let stmt = conn.prepare("SELECT * FROM users").await?;
1004    /// conn.exec_iter(&stmt, (), |portal| async move {
1005    ///     while portal.fetch(100, &mut handler).await? {
1006    ///         // process handler.into_rows()...
1007    ///     }
1008    ///     Ok(())
1009    /// }).await?;
1010    ///
1011    /// // Using raw SQL
1012    /// conn.exec_iter("SELECT * FROM users", (), |portal| async move {
1013    ///     while portal.fetch(100, &mut handler).await? {
1014    ///         // process handler.into_rows()...
1015    ///     }
1016    ///     Ok(())
1017    /// }).await?;
1018    /// ```
1019    pub async fn exec_iter<S: IntoStatement, P, F, Fut, T>(
1020        &mut self,
1021        statement: S,
1022        params: P,
1023        f: F,
1024    ) -> Result<T>
1025    where
1026        P: ToParams,
1027        F: FnOnce(&mut super::unnamed_portal::UnnamedPortal<'_>) -> Fut,
1028        Fut: std::future::Future<Output = Result<T>>,
1029    {
1030        let result = self.exec_iter_inner(&statement, &params, f).await;
1031        if let Err(e) = &result
1032            && e.is_connection_broken()
1033        {
1034            self.is_broken = true;
1035        }
1036        result
1037    }
1038
1039    async fn exec_iter_inner<S: IntoStatement, P, F, Fut, T>(
1040        &mut self,
1041        statement: &S,
1042        params: &P,
1043        f: F,
1044    ) -> Result<T>
1045    where
1046        P: ToParams,
1047        F: FnOnce(&mut super::unnamed_portal::UnnamedPortal<'_>) -> Fut,
1048        Fut: std::future::Future<Output = Result<T>>,
1049    {
1050        // Create bind state machine for unnamed portal
1051        let mut state_machine = if let Some(sql) = statement.as_sql() {
1052            BindStateMachine::bind_sql(&mut self.buffer_set, "", sql, params)?
1053        } else {
1054            let stmt = statement.as_prepared().unwrap();
1055            BindStateMachine::bind_prepared(
1056                &mut self.buffer_set,
1057                "",
1058                &stmt.wire_name(),
1059                &stmt.param_oids,
1060                params,
1061            )?
1062        };
1063
1064        // Drive the state machine to completion (ParseComplete + BindComplete)
1065        loop {
1066            match state_machine.step(&mut self.buffer_set)? {
1067                Action::ReadMessage => {
1068                    self.stream.read_message(&mut self.buffer_set).await?;
1069                }
1070                Action::Write => {
1071                    self.stream.write_all(&self.buffer_set.write_buffer).await?;
1072                    self.stream.flush().await?;
1073                }
1074                Action::WriteAndReadMessage => {
1075                    self.stream.write_all(&self.buffer_set.write_buffer).await?;
1076                    self.stream.flush().await?;
1077                    self.stream.read_message(&mut self.buffer_set).await?;
1078                }
1079                Action::Finished => break,
1080                _ => return Err(Error::Protocol("Unexpected action in bind".into())),
1081            }
1082        }
1083
1084        // Execute closure with portal handle
1085        let mut portal = super::unnamed_portal::UnnamedPortal { conn: self };
1086        let result = f(&mut portal).await;
1087
1088        // Always sync to end implicit transaction (even on error)
1089        let sync_result = portal.conn.lowlevel_sync().await;
1090
1091        // Return closure result, or sync error if closure succeeded but sync failed
1092        match (result, sync_result) {
1093            (Ok(v), Ok(())) => Ok(v),
1094            (Err(e), _) => Err(e),
1095            (Ok(_), Err(e)) => Err(e),
1096        }
1097    }
1098
1099    /// Low-level close portal: send Close(Portal) and receive CloseComplete.
1100    pub async fn lowlevel_close_portal(&mut self, portal: &str) -> Result<()> {
1101        let result = self.lowlevel_close_portal_inner(portal).await;
1102        if let Err(e) = &result
1103            && e.is_connection_broken()
1104        {
1105            self.is_broken = true;
1106        }
1107        result
1108    }
1109
1110    async fn lowlevel_close_portal_inner(&mut self, portal: &str) -> Result<()> {
1111        use crate::protocol::backend::{CloseComplete, ErrorResponse, RawMessage, msg_type};
1112        use crate::protocol::frontend::{write_close_portal, write_flush};
1113
1114        self.buffer_set.write_buffer.clear();
1115        write_close_portal(&mut self.buffer_set.write_buffer, portal);
1116        write_flush(&mut self.buffer_set.write_buffer);
1117
1118        self.stream.write_all(&self.buffer_set.write_buffer).await?;
1119        self.stream.flush().await?;
1120
1121        loop {
1122            self.stream.read_message(&mut self.buffer_set).await?;
1123            let type_byte = self.buffer_set.type_byte;
1124
1125            if RawMessage::is_async_type(type_byte) {
1126                continue;
1127            }
1128
1129            match type_byte {
1130                msg_type::CLOSE_COMPLETE => {
1131                    CloseComplete::parse(&self.buffer_set.read_buffer)?;
1132                    return Ok(());
1133                }
1134                msg_type::ERROR_RESPONSE => {
1135                    let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
1136                    return Err(error.into_error());
1137                }
1138                _ => {
1139                    return Err(Error::Protocol(format!(
1140                        "Expected CloseComplete or ErrorResponse, got '{}'",
1141                        type_byte as char
1142                    )));
1143                }
1144            }
1145        }
1146    }
1147
1148    /// Run a pipeline of batched queries.
1149    ///
1150    /// Pipeline mode allows sending multiple queries to the server without waiting
1151    /// for responses, reducing round-trip latency.
1152    ///
1153    /// # Example
1154    ///
1155    /// ```ignore
1156    /// // Prepare statements outside the pipeline
1157    /// let stmts = conn.prepare_batch(&[
1158    ///     "SELECT id, name FROM users WHERE active = $1",
1159    ///     "INSERT INTO users (name) VALUES ($1) RETURNING id",
1160    /// ]).await?;
1161    ///
1162    /// let (active, inactive, count) = conn.run_pipeline(|p| async move {
1163    ///     // Queue executions
1164    ///     let t1 = p.exec(&stmts[0], (true,)).await?;
1165    ///     let t2 = p.exec(&stmts[0], (false,)).await?;
1166    ///     let t3 = p.exec("SELECT COUNT(*) FROM users", ()).await?;
1167    ///
1168    ///     p.sync().await?;
1169    ///
1170    ///     // Claim results in order with different methods
1171    ///     let active: Vec<(i32, String)> = p.claim_collect(t1).await?;
1172    ///     let inactive: Option<(i32, String)> = p.claim_one(t2).await?;
1173    ///     let count: Vec<(i64,)> = p.claim_collect(t3).await?;
1174    ///
1175    ///     Ok((active, inactive, count))
1176    /// }).await?;
1177    /// ```
1178    pub async fn run_pipeline<T, F, Fut>(&mut self, f: F) -> Result<T>
1179    where
1180        F: FnOnce(&mut super::pipeline::Pipeline<'_>) -> Fut,
1181        Fut: std::future::Future<Output = Result<T>>,
1182    {
1183        let mut pipeline = super::pipeline::Pipeline::new_inner(self);
1184        let result = f(&mut pipeline).await;
1185        pipeline.cleanup().await;
1186        result
1187    }
1188
1189    /// Execute a closure within a transaction.
1190    ///
1191    /// If the closure returns `Ok`, the transaction is committed.
1192    /// If the closure returns `Err` or the transaction is not explicitly
1193    /// committed or rolled back, the transaction is rolled back.
1194    ///
1195    /// # Errors
1196    ///
1197    /// Returns `Error::InvalidUsage` if called while already in a transaction.
1198    pub async fn tx<F, R, Fut>(&mut self, f: F) -> Result<R>
1199    where
1200        F: FnOnce(&mut Conn, super::transaction::Transaction) -> Fut,
1201        Fut: std::future::Future<Output = Result<R>>,
1202    {
1203        if self.in_transaction() {
1204            return Err(Error::InvalidUsage(
1205                "nested transactions are not supported".into(),
1206            ));
1207        }
1208
1209        self.query_drop("BEGIN").await?;
1210
1211        let tx = super::transaction::Transaction::new(self.connection_id());
1212
1213        // We need to use unsafe to work around the borrow checker here
1214        // because async closures can't capture &mut self properly
1215        let result = f(self, tx).await;
1216
1217        // If still in a transaction (not committed or rolled back), roll it back
1218        if self.in_transaction() {
1219            let rollback_result = self.query_drop("ROLLBACK").await;
1220
1221            // Return the first error (either from closure or rollback)
1222            if let Err(e) = result {
1223                return Err(e);
1224            }
1225            rollback_result?;
1226        }
1227
1228        result
1229    }
1230}