1use std::net::TcpStream;
4use std::os::unix::net::UnixStream;
5
6use crate::buffer_pool::PooledBufferSet;
7use crate::conversion::ToParams;
8use crate::error::{Error, Result};
9use crate::handler::{
10 AsyncMessageHandler, BinaryHandler, DropHandler, FirstRowHandler, TextHandler,
11};
12use crate::opts::Opts;
13use crate::protocol::backend::BackendKeyData;
14use crate::protocol::frontend::write_terminate;
15use crate::protocol::types::TransactionStatus;
16use crate::state::StateMachine;
17use crate::state::action::Action;
18use crate::state::connection::ConnectionStateMachine;
19use crate::state::extended::{BindStateMachine, ExtendedQueryStateMachine, PreparedStatement};
20use crate::state::simple_query::SimpleQueryStateMachine;
21use crate::statement::IntoStatement;
22
23use super::stream::Stream;
24use super::unnamed_portal::UnnamedPortal;
25
26pub struct Conn {
28 pub(crate) stream: Stream,
29 pub(crate) buffer_set: PooledBufferSet,
30 backend_key: Option<BackendKeyData>,
31 server_params: Vec<(String, String)>,
32 pub(crate) transaction_status: TransactionStatus,
33 pub(crate) is_broken: bool,
34 name_counter: u64,
35 async_message_handler: Option<Box<dyn AsyncMessageHandler>>,
36}
37
38impl Conn {
39 pub fn new<O: TryInto<Opts>>(opts: O) -> Result<Self>
41 where
42 Error: From<O::Error>,
43 {
44 let opts = opts.try_into()?;
45
46 let stream = if let Some(socket_path) = &opts.socket {
47 Stream::unix(UnixStream::connect(socket_path)?)
48 } else {
49 if opts.host.is_empty() {
50 return Err(Error::InvalidUsage("host is empty".into()));
51 }
52 let addr = format!("{}:{}", opts.host, opts.port);
53 let tcp = TcpStream::connect(&addr)?;
54 tcp.set_nodelay(true)?;
55 Stream::tcp(tcp)
56 };
57
58 Self::new_with_stream(stream, opts)
59 }
60
61 #[allow(unused_mut)]
63 pub fn new_with_stream(mut stream: Stream, options: Opts) -> Result<Self> {
64 let mut buffer_set = options.buffer_pool.get_buffer_set();
65 let mut state_machine = ConnectionStateMachine::new(options.clone());
66
67 loop {
69 match state_machine.step(&mut buffer_set)? {
70 Action::WriteAndReadByte => {
71 stream.write_all(&buffer_set.write_buffer)?;
72 stream.flush()?;
73 let byte = stream.read_u8()?;
74 state_machine.set_ssl_response(byte);
75 }
76 Action::ReadMessage => {
77 stream.read_message(&mut buffer_set)?;
78 }
79 Action::Write => {
80 stream.write_all(&buffer_set.write_buffer)?;
81 stream.flush()?;
82 }
83 Action::WriteAndReadMessage => {
84 stream.write_all(&buffer_set.write_buffer)?;
85 stream.flush()?;
86 stream.read_message(&mut buffer_set)?;
87 }
88 Action::TlsHandshake => {
89 #[cfg(feature = "sync-tls")]
90 {
91 stream = stream.upgrade_to_tls(&options.host)?;
92 }
93 #[cfg(not(feature = "sync-tls"))]
94 {
95 return Err(Error::Unsupported(
96 "TLS requested but sync-tls feature not enabled".into(),
97 ));
98 }
99 }
100 Action::HandleAsyncMessageAndReadMessage(_) => {
101 stream.read_message(&mut buffer_set)?;
103 }
104 Action::Finished => break,
105 }
106 }
107
108 let conn = Self {
109 stream,
110 buffer_set,
111 backend_key: state_machine.backend_key().cloned(),
112 server_params: state_machine.take_server_params(),
113 transaction_status: state_machine.transaction_status(),
114 is_broken: false,
115 name_counter: 0,
116 async_message_handler: None,
117 };
118
119 let conn = if options.prefer_unix_socket && conn.stream.is_tcp_loopback() {
121 conn.try_upgrade_to_unix_socket(&options)
122 } else {
123 conn
124 };
125
126 Ok(conn)
127 }
128
129 fn try_upgrade_to_unix_socket(mut self, opts: &Opts) -> Self {
132 let mut handler = FirstRowHandler::<(String,)>::new();
134 if self
135 .query("SHOW unix_socket_directories", &mut handler)
136 .is_err()
137 {
138 return self;
139 }
140
141 let socket_dir = match handler.into_row() {
142 Some((dirs,)) => {
143 match dirs.split(',').next() {
145 Some(d) if !d.trim().is_empty() => d.trim().to_string(),
146 _ => return self,
147 }
148 }
149 None => return self,
150 };
151
152 let socket_path = format!("{}/.s.PGSQL.{}", socket_dir, opts.port);
154
155 let unix_stream = match UnixStream::connect(&socket_path) {
157 Ok(s) => s,
158 Err(_) => return self,
159 };
160
161 let mut opts_unix = opts.clone();
163 opts_unix.prefer_unix_socket = false;
164
165 match Self::new_with_stream(Stream::unix(unix_stream), opts_unix) {
166 Ok(new_conn) => new_conn,
167 Err(_) => self,
168 }
169 }
170
171 pub fn backend_key(&self) -> Option<&BackendKeyData> {
173 self.backend_key.as_ref()
174 }
175
176 pub fn connection_id(&self) -> u32 {
180 self.backend_key.as_ref().map_or(0, |k| k.process_id())
181 }
182
183 pub fn server_params(&self) -> &[(String, String)] {
185 &self.server_params
186 }
187
188 pub fn transaction_status(&self) -> TransactionStatus {
190 self.transaction_status
191 }
192
193 pub fn in_transaction(&self) -> bool {
195 self.transaction_status.in_transaction()
196 }
197
198 pub fn is_broken(&self) -> bool {
200 self.is_broken
201 }
202
203 pub(crate) fn next_portal_name(&mut self) -> String {
205 self.name_counter += 1;
206 format!("_zero_p_{}", self.name_counter)
207 }
208
209 pub(crate) fn create_named_portal<S: IntoStatement, P: ToParams>(
213 &mut self,
214 portal_name: &str,
215 statement: &S,
216 params: &P,
217 ) -> Result<()> {
218 let mut state_machine = if let Some(sql) = statement.as_sql() {
220 BindStateMachine::bind_sql(&mut self.buffer_set, portal_name, sql, params)?
221 } else {
222 let stmt = statement.as_prepared().unwrap();
223 BindStateMachine::bind_prepared(
224 &mut self.buffer_set,
225 portal_name,
226 &stmt.wire_name(),
227 &stmt.param_oids,
228 params,
229 )?
230 };
231
232 loop {
234 match state_machine.step(&mut self.buffer_set)? {
235 Action::ReadMessage => {
236 self.stream.read_message(&mut self.buffer_set)?;
237 }
238 Action::Write => {
239 self.stream.write_all(&self.buffer_set.write_buffer)?;
240 self.stream.flush()?;
241 }
242 Action::WriteAndReadMessage => {
243 self.stream.write_all(&self.buffer_set.write_buffer)?;
244 self.stream.flush()?;
245 self.stream.read_message(&mut self.buffer_set)?;
246 }
247 Action::Finished => break,
248 _ => return Err(Error::Protocol("Unexpected action in bind".into())),
249 }
250 }
251
252 Ok(())
253 }
254
255 pub fn set_async_message_handler<H: AsyncMessageHandler + 'static>(&mut self, handler: H) {
262 self.async_message_handler = Some(Box::new(handler));
263 }
264
265 pub fn clear_async_message_handler(&mut self) {
267 self.async_message_handler = None;
268 }
269
270 pub fn run_pipeline<T, F>(&mut self, f: F) -> Result<T>
291 where
292 F: FnOnce(&mut super::pipeline::Pipeline<'_>) -> Result<T>,
293 {
294 let mut pipeline = super::pipeline::Pipeline::new_inner(self);
295 let result = f(&mut pipeline);
296 pipeline.cleanup();
297 result
298 }
299
300 pub fn ping(&mut self) -> Result<()> {
302 self.query_drop("")?;
303 Ok(())
304 }
305
306 fn drive<S: StateMachine>(&mut self, state_machine: &mut S) -> Result<()> {
308 loop {
309 match state_machine.step(&mut self.buffer_set)? {
310 Action::WriteAndReadByte => {
311 return Err(Error::Protocol(
312 "Unexpected WriteAndReadByte in query state machine".into(),
313 ));
314 }
315 Action::ReadMessage => {
316 self.stream.read_message(&mut self.buffer_set)?;
317 }
318 Action::Write => {
319 self.stream.write_all(&self.buffer_set.write_buffer)?;
320 self.stream.flush()?;
321 }
322 Action::WriteAndReadMessage => {
323 self.stream.write_all(&self.buffer_set.write_buffer)?;
324 self.stream.flush()?;
325 self.stream.read_message(&mut self.buffer_set)?;
326 }
327 Action::TlsHandshake => {
328 return Err(Error::Protocol(
329 "Unexpected TlsHandshake in query state machine".into(),
330 ));
331 }
332 Action::HandleAsyncMessageAndReadMessage(ref async_msg) => {
333 if let Some(ref mut h) = self.async_message_handler {
334 h.handle(async_msg);
335 }
336 self.stream.read_message(&mut self.buffer_set)?;
338 }
339 Action::Finished => {
340 self.transaction_status = state_machine.transaction_status();
341 break;
342 }
343 }
344 }
345 Ok(())
346 }
347
348 pub fn query<H: TextHandler>(&mut self, sql: &str, handler: &mut H) -> Result<()> {
350 let result = self.query_inner(sql, handler);
351 if let Err(e) = &result
352 && e.is_connection_broken()
353 {
354 self.is_broken = true;
355 }
356 result
357 }
358
359 fn query_inner<H: TextHandler>(&mut self, sql: &str, handler: &mut H) -> Result<()> {
360 let mut state_machine = SimpleQueryStateMachine::new(handler, sql);
361 self.drive(&mut state_machine)
362 }
363
364 pub fn query_drop(&mut self, sql: &str) -> Result<Option<u64>> {
366 let mut handler = DropHandler::new();
367 self.query(sql, &mut handler)?;
368 Ok(handler.rows_affected())
369 }
370
371 pub fn query_collect<T: for<'a> crate::conversion::FromRow<'a>>(
382 &mut self,
383 sql: &str,
384 ) -> Result<Vec<T>> {
385 let mut handler = crate::handler::CollectHandler::<T>::new();
386 self.query(sql, &mut handler)?;
387 Ok(handler.into_rows())
388 }
389
390 pub fn query_first<T: for<'a> crate::conversion::FromRow<'a>>(
392 &mut self,
393 sql: &str,
394 ) -> Result<Option<T>> {
395 let mut handler = crate::handler::FirstRowHandler::<T>::new();
396 self.query(sql, &mut handler)?;
397 Ok(handler.into_row())
398 }
399
400 pub fn close(mut self) -> Result<()> {
402 self.buffer_set.write_buffer.clear();
403 write_terminate(&mut self.buffer_set.write_buffer);
404 self.stream.write_all(&self.buffer_set.write_buffer)?;
405 self.stream.flush()?;
406 Ok(())
407 }
408
409 pub fn prepare(&mut self, query: &str) -> Result<PreparedStatement> {
413 self.prepare_typed(query, &[])
414 }
415
416 pub fn prepare_batch(&mut self, queries: &[&str]) -> Result<Vec<PreparedStatement>> {
433 let mut statements = Vec::with_capacity(queries.len());
434 for query in queries {
435 statements.push(self.prepare(query)?);
436 }
437 Ok(statements)
438 }
439
440 pub fn prepare_typed(&mut self, query: &str, param_oids: &[u32]) -> Result<PreparedStatement> {
442 self.name_counter += 1;
443 let idx = self.name_counter;
444 let result = self.prepare_inner(idx, query, param_oids);
445 if let Err(e) = &result
446 && e.is_connection_broken()
447 {
448 self.is_broken = true;
449 }
450 result
451 }
452
453 fn prepare_inner(
454 &mut self,
455 idx: u64,
456 query: &str,
457 param_oids: &[u32],
458 ) -> Result<PreparedStatement> {
459 let mut handler = DropHandler::new();
460 let mut state_machine = ExtendedQueryStateMachine::prepare(
461 &mut handler,
462 &mut self.buffer_set,
463 idx,
464 query,
465 param_oids,
466 );
467 self.drive(&mut state_machine)?;
468 state_machine
469 .take_prepared_statement()
470 .ok_or_else(|| Error::Protocol("No prepared statement".into()))
471 }
472
473 pub fn exec<S: IntoStatement, P: ToParams, H: BinaryHandler>(
490 &mut self,
491 statement: S,
492 params: P,
493 handler: &mut H,
494 ) -> Result<()> {
495 let result = self.exec_inner(&statement, ¶ms, handler);
496 if let Err(e) = &result
497 && e.is_connection_broken()
498 {
499 self.is_broken = true;
500 }
501 result
502 }
503
504 fn exec_inner<S: IntoStatement, P: ToParams, H: BinaryHandler>(
505 &mut self,
506 statement: &S,
507 params: &P,
508 handler: &mut H,
509 ) -> Result<()> {
510 let mut state_machine = if statement.needs_parse() {
511 ExtendedQueryStateMachine::execute_sql(
512 handler,
513 &mut self.buffer_set,
514 statement.as_sql().unwrap(),
515 params,
516 )?
517 } else {
518 let stmt = statement.as_prepared().unwrap();
519 ExtendedQueryStateMachine::execute(
520 handler,
521 &mut self.buffer_set,
522 &stmt.wire_name(),
523 &stmt.param_oids,
524 params,
525 )?
526 };
527
528 self.drive(&mut state_machine)
529 }
530
531 pub fn exec_drop<S: IntoStatement, P: ToParams>(
535 &mut self,
536 statement: S,
537 params: P,
538 ) -> Result<Option<u64>> {
539 let mut handler = DropHandler::new();
540 self.exec(statement, params, &mut handler)?;
541 Ok(handler.rows_affected())
542 }
543
544 pub fn exec_collect<
558 T: for<'a> crate::conversion::FromRow<'a>,
559 S: IntoStatement,
560 P: ToParams,
561 >(
562 &mut self,
563 statement: S,
564 params: P,
565 ) -> Result<Vec<T>> {
566 let mut handler = crate::handler::CollectHandler::<T>::new();
567 self.exec(statement, params, &mut handler)?;
568 Ok(handler.into_rows())
569 }
570
571 pub fn exec_first<T: for<'a> crate::conversion::FromRow<'a>, S: IntoStatement, P: ToParams>(
585 &mut self,
586 statement: S,
587 params: P,
588 ) -> Result<Option<T>> {
589 let mut handler = crate::handler::FirstRowHandler::<T>::new();
590 self.exec(statement, params, &mut handler)?;
591 Ok(handler.into_row())
592 }
593
594 pub fn exec_batch<S: IntoStatement, P: ToParams>(
625 &mut self,
626 statement: S,
627 params_list: &[P],
628 ) -> Result<()> {
629 self.exec_batch_chunked(statement, params_list, 1000)
630 }
631
632 pub fn exec_batch_chunked<S: IntoStatement, P: ToParams>(
636 &mut self,
637 statement: S,
638 params_list: &[P],
639 chunk_size: usize,
640 ) -> Result<()> {
641 let result = self.exec_batch_inner(&statement, params_list, chunk_size);
642 if let Err(e) = &result
643 && e.is_connection_broken()
644 {
645 self.is_broken = true;
646 }
647 result
648 }
649
650 fn exec_batch_inner<S: IntoStatement, P: ToParams>(
651 &mut self,
652 statement: &S,
653 params_list: &[P],
654 chunk_size: usize,
655 ) -> Result<()> {
656 use crate::protocol::frontend::{write_bind, write_execute, write_parse, write_sync};
657 use crate::state::extended::BatchStateMachine;
658
659 if params_list.is_empty() {
660 return Ok(());
661 }
662
663 let chunk_size = chunk_size.max(1);
664 let needs_parse = statement.needs_parse();
665 let sql = statement.as_sql();
666 let prepared = statement.as_prepared();
667
668 let param_oids: Vec<u32> = if let Some(stmt) = prepared {
670 stmt.param_oids.clone()
671 } else {
672 params_list[0].natural_oids()
673 };
674
675 let stmt_name = prepared.map(|s| s.wire_name()).unwrap_or_default();
677
678 for chunk in params_list.chunks(chunk_size) {
679 self.buffer_set.write_buffer.clear();
680
681 let parse_in_chunk = needs_parse;
683 if parse_in_chunk {
684 write_parse(
685 &mut self.buffer_set.write_buffer,
686 "",
687 sql.unwrap(),
688 ¶m_oids,
689 );
690 }
691
692 for params in chunk {
694 let effective_stmt_name = if needs_parse { "" } else { &stmt_name };
695 write_bind(
696 &mut self.buffer_set.write_buffer,
697 "",
698 effective_stmt_name,
699 params,
700 ¶m_oids,
701 )?;
702 write_execute(&mut self.buffer_set.write_buffer, "", 0);
703 }
704
705 write_sync(&mut self.buffer_set.write_buffer);
707
708 let mut state_machine = BatchStateMachine::new(parse_in_chunk);
710 self.drive_batch(&mut state_machine)?;
711 self.transaction_status = state_machine.transaction_status();
712 }
713
714 Ok(())
715 }
716
717 fn drive_batch(
719 &mut self,
720 state_machine: &mut crate::state::extended::BatchStateMachine,
721 ) -> Result<()> {
722 use crate::protocol::backend::{ReadyForQuery, msg_type};
723 use crate::state::action::Action;
724
725 loop {
726 let step_result = state_machine.step(&mut self.buffer_set);
727 match step_result {
728 Ok(Action::ReadMessage) => {
729 self.stream.read_message(&mut self.buffer_set)?;
730 }
731 Ok(Action::WriteAndReadMessage) => {
732 self.stream.write_all(&self.buffer_set.write_buffer)?;
733 self.stream.flush()?;
734 self.stream.read_message(&mut self.buffer_set)?;
735 }
736 Ok(Action::Finished) => {
737 break;
738 }
739 Ok(_) => return Err(Error::Protocol("Unexpected action in batch".into())),
740 Err(e) => {
741 loop {
743 self.stream.read_message(&mut self.buffer_set)?;
744 if self.buffer_set.type_byte == msg_type::READY_FOR_QUERY {
745 let ready = ReadyForQuery::parse(&self.buffer_set.read_buffer)?;
746 self.transaction_status =
747 ready.transaction_status().unwrap_or_default();
748 break;
749 }
750 }
751 return Err(e);
752 }
753 }
754 }
755 Ok(())
756 }
757
758 pub fn close_statement(&mut self, stmt: &PreparedStatement) -> Result<()> {
760 let result = self.close_statement_inner(&stmt.wire_name());
761 if let Err(e) = &result
762 && e.is_connection_broken()
763 {
764 self.is_broken = true;
765 }
766 result
767 }
768
769 fn close_statement_inner(&mut self, name: &str) -> Result<()> {
770 let mut handler = DropHandler::new();
771 let mut state_machine =
772 ExtendedQueryStateMachine::close_statement(&mut handler, &mut self.buffer_set, name);
773 self.drive(&mut state_machine)
774 }
775
776 pub fn transaction<F, R>(&mut self, f: F) -> Result<R>
786 where
787 F: FnOnce(&mut Conn, super::transaction::Transaction) -> Result<R>,
788 {
789 if self.in_transaction() {
790 return Err(Error::InvalidUsage(
791 "nested transactions are not supported".into(),
792 ));
793 }
794
795 self.query_drop("BEGIN")?;
796
797 let tx = super::transaction::Transaction::new(self.connection_id());
798 let result = f(self, tx);
799
800 if self.in_transaction() {
802 let rollback_result = self.query_drop("ROLLBACK");
803
804 if let Err(e) = result {
806 return Err(e);
807 }
808 rollback_result?;
809 }
810
811 result
812 }
813}
814
815impl Conn {
818 pub fn lowlevel_bind<P: ToParams>(
828 &mut self,
829 portal: &str,
830 statement_name: &str,
831 params: P,
832 ) -> Result<()> {
833 let result = self.lowlevel_bind_inner(portal, statement_name, ¶ms);
834 if let Err(e) = &result
835 && e.is_connection_broken()
836 {
837 self.is_broken = true;
838 }
839 result
840 }
841
842 fn lowlevel_bind_inner<P: ToParams>(
843 &mut self,
844 portal: &str,
845 statement_name: &str,
846 params: &P,
847 ) -> Result<()> {
848 use crate::protocol::backend::{BindComplete, ErrorResponse, RawMessage, msg_type};
849 use crate::protocol::frontend::{write_bind, write_flush};
850
851 let param_oids = params.natural_oids();
852 self.buffer_set.write_buffer.clear();
853 write_bind(
854 &mut self.buffer_set.write_buffer,
855 portal,
856 statement_name,
857 params,
858 ¶m_oids,
859 )?;
860 write_flush(&mut self.buffer_set.write_buffer);
861
862 self.stream.write_all(&self.buffer_set.write_buffer)?;
863 self.stream.flush()?;
864
865 loop {
866 self.stream.read_message(&mut self.buffer_set)?;
867 let type_byte = self.buffer_set.type_byte;
868
869 if RawMessage::is_async_type(type_byte) {
870 continue;
871 }
872
873 match type_byte {
874 msg_type::BIND_COMPLETE => {
875 BindComplete::parse(&self.buffer_set.read_buffer)?;
876 return Ok(());
877 }
878 msg_type::ERROR_RESPONSE => {
879 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
880 return Err(error.into_error());
881 }
882 _ => {
883 return Err(Error::Protocol(format!(
884 "Expected BindComplete or ErrorResponse, got '{}'",
885 type_byte as char
886 )));
887 }
888 }
889 }
890 }
891
892 pub fn lowlevel_execute<H: BinaryHandler>(
905 &mut self,
906 portal: &str,
907 max_rows: u32,
908 handler: &mut H,
909 ) -> Result<bool> {
910 let result = self.lowlevel_execute_inner(portal, max_rows, handler);
911 if let Err(e) = &result
912 && e.is_connection_broken()
913 {
914 self.is_broken = true;
915 }
916 result
917 }
918
919 fn lowlevel_execute_inner<H: BinaryHandler>(
920 &mut self,
921 portal: &str,
922 max_rows: u32,
923 handler: &mut H,
924 ) -> Result<bool> {
925 use crate::protocol::backend::{
926 CommandComplete, DataRow, ErrorResponse, NoData, PortalSuspended, RawMessage,
927 RowDescription, msg_type,
928 };
929 use crate::protocol::frontend::{write_describe_portal, write_execute, write_flush};
930
931 self.buffer_set.write_buffer.clear();
932 write_describe_portal(&mut self.buffer_set.write_buffer, portal);
933 write_execute(&mut self.buffer_set.write_buffer, portal, max_rows);
934 write_flush(&mut self.buffer_set.write_buffer);
935
936 self.stream.write_all(&self.buffer_set.write_buffer)?;
937 self.stream.flush()?;
938
939 let mut column_buffer: Vec<u8> = Vec::new();
940
941 loop {
942 self.stream.read_message(&mut self.buffer_set)?;
943 let type_byte = self.buffer_set.type_byte;
944
945 if RawMessage::is_async_type(type_byte) {
946 continue;
947 }
948
949 match type_byte {
950 msg_type::ROW_DESCRIPTION => {
951 column_buffer.clear();
952 column_buffer.extend_from_slice(&self.buffer_set.read_buffer);
953 let cols = RowDescription::parse(&column_buffer)?;
954 handler.result_start(cols)?;
955 }
956 msg_type::NO_DATA => {
957 NoData::parse(&self.buffer_set.read_buffer)?;
958 }
959 msg_type::DATA_ROW => {
960 let cols = RowDescription::parse(&column_buffer)?;
961 let row = DataRow::parse(&self.buffer_set.read_buffer)?;
962 handler.row(cols, row)?;
963 }
964 msg_type::COMMAND_COMPLETE => {
965 let complete = CommandComplete::parse(&self.buffer_set.read_buffer)?;
966 handler.result_end(complete)?;
967 return Ok(false); }
969 msg_type::PORTAL_SUSPENDED => {
970 PortalSuspended::parse(&self.buffer_set.read_buffer)?;
971 return Ok(true); }
973 msg_type::ERROR_RESPONSE => {
974 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
975 return Err(error.into_error());
976 }
977 _ => {
978 return Err(Error::Protocol(format!(
979 "Unexpected message in execute: '{}'",
980 type_byte as char
981 )));
982 }
983 }
984 }
985 }
986
987 pub fn lowlevel_sync(&mut self) -> Result<()> {
994 let result = self.lowlevel_sync_inner();
995 if let Err(e) = &result
996 && e.is_connection_broken()
997 {
998 self.is_broken = true;
999 }
1000 result
1001 }
1002
1003 fn lowlevel_sync_inner(&mut self) -> Result<()> {
1004 use crate::protocol::backend::{ErrorResponse, RawMessage, ReadyForQuery, msg_type};
1005 use crate::protocol::frontend::write_sync;
1006
1007 self.buffer_set.write_buffer.clear();
1008 write_sync(&mut self.buffer_set.write_buffer);
1009
1010 self.stream.write_all(&self.buffer_set.write_buffer)?;
1011 self.stream.flush()?;
1012
1013 let mut pending_error: Option<Error> = None;
1014
1015 loop {
1016 self.stream.read_message(&mut self.buffer_set)?;
1017 let type_byte = self.buffer_set.type_byte;
1018
1019 if RawMessage::is_async_type(type_byte) {
1020 continue;
1021 }
1022
1023 match type_byte {
1024 msg_type::READY_FOR_QUERY => {
1025 let ready = ReadyForQuery::parse(&self.buffer_set.read_buffer)?;
1026 self.transaction_status = ready.transaction_status().unwrap_or_default();
1027 if let Some(e) = pending_error {
1028 return Err(e);
1029 }
1030 return Ok(());
1031 }
1032 msg_type::ERROR_RESPONSE => {
1033 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
1034 pending_error = Some(error.into_error());
1035 }
1036 _ => {
1037 }
1039 }
1040 }
1041 }
1042
1043 pub fn lowlevel_flush(&mut self) -> Result<()> {
1050 use crate::protocol::frontend::write_flush;
1051
1052 self.buffer_set.write_buffer.clear();
1053 write_flush(&mut self.buffer_set.write_buffer);
1054
1055 self.stream.write_all(&self.buffer_set.write_buffer)?;
1056 self.stream.flush()?;
1057 Ok(())
1058 }
1059
1060 pub fn exec_iter<S: IntoStatement, P, F, T>(
1090 &mut self,
1091 statement: S,
1092 params: P,
1093 f: F,
1094 ) -> Result<T>
1095 where
1096 P: ToParams,
1097 F: FnOnce(&mut UnnamedPortal<'_>) -> Result<T>,
1098 {
1099 let result = self.exec_iter_inner(&statement, ¶ms, f);
1100 if let Err(e) = &result
1101 && e.is_connection_broken()
1102 {
1103 self.is_broken = true;
1104 }
1105 result
1106 }
1107
1108 fn exec_iter_inner<S: IntoStatement, P, F, T>(
1109 &mut self,
1110 statement: &S,
1111 params: &P,
1112 f: F,
1113 ) -> Result<T>
1114 where
1115 P: ToParams,
1116 F: FnOnce(&mut UnnamedPortal<'_>) -> Result<T>,
1117 {
1118 let mut state_machine = if let Some(sql) = statement.as_sql() {
1120 BindStateMachine::bind_sql(&mut self.buffer_set, "", sql, params)?
1121 } else {
1122 let stmt = statement.as_prepared().unwrap();
1123 BindStateMachine::bind_prepared(
1124 &mut self.buffer_set,
1125 "",
1126 &stmt.wire_name(),
1127 &stmt.param_oids,
1128 params,
1129 )?
1130 };
1131
1132 loop {
1134 match state_machine.step(&mut self.buffer_set)? {
1135 Action::ReadMessage => {
1136 self.stream.read_message(&mut self.buffer_set)?;
1137 }
1138 Action::Write => {
1139 self.stream.write_all(&self.buffer_set.write_buffer)?;
1140 self.stream.flush()?;
1141 }
1142 Action::WriteAndReadMessage => {
1143 self.stream.write_all(&self.buffer_set.write_buffer)?;
1144 self.stream.flush()?;
1145 self.stream.read_message(&mut self.buffer_set)?;
1146 }
1147 Action::Finished => break,
1148 _ => return Err(Error::Protocol("Unexpected action in bind".into())),
1149 }
1150 }
1151
1152 let mut portal = UnnamedPortal { conn: self };
1154 let result = f(&mut portal);
1155
1156 let sync_result = portal.conn.lowlevel_sync();
1158
1159 match (result, sync_result) {
1161 (Ok(v), Ok(())) => Ok(v),
1162 (Err(e), _) => Err(e),
1163 (Ok(_), Err(e)) => Err(e),
1164 }
1165 }
1166
1167 pub fn lowlevel_close_portal(&mut self, portal: &str) -> Result<()> {
1169 let result = self.lowlevel_close_portal_inner(portal);
1170 if let Err(e) = &result
1171 && e.is_connection_broken()
1172 {
1173 self.is_broken = true;
1174 }
1175 result
1176 }
1177
1178 fn lowlevel_close_portal_inner(&mut self, portal: &str) -> Result<()> {
1179 use crate::protocol::backend::{CloseComplete, ErrorResponse, RawMessage, msg_type};
1180 use crate::protocol::frontend::{write_close_portal, write_flush};
1181
1182 self.buffer_set.write_buffer.clear();
1183 write_close_portal(&mut self.buffer_set.write_buffer, portal);
1184 write_flush(&mut self.buffer_set.write_buffer);
1185
1186 self.stream.write_all(&self.buffer_set.write_buffer)?;
1187 self.stream.flush()?;
1188
1189 loop {
1190 self.stream.read_message(&mut self.buffer_set)?;
1191 let type_byte = self.buffer_set.type_byte;
1192
1193 if RawMessage::is_async_type(type_byte) {
1194 continue;
1195 }
1196
1197 match type_byte {
1198 msg_type::CLOSE_COMPLETE => {
1199 CloseComplete::parse(&self.buffer_set.read_buffer)?;
1200 return Ok(());
1201 }
1202 msg_type::ERROR_RESPONSE => {
1203 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
1204 return Err(error.into_error());
1205 }
1206 _ => {
1207 return Err(Error::Protocol(format!(
1208 "Expected CloseComplete or ErrorResponse, got '{}'",
1209 type_byte as char
1210 )));
1211 }
1212 }
1213 }
1214 }
1215}
1216
1217impl Drop for Conn {
1218 fn drop(&mut self) {
1219 self.buffer_set.write_buffer.clear();
1221 write_terminate(&mut self.buffer_set.write_buffer);
1222 let _ = self.stream.write_all(&self.buffer_set.write_buffer);
1223 let _ = self.stream.flush();
1224 }
1225}