Skip to main content

sqlx_sqlserver/protocol/
packet.rs

1use thiserror::Error;
2
3/// Length in bytes of a TDS packet header.
4pub const PACKET_HEADER_LEN: usize = 8;
5
6/// Maximum encoded TDS packet length. The packet header stores this as a u16.
7pub const MAX_PACKET_LEN: usize = u16::MAX as usize;
8
9/// TDS packet type byte.
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub struct PacketType(u8);
12
13impl PacketType {
14    /// SQL batch packet.
15    pub const SQL_BATCH: Self = Self(0x01);
16    /// RPC request packet.
17    pub const RPC: Self = Self(0x03);
18    /// Tabular result packet.
19    pub const TABULAR_RESULT: Self = Self(0x04);
20    /// Login7 packet.
21    pub const LOGIN7: Self = Self(0x10);
22    /// Pre-login packet.
23    pub const PRE_LOGIN: Self = Self(0x12);
24
25    /// Returns the raw packet type byte.
26    pub const fn code(self) -> u8 {
27        self.0
28    }
29}
30
31impl From<u8> for PacketType {
32    fn from(value: u8) -> Self {
33        Self(value)
34    }
35}
36
37/// TDS packet status byte.
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub struct PacketStatus(u8);
40
41impl PacketStatus {
42    /// Normal packet.
43    pub const NORMAL: Self = Self(0x00);
44    /// Last packet in a message.
45    pub const END_OF_MESSAGE: Self = Self(0x01);
46
47    /// Returns the raw status byte.
48    pub const fn code(self) -> u8 {
49        self.0
50    }
51}
52
53impl From<u8> for PacketStatus {
54    fn from(value: u8) -> Self {
55        Self(value)
56    }
57}
58
59/// TDS packet header.
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub struct PacketHeader {
62    /// Packet type.
63    pub packet_type: PacketType,
64    /// Packet status.
65    pub status: PacketStatus,
66    /// Full packet length including the 8-byte header.
67    pub length: u16,
68    /// Server process ID.
69    pub server_process_id: u16,
70    /// Packet sequence ID.
71    pub packet_id: u8,
72    /// TDS window byte. Usually zero.
73    pub window: u8,
74}
75
76impl PacketHeader {
77    /// Creates a packet header for an outgoing client packet.
78    pub fn new(packet_type: PacketType, status: PacketStatus, length: u16, packet_id: u8) -> Self {
79        Self {
80            packet_type,
81            status,
82            length,
83            server_process_id: 0,
84            packet_id,
85            window: 0,
86        }
87    }
88
89    /// Encodes this header to its wire representation.
90    pub fn encode(self) -> [u8; PACKET_HEADER_LEN] {
91        let length = self.length.to_be_bytes();
92        let server_process_id = self.server_process_id.to_be_bytes();
93
94        [
95            self.packet_type.code(),
96            self.status.code(),
97            length[0],
98            length[1],
99            server_process_id[0],
100            server_process_id[1],
101            self.packet_id,
102            self.window,
103        ]
104    }
105
106    /// Decodes a TDS packet header from its wire representation.
107    pub fn decode(input: &[u8]) -> Result<Self, PacketHeaderError> {
108        let bytes: &[u8; PACKET_HEADER_LEN] = input
109            .try_into()
110            .map_err(|_| PacketHeaderError::WrongLength(input.len()))?;
111
112        let length = u16::from_be_bytes([bytes[2], bytes[3]]);
113
114        if usize::from(length) < PACKET_HEADER_LEN {
115            return Err(PacketHeaderError::InvalidPacketLength(length));
116        }
117
118        Ok(Self {
119            packet_type: PacketType::from(bytes[0]),
120            status: PacketStatus::from(bytes[1]),
121            length,
122            server_process_id: u16::from_be_bytes([bytes[4], bytes[5]]),
123            packet_id: bytes[6],
124            window: bytes[7],
125        })
126    }
127}
128
129/// A decoded TDS message assembled from one or more packets.
130#[derive(Debug, Clone, PartialEq, Eq)]
131pub struct PacketMessage {
132    /// Packet type shared by all packets in the message.
133    pub packet_type: PacketType,
134    /// Concatenated message payload, excluding packet headers.
135    pub payload: Vec<u8>,
136    /// Number of bytes consumed from the input buffer.
137    pub consumed: usize,
138}
139
140/// Encodes a message payload into one or more TDS packets.
141///
142/// `packet_size` is the maximum packet length including the 8-byte header. The
143/// helper emits client packet IDs starting at one and sets
144/// `END_OF_MESSAGE` only on the final packet.
145pub fn encode_message(
146    packet_type: PacketType,
147    payload: &[u8],
148    packet_size: usize,
149) -> Result<Vec<u8>, PacketFrameError> {
150    if packet_size <= PACKET_HEADER_LEN {
151        return Err(PacketFrameError::InvalidMaxPacketSize(packet_size));
152    }
153
154    if packet_size > MAX_PACKET_LEN {
155        return Err(PacketFrameError::InvalidMaxPacketSize(packet_size));
156    }
157
158    let max_payload_len = packet_size - PACKET_HEADER_LEN;
159    let packet_count = if payload.is_empty() {
160        1
161    } else {
162        payload.len().div_ceil(max_payload_len)
163    };
164
165    let total_len = payload
166        .len()
167        .checked_add(packet_count * PACKET_HEADER_LEN)
168        .ok_or(PacketFrameError::MessageTooLarge)?;
169
170    let mut out = Vec::with_capacity(total_len);
171    let mut packet_id = 1u8;
172
173    if payload.is_empty() {
174        let header = PacketHeader::new(
175            packet_type,
176            PacketStatus::END_OF_MESSAGE,
177            PACKET_HEADER_LEN as u16,
178            packet_id,
179        );
180        out.extend_from_slice(&header.encode());
181        return Ok(out);
182    }
183
184    for chunk in payload.chunks(max_payload_len) {
185        let is_last = out.len() + PACKET_HEADER_LEN + chunk.len() == total_len;
186        let status = if is_last {
187            PacketStatus::END_OF_MESSAGE
188        } else {
189            PacketStatus::NORMAL
190        };
191        let length = u16::try_from(PACKET_HEADER_LEN + chunk.len())
192            .map_err(|_| PacketFrameError::MessageTooLarge)?;
193
194        let header = PacketHeader::new(packet_type, status, length, packet_id);
195        out.extend_from_slice(&header.encode());
196        out.extend_from_slice(chunk);
197        packet_id = packet_id.wrapping_add(1);
198    }
199
200    Ok(out)
201}
202
203/// Tries to decode one complete TDS message from the front of `input`.
204///
205/// Returns `Ok(None)` when the buffer does not yet contain a full packet or a
206/// packet marked `END_OF_MESSAGE`. On success, `PacketMessage::consumed`
207/// identifies how many bytes can be removed from the caller's receive buffer.
208pub fn try_decode_message(input: &[u8]) -> Result<Option<PacketMessage>, PacketFrameError> {
209    let mut offset = 0usize;
210    let mut packet_type = None;
211    let mut expected_packet_id = None;
212    let mut payload = Vec::new();
213
214    loop {
215        let Some(header_bytes) = input.get(offset..offset + PACKET_HEADER_LEN) else {
216            return Ok(None);
217        };
218
219        let header = PacketHeader::decode(header_bytes)?;
220
221        if let Some(packet_type) = packet_type {
222            if header.packet_type != packet_type {
223                return Err(PacketFrameError::MismatchedPacketType {
224                    expected: packet_type,
225                    actual: header.packet_type,
226                });
227            }
228        } else {
229            packet_type = Some(header.packet_type);
230        }
231
232        if let Some(packet_id) = expected_packet_id {
233            if header.packet_id != packet_id {
234                return Err(PacketFrameError::UnexpectedPacketId {
235                    expected: packet_id,
236                    actual: header.packet_id,
237                });
238            }
239        }
240
241        let packet_len = usize::from(header.length);
242        let packet_end = offset + packet_len;
243        let Some(packet) = input.get(offset + PACKET_HEADER_LEN..packet_end) else {
244            return Ok(None);
245        };
246
247        payload
248            .try_reserve(packet.len())
249            .map_err(|_| PacketFrameError::MessageTooLarge)?;
250        payload.extend_from_slice(packet);
251        offset = packet_end;
252        expected_packet_id = Some(header.packet_id.wrapping_add(1));
253
254        if header.status == PacketStatus::END_OF_MESSAGE {
255            return Ok(Some(PacketMessage {
256                packet_type: packet_type.expect("packet_type is set after decoding a header"),
257                payload,
258                consumed: offset,
259            }));
260        }
261    }
262}
263
264/// Error returned while decoding a TDS packet header.
265#[derive(Debug, Error, PartialEq, Eq)]
266pub enum PacketHeaderError {
267    /// The header input did not contain exactly 8 bytes.
268    #[error("TDS packet header must be 8 bytes, got {0}")]
269    WrongLength(usize),
270    /// The encoded packet length is smaller than the header itself.
271    #[error("TDS packet length {0} is smaller than the 8-byte header")]
272    InvalidPacketLength(u16),
273}
274
275/// Error returned while framing or deframing TDS packets.
276#[derive(Debug, Error, PartialEq, Eq)]
277pub enum PacketFrameError {
278    /// Packet header decoding failed.
279    #[error(transparent)]
280    Header(#[from] PacketHeaderError),
281    /// The requested packet size cannot be encoded in a TDS packet header or
282    /// leaves no room for payload bytes.
283    #[error("invalid maximum TDS packet size {0}")]
284    InvalidMaxPacketSize(usize),
285    /// A decoded message contained packets with different packet types.
286    #[error("TDS message packet type changed from 0x{expected:02x} to 0x{actual:02x}")]
287    MismatchedPacketType {
288        /// Packet type from the first packet.
289        expected: PacketType,
290        /// Packet type from a later packet in the same message.
291        actual: PacketType,
292    },
293    /// Packet IDs in a multi-packet message were not contiguous.
294    #[error("unexpected TDS packet id {actual}, expected {expected}")]
295    UnexpectedPacketId {
296        /// Expected packet ID.
297        expected: u8,
298        /// Packet ID from the header.
299        actual: u8,
300    },
301    /// The message could not fit in memory or in a protocol length field.
302    #[error("TDS message is too large")]
303    MessageTooLarge,
304}
305
306impl std::fmt::LowerHex for PacketType {
307    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
308        std::fmt::LowerHex::fmt(&self.0, f)
309    }
310}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315
316    #[test]
317    fn encodes_header_with_big_endian_integer_fields() {
318        let header = PacketHeader {
319            packet_type: PacketType::PRE_LOGIN,
320            status: PacketStatus::END_OF_MESSAGE,
321            length: 0x1234,
322            server_process_id: 0xabcd,
323            packet_id: 7,
324            window: 0,
325        };
326
327        assert_eq!(
328            [0x12, 0x01, 0x12, 0x34, 0xab, 0xcd, 0x07, 0x00],
329            header.encode()
330        );
331    }
332
333    #[test]
334    fn decodes_header_from_wire_bytes() {
335        let header =
336            PacketHeader::decode(&[0x04, 0x01, 0x00, 0x08, 0x00, 0x2a, 0x03, 0x00]).unwrap();
337
338        assert_eq!(PacketType::TABULAR_RESULT, header.packet_type);
339        assert_eq!(PacketStatus::END_OF_MESSAGE, header.status);
340        assert_eq!(8, header.length);
341        assert_eq!(42, header.server_process_id);
342        assert_eq!(3, header.packet_id);
343    }
344
345    #[test]
346    fn rejects_header_with_impossible_length() {
347        let err = PacketHeader::decode(&[0x12, 0x01, 0x00, 0x07, 0, 0, 0, 0]).unwrap_err();
348
349        assert_eq!(PacketHeaderError::InvalidPacketLength(7), err);
350    }
351
352    #[test]
353    fn encodes_empty_message_as_end_packet() {
354        let bytes = encode_message(PacketType::SQL_BATCH, &[], 512).unwrap();
355
356        assert_eq!(vec![0x01, 0x01, 0x00, 0x08, 0, 0, 1, 0], bytes);
357    }
358
359    #[test]
360    fn encodes_client_message_across_packet_boundaries_from_packet_id_one() {
361        let bytes = encode_message(PacketType::PRE_LOGIN, b"abcdefghi", 12).unwrap();
362
363        assert_eq!(
364            vec![
365                0x12, 0x00, 0x00, 0x0c, 0, 0, 1, 0, b'a', b'b', b'c', b'd', 0x12, 0x00, 0x00, 0x0c,
366                0, 0, 2, 0, b'e', b'f', b'g', b'h', 0x12, 0x01, 0x00, 0x09, 0, 0, 3, 0, b'i',
367            ],
368            bytes
369        );
370    }
371
372    #[test]
373    fn rejects_invalid_max_packet_size() {
374        let err = encode_message(PacketType::PRE_LOGIN, b"abc", PACKET_HEADER_LEN).unwrap_err();
375
376        assert_eq!(
377            PacketFrameError::InvalidMaxPacketSize(PACKET_HEADER_LEN),
378            err
379        );
380    }
381
382    #[test]
383    fn decodes_single_packet_message_and_reports_consumed_bytes() {
384        let mut bytes = encode_message(PacketType::SQL_BATCH, b"select 1", 512).unwrap();
385        bytes.extend_from_slice(b"next message bytes");
386
387        let message = try_decode_message(&bytes).unwrap().unwrap();
388
389        assert_eq!(PacketType::SQL_BATCH, message.packet_type);
390        assert_eq!(b"select 1", message.payload.as_slice());
391        assert_eq!(PACKET_HEADER_LEN + b"select 1".len(), message.consumed);
392    }
393
394    #[test]
395    fn decodes_multi_packet_message_payload() {
396        let bytes = contiguous_packet_id_message();
397        let message = try_decode_message(&bytes).unwrap().unwrap();
398
399        assert_eq!(PacketType::PRE_LOGIN, message.packet_type);
400        assert_eq!(b"abcdefghi", message.payload.as_slice());
401        assert_eq!(bytes.len(), message.consumed);
402    }
403
404    #[test]
405    fn waits_for_complete_packet() {
406        let bytes = contiguous_packet_id_message();
407
408        assert_eq!(None, try_decode_message(&bytes[..15]).unwrap());
409    }
410
411    #[test]
412    fn waits_for_end_of_message_packet() {
413        let bytes = contiguous_packet_id_message();
414
415        assert_eq!(None, try_decode_message(&bytes[..12]).unwrap());
416    }
417
418    #[test]
419    fn rejects_mismatched_packet_types() {
420        let mut bytes = contiguous_packet_id_message();
421        bytes[12] = PacketType::SQL_BATCH.code();
422
423        let err = try_decode_message(&bytes).unwrap_err();
424
425        assert_eq!(
426            PacketFrameError::MismatchedPacketType {
427                expected: PacketType::PRE_LOGIN,
428                actual: PacketType::SQL_BATCH,
429            },
430            err
431        );
432    }
433
434    #[test]
435    fn rejects_non_contiguous_packet_ids() {
436        let mut bytes = contiguous_packet_id_message();
437        bytes[18] = 5;
438
439        let err = try_decode_message(&bytes).unwrap_err();
440
441        assert_eq!(
442            PacketFrameError::UnexpectedPacketId {
443                expected: 2,
444                actual: 5,
445            },
446            err
447        );
448    }
449
450    fn contiguous_packet_id_message() -> Vec<u8> {
451        vec![
452            0x12, 0x00, 0x00, 0x0c, 0, 0, 1, 0, b'a', b'b', b'c', b'd', 0x12, 0x00, 0x00, 0x0c, 0,
453            0, 2, 0, b'e', b'f', b'g', b'h', 0x12, 0x01, 0x00, 0x09, 0, 0, 3, 0, b'i',
454        ]
455    }
456}