1#![allow(clippy::manual_async_fn)]
15#![allow(clippy::result_large_err)]
17
18use std::collections::HashMap;
19use std::future::Future;
20#[cfg(feature = "tls")]
21use std::io::{Read, Write};
22use std::sync::Arc;
23
24use asupersync::io::{AsyncRead, AsyncWrite, ReadBuf};
25use asupersync::net::TcpStream;
26use asupersync::sync::Mutex;
27use asupersync::{Cx, Outcome};
28
29use sqlmodel_core::connection::{Connection, IsolationLevel, PreparedStatement, TransactionOps};
30use sqlmodel_core::error::{
31 ConnectionError, ConnectionErrorKind, ProtocolError, QueryError, QueryErrorKind,
32};
33use sqlmodel_core::row::ColumnInfo;
34use sqlmodel_core::{Error, Row, Value};
35
36use crate::auth::ScramClient;
37use crate::config::{PgConfig, SslMode};
38use crate::connection::{ConnectionState, TransactionStatusState};
39use crate::protocol::{
40 BackendMessage, DescribeKind, ErrorFields, FrontendMessage, MessageReader, MessageWriter,
41 PROTOCOL_VERSION,
42};
43use crate::types::{Format, decode_value, encode_value};
44
45#[cfg(feature = "tls")]
46use crate::tls;
47
48enum PgAsyncStream {
49 Plain(TcpStream),
50 #[cfg(feature = "tls")]
51 Tls(AsyncTlsStream),
52 #[cfg(feature = "tls")]
53 Closed,
54}
55
56impl PgAsyncStream {
57 #[cfg(feature = "tls")]
58 async fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
59 match self {
60 PgAsyncStream::Plain(s) => read_exact_plain_async(s, buf).await,
61 #[cfg(feature = "tls")]
62 PgAsyncStream::Tls(s) => s.read_exact(buf).await,
63 #[cfg(feature = "tls")]
64 PgAsyncStream::Closed => Err(std::io::Error::new(
65 std::io::ErrorKind::NotConnected,
66 "connection closed",
67 )),
68 }
69 }
70
71 async fn read_some(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
72 match self {
73 PgAsyncStream::Plain(s) => read_some_plain_async(s, buf).await,
74 #[cfg(feature = "tls")]
75 PgAsyncStream::Tls(s) => s.read_plain(buf).await,
76 #[cfg(feature = "tls")]
77 PgAsyncStream::Closed => Err(std::io::Error::new(
78 std::io::ErrorKind::NotConnected,
79 "connection closed",
80 )),
81 }
82 }
83
84 async fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
85 match self {
86 PgAsyncStream::Plain(s) => write_all_plain_async(s, buf).await,
87 #[cfg(feature = "tls")]
88 PgAsyncStream::Tls(s) => s.write_all(buf).await,
89 #[cfg(feature = "tls")]
90 PgAsyncStream::Closed => Err(std::io::Error::new(
91 std::io::ErrorKind::NotConnected,
92 "connection closed",
93 )),
94 }
95 }
96
97 async fn flush(&mut self) -> std::io::Result<()> {
98 match self {
99 PgAsyncStream::Plain(s) => flush_plain_async(s).await,
100 #[cfg(feature = "tls")]
101 PgAsyncStream::Tls(s) => s.flush().await,
102 #[cfg(feature = "tls")]
103 PgAsyncStream::Closed => Err(std::io::Error::new(
104 std::io::ErrorKind::NotConnected,
105 "connection closed",
106 )),
107 }
108 }
109}
110
111#[cfg(feature = "tls")]
112struct AsyncTlsStream {
113 tcp: TcpStream,
114 tls: rustls::ClientConnection,
115}
116
117#[cfg(feature = "tls")]
118impl AsyncTlsStream {
119 async fn handshake(mut tcp: TcpStream, ssl_mode: SslMode, host: &str) -> Result<Self, Error> {
120 let config = tls::build_client_config(ssl_mode)?;
121 let server_name = tls::server_name(host)?;
122 let mut tls = rustls::ClientConnection::new(std::sync::Arc::new(config), server_name)
123 .map_err(|e| connection_error(format!("Failed to create TLS connection: {e}")))?;
124
125 while tls.is_handshaking() {
126 while tls.wants_write() {
127 let mut out = Vec::new();
128 tls.write_tls(&mut out)
129 .map_err(|e| connection_error(format!("TLS handshake write_tls error: {e}")))?;
130 if !out.is_empty() {
131 write_all_plain_async(&mut tcp, &out).await.map_err(|e| {
132 Error::Connection(ConnectionError {
133 kind: ConnectionErrorKind::Disconnected,
134 message: format!("TLS handshake write error: {e}"),
135 source: Some(Box::new(e)),
136 })
137 })?;
138 }
139 }
140
141 if tls.wants_read() {
142 let mut buf = [0u8; 8192];
143 let n = read_some_plain_async(&mut tcp, &mut buf)
144 .await
145 .map_err(|e| {
146 Error::Connection(ConnectionError {
147 kind: ConnectionErrorKind::Disconnected,
148 message: format!("TLS handshake read error: {e}"),
149 source: Some(Box::new(e)),
150 })
151 })?;
152 if n == 0 {
153 return Err(connection_error("Connection closed during TLS handshake"));
154 }
155
156 let mut cursor = std::io::Cursor::new(&buf[..n]);
157 tls.read_tls(&mut cursor)
158 .map_err(|e| connection_error(format!("TLS handshake read_tls error: {e}")))?;
159 tls.process_new_packets()
160 .map_err(|e| connection_error(format!("TLS handshake error: {e}")))?;
161 }
162 }
163
164 Ok(Self { tcp, tls })
165 }
166
167 async fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
168 let mut read = 0;
169 while read < buf.len() {
170 let n = self.read_plain(&mut buf[read..]).await?;
171 if n == 0 {
172 return Err(std::io::Error::new(
173 std::io::ErrorKind::UnexpectedEof,
174 "connection closed",
175 ));
176 }
177 read += n;
178 }
179 Ok(())
180 }
181
182 async fn read_plain(&mut self, out: &mut [u8]) -> std::io::Result<usize> {
183 loop {
184 match self.tls.reader().read(out) {
185 Ok(n) if n > 0 => return Ok(n),
186 Ok(_) => {}
187 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {}
188 Err(e) => return Err(e),
189 }
190
191 if !self.tls.wants_read() {
192 return Ok(0);
193 }
194
195 let mut enc = [0u8; 8192];
196 let n = read_some_plain_async(&mut self.tcp, &mut enc).await?;
197 if n == 0 {
198 return Ok(0);
199 }
200
201 let mut cursor = std::io::Cursor::new(&enc[..n]);
202 self.tls.read_tls(&mut cursor)?;
203 self.tls
204 .process_new_packets()
205 .map_err(|e| std::io::Error::other(format!("TLS error: {e}")))?;
206 }
207 }
208
209 async fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
210 let mut written = 0;
211 while written < buf.len() {
212 let n = self.tls.writer().write(&buf[written..])?;
213 if n == 0 {
214 return Err(std::io::Error::new(
215 std::io::ErrorKind::WriteZero,
216 "TLS write zero",
217 ));
218 }
219 written += n;
220 self.flush().await?;
221 }
222 Ok(())
223 }
224
225 async fn flush(&mut self) -> std::io::Result<()> {
226 self.tls.writer().flush()?;
227 while self.tls.wants_write() {
228 let mut out = Vec::new();
229 self.tls.write_tls(&mut out)?;
230 if !out.is_empty() {
231 write_all_plain_async(&mut self.tcp, &out).await?;
232 }
233 }
234 flush_plain_async(&mut self.tcp).await
235 }
236}
237
238#[cfg(feature = "tls")]
239async fn read_exact_plain_async(stream: &mut TcpStream, buf: &mut [u8]) -> std::io::Result<()> {
240 let mut read = 0;
241 while read < buf.len() {
242 let n = read_some_plain_async(stream, &mut buf[read..]).await?;
243 if n == 0 {
244 return Err(std::io::Error::new(
245 std::io::ErrorKind::UnexpectedEof,
246 "connection closed",
247 ));
248 }
249 read += n;
250 }
251 Ok(())
252}
253
254async fn read_some_plain_async(stream: &mut TcpStream, buf: &mut [u8]) -> std::io::Result<usize> {
255 let mut read_buf = ReadBuf::new(buf);
256 std::future::poll_fn(|cx| std::pin::Pin::new(&mut *stream).poll_read(cx, &mut read_buf))
257 .await?;
258 Ok(read_buf.filled().len())
259}
260
261async fn write_all_plain_async(stream: &mut TcpStream, buf: &[u8]) -> std::io::Result<()> {
262 let mut written = 0;
263 while written < buf.len() {
264 let n = std::future::poll_fn(|cx| {
265 std::pin::Pin::new(&mut *stream).poll_write(cx, &buf[written..])
266 })
267 .await?;
268 if n == 0 {
269 return Err(std::io::Error::new(
270 std::io::ErrorKind::WriteZero,
271 "connection closed",
272 ));
273 }
274 written += n;
275 }
276 Ok(())
277}
278
279async fn flush_plain_async(stream: &mut TcpStream) -> std::io::Result<()> {
280 std::future::poll_fn(|cx| std::pin::Pin::new(&mut *stream).poll_flush(cx)).await
281}
282
283pub struct PgAsyncConnection {
288 stream: PgAsyncStream,
289 state: ConnectionState,
290 process_id: i32,
291 secret_key: i32,
292 parameters: HashMap<String, String>,
293 next_prepared_id: u64,
294 prepared: HashMap<u64, PgPreparedMeta>,
295 config: PgConfig,
296 reader: MessageReader,
297 writer: MessageWriter,
298 read_buf: Vec<u8>,
299}
300
301#[derive(Debug, Clone)]
302struct PgPreparedMeta {
303 name: String,
304 param_type_oids: Vec<u32>,
305}
306
307impl std::fmt::Debug for PgAsyncConnection {
308 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
309 f.debug_struct("PgAsyncConnection")
310 .field("state", &self.state)
311 .field("process_id", &self.process_id)
312 .field("host", &self.config.host)
313 .field("port", &self.config.port)
314 .field("database", &self.config.database)
315 .finish_non_exhaustive()
316 }
317}
318
319impl PgAsyncConnection {
320 pub async fn connect(_cx: &Cx, config: PgConfig) -> Outcome<Self, Error> {
322 let addr = config.socket_addr();
323 let socket_addr = match addr.parse() {
324 Ok(a) => a,
325 Err(e) => {
326 return Outcome::Err(Error::Connection(ConnectionError {
327 kind: ConnectionErrorKind::Connect,
328 message: format!("Invalid socket address: {}", e),
329 source: None,
330 }));
331 }
332 };
333
334 let stream = match TcpStream::connect_timeout(socket_addr, config.connect_timeout).await {
335 Ok(s) => s,
336 Err(e) => {
337 let kind = if e.kind() == std::io::ErrorKind::ConnectionRefused {
338 ConnectionErrorKind::Refused
339 } else {
340 ConnectionErrorKind::Connect
341 };
342 return Outcome::Err(Error::Connection(ConnectionError {
343 kind,
344 message: format!("Failed to connect to {}: {}", addr, e),
345 source: Some(Box::new(e)),
346 }));
347 }
348 };
349
350 stream.set_nodelay(true).ok();
351
352 let mut conn = Self {
353 stream: PgAsyncStream::Plain(stream),
354 state: ConnectionState::Connecting,
355 process_id: 0,
356 secret_key: 0,
357 parameters: HashMap::new(),
358 next_prepared_id: 1,
359 prepared: HashMap::new(),
360 config,
361 reader: MessageReader::new(),
362 writer: MessageWriter::new(),
363 read_buf: vec![0u8; 8192],
364 };
365
366 if conn.config.ssl_mode.should_try_ssl() {
368 #[cfg(feature = "tls")]
369 match conn.negotiate_ssl().await {
370 Outcome::Ok(()) => {}
371 Outcome::Err(e) => return Outcome::Err(e),
372 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
373 Outcome::Panicked(p) => return Outcome::Panicked(p),
374 }
375
376 #[cfg(not(feature = "tls"))]
377 if conn.config.ssl_mode != SslMode::Prefer {
378 return Outcome::Err(connection_error(
379 "TLS requested but 'sqlmodel-postgres' was built without feature 'tls'",
380 ));
381 }
382 }
383
384 if let Outcome::Err(e) = conn.send_startup().await {
386 return Outcome::Err(e);
387 }
388 conn.state = ConnectionState::Authenticating;
389
390 match conn.handle_auth().await {
391 Outcome::Ok(()) => {}
392 Outcome::Err(e) => return Outcome::Err(e),
393 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
394 Outcome::Panicked(p) => return Outcome::Panicked(p),
395 }
396
397 match conn.read_startup_messages().await {
398 Outcome::Ok(()) => Outcome::Ok(conn),
399 Outcome::Err(e) => Outcome::Err(e),
400 Outcome::Cancelled(r) => Outcome::Cancelled(r),
401 Outcome::Panicked(p) => Outcome::Panicked(p),
402 }
403 }
404
405 pub async fn query_async(
407 &mut self,
408 cx: &Cx,
409 sql: &str,
410 params: &[Value],
411 ) -> Outcome<Vec<Row>, Error> {
412 match self.run_extended(cx, sql, params).await {
413 Outcome::Ok(result) => Outcome::Ok(result.rows),
414 Outcome::Err(e) => Outcome::Err(e),
415 Outcome::Cancelled(r) => Outcome::Cancelled(r),
416 Outcome::Panicked(p) => Outcome::Panicked(p),
417 }
418 }
419
420 pub async fn execute_async(
422 &mut self,
423 cx: &Cx,
424 sql: &str,
425 params: &[Value],
426 ) -> Outcome<u64, Error> {
427 match self.run_extended(cx, sql, params).await {
428 Outcome::Ok(result) => {
429 Outcome::Ok(parse_rows_affected(result.command_tag.as_deref()).unwrap_or(0))
430 }
431 Outcome::Err(e) => Outcome::Err(e),
432 Outcome::Cancelled(r) => Outcome::Cancelled(r),
433 Outcome::Panicked(p) => Outcome::Panicked(p),
434 }
435 }
436
437 pub async fn insert_async(
443 &mut self,
444 cx: &Cx,
445 sql: &str,
446 params: &[Value],
447 ) -> Outcome<i64, Error> {
448 let result = match self.run_extended(cx, sql, params).await {
449 Outcome::Ok(r) => r,
450 Outcome::Err(e) => return Outcome::Err(e),
451 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
452 Outcome::Panicked(p) => return Outcome::Panicked(p),
453 };
454
455 let Some(row) = result.rows.first() else {
456 return Outcome::Err(query_error_msg(
457 "INSERT did not return an id; add `RETURNING id`",
458 QueryErrorKind::Database,
459 ));
460 };
461 let Some(id_value) = row.get(0) else {
462 return Outcome::Err(query_error_msg(
463 "INSERT result row missing id column",
464 QueryErrorKind::Database,
465 ));
466 };
467 match id_value.as_i64() {
468 Some(v) => Outcome::Ok(v),
469 None => Outcome::Err(query_error_msg(
470 "INSERT returned non-integer id",
471 QueryErrorKind::Database,
472 )),
473 }
474 }
475
476 pub async fn ping_async(&mut self, cx: &Cx) -> Outcome<(), Error> {
478 self.execute_async(cx, "SELECT 1", &[]).await.map(|_| ())
479 }
480
481 pub async fn close_async(&mut self, cx: &Cx) -> Outcome<(), Error> {
483 let _ = self.send_message(cx, &FrontendMessage::Terminate).await;
488 self.state = ConnectionState::Closed;
489 Outcome::Ok(())
490 }
491
492 pub async fn prepare_async(&mut self, cx: &Cx, sql: &str) -> Outcome<PreparedStatement, Error> {
496 let stmt_id = self.next_prepared_id;
497 self.next_prepared_id = self.next_prepared_id.saturating_add(1);
498 let stmt_name = format!("sqlmodel_stmt_{stmt_id}");
499
500 if let Outcome::Err(e) = self
501 .send_message(
502 cx,
503 &FrontendMessage::Parse {
504 name: stmt_name.clone(),
505 query: sql.to_string(),
506 param_types: Vec::new(),
509 },
510 )
511 .await
512 {
513 return Outcome::Err(e);
514 }
515
516 if let Outcome::Err(e) = self
517 .send_message(
518 cx,
519 &FrontendMessage::Describe {
520 kind: DescribeKind::Statement,
521 name: stmt_name.clone(),
522 },
523 )
524 .await
525 {
526 return Outcome::Err(e);
527 }
528
529 if let Outcome::Err(e) = self.send_message(cx, &FrontendMessage::Sync).await {
530 return Outcome::Err(e);
531 }
532
533 let mut param_type_oids: Option<Vec<u32>> = None;
534 let mut columns: Option<Vec<String>> = None;
535
536 loop {
537 let msg = match self.receive_message(cx).await {
538 Outcome::Ok(m) => m,
539 Outcome::Err(e) => return Outcome::Err(e),
540 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
541 Outcome::Panicked(p) => return Outcome::Panicked(p),
542 };
543
544 match msg {
545 BackendMessage::ParseComplete
546 | BackendMessage::BindComplete
547 | BackendMessage::CloseComplete
548 | BackendMessage::NoData
549 | BackendMessage::EmptyQueryResponse => {}
550 BackendMessage::ParameterDescription(oids) => {
551 param_type_oids = Some(oids);
552 }
553 BackendMessage::RowDescription(desc) => {
554 columns = Some(desc.iter().map(|f| f.name.clone()).collect());
555 }
556 BackendMessage::ReadyForQuery(status) => {
557 self.state = ConnectionState::Ready(TransactionStatusState::from(status));
558 break;
559 }
560 BackendMessage::ErrorResponse(e) => {
561 self.state = ConnectionState::Error;
562 return Outcome::Err(error_from_fields(&e));
563 }
564 BackendMessage::NoticeResponse(_notice) => {}
565 other => {
566 return Outcome::Err(protocol_error(format!(
567 "Unexpected message during prepare: {other:?}"
568 )));
569 }
570 }
571 }
572
573 let param_type_oids = param_type_oids.unwrap_or_default();
574 self.prepared.insert(
575 stmt_id,
576 PgPreparedMeta {
577 name: stmt_name,
578 param_type_oids: param_type_oids.clone(),
579 },
580 );
581
582 match columns {
583 Some(cols) => Outcome::Ok(PreparedStatement::with_columns(
584 stmt_id,
585 sql.to_string(),
586 param_type_oids.len(),
587 cols,
588 )),
589 None => Outcome::Ok(PreparedStatement::new(
590 stmt_id,
591 sql.to_string(),
592 param_type_oids.len(),
593 )),
594 }
595 }
596
597 pub async fn query_prepared_async(
598 &mut self,
599 cx: &Cx,
600 stmt: &PreparedStatement,
601 params: &[Value],
602 ) -> Outcome<Vec<Row>, Error> {
603 let meta = match self.prepared.get(&stmt.id()) {
604 Some(m) => m.clone(),
605 None => {
606 return Outcome::Err(query_error_msg(
607 format!("Unknown prepared statement id {}", stmt.id()),
608 QueryErrorKind::Database,
609 ));
610 }
611 };
612
613 if meta.param_type_oids.len() != params.len() {
614 return Outcome::Err(query_error_msg(
615 format!(
616 "Prepared statement expects {} params, got {}",
617 meta.param_type_oids.len(),
618 params.len()
619 ),
620 QueryErrorKind::Database,
621 ));
622 }
623
624 match self.run_prepared(cx, &meta, params).await {
625 Outcome::Ok(result) => Outcome::Ok(result.rows),
626 Outcome::Err(e) => Outcome::Err(e),
627 Outcome::Cancelled(r) => Outcome::Cancelled(r),
628 Outcome::Panicked(p) => Outcome::Panicked(p),
629 }
630 }
631
632 pub async fn execute_prepared_async(
633 &mut self,
634 cx: &Cx,
635 stmt: &PreparedStatement,
636 params: &[Value],
637 ) -> Outcome<u64, Error> {
638 let meta = match self.prepared.get(&stmt.id()) {
639 Some(m) => m.clone(),
640 None => {
641 return Outcome::Err(query_error_msg(
642 format!("Unknown prepared statement id {}", stmt.id()),
643 QueryErrorKind::Database,
644 ));
645 }
646 };
647
648 if meta.param_type_oids.len() != params.len() {
649 return Outcome::Err(query_error_msg(
650 format!(
651 "Prepared statement expects {} params, got {}",
652 meta.param_type_oids.len(),
653 params.len()
654 ),
655 QueryErrorKind::Database,
656 ));
657 }
658
659 match self.run_prepared(cx, &meta, params).await {
660 Outcome::Ok(result) => {
661 Outcome::Ok(parse_rows_affected(result.command_tag.as_deref()).unwrap_or(0))
662 }
663 Outcome::Err(e) => Outcome::Err(e),
664 Outcome::Cancelled(r) => Outcome::Cancelled(r),
665 Outcome::Panicked(p) => Outcome::Panicked(p),
666 }
667 }
668
669 async fn read_extended_result(&mut self, cx: &Cx) -> Outcome<PgQueryResult, Error> {
672 let mut field_descs: Option<Vec<crate::protocol::FieldDescription>> = None;
674 let mut columns: Option<Arc<ColumnInfo>> = None;
675 let mut rows: Vec<Row> = Vec::new();
676 let mut command_tag: Option<String> = None;
677
678 loop {
679 let msg = match self.receive_message(cx).await {
680 Outcome::Ok(m) => m,
681 Outcome::Err(e) => return Outcome::Err(e),
682 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
683 Outcome::Panicked(p) => return Outcome::Panicked(p),
684 };
685
686 match msg {
687 BackendMessage::ParseComplete
688 | BackendMessage::BindComplete
689 | BackendMessage::CloseComplete
690 | BackendMessage::ParameterDescription(_)
691 | BackendMessage::NoData
692 | BackendMessage::PortalSuspended
693 | BackendMessage::EmptyQueryResponse => {}
694 BackendMessage::RowDescription(desc) => {
695 let names: Vec<String> = desc.iter().map(|f| f.name.clone()).collect();
696 columns = Some(Arc::new(ColumnInfo::new(names)));
697 field_descs = Some(desc);
698 }
699 BackendMessage::DataRow(raw_values) => {
700 let Some(ref desc) = field_descs else {
701 return Outcome::Err(protocol_error(
702 "DataRow received before RowDescription",
703 ));
704 };
705 let Some(ref cols) = columns else {
706 return Outcome::Err(protocol_error("Row column metadata missing"));
707 };
708 if raw_values.len() != desc.len() {
709 return Outcome::Err(protocol_error("DataRow field count mismatch"));
710 }
711
712 let mut values = Vec::with_capacity(raw_values.len());
713 for (i, raw) in raw_values.into_iter().enumerate() {
714 match raw {
715 None => values.push(Value::Null),
716 Some(bytes) => {
717 let field = &desc[i];
718 let format = Format::from_code(field.format);
719 let decoded = match decode_value(
720 field.type_oid,
721 Some(bytes.as_slice()),
722 format,
723 ) {
724 Ok(v) => v,
725 Err(e) => return Outcome::Err(e),
726 };
727 values.push(decoded);
728 }
729 }
730 }
731 rows.push(Row::with_columns(Arc::clone(cols), values));
732 }
733 BackendMessage::CommandComplete(tag) => {
734 command_tag = Some(tag);
735 }
736 BackendMessage::ReadyForQuery(status) => {
737 self.state = ConnectionState::Ready(TransactionStatusState::from(status));
738 break;
739 }
740 BackendMessage::ErrorResponse(e) => {
741 self.state = ConnectionState::Error;
742 return Outcome::Err(error_from_fields(&e));
743 }
744 BackendMessage::NoticeResponse(_notice) => {}
745 _ => {}
746 }
747 }
748
749 Outcome::Ok(PgQueryResult { rows, command_tag })
750 }
751
752 async fn run_extended(
753 &mut self,
754 cx: &Cx,
755 sql: &str,
756 params: &[Value],
757 ) -> Outcome<PgQueryResult, Error> {
758 let mut param_types = Vec::with_capacity(params.len());
760 let mut param_values = Vec::with_capacity(params.len());
761
762 for v in params {
763 if matches!(v, Value::Null) {
764 param_types.push(0);
765 param_values.push(None);
766 continue;
767 }
768 match encode_value(v, Format::Text) {
769 Ok((bytes, oid)) => {
770 param_types.push(oid);
771 param_values.push(Some(bytes));
772 }
773 Err(e) => return Outcome::Err(e),
774 }
775 }
776
777 if let Outcome::Err(e) = self
779 .send_message(
780 cx,
781 &FrontendMessage::Parse {
782 name: String::new(),
783 query: sql.to_string(),
784 param_types,
785 },
786 )
787 .await
788 {
789 return Outcome::Err(e);
790 }
791
792 let param_formats = if params.is_empty() {
793 Vec::new()
794 } else {
795 vec![Format::Text.code()]
796 };
797 if let Outcome::Err(e) = self
798 .send_message(
799 cx,
800 &FrontendMessage::Bind {
801 portal: String::new(),
802 statement: String::new(),
803 param_formats,
804 params: param_values,
805 result_formats: Vec::new(),
807 },
808 )
809 .await
810 {
811 return Outcome::Err(e);
812 }
813
814 if let Outcome::Err(e) = self
815 .send_message(
816 cx,
817 &FrontendMessage::Describe {
818 kind: DescribeKind::Portal,
819 name: String::new(),
820 },
821 )
822 .await
823 {
824 return Outcome::Err(e);
825 }
826
827 if let Outcome::Err(e) = self
828 .send_message(
829 cx,
830 &FrontendMessage::Execute {
831 portal: String::new(),
832 max_rows: 0,
833 },
834 )
835 .await
836 {
837 return Outcome::Err(e);
838 }
839
840 if let Outcome::Err(e) = self.send_message(cx, &FrontendMessage::Sync).await {
841 return Outcome::Err(e);
842 }
843 self.read_extended_result(cx).await
844 }
845
846 async fn run_prepared(
847 &mut self,
848 cx: &Cx,
849 meta: &PgPreparedMeta,
850 params: &[Value],
851 ) -> Outcome<PgQueryResult, Error> {
852 let mut param_values = Vec::with_capacity(params.len());
853
854 for (i, v) in params.iter().enumerate() {
855 if matches!(v, Value::Null) {
856 param_values.push(None);
857 continue;
858 }
859 match encode_value(v, Format::Text) {
860 Ok((bytes, oid)) => {
861 let expected = meta.param_type_oids.get(i).copied().unwrap_or(0);
862 if expected != 0 && expected != oid {
863 return Outcome::Err(query_error_msg(
864 format!(
865 "Prepared statement param {} expects type OID {}, got {}",
866 i + 1,
867 expected,
868 oid
869 ),
870 QueryErrorKind::Database,
871 ));
872 }
873 param_values.push(Some(bytes));
874 }
875 Err(e) => return Outcome::Err(e),
876 }
877 }
878
879 let param_formats = if params.is_empty() {
880 Vec::new()
881 } else {
882 vec![Format::Text.code()]
883 };
884
885 if let Outcome::Err(e) = self
886 .send_message(
887 cx,
888 &FrontendMessage::Bind {
889 portal: String::new(),
890 statement: meta.name.clone(),
891 param_formats,
892 params: param_values,
893 result_formats: Vec::new(),
894 },
895 )
896 .await
897 {
898 return Outcome::Err(e);
899 }
900
901 if let Outcome::Err(e) = self
902 .send_message(
903 cx,
904 &FrontendMessage::Describe {
905 kind: DescribeKind::Portal,
906 name: String::new(),
907 },
908 )
909 .await
910 {
911 return Outcome::Err(e);
912 }
913
914 if let Outcome::Err(e) = self
915 .send_message(
916 cx,
917 &FrontendMessage::Execute {
918 portal: String::new(),
919 max_rows: 0,
920 },
921 )
922 .await
923 {
924 return Outcome::Err(e);
925 }
926
927 if let Outcome::Err(e) = self.send_message(cx, &FrontendMessage::Sync).await {
928 return Outcome::Err(e);
929 }
930
931 self.read_extended_result(cx).await
932 }
933
934 #[cfg(feature = "tls")]
937 async fn negotiate_ssl(&mut self) -> Outcome<(), Error> {
938 if let Outcome::Err(e) = self.send_message_no_cx(&FrontendMessage::SSLRequest).await {
940 return Outcome::Err(e);
941 }
942
943 let mut buf = [0u8; 1];
945 if let Err(e) = self.stream.read_exact(&mut buf).await {
946 return Outcome::Err(Error::Connection(ConnectionError {
947 kind: ConnectionErrorKind::Ssl,
948 message: format!("Failed to read SSL response: {}", e),
949 source: Some(Box::new(e)),
950 }));
951 }
952
953 match buf[0] {
954 b'S' => {
955 #[cfg(feature = "tls")]
956 {
957 let plain = match std::mem::replace(&mut self.stream, PgAsyncStream::Closed) {
958 PgAsyncStream::Plain(s) => s,
959 other => {
960 self.stream = other;
961 return Outcome::Err(connection_error(
962 "TLS upgrade requires a plain TCP stream",
963 ));
964 }
965 };
966
967 let tls_stream = match AsyncTlsStream::handshake(
968 plain,
969 self.config.ssl_mode,
970 &self.config.host,
971 )
972 .await
973 {
974 Ok(s) => s,
975 Err(e) => return Outcome::Err(e),
976 };
977
978 self.stream = PgAsyncStream::Tls(tls_stream);
979 Outcome::Ok(())
980 }
981
982 #[cfg(not(feature = "tls"))]
983 {
984 Outcome::Err(connection_error(
985 "TLS requested but 'sqlmodel-postgres' was built without feature 'tls'",
986 ))
987 }
988 }
989 b'N' => {
990 if self.config.ssl_mode.is_required() {
991 Outcome::Err(Error::Connection(ConnectionError {
992 kind: ConnectionErrorKind::Ssl,
993 message: "Server does not support SSL".to_string(),
994 source: None,
995 }))
996 } else {
997 Outcome::Ok(())
998 }
999 }
1000 other => Outcome::Err(Error::Connection(ConnectionError {
1001 kind: ConnectionErrorKind::Ssl,
1002 message: format!("Unexpected SSL response: 0x{other:02x}"),
1003 source: None,
1004 })),
1005 }
1006 }
1007
1008 async fn send_startup(&mut self) -> Outcome<(), Error> {
1009 let params = self.config.startup_params();
1010 self.send_message_no_cx(&FrontendMessage::Startup {
1011 version: PROTOCOL_VERSION,
1012 params,
1013 })
1014 .await
1015 }
1016
1017 fn require_auth_value(&self, message: &'static str) -> Outcome<&str, Error> {
1018 match self.config.password.as_deref() {
1020 Some(password) => Outcome::Ok(password),
1021 None => Outcome::Err(auth_error(message)),
1022 }
1023 }
1024
1025 async fn handle_auth(&mut self) -> Outcome<(), Error> {
1026 loop {
1027 let msg = match self.receive_message_no_cx().await {
1028 Outcome::Ok(m) => m,
1029 Outcome::Err(e) => return Outcome::Err(e),
1030 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1031 Outcome::Panicked(p) => return Outcome::Panicked(p),
1032 };
1033
1034 match msg {
1035 BackendMessage::AuthenticationOk => return Outcome::Ok(()),
1036 BackendMessage::AuthenticationCleartextPassword => {
1037 let auth_value = match self
1038 .require_auth_value("Authentication value required but not provided")
1039 {
1040 Outcome::Ok(password) => password,
1041 Outcome::Err(e) => return Outcome::Err(e),
1042 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1043 Outcome::Panicked(p) => return Outcome::Panicked(p),
1044 };
1045 if let Outcome::Err(e) = self
1046 .send_message_no_cx(&FrontendMessage::PasswordMessage(
1047 auth_value.to_string(),
1048 ))
1049 .await
1050 {
1051 return Outcome::Err(e);
1052 }
1053 }
1054 BackendMessage::AuthenticationMD5Password(salt) => {
1055 let auth_value = match self
1056 .require_auth_value("Authentication value required but not provided")
1057 {
1058 Outcome::Ok(password) => password,
1059 Outcome::Err(e) => return Outcome::Err(e),
1060 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1061 Outcome::Panicked(p) => return Outcome::Panicked(p),
1062 };
1063 let hash = md5_password(&self.config.user, auth_value, salt);
1064 if let Outcome::Err(e) = self
1065 .send_message_no_cx(&FrontendMessage::PasswordMessage(hash))
1066 .await
1067 {
1068 return Outcome::Err(e);
1069 }
1070 }
1071 BackendMessage::AuthenticationSASL(mechanisms) => {
1072 if mechanisms.contains(&"SCRAM-SHA-256".to_string()) {
1073 match self.scram_auth().await {
1074 Outcome::Ok(()) => {}
1075 Outcome::Err(e) => return Outcome::Err(e),
1076 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1077 Outcome::Panicked(p) => return Outcome::Panicked(p),
1078 }
1079 } else {
1080 return Outcome::Err(auth_error(format!(
1081 "Unsupported SASL mechanisms: {:?}",
1082 mechanisms
1083 )));
1084 }
1085 }
1086 BackendMessage::ErrorResponse(e) => {
1087 self.state = ConnectionState::Error;
1088 return Outcome::Err(error_from_fields(&e));
1089 }
1090 other => {
1091 return Outcome::Err(protocol_error(format!(
1092 "Unexpected message during auth: {other:?}"
1093 )));
1094 }
1095 }
1096 }
1097 }
1098
1099 async fn scram_auth(&mut self) -> Outcome<(), Error> {
1100 let auth_value =
1101 match self.require_auth_value("Authentication value required for SCRAM-SHA-256") {
1102 Outcome::Ok(password) => password,
1103 Outcome::Err(e) => return Outcome::Err(e),
1104 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1105 Outcome::Panicked(p) => return Outcome::Panicked(p),
1106 };
1107
1108 let mut client = ScramClient::new(&self.config.user, auth_value);
1109
1110 let client_first = client.client_first();
1112 if let Outcome::Err(e) = self
1113 .send_message_no_cx(&FrontendMessage::SASLInitialResponse {
1114 mechanism: "SCRAM-SHA-256".to_string(),
1115 data: client_first,
1116 })
1117 .await
1118 {
1119 return Outcome::Err(e);
1120 }
1121
1122 let msg = match self.receive_message_no_cx().await {
1124 Outcome::Ok(m) => m,
1125 Outcome::Err(e) => return Outcome::Err(e),
1126 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1127 Outcome::Panicked(p) => return Outcome::Panicked(p),
1128 };
1129 let server_first_data = match msg {
1130 BackendMessage::AuthenticationSASLContinue(data) => data,
1131 BackendMessage::ErrorResponse(e) => {
1132 self.state = ConnectionState::Error;
1133 return Outcome::Err(error_from_fields(&e));
1134 }
1135 other => {
1136 return Outcome::Err(protocol_error(format!(
1137 "Expected SASL continue, got: {other:?}"
1138 )));
1139 }
1140 };
1141
1142 let client_final = match client.process_server_first(&server_first_data) {
1144 Ok(v) => v,
1145 Err(e) => return Outcome::Err(e),
1146 };
1147 if let Outcome::Err(e) = self
1148 .send_message_no_cx(&FrontendMessage::SASLResponse(client_final))
1149 .await
1150 {
1151 return Outcome::Err(e);
1152 }
1153
1154 let msg = match self.receive_message_no_cx().await {
1156 Outcome::Ok(m) => m,
1157 Outcome::Err(e) => return Outcome::Err(e),
1158 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1159 Outcome::Panicked(p) => return Outcome::Panicked(p),
1160 };
1161 let server_final_data = match msg {
1162 BackendMessage::AuthenticationSASLFinal(data) => data,
1163 BackendMessage::ErrorResponse(e) => {
1164 self.state = ConnectionState::Error;
1165 return Outcome::Err(error_from_fields(&e));
1166 }
1167 other => {
1168 return Outcome::Err(protocol_error(format!(
1169 "Expected SASL final, got: {other:?}"
1170 )));
1171 }
1172 };
1173
1174 if let Err(e) = client.verify_server_final(&server_final_data) {
1175 return Outcome::Err(e);
1176 }
1177
1178 let msg = match self.receive_message_no_cx().await {
1180 Outcome::Ok(m) => m,
1181 Outcome::Err(e) => return Outcome::Err(e),
1182 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1183 Outcome::Panicked(p) => return Outcome::Panicked(p),
1184 };
1185 match msg {
1186 BackendMessage::AuthenticationOk => Outcome::Ok(()),
1187 BackendMessage::ErrorResponse(e) => {
1188 self.state = ConnectionState::Error;
1189 Outcome::Err(error_from_fields(&e))
1190 }
1191 other => Outcome::Err(protocol_error(format!(
1192 "Expected AuthenticationOk, got: {other:?}"
1193 ))),
1194 }
1195 }
1196
1197 async fn read_startup_messages(&mut self) -> Outcome<(), Error> {
1198 loop {
1199 let msg = match self.receive_message_no_cx().await {
1200 Outcome::Ok(m) => m,
1201 Outcome::Err(e) => return Outcome::Err(e),
1202 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1203 Outcome::Panicked(p) => return Outcome::Panicked(p),
1204 };
1205
1206 match msg {
1207 BackendMessage::BackendKeyData {
1208 process_id,
1209 secret_key,
1210 } => {
1211 self.process_id = process_id;
1212 self.secret_key = secret_key;
1213 }
1214 BackendMessage::ParameterStatus { name, value } => {
1215 self.parameters.insert(name, value);
1216 }
1217 BackendMessage::ReadyForQuery(status) => {
1218 self.state = ConnectionState::Ready(TransactionStatusState::from(status));
1219 return Outcome::Ok(());
1220 }
1221 BackendMessage::ErrorResponse(e) => {
1222 self.state = ConnectionState::Error;
1223 return Outcome::Err(error_from_fields(&e));
1224 }
1225 BackendMessage::NoticeResponse(_notice) => {}
1226 other => {
1227 return Outcome::Err(protocol_error(format!(
1228 "Unexpected startup message: {other:?}"
1229 )));
1230 }
1231 }
1232 }
1233 }
1234
1235 async fn send_message(&mut self, cx: &Cx, msg: &FrontendMessage) -> Outcome<(), Error> {
1238 if let Some(reason) = cx.cancel_reason() {
1240 return Outcome::Cancelled(reason);
1241 }
1242 self.send_message_no_cx(msg).await
1243 }
1244
1245 async fn receive_message(&mut self, cx: &Cx) -> Outcome<BackendMessage, Error> {
1246 if let Some(reason) = cx.cancel_reason() {
1247 return Outcome::Cancelled(reason);
1248 }
1249 self.receive_message_no_cx().await
1250 }
1251
1252 async fn send_message_no_cx(&mut self, msg: &FrontendMessage) -> Outcome<(), Error> {
1253 let data = self.writer.write(msg).to_vec();
1254
1255 if let Err(e) = self.stream.write_all(&data).await {
1256 self.state = ConnectionState::Error;
1257 return Outcome::Err(Error::Connection(ConnectionError {
1258 kind: ConnectionErrorKind::Disconnected,
1259 message: format!("Failed to write to server: {}", e),
1260 source: Some(Box::new(e)),
1261 }));
1262 }
1263
1264 if let Err(e) = self.stream.flush().await {
1265 self.state = ConnectionState::Error;
1266 return Outcome::Err(Error::Connection(ConnectionError {
1267 kind: ConnectionErrorKind::Disconnected,
1268 message: format!("Failed to flush stream: {}", e),
1269 source: Some(Box::new(e)),
1270 }));
1271 }
1272
1273 Outcome::Ok(())
1274 }
1275
1276 async fn receive_message_no_cx(&mut self) -> Outcome<BackendMessage, Error> {
1277 loop {
1278 match self.reader.next_message() {
1279 Ok(Some(msg)) => return Outcome::Ok(msg),
1280 Ok(None) => {}
1281 Err(e) => {
1282 self.state = ConnectionState::Error;
1283 return Outcome::Err(protocol_error(format!("Protocol error: {}", e)));
1284 }
1285 }
1286
1287 let n = match self.stream.read_some(&mut self.read_buf).await {
1288 Ok(n) => n,
1289 Err(e) => {
1290 self.state = ConnectionState::Error;
1291 return Outcome::Err(match e.kind() {
1292 std::io::ErrorKind::TimedOut | std::io::ErrorKind::WouldBlock => {
1293 Error::Timeout
1294 }
1295 _ => Error::Connection(ConnectionError {
1296 kind: ConnectionErrorKind::Disconnected,
1297 message: format!("Failed to read from server: {}", e),
1298 source: Some(Box::new(e)),
1299 }),
1300 });
1301 }
1302 };
1303
1304 if n == 0 {
1305 self.state = ConnectionState::Disconnected;
1306 return Outcome::Err(Error::Connection(ConnectionError {
1307 kind: ConnectionErrorKind::Disconnected,
1308 message: "Connection closed by server".to_string(),
1309 source: None,
1310 }));
1311 }
1312
1313 if let Err(e) = self.reader.feed(&self.read_buf[..n]) {
1314 self.state = ConnectionState::Error;
1315 return Outcome::Err(protocol_error(format!("Protocol error: {}", e)));
1316 }
1317 }
1318 }
1319}
1320
1321pub struct SharedPgConnection {
1323 inner: Arc<Mutex<PgAsyncConnection>>,
1324}
1325
1326impl SharedPgConnection {
1327 pub fn new(conn: PgAsyncConnection) -> Self {
1328 Self {
1329 inner: Arc::new(Mutex::new(conn)),
1330 }
1331 }
1332
1333 pub async fn connect(cx: &Cx, config: PgConfig) -> Outcome<Self, Error> {
1334 match PgAsyncConnection::connect(cx, config).await {
1335 Outcome::Ok(conn) => Outcome::Ok(Self::new(conn)),
1336 Outcome::Err(e) => Outcome::Err(e),
1337 Outcome::Cancelled(r) => Outcome::Cancelled(r),
1338 Outcome::Panicked(p) => Outcome::Panicked(p),
1339 }
1340 }
1341
1342 pub fn inner(&self) -> &Arc<Mutex<PgAsyncConnection>> {
1343 &self.inner
1344 }
1345
1346 async fn begin_transaction_impl(
1347 &self,
1348 cx: &Cx,
1349 isolation: Option<IsolationLevel>,
1350 ) -> Outcome<SharedPgTransaction<'_>, Error> {
1351 let inner = Arc::clone(&self.inner);
1352 let Ok(mut guard) = inner.lock(cx).await else {
1353 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1354 };
1355
1356 if let Some(level) = isolation {
1357 let sql = format!("SET TRANSACTION ISOLATION LEVEL {}", level.as_sql());
1358 match guard.execute_async(cx, &sql, &[]).await {
1359 Outcome::Ok(_) => {}
1360 Outcome::Err(e) => return Outcome::Err(e),
1361 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1362 Outcome::Panicked(p) => return Outcome::Panicked(p),
1363 }
1364 }
1365
1366 match guard.execute_async(cx, "BEGIN", &[]).await {
1367 Outcome::Ok(_) => {}
1368 Outcome::Err(e) => return Outcome::Err(e),
1369 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1370 Outcome::Panicked(p) => return Outcome::Panicked(p),
1371 }
1372
1373 drop(guard);
1374 Outcome::Ok(SharedPgTransaction {
1375 inner,
1376 committed: false,
1377 _marker: std::marker::PhantomData,
1378 })
1379 }
1380}
1381
1382impl Clone for SharedPgConnection {
1383 fn clone(&self) -> Self {
1384 Self {
1385 inner: Arc::clone(&self.inner),
1386 }
1387 }
1388}
1389
1390impl std::fmt::Debug for SharedPgConnection {
1391 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1392 f.debug_struct("SharedPgConnection")
1393 .field("inner", &"Arc<Mutex<PgAsyncConnection>>")
1394 .finish()
1395 }
1396}
1397
1398pub struct SharedPgTransaction<'conn> {
1399 inner: Arc<Mutex<PgAsyncConnection>>,
1400 committed: bool,
1401 _marker: std::marker::PhantomData<&'conn ()>,
1402}
1403
1404impl<'conn> Drop for SharedPgTransaction<'conn> {
1405 fn drop(&mut self) {
1406 if !self.committed {
1407 #[cfg(debug_assertions)]
1412 eprintln!(
1413 "WARNING: SharedPgTransaction dropped without commit/rollback. \
1414 The PostgreSQL transaction may still be open."
1415 );
1416 }
1417 }
1418}
1419
1420impl Connection for SharedPgConnection {
1421 type Tx<'conn>
1422 = SharedPgTransaction<'conn>
1423 where
1424 Self: 'conn;
1425
1426 fn dialect(&self) -> sqlmodel_core::Dialect {
1427 sqlmodel_core::Dialect::Postgres
1428 }
1429
1430 fn query(
1431 &self,
1432 cx: &Cx,
1433 sql: &str,
1434 params: &[Value],
1435 ) -> impl Future<Output = Outcome<Vec<Row>, Error>> + Send {
1436 let inner = Arc::clone(&self.inner);
1437 let sql = sql.to_string();
1438 let params = params.to_vec();
1439 async move {
1440 let Ok(mut guard) = inner.lock(cx).await else {
1441 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1442 };
1443 guard.query_async(cx, &sql, ¶ms).await
1444 }
1445 }
1446
1447 fn query_one(
1448 &self,
1449 cx: &Cx,
1450 sql: &str,
1451 params: &[Value],
1452 ) -> impl Future<Output = Outcome<Option<Row>, Error>> + Send {
1453 let inner = Arc::clone(&self.inner);
1454 let sql = sql.to_string();
1455 let params = params.to_vec();
1456 async move {
1457 let Ok(mut guard) = inner.lock(cx).await else {
1458 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1459 };
1460 let rows = match guard.query_async(cx, &sql, ¶ms).await {
1461 Outcome::Ok(r) => r,
1462 Outcome::Err(e) => return Outcome::Err(e),
1463 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1464 Outcome::Panicked(p) => return Outcome::Panicked(p),
1465 };
1466 Outcome::Ok(rows.into_iter().next())
1467 }
1468 }
1469
1470 fn execute(
1471 &self,
1472 cx: &Cx,
1473 sql: &str,
1474 params: &[Value],
1475 ) -> impl Future<Output = Outcome<u64, Error>> + Send {
1476 let inner = Arc::clone(&self.inner);
1477 let sql = sql.to_string();
1478 let params = params.to_vec();
1479 async move {
1480 let Ok(mut guard) = inner.lock(cx).await else {
1481 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1482 };
1483 guard.execute_async(cx, &sql, ¶ms).await
1484 }
1485 }
1486
1487 fn insert(
1488 &self,
1489 cx: &Cx,
1490 sql: &str,
1491 params: &[Value],
1492 ) -> impl Future<Output = Outcome<i64, Error>> + Send {
1493 let inner = Arc::clone(&self.inner);
1494 let sql = sql.to_string();
1495 let params = params.to_vec();
1496 async move {
1497 let Ok(mut guard) = inner.lock(cx).await else {
1498 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1499 };
1500 guard.insert_async(cx, &sql, ¶ms).await
1501 }
1502 }
1503
1504 fn batch(
1505 &self,
1506 cx: &Cx,
1507 statements: &[(String, Vec<Value>)],
1508 ) -> impl Future<Output = Outcome<Vec<u64>, Error>> + Send {
1509 let inner = Arc::clone(&self.inner);
1510 let statements = statements.to_vec();
1511 async move {
1512 let Ok(mut guard) = inner.lock(cx).await else {
1513 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1514 };
1515 let mut results = Vec::with_capacity(statements.len());
1516 for (sql, params) in &statements {
1517 match guard.execute_async(cx, sql, params).await {
1518 Outcome::Ok(n) => results.push(n),
1519 Outcome::Err(e) => return Outcome::Err(e),
1520 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1521 Outcome::Panicked(p) => return Outcome::Panicked(p),
1522 }
1523 }
1524 Outcome::Ok(results)
1525 }
1526 }
1527
1528 fn begin(&self, cx: &Cx) -> impl Future<Output = Outcome<Self::Tx<'_>, Error>> + Send {
1529 self.begin_with(cx, IsolationLevel::default())
1530 }
1531
1532 fn begin_with(
1533 &self,
1534 cx: &Cx,
1535 isolation: IsolationLevel,
1536 ) -> impl Future<Output = Outcome<Self::Tx<'_>, Error>> + Send {
1537 self.begin_transaction_impl(cx, Some(isolation))
1538 }
1539
1540 fn prepare(
1541 &self,
1542 cx: &Cx,
1543 sql: &str,
1544 ) -> impl Future<Output = Outcome<PreparedStatement, Error>> + Send {
1545 let inner = Arc::clone(&self.inner);
1546 let sql = sql.to_string();
1547 async move {
1548 let Ok(mut guard) = inner.lock(cx).await else {
1549 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1550 };
1551 guard.prepare_async(cx, &sql).await
1552 }
1553 }
1554
1555 fn query_prepared(
1556 &self,
1557 cx: &Cx,
1558 stmt: &PreparedStatement,
1559 params: &[Value],
1560 ) -> impl Future<Output = Outcome<Vec<Row>, Error>> + Send {
1561 let inner = Arc::clone(&self.inner);
1562 let stmt = stmt.clone();
1563 let params = params.to_vec();
1564 async move {
1565 let Ok(mut guard) = inner.lock(cx).await else {
1566 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1567 };
1568 guard.query_prepared_async(cx, &stmt, ¶ms).await
1569 }
1570 }
1571
1572 fn execute_prepared(
1573 &self,
1574 cx: &Cx,
1575 stmt: &PreparedStatement,
1576 params: &[Value],
1577 ) -> impl Future<Output = Outcome<u64, Error>> + Send {
1578 let inner = Arc::clone(&self.inner);
1579 let stmt = stmt.clone();
1580 let params = params.to_vec();
1581 async move {
1582 let Ok(mut guard) = inner.lock(cx).await else {
1583 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1584 };
1585 guard.execute_prepared_async(cx, &stmt, ¶ms).await
1586 }
1587 }
1588
1589 fn ping(&self, cx: &Cx) -> impl Future<Output = Outcome<(), Error>> + Send {
1590 let inner = Arc::clone(&self.inner);
1591 async move {
1592 let Ok(mut guard) = inner.lock(cx).await else {
1593 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1594 };
1595 guard.ping_async(cx).await
1596 }
1597 }
1598
1599 async fn close(self, cx: &Cx) -> sqlmodel_core::Result<()> {
1600 let Ok(mut guard) = self.inner.lock(cx).await else {
1601 return Err(connection_error("Failed to acquire connection lock"));
1602 };
1603 match guard.close_async(cx).await {
1604 Outcome::Ok(()) => Ok(()),
1605 Outcome::Err(e) => Err(e),
1606 Outcome::Cancelled(r) => Err(Error::Query(QueryError {
1607 kind: QueryErrorKind::Cancelled,
1608 message: format!("Cancelled: {r:?}"),
1609 sqlstate: None,
1610 sql: None,
1611 detail: None,
1612 hint: None,
1613 position: None,
1614 source: None,
1615 })),
1616 Outcome::Panicked(p) => Err(Error::Protocol(ProtocolError {
1617 message: format!("Panicked: {p:?}"),
1618 raw_data: None,
1619 source: None,
1620 })),
1621 }
1622 }
1623}
1624
1625impl<'conn> TransactionOps for SharedPgTransaction<'conn> {
1626 fn query(
1627 &self,
1628 cx: &Cx,
1629 sql: &str,
1630 params: &[Value],
1631 ) -> impl Future<Output = Outcome<Vec<Row>, Error>> + Send {
1632 let inner = Arc::clone(&self.inner);
1633 let sql = sql.to_string();
1634 let params = params.to_vec();
1635 async move {
1636 let Ok(mut guard) = inner.lock(cx).await else {
1637 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1638 };
1639 guard.query_async(cx, &sql, ¶ms).await
1640 }
1641 }
1642
1643 fn query_one(
1644 &self,
1645 cx: &Cx,
1646 sql: &str,
1647 params: &[Value],
1648 ) -> impl Future<Output = Outcome<Option<Row>, Error>> + Send {
1649 let inner = Arc::clone(&self.inner);
1650 let sql = sql.to_string();
1651 let params = params.to_vec();
1652 async move {
1653 let Ok(mut guard) = inner.lock(cx).await else {
1654 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1655 };
1656 let rows = match guard.query_async(cx, &sql, ¶ms).await {
1657 Outcome::Ok(r) => r,
1658 Outcome::Err(e) => return Outcome::Err(e),
1659 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1660 Outcome::Panicked(p) => return Outcome::Panicked(p),
1661 };
1662 Outcome::Ok(rows.into_iter().next())
1663 }
1664 }
1665
1666 fn execute(
1667 &self,
1668 cx: &Cx,
1669 sql: &str,
1670 params: &[Value],
1671 ) -> impl Future<Output = Outcome<u64, Error>> + Send {
1672 let inner = Arc::clone(&self.inner);
1673 let sql = sql.to_string();
1674 let params = params.to_vec();
1675 async move {
1676 let Ok(mut guard) = inner.lock(cx).await else {
1677 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1678 };
1679 guard.execute_async(cx, &sql, ¶ms).await
1680 }
1681 }
1682
1683 fn savepoint(&self, cx: &Cx, name: &str) -> impl Future<Output = Outcome<(), Error>> + Send {
1684 let inner = Arc::clone(&self.inner);
1685 let name = name.to_string();
1686 async move {
1687 if let Err(e) = validate_savepoint_name(&name) {
1688 return Outcome::Err(e);
1689 }
1690 let sql = format!("SAVEPOINT {}", name);
1691 let Ok(mut guard) = inner.lock(cx).await else {
1692 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1693 };
1694 guard.execute_async(cx, &sql, &[]).await.map(|_| ())
1695 }
1696 }
1697
1698 fn rollback_to(&self, cx: &Cx, name: &str) -> impl Future<Output = Outcome<(), Error>> + Send {
1699 let inner = Arc::clone(&self.inner);
1700 let name = name.to_string();
1701 async move {
1702 if let Err(e) = validate_savepoint_name(&name) {
1703 return Outcome::Err(e);
1704 }
1705 let sql = format!("ROLLBACK TO SAVEPOINT {}", name);
1706 let Ok(mut guard) = inner.lock(cx).await else {
1707 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1708 };
1709 guard.execute_async(cx, &sql, &[]).await.map(|_| ())
1710 }
1711 }
1712
1713 fn release(&self, cx: &Cx, name: &str) -> impl Future<Output = Outcome<(), Error>> + Send {
1714 let inner = Arc::clone(&self.inner);
1715 let name = name.to_string();
1716 async move {
1717 if let Err(e) = validate_savepoint_name(&name) {
1718 return Outcome::Err(e);
1719 }
1720 let sql = format!("RELEASE SAVEPOINT {}", name);
1721 let Ok(mut guard) = inner.lock(cx).await else {
1722 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1723 };
1724 guard.execute_async(cx, &sql, &[]).await.map(|_| ())
1725 }
1726 }
1727
1728 #[allow(unused_assignments)]
1730 fn commit(mut self, cx: &Cx) -> impl Future<Output = Outcome<(), Error>> + Send {
1731 let inner = Arc::clone(&self.inner);
1732 async move {
1733 let Ok(mut guard) = inner.lock(cx).await else {
1734 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1735 };
1736 let result = guard.execute_async(cx, "COMMIT", &[]).await;
1737 if matches!(result, Outcome::Ok(_)) {
1738 self.committed = true;
1739 }
1740 result.map(|_| ())
1741 }
1742 }
1743
1744 #[allow(unused_assignments)]
1745 fn rollback(mut self, cx: &Cx) -> impl Future<Output = Outcome<(), Error>> + Send {
1746 let inner = Arc::clone(&self.inner);
1747 async move {
1748 let Ok(mut guard) = inner.lock(cx).await else {
1749 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1750 };
1751 let result = guard.execute_async(cx, "ROLLBACK", &[]).await;
1752 if matches!(result, Outcome::Ok(_)) {
1753 self.committed = true;
1754 }
1755 result.map(|_| ())
1756 }
1757 }
1758}
1759
1760struct PgQueryResult {
1763 rows: Vec<Row>,
1764 command_tag: Option<String>,
1765}
1766
1767fn connection_error(msg: impl Into<String>) -> Error {
1768 Error::Connection(ConnectionError {
1769 kind: ConnectionErrorKind::Connect,
1770 message: msg.into(),
1771 source: None,
1772 })
1773}
1774
1775fn auth_error(msg: impl Into<String>) -> Error {
1776 Error::Connection(ConnectionError {
1777 kind: ConnectionErrorKind::Authentication,
1778 message: msg.into(),
1779 source: None,
1780 })
1781}
1782
1783fn protocol_error(msg: impl Into<String>) -> Error {
1784 Error::Protocol(ProtocolError {
1785 message: msg.into(),
1786 raw_data: None,
1787 source: None,
1788 })
1789}
1790
1791fn query_error_msg(msg: impl Into<String>, kind: QueryErrorKind) -> Error {
1792 Error::Query(QueryError {
1793 kind,
1794 message: msg.into(),
1795 sqlstate: None,
1796 sql: None,
1797 detail: None,
1798 hint: None,
1799 position: None,
1800 source: None,
1801 })
1802}
1803
1804fn error_from_fields(fields: &ErrorFields) -> Error {
1805 let kind = match fields.code.get(..2) {
1806 Some("08") => {
1807 return Error::Connection(ConnectionError {
1808 kind: ConnectionErrorKind::Connect,
1809 message: fields.message.clone(),
1810 source: None,
1811 });
1812 }
1813 Some("28") => {
1814 return Error::Connection(ConnectionError {
1815 kind: ConnectionErrorKind::Authentication,
1816 message: fields.message.clone(),
1817 source: None,
1818 });
1819 }
1820 Some("42") => QueryErrorKind::Syntax,
1821 Some("23") => QueryErrorKind::Constraint,
1822 Some("40") => {
1823 if fields.code == "40001" {
1824 QueryErrorKind::Serialization
1825 } else {
1826 QueryErrorKind::Deadlock
1827 }
1828 }
1829 Some("57") => {
1830 if fields.code == "57014" {
1831 QueryErrorKind::Cancelled
1832 } else {
1833 QueryErrorKind::Timeout
1834 }
1835 }
1836 _ => QueryErrorKind::Database,
1837 };
1838
1839 Error::Query(QueryError {
1840 kind,
1841 sql: None,
1842 sqlstate: Some(fields.code.clone()),
1843 message: fields.message.clone(),
1844 detail: fields.detail.clone(),
1845 hint: fields.hint.clone(),
1846 position: fields.position.map(|p| p as usize),
1847 source: None,
1848 })
1849}
1850
1851fn parse_rows_affected(tag: Option<&str>) -> Option<u64> {
1852 let tag = tag?;
1853 let mut parts = tag.split_whitespace().collect::<Vec<_>>();
1854 parts.pop().and_then(|last| last.parse::<u64>().ok())
1855}
1856
1857fn validate_savepoint_name(name: &str) -> sqlmodel_core::Result<()> {
1859 if name.is_empty() {
1860 return Err(query_error_msg(
1861 "Savepoint name cannot be empty",
1862 QueryErrorKind::Syntax,
1863 ));
1864 }
1865 if name.len() > 63 {
1866 return Err(query_error_msg(
1867 "Savepoint name exceeds maximum length of 63 characters",
1868 QueryErrorKind::Syntax,
1869 ));
1870 }
1871 let mut chars = name.chars();
1872 let Some(first) = chars.next() else {
1873 return Err(query_error_msg(
1874 "Savepoint name cannot be empty",
1875 QueryErrorKind::Syntax,
1876 ));
1877 };
1878 if !first.is_ascii_alphabetic() && first != '_' {
1879 return Err(query_error_msg(
1880 "Savepoint name must start with a letter or underscore",
1881 QueryErrorKind::Syntax,
1882 ));
1883 }
1884 for c in chars {
1885 if !c.is_ascii_alphanumeric() && c != '_' {
1886 return Err(query_error_msg(
1887 format!("Savepoint name contains invalid character: '{c}'"),
1888 QueryErrorKind::Syntax,
1889 ));
1890 }
1891 }
1892 Ok(())
1893}
1894
1895fn md5_password(user: &str, password: &str, salt: [u8; 4]) -> String {
1896 use std::fmt::Write;
1897
1898 let inner = format!("{password}{user}");
1899 let inner_hash = md5::compute(inner.as_bytes());
1900
1901 let mut outer_input = format!("{inner_hash:x}").into_bytes();
1902 outer_input.extend_from_slice(&salt);
1903 let outer_hash = md5::compute(&outer_input);
1904
1905 let mut result = String::with_capacity(35);
1906 result.push_str("md5");
1907 write!(&mut result, "{outer_hash:x}").unwrap();
1908 result
1909}
1910
1911