zero_postgres/tokio/
conn.rs

1//! Asynchronous PostgreSQL connection.
2
3use tokio::net::TcpStream;
4#[cfg(unix)]
5use tokio::net::UnixStream;
6
7use crate::buffer_pool::PooledBufferSet;
8use crate::conversion::ToParams;
9use crate::error::{Error, Result};
10use crate::handler::{
11    AsyncMessageHandler, BinaryHandler, DropHandler, FirstRowHandler, TextHandler,
12};
13use crate::opts::Opts;
14use crate::protocol::backend::BackendKeyData;
15use crate::protocol::frontend::write_terminate;
16use crate::protocol::types::TransactionStatus;
17use crate::state::StateMachine;
18use crate::state::action::Action;
19use crate::state::connection::ConnectionStateMachine;
20use crate::state::extended::{BindStateMachine, ExtendedQueryStateMachine, PreparedStatement};
21use crate::state::simple_query::SimpleQueryStateMachine;
22use crate::statement::IntoStatement;
23
24use super::stream::Stream;
25
26/// Asynchronous PostgreSQL connection.
27pub struct Conn {
28    pub(crate) stream: Stream,
29    pub(crate) buffer_set: PooledBufferSet,
30    backend_key: Option<BackendKeyData>,
31    server_params: Vec<(String, String)>,
32    pub(crate) transaction_status: TransactionStatus,
33    pub(crate) is_broken: bool,
34    name_counter: u64,
35    async_message_handler: Option<Box<dyn AsyncMessageHandler>>,
36}
37
38impl Conn {
39    /// Connect to a PostgreSQL server.
40    pub async fn new<O: TryInto<Opts>>(opts: O) -> Result<Self>
41    where
42        Error: From<O::Error>,
43    {
44        let opts = opts.try_into()?;
45
46        let stream = if let Some(socket_path) = &opts.socket {
47            #[cfg(unix)]
48            {
49                Stream::unix(UnixStream::connect(socket_path).await?)
50            }
51            #[cfg(not(unix))]
52            {
53                let _ = socket_path;
54                return Err(Error::Unsupported(
55                    "Unix sockets are not supported on this platform".into(),
56                ));
57            }
58        } else {
59            if opts.host.is_empty() {
60                return Err(Error::InvalidUsage("host is empty".into()));
61            }
62            let addr = format!("{}:{}", opts.host, opts.port);
63            let tcp = TcpStream::connect(&addr).await?;
64            tcp.set_nodelay(true)?;
65            Stream::tcp(tcp)
66        };
67
68        Self::new_with_stream(stream, opts).await
69    }
70
71    /// Connect using an existing stream.
72    #[allow(unused_mut)]
73    pub async fn new_with_stream(mut stream: Stream, options: Opts) -> Result<Self> {
74        let mut buffer_set = options.buffer_pool.get_buffer_set();
75        let mut state_machine = ConnectionStateMachine::new(options.clone());
76
77        // Drive the connection state machine
78        loop {
79            match state_machine.step(&mut buffer_set)? {
80                Action::WriteAndReadByte => {
81                    stream.write_all(&buffer_set.write_buffer).await?;
82                    stream.flush().await?;
83                    let byte = stream.read_u8().await?;
84                    state_machine.set_ssl_response(byte);
85                }
86                Action::ReadMessage => {
87                    stream.read_message(&mut buffer_set).await?;
88                }
89                Action::Write => {
90                    stream.write_all(&buffer_set.write_buffer).await?;
91                    stream.flush().await?;
92                }
93                Action::WriteAndReadMessage => {
94                    stream.write_all(&buffer_set.write_buffer).await?;
95                    stream.flush().await?;
96                    stream.read_message(&mut buffer_set).await?;
97                }
98                Action::TlsHandshake => {
99                    #[cfg(feature = "tokio-tls")]
100                    {
101                        stream = stream.upgrade_to_tls(&options.host).await?;
102                    }
103                    #[cfg(not(feature = "tokio-tls"))]
104                    {
105                        return Err(Error::Unsupported(
106                            "TLS requested but tokio-tls feature not enabled".into(),
107                        ));
108                    }
109                }
110                Action::HandleAsyncMessageAndReadMessage(_) => {
111                    // Ignore async messages during startup, read next message
112                    stream.read_message(&mut buffer_set).await?;
113                }
114                Action::Finished => break,
115            }
116        }
117
118        let conn = Self {
119            stream,
120            buffer_set,
121            backend_key: state_machine.backend_key().cloned(),
122            server_params: state_machine.take_server_params(),
123            transaction_status: state_machine.transaction_status(),
124            is_broken: false,
125            name_counter: 0,
126            async_message_handler: None,
127        };
128
129        // Upgrade to Unix socket if connected via TCP to loopback
130        #[cfg(unix)]
131        let conn = if options.upgrade_to_unix_socket && conn.stream.is_tcp_loopback() {
132            conn.try_upgrade_to_unix_socket(&options).await
133        } else {
134            conn
135        };
136
137        Ok(conn)
138    }
139
140    /// Try to upgrade to Unix socket connection.
141    /// Returns upgraded conn on success, original conn on failure.
142    #[cfg(unix)]
143    fn try_upgrade_to_unix_socket(
144        mut self,
145        opts: &Opts,
146    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Self> + Send + '_>> {
147        let opts = opts.clone();
148        Box::pin(async move {
149            // Query unix_socket_directories from server
150            let mut handler = FirstRowHandler::<(String,)>::new();
151            if self
152                .query("SHOW unix_socket_directories", &mut handler)
153                .await
154                .is_err()
155            {
156                return self;
157            }
158
159            let socket_dir = match handler.into_row() {
160                Some((dirs,)) => {
161                    // May contain multiple directories, use the first one
162                    match dirs.split(',').next() {
163                        Some(d) if !d.trim().is_empty() => d.trim().to_string(),
164                        _ => return self,
165                    }
166                }
167                None => return self,
168            };
169
170            // Build socket path: {directory}/.s.PGSQL.{port}
171            let socket_path = format!("{}/.s.PGSQL.{}", socket_dir, opts.port);
172
173            // Connect via Unix socket
174            let unix_stream = match UnixStream::connect(&socket_path).await {
175                Ok(s) => s,
176                Err(_) => return self,
177            };
178
179            // Create new connection over Unix socket
180            let mut opts_unix = opts.clone();
181            opts_unix.upgrade_to_unix_socket = false;
182
183            match Self::new_with_stream(Stream::unix(unix_stream), opts_unix).await {
184                Ok(new_conn) => new_conn,
185                Err(_) => self,
186            }
187        })
188    }
189
190    /// Get the backend key data for query cancellation.
191    pub fn backend_key(&self) -> Option<&BackendKeyData> {
192        self.backend_key.as_ref()
193    }
194
195    /// Get the connection ID (backend process ID).
196    ///
197    /// Returns 0 if the backend key data is not available.
198    pub fn connection_id(&self) -> u32 {
199        self.backend_key.as_ref().map_or(0, |k| k.process_id())
200    }
201
202    /// Get server parameters.
203    pub fn server_params(&self) -> &[(String, String)] {
204        &self.server_params
205    }
206
207    /// Get the current transaction status.
208    pub fn transaction_status(&self) -> TransactionStatus {
209        self.transaction_status
210    }
211
212    /// Check if currently in a transaction.
213    pub fn in_transaction(&self) -> bool {
214        self.transaction_status.in_transaction()
215    }
216
217    /// Check if the connection is broken.
218    pub fn is_broken(&self) -> bool {
219        self.is_broken
220    }
221
222    /// Generate the next unique portal name.
223    pub(crate) fn next_portal_name(&mut self) -> String {
224        self.name_counter += 1;
225        format!("_zero_p_{}", self.name_counter)
226    }
227
228    /// Create a named portal by binding a statement.
229    ///
230    /// Used internally by Transaction::exec_portal.
231    pub(crate) async fn create_named_portal<S: IntoStatement, P: ToParams>(
232        &mut self,
233        portal_name: &str,
234        statement: &S,
235        params: &P,
236    ) -> Result<()> {
237        // Create bind state machine for named portal
238        let mut state_machine = if let Some(sql) = statement.as_sql() {
239            BindStateMachine::bind_sql(&mut self.buffer_set, portal_name, sql, params)?
240        } else {
241            let stmt = statement.as_prepared().unwrap();
242            BindStateMachine::bind_prepared(
243                &mut self.buffer_set,
244                portal_name,
245                &stmt.wire_name(),
246                &stmt.param_oids,
247                params,
248            )?
249        };
250
251        // Drive the state machine to completion (ParseComplete + BindComplete)
252        loop {
253            match state_machine.step(&mut self.buffer_set)? {
254                Action::ReadMessage => {
255                    self.stream.read_message(&mut self.buffer_set).await?;
256                }
257                Action::Write => {
258                    self.stream.write_all(&self.buffer_set.write_buffer).await?;
259                    self.stream.flush().await?;
260                }
261                Action::WriteAndReadMessage => {
262                    self.stream.write_all(&self.buffer_set.write_buffer).await?;
263                    self.stream.flush().await?;
264                    self.stream.read_message(&mut self.buffer_set).await?;
265                }
266                Action::Finished => break,
267                _ => return Err(Error::Protocol("Unexpected action in bind".into())),
268            }
269        }
270
271        Ok(())
272    }
273
274    /// Set the async message handler.
275    ///
276    /// The handler is called when the server sends asynchronous messages:
277    /// - `Notification` - from LISTEN/NOTIFY
278    /// - `Notice` - warnings and informational messages
279    /// - `ParameterChanged` - server parameter updates
280    pub fn set_async_message_handler<H: AsyncMessageHandler + 'static>(&mut self, handler: H) {
281        self.async_message_handler = Some(Box::new(handler));
282    }
283
284    /// Remove the async message handler.
285    pub fn clear_async_message_handler(&mut self) {
286        self.async_message_handler = None;
287    }
288
289    /// Ping the server with an empty query to check connection aliveness.
290    pub async fn ping(&mut self) -> Result<()> {
291        self.query_drop("").await?;
292        Ok(())
293    }
294
295    /// Drive a state machine to completion.
296    async fn drive<S: StateMachine>(&mut self, state_machine: &mut S) -> Result<()> {
297        loop {
298            match state_machine.step(&mut self.buffer_set)? {
299                Action::WriteAndReadByte => {
300                    return Err(Error::Protocol(
301                        "Unexpected WriteAndReadByte in query state machine".into(),
302                    ));
303                }
304                Action::ReadMessage => {
305                    self.stream.read_message(&mut self.buffer_set).await?;
306                }
307                Action::Write => {
308                    self.stream.write_all(&self.buffer_set.write_buffer).await?;
309                    self.stream.flush().await?;
310                }
311                Action::WriteAndReadMessage => {
312                    self.stream.write_all(&self.buffer_set.write_buffer).await?;
313                    self.stream.flush().await?;
314                    self.stream.read_message(&mut self.buffer_set).await?;
315                }
316                Action::TlsHandshake => {
317                    return Err(Error::Protocol(
318                        "Unexpected TlsHandshake in query state machine".into(),
319                    ));
320                }
321                Action::HandleAsyncMessageAndReadMessage(ref async_msg) => {
322                    if let Some(ref mut h) = self.async_message_handler {
323                        h.handle(async_msg);
324                    }
325                    // Read next message after handling async message
326                    self.stream.read_message(&mut self.buffer_set).await?;
327                }
328                Action::Finished => {
329                    self.transaction_status = state_machine.transaction_status();
330                    break;
331                }
332            }
333        }
334        Ok(())
335    }
336
337    /// Execute a simple query with a handler.
338    pub async fn query<H: TextHandler>(&mut self, sql: &str, handler: &mut H) -> Result<()> {
339        let result = self.query_inner(sql, handler).await;
340        if let Err(e) = &result
341            && e.is_connection_broken()
342        {
343            self.is_broken = true;
344        }
345        result
346    }
347
348    async fn query_inner<H: TextHandler>(&mut self, sql: &str, handler: &mut H) -> Result<()> {
349        let mut state_machine = SimpleQueryStateMachine::new(handler, sql);
350        self.drive(&mut state_machine).await
351    }
352
353    /// Execute a simple query and discard results.
354    pub async fn query_drop(&mut self, sql: &str) -> Result<Option<u64>> {
355        let mut handler = DropHandler::new();
356        self.query(sql, &mut handler).await?;
357        Ok(handler.rows_affected())
358    }
359
360    /// Execute a simple query and collect typed rows.
361    pub async fn query_collect<T: for<'a> crate::conversion::FromRow<'a>>(
362        &mut self,
363        sql: &str,
364    ) -> Result<Vec<T>> {
365        let mut handler = crate::handler::CollectHandler::<T>::new();
366        self.query(sql, &mut handler).await?;
367        Ok(handler.into_rows())
368    }
369
370    /// Execute a simple query and return the first typed row.
371    pub async fn query_first<T: for<'a> crate::conversion::FromRow<'a>>(
372        &mut self,
373        sql: &str,
374    ) -> Result<Option<T>> {
375        let mut handler = crate::handler::FirstRowHandler::<T>::new();
376        self.query(sql, &mut handler).await?;
377        Ok(handler.into_row())
378    }
379
380    /// Execute a simple query and call a closure for each row.
381    ///
382    /// # Example
383    ///
384    /// ```ignore
385    /// conn.query_foreach("SELECT id, name FROM users", |row: (i32, String)| {
386    ///     println!("{}: {}", row.0, row.1);
387    ///     Ok(())
388    /// }).await?;
389    /// ```
390    ///
391    /// The closure can return an error to stop iteration early.
392    pub async fn query_foreach<
393        T: for<'a> crate::conversion::FromRow<'a>,
394        F: FnMut(T) -> Result<()>,
395    >(
396        &mut self,
397        sql: &str,
398        f: F,
399    ) -> Result<()> {
400        let mut handler = crate::handler::ForEachHandler::<T, F>::new(f);
401        self.query(sql, &mut handler).await?;
402        Ok(())
403    }
404
405    /// Close the connection gracefully.
406    pub async fn close(mut self) -> Result<()> {
407        self.buffer_set.write_buffer.clear();
408        write_terminate(&mut self.buffer_set.write_buffer);
409        self.stream.write_all(&self.buffer_set.write_buffer).await?;
410        self.stream.flush().await?;
411        Ok(())
412    }
413
414    // === Extended Query Protocol ===
415
416    /// Prepare a statement using the extended query protocol.
417    pub async fn prepare(&mut self, query: &str) -> Result<PreparedStatement> {
418        self.prepare_typed(query, &[]).await
419    }
420
421    /// Prepare a statement with explicit parameter types.
422    pub async fn prepare_typed(
423        &mut self,
424        query: &str,
425        param_oids: &[u32],
426    ) -> Result<PreparedStatement> {
427        self.name_counter += 1;
428        let idx = self.name_counter;
429        let result = self.prepare_inner(idx, query, param_oids).await;
430        if let Err(e) = &result
431            && e.is_connection_broken()
432        {
433            self.is_broken = true;
434        }
435        result
436    }
437
438    /// Prepare multiple statements in a single round-trip.
439    ///
440    /// This is more efficient than calling `prepare()` multiple times when you
441    /// need to prepare several statements, as it batches the network communication.
442    ///
443    /// # Example
444    ///
445    /// ```ignore
446    /// let stmts = conn.prepare_batch(&[
447    ///     "SELECT id, name FROM users WHERE id = $1",
448    ///     "INSERT INTO users (name) VALUES ($1) RETURNING id",
449    ///     "UPDATE users SET name = $1 WHERE id = $2",
450    /// ]).await?;
451    ///
452    /// // Use stmts[0], stmts[1], stmts[2]...
453    /// ```
454    pub async fn prepare_batch(&mut self, queries: &[&str]) -> Result<Vec<PreparedStatement>> {
455        if queries.is_empty() {
456            return Ok(Vec::new());
457        }
458
459        let start_idx = self.name_counter + 1;
460        self.name_counter += queries.len() as u64;
461
462        let result = self.prepare_batch_inner(queries, start_idx).await;
463        if let Err(e) = &result
464            && e.is_connection_broken()
465        {
466            self.is_broken = true;
467        }
468        result
469    }
470
471    async fn prepare_batch_inner(
472        &mut self,
473        queries: &[&str],
474        start_idx: u64,
475    ) -> Result<Vec<PreparedStatement>> {
476        use crate::state::batch_prepare::BatchPrepareStateMachine;
477
478        let mut state_machine =
479            BatchPrepareStateMachine::new(&mut self.buffer_set, queries, start_idx);
480
481        loop {
482            match state_machine.step(&mut self.buffer_set)? {
483                Action::ReadMessage => {
484                    self.stream.read_message(&mut self.buffer_set).await?;
485                }
486                Action::WriteAndReadMessage => {
487                    self.stream.write_all(&self.buffer_set.write_buffer).await?;
488                    self.stream.flush().await?;
489                    self.stream.read_message(&mut self.buffer_set).await?;
490                }
491                Action::Finished => {
492                    self.transaction_status = state_machine.transaction_status();
493                    break;
494                }
495                _ => return Err(Error::Protocol("Unexpected action in batch prepare".into())),
496            }
497        }
498
499        Ok(state_machine.take_statements())
500    }
501
502    async fn prepare_inner(
503        &mut self,
504        idx: u64,
505        query: &str,
506        param_oids: &[u32],
507    ) -> Result<PreparedStatement> {
508        let mut handler = DropHandler::new();
509        let mut state_machine = ExtendedQueryStateMachine::prepare(
510            &mut handler,
511            &mut self.buffer_set,
512            idx,
513            query,
514            param_oids,
515        );
516        self.drive(&mut state_machine).await?;
517        state_machine
518            .take_prepared_statement()
519            .ok_or_else(|| Error::Protocol("No prepared statement".into()))
520    }
521
522    /// Execute a statement with a handler.
523    ///
524    /// The statement can be either:
525    /// - A `&PreparedStatement` returned from `prepare()`
526    /// - A raw SQL `&str` for one-shot execution
527    pub async fn exec<S: IntoStatement, P: ToParams, H: BinaryHandler>(
528        &mut self,
529        statement: S,
530        params: P,
531        handler: &mut H,
532    ) -> Result<()> {
533        let result = self.exec_inner(&statement, &params, handler).await;
534        if let Err(e) = &result
535            && e.is_connection_broken()
536        {
537            self.is_broken = true;
538        }
539        result
540    }
541
542    async fn exec_inner<S: IntoStatement, P: ToParams, H: BinaryHandler>(
543        &mut self,
544        statement: &S,
545        params: &P,
546        handler: &mut H,
547    ) -> Result<()> {
548        let mut state_machine = if statement.needs_parse() {
549            ExtendedQueryStateMachine::execute_sql(
550                handler,
551                &mut self.buffer_set,
552                statement.as_sql().unwrap(),
553                params,
554            )?
555        } else {
556            let stmt = statement.as_prepared().unwrap();
557            ExtendedQueryStateMachine::execute(
558                handler,
559                &mut self.buffer_set,
560                &stmt.wire_name(),
561                &stmt.param_oids,
562                params,
563            )?
564        };
565
566        self.drive(&mut state_machine).await
567    }
568
569    /// Execute a statement and discard results.
570    ///
571    /// The statement can be either a `&PreparedStatement` or a raw SQL `&str`.
572    pub async fn exec_drop<S: IntoStatement, P: ToParams>(
573        &mut self,
574        statement: S,
575        params: P,
576    ) -> Result<Option<u64>> {
577        let mut handler = DropHandler::new();
578        self.exec(statement, params, &mut handler).await?;
579        Ok(handler.rows_affected())
580    }
581
582    /// Execute a statement and collect typed rows.
583    ///
584    /// The statement can be either a `&PreparedStatement` or a raw SQL `&str`.
585    pub async fn exec_collect<
586        T: for<'a> crate::conversion::FromRow<'a>,
587        S: IntoStatement,
588        P: ToParams,
589    >(
590        &mut self,
591        statement: S,
592        params: P,
593    ) -> Result<Vec<T>> {
594        let mut handler = crate::handler::CollectHandler::<T>::new();
595        self.exec(statement, params, &mut handler).await?;
596        Ok(handler.into_rows())
597    }
598
599    /// Execute a statement and return the first typed row.
600    ///
601    /// The statement can be either a `&PreparedStatement` or a raw SQL `&str`.
602    pub async fn exec_first<
603        T: for<'a> crate::conversion::FromRow<'a>,
604        S: IntoStatement,
605        P: ToParams,
606    >(
607        &mut self,
608        statement: S,
609        params: P,
610    ) -> Result<Option<T>> {
611        let mut handler = crate::handler::FirstRowHandler::<T>::new();
612        self.exec(statement, params, &mut handler).await?;
613        Ok(handler.into_row())
614    }
615
616    /// Execute a statement and call a closure for each row.
617    ///
618    /// The statement can be either a `&PreparedStatement` or a raw SQL `&str`.
619    ///
620    /// # Example
621    ///
622    /// ```ignore
623    /// let stmt = conn.prepare("SELECT id, name FROM users").await?;
624    /// conn.exec_foreach(&stmt, (), |row: (i32, String)| {
625    ///     println!("{}: {}", row.0, row.1);
626    ///     Ok(())
627    /// }).await?;
628    /// ```
629    ///
630    /// The closure can return an error to stop iteration early.
631    pub async fn exec_foreach<
632        T: for<'a> crate::conversion::FromRow<'a>,
633        S: IntoStatement,
634        P: ToParams,
635        F: FnMut(T) -> Result<()>,
636    >(
637        &mut self,
638        statement: S,
639        params: P,
640        f: F,
641    ) -> Result<()> {
642        let mut handler = crate::handler::ForEachHandler::<T, F>::new(f);
643        self.exec(statement, params, &mut handler).await?;
644        Ok(())
645    }
646
647    /// Execute a statement with multiple parameter sets in a batch.
648    ///
649    /// This is more efficient than calling `exec_drop` multiple times as it
650    /// batches the network communication. The statement is parsed once (if raw SQL)
651    /// and then bound/executed for each parameter set.
652    ///
653    /// Parameters are processed in chunks (default 1000) to avoid overwhelming
654    /// the server with too many pending operations.
655    ///
656    /// The statement can be either:
657    /// - A `&PreparedStatement` returned from `prepare()`
658    /// - A raw SQL `&str` for one-shot execution
659    ///
660    /// # Example
661    ///
662    /// ```ignore
663    /// // Using prepared statement
664    /// let stmt = conn.prepare("INSERT INTO users (name, age) VALUES ($1, $2)").await?;
665    /// conn.exec_batch(&stmt, &[
666    ///     ("alice", 30),
667    ///     ("bob", 25),
668    ///     ("charlie", 35),
669    /// ]).await?;
670    ///
671    /// // Using raw SQL
672    /// conn.exec_batch("INSERT INTO users (name, age) VALUES ($1, $2)", &[
673    ///     ("alice", 30),
674    ///     ("bob", 25),
675    /// ]).await?;
676    /// ```
677    pub async fn exec_batch<S: IntoStatement, P: ToParams>(
678        &mut self,
679        statement: S,
680        params_list: &[P],
681    ) -> Result<()> {
682        self.exec_batch_chunked(statement, params_list, 1000).await
683    }
684
685    /// Execute a statement with multiple parameter sets in a batch with custom chunk size.
686    ///
687    /// Same as `exec_batch` but allows specifying the chunk size for batching.
688    pub async fn exec_batch_chunked<S: IntoStatement, P: ToParams>(
689        &mut self,
690        statement: S,
691        params_list: &[P],
692        chunk_size: usize,
693    ) -> Result<()> {
694        let result = self
695            .exec_batch_inner(&statement, params_list, chunk_size)
696            .await;
697        if let Err(e) = &result
698            && e.is_connection_broken()
699        {
700            self.is_broken = true;
701        }
702        result
703    }
704
705    async fn exec_batch_inner<S: IntoStatement, P: ToParams>(
706        &mut self,
707        statement: &S,
708        params_list: &[P],
709        chunk_size: usize,
710    ) -> Result<()> {
711        use crate::protocol::frontend::{write_bind, write_execute, write_parse, write_sync};
712        use crate::state::extended::BatchStateMachine;
713
714        if params_list.is_empty() {
715            return Ok(());
716        }
717
718        let chunk_size = chunk_size.max(1);
719        let needs_parse = statement.needs_parse();
720        let sql = statement.as_sql();
721        let prepared = statement.as_prepared();
722
723        // Get param OIDs from first params or prepared statement
724        let param_oids: Vec<u32> = if let Some(stmt) = prepared {
725            stmt.param_oids.clone()
726        } else {
727            params_list[0].natural_oids()
728        };
729
730        // Statement name: empty for raw SQL, actual name for prepared
731        let stmt_name = prepared.map(|s| s.wire_name()).unwrap_or_default();
732
733        for chunk in params_list.chunks(chunk_size) {
734            self.buffer_set.write_buffer.clear();
735
736            // For raw SQL, send Parse each chunk (reuses unnamed statement)
737            let parse_in_chunk = needs_parse;
738            if parse_in_chunk {
739                write_parse(
740                    &mut self.buffer_set.write_buffer,
741                    "",
742                    sql.unwrap(),
743                    &param_oids,
744                );
745            }
746
747            // Write Bind + Execute for each param set
748            for params in chunk {
749                let effective_stmt_name = if needs_parse { "" } else { &stmt_name };
750                write_bind(
751                    &mut self.buffer_set.write_buffer,
752                    "",
753                    effective_stmt_name,
754                    params,
755                    &param_oids,
756                )?;
757                write_execute(&mut self.buffer_set.write_buffer, "", 0);
758            }
759
760            // Send Sync
761            write_sync(&mut self.buffer_set.write_buffer);
762
763            // Drive state machine
764            let mut state_machine = BatchStateMachine::new(parse_in_chunk);
765            self.drive_batch(&mut state_machine).await?;
766            self.transaction_status = state_machine.transaction_status();
767        }
768
769        Ok(())
770    }
771
772    /// Drive a batch state machine to completion.
773    async fn drive_batch(
774        &mut self,
775        state_machine: &mut crate::state::extended::BatchStateMachine,
776    ) -> Result<()> {
777        use crate::protocol::backend::{ReadyForQuery, msg_type};
778        use crate::state::action::Action;
779
780        loop {
781            let step_result = state_machine.step(&mut self.buffer_set);
782            match step_result {
783                Ok(Action::ReadMessage) => {
784                    self.stream.read_message(&mut self.buffer_set).await?;
785                }
786                Ok(Action::WriteAndReadMessage) => {
787                    self.stream.write_all(&self.buffer_set.write_buffer).await?;
788                    self.stream.flush().await?;
789                    self.stream.read_message(&mut self.buffer_set).await?;
790                }
791                Ok(Action::Finished) => {
792                    break;
793                }
794                Ok(_) => return Err(Error::Protocol("Unexpected action in batch".into())),
795                Err(e) => {
796                    // On error, drain to ReadyForQuery to leave connection in clean state
797                    loop {
798                        self.stream.read_message(&mut self.buffer_set).await?;
799                        if self.buffer_set.type_byte == msg_type::READY_FOR_QUERY {
800                            let ready = ReadyForQuery::parse(&self.buffer_set.read_buffer)?;
801                            self.transaction_status =
802                                ready.transaction_status().unwrap_or_default();
803                            break;
804                        }
805                    }
806                    return Err(e);
807                }
808            }
809        }
810        Ok(())
811    }
812
813    /// Close a prepared statement.
814    pub async fn close_statement(&mut self, stmt: &PreparedStatement) -> Result<()> {
815        let result = self.close_statement_inner(&stmt.wire_name()).await;
816        if let Err(e) = &result
817            && e.is_connection_broken()
818        {
819            self.is_broken = true;
820        }
821        result
822    }
823
824    async fn close_statement_inner(&mut self, name: &str) -> Result<()> {
825        let mut handler = DropHandler::new();
826        let mut state_machine =
827            ExtendedQueryStateMachine::close_statement(&mut handler, &mut self.buffer_set, name);
828        self.drive(&mut state_machine).await
829    }
830
831    // === Low-Level Extended Query Protocol ===
832
833    /// Low-level flush: send FLUSH to force server to send pending responses.
834    ///
835    /// Unlike SYNC, FLUSH does not end the transaction or wait for ReadyForQuery.
836    /// It just forces the server to send any pending responses without ending
837    /// the extended query sequence.
838    pub async fn lowlevel_flush(&mut self) -> Result<()> {
839        use crate::protocol::frontend::write_flush;
840
841        self.buffer_set.write_buffer.clear();
842        write_flush(&mut self.buffer_set.write_buffer);
843
844        self.stream.write_all(&self.buffer_set.write_buffer).await?;
845        self.stream.flush().await?;
846        Ok(())
847    }
848
849    /// Low-level sync: send SYNC and receive ReadyForQuery.
850    ///
851    /// This ends an extended query sequence and:
852    /// - Commits implicit transaction if successful
853    /// - Rolls back implicit transaction if failed
854    /// - Updates transaction status
855    pub async fn lowlevel_sync(&mut self) -> Result<()> {
856        let result = self.lowlevel_sync_inner().await;
857        if let Err(e) = &result
858            && e.is_connection_broken()
859        {
860            self.is_broken = true;
861        }
862        result
863    }
864
865    async fn lowlevel_sync_inner(&mut self) -> Result<()> {
866        use crate::protocol::backend::{ErrorResponse, RawMessage, ReadyForQuery, msg_type};
867        use crate::protocol::frontend::write_sync;
868
869        self.buffer_set.write_buffer.clear();
870        write_sync(&mut self.buffer_set.write_buffer);
871
872        self.stream.write_all(&self.buffer_set.write_buffer).await?;
873        self.stream.flush().await?;
874
875        let mut pending_error: Option<Error> = None;
876
877        loop {
878            self.stream.read_message(&mut self.buffer_set).await?;
879            let type_byte = self.buffer_set.type_byte;
880
881            if RawMessage::is_async_type(type_byte) {
882                continue;
883            }
884
885            match type_byte {
886                msg_type::READY_FOR_QUERY => {
887                    let ready = ReadyForQuery::parse(&self.buffer_set.read_buffer)?;
888                    self.transaction_status = ready.transaction_status().unwrap_or_default();
889                    if let Some(e) = pending_error {
890                        return Err(e);
891                    }
892                    return Ok(());
893                }
894                msg_type::ERROR_RESPONSE => {
895                    let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
896                    pending_error = Some(error.into_error());
897                }
898                _ => {
899                    // Ignore other messages before ReadyForQuery
900                }
901            }
902        }
903    }
904
905    /// Low-level bind: send BIND message and receive BindComplete.
906    ///
907    /// This allows creating named portals. Unlike `exec()`, this does NOT
908    /// send EXECUTE or SYNC - the caller controls when to execute and sync.
909    ///
910    /// # Arguments
911    /// - `portal`: Portal name (empty string "" for unnamed portal)
912    /// - `statement_name`: Prepared statement name
913    /// - `params`: Parameter values
914    pub async fn lowlevel_bind<P: ToParams>(
915        &mut self,
916        portal: &str,
917        statement_name: &str,
918        params: P,
919    ) -> Result<()> {
920        let result = self
921            .lowlevel_bind_inner(portal, statement_name, &params)
922            .await;
923        if let Err(e) = &result
924            && e.is_connection_broken()
925        {
926            self.is_broken = true;
927        }
928        result
929    }
930
931    async fn lowlevel_bind_inner<P: ToParams>(
932        &mut self,
933        portal: &str,
934        statement_name: &str,
935        params: &P,
936    ) -> Result<()> {
937        use crate::protocol::backend::{BindComplete, ErrorResponse, RawMessage, msg_type};
938        use crate::protocol::frontend::{write_bind, write_flush};
939
940        let param_oids = params.natural_oids();
941        self.buffer_set.write_buffer.clear();
942        write_bind(
943            &mut self.buffer_set.write_buffer,
944            portal,
945            statement_name,
946            params,
947            &param_oids,
948        )?;
949        write_flush(&mut self.buffer_set.write_buffer);
950
951        self.stream.write_all(&self.buffer_set.write_buffer).await?;
952        self.stream.flush().await?;
953
954        loop {
955            self.stream.read_message(&mut self.buffer_set).await?;
956            let type_byte = self.buffer_set.type_byte;
957
958            if RawMessage::is_async_type(type_byte) {
959                continue;
960            }
961
962            match type_byte {
963                msg_type::BIND_COMPLETE => {
964                    BindComplete::parse(&self.buffer_set.read_buffer)?;
965                    return Ok(());
966                }
967                msg_type::ERROR_RESPONSE => {
968                    let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
969                    return Err(error.into_error());
970                }
971                _ => {
972                    return Err(Error::Protocol(format!(
973                        "Expected BindComplete or ErrorResponse, got '{}'",
974                        type_byte as char
975                    )));
976                }
977            }
978        }
979    }
980
981    /// Low-level execute: send EXECUTE message and receive results.
982    ///
983    /// Executes a previously bound portal. Does NOT send SYNC.
984    ///
985    /// # Arguments
986    /// - `portal`: Portal name (empty string "" for unnamed portal)
987    /// - `max_rows`: Maximum rows to return (0 = unlimited)
988    /// - `handler`: Handler to receive rows
989    ///
990    /// # Returns
991    /// - `Ok(true)` if more rows available (PortalSuspended received)
992    /// - `Ok(false)` if execution completed (CommandComplete received)
993    pub async fn lowlevel_execute<H: BinaryHandler>(
994        &mut self,
995        portal: &str,
996        max_rows: u32,
997        handler: &mut H,
998    ) -> Result<bool> {
999        let result = self.lowlevel_execute_inner(portal, max_rows, handler).await;
1000        if let Err(e) = &result
1001            && e.is_connection_broken()
1002        {
1003            self.is_broken = true;
1004        }
1005        result
1006    }
1007
1008    async fn lowlevel_execute_inner<H: BinaryHandler>(
1009        &mut self,
1010        portal: &str,
1011        max_rows: u32,
1012        handler: &mut H,
1013    ) -> Result<bool> {
1014        use crate::protocol::backend::{
1015            CommandComplete, DataRow, ErrorResponse, NoData, PortalSuspended, RawMessage,
1016            RowDescription, msg_type,
1017        };
1018        use crate::protocol::frontend::{write_describe_portal, write_execute, write_flush};
1019
1020        self.buffer_set.write_buffer.clear();
1021        write_describe_portal(&mut self.buffer_set.write_buffer, portal);
1022        write_execute(&mut self.buffer_set.write_buffer, portal, max_rows);
1023        write_flush(&mut self.buffer_set.write_buffer);
1024
1025        self.stream.write_all(&self.buffer_set.write_buffer).await?;
1026        self.stream.flush().await?;
1027
1028        let mut column_buffer: Vec<u8> = Vec::new();
1029
1030        loop {
1031            self.stream.read_message(&mut self.buffer_set).await?;
1032            let type_byte = self.buffer_set.type_byte;
1033
1034            if RawMessage::is_async_type(type_byte) {
1035                continue;
1036            }
1037
1038            match type_byte {
1039                msg_type::ROW_DESCRIPTION => {
1040                    column_buffer.clear();
1041                    column_buffer.extend_from_slice(&self.buffer_set.read_buffer);
1042                    let cols = RowDescription::parse(&column_buffer)?;
1043                    handler.result_start(cols)?;
1044                }
1045                msg_type::NO_DATA => {
1046                    NoData::parse(&self.buffer_set.read_buffer)?;
1047                }
1048                msg_type::DATA_ROW => {
1049                    let cols = RowDescription::parse(&column_buffer)?;
1050                    let row = DataRow::parse(&self.buffer_set.read_buffer)?;
1051                    handler.row(cols, row)?;
1052                }
1053                msg_type::COMMAND_COMPLETE => {
1054                    let complete = CommandComplete::parse(&self.buffer_set.read_buffer)?;
1055                    handler.result_end(complete)?;
1056                    return Ok(false); // No more rows
1057                }
1058                msg_type::PORTAL_SUSPENDED => {
1059                    PortalSuspended::parse(&self.buffer_set.read_buffer)?;
1060                    return Ok(true); // More rows available
1061                }
1062                msg_type::ERROR_RESPONSE => {
1063                    let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
1064                    return Err(error.into_error());
1065                }
1066                _ => {
1067                    return Err(Error::Protocol(format!(
1068                        "Unexpected message in execute: '{}'",
1069                        type_byte as char
1070                    )));
1071                }
1072            }
1073        }
1074    }
1075
1076    /// Execute a statement with iterative row fetching using a portal.
1077    ///
1078    /// Creates an unnamed portal and passes it to the closure. The closure can
1079    /// call `portal.fetch(n, handler)` multiple times to retrieve rows in batches.
1080    /// Sync is called after the closure returns to end the implicit transaction.
1081    ///
1082    /// The statement can be either:
1083    /// - A `&PreparedStatement` returned from `prepare()`
1084    /// - A raw SQL `&str` for one-shot execution
1085    ///
1086    /// # Example
1087    /// ```ignore
1088    /// // Using prepared statement
1089    /// let stmt = conn.prepare("SELECT * FROM users").await?;
1090    /// conn.exec_portal(&stmt, (), |portal| async move {
1091    ///     while portal.fetch(100, &mut handler).await? {
1092    ///         // process handler.into_rows()...
1093    ///     }
1094    ///     Ok(())
1095    /// }).await?;
1096    ///
1097    /// // Using raw SQL
1098    /// conn.exec_portal("SELECT * FROM users", (), |portal| async move {
1099    ///     while portal.fetch(100, &mut handler).await? {
1100    ///         // process handler.into_rows()...
1101    ///     }
1102    ///     Ok(())
1103    /// }).await?;
1104    /// ```
1105    pub async fn exec_portal<S: IntoStatement, P, F, Fut, T>(
1106        &mut self,
1107        statement: S,
1108        params: P,
1109        f: F,
1110    ) -> Result<T>
1111    where
1112        P: ToParams,
1113        F: FnOnce(&mut super::unnamed_portal::UnnamedPortal<'_>) -> Fut,
1114        Fut: std::future::Future<Output = Result<T>>,
1115    {
1116        let result = self.exec_portal_inner(&statement, &params, f).await;
1117        if let Err(e) = &result
1118            && e.is_connection_broken()
1119        {
1120            self.is_broken = true;
1121        }
1122        result
1123    }
1124
1125    async fn exec_portal_inner<S: IntoStatement, P, F, Fut, T>(
1126        &mut self,
1127        statement: &S,
1128        params: &P,
1129        f: F,
1130    ) -> Result<T>
1131    where
1132        P: ToParams,
1133        F: FnOnce(&mut super::unnamed_portal::UnnamedPortal<'_>) -> Fut,
1134        Fut: std::future::Future<Output = Result<T>>,
1135    {
1136        // Create bind state machine for unnamed portal
1137        let mut state_machine = if let Some(sql) = statement.as_sql() {
1138            BindStateMachine::bind_sql(&mut self.buffer_set, "", sql, params)?
1139        } else {
1140            let stmt = statement.as_prepared().unwrap();
1141            BindStateMachine::bind_prepared(
1142                &mut self.buffer_set,
1143                "",
1144                &stmt.wire_name(),
1145                &stmt.param_oids,
1146                params,
1147            )?
1148        };
1149
1150        // Drive the state machine to completion (ParseComplete + BindComplete)
1151        loop {
1152            match state_machine.step(&mut self.buffer_set)? {
1153                Action::ReadMessage => {
1154                    self.stream.read_message(&mut self.buffer_set).await?;
1155                }
1156                Action::Write => {
1157                    self.stream.write_all(&self.buffer_set.write_buffer).await?;
1158                    self.stream.flush().await?;
1159                }
1160                Action::WriteAndReadMessage => {
1161                    self.stream.write_all(&self.buffer_set.write_buffer).await?;
1162                    self.stream.flush().await?;
1163                    self.stream.read_message(&mut self.buffer_set).await?;
1164                }
1165                Action::Finished => break,
1166                _ => return Err(Error::Protocol("Unexpected action in bind".into())),
1167            }
1168        }
1169
1170        // Execute closure with portal handle
1171        let mut portal = super::unnamed_portal::UnnamedPortal { conn: self };
1172        let result = f(&mut portal).await;
1173
1174        // Always sync to end implicit transaction (even on error)
1175        let sync_result = portal.conn.lowlevel_sync().await;
1176
1177        // Return closure result, or sync error if closure succeeded but sync failed
1178        match (result, sync_result) {
1179            (Ok(v), Ok(())) => Ok(v),
1180            (Err(e), _) => Err(e),
1181            (Ok(_), Err(e)) => Err(e),
1182        }
1183    }
1184
1185    /// Low-level close portal: send Close(Portal) and receive CloseComplete.
1186    pub async fn lowlevel_close_portal(&mut self, portal: &str) -> Result<()> {
1187        let result = self.lowlevel_close_portal_inner(portal).await;
1188        if let Err(e) = &result
1189            && e.is_connection_broken()
1190        {
1191            self.is_broken = true;
1192        }
1193        result
1194    }
1195
1196    async fn lowlevel_close_portal_inner(&mut self, portal: &str) -> Result<()> {
1197        use crate::protocol::backend::{CloseComplete, ErrorResponse, RawMessage, msg_type};
1198        use crate::protocol::frontend::{write_close_portal, write_flush};
1199
1200        self.buffer_set.write_buffer.clear();
1201        write_close_portal(&mut self.buffer_set.write_buffer, portal);
1202        write_flush(&mut self.buffer_set.write_buffer);
1203
1204        self.stream.write_all(&self.buffer_set.write_buffer).await?;
1205        self.stream.flush().await?;
1206
1207        loop {
1208            self.stream.read_message(&mut self.buffer_set).await?;
1209            let type_byte = self.buffer_set.type_byte;
1210
1211            if RawMessage::is_async_type(type_byte) {
1212                continue;
1213            }
1214
1215            match type_byte {
1216                msg_type::CLOSE_COMPLETE => {
1217                    CloseComplete::parse(&self.buffer_set.read_buffer)?;
1218                    return Ok(());
1219                }
1220                msg_type::ERROR_RESPONSE => {
1221                    let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
1222                    return Err(error.into_error());
1223                }
1224                _ => {
1225                    return Err(Error::Protocol(format!(
1226                        "Expected CloseComplete or ErrorResponse, got '{}'",
1227                        type_byte as char
1228                    )));
1229                }
1230            }
1231        }
1232    }
1233
1234    /// Run a pipeline of batched queries.
1235    ///
1236    /// Pipeline mode allows sending multiple queries to the server without waiting
1237    /// for responses, reducing round-trip latency.
1238    ///
1239    /// # Example
1240    ///
1241    /// ```ignore
1242    /// // Prepare statements outside the pipeline
1243    /// let stmts = conn.prepare_batch(&[
1244    ///     "SELECT id, name FROM users WHERE active = $1",
1245    ///     "INSERT INTO users (name) VALUES ($1) RETURNING id",
1246    /// ]).await?;
1247    ///
1248    /// let (active, inactive, count) = conn.run_pipeline(|p| async move {
1249    ///     // Queue executions
1250    ///     let t1 = p.exec(&stmts[0], (true,)).await?;
1251    ///     let t2 = p.exec(&stmts[0], (false,)).await?;
1252    ///     let t3 = p.exec("SELECT COUNT(*) FROM users", ()).await?;
1253    ///
1254    ///     p.sync().await?;
1255    ///
1256    ///     // Claim results in order with different methods
1257    ///     let active: Vec<(i32, String)> = p.claim_collect(t1).await?;
1258    ///     let inactive: Option<(i32, String)> = p.claim_one(t2).await?;
1259    ///     let count: Vec<(i64,)> = p.claim_collect(t3).await?;
1260    ///
1261    ///     Ok((active, inactive, count))
1262    /// }).await?;
1263    /// ```
1264    pub async fn run_pipeline<T, F, Fut>(&mut self, f: F) -> Result<T>
1265    where
1266        F: FnOnce(&mut super::pipeline::Pipeline<'_>) -> Fut,
1267        Fut: std::future::Future<Output = Result<T>>,
1268    {
1269        let mut pipeline = super::pipeline::Pipeline::new_inner(self);
1270        let result = f(&mut pipeline).await;
1271        pipeline.cleanup().await;
1272        result
1273    }
1274
1275    /// Execute a closure within a transaction.
1276    ///
1277    /// If the closure returns `Ok`, the transaction is committed.
1278    /// If the closure returns `Err` or the transaction is not explicitly
1279    /// committed or rolled back, the transaction is rolled back.
1280    ///
1281    /// # Errors
1282    ///
1283    /// Returns `Error::InvalidUsage` if called while already in a transaction.
1284    pub async fn tx<F, R, Fut>(&mut self, f: F) -> Result<R>
1285    where
1286        F: FnOnce(&mut Conn, super::transaction::Transaction) -> Fut,
1287        Fut: std::future::Future<Output = Result<R>>,
1288    {
1289        if self.in_transaction() {
1290            return Err(Error::InvalidUsage(
1291                "nested transactions are not supported".into(),
1292            ));
1293        }
1294
1295        self.query_drop("BEGIN").await?;
1296
1297        let tx = super::transaction::Transaction::new(self.connection_id());
1298
1299        // We need to use unsafe to work around the borrow checker here
1300        // because async closures can't capture &mut self properly
1301        let result = f(self, tx).await;
1302
1303        // If still in a transaction (not committed or rolled back), roll it back
1304        if self.in_transaction() {
1305            let rollback_result = self.query_drop("ROLLBACK").await;
1306
1307            // Return the first error (either from closure or rollback)
1308            if let Err(e) = result {
1309                return Err(e);
1310            }
1311            rollback_result?;
1312        }
1313
1314        result
1315    }
1316}