1#![allow(clippy::manual_async_fn)]
8#![allow(clippy::result_large_err)]
10
11use std::collections::HashMap;
12use std::future::Future;
13use std::io::{self, Read as StdRead, Write as StdWrite};
14use std::net::TcpStream as StdTcpStream;
15use std::sync::Arc;
16
17use asupersync::io::{AsyncRead, AsyncWrite, ReadBuf};
18use asupersync::net::TcpStream;
19use asupersync::sync::Mutex;
20use asupersync::{Cx, Outcome};
21
22use sqlmodel_core::connection::{Connection, IsolationLevel, PreparedStatement, TransactionOps};
23use sqlmodel_core::error::{
24 ConnectionError, ConnectionErrorKind, ProtocolError, QueryError, QueryErrorKind,
25};
26use sqlmodel_core::{Error, Row, Value};
27
28#[cfg(feature = "console")]
29use sqlmodel_console::{ConsoleAware, SqlModelConsole};
30
31use crate::auth;
32use crate::config::MySqlConfig;
33use crate::connection::{ConnectionState, ServerCapabilities};
34use crate::protocol::{
35 Command, ErrPacket, MAX_PACKET_SIZE, PacketHeader, PacketReader, PacketType, PacketWriter,
36 capabilities, charset, prepared,
37};
38use crate::types::{
39 ColumnDef, FieldType, decode_binary_value_with_len, decode_text_value, interpolate_params,
40};
41
42pub struct MySqlAsyncConnection {
47 stream: ConnectionStream,
49 state: ConnectionState,
51 server_caps: Option<ServerCapabilities>,
53 connection_id: u32,
55 status_flags: u16,
57 affected_rows: u64,
59 last_insert_id: u64,
61 warnings: u16,
63 config: MySqlConfig,
65 sequence_id: u8,
67 prepared_stmts: HashMap<u32, PreparedStmtMeta>,
69 #[cfg(feature = "console")]
71 console: Option<Arc<SqlModelConsole>>,
72}
73
74#[derive(Debug, Clone)]
79struct PreparedStmtMeta {
80 #[allow(dead_code)]
82 statement_id: u32,
83 params: Vec<ColumnDef>,
85 columns: Vec<ColumnDef>,
87}
88
89#[allow(dead_code)]
91enum ConnectionStream {
92 Sync(StdTcpStream),
94 Async(TcpStream),
96}
97
98impl std::fmt::Debug for MySqlAsyncConnection {
99 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100 f.debug_struct("MySqlAsyncConnection")
101 .field("state", &self.state)
102 .field("connection_id", &self.connection_id)
103 .field("host", &self.config.host)
104 .field("port", &self.config.port)
105 .field("database", &self.config.database)
106 .finish_non_exhaustive()
107 }
108}
109
110impl MySqlAsyncConnection {
111 pub async fn connect(_cx: &Cx, config: MySqlConfig) -> Outcome<Self, Error> {
119 let addr = config.socket_addr();
121 let socket_addr = match addr.parse() {
122 Ok(a) => a,
123 Err(e) => {
124 return Outcome::Err(Error::Connection(ConnectionError {
125 kind: ConnectionErrorKind::Connect,
126 message: format!("Invalid socket address: {}", e),
127 source: None,
128 }));
129 }
130 };
131 let stream = match TcpStream::connect_timeout(socket_addr, config.connect_timeout).await {
132 Ok(s) => s,
133 Err(e) => {
134 let kind = if e.kind() == io::ErrorKind::ConnectionRefused {
135 ConnectionErrorKind::Refused
136 } else {
137 ConnectionErrorKind::Connect
138 };
139 return Outcome::Err(Error::Connection(ConnectionError {
140 kind,
141 message: format!("Failed to connect to {}: {}", addr, e),
142 source: Some(Box::new(e)),
143 }));
144 }
145 };
146
147 stream.set_nodelay(true).ok();
149
150 let mut conn = Self {
151 stream: ConnectionStream::Async(stream),
152 state: ConnectionState::Connecting,
153 server_caps: None,
154 connection_id: 0,
155 status_flags: 0,
156 affected_rows: 0,
157 last_insert_id: 0,
158 warnings: 0,
159 config,
160 sequence_id: 0,
161 prepared_stmts: HashMap::new(),
162 #[cfg(feature = "console")]
163 console: None,
164 };
165
166 match conn.read_handshake_async().await {
168 Outcome::Ok(server_caps) => {
169 conn.connection_id = server_caps.connection_id;
170 conn.server_caps = Some(server_caps);
171 conn.state = ConnectionState::Authenticating;
172 }
173 Outcome::Err(e) => return Outcome::Err(e),
174 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
175 Outcome::Panicked(p) => return Outcome::Panicked(p),
176 }
177
178 if let Outcome::Err(e) = conn.send_handshake_response_async().await {
180 return Outcome::Err(e);
181 }
182
183 if let Outcome::Err(e) = conn.handle_auth_result_async().await {
185 return Outcome::Err(e);
186 }
187
188 conn.state = ConnectionState::Ready;
189 Outcome::Ok(conn)
190 }
191
192 pub fn state(&self) -> ConnectionState {
194 self.state
195 }
196
197 pub fn is_ready(&self) -> bool {
199 matches!(self.state, ConnectionState::Ready)
200 }
201
202 pub fn connection_id(&self) -> u32 {
204 self.connection_id
205 }
206
207 pub fn server_version(&self) -> Option<&str> {
209 self.server_caps
210 .as_ref()
211 .map(|caps| caps.server_version.as_str())
212 }
213
214 pub fn affected_rows(&self) -> u64 {
216 self.affected_rows
217 }
218
219 pub fn last_insert_id(&self) -> u64 {
221 self.last_insert_id
222 }
223
224 async fn read_packet_async(&mut self) -> Outcome<(Vec<u8>, u8), Error> {
228 let mut header_buf = [0u8; 4];
230
231 match &mut self.stream {
232 ConnectionStream::Async(stream) => {
233 let mut header_read = 0;
234 while header_read < 4 {
235 let mut read_buf = ReadBuf::new(&mut header_buf[header_read..]);
236 match std::future::poll_fn(|cx| {
237 std::pin::Pin::new(&mut *stream).poll_read(cx, &mut read_buf)
238 })
239 .await
240 {
241 Ok(()) => {
242 let n = read_buf.filled().len();
243 if n == 0 {
244 return Outcome::Err(Error::Connection(ConnectionError {
245 kind: ConnectionErrorKind::Disconnected,
246 message: "Connection closed while reading header".to_string(),
247 source: None,
248 }));
249 }
250 header_read += n;
251 }
252 Err(e) => {
253 return Outcome::Err(Error::Connection(ConnectionError {
254 kind: ConnectionErrorKind::Disconnected,
255 message: format!("Failed to read packet header: {}", e),
256 source: Some(Box::new(e)),
257 }));
258 }
259 }
260 }
261 }
262 ConnectionStream::Sync(stream) => {
263 if let Err(e) = stream.read_exact(&mut header_buf) {
264 return Outcome::Err(Error::Connection(ConnectionError {
265 kind: ConnectionErrorKind::Disconnected,
266 message: format!("Failed to read packet header: {}", e),
267 source: Some(Box::new(e)),
268 }));
269 }
270 }
271 }
272
273 let header = PacketHeader::from_bytes(&header_buf);
274 let payload_len = header.payload_length as usize;
275 self.sequence_id = header.sequence_id.wrapping_add(1);
276
277 let mut payload = vec![0u8; payload_len];
279 if payload_len > 0 {
280 match &mut self.stream {
281 ConnectionStream::Async(stream) => {
282 let mut total_read = 0;
283 while total_read < payload_len {
284 let mut read_buf = ReadBuf::new(&mut payload[total_read..]);
285 match std::future::poll_fn(|cx| {
286 std::pin::Pin::new(&mut *stream).poll_read(cx, &mut read_buf)
287 })
288 .await
289 {
290 Ok(()) => {
291 let n = read_buf.filled().len();
292 if n == 0 {
293 return Outcome::Err(Error::Connection(ConnectionError {
294 kind: ConnectionErrorKind::Disconnected,
295 message: "Connection closed while reading payload"
296 .to_string(),
297 source: None,
298 }));
299 }
300 total_read += n;
301 }
302 Err(e) => {
303 return Outcome::Err(Error::Connection(ConnectionError {
304 kind: ConnectionErrorKind::Disconnected,
305 message: format!("Failed to read packet payload: {}", e),
306 source: Some(Box::new(e)),
307 }));
308 }
309 }
310 }
311 }
312 ConnectionStream::Sync(stream) => {
313 if let Err(e) = stream.read_exact(&mut payload) {
314 return Outcome::Err(Error::Connection(ConnectionError {
315 kind: ConnectionErrorKind::Disconnected,
316 message: format!("Failed to read packet payload: {}", e),
317 source: Some(Box::new(e)),
318 }));
319 }
320 }
321 }
322 }
323
324 if payload_len == MAX_PACKET_SIZE {
326 loop {
327 let mut header_buf = [0u8; 4];
329 match &mut self.stream {
330 ConnectionStream::Async(stream) => {
331 let mut header_read = 0;
332 while header_read < 4 {
333 let mut read_buf = ReadBuf::new(&mut header_buf[header_read..]);
334 match std::future::poll_fn(|cx| {
335 std::pin::Pin::new(&mut *stream).poll_read(cx, &mut read_buf)
336 })
337 .await
338 {
339 Ok(()) => {
340 let n = read_buf.filled().len();
341 if n == 0 {
342 return Outcome::Err(Error::Connection(ConnectionError {
343 kind: ConnectionErrorKind::Disconnected,
344 message: "Connection closed while reading continuation header".to_string(),
345 source: None,
346 }));
347 }
348 header_read += n;
349 }
350 Err(e) => {
351 return Outcome::Err(Error::Connection(ConnectionError {
352 kind: ConnectionErrorKind::Disconnected,
353 message: format!(
354 "Failed to read continuation header: {}",
355 e
356 ),
357 source: Some(Box::new(e)),
358 }));
359 }
360 }
361 }
362 }
363 ConnectionStream::Sync(stream) => {
364 if let Err(e) = stream.read_exact(&mut header_buf) {
365 return Outcome::Err(Error::Connection(ConnectionError {
366 kind: ConnectionErrorKind::Disconnected,
367 message: format!("Failed to read continuation header: {}", e),
368 source: Some(Box::new(e)),
369 }));
370 }
371 }
372 }
373
374 let cont_header = PacketHeader::from_bytes(&header_buf);
375 let cont_len = cont_header.payload_length as usize;
376 self.sequence_id = cont_header.sequence_id.wrapping_add(1);
377
378 if cont_len > 0 {
379 let mut cont_payload = vec![0u8; cont_len];
380 match &mut self.stream {
381 ConnectionStream::Async(stream) => {
382 let mut total_read = 0;
383 while total_read < cont_len {
384 let mut read_buf = ReadBuf::new(&mut cont_payload[total_read..]);
385 match std::future::poll_fn(|cx| {
386 std::pin::Pin::new(&mut *stream).poll_read(cx, &mut read_buf)
387 })
388 .await
389 {
390 Ok(()) => {
391 let n = read_buf.filled().len();
392 if n == 0 {
393 return Outcome::Err(Error::Connection(ConnectionError {
394 kind: ConnectionErrorKind::Disconnected,
395 message: "Connection closed while reading continuation payload".to_string(),
396 source: None,
397 }));
398 }
399 total_read += n;
400 }
401 Err(e) => {
402 return Outcome::Err(Error::Connection(ConnectionError {
403 kind: ConnectionErrorKind::Disconnected,
404 message: format!(
405 "Failed to read continuation payload: {}",
406 e
407 ),
408 source: Some(Box::new(e)),
409 }));
410 }
411 }
412 }
413 }
414 ConnectionStream::Sync(stream) => {
415 if let Err(e) = stream.read_exact(&mut cont_payload) {
416 return Outcome::Err(Error::Connection(ConnectionError {
417 kind: ConnectionErrorKind::Disconnected,
418 message: format!("Failed to read continuation payload: {}", e),
419 source: Some(Box::new(e)),
420 }));
421 }
422 }
423 }
424 payload.extend_from_slice(&cont_payload);
425 }
426
427 if cont_len < MAX_PACKET_SIZE {
428 break;
429 }
430 }
431 }
432
433 Outcome::Ok((payload, header.sequence_id))
434 }
435
436 async fn write_packet_async(&mut self, payload: &[u8]) -> Outcome<(), Error> {
438 let writer = PacketWriter::new();
439 let packet = writer.build_packet_from_payload(payload, self.sequence_id);
440 self.sequence_id = self.sequence_id.wrapping_add(1);
441
442 match &mut self.stream {
443 ConnectionStream::Async(stream) => {
444 let mut written = 0;
446 while written < packet.len() {
447 match std::future::poll_fn(|cx| {
448 std::pin::Pin::new(&mut *stream).poll_write(cx, &packet[written..])
449 })
450 .await
451 {
452 Ok(n) => {
453 if n == 0 {
454 return Outcome::Err(Error::Connection(ConnectionError {
455 kind: ConnectionErrorKind::Disconnected,
456 message: "Connection closed while writing packet".to_string(),
457 source: None,
458 }));
459 }
460 written += n;
461 }
462 Err(e) => {
463 return Outcome::Err(Error::Connection(ConnectionError {
464 kind: ConnectionErrorKind::Disconnected,
465 message: format!("Failed to write packet: {}", e),
466 source: Some(Box::new(e)),
467 }));
468 }
469 }
470 }
471
472 match std::future::poll_fn(|cx| std::pin::Pin::new(&mut *stream).poll_flush(cx))
473 .await
474 {
475 Ok(()) => {}
476 Err(e) => {
477 return Outcome::Err(Error::Connection(ConnectionError {
478 kind: ConnectionErrorKind::Disconnected,
479 message: format!("Failed to flush stream: {}", e),
480 source: Some(Box::new(e)),
481 }));
482 }
483 }
484 }
485 ConnectionStream::Sync(stream) => {
486 if let Err(e) = stream.write_all(&packet) {
487 return Outcome::Err(Error::Connection(ConnectionError {
488 kind: ConnectionErrorKind::Disconnected,
489 message: format!("Failed to write packet: {}", e),
490 source: Some(Box::new(e)),
491 }));
492 }
493 if let Err(e) = stream.flush() {
494 return Outcome::Err(Error::Connection(ConnectionError {
495 kind: ConnectionErrorKind::Disconnected,
496 message: format!("Failed to flush stream: {}", e),
497 source: Some(Box::new(e)),
498 }));
499 }
500 }
501 }
502
503 Outcome::Ok(())
504 }
505
506 async fn read_handshake_async(&mut self) -> Outcome<ServerCapabilities, Error> {
510 let (payload, _) = match self.read_packet_async().await {
511 Outcome::Ok(p) => p,
512 Outcome::Err(e) => return Outcome::Err(e),
513 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
514 Outcome::Panicked(p) => return Outcome::Panicked(p),
515 };
516
517 let mut reader = PacketReader::new(&payload);
518
519 let Some(protocol_version) = reader.read_u8() else {
521 return Outcome::Err(protocol_error("Missing protocol version"));
522 };
523
524 if protocol_version != 10 {
525 return Outcome::Err(protocol_error(format!(
526 "Unsupported protocol version: {}",
527 protocol_version
528 )));
529 }
530
531 let Some(server_version) = reader.read_null_string() else {
533 return Outcome::Err(protocol_error("Missing server version"));
534 };
535
536 let Some(connection_id) = reader.read_u32_le() else {
538 return Outcome::Err(protocol_error("Missing connection ID"));
539 };
540
541 let Some(auth_data_1) = reader.read_bytes(8) else {
543 return Outcome::Err(protocol_error("Missing auth data"));
544 };
545
546 reader.skip(1);
548
549 let Some(caps_lower) = reader.read_u16_le() else {
551 return Outcome::Err(protocol_error("Missing capability flags"));
552 };
553
554 let charset_val = reader.read_u8().unwrap_or(charset::UTF8MB4_0900_AI_CI);
556
557 let status_flags = reader.read_u16_le().unwrap_or(0);
559
560 let caps_upper = reader.read_u16_le().unwrap_or(0);
562 let capabilities_val = u32::from(caps_lower) | (u32::from(caps_upper) << 16);
563
564 let auth_data_len = if capabilities_val & capabilities::CLIENT_PLUGIN_AUTH != 0 {
566 reader.read_u8().unwrap_or(0) as usize
567 } else {
568 0
569 };
570
571 reader.skip(10);
573
574 let mut auth_data = auth_data_1.to_vec();
576 if capabilities_val & capabilities::CLIENT_SECURE_CONNECTION != 0 {
577 let len2 = if auth_data_len > 8 {
578 auth_data_len - 8
579 } else {
580 13 };
582 if let Some(data2) = reader.read_bytes(len2) {
583 let data2_clean = if data2.last() == Some(&0) {
585 &data2[..data2.len() - 1]
586 } else {
587 data2
588 };
589 auth_data.extend_from_slice(data2_clean);
590 }
591 }
592
593 let auth_plugin = if capabilities_val & capabilities::CLIENT_PLUGIN_AUTH != 0 {
595 reader.read_null_string().unwrap_or_default()
596 } else {
597 auth::plugins::MYSQL_NATIVE_PASSWORD.to_string()
598 };
599
600 Outcome::Ok(ServerCapabilities {
601 capabilities: capabilities_val,
602 protocol_version,
603 server_version,
604 connection_id,
605 auth_plugin,
606 auth_data,
607 charset: charset_val,
608 status_flags,
609 })
610 }
611
612 async fn send_handshake_response_async(&mut self) -> Outcome<(), Error> {
614 let Some(server_caps) = self.server_caps.as_ref() else {
615 return Outcome::Err(protocol_error("No server handshake received"));
616 };
617
618 let client_caps = self.config.capability_flags() & server_caps.capabilities;
620
621 let auth_response =
623 self.compute_auth_response(&server_caps.auth_plugin, &server_caps.auth_data);
624
625 let mut writer = PacketWriter::new();
626
627 writer.write_u32_le(client_caps);
629
630 writer.write_u32_le(self.config.max_packet_size);
632
633 writer.write_u8(self.config.charset);
635
636 writer.write_zeros(23);
638
639 writer.write_null_string(&self.config.user);
641
642 if client_caps & capabilities::CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA != 0 {
644 writer.write_lenenc_bytes(&auth_response);
645 } else if client_caps & capabilities::CLIENT_SECURE_CONNECTION != 0 {
646 #[allow(clippy::cast_possible_truncation)]
647 writer.write_u8(auth_response.len() as u8);
648 writer.write_bytes(&auth_response);
649 } else {
650 writer.write_bytes(&auth_response);
651 writer.write_u8(0); }
653
654 if client_caps & capabilities::CLIENT_CONNECT_WITH_DB != 0 {
656 if let Some(ref db) = self.config.database {
657 writer.write_null_string(db);
658 } else {
659 writer.write_u8(0); }
661 }
662
663 if client_caps & capabilities::CLIENT_PLUGIN_AUTH != 0 {
665 writer.write_null_string(&server_caps.auth_plugin);
666 }
667
668 if client_caps & capabilities::CLIENT_CONNECT_ATTRS != 0
670 && !self.config.attributes.is_empty()
671 {
672 let mut attrs_writer = PacketWriter::new();
673 for (key, value) in &self.config.attributes {
674 attrs_writer.write_lenenc_string(key);
675 attrs_writer.write_lenenc_string(value);
676 }
677 let attrs_data = attrs_writer.into_bytes();
678 writer.write_lenenc_bytes(&attrs_data);
679 }
680
681 self.write_packet_async(writer.as_bytes()).await
682 }
683
684 fn compute_auth_response(&self, plugin: &str, auth_data: &[u8]) -> Vec<u8> {
686 let password = self.config.password.as_deref().unwrap_or("");
687
688 match plugin {
689 auth::plugins::MYSQL_NATIVE_PASSWORD => {
690 auth::mysql_native_password(password, auth_data)
691 }
692 auth::plugins::CACHING_SHA2_PASSWORD => {
693 auth::caching_sha2_password(password, auth_data)
694 }
695 auth::plugins::MYSQL_CLEAR_PASSWORD => {
696 let mut result = password.as_bytes().to_vec();
697 result.push(0);
698 result
699 }
700 _ => auth::mysql_native_password(password, auth_data),
701 }
702 }
703
704 async fn handle_auth_result_async(&mut self) -> Outcome<(), Error> {
707 loop {
709 let (payload, _) = match self.read_packet_async().await {
710 Outcome::Ok(p) => p,
711 Outcome::Err(e) => return Outcome::Err(e),
712 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
713 Outcome::Panicked(p) => return Outcome::Panicked(p),
714 };
715
716 if payload.is_empty() {
717 return Outcome::Err(protocol_error("Empty authentication response"));
718 }
719
720 #[allow(clippy::cast_possible_truncation)] match PacketType::from_first_byte(payload[0], payload.len() as u32) {
722 PacketType::Ok => {
723 let mut reader = PacketReader::new(&payload);
724 if let Some(ok) = reader.parse_ok_packet() {
725 self.status_flags = ok.status_flags;
726 self.affected_rows = ok.affected_rows;
727 }
728 return Outcome::Ok(());
729 }
730 PacketType::Error => {
731 let mut reader = PacketReader::new(&payload);
732 let Some(err) = reader.parse_err_packet() else {
733 return Outcome::Err(protocol_error("Invalid error packet"));
734 };
735 return Outcome::Err(auth_error(format!(
736 "Authentication failed: {} ({})",
737 err.error_message, err.error_code
738 )));
739 }
740 PacketType::Eof => {
741 let data = &payload[1..];
743 let mut reader = PacketReader::new(data);
744
745 let Some(plugin) = reader.read_null_string() else {
746 return Outcome::Err(protocol_error("Missing plugin name in auth switch"));
747 };
748
749 let auth_data = reader.read_rest();
750 let response = self.compute_auth_response(&plugin, auth_data);
751
752 if let Outcome::Err(e) = self.write_packet_async(&response).await {
753 return Outcome::Err(e);
754 }
755 }
757 _ => {
758 return self.handle_additional_auth_async(&payload).await;
760 }
761 }
762 }
763 }
764
765 async fn handle_additional_auth_async(&mut self, data: &[u8]) -> Outcome<(), Error> {
767 if data.is_empty() {
768 return Outcome::Err(protocol_error("Empty additional auth data"));
769 }
770
771 match data[0] {
772 auth::caching_sha2::FAST_AUTH_SUCCESS => {
773 let (payload, _) = match self.read_packet_async().await {
774 Outcome::Ok(p) => p,
775 Outcome::Err(e) => return Outcome::Err(e),
776 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
777 Outcome::Panicked(p) => return Outcome::Panicked(p),
778 };
779 let mut reader = PacketReader::new(&payload);
780 if let Some(ok) = reader.parse_ok_packet() {
781 self.status_flags = ok.status_flags;
782 }
783 Outcome::Ok(())
784 }
785 auth::caching_sha2::PERFORM_FULL_AUTH => Outcome::Err(auth_error(
786 "Full authentication required - please use TLS connection",
787 )),
788 _ => {
789 let mut reader = PacketReader::new(data);
790 if let Some(ok) = reader.parse_ok_packet() {
791 self.status_flags = ok.status_flags;
792 Outcome::Ok(())
793 } else {
794 Outcome::Err(protocol_error(format!(
795 "Unknown auth response: {:02X}",
796 data[0]
797 )))
798 }
799 }
800 }
801 }
802
803 pub async fn query_async(
805 &mut self,
806 _cx: &Cx,
807 sql: &str,
808 params: &[Value],
809 ) -> Outcome<Vec<Row>, Error> {
810 let sql = interpolate_params(sql, params);
811 if !self.is_ready() && self.state != ConnectionState::InTransaction {
812 return Outcome::Err(connection_error("Connection not ready for queries"));
813 }
814
815 self.state = ConnectionState::InQuery;
816 self.sequence_id = 0;
817
818 let mut writer = PacketWriter::new();
820 writer.write_u8(Command::Query as u8);
821 writer.write_bytes(sql.as_bytes());
822
823 if let Outcome::Err(e) = self.write_packet_async(writer.as_bytes()).await {
824 return Outcome::Err(e);
825 }
826
827 let (payload, _) = match self.read_packet_async().await {
829 Outcome::Ok(p) => p,
830 Outcome::Err(e) => return Outcome::Err(e),
831 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
832 Outcome::Panicked(p) => return Outcome::Panicked(p),
833 };
834
835 if payload.is_empty() {
836 self.state = ConnectionState::Ready;
837 return Outcome::Err(protocol_error("Empty query response"));
838 }
839
840 #[allow(clippy::cast_possible_truncation)] match PacketType::from_first_byte(payload[0], payload.len() as u32) {
842 PacketType::Ok => {
843 let mut reader = PacketReader::new(&payload);
844 if let Some(ok) = reader.parse_ok_packet() {
845 self.affected_rows = ok.affected_rows;
846 self.last_insert_id = ok.last_insert_id;
847 self.status_flags = ok.status_flags;
848 self.warnings = ok.warnings;
849 }
850 self.state = if self.status_flags
851 & crate::protocol::server_status::SERVER_STATUS_IN_TRANS
852 != 0
853 {
854 ConnectionState::InTransaction
855 } else {
856 ConnectionState::Ready
857 };
858 Outcome::Ok(vec![])
859 }
860 PacketType::Error => {
861 self.state = ConnectionState::Ready;
862 let mut reader = PacketReader::new(&payload);
863 let Some(err) = reader.parse_err_packet() else {
864 return Outcome::Err(protocol_error("Invalid error packet"));
865 };
866 Outcome::Err(query_error(&err))
867 }
868 PacketType::LocalInfile => {
869 self.state = ConnectionState::Ready;
870 Outcome::Err(query_error_msg("LOCAL INFILE not supported"))
871 }
872 _ => self.read_result_set_async(&payload).await,
873 }
874 }
875
876 async fn read_result_set_async(&mut self, first_packet: &[u8]) -> Outcome<Vec<Row>, Error> {
878 let mut reader = PacketReader::new(first_packet);
879 #[allow(clippy::cast_possible_truncation)] let Some(column_count) = reader.read_lenenc_int().map(|c| c as usize) else {
881 return Outcome::Err(protocol_error("Invalid column count"));
882 };
883
884 let mut columns = Vec::with_capacity(column_count);
886 for _ in 0..column_count {
887 let (payload, _) = match self.read_packet_async().await {
888 Outcome::Ok(p) => p,
889 Outcome::Err(e) => return Outcome::Err(e),
890 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
891 Outcome::Panicked(p) => return Outcome::Panicked(p),
892 };
893 match self.parse_column_def(&payload) {
894 Ok(col) => columns.push(col),
895 Err(e) => return Outcome::Err(e),
896 }
897 }
898
899 let server_caps = self.server_caps.as_ref().map_or(0, |c| c.capabilities);
901 if server_caps & capabilities::CLIENT_DEPRECATE_EOF == 0 {
902 let (payload, _) = match self.read_packet_async().await {
903 Outcome::Ok(p) => p,
904 Outcome::Err(e) => return Outcome::Err(e),
905 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
906 Outcome::Panicked(p) => return Outcome::Panicked(p),
907 };
908 if payload.first() == Some(&0xFE) {
909 }
911 }
912
913 let mut rows = Vec::new();
915 loop {
916 let (payload, _) = match self.read_packet_async().await {
917 Outcome::Ok(p) => p,
918 Outcome::Err(e) => return Outcome::Err(e),
919 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
920 Outcome::Panicked(p) => return Outcome::Panicked(p),
921 };
922
923 if payload.is_empty() {
924 break;
925 }
926
927 #[allow(clippy::cast_possible_truncation)] match PacketType::from_first_byte(payload[0], payload.len() as u32) {
929 PacketType::Eof | PacketType::Ok => {
930 let mut reader = PacketReader::new(&payload);
931 if payload[0] == 0x00 {
932 if let Some(ok) = reader.parse_ok_packet() {
933 self.status_flags = ok.status_flags;
934 self.warnings = ok.warnings;
935 }
936 } else if payload[0] == 0xFE {
937 if let Some(eof) = reader.parse_eof_packet() {
938 self.status_flags = eof.status_flags;
939 self.warnings = eof.warnings;
940 }
941 }
942 break;
943 }
944 PacketType::Error => {
945 let mut reader = PacketReader::new(&payload);
946 let Some(err) = reader.parse_err_packet() else {
947 return Outcome::Err(protocol_error("Invalid error packet"));
948 };
949 self.state = ConnectionState::Ready;
950 return Outcome::Err(query_error(&err));
951 }
952 _ => {
953 let row = self.parse_text_row(&payload, &columns);
954 rows.push(row);
955 }
956 }
957 }
958
959 self.state =
960 if self.status_flags & crate::protocol::server_status::SERVER_STATUS_IN_TRANS != 0 {
961 ConnectionState::InTransaction
962 } else {
963 ConnectionState::Ready
964 };
965
966 Outcome::Ok(rows)
967 }
968
969 fn parse_column_def(&self, data: &[u8]) -> Result<ColumnDef, Error> {
971 let mut reader = PacketReader::new(data);
972
973 let catalog = reader
974 .read_lenenc_string()
975 .ok_or_else(|| protocol_error("Missing catalog"))?;
976 let schema = reader
977 .read_lenenc_string()
978 .ok_or_else(|| protocol_error("Missing schema"))?;
979 let table = reader
980 .read_lenenc_string()
981 .ok_or_else(|| protocol_error("Missing table"))?;
982 let org_table = reader
983 .read_lenenc_string()
984 .ok_or_else(|| protocol_error("Missing org_table"))?;
985 let name = reader
986 .read_lenenc_string()
987 .ok_or_else(|| protocol_error("Missing name"))?;
988 let org_name = reader
989 .read_lenenc_string()
990 .ok_or_else(|| protocol_error("Missing org_name"))?;
991
992 let _fixed_len = reader.read_lenenc_int();
993
994 let charset_val = reader
995 .read_u16_le()
996 .ok_or_else(|| protocol_error("Missing charset"))?;
997 let column_length = reader
998 .read_u32_le()
999 .ok_or_else(|| protocol_error("Missing column_length"))?;
1000 let column_type = FieldType::from_u8(
1001 reader
1002 .read_u8()
1003 .ok_or_else(|| protocol_error("Missing column_type"))?,
1004 );
1005 let flags = reader
1006 .read_u16_le()
1007 .ok_or_else(|| protocol_error("Missing flags"))?;
1008 let decimals = reader
1009 .read_u8()
1010 .ok_or_else(|| protocol_error("Missing decimals"))?;
1011
1012 Ok(ColumnDef {
1013 catalog,
1014 schema,
1015 table,
1016 org_table,
1017 name,
1018 org_name,
1019 charset: charset_val,
1020 column_length,
1021 column_type,
1022 flags,
1023 decimals,
1024 })
1025 }
1026
1027 fn parse_text_row(&self, data: &[u8], columns: &[ColumnDef]) -> Row {
1029 let mut reader = PacketReader::new(data);
1030 let mut values = Vec::with_capacity(columns.len());
1031
1032 for col in columns {
1033 if reader.peek() == Some(0xFB) {
1034 reader.skip(1);
1035 values.push(Value::Null);
1036 } else if let Some(data) = reader.read_lenenc_bytes() {
1037 let is_unsigned = col.is_unsigned();
1038 let value = decode_text_value(col.column_type, &data, is_unsigned);
1039 values.push(value);
1040 } else {
1041 values.push(Value::Null);
1042 }
1043 }
1044
1045 let column_names: Vec<String> = columns.iter().map(|c| c.name.clone()).collect();
1046 Row::new(column_names, values)
1047 }
1048
1049 pub async fn execute_async(
1054 &mut self,
1055 cx: &Cx,
1056 sql: &str,
1057 params: &[Value],
1058 ) -> Outcome<u64, Error> {
1059 match self.query_async(cx, sql, params).await {
1061 Outcome::Ok(_) => Outcome::Ok(self.affected_rows),
1062 Outcome::Err(e) => Outcome::Err(e),
1063 Outcome::Cancelled(c) => Outcome::Cancelled(c),
1064 Outcome::Panicked(p) => Outcome::Panicked(p),
1065 }
1066 }
1067
1068 pub async fn prepare_async(
1073 &mut self,
1074 _cx: &Cx,
1075 sql: &str,
1076 ) -> Outcome<PreparedStatement, Error> {
1077 if !self.is_ready() && self.state != ConnectionState::InTransaction {
1078 return Outcome::Err(connection_error("Connection not ready for prepare"));
1079 }
1080
1081 self.sequence_id = 0;
1082
1083 let packet = prepared::build_stmt_prepare_packet(sql, self.sequence_id);
1085 if let Outcome::Err(e) = self.write_packet_raw_async(&packet).await {
1086 return Outcome::Err(e);
1087 }
1088
1089 let (payload, _) = match self.read_packet_async().await {
1091 Outcome::Ok(p) => p,
1092 Outcome::Err(e) => return Outcome::Err(e),
1093 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1094 Outcome::Panicked(p) => return Outcome::Panicked(p),
1095 };
1096
1097 if payload.first() == Some(&0xFF) {
1099 let mut reader = PacketReader::new(&payload);
1100 let Some(err) = reader.parse_err_packet() else {
1101 return Outcome::Err(protocol_error("Invalid error packet"));
1102 };
1103 return Outcome::Err(query_error(&err));
1104 }
1105
1106 let Some(prep_ok) = prepared::parse_stmt_prepare_ok(&payload) else {
1108 return Outcome::Err(protocol_error("Invalid prepare OK response"));
1109 };
1110
1111 let mut param_defs = Vec::with_capacity(prep_ok.num_params as usize);
1113 for _ in 0..prep_ok.num_params {
1114 let (payload, _) = match self.read_packet_async().await {
1115 Outcome::Ok(p) => p,
1116 Outcome::Err(e) => return Outcome::Err(e),
1117 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1118 Outcome::Panicked(p) => return Outcome::Panicked(p),
1119 };
1120 match self.parse_column_def(&payload) {
1121 Ok(col) => param_defs.push(col),
1122 Err(e) => return Outcome::Err(e),
1123 }
1124 }
1125
1126 let server_caps = self.server_caps.as_ref().map_or(0, |c| c.capabilities);
1128 if prep_ok.num_params > 0 && server_caps & capabilities::CLIENT_DEPRECATE_EOF == 0 {
1129 let (payload, _) = match self.read_packet_async().await {
1130 Outcome::Ok(p) => p,
1131 Outcome::Err(e) => return Outcome::Err(e),
1132 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1133 Outcome::Panicked(p) => return Outcome::Panicked(p),
1134 };
1135 if payload.first() != Some(&0xFE) {
1136 return Outcome::Err(protocol_error("Expected EOF after param definitions"));
1137 }
1138 }
1139
1140 let mut column_defs = Vec::with_capacity(prep_ok.num_columns as usize);
1142 for _ in 0..prep_ok.num_columns {
1143 let (payload, _) = match self.read_packet_async().await {
1144 Outcome::Ok(p) => p,
1145 Outcome::Err(e) => return Outcome::Err(e),
1146 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1147 Outcome::Panicked(p) => return Outcome::Panicked(p),
1148 };
1149 match self.parse_column_def(&payload) {
1150 Ok(col) => column_defs.push(col),
1151 Err(e) => return Outcome::Err(e),
1152 }
1153 }
1154
1155 if prep_ok.num_columns > 0 && server_caps & capabilities::CLIENT_DEPRECATE_EOF == 0 {
1157 let (payload, _) = match self.read_packet_async().await {
1158 Outcome::Ok(p) => p,
1159 Outcome::Err(e) => return Outcome::Err(e),
1160 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1161 Outcome::Panicked(p) => return Outcome::Panicked(p),
1162 };
1163 if payload.first() != Some(&0xFE) {
1164 return Outcome::Err(protocol_error("Expected EOF after column definitions"));
1165 }
1166 }
1167
1168 let meta = PreparedStmtMeta {
1170 statement_id: prep_ok.statement_id,
1171 params: param_defs,
1172 columns: column_defs.clone(),
1173 };
1174 self.prepared_stmts.insert(prep_ok.statement_id, meta);
1175
1176 let column_names: Vec<String> = column_defs.iter().map(|c| c.name.clone()).collect();
1178 Outcome::Ok(PreparedStatement::with_columns(
1179 u64::from(prep_ok.statement_id),
1180 sql.to_string(),
1181 prep_ok.num_params as usize,
1182 column_names,
1183 ))
1184 }
1185
1186 pub async fn query_prepared_async(
1188 &mut self,
1189 _cx: &Cx,
1190 stmt: &PreparedStatement,
1191 params: &[Value],
1192 ) -> Outcome<Vec<Row>, Error> {
1193 #[allow(clippy::cast_possible_truncation)] let stmt_id = stmt.id() as u32;
1195
1196 let Some(meta) = self.prepared_stmts.get(&stmt_id).cloned() else {
1198 return Outcome::Err(connection_error("Unknown prepared statement"));
1199 };
1200
1201 if params.len() != meta.params.len() {
1203 return Outcome::Err(connection_error(format!(
1204 "Expected {} parameters, got {}",
1205 meta.params.len(),
1206 params.len()
1207 )));
1208 }
1209
1210 if !self.is_ready() && self.state != ConnectionState::InTransaction {
1211 return Outcome::Err(connection_error("Connection not ready for query"));
1212 }
1213
1214 self.state = ConnectionState::InQuery;
1215 self.sequence_id = 0;
1216
1217 let param_types: Vec<FieldType> = meta.params.iter().map(|c| c.column_type).collect();
1219 let packet = prepared::build_stmt_execute_packet(
1220 stmt_id,
1221 params,
1222 Some(¶m_types),
1223 self.sequence_id,
1224 );
1225 if let Outcome::Err(e) = self.write_packet_raw_async(&packet).await {
1226 return Outcome::Err(e);
1227 }
1228
1229 let (payload, _) = match self.read_packet_async().await {
1231 Outcome::Ok(p) => p,
1232 Outcome::Err(e) => return Outcome::Err(e),
1233 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1234 Outcome::Panicked(p) => return Outcome::Panicked(p),
1235 };
1236
1237 if payload.is_empty() {
1238 self.state = ConnectionState::Ready;
1239 return Outcome::Err(protocol_error("Empty execute response"));
1240 }
1241
1242 #[allow(clippy::cast_possible_truncation)] match PacketType::from_first_byte(payload[0], payload.len() as u32) {
1244 PacketType::Ok => {
1245 let mut reader = PacketReader::new(&payload);
1247 if let Some(ok) = reader.parse_ok_packet() {
1248 self.affected_rows = ok.affected_rows;
1249 self.last_insert_id = ok.last_insert_id;
1250 self.status_flags = ok.status_flags;
1251 self.warnings = ok.warnings;
1252 }
1253 self.state = ConnectionState::Ready;
1254 Outcome::Ok(vec![])
1255 }
1256 PacketType::Error => {
1257 self.state = ConnectionState::Ready;
1258 let mut reader = PacketReader::new(&payload);
1259 let Some(err) = reader.parse_err_packet() else {
1260 return Outcome::Err(protocol_error("Invalid error packet"));
1261 };
1262 Outcome::Err(query_error(&err))
1263 }
1264 _ => {
1265 self.read_binary_result_set_async(&payload, &meta.columns)
1267 .await
1268 }
1269 }
1270 }
1271
1272 pub async fn execute_prepared_async(
1274 &mut self,
1275 cx: &Cx,
1276 stmt: &PreparedStatement,
1277 params: &[Value],
1278 ) -> Outcome<u64, Error> {
1279 match self.query_prepared_async(cx, stmt, params).await {
1280 Outcome::Ok(_) => Outcome::Ok(self.affected_rows),
1281 Outcome::Err(e) => Outcome::Err(e),
1282 Outcome::Cancelled(c) => Outcome::Cancelled(c),
1283 Outcome::Panicked(p) => Outcome::Panicked(p),
1284 }
1285 }
1286
1287 pub async fn close_prepared_async(&mut self, stmt: &PreparedStatement) {
1289 #[allow(clippy::cast_possible_truncation)] let stmt_id = stmt.id() as u32;
1291 self.prepared_stmts.remove(&stmt_id);
1292
1293 self.sequence_id = 0;
1294 let packet = prepared::build_stmt_close_packet(stmt_id, self.sequence_id);
1295 let _ = self.write_packet_raw_async(&packet).await;
1297 }
1298
1299 async fn read_binary_result_set_async(
1301 &mut self,
1302 first_packet: &[u8],
1303 columns: &[ColumnDef],
1304 ) -> Outcome<Vec<Row>, Error> {
1305 let mut reader = PacketReader::new(first_packet);
1307 #[allow(clippy::cast_possible_truncation)] let Some(column_count) = reader.read_lenenc_int().map(|c| c as usize) else {
1309 return Outcome::Err(protocol_error("Invalid column count"));
1310 };
1311
1312 let mut result_columns = Vec::with_capacity(column_count);
1315 for _ in 0..column_count {
1316 let (payload, _) = match self.read_packet_async().await {
1317 Outcome::Ok(p) => p,
1318 Outcome::Err(e) => return Outcome::Err(e),
1319 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1320 Outcome::Panicked(p) => return Outcome::Panicked(p),
1321 };
1322 match self.parse_column_def(&payload) {
1323 Ok(col) => result_columns.push(col),
1324 Err(e) => return Outcome::Err(e),
1325 }
1326 }
1327
1328 let cols = if result_columns.len() == columns.len() {
1330 &result_columns
1331 } else {
1332 columns
1333 };
1334
1335 let server_caps = self.server_caps.as_ref().map_or(0, |c| c.capabilities);
1337 if server_caps & capabilities::CLIENT_DEPRECATE_EOF == 0 {
1338 let (payload, _) = match self.read_packet_async().await {
1339 Outcome::Ok(p) => p,
1340 Outcome::Err(e) => return Outcome::Err(e),
1341 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1342 Outcome::Panicked(p) => return Outcome::Panicked(p),
1343 };
1344 if payload.first() == Some(&0xFE) {
1345 }
1347 }
1348
1349 let mut rows = Vec::new();
1351 loop {
1352 let (payload, _) = match self.read_packet_async().await {
1353 Outcome::Ok(p) => p,
1354 Outcome::Err(e) => return Outcome::Err(e),
1355 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1356 Outcome::Panicked(p) => return Outcome::Panicked(p),
1357 };
1358
1359 if payload.is_empty() {
1360 break;
1361 }
1362
1363 #[allow(clippy::cast_possible_truncation)] match PacketType::from_first_byte(payload[0], payload.len() as u32) {
1365 PacketType::Eof | PacketType::Ok => {
1366 let mut reader = PacketReader::new(&payload);
1367 if payload[0] == 0x00 {
1368 if let Some(ok) = reader.parse_ok_packet() {
1369 self.status_flags = ok.status_flags;
1370 self.warnings = ok.warnings;
1371 }
1372 } else if payload[0] == 0xFE {
1373 if let Some(eof) = reader.parse_eof_packet() {
1374 self.status_flags = eof.status_flags;
1375 self.warnings = eof.warnings;
1376 }
1377 }
1378 break;
1379 }
1380 PacketType::Error => {
1381 let mut reader = PacketReader::new(&payload);
1382 let Some(err) = reader.parse_err_packet() else {
1383 return Outcome::Err(protocol_error("Invalid error packet"));
1384 };
1385 self.state = ConnectionState::Ready;
1386 return Outcome::Err(query_error(&err));
1387 }
1388 _ => {
1389 let row = self.parse_binary_row(&payload, cols);
1390 rows.push(row);
1391 }
1392 }
1393 }
1394
1395 self.state =
1396 if self.status_flags & crate::protocol::server_status::SERVER_STATUS_IN_TRANS != 0 {
1397 ConnectionState::InTransaction
1398 } else {
1399 ConnectionState::Ready
1400 };
1401
1402 Outcome::Ok(rows)
1403 }
1404
1405 fn parse_binary_row(&self, data: &[u8], columns: &[ColumnDef]) -> Row {
1407 let mut values = Vec::with_capacity(columns.len());
1413 let mut column_names = Vec::with_capacity(columns.len());
1414
1415 if data.is_empty() {
1416 return Row::new(column_names, values);
1417 }
1418
1419 let mut pos = 1;
1421
1422 let null_bitmap_len = (columns.len() + 7 + 2) / 8;
1425 if pos + null_bitmap_len > data.len() {
1426 return Row::new(column_names, values);
1427 }
1428 let null_bitmap = &data[pos..pos + null_bitmap_len];
1429 pos += null_bitmap_len;
1430
1431 for (i, col) in columns.iter().enumerate() {
1433 column_names.push(col.name.clone());
1434
1435 let bit_pos = i + 2;
1437 let is_null = (null_bitmap[bit_pos / 8] & (1 << (bit_pos % 8))) != 0;
1438
1439 if is_null {
1440 values.push(Value::Null);
1441 } else {
1442 let is_unsigned = col.flags & 0x20 != 0; let (value, consumed) =
1444 decode_binary_value_with_len(&data[pos..], col.column_type, is_unsigned);
1445 values.push(value);
1446 pos += consumed;
1447 }
1448 }
1449
1450 Row::new(column_names, values)
1451 }
1452
1453 async fn write_packet_raw_async(&mut self, packet: &[u8]) -> Outcome<(), Error> {
1455 match &mut self.stream {
1456 ConnectionStream::Async(stream) => {
1457 let mut written = 0;
1458 while written < packet.len() {
1459 match std::future::poll_fn(|cx| {
1460 std::pin::Pin::new(&mut *stream).poll_write(cx, &packet[written..])
1461 })
1462 .await
1463 {
1464 Ok(n) => written += n,
1465 Err(e) => {
1466 return Outcome::Err(Error::Connection(ConnectionError {
1467 kind: ConnectionErrorKind::Disconnected,
1468 message: format!("Failed to write packet: {}", e),
1469 source: Some(Box::new(e)),
1470 }));
1471 }
1472 }
1473 }
1474 if let Err(e) =
1476 std::future::poll_fn(|cx| std::pin::Pin::new(&mut *stream).poll_flush(cx)).await
1477 {
1478 return Outcome::Err(Error::Connection(ConnectionError {
1479 kind: ConnectionErrorKind::Disconnected,
1480 message: format!("Failed to flush: {}", e),
1481 source: Some(Box::new(e)),
1482 }));
1483 }
1484 Outcome::Ok(())
1485 }
1486 ConnectionStream::Sync(stream) => {
1487 if let Err(e) = stream.write_all(packet) {
1488 return Outcome::Err(Error::Connection(ConnectionError {
1489 kind: ConnectionErrorKind::Disconnected,
1490 message: format!("Failed to write packet: {}", e),
1491 source: Some(Box::new(e)),
1492 }));
1493 }
1494 if let Err(e) = stream.flush() {
1495 return Outcome::Err(Error::Connection(ConnectionError {
1496 kind: ConnectionErrorKind::Disconnected,
1497 message: format!("Failed to flush: {}", e),
1498 source: Some(Box::new(e)),
1499 }));
1500 }
1501 Outcome::Ok(())
1502 }
1503 }
1504 }
1505
1506 pub async fn ping_async(&mut self, _cx: &Cx) -> Outcome<(), Error> {
1508 self.sequence_id = 0;
1509
1510 let mut writer = PacketWriter::new();
1511 writer.write_u8(Command::Ping as u8);
1512
1513 if let Outcome::Err(e) = self.write_packet_async(writer.as_bytes()).await {
1514 return Outcome::Err(e);
1515 }
1516
1517 let (payload, _) = match self.read_packet_async().await {
1518 Outcome::Ok(p) => p,
1519 Outcome::Err(e) => return Outcome::Err(e),
1520 Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1521 Outcome::Panicked(p) => return Outcome::Panicked(p),
1522 };
1523
1524 if payload.first() == Some(&0x00) {
1525 Outcome::Ok(())
1526 } else {
1527 Outcome::Err(connection_error("Ping failed"))
1528 }
1529 }
1530
1531 pub async fn close_async(mut self, _cx: &Cx) -> Result<(), Error> {
1533 if self.state == ConnectionState::Closed {
1534 return Ok(());
1535 }
1536
1537 self.sequence_id = 0;
1538
1539 let mut writer = PacketWriter::new();
1540 writer.write_u8(Command::Quit as u8);
1541
1542 let _ = self.write_packet_async(writer.as_bytes()).await;
1544
1545 self.state = ConnectionState::Closed;
1546 Ok(())
1547 }
1548}
1549
1550impl Connection for MySqlAsyncConnection {
1553 type Tx<'conn> = MySqlTransaction<'conn>;
1554
1555 fn query(
1556 &self,
1557 _cx: &Cx,
1558 _sql: &str,
1559 _params: &[Value],
1560 ) -> impl Future<Output = Outcome<Vec<Row>, Error>> + Send {
1561 async move {
1565 Outcome::Err(connection_error(
1568 "Query requires mutable access - use query_async directly",
1569 ))
1570 }
1571 }
1572
1573 fn query_one(
1574 &self,
1575 _cx: &Cx,
1576 _sql: &str,
1577 _params: &[Value],
1578 ) -> impl Future<Output = Outcome<Option<Row>, Error>> + Send {
1579 async move {
1580 Outcome::Err(connection_error(
1581 "Query requires mutable access - use query_async directly",
1582 ))
1583 }
1584 }
1585
1586 fn execute(
1587 &self,
1588 _cx: &Cx,
1589 _sql: &str,
1590 _params: &[Value],
1591 ) -> impl Future<Output = Outcome<u64, Error>> + Send {
1592 async move {
1593 Outcome::Err(connection_error(
1594 "Execute requires mutable access - use query_async directly",
1595 ))
1596 }
1597 }
1598
1599 fn insert(
1600 &self,
1601 _cx: &Cx,
1602 _sql: &str,
1603 _params: &[Value],
1604 ) -> impl Future<Output = Outcome<i64, Error>> + Send {
1605 async move {
1606 Outcome::Err(connection_error(
1607 "Insert requires mutable access - use query_async directly",
1608 ))
1609 }
1610 }
1611
1612 fn batch(
1613 &self,
1614 _cx: &Cx,
1615 _statements: &[(String, Vec<Value>)],
1616 ) -> impl Future<Output = Outcome<Vec<u64>, Error>> + Send {
1617 async move {
1618 Outcome::Err(connection_error(
1619 "Batch requires mutable access - use query_async directly",
1620 ))
1621 }
1622 }
1623
1624 fn begin(&self, _cx: &Cx) -> impl Future<Output = Outcome<Self::Tx<'_>, Error>> + Send {
1625 async move {
1626 Outcome::Err(connection_error(
1627 "Begin requires mutable access - use transaction methods directly",
1628 ))
1629 }
1630 }
1631
1632 fn begin_with(
1633 &self,
1634 _cx: &Cx,
1635 _isolation: IsolationLevel,
1636 ) -> impl Future<Output = Outcome<Self::Tx<'_>, Error>> + Send {
1637 async move {
1638 Outcome::Err(connection_error(
1639 "Begin requires mutable access - use transaction methods directly",
1640 ))
1641 }
1642 }
1643
1644 fn prepare(
1645 &self,
1646 _cx: &Cx,
1647 _sql: &str,
1648 ) -> impl Future<Output = Outcome<PreparedStatement, Error>> + Send {
1649 async move {
1650 Outcome::Err(connection_error(
1651 "Prepare not yet implemented for MySQL async",
1652 ))
1653 }
1654 }
1655
1656 fn query_prepared(
1657 &self,
1658 _cx: &Cx,
1659 _stmt: &PreparedStatement,
1660 _params: &[Value],
1661 ) -> impl Future<Output = Outcome<Vec<Row>, Error>> + Send {
1662 async move {
1663 Outcome::Err(connection_error(
1664 "Prepared query not yet implemented for MySQL async",
1665 ))
1666 }
1667 }
1668
1669 fn execute_prepared(
1670 &self,
1671 _cx: &Cx,
1672 _stmt: &PreparedStatement,
1673 _params: &[Value],
1674 ) -> impl Future<Output = Outcome<u64, Error>> + Send {
1675 async move {
1676 Outcome::Err(connection_error(
1677 "Prepared execute not yet implemented for MySQL async",
1678 ))
1679 }
1680 }
1681
1682 fn ping(&self, _cx: &Cx) -> impl Future<Output = Outcome<(), Error>> + Send {
1683 async move {
1684 Outcome::Err(connection_error(
1685 "Ping requires mutable access - use ping_async directly",
1686 ))
1687 }
1688 }
1689
1690 fn close(self, cx: &Cx) -> impl Future<Output = Result<(), Error>> + Send {
1691 async move { self.close_async(cx).await }
1692 }
1693}
1694
1695pub struct MySqlTransaction<'conn> {
1697 #[allow(dead_code)]
1698 conn: &'conn mut MySqlAsyncConnection,
1699}
1700
1701impl<'conn> TransactionOps for MySqlTransaction<'conn> {
1702 fn query(
1703 &self,
1704 _cx: &Cx,
1705 _sql: &str,
1706 _params: &[Value],
1707 ) -> impl Future<Output = Outcome<Vec<Row>, Error>> + Send {
1708 async move { Outcome::Err(connection_error("Transaction query not yet implemented")) }
1709 }
1710
1711 fn query_one(
1712 &self,
1713 _cx: &Cx,
1714 _sql: &str,
1715 _params: &[Value],
1716 ) -> impl Future<Output = Outcome<Option<Row>, Error>> + Send {
1717 async move {
1718 Outcome::Err(connection_error(
1719 "Transaction query_one not yet implemented",
1720 ))
1721 }
1722 }
1723
1724 fn execute(
1725 &self,
1726 _cx: &Cx,
1727 _sql: &str,
1728 _params: &[Value],
1729 ) -> impl Future<Output = Outcome<u64, Error>> + Send {
1730 async move { Outcome::Err(connection_error("Transaction execute not yet implemented")) }
1731 }
1732
1733 fn savepoint(&self, _cx: &Cx, _name: &str) -> impl Future<Output = Outcome<(), Error>> + Send {
1734 async move {
1735 Outcome::Err(connection_error(
1736 "Transaction savepoint not yet implemented",
1737 ))
1738 }
1739 }
1740
1741 fn rollback_to(
1742 &self,
1743 _cx: &Cx,
1744 _name: &str,
1745 ) -> impl Future<Output = Outcome<(), Error>> + Send {
1746 async move {
1747 Outcome::Err(connection_error(
1748 "Transaction rollback_to not yet implemented",
1749 ))
1750 }
1751 }
1752
1753 fn release(&self, _cx: &Cx, _name: &str) -> impl Future<Output = Outcome<(), Error>> + Send {
1754 async move { Outcome::Err(connection_error("Transaction release not yet implemented")) }
1755 }
1756
1757 fn commit(self, _cx: &Cx) -> impl Future<Output = Outcome<(), Error>> + Send {
1758 async move { Outcome::Err(connection_error("Transaction commit not yet implemented")) }
1759 }
1760
1761 fn rollback(self, _cx: &Cx) -> impl Future<Output = Outcome<(), Error>> + Send {
1762 async move { Outcome::Err(connection_error("Transaction rollback not yet implemented")) }
1763 }
1764}
1765
1766#[cfg(feature = "console")]
1769impl ConsoleAware for MySqlAsyncConnection {
1770 fn set_console(&mut self, console: Option<Arc<SqlModelConsole>>) {
1771 self.console = console;
1772 }
1773
1774 fn console(&self) -> Option<&Arc<SqlModelConsole>> {
1775 self.console.as_ref()
1776 }
1777}
1778
1779fn protocol_error(msg: impl Into<String>) -> Error {
1782 Error::Protocol(ProtocolError {
1783 message: msg.into(),
1784 raw_data: None,
1785 source: None,
1786 })
1787}
1788
1789fn auth_error(msg: impl Into<String>) -> Error {
1790 Error::Connection(ConnectionError {
1791 kind: ConnectionErrorKind::Authentication,
1792 message: msg.into(),
1793 source: None,
1794 })
1795}
1796
1797fn connection_error(msg: impl Into<String>) -> Error {
1798 Error::Connection(ConnectionError {
1799 kind: ConnectionErrorKind::Connect,
1800 message: msg.into(),
1801 source: None,
1802 })
1803}
1804
1805fn query_error(err: &ErrPacket) -> Error {
1806 let kind = if err.is_duplicate_key() || err.is_foreign_key_violation() {
1807 QueryErrorKind::Constraint
1808 } else {
1809 QueryErrorKind::Syntax
1810 };
1811
1812 Error::Query(QueryError {
1813 kind,
1814 message: err.error_message.clone(),
1815 sqlstate: Some(err.sql_state.clone()),
1816 sql: None,
1817 detail: None,
1818 hint: None,
1819 position: None,
1820 source: None,
1821 })
1822}
1823
1824fn query_error_msg(msg: impl Into<String>) -> Error {
1825 Error::Query(QueryError {
1826 kind: QueryErrorKind::Syntax,
1827 message: msg.into(),
1828 sqlstate: None,
1829 sql: None,
1830 detail: None,
1831 hint: None,
1832 position: None,
1833 source: None,
1834 })
1835}
1836
1837fn validate_savepoint_name(name: &str) -> Result<(), Error> {
1845 if name.is_empty() {
1846 return Err(query_error_msg("Savepoint name cannot be empty"));
1847 }
1848 if name.len() > 64 {
1849 return Err(query_error_msg(
1850 "Savepoint name exceeds maximum length of 64 characters",
1851 ));
1852 }
1853 let mut chars = name.chars();
1854 let first = chars.next().unwrap();
1855 if !first.is_ascii_alphabetic() && first != '_' {
1856 return Err(query_error_msg(
1857 "Savepoint name must start with a letter or underscore",
1858 ));
1859 }
1860 for c in chars {
1861 if !c.is_ascii_alphanumeric() && c != '_' && c != '$' {
1862 return Err(query_error_msg(format!(
1863 "Savepoint name contains invalid character: '{}'",
1864 c
1865 )));
1866 }
1867 }
1868 Ok(())
1869}
1870
1871pub struct SharedMySqlConnection {
1888 inner: Arc<Mutex<MySqlAsyncConnection>>,
1889}
1890
1891impl SharedMySqlConnection {
1892 pub fn new(conn: MySqlAsyncConnection) -> Self {
1894 Self {
1895 inner: Arc::new(Mutex::new(conn)),
1896 }
1897 }
1898
1899 pub async fn connect(cx: &Cx, config: MySqlConfig) -> Outcome<Self, Error> {
1901 match MySqlAsyncConnection::connect(cx, config).await {
1902 Outcome::Ok(conn) => Outcome::Ok(Self::new(conn)),
1903 Outcome::Err(e) => Outcome::Err(e),
1904 Outcome::Cancelled(c) => Outcome::Cancelled(c),
1905 Outcome::Panicked(p) => Outcome::Panicked(p),
1906 }
1907 }
1908
1909 pub fn inner(&self) -> &Arc<Mutex<MySqlAsyncConnection>> {
1911 &self.inner
1912 }
1913}
1914
1915impl Clone for SharedMySqlConnection {
1916 fn clone(&self) -> Self {
1917 Self {
1918 inner: Arc::clone(&self.inner),
1919 }
1920 }
1921}
1922
1923impl std::fmt::Debug for SharedMySqlConnection {
1924 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1925 f.debug_struct("SharedMySqlConnection")
1926 .field("inner", &"Arc<Mutex<MySqlAsyncConnection>>")
1927 .finish()
1928 }
1929}
1930
1931pub struct SharedMySqlTransaction<'conn> {
1950 inner: Arc<Mutex<MySqlAsyncConnection>>,
1951 committed: bool,
1952 _marker: std::marker::PhantomData<&'conn ()>,
1953}
1954
1955impl SharedMySqlConnection {
1956 async fn begin_transaction_impl(
1958 &self,
1959 cx: &Cx,
1960 isolation: Option<IsolationLevel>,
1961 ) -> Outcome<SharedMySqlTransaction<'_>, Error> {
1962 let inner = Arc::clone(&self.inner);
1963
1964 let Ok(mut guard) = inner.lock(cx).await else {
1966 return Outcome::Err(connection_error("Failed to acquire connection lock"));
1967 };
1968
1969 if let Some(level) = isolation {
1971 let isolation_sql = format!("SET TRANSACTION ISOLATION LEVEL {}", level.as_sql());
1972 match guard.execute_async(cx, &isolation_sql, &[]).await {
1973 Outcome::Ok(_) => {}
1974 Outcome::Err(e) => return Outcome::Err(e),
1975 Outcome::Cancelled(c) => return Outcome::Cancelled(c),
1976 Outcome::Panicked(p) => return Outcome::Panicked(p),
1977 }
1978 }
1979
1980 match guard.execute_async(cx, "BEGIN", &[]).await {
1982 Outcome::Ok(_) => {}
1983 Outcome::Err(e) => return Outcome::Err(e),
1984 Outcome::Cancelled(c) => return Outcome::Cancelled(c),
1985 Outcome::Panicked(p) => return Outcome::Panicked(p),
1986 }
1987
1988 drop(guard);
1989
1990 Outcome::Ok(SharedMySqlTransaction {
1991 inner,
1992 committed: false,
1993 _marker: std::marker::PhantomData,
1994 })
1995 }
1996}
1997
1998impl Connection for SharedMySqlConnection {
1999 type Tx<'conn>
2000 = SharedMySqlTransaction<'conn>
2001 where
2002 Self: 'conn;
2003
2004 fn query(
2005 &self,
2006 cx: &Cx,
2007 sql: &str,
2008 params: &[Value],
2009 ) -> impl Future<Output = Outcome<Vec<Row>, Error>> + Send {
2010 let inner = Arc::clone(&self.inner);
2011 let sql = sql.to_string();
2012 let params = params.to_vec();
2013 async move {
2014 let Ok(mut guard) = inner.lock(cx).await else {
2015 return Outcome::Err(connection_error("Failed to acquire connection lock"));
2016 };
2017 guard.query_async(cx, &sql, ¶ms).await
2018 }
2019 }
2020
2021 fn query_one(
2022 &self,
2023 cx: &Cx,
2024 sql: &str,
2025 params: &[Value],
2026 ) -> impl Future<Output = Outcome<Option<Row>, Error>> + Send {
2027 let inner = Arc::clone(&self.inner);
2028 let sql = sql.to_string();
2029 let params = params.to_vec();
2030 async move {
2031 let Ok(mut guard) = inner.lock(cx).await else {
2032 return Outcome::Err(connection_error("Failed to acquire connection lock"));
2033 };
2034 let rows = match guard.query_async(cx, &sql, ¶ms).await {
2035 Outcome::Ok(r) => r,
2036 Outcome::Err(e) => return Outcome::Err(e),
2037 Outcome::Cancelled(c) => return Outcome::Cancelled(c),
2038 Outcome::Panicked(p) => return Outcome::Panicked(p),
2039 };
2040 Outcome::Ok(rows.into_iter().next())
2041 }
2042 }
2043
2044 fn execute(
2045 &self,
2046 cx: &Cx,
2047 sql: &str,
2048 params: &[Value],
2049 ) -> impl Future<Output = Outcome<u64, Error>> + Send {
2050 let inner = Arc::clone(&self.inner);
2051 let sql = sql.to_string();
2052 let params = params.to_vec();
2053 async move {
2054 let Ok(mut guard) = inner.lock(cx).await else {
2055 return Outcome::Err(connection_error("Failed to acquire connection lock"));
2056 };
2057 guard.execute_async(cx, &sql, ¶ms).await
2058 }
2059 }
2060
2061 fn insert(
2062 &self,
2063 cx: &Cx,
2064 sql: &str,
2065 params: &[Value],
2066 ) -> impl Future<Output = Outcome<i64, Error>> + Send {
2067 let inner = Arc::clone(&self.inner);
2068 let sql = sql.to_string();
2069 let params = params.to_vec();
2070 async move {
2071 let Ok(mut guard) = inner.lock(cx).await else {
2072 return Outcome::Err(connection_error("Failed to acquire connection lock"));
2073 };
2074 match guard.execute_async(cx, &sql, ¶ms).await {
2075 Outcome::Ok(_) => Outcome::Ok(guard.last_insert_id() as i64),
2076 Outcome::Err(e) => Outcome::Err(e),
2077 Outcome::Cancelled(c) => Outcome::Cancelled(c),
2078 Outcome::Panicked(p) => Outcome::Panicked(p),
2079 }
2080 }
2081 }
2082
2083 fn batch(
2084 &self,
2085 cx: &Cx,
2086 statements: &[(String, Vec<Value>)],
2087 ) -> impl Future<Output = Outcome<Vec<u64>, Error>> + Send {
2088 let inner = Arc::clone(&self.inner);
2089 let statements = statements.to_vec();
2090 async move {
2091 let Ok(mut guard) = inner.lock(cx).await else {
2092 return Outcome::Err(connection_error("Failed to acquire connection lock"));
2093 };
2094 let mut results = Vec::with_capacity(statements.len());
2095 for (sql, params) in &statements {
2096 match guard.execute_async(cx, sql, params).await {
2097 Outcome::Ok(n) => results.push(n),
2098 Outcome::Err(e) => return Outcome::Err(e),
2099 Outcome::Cancelled(c) => return Outcome::Cancelled(c),
2100 Outcome::Panicked(p) => return Outcome::Panicked(p),
2101 }
2102 }
2103 Outcome::Ok(results)
2104 }
2105 }
2106
2107 fn begin(&self, cx: &Cx) -> impl Future<Output = Outcome<Self::Tx<'_>, Error>> + Send {
2108 self.begin_transaction_impl(cx, None)
2109 }
2110
2111 fn begin_with(
2112 &self,
2113 cx: &Cx,
2114 isolation: IsolationLevel,
2115 ) -> impl Future<Output = Outcome<Self::Tx<'_>, Error>> + Send {
2116 self.begin_transaction_impl(cx, Some(isolation))
2117 }
2118
2119 fn prepare(
2120 &self,
2121 cx: &Cx,
2122 sql: &str,
2123 ) -> impl Future<Output = Outcome<PreparedStatement, Error>> + Send {
2124 let inner = Arc::clone(&self.inner);
2125 let sql = sql.to_string();
2126 async move {
2127 let Ok(mut guard) = inner.lock(cx).await else {
2128 return Outcome::Err(connection_error("Failed to acquire connection lock"));
2129 };
2130 guard.prepare_async(cx, &sql).await
2131 }
2132 }
2133
2134 fn query_prepared(
2135 &self,
2136 cx: &Cx,
2137 stmt: &PreparedStatement,
2138 params: &[Value],
2139 ) -> impl Future<Output = Outcome<Vec<Row>, Error>> + Send {
2140 let inner = Arc::clone(&self.inner);
2141 let stmt = stmt.clone();
2142 let params = params.to_vec();
2143 async move {
2144 let Ok(mut guard) = inner.lock(cx).await else {
2145 return Outcome::Err(connection_error("Failed to acquire connection lock"));
2146 };
2147 guard.query_prepared_async(cx, &stmt, ¶ms).await
2148 }
2149 }
2150
2151 fn execute_prepared(
2152 &self,
2153 cx: &Cx,
2154 stmt: &PreparedStatement,
2155 params: &[Value],
2156 ) -> impl Future<Output = Outcome<u64, Error>> + Send {
2157 let inner = Arc::clone(&self.inner);
2158 let stmt = stmt.clone();
2159 let params = params.to_vec();
2160 async move {
2161 let Ok(mut guard) = inner.lock(cx).await else {
2162 return Outcome::Err(connection_error("Failed to acquire connection lock"));
2163 };
2164 guard.execute_prepared_async(cx, &stmt, ¶ms).await
2165 }
2166 }
2167
2168 fn ping(&self, cx: &Cx) -> impl Future<Output = Outcome<(), Error>> + Send {
2169 let inner = Arc::clone(&self.inner);
2170 async move {
2171 let Ok(mut guard) = inner.lock(cx).await else {
2172 return Outcome::Err(connection_error("Failed to acquire connection lock"));
2173 };
2174 guard.ping_async(cx).await
2175 }
2176 }
2177
2178 fn close(self, cx: &Cx) -> impl Future<Output = Result<(), Error>> + Send {
2179 async move {
2180 match Arc::try_unwrap(self.inner) {
2182 Ok(mutex) => {
2183 let conn = mutex.into_inner();
2184 conn.close_async(cx).await
2185 }
2186 Err(_) => {
2187 Err(connection_error(
2189 "Cannot close: other references to connection exist",
2190 ))
2191 }
2192 }
2193 }
2194 }
2195}
2196
2197impl<'conn> TransactionOps for SharedMySqlTransaction<'conn> {
2198 fn query(
2199 &self,
2200 cx: &Cx,
2201 sql: &str,
2202 params: &[Value],
2203 ) -> impl Future<Output = Outcome<Vec<Row>, Error>> + Send {
2204 let inner = Arc::clone(&self.inner);
2205 let sql = sql.to_string();
2206 let params = params.to_vec();
2207 async move {
2208 let Ok(mut guard) = inner.lock(cx).await else {
2209 return Outcome::Err(connection_error("Failed to acquire connection lock"));
2210 };
2211 guard.query_async(cx, &sql, ¶ms).await
2212 }
2213 }
2214
2215 fn query_one(
2216 &self,
2217 cx: &Cx,
2218 sql: &str,
2219 params: &[Value],
2220 ) -> impl Future<Output = Outcome<Option<Row>, Error>> + Send {
2221 let inner = Arc::clone(&self.inner);
2222 let sql = sql.to_string();
2223 let params = params.to_vec();
2224 async move {
2225 let Ok(mut guard) = inner.lock(cx).await else {
2226 return Outcome::Err(connection_error("Failed to acquire connection lock"));
2227 };
2228 let rows = match guard.query_async(cx, &sql, ¶ms).await {
2229 Outcome::Ok(r) => r,
2230 Outcome::Err(e) => return Outcome::Err(e),
2231 Outcome::Cancelled(c) => return Outcome::Cancelled(c),
2232 Outcome::Panicked(p) => return Outcome::Panicked(p),
2233 };
2234 Outcome::Ok(rows.into_iter().next())
2235 }
2236 }
2237
2238 fn execute(
2239 &self,
2240 cx: &Cx,
2241 sql: &str,
2242 params: &[Value],
2243 ) -> impl Future<Output = Outcome<u64, Error>> + Send {
2244 let inner = Arc::clone(&self.inner);
2245 let sql = sql.to_string();
2246 let params = params.to_vec();
2247 async move {
2248 let Ok(mut guard) = inner.lock(cx).await else {
2249 return Outcome::Err(connection_error("Failed to acquire connection lock"));
2250 };
2251 guard.execute_async(cx, &sql, ¶ms).await
2252 }
2253 }
2254
2255 fn savepoint(&self, cx: &Cx, name: &str) -> impl Future<Output = Outcome<(), Error>> + Send {
2256 let inner = Arc::clone(&self.inner);
2257 let validation_result = validate_savepoint_name(name);
2259 let sql = format!("SAVEPOINT {}", name);
2260 async move {
2261 if let Err(e) = validation_result {
2263 return Outcome::Err(e);
2264 }
2265 let Ok(mut guard) = inner.lock(cx).await else {
2266 return Outcome::Err(connection_error("Failed to acquire connection lock"));
2267 };
2268 match guard.execute_async(cx, &sql, &[]).await {
2269 Outcome::Ok(_) => Outcome::Ok(()),
2270 Outcome::Err(e) => Outcome::Err(e),
2271 Outcome::Cancelled(c) => Outcome::Cancelled(c),
2272 Outcome::Panicked(p) => Outcome::Panicked(p),
2273 }
2274 }
2275 }
2276
2277 fn rollback_to(&self, cx: &Cx, name: &str) -> impl Future<Output = Outcome<(), Error>> + Send {
2278 let inner = Arc::clone(&self.inner);
2279 let validation_result = validate_savepoint_name(name);
2281 let sql = format!("ROLLBACK TO SAVEPOINT {}", name);
2282 async move {
2283 if let Err(e) = validation_result {
2285 return Outcome::Err(e);
2286 }
2287 let Ok(mut guard) = inner.lock(cx).await else {
2288 return Outcome::Err(connection_error("Failed to acquire connection lock"));
2289 };
2290 match guard.execute_async(cx, &sql, &[]).await {
2291 Outcome::Ok(_) => Outcome::Ok(()),
2292 Outcome::Err(e) => Outcome::Err(e),
2293 Outcome::Cancelled(c) => Outcome::Cancelled(c),
2294 Outcome::Panicked(p) => Outcome::Panicked(p),
2295 }
2296 }
2297 }
2298
2299 fn release(&self, cx: &Cx, name: &str) -> impl Future<Output = Outcome<(), Error>> + Send {
2300 let inner = Arc::clone(&self.inner);
2301 let validation_result = validate_savepoint_name(name);
2303 let sql = format!("RELEASE SAVEPOINT {}", name);
2304 async move {
2305 if let Err(e) = validation_result {
2307 return Outcome::Err(e);
2308 }
2309 let Ok(mut guard) = inner.lock(cx).await else {
2310 return Outcome::Err(connection_error("Failed to acquire connection lock"));
2311 };
2312 match guard.execute_async(cx, &sql, &[]).await {
2313 Outcome::Ok(_) => Outcome::Ok(()),
2314 Outcome::Err(e) => Outcome::Err(e),
2315 Outcome::Cancelled(c) => Outcome::Cancelled(c),
2316 Outcome::Panicked(p) => Outcome::Panicked(p),
2317 }
2318 }
2319 }
2320
2321 #[allow(unused_assignments)]
2324 fn commit(mut self, cx: &Cx) -> impl Future<Output = Outcome<(), Error>> + Send {
2325 async move {
2326 let Ok(mut guard) = self.inner.lock(cx).await else {
2327 return Outcome::Err(connection_error("Failed to acquire connection lock"));
2328 };
2329 match guard.execute_async(cx, "COMMIT", &[]).await {
2330 Outcome::Ok(_) => {
2331 self.committed = true;
2332 Outcome::Ok(())
2333 }
2334 Outcome::Err(e) => Outcome::Err(e),
2335 Outcome::Cancelled(c) => Outcome::Cancelled(c),
2336 Outcome::Panicked(p) => Outcome::Panicked(p),
2337 }
2338 }
2339 }
2340
2341 fn rollback(self, cx: &Cx) -> impl Future<Output = Outcome<(), Error>> + Send {
2342 async move {
2343 let Ok(mut guard) = self.inner.lock(cx).await else {
2344 return Outcome::Err(connection_error("Failed to acquire connection lock"));
2345 };
2346 match guard.execute_async(cx, "ROLLBACK", &[]).await {
2347 Outcome::Ok(_) => Outcome::Ok(()),
2348 Outcome::Err(e) => Outcome::Err(e),
2349 Outcome::Cancelled(c) => Outcome::Cancelled(c),
2350 Outcome::Panicked(p) => Outcome::Panicked(p),
2351 }
2352 }
2353 }
2354}
2355
2356impl<'conn> Drop for SharedMySqlTransaction<'conn> {
2357 fn drop(&mut self) {
2358 if !self.committed {
2359 #[cfg(debug_assertions)]
2367 eprintln!(
2368 "WARNING: SharedMySqlTransaction dropped without commit/rollback. \
2369 The MySQL transaction may still be open."
2370 );
2371 }
2372 }
2373}
2374
2375#[cfg(test)]
2376mod tests {
2377 use super::*;
2378
2379 #[test]
2380 fn test_connection_state() {
2381 assert_eq!(ConnectionState::Disconnected, ConnectionState::Disconnected);
2382 }
2383
2384 #[test]
2385 fn test_error_helpers() {
2386 let err = protocol_error("test");
2387 assert!(matches!(err, Error::Protocol(_)));
2388
2389 let err = auth_error("auth failed");
2390 assert!(matches!(err, Error::Connection(_)));
2391
2392 let err = connection_error("conn failed");
2393 assert!(matches!(err, Error::Connection(_)));
2394 }
2395
2396 #[test]
2397 fn test_validate_savepoint_name_valid() {
2398 assert!(validate_savepoint_name("sp1").is_ok());
2400 assert!(validate_savepoint_name("_savepoint").is_ok());
2401 assert!(validate_savepoint_name("SavePoint_123").is_ok());
2402 assert!(validate_savepoint_name("sp$test").is_ok());
2403 assert!(validate_savepoint_name("a").is_ok());
2404 assert!(validate_savepoint_name("_").is_ok());
2405 }
2406
2407 #[test]
2408 fn test_validate_savepoint_name_invalid() {
2409 assert!(validate_savepoint_name("").is_err());
2411
2412 assert!(validate_savepoint_name("1savepoint").is_err());
2414
2415 assert!(validate_savepoint_name("save-point").is_err());
2417 assert!(validate_savepoint_name("save point").is_err());
2418 assert!(validate_savepoint_name("save;drop table").is_err());
2419 assert!(validate_savepoint_name("sp'--").is_err());
2420
2421 let long_name = "a".repeat(65);
2423 assert!(validate_savepoint_name(&long_name).is_err());
2424 }
2425}