zero_mysql/protocol/command/
column_definition.rs

1use crate::constant::{ColumnFlags, ColumnType};
2use crate::error::{eyre, Error, Result};
3use crate::protocol::primitive::*;
4use zerocopy::byteorder::little_endian::{U16 as U16LE, U32 as U32LE};
5use zerocopy::{FromBytes, Immutable, KnownLayout};
6
7/// Represents a payload part of a column definition packet
8#[derive(Debug, Clone, Copy)]
9pub struct ColumnDefinitionBytes<'a>(pub &'a [u8]);
10
11impl<'a> ColumnDefinitionBytes<'a> {
12    /// Get a reference to the fixed-size tail of the column definition
13    ///
14    /// The tail is always the last 12 bytes of the column definition packet
15    pub fn tail(&self) -> Result<&'a ColumnDefinitionTail> {
16        if self.0.len() < 12 {
17            return Err(Error::LibraryBug(eyre!(
18                "column definition too short: {} < 12",
19                self.0.len()
20            )));
21        }
22        let tail_bytes = &self.0[self.0.len() - 12..];
23        Ok(ColumnDefinitionTail::ref_from_bytes(tail_bytes)?)
24    }
25}
26
27/// The column definition parsed from `ColumnDefinitionBytes`
28#[derive(Debug, Clone)]
29pub struct ColumnDefinition<'a> {
30    pub schema: &'a [u8],
31    pub table_alias: &'a [u8],
32    pub table_original: &'a [u8],
33    pub name_alias: &'a [u8],
34    pub name_original: &'a [u8],
35    pub tail: &'a ColumnDefinitionTail,
36}
37
38impl<'a> TryFrom<ColumnDefinitionBytes<'a>> for ColumnDefinition<'a> {
39    type Error = Error;
40
41    fn try_from(bytes: ColumnDefinitionBytes<'a>) -> Result<Self> {
42        let data = bytes.0;
43
44        // ─── Variable Length String Fields ───────────────────────────
45        let (_catalog, data) = read_string_lenenc(data)?;
46        let (schema, data) = read_string_lenenc(data)?;
47        let (table_alias, data) = read_string_lenenc(data)?;
48        let (table_original, data) = read_string_lenenc(data)?;
49        let (name_alias, data) = read_string_lenenc(data)?;
50        let (name_original, data) = read_string_lenenc(data)?;
51
52        // ─── Columndefinitiontail ────────────────────────────────────
53        // length is always 0x0c
54        let (_length, data) = read_int_lenenc(data)?;
55        let tail = ColumnDefinitionTail::ref_from_bytes(data)?;
56        Ok(Self {
57            // catalog,
58            schema,
59            table_alias,
60            table_original,
61            name_alias,
62            name_original,
63            tail,
64        })
65    }
66}
67
68/// Combined column type and flags
69#[derive(Debug, Clone, Copy, PartialEq, Eq)]
70pub struct ColumnTypeAndFlags {
71    pub column_type: ColumnType,
72    pub flags: ColumnFlags,
73}
74
75/// Fixed-size tail of Column Definition packet (12 bytes)
76#[repr(C, packed)]
77#[derive(Debug, Clone, Copy, FromBytes, KnownLayout, Immutable)]
78pub struct ColumnDefinitionTail {
79    charset: U16LE,
80    column_length: U32LE,
81    column_type: u8,
82    flags: U16LE,
83    decimals: u8,
84    reserved: U16LE,
85}
86
87impl ColumnDefinitionTail {
88    pub fn charset(&self) -> u16 {
89        self.charset.get()
90    }
91
92    pub fn column_length(&self) -> u32 {
93        self.column_length.get()
94    }
95
96    /// Returns an error if the column type is unknown
97    pub fn column_type(&self) -> Result<ColumnType> {
98        ColumnType::from_u8(self.column_type).ok_or_else(|| {
99            Error::LibraryBug(eyre!("unknown column type: 0x{:02X}", self.column_type))
100        })
101    }
102
103    pub fn flags(&self) -> Result<ColumnFlags> {
104        ColumnFlags::from_bits(self.flags.get()).ok_or_else(|| {
105            Error::LibraryBug(eyre!("invalid column flags: 0x{:04X}", self.flags.get()))
106        })
107    }
108
109    /// A handy function to get `ColumnTypeAndFlags` for decoding a value
110    pub fn type_and_flags(&self) -> Result<ColumnTypeAndFlags> {
111        Ok(ColumnTypeAndFlags {
112            column_type: self.column_type()?,
113            flags: self.flags()?,
114        })
115    }
116}
117
118pub struct ColumnDefinitions {
119    _packets: Vec<u8>, // concatenation of packets (length(usize, native endian) + payload)
120    definitions: Vec<ColumnDefinition<'static>>,
121}
122
123impl ColumnDefinitions {
124    pub fn new(num_columns: usize, packets: Vec<u8>) -> Result<Self> {
125        let definitions = {
126            let mut buf = packets.as_slice();
127            let mut definitions = Vec::with_capacity(num_columns);
128            for _ in 0..num_columns {
129                let len = u32::from_ne_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
130                definitions.push(ColumnDefinition::try_from(ColumnDefinitionBytes(
131                    &buf[4..4 + len],
132                ))?);
133                buf = &buf[4 + len..]; // Advance past the length prefix and payload
134            }
135
136            // Safety: borrowed data is valid for 'static because Self holds packets
137            unsafe {
138                std::mem::transmute::<Vec<ColumnDefinition<'_>>, Vec<ColumnDefinition<'static>>>(
139                    definitions,
140                )
141            }
142        };
143
144        Ok(Self {
145            _packets: packets,
146            definitions,
147        })
148    }
149
150    pub fn definitions<'a>(&'a self) -> &'a [ColumnDefinition<'a>] {
151        self.definitions.as_slice()
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158    use std::mem::size_of;
159
160    #[test]
161    fn test_column_definition_tail_size() {
162        // Verify the struct is exactly 12 bytes as per MySQL protocol
163        assert_eq!(size_of::<ColumnDefinitionTail>(), 12);
164    }
165
166    #[test]
167    fn test_column_definition_tail_parsing() {
168        // Example data: charset=33 (utf8), length=255, type=253 (VARCHAR), flags=0, decimals=0, reserved=0
169        let data: [u8; 12] = [
170            0x21, 0x00, // charset = 33 (0x0021) LE
171            0xFF, 0x00, 0x00, 0x00, // column_length = 255 (0x000000FF) LE
172            0xFD, // column_type = 253 (VARCHAR)
173            0x00, 0x00, // flags = 0 (0x0000) LE
174            0x00, // decimals = 0
175            0x00, 0x00, // reserved = 0 (0x0000) LE
176        ];
177
178        let tail = ColumnDefinitionTail::ref_from_bytes(&data).expect("Failed to parse");
179
180        assert_eq!(tail.charset(), 33);
181        assert_eq!(tail.column_length(), 255);
182        assert_eq!(tail.decimals, 0);
183
184        // Test conversion methods
185        let flags = tail.flags().expect("Failed to parse flags");
186        assert!(flags.is_empty());
187
188        let col_type = tail.column_type().expect("Failed to parse column type");
189        assert_eq!(col_type, ColumnType::MYSQL_TYPE_VAR_STRING);
190
191        // Test type_and_flags
192        let type_and_flags = tail.type_and_flags().expect("Failed to get type_and_flags");
193        assert_eq!(
194            type_and_flags.column_type,
195            ColumnType::MYSQL_TYPE_VAR_STRING
196        );
197        assert!(type_and_flags.flags.is_empty());
198    }
199
200    #[test]
201    fn test_column_definition_tail_with_flags() {
202        // Example with NOT_NULL and UNSIGNED flags set
203        let data: [u8; 12] = [
204            0x21, 0x00, // charset = 33
205            0xFF, 0x00, 0x00, 0x00, // column_length = 255
206            0x01, // column_type = 1 (TINYINT)
207            0x21, 0x00, // flags = 0x0021 (NOT_NULL_FLAG | UNSIGNED_FLAG) LE
208            0x00, // decimals = 0
209            0x00, 0x00, // reserved = 0
210        ];
211
212        let tail = ColumnDefinitionTail::ref_from_bytes(&data).expect("Failed to parse");
213
214        let flags = tail.flags().expect("Failed to parse flags");
215        assert!(flags.contains(ColumnFlags::NOT_NULL_FLAG));
216        assert!(flags.contains(ColumnFlags::UNSIGNED_FLAG));
217        assert!(!flags.contains(ColumnFlags::AUTO_INCREMENT_FLAG));
218
219        let col_type = tail.column_type().expect("Failed to parse column type");
220        assert_eq!(col_type, ColumnType::MYSQL_TYPE_TINY);
221
222        // Test type_and_flags
223        let type_and_flags = tail.type_and_flags().expect("Failed to get type_and_flags");
224        assert_eq!(type_and_flags.column_type, ColumnType::MYSQL_TYPE_TINY);
225        assert!(type_and_flags.flags.contains(ColumnFlags::NOT_NULL_FLAG));
226        assert!(type_and_flags.flags.contains(ColumnFlags::UNSIGNED_FLAG));
227    }
228
229    #[test]
230    fn test_column_definition_tail_with_part_key_flag() {
231        // Test with PART_KEY_FLAG (0x4000) - from actual MySQL response
232        // This reproduces the bug: flags = 0x4203 (NOT_NULL | PRI_KEY | AUTO_INCREMENT | PART_KEY)
233        let data: [u8; 12] = [
234            0x3f, 0x00, // charset = 63 (binary)
235            0x0B, 0x00, 0x00, 0x00, // column_length = 11
236            0x03, // column_type = 3 (LONG/INT)
237            0x03, 0x42, // flags = 0x4203 (NOT_NULL | PRI_KEY | AUTO_INCREMENT | PART_KEY) LE
238            0x00, // decimals = 0
239            0x00, 0x00, // reserved = 0
240        ];
241
242        let tail = ColumnDefinitionTail::ref_from_bytes(&data).expect("Failed to parse");
243
244        // Verify the fields
245        assert_eq!(tail.charset(), 63);
246        assert_eq!(tail.column_length(), 11);
247        assert_eq!(tail.decimals, 0);
248
249        // Verify flags can be parsed
250        let flags = tail
251            .flags()
252            .expect("Failed to parse flags with PART_KEY_FLAG");
253        assert!(flags.contains(ColumnFlags::NOT_NULL_FLAG));
254        assert!(flags.contains(ColumnFlags::PRI_KEY_FLAG));
255        assert!(flags.contains(ColumnFlags::AUTO_INCREMENT_FLAG));
256        assert!(flags.contains(ColumnFlags::PART_KEY_FLAG));
257
258        // Verify column type
259        let col_type = tail.column_type().expect("Failed to parse column type");
260        assert_eq!(col_type, ColumnType::MYSQL_TYPE_LONG);
261
262        // Test type_and_flags
263        let type_and_flags = tail.type_and_flags().expect("Failed to get type_and_flags");
264        assert_eq!(type_and_flags.column_type, ColumnType::MYSQL_TYPE_LONG);
265        assert!(type_and_flags.flags.contains(ColumnFlags::NOT_NULL_FLAG));
266        assert!(type_and_flags.flags.contains(ColumnFlags::PRI_KEY_FLAG));
267        assert!(type_and_flags
268            .flags
269            .contains(ColumnFlags::AUTO_INCREMENT_FLAG));
270        assert!(type_and_flags.flags.contains(ColumnFlags::PART_KEY_FLAG));
271    }
272
273    #[test]
274    fn test_column_definition_tail_invalid_column_type() {
275        // Example with invalid column type
276        let data: [u8; 12] = [
277            0x21, 0x00, // charset = 33
278            0xFF, 0x00, 0x00, 0x00, // column_length = 255
279            0x50, // column_type = 0x50 (invalid, in the gap)
280            0x00, 0x00, // flags = 0
281            0x00, // decimals = 0
282            0x00, 0x00, // reserved = 0
283        ];
284
285        let tail = ColumnDefinitionTail::ref_from_bytes(&data).expect("Failed to parse");
286
287        // Should error on unknown column type
288        let result = tail.column_type();
289        assert!(result.is_err());
290    }
291
292    #[test]
293    fn test_column_definition_bytes() {
294        // Simulate a minimal column definition packet with just the tail
295        // In reality, there would be variable-length strings before the tail
296        let data: &[u8; 12] = &[
297            0x21, 0x00, // charset = 33 (utf8)
298            0xFF, 0x00, 0x00, 0x00, // column_length = 255
299            0x01, // column_type = 1 (TINYINT)
300            0x21, 0x00, // flags = 0x0021 (NOT_NULL_FLAG | UNSIGNED_FLAG)
301            0x00, // decimals = 0
302            0x00, 0x00, // reserved = 0
303        ];
304
305        let col_bytes = ColumnDefinitionBytes(data);
306        let tail = col_bytes.tail().expect("Failed to parse tail");
307
308        assert_eq!(tail.charset(), 33);
309        assert_eq!(tail.column_length(), 255);
310        assert_eq!(tail.decimals, 0);
311
312        let flags = tail.flags().expect("Failed to parse flags");
313        assert!(flags.contains(ColumnFlags::NOT_NULL_FLAG));
314        assert!(flags.contains(ColumnFlags::UNSIGNED_FLAG));
315
316        let col_type = tail.column_type().expect("Failed to parse column type");
317        assert_eq!(col_type, ColumnType::MYSQL_TYPE_TINY);
318    }
319
320    #[test]
321    fn test_column_definition_bytes_too_short() {
322        // Test with data that's too short
323        let data: &[u8; 8] = &[0; 8];
324        let col_bytes = ColumnDefinitionBytes(data);
325        let result = col_bytes.tail();
326        assert!(result.is_err());
327    }
328
329    #[test]
330    fn test_column_definition_try_from() {
331        // Build a complete column definition packet
332        let mut packet = Vec::new();
333
334        // catalog (length-encoded string) - "def"
335        packet.push(0x03);
336        packet.extend_from_slice(b"def");
337
338        // schema (length-encoded string) - "test"
339        packet.push(0x04);
340        packet.extend_from_slice(b"test");
341
342        // table (length-encoded string) - "users"
343        packet.push(0x05);
344        packet.extend_from_slice(b"users");
345
346        // org_table (length-encoded string) - "users"
347        packet.push(0x05);
348        packet.extend_from_slice(b"users");
349
350        // name (length-encoded string) - "id"
351        packet.push(0x02);
352        packet.extend_from_slice(b"id");
353
354        // org_name (length-encoded string) - "id"
355        packet.push(0x02);
356        packet.extend_from_slice(b"id");
357
358        // length of fixed fields (0x0c = 12)
359        packet.push(0x0c);
360
361        // Fixed tail (12 bytes)
362        packet.extend_from_slice(&[
363            0x21, 0x00, // charset = 33 (utf8)
364            0x0B, 0x00, 0x00, 0x00, // column_length = 11
365            0x03, // column_type = 3 (LONG/INT)
366            0x03, 0x00, // flags = 0x0003 (NOT_NULL_FLAG | PRI_KEY_FLAG)
367            0x00, // decimals = 0
368            0x00, 0x00, // reserved = 0
369        ]);
370
371        // Parse using TryFrom
372        let col_bytes = ColumnDefinitionBytes(&packet);
373        let col_def = ColumnDefinition::try_from(col_bytes).expect("Failed to parse");
374
375        // Verify string fields
376        // assert_eq!(col_def.catalog, b"def");
377        assert_eq!(col_def.schema, b"test");
378        assert_eq!(col_def.table_alias, b"users");
379        assert_eq!(col_def.table_original, b"users");
380        assert_eq!(col_def.name_alias, b"id");
381        assert_eq!(col_def.name_original, b"id");
382
383        // Verify tail fields
384        assert_eq!(col_def.tail.charset(), 33);
385        assert_eq!(col_def.tail.column_length(), 11);
386        assert_eq!(col_def.tail.decimals, 0);
387
388        let flags = col_def.tail.flags().expect("Failed to parse flags");
389        assert!(flags.contains(ColumnFlags::NOT_NULL_FLAG));
390        assert!(flags.contains(ColumnFlags::PRI_KEY_FLAG));
391
392        let col_type = col_def
393            .tail
394            .column_type()
395            .expect("Failed to parse column type");
396        assert_eq!(col_type, ColumnType::MYSQL_TYPE_LONG);
397
398        // Test type_and_flags
399        let type_and_flags = col_def
400            .tail
401            .type_and_flags()
402            .expect("Failed to get type_and_flags");
403        assert_eq!(type_and_flags.column_type, ColumnType::MYSQL_TYPE_LONG);
404        assert!(type_and_flags.flags.contains(ColumnFlags::NOT_NULL_FLAG));
405        assert!(type_and_flags.flags.contains(ColumnFlags::PRI_KEY_FLAG));
406    }
407}