Skip to main content

sqlmodel_mysql/protocol/
mod.rs

1//! MySQL wire protocol implementation.
2//!
3//! MySQL packets have a 4-byte header:
4//! - 3 bytes: payload length (little-endian)
5//! - 1 byte: sequence number
6//!
7//! Maximum packet payload is 2^24 - 1 (16MB - 1). Larger payloads
8//! are split into multiple packets.
9
10pub mod prepared;
11pub mod reader;
12pub mod writer;
13
14pub use prepared::{
15    PreparedStatement, StmtPrepareOk, build_stmt_close_packet, build_stmt_execute_packet,
16    build_stmt_prepare_packet, build_stmt_reset_packet, parse_stmt_prepare_ok,
17};
18pub use reader::PacketReader;
19pub use writer::PacketWriter;
20
21/// Maximum payload size for a single MySQL packet (2^24 - 1 bytes).
22pub const MAX_PACKET_SIZE: usize = 0xFF_FF_FF;
23
24/// MySQL capability flags (client and server).
25#[allow(dead_code)]
26pub mod capabilities {
27    pub const CLIENT_LONG_PASSWORD: u32 = 1;
28    pub const CLIENT_FOUND_ROWS: u32 = 1 << 1;
29    pub const CLIENT_LONG_FLAG: u32 = 1 << 2;
30    pub const CLIENT_CONNECT_WITH_DB: u32 = 1 << 3;
31    pub const CLIENT_NO_SCHEMA: u32 = 1 << 4;
32    pub const CLIENT_COMPRESS: u32 = 1 << 5;
33    pub const CLIENT_ODBC: u32 = 1 << 6;
34    pub const CLIENT_LOCAL_FILES: u32 = 1 << 7;
35    pub const CLIENT_IGNORE_SPACE: u32 = 1 << 8;
36    pub const CLIENT_PROTOCOL_41: u32 = 1 << 9;
37    pub const CLIENT_INTERACTIVE: u32 = 1 << 10;
38    pub const CLIENT_SSL: u32 = 1 << 11;
39    pub const CLIENT_IGNORE_SIGPIPE: u32 = 1 << 12;
40    pub const CLIENT_TRANSACTIONS: u32 = 1 << 13;
41    pub const CLIENT_RESERVED: u32 = 1 << 14;
42    pub const CLIENT_SECURE_CONNECTION: u32 = 1 << 15;
43    pub const CLIENT_MULTI_STATEMENTS: u32 = 1 << 16;
44    pub const CLIENT_MULTI_RESULTS: u32 = 1 << 17;
45    pub const CLIENT_PS_MULTI_RESULTS: u32 = 1 << 18;
46    pub const CLIENT_PLUGIN_AUTH: u32 = 1 << 19;
47    pub const CLIENT_CONNECT_ATTRS: u32 = 1 << 20;
48    pub const CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA: u32 = 1 << 21;
49    pub const CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS: u32 = 1 << 22;
50    pub const CLIENT_SESSION_TRACK: u32 = 1 << 23;
51    pub const CLIENT_DEPRECATE_EOF: u32 = 1 << 24;
52    pub const CLIENT_OPTIONAL_RESULTSET_METADATA: u32 = 1 << 25;
53    pub const CLIENT_ZSTD_COMPRESSION_ALGORITHM: u32 = 1 << 26;
54    pub const CLIENT_QUERY_ATTRIBUTES: u32 = 1 << 27;
55
56    /// Default client capabilities for modern MySQL connections.
57    pub const DEFAULT_CLIENT_FLAGS: u32 = CLIENT_PROTOCOL_41
58        | CLIENT_SECURE_CONNECTION
59        | CLIENT_LONG_PASSWORD
60        | CLIENT_TRANSACTIONS
61        | CLIENT_MULTI_STATEMENTS
62        | CLIENT_MULTI_RESULTS
63        | CLIENT_PS_MULTI_RESULTS
64        | CLIENT_PLUGIN_AUTH
65        | CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA
66        | CLIENT_CONNECT_WITH_DB
67        | CLIENT_DEPRECATE_EOF;
68}
69
70/// MySQL command codes (COM_xxx).
71#[derive(Debug, Clone, Copy, PartialEq, Eq)]
72#[repr(u8)]
73pub enum Command {
74    /// Sleep (internal use)
75    Sleep = 0x00,
76    /// Quit connection
77    Quit = 0x01,
78    /// Switch database
79    InitDb = 0x02,
80    /// Text protocol query
81    Query = 0x03,
82    /// List fields in table (deprecated)
83    FieldList = 0x04,
84    /// Create database
85    CreateDb = 0x05,
86    /// Drop database
87    DropDb = 0x06,
88    /// Refresh (flush tables, etc.)
89    Refresh = 0x07,
90    /// Shutdown server
91    Shutdown = 0x08,
92    /// Statistics
93    Statistics = 0x09,
94    /// Process info
95    ProcessInfo = 0x0a,
96    /// Connect (internal use)
97    Connect = 0x0b,
98    /// Kill process
99    ProcessKill = 0x0c,
100    /// Debug
101    Debug = 0x0d,
102    /// Ping server
103    Ping = 0x0e,
104    /// Time (internal use)
105    Time = 0x0f,
106    /// Delayed insert (deprecated)
107    DelayedInsert = 0x10,
108    /// Change user
109    ChangeUser = 0x11,
110    /// Binlog dump
111    BinlogDump = 0x12,
112    /// Table dump
113    TableDump = 0x13,
114    /// Connect out
115    ConnectOut = 0x14,
116    /// Register slave
117    RegisterSlave = 0x15,
118    /// Prepare statement
119    StmtPrepare = 0x16,
120    /// Execute prepared statement
121    StmtExecute = 0x17,
122    /// Send long data for prepared statement
123    StmtSendLongData = 0x18,
124    /// Close prepared statement
125    StmtClose = 0x19,
126    /// Reset prepared statement
127    StmtReset = 0x1a,
128    /// Set option
129    SetOption = 0x1b,
130    /// Fetch cursor rows
131    StmtFetch = 0x1c,
132    /// Daemon (internal use)
133    Daemon = 0x1d,
134    /// Binlog dump GTID
135    BinlogDumpGtid = 0x1e,
136    /// Reset connection
137    ResetConnection = 0x1f,
138}
139
140/// MySQL server status flags.
141#[allow(dead_code)]
142pub mod server_status {
143    pub const SERVER_STATUS_IN_TRANS: u16 = 0x0001;
144    pub const SERVER_STATUS_AUTOCOMMIT: u16 = 0x0002;
145    pub const SERVER_MORE_RESULTS_EXISTS: u16 = 0x0008;
146    pub const SERVER_STATUS_NO_GOOD_INDEX_USED: u16 = 0x0010;
147    pub const SERVER_STATUS_NO_INDEX_USED: u16 = 0x0020;
148    pub const SERVER_STATUS_CURSOR_EXISTS: u16 = 0x0040;
149    pub const SERVER_STATUS_LAST_ROW_SENT: u16 = 0x0080;
150    pub const SERVER_STATUS_DB_DROPPED: u16 = 0x0100;
151    pub const SERVER_STATUS_NO_BACKSLASH_ESCAPES: u16 = 0x0200;
152    pub const SERVER_STATUS_METADATA_CHANGED: u16 = 0x0400;
153    pub const SERVER_QUERY_WAS_SLOW: u16 = 0x0800;
154    pub const SERVER_PS_OUT_PARAMS: u16 = 0x1000;
155    pub const SERVER_STATUS_IN_TRANS_READONLY: u16 = 0x2000;
156    pub const SERVER_SESSION_STATE_CHANGED: u16 = 0x4000;
157}
158
159/// MySQL character set codes.
160#[allow(dead_code)]
161pub mod charset {
162    pub const LATIN1_SWEDISH_CI: u8 = 8;
163    pub const UTF8_GENERAL_CI: u8 = 33;
164    pub const BINARY: u8 = 63;
165    pub const UTF8MB4_GENERAL_CI: u8 = 45;
166    pub const UTF8MB4_UNICODE_CI: u8 = 224;
167    pub const UTF8MB4_0900_AI_CI: u8 = 255;
168
169    /// Default charset for new connections (utf8mb4).
170    pub const DEFAULT_CHARSET: u8 = UTF8MB4_0900_AI_CI;
171}
172
173/// A MySQL packet header.
174#[derive(Debug, Clone, Copy)]
175pub struct PacketHeader {
176    /// Payload length (3 bytes, max 16MB - 1)
177    pub payload_length: u32,
178    /// Sequence number (wraps at 255)
179    pub sequence_id: u8,
180}
181
182impl PacketHeader {
183    /// Total header size in bytes.
184    pub const SIZE: usize = 4;
185
186    /// Parse a packet header from 4 bytes.
187    pub fn from_bytes(bytes: &[u8; 4]) -> Self {
188        let payload_length =
189            u32::from(bytes[0]) | (u32::from(bytes[1]) << 8) | (u32::from(bytes[2]) << 16);
190        let sequence_id = bytes[3];
191        Self {
192            payload_length,
193            sequence_id,
194        }
195    }
196
197    /// Encode the header to 4 bytes.
198    pub fn to_bytes(&self) -> [u8; 4] {
199        [
200            (self.payload_length & 0xFF) as u8,
201            ((self.payload_length >> 8) & 0xFF) as u8,
202            ((self.payload_length >> 16) & 0xFF) as u8,
203            self.sequence_id,
204        ]
205    }
206}
207
208/// Server response packet types.
209#[derive(Debug, Clone, Copy, PartialEq, Eq)]
210pub enum PacketType {
211    /// OK packet (0x00)
212    Ok,
213    /// Error packet (0xFF)
214    Error,
215    /// EOF packet (0xFE) - deprecated in CLIENT_DEPRECATE_EOF
216    Eof,
217    /// Local infile request (0xFB)
218    LocalInfile,
219    /// Data packet (result set row, etc.)
220    Data,
221}
222
223impl PacketType {
224    /// Detect packet type from the first byte of payload.
225    pub fn from_first_byte(byte: u8, payload_len: u32) -> Self {
226        match byte {
227            0x00 => PacketType::Ok,
228            0xFF => PacketType::Error,
229            // EOF is 0xFE with payload < 9 bytes
230            0xFE if payload_len < 9 => PacketType::Eof,
231            0xFB => PacketType::LocalInfile,
232            _ => PacketType::Data,
233        }
234    }
235}
236
237/// Parsed OK packet.
238#[derive(Debug, Clone)]
239pub struct OkPacket {
240    /// Number of affected rows
241    pub affected_rows: u64,
242    /// Last insert ID
243    pub last_insert_id: u64,
244    /// Server status flags
245    pub status_flags: u16,
246    /// Number of warnings
247    pub warnings: u16,
248    /// Info string (if any)
249    pub info: String,
250}
251
252/// Parsed Error packet.
253#[derive(Debug, Clone)]
254pub struct ErrPacket {
255    /// Error code
256    pub error_code: u16,
257    /// SQL state (5 characters)
258    pub sql_state: String,
259    /// Error message
260    pub error_message: String,
261}
262
263impl ErrPacket {
264    /// Check if this is a unique constraint violation.
265    pub fn is_duplicate_key(&self) -> bool {
266        // MySQL error code 1062 = ER_DUP_ENTRY
267        self.error_code == 1062
268    }
269
270    /// Check if this is a foreign key constraint violation.
271    pub fn is_foreign_key_violation(&self) -> bool {
272        // MySQL error codes 1451, 1452 = foreign key violations
273        self.error_code == 1451 || self.error_code == 1452
274    }
275}
276
277/// Parsed EOF packet (deprecated in newer MySQL versions).
278#[derive(Debug, Clone, Copy)]
279pub struct EofPacket {
280    /// Number of warnings
281    pub warnings: u16,
282    /// Server status flags
283    pub status_flags: u16,
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289
290    #[test]
291    fn test_packet_header_roundtrip() {
292        let header = PacketHeader {
293            payload_length: 0x0012_3456,
294            sequence_id: 7,
295        };
296        let bytes = header.to_bytes();
297        let parsed = PacketHeader::from_bytes(&bytes);
298        assert_eq!(header.payload_length, parsed.payload_length);
299        assert_eq!(header.sequence_id, parsed.sequence_id);
300    }
301
302    #[test]
303    #[allow(clippy::cast_possible_truncation)]
304    fn test_packet_header_max_size() {
305        let header = PacketHeader {
306            payload_length: MAX_PACKET_SIZE as u32,
307            sequence_id: 255,
308        };
309        let bytes = header.to_bytes();
310        assert_eq!(bytes, [0xFF, 0xFF, 0xFF, 255]);
311    }
312
313    #[test]
314    fn test_packet_type_detection() {
315        assert_eq!(PacketType::from_first_byte(0x00, 10), PacketType::Ok);
316        assert_eq!(PacketType::from_first_byte(0xFF, 10), PacketType::Error);
317        assert_eq!(PacketType::from_first_byte(0xFE, 5), PacketType::Eof);
318        assert_eq!(PacketType::from_first_byte(0xFE, 100), PacketType::Data);
319        assert_eq!(
320            PacketType::from_first_byte(0xFB, 10),
321            PacketType::LocalInfile
322        );
323        assert_eq!(PacketType::from_first_byte(0x42, 10), PacketType::Data);
324    }
325
326    #[test]
327    fn test_err_packet_error_types() {
328        let dup = ErrPacket {
329            error_code: 1062,
330            sql_state: "23000".to_string(),
331            error_message: "Duplicate entry".to_string(),
332        };
333        assert!(dup.is_duplicate_key());
334        assert!(!dup.is_foreign_key_violation());
335
336        let fk = ErrPacket {
337            error_code: 1451,
338            sql_state: "23000".to_string(),
339            error_message: "Cannot delete".to_string(),
340        };
341        assert!(!fk.is_duplicate_key());
342        assert!(fk.is_foreign_key_violation());
343    }
344}