zero_postgres/tokio/
conn.rs

1//! Asynchronous PostgreSQL connection.
2
3use tokio::net::TcpStream;
4#[cfg(unix)]
5use tokio::net::UnixStream;
6
7use crate::buffer_pool::PooledBufferSet;
8use crate::conversion::ToParams;
9use crate::error::{Error, Result};
10use crate::handler::{
11    AsyncMessageHandler, BinaryHandler, DropHandler, FirstRowHandler, TextHandler,
12};
13use crate::opts::Opts;
14use crate::protocol::backend::BackendKeyData;
15use crate::protocol::frontend::write_terminate;
16use crate::protocol::types::TransactionStatus;
17use crate::state::StateMachine;
18use crate::state::action::Action;
19use crate::state::connection::ConnectionStateMachine;
20use crate::state::extended::{BindStateMachine, ExtendedQueryStateMachine, PreparedStatement};
21use crate::state::simple_query::SimpleQueryStateMachine;
22use crate::statement::IntoStatement;
23
24use super::stream::Stream;
25
26/// Asynchronous PostgreSQL connection.
27pub struct Conn {
28    pub(crate) stream: Stream,
29    pub(crate) buffer_set: PooledBufferSet,
30    backend_key: Option<BackendKeyData>,
31    server_params: Vec<(String, String)>,
32    pub(crate) transaction_status: TransactionStatus,
33    pub(crate) is_broken: bool,
34    name_counter: u64,
35    async_message_handler: Option<Box<dyn AsyncMessageHandler>>,
36}
37
38impl Conn {
39    /// Connect to a PostgreSQL server.
40    pub async fn new<O: TryInto<Opts>>(opts: O) -> Result<Self>
41    where
42        Error: From<O::Error>,
43    {
44        let opts = opts.try_into()?;
45
46        let stream = if let Some(socket_path) = &opts.socket {
47            #[cfg(unix)]
48            {
49                Stream::unix(UnixStream::connect(socket_path).await?)
50            }
51            #[cfg(not(unix))]
52            {
53                let _ = socket_path;
54                return Err(Error::Unsupported(
55                    "Unix sockets are not supported on this platform".into(),
56                ));
57            }
58        } else {
59            if opts.host.is_empty() {
60                return Err(Error::InvalidUsage("host is empty".into()));
61            }
62            let addr = format!("{}:{}", opts.host, opts.port);
63            let tcp = TcpStream::connect(&addr).await?;
64            tcp.set_nodelay(true)?;
65            Stream::tcp(tcp)
66        };
67
68        Self::new_with_stream(stream, opts).await
69    }
70
71    /// Connect using an existing stream.
72    #[allow(unused_mut)]
73    pub async fn new_with_stream(mut stream: Stream, options: Opts) -> Result<Self> {
74        let mut buffer_set = options.buffer_pool.get_buffer_set();
75        let mut state_machine = ConnectionStateMachine::new(options.clone());
76
77        // Drive the connection state machine
78        loop {
79            match state_machine.step(&mut buffer_set)? {
80                Action::WriteAndReadByte => {
81                    stream.write_all(&buffer_set.write_buffer).await?;
82                    stream.flush().await?;
83                    let byte = stream.read_u8().await?;
84                    state_machine.set_ssl_response(byte);
85                }
86                Action::ReadMessage => {
87                    stream.read_message(&mut buffer_set).await?;
88                }
89                Action::Write => {
90                    stream.write_all(&buffer_set.write_buffer).await?;
91                    stream.flush().await?;
92                }
93                Action::WriteAndReadMessage => {
94                    stream.write_all(&buffer_set.write_buffer).await?;
95                    stream.flush().await?;
96                    stream.read_message(&mut buffer_set).await?;
97                }
98                Action::TlsHandshake => {
99                    #[cfg(feature = "tokio-tls")]
100                    {
101                        stream = stream.upgrade_to_tls(&options.host).await?;
102                    }
103                    #[cfg(not(feature = "tokio-tls"))]
104                    {
105                        return Err(Error::Unsupported(
106                            "TLS requested but tokio-tls feature not enabled".into(),
107                        ));
108                    }
109                }
110                Action::HandleAsyncMessageAndReadMessage(_) => {
111                    // Ignore async messages during startup, read next message
112                    stream.read_message(&mut buffer_set).await?;
113                }
114                Action::Finished => break,
115            }
116        }
117
118        let conn = Self {
119            stream,
120            buffer_set,
121            backend_key: state_machine.backend_key().cloned(),
122            server_params: state_machine.take_server_params(),
123            transaction_status: state_machine.transaction_status(),
124            is_broken: false,
125            name_counter: 0,
126            async_message_handler: None,
127        };
128
129        // Upgrade to Unix socket if connected via TCP to loopback
130        #[cfg(unix)]
131        let conn = if options.upgrade_to_unix_socket && conn.stream.is_tcp_loopback() {
132            conn.try_upgrade_to_unix_socket(&options).await
133        } else {
134            conn
135        };
136
137        Ok(conn)
138    }
139
140    /// Try to upgrade to Unix socket connection.
141    /// Returns upgraded conn on success, original conn on failure.
142    #[cfg(unix)]
143    fn try_upgrade_to_unix_socket(
144        mut self,
145        opts: &Opts,
146    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Self> + Send + '_>> {
147        let opts = opts.clone();
148        Box::pin(async move {
149            // Query unix_socket_directories from server
150            let mut handler = FirstRowHandler::<(String,)>::new();
151            if self
152                .query("SHOW unix_socket_directories", &mut handler)
153                .await
154                .is_err()
155            {
156                return self;
157            }
158
159            let socket_dir = match handler.into_row() {
160                Some((dirs,)) => {
161                    // May contain multiple directories, use the first one
162                    match dirs.split(',').next() {
163                        Some(d) if !d.trim().is_empty() => d.trim().to_string(),
164                        _ => return self,
165                    }
166                }
167                None => return self,
168            };
169
170            // Build socket path: {directory}/.s.PGSQL.{port}
171            let socket_path = format!("{}/.s.PGSQL.{}", socket_dir, opts.port);
172
173            // Connect via Unix socket
174            let unix_stream = match UnixStream::connect(&socket_path).await {
175                Ok(s) => s,
176                Err(_) => return self,
177            };
178
179            // Create new connection over Unix socket
180            let mut opts_unix = opts.clone();
181            opts_unix.upgrade_to_unix_socket = false;
182
183            match Self::new_with_stream(Stream::unix(unix_stream), opts_unix).await {
184                Ok(new_conn) => new_conn,
185                Err(_) => self,
186            }
187        })
188    }
189
190    /// Get the backend key data for query cancellation.
191    pub fn backend_key(&self) -> Option<&BackendKeyData> {
192        self.backend_key.as_ref()
193    }
194
195    /// Get the connection ID (backend process ID).
196    ///
197    /// Returns 0 if the backend key data is not available.
198    pub fn connection_id(&self) -> u32 {
199        self.backend_key.as_ref().map_or(0, |k| k.process_id())
200    }
201
202    /// Get server parameters.
203    pub fn server_params(&self) -> &[(String, String)] {
204        &self.server_params
205    }
206
207    /// Get the current transaction status.
208    pub fn transaction_status(&self) -> TransactionStatus {
209        self.transaction_status
210    }
211
212    /// Check if currently in a transaction.
213    pub fn in_transaction(&self) -> bool {
214        self.transaction_status.in_transaction()
215    }
216
217    /// Check if the connection is broken.
218    pub fn is_broken(&self) -> bool {
219        self.is_broken
220    }
221
222    /// Generate the next unique portal name.
223    pub(crate) fn next_portal_name(&mut self) -> String {
224        self.name_counter += 1;
225        format!("_zero_p_{}", self.name_counter)
226    }
227
228    /// Create a named portal by binding a statement.
229    ///
230    /// Used internally by Transaction::exec_portal.
231    pub(crate) async fn create_named_portal<S: IntoStatement, P: ToParams>(
232        &mut self,
233        portal_name: &str,
234        statement: &S,
235        params: &P,
236    ) -> Result<()> {
237        // Create bind state machine for named portal
238        let mut state_machine = if let Some(sql) = statement.as_sql() {
239            BindStateMachine::bind_sql(&mut self.buffer_set, portal_name, sql, params)?
240        } else {
241            let stmt = statement.as_prepared().unwrap();
242            BindStateMachine::bind_prepared(
243                &mut self.buffer_set,
244                portal_name,
245                &stmt.wire_name(),
246                &stmt.param_oids,
247                params,
248            )?
249        };
250
251        // Drive the state machine to completion (ParseComplete + BindComplete)
252        loop {
253            match state_machine.step(&mut self.buffer_set)? {
254                Action::ReadMessage => {
255                    self.stream.read_message(&mut self.buffer_set).await?;
256                }
257                Action::Write => {
258                    self.stream.write_all(&self.buffer_set.write_buffer).await?;
259                    self.stream.flush().await?;
260                }
261                Action::WriteAndReadMessage => {
262                    self.stream.write_all(&self.buffer_set.write_buffer).await?;
263                    self.stream.flush().await?;
264                    self.stream.read_message(&mut self.buffer_set).await?;
265                }
266                Action::Finished => break,
267                _ => return Err(Error::Protocol("Unexpected action in bind".into())),
268            }
269        }
270
271        Ok(())
272    }
273
274    /// Set the async message handler.
275    ///
276    /// The handler is called when the server sends asynchronous messages:
277    /// - `Notification` - from LISTEN/NOTIFY
278    /// - `Notice` - warnings and informational messages
279    /// - `ParameterChanged` - server parameter updates
280    pub fn set_async_message_handler<H: AsyncMessageHandler + 'static>(&mut self, handler: H) {
281        self.async_message_handler = Some(Box::new(handler));
282    }
283
284    /// Remove the async message handler.
285    pub fn clear_async_message_handler(&mut self) {
286        self.async_message_handler = None;
287    }
288
289    /// Ping the server with an empty query to check connection aliveness.
290    pub async fn ping(&mut self) -> Result<()> {
291        self.query_drop("").await?;
292        Ok(())
293    }
294
295    /// Drive a state machine to completion.
296    async fn drive<S: StateMachine>(&mut self, state_machine: &mut S) -> Result<()> {
297        loop {
298            match state_machine.step(&mut self.buffer_set)? {
299                Action::WriteAndReadByte => {
300                    return Err(Error::Protocol(
301                        "Unexpected WriteAndReadByte in query state machine".into(),
302                    ));
303                }
304                Action::ReadMessage => {
305                    self.stream.read_message(&mut self.buffer_set).await?;
306                }
307                Action::Write => {
308                    self.stream.write_all(&self.buffer_set.write_buffer).await?;
309                    self.stream.flush().await?;
310                }
311                Action::WriteAndReadMessage => {
312                    self.stream.write_all(&self.buffer_set.write_buffer).await?;
313                    self.stream.flush().await?;
314                    self.stream.read_message(&mut self.buffer_set).await?;
315                }
316                Action::TlsHandshake => {
317                    return Err(Error::Protocol(
318                        "Unexpected TlsHandshake in query state machine".into(),
319                    ));
320                }
321                Action::HandleAsyncMessageAndReadMessage(ref async_msg) => {
322                    if let Some(ref mut h) = self.async_message_handler {
323                        h.handle(async_msg);
324                    }
325                    // Read next message after handling async message
326                    self.stream.read_message(&mut self.buffer_set).await?;
327                }
328                Action::Finished => {
329                    self.transaction_status = state_machine.transaction_status();
330                    break;
331                }
332            }
333        }
334        Ok(())
335    }
336
337    /// Execute a simple query with a handler.
338    pub async fn query<H: TextHandler>(&mut self, sql: &str, handler: &mut H) -> Result<()> {
339        let result = self.query_inner(sql, handler).await;
340        if let Err(e) = &result
341            && e.is_connection_broken()
342        {
343            self.is_broken = true;
344        }
345        result
346    }
347
348    async fn query_inner<H: TextHandler>(&mut self, sql: &str, handler: &mut H) -> Result<()> {
349        let mut state_machine = SimpleQueryStateMachine::new(handler, sql);
350        self.drive(&mut state_machine).await
351    }
352
353    /// Execute a simple query and discard results.
354    pub async fn query_drop(&mut self, sql: &str) -> Result<Option<u64>> {
355        let mut handler = DropHandler::new();
356        self.query(sql, &mut handler).await?;
357        Ok(handler.rows_affected())
358    }
359
360    /// Execute a simple query and collect typed rows.
361    pub async fn query_collect<T: for<'a> crate::conversion::FromRow<'a>>(
362        &mut self,
363        sql: &str,
364    ) -> Result<Vec<T>> {
365        let mut handler = crate::handler::CollectHandler::<T>::new();
366        self.query(sql, &mut handler).await?;
367        Ok(handler.into_rows())
368    }
369
370    /// Execute a simple query and return the first typed row.
371    pub async fn query_first<T: for<'a> crate::conversion::FromRow<'a>>(
372        &mut self,
373        sql: &str,
374    ) -> Result<Option<T>> {
375        let mut handler = crate::handler::FirstRowHandler::<T>::new();
376        self.query(sql, &mut handler).await?;
377        Ok(handler.into_row())
378    }
379
380    /// Execute a simple query and call a closure for each row.
381    ///
382    /// # Example
383    ///
384    /// ```ignore
385    /// conn.query_foreach("SELECT id, name FROM users", |row: (i32, String)| {
386    ///     println!("{}: {}", row.0, row.1);
387    /// }).await?;
388    /// ```
389    pub async fn query_foreach<T: for<'a> crate::conversion::FromRow<'a>, F: FnMut(T)>(
390        &mut self,
391        sql: &str,
392        f: F,
393    ) -> Result<()> {
394        let mut handler = crate::handler::ForEachHandler::<T, F>::new(f);
395        self.query(sql, &mut handler).await?;
396        Ok(())
397    }
398
399    /// Close the connection gracefully.
400    pub async fn close(mut self) -> Result<()> {
401        self.buffer_set.write_buffer.clear();
402        write_terminate(&mut self.buffer_set.write_buffer);
403        self.stream.write_all(&self.buffer_set.write_buffer).await?;
404        self.stream.flush().await?;
405        Ok(())
406    }
407
408    // === Extended Query Protocol ===
409
410    /// Prepare a statement using the extended query protocol.
411    pub async fn prepare(&mut self, query: &str) -> Result<PreparedStatement> {
412        self.prepare_typed(query, &[]).await
413    }
414
415    /// Prepare a statement with explicit parameter types.
416    pub async fn prepare_typed(
417        &mut self,
418        query: &str,
419        param_oids: &[u32],
420    ) -> Result<PreparedStatement> {
421        self.name_counter += 1;
422        let idx = self.name_counter;
423        let result = self.prepare_inner(idx, query, param_oids).await;
424        if let Err(e) = &result
425            && e.is_connection_broken()
426        {
427            self.is_broken = true;
428        }
429        result
430    }
431
432    /// Prepare multiple statements in a single round-trip.
433    ///
434    /// This is more efficient than calling `prepare()` multiple times when you
435    /// need to prepare several statements, as it batches the network communication.
436    ///
437    /// # Example
438    ///
439    /// ```ignore
440    /// let stmts = conn.prepare_batch(&[
441    ///     "SELECT id, name FROM users WHERE id = $1",
442    ///     "INSERT INTO users (name) VALUES ($1) RETURNING id",
443    ///     "UPDATE users SET name = $1 WHERE id = $2",
444    /// ]).await?;
445    ///
446    /// // Use stmts[0], stmts[1], stmts[2]...
447    /// ```
448    pub async fn prepare_batch(&mut self, queries: &[&str]) -> Result<Vec<PreparedStatement>> {
449        if queries.is_empty() {
450            return Ok(Vec::new());
451        }
452
453        let start_idx = self.name_counter + 1;
454        self.name_counter += queries.len() as u64;
455
456        let result = self.prepare_batch_inner(queries, start_idx).await;
457        if let Err(e) = &result
458            && e.is_connection_broken()
459        {
460            self.is_broken = true;
461        }
462        result
463    }
464
465    async fn prepare_batch_inner(
466        &mut self,
467        queries: &[&str],
468        start_idx: u64,
469    ) -> Result<Vec<PreparedStatement>> {
470        use crate::state::batch_prepare::BatchPrepareStateMachine;
471
472        let mut state_machine =
473            BatchPrepareStateMachine::new(&mut self.buffer_set, queries, start_idx);
474
475        loop {
476            match state_machine.step(&mut self.buffer_set)? {
477                Action::ReadMessage => {
478                    self.stream.read_message(&mut self.buffer_set).await?;
479                }
480                Action::WriteAndReadMessage => {
481                    self.stream.write_all(&self.buffer_set.write_buffer).await?;
482                    self.stream.flush().await?;
483                    self.stream.read_message(&mut self.buffer_set).await?;
484                }
485                Action::Finished => {
486                    self.transaction_status = state_machine.transaction_status();
487                    break;
488                }
489                _ => return Err(Error::Protocol("Unexpected action in batch prepare".into())),
490            }
491        }
492
493        Ok(state_machine.take_statements())
494    }
495
496    async fn prepare_inner(
497        &mut self,
498        idx: u64,
499        query: &str,
500        param_oids: &[u32],
501    ) -> Result<PreparedStatement> {
502        let mut handler = DropHandler::new();
503        let mut state_machine = ExtendedQueryStateMachine::prepare(
504            &mut handler,
505            &mut self.buffer_set,
506            idx,
507            query,
508            param_oids,
509        );
510        self.drive(&mut state_machine).await?;
511        state_machine
512            .take_prepared_statement()
513            .ok_or_else(|| Error::Protocol("No prepared statement".into()))
514    }
515
516    /// Execute a statement with a handler.
517    ///
518    /// The statement can be either:
519    /// - A `&PreparedStatement` returned from `prepare()`
520    /// - A raw SQL `&str` for one-shot execution
521    pub async fn exec<S: IntoStatement, P: ToParams, H: BinaryHandler>(
522        &mut self,
523        statement: S,
524        params: P,
525        handler: &mut H,
526    ) -> Result<()> {
527        let result = self.exec_inner(&statement, &params, handler).await;
528        if let Err(e) = &result
529            && e.is_connection_broken()
530        {
531            self.is_broken = true;
532        }
533        result
534    }
535
536    async fn exec_inner<S: IntoStatement, P: ToParams, H: BinaryHandler>(
537        &mut self,
538        statement: &S,
539        params: &P,
540        handler: &mut H,
541    ) -> Result<()> {
542        let mut state_machine = if statement.needs_parse() {
543            ExtendedQueryStateMachine::execute_sql(
544                handler,
545                &mut self.buffer_set,
546                statement.as_sql().unwrap(),
547                params,
548            )?
549        } else {
550            let stmt = statement.as_prepared().unwrap();
551            ExtendedQueryStateMachine::execute(
552                handler,
553                &mut self.buffer_set,
554                &stmt.wire_name(),
555                &stmt.param_oids,
556                params,
557            )?
558        };
559
560        self.drive(&mut state_machine).await
561    }
562
563    /// Execute a statement and discard results.
564    ///
565    /// The statement can be either a `&PreparedStatement` or a raw SQL `&str`.
566    pub async fn exec_drop<S: IntoStatement, P: ToParams>(
567        &mut self,
568        statement: S,
569        params: P,
570    ) -> Result<Option<u64>> {
571        let mut handler = DropHandler::new();
572        self.exec(statement, params, &mut handler).await?;
573        Ok(handler.rows_affected())
574    }
575
576    /// Execute a statement and collect typed rows.
577    ///
578    /// The statement can be either a `&PreparedStatement` or a raw SQL `&str`.
579    pub async fn exec_collect<
580        T: for<'a> crate::conversion::FromRow<'a>,
581        S: IntoStatement,
582        P: ToParams,
583    >(
584        &mut self,
585        statement: S,
586        params: P,
587    ) -> Result<Vec<T>> {
588        let mut handler = crate::handler::CollectHandler::<T>::new();
589        self.exec(statement, params, &mut handler).await?;
590        Ok(handler.into_rows())
591    }
592
593    /// Execute a statement and return the first typed row.
594    ///
595    /// The statement can be either a `&PreparedStatement` or a raw SQL `&str`.
596    pub async fn exec_first<
597        T: for<'a> crate::conversion::FromRow<'a>,
598        S: IntoStatement,
599        P: ToParams,
600    >(
601        &mut self,
602        statement: S,
603        params: P,
604    ) -> Result<Option<T>> {
605        let mut handler = crate::handler::FirstRowHandler::<T>::new();
606        self.exec(statement, params, &mut handler).await?;
607        Ok(handler.into_row())
608    }
609
610    /// Execute a statement and call a closure for each row.
611    ///
612    /// The statement can be either a `&PreparedStatement` or a raw SQL `&str`.
613    ///
614    /// # Example
615    ///
616    /// ```ignore
617    /// let stmt = conn.prepare("SELECT id, name FROM users").await?;
618    /// conn.exec_foreach(&stmt, (), |row: (i32, String)| {
619    ///     println!("{}: {}", row.0, row.1);
620    /// }).await?;
621    /// ```
622    pub async fn exec_foreach<
623        T: for<'a> crate::conversion::FromRow<'a>,
624        S: IntoStatement,
625        P: ToParams,
626        F: FnMut(T),
627    >(
628        &mut self,
629        statement: S,
630        params: P,
631        f: F,
632    ) -> Result<()> {
633        let mut handler = crate::handler::ForEachHandler::<T, F>::new(f);
634        self.exec(statement, params, &mut handler).await?;
635        Ok(())
636    }
637
638    /// Execute a statement with multiple parameter sets in a batch.
639    ///
640    /// This is more efficient than calling `exec_drop` multiple times as it
641    /// batches the network communication. The statement is parsed once (if raw SQL)
642    /// and then bound/executed for each parameter set.
643    ///
644    /// Parameters are processed in chunks (default 1000) to avoid overwhelming
645    /// the server with too many pending operations.
646    ///
647    /// The statement can be either:
648    /// - A `&PreparedStatement` returned from `prepare()`
649    /// - A raw SQL `&str` for one-shot execution
650    ///
651    /// # Example
652    ///
653    /// ```ignore
654    /// // Using prepared statement
655    /// let stmt = conn.prepare("INSERT INTO users (name, age) VALUES ($1, $2)").await?;
656    /// conn.exec_batch(&stmt, &[
657    ///     ("alice", 30),
658    ///     ("bob", 25),
659    ///     ("charlie", 35),
660    /// ]).await?;
661    ///
662    /// // Using raw SQL
663    /// conn.exec_batch("INSERT INTO users (name, age) VALUES ($1, $2)", &[
664    ///     ("alice", 30),
665    ///     ("bob", 25),
666    /// ]).await?;
667    /// ```
668    pub async fn exec_batch<S: IntoStatement, P: ToParams>(
669        &mut self,
670        statement: S,
671        params_list: &[P],
672    ) -> Result<()> {
673        self.exec_batch_chunked(statement, params_list, 1000).await
674    }
675
676    /// Execute a statement with multiple parameter sets in a batch with custom chunk size.
677    ///
678    /// Same as `exec_batch` but allows specifying the chunk size for batching.
679    pub async fn exec_batch_chunked<S: IntoStatement, P: ToParams>(
680        &mut self,
681        statement: S,
682        params_list: &[P],
683        chunk_size: usize,
684    ) -> Result<()> {
685        let result = self
686            .exec_batch_inner(&statement, params_list, chunk_size)
687            .await;
688        if let Err(e) = &result
689            && e.is_connection_broken()
690        {
691            self.is_broken = true;
692        }
693        result
694    }
695
696    async fn exec_batch_inner<S: IntoStatement, P: ToParams>(
697        &mut self,
698        statement: &S,
699        params_list: &[P],
700        chunk_size: usize,
701    ) -> Result<()> {
702        use crate::protocol::frontend::{write_bind, write_execute, write_parse, write_sync};
703        use crate::state::extended::BatchStateMachine;
704
705        if params_list.is_empty() {
706            return Ok(());
707        }
708
709        let chunk_size = chunk_size.max(1);
710        let needs_parse = statement.needs_parse();
711        let sql = statement.as_sql();
712        let prepared = statement.as_prepared();
713
714        // Get param OIDs from first params or prepared statement
715        let param_oids: Vec<u32> = if let Some(stmt) = prepared {
716            stmt.param_oids.clone()
717        } else {
718            params_list[0].natural_oids()
719        };
720
721        // Statement name: empty for raw SQL, actual name for prepared
722        let stmt_name = prepared.map(|s| s.wire_name()).unwrap_or_default();
723
724        for chunk in params_list.chunks(chunk_size) {
725            self.buffer_set.write_buffer.clear();
726
727            // For raw SQL, send Parse each chunk (reuses unnamed statement)
728            let parse_in_chunk = needs_parse;
729            if parse_in_chunk {
730                write_parse(
731                    &mut self.buffer_set.write_buffer,
732                    "",
733                    sql.unwrap(),
734                    &param_oids,
735                );
736            }
737
738            // Write Bind + Execute for each param set
739            for params in chunk {
740                let effective_stmt_name = if needs_parse { "" } else { &stmt_name };
741                write_bind(
742                    &mut self.buffer_set.write_buffer,
743                    "",
744                    effective_stmt_name,
745                    params,
746                    &param_oids,
747                )?;
748                write_execute(&mut self.buffer_set.write_buffer, "", 0);
749            }
750
751            // Send Sync
752            write_sync(&mut self.buffer_set.write_buffer);
753
754            // Drive state machine
755            let mut state_machine = BatchStateMachine::new(parse_in_chunk);
756            self.drive_batch(&mut state_machine).await?;
757            self.transaction_status = state_machine.transaction_status();
758        }
759
760        Ok(())
761    }
762
763    /// Drive a batch state machine to completion.
764    async fn drive_batch(
765        &mut self,
766        state_machine: &mut crate::state::extended::BatchStateMachine,
767    ) -> Result<()> {
768        use crate::protocol::backend::{ReadyForQuery, msg_type};
769        use crate::state::action::Action;
770
771        loop {
772            let step_result = state_machine.step(&mut self.buffer_set);
773            match step_result {
774                Ok(Action::ReadMessage) => {
775                    self.stream.read_message(&mut self.buffer_set).await?;
776                }
777                Ok(Action::WriteAndReadMessage) => {
778                    self.stream.write_all(&self.buffer_set.write_buffer).await?;
779                    self.stream.flush().await?;
780                    self.stream.read_message(&mut self.buffer_set).await?;
781                }
782                Ok(Action::Finished) => {
783                    break;
784                }
785                Ok(_) => return Err(Error::Protocol("Unexpected action in batch".into())),
786                Err(e) => {
787                    // On error, drain to ReadyForQuery to leave connection in clean state
788                    loop {
789                        self.stream.read_message(&mut self.buffer_set).await?;
790                        if self.buffer_set.type_byte == msg_type::READY_FOR_QUERY {
791                            let ready = ReadyForQuery::parse(&self.buffer_set.read_buffer)?;
792                            self.transaction_status =
793                                ready.transaction_status().unwrap_or_default();
794                            break;
795                        }
796                    }
797                    return Err(e);
798                }
799            }
800        }
801        Ok(())
802    }
803
804    /// Close a prepared statement.
805    pub async fn close_statement(&mut self, stmt: &PreparedStatement) -> Result<()> {
806        let result = self.close_statement_inner(&stmt.wire_name()).await;
807        if let Err(e) = &result
808            && e.is_connection_broken()
809        {
810            self.is_broken = true;
811        }
812        result
813    }
814
815    async fn close_statement_inner(&mut self, name: &str) -> Result<()> {
816        let mut handler = DropHandler::new();
817        let mut state_machine =
818            ExtendedQueryStateMachine::close_statement(&mut handler, &mut self.buffer_set, name);
819        self.drive(&mut state_machine).await
820    }
821
822    // === Low-Level Extended Query Protocol ===
823
824    /// Low-level flush: send FLUSH to force server to send pending responses.
825    ///
826    /// Unlike SYNC, FLUSH does not end the transaction or wait for ReadyForQuery.
827    /// It just forces the server to send any pending responses without ending
828    /// the extended query sequence.
829    pub async fn lowlevel_flush(&mut self) -> Result<()> {
830        use crate::protocol::frontend::write_flush;
831
832        self.buffer_set.write_buffer.clear();
833        write_flush(&mut self.buffer_set.write_buffer);
834
835        self.stream.write_all(&self.buffer_set.write_buffer).await?;
836        self.stream.flush().await?;
837        Ok(())
838    }
839
840    /// Low-level sync: send SYNC and receive ReadyForQuery.
841    ///
842    /// This ends an extended query sequence and:
843    /// - Commits implicit transaction if successful
844    /// - Rolls back implicit transaction if failed
845    /// - Updates transaction status
846    pub async fn lowlevel_sync(&mut self) -> Result<()> {
847        let result = self.lowlevel_sync_inner().await;
848        if let Err(e) = &result
849            && e.is_connection_broken()
850        {
851            self.is_broken = true;
852        }
853        result
854    }
855
856    async fn lowlevel_sync_inner(&mut self) -> Result<()> {
857        use crate::protocol::backend::{ErrorResponse, RawMessage, ReadyForQuery, msg_type};
858        use crate::protocol::frontend::write_sync;
859
860        self.buffer_set.write_buffer.clear();
861        write_sync(&mut self.buffer_set.write_buffer);
862
863        self.stream.write_all(&self.buffer_set.write_buffer).await?;
864        self.stream.flush().await?;
865
866        let mut pending_error: Option<Error> = None;
867
868        loop {
869            self.stream.read_message(&mut self.buffer_set).await?;
870            let type_byte = self.buffer_set.type_byte;
871
872            if RawMessage::is_async_type(type_byte) {
873                continue;
874            }
875
876            match type_byte {
877                msg_type::READY_FOR_QUERY => {
878                    let ready = ReadyForQuery::parse(&self.buffer_set.read_buffer)?;
879                    self.transaction_status = ready.transaction_status().unwrap_or_default();
880                    if let Some(e) = pending_error {
881                        return Err(e);
882                    }
883                    return Ok(());
884                }
885                msg_type::ERROR_RESPONSE => {
886                    let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
887                    pending_error = Some(error.into_error());
888                }
889                _ => {
890                    // Ignore other messages before ReadyForQuery
891                }
892            }
893        }
894    }
895
896    /// Low-level bind: send BIND message and receive BindComplete.
897    ///
898    /// This allows creating named portals. Unlike `exec()`, this does NOT
899    /// send EXECUTE or SYNC - the caller controls when to execute and sync.
900    ///
901    /// # Arguments
902    /// - `portal`: Portal name (empty string "" for unnamed portal)
903    /// - `statement_name`: Prepared statement name
904    /// - `params`: Parameter values
905    pub async fn lowlevel_bind<P: ToParams>(
906        &mut self,
907        portal: &str,
908        statement_name: &str,
909        params: P,
910    ) -> Result<()> {
911        let result = self
912            .lowlevel_bind_inner(portal, statement_name, &params)
913            .await;
914        if let Err(e) = &result
915            && e.is_connection_broken()
916        {
917            self.is_broken = true;
918        }
919        result
920    }
921
922    async fn lowlevel_bind_inner<P: ToParams>(
923        &mut self,
924        portal: &str,
925        statement_name: &str,
926        params: &P,
927    ) -> Result<()> {
928        use crate::protocol::backend::{BindComplete, ErrorResponse, RawMessage, msg_type};
929        use crate::protocol::frontend::{write_bind, write_flush};
930
931        let param_oids = params.natural_oids();
932        self.buffer_set.write_buffer.clear();
933        write_bind(
934            &mut self.buffer_set.write_buffer,
935            portal,
936            statement_name,
937            params,
938            &param_oids,
939        )?;
940        write_flush(&mut self.buffer_set.write_buffer);
941
942        self.stream.write_all(&self.buffer_set.write_buffer).await?;
943        self.stream.flush().await?;
944
945        loop {
946            self.stream.read_message(&mut self.buffer_set).await?;
947            let type_byte = self.buffer_set.type_byte;
948
949            if RawMessage::is_async_type(type_byte) {
950                continue;
951            }
952
953            match type_byte {
954                msg_type::BIND_COMPLETE => {
955                    BindComplete::parse(&self.buffer_set.read_buffer)?;
956                    return Ok(());
957                }
958                msg_type::ERROR_RESPONSE => {
959                    let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
960                    return Err(error.into_error());
961                }
962                _ => {
963                    return Err(Error::Protocol(format!(
964                        "Expected BindComplete or ErrorResponse, got '{}'",
965                        type_byte as char
966                    )));
967                }
968            }
969        }
970    }
971
972    /// Low-level execute: send EXECUTE message and receive results.
973    ///
974    /// Executes a previously bound portal. Does NOT send SYNC.
975    ///
976    /// # Arguments
977    /// - `portal`: Portal name (empty string "" for unnamed portal)
978    /// - `max_rows`: Maximum rows to return (0 = unlimited)
979    /// - `handler`: Handler to receive rows
980    ///
981    /// # Returns
982    /// - `Ok(true)` if more rows available (PortalSuspended received)
983    /// - `Ok(false)` if execution completed (CommandComplete received)
984    pub async fn lowlevel_execute<H: BinaryHandler>(
985        &mut self,
986        portal: &str,
987        max_rows: u32,
988        handler: &mut H,
989    ) -> Result<bool> {
990        let result = self.lowlevel_execute_inner(portal, max_rows, handler).await;
991        if let Err(e) = &result
992            && e.is_connection_broken()
993        {
994            self.is_broken = true;
995        }
996        result
997    }
998
999    async fn lowlevel_execute_inner<H: BinaryHandler>(
1000        &mut self,
1001        portal: &str,
1002        max_rows: u32,
1003        handler: &mut H,
1004    ) -> Result<bool> {
1005        use crate::protocol::backend::{
1006            CommandComplete, DataRow, ErrorResponse, NoData, PortalSuspended, RawMessage,
1007            RowDescription, msg_type,
1008        };
1009        use crate::protocol::frontend::{write_describe_portal, write_execute, write_flush};
1010
1011        self.buffer_set.write_buffer.clear();
1012        write_describe_portal(&mut self.buffer_set.write_buffer, portal);
1013        write_execute(&mut self.buffer_set.write_buffer, portal, max_rows);
1014        write_flush(&mut self.buffer_set.write_buffer);
1015
1016        self.stream.write_all(&self.buffer_set.write_buffer).await?;
1017        self.stream.flush().await?;
1018
1019        let mut column_buffer: Vec<u8> = Vec::new();
1020
1021        loop {
1022            self.stream.read_message(&mut self.buffer_set).await?;
1023            let type_byte = self.buffer_set.type_byte;
1024
1025            if RawMessage::is_async_type(type_byte) {
1026                continue;
1027            }
1028
1029            match type_byte {
1030                msg_type::ROW_DESCRIPTION => {
1031                    column_buffer.clear();
1032                    column_buffer.extend_from_slice(&self.buffer_set.read_buffer);
1033                    let cols = RowDescription::parse(&column_buffer)?;
1034                    handler.result_start(cols)?;
1035                }
1036                msg_type::NO_DATA => {
1037                    NoData::parse(&self.buffer_set.read_buffer)?;
1038                }
1039                msg_type::DATA_ROW => {
1040                    let cols = RowDescription::parse(&column_buffer)?;
1041                    let row = DataRow::parse(&self.buffer_set.read_buffer)?;
1042                    handler.row(cols, row)?;
1043                }
1044                msg_type::COMMAND_COMPLETE => {
1045                    let complete = CommandComplete::parse(&self.buffer_set.read_buffer)?;
1046                    handler.result_end(complete)?;
1047                    return Ok(false); // No more rows
1048                }
1049                msg_type::PORTAL_SUSPENDED => {
1050                    PortalSuspended::parse(&self.buffer_set.read_buffer)?;
1051                    return Ok(true); // More rows available
1052                }
1053                msg_type::ERROR_RESPONSE => {
1054                    let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
1055                    return Err(error.into_error());
1056                }
1057                _ => {
1058                    return Err(Error::Protocol(format!(
1059                        "Unexpected message in execute: '{}'",
1060                        type_byte as char
1061                    )));
1062                }
1063            }
1064        }
1065    }
1066
1067    /// Execute a statement with iterative row fetching using a portal.
1068    ///
1069    /// Creates an unnamed portal and passes it to the closure. The closure can
1070    /// call `portal.fetch(n, handler)` multiple times to retrieve rows in batches.
1071    /// Sync is called after the closure returns to end the implicit transaction.
1072    ///
1073    /// The statement can be either:
1074    /// - A `&PreparedStatement` returned from `prepare()`
1075    /// - A raw SQL `&str` for one-shot execution
1076    ///
1077    /// # Example
1078    /// ```ignore
1079    /// // Using prepared statement
1080    /// let stmt = conn.prepare("SELECT * FROM users").await?;
1081    /// conn.exec_portal(&stmt, (), |portal| async move {
1082    ///     while portal.fetch(100, &mut handler).await? {
1083    ///         // process handler.into_rows()...
1084    ///     }
1085    ///     Ok(())
1086    /// }).await?;
1087    ///
1088    /// // Using raw SQL
1089    /// conn.exec_portal("SELECT * FROM users", (), |portal| async move {
1090    ///     while portal.fetch(100, &mut handler).await? {
1091    ///         // process handler.into_rows()...
1092    ///     }
1093    ///     Ok(())
1094    /// }).await?;
1095    /// ```
1096    pub async fn exec_portal<S: IntoStatement, P, F, Fut, T>(
1097        &mut self,
1098        statement: S,
1099        params: P,
1100        f: F,
1101    ) -> Result<T>
1102    where
1103        P: ToParams,
1104        F: FnOnce(&mut super::unnamed_portal::UnnamedPortal<'_>) -> Fut,
1105        Fut: std::future::Future<Output = Result<T>>,
1106    {
1107        let result = self.exec_portal_inner(&statement, &params, f).await;
1108        if let Err(e) = &result
1109            && e.is_connection_broken()
1110        {
1111            self.is_broken = true;
1112        }
1113        result
1114    }
1115
1116    async fn exec_portal_inner<S: IntoStatement, P, F, Fut, T>(
1117        &mut self,
1118        statement: &S,
1119        params: &P,
1120        f: F,
1121    ) -> Result<T>
1122    where
1123        P: ToParams,
1124        F: FnOnce(&mut super::unnamed_portal::UnnamedPortal<'_>) -> Fut,
1125        Fut: std::future::Future<Output = Result<T>>,
1126    {
1127        // Create bind state machine for unnamed portal
1128        let mut state_machine = if let Some(sql) = statement.as_sql() {
1129            BindStateMachine::bind_sql(&mut self.buffer_set, "", sql, params)?
1130        } else {
1131            let stmt = statement.as_prepared().unwrap();
1132            BindStateMachine::bind_prepared(
1133                &mut self.buffer_set,
1134                "",
1135                &stmt.wire_name(),
1136                &stmt.param_oids,
1137                params,
1138            )?
1139        };
1140
1141        // Drive the state machine to completion (ParseComplete + BindComplete)
1142        loop {
1143            match state_machine.step(&mut self.buffer_set)? {
1144                Action::ReadMessage => {
1145                    self.stream.read_message(&mut self.buffer_set).await?;
1146                }
1147                Action::Write => {
1148                    self.stream.write_all(&self.buffer_set.write_buffer).await?;
1149                    self.stream.flush().await?;
1150                }
1151                Action::WriteAndReadMessage => {
1152                    self.stream.write_all(&self.buffer_set.write_buffer).await?;
1153                    self.stream.flush().await?;
1154                    self.stream.read_message(&mut self.buffer_set).await?;
1155                }
1156                Action::Finished => break,
1157                _ => return Err(Error::Protocol("Unexpected action in bind".into())),
1158            }
1159        }
1160
1161        // Execute closure with portal handle
1162        let mut portal = super::unnamed_portal::UnnamedPortal { conn: self };
1163        let result = f(&mut portal).await;
1164
1165        // Always sync to end implicit transaction (even on error)
1166        let sync_result = portal.conn.lowlevel_sync().await;
1167
1168        // Return closure result, or sync error if closure succeeded but sync failed
1169        match (result, sync_result) {
1170            (Ok(v), Ok(())) => Ok(v),
1171            (Err(e), _) => Err(e),
1172            (Ok(_), Err(e)) => Err(e),
1173        }
1174    }
1175
1176    /// Low-level close portal: send Close(Portal) and receive CloseComplete.
1177    pub async fn lowlevel_close_portal(&mut self, portal: &str) -> Result<()> {
1178        let result = self.lowlevel_close_portal_inner(portal).await;
1179        if let Err(e) = &result
1180            && e.is_connection_broken()
1181        {
1182            self.is_broken = true;
1183        }
1184        result
1185    }
1186
1187    async fn lowlevel_close_portal_inner(&mut self, portal: &str) -> Result<()> {
1188        use crate::protocol::backend::{CloseComplete, ErrorResponse, RawMessage, msg_type};
1189        use crate::protocol::frontend::{write_close_portal, write_flush};
1190
1191        self.buffer_set.write_buffer.clear();
1192        write_close_portal(&mut self.buffer_set.write_buffer, portal);
1193        write_flush(&mut self.buffer_set.write_buffer);
1194
1195        self.stream.write_all(&self.buffer_set.write_buffer).await?;
1196        self.stream.flush().await?;
1197
1198        loop {
1199            self.stream.read_message(&mut self.buffer_set).await?;
1200            let type_byte = self.buffer_set.type_byte;
1201
1202            if RawMessage::is_async_type(type_byte) {
1203                continue;
1204            }
1205
1206            match type_byte {
1207                msg_type::CLOSE_COMPLETE => {
1208                    CloseComplete::parse(&self.buffer_set.read_buffer)?;
1209                    return Ok(());
1210                }
1211                msg_type::ERROR_RESPONSE => {
1212                    let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
1213                    return Err(error.into_error());
1214                }
1215                _ => {
1216                    return Err(Error::Protocol(format!(
1217                        "Expected CloseComplete or ErrorResponse, got '{}'",
1218                        type_byte as char
1219                    )));
1220                }
1221            }
1222        }
1223    }
1224
1225    /// Run a pipeline of batched queries.
1226    ///
1227    /// Pipeline mode allows sending multiple queries to the server without waiting
1228    /// for responses, reducing round-trip latency.
1229    ///
1230    /// # Example
1231    ///
1232    /// ```ignore
1233    /// // Prepare statements outside the pipeline
1234    /// let stmts = conn.prepare_batch(&[
1235    ///     "SELECT id, name FROM users WHERE active = $1",
1236    ///     "INSERT INTO users (name) VALUES ($1) RETURNING id",
1237    /// ]).await?;
1238    ///
1239    /// let (active, inactive, count) = conn.run_pipeline(|p| async move {
1240    ///     // Queue executions
1241    ///     let t1 = p.exec(&stmts[0], (true,)).await?;
1242    ///     let t2 = p.exec(&stmts[0], (false,)).await?;
1243    ///     let t3 = p.exec("SELECT COUNT(*) FROM users", ()).await?;
1244    ///
1245    ///     p.sync().await?;
1246    ///
1247    ///     // Claim results in order with different methods
1248    ///     let active: Vec<(i32, String)> = p.claim_collect(t1).await?;
1249    ///     let inactive: Option<(i32, String)> = p.claim_one(t2).await?;
1250    ///     let count: Vec<(i64,)> = p.claim_collect(t3).await?;
1251    ///
1252    ///     Ok((active, inactive, count))
1253    /// }).await?;
1254    /// ```
1255    pub async fn run_pipeline<T, F, Fut>(&mut self, f: F) -> Result<T>
1256    where
1257        F: FnOnce(&mut super::pipeline::Pipeline<'_>) -> Fut,
1258        Fut: std::future::Future<Output = Result<T>>,
1259    {
1260        let mut pipeline = super::pipeline::Pipeline::new_inner(self);
1261        let result = f(&mut pipeline).await;
1262        pipeline.cleanup().await;
1263        result
1264    }
1265
1266    /// Execute a closure within a transaction.
1267    ///
1268    /// If the closure returns `Ok`, the transaction is committed.
1269    /// If the closure returns `Err` or the transaction is not explicitly
1270    /// committed or rolled back, the transaction is rolled back.
1271    ///
1272    /// # Errors
1273    ///
1274    /// Returns `Error::InvalidUsage` if called while already in a transaction.
1275    pub async fn tx<F, R, Fut>(&mut self, f: F) -> Result<R>
1276    where
1277        F: FnOnce(&mut Conn, super::transaction::Transaction) -> Fut,
1278        Fut: std::future::Future<Output = Result<R>>,
1279    {
1280        if self.in_transaction() {
1281            return Err(Error::InvalidUsage(
1282                "nested transactions are not supported".into(),
1283            ));
1284        }
1285
1286        self.query_drop("BEGIN").await?;
1287
1288        let tx = super::transaction::Transaction::new(self.connection_id());
1289
1290        // We need to use unsafe to work around the borrow checker here
1291        // because async closures can't capture &mut self properly
1292        let result = f(self, tx).await;
1293
1294        // If still in a transaction (not committed or rolled back), roll it back
1295        if self.in_transaction() {
1296            let rollback_result = self.query_drop("ROLLBACK").await;
1297
1298            // Return the first error (either from closure or rollback)
1299            if let Err(e) = result {
1300                return Err(e);
1301            }
1302            rollback_result?;
1303        }
1304
1305        result
1306    }
1307}