1use tokio::net::TcpStream;
4use tokio::net::UnixStream;
5
6use crate::buffer_pool::PooledBufferSet;
7use crate::conversion::ToParams;
8use crate::error::{Error, Result};
9use crate::handler::{
10 AsyncMessageHandler, BinaryHandler, DropHandler, FirstRowHandler, TextHandler,
11};
12use crate::opts::Opts;
13use crate::protocol::backend::BackendKeyData;
14use crate::protocol::frontend::write_terminate;
15use crate::protocol::types::TransactionStatus;
16use crate::state::StateMachine;
17use crate::state::action::Action;
18use crate::state::connection::ConnectionStateMachine;
19use crate::state::extended::{BindStateMachine, ExtendedQueryStateMachine, PreparedStatement};
20use crate::state::simple_query::SimpleQueryStateMachine;
21use crate::statement::IntoStatement;
22
23use super::stream::Stream;
24
25pub struct Conn {
27 pub(crate) stream: Stream,
28 pub(crate) buffer_set: PooledBufferSet,
29 backend_key: Option<BackendKeyData>,
30 server_params: Vec<(String, String)>,
31 pub(crate) transaction_status: TransactionStatus,
32 pub(crate) is_broken: bool,
33 stmt_counter: u64,
34 async_message_handler: Option<Box<dyn AsyncMessageHandler>>,
35}
36
37impl Conn {
38 pub async fn new<O: TryInto<Opts>>(opts: O) -> Result<Self>
40 where
41 Error: From<O::Error>,
42 {
43 let opts = opts.try_into()?;
44
45 let stream = if let Some(socket_path) = &opts.socket {
46 Stream::unix(UnixStream::connect(socket_path).await?)
47 } else {
48 if opts.host.is_empty() {
49 return Err(Error::InvalidUsage("host is empty".into()));
50 }
51 let addr = format!("{}:{}", opts.host, opts.port);
52 let tcp = TcpStream::connect(&addr).await?;
53 tcp.set_nodelay(true)?;
54 Stream::tcp(tcp)
55 };
56
57 Self::new_with_stream(stream, opts).await
58 }
59
60 #[allow(unused_mut)]
62 pub async fn new_with_stream(mut stream: Stream, options: Opts) -> Result<Self> {
63 let mut buffer_set = options.buffer_pool.get_buffer_set();
64 let mut state_machine = ConnectionStateMachine::new(options.clone());
65
66 loop {
68 match state_machine.step(&mut buffer_set)? {
69 Action::WriteAndReadByte => {
70 stream.write_all(&buffer_set.write_buffer).await?;
71 stream.flush().await?;
72 let byte = stream.read_u8().await?;
73 state_machine.set_ssl_response(byte);
74 }
75 Action::ReadMessage => {
76 stream.read_message(&mut buffer_set).await?;
77 }
78 Action::Write => {
79 stream.write_all(&buffer_set.write_buffer).await?;
80 stream.flush().await?;
81 }
82 Action::WriteAndReadMessage => {
83 stream.write_all(&buffer_set.write_buffer).await?;
84 stream.flush().await?;
85 stream.read_message(&mut buffer_set).await?;
86 }
87 Action::TlsHandshake => {
88 #[cfg(feature = "tokio-tls")]
89 {
90 stream = stream.upgrade_to_tls(&options.host).await?;
91 }
92 #[cfg(not(feature = "tokio-tls"))]
93 {
94 return Err(Error::Unsupported(
95 "TLS requested but tokio-tls feature not enabled".into(),
96 ));
97 }
98 }
99 Action::HandleAsyncMessageAndReadMessage(_) => {
100 stream.read_message(&mut buffer_set).await?;
102 }
103 Action::Finished => break,
104 }
105 }
106
107 let conn = Self {
108 stream,
109 buffer_set,
110 backend_key: state_machine.backend_key().cloned(),
111 server_params: state_machine.take_server_params(),
112 transaction_status: state_machine.transaction_status(),
113 is_broken: false,
114 stmt_counter: 0,
115 async_message_handler: None,
116 };
117
118 let conn = if options.prefer_unix_socket && conn.stream.is_tcp_loopback() {
120 conn.try_upgrade_to_unix_socket(&options).await
121 } else {
122 conn
123 };
124
125 Ok(conn)
126 }
127
128 fn try_upgrade_to_unix_socket(
131 mut self,
132 opts: &Opts,
133 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Self> + Send + '_>> {
134 let opts = opts.clone();
135 Box::pin(async move {
136 let mut handler = FirstRowHandler::<(String,)>::new();
138 if self
139 .query("SHOW unix_socket_directories", &mut handler)
140 .await
141 .is_err()
142 {
143 return self;
144 }
145
146 let socket_dir = match handler.into_row() {
147 Some((dirs,)) => {
148 match dirs.split(',').next() {
150 Some(d) if !d.trim().is_empty() => d.trim().to_string(),
151 _ => return self,
152 }
153 }
154 None => return self,
155 };
156
157 let socket_path = format!("{}/.s.PGSQL.{}", socket_dir, opts.port);
159
160 let unix_stream = match UnixStream::connect(&socket_path).await {
162 Ok(s) => s,
163 Err(_) => return self,
164 };
165
166 let mut opts_unix = opts.clone();
168 opts_unix.prefer_unix_socket = false;
169
170 match Self::new_with_stream(Stream::unix(unix_stream), opts_unix).await {
171 Ok(new_conn) => new_conn,
172 Err(_) => self,
173 }
174 })
175 }
176
177 pub fn backend_key(&self) -> Option<&BackendKeyData> {
179 self.backend_key.as_ref()
180 }
181
182 pub fn connection_id(&self) -> u32 {
186 self.backend_key.as_ref().map_or(0, |k| k.process_id())
187 }
188
189 pub fn server_params(&self) -> &[(String, String)] {
191 &self.server_params
192 }
193
194 pub fn transaction_status(&self) -> TransactionStatus {
196 self.transaction_status
197 }
198
199 pub fn in_transaction(&self) -> bool {
201 self.transaction_status.in_transaction()
202 }
203
204 pub fn is_broken(&self) -> bool {
206 self.is_broken
207 }
208
209 pub fn set_async_message_handler<H: AsyncMessageHandler + 'static>(&mut self, handler: H) {
216 self.async_message_handler = Some(Box::new(handler));
217 }
218
219 pub fn clear_async_message_handler(&mut self) {
221 self.async_message_handler = None;
222 }
223
224 pub async fn ping(&mut self) -> Result<()> {
226 self.query_drop("").await?;
227 Ok(())
228 }
229
230 async fn drive<S: StateMachine>(&mut self, state_machine: &mut S) -> Result<()> {
232 loop {
233 match state_machine.step(&mut self.buffer_set)? {
234 Action::WriteAndReadByte => {
235 return Err(Error::Protocol(
236 "Unexpected WriteAndReadByte in query state machine".into(),
237 ));
238 }
239 Action::ReadMessage => {
240 self.stream.read_message(&mut self.buffer_set).await?;
241 }
242 Action::Write => {
243 self.stream.write_all(&self.buffer_set.write_buffer).await?;
244 self.stream.flush().await?;
245 }
246 Action::WriteAndReadMessage => {
247 self.stream.write_all(&self.buffer_set.write_buffer).await?;
248 self.stream.flush().await?;
249 self.stream.read_message(&mut self.buffer_set).await?;
250 }
251 Action::TlsHandshake => {
252 return Err(Error::Protocol(
253 "Unexpected TlsHandshake in query state machine".into(),
254 ));
255 }
256 Action::HandleAsyncMessageAndReadMessage(ref async_msg) => {
257 if let Some(ref mut h) = self.async_message_handler {
258 h.handle(async_msg);
259 }
260 self.stream.read_message(&mut self.buffer_set).await?;
262 }
263 Action::Finished => {
264 self.transaction_status = state_machine.transaction_status();
265 break;
266 }
267 }
268 }
269 Ok(())
270 }
271
272 pub async fn query<H: TextHandler>(&mut self, sql: &str, handler: &mut H) -> Result<()> {
274 let result = self.query_inner(sql, handler).await;
275 if let Err(e) = &result
276 && e.is_connection_broken()
277 {
278 self.is_broken = true;
279 }
280 result
281 }
282
283 async fn query_inner<H: TextHandler>(&mut self, sql: &str, handler: &mut H) -> Result<()> {
284 let mut state_machine = SimpleQueryStateMachine::new(handler, sql);
285 self.drive(&mut state_machine).await
286 }
287
288 pub async fn query_drop(&mut self, sql: &str) -> Result<Option<u64>> {
290 let mut handler = DropHandler::new();
291 self.query(sql, &mut handler).await?;
292 Ok(handler.rows_affected())
293 }
294
295 pub async fn query_collect<T: for<'a> crate::conversion::FromRow<'a>>(
297 &mut self,
298 sql: &str,
299 ) -> Result<Vec<T>> {
300 let mut handler = crate::handler::CollectHandler::<T>::new();
301 self.query(sql, &mut handler).await?;
302 Ok(handler.into_rows())
303 }
304
305 pub async fn query_first<T: for<'a> crate::conversion::FromRow<'a>>(
307 &mut self,
308 sql: &str,
309 ) -> Result<Option<T>> {
310 let mut handler = crate::handler::FirstRowHandler::<T>::new();
311 self.query(sql, &mut handler).await?;
312 Ok(handler.into_row())
313 }
314
315 pub async fn close(mut self) -> Result<()> {
317 self.buffer_set.write_buffer.clear();
318 write_terminate(&mut self.buffer_set.write_buffer);
319 self.stream.write_all(&self.buffer_set.write_buffer).await?;
320 self.stream.flush().await?;
321 Ok(())
322 }
323
324 pub async fn prepare(&mut self, query: &str) -> Result<PreparedStatement> {
328 self.prepare_typed(query, &[]).await
329 }
330
331 pub async fn prepare_typed(
333 &mut self,
334 query: &str,
335 param_oids: &[u32],
336 ) -> Result<PreparedStatement> {
337 self.stmt_counter += 1;
338 let idx = self.stmt_counter;
339 let result = self.prepare_inner(idx, query, param_oids).await;
340 if let Err(e) = &result
341 && e.is_connection_broken()
342 {
343 self.is_broken = true;
344 }
345 result
346 }
347
348 async fn prepare_inner(
349 &mut self,
350 idx: u64,
351 query: &str,
352 param_oids: &[u32],
353 ) -> Result<PreparedStatement> {
354 let mut handler = DropHandler::new();
355 let mut state_machine = ExtendedQueryStateMachine::prepare(
356 &mut handler,
357 &mut self.buffer_set,
358 idx,
359 query,
360 param_oids,
361 );
362 self.drive(&mut state_machine).await?;
363 state_machine
364 .take_prepared_statement()
365 .ok_or_else(|| Error::Protocol("No prepared statement".into()))
366 }
367
368 pub async fn exec<S: IntoStatement, P: ToParams, H: BinaryHandler>(
374 &mut self,
375 statement: S,
376 params: P,
377 handler: &mut H,
378 ) -> Result<()> {
379 let result = self.exec_inner(&statement, ¶ms, handler).await;
380 if let Err(e) = &result
381 && e.is_connection_broken()
382 {
383 self.is_broken = true;
384 }
385 result
386 }
387
388 async fn exec_inner<S: IntoStatement, P: ToParams, H: BinaryHandler>(
389 &mut self,
390 statement: &S,
391 params: &P,
392 handler: &mut H,
393 ) -> Result<()> {
394 let mut state_machine = if statement.needs_parse() {
395 ExtendedQueryStateMachine::execute_sql(
396 handler,
397 &mut self.buffer_set,
398 statement.as_sql().unwrap(),
399 params,
400 )?
401 } else {
402 let stmt = statement.as_prepared().unwrap();
403 ExtendedQueryStateMachine::execute(
404 handler,
405 &mut self.buffer_set,
406 &stmt.wire_name(),
407 &stmt.param_oids,
408 params,
409 )?
410 };
411
412 self.drive(&mut state_machine).await
413 }
414
415 pub async fn exec_drop<S: IntoStatement, P: ToParams>(
419 &mut self,
420 statement: S,
421 params: P,
422 ) -> Result<Option<u64>> {
423 let mut handler = DropHandler::new();
424 self.exec(statement, params, &mut handler).await?;
425 Ok(handler.rows_affected())
426 }
427
428 pub async fn exec_collect<
432 T: for<'a> crate::conversion::FromRow<'a>,
433 S: IntoStatement,
434 P: ToParams,
435 >(
436 &mut self,
437 statement: S,
438 params: P,
439 ) -> Result<Vec<T>> {
440 let mut handler = crate::handler::CollectHandler::<T>::new();
441 self.exec(statement, params, &mut handler).await?;
442 Ok(handler.into_rows())
443 }
444
445 pub async fn exec_batch<S: IntoStatement, P: ToParams>(
476 &mut self,
477 statement: S,
478 params_list: &[P],
479 ) -> Result<()> {
480 self.exec_batch_chunked(statement, params_list, 1000).await
481 }
482
483 pub async fn exec_batch_chunked<S: IntoStatement, P: ToParams>(
487 &mut self,
488 statement: S,
489 params_list: &[P],
490 chunk_size: usize,
491 ) -> Result<()> {
492 let result = self
493 .exec_batch_inner(&statement, params_list, chunk_size)
494 .await;
495 if let Err(e) = &result
496 && e.is_connection_broken()
497 {
498 self.is_broken = true;
499 }
500 result
501 }
502
503 async fn exec_batch_inner<S: IntoStatement, P: ToParams>(
504 &mut self,
505 statement: &S,
506 params_list: &[P],
507 chunk_size: usize,
508 ) -> Result<()> {
509 use crate::protocol::frontend::{write_bind, write_execute, write_parse, write_sync};
510 use crate::state::extended::BatchStateMachine;
511
512 if params_list.is_empty() {
513 return Ok(());
514 }
515
516 let chunk_size = chunk_size.max(1);
517 let needs_parse = statement.needs_parse();
518 let sql = statement.as_sql();
519 let prepared = statement.as_prepared();
520
521 let param_oids: Vec<u32> = if let Some(stmt) = prepared {
523 stmt.param_oids.clone()
524 } else {
525 params_list[0].natural_oids()
526 };
527
528 let stmt_name = prepared.map(|s| s.wire_name()).unwrap_or_default();
530
531 for chunk in params_list.chunks(chunk_size) {
532 self.buffer_set.write_buffer.clear();
533
534 let parse_in_chunk = needs_parse;
536 if parse_in_chunk {
537 write_parse(
538 &mut self.buffer_set.write_buffer,
539 "",
540 sql.unwrap(),
541 ¶m_oids,
542 );
543 }
544
545 for params in chunk {
547 let effective_stmt_name = if needs_parse { "" } else { &stmt_name };
548 write_bind(
549 &mut self.buffer_set.write_buffer,
550 "",
551 effective_stmt_name,
552 params,
553 ¶m_oids,
554 )?;
555 write_execute(&mut self.buffer_set.write_buffer, "", 0);
556 }
557
558 write_sync(&mut self.buffer_set.write_buffer);
560
561 let mut state_machine = BatchStateMachine::new(parse_in_chunk);
563 self.drive_batch(&mut state_machine).await?;
564 self.transaction_status = state_machine.transaction_status();
565 }
566
567 Ok(())
568 }
569
570 async fn drive_batch(
572 &mut self,
573 state_machine: &mut crate::state::extended::BatchStateMachine,
574 ) -> Result<()> {
575 use crate::protocol::backend::{ReadyForQuery, msg_type};
576 use crate::state::action::Action;
577
578 loop {
579 let step_result = state_machine.step(&mut self.buffer_set);
580 match step_result {
581 Ok(Action::ReadMessage) => {
582 self.stream.read_message(&mut self.buffer_set).await?;
583 }
584 Ok(Action::WriteAndReadMessage) => {
585 self.stream.write_all(&self.buffer_set.write_buffer).await?;
586 self.stream.flush().await?;
587 self.stream.read_message(&mut self.buffer_set).await?;
588 }
589 Ok(Action::Finished) => {
590 break;
591 }
592 Ok(_) => return Err(Error::Protocol("Unexpected action in batch".into())),
593 Err(e) => {
594 loop {
596 self.stream.read_message(&mut self.buffer_set).await?;
597 if self.buffer_set.type_byte == msg_type::READY_FOR_QUERY {
598 let ready = ReadyForQuery::parse(&self.buffer_set.read_buffer)?;
599 self.transaction_status =
600 ready.transaction_status().unwrap_or_default();
601 break;
602 }
603 }
604 return Err(e);
605 }
606 }
607 }
608 Ok(())
609 }
610
611 pub async fn close_statement(&mut self, stmt: &PreparedStatement) -> Result<()> {
613 let result = self.close_statement_inner(&stmt.wire_name()).await;
614 if let Err(e) = &result
615 && e.is_connection_broken()
616 {
617 self.is_broken = true;
618 }
619 result
620 }
621
622 async fn close_statement_inner(&mut self, name: &str) -> Result<()> {
623 let mut handler = DropHandler::new();
624 let mut state_machine =
625 ExtendedQueryStateMachine::close_statement(&mut handler, &mut self.buffer_set, name);
626 self.drive(&mut state_machine).await
627 }
628
629 pub async fn lowlevel_flush(&mut self) -> Result<()> {
637 use crate::protocol::frontend::write_flush;
638
639 self.buffer_set.write_buffer.clear();
640 write_flush(&mut self.buffer_set.write_buffer);
641
642 self.stream.write_all(&self.buffer_set.write_buffer).await?;
643 self.stream.flush().await?;
644 Ok(())
645 }
646
647 pub async fn lowlevel_sync(&mut self) -> Result<()> {
654 let result = self.lowlevel_sync_inner().await;
655 if let Err(e) = &result
656 && e.is_connection_broken()
657 {
658 self.is_broken = true;
659 }
660 result
661 }
662
663 async fn lowlevel_sync_inner(&mut self) -> Result<()> {
664 use crate::protocol::backend::{ErrorResponse, RawMessage, ReadyForQuery, msg_type};
665 use crate::protocol::frontend::write_sync;
666
667 self.buffer_set.write_buffer.clear();
668 write_sync(&mut self.buffer_set.write_buffer);
669
670 self.stream.write_all(&self.buffer_set.write_buffer).await?;
671 self.stream.flush().await?;
672
673 let mut pending_error: Option<Error> = None;
674
675 loop {
676 self.stream.read_message(&mut self.buffer_set).await?;
677 let type_byte = self.buffer_set.type_byte;
678
679 if RawMessage::is_async_type(type_byte) {
680 continue;
681 }
682
683 match type_byte {
684 msg_type::READY_FOR_QUERY => {
685 let ready = ReadyForQuery::parse(&self.buffer_set.read_buffer)?;
686 self.transaction_status = ready.transaction_status().unwrap_or_default();
687 if let Some(e) = pending_error {
688 return Err(e);
689 }
690 return Ok(());
691 }
692 msg_type::ERROR_RESPONSE => {
693 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
694 pending_error = Some(error.into_error());
695 }
696 _ => {
697 }
699 }
700 }
701 }
702
703 pub async fn lowlevel_bind<P: ToParams>(
713 &mut self,
714 portal: &str,
715 statement_name: &str,
716 params: P,
717 ) -> Result<()> {
718 let result = self
719 .lowlevel_bind_inner(portal, statement_name, ¶ms)
720 .await;
721 if let Err(e) = &result
722 && e.is_connection_broken()
723 {
724 self.is_broken = true;
725 }
726 result
727 }
728
729 async fn lowlevel_bind_inner<P: ToParams>(
730 &mut self,
731 portal: &str,
732 statement_name: &str,
733 params: &P,
734 ) -> Result<()> {
735 use crate::protocol::backend::{BindComplete, ErrorResponse, RawMessage, msg_type};
736 use crate::protocol::frontend::{write_bind, write_flush};
737
738 let param_oids = params.natural_oids();
739 self.buffer_set.write_buffer.clear();
740 write_bind(
741 &mut self.buffer_set.write_buffer,
742 portal,
743 statement_name,
744 params,
745 ¶m_oids,
746 )?;
747 write_flush(&mut self.buffer_set.write_buffer);
748
749 self.stream.write_all(&self.buffer_set.write_buffer).await?;
750 self.stream.flush().await?;
751
752 loop {
753 self.stream.read_message(&mut self.buffer_set).await?;
754 let type_byte = self.buffer_set.type_byte;
755
756 if RawMessage::is_async_type(type_byte) {
757 continue;
758 }
759
760 match type_byte {
761 msg_type::BIND_COMPLETE => {
762 BindComplete::parse(&self.buffer_set.read_buffer)?;
763 return Ok(());
764 }
765 msg_type::ERROR_RESPONSE => {
766 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
767 return Err(error.into_error());
768 }
769 _ => {
770 return Err(Error::Protocol(format!(
771 "Expected BindComplete or ErrorResponse, got '{}'",
772 type_byte as char
773 )));
774 }
775 }
776 }
777 }
778
779 pub async fn lowlevel_execute<H: BinaryHandler>(
792 &mut self,
793 portal: &str,
794 max_rows: u32,
795 handler: &mut H,
796 ) -> Result<bool> {
797 let result = self.lowlevel_execute_inner(portal, max_rows, handler).await;
798 if let Err(e) = &result
799 && e.is_connection_broken()
800 {
801 self.is_broken = true;
802 }
803 result
804 }
805
806 async fn lowlevel_execute_inner<H: BinaryHandler>(
807 &mut self,
808 portal: &str,
809 max_rows: u32,
810 handler: &mut H,
811 ) -> Result<bool> {
812 use crate::protocol::backend::{
813 CommandComplete, DataRow, ErrorResponse, NoData, PortalSuspended, RawMessage,
814 RowDescription, msg_type,
815 };
816 use crate::protocol::frontend::{write_describe_portal, write_execute, write_flush};
817
818 self.buffer_set.write_buffer.clear();
819 write_describe_portal(&mut self.buffer_set.write_buffer, portal);
820 write_execute(&mut self.buffer_set.write_buffer, portal, max_rows);
821 write_flush(&mut self.buffer_set.write_buffer);
822
823 self.stream.write_all(&self.buffer_set.write_buffer).await?;
824 self.stream.flush().await?;
825
826 let mut column_buffer: Vec<u8> = Vec::new();
827
828 loop {
829 self.stream.read_message(&mut self.buffer_set).await?;
830 let type_byte = self.buffer_set.type_byte;
831
832 if RawMessage::is_async_type(type_byte) {
833 continue;
834 }
835
836 match type_byte {
837 msg_type::ROW_DESCRIPTION => {
838 column_buffer.clear();
839 column_buffer.extend_from_slice(&self.buffer_set.read_buffer);
840 let cols = RowDescription::parse(&column_buffer)?;
841 handler.result_start(cols)?;
842 }
843 msg_type::NO_DATA => {
844 NoData::parse(&self.buffer_set.read_buffer)?;
845 }
846 msg_type::DATA_ROW => {
847 let cols = RowDescription::parse(&column_buffer)?;
848 let row = DataRow::parse(&self.buffer_set.read_buffer)?;
849 handler.row(cols, row)?;
850 }
851 msg_type::COMMAND_COMPLETE => {
852 let complete = CommandComplete::parse(&self.buffer_set.read_buffer)?;
853 handler.result_end(complete)?;
854 return Ok(false); }
856 msg_type::PORTAL_SUSPENDED => {
857 PortalSuspended::parse(&self.buffer_set.read_buffer)?;
858 return Ok(true); }
860 msg_type::ERROR_RESPONSE => {
861 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
862 return Err(error.into_error());
863 }
864 _ => {
865 return Err(Error::Protocol(format!(
866 "Unexpected message in execute: '{}'",
867 type_byte as char
868 )));
869 }
870 }
871 }
872 }
873
874 pub async fn exec_iter<S: IntoStatement, P, F, Fut, T>(
904 &mut self,
905 statement: S,
906 params: P,
907 f: F,
908 ) -> Result<T>
909 where
910 P: ToParams,
911 F: FnOnce(&mut super::unnamed_portal::UnnamedPortal<'_>) -> Fut,
912 Fut: std::future::Future<Output = Result<T>>,
913 {
914 let result = self.exec_iter_inner(&statement, ¶ms, f).await;
915 if let Err(e) = &result
916 && e.is_connection_broken()
917 {
918 self.is_broken = true;
919 }
920 result
921 }
922
923 async fn exec_iter_inner<S: IntoStatement, P, F, Fut, T>(
924 &mut self,
925 statement: &S,
926 params: &P,
927 f: F,
928 ) -> Result<T>
929 where
930 P: ToParams,
931 F: FnOnce(&mut super::unnamed_portal::UnnamedPortal<'_>) -> Fut,
932 Fut: std::future::Future<Output = Result<T>>,
933 {
934 let mut state_machine = if let Some(sql) = statement.as_sql() {
936 BindStateMachine::bind_sql(&mut self.buffer_set, sql, params)?
937 } else {
938 let stmt = statement.as_prepared().unwrap();
939 BindStateMachine::bind_prepared(
940 &mut self.buffer_set,
941 &stmt.wire_name(),
942 &stmt.param_oids,
943 params,
944 )?
945 };
946
947 loop {
949 match state_machine.step(&mut self.buffer_set)? {
950 Action::ReadMessage => {
951 self.stream.read_message(&mut self.buffer_set).await?;
952 }
953 Action::Write => {
954 self.stream.write_all(&self.buffer_set.write_buffer).await?;
955 self.stream.flush().await?;
956 }
957 Action::WriteAndReadMessage => {
958 self.stream.write_all(&self.buffer_set.write_buffer).await?;
959 self.stream.flush().await?;
960 self.stream.read_message(&mut self.buffer_set).await?;
961 }
962 Action::Finished => break,
963 _ => return Err(Error::Protocol("Unexpected action in bind".into())),
964 }
965 }
966
967 let mut portal = super::unnamed_portal::UnnamedPortal { conn: self };
969 let result = f(&mut portal).await;
970
971 let sync_result = portal.conn.lowlevel_sync().await;
973
974 match (result, sync_result) {
976 (Ok(v), Ok(())) => Ok(v),
977 (Err(e), _) => Err(e),
978 (Ok(_), Err(e)) => Err(e),
979 }
980 }
981
982 pub async fn run_pipeline<T, F, Fut>(&mut self, f: F) -> Result<T>
1013 where
1014 F: FnOnce(&mut super::pipeline::Pipeline<'_>) -> Fut,
1015 Fut: std::future::Future<Output = Result<T>>,
1016 {
1017 let mut pipeline = super::pipeline::Pipeline::new_inner(self);
1018 let result = f(&mut pipeline).await;
1019 pipeline.cleanup().await;
1020 result
1021 }
1022
1023 pub async fn run_transaction<F, R, Fut>(&mut self, f: F) -> Result<R>
1033 where
1034 F: FnOnce(&mut Conn, super::transaction::Transaction) -> Fut,
1035 Fut: std::future::Future<Output = Result<R>>,
1036 {
1037 if self.in_transaction() {
1038 return Err(Error::InvalidUsage(
1039 "nested transactions are not supported".into(),
1040 ));
1041 }
1042
1043 self.query_drop("BEGIN").await?;
1044
1045 let tx = super::transaction::Transaction::new(self.connection_id());
1046
1047 let result = f(self, tx).await;
1050
1051 if self.in_transaction() {
1053 let rollback_result = self.query_drop("ROLLBACK").await;
1054
1055 if let Err(e) = result {
1057 return Err(e);
1058 }
1059 rollback_result?;
1060 }
1061
1062 result
1063 }
1064}