Skip to main content

sqlmodel_mysql/
async_connection.rs

1//! Async MySQL connection implementation.
2//!
3//! This module implements the async MySQL connection using asupersync's TCP primitives.
4//! It provides the `Connection` trait implementation for integration with sqlmodel-core.
5
6// Allow `impl Future` return types in trait methods - intentional design for async trait compat
7#![allow(clippy::manual_async_fn)]
8// The Error type is intentionally large to carry full context
9#![allow(clippy::result_large_err)]
10
11use std::collections::HashMap;
12use std::future::Future;
13use std::io::{self, Read as StdRead, Write as StdWrite};
14use std::net::TcpStream as StdTcpStream;
15use std::sync::Arc;
16
17use asupersync::io::{AsyncRead, AsyncWrite, ReadBuf};
18use asupersync::net::TcpStream;
19use asupersync::sync::Mutex;
20use asupersync::{Cx, Outcome};
21
22use sqlmodel_core::connection::{Connection, IsolationLevel, PreparedStatement, TransactionOps};
23use sqlmodel_core::error::{
24    ConnectionError, ConnectionErrorKind, ProtocolError, QueryError, QueryErrorKind,
25};
26use sqlmodel_core::{Error, Row, Value};
27
28#[cfg(feature = "console")]
29use sqlmodel_console::{ConsoleAware, SqlModelConsole};
30
31use crate::auth;
32use crate::config::MySqlConfig;
33use crate::connection::{ConnectionState, ServerCapabilities};
34use crate::protocol::{
35    Command, ErrPacket, MAX_PACKET_SIZE, PacketHeader, PacketReader, PacketType, PacketWriter,
36    capabilities, charset, prepared,
37};
38use crate::types::{
39    ColumnDef, FieldType, decode_binary_value_with_len, decode_text_value, interpolate_params,
40};
41
42/// Async MySQL connection.
43///
44/// This connection uses asupersync's TCP stream for non-blocking I/O
45/// and implements the `Connection` trait from sqlmodel-core.
46pub struct MySqlAsyncConnection {
47    /// TCP stream (either sync for compatibility or async wrapper)
48    stream: Option<ConnectionStream>,
49    /// Current connection state
50    state: ConnectionState,
51    /// Server capabilities from handshake
52    server_caps: Option<ServerCapabilities>,
53    /// Connection ID
54    connection_id: u32,
55    /// Server status flags
56    status_flags: u16,
57    /// Affected rows from last statement
58    affected_rows: u64,
59    /// Last insert ID
60    last_insert_id: u64,
61    /// Number of warnings
62    warnings: u16,
63    /// Connection configuration
64    config: MySqlConfig,
65    /// Current sequence ID for packet framing
66    sequence_id: u8,
67    /// Prepared statement metadata (keyed by statement ID)
68    prepared_stmts: HashMap<u32, PreparedStmtMeta>,
69    /// Optional console for rich output
70    #[cfg(feature = "console")]
71    console: Option<Arc<SqlModelConsole>>,
72}
73
74/// Metadata for a prepared statement.
75///
76/// Stores the MySQL-specific information needed to execute
77/// and decode results from a prepared statement.
78#[derive(Debug, Clone)]
79struct PreparedStmtMeta {
80    /// Server-assigned statement ID (stored for potential future use in close/reset)
81    #[allow(dead_code)]
82    statement_id: u32,
83    /// Parameter column definitions (for type encoding)
84    params: Vec<ColumnDef>,
85    /// Result column definitions (for binary decoding)
86    columns: Vec<ColumnDef>,
87}
88
89/// Connection stream wrapper for sync/async compatibility.
90#[allow(dead_code)]
91enum ConnectionStream {
92    /// Standard sync TCP stream (for initial connection)
93    Sync(StdTcpStream),
94    /// Async TCP stream (for async operations)
95    Async(TcpStream),
96    /// Async TLS stream (for encrypted async operations)
97    #[cfg(feature = "tls")]
98    Tls(AsyncTlsStream),
99}
100
101/// Async TLS stream built on rustls + asupersync TcpStream.
102///
103/// This is intentionally minimal: it provides enough read/write behavior for
104/// MySQL packet framing without depending on a tokio/futures I/O ecosystem.
105#[cfg(feature = "tls")]
106struct AsyncTlsStream {
107    tcp: TcpStream,
108    tls: rustls::ClientConnection,
109}
110
111#[cfg(feature = "tls")]
112impl std::fmt::Debug for AsyncTlsStream {
113    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114        f.debug_struct("AsyncTlsStream")
115            .field("protocol_version", &self.tls.protocol_version())
116            .field("is_handshaking", &self.tls.is_handshaking())
117            .finish_non_exhaustive()
118    }
119}
120
121#[cfg(feature = "tls")]
122impl AsyncTlsStream {
123    async fn handshake(
124        mut tcp: TcpStream,
125        tls_config: &crate::config::TlsConfig,
126        host: &str,
127        ssl_mode: crate::config::SslMode,
128    ) -> Result<Self, Error> {
129        let config = crate::tls::build_client_config(tls_config, ssl_mode)?;
130
131        let sni = tls_config.server_name.as_deref().unwrap_or(host);
132        let server_name = sni
133            .to_string()
134            .try_into()
135            .map_err(|e| connection_error(format!("Invalid server name '{sni}': {e}")))?;
136
137        let mut tls = rustls::ClientConnection::new(std::sync::Arc::new(config), server_name)
138            .map_err(|e| connection_error(format!("Failed to create TLS connection: {e}")))?;
139
140        // Drive rustls handshake using async reads/writes on the TCP stream.
141        while tls.is_handshaking() {
142            while tls.wants_write() {
143                let mut out = Vec::new();
144                tls.write_tls(&mut out)
145                    .map_err(|e| connection_error(format!("TLS handshake write_tls error: {e}")))?;
146                if !out.is_empty() {
147                    write_all_async(&mut tcp, &out).await.map_err(|e| {
148                        Error::Connection(ConnectionError {
149                            kind: ConnectionErrorKind::Disconnected,
150                            message: format!("TLS handshake write error: {e}"),
151                            source: Some(Box::new(e)),
152                        })
153                    })?;
154                }
155            }
156
157            if tls.wants_read() {
158                let mut buf = [0u8; 8192];
159                let n = read_some_async(&mut tcp, &mut buf).await.map_err(|e| {
160                    Error::Connection(ConnectionError {
161                        kind: ConnectionErrorKind::Disconnected,
162                        message: format!("TLS handshake read error: {e}"),
163                        source: Some(Box::new(e)),
164                    })
165                })?;
166                if n == 0 {
167                    return Err(connection_error("Connection closed during TLS handshake"));
168                }
169
170                let mut cursor = std::io::Cursor::new(&buf[..n]);
171                tls.read_tls(&mut cursor)
172                    .map_err(|e| connection_error(format!("TLS handshake read_tls error: {e}")))?;
173                tls.process_new_packets()
174                    .map_err(|e| connection_error(format!("TLS handshake error: {e}")))?;
175            }
176        }
177
178        Ok(Self { tcp, tls })
179    }
180
181    async fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
182        let mut read = 0;
183        while read < buf.len() {
184            let n = self.read_plain(&mut buf[read..]).await?;
185            if n == 0 {
186                return Err(io::Error::new(
187                    io::ErrorKind::UnexpectedEof,
188                    "connection closed",
189                ));
190            }
191            read += n;
192        }
193        Ok(())
194    }
195
196    async fn read_plain(&mut self, out: &mut [u8]) -> io::Result<usize> {
197        loop {
198            match self.tls.reader().read(out) {
199                Ok(n) if n > 0 => return Ok(n),
200                Ok(_) => {}
201                Err(e) if e.kind() == io::ErrorKind::WouldBlock => {}
202                Err(e) => return Err(e),
203            }
204
205            if !self.tls.wants_read() {
206                return Ok(0);
207            }
208
209            let mut enc = [0u8; 8192];
210            let n = read_some_async(&mut self.tcp, &mut enc).await?;
211            if n == 0 {
212                return Ok(0);
213            }
214
215            let mut cursor = std::io::Cursor::new(&enc[..n]);
216            self.tls.read_tls(&mut cursor)?;
217            self.tls
218                .process_new_packets()
219                .map_err(|e| io::Error::other(format!("TLS error: {e}")))?;
220        }
221    }
222
223    async fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
224        let mut written = 0;
225        while written < buf.len() {
226            let n = self.tls.writer().write(&buf[written..])?;
227            if n == 0 {
228                return Err(io::Error::new(io::ErrorKind::WriteZero, "TLS write zero"));
229            }
230            written += n;
231            self.flush().await?;
232        }
233        Ok(())
234    }
235
236    async fn flush(&mut self) -> io::Result<()> {
237        self.tls.writer().flush()?;
238        while self.tls.wants_write() {
239            let mut out = Vec::new();
240            self.tls.write_tls(&mut out)?;
241            if !out.is_empty() {
242                write_all_async(&mut self.tcp, &out).await?;
243            }
244        }
245        flush_async(&mut self.tcp).await
246    }
247}
248
249#[cfg(feature = "tls")]
250async fn read_some_async(stream: &mut TcpStream, buf: &mut [u8]) -> io::Result<usize> {
251    let mut read_buf = ReadBuf::new(buf);
252    std::future::poll_fn(|cx| std::pin::Pin::new(&mut *stream).poll_read(cx, &mut read_buf))
253        .await?;
254    Ok(read_buf.filled().len())
255}
256
257#[cfg(feature = "tls")]
258async fn write_all_async(stream: &mut TcpStream, buf: &[u8]) -> io::Result<()> {
259    let mut written = 0;
260    while written < buf.len() {
261        let n = std::future::poll_fn(|cx| {
262            std::pin::Pin::new(&mut *stream).poll_write(cx, &buf[written..])
263        })
264        .await?;
265        if n == 0 {
266            return Err(io::Error::new(
267                io::ErrorKind::WriteZero,
268                "connection closed",
269            ));
270        }
271        written += n;
272    }
273    Ok(())
274}
275
276#[cfg(feature = "tls")]
277async fn flush_async(stream: &mut TcpStream) -> io::Result<()> {
278    std::future::poll_fn(|cx| std::pin::Pin::new(&mut *stream).poll_flush(cx)).await
279}
280
281impl std::fmt::Debug for MySqlAsyncConnection {
282    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
283        f.debug_struct("MySqlAsyncConnection")
284            .field("state", &self.state)
285            .field("connection_id", &self.connection_id)
286            .field("host", &self.config.host)
287            .field("port", &self.config.port)
288            .field("database", &self.config.database)
289            .finish_non_exhaustive()
290    }
291}
292
293impl MySqlAsyncConnection {
294    /// Establish a new async connection to the MySQL server.
295    ///
296    /// This performs the complete connection handshake asynchronously:
297    /// 1. TCP connection
298    /// 2. Receive server handshake
299    /// 3. Send handshake response with authentication
300    /// 4. Handle auth result (possibly auth switch)
301    pub async fn connect(_cx: &Cx, config: MySqlConfig) -> Outcome<Self, Error> {
302        // Use async TCP connect
303        let addr = config.socket_addr();
304        let socket_addr = match addr.parse() {
305            Ok(a) => a,
306            Err(e) => {
307                return Outcome::Err(Error::Connection(ConnectionError {
308                    kind: ConnectionErrorKind::Connect,
309                    message: format!("Invalid socket address: {}", e),
310                    source: None,
311                }));
312            }
313        };
314        let stream = match TcpStream::connect_timeout(socket_addr, config.connect_timeout).await {
315            Ok(s) => s,
316            Err(e) => {
317                let kind = if e.kind() == io::ErrorKind::ConnectionRefused {
318                    ConnectionErrorKind::Refused
319                } else {
320                    ConnectionErrorKind::Connect
321                };
322                return Outcome::Err(Error::Connection(ConnectionError {
323                    kind,
324                    message: format!("Failed to connect to {}: {}", addr, e),
325                    source: Some(Box::new(e)),
326                }));
327            }
328        };
329
330        // Set TCP options
331        stream.set_nodelay(true).ok();
332
333        let mut conn = Self {
334            stream: Some(ConnectionStream::Async(stream)),
335            state: ConnectionState::Connecting,
336            server_caps: None,
337            connection_id: 0,
338            status_flags: 0,
339            affected_rows: 0,
340            last_insert_id: 0,
341            warnings: 0,
342            config,
343            sequence_id: 0,
344            prepared_stmts: HashMap::new(),
345            #[cfg(feature = "console")]
346            console: None,
347        };
348
349        // 2. Receive server handshake
350        match conn.read_handshake_async().await {
351            Outcome::Ok(server_caps) => {
352                conn.connection_id = server_caps.connection_id;
353                conn.server_caps = Some(server_caps);
354                conn.state = ConnectionState::Authenticating;
355            }
356            Outcome::Err(e) => return Outcome::Err(e),
357            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
358            Outcome::Panicked(p) => return Outcome::Panicked(p),
359        }
360
361        // 3. Send handshake response
362        if let Outcome::Err(e) = conn.send_handshake_response_async().await {
363            return Outcome::Err(e);
364        }
365
366        // 4. Handle authentication result
367        if let Outcome::Err(e) = conn.handle_auth_result_async().await {
368            return Outcome::Err(e);
369        }
370
371        conn.state = ConnectionState::Ready;
372        Outcome::Ok(conn)
373    }
374
375    /// Get the current connection state.
376    pub fn state(&self) -> ConnectionState {
377        self.state
378    }
379
380    /// Check if the connection is ready for queries.
381    pub fn is_ready(&self) -> bool {
382        matches!(self.state, ConnectionState::Ready)
383    }
384
385    fn is_secure_transport(&self) -> bool {
386        #[cfg(feature = "tls")]
387        {
388            matches!(self.stream, Some(ConnectionStream::Tls(_)))
389        }
390        #[cfg(not(feature = "tls"))]
391        {
392            false
393        }
394    }
395
396    /// Get the connection ID.
397    pub fn connection_id(&self) -> u32 {
398        self.connection_id
399    }
400
401    /// Get the server version.
402    pub fn server_version(&self) -> Option<&str> {
403        self.server_caps
404            .as_ref()
405            .map(|caps| caps.server_version.as_str())
406    }
407
408    /// Get the number of affected rows from the last statement.
409    pub fn affected_rows(&self) -> u64 {
410        self.affected_rows
411    }
412
413    /// Get the last insert ID.
414    pub fn last_insert_id(&self) -> u64 {
415        self.last_insert_id
416    }
417
418    // === Async I/O methods ===
419
420    /// Read a complete packet from the stream asynchronously.
421    async fn read_packet_async(&mut self) -> Outcome<(Vec<u8>, u8), Error> {
422        // Read header (4 bytes) - must loop since TCP can fragment reads
423        let mut header_buf = [0u8; 4];
424
425        let Some(stream) = self.stream.as_mut() else {
426            return Outcome::Err(connection_error("Connection stream missing"));
427        };
428
429        match stream {
430            ConnectionStream::Async(stream) => {
431                let mut header_read = 0;
432                while header_read < 4 {
433                    let mut read_buf = ReadBuf::new(&mut header_buf[header_read..]);
434                    match std::future::poll_fn(|cx| {
435                        std::pin::Pin::new(&mut *stream).poll_read(cx, &mut read_buf)
436                    })
437                    .await
438                    {
439                        Ok(()) => {
440                            let n = read_buf.filled().len();
441                            if n == 0 {
442                                return Outcome::Err(Error::Connection(ConnectionError {
443                                    kind: ConnectionErrorKind::Disconnected,
444                                    message: "Connection closed while reading header".to_string(),
445                                    source: None,
446                                }));
447                            }
448                            header_read += n;
449                        }
450                        Err(e) => {
451                            return Outcome::Err(Error::Connection(ConnectionError {
452                                kind: ConnectionErrorKind::Disconnected,
453                                message: format!("Failed to read packet header: {}", e),
454                                source: Some(Box::new(e)),
455                            }));
456                        }
457                    }
458                }
459            }
460            ConnectionStream::Sync(stream) => {
461                if let Err(e) = stream.read_exact(&mut header_buf) {
462                    return Outcome::Err(Error::Connection(ConnectionError {
463                        kind: ConnectionErrorKind::Disconnected,
464                        message: format!("Failed to read packet header: {}", e),
465                        source: Some(Box::new(e)),
466                    }));
467                }
468            }
469            #[cfg(feature = "tls")]
470            ConnectionStream::Tls(stream) => {
471                if let Err(e) = stream.read_exact(&mut header_buf).await {
472                    return Outcome::Err(Error::Connection(ConnectionError {
473                        kind: ConnectionErrorKind::Disconnected,
474                        message: format!("Failed to read packet header: {e}"),
475                        source: Some(Box::new(e)),
476                    }));
477                }
478            }
479        }
480
481        let header = PacketHeader::from_bytes(&header_buf);
482        let payload_len = header.payload_length as usize;
483        self.sequence_id = header.sequence_id.wrapping_add(1);
484
485        // Read payload
486        let mut payload = vec![0u8; payload_len];
487        if payload_len > 0 {
488            let Some(stream) = self.stream.as_mut() else {
489                return Outcome::Err(connection_error("Connection stream missing"));
490            };
491            match stream {
492                ConnectionStream::Async(stream) => {
493                    let mut total_read = 0;
494                    while total_read < payload_len {
495                        let mut read_buf = ReadBuf::new(&mut payload[total_read..]);
496                        match std::future::poll_fn(|cx| {
497                            std::pin::Pin::new(&mut *stream).poll_read(cx, &mut read_buf)
498                        })
499                        .await
500                        {
501                            Ok(()) => {
502                                let n = read_buf.filled().len();
503                                if n == 0 {
504                                    return Outcome::Err(Error::Connection(ConnectionError {
505                                        kind: ConnectionErrorKind::Disconnected,
506                                        message: "Connection closed while reading payload"
507                                            .to_string(),
508                                        source: None,
509                                    }));
510                                }
511                                total_read += n;
512                            }
513                            Err(e) => {
514                                return Outcome::Err(Error::Connection(ConnectionError {
515                                    kind: ConnectionErrorKind::Disconnected,
516                                    message: format!("Failed to read packet payload: {}", e),
517                                    source: Some(Box::new(e)),
518                                }));
519                            }
520                        }
521                    }
522                }
523                ConnectionStream::Sync(stream) => {
524                    if let Err(e) = stream.read_exact(&mut payload) {
525                        return Outcome::Err(Error::Connection(ConnectionError {
526                            kind: ConnectionErrorKind::Disconnected,
527                            message: format!("Failed to read packet payload: {}", e),
528                            source: Some(Box::new(e)),
529                        }));
530                    }
531                }
532                #[cfg(feature = "tls")]
533                ConnectionStream::Tls(stream) => {
534                    if let Err(e) = stream.read_exact(&mut payload).await {
535                        return Outcome::Err(Error::Connection(ConnectionError {
536                            kind: ConnectionErrorKind::Disconnected,
537                            message: format!("Failed to read packet payload: {e}"),
538                            source: Some(Box::new(e)),
539                        }));
540                    }
541                }
542            }
543        }
544
545        // Handle multi-packet payloads
546        if payload_len == MAX_PACKET_SIZE {
547            loop {
548                // Read continuation header with loop (TCP can fragment)
549                let mut header_buf = [0u8; 4];
550                let Some(stream) = self.stream.as_mut() else {
551                    return Outcome::Err(connection_error("Connection stream missing"));
552                };
553                match stream {
554                    ConnectionStream::Async(stream) => {
555                        let mut header_read = 0;
556                        while header_read < 4 {
557                            let mut read_buf = ReadBuf::new(&mut header_buf[header_read..]);
558                            match std::future::poll_fn(|cx| {
559                                std::pin::Pin::new(&mut *stream).poll_read(cx, &mut read_buf)
560                            })
561                            .await
562                            {
563                                Ok(()) => {
564                                    let n = read_buf.filled().len();
565                                    if n == 0 {
566                                        return Outcome::Err(Error::Connection(ConnectionError {
567                                            kind: ConnectionErrorKind::Disconnected,
568                                            message: "Connection closed while reading continuation header".to_string(),
569                                            source: None,
570                                        }));
571                                    }
572                                    header_read += n;
573                                }
574                                Err(e) => {
575                                    return Outcome::Err(Error::Connection(ConnectionError {
576                                        kind: ConnectionErrorKind::Disconnected,
577                                        message: format!(
578                                            "Failed to read continuation header: {}",
579                                            e
580                                        ),
581                                        source: Some(Box::new(e)),
582                                    }));
583                                }
584                            }
585                        }
586                    }
587                    ConnectionStream::Sync(stream) => {
588                        if let Err(e) = stream.read_exact(&mut header_buf) {
589                            return Outcome::Err(Error::Connection(ConnectionError {
590                                kind: ConnectionErrorKind::Disconnected,
591                                message: format!("Failed to read continuation header: {}", e),
592                                source: Some(Box::new(e)),
593                            }));
594                        }
595                    }
596                    #[cfg(feature = "tls")]
597                    ConnectionStream::Tls(stream) => {
598                        if let Err(e) = stream.read_exact(&mut header_buf).await {
599                            return Outcome::Err(Error::Connection(ConnectionError {
600                                kind: ConnectionErrorKind::Disconnected,
601                                message: format!("Failed to read continuation header: {e}"),
602                                source: Some(Box::new(e)),
603                            }));
604                        }
605                    }
606                }
607
608                let cont_header = PacketHeader::from_bytes(&header_buf);
609                let cont_len = cont_header.payload_length as usize;
610                self.sequence_id = cont_header.sequence_id.wrapping_add(1);
611
612                if cont_len > 0 {
613                    let mut cont_payload = vec![0u8; cont_len];
614                    let Some(stream) = self.stream.as_mut() else {
615                        return Outcome::Err(connection_error("Connection stream missing"));
616                    };
617                    match stream {
618                        ConnectionStream::Async(stream) => {
619                            let mut total_read = 0;
620                            while total_read < cont_len {
621                                let mut read_buf = ReadBuf::new(&mut cont_payload[total_read..]);
622                                match std::future::poll_fn(|cx| {
623                                    std::pin::Pin::new(&mut *stream).poll_read(cx, &mut read_buf)
624                                })
625                                .await
626                                {
627                                    Ok(()) => {
628                                        let n = read_buf.filled().len();
629                                        if n == 0 {
630                                            return Outcome::Err(Error::Connection(ConnectionError {
631                                                kind: ConnectionErrorKind::Disconnected,
632                                                message: "Connection closed while reading continuation payload".to_string(),
633                                                source: None,
634                                            }));
635                                        }
636                                        total_read += n;
637                                    }
638                                    Err(e) => {
639                                        return Outcome::Err(Error::Connection(ConnectionError {
640                                            kind: ConnectionErrorKind::Disconnected,
641                                            message: format!(
642                                                "Failed to read continuation payload: {}",
643                                                e
644                                            ),
645                                            source: Some(Box::new(e)),
646                                        }));
647                                    }
648                                }
649                            }
650                        }
651                        ConnectionStream::Sync(stream) => {
652                            if let Err(e) = stream.read_exact(&mut cont_payload) {
653                                return Outcome::Err(Error::Connection(ConnectionError {
654                                    kind: ConnectionErrorKind::Disconnected,
655                                    message: format!("Failed to read continuation payload: {}", e),
656                                    source: Some(Box::new(e)),
657                                }));
658                            }
659                        }
660                        #[cfg(feature = "tls")]
661                        ConnectionStream::Tls(stream) => {
662                            if let Err(e) = stream.read_exact(&mut cont_payload).await {
663                                return Outcome::Err(Error::Connection(ConnectionError {
664                                    kind: ConnectionErrorKind::Disconnected,
665                                    message: format!("Failed to read continuation payload: {e}"),
666                                    source: Some(Box::new(e)),
667                                }));
668                            }
669                        }
670                    }
671                    payload.extend_from_slice(&cont_payload);
672                }
673
674                if cont_len < MAX_PACKET_SIZE {
675                    break;
676                }
677            }
678        }
679
680        Outcome::Ok((payload, header.sequence_id))
681    }
682
683    /// Write a packet to the stream asynchronously.
684    async fn write_packet_async(&mut self, payload: &[u8]) -> Outcome<(), Error> {
685        let writer = PacketWriter::new();
686        let packet = writer.build_packet_from_payload(payload, self.sequence_id);
687        self.sequence_id = self.sequence_id.wrapping_add(1);
688
689        let Some(stream) = self.stream.as_mut() else {
690            return Outcome::Err(connection_error("Connection stream missing"));
691        };
692
693        match stream {
694            ConnectionStream::Async(stream) => {
695                // Loop to handle partial writes (poll_write may return fewer bytes)
696                let mut written = 0;
697                while written < packet.len() {
698                    match std::future::poll_fn(|cx| {
699                        std::pin::Pin::new(&mut *stream).poll_write(cx, &packet[written..])
700                    })
701                    .await
702                    {
703                        Ok(n) => {
704                            if n == 0 {
705                                return Outcome::Err(Error::Connection(ConnectionError {
706                                    kind: ConnectionErrorKind::Disconnected,
707                                    message: "Connection closed while writing packet".to_string(),
708                                    source: None,
709                                }));
710                            }
711                            written += n;
712                        }
713                        Err(e) => {
714                            return Outcome::Err(Error::Connection(ConnectionError {
715                                kind: ConnectionErrorKind::Disconnected,
716                                message: format!("Failed to write packet: {}", e),
717                                source: Some(Box::new(e)),
718                            }));
719                        }
720                    }
721                }
722
723                match std::future::poll_fn(|cx| std::pin::Pin::new(&mut *stream).poll_flush(cx))
724                    .await
725                {
726                    Ok(()) => {}
727                    Err(e) => {
728                        return Outcome::Err(Error::Connection(ConnectionError {
729                            kind: ConnectionErrorKind::Disconnected,
730                            message: format!("Failed to flush stream: {}", e),
731                            source: Some(Box::new(e)),
732                        }));
733                    }
734                }
735            }
736            ConnectionStream::Sync(stream) => {
737                if let Err(e) = stream.write_all(&packet) {
738                    return Outcome::Err(Error::Connection(ConnectionError {
739                        kind: ConnectionErrorKind::Disconnected,
740                        message: format!("Failed to write packet: {}", e),
741                        source: Some(Box::new(e)),
742                    }));
743                }
744                if let Err(e) = stream.flush() {
745                    return Outcome::Err(Error::Connection(ConnectionError {
746                        kind: ConnectionErrorKind::Disconnected,
747                        message: format!("Failed to flush stream: {}", e),
748                        source: Some(Box::new(e)),
749                    }));
750                }
751            }
752            #[cfg(feature = "tls")]
753            ConnectionStream::Tls(stream) => {
754                if let Err(e) = stream.write_all(&packet).await {
755                    return Outcome::Err(Error::Connection(ConnectionError {
756                        kind: ConnectionErrorKind::Disconnected,
757                        message: format!("Failed to write packet: {e}"),
758                        source: Some(Box::new(e)),
759                    }));
760                }
761                if let Err(e) = stream.flush().await {
762                    return Outcome::Err(Error::Connection(ConnectionError {
763                        kind: ConnectionErrorKind::Disconnected,
764                        message: format!("Failed to flush stream: {e}"),
765                        source: Some(Box::new(e)),
766                    }));
767                }
768            }
769        }
770
771        Outcome::Ok(())
772    }
773
774    // === Handshake methods ===
775
776    /// Read the server handshake packet asynchronously.
777    async fn read_handshake_async(&mut self) -> Outcome<ServerCapabilities, Error> {
778        let (payload, _) = match self.read_packet_async().await {
779            Outcome::Ok(p) => p,
780            Outcome::Err(e) => return Outcome::Err(e),
781            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
782            Outcome::Panicked(p) => return Outcome::Panicked(p),
783        };
784
785        let mut reader = PacketReader::new(&payload);
786
787        // Protocol version
788        let Some(protocol_version) = reader.read_u8() else {
789            return Outcome::Err(protocol_error("Missing protocol version"));
790        };
791
792        if protocol_version != 10 {
793            return Outcome::Err(protocol_error(format!(
794                "Unsupported protocol version: {}",
795                protocol_version
796            )));
797        }
798
799        // Server version (null-terminated string)
800        let Some(server_version) = reader.read_null_string() else {
801            return Outcome::Err(protocol_error("Missing server version"));
802        };
803
804        // Connection ID
805        let Some(connection_id) = reader.read_u32_le() else {
806            return Outcome::Err(protocol_error("Missing connection ID"));
807        };
808
809        // Auth plugin data part 1 (8 bytes)
810        let Some(auth_data_1) = reader.read_bytes(8) else {
811            return Outcome::Err(protocol_error("Missing auth data"));
812        };
813
814        // Filler (1 byte)
815        reader.skip(1);
816
817        // Capability flags (lower 2 bytes)
818        let Some(caps_lower) = reader.read_u16_le() else {
819            return Outcome::Err(protocol_error("Missing capability flags"));
820        };
821
822        // Character set
823        let charset_val = reader.read_u8().unwrap_or(charset::UTF8MB4_0900_AI_CI);
824
825        // Status flags
826        let status_flags = reader.read_u16_le().unwrap_or(0);
827
828        // Capability flags (upper 2 bytes)
829        let caps_upper = reader.read_u16_le().unwrap_or(0);
830        let capabilities_val = u32::from(caps_lower) | (u32::from(caps_upper) << 16);
831
832        // Length of auth-plugin-data (if CLIENT_PLUGIN_AUTH)
833        let auth_data_len = if capabilities_val & capabilities::CLIENT_PLUGIN_AUTH != 0 {
834            reader.read_u8().unwrap_or(0) as usize
835        } else {
836            0
837        };
838
839        // Reserved (10 bytes)
840        reader.skip(10);
841
842        // Auth plugin data part 2 (if CLIENT_SECURE_CONNECTION)
843        let mut auth_data = auth_data_1.to_vec();
844        if capabilities_val & capabilities::CLIENT_SECURE_CONNECTION != 0 {
845            let len2 = if auth_data_len > 8 {
846                auth_data_len - 8
847            } else {
848                13 // Default length
849            };
850            if let Some(data2) = reader.read_bytes(len2) {
851                // Remove trailing NUL if present
852                let data2_clean = if data2.last() == Some(&0) {
853                    &data2[..data2.len() - 1]
854                } else {
855                    data2
856                };
857                auth_data.extend_from_slice(data2_clean);
858            }
859        }
860
861        // Auth plugin name (if CLIENT_PLUGIN_AUTH)
862        let auth_plugin = if capabilities_val & capabilities::CLIENT_PLUGIN_AUTH != 0 {
863            reader.read_null_string().unwrap_or_default()
864        } else {
865            auth::plugins::MYSQL_NATIVE_PASSWORD.to_string()
866        };
867
868        Outcome::Ok(ServerCapabilities {
869            capabilities: capabilities_val,
870            protocol_version,
871            server_version,
872            connection_id,
873            auth_plugin,
874            auth_data,
875            charset: charset_val,
876            status_flags,
877        })
878    }
879
880    /// Send the handshake response packet asynchronously.
881    async fn send_handshake_response_async(&mut self) -> Outcome<(), Error> {
882        let Some(server_caps) = self.server_caps.as_ref() else {
883            return Outcome::Err(protocol_error("No server handshake received"));
884        };
885
886        // Grab what we need up-front so we can mutably borrow `self` later.
887        let server_caps_bits = server_caps.capabilities;
888        let auth_plugin = server_caps.auth_plugin.clone();
889        let auth_data = server_caps.auth_data.clone();
890
891        // Determine client capabilities
892        let mut client_caps = self.config.capability_flags() & server_caps_bits;
893        #[cfg(feature = "tls")]
894        if let Outcome::Err(e) = self
895            .maybe_upgrade_tls_async(server_caps_bits, &mut client_caps)
896            .await
897        {
898            return Outcome::Err(e);
899        }
900
901        #[cfg(not(feature = "tls"))]
902        if let Outcome::Err(e) = self.maybe_upgrade_tls(server_caps_bits, &mut client_caps) {
903            return Outcome::Err(e);
904        }
905
906        // Build authentication response
907        let auth_response = self.compute_auth_response(&auth_plugin, &auth_data);
908
909        let mut writer = PacketWriter::new();
910
911        // Client capability flags (4 bytes)
912        writer.write_u32_le(client_caps);
913
914        // Max packet size (4 bytes)
915        writer.write_u32_le(self.config.max_packet_size);
916
917        // Character set (1 byte)
918        writer.write_u8(self.config.charset);
919
920        // Reserved (23 bytes of zeros)
921        writer.write_zeros(23);
922
923        // Username (null-terminated)
924        writer.write_null_string(&self.config.user);
925
926        // Auth response
927        if client_caps & capabilities::CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA != 0 {
928            writer.write_lenenc_bytes(&auth_response);
929        } else if client_caps & capabilities::CLIENT_SECURE_CONNECTION != 0 {
930            #[allow(clippy::cast_possible_truncation)]
931            writer.write_u8(auth_response.len() as u8);
932            writer.write_bytes(&auth_response);
933        } else {
934            writer.write_bytes(&auth_response);
935            writer.write_u8(0); // Null terminator
936        }
937
938        // Database (if CLIENT_CONNECT_WITH_DB)
939        if client_caps & capabilities::CLIENT_CONNECT_WITH_DB != 0 {
940            if let Some(ref db) = self.config.database {
941                writer.write_null_string(db);
942            } else {
943                writer.write_u8(0); // Empty string
944            }
945        }
946
947        // Auth plugin name (if CLIENT_PLUGIN_AUTH)
948        if client_caps & capabilities::CLIENT_PLUGIN_AUTH != 0 {
949            writer.write_null_string(&auth_plugin);
950        }
951
952        // Connection attributes (if CLIENT_CONNECT_ATTRS)
953        if client_caps & capabilities::CLIENT_CONNECT_ATTRS != 0
954            && !self.config.attributes.is_empty()
955        {
956            let mut attrs_writer = PacketWriter::new();
957            for (key, value) in &self.config.attributes {
958                attrs_writer.write_lenenc_string(key);
959                attrs_writer.write_lenenc_string(value);
960            }
961            let attrs_data = attrs_writer.into_bytes();
962            writer.write_lenenc_bytes(&attrs_data);
963        }
964
965        self.write_packet_async(writer.as_bytes()).await
966    }
967
968    #[cfg(feature = "tls")]
969    async fn maybe_upgrade_tls_async(
970        &mut self,
971        server_caps: u32,
972        client_caps: &mut u32,
973    ) -> Outcome<(), Error> {
974        let ssl_mode = self.config.ssl_mode;
975
976        if !ssl_mode.should_try_ssl() {
977            *client_caps &= !capabilities::CLIENT_SSL;
978            return Outcome::Ok(());
979        }
980
981        let use_tls = match crate::tls::validate_ssl_mode(ssl_mode, server_caps) {
982            Ok(v) => v,
983            Err(e) => return Outcome::Err(e),
984        };
985
986        if !use_tls {
987            // Preferred but server doesn't support SSL: clear the bit so we don't lie.
988            *client_caps &= !capabilities::CLIENT_SSL;
989            return Outcome::Ok(());
990        }
991
992        if let Err(e) = crate::tls::validate_tls_config(ssl_mode, &self.config.tls_config) {
993            return Outcome::Err(e);
994        }
995
996        // Send SSLRequest packet (sequence id 1), then upgrade to TLS and continue
997        // the normal handshake (sequence id 2) over the encrypted stream.
998        let packet = crate::tls::build_ssl_request_packet(
999            *client_caps,
1000            self.config.max_packet_size,
1001            self.config.charset,
1002            self.sequence_id,
1003        );
1004        if let Outcome::Err(e) = self.write_packet_raw_async(&packet).await {
1005            return Outcome::Err(e);
1006        }
1007        self.sequence_id = self.sequence_id.wrapping_add(1);
1008
1009        let Some(stream) = self.stream.take() else {
1010            return Outcome::Err(connection_error("Connection stream missing"));
1011        };
1012        let ConnectionStream::Async(tcp) = stream else {
1013            return Outcome::Err(connection_error("TLS upgrade requires async TCP stream"));
1014        };
1015
1016        let tls = match AsyncTlsStream::handshake(
1017            tcp,
1018            &self.config.tls_config,
1019            &self.config.host,
1020            ssl_mode,
1021        )
1022        .await
1023        {
1024            Ok(s) => s,
1025            Err(e) => return Outcome::Err(e),
1026        };
1027
1028        self.stream = Some(ConnectionStream::Tls(tls));
1029        Outcome::Ok(())
1030    }
1031
1032    #[cfg(not(feature = "tls"))]
1033    fn maybe_upgrade_tls(&mut self, server_caps: u32, client_caps: &mut u32) -> Outcome<(), Error> {
1034        let ssl_mode = self.config.ssl_mode;
1035
1036        if !ssl_mode.should_try_ssl() {
1037            *client_caps &= !capabilities::CLIENT_SSL;
1038            return Outcome::Ok(());
1039        }
1040
1041        let use_tls = match crate::tls::validate_ssl_mode(ssl_mode, server_caps) {
1042            Ok(v) => v,
1043            Err(e) => return Outcome::Err(e),
1044        };
1045
1046        if !use_tls {
1047            // Preferred but server doesn't support SSL: clear the bit so we don't lie.
1048            *client_caps &= !capabilities::CLIENT_SSL;
1049            return Outcome::Ok(());
1050        }
1051
1052        // Preferred should fall back to plain, required/verify must error.
1053        if ssl_mode == crate::config::SslMode::Preferred {
1054            *client_caps &= !capabilities::CLIENT_SSL;
1055            Outcome::Ok(())
1056        } else {
1057            Outcome::Err(connection_error(
1058                "TLS requested but 'sqlmodel-mysql' was built without feature 'tls'",
1059            ))
1060        }
1061    }
1062
1063    /// Compute authentication response based on the plugin.
1064    fn compute_auth_response(&self, plugin: &str, auth_data: &[u8]) -> Vec<u8> {
1065        let pw = self.config.password_str();
1066
1067        match plugin {
1068            // UBS secret-heuristic false-positive: it matches `PASSWORD\s*=`; the block comment breaks that pattern.
1069            auth::plugins::MYSQL_NATIVE_PASSWORD /* ubs-fp */ => {
1070                auth::mysql_native_password(pw, auth_data)
1071            }
1072            auth::plugins::CACHING_SHA2_PASSWORD /* ubs-fp */ => {
1073                auth::caching_sha2_password(pw, auth_data)
1074            }
1075            auth::plugins::MYSQL_CLEAR_PASSWORD /* ubs-fp */ => {
1076                let mut result = pw.as_bytes().to_vec();
1077                result.push(0);
1078                result
1079            }
1080            _ => auth::mysql_native_password(pw, auth_data),
1081        }
1082    }
1083
1084    /// Handle authentication result asynchronously.
1085    /// Uses a loop to handle auth switches without recursion.
1086    async fn handle_auth_result_async(&mut self) -> Outcome<(), Error> {
1087        // Loop to handle potential auth switches without recursion
1088        loop {
1089            let (payload, _) = match self.read_packet_async().await {
1090                Outcome::Ok(p) => p,
1091                Outcome::Err(e) => return Outcome::Err(e),
1092                Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1093                Outcome::Panicked(p) => return Outcome::Panicked(p),
1094            };
1095
1096            if payload.is_empty() {
1097                return Outcome::Err(protocol_error("Empty authentication response"));
1098            }
1099
1100            #[allow(clippy::cast_possible_truncation)] // MySQL packets are max 16MB
1101            match PacketType::from_first_byte(payload[0], payload.len() as u32) {
1102                PacketType::Ok => {
1103                    let mut reader = PacketReader::new(&payload);
1104                    if let Some(ok) = reader.parse_ok_packet() {
1105                        self.status_flags = ok.status_flags;
1106                        self.affected_rows = ok.affected_rows;
1107                    }
1108                    return Outcome::Ok(());
1109                }
1110                PacketType::Error => {
1111                    let mut reader = PacketReader::new(&payload);
1112                    let Some(err) = reader.parse_err_packet() else {
1113                        return Outcome::Err(protocol_error("Invalid error packet"));
1114                    };
1115                    return Outcome::Err(auth_error(format!(
1116                        "Authentication failed: {} ({})",
1117                        err.error_message, err.error_code
1118                    )));
1119                }
1120                PacketType::Eof => {
1121                    // Auth switch request - handle inline to avoid recursion
1122                    let data = &payload[1..];
1123                    let mut reader = PacketReader::new(data);
1124
1125                    let Some(plugin) = reader.read_null_string() else {
1126                        return Outcome::Err(protocol_error("Missing plugin name in auth switch"));
1127                    };
1128
1129                    let auth_data = reader.read_rest();
1130                    let response = self.compute_auth_response(&plugin, auth_data);
1131
1132                    if let Outcome::Err(e) = self.write_packet_async(&response).await {
1133                        return Outcome::Err(e);
1134                    }
1135                    // Continue loop to read next auth result
1136                }
1137                _ => {
1138                    // Handle additional auth data
1139                    match self.handle_additional_auth_async(&payload).await {
1140                        Outcome::Ok(()) => continue,
1141                        Outcome::Err(e) => return Outcome::Err(e),
1142                        Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1143                        Outcome::Panicked(p) => return Outcome::Panicked(p),
1144                    }
1145                }
1146            }
1147        }
1148    }
1149
1150    /// Handle additional auth data asynchronously.
1151    async fn handle_additional_auth_async(&mut self, data: &[u8]) -> Outcome<(), Error> {
1152        if data.is_empty() {
1153            return Outcome::Err(protocol_error("Empty additional auth data"));
1154        }
1155
1156        match data[0] {
1157            auth::caching_sha2::FAST_AUTH_SUCCESS => {
1158                // Server will send the final OK packet next; leave it for the main auth loop.
1159                Outcome::Ok(())
1160            }
1161            auth::caching_sha2::PERFORM_FULL_AUTH => {
1162                let Some(server_caps) = self.server_caps.as_ref() else {
1163                    return Outcome::Err(protocol_error("Missing server capabilities during auth"));
1164                };
1165
1166                let pw = self.config.password_owned();
1167                let seed = server_caps.auth_data.clone();
1168                let server_version = server_caps.server_version.clone();
1169
1170                if self.is_secure_transport() {
1171                    // On a secure transport (TLS), caching_sha2_password allows sending the
1172                    // password as a NUL-terminated string.
1173                    let mut clear = pw.as_bytes().to_vec();
1174                    clear.push(0);
1175                    if let Outcome::Err(e) = self.write_packet_async(&clear).await {
1176                        return Outcome::Err(e);
1177                    }
1178                    Outcome::Ok(())
1179                } else {
1180                    // Request server public key then send RSA-encrypted password.
1181                    if let Outcome::Err(e) = self
1182                        .write_packet_async(&[auth::caching_sha2::REQUEST_PUBLIC_KEY])
1183                        .await
1184                    {
1185                        return Outcome::Err(e);
1186                    }
1187
1188                    let (payload, _) = match self.read_packet_async().await {
1189                        Outcome::Ok(p) => p,
1190                        Outcome::Err(e) => return Outcome::Err(e),
1191                        Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1192                        Outcome::Panicked(p) => return Outcome::Panicked(p),
1193                    };
1194                    if payload.is_empty() {
1195                        return Outcome::Err(protocol_error("Empty public key response"));
1196                    }
1197
1198                    // Some servers wrap the PEM in an AuthMoreData packet (0x01 prefix).
1199                    let public_key = if payload[0] == 0x01 {
1200                        &payload[1..]
1201                    } else {
1202                        &payload[..]
1203                    };
1204
1205                    let use_oaep = mysql_server_uses_oaep(&server_version);
1206                    let encrypted =
1207                        match auth::sha256_password_rsa(&pw, &seed, public_key, use_oaep) {
1208                            Ok(v) => v,
1209                            Err(e) => return Outcome::Err(auth_error(e)),
1210                        };
1211
1212                    if let Outcome::Err(e) = self.write_packet_async(&encrypted).await {
1213                        return Outcome::Err(e);
1214                    }
1215                    Outcome::Ok(())
1216                }
1217            }
1218            _ => Outcome::Err(protocol_error(format!(
1219                "Unknown additional auth response: {:02X}",
1220                data[0]
1221            ))),
1222        }
1223    }
1224
1225    /// Execute a text protocol query asynchronously.
1226    pub async fn query_async(
1227        &mut self,
1228        _cx: &Cx,
1229        sql: &str,
1230        params: &[Value],
1231    ) -> Outcome<Vec<Row>, Error> {
1232        let sql = interpolate_params(sql, params);
1233        if !self.is_ready() && self.state != ConnectionState::InTransaction {
1234            return Outcome::Err(connection_error("Connection not ready for queries"));
1235        }
1236
1237        self.state = ConnectionState::InQuery;
1238        self.sequence_id = 0;
1239
1240        // Send COM_QUERY
1241        let mut writer = PacketWriter::new();
1242        writer.write_u8(Command::Query as u8);
1243        writer.write_bytes(sql.as_bytes());
1244
1245        if let Outcome::Err(e) = self.write_packet_async(writer.as_bytes()).await {
1246            return Outcome::Err(e);
1247        }
1248
1249        // Read response
1250        let (payload, _) = match self.read_packet_async().await {
1251            Outcome::Ok(p) => p,
1252            Outcome::Err(e) => return Outcome::Err(e),
1253            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1254            Outcome::Panicked(p) => return Outcome::Panicked(p),
1255        };
1256
1257        if payload.is_empty() {
1258            self.state = ConnectionState::Ready;
1259            return Outcome::Err(protocol_error("Empty query response"));
1260        }
1261
1262        #[allow(clippy::cast_possible_truncation)] // MySQL packets are max 16MB
1263        match PacketType::from_first_byte(payload[0], payload.len() as u32) {
1264            PacketType::Ok => {
1265                let mut reader = PacketReader::new(&payload);
1266                if let Some(ok) = reader.parse_ok_packet() {
1267                    self.affected_rows = ok.affected_rows;
1268                    self.last_insert_id = ok.last_insert_id;
1269                    self.status_flags = ok.status_flags;
1270                    self.warnings = ok.warnings;
1271                }
1272                self.state = if self.status_flags
1273                    & crate::protocol::server_status::SERVER_STATUS_IN_TRANS
1274                    != 0
1275                {
1276                    ConnectionState::InTransaction
1277                } else {
1278                    ConnectionState::Ready
1279                };
1280                Outcome::Ok(vec![])
1281            }
1282            PacketType::Error => {
1283                self.state = ConnectionState::Ready;
1284                let mut reader = PacketReader::new(&payload);
1285                let Some(err) = reader.parse_err_packet() else {
1286                    return Outcome::Err(protocol_error("Invalid error packet"));
1287                };
1288                Outcome::Err(query_error(&err))
1289            }
1290            PacketType::LocalInfile => {
1291                self.state = ConnectionState::Ready;
1292                Outcome::Err(query_error_msg("LOCAL INFILE not supported"))
1293            }
1294            _ => self.read_result_set_async(&payload).await,
1295        }
1296    }
1297
1298    /// Read a result set asynchronously.
1299    async fn read_result_set_async(&mut self, first_packet: &[u8]) -> Outcome<Vec<Row>, Error> {
1300        let mut reader = PacketReader::new(first_packet);
1301        #[allow(clippy::cast_possible_truncation)] // Column count fits in usize
1302        let Some(column_count) = reader.read_lenenc_int().map(|c| c as usize) else {
1303            return Outcome::Err(protocol_error("Invalid column count"));
1304        };
1305
1306        // Read column definitions
1307        let mut columns = Vec::with_capacity(column_count);
1308        for _ in 0..column_count {
1309            let (payload, _) = match self.read_packet_async().await {
1310                Outcome::Ok(p) => p,
1311                Outcome::Err(e) => return Outcome::Err(e),
1312                Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1313                Outcome::Panicked(p) => return Outcome::Panicked(p),
1314            };
1315            match self.parse_column_def(&payload) {
1316                Ok(col) => columns.push(col),
1317                Err(e) => return Outcome::Err(e),
1318            }
1319        }
1320
1321        // Check for EOF packet
1322        let server_caps = self.server_caps.as_ref().map_or(0, |c| c.capabilities);
1323        if server_caps & capabilities::CLIENT_DEPRECATE_EOF == 0 {
1324            let (payload, _) = match self.read_packet_async().await {
1325                Outcome::Ok(p) => p,
1326                Outcome::Err(e) => return Outcome::Err(e),
1327                Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1328                Outcome::Panicked(p) => return Outcome::Panicked(p),
1329            };
1330            if payload.first() == Some(&0xFE) {
1331                // EOF packet - continue to rows
1332            }
1333        }
1334
1335        // Read rows until EOF or OK
1336        let mut rows = Vec::new();
1337        loop {
1338            let (payload, _) = match self.read_packet_async().await {
1339                Outcome::Ok(p) => p,
1340                Outcome::Err(e) => return Outcome::Err(e),
1341                Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1342                Outcome::Panicked(p) => return Outcome::Panicked(p),
1343            };
1344
1345            if payload.is_empty() {
1346                break;
1347            }
1348
1349            #[allow(clippy::cast_possible_truncation)] // MySQL packets are max 16MB
1350            match PacketType::from_first_byte(payload[0], payload.len() as u32) {
1351                PacketType::Eof | PacketType::Ok => {
1352                    let mut reader = PacketReader::new(&payload);
1353                    if payload[0] == 0x00 {
1354                        if let Some(ok) = reader.parse_ok_packet() {
1355                            self.status_flags = ok.status_flags;
1356                            self.warnings = ok.warnings;
1357                        }
1358                    } else if payload[0] == 0xFE {
1359                        if let Some(eof) = reader.parse_eof_packet() {
1360                            self.status_flags = eof.status_flags;
1361                            self.warnings = eof.warnings;
1362                        }
1363                    }
1364                    break;
1365                }
1366                PacketType::Error => {
1367                    let mut reader = PacketReader::new(&payload);
1368                    let Some(err) = reader.parse_err_packet() else {
1369                        return Outcome::Err(protocol_error("Invalid error packet"));
1370                    };
1371                    self.state = ConnectionState::Ready;
1372                    return Outcome::Err(query_error(&err));
1373                }
1374                _ => {
1375                    let row = self.parse_text_row(&payload, &columns);
1376                    rows.push(row);
1377                }
1378            }
1379        }
1380
1381        self.state =
1382            if self.status_flags & crate::protocol::server_status::SERVER_STATUS_IN_TRANS != 0 {
1383                ConnectionState::InTransaction
1384            } else {
1385                ConnectionState::Ready
1386            };
1387
1388        Outcome::Ok(rows)
1389    }
1390
1391    /// Parse a column definition packet.
1392    fn parse_column_def(&self, data: &[u8]) -> Result<ColumnDef, Error> {
1393        let mut reader = PacketReader::new(data);
1394
1395        let catalog = reader
1396            .read_lenenc_string()
1397            .ok_or_else(|| protocol_error("Missing catalog"))?;
1398        let schema = reader
1399            .read_lenenc_string()
1400            .ok_or_else(|| protocol_error("Missing schema"))?;
1401        let table = reader
1402            .read_lenenc_string()
1403            .ok_or_else(|| protocol_error("Missing table"))?;
1404        let org_table = reader
1405            .read_lenenc_string()
1406            .ok_or_else(|| protocol_error("Missing org_table"))?;
1407        let name = reader
1408            .read_lenenc_string()
1409            .ok_or_else(|| protocol_error("Missing name"))?;
1410        let org_name = reader
1411            .read_lenenc_string()
1412            .ok_or_else(|| protocol_error("Missing org_name"))?;
1413
1414        let _fixed_len = reader.read_lenenc_int();
1415
1416        let charset_val = reader
1417            .read_u16_le()
1418            .ok_or_else(|| protocol_error("Missing charset"))?;
1419        let column_length = reader
1420            .read_u32_le()
1421            .ok_or_else(|| protocol_error("Missing column_length"))?;
1422        let column_type = FieldType::from_u8(
1423            reader
1424                .read_u8()
1425                .ok_or_else(|| protocol_error("Missing column_type"))?,
1426        );
1427        let flags = reader
1428            .read_u16_le()
1429            .ok_or_else(|| protocol_error("Missing flags"))?;
1430        let decimals = reader
1431            .read_u8()
1432            .ok_or_else(|| protocol_error("Missing decimals"))?;
1433
1434        Ok(ColumnDef {
1435            catalog,
1436            schema,
1437            table,
1438            org_table,
1439            name,
1440            org_name,
1441            charset: charset_val,
1442            column_length,
1443            column_type,
1444            flags,
1445            decimals,
1446        })
1447    }
1448
1449    /// Parse a text protocol row.
1450    fn parse_text_row(&self, data: &[u8], columns: &[ColumnDef]) -> Row {
1451        let mut reader = PacketReader::new(data);
1452        let mut values = Vec::with_capacity(columns.len());
1453
1454        for col in columns {
1455            if reader.peek() == Some(0xFB) {
1456                reader.skip(1);
1457                values.push(Value::Null);
1458            } else if let Some(data) = reader.read_lenenc_bytes() {
1459                let is_unsigned = col.is_unsigned();
1460                let value = decode_text_value(col.column_type, &data, is_unsigned);
1461                values.push(value);
1462            } else {
1463                values.push(Value::Null);
1464            }
1465        }
1466
1467        let column_names: Vec<String> = columns.iter().map(|c| c.name.clone()).collect();
1468        Row::new(column_names, values)
1469    }
1470
1471    /// Execute a statement asynchronously and return affected rows.
1472    ///
1473    /// This is similar to `query_async` but returns the number of affected rows
1474    /// instead of the result set. Useful for INSERT, UPDATE, DELETE statements.
1475    pub async fn execute_async(
1476        &mut self,
1477        cx: &Cx,
1478        sql: &str,
1479        params: &[Value],
1480    ) -> Outcome<u64, Error> {
1481        // Execute the query
1482        match self.query_async(cx, sql, params).await {
1483            Outcome::Ok(_) => Outcome::Ok(self.affected_rows),
1484            Outcome::Err(e) => Outcome::Err(e),
1485            Outcome::Cancelled(c) => Outcome::Cancelled(c),
1486            Outcome::Panicked(p) => Outcome::Panicked(p),
1487        }
1488    }
1489
1490    /// Prepare a statement for later execution using the binary protocol.
1491    ///
1492    /// This sends COM_STMT_PREPARE to the server and stores the metadata
1493    /// needed for later execution via `query_prepared_async` or `execute_prepared_async`.
1494    pub async fn prepare_async(
1495        &mut self,
1496        _cx: &Cx,
1497        sql: &str,
1498    ) -> Outcome<PreparedStatement, Error> {
1499        if !self.is_ready() && self.state != ConnectionState::InTransaction {
1500            return Outcome::Err(connection_error("Connection not ready for prepare"));
1501        }
1502
1503        self.sequence_id = 0;
1504
1505        // Send COM_STMT_PREPARE
1506        let packet = prepared::build_stmt_prepare_packet(sql, self.sequence_id);
1507        if let Outcome::Err(e) = self.write_packet_raw_async(&packet).await {
1508            return Outcome::Err(e);
1509        }
1510
1511        // Read response
1512        let (payload, _) = match self.read_packet_async().await {
1513            Outcome::Ok(p) => p,
1514            Outcome::Err(e) => return Outcome::Err(e),
1515            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1516            Outcome::Panicked(p) => return Outcome::Panicked(p),
1517        };
1518
1519        // Check for error
1520        if payload.first() == Some(&0xFF) {
1521            let mut reader = PacketReader::new(&payload);
1522            let Some(err) = reader.parse_err_packet() else {
1523                return Outcome::Err(protocol_error("Invalid error packet"));
1524            };
1525            return Outcome::Err(query_error(&err));
1526        }
1527
1528        // Parse COM_STMT_PREPARE_OK
1529        let Some(prep_ok) = prepared::parse_stmt_prepare_ok(&payload) else {
1530            return Outcome::Err(protocol_error("Invalid prepare OK response"));
1531        };
1532
1533        // Read parameter column definitions
1534        let mut param_defs = Vec::with_capacity(prep_ok.num_params as usize);
1535        for _ in 0..prep_ok.num_params {
1536            let (payload, _) = match self.read_packet_async().await {
1537                Outcome::Ok(p) => p,
1538                Outcome::Err(e) => return Outcome::Err(e),
1539                Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1540                Outcome::Panicked(p) => return Outcome::Panicked(p),
1541            };
1542            match self.parse_column_def(&payload) {
1543                Ok(col) => param_defs.push(col),
1544                Err(e) => return Outcome::Err(e),
1545            }
1546        }
1547
1548        // Read EOF after params (if not CLIENT_DEPRECATE_EOF)
1549        let server_caps = self.server_caps.as_ref().map_or(0, |c| c.capabilities);
1550        if prep_ok.num_params > 0 && server_caps & capabilities::CLIENT_DEPRECATE_EOF == 0 {
1551            let (payload, _) = match self.read_packet_async().await {
1552                Outcome::Ok(p) => p,
1553                Outcome::Err(e) => return Outcome::Err(e),
1554                Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1555                Outcome::Panicked(p) => return Outcome::Panicked(p),
1556            };
1557            if payload.first() != Some(&0xFE) {
1558                return Outcome::Err(protocol_error("Expected EOF after param definitions"));
1559            }
1560        }
1561
1562        // Read column definitions
1563        let mut column_defs = Vec::with_capacity(prep_ok.num_columns as usize);
1564        for _ in 0..prep_ok.num_columns {
1565            let (payload, _) = match self.read_packet_async().await {
1566                Outcome::Ok(p) => p,
1567                Outcome::Err(e) => return Outcome::Err(e),
1568                Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1569                Outcome::Panicked(p) => return Outcome::Panicked(p),
1570            };
1571            match self.parse_column_def(&payload) {
1572                Ok(col) => column_defs.push(col),
1573                Err(e) => return Outcome::Err(e),
1574            }
1575        }
1576
1577        // Read EOF after columns (if not CLIENT_DEPRECATE_EOF)
1578        if prep_ok.num_columns > 0 && server_caps & capabilities::CLIENT_DEPRECATE_EOF == 0 {
1579            let (payload, _) = match self.read_packet_async().await {
1580                Outcome::Ok(p) => p,
1581                Outcome::Err(e) => return Outcome::Err(e),
1582                Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1583                Outcome::Panicked(p) => return Outcome::Panicked(p),
1584            };
1585            if payload.first() != Some(&0xFE) {
1586                return Outcome::Err(protocol_error("Expected EOF after column definitions"));
1587            }
1588        }
1589
1590        // Store metadata
1591        let meta = PreparedStmtMeta {
1592            statement_id: prep_ok.statement_id,
1593            params: param_defs,
1594            columns: column_defs.clone(),
1595        };
1596        self.prepared_stmts.insert(prep_ok.statement_id, meta);
1597
1598        // Return core PreparedStatement
1599        let column_names: Vec<String> = column_defs.iter().map(|c| c.name.clone()).collect();
1600        Outcome::Ok(PreparedStatement::with_columns(
1601            u64::from(prep_ok.statement_id),
1602            sql.to_string(),
1603            prep_ok.num_params as usize,
1604            column_names,
1605        ))
1606    }
1607
1608    /// Execute a prepared statement and return result rows (binary protocol).
1609    pub async fn query_prepared_async(
1610        &mut self,
1611        _cx: &Cx,
1612        stmt: &PreparedStatement,
1613        params: &[Value],
1614    ) -> Outcome<Vec<Row>, Error> {
1615        #[allow(clippy::cast_possible_truncation)] // Statement IDs are u32 in MySQL
1616        let stmt_id = stmt.id() as u32;
1617
1618        // Look up metadata
1619        let Some(meta) = self.prepared_stmts.get(&stmt_id).cloned() else {
1620            return Outcome::Err(connection_error("Unknown prepared statement"));
1621        };
1622
1623        // Verify param count
1624        if params.len() != meta.params.len() {
1625            return Outcome::Err(connection_error(format!(
1626                "Expected {} parameters, got {}",
1627                meta.params.len(),
1628                params.len()
1629            )));
1630        }
1631
1632        if !self.is_ready() && self.state != ConnectionState::InTransaction {
1633            return Outcome::Err(connection_error("Connection not ready for query"));
1634        }
1635
1636        self.state = ConnectionState::InQuery;
1637        self.sequence_id = 0;
1638
1639        // Build and send COM_STMT_EXECUTE
1640        let param_types: Vec<FieldType> = meta.params.iter().map(|c| c.column_type).collect();
1641        let packet = prepared::build_stmt_execute_packet(
1642            stmt_id,
1643            params,
1644            Some(&param_types),
1645            self.sequence_id,
1646        );
1647        if let Outcome::Err(e) = self.write_packet_raw_async(&packet).await {
1648            return Outcome::Err(e);
1649        }
1650
1651        // Read response
1652        let (payload, _) = match self.read_packet_async().await {
1653            Outcome::Ok(p) => p,
1654            Outcome::Err(e) => return Outcome::Err(e),
1655            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1656            Outcome::Panicked(p) => return Outcome::Panicked(p),
1657        };
1658
1659        if payload.is_empty() {
1660            self.state = ConnectionState::Ready;
1661            return Outcome::Err(protocol_error("Empty execute response"));
1662        }
1663
1664        #[allow(clippy::cast_possible_truncation)] // MySQL packets are max 16MB
1665        match PacketType::from_first_byte(payload[0], payload.len() as u32) {
1666            PacketType::Ok => {
1667                // Non-SELECT statement - parse OK packet
1668                let mut reader = PacketReader::new(&payload);
1669                if let Some(ok) = reader.parse_ok_packet() {
1670                    self.affected_rows = ok.affected_rows;
1671                    self.last_insert_id = ok.last_insert_id;
1672                    self.status_flags = ok.status_flags;
1673                    self.warnings = ok.warnings;
1674                }
1675                self.state = ConnectionState::Ready;
1676                Outcome::Ok(vec![])
1677            }
1678            PacketType::Error => {
1679                self.state = ConnectionState::Ready;
1680                let mut reader = PacketReader::new(&payload);
1681                let Some(err) = reader.parse_err_packet() else {
1682                    return Outcome::Err(protocol_error("Invalid error packet"));
1683                };
1684                Outcome::Err(query_error(&err))
1685            }
1686            _ => {
1687                // Result set - read binary protocol rows
1688                self.read_binary_result_set_async(&payload, &meta.columns)
1689                    .await
1690            }
1691        }
1692    }
1693
1694    /// Execute a prepared statement and return affected row count.
1695    pub async fn execute_prepared_async(
1696        &mut self,
1697        cx: &Cx,
1698        stmt: &PreparedStatement,
1699        params: &[Value],
1700    ) -> Outcome<u64, Error> {
1701        match self.query_prepared_async(cx, stmt, params).await {
1702            Outcome::Ok(_) => Outcome::Ok(self.affected_rows),
1703            Outcome::Err(e) => Outcome::Err(e),
1704            Outcome::Cancelled(c) => Outcome::Cancelled(c),
1705            Outcome::Panicked(p) => Outcome::Panicked(p),
1706        }
1707    }
1708
1709    /// Close a prepared statement.
1710    pub async fn close_prepared_async(&mut self, stmt: &PreparedStatement) {
1711        #[allow(clippy::cast_possible_truncation)] // Statement IDs are u32 in MySQL
1712        let stmt_id = stmt.id() as u32;
1713        self.prepared_stmts.remove(&stmt_id);
1714
1715        self.sequence_id = 0;
1716        let packet = prepared::build_stmt_close_packet(stmt_id, self.sequence_id);
1717        // Best effort - no response expected
1718        let _ = self.write_packet_raw_async(&packet).await;
1719    }
1720
1721    /// Read a binary protocol result set.
1722    async fn read_binary_result_set_async(
1723        &mut self,
1724        first_packet: &[u8],
1725        columns: &[ColumnDef],
1726    ) -> Outcome<Vec<Row>, Error> {
1727        // First packet contains column count
1728        let mut reader = PacketReader::new(first_packet);
1729        #[allow(clippy::cast_possible_truncation)] // Column count fits in usize
1730        let Some(column_count) = reader.read_lenenc_int().map(|c| c as usize) else {
1731            return Outcome::Err(protocol_error("Invalid column count"));
1732        };
1733
1734        // The column definitions were already provided from prepare
1735        // But server sends them again in binary result set - we need to read them
1736        let mut result_columns = Vec::with_capacity(column_count);
1737        for _ in 0..column_count {
1738            let (payload, _) = match self.read_packet_async().await {
1739                Outcome::Ok(p) => p,
1740                Outcome::Err(e) => return Outcome::Err(e),
1741                Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1742                Outcome::Panicked(p) => return Outcome::Panicked(p),
1743            };
1744            match self.parse_column_def(&payload) {
1745                Ok(col) => result_columns.push(col),
1746                Err(e) => return Outcome::Err(e),
1747            }
1748        }
1749
1750        // Use the columns from the result set if available, otherwise use prepared metadata
1751        let cols = if result_columns.len() == columns.len() {
1752            &result_columns
1753        } else {
1754            columns
1755        };
1756
1757        // Check for EOF packet
1758        let server_caps = self.server_caps.as_ref().map_or(0, |c| c.capabilities);
1759        if server_caps & capabilities::CLIENT_DEPRECATE_EOF == 0 {
1760            let (payload, _) = match self.read_packet_async().await {
1761                Outcome::Ok(p) => p,
1762                Outcome::Err(e) => return Outcome::Err(e),
1763                Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1764                Outcome::Panicked(p) => return Outcome::Panicked(p),
1765            };
1766            if payload.first() == Some(&0xFE) {
1767                // EOF packet - continue to rows
1768            }
1769        }
1770
1771        // Read binary rows until EOF or OK
1772        let mut rows = Vec::new();
1773        loop {
1774            let (payload, _) = match self.read_packet_async().await {
1775                Outcome::Ok(p) => p,
1776                Outcome::Err(e) => return Outcome::Err(e),
1777                Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1778                Outcome::Panicked(p) => return Outcome::Panicked(p),
1779            };
1780
1781            if payload.is_empty() {
1782                break;
1783            }
1784
1785            #[allow(clippy::cast_possible_truncation)] // MySQL packets are max 16MB
1786            match PacketType::from_first_byte(payload[0], payload.len() as u32) {
1787                PacketType::Eof | PacketType::Ok => {
1788                    let mut reader = PacketReader::new(&payload);
1789                    if payload[0] == 0x00 {
1790                        if let Some(ok) = reader.parse_ok_packet() {
1791                            self.status_flags = ok.status_flags;
1792                            self.warnings = ok.warnings;
1793                        }
1794                    } else if payload[0] == 0xFE {
1795                        if let Some(eof) = reader.parse_eof_packet() {
1796                            self.status_flags = eof.status_flags;
1797                            self.warnings = eof.warnings;
1798                        }
1799                    }
1800                    break;
1801                }
1802                PacketType::Error => {
1803                    let mut reader = PacketReader::new(&payload);
1804                    let Some(err) = reader.parse_err_packet() else {
1805                        return Outcome::Err(protocol_error("Invalid error packet"));
1806                    };
1807                    self.state = ConnectionState::Ready;
1808                    return Outcome::Err(query_error(&err));
1809                }
1810                _ => {
1811                    let row = self.parse_binary_row(&payload, cols);
1812                    rows.push(row);
1813                }
1814            }
1815        }
1816
1817        self.state =
1818            if self.status_flags & crate::protocol::server_status::SERVER_STATUS_IN_TRANS != 0 {
1819                ConnectionState::InTransaction
1820            } else {
1821                ConnectionState::Ready
1822            };
1823
1824        Outcome::Ok(rows)
1825    }
1826
1827    /// Parse a binary protocol row.
1828    fn parse_binary_row(&self, data: &[u8], columns: &[ColumnDef]) -> Row {
1829        // Binary row format:
1830        // - 0x00 header (1 byte)
1831        // - NULL bitmap ((column_count + 7 + 2) / 8 bytes)
1832        // - Column values (only non-NULL)
1833
1834        let mut values = Vec::with_capacity(columns.len());
1835        let mut column_names = Vec::with_capacity(columns.len());
1836
1837        if data.is_empty() {
1838            return Row::new(column_names, values);
1839        }
1840
1841        // Skip header byte (0x00)
1842        let mut pos = 1;
1843
1844        // NULL bitmap: (column_count + 7 + 2) / 8 bytes
1845        // The +2 offset is for the reserved bits at the beginning
1846        let null_bitmap_len = (columns.len() + 7 + 2) / 8;
1847        if pos + null_bitmap_len > data.len() {
1848            return Row::new(column_names, values);
1849        }
1850        let null_bitmap = &data[pos..pos + null_bitmap_len];
1851        pos += null_bitmap_len;
1852
1853        // Parse column values
1854        for (i, col) in columns.iter().enumerate() {
1855            column_names.push(col.name.clone());
1856
1857            // Check NULL bitmap (bit position is i + 2 due to offset)
1858            let bit_pos = i + 2;
1859            let is_null = (null_bitmap[bit_pos / 8] & (1 << (bit_pos % 8))) != 0;
1860
1861            if is_null {
1862                values.push(Value::Null);
1863            } else {
1864                let is_unsigned = col.flags & 0x20 != 0; // UNSIGNED_FLAG
1865                let (value, consumed) =
1866                    decode_binary_value_with_len(&data[pos..], col.column_type, is_unsigned);
1867                values.push(value);
1868                pos += consumed;
1869            }
1870        }
1871
1872        Row::new(column_names, values)
1873    }
1874
1875    /// Write a pre-built packet (with header already included).
1876    async fn write_packet_raw_async(&mut self, packet: &[u8]) -> Outcome<(), Error> {
1877        let Some(stream) = self.stream.as_mut() else {
1878            return Outcome::Err(connection_error("Connection stream missing"));
1879        };
1880        match stream {
1881            ConnectionStream::Async(stream) => {
1882                let mut written = 0;
1883                while written < packet.len() {
1884                    match std::future::poll_fn(|cx| {
1885                        std::pin::Pin::new(&mut *stream).poll_write(cx, &packet[written..])
1886                    })
1887                    .await
1888                    {
1889                        Ok(n) => written += n,
1890                        Err(e) => {
1891                            return Outcome::Err(Error::Connection(ConnectionError {
1892                                kind: ConnectionErrorKind::Disconnected,
1893                                message: format!("Failed to write packet: {}", e),
1894                                source: Some(Box::new(e)),
1895                            }));
1896                        }
1897                    }
1898                }
1899                // Flush
1900                if let Err(e) =
1901                    std::future::poll_fn(|cx| std::pin::Pin::new(&mut *stream).poll_flush(cx)).await
1902                {
1903                    return Outcome::Err(Error::Connection(ConnectionError {
1904                        kind: ConnectionErrorKind::Disconnected,
1905                        message: format!("Failed to flush: {}", e),
1906                        source: Some(Box::new(e)),
1907                    }));
1908                }
1909                Outcome::Ok(())
1910            }
1911            ConnectionStream::Sync(stream) => {
1912                if let Err(e) = stream.write_all(packet) {
1913                    return Outcome::Err(Error::Connection(ConnectionError {
1914                        kind: ConnectionErrorKind::Disconnected,
1915                        message: format!("Failed to write packet: {}", e),
1916                        source: Some(Box::new(e)),
1917                    }));
1918                }
1919                if let Err(e) = stream.flush() {
1920                    return Outcome::Err(Error::Connection(ConnectionError {
1921                        kind: ConnectionErrorKind::Disconnected,
1922                        message: format!("Failed to flush: {}", e),
1923                        source: Some(Box::new(e)),
1924                    }));
1925                }
1926                Outcome::Ok(())
1927            }
1928            #[cfg(feature = "tls")]
1929            ConnectionStream::Tls(_) => Outcome::Err(connection_error(
1930                "write_packet_raw_async called after TLS upgrade (bug)",
1931            )),
1932        }
1933    }
1934
1935    /// Ping the server asynchronously.
1936    pub async fn ping_async(&mut self, _cx: &Cx) -> Outcome<(), Error> {
1937        self.sequence_id = 0;
1938
1939        let mut writer = PacketWriter::new();
1940        writer.write_u8(Command::Ping as u8);
1941
1942        if let Outcome::Err(e) = self.write_packet_async(writer.as_bytes()).await {
1943            return Outcome::Err(e);
1944        }
1945
1946        let (payload, _) = match self.read_packet_async().await {
1947            Outcome::Ok(p) => p,
1948            Outcome::Err(e) => return Outcome::Err(e),
1949            Outcome::Cancelled(r) => return Outcome::Cancelled(r),
1950            Outcome::Panicked(p) => return Outcome::Panicked(p),
1951        };
1952
1953        if payload.first() == Some(&0x00) {
1954            Outcome::Ok(())
1955        } else {
1956            Outcome::Err(connection_error("Ping failed"))
1957        }
1958    }
1959
1960    /// Close the connection asynchronously.
1961    pub async fn close_async(mut self, _cx: &Cx) -> Result<(), Error> {
1962        if self.state == ConnectionState::Closed {
1963            return Ok(());
1964        }
1965
1966        self.sequence_id = 0;
1967
1968        let mut writer = PacketWriter::new();
1969        writer.write_u8(Command::Quit as u8);
1970
1971        // Best effort - ignore errors on close
1972        let _ = self.write_packet_async(writer.as_bytes()).await;
1973
1974        self.state = ConnectionState::Closed;
1975        Ok(())
1976    }
1977}
1978
1979// === Console integration ===
1980
1981#[cfg(feature = "console")]
1982impl ConsoleAware for MySqlAsyncConnection {
1983    fn set_console(&mut self, console: Option<Arc<SqlModelConsole>>) {
1984        self.console = console;
1985    }
1986
1987    fn console(&self) -> Option<&Arc<SqlModelConsole>> {
1988        self.console.as_ref()
1989    }
1990}
1991
1992// === Helper functions ===
1993
1994fn protocol_error(msg: impl Into<String>) -> Error {
1995    Error::Protocol(ProtocolError {
1996        message: msg.into(),
1997        raw_data: None,
1998        source: None,
1999    })
2000}
2001
2002fn auth_error(msg: impl Into<String>) -> Error {
2003    Error::Connection(ConnectionError {
2004        kind: ConnectionErrorKind::Authentication,
2005        message: msg.into(),
2006        source: None,
2007    })
2008}
2009
2010fn connection_error(msg: impl Into<String>) -> Error {
2011    Error::Connection(ConnectionError {
2012        kind: ConnectionErrorKind::Connect,
2013        message: msg.into(),
2014        source: None,
2015    })
2016}
2017
2018fn mysql_server_uses_oaep(server_version: &str) -> bool {
2019    // MySQL 8.0.5+ uses OAEP for caching_sha2_password RSA encryption.
2020    // Parse leading "major.minor.patch" prefix; if parsing fails, default to OAEP.
2021    let prefix: String = server_version
2022        .chars()
2023        .take_while(|c| c.is_ascii_digit() || *c == '.')
2024        .collect();
2025    let mut it = prefix.split('.').filter(|s| !s.is_empty());
2026    let major: u64 = match it.next().and_then(|s| s.parse().ok()) {
2027        Some(v) => v,
2028        None => return true,
2029    };
2030    let minor: u64 = it.next().and_then(|s| s.parse().ok()).unwrap_or(0);
2031    let patch: u64 = it.next().and_then(|s| s.parse().ok()).unwrap_or(0);
2032
2033    (major, minor, patch) >= (8, 0, 5)
2034}
2035
2036fn query_error(err: &ErrPacket) -> Error {
2037    let kind = if err.is_duplicate_key() || err.is_foreign_key_violation() {
2038        QueryErrorKind::Constraint
2039    } else {
2040        QueryErrorKind::Syntax
2041    };
2042
2043    Error::Query(QueryError {
2044        kind,
2045        message: err.error_message.clone(),
2046        sqlstate: Some(err.sql_state.clone()),
2047        sql: None,
2048        detail: None,
2049        hint: None,
2050        position: None,
2051        source: None,
2052    })
2053}
2054
2055fn query_error_msg(msg: impl Into<String>) -> Error {
2056    Error::Query(QueryError {
2057        kind: QueryErrorKind::Syntax,
2058        message: msg.into(),
2059        sqlstate: None,
2060        sql: None,
2061        detail: None,
2062        hint: None,
2063        position: None,
2064        source: None,
2065    })
2066}
2067
2068/// Validate a savepoint name to prevent SQL injection.
2069///
2070/// MySQL identifiers must:
2071/// - Not be empty
2072/// - Start with a letter or underscore
2073/// - Contain only letters, digits, underscores, or dollar signs
2074/// - Be at most 64 characters
2075fn validate_savepoint_name(name: &str) -> Result<(), Error> {
2076    if name.is_empty() {
2077        return Err(query_error_msg("Savepoint name cannot be empty"));
2078    }
2079    if name.len() > 64 {
2080        return Err(query_error_msg(
2081            "Savepoint name exceeds maximum length of 64 characters",
2082        ));
2083    }
2084    let mut chars = name.chars();
2085    let Some(first) = chars.next() else {
2086        // Defensive: `is_empty()` was checked above.
2087        return Err(query_error_msg("Savepoint name cannot be empty"));
2088    };
2089    if !first.is_ascii_alphabetic() && first != '_' {
2090        return Err(query_error_msg(
2091            "Savepoint name must start with a letter or underscore",
2092        ));
2093    }
2094    for c in chars {
2095        if !c.is_ascii_alphanumeric() && c != '_' && c != '$' {
2096            return Err(query_error_msg(format!(
2097                "Savepoint name contains invalid character: '{}'",
2098                c
2099            )));
2100        }
2101    }
2102    Ok(())
2103}
2104
2105// === Shared connection wrapper ===
2106
2107/// A thread-safe, shared MySQL connection with interior mutability.
2108///
2109/// This wrapper allows the `Connection` trait to be implemented properly
2110/// by wrapping the raw `MySqlAsyncConnection` in an async mutex.
2111///
2112/// # Example
2113///
2114/// ```ignore
2115/// let conn = MySqlAsyncConnection::connect(&cx, config).await?;
2116/// let shared = SharedMySqlConnection::new(conn);
2117///
2118/// // Now you can use &shared with the Connection trait
2119/// let rows = shared.query(&cx, "SELECT * FROM users", &[]).await?;
2120/// ```
2121pub struct SharedMySqlConnection {
2122    inner: Arc<Mutex<MySqlAsyncConnection>>,
2123}
2124
2125impl SharedMySqlConnection {
2126    /// Create a new shared connection from a raw connection.
2127    pub fn new(conn: MySqlAsyncConnection) -> Self {
2128        Self {
2129            inner: Arc::new(Mutex::new(conn)),
2130        }
2131    }
2132
2133    /// Create a new shared connection by connecting to the server.
2134    pub async fn connect(cx: &Cx, config: MySqlConfig) -> Outcome<Self, Error> {
2135        match MySqlAsyncConnection::connect(cx, config).await {
2136            Outcome::Ok(conn) => Outcome::Ok(Self::new(conn)),
2137            Outcome::Err(e) => Outcome::Err(e),
2138            Outcome::Cancelled(c) => Outcome::Cancelled(c),
2139            Outcome::Panicked(p) => Outcome::Panicked(p),
2140        }
2141    }
2142
2143    /// Get the inner Arc for cloning.
2144    pub fn inner(&self) -> &Arc<Mutex<MySqlAsyncConnection>> {
2145        &self.inner
2146    }
2147}
2148
2149impl Clone for SharedMySqlConnection {
2150    fn clone(&self) -> Self {
2151        Self {
2152            inner: Arc::clone(&self.inner),
2153        }
2154    }
2155}
2156
2157impl std::fmt::Debug for SharedMySqlConnection {
2158    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2159        f.debug_struct("SharedMySqlConnection")
2160            .field("inner", &"Arc<Mutex<MySqlAsyncConnection>>")
2161            .finish()
2162    }
2163}
2164
2165/// Transaction type for SharedMySqlConnection.
2166///
2167/// This transaction holds a clone of the Arc to the connection and executes
2168/// transaction operations by acquiring the mutex lock for each operation.
2169/// The transaction must be committed or rolled back explicitly.
2170///
2171/// # Warning: Uncommitted Transactions
2172///
2173/// If a transaction is dropped without calling `commit()` or `rollback()`,
2174/// the underlying MySQL transaction will remain open until the connection
2175/// is closed or a new transaction is started. This is because Rust's `Drop`
2176/// trait cannot perform async operations.
2177///
2178/// **Always explicitly call `commit()` or `rollback()` before dropping.**
2179///
2180/// Note: The lifetime parameter is required by the Connection trait but the
2181/// actual implementation holds an owned Arc, so the transaction can outlive
2182/// the reference to SharedMySqlConnection if needed.
2183pub struct SharedMySqlTransaction<'conn> {
2184    inner: Arc<Mutex<MySqlAsyncConnection>>,
2185    committed: bool,
2186    _marker: std::marker::PhantomData<&'conn ()>,
2187}
2188
2189impl SharedMySqlConnection {
2190    /// Internal implementation for beginning a transaction.
2191    async fn begin_transaction_impl(
2192        &self,
2193        cx: &Cx,
2194        isolation: Option<IsolationLevel>,
2195    ) -> Outcome<SharedMySqlTransaction<'_>, Error> {
2196        let inner = Arc::clone(&self.inner);
2197
2198        // Acquire lock
2199        let Ok(mut guard) = inner.lock(cx).await else {
2200            return Outcome::Err(connection_error("Failed to acquire connection lock"));
2201        };
2202
2203        // Set isolation level if specified
2204        if let Some(level) = isolation {
2205            let isolation_sql = format!("SET TRANSACTION ISOLATION LEVEL {}", level.as_sql());
2206            match guard.execute_async(cx, &isolation_sql, &[]).await {
2207                Outcome::Ok(_) => {}
2208                Outcome::Err(e) => return Outcome::Err(e),
2209                Outcome::Cancelled(c) => return Outcome::Cancelled(c),
2210                Outcome::Panicked(p) => return Outcome::Panicked(p),
2211            }
2212        }
2213
2214        // Start transaction
2215        match guard.execute_async(cx, "BEGIN", &[]).await {
2216            Outcome::Ok(_) => {}
2217            Outcome::Err(e) => return Outcome::Err(e),
2218            Outcome::Cancelled(c) => return Outcome::Cancelled(c),
2219            Outcome::Panicked(p) => return Outcome::Panicked(p),
2220        }
2221
2222        drop(guard);
2223
2224        Outcome::Ok(SharedMySqlTransaction {
2225            inner,
2226            committed: false,
2227            _marker: std::marker::PhantomData,
2228        })
2229    }
2230}
2231
2232impl Connection for SharedMySqlConnection {
2233    type Tx<'conn>
2234        = SharedMySqlTransaction<'conn>
2235    where
2236        Self: 'conn;
2237
2238    fn dialect(&self) -> sqlmodel_core::Dialect {
2239        sqlmodel_core::Dialect::Mysql
2240    }
2241
2242    fn query(
2243        &self,
2244        cx: &Cx,
2245        sql: &str,
2246        params: &[Value],
2247    ) -> impl Future<Output = Outcome<Vec<Row>, Error>> + Send {
2248        let inner = Arc::clone(&self.inner);
2249        let sql = sql.to_string();
2250        let params = params.to_vec();
2251        async move {
2252            let Ok(mut guard) = inner.lock(cx).await else {
2253                return Outcome::Err(connection_error("Failed to acquire connection lock"));
2254            };
2255            guard.query_async(cx, &sql, &params).await
2256        }
2257    }
2258
2259    fn query_one(
2260        &self,
2261        cx: &Cx,
2262        sql: &str,
2263        params: &[Value],
2264    ) -> impl Future<Output = Outcome<Option<Row>, Error>> + Send {
2265        let inner = Arc::clone(&self.inner);
2266        let sql = sql.to_string();
2267        let params = params.to_vec();
2268        async move {
2269            let Ok(mut guard) = inner.lock(cx).await else {
2270                return Outcome::Err(connection_error("Failed to acquire connection lock"));
2271            };
2272            let rows = match guard.query_async(cx, &sql, &params).await {
2273                Outcome::Ok(r) => r,
2274                Outcome::Err(e) => return Outcome::Err(e),
2275                Outcome::Cancelled(c) => return Outcome::Cancelled(c),
2276                Outcome::Panicked(p) => return Outcome::Panicked(p),
2277            };
2278            Outcome::Ok(rows.into_iter().next())
2279        }
2280    }
2281
2282    fn execute(
2283        &self,
2284        cx: &Cx,
2285        sql: &str,
2286        params: &[Value],
2287    ) -> impl Future<Output = Outcome<u64, Error>> + Send {
2288        let inner = Arc::clone(&self.inner);
2289        let sql = sql.to_string();
2290        let params = params.to_vec();
2291        async move {
2292            let Ok(mut guard) = inner.lock(cx).await else {
2293                return Outcome::Err(connection_error("Failed to acquire connection lock"));
2294            };
2295            guard.execute_async(cx, &sql, &params).await
2296        }
2297    }
2298
2299    fn insert(
2300        &self,
2301        cx: &Cx,
2302        sql: &str,
2303        params: &[Value],
2304    ) -> impl Future<Output = Outcome<i64, Error>> + Send {
2305        let inner = Arc::clone(&self.inner);
2306        let sql = sql.to_string();
2307        let params = params.to_vec();
2308        async move {
2309            let Ok(mut guard) = inner.lock(cx).await else {
2310                return Outcome::Err(connection_error("Failed to acquire connection lock"));
2311            };
2312            match guard.execute_async(cx, &sql, &params).await {
2313                Outcome::Ok(_) => Outcome::Ok(guard.last_insert_id() as i64),
2314                Outcome::Err(e) => Outcome::Err(e),
2315                Outcome::Cancelled(c) => Outcome::Cancelled(c),
2316                Outcome::Panicked(p) => Outcome::Panicked(p),
2317            }
2318        }
2319    }
2320
2321    fn batch(
2322        &self,
2323        cx: &Cx,
2324        statements: &[(String, Vec<Value>)],
2325    ) -> impl Future<Output = Outcome<Vec<u64>, Error>> + Send {
2326        let inner = Arc::clone(&self.inner);
2327        let statements = statements.to_vec();
2328        async move {
2329            let Ok(mut guard) = inner.lock(cx).await else {
2330                return Outcome::Err(connection_error("Failed to acquire connection lock"));
2331            };
2332            let mut results = Vec::with_capacity(statements.len());
2333            for (sql, params) in &statements {
2334                match guard.execute_async(cx, sql, params).await {
2335                    Outcome::Ok(n) => results.push(n),
2336                    Outcome::Err(e) => return Outcome::Err(e),
2337                    Outcome::Cancelled(c) => return Outcome::Cancelled(c),
2338                    Outcome::Panicked(p) => return Outcome::Panicked(p),
2339                }
2340            }
2341            Outcome::Ok(results)
2342        }
2343    }
2344
2345    fn begin(&self, cx: &Cx) -> impl Future<Output = Outcome<Self::Tx<'_>, Error>> + Send {
2346        self.begin_transaction_impl(cx, None)
2347    }
2348
2349    fn begin_with(
2350        &self,
2351        cx: &Cx,
2352        isolation: IsolationLevel,
2353    ) -> impl Future<Output = Outcome<Self::Tx<'_>, Error>> + Send {
2354        self.begin_transaction_impl(cx, Some(isolation))
2355    }
2356
2357    fn prepare(
2358        &self,
2359        cx: &Cx,
2360        sql: &str,
2361    ) -> impl Future<Output = Outcome<PreparedStatement, Error>> + Send {
2362        let inner = Arc::clone(&self.inner);
2363        let sql = sql.to_string();
2364        async move {
2365            let Ok(mut guard) = inner.lock(cx).await else {
2366                return Outcome::Err(connection_error("Failed to acquire connection lock"));
2367            };
2368            guard.prepare_async(cx, &sql).await
2369        }
2370    }
2371
2372    fn query_prepared(
2373        &self,
2374        cx: &Cx,
2375        stmt: &PreparedStatement,
2376        params: &[Value],
2377    ) -> impl Future<Output = Outcome<Vec<Row>, Error>> + Send {
2378        let inner = Arc::clone(&self.inner);
2379        let stmt = stmt.clone();
2380        let params = params.to_vec();
2381        async move {
2382            let Ok(mut guard) = inner.lock(cx).await else {
2383                return Outcome::Err(connection_error("Failed to acquire connection lock"));
2384            };
2385            guard.query_prepared_async(cx, &stmt, &params).await
2386        }
2387    }
2388
2389    fn execute_prepared(
2390        &self,
2391        cx: &Cx,
2392        stmt: &PreparedStatement,
2393        params: &[Value],
2394    ) -> impl Future<Output = Outcome<u64, Error>> + Send {
2395        let inner = Arc::clone(&self.inner);
2396        let stmt = stmt.clone();
2397        let params = params.to_vec();
2398        async move {
2399            let Ok(mut guard) = inner.lock(cx).await else {
2400                return Outcome::Err(connection_error("Failed to acquire connection lock"));
2401            };
2402            guard.execute_prepared_async(cx, &stmt, &params).await
2403        }
2404    }
2405
2406    fn ping(&self, cx: &Cx) -> impl Future<Output = Outcome<(), Error>> + Send {
2407        let inner = Arc::clone(&self.inner);
2408        async move {
2409            let Ok(mut guard) = inner.lock(cx).await else {
2410                return Outcome::Err(connection_error("Failed to acquire connection lock"));
2411            };
2412            guard.ping_async(cx).await
2413        }
2414    }
2415
2416    fn close(self, cx: &Cx) -> impl Future<Output = Result<(), Error>> + Send {
2417        async move {
2418            // Try to get exclusive access - if we have the only Arc, we can close
2419            match Arc::try_unwrap(self.inner) {
2420                Ok(mutex) => {
2421                    let conn = mutex.into_inner();
2422                    conn.close_async(cx).await
2423                }
2424                Err(_) => {
2425                    // Other references exist, can't close
2426                    Err(connection_error(
2427                        "Cannot close: other references to connection exist",
2428                    ))
2429                }
2430            }
2431        }
2432    }
2433}
2434
2435impl<'conn> TransactionOps for SharedMySqlTransaction<'conn> {
2436    fn query(
2437        &self,
2438        cx: &Cx,
2439        sql: &str,
2440        params: &[Value],
2441    ) -> impl Future<Output = Outcome<Vec<Row>, Error>> + Send {
2442        let inner = Arc::clone(&self.inner);
2443        let sql = sql.to_string();
2444        let params = params.to_vec();
2445        async move {
2446            let Ok(mut guard) = inner.lock(cx).await else {
2447                return Outcome::Err(connection_error("Failed to acquire connection lock"));
2448            };
2449            guard.query_async(cx, &sql, &params).await
2450        }
2451    }
2452
2453    fn query_one(
2454        &self,
2455        cx: &Cx,
2456        sql: &str,
2457        params: &[Value],
2458    ) -> impl Future<Output = Outcome<Option<Row>, Error>> + Send {
2459        let inner = Arc::clone(&self.inner);
2460        let sql = sql.to_string();
2461        let params = params.to_vec();
2462        async move {
2463            let Ok(mut guard) = inner.lock(cx).await else {
2464                return Outcome::Err(connection_error("Failed to acquire connection lock"));
2465            };
2466            let rows = match guard.query_async(cx, &sql, &params).await {
2467                Outcome::Ok(r) => r,
2468                Outcome::Err(e) => return Outcome::Err(e),
2469                Outcome::Cancelled(c) => return Outcome::Cancelled(c),
2470                Outcome::Panicked(p) => return Outcome::Panicked(p),
2471            };
2472            Outcome::Ok(rows.into_iter().next())
2473        }
2474    }
2475
2476    fn execute(
2477        &self,
2478        cx: &Cx,
2479        sql: &str,
2480        params: &[Value],
2481    ) -> impl Future<Output = Outcome<u64, Error>> + Send {
2482        let inner = Arc::clone(&self.inner);
2483        let sql = sql.to_string();
2484        let params = params.to_vec();
2485        async move {
2486            let Ok(mut guard) = inner.lock(cx).await else {
2487                return Outcome::Err(connection_error("Failed to acquire connection lock"));
2488            };
2489            guard.execute_async(cx, &sql, &params).await
2490        }
2491    }
2492
2493    fn savepoint(&self, cx: &Cx, name: &str) -> impl Future<Output = Outcome<(), Error>> + Send {
2494        let inner = Arc::clone(&self.inner);
2495        // Validate name before building SQL to prevent injection
2496        let validation_result = validate_savepoint_name(name);
2497        let sql = format!("SAVEPOINT {}", name);
2498        async move {
2499            // Return validation error if name was invalid
2500            if let Err(e) = validation_result {
2501                return Outcome::Err(e);
2502            }
2503            let Ok(mut guard) = inner.lock(cx).await else {
2504                return Outcome::Err(connection_error("Failed to acquire connection lock"));
2505            };
2506            match guard.execute_async(cx, &sql, &[]).await {
2507                Outcome::Ok(_) => Outcome::Ok(()),
2508                Outcome::Err(e) => Outcome::Err(e),
2509                Outcome::Cancelled(c) => Outcome::Cancelled(c),
2510                Outcome::Panicked(p) => Outcome::Panicked(p),
2511            }
2512        }
2513    }
2514
2515    fn rollback_to(&self, cx: &Cx, name: &str) -> impl Future<Output = Outcome<(), Error>> + Send {
2516        let inner = Arc::clone(&self.inner);
2517        // Validate name before building SQL to prevent injection
2518        let validation_result = validate_savepoint_name(name);
2519        let sql = format!("ROLLBACK TO SAVEPOINT {}", name);
2520        async move {
2521            // Return validation error if name was invalid
2522            if let Err(e) = validation_result {
2523                return Outcome::Err(e);
2524            }
2525            let Ok(mut guard) = inner.lock(cx).await else {
2526                return Outcome::Err(connection_error("Failed to acquire connection lock"));
2527            };
2528            match guard.execute_async(cx, &sql, &[]).await {
2529                Outcome::Ok(_) => Outcome::Ok(()),
2530                Outcome::Err(e) => Outcome::Err(e),
2531                Outcome::Cancelled(c) => Outcome::Cancelled(c),
2532                Outcome::Panicked(p) => Outcome::Panicked(p),
2533            }
2534        }
2535    }
2536
2537    fn release(&self, cx: &Cx, name: &str) -> impl Future<Output = Outcome<(), Error>> + Send {
2538        let inner = Arc::clone(&self.inner);
2539        // Validate name before building SQL to prevent injection
2540        let validation_result = validate_savepoint_name(name);
2541        let sql = format!("RELEASE SAVEPOINT {}", name);
2542        async move {
2543            // Return validation error if name was invalid
2544            if let Err(e) = validation_result {
2545                return Outcome::Err(e);
2546            }
2547            let Ok(mut guard) = inner.lock(cx).await else {
2548                return Outcome::Err(connection_error("Failed to acquire connection lock"));
2549            };
2550            match guard.execute_async(cx, &sql, &[]).await {
2551                Outcome::Ok(_) => Outcome::Ok(()),
2552                Outcome::Err(e) => Outcome::Err(e),
2553                Outcome::Cancelled(c) => Outcome::Cancelled(c),
2554                Outcome::Panicked(p) => Outcome::Panicked(p),
2555            }
2556        }
2557    }
2558
2559    // Note: clippy incorrectly flags `self.committed = true` as unused, but
2560    // the Drop impl reads this field to determine if rollback logging is needed.
2561    #[allow(unused_assignments)]
2562    fn commit(mut self, cx: &Cx) -> impl Future<Output = Outcome<(), Error>> + Send {
2563        async move {
2564            let Ok(mut guard) = self.inner.lock(cx).await else {
2565                return Outcome::Err(connection_error("Failed to acquire connection lock"));
2566            };
2567            match guard.execute_async(cx, "COMMIT", &[]).await {
2568                Outcome::Ok(_) => {
2569                    self.committed = true;
2570                    Outcome::Ok(())
2571                }
2572                Outcome::Err(e) => Outcome::Err(e),
2573                Outcome::Cancelled(c) => Outcome::Cancelled(c),
2574                Outcome::Panicked(p) => Outcome::Panicked(p),
2575            }
2576        }
2577    }
2578
2579    fn rollback(self, cx: &Cx) -> impl Future<Output = Outcome<(), Error>> + Send {
2580        async move {
2581            let Ok(mut guard) = self.inner.lock(cx).await else {
2582                return Outcome::Err(connection_error("Failed to acquire connection lock"));
2583            };
2584            match guard.execute_async(cx, "ROLLBACK", &[]).await {
2585                Outcome::Ok(_) => Outcome::Ok(()),
2586                Outcome::Err(e) => Outcome::Err(e),
2587                Outcome::Cancelled(c) => Outcome::Cancelled(c),
2588                Outcome::Panicked(p) => Outcome::Panicked(p),
2589            }
2590        }
2591    }
2592}
2593
2594impl<'conn> Drop for SharedMySqlTransaction<'conn> {
2595    fn drop(&mut self) {
2596        if !self.committed {
2597            // WARNING: Transaction was dropped without commit() or rollback()!
2598            // We cannot do async work in Drop, so the MySQL transaction will
2599            // remain open until the connection is closed or a new transaction
2600            // is started. This may cause unexpected behavior.
2601            //
2602            // To fix: Always call tx.commit(cx).await or tx.rollback(cx).await
2603            // before the transaction goes out of scope.
2604            #[cfg(debug_assertions)]
2605            eprintln!(
2606                "WARNING: SharedMySqlTransaction dropped without commit/rollback. \
2607                 The MySQL transaction may still be open."
2608            );
2609        }
2610    }
2611}
2612
2613#[cfg(test)]
2614mod tests {
2615    use super::*;
2616
2617    #[test]
2618    fn test_connection_state() {
2619        assert_eq!(ConnectionState::Disconnected, ConnectionState::Disconnected);
2620    }
2621
2622    #[test]
2623    fn test_error_helpers() {
2624        let err = protocol_error("test");
2625        assert!(matches!(err, Error::Protocol(_)));
2626
2627        let err = auth_error("auth failed");
2628        assert!(matches!(err, Error::Connection(_)));
2629
2630        let err = connection_error("conn failed");
2631        assert!(matches!(err, Error::Connection(_)));
2632    }
2633
2634    #[test]
2635    fn test_validate_savepoint_name_valid() {
2636        // Valid names
2637        assert!(validate_savepoint_name("sp1").is_ok());
2638        assert!(validate_savepoint_name("_savepoint").is_ok());
2639        assert!(validate_savepoint_name("SavePoint_123").is_ok());
2640        assert!(validate_savepoint_name("sp$test").is_ok());
2641        assert!(validate_savepoint_name("a").is_ok());
2642        assert!(validate_savepoint_name("_").is_ok());
2643    }
2644
2645    #[test]
2646    fn test_validate_savepoint_name_invalid() {
2647        // Empty name
2648        assert!(validate_savepoint_name("").is_err());
2649
2650        // Starts with digit
2651        assert!(validate_savepoint_name("1savepoint").is_err());
2652
2653        // Contains invalid characters
2654        assert!(validate_savepoint_name("save-point").is_err());
2655        assert!(validate_savepoint_name("save point").is_err());
2656        assert!(validate_savepoint_name("save;drop table").is_err());
2657        assert!(validate_savepoint_name("sp'--").is_err());
2658
2659        // Too long (over 64 chars)
2660        let long_name = "a".repeat(65);
2661        assert!(validate_savepoint_name(&long_name).is_err());
2662    }
2663}