1use crate::conversion::ToParams;
4use crate::error::{Error, Result};
5use crate::handler::BinaryHandler;
6use crate::protocol::backend::{
7 BindComplete, CloseComplete, CommandComplete, DataRow, EmptyQueryResponse, ErrorResponse,
8 NoData, ParameterDescription, ParseComplete, PortalSuspended, RawMessage, ReadyForQuery,
9 RowDescription, msg_type,
10};
11use crate::protocol::frontend::{
12 write_bind, write_close_statement, write_describe_portal, write_describe_statement,
13 write_execute, write_parse, write_sync,
14};
15use crate::protocol::types::{Oid, TransactionStatus};
16
17use super::StateMachine;
18use super::action::{Action, AsyncMessage};
19use crate::buffer_set::BufferSet;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23enum State {
24 Initial,
25 WaitingParse,
26 WaitingBind,
27 WaitingDescribe,
28 WaitingRowDesc,
29 ProcessingRows,
30 WaitingReady,
31 Finished,
32}
33
34#[derive(Debug, Clone)]
36pub struct PreparedStatement {
37 pub idx: u64,
39 pub param_oids: Vec<Oid>,
41 row_desc_payload: Option<Vec<u8>>,
43 custom_wire_name: Option<String>,
45}
46
47impl PreparedStatement {
48 pub fn new(
50 idx: u64,
51 param_oids: Vec<Oid>,
52 row_desc_payload: Option<Vec<u8>>,
53 wire_name: String,
54 ) -> Self {
55 Self {
56 idx,
57 param_oids,
58 row_desc_payload,
59 custom_wire_name: Some(wire_name),
60 }
61 }
62
63 pub fn wire_name(&self) -> String {
65 if let Some(name) = &self.custom_wire_name {
66 name.clone()
67 } else {
68 format!("_zero_{}", self.idx)
69 }
70 }
71
72 pub fn parse_columns(&self) -> Option<Result<RowDescription<'_>>> {
76 self.row_desc_payload
77 .as_ref()
78 .map(|bytes| RowDescription::parse(bytes))
79 }
80
81 pub fn row_desc_payload(&self) -> Option<&[u8]> {
85 self.row_desc_payload.as_deref()
86 }
87}
88
89#[derive(Debug, Clone, Copy, PartialEq, Eq)]
91enum Operation {
92 Prepare,
94 Execute,
96 ExecuteSql,
98 CloseStatement,
100}
101
102pub struct ExtendedQueryStateMachine<'a, H> {
104 state: State,
105 handler: &'a mut H,
106 operation: Operation,
107 transaction_status: TransactionStatus,
108 prepared_stmt: Option<PreparedStatement>,
109}
110
111impl<'a, H: BinaryHandler> ExtendedQueryStateMachine<'a, H> {
112 pub fn take_prepared_statement(&mut self) -> Option<PreparedStatement> {
114 self.prepared_stmt.take()
115 }
116
117 pub fn prepare(
121 handler: &'a mut H,
122 buffer_set: &mut BufferSet,
123 idx: u64,
124 query: &str,
125 param_oids: &[Oid],
126 ) -> Self {
127 let stmt_name = format!("_zero_{}", idx);
128 buffer_set.write_buffer.clear();
129 write_parse(&mut buffer_set.write_buffer, &stmt_name, query, param_oids);
130 write_describe_statement(&mut buffer_set.write_buffer, &stmt_name);
131 write_sync(&mut buffer_set.write_buffer);
132
133 Self {
134 state: State::Initial,
135 handler,
136 operation: Operation::Prepare,
137 transaction_status: TransactionStatus::Idle,
138 prepared_stmt: Some(PreparedStatement {
139 idx,
140 param_oids: Vec::new(),
141 row_desc_payload: None,
142 custom_wire_name: None,
143 }),
144 }
145 }
146
147 pub fn execute<P: ToParams>(
154 handler: &'a mut H,
155 buffer_set: &mut BufferSet,
156 statement_name: &str,
157 param_oids: &[Oid],
158 params: &P,
159 ) -> Result<Self> {
160 buffer_set.write_buffer.clear();
161 write_bind(
162 &mut buffer_set.write_buffer,
163 "",
164 statement_name,
165 params,
166 param_oids,
167 )?;
168 write_describe_portal(&mut buffer_set.write_buffer, "");
169 write_execute(&mut buffer_set.write_buffer, "", 0);
170 write_sync(&mut buffer_set.write_buffer);
171
172 Ok(Self {
173 state: State::Initial,
174 handler,
175 operation: Operation::Execute,
176 transaction_status: TransactionStatus::Idle,
177 prepared_stmt: None,
178 })
179 }
180
181 pub fn execute_sql<P: ToParams>(
189 handler: &'a mut H,
190 buffer_set: &mut BufferSet,
191 sql: &str,
192 params: &P,
193 ) -> Result<Self> {
194 let param_oids = params.natural_oids();
195 buffer_set.write_buffer.clear();
196 write_parse(&mut buffer_set.write_buffer, "", sql, ¶m_oids);
197 write_bind(&mut buffer_set.write_buffer, "", "", params, ¶m_oids)?;
198 write_describe_portal(&mut buffer_set.write_buffer, "");
199 write_execute(&mut buffer_set.write_buffer, "", 0);
200 write_sync(&mut buffer_set.write_buffer);
201
202 Ok(Self {
203 state: State::Initial,
204 handler,
205 operation: Operation::ExecuteSql,
206 transaction_status: TransactionStatus::Idle,
207 prepared_stmt: None,
208 })
209 }
210
211 pub fn close_statement(handler: &'a mut H, buffer_set: &mut BufferSet, name: &str) -> Self {
215 buffer_set.write_buffer.clear();
216 write_close_statement(&mut buffer_set.write_buffer, name);
217 write_sync(&mut buffer_set.write_buffer);
218
219 Self {
220 state: State::Initial,
221 handler,
222 operation: Operation::CloseStatement,
223 transaction_status: TransactionStatus::Idle,
224 prepared_stmt: None,
225 }
226 }
227
228 fn handle_parse(&mut self, buffer_set: &BufferSet) -> Result<Action> {
229 let type_byte = buffer_set.type_byte;
230 if type_byte != msg_type::PARSE_COMPLETE {
231 return Err(Error::Protocol(format!(
232 "Expected ParseComplete, got '{}'",
233 type_byte as char
234 )));
235 }
236
237 ParseComplete::parse(&buffer_set.read_buffer)?;
238 self.state = match self.operation {
241 Operation::ExecuteSql => State::WaitingBind,
242 Operation::Prepare => State::WaitingDescribe,
243 _ => unreachable!("handle_parse called for non-parse operation"),
244 };
245 Ok(Action::ReadMessage)
246 }
247
248 fn handle_describe(&mut self, buffer_set: &BufferSet) -> Result<Action> {
249 let type_byte = buffer_set.type_byte;
250 if type_byte != msg_type::PARAMETER_DESCRIPTION {
251 return Err(Error::Protocol(format!(
252 "Expected ParameterDescription, got '{}'",
253 type_byte as char
254 )));
255 }
256
257 let param_desc = ParameterDescription::parse(&buffer_set.read_buffer)?;
258 if let Some(ref mut stmt) = self.prepared_stmt {
259 stmt.param_oids = param_desc.oids().to_vec();
260 }
261
262 self.state = State::WaitingRowDesc;
263 Ok(Action::ReadMessage)
264 }
265
266 fn handle_row_desc(&mut self, buffer_set: &BufferSet) -> Result<Action> {
267 let type_byte = buffer_set.type_byte;
268
269 match type_byte {
270 msg_type::ROW_DESCRIPTION => {
271 if let Some(ref mut stmt) = self.prepared_stmt {
272 stmt.row_desc_payload = Some(buffer_set.read_buffer.clone());
273 }
274 self.state = State::WaitingReady;
275 Ok(Action::ReadMessage)
276 }
277 msg_type::NO_DATA => {
278 let payload = &buffer_set.read_buffer;
279 NoData::parse(payload)?;
280 self.state = State::WaitingReady;
282 Ok(Action::ReadMessage)
283 }
284 _ => Err(Error::Protocol(format!(
285 "Expected RowDescription or NoData, got '{}'",
286 type_byte as char
287 ))),
288 }
289 }
290
291 fn handle_bind(&mut self, buffer_set: &mut BufferSet) -> Result<Action> {
292 let type_byte = buffer_set.type_byte;
293
294 match type_byte {
295 msg_type::BIND_COMPLETE => {
296 BindComplete::parse(&buffer_set.read_buffer)?;
297 self.state = State::ProcessingRows;
298 Ok(Action::ReadMessage)
299 }
300 msg_type::ROW_DESCRIPTION => {
301 buffer_set.column_buffer.clear();
303 buffer_set
304 .column_buffer
305 .extend_from_slice(&buffer_set.read_buffer);
306 let cols = RowDescription::parse(&buffer_set.column_buffer)?;
307 self.handler.result_start(cols)?;
308 self.state = State::ProcessingRows;
309 Ok(Action::ReadMessage)
310 }
311 _ => Err(Error::Protocol(format!(
312 "Expected BindComplete, got '{}'",
313 type_byte as char
314 ))),
315 }
316 }
317
318 fn handle_rows(&mut self, buffer_set: &mut BufferSet) -> Result<Action> {
319 let type_byte = buffer_set.type_byte;
320 let payload = &buffer_set.read_buffer;
321
322 match type_byte {
323 msg_type::ROW_DESCRIPTION => {
324 buffer_set.column_buffer.clear();
326 buffer_set.column_buffer.extend_from_slice(payload);
327 let cols = RowDescription::parse(&buffer_set.column_buffer)?;
328 self.handler.result_start(cols)?;
329 Ok(Action::ReadMessage)
330 }
331 msg_type::NO_DATA => {
332 NoData::parse(payload)?;
334 Ok(Action::ReadMessage)
335 }
336 msg_type::DATA_ROW => {
337 let cols = RowDescription::parse(&buffer_set.column_buffer)?;
338 let row = DataRow::parse(payload)?;
339 self.handler.row(cols, row)?;
340 Ok(Action::ReadMessage)
341 }
342 msg_type::COMMAND_COMPLETE => {
343 let complete = CommandComplete::parse(payload)?;
344 self.handler.result_end(complete)?;
345 self.state = State::WaitingReady;
346 Ok(Action::ReadMessage)
347 }
348 msg_type::EMPTY_QUERY_RESPONSE => {
349 EmptyQueryResponse::parse(payload)?;
350 self.state = State::WaitingReady;
352 Ok(Action::ReadMessage)
353 }
354 msg_type::PORTAL_SUSPENDED => {
355 PortalSuspended::parse(payload)?;
356 self.state = State::WaitingReady;
358 Ok(Action::ReadMessage)
359 }
360 msg_type::READY_FOR_QUERY => {
361 let ready = ReadyForQuery::parse(payload)?;
362 self.transaction_status = ready.transaction_status().unwrap_or_default();
363 self.state = State::Finished;
364 Ok(Action::Finished)
365 }
366 _ => Err(Error::Protocol(format!(
367 "Unexpected message in rows: '{}'",
368 type_byte as char
369 ))),
370 }
371 }
372
373 fn handle_ready(&mut self, buffer_set: &BufferSet) -> Result<Action> {
374 let type_byte = buffer_set.type_byte;
375 let payload = &buffer_set.read_buffer;
376
377 match type_byte {
378 msg_type::READY_FOR_QUERY => {
379 let ready = ReadyForQuery::parse(payload)?;
380 self.transaction_status = ready.transaction_status().unwrap_or_default();
381 self.state = State::Finished;
382 Ok(Action::Finished)
383 }
384 msg_type::CLOSE_COMPLETE => {
385 CloseComplete::parse(payload)?;
386 Ok(Action::ReadMessage)
388 }
389 _ => Err(Error::Protocol(format!(
390 "Expected ReadyForQuery, got '{}'",
391 type_byte as char
392 ))),
393 }
394 }
395
396 fn handle_async_message(&self, msg: &RawMessage<'_>) -> Result<Action> {
397 match msg.type_byte {
398 msg_type::NOTICE_RESPONSE => {
399 let notice = crate::protocol::backend::NoticeResponse::parse(msg.payload)?;
400 Ok(Action::HandleAsyncMessageAndReadMessage(
401 AsyncMessage::Notice(notice.0),
402 ))
403 }
404 msg_type::PARAMETER_STATUS => {
405 let param = crate::protocol::backend::auth::ParameterStatus::parse(msg.payload)?;
406 Ok(Action::HandleAsyncMessageAndReadMessage(
407 AsyncMessage::ParameterChanged {
408 name: param.name.to_string(),
409 value: param.value.to_string(),
410 },
411 ))
412 }
413 msg_type::NOTIFICATION_RESPONSE => {
414 let notification =
415 crate::protocol::backend::auth::NotificationResponse::parse(msg.payload)?;
416 Ok(Action::HandleAsyncMessageAndReadMessage(
417 AsyncMessage::Notification {
418 pid: notification.pid,
419 channel: notification.channel.to_string(),
420 payload: notification.payload.to_string(),
421 },
422 ))
423 }
424 _ => Err(Error::Protocol(format!(
425 "Unknown async message type: '{}'",
426 msg.type_byte as char
427 ))),
428 }
429 }
430}
431
432impl<H: BinaryHandler> StateMachine for ExtendedQueryStateMachine<'_, H> {
433 fn step(&mut self, buffer_set: &mut BufferSet) -> Result<Action> {
434 if self.state == State::Initial {
436 self.state = match self.operation {
438 Operation::Prepare => State::WaitingParse,
439 Operation::Execute => State::WaitingBind, Operation::ExecuteSql => State::WaitingParse,
441 Operation::CloseStatement => State::WaitingReady,
442 };
443 return Ok(Action::WriteAndReadMessage);
444 }
445
446 let type_byte = buffer_set.type_byte;
447
448 if RawMessage::is_async_type(type_byte) {
450 let msg = RawMessage::new(type_byte, &buffer_set.read_buffer);
451 return self.handle_async_message(&msg);
452 }
453
454 if type_byte == msg_type::ERROR_RESPONSE {
456 let error = ErrorResponse::parse(&buffer_set.read_buffer)?;
457 self.state = State::WaitingReady;
459 return Err(error.into_error());
460 }
461
462 match self.state {
463 State::WaitingParse => self.handle_parse(buffer_set),
464 State::WaitingDescribe => self.handle_describe(buffer_set),
465 State::WaitingRowDesc => self.handle_row_desc(buffer_set),
466 State::WaitingBind => self.handle_bind(buffer_set),
467 State::ProcessingRows => self.handle_rows(buffer_set),
468 State::WaitingReady => self.handle_ready(buffer_set),
469 _ => Err(Error::Protocol(format!(
470 "Unexpected state {:?}",
471 self.state
472 ))),
473 }
474 }
475
476 fn transaction_status(&self) -> TransactionStatus {
477 self.transaction_status
478 }
479}
480
481use crate::protocol::frontend::write_flush;
485
486#[derive(Debug, Clone, Copy, PartialEq, Eq)]
488enum BindState {
489 Initial,
490 WaitingParse,
491 WaitingBind,
492 Finished,
493}
494
495pub struct BindStateMachine {
499 state: BindState,
500 needs_parse: bool,
501}
502
503impl BindStateMachine {
504 pub fn bind_prepared<P: ToParams>(
510 buffer_set: &mut BufferSet,
511 statement_name: &str,
512 param_oids: &[Oid],
513 params: &P,
514 ) -> Result<Self> {
515 buffer_set.write_buffer.clear();
516 write_bind(
517 &mut buffer_set.write_buffer,
518 "",
519 statement_name,
520 params,
521 param_oids,
522 )?;
523 write_flush(&mut buffer_set.write_buffer);
524
525 Ok(Self {
526 state: BindState::Initial,
527 needs_parse: false,
528 })
529 }
530
531 pub fn bind_sql<P: ToParams>(
537 buffer_set: &mut BufferSet,
538 sql: &str,
539 params: &P,
540 ) -> Result<Self> {
541 let param_oids = params.natural_oids();
542 buffer_set.write_buffer.clear();
543 write_parse(&mut buffer_set.write_buffer, "", sql, ¶m_oids);
544 write_bind(&mut buffer_set.write_buffer, "", "", params, ¶m_oids)?;
545 write_flush(&mut buffer_set.write_buffer);
546
547 Ok(Self {
548 state: BindState::Initial,
549 needs_parse: true,
550 })
551 }
552
553 pub fn step(&mut self, buffer_set: &mut BufferSet) -> Result<Action> {
555 if self.state == BindState::Initial {
557 self.state = if self.needs_parse {
558 BindState::WaitingParse
559 } else {
560 BindState::WaitingBind
561 };
562 return Ok(Action::WriteAndReadMessage);
563 }
564
565 let type_byte = buffer_set.type_byte;
566
567 if RawMessage::is_async_type(type_byte) {
569 return Ok(Action::ReadMessage);
570 }
571
572 if type_byte == msg_type::ERROR_RESPONSE {
574 let error = ErrorResponse::parse(&buffer_set.read_buffer)?;
575 return Err(error.into_error());
576 }
577
578 match self.state {
579 BindState::WaitingParse => {
580 if type_byte != msg_type::PARSE_COMPLETE {
581 return Err(Error::Protocol(format!(
582 "Expected ParseComplete, got '{}'",
583 type_byte as char
584 )));
585 }
586 ParseComplete::parse(&buffer_set.read_buffer)?;
587 self.state = BindState::WaitingBind;
588 Ok(Action::ReadMessage)
589 }
590 BindState::WaitingBind => {
591 if type_byte != msg_type::BIND_COMPLETE {
592 return Err(Error::Protocol(format!(
593 "Expected BindComplete, got '{}'",
594 type_byte as char
595 )));
596 }
597 BindComplete::parse(&buffer_set.read_buffer)?;
598 self.state = BindState::Finished;
599 Ok(Action::Finished)
600 }
601 _ => Err(Error::Protocol(format!(
602 "Unexpected state {:?}",
603 self.state
604 ))),
605 }
606 }
607}
608
609#[derive(Debug, Clone, Copy, PartialEq, Eq)]
614enum BatchState {
615 Initial,
616 WaitingParse,
617 Processing,
618 Finished,
619}
620
621pub struct BatchStateMachine {
625 state: BatchState,
626 needs_parse: bool,
627 transaction_status: TransactionStatus,
628}
629
630impl BatchStateMachine {
631 pub fn new(needs_parse: bool) -> Self {
638 Self {
639 state: BatchState::Initial,
640 needs_parse,
641 transaction_status: TransactionStatus::Idle,
642 }
643 }
644
645 pub fn transaction_status(&self) -> TransactionStatus {
647 self.transaction_status
648 }
649
650 pub fn step(&mut self, buffer_set: &mut BufferSet) -> Result<Action> {
652 if self.state == BatchState::Initial {
654 self.state = if self.needs_parse {
655 BatchState::WaitingParse
656 } else {
657 BatchState::Processing
658 };
659 return Ok(Action::WriteAndReadMessage);
660 }
661
662 let type_byte = buffer_set.type_byte;
663
664 if RawMessage::is_async_type(type_byte) {
666 return Ok(Action::ReadMessage);
667 }
668
669 if type_byte == msg_type::ERROR_RESPONSE {
671 let error = ErrorResponse::parse(&buffer_set.read_buffer)?;
672 self.state = BatchState::Processing;
673 return Err(error.into_error());
674 }
675
676 match self.state {
677 BatchState::WaitingParse => {
678 if type_byte != msg_type::PARSE_COMPLETE {
679 return Err(Error::Protocol(format!(
680 "Expected ParseComplete, got '{}'",
681 type_byte as char
682 )));
683 }
684 ParseComplete::parse(&buffer_set.read_buffer)?;
685 self.state = BatchState::Processing;
686 Ok(Action::ReadMessage)
687 }
688 BatchState::Processing => {
689 match type_byte {
690 msg_type::BIND_COMPLETE => {
691 BindComplete::parse(&buffer_set.read_buffer)?;
692 Ok(Action::ReadMessage)
693 }
694 msg_type::NO_DATA => {
695 NoData::parse(&buffer_set.read_buffer)?;
696 Ok(Action::ReadMessage)
697 }
698 msg_type::ROW_DESCRIPTION => {
699 RowDescription::parse(&buffer_set.read_buffer)?;
701 Ok(Action::ReadMessage)
702 }
703 msg_type::DATA_ROW => {
704 Ok(Action::ReadMessage)
706 }
707 msg_type::COMMAND_COMPLETE => {
708 CommandComplete::parse(&buffer_set.read_buffer)?;
709 Ok(Action::ReadMessage)
710 }
711 msg_type::EMPTY_QUERY_RESPONSE => {
712 EmptyQueryResponse::parse(&buffer_set.read_buffer)?;
713 Ok(Action::ReadMessage)
714 }
715 msg_type::READY_FOR_QUERY => {
716 let ready = ReadyForQuery::parse(&buffer_set.read_buffer)?;
717 self.transaction_status = ready.transaction_status().unwrap_or_default();
718 self.state = BatchState::Finished;
719 Ok(Action::Finished)
720 }
721 _ => Err(Error::Protocol(format!(
722 "Unexpected message in batch: '{}'",
723 type_byte as char
724 ))),
725 }
726 }
727 _ => Err(Error::Protocol(format!(
728 "Unexpected state {:?}",
729 self.state
730 ))),
731 }
732 }
733}