1use std::net::TcpStream;
4#[cfg(unix)]
5use std::os::unix::net::UnixStream;
6
7use crate::buffer_pool::PooledBufferSet;
8use crate::conversion::ToParams;
9use crate::error::{Error, Result};
10use crate::handler::{
11 AsyncMessageHandler, DropHandler, ExtendedHandler, FirstRowHandler, SimpleHandler,
12};
13use crate::opts::Opts;
14use crate::protocol::backend::BackendKeyData;
15use crate::protocol::frontend::write_terminate;
16use crate::protocol::types::TransactionStatus;
17use crate::state::StateMachine;
18use crate::state::action::Action;
19use crate::state::connection::ConnectionStateMachine;
20use crate::state::extended::{BindStateMachine, ExtendedQueryStateMachine, PreparedStatement};
21use crate::state::simple_query::SimpleQueryStateMachine;
22use crate::statement::IntoStatement;
23
24use super::stream::Stream;
25use super::unnamed_portal::UnnamedPortal;
26
27pub struct Conn {
29 pub(crate) stream: Stream,
30 pub(crate) buffer_set: PooledBufferSet,
31 backend_key: Option<BackendKeyData>,
32 server_params: Vec<(String, String)>,
33 pub(crate) transaction_status: TransactionStatus,
34 pub(crate) is_broken: bool,
35 name_counter: u64,
36 async_message_handler: Option<Box<dyn AsyncMessageHandler>>,
37}
38
39impl Conn {
40 pub fn new<O: TryInto<Opts>>(opts: O) -> Result<Self>
42 where
43 Error: From<O::Error>,
44 {
45 let opts = opts.try_into()?;
46
47 let stream = if let Some(socket_path) = &opts.socket {
48 #[cfg(unix)]
49 {
50 Stream::unix(UnixStream::connect(socket_path)?)
51 }
52 #[cfg(not(unix))]
53 {
54 let _ = socket_path;
55 return Err(Error::Unsupported(
56 "Unix sockets are not supported on this platform".into(),
57 ));
58 }
59 } else {
60 if opts.host.is_empty() {
61 return Err(Error::InvalidUsage("host is empty".into()));
62 }
63 let addr = format!("{}:{}", opts.host, opts.port);
64 let tcp = TcpStream::connect(&addr)?;
65 tcp.set_nodelay(true)?;
66 Stream::tcp(tcp)
67 };
68
69 Self::new_with_stream(stream, opts)
70 }
71
72 #[allow(unused_mut)]
74 pub fn new_with_stream(mut stream: Stream, options: Opts) -> Result<Self> {
75 let mut buffer_set = options.buffer_pool.get_buffer_set();
76 let mut state_machine = ConnectionStateMachine::new(options.clone());
77
78 loop {
80 match state_machine.step(&mut buffer_set)? {
81 Action::WriteAndReadByte => {
82 stream.write_all(&buffer_set.write_buffer)?;
83 stream.flush()?;
84 let byte = stream.read_u8()?;
85 state_machine.set_ssl_response(byte);
86 }
87 Action::ReadMessage => {
88 stream.read_message(&mut buffer_set)?;
89 }
90 Action::Write => {
91 stream.write_all(&buffer_set.write_buffer)?;
92 stream.flush()?;
93 }
94 Action::WriteAndReadMessage => {
95 stream.write_all(&buffer_set.write_buffer)?;
96 stream.flush()?;
97 stream.read_message(&mut buffer_set)?;
98 }
99 Action::TlsHandshake => {
100 #[cfg(feature = "sync-tls")]
101 {
102 stream = stream.upgrade_to_tls(&options.host)?;
103 }
104 #[cfg(not(feature = "sync-tls"))]
105 {
106 return Err(Error::Unsupported(
107 "TLS requested but sync-tls feature not enabled".into(),
108 ));
109 }
110 }
111 Action::HandleAsyncMessageAndReadMessage(_) => {
112 stream.read_message(&mut buffer_set)?;
114 }
115 Action::Finished => break,
116 }
117 }
118
119 let conn = Self {
120 stream,
121 buffer_set,
122 backend_key: state_machine.backend_key().cloned(),
123 server_params: state_machine.take_server_params(),
124 transaction_status: state_machine.transaction_status(),
125 is_broken: false,
126 name_counter: 0,
127 async_message_handler: None,
128 };
129
130 #[cfg(unix)]
132 let conn = if options.upgrade_to_unix_socket && conn.stream.is_tcp_loopback() {
133 conn.try_upgrade_to_unix_socket(&options)
134 } else {
135 conn
136 };
137
138 Ok(conn)
139 }
140
141 #[cfg(unix)]
144 fn try_upgrade_to_unix_socket(mut self, opts: &Opts) -> Self {
145 let mut handler = FirstRowHandler::<(String,)>::new();
147 if self
148 .query("SHOW unix_socket_directories", &mut handler)
149 .is_err()
150 {
151 return self;
152 }
153
154 let socket_dir = match handler.into_row() {
155 Some((dirs,)) => {
156 match dirs.split(',').next() {
158 Some(d) if !d.trim().is_empty() => d.trim().to_string(),
159 _ => return self,
160 }
161 }
162 None => return self,
163 };
164
165 let socket_path = format!("{}/.s.PGSQL.{}", socket_dir, opts.port);
167
168 let unix_stream = match UnixStream::connect(&socket_path) {
170 Ok(s) => s,
171 Err(_) => return self,
172 };
173
174 let mut opts_unix = opts.clone();
176 opts_unix.upgrade_to_unix_socket = false;
177
178 match Self::new_with_stream(Stream::unix(unix_stream), opts_unix) {
179 Ok(new_conn) => new_conn,
180 Err(_) => self,
181 }
182 }
183
184 pub fn backend_key(&self) -> Option<&BackendKeyData> {
186 self.backend_key.as_ref()
187 }
188
189 pub fn connection_id(&self) -> u32 {
193 self.backend_key.as_ref().map_or(0, |k| k.process_id())
194 }
195
196 pub fn server_params(&self) -> &[(String, String)] {
198 &self.server_params
199 }
200
201 pub fn transaction_status(&self) -> TransactionStatus {
203 self.transaction_status
204 }
205
206 pub fn in_transaction(&self) -> bool {
208 self.transaction_status.in_transaction()
209 }
210
211 pub fn is_broken(&self) -> bool {
213 self.is_broken
214 }
215
216 pub(crate) fn next_portal_name(&mut self) -> String {
218 self.name_counter += 1;
219 format!("_zero_p_{}", self.name_counter)
220 }
221
222 pub(crate) fn create_named_portal<S: IntoStatement, P: ToParams>(
226 &mut self,
227 portal_name: &str,
228 statement: &S,
229 params: &P,
230 ) -> Result<()> {
231 let mut state_machine = if let Some(sql) = statement.as_sql() {
233 BindStateMachine::bind_sql(&mut self.buffer_set, portal_name, sql, params)?
234 } else {
235 let stmt = statement.as_prepared().unwrap();
236 BindStateMachine::bind_prepared(
237 &mut self.buffer_set,
238 portal_name,
239 &stmt.wire_name(),
240 &stmt.param_oids,
241 params,
242 )?
243 };
244
245 loop {
247 match state_machine.step(&mut self.buffer_set)? {
248 Action::ReadMessage => {
249 self.stream.read_message(&mut self.buffer_set)?;
250 }
251 Action::Write => {
252 self.stream.write_all(&self.buffer_set.write_buffer)?;
253 self.stream.flush()?;
254 }
255 Action::WriteAndReadMessage => {
256 self.stream.write_all(&self.buffer_set.write_buffer)?;
257 self.stream.flush()?;
258 self.stream.read_message(&mut self.buffer_set)?;
259 }
260 Action::Finished => break,
261 _ => return Err(Error::Protocol("Unexpected action in bind".into())),
262 }
263 }
264
265 Ok(())
266 }
267
268 pub fn set_async_message_handler<H: AsyncMessageHandler + 'static>(&mut self, handler: H) {
275 self.async_message_handler = Some(Box::new(handler));
276 }
277
278 pub fn clear_async_message_handler(&mut self) {
280 self.async_message_handler = None;
281 }
282
283 pub fn pipeline<T, F>(&mut self, f: F) -> Result<T>
304 where
305 F: FnOnce(&mut super::pipeline::Pipeline<'_>) -> Result<T>,
306 {
307 let mut pipeline = super::pipeline::Pipeline::new_inner(self);
308 let result = f(&mut pipeline);
309 pipeline.cleanup();
310 result
311 }
312
313 pub fn ping(&mut self) -> Result<()> {
315 self.query_drop("")?;
316 Ok(())
317 }
318
319 fn drive<S: StateMachine>(&mut self, state_machine: &mut S) -> Result<()> {
321 loop {
322 match state_machine.step(&mut self.buffer_set)? {
323 Action::WriteAndReadByte => {
324 return Err(Error::Protocol(
325 "Unexpected WriteAndReadByte in query state machine".into(),
326 ));
327 }
328 Action::ReadMessage => {
329 self.stream.read_message(&mut self.buffer_set)?;
330 }
331 Action::Write => {
332 self.stream.write_all(&self.buffer_set.write_buffer)?;
333 self.stream.flush()?;
334 }
335 Action::WriteAndReadMessage => {
336 self.stream.write_all(&self.buffer_set.write_buffer)?;
337 self.stream.flush()?;
338 self.stream.read_message(&mut self.buffer_set)?;
339 }
340 Action::TlsHandshake => {
341 return Err(Error::Protocol(
342 "Unexpected TlsHandshake in query state machine".into(),
343 ));
344 }
345 Action::HandleAsyncMessageAndReadMessage(ref async_msg) => {
346 if let Some(ref mut h) = self.async_message_handler {
347 h.handle(async_msg);
348 }
349 self.stream.read_message(&mut self.buffer_set)?;
351 }
352 Action::Finished => {
353 self.transaction_status = state_machine.transaction_status();
354 break;
355 }
356 }
357 }
358 Ok(())
359 }
360
361 pub fn query<H: SimpleHandler>(&mut self, sql: &str, handler: &mut H) -> Result<()> {
363 let result = self.query_inner(sql, handler);
364 if let Err(e) = &result
365 && e.is_connection_broken()
366 {
367 self.is_broken = true;
368 }
369 result
370 }
371
372 fn query_inner<H: SimpleHandler>(&mut self, sql: &str, handler: &mut H) -> Result<()> {
373 let mut state_machine = SimpleQueryStateMachine::new(handler, sql);
374 self.drive(&mut state_machine)
375 }
376
377 pub fn query_drop(&mut self, sql: &str) -> Result<Option<u64>> {
379 let mut handler = DropHandler::new();
380 self.query(sql, &mut handler)?;
381 Ok(handler.rows_affected())
382 }
383
384 pub fn query_collect<T: for<'a> crate::conversion::FromRow<'a>>(
395 &mut self,
396 sql: &str,
397 ) -> Result<Vec<T>> {
398 let mut handler = crate::handler::CollectHandler::<T>::new();
399 self.query(sql, &mut handler)?;
400 Ok(handler.into_rows())
401 }
402
403 pub fn query_first<T: for<'a> crate::conversion::FromRow<'a>>(
405 &mut self,
406 sql: &str,
407 ) -> Result<Option<T>> {
408 let mut handler = crate::handler::FirstRowHandler::<T>::new();
409 self.query(sql, &mut handler)?;
410 Ok(handler.into_row())
411 }
412
413 pub fn query_foreach<T: for<'a> crate::conversion::FromRow<'a>, F: FnMut(T) -> Result<()>>(
426 &mut self,
427 sql: &str,
428 f: F,
429 ) -> Result<()> {
430 let mut handler = crate::handler::ForEachHandler::<T, F>::new(f);
431 self.query(sql, &mut handler)?;
432 Ok(())
433 }
434
435 pub fn close(mut self) -> Result<()> {
437 self.buffer_set.write_buffer.clear();
438 write_terminate(&mut self.buffer_set.write_buffer);
439 self.stream.write_all(&self.buffer_set.write_buffer)?;
440 self.stream.flush()?;
441 Ok(())
442 }
443
444 pub fn prepare(&mut self, query: &str) -> Result<PreparedStatement> {
448 self.prepare_typed(query, &[])
449 }
450
451 pub fn prepare_batch(&mut self, queries: &[&str]) -> Result<Vec<PreparedStatement>> {
468 if queries.is_empty() {
469 return Ok(Vec::new());
470 }
471
472 let start_idx = self.name_counter + 1;
473 self.name_counter += queries.len() as u64;
474
475 let result = self.prepare_batch_inner(queries, start_idx);
476 if let Err(e) = &result
477 && e.is_connection_broken()
478 {
479 self.is_broken = true;
480 }
481 result
482 }
483
484 fn prepare_batch_inner(
485 &mut self,
486 queries: &[&str],
487 start_idx: u64,
488 ) -> Result<Vec<PreparedStatement>> {
489 use crate::state::batch_prepare::BatchPrepareStateMachine;
490
491 let mut state_machine =
492 BatchPrepareStateMachine::new(&mut self.buffer_set, queries, start_idx);
493
494 loop {
495 match state_machine.step(&mut self.buffer_set)? {
496 Action::ReadMessage => {
497 self.stream.read_message(&mut self.buffer_set)?;
498 }
499 Action::WriteAndReadMessage => {
500 self.stream.write_all(&self.buffer_set.write_buffer)?;
501 self.stream.flush()?;
502 self.stream.read_message(&mut self.buffer_set)?;
503 }
504 Action::Finished => {
505 self.transaction_status = state_machine.transaction_status();
506 break;
507 }
508 _ => return Err(Error::Protocol("Unexpected action in batch prepare".into())),
509 }
510 }
511
512 Ok(state_machine.take_statements())
513 }
514
515 pub fn prepare_typed(&mut self, query: &str, param_oids: &[u32]) -> Result<PreparedStatement> {
517 self.name_counter += 1;
518 let idx = self.name_counter;
519 let result = self.prepare_inner(idx, query, param_oids);
520 if let Err(e) = &result
521 && e.is_connection_broken()
522 {
523 self.is_broken = true;
524 }
525 result
526 }
527
528 fn prepare_inner(
529 &mut self,
530 idx: u64,
531 query: &str,
532 param_oids: &[u32],
533 ) -> Result<PreparedStatement> {
534 let mut handler = DropHandler::new();
535 let mut state_machine = ExtendedQueryStateMachine::prepare(
536 &mut handler,
537 &mut self.buffer_set,
538 idx,
539 query,
540 param_oids,
541 );
542 self.drive(&mut state_machine)?;
543 state_machine
544 .take_prepared_statement()
545 .ok_or_else(|| Error::Protocol("No prepared statement".into()))
546 }
547
548 pub fn exec<S: IntoStatement, P: ToParams, H: ExtendedHandler>(
565 &mut self,
566 statement: S,
567 params: P,
568 handler: &mut H,
569 ) -> Result<()> {
570 let result = self.exec_inner(&statement, ¶ms, handler);
571 if let Err(e) = &result
572 && e.is_connection_broken()
573 {
574 self.is_broken = true;
575 }
576 result
577 }
578
579 fn exec_inner<S: IntoStatement, P: ToParams, H: ExtendedHandler>(
580 &mut self,
581 statement: &S,
582 params: &P,
583 handler: &mut H,
584 ) -> Result<()> {
585 let mut state_machine = if statement.needs_parse() {
586 ExtendedQueryStateMachine::execute_sql(
587 handler,
588 &mut self.buffer_set,
589 statement.as_sql().unwrap(),
590 params,
591 )?
592 } else {
593 let stmt = statement.as_prepared().unwrap();
594 ExtendedQueryStateMachine::execute(
595 handler,
596 &mut self.buffer_set,
597 &stmt.wire_name(),
598 &stmt.param_oids,
599 params,
600 )?
601 };
602
603 self.drive(&mut state_machine)
604 }
605
606 pub fn exec_drop<S: IntoStatement, P: ToParams>(
610 &mut self,
611 statement: S,
612 params: P,
613 ) -> Result<Option<u64>> {
614 let mut handler = DropHandler::new();
615 self.exec(statement, params, &mut handler)?;
616 Ok(handler.rows_affected())
617 }
618
619 pub fn exec_collect<
633 T: for<'a> crate::conversion::FromRow<'a>,
634 S: IntoStatement,
635 P: ToParams,
636 >(
637 &mut self,
638 statement: S,
639 params: P,
640 ) -> Result<Vec<T>> {
641 let mut handler = crate::handler::CollectHandler::<T>::new();
642 self.exec(statement, params, &mut handler)?;
643 Ok(handler.into_rows())
644 }
645
646 pub fn exec_first<T: for<'a> crate::conversion::FromRow<'a>, S: IntoStatement, P: ToParams>(
660 &mut self,
661 statement: S,
662 params: P,
663 ) -> Result<Option<T>> {
664 let mut handler = crate::handler::FirstRowHandler::<T>::new();
665 self.exec(statement, params, &mut handler)?;
666 Ok(handler.into_row())
667 }
668
669 pub fn exec_foreach<
685 T: for<'a> crate::conversion::FromRow<'a>,
686 S: IntoStatement,
687 P: ToParams,
688 F: FnMut(T) -> Result<()>,
689 >(
690 &mut self,
691 statement: S,
692 params: P,
693 f: F,
694 ) -> Result<()> {
695 let mut handler = crate::handler::ForEachHandler::<T, F>::new(f);
696 self.exec(statement, params, &mut handler)?;
697 Ok(())
698 }
699
700 pub fn exec_batch<S: IntoStatement, P: ToParams>(
731 &mut self,
732 statement: S,
733 params_list: &[P],
734 ) -> Result<()> {
735 self.exec_batch_chunked(statement, params_list, 1000)
736 }
737
738 pub fn exec_batch_chunked<S: IntoStatement, P: ToParams>(
742 &mut self,
743 statement: S,
744 params_list: &[P],
745 chunk_size: usize,
746 ) -> Result<()> {
747 let result = self.exec_batch_inner(&statement, params_list, chunk_size);
748 if let Err(e) = &result
749 && e.is_connection_broken()
750 {
751 self.is_broken = true;
752 }
753 result
754 }
755
756 fn exec_batch_inner<S: IntoStatement, P: ToParams>(
757 &mut self,
758 statement: &S,
759 params_list: &[P],
760 chunk_size: usize,
761 ) -> Result<()> {
762 use crate::protocol::frontend::{write_bind, write_execute, write_parse, write_sync};
763 use crate::state::extended::BatchStateMachine;
764
765 if params_list.is_empty() {
766 return Ok(());
767 }
768
769 let chunk_size = chunk_size.max(1);
770 let needs_parse = statement.needs_parse();
771 let sql = statement.as_sql();
772 let prepared = statement.as_prepared();
773
774 let param_oids: Vec<u32> = if let Some(stmt) = prepared {
776 stmt.param_oids.clone()
777 } else {
778 params_list[0].natural_oids()
779 };
780
781 let stmt_name = prepared.map(|s| s.wire_name()).unwrap_or_default();
783
784 for chunk in params_list.chunks(chunk_size) {
785 self.buffer_set.write_buffer.clear();
786
787 let parse_in_chunk = needs_parse;
789 if parse_in_chunk {
790 write_parse(
791 &mut self.buffer_set.write_buffer,
792 "",
793 sql.unwrap(),
794 ¶m_oids,
795 );
796 }
797
798 for params in chunk {
800 let effective_stmt_name = if needs_parse { "" } else { &stmt_name };
801 write_bind(
802 &mut self.buffer_set.write_buffer,
803 "",
804 effective_stmt_name,
805 params,
806 ¶m_oids,
807 )?;
808 write_execute(&mut self.buffer_set.write_buffer, "", 0);
809 }
810
811 write_sync(&mut self.buffer_set.write_buffer);
813
814 let mut state_machine = BatchStateMachine::new(parse_in_chunk);
816 self.drive_batch(&mut state_machine)?;
817 self.transaction_status = state_machine.transaction_status();
818 }
819
820 Ok(())
821 }
822
823 fn drive_batch(
825 &mut self,
826 state_machine: &mut crate::state::extended::BatchStateMachine,
827 ) -> Result<()> {
828 use crate::protocol::backend::{ReadyForQuery, msg_type};
829 use crate::state::action::Action;
830
831 loop {
832 let step_result = state_machine.step(&mut self.buffer_set);
833 match step_result {
834 Ok(Action::ReadMessage) => {
835 self.stream.read_message(&mut self.buffer_set)?;
836 }
837 Ok(Action::WriteAndReadMessage) => {
838 self.stream.write_all(&self.buffer_set.write_buffer)?;
839 self.stream.flush()?;
840 self.stream.read_message(&mut self.buffer_set)?;
841 }
842 Ok(Action::Finished) => {
843 break;
844 }
845 Ok(_) => return Err(Error::Protocol("Unexpected action in batch".into())),
846 Err(e) => {
847 loop {
849 self.stream.read_message(&mut self.buffer_set)?;
850 if self.buffer_set.type_byte == msg_type::READY_FOR_QUERY {
851 let ready = ReadyForQuery::parse(&self.buffer_set.read_buffer)?;
852 self.transaction_status =
853 ready.transaction_status().unwrap_or_default();
854 break;
855 }
856 }
857 return Err(e);
858 }
859 }
860 }
861 Ok(())
862 }
863
864 pub fn close_statement(&mut self, stmt: &PreparedStatement) -> Result<()> {
866 let result = self.close_statement_inner(&stmt.wire_name());
867 if let Err(e) = &result
868 && e.is_connection_broken()
869 {
870 self.is_broken = true;
871 }
872 result
873 }
874
875 fn close_statement_inner(&mut self, name: &str) -> Result<()> {
876 let mut handler = DropHandler::new();
877 let mut state_machine =
878 ExtendedQueryStateMachine::close_statement(&mut handler, &mut self.buffer_set, name);
879 self.drive(&mut state_machine)
880 }
881
882 pub fn transaction<F, R>(&mut self, f: F) -> Result<R>
892 where
893 F: FnOnce(&mut Conn, super::transaction::Transaction) -> Result<R>,
894 {
895 if self.in_transaction() {
896 return Err(Error::InvalidUsage(
897 "nested transactions are not supported".into(),
898 ));
899 }
900
901 self.query_drop("BEGIN")?;
902
903 let tx = super::transaction::Transaction::new(self.connection_id());
904 let result = f(self, tx);
905
906 if self.in_transaction() {
908 match &result {
909 Ok(_) => {
910 self.query_drop("COMMIT")?;
912 }
913 Err(_) => {
914 let _ = self.query_drop("ROLLBACK");
917 }
918 }
919 }
920
921 result
922 }
923}
924
925impl Conn {
928 pub fn lowlevel_bind<P: ToParams>(
938 &mut self,
939 portal: &str,
940 statement_name: &str,
941 params: P,
942 ) -> Result<()> {
943 let result = self.lowlevel_bind_inner(portal, statement_name, ¶ms);
944 if let Err(e) = &result
945 && e.is_connection_broken()
946 {
947 self.is_broken = true;
948 }
949 result
950 }
951
952 fn lowlevel_bind_inner<P: ToParams>(
953 &mut self,
954 portal: &str,
955 statement_name: &str,
956 params: &P,
957 ) -> Result<()> {
958 use crate::protocol::backend::{BindComplete, ErrorResponse, RawMessage, msg_type};
959 use crate::protocol::frontend::{write_bind, write_flush};
960
961 let param_oids = params.natural_oids();
962 self.buffer_set.write_buffer.clear();
963 write_bind(
964 &mut self.buffer_set.write_buffer,
965 portal,
966 statement_name,
967 params,
968 ¶m_oids,
969 )?;
970 write_flush(&mut self.buffer_set.write_buffer);
971
972 self.stream.write_all(&self.buffer_set.write_buffer)?;
973 self.stream.flush()?;
974
975 loop {
976 self.stream.read_message(&mut self.buffer_set)?;
977 let type_byte = self.buffer_set.type_byte;
978
979 if RawMessage::is_async_type(type_byte) {
980 continue;
981 }
982
983 match type_byte {
984 msg_type::BIND_COMPLETE => {
985 BindComplete::parse(&self.buffer_set.read_buffer)?;
986 return Ok(());
987 }
988 msg_type::ERROR_RESPONSE => {
989 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
990 return Err(error.into_error());
991 }
992 _ => {
993 return Err(Error::Protocol(format!(
994 "Expected BindComplete or ErrorResponse, got '{}'",
995 type_byte as char
996 )));
997 }
998 }
999 }
1000 }
1001
1002 pub fn lowlevel_execute<H: ExtendedHandler>(
1015 &mut self,
1016 portal: &str,
1017 max_rows: u32,
1018 handler: &mut H,
1019 ) -> Result<bool> {
1020 let result = self.lowlevel_execute_inner(portal, max_rows, handler);
1021 if let Err(e) = &result
1022 && e.is_connection_broken()
1023 {
1024 self.is_broken = true;
1025 }
1026 result
1027 }
1028
1029 fn lowlevel_execute_inner<H: ExtendedHandler>(
1030 &mut self,
1031 portal: &str,
1032 max_rows: u32,
1033 handler: &mut H,
1034 ) -> Result<bool> {
1035 use crate::protocol::backend::{
1036 CommandComplete, DataRow, ErrorResponse, NoData, PortalSuspended, RawMessage,
1037 RowDescription, msg_type,
1038 };
1039 use crate::protocol::frontend::{write_describe_portal, write_execute, write_flush};
1040
1041 self.buffer_set.write_buffer.clear();
1042 write_describe_portal(&mut self.buffer_set.write_buffer, portal);
1043 write_execute(&mut self.buffer_set.write_buffer, portal, max_rows);
1044 write_flush(&mut self.buffer_set.write_buffer);
1045
1046 self.stream.write_all(&self.buffer_set.write_buffer)?;
1047 self.stream.flush()?;
1048
1049 let mut column_buffer: Vec<u8> = Vec::new();
1050
1051 loop {
1052 self.stream.read_message(&mut self.buffer_set)?;
1053 let type_byte = self.buffer_set.type_byte;
1054
1055 if RawMessage::is_async_type(type_byte) {
1056 continue;
1057 }
1058
1059 match type_byte {
1060 msg_type::ROW_DESCRIPTION => {
1061 column_buffer.clear();
1062 column_buffer.extend_from_slice(&self.buffer_set.read_buffer);
1063 let cols = RowDescription::parse(&column_buffer)?;
1064 handler.result_start(cols)?;
1065 }
1066 msg_type::NO_DATA => {
1067 NoData::parse(&self.buffer_set.read_buffer)?;
1068 }
1069 msg_type::DATA_ROW => {
1070 let cols = RowDescription::parse(&column_buffer)?;
1071 let row = DataRow::parse(&self.buffer_set.read_buffer)?;
1072 handler.row(cols, row)?;
1073 }
1074 msg_type::COMMAND_COMPLETE => {
1075 let complete = CommandComplete::parse(&self.buffer_set.read_buffer)?;
1076 handler.result_end(complete)?;
1077 return Ok(false); }
1079 msg_type::PORTAL_SUSPENDED => {
1080 PortalSuspended::parse(&self.buffer_set.read_buffer)?;
1081 return Ok(true); }
1083 msg_type::ERROR_RESPONSE => {
1084 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
1085 return Err(error.into_error());
1086 }
1087 _ => {
1088 return Err(Error::Protocol(format!(
1089 "Unexpected message in execute: '{}'",
1090 type_byte as char
1091 )));
1092 }
1093 }
1094 }
1095 }
1096
1097 pub fn lowlevel_sync(&mut self) -> Result<()> {
1104 let result = self.lowlevel_sync_inner();
1105 if let Err(e) = &result
1106 && e.is_connection_broken()
1107 {
1108 self.is_broken = true;
1109 }
1110 result
1111 }
1112
1113 fn lowlevel_sync_inner(&mut self) -> Result<()> {
1114 use crate::protocol::backend::{ErrorResponse, RawMessage, ReadyForQuery, msg_type};
1115 use crate::protocol::frontend::write_sync;
1116
1117 self.buffer_set.write_buffer.clear();
1118 write_sync(&mut self.buffer_set.write_buffer);
1119
1120 self.stream.write_all(&self.buffer_set.write_buffer)?;
1121 self.stream.flush()?;
1122
1123 let mut pending_error: Option<Error> = None;
1124
1125 loop {
1126 self.stream.read_message(&mut self.buffer_set)?;
1127 let type_byte = self.buffer_set.type_byte;
1128
1129 if RawMessage::is_async_type(type_byte) {
1130 continue;
1131 }
1132
1133 match type_byte {
1134 msg_type::READY_FOR_QUERY => {
1135 let ready = ReadyForQuery::parse(&self.buffer_set.read_buffer)?;
1136 self.transaction_status = ready.transaction_status().unwrap_or_default();
1137 if let Some(e) = pending_error {
1138 return Err(e);
1139 }
1140 return Ok(());
1141 }
1142 msg_type::ERROR_RESPONSE => {
1143 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
1144 pending_error = Some(error.into_error());
1145 }
1146 _ => {
1147 }
1149 }
1150 }
1151 }
1152
1153 pub fn lowlevel_flush(&mut self) -> Result<()> {
1160 use crate::protocol::frontend::write_flush;
1161
1162 self.buffer_set.write_buffer.clear();
1163 write_flush(&mut self.buffer_set.write_buffer);
1164
1165 self.stream.write_all(&self.buffer_set.write_buffer)?;
1166 self.stream.flush()?;
1167 Ok(())
1168 }
1169
1170 pub fn exec_portal<S: IntoStatement, P, F, T>(
1200 &mut self,
1201 statement: S,
1202 params: P,
1203 f: F,
1204 ) -> Result<T>
1205 where
1206 P: ToParams,
1207 F: FnOnce(&mut UnnamedPortal<'_>) -> Result<T>,
1208 {
1209 let result = self.exec_portal_inner(&statement, ¶ms, f);
1210 if let Err(e) = &result
1211 && e.is_connection_broken()
1212 {
1213 self.is_broken = true;
1214 }
1215 result
1216 }
1217
1218 fn exec_portal_inner<S: IntoStatement, P, F, T>(
1219 &mut self,
1220 statement: &S,
1221 params: &P,
1222 f: F,
1223 ) -> Result<T>
1224 where
1225 P: ToParams,
1226 F: FnOnce(&mut UnnamedPortal<'_>) -> Result<T>,
1227 {
1228 let mut state_machine = if let Some(sql) = statement.as_sql() {
1230 BindStateMachine::bind_sql(&mut self.buffer_set, "", sql, params)?
1231 } else {
1232 let stmt = statement.as_prepared().unwrap();
1233 BindStateMachine::bind_prepared(
1234 &mut self.buffer_set,
1235 "",
1236 &stmt.wire_name(),
1237 &stmt.param_oids,
1238 params,
1239 )?
1240 };
1241
1242 loop {
1244 match state_machine.step(&mut self.buffer_set)? {
1245 Action::ReadMessage => {
1246 self.stream.read_message(&mut self.buffer_set)?;
1247 }
1248 Action::Write => {
1249 self.stream.write_all(&self.buffer_set.write_buffer)?;
1250 self.stream.flush()?;
1251 }
1252 Action::WriteAndReadMessage => {
1253 self.stream.write_all(&self.buffer_set.write_buffer)?;
1254 self.stream.flush()?;
1255 self.stream.read_message(&mut self.buffer_set)?;
1256 }
1257 Action::Finished => break,
1258 _ => return Err(Error::Protocol("Unexpected action in bind".into())),
1259 }
1260 }
1261
1262 let mut portal = UnnamedPortal { conn: self };
1264 let result = f(&mut portal);
1265
1266 let sync_result = portal.conn.lowlevel_sync();
1268
1269 match (result, sync_result) {
1271 (Ok(v), Ok(())) => Ok(v),
1272 (Err(e), _) => Err(e),
1273 (Ok(_), Err(e)) => Err(e),
1274 }
1275 }
1276
1277 pub fn lowlevel_close_portal(&mut self, portal: &str) -> Result<()> {
1279 let result = self.lowlevel_close_portal_inner(portal);
1280 if let Err(e) = &result
1281 && e.is_connection_broken()
1282 {
1283 self.is_broken = true;
1284 }
1285 result
1286 }
1287
1288 fn lowlevel_close_portal_inner(&mut self, portal: &str) -> Result<()> {
1289 use crate::protocol::backend::{CloseComplete, ErrorResponse, RawMessage, msg_type};
1290 use crate::protocol::frontend::{write_close_portal, write_flush};
1291
1292 self.buffer_set.write_buffer.clear();
1293 write_close_portal(&mut self.buffer_set.write_buffer, portal);
1294 write_flush(&mut self.buffer_set.write_buffer);
1295
1296 self.stream.write_all(&self.buffer_set.write_buffer)?;
1297 self.stream.flush()?;
1298
1299 loop {
1300 self.stream.read_message(&mut self.buffer_set)?;
1301 let type_byte = self.buffer_set.type_byte;
1302
1303 if RawMessage::is_async_type(type_byte) {
1304 continue;
1305 }
1306
1307 match type_byte {
1308 msg_type::CLOSE_COMPLETE => {
1309 CloseComplete::parse(&self.buffer_set.read_buffer)?;
1310 return Ok(());
1311 }
1312 msg_type::ERROR_RESPONSE => {
1313 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
1314 return Err(error.into_error());
1315 }
1316 _ => {
1317 return Err(Error::Protocol(format!(
1318 "Expected CloseComplete or ErrorResponse, got '{}'",
1319 type_byte as char
1320 )));
1321 }
1322 }
1323 }
1324 }
1325}
1326
1327impl Drop for Conn {
1328 fn drop(&mut self) {
1329 self.buffer_set.write_buffer.clear();
1331 write_terminate(&mut self.buffer_set.write_buffer);
1332 let _ = self.stream.write_all(&self.buffer_set.write_buffer);
1333 let _ = self.stream.flush();
1334 }
1335}