1use thiserror::Error;
2
3pub const PACKET_HEADER_LEN: usize = 8;
5
6pub const MAX_PACKET_LEN: usize = u16::MAX as usize;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub struct PacketType(u8);
12
13impl PacketType {
14 pub const SQL_BATCH: Self = Self(0x01);
16 pub const RPC: Self = Self(0x03);
18 pub const TABULAR_RESULT: Self = Self(0x04);
20 pub const LOGIN7: Self = Self(0x10);
22 pub const PRE_LOGIN: Self = Self(0x12);
24
25 pub const fn code(self) -> u8 {
27 self.0
28 }
29}
30
31impl From<u8> for PacketType {
32 fn from(value: u8) -> Self {
33 Self(value)
34 }
35}
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub struct PacketStatus(u8);
40
41impl PacketStatus {
42 pub const NORMAL: Self = Self(0x00);
44 pub const END_OF_MESSAGE: Self = Self(0x01);
46
47 pub const fn code(self) -> u8 {
49 self.0
50 }
51}
52
53impl From<u8> for PacketStatus {
54 fn from(value: u8) -> Self {
55 Self(value)
56 }
57}
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub struct PacketHeader {
62 pub packet_type: PacketType,
64 pub status: PacketStatus,
66 pub length: u16,
68 pub server_process_id: u16,
70 pub packet_id: u8,
72 pub window: u8,
74}
75
76impl PacketHeader {
77 pub fn new(packet_type: PacketType, status: PacketStatus, length: u16, packet_id: u8) -> Self {
79 Self {
80 packet_type,
81 status,
82 length,
83 server_process_id: 0,
84 packet_id,
85 window: 0,
86 }
87 }
88
89 pub fn encode(self) -> [u8; PACKET_HEADER_LEN] {
91 let length = self.length.to_be_bytes();
92 let server_process_id = self.server_process_id.to_be_bytes();
93
94 [
95 self.packet_type.code(),
96 self.status.code(),
97 length[0],
98 length[1],
99 server_process_id[0],
100 server_process_id[1],
101 self.packet_id,
102 self.window,
103 ]
104 }
105
106 pub fn decode(input: &[u8]) -> Result<Self, PacketHeaderError> {
108 let bytes: &[u8; PACKET_HEADER_LEN] = input
109 .try_into()
110 .map_err(|_| PacketHeaderError::WrongLength(input.len()))?;
111
112 let length = u16::from_be_bytes([bytes[2], bytes[3]]);
113
114 if usize::from(length) < PACKET_HEADER_LEN {
115 return Err(PacketHeaderError::InvalidPacketLength(length));
116 }
117
118 Ok(Self {
119 packet_type: PacketType::from(bytes[0]),
120 status: PacketStatus::from(bytes[1]),
121 length,
122 server_process_id: u16::from_be_bytes([bytes[4], bytes[5]]),
123 packet_id: bytes[6],
124 window: bytes[7],
125 })
126 }
127}
128
129#[derive(Debug, Clone, PartialEq, Eq)]
131pub struct PacketMessage {
132 pub packet_type: PacketType,
134 pub payload: Vec<u8>,
136 pub consumed: usize,
138}
139
140pub fn encode_message(
146 packet_type: PacketType,
147 payload: &[u8],
148 packet_size: usize,
149) -> Result<Vec<u8>, PacketFrameError> {
150 if packet_size <= PACKET_HEADER_LEN {
151 return Err(PacketFrameError::InvalidMaxPacketSize(packet_size));
152 }
153
154 if packet_size > MAX_PACKET_LEN {
155 return Err(PacketFrameError::InvalidMaxPacketSize(packet_size));
156 }
157
158 let max_payload_len = packet_size - PACKET_HEADER_LEN;
159 let packet_count = if payload.is_empty() {
160 1
161 } else {
162 payload.len().div_ceil(max_payload_len)
163 };
164
165 let total_len = payload
166 .len()
167 .checked_add(packet_count * PACKET_HEADER_LEN)
168 .ok_or(PacketFrameError::MessageTooLarge)?;
169
170 let mut out = Vec::with_capacity(total_len);
171 let mut packet_id = 1u8;
172
173 if payload.is_empty() {
174 let header = PacketHeader::new(
175 packet_type,
176 PacketStatus::END_OF_MESSAGE,
177 PACKET_HEADER_LEN as u16,
178 packet_id,
179 );
180 out.extend_from_slice(&header.encode());
181 return Ok(out);
182 }
183
184 for chunk in payload.chunks(max_payload_len) {
185 let is_last = out.len() + PACKET_HEADER_LEN + chunk.len() == total_len;
186 let status = if is_last {
187 PacketStatus::END_OF_MESSAGE
188 } else {
189 PacketStatus::NORMAL
190 };
191 let length = u16::try_from(PACKET_HEADER_LEN + chunk.len())
192 .map_err(|_| PacketFrameError::MessageTooLarge)?;
193
194 let header = PacketHeader::new(packet_type, status, length, packet_id);
195 out.extend_from_slice(&header.encode());
196 out.extend_from_slice(chunk);
197 packet_id = packet_id.wrapping_add(1);
198 }
199
200 Ok(out)
201}
202
203pub fn try_decode_message(input: &[u8]) -> Result<Option<PacketMessage>, PacketFrameError> {
209 let mut offset = 0usize;
210 let mut packet_type = None;
211 let mut expected_packet_id = None;
212 let mut payload = Vec::new();
213
214 loop {
215 let Some(header_bytes) = input.get(offset..offset + PACKET_HEADER_LEN) else {
216 return Ok(None);
217 };
218
219 let header = PacketHeader::decode(header_bytes)?;
220
221 if let Some(packet_type) = packet_type {
222 if header.packet_type != packet_type {
223 return Err(PacketFrameError::MismatchedPacketType {
224 expected: packet_type,
225 actual: header.packet_type,
226 });
227 }
228 } else {
229 packet_type = Some(header.packet_type);
230 }
231
232 if let Some(packet_id) = expected_packet_id {
233 if header.packet_id != packet_id {
234 return Err(PacketFrameError::UnexpectedPacketId {
235 expected: packet_id,
236 actual: header.packet_id,
237 });
238 }
239 }
240
241 let packet_len = usize::from(header.length);
242 let packet_end = offset + packet_len;
243 let Some(packet) = input.get(offset + PACKET_HEADER_LEN..packet_end) else {
244 return Ok(None);
245 };
246
247 payload
248 .try_reserve(packet.len())
249 .map_err(|_| PacketFrameError::MessageTooLarge)?;
250 payload.extend_from_slice(packet);
251 offset = packet_end;
252 expected_packet_id = Some(header.packet_id.wrapping_add(1));
253
254 if header.status == PacketStatus::END_OF_MESSAGE {
255 return Ok(Some(PacketMessage {
256 packet_type: packet_type.expect("packet_type is set after decoding a header"),
257 payload,
258 consumed: offset,
259 }));
260 }
261 }
262}
263
264#[derive(Debug, Error, PartialEq, Eq)]
266pub enum PacketHeaderError {
267 #[error("TDS packet header must be 8 bytes, got {0}")]
269 WrongLength(usize),
270 #[error("TDS packet length {0} is smaller than the 8-byte header")]
272 InvalidPacketLength(u16),
273}
274
275#[derive(Debug, Error, PartialEq, Eq)]
277pub enum PacketFrameError {
278 #[error(transparent)]
280 Header(#[from] PacketHeaderError),
281 #[error("invalid maximum TDS packet size {0}")]
284 InvalidMaxPacketSize(usize),
285 #[error("TDS message packet type changed from 0x{expected:02x} to 0x{actual:02x}")]
287 MismatchedPacketType {
288 expected: PacketType,
290 actual: PacketType,
292 },
293 #[error("unexpected TDS packet id {actual}, expected {expected}")]
295 UnexpectedPacketId {
296 expected: u8,
298 actual: u8,
300 },
301 #[error("TDS message is too large")]
303 MessageTooLarge,
304}
305
306impl std::fmt::LowerHex for PacketType {
307 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
308 std::fmt::LowerHex::fmt(&self.0, f)
309 }
310}
311
312#[cfg(test)]
313mod tests {
314 use super::*;
315
316 #[test]
317 fn encodes_header_with_big_endian_integer_fields() {
318 let header = PacketHeader {
319 packet_type: PacketType::PRE_LOGIN,
320 status: PacketStatus::END_OF_MESSAGE,
321 length: 0x1234,
322 server_process_id: 0xabcd,
323 packet_id: 7,
324 window: 0,
325 };
326
327 assert_eq!(
328 [0x12, 0x01, 0x12, 0x34, 0xab, 0xcd, 0x07, 0x00],
329 header.encode()
330 );
331 }
332
333 #[test]
334 fn decodes_header_from_wire_bytes() {
335 let header =
336 PacketHeader::decode(&[0x04, 0x01, 0x00, 0x08, 0x00, 0x2a, 0x03, 0x00]).unwrap();
337
338 assert_eq!(PacketType::TABULAR_RESULT, header.packet_type);
339 assert_eq!(PacketStatus::END_OF_MESSAGE, header.status);
340 assert_eq!(8, header.length);
341 assert_eq!(42, header.server_process_id);
342 assert_eq!(3, header.packet_id);
343 }
344
345 #[test]
346 fn rejects_header_with_impossible_length() {
347 let err = PacketHeader::decode(&[0x12, 0x01, 0x00, 0x07, 0, 0, 0, 0]).unwrap_err();
348
349 assert_eq!(PacketHeaderError::InvalidPacketLength(7), err);
350 }
351
352 #[test]
353 fn encodes_empty_message_as_end_packet() {
354 let bytes = encode_message(PacketType::SQL_BATCH, &[], 512).unwrap();
355
356 assert_eq!(vec![0x01, 0x01, 0x00, 0x08, 0, 0, 1, 0], bytes);
357 }
358
359 #[test]
360 fn encodes_client_message_across_packet_boundaries_from_packet_id_one() {
361 let bytes = encode_message(PacketType::PRE_LOGIN, b"abcdefghi", 12).unwrap();
362
363 assert_eq!(
364 vec![
365 0x12, 0x00, 0x00, 0x0c, 0, 0, 1, 0, b'a', b'b', b'c', b'd', 0x12, 0x00, 0x00, 0x0c,
366 0, 0, 2, 0, b'e', b'f', b'g', b'h', 0x12, 0x01, 0x00, 0x09, 0, 0, 3, 0, b'i',
367 ],
368 bytes
369 );
370 }
371
372 #[test]
373 fn rejects_invalid_max_packet_size() {
374 let err = encode_message(PacketType::PRE_LOGIN, b"abc", PACKET_HEADER_LEN).unwrap_err();
375
376 assert_eq!(
377 PacketFrameError::InvalidMaxPacketSize(PACKET_HEADER_LEN),
378 err
379 );
380 }
381
382 #[test]
383 fn decodes_single_packet_message_and_reports_consumed_bytes() {
384 let mut bytes = encode_message(PacketType::SQL_BATCH, b"select 1", 512).unwrap();
385 bytes.extend_from_slice(b"next message bytes");
386
387 let message = try_decode_message(&bytes).unwrap().unwrap();
388
389 assert_eq!(PacketType::SQL_BATCH, message.packet_type);
390 assert_eq!(b"select 1", message.payload.as_slice());
391 assert_eq!(PACKET_HEADER_LEN + b"select 1".len(), message.consumed);
392 }
393
394 #[test]
395 fn decodes_multi_packet_message_payload() {
396 let bytes = contiguous_packet_id_message();
397 let message = try_decode_message(&bytes).unwrap().unwrap();
398
399 assert_eq!(PacketType::PRE_LOGIN, message.packet_type);
400 assert_eq!(b"abcdefghi", message.payload.as_slice());
401 assert_eq!(bytes.len(), message.consumed);
402 }
403
404 #[test]
405 fn waits_for_complete_packet() {
406 let bytes = contiguous_packet_id_message();
407
408 assert_eq!(None, try_decode_message(&bytes[..15]).unwrap());
409 }
410
411 #[test]
412 fn waits_for_end_of_message_packet() {
413 let bytes = contiguous_packet_id_message();
414
415 assert_eq!(None, try_decode_message(&bytes[..12]).unwrap());
416 }
417
418 #[test]
419 fn rejects_mismatched_packet_types() {
420 let mut bytes = contiguous_packet_id_message();
421 bytes[12] = PacketType::SQL_BATCH.code();
422
423 let err = try_decode_message(&bytes).unwrap_err();
424
425 assert_eq!(
426 PacketFrameError::MismatchedPacketType {
427 expected: PacketType::PRE_LOGIN,
428 actual: PacketType::SQL_BATCH,
429 },
430 err
431 );
432 }
433
434 #[test]
435 fn rejects_non_contiguous_packet_ids() {
436 let mut bytes = contiguous_packet_id_message();
437 bytes[18] = 5;
438
439 let err = try_decode_message(&bytes).unwrap_err();
440
441 assert_eq!(
442 PacketFrameError::UnexpectedPacketId {
443 expected: 2,
444 actual: 5,
445 },
446 err
447 );
448 }
449
450 fn contiguous_packet_id_message() -> Vec<u8> {
451 vec![
452 0x12, 0x00, 0x00, 0x0c, 0, 0, 1, 0, b'a', b'b', b'c', b'd', 0x12, 0x00, 0x00, 0x0c, 0,
453 0, 2, 0, b'e', b'f', b'g', b'h', 0x12, 0x01, 0x00, 0x09, 0, 0, 3, 0, b'i',
454 ]
455 }
456}