1use bytes::{Buf, BufMut, Bytes, BytesMut};
11
12use crate::error::ProtocolError;
13use crate::version::TdsVersion;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17#[repr(u8)]
18pub enum PreLoginOption {
19 Version = 0x00,
21 Encryption = 0x01,
23 Instance = 0x02,
25 ThreadId = 0x03,
27 Mars = 0x04,
29 TraceId = 0x05,
31 FedAuthRequired = 0x06,
33 Nonce = 0x07,
35 Terminator = 0xFF,
37}
38
39impl PreLoginOption {
40 pub fn from_u8(value: u8) -> Result<Self, ProtocolError> {
42 match value {
43 0x00 => Ok(Self::Version),
44 0x01 => Ok(Self::Encryption),
45 0x02 => Ok(Self::Instance),
46 0x03 => Ok(Self::ThreadId),
47 0x04 => Ok(Self::Mars),
48 0x05 => Ok(Self::TraceId),
49 0x06 => Ok(Self::FedAuthRequired),
50 0x07 => Ok(Self::Nonce),
51 0xFF => Ok(Self::Terminator),
52 _ => Err(ProtocolError::InvalidPreloginOption(value)),
53 }
54 }
55}
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
59#[repr(u8)]
60pub enum EncryptionLevel {
61 Off = 0x00,
63 On = 0x01,
65 NotSupported = 0x02,
67 #[default]
69 Required = 0x03,
70 ClientCertAuth = 0x80,
72}
73
74impl EncryptionLevel {
75 pub fn from_u8(value: u8) -> Self {
77 match value {
78 0x00 => Self::Off,
79 0x01 => Self::On,
80 0x02 => Self::NotSupported,
81 0x03 => Self::Required,
82 0x80 => Self::ClientCertAuth,
83 _ => Self::Off,
84 }
85 }
86
87 #[must_use]
89 pub const fn is_required(&self) -> bool {
90 matches!(self, Self::On | Self::Required | Self::ClientCertAuth)
91 }
92}
93
94#[derive(Debug, Clone, Default)]
96pub struct PreLogin {
97 pub version: TdsVersion,
99 pub sub_build: u16,
101 pub encryption: EncryptionLevel,
103 pub instance: Option<String>,
105 pub thread_id: Option<u32>,
107 pub mars: bool,
109 pub trace_id: Option<TraceId>,
111 pub fed_auth_required: bool,
113 pub nonce: Option<[u8; 32]>,
115}
116
117#[derive(Debug, Clone, Copy)]
119pub struct TraceId {
120 pub activity_id: [u8; 16],
122 pub activity_sequence: u32,
124}
125
126impl PreLogin {
127 #[must_use]
129 pub fn new() -> Self {
130 Self {
131 version: TdsVersion::V7_4,
132 sub_build: 0,
133 encryption: EncryptionLevel::Required,
134 instance: None,
135 thread_id: None,
136 mars: false,
137 trace_id: None,
138 fed_auth_required: false,
139 nonce: None,
140 }
141 }
142
143 #[must_use]
145 pub fn with_version(mut self, version: TdsVersion) -> Self {
146 self.version = version;
147 self
148 }
149
150 #[must_use]
152 pub fn with_encryption(mut self, level: EncryptionLevel) -> Self {
153 self.encryption = level;
154 self
155 }
156
157 #[must_use]
159 pub fn with_mars(mut self, enabled: bool) -> Self {
160 self.mars = enabled;
161 self
162 }
163
164 #[must_use]
166 pub fn with_instance(mut self, instance: impl Into<String>) -> Self {
167 self.instance = Some(instance.into());
168 self
169 }
170
171 #[must_use]
173 pub fn encode(&self) -> Bytes {
174 let mut buf = BytesMut::with_capacity(256);
175
176 let mut option_count = 3; if self.instance.is_some() {
181 option_count += 1;
182 }
183 if self.thread_id.is_some() {
184 option_count += 1;
185 }
186 if self.trace_id.is_some() {
187 option_count += 1;
188 }
189 if self.fed_auth_required {
190 option_count += 1;
191 }
192 if self.nonce.is_some() {
193 option_count += 1;
194 }
195
196 let header_size = option_count * 5 + 1; let mut data_offset = header_size as u16;
198 let mut data_buf = BytesMut::new();
199
200 buf.put_u8(PreLoginOption::Version as u8);
202 buf.put_u16(data_offset);
203 buf.put_u16(6);
204 let version_raw = self.version.raw();
205 data_buf.put_u8((version_raw >> 24) as u8);
206 data_buf.put_u8((version_raw >> 16) as u8);
207 data_buf.put_u8((version_raw >> 8) as u8);
208 data_buf.put_u8(version_raw as u8);
209 data_buf.put_u16_le(self.sub_build);
210 data_offset += 6;
211
212 buf.put_u8(PreLoginOption::Encryption as u8);
214 buf.put_u16(data_offset);
215 buf.put_u16(1);
216 data_buf.put_u8(self.encryption as u8);
217 data_offset += 1;
218
219 if let Some(ref instance) = self.instance {
221 let instance_bytes = instance.as_bytes();
222 let len = instance_bytes.len() as u16 + 1; buf.put_u8(PreLoginOption::Instance as u8);
224 buf.put_u16(data_offset);
225 buf.put_u16(len);
226 data_buf.put_slice(instance_bytes);
227 data_buf.put_u8(0); data_offset += len;
229 }
230
231 if let Some(thread_id) = self.thread_id {
233 buf.put_u8(PreLoginOption::ThreadId as u8);
234 buf.put_u16(data_offset);
235 buf.put_u16(4);
236 data_buf.put_u32(thread_id);
237 data_offset += 4;
238 }
239
240 buf.put_u8(PreLoginOption::Mars as u8);
242 buf.put_u16(data_offset);
243 buf.put_u16(1);
244 data_buf.put_u8(if self.mars { 0x01 } else { 0x00 });
245 data_offset += 1;
246
247 if let Some(ref trace_id) = self.trace_id {
249 buf.put_u8(PreLoginOption::TraceId as u8);
250 buf.put_u16(data_offset);
251 buf.put_u16(36);
252 data_buf.put_slice(&trace_id.activity_id);
253 data_buf.put_u32_le(trace_id.activity_sequence);
254 data_buf.put_slice(&[0u8; 16]);
256 data_offset += 36;
257 }
258
259 if self.fed_auth_required {
261 buf.put_u8(PreLoginOption::FedAuthRequired as u8);
262 buf.put_u16(data_offset);
263 buf.put_u16(1);
264 data_buf.put_u8(0x01);
265 data_offset += 1;
266 }
267
268 if let Some(ref nonce) = self.nonce {
270 buf.put_u8(PreLoginOption::Nonce as u8);
271 buf.put_u16(data_offset);
272 buf.put_u16(32);
273 data_buf.put_slice(nonce);
274 let _ = data_offset; }
276
277 buf.put_u8(PreLoginOption::Terminator as u8);
279
280 buf.put_slice(&data_buf);
282
283 buf.freeze()
284 }
285
286 pub fn decode(mut src: impl Buf) -> Result<Self, ProtocolError> {
295 let mut prelogin = Self::default();
296
297 let mut options = Vec::new();
299 loop {
300 if src.remaining() < 1 {
301 return Err(ProtocolError::UnexpectedEof);
302 }
303
304 let option_type = src.get_u8();
305 if option_type == PreLoginOption::Terminator as u8 {
306 break;
307 }
308
309 if src.remaining() < 4 {
310 return Err(ProtocolError::UnexpectedEof);
311 }
312
313 let offset = src.get_u16();
314 let length = src.get_u16();
315 options.push((PreLoginOption::from_u8(option_type)?, offset, length));
316 }
317
318 let data = src.copy_to_bytes(src.remaining());
320
321 let header_size = options.len() * 5 + 1;
323
324 for (option, packet_offset, length) in options {
325 let packet_offset = packet_offset as usize;
326 let length = length as usize;
327
328 if packet_offset < header_size {
331 continue;
333 }
334 let data_offset = packet_offset - header_size;
335
336 if data_offset + length > data.len() {
338 continue;
339 }
340
341 match option {
342 PreLoginOption::Version if length >= 6 => {
343 let version_bytes = &data[data_offset..data_offset + 4];
345 let version_raw = u32::from_be_bytes([
346 version_bytes[0],
347 version_bytes[1],
348 version_bytes[2],
349 version_bytes[3],
350 ]);
351 prelogin.version = TdsVersion::new(version_raw);
352
353 if length >= 6 {
354 let sub_build_bytes = &data[data_offset + 4..data_offset + 6];
355 prelogin.sub_build =
356 u16::from_le_bytes([sub_build_bytes[0], sub_build_bytes[1]]);
357 }
358 }
359 PreLoginOption::Encryption if length >= 1 => {
360 prelogin.encryption = EncryptionLevel::from_u8(data[data_offset]);
361 }
362 PreLoginOption::Mars if length >= 1 => {
363 prelogin.mars = data[data_offset] != 0;
364 }
365 PreLoginOption::Instance if length > 0 => {
366 let instance_data = &data[data_offset..data_offset + length];
368 if let Some(null_pos) = instance_data.iter().position(|&b| b == 0) {
369 if let Ok(s) = std::str::from_utf8(&instance_data[..null_pos]) {
370 if !s.is_empty() {
371 prelogin.instance = Some(s.to_string());
372 }
373 }
374 }
375 }
376 PreLoginOption::ThreadId if length >= 4 => {
377 let bytes = &data[data_offset..data_offset + 4];
378 prelogin.thread_id =
379 Some(u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]));
380 }
381 PreLoginOption::FedAuthRequired if length >= 1 => {
382 prelogin.fed_auth_required = data[data_offset] != 0;
383 }
384 PreLoginOption::Nonce if length >= 32 => {
385 let mut nonce = [0u8; 32];
386 nonce.copy_from_slice(&data[data_offset..data_offset + 32]);
387 prelogin.nonce = Some(nonce);
388 }
389 _ => {}
390 }
391 }
392
393 Ok(prelogin)
394 }
395}
396
397#[cfg(not(feature = "std"))]
398use alloc::string::String;
399#[cfg(not(feature = "std"))]
400use alloc::vec::Vec;
401
402#[cfg(test)]
403#[allow(clippy::unwrap_used)]
404mod tests {
405 use super::*;
406
407 #[test]
408 fn test_prelogin_encode() {
409 let prelogin = PreLogin::new()
410 .with_version(TdsVersion::V7_4)
411 .with_encryption(EncryptionLevel::Required);
412
413 let encoded = prelogin.encode();
414 assert!(!encoded.is_empty());
415 assert_eq!(encoded[0], PreLoginOption::Version as u8);
417 }
418
419 #[test]
420 fn test_encryption_level() {
421 assert!(EncryptionLevel::Required.is_required());
422 assert!(EncryptionLevel::On.is_required());
423 assert!(!EncryptionLevel::Off.is_required());
424 assert!(!EncryptionLevel::NotSupported.is_required());
425 }
426
427 #[test]
428 fn test_prelogin_decode_roundtrip() {
429 let original = PreLogin::new()
431 .with_version(TdsVersion::V7_4)
432 .with_encryption(EncryptionLevel::On)
433 .with_mars(true);
434
435 let encoded = original.encode();
437
438 let decoded = PreLogin::decode(encoded.as_ref()).unwrap();
440
441 assert_eq!(decoded.version, original.version);
443 assert_eq!(decoded.encryption, original.encryption);
444 assert_eq!(decoded.mars, original.mars);
445 }
446
447 #[test]
448 fn test_prelogin_decode_encryption_offset() {
449 use bytes::BufMut;
459
460 let mut buf = bytes::BytesMut::new();
461
462 let header_size: u16 = 11;
465
466 buf.put_u8(PreLoginOption::Encryption as u8);
468 buf.put_u16(header_size); buf.put_u16(1); buf.put_u8(PreLoginOption::Version as u8);
473 buf.put_u16(header_size + 1); buf.put_u16(6); buf.put_u8(PreLoginOption::Terminator as u8);
478
479 buf.put_u8(0x01);
482
483 buf.put_u8(0x74);
485 buf.put_u8(0x00);
486 buf.put_u8(0x00);
487 buf.put_u8(0x04);
488 buf.put_u16_le(0x0000); let decoded = PreLogin::decode(buf.freeze().as_ref()).unwrap();
492
493 assert_eq!(decoded.encryption, EncryptionLevel::On);
495 assert_eq!(decoded.version, TdsVersion::V7_4);
496 }
497}