sqlx_core_oldapi/mssql/protocol/
pre_login.rs

1use std::fmt::{self, Display, Formatter};
2
3use bytes::{Buf, Bytes};
4use uuid::Uuid;
5
6use crate::error::Error;
7use crate::io::{Decode, Encode};
8
9/// A message sent by the client to set up context for login. The server responds to a client
10/// `PRELOGIN` message with a message of packet header type `0x04` and the packet data
11/// containing a `PRELOGIN` structure.
12#[derive(Debug, Default)]
13pub(crate) struct PreLogin {
14    pub(crate) version: Version,
15    pub(crate) encryption: Encrypt,
16    pub(crate) instance: Option<String>,
17    pub(crate) thread_id: Option<u32>,
18    pub(crate) trace_id: Option<TraceId>,
19    pub(crate) multiple_active_result_sets: Option<bool>,
20}
21
22impl<'de> Decode<'de> for PreLogin {
23    fn decode_with(buf: Bytes, _: ()) -> Result<Self, Error> {
24        let mut version = None;
25        let mut encryption = None;
26        let mut instance = None;
27        let mut thread_id = None;
28        let mut trace_id = None;
29        let mut multiple_active_result_sets = None;
30
31        let mut offsets = buf.clone();
32
33        loop {
34            let token = offsets.get_u8();
35
36            match PreLoginOptionToken::get(token) {
37                Some(token) => {
38                    let offset = offsets.get_u16() as usize;
39                    let size = offsets.get_u16() as usize;
40                    let mut data = &buf[offset..offset + size];
41
42                    match token {
43                        PreLoginOptionToken::Version => {
44                            let major = data.get_u8();
45                            let minor = data.get_u8();
46                            let build = data.get_u16();
47                            let sub_build = data.get_u16();
48
49                            version = Some(Version {
50                                major,
51                                minor,
52                                build,
53                                sub_build,
54                            });
55                        }
56
57                        PreLoginOptionToken::Encryption => {
58                            encryption = Some(Encrypt::from(data.get_u8()));
59                        }
60
61                        PreLoginOptionToken::Instance => {
62                            // data is null-terminated
63                            instance = Some(String::from_utf8_lossy(&data[..size - 1]).to_string());
64                        }
65                        PreLoginOptionToken::ThreadId => {
66                            thread_id = Some(data.get_u32_le());
67                        }
68                        PreLoginOptionToken::MultipleActiveResultSets => {
69                            multiple_active_result_sets = Some(data.get_u8() != 0);
70                        }
71                        PreLoginOptionToken::TraceId => {
72                            let connection_id = Uuid::from_u128(data.get_u128());
73                            let activity_id = Uuid::from_u128(data.get_u128());
74                            let activity_seq = data.get_u32();
75
76                            trace_id = Some(TraceId {
77                                connection_id,
78                                activity_id,
79                                activity_seq,
80                            });
81                        }
82                    }
83                }
84
85                None if token == 0xff => {
86                    break;
87                }
88
89                None => {
90                    return Err(err_protocol!(
91                        "PRELOGIN: unexpected login option token: 0x{:02?}",
92                        token
93                    )
94                    .into());
95                }
96            }
97        }
98
99        let version =
100            version.ok_or(err_protocol!("PRELOGIN: missing required `version` option"))?;
101
102        let encryption = encryption.ok_or(err_protocol!(
103            "PRELOGIN: missing required `encryption` option"
104        ))?;
105
106        Ok(Self {
107            version,
108            encryption,
109            instance,
110            thread_id,
111            trace_id,
112            multiple_active_result_sets,
113        })
114    }
115}
116
117impl Encode<'_> for PreLogin {
118    fn encode_with(&self, buf: &mut Vec<u8>, _: ()) {
119        use PreLoginOptionToken::*;
120
121        // NOTE: Packet headers are written in MssqlStream::write
122
123        // Rules
124        //  PRELOGIN = (*PRELOGIN_OPTION *PL_OPTION_DATA) / SSL_PAYLOAD
125        //  PRELOGIN_OPTION = (PL_OPTION_TOKEN PL_OFFSET PL_OPTION_LENGTH) / TERMINATOR
126
127        // Count the number of set options
128        let num_options = 2
129            + self.instance.as_ref().map_or(0, |_| 1)
130            + self.thread_id.map_or(0, |_| 1)
131            + self.trace_id.as_ref().map_or(0, |_| 1)
132            + self.multiple_active_result_sets.map_or(0, |_| 1);
133
134        // Calculate the length of the option offset block. Each block is 5 bytes and it ends in
135        // a 1 byte terminator.
136        let len_offsets = (num_options * 5) + 1;
137        let mut offsets = buf.len();
138        let mut offset = u16::try_from(len_offsets).unwrap();
139
140        // Reserve a chunk for the offset block and set the final terminator
141        buf.resize(buf.len() + len_offsets, 0);
142        let end_offsets = buf.len() - 1;
143        buf[end_offsets] = 0xff;
144
145        // NOTE: VERSION is a required token, and it MUST be the first token.
146        Version.put(buf, &mut offsets, &mut offset, 6);
147        self.version.encode(buf);
148
149        Encryption.put(buf, &mut offsets, &mut offset, 1);
150        buf.push(u8::from(self.encryption));
151
152        if let Some(name) = &self.instance {
153            Instance.put(
154                buf,
155                &mut offsets,
156                &mut offset,
157                u16::try_from(name.len() + 1).unwrap(),
158            );
159            buf.extend_from_slice(name.as_bytes());
160            buf.push(b'\0');
161        }
162
163        if let Some(id) = self.thread_id {
164            ThreadId.put(buf, &mut offsets, &mut offset, 4);
165            buf.extend_from_slice(&id.to_le_bytes());
166        }
167
168        if let Some(trace) = &self.trace_id {
169            ThreadId.put(buf, &mut offsets, &mut offset, 36);
170            buf.extend_from_slice(trace.connection_id.as_bytes());
171            buf.extend_from_slice(trace.activity_id.as_bytes());
172            buf.extend_from_slice(&trace.activity_seq.to_be_bytes());
173        }
174
175        if let Some(mars) = &self.multiple_active_result_sets {
176            MultipleActiveResultSets.put(buf, &mut offsets, &mut offset, 1);
177            buf.push(*mars as u8);
178        }
179    }
180}
181
182// token value representing the option (PL_OPTION_TOKEN)
183#[derive(Debug, Copy, Clone)]
184#[repr(u8)]
185enum PreLoginOptionToken {
186    Version = 0x00,
187    Encryption = 0x01,
188    Instance = 0x02,
189    ThreadId = 0x03,
190
191    // Multiple Active Result Sets (MARS)
192    MultipleActiveResultSets = 0x04,
193
194    TraceId = 0x05,
195}
196
197impl PreLoginOptionToken {
198    fn put(self, buf: &mut Vec<u8>, pos: &mut usize, offset: &mut u16, len: u16) {
199        buf[*pos] = self as u8;
200        *pos += 1;
201
202        buf[*pos..(*pos + 2)].copy_from_slice(&offset.to_be_bytes());
203        *pos += 2;
204
205        buf[*pos..(*pos + 2)].copy_from_slice(&len.to_be_bytes());
206        *pos += 2;
207
208        *offset += len;
209    }
210
211    fn get(b: u8) -> Option<Self> {
212        Some(match b {
213            0x00 => PreLoginOptionToken::Version,
214            0x01 => PreLoginOptionToken::Encryption,
215            0x02 => PreLoginOptionToken::Instance,
216            0x03 => PreLoginOptionToken::ThreadId,
217            0x04 => PreLoginOptionToken::MultipleActiveResultSets,
218            0x05 => PreLoginOptionToken::TraceId,
219
220            _ => {
221                return None;
222            }
223        })
224    }
225}
226
227#[derive(Debug)]
228pub(crate) struct TraceId {
229    // client application trace ID (GUID_CONNID)
230    pub(crate) connection_id: Uuid,
231
232    // client application activity ID (GUID_ActivityID)
233    pub(crate) activity_id: Uuid,
234
235    // client application activity sequence (ActivitySequence)
236    pub(crate) activity_seq: u32,
237}
238
239// Version of the sender (UL_VERSION)
240#[derive(Debug)]
241pub(crate) struct Version {
242    pub(crate) major: u8,
243    pub(crate) minor: u8,
244    pub(crate) build: u16,
245
246    // Sub-build number of the sender (US_SUBBUILD)
247    pub(crate) sub_build: u16,
248}
249
250impl Default for Version {
251    fn default() -> Self {
252        Self {
253            major: env!("CARGO_PKG_VERSION_MAJOR").parse().unwrap(),
254            minor: env!("CARGO_PKG_VERSION_MINOR").parse().unwrap(),
255            build: env!("CARGO_PKG_VERSION_PATCH").parse().unwrap(),
256            sub_build: 0,
257        }
258    }
259}
260
261impl Version {
262    fn encode(&self, buf: &mut Vec<u8>) {
263        buf.push(self.major);
264        buf.push(self.minor);
265        buf.extend(&self.build.to_be_bytes());
266        buf.extend(&self.sub_build.to_be_bytes());
267    }
268}
269
270impl Display for Version {
271    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
272        write!(f, "v{}.{}.{}", self.major, self.minor, self.build)
273    }
274}
275
276/// During the Pre-Login handshake, the client and the server negotiate the
277/// wire encryption to be used.
278#[derive(Default, Clone, Copy, Debug, PartialEq, Eq, Hash)]
279pub enum Encrypt {
280    /// Encryption is available but off.
281    Off = 0x00,
282
283    /// Encryption is available and on.
284    #[default]
285    On = 0x01,
286
287    /// Encryption is not available.
288    NotSupported = 0x02,
289
290    /// Encryption is required.
291    Required = 0x03,
292}
293
294impl From<u8> for Encrypt {
295    fn from(value: u8) -> Self {
296        match value {
297            0x00 => Encrypt::Off,
298            0x01 => Encrypt::On,
299            0x02 => Encrypt::NotSupported,
300            0x03 => Encrypt::Required,
301            _ => Encrypt::Off,
302        }
303    }
304}
305
306impl From<Encrypt> for u8 {
307    fn from(value: Encrypt) -> Self {
308        value as u8
309    }
310}
311
312#[test]
313fn test_encode_pre_login() {
314    let mut buf = Vec::new();
315
316    let pre_login = PreLogin {
317        version: Version {
318            major: 9,
319            minor: 0,
320            build: 0,
321            sub_build: 0,
322        },
323        encryption: Encrypt::On,
324        instance: Some("".to_string()),
325        thread_id: Some(0x00000DB8),
326        multiple_active_result_sets: Some(true),
327
328        ..Default::default()
329    };
330
331    // From v20191101 of MS-TDS documentation
332    #[rustfmt::skip]
333    let expected = vec![
334        0x00, 0x00, 0x1A, 0x00, 0x06, 0x01, 0x00, 0x20, 0x00, 0x01, 0x02, 0x00, 0x21, 0x00,
335        0x01, 0x03, 0x00, 0x22, 0x00, 0x04, 0x04, 0x00, 0x26, 0x00, 0x01, 0xFF, 0x09, 0x00,
336        0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0xB8, 0x0D, 0x00, 0x00, 0x01
337    ];
338
339    pre_login.encode(&mut buf);
340
341    assert_eq!(expected, buf);
342}
343
344#[test]
345fn test_decode_pre_login() {
346    #[rustfmt::skip]
347    let buffer = Bytes::from_static(&[
348        0, 0, 11, 0, 6, 1, 0, 17, 0, 1, 255,
349        14, 0, 12, 209, 0, 0, 0,
350    ]);
351
352    let pre_login = PreLogin::decode(buffer).unwrap();
353
354    // v14.0.3281
355    assert_eq!(pre_login.version.major, 14);
356    assert_eq!(pre_login.version.minor, 0);
357    assert_eq!(pre_login.version.build, 3281);
358    assert_eq!(pre_login.version.sub_build, 0);
359
360    // ENCRYPT_OFF
361    assert_eq!(u8::from(pre_login.encryption), 0);
362}