zero_postgres/tokio/
conn.rs

1//! Asynchronous PostgreSQL connection.
2
3use tokio::net::TcpStream;
4#[cfg(unix)]
5use tokio::net::UnixStream;
6
7use crate::buffer_pool::PooledBufferSet;
8use crate::conversion::ToParams;
9use crate::error::{Error, Result};
10use crate::handler::{
11    AsyncMessageHandler, BinaryHandler, DropHandler, FirstRowHandler, TextHandler,
12};
13use crate::opts::Opts;
14use crate::protocol::backend::BackendKeyData;
15use crate::protocol::frontend::write_terminate;
16use crate::protocol::types::TransactionStatus;
17use crate::state::StateMachine;
18use crate::state::action::Action;
19use crate::state::connection::ConnectionStateMachine;
20use crate::state::extended::{BindStateMachine, ExtendedQueryStateMachine, PreparedStatement};
21use crate::state::simple_query::SimpleQueryStateMachine;
22use crate::statement::IntoStatement;
23
24use super::stream::Stream;
25
26/// Asynchronous PostgreSQL connection.
27pub struct Conn {
28    pub(crate) stream: Stream,
29    pub(crate) buffer_set: PooledBufferSet,
30    backend_key: Option<BackendKeyData>,
31    server_params: Vec<(String, String)>,
32    pub(crate) transaction_status: TransactionStatus,
33    pub(crate) is_broken: bool,
34    name_counter: u64,
35    async_message_handler: Option<Box<dyn AsyncMessageHandler>>,
36}
37
38impl Conn {
39    /// Connect to a PostgreSQL server.
40    pub async fn new<O: TryInto<Opts>>(opts: O) -> Result<Self>
41    where
42        Error: From<O::Error>,
43    {
44        let opts = opts.try_into()?;
45
46        let stream = if let Some(socket_path) = &opts.socket {
47            #[cfg(unix)]
48            {
49                Stream::unix(UnixStream::connect(socket_path).await?)
50            }
51            #[cfg(not(unix))]
52            {
53                let _ = socket_path;
54                return Err(Error::Unsupported(
55                    "Unix sockets are not supported on this platform".into(),
56                ));
57            }
58        } else {
59            if opts.host.is_empty() {
60                return Err(Error::InvalidUsage("host is empty".into()));
61            }
62            let addr = format!("{}:{}", opts.host, opts.port);
63            let tcp = TcpStream::connect(&addr).await?;
64            tcp.set_nodelay(true)?;
65            Stream::tcp(tcp)
66        };
67
68        Self::new_with_stream(stream, opts).await
69    }
70
71    /// Connect using an existing stream.
72    #[allow(unused_mut)]
73    pub async fn new_with_stream(mut stream: Stream, options: Opts) -> Result<Self> {
74        let mut buffer_set = options.buffer_pool.get_buffer_set();
75        let mut state_machine = ConnectionStateMachine::new(options.clone());
76
77        // Drive the connection state machine
78        loop {
79            match state_machine.step(&mut buffer_set)? {
80                Action::WriteAndReadByte => {
81                    stream.write_all(&buffer_set.write_buffer).await?;
82                    stream.flush().await?;
83                    let byte = stream.read_u8().await?;
84                    state_machine.set_ssl_response(byte);
85                }
86                Action::ReadMessage => {
87                    stream.read_message(&mut buffer_set).await?;
88                }
89                Action::Write => {
90                    stream.write_all(&buffer_set.write_buffer).await?;
91                    stream.flush().await?;
92                }
93                Action::WriteAndReadMessage => {
94                    stream.write_all(&buffer_set.write_buffer).await?;
95                    stream.flush().await?;
96                    stream.read_message(&mut buffer_set).await?;
97                }
98                Action::TlsHandshake => {
99                    #[cfg(feature = "tokio-tls")]
100                    {
101                        stream = stream.upgrade_to_tls(&options.host).await?;
102                    }
103                    #[cfg(not(feature = "tokio-tls"))]
104                    {
105                        return Err(Error::Unsupported(
106                            "TLS requested but tokio-tls feature not enabled".into(),
107                        ));
108                    }
109                }
110                Action::HandleAsyncMessageAndReadMessage(_) => {
111                    // Ignore async messages during startup, read next message
112                    stream.read_message(&mut buffer_set).await?;
113                }
114                Action::Finished => break,
115            }
116        }
117
118        let conn = Self {
119            stream,
120            buffer_set,
121            backend_key: state_machine.backend_key().cloned(),
122            server_params: state_machine.take_server_params(),
123            transaction_status: state_machine.transaction_status(),
124            is_broken: false,
125            name_counter: 0,
126            async_message_handler: None,
127        };
128
129        // Upgrade to Unix socket if connected via TCP to loopback
130        #[cfg(unix)]
131        let conn = if options.upgrade_to_unix_socket && conn.stream.is_tcp_loopback() {
132            conn.try_upgrade_to_unix_socket(&options).await
133        } else {
134            conn
135        };
136
137        Ok(conn)
138    }
139
140    /// Try to upgrade to Unix socket connection.
141    /// Returns upgraded conn on success, original conn on failure.
142    #[cfg(unix)]
143    fn try_upgrade_to_unix_socket(
144        mut self,
145        opts: &Opts,
146    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Self> + Send + '_>> {
147        let opts = opts.clone();
148        Box::pin(async move {
149            // Query unix_socket_directories from server
150            let mut handler = FirstRowHandler::<(String,)>::new();
151            if self
152                .query("SHOW unix_socket_directories", &mut handler)
153                .await
154                .is_err()
155            {
156                return self;
157            }
158
159            let socket_dir = match handler.into_row() {
160                Some((dirs,)) => {
161                    // May contain multiple directories, use the first one
162                    match dirs.split(',').next() {
163                        Some(d) if !d.trim().is_empty() => d.trim().to_string(),
164                        _ => return self,
165                    }
166                }
167                None => return self,
168            };
169
170            // Build socket path: {directory}/.s.PGSQL.{port}
171            let socket_path = format!("{}/.s.PGSQL.{}", socket_dir, opts.port);
172
173            // Connect via Unix socket
174            let unix_stream = match UnixStream::connect(&socket_path).await {
175                Ok(s) => s,
176                Err(_) => return self,
177            };
178
179            // Create new connection over Unix socket
180            let mut opts_unix = opts.clone();
181            opts_unix.upgrade_to_unix_socket = false;
182
183            match Self::new_with_stream(Stream::unix(unix_stream), opts_unix).await {
184                Ok(new_conn) => new_conn,
185                Err(_) => self,
186            }
187        })
188    }
189
190    /// Get the backend key data for query cancellation.
191    pub fn backend_key(&self) -> Option<&BackendKeyData> {
192        self.backend_key.as_ref()
193    }
194
195    /// Get the connection ID (backend process ID).
196    ///
197    /// Returns 0 if the backend key data is not available.
198    pub fn connection_id(&self) -> u32 {
199        self.backend_key.as_ref().map_or(0, |k| k.process_id())
200    }
201
202    /// Get server parameters.
203    pub fn server_params(&self) -> &[(String, String)] {
204        &self.server_params
205    }
206
207    /// Get the current transaction status.
208    pub fn transaction_status(&self) -> TransactionStatus {
209        self.transaction_status
210    }
211
212    /// Check if currently in a transaction.
213    pub fn in_transaction(&self) -> bool {
214        self.transaction_status.in_transaction()
215    }
216
217    /// Check if the connection is broken.
218    pub fn is_broken(&self) -> bool {
219        self.is_broken
220    }
221
222    /// Generate the next unique portal name.
223    pub(crate) fn next_portal_name(&mut self) -> String {
224        self.name_counter += 1;
225        format!("_zero_p_{}", self.name_counter)
226    }
227
228    /// Create a named portal by binding a statement.
229    ///
230    /// Used internally by Transaction::exec_portal.
231    pub(crate) async fn create_named_portal<S: IntoStatement, P: ToParams>(
232        &mut self,
233        portal_name: &str,
234        statement: &S,
235        params: &P,
236    ) -> Result<()> {
237        // Create bind state machine for named portal
238        let mut state_machine = if let Some(sql) = statement.as_sql() {
239            BindStateMachine::bind_sql(&mut self.buffer_set, portal_name, sql, params)?
240        } else {
241            let stmt = statement.as_prepared().unwrap();
242            BindStateMachine::bind_prepared(
243                &mut self.buffer_set,
244                portal_name,
245                &stmt.wire_name(),
246                &stmt.param_oids,
247                params,
248            )?
249        };
250
251        // Drive the state machine to completion (ParseComplete + BindComplete)
252        loop {
253            match state_machine.step(&mut self.buffer_set)? {
254                Action::ReadMessage => {
255                    self.stream.read_message(&mut self.buffer_set).await?;
256                }
257                Action::Write => {
258                    self.stream.write_all(&self.buffer_set.write_buffer).await?;
259                    self.stream.flush().await?;
260                }
261                Action::WriteAndReadMessage => {
262                    self.stream.write_all(&self.buffer_set.write_buffer).await?;
263                    self.stream.flush().await?;
264                    self.stream.read_message(&mut self.buffer_set).await?;
265                }
266                Action::Finished => break,
267                _ => return Err(Error::Protocol("Unexpected action in bind".into())),
268            }
269        }
270
271        Ok(())
272    }
273
274    /// Set the async message handler.
275    ///
276    /// The handler is called when the server sends asynchronous messages:
277    /// - `Notification` - from LISTEN/NOTIFY
278    /// - `Notice` - warnings and informational messages
279    /// - `ParameterChanged` - server parameter updates
280    pub fn set_async_message_handler<H: AsyncMessageHandler + 'static>(&mut self, handler: H) {
281        self.async_message_handler = Some(Box::new(handler));
282    }
283
284    /// Remove the async message handler.
285    pub fn clear_async_message_handler(&mut self) {
286        self.async_message_handler = None;
287    }
288
289    /// Ping the server with an empty query to check connection aliveness.
290    pub async fn ping(&mut self) -> Result<()> {
291        self.query_drop("").await?;
292        Ok(())
293    }
294
295    /// Drive a state machine to completion.
296    async fn drive<S: StateMachine>(&mut self, state_machine: &mut S) -> Result<()> {
297        loop {
298            match state_machine.step(&mut self.buffer_set)? {
299                Action::WriteAndReadByte => {
300                    return Err(Error::Protocol(
301                        "Unexpected WriteAndReadByte in query state machine".into(),
302                    ));
303                }
304                Action::ReadMessage => {
305                    self.stream.read_message(&mut self.buffer_set).await?;
306                }
307                Action::Write => {
308                    self.stream.write_all(&self.buffer_set.write_buffer).await?;
309                    self.stream.flush().await?;
310                }
311                Action::WriteAndReadMessage => {
312                    self.stream.write_all(&self.buffer_set.write_buffer).await?;
313                    self.stream.flush().await?;
314                    self.stream.read_message(&mut self.buffer_set).await?;
315                }
316                Action::TlsHandshake => {
317                    return Err(Error::Protocol(
318                        "Unexpected TlsHandshake in query state machine".into(),
319                    ));
320                }
321                Action::HandleAsyncMessageAndReadMessage(ref async_msg) => {
322                    if let Some(ref mut h) = self.async_message_handler {
323                        h.handle(async_msg);
324                    }
325                    // Read next message after handling async message
326                    self.stream.read_message(&mut self.buffer_set).await?;
327                }
328                Action::Finished => {
329                    self.transaction_status = state_machine.transaction_status();
330                    break;
331                }
332            }
333        }
334        Ok(())
335    }
336
337    /// Execute a simple query with a handler.
338    pub async fn query<H: TextHandler>(&mut self, sql: &str, handler: &mut H) -> Result<()> {
339        let result = self.query_inner(sql, handler).await;
340        if let Err(e) = &result
341            && e.is_connection_broken()
342        {
343            self.is_broken = true;
344        }
345        result
346    }
347
348    async fn query_inner<H: TextHandler>(&mut self, sql: &str, handler: &mut H) -> Result<()> {
349        let mut state_machine = SimpleQueryStateMachine::new(handler, sql);
350        self.drive(&mut state_machine).await
351    }
352
353    /// Execute a simple query and discard results.
354    pub async fn query_drop(&mut self, sql: &str) -> Result<Option<u64>> {
355        let mut handler = DropHandler::new();
356        self.query(sql, &mut handler).await?;
357        Ok(handler.rows_affected())
358    }
359
360    /// Execute a simple query and collect typed rows.
361    pub async fn query_collect<T: for<'a> crate::conversion::FromRow<'a>>(
362        &mut self,
363        sql: &str,
364    ) -> Result<Vec<T>> {
365        let mut handler = crate::handler::CollectHandler::<T>::new();
366        self.query(sql, &mut handler).await?;
367        Ok(handler.into_rows())
368    }
369
370    /// Execute a simple query and return the first typed row.
371    pub async fn query_first<T: for<'a> crate::conversion::FromRow<'a>>(
372        &mut self,
373        sql: &str,
374    ) -> Result<Option<T>> {
375        let mut handler = crate::handler::FirstRowHandler::<T>::new();
376        self.query(sql, &mut handler).await?;
377        Ok(handler.into_row())
378    }
379
380    /// Close the connection gracefully.
381    pub async fn close(mut self) -> Result<()> {
382        self.buffer_set.write_buffer.clear();
383        write_terminate(&mut self.buffer_set.write_buffer);
384        self.stream.write_all(&self.buffer_set.write_buffer).await?;
385        self.stream.flush().await?;
386        Ok(())
387    }
388
389    // === Extended Query Protocol ===
390
391    /// Prepare a statement using the extended query protocol.
392    pub async fn prepare(&mut self, query: &str) -> Result<PreparedStatement> {
393        self.prepare_typed(query, &[]).await
394    }
395
396    /// Prepare a statement with explicit parameter types.
397    pub async fn prepare_typed(
398        &mut self,
399        query: &str,
400        param_oids: &[u32],
401    ) -> Result<PreparedStatement> {
402        self.name_counter += 1;
403        let idx = self.name_counter;
404        let result = self.prepare_inner(idx, query, param_oids).await;
405        if let Err(e) = &result
406            && e.is_connection_broken()
407        {
408            self.is_broken = true;
409        }
410        result
411    }
412
413    /// Prepare multiple statements in a single round-trip.
414    ///
415    /// This is more efficient than calling `prepare()` multiple times when you
416    /// need to prepare several statements, as it batches the network communication.
417    ///
418    /// # Example
419    ///
420    /// ```ignore
421    /// let stmts = conn.prepare_batch(&[
422    ///     "SELECT id, name FROM users WHERE id = $1",
423    ///     "INSERT INTO users (name) VALUES ($1) RETURNING id",
424    ///     "UPDATE users SET name = $1 WHERE id = $2",
425    /// ]).await?;
426    ///
427    /// // Use stmts[0], stmts[1], stmts[2]...
428    /// ```
429    pub async fn prepare_batch(&mut self, queries: &[&str]) -> Result<Vec<PreparedStatement>> {
430        if queries.is_empty() {
431            return Ok(Vec::new());
432        }
433
434        let start_idx = self.name_counter + 1;
435        self.name_counter += queries.len() as u64;
436
437        let result = self.prepare_batch_inner(queries, start_idx).await;
438        if let Err(e) = &result
439            && e.is_connection_broken()
440        {
441            self.is_broken = true;
442        }
443        result
444    }
445
446    async fn prepare_batch_inner(
447        &mut self,
448        queries: &[&str],
449        start_idx: u64,
450    ) -> Result<Vec<PreparedStatement>> {
451        use crate::state::batch_prepare::BatchPrepareStateMachine;
452
453        let mut state_machine =
454            BatchPrepareStateMachine::new(&mut self.buffer_set, queries, start_idx);
455
456        loop {
457            match state_machine.step(&mut self.buffer_set)? {
458                Action::ReadMessage => {
459                    self.stream.read_message(&mut self.buffer_set).await?;
460                }
461                Action::WriteAndReadMessage => {
462                    self.stream.write_all(&self.buffer_set.write_buffer).await?;
463                    self.stream.flush().await?;
464                    self.stream.read_message(&mut self.buffer_set).await?;
465                }
466                Action::Finished => {
467                    self.transaction_status = state_machine.transaction_status();
468                    break;
469                }
470                _ => return Err(Error::Protocol("Unexpected action in batch prepare".into())),
471            }
472        }
473
474        Ok(state_machine.take_statements())
475    }
476
477    async fn prepare_inner(
478        &mut self,
479        idx: u64,
480        query: &str,
481        param_oids: &[u32],
482    ) -> Result<PreparedStatement> {
483        let mut handler = DropHandler::new();
484        let mut state_machine = ExtendedQueryStateMachine::prepare(
485            &mut handler,
486            &mut self.buffer_set,
487            idx,
488            query,
489            param_oids,
490        );
491        self.drive(&mut state_machine).await?;
492        state_machine
493            .take_prepared_statement()
494            .ok_or_else(|| Error::Protocol("No prepared statement".into()))
495    }
496
497    /// Execute a statement with a handler.
498    ///
499    /// The statement can be either:
500    /// - A `&PreparedStatement` returned from `prepare()`
501    /// - A raw SQL `&str` for one-shot execution
502    pub async fn exec<S: IntoStatement, P: ToParams, H: BinaryHandler>(
503        &mut self,
504        statement: S,
505        params: P,
506        handler: &mut H,
507    ) -> Result<()> {
508        let result = self.exec_inner(&statement, &params, handler).await;
509        if let Err(e) = &result
510            && e.is_connection_broken()
511        {
512            self.is_broken = true;
513        }
514        result
515    }
516
517    async fn exec_inner<S: IntoStatement, P: ToParams, H: BinaryHandler>(
518        &mut self,
519        statement: &S,
520        params: &P,
521        handler: &mut H,
522    ) -> Result<()> {
523        let mut state_machine = if statement.needs_parse() {
524            ExtendedQueryStateMachine::execute_sql(
525                handler,
526                &mut self.buffer_set,
527                statement.as_sql().unwrap(),
528                params,
529            )?
530        } else {
531            let stmt = statement.as_prepared().unwrap();
532            ExtendedQueryStateMachine::execute(
533                handler,
534                &mut self.buffer_set,
535                &stmt.wire_name(),
536                &stmt.param_oids,
537                params,
538            )?
539        };
540
541        self.drive(&mut state_machine).await
542    }
543
544    /// Execute a statement and discard results.
545    ///
546    /// The statement can be either a `&PreparedStatement` or a raw SQL `&str`.
547    pub async fn exec_drop<S: IntoStatement, P: ToParams>(
548        &mut self,
549        statement: S,
550        params: P,
551    ) -> Result<Option<u64>> {
552        let mut handler = DropHandler::new();
553        self.exec(statement, params, &mut handler).await?;
554        Ok(handler.rows_affected())
555    }
556
557    /// Execute a statement and collect typed rows.
558    ///
559    /// The statement can be either a `&PreparedStatement` or a raw SQL `&str`.
560    pub async fn exec_collect<
561        T: for<'a> crate::conversion::FromRow<'a>,
562        S: IntoStatement,
563        P: ToParams,
564    >(
565        &mut self,
566        statement: S,
567        params: P,
568    ) -> Result<Vec<T>> {
569        let mut handler = crate::handler::CollectHandler::<T>::new();
570        self.exec(statement, params, &mut handler).await?;
571        Ok(handler.into_rows())
572    }
573
574    /// Execute a statement with multiple parameter sets in a batch.
575    ///
576    /// This is more efficient than calling `exec_drop` multiple times as it
577    /// batches the network communication. The statement is parsed once (if raw SQL)
578    /// and then bound/executed for each parameter set.
579    ///
580    /// Parameters are processed in chunks (default 1000) to avoid overwhelming
581    /// the server with too many pending operations.
582    ///
583    /// The statement can be either:
584    /// - A `&PreparedStatement` returned from `prepare()`
585    /// - A raw SQL `&str` for one-shot execution
586    ///
587    /// # Example
588    ///
589    /// ```ignore
590    /// // Using prepared statement
591    /// let stmt = conn.prepare("INSERT INTO users (name, age) VALUES ($1, $2)").await?;
592    /// conn.exec_batch(&stmt, &[
593    ///     ("alice", 30),
594    ///     ("bob", 25),
595    ///     ("charlie", 35),
596    /// ]).await?;
597    ///
598    /// // Using raw SQL
599    /// conn.exec_batch("INSERT INTO users (name, age) VALUES ($1, $2)", &[
600    ///     ("alice", 30),
601    ///     ("bob", 25),
602    /// ]).await?;
603    /// ```
604    pub async fn exec_batch<S: IntoStatement, P: ToParams>(
605        &mut self,
606        statement: S,
607        params_list: &[P],
608    ) -> Result<()> {
609        self.exec_batch_chunked(statement, params_list, 1000).await
610    }
611
612    /// Execute a statement with multiple parameter sets in a batch with custom chunk size.
613    ///
614    /// Same as `exec_batch` but allows specifying the chunk size for batching.
615    pub async fn exec_batch_chunked<S: IntoStatement, P: ToParams>(
616        &mut self,
617        statement: S,
618        params_list: &[P],
619        chunk_size: usize,
620    ) -> Result<()> {
621        let result = self
622            .exec_batch_inner(&statement, params_list, chunk_size)
623            .await;
624        if let Err(e) = &result
625            && e.is_connection_broken()
626        {
627            self.is_broken = true;
628        }
629        result
630    }
631
632    async fn exec_batch_inner<S: IntoStatement, P: ToParams>(
633        &mut self,
634        statement: &S,
635        params_list: &[P],
636        chunk_size: usize,
637    ) -> Result<()> {
638        use crate::protocol::frontend::{write_bind, write_execute, write_parse, write_sync};
639        use crate::state::extended::BatchStateMachine;
640
641        if params_list.is_empty() {
642            return Ok(());
643        }
644
645        let chunk_size = chunk_size.max(1);
646        let needs_parse = statement.needs_parse();
647        let sql = statement.as_sql();
648        let prepared = statement.as_prepared();
649
650        // Get param OIDs from first params or prepared statement
651        let param_oids: Vec<u32> = if let Some(stmt) = prepared {
652            stmt.param_oids.clone()
653        } else {
654            params_list[0].natural_oids()
655        };
656
657        // Statement name: empty for raw SQL, actual name for prepared
658        let stmt_name = prepared.map(|s| s.wire_name()).unwrap_or_default();
659
660        for chunk in params_list.chunks(chunk_size) {
661            self.buffer_set.write_buffer.clear();
662
663            // For raw SQL, send Parse each chunk (reuses unnamed statement)
664            let parse_in_chunk = needs_parse;
665            if parse_in_chunk {
666                write_parse(
667                    &mut self.buffer_set.write_buffer,
668                    "",
669                    sql.unwrap(),
670                    &param_oids,
671                );
672            }
673
674            // Write Bind + Execute for each param set
675            for params in chunk {
676                let effective_stmt_name = if needs_parse { "" } else { &stmt_name };
677                write_bind(
678                    &mut self.buffer_set.write_buffer,
679                    "",
680                    effective_stmt_name,
681                    params,
682                    &param_oids,
683                )?;
684                write_execute(&mut self.buffer_set.write_buffer, "", 0);
685            }
686
687            // Send Sync
688            write_sync(&mut self.buffer_set.write_buffer);
689
690            // Drive state machine
691            let mut state_machine = BatchStateMachine::new(parse_in_chunk);
692            self.drive_batch(&mut state_machine).await?;
693            self.transaction_status = state_machine.transaction_status();
694        }
695
696        Ok(())
697    }
698
699    /// Drive a batch state machine to completion.
700    async fn drive_batch(
701        &mut self,
702        state_machine: &mut crate::state::extended::BatchStateMachine,
703    ) -> Result<()> {
704        use crate::protocol::backend::{ReadyForQuery, msg_type};
705        use crate::state::action::Action;
706
707        loop {
708            let step_result = state_machine.step(&mut self.buffer_set);
709            match step_result {
710                Ok(Action::ReadMessage) => {
711                    self.stream.read_message(&mut self.buffer_set).await?;
712                }
713                Ok(Action::WriteAndReadMessage) => {
714                    self.stream.write_all(&self.buffer_set.write_buffer).await?;
715                    self.stream.flush().await?;
716                    self.stream.read_message(&mut self.buffer_set).await?;
717                }
718                Ok(Action::Finished) => {
719                    break;
720                }
721                Ok(_) => return Err(Error::Protocol("Unexpected action in batch".into())),
722                Err(e) => {
723                    // On error, drain to ReadyForQuery to leave connection in clean state
724                    loop {
725                        self.stream.read_message(&mut self.buffer_set).await?;
726                        if self.buffer_set.type_byte == msg_type::READY_FOR_QUERY {
727                            let ready = ReadyForQuery::parse(&self.buffer_set.read_buffer)?;
728                            self.transaction_status =
729                                ready.transaction_status().unwrap_or_default();
730                            break;
731                        }
732                    }
733                    return Err(e);
734                }
735            }
736        }
737        Ok(())
738    }
739
740    /// Close a prepared statement.
741    pub async fn close_statement(&mut self, stmt: &PreparedStatement) -> Result<()> {
742        let result = self.close_statement_inner(&stmt.wire_name()).await;
743        if let Err(e) = &result
744            && e.is_connection_broken()
745        {
746            self.is_broken = true;
747        }
748        result
749    }
750
751    async fn close_statement_inner(&mut self, name: &str) -> Result<()> {
752        let mut handler = DropHandler::new();
753        let mut state_machine =
754            ExtendedQueryStateMachine::close_statement(&mut handler, &mut self.buffer_set, name);
755        self.drive(&mut state_machine).await
756    }
757
758    // === Low-Level Extended Query Protocol ===
759
760    /// Low-level flush: send FLUSH to force server to send pending responses.
761    ///
762    /// Unlike SYNC, FLUSH does not end the transaction or wait for ReadyForQuery.
763    /// It just forces the server to send any pending responses without ending
764    /// the extended query sequence.
765    pub async fn lowlevel_flush(&mut self) -> Result<()> {
766        use crate::protocol::frontend::write_flush;
767
768        self.buffer_set.write_buffer.clear();
769        write_flush(&mut self.buffer_set.write_buffer);
770
771        self.stream.write_all(&self.buffer_set.write_buffer).await?;
772        self.stream.flush().await?;
773        Ok(())
774    }
775
776    /// Low-level sync: send SYNC and receive ReadyForQuery.
777    ///
778    /// This ends an extended query sequence and:
779    /// - Commits implicit transaction if successful
780    /// - Rolls back implicit transaction if failed
781    /// - Updates transaction status
782    pub async fn lowlevel_sync(&mut self) -> Result<()> {
783        let result = self.lowlevel_sync_inner().await;
784        if let Err(e) = &result
785            && e.is_connection_broken()
786        {
787            self.is_broken = true;
788        }
789        result
790    }
791
792    async fn lowlevel_sync_inner(&mut self) -> Result<()> {
793        use crate::protocol::backend::{ErrorResponse, RawMessage, ReadyForQuery, msg_type};
794        use crate::protocol::frontend::write_sync;
795
796        self.buffer_set.write_buffer.clear();
797        write_sync(&mut self.buffer_set.write_buffer);
798
799        self.stream.write_all(&self.buffer_set.write_buffer).await?;
800        self.stream.flush().await?;
801
802        let mut pending_error: Option<Error> = None;
803
804        loop {
805            self.stream.read_message(&mut self.buffer_set).await?;
806            let type_byte = self.buffer_set.type_byte;
807
808            if RawMessage::is_async_type(type_byte) {
809                continue;
810            }
811
812            match type_byte {
813                msg_type::READY_FOR_QUERY => {
814                    let ready = ReadyForQuery::parse(&self.buffer_set.read_buffer)?;
815                    self.transaction_status = ready.transaction_status().unwrap_or_default();
816                    if let Some(e) = pending_error {
817                        return Err(e);
818                    }
819                    return Ok(());
820                }
821                msg_type::ERROR_RESPONSE => {
822                    let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
823                    pending_error = Some(error.into_error());
824                }
825                _ => {
826                    // Ignore other messages before ReadyForQuery
827                }
828            }
829        }
830    }
831
832    /// Low-level bind: send BIND message and receive BindComplete.
833    ///
834    /// This allows creating named portals. Unlike `exec()`, this does NOT
835    /// send EXECUTE or SYNC - the caller controls when to execute and sync.
836    ///
837    /// # Arguments
838    /// - `portal`: Portal name (empty string "" for unnamed portal)
839    /// - `statement_name`: Prepared statement name
840    /// - `params`: Parameter values
841    pub async fn lowlevel_bind<P: ToParams>(
842        &mut self,
843        portal: &str,
844        statement_name: &str,
845        params: P,
846    ) -> Result<()> {
847        let result = self
848            .lowlevel_bind_inner(portal, statement_name, &params)
849            .await;
850        if let Err(e) = &result
851            && e.is_connection_broken()
852        {
853            self.is_broken = true;
854        }
855        result
856    }
857
858    async fn lowlevel_bind_inner<P: ToParams>(
859        &mut self,
860        portal: &str,
861        statement_name: &str,
862        params: &P,
863    ) -> Result<()> {
864        use crate::protocol::backend::{BindComplete, ErrorResponse, RawMessage, msg_type};
865        use crate::protocol::frontend::{write_bind, write_flush};
866
867        let param_oids = params.natural_oids();
868        self.buffer_set.write_buffer.clear();
869        write_bind(
870            &mut self.buffer_set.write_buffer,
871            portal,
872            statement_name,
873            params,
874            &param_oids,
875        )?;
876        write_flush(&mut self.buffer_set.write_buffer);
877
878        self.stream.write_all(&self.buffer_set.write_buffer).await?;
879        self.stream.flush().await?;
880
881        loop {
882            self.stream.read_message(&mut self.buffer_set).await?;
883            let type_byte = self.buffer_set.type_byte;
884
885            if RawMessage::is_async_type(type_byte) {
886                continue;
887            }
888
889            match type_byte {
890                msg_type::BIND_COMPLETE => {
891                    BindComplete::parse(&self.buffer_set.read_buffer)?;
892                    return Ok(());
893                }
894                msg_type::ERROR_RESPONSE => {
895                    let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
896                    return Err(error.into_error());
897                }
898                _ => {
899                    return Err(Error::Protocol(format!(
900                        "Expected BindComplete or ErrorResponse, got '{}'",
901                        type_byte as char
902                    )));
903                }
904            }
905        }
906    }
907
908    /// Low-level execute: send EXECUTE message and receive results.
909    ///
910    /// Executes a previously bound portal. Does NOT send SYNC.
911    ///
912    /// # Arguments
913    /// - `portal`: Portal name (empty string "" for unnamed portal)
914    /// - `max_rows`: Maximum rows to return (0 = unlimited)
915    /// - `handler`: Handler to receive rows
916    ///
917    /// # Returns
918    /// - `Ok(true)` if more rows available (PortalSuspended received)
919    /// - `Ok(false)` if execution completed (CommandComplete received)
920    pub async fn lowlevel_execute<H: BinaryHandler>(
921        &mut self,
922        portal: &str,
923        max_rows: u32,
924        handler: &mut H,
925    ) -> Result<bool> {
926        let result = self.lowlevel_execute_inner(portal, max_rows, handler).await;
927        if let Err(e) = &result
928            && e.is_connection_broken()
929        {
930            self.is_broken = true;
931        }
932        result
933    }
934
935    async fn lowlevel_execute_inner<H: BinaryHandler>(
936        &mut self,
937        portal: &str,
938        max_rows: u32,
939        handler: &mut H,
940    ) -> Result<bool> {
941        use crate::protocol::backend::{
942            CommandComplete, DataRow, ErrorResponse, NoData, PortalSuspended, RawMessage,
943            RowDescription, msg_type,
944        };
945        use crate::protocol::frontend::{write_describe_portal, write_execute, write_flush};
946
947        self.buffer_set.write_buffer.clear();
948        write_describe_portal(&mut self.buffer_set.write_buffer, portal);
949        write_execute(&mut self.buffer_set.write_buffer, portal, max_rows);
950        write_flush(&mut self.buffer_set.write_buffer);
951
952        self.stream.write_all(&self.buffer_set.write_buffer).await?;
953        self.stream.flush().await?;
954
955        let mut column_buffer: Vec<u8> = Vec::new();
956
957        loop {
958            self.stream.read_message(&mut self.buffer_set).await?;
959            let type_byte = self.buffer_set.type_byte;
960
961            if RawMessage::is_async_type(type_byte) {
962                continue;
963            }
964
965            match type_byte {
966                msg_type::ROW_DESCRIPTION => {
967                    column_buffer.clear();
968                    column_buffer.extend_from_slice(&self.buffer_set.read_buffer);
969                    let cols = RowDescription::parse(&column_buffer)?;
970                    handler.result_start(cols)?;
971                }
972                msg_type::NO_DATA => {
973                    NoData::parse(&self.buffer_set.read_buffer)?;
974                }
975                msg_type::DATA_ROW => {
976                    let cols = RowDescription::parse(&column_buffer)?;
977                    let row = DataRow::parse(&self.buffer_set.read_buffer)?;
978                    handler.row(cols, row)?;
979                }
980                msg_type::COMMAND_COMPLETE => {
981                    let complete = CommandComplete::parse(&self.buffer_set.read_buffer)?;
982                    handler.result_end(complete)?;
983                    return Ok(false); // No more rows
984                }
985                msg_type::PORTAL_SUSPENDED => {
986                    PortalSuspended::parse(&self.buffer_set.read_buffer)?;
987                    return Ok(true); // More rows available
988                }
989                msg_type::ERROR_RESPONSE => {
990                    let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
991                    return Err(error.into_error());
992                }
993                _ => {
994                    return Err(Error::Protocol(format!(
995                        "Unexpected message in execute: '{}'",
996                        type_byte as char
997                    )));
998                }
999            }
1000        }
1001    }
1002
1003    /// Execute a statement with iterative row fetching.
1004    ///
1005    /// Creates an unnamed portal and passes it to the closure. The closure can
1006    /// call `portal.fetch(n, handler)` multiple times to retrieve rows in batches.
1007    /// Sync is called after the closure returns to end the implicit transaction.
1008    ///
1009    /// The statement can be either:
1010    /// - A `&PreparedStatement` returned from `prepare()`
1011    /// - A raw SQL `&str` for one-shot execution
1012    ///
1013    /// # Example
1014    /// ```ignore
1015    /// // Using prepared statement
1016    /// let stmt = conn.prepare("SELECT * FROM users").await?;
1017    /// conn.exec_iter(&stmt, (), |portal| async move {
1018    ///     while portal.fetch(100, &mut handler).await? {
1019    ///         // process handler.into_rows()...
1020    ///     }
1021    ///     Ok(())
1022    /// }).await?;
1023    ///
1024    /// // Using raw SQL
1025    /// conn.exec_iter("SELECT * FROM users", (), |portal| async move {
1026    ///     while portal.fetch(100, &mut handler).await? {
1027    ///         // process handler.into_rows()...
1028    ///     }
1029    ///     Ok(())
1030    /// }).await?;
1031    /// ```
1032    pub async fn exec_iter<S: IntoStatement, P, F, Fut, T>(
1033        &mut self,
1034        statement: S,
1035        params: P,
1036        f: F,
1037    ) -> Result<T>
1038    where
1039        P: ToParams,
1040        F: FnOnce(&mut super::unnamed_portal::UnnamedPortal<'_>) -> Fut,
1041        Fut: std::future::Future<Output = Result<T>>,
1042    {
1043        let result = self.exec_iter_inner(&statement, &params, f).await;
1044        if let Err(e) = &result
1045            && e.is_connection_broken()
1046        {
1047            self.is_broken = true;
1048        }
1049        result
1050    }
1051
1052    async fn exec_iter_inner<S: IntoStatement, P, F, Fut, T>(
1053        &mut self,
1054        statement: &S,
1055        params: &P,
1056        f: F,
1057    ) -> Result<T>
1058    where
1059        P: ToParams,
1060        F: FnOnce(&mut super::unnamed_portal::UnnamedPortal<'_>) -> Fut,
1061        Fut: std::future::Future<Output = Result<T>>,
1062    {
1063        // Create bind state machine for unnamed portal
1064        let mut state_machine = if let Some(sql) = statement.as_sql() {
1065            BindStateMachine::bind_sql(&mut self.buffer_set, "", sql, params)?
1066        } else {
1067            let stmt = statement.as_prepared().unwrap();
1068            BindStateMachine::bind_prepared(
1069                &mut self.buffer_set,
1070                "",
1071                &stmt.wire_name(),
1072                &stmt.param_oids,
1073                params,
1074            )?
1075        };
1076
1077        // Drive the state machine to completion (ParseComplete + BindComplete)
1078        loop {
1079            match state_machine.step(&mut self.buffer_set)? {
1080                Action::ReadMessage => {
1081                    self.stream.read_message(&mut self.buffer_set).await?;
1082                }
1083                Action::Write => {
1084                    self.stream.write_all(&self.buffer_set.write_buffer).await?;
1085                    self.stream.flush().await?;
1086                }
1087                Action::WriteAndReadMessage => {
1088                    self.stream.write_all(&self.buffer_set.write_buffer).await?;
1089                    self.stream.flush().await?;
1090                    self.stream.read_message(&mut self.buffer_set).await?;
1091                }
1092                Action::Finished => break,
1093                _ => return Err(Error::Protocol("Unexpected action in bind".into())),
1094            }
1095        }
1096
1097        // Execute closure with portal handle
1098        let mut portal = super::unnamed_portal::UnnamedPortal { conn: self };
1099        let result = f(&mut portal).await;
1100
1101        // Always sync to end implicit transaction (even on error)
1102        let sync_result = portal.conn.lowlevel_sync().await;
1103
1104        // Return closure result, or sync error if closure succeeded but sync failed
1105        match (result, sync_result) {
1106            (Ok(v), Ok(())) => Ok(v),
1107            (Err(e), _) => Err(e),
1108            (Ok(_), Err(e)) => Err(e),
1109        }
1110    }
1111
1112    /// Low-level close portal: send Close(Portal) and receive CloseComplete.
1113    pub async fn lowlevel_close_portal(&mut self, portal: &str) -> Result<()> {
1114        let result = self.lowlevel_close_portal_inner(portal).await;
1115        if let Err(e) = &result
1116            && e.is_connection_broken()
1117        {
1118            self.is_broken = true;
1119        }
1120        result
1121    }
1122
1123    async fn lowlevel_close_portal_inner(&mut self, portal: &str) -> Result<()> {
1124        use crate::protocol::backend::{CloseComplete, ErrorResponse, RawMessage, msg_type};
1125        use crate::protocol::frontend::{write_close_portal, write_flush};
1126
1127        self.buffer_set.write_buffer.clear();
1128        write_close_portal(&mut self.buffer_set.write_buffer, portal);
1129        write_flush(&mut self.buffer_set.write_buffer);
1130
1131        self.stream.write_all(&self.buffer_set.write_buffer).await?;
1132        self.stream.flush().await?;
1133
1134        loop {
1135            self.stream.read_message(&mut self.buffer_set).await?;
1136            let type_byte = self.buffer_set.type_byte;
1137
1138            if RawMessage::is_async_type(type_byte) {
1139                continue;
1140            }
1141
1142            match type_byte {
1143                msg_type::CLOSE_COMPLETE => {
1144                    CloseComplete::parse(&self.buffer_set.read_buffer)?;
1145                    return Ok(());
1146                }
1147                msg_type::ERROR_RESPONSE => {
1148                    let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
1149                    return Err(error.into_error());
1150                }
1151                _ => {
1152                    return Err(Error::Protocol(format!(
1153                        "Expected CloseComplete or ErrorResponse, got '{}'",
1154                        type_byte as char
1155                    )));
1156                }
1157            }
1158        }
1159    }
1160
1161    /// Run a pipeline of batched queries.
1162    ///
1163    /// Pipeline mode allows sending multiple queries to the server without waiting
1164    /// for responses, reducing round-trip latency.
1165    ///
1166    /// # Example
1167    ///
1168    /// ```ignore
1169    /// // Prepare statements outside the pipeline
1170    /// let stmts = conn.prepare_batch(&[
1171    ///     "SELECT id, name FROM users WHERE active = $1",
1172    ///     "INSERT INTO users (name) VALUES ($1) RETURNING id",
1173    /// ]).await?;
1174    ///
1175    /// let (active, inactive, count) = conn.run_pipeline(|p| async move {
1176    ///     // Queue executions
1177    ///     let t1 = p.exec(&stmts[0], (true,)).await?;
1178    ///     let t2 = p.exec(&stmts[0], (false,)).await?;
1179    ///     let t3 = p.exec("SELECT COUNT(*) FROM users", ()).await?;
1180    ///
1181    ///     p.sync().await?;
1182    ///
1183    ///     // Claim results in order with different methods
1184    ///     let active: Vec<(i32, String)> = p.claim_collect(t1).await?;
1185    ///     let inactive: Option<(i32, String)> = p.claim_one(t2).await?;
1186    ///     let count: Vec<(i64,)> = p.claim_collect(t3).await?;
1187    ///
1188    ///     Ok((active, inactive, count))
1189    /// }).await?;
1190    /// ```
1191    pub async fn run_pipeline<T, F, Fut>(&mut self, f: F) -> Result<T>
1192    where
1193        F: FnOnce(&mut super::pipeline::Pipeline<'_>) -> Fut,
1194        Fut: std::future::Future<Output = Result<T>>,
1195    {
1196        let mut pipeline = super::pipeline::Pipeline::new_inner(self);
1197        let result = f(&mut pipeline).await;
1198        pipeline.cleanup().await;
1199        result
1200    }
1201
1202    /// Execute a closure within a transaction.
1203    ///
1204    /// If the closure returns `Ok`, the transaction is committed.
1205    /// If the closure returns `Err` or the transaction is not explicitly
1206    /// committed or rolled back, the transaction is rolled back.
1207    ///
1208    /// # Errors
1209    ///
1210    /// Returns `Error::InvalidUsage` if called while already in a transaction.
1211    pub async fn tx<F, R, Fut>(&mut self, f: F) -> Result<R>
1212    where
1213        F: FnOnce(&mut Conn, super::transaction::Transaction) -> Fut,
1214        Fut: std::future::Future<Output = Result<R>>,
1215    {
1216        if self.in_transaction() {
1217            return Err(Error::InvalidUsage(
1218                "nested transactions are not supported".into(),
1219            ));
1220        }
1221
1222        self.query_drop("BEGIN").await?;
1223
1224        let tx = super::transaction::Transaction::new(self.connection_id());
1225
1226        // We need to use unsafe to work around the borrow checker here
1227        // because async closures can't capture &mut self properly
1228        let result = f(self, tx).await;
1229
1230        // If still in a transaction (not committed or rolled back), roll it back
1231        if self.in_transaction() {
1232            let rollback_result = self.query_drop("ROLLBACK").await;
1233
1234            // Return the first error (either from closure or rollback)
1235            if let Err(e) = result {
1236                return Err(e);
1237            }
1238            rollback_result?;
1239        }
1240
1241        result
1242    }
1243}