1use tokio::net::TcpStream;
4use tokio::net::UnixStream;
5
6use crate::buffer_pool::PooledBufferSet;
7use crate::conversion::ToParams;
8use crate::error::{Error, Result};
9use crate::handler::{
10 AsyncMessageHandler, BinaryHandler, DropHandler, FirstRowHandler, TextHandler,
11};
12use crate::opts::Opts;
13use crate::protocol::backend::BackendKeyData;
14use crate::protocol::frontend::write_terminate;
15use crate::protocol::types::TransactionStatus;
16use crate::state::StateMachine;
17use crate::state::action::Action;
18use crate::state::connection::ConnectionStateMachine;
19use crate::state::extended::{BindStateMachine, ExtendedQueryStateMachine, PreparedStatement};
20use crate::state::simple_query::SimpleQueryStateMachine;
21use crate::statement::IntoStatement;
22
23use super::stream::Stream;
24
25pub struct Conn {
27 pub(crate) stream: Stream,
28 pub(crate) buffer_set: PooledBufferSet,
29 backend_key: Option<BackendKeyData>,
30 server_params: Vec<(String, String)>,
31 pub(crate) transaction_status: TransactionStatus,
32 pub(crate) is_broken: bool,
33 name_counter: u64,
34 async_message_handler: Option<Box<dyn AsyncMessageHandler>>,
35}
36
37impl Conn {
38 pub async fn new<O: TryInto<Opts>>(opts: O) -> Result<Self>
40 where
41 Error: From<O::Error>,
42 {
43 let opts = opts.try_into()?;
44
45 let stream = if let Some(socket_path) = &opts.socket {
46 Stream::unix(UnixStream::connect(socket_path).await?)
47 } else {
48 if opts.host.is_empty() {
49 return Err(Error::InvalidUsage("host is empty".into()));
50 }
51 let addr = format!("{}:{}", opts.host, opts.port);
52 let tcp = TcpStream::connect(&addr).await?;
53 tcp.set_nodelay(true)?;
54 Stream::tcp(tcp)
55 };
56
57 Self::new_with_stream(stream, opts).await
58 }
59
60 #[allow(unused_mut)]
62 pub async fn new_with_stream(mut stream: Stream, options: Opts) -> Result<Self> {
63 let mut buffer_set = options.buffer_pool.get_buffer_set();
64 let mut state_machine = ConnectionStateMachine::new(options.clone());
65
66 loop {
68 match state_machine.step(&mut buffer_set)? {
69 Action::WriteAndReadByte => {
70 stream.write_all(&buffer_set.write_buffer).await?;
71 stream.flush().await?;
72 let byte = stream.read_u8().await?;
73 state_machine.set_ssl_response(byte);
74 }
75 Action::ReadMessage => {
76 stream.read_message(&mut buffer_set).await?;
77 }
78 Action::Write => {
79 stream.write_all(&buffer_set.write_buffer).await?;
80 stream.flush().await?;
81 }
82 Action::WriteAndReadMessage => {
83 stream.write_all(&buffer_set.write_buffer).await?;
84 stream.flush().await?;
85 stream.read_message(&mut buffer_set).await?;
86 }
87 Action::TlsHandshake => {
88 #[cfg(feature = "tokio-tls")]
89 {
90 stream = stream.upgrade_to_tls(&options.host).await?;
91 }
92 #[cfg(not(feature = "tokio-tls"))]
93 {
94 return Err(Error::Unsupported(
95 "TLS requested but tokio-tls feature not enabled".into(),
96 ));
97 }
98 }
99 Action::HandleAsyncMessageAndReadMessage(_) => {
100 stream.read_message(&mut buffer_set).await?;
102 }
103 Action::Finished => break,
104 }
105 }
106
107 let conn = Self {
108 stream,
109 buffer_set,
110 backend_key: state_machine.backend_key().cloned(),
111 server_params: state_machine.take_server_params(),
112 transaction_status: state_machine.transaction_status(),
113 is_broken: false,
114 name_counter: 0,
115 async_message_handler: None,
116 };
117
118 let conn = if options.upgrade_to_unix_socket && conn.stream.is_tcp_loopback() {
120 conn.try_upgrade_to_unix_socket(&options).await
121 } else {
122 conn
123 };
124
125 Ok(conn)
126 }
127
128 fn try_upgrade_to_unix_socket(
131 mut self,
132 opts: &Opts,
133 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Self> + Send + '_>> {
134 let opts = opts.clone();
135 Box::pin(async move {
136 let mut handler = FirstRowHandler::<(String,)>::new();
138 if self
139 .query("SHOW unix_socket_directories", &mut handler)
140 .await
141 .is_err()
142 {
143 return self;
144 }
145
146 let socket_dir = match handler.into_row() {
147 Some((dirs,)) => {
148 match dirs.split(',').next() {
150 Some(d) if !d.trim().is_empty() => d.trim().to_string(),
151 _ => return self,
152 }
153 }
154 None => return self,
155 };
156
157 let socket_path = format!("{}/.s.PGSQL.{}", socket_dir, opts.port);
159
160 let unix_stream = match UnixStream::connect(&socket_path).await {
162 Ok(s) => s,
163 Err(_) => return self,
164 };
165
166 let mut opts_unix = opts.clone();
168 opts_unix.upgrade_to_unix_socket = false;
169
170 match Self::new_with_stream(Stream::unix(unix_stream), opts_unix).await {
171 Ok(new_conn) => new_conn,
172 Err(_) => self,
173 }
174 })
175 }
176
177 pub fn backend_key(&self) -> Option<&BackendKeyData> {
179 self.backend_key.as_ref()
180 }
181
182 pub fn connection_id(&self) -> u32 {
186 self.backend_key.as_ref().map_or(0, |k| k.process_id())
187 }
188
189 pub fn server_params(&self) -> &[(String, String)] {
191 &self.server_params
192 }
193
194 pub fn transaction_status(&self) -> TransactionStatus {
196 self.transaction_status
197 }
198
199 pub fn in_transaction(&self) -> bool {
201 self.transaction_status.in_transaction()
202 }
203
204 pub fn is_broken(&self) -> bool {
206 self.is_broken
207 }
208
209 pub(crate) fn next_portal_name(&mut self) -> String {
211 self.name_counter += 1;
212 format!("_zero_p_{}", self.name_counter)
213 }
214
215 pub(crate) async fn create_named_portal<S: IntoStatement, P: ToParams>(
219 &mut self,
220 portal_name: &str,
221 statement: &S,
222 params: &P,
223 ) -> Result<()> {
224 let mut state_machine = if let Some(sql) = statement.as_sql() {
226 BindStateMachine::bind_sql(&mut self.buffer_set, portal_name, sql, params)?
227 } else {
228 let stmt = statement.as_prepared().unwrap();
229 BindStateMachine::bind_prepared(
230 &mut self.buffer_set,
231 portal_name,
232 &stmt.wire_name(),
233 &stmt.param_oids,
234 params,
235 )?
236 };
237
238 loop {
240 match state_machine.step(&mut self.buffer_set)? {
241 Action::ReadMessage => {
242 self.stream.read_message(&mut self.buffer_set).await?;
243 }
244 Action::Write => {
245 self.stream.write_all(&self.buffer_set.write_buffer).await?;
246 self.stream.flush().await?;
247 }
248 Action::WriteAndReadMessage => {
249 self.stream.write_all(&self.buffer_set.write_buffer).await?;
250 self.stream.flush().await?;
251 self.stream.read_message(&mut self.buffer_set).await?;
252 }
253 Action::Finished => break,
254 _ => return Err(Error::Protocol("Unexpected action in bind".into())),
255 }
256 }
257
258 Ok(())
259 }
260
261 pub fn set_async_message_handler<H: AsyncMessageHandler + 'static>(&mut self, handler: H) {
268 self.async_message_handler = Some(Box::new(handler));
269 }
270
271 pub fn clear_async_message_handler(&mut self) {
273 self.async_message_handler = None;
274 }
275
276 pub async fn ping(&mut self) -> Result<()> {
278 self.query_drop("").await?;
279 Ok(())
280 }
281
282 async fn drive<S: StateMachine>(&mut self, state_machine: &mut S) -> Result<()> {
284 loop {
285 match state_machine.step(&mut self.buffer_set)? {
286 Action::WriteAndReadByte => {
287 return Err(Error::Protocol(
288 "Unexpected WriteAndReadByte in query state machine".into(),
289 ));
290 }
291 Action::ReadMessage => {
292 self.stream.read_message(&mut self.buffer_set).await?;
293 }
294 Action::Write => {
295 self.stream.write_all(&self.buffer_set.write_buffer).await?;
296 self.stream.flush().await?;
297 }
298 Action::WriteAndReadMessage => {
299 self.stream.write_all(&self.buffer_set.write_buffer).await?;
300 self.stream.flush().await?;
301 self.stream.read_message(&mut self.buffer_set).await?;
302 }
303 Action::TlsHandshake => {
304 return Err(Error::Protocol(
305 "Unexpected TlsHandshake in query state machine".into(),
306 ));
307 }
308 Action::HandleAsyncMessageAndReadMessage(ref async_msg) => {
309 if let Some(ref mut h) = self.async_message_handler {
310 h.handle(async_msg);
311 }
312 self.stream.read_message(&mut self.buffer_set).await?;
314 }
315 Action::Finished => {
316 self.transaction_status = state_machine.transaction_status();
317 break;
318 }
319 }
320 }
321 Ok(())
322 }
323
324 pub async fn query<H: TextHandler>(&mut self, sql: &str, handler: &mut H) -> Result<()> {
326 let result = self.query_inner(sql, handler).await;
327 if let Err(e) = &result
328 && e.is_connection_broken()
329 {
330 self.is_broken = true;
331 }
332 result
333 }
334
335 async fn query_inner<H: TextHandler>(&mut self, sql: &str, handler: &mut H) -> Result<()> {
336 let mut state_machine = SimpleQueryStateMachine::new(handler, sql);
337 self.drive(&mut state_machine).await
338 }
339
340 pub async fn query_drop(&mut self, sql: &str) -> Result<Option<u64>> {
342 let mut handler = DropHandler::new();
343 self.query(sql, &mut handler).await?;
344 Ok(handler.rows_affected())
345 }
346
347 pub async fn query_collect<T: for<'a> crate::conversion::FromRow<'a>>(
349 &mut self,
350 sql: &str,
351 ) -> Result<Vec<T>> {
352 let mut handler = crate::handler::CollectHandler::<T>::new();
353 self.query(sql, &mut handler).await?;
354 Ok(handler.into_rows())
355 }
356
357 pub async fn query_first<T: for<'a> crate::conversion::FromRow<'a>>(
359 &mut self,
360 sql: &str,
361 ) -> Result<Option<T>> {
362 let mut handler = crate::handler::FirstRowHandler::<T>::new();
363 self.query(sql, &mut handler).await?;
364 Ok(handler.into_row())
365 }
366
367 pub async fn close(mut self) -> Result<()> {
369 self.buffer_set.write_buffer.clear();
370 write_terminate(&mut self.buffer_set.write_buffer);
371 self.stream.write_all(&self.buffer_set.write_buffer).await?;
372 self.stream.flush().await?;
373 Ok(())
374 }
375
376 pub async fn prepare(&mut self, query: &str) -> Result<PreparedStatement> {
380 self.prepare_typed(query, &[]).await
381 }
382
383 pub async fn prepare_typed(
385 &mut self,
386 query: &str,
387 param_oids: &[u32],
388 ) -> Result<PreparedStatement> {
389 self.name_counter += 1;
390 let idx = self.name_counter;
391 let result = self.prepare_inner(idx, query, param_oids).await;
392 if let Err(e) = &result
393 && e.is_connection_broken()
394 {
395 self.is_broken = true;
396 }
397 result
398 }
399
400 pub async fn prepare_batch(&mut self, queries: &[&str]) -> Result<Vec<PreparedStatement>> {
417 if queries.is_empty() {
418 return Ok(Vec::new());
419 }
420
421 let start_idx = self.name_counter + 1;
422 self.name_counter += queries.len() as u64;
423
424 let result = self.prepare_batch_inner(queries, start_idx).await;
425 if let Err(e) = &result
426 && e.is_connection_broken()
427 {
428 self.is_broken = true;
429 }
430 result
431 }
432
433 async fn prepare_batch_inner(
434 &mut self,
435 queries: &[&str],
436 start_idx: u64,
437 ) -> Result<Vec<PreparedStatement>> {
438 use crate::state::batch_prepare::BatchPrepareStateMachine;
439
440 let mut state_machine =
441 BatchPrepareStateMachine::new(&mut self.buffer_set, queries, start_idx);
442
443 loop {
444 match state_machine.step(&mut self.buffer_set)? {
445 Action::ReadMessage => {
446 self.stream.read_message(&mut self.buffer_set).await?;
447 }
448 Action::WriteAndReadMessage => {
449 self.stream.write_all(&self.buffer_set.write_buffer).await?;
450 self.stream.flush().await?;
451 self.stream.read_message(&mut self.buffer_set).await?;
452 }
453 Action::Finished => {
454 self.transaction_status = state_machine.transaction_status();
455 break;
456 }
457 _ => return Err(Error::Protocol("Unexpected action in batch prepare".into())),
458 }
459 }
460
461 Ok(state_machine.take_statements())
462 }
463
464 async fn prepare_inner(
465 &mut self,
466 idx: u64,
467 query: &str,
468 param_oids: &[u32],
469 ) -> Result<PreparedStatement> {
470 let mut handler = DropHandler::new();
471 let mut state_machine = ExtendedQueryStateMachine::prepare(
472 &mut handler,
473 &mut self.buffer_set,
474 idx,
475 query,
476 param_oids,
477 );
478 self.drive(&mut state_machine).await?;
479 state_machine
480 .take_prepared_statement()
481 .ok_or_else(|| Error::Protocol("No prepared statement".into()))
482 }
483
484 pub async fn exec<S: IntoStatement, P: ToParams, H: BinaryHandler>(
490 &mut self,
491 statement: S,
492 params: P,
493 handler: &mut H,
494 ) -> Result<()> {
495 let result = self.exec_inner(&statement, ¶ms, handler).await;
496 if let Err(e) = &result
497 && e.is_connection_broken()
498 {
499 self.is_broken = true;
500 }
501 result
502 }
503
504 async fn exec_inner<S: IntoStatement, P: ToParams, H: BinaryHandler>(
505 &mut self,
506 statement: &S,
507 params: &P,
508 handler: &mut H,
509 ) -> Result<()> {
510 let mut state_machine = if statement.needs_parse() {
511 ExtendedQueryStateMachine::execute_sql(
512 handler,
513 &mut self.buffer_set,
514 statement.as_sql().unwrap(),
515 params,
516 )?
517 } else {
518 let stmt = statement.as_prepared().unwrap();
519 ExtendedQueryStateMachine::execute(
520 handler,
521 &mut self.buffer_set,
522 &stmt.wire_name(),
523 &stmt.param_oids,
524 params,
525 )?
526 };
527
528 self.drive(&mut state_machine).await
529 }
530
531 pub async fn exec_drop<S: IntoStatement, P: ToParams>(
535 &mut self,
536 statement: S,
537 params: P,
538 ) -> Result<Option<u64>> {
539 let mut handler = DropHandler::new();
540 self.exec(statement, params, &mut handler).await?;
541 Ok(handler.rows_affected())
542 }
543
544 pub async fn exec_collect<
548 T: for<'a> crate::conversion::FromRow<'a>,
549 S: IntoStatement,
550 P: ToParams,
551 >(
552 &mut self,
553 statement: S,
554 params: P,
555 ) -> Result<Vec<T>> {
556 let mut handler = crate::handler::CollectHandler::<T>::new();
557 self.exec(statement, params, &mut handler).await?;
558 Ok(handler.into_rows())
559 }
560
561 pub async fn exec_batch<S: IntoStatement, P: ToParams>(
592 &mut self,
593 statement: S,
594 params_list: &[P],
595 ) -> Result<()> {
596 self.exec_batch_chunked(statement, params_list, 1000).await
597 }
598
599 pub async fn exec_batch_chunked<S: IntoStatement, P: ToParams>(
603 &mut self,
604 statement: S,
605 params_list: &[P],
606 chunk_size: usize,
607 ) -> Result<()> {
608 let result = self
609 .exec_batch_inner(&statement, params_list, chunk_size)
610 .await;
611 if let Err(e) = &result
612 && e.is_connection_broken()
613 {
614 self.is_broken = true;
615 }
616 result
617 }
618
619 async fn exec_batch_inner<S: IntoStatement, P: ToParams>(
620 &mut self,
621 statement: &S,
622 params_list: &[P],
623 chunk_size: usize,
624 ) -> Result<()> {
625 use crate::protocol::frontend::{write_bind, write_execute, write_parse, write_sync};
626 use crate::state::extended::BatchStateMachine;
627
628 if params_list.is_empty() {
629 return Ok(());
630 }
631
632 let chunk_size = chunk_size.max(1);
633 let needs_parse = statement.needs_parse();
634 let sql = statement.as_sql();
635 let prepared = statement.as_prepared();
636
637 let param_oids: Vec<u32> = if let Some(stmt) = prepared {
639 stmt.param_oids.clone()
640 } else {
641 params_list[0].natural_oids()
642 };
643
644 let stmt_name = prepared.map(|s| s.wire_name()).unwrap_or_default();
646
647 for chunk in params_list.chunks(chunk_size) {
648 self.buffer_set.write_buffer.clear();
649
650 let parse_in_chunk = needs_parse;
652 if parse_in_chunk {
653 write_parse(
654 &mut self.buffer_set.write_buffer,
655 "",
656 sql.unwrap(),
657 ¶m_oids,
658 );
659 }
660
661 for params in chunk {
663 let effective_stmt_name = if needs_parse { "" } else { &stmt_name };
664 write_bind(
665 &mut self.buffer_set.write_buffer,
666 "",
667 effective_stmt_name,
668 params,
669 ¶m_oids,
670 )?;
671 write_execute(&mut self.buffer_set.write_buffer, "", 0);
672 }
673
674 write_sync(&mut self.buffer_set.write_buffer);
676
677 let mut state_machine = BatchStateMachine::new(parse_in_chunk);
679 self.drive_batch(&mut state_machine).await?;
680 self.transaction_status = state_machine.transaction_status();
681 }
682
683 Ok(())
684 }
685
686 async fn drive_batch(
688 &mut self,
689 state_machine: &mut crate::state::extended::BatchStateMachine,
690 ) -> Result<()> {
691 use crate::protocol::backend::{ReadyForQuery, msg_type};
692 use crate::state::action::Action;
693
694 loop {
695 let step_result = state_machine.step(&mut self.buffer_set);
696 match step_result {
697 Ok(Action::ReadMessage) => {
698 self.stream.read_message(&mut self.buffer_set).await?;
699 }
700 Ok(Action::WriteAndReadMessage) => {
701 self.stream.write_all(&self.buffer_set.write_buffer).await?;
702 self.stream.flush().await?;
703 self.stream.read_message(&mut self.buffer_set).await?;
704 }
705 Ok(Action::Finished) => {
706 break;
707 }
708 Ok(_) => return Err(Error::Protocol("Unexpected action in batch".into())),
709 Err(e) => {
710 loop {
712 self.stream.read_message(&mut self.buffer_set).await?;
713 if self.buffer_set.type_byte == msg_type::READY_FOR_QUERY {
714 let ready = ReadyForQuery::parse(&self.buffer_set.read_buffer)?;
715 self.transaction_status =
716 ready.transaction_status().unwrap_or_default();
717 break;
718 }
719 }
720 return Err(e);
721 }
722 }
723 }
724 Ok(())
725 }
726
727 pub async fn close_statement(&mut self, stmt: &PreparedStatement) -> Result<()> {
729 let result = self.close_statement_inner(&stmt.wire_name()).await;
730 if let Err(e) = &result
731 && e.is_connection_broken()
732 {
733 self.is_broken = true;
734 }
735 result
736 }
737
738 async fn close_statement_inner(&mut self, name: &str) -> Result<()> {
739 let mut handler = DropHandler::new();
740 let mut state_machine =
741 ExtendedQueryStateMachine::close_statement(&mut handler, &mut self.buffer_set, name);
742 self.drive(&mut state_machine).await
743 }
744
745 pub async fn lowlevel_flush(&mut self) -> Result<()> {
753 use crate::protocol::frontend::write_flush;
754
755 self.buffer_set.write_buffer.clear();
756 write_flush(&mut self.buffer_set.write_buffer);
757
758 self.stream.write_all(&self.buffer_set.write_buffer).await?;
759 self.stream.flush().await?;
760 Ok(())
761 }
762
763 pub async fn lowlevel_sync(&mut self) -> Result<()> {
770 let result = self.lowlevel_sync_inner().await;
771 if let Err(e) = &result
772 && e.is_connection_broken()
773 {
774 self.is_broken = true;
775 }
776 result
777 }
778
779 async fn lowlevel_sync_inner(&mut self) -> Result<()> {
780 use crate::protocol::backend::{ErrorResponse, RawMessage, ReadyForQuery, msg_type};
781 use crate::protocol::frontend::write_sync;
782
783 self.buffer_set.write_buffer.clear();
784 write_sync(&mut self.buffer_set.write_buffer);
785
786 self.stream.write_all(&self.buffer_set.write_buffer).await?;
787 self.stream.flush().await?;
788
789 let mut pending_error: Option<Error> = None;
790
791 loop {
792 self.stream.read_message(&mut self.buffer_set).await?;
793 let type_byte = self.buffer_set.type_byte;
794
795 if RawMessage::is_async_type(type_byte) {
796 continue;
797 }
798
799 match type_byte {
800 msg_type::READY_FOR_QUERY => {
801 let ready = ReadyForQuery::parse(&self.buffer_set.read_buffer)?;
802 self.transaction_status = ready.transaction_status().unwrap_or_default();
803 if let Some(e) = pending_error {
804 return Err(e);
805 }
806 return Ok(());
807 }
808 msg_type::ERROR_RESPONSE => {
809 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
810 pending_error = Some(error.into_error());
811 }
812 _ => {
813 }
815 }
816 }
817 }
818
819 pub async fn lowlevel_bind<P: ToParams>(
829 &mut self,
830 portal: &str,
831 statement_name: &str,
832 params: P,
833 ) -> Result<()> {
834 let result = self
835 .lowlevel_bind_inner(portal, statement_name, ¶ms)
836 .await;
837 if let Err(e) = &result
838 && e.is_connection_broken()
839 {
840 self.is_broken = true;
841 }
842 result
843 }
844
845 async fn lowlevel_bind_inner<P: ToParams>(
846 &mut self,
847 portal: &str,
848 statement_name: &str,
849 params: &P,
850 ) -> Result<()> {
851 use crate::protocol::backend::{BindComplete, ErrorResponse, RawMessage, msg_type};
852 use crate::protocol::frontend::{write_bind, write_flush};
853
854 let param_oids = params.natural_oids();
855 self.buffer_set.write_buffer.clear();
856 write_bind(
857 &mut self.buffer_set.write_buffer,
858 portal,
859 statement_name,
860 params,
861 ¶m_oids,
862 )?;
863 write_flush(&mut self.buffer_set.write_buffer);
864
865 self.stream.write_all(&self.buffer_set.write_buffer).await?;
866 self.stream.flush().await?;
867
868 loop {
869 self.stream.read_message(&mut self.buffer_set).await?;
870 let type_byte = self.buffer_set.type_byte;
871
872 if RawMessage::is_async_type(type_byte) {
873 continue;
874 }
875
876 match type_byte {
877 msg_type::BIND_COMPLETE => {
878 BindComplete::parse(&self.buffer_set.read_buffer)?;
879 return Ok(());
880 }
881 msg_type::ERROR_RESPONSE => {
882 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
883 return Err(error.into_error());
884 }
885 _ => {
886 return Err(Error::Protocol(format!(
887 "Expected BindComplete or ErrorResponse, got '{}'",
888 type_byte as char
889 )));
890 }
891 }
892 }
893 }
894
895 pub async fn lowlevel_execute<H: BinaryHandler>(
908 &mut self,
909 portal: &str,
910 max_rows: u32,
911 handler: &mut H,
912 ) -> Result<bool> {
913 let result = self.lowlevel_execute_inner(portal, max_rows, handler).await;
914 if let Err(e) = &result
915 && e.is_connection_broken()
916 {
917 self.is_broken = true;
918 }
919 result
920 }
921
922 async fn lowlevel_execute_inner<H: BinaryHandler>(
923 &mut self,
924 portal: &str,
925 max_rows: u32,
926 handler: &mut H,
927 ) -> Result<bool> {
928 use crate::protocol::backend::{
929 CommandComplete, DataRow, ErrorResponse, NoData, PortalSuspended, RawMessage,
930 RowDescription, msg_type,
931 };
932 use crate::protocol::frontend::{write_describe_portal, write_execute, write_flush};
933
934 self.buffer_set.write_buffer.clear();
935 write_describe_portal(&mut self.buffer_set.write_buffer, portal);
936 write_execute(&mut self.buffer_set.write_buffer, portal, max_rows);
937 write_flush(&mut self.buffer_set.write_buffer);
938
939 self.stream.write_all(&self.buffer_set.write_buffer).await?;
940 self.stream.flush().await?;
941
942 let mut column_buffer: Vec<u8> = Vec::new();
943
944 loop {
945 self.stream.read_message(&mut self.buffer_set).await?;
946 let type_byte = self.buffer_set.type_byte;
947
948 if RawMessage::is_async_type(type_byte) {
949 continue;
950 }
951
952 match type_byte {
953 msg_type::ROW_DESCRIPTION => {
954 column_buffer.clear();
955 column_buffer.extend_from_slice(&self.buffer_set.read_buffer);
956 let cols = RowDescription::parse(&column_buffer)?;
957 handler.result_start(cols)?;
958 }
959 msg_type::NO_DATA => {
960 NoData::parse(&self.buffer_set.read_buffer)?;
961 }
962 msg_type::DATA_ROW => {
963 let cols = RowDescription::parse(&column_buffer)?;
964 let row = DataRow::parse(&self.buffer_set.read_buffer)?;
965 handler.row(cols, row)?;
966 }
967 msg_type::COMMAND_COMPLETE => {
968 let complete = CommandComplete::parse(&self.buffer_set.read_buffer)?;
969 handler.result_end(complete)?;
970 return Ok(false); }
972 msg_type::PORTAL_SUSPENDED => {
973 PortalSuspended::parse(&self.buffer_set.read_buffer)?;
974 return Ok(true); }
976 msg_type::ERROR_RESPONSE => {
977 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
978 return Err(error.into_error());
979 }
980 _ => {
981 return Err(Error::Protocol(format!(
982 "Unexpected message in execute: '{}'",
983 type_byte as char
984 )));
985 }
986 }
987 }
988 }
989
990 pub async fn exec_iter<S: IntoStatement, P, F, Fut, T>(
1020 &mut self,
1021 statement: S,
1022 params: P,
1023 f: F,
1024 ) -> Result<T>
1025 where
1026 P: ToParams,
1027 F: FnOnce(&mut super::unnamed_portal::UnnamedPortal<'_>) -> Fut,
1028 Fut: std::future::Future<Output = Result<T>>,
1029 {
1030 let result = self.exec_iter_inner(&statement, ¶ms, f).await;
1031 if let Err(e) = &result
1032 && e.is_connection_broken()
1033 {
1034 self.is_broken = true;
1035 }
1036 result
1037 }
1038
1039 async fn exec_iter_inner<S: IntoStatement, P, F, Fut, T>(
1040 &mut self,
1041 statement: &S,
1042 params: &P,
1043 f: F,
1044 ) -> Result<T>
1045 where
1046 P: ToParams,
1047 F: FnOnce(&mut super::unnamed_portal::UnnamedPortal<'_>) -> Fut,
1048 Fut: std::future::Future<Output = Result<T>>,
1049 {
1050 let mut state_machine = if let Some(sql) = statement.as_sql() {
1052 BindStateMachine::bind_sql(&mut self.buffer_set, "", sql, params)?
1053 } else {
1054 let stmt = statement.as_prepared().unwrap();
1055 BindStateMachine::bind_prepared(
1056 &mut self.buffer_set,
1057 "",
1058 &stmt.wire_name(),
1059 &stmt.param_oids,
1060 params,
1061 )?
1062 };
1063
1064 loop {
1066 match state_machine.step(&mut self.buffer_set)? {
1067 Action::ReadMessage => {
1068 self.stream.read_message(&mut self.buffer_set).await?;
1069 }
1070 Action::Write => {
1071 self.stream.write_all(&self.buffer_set.write_buffer).await?;
1072 self.stream.flush().await?;
1073 }
1074 Action::WriteAndReadMessage => {
1075 self.stream.write_all(&self.buffer_set.write_buffer).await?;
1076 self.stream.flush().await?;
1077 self.stream.read_message(&mut self.buffer_set).await?;
1078 }
1079 Action::Finished => break,
1080 _ => return Err(Error::Protocol("Unexpected action in bind".into())),
1081 }
1082 }
1083
1084 let mut portal = super::unnamed_portal::UnnamedPortal { conn: self };
1086 let result = f(&mut portal).await;
1087
1088 let sync_result = portal.conn.lowlevel_sync().await;
1090
1091 match (result, sync_result) {
1093 (Ok(v), Ok(())) => Ok(v),
1094 (Err(e), _) => Err(e),
1095 (Ok(_), Err(e)) => Err(e),
1096 }
1097 }
1098
1099 pub async fn lowlevel_close_portal(&mut self, portal: &str) -> Result<()> {
1101 let result = self.lowlevel_close_portal_inner(portal).await;
1102 if let Err(e) = &result
1103 && e.is_connection_broken()
1104 {
1105 self.is_broken = true;
1106 }
1107 result
1108 }
1109
1110 async fn lowlevel_close_portal_inner(&mut self, portal: &str) -> Result<()> {
1111 use crate::protocol::backend::{CloseComplete, ErrorResponse, RawMessage, msg_type};
1112 use crate::protocol::frontend::{write_close_portal, write_flush};
1113
1114 self.buffer_set.write_buffer.clear();
1115 write_close_portal(&mut self.buffer_set.write_buffer, portal);
1116 write_flush(&mut self.buffer_set.write_buffer);
1117
1118 self.stream.write_all(&self.buffer_set.write_buffer).await?;
1119 self.stream.flush().await?;
1120
1121 loop {
1122 self.stream.read_message(&mut self.buffer_set).await?;
1123 let type_byte = self.buffer_set.type_byte;
1124
1125 if RawMessage::is_async_type(type_byte) {
1126 continue;
1127 }
1128
1129 match type_byte {
1130 msg_type::CLOSE_COMPLETE => {
1131 CloseComplete::parse(&self.buffer_set.read_buffer)?;
1132 return Ok(());
1133 }
1134 msg_type::ERROR_RESPONSE => {
1135 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
1136 return Err(error.into_error());
1137 }
1138 _ => {
1139 return Err(Error::Protocol(format!(
1140 "Expected CloseComplete or ErrorResponse, got '{}'",
1141 type_byte as char
1142 )));
1143 }
1144 }
1145 }
1146 }
1147
1148 pub async fn run_pipeline<T, F, Fut>(&mut self, f: F) -> Result<T>
1179 where
1180 F: FnOnce(&mut super::pipeline::Pipeline<'_>) -> Fut,
1181 Fut: std::future::Future<Output = Result<T>>,
1182 {
1183 let mut pipeline = super::pipeline::Pipeline::new_inner(self);
1184 let result = f(&mut pipeline).await;
1185 pipeline.cleanup().await;
1186 result
1187 }
1188
1189 pub async fn tx<F, R, Fut>(&mut self, f: F) -> Result<R>
1199 where
1200 F: FnOnce(&mut Conn, super::transaction::Transaction) -> Fut,
1201 Fut: std::future::Future<Output = Result<R>>,
1202 {
1203 if self.in_transaction() {
1204 return Err(Error::InvalidUsage(
1205 "nested transactions are not supported".into(),
1206 ));
1207 }
1208
1209 self.query_drop("BEGIN").await?;
1210
1211 let tx = super::transaction::Transaction::new(self.connection_id());
1212
1213 let result = f(self, tx).await;
1216
1217 if self.in_transaction() {
1219 let rollback_result = self.query_drop("ROLLBACK").await;
1220
1221 if let Err(e) = result {
1223 return Err(e);
1224 }
1225 rollback_result?;
1226 }
1227
1228 result
1229 }
1230}