zero_postgres/tokio/pipeline/
mod.rs

1//! Async pipeline mode for batching multiple queries.
2//!
3//! Pipeline mode allows sending multiple queries to the server without waiting
4//! for responses, reducing round-trip latency.
5//!
6//! # Example
7//!
8//! ```ignore
9//! // Prepare statements outside the pipeline
10//! let stmts = conn.prepare_batch(&[
11//!     "SELECT id, name FROM users WHERE active = $1",
12//!     "INSERT INTO users (name) VALUES ($1) RETURNING id",
13//! ]).await?;
14//!
15//! let (active, inactive, count) = conn.pipeline(|p| async move {
16//!     // Queue executions
17//!     let t1 = p.exec(&stmts[0], (true,)).await?;
18//!     let t2 = p.exec(&stmts[0], (false,)).await?;
19//!     let t3 = p.exec("SELECT COUNT(*) FROM users", ()).await?;
20//!
21//!     p.sync().await?;
22//!
23//!     // Claim results in order with different methods
24//!     let active: Vec<(i32, String)> = p.claim_collect(t1).await?;
25//!     let inactive: Option<(i32, String)> = p.claim_one(t2).await?;
26//!     let count: Vec<(i64,)> = p.claim_collect(t3).await?;
27//!
28//!     Ok((active, inactive, count))
29//! }).await?;
30//! ```
31
32use std::collections::VecDeque;
33
34use crate::pipeline::Expectation;
35use crate::pipeline::Ticket;
36
37use crate::conversion::{FromRow, ToParams};
38use crate::error::{Error, Result};
39use crate::handler::ExtendedHandler;
40use crate::protocol::backend::{
41    BindComplete, CommandComplete, DataRow, EmptyQueryResponse, ErrorResponse, NoData,
42    ParseComplete, RawMessage, ReadyForQuery, RowDescription, msg_type,
43};
44use crate::protocol::frontend::{
45    write_bind, write_describe_portal, write_execute, write_flush, write_parse, write_sync,
46};
47use crate::state::extended::PreparedStatement;
48use crate::statement::IntoStatement;
49
50use super::conn::Conn;
51
52/// Async pipeline mode for batching multiple queries.
53///
54/// Created by [`Conn::pipeline`].
55pub struct Pipeline<'a> {
56    conn: &'a mut Conn,
57    /// Monotonically increasing counter for queued operations
58    queue_seq: usize,
59    /// Next sequence number to claim
60    claim_seq: usize,
61    /// Whether the pipeline is in aborted state (error occurred)
62    aborted: bool,
63    /// Buffer for column descriptions during row processing
64    column_buffer: Vec<u8>,
65    /// Expected responses queue (exec operations and Sync points)
66    expectations: VecDeque<Expectation>,
67}
68
69impl<'a> Pipeline<'a> {
70    /// Create a new pipeline.
71    ///
72    /// Prefer using [`Conn::pipeline`] which handles cleanup automatically.
73    /// This constructor is available for advanced use cases.
74    #[cfg(feature = "lowlevel")]
75    pub fn new(conn: &'a mut Conn) -> Self {
76        Self::new_inner(conn)
77    }
78
79    /// Create a new pipeline (internal).
80    pub(crate) fn new_inner(conn: &'a mut Conn) -> Self {
81        conn.buffer_set.write_buffer.clear();
82        Self {
83            conn,
84            queue_seq: 0,
85            claim_seq: 0,
86            aborted: false,
87            column_buffer: Vec::new(),
88            expectations: VecDeque::new(),
89        }
90    }
91
92    /// Cleanup the pipeline, draining any unclaimed tickets.
93    ///
94    /// This is called automatically by [`Conn::pipeline`].
95    /// Also available with the `lowlevel` feature for manual cleanup.
96    #[cfg(feature = "lowlevel")]
97    pub async fn cleanup(&mut self) {
98        self.cleanup_inner().await;
99    }
100
101    #[cfg(not(feature = "lowlevel"))]
102    pub(crate) async fn cleanup(&mut self) {
103        self.cleanup_inner().await;
104    }
105
106    async fn cleanup_inner(&mut self) {
107        // Nothing to clean up if no operations were queued and no expectations pending
108        if self.queue_seq == 0 && self.expectations.is_empty() {
109            return;
110        }
111
112        // Send sync if we have pending operations that weren't synced
113        if !self.conn.buffer_set.write_buffer.is_empty() {
114            let _ = self.sync().await;
115        } else if !self.expectations.iter().any(|e| *e == Expectation::Sync) {
116            // Buffer was flushed but no sync sent - send one now
117            let _ = self.sync().await;
118        }
119
120        // Drain remaining expectations
121        if self.aborted {
122            // In aborted state, server skipped remaining commands - only consume ReadyForQuery(s)
123            while let Some(expectation) = self.expectations.pop_front() {
124                if expectation == Expectation::Sync {
125                    let _ = self.consume_ready_for_query().await;
126                }
127            }
128        } else {
129            // Normal drain: process all expectations
130            while let Some(expectation) = self.expectations.pop_front() {
131                let _ = self.drain_expectation(expectation).await;
132            }
133        }
134
135        // Reset state
136        self.queue_seq = 0;
137        self.claim_seq = 0;
138        self.aborted = false;
139    }
140
141    /// Drain a single expectation.
142    async fn drain_expectation(&mut self, expectation: Expectation) {
143        let mut handler = crate::handler::DropHandler::new();
144        let _ = match expectation {
145            Expectation::ParseBindExecute => self.claim_parse_bind_exec_inner(&mut handler).await,
146            Expectation::BindExecute => self.claim_bind_exec_inner(&mut handler, None).await,
147            Expectation::Sync => self.consume_ready_for_query().await,
148        };
149    }
150
151    // ========================================================================
152    // Queue Operations
153    // ========================================================================
154
155    /// Queue a statement execution.
156    ///
157    /// The statement can be either:
158    /// - A `&PreparedStatement` returned from `conn.prepare()` or `conn.prepare_batch()`
159    /// - A raw SQL `&str` for one-shot execution
160    ///
161    /// This method only buffers the command locally - no network I/O occurs until
162    /// `sync()` or `flush()` is called.
163    ///
164    /// # Example
165    ///
166    /// ```ignore
167    /// let stmt = conn.prepare("SELECT id, name FROM users WHERE id = $1").await?;
168    ///
169    /// let (r1, r2) = conn.pipeline(|p| async move {
170    ///     let t1 = p.exec(&stmt, (1,))?;
171    ///     let t2 = p.exec("SELECT COUNT(*) FROM users", ())?;
172    ///     p.sync().await?;
173    ///
174    ///     let r1: Vec<(i32, String)> = p.claim_collect(t1).await?;
175    ///     let r2: Option<(i64,)> = p.claim_one(t2).await?;
176    ///     Ok((r1, r2))
177    /// }).await?;
178    /// ```
179    pub fn exec<'s, P: ToParams>(
180        &mut self,
181        statement: &'s (impl IntoStatement + ?Sized),
182        params: P,
183    ) -> Result<Ticket<'s>> {
184        let seq = self.queue_seq;
185        self.queue_seq += 1;
186
187        if statement.needs_parse() {
188            self.exec_sql_inner(statement.as_sql().unwrap(), &params)?;
189            Ok(Ticket { seq, stmt: None })
190        } else {
191            let stmt = statement.as_prepared().unwrap();
192            self.exec_prepared_inner(&stmt.wire_name(), &stmt.param_oids, &params)?;
193            Ok(Ticket {
194                seq,
195                stmt: Some(stmt),
196            })
197        }
198    }
199
200    fn exec_sql_inner<P: ToParams>(&mut self, sql: &str, params: &P) -> Result<()> {
201        let param_oids = params.natural_oids();
202        let buf = &mut self.conn.buffer_set.write_buffer;
203        write_parse(buf, "", sql, &param_oids);
204        write_bind(buf, "", "", params, &param_oids)?;
205        write_describe_portal(buf, "");
206        write_execute(buf, "", 0);
207        self.expectations.push_back(Expectation::ParseBindExecute);
208        Ok(())
209    }
210
211    fn exec_prepared_inner<P: ToParams>(
212        &mut self,
213        stmt_name: &str,
214        param_oids: &[u32],
215        params: &P,
216    ) -> Result<()> {
217        let buf = &mut self.conn.buffer_set.write_buffer;
218        write_bind(buf, "", stmt_name, params, param_oids)?;
219        // Skip write_describe_portal - use cached RowDescription from PreparedStatement
220        write_execute(buf, "", 0);
221        self.expectations.push_back(Expectation::BindExecute);
222        Ok(())
223    }
224
225    /// Send a FLUSH message to trigger server response.
226    ///
227    /// This forces the server to send all pending responses without establishing
228    /// a transaction boundary. Called automatically by claim methods when needed.
229    pub async fn flush(&mut self) -> Result<()> {
230        if !self.conn.buffer_set.write_buffer.is_empty() {
231            write_flush(&mut self.conn.buffer_set.write_buffer);
232            self.conn
233                .stream
234                .write_all(&self.conn.buffer_set.write_buffer)
235                .await?;
236            self.conn.stream.flush().await?;
237            self.conn.buffer_set.write_buffer.clear();
238        }
239        Ok(())
240    }
241
242    /// Send a SYNC message to establish a transaction boundary.
243    ///
244    /// After calling sync, you must claim all queued operations in order.
245    /// The ReadyForQuery message will be consumed automatically after claiming.
246    pub async fn sync(&mut self) -> Result<()> {
247        let result = self.sync_inner().await;
248        if let Err(e) = &result
249            && e.is_connection_broken()
250        {
251            self.conn.is_broken = true;
252        }
253        result
254    }
255
256    async fn sync_inner(&mut self) -> Result<()> {
257        write_sync(&mut self.conn.buffer_set.write_buffer);
258        self.expectations.push_back(Expectation::Sync);
259        self.conn
260            .stream
261            .write_all(&self.conn.buffer_set.write_buffer)
262            .await?;
263        self.conn.stream.flush().await?;
264        self.conn.buffer_set.write_buffer.clear();
265        Ok(())
266    }
267
268    /// Consume a single ReadyForQuery message.
269    async fn consume_ready_for_query(&mut self) -> Result<()> {
270        loop {
271            self.conn
272                .stream
273                .read_message(&mut self.conn.buffer_set)
274                .await?;
275            let type_byte = self.conn.buffer_set.type_byte;
276
277            if RawMessage::is_async_type(type_byte) {
278                continue;
279            }
280
281            if type_byte == msg_type::ERROR_RESPONSE {
282                let error = ErrorResponse::parse(&self.conn.buffer_set.read_buffer)?;
283                return Err(error.into_error());
284            }
285
286            if type_byte == msg_type::READY_FOR_QUERY {
287                let ready = ReadyForQuery::parse(&self.conn.buffer_set.read_buffer)?;
288                self.conn.transaction_status = ready.transaction_status().unwrap_or_default();
289                return Ok(());
290            }
291        }
292    }
293
294    /// Consume all pending Sync expectations from the front of the queue.
295    async fn consume_pending_syncs(&mut self) -> Result<()> {
296        while self.expectations.front() == Some(&Expectation::Sync) {
297            self.expectations.pop_front();
298            self.consume_ready_for_query().await?;
299            // Reset aborted state - after ReadyForQuery, pipeline can continue
300            self.aborted = false;
301        }
302        Ok(())
303    }
304
305    // ========================================================================
306    // Claim Operations
307    // ========================================================================
308
309    /// Claim with a custom handler.
310    ///
311    /// Results must be claimed in the same order they were queued.
312    #[cfg(feature = "lowlevel")]
313    pub async fn claim<H: ExtendedHandler>(
314        &mut self,
315        ticket: Ticket<'_>,
316        handler: &mut H,
317    ) -> Result<()> {
318        self.claim_with_handler(ticket, handler).await
319    }
320
321    async fn claim_with_handler<H: ExtendedHandler>(
322        &mut self,
323        ticket: Ticket<'_>,
324        handler: &mut H,
325    ) -> Result<()> {
326        self.check_sequence(ticket.seq)?;
327
328        // Auto-sync if buffer has unsent data
329        if !self.conn.buffer_set.write_buffer.is_empty() {
330            self.sync().await?;
331        }
332
333        if self.aborted {
334            self.claim_seq += 1;
335            // Pop but don't process the exec expectation (server skipped it)
336            self.expectations.pop_front();
337            self.consume_pending_syncs().await?;
338            return Err(Error::Protocol(
339                "pipeline aborted due to earlier error".into(),
340            ));
341        }
342
343        let expectation = self.expectations.pop_front();
344
345        let result = match expectation {
346            Some(Expectation::ParseBindExecute) => self.claim_parse_bind_exec_inner(handler).await,
347            Some(Expectation::BindExecute) => {
348                self.claim_bind_exec_inner(handler, ticket.stmt).await
349            }
350            Some(Expectation::Sync) => Err(Error::Protocol("unexpected Sync expectation".into())),
351            None => Err(Error::Protocol("no expectation in queue".into())),
352        };
353
354        if let Err(e) = &result {
355            if e.is_connection_broken() {
356                self.conn.is_broken = true;
357            }
358            self.aborted = true;
359        }
360        self.claim_seq += 1;
361        self.consume_pending_syncs().await?;
362        result
363    }
364
365    /// Claim and collect all rows.
366    ///
367    /// Results must be claimed in the same order they were queued.
368    pub async fn claim_collect<T: for<'b> FromRow<'b>>(
369        &mut self,
370        ticket: Ticket<'_>,
371    ) -> Result<Vec<T>> {
372        let mut handler = crate::handler::CollectHandler::<T>::new();
373        self.claim_with_handler(ticket, &mut handler).await?;
374        Ok(handler.into_rows())
375    }
376
377    /// Claim and return just the first row.
378    ///
379    /// Results must be claimed in the same order they were queued.
380    pub async fn claim_one<T: for<'b> FromRow<'b>>(
381        &mut self,
382        ticket: Ticket<'_>,
383    ) -> Result<Option<T>> {
384        let mut handler = crate::handler::FirstRowHandler::<T>::new();
385        self.claim_with_handler(ticket, &mut handler).await?;
386        Ok(handler.into_row())
387    }
388
389    /// Claim and discard all rows.
390    ///
391    /// Results must be claimed in the same order they were queued.
392    pub async fn claim_drop(&mut self, ticket: Ticket<'_>) -> Result<()> {
393        let mut handler = crate::handler::DropHandler::new();
394        self.claim_with_handler(ticket, &mut handler).await
395    }
396
397    /// Check that the ticket sequence matches the expected claim sequence.
398    fn check_sequence(&self, seq: usize) -> Result<()> {
399        if seq != self.claim_seq {
400            return Err(Error::InvalidUsage(format!(
401                "claim out of order: expected seq {}, got {}",
402                self.claim_seq, seq
403            )));
404        }
405        Ok(())
406    }
407
408    /// Claim Parse + Bind + Execute (for raw SQL exec() calls).
409    async fn claim_parse_bind_exec_inner<H: ExtendedHandler>(
410        &mut self,
411        handler: &mut H,
412    ) -> Result<()> {
413        // Expect: ParseComplete
414        self.read_next_message().await?;
415        if self.conn.buffer_set.type_byte != msg_type::PARSE_COMPLETE {
416            return self.unexpected_message("ParseComplete");
417        }
418        ParseComplete::parse(&self.conn.buffer_set.read_buffer)?;
419
420        // Expect: BindComplete
421        self.read_next_message().await?;
422        if self.conn.buffer_set.type_byte != msg_type::BIND_COMPLETE {
423            return self.unexpected_message("BindComplete");
424        }
425        BindComplete::parse(&self.conn.buffer_set.read_buffer)?;
426
427        // Now read rows
428        self.claim_rows_inner(handler).await
429    }
430
431    /// Claim Bind + Execute (for prepared statement exec() calls).
432    async fn claim_bind_exec_inner<H: ExtendedHandler>(
433        &mut self,
434        handler: &mut H,
435        stmt: Option<&PreparedStatement>,
436    ) -> Result<()> {
437        // Expect: BindComplete
438        self.read_next_message().await?;
439        if self.conn.buffer_set.type_byte != msg_type::BIND_COMPLETE {
440            return self.unexpected_message("BindComplete");
441        }
442        BindComplete::parse(&self.conn.buffer_set.read_buffer)?;
443
444        // Use cached RowDescription from PreparedStatement (no copy)
445        let row_desc = stmt.and_then(|s| s.row_desc_payload());
446
447        // Now read rows (no RowDescription/NoData expected from server)
448        self.claim_rows_cached_inner(handler, row_desc).await
449    }
450
451    /// Common row reading logic (reads RowDescription from server).
452    async fn claim_rows_inner<H: ExtendedHandler>(&mut self, handler: &mut H) -> Result<()> {
453        // Expect RowDescription or NoData
454        self.read_next_message().await?;
455        let has_rows = match self.conn.buffer_set.type_byte {
456            msg_type::ROW_DESCRIPTION => {
457                self.column_buffer.clear();
458                self.column_buffer
459                    .extend_from_slice(&self.conn.buffer_set.read_buffer);
460                true
461            }
462            msg_type::NO_DATA => {
463                NoData::parse(&self.conn.buffer_set.read_buffer)?;
464                // No rows will follow, but we still need terminal message
465                false
466            }
467            _ => {
468                return Err(Error::Protocol(format!(
469                    "expected RowDescription or NoData, got '{}'",
470                    self.conn.buffer_set.type_byte as char
471                )));
472            }
473        };
474
475        // Read data rows until terminal message
476        loop {
477            self.read_next_message().await?;
478            let type_byte = self.conn.buffer_set.type_byte;
479
480            match type_byte {
481                msg_type::DATA_ROW => {
482                    if !has_rows {
483                        return Err(Error::Protocol(
484                            "received DataRow but no RowDescription".into(),
485                        ));
486                    }
487                    let cols = RowDescription::parse(&self.column_buffer)?;
488                    let row = DataRow::parse(&self.conn.buffer_set.read_buffer)?;
489                    handler.row(cols, row)?;
490                }
491                msg_type::COMMAND_COMPLETE => {
492                    let cmd = CommandComplete::parse(&self.conn.buffer_set.read_buffer)?;
493                    handler.result_end(cmd)?;
494                    return Ok(());
495                }
496                msg_type::EMPTY_QUERY_RESPONSE => {
497                    EmptyQueryResponse::parse(&self.conn.buffer_set.read_buffer)?;
498                    return Ok(());
499                }
500                _ => {
501                    return Err(Error::Protocol(format!(
502                        "unexpected message type in pipeline claim: '{}'",
503                        type_byte as char
504                    )));
505                }
506            }
507        }
508    }
509
510    /// Row reading logic with cached RowDescription (no server message expected).
511    async fn claim_rows_cached_inner<H: ExtendedHandler>(
512        &mut self,
513        handler: &mut H,
514        row_desc: Option<&[u8]>,
515    ) -> Result<()> {
516        // Read data rows until terminal message
517        loop {
518            self.read_next_message().await?;
519            let type_byte = self.conn.buffer_set.type_byte;
520
521            match type_byte {
522                msg_type::DATA_ROW => {
523                    let row_desc = row_desc.ok_or_else(|| {
524                        Error::Protocol("received DataRow but no RowDescription cached".into())
525                    })?;
526                    let cols = RowDescription::parse(row_desc)?;
527                    let row = DataRow::parse(&self.conn.buffer_set.read_buffer)?;
528                    handler.row(cols, row)?;
529                }
530                msg_type::COMMAND_COMPLETE => {
531                    let cmd = CommandComplete::parse(&self.conn.buffer_set.read_buffer)?;
532                    handler.result_end(cmd)?;
533                    return Ok(());
534                }
535                msg_type::EMPTY_QUERY_RESPONSE => {
536                    EmptyQueryResponse::parse(&self.conn.buffer_set.read_buffer)?;
537                    return Ok(());
538                }
539                _ => {
540                    return Err(Error::Protocol(format!(
541                        "unexpected message type in pipeline claim: '{}'",
542                        type_byte as char
543                    )));
544                }
545            }
546        }
547    }
548
549    /// Read the next message, skipping async messages and handling errors.
550    async fn read_next_message(&mut self) -> Result<()> {
551        loop {
552            self.conn
553                .stream
554                .read_message(&mut self.conn.buffer_set)
555                .await?;
556            let type_byte = self.conn.buffer_set.type_byte;
557
558            // Handle async messages
559            if RawMessage::is_async_type(type_byte) {
560                continue;
561            }
562
563            // Handle error
564            if type_byte == msg_type::ERROR_RESPONSE {
565                let error = ErrorResponse::parse(&self.conn.buffer_set.read_buffer)?;
566                return Err(error.into_error());
567            }
568
569            return Ok(());
570        }
571    }
572
573    /// Create an error for unexpected message type.
574    fn unexpected_message<T>(&self, expected: &str) -> Result<T> {
575        Err(Error::Protocol(format!(
576            "expected {}, got '{}'",
577            expected, self.conn.buffer_set.type_byte as char
578        )))
579    }
580
581    /// Returns the number of operations that have been queued but not yet claimed.
582    pub fn pending_count(&self) -> usize {
583        self.queue_seq - self.claim_seq
584    }
585
586    /// Returns true if the pipeline is in aborted state due to an error.
587    pub fn is_aborted(&self) -> bool {
588        self.aborted
589    }
590}