zero_postgres/state/
batch_prepare.rs

1//! Batch prepare state machine.
2//!
3//! Used by `prepare_batch` to prepare multiple statements in a single round-trip.
4
5use crate::buffer_set::BufferSet;
6use crate::error::{Error, Result};
7use crate::protocol::backend::{
8    ErrorResponse, NoData, ParameterDescription, ParseComplete, RawMessage, ReadyForQuery, msg_type,
9};
10use crate::protocol::frontend::{write_describe_statement, write_parse, write_sync};
11use crate::protocol::types::TransactionStatus;
12
13use super::action::Action;
14use super::extended::PreparedStatement;
15
16/// State for batch prepare flow.
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18enum State {
19    Initial,
20    Processing,
21    Finished,
22}
23
24/// State machine for batch prepare (Parse + Describe)* + Sync.
25///
26/// Prepares multiple statements in a single round-trip by sending all
27/// Parse and DescribeStatement messages followed by a single Sync.
28pub struct BatchPrepareStateMachine {
29    state: State,
30    /// Statements being prepared
31    statements: Vec<PreparedStatement>,
32    /// Current statement index we're processing responses for
33    current_stmt: usize,
34    transaction_status: TransactionStatus,
35}
36
37impl BatchPrepareStateMachine {
38    /// Create a new batch prepare state machine.
39    ///
40    /// Writes all Parse + DescribeStatement messages followed by Sync to the buffer.
41    pub fn new(buffer_set: &mut BufferSet, queries: &[&str], start_idx: u64) -> Self {
42        buffer_set.write_buffer.clear();
43
44        let mut statements = Vec::with_capacity(queries.len());
45
46        for (i, query) in queries.iter().enumerate() {
47            let idx = start_idx + i as u64;
48            let stmt_name = format!("_zero_s_{}", idx);
49            write_parse(&mut buffer_set.write_buffer, &stmt_name, query, &[]);
50            write_describe_statement(&mut buffer_set.write_buffer, &stmt_name);
51            statements.push(PreparedStatement {
52                idx,
53                param_oids: Vec::new(),
54                row_desc_payload: None,
55            });
56        }
57
58        write_sync(&mut buffer_set.write_buffer);
59
60        Self {
61            state: State::Initial,
62            statements,
63            current_stmt: 0,
64            transaction_status: TransactionStatus::Idle,
65        }
66    }
67
68    /// Take the prepared statements after completion.
69    pub fn take_statements(&mut self) -> Vec<PreparedStatement> {
70        std::mem::take(&mut self.statements)
71    }
72
73    /// Get the transaction status after completion.
74    pub fn transaction_status(&self) -> TransactionStatus {
75        self.transaction_status
76    }
77
78    /// Process input and return the next action.
79    pub fn step(&mut self, buffer_set: &mut BufferSet) -> Result<Action> {
80        // Initial state: write buffer was pre-filled by constructor
81        if self.state == State::Initial {
82            self.state = State::Processing;
83            return Ok(Action::WriteAndReadMessage);
84        }
85
86        let type_byte = buffer_set.type_byte;
87
88        // Handle async messages - need to keep reading
89        if RawMessage::is_async_type(type_byte) {
90            return Ok(Action::ReadMessage);
91        }
92
93        // Handle error response
94        if type_byte == msg_type::ERROR_RESPONSE {
95            let error = ErrorResponse::parse(&buffer_set.read_buffer)?;
96            return Err(error.into_error());
97        }
98
99        match self.state {
100            State::Processing => match type_byte {
101                msg_type::PARSE_COMPLETE => {
102                    ParseComplete::parse(&buffer_set.read_buffer)?;
103                    Ok(Action::ReadMessage)
104                }
105                msg_type::PARAMETER_DESCRIPTION => {
106                    let param_desc = ParameterDescription::parse(&buffer_set.read_buffer)?;
107                    if self.current_stmt < self.statements.len() {
108                        self.statements[self.current_stmt].param_oids = param_desc.oids().to_vec();
109                    }
110                    Ok(Action::ReadMessage)
111                }
112                msg_type::ROW_DESCRIPTION => {
113                    if self.current_stmt < self.statements.len() {
114                        self.statements[self.current_stmt].row_desc_payload =
115                            Some(buffer_set.read_buffer.clone());
116                    }
117                    self.current_stmt += 1;
118                    Ok(Action::ReadMessage)
119                }
120                msg_type::NO_DATA => {
121                    NoData::parse(&buffer_set.read_buffer)?;
122                    // Statement doesn't return rows
123                    self.current_stmt += 1;
124                    Ok(Action::ReadMessage)
125                }
126                msg_type::READY_FOR_QUERY => {
127                    let ready = ReadyForQuery::parse(&buffer_set.read_buffer)?;
128                    self.transaction_status = ready.transaction_status().unwrap_or_default();
129                    self.state = State::Finished;
130                    Ok(Action::Finished)
131                }
132                _ => Err(Error::Protocol(format!(
133                    "Unexpected message in batch prepare: '{}'",
134                    type_byte as char
135                ))),
136            },
137            _ => Err(Error::Protocol(format!(
138                "Unexpected state {:?}",
139                self.state
140            ))),
141        }
142    }
143}