1#![allow(clippy::manual_async_fn)]
15#![allow(clippy::result_large_err)]
17
18use std::collections::HashMap;
19use std::future::Future;
20#[cfg(feature = "tls")]
21use std::io::{Read, Write};
22use std::sync::Arc;
23
24use asupersync::io::{AsyncRead, AsyncWrite, ReadBuf};
25use asupersync::net::TcpStream;
26use asupersync::sync::Mutex;
27use asupersync::{Cx, Outcome};
28
29use sqlmodel_core::connection::{Connection, IsolationLevel, PreparedStatement, TransactionOps};
30use sqlmodel_core::error::{
31 ConnectionError, ConnectionErrorKind, ProtocolError, QueryError, QueryErrorKind,
32};
33use sqlmodel_core::row::ColumnInfo;
34use sqlmodel_core::{Error, Row, Value};
35
36use crate::auth::ScramClient;
37use crate::config::{PgConfig, SslMode};
38use crate::connection::{ConnectionState, TransactionStatusState};
39use crate::protocol::{
40 BackendMessage, DescribeKind, ErrorFields, FrontendMessage, MessageReader, MessageWriter,
41 PROTOCOL_VERSION,
42};
43use crate::types::{Format, decode_value, encode_value};
44
45#[cfg(feature = "tls")]
46use crate::tls;
47
48enum PgAsyncStream {
49 Plain(TcpStream),
50 #[cfg(feature = "tls")]
51 Tls(AsyncTlsStream),
52 #[cfg(feature = "tls")]
53 Closed,
54}
55
56impl PgAsyncStream {
57 #[cfg(feature = "tls")]
58 async fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
59 match self {
60 PgAsyncStream::Plain(s) => read_exact_plain_async(s, buf).await,
61 #[cfg(feature = "tls")]
62 PgAsyncStream::Tls(s) => s.read_exact(buf).await,
63 #[cfg(feature = "tls")]
64 PgAsyncStream::Closed => Err(std::io::Error::new(
65 std::io::ErrorKind::NotConnected,
66 "connection closed",
67 )),
68 }
69 }
70
71 async fn read_some(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
72 match self {
73 PgAsyncStream::Plain(s) => read_some_plain_async(s, buf).await,
74 #[cfg(feature = "tls")]
75 PgAsyncStream::Tls(s) => s.read_plain(buf).await,
76 #[cfg(feature = "tls")]
77 PgAsyncStream::Closed => Err(std::io::Error::new(
78 std::io::ErrorKind::NotConnected,
79 "connection closed",
80 )),
81 }
82 }
83
84 async fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
85 match self {
86 PgAsyncStream::Plain(s) => write_all_plain_async(s, buf).await,
87 #[cfg(feature = "tls")]
88 PgAsyncStream::Tls(s) => s.write_all(buf).await,
89 #[cfg(feature = "tls")]
90 PgAsyncStream::Closed => Err(std::io::Error::new(
91 std::io::ErrorKind::NotConnected,
92 "connection closed",
93 )),
94 }
95 }
96
97 async fn flush(&mut self) -> std::io::Result<()> {
98 match self {
99 PgAsyncStream::Plain(s) => flush_plain_async(s).await,
100 #[cfg(feature = "tls")]
101 PgAsyncStream::Tls(s) => s.flush().await,
102 #[cfg(feature = "tls")]
103 PgAsyncStream::Closed => Err(std::io::Error::new(
104 std::io::ErrorKind::NotConnected,
105 "connection closed",
106 )),
107 }
108 }
109}
110
111#[cfg(feature = "tls")]
112struct AsyncTlsStream {
113 tcp: TcpStream,
114 tls: rustls::ClientConnection,
115}
116
117#[cfg(feature = "tls")]
118impl AsyncTlsStream {
119 async fn handshake(mut tcp: TcpStream, ssl_mode: SslMode, host: &str) -> Result<Self, Error> {
120 let config = tls::build_client_config(ssl_mode)?;
121 let server_name = tls::server_name(host)?;
122 let mut tls = rustls::ClientConnection::new(std::sync::Arc::new(config), server_name)
123 .map_err(|e| connection_error(format!("Failed to create TLS connection: {e}")))?;
124
125 while tls.is_handshaking() {
126 while tls.wants_write() {
127 let mut out = Vec::new();
128 tls.write_tls(&mut out)
129 .map_err(|e| connection_error(format!("TLS handshake write_tls error: {e}")))?;
130 if !out.is_empty() {
131 write_all_plain_async(&mut tcp, &out).await.map_err(|e| {
132 Error::Connection(ConnectionError {
133 kind: ConnectionErrorKind::Disconnected,
134 message: format!("TLS handshake write error: {e}"),
135 source: Some(Box::new(e)),
136 })
137 })?;
138 }
139 }
140
141 if tls.wants_read() {
142 let mut buf = [0u8; 8192];
143 let n = read_some_plain_async(&mut tcp, &mut buf)
144 .await
145 .map_err(|e| {
146 Error::Connection(ConnectionError {
147 kind: ConnectionErrorKind::Disconnected,
148 message: format!("TLS handshake read error: {e}"),
149 source: Some(Box::new(e)),
150 })
151 })?;
152 if n == 0 {
153 return Err(connection_error("Connection closed during TLS handshake"));
154 }
155
156 let mut cursor = std::io::Cursor::new(&buf[..n]);
157 tls.read_tls(&mut cursor)
158 .map_err(|e| connection_error(format!("TLS handshake read_tls error: {e}")))?;
159 tls.process_new_packets()
160 .map_err(|e| connection_error(format!("TLS handshake error: {e}")))?;
161 }
162 }
163
164 Ok(Self { tcp, tls })
165 }
166
167 async fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
168 let mut read = 0;
169 while read < buf.len() {
170 let n = self.read_plain(&mut buf[read..]).await?;
171 if n == 0 {
172 return Err(std::io::Error::new(
173 std::io::ErrorKind::UnexpectedEof,
174 "connection closed",
175 ));
176 }
177 read += n;
178 }
179 Ok(())
180 }
181
182 async fn read_plain(&mut self, out: &mut [u8]) -> std::io::Result<usize> {
183 loop {
184 match self.tls.reader().read(out) {
185 Ok(n) if n > 0 => return Ok(n),
186 Ok(_) => {}
187 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {}
188 Err(e) => return Err(e),
189 }
190
191 if !self.tls.wants_read() {
192 return Ok(0);
193 }
194
195 let mut enc = [0u8; 8192];
196 let n = read_some_plain_async(&mut self.tcp, &mut enc).await?;
197 if n == 0 {
198 return Ok(0);
199 }
200
201 let mut cursor = std::io::Cursor::new(&enc[..n]);
202 self.tls.read_tls(&mut cursor)?;
203 self.tls
204 .process_new_packets()
205 .map_err(|e| std::io::Error::other(format!("TLS error: {e}")))?;
206 }
207 }
208
209 async fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
210 let mut written = 0;
211 while written < buf.len() {
212 let n = self.tls.writer().write(&buf[written..])?;
213 if n == 0 {
214 return Err(std::io::Error::new(
215 std::io::ErrorKind::WriteZero,
216 "TLS write zero",
217 ));
218 }
219 written += n;
220 self.flush().await?;
221 }
222 Ok(())
223 }
224
225 async fn flush(&mut self) -> std::io::Result<()> {
226 self.tls.writer().flush()?;
227 while self.tls.wants_write() {
228 let mut out = Vec::new();
229 self.tls.write_tls(&mut out)?;
230 if !out.is_empty() {
231 write_all_plain_async(&mut self.tcp, &out).await?;
232 }
233 }
234 flush_plain_async(&mut self.tcp).await
235 }
236}
237
238#[cfg(feature = "tls")]
239async fn read_exact_plain_async(stream: &mut TcpStream, buf: &mut [u8]) -> std::io::Result<()> {
240 let mut read = 0;
241 while read < buf.len() {
242 let n = read_some_plain_async(stream, &mut buf[read..]).await?;
243 if n == 0 {
244 return Err(std::io::Error::new(
245 std::io::ErrorKind::UnexpectedEof,
246 "connection closed",
247 ));
248 }
249 read += n;
250 }
251 Ok(())
252}
253
254async fn read_some_plain_async(stream: &mut TcpStream, buf: &mut [u8]) -> std::io::Result<usize> {
255 let mut read_buf = ReadBuf::new(buf);
256 std::future::poll_fn(|cx| std::pin::Pin::new(&mut *stream).poll_read(cx, &mut read_buf))
257 .await?;
258 Ok(read_buf.filled().len())
259}
260
261async fn write_all_plain_async(stream: &mut TcpStream, buf: &[u8]) -> std::io::Result<()> {
262 let mut written = 0;
263 while written < buf.len() {
264 let n = std::future::poll_fn(|cx| {
265 std::pin::Pin::new(&mut *stream).poll_write(cx, &buf[written..])
266 })
267 .await?;
268 if n == 0 {
269 return Err(std::io::Error::new(
270 std::io::ErrorKind::WriteZero,
271 "connection closed",
272 ));
273 }
274 written += n;
275 }
276 Ok(())
277}
278
279async fn flush_plain_async(stream: &mut TcpStream) -> std::io::Result<()> {
280 std::future::poll_fn(|cx| std::pin::Pin::new(&mut *stream).poll_flush(cx)).await
281}
282
283pub struct PgAsyncConnection {
288 stream: PgAsyncStream,
289 state: ConnectionState,
290 process_id: i32,
291 secret_key: i32,
292 parameters: HashMap<String, String>,
293 next_prepared_id: u64,
294 prepared: HashMap<u64, PgPreparedMeta>,
295 config: PgConfig,
296 reader: MessageReader,
297 writer: MessageWriter,
298 read_buf: Vec<u8>,
299}
300
301#[derive(Debug, Clone)]
302struct PgPreparedMeta {
303 name: String,
304 param_type_oids: Vec<u32>,
305}
306
307impl std::fmt::Debug for PgAsyncConnection {
308 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
309 f.debug_struct("PgAsyncConnection")
310 .field("state", &self.state)
311 .field("process_id", &self.process_id)
312 .field("host", &self.config.host)
313 .field("port", &self.config.port)
314 .field("database", &self.config.database)
315 .finish_non_exhaustive()
316 }
317}
318
319impl PgAsyncConnection {
320 pub async fn connect(_cx: &Cx, config: PgConfig) -> Outcome<Self, Error> {
322 let addr = config.socket_addr();
323 let socket_addr = match addr.parse() {
324 Ok(a) => a,
325 Err(e) => {
326 return Outcome::Err(Error::Connection(ConnectionError {
327 kind: ConnectionErrorKind::Connect,
328 message: format!("Invalid socket address: {}", e),
329 source: None,
330 }));
331 }
332 };
333
334 let stream = match TcpStream::connect_timeout(socket_addr, config.connect_timeout).await {
335 Ok(s) => s,
336 Err(e) => {
337 let kind = if e.kind() == std::io::ErrorKind::ConnectionRefused {
338 ConnectionErrorKind::Refused
339 } else {
340 ConnectionErrorKind::Connect
341 };
342 return Outcome::Err(Error::Connection(ConnectionError {
343 kind,
344 message: format!("Failed to connect to {}: {}", addr, e),
345 source: Some(Box::new(e)),
346 }));
347 }
348 };
349
350 stream.set_nodelay(true).ok();
351
352 let mut conn = Self {
353 stream: PgAsyncStream::Plain(stream),
354 state: ConnectionState::Connecting,
355 process_id: 0,
356 secret_key: 0,
357 parameters: HashMap::new(),
358 next_prepared_id: 1,
359 prepared: HashMap::new(),
360 config,
361 reader: MessageReader::new(),
362 writer: MessageWriter::new(),
363 read_buf: vec![0u8; 8192],
364 };
365
366 if conn.config.ssl_mode.should_try_ssl() {
368 #[cfg(feature = "tls")]
369 match conn.negotiate_ssl().await {
370 Outcome::Ok(()) => {}
371 Outcome::Err(e) => return Outcome::Err(e),
372 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
373 Outcome::Panicked(p) => return Outcome::Panicked(p),
374 }
375
376 #[cfg(not(feature = "tls"))]
377 if conn.config.ssl_mode != SslMode::Prefer {
378 return Outcome::Err(connection_error(
379 "TLS requested but 'sqlmodel-postgres' was built without feature 'tls'",
380 ));
381 }
382 }
383
384 if let Outcome::Err(e) = conn.send_startup().await {
386 return Outcome::Err(e);
387 }
388 conn.state = ConnectionState::Authenticating;
389
390 match conn.handle_auth().await {
391 Outcome::Ok(()) => {}
392 Outcome::Err(e) => return Outcome::Err(e),
393 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
394 Outcome::Panicked(p) => return Outcome::Panicked(p),
395 }
396
397 match conn.read_startup_messages().await {
398 Outcome::Ok(()) => Outcome::Ok(conn),
399 Outcome::Err(e) => Outcome::Err(e),
400 Outcome::Cancelled(r) => Outcome::Cancelled(r),
401 Outcome::Panicked(p) => Outcome::Panicked(p),
402 }
403 }
404
405 pub async fn query_async(
407 &mut self,
408 cx: &Cx,
409 sql: &str,
410 params: &[Value],
411 ) -> Outcome<Vec<Row>, Error> {
412 match self.run_extended(cx, sql, params).await {
413 Outcome::Ok(result) => Outcome::Ok(result.rows),
414 Outcome::Err(e) => Outcome::Err(e),
415 Outcome::Cancelled(r) => Outcome::Cancelled(r),
416 Outcome::Panicked(p) => Outcome::Panicked(p),
417 }
418 }
419
420 pub async fn execute_async(
422 &mut self,
423 cx: &Cx,
424 sql: &str,
425 params: &[Value],
426 ) -> Outcome<u64, Error> {
427 match self.run_extended(cx, sql, params).await {
428 Outcome::Ok(result) => {
429 Outcome::Ok(parse_rows_affected(result.command_tag.as_deref()).unwrap_or(0))
430 }
431 Outcome::Err(e) => Outcome::Err(e),
432 Outcome::Cancelled(r) => Outcome::Cancelled(r),
433 Outcome::Panicked(p) => Outcome::Panicked(p),
434 }
435 }
436
437 pub async fn insert_async(
443 &mut self,
444 cx: &Cx,
445 sql: &str,
446 params: &[Value],
447 ) -> Outcome<i64, Error> {
448 let result = match self.run_extended(cx, sql, params).await {
449 Outcome::Ok(r) => r,
450 Outcome::Err(e) => return Outcome::Err(e),
451 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
452 Outcome::Panicked(p) => return Outcome::Panicked(p),
453 };
454
455 let Some(row) = result.rows.first() else {
456 return Outcome::Err(query_error_msg(
457 "INSERT did not return an id; add `RETURNING id`",
458 QueryErrorKind::Database,
459 ));
460 };
461 let Some(id_value) = row.get(0) else {
462 return Outcome::Err(query_error_msg(
463 "INSERT result row missing id column",
464 QueryErrorKind::Database,
465 ));
466 };
467 match id_value.as_i64() {
468 Some(v) => Outcome::Ok(v),
469 None => Outcome::Err(query_error_msg(
470 "INSERT returned non-integer id",
471 QueryErrorKind::Database,
472 )),
473 }
474 }
475
476 pub async fn ping_async(&mut self, cx: &Cx) -> Outcome<(), Error> {
478 self.execute_async(cx, "SELECT 1", &[]).await.map(|_| ())
479 }
480
481 pub async fn close_async(&mut self, cx: &Cx) -> Outcome<(), Error> {
483 let _ = self.send_message(cx, &FrontendMessage::Terminate).await;
488 self.state = ConnectionState::Closed;
489 Outcome::Ok(())
490 }
491
492 pub async fn prepare_async(&mut self, cx: &Cx, sql: &str) -> Outcome<PreparedStatement, Error> {
496 let stmt_id = self.next_prepared_id;
497 self.next_prepared_id = self.next_prepared_id.saturating_add(1);
498 let stmt_name = format!("sqlmodel_stmt_{stmt_id}");
499
500 if let Outcome::Err(e) = self
501 .send_message(
502 cx,
503 &FrontendMessage::Parse {
504 name: stmt_name.clone(),
505 query: sql.to_string(),
506 param_types: Vec::new(),
509 },
510 )
511 .await
512 {
513 return Outcome::Err(e);
514 }
515
516 if let Outcome::Err(e) = self
517 .send_message(
518 cx,
519 &FrontendMessage::Describe {
520 kind: DescribeKind::Statement,
521 name: stmt_name.clone(),
522 },
523 )
524 .await
525 {
526 return Outcome::Err(e);
527 }
528
529 if let Outcome::Err(e) = self.send_message(cx, &FrontendMessage::Sync).await {
530 return Outcome::Err(e);
531 }
532
533 let mut param_type_oids: Option<Vec<u32>> = None;
534 let mut columns: Option<Vec<String>> = None;
535
536 loop {
537 let msg = match self.receive_message(cx).await {
538 Outcome::Ok(m) => m,
539 Outcome::Err(e) => return Outcome::Err(e),
540 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
541 Outcome::Panicked(p) => return Outcome::Panicked(p),
542 };
543
544 match msg {
545 BackendMessage::ParseComplete
546 | BackendMessage::BindComplete
547 | BackendMessage::CloseComplete
548 | BackendMessage::NoData
549 | BackendMessage::EmptyQueryResponse => {}
550 BackendMessage::ParameterDescription(oids) => {
551 param_type_oids = Some(oids);
552 }
553 BackendMessage::RowDescription(desc) => {
554 columns = Some(desc.iter().map(|f| f.name.clone()).collect());
555 }
556 BackendMessage::ReadyForQuery(status) => {
557 self.state = ConnectionState::Ready(TransactionStatusState::from(status));
558 break;
559 }
560 BackendMessage::ErrorResponse(e) => {
561 self.state = ConnectionState::Error;
562 return Outcome::Err(error_from_fields(&e));
563 }
564 BackendMessage::NoticeResponse(_notice) => {}
565 other => {
566 return Outcome::Err(protocol_error(format!(
567 "Unexpected message during prepare: {other:?}"
568 )));
569 }
570 }
571 }
572
573 let param_type_oids = param_type_oids.unwrap_or_default();
574 self.prepared.insert(
575 stmt_id,
576 PgPreparedMeta {
577 name: stmt_name,
578 param_type_oids: param_type_oids.clone(),
579 },
580 );
581
582 match columns {
583 Some(cols) => Outcome::Ok(PreparedStatement::with_columns(
584 stmt_id,
585 sql.to_string(),
586 param_type_oids.len(),
587 cols,
588 )),
589 None => Outcome::Ok(PreparedStatement::new(
590 stmt_id,
591 sql.to_string(),
592 param_type_oids.len(),
593 )),
594 }
595 }
596
597 pub async fn query_prepared_async(
598 &mut self,
599 cx: &Cx,
600 stmt: &PreparedStatement,
601 params: &[Value],
602 ) -> Outcome<Vec<Row>, Error> {
603 let meta = match self.prepared.get(&stmt.id()) {
604 Some(m) => m.clone(),
605 None => {
606 return Outcome::Err(query_error_msg(
607 format!("Unknown prepared statement id {}", stmt.id()),
608 QueryErrorKind::Database,
609 ));
610 }
611 };
612
613 if meta.param_type_oids.len() != params.len() {
614 return Outcome::Err(query_error_msg(
615 format!(
616 "Prepared statement expects {} params, got {}",
617 meta.param_type_oids.len(),
618 params.len()
619 ),
620 QueryErrorKind::Database,
621 ));
622 }
623
624 match self.run_prepared(cx, &meta, params).await {
625 Outcome::Ok(result) => Outcome::Ok(result.rows),
626 Outcome::Err(e) => Outcome::Err(e),
627 Outcome::Cancelled(r) => Outcome::Cancelled(r),
628 Outcome::Panicked(p) => Outcome::Panicked(p),
629 }
630 }
631
632 pub async fn execute_prepared_async(
633 &mut self,
634 cx: &Cx,
635 stmt: &PreparedStatement,
636 params: &[Value],
637 ) -> Outcome<u64, Error> {
638 let meta = match self.prepared.get(&stmt.id()) {
639 Some(m) => m.clone(),
640 None => {
641 return Outcome::Err(query_error_msg(
642 format!("Unknown prepared statement id {}", stmt.id()),
643 QueryErrorKind::Database,
644 ));
645 }
646 };
647
648 if meta.param_type_oids.len() != params.len() {
649 return Outcome::Err(query_error_msg(
650 format!(
651 "Prepared statement expects {} params, got {}",
652 meta.param_type_oids.len(),
653 params.len()
654 ),
655 QueryErrorKind::Database,
656 ));
657 }
658
659 match self.run_prepared(cx, &meta, params).await {
660 Outcome::Ok(result) => {
661 Outcome::Ok(parse_rows_affected(result.command_tag.as_deref()).unwrap_or(0))
662 }
663 Outcome::Err(e) => Outcome::Err(e),
664 Outcome::Cancelled(r) => Outcome::Cancelled(r),
665 Outcome::Panicked(p) => Outcome::Panicked(p),
666 }
667 }
668
669 async fn read_extended_result(&mut self, cx: &Cx) -> Outcome<PgQueryResult, Error> {
672 let mut field_descs: Option<Vec<crate::protocol::FieldDescription>> = None;
674 let mut columns: Option<Arc<ColumnInfo>> = None;
675 let mut rows: Vec<Row> = Vec::new();
676 let mut command_tag: Option<String> = None;
677
678 loop {
679 let msg = match self.receive_message(cx).await {
680 Outcome::Ok(m) => m,
681 Outcome::Err(e) => return Outcome::Err(e),
682 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
683 Outcome::Panicked(p) => return Outcome::Panicked(p),
684 };
685
686 match msg {
687 BackendMessage::ParseComplete
688 | BackendMessage::BindComplete
689 | BackendMessage::CloseComplete
690 | BackendMessage::ParameterDescription(_)
691 | BackendMessage::NoData
692 | BackendMessage::PortalSuspended
693 | BackendMessage::EmptyQueryResponse => {}
694 BackendMessage::RowDescription(desc) => {
695 let names: Vec<String> = desc.iter().map(|f| f.name.clone()).collect();
696 columns = Some(Arc::new(ColumnInfo::new(names)));
697 field_descs = Some(desc);
698 }
699 BackendMessage::DataRow(raw_values) => {
700 let Some(ref desc) = field_descs else {
701 return Outcome::Err(protocol_error(
702 "DataRow received before RowDescription",
703 ));
704 };
705 let Some(ref cols) = columns else {
706 return Outcome::Err(protocol_error("Row column metadata missing"));
707 };
708 if raw_values.len() != desc.len() {
709 return Outcome::Err(protocol_error("DataRow field count mismatch"));
710 }
711
712 let mut values = Vec::with_capacity(raw_values.len());
713 for (i, raw) in raw_values.into_iter().enumerate() {
714 match raw {
715 None => values.push(Value::Null),
716 Some(bytes) => {
717 let field = &desc[i];
718 let format = Format::from_code(field.format);
719 let decoded = match decode_value(
720 field.type_oid,
721 Some(bytes.as_slice()),
722 format,
723 ) {
724 Ok(v) => v,
725 Err(e) => return Outcome::Err(e),
726 };
727 values.push(decoded);
728 }
729 }
730 }
731 rows.push(Row::with_columns(Arc::clone(cols), values));
732 }
733 BackendMessage::CommandComplete(tag) => {
734 command_tag = Some(tag);
735 }
736 BackendMessage::ReadyForQuery(status) => {
737 self.state = ConnectionState::Ready(TransactionStatusState::from(status));
738 break;
739 }
740 BackendMessage::ErrorResponse(e) => {
741 self.state = ConnectionState::Error;
742 return Outcome::Err(error_from_fields(&e));
743 }
744 BackendMessage::NoticeResponse(_notice) => {}
745 _ => {}
746 }
747 }
748
749 Outcome::Ok(PgQueryResult { rows, command_tag })
750 }
751
752 async fn run_extended(
753 &mut self,
754 cx: &Cx,
755 sql: &str,
756 params: &[Value],
757 ) -> Outcome<PgQueryResult, Error> {
758 let mut param_types = Vec::with_capacity(params.len());
760 let mut param_values = Vec::with_capacity(params.len());
761
762 for v in params {
763 if matches!(v, Value::Null) {
764 param_types.push(0);
765 param_values.push(None);
766 continue;
767 }
768 match encode_value(v, Format::Text) {
769 Ok((bytes, oid)) => {
770 param_types.push(oid);
771 param_values.push(Some(bytes));
772 }
773 Err(e) => return Outcome::Err(e),
774 }
775 }
776
777 if let Outcome::Err(e) = self
779 .send_message(
780 cx,
781 &FrontendMessage::Parse {
782 name: String::new(),
783 query: sql.to_string(),
784 param_types,
785 },
786 )
787 .await
788 {
789 return Outcome::Err(e);
790 }
791
792 let param_formats = if params.is_empty() {
793 Vec::new()
794 } else {
795 vec![Format::Text.code()]
796 };
797 if let Outcome::Err(e) = self
798 .send_message(
799 cx,
800 &FrontendMessage::Bind {
801 portal: String::new(),
802 statement: String::new(),
803 param_formats,
804 params: param_values,
805 result_formats: Vec::new(),
807 },
808 )
809 .await
810 {
811 return Outcome::Err(e);
812 }
813
814 if let Outcome::Err(e) = self
815 .send_message(
816 cx,
817 &FrontendMessage::Describe {
818 kind: DescribeKind::Portal,
819 name: String::new(),
820 },
821 )
822 .await
823 {
824 return Outcome::Err(e);
825 }
826
827 if let Outcome::Err(e) = self
828 .send_message(
829 cx,
830 &FrontendMessage::Execute {
831 portal: String::new(),
832 max_rows: 0,
833 },
834 )
835 .await
836 {
837 return Outcome::Err(e);
838 }
839
840 if let Outcome::Err(e) = self.send_message(cx, &FrontendMessage::Sync).await {
841 return Outcome::Err(e);
842 }
843 self.read_extended_result(cx).await
844 }
845
846 async fn run_prepared(
847 &mut self,
848 cx: &Cx,
849 meta: &PgPreparedMeta,
850 params: &[Value],
851 ) -> Outcome<PgQueryResult, Error> {
852 let mut param_values = Vec::with_capacity(params.len());
853
854 for (i, v) in params.iter().enumerate() {
855 if matches!(v, Value::Null) {
856 param_values.push(None);
857 continue;
858 }
859 match encode_value(v, Format::Text) {
860 Ok((bytes, oid)) => {
861 let expected = meta.param_type_oids.get(i).copied().unwrap_or(0);
862 if expected != 0 && expected != oid {
863 return Outcome::Err(query_error_msg(
864 format!(
865 "Prepared statement param {} expects type OID {}, got {}",
866 i + 1,
867 expected,
868 oid
869 ),
870 QueryErrorKind::Database,
871 ));
872 }
873 param_values.push(Some(bytes));
874 }
875 Err(e) => return Outcome::Err(e),
876 }
877 }
878
879 let param_formats = if params.is_empty() {
880 Vec::new()
881 } else {
882 vec![Format::Text.code()]
883 };
884
885 if let Outcome::Err(e) = self
886 .send_message(
887 cx,
888 &FrontendMessage::Bind {
889 portal: String::new(),
890 statement: meta.name.clone(),
891 param_formats,
892 params: param_values,
893 result_formats: Vec::new(),
894 },
895 )
896 .await
897 {
898 return Outcome::Err(e);
899 }
900
901 if let Outcome::Err(e) = self
902 .send_message(
903 cx,
904 &FrontendMessage::Describe {
905 kind: DescribeKind::Portal,
906 name: String::new(),
907 },
908 )
909 .await
910 {
911 return Outcome::Err(e);
912 }
913
914 if let Outcome::Err(e) = self
915 .send_message(
916 cx,
917 &FrontendMessage::Execute {
918 portal: String::new(),
919 max_rows: 0,
920 },
921 )
922 .await
923 {
924 return Outcome::Err(e);
925 }
926
927 if let Outcome::Err(e) = self.send_message(cx, &FrontendMessage::Sync).await {
928 return Outcome::Err(e);
929 }
930
931 self.read_extended_result(cx).await
932 }
933
934 #[cfg(feature = "tls")]
937 async fn negotiate_ssl(&mut self) -> Outcome<(), Error> {
938 if let Outcome::Err(e) = self.send_message_no_cx(&FrontendMessage::SSLRequest).await {
940 return Outcome::Err(e);
941 }
942
943 let mut buf = [0u8; 1];
945 if let Err(e) = self.stream.read_exact(&mut buf).await {
946 return Outcome::Err(Error::Connection(ConnectionError {
947 kind: ConnectionErrorKind::Ssl,
948 message: format!("Failed to read SSL response: {}", e),
949 source: Some(Box::new(e)),
950 }));
951 }
952
953 match buf[0] {
954 b'S' => {
955 #[cfg(feature = "tls")]
956 {
957 let plain = match std::mem::replace(&mut self.stream, PgAsyncStream::Closed) {
958 PgAsyncStream::Plain(s) => s,
959 other => {
960 self.stream = other;
961 return Outcome::Err(connection_error(
962 "TLS upgrade requires a plain TCP stream",
963 ));
964 }
965 };
966
967 let tls_stream = match AsyncTlsStream::handshake(
968 plain,
969 self.config.ssl_mode,
970 &self.config.host,
971 )
972 .await
973 {
974 Ok(s) => s,
975 Err(e) => return Outcome::Err(e),
976 };
977
978 self.stream = PgAsyncStream::Tls(tls_stream);
979 Outcome::Ok(())
980 }
981
982 #[cfg(not(feature = "tls"))]
983 {
984 Outcome::Err(connection_error(
985 "TLS requested but 'sqlmodel-postgres' was built without feature 'tls'",
986 ))
987 }
988 }
989 b'N' => {
990 if self.config.ssl_mode.is_required() {
991 Outcome::Err(Error::Connection(ConnectionError {
992 kind: ConnectionErrorKind::Ssl,
993 message: "Server does not support SSL".to_string(),
994 source: None,
995 }))
996 } else {
997 Outcome::Ok(())
998 }
999 }
1000 other => Outcome::Err(Error::Connection(ConnectionError {
1001 kind: ConnectionErrorKind::Ssl,
1002 message: format!("Unexpected SSL response: 0x{other:02x}"),
1003 source: None,
1004 })),
1005 }
1006 }
1007
1008 async fn send_startup(&mut self) -> Outcome<(), Error> {
1009 let params = self.config.startup_params();
1010 self.send_message_no_cx(&FrontendMessage::Startup {
1011 version: PROTOCOL_VERSION,
1012 params,
1013 })
1014 .await
1015 }
1016
1017 fn require_auth_value(&self, message: &'static str) -> Outcome<&str, Error> {
1018 match self.config.password.as_deref() {
1020 Some(password) => Outcome::Ok(password),
1021 None => Outcome::Err(auth_error(message)),
1022 }
1023 }
1024
1025 async fn handle_auth(&mut self) -> Outcome<(), Error> {
1026 loop {
1027 let msg = match self.receive_message_no_cx().await {
1028 Outcome::Ok(m) => m,
1029 Outcome::Err(e) => return Outcome::Err(e),
1030 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1031 Outcome::Panicked(p) => return Outcome::Panicked(p),
1032 };
1033
1034 match msg {
1035 BackendMessage::AuthenticationOk => return Outcome::Ok(()),
1036 BackendMessage::AuthenticationCleartextPassword => {
1037 let auth_value = match self
1038 .require_auth_value("Authentication value required but not provided")
1039 {
1040 Outcome::Ok(password) => password,
1041 Outcome::Err(e) => return Outcome::Err(e),
1042 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1043 Outcome::Panicked(p) => return Outcome::Panicked(p),
1044 };
1045 if let Outcome::Err(e) = self
1046 .send_message_no_cx(&FrontendMessage::PasswordMessage(
1047 auth_value.to_string(),
1048 ))
1049 .await
1050 {
1051 return Outcome::Err(e);
1052 }
1053 }
1054 BackendMessage::AuthenticationMD5Password(salt) => {
1055 let auth_value = match self
1056 .require_auth_value("Authentication value required but not provided")
1057 {
1058 Outcome::Ok(password) => password,
1059 Outcome::Err(e) => return Outcome::Err(e),
1060 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1061 Outcome::Panicked(p) => return Outcome::Panicked(p),
1062 };
1063 let hash = md5_password(&self.config.user, auth_value, salt);
1064 if let Outcome::Err(e) = self
1065 .send_message_no_cx(&FrontendMessage::PasswordMessage(hash))
1066 .await
1067 {
1068 return Outcome::Err(e);
1069 }
1070 }
1071 BackendMessage::AuthenticationSASL(mechanisms) => {
1072 if mechanisms.contains(&"SCRAM-SHA-256".to_string()) {
1073 match self.scram_auth().await {
1074 Outcome::Ok(()) => {}
1075 Outcome::Err(e) => return Outcome::Err(e),
1076 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1077 Outcome::Panicked(p) => return Outcome::Panicked(p),
1078 }
1079 } else {
1080 return Outcome::Err(auth_error(format!(
1081 "Unsupported SASL mechanisms: {:?}",
1082 mechanisms
1083 )));
1084 }
1085 }
1086 BackendMessage::ErrorResponse(e) => {
1087 self.state = ConnectionState::Error;
1088 return Outcome::Err(error_from_fields(&e));
1089 }
1090 other => {
1091 return Outcome::Err(protocol_error(format!(
1092 "Unexpected message during auth: {other:?}"
1093 )));
1094 }
1095 }
1096 }
1097 }
1098
1099 async fn scram_auth(&mut self) -> Outcome<(), Error> {
1100 let auth_value =
1101 match self.require_auth_value("Authentication value required for SCRAM-SHA-256") {
1102 Outcome::Ok(password) => password,
1103 Outcome::Err(e) => return Outcome::Err(e),
1104 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1105 Outcome::Panicked(p) => return Outcome::Panicked(p),
1106 };
1107
1108 let mut client = ScramClient::new(&self.config.user, auth_value);
1109
1110 let client_first = client.client_first();
1112 if let Outcome::Err(e) = self
1113 .send_message_no_cx(&FrontendMessage::SASLInitialResponse {
1114 mechanism: "SCRAM-SHA-256".to_string(),
1115 data: client_first,
1116 })
1117 .await
1118 {
1119 return Outcome::Err(e);
1120 }
1121
1122 let msg = match self.receive_message_no_cx().await {
1124 Outcome::Ok(m) => m,
1125 Outcome::Err(e) => return Outcome::Err(e),
1126 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1127 Outcome::Panicked(p) => return Outcome::Panicked(p),
1128 };
1129 let server_first_data = match msg {
1130 BackendMessage::AuthenticationSASLContinue(data) => data,
1131 BackendMessage::ErrorResponse(e) => {
1132 self.state = ConnectionState::Error;
1133 return Outcome::Err(error_from_fields(&e));
1134 }
1135 other => {
1136 return Outcome::Err(protocol_error(format!(
1137 "Expected SASL continue, got: {other:?}"
1138 )));
1139 }
1140 };
1141
1142 let client_final = match client.process_server_first(&server_first_data) {
1144 Ok(v) => v,
1145 Err(e) => return Outcome::Err(e),
1146 };
1147 if let Outcome::Err(e) = self
1148 .send_message_no_cx(&FrontendMessage::SASLResponse(client_final))
1149 .await
1150 {
1151 return Outcome::Err(e);
1152 }
1153
1154 let msg = match self.receive_message_no_cx().await {
1156 Outcome::Ok(m) => m,
1157 Outcome::Err(e) => return Outcome::Err(e),
1158 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1159 Outcome::Panicked(p) => return Outcome::Panicked(p),
1160 };
1161 let server_final_data = match msg {
1162 BackendMessage::AuthenticationSASLFinal(data) => data,
1163 BackendMessage::ErrorResponse(e) => {
1164 self.state = ConnectionState::Error;
1165 return Outcome::Err(error_from_fields(&e));
1166 }
1167 other => {
1168 return Outcome::Err(protocol_error(format!(
1169 "Expected SASL final, got: {other:?}"
1170 )));
1171 }
1172 };
1173
1174 if let Err(e) = client.verify_server_final(&server_final_data) {
1175 return Outcome::Err(e);
1176 }
1177
1178 let msg = match self.receive_message_no_cx().await {
1180 Outcome::Ok(m) => m,
1181 Outcome::Err(e) => return Outcome::Err(e),
1182 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1183 Outcome::Panicked(p) => return Outcome::Panicked(p),
1184 };
1185 match msg {
1186 BackendMessage::AuthenticationOk => Outcome::Ok(()),
1187 BackendMessage::ErrorResponse(e) => {
1188 self.state = ConnectionState::Error;
1189 Outcome::Err(error_from_fields(&e))
1190 }
1191 other => Outcome::Err(protocol_error(format!(
1192 "Expected AuthenticationOk, got: {other:?}"
1193 ))),
1194 }
1195 }
1196
1197 async fn read_startup_messages(&mut self) -> Outcome<(), Error> {
1198 loop {
1199 let msg = match self.receive_message_no_cx().await {
1200 Outcome::Ok(m) => m,
1201 Outcome::Err(e) => return Outcome::Err(e),
1202 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1203 Outcome::Panicked(p) => return Outcome::Panicked(p),
1204 };
1205
1206 match msg {
1207 BackendMessage::BackendKeyData {
1208 process_id,
1209 secret_key,
1210 } => {
1211 self.process_id = process_id;
1212 self.secret_key = secret_key;
1213 }
1214 BackendMessage::ParameterStatus { name, value } => {
1215 self.parameters.insert(name, value);
1216 }
1217 BackendMessage::ReadyForQuery(status) => {
1218 self.state = ConnectionState::Ready(TransactionStatusState::from(status));
1219 return Outcome::Ok(());
1220 }
1221 BackendMessage::ErrorResponse(e) => {
1222 self.state = ConnectionState::Error;
1223 return Outcome::Err(error_from_fields(&e));
1224 }
1225 BackendMessage::NoticeResponse(_notice) => {}
1226 other => {
1227 return Outcome::Err(protocol_error(format!(
1228 "Unexpected startup message: {other:?}"
1229 )));
1230 }
1231 }
1232 }
1233 }
1234
1235 async fn send_message(&mut self, cx: &Cx, msg: &FrontendMessage) -> Outcome<(), Error> {
1238 if let Some(reason) = cx.cancel_reason() {
1240 return Outcome::Cancelled(reason);
1241 }
1242 self.send_message_no_cx(msg).await
1243 }
1244
1245 async fn receive_message(&mut self, cx: &Cx) -> Outcome<BackendMessage, Error> {
1246 if let Some(reason) = cx.cancel_reason() {
1247 return Outcome::Cancelled(reason);
1248 }
1249 self.receive_message_no_cx().await
1250 }
1251
1252 async fn send_message_no_cx(&mut self, msg: &FrontendMessage) -> Outcome<(), Error> {
1253 let data = self.writer.write(msg).to_vec();
1254
1255 if let Err(e) = self.stream.write_all(&data).await {
1256 self.state = ConnectionState::Error;
1257 return Outcome::Err(Error::Connection(ConnectionError {
1258 kind: ConnectionErrorKind::Disconnected,
1259 message: format!("Failed to write to server: {}", e),
1260 source: Some(Box::new(e)),
1261 }));
1262 }
1263
1264 if let Err(e) = self.stream.flush().await {
1265 self.state = ConnectionState::Error;
1266 return Outcome::Err(Error::Connection(ConnectionError {
1267 kind: ConnectionErrorKind::Disconnected,
1268 message: format!("Failed to flush stream: {}", e),
1269 source: Some(Box::new(e)),
1270 }));
1271 }
1272
1273 Outcome::Ok(())
1274 }
1275
1276 async fn receive_message_no_cx(&mut self) -> Outcome<BackendMessage, Error> {
1277 loop {
1278 match self.reader.next_message() {
1279 Ok(Some(msg)) => return Outcome::Ok(msg),
1280 Ok(None) => {}
1281 Err(e) => {
1282 self.state = ConnectionState::Error;
1283 return Outcome::Err(protocol_error(format!("Protocol error: {}", e)));
1284 }
1285 }
1286
1287 let n = match self.stream.read_some(&mut self.read_buf).await {
1288 Ok(n) => n,
1289 Err(e) => {
1290 self.state = ConnectionState::Error;
1291 return Outcome::Err(match e.kind() {
1292 std::io::ErrorKind::TimedOut | std::io::ErrorKind::WouldBlock => {
1293 Error::Timeout
1294 }
1295 _ => Error::Connection(ConnectionError {
1296 kind: ConnectionErrorKind::Disconnected,
1297 message: format!("Failed to read from server: {}", e),
1298 source: Some(Box::new(e)),
1299 }),
1300 });
1301 }
1302 };
1303
1304 if n == 0 {
1305 self.state = ConnectionState::Disconnected;
1306 return Outcome::Err(Error::Connection(ConnectionError {
1307 kind: ConnectionErrorKind::Disconnected,
1308 message: "Connection closed by server".to_string(),
1309 source: None,
1310 }));
1311 }
1312
1313 self.reader.push(&self.read_buf[..n]);
1321 }
1322 }
1323}
1324
1325pub struct SharedPgConnection {
1327 inner: Arc<Mutex<PgAsyncConnection>>,
1328}
1329
1330impl SharedPgConnection {
1331 pub fn new(conn: PgAsyncConnection) -> Self {
1332 Self {
1333 inner: Arc::new(Mutex::new(conn)),
1334 }
1335 }
1336
1337 pub async fn connect(cx: &Cx, config: PgConfig) -> Outcome<Self, Error> {
1338 match PgAsyncConnection::connect(cx, config).await {
1339 Outcome::Ok(conn) => Outcome::Ok(Self::new(conn)),
1340 Outcome::Err(e) => Outcome::Err(e),
1341 Outcome::Cancelled(r) => Outcome::Cancelled(r),
1342 Outcome::Panicked(p) => Outcome::Panicked(p),
1343 }
1344 }
1345
1346 pub fn inner(&self) -> &Arc<Mutex<PgAsyncConnection>> {
1347 &self.inner
1348 }
1349
1350 async fn begin_transaction_impl(
1351 &self,
1352 cx: &Cx,
1353 isolation: Option<IsolationLevel>,
1354 ) -> Outcome<SharedPgTransaction<'_>, Error> {
1355 let inner = Arc::clone(&self.inner);
1356 let Ok(mut guard) = inner.lock(cx).await else {
1357 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1358 };
1359
1360 if let Some(level) = isolation {
1361 let sql = format!("SET TRANSACTION ISOLATION LEVEL {}", level.as_sql());
1362 match guard.execute_async(cx, &sql, &[]).await {
1363 Outcome::Ok(_) => {}
1364 Outcome::Err(e) => return Outcome::Err(e),
1365 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1366 Outcome::Panicked(p) => return Outcome::Panicked(p),
1367 }
1368 }
1369
1370 match guard.execute_async(cx, "BEGIN", &[]).await {
1371 Outcome::Ok(_) => {}
1372 Outcome::Err(e) => return Outcome::Err(e),
1373 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1374 Outcome::Panicked(p) => return Outcome::Panicked(p),
1375 }
1376
1377 drop(guard);
1378 Outcome::Ok(SharedPgTransaction {
1379 inner,
1380 committed: false,
1381 _marker: std::marker::PhantomData,
1382 })
1383 }
1384}
1385
1386impl Clone for SharedPgConnection {
1387 fn clone(&self) -> Self {
1388 Self {
1389 inner: Arc::clone(&self.inner),
1390 }
1391 }
1392}
1393
1394impl std::fmt::Debug for SharedPgConnection {
1395 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1396 f.debug_struct("SharedPgConnection")
1397 .field("inner", &"Arc<Mutex<PgAsyncConnection>>")
1398 .finish()
1399 }
1400}
1401
1402pub struct SharedPgTransaction<'conn> {
1403 inner: Arc<Mutex<PgAsyncConnection>>,
1404 committed: bool,
1405 _marker: std::marker::PhantomData<&'conn ()>,
1406}
1407
1408impl<'conn> Drop for SharedPgTransaction<'conn> {
1409 fn drop(&mut self) {
1410 if !self.committed {
1411 #[cfg(debug_assertions)]
1416 eprintln!(
1417 "WARNING: SharedPgTransaction dropped without commit/rollback. \
1418 The PostgreSQL transaction may still be open."
1419 );
1420 }
1421 }
1422}
1423
1424impl Connection for SharedPgConnection {
1425 type Tx<'conn>
1426 = SharedPgTransaction<'conn>
1427 where
1428 Self: 'conn;
1429
1430 fn dialect(&self) -> sqlmodel_core::Dialect {
1431 sqlmodel_core::Dialect::Postgres
1432 }
1433
1434 fn query(
1435 &self,
1436 cx: &Cx,
1437 sql: &str,
1438 params: &[Value],
1439 ) -> impl Future<Output = Outcome<Vec<Row>, Error>> + Send {
1440 let inner = Arc::clone(&self.inner);
1441 let sql = sql.to_string();
1442 let params = params.to_vec();
1443 async move {
1444 let Ok(mut guard) = inner.lock(cx).await else {
1445 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1446 };
1447 guard.query_async(cx, &sql, ¶ms).await
1448 }
1449 }
1450
1451 fn query_one(
1452 &self,
1453 cx: &Cx,
1454 sql: &str,
1455 params: &[Value],
1456 ) -> impl Future<Output = Outcome<Option<Row>, Error>> + Send {
1457 let inner = Arc::clone(&self.inner);
1458 let sql = sql.to_string();
1459 let params = params.to_vec();
1460 async move {
1461 let Ok(mut guard) = inner.lock(cx).await else {
1462 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1463 };
1464 let rows = match guard.query_async(cx, &sql, ¶ms).await {
1465 Outcome::Ok(r) => r,
1466 Outcome::Err(e) => return Outcome::Err(e),
1467 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1468 Outcome::Panicked(p) => return Outcome::Panicked(p),
1469 };
1470 Outcome::Ok(rows.into_iter().next())
1471 }
1472 }
1473
1474 fn execute(
1475 &self,
1476 cx: &Cx,
1477 sql: &str,
1478 params: &[Value],
1479 ) -> impl Future<Output = Outcome<u64, Error>> + Send {
1480 let inner = Arc::clone(&self.inner);
1481 let sql = sql.to_string();
1482 let params = params.to_vec();
1483 async move {
1484 let Ok(mut guard) = inner.lock(cx).await else {
1485 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1486 };
1487 guard.execute_async(cx, &sql, ¶ms).await
1488 }
1489 }
1490
1491 fn insert(
1492 &self,
1493 cx: &Cx,
1494 sql: &str,
1495 params: &[Value],
1496 ) -> impl Future<Output = Outcome<i64, Error>> + Send {
1497 let inner = Arc::clone(&self.inner);
1498 let sql = sql.to_string();
1499 let params = params.to_vec();
1500 async move {
1501 let Ok(mut guard) = inner.lock(cx).await else {
1502 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1503 };
1504 guard.insert_async(cx, &sql, ¶ms).await
1505 }
1506 }
1507
1508 fn batch(
1509 &self,
1510 cx: &Cx,
1511 statements: &[(String, Vec<Value>)],
1512 ) -> impl Future<Output = Outcome<Vec<u64>, Error>> + Send {
1513 let inner = Arc::clone(&self.inner);
1514 let statements = statements.to_vec();
1515 async move {
1516 let Ok(mut guard) = inner.lock(cx).await else {
1517 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1518 };
1519 let mut results = Vec::with_capacity(statements.len());
1520 for (sql, params) in &statements {
1521 match guard.execute_async(cx, sql, params).await {
1522 Outcome::Ok(n) => results.push(n),
1523 Outcome::Err(e) => return Outcome::Err(e),
1524 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1525 Outcome::Panicked(p) => return Outcome::Panicked(p),
1526 }
1527 }
1528 Outcome::Ok(results)
1529 }
1530 }
1531
1532 fn begin(&self, cx: &Cx) -> impl Future<Output = Outcome<Self::Tx<'_>, Error>> + Send {
1533 self.begin_with(cx, IsolationLevel::default())
1534 }
1535
1536 fn begin_with(
1537 &self,
1538 cx: &Cx,
1539 isolation: IsolationLevel,
1540 ) -> impl Future<Output = Outcome<Self::Tx<'_>, Error>> + Send {
1541 self.begin_transaction_impl(cx, Some(isolation))
1542 }
1543
1544 fn prepare(
1545 &self,
1546 cx: &Cx,
1547 sql: &str,
1548 ) -> impl Future<Output = Outcome<PreparedStatement, Error>> + Send {
1549 let inner = Arc::clone(&self.inner);
1550 let sql = sql.to_string();
1551 async move {
1552 let Ok(mut guard) = inner.lock(cx).await else {
1553 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1554 };
1555 guard.prepare_async(cx, &sql).await
1556 }
1557 }
1558
1559 fn query_prepared(
1560 &self,
1561 cx: &Cx,
1562 stmt: &PreparedStatement,
1563 params: &[Value],
1564 ) -> impl Future<Output = Outcome<Vec<Row>, Error>> + Send {
1565 let inner = Arc::clone(&self.inner);
1566 let stmt = stmt.clone();
1567 let params = params.to_vec();
1568 async move {
1569 let Ok(mut guard) = inner.lock(cx).await else {
1570 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1571 };
1572 guard.query_prepared_async(cx, &stmt, ¶ms).await
1573 }
1574 }
1575
1576 fn execute_prepared(
1577 &self,
1578 cx: &Cx,
1579 stmt: &PreparedStatement,
1580 params: &[Value],
1581 ) -> impl Future<Output = Outcome<u64, Error>> + Send {
1582 let inner = Arc::clone(&self.inner);
1583 let stmt = stmt.clone();
1584 let params = params.to_vec();
1585 async move {
1586 let Ok(mut guard) = inner.lock(cx).await else {
1587 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1588 };
1589 guard.execute_prepared_async(cx, &stmt, ¶ms).await
1590 }
1591 }
1592
1593 fn ping(&self, cx: &Cx) -> impl Future<Output = Outcome<(), Error>> + Send {
1594 let inner = Arc::clone(&self.inner);
1595 async move {
1596 let Ok(mut guard) = inner.lock(cx).await else {
1597 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1598 };
1599 guard.ping_async(cx).await
1600 }
1601 }
1602
1603 async fn close(self, cx: &Cx) -> sqlmodel_core::Result<()> {
1604 let Ok(mut guard) = self.inner.lock(cx).await else {
1605 return Err(connection_error("Failed to acquire connection lock"));
1606 };
1607 match guard.close_async(cx).await {
1608 Outcome::Ok(()) => Ok(()),
1609 Outcome::Err(e) => Err(e),
1610 Outcome::Cancelled(r) => Err(Error::Query(QueryError {
1611 kind: QueryErrorKind::Cancelled,
1612 message: format!("Cancelled: {r:?}"),
1613 sqlstate: None,
1614 sql: None,
1615 detail: None,
1616 hint: None,
1617 position: None,
1618 source: None,
1619 })),
1620 Outcome::Panicked(p) => Err(Error::Protocol(ProtocolError {
1621 message: format!("Panicked: {p:?}"),
1622 raw_data: None,
1623 source: None,
1624 })),
1625 }
1626 }
1627}
1628
1629impl<'conn> TransactionOps for SharedPgTransaction<'conn> {
1630 fn query(
1631 &self,
1632 cx: &Cx,
1633 sql: &str,
1634 params: &[Value],
1635 ) -> impl Future<Output = Outcome<Vec<Row>, Error>> + Send {
1636 let inner = Arc::clone(&self.inner);
1637 let sql = sql.to_string();
1638 let params = params.to_vec();
1639 async move {
1640 let Ok(mut guard) = inner.lock(cx).await else {
1641 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1642 };
1643 guard.query_async(cx, &sql, ¶ms).await
1644 }
1645 }
1646
1647 fn query_one(
1648 &self,
1649 cx: &Cx,
1650 sql: &str,
1651 params: &[Value],
1652 ) -> impl Future<Output = Outcome<Option<Row>, Error>> + Send {
1653 let inner = Arc::clone(&self.inner);
1654 let sql = sql.to_string();
1655 let params = params.to_vec();
1656 async move {
1657 let Ok(mut guard) = inner.lock(cx).await else {
1658 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1659 };
1660 let rows = match guard.query_async(cx, &sql, ¶ms).await {
1661 Outcome::Ok(r) => r,
1662 Outcome::Err(e) => return Outcome::Err(e),
1663 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1664 Outcome::Panicked(p) => return Outcome::Panicked(p),
1665 };
1666 Outcome::Ok(rows.into_iter().next())
1667 }
1668 }
1669
1670 fn execute(
1671 &self,
1672 cx: &Cx,
1673 sql: &str,
1674 params: &[Value],
1675 ) -> impl Future<Output = Outcome<u64, Error>> + Send {
1676 let inner = Arc::clone(&self.inner);
1677 let sql = sql.to_string();
1678 let params = params.to_vec();
1679 async move {
1680 let Ok(mut guard) = inner.lock(cx).await else {
1681 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1682 };
1683 guard.execute_async(cx, &sql, ¶ms).await
1684 }
1685 }
1686
1687 fn savepoint(&self, cx: &Cx, name: &str) -> impl Future<Output = Outcome<(), Error>> + Send {
1688 let inner = Arc::clone(&self.inner);
1689 let name = name.to_string();
1690 async move {
1691 if let Err(e) = validate_savepoint_name(&name) {
1692 return Outcome::Err(e);
1693 }
1694 let sql = format!("SAVEPOINT {}", name);
1695 let Ok(mut guard) = inner.lock(cx).await else {
1696 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1697 };
1698 guard.execute_async(cx, &sql, &[]).await.map(|_| ())
1699 }
1700 }
1701
1702 fn rollback_to(&self, cx: &Cx, name: &str) -> impl Future<Output = Outcome<(), Error>> + Send {
1703 let inner = Arc::clone(&self.inner);
1704 let name = name.to_string();
1705 async move {
1706 if let Err(e) = validate_savepoint_name(&name) {
1707 return Outcome::Err(e);
1708 }
1709 let sql = format!("ROLLBACK TO SAVEPOINT {}", name);
1710 let Ok(mut guard) = inner.lock(cx).await else {
1711 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1712 };
1713 guard.execute_async(cx, &sql, &[]).await.map(|_| ())
1714 }
1715 }
1716
1717 fn release(&self, cx: &Cx, name: &str) -> impl Future<Output = Outcome<(), Error>> + Send {
1718 let inner = Arc::clone(&self.inner);
1719 let name = name.to_string();
1720 async move {
1721 if let Err(e) = validate_savepoint_name(&name) {
1722 return Outcome::Err(e);
1723 }
1724 let sql = format!("RELEASE SAVEPOINT {}", name);
1725 let Ok(mut guard) = inner.lock(cx).await else {
1726 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1727 };
1728 guard.execute_async(cx, &sql, &[]).await.map(|_| ())
1729 }
1730 }
1731
1732 #[allow(unused_assignments)]
1734 fn commit(mut self, cx: &Cx) -> impl Future<Output = Outcome<(), Error>> + Send {
1735 let inner = Arc::clone(&self.inner);
1736 async move {
1737 let Ok(mut guard) = inner.lock(cx).await else {
1738 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1739 };
1740 let result = guard.execute_async(cx, "COMMIT", &[]).await;
1741 if matches!(result, Outcome::Ok(_)) {
1742 self.committed = true;
1743 }
1744 result.map(|_| ())
1745 }
1746 }
1747
1748 #[allow(unused_assignments)]
1749 fn rollback(mut self, cx: &Cx) -> impl Future<Output = Outcome<(), Error>> + Send {
1750 let inner = Arc::clone(&self.inner);
1751 async move {
1752 let Ok(mut guard) = inner.lock(cx).await else {
1753 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1754 };
1755 let result = guard.execute_async(cx, "ROLLBACK", &[]).await;
1756 if matches!(result, Outcome::Ok(_)) {
1757 self.committed = true;
1758 }
1759 result.map(|_| ())
1760 }
1761 }
1762}
1763
1764struct PgQueryResult {
1767 rows: Vec<Row>,
1768 command_tag: Option<String>,
1769}
1770
1771fn connection_error(msg: impl Into<String>) -> Error {
1772 Error::Connection(ConnectionError {
1773 kind: ConnectionErrorKind::Connect,
1774 message: msg.into(),
1775 source: None,
1776 })
1777}
1778
1779fn auth_error(msg: impl Into<String>) -> Error {
1780 Error::Connection(ConnectionError {
1781 kind: ConnectionErrorKind::Authentication,
1782 message: msg.into(),
1783 source: None,
1784 })
1785}
1786
1787fn protocol_error(msg: impl Into<String>) -> Error {
1788 Error::Protocol(ProtocolError {
1789 message: msg.into(),
1790 raw_data: None,
1791 source: None,
1792 })
1793}
1794
1795fn query_error_msg(msg: impl Into<String>, kind: QueryErrorKind) -> Error {
1796 Error::Query(QueryError {
1797 kind,
1798 message: msg.into(),
1799 sqlstate: None,
1800 sql: None,
1801 detail: None,
1802 hint: None,
1803 position: None,
1804 source: None,
1805 })
1806}
1807
1808fn error_from_fields(fields: &ErrorFields) -> Error {
1809 let kind = match fields.code.get(..2) {
1810 Some("08") => {
1811 return Error::Connection(ConnectionError {
1812 kind: ConnectionErrorKind::Connect,
1813 message: fields.message.clone(),
1814 source: None,
1815 });
1816 }
1817 Some("28") => {
1818 return Error::Connection(ConnectionError {
1819 kind: ConnectionErrorKind::Authentication,
1820 message: fields.message.clone(),
1821 source: None,
1822 });
1823 }
1824 Some("42") => QueryErrorKind::Syntax,
1825 Some("23") => QueryErrorKind::Constraint,
1826 Some("40") => {
1827 if fields.code == "40001" {
1828 QueryErrorKind::Serialization
1829 } else {
1830 QueryErrorKind::Deadlock
1831 }
1832 }
1833 Some("57") => {
1834 if fields.code == "57014" {
1835 QueryErrorKind::Cancelled
1836 } else {
1837 QueryErrorKind::Timeout
1838 }
1839 }
1840 _ => QueryErrorKind::Database,
1841 };
1842
1843 Error::Query(QueryError {
1844 kind,
1845 sql: None,
1846 sqlstate: Some(fields.code.clone()),
1847 message: fields.message.clone(),
1848 detail: fields.detail.clone(),
1849 hint: fields.hint.clone(),
1850 position: fields.position.map(|p| p as usize),
1851 source: None,
1852 })
1853}
1854
1855fn parse_rows_affected(tag: Option<&str>) -> Option<u64> {
1856 let tag = tag?;
1857 let mut parts = tag.split_whitespace().collect::<Vec<_>>();
1858 parts.pop().and_then(|last| last.parse::<u64>().ok())
1859}
1860
1861fn validate_savepoint_name(name: &str) -> sqlmodel_core::Result<()> {
1863 if name.is_empty() {
1864 return Err(query_error_msg(
1865 "Savepoint name cannot be empty",
1866 QueryErrorKind::Syntax,
1867 ));
1868 }
1869 if name.len() > 63 {
1870 return Err(query_error_msg(
1871 "Savepoint name exceeds maximum length of 63 characters",
1872 QueryErrorKind::Syntax,
1873 ));
1874 }
1875 let mut chars = name.chars();
1876 let Some(first) = chars.next() else {
1877 return Err(query_error_msg(
1878 "Savepoint name cannot be empty",
1879 QueryErrorKind::Syntax,
1880 ));
1881 };
1882 if !first.is_ascii_alphabetic() && first != '_' {
1883 return Err(query_error_msg(
1884 "Savepoint name must start with a letter or underscore",
1885 QueryErrorKind::Syntax,
1886 ));
1887 }
1888 for c in chars {
1889 if !c.is_ascii_alphanumeric() && c != '_' {
1890 return Err(query_error_msg(
1891 format!("Savepoint name contains invalid character: '{c}'"),
1892 QueryErrorKind::Syntax,
1893 ));
1894 }
1895 }
1896 Ok(())
1897}
1898
1899fn md5_password(user: &str, password: &str, salt: [u8; 4]) -> String {
1900 use std::fmt::Write;
1901
1902 let inner = format!("{password}{user}");
1903 let inner_hash = md5::compute(inner.as_bytes());
1904
1905 let mut outer_input = format!("{inner_hash:x}").into_bytes();
1906 outer_input.extend_from_slice(&salt);
1907 let outer_hash = md5::compute(&outer_input);
1908
1909 let mut result = String::with_capacity(35);
1910 result.push_str("md5");
1911 write!(&mut result, "{outer_hash:x}").unwrap();
1912 result
1913}
1914
1915