zero_postgres/state/
batch_prepare.rs1use crate::buffer_set::BufferSet;
6use crate::error::{Error, Result};
7use crate::protocol::backend::{
8 ErrorResponse, NoData, ParameterDescription, ParseComplete, RawMessage, ReadyForQuery, msg_type,
9};
10use crate::protocol::frontend::{write_describe_statement, write_parse, write_sync};
11use crate::protocol::types::TransactionStatus;
12
13use super::action::Action;
14use super::extended::PreparedStatement;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18enum State {
19 Initial,
20 Processing,
21 Finished,
22}
23
24pub struct BatchPrepareStateMachine {
29 state: State,
30 statements: Vec<PreparedStatement>,
32 current_stmt: usize,
34 transaction_status: TransactionStatus,
35}
36
37impl BatchPrepareStateMachine {
38 pub fn new(buffer_set: &mut BufferSet, queries: &[&str], start_idx: u64) -> Self {
42 buffer_set.write_buffer.clear();
43
44 let mut statements = Vec::with_capacity(queries.len());
45
46 for (i, query) in queries.iter().enumerate() {
47 let idx = start_idx + i as u64;
48 let stmt_name = format!("_zero_s_{}", idx);
49 write_parse(&mut buffer_set.write_buffer, &stmt_name, query, &[]);
50 write_describe_statement(&mut buffer_set.write_buffer, &stmt_name);
51 statements.push(PreparedStatement {
52 idx,
53 param_oids: Vec::new(),
54 row_desc_payload: None,
55 });
56 }
57
58 write_sync(&mut buffer_set.write_buffer);
59
60 Self {
61 state: State::Initial,
62 statements,
63 current_stmt: 0,
64 transaction_status: TransactionStatus::Idle,
65 }
66 }
67
68 pub fn take_statements(&mut self) -> Vec<PreparedStatement> {
70 std::mem::take(&mut self.statements)
71 }
72
73 pub fn transaction_status(&self) -> TransactionStatus {
75 self.transaction_status
76 }
77
78 pub fn step(&mut self, buffer_set: &mut BufferSet) -> Result<Action> {
80 if self.state == State::Initial {
82 self.state = State::Processing;
83 return Ok(Action::WriteAndReadMessage);
84 }
85
86 let type_byte = buffer_set.type_byte;
87
88 if RawMessage::is_async_type(type_byte) {
90 return Ok(Action::ReadMessage);
91 }
92
93 if type_byte == msg_type::ERROR_RESPONSE {
95 let error = ErrorResponse::parse(&buffer_set.read_buffer)?;
96 return Err(error.into_error());
97 }
98
99 match self.state {
100 State::Processing => match type_byte {
101 msg_type::PARSE_COMPLETE => {
102 ParseComplete::parse(&buffer_set.read_buffer)?;
103 Ok(Action::ReadMessage)
104 }
105 msg_type::PARAMETER_DESCRIPTION => {
106 let param_desc = ParameterDescription::parse(&buffer_set.read_buffer)?;
107 if self.current_stmt < self.statements.len() {
108 self.statements[self.current_stmt].param_oids = param_desc.oids().to_vec();
109 }
110 Ok(Action::ReadMessage)
111 }
112 msg_type::ROW_DESCRIPTION => {
113 if self.current_stmt < self.statements.len() {
114 self.statements[self.current_stmt].row_desc_payload =
115 Some(buffer_set.read_buffer.clone());
116 }
117 self.current_stmt += 1;
118 Ok(Action::ReadMessage)
119 }
120 msg_type::NO_DATA => {
121 NoData::parse(&buffer_set.read_buffer)?;
122 self.current_stmt += 1;
124 Ok(Action::ReadMessage)
125 }
126 msg_type::READY_FOR_QUERY => {
127 let ready = ReadyForQuery::parse(&buffer_set.read_buffer)?;
128 self.transaction_status = ready.transaction_status().unwrap_or_default();
129 self.state = State::Finished;
130 Ok(Action::Finished)
131 }
132 _ => Err(Error::Protocol(format!(
133 "Unexpected message in batch prepare: '{}'",
134 type_byte as char
135 ))),
136 },
137 _ => Err(Error::Protocol(format!(
138 "Unexpected state {:?}",
139 self.state
140 ))),
141 }
142 }
143}