Skip to main content

sqlx_sqlserver/protocol/
login.rs

1use crate::MssqlConnectOptions;
2
3use super::packet::{encode_message, PacketFrameError, PacketType};
4use thiserror::Error;
5
6const LOGIN7_FIXED_LEN: usize = 94;
7const TDS_VERSION_74: u32 = 0x7400_0004;
8
9const OPTION_FLAGS_1: u8 = 0xe0;
10const OPTION_FLAGS_2: u8 = 0x03;
11const TYPE_FLAGS: u8 = 0x00;
12const OPTION_FLAGS_3: u8 = 0x00;
13
14/// Builds an unframed TDS LOGIN7 payload from connection options.
15pub fn build_login7_payload(options: &MssqlConnectOptions) -> Result<Vec<u8>, Login7Error> {
16    let mut fields = Login7Fields::new(LOGIN7_FIXED_LEN);
17
18    let hostname = fields.push_text(options.hostname(), false)?;
19    let username = fields.push_text(options.username(), false)?;
20    let password = fields.push_text(options.password().unwrap_or_default(), true)?;
21    let app_name = fields.push_text(options.app_name(), false)?;
22    let server_name = fields.push_text(options.server_name(), false)?;
23    let unused = Login7FieldOffset::empty(fields.next_offset);
24    let client_interface_name = fields.push_text(options.client_interface_name(), false)?;
25    let language = fields.push_text(options.language(), false)?;
26    let database = fields.push_text(options.database(), false)?;
27    let sspi = Login7FieldOffset::empty(fields.next_offset);
28    let attach_db_file = Login7FieldOffset::empty(fields.next_offset);
29    let change_password = Login7FieldOffset::empty(fields.next_offset);
30
31    let total_len = u32::from(fields.next_offset);
32    let mut out = Vec::with_capacity(usize::from(fields.next_offset));
33
34    write_u32_le(&mut out, total_len);
35    write_u32_le(&mut out, TDS_VERSION_74);
36    write_u32_le(&mut out, options.requested_packet_size());
37    write_u32_le(&mut out, options.client_program_version());
38    write_u32_le(&mut out, options.client_pid());
39    write_u32_le(&mut out, 0);
40    out.extend_from_slice(&[OPTION_FLAGS_1, OPTION_FLAGS_2, TYPE_FLAGS, OPTION_FLAGS_3]);
41    write_i32_le(&mut out, 0);
42    write_u32_le(&mut out, 0);
43
44    for offset in [
45        hostname,
46        username,
47        password,
48        app_name,
49        server_name,
50        unused,
51        client_interface_name,
52        language,
53        database,
54    ] {
55        offset.write_to(&mut out);
56    }
57
58    out.extend_from_slice(&[0; 6]);
59    sspi.write_to(&mut out);
60    attach_db_file.write_to(&mut out);
61    change_password.write_to(&mut out);
62    write_u32_le(&mut out, 0);
63
64    debug_assert_eq!(LOGIN7_FIXED_LEN, out.len());
65    out.extend_from_slice(&fields.data);
66
67    Ok(out)
68}
69
70/// Builds framed TDS LOGIN7 packet bytes from connection options.
71pub fn build_login7_packet(options: &MssqlConnectOptions) -> Result<Vec<u8>, Login7Error> {
72    let payload = build_login7_payload(options)?;
73
74    encode_message(
75        PacketType::LOGIN7,
76        &payload,
77        usize::try_from(options.requested_packet_size())
78            .map_err(|_| Login7Error::MessageTooLarge)?,
79    )
80    .map_err(Login7Error::Packet)
81}
82
83#[derive(Debug, Clone, Copy, PartialEq, Eq)]
84struct Login7FieldOffset {
85    offset: u16,
86    len_chars: u16,
87}
88
89impl Login7FieldOffset {
90    fn empty(offset: u16) -> Self {
91        Self {
92            offset,
93            len_chars: 0,
94        }
95    }
96
97    fn write_to(self, out: &mut Vec<u8>) {
98        write_u16_le(out, self.offset);
99        write_u16_le(out, self.len_chars);
100    }
101}
102
103struct Login7Fields {
104    data: Vec<u8>,
105    next_offset: u16,
106}
107
108impl Login7Fields {
109    fn new(base_offset: usize) -> Self {
110        Self {
111            data: Vec::new(),
112            next_offset: u16::try_from(base_offset).expect("LOGIN7 fixed header fits in u16"),
113        }
114    }
115
116    fn push_text(
117        &mut self,
118        value: &str,
119        obfuscate: bool,
120    ) -> Result<Login7FieldOffset, Login7Error> {
121        let offset = self.next_offset;
122        let len_chars =
123            u16::try_from(value.encode_utf16().count()).map_err(|_| Login7Error::FieldTooLong)?;
124        let encoded = encode_utf16_le(value, obfuscate);
125        let encoded_len = u16::try_from(encoded.len()).map_err(|_| Login7Error::MessageTooLarge)?;
126
127        self.next_offset = self
128            .next_offset
129            .checked_add(encoded_len)
130            .ok_or(Login7Error::MessageTooLarge)?;
131        self.data.extend_from_slice(&encoded);
132
133        Ok(Login7FieldOffset { offset, len_chars })
134    }
135}
136
137fn encode_utf16_le(value: &str, obfuscate: bool) -> Vec<u8> {
138    let mut out = Vec::with_capacity(value.len() * 2);
139
140    for unit in value.encode_utf16() {
141        out.extend_from_slice(&unit.to_le_bytes());
142    }
143
144    if obfuscate {
145        for byte in &mut out {
146            *byte = byte.rotate_left(4) ^ 0xa5;
147        }
148    }
149
150    out
151}
152
153fn write_u16_le(out: &mut Vec<u8>, value: u16) {
154    out.extend_from_slice(&value.to_le_bytes());
155}
156
157fn write_u32_le(out: &mut Vec<u8>, value: u32) {
158    out.extend_from_slice(&value.to_le_bytes());
159}
160
161fn write_i32_le(out: &mut Vec<u8>, value: i32) {
162    out.extend_from_slice(&value.to_le_bytes());
163}
164
165/// Error returned while building a LOGIN7 packet.
166#[derive(Debug, Error, PartialEq, Eq)]
167pub enum Login7Error {
168    /// A text field exceeds the 16-bit LOGIN7 character-count field.
169    #[error("TDS LOGIN7 text field is too long")]
170    FieldTooLong,
171    /// The payload cannot fit in LOGIN7's 16-bit offset fields.
172    #[error("TDS LOGIN7 message is too large")]
173    MessageTooLarge,
174    /// Packet framing failed.
175    #[error(transparent)]
176    Packet(#[from] PacketFrameError),
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182    use crate::protocol::packet::{PacketHeader, PacketStatus, PACKET_HEADER_LEN};
183
184    #[test]
185    fn builds_login7_payload_with_little_endian_fixed_fields() {
186        let options = MssqlConnectOptions::parse_url(
187            "mssql://alice:secret@example.com/appdb?packet_size=512&client_program_version=42&client_pid=7",
188        )
189        .unwrap();
190
191        let payload = build_login7_payload(&options).unwrap();
192
193        assert_eq!(
194            payload.len() as u32,
195            u32::from_le_bytes(payload[0..4].try_into().unwrap())
196        );
197        assert_eq!(
198            TDS_VERSION_74,
199            u32::from_le_bytes(payload[4..8].try_into().unwrap())
200        );
201        assert_eq!(512, u32::from_le_bytes(payload[8..12].try_into().unwrap()));
202        assert_eq!(42, u32::from_le_bytes(payload[12..16].try_into().unwrap()));
203        assert_eq!(7, u32::from_le_bytes(payload[16..20].try_into().unwrap()));
204        assert_eq!(
205            [OPTION_FLAGS_1, OPTION_FLAGS_2, TYPE_FLAGS, OPTION_FLAGS_3],
206            payload[24..28]
207        );
208    }
209
210    #[test]
211    fn encodes_variable_fields_as_utf16_with_character_lengths() {
212        let options = MssqlConnectOptions::parse_url(
213            "mssql://al:pw@example.com/db?hostname=client&app_name=sqlx",
214        )
215        .unwrap();
216        let payload = build_login7_payload(&options).unwrap();
217
218        let hostname = field_at(&payload, 36);
219        let username = field_at(&payload, 40);
220        let password = field_at(&payload, 44);
221        let app_name = field_at(&payload, 48);
222        let database = field_at(&payload, 68);
223
224        assert_eq!((94, 6), hostname);
225        assert_eq!(b"c\0l\0i\0e\0n\0t\0", field_bytes(&payload, hostname));
226        assert_eq!((106, 2), username);
227        assert_eq!(b"a\0l\0", field_bytes(&payload, username));
228        assert_eq!((114, 4), app_name);
229        assert_eq!(b"s\0q\0l\0x\0", field_bytes(&payload, app_name));
230        assert_eq!((122, 2), database);
231        assert_eq!(b"d\0b\0", field_bytes(&payload, database));
232
233        let raw_password = encode_utf16_le("pw", true);
234        assert_eq!((110, 2), password);
235        assert_eq!(raw_password.as_slice(), field_bytes(&payload, password));
236        assert_ne!(b"p\0w\0", field_bytes(&payload, password));
237    }
238
239    #[test]
240    fn frames_login7_payload_as_login7_packet() {
241        let options = MssqlConnectOptions::parse_url(
242            "mssql://alice:secret@example.com/master?packet_size=512",
243        )
244        .unwrap();
245        let packet = build_login7_packet(&options).unwrap();
246        let header = PacketHeader::decode(&packet[..PACKET_HEADER_LEN]).unwrap();
247
248        assert_eq!(PacketType::LOGIN7, header.packet_type);
249        assert_eq!(PacketStatus::END_OF_MESSAGE, header.status);
250        assert_eq!(packet.len(), usize::from(header.length));
251        assert_eq!(
252            packet.len() - PACKET_HEADER_LEN,
253            u32::from_le_bytes(
254                packet[PACKET_HEADER_LEN..PACKET_HEADER_LEN + 4]
255                    .try_into()
256                    .unwrap()
257            ) as usize
258        );
259    }
260
261    #[test]
262    fn rejects_text_fields_that_do_not_fit_login7_lengths() {
263        let mut options = MssqlConnectOptions::new();
264        options.set_hostname_for_test("a".repeat(usize::from(u16::MAX) + 1));
265
266        let err = build_login7_payload(&options).unwrap_err();
267
268        assert_eq!(Login7Error::FieldTooLong, err);
269    }
270
271    fn field_at(payload: &[u8], offset: usize) -> (usize, usize) {
272        let start = usize::from(u16::from_le_bytes(
273            payload[offset..offset + 2].try_into().unwrap(),
274        ));
275        let len_chars = usize::from(u16::from_le_bytes(
276            payload[offset + 2..offset + 4].try_into().unwrap(),
277        ));
278
279        (start, len_chars)
280    }
281
282    fn field_bytes(payload: &[u8], field: (usize, usize)) -> &[u8] {
283        &payload[field.0..field.0 + field.1 * 2]
284    }
285}