1use tokio::net::TcpStream;
4use tokio::net::UnixStream;
5
6use crate::buffer_pool::PooledBufferSet;
7use crate::conversion::ToParams;
8use crate::error::{Error, Result};
9use crate::handler::{
10 AsyncMessageHandler, BinaryHandler, DropHandler, FirstRowHandler, TextHandler,
11};
12use crate::opts::Opts;
13use crate::protocol::backend::BackendKeyData;
14use crate::protocol::frontend::write_terminate;
15use crate::protocol::types::TransactionStatus;
16use crate::state::StateMachine;
17use crate::state::action::Action;
18use crate::state::connection::ConnectionStateMachine;
19use crate::state::extended::{BindStateMachine, ExtendedQueryStateMachine, PreparedStatement};
20use crate::state::simple_query::SimpleQueryStateMachine;
21use crate::statement::IntoStatement;
22
23use super::stream::Stream;
24
25pub struct Conn {
27 pub(crate) stream: Stream,
28 pub(crate) buffer_set: PooledBufferSet,
29 backend_key: Option<BackendKeyData>,
30 server_params: Vec<(String, String)>,
31 pub(crate) transaction_status: TransactionStatus,
32 pub(crate) is_broken: bool,
33 name_counter: u64,
34 async_message_handler: Option<Box<dyn AsyncMessageHandler>>,
35}
36
37impl Conn {
38 pub async fn new<O: TryInto<Opts>>(opts: O) -> Result<Self>
40 where
41 Error: From<O::Error>,
42 {
43 let opts = opts.try_into()?;
44
45 let stream = if let Some(socket_path) = &opts.socket {
46 Stream::unix(UnixStream::connect(socket_path).await?)
47 } else {
48 if opts.host.is_empty() {
49 return Err(Error::InvalidUsage("host is empty".into()));
50 }
51 let addr = format!("{}:{}", opts.host, opts.port);
52 let tcp = TcpStream::connect(&addr).await?;
53 tcp.set_nodelay(true)?;
54 Stream::tcp(tcp)
55 };
56
57 Self::new_with_stream(stream, opts).await
58 }
59
60 #[allow(unused_mut)]
62 pub async fn new_with_stream(mut stream: Stream, options: Opts) -> Result<Self> {
63 let mut buffer_set = options.buffer_pool.get_buffer_set();
64 let mut state_machine = ConnectionStateMachine::new(options.clone());
65
66 loop {
68 match state_machine.step(&mut buffer_set)? {
69 Action::WriteAndReadByte => {
70 stream.write_all(&buffer_set.write_buffer).await?;
71 stream.flush().await?;
72 let byte = stream.read_u8().await?;
73 state_machine.set_ssl_response(byte);
74 }
75 Action::ReadMessage => {
76 stream.read_message(&mut buffer_set).await?;
77 }
78 Action::Write => {
79 stream.write_all(&buffer_set.write_buffer).await?;
80 stream.flush().await?;
81 }
82 Action::WriteAndReadMessage => {
83 stream.write_all(&buffer_set.write_buffer).await?;
84 stream.flush().await?;
85 stream.read_message(&mut buffer_set).await?;
86 }
87 Action::TlsHandshake => {
88 #[cfg(feature = "tokio-tls")]
89 {
90 stream = stream.upgrade_to_tls(&options.host).await?;
91 }
92 #[cfg(not(feature = "tokio-tls"))]
93 {
94 return Err(Error::Unsupported(
95 "TLS requested but tokio-tls feature not enabled".into(),
96 ));
97 }
98 }
99 Action::HandleAsyncMessageAndReadMessage(_) => {
100 stream.read_message(&mut buffer_set).await?;
102 }
103 Action::Finished => break,
104 }
105 }
106
107 let conn = Self {
108 stream,
109 buffer_set,
110 backend_key: state_machine.backend_key().cloned(),
111 server_params: state_machine.take_server_params(),
112 transaction_status: state_machine.transaction_status(),
113 is_broken: false,
114 name_counter: 0,
115 async_message_handler: None,
116 };
117
118 let conn = if options.prefer_unix_socket && conn.stream.is_tcp_loopback() {
120 conn.try_upgrade_to_unix_socket(&options).await
121 } else {
122 conn
123 };
124
125 Ok(conn)
126 }
127
128 fn try_upgrade_to_unix_socket(
131 mut self,
132 opts: &Opts,
133 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Self> + Send + '_>> {
134 let opts = opts.clone();
135 Box::pin(async move {
136 let mut handler = FirstRowHandler::<(String,)>::new();
138 if self
139 .query("SHOW unix_socket_directories", &mut handler)
140 .await
141 .is_err()
142 {
143 return self;
144 }
145
146 let socket_dir = match handler.into_row() {
147 Some((dirs,)) => {
148 match dirs.split(',').next() {
150 Some(d) if !d.trim().is_empty() => d.trim().to_string(),
151 _ => return self,
152 }
153 }
154 None => return self,
155 };
156
157 let socket_path = format!("{}/.s.PGSQL.{}", socket_dir, opts.port);
159
160 let unix_stream = match UnixStream::connect(&socket_path).await {
162 Ok(s) => s,
163 Err(_) => return self,
164 };
165
166 let mut opts_unix = opts.clone();
168 opts_unix.prefer_unix_socket = false;
169
170 match Self::new_with_stream(Stream::unix(unix_stream), opts_unix).await {
171 Ok(new_conn) => new_conn,
172 Err(_) => self,
173 }
174 })
175 }
176
177 pub fn backend_key(&self) -> Option<&BackendKeyData> {
179 self.backend_key.as_ref()
180 }
181
182 pub fn connection_id(&self) -> u32 {
186 self.backend_key.as_ref().map_or(0, |k| k.process_id())
187 }
188
189 pub fn server_params(&self) -> &[(String, String)] {
191 &self.server_params
192 }
193
194 pub fn transaction_status(&self) -> TransactionStatus {
196 self.transaction_status
197 }
198
199 pub fn in_transaction(&self) -> bool {
201 self.transaction_status.in_transaction()
202 }
203
204 pub fn is_broken(&self) -> bool {
206 self.is_broken
207 }
208
209 pub(crate) fn next_portal_name(&mut self) -> String {
211 self.name_counter += 1;
212 format!("_zero_p_{}", self.name_counter)
213 }
214
215 pub(crate) async fn create_named_portal<S: IntoStatement, P: ToParams>(
219 &mut self,
220 portal_name: &str,
221 statement: &S,
222 params: &P,
223 ) -> Result<()> {
224 let mut state_machine = if let Some(sql) = statement.as_sql() {
226 BindStateMachine::bind_sql(&mut self.buffer_set, portal_name, sql, params)?
227 } else {
228 let stmt = statement.as_prepared().unwrap();
229 BindStateMachine::bind_prepared(
230 &mut self.buffer_set,
231 portal_name,
232 &stmt.wire_name(),
233 &stmt.param_oids,
234 params,
235 )?
236 };
237
238 loop {
240 match state_machine.step(&mut self.buffer_set)? {
241 Action::ReadMessage => {
242 self.stream.read_message(&mut self.buffer_set).await?;
243 }
244 Action::Write => {
245 self.stream.write_all(&self.buffer_set.write_buffer).await?;
246 self.stream.flush().await?;
247 }
248 Action::WriteAndReadMessage => {
249 self.stream.write_all(&self.buffer_set.write_buffer).await?;
250 self.stream.flush().await?;
251 self.stream.read_message(&mut self.buffer_set).await?;
252 }
253 Action::Finished => break,
254 _ => return Err(Error::Protocol("Unexpected action in bind".into())),
255 }
256 }
257
258 Ok(())
259 }
260
261 pub fn set_async_message_handler<H: AsyncMessageHandler + 'static>(&mut self, handler: H) {
268 self.async_message_handler = Some(Box::new(handler));
269 }
270
271 pub fn clear_async_message_handler(&mut self) {
273 self.async_message_handler = None;
274 }
275
276 pub async fn ping(&mut self) -> Result<()> {
278 self.query_drop("").await?;
279 Ok(())
280 }
281
282 async fn drive<S: StateMachine>(&mut self, state_machine: &mut S) -> Result<()> {
284 loop {
285 match state_machine.step(&mut self.buffer_set)? {
286 Action::WriteAndReadByte => {
287 return Err(Error::Protocol(
288 "Unexpected WriteAndReadByte in query state machine".into(),
289 ));
290 }
291 Action::ReadMessage => {
292 self.stream.read_message(&mut self.buffer_set).await?;
293 }
294 Action::Write => {
295 self.stream.write_all(&self.buffer_set.write_buffer).await?;
296 self.stream.flush().await?;
297 }
298 Action::WriteAndReadMessage => {
299 self.stream.write_all(&self.buffer_set.write_buffer).await?;
300 self.stream.flush().await?;
301 self.stream.read_message(&mut self.buffer_set).await?;
302 }
303 Action::TlsHandshake => {
304 return Err(Error::Protocol(
305 "Unexpected TlsHandshake in query state machine".into(),
306 ));
307 }
308 Action::HandleAsyncMessageAndReadMessage(ref async_msg) => {
309 if let Some(ref mut h) = self.async_message_handler {
310 h.handle(async_msg);
311 }
312 self.stream.read_message(&mut self.buffer_set).await?;
314 }
315 Action::Finished => {
316 self.transaction_status = state_machine.transaction_status();
317 break;
318 }
319 }
320 }
321 Ok(())
322 }
323
324 pub async fn query<H: TextHandler>(&mut self, sql: &str, handler: &mut H) -> Result<()> {
326 let result = self.query_inner(sql, handler).await;
327 if let Err(e) = &result
328 && e.is_connection_broken()
329 {
330 self.is_broken = true;
331 }
332 result
333 }
334
335 async fn query_inner<H: TextHandler>(&mut self, sql: &str, handler: &mut H) -> Result<()> {
336 let mut state_machine = SimpleQueryStateMachine::new(handler, sql);
337 self.drive(&mut state_machine).await
338 }
339
340 pub async fn query_drop(&mut self, sql: &str) -> Result<Option<u64>> {
342 let mut handler = DropHandler::new();
343 self.query(sql, &mut handler).await?;
344 Ok(handler.rows_affected())
345 }
346
347 pub async fn query_collect<T: for<'a> crate::conversion::FromRow<'a>>(
349 &mut self,
350 sql: &str,
351 ) -> Result<Vec<T>> {
352 let mut handler = crate::handler::CollectHandler::<T>::new();
353 self.query(sql, &mut handler).await?;
354 Ok(handler.into_rows())
355 }
356
357 pub async fn query_first<T: for<'a> crate::conversion::FromRow<'a>>(
359 &mut self,
360 sql: &str,
361 ) -> Result<Option<T>> {
362 let mut handler = crate::handler::FirstRowHandler::<T>::new();
363 self.query(sql, &mut handler).await?;
364 Ok(handler.into_row())
365 }
366
367 pub async fn close(mut self) -> Result<()> {
369 self.buffer_set.write_buffer.clear();
370 write_terminate(&mut self.buffer_set.write_buffer);
371 self.stream.write_all(&self.buffer_set.write_buffer).await?;
372 self.stream.flush().await?;
373 Ok(())
374 }
375
376 pub async fn prepare(&mut self, query: &str) -> Result<PreparedStatement> {
380 self.prepare_typed(query, &[]).await
381 }
382
383 pub async fn prepare_typed(
385 &mut self,
386 query: &str,
387 param_oids: &[u32],
388 ) -> Result<PreparedStatement> {
389 self.name_counter += 1;
390 let idx = self.name_counter;
391 let result = self.prepare_inner(idx, query, param_oids).await;
392 if let Err(e) = &result
393 && e.is_connection_broken()
394 {
395 self.is_broken = true;
396 }
397 result
398 }
399
400 async fn prepare_inner(
401 &mut self,
402 idx: u64,
403 query: &str,
404 param_oids: &[u32],
405 ) -> Result<PreparedStatement> {
406 let mut handler = DropHandler::new();
407 let mut state_machine = ExtendedQueryStateMachine::prepare(
408 &mut handler,
409 &mut self.buffer_set,
410 idx,
411 query,
412 param_oids,
413 );
414 self.drive(&mut state_machine).await?;
415 state_machine
416 .take_prepared_statement()
417 .ok_or_else(|| Error::Protocol("No prepared statement".into()))
418 }
419
420 pub async fn exec<S: IntoStatement, P: ToParams, H: BinaryHandler>(
426 &mut self,
427 statement: S,
428 params: P,
429 handler: &mut H,
430 ) -> Result<()> {
431 let result = self.exec_inner(&statement, ¶ms, handler).await;
432 if let Err(e) = &result
433 && e.is_connection_broken()
434 {
435 self.is_broken = true;
436 }
437 result
438 }
439
440 async fn exec_inner<S: IntoStatement, P: ToParams, H: BinaryHandler>(
441 &mut self,
442 statement: &S,
443 params: &P,
444 handler: &mut H,
445 ) -> Result<()> {
446 let mut state_machine = if statement.needs_parse() {
447 ExtendedQueryStateMachine::execute_sql(
448 handler,
449 &mut self.buffer_set,
450 statement.as_sql().unwrap(),
451 params,
452 )?
453 } else {
454 let stmt = statement.as_prepared().unwrap();
455 ExtendedQueryStateMachine::execute(
456 handler,
457 &mut self.buffer_set,
458 &stmt.wire_name(),
459 &stmt.param_oids,
460 params,
461 )?
462 };
463
464 self.drive(&mut state_machine).await
465 }
466
467 pub async fn exec_drop<S: IntoStatement, P: ToParams>(
471 &mut self,
472 statement: S,
473 params: P,
474 ) -> Result<Option<u64>> {
475 let mut handler = DropHandler::new();
476 self.exec(statement, params, &mut handler).await?;
477 Ok(handler.rows_affected())
478 }
479
480 pub async fn exec_collect<
484 T: for<'a> crate::conversion::FromRow<'a>,
485 S: IntoStatement,
486 P: ToParams,
487 >(
488 &mut self,
489 statement: S,
490 params: P,
491 ) -> Result<Vec<T>> {
492 let mut handler = crate::handler::CollectHandler::<T>::new();
493 self.exec(statement, params, &mut handler).await?;
494 Ok(handler.into_rows())
495 }
496
497 pub async fn exec_batch<S: IntoStatement, P: ToParams>(
528 &mut self,
529 statement: S,
530 params_list: &[P],
531 ) -> Result<()> {
532 self.exec_batch_chunked(statement, params_list, 1000).await
533 }
534
535 pub async fn exec_batch_chunked<S: IntoStatement, P: ToParams>(
539 &mut self,
540 statement: S,
541 params_list: &[P],
542 chunk_size: usize,
543 ) -> Result<()> {
544 let result = self
545 .exec_batch_inner(&statement, params_list, chunk_size)
546 .await;
547 if let Err(e) = &result
548 && e.is_connection_broken()
549 {
550 self.is_broken = true;
551 }
552 result
553 }
554
555 async fn exec_batch_inner<S: IntoStatement, P: ToParams>(
556 &mut self,
557 statement: &S,
558 params_list: &[P],
559 chunk_size: usize,
560 ) -> Result<()> {
561 use crate::protocol::frontend::{write_bind, write_execute, write_parse, write_sync};
562 use crate::state::extended::BatchStateMachine;
563
564 if params_list.is_empty() {
565 return Ok(());
566 }
567
568 let chunk_size = chunk_size.max(1);
569 let needs_parse = statement.needs_parse();
570 let sql = statement.as_sql();
571 let prepared = statement.as_prepared();
572
573 let param_oids: Vec<u32> = if let Some(stmt) = prepared {
575 stmt.param_oids.clone()
576 } else {
577 params_list[0].natural_oids()
578 };
579
580 let stmt_name = prepared.map(|s| s.wire_name()).unwrap_or_default();
582
583 for chunk in params_list.chunks(chunk_size) {
584 self.buffer_set.write_buffer.clear();
585
586 let parse_in_chunk = needs_parse;
588 if parse_in_chunk {
589 write_parse(
590 &mut self.buffer_set.write_buffer,
591 "",
592 sql.unwrap(),
593 ¶m_oids,
594 );
595 }
596
597 for params in chunk {
599 let effective_stmt_name = if needs_parse { "" } else { &stmt_name };
600 write_bind(
601 &mut self.buffer_set.write_buffer,
602 "",
603 effective_stmt_name,
604 params,
605 ¶m_oids,
606 )?;
607 write_execute(&mut self.buffer_set.write_buffer, "", 0);
608 }
609
610 write_sync(&mut self.buffer_set.write_buffer);
612
613 let mut state_machine = BatchStateMachine::new(parse_in_chunk);
615 self.drive_batch(&mut state_machine).await?;
616 self.transaction_status = state_machine.transaction_status();
617 }
618
619 Ok(())
620 }
621
622 async fn drive_batch(
624 &mut self,
625 state_machine: &mut crate::state::extended::BatchStateMachine,
626 ) -> Result<()> {
627 use crate::protocol::backend::{ReadyForQuery, msg_type};
628 use crate::state::action::Action;
629
630 loop {
631 let step_result = state_machine.step(&mut self.buffer_set);
632 match step_result {
633 Ok(Action::ReadMessage) => {
634 self.stream.read_message(&mut self.buffer_set).await?;
635 }
636 Ok(Action::WriteAndReadMessage) => {
637 self.stream.write_all(&self.buffer_set.write_buffer).await?;
638 self.stream.flush().await?;
639 self.stream.read_message(&mut self.buffer_set).await?;
640 }
641 Ok(Action::Finished) => {
642 break;
643 }
644 Ok(_) => return Err(Error::Protocol("Unexpected action in batch".into())),
645 Err(e) => {
646 loop {
648 self.stream.read_message(&mut self.buffer_set).await?;
649 if self.buffer_set.type_byte == msg_type::READY_FOR_QUERY {
650 let ready = ReadyForQuery::parse(&self.buffer_set.read_buffer)?;
651 self.transaction_status =
652 ready.transaction_status().unwrap_or_default();
653 break;
654 }
655 }
656 return Err(e);
657 }
658 }
659 }
660 Ok(())
661 }
662
663 pub async fn close_statement(&mut self, stmt: &PreparedStatement) -> Result<()> {
665 let result = self.close_statement_inner(&stmt.wire_name()).await;
666 if let Err(e) = &result
667 && e.is_connection_broken()
668 {
669 self.is_broken = true;
670 }
671 result
672 }
673
674 async fn close_statement_inner(&mut self, name: &str) -> Result<()> {
675 let mut handler = DropHandler::new();
676 let mut state_machine =
677 ExtendedQueryStateMachine::close_statement(&mut handler, &mut self.buffer_set, name);
678 self.drive(&mut state_machine).await
679 }
680
681 pub async fn lowlevel_flush(&mut self) -> Result<()> {
689 use crate::protocol::frontend::write_flush;
690
691 self.buffer_set.write_buffer.clear();
692 write_flush(&mut self.buffer_set.write_buffer);
693
694 self.stream.write_all(&self.buffer_set.write_buffer).await?;
695 self.stream.flush().await?;
696 Ok(())
697 }
698
699 pub async fn lowlevel_sync(&mut self) -> Result<()> {
706 let result = self.lowlevel_sync_inner().await;
707 if let Err(e) = &result
708 && e.is_connection_broken()
709 {
710 self.is_broken = true;
711 }
712 result
713 }
714
715 async fn lowlevel_sync_inner(&mut self) -> Result<()> {
716 use crate::protocol::backend::{ErrorResponse, RawMessage, ReadyForQuery, msg_type};
717 use crate::protocol::frontend::write_sync;
718
719 self.buffer_set.write_buffer.clear();
720 write_sync(&mut self.buffer_set.write_buffer);
721
722 self.stream.write_all(&self.buffer_set.write_buffer).await?;
723 self.stream.flush().await?;
724
725 let mut pending_error: Option<Error> = None;
726
727 loop {
728 self.stream.read_message(&mut self.buffer_set).await?;
729 let type_byte = self.buffer_set.type_byte;
730
731 if RawMessage::is_async_type(type_byte) {
732 continue;
733 }
734
735 match type_byte {
736 msg_type::READY_FOR_QUERY => {
737 let ready = ReadyForQuery::parse(&self.buffer_set.read_buffer)?;
738 self.transaction_status = ready.transaction_status().unwrap_or_default();
739 if let Some(e) = pending_error {
740 return Err(e);
741 }
742 return Ok(());
743 }
744 msg_type::ERROR_RESPONSE => {
745 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
746 pending_error = Some(error.into_error());
747 }
748 _ => {
749 }
751 }
752 }
753 }
754
755 pub async fn lowlevel_bind<P: ToParams>(
765 &mut self,
766 portal: &str,
767 statement_name: &str,
768 params: P,
769 ) -> Result<()> {
770 let result = self
771 .lowlevel_bind_inner(portal, statement_name, ¶ms)
772 .await;
773 if let Err(e) = &result
774 && e.is_connection_broken()
775 {
776 self.is_broken = true;
777 }
778 result
779 }
780
781 async fn lowlevel_bind_inner<P: ToParams>(
782 &mut self,
783 portal: &str,
784 statement_name: &str,
785 params: &P,
786 ) -> Result<()> {
787 use crate::protocol::backend::{BindComplete, ErrorResponse, RawMessage, msg_type};
788 use crate::protocol::frontend::{write_bind, write_flush};
789
790 let param_oids = params.natural_oids();
791 self.buffer_set.write_buffer.clear();
792 write_bind(
793 &mut self.buffer_set.write_buffer,
794 portal,
795 statement_name,
796 params,
797 ¶m_oids,
798 )?;
799 write_flush(&mut self.buffer_set.write_buffer);
800
801 self.stream.write_all(&self.buffer_set.write_buffer).await?;
802 self.stream.flush().await?;
803
804 loop {
805 self.stream.read_message(&mut self.buffer_set).await?;
806 let type_byte = self.buffer_set.type_byte;
807
808 if RawMessage::is_async_type(type_byte) {
809 continue;
810 }
811
812 match type_byte {
813 msg_type::BIND_COMPLETE => {
814 BindComplete::parse(&self.buffer_set.read_buffer)?;
815 return Ok(());
816 }
817 msg_type::ERROR_RESPONSE => {
818 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
819 return Err(error.into_error());
820 }
821 _ => {
822 return Err(Error::Protocol(format!(
823 "Expected BindComplete or ErrorResponse, got '{}'",
824 type_byte as char
825 )));
826 }
827 }
828 }
829 }
830
831 pub async fn lowlevel_execute<H: BinaryHandler>(
844 &mut self,
845 portal: &str,
846 max_rows: u32,
847 handler: &mut H,
848 ) -> Result<bool> {
849 let result = self.lowlevel_execute_inner(portal, max_rows, handler).await;
850 if let Err(e) = &result
851 && e.is_connection_broken()
852 {
853 self.is_broken = true;
854 }
855 result
856 }
857
858 async fn lowlevel_execute_inner<H: BinaryHandler>(
859 &mut self,
860 portal: &str,
861 max_rows: u32,
862 handler: &mut H,
863 ) -> Result<bool> {
864 use crate::protocol::backend::{
865 CommandComplete, DataRow, ErrorResponse, NoData, PortalSuspended, RawMessage,
866 RowDescription, msg_type,
867 };
868 use crate::protocol::frontend::{write_describe_portal, write_execute, write_flush};
869
870 self.buffer_set.write_buffer.clear();
871 write_describe_portal(&mut self.buffer_set.write_buffer, portal);
872 write_execute(&mut self.buffer_set.write_buffer, portal, max_rows);
873 write_flush(&mut self.buffer_set.write_buffer);
874
875 self.stream.write_all(&self.buffer_set.write_buffer).await?;
876 self.stream.flush().await?;
877
878 let mut column_buffer: Vec<u8> = Vec::new();
879
880 loop {
881 self.stream.read_message(&mut self.buffer_set).await?;
882 let type_byte = self.buffer_set.type_byte;
883
884 if RawMessage::is_async_type(type_byte) {
885 continue;
886 }
887
888 match type_byte {
889 msg_type::ROW_DESCRIPTION => {
890 column_buffer.clear();
891 column_buffer.extend_from_slice(&self.buffer_set.read_buffer);
892 let cols = RowDescription::parse(&column_buffer)?;
893 handler.result_start(cols)?;
894 }
895 msg_type::NO_DATA => {
896 NoData::parse(&self.buffer_set.read_buffer)?;
897 }
898 msg_type::DATA_ROW => {
899 let cols = RowDescription::parse(&column_buffer)?;
900 let row = DataRow::parse(&self.buffer_set.read_buffer)?;
901 handler.row(cols, row)?;
902 }
903 msg_type::COMMAND_COMPLETE => {
904 let complete = CommandComplete::parse(&self.buffer_set.read_buffer)?;
905 handler.result_end(complete)?;
906 return Ok(false); }
908 msg_type::PORTAL_SUSPENDED => {
909 PortalSuspended::parse(&self.buffer_set.read_buffer)?;
910 return Ok(true); }
912 msg_type::ERROR_RESPONSE => {
913 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
914 return Err(error.into_error());
915 }
916 _ => {
917 return Err(Error::Protocol(format!(
918 "Unexpected message in execute: '{}'",
919 type_byte as char
920 )));
921 }
922 }
923 }
924 }
925
926 pub async fn exec_iter<S: IntoStatement, P, F, Fut, T>(
956 &mut self,
957 statement: S,
958 params: P,
959 f: F,
960 ) -> Result<T>
961 where
962 P: ToParams,
963 F: FnOnce(&mut super::unnamed_portal::UnnamedPortal<'_>) -> Fut,
964 Fut: std::future::Future<Output = Result<T>>,
965 {
966 let result = self.exec_iter_inner(&statement, ¶ms, f).await;
967 if let Err(e) = &result
968 && e.is_connection_broken()
969 {
970 self.is_broken = true;
971 }
972 result
973 }
974
975 async fn exec_iter_inner<S: IntoStatement, P, F, Fut, T>(
976 &mut self,
977 statement: &S,
978 params: &P,
979 f: F,
980 ) -> Result<T>
981 where
982 P: ToParams,
983 F: FnOnce(&mut super::unnamed_portal::UnnamedPortal<'_>) -> Fut,
984 Fut: std::future::Future<Output = Result<T>>,
985 {
986 let mut state_machine = if let Some(sql) = statement.as_sql() {
988 BindStateMachine::bind_sql(&mut self.buffer_set, "", sql, params)?
989 } else {
990 let stmt = statement.as_prepared().unwrap();
991 BindStateMachine::bind_prepared(
992 &mut self.buffer_set,
993 "",
994 &stmt.wire_name(),
995 &stmt.param_oids,
996 params,
997 )?
998 };
999
1000 loop {
1002 match state_machine.step(&mut self.buffer_set)? {
1003 Action::ReadMessage => {
1004 self.stream.read_message(&mut self.buffer_set).await?;
1005 }
1006 Action::Write => {
1007 self.stream.write_all(&self.buffer_set.write_buffer).await?;
1008 self.stream.flush().await?;
1009 }
1010 Action::WriteAndReadMessage => {
1011 self.stream.write_all(&self.buffer_set.write_buffer).await?;
1012 self.stream.flush().await?;
1013 self.stream.read_message(&mut self.buffer_set).await?;
1014 }
1015 Action::Finished => break,
1016 _ => return Err(Error::Protocol("Unexpected action in bind".into())),
1017 }
1018 }
1019
1020 let mut portal = super::unnamed_portal::UnnamedPortal { conn: self };
1022 let result = f(&mut portal).await;
1023
1024 let sync_result = portal.conn.lowlevel_sync().await;
1026
1027 match (result, sync_result) {
1029 (Ok(v), Ok(())) => Ok(v),
1030 (Err(e), _) => Err(e),
1031 (Ok(_), Err(e)) => Err(e),
1032 }
1033 }
1034
1035 pub async fn lowlevel_close_portal(&mut self, portal: &str) -> Result<()> {
1037 let result = self.lowlevel_close_portal_inner(portal).await;
1038 if let Err(e) = &result
1039 && e.is_connection_broken()
1040 {
1041 self.is_broken = true;
1042 }
1043 result
1044 }
1045
1046 async fn lowlevel_close_portal_inner(&mut self, portal: &str) -> Result<()> {
1047 use crate::protocol::backend::{CloseComplete, ErrorResponse, RawMessage, msg_type};
1048 use crate::protocol::frontend::{write_close_portal, write_flush};
1049
1050 self.buffer_set.write_buffer.clear();
1051 write_close_portal(&mut self.buffer_set.write_buffer, portal);
1052 write_flush(&mut self.buffer_set.write_buffer);
1053
1054 self.stream.write_all(&self.buffer_set.write_buffer).await?;
1055 self.stream.flush().await?;
1056
1057 loop {
1058 self.stream.read_message(&mut self.buffer_set).await?;
1059 let type_byte = self.buffer_set.type_byte;
1060
1061 if RawMessage::is_async_type(type_byte) {
1062 continue;
1063 }
1064
1065 match type_byte {
1066 msg_type::CLOSE_COMPLETE => {
1067 CloseComplete::parse(&self.buffer_set.read_buffer)?;
1068 return Ok(());
1069 }
1070 msg_type::ERROR_RESPONSE => {
1071 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
1072 return Err(error.into_error());
1073 }
1074 _ => {
1075 return Err(Error::Protocol(format!(
1076 "Expected CloseComplete or ErrorResponse, got '{}'",
1077 type_byte as char
1078 )));
1079 }
1080 }
1081 }
1082 }
1083
1084 pub async fn run_pipeline<T, F, Fut>(&mut self, f: F) -> Result<T>
1115 where
1116 F: FnOnce(&mut super::pipeline::Pipeline<'_>) -> Fut,
1117 Fut: std::future::Future<Output = Result<T>>,
1118 {
1119 let mut pipeline = super::pipeline::Pipeline::new_inner(self);
1120 let result = f(&mut pipeline).await;
1121 pipeline.cleanup().await;
1122 result
1123 }
1124
1125 pub async fn transaction<F, R, Fut>(&mut self, f: F) -> Result<R>
1135 where
1136 F: FnOnce(&mut Conn, super::transaction::Transaction) -> Fut,
1137 Fut: std::future::Future<Output = Result<R>>,
1138 {
1139 if self.in_transaction() {
1140 return Err(Error::InvalidUsage(
1141 "nested transactions are not supported".into(),
1142 ));
1143 }
1144
1145 self.query_drop("BEGIN").await?;
1146
1147 let tx = super::transaction::Transaction::new(self.connection_id());
1148
1149 let result = f(self, tx).await;
1152
1153 if self.in_transaction() {
1155 let rollback_result = self.query_drop("ROLLBACK").await;
1156
1157 if let Err(e) = result {
1159 return Err(e);
1160 }
1161 rollback_result?;
1162 }
1163
1164 result
1165 }
1166}