zero_postgres/tokio/pipeline/
mod.rs

1//! Async pipeline mode for batching multiple queries.
2//!
3//! Pipeline mode allows sending multiple queries to the server without waiting
4//! for responses, reducing round-trip latency.
5//!
6//! # Example
7//!
8//! ```ignore
9//! // Prepare statements outside the pipeline
10//! let stmts = conn.prepare_batch(&[
11//!     "SELECT id, name FROM users WHERE active = $1",
12//!     "INSERT INTO users (name) VALUES ($1) RETURNING id",
13//! ]).await?;
14//!
15//! let (active, inactive, count) = conn.run_pipeline(|p| async move {
16//!     // Queue executions
17//!     let t1 = p.exec(&stmts[0], (true,)).await?;
18//!     let t2 = p.exec(&stmts[0], (false,)).await?;
19//!     let t3 = p.exec("SELECT COUNT(*) FROM users", ()).await?;
20//!
21//!     p.sync().await?;
22//!
23//!     // Claim results in order with different methods
24//!     let active: Vec<(i32, String)> = p.claim_collect(t1).await?;
25//!     let inactive: Option<(i32, String)> = p.claim_one(t2).await?;
26//!     let count: Vec<(i64,)> = p.claim_collect(t3).await?;
27//!
28//!     Ok((active, inactive, count))
29//! }).await?;
30//! ```
31
32use crate::pipeline::Expectation;
33use crate::pipeline::Ticket;
34
35use crate::conversion::{FromRow, ToParams};
36use crate::error::{Error, Result};
37use crate::handler::BinaryHandler;
38use crate::protocol::backend::{
39    BindComplete, CommandComplete, DataRow, EmptyQueryResponse, ErrorResponse, NoData,
40    ParseComplete, RawMessage, ReadyForQuery, RowDescription, msg_type,
41};
42use crate::protocol::frontend::{
43    write_bind, write_describe_portal, write_execute, write_flush, write_parse, write_sync,
44};
45use crate::state::extended::PreparedStatement;
46use crate::statement::IntoStatement;
47
48use super::conn::Conn;
49
50/// Async pipeline mode for batching multiple queries.
51///
52/// Created by [`Conn::run_pipeline`].
53pub struct Pipeline<'a> {
54    conn: &'a mut Conn,
55    /// Monotonically increasing counter for queued operations
56    queue_seq: usize,
57    /// Next sequence number to claim
58    claim_seq: usize,
59    /// Whether the pipeline is in aborted state (error occurred)
60    aborted: bool,
61    /// Buffer for column descriptions during row processing
62    column_buffer: Vec<u8>,
63    /// Expected responses for each queued operation
64    expectations: Vec<Expectation>,
65}
66
67impl<'a> Pipeline<'a> {
68    /// Create a new pipeline.
69    ///
70    /// Prefer using [`Conn::run_pipeline`] which handles cleanup automatically.
71    /// This constructor is available for advanced use cases.
72    #[cfg(feature = "lowlevel")]
73    pub fn new(conn: &'a mut Conn) -> Self {
74        Self::new_inner(conn)
75    }
76
77    /// Create a new pipeline (internal).
78    pub(crate) fn new_inner(conn: &'a mut Conn) -> Self {
79        conn.buffer_set.write_buffer.clear();
80        Self {
81            conn,
82            queue_seq: 0,
83            claim_seq: 0,
84            aborted: false,
85            column_buffer: Vec::new(),
86            expectations: Vec::new(),
87        }
88    }
89
90    /// Cleanup the pipeline, draining any unclaimed tickets.
91    ///
92    /// This is called automatically by [`Conn::run_pipeline`].
93    /// Also available with the `lowlevel` feature for manual cleanup.
94    #[cfg(feature = "lowlevel")]
95    pub async fn cleanup(&mut self) {
96        self.cleanup_inner().await;
97    }
98
99    #[cfg(not(feature = "lowlevel"))]
100    pub(crate) async fn cleanup(&mut self) {
101        self.cleanup_inner().await;
102    }
103
104    async fn cleanup_inner(&mut self) {
105        if self.queue_seq == self.claim_seq {
106            return;
107        }
108
109        // Send sync if we have pending operations
110        if !self.conn.buffer_set.write_buffer.is_empty() {
111            let _ = self.sync().await;
112        }
113
114        // Drain remaining tickets
115        while self.claim_seq < self.queue_seq {
116            let _ = self.drain_one().await;
117            self.claim_seq += 1;
118        }
119
120        // Consume ReadyForQuery
121        let _ = self.finish().await;
122    }
123
124    /// Drain one ticket's worth of messages.
125    async fn drain_one(&mut self) {
126        let Some(expectation) = self.expectations.get(self.claim_seq).copied() else {
127            return;
128        };
129        let mut handler = crate::handler::DropHandler::new();
130
131        let _ = match expectation {
132            Expectation::ParseBindExecute => self.claim_parse_bind_exec_inner(&mut handler).await,
133            // When draining, we don't have the statement ref, but we also don't need row desc
134            // since we're just dropping the results
135            Expectation::BindExecute => self.claim_bind_exec_inner(&mut handler, None).await,
136        };
137    }
138
139    // ========================================================================
140    // Queue Operations
141    // ========================================================================
142
143    /// Queue a statement execution.
144    ///
145    /// The statement can be either:
146    /// - A `&PreparedStatement` returned from `conn.prepare()` or `conn.prepare_batch()`
147    /// - A raw SQL `&str` for one-shot execution
148    ///
149    /// This method only buffers the command locally - no network I/O occurs until
150    /// `sync()` or `flush()` is called.
151    ///
152    /// # Example
153    ///
154    /// ```ignore
155    /// let stmt = conn.prepare("SELECT id, name FROM users WHERE id = $1").await?;
156    ///
157    /// let (r1, r2) = conn.run_pipeline(|p| async move {
158    ///     let t1 = p.exec(&stmt, (1,))?;
159    ///     let t2 = p.exec("SELECT COUNT(*) FROM users", ())?;
160    ///     p.sync().await?;
161    ///
162    ///     let r1: Vec<(i32, String)> = p.claim_collect(t1).await?;
163    ///     let r2: Option<(i64,)> = p.claim_one(t2).await?;
164    ///     Ok((r1, r2))
165    /// }).await?;
166    /// ```
167    pub fn exec<'s, P: ToParams>(
168        &mut self,
169        statement: &'s (impl IntoStatement + ?Sized),
170        params: P,
171    ) -> Result<Ticket<'s>> {
172        let seq = self.queue_seq;
173        self.queue_seq += 1;
174
175        if statement.needs_parse() {
176            self.exec_sql_inner(statement.as_sql().unwrap(), &params)?;
177            Ok(Ticket { seq, stmt: None })
178        } else {
179            let stmt = statement.as_prepared().unwrap();
180            self.exec_prepared_inner(&stmt.wire_name(), &stmt.param_oids, &params)?;
181            Ok(Ticket {
182                seq,
183                stmt: Some(stmt),
184            })
185        }
186    }
187
188    fn exec_sql_inner<P: ToParams>(&mut self, sql: &str, params: &P) -> Result<()> {
189        let param_oids = params.natural_oids();
190        let buf = &mut self.conn.buffer_set.write_buffer;
191        write_parse(buf, "", sql, &param_oids);
192        write_bind(buf, "", "", params, &param_oids)?;
193        write_describe_portal(buf, "");
194        write_execute(buf, "", 0);
195        self.expectations.push(Expectation::ParseBindExecute);
196        Ok(())
197    }
198
199    fn exec_prepared_inner<P: ToParams>(
200        &mut self,
201        stmt_name: &str,
202        param_oids: &[u32],
203        params: &P,
204    ) -> Result<()> {
205        let buf = &mut self.conn.buffer_set.write_buffer;
206        write_bind(buf, "", stmt_name, params, param_oids)?;
207        // Skip write_describe_portal - use cached RowDescription from PreparedStatement
208        write_execute(buf, "", 0);
209        self.expectations.push(Expectation::BindExecute);
210        Ok(())
211    }
212
213    /// Send a FLUSH message to trigger server response.
214    ///
215    /// This forces the server to send all pending responses without establishing
216    /// a transaction boundary. Called automatically by claim methods when needed.
217    pub async fn flush(&mut self) -> Result<()> {
218        if !self.conn.buffer_set.write_buffer.is_empty() {
219            write_flush(&mut self.conn.buffer_set.write_buffer);
220            self.conn
221                .stream
222                .write_all(&self.conn.buffer_set.write_buffer)
223                .await?;
224            self.conn.stream.flush().await?;
225            self.conn.buffer_set.write_buffer.clear();
226        }
227        Ok(())
228    }
229
230    /// Send a SYNC message to establish a transaction boundary.
231    ///
232    /// After calling sync, you must claim all queued operations in order.
233    /// The final ReadyForQuery message will be consumed when all operations
234    /// are claimed.
235    pub async fn sync(&mut self) -> Result<()> {
236        let result = self.sync_inner().await;
237        if let Err(e) = &result
238            && e.is_connection_broken()
239        {
240            self.conn.is_broken = true;
241        }
242        result
243    }
244
245    async fn sync_inner(&mut self) -> Result<()> {
246        write_sync(&mut self.conn.buffer_set.write_buffer);
247        self.conn
248            .stream
249            .write_all(&self.conn.buffer_set.write_buffer)
250            .await?;
251        self.conn.stream.flush().await?;
252        self.conn.buffer_set.write_buffer.clear();
253        Ok(())
254    }
255
256    /// Wait for ReadyForQuery after all operations are claimed.
257    async fn finish(&mut self) -> Result<()> {
258        // Wait for ReadyForQuery
259        loop {
260            self.conn
261                .stream
262                .read_message(&mut self.conn.buffer_set)
263                .await?;
264            let type_byte = self.conn.buffer_set.type_byte;
265
266            // Handle async messages
267            if RawMessage::is_async_type(type_byte) {
268                continue;
269            }
270
271            // Handle error
272            if type_byte == msg_type::ERROR_RESPONSE {
273                let error = ErrorResponse::parse(&self.conn.buffer_set.read_buffer)?;
274                return Err(error.into_error());
275            }
276
277            if type_byte == msg_type::READY_FOR_QUERY {
278                let ready = ReadyForQuery::parse(&self.conn.buffer_set.read_buffer)?;
279                self.conn.transaction_status = ready.transaction_status().unwrap_or_default();
280                // Reset pipeline state
281                self.queue_seq = 0;
282                self.claim_seq = 0;
283                self.expectations.clear();
284                self.aborted = false;
285                return Ok(());
286            }
287        }
288    }
289
290    // ========================================================================
291    // Claim Operations
292    // ========================================================================
293
294    /// Claim with a custom handler.
295    ///
296    /// Results must be claimed in the same order they were queued.
297    #[cfg(feature = "lowlevel")]
298    pub async fn claim<H: BinaryHandler>(
299        &mut self,
300        ticket: Ticket<'_>,
301        handler: &mut H,
302    ) -> Result<()> {
303        self.claim_with_handler(ticket, handler).await
304    }
305
306    async fn claim_with_handler<H: BinaryHandler>(
307        &mut self,
308        ticket: Ticket<'_>,
309        handler: &mut H,
310    ) -> Result<()> {
311        self.check_sequence(ticket.seq)?;
312        self.flush().await?;
313
314        if self.aborted {
315            self.claim_seq += 1;
316            self.maybe_finish().await?;
317            return Err(Error::Protocol(
318                "pipeline aborted due to earlier error".into(),
319            ));
320        }
321
322        let expectation = self.expectations.get(ticket.seq).copied();
323
324        let result = match expectation {
325            Some(Expectation::ParseBindExecute) => self.claim_parse_bind_exec_inner(handler).await,
326            Some(Expectation::BindExecute) => {
327                self.claim_bind_exec_inner(handler, ticket.stmt).await
328            }
329            None => Err(Error::Protocol("unexpected expectation type".into())),
330        };
331
332        if let Err(e) = &result {
333            if e.is_connection_broken() {
334                self.conn.is_broken = true;
335            }
336            self.aborted = true;
337        }
338        self.claim_seq += 1;
339        self.maybe_finish().await?;
340        result
341    }
342
343    /// Claim and collect all rows.
344    ///
345    /// Results must be claimed in the same order they were queued.
346    pub async fn claim_collect<T: for<'b> FromRow<'b>>(
347        &mut self,
348        ticket: Ticket<'_>,
349    ) -> Result<Vec<T>> {
350        let mut handler = crate::handler::CollectHandler::<T>::new();
351        self.claim_with_handler(ticket, &mut handler).await?;
352        Ok(handler.into_rows())
353    }
354
355    /// Claim and return just the first row.
356    ///
357    /// Results must be claimed in the same order they were queued.
358    pub async fn claim_one<T: for<'b> FromRow<'b>>(
359        &mut self,
360        ticket: Ticket<'_>,
361    ) -> Result<Option<T>> {
362        let mut handler = crate::handler::FirstRowHandler::<T>::new();
363        self.claim_with_handler(ticket, &mut handler).await?;
364        Ok(handler.into_row())
365    }
366
367    /// Claim and discard all rows.
368    ///
369    /// Results must be claimed in the same order they were queued.
370    pub async fn claim_drop(&mut self, ticket: Ticket<'_>) -> Result<()> {
371        let mut handler = crate::handler::DropHandler::new();
372        self.claim_with_handler(ticket, &mut handler).await
373    }
374
375    /// Check that the ticket sequence matches the expected claim sequence.
376    fn check_sequence(&self, seq: usize) -> Result<()> {
377        if seq != self.claim_seq {
378            return Err(Error::InvalidUsage(format!(
379                "claim out of order: expected seq {}, got {}",
380                self.claim_seq, seq
381            )));
382        }
383        Ok(())
384    }
385
386    /// Check if all operations are claimed and consume ReadyForQuery if so.
387    async fn maybe_finish(&mut self) -> Result<()> {
388        if self.claim_seq == self.queue_seq {
389            self.finish().await?;
390        }
391        Ok(())
392    }
393
394    /// Claim Parse + Bind + Execute (for raw SQL exec() calls).
395    async fn claim_parse_bind_exec_inner<H: BinaryHandler>(
396        &mut self,
397        handler: &mut H,
398    ) -> Result<()> {
399        // Expect: ParseComplete
400        self.read_next_message().await?;
401        if self.conn.buffer_set.type_byte != msg_type::PARSE_COMPLETE {
402            return self.unexpected_message("ParseComplete");
403        }
404        ParseComplete::parse(&self.conn.buffer_set.read_buffer)?;
405
406        // Expect: BindComplete
407        self.read_next_message().await?;
408        if self.conn.buffer_set.type_byte != msg_type::BIND_COMPLETE {
409            return self.unexpected_message("BindComplete");
410        }
411        BindComplete::parse(&self.conn.buffer_set.read_buffer)?;
412
413        // Now read rows
414        self.claim_rows_inner(handler).await
415    }
416
417    /// Claim Bind + Execute (for prepared statement exec() calls).
418    async fn claim_bind_exec_inner<H: BinaryHandler>(
419        &mut self,
420        handler: &mut H,
421        stmt: Option<&PreparedStatement>,
422    ) -> Result<()> {
423        // Expect: BindComplete
424        self.read_next_message().await?;
425        if self.conn.buffer_set.type_byte != msg_type::BIND_COMPLETE {
426            return self.unexpected_message("BindComplete");
427        }
428        BindComplete::parse(&self.conn.buffer_set.read_buffer)?;
429
430        // Use cached RowDescription from PreparedStatement (no copy)
431        let row_desc = stmt.and_then(|s| s.row_desc_payload());
432
433        // Now read rows (no RowDescription/NoData expected from server)
434        self.claim_rows_cached_inner(handler, row_desc).await
435    }
436
437    /// Common row reading logic (reads RowDescription from server).
438    async fn claim_rows_inner<H: BinaryHandler>(&mut self, handler: &mut H) -> Result<()> {
439        // Expect RowDescription or NoData
440        self.read_next_message().await?;
441        let has_rows = match self.conn.buffer_set.type_byte {
442            msg_type::ROW_DESCRIPTION => {
443                self.column_buffer.clear();
444                self.column_buffer
445                    .extend_from_slice(&self.conn.buffer_set.read_buffer);
446                true
447            }
448            msg_type::NO_DATA => {
449                NoData::parse(&self.conn.buffer_set.read_buffer)?;
450                // No rows will follow, but we still need terminal message
451                false
452            }
453            _ => {
454                return Err(Error::Protocol(format!(
455                    "expected RowDescription or NoData, got '{}'",
456                    self.conn.buffer_set.type_byte as char
457                )));
458            }
459        };
460
461        // Read data rows until terminal message
462        loop {
463            self.read_next_message().await?;
464            let type_byte = self.conn.buffer_set.type_byte;
465
466            match type_byte {
467                msg_type::DATA_ROW => {
468                    if !has_rows {
469                        return Err(Error::Protocol(
470                            "received DataRow but no RowDescription".into(),
471                        ));
472                    }
473                    let cols = RowDescription::parse(&self.column_buffer)?;
474                    let row = DataRow::parse(&self.conn.buffer_set.read_buffer)?;
475                    handler.row(cols, row)?;
476                }
477                msg_type::COMMAND_COMPLETE => {
478                    let cmd = CommandComplete::parse(&self.conn.buffer_set.read_buffer)?;
479                    handler.result_end(cmd)?;
480                    return Ok(());
481                }
482                msg_type::EMPTY_QUERY_RESPONSE => {
483                    EmptyQueryResponse::parse(&self.conn.buffer_set.read_buffer)?;
484                    return Ok(());
485                }
486                _ => {
487                    return Err(Error::Protocol(format!(
488                        "unexpected message type in pipeline claim: '{}'",
489                        type_byte as char
490                    )));
491                }
492            }
493        }
494    }
495
496    /// Row reading logic with cached RowDescription (no server message expected).
497    async fn claim_rows_cached_inner<H: BinaryHandler>(
498        &mut self,
499        handler: &mut H,
500        row_desc: Option<&[u8]>,
501    ) -> Result<()> {
502        // Read data rows until terminal message
503        loop {
504            self.read_next_message().await?;
505            let type_byte = self.conn.buffer_set.type_byte;
506
507            match type_byte {
508                msg_type::DATA_ROW => {
509                    let row_desc = row_desc.ok_or_else(|| {
510                        Error::Protocol("received DataRow but no RowDescription cached".into())
511                    })?;
512                    let cols = RowDescription::parse(row_desc)?;
513                    let row = DataRow::parse(&self.conn.buffer_set.read_buffer)?;
514                    handler.row(cols, row)?;
515                }
516                msg_type::COMMAND_COMPLETE => {
517                    let cmd = CommandComplete::parse(&self.conn.buffer_set.read_buffer)?;
518                    handler.result_end(cmd)?;
519                    return Ok(());
520                }
521                msg_type::EMPTY_QUERY_RESPONSE => {
522                    EmptyQueryResponse::parse(&self.conn.buffer_set.read_buffer)?;
523                    return Ok(());
524                }
525                _ => {
526                    return Err(Error::Protocol(format!(
527                        "unexpected message type in pipeline claim: '{}'",
528                        type_byte as char
529                    )));
530                }
531            }
532        }
533    }
534
535    /// Read the next message, skipping async messages and handling errors.
536    async fn read_next_message(&mut self) -> Result<()> {
537        loop {
538            self.conn
539                .stream
540                .read_message(&mut self.conn.buffer_set)
541                .await?;
542            let type_byte = self.conn.buffer_set.type_byte;
543
544            // Handle async messages
545            if RawMessage::is_async_type(type_byte) {
546                continue;
547            }
548
549            // Handle error
550            if type_byte == msg_type::ERROR_RESPONSE {
551                let error = ErrorResponse::parse(&self.conn.buffer_set.read_buffer)?;
552                return Err(error.into_error());
553            }
554
555            return Ok(());
556        }
557    }
558
559    /// Create an error for unexpected message type.
560    fn unexpected_message<T>(&self, expected: &str) -> Result<T> {
561        Err(Error::Protocol(format!(
562            "expected {}, got '{}'",
563            expected, self.conn.buffer_set.type_byte as char
564        )))
565    }
566
567    /// Returns the number of operations that have been queued but not yet claimed.
568    pub fn pending_count(&self) -> usize {
569        self.queue_seq - self.claim_seq
570    }
571
572    /// Returns true if the pipeline is in aborted state due to an error.
573    pub fn is_aborted(&self) -> bool {
574        self.aborted
575    }
576}