1use tokio::net::TcpStream;
4#[cfg(unix)]
5use tokio::net::UnixStream;
6
7use crate::buffer_pool::PooledBufferSet;
8use crate::conversion::ToParams;
9use crate::error::{Error, Result};
10use crate::handler::{
11 AsyncMessageHandler, BinaryHandler, DropHandler, FirstRowHandler, TextHandler,
12};
13use crate::opts::Opts;
14use crate::protocol::backend::BackendKeyData;
15use crate::protocol::frontend::write_terminate;
16use crate::protocol::types::TransactionStatus;
17use crate::state::StateMachine;
18use crate::state::action::Action;
19use crate::state::connection::ConnectionStateMachine;
20use crate::state::extended::{BindStateMachine, ExtendedQueryStateMachine, PreparedStatement};
21use crate::state::simple_query::SimpleQueryStateMachine;
22use crate::statement::IntoStatement;
23
24use super::stream::Stream;
25
26pub struct Conn {
28 pub(crate) stream: Stream,
29 pub(crate) buffer_set: PooledBufferSet,
30 backend_key: Option<BackendKeyData>,
31 server_params: Vec<(String, String)>,
32 pub(crate) transaction_status: TransactionStatus,
33 pub(crate) is_broken: bool,
34 name_counter: u64,
35 async_message_handler: Option<Box<dyn AsyncMessageHandler>>,
36}
37
38impl Conn {
39 pub async fn new<O: TryInto<Opts>>(opts: O) -> Result<Self>
41 where
42 Error: From<O::Error>,
43 {
44 let opts = opts.try_into()?;
45
46 let stream = if let Some(socket_path) = &opts.socket {
47 #[cfg(unix)]
48 {
49 Stream::unix(UnixStream::connect(socket_path).await?)
50 }
51 #[cfg(not(unix))]
52 {
53 let _ = socket_path;
54 return Err(Error::Unsupported(
55 "Unix sockets are not supported on this platform".into(),
56 ));
57 }
58 } else {
59 if opts.host.is_empty() {
60 return Err(Error::InvalidUsage("host is empty".into()));
61 }
62 let addr = format!("{}:{}", opts.host, opts.port);
63 let tcp = TcpStream::connect(&addr).await?;
64 tcp.set_nodelay(true)?;
65 Stream::tcp(tcp)
66 };
67
68 Self::new_with_stream(stream, opts).await
69 }
70
71 #[allow(unused_mut)]
73 pub async fn new_with_stream(mut stream: Stream, options: Opts) -> Result<Self> {
74 let mut buffer_set = options.buffer_pool.get_buffer_set();
75 let mut state_machine = ConnectionStateMachine::new(options.clone());
76
77 loop {
79 match state_machine.step(&mut buffer_set)? {
80 Action::WriteAndReadByte => {
81 stream.write_all(&buffer_set.write_buffer).await?;
82 stream.flush().await?;
83 let byte = stream.read_u8().await?;
84 state_machine.set_ssl_response(byte);
85 }
86 Action::ReadMessage => {
87 stream.read_message(&mut buffer_set).await?;
88 }
89 Action::Write => {
90 stream.write_all(&buffer_set.write_buffer).await?;
91 stream.flush().await?;
92 }
93 Action::WriteAndReadMessage => {
94 stream.write_all(&buffer_set.write_buffer).await?;
95 stream.flush().await?;
96 stream.read_message(&mut buffer_set).await?;
97 }
98 Action::TlsHandshake => {
99 #[cfg(feature = "tokio-tls")]
100 {
101 stream = stream.upgrade_to_tls(&options.host).await?;
102 }
103 #[cfg(not(feature = "tokio-tls"))]
104 {
105 return Err(Error::Unsupported(
106 "TLS requested but tokio-tls feature not enabled".into(),
107 ));
108 }
109 }
110 Action::HandleAsyncMessageAndReadMessage(_) => {
111 stream.read_message(&mut buffer_set).await?;
113 }
114 Action::Finished => break,
115 }
116 }
117
118 let conn = Self {
119 stream,
120 buffer_set,
121 backend_key: state_machine.backend_key().cloned(),
122 server_params: state_machine.take_server_params(),
123 transaction_status: state_machine.transaction_status(),
124 is_broken: false,
125 name_counter: 0,
126 async_message_handler: None,
127 };
128
129 #[cfg(unix)]
131 let conn = if options.upgrade_to_unix_socket && conn.stream.is_tcp_loopback() {
132 conn.try_upgrade_to_unix_socket(&options).await
133 } else {
134 conn
135 };
136
137 Ok(conn)
138 }
139
140 #[cfg(unix)]
143 fn try_upgrade_to_unix_socket(
144 mut self,
145 opts: &Opts,
146 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Self> + Send + '_>> {
147 let opts = opts.clone();
148 Box::pin(async move {
149 let mut handler = FirstRowHandler::<(String,)>::new();
151 if self
152 .query("SHOW unix_socket_directories", &mut handler)
153 .await
154 .is_err()
155 {
156 return self;
157 }
158
159 let socket_dir = match handler.into_row() {
160 Some((dirs,)) => {
161 match dirs.split(',').next() {
163 Some(d) if !d.trim().is_empty() => d.trim().to_string(),
164 _ => return self,
165 }
166 }
167 None => return self,
168 };
169
170 let socket_path = format!("{}/.s.PGSQL.{}", socket_dir, opts.port);
172
173 let unix_stream = match UnixStream::connect(&socket_path).await {
175 Ok(s) => s,
176 Err(_) => return self,
177 };
178
179 let mut opts_unix = opts.clone();
181 opts_unix.upgrade_to_unix_socket = false;
182
183 match Self::new_with_stream(Stream::unix(unix_stream), opts_unix).await {
184 Ok(new_conn) => new_conn,
185 Err(_) => self,
186 }
187 })
188 }
189
190 pub fn backend_key(&self) -> Option<&BackendKeyData> {
192 self.backend_key.as_ref()
193 }
194
195 pub fn connection_id(&self) -> u32 {
199 self.backend_key.as_ref().map_or(0, |k| k.process_id())
200 }
201
202 pub fn server_params(&self) -> &[(String, String)] {
204 &self.server_params
205 }
206
207 pub fn transaction_status(&self) -> TransactionStatus {
209 self.transaction_status
210 }
211
212 pub fn in_transaction(&self) -> bool {
214 self.transaction_status.in_transaction()
215 }
216
217 pub fn is_broken(&self) -> bool {
219 self.is_broken
220 }
221
222 pub(crate) fn next_portal_name(&mut self) -> String {
224 self.name_counter += 1;
225 format!("_zero_p_{}", self.name_counter)
226 }
227
228 pub(crate) async fn create_named_portal<S: IntoStatement, P: ToParams>(
232 &mut self,
233 portal_name: &str,
234 statement: &S,
235 params: &P,
236 ) -> Result<()> {
237 let mut state_machine = if let Some(sql) = statement.as_sql() {
239 BindStateMachine::bind_sql(&mut self.buffer_set, portal_name, sql, params)?
240 } else {
241 let stmt = statement.as_prepared().unwrap();
242 BindStateMachine::bind_prepared(
243 &mut self.buffer_set,
244 portal_name,
245 &stmt.wire_name(),
246 &stmt.param_oids,
247 params,
248 )?
249 };
250
251 loop {
253 match state_machine.step(&mut self.buffer_set)? {
254 Action::ReadMessage => {
255 self.stream.read_message(&mut self.buffer_set).await?;
256 }
257 Action::Write => {
258 self.stream.write_all(&self.buffer_set.write_buffer).await?;
259 self.stream.flush().await?;
260 }
261 Action::WriteAndReadMessage => {
262 self.stream.write_all(&self.buffer_set.write_buffer).await?;
263 self.stream.flush().await?;
264 self.stream.read_message(&mut self.buffer_set).await?;
265 }
266 Action::Finished => break,
267 _ => return Err(Error::Protocol("Unexpected action in bind".into())),
268 }
269 }
270
271 Ok(())
272 }
273
274 pub fn set_async_message_handler<H: AsyncMessageHandler + 'static>(&mut self, handler: H) {
281 self.async_message_handler = Some(Box::new(handler));
282 }
283
284 pub fn clear_async_message_handler(&mut self) {
286 self.async_message_handler = None;
287 }
288
289 pub async fn ping(&mut self) -> Result<()> {
291 self.query_drop("").await?;
292 Ok(())
293 }
294
295 async fn drive<S: StateMachine>(&mut self, state_machine: &mut S) -> Result<()> {
297 loop {
298 match state_machine.step(&mut self.buffer_set)? {
299 Action::WriteAndReadByte => {
300 return Err(Error::Protocol(
301 "Unexpected WriteAndReadByte in query state machine".into(),
302 ));
303 }
304 Action::ReadMessage => {
305 self.stream.read_message(&mut self.buffer_set).await?;
306 }
307 Action::Write => {
308 self.stream.write_all(&self.buffer_set.write_buffer).await?;
309 self.stream.flush().await?;
310 }
311 Action::WriteAndReadMessage => {
312 self.stream.write_all(&self.buffer_set.write_buffer).await?;
313 self.stream.flush().await?;
314 self.stream.read_message(&mut self.buffer_set).await?;
315 }
316 Action::TlsHandshake => {
317 return Err(Error::Protocol(
318 "Unexpected TlsHandshake in query state machine".into(),
319 ));
320 }
321 Action::HandleAsyncMessageAndReadMessage(ref async_msg) => {
322 if let Some(ref mut h) = self.async_message_handler {
323 h.handle(async_msg);
324 }
325 self.stream.read_message(&mut self.buffer_set).await?;
327 }
328 Action::Finished => {
329 self.transaction_status = state_machine.transaction_status();
330 break;
331 }
332 }
333 }
334 Ok(())
335 }
336
337 pub async fn query<H: TextHandler>(&mut self, sql: &str, handler: &mut H) -> Result<()> {
339 let result = self.query_inner(sql, handler).await;
340 if let Err(e) = &result
341 && e.is_connection_broken()
342 {
343 self.is_broken = true;
344 }
345 result
346 }
347
348 async fn query_inner<H: TextHandler>(&mut self, sql: &str, handler: &mut H) -> Result<()> {
349 let mut state_machine = SimpleQueryStateMachine::new(handler, sql);
350 self.drive(&mut state_machine).await
351 }
352
353 pub async fn query_drop(&mut self, sql: &str) -> Result<Option<u64>> {
355 let mut handler = DropHandler::new();
356 self.query(sql, &mut handler).await?;
357 Ok(handler.rows_affected())
358 }
359
360 pub async fn query_collect<T: for<'a> crate::conversion::FromRow<'a>>(
362 &mut self,
363 sql: &str,
364 ) -> Result<Vec<T>> {
365 let mut handler = crate::handler::CollectHandler::<T>::new();
366 self.query(sql, &mut handler).await?;
367 Ok(handler.into_rows())
368 }
369
370 pub async fn query_first<T: for<'a> crate::conversion::FromRow<'a>>(
372 &mut self,
373 sql: &str,
374 ) -> Result<Option<T>> {
375 let mut handler = crate::handler::FirstRowHandler::<T>::new();
376 self.query(sql, &mut handler).await?;
377 Ok(handler.into_row())
378 }
379
380 pub async fn query_foreach<T: for<'a> crate::conversion::FromRow<'a>, F: FnMut(T)>(
390 &mut self,
391 sql: &str,
392 f: F,
393 ) -> Result<()> {
394 let mut handler = crate::handler::ForEachHandler::<T, F>::new(f);
395 self.query(sql, &mut handler).await?;
396 Ok(())
397 }
398
399 pub async fn close(mut self) -> Result<()> {
401 self.buffer_set.write_buffer.clear();
402 write_terminate(&mut self.buffer_set.write_buffer);
403 self.stream.write_all(&self.buffer_set.write_buffer).await?;
404 self.stream.flush().await?;
405 Ok(())
406 }
407
408 pub async fn prepare(&mut self, query: &str) -> Result<PreparedStatement> {
412 self.prepare_typed(query, &[]).await
413 }
414
415 pub async fn prepare_typed(
417 &mut self,
418 query: &str,
419 param_oids: &[u32],
420 ) -> Result<PreparedStatement> {
421 self.name_counter += 1;
422 let idx = self.name_counter;
423 let result = self.prepare_inner(idx, query, param_oids).await;
424 if let Err(e) = &result
425 && e.is_connection_broken()
426 {
427 self.is_broken = true;
428 }
429 result
430 }
431
432 pub async fn prepare_batch(&mut self, queries: &[&str]) -> Result<Vec<PreparedStatement>> {
449 if queries.is_empty() {
450 return Ok(Vec::new());
451 }
452
453 let start_idx = self.name_counter + 1;
454 self.name_counter += queries.len() as u64;
455
456 let result = self.prepare_batch_inner(queries, start_idx).await;
457 if let Err(e) = &result
458 && e.is_connection_broken()
459 {
460 self.is_broken = true;
461 }
462 result
463 }
464
465 async fn prepare_batch_inner(
466 &mut self,
467 queries: &[&str],
468 start_idx: u64,
469 ) -> Result<Vec<PreparedStatement>> {
470 use crate::state::batch_prepare::BatchPrepareStateMachine;
471
472 let mut state_machine =
473 BatchPrepareStateMachine::new(&mut self.buffer_set, queries, start_idx);
474
475 loop {
476 match state_machine.step(&mut self.buffer_set)? {
477 Action::ReadMessage => {
478 self.stream.read_message(&mut self.buffer_set).await?;
479 }
480 Action::WriteAndReadMessage => {
481 self.stream.write_all(&self.buffer_set.write_buffer).await?;
482 self.stream.flush().await?;
483 self.stream.read_message(&mut self.buffer_set).await?;
484 }
485 Action::Finished => {
486 self.transaction_status = state_machine.transaction_status();
487 break;
488 }
489 _ => return Err(Error::Protocol("Unexpected action in batch prepare".into())),
490 }
491 }
492
493 Ok(state_machine.take_statements())
494 }
495
496 async fn prepare_inner(
497 &mut self,
498 idx: u64,
499 query: &str,
500 param_oids: &[u32],
501 ) -> Result<PreparedStatement> {
502 let mut handler = DropHandler::new();
503 let mut state_machine = ExtendedQueryStateMachine::prepare(
504 &mut handler,
505 &mut self.buffer_set,
506 idx,
507 query,
508 param_oids,
509 );
510 self.drive(&mut state_machine).await?;
511 state_machine
512 .take_prepared_statement()
513 .ok_or_else(|| Error::Protocol("No prepared statement".into()))
514 }
515
516 pub async fn exec<S: IntoStatement, P: ToParams, H: BinaryHandler>(
522 &mut self,
523 statement: S,
524 params: P,
525 handler: &mut H,
526 ) -> Result<()> {
527 let result = self.exec_inner(&statement, ¶ms, handler).await;
528 if let Err(e) = &result
529 && e.is_connection_broken()
530 {
531 self.is_broken = true;
532 }
533 result
534 }
535
536 async fn exec_inner<S: IntoStatement, P: ToParams, H: BinaryHandler>(
537 &mut self,
538 statement: &S,
539 params: &P,
540 handler: &mut H,
541 ) -> Result<()> {
542 let mut state_machine = if statement.needs_parse() {
543 ExtendedQueryStateMachine::execute_sql(
544 handler,
545 &mut self.buffer_set,
546 statement.as_sql().unwrap(),
547 params,
548 )?
549 } else {
550 let stmt = statement.as_prepared().unwrap();
551 ExtendedQueryStateMachine::execute(
552 handler,
553 &mut self.buffer_set,
554 &stmt.wire_name(),
555 &stmt.param_oids,
556 params,
557 )?
558 };
559
560 self.drive(&mut state_machine).await
561 }
562
563 pub async fn exec_drop<S: IntoStatement, P: ToParams>(
567 &mut self,
568 statement: S,
569 params: P,
570 ) -> Result<Option<u64>> {
571 let mut handler = DropHandler::new();
572 self.exec(statement, params, &mut handler).await?;
573 Ok(handler.rows_affected())
574 }
575
576 pub async fn exec_collect<
580 T: for<'a> crate::conversion::FromRow<'a>,
581 S: IntoStatement,
582 P: ToParams,
583 >(
584 &mut self,
585 statement: S,
586 params: P,
587 ) -> Result<Vec<T>> {
588 let mut handler = crate::handler::CollectHandler::<T>::new();
589 self.exec(statement, params, &mut handler).await?;
590 Ok(handler.into_rows())
591 }
592
593 pub async fn exec_first<
597 T: for<'a> crate::conversion::FromRow<'a>,
598 S: IntoStatement,
599 P: ToParams,
600 >(
601 &mut self,
602 statement: S,
603 params: P,
604 ) -> Result<Option<T>> {
605 let mut handler = crate::handler::FirstRowHandler::<T>::new();
606 self.exec(statement, params, &mut handler).await?;
607 Ok(handler.into_row())
608 }
609
610 pub async fn exec_foreach<
623 T: for<'a> crate::conversion::FromRow<'a>,
624 S: IntoStatement,
625 P: ToParams,
626 F: FnMut(T),
627 >(
628 &mut self,
629 statement: S,
630 params: P,
631 f: F,
632 ) -> Result<()> {
633 let mut handler = crate::handler::ForEachHandler::<T, F>::new(f);
634 self.exec(statement, params, &mut handler).await?;
635 Ok(())
636 }
637
638 pub async fn exec_batch<S: IntoStatement, P: ToParams>(
669 &mut self,
670 statement: S,
671 params_list: &[P],
672 ) -> Result<()> {
673 self.exec_batch_chunked(statement, params_list, 1000).await
674 }
675
676 pub async fn exec_batch_chunked<S: IntoStatement, P: ToParams>(
680 &mut self,
681 statement: S,
682 params_list: &[P],
683 chunk_size: usize,
684 ) -> Result<()> {
685 let result = self
686 .exec_batch_inner(&statement, params_list, chunk_size)
687 .await;
688 if let Err(e) = &result
689 && e.is_connection_broken()
690 {
691 self.is_broken = true;
692 }
693 result
694 }
695
696 async fn exec_batch_inner<S: IntoStatement, P: ToParams>(
697 &mut self,
698 statement: &S,
699 params_list: &[P],
700 chunk_size: usize,
701 ) -> Result<()> {
702 use crate::protocol::frontend::{write_bind, write_execute, write_parse, write_sync};
703 use crate::state::extended::BatchStateMachine;
704
705 if params_list.is_empty() {
706 return Ok(());
707 }
708
709 let chunk_size = chunk_size.max(1);
710 let needs_parse = statement.needs_parse();
711 let sql = statement.as_sql();
712 let prepared = statement.as_prepared();
713
714 let param_oids: Vec<u32> = if let Some(stmt) = prepared {
716 stmt.param_oids.clone()
717 } else {
718 params_list[0].natural_oids()
719 };
720
721 let stmt_name = prepared.map(|s| s.wire_name()).unwrap_or_default();
723
724 for chunk in params_list.chunks(chunk_size) {
725 self.buffer_set.write_buffer.clear();
726
727 let parse_in_chunk = needs_parse;
729 if parse_in_chunk {
730 write_parse(
731 &mut self.buffer_set.write_buffer,
732 "",
733 sql.unwrap(),
734 ¶m_oids,
735 );
736 }
737
738 for params in chunk {
740 let effective_stmt_name = if needs_parse { "" } else { &stmt_name };
741 write_bind(
742 &mut self.buffer_set.write_buffer,
743 "",
744 effective_stmt_name,
745 params,
746 ¶m_oids,
747 )?;
748 write_execute(&mut self.buffer_set.write_buffer, "", 0);
749 }
750
751 write_sync(&mut self.buffer_set.write_buffer);
753
754 let mut state_machine = BatchStateMachine::new(parse_in_chunk);
756 self.drive_batch(&mut state_machine).await?;
757 self.transaction_status = state_machine.transaction_status();
758 }
759
760 Ok(())
761 }
762
763 async fn drive_batch(
765 &mut self,
766 state_machine: &mut crate::state::extended::BatchStateMachine,
767 ) -> Result<()> {
768 use crate::protocol::backend::{ReadyForQuery, msg_type};
769 use crate::state::action::Action;
770
771 loop {
772 let step_result = state_machine.step(&mut self.buffer_set);
773 match step_result {
774 Ok(Action::ReadMessage) => {
775 self.stream.read_message(&mut self.buffer_set).await?;
776 }
777 Ok(Action::WriteAndReadMessage) => {
778 self.stream.write_all(&self.buffer_set.write_buffer).await?;
779 self.stream.flush().await?;
780 self.stream.read_message(&mut self.buffer_set).await?;
781 }
782 Ok(Action::Finished) => {
783 break;
784 }
785 Ok(_) => return Err(Error::Protocol("Unexpected action in batch".into())),
786 Err(e) => {
787 loop {
789 self.stream.read_message(&mut self.buffer_set).await?;
790 if self.buffer_set.type_byte == msg_type::READY_FOR_QUERY {
791 let ready = ReadyForQuery::parse(&self.buffer_set.read_buffer)?;
792 self.transaction_status =
793 ready.transaction_status().unwrap_or_default();
794 break;
795 }
796 }
797 return Err(e);
798 }
799 }
800 }
801 Ok(())
802 }
803
804 pub async fn close_statement(&mut self, stmt: &PreparedStatement) -> Result<()> {
806 let result = self.close_statement_inner(&stmt.wire_name()).await;
807 if let Err(e) = &result
808 && e.is_connection_broken()
809 {
810 self.is_broken = true;
811 }
812 result
813 }
814
815 async fn close_statement_inner(&mut self, name: &str) -> Result<()> {
816 let mut handler = DropHandler::new();
817 let mut state_machine =
818 ExtendedQueryStateMachine::close_statement(&mut handler, &mut self.buffer_set, name);
819 self.drive(&mut state_machine).await
820 }
821
822 pub async fn lowlevel_flush(&mut self) -> Result<()> {
830 use crate::protocol::frontend::write_flush;
831
832 self.buffer_set.write_buffer.clear();
833 write_flush(&mut self.buffer_set.write_buffer);
834
835 self.stream.write_all(&self.buffer_set.write_buffer).await?;
836 self.stream.flush().await?;
837 Ok(())
838 }
839
840 pub async fn lowlevel_sync(&mut self) -> Result<()> {
847 let result = self.lowlevel_sync_inner().await;
848 if let Err(e) = &result
849 && e.is_connection_broken()
850 {
851 self.is_broken = true;
852 }
853 result
854 }
855
856 async fn lowlevel_sync_inner(&mut self) -> Result<()> {
857 use crate::protocol::backend::{ErrorResponse, RawMessage, ReadyForQuery, msg_type};
858 use crate::protocol::frontend::write_sync;
859
860 self.buffer_set.write_buffer.clear();
861 write_sync(&mut self.buffer_set.write_buffer);
862
863 self.stream.write_all(&self.buffer_set.write_buffer).await?;
864 self.stream.flush().await?;
865
866 let mut pending_error: Option<Error> = None;
867
868 loop {
869 self.stream.read_message(&mut self.buffer_set).await?;
870 let type_byte = self.buffer_set.type_byte;
871
872 if RawMessage::is_async_type(type_byte) {
873 continue;
874 }
875
876 match type_byte {
877 msg_type::READY_FOR_QUERY => {
878 let ready = ReadyForQuery::parse(&self.buffer_set.read_buffer)?;
879 self.transaction_status = ready.transaction_status().unwrap_or_default();
880 if let Some(e) = pending_error {
881 return Err(e);
882 }
883 return Ok(());
884 }
885 msg_type::ERROR_RESPONSE => {
886 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
887 pending_error = Some(error.into_error());
888 }
889 _ => {
890 }
892 }
893 }
894 }
895
896 pub async fn lowlevel_bind<P: ToParams>(
906 &mut self,
907 portal: &str,
908 statement_name: &str,
909 params: P,
910 ) -> Result<()> {
911 let result = self
912 .lowlevel_bind_inner(portal, statement_name, ¶ms)
913 .await;
914 if let Err(e) = &result
915 && e.is_connection_broken()
916 {
917 self.is_broken = true;
918 }
919 result
920 }
921
922 async fn lowlevel_bind_inner<P: ToParams>(
923 &mut self,
924 portal: &str,
925 statement_name: &str,
926 params: &P,
927 ) -> Result<()> {
928 use crate::protocol::backend::{BindComplete, ErrorResponse, RawMessage, msg_type};
929 use crate::protocol::frontend::{write_bind, write_flush};
930
931 let param_oids = params.natural_oids();
932 self.buffer_set.write_buffer.clear();
933 write_bind(
934 &mut self.buffer_set.write_buffer,
935 portal,
936 statement_name,
937 params,
938 ¶m_oids,
939 )?;
940 write_flush(&mut self.buffer_set.write_buffer);
941
942 self.stream.write_all(&self.buffer_set.write_buffer).await?;
943 self.stream.flush().await?;
944
945 loop {
946 self.stream.read_message(&mut self.buffer_set).await?;
947 let type_byte = self.buffer_set.type_byte;
948
949 if RawMessage::is_async_type(type_byte) {
950 continue;
951 }
952
953 match type_byte {
954 msg_type::BIND_COMPLETE => {
955 BindComplete::parse(&self.buffer_set.read_buffer)?;
956 return Ok(());
957 }
958 msg_type::ERROR_RESPONSE => {
959 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
960 return Err(error.into_error());
961 }
962 _ => {
963 return Err(Error::Protocol(format!(
964 "Expected BindComplete or ErrorResponse, got '{}'",
965 type_byte as char
966 )));
967 }
968 }
969 }
970 }
971
972 pub async fn lowlevel_execute<H: BinaryHandler>(
985 &mut self,
986 portal: &str,
987 max_rows: u32,
988 handler: &mut H,
989 ) -> Result<bool> {
990 let result = self.lowlevel_execute_inner(portal, max_rows, handler).await;
991 if let Err(e) = &result
992 && e.is_connection_broken()
993 {
994 self.is_broken = true;
995 }
996 result
997 }
998
999 async fn lowlevel_execute_inner<H: BinaryHandler>(
1000 &mut self,
1001 portal: &str,
1002 max_rows: u32,
1003 handler: &mut H,
1004 ) -> Result<bool> {
1005 use crate::protocol::backend::{
1006 CommandComplete, DataRow, ErrorResponse, NoData, PortalSuspended, RawMessage,
1007 RowDescription, msg_type,
1008 };
1009 use crate::protocol::frontend::{write_describe_portal, write_execute, write_flush};
1010
1011 self.buffer_set.write_buffer.clear();
1012 write_describe_portal(&mut self.buffer_set.write_buffer, portal);
1013 write_execute(&mut self.buffer_set.write_buffer, portal, max_rows);
1014 write_flush(&mut self.buffer_set.write_buffer);
1015
1016 self.stream.write_all(&self.buffer_set.write_buffer).await?;
1017 self.stream.flush().await?;
1018
1019 let mut column_buffer: Vec<u8> = Vec::new();
1020
1021 loop {
1022 self.stream.read_message(&mut self.buffer_set).await?;
1023 let type_byte = self.buffer_set.type_byte;
1024
1025 if RawMessage::is_async_type(type_byte) {
1026 continue;
1027 }
1028
1029 match type_byte {
1030 msg_type::ROW_DESCRIPTION => {
1031 column_buffer.clear();
1032 column_buffer.extend_from_slice(&self.buffer_set.read_buffer);
1033 let cols = RowDescription::parse(&column_buffer)?;
1034 handler.result_start(cols)?;
1035 }
1036 msg_type::NO_DATA => {
1037 NoData::parse(&self.buffer_set.read_buffer)?;
1038 }
1039 msg_type::DATA_ROW => {
1040 let cols = RowDescription::parse(&column_buffer)?;
1041 let row = DataRow::parse(&self.buffer_set.read_buffer)?;
1042 handler.row(cols, row)?;
1043 }
1044 msg_type::COMMAND_COMPLETE => {
1045 let complete = CommandComplete::parse(&self.buffer_set.read_buffer)?;
1046 handler.result_end(complete)?;
1047 return Ok(false); }
1049 msg_type::PORTAL_SUSPENDED => {
1050 PortalSuspended::parse(&self.buffer_set.read_buffer)?;
1051 return Ok(true); }
1053 msg_type::ERROR_RESPONSE => {
1054 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
1055 return Err(error.into_error());
1056 }
1057 _ => {
1058 return Err(Error::Protocol(format!(
1059 "Unexpected message in execute: '{}'",
1060 type_byte as char
1061 )));
1062 }
1063 }
1064 }
1065 }
1066
1067 pub async fn exec_portal<S: IntoStatement, P, F, Fut, T>(
1097 &mut self,
1098 statement: S,
1099 params: P,
1100 f: F,
1101 ) -> Result<T>
1102 where
1103 P: ToParams,
1104 F: FnOnce(&mut super::unnamed_portal::UnnamedPortal<'_>) -> Fut,
1105 Fut: std::future::Future<Output = Result<T>>,
1106 {
1107 let result = self.exec_portal_inner(&statement, ¶ms, f).await;
1108 if let Err(e) = &result
1109 && e.is_connection_broken()
1110 {
1111 self.is_broken = true;
1112 }
1113 result
1114 }
1115
1116 async fn exec_portal_inner<S: IntoStatement, P, F, Fut, T>(
1117 &mut self,
1118 statement: &S,
1119 params: &P,
1120 f: F,
1121 ) -> Result<T>
1122 where
1123 P: ToParams,
1124 F: FnOnce(&mut super::unnamed_portal::UnnamedPortal<'_>) -> Fut,
1125 Fut: std::future::Future<Output = Result<T>>,
1126 {
1127 let mut state_machine = if let Some(sql) = statement.as_sql() {
1129 BindStateMachine::bind_sql(&mut self.buffer_set, "", sql, params)?
1130 } else {
1131 let stmt = statement.as_prepared().unwrap();
1132 BindStateMachine::bind_prepared(
1133 &mut self.buffer_set,
1134 "",
1135 &stmt.wire_name(),
1136 &stmt.param_oids,
1137 params,
1138 )?
1139 };
1140
1141 loop {
1143 match state_machine.step(&mut self.buffer_set)? {
1144 Action::ReadMessage => {
1145 self.stream.read_message(&mut self.buffer_set).await?;
1146 }
1147 Action::Write => {
1148 self.stream.write_all(&self.buffer_set.write_buffer).await?;
1149 self.stream.flush().await?;
1150 }
1151 Action::WriteAndReadMessage => {
1152 self.stream.write_all(&self.buffer_set.write_buffer).await?;
1153 self.stream.flush().await?;
1154 self.stream.read_message(&mut self.buffer_set).await?;
1155 }
1156 Action::Finished => break,
1157 _ => return Err(Error::Protocol("Unexpected action in bind".into())),
1158 }
1159 }
1160
1161 let mut portal = super::unnamed_portal::UnnamedPortal { conn: self };
1163 let result = f(&mut portal).await;
1164
1165 let sync_result = portal.conn.lowlevel_sync().await;
1167
1168 match (result, sync_result) {
1170 (Ok(v), Ok(())) => Ok(v),
1171 (Err(e), _) => Err(e),
1172 (Ok(_), Err(e)) => Err(e),
1173 }
1174 }
1175
1176 pub async fn lowlevel_close_portal(&mut self, portal: &str) -> Result<()> {
1178 let result = self.lowlevel_close_portal_inner(portal).await;
1179 if let Err(e) = &result
1180 && e.is_connection_broken()
1181 {
1182 self.is_broken = true;
1183 }
1184 result
1185 }
1186
1187 async fn lowlevel_close_portal_inner(&mut self, portal: &str) -> Result<()> {
1188 use crate::protocol::backend::{CloseComplete, ErrorResponse, RawMessage, msg_type};
1189 use crate::protocol::frontend::{write_close_portal, write_flush};
1190
1191 self.buffer_set.write_buffer.clear();
1192 write_close_portal(&mut self.buffer_set.write_buffer, portal);
1193 write_flush(&mut self.buffer_set.write_buffer);
1194
1195 self.stream.write_all(&self.buffer_set.write_buffer).await?;
1196 self.stream.flush().await?;
1197
1198 loop {
1199 self.stream.read_message(&mut self.buffer_set).await?;
1200 let type_byte = self.buffer_set.type_byte;
1201
1202 if RawMessage::is_async_type(type_byte) {
1203 continue;
1204 }
1205
1206 match type_byte {
1207 msg_type::CLOSE_COMPLETE => {
1208 CloseComplete::parse(&self.buffer_set.read_buffer)?;
1209 return Ok(());
1210 }
1211 msg_type::ERROR_RESPONSE => {
1212 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
1213 return Err(error.into_error());
1214 }
1215 _ => {
1216 return Err(Error::Protocol(format!(
1217 "Expected CloseComplete or ErrorResponse, got '{}'",
1218 type_byte as char
1219 )));
1220 }
1221 }
1222 }
1223 }
1224
1225 pub async fn run_pipeline<T, F, Fut>(&mut self, f: F) -> Result<T>
1256 where
1257 F: FnOnce(&mut super::pipeline::Pipeline<'_>) -> Fut,
1258 Fut: std::future::Future<Output = Result<T>>,
1259 {
1260 let mut pipeline = super::pipeline::Pipeline::new_inner(self);
1261 let result = f(&mut pipeline).await;
1262 pipeline.cleanup().await;
1263 result
1264 }
1265
1266 pub async fn tx<F, R, Fut>(&mut self, f: F) -> Result<R>
1276 where
1277 F: FnOnce(&mut Conn, super::transaction::Transaction) -> Fut,
1278 Fut: std::future::Future<Output = Result<R>>,
1279 {
1280 if self.in_transaction() {
1281 return Err(Error::InvalidUsage(
1282 "nested transactions are not supported".into(),
1283 ));
1284 }
1285
1286 self.query_drop("BEGIN").await?;
1287
1288 let tx = super::transaction::Transaction::new(self.connection_id());
1289
1290 let result = f(self, tx).await;
1293
1294 if self.in_transaction() {
1296 let rollback_result = self.query_drop("ROLLBACK").await;
1297
1298 if let Err(e) = result {
1300 return Err(e);
1301 }
1302 rollback_result?;
1303 }
1304
1305 result
1306 }
1307}