1use bytes::{Buf, BufMut, Bytes, BytesMut};
11
12use crate::error::ProtocolError;
13use crate::prelude::*;
14use crate::version::{SqlServerVersion, TdsVersion};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18#[repr(u8)]
19#[non_exhaustive]
20pub enum PreLoginOption {
21 Version = 0x00,
23 Encryption = 0x01,
25 Instance = 0x02,
27 ThreadId = 0x03,
29 Mars = 0x04,
31 TraceId = 0x05,
33 FedAuthRequired = 0x06,
35 Nonce = 0x07,
37 Terminator = 0xFF,
39}
40
41impl PreLoginOption {
42 pub fn from_u8(value: u8) -> Result<Self, ProtocolError> {
44 match value {
45 0x00 => Ok(Self::Version),
46 0x01 => Ok(Self::Encryption),
47 0x02 => Ok(Self::Instance),
48 0x03 => Ok(Self::ThreadId),
49 0x04 => Ok(Self::Mars),
50 0x05 => Ok(Self::TraceId),
51 0x06 => Ok(Self::FedAuthRequired),
52 0x07 => Ok(Self::Nonce),
53 0xFF => Ok(Self::Terminator),
54 _ => Err(ProtocolError::InvalidPreloginOption(value)),
55 }
56 }
57}
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
61#[repr(u8)]
62#[non_exhaustive]
63pub enum EncryptionLevel {
64 Off = 0x00,
66 On = 0x01,
68 NotSupported = 0x02,
70 #[default]
72 Required = 0x03,
73 ClientCertAuth = 0x80,
75}
76
77impl EncryptionLevel {
78 pub fn from_u8(value: u8) -> Self {
80 match value {
81 0x00 => Self::Off,
82 0x01 => Self::On,
83 0x02 => Self::NotSupported,
84 0x03 => Self::Required,
85 0x80 => Self::ClientCertAuth,
86 _ => Self::Off,
87 }
88 }
89
90 #[must_use]
92 pub const fn is_required(&self) -> bool {
93 matches!(self, Self::On | Self::Required | Self::ClientCertAuth)
94 }
95}
96
97#[derive(Debug, Clone, Default)]
107pub struct PreLogin {
108 pub version: TdsVersion,
113
114 pub server_version: Option<SqlServerVersion>,
121
122 pub encryption: EncryptionLevel,
124 pub instance: Option<String>,
126 pub thread_id: Option<u32>,
128 pub mars: bool,
130 pub trace_id: Option<TraceId>,
132 pub fed_auth_required: bool,
134 pub nonce: Option<[u8; 32]>,
136}
137
138#[derive(Debug, Clone, Copy)]
140pub struct TraceId {
141 pub activity_id: [u8; 16],
143 pub activity_sequence: u32,
145}
146
147impl PreLogin {
148 #[must_use]
150 pub fn new() -> Self {
151 Self {
152 version: TdsVersion::V7_4,
153 server_version: None,
154 encryption: EncryptionLevel::Required,
155 instance: None,
156 thread_id: None,
157 mars: false,
158 trace_id: None,
159 fed_auth_required: false,
160 nonce: None,
161 }
162 }
163
164 #[must_use]
166 pub fn with_version(mut self, version: TdsVersion) -> Self {
167 self.version = version;
168 self
169 }
170
171 #[must_use]
173 pub fn with_encryption(mut self, level: EncryptionLevel) -> Self {
174 self.encryption = level;
175 self
176 }
177
178 #[must_use]
180 pub fn with_mars(mut self, enabled: bool) -> Self {
181 self.mars = enabled;
182 self
183 }
184
185 #[must_use]
187 pub fn with_instance(mut self, instance: impl Into<String>) -> Self {
188 self.instance = Some(instance.into());
189 self
190 }
191
192 #[must_use]
200 pub fn with_fed_auth_required(mut self, required: bool) -> Self {
201 self.fed_auth_required = required;
202 self
203 }
204
205 #[must_use]
207 pub fn encode(&self) -> Bytes {
208 let mut buf = BytesMut::with_capacity(256);
209
210 let mut option_count = 3; if self.instance.is_some() {
215 option_count += 1;
216 }
217 if self.thread_id.is_some() {
218 option_count += 1;
219 }
220 if self.trace_id.is_some() {
221 option_count += 1;
222 }
223 if self.fed_auth_required {
224 option_count += 1;
225 }
226 if self.nonce.is_some() {
227 option_count += 1;
228 }
229
230 let header_size = option_count * 5 + 1; let mut data_offset = header_size as u16;
232 let mut data_buf = BytesMut::new();
233
234 buf.put_u8(PreLoginOption::Version as u8);
236 buf.put_u16(data_offset);
237 buf.put_u16(6);
238 let version_raw = self.version.raw();
239 data_buf.put_u8((version_raw >> 24) as u8);
240 data_buf.put_u8((version_raw >> 16) as u8);
241 data_buf.put_u8((version_raw >> 8) as u8);
242 data_buf.put_u8(version_raw as u8);
243 data_buf.put_u16_le(0);
246 data_offset += 6;
247
248 buf.put_u8(PreLoginOption::Encryption as u8);
250 buf.put_u16(data_offset);
251 buf.put_u16(1);
252 data_buf.put_u8(self.encryption as u8);
253 data_offset += 1;
254
255 if let Some(ref instance) = self.instance {
257 let instance_bytes = instance.as_bytes();
258 let len = instance_bytes.len() as u16 + 1; buf.put_u8(PreLoginOption::Instance as u8);
260 buf.put_u16(data_offset);
261 buf.put_u16(len);
262 data_buf.put_slice(instance_bytes);
263 data_buf.put_u8(0); data_offset += len;
265 }
266
267 if let Some(thread_id) = self.thread_id {
269 buf.put_u8(PreLoginOption::ThreadId as u8);
270 buf.put_u16(data_offset);
271 buf.put_u16(4);
272 data_buf.put_u32(thread_id);
273 data_offset += 4;
274 }
275
276 buf.put_u8(PreLoginOption::Mars as u8);
278 buf.put_u16(data_offset);
279 buf.put_u16(1);
280 data_buf.put_u8(if self.mars { 0x01 } else { 0x00 });
281 data_offset += 1;
282
283 if let Some(ref trace_id) = self.trace_id {
285 buf.put_u8(PreLoginOption::TraceId as u8);
286 buf.put_u16(data_offset);
287 buf.put_u16(36);
288 data_buf.put_slice(&trace_id.activity_id);
289 data_buf.put_u32_le(trace_id.activity_sequence);
290 data_buf.put_slice(&[0u8; 16]);
292 data_offset += 36;
293 }
294
295 if self.fed_auth_required {
297 buf.put_u8(PreLoginOption::FedAuthRequired as u8);
298 buf.put_u16(data_offset);
299 buf.put_u16(1);
300 data_buf.put_u8(0x01);
301 data_offset += 1;
302 }
303
304 if let Some(ref nonce) = self.nonce {
306 buf.put_u8(PreLoginOption::Nonce as u8);
307 buf.put_u16(data_offset);
308 buf.put_u16(32);
309 data_buf.put_slice(nonce);
310 let _ = data_offset; }
312
313 buf.put_u8(PreLoginOption::Terminator as u8);
315
316 buf.put_slice(&data_buf);
318
319 buf.freeze()
320 }
321
322 pub fn decode(mut src: impl Buf) -> Result<Self, ProtocolError> {
331 let mut prelogin = Self::default();
332
333 let mut options = Vec::new();
335 loop {
336 if src.remaining() < 1 {
337 return Err(ProtocolError::UnexpectedEof);
338 }
339
340 let option_type = src.get_u8();
341 if option_type == PreLoginOption::Terminator as u8 {
342 break;
343 }
344
345 if src.remaining() < 4 {
346 return Err(ProtocolError::UnexpectedEof);
347 }
348
349 let offset = src.get_u16();
350 let length = src.get_u16();
351 options.push((PreLoginOption::from_u8(option_type)?, offset, length));
352 }
353
354 let data = src.copy_to_bytes(src.remaining());
356
357 let header_size = options.len() * 5 + 1;
359
360 for (option, packet_offset, length) in options {
361 let packet_offset = packet_offset as usize;
362 let length = length as usize;
363
364 if packet_offset < header_size {
367 continue;
369 }
370 let data_offset = packet_offset - header_size;
371
372 if data_offset + length > data.len() {
374 continue;
375 }
376
377 match option {
378 PreLoginOption::Version if length >= 4 => {
379 let version_bytes = &data[data_offset..data_offset + 4];
387 let version_raw = u32::from_be_bytes([
388 version_bytes[0],
389 version_bytes[1],
390 version_bytes[2],
391 version_bytes[3],
392 ]);
393
394 let sub_build = if length >= 6 {
396 let sub_build_bytes = &data[data_offset + 4..data_offset + 6];
397 u16::from_le_bytes([sub_build_bytes[0], sub_build_bytes[1]])
398 } else {
399 0
400 };
401
402 prelogin.server_version =
404 Some(SqlServerVersion::from_raw(version_raw, sub_build));
405
406 prelogin.version = TdsVersion::new(version_raw);
408 }
409 PreLoginOption::Encryption if length >= 1 => {
410 prelogin.encryption = EncryptionLevel::from_u8(data[data_offset]);
411 }
412 PreLoginOption::Mars if length >= 1 => {
413 prelogin.mars = data[data_offset] != 0;
414 }
415 PreLoginOption::Instance if length > 0 => {
416 let instance_data = &data[data_offset..data_offset + length];
418 if let Some(null_pos) = instance_data.iter().position(|&b| b == 0) {
419 if let Ok(s) = core::str::from_utf8(&instance_data[..null_pos]) {
420 if !s.is_empty() {
421 prelogin.instance = Some(s.to_string());
422 }
423 }
424 }
425 }
426 PreLoginOption::ThreadId if length >= 4 => {
427 let bytes = &data[data_offset..data_offset + 4];
428 prelogin.thread_id =
429 Some(u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]));
430 }
431 PreLoginOption::FedAuthRequired if length >= 1 => {
432 prelogin.fed_auth_required = data[data_offset] != 0;
433 }
434 PreLoginOption::Nonce if length >= 32 => {
435 let mut nonce = [0u8; 32];
436 nonce.copy_from_slice(&data[data_offset..data_offset + 32]);
437 prelogin.nonce = Some(nonce);
438 }
439 _ => {}
440 }
441 }
442
443 Ok(prelogin)
444 }
445}
446
447#[cfg(test)]
448#[allow(clippy::unwrap_used)]
449mod tests {
450 use super::*;
451
452 #[test]
453 fn test_prelogin_encode() {
454 let prelogin = PreLogin::new()
455 .with_version(TdsVersion::V7_4)
456 .with_encryption(EncryptionLevel::Required);
457
458 let encoded = prelogin.encode();
459 assert!(!encoded.is_empty());
460 assert_eq!(encoded[0], PreLoginOption::Version as u8);
462 }
463
464 #[test]
465 fn test_encryption_level() {
466 assert!(EncryptionLevel::Required.is_required());
467 assert!(EncryptionLevel::On.is_required());
468 assert!(!EncryptionLevel::Off.is_required());
469 assert!(!EncryptionLevel::NotSupported.is_required());
470 }
471
472 #[test]
477 fn test_prelogin_fed_auth_required_roundtrip() {
478 let without = PreLogin::new().encode();
479 let decoded = PreLogin::decode(without.as_ref()).unwrap();
480 assert!(
481 !decoded.fed_auth_required,
482 "FEDAUTHREQUIRED must default to absent/false"
483 );
484
485 let with = PreLogin::new().with_fed_auth_required(true).encode();
486 let header_end = with.iter().position(|&b| b == 0xFF).unwrap();
488 assert!(
489 with[..header_end]
490 .chunks(5)
491 .any(|opt| opt[0] == PreLoginOption::FedAuthRequired as u8),
492 "encoded PreLogin must contain a FEDAUTHREQUIRED option header"
493 );
494 let decoded = PreLogin::decode(with.as_ref()).unwrap();
495 assert!(decoded.fed_auth_required);
496 }
497
498 #[test]
499 fn test_prelogin_decode_roundtrip() {
500 let original = PreLogin::new()
502 .with_version(TdsVersion::V7_4)
503 .with_encryption(EncryptionLevel::On)
504 .with_mars(true);
505
506 let encoded = original.encode();
508
509 let decoded = PreLogin::decode(encoded.as_ref()).unwrap();
511
512 assert_eq!(decoded.version, original.version);
514 assert_eq!(decoded.encryption, original.encryption);
515 assert_eq!(decoded.mars, original.mars);
516 }
517
518 #[test]
519 fn test_prelogin_decode_encryption_offset() {
520 use bytes::BufMut;
530
531 let mut buf = bytes::BytesMut::new();
532
533 let header_size: u16 = 11;
536
537 buf.put_u8(PreLoginOption::Encryption as u8);
539 buf.put_u16(header_size); buf.put_u16(1); buf.put_u8(PreLoginOption::Version as u8);
544 buf.put_u16(header_size + 1); buf.put_u16(6); buf.put_u8(PreLoginOption::Terminator as u8);
549
550 buf.put_u8(0x01);
553
554 buf.put_u8(0x74);
556 buf.put_u8(0x00);
557 buf.put_u8(0x00);
558 buf.put_u8(0x04);
559 buf.put_u16_le(0x0000); let decoded = PreLogin::decode(buf.freeze().as_ref()).unwrap();
563
564 assert_eq!(decoded.encryption, EncryptionLevel::On);
566 assert_eq!(decoded.version, TdsVersion::V7_4);
567 }
568}