tds_protocol/
packet.rs

1//! TDS packet header definitions.
2
3use bitflags::bitflags;
4use bytes::{Buf, BufMut, Bytes, BytesMut};
5
6use crate::error::ProtocolError;
7
8/// TDS packet header size in bytes.
9pub const PACKET_HEADER_SIZE: usize = 8;
10
11/// Maximum TDS packet size (64KB - 1).
12pub const MAX_PACKET_SIZE: usize = 65535;
13
14/// Default TDS packet size.
15pub const DEFAULT_PACKET_SIZE: usize = 4096;
16
17/// TDS packet type.
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
19#[repr(u8)]
20pub enum PacketType {
21    /// SQL batch request.
22    SqlBatch = 0x01,
23    /// Pre-TDS7 login packet.
24    PreTds7Login = 0x02,
25    /// Remote procedure call.
26    Rpc = 0x03,
27    /// Tabular response.
28    TabularResult = 0x04,
29    /// Attention signal.
30    Attention = 0x06,
31    /// Bulk load data.
32    BulkLoad = 0x07,
33    /// Federated authentication token.
34    FedAuthToken = 0x08,
35    /// Transaction manager request.
36    TransactionManager = 0x0E,
37    /// TDS7+ login packet.
38    Tds7Login = 0x10,
39    /// SSPI authentication.
40    Sspi = 0x11,
41    /// Pre-login packet.
42    PreLogin = 0x12,
43}
44
45impl PacketType {
46    /// Create a packet type from a raw byte value.
47    pub fn from_u8(value: u8) -> Result<Self, ProtocolError> {
48        match value {
49            0x01 => Ok(Self::SqlBatch),
50            0x02 => Ok(Self::PreTds7Login),
51            0x03 => Ok(Self::Rpc),
52            0x04 => Ok(Self::TabularResult),
53            0x06 => Ok(Self::Attention),
54            0x07 => Ok(Self::BulkLoad),
55            0x08 => Ok(Self::FedAuthToken),
56            0x0E => Ok(Self::TransactionManager),
57            0x10 => Ok(Self::Tds7Login),
58            0x11 => Ok(Self::Sspi),
59            0x12 => Ok(Self::PreLogin),
60            _ => Err(ProtocolError::InvalidPacketType(value)),
61        }
62    }
63}
64
65bitflags! {
66    /// TDS packet status flags.
67    #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
68    pub struct PacketStatus: u8 {
69        /// Normal packet, more packets to follow.
70        const NORMAL = 0x00;
71        /// End of message (last packet).
72        const END_OF_MESSAGE = 0x01;
73        /// Ignore this event (used for attention acknowledgment).
74        const IGNORE_EVENT = 0x02;
75        /// Reset connection (SQL Server 2000+).
76        const RESET_CONNECTION = 0x08;
77        /// Reset connection but keep transaction state.
78        const RESET_CONNECTION_KEEP_TRANSACTION = 0x10;
79    }
80}
81
82/// TDS packet header.
83///
84/// Every TDS packet begins with an 8-byte header that describes
85/// the packet type, status, and length.
86#[derive(Debug, Clone, Copy, PartialEq, Eq)]
87pub struct PacketHeader {
88    /// Type of packet.
89    pub packet_type: PacketType,
90    /// Status flags.
91    pub status: PacketStatus,
92    /// Total packet length including header.
93    pub length: u16,
94    /// Server process ID (SPID).
95    pub spid: u16,
96    /// Packet sequence number (wraps at 255).
97    pub packet_id: u8,
98    /// Window (unused, should be 0).
99    pub window: u8,
100}
101
102impl PacketHeader {
103    /// Create a new packet header.
104    #[must_use]
105    pub const fn new(packet_type: PacketType, status: PacketStatus, length: u16) -> Self {
106        Self {
107            packet_type,
108            status,
109            length,
110            spid: 0,
111            packet_id: 0,
112            window: 0,
113        }
114    }
115
116    /// Parse a packet header from bytes.
117    pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
118        if src.remaining() < PACKET_HEADER_SIZE {
119            return Err(ProtocolError::IncompletePacket {
120                expected: PACKET_HEADER_SIZE,
121                actual: src.remaining(),
122            });
123        }
124
125        let packet_type = PacketType::from_u8(src.get_u8())?;
126        let status_byte = src.get_u8();
127        let status = PacketStatus::from_bits(status_byte)
128            .ok_or(ProtocolError::InvalidPacketStatus(status_byte))?;
129        let length = src.get_u16();
130        let spid = src.get_u16();
131        let packet_id = src.get_u8();
132        let window = src.get_u8();
133
134        Ok(Self {
135            packet_type,
136            status,
137            length,
138            spid,
139            packet_id,
140            window,
141        })
142    }
143
144    /// Encode the packet header to bytes.
145    pub fn encode(&self, dst: &mut impl BufMut) {
146        dst.put_u8(self.packet_type as u8);
147        dst.put_u8(self.status.bits());
148        dst.put_u16(self.length);
149        dst.put_u16(self.spid);
150        dst.put_u8(self.packet_id);
151        dst.put_u8(self.window);
152    }
153
154    /// Encode the packet header to a new `Bytes` buffer.
155    #[must_use]
156    pub fn encode_to_bytes(&self) -> Bytes {
157        let mut buf = BytesMut::with_capacity(PACKET_HEADER_SIZE);
158        self.encode(&mut buf);
159        buf.freeze()
160    }
161
162    /// Get the payload length (total length minus header).
163    #[must_use]
164    pub const fn payload_length(&self) -> usize {
165        self.length.saturating_sub(PACKET_HEADER_SIZE as u16) as usize
166    }
167
168    /// Check if this is the last packet in a message.
169    #[must_use]
170    pub const fn is_end_of_message(&self) -> bool {
171        self.status.contains(PacketStatus::END_OF_MESSAGE)
172    }
173
174    /// Set the packet ID (sequence number).
175    #[must_use]
176    pub const fn with_packet_id(mut self, id: u8) -> Self {
177        self.packet_id = id;
178        self
179    }
180
181    /// Set the SPID.
182    #[must_use]
183    pub const fn with_spid(mut self, spid: u16) -> Self {
184        self.spid = spid;
185        self
186    }
187}
188
189impl Default for PacketHeader {
190    fn default() -> Self {
191        Self {
192            packet_type: PacketType::SqlBatch,
193            status: PacketStatus::END_OF_MESSAGE,
194            length: PACKET_HEADER_SIZE as u16,
195            spid: 0,
196            packet_id: 1,
197            window: 0,
198        }
199    }
200}
201
202#[cfg(test)]
203#[allow(clippy::unwrap_used)]
204mod tests {
205    use super::*;
206
207    #[test]
208    fn test_header_roundtrip() {
209        let header = PacketHeader {
210            packet_type: PacketType::SqlBatch,
211            status: PacketStatus::END_OF_MESSAGE,
212            length: 100,
213            spid: 54,
214            packet_id: 1,
215            window: 0,
216        };
217
218        let bytes = header.encode_to_bytes();
219        assert_eq!(bytes.len(), PACKET_HEADER_SIZE);
220
221        let mut cursor = bytes.as_ref();
222        let decoded = PacketHeader::decode(&mut cursor).unwrap();
223        assert_eq!(header, decoded);
224    }
225
226    #[test]
227    fn test_payload_length() {
228        let header = PacketHeader::new(PacketType::SqlBatch, PacketStatus::END_OF_MESSAGE, 100);
229        assert_eq!(header.payload_length(), 92);
230    }
231
232    #[test]
233    fn test_packet_type_from_u8() {
234        assert_eq!(PacketType::from_u8(0x01).unwrap(), PacketType::SqlBatch);
235        assert_eq!(PacketType::from_u8(0x12).unwrap(), PacketType::PreLogin);
236        assert!(PacketType::from_u8(0xFF).is_err());
237    }
238}