1use bitflags::bitflags;
4use bytes::{Buf, BufMut, Bytes, BytesMut};
5
6use crate::error::ProtocolError;
7
8pub const PACKET_HEADER_SIZE: usize = 8;
10
11pub const MAX_PACKET_SIZE: usize = 65535;
13
14pub const DEFAULT_PACKET_SIZE: usize = 4096;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
19#[repr(u8)]
20pub enum PacketType {
21 SqlBatch = 0x01,
23 PreTds7Login = 0x02,
25 Rpc = 0x03,
27 TabularResult = 0x04,
29 Attention = 0x06,
31 BulkLoad = 0x07,
33 FedAuthToken = 0x08,
35 TransactionManager = 0x0E,
37 Tds7Login = 0x10,
39 Sspi = 0x11,
41 PreLogin = 0x12,
43}
44
45impl PacketType {
46 pub fn from_u8(value: u8) -> Result<Self, ProtocolError> {
48 match value {
49 0x01 => Ok(Self::SqlBatch),
50 0x02 => Ok(Self::PreTds7Login),
51 0x03 => Ok(Self::Rpc),
52 0x04 => Ok(Self::TabularResult),
53 0x06 => Ok(Self::Attention),
54 0x07 => Ok(Self::BulkLoad),
55 0x08 => Ok(Self::FedAuthToken),
56 0x0E => Ok(Self::TransactionManager),
57 0x10 => Ok(Self::Tds7Login),
58 0x11 => Ok(Self::Sspi),
59 0x12 => Ok(Self::PreLogin),
60 _ => Err(ProtocolError::InvalidPacketType(value)),
61 }
62 }
63}
64
65bitflags! {
66 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
68 pub struct PacketStatus: u8 {
69 const NORMAL = 0x00;
71 const END_OF_MESSAGE = 0x01;
73 const IGNORE_EVENT = 0x02;
75 const RESET_CONNECTION = 0x08;
77 const RESET_CONNECTION_KEEP_TRANSACTION = 0x10;
79 }
80}
81
82#[derive(Debug, Clone, Copy, PartialEq, Eq)]
87pub struct PacketHeader {
88 pub packet_type: PacketType,
90 pub status: PacketStatus,
92 pub length: u16,
94 pub spid: u16,
96 pub packet_id: u8,
98 pub window: u8,
100}
101
102impl PacketHeader {
103 #[must_use]
105 pub const fn new(packet_type: PacketType, status: PacketStatus, length: u16) -> Self {
106 Self {
107 packet_type,
108 status,
109 length,
110 spid: 0,
111 packet_id: 0,
112 window: 0,
113 }
114 }
115
116 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
118 if src.remaining() < PACKET_HEADER_SIZE {
119 return Err(ProtocolError::IncompletePacket {
120 expected: PACKET_HEADER_SIZE,
121 actual: src.remaining(),
122 });
123 }
124
125 let packet_type = PacketType::from_u8(src.get_u8())?;
126 let status_byte = src.get_u8();
127 let status = PacketStatus::from_bits(status_byte)
128 .ok_or(ProtocolError::InvalidPacketStatus(status_byte))?;
129 let length = src.get_u16();
130 let spid = src.get_u16();
131 let packet_id = src.get_u8();
132 let window = src.get_u8();
133
134 Ok(Self {
135 packet_type,
136 status,
137 length,
138 spid,
139 packet_id,
140 window,
141 })
142 }
143
144 pub fn encode(&self, dst: &mut impl BufMut) {
146 dst.put_u8(self.packet_type as u8);
147 dst.put_u8(self.status.bits());
148 dst.put_u16(self.length);
149 dst.put_u16(self.spid);
150 dst.put_u8(self.packet_id);
151 dst.put_u8(self.window);
152 }
153
154 #[must_use]
156 pub fn encode_to_bytes(&self) -> Bytes {
157 let mut buf = BytesMut::with_capacity(PACKET_HEADER_SIZE);
158 self.encode(&mut buf);
159 buf.freeze()
160 }
161
162 #[must_use]
164 pub const fn payload_length(&self) -> usize {
165 self.length.saturating_sub(PACKET_HEADER_SIZE as u16) as usize
166 }
167
168 #[must_use]
170 pub const fn is_end_of_message(&self) -> bool {
171 self.status.contains(PacketStatus::END_OF_MESSAGE)
172 }
173
174 #[must_use]
176 pub const fn with_packet_id(mut self, id: u8) -> Self {
177 self.packet_id = id;
178 self
179 }
180
181 #[must_use]
183 pub const fn with_spid(mut self, spid: u16) -> Self {
184 self.spid = spid;
185 self
186 }
187}
188
189impl Default for PacketHeader {
190 fn default() -> Self {
191 Self {
192 packet_type: PacketType::SqlBatch,
193 status: PacketStatus::END_OF_MESSAGE,
194 length: PACKET_HEADER_SIZE as u16,
195 spid: 0,
196 packet_id: 1,
197 window: 0,
198 }
199 }
200}
201
202#[cfg(test)]
203#[allow(clippy::unwrap_used)]
204mod tests {
205 use super::*;
206
207 #[test]
208 fn test_header_roundtrip() {
209 let header = PacketHeader {
210 packet_type: PacketType::SqlBatch,
211 status: PacketStatus::END_OF_MESSAGE,
212 length: 100,
213 spid: 54,
214 packet_id: 1,
215 window: 0,
216 };
217
218 let bytes = header.encode_to_bytes();
219 assert_eq!(bytes.len(), PACKET_HEADER_SIZE);
220
221 let mut cursor = bytes.as_ref();
222 let decoded = PacketHeader::decode(&mut cursor).unwrap();
223 assert_eq!(header, decoded);
224 }
225
226 #[test]
227 fn test_payload_length() {
228 let header = PacketHeader::new(PacketType::SqlBatch, PacketStatus::END_OF_MESSAGE, 100);
229 assert_eq!(header.payload_length(), 92);
230 }
231
232 #[test]
233 fn test_packet_type_from_u8() {
234 assert_eq!(PacketType::from_u8(0x01).unwrap(), PacketType::SqlBatch);
235 assert_eq!(PacketType::from_u8(0x12).unwrap(), PacketType::PreLogin);
236 assert!(PacketType::from_u8(0xFF).is_err());
237 }
238}