1use crate::conversion::ToParams;
4use crate::error::{Error, Result};
5use crate::handler::ExtendedHandler;
6use crate::protocol::backend::{
7 BindComplete, CloseComplete, CommandComplete, DataRow, EmptyQueryResponse, ErrorResponse,
8 NoData, ParameterDescription, ParseComplete, PortalSuspended, RawMessage, ReadyForQuery,
9 RowDescription, msg_type,
10};
11use crate::protocol::frontend::{
12 write_bind, write_close_statement, write_describe_portal, write_describe_statement,
13 write_execute, write_parse, write_sync,
14};
15use crate::protocol::types::{Oid, TransactionStatus};
16
17use super::StateMachine;
18use super::action::{Action, AsyncMessage};
19use crate::buffer_set::BufferSet;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23enum State {
24 Initial,
25 WaitingParse,
26 WaitingBind,
27 WaitingDescribe,
28 WaitingRowDesc,
29 ProcessingRows,
30 WaitingReady,
31 Finished,
32}
33
34#[derive(Debug, Clone)]
36pub struct PreparedStatement {
37 pub idx: u64,
39 pub param_oids: Vec<Oid>,
41 pub(crate) row_desc_payload: Option<Vec<u8>>,
43}
44
45impl PreparedStatement {
46 pub fn wire_name(&self) -> String {
48 format!("_zero_s_{}", self.idx)
49 }
50
51 pub fn parse_columns(&self) -> Option<Result<RowDescription<'_>>> {
55 self.row_desc_payload
56 .as_ref()
57 .map(|bytes| RowDescription::parse(bytes))
58 }
59
60 pub fn row_desc_payload(&self) -> Option<&[u8]> {
64 self.row_desc_payload.as_deref()
65 }
66}
67
68#[derive(Debug, Clone, Copy, PartialEq, Eq)]
70enum Operation {
71 Prepare,
73 Execute,
75 ExecuteSql,
77 CloseStatement,
79}
80
81pub struct ExtendedQueryStateMachine<'a, H> {
83 state: State,
84 handler: &'a mut H,
85 operation: Operation,
86 transaction_status: TransactionStatus,
87 prepared_stmt: Option<PreparedStatement>,
88 pending_error: Option<crate::error::ServerError>,
89}
90
91impl<'a, H: ExtendedHandler> ExtendedQueryStateMachine<'a, H> {
92 pub fn take_prepared_statement(&mut self) -> Option<PreparedStatement> {
94 self.prepared_stmt.take()
95 }
96
97 pub fn prepare(
101 handler: &'a mut H,
102 buffer_set: &mut BufferSet,
103 idx: u64,
104 query: &str,
105 param_oids: &[Oid],
106 ) -> Self {
107 let stmt_name = format!("_zero_s_{}", idx);
108 buffer_set.write_buffer.clear();
109 write_parse(&mut buffer_set.write_buffer, &stmt_name, query, param_oids);
110 write_describe_statement(&mut buffer_set.write_buffer, &stmt_name);
111 write_sync(&mut buffer_set.write_buffer);
112
113 Self {
114 state: State::Initial,
115 handler,
116 operation: Operation::Prepare,
117 transaction_status: TransactionStatus::Idle,
118 prepared_stmt: Some(PreparedStatement {
119 idx,
120 param_oids: Vec::new(),
121 row_desc_payload: None,
122 }),
123 pending_error: None,
124 }
125 }
126
127 pub fn execute<P: ToParams>(
134 handler: &'a mut H,
135 buffer_set: &mut BufferSet,
136 statement_name: &str,
137 param_oids: &[Oid],
138 params: &P,
139 ) -> Result<Self> {
140 buffer_set.write_buffer.clear();
141 write_bind(
142 &mut buffer_set.write_buffer,
143 "",
144 statement_name,
145 params,
146 param_oids,
147 )?;
148 write_describe_portal(&mut buffer_set.write_buffer, "");
149 write_execute(&mut buffer_set.write_buffer, "", 0);
150 write_sync(&mut buffer_set.write_buffer);
151
152 Ok(Self {
153 state: State::Initial,
154 handler,
155 operation: Operation::Execute,
156 transaction_status: TransactionStatus::Idle,
157 prepared_stmt: None,
158 pending_error: None,
159 })
160 }
161
162 pub fn execute_sql<P: ToParams>(
170 handler: &'a mut H,
171 buffer_set: &mut BufferSet,
172 sql: &str,
173 params: &P,
174 ) -> Result<Self> {
175 let param_oids = params.natural_oids();
176 buffer_set.write_buffer.clear();
177 write_parse(&mut buffer_set.write_buffer, "", sql, ¶m_oids);
178 write_bind(&mut buffer_set.write_buffer, "", "", params, ¶m_oids)?;
179 write_describe_portal(&mut buffer_set.write_buffer, "");
180 write_execute(&mut buffer_set.write_buffer, "", 0);
181 write_sync(&mut buffer_set.write_buffer);
182
183 Ok(Self {
184 state: State::Initial,
185 handler,
186 operation: Operation::ExecuteSql,
187 transaction_status: TransactionStatus::Idle,
188 prepared_stmt: None,
189 pending_error: None,
190 })
191 }
192
193 pub fn close_statement(handler: &'a mut H, buffer_set: &mut BufferSet, name: &str) -> Self {
197 buffer_set.write_buffer.clear();
198 write_close_statement(&mut buffer_set.write_buffer, name);
199 write_sync(&mut buffer_set.write_buffer);
200
201 Self {
202 state: State::Initial,
203 handler,
204 operation: Operation::CloseStatement,
205 transaction_status: TransactionStatus::Idle,
206 prepared_stmt: None,
207 pending_error: None,
208 }
209 }
210
211 fn handle_parse(&mut self, buffer_set: &BufferSet) -> Result<Action> {
212 let type_byte = buffer_set.type_byte;
213 if type_byte != msg_type::PARSE_COMPLETE {
214 return Err(Error::LibraryBug(format!(
215 "Expected ParseComplete, got '{}'",
216 type_byte as char
217 )));
218 }
219
220 ParseComplete::parse(&buffer_set.read_buffer)?;
221 self.state = match self.operation {
224 Operation::ExecuteSql => State::WaitingBind,
225 Operation::Prepare => State::WaitingDescribe,
226 _ => {
227 return Err(Error::LibraryBug(
228 "handle_parse called for non-parse operation".into(),
229 ));
230 }
231 };
232 Ok(Action::ReadMessage)
233 }
234
235 fn handle_describe(&mut self, buffer_set: &BufferSet) -> Result<Action> {
236 let type_byte = buffer_set.type_byte;
237 if type_byte != msg_type::PARAMETER_DESCRIPTION {
238 return Err(Error::LibraryBug(format!(
239 "Expected ParameterDescription, got '{}'",
240 type_byte as char
241 )));
242 }
243
244 let param_desc = ParameterDescription::parse(&buffer_set.read_buffer)?;
245 if let Some(stmt) = &mut self.prepared_stmt {
246 stmt.param_oids = param_desc.oids().to_vec();
247 }
248
249 self.state = State::WaitingRowDesc;
250 Ok(Action::ReadMessage)
251 }
252
253 fn handle_row_desc(&mut self, buffer_set: &BufferSet) -> Result<Action> {
254 let type_byte = buffer_set.type_byte;
255
256 match type_byte {
257 msg_type::ROW_DESCRIPTION => {
258 if let Some(stmt) = &mut self.prepared_stmt {
259 stmt.row_desc_payload = Some(buffer_set.read_buffer.clone());
260 }
261 self.state = State::WaitingReady;
262 Ok(Action::ReadMessage)
263 }
264 msg_type::NO_DATA => {
265 let payload = &buffer_set.read_buffer;
266 NoData::parse(payload)?;
267 self.state = State::WaitingReady;
269 Ok(Action::ReadMessage)
270 }
271 _ => Err(Error::LibraryBug(format!(
272 "Expected RowDescription or NoData, got '{}'",
273 type_byte as char
274 ))),
275 }
276 }
277
278 fn handle_bind(&mut self, buffer_set: &mut BufferSet) -> Result<Action> {
279 let type_byte = buffer_set.type_byte;
280
281 match type_byte {
282 msg_type::BIND_COMPLETE => {
283 BindComplete::parse(&buffer_set.read_buffer)?;
284 self.state = State::ProcessingRows;
285 Ok(Action::ReadMessage)
286 }
287 msg_type::ROW_DESCRIPTION => {
288 buffer_set.column_buffer.clear();
290 buffer_set
291 .column_buffer
292 .extend_from_slice(&buffer_set.read_buffer);
293 let cols = RowDescription::parse(&buffer_set.column_buffer)?;
294 self.handler.result_start(cols)?;
295 self.state = State::ProcessingRows;
296 Ok(Action::ReadMessage)
297 }
298 _ => Err(Error::LibraryBug(format!(
299 "Expected BindComplete, got '{}'",
300 type_byte as char
301 ))),
302 }
303 }
304
305 fn handle_rows(&mut self, buffer_set: &mut BufferSet) -> Result<Action> {
306 let type_byte = buffer_set.type_byte;
307 let payload = &buffer_set.read_buffer;
308
309 match type_byte {
310 msg_type::ROW_DESCRIPTION => {
311 buffer_set.column_buffer.clear();
313 buffer_set.column_buffer.extend_from_slice(payload);
314 let cols = RowDescription::parse(&buffer_set.column_buffer)?;
315 self.handler.result_start(cols)?;
316 Ok(Action::ReadMessage)
317 }
318 msg_type::NO_DATA => {
319 NoData::parse(payload)?;
321 Ok(Action::ReadMessage)
322 }
323 msg_type::DATA_ROW => {
324 let cols = RowDescription::parse(&buffer_set.column_buffer)?;
325 let row = DataRow::parse(payload)?;
326 self.handler.row(cols, row)?;
327 Ok(Action::ReadMessage)
328 }
329 msg_type::COMMAND_COMPLETE => {
330 let complete = CommandComplete::parse(payload)?;
331 self.handler.result_end(complete)?;
332 self.state = State::WaitingReady;
333 Ok(Action::ReadMessage)
334 }
335 msg_type::EMPTY_QUERY_RESPONSE => {
336 EmptyQueryResponse::parse(payload)?;
337 self.state = State::WaitingReady;
339 Ok(Action::ReadMessage)
340 }
341 msg_type::PORTAL_SUSPENDED => {
342 PortalSuspended::parse(payload)?;
343 self.state = State::WaitingReady;
345 Ok(Action::ReadMessage)
346 }
347 msg_type::READY_FOR_QUERY => {
348 let ready = ReadyForQuery::parse(payload)?;
349 self.transaction_status = ready.transaction_status().unwrap_or_default();
350 self.state = State::Finished;
351 Ok(Action::Finished)
352 }
353 _ => Err(Error::LibraryBug(format!(
354 "Unexpected message in rows: '{}'",
355 type_byte as char
356 ))),
357 }
358 }
359
360 fn handle_ready(&mut self, buffer_set: &BufferSet) -> Result<Action> {
361 let type_byte = buffer_set.type_byte;
362 let payload = &buffer_set.read_buffer;
363
364 match type_byte {
365 msg_type::READY_FOR_QUERY => {
366 let ready = ReadyForQuery::parse(payload)?;
367 self.transaction_status = ready.transaction_status().unwrap_or_default();
368 self.state = State::Finished;
369 if let Some(err) = self.pending_error.take() {
370 Ok(Action::Error(err))
371 } else {
372 Ok(Action::Finished)
373 }
374 }
375 msg_type::CLOSE_COMPLETE => {
376 CloseComplete::parse(payload)?;
377 Ok(Action::ReadMessage)
379 }
380 _ => Err(Error::LibraryBug(format!(
381 "Expected ReadyForQuery, got '{}'",
382 type_byte as char
383 ))),
384 }
385 }
386
387 fn handle_async_message(&self, msg: &RawMessage<'_>) -> Result<Action> {
388 match msg.type_byte {
389 msg_type::NOTICE_RESPONSE => {
390 let notice = crate::protocol::backend::NoticeResponse::parse(msg.payload)?;
391 Ok(Action::HandleAsyncMessageAndReadMessage(
392 AsyncMessage::Notice(notice.0),
393 ))
394 }
395 msg_type::PARAMETER_STATUS => {
396 let param = crate::protocol::backend::auth::ParameterStatus::parse(msg.payload)?;
397 Ok(Action::HandleAsyncMessageAndReadMessage(
398 AsyncMessage::ParameterChanged {
399 name: param.name.to_string(),
400 value: param.value.to_string(),
401 },
402 ))
403 }
404 msg_type::NOTIFICATION_RESPONSE => {
405 let notification =
406 crate::protocol::backend::auth::NotificationResponse::parse(msg.payload)?;
407 Ok(Action::HandleAsyncMessageAndReadMessage(
408 AsyncMessage::Notification {
409 pid: notification.pid,
410 channel: notification.channel.to_string(),
411 payload: notification.payload.to_string(),
412 },
413 ))
414 }
415 _ => Err(Error::LibraryBug(format!(
416 "Unknown async message type: '{}'",
417 msg.type_byte as char
418 ))),
419 }
420 }
421}
422
423impl<H: ExtendedHandler> StateMachine for ExtendedQueryStateMachine<'_, H> {
424 fn step(&mut self, buffer_set: &mut BufferSet) -> Result<Action> {
425 if self.state == State::Initial {
427 self.state = match self.operation {
429 Operation::Prepare => State::WaitingParse,
430 Operation::Execute => State::WaitingBind, Operation::ExecuteSql => State::WaitingParse,
432 Operation::CloseStatement => State::WaitingReady,
433 };
434 return Ok(Action::WriteAndReadMessage);
435 }
436
437 let type_byte = buffer_set.type_byte;
438
439 if RawMessage::is_async_type(type_byte) {
441 let msg = RawMessage::new(type_byte, &buffer_set.read_buffer);
442 return self.handle_async_message(&msg);
443 }
444 if type_byte == msg_type::ERROR_RESPONSE {
446 let error = ErrorResponse::parse(&buffer_set.read_buffer)?;
447 self.pending_error = Some(error.0);
448 self.state = State::WaitingReady;
449 return Ok(Action::ReadMessage);
450 }
451
452 match self.state {
453 State::WaitingParse => self.handle_parse(buffer_set),
454 State::WaitingDescribe => self.handle_describe(buffer_set),
455 State::WaitingRowDesc => self.handle_row_desc(buffer_set),
456 State::WaitingBind => self.handle_bind(buffer_set),
457 State::ProcessingRows => self.handle_rows(buffer_set),
458 State::WaitingReady => self.handle_ready(buffer_set),
459 _ => Err(Error::LibraryBug(format!(
460 "Unexpected state {:?}",
461 self.state
462 ))),
463 }
464 }
465
466 fn transaction_status(&self) -> TransactionStatus {
467 self.transaction_status
468 }
469}
470
471use crate::protocol::frontend::write_flush;
475
476#[derive(Debug, Clone, Copy, PartialEq, Eq)]
478enum BindState {
479 Initial,
480 WaitingParse,
481 WaitingBind,
482 Finished,
483}
484
485pub struct BindStateMachine {
489 state: BindState,
490 needs_parse: bool,
491}
492
493impl BindStateMachine {
494 pub fn bind_prepared<P: ToParams>(
503 buffer_set: &mut BufferSet,
504 portal_name: &str,
505 statement_name: &str,
506 param_oids: &[Oid],
507 params: &P,
508 ) -> Result<Self> {
509 buffer_set.write_buffer.clear();
510 write_bind(
511 &mut buffer_set.write_buffer,
512 portal_name,
513 statement_name,
514 params,
515 param_oids,
516 )?;
517 write_flush(&mut buffer_set.write_buffer);
518
519 Ok(Self {
520 state: BindState::Initial,
521 needs_parse: false,
522 })
523 }
524
525 pub fn bind_sql<P: ToParams>(
534 buffer_set: &mut BufferSet,
535 portal_name: &str,
536 sql: &str,
537 params: &P,
538 ) -> Result<Self> {
539 let param_oids = params.natural_oids();
540 buffer_set.write_buffer.clear();
541 write_parse(&mut buffer_set.write_buffer, "", sql, ¶m_oids);
542 write_bind(
543 &mut buffer_set.write_buffer,
544 portal_name,
545 "",
546 params,
547 ¶m_oids,
548 )?;
549 write_flush(&mut buffer_set.write_buffer);
550
551 Ok(Self {
552 state: BindState::Initial,
553 needs_parse: true,
554 })
555 }
556
557 pub fn step(&mut self, buffer_set: &mut BufferSet) -> Result<Action> {
559 if self.state == BindState::Initial {
561 self.state = if self.needs_parse {
562 BindState::WaitingParse
563 } else {
564 BindState::WaitingBind
565 };
566 return Ok(Action::WriteAndReadMessage);
567 }
568
569 let type_byte = buffer_set.type_byte;
570
571 if RawMessage::is_async_type(type_byte) {
573 return Ok(Action::ReadMessage);
574 }
575
576 if type_byte == msg_type::ERROR_RESPONSE {
578 let error = ErrorResponse::parse(&buffer_set.read_buffer)?;
579 return Err(error.into_error());
580 }
581
582 match self.state {
583 BindState::WaitingParse => {
584 if type_byte != msg_type::PARSE_COMPLETE {
585 return Err(Error::LibraryBug(format!(
586 "Expected ParseComplete, got '{}'",
587 type_byte as char
588 )));
589 }
590 ParseComplete::parse(&buffer_set.read_buffer)?;
591 self.state = BindState::WaitingBind;
592 Ok(Action::ReadMessage)
593 }
594 BindState::WaitingBind => {
595 if type_byte != msg_type::BIND_COMPLETE {
596 return Err(Error::LibraryBug(format!(
597 "Expected BindComplete, got '{}'",
598 type_byte as char
599 )));
600 }
601 BindComplete::parse(&buffer_set.read_buffer)?;
602 self.state = BindState::Finished;
603 Ok(Action::Finished)
604 }
605 _ => Err(Error::LibraryBug(format!(
606 "Unexpected state {:?}",
607 self.state
608 ))),
609 }
610 }
611}
612
613#[derive(Debug, Clone, Copy, PartialEq, Eq)]
618enum BatchState {
619 Initial,
620 WaitingParse,
621 Processing,
622 Finished,
623}
624
625pub struct BatchStateMachine {
629 state: BatchState,
630 needs_parse: bool,
631 transaction_status: TransactionStatus,
632 pending_error: Option<crate::error::ServerError>,
633}
634
635impl BatchStateMachine {
636 pub fn new(needs_parse: bool) -> Self {
643 Self {
644 state: BatchState::Initial,
645 needs_parse,
646 transaction_status: TransactionStatus::Idle,
647 pending_error: None,
648 }
649 }
650
651 pub fn transaction_status(&self) -> TransactionStatus {
653 self.transaction_status
654 }
655
656 pub fn step(&mut self, buffer_set: &mut BufferSet) -> Result<Action> {
658 if self.state == BatchState::Initial {
660 self.state = if self.needs_parse {
661 BatchState::WaitingParse
662 } else {
663 BatchState::Processing
664 };
665 return Ok(Action::WriteAndReadMessage);
666 }
667
668 let type_byte = buffer_set.type_byte;
669
670 if RawMessage::is_async_type(type_byte) {
672 return Ok(Action::ReadMessage);
673 }
674 if type_byte == msg_type::ERROR_RESPONSE {
676 let error = ErrorResponse::parse(&buffer_set.read_buffer)?;
677 self.pending_error = Some(error.0);
678 self.state = BatchState::Processing;
679 return Ok(Action::ReadMessage);
680 }
681
682 match self.state {
683 BatchState::WaitingParse => {
684 if type_byte != msg_type::PARSE_COMPLETE {
685 return Err(Error::LibraryBug(format!(
686 "Expected ParseComplete, got '{}'",
687 type_byte as char
688 )));
689 }
690 ParseComplete::parse(&buffer_set.read_buffer)?;
691 self.state = BatchState::Processing;
692 Ok(Action::ReadMessage)
693 }
694 BatchState::Processing => {
695 match type_byte {
696 msg_type::BIND_COMPLETE => {
697 BindComplete::parse(&buffer_set.read_buffer)?;
698 Ok(Action::ReadMessage)
699 }
700 msg_type::NO_DATA => {
701 NoData::parse(&buffer_set.read_buffer)?;
702 Ok(Action::ReadMessage)
703 }
704 msg_type::ROW_DESCRIPTION => {
705 RowDescription::parse(&buffer_set.read_buffer)?;
707 Ok(Action::ReadMessage)
708 }
709 msg_type::DATA_ROW => {
710 Ok(Action::ReadMessage)
712 }
713 msg_type::COMMAND_COMPLETE => {
714 CommandComplete::parse(&buffer_set.read_buffer)?;
715 Ok(Action::ReadMessage)
716 }
717 msg_type::EMPTY_QUERY_RESPONSE => {
718 EmptyQueryResponse::parse(&buffer_set.read_buffer)?;
719 Ok(Action::ReadMessage)
720 }
721 msg_type::READY_FOR_QUERY => {
722 let ready = ReadyForQuery::parse(&buffer_set.read_buffer)?;
723 self.transaction_status = ready.transaction_status().unwrap_or_default();
724 self.state = BatchState::Finished;
725 if let Some(err) = self.pending_error.take() {
726 Ok(Action::Error(err))
727 } else {
728 Ok(Action::Finished)
729 }
730 }
731 _ => Err(Error::LibraryBug(format!(
732 "Unexpected message in batch: '{}'",
733 type_byte as char
734 ))),
735 }
736 }
737 _ => Err(Error::LibraryBug(format!(
738 "Unexpected state {:?}",
739 self.state
740 ))),
741 }
742 }
743}