1use std::net::TcpStream;
4#[cfg(unix)]
5use std::os::unix::net::UnixStream;
6
7use crate::buffer_pool::PooledBufferSet;
8use crate::conversion::ToParams;
9use crate::error::{Error, Result};
10use crate::handler::{
11 AsyncMessageHandler, BinaryHandler, DropHandler, FirstRowHandler, TextHandler,
12};
13use crate::opts::Opts;
14use crate::protocol::backend::BackendKeyData;
15use crate::protocol::frontend::write_terminate;
16use crate::protocol::types::TransactionStatus;
17use crate::state::StateMachine;
18use crate::state::action::Action;
19use crate::state::connection::ConnectionStateMachine;
20use crate::state::extended::{BindStateMachine, ExtendedQueryStateMachine, PreparedStatement};
21use crate::state::simple_query::SimpleQueryStateMachine;
22use crate::statement::IntoStatement;
23
24use super::stream::Stream;
25use super::unnamed_portal::UnnamedPortal;
26
27pub struct Conn {
29 pub(crate) stream: Stream,
30 pub(crate) buffer_set: PooledBufferSet,
31 backend_key: Option<BackendKeyData>,
32 server_params: Vec<(String, String)>,
33 pub(crate) transaction_status: TransactionStatus,
34 pub(crate) is_broken: bool,
35 name_counter: u64,
36 async_message_handler: Option<Box<dyn AsyncMessageHandler>>,
37}
38
39impl Conn {
40 pub fn new<O: TryInto<Opts>>(opts: O) -> Result<Self>
42 where
43 Error: From<O::Error>,
44 {
45 let opts = opts.try_into()?;
46
47 let stream = if let Some(socket_path) = &opts.socket {
48 #[cfg(unix)]
49 {
50 Stream::unix(UnixStream::connect(socket_path)?)
51 }
52 #[cfg(not(unix))]
53 {
54 let _ = socket_path;
55 return Err(Error::Unsupported(
56 "Unix sockets are not supported on this platform".into(),
57 ));
58 }
59 } else {
60 if opts.host.is_empty() {
61 return Err(Error::InvalidUsage("host is empty".into()));
62 }
63 let addr = format!("{}:{}", opts.host, opts.port);
64 let tcp = TcpStream::connect(&addr)?;
65 tcp.set_nodelay(true)?;
66 Stream::tcp(tcp)
67 };
68
69 Self::new_with_stream(stream, opts)
70 }
71
72 #[allow(unused_mut)]
74 pub fn new_with_stream(mut stream: Stream, options: Opts) -> Result<Self> {
75 let mut buffer_set = options.buffer_pool.get_buffer_set();
76 let mut state_machine = ConnectionStateMachine::new(options.clone());
77
78 loop {
80 match state_machine.step(&mut buffer_set)? {
81 Action::WriteAndReadByte => {
82 stream.write_all(&buffer_set.write_buffer)?;
83 stream.flush()?;
84 let byte = stream.read_u8()?;
85 state_machine.set_ssl_response(byte);
86 }
87 Action::ReadMessage => {
88 stream.read_message(&mut buffer_set)?;
89 }
90 Action::Write => {
91 stream.write_all(&buffer_set.write_buffer)?;
92 stream.flush()?;
93 }
94 Action::WriteAndReadMessage => {
95 stream.write_all(&buffer_set.write_buffer)?;
96 stream.flush()?;
97 stream.read_message(&mut buffer_set)?;
98 }
99 Action::TlsHandshake => {
100 #[cfg(feature = "sync-tls")]
101 {
102 stream = stream.upgrade_to_tls(&options.host)?;
103 }
104 #[cfg(not(feature = "sync-tls"))]
105 {
106 return Err(Error::Unsupported(
107 "TLS requested but sync-tls feature not enabled".into(),
108 ));
109 }
110 }
111 Action::HandleAsyncMessageAndReadMessage(_) => {
112 stream.read_message(&mut buffer_set)?;
114 }
115 Action::Finished => break,
116 }
117 }
118
119 let conn = Self {
120 stream,
121 buffer_set,
122 backend_key: state_machine.backend_key().cloned(),
123 server_params: state_machine.take_server_params(),
124 transaction_status: state_machine.transaction_status(),
125 is_broken: false,
126 name_counter: 0,
127 async_message_handler: None,
128 };
129
130 #[cfg(unix)]
132 let conn = if options.upgrade_to_unix_socket && conn.stream.is_tcp_loopback() {
133 conn.try_upgrade_to_unix_socket(&options)
134 } else {
135 conn
136 };
137
138 Ok(conn)
139 }
140
141 #[cfg(unix)]
144 fn try_upgrade_to_unix_socket(mut self, opts: &Opts) -> Self {
145 let mut handler = FirstRowHandler::<(String,)>::new();
147 if self
148 .query("SHOW unix_socket_directories", &mut handler)
149 .is_err()
150 {
151 return self;
152 }
153
154 let socket_dir = match handler.into_row() {
155 Some((dirs,)) => {
156 match dirs.split(',').next() {
158 Some(d) if !d.trim().is_empty() => d.trim().to_string(),
159 _ => return self,
160 }
161 }
162 None => return self,
163 };
164
165 let socket_path = format!("{}/.s.PGSQL.{}", socket_dir, opts.port);
167
168 let unix_stream = match UnixStream::connect(&socket_path) {
170 Ok(s) => s,
171 Err(_) => return self,
172 };
173
174 let mut opts_unix = opts.clone();
176 opts_unix.upgrade_to_unix_socket = false;
177
178 match Self::new_with_stream(Stream::unix(unix_stream), opts_unix) {
179 Ok(new_conn) => new_conn,
180 Err(_) => self,
181 }
182 }
183
184 pub fn backend_key(&self) -> Option<&BackendKeyData> {
186 self.backend_key.as_ref()
187 }
188
189 pub fn connection_id(&self) -> u32 {
193 self.backend_key.as_ref().map_or(0, |k| k.process_id())
194 }
195
196 pub fn server_params(&self) -> &[(String, String)] {
198 &self.server_params
199 }
200
201 pub fn transaction_status(&self) -> TransactionStatus {
203 self.transaction_status
204 }
205
206 pub fn in_transaction(&self) -> bool {
208 self.transaction_status.in_transaction()
209 }
210
211 pub fn is_broken(&self) -> bool {
213 self.is_broken
214 }
215
216 pub(crate) fn next_portal_name(&mut self) -> String {
218 self.name_counter += 1;
219 format!("_zero_p_{}", self.name_counter)
220 }
221
222 pub(crate) fn create_named_portal<S: IntoStatement, P: ToParams>(
226 &mut self,
227 portal_name: &str,
228 statement: &S,
229 params: &P,
230 ) -> Result<()> {
231 let mut state_machine = if let Some(sql) = statement.as_sql() {
233 BindStateMachine::bind_sql(&mut self.buffer_set, portal_name, sql, params)?
234 } else {
235 let stmt = statement.as_prepared().unwrap();
236 BindStateMachine::bind_prepared(
237 &mut self.buffer_set,
238 portal_name,
239 &stmt.wire_name(),
240 &stmt.param_oids,
241 params,
242 )?
243 };
244
245 loop {
247 match state_machine.step(&mut self.buffer_set)? {
248 Action::ReadMessage => {
249 self.stream.read_message(&mut self.buffer_set)?;
250 }
251 Action::Write => {
252 self.stream.write_all(&self.buffer_set.write_buffer)?;
253 self.stream.flush()?;
254 }
255 Action::WriteAndReadMessage => {
256 self.stream.write_all(&self.buffer_set.write_buffer)?;
257 self.stream.flush()?;
258 self.stream.read_message(&mut self.buffer_set)?;
259 }
260 Action::Finished => break,
261 _ => return Err(Error::Protocol("Unexpected action in bind".into())),
262 }
263 }
264
265 Ok(())
266 }
267
268 pub fn set_async_message_handler<H: AsyncMessageHandler + 'static>(&mut self, handler: H) {
275 self.async_message_handler = Some(Box::new(handler));
276 }
277
278 pub fn clear_async_message_handler(&mut self) {
280 self.async_message_handler = None;
281 }
282
283 pub fn run_pipeline<T, F>(&mut self, f: F) -> Result<T>
304 where
305 F: FnOnce(&mut super::pipeline::Pipeline<'_>) -> Result<T>,
306 {
307 let mut pipeline = super::pipeline::Pipeline::new_inner(self);
308 let result = f(&mut pipeline);
309 pipeline.cleanup();
310 result
311 }
312
313 pub fn ping(&mut self) -> Result<()> {
315 self.query_drop("")?;
316 Ok(())
317 }
318
319 fn drive<S: StateMachine>(&mut self, state_machine: &mut S) -> Result<()> {
321 loop {
322 match state_machine.step(&mut self.buffer_set)? {
323 Action::WriteAndReadByte => {
324 return Err(Error::Protocol(
325 "Unexpected WriteAndReadByte in query state machine".into(),
326 ));
327 }
328 Action::ReadMessage => {
329 self.stream.read_message(&mut self.buffer_set)?;
330 }
331 Action::Write => {
332 self.stream.write_all(&self.buffer_set.write_buffer)?;
333 self.stream.flush()?;
334 }
335 Action::WriteAndReadMessage => {
336 self.stream.write_all(&self.buffer_set.write_buffer)?;
337 self.stream.flush()?;
338 self.stream.read_message(&mut self.buffer_set)?;
339 }
340 Action::TlsHandshake => {
341 return Err(Error::Protocol(
342 "Unexpected TlsHandshake in query state machine".into(),
343 ));
344 }
345 Action::HandleAsyncMessageAndReadMessage(ref async_msg) => {
346 if let Some(ref mut h) = self.async_message_handler {
347 h.handle(async_msg);
348 }
349 self.stream.read_message(&mut self.buffer_set)?;
351 }
352 Action::Finished => {
353 self.transaction_status = state_machine.transaction_status();
354 break;
355 }
356 }
357 }
358 Ok(())
359 }
360
361 pub fn query<H: TextHandler>(&mut self, sql: &str, handler: &mut H) -> Result<()> {
363 let result = self.query_inner(sql, handler);
364 if let Err(e) = &result
365 && e.is_connection_broken()
366 {
367 self.is_broken = true;
368 }
369 result
370 }
371
372 fn query_inner<H: TextHandler>(&mut self, sql: &str, handler: &mut H) -> Result<()> {
373 let mut state_machine = SimpleQueryStateMachine::new(handler, sql);
374 self.drive(&mut state_machine)
375 }
376
377 pub fn query_drop(&mut self, sql: &str) -> Result<Option<u64>> {
379 let mut handler = DropHandler::new();
380 self.query(sql, &mut handler)?;
381 Ok(handler.rows_affected())
382 }
383
384 pub fn query_collect<T: for<'a> crate::conversion::FromRow<'a>>(
395 &mut self,
396 sql: &str,
397 ) -> Result<Vec<T>> {
398 let mut handler = crate::handler::CollectHandler::<T>::new();
399 self.query(sql, &mut handler)?;
400 Ok(handler.into_rows())
401 }
402
403 pub fn query_first<T: for<'a> crate::conversion::FromRow<'a>>(
405 &mut self,
406 sql: &str,
407 ) -> Result<Option<T>> {
408 let mut handler = crate::handler::FirstRowHandler::<T>::new();
409 self.query(sql, &mut handler)?;
410 Ok(handler.into_row())
411 }
412
413 pub fn query_foreach<T: for<'a> crate::conversion::FromRow<'a>, F: FnMut(T)>(
423 &mut self,
424 sql: &str,
425 f: F,
426 ) -> Result<()> {
427 let mut handler = crate::handler::ForEachHandler::<T, F>::new(f);
428 self.query(sql, &mut handler)?;
429 Ok(())
430 }
431
432 pub fn close(mut self) -> Result<()> {
434 self.buffer_set.write_buffer.clear();
435 write_terminate(&mut self.buffer_set.write_buffer);
436 self.stream.write_all(&self.buffer_set.write_buffer)?;
437 self.stream.flush()?;
438 Ok(())
439 }
440
441 pub fn prepare(&mut self, query: &str) -> Result<PreparedStatement> {
445 self.prepare_typed(query, &[])
446 }
447
448 pub fn prepare_batch(&mut self, queries: &[&str]) -> Result<Vec<PreparedStatement>> {
465 if queries.is_empty() {
466 return Ok(Vec::new());
467 }
468
469 let start_idx = self.name_counter + 1;
470 self.name_counter += queries.len() as u64;
471
472 let result = self.prepare_batch_inner(queries, start_idx);
473 if let Err(e) = &result
474 && e.is_connection_broken()
475 {
476 self.is_broken = true;
477 }
478 result
479 }
480
481 fn prepare_batch_inner(
482 &mut self,
483 queries: &[&str],
484 start_idx: u64,
485 ) -> Result<Vec<PreparedStatement>> {
486 use crate::state::batch_prepare::BatchPrepareStateMachine;
487
488 let mut state_machine =
489 BatchPrepareStateMachine::new(&mut self.buffer_set, queries, start_idx);
490
491 loop {
492 match state_machine.step(&mut self.buffer_set)? {
493 Action::ReadMessage => {
494 self.stream.read_message(&mut self.buffer_set)?;
495 }
496 Action::WriteAndReadMessage => {
497 self.stream.write_all(&self.buffer_set.write_buffer)?;
498 self.stream.flush()?;
499 self.stream.read_message(&mut self.buffer_set)?;
500 }
501 Action::Finished => {
502 self.transaction_status = state_machine.transaction_status();
503 break;
504 }
505 _ => return Err(Error::Protocol("Unexpected action in batch prepare".into())),
506 }
507 }
508
509 Ok(state_machine.take_statements())
510 }
511
512 pub fn prepare_typed(&mut self, query: &str, param_oids: &[u32]) -> Result<PreparedStatement> {
514 self.name_counter += 1;
515 let idx = self.name_counter;
516 let result = self.prepare_inner(idx, query, param_oids);
517 if let Err(e) = &result
518 && e.is_connection_broken()
519 {
520 self.is_broken = true;
521 }
522 result
523 }
524
525 fn prepare_inner(
526 &mut self,
527 idx: u64,
528 query: &str,
529 param_oids: &[u32],
530 ) -> Result<PreparedStatement> {
531 let mut handler = DropHandler::new();
532 let mut state_machine = ExtendedQueryStateMachine::prepare(
533 &mut handler,
534 &mut self.buffer_set,
535 idx,
536 query,
537 param_oids,
538 );
539 self.drive(&mut state_machine)?;
540 state_machine
541 .take_prepared_statement()
542 .ok_or_else(|| Error::Protocol("No prepared statement".into()))
543 }
544
545 pub fn exec<S: IntoStatement, P: ToParams, H: BinaryHandler>(
562 &mut self,
563 statement: S,
564 params: P,
565 handler: &mut H,
566 ) -> Result<()> {
567 let result = self.exec_inner(&statement, ¶ms, handler);
568 if let Err(e) = &result
569 && e.is_connection_broken()
570 {
571 self.is_broken = true;
572 }
573 result
574 }
575
576 fn exec_inner<S: IntoStatement, P: ToParams, H: BinaryHandler>(
577 &mut self,
578 statement: &S,
579 params: &P,
580 handler: &mut H,
581 ) -> Result<()> {
582 let mut state_machine = if statement.needs_parse() {
583 ExtendedQueryStateMachine::execute_sql(
584 handler,
585 &mut self.buffer_set,
586 statement.as_sql().unwrap(),
587 params,
588 )?
589 } else {
590 let stmt = statement.as_prepared().unwrap();
591 ExtendedQueryStateMachine::execute(
592 handler,
593 &mut self.buffer_set,
594 &stmt.wire_name(),
595 &stmt.param_oids,
596 params,
597 )?
598 };
599
600 self.drive(&mut state_machine)
601 }
602
603 pub fn exec_drop<S: IntoStatement, P: ToParams>(
607 &mut self,
608 statement: S,
609 params: P,
610 ) -> Result<Option<u64>> {
611 let mut handler = DropHandler::new();
612 self.exec(statement, params, &mut handler)?;
613 Ok(handler.rows_affected())
614 }
615
616 pub fn exec_collect<
630 T: for<'a> crate::conversion::FromRow<'a>,
631 S: IntoStatement,
632 P: ToParams,
633 >(
634 &mut self,
635 statement: S,
636 params: P,
637 ) -> Result<Vec<T>> {
638 let mut handler = crate::handler::CollectHandler::<T>::new();
639 self.exec(statement, params, &mut handler)?;
640 Ok(handler.into_rows())
641 }
642
643 pub fn exec_first<T: for<'a> crate::conversion::FromRow<'a>, S: IntoStatement, P: ToParams>(
657 &mut self,
658 statement: S,
659 params: P,
660 ) -> Result<Option<T>> {
661 let mut handler = crate::handler::FirstRowHandler::<T>::new();
662 self.exec(statement, params, &mut handler)?;
663 Ok(handler.into_row())
664 }
665
666 pub fn exec_foreach<
679 T: for<'a> crate::conversion::FromRow<'a>,
680 S: IntoStatement,
681 P: ToParams,
682 F: FnMut(T),
683 >(
684 &mut self,
685 statement: S,
686 params: P,
687 f: F,
688 ) -> Result<()> {
689 let mut handler = crate::handler::ForEachHandler::<T, F>::new(f);
690 self.exec(statement, params, &mut handler)?;
691 Ok(())
692 }
693
694 pub fn exec_batch<S: IntoStatement, P: ToParams>(
725 &mut self,
726 statement: S,
727 params_list: &[P],
728 ) -> Result<()> {
729 self.exec_batch_chunked(statement, params_list, 1000)
730 }
731
732 pub fn exec_batch_chunked<S: IntoStatement, P: ToParams>(
736 &mut self,
737 statement: S,
738 params_list: &[P],
739 chunk_size: usize,
740 ) -> Result<()> {
741 let result = self.exec_batch_inner(&statement, params_list, chunk_size);
742 if let Err(e) = &result
743 && e.is_connection_broken()
744 {
745 self.is_broken = true;
746 }
747 result
748 }
749
750 fn exec_batch_inner<S: IntoStatement, P: ToParams>(
751 &mut self,
752 statement: &S,
753 params_list: &[P],
754 chunk_size: usize,
755 ) -> Result<()> {
756 use crate::protocol::frontend::{write_bind, write_execute, write_parse, write_sync};
757 use crate::state::extended::BatchStateMachine;
758
759 if params_list.is_empty() {
760 return Ok(());
761 }
762
763 let chunk_size = chunk_size.max(1);
764 let needs_parse = statement.needs_parse();
765 let sql = statement.as_sql();
766 let prepared = statement.as_prepared();
767
768 let param_oids: Vec<u32> = if let Some(stmt) = prepared {
770 stmt.param_oids.clone()
771 } else {
772 params_list[0].natural_oids()
773 };
774
775 let stmt_name = prepared.map(|s| s.wire_name()).unwrap_or_default();
777
778 for chunk in params_list.chunks(chunk_size) {
779 self.buffer_set.write_buffer.clear();
780
781 let parse_in_chunk = needs_parse;
783 if parse_in_chunk {
784 write_parse(
785 &mut self.buffer_set.write_buffer,
786 "",
787 sql.unwrap(),
788 ¶m_oids,
789 );
790 }
791
792 for params in chunk {
794 let effective_stmt_name = if needs_parse { "" } else { &stmt_name };
795 write_bind(
796 &mut self.buffer_set.write_buffer,
797 "",
798 effective_stmt_name,
799 params,
800 ¶m_oids,
801 )?;
802 write_execute(&mut self.buffer_set.write_buffer, "", 0);
803 }
804
805 write_sync(&mut self.buffer_set.write_buffer);
807
808 let mut state_machine = BatchStateMachine::new(parse_in_chunk);
810 self.drive_batch(&mut state_machine)?;
811 self.transaction_status = state_machine.transaction_status();
812 }
813
814 Ok(())
815 }
816
817 fn drive_batch(
819 &mut self,
820 state_machine: &mut crate::state::extended::BatchStateMachine,
821 ) -> Result<()> {
822 use crate::protocol::backend::{ReadyForQuery, msg_type};
823 use crate::state::action::Action;
824
825 loop {
826 let step_result = state_machine.step(&mut self.buffer_set);
827 match step_result {
828 Ok(Action::ReadMessage) => {
829 self.stream.read_message(&mut self.buffer_set)?;
830 }
831 Ok(Action::WriteAndReadMessage) => {
832 self.stream.write_all(&self.buffer_set.write_buffer)?;
833 self.stream.flush()?;
834 self.stream.read_message(&mut self.buffer_set)?;
835 }
836 Ok(Action::Finished) => {
837 break;
838 }
839 Ok(_) => return Err(Error::Protocol("Unexpected action in batch".into())),
840 Err(e) => {
841 loop {
843 self.stream.read_message(&mut self.buffer_set)?;
844 if self.buffer_set.type_byte == msg_type::READY_FOR_QUERY {
845 let ready = ReadyForQuery::parse(&self.buffer_set.read_buffer)?;
846 self.transaction_status =
847 ready.transaction_status().unwrap_or_default();
848 break;
849 }
850 }
851 return Err(e);
852 }
853 }
854 }
855 Ok(())
856 }
857
858 pub fn close_statement(&mut self, stmt: &PreparedStatement) -> Result<()> {
860 let result = self.close_statement_inner(&stmt.wire_name());
861 if let Err(e) = &result
862 && e.is_connection_broken()
863 {
864 self.is_broken = true;
865 }
866 result
867 }
868
869 fn close_statement_inner(&mut self, name: &str) -> Result<()> {
870 let mut handler = DropHandler::new();
871 let mut state_machine =
872 ExtendedQueryStateMachine::close_statement(&mut handler, &mut self.buffer_set, name);
873 self.drive(&mut state_machine)
874 }
875
876 pub fn tx<F, R>(&mut self, f: F) -> Result<R>
886 where
887 F: FnOnce(&mut Conn, super::transaction::Transaction) -> Result<R>,
888 {
889 if self.in_transaction() {
890 return Err(Error::InvalidUsage(
891 "nested transactions are not supported".into(),
892 ));
893 }
894
895 self.query_drop("BEGIN")?;
896
897 let tx = super::transaction::Transaction::new(self.connection_id());
898 let result = f(self, tx);
899
900 if self.in_transaction() {
902 let rollback_result = self.query_drop("ROLLBACK");
903
904 if let Err(e) = result {
906 return Err(e);
907 }
908 rollback_result?;
909 }
910
911 result
912 }
913}
914
915impl Conn {
918 pub fn lowlevel_bind<P: ToParams>(
928 &mut self,
929 portal: &str,
930 statement_name: &str,
931 params: P,
932 ) -> Result<()> {
933 let result = self.lowlevel_bind_inner(portal, statement_name, ¶ms);
934 if let Err(e) = &result
935 && e.is_connection_broken()
936 {
937 self.is_broken = true;
938 }
939 result
940 }
941
942 fn lowlevel_bind_inner<P: ToParams>(
943 &mut self,
944 portal: &str,
945 statement_name: &str,
946 params: &P,
947 ) -> Result<()> {
948 use crate::protocol::backend::{BindComplete, ErrorResponse, RawMessage, msg_type};
949 use crate::protocol::frontend::{write_bind, write_flush};
950
951 let param_oids = params.natural_oids();
952 self.buffer_set.write_buffer.clear();
953 write_bind(
954 &mut self.buffer_set.write_buffer,
955 portal,
956 statement_name,
957 params,
958 ¶m_oids,
959 )?;
960 write_flush(&mut self.buffer_set.write_buffer);
961
962 self.stream.write_all(&self.buffer_set.write_buffer)?;
963 self.stream.flush()?;
964
965 loop {
966 self.stream.read_message(&mut self.buffer_set)?;
967 let type_byte = self.buffer_set.type_byte;
968
969 if RawMessage::is_async_type(type_byte) {
970 continue;
971 }
972
973 match type_byte {
974 msg_type::BIND_COMPLETE => {
975 BindComplete::parse(&self.buffer_set.read_buffer)?;
976 return Ok(());
977 }
978 msg_type::ERROR_RESPONSE => {
979 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
980 return Err(error.into_error());
981 }
982 _ => {
983 return Err(Error::Protocol(format!(
984 "Expected BindComplete or ErrorResponse, got '{}'",
985 type_byte as char
986 )));
987 }
988 }
989 }
990 }
991
992 pub fn lowlevel_execute<H: BinaryHandler>(
1005 &mut self,
1006 portal: &str,
1007 max_rows: u32,
1008 handler: &mut H,
1009 ) -> Result<bool> {
1010 let result = self.lowlevel_execute_inner(portal, max_rows, handler);
1011 if let Err(e) = &result
1012 && e.is_connection_broken()
1013 {
1014 self.is_broken = true;
1015 }
1016 result
1017 }
1018
1019 fn lowlevel_execute_inner<H: BinaryHandler>(
1020 &mut self,
1021 portal: &str,
1022 max_rows: u32,
1023 handler: &mut H,
1024 ) -> Result<bool> {
1025 use crate::protocol::backend::{
1026 CommandComplete, DataRow, ErrorResponse, NoData, PortalSuspended, RawMessage,
1027 RowDescription, msg_type,
1028 };
1029 use crate::protocol::frontend::{write_describe_portal, write_execute, write_flush};
1030
1031 self.buffer_set.write_buffer.clear();
1032 write_describe_portal(&mut self.buffer_set.write_buffer, portal);
1033 write_execute(&mut self.buffer_set.write_buffer, portal, max_rows);
1034 write_flush(&mut self.buffer_set.write_buffer);
1035
1036 self.stream.write_all(&self.buffer_set.write_buffer)?;
1037 self.stream.flush()?;
1038
1039 let mut column_buffer: Vec<u8> = Vec::new();
1040
1041 loop {
1042 self.stream.read_message(&mut self.buffer_set)?;
1043 let type_byte = self.buffer_set.type_byte;
1044
1045 if RawMessage::is_async_type(type_byte) {
1046 continue;
1047 }
1048
1049 match type_byte {
1050 msg_type::ROW_DESCRIPTION => {
1051 column_buffer.clear();
1052 column_buffer.extend_from_slice(&self.buffer_set.read_buffer);
1053 let cols = RowDescription::parse(&column_buffer)?;
1054 handler.result_start(cols)?;
1055 }
1056 msg_type::NO_DATA => {
1057 NoData::parse(&self.buffer_set.read_buffer)?;
1058 }
1059 msg_type::DATA_ROW => {
1060 let cols = RowDescription::parse(&column_buffer)?;
1061 let row = DataRow::parse(&self.buffer_set.read_buffer)?;
1062 handler.row(cols, row)?;
1063 }
1064 msg_type::COMMAND_COMPLETE => {
1065 let complete = CommandComplete::parse(&self.buffer_set.read_buffer)?;
1066 handler.result_end(complete)?;
1067 return Ok(false); }
1069 msg_type::PORTAL_SUSPENDED => {
1070 PortalSuspended::parse(&self.buffer_set.read_buffer)?;
1071 return Ok(true); }
1073 msg_type::ERROR_RESPONSE => {
1074 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
1075 return Err(error.into_error());
1076 }
1077 _ => {
1078 return Err(Error::Protocol(format!(
1079 "Unexpected message in execute: '{}'",
1080 type_byte as char
1081 )));
1082 }
1083 }
1084 }
1085 }
1086
1087 pub fn lowlevel_sync(&mut self) -> Result<()> {
1094 let result = self.lowlevel_sync_inner();
1095 if let Err(e) = &result
1096 && e.is_connection_broken()
1097 {
1098 self.is_broken = true;
1099 }
1100 result
1101 }
1102
1103 fn lowlevel_sync_inner(&mut self) -> Result<()> {
1104 use crate::protocol::backend::{ErrorResponse, RawMessage, ReadyForQuery, msg_type};
1105 use crate::protocol::frontend::write_sync;
1106
1107 self.buffer_set.write_buffer.clear();
1108 write_sync(&mut self.buffer_set.write_buffer);
1109
1110 self.stream.write_all(&self.buffer_set.write_buffer)?;
1111 self.stream.flush()?;
1112
1113 let mut pending_error: Option<Error> = None;
1114
1115 loop {
1116 self.stream.read_message(&mut self.buffer_set)?;
1117 let type_byte = self.buffer_set.type_byte;
1118
1119 if RawMessage::is_async_type(type_byte) {
1120 continue;
1121 }
1122
1123 match type_byte {
1124 msg_type::READY_FOR_QUERY => {
1125 let ready = ReadyForQuery::parse(&self.buffer_set.read_buffer)?;
1126 self.transaction_status = ready.transaction_status().unwrap_or_default();
1127 if let Some(e) = pending_error {
1128 return Err(e);
1129 }
1130 return Ok(());
1131 }
1132 msg_type::ERROR_RESPONSE => {
1133 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
1134 pending_error = Some(error.into_error());
1135 }
1136 _ => {
1137 }
1139 }
1140 }
1141 }
1142
1143 pub fn lowlevel_flush(&mut self) -> Result<()> {
1150 use crate::protocol::frontend::write_flush;
1151
1152 self.buffer_set.write_buffer.clear();
1153 write_flush(&mut self.buffer_set.write_buffer);
1154
1155 self.stream.write_all(&self.buffer_set.write_buffer)?;
1156 self.stream.flush()?;
1157 Ok(())
1158 }
1159
1160 pub fn exec_portal<S: IntoStatement, P, F, T>(
1190 &mut self,
1191 statement: S,
1192 params: P,
1193 f: F,
1194 ) -> Result<T>
1195 where
1196 P: ToParams,
1197 F: FnOnce(&mut UnnamedPortal<'_>) -> Result<T>,
1198 {
1199 let result = self.exec_portal_inner(&statement, ¶ms, f);
1200 if let Err(e) = &result
1201 && e.is_connection_broken()
1202 {
1203 self.is_broken = true;
1204 }
1205 result
1206 }
1207
1208 fn exec_portal_inner<S: IntoStatement, P, F, T>(
1209 &mut self,
1210 statement: &S,
1211 params: &P,
1212 f: F,
1213 ) -> Result<T>
1214 where
1215 P: ToParams,
1216 F: FnOnce(&mut UnnamedPortal<'_>) -> Result<T>,
1217 {
1218 let mut state_machine = if let Some(sql) = statement.as_sql() {
1220 BindStateMachine::bind_sql(&mut self.buffer_set, "", sql, params)?
1221 } else {
1222 let stmt = statement.as_prepared().unwrap();
1223 BindStateMachine::bind_prepared(
1224 &mut self.buffer_set,
1225 "",
1226 &stmt.wire_name(),
1227 &stmt.param_oids,
1228 params,
1229 )?
1230 };
1231
1232 loop {
1234 match state_machine.step(&mut self.buffer_set)? {
1235 Action::ReadMessage => {
1236 self.stream.read_message(&mut self.buffer_set)?;
1237 }
1238 Action::Write => {
1239 self.stream.write_all(&self.buffer_set.write_buffer)?;
1240 self.stream.flush()?;
1241 }
1242 Action::WriteAndReadMessage => {
1243 self.stream.write_all(&self.buffer_set.write_buffer)?;
1244 self.stream.flush()?;
1245 self.stream.read_message(&mut self.buffer_set)?;
1246 }
1247 Action::Finished => break,
1248 _ => return Err(Error::Protocol("Unexpected action in bind".into())),
1249 }
1250 }
1251
1252 let mut portal = UnnamedPortal { conn: self };
1254 let result = f(&mut portal);
1255
1256 let sync_result = portal.conn.lowlevel_sync();
1258
1259 match (result, sync_result) {
1261 (Ok(v), Ok(())) => Ok(v),
1262 (Err(e), _) => Err(e),
1263 (Ok(_), Err(e)) => Err(e),
1264 }
1265 }
1266
1267 pub fn lowlevel_close_portal(&mut self, portal: &str) -> Result<()> {
1269 let result = self.lowlevel_close_portal_inner(portal);
1270 if let Err(e) = &result
1271 && e.is_connection_broken()
1272 {
1273 self.is_broken = true;
1274 }
1275 result
1276 }
1277
1278 fn lowlevel_close_portal_inner(&mut self, portal: &str) -> Result<()> {
1279 use crate::protocol::backend::{CloseComplete, ErrorResponse, RawMessage, msg_type};
1280 use crate::protocol::frontend::{write_close_portal, write_flush};
1281
1282 self.buffer_set.write_buffer.clear();
1283 write_close_portal(&mut self.buffer_set.write_buffer, portal);
1284 write_flush(&mut self.buffer_set.write_buffer);
1285
1286 self.stream.write_all(&self.buffer_set.write_buffer)?;
1287 self.stream.flush()?;
1288
1289 loop {
1290 self.stream.read_message(&mut self.buffer_set)?;
1291 let type_byte = self.buffer_set.type_byte;
1292
1293 if RawMessage::is_async_type(type_byte) {
1294 continue;
1295 }
1296
1297 match type_byte {
1298 msg_type::CLOSE_COMPLETE => {
1299 CloseComplete::parse(&self.buffer_set.read_buffer)?;
1300 return Ok(());
1301 }
1302 msg_type::ERROR_RESPONSE => {
1303 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
1304 return Err(error.into_error());
1305 }
1306 _ => {
1307 return Err(Error::Protocol(format!(
1308 "Expected CloseComplete or ErrorResponse, got '{}'",
1309 type_byte as char
1310 )));
1311 }
1312 }
1313 }
1314 }
1315}
1316
1317impl Drop for Conn {
1318 fn drop(&mut self) {
1319 self.buffer_set.write_buffer.clear();
1321 write_terminate(&mut self.buffer_set.write_buffer);
1322 let _ = self.stream.write_all(&self.buffer_set.write_buffer);
1323 let _ = self.stream.flush();
1324 }
1325}