zero_postgres/sync/
conn.rs

1//! Synchronous PostgreSQL connection.
2
3use std::net::TcpStream;
4use std::os::unix::net::UnixStream;
5
6use crate::buffer_pool::PooledBufferSet;
7use crate::conversion::ToParams;
8use crate::error::{Error, Result};
9use crate::handler::{
10    AsyncMessageHandler, BinaryHandler, DropHandler, FirstRowHandler, TextHandler,
11};
12use crate::opts::Opts;
13use crate::protocol::backend::BackendKeyData;
14use crate::protocol::frontend::write_terminate;
15use crate::protocol::types::TransactionStatus;
16use crate::state::StateMachine;
17use crate::state::action::Action;
18use crate::state::connection::ConnectionStateMachine;
19use crate::state::extended::{BindStateMachine, ExtendedQueryStateMachine, PreparedStatement};
20use crate::state::simple_query::SimpleQueryStateMachine;
21use crate::statement::IntoStatement;
22
23use super::stream::Stream;
24use super::unnamed_portal::UnnamedPortal;
25
26/// Synchronous PostgreSQL connection.
27pub struct Conn {
28    pub(crate) stream: Stream,
29    pub(crate) buffer_set: PooledBufferSet,
30    backend_key: Option<BackendKeyData>,
31    server_params: Vec<(String, String)>,
32    pub(crate) transaction_status: TransactionStatus,
33    pub(crate) is_broken: bool,
34    stmt_counter: u64,
35    async_message_handler: Option<Box<dyn AsyncMessageHandler>>,
36}
37
38impl Conn {
39    /// Connect to a PostgreSQL server.
40    pub fn new<O: TryInto<Opts>>(opts: O) -> Result<Self>
41    where
42        Error: From<O::Error>,
43    {
44        let opts = opts.try_into()?;
45
46        let stream = if let Some(socket_path) = &opts.socket {
47            Stream::unix(UnixStream::connect(socket_path)?)
48        } else {
49            if opts.host.is_empty() {
50                return Err(Error::InvalidUsage("host is empty".into()));
51            }
52            let addr = format!("{}:{}", opts.host, opts.port);
53            let tcp = TcpStream::connect(&addr)?;
54            tcp.set_nodelay(true)?;
55            Stream::tcp(tcp)
56        };
57
58        Self::new_with_stream(stream, opts)
59    }
60
61    /// Connect using an existing stream.
62    #[allow(unused_mut)]
63    pub fn new_with_stream(mut stream: Stream, options: Opts) -> Result<Self> {
64        let mut buffer_set = options.buffer_pool.get_buffer_set();
65        let mut state_machine = ConnectionStateMachine::new(options.clone());
66
67        // Drive the connection state machine
68        loop {
69            match state_machine.step(&mut buffer_set)? {
70                Action::WriteAndReadByte => {
71                    stream.write_all(&buffer_set.write_buffer)?;
72                    stream.flush()?;
73                    let byte = stream.read_u8()?;
74                    state_machine.set_ssl_response(byte);
75                }
76                Action::ReadMessage => {
77                    stream.read_message(&mut buffer_set)?;
78                }
79                Action::Write => {
80                    stream.write_all(&buffer_set.write_buffer)?;
81                    stream.flush()?;
82                }
83                Action::WriteAndReadMessage => {
84                    stream.write_all(&buffer_set.write_buffer)?;
85                    stream.flush()?;
86                    stream.read_message(&mut buffer_set)?;
87                }
88                Action::TlsHandshake => {
89                    #[cfg(feature = "sync-tls")]
90                    {
91                        stream = stream.upgrade_to_tls(&options.host)?;
92                    }
93                    #[cfg(not(feature = "sync-tls"))]
94                    {
95                        return Err(Error::Unsupported(
96                            "TLS requested but sync-tls feature not enabled".into(),
97                        ));
98                    }
99                }
100                Action::HandleAsyncMessageAndReadMessage(_) => {
101                    // Ignore async messages during startup, read next message
102                    stream.read_message(&mut buffer_set)?;
103                }
104                Action::Finished => break,
105            }
106        }
107
108        let conn = Self {
109            stream,
110            buffer_set,
111            backend_key: state_machine.backend_key().cloned(),
112            server_params: state_machine.take_server_params(),
113            transaction_status: state_machine.transaction_status(),
114            is_broken: false,
115            stmt_counter: 0,
116            async_message_handler: None,
117        };
118
119        // Upgrade to Unix socket if connected via TCP to loopback
120        let conn = if options.prefer_unix_socket && conn.stream.is_tcp_loopback() {
121            conn.try_upgrade_to_unix_socket(&options)
122        } else {
123            conn
124        };
125
126        Ok(conn)
127    }
128
129    /// Try to upgrade to Unix socket connection.
130    /// Returns upgraded conn on success, original conn on failure.
131    fn try_upgrade_to_unix_socket(mut self, opts: &Opts) -> Self {
132        // Query unix_socket_directories from server
133        let mut handler = FirstRowHandler::<(String,)>::new();
134        if self
135            .query("SHOW unix_socket_directories", &mut handler)
136            .is_err()
137        {
138            return self;
139        }
140
141        let socket_dir = match handler.into_row() {
142            Some((dirs,)) => {
143                // May contain multiple directories, use the first one
144                match dirs.split(',').next() {
145                    Some(d) if !d.trim().is_empty() => d.trim().to_string(),
146                    _ => return self,
147                }
148            }
149            None => return self,
150        };
151
152        // Build socket path: {directory}/.s.PGSQL.{port}
153        let socket_path = format!("{}/.s.PGSQL.{}", socket_dir, opts.port);
154
155        // Connect via Unix socket
156        let unix_stream = match UnixStream::connect(&socket_path) {
157            Ok(s) => s,
158            Err(_) => return self,
159        };
160
161        // Create new connection over Unix socket
162        let mut opts_unix = opts.clone();
163        opts_unix.prefer_unix_socket = false;
164
165        match Self::new_with_stream(Stream::unix(unix_stream), opts_unix) {
166            Ok(new_conn) => new_conn,
167            Err(_) => self,
168        }
169    }
170
171    /// Get the backend key data for query cancellation.
172    pub fn backend_key(&self) -> Option<&BackendKeyData> {
173        self.backend_key.as_ref()
174    }
175
176    /// Get the connection ID (backend process ID).
177    ///
178    /// Returns 0 if the backend key data is not available.
179    pub fn connection_id(&self) -> u32 {
180        self.backend_key.as_ref().map_or(0, |k| k.process_id())
181    }
182
183    /// Get server parameters.
184    pub fn server_params(&self) -> &[(String, String)] {
185        &self.server_params
186    }
187
188    /// Get the current transaction status.
189    pub fn transaction_status(&self) -> TransactionStatus {
190        self.transaction_status
191    }
192
193    /// Check if currently in a transaction.
194    pub fn in_transaction(&self) -> bool {
195        self.transaction_status.in_transaction()
196    }
197
198    /// Check if the connection is broken.
199    pub fn is_broken(&self) -> bool {
200        self.is_broken
201    }
202
203    /// Set the async message handler.
204    ///
205    /// The handler is called when the server sends asynchronous messages:
206    /// - `Notification` - from LISTEN/NOTIFY
207    /// - `Notice` - warnings and informational messages
208    /// - `ParameterChanged` - server parameter updates
209    pub fn set_async_message_handler<H: AsyncMessageHandler + 'static>(&mut self, handler: H) {
210        self.async_message_handler = Some(Box::new(handler));
211    }
212
213    /// Remove the async message handler.
214    pub fn clear_async_message_handler(&mut self) {
215        self.async_message_handler = None;
216    }
217
218    /// Run a pipeline of batched queries.
219    ///
220    /// This provides automatic cleanup of the pipeline on exit, ensuring
221    /// the connection is left in a valid state even if the closure fails.
222    ///
223    /// # Example
224    ///
225    /// ```ignore
226    /// let stmt = conn.prepare("INSERT INTO users (name) VALUES ($1) RETURNING id")?;
227    ///
228    /// let (id1, id2) = conn.run_pipeline(|p| {
229    ///     let t1 = p.exec(&stmt, ("alice",))?;
230    ///     let t2 = p.exec(&stmt, ("bob",))?;
231    ///     p.sync()?;
232    ///
233    ///     let id1: Option<(i32,)> = p.claim_one(t1)?;
234    ///     let id2: Option<(i32,)> = p.claim_one(t2)?;
235    ///     Ok((id1, id2))
236    /// })?;
237    /// ```
238    pub fn run_pipeline<T, F>(&mut self, f: F) -> Result<T>
239    where
240        F: FnOnce(&mut super::pipeline::Pipeline<'_>) -> Result<T>,
241    {
242        let mut pipeline = super::pipeline::Pipeline::new_inner(self);
243        let result = f(&mut pipeline);
244        pipeline.cleanup();
245        result
246    }
247
248    /// Ping the server with an empty query to check connection aliveness.
249    pub fn ping(&mut self) -> Result<()> {
250        self.query_drop("")?;
251        Ok(())
252    }
253
254    /// Drive a state machine to completion.
255    fn drive<S: StateMachine>(&mut self, state_machine: &mut S) -> Result<()> {
256        loop {
257            match state_machine.step(&mut self.buffer_set)? {
258                Action::WriteAndReadByte => {
259                    return Err(Error::Protocol(
260                        "Unexpected WriteAndReadByte in query state machine".into(),
261                    ));
262                }
263                Action::ReadMessage => {
264                    self.stream.read_message(&mut self.buffer_set)?;
265                }
266                Action::Write => {
267                    self.stream.write_all(&self.buffer_set.write_buffer)?;
268                    self.stream.flush()?;
269                }
270                Action::WriteAndReadMessage => {
271                    self.stream.write_all(&self.buffer_set.write_buffer)?;
272                    self.stream.flush()?;
273                    self.stream.read_message(&mut self.buffer_set)?;
274                }
275                Action::TlsHandshake => {
276                    return Err(Error::Protocol(
277                        "Unexpected TlsHandshake in query state machine".into(),
278                    ));
279                }
280                Action::HandleAsyncMessageAndReadMessage(ref async_msg) => {
281                    if let Some(ref mut h) = self.async_message_handler {
282                        h.handle(async_msg);
283                    }
284                    // Read next message after handling async message
285                    self.stream.read_message(&mut self.buffer_set)?;
286                }
287                Action::Finished => {
288                    self.transaction_status = state_machine.transaction_status();
289                    break;
290                }
291            }
292        }
293        Ok(())
294    }
295
296    /// Execute a simple query with a handler.
297    pub fn query<H: TextHandler>(&mut self, sql: &str, handler: &mut H) -> Result<()> {
298        let result = self.query_inner(sql, handler);
299        if let Err(e) = &result
300            && e.is_connection_broken()
301        {
302            self.is_broken = true;
303        }
304        result
305    }
306
307    fn query_inner<H: TextHandler>(&mut self, sql: &str, handler: &mut H) -> Result<()> {
308        let mut state_machine = SimpleQueryStateMachine::new(handler, sql);
309        self.drive(&mut state_machine)
310    }
311
312    /// Execute a simple query and discard results.
313    pub fn query_drop(&mut self, sql: &str) -> Result<Option<u64>> {
314        let mut handler = DropHandler::new();
315        self.query(sql, &mut handler)?;
316        Ok(handler.rows_affected())
317    }
318
319    /// Execute a simple query and collect typed rows.
320    ///
321    /// # Example
322    ///
323    /// ```ignore
324    /// let rows: Vec<(i32, String)> = conn.query_typed("SELECT id, name FROM users")?;
325    /// for (id, name) in rows {
326    ///     println!("{}: {}", id, name);
327    /// }
328    /// ```
329    pub fn query_collect<T: for<'a> crate::conversion::FromRow<'a>>(
330        &mut self,
331        sql: &str,
332    ) -> Result<Vec<T>> {
333        let mut handler = crate::handler::CollectHandler::<T>::new();
334        self.query(sql, &mut handler)?;
335        Ok(handler.into_rows())
336    }
337
338    /// Execute a simple query and return the first typed row.
339    pub fn query_first<T: for<'a> crate::conversion::FromRow<'a>>(
340        &mut self,
341        sql: &str,
342    ) -> Result<Option<T>> {
343        let mut handler = crate::handler::FirstRowHandler::<T>::new();
344        self.query(sql, &mut handler)?;
345        Ok(handler.into_row())
346    }
347
348    /// Close the connection gracefully.
349    pub fn close(mut self) -> Result<()> {
350        self.buffer_set.write_buffer.clear();
351        write_terminate(&mut self.buffer_set.write_buffer);
352        self.stream.write_all(&self.buffer_set.write_buffer)?;
353        self.stream.flush()?;
354        Ok(())
355    }
356
357    // === Extended Query Protocol ===
358
359    /// Prepare a statement using the extended query protocol.
360    pub fn prepare(&mut self, query: &str) -> Result<PreparedStatement> {
361        self.prepare_typed(query, &[])
362    }
363
364    /// Prepare multiple statements in a single round-trip.
365    ///
366    /// This is more efficient than calling `prepare()` multiple times when you
367    /// need to prepare several statements, as it batches the network communication.
368    ///
369    /// # Example
370    ///
371    /// ```ignore
372    /// let stmts = conn.prepare_batch(&[
373    ///     "SELECT id, name FROM users WHERE id = $1",
374    ///     "INSERT INTO users (name) VALUES ($1) RETURNING id",
375    ///     "UPDATE users SET name = $1 WHERE id = $2",
376    /// ])?;
377    ///
378    /// // Use stmts[0], stmts[1], stmts[2]...
379    /// ```
380    pub fn prepare_batch(&mut self, queries: &[&str]) -> Result<Vec<PreparedStatement>> {
381        let mut statements = Vec::with_capacity(queries.len());
382        for query in queries {
383            statements.push(self.prepare(query)?);
384        }
385        Ok(statements)
386    }
387
388    /// Prepare a statement with explicit parameter types.
389    pub fn prepare_typed(&mut self, query: &str, param_oids: &[u32]) -> Result<PreparedStatement> {
390        self.stmt_counter += 1;
391        let idx = self.stmt_counter;
392        let result = self.prepare_inner(idx, query, param_oids);
393        if let Err(e) = &result
394            && e.is_connection_broken()
395        {
396            self.is_broken = true;
397        }
398        result
399    }
400
401    fn prepare_inner(
402        &mut self,
403        idx: u64,
404        query: &str,
405        param_oids: &[u32],
406    ) -> Result<PreparedStatement> {
407        let mut handler = DropHandler::new();
408        let mut state_machine = ExtendedQueryStateMachine::prepare(
409            &mut handler,
410            &mut self.buffer_set,
411            idx,
412            query,
413            param_oids,
414        );
415        self.drive(&mut state_machine)?;
416        state_machine
417            .take_prepared_statement()
418            .ok_or_else(|| Error::Protocol("No prepared statement".into()))
419    }
420
421    /// Execute a statement with a handler.
422    ///
423    /// The statement can be either:
424    /// - A `&PreparedStatement` returned from `prepare()`
425    /// - A raw SQL `&str` for one-shot execution
426    ///
427    /// # Examples
428    ///
429    /// ```ignore
430    /// // Using prepared statement
431    /// let stmt = conn.prepare("SELECT $1::int")?;
432    /// conn.exec(&stmt, (42,), &mut handler)?;
433    ///
434    /// // Using raw SQL
435    /// conn.exec("SELECT $1::int", (42,), &mut handler)?;
436    /// ```
437    pub fn exec<S: IntoStatement, P: ToParams, H: BinaryHandler>(
438        &mut self,
439        statement: S,
440        params: P,
441        handler: &mut H,
442    ) -> Result<()> {
443        let result = self.exec_inner(&statement, &params, handler);
444        if let Err(e) = &result
445            && e.is_connection_broken()
446        {
447            self.is_broken = true;
448        }
449        result
450    }
451
452    fn exec_inner<S: IntoStatement, P: ToParams, H: BinaryHandler>(
453        &mut self,
454        statement: &S,
455        params: &P,
456        handler: &mut H,
457    ) -> Result<()> {
458        let mut state_machine = if statement.needs_parse() {
459            ExtendedQueryStateMachine::execute_sql(
460                handler,
461                &mut self.buffer_set,
462                statement.as_sql().unwrap(),
463                params,
464            )?
465        } else {
466            let stmt = statement.as_prepared().unwrap();
467            ExtendedQueryStateMachine::execute(
468                handler,
469                &mut self.buffer_set,
470                &stmt.wire_name(),
471                &stmt.param_oids,
472                params,
473            )?
474        };
475
476        self.drive(&mut state_machine)
477    }
478
479    /// Execute a statement and discard results.
480    ///
481    /// The statement can be either a `&PreparedStatement` or a raw SQL `&str`.
482    pub fn exec_drop<S: IntoStatement, P: ToParams>(
483        &mut self,
484        statement: S,
485        params: P,
486    ) -> Result<Option<u64>> {
487        let mut handler = DropHandler::new();
488        self.exec(statement, params, &mut handler)?;
489        Ok(handler.rows_affected())
490    }
491
492    /// Execute a statement and collect typed rows.
493    ///
494    /// The statement can be either a `&PreparedStatement` or a raw SQL `&str`.
495    ///
496    /// # Example
497    ///
498    /// ```ignore
499    /// let stmt = conn.prepare("SELECT id, name FROM users WHERE id = $1")?;
500    /// let rows: Vec<(i32, String)> = conn.exec_collect(&stmt, (42,))?;
501    ///
502    /// // Or with raw SQL:
503    /// let rows: Vec<(i32, String)> = conn.exec_collect("SELECT id, name FROM users", ())?;
504    /// ```
505    pub fn exec_collect<
506        T: for<'a> crate::conversion::FromRow<'a>,
507        S: IntoStatement,
508        P: ToParams,
509    >(
510        &mut self,
511        statement: S,
512        params: P,
513    ) -> Result<Vec<T>> {
514        let mut handler = crate::handler::CollectHandler::<T>::new();
515        self.exec(statement, params, &mut handler)?;
516        Ok(handler.into_rows())
517    }
518
519    /// Execute a statement and return the first typed row.
520    ///
521    /// The statement can be either a `&PreparedStatement` or a raw SQL `&str`.
522    ///
523    /// # Example
524    ///
525    /// ```ignore
526    /// let stmt = conn.prepare("SELECT id, name FROM users WHERE id = $1")?;
527    /// let row: Option<(i32, String)> = conn.exec_first(&stmt, (42,))?;
528    ///
529    /// // Or with raw SQL:
530    /// let row: Option<(i32, String)> = conn.exec_first("SELECT id, name FROM users LIMIT 1", ())?;
531    /// ```
532    pub fn exec_first<T: for<'a> crate::conversion::FromRow<'a>, S: IntoStatement, P: ToParams>(
533        &mut self,
534        statement: S,
535        params: P,
536    ) -> Result<Option<T>> {
537        let mut handler = crate::handler::FirstRowHandler::<T>::new();
538        self.exec(statement, params, &mut handler)?;
539        Ok(handler.into_row())
540    }
541
542    /// Execute a statement with multiple parameter sets in a batch.
543    ///
544    /// This is more efficient than calling `exec_drop` multiple times as it
545    /// batches the network communication. The statement is parsed once (if raw SQL)
546    /// and then bound/executed for each parameter set.
547    ///
548    /// Parameters are processed in chunks (default 1000) to avoid overwhelming
549    /// the server with too many pending operations.
550    ///
551    /// The statement can be either:
552    /// - A `&PreparedStatement` returned from `prepare()`
553    /// - A raw SQL `&str` for one-shot execution
554    ///
555    /// # Example
556    ///
557    /// ```ignore
558    /// // Using prepared statement
559    /// let stmt = conn.prepare("INSERT INTO users (name, age) VALUES ($1, $2)")?;
560    /// conn.exec_batch(&stmt, &[
561    ///     ("alice", 30),
562    ///     ("bob", 25),
563    ///     ("charlie", 35),
564    /// ])?;
565    ///
566    /// // Using raw SQL
567    /// conn.exec_batch("INSERT INTO users (name, age) VALUES ($1, $2)", &[
568    ///     ("alice", 30),
569    ///     ("bob", 25),
570    /// ])?;
571    /// ```
572    pub fn exec_batch<S: IntoStatement, P: ToParams>(
573        &mut self,
574        statement: S,
575        params_list: &[P],
576    ) -> Result<()> {
577        self.exec_batch_chunked(statement, params_list, 1000)
578    }
579
580    /// Execute a statement with multiple parameter sets in a batch with custom chunk size.
581    ///
582    /// Same as `exec_batch` but allows specifying the chunk size for batching.
583    pub fn exec_batch_chunked<S: IntoStatement, P: ToParams>(
584        &mut self,
585        statement: S,
586        params_list: &[P],
587        chunk_size: usize,
588    ) -> Result<()> {
589        let result = self.exec_batch_inner(&statement, params_list, chunk_size);
590        if let Err(e) = &result
591            && e.is_connection_broken()
592        {
593            self.is_broken = true;
594        }
595        result
596    }
597
598    fn exec_batch_inner<S: IntoStatement, P: ToParams>(
599        &mut self,
600        statement: &S,
601        params_list: &[P],
602        chunk_size: usize,
603    ) -> Result<()> {
604        use crate::protocol::frontend::{write_bind, write_execute, write_parse, write_sync};
605        use crate::state::extended::BatchStateMachine;
606
607        if params_list.is_empty() {
608            return Ok(());
609        }
610
611        let chunk_size = chunk_size.max(1);
612        let needs_parse = statement.needs_parse();
613        let sql = statement.as_sql();
614        let prepared = statement.as_prepared();
615
616        // Get param OIDs from first params or prepared statement
617        let param_oids: Vec<u32> = if let Some(stmt) = prepared {
618            stmt.param_oids.clone()
619        } else {
620            params_list[0].natural_oids()
621        };
622
623        // Statement name: empty for raw SQL, actual name for prepared
624        let stmt_name = prepared.map(|s| s.wire_name()).unwrap_or_default();
625
626        for chunk in params_list.chunks(chunk_size) {
627            self.buffer_set.write_buffer.clear();
628
629            // For raw SQL, send Parse each chunk (reuses unnamed statement)
630            let parse_in_chunk = needs_parse;
631            if parse_in_chunk {
632                write_parse(
633                    &mut self.buffer_set.write_buffer,
634                    "",
635                    sql.unwrap(),
636                    &param_oids,
637                );
638            }
639
640            // Write Bind + Execute for each param set
641            for params in chunk {
642                let effective_stmt_name = if needs_parse { "" } else { &stmt_name };
643                write_bind(
644                    &mut self.buffer_set.write_buffer,
645                    "",
646                    effective_stmt_name,
647                    params,
648                    &param_oids,
649                )?;
650                write_execute(&mut self.buffer_set.write_buffer, "", 0);
651            }
652
653            // Send Sync
654            write_sync(&mut self.buffer_set.write_buffer);
655
656            // Drive state machine
657            let mut state_machine = BatchStateMachine::new(parse_in_chunk);
658            self.drive_batch(&mut state_machine)?;
659            self.transaction_status = state_machine.transaction_status();
660        }
661
662        Ok(())
663    }
664
665    /// Drive a batch state machine to completion.
666    fn drive_batch(
667        &mut self,
668        state_machine: &mut crate::state::extended::BatchStateMachine,
669    ) -> Result<()> {
670        use crate::protocol::backend::{ReadyForQuery, msg_type};
671        use crate::state::action::Action;
672
673        loop {
674            let step_result = state_machine.step(&mut self.buffer_set);
675            match step_result {
676                Ok(Action::ReadMessage) => {
677                    self.stream.read_message(&mut self.buffer_set)?;
678                }
679                Ok(Action::WriteAndReadMessage) => {
680                    self.stream.write_all(&self.buffer_set.write_buffer)?;
681                    self.stream.flush()?;
682                    self.stream.read_message(&mut self.buffer_set)?;
683                }
684                Ok(Action::Finished) => {
685                    break;
686                }
687                Ok(_) => return Err(Error::Protocol("Unexpected action in batch".into())),
688                Err(e) => {
689                    // On error, drain to ReadyForQuery to leave connection in clean state
690                    loop {
691                        self.stream.read_message(&mut self.buffer_set)?;
692                        if self.buffer_set.type_byte == msg_type::READY_FOR_QUERY {
693                            let ready = ReadyForQuery::parse(&self.buffer_set.read_buffer)?;
694                            self.transaction_status =
695                                ready.transaction_status().unwrap_or_default();
696                            break;
697                        }
698                    }
699                    return Err(e);
700                }
701            }
702        }
703        Ok(())
704    }
705
706    /// Close a prepared statement.
707    pub fn close_statement(&mut self, stmt: &PreparedStatement) -> Result<()> {
708        let result = self.close_statement_inner(&stmt.wire_name());
709        if let Err(e) = &result
710            && e.is_connection_broken()
711        {
712            self.is_broken = true;
713        }
714        result
715    }
716
717    fn close_statement_inner(&mut self, name: &str) -> Result<()> {
718        let mut handler = DropHandler::new();
719        let mut state_machine =
720            ExtendedQueryStateMachine::close_statement(&mut handler, &mut self.buffer_set, name);
721        self.drive(&mut state_machine)
722    }
723
724    /// Execute a closure within a transaction.
725    ///
726    /// If the closure returns `Ok`, the transaction is committed.
727    /// If the closure returns `Err` or the transaction is not explicitly
728    /// committed or rolled back, the transaction is rolled back.
729    ///
730    /// # Errors
731    ///
732    /// Returns `Error::InvalidUsage` if called while already in a transaction.
733    pub fn run_transaction<F, R>(&mut self, f: F) -> Result<R>
734    where
735        F: FnOnce(&mut Conn, super::transaction::Transaction) -> Result<R>,
736    {
737        if self.in_transaction() {
738            return Err(Error::InvalidUsage(
739                "nested transactions are not supported".into(),
740            ));
741        }
742
743        self.query_drop("BEGIN")?;
744
745        let tx = super::transaction::Transaction::new(self.connection_id());
746        let result = f(self, tx);
747
748        // If still in a transaction (not committed or rolled back), roll it back
749        if self.in_transaction() {
750            let rollback_result = self.query_drop("ROLLBACK");
751
752            // Return the first error (either from closure or rollback)
753            if let Err(e) = result {
754                return Err(e);
755            }
756            rollback_result?;
757        }
758
759        result
760    }
761}
762
763// === Low-level Extended Query Protocol ===
764
765impl Conn {
766    /// Low-level bind: send BIND message and receive BindComplete.
767    ///
768    /// This allows creating named portals. Unlike `exec()`, this does NOT
769    /// send EXECUTE or SYNC - the caller controls when to execute and sync.
770    ///
771    /// # Arguments
772    /// - `portal`: Portal name (empty string "" for unnamed portal)
773    /// - `statement_name`: Prepared statement name
774    /// - `params`: Parameter values
775    pub fn lowlevel_bind<P: ToParams>(
776        &mut self,
777        portal: &str,
778        statement_name: &str,
779        params: P,
780    ) -> Result<()> {
781        let result = self.lowlevel_bind_inner(portal, statement_name, &params);
782        if let Err(e) = &result
783            && e.is_connection_broken()
784        {
785            self.is_broken = true;
786        }
787        result
788    }
789
790    fn lowlevel_bind_inner<P: ToParams>(
791        &mut self,
792        portal: &str,
793        statement_name: &str,
794        params: &P,
795    ) -> Result<()> {
796        use crate::protocol::backend::{BindComplete, ErrorResponse, RawMessage, msg_type};
797        use crate::protocol::frontend::{write_bind, write_flush};
798
799        let param_oids = params.natural_oids();
800        self.buffer_set.write_buffer.clear();
801        write_bind(
802            &mut self.buffer_set.write_buffer,
803            portal,
804            statement_name,
805            params,
806            &param_oids,
807        )?;
808        write_flush(&mut self.buffer_set.write_buffer);
809
810        self.stream.write_all(&self.buffer_set.write_buffer)?;
811        self.stream.flush()?;
812
813        loop {
814            self.stream.read_message(&mut self.buffer_set)?;
815            let type_byte = self.buffer_set.type_byte;
816
817            if RawMessage::is_async_type(type_byte) {
818                continue;
819            }
820
821            match type_byte {
822                msg_type::BIND_COMPLETE => {
823                    BindComplete::parse(&self.buffer_set.read_buffer)?;
824                    return Ok(());
825                }
826                msg_type::ERROR_RESPONSE => {
827                    let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
828                    return Err(error.into_error());
829                }
830                _ => {
831                    return Err(Error::Protocol(format!(
832                        "Expected BindComplete or ErrorResponse, got '{}'",
833                        type_byte as char
834                    )));
835                }
836            }
837        }
838    }
839
840    /// Low-level execute: send EXECUTE message and receive results.
841    ///
842    /// Executes a previously bound portal. Does NOT send SYNC.
843    ///
844    /// # Arguments
845    /// - `portal`: Portal name (empty string "" for unnamed portal)
846    /// - `max_rows`: Maximum rows to return (0 = unlimited)
847    /// - `handler`: Handler to receive rows
848    ///
849    /// # Returns
850    /// - `Ok(true)` if more rows available (PortalSuspended received)
851    /// - `Ok(false)` if execution completed (CommandComplete received)
852    pub fn lowlevel_execute<H: BinaryHandler>(
853        &mut self,
854        portal: &str,
855        max_rows: u32,
856        handler: &mut H,
857    ) -> Result<bool> {
858        let result = self.lowlevel_execute_inner(portal, max_rows, handler);
859        if let Err(e) = &result
860            && e.is_connection_broken()
861        {
862            self.is_broken = true;
863        }
864        result
865    }
866
867    fn lowlevel_execute_inner<H: BinaryHandler>(
868        &mut self,
869        portal: &str,
870        max_rows: u32,
871        handler: &mut H,
872    ) -> Result<bool> {
873        use crate::protocol::backend::{
874            CommandComplete, DataRow, ErrorResponse, NoData, PortalSuspended, RawMessage,
875            RowDescription, msg_type,
876        };
877        use crate::protocol::frontend::{write_describe_portal, write_execute, write_flush};
878
879        self.buffer_set.write_buffer.clear();
880        write_describe_portal(&mut self.buffer_set.write_buffer, portal);
881        write_execute(&mut self.buffer_set.write_buffer, portal, max_rows);
882        write_flush(&mut self.buffer_set.write_buffer);
883
884        self.stream.write_all(&self.buffer_set.write_buffer)?;
885        self.stream.flush()?;
886
887        let mut column_buffer: Vec<u8> = Vec::new();
888
889        loop {
890            self.stream.read_message(&mut self.buffer_set)?;
891            let type_byte = self.buffer_set.type_byte;
892
893            if RawMessage::is_async_type(type_byte) {
894                continue;
895            }
896
897            match type_byte {
898                msg_type::ROW_DESCRIPTION => {
899                    column_buffer.clear();
900                    column_buffer.extend_from_slice(&self.buffer_set.read_buffer);
901                    let cols = RowDescription::parse(&column_buffer)?;
902                    handler.result_start(cols)?;
903                }
904                msg_type::NO_DATA => {
905                    NoData::parse(&self.buffer_set.read_buffer)?;
906                }
907                msg_type::DATA_ROW => {
908                    let cols = RowDescription::parse(&column_buffer)?;
909                    let row = DataRow::parse(&self.buffer_set.read_buffer)?;
910                    handler.row(cols, row)?;
911                }
912                msg_type::COMMAND_COMPLETE => {
913                    let complete = CommandComplete::parse(&self.buffer_set.read_buffer)?;
914                    handler.result_end(complete)?;
915                    return Ok(false); // No more rows
916                }
917                msg_type::PORTAL_SUSPENDED => {
918                    PortalSuspended::parse(&self.buffer_set.read_buffer)?;
919                    return Ok(true); // More rows available
920                }
921                msg_type::ERROR_RESPONSE => {
922                    let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
923                    return Err(error.into_error());
924                }
925                _ => {
926                    return Err(Error::Protocol(format!(
927                        "Unexpected message in execute: '{}'",
928                        type_byte as char
929                    )));
930                }
931            }
932        }
933    }
934
935    /// Low-level sync: send SYNC and receive ReadyForQuery.
936    ///
937    /// This ends an extended query sequence and:
938    /// - Commits implicit transaction if successful
939    /// - Rolls back implicit transaction if failed
940    /// - Updates transaction status
941    pub fn lowlevel_sync(&mut self) -> Result<()> {
942        let result = self.lowlevel_sync_inner();
943        if let Err(e) = &result
944            && e.is_connection_broken()
945        {
946            self.is_broken = true;
947        }
948        result
949    }
950
951    fn lowlevel_sync_inner(&mut self) -> Result<()> {
952        use crate::protocol::backend::{ErrorResponse, RawMessage, ReadyForQuery, msg_type};
953        use crate::protocol::frontend::write_sync;
954
955        self.buffer_set.write_buffer.clear();
956        write_sync(&mut self.buffer_set.write_buffer);
957
958        self.stream.write_all(&self.buffer_set.write_buffer)?;
959        self.stream.flush()?;
960
961        let mut pending_error: Option<Error> = None;
962
963        loop {
964            self.stream.read_message(&mut self.buffer_set)?;
965            let type_byte = self.buffer_set.type_byte;
966
967            if RawMessage::is_async_type(type_byte) {
968                continue;
969            }
970
971            match type_byte {
972                msg_type::READY_FOR_QUERY => {
973                    let ready = ReadyForQuery::parse(&self.buffer_set.read_buffer)?;
974                    self.transaction_status = ready.transaction_status().unwrap_or_default();
975                    if let Some(e) = pending_error {
976                        return Err(e);
977                    }
978                    return Ok(());
979                }
980                msg_type::ERROR_RESPONSE => {
981                    let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
982                    pending_error = Some(error.into_error());
983                }
984                _ => {
985                    // Ignore other messages before ReadyForQuery
986                }
987            }
988        }
989    }
990
991    /// Low-level flush: send FLUSH to force server to send pending responses.
992    ///
993    /// Unlike SYNC, FLUSH does not end the transaction or wait for ReadyForQuery.
994    /// It just forces the server to send any pending responses (like ParseComplete,
995    /// BindComplete, RowDescription, DataRow, etc.) without ending the extended
996    /// query sequence.
997    pub fn lowlevel_flush(&mut self) -> Result<()> {
998        use crate::protocol::frontend::write_flush;
999
1000        self.buffer_set.write_buffer.clear();
1001        write_flush(&mut self.buffer_set.write_buffer);
1002
1003        self.stream.write_all(&self.buffer_set.write_buffer)?;
1004        self.stream.flush()?;
1005        Ok(())
1006    }
1007
1008    /// Execute a statement with iterative row fetching.
1009    ///
1010    /// Creates an unnamed portal and passes it to the closure. The closure can
1011    /// call `portal.fetch(n, handler)` multiple times to retrieve rows in batches.
1012    /// Sync is called after the closure returns to end the implicit transaction.
1013    ///
1014    /// The statement can be either:
1015    /// - A `&PreparedStatement` returned from `prepare()`
1016    /// - A raw SQL `&str` for one-shot execution
1017    ///
1018    /// # Example
1019    /// ```ignore
1020    /// // Using prepared statement
1021    /// let stmt = conn.prepare("SELECT * FROM users")?;
1022    /// conn.exec_iter(&stmt, (), |portal| {
1023    ///     while portal.fetch(100, &mut handler)? {
1024    ///         // process handler.into_rows()...
1025    ///     }
1026    ///     Ok(())
1027    /// })?;
1028    ///
1029    /// // Using raw SQL
1030    /// conn.exec_iter("SELECT * FROM users", (), |portal| {
1031    ///     while portal.fetch(100, &mut handler)? {
1032    ///         // process handler.into_rows()...
1033    ///     }
1034    ///     Ok(())
1035    /// })?;
1036    /// ```
1037    pub fn exec_iter<S: IntoStatement, P, F, T>(
1038        &mut self,
1039        statement: S,
1040        params: P,
1041        f: F,
1042    ) -> Result<T>
1043    where
1044        P: ToParams,
1045        F: FnOnce(&mut UnnamedPortal<'_>) -> Result<T>,
1046    {
1047        let result = self.exec_iter_inner(&statement, &params, f);
1048        if let Err(e) = &result
1049            && e.is_connection_broken()
1050        {
1051            self.is_broken = true;
1052        }
1053        result
1054    }
1055
1056    fn exec_iter_inner<S: IntoStatement, P, F, T>(
1057        &mut self,
1058        statement: &S,
1059        params: &P,
1060        f: F,
1061    ) -> Result<T>
1062    where
1063        P: ToParams,
1064        F: FnOnce(&mut UnnamedPortal<'_>) -> Result<T>,
1065    {
1066        // Create bind state machine
1067        let mut state_machine = if let Some(sql) = statement.as_sql() {
1068            BindStateMachine::bind_sql(&mut self.buffer_set, sql, params)?
1069        } else {
1070            let stmt = statement.as_prepared().unwrap();
1071            BindStateMachine::bind_prepared(
1072                &mut self.buffer_set,
1073                &stmt.wire_name(),
1074                &stmt.param_oids,
1075                params,
1076            )?
1077        };
1078
1079        // Drive the state machine to completion (ParseComplete + BindComplete)
1080        loop {
1081            match state_machine.step(&mut self.buffer_set)? {
1082                Action::ReadMessage => {
1083                    self.stream.read_message(&mut self.buffer_set)?;
1084                }
1085                Action::Write => {
1086                    self.stream.write_all(&self.buffer_set.write_buffer)?;
1087                    self.stream.flush()?;
1088                }
1089                Action::WriteAndReadMessage => {
1090                    self.stream.write_all(&self.buffer_set.write_buffer)?;
1091                    self.stream.flush()?;
1092                    self.stream.read_message(&mut self.buffer_set)?;
1093                }
1094                Action::Finished => break,
1095                _ => return Err(Error::Protocol("Unexpected action in bind".into())),
1096            }
1097        }
1098
1099        // Execute closure with portal handle
1100        let mut portal = UnnamedPortal { conn: self };
1101        let result = f(&mut portal);
1102
1103        // Always sync to end implicit transaction (even on error)
1104        let sync_result = portal.conn.lowlevel_sync();
1105
1106        // Return closure result, or sync error if closure succeeded but sync failed
1107        match (result, sync_result) {
1108            (Ok(v), Ok(())) => Ok(v),
1109            (Err(e), _) => Err(e),
1110            (Ok(_), Err(e)) => Err(e),
1111        }
1112    }
1113
1114    /// Low-level close portal: send Close(Portal) and receive CloseComplete.
1115    pub fn lowlevel_close_portal(&mut self, portal: &str) -> Result<()> {
1116        let result = self.lowlevel_close_portal_inner(portal);
1117        if let Err(e) = &result
1118            && e.is_connection_broken()
1119        {
1120            self.is_broken = true;
1121        }
1122        result
1123    }
1124
1125    fn lowlevel_close_portal_inner(&mut self, portal: &str) -> Result<()> {
1126        use crate::protocol::backend::{CloseComplete, ErrorResponse, RawMessage, msg_type};
1127        use crate::protocol::frontend::{write_close_portal, write_flush};
1128
1129        self.buffer_set.write_buffer.clear();
1130        write_close_portal(&mut self.buffer_set.write_buffer, portal);
1131        write_flush(&mut self.buffer_set.write_buffer);
1132
1133        self.stream.write_all(&self.buffer_set.write_buffer)?;
1134        self.stream.flush()?;
1135
1136        loop {
1137            self.stream.read_message(&mut self.buffer_set)?;
1138            let type_byte = self.buffer_set.type_byte;
1139
1140            if RawMessage::is_async_type(type_byte) {
1141                continue;
1142            }
1143
1144            match type_byte {
1145                msg_type::CLOSE_COMPLETE => {
1146                    CloseComplete::parse(&self.buffer_set.read_buffer)?;
1147                    return Ok(());
1148                }
1149                msg_type::ERROR_RESPONSE => {
1150                    let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
1151                    return Err(error.into_error());
1152                }
1153                _ => {
1154                    return Err(Error::Protocol(format!(
1155                        "Expected CloseComplete or ErrorResponse, got '{}'",
1156                        type_byte as char
1157                    )));
1158                }
1159            }
1160        }
1161    }
1162}
1163
1164impl Drop for Conn {
1165    fn drop(&mut self) {
1166        // Try to send Terminate message, ignore errors
1167        self.buffer_set.write_buffer.clear();
1168        write_terminate(&mut self.buffer_set.write_buffer);
1169        let _ = self.stream.write_all(&self.buffer_set.write_buffer);
1170        let _ = self.stream.flush();
1171    }
1172}