tds_protocol/
prelogin.rs

1//! TDS pre-login packet handling.
2//!
3//! The pre-login packet is the first message exchanged between client and server
4//! in TDS 7.x connections. It negotiates protocol version, encryption, and other
5//! connection parameters.
6//!
7//! Note: TDS 8.0 (strict mode) does not use pre-login negotiation; TLS is
8//! established before any TDS traffic.
9
10use bytes::{Buf, BufMut, Bytes, BytesMut};
11
12use crate::error::ProtocolError;
13use crate::prelude::*;
14use crate::version::TdsVersion;
15
16/// Pre-login option types.
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18#[repr(u8)]
19pub enum PreLoginOption {
20    /// Version information.
21    Version = 0x00,
22    /// Encryption negotiation.
23    Encryption = 0x01,
24    /// Instance name (for named instances).
25    Instance = 0x02,
26    /// Thread ID.
27    ThreadId = 0x03,
28    /// MARS (Multiple Active Result Sets) support.
29    Mars = 0x04,
30    /// Trace ID for distributed tracing.
31    TraceId = 0x05,
32    /// Federated authentication required.
33    FedAuthRequired = 0x06,
34    /// Nonce for encryption.
35    Nonce = 0x07,
36    /// Terminator (end of options).
37    Terminator = 0xFF,
38}
39
40impl PreLoginOption {
41    /// Create from raw byte value.
42    pub fn from_u8(value: u8) -> Result<Self, ProtocolError> {
43        match value {
44            0x00 => Ok(Self::Version),
45            0x01 => Ok(Self::Encryption),
46            0x02 => Ok(Self::Instance),
47            0x03 => Ok(Self::ThreadId),
48            0x04 => Ok(Self::Mars),
49            0x05 => Ok(Self::TraceId),
50            0x06 => Ok(Self::FedAuthRequired),
51            0x07 => Ok(Self::Nonce),
52            0xFF => Ok(Self::Terminator),
53            _ => Err(ProtocolError::InvalidPreloginOption(value)),
54        }
55    }
56}
57
58/// Encryption level for connection.
59#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
60#[repr(u8)]
61pub enum EncryptionLevel {
62    /// Encryption is off.
63    Off = 0x00,
64    /// Encryption is on.
65    On = 0x01,
66    /// Encryption is not supported.
67    NotSupported = 0x02,
68    /// Encryption is required.
69    #[default]
70    Required = 0x03,
71    /// Client certificate authentication (TDS 8.0+).
72    ClientCertAuth = 0x80,
73}
74
75impl EncryptionLevel {
76    /// Create from raw byte value.
77    pub fn from_u8(value: u8) -> Self {
78        match value {
79            0x00 => Self::Off,
80            0x01 => Self::On,
81            0x02 => Self::NotSupported,
82            0x03 => Self::Required,
83            0x80 => Self::ClientCertAuth,
84            _ => Self::Off,
85        }
86    }
87
88    /// Check if encryption is required.
89    #[must_use]
90    pub const fn is_required(&self) -> bool {
91        matches!(self, Self::On | Self::Required | Self::ClientCertAuth)
92    }
93}
94
95/// Pre-login message builder and parser.
96#[derive(Debug, Clone, Default)]
97pub struct PreLogin {
98    /// TDS version.
99    pub version: TdsVersion,
100    /// Sub-build version.
101    pub sub_build: u16,
102    /// Encryption level.
103    pub encryption: EncryptionLevel,
104    /// Instance name (for named instances).
105    pub instance: Option<String>,
106    /// Thread ID.
107    pub thread_id: Option<u32>,
108    /// MARS enabled.
109    pub mars: bool,
110    /// Trace ID (Activity ID and Sequence).
111    pub trace_id: Option<TraceId>,
112    /// Federated authentication required.
113    pub fed_auth_required: bool,
114    /// Nonce for encryption.
115    pub nonce: Option<[u8; 32]>,
116}
117
118/// Distributed tracing ID.
119#[derive(Debug, Clone, Copy)]
120pub struct TraceId {
121    /// Activity ID (GUID).
122    pub activity_id: [u8; 16],
123    /// Activity sequence.
124    pub activity_sequence: u32,
125}
126
127impl PreLogin {
128    /// Create a new pre-login message with default values.
129    #[must_use]
130    pub fn new() -> Self {
131        Self {
132            version: TdsVersion::V7_4,
133            sub_build: 0,
134            encryption: EncryptionLevel::Required,
135            instance: None,
136            thread_id: None,
137            mars: false,
138            trace_id: None,
139            fed_auth_required: false,
140            nonce: None,
141        }
142    }
143
144    /// Set the TDS version.
145    #[must_use]
146    pub fn with_version(mut self, version: TdsVersion) -> Self {
147        self.version = version;
148        self
149    }
150
151    /// Set the encryption level.
152    #[must_use]
153    pub fn with_encryption(mut self, level: EncryptionLevel) -> Self {
154        self.encryption = level;
155        self
156    }
157
158    /// Enable MARS.
159    #[must_use]
160    pub fn with_mars(mut self, enabled: bool) -> Self {
161        self.mars = enabled;
162        self
163    }
164
165    /// Set the instance name.
166    #[must_use]
167    pub fn with_instance(mut self, instance: impl Into<String>) -> Self {
168        self.instance = Some(instance.into());
169        self
170    }
171
172    /// Encode the pre-login message to bytes.
173    #[must_use]
174    pub fn encode(&self) -> Bytes {
175        let mut buf = BytesMut::with_capacity(256);
176
177        // Calculate option data offsets
178        // Each option entry is 5 bytes: type (1) + offset (2) + length (2)
179        // Plus 1 byte for terminator
180        let mut option_count = 3; // Version, Encryption, MARS are always present
181        if self.instance.is_some() {
182            option_count += 1;
183        }
184        if self.thread_id.is_some() {
185            option_count += 1;
186        }
187        if self.trace_id.is_some() {
188            option_count += 1;
189        }
190        if self.fed_auth_required {
191            option_count += 1;
192        }
193        if self.nonce.is_some() {
194            option_count += 1;
195        }
196
197        let header_size = option_count * 5 + 1; // +1 for terminator
198        let mut data_offset = header_size as u16;
199        let mut data_buf = BytesMut::new();
200
201        // VERSION option (6 bytes: 4 bytes version + 2 bytes sub-build)
202        buf.put_u8(PreLoginOption::Version as u8);
203        buf.put_u16(data_offset);
204        buf.put_u16(6);
205        let version_raw = self.version.raw();
206        data_buf.put_u8((version_raw >> 24) as u8);
207        data_buf.put_u8((version_raw >> 16) as u8);
208        data_buf.put_u8((version_raw >> 8) as u8);
209        data_buf.put_u8(version_raw as u8);
210        data_buf.put_u16_le(self.sub_build);
211        data_offset += 6;
212
213        // ENCRYPTION option (1 byte)
214        buf.put_u8(PreLoginOption::Encryption as u8);
215        buf.put_u16(data_offset);
216        buf.put_u16(1);
217        data_buf.put_u8(self.encryption as u8);
218        data_offset += 1;
219
220        // INSTANCE option (if set)
221        if let Some(ref instance) = self.instance {
222            let instance_bytes = instance.as_bytes();
223            let len = instance_bytes.len() as u16 + 1; // +1 for null terminator
224            buf.put_u8(PreLoginOption::Instance as u8);
225            buf.put_u16(data_offset);
226            buf.put_u16(len);
227            data_buf.put_slice(instance_bytes);
228            data_buf.put_u8(0); // null terminator
229            data_offset += len;
230        }
231
232        // THREADID option (if set)
233        if let Some(thread_id) = self.thread_id {
234            buf.put_u8(PreLoginOption::ThreadId as u8);
235            buf.put_u16(data_offset);
236            buf.put_u16(4);
237            data_buf.put_u32(thread_id);
238            data_offset += 4;
239        }
240
241        // MARS option (1 byte)
242        buf.put_u8(PreLoginOption::Mars as u8);
243        buf.put_u16(data_offset);
244        buf.put_u16(1);
245        data_buf.put_u8(if self.mars { 0x01 } else { 0x00 });
246        data_offset += 1;
247
248        // TRACEID option (if set)
249        if let Some(ref trace_id) = self.trace_id {
250            buf.put_u8(PreLoginOption::TraceId as u8);
251            buf.put_u16(data_offset);
252            buf.put_u16(36);
253            data_buf.put_slice(&trace_id.activity_id);
254            data_buf.put_u32_le(trace_id.activity_sequence);
255            // Connection ID (16 bytes, typically zeros for client)
256            data_buf.put_slice(&[0u8; 16]);
257            data_offset += 36;
258        }
259
260        // FEDAUTHREQUIRED option (if set)
261        if self.fed_auth_required {
262            buf.put_u8(PreLoginOption::FedAuthRequired as u8);
263            buf.put_u16(data_offset);
264            buf.put_u16(1);
265            data_buf.put_u8(0x01);
266            data_offset += 1;
267        }
268
269        // NONCE option (if set)
270        if let Some(ref nonce) = self.nonce {
271            buf.put_u8(PreLoginOption::Nonce as u8);
272            buf.put_u16(data_offset);
273            buf.put_u16(32);
274            data_buf.put_slice(nonce);
275            let _ = data_offset; // Suppress unused warning
276        }
277
278        // Terminator
279        buf.put_u8(PreLoginOption::Terminator as u8);
280
281        // Append data section
282        buf.put_slice(&data_buf);
283
284        buf.freeze()
285    }
286
287    /// Decode a pre-login response from the server.
288    ///
289    /// Per MS-TDS spec 2.2.6.4, PreLogin message structure:
290    /// - Option headers: each 5 bytes (type:1 + offset:2 + length:2)
291    /// - Terminator: 1 byte (0xFF)
292    /// - Option data: variable length, positioned at offsets specified in headers
293    ///
294    /// Offsets in headers are absolute from the start of the PreLogin packet payload.
295    pub fn decode(mut src: impl Buf) -> Result<Self, ProtocolError> {
296        let mut prelogin = Self::default();
297
298        // Parse option headers first, collecting (option_type, offset, length)
299        let mut options = Vec::new();
300        loop {
301            if src.remaining() < 1 {
302                return Err(ProtocolError::UnexpectedEof);
303            }
304
305            let option_type = src.get_u8();
306            if option_type == PreLoginOption::Terminator as u8 {
307                break;
308            }
309
310            if src.remaining() < 4 {
311                return Err(ProtocolError::UnexpectedEof);
312            }
313
314            let offset = src.get_u16();
315            let length = src.get_u16();
316            options.push((PreLoginOption::from_u8(option_type)?, offset, length));
317        }
318
319        // Get remaining data as bytes for random access
320        let data = src.copy_to_bytes(src.remaining());
321
322        // Calculate header size: each option is 5 bytes + 1 byte terminator
323        let header_size = options.len() * 5 + 1;
324
325        for (option, packet_offset, length) in options {
326            let packet_offset = packet_offset as usize;
327            let length = length as usize;
328
329            // Convert absolute packet offset to offset within data buffer
330            // The data buffer starts after the headers, so we subtract header_size
331            if packet_offset < header_size {
332                // Invalid: offset points inside the headers
333                continue;
334            }
335            let data_offset = packet_offset - header_size;
336
337            // Bounds check
338            if data_offset + length > data.len() {
339                continue;
340            }
341
342            match option {
343                PreLoginOption::Version if length >= 6 => {
344                    // Per MS-TDS: UL_VERSION is 4 bytes big-endian, US_SUBBUILD is 2 bytes little-endian
345                    let version_bytes = &data[data_offset..data_offset + 4];
346                    let version_raw = u32::from_be_bytes([
347                        version_bytes[0],
348                        version_bytes[1],
349                        version_bytes[2],
350                        version_bytes[3],
351                    ]);
352                    prelogin.version = TdsVersion::new(version_raw);
353
354                    if length >= 6 {
355                        let sub_build_bytes = &data[data_offset + 4..data_offset + 6];
356                        prelogin.sub_build =
357                            u16::from_le_bytes([sub_build_bytes[0], sub_build_bytes[1]]);
358                    }
359                }
360                PreLoginOption::Encryption if length >= 1 => {
361                    prelogin.encryption = EncryptionLevel::from_u8(data[data_offset]);
362                }
363                PreLoginOption::Mars if length >= 1 => {
364                    prelogin.mars = data[data_offset] != 0;
365                }
366                PreLoginOption::Instance if length > 0 => {
367                    // Instance name is null-terminated string
368                    let instance_data = &data[data_offset..data_offset + length];
369                    if let Some(null_pos) = instance_data.iter().position(|&b| b == 0) {
370                        if let Ok(s) = core::str::from_utf8(&instance_data[..null_pos]) {
371                            if !s.is_empty() {
372                                prelogin.instance = Some(s.to_string());
373                            }
374                        }
375                    }
376                }
377                PreLoginOption::ThreadId if length >= 4 => {
378                    let bytes = &data[data_offset..data_offset + 4];
379                    prelogin.thread_id =
380                        Some(u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]));
381                }
382                PreLoginOption::FedAuthRequired if length >= 1 => {
383                    prelogin.fed_auth_required = data[data_offset] != 0;
384                }
385                PreLoginOption::Nonce if length >= 32 => {
386                    let mut nonce = [0u8; 32];
387                    nonce.copy_from_slice(&data[data_offset..data_offset + 32]);
388                    prelogin.nonce = Some(nonce);
389                }
390                _ => {}
391            }
392        }
393
394        Ok(prelogin)
395    }
396}
397
398#[cfg(test)]
399#[allow(clippy::unwrap_used)]
400mod tests {
401    use super::*;
402
403    #[test]
404    fn test_prelogin_encode() {
405        let prelogin = PreLogin::new()
406            .with_version(TdsVersion::V7_4)
407            .with_encryption(EncryptionLevel::Required);
408
409        let encoded = prelogin.encode();
410        assert!(!encoded.is_empty());
411        // First byte should be VERSION option type
412        assert_eq!(encoded[0], PreLoginOption::Version as u8);
413    }
414
415    #[test]
416    fn test_encryption_level() {
417        assert!(EncryptionLevel::Required.is_required());
418        assert!(EncryptionLevel::On.is_required());
419        assert!(!EncryptionLevel::Off.is_required());
420        assert!(!EncryptionLevel::NotSupported.is_required());
421    }
422
423    #[test]
424    fn test_prelogin_decode_roundtrip() {
425        // Create a PreLogin with various options
426        let original = PreLogin::new()
427            .with_version(TdsVersion::V7_4)
428            .with_encryption(EncryptionLevel::On)
429            .with_mars(true);
430
431        // Encode it
432        let encoded = original.encode();
433
434        // Decode it back
435        let decoded = PreLogin::decode(encoded.as_ref()).unwrap();
436
437        // Verify the critical fields match
438        assert_eq!(decoded.version, original.version);
439        assert_eq!(decoded.encryption, original.encryption);
440        assert_eq!(decoded.mars, original.mars);
441    }
442
443    #[test]
444    fn test_prelogin_decode_encryption_offset() {
445        // Manually construct a PreLogin packet with options in non-standard order
446        // to verify offset handling works correctly
447        //
448        // Structure:
449        // - ENCRYPTION header at offset pointing to encryption data
450        // - VERSION header at offset pointing to version data
451        // - Terminator
452        // - Data section
453
454        use bytes::BufMut;
455
456        let mut buf = bytes::BytesMut::new();
457
458        // Header section: each option is 5 bytes (type:1 + offset:2 + length:2)
459        // We'll have 2 options + terminator = 11 bytes header
460        let header_size: u16 = 11;
461
462        // ENCRYPTION option header (put this first to test that we read from correct offset)
463        buf.put_u8(PreLoginOption::Encryption as u8);
464        buf.put_u16(header_size); // offset to encryption data
465        buf.put_u16(1); // length
466
467        // VERSION option header
468        buf.put_u8(PreLoginOption::Version as u8);
469        buf.put_u16(header_size + 1); // offset to version data (after encryption)
470        buf.put_u16(6); // length
471
472        // Terminator
473        buf.put_u8(PreLoginOption::Terminator as u8);
474
475        // Data section
476        // Encryption data (1 byte): ENCRYPT_ON = 0x01
477        buf.put_u8(0x01);
478
479        // Version data (6 bytes): TDS 7.4 = 0x74000004 big-endian + sub-build 0x0000 little-endian
480        buf.put_u8(0x74);
481        buf.put_u8(0x00);
482        buf.put_u8(0x00);
483        buf.put_u8(0x04);
484        buf.put_u16_le(0x0000); // sub-build
485
486        // Decode
487        let decoded = PreLogin::decode(buf.freeze().as_ref()).unwrap();
488
489        // Verify encryption was read from correct offset (not from index 0)
490        assert_eq!(decoded.encryption, EncryptionLevel::On);
491        assert_eq!(decoded.version, TdsVersion::V7_4);
492    }
493}