tds_protocol/
prelogin.rs

1//! TDS pre-login packet handling.
2//!
3//! The pre-login packet is the first message exchanged between client and server
4//! in TDS 7.x connections. It negotiates protocol version, encryption, and other
5//! connection parameters.
6//!
7//! Note: TDS 8.0 (strict mode) does not use pre-login negotiation; TLS is
8//! established before any TDS traffic.
9
10use bytes::{Buf, BufMut, Bytes, BytesMut};
11
12use crate::error::ProtocolError;
13use crate::prelude::*;
14use crate::version::{SqlServerVersion, TdsVersion};
15
16/// Pre-login option types.
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18#[repr(u8)]
19pub enum PreLoginOption {
20    /// Version information.
21    Version = 0x00,
22    /// Encryption negotiation.
23    Encryption = 0x01,
24    /// Instance name (for named instances).
25    Instance = 0x02,
26    /// Thread ID.
27    ThreadId = 0x03,
28    /// MARS (Multiple Active Result Sets) support.
29    Mars = 0x04,
30    /// Trace ID for distributed tracing.
31    TraceId = 0x05,
32    /// Federated authentication required.
33    FedAuthRequired = 0x06,
34    /// Nonce for encryption.
35    Nonce = 0x07,
36    /// Terminator (end of options).
37    Terminator = 0xFF,
38}
39
40impl PreLoginOption {
41    /// Create from raw byte value.
42    pub fn from_u8(value: u8) -> Result<Self, ProtocolError> {
43        match value {
44            0x00 => Ok(Self::Version),
45            0x01 => Ok(Self::Encryption),
46            0x02 => Ok(Self::Instance),
47            0x03 => Ok(Self::ThreadId),
48            0x04 => Ok(Self::Mars),
49            0x05 => Ok(Self::TraceId),
50            0x06 => Ok(Self::FedAuthRequired),
51            0x07 => Ok(Self::Nonce),
52            0xFF => Ok(Self::Terminator),
53            _ => Err(ProtocolError::InvalidPreloginOption(value)),
54        }
55    }
56}
57
58/// Encryption level for connection.
59#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
60#[repr(u8)]
61pub enum EncryptionLevel {
62    /// Encryption is off.
63    Off = 0x00,
64    /// Encryption is on.
65    On = 0x01,
66    /// Encryption is not supported.
67    NotSupported = 0x02,
68    /// Encryption is required.
69    #[default]
70    Required = 0x03,
71    /// Client certificate authentication (TDS 8.0+).
72    ClientCertAuth = 0x80,
73}
74
75impl EncryptionLevel {
76    /// Create from raw byte value.
77    pub fn from_u8(value: u8) -> Self {
78        match value {
79            0x00 => Self::Off,
80            0x01 => Self::On,
81            0x02 => Self::NotSupported,
82            0x03 => Self::Required,
83            0x80 => Self::ClientCertAuth,
84            _ => Self::Off,
85        }
86    }
87
88    /// Check if encryption is required.
89    #[must_use]
90    pub const fn is_required(&self) -> bool {
91        matches!(self, Self::On | Self::Required | Self::ClientCertAuth)
92    }
93}
94
95/// Pre-login message builder and parser.
96///
97/// This struct is used for both client requests and server responses:
98/// - **Client → Server**: Set `version` to the requested TDS version
99/// - **Server → Client**: `server_version` contains the SQL Server product version
100///
101/// Note: The VERSION field has different semantics in each direction:
102/// - Client sends: TDS protocol version (e.g., 7.4)
103/// - Server sends: SQL Server product version (e.g., 13.0.6300 for SQL Server 2016)
104#[derive(Debug, Clone, Default)]
105pub struct PreLogin {
106    /// TDS version (client request).
107    ///
108    /// This is the TDS protocol version the client requests. When sending a
109    /// PreLogin, set this to the desired TDS version.
110    pub version: TdsVersion,
111
112    /// SQL Server product version (server response).
113    ///
114    /// When decoding a PreLogin response from the server, this contains the
115    /// SQL Server product version (e.g., 13.0.6300 for SQL Server 2016).
116    /// This is NOT the TDS version - the actual TDS version is negotiated
117    /// in the LOGINACK token after login.
118    pub server_version: Option<SqlServerVersion>,
119
120    /// Sub-build version (legacy, now part of server_version).
121    #[deprecated(since = "0.5.2", note = "Use server_version.sub_build instead")]
122    pub sub_build: u16,
123
124    /// Encryption level.
125    pub encryption: EncryptionLevel,
126    /// Instance name (for named instances).
127    pub instance: Option<String>,
128    /// Thread ID.
129    pub thread_id: Option<u32>,
130    /// MARS enabled.
131    pub mars: bool,
132    /// Trace ID (Activity ID and Sequence).
133    pub trace_id: Option<TraceId>,
134    /// Federated authentication required.
135    pub fed_auth_required: bool,
136    /// Nonce for encryption.
137    pub nonce: Option<[u8; 32]>,
138}
139
140/// Distributed tracing ID.
141#[derive(Debug, Clone, Copy)]
142pub struct TraceId {
143    /// Activity ID (GUID).
144    pub activity_id: [u8; 16],
145    /// Activity sequence.
146    pub activity_sequence: u32,
147}
148
149impl PreLogin {
150    /// Create a new pre-login message with default values.
151    #[must_use]
152    #[allow(deprecated)] // sub_build is deprecated but we need to initialize it
153    pub fn new() -> Self {
154        Self {
155            version: TdsVersion::V7_4,
156            server_version: None,
157            sub_build: 0,
158            encryption: EncryptionLevel::Required,
159            instance: None,
160            thread_id: None,
161            mars: false,
162            trace_id: None,
163            fed_auth_required: false,
164            nonce: None,
165        }
166    }
167
168    /// Set the TDS version.
169    #[must_use]
170    pub fn with_version(mut self, version: TdsVersion) -> Self {
171        self.version = version;
172        self
173    }
174
175    /// Set the encryption level.
176    #[must_use]
177    pub fn with_encryption(mut self, level: EncryptionLevel) -> Self {
178        self.encryption = level;
179        self
180    }
181
182    /// Enable MARS.
183    #[must_use]
184    pub fn with_mars(mut self, enabled: bool) -> Self {
185        self.mars = enabled;
186        self
187    }
188
189    /// Set the instance name.
190    #[must_use]
191    pub fn with_instance(mut self, instance: impl Into<String>) -> Self {
192        self.instance = Some(instance.into());
193        self
194    }
195
196    /// Encode the pre-login message to bytes.
197    #[must_use]
198    #[allow(deprecated)] // sub_build is deprecated but we still encode it
199    pub fn encode(&self) -> Bytes {
200        let mut buf = BytesMut::with_capacity(256);
201
202        // Calculate option data offsets
203        // Each option entry is 5 bytes: type (1) + offset (2) + length (2)
204        // Plus 1 byte for terminator
205        let mut option_count = 3; // Version, Encryption, MARS are always present
206        if self.instance.is_some() {
207            option_count += 1;
208        }
209        if self.thread_id.is_some() {
210            option_count += 1;
211        }
212        if self.trace_id.is_some() {
213            option_count += 1;
214        }
215        if self.fed_auth_required {
216            option_count += 1;
217        }
218        if self.nonce.is_some() {
219            option_count += 1;
220        }
221
222        let header_size = option_count * 5 + 1; // +1 for terminator
223        let mut data_offset = header_size as u16;
224        let mut data_buf = BytesMut::new();
225
226        // VERSION option (6 bytes: 4 bytes version + 2 bytes sub-build)
227        buf.put_u8(PreLoginOption::Version as u8);
228        buf.put_u16(data_offset);
229        buf.put_u16(6);
230        let version_raw = self.version.raw();
231        data_buf.put_u8((version_raw >> 24) as u8);
232        data_buf.put_u8((version_raw >> 16) as u8);
233        data_buf.put_u8((version_raw >> 8) as u8);
234        data_buf.put_u8(version_raw as u8);
235        data_buf.put_u16_le(self.sub_build);
236        data_offset += 6;
237
238        // ENCRYPTION option (1 byte)
239        buf.put_u8(PreLoginOption::Encryption as u8);
240        buf.put_u16(data_offset);
241        buf.put_u16(1);
242        data_buf.put_u8(self.encryption as u8);
243        data_offset += 1;
244
245        // INSTANCE option (if set)
246        if let Some(ref instance) = self.instance {
247            let instance_bytes = instance.as_bytes();
248            let len = instance_bytes.len() as u16 + 1; // +1 for null terminator
249            buf.put_u8(PreLoginOption::Instance as u8);
250            buf.put_u16(data_offset);
251            buf.put_u16(len);
252            data_buf.put_slice(instance_bytes);
253            data_buf.put_u8(0); // null terminator
254            data_offset += len;
255        }
256
257        // THREADID option (if set)
258        if let Some(thread_id) = self.thread_id {
259            buf.put_u8(PreLoginOption::ThreadId as u8);
260            buf.put_u16(data_offset);
261            buf.put_u16(4);
262            data_buf.put_u32(thread_id);
263            data_offset += 4;
264        }
265
266        // MARS option (1 byte)
267        buf.put_u8(PreLoginOption::Mars as u8);
268        buf.put_u16(data_offset);
269        buf.put_u16(1);
270        data_buf.put_u8(if self.mars { 0x01 } else { 0x00 });
271        data_offset += 1;
272
273        // TRACEID option (if set)
274        if let Some(ref trace_id) = self.trace_id {
275            buf.put_u8(PreLoginOption::TraceId as u8);
276            buf.put_u16(data_offset);
277            buf.put_u16(36);
278            data_buf.put_slice(&trace_id.activity_id);
279            data_buf.put_u32_le(trace_id.activity_sequence);
280            // Connection ID (16 bytes, typically zeros for client)
281            data_buf.put_slice(&[0u8; 16]);
282            data_offset += 36;
283        }
284
285        // FEDAUTHREQUIRED option (if set)
286        if self.fed_auth_required {
287            buf.put_u8(PreLoginOption::FedAuthRequired as u8);
288            buf.put_u16(data_offset);
289            buf.put_u16(1);
290            data_buf.put_u8(0x01);
291            data_offset += 1;
292        }
293
294        // NONCE option (if set)
295        if let Some(ref nonce) = self.nonce {
296            buf.put_u8(PreLoginOption::Nonce as u8);
297            buf.put_u16(data_offset);
298            buf.put_u16(32);
299            data_buf.put_slice(nonce);
300            let _ = data_offset; // Suppress unused warning
301        }
302
303        // Terminator
304        buf.put_u8(PreLoginOption::Terminator as u8);
305
306        // Append data section
307        buf.put_slice(&data_buf);
308
309        buf.freeze()
310    }
311
312    /// Decode a pre-login response from the server.
313    ///
314    /// Per MS-TDS spec 2.2.6.4, PreLogin message structure:
315    /// - Option headers: each 5 bytes (type:1 + offset:2 + length:2)
316    /// - Terminator: 1 byte (0xFF)
317    /// - Option data: variable length, positioned at offsets specified in headers
318    ///
319    /// Offsets in headers are absolute from the start of the PreLogin packet payload.
320    pub fn decode(mut src: impl Buf) -> Result<Self, ProtocolError> {
321        let mut prelogin = Self::default();
322
323        // Parse option headers first, collecting (option_type, offset, length)
324        let mut options = Vec::new();
325        loop {
326            if src.remaining() < 1 {
327                return Err(ProtocolError::UnexpectedEof);
328            }
329
330            let option_type = src.get_u8();
331            if option_type == PreLoginOption::Terminator as u8 {
332                break;
333            }
334
335            if src.remaining() < 4 {
336                return Err(ProtocolError::UnexpectedEof);
337            }
338
339            let offset = src.get_u16();
340            let length = src.get_u16();
341            options.push((PreLoginOption::from_u8(option_type)?, offset, length));
342        }
343
344        // Get remaining data as bytes for random access
345        let data = src.copy_to_bytes(src.remaining());
346
347        // Calculate header size: each option is 5 bytes + 1 byte terminator
348        let header_size = options.len() * 5 + 1;
349
350        for (option, packet_offset, length) in options {
351            let packet_offset = packet_offset as usize;
352            let length = length as usize;
353
354            // Convert absolute packet offset to offset within data buffer
355            // The data buffer starts after the headers, so we subtract header_size
356            if packet_offset < header_size {
357                // Invalid: offset points inside the headers
358                continue;
359            }
360            let data_offset = packet_offset - header_size;
361
362            // Bounds check
363            if data_offset + length > data.len() {
364                continue;
365            }
366
367            #[allow(deprecated)] // We still populate sub_build for backward compatibility
368            match option {
369                PreLoginOption::Version if length >= 4 => {
370                    // Per MS-TDS 2.2.6.4: The server sends its SQL Server product version
371                    // in the VERSION field, NOT the TDS protocol version.
372                    //
373                    // Format: UL_VERSION (4 bytes big-endian) + US_SUBBUILD (2 bytes little-endian)
374                    // UL_VERSION contains: [major][minor][build_hi][build_lo]
375                    //
376                    // For example, SQL Server 2016 sends 13.0.xxxx (major=13, minor=0)
377                    let version_bytes = &data[data_offset..data_offset + 4];
378                    let version_raw = u32::from_be_bytes([
379                        version_bytes[0],
380                        version_bytes[1],
381                        version_bytes[2],
382                        version_bytes[3],
383                    ]);
384
385                    // Extract sub_build if present
386                    let sub_build = if length >= 6 {
387                        let sub_build_bytes = &data[data_offset + 4..data_offset + 6];
388                        u16::from_le_bytes([sub_build_bytes[0], sub_build_bytes[1]])
389                    } else {
390                        0
391                    };
392
393                    // Populate the new SqlServerVersion field (correct semantics)
394                    prelogin.server_version =
395                        Some(SqlServerVersion::from_raw(version_raw, sub_build));
396
397                    // Also set deprecated fields for backward compatibility
398                    prelogin.version = TdsVersion::new(version_raw);
399                    prelogin.sub_build = sub_build;
400                }
401                PreLoginOption::Encryption if length >= 1 => {
402                    prelogin.encryption = EncryptionLevel::from_u8(data[data_offset]);
403                }
404                PreLoginOption::Mars if length >= 1 => {
405                    prelogin.mars = data[data_offset] != 0;
406                }
407                PreLoginOption::Instance if length > 0 => {
408                    // Instance name is null-terminated string
409                    let instance_data = &data[data_offset..data_offset + length];
410                    if let Some(null_pos) = instance_data.iter().position(|&b| b == 0) {
411                        if let Ok(s) = core::str::from_utf8(&instance_data[..null_pos]) {
412                            if !s.is_empty() {
413                                prelogin.instance = Some(s.to_string());
414                            }
415                        }
416                    }
417                }
418                PreLoginOption::ThreadId if length >= 4 => {
419                    let bytes = &data[data_offset..data_offset + 4];
420                    prelogin.thread_id =
421                        Some(u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]));
422                }
423                PreLoginOption::FedAuthRequired if length >= 1 => {
424                    prelogin.fed_auth_required = data[data_offset] != 0;
425                }
426                PreLoginOption::Nonce if length >= 32 => {
427                    let mut nonce = [0u8; 32];
428                    nonce.copy_from_slice(&data[data_offset..data_offset + 32]);
429                    prelogin.nonce = Some(nonce);
430                }
431                _ => {}
432            }
433        }
434
435        Ok(prelogin)
436    }
437}
438
439#[cfg(test)]
440#[allow(clippy::unwrap_used)]
441mod tests {
442    use super::*;
443
444    #[test]
445    fn test_prelogin_encode() {
446        let prelogin = PreLogin::new()
447            .with_version(TdsVersion::V7_4)
448            .with_encryption(EncryptionLevel::Required);
449
450        let encoded = prelogin.encode();
451        assert!(!encoded.is_empty());
452        // First byte should be VERSION option type
453        assert_eq!(encoded[0], PreLoginOption::Version as u8);
454    }
455
456    #[test]
457    fn test_encryption_level() {
458        assert!(EncryptionLevel::Required.is_required());
459        assert!(EncryptionLevel::On.is_required());
460        assert!(!EncryptionLevel::Off.is_required());
461        assert!(!EncryptionLevel::NotSupported.is_required());
462    }
463
464    #[test]
465    fn test_prelogin_decode_roundtrip() {
466        // Create a PreLogin with various options
467        let original = PreLogin::new()
468            .with_version(TdsVersion::V7_4)
469            .with_encryption(EncryptionLevel::On)
470            .with_mars(true);
471
472        // Encode it
473        let encoded = original.encode();
474
475        // Decode it back
476        let decoded = PreLogin::decode(encoded.as_ref()).unwrap();
477
478        // Verify the critical fields match
479        assert_eq!(decoded.version, original.version);
480        assert_eq!(decoded.encryption, original.encryption);
481        assert_eq!(decoded.mars, original.mars);
482    }
483
484    #[test]
485    fn test_prelogin_decode_encryption_offset() {
486        // Manually construct a PreLogin packet with options in non-standard order
487        // to verify offset handling works correctly
488        //
489        // Structure:
490        // - ENCRYPTION header at offset pointing to encryption data
491        // - VERSION header at offset pointing to version data
492        // - Terminator
493        // - Data section
494
495        use bytes::BufMut;
496
497        let mut buf = bytes::BytesMut::new();
498
499        // Header section: each option is 5 bytes (type:1 + offset:2 + length:2)
500        // We'll have 2 options + terminator = 11 bytes header
501        let header_size: u16 = 11;
502
503        // ENCRYPTION option header (put this first to test that we read from correct offset)
504        buf.put_u8(PreLoginOption::Encryption as u8);
505        buf.put_u16(header_size); // offset to encryption data
506        buf.put_u16(1); // length
507
508        // VERSION option header
509        buf.put_u8(PreLoginOption::Version as u8);
510        buf.put_u16(header_size + 1); // offset to version data (after encryption)
511        buf.put_u16(6); // length
512
513        // Terminator
514        buf.put_u8(PreLoginOption::Terminator as u8);
515
516        // Data section
517        // Encryption data (1 byte): ENCRYPT_ON = 0x01
518        buf.put_u8(0x01);
519
520        // Version data (6 bytes): TDS 7.4 = 0x74000004 big-endian + sub-build 0x0000 little-endian
521        buf.put_u8(0x74);
522        buf.put_u8(0x00);
523        buf.put_u8(0x00);
524        buf.put_u8(0x04);
525        buf.put_u16_le(0x0000); // sub-build
526
527        // Decode
528        let decoded = PreLogin::decode(buf.freeze().as_ref()).unwrap();
529
530        // Verify encryption was read from correct offset (not from index 0)
531        assert_eq!(decoded.encryption, EncryptionLevel::On);
532        assert_eq!(decoded.version, TdsVersion::V7_4);
533    }
534}