1use std::net::TcpStream;
4#[cfg(unix)]
5use std::os::unix::net::UnixStream;
6
7use crate::buffer_pool::PooledBufferSet;
8use crate::conversion::ToParams;
9use crate::error::{Error, Result};
10use crate::handler::{
11 AsyncMessageHandler, DropHandler, ExtendedHandler, FirstRowHandler, SimpleHandler,
12};
13use crate::opts::Opts;
14use crate::protocol::backend::BackendKeyData;
15use crate::protocol::frontend::write_terminate;
16use crate::protocol::types::TransactionStatus;
17use crate::state::StateMachine;
18use crate::state::action::Action;
19use crate::state::connection::ConnectionStateMachine;
20use crate::state::extended::{BindStateMachine, ExtendedQueryStateMachine, PreparedStatement};
21use crate::state::simple_query::SimpleQueryStateMachine;
22use crate::statement::{IntoStatement, StatementRef};
23
24use super::stream::Stream;
25use super::unnamed_portal::UnnamedPortal;
26
27pub struct Conn {
29 pub(crate) stream: Stream,
30 pub(crate) buffer_set: PooledBufferSet,
31 backend_key: Option<BackendKeyData>,
32 server_params: Vec<(String, String)>,
33 pub(crate) transaction_status: TransactionStatus,
34 pub(crate) is_broken: bool,
35 name_counter: u64,
36 async_message_handler: Option<Box<dyn AsyncMessageHandler>>,
37}
38
39impl Conn {
40 pub fn new<O: TryInto<Opts>>(opts: O) -> Result<Self>
42 where
43 Error: From<O::Error>,
44 {
45 let opts = opts.try_into()?;
46
47 let stream = if let Some(socket_path) = &opts.socket {
48 #[cfg(unix)]
49 {
50 Stream::unix(UnixStream::connect(socket_path)?)
51 }
52 #[cfg(not(unix))]
53 {
54 let _ = socket_path;
55 return Err(Error::Unsupported(
56 "Unix sockets are not supported on this platform".into(),
57 ));
58 }
59 } else {
60 if opts.host.is_empty() {
61 return Err(Error::InvalidUsage("host is empty".into()));
62 }
63 let addr = format!("{}:{}", opts.host, opts.port);
64 let tcp = TcpStream::connect(&addr)?;
65 tcp.set_nodelay(true)?;
66 Stream::tcp(tcp)
67 };
68
69 Self::new_with_stream(stream, opts)
70 }
71
72 pub fn new_with_stream(mut stream: Stream, options: Opts) -> Result<Self> {
74 let mut buffer_set = options.buffer_pool.get_buffer_set();
75 let mut state_machine = ConnectionStateMachine::new(options.clone());
76
77 loop {
79 match state_machine.step(&mut buffer_set)? {
80 Action::WriteAndReadByte => {
81 stream.write_all(&buffer_set.write_buffer)?;
82 stream.flush()?;
83 let byte = stream.read_u8()?;
84 state_machine.set_ssl_response(byte);
85 }
86 Action::ReadMessage => {
87 stream.read_message(&mut buffer_set)?;
88 }
89 Action::Write => {
90 stream.write_all(&buffer_set.write_buffer)?;
91 stream.flush()?;
92 }
93 Action::WriteAndReadMessage => {
94 stream.write_all(&buffer_set.write_buffer)?;
95 stream.flush()?;
96 stream.read_message(&mut buffer_set)?;
97 }
98 Action::TlsHandshake => {
99 #[cfg(feature = "sync-tls")]
100 {
101 stream = stream.upgrade_to_tls(&options.host)?;
102 }
103 #[cfg(not(feature = "sync-tls"))]
104 {
105 return Err(Error::Unsupported(
106 "TLS requested but sync-tls feature not enabled".into(),
107 ));
108 }
109 }
110 Action::HandleAsyncMessageAndReadMessage(_) => {
111 stream.read_message(&mut buffer_set)?;
113 }
114 Action::Error(_) => {
115 return Err(Error::LibraryBug(
116 "unexpected server error during connection startup".into(),
117 ));
118 }
119 Action::Finished => break,
120 }
121 }
122
123 let conn = Self {
124 stream,
125 buffer_set,
126 backend_key: state_machine.backend_key().cloned(),
127 server_params: state_machine.take_server_params(),
128 transaction_status: state_machine.transaction_status(),
129 is_broken: false,
130 name_counter: 0,
131 async_message_handler: None,
132 };
133
134 #[cfg(unix)]
136 let conn = if options.upgrade_to_unix_socket && conn.stream.is_tcp_loopback() {
137 conn.try_upgrade_to_unix_socket(&options)
138 } else {
139 conn
140 };
141
142 Ok(conn)
143 }
144
145 #[cfg(unix)]
148 fn try_upgrade_to_unix_socket(mut self, opts: &Opts) -> Self {
149 let mut handler = FirstRowHandler::<(String,)>::new();
151 if self
152 .query("SHOW unix_socket_directories", &mut handler)
153 .is_err()
154 {
155 return self;
156 }
157
158 let socket_dir = match handler.into_row() {
159 Some((dirs,)) => {
160 match dirs.split(',').next() {
162 Some(d) if !d.trim().is_empty() => d.trim().to_string(),
163 _ => return self,
164 }
165 }
166 None => return self,
167 };
168
169 let socket_path = format!("{}/.s.PGSQL.{}", socket_dir, opts.port);
171
172 let unix_stream = match UnixStream::connect(&socket_path) {
174 Ok(s) => s,
175 Err(_) => return self,
176 };
177
178 let mut opts_unix = opts.clone();
180 opts_unix.upgrade_to_unix_socket = false;
181
182 match Self::new_with_stream(Stream::unix(unix_stream), opts_unix) {
183 Ok(new_conn) => new_conn,
184 Err(_) => self,
185 }
186 }
187
188 pub fn backend_key(&self) -> Option<&BackendKeyData> {
190 self.backend_key.as_ref()
191 }
192
193 pub fn connection_id(&self) -> u32 {
197 self.backend_key.as_ref().map_or(0, |k| k.process_id())
198 }
199
200 pub fn server_params(&self) -> &[(String, String)] {
202 &self.server_params
203 }
204
205 pub fn transaction_status(&self) -> TransactionStatus {
207 self.transaction_status
208 }
209
210 pub fn in_transaction(&self) -> bool {
212 self.transaction_status.in_transaction()
213 }
214
215 pub fn is_broken(&self) -> bool {
217 self.is_broken
218 }
219
220 pub(crate) fn next_portal_name(&mut self) -> String {
222 self.name_counter += 1;
223 format!("_zero_p_{}", self.name_counter)
224 }
225
226 pub(crate) fn create_named_portal<S: IntoStatement, P: ToParams>(
230 &mut self,
231 portal_name: &str,
232 statement: &S,
233 params: &P,
234 ) -> Result<()> {
235 let mut state_machine = match statement.statement_ref() {
237 StatementRef::Sql(sql) => {
238 BindStateMachine::bind_sql(&mut self.buffer_set, portal_name, sql, params)?
239 }
240 StatementRef::Prepared(stmt) => BindStateMachine::bind_prepared(
241 &mut self.buffer_set,
242 portal_name,
243 &stmt.wire_name(),
244 &stmt.param_oids,
245 params,
246 )?,
247 };
248
249 loop {
251 match state_machine.step(&mut self.buffer_set)? {
252 Action::ReadMessage => {
253 self.stream.read_message(&mut self.buffer_set)?;
254 }
255 Action::Write => {
256 self.stream.write_all(&self.buffer_set.write_buffer)?;
257 self.stream.flush()?;
258 }
259 Action::WriteAndReadMessage => {
260 self.stream.write_all(&self.buffer_set.write_buffer)?;
261 self.stream.flush()?;
262 self.stream.read_message(&mut self.buffer_set)?;
263 }
264 Action::Finished => break,
265 _ => return Err(Error::LibraryBug("Unexpected action in bind".into())),
266 }
267 }
268
269 Ok(())
270 }
271
272 pub fn set_async_message_handler<H: AsyncMessageHandler + 'static>(&mut self, handler: H) {
279 self.async_message_handler = Some(Box::new(handler));
280 }
281
282 pub fn clear_async_message_handler(&mut self) {
284 self.async_message_handler = None;
285 }
286
287 pub fn pipeline<T, F>(&mut self, f: F) -> Result<T>
308 where
309 F: FnOnce(&mut super::pipeline::Pipeline<'_>) -> Result<T>,
310 {
311 let mut pipeline = super::pipeline::Pipeline::new_inner(self);
312 let result = f(&mut pipeline);
313 pipeline.cleanup();
314 result
315 }
316
317 pub fn ping(&mut self) -> Result<()> {
319 self.query_drop("")?;
320 Ok(())
321 }
322
323 fn drive<S: StateMachine>(&mut self, state_machine: &mut S) -> Result<()> {
325 loop {
326 let action = state_machine.step(&mut self.buffer_set)?;
327
328 match action {
329 Action::WriteAndReadByte => {
330 return Err(Error::LibraryBug(
331 "Unexpected WriteAndReadByte in query state machine".into(),
332 ));
333 }
334 Action::ReadMessage => {
335 self.stream.read_message(&mut self.buffer_set)?;
336 }
337 Action::Write => {
338 self.stream.write_all(&self.buffer_set.write_buffer)?;
339 self.stream.flush()?;
340 }
341 Action::WriteAndReadMessage => {
342 self.stream.write_all(&self.buffer_set.write_buffer)?;
343 self.stream.flush()?;
344 self.stream.read_message(&mut self.buffer_set)?;
345 }
346 Action::TlsHandshake => {
347 return Err(Error::LibraryBug(
348 "Unexpected TlsHandshake in query state machine".into(),
349 ));
350 }
351 Action::HandleAsyncMessageAndReadMessage(async_msg) => {
352 if let Some(h) = &mut self.async_message_handler {
353 h.handle(&async_msg);
354 }
355 self.stream.read_message(&mut self.buffer_set)?;
357 }
358 Action::Error(server_error) => {
359 self.transaction_status = state_machine.transaction_status();
360 return Err(Error::Server(server_error));
361 }
362 Action::Finished => {
363 self.transaction_status = state_machine.transaction_status();
364 break;
365 }
366 }
367 }
368 Ok(())
369 }
370
371 pub fn query<H: SimpleHandler>(&mut self, sql: &str, handler: &mut H) -> Result<()> {
373 let result = self.query_inner(sql, handler);
374 if let Err(e) = &result
375 && e.is_connection_broken()
376 {
377 self.is_broken = true;
378 }
379 result
380 }
381
382 fn query_inner<H: SimpleHandler>(&mut self, sql: &str, handler: &mut H) -> Result<()> {
383 let mut state_machine = SimpleQueryStateMachine::new(handler, sql);
384 self.drive(&mut state_machine)
385 }
386
387 pub fn query_drop(&mut self, sql: &str) -> Result<Option<u64>> {
389 let mut handler = DropHandler::new();
390 self.query(sql, &mut handler)?;
391 Ok(handler.rows_affected())
392 }
393
394 pub fn query_collect<T: for<'a> crate::conversion::FromRow<'a>>(
405 &mut self,
406 sql: &str,
407 ) -> Result<Vec<T>> {
408 let mut handler = crate::handler::CollectHandler::<T>::new();
409 self.query(sql, &mut handler)?;
410 Ok(handler.into_rows())
411 }
412
413 pub fn query_first<T: for<'a> crate::conversion::FromRow<'a>>(
415 &mut self,
416 sql: &str,
417 ) -> Result<Option<T>> {
418 let mut handler = crate::handler::FirstRowHandler::<T>::new();
419 self.query(sql, &mut handler)?;
420 Ok(handler.into_row())
421 }
422
423 pub fn query_foreach<T: for<'a> crate::conversion::FromRow<'a>, F: FnMut(T) -> Result<()>>(
436 &mut self,
437 sql: &str,
438 f: F,
439 ) -> Result<()> {
440 let mut handler = crate::handler::ForEachHandler::<T, F>::new(f);
441 self.query(sql, &mut handler)?;
442 Ok(())
443 }
444
445 pub fn close(mut self) -> Result<()> {
447 self.buffer_set.write_buffer.clear();
448 write_terminate(&mut self.buffer_set.write_buffer);
449 self.stream.write_all(&self.buffer_set.write_buffer)?;
450 self.stream.flush()?;
451 Ok(())
452 }
453
454 pub fn prepare(&mut self, query: &str) -> Result<PreparedStatement> {
458 self.prepare_typed(query, &[])
459 }
460
461 pub fn prepare_batch(&mut self, queries: &[&str]) -> Result<Vec<PreparedStatement>> {
478 if queries.is_empty() {
479 return Ok(Vec::new());
480 }
481
482 let start_idx = self.name_counter + 1;
483 self.name_counter += queries.len() as u64;
484
485 let result = self.prepare_batch_inner(queries, start_idx);
486 if let Err(e) = &result
487 && e.is_connection_broken()
488 {
489 self.is_broken = true;
490 }
491 result
492 }
493
494 fn prepare_batch_inner(
495 &mut self,
496 queries: &[&str],
497 start_idx: u64,
498 ) -> Result<Vec<PreparedStatement>> {
499 use crate::state::batch_prepare::BatchPrepareStateMachine;
500
501 let mut state_machine =
502 BatchPrepareStateMachine::new(&mut self.buffer_set, queries, start_idx);
503
504 loop {
505 match state_machine.step(&mut self.buffer_set)? {
506 Action::ReadMessage => {
507 self.stream.read_message(&mut self.buffer_set)?;
508 }
509 Action::WriteAndReadMessage => {
510 self.stream.write_all(&self.buffer_set.write_buffer)?;
511 self.stream.flush()?;
512 self.stream.read_message(&mut self.buffer_set)?;
513 }
514 Action::Finished => {
515 self.transaction_status = state_machine.transaction_status();
516 break;
517 }
518 _ => {
519 return Err(Error::LibraryBug(
520 "Unexpected action in batch prepare".into(),
521 ));
522 }
523 }
524 }
525
526 Ok(state_machine.take_statements())
527 }
528
529 pub fn prepare_typed(&mut self, query: &str, param_oids: &[u32]) -> Result<PreparedStatement> {
531 self.name_counter += 1;
532 let idx = self.name_counter;
533 let result = self.prepare_inner(idx, query, param_oids);
534 if let Err(e) = &result
535 && e.is_connection_broken()
536 {
537 self.is_broken = true;
538 }
539 result
540 }
541
542 fn prepare_inner(
543 &mut self,
544 idx: u64,
545 query: &str,
546 param_oids: &[u32],
547 ) -> Result<PreparedStatement> {
548 let mut handler = DropHandler::new();
549 let mut state_machine = ExtendedQueryStateMachine::prepare(
550 &mut handler,
551 &mut self.buffer_set,
552 idx,
553 query,
554 param_oids,
555 );
556 self.drive(&mut state_machine)?;
557 state_machine
558 .take_prepared_statement()
559 .ok_or_else(|| Error::LibraryBug("No prepared statement".into()))
560 }
561
562 pub fn exec<S: IntoStatement, P: ToParams, H: ExtendedHandler>(
579 &mut self,
580 statement: S,
581 params: P,
582 handler: &mut H,
583 ) -> Result<()> {
584 let result = self.exec_inner(&statement, ¶ms, handler);
585 if let Err(e) = &result
586 && e.is_connection_broken()
587 {
588 self.is_broken = true;
589 }
590 result
591 }
592
593 fn exec_inner<S: IntoStatement, P: ToParams, H: ExtendedHandler>(
594 &mut self,
595 statement: &S,
596 params: &P,
597 handler: &mut H,
598 ) -> Result<()> {
599 let mut state_machine = match statement.statement_ref() {
600 StatementRef::Sql(sql) => {
601 ExtendedQueryStateMachine::execute_sql(handler, &mut self.buffer_set, sql, params)?
602 }
603 StatementRef::Prepared(stmt) => ExtendedQueryStateMachine::execute(
604 handler,
605 &mut self.buffer_set,
606 &stmt.wire_name(),
607 &stmt.param_oids,
608 params,
609 )?,
610 };
611
612 self.drive(&mut state_machine)
613 }
614
615 pub fn exec_drop<S: IntoStatement, P: ToParams>(
619 &mut self,
620 statement: S,
621 params: P,
622 ) -> Result<Option<u64>> {
623 let mut handler = DropHandler::new();
624 self.exec(statement, params, &mut handler)?;
625 Ok(handler.rows_affected())
626 }
627
628 pub fn exec_collect<
642 T: for<'a> crate::conversion::FromRow<'a>,
643 S: IntoStatement,
644 P: ToParams,
645 >(
646 &mut self,
647 statement: S,
648 params: P,
649 ) -> Result<Vec<T>> {
650 let mut handler = crate::handler::CollectHandler::<T>::new();
651 self.exec(statement, params, &mut handler)?;
652 Ok(handler.into_rows())
653 }
654
655 pub fn exec_first<T: for<'a> crate::conversion::FromRow<'a>, S: IntoStatement, P: ToParams>(
669 &mut self,
670 statement: S,
671 params: P,
672 ) -> Result<Option<T>> {
673 let mut handler = crate::handler::FirstRowHandler::<T>::new();
674 self.exec(statement, params, &mut handler)?;
675 Ok(handler.into_row())
676 }
677
678 pub fn exec_foreach<
694 T: for<'a> crate::conversion::FromRow<'a>,
695 S: IntoStatement,
696 P: ToParams,
697 F: FnMut(T) -> Result<()>,
698 >(
699 &mut self,
700 statement: S,
701 params: P,
702 f: F,
703 ) -> Result<()> {
704 let mut handler = crate::handler::ForEachHandler::<T, F>::new(f);
705 self.exec(statement, params, &mut handler)?;
706 Ok(())
707 }
708
709 pub fn exec_foreach_ref<
742 T: for<'a> crate::conversion::ref_row::RefFromRow<'a>,
743 S: IntoStatement,
744 P: ToParams,
745 F: for<'a> FnMut(&'a T) -> Result<()>,
746 >(
747 &mut self,
748 statement: S,
749 params: P,
750 f: F,
751 ) -> Result<()> {
752 let mut handler = crate::handler::ForEachRefHandler::<T, F>::new(f);
753 self.exec(statement, params, &mut handler)?;
754 Ok(())
755 }
756
757 pub fn exec_batch<S: IntoStatement, P: ToParams>(
788 &mut self,
789 statement: S,
790 params_list: &[P],
791 ) -> Result<()> {
792 self.exec_batch_chunked(statement, params_list, 1000)
793 }
794
795 pub fn exec_batch_chunked<S: IntoStatement, P: ToParams>(
799 &mut self,
800 statement: S,
801 params_list: &[P],
802 chunk_size: usize,
803 ) -> Result<()> {
804 let result = self.exec_batch_inner(&statement, params_list, chunk_size);
805 if let Err(e) = &result
806 && e.is_connection_broken()
807 {
808 self.is_broken = true;
809 }
810 result
811 }
812
813 fn exec_batch_inner<S: IntoStatement, P: ToParams>(
814 &mut self,
815 statement: &S,
816 params_list: &[P],
817 chunk_size: usize,
818 ) -> Result<()> {
819 use crate::protocol::frontend::{write_bind, write_execute, write_parse, write_sync};
820 use crate::state::extended::BatchStateMachine;
821
822 if params_list.is_empty() {
823 return Ok(());
824 }
825
826 let chunk_size = chunk_size.max(1);
827 let stmt_ref = statement.statement_ref();
828
829 let (param_oids, stmt_name) = match stmt_ref {
830 StatementRef::Sql(_) => (params_list[0].natural_oids(), String::new()),
831 StatementRef::Prepared(stmt) => (stmt.param_oids.clone(), stmt.wire_name()),
832 };
833
834 for chunk in params_list.chunks(chunk_size) {
835 self.buffer_set.write_buffer.clear();
836
837 if let StatementRef::Sql(sql) = stmt_ref {
839 write_parse(&mut self.buffer_set.write_buffer, "", sql, ¶m_oids);
840 }
841
842 for params in chunk {
844 let effective_stmt_name = if matches!(stmt_ref, StatementRef::Sql(_)) {
845 ""
846 } else {
847 &stmt_name
848 };
849 write_bind(
850 &mut self.buffer_set.write_buffer,
851 "",
852 effective_stmt_name,
853 params,
854 ¶m_oids,
855 )?;
856 write_execute(&mut self.buffer_set.write_buffer, "", 0);
857 }
858
859 write_sync(&mut self.buffer_set.write_buffer);
861
862 let mut state_machine =
864 BatchStateMachine::new(matches!(stmt_ref, StatementRef::Sql(_)));
865 self.drive_batch(&mut state_machine)?;
866 self.transaction_status = state_machine.transaction_status();
867 }
868
869 Ok(())
870 }
871
872 fn drive_batch(
874 &mut self,
875 state_machine: &mut crate::state::extended::BatchStateMachine,
876 ) -> Result<()> {
877 use crate::state::action::Action;
878
879 loop {
880 let step_result = state_machine.step(&mut self.buffer_set);
881 match step_result {
882 Ok(Action::ReadMessage) => {
883 self.stream.read_message(&mut self.buffer_set)?;
884 }
885 Ok(Action::WriteAndReadMessage) => {
886 self.stream.write_all(&self.buffer_set.write_buffer)?;
887 self.stream.flush()?;
888 self.stream.read_message(&mut self.buffer_set)?;
889 }
890 Ok(Action::Finished) => {
891 break;
892 }
893 Ok(Action::Error(server_error)) => {
894 self.transaction_status = state_machine.transaction_status();
895 return Err(Error::Server(server_error));
896 }
897 Ok(_) => return Err(Error::LibraryBug("Unexpected action in batch".into())),
898 Err(e) => return Err(e),
899 }
900 }
901 Ok(())
902 }
903
904 pub fn close_statement(&mut self, stmt: &PreparedStatement) -> Result<()> {
906 let result = self.close_statement_inner(&stmt.wire_name());
907 if let Err(e) = &result
908 && e.is_connection_broken()
909 {
910 self.is_broken = true;
911 }
912 result
913 }
914
915 fn close_statement_inner(&mut self, name: &str) -> Result<()> {
916 let mut handler = DropHandler::new();
917 let mut state_machine =
918 ExtendedQueryStateMachine::close_statement(&mut handler, &mut self.buffer_set, name);
919 self.drive(&mut state_machine)
920 }
921
922 pub fn transaction<F, R>(&mut self, f: F) -> Result<R>
932 where
933 F: FnOnce(&mut Conn, super::transaction::Transaction) -> Result<R>,
934 {
935 if self.in_transaction() {
936 return Err(Error::InvalidUsage(
937 "nested transactions are not supported".into(),
938 ));
939 }
940
941 self.query_drop("BEGIN")?;
942
943 let tx = super::transaction::Transaction::new(self.connection_id());
944 let result = f(self, tx);
945
946 if self.in_transaction() {
948 match &result {
949 Ok(_) => {
950 self.query_drop("COMMIT")?;
952 }
953 Err(_) => {
954 let _ = self.query_drop("ROLLBACK");
957 }
958 }
959 }
960
961 result
962 }
963
964 pub fn lowlevel_bind<P: ToParams>(
976 &mut self,
977 portal: &str,
978 statement_name: &str,
979 params: P,
980 ) -> Result<()> {
981 let result = self.lowlevel_bind_inner(portal, statement_name, ¶ms);
982 if let Err(e) = &result
983 && e.is_connection_broken()
984 {
985 self.is_broken = true;
986 }
987 result
988 }
989
990 fn lowlevel_bind_inner<P: ToParams>(
991 &mut self,
992 portal: &str,
993 statement_name: &str,
994 params: &P,
995 ) -> Result<()> {
996 use crate::protocol::backend::{BindComplete, ErrorResponse, RawMessage, msg_type};
997 use crate::protocol::frontend::{write_bind, write_flush};
998
999 let param_oids = params.natural_oids();
1000 self.buffer_set.write_buffer.clear();
1001 write_bind(
1002 &mut self.buffer_set.write_buffer,
1003 portal,
1004 statement_name,
1005 params,
1006 ¶m_oids,
1007 )?;
1008 write_flush(&mut self.buffer_set.write_buffer);
1009
1010 self.stream.write_all(&self.buffer_set.write_buffer)?;
1011 self.stream.flush()?;
1012
1013 loop {
1014 self.stream.read_message(&mut self.buffer_set)?;
1015 let type_byte = self.buffer_set.type_byte;
1016
1017 if RawMessage::is_async_type(type_byte) {
1018 continue;
1019 }
1020
1021 match type_byte {
1022 msg_type::BIND_COMPLETE => {
1023 BindComplete::parse(&self.buffer_set.read_buffer)?;
1024 return Ok(());
1025 }
1026 msg_type::ERROR_RESPONSE => {
1027 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
1028 return Err(error.into_error());
1029 }
1030 _ => {
1031 return Err(Error::LibraryBug(format!(
1032 "Expected BindComplete or ErrorResponse, got '{}'",
1033 type_byte as char
1034 )));
1035 }
1036 }
1037 }
1038 }
1039
1040 pub fn lowlevel_execute<H: ExtendedHandler>(
1053 &mut self,
1054 portal: &str,
1055 max_rows: u32,
1056 handler: &mut H,
1057 ) -> Result<bool> {
1058 let result = self.lowlevel_execute_inner(portal, max_rows, handler);
1059 if let Err(e) = &result
1060 && e.is_connection_broken()
1061 {
1062 self.is_broken = true;
1063 }
1064 result
1065 }
1066
1067 fn lowlevel_execute_inner<H: ExtendedHandler>(
1068 &mut self,
1069 portal: &str,
1070 max_rows: u32,
1071 handler: &mut H,
1072 ) -> Result<bool> {
1073 use crate::protocol::backend::{
1074 CommandComplete, DataRow, ErrorResponse, NoData, PortalSuspended, RawMessage,
1075 RowDescription, msg_type,
1076 };
1077 use crate::protocol::frontend::{write_describe_portal, write_execute, write_flush};
1078
1079 self.buffer_set.write_buffer.clear();
1080 write_describe_portal(&mut self.buffer_set.write_buffer, portal);
1081 write_execute(&mut self.buffer_set.write_buffer, portal, max_rows);
1082 write_flush(&mut self.buffer_set.write_buffer);
1083
1084 self.stream.write_all(&self.buffer_set.write_buffer)?;
1085 self.stream.flush()?;
1086
1087 let mut column_buffer: Vec<u8> = Vec::new();
1088
1089 loop {
1090 self.stream.read_message(&mut self.buffer_set)?;
1091 let type_byte = self.buffer_set.type_byte;
1092
1093 if RawMessage::is_async_type(type_byte) {
1094 continue;
1095 }
1096
1097 match type_byte {
1098 msg_type::ROW_DESCRIPTION => {
1099 column_buffer.clear();
1100 column_buffer.extend_from_slice(&self.buffer_set.read_buffer);
1101 let cols = RowDescription::parse(&column_buffer)?;
1102 handler.result_start(cols)?;
1103 }
1104 msg_type::NO_DATA => {
1105 NoData::parse(&self.buffer_set.read_buffer)?;
1106 }
1107 msg_type::DATA_ROW => {
1108 let cols = RowDescription::parse(&column_buffer)?;
1109 let row = DataRow::parse(&self.buffer_set.read_buffer)?;
1110 handler.row(cols, row)?;
1111 }
1112 msg_type::COMMAND_COMPLETE => {
1113 let complete = CommandComplete::parse(&self.buffer_set.read_buffer)?;
1114 handler.result_end(complete)?;
1115 return Ok(false); }
1117 msg_type::PORTAL_SUSPENDED => {
1118 PortalSuspended::parse(&self.buffer_set.read_buffer)?;
1119 return Ok(true); }
1121 msg_type::ERROR_RESPONSE => {
1122 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
1123 return Err(error.into_error());
1124 }
1125 _ => {
1126 return Err(Error::LibraryBug(format!(
1127 "Unexpected message in execute: '{}'",
1128 type_byte as char
1129 )));
1130 }
1131 }
1132 }
1133 }
1134
1135 pub fn lowlevel_sync(&mut self) -> Result<()> {
1142 let result = self.lowlevel_sync_inner();
1143 if let Err(e) = &result
1144 && e.is_connection_broken()
1145 {
1146 self.is_broken = true;
1147 }
1148 result
1149 }
1150
1151 fn lowlevel_sync_inner(&mut self) -> Result<()> {
1152 use crate::protocol::backend::{ErrorResponse, RawMessage, ReadyForQuery, msg_type};
1153 use crate::protocol::frontend::write_sync;
1154
1155 self.buffer_set.write_buffer.clear();
1156 write_sync(&mut self.buffer_set.write_buffer);
1157
1158 self.stream.write_all(&self.buffer_set.write_buffer)?;
1159 self.stream.flush()?;
1160
1161 let mut pending_error: Option<Error> = None;
1162
1163 loop {
1164 self.stream.read_message(&mut self.buffer_set)?;
1165 let type_byte = self.buffer_set.type_byte;
1166
1167 if RawMessage::is_async_type(type_byte) {
1168 continue;
1169 }
1170
1171 match type_byte {
1172 msg_type::READY_FOR_QUERY => {
1173 let ready = ReadyForQuery::parse(&self.buffer_set.read_buffer)?;
1174 self.transaction_status = ready.transaction_status().unwrap_or_default();
1175 if let Some(e) = pending_error {
1176 return Err(e);
1177 }
1178 return Ok(());
1179 }
1180 msg_type::ERROR_RESPONSE => {
1181 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
1182 pending_error = Some(error.into_error());
1183 }
1184 _ => {
1185 }
1187 }
1188 }
1189 }
1190
1191 pub fn lowlevel_flush(&mut self) -> Result<()> {
1198 use crate::protocol::frontend::write_flush;
1199
1200 self.buffer_set.write_buffer.clear();
1201 write_flush(&mut self.buffer_set.write_buffer);
1202
1203 self.stream.write_all(&self.buffer_set.write_buffer)?;
1204 self.stream.flush()?;
1205 Ok(())
1206 }
1207
1208 pub fn exec_portal<S: IntoStatement, P, F, T>(
1238 &mut self,
1239 statement: S,
1240 params: P,
1241 f: F,
1242 ) -> Result<T>
1243 where
1244 P: ToParams,
1245 F: FnOnce(&mut UnnamedPortal<'_>) -> Result<T>,
1246 {
1247 let result = self.exec_portal_inner(&statement, ¶ms, f);
1248 if let Err(e) = &result
1249 && e.is_connection_broken()
1250 {
1251 self.is_broken = true;
1252 }
1253 result
1254 }
1255
1256 fn exec_portal_inner<S: IntoStatement, P, F, T>(
1257 &mut self,
1258 statement: &S,
1259 params: &P,
1260 f: F,
1261 ) -> Result<T>
1262 where
1263 P: ToParams,
1264 F: FnOnce(&mut UnnamedPortal<'_>) -> Result<T>,
1265 {
1266 let mut state_machine = match statement.statement_ref() {
1268 StatementRef::Sql(sql) => {
1269 BindStateMachine::bind_sql(&mut self.buffer_set, "", sql, params)?
1270 }
1271 StatementRef::Prepared(stmt) => BindStateMachine::bind_prepared(
1272 &mut self.buffer_set,
1273 "",
1274 &stmt.wire_name(),
1275 &stmt.param_oids,
1276 params,
1277 )?,
1278 };
1279
1280 loop {
1282 match state_machine.step(&mut self.buffer_set)? {
1283 Action::ReadMessage => {
1284 self.stream.read_message(&mut self.buffer_set)?;
1285 }
1286 Action::Write => {
1287 self.stream.write_all(&self.buffer_set.write_buffer)?;
1288 self.stream.flush()?;
1289 }
1290 Action::WriteAndReadMessage => {
1291 self.stream.write_all(&self.buffer_set.write_buffer)?;
1292 self.stream.flush()?;
1293 self.stream.read_message(&mut self.buffer_set)?;
1294 }
1295 Action::Finished => break,
1296 _ => return Err(Error::LibraryBug("Unexpected action in bind".into())),
1297 }
1298 }
1299
1300 let mut portal = UnnamedPortal { conn: self };
1302 let result = f(&mut portal);
1303
1304 let sync_result = portal.conn.lowlevel_sync();
1306
1307 match (result, sync_result) {
1309 (Ok(v), Ok(())) => Ok(v),
1310 (Err(e), _) => Err(e),
1311 (Ok(_), Err(e)) => Err(e),
1312 }
1313 }
1314
1315 pub fn lowlevel_close_portal(&mut self, portal: &str) -> Result<()> {
1317 let result = self.lowlevel_close_portal_inner(portal);
1318 if let Err(e) = &result
1319 && e.is_connection_broken()
1320 {
1321 self.is_broken = true;
1322 }
1323 result
1324 }
1325
1326 fn lowlevel_close_portal_inner(&mut self, portal: &str) -> Result<()> {
1327 use crate::protocol::backend::{CloseComplete, ErrorResponse, RawMessage, msg_type};
1328 use crate::protocol::frontend::{write_close_portal, write_flush};
1329
1330 self.buffer_set.write_buffer.clear();
1331 write_close_portal(&mut self.buffer_set.write_buffer, portal);
1332 write_flush(&mut self.buffer_set.write_buffer);
1333
1334 self.stream.write_all(&self.buffer_set.write_buffer)?;
1335 self.stream.flush()?;
1336
1337 loop {
1338 self.stream.read_message(&mut self.buffer_set)?;
1339 let type_byte = self.buffer_set.type_byte;
1340
1341 if RawMessage::is_async_type(type_byte) {
1342 continue;
1343 }
1344
1345 match type_byte {
1346 msg_type::CLOSE_COMPLETE => {
1347 CloseComplete::parse(&self.buffer_set.read_buffer)?;
1348 return Ok(());
1349 }
1350 msg_type::ERROR_RESPONSE => {
1351 let error = ErrorResponse::parse(&self.buffer_set.read_buffer)?;
1352 return Err(error.into_error());
1353 }
1354 _ => {
1355 return Err(Error::LibraryBug(format!(
1356 "Expected CloseComplete or ErrorResponse, got '{}'",
1357 type_byte as char
1358 )));
1359 }
1360 }
1361 }
1362 }
1363}
1364
1365impl Drop for Conn {
1366 fn drop(&mut self) {
1367 self.buffer_set.write_buffer.clear();
1369 write_terminate(&mut self.buffer_set.write_buffer);
1370 let _ = self.stream.write_all(&self.buffer_set.write_buffer);
1371 let _ = self.stream.flush();
1372 }
1373}