zero_postgres/sync/pipeline/
mod.rs

1//! Pipeline mode for batching multiple queries.
2//!
3//! Pipeline mode allows sending multiple queries to the server without waiting
4//! for responses, reducing round-trip latency.
5//!
6//! # Example
7//!
8//! ```ignore
9//! // Prepare statements outside the pipeline
10//! let stmts = conn.prepare_batch(&[
11//!     "SELECT id, name FROM users WHERE active = $1",
12//!     "INSERT INTO users (name) VALUES ($1) RETURNING id",
13//! ])?;
14//!
15//! let (active, inactive, count) = conn.run_pipeline(|p| {
16//!     // Queue executions
17//!     let t1 = p.exec(&stmts[0], (true,))?;
18//!     let t2 = p.exec(&stmts[0], (false,))?;
19//!     let t3 = p.exec("SELECT COUNT(*) FROM users", ())?;
20//!
21//!     p.sync()?;
22//!
23//!     // Claim results in order with different methods
24//!     let active: Vec<(i32, String)> = p.claim_collect(t1)?;
25//!     let inactive: Option<(i32, String)> = p.claim_one(t2)?;
26//!     let count: Vec<(i64,)> = p.claim_collect(t3)?;
27//!
28//!     Ok((active, inactive, count))
29//! })?;
30//! ```
31
32use crate::pipeline::Expectation;
33use crate::pipeline::Ticket;
34
35use crate::conversion::{FromRow, ToParams};
36use crate::error::{Error, Result};
37use crate::handler::BinaryHandler;
38use crate::protocol::backend::{
39    BindComplete, CommandComplete, DataRow, EmptyQueryResponse, ErrorResponse, NoData,
40    ParseComplete, RawMessage, ReadyForQuery, RowDescription, msg_type,
41};
42use crate::protocol::frontend::{
43    write_bind, write_describe_portal, write_execute, write_flush, write_parse, write_sync,
44};
45use crate::state::extended::PreparedStatement;
46use crate::statement::IntoStatement;
47
48use super::conn::Conn;
49
50/// Pipeline mode for batching multiple queries.
51///
52/// Created by [`Conn::run_pipeline`].
53pub struct Pipeline<'a> {
54    conn: &'a mut Conn,
55    /// Monotonically increasing counter for queued operations
56    queue_seq: usize,
57    /// Next sequence number to claim
58    claim_seq: usize,
59    /// Whether the pipeline is in aborted state (error occurred)
60    aborted: bool,
61    /// Buffer for column descriptions during row processing
62    column_buffer: Vec<u8>,
63    /// Expected responses for each queued operation
64    expectations: Vec<Expectation>,
65}
66
67impl<'a> Pipeline<'a> {
68    /// Create a new pipeline.
69    ///
70    /// Prefer using [`Conn::run_pipeline`] which handles cleanup automatically.
71    /// This constructor is available for advanced use cases.
72    #[cfg(feature = "lowlevel")]
73    pub fn new(conn: &'a mut Conn) -> Self {
74        Self::new_inner(conn)
75    }
76
77    /// Create a new pipeline (internal).
78    pub(crate) fn new_inner(conn: &'a mut Conn) -> Self {
79        conn.buffer_set.write_buffer.clear();
80        Self {
81            conn,
82            queue_seq: 0,
83            claim_seq: 0,
84            aborted: false,
85            column_buffer: Vec::new(),
86            expectations: Vec::new(),
87        }
88    }
89
90    /// Cleanup the pipeline, draining any unclaimed tickets.
91    ///
92    /// This is called automatically by [`Conn::run_pipeline`].
93    /// Also available with the `lowlevel` feature for manual cleanup.
94    #[cfg(feature = "lowlevel")]
95    pub fn cleanup(&mut self) {
96        self.cleanup_inner();
97    }
98
99    #[cfg(not(feature = "lowlevel"))]
100    pub(crate) fn cleanup(&mut self) {
101        self.cleanup_inner();
102    }
103
104    fn cleanup_inner(&mut self) {
105        if self.queue_seq == self.claim_seq {
106            return;
107        }
108
109        // Send sync if we have pending operations
110        if !self.conn.buffer_set.write_buffer.is_empty() {
111            let _ = self.sync();
112        }
113
114        // Drain remaining tickets
115        while self.claim_seq < self.queue_seq {
116            let _ = self.drain_one();
117            self.claim_seq += 1;
118        }
119
120        // Consume ReadyForQuery
121        let _ = self.finish();
122    }
123
124    /// Drain one ticket's worth of messages.
125    fn drain_one(&mut self) {
126        let Some(expectation) = self.expectations.get(self.claim_seq).copied() else {
127            return;
128        };
129        let mut handler = crate::handler::DropHandler::new();
130
131        let _ = match expectation {
132            Expectation::ParseBindExecute => self.claim_parse_bind_exec_inner(&mut handler),
133            // When draining, we don't have the statement ref, but we also don't need row desc
134            // since we're just dropping the results
135            Expectation::BindExecute => self.claim_bind_exec_inner(&mut handler, None),
136        };
137    }
138
139    // ========================================================================
140    // Queue Operations
141    // ========================================================================
142
143    /// Queue a statement execution.
144    ///
145    /// The statement can be either:
146    /// - A `&PreparedStatement` returned from `conn.prepare()` or `conn.prepare_batch()`
147    /// - A raw SQL `&str` for one-shot execution
148    ///
149    /// This method only buffers the command locally - no network I/O occurs until
150    /// `sync()` or `flush()` is called.
151    ///
152    /// # Example
153    ///
154    /// ```ignore
155    /// let stmt = conn.prepare("SELECT id, name FROM users WHERE id = $1")?;
156    ///
157    /// let (r1, r2) = conn.run_pipeline(|p| {
158    ///     let t1 = p.exec(&stmt, (1,))?;
159    ///     let t2 = p.exec("SELECT COUNT(*) FROM users", ())?;
160    ///     p.sync()?;
161    ///
162    ///     let r1: Vec<(i32, String)> = p.claim_collect(t1)?;
163    ///     let r2: Option<(i64,)> = p.claim_one(t2)?;
164    ///     Ok((r1, r2))
165    /// })?;
166    /// ```
167    pub fn exec<'s, P: ToParams>(
168        &mut self,
169        statement: &'s (impl IntoStatement + ?Sized),
170        params: P,
171    ) -> Result<Ticket<'s>> {
172        let seq = self.queue_seq;
173        self.queue_seq += 1;
174
175        if statement.needs_parse() {
176            self.exec_sql_inner(statement.as_sql().unwrap(), &params)?;
177            Ok(Ticket { seq, stmt: None })
178        } else {
179            let stmt = statement.as_prepared().unwrap();
180            self.exec_prepared_inner(&stmt.wire_name(), &stmt.param_oids, &params)?;
181            Ok(Ticket {
182                seq,
183                stmt: Some(stmt),
184            })
185        }
186    }
187
188    fn exec_sql_inner<P: ToParams>(&mut self, sql: &str, params: &P) -> Result<()> {
189        let param_oids = params.natural_oids();
190        let buf = &mut self.conn.buffer_set.write_buffer;
191        write_parse(buf, "", sql, &param_oids);
192        write_bind(buf, "", "", params, &param_oids)?;
193        write_describe_portal(buf, "");
194        write_execute(buf, "", 0);
195        self.expectations.push(Expectation::ParseBindExecute);
196        Ok(())
197    }
198
199    fn exec_prepared_inner<P: ToParams>(
200        &mut self,
201        stmt_name: &str,
202        param_oids: &[u32],
203        params: &P,
204    ) -> Result<()> {
205        let buf = &mut self.conn.buffer_set.write_buffer;
206        write_bind(buf, "", stmt_name, params, param_oids)?;
207        // Skip write_describe_portal - use cached RowDescription from PreparedStatement
208        write_execute(buf, "", 0);
209        self.expectations.push(Expectation::BindExecute);
210        Ok(())
211    }
212
213    /// Send a FLUSH message to trigger server response.
214    ///
215    /// This forces the server to send all pending responses without establishing
216    /// a transaction boundary. Called automatically by claim methods when needed.
217    pub fn flush(&mut self) -> Result<()> {
218        if !self.conn.buffer_set.write_buffer.is_empty() {
219            write_flush(&mut self.conn.buffer_set.write_buffer);
220            self.conn
221                .stream
222                .write_all(&self.conn.buffer_set.write_buffer)?;
223            self.conn.stream.flush()?;
224            self.conn.buffer_set.write_buffer.clear();
225        }
226        Ok(())
227    }
228
229    /// Send a SYNC message to establish a transaction boundary.
230    ///
231    /// After calling sync, you must claim all queued operations in order.
232    /// The final ReadyForQuery message will be consumed when all operations
233    /// are claimed.
234    pub fn sync(&mut self) -> Result<()> {
235        let result = self.sync_inner();
236        if let Err(e) = &result
237            && e.is_connection_broken()
238        {
239            self.conn.is_broken = true;
240        }
241        result
242    }
243
244    fn sync_inner(&mut self) -> Result<()> {
245        write_sync(&mut self.conn.buffer_set.write_buffer);
246        self.conn
247            .stream
248            .write_all(&self.conn.buffer_set.write_buffer)?;
249        self.conn.stream.flush()?;
250        self.conn.buffer_set.write_buffer.clear();
251        Ok(())
252    }
253
254    /// Wait for ReadyForQuery after all operations are claimed.
255    fn finish(&mut self) -> Result<()> {
256        // Wait for ReadyForQuery
257        loop {
258            self.conn.stream.read_message(&mut self.conn.buffer_set)?;
259            let type_byte = self.conn.buffer_set.type_byte;
260
261            // Handle async messages
262            if RawMessage::is_async_type(type_byte) {
263                continue;
264            }
265
266            // Handle error
267            if type_byte == msg_type::ERROR_RESPONSE {
268                let error = ErrorResponse::parse(&self.conn.buffer_set.read_buffer)?;
269                return Err(error.into_error());
270            }
271
272            if type_byte == msg_type::READY_FOR_QUERY {
273                let ready = ReadyForQuery::parse(&self.conn.buffer_set.read_buffer)?;
274                self.conn.transaction_status = ready.transaction_status().unwrap_or_default();
275                // Reset pipeline state
276                self.queue_seq = 0;
277                self.claim_seq = 0;
278                self.expectations.clear();
279                self.aborted = false;
280                return Ok(());
281            }
282        }
283    }
284
285    // ========================================================================
286    // Claim Operations
287    // ========================================================================
288
289    /// Claim with a custom handler.
290    ///
291    /// Results must be claimed in the same order they were queued.
292    #[cfg(feature = "lowlevel")]
293    pub fn claim<H: BinaryHandler>(&mut self, ticket: Ticket<'_>, handler: &mut H) -> Result<()> {
294        self.claim_with_handler(ticket, handler)
295    }
296
297    fn claim_with_handler<H: BinaryHandler>(
298        &mut self,
299        ticket: Ticket<'_>,
300        handler: &mut H,
301    ) -> Result<()> {
302        self.check_sequence(ticket.seq)?;
303        self.flush()?;
304
305        if self.aborted {
306            self.claim_seq += 1;
307            self.maybe_finish()?;
308            return Err(Error::Protocol(
309                "pipeline aborted due to earlier error".into(),
310            ));
311        }
312
313        let expectation = self.expectations.get(ticket.seq).copied();
314
315        let result = match expectation {
316            Some(Expectation::ParseBindExecute) => self.claim_parse_bind_exec_inner(handler),
317            Some(Expectation::BindExecute) => self.claim_bind_exec_inner(handler, ticket.stmt),
318            None => Err(Error::Protocol("unexpected expectation type".into())),
319        };
320
321        if let Err(e) = &result {
322            if e.is_connection_broken() {
323                self.conn.is_broken = true;
324            }
325            self.aborted = true;
326        }
327        self.claim_seq += 1;
328        self.maybe_finish()?;
329        result
330    }
331
332    /// Claim and collect all rows.
333    ///
334    /// Results must be claimed in the same order they were queued.
335    pub fn claim_collect<T: for<'b> FromRow<'b>>(&mut self, ticket: Ticket<'_>) -> Result<Vec<T>> {
336        let mut handler = crate::handler::CollectHandler::<T>::new();
337        self.claim_with_handler(ticket, &mut handler)?;
338        Ok(handler.into_rows())
339    }
340
341    /// Claim and return just the first row.
342    ///
343    /// Results must be claimed in the same order they were queued.
344    pub fn claim_one<T: for<'b> FromRow<'b>>(&mut self, ticket: Ticket<'_>) -> Result<Option<T>> {
345        let mut handler = crate::handler::FirstRowHandler::<T>::new();
346        self.claim_with_handler(ticket, &mut handler)?;
347        Ok(handler.into_row())
348    }
349
350    /// Claim and discard all rows.
351    ///
352    /// Results must be claimed in the same order they were queued.
353    pub fn claim_drop(&mut self, ticket: Ticket<'_>) -> Result<()> {
354        let mut handler = crate::handler::DropHandler::new();
355        self.claim_with_handler(ticket, &mut handler)
356    }
357
358    /// Check that the ticket sequence matches the expected claim sequence.
359    fn check_sequence(&self, seq: usize) -> Result<()> {
360        if seq != self.claim_seq {
361            return Err(Error::InvalidUsage(format!(
362                "claim out of order: expected seq {}, got {}",
363                self.claim_seq, seq
364            )));
365        }
366        Ok(())
367    }
368
369    /// Check if all operations are claimed and consume ReadyForQuery if so.
370    fn maybe_finish(&mut self) -> Result<()> {
371        if self.claim_seq == self.queue_seq {
372            self.finish()?;
373        }
374        Ok(())
375    }
376
377    /// Claim Parse + Bind + Execute (for raw SQL exec() calls).
378    fn claim_parse_bind_exec_inner<H: BinaryHandler>(&mut self, handler: &mut H) -> Result<()> {
379        // Expect: ParseComplete
380        self.read_next_message()?;
381        if self.conn.buffer_set.type_byte != msg_type::PARSE_COMPLETE {
382            return self.unexpected_message("ParseComplete");
383        }
384        ParseComplete::parse(&self.conn.buffer_set.read_buffer)?;
385
386        // Expect: BindComplete
387        self.read_next_message()?;
388        if self.conn.buffer_set.type_byte != msg_type::BIND_COMPLETE {
389            return self.unexpected_message("BindComplete");
390        }
391        BindComplete::parse(&self.conn.buffer_set.read_buffer)?;
392
393        // Now read rows
394        self.claim_rows_inner(handler)
395    }
396
397    /// Claim Bind + Execute (for prepared statement exec() calls).
398    fn claim_bind_exec_inner<H: BinaryHandler>(
399        &mut self,
400        handler: &mut H,
401        stmt: Option<&PreparedStatement>,
402    ) -> Result<()> {
403        // Expect: BindComplete
404        self.read_next_message()?;
405        if self.conn.buffer_set.type_byte != msg_type::BIND_COMPLETE {
406            return self.unexpected_message("BindComplete");
407        }
408        BindComplete::parse(&self.conn.buffer_set.read_buffer)?;
409
410        // Use cached RowDescription from PreparedStatement (no copy)
411        let row_desc = stmt.and_then(|s| s.row_desc_payload());
412
413        // Now read rows (no RowDescription/NoData expected from server)
414        self.claim_rows_cached_inner(handler, row_desc)
415    }
416
417    /// Common row reading logic (reads RowDescription from server).
418    fn claim_rows_inner<H: BinaryHandler>(&mut self, handler: &mut H) -> Result<()> {
419        // Expect RowDescription or NoData
420        self.read_next_message()?;
421        let has_rows = match self.conn.buffer_set.type_byte {
422            msg_type::ROW_DESCRIPTION => {
423                self.column_buffer.clear();
424                self.column_buffer
425                    .extend_from_slice(&self.conn.buffer_set.read_buffer);
426                true
427            }
428            msg_type::NO_DATA => {
429                NoData::parse(&self.conn.buffer_set.read_buffer)?;
430                // No rows will follow, but we still need terminal message
431                false
432            }
433            _ => {
434                return Err(Error::Protocol(format!(
435                    "expected RowDescription or NoData, got '{}'",
436                    self.conn.buffer_set.type_byte as char
437                )));
438            }
439        };
440
441        // Read data rows until terminal message
442        loop {
443            self.read_next_message()?;
444            let type_byte = self.conn.buffer_set.type_byte;
445
446            match type_byte {
447                msg_type::DATA_ROW => {
448                    if !has_rows {
449                        return Err(Error::Protocol(
450                            "received DataRow but no RowDescription".into(),
451                        ));
452                    }
453                    let cols = RowDescription::parse(&self.column_buffer)?;
454                    let row = DataRow::parse(&self.conn.buffer_set.read_buffer)?;
455                    handler.row(cols, row)?;
456                }
457                msg_type::COMMAND_COMPLETE => {
458                    let cmd = CommandComplete::parse(&self.conn.buffer_set.read_buffer)?;
459                    handler.result_end(cmd)?;
460                    return Ok(());
461                }
462                msg_type::EMPTY_QUERY_RESPONSE => {
463                    EmptyQueryResponse::parse(&self.conn.buffer_set.read_buffer)?;
464                    return Ok(());
465                }
466                _ => {
467                    return Err(Error::Protocol(format!(
468                        "unexpected message type in pipeline claim: '{}'",
469                        type_byte as char
470                    )));
471                }
472            }
473        }
474    }
475
476    /// Row reading logic with cached RowDescription (no server message expected).
477    fn claim_rows_cached_inner<H: BinaryHandler>(
478        &mut self,
479        handler: &mut H,
480        row_desc: Option<&[u8]>,
481    ) -> Result<()> {
482        // Read data rows until terminal message
483        loop {
484            self.read_next_message()?;
485            let type_byte = self.conn.buffer_set.type_byte;
486
487            match type_byte {
488                msg_type::DATA_ROW => {
489                    let row_desc = row_desc.ok_or_else(|| {
490                        Error::Protocol("received DataRow but no RowDescription cached".into())
491                    })?;
492                    let cols = RowDescription::parse(row_desc)?;
493                    let row = DataRow::parse(&self.conn.buffer_set.read_buffer)?;
494                    handler.row(cols, row)?;
495                }
496                msg_type::COMMAND_COMPLETE => {
497                    let cmd = CommandComplete::parse(&self.conn.buffer_set.read_buffer)?;
498                    handler.result_end(cmd)?;
499                    return Ok(());
500                }
501                msg_type::EMPTY_QUERY_RESPONSE => {
502                    EmptyQueryResponse::parse(&self.conn.buffer_set.read_buffer)?;
503                    return Ok(());
504                }
505                _ => {
506                    return Err(Error::Protocol(format!(
507                        "unexpected message type in pipeline claim: '{}'",
508                        type_byte as char
509                    )));
510                }
511            }
512        }
513    }
514
515    /// Read the next message, skipping async messages and handling errors.
516    fn read_next_message(&mut self) -> Result<()> {
517        loop {
518            self.conn.stream.read_message(&mut self.conn.buffer_set)?;
519            let type_byte = self.conn.buffer_set.type_byte;
520
521            // Handle async messages
522            if RawMessage::is_async_type(type_byte) {
523                continue;
524            }
525
526            // Handle error
527            if type_byte == msg_type::ERROR_RESPONSE {
528                let error = ErrorResponse::parse(&self.conn.buffer_set.read_buffer)?;
529                return Err(error.into_error());
530            }
531
532            return Ok(());
533        }
534    }
535
536    /// Create an error for unexpected message type.
537    fn unexpected_message<T>(&self, expected: &str) -> Result<T> {
538        Err(Error::Protocol(format!(
539            "expected {}, got '{}'",
540            expected, self.conn.buffer_set.type_byte as char
541        )))
542    }
543
544    /// Returns the number of operations that have been queued but not yet claimed.
545    pub fn pending_count(&self) -> usize {
546        self.queue_seq - self.claim_seq
547    }
548
549    /// Returns true if the pipeline is in aborted state due to an error.
550    pub fn is_aborted(&self) -> bool {
551        self.aborted
552    }
553}