1use std::net::TcpStream;
4#[cfg(unix)]
5use std::os::unix::net::UnixStream;
6
7use crate::buffer_pool::PooledBufferSet;
8use crate::conversion::ToParams;
9use crate::error::{Error, Result};
10use crate::handler::{
11 AsyncMessageHandler, BinaryHandler, DropHandler, FirstRowHandler, TextHandler,
12};
13use crate::opts::Opts;
14use crate::protocol::backend::BackendKeyData;
15use crate::protocol::frontend::write_terminate;
16use crate::protocol::types::TransactionStatus;
17use crate::state::StateMachine;
18use crate::state::action::Action;
19use crate::state::connection::ConnectionStateMachine;
20use crate::state::extended::{BindStateMachine, ExtendedQueryStateMachine, PreparedStatement};
21use crate::state::simple_query::SimpleQueryStateMachine;
22use crate::statement::IntoStatement;
23
24use super::stream::Stream;
25use super::unnamed_portal::UnnamedPortal;
26
27pub struct Conn {
29 pub(crate) stream: Stream,
30 pub(crate) buffer_set: PooledBufferSet,
31 backend_key: Option<BackendKeyData>,
32 server_params: Vec<(String, String)>,
33 pub(crate) transaction_status: TransactionStatus,
34 pub(crate) is_broken: bool,
35 name_counter: u64,
36 async_message_handler: Option<Box<dyn AsyncMessageHandler>>,
37}
38
39impl Conn {
40 pub fn new<O: TryInto<Opts>>(opts: O) -> Result<Self>
42 where
43 Error: From<O::Error>,
44 {
45 let opts = opts.try_into()?;
46
47 let stream = if let Some(socket_path) = &opts.socket {
48 #[cfg(unix)]
49 {
50 Stream::unix(UnixStream::connect(socket_path)?)
51 }
52 #[cfg(not(unix))]
53 {
54 let _ = socket_path;
55 return Err(Error::Unsupported(
56 "Unix sockets are not supported on this platform".into(),
57 ));
58 }
59 } else {
60 if opts.host.is_empty() {
61 return Err(Error::InvalidUsage("host is empty".into()));
62 }
63 let addr = format!("{}:{}", opts.host, opts.port);
64 let tcp = TcpStream::connect(&addr)?;
65 tcp.set_nodelay(true)?;
66 Stream::tcp(tcp)
67 };
68
69 Self::new_with_stream(stream, opts)
70 }
71
72 #[allow(unused_mut)]
74 pub fn new_with_stream(mut stream: Stream, options: Opts) -> Result<Self> {
75 let mut buffer_set = options.buffer_pool.get_buffer_set();
76 let mut state_machine = ConnectionStateMachine::new(options.clone());
77
78 loop {
80 match state_machine.step(&mut buffer_set)? {
81 Action::WriteAndReadByte => {
82 stream.write_all(&buffer_set.write_buffer)?;
83 stream.flush()?;
84 let byte = stream.read_u8()?;
85 state_machine.set_ssl_response(byte);
86 }
87 Action::ReadMessage => {
88 stream.read_message(&mut buffer_set)?;
89 }
90 Action::Write => {
91 stream.write_all(&buffer_set.write_buffer)?;
92 stream.flush()?;
93 }
94 Action::WriteAndReadMessage => {
95 stream.write_all(&buffer_set.write_buffer)?;
96 stream.flush()?;
97 stream.read_message(&mut buffer_set)?;
98 }
99 Action::TlsHandshake => {
100 #[cfg(feature = "sync-tls")]
101 {
102 stream = stream.upgrade_to_tls(&options.host)?;
103 }
104 #[cfg(not(feature = "sync-tls"))]
105 {
106 return Err(Error::Unsupported(
107 "TLS requested but sync-tls feature not enabled".into(),
108 ));
109 }
110 }
111 Action::HandleAsyncMessageAndReadMessage(_) => {
112 stream.read_message(&mut buffer_set)?;
114 }
115 Action::Finished => break,
116 }
117 }
118
119 let conn = Self {
120 stream,
121 buffer_set,
122 backend_key: state_machine.backend_key().cloned(),
123 server_params: state_machine.take_server_params(),
124 transaction_status: state_machine.transaction_status(),
125 is_broken: false,
126 name_counter: 0,
127 async_message_handler: None,
128 };
129
130 #[cfg(unix)]
132 let conn = if options.upgrade_to_unix_socket && conn.stream.is_tcp_loopback() {
133 conn.try_upgrade_to_unix_socket(&options)
134 } else {
135 conn
136 };
137
138 Ok(conn)
139 }
140
141 #[cfg(unix)]
144 fn try_upgrade_to_unix_socket(mut self, opts: &Opts) -> Self {
145 let mut handler = FirstRowHandler::<(String,)>::new();
147 if self
148 .query("SHOW unix_socket_directories", &mut handler)
149 .is_err()
150 {
151 return self;
152 }
153
154 let socket_dir = match handler.into_row() {
155 Some((dirs,)) => {
156 match dirs.split(',').next() {
158 Some(d) if !d.trim().is_empty() => d.trim().to_string(),
159 _ => return self,
160 }
161 }
162 None => return self,
163 };
164
165 let socket_path = format!("{}/.s.PGSQL.{}", socket_dir, opts.port);
167
168 let unix_stream = match UnixStream::connect(&socket_path) {
170 Ok(s) => s,
171 Err(_) => return self,
172 };
173
174 let mut opts_unix = opts.clone();
176 opts_unix.upgrade_to_unix_socket = false;
177
178 match Self::new_with_stream(Stream::unix(unix_stream), opts_unix) {
179 Ok(new_conn) => new_conn,
180 Err(_) => self,
181 }
182 }
183
184 pub fn backend_key(&self) -> Option<&BackendKeyData> {
186 self.backend_key.as_ref()
187 }
188
189 pub fn connection_id(&self) -> u32 {
193 self.backend_key.as_ref().map_or(0, |k| k.process_id())
194 }
195
196 pub fn server_params(&self) -> &[(String, String)] {
198 &self.server_params
199 }
200
201 pub fn transaction_status(&self) -> TransactionStatus {
203 self.transaction_status
204 }
205
206 pub fn in_transaction(&self) -> bool {
208 self.transaction_status.in_transaction()
209 }
210
211 pub fn is_broken(&self) -> bool {
213 self.is_broken
214 }
215
216 pub(crate) fn next_portal_name(&mut self) -> String {
218 self.name_counter += 1;
219 format!("_zero_p_{}", self.name_counter)
220 }
221
222 pub(crate) fn create_named_portal<S: IntoStatement, P: ToParams>(
226 &mut self,
227 portal_name: &str,
228 statement: &S,
229 params: &P,
230 ) -> Result<()> {
231 let mut state_machine = if let Some(sql) = statement.as_sql() {
233 BindStateMachine::bind_sql(&mut self.buffer_set, portal_name, sql, params)?
234 } else {
235 let stmt = statement.as_prepared().unwrap();
236 BindStateMachine::bind_prepared(
237 &mut self.buffer_set,
238 portal_name,
239 &stmt.wire_name(),
240 &stmt.param_oids,
241 params,
242 )?
243 };
244
245 loop {
247 match state_machine.step(&mut self.buffer_set)? {
248 Action::ReadMessage => {
249 self.stream.read_message(&mut self.buffer_set)?;
250 }
251 Action::Write => {
252 self.stream.write_all(&self.buffer_set.write_buffer)?;
253 self.stream.flush()?;
254 }
255 Action::WriteAndReadMessage => {
256 self.stream.write_all(&self.buffer_set.write_buffer)?;
257 self.stream.flush()?;
258 self.stream.read_message(&mut self.buffer_set)?;
259 }
260 Action::Finished => break,
261 _ => return Err(Error::Protocol("Unexpected action in bind".into())),
262 }
263 }
264
265 Ok(())
266 }
267
268 pub fn set_async_message_handler<H: AsyncMessageHandler + 'static>(&mut self, handler: H) {
275 self.async_message_handler = Some(Box::new(handler));
276 }
277
278 pub fn clear_async_message_handler(&mut self) {
280 self.async_message_handler = None;
281 }
282
283 pub fn run_pipeline<T, F>(&mut self, f: F) -> Result<T>
304 where
305 F: FnOnce(&mut super::pipeline::Pipeline<'_>) -> Result<T>,
306 {
307 let mut pipeline = super::pipeline::Pipeline::new_inner(self);
308 let result = f(&mut pipeline);
309 pipeline.cleanup();
310 result
311 }
312
313 pub fn ping(&mut self) -> Result<()> {
315 self.query_drop("")?;
316 Ok(())
317 }
318
319 fn drive<S: StateMachine>(&mut self, state_machine: &mut S) -> Result<()> {
321 loop {
322 match state_machine.step(&mut self.buffer_set)? {
323 Action::WriteAndReadByte => {
324 return Err(Error::Protocol(
325 "Unexpected WriteAndReadByte in query state machine".into(),
326 ));
327 }
328 Action::ReadMessage => {
329 self.stream.read_message(&mut self.buffer_set)?;
330 }
331 Action::Write => {
332 self.stream.write_all(&self.buffer_set.write_buffer)?;
333 self.stream.flush()?;
334 }
335 Action::WriteAndReadMessage => {
336 self.stream.write_all(&self.buffer_set.write_buffer)?;
337 self.stream.flush()?;
338 self.stream.read_message(&mut self.buffer_set)?;
339 }
340 Action::TlsHandshake => {
341 return Err(Error::Protocol(
342 "Unexpected TlsHandshake in query state machine".into(),
343 ));
344 }
345 Action::HandleAsyncMessageAndReadMessage(ref async_msg) => {
346 if let Some(ref mut h) = self.async_message_handler {
347 h.handle(async_msg);
348 }
349 self.stream.read_message(&mut self.buffer_set)?;
351 }
352 Action::Finished => {
353 self.transaction_status = state_machine.transaction_status();
354 break;
355 }
356 }
357 }
358 Ok(())
359 }
360
361 pub fn query<H: TextHandler>(&mut self, sql: &str, handler: &mut H) -> Result<()> {
363 let result = self.query_inner(sql, handler);
364 if let Err(e) = &result
365 && e.is_connection_broken()
366 {
367 self.is_broken = true;
368 }
369 result
370 }
371
372 fn query_inner<H: TextHandler>(&mut self, sql: &str, handler: &mut H) -> Result<()> {
373 let mut state_machine = SimpleQueryStateMachine::new(handler, sql);
374 self.drive(&mut state_machine)
375 }
376
377 pub fn query_drop(&mut self, sql: &str) -> Result<Option<u64>> {
379 let mut handler = DropHandler::new();
380 self.query(sql, &mut handler)?;
381 Ok(handler.rows_affected())
382 }
383
384 pub fn query_collect<T: for<'a> crate::conversion::FromRow<'a>>(
395 &mut self,
396 sql: &str,
397 ) -> Result<Vec<T>> {
398 let mut handler = crate::handler::CollectHandler::<T>::new();
399 self.query(sql, &mut handler)?;
400 Ok(handler.into_rows())
401 }
402
403 pub fn query_first<T: for<'a> crate::conversion::FromRow<'a>>(
405 &mut self,
406 sql: &str,
407 ) -> Result<Option<T>> {
408 let mut handler = crate::handler::FirstRowHandler::<T>::new();
409 self.query(sql, &mut handler)?;
410 Ok(handler.into_row())
411 }
412
413 pub fn query_foreach<T: for<'a> crate::conversion::FromRow<'a>, F: FnMut(T) -> Result<()>>(
426 &mut self,
427 sql: &str,
428 f: F,
429 ) -> Result<()> {
430 let mut handler = crate::handler::ForEachHandler::<T, F>::new(f);
431 self.query(sql, &mut handler)?;
432 Ok(())
433 }
434
435 pub fn close(mut self) -> Result<()> {
437 self.buffer_set.write_buffer.clear();
438 write_terminate(&mut self.buffer_set.write_buffer);
439 self.stream.write_all(&self.buffer_set.write_buffer)?;
440 self.stream.flush()?;
441 Ok(())
442 }
443
444 pub fn prepare(&mut self, query: &str) -> Result<PreparedStatement> {
448 self.prepare_typed(query, &[])
449 }
450
451 pub fn prepare_batch(&mut self, queries: &[&str]) -> Result<Vec<PreparedStatement>> {
468 if queries.is_empty() {
469 return Ok(Vec::new());
470 }
471
472 let start_idx = self.name_counter + 1;
473 self.name_counter += queries.len() as u64;
474
475 let result = self.prepare_batch_inner(queries, start_idx);
476 if let Err(e) = &result
477 && e.is_connection_broken()
478 {
479 self.is_broken = true;
480 }
481 result
482 }
483
484 fn prepare_batch_inner(
485 &mut self,
486 queries: &[&str],
487 start_idx: u64,
488 ) -> Result<Vec<PreparedStatement>> {
489 use crate::state::batch_prepare::BatchPrepareStateMachine;
490
491 let mut state_machine =
492 BatchPrepareStateMachine::new(&mut self.buffer_set, queries, start_idx);
493
494 loop {
495 match state_machine.step(&mut self.buffer_set)? {
496 Action::ReadMessage => {
497 self.stream.read_message(&mut self.buffer_set)?;
498 }
499 Action::WriteAndReadMessage => {
500 self.stream.write_all(&self.buffer_set.write_buffer)?;
501 self.stream.flush()?;
502 self.stream.read_message(&mut self.buffer_set)?;
503 }
504 Action::Finished => {
505 self.transaction_status = state_machine.transaction_status();
506 break;
507 }
508 _ => return Err(Error::Protocol("Unexpected action in batch prepare".into())),
509 }
510 }
511
512 Ok(state_machine.take_statements())
513 }
514
515 pub fn prepare_typed(&mut self, query: &str, param_oids: &[u32]) -> Result<PreparedStatement> {
517 self.name_counter += 1;
518 let idx = self.name_counter;
519 let result = self.prepare_inner(idx, query, param_oids);
520 if let Err(e) = &result
521 && e.is_connection_broken()
522 {
523 self.is_broken = true;
524 }
525 result
526 }
527
528 fn prepare_inner(
529 &mut self,
530 idx: u64,
531 query: &str,
532 param_oids: &[u32],
533 ) -> Result<PreparedStatement> {
534 let mut handler = DropHandler::new();
535 let mut state_machine = ExtendedQueryStateMachine::prepare(
536 &mut handler,
537 &mut self.buffer_set,
538 idx,
539 query,
540 param_oids,
541 );
542 self.drive(&mut state_machine)?;
543 state_machine
544 .take_prepared_statement()
545 .ok_or_else(|| Error::Protocol("No prepared statement".into()))
546 }
547
548 pub fn exec<S: IntoStatement, P: ToParams, H: BinaryHandler>(
565 &mut self,
566 statement: S,
567 params: P,
568 handler: &mut H,
569 ) -> Result<()> {
570 let result = self.exec_inner(&statement, ¶ms, handler);
571 if let Err(e) = &result
572 && e.is_connection_broken()
573 {
574 self.is_broken = true;
575 }
576 result
577 }
578
579 fn exec_inner<S: IntoStatement, P: ToParams, H: BinaryHandler>(
580 &mut self,
581 statement: &S,
582 params: &P,
583 handler: &mut H,
584 ) -> Result<()> {
585 let mut state_machine = if statement.needs_parse() {
586 ExtendedQueryStateMachine::execute_sql(
587 handler,
588 &mut self.buffer_set,
589 statement.as_sql().unwrap(),
590 params,
591 )?
592 } else {
593 let stmt = statement.as_prepared().unwrap();
594 ExtendedQueryStateMachine::execute(
595 handler,
596 &mut self.buffer_set,
597 &stmt.wire_name(),
598 &stmt.param_oids,
599 params,
600 )?
601 };
602
603 self.drive(&mut state_machine)
604 }
605
606 pub fn exec_drop<S: IntoStatement, P: ToParams>(
610 &mut self,
611 statement: S,
612 params: P,
613 ) -> Result<Option<u64>> {
614 let mut handler = DropHandler::new();
615 self.exec(statement, params, &mut handler)?;
616 Ok(handler.rows_affected())
617 }
618
619 pub fn exec_collect<
633 T: for<'a> crate::conversion::FromRow<'a>,
634 S: IntoStatement,
635 P: ToParams,
636 >(
637 &mut self,
638 statement: S,
639 params: P,
640 ) -> Result<Vec<T>> {
641 let mut handler = crate::handler::CollectHandler::<T>::new();
642 self.exec(statement, params, &mut handler)?;
643 Ok(handler.into_rows())
644 }
645
646 pub fn exec_first<T: for<'a> crate::conversion::FromRow<'a>, S: IntoStatement, P: ToParams>(
660 &mut self,
661 statement: S,
662 params: P,
663 ) -> Result<Option<T>> {
664 let mut handler = crate::handler::FirstRowHandler::<T>::new();
665 self.exec(statement, params, &mut handler)?;
666 Ok(handler.into_row())
667 }
668
669 pub fn exec_foreach<
685 T: for<'a> crate::conversion::FromRow<'a>,
686 S: IntoStatement,
687 P: ToParams,
688 F: FnMut(T) -> Result<()>,
689 >(
690 &mut self,
691 statement: S,
692 params: P,
693 f: F,
694 ) -> Result<()> {
695 let mut handler = crate::handler::ForEachHandler::<T, F>::new(f);
696 self.exec(statement, params, &mut handler)?;
697 Ok(())
698 }
699
700 pub fn exec_batch<S: IntoStatement, P: ToParams>(
731 &mut self,
732 statement: S,
733 params_list: &[P],
734 ) -> Result<()> {
735 self.exec_batch_chunked(statement, params_list, 1000)
736 }
737
738 pub fn exec_batch_chunked<S: IntoStatement, P: ToParams>(
742 &mut self,
743 statement: S,
744 params_list: &[P],
745 chunk_size: usize,
746 ) -> Result<()> {
747 let result = self.exec_batch_inner(&statement, params_list, chunk_size);
748 if let Err(e) = &result
749 && e.is_connection_broken()
750 {
751 self.is_broken = true;
752 }
753 result
754 }
755
756 fn exec_batch_inner<S: IntoStatement, P: ToParams>(
757 &mut self,
758 statement: &S,
759 params_list: &[P],
760 chunk_size: usize,
761 ) -> Result<()> {
762 use crate::protocol::frontend::{write_bind, write_execute, write_parse, write_sync};
763 use crate::state::extended::BatchStateMachine;
764
765 if params_list.is_empty() {
766 return Ok(());
767 }
768
769 let chunk_size = chunk_size.max(1);
770 let needs_parse = statement.needs_parse();
771 let sql = statement.as_sql();
772 let prepared = statement.as_prepared();
773
774 let param_oids: Vec<u32> = if let Some(stmt) = prepared {
776 stmt.param_oids.clone()
777 } else {
778 params_list[0].natural_oids()
779 };
780
781 let stmt_name = prepared.map(|s| s.wire_name()).unwrap_or_default();
783
784 for chunk in params_list.chunks(chunk_size) {
785 self.buffer_set.write_buffer.clear();
786
787 let parse_in_chunk = needs_parse;
789 if parse_in_chunk {
790 write_parse(
791 &mut self.buffer_set.write_buffer,
792 "",
793 sql.unwrap(),
794 ¶m_oids,
795 );
796 }
797
798 for params in chunk {
800 let effective_stmt_name = if needs_parse { "" } else { &stmt_name };
801 write_bind(
802 &mut self.buffer_set.write_buffer,
803 "",
804 effective_stmt_name,
805 params,
806 ¶m_oids,
807 )?;
808 write_execute(&mut self.buffer_set.write_buffer, "", 0);
809 }
810
811 write_sync(&mut self.buffer_set.write_buffer);
813
814 let mut state_machine = BatchStateMachine::new(parse_in_chunk);
816 self.drive_batch(&mut state_machine)?;
817 self.transaction_status = state_machine.transaction_status();
818 }
819
820 Ok(())
821 }
822
823 fn drive_batch(
825 &mut self,
826 state_machine: &mut crate::state::extended::BatchStateMachine,
827 ) -> Result<()> {
828 use crate::protocol::backend::{ReadyForQuery, msg_type};
829 use crate::state::action::Action;
830
831 loop {
832 let step_result = state_machine.step(&mut self.buffer_set);
833 match step_result {
834 Ok(Action::ReadMessage) => {
835 self.stream.read_message(&mut self.buffer_set)?;
836 }
837 Ok(Action::WriteAndReadMessage) => {
838 self.stream.write_all(&self.buffer_set.write_buffer)?;
839 self.stream.flush()?;
840 self.stream.read_message(&mut self.buffer_set)?;
841 }
842 Ok(Action::Finished) => {
843 break;
844 }
845 Ok(_) => return Err(Error::Protocol("Unexpected action in batch".into())),
846 Err(e) => {
847 loop {
849 self.stream.read_message(&mut self.buffer_set)?;
850 if self.buffer_set.type_byte == msg_type::READY_FOR_QUERY {
851 let ready = ReadyForQuery::parse(&self.buffer_set.read_buffer)?;
852 self.transaction_status =
853 ready.transaction_status().unwrap_or_default();
854 break;
855 }
856 }
857 return Err(e);
858 }
859 }
860 }
861 Ok(())
862 }
863
864 pub fn close_statement(&mut self, stmt: &PreparedStatement) -> Result<()> {
866 let result = self.close_statement_inner(&stmt.wire_name());
867 if let Err(e) = &result
868 && e.is_connection_broken()
869 {
870 self.is_broken = true;
871 }
872 result
873 }
874
875 fn close_statement_inner(&mut self, name: &str) -> Result<()> {
876 let mut handler = DropHandler::new();
877 let mut state_machine =
878 ExtendedQueryStateMachine::close_statement(&mut handler, &mut self.buffer_set, name);
879 self.drive(&mut state_machine)
880 }
881
882 pub fn tx<F, R>(&mut self, f: F) -> Result<R>
892 where
893 F: FnOnce(&mut Conn, super::transaction::Transaction) -> Result<R>,
894 {
895 if self.in_transaction() {
896 return Err(Error::InvalidUsage(
897 "nested transactions are not supported".into(),
898 ));
899 }
900
901 self.query_drop("BEGIN")?;
902
903 let tx = super::transaction::Transaction::new(self.connection_id());
904 let result = f(self, tx);
905
906 if self.in_transaction() {
908 let rollback_result = self.query_drop("ROLLBACK");
909
910 if let Err(e) = result {
912 return Err(e);
913 }
914 rollback_result?;
915 }
916
917 result
918 }
919}
920
921impl Conn {
924 pub fn lowlevel_bind<P: ToParams>(
934 &mut self,
935 portal: &str,
936 statement_name: &str,
937 params: P,
938 ) -> Result<()> {
939 let result = self.lowlevel_bind_inner(portal, statement_name, ¶ms);
940 if let Err(e) = &result
941 && e.is_connection_broken()
942 {
943 self.is_broken = true;
944 }
945 result
946 }
947
948 fn lowlevel_bind_inner<P: ToParams>(
949 &mut self,
950 portal: &str,
951 statement_name: &str,
952 params: &P,
953 ) -> Result<()> {
954 use crate::protocol::backend::{BindComplete, ErrorResponse, RawMessage, msg_type};
955 use crate::protocol::frontend::{write_bind, write_flush};
956
957 let param_oids = params.natural_oids();
958 self.buffer_set.write_buffer.clear();
959 write_bind(
960 &mut self.buffer_set.write_buffer,
961 portal,
962 statement_name,
963 params,
964 ¶m_oids,
965 )?;
966 write_flush(&mut self.buffer_set.write_buffer);
967
968 self.stream.write_all(&self.buffer_set.write_buffer)?;
969 self.stream.flush()?;
970
971 loop {
972 self.stream.read_message(&mut self.buffer_set)?;
973 let type_byte = self.buffer_set.type_byte;
974
975 if RawMessage::is_async_type(type_byte) {
976 continue;
977 }
978
979 match type_byte {
980 msg_type::BIND_COMPLETE => {
981 BindComplete::parse(&self.buffer_set.read_buffer)?;
982 return Ok(());
983 }
984 msg_type::ERROR_RESPONSE => {
985 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
986 return Err(error.into_error());
987 }
988 _ => {
989 return Err(Error::Protocol(format!(
990 "Expected BindComplete or ErrorResponse, got '{}'",
991 type_byte as char
992 )));
993 }
994 }
995 }
996 }
997
998 pub fn lowlevel_execute<H: BinaryHandler>(
1011 &mut self,
1012 portal: &str,
1013 max_rows: u32,
1014 handler: &mut H,
1015 ) -> Result<bool> {
1016 let result = self.lowlevel_execute_inner(portal, max_rows, handler);
1017 if let Err(e) = &result
1018 && e.is_connection_broken()
1019 {
1020 self.is_broken = true;
1021 }
1022 result
1023 }
1024
1025 fn lowlevel_execute_inner<H: BinaryHandler>(
1026 &mut self,
1027 portal: &str,
1028 max_rows: u32,
1029 handler: &mut H,
1030 ) -> Result<bool> {
1031 use crate::protocol::backend::{
1032 CommandComplete, DataRow, ErrorResponse, NoData, PortalSuspended, RawMessage,
1033 RowDescription, msg_type,
1034 };
1035 use crate::protocol::frontend::{write_describe_portal, write_execute, write_flush};
1036
1037 self.buffer_set.write_buffer.clear();
1038 write_describe_portal(&mut self.buffer_set.write_buffer, portal);
1039 write_execute(&mut self.buffer_set.write_buffer, portal, max_rows);
1040 write_flush(&mut self.buffer_set.write_buffer);
1041
1042 self.stream.write_all(&self.buffer_set.write_buffer)?;
1043 self.stream.flush()?;
1044
1045 let mut column_buffer: Vec<u8> = Vec::new();
1046
1047 loop {
1048 self.stream.read_message(&mut self.buffer_set)?;
1049 let type_byte = self.buffer_set.type_byte;
1050
1051 if RawMessage::is_async_type(type_byte) {
1052 continue;
1053 }
1054
1055 match type_byte {
1056 msg_type::ROW_DESCRIPTION => {
1057 column_buffer.clear();
1058 column_buffer.extend_from_slice(&self.buffer_set.read_buffer);
1059 let cols = RowDescription::parse(&column_buffer)?;
1060 handler.result_start(cols)?;
1061 }
1062 msg_type::NO_DATA => {
1063 NoData::parse(&self.buffer_set.read_buffer)?;
1064 }
1065 msg_type::DATA_ROW => {
1066 let cols = RowDescription::parse(&column_buffer)?;
1067 let row = DataRow::parse(&self.buffer_set.read_buffer)?;
1068 handler.row(cols, row)?;
1069 }
1070 msg_type::COMMAND_COMPLETE => {
1071 let complete = CommandComplete::parse(&self.buffer_set.read_buffer)?;
1072 handler.result_end(complete)?;
1073 return Ok(false); }
1075 msg_type::PORTAL_SUSPENDED => {
1076 PortalSuspended::parse(&self.buffer_set.read_buffer)?;
1077 return Ok(true); }
1079 msg_type::ERROR_RESPONSE => {
1080 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
1081 return Err(error.into_error());
1082 }
1083 _ => {
1084 return Err(Error::Protocol(format!(
1085 "Unexpected message in execute: '{}'",
1086 type_byte as char
1087 )));
1088 }
1089 }
1090 }
1091 }
1092
1093 pub fn lowlevel_sync(&mut self) -> Result<()> {
1100 let result = self.lowlevel_sync_inner();
1101 if let Err(e) = &result
1102 && e.is_connection_broken()
1103 {
1104 self.is_broken = true;
1105 }
1106 result
1107 }
1108
1109 fn lowlevel_sync_inner(&mut self) -> Result<()> {
1110 use crate::protocol::backend::{ErrorResponse, RawMessage, ReadyForQuery, msg_type};
1111 use crate::protocol::frontend::write_sync;
1112
1113 self.buffer_set.write_buffer.clear();
1114 write_sync(&mut self.buffer_set.write_buffer);
1115
1116 self.stream.write_all(&self.buffer_set.write_buffer)?;
1117 self.stream.flush()?;
1118
1119 let mut pending_error: Option<Error> = None;
1120
1121 loop {
1122 self.stream.read_message(&mut self.buffer_set)?;
1123 let type_byte = self.buffer_set.type_byte;
1124
1125 if RawMessage::is_async_type(type_byte) {
1126 continue;
1127 }
1128
1129 match type_byte {
1130 msg_type::READY_FOR_QUERY => {
1131 let ready = ReadyForQuery::parse(&self.buffer_set.read_buffer)?;
1132 self.transaction_status = ready.transaction_status().unwrap_or_default();
1133 if let Some(e) = pending_error {
1134 return Err(e);
1135 }
1136 return Ok(());
1137 }
1138 msg_type::ERROR_RESPONSE => {
1139 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
1140 pending_error = Some(error.into_error());
1141 }
1142 _ => {
1143 }
1145 }
1146 }
1147 }
1148
1149 pub fn lowlevel_flush(&mut self) -> Result<()> {
1156 use crate::protocol::frontend::write_flush;
1157
1158 self.buffer_set.write_buffer.clear();
1159 write_flush(&mut self.buffer_set.write_buffer);
1160
1161 self.stream.write_all(&self.buffer_set.write_buffer)?;
1162 self.stream.flush()?;
1163 Ok(())
1164 }
1165
1166 pub fn exec_portal<S: IntoStatement, P, F, T>(
1196 &mut self,
1197 statement: S,
1198 params: P,
1199 f: F,
1200 ) -> Result<T>
1201 where
1202 P: ToParams,
1203 F: FnOnce(&mut UnnamedPortal<'_>) -> Result<T>,
1204 {
1205 let result = self.exec_portal_inner(&statement, ¶ms, f);
1206 if let Err(e) = &result
1207 && e.is_connection_broken()
1208 {
1209 self.is_broken = true;
1210 }
1211 result
1212 }
1213
1214 fn exec_portal_inner<S: IntoStatement, P, F, T>(
1215 &mut self,
1216 statement: &S,
1217 params: &P,
1218 f: F,
1219 ) -> Result<T>
1220 where
1221 P: ToParams,
1222 F: FnOnce(&mut UnnamedPortal<'_>) -> Result<T>,
1223 {
1224 let mut state_machine = if let Some(sql) = statement.as_sql() {
1226 BindStateMachine::bind_sql(&mut self.buffer_set, "", sql, params)?
1227 } else {
1228 let stmt = statement.as_prepared().unwrap();
1229 BindStateMachine::bind_prepared(
1230 &mut self.buffer_set,
1231 "",
1232 &stmt.wire_name(),
1233 &stmt.param_oids,
1234 params,
1235 )?
1236 };
1237
1238 loop {
1240 match state_machine.step(&mut self.buffer_set)? {
1241 Action::ReadMessage => {
1242 self.stream.read_message(&mut self.buffer_set)?;
1243 }
1244 Action::Write => {
1245 self.stream.write_all(&self.buffer_set.write_buffer)?;
1246 self.stream.flush()?;
1247 }
1248 Action::WriteAndReadMessage => {
1249 self.stream.write_all(&self.buffer_set.write_buffer)?;
1250 self.stream.flush()?;
1251 self.stream.read_message(&mut self.buffer_set)?;
1252 }
1253 Action::Finished => break,
1254 _ => return Err(Error::Protocol("Unexpected action in bind".into())),
1255 }
1256 }
1257
1258 let mut portal = UnnamedPortal { conn: self };
1260 let result = f(&mut portal);
1261
1262 let sync_result = portal.conn.lowlevel_sync();
1264
1265 match (result, sync_result) {
1267 (Ok(v), Ok(())) => Ok(v),
1268 (Err(e), _) => Err(e),
1269 (Ok(_), Err(e)) => Err(e),
1270 }
1271 }
1272
1273 pub fn lowlevel_close_portal(&mut self, portal: &str) -> Result<()> {
1275 let result = self.lowlevel_close_portal_inner(portal);
1276 if let Err(e) = &result
1277 && e.is_connection_broken()
1278 {
1279 self.is_broken = true;
1280 }
1281 result
1282 }
1283
1284 fn lowlevel_close_portal_inner(&mut self, portal: &str) -> Result<()> {
1285 use crate::protocol::backend::{CloseComplete, ErrorResponse, RawMessage, msg_type};
1286 use crate::protocol::frontend::{write_close_portal, write_flush};
1287
1288 self.buffer_set.write_buffer.clear();
1289 write_close_portal(&mut self.buffer_set.write_buffer, portal);
1290 write_flush(&mut self.buffer_set.write_buffer);
1291
1292 self.stream.write_all(&self.buffer_set.write_buffer)?;
1293 self.stream.flush()?;
1294
1295 loop {
1296 self.stream.read_message(&mut self.buffer_set)?;
1297 let type_byte = self.buffer_set.type_byte;
1298
1299 if RawMessage::is_async_type(type_byte) {
1300 continue;
1301 }
1302
1303 match type_byte {
1304 msg_type::CLOSE_COMPLETE => {
1305 CloseComplete::parse(&self.buffer_set.read_buffer)?;
1306 return Ok(());
1307 }
1308 msg_type::ERROR_RESPONSE => {
1309 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
1310 return Err(error.into_error());
1311 }
1312 _ => {
1313 return Err(Error::Protocol(format!(
1314 "Expected CloseComplete or ErrorResponse, got '{}'",
1315 type_byte as char
1316 )));
1317 }
1318 }
1319 }
1320 }
1321}
1322
1323impl Drop for Conn {
1324 fn drop(&mut self) {
1325 self.buffer_set.write_buffer.clear();
1327 write_terminate(&mut self.buffer_set.write_buffer);
1328 let _ = self.stream.write_all(&self.buffer_set.write_buffer);
1329 let _ = self.stream.flush();
1330 }
1331}