sqlx_core_oldapi/mssql/protocol/
pre_login.rs

1use std::fmt::{self, Display, Formatter};
2
3use bytes::{Buf, Bytes};
4use uuid::Uuid;
5
6use crate::error::Error;
7use crate::io::{Decode, Encode};
8
9/// A message sent by the client to set up context for login. The server responds to a client
10/// `PRELOGIN` message with a message of packet header type `0x04` and the packet data
11/// containing a `PRELOGIN` structure.
12#[derive(Debug, Default)]
13pub(crate) struct PreLogin {
14    pub(crate) version: Version,
15    pub(crate) encryption: Encrypt,
16    pub(crate) instance: Option<String>,
17    pub(crate) thread_id: Option<u32>,
18    pub(crate) trace_id: Option<TraceId>,
19    pub(crate) multiple_active_result_sets: Option<bool>,
20}
21
22impl<'de> Decode<'de> for PreLogin {
23    fn decode_with(buf: Bytes, _: ()) -> Result<Self, Error> {
24        let mut version = None;
25        let mut encryption = None;
26        let mut instance = None;
27        let mut thread_id = None;
28        let mut trace_id = None;
29        let mut multiple_active_result_sets = None;
30
31        let mut offsets = buf.clone();
32
33        loop {
34            let token = offsets.get_u8();
35
36            match PreLoginOptionToken::get(token) {
37                Some(token) => {
38                    let offset = offsets.get_u16() as usize;
39                    let size = offsets.get_u16() as usize;
40                    let mut data = &buf[offset..offset + size];
41
42                    match token {
43                        PreLoginOptionToken::Version => {
44                            let major = data.get_u8();
45                            let minor = data.get_u8();
46                            let build = data.get_u16();
47                            let sub_build = data.get_u16();
48
49                            version = Some(Version {
50                                major,
51                                minor,
52                                build,
53                                sub_build,
54                            });
55                        }
56
57                        PreLoginOptionToken::Encryption => {
58                            encryption = Some(Encrypt::from(data.get_u8()));
59                        }
60
61                        PreLoginOptionToken::Instance => {
62                            // data is null-terminated
63                            instance = Some(String::from_utf8_lossy(&data[..size - 1]).to_string());
64                        }
65                        PreLoginOptionToken::ThreadId => {
66                            thread_id = Some(data.get_u32_le());
67                        }
68                        PreLoginOptionToken::MultipleActiveResultSets => {
69                            multiple_active_result_sets = Some(data.get_u8() != 0);
70                        }
71                        PreLoginOptionToken::TraceId => {
72                            let connection_id = Uuid::from_u128(data.get_u128());
73                            let activity_id = Uuid::from_u128(data.get_u128());
74                            let activity_seq = data.get_u32();
75
76                            trace_id = Some(TraceId {
77                                connection_id,
78                                activity_id,
79                                activity_seq,
80                            });
81                        }
82                    }
83                }
84
85                None if token == 0xff => {
86                    break;
87                }
88
89                None => {
90                    return Err(err_protocol!(
91                        "PRELOGIN: unexpected login option token: 0x{:02?}",
92                        token
93                    ));
94                }
95            }
96        }
97
98        let version =
99            version.ok_or(err_protocol!("PRELOGIN: missing required `version` option"))?;
100
101        let encryption = encryption.ok_or(err_protocol!(
102            "PRELOGIN: missing required `encryption` option"
103        ))?;
104
105        Ok(Self {
106            version,
107            encryption,
108            instance,
109            thread_id,
110            trace_id,
111            multiple_active_result_sets,
112        })
113    }
114}
115
116impl Encode<'_> for PreLogin {
117    fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
118        use PreLoginOptionToken::*;
119
120        // NOTE: Packet headers are written in MssqlStream::write
121
122        // Rules
123        //  PRELOGIN = (*PRELOGIN_OPTION *PL_OPTION_DATA) / SSL_PAYLOAD
124        //  PRELOGIN_OPTION = (PL_OPTION_TOKEN PL_OFFSET PL_OPTION_LENGTH) / TERMINATOR
125
126        // Count the number of set options
127        let num_options = 2
128            + self.instance.as_ref().map_or(0, |_| 1)
129            + self.thread_id.map_or(0, |_| 1)
130            + self.trace_id.as_ref().map_or(0, |_| 1)
131            + self.multiple_active_result_sets.map_or(0, |_| 1);
132
133        // Calculate the length of the option offset block. Each block is 5 bytes and it ends in
134        // a 1 byte terminator.
135        let len_offsets = (num_options * 5) + 1;
136        let mut offsets = buf.len();
137        let mut offset = u16::try_from(len_offsets).unwrap();
138
139        // Reserve a chunk for the offset block and set the final terminator
140        buf.resize(buf.len() + len_offsets, 0);
141        let end_offsets = buf.len() - 1;
142        buf[end_offsets] = 0xff;
143
144        // NOTE: VERSION is a required token, and it MUST be the first token.
145        Version.put(buf, &mut offsets, &mut offset, 6);
146        self.version.encode(buf);
147
148        Encryption.put(buf, &mut offsets, &mut offset, 1);
149        buf.push(u8::from(self.encryption));
150
151        if let Some(name) = &self.instance {
152            Instance.put(
153                buf,
154                &mut offsets,
155                &mut offset,
156                u16::try_from(name.len() + 1).unwrap(),
157            );
158            buf.extend_from_slice(name.as_bytes());
159            buf.push(b'\0');
160        }
161
162        if let Some(id) = self.thread_id {
163            ThreadId.put(buf, &mut offsets, &mut offset, 4);
164            buf.extend_from_slice(&id.to_le_bytes());
165        }
166
167        if let Some(trace) = &self.trace_id {
168            ThreadId.put(buf, &mut offsets, &mut offset, 36);
169            buf.extend_from_slice(trace.connection_id.as_bytes());
170            buf.extend_from_slice(trace.activity_id.as_bytes());
171            buf.extend_from_slice(&trace.activity_seq.to_be_bytes());
172        }
173
174        if let Some(mars) = &self.multiple_active_result_sets {
175            MultipleActiveResultSets.put(buf, &mut offsets, &mut offset, 1);
176            buf.push(*mars as u8);
177        }
178    }
179}
180
181// token value representing the option (PL_OPTION_TOKEN)
182#[derive(Debug, Copy, Clone)]
183#[repr(u8)]
184enum PreLoginOptionToken {
185    Version = 0x00,
186    Encryption = 0x01,
187    Instance = 0x02,
188    ThreadId = 0x03,
189
190    // Multiple Active Result Sets (MARS)
191    MultipleActiveResultSets = 0x04,
192
193    TraceId = 0x05,
194}
195
196impl PreLoginOptionToken {
197    #[allow(clippy::ptr_arg)]
198    fn put(self, buf: &mut Vec<u8>, pos: &mut usize, offset: &mut u16, len: u16) {
199        buf[*pos] = self as u8;
200        *pos += 1;
201
202        buf[*pos..(*pos + 2)].copy_from_slice(&offset.to_be_bytes());
203        *pos += 2;
204
205        buf[*pos..(*pos + 2)].copy_from_slice(&len.to_be_bytes());
206        *pos += 2;
207
208        *offset += len;
209    }
210
211    fn get(b: u8) -> Option<Self> {
212        Some(match b {
213            0x00 => PreLoginOptionToken::Version,
214            0x01 => PreLoginOptionToken::Encryption,
215            0x02 => PreLoginOptionToken::Instance,
216            0x03 => PreLoginOptionToken::ThreadId,
217            0x04 => PreLoginOptionToken::MultipleActiveResultSets,
218            0x05 => PreLoginOptionToken::TraceId,
219
220            _ => {
221                return None;
222            }
223        })
224    }
225}
226
227#[derive(Debug)]
228pub(crate) struct TraceId {
229    // client application trace ID (GUID_CONNID)
230    pub(crate) connection_id: Uuid,
231
232    // client application activity ID (GUID_ActivityID)
233    pub(crate) activity_id: Uuid,
234
235    // client application activity sequence (ActivitySequence)
236    pub(crate) activity_seq: u32,
237}
238
239// Version of the sender (UL_VERSION)
240#[derive(Debug)]
241pub(crate) struct Version {
242    pub(crate) major: u8,
243    pub(crate) minor: u8,
244    pub(crate) build: u16,
245
246    // Sub-build number of the sender (US_SUBBUILD)
247    pub(crate) sub_build: u16,
248}
249
250impl Default for Version {
251    fn default() -> Self {
252        Self {
253            major: env!("CARGO_PKG_VERSION_MAJOR").parse().unwrap(),
254            minor: env!("CARGO_PKG_VERSION_MINOR").parse().unwrap(),
255            build: env!("CARGO_PKG_VERSION_PATCH").parse().unwrap(),
256            sub_build: 0,
257        }
258    }
259}
260
261impl Version {
262    fn encode(&self, buf: &mut Vec<u8>) {
263        buf.push(self.major);
264        buf.push(self.minor);
265        buf.extend(&self.build.to_be_bytes());
266        buf.extend(&self.sub_build.to_be_bytes());
267    }
268}
269
270impl Display for Version {
271    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
272        write!(f, "v{}.{}.{}", self.major, self.minor, self.build)
273    }
274}
275
276/// During the Pre-Login handshake, the client and the server negotiate the
277/// wire encryption to be used.
278#[derive(Default, Clone, Copy, Debug, PartialEq, Eq, Hash)]
279pub enum Encrypt {
280    /// Encryption is available but off.
281    Off = 0x00,
282
283    /// Encryption is available and on.
284    #[default]
285    On = 0x01,
286
287    /// Encryption is not available.
288    NotSupported = 0x02,
289
290    /// Encryption is required.
291    Required = 0x03,
292}
293
294impl From<u8> for Encrypt {
295    fn from(value: u8) -> Self {
296        match value {
297            0x00 => Encrypt::Off,
298            0x01 => Encrypt::On,
299            0x02 => Encrypt::NotSupported,
300            0x03 => Encrypt::Required,
301            _ => Encrypt::Off,
302        }
303    }
304}
305
306impl From<Encrypt> for u8 {
307    fn from(value: Encrypt) -> Self {
308        value as u8
309    }
310}
311
312#[test]
313fn test_encode_pre_login() {
314    let mut buf = Vec::new();
315
316    let pre_login = PreLogin {
317        version: Version {
318            major: 9,
319            minor: 0,
320            build: 0,
321            sub_build: 0,
322        },
323        encryption: Encrypt::On,
324        instance: Some("".to_string()),
325        thread_id: Some(0x00000DB8),
326        multiple_active_result_sets: Some(true),
327
328        ..Default::default()
329    };
330
331    // From v20191101 of MS-TDS documentation
332    #[rustfmt::skip]
333    let expected = vec![
334        0x00, 0x00, 0x1A, 0x00, 0x06, 0x01, 0x00, 0x20, 0x00, 0x01, 0x02, 0x00, 0x21, 0x00,
335        0x01, 0x03, 0x00, 0x22, 0x00, 0x04, 0x04, 0x00, 0x26, 0x00, 0x01, 0xFF, 0x09, 0x00,
336        0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0xB8, 0x0D, 0x00, 0x00, 0x01
337    ];
338
339    pre_login.encode(&mut buf);
340
341    assert_eq!(expected, buf);
342}
343
344#[test]
345fn test_decode_pre_login() {
346    #[rustfmt::skip]
347    let buffer = Bytes::from_static(&[
348        0, 0, 11, 0, 6, 1, 0, 17, 0, 1, 255,
349        14, 0, 12, 209, 0, 0, 0,
350    ]);
351
352    let pre_login = PreLogin::decode(buffer).unwrap();
353
354    // v14.0.3281
355    assert_eq!(pre_login.version.major, 14);
356    assert_eq!(pre_login.version.minor, 0);
357    assert_eq!(pre_login.version.build, 3281);
358    assert_eq!(pre_login.version.sub_build, 0);
359
360    // ENCRYPT_OFF
361    assert_eq!(u8::from(pre_login.encryption), 0);
362}