1use tokio::net::TcpStream;
4#[cfg(unix)]
5use tokio::net::UnixStream;
6
7use crate::buffer_pool::PooledBufferSet;
8use crate::conversion::ToParams;
9use crate::error::{Error, Result};
10use crate::handler::{
11 AsyncMessageHandler, DropHandler, ExtendedHandler, FirstRowHandler, SimpleHandler,
12};
13use crate::opts::Opts;
14use crate::protocol::backend::BackendKeyData;
15use crate::protocol::frontend::write_terminate;
16use crate::protocol::types::TransactionStatus;
17use crate::state::StateMachine;
18use crate::state::action::Action;
19use crate::state::connection::ConnectionStateMachine;
20use crate::state::extended::{BindStateMachine, ExtendedQueryStateMachine, PreparedStatement};
21use crate::state::simple_query::SimpleQueryStateMachine;
22use crate::statement::{IntoStatement, StatementRef};
23
24use super::stream::Stream;
25
26pub struct Conn {
28 pub(crate) stream: Stream,
29 pub(crate) buffer_set: PooledBufferSet,
30 backend_key: Option<BackendKeyData>,
31 server_params: Vec<(String, String)>,
32 pub(crate) transaction_status: TransactionStatus,
33 pub(crate) is_broken: bool,
34 name_counter: u64,
35 async_message_handler: Option<Box<dyn AsyncMessageHandler>>,
36}
37
38impl Conn {
39 pub async fn new<O: TryInto<Opts>>(opts: O) -> Result<Self>
41 where
42 Error: From<O::Error>,
43 {
44 let opts = opts.try_into()?;
45
46 let stream = if let Some(socket_path) = &opts.socket {
47 #[cfg(unix)]
48 {
49 Stream::unix(UnixStream::connect(socket_path).await?)
50 }
51 #[cfg(not(unix))]
52 {
53 let _ = socket_path;
54 return Err(Error::Unsupported(
55 "Unix sockets are not supported on this platform".into(),
56 ));
57 }
58 } else {
59 if opts.host.is_empty() {
60 return Err(Error::InvalidUsage("host is empty".into()));
61 }
62 let addr = format!("{}:{}", opts.host, opts.port);
63 let tcp = TcpStream::connect(&addr).await?;
64 tcp.set_nodelay(true)?;
65 Stream::tcp(tcp)
66 };
67
68 Self::new_with_stream(stream, opts).await
69 }
70
71 pub async fn new_with_stream(mut stream: Stream, options: Opts) -> Result<Self> {
73 let mut buffer_set = options.buffer_pool.get_buffer_set();
74 let mut state_machine = ConnectionStateMachine::new(options.clone());
75
76 loop {
78 match state_machine.step(&mut buffer_set)? {
79 Action::WriteAndReadByte => {
80 stream.write_all(&buffer_set.write_buffer).await?;
81 stream.flush().await?;
82 let byte = stream.read_u8().await?;
83 state_machine.set_ssl_response(byte);
84 }
85 Action::ReadMessage => {
86 stream.read_message(&mut buffer_set).await?;
87 }
88 Action::Write => {
89 stream.write_all(&buffer_set.write_buffer).await?;
90 stream.flush().await?;
91 }
92 Action::WriteAndReadMessage => {
93 stream.write_all(&buffer_set.write_buffer).await?;
94 stream.flush().await?;
95 stream.read_message(&mut buffer_set).await?;
96 }
97 Action::TlsHandshake => {
98 #[cfg(feature = "tokio-tls")]
99 {
100 stream = stream.upgrade_to_tls(&options.host).await?;
101 }
102 #[cfg(not(feature = "tokio-tls"))]
103 {
104 return Err(Error::Unsupported(
105 "TLS requested but tokio-tls feature not enabled".into(),
106 ));
107 }
108 }
109 Action::HandleAsyncMessageAndReadMessage(_) => {
110 stream.read_message(&mut buffer_set).await?;
112 }
113 Action::Error(_) => {
114 return Err(Error::LibraryBug(
115 "unexpected server error during connection startup".into(),
116 ));
117 }
118 Action::Finished => break,
119 }
120 }
121
122 let conn = Self {
123 stream,
124 buffer_set,
125 backend_key: state_machine.backend_key().cloned(),
126 server_params: state_machine.take_server_params(),
127 transaction_status: state_machine.transaction_status(),
128 is_broken: false,
129 name_counter: 0,
130 async_message_handler: None,
131 };
132
133 #[cfg(unix)]
135 let conn = if options.upgrade_to_unix_socket && conn.stream.is_tcp_loopback() {
136 conn.try_upgrade_to_unix_socket(&options).await
137 } else {
138 conn
139 };
140
141 Ok(conn)
142 }
143
144 #[cfg(unix)]
147 fn try_upgrade_to_unix_socket(
148 mut self,
149 opts: &Opts,
150 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Self> + Send + '_>> {
151 let opts = opts.clone();
152 Box::pin(async move {
153 let mut handler = FirstRowHandler::<(String,)>::new();
155 if self
156 .query("SHOW unix_socket_directories", &mut handler)
157 .await
158 .is_err()
159 {
160 return self;
161 }
162
163 let socket_dir = match handler.into_row() {
164 Some((dirs,)) => {
165 match dirs.split(',').next() {
167 Some(d) if !d.trim().is_empty() => d.trim().to_string(),
168 _ => return self,
169 }
170 }
171 None => return self,
172 };
173
174 let socket_path = format!("{}/.s.PGSQL.{}", socket_dir, opts.port);
176
177 let unix_stream = match UnixStream::connect(&socket_path).await {
179 Ok(s) => s,
180 Err(_) => return self,
181 };
182
183 let mut opts_unix = opts.clone();
185 opts_unix.upgrade_to_unix_socket = false;
186
187 match Self::new_with_stream(Stream::unix(unix_stream), opts_unix).await {
188 Ok(new_conn) => new_conn,
189 Err(_) => self,
190 }
191 })
192 }
193
194 pub fn backend_key(&self) -> Option<&BackendKeyData> {
196 self.backend_key.as_ref()
197 }
198
199 pub fn connection_id(&self) -> u32 {
203 self.backend_key.as_ref().map_or(0, |k| k.process_id())
204 }
205
206 pub fn server_params(&self) -> &[(String, String)] {
208 &self.server_params
209 }
210
211 pub fn transaction_status(&self) -> TransactionStatus {
213 self.transaction_status
214 }
215
216 pub fn in_transaction(&self) -> bool {
218 self.transaction_status.in_transaction()
219 }
220
221 pub fn is_broken(&self) -> bool {
223 self.is_broken
224 }
225
226 pub(crate) fn next_portal_name(&mut self) -> String {
228 self.name_counter += 1;
229 format!("_zero_p_{}", self.name_counter)
230 }
231
232 pub(crate) async fn create_named_portal<S: IntoStatement, P: ToParams>(
236 &mut self,
237 portal_name: &str,
238 statement: &S,
239 params: &P,
240 ) -> Result<()> {
241 let mut state_machine = match statement.statement_ref() {
243 StatementRef::Sql(sql) => {
244 BindStateMachine::bind_sql(&mut self.buffer_set, portal_name, sql, params)?
245 }
246 StatementRef::Prepared(stmt) => BindStateMachine::bind_prepared(
247 &mut self.buffer_set,
248 portal_name,
249 &stmt.wire_name(),
250 &stmt.param_oids,
251 params,
252 )?,
253 };
254
255 loop {
257 match state_machine.step(&mut self.buffer_set)? {
258 Action::ReadMessage => {
259 self.stream.read_message(&mut self.buffer_set).await?;
260 }
261 Action::Write => {
262 self.stream.write_all(&self.buffer_set.write_buffer).await?;
263 self.stream.flush().await?;
264 }
265 Action::WriteAndReadMessage => {
266 self.stream.write_all(&self.buffer_set.write_buffer).await?;
267 self.stream.flush().await?;
268 self.stream.read_message(&mut self.buffer_set).await?;
269 }
270 Action::Finished => break,
271 _ => return Err(Error::LibraryBug("Unexpected action in bind".into())),
272 }
273 }
274
275 Ok(())
276 }
277
278 pub fn set_async_message_handler<H: AsyncMessageHandler + 'static>(&mut self, handler: H) {
285 self.async_message_handler = Some(Box::new(handler));
286 }
287
288 pub fn clear_async_message_handler(&mut self) {
290 self.async_message_handler = None;
291 }
292
293 pub async fn ping(&mut self) -> Result<()> {
295 self.query_drop("").await?;
296 Ok(())
297 }
298
299 async fn drive<S: StateMachine>(&mut self, state_machine: &mut S) -> Result<()> {
301 loop {
302 let action = state_machine.step(&mut self.buffer_set)?;
303
304 match action {
305 Action::WriteAndReadByte => {
306 return Err(Error::LibraryBug(
307 "Unexpected WriteAndReadByte in query state machine".into(),
308 ));
309 }
310 Action::ReadMessage => {
311 self.stream.read_message(&mut self.buffer_set).await?;
312 }
313 Action::Write => {
314 self.stream.write_all(&self.buffer_set.write_buffer).await?;
315 self.stream.flush().await?;
316 }
317 Action::WriteAndReadMessage => {
318 self.stream.write_all(&self.buffer_set.write_buffer).await?;
319 self.stream.flush().await?;
320 self.stream.read_message(&mut self.buffer_set).await?;
321 }
322 Action::TlsHandshake => {
323 return Err(Error::LibraryBug(
324 "Unexpected TlsHandshake in query state machine".into(),
325 ));
326 }
327 Action::HandleAsyncMessageAndReadMessage(async_msg) => {
328 if let Some(h) = &mut self.async_message_handler {
329 h.handle(&async_msg);
330 }
331 self.stream.read_message(&mut self.buffer_set).await?;
333 }
334 Action::Error(server_error) => {
335 self.transaction_status = state_machine.transaction_status();
336 return Err(Error::Server(server_error));
337 }
338 Action::Finished => {
339 self.transaction_status = state_machine.transaction_status();
340 break;
341 }
342 }
343 }
344 Ok(())
345 }
346
347 pub async fn query<H: SimpleHandler>(&mut self, sql: &str, handler: &mut H) -> Result<()> {
349 let result = self.query_inner(sql, handler).await;
350 if let Err(e) = &result
351 && e.is_connection_broken()
352 {
353 self.is_broken = true;
354 }
355 result
356 }
357
358 async fn query_inner<H: SimpleHandler>(&mut self, sql: &str, handler: &mut H) -> Result<()> {
359 let mut state_machine = SimpleQueryStateMachine::new(handler, sql);
360 self.drive(&mut state_machine).await
361 }
362
363 pub async fn query_drop(&mut self, sql: &str) -> Result<Option<u64>> {
365 let mut handler = DropHandler::new();
366 self.query(sql, &mut handler).await?;
367 Ok(handler.rows_affected())
368 }
369
370 pub async fn query_collect<T: for<'a> crate::conversion::FromRow<'a>>(
372 &mut self,
373 sql: &str,
374 ) -> Result<Vec<T>> {
375 let mut handler = crate::handler::CollectHandler::<T>::new();
376 self.query(sql, &mut handler).await?;
377 Ok(handler.into_rows())
378 }
379
380 pub async fn query_first<T: for<'a> crate::conversion::FromRow<'a>>(
382 &mut self,
383 sql: &str,
384 ) -> Result<Option<T>> {
385 let mut handler = crate::handler::FirstRowHandler::<T>::new();
386 self.query(sql, &mut handler).await?;
387 Ok(handler.into_row())
388 }
389
390 pub async fn query_foreach<
403 T: for<'a> crate::conversion::FromRow<'a>,
404 F: FnMut(T) -> Result<()>,
405 >(
406 &mut self,
407 sql: &str,
408 f: F,
409 ) -> Result<()> {
410 let mut handler = crate::handler::ForEachHandler::<T, F>::new(f);
411 self.query(sql, &mut handler).await?;
412 Ok(())
413 }
414
415 pub async fn close(mut self) -> Result<()> {
417 self.buffer_set.write_buffer.clear();
418 write_terminate(&mut self.buffer_set.write_buffer);
419 self.stream.write_all(&self.buffer_set.write_buffer).await?;
420 self.stream.flush().await?;
421 Ok(())
422 }
423
424 pub async fn prepare(&mut self, query: &str) -> Result<PreparedStatement> {
428 self.prepare_typed(query, &[]).await
429 }
430
431 pub async fn prepare_typed(
433 &mut self,
434 query: &str,
435 param_oids: &[u32],
436 ) -> Result<PreparedStatement> {
437 self.name_counter += 1;
438 let idx = self.name_counter;
439 let result = self.prepare_inner(idx, query, param_oids).await;
440 if let Err(e) = &result
441 && e.is_connection_broken()
442 {
443 self.is_broken = true;
444 }
445 result
446 }
447
448 pub async fn prepare_batch(&mut self, queries: &[&str]) -> Result<Vec<PreparedStatement>> {
465 if queries.is_empty() {
466 return Ok(Vec::new());
467 }
468
469 let start_idx = self.name_counter + 1;
470 self.name_counter += queries.len() as u64;
471
472 let result = self.prepare_batch_inner(queries, start_idx).await;
473 if let Err(e) = &result
474 && e.is_connection_broken()
475 {
476 self.is_broken = true;
477 }
478 result
479 }
480
481 async fn prepare_batch_inner(
482 &mut self,
483 queries: &[&str],
484 start_idx: u64,
485 ) -> Result<Vec<PreparedStatement>> {
486 use crate::state::batch_prepare::BatchPrepareStateMachine;
487
488 let mut state_machine =
489 BatchPrepareStateMachine::new(&mut self.buffer_set, queries, start_idx);
490
491 loop {
492 match state_machine.step(&mut self.buffer_set)? {
493 Action::ReadMessage => {
494 self.stream.read_message(&mut self.buffer_set).await?;
495 }
496 Action::WriteAndReadMessage => {
497 self.stream.write_all(&self.buffer_set.write_buffer).await?;
498 self.stream.flush().await?;
499 self.stream.read_message(&mut self.buffer_set).await?;
500 }
501 Action::Finished => {
502 self.transaction_status = state_machine.transaction_status();
503 break;
504 }
505 _ => {
506 return Err(Error::LibraryBug(
507 "Unexpected action in batch prepare".into(),
508 ));
509 }
510 }
511 }
512
513 Ok(state_machine.take_statements())
514 }
515
516 async fn prepare_inner(
517 &mut self,
518 idx: u64,
519 query: &str,
520 param_oids: &[u32],
521 ) -> Result<PreparedStatement> {
522 let mut handler = DropHandler::new();
523 let mut state_machine = ExtendedQueryStateMachine::prepare(
524 &mut handler,
525 &mut self.buffer_set,
526 idx,
527 query,
528 param_oids,
529 );
530 self.drive(&mut state_machine).await?;
531 state_machine
532 .take_prepared_statement()
533 .ok_or_else(|| Error::LibraryBug("No prepared statement".into()))
534 }
535
536 pub async fn exec<S: IntoStatement, P: ToParams, H: ExtendedHandler>(
542 &mut self,
543 statement: S,
544 params: P,
545 handler: &mut H,
546 ) -> Result<()> {
547 let result = self.exec_inner(&statement, ¶ms, handler).await;
548 if let Err(e) = &result
549 && e.is_connection_broken()
550 {
551 self.is_broken = true;
552 }
553 result
554 }
555
556 async fn exec_inner<S: IntoStatement, P: ToParams, H: ExtendedHandler>(
557 &mut self,
558 statement: &S,
559 params: &P,
560 handler: &mut H,
561 ) -> Result<()> {
562 let mut state_machine = match statement.statement_ref() {
563 StatementRef::Sql(sql) => {
564 ExtendedQueryStateMachine::execute_sql(handler, &mut self.buffer_set, sql, params)?
565 }
566 StatementRef::Prepared(stmt) => ExtendedQueryStateMachine::execute(
567 handler,
568 &mut self.buffer_set,
569 &stmt.wire_name(),
570 &stmt.param_oids,
571 params,
572 )?,
573 };
574
575 self.drive(&mut state_machine).await
576 }
577
578 pub async fn exec_drop<S: IntoStatement, P: ToParams>(
582 &mut self,
583 statement: S,
584 params: P,
585 ) -> Result<Option<u64>> {
586 let mut handler = DropHandler::new();
587 self.exec(statement, params, &mut handler).await?;
588 Ok(handler.rows_affected())
589 }
590
591 pub async fn exec_collect<
595 T: for<'a> crate::conversion::FromRow<'a>,
596 S: IntoStatement,
597 P: ToParams,
598 >(
599 &mut self,
600 statement: S,
601 params: P,
602 ) -> Result<Vec<T>> {
603 let mut handler = crate::handler::CollectHandler::<T>::new();
604 self.exec(statement, params, &mut handler).await?;
605 Ok(handler.into_rows())
606 }
607
608 pub async fn exec_first<
612 T: for<'a> crate::conversion::FromRow<'a>,
613 S: IntoStatement,
614 P: ToParams,
615 >(
616 &mut self,
617 statement: S,
618 params: P,
619 ) -> Result<Option<T>> {
620 let mut handler = crate::handler::FirstRowHandler::<T>::new();
621 self.exec(statement, params, &mut handler).await?;
622 Ok(handler.into_row())
623 }
624
625 pub async fn exec_foreach<
641 T: for<'a> crate::conversion::FromRow<'a>,
642 S: IntoStatement,
643 P: ToParams,
644 F: FnMut(T) -> Result<()>,
645 >(
646 &mut self,
647 statement: S,
648 params: P,
649 f: F,
650 ) -> Result<()> {
651 let mut handler = crate::handler::ForEachHandler::<T, F>::new(f);
652 self.exec(statement, params, &mut handler).await?;
653 Ok(())
654 }
655
656 pub async fn exec_batch<S: IntoStatement, P: ToParams>(
687 &mut self,
688 statement: S,
689 params_list: &[P],
690 ) -> Result<()> {
691 self.exec_batch_chunked(statement, params_list, 1000).await
692 }
693
694 pub async fn exec_batch_chunked<S: IntoStatement, P: ToParams>(
698 &mut self,
699 statement: S,
700 params_list: &[P],
701 chunk_size: usize,
702 ) -> Result<()> {
703 let result = self
704 .exec_batch_inner(&statement, params_list, chunk_size)
705 .await;
706 if let Err(e) = &result
707 && e.is_connection_broken()
708 {
709 self.is_broken = true;
710 }
711 result
712 }
713
714 async fn exec_batch_inner<S: IntoStatement, P: ToParams>(
715 &mut self,
716 statement: &S,
717 params_list: &[P],
718 chunk_size: usize,
719 ) -> Result<()> {
720 use crate::protocol::frontend::{write_bind, write_execute, write_parse, write_sync};
721 use crate::state::extended::BatchStateMachine;
722
723 if params_list.is_empty() {
724 return Ok(());
725 }
726
727 let chunk_size = chunk_size.max(1);
728 let stmt_ref = statement.statement_ref();
729
730 let (param_oids, stmt_name) = match stmt_ref {
731 StatementRef::Sql(_) => (params_list[0].natural_oids(), String::new()),
732 StatementRef::Prepared(stmt) => (stmt.param_oids.clone(), stmt.wire_name()),
733 };
734
735 for chunk in params_list.chunks(chunk_size) {
736 self.buffer_set.write_buffer.clear();
737
738 if let StatementRef::Sql(sql) = stmt_ref {
740 write_parse(&mut self.buffer_set.write_buffer, "", sql, ¶m_oids);
741 }
742
743 for params in chunk {
745 let effective_stmt_name = if matches!(stmt_ref, StatementRef::Sql(_)) {
746 ""
747 } else {
748 &stmt_name
749 };
750 write_bind(
751 &mut self.buffer_set.write_buffer,
752 "",
753 effective_stmt_name,
754 params,
755 ¶m_oids,
756 )?;
757 write_execute(&mut self.buffer_set.write_buffer, "", 0);
758 }
759
760 write_sync(&mut self.buffer_set.write_buffer);
762
763 let mut state_machine =
765 BatchStateMachine::new(matches!(stmt_ref, StatementRef::Sql(_)));
766 self.drive_batch(&mut state_machine).await?;
767 self.transaction_status = state_machine.transaction_status();
768 }
769
770 Ok(())
771 }
772
773 async fn drive_batch(
775 &mut self,
776 state_machine: &mut crate::state::extended::BatchStateMachine,
777 ) -> Result<()> {
778 use crate::state::action::Action;
779
780 loop {
781 let step_result = state_machine.step(&mut self.buffer_set);
782 match step_result {
783 Ok(Action::ReadMessage) => {
784 self.stream.read_message(&mut self.buffer_set).await?;
785 }
786 Ok(Action::WriteAndReadMessage) => {
787 self.stream.write_all(&self.buffer_set.write_buffer).await?;
788 self.stream.flush().await?;
789 self.stream.read_message(&mut self.buffer_set).await?;
790 }
791 Ok(Action::Finished) => {
792 break;
793 }
794 Ok(Action::Error(server_error)) => {
795 self.transaction_status = state_machine.transaction_status();
796 return Err(Error::Server(server_error));
797 }
798 Ok(_) => return Err(Error::LibraryBug("Unexpected action in batch".into())),
799 Err(e) => return Err(e),
800 }
801 }
802 Ok(())
803 }
804
805 pub async fn close_statement(&mut self, stmt: &PreparedStatement) -> Result<()> {
807 let result = self.close_statement_inner(&stmt.wire_name()).await;
808 if let Err(e) = &result
809 && e.is_connection_broken()
810 {
811 self.is_broken = true;
812 }
813 result
814 }
815
816 async fn close_statement_inner(&mut self, name: &str) -> Result<()> {
817 let mut handler = DropHandler::new();
818 let mut state_machine =
819 ExtendedQueryStateMachine::close_statement(&mut handler, &mut self.buffer_set, name);
820 self.drive(&mut state_machine).await
821 }
822
823 pub async fn lowlevel_flush(&mut self) -> Result<()> {
831 use crate::protocol::frontend::write_flush;
832
833 self.buffer_set.write_buffer.clear();
834 write_flush(&mut self.buffer_set.write_buffer);
835
836 self.stream.write_all(&self.buffer_set.write_buffer).await?;
837 self.stream.flush().await?;
838 Ok(())
839 }
840
841 pub async fn lowlevel_sync(&mut self) -> Result<()> {
848 let result = self.lowlevel_sync_inner().await;
849 if let Err(e) = &result
850 && e.is_connection_broken()
851 {
852 self.is_broken = true;
853 }
854 result
855 }
856
857 async fn lowlevel_sync_inner(&mut self) -> Result<()> {
858 use crate::protocol::backend::{ErrorResponse, RawMessage, ReadyForQuery, msg_type};
859 use crate::protocol::frontend::write_sync;
860
861 self.buffer_set.write_buffer.clear();
862 write_sync(&mut self.buffer_set.write_buffer);
863
864 self.stream.write_all(&self.buffer_set.write_buffer).await?;
865 self.stream.flush().await?;
866
867 let mut pending_error: Option<Error> = None;
868
869 loop {
870 self.stream.read_message(&mut self.buffer_set).await?;
871 let type_byte = self.buffer_set.type_byte;
872
873 if RawMessage::is_async_type(type_byte) {
874 continue;
875 }
876
877 match type_byte {
878 msg_type::READY_FOR_QUERY => {
879 let ready = ReadyForQuery::parse(&self.buffer_set.read_buffer)?;
880 self.transaction_status = ready.transaction_status().unwrap_or_default();
881 if let Some(e) = pending_error {
882 return Err(e);
883 }
884 return Ok(());
885 }
886 msg_type::ERROR_RESPONSE => {
887 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
888 pending_error = Some(error.into_error());
889 }
890 _ => {
891 }
893 }
894 }
895 }
896
897 pub async fn lowlevel_bind<P: ToParams>(
907 &mut self,
908 portal: &str,
909 statement_name: &str,
910 params: P,
911 ) -> Result<()> {
912 let result = self
913 .lowlevel_bind_inner(portal, statement_name, ¶ms)
914 .await;
915 if let Err(e) = &result
916 && e.is_connection_broken()
917 {
918 self.is_broken = true;
919 }
920 result
921 }
922
923 async fn lowlevel_bind_inner<P: ToParams>(
924 &mut self,
925 portal: &str,
926 statement_name: &str,
927 params: &P,
928 ) -> Result<()> {
929 use crate::protocol::backend::{BindComplete, ErrorResponse, RawMessage, msg_type};
930 use crate::protocol::frontend::{write_bind, write_flush};
931
932 let param_oids = params.natural_oids();
933 self.buffer_set.write_buffer.clear();
934 write_bind(
935 &mut self.buffer_set.write_buffer,
936 portal,
937 statement_name,
938 params,
939 ¶m_oids,
940 )?;
941 write_flush(&mut self.buffer_set.write_buffer);
942
943 self.stream.write_all(&self.buffer_set.write_buffer).await?;
944 self.stream.flush().await?;
945
946 loop {
947 self.stream.read_message(&mut self.buffer_set).await?;
948 let type_byte = self.buffer_set.type_byte;
949
950 if RawMessage::is_async_type(type_byte) {
951 continue;
952 }
953
954 match type_byte {
955 msg_type::BIND_COMPLETE => {
956 BindComplete::parse(&self.buffer_set.read_buffer)?;
957 return Ok(());
958 }
959 msg_type::ERROR_RESPONSE => {
960 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
961 return Err(error.into_error());
962 }
963 _ => {
964 return Err(Error::LibraryBug(format!(
965 "Expected BindComplete or ErrorResponse, got '{}'",
966 type_byte as char
967 )));
968 }
969 }
970 }
971 }
972
973 pub async fn lowlevel_execute<H: ExtendedHandler>(
986 &mut self,
987 portal: &str,
988 max_rows: u32,
989 handler: &mut H,
990 ) -> Result<bool> {
991 let result = self.lowlevel_execute_inner(portal, max_rows, handler).await;
992 if let Err(e) = &result
993 && e.is_connection_broken()
994 {
995 self.is_broken = true;
996 }
997 result
998 }
999
1000 async fn lowlevel_execute_inner<H: ExtendedHandler>(
1001 &mut self,
1002 portal: &str,
1003 max_rows: u32,
1004 handler: &mut H,
1005 ) -> Result<bool> {
1006 use crate::protocol::backend::{
1007 CommandComplete, DataRow, ErrorResponse, NoData, PortalSuspended, RawMessage,
1008 RowDescription, msg_type,
1009 };
1010 use crate::protocol::frontend::{write_describe_portal, write_execute, write_flush};
1011
1012 self.buffer_set.write_buffer.clear();
1013 write_describe_portal(&mut self.buffer_set.write_buffer, portal);
1014 write_execute(&mut self.buffer_set.write_buffer, portal, max_rows);
1015 write_flush(&mut self.buffer_set.write_buffer);
1016
1017 self.stream.write_all(&self.buffer_set.write_buffer).await?;
1018 self.stream.flush().await?;
1019
1020 let mut column_buffer: Vec<u8> = Vec::new();
1021
1022 loop {
1023 self.stream.read_message(&mut self.buffer_set).await?;
1024 let type_byte = self.buffer_set.type_byte;
1025
1026 if RawMessage::is_async_type(type_byte) {
1027 continue;
1028 }
1029
1030 match type_byte {
1031 msg_type::ROW_DESCRIPTION => {
1032 column_buffer.clear();
1033 column_buffer.extend_from_slice(&self.buffer_set.read_buffer);
1034 let cols = RowDescription::parse(&column_buffer)?;
1035 handler.result_start(cols)?;
1036 }
1037 msg_type::NO_DATA => {
1038 NoData::parse(&self.buffer_set.read_buffer)?;
1039 }
1040 msg_type::DATA_ROW => {
1041 let cols = RowDescription::parse(&column_buffer)?;
1042 let row = DataRow::parse(&self.buffer_set.read_buffer)?;
1043 handler.row(cols, row)?;
1044 }
1045 msg_type::COMMAND_COMPLETE => {
1046 let complete = CommandComplete::parse(&self.buffer_set.read_buffer)?;
1047 handler.result_end(complete)?;
1048 return Ok(false); }
1050 msg_type::PORTAL_SUSPENDED => {
1051 PortalSuspended::parse(&self.buffer_set.read_buffer)?;
1052 return Ok(true); }
1054 msg_type::ERROR_RESPONSE => {
1055 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
1056 return Err(error.into_error());
1057 }
1058 _ => {
1059 return Err(Error::LibraryBug(format!(
1060 "Unexpected message in execute: '{}'",
1061 type_byte as char
1062 )));
1063 }
1064 }
1065 }
1066 }
1067
1068 pub async fn exec_portal<S: IntoStatement, P, F, T>(
1098 &mut self,
1099 statement: S,
1100 params: P,
1101 f: F,
1102 ) -> Result<T>
1103 where
1104 P: ToParams,
1105 F: AsyncFnOnce(&mut super::unnamed_portal::UnnamedPortal<'_>) -> Result<T>,
1106 {
1107 let result = self.exec_portal_inner(&statement, ¶ms, f).await;
1108 if let Err(e) = &result
1109 && e.is_connection_broken()
1110 {
1111 self.is_broken = true;
1112 }
1113 result
1114 }
1115
1116 async fn exec_portal_inner<S: IntoStatement, P, F, T>(
1117 &mut self,
1118 statement: &S,
1119 params: &P,
1120 f: F,
1121 ) -> Result<T>
1122 where
1123 P: ToParams,
1124 F: AsyncFnOnce(&mut super::unnamed_portal::UnnamedPortal<'_>) -> Result<T>,
1125 {
1126 let mut state_machine = match statement.statement_ref() {
1128 StatementRef::Sql(sql) => {
1129 BindStateMachine::bind_sql(&mut self.buffer_set, "", sql, params)?
1130 }
1131 StatementRef::Prepared(stmt) => BindStateMachine::bind_prepared(
1132 &mut self.buffer_set,
1133 "",
1134 &stmt.wire_name(),
1135 &stmt.param_oids,
1136 params,
1137 )?,
1138 };
1139
1140 loop {
1142 match state_machine.step(&mut self.buffer_set)? {
1143 Action::ReadMessage => {
1144 self.stream.read_message(&mut self.buffer_set).await?;
1145 }
1146 Action::Write => {
1147 self.stream.write_all(&self.buffer_set.write_buffer).await?;
1148 self.stream.flush().await?;
1149 }
1150 Action::WriteAndReadMessage => {
1151 self.stream.write_all(&self.buffer_set.write_buffer).await?;
1152 self.stream.flush().await?;
1153 self.stream.read_message(&mut self.buffer_set).await?;
1154 }
1155 Action::Finished => break,
1156 _ => return Err(Error::LibraryBug("Unexpected action in bind".into())),
1157 }
1158 }
1159
1160 let mut portal = super::unnamed_portal::UnnamedPortal { conn: self };
1162 let result = f(&mut portal).await;
1163
1164 let sync_result = portal.conn.lowlevel_sync().await;
1166
1167 match (result, sync_result) {
1169 (Ok(v), Ok(())) => Ok(v),
1170 (Err(e), _) => Err(e),
1171 (Ok(_), Err(e)) => Err(e),
1172 }
1173 }
1174
1175 pub async fn lowlevel_close_portal(&mut self, portal: &str) -> Result<()> {
1177 let result = self.lowlevel_close_portal_inner(portal).await;
1178 if let Err(e) = &result
1179 && e.is_connection_broken()
1180 {
1181 self.is_broken = true;
1182 }
1183 result
1184 }
1185
1186 async fn lowlevel_close_portal_inner(&mut self, portal: &str) -> Result<()> {
1187 use crate::protocol::backend::{CloseComplete, ErrorResponse, RawMessage, msg_type};
1188 use crate::protocol::frontend::{write_close_portal, write_flush};
1189
1190 self.buffer_set.write_buffer.clear();
1191 write_close_portal(&mut self.buffer_set.write_buffer, portal);
1192 write_flush(&mut self.buffer_set.write_buffer);
1193
1194 self.stream.write_all(&self.buffer_set.write_buffer).await?;
1195 self.stream.flush().await?;
1196
1197 loop {
1198 self.stream.read_message(&mut self.buffer_set).await?;
1199 let type_byte = self.buffer_set.type_byte;
1200
1201 if RawMessage::is_async_type(type_byte) {
1202 continue;
1203 }
1204
1205 match type_byte {
1206 msg_type::CLOSE_COMPLETE => {
1207 CloseComplete::parse(&self.buffer_set.read_buffer)?;
1208 return Ok(());
1209 }
1210 msg_type::ERROR_RESPONSE => {
1211 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
1212 return Err(error.into_error());
1213 }
1214 _ => {
1215 return Err(Error::LibraryBug(format!(
1216 "Expected CloseComplete or ErrorResponse, got '{}'",
1217 type_byte as char
1218 )));
1219 }
1220 }
1221 }
1222 }
1223
1224 pub async fn pipeline<T, F>(&mut self, f: F) -> Result<T>
1255 where
1256 F: AsyncFnOnce(&mut super::pipeline::Pipeline<'_>) -> Result<T>,
1257 {
1258 let mut pipeline = super::pipeline::Pipeline::new_inner(self);
1259 let result = f(&mut pipeline).await;
1260 pipeline.cleanup().await;
1261 result
1262 }
1263
1264 pub async fn transaction<F, R>(&mut self, f: F) -> Result<R>
1274 where
1275 F: AsyncFnOnce(&mut Conn, super::transaction::Transaction) -> Result<R>,
1276 {
1277 if self.in_transaction() {
1278 return Err(Error::InvalidUsage(
1279 "nested transactions are not supported".into(),
1280 ));
1281 }
1282
1283 self.query_drop("BEGIN").await?;
1284
1285 let tx = super::transaction::Transaction::new(self.connection_id());
1286
1287 let result = f(self, tx).await;
1288
1289 if self.in_transaction() {
1291 match &result {
1292 Ok(_) => {
1293 self.query_drop("COMMIT").await?;
1295 }
1296 Err(_) => {
1297 let _ = self.query_drop("ROLLBACK").await;
1300 }
1301 }
1302 }
1303
1304 result
1305 }
1306}