1use tokio::net::TcpStream;
4#[cfg(unix)]
5use tokio::net::UnixStream;
6
7use crate::buffer_pool::PooledBufferSet;
8use crate::conversion::ToParams;
9use crate::error::{Error, Result};
10use crate::handler::{
11 AsyncMessageHandler, BinaryHandler, DropHandler, FirstRowHandler, TextHandler,
12};
13use crate::opts::Opts;
14use crate::protocol::backend::BackendKeyData;
15use crate::protocol::frontend::write_terminate;
16use crate::protocol::types::TransactionStatus;
17use crate::state::StateMachine;
18use crate::state::action::Action;
19use crate::state::connection::ConnectionStateMachine;
20use crate::state::extended::{BindStateMachine, ExtendedQueryStateMachine, PreparedStatement};
21use crate::state::simple_query::SimpleQueryStateMachine;
22use crate::statement::IntoStatement;
23
24use super::stream::Stream;
25
26pub struct Conn {
28 pub(crate) stream: Stream,
29 pub(crate) buffer_set: PooledBufferSet,
30 backend_key: Option<BackendKeyData>,
31 server_params: Vec<(String, String)>,
32 pub(crate) transaction_status: TransactionStatus,
33 pub(crate) is_broken: bool,
34 name_counter: u64,
35 async_message_handler: Option<Box<dyn AsyncMessageHandler>>,
36}
37
38impl Conn {
39 pub async fn new<O: TryInto<Opts>>(opts: O) -> Result<Self>
41 where
42 Error: From<O::Error>,
43 {
44 let opts = opts.try_into()?;
45
46 let stream = if let Some(socket_path) = &opts.socket {
47 #[cfg(unix)]
48 {
49 Stream::unix(UnixStream::connect(socket_path).await?)
50 }
51 #[cfg(not(unix))]
52 {
53 let _ = socket_path;
54 return Err(Error::Unsupported(
55 "Unix sockets are not supported on this platform".into(),
56 ));
57 }
58 } else {
59 if opts.host.is_empty() {
60 return Err(Error::InvalidUsage("host is empty".into()));
61 }
62 let addr = format!("{}:{}", opts.host, opts.port);
63 let tcp = TcpStream::connect(&addr).await?;
64 tcp.set_nodelay(true)?;
65 Stream::tcp(tcp)
66 };
67
68 Self::new_with_stream(stream, opts).await
69 }
70
71 #[allow(unused_mut)]
73 pub async fn new_with_stream(mut stream: Stream, options: Opts) -> Result<Self> {
74 let mut buffer_set = options.buffer_pool.get_buffer_set();
75 let mut state_machine = ConnectionStateMachine::new(options.clone());
76
77 loop {
79 match state_machine.step(&mut buffer_set)? {
80 Action::WriteAndReadByte => {
81 stream.write_all(&buffer_set.write_buffer).await?;
82 stream.flush().await?;
83 let byte = stream.read_u8().await?;
84 state_machine.set_ssl_response(byte);
85 }
86 Action::ReadMessage => {
87 stream.read_message(&mut buffer_set).await?;
88 }
89 Action::Write => {
90 stream.write_all(&buffer_set.write_buffer).await?;
91 stream.flush().await?;
92 }
93 Action::WriteAndReadMessage => {
94 stream.write_all(&buffer_set.write_buffer).await?;
95 stream.flush().await?;
96 stream.read_message(&mut buffer_set).await?;
97 }
98 Action::TlsHandshake => {
99 #[cfg(feature = "tokio-tls")]
100 {
101 stream = stream.upgrade_to_tls(&options.host).await?;
102 }
103 #[cfg(not(feature = "tokio-tls"))]
104 {
105 return Err(Error::Unsupported(
106 "TLS requested but tokio-tls feature not enabled".into(),
107 ));
108 }
109 }
110 Action::HandleAsyncMessageAndReadMessage(_) => {
111 stream.read_message(&mut buffer_set).await?;
113 }
114 Action::Finished => break,
115 }
116 }
117
118 let conn = Self {
119 stream,
120 buffer_set,
121 backend_key: state_machine.backend_key().cloned(),
122 server_params: state_machine.take_server_params(),
123 transaction_status: state_machine.transaction_status(),
124 is_broken: false,
125 name_counter: 0,
126 async_message_handler: None,
127 };
128
129 #[cfg(unix)]
131 let conn = if options.upgrade_to_unix_socket && conn.stream.is_tcp_loopback() {
132 conn.try_upgrade_to_unix_socket(&options).await
133 } else {
134 conn
135 };
136
137 Ok(conn)
138 }
139
140 #[cfg(unix)]
143 fn try_upgrade_to_unix_socket(
144 mut self,
145 opts: &Opts,
146 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Self> + Send + '_>> {
147 let opts = opts.clone();
148 Box::pin(async move {
149 let mut handler = FirstRowHandler::<(String,)>::new();
151 if self
152 .query("SHOW unix_socket_directories", &mut handler)
153 .await
154 .is_err()
155 {
156 return self;
157 }
158
159 let socket_dir = match handler.into_row() {
160 Some((dirs,)) => {
161 match dirs.split(',').next() {
163 Some(d) if !d.trim().is_empty() => d.trim().to_string(),
164 _ => return self,
165 }
166 }
167 None => return self,
168 };
169
170 let socket_path = format!("{}/.s.PGSQL.{}", socket_dir, opts.port);
172
173 let unix_stream = match UnixStream::connect(&socket_path).await {
175 Ok(s) => s,
176 Err(_) => return self,
177 };
178
179 let mut opts_unix = opts.clone();
181 opts_unix.upgrade_to_unix_socket = false;
182
183 match Self::new_with_stream(Stream::unix(unix_stream), opts_unix).await {
184 Ok(new_conn) => new_conn,
185 Err(_) => self,
186 }
187 })
188 }
189
190 pub fn backend_key(&self) -> Option<&BackendKeyData> {
192 self.backend_key.as_ref()
193 }
194
195 pub fn connection_id(&self) -> u32 {
199 self.backend_key.as_ref().map_or(0, |k| k.process_id())
200 }
201
202 pub fn server_params(&self) -> &[(String, String)] {
204 &self.server_params
205 }
206
207 pub fn transaction_status(&self) -> TransactionStatus {
209 self.transaction_status
210 }
211
212 pub fn in_transaction(&self) -> bool {
214 self.transaction_status.in_transaction()
215 }
216
217 pub fn is_broken(&self) -> bool {
219 self.is_broken
220 }
221
222 pub(crate) fn next_portal_name(&mut self) -> String {
224 self.name_counter += 1;
225 format!("_zero_p_{}", self.name_counter)
226 }
227
228 pub(crate) async fn create_named_portal<S: IntoStatement, P: ToParams>(
232 &mut self,
233 portal_name: &str,
234 statement: &S,
235 params: &P,
236 ) -> Result<()> {
237 let mut state_machine = if let Some(sql) = statement.as_sql() {
239 BindStateMachine::bind_sql(&mut self.buffer_set, portal_name, sql, params)?
240 } else {
241 let stmt = statement.as_prepared().unwrap();
242 BindStateMachine::bind_prepared(
243 &mut self.buffer_set,
244 portal_name,
245 &stmt.wire_name(),
246 &stmt.param_oids,
247 params,
248 )?
249 };
250
251 loop {
253 match state_machine.step(&mut self.buffer_set)? {
254 Action::ReadMessage => {
255 self.stream.read_message(&mut self.buffer_set).await?;
256 }
257 Action::Write => {
258 self.stream.write_all(&self.buffer_set.write_buffer).await?;
259 self.stream.flush().await?;
260 }
261 Action::WriteAndReadMessage => {
262 self.stream.write_all(&self.buffer_set.write_buffer).await?;
263 self.stream.flush().await?;
264 self.stream.read_message(&mut self.buffer_set).await?;
265 }
266 Action::Finished => break,
267 _ => return Err(Error::Protocol("Unexpected action in bind".into())),
268 }
269 }
270
271 Ok(())
272 }
273
274 pub fn set_async_message_handler<H: AsyncMessageHandler + 'static>(&mut self, handler: H) {
281 self.async_message_handler = Some(Box::new(handler));
282 }
283
284 pub fn clear_async_message_handler(&mut self) {
286 self.async_message_handler = None;
287 }
288
289 pub async fn ping(&mut self) -> Result<()> {
291 self.query_drop("").await?;
292 Ok(())
293 }
294
295 async fn drive<S: StateMachine>(&mut self, state_machine: &mut S) -> Result<()> {
297 loop {
298 match state_machine.step(&mut self.buffer_set)? {
299 Action::WriteAndReadByte => {
300 return Err(Error::Protocol(
301 "Unexpected WriteAndReadByte in query state machine".into(),
302 ));
303 }
304 Action::ReadMessage => {
305 self.stream.read_message(&mut self.buffer_set).await?;
306 }
307 Action::Write => {
308 self.stream.write_all(&self.buffer_set.write_buffer).await?;
309 self.stream.flush().await?;
310 }
311 Action::WriteAndReadMessage => {
312 self.stream.write_all(&self.buffer_set.write_buffer).await?;
313 self.stream.flush().await?;
314 self.stream.read_message(&mut self.buffer_set).await?;
315 }
316 Action::TlsHandshake => {
317 return Err(Error::Protocol(
318 "Unexpected TlsHandshake in query state machine".into(),
319 ));
320 }
321 Action::HandleAsyncMessageAndReadMessage(ref async_msg) => {
322 if let Some(ref mut h) = self.async_message_handler {
323 h.handle(async_msg);
324 }
325 self.stream.read_message(&mut self.buffer_set).await?;
327 }
328 Action::Finished => {
329 self.transaction_status = state_machine.transaction_status();
330 break;
331 }
332 }
333 }
334 Ok(())
335 }
336
337 pub async fn query<H: TextHandler>(&mut self, sql: &str, handler: &mut H) -> Result<()> {
339 let result = self.query_inner(sql, handler).await;
340 if let Err(e) = &result
341 && e.is_connection_broken()
342 {
343 self.is_broken = true;
344 }
345 result
346 }
347
348 async fn query_inner<H: TextHandler>(&mut self, sql: &str, handler: &mut H) -> Result<()> {
349 let mut state_machine = SimpleQueryStateMachine::new(handler, sql);
350 self.drive(&mut state_machine).await
351 }
352
353 pub async fn query_drop(&mut self, sql: &str) -> Result<Option<u64>> {
355 let mut handler = DropHandler::new();
356 self.query(sql, &mut handler).await?;
357 Ok(handler.rows_affected())
358 }
359
360 pub async fn query_collect<T: for<'a> crate::conversion::FromRow<'a>>(
362 &mut self,
363 sql: &str,
364 ) -> Result<Vec<T>> {
365 let mut handler = crate::handler::CollectHandler::<T>::new();
366 self.query(sql, &mut handler).await?;
367 Ok(handler.into_rows())
368 }
369
370 pub async fn query_first<T: for<'a> crate::conversion::FromRow<'a>>(
372 &mut self,
373 sql: &str,
374 ) -> Result<Option<T>> {
375 let mut handler = crate::handler::FirstRowHandler::<T>::new();
376 self.query(sql, &mut handler).await?;
377 Ok(handler.into_row())
378 }
379
380 pub async fn query_foreach<
393 T: for<'a> crate::conversion::FromRow<'a>,
394 F: FnMut(T) -> Result<()>,
395 >(
396 &mut self,
397 sql: &str,
398 f: F,
399 ) -> Result<()> {
400 let mut handler = crate::handler::ForEachHandler::<T, F>::new(f);
401 self.query(sql, &mut handler).await?;
402 Ok(())
403 }
404
405 pub async fn close(mut self) -> Result<()> {
407 self.buffer_set.write_buffer.clear();
408 write_terminate(&mut self.buffer_set.write_buffer);
409 self.stream.write_all(&self.buffer_set.write_buffer).await?;
410 self.stream.flush().await?;
411 Ok(())
412 }
413
414 pub async fn prepare(&mut self, query: &str) -> Result<PreparedStatement> {
418 self.prepare_typed(query, &[]).await
419 }
420
421 pub async fn prepare_typed(
423 &mut self,
424 query: &str,
425 param_oids: &[u32],
426 ) -> Result<PreparedStatement> {
427 self.name_counter += 1;
428 let idx = self.name_counter;
429 let result = self.prepare_inner(idx, query, param_oids).await;
430 if let Err(e) = &result
431 && e.is_connection_broken()
432 {
433 self.is_broken = true;
434 }
435 result
436 }
437
438 pub async fn prepare_batch(&mut self, queries: &[&str]) -> Result<Vec<PreparedStatement>> {
455 if queries.is_empty() {
456 return Ok(Vec::new());
457 }
458
459 let start_idx = self.name_counter + 1;
460 self.name_counter += queries.len() as u64;
461
462 let result = self.prepare_batch_inner(queries, start_idx).await;
463 if let Err(e) = &result
464 && e.is_connection_broken()
465 {
466 self.is_broken = true;
467 }
468 result
469 }
470
471 async fn prepare_batch_inner(
472 &mut self,
473 queries: &[&str],
474 start_idx: u64,
475 ) -> Result<Vec<PreparedStatement>> {
476 use crate::state::batch_prepare::BatchPrepareStateMachine;
477
478 let mut state_machine =
479 BatchPrepareStateMachine::new(&mut self.buffer_set, queries, start_idx);
480
481 loop {
482 match state_machine.step(&mut self.buffer_set)? {
483 Action::ReadMessage => {
484 self.stream.read_message(&mut self.buffer_set).await?;
485 }
486 Action::WriteAndReadMessage => {
487 self.stream.write_all(&self.buffer_set.write_buffer).await?;
488 self.stream.flush().await?;
489 self.stream.read_message(&mut self.buffer_set).await?;
490 }
491 Action::Finished => {
492 self.transaction_status = state_machine.transaction_status();
493 break;
494 }
495 _ => return Err(Error::Protocol("Unexpected action in batch prepare".into())),
496 }
497 }
498
499 Ok(state_machine.take_statements())
500 }
501
502 async fn prepare_inner(
503 &mut self,
504 idx: u64,
505 query: &str,
506 param_oids: &[u32],
507 ) -> Result<PreparedStatement> {
508 let mut handler = DropHandler::new();
509 let mut state_machine = ExtendedQueryStateMachine::prepare(
510 &mut handler,
511 &mut self.buffer_set,
512 idx,
513 query,
514 param_oids,
515 );
516 self.drive(&mut state_machine).await?;
517 state_machine
518 .take_prepared_statement()
519 .ok_or_else(|| Error::Protocol("No prepared statement".into()))
520 }
521
522 pub async fn exec<S: IntoStatement, P: ToParams, H: BinaryHandler>(
528 &mut self,
529 statement: S,
530 params: P,
531 handler: &mut H,
532 ) -> Result<()> {
533 let result = self.exec_inner(&statement, ¶ms, handler).await;
534 if let Err(e) = &result
535 && e.is_connection_broken()
536 {
537 self.is_broken = true;
538 }
539 result
540 }
541
542 async fn exec_inner<S: IntoStatement, P: ToParams, H: BinaryHandler>(
543 &mut self,
544 statement: &S,
545 params: &P,
546 handler: &mut H,
547 ) -> Result<()> {
548 let mut state_machine = if statement.needs_parse() {
549 ExtendedQueryStateMachine::execute_sql(
550 handler,
551 &mut self.buffer_set,
552 statement.as_sql().unwrap(),
553 params,
554 )?
555 } else {
556 let stmt = statement.as_prepared().unwrap();
557 ExtendedQueryStateMachine::execute(
558 handler,
559 &mut self.buffer_set,
560 &stmt.wire_name(),
561 &stmt.param_oids,
562 params,
563 )?
564 };
565
566 self.drive(&mut state_machine).await
567 }
568
569 pub async fn exec_drop<S: IntoStatement, P: ToParams>(
573 &mut self,
574 statement: S,
575 params: P,
576 ) -> Result<Option<u64>> {
577 let mut handler = DropHandler::new();
578 self.exec(statement, params, &mut handler).await?;
579 Ok(handler.rows_affected())
580 }
581
582 pub async fn exec_collect<
586 T: for<'a> crate::conversion::FromRow<'a>,
587 S: IntoStatement,
588 P: ToParams,
589 >(
590 &mut self,
591 statement: S,
592 params: P,
593 ) -> Result<Vec<T>> {
594 let mut handler = crate::handler::CollectHandler::<T>::new();
595 self.exec(statement, params, &mut handler).await?;
596 Ok(handler.into_rows())
597 }
598
599 pub async fn exec_first<
603 T: for<'a> crate::conversion::FromRow<'a>,
604 S: IntoStatement,
605 P: ToParams,
606 >(
607 &mut self,
608 statement: S,
609 params: P,
610 ) -> Result<Option<T>> {
611 let mut handler = crate::handler::FirstRowHandler::<T>::new();
612 self.exec(statement, params, &mut handler).await?;
613 Ok(handler.into_row())
614 }
615
616 pub async fn exec_foreach<
632 T: for<'a> crate::conversion::FromRow<'a>,
633 S: IntoStatement,
634 P: ToParams,
635 F: FnMut(T) -> Result<()>,
636 >(
637 &mut self,
638 statement: S,
639 params: P,
640 f: F,
641 ) -> Result<()> {
642 let mut handler = crate::handler::ForEachHandler::<T, F>::new(f);
643 self.exec(statement, params, &mut handler).await?;
644 Ok(())
645 }
646
647 pub async fn exec_batch<S: IntoStatement, P: ToParams>(
678 &mut self,
679 statement: S,
680 params_list: &[P],
681 ) -> Result<()> {
682 self.exec_batch_chunked(statement, params_list, 1000).await
683 }
684
685 pub async fn exec_batch_chunked<S: IntoStatement, P: ToParams>(
689 &mut self,
690 statement: S,
691 params_list: &[P],
692 chunk_size: usize,
693 ) -> Result<()> {
694 let result = self
695 .exec_batch_inner(&statement, params_list, chunk_size)
696 .await;
697 if let Err(e) = &result
698 && e.is_connection_broken()
699 {
700 self.is_broken = true;
701 }
702 result
703 }
704
705 async fn exec_batch_inner<S: IntoStatement, P: ToParams>(
706 &mut self,
707 statement: &S,
708 params_list: &[P],
709 chunk_size: usize,
710 ) -> Result<()> {
711 use crate::protocol::frontend::{write_bind, write_execute, write_parse, write_sync};
712 use crate::state::extended::BatchStateMachine;
713
714 if params_list.is_empty() {
715 return Ok(());
716 }
717
718 let chunk_size = chunk_size.max(1);
719 let needs_parse = statement.needs_parse();
720 let sql = statement.as_sql();
721 let prepared = statement.as_prepared();
722
723 let param_oids: Vec<u32> = if let Some(stmt) = prepared {
725 stmt.param_oids.clone()
726 } else {
727 params_list[0].natural_oids()
728 };
729
730 let stmt_name = prepared.map(|s| s.wire_name()).unwrap_or_default();
732
733 for chunk in params_list.chunks(chunk_size) {
734 self.buffer_set.write_buffer.clear();
735
736 let parse_in_chunk = needs_parse;
738 if parse_in_chunk {
739 write_parse(
740 &mut self.buffer_set.write_buffer,
741 "",
742 sql.unwrap(),
743 ¶m_oids,
744 );
745 }
746
747 for params in chunk {
749 let effective_stmt_name = if needs_parse { "" } else { &stmt_name };
750 write_bind(
751 &mut self.buffer_set.write_buffer,
752 "",
753 effective_stmt_name,
754 params,
755 ¶m_oids,
756 )?;
757 write_execute(&mut self.buffer_set.write_buffer, "", 0);
758 }
759
760 write_sync(&mut self.buffer_set.write_buffer);
762
763 let mut state_machine = BatchStateMachine::new(parse_in_chunk);
765 self.drive_batch(&mut state_machine).await?;
766 self.transaction_status = state_machine.transaction_status();
767 }
768
769 Ok(())
770 }
771
772 async fn drive_batch(
774 &mut self,
775 state_machine: &mut crate::state::extended::BatchStateMachine,
776 ) -> Result<()> {
777 use crate::protocol::backend::{ReadyForQuery, msg_type};
778 use crate::state::action::Action;
779
780 loop {
781 let step_result = state_machine.step(&mut self.buffer_set);
782 match step_result {
783 Ok(Action::ReadMessage) => {
784 self.stream.read_message(&mut self.buffer_set).await?;
785 }
786 Ok(Action::WriteAndReadMessage) => {
787 self.stream.write_all(&self.buffer_set.write_buffer).await?;
788 self.stream.flush().await?;
789 self.stream.read_message(&mut self.buffer_set).await?;
790 }
791 Ok(Action::Finished) => {
792 break;
793 }
794 Ok(_) => return Err(Error::Protocol("Unexpected action in batch".into())),
795 Err(e) => {
796 loop {
798 self.stream.read_message(&mut self.buffer_set).await?;
799 if self.buffer_set.type_byte == msg_type::READY_FOR_QUERY {
800 let ready = ReadyForQuery::parse(&self.buffer_set.read_buffer)?;
801 self.transaction_status =
802 ready.transaction_status().unwrap_or_default();
803 break;
804 }
805 }
806 return Err(e);
807 }
808 }
809 }
810 Ok(())
811 }
812
813 pub async fn close_statement(&mut self, stmt: &PreparedStatement) -> Result<()> {
815 let result = self.close_statement_inner(&stmt.wire_name()).await;
816 if let Err(e) = &result
817 && e.is_connection_broken()
818 {
819 self.is_broken = true;
820 }
821 result
822 }
823
824 async fn close_statement_inner(&mut self, name: &str) -> Result<()> {
825 let mut handler = DropHandler::new();
826 let mut state_machine =
827 ExtendedQueryStateMachine::close_statement(&mut handler, &mut self.buffer_set, name);
828 self.drive(&mut state_machine).await
829 }
830
831 pub async fn lowlevel_flush(&mut self) -> Result<()> {
839 use crate::protocol::frontend::write_flush;
840
841 self.buffer_set.write_buffer.clear();
842 write_flush(&mut self.buffer_set.write_buffer);
843
844 self.stream.write_all(&self.buffer_set.write_buffer).await?;
845 self.stream.flush().await?;
846 Ok(())
847 }
848
849 pub async fn lowlevel_sync(&mut self) -> Result<()> {
856 let result = self.lowlevel_sync_inner().await;
857 if let Err(e) = &result
858 && e.is_connection_broken()
859 {
860 self.is_broken = true;
861 }
862 result
863 }
864
865 async fn lowlevel_sync_inner(&mut self) -> Result<()> {
866 use crate::protocol::backend::{ErrorResponse, RawMessage, ReadyForQuery, msg_type};
867 use crate::protocol::frontend::write_sync;
868
869 self.buffer_set.write_buffer.clear();
870 write_sync(&mut self.buffer_set.write_buffer);
871
872 self.stream.write_all(&self.buffer_set.write_buffer).await?;
873 self.stream.flush().await?;
874
875 let mut pending_error: Option<Error> = None;
876
877 loop {
878 self.stream.read_message(&mut self.buffer_set).await?;
879 let type_byte = self.buffer_set.type_byte;
880
881 if RawMessage::is_async_type(type_byte) {
882 continue;
883 }
884
885 match type_byte {
886 msg_type::READY_FOR_QUERY => {
887 let ready = ReadyForQuery::parse(&self.buffer_set.read_buffer)?;
888 self.transaction_status = ready.transaction_status().unwrap_or_default();
889 if let Some(e) = pending_error {
890 return Err(e);
891 }
892 return Ok(());
893 }
894 msg_type::ERROR_RESPONSE => {
895 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
896 pending_error = Some(error.into_error());
897 }
898 _ => {
899 }
901 }
902 }
903 }
904
905 pub async fn lowlevel_bind<P: ToParams>(
915 &mut self,
916 portal: &str,
917 statement_name: &str,
918 params: P,
919 ) -> Result<()> {
920 let result = self
921 .lowlevel_bind_inner(portal, statement_name, ¶ms)
922 .await;
923 if let Err(e) = &result
924 && e.is_connection_broken()
925 {
926 self.is_broken = true;
927 }
928 result
929 }
930
931 async fn lowlevel_bind_inner<P: ToParams>(
932 &mut self,
933 portal: &str,
934 statement_name: &str,
935 params: &P,
936 ) -> Result<()> {
937 use crate::protocol::backend::{BindComplete, ErrorResponse, RawMessage, msg_type};
938 use crate::protocol::frontend::{write_bind, write_flush};
939
940 let param_oids = params.natural_oids();
941 self.buffer_set.write_buffer.clear();
942 write_bind(
943 &mut self.buffer_set.write_buffer,
944 portal,
945 statement_name,
946 params,
947 ¶m_oids,
948 )?;
949 write_flush(&mut self.buffer_set.write_buffer);
950
951 self.stream.write_all(&self.buffer_set.write_buffer).await?;
952 self.stream.flush().await?;
953
954 loop {
955 self.stream.read_message(&mut self.buffer_set).await?;
956 let type_byte = self.buffer_set.type_byte;
957
958 if RawMessage::is_async_type(type_byte) {
959 continue;
960 }
961
962 match type_byte {
963 msg_type::BIND_COMPLETE => {
964 BindComplete::parse(&self.buffer_set.read_buffer)?;
965 return Ok(());
966 }
967 msg_type::ERROR_RESPONSE => {
968 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
969 return Err(error.into_error());
970 }
971 _ => {
972 return Err(Error::Protocol(format!(
973 "Expected BindComplete or ErrorResponse, got '{}'",
974 type_byte as char
975 )));
976 }
977 }
978 }
979 }
980
981 pub async fn lowlevel_execute<H: BinaryHandler>(
994 &mut self,
995 portal: &str,
996 max_rows: u32,
997 handler: &mut H,
998 ) -> Result<bool> {
999 let result = self.lowlevel_execute_inner(portal, max_rows, handler).await;
1000 if let Err(e) = &result
1001 && e.is_connection_broken()
1002 {
1003 self.is_broken = true;
1004 }
1005 result
1006 }
1007
1008 async fn lowlevel_execute_inner<H: BinaryHandler>(
1009 &mut self,
1010 portal: &str,
1011 max_rows: u32,
1012 handler: &mut H,
1013 ) -> Result<bool> {
1014 use crate::protocol::backend::{
1015 CommandComplete, DataRow, ErrorResponse, NoData, PortalSuspended, RawMessage,
1016 RowDescription, msg_type,
1017 };
1018 use crate::protocol::frontend::{write_describe_portal, write_execute, write_flush};
1019
1020 self.buffer_set.write_buffer.clear();
1021 write_describe_portal(&mut self.buffer_set.write_buffer, portal);
1022 write_execute(&mut self.buffer_set.write_buffer, portal, max_rows);
1023 write_flush(&mut self.buffer_set.write_buffer);
1024
1025 self.stream.write_all(&self.buffer_set.write_buffer).await?;
1026 self.stream.flush().await?;
1027
1028 let mut column_buffer: Vec<u8> = Vec::new();
1029
1030 loop {
1031 self.stream.read_message(&mut self.buffer_set).await?;
1032 let type_byte = self.buffer_set.type_byte;
1033
1034 if RawMessage::is_async_type(type_byte) {
1035 continue;
1036 }
1037
1038 match type_byte {
1039 msg_type::ROW_DESCRIPTION => {
1040 column_buffer.clear();
1041 column_buffer.extend_from_slice(&self.buffer_set.read_buffer);
1042 let cols = RowDescription::parse(&column_buffer)?;
1043 handler.result_start(cols)?;
1044 }
1045 msg_type::NO_DATA => {
1046 NoData::parse(&self.buffer_set.read_buffer)?;
1047 }
1048 msg_type::DATA_ROW => {
1049 let cols = RowDescription::parse(&column_buffer)?;
1050 let row = DataRow::parse(&self.buffer_set.read_buffer)?;
1051 handler.row(cols, row)?;
1052 }
1053 msg_type::COMMAND_COMPLETE => {
1054 let complete = CommandComplete::parse(&self.buffer_set.read_buffer)?;
1055 handler.result_end(complete)?;
1056 return Ok(false); }
1058 msg_type::PORTAL_SUSPENDED => {
1059 PortalSuspended::parse(&self.buffer_set.read_buffer)?;
1060 return Ok(true); }
1062 msg_type::ERROR_RESPONSE => {
1063 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
1064 return Err(error.into_error());
1065 }
1066 _ => {
1067 return Err(Error::Protocol(format!(
1068 "Unexpected message in execute: '{}'",
1069 type_byte as char
1070 )));
1071 }
1072 }
1073 }
1074 }
1075
1076 pub async fn exec_portal<S: IntoStatement, P, F, Fut, T>(
1106 &mut self,
1107 statement: S,
1108 params: P,
1109 f: F,
1110 ) -> Result<T>
1111 where
1112 P: ToParams,
1113 F: FnOnce(&mut super::unnamed_portal::UnnamedPortal<'_>) -> Fut,
1114 Fut: std::future::Future<Output = Result<T>>,
1115 {
1116 let result = self.exec_portal_inner(&statement, ¶ms, f).await;
1117 if let Err(e) = &result
1118 && e.is_connection_broken()
1119 {
1120 self.is_broken = true;
1121 }
1122 result
1123 }
1124
1125 async fn exec_portal_inner<S: IntoStatement, P, F, Fut, T>(
1126 &mut self,
1127 statement: &S,
1128 params: &P,
1129 f: F,
1130 ) -> Result<T>
1131 where
1132 P: ToParams,
1133 F: FnOnce(&mut super::unnamed_portal::UnnamedPortal<'_>) -> Fut,
1134 Fut: std::future::Future<Output = Result<T>>,
1135 {
1136 let mut state_machine = if let Some(sql) = statement.as_sql() {
1138 BindStateMachine::bind_sql(&mut self.buffer_set, "", sql, params)?
1139 } else {
1140 let stmt = statement.as_prepared().unwrap();
1141 BindStateMachine::bind_prepared(
1142 &mut self.buffer_set,
1143 "",
1144 &stmt.wire_name(),
1145 &stmt.param_oids,
1146 params,
1147 )?
1148 };
1149
1150 loop {
1152 match state_machine.step(&mut self.buffer_set)? {
1153 Action::ReadMessage => {
1154 self.stream.read_message(&mut self.buffer_set).await?;
1155 }
1156 Action::Write => {
1157 self.stream.write_all(&self.buffer_set.write_buffer).await?;
1158 self.stream.flush().await?;
1159 }
1160 Action::WriteAndReadMessage => {
1161 self.stream.write_all(&self.buffer_set.write_buffer).await?;
1162 self.stream.flush().await?;
1163 self.stream.read_message(&mut self.buffer_set).await?;
1164 }
1165 Action::Finished => break,
1166 _ => return Err(Error::Protocol("Unexpected action in bind".into())),
1167 }
1168 }
1169
1170 let mut portal = super::unnamed_portal::UnnamedPortal { conn: self };
1172 let result = f(&mut portal).await;
1173
1174 let sync_result = portal.conn.lowlevel_sync().await;
1176
1177 match (result, sync_result) {
1179 (Ok(v), Ok(())) => Ok(v),
1180 (Err(e), _) => Err(e),
1181 (Ok(_), Err(e)) => Err(e),
1182 }
1183 }
1184
1185 pub async fn lowlevel_close_portal(&mut self, portal: &str) -> Result<()> {
1187 let result = self.lowlevel_close_portal_inner(portal).await;
1188 if let Err(e) = &result
1189 && e.is_connection_broken()
1190 {
1191 self.is_broken = true;
1192 }
1193 result
1194 }
1195
1196 async fn lowlevel_close_portal_inner(&mut self, portal: &str) -> Result<()> {
1197 use crate::protocol::backend::{CloseComplete, ErrorResponse, RawMessage, msg_type};
1198 use crate::protocol::frontend::{write_close_portal, write_flush};
1199
1200 self.buffer_set.write_buffer.clear();
1201 write_close_portal(&mut self.buffer_set.write_buffer, portal);
1202 write_flush(&mut self.buffer_set.write_buffer);
1203
1204 self.stream.write_all(&self.buffer_set.write_buffer).await?;
1205 self.stream.flush().await?;
1206
1207 loop {
1208 self.stream.read_message(&mut self.buffer_set).await?;
1209 let type_byte = self.buffer_set.type_byte;
1210
1211 if RawMessage::is_async_type(type_byte) {
1212 continue;
1213 }
1214
1215 match type_byte {
1216 msg_type::CLOSE_COMPLETE => {
1217 CloseComplete::parse(&self.buffer_set.read_buffer)?;
1218 return Ok(());
1219 }
1220 msg_type::ERROR_RESPONSE => {
1221 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
1222 return Err(error.into_error());
1223 }
1224 _ => {
1225 return Err(Error::Protocol(format!(
1226 "Expected CloseComplete or ErrorResponse, got '{}'",
1227 type_byte as char
1228 )));
1229 }
1230 }
1231 }
1232 }
1233
1234 pub async fn run_pipeline<T, F, Fut>(&mut self, f: F) -> Result<T>
1265 where
1266 F: FnOnce(&mut super::pipeline::Pipeline<'_>) -> Fut,
1267 Fut: std::future::Future<Output = Result<T>>,
1268 {
1269 let mut pipeline = super::pipeline::Pipeline::new_inner(self);
1270 let result = f(&mut pipeline).await;
1271 pipeline.cleanup().await;
1272 result
1273 }
1274
1275 pub async fn tx<F, R, Fut>(&mut self, f: F) -> Result<R>
1285 where
1286 F: FnOnce(&mut Conn, super::transaction::Transaction) -> Fut,
1287 Fut: std::future::Future<Output = Result<R>>,
1288 {
1289 if self.in_transaction() {
1290 return Err(Error::InvalidUsage(
1291 "nested transactions are not supported".into(),
1292 ));
1293 }
1294
1295 self.query_drop("BEGIN").await?;
1296
1297 let tx = super::transaction::Transaction::new(self.connection_id());
1298
1299 let result = f(self, tx).await;
1302
1303 if self.in_transaction() {
1305 let rollback_result = self.query_drop("ROLLBACK").await;
1306
1307 if let Err(e) = result {
1309 return Err(e);
1310 }
1311 rollback_result?;
1312 }
1313
1314 result
1315 }
1316}