Skip to main content

tds_protocol/
prelogin.rs

1//! TDS pre-login packet handling.
2//!
3//! The pre-login packet is the first message exchanged between client and server
4//! in TDS 7.x connections. It negotiates protocol version, encryption, and other
5//! connection parameters.
6//!
7//! Note: TDS 8.0 (strict mode) does not use pre-login negotiation; TLS is
8//! established before any TDS traffic.
9
10use bytes::{Buf, BufMut, Bytes, BytesMut};
11
12use crate::error::ProtocolError;
13use crate::prelude::*;
14use crate::version::{SqlServerVersion, TdsVersion};
15
16/// Pre-login option types.
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18#[repr(u8)]
19#[non_exhaustive]
20pub enum PreLoginOption {
21    /// Version information.
22    Version = 0x00,
23    /// Encryption negotiation.
24    Encryption = 0x01,
25    /// Instance name (for named instances).
26    Instance = 0x02,
27    /// Thread ID.
28    ThreadId = 0x03,
29    /// MARS (Multiple Active Result Sets) support.
30    Mars = 0x04,
31    /// Trace ID for distributed tracing.
32    TraceId = 0x05,
33    /// Federated authentication required.
34    FedAuthRequired = 0x06,
35    /// Nonce for encryption.
36    Nonce = 0x07,
37    /// Terminator (end of options).
38    Terminator = 0xFF,
39}
40
41impl PreLoginOption {
42    /// Create from raw byte value.
43    pub fn from_u8(value: u8) -> Result<Self, ProtocolError> {
44        match value {
45            0x00 => Ok(Self::Version),
46            0x01 => Ok(Self::Encryption),
47            0x02 => Ok(Self::Instance),
48            0x03 => Ok(Self::ThreadId),
49            0x04 => Ok(Self::Mars),
50            0x05 => Ok(Self::TraceId),
51            0x06 => Ok(Self::FedAuthRequired),
52            0x07 => Ok(Self::Nonce),
53            0xFF => Ok(Self::Terminator),
54            _ => Err(ProtocolError::InvalidPreloginOption(value)),
55        }
56    }
57}
58
59/// Encryption level for connection.
60#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
61#[repr(u8)]
62#[non_exhaustive]
63pub enum EncryptionLevel {
64    /// Encryption is off.
65    Off = 0x00,
66    /// Encryption is on.
67    On = 0x01,
68    /// Encryption is not supported.
69    NotSupported = 0x02,
70    /// Encryption is required.
71    #[default]
72    Required = 0x03,
73    /// Client certificate authentication (TDS 8.0+).
74    ClientCertAuth = 0x80,
75}
76
77impl EncryptionLevel {
78    /// Create from raw byte value.
79    pub fn from_u8(value: u8) -> Self {
80        match value {
81            0x00 => Self::Off,
82            0x01 => Self::On,
83            0x02 => Self::NotSupported,
84            0x03 => Self::Required,
85            0x80 => Self::ClientCertAuth,
86            _ => Self::Off,
87        }
88    }
89
90    /// Check if encryption is required.
91    #[must_use]
92    pub const fn is_required(&self) -> bool {
93        matches!(self, Self::On | Self::Required | Self::ClientCertAuth)
94    }
95}
96
97/// Pre-login message builder and parser.
98///
99/// This struct is used for both client requests and server responses:
100/// - **Client โ†’ Server**: Set `version` to the requested TDS version
101/// - **Server โ†’ Client**: `server_version` contains the SQL Server product version
102///
103/// Note: The VERSION field has different semantics in each direction:
104/// - Client sends: TDS protocol version (e.g., 7.4)
105/// - Server sends: SQL Server product version (e.g., 13.0.6300 for SQL Server 2016)
106#[derive(Debug, Clone, Default)]
107pub struct PreLogin {
108    /// TDS version (client request).
109    ///
110    /// This is the TDS protocol version the client requests. When sending a
111    /// PreLogin, set this to the desired TDS version.
112    pub version: TdsVersion,
113
114    /// SQL Server product version (server response).
115    ///
116    /// When decoding a PreLogin response from the server, this contains the
117    /// SQL Server product version (e.g., 13.0.6300 for SQL Server 2016).
118    /// This is NOT the TDS version - the actual TDS version is negotiated
119    /// in the LOGINACK token after login.
120    pub server_version: Option<SqlServerVersion>,
121
122    /// Encryption level.
123    pub encryption: EncryptionLevel,
124    /// Instance name (for named instances).
125    pub instance: Option<String>,
126    /// Thread ID.
127    pub thread_id: Option<u32>,
128    /// MARS enabled.
129    pub mars: bool,
130    /// Trace ID (Activity ID and Sequence).
131    pub trace_id: Option<TraceId>,
132    /// Federated authentication required.
133    pub fed_auth_required: bool,
134    /// Nonce for encryption.
135    pub nonce: Option<[u8; 32]>,
136}
137
138/// Distributed tracing ID.
139#[derive(Debug, Clone, Copy)]
140pub struct TraceId {
141    /// Activity ID (GUID).
142    pub activity_id: [u8; 16],
143    /// Activity sequence.
144    pub activity_sequence: u32,
145}
146
147impl PreLogin {
148    /// Create a new pre-login message with default values.
149    #[must_use]
150    pub fn new() -> Self {
151        Self {
152            version: TdsVersion::V7_4,
153            server_version: None,
154            encryption: EncryptionLevel::Required,
155            instance: None,
156            thread_id: None,
157            mars: false,
158            trace_id: None,
159            fed_auth_required: false,
160            nonce: None,
161        }
162    }
163
164    /// Set the TDS version.
165    #[must_use]
166    pub fn with_version(mut self, version: TdsVersion) -> Self {
167        self.version = version;
168        self
169    }
170
171    /// Set the encryption level.
172    #[must_use]
173    pub fn with_encryption(mut self, level: EncryptionLevel) -> Self {
174        self.encryption = level;
175        self
176    }
177
178    /// Enable MARS.
179    #[must_use]
180    pub fn with_mars(mut self, enabled: bool) -> Self {
181        self.mars = enabled;
182        self
183    }
184
185    /// Set the instance name.
186    #[must_use]
187    pub fn with_instance(mut self, instance: impl Into<String>) -> Self {
188        self.instance = Some(instance.into());
189        self
190    }
191
192    /// Advertise federated authentication support (FEDAUTHREQUIRED option).
193    ///
194    /// When set on a client PreLogin, the encoded message carries the
195    /// FEDAUTHREQUIRED option with value 0x01. The server's response echoes
196    /// its own FEDAUTHREQUIRED value in [`PreLogin::fed_auth_required`]; per
197    /// MS-TDS ยง2.2.6.4 the LOGIN7 FEDAUTH feature extension's `fFedAuthEcho`
198    /// bit MUST mirror that response value.
199    #[must_use]
200    pub fn with_fed_auth_required(mut self, required: bool) -> Self {
201        self.fed_auth_required = required;
202        self
203    }
204
205    /// Encode the pre-login message to bytes.
206    #[must_use]
207    pub fn encode(&self) -> Bytes {
208        let mut buf = BytesMut::with_capacity(256);
209
210        // Calculate option data offsets
211        // Each option entry is 5 bytes: type (1) + offset (2) + length (2)
212        // Plus 1 byte for terminator
213        let mut option_count = 3; // Version, Encryption, MARS are always present
214        if self.instance.is_some() {
215            option_count += 1;
216        }
217        if self.thread_id.is_some() {
218            option_count += 1;
219        }
220        if self.trace_id.is_some() {
221            option_count += 1;
222        }
223        if self.fed_auth_required {
224            option_count += 1;
225        }
226        if self.nonce.is_some() {
227            option_count += 1;
228        }
229
230        let header_size = option_count * 5 + 1; // +1 for terminator
231        let mut data_offset = header_size as u16;
232        let mut data_buf = BytesMut::new();
233
234        // VERSION option (6 bytes: 4 bytes version + 2 bytes sub-build)
235        buf.put_u8(PreLoginOption::Version as u8);
236        buf.put_u16(data_offset);
237        buf.put_u16(6);
238        let version_raw = self.version.raw();
239        data_buf.put_u8((version_raw >> 24) as u8);
240        data_buf.put_u8((version_raw >> 16) as u8);
241        data_buf.put_u8((version_raw >> 8) as u8);
242        data_buf.put_u8(version_raw as u8);
243        // Sub-build is always 0 for client-sent PreLogin; server sub-build
244        // lives in server_version after decode.
245        data_buf.put_u16_le(0);
246        data_offset += 6;
247
248        // ENCRYPTION option (1 byte)
249        buf.put_u8(PreLoginOption::Encryption as u8);
250        buf.put_u16(data_offset);
251        buf.put_u16(1);
252        data_buf.put_u8(self.encryption as u8);
253        data_offset += 1;
254
255        // INSTANCE option (if set)
256        if let Some(ref instance) = self.instance {
257            let instance_bytes = instance.as_bytes();
258            let len = instance_bytes.len() as u16 + 1; // +1 for null terminator
259            buf.put_u8(PreLoginOption::Instance as u8);
260            buf.put_u16(data_offset);
261            buf.put_u16(len);
262            data_buf.put_slice(instance_bytes);
263            data_buf.put_u8(0); // null terminator
264            data_offset += len;
265        }
266
267        // THREADID option (if set)
268        if let Some(thread_id) = self.thread_id {
269            buf.put_u8(PreLoginOption::ThreadId as u8);
270            buf.put_u16(data_offset);
271            buf.put_u16(4);
272            data_buf.put_u32(thread_id);
273            data_offset += 4;
274        }
275
276        // MARS option (1 byte)
277        buf.put_u8(PreLoginOption::Mars as u8);
278        buf.put_u16(data_offset);
279        buf.put_u16(1);
280        data_buf.put_u8(if self.mars { 0x01 } else { 0x00 });
281        data_offset += 1;
282
283        // TRACEID option (if set)
284        if let Some(ref trace_id) = self.trace_id {
285            buf.put_u8(PreLoginOption::TraceId as u8);
286            buf.put_u16(data_offset);
287            buf.put_u16(36);
288            data_buf.put_slice(&trace_id.activity_id);
289            data_buf.put_u32_le(trace_id.activity_sequence);
290            // Connection ID (16 bytes, typically zeros for client)
291            data_buf.put_slice(&[0u8; 16]);
292            data_offset += 36;
293        }
294
295        // FEDAUTHREQUIRED option (if set)
296        if self.fed_auth_required {
297            buf.put_u8(PreLoginOption::FedAuthRequired as u8);
298            buf.put_u16(data_offset);
299            buf.put_u16(1);
300            data_buf.put_u8(0x01);
301            data_offset += 1;
302        }
303
304        // NONCE option (if set)
305        if let Some(ref nonce) = self.nonce {
306            buf.put_u8(PreLoginOption::Nonce as u8);
307            buf.put_u16(data_offset);
308            buf.put_u16(32);
309            data_buf.put_slice(nonce);
310            let _ = data_offset; // Suppress unused warning
311        }
312
313        // Terminator
314        buf.put_u8(PreLoginOption::Terminator as u8);
315
316        // Append data section
317        buf.put_slice(&data_buf);
318
319        buf.freeze()
320    }
321
322    /// Decode a pre-login response from the server.
323    ///
324    /// Per MS-TDS spec 2.2.6.4, PreLogin message structure:
325    /// - Option headers: each 5 bytes (type:1 + offset:2 + length:2)
326    /// - Terminator: 1 byte (0xFF)
327    /// - Option data: variable length, positioned at offsets specified in headers
328    ///
329    /// Offsets in headers are absolute from the start of the PreLogin packet payload.
330    pub fn decode(mut src: impl Buf) -> Result<Self, ProtocolError> {
331        let mut prelogin = Self::default();
332
333        // Parse option headers first, collecting (option_type, offset, length)
334        let mut options = Vec::new();
335        loop {
336            if src.remaining() < 1 {
337                return Err(ProtocolError::UnexpectedEof);
338            }
339
340            let option_type = src.get_u8();
341            if option_type == PreLoginOption::Terminator as u8 {
342                break;
343            }
344
345            if src.remaining() < 4 {
346                return Err(ProtocolError::UnexpectedEof);
347            }
348
349            let offset = src.get_u16();
350            let length = src.get_u16();
351            options.push((PreLoginOption::from_u8(option_type)?, offset, length));
352        }
353
354        // Get remaining data as bytes for random access
355        let data = src.copy_to_bytes(src.remaining());
356
357        // Calculate header size: each option is 5 bytes + 1 byte terminator
358        let header_size = options.len() * 5 + 1;
359
360        for (option, packet_offset, length) in options {
361            let packet_offset = packet_offset as usize;
362            let length = length as usize;
363
364            // Convert absolute packet offset to offset within data buffer
365            // The data buffer starts after the headers, so we subtract header_size
366            if packet_offset < header_size {
367                // Invalid: offset points inside the headers
368                continue;
369            }
370            let data_offset = packet_offset - header_size;
371
372            // Bounds check
373            if data_offset + length > data.len() {
374                continue;
375            }
376
377            match option {
378                PreLoginOption::Version if length >= 4 => {
379                    // Per MS-TDS 2.2.6.4: The server sends its SQL Server product version
380                    // in the VERSION field, NOT the TDS protocol version.
381                    //
382                    // Format: UL_VERSION (4 bytes big-endian) + US_SUBBUILD (2 bytes little-endian)
383                    // UL_VERSION contains: [major][minor][build_hi][build_lo]
384                    //
385                    // For example, SQL Server 2016 sends 13.0.xxxx (major=13, minor=0)
386                    let version_bytes = &data[data_offset..data_offset + 4];
387                    let version_raw = u32::from_be_bytes([
388                        version_bytes[0],
389                        version_bytes[1],
390                        version_bytes[2],
391                        version_bytes[3],
392                    ]);
393
394                    // Extract sub_build if present
395                    let sub_build = if length >= 6 {
396                        let sub_build_bytes = &data[data_offset + 4..data_offset + 6];
397                        u16::from_le_bytes([sub_build_bytes[0], sub_build_bytes[1]])
398                    } else {
399                        0
400                    };
401
402                    // Populate the new SqlServerVersion field (correct semantics)
403                    prelogin.server_version =
404                        Some(SqlServerVersion::from_raw(version_raw, sub_build));
405
406                    // Also set version for backward compatibility
407                    prelogin.version = TdsVersion::new(version_raw);
408                }
409                PreLoginOption::Encryption if length >= 1 => {
410                    prelogin.encryption = EncryptionLevel::from_u8(data[data_offset]);
411                }
412                PreLoginOption::Mars if length >= 1 => {
413                    prelogin.mars = data[data_offset] != 0;
414                }
415                PreLoginOption::Instance if length > 0 => {
416                    // Instance name is null-terminated string
417                    let instance_data = &data[data_offset..data_offset + length];
418                    if let Some(null_pos) = instance_data.iter().position(|&b| b == 0) {
419                        if let Ok(s) = core::str::from_utf8(&instance_data[..null_pos]) {
420                            if !s.is_empty() {
421                                prelogin.instance = Some(s.to_string());
422                            }
423                        }
424                    }
425                }
426                PreLoginOption::ThreadId if length >= 4 => {
427                    let bytes = &data[data_offset..data_offset + 4];
428                    prelogin.thread_id =
429                        Some(u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]));
430                }
431                PreLoginOption::FedAuthRequired if length >= 1 => {
432                    prelogin.fed_auth_required = data[data_offset] != 0;
433                }
434                PreLoginOption::Nonce if length >= 32 => {
435                    let mut nonce = [0u8; 32];
436                    nonce.copy_from_slice(&data[data_offset..data_offset + 32]);
437                    prelogin.nonce = Some(nonce);
438                }
439                _ => {}
440            }
441        }
442
443        Ok(prelogin)
444    }
445}
446
447#[cfg(test)]
448#[allow(clippy::unwrap_used)]
449mod tests {
450    use super::*;
451
452    #[test]
453    fn test_prelogin_encode() {
454        let prelogin = PreLogin::new()
455            .with_version(TdsVersion::V7_4)
456            .with_encryption(EncryptionLevel::Required);
457
458        let encoded = prelogin.encode();
459        assert!(!encoded.is_empty());
460        // First byte should be VERSION option type
461        assert_eq!(encoded[0], PreLoginOption::Version as u8);
462    }
463
464    #[test]
465    fn test_encryption_level() {
466        assert!(EncryptionLevel::Required.is_required());
467        assert!(EncryptionLevel::On.is_required());
468        assert!(!EncryptionLevel::Off.is_required());
469        assert!(!EncryptionLevel::NotSupported.is_required());
470    }
471
472    /// FEDAUTHREQUIRED (option 0x06) must be emitted with payload 0x01 when
473    /// requested and omitted otherwise, and must survive an encode/decode
474    /// round trip โ€” the login path reads the decoded flag back as the
475    /// LOGIN7 `fFedAuthEcho` source.
476    #[test]
477    fn test_prelogin_fed_auth_required_roundtrip() {
478        let without = PreLogin::new().encode();
479        let decoded = PreLogin::decode(without.as_ref()).unwrap();
480        assert!(
481            !decoded.fed_auth_required,
482            "FEDAUTHREQUIRED must default to absent/false"
483        );
484
485        let with = PreLogin::new().with_fed_auth_required(true).encode();
486        // Option header present: type 0x06 somewhere in the header section.
487        let header_end = with.iter().position(|&b| b == 0xFF).unwrap();
488        assert!(
489            with[..header_end]
490                .chunks(5)
491                .any(|opt| opt[0] == PreLoginOption::FedAuthRequired as u8),
492            "encoded PreLogin must contain a FEDAUTHREQUIRED option header"
493        );
494        let decoded = PreLogin::decode(with.as_ref()).unwrap();
495        assert!(decoded.fed_auth_required);
496    }
497
498    #[test]
499    fn test_prelogin_decode_roundtrip() {
500        // Create a PreLogin with various options
501        let original = PreLogin::new()
502            .with_version(TdsVersion::V7_4)
503            .with_encryption(EncryptionLevel::On)
504            .with_mars(true);
505
506        // Encode it
507        let encoded = original.encode();
508
509        // Decode it back
510        let decoded = PreLogin::decode(encoded.as_ref()).unwrap();
511
512        // Verify the critical fields match
513        assert_eq!(decoded.version, original.version);
514        assert_eq!(decoded.encryption, original.encryption);
515        assert_eq!(decoded.mars, original.mars);
516    }
517
518    #[test]
519    fn test_prelogin_decode_encryption_offset() {
520        // Manually construct a PreLogin packet with options in non-standard order
521        // to verify offset handling works correctly
522        //
523        // Structure:
524        // - ENCRYPTION header at offset pointing to encryption data
525        // - VERSION header at offset pointing to version data
526        // - Terminator
527        // - Data section
528
529        use bytes::BufMut;
530
531        let mut buf = bytes::BytesMut::new();
532
533        // Header section: each option is 5 bytes (type:1 + offset:2 + length:2)
534        // We'll have 2 options + terminator = 11 bytes header
535        let header_size: u16 = 11;
536
537        // ENCRYPTION option header (put this first to test that we read from correct offset)
538        buf.put_u8(PreLoginOption::Encryption as u8);
539        buf.put_u16(header_size); // offset to encryption data
540        buf.put_u16(1); // length
541
542        // VERSION option header
543        buf.put_u8(PreLoginOption::Version as u8);
544        buf.put_u16(header_size + 1); // offset to version data (after encryption)
545        buf.put_u16(6); // length
546
547        // Terminator
548        buf.put_u8(PreLoginOption::Terminator as u8);
549
550        // Data section
551        // Encryption data (1 byte): ENCRYPT_ON = 0x01
552        buf.put_u8(0x01);
553
554        // Version data (6 bytes): TDS 7.4 = 0x74000004 big-endian + sub-build 0x0000 little-endian
555        buf.put_u8(0x74);
556        buf.put_u8(0x00);
557        buf.put_u8(0x00);
558        buf.put_u8(0x04);
559        buf.put_u16_le(0x0000); // sub-build
560
561        // Decode
562        let decoded = PreLogin::decode(buf.freeze().as_ref()).unwrap();
563
564        // Verify encryption was read from correct offset (not from index 0)
565        assert_eq!(decoded.encryption, EncryptionLevel::On);
566        assert_eq!(decoded.version, TdsVersion::V7_4);
567    }
568}