1use bytes::{Buf, BufMut, Bytes, BytesMut};
11
12use crate::error::ProtocolError;
13use crate::prelude::*;
14use crate::version::{SqlServerVersion, TdsVersion};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18#[repr(u8)]
19#[non_exhaustive]
20pub enum PreLoginOption {
21 Version = 0x00,
23 Encryption = 0x01,
25 Instance = 0x02,
27 ThreadId = 0x03,
29 Mars = 0x04,
31 TraceId = 0x05,
33 FedAuthRequired = 0x06,
35 Nonce = 0x07,
37 Terminator = 0xFF,
39}
40
41impl PreLoginOption {
42 pub fn from_u8(value: u8) -> Result<Self, ProtocolError> {
44 match value {
45 0x00 => Ok(Self::Version),
46 0x01 => Ok(Self::Encryption),
47 0x02 => Ok(Self::Instance),
48 0x03 => Ok(Self::ThreadId),
49 0x04 => Ok(Self::Mars),
50 0x05 => Ok(Self::TraceId),
51 0x06 => Ok(Self::FedAuthRequired),
52 0x07 => Ok(Self::Nonce),
53 0xFF => Ok(Self::Terminator),
54 _ => Err(ProtocolError::InvalidPreloginOption(value)),
55 }
56 }
57}
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
61#[repr(u8)]
62#[non_exhaustive]
63pub enum EncryptionLevel {
64 Off = 0x00,
66 On = 0x01,
68 NotSupported = 0x02,
70 #[default]
72 Required = 0x03,
73 ClientCertAuth = 0x80,
75}
76
77impl EncryptionLevel {
78 pub fn from_u8(value: u8) -> Result<Self, ProtocolError> {
86 match value {
87 0x00 => Ok(Self::Off),
88 0x01 => Ok(Self::On),
89 0x02 => Ok(Self::NotSupported),
90 0x03 => Ok(Self::Required),
91 0x80 => Ok(Self::ClientCertAuth),
92 _ => Err(ProtocolError::InvalidEncryptionLevel(value)),
93 }
94 }
95
96 #[must_use]
98 pub const fn is_required(&self) -> bool {
99 matches!(self, Self::On | Self::Required | Self::ClientCertAuth)
100 }
101}
102
103#[derive(Debug, Clone, Default)]
113pub struct PreLogin {
114 pub version: TdsVersion,
119
120 pub server_version: Option<SqlServerVersion>,
127
128 pub encryption: EncryptionLevel,
130 pub instance: Option<String>,
132 pub thread_id: Option<u32>,
134 pub mars: bool,
136 pub trace_id: Option<TraceId>,
138 pub fed_auth_required: bool,
140 pub nonce: Option<[u8; 32]>,
142}
143
144#[derive(Debug, Clone, Copy)]
146pub struct TraceId {
147 pub activity_id: [u8; 16],
149 pub activity_sequence: u32,
151}
152
153impl PreLogin {
154 #[must_use]
156 pub fn new() -> Self {
157 Self {
158 version: TdsVersion::V7_4,
159 server_version: None,
160 encryption: EncryptionLevel::Required,
161 instance: None,
162 thread_id: None,
163 mars: false,
164 trace_id: None,
165 fed_auth_required: false,
166 nonce: None,
167 }
168 }
169
170 #[must_use]
172 pub fn with_version(mut self, version: TdsVersion) -> Self {
173 self.version = version;
174 self
175 }
176
177 #[must_use]
179 pub fn with_encryption(mut self, level: EncryptionLevel) -> Self {
180 self.encryption = level;
181 self
182 }
183
184 #[must_use]
186 pub fn with_mars(mut self, enabled: bool) -> Self {
187 self.mars = enabled;
188 self
189 }
190
191 #[must_use]
193 pub fn with_instance(mut self, instance: impl Into<String>) -> Self {
194 self.instance = Some(instance.into());
195 self
196 }
197
198 #[must_use]
206 pub fn with_fed_auth_required(mut self, required: bool) -> Self {
207 self.fed_auth_required = required;
208 self
209 }
210
211 #[must_use]
213 pub fn encode(&self) -> Bytes {
214 let mut buf = BytesMut::with_capacity(256);
215
216 let mut option_count = 3; if self.instance.is_some() {
221 option_count += 1;
222 }
223 if self.thread_id.is_some() {
224 option_count += 1;
225 }
226 if self.trace_id.is_some() {
227 option_count += 1;
228 }
229 if self.fed_auth_required {
230 option_count += 1;
231 }
232 if self.nonce.is_some() {
233 option_count += 1;
234 }
235
236 let header_size = option_count * 5 + 1; let mut data_offset = header_size as u16;
238 let mut data_buf = BytesMut::new();
239
240 buf.put_u8(PreLoginOption::Version as u8);
242 buf.put_u16(data_offset);
243 buf.put_u16(6);
244 let version_raw = self.version.raw();
245 data_buf.put_u8((version_raw >> 24) as u8);
246 data_buf.put_u8((version_raw >> 16) as u8);
247 data_buf.put_u8((version_raw >> 8) as u8);
248 data_buf.put_u8(version_raw as u8);
249 data_buf.put_u16_le(0);
252 data_offset += 6;
253
254 buf.put_u8(PreLoginOption::Encryption as u8);
256 buf.put_u16(data_offset);
257 buf.put_u16(1);
258 data_buf.put_u8(self.encryption as u8);
259 data_offset += 1;
260
261 if let Some(ref instance) = self.instance {
263 let instance_bytes = instance.as_bytes();
264 let len = instance_bytes.len() as u16 + 1; buf.put_u8(PreLoginOption::Instance as u8);
266 buf.put_u16(data_offset);
267 buf.put_u16(len);
268 data_buf.put_slice(instance_bytes);
269 data_buf.put_u8(0); data_offset += len;
271 }
272
273 if let Some(thread_id) = self.thread_id {
275 buf.put_u8(PreLoginOption::ThreadId as u8);
276 buf.put_u16(data_offset);
277 buf.put_u16(4);
278 data_buf.put_u32(thread_id);
279 data_offset += 4;
280 }
281
282 buf.put_u8(PreLoginOption::Mars as u8);
284 buf.put_u16(data_offset);
285 buf.put_u16(1);
286 data_buf.put_u8(if self.mars { 0x01 } else { 0x00 });
287 data_offset += 1;
288
289 if let Some(ref trace_id) = self.trace_id {
291 buf.put_u8(PreLoginOption::TraceId as u8);
292 buf.put_u16(data_offset);
293 buf.put_u16(36);
294 data_buf.put_slice(&trace_id.activity_id);
295 data_buf.put_u32_le(trace_id.activity_sequence);
296 data_buf.put_slice(&[0u8; 16]);
298 data_offset += 36;
299 }
300
301 if self.fed_auth_required {
303 buf.put_u8(PreLoginOption::FedAuthRequired as u8);
304 buf.put_u16(data_offset);
305 buf.put_u16(1);
306 data_buf.put_u8(0x01);
307 data_offset += 1;
308 }
309
310 if let Some(ref nonce) = self.nonce {
312 buf.put_u8(PreLoginOption::Nonce as u8);
313 buf.put_u16(data_offset);
314 buf.put_u16(32);
315 data_buf.put_slice(nonce);
316 let _ = data_offset; }
318
319 buf.put_u8(PreLoginOption::Terminator as u8);
321
322 buf.put_slice(&data_buf);
324
325 buf.freeze()
326 }
327
328 pub fn decode(mut src: impl Buf) -> Result<Self, ProtocolError> {
337 let mut prelogin = Self::default();
338
339 let mut options = Vec::new();
341 loop {
342 if src.remaining() < 1 {
343 return Err(ProtocolError::UnexpectedEof);
344 }
345
346 let option_type = src.get_u8();
347 if option_type == PreLoginOption::Terminator as u8 {
348 break;
349 }
350
351 if src.remaining() < 4 {
352 return Err(ProtocolError::UnexpectedEof);
353 }
354
355 let offset = src.get_u16();
356 let length = src.get_u16();
357 options.push((PreLoginOption::from_u8(option_type)?, offset, length));
358 }
359
360 let data = src.copy_to_bytes(src.remaining());
362
363 let header_size = options.len() * 5 + 1;
365
366 for (option, packet_offset, length) in options {
367 let packet_offset = packet_offset as usize;
368 let length = length as usize;
369
370 if packet_offset < header_size {
373 continue;
375 }
376 let data_offset = packet_offset - header_size;
377
378 if data_offset + length > data.len() {
380 continue;
381 }
382
383 match option {
384 PreLoginOption::Version if length >= 4 => {
385 let version_bytes = &data[data_offset..data_offset + 4];
393 let version_raw = u32::from_be_bytes([
394 version_bytes[0],
395 version_bytes[1],
396 version_bytes[2],
397 version_bytes[3],
398 ]);
399
400 let sub_build = if length >= 6 {
402 let sub_build_bytes = &data[data_offset + 4..data_offset + 6];
403 u16::from_le_bytes([sub_build_bytes[0], sub_build_bytes[1]])
404 } else {
405 0
406 };
407
408 prelogin.server_version =
410 Some(SqlServerVersion::from_raw(version_raw, sub_build));
411
412 prelogin.version = TdsVersion::new(version_raw);
414 }
415 PreLoginOption::Encryption if length >= 1 => {
416 prelogin.encryption = EncryptionLevel::from_u8(data[data_offset])?;
417 }
418 PreLoginOption::Mars if length >= 1 => {
419 prelogin.mars = data[data_offset] != 0;
420 }
421 PreLoginOption::Instance if length > 0 => {
422 let instance_data = &data[data_offset..data_offset + length];
424 if let Some(null_pos) = instance_data.iter().position(|&b| b == 0) {
425 if let Ok(s) = core::str::from_utf8(&instance_data[..null_pos]) {
426 if !s.is_empty() {
427 prelogin.instance = Some(s.to_string());
428 }
429 }
430 }
431 }
432 PreLoginOption::ThreadId if length >= 4 => {
433 let bytes = &data[data_offset..data_offset + 4];
434 prelogin.thread_id =
435 Some(u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]));
436 }
437 PreLoginOption::FedAuthRequired if length >= 1 => {
438 prelogin.fed_auth_required = data[data_offset] != 0;
439 }
440 PreLoginOption::Nonce if length >= 32 => {
441 let mut nonce = [0u8; 32];
442 nonce.copy_from_slice(&data[data_offset..data_offset + 32]);
443 prelogin.nonce = Some(nonce);
444 }
445 _ => {}
446 }
447 }
448
449 Ok(prelogin)
450 }
451}
452
453#[cfg(test)]
454#[allow(clippy::unwrap_used)]
455mod tests {
456 use super::*;
457
458 #[test]
459 fn test_prelogin_encode() {
460 let prelogin = PreLogin::new()
461 .with_version(TdsVersion::V7_4)
462 .with_encryption(EncryptionLevel::Required);
463
464 let encoded = prelogin.encode();
465 assert!(!encoded.is_empty());
466 assert_eq!(encoded[0], PreLoginOption::Version as u8);
468 }
469
470 #[test]
471 fn test_encryption_level() {
472 assert!(EncryptionLevel::Required.is_required());
473 assert!(EncryptionLevel::On.is_required());
474 assert!(!EncryptionLevel::Off.is_required());
475 assert!(!EncryptionLevel::NotSupported.is_required());
476 }
477
478 #[test]
483 fn test_prelogin_fed_auth_required_roundtrip() {
484 let without = PreLogin::new().encode();
485 let decoded = PreLogin::decode(without.as_ref()).unwrap();
486 assert!(
487 !decoded.fed_auth_required,
488 "FEDAUTHREQUIRED must default to absent/false"
489 );
490
491 let with = PreLogin::new().with_fed_auth_required(true).encode();
492 let header_end = with.iter().position(|&b| b == 0xFF).unwrap();
494 assert!(
495 with[..header_end]
496 .chunks(5)
497 .any(|opt| opt[0] == PreLoginOption::FedAuthRequired as u8),
498 "encoded PreLogin must contain a FEDAUTHREQUIRED option header"
499 );
500 let decoded = PreLogin::decode(with.as_ref()).unwrap();
501 assert!(decoded.fed_auth_required);
502 }
503
504 #[test]
508 fn test_prelogin_decode_rejects_unknown_encryption_byte() {
509 use bytes::BufMut;
510
511 let mut buf = bytes::BytesMut::new();
512 let header_size: u16 = 6; buf.put_u8(PreLoginOption::Encryption as u8);
516 buf.put_u16(header_size); buf.put_u16(1); buf.put_u8(PreLoginOption::Terminator as u8);
520 buf.put_u8(0x42);
522
523 let result = PreLogin::decode(buf.freeze().as_ref());
524 assert!(
525 matches!(result, Err(ProtocolError::InvalidEncryptionLevel(0x42))),
526 "an unknown encryption byte must be rejected as \
527 InvalidEncryptionLevel(0x42), not read as Off; got {result:?}"
528 );
529 }
530
531 #[test]
532 fn test_prelogin_decode_roundtrip() {
533 let original = PreLogin::new()
535 .with_version(TdsVersion::V7_4)
536 .with_encryption(EncryptionLevel::On)
537 .with_mars(true);
538
539 let encoded = original.encode();
541
542 let decoded = PreLogin::decode(encoded.as_ref()).unwrap();
544
545 assert_eq!(decoded.version, original.version);
547 assert_eq!(decoded.encryption, original.encryption);
548 assert_eq!(decoded.mars, original.mars);
549 }
550
551 #[test]
552 fn test_prelogin_decode_encryption_offset() {
553 use bytes::BufMut;
563
564 let mut buf = bytes::BytesMut::new();
565
566 let header_size: u16 = 11;
569
570 buf.put_u8(PreLoginOption::Encryption as u8);
572 buf.put_u16(header_size); buf.put_u16(1); buf.put_u8(PreLoginOption::Version as u8);
577 buf.put_u16(header_size + 1); buf.put_u16(6); buf.put_u8(PreLoginOption::Terminator as u8);
582
583 buf.put_u8(0x01);
586
587 buf.put_u8(0x74);
589 buf.put_u8(0x00);
590 buf.put_u8(0x00);
591 buf.put_u8(0x04);
592 buf.put_u16_le(0x0000); let decoded = PreLogin::decode(buf.freeze().as_ref()).unwrap();
596
597 assert_eq!(decoded.encryption, EncryptionLevel::On);
599 assert_eq!(decoded.version, TdsVersion::V7_4);
600 }
601}