1use bytes::{Buf, BufMut, Bytes, BytesMut};
11
12use crate::error::ProtocolError;
13use crate::prelude::*;
14use crate::version::{SqlServerVersion, TdsVersion};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18#[repr(u8)]
19pub enum PreLoginOption {
20 Version = 0x00,
22 Encryption = 0x01,
24 Instance = 0x02,
26 ThreadId = 0x03,
28 Mars = 0x04,
30 TraceId = 0x05,
32 FedAuthRequired = 0x06,
34 Nonce = 0x07,
36 Terminator = 0xFF,
38}
39
40impl PreLoginOption {
41 pub fn from_u8(value: u8) -> Result<Self, ProtocolError> {
43 match value {
44 0x00 => Ok(Self::Version),
45 0x01 => Ok(Self::Encryption),
46 0x02 => Ok(Self::Instance),
47 0x03 => Ok(Self::ThreadId),
48 0x04 => Ok(Self::Mars),
49 0x05 => Ok(Self::TraceId),
50 0x06 => Ok(Self::FedAuthRequired),
51 0x07 => Ok(Self::Nonce),
52 0xFF => Ok(Self::Terminator),
53 _ => Err(ProtocolError::InvalidPreloginOption(value)),
54 }
55 }
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
60#[repr(u8)]
61pub enum EncryptionLevel {
62 Off = 0x00,
64 On = 0x01,
66 NotSupported = 0x02,
68 #[default]
70 Required = 0x03,
71 ClientCertAuth = 0x80,
73}
74
75impl EncryptionLevel {
76 pub fn from_u8(value: u8) -> Self {
78 match value {
79 0x00 => Self::Off,
80 0x01 => Self::On,
81 0x02 => Self::NotSupported,
82 0x03 => Self::Required,
83 0x80 => Self::ClientCertAuth,
84 _ => Self::Off,
85 }
86 }
87
88 #[must_use]
90 pub const fn is_required(&self) -> bool {
91 matches!(self, Self::On | Self::Required | Self::ClientCertAuth)
92 }
93}
94
95#[derive(Debug, Clone, Default)]
105pub struct PreLogin {
106 pub version: TdsVersion,
111
112 pub server_version: Option<SqlServerVersion>,
119
120 #[deprecated(since = "0.5.2", note = "Use server_version.sub_build instead")]
122 pub sub_build: u16,
123
124 pub encryption: EncryptionLevel,
126 pub instance: Option<String>,
128 pub thread_id: Option<u32>,
130 pub mars: bool,
132 pub trace_id: Option<TraceId>,
134 pub fed_auth_required: bool,
136 pub nonce: Option<[u8; 32]>,
138}
139
140#[derive(Debug, Clone, Copy)]
142pub struct TraceId {
143 pub activity_id: [u8; 16],
145 pub activity_sequence: u32,
147}
148
149impl PreLogin {
150 #[must_use]
152 #[allow(deprecated)] pub fn new() -> Self {
154 Self {
155 version: TdsVersion::V7_4,
156 server_version: None,
157 sub_build: 0,
158 encryption: EncryptionLevel::Required,
159 instance: None,
160 thread_id: None,
161 mars: false,
162 trace_id: None,
163 fed_auth_required: false,
164 nonce: None,
165 }
166 }
167
168 #[must_use]
170 pub fn with_version(mut self, version: TdsVersion) -> Self {
171 self.version = version;
172 self
173 }
174
175 #[must_use]
177 pub fn with_encryption(mut self, level: EncryptionLevel) -> Self {
178 self.encryption = level;
179 self
180 }
181
182 #[must_use]
184 pub fn with_mars(mut self, enabled: bool) -> Self {
185 self.mars = enabled;
186 self
187 }
188
189 #[must_use]
191 pub fn with_instance(mut self, instance: impl Into<String>) -> Self {
192 self.instance = Some(instance.into());
193 self
194 }
195
196 #[must_use]
198 #[allow(deprecated)] pub fn encode(&self) -> Bytes {
200 let mut buf = BytesMut::with_capacity(256);
201
202 let mut option_count = 3; if self.instance.is_some() {
207 option_count += 1;
208 }
209 if self.thread_id.is_some() {
210 option_count += 1;
211 }
212 if self.trace_id.is_some() {
213 option_count += 1;
214 }
215 if self.fed_auth_required {
216 option_count += 1;
217 }
218 if self.nonce.is_some() {
219 option_count += 1;
220 }
221
222 let header_size = option_count * 5 + 1; let mut data_offset = header_size as u16;
224 let mut data_buf = BytesMut::new();
225
226 buf.put_u8(PreLoginOption::Version as u8);
228 buf.put_u16(data_offset);
229 buf.put_u16(6);
230 let version_raw = self.version.raw();
231 data_buf.put_u8((version_raw >> 24) as u8);
232 data_buf.put_u8((version_raw >> 16) as u8);
233 data_buf.put_u8((version_raw >> 8) as u8);
234 data_buf.put_u8(version_raw as u8);
235 data_buf.put_u16_le(self.sub_build);
236 data_offset += 6;
237
238 buf.put_u8(PreLoginOption::Encryption as u8);
240 buf.put_u16(data_offset);
241 buf.put_u16(1);
242 data_buf.put_u8(self.encryption as u8);
243 data_offset += 1;
244
245 if let Some(ref instance) = self.instance {
247 let instance_bytes = instance.as_bytes();
248 let len = instance_bytes.len() as u16 + 1; buf.put_u8(PreLoginOption::Instance as u8);
250 buf.put_u16(data_offset);
251 buf.put_u16(len);
252 data_buf.put_slice(instance_bytes);
253 data_buf.put_u8(0); data_offset += len;
255 }
256
257 if let Some(thread_id) = self.thread_id {
259 buf.put_u8(PreLoginOption::ThreadId as u8);
260 buf.put_u16(data_offset);
261 buf.put_u16(4);
262 data_buf.put_u32(thread_id);
263 data_offset += 4;
264 }
265
266 buf.put_u8(PreLoginOption::Mars as u8);
268 buf.put_u16(data_offset);
269 buf.put_u16(1);
270 data_buf.put_u8(if self.mars { 0x01 } else { 0x00 });
271 data_offset += 1;
272
273 if let Some(ref trace_id) = self.trace_id {
275 buf.put_u8(PreLoginOption::TraceId as u8);
276 buf.put_u16(data_offset);
277 buf.put_u16(36);
278 data_buf.put_slice(&trace_id.activity_id);
279 data_buf.put_u32_le(trace_id.activity_sequence);
280 data_buf.put_slice(&[0u8; 16]);
282 data_offset += 36;
283 }
284
285 if self.fed_auth_required {
287 buf.put_u8(PreLoginOption::FedAuthRequired as u8);
288 buf.put_u16(data_offset);
289 buf.put_u16(1);
290 data_buf.put_u8(0x01);
291 data_offset += 1;
292 }
293
294 if let Some(ref nonce) = self.nonce {
296 buf.put_u8(PreLoginOption::Nonce as u8);
297 buf.put_u16(data_offset);
298 buf.put_u16(32);
299 data_buf.put_slice(nonce);
300 let _ = data_offset; }
302
303 buf.put_u8(PreLoginOption::Terminator as u8);
305
306 buf.put_slice(&data_buf);
308
309 buf.freeze()
310 }
311
312 pub fn decode(mut src: impl Buf) -> Result<Self, ProtocolError> {
321 let mut prelogin = Self::default();
322
323 let mut options = Vec::new();
325 loop {
326 if src.remaining() < 1 {
327 return Err(ProtocolError::UnexpectedEof);
328 }
329
330 let option_type = src.get_u8();
331 if option_type == PreLoginOption::Terminator as u8 {
332 break;
333 }
334
335 if src.remaining() < 4 {
336 return Err(ProtocolError::UnexpectedEof);
337 }
338
339 let offset = src.get_u16();
340 let length = src.get_u16();
341 options.push((PreLoginOption::from_u8(option_type)?, offset, length));
342 }
343
344 let data = src.copy_to_bytes(src.remaining());
346
347 let header_size = options.len() * 5 + 1;
349
350 for (option, packet_offset, length) in options {
351 let packet_offset = packet_offset as usize;
352 let length = length as usize;
353
354 if packet_offset < header_size {
357 continue;
359 }
360 let data_offset = packet_offset - header_size;
361
362 if data_offset + length > data.len() {
364 continue;
365 }
366
367 #[allow(deprecated)] match option {
369 PreLoginOption::Version if length >= 4 => {
370 let version_bytes = &data[data_offset..data_offset + 4];
378 let version_raw = u32::from_be_bytes([
379 version_bytes[0],
380 version_bytes[1],
381 version_bytes[2],
382 version_bytes[3],
383 ]);
384
385 let sub_build = if length >= 6 {
387 let sub_build_bytes = &data[data_offset + 4..data_offset + 6];
388 u16::from_le_bytes([sub_build_bytes[0], sub_build_bytes[1]])
389 } else {
390 0
391 };
392
393 prelogin.server_version =
395 Some(SqlServerVersion::from_raw(version_raw, sub_build));
396
397 prelogin.version = TdsVersion::new(version_raw);
399 prelogin.sub_build = sub_build;
400 }
401 PreLoginOption::Encryption if length >= 1 => {
402 prelogin.encryption = EncryptionLevel::from_u8(data[data_offset]);
403 }
404 PreLoginOption::Mars if length >= 1 => {
405 prelogin.mars = data[data_offset] != 0;
406 }
407 PreLoginOption::Instance if length > 0 => {
408 let instance_data = &data[data_offset..data_offset + length];
410 if let Some(null_pos) = instance_data.iter().position(|&b| b == 0) {
411 if let Ok(s) = core::str::from_utf8(&instance_data[..null_pos]) {
412 if !s.is_empty() {
413 prelogin.instance = Some(s.to_string());
414 }
415 }
416 }
417 }
418 PreLoginOption::ThreadId if length >= 4 => {
419 let bytes = &data[data_offset..data_offset + 4];
420 prelogin.thread_id =
421 Some(u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]));
422 }
423 PreLoginOption::FedAuthRequired if length >= 1 => {
424 prelogin.fed_auth_required = data[data_offset] != 0;
425 }
426 PreLoginOption::Nonce if length >= 32 => {
427 let mut nonce = [0u8; 32];
428 nonce.copy_from_slice(&data[data_offset..data_offset + 32]);
429 prelogin.nonce = Some(nonce);
430 }
431 _ => {}
432 }
433 }
434
435 Ok(prelogin)
436 }
437}
438
439#[cfg(test)]
440#[allow(clippy::unwrap_used)]
441mod tests {
442 use super::*;
443
444 #[test]
445 fn test_prelogin_encode() {
446 let prelogin = PreLogin::new()
447 .with_version(TdsVersion::V7_4)
448 .with_encryption(EncryptionLevel::Required);
449
450 let encoded = prelogin.encode();
451 assert!(!encoded.is_empty());
452 assert_eq!(encoded[0], PreLoginOption::Version as u8);
454 }
455
456 #[test]
457 fn test_encryption_level() {
458 assert!(EncryptionLevel::Required.is_required());
459 assert!(EncryptionLevel::On.is_required());
460 assert!(!EncryptionLevel::Off.is_required());
461 assert!(!EncryptionLevel::NotSupported.is_required());
462 }
463
464 #[test]
465 fn test_prelogin_decode_roundtrip() {
466 let original = PreLogin::new()
468 .with_version(TdsVersion::V7_4)
469 .with_encryption(EncryptionLevel::On)
470 .with_mars(true);
471
472 let encoded = original.encode();
474
475 let decoded = PreLogin::decode(encoded.as_ref()).unwrap();
477
478 assert_eq!(decoded.version, original.version);
480 assert_eq!(decoded.encryption, original.encryption);
481 assert_eq!(decoded.mars, original.mars);
482 }
483
484 #[test]
485 fn test_prelogin_decode_encryption_offset() {
486 use bytes::BufMut;
496
497 let mut buf = bytes::BytesMut::new();
498
499 let header_size: u16 = 11;
502
503 buf.put_u8(PreLoginOption::Encryption as u8);
505 buf.put_u16(header_size); buf.put_u16(1); buf.put_u8(PreLoginOption::Version as u8);
510 buf.put_u16(header_size + 1); buf.put_u16(6); buf.put_u8(PreLoginOption::Terminator as u8);
515
516 buf.put_u8(0x01);
519
520 buf.put_u8(0x74);
522 buf.put_u8(0x00);
523 buf.put_u8(0x00);
524 buf.put_u8(0x04);
525 buf.put_u16_le(0x0000); let decoded = PreLogin::decode(buf.freeze().as_ref()).unwrap();
529
530 assert_eq!(decoded.encryption, EncryptionLevel::On);
532 assert_eq!(decoded.version, TdsVersion::V7_4);
533 }
534}