zero_postgres/state/
simple_query.rs1use crate::buffer_set::BufferSet;
4use crate::error::{Error, Result};
5use crate::handler::SimpleHandler;
6use crate::protocol::backend::{
7 CommandComplete, DataRow, EmptyQueryResponse, ErrorResponse, RawMessage, ReadyForQuery,
8 RowDescription, msg_type,
9};
10use crate::protocol::frontend::write_query;
11use crate::protocol::types::TransactionStatus;
12
13use super::StateMachine;
14use super::action::{Action, AsyncMessage};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18enum State {
19 Initial,
20 WaitingResponse,
21 ProcessingRows,
22 WaitingReady,
23 Finished,
24}
25
26pub struct SimpleQueryStateMachine<'a, 'q, H> {
28 state: State,
29 handler: &'a mut H,
30 query: &'q str,
31 transaction_status: TransactionStatus,
32 pending_error: Option<crate::error::ServerError>,
33}
34
35impl<'a, 'q, H: SimpleHandler> SimpleQueryStateMachine<'a, 'q, H> {
36 pub fn new(handler: &'a mut H, query: &'q str) -> Self {
38 Self {
39 state: State::Initial,
40 handler,
41 query,
42 transaction_status: TransactionStatus::Idle,
43 pending_error: None,
44 }
45 }
46
47 fn handle_response(&mut self, buffer_set: &mut BufferSet) -> Result<Action> {
48 let type_byte = buffer_set.type_byte;
49 let payload = &buffer_set.read_buffer;
50
51 match type_byte {
52 msg_type::ROW_DESCRIPTION => {
53 buffer_set.column_buffer.clear();
55 buffer_set.column_buffer.extend_from_slice(payload);
56 let cols = RowDescription::parse(&buffer_set.column_buffer)?;
57 self.handler.result_start(cols)?;
58 self.state = State::ProcessingRows;
59 Ok(Action::ReadMessage)
60 }
61 msg_type::COMMAND_COMPLETE => {
62 let complete = CommandComplete::parse(payload)?;
63 self.handler.result_end(complete)?;
64 self.state = State::WaitingResponse;
66 Ok(Action::ReadMessage)
67 }
68 msg_type::EMPTY_QUERY_RESPONSE => {
69 EmptyQueryResponse::parse(payload)?;
70 self.state = State::WaitingReady;
72 Ok(Action::ReadMessage)
73 }
74 msg_type::READY_FOR_QUERY => {
75 let ready = ReadyForQuery::parse(payload)?;
76 self.transaction_status = ready.transaction_status().unwrap_or_default();
77 self.state = State::Finished;
78 Ok(Action::Finished)
79 }
80 _ => Err(Error::LibraryBug(format!(
81 "Unexpected message in query response: '{}'",
82 type_byte as char
83 ))),
84 }
85 }
86
87 fn handle_rows(&mut self, buffer_set: &BufferSet) -> Result<Action> {
88 let type_byte = buffer_set.type_byte;
89 let payload = &buffer_set.read_buffer;
90
91 match type_byte {
92 msg_type::DATA_ROW => {
93 let cols = RowDescription::parse(&buffer_set.column_buffer)?;
94 let row = DataRow::parse(payload)?;
95 self.handler.row(cols, row)?;
96 Ok(Action::ReadMessage)
97 }
98 msg_type::COMMAND_COMPLETE => {
99 let complete = CommandComplete::parse(payload)?;
100 self.handler.result_end(complete)?;
101 self.state = State::WaitingResponse;
103 Ok(Action::ReadMessage)
104 }
105 msg_type::READY_FOR_QUERY => {
106 let ready = ReadyForQuery::parse(payload)?;
107 self.transaction_status = ready.transaction_status().unwrap_or_default();
108 self.state = State::Finished;
109 Ok(Action::Finished)
110 }
111 _ => Err(Error::LibraryBug(format!(
112 "Unexpected message in row processing: '{}'",
113 type_byte as char
114 ))),
115 }
116 }
117
118 fn handle_ready(&mut self, buffer_set: &BufferSet) -> Result<Action> {
119 if buffer_set.type_byte != msg_type::READY_FOR_QUERY {
120 return Err(Error::LibraryBug(format!(
121 "Expected ReadyForQuery, got '{}'",
122 buffer_set.type_byte as char
123 )));
124 }
125
126 let ready = ReadyForQuery::parse(&buffer_set.read_buffer)?;
127 self.transaction_status = ready.transaction_status().unwrap_or_default();
128 self.state = State::Finished;
129 if let Some(err) = self.pending_error.take() {
130 Ok(Action::Error(err))
131 } else {
132 Ok(Action::Finished)
133 }
134 }
135
136 fn handle_async_message(&self, msg: &RawMessage<'_>) -> Result<Action> {
137 match msg.type_byte {
138 msg_type::NOTICE_RESPONSE => {
139 let notice = crate::protocol::backend::NoticeResponse::parse(msg.payload)?;
140 Ok(Action::HandleAsyncMessageAndReadMessage(
141 AsyncMessage::Notice(notice.0),
142 ))
143 }
144 msg_type::PARAMETER_STATUS => {
145 let param = crate::protocol::backend::auth::ParameterStatus::parse(msg.payload)?;
146 Ok(Action::HandleAsyncMessageAndReadMessage(
147 AsyncMessage::ParameterChanged {
148 name: param.name.to_string(),
149 value: param.value.to_string(),
150 },
151 ))
152 }
153 msg_type::NOTIFICATION_RESPONSE => {
154 let notification =
155 crate::protocol::backend::auth::NotificationResponse::parse(msg.payload)?;
156 Ok(Action::HandleAsyncMessageAndReadMessage(
157 AsyncMessage::Notification {
158 pid: notification.pid,
159 channel: notification.channel.to_string(),
160 payload: notification.payload.to_string(),
161 },
162 ))
163 }
164 _ => Err(Error::LibraryBug(format!(
165 "Unknown async message type: '{}'",
166 msg.type_byte as char
167 ))),
168 }
169 }
170}
171
172impl<H: SimpleHandler> StateMachine for SimpleQueryStateMachine<'_, '_, H> {
173 fn step(&mut self, buffer_set: &mut BufferSet) -> Result<Action> {
174 if self.state == State::Initial {
176 buffer_set.write_buffer.clear();
177 write_query(&mut buffer_set.write_buffer, self.query);
178 self.state = State::WaitingResponse;
179 return Ok(Action::WriteAndReadMessage);
180 }
181
182 let type_byte = buffer_set.type_byte;
183
184 if RawMessage::is_async_type(type_byte) {
186 let msg = RawMessage::new(type_byte, &buffer_set.read_buffer);
187 return self.handle_async_message(&msg);
188 }
189
190 if type_byte == msg_type::ERROR_RESPONSE {
192 let error = ErrorResponse::parse(&buffer_set.read_buffer)?;
193 self.pending_error = Some(error.0);
194 self.state = State::WaitingReady;
195 return Ok(Action::ReadMessage);
196 }
197
198 match self.state {
199 State::WaitingResponse => self.handle_response(buffer_set),
200 State::ProcessingRows => self.handle_rows(buffer_set),
201 State::WaitingReady => self.handle_ready(buffer_set),
202 _ => Err(Error::LibraryBug(format!(
203 "Unexpected state {:?}",
204 self.state
205 ))),
206 }
207 }
208
209 fn transaction_status(&self) -> TransactionStatus {
210 self.transaction_status
211 }
212}