Skip to main content

zero_postgres/tokio/pipeline/
mod.rs

1//! Async pipeline mode for batching multiple queries.
2//!
3//! Pipeline mode allows sending multiple queries to the server without waiting
4//! for responses, reducing round-trip latency.
5//!
6//! # Example
7//!
8//! ```ignore
9//! // Prepare statements outside the pipeline
10//! let stmts = conn.prepare_batch(&[
11//!     "SELECT id, name FROM users WHERE active = $1",
12//!     "INSERT INTO users (name) VALUES ($1) RETURNING id",
13//! ]).await?;
14//!
15//! let (active, inactive, count) = conn.pipeline(|p| async move {
16//!     // Queue executions
17//!     let t1 = p.exec(&stmts[0], (true,)).await?;
18//!     let t2 = p.exec(&stmts[0], (false,)).await?;
19//!     let t3 = p.exec("SELECT COUNT(*) FROM users", ()).await?;
20//!
21//!     p.sync().await?;
22//!
23//!     // Claim results in order with different methods
24//!     let active: Vec<(i32, String)> = p.claim_collect(t1).await?;
25//!     let inactive: Option<(i32, String)> = p.claim_one(t2).await?;
26//!     let count: Vec<(i64,)> = p.claim_collect(t3).await?;
27//!
28//!     Ok((active, inactive, count))
29//! }).await?;
30//! ```
31
32use std::collections::VecDeque;
33
34use crate::pipeline::Expectation;
35use crate::pipeline::Ticket;
36
37use crate::conversion::{FromRow, ToParams};
38use crate::error::{Error, Result};
39use crate::handler::ExtendedHandler;
40use crate::protocol::backend::{
41    BindComplete, CommandComplete, DataRow, EmptyQueryResponse, ErrorResponse, NoData,
42    ParseComplete, RawMessage, ReadyForQuery, RowDescription, msg_type,
43};
44use crate::protocol::frontend::{
45    write_bind, write_describe_portal, write_execute, write_flush, write_parse, write_sync,
46};
47use crate::state::extended::PreparedStatement;
48use crate::statement::{IntoStatement, StatementRef};
49
50use super::conn::Conn;
51
52/// Async pipeline mode for batching multiple queries.
53///
54/// Created by [`Conn::pipeline`].
55pub struct Pipeline<'a> {
56    conn: &'a mut Conn,
57    /// Monotonically increasing counter for queued operations
58    queue_seq: usize,
59    /// Next sequence number to claim
60    claim_seq: usize,
61    /// Whether the pipeline is in aborted state (error occurred)
62    aborted: bool,
63    /// Buffer for column descriptions during row processing
64    column_buffer: Vec<u8>,
65    /// Expected responses queue (exec operations and Sync points)
66    expectations: VecDeque<Expectation>,
67}
68
69impl<'a> Pipeline<'a> {
70    /// Create a new pipeline.
71    ///
72    /// Prefer using [`Conn::pipeline`] which handles cleanup automatically.
73    /// This constructor is available for advanced use cases.
74    #[cfg(feature = "lowlevel")]
75    pub fn new(conn: &'a mut Conn) -> Self {
76        Self::new_inner(conn)
77    }
78
79    /// Create a new pipeline (internal).
80    pub(crate) fn new_inner(conn: &'a mut Conn) -> Self {
81        conn.buffer_set.write_buffer.clear();
82        Self {
83            conn,
84            queue_seq: 0,
85            claim_seq: 0,
86            aborted: false,
87            column_buffer: Vec::new(),
88            expectations: VecDeque::new(),
89        }
90    }
91
92    /// Cleanup the pipeline, draining any unclaimed tickets.
93    ///
94    /// This is called automatically by [`Conn::pipeline`].
95    /// Also available with the `lowlevel` feature for manual cleanup.
96    #[cfg(feature = "lowlevel")]
97    pub async fn cleanup(&mut self) {
98        self.cleanup_inner().await;
99    }
100
101    #[cfg(not(feature = "lowlevel"))]
102    pub(crate) async fn cleanup(&mut self) {
103        self.cleanup_inner().await;
104    }
105
106    async fn cleanup_inner(&mut self) {
107        // Nothing to clean up if no operations were queued and no expectations pending
108        if self.queue_seq == 0 && self.expectations.is_empty() {
109            return;
110        }
111
112        // Send sync if we have unflushed operations or no sync is queued yet
113        if !self.conn.buffer_set.write_buffer.is_empty()
114            || !self.expectations.iter().any(|e| *e == Expectation::Sync)
115        {
116            let _ = self.sync().await;
117        }
118
119        // Drain remaining expectations
120        if self.aborted {
121            // In aborted state, server skipped remaining commands - only consume ReadyForQuery(s)
122            while let Some(expectation) = self.expectations.pop_front() {
123                if expectation == Expectation::Sync {
124                    let _ = self.consume_ready_for_query().await;
125                }
126            }
127        } else {
128            // Normal drain: process all expectations
129            while let Some(expectation) = self.expectations.pop_front() {
130                let _ = self.drain_expectation(expectation).await;
131            }
132        }
133
134        // Reset state
135        self.queue_seq = 0;
136        self.claim_seq = 0;
137        self.aborted = false;
138    }
139
140    /// Drain a single expectation.
141    async fn drain_expectation(&mut self, expectation: Expectation) {
142        let mut handler = crate::handler::DropHandler::new();
143        let _ = match expectation {
144            Expectation::ParseBindExecute => self.claim_parse_bind_exec_inner(&mut handler).await,
145            Expectation::BindExecute => self.claim_bind_exec_inner(&mut handler, None).await,
146            Expectation::Sync => self.consume_ready_for_query().await,
147        };
148    }
149
150    // ========================================================================
151    // Queue Operations
152    // ========================================================================
153
154    /// Queue a statement execution.
155    ///
156    /// The statement can be either:
157    /// - A `&PreparedStatement` returned from `conn.prepare()` or `conn.prepare_batch()`
158    /// - A raw SQL `&str` for one-shot execution
159    ///
160    /// This method only buffers the command locally - no network I/O occurs until
161    /// `sync()` or `flush()` is called.
162    ///
163    /// # Example
164    ///
165    /// ```ignore
166    /// let stmt = conn.prepare("SELECT id, name FROM users WHERE id = $1").await?;
167    ///
168    /// let (r1, r2) = conn.pipeline(|p| async move {
169    ///     let t1 = p.exec(&stmt, (1,))?;
170    ///     let t2 = p.exec("SELECT COUNT(*) FROM users", ())?;
171    ///     p.sync().await?;
172    ///
173    ///     let r1: Vec<(i32, String)> = p.claim_collect(t1).await?;
174    ///     let r2: Option<(i64,)> = p.claim_one(t2).await?;
175    ///     Ok((r1, r2))
176    /// }).await?;
177    /// ```
178    pub fn exec<'s, P: ToParams>(
179        &mut self,
180        statement: &'s (impl IntoStatement + ?Sized),
181        params: P,
182    ) -> Result<Ticket<'s>> {
183        let seq = self.queue_seq;
184        self.queue_seq += 1;
185
186        match statement.statement_ref() {
187            StatementRef::Sql(sql) => {
188                self.exec_sql_inner(sql, &params)?;
189                Ok(Ticket { seq, stmt: None })
190            }
191            StatementRef::Prepared(stmt) => {
192                self.exec_prepared_inner(&stmt.wire_name(), &stmt.param_oids, &params)?;
193                Ok(Ticket {
194                    seq,
195                    stmt: Some(stmt),
196                })
197            }
198        }
199    }
200
201    fn exec_sql_inner<P: ToParams>(&mut self, sql: &str, params: &P) -> Result<()> {
202        let param_oids = params.natural_oids();
203        let buf = &mut self.conn.buffer_set.write_buffer;
204        write_parse(buf, "", sql, &param_oids);
205        write_bind(buf, "", "", params, &param_oids)?;
206        write_describe_portal(buf, "");
207        write_execute(buf, "", 0);
208        self.expectations.push_back(Expectation::ParseBindExecute);
209        Ok(())
210    }
211
212    fn exec_prepared_inner<P: ToParams>(
213        &mut self,
214        stmt_name: &str,
215        param_oids: &[u32],
216        params: &P,
217    ) -> Result<()> {
218        let buf = &mut self.conn.buffer_set.write_buffer;
219        write_bind(buf, "", stmt_name, params, param_oids)?;
220        // Skip write_describe_portal - use cached RowDescription from PreparedStatement
221        write_execute(buf, "", 0);
222        self.expectations.push_back(Expectation::BindExecute);
223        Ok(())
224    }
225
226    /// Send a FLUSH message to trigger server response.
227    ///
228    /// This forces the server to send all pending responses without establishing
229    /// a transaction boundary. Called automatically by claim methods when needed.
230    pub async fn flush(&mut self) -> Result<()> {
231        if !self.conn.buffer_set.write_buffer.is_empty() {
232            write_flush(&mut self.conn.buffer_set.write_buffer);
233            self.conn
234                .stream
235                .write_all(&self.conn.buffer_set.write_buffer)
236                .await?;
237            self.conn.stream.flush().await?;
238            self.conn.buffer_set.write_buffer.clear();
239        }
240        Ok(())
241    }
242
243    /// Send a SYNC message to establish a transaction boundary.
244    ///
245    /// After calling sync, you must claim all queued operations in order.
246    /// The ReadyForQuery message will be consumed automatically after claiming.
247    pub async fn sync(&mut self) -> Result<()> {
248        let result = self.sync_inner().await;
249        if let Err(e) = &result
250            && e.is_connection_broken()
251        {
252            self.conn.is_broken = true;
253        }
254        result
255    }
256
257    async fn sync_inner(&mut self) -> Result<()> {
258        write_sync(&mut self.conn.buffer_set.write_buffer);
259        self.expectations.push_back(Expectation::Sync);
260        self.conn
261            .stream
262            .write_all(&self.conn.buffer_set.write_buffer)
263            .await?;
264        self.conn.stream.flush().await?;
265        self.conn.buffer_set.write_buffer.clear();
266        Ok(())
267    }
268
269    /// Consume a single ReadyForQuery message.
270    async fn consume_ready_for_query(&mut self) -> Result<()> {
271        loop {
272            self.conn
273                .stream
274                .read_message(&mut self.conn.buffer_set)
275                .await?;
276            let type_byte = self.conn.buffer_set.type_byte;
277
278            if RawMessage::is_async_type(type_byte) {
279                continue;
280            }
281
282            if type_byte == msg_type::ERROR_RESPONSE {
283                let error = ErrorResponse::parse(&self.conn.buffer_set.read_buffer)?;
284                return Err(error.into_error());
285            }
286
287            if type_byte == msg_type::READY_FOR_QUERY {
288                let ready = ReadyForQuery::parse(&self.conn.buffer_set.read_buffer)?;
289                self.conn.transaction_status = ready.transaction_status().unwrap_or_default();
290                return Ok(());
291            }
292        }
293    }
294
295    /// Consume all pending Sync expectations from the front of the queue.
296    async fn consume_pending_syncs(&mut self) -> Result<()> {
297        while self.expectations.front() == Some(&Expectation::Sync) {
298            self.expectations.pop_front();
299            self.consume_ready_for_query().await?;
300            // Reset aborted state - after ReadyForQuery, pipeline can continue
301            self.aborted = false;
302        }
303        Ok(())
304    }
305
306    // ========================================================================
307    // Claim Operations
308    // ========================================================================
309
310    /// Claim with a custom handler.
311    ///
312    /// Results must be claimed in the same order they were queued.
313    #[cfg(feature = "lowlevel")]
314    pub async fn claim<H: ExtendedHandler>(
315        &mut self,
316        ticket: Ticket<'_>,
317        handler: &mut H,
318    ) -> Result<()> {
319        self.claim_with_handler(ticket, handler).await
320    }
321
322    async fn claim_with_handler<H: ExtendedHandler>(
323        &mut self,
324        ticket: Ticket<'_>,
325        handler: &mut H,
326    ) -> Result<()> {
327        self.check_sequence(ticket.seq)?;
328
329        // Auto-sync if buffer has unsent data
330        if !self.conn.buffer_set.write_buffer.is_empty() {
331            self.sync().await?;
332        }
333
334        if self.aborted {
335            self.claim_seq += 1;
336            // Pop but don't process the exec expectation (server skipped it)
337            self.expectations.pop_front();
338            self.consume_pending_syncs().await?;
339            return Err(Error::LibraryBug(
340                "pipeline aborted due to earlier error".into(),
341            ));
342        }
343
344        let expectation = self.expectations.pop_front();
345
346        let result = match expectation {
347            Some(Expectation::ParseBindExecute) => self.claim_parse_bind_exec_inner(handler).await,
348            Some(Expectation::BindExecute) => {
349                self.claim_bind_exec_inner(handler, ticket.stmt).await
350            }
351            Some(Expectation::Sync) => Err(Error::LibraryBug("unexpected Sync expectation".into())),
352            None => Err(Error::LibraryBug("no expectation in queue".into())),
353        };
354
355        if let Err(e) = &result {
356            if e.is_connection_broken() {
357                self.conn.is_broken = true;
358            }
359            self.aborted = true;
360        }
361        self.claim_seq += 1;
362        self.consume_pending_syncs().await?;
363        result
364    }
365
366    /// Claim and collect all rows.
367    ///
368    /// Results must be claimed in the same order they were queued.
369    pub async fn claim_collect<T: for<'b> FromRow<'b>>(
370        &mut self,
371        ticket: Ticket<'_>,
372    ) -> Result<Vec<T>> {
373        let mut handler = crate::handler::CollectHandler::<T>::new();
374        self.claim_with_handler(ticket, &mut handler).await?;
375        Ok(handler.into_rows())
376    }
377
378    /// Claim and return just the first row.
379    ///
380    /// Results must be claimed in the same order they were queued.
381    pub async fn claim_one<T: for<'b> FromRow<'b>>(
382        &mut self,
383        ticket: Ticket<'_>,
384    ) -> Result<Option<T>> {
385        let mut handler = crate::handler::FirstRowHandler::<T>::new();
386        self.claim_with_handler(ticket, &mut handler).await?;
387        Ok(handler.into_row())
388    }
389
390    /// Claim and discard all rows.
391    ///
392    /// Results must be claimed in the same order they were queued.
393    pub async fn claim_drop(&mut self, ticket: Ticket<'_>) -> Result<()> {
394        let mut handler = crate::handler::DropHandler::new();
395        self.claim_with_handler(ticket, &mut handler).await
396    }
397
398    /// Check that the ticket sequence matches the expected claim sequence.
399    fn check_sequence(&self, seq: usize) -> Result<()> {
400        if seq != self.claim_seq {
401            return Err(Error::InvalidUsage(format!(
402                "claim out of order: expected seq {}, got {}",
403                self.claim_seq, seq
404            )));
405        }
406        Ok(())
407    }
408
409    /// Claim Parse + Bind + Execute (for raw SQL exec() calls).
410    async fn claim_parse_bind_exec_inner<H: ExtendedHandler>(
411        &mut self,
412        handler: &mut H,
413    ) -> Result<()> {
414        // Expect: ParseComplete
415        self.read_next_message().await?;
416        if self.conn.buffer_set.type_byte != msg_type::PARSE_COMPLETE {
417            return self.unexpected_message("ParseComplete");
418        }
419        ParseComplete::parse(&self.conn.buffer_set.read_buffer)?;
420
421        // Expect: BindComplete
422        self.read_next_message().await?;
423        if self.conn.buffer_set.type_byte != msg_type::BIND_COMPLETE {
424            return self.unexpected_message("BindComplete");
425        }
426        BindComplete::parse(&self.conn.buffer_set.read_buffer)?;
427
428        // Now read rows
429        self.claim_rows_inner(handler).await
430    }
431
432    /// Claim Bind + Execute (for prepared statement exec() calls).
433    async fn claim_bind_exec_inner<H: ExtendedHandler>(
434        &mut self,
435        handler: &mut H,
436        stmt: Option<&PreparedStatement>,
437    ) -> Result<()> {
438        // Expect: BindComplete
439        self.read_next_message().await?;
440        if self.conn.buffer_set.type_byte != msg_type::BIND_COMPLETE {
441            return self.unexpected_message("BindComplete");
442        }
443        BindComplete::parse(&self.conn.buffer_set.read_buffer)?;
444
445        // Use cached RowDescription from PreparedStatement (no copy)
446        let row_desc = stmt.and_then(|s| s.row_desc_payload());
447
448        // Now read rows (no RowDescription/NoData expected from server)
449        self.claim_rows_cached_inner(handler, row_desc).await
450    }
451
452    /// Common row reading logic (reads RowDescription from server).
453    async fn claim_rows_inner<H: ExtendedHandler>(&mut self, handler: &mut H) -> Result<()> {
454        // Expect RowDescription or NoData
455        self.read_next_message().await?;
456        let has_rows = match self.conn.buffer_set.type_byte {
457            msg_type::ROW_DESCRIPTION => {
458                self.column_buffer.clear();
459                self.column_buffer
460                    .extend_from_slice(&self.conn.buffer_set.read_buffer);
461                true
462            }
463            msg_type::NO_DATA => {
464                NoData::parse(&self.conn.buffer_set.read_buffer)?;
465                // No rows will follow, but we still need terminal message
466                false
467            }
468            _ => {
469                return Err(Error::LibraryBug(format!(
470                    "expected RowDescription or NoData, got '{}'",
471                    self.conn.buffer_set.type_byte as char
472                )));
473            }
474        };
475
476        // Read data rows until terminal message
477        loop {
478            self.read_next_message().await?;
479            let type_byte = self.conn.buffer_set.type_byte;
480
481            match type_byte {
482                msg_type::DATA_ROW => {
483                    if !has_rows {
484                        return Err(Error::LibraryBug(
485                            "received DataRow but no RowDescription".into(),
486                        ));
487                    }
488                    let cols = RowDescription::parse(&self.column_buffer)?;
489                    let row = DataRow::parse(&self.conn.buffer_set.read_buffer)?;
490                    handler.row(cols, row)?;
491                }
492                msg_type::COMMAND_COMPLETE => {
493                    let cmd = CommandComplete::parse(&self.conn.buffer_set.read_buffer)?;
494                    handler.result_end(cmd)?;
495                    return Ok(());
496                }
497                msg_type::EMPTY_QUERY_RESPONSE => {
498                    EmptyQueryResponse::parse(&self.conn.buffer_set.read_buffer)?;
499                    return Ok(());
500                }
501                _ => {
502                    return Err(Error::LibraryBug(format!(
503                        "unexpected message type in pipeline claim: '{}'",
504                        type_byte as char
505                    )));
506                }
507            }
508        }
509    }
510
511    /// Row reading logic with cached RowDescription (no server message expected).
512    async fn claim_rows_cached_inner<H: ExtendedHandler>(
513        &mut self,
514        handler: &mut H,
515        row_desc: Option<&[u8]>,
516    ) -> Result<()> {
517        // Read data rows until terminal message
518        loop {
519            self.read_next_message().await?;
520            let type_byte = self.conn.buffer_set.type_byte;
521
522            match type_byte {
523                msg_type::DATA_ROW => {
524                    let row_desc = row_desc.ok_or_else(|| {
525                        Error::LibraryBug("received DataRow but no RowDescription cached".into())
526                    })?;
527                    let cols = RowDescription::parse(row_desc)?;
528                    let row = DataRow::parse(&self.conn.buffer_set.read_buffer)?;
529                    handler.row(cols, row)?;
530                }
531                msg_type::COMMAND_COMPLETE => {
532                    let cmd = CommandComplete::parse(&self.conn.buffer_set.read_buffer)?;
533                    handler.result_end(cmd)?;
534                    return Ok(());
535                }
536                msg_type::EMPTY_QUERY_RESPONSE => {
537                    EmptyQueryResponse::parse(&self.conn.buffer_set.read_buffer)?;
538                    return Ok(());
539                }
540                _ => {
541                    return Err(Error::LibraryBug(format!(
542                        "unexpected message type in pipeline claim: '{}'",
543                        type_byte as char
544                    )));
545                }
546            }
547        }
548    }
549
550    /// Read the next message, skipping async messages and handling errors.
551    async fn read_next_message(&mut self) -> Result<()> {
552        loop {
553            self.conn
554                .stream
555                .read_message(&mut self.conn.buffer_set)
556                .await?;
557            let type_byte = self.conn.buffer_set.type_byte;
558
559            // Handle async messages
560            if RawMessage::is_async_type(type_byte) {
561                continue;
562            }
563
564            // Handle error
565            if type_byte == msg_type::ERROR_RESPONSE {
566                let error = ErrorResponse::parse(&self.conn.buffer_set.read_buffer)?;
567                return Err(error.into_error());
568            }
569
570            return Ok(());
571        }
572    }
573
574    /// Create an error for unexpected message type.
575    fn unexpected_message<T>(&self, expected: &str) -> Result<T> {
576        Err(Error::LibraryBug(format!(
577            "expected {}, got '{}'",
578            expected, self.conn.buffer_set.type_byte as char
579        )))
580    }
581
582    /// Returns the number of operations that have been queued but not yet claimed.
583    pub fn pending_count(&self) -> usize {
584        self.queue_seq - self.claim_seq
585    }
586
587    /// Returns true if the pipeline is in aborted state due to an error.
588    pub fn is_aborted(&self) -> bool {
589        self.aborted
590    }
591}