Skip to main content

zero_postgres/sync/pipeline/
mod.rs

1//! Pipeline mode for batching multiple queries.
2//!
3//! Pipeline mode allows sending multiple queries to the server without waiting
4//! for responses, reducing round-trip latency.
5//!
6//! # Example
7//!
8//! ```ignore
9//! // Prepare statements outside the pipeline
10//! let stmts = conn.prepare_batch(&[
11//!     "SELECT id, name FROM users WHERE active = $1",
12//!     "INSERT INTO users (name) VALUES ($1) RETURNING id",
13//! ])?;
14//!
15//! let (active, inactive, count) = conn.pipeline(|p| {
16//!     // Queue executions
17//!     let t1 = p.exec(&stmts[0], (true,))?;
18//!     let t2 = p.exec(&stmts[0], (false,))?;
19//!     let t3 = p.exec("SELECT COUNT(*) FROM users", ())?;
20//!
21//!     p.sync()?;
22//!
23//!     // Claim results in order with different methods
24//!     let active: Vec<(i32, String)> = p.claim_collect(t1)?;
25//!     let inactive: Option<(i32, String)> = p.claim_one(t2)?;
26//!     let count: Vec<(i64,)> = p.claim_collect(t3)?;
27//!
28//!     Ok((active, inactive, count))
29//! })?;
30//! ```
31
32use std::collections::VecDeque;
33
34use crate::pipeline::Expectation;
35use crate::pipeline::Ticket;
36
37use crate::conversion::{FromRow, ToParams};
38use crate::error::{Error, Result};
39use crate::handler::ExtendedHandler;
40use crate::protocol::backend::{
41    BindComplete, CommandComplete, DataRow, EmptyQueryResponse, ErrorResponse, NoData,
42    ParseComplete, RawMessage, ReadyForQuery, RowDescription, msg_type,
43};
44use crate::protocol::frontend::{
45    write_bind, write_describe_portal, write_execute, write_flush, write_parse, write_sync,
46};
47use crate::state::extended::PreparedStatement;
48use crate::statement::{IntoStatement, StatementRef};
49
50use super::conn::Conn;
51
52/// Pipeline mode for batching multiple queries.
53///
54/// Created by [`Conn::pipeline`].
55pub struct Pipeline<'a> {
56    conn: &'a mut Conn,
57    /// Monotonically increasing counter for queued operations
58    queue_seq: usize,
59    /// Next sequence number to claim
60    claim_seq: usize,
61    /// Whether the pipeline is in aborted state (error occurred)
62    aborted: bool,
63    /// Expected responses queue (exec operations and Sync points)
64    expectations: VecDeque<Expectation>,
65}
66
67impl<'a> Pipeline<'a> {
68    /// Create a new pipeline.
69    ///
70    /// Prefer using [`Conn::pipeline`] which handles cleanup automatically.
71    /// This constructor is available for advanced use cases.
72    #[cfg(feature = "lowlevel")]
73    pub fn new(conn: &'a mut Conn) -> Self {
74        Self::new_inner(conn)
75    }
76
77    /// Create a new pipeline (internal).
78    pub(crate) fn new_inner(conn: &'a mut Conn) -> Self {
79        conn.buffer_set.write_buffer.clear();
80        Self {
81            conn,
82            queue_seq: 0,
83            claim_seq: 0,
84            aborted: false,
85            expectations: VecDeque::new(),
86        }
87    }
88
89    /// Cleanup the pipeline, draining any unclaimed tickets.
90    ///
91    /// This is called automatically by [`Conn::pipeline`].
92    /// Also available with the `lowlevel` feature for manual cleanup.
93    #[cfg(feature = "lowlevel")]
94    pub fn cleanup(&mut self) {
95        self.cleanup_inner();
96    }
97
98    #[cfg(not(feature = "lowlevel"))]
99    pub(crate) fn cleanup(&mut self) {
100        self.cleanup_inner();
101    }
102
103    fn cleanup_inner(&mut self) {
104        // Nothing to clean up if no operations were queued and no expectations pending
105        if self.queue_seq == 0 && self.expectations.is_empty() {
106            return;
107        }
108
109        // Send sync if we have unflushed operations or no sync is queued yet
110        if !self.conn.buffer_set.write_buffer.is_empty()
111            || !self.expectations.iter().any(|e| *e == Expectation::Sync)
112        {
113            let _ = self.sync();
114        }
115
116        // Drain remaining expectations
117        if self.aborted {
118            // In aborted state, server skipped remaining commands - only consume ReadyForQuery(s)
119            while let Some(expectation) = self.expectations.pop_front() {
120                if expectation == Expectation::Sync {
121                    let _ = self.consume_ready_for_query();
122                }
123            }
124        } else {
125            // Normal drain: process all expectations
126            while let Some(expectation) = self.expectations.pop_front() {
127                self.drain_expectation(expectation);
128            }
129        }
130
131        // Reset state
132        self.queue_seq = 0;
133        self.claim_seq = 0;
134        self.aborted = false;
135    }
136
137    /// Drain a single expectation.
138    fn drain_expectation(&mut self, expectation: Expectation) {
139        let mut handler = crate::handler::DropHandler::new();
140        let _ = match expectation {
141            Expectation::ParseBindExecute => self.claim_parse_bind_exec_inner(&mut handler),
142            Expectation::BindExecute => self.claim_bind_exec_inner(&mut handler, None),
143            Expectation::Sync => self.consume_ready_for_query(),
144        };
145    }
146
147    // ========================================================================
148    // Queue Operations
149    // ========================================================================
150
151    /// Queue a statement execution.
152    ///
153    /// The statement can be either:
154    /// - A `&PreparedStatement` returned from `conn.prepare()` or `conn.prepare_batch()`
155    /// - A raw SQL `&str` for one-shot execution
156    ///
157    /// This method only buffers the command locally - no network I/O occurs until
158    /// `sync()` or `flush()` is called.
159    ///
160    /// # Example
161    ///
162    /// ```ignore
163    /// let stmt = conn.prepare("SELECT id, name FROM users WHERE id = $1")?;
164    ///
165    /// let (r1, r2) = conn.pipeline(|p| {
166    ///     let t1 = p.exec(&stmt, (1,))?;
167    ///     let t2 = p.exec("SELECT COUNT(*) FROM users", ())?;
168    ///     p.sync()?;
169    ///
170    ///     let r1: Vec<(i32, String)> = p.claim_collect(t1)?;
171    ///     let r2: Option<(i64,)> = p.claim_one(t2)?;
172    ///     Ok((r1, r2))
173    /// })?;
174    /// ```
175    pub fn exec<'s, P: ToParams>(
176        &mut self,
177        statement: &'s (impl IntoStatement + ?Sized),
178        params: P,
179    ) -> Result<Ticket<'s>> {
180        let seq = self.queue_seq;
181        self.queue_seq += 1;
182
183        match statement.statement_ref() {
184            StatementRef::Sql(sql) => {
185                self.exec_sql_inner(sql, &params)?;
186                Ok(Ticket { seq, stmt: None })
187            }
188            StatementRef::Prepared(stmt) => {
189                self.exec_prepared_inner(&stmt.wire_name(), &stmt.param_oids, &params)?;
190                Ok(Ticket {
191                    seq,
192                    stmt: Some(stmt),
193                })
194            }
195        }
196    }
197
198    fn exec_sql_inner<P: ToParams>(&mut self, sql: &str, params: &P) -> Result<()> {
199        let param_oids = params.natural_oids();
200        let buf = &mut self.conn.buffer_set.write_buffer;
201        write_parse(buf, "", sql, &param_oids);
202        write_bind(buf, "", "", params, &param_oids)?;
203        write_describe_portal(buf, "");
204        write_execute(buf, "", 0);
205        self.expectations.push_back(Expectation::ParseBindExecute);
206        Ok(())
207    }
208
209    fn exec_prepared_inner<P: ToParams>(
210        &mut self,
211        stmt_name: &str,
212        param_oids: &[u32],
213        params: &P,
214    ) -> Result<()> {
215        let buf = &mut self.conn.buffer_set.write_buffer;
216        write_bind(buf, "", stmt_name, params, param_oids)?;
217        // Skip write_describe_portal - use cached RowDescription from PreparedStatement
218        write_execute(buf, "", 0);
219        self.expectations.push_back(Expectation::BindExecute);
220        Ok(())
221    }
222
223    /// Send a FLUSH message to trigger server response.
224    ///
225    /// This forces the server to send all pending responses without establishing
226    /// a transaction boundary. Called automatically by claim methods when needed.
227    pub fn flush(&mut self) -> Result<()> {
228        if !self.conn.buffer_set.write_buffer.is_empty() {
229            write_flush(&mut self.conn.buffer_set.write_buffer);
230            self.conn
231                .stream
232                .write_all(&self.conn.buffer_set.write_buffer)?;
233            self.conn.stream.flush()?;
234            self.conn.buffer_set.write_buffer.clear();
235        }
236        Ok(())
237    }
238
239    /// Send a SYNC message to establish a transaction boundary.
240    ///
241    /// After calling sync, you must claim all queued operations in order.
242    /// The ReadyForQuery message will be consumed automatically after claiming.
243    pub fn sync(&mut self) -> Result<()> {
244        let result = self.sync_inner();
245        if let Err(e) = &result
246            && e.is_connection_broken()
247        {
248            self.conn.is_broken = true;
249        }
250        result
251    }
252
253    fn sync_inner(&mut self) -> Result<()> {
254        write_sync(&mut self.conn.buffer_set.write_buffer);
255        self.expectations.push_back(Expectation::Sync);
256        self.conn
257            .stream
258            .write_all(&self.conn.buffer_set.write_buffer)?;
259        self.conn.stream.flush()?;
260        self.conn.buffer_set.write_buffer.clear();
261        Ok(())
262    }
263
264    /// Consume a single ReadyForQuery message.
265    fn consume_ready_for_query(&mut self) -> Result<()> {
266        loop {
267            self.conn.stream.read_message(&mut self.conn.buffer_set)?;
268            let type_byte = self.conn.buffer_set.type_byte;
269
270            if RawMessage::is_async_type(type_byte) {
271                continue;
272            }
273
274            if type_byte == msg_type::ERROR_RESPONSE {
275                let error = ErrorResponse::parse(&self.conn.buffer_set.read_buffer)?;
276                return Err(error.into_error());
277            }
278
279            if type_byte == msg_type::READY_FOR_QUERY {
280                let ready = ReadyForQuery::parse(&self.conn.buffer_set.read_buffer)?;
281                self.conn.transaction_status = ready.transaction_status().unwrap_or_default();
282                return Ok(());
283            }
284        }
285    }
286
287    /// Consume all pending Sync expectations from the front of the queue.
288    fn consume_pending_syncs(&mut self) -> Result<()> {
289        while self.expectations.front() == Some(&Expectation::Sync) {
290            self.expectations.pop_front();
291            self.consume_ready_for_query()?;
292            // Reset aborted state - after ReadyForQuery, pipeline can continue
293            self.aborted = false;
294        }
295        Ok(())
296    }
297
298    // ========================================================================
299    // Claim Operations
300    // ========================================================================
301
302    /// Claim with a custom handler.
303    ///
304    /// Results must be claimed in the same order they were queued.
305    #[cfg(feature = "lowlevel")]
306    pub fn claim<H: ExtendedHandler>(&mut self, ticket: Ticket<'_>, handler: &mut H) -> Result<()> {
307        self.claim_with_handler(ticket, handler)
308    }
309
310    fn claim_with_handler<H: ExtendedHandler>(
311        &mut self,
312        ticket: Ticket<'_>,
313        handler: &mut H,
314    ) -> Result<()> {
315        self.check_sequence(ticket.seq)?;
316
317        // Auto-sync if buffer has unsent data
318        if !self.conn.buffer_set.write_buffer.is_empty() {
319            self.sync()?;
320        }
321
322        if self.aborted {
323            self.claim_seq += 1;
324            // Pop but don't process the exec expectation (server skipped it)
325            self.expectations.pop_front();
326            self.consume_pending_syncs()?;
327            return Err(Error::LibraryBug(
328                "pipeline aborted due to earlier error".into(),
329            ));
330        }
331
332        let expectation = self.expectations.pop_front();
333
334        let result = match expectation {
335            Some(Expectation::ParseBindExecute) => self.claim_parse_bind_exec_inner(handler),
336            Some(Expectation::BindExecute) => self.claim_bind_exec_inner(handler, ticket.stmt),
337            Some(Expectation::Sync) => Err(Error::LibraryBug("unexpected Sync expectation".into())),
338            None => Err(Error::LibraryBug("no expectation in queue".into())),
339        };
340
341        if let Err(e) = &result {
342            if e.is_connection_broken() {
343                self.conn.is_broken = true;
344            }
345            self.aborted = true;
346        }
347        self.claim_seq += 1;
348        self.consume_pending_syncs()?;
349        result
350    }
351
352    /// Claim and collect all rows.
353    ///
354    /// Results must be claimed in the same order they were queued.
355    pub fn claim_collect<T: for<'b> FromRow<'b>>(&mut self, ticket: Ticket<'_>) -> Result<Vec<T>> {
356        let mut handler = crate::handler::CollectHandler::<T>::new();
357        self.claim_with_handler(ticket, &mut handler)?;
358        Ok(handler.into_rows())
359    }
360
361    /// Claim and return just the first row.
362    ///
363    /// Results must be claimed in the same order they were queued.
364    pub fn claim_one<T: for<'b> FromRow<'b>>(&mut self, ticket: Ticket<'_>) -> Result<Option<T>> {
365        let mut handler = crate::handler::FirstRowHandler::<T>::new();
366        self.claim_with_handler(ticket, &mut handler)?;
367        Ok(handler.into_row())
368    }
369
370    /// Claim and discard all rows.
371    ///
372    /// Results must be claimed in the same order they were queued.
373    pub fn claim_drop(&mut self, ticket: Ticket<'_>) -> Result<()> {
374        let mut handler = crate::handler::DropHandler::new();
375        self.claim_with_handler(ticket, &mut handler)
376    }
377
378    /// Check that the ticket sequence matches the expected claim sequence.
379    fn check_sequence(&self, seq: usize) -> Result<()> {
380        if seq != self.claim_seq {
381            return Err(Error::InvalidUsage(format!(
382                "claim out of order: expected seq {}, got {}",
383                self.claim_seq, seq
384            )));
385        }
386        Ok(())
387    }
388
389    /// Claim Parse + Bind + Execute (for raw SQL exec() calls).
390    fn claim_parse_bind_exec_inner<H: ExtendedHandler>(&mut self, handler: &mut H) -> Result<()> {
391        // Expect: ParseComplete
392        self.read_next_message()?;
393        if self.conn.buffer_set.type_byte != msg_type::PARSE_COMPLETE {
394            return self.unexpected_message("ParseComplete");
395        }
396        ParseComplete::parse(&self.conn.buffer_set.read_buffer)?;
397
398        // Expect: BindComplete
399        self.read_next_message()?;
400        if self.conn.buffer_set.type_byte != msg_type::BIND_COMPLETE {
401            return self.unexpected_message("BindComplete");
402        }
403        BindComplete::parse(&self.conn.buffer_set.read_buffer)?;
404
405        // Now read rows
406        self.claim_rows_inner(handler)
407    }
408
409    /// Claim Bind + Execute (for prepared statement exec() calls).
410    fn claim_bind_exec_inner<H: ExtendedHandler>(
411        &mut self,
412        handler: &mut H,
413        stmt: Option<&PreparedStatement>,
414    ) -> Result<()> {
415        // Expect: BindComplete
416        self.read_next_message()?;
417        if self.conn.buffer_set.type_byte != msg_type::BIND_COMPLETE {
418            return self.unexpected_message("BindComplete");
419        }
420        BindComplete::parse(&self.conn.buffer_set.read_buffer)?;
421
422        // Use cached RowDescription from PreparedStatement (no copy)
423        let row_desc = stmt.and_then(|s| s.row_desc_payload());
424
425        // Now read rows (no RowDescription/NoData expected from server)
426        self.claim_rows_cached_inner(handler, row_desc)
427    }
428
429    /// Common row reading logic (reads RowDescription from server).
430    fn claim_rows_inner<H: ExtendedHandler>(&mut self, handler: &mut H) -> Result<()> {
431        // Expect RowDescription or NoData
432        self.read_next_message()?;
433        let has_rows = match self.conn.buffer_set.type_byte {
434            msg_type::ROW_DESCRIPTION => {
435                self.conn.buffer_set.save_column_buffer();
436                true
437            }
438            msg_type::NO_DATA => {
439                NoData::parse(&self.conn.buffer_set.read_buffer)?;
440                // No rows will follow, but we still need terminal message
441                false
442            }
443            _ => {
444                return Err(Error::LibraryBug(format!(
445                    "expected RowDescription or NoData, got '{}'",
446                    self.conn.buffer_set.type_byte as char
447                )));
448            }
449        };
450
451        // Read data rows until terminal message
452        loop {
453            self.read_next_message()?;
454            let type_byte = self.conn.buffer_set.type_byte;
455
456            match type_byte {
457                msg_type::DATA_ROW => {
458                    if !has_rows {
459                        return Err(Error::LibraryBug(
460                            "received DataRow but no RowDescription".into(),
461                        ));
462                    }
463                    let cols = RowDescription::parse(&self.conn.buffer_set.column_buffer)?;
464                    let row = DataRow::parse(&self.conn.buffer_set.read_buffer)?;
465                    handler.row(cols, row)?;
466                }
467                msg_type::COMMAND_COMPLETE => {
468                    let cmd = CommandComplete::parse(&self.conn.buffer_set.read_buffer)?;
469                    handler.result_end(cmd)?;
470                    return Ok(());
471                }
472                msg_type::EMPTY_QUERY_RESPONSE => {
473                    EmptyQueryResponse::parse(&self.conn.buffer_set.read_buffer)?;
474                    return Ok(());
475                }
476                _ => {
477                    return Err(Error::LibraryBug(format!(
478                        "unexpected message type in pipeline claim: '{}'",
479                        type_byte as char
480                    )));
481                }
482            }
483        }
484    }
485
486    /// Row reading logic with cached RowDescription (no server message expected).
487    fn claim_rows_cached_inner<H: ExtendedHandler>(
488        &mut self,
489        handler: &mut H,
490        row_desc: Option<&[u8]>,
491    ) -> Result<()> {
492        // Read data rows until terminal message
493        loop {
494            self.read_next_message()?;
495            let type_byte = self.conn.buffer_set.type_byte;
496
497            match type_byte {
498                msg_type::DATA_ROW => {
499                    let row_desc = row_desc.ok_or_else(|| {
500                        Error::LibraryBug("received DataRow but no RowDescription cached".into())
501                    })?;
502                    let cols = RowDescription::parse(row_desc)?;
503                    let row = DataRow::parse(&self.conn.buffer_set.read_buffer)?;
504                    handler.row(cols, row)?;
505                }
506                msg_type::COMMAND_COMPLETE => {
507                    let cmd = CommandComplete::parse(&self.conn.buffer_set.read_buffer)?;
508                    handler.result_end(cmd)?;
509                    return Ok(());
510                }
511                msg_type::EMPTY_QUERY_RESPONSE => {
512                    EmptyQueryResponse::parse(&self.conn.buffer_set.read_buffer)?;
513                    return Ok(());
514                }
515                _ => {
516                    return Err(Error::LibraryBug(format!(
517                        "unexpected message type in pipeline claim: '{}'",
518                        type_byte as char
519                    )));
520                }
521            }
522        }
523    }
524
525    /// Read the next message, skipping async messages and handling errors.
526    fn read_next_message(&mut self) -> Result<()> {
527        loop {
528            self.conn.stream.read_message(&mut self.conn.buffer_set)?;
529            let type_byte = self.conn.buffer_set.type_byte;
530
531            // Handle async messages
532            if RawMessage::is_async_type(type_byte) {
533                continue;
534            }
535
536            // Handle error
537            if type_byte == msg_type::ERROR_RESPONSE {
538                let error = ErrorResponse::parse(&self.conn.buffer_set.read_buffer)?;
539                return Err(error.into_error());
540            }
541
542            return Ok(());
543        }
544    }
545
546    /// Create an error for unexpected message type.
547    fn unexpected_message<T>(&self, expected: &str) -> Result<T> {
548        Err(Error::LibraryBug(format!(
549            "expected {}, got '{}'",
550            expected, self.conn.buffer_set.type_byte as char
551        )))
552    }
553
554    /// Returns the number of operations that have been queued but not yet claimed.
555    pub fn pending_count(&self) -> usize {
556        self.queue_seq - self.claim_seq
557    }
558
559    /// Returns true if the pipeline is in aborted state due to an error.
560    pub fn is_aborted(&self) -> bool {
561        self.aborted
562    }
563}