1use crate::constant::{ColumnFlags, ColumnType};
2use crate::error::{eyre, Error, Result};
3use crate::protocol::primitive::*;
4use zerocopy::byteorder::little_endian::{U16 as U16LE, U32 as U32LE};
5use zerocopy::{FromBytes, Immutable, KnownLayout};
6
7#[derive(Debug, Clone, Copy)]
9pub struct ColumnDefinitionBytes<'a>(pub &'a [u8]);
10
11impl<'a> ColumnDefinitionBytes<'a> {
12 pub fn tail(&self) -> Result<&'a ColumnDefinitionTail> {
16 if self.0.len() < 12 {
17 return Err(Error::LibraryBug(eyre!(
18 "column definition too short: {} < 12",
19 self.0.len()
20 )));
21 }
22 let tail_bytes = &self.0[self.0.len() - 12..];
23 Ok(ColumnDefinitionTail::ref_from_bytes(tail_bytes)?)
24 }
25}
26
27#[derive(Debug, Clone)]
29pub struct ColumnDefinition<'a> {
30 pub schema: &'a [u8],
31 pub table_alias: &'a [u8],
32 pub table_original: &'a [u8],
33 pub name_alias: &'a [u8],
34 pub name_original: &'a [u8],
35 pub tail: &'a ColumnDefinitionTail,
36}
37
38impl<'a> TryFrom<ColumnDefinitionBytes<'a>> for ColumnDefinition<'a> {
39 type Error = Error;
40
41 fn try_from(bytes: ColumnDefinitionBytes<'a>) -> Result<Self> {
42 let data = bytes.0;
43
44 let (_catalog, data) = read_string_lenenc(data)?;
46 let (schema, data) = read_string_lenenc(data)?;
47 let (table_alias, data) = read_string_lenenc(data)?;
48 let (table_original, data) = read_string_lenenc(data)?;
49 let (name_alias, data) = read_string_lenenc(data)?;
50 let (name_original, data) = read_string_lenenc(data)?;
51
52 let (_length, data) = read_int_lenenc(data)?;
55 let tail = ColumnDefinitionTail::ref_from_bytes(data)?;
56 Ok(Self {
57 schema,
59 table_alias,
60 table_original,
61 name_alias,
62 name_original,
63 tail,
64 })
65 }
66}
67
68#[derive(Debug, Clone, Copy, PartialEq, Eq)]
70pub struct ColumnTypeAndFlags {
71 pub column_type: ColumnType,
72 pub flags: ColumnFlags,
73}
74
75#[repr(C, packed)]
77#[derive(Debug, Clone, Copy, FromBytes, KnownLayout, Immutable)]
78pub struct ColumnDefinitionTail {
79 charset: U16LE,
80 column_length: U32LE,
81 column_type: u8,
82 flags: U16LE,
83 decimals: u8,
84 reserved: U16LE,
85}
86
87impl ColumnDefinitionTail {
88 pub fn charset(&self) -> u16 {
89 self.charset.get()
90 }
91
92 pub fn column_length(&self) -> u32 {
93 self.column_length.get()
94 }
95
96 pub fn column_type(&self) -> Result<ColumnType> {
98 ColumnType::from_u8(self.column_type).ok_or_else(|| {
99 Error::LibraryBug(eyre!("unknown column type: 0x{:02X}", self.column_type))
100 })
101 }
102
103 pub fn flags(&self) -> Result<ColumnFlags> {
104 ColumnFlags::from_bits(self.flags.get()).ok_or_else(|| {
105 Error::LibraryBug(eyre!("invalid column flags: 0x{:04X}", self.flags.get()))
106 })
107 }
108
109 pub fn type_and_flags(&self) -> Result<ColumnTypeAndFlags> {
111 Ok(ColumnTypeAndFlags {
112 column_type: self.column_type()?,
113 flags: self.flags()?,
114 })
115 }
116}
117
118pub struct ColumnDefinitions {
119 _packets: Vec<u8>, definitions: Vec<ColumnDefinition<'static>>,
121}
122
123impl ColumnDefinitions {
124 pub fn new(num_columns: usize, packets: Vec<u8>) -> Result<Self> {
125 let definitions = {
126 let mut buf = packets.as_slice();
127 let mut definitions = Vec::with_capacity(num_columns);
128 for _ in 0..num_columns {
129 let len = u32::from_ne_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
130 definitions.push(ColumnDefinition::try_from(ColumnDefinitionBytes(
131 &buf[4..4 + len],
132 ))?);
133 buf = &buf[4 + len..]; }
135
136 unsafe {
138 std::mem::transmute::<Vec<ColumnDefinition<'_>>, Vec<ColumnDefinition<'static>>>(
139 definitions,
140 )
141 }
142 };
143
144 Ok(Self {
145 _packets: packets,
146 definitions,
147 })
148 }
149
150 pub fn definitions<'a>(&'a self) -> &'a [ColumnDefinition<'a>] {
151 self.definitions.as_slice()
152 }
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158 use std::mem::size_of;
159
160 #[test]
161 fn test_column_definition_tail_size() {
162 assert_eq!(size_of::<ColumnDefinitionTail>(), 12);
164 }
165
166 #[test]
167 fn test_column_definition_tail_parsing() {
168 let data: [u8; 12] = [
170 0x21, 0x00, 0xFF, 0x00, 0x00, 0x00, 0xFD, 0x00, 0x00, 0x00, 0x00, 0x00, ];
177
178 let tail = ColumnDefinitionTail::ref_from_bytes(&data).expect("Failed to parse");
179
180 assert_eq!(tail.charset(), 33);
181 assert_eq!(tail.column_length(), 255);
182 assert_eq!(tail.decimals, 0);
183
184 let flags = tail.flags().expect("Failed to parse flags");
186 assert!(flags.is_empty());
187
188 let col_type = tail.column_type().expect("Failed to parse column type");
189 assert_eq!(col_type, ColumnType::MYSQL_TYPE_VAR_STRING);
190
191 let type_and_flags = tail.type_and_flags().expect("Failed to get type_and_flags");
193 assert_eq!(
194 type_and_flags.column_type,
195 ColumnType::MYSQL_TYPE_VAR_STRING
196 );
197 assert!(type_and_flags.flags.is_empty());
198 }
199
200 #[test]
201 fn test_column_definition_tail_with_flags() {
202 let data: [u8; 12] = [
204 0x21, 0x00, 0xFF, 0x00, 0x00, 0x00, 0x01, 0x21, 0x00, 0x00, 0x00, 0x00, ];
211
212 let tail = ColumnDefinitionTail::ref_from_bytes(&data).expect("Failed to parse");
213
214 let flags = tail.flags().expect("Failed to parse flags");
215 assert!(flags.contains(ColumnFlags::NOT_NULL_FLAG));
216 assert!(flags.contains(ColumnFlags::UNSIGNED_FLAG));
217 assert!(!flags.contains(ColumnFlags::AUTO_INCREMENT_FLAG));
218
219 let col_type = tail.column_type().expect("Failed to parse column type");
220 assert_eq!(col_type, ColumnType::MYSQL_TYPE_TINY);
221
222 let type_and_flags = tail.type_and_flags().expect("Failed to get type_and_flags");
224 assert_eq!(type_and_flags.column_type, ColumnType::MYSQL_TYPE_TINY);
225 assert!(type_and_flags.flags.contains(ColumnFlags::NOT_NULL_FLAG));
226 assert!(type_and_flags.flags.contains(ColumnFlags::UNSIGNED_FLAG));
227 }
228
229 #[test]
230 fn test_column_definition_tail_with_part_key_flag() {
231 let data: [u8; 12] = [
234 0x3f, 0x00, 0x0B, 0x00, 0x00, 0x00, 0x03, 0x03, 0x42, 0x00, 0x00, 0x00, ];
241
242 let tail = ColumnDefinitionTail::ref_from_bytes(&data).expect("Failed to parse");
243
244 assert_eq!(tail.charset(), 63);
246 assert_eq!(tail.column_length(), 11);
247 assert_eq!(tail.decimals, 0);
248
249 let flags = tail
251 .flags()
252 .expect("Failed to parse flags with PART_KEY_FLAG");
253 assert!(flags.contains(ColumnFlags::NOT_NULL_FLAG));
254 assert!(flags.contains(ColumnFlags::PRI_KEY_FLAG));
255 assert!(flags.contains(ColumnFlags::AUTO_INCREMENT_FLAG));
256 assert!(flags.contains(ColumnFlags::PART_KEY_FLAG));
257
258 let col_type = tail.column_type().expect("Failed to parse column type");
260 assert_eq!(col_type, ColumnType::MYSQL_TYPE_LONG);
261
262 let type_and_flags = tail.type_and_flags().expect("Failed to get type_and_flags");
264 assert_eq!(type_and_flags.column_type, ColumnType::MYSQL_TYPE_LONG);
265 assert!(type_and_flags.flags.contains(ColumnFlags::NOT_NULL_FLAG));
266 assert!(type_and_flags.flags.contains(ColumnFlags::PRI_KEY_FLAG));
267 assert!(type_and_flags
268 .flags
269 .contains(ColumnFlags::AUTO_INCREMENT_FLAG));
270 assert!(type_and_flags.flags.contains(ColumnFlags::PART_KEY_FLAG));
271 }
272
273 #[test]
274 fn test_column_definition_tail_invalid_column_type() {
275 let data: [u8; 12] = [
277 0x21, 0x00, 0xFF, 0x00, 0x00, 0x00, 0x50, 0x00, 0x00, 0x00, 0x00, 0x00, ];
284
285 let tail = ColumnDefinitionTail::ref_from_bytes(&data).expect("Failed to parse");
286
287 let result = tail.column_type();
289 assert!(result.is_err());
290 }
291
292 #[test]
293 fn test_column_definition_bytes() {
294 let data: &[u8; 12] = &[
297 0x21, 0x00, 0xFF, 0x00, 0x00, 0x00, 0x01, 0x21, 0x00, 0x00, 0x00, 0x00, ];
304
305 let col_bytes = ColumnDefinitionBytes(data);
306 let tail = col_bytes.tail().expect("Failed to parse tail");
307
308 assert_eq!(tail.charset(), 33);
309 assert_eq!(tail.column_length(), 255);
310 assert_eq!(tail.decimals, 0);
311
312 let flags = tail.flags().expect("Failed to parse flags");
313 assert!(flags.contains(ColumnFlags::NOT_NULL_FLAG));
314 assert!(flags.contains(ColumnFlags::UNSIGNED_FLAG));
315
316 let col_type = tail.column_type().expect("Failed to parse column type");
317 assert_eq!(col_type, ColumnType::MYSQL_TYPE_TINY);
318 }
319
320 #[test]
321 fn test_column_definition_bytes_too_short() {
322 let data: &[u8; 8] = &[0; 8];
324 let col_bytes = ColumnDefinitionBytes(data);
325 let result = col_bytes.tail();
326 assert!(result.is_err());
327 }
328
329 #[test]
330 fn test_column_definition_try_from() {
331 let mut packet = Vec::new();
333
334 packet.push(0x03);
336 packet.extend_from_slice(b"def");
337
338 packet.push(0x04);
340 packet.extend_from_slice(b"test");
341
342 packet.push(0x05);
344 packet.extend_from_slice(b"users");
345
346 packet.push(0x05);
348 packet.extend_from_slice(b"users");
349
350 packet.push(0x02);
352 packet.extend_from_slice(b"id");
353
354 packet.push(0x02);
356 packet.extend_from_slice(b"id");
357
358 packet.push(0x0c);
360
361 packet.extend_from_slice(&[
363 0x21, 0x00, 0x0B, 0x00, 0x00, 0x00, 0x03, 0x03, 0x00, 0x00, 0x00, 0x00, ]);
370
371 let col_bytes = ColumnDefinitionBytes(&packet);
373 let col_def = ColumnDefinition::try_from(col_bytes).expect("Failed to parse");
374
375 assert_eq!(col_def.schema, b"test");
378 assert_eq!(col_def.table_alias, b"users");
379 assert_eq!(col_def.table_original, b"users");
380 assert_eq!(col_def.name_alias, b"id");
381 assert_eq!(col_def.name_original, b"id");
382
383 assert_eq!(col_def.tail.charset(), 33);
385 assert_eq!(col_def.tail.column_length(), 11);
386 assert_eq!(col_def.tail.decimals, 0);
387
388 let flags = col_def.tail.flags().expect("Failed to parse flags");
389 assert!(flags.contains(ColumnFlags::NOT_NULL_FLAG));
390 assert!(flags.contains(ColumnFlags::PRI_KEY_FLAG));
391
392 let col_type = col_def
393 .tail
394 .column_type()
395 .expect("Failed to parse column type");
396 assert_eq!(col_type, ColumnType::MYSQL_TYPE_LONG);
397
398 let type_and_flags = col_def
400 .tail
401 .type_and_flags()
402 .expect("Failed to get type_and_flags");
403 assert_eq!(type_and_flags.column_type, ColumnType::MYSQL_TYPE_LONG);
404 assert!(type_and_flags.flags.contains(ColumnFlags::NOT_NULL_FLAG));
405 assert!(type_and_flags.flags.contains(ColumnFlags::PRI_KEY_FLAG));
406 }
407}