Skip to main content

tds_protocol/
packet.rs

1//! TDS packet header definitions.
2
3use bitflags::bitflags;
4use bytes::{Buf, BufMut, Bytes, BytesMut};
5
6use crate::error::ProtocolError;
7
8/// TDS packet header size in bytes.
9pub const PACKET_HEADER_SIZE: usize = 8;
10
11/// Maximum TDS packet size (64KB - 1).
12pub const MAX_PACKET_SIZE: usize = 65535;
13
14/// Default TDS packet size.
15pub const DEFAULT_PACKET_SIZE: usize = 4096;
16
17/// TDS packet type.
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
19#[repr(u8)]
20#[non_exhaustive]
21pub enum PacketType {
22    /// SQL batch request.
23    SqlBatch = 0x01,
24    /// Pre-TDS7 login packet.
25    PreTds7Login = 0x02,
26    /// Remote procedure call.
27    Rpc = 0x03,
28    /// Tabular response.
29    TabularResult = 0x04,
30    /// Attention signal.
31    Attention = 0x06,
32    /// Bulk load data.
33    BulkLoad = 0x07,
34    /// Federated authentication token.
35    FedAuthToken = 0x08,
36    /// Transaction manager request.
37    TransactionManager = 0x0E,
38    /// TDS7+ login packet.
39    Tds7Login = 0x10,
40    /// SSPI authentication.
41    Sspi = 0x11,
42    /// Pre-login packet.
43    PreLogin = 0x12,
44}
45
46impl PacketType {
47    /// Create a packet type from a raw byte value.
48    pub fn from_u8(value: u8) -> Result<Self, ProtocolError> {
49        match value {
50            0x01 => Ok(Self::SqlBatch),
51            0x02 => Ok(Self::PreTds7Login),
52            0x03 => Ok(Self::Rpc),
53            0x04 => Ok(Self::TabularResult),
54            0x06 => Ok(Self::Attention),
55            0x07 => Ok(Self::BulkLoad),
56            0x08 => Ok(Self::FedAuthToken),
57            0x0E => Ok(Self::TransactionManager),
58            0x10 => Ok(Self::Tds7Login),
59            0x11 => Ok(Self::Sspi),
60            0x12 => Ok(Self::PreLogin),
61            _ => Err(ProtocolError::InvalidPacketType(value)),
62        }
63    }
64}
65
66bitflags! {
67    /// TDS packet status flags.
68    #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
69    pub struct PacketStatus: u8 {
70        /// Normal packet, more packets to follow.
71        const NORMAL = 0x00;
72        /// End of message (last packet).
73        const END_OF_MESSAGE = 0x01;
74        /// Ignore this event (used for attention acknowledgment).
75        const IGNORE_EVENT = 0x02;
76        /// Reset connection (SQL Server 2000+).
77        const RESET_CONNECTION = 0x08;
78        /// Reset connection but keep transaction state.
79        const RESET_CONNECTION_KEEP_TRANSACTION = 0x10;
80    }
81}
82
83/// TDS packet header.
84///
85/// Every TDS packet begins with an 8-byte header that describes
86/// the packet type, status, and length.
87#[derive(Debug, Clone, Copy, PartialEq, Eq)]
88pub struct PacketHeader {
89    /// Type of packet.
90    pub packet_type: PacketType,
91    /// Status flags.
92    pub status: PacketStatus,
93    /// Total packet length including header.
94    pub length: u16,
95    /// Server process ID (SPID).
96    pub spid: u16,
97    /// Packet sequence number (wraps at 255).
98    pub packet_id: u8,
99    /// Window (unused, should be 0).
100    pub window: u8,
101}
102
103impl PacketHeader {
104    /// Create a new packet header.
105    #[must_use]
106    pub const fn new(packet_type: PacketType, status: PacketStatus, length: u16) -> Self {
107        Self {
108            packet_type,
109            status,
110            length,
111            spid: 0,
112            packet_id: 0,
113            window: 0,
114        }
115    }
116
117    /// Parse a packet header from bytes.
118    pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
119        if src.remaining() < PACKET_HEADER_SIZE {
120            return Err(ProtocolError::IncompletePacket {
121                expected: PACKET_HEADER_SIZE,
122                actual: src.remaining(),
123            });
124        }
125
126        let packet_type = PacketType::from_u8(src.get_u8())?;
127        let status_byte = src.get_u8();
128        let status = PacketStatus::from_bits(status_byte)
129            .ok_or(ProtocolError::InvalidPacketStatus(status_byte))?;
130        let length = src.get_u16();
131        let spid = src.get_u16();
132        let packet_id = src.get_u8();
133        let window = src.get_u8();
134
135        Ok(Self {
136            packet_type,
137            status,
138            length,
139            spid,
140            packet_id,
141            window,
142        })
143    }
144
145    /// Encode the packet header to bytes.
146    pub fn encode(&self, dst: &mut impl BufMut) {
147        dst.put_u8(self.packet_type as u8);
148        dst.put_u8(self.status.bits());
149        dst.put_u16(self.length);
150        dst.put_u16(self.spid);
151        dst.put_u8(self.packet_id);
152        dst.put_u8(self.window);
153    }
154
155    /// Encode the packet header to a new `Bytes` buffer.
156    #[must_use]
157    pub fn encode_to_bytes(&self) -> Bytes {
158        let mut buf = BytesMut::with_capacity(PACKET_HEADER_SIZE);
159        self.encode(&mut buf);
160        buf.freeze()
161    }
162
163    /// Get the payload length (total length minus header).
164    #[must_use]
165    pub const fn payload_length(&self) -> usize {
166        self.length.saturating_sub(PACKET_HEADER_SIZE as u16) as usize
167    }
168
169    /// Check if this is the last packet in a message.
170    #[must_use]
171    pub const fn is_end_of_message(&self) -> bool {
172        self.status.contains(PacketStatus::END_OF_MESSAGE)
173    }
174
175    /// Set the packet ID (sequence number).
176    #[must_use]
177    pub const fn with_packet_id(mut self, id: u8) -> Self {
178        self.packet_id = id;
179        self
180    }
181
182    /// Set the SPID.
183    #[must_use]
184    pub const fn with_spid(mut self, spid: u16) -> Self {
185        self.spid = spid;
186        self
187    }
188}
189
190impl Default for PacketHeader {
191    fn default() -> Self {
192        Self {
193            packet_type: PacketType::SqlBatch,
194            status: PacketStatus::END_OF_MESSAGE,
195            length: PACKET_HEADER_SIZE as u16,
196            spid: 0,
197            packet_id: 1,
198            window: 0,
199        }
200    }
201}
202
203#[cfg(test)]
204#[allow(clippy::unwrap_used)]
205mod tests {
206    use super::*;
207
208    #[test]
209    fn test_header_roundtrip() {
210        let header = PacketHeader {
211            packet_type: PacketType::SqlBatch,
212            status: PacketStatus::END_OF_MESSAGE,
213            length: 100,
214            spid: 54,
215            packet_id: 1,
216            window: 0,
217        };
218
219        let bytes = header.encode_to_bytes();
220        assert_eq!(bytes.len(), PACKET_HEADER_SIZE);
221
222        let mut cursor = bytes.as_ref();
223        let decoded = PacketHeader::decode(&mut cursor).unwrap();
224        assert_eq!(header, decoded);
225    }
226
227    #[test]
228    fn test_payload_length() {
229        let header = PacketHeader::new(PacketType::SqlBatch, PacketStatus::END_OF_MESSAGE, 100);
230        assert_eq!(header.payload_length(), 92);
231    }
232
233    #[test]
234    fn test_packet_type_from_u8() {
235        assert_eq!(PacketType::from_u8(0x01).unwrap(), PacketType::SqlBatch);
236        assert_eq!(PacketType::from_u8(0x12).unwrap(), PacketType::PreLogin);
237        assert!(PacketType::from_u8(0xFF).is_err());
238    }
239}