zero_postgres/state/
simple_query.rs

1//! Simple query protocol state machine.
2
3use crate::buffer_set::BufferSet;
4use crate::error::{Error, Result};
5use crate::handler::SimpleHandler;
6use crate::protocol::backend::{
7    CommandComplete, DataRow, EmptyQueryResponse, ErrorResponse, RawMessage, ReadyForQuery,
8    RowDescription, msg_type,
9};
10use crate::protocol::frontend::write_query;
11use crate::protocol::types::TransactionStatus;
12
13use super::StateMachine;
14use super::action::{Action, AsyncMessage};
15
16/// Simple query state machine state.
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18enum State {
19    Initial,
20    WaitingResponse,
21    ProcessingRows,
22    WaitingReady,
23    Finished,
24}
25
26/// Simple query protocol state machine.
27pub struct SimpleQueryStateMachine<'a, 'q, H> {
28    state: State,
29    handler: &'a mut H,
30    query: &'q str,
31    transaction_status: TransactionStatus,
32}
33
34impl<'a, 'q, H: SimpleHandler> SimpleQueryStateMachine<'a, 'q, H> {
35    /// Create a new simple query state machine.
36    pub fn new(handler: &'a mut H, query: &'q str) -> Self {
37        Self {
38            state: State::Initial,
39            handler,
40            query,
41            transaction_status: TransactionStatus::Idle,
42        }
43    }
44
45    fn handle_response(&mut self, buffer_set: &mut BufferSet) -> Result<Action> {
46        let type_byte = buffer_set.type_byte;
47        let payload = &buffer_set.read_buffer;
48
49        match type_byte {
50            msg_type::ROW_DESCRIPTION => {
51                // Store column buffer for later use in row callbacks
52                buffer_set.column_buffer.clear();
53                buffer_set.column_buffer.extend_from_slice(payload);
54                let cols = RowDescription::parse(&buffer_set.column_buffer)?;
55                self.handler.result_start(cols)?;
56                self.state = State::ProcessingRows;
57                Ok(Action::ReadMessage)
58            }
59            msg_type::COMMAND_COMPLETE => {
60                let complete = CommandComplete::parse(payload)?;
61                self.handler.result_end(complete)?;
62                // More commands may follow in a multi-statement query
63                self.state = State::WaitingResponse;
64                Ok(Action::ReadMessage)
65            }
66            msg_type::EMPTY_QUERY_RESPONSE => {
67                EmptyQueryResponse::parse(payload)?;
68                // Empty query string - silently ignore
69                self.state = State::WaitingReady;
70                Ok(Action::ReadMessage)
71            }
72            msg_type::READY_FOR_QUERY => {
73                let ready = ReadyForQuery::parse(payload)?;
74                self.transaction_status = ready.transaction_status().unwrap_or_default();
75                self.state = State::Finished;
76                Ok(Action::Finished)
77            }
78            _ => Err(Error::Protocol(format!(
79                "Unexpected message in query response: '{}'",
80                type_byte as char
81            ))),
82        }
83    }
84
85    fn handle_rows(&mut self, buffer_set: &BufferSet) -> Result<Action> {
86        let type_byte = buffer_set.type_byte;
87        let payload = &buffer_set.read_buffer;
88
89        match type_byte {
90            msg_type::DATA_ROW => {
91                let cols = RowDescription::parse(&buffer_set.column_buffer)?;
92                let row = DataRow::parse(payload)?;
93                self.handler.row(cols, row)?;
94                Ok(Action::ReadMessage)
95            }
96            msg_type::COMMAND_COMPLETE => {
97                let complete = CommandComplete::parse(payload)?;
98                self.handler.result_end(complete)?;
99                // More commands may follow
100                self.state = State::WaitingResponse;
101                Ok(Action::ReadMessage)
102            }
103            msg_type::READY_FOR_QUERY => {
104                let ready = ReadyForQuery::parse(payload)?;
105                self.transaction_status = ready.transaction_status().unwrap_or_default();
106                self.state = State::Finished;
107                Ok(Action::Finished)
108            }
109            _ => Err(Error::Protocol(format!(
110                "Unexpected message in row processing: '{}'",
111                type_byte as char
112            ))),
113        }
114    }
115
116    fn handle_ready(&mut self, buffer_set: &BufferSet) -> Result<Action> {
117        if buffer_set.type_byte != msg_type::READY_FOR_QUERY {
118            return Err(Error::Protocol(format!(
119                "Expected ReadyForQuery, got '{}'",
120                buffer_set.type_byte as char
121            )));
122        }
123
124        let ready = ReadyForQuery::parse(&buffer_set.read_buffer)?;
125        self.transaction_status = ready.transaction_status().unwrap_or_default();
126        self.state = State::Finished;
127        Ok(Action::Finished)
128    }
129
130    fn handle_async_message(&self, msg: &RawMessage<'_>) -> Result<Action> {
131        match msg.type_byte {
132            msg_type::NOTICE_RESPONSE => {
133                let notice = crate::protocol::backend::NoticeResponse::parse(msg.payload)?;
134                Ok(Action::HandleAsyncMessageAndReadMessage(
135                    AsyncMessage::Notice(notice.0),
136                ))
137            }
138            msg_type::PARAMETER_STATUS => {
139                let param = crate::protocol::backend::auth::ParameterStatus::parse(msg.payload)?;
140                Ok(Action::HandleAsyncMessageAndReadMessage(
141                    AsyncMessage::ParameterChanged {
142                        name: param.name.to_string(),
143                        value: param.value.to_string(),
144                    },
145                ))
146            }
147            msg_type::NOTIFICATION_RESPONSE => {
148                let notification =
149                    crate::protocol::backend::auth::NotificationResponse::parse(msg.payload)?;
150                Ok(Action::HandleAsyncMessageAndReadMessage(
151                    AsyncMessage::Notification {
152                        pid: notification.pid,
153                        channel: notification.channel.to_string(),
154                        payload: notification.payload.to_string(),
155                    },
156                ))
157            }
158            _ => Err(Error::Protocol(format!(
159                "Unknown async message type: '{}'",
160                msg.type_byte as char
161            ))),
162        }
163    }
164}
165
166impl<H: SimpleHandler> StateMachine for SimpleQueryStateMachine<'_, '_, H> {
167    fn step(&mut self, buffer_set: &mut BufferSet) -> Result<Action> {
168        // Initial state: write query to buffer and request send
169        if self.state == State::Initial {
170            buffer_set.write_buffer.clear();
171            write_query(&mut buffer_set.write_buffer, self.query);
172            self.state = State::WaitingResponse;
173            return Ok(Action::WriteAndReadMessage);
174        }
175
176        let type_byte = buffer_set.type_byte;
177
178        // Handle async messages
179        if RawMessage::is_async_type(type_byte) {
180            let msg = RawMessage::new(type_byte, &buffer_set.read_buffer);
181            return self.handle_async_message(&msg);
182        }
183
184        // Handle error response
185        if type_byte == msg_type::ERROR_RESPONSE {
186            let error = ErrorResponse::parse(&buffer_set.read_buffer)?;
187            // After error, we still need to wait for ReadyForQuery
188            self.state = State::WaitingReady;
189            return Err(error.into_error());
190        }
191
192        match self.state {
193            State::WaitingResponse => self.handle_response(buffer_set),
194            State::ProcessingRows => self.handle_rows(buffer_set),
195            State::WaitingReady => self.handle_ready(buffer_set),
196            _ => Err(Error::Protocol(format!(
197                "Unexpected state {:?}",
198                self.state
199            ))),
200        }
201    }
202
203    fn transaction_status(&self) -> TransactionStatus {
204        self.transaction_status
205    }
206}