Skip to main content

sqlmodel_mysql/protocol/
writer.rs

1//! MySQL packet writing utilities.
2//!
3//! This module provides utilities for writing MySQL protocol data types
4//! including length-encoded integers and strings.
5
6#![allow(clippy::cast_possible_truncation)]
7
8use crate::protocol::{MAX_PACKET_SIZE, PacketHeader};
9
10/// A writer for MySQL protocol data.
11#[derive(Debug, Default)]
12pub struct PacketWriter {
13    buffer: Vec<u8>,
14}
15
16impl PacketWriter {
17    /// Create a new writer with default capacity.
18    pub fn new() -> Self {
19        Self::with_capacity(256)
20    }
21
22    /// Create a new writer with specified capacity.
23    pub fn with_capacity(capacity: usize) -> Self {
24        Self {
25            buffer: Vec::with_capacity(capacity),
26        }
27    }
28
29    /// Get the current buffer length.
30    pub fn len(&self) -> usize {
31        self.buffer.len()
32    }
33
34    /// Check if the buffer is empty.
35    pub fn is_empty(&self) -> bool {
36        self.buffer.is_empty()
37    }
38
39    /// Clear the buffer.
40    pub fn clear(&mut self) {
41        self.buffer.clear();
42    }
43
44    /// Get the buffer as a byte slice.
45    pub fn as_bytes(&self) -> &[u8] {
46        &self.buffer
47    }
48
49    /// Consume the writer and return the buffer.
50    pub fn into_bytes(self) -> Vec<u8> {
51        self.buffer
52    }
53
54    /// Write a single byte.
55    pub fn write_u8(&mut self, value: u8) {
56        self.buffer.push(value);
57    }
58
59    /// Write a u16 (little-endian).
60    pub fn write_u16_le(&mut self, value: u16) {
61        self.buffer.extend_from_slice(&value.to_le_bytes());
62    }
63
64    /// Write a u24 (little-endian, 3 bytes).
65    pub fn write_u24_le(&mut self, value: u32) {
66        self.buffer.push((value & 0xFF) as u8);
67        self.buffer.push(((value >> 8) & 0xFF) as u8);
68        self.buffer.push(((value >> 16) & 0xFF) as u8);
69    }
70
71    /// Write a u32 (little-endian).
72    pub fn write_u32_le(&mut self, value: u32) {
73        self.buffer.extend_from_slice(&value.to_le_bytes());
74    }
75
76    /// Write a u64 (little-endian).
77    pub fn write_u64_le(&mut self, value: u64) {
78        self.buffer.extend_from_slice(&value.to_le_bytes());
79    }
80
81    /// Write a length-encoded integer.
82    ///
83    /// MySQL uses a variable-length integer encoding:
84    /// - 0x00-0xFA: 1-byte value
85    /// - 0xFC + 2 bytes: values up to 2^16
86    /// - 0xFD + 3 bytes: values up to 2^24
87    /// - 0xFE + 8 bytes: values up to 2^64
88    pub fn write_lenenc_int(&mut self, value: u64) {
89        if value < 251 {
90            self.write_u8(value as u8);
91        } else if value < 0x10000 {
92            self.write_u8(0xFC);
93            self.write_u16_le(value as u16);
94        } else if value < 0x0100_0000 {
95            self.write_u8(0xFD);
96            self.write_u24_le(value as u32);
97        } else {
98            self.write_u8(0xFE);
99            self.write_u64_le(value);
100        }
101    }
102
103    /// Write a length-encoded string.
104    pub fn write_lenenc_string(&mut self, s: &str) {
105        self.write_lenenc_int(s.len() as u64);
106        self.buffer.extend_from_slice(s.as_bytes());
107    }
108
109    /// Write a length-encoded byte slice.
110    pub fn write_lenenc_bytes(&mut self, data: &[u8]) {
111        self.write_lenenc_int(data.len() as u64);
112        self.buffer.extend_from_slice(data);
113    }
114
115    /// Write a null-terminated string.
116    pub fn write_null_string(&mut self, s: &str) {
117        self.buffer.extend_from_slice(s.as_bytes());
118        self.buffer.push(0);
119    }
120
121    /// Write a fixed-length string, padding with zeros if necessary.
122    pub fn write_fixed_string(&mut self, s: &str, len: usize) {
123        let bytes = s.as_bytes();
124        if bytes.len() >= len {
125            self.buffer.extend_from_slice(&bytes[..len]);
126        } else {
127            self.buffer.extend_from_slice(bytes);
128            self.buffer.resize(self.buffer.len() + len - bytes.len(), 0);
129        }
130    }
131
132    /// Write raw bytes.
133    pub fn write_bytes(&mut self, data: &[u8]) {
134        self.buffer.extend_from_slice(data);
135    }
136
137    /// Write zeros (padding).
138    pub fn write_zeros(&mut self, count: usize) {
139        self.buffer.resize(self.buffer.len() + count, 0);
140    }
141
142    /// Build a complete packet with header and payload.
143    ///
144    /// This handles splitting large payloads into multiple packets
145    /// if needed (payloads over 16MB - 1).
146    pub fn build_packet(&self, sequence_id: u8) -> Vec<u8> {
147        self.build_packet_from_payload(&self.buffer, sequence_id)
148    }
149
150    /// Build a packet from a given payload.
151    pub fn build_packet_from_payload(&self, payload: &[u8], mut sequence_id: u8) -> Vec<u8> {
152        let mut result = Vec::with_capacity(payload.len() + 4);
153
154        if payload.len() <= MAX_PACKET_SIZE {
155            // Single packet
156            let header = PacketHeader {
157                payload_length: payload.len() as u32,
158                sequence_id,
159            };
160            result.extend_from_slice(&header.to_bytes());
161            result.extend_from_slice(payload);
162        } else {
163            // Split into multiple packets
164            let mut offset = 0;
165            while offset < payload.len() {
166                let chunk_len = (payload.len() - offset).min(MAX_PACKET_SIZE);
167                let header = PacketHeader {
168                    payload_length: chunk_len as u32,
169                    sequence_id,
170                };
171                result.extend_from_slice(&header.to_bytes());
172                result.extend_from_slice(&payload[offset..offset + chunk_len]);
173                offset += chunk_len;
174                sequence_id = sequence_id.wrapping_add(1);
175
176                // If we wrote exactly MAX_PACKET_SIZE, we need an empty packet
177                // to signal the end of the payload
178                if chunk_len == MAX_PACKET_SIZE && offset == payload.len() {
179                    let header = PacketHeader {
180                        payload_length: 0,
181                        sequence_id,
182                    };
183                    result.extend_from_slice(&header.to_bytes());
184                }
185            }
186        }
187
188        result
189    }
190}
191
192/// Helper to build a command packet.
193pub fn build_command_packet(command: u8, payload: &[u8], sequence_id: u8) -> Vec<u8> {
194    let mut writer = PacketWriter::with_capacity(1 + payload.len());
195    writer.write_u8(command);
196    writer.write_bytes(payload);
197    writer.build_packet(sequence_id)
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203
204    #[test]
205    fn test_write_u8() {
206        let mut writer = PacketWriter::new();
207        writer.write_u8(0x42);
208        assert_eq!(writer.as_bytes(), &[0x42]);
209    }
210
211    #[test]
212    fn test_write_u16_le() {
213        let mut writer = PacketWriter::new();
214        writer.write_u16_le(0x1234);
215        assert_eq!(writer.as_bytes(), &[0x34, 0x12]);
216    }
217
218    #[test]
219    fn test_write_u24_le() {
220        let mut writer = PacketWriter::new();
221        writer.write_u24_le(0x0012_3456);
222        assert_eq!(writer.as_bytes(), &[0x56, 0x34, 0x12]);
223    }
224
225    #[test]
226    fn test_write_u32_le() {
227        let mut writer = PacketWriter::new();
228        writer.write_u32_le(0x1234_5678);
229        assert_eq!(writer.as_bytes(), &[0x78, 0x56, 0x34, 0x12]);
230    }
231
232    #[test]
233    fn test_write_u64_le() {
234        let mut writer = PacketWriter::new();
235        writer.write_u64_le(0x0807_0605_0403_0201);
236        assert_eq!(
237            writer.as_bytes(),
238            &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]
239        );
240    }
241
242    #[test]
243    fn test_write_lenenc_int() {
244        // 1-byte value
245        let mut writer = PacketWriter::new();
246        writer.write_lenenc_int(0x42);
247        assert_eq!(writer.as_bytes(), &[0x42]);
248
249        // 2-byte value
250        let mut writer = PacketWriter::new();
251        writer.write_lenenc_int(0x1234);
252        assert_eq!(writer.as_bytes(), &[0xFC, 0x34, 0x12]);
253
254        // 3-byte value
255        let mut writer = PacketWriter::new();
256        writer.write_lenenc_int(0x0012_3456);
257        assert_eq!(writer.as_bytes(), &[0xFD, 0x56, 0x34, 0x12]);
258
259        // 8-byte value
260        let mut writer = PacketWriter::new();
261        writer.write_lenenc_int(0x0807_0605_0403_0201);
262        assert_eq!(
263            writer.as_bytes(),
264            &[0xFE, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]
265        );
266    }
267
268    #[test]
269    fn test_write_null_string() {
270        let mut writer = PacketWriter::new();
271        writer.write_null_string("hello");
272        assert_eq!(writer.as_bytes(), b"hello\0");
273    }
274
275    #[test]
276    fn test_write_lenenc_string() {
277        let mut writer = PacketWriter::new();
278        writer.write_lenenc_string("hello");
279        assert_eq!(writer.as_bytes(), &[0x05, b'h', b'e', b'l', b'l', b'o']);
280    }
281
282    #[test]
283    fn test_write_fixed_string() {
284        // String shorter than length
285        let mut writer = PacketWriter::new();
286        writer.write_fixed_string("hi", 5);
287        assert_eq!(writer.as_bytes(), &[b'h', b'i', 0, 0, 0]);
288
289        // String exactly matches length
290        let mut writer = PacketWriter::new();
291        writer.write_fixed_string("hello", 5);
292        assert_eq!(writer.as_bytes(), b"hello");
293
294        // String longer than length (truncated)
295        let mut writer = PacketWriter::new();
296        writer.write_fixed_string("hello world", 5);
297        assert_eq!(writer.as_bytes(), b"hello");
298    }
299
300    #[test]
301    fn test_build_packet() {
302        let mut writer = PacketWriter::new();
303        writer.write_bytes(b"hello");
304        let packet = writer.build_packet(1);
305        // Header: 05 00 00 01 + payload: hello
306        assert_eq!(&packet[..4], &[0x05, 0x00, 0x00, 0x01]);
307        assert_eq!(&packet[4..], b"hello");
308    }
309
310    #[test]
311    fn test_build_command_packet() {
312        let packet = build_command_packet(0x03, b"SELECT 1", 0);
313        // Header: 09 00 00 00 + command: 03 + payload: SELECT 1
314        assert_eq!(&packet[..4], &[0x09, 0x00, 0x00, 0x00]);
315        assert_eq!(packet[4], 0x03);
316        assert_eq!(&packet[5..], b"SELECT 1");
317    }
318}