1use bytes::{BufMut, Bytes, BytesMut};
19
20use crate::codec::write_utf16_string;
21use crate::prelude::*;
22use crate::version::TdsVersion;
23
24pub const LOGIN7_HEADER_SIZE: usize = 94;
26
27#[derive(Debug, Clone, Copy, Default)]
29pub struct OptionFlags1 {
30 pub byte_order_be: bool,
32 pub char_ebcdic: bool,
34 pub float_ieee: bool,
36 pub dump_load_off: bool,
38 pub use_db_notify: bool,
40 pub database_fatal: bool,
42 pub set_lang_warn: bool,
44}
45
46impl OptionFlags1 {
47 #[must_use]
58 pub fn to_byte(&self) -> u8 {
59 let mut flags = 0u8;
60 if self.byte_order_be {
61 flags |= 0x01; }
63 if self.char_ebcdic {
64 flags |= 0x02; }
66 if self.dump_load_off {
69 flags |= 0x10; }
71 if self.use_db_notify {
72 flags |= 0x20; }
74 if self.database_fatal {
75 flags |= 0x40; }
77 if self.set_lang_warn {
78 flags |= 0x80; }
80 flags
81 }
82}
83
84#[derive(Debug, Clone, Copy, Default)]
86pub struct OptionFlags2 {
87 pub language_fatal: bool,
89 pub odbc: bool,
91 pub tran_boundary: bool,
93 pub cache_connect: bool,
95 pub user_type: u8,
97 pub integrated_security: bool,
99}
100
101impl OptionFlags2 {
102 #[must_use]
104 pub fn to_byte(&self) -> u8 {
105 let mut flags = 0u8;
106 if self.language_fatal {
107 flags |= 0x01;
108 }
109 if self.odbc {
110 flags |= 0x02;
111 }
112 if self.tran_boundary {
113 flags |= 0x04;
114 }
115 if self.cache_connect {
116 flags |= 0x08;
117 }
118 flags |= (self.user_type & 0x07) << 4;
119 if self.integrated_security {
120 flags |= 0x80;
121 }
122 flags
123 }
124}
125
126#[derive(Debug, Clone, Copy, Default)]
128pub struct TypeFlags {
129 pub sql_type: u8,
131 pub oledb: bool,
133 pub read_only_intent: bool,
135}
136
137impl TypeFlags {
138 #[must_use]
140 pub fn to_byte(&self) -> u8 {
141 let mut flags = 0u8;
142 flags |= self.sql_type & 0x0F;
143 if self.oledb {
144 flags |= 0x10;
145 }
146 if self.read_only_intent {
147 flags |= 0x20;
148 }
149 flags
150 }
151}
152
153#[derive(Debug, Clone, Copy, Default)]
155pub struct OptionFlags3 {
156 pub change_password: bool,
158 pub user_instance: bool,
160 pub send_yukon_binary_xml: bool,
162 pub unknown_collation_handling: bool,
164 pub extension: bool,
166}
167
168impl OptionFlags3 {
169 #[must_use]
171 pub fn to_byte(&self) -> u8 {
172 let mut flags = 0u8;
173 if self.change_password {
174 flags |= 0x01;
175 }
176 if self.user_instance {
177 flags |= 0x02;
178 }
179 if self.send_yukon_binary_xml {
180 flags |= 0x04;
181 }
182 if self.unknown_collation_handling {
183 flags |= 0x08;
184 }
185 if self.extension {
186 flags |= 0x10;
187 }
188 flags
189 }
190}
191
192#[derive(Debug, Clone, Copy, PartialEq, Eq)]
194#[repr(u8)]
195#[non_exhaustive]
196pub enum FeatureId {
197 SessionRecovery = 0x01,
199 FedAuth = 0x02,
201 ColumnEncryption = 0x04,
203 GlobalTransactions = 0x05,
205 AzureSqlSupport = 0x08,
207 DataClassification = 0x09,
209 Utf8Support = 0x0A,
211 AzureSqlDnsCaching = 0x0B,
213 Terminator = 0xFF,
215}
216
217#[derive(Debug, Clone)]
219pub struct Login7 {
220 pub tds_version: TdsVersion,
222 pub packet_size: u32,
224 pub client_prog_version: u32,
226 pub client_pid: u32,
228 pub connection_id: u32,
230 pub option_flags1: OptionFlags1,
232 pub option_flags2: OptionFlags2,
234 pub type_flags: TypeFlags,
236 pub option_flags3: OptionFlags3,
238 pub client_timezone: i32,
240 pub client_lcid: u32,
242 pub hostname: String,
244 pub username: String,
246 pub password: String,
248 pub app_name: String,
250 pub server_name: String,
252 pub unused: String,
254 pub library_name: String,
256 pub language: String,
258 pub database: String,
260 pub client_id: [u8; 6],
262 pub sspi_data: Vec<u8>,
264 pub attach_db_file: String,
266 pub new_password: String,
268 pub features: Vec<FeatureExtension>,
270}
271
272#[derive(Debug, Clone)]
274pub struct FeatureExtension {
275 pub feature_id: FeatureId,
277 pub data: Bytes,
279}
280
281impl Default for Login7 {
282 fn default() -> Self {
283 #[cfg(feature = "std")]
284 let client_pid = std::process::id();
285 #[cfg(not(feature = "std"))]
286 let client_pid = 0;
287
288 Self {
289 tds_version: TdsVersion::V7_4,
290 packet_size: 4096,
291 client_prog_version: 0,
292 client_pid,
293 connection_id: 0,
294 option_flags1: OptionFlags1 {
296 use_db_notify: true,
297 database_fatal: true,
298 ..Default::default()
299 },
300 option_flags2: OptionFlags2 {
301 language_fatal: true,
302 odbc: true,
303 ..Default::default()
304 },
305 type_flags: TypeFlags::default(), option_flags3: OptionFlags3 {
307 unknown_collation_handling: true,
308 ..Default::default()
309 },
310 client_timezone: 0,
311 client_lcid: 0x0409, hostname: String::new(),
313 username: String::new(),
314 password: String::new(),
315 app_name: String::from("rust-mssql-driver"),
316 server_name: String::new(),
317 unused: String::new(),
318 library_name: String::from("rust-mssql-driver"),
319 language: String::new(),
320 database: String::new(),
321 client_id: [0u8; 6],
322 sspi_data: Vec::new(),
323 attach_db_file: String::new(),
324 new_password: String::new(),
325 features: Vec::new(),
326 }
327 }
328}
329
330impl Login7 {
331 #[must_use]
333 pub fn new() -> Self {
334 Self::default()
335 }
336
337 #[must_use]
339 pub fn with_tds_version(mut self, version: TdsVersion) -> Self {
340 self.tds_version = version;
341 self
342 }
343
344 #[must_use]
346 pub fn with_sql_auth(
347 mut self,
348 username: impl Into<String>,
349 password: impl Into<String>,
350 ) -> Self {
351 self.username = username.into();
352 self.password = password.into();
353 self.option_flags2.integrated_security = false;
354 self
355 }
356
357 #[must_use]
359 pub fn with_integrated_auth(mut self, sspi_data: Vec<u8>) -> Self {
360 self.sspi_data = sspi_data;
361 self.option_flags2.integrated_security = true;
362 self
363 }
364
365 #[must_use]
367 pub fn with_database(mut self, database: impl Into<String>) -> Self {
368 self.database = database.into();
369 self
370 }
371
372 #[must_use]
374 pub fn with_hostname(mut self, hostname: impl Into<String>) -> Self {
375 self.hostname = hostname.into();
376 self
377 }
378
379 #[must_use]
381 pub fn with_app_name(mut self, app_name: impl Into<String>) -> Self {
382 self.app_name = app_name.into();
383 self
384 }
385
386 #[must_use]
388 pub fn with_server_name(mut self, server_name: impl Into<String>) -> Self {
389 self.server_name = server_name.into();
390 self
391 }
392
393 #[must_use]
398 pub fn with_language(mut self, language: impl Into<String>) -> Self {
399 self.language = language.into();
400 self
401 }
402
403 #[must_use]
405 pub fn with_packet_size(mut self, packet_size: u32) -> Self {
406 self.packet_size = packet_size;
407 self
408 }
409
410 #[must_use]
412 pub fn with_read_only_intent(mut self, read_only: bool) -> Self {
413 self.type_flags.read_only_intent = read_only;
414 self
415 }
416
417 #[must_use]
419 pub fn with_feature(mut self, feature: FeatureExtension) -> Self {
420 self.option_flags3.extension = true;
421 self.features.push(feature);
422 self
423 }
424
425 #[must_use]
427 pub fn encode(&self) -> Bytes {
428 let mut buf = BytesMut::with_capacity(512);
429
430 let mut offset = LOGIN7_HEADER_SIZE as u16;
433
434 let hostname_len = self.hostname.encode_utf16().count() as u16;
436 let username_len = self.username.encode_utf16().count() as u16;
437 let password_len = self.password.encode_utf16().count() as u16;
438 let app_name_len = self.app_name.encode_utf16().count() as u16;
439 let server_name_len = self.server_name.encode_utf16().count() as u16;
440 let unused_len = self.unused.encode_utf16().count() as u16;
441 let library_name_len = self.library_name.encode_utf16().count() as u16;
442 let language_len = self.language.encode_utf16().count() as u16;
443 let database_len = self.database.encode_utf16().count() as u16;
444 let sspi_len = self.sspi_data.len() as u16;
445 let attach_db_len = self.attach_db_file.encode_utf16().count() as u16;
446 let new_password_len = self.new_password.encode_utf16().count() as u16;
447
448 let mut var_data = BytesMut::new();
450
451 let hostname_offset = offset;
453 write_utf16_string(&mut var_data, &self.hostname);
454 offset += hostname_len * 2;
455
456 let username_offset = offset;
458 write_utf16_string(&mut var_data, &self.username);
459 offset += username_len * 2;
460
461 let password_offset = offset;
463 Self::write_obfuscated_password(&mut var_data, &self.password);
464 offset += password_len * 2;
465
466 let app_name_offset = offset;
468 write_utf16_string(&mut var_data, &self.app_name);
469 offset += app_name_len * 2;
470
471 let server_name_offset = offset;
473 write_utf16_string(&mut var_data, &self.server_name);
474 offset += server_name_len * 2;
475
476 let extension_offset = if self.option_flags3.extension {
493 let feature_data_offset = offset
496 + 4 + library_name_len * 2
498 + language_len * 2
499 + database_len * 2
500 + sspi_len
501 + attach_db_len * 2
502 + new_password_len * 2;
503 let pointer_offset = offset;
504 var_data.put_u32_le(feature_data_offset as u32);
507 offset += 4;
508 pointer_offset
509 } else {
510 let unused_offset = offset;
511 write_utf16_string(&mut var_data, &self.unused);
512 offset += unused_len * 2;
513 unused_offset
514 };
515
516 let library_name_offset = offset;
518 write_utf16_string(&mut var_data, &self.library_name);
519 offset += library_name_len * 2;
520
521 let language_offset = offset;
523 write_utf16_string(&mut var_data, &self.language);
524 offset += language_len * 2;
525
526 let database_offset = offset;
528 write_utf16_string(&mut var_data, &self.database);
529 offset += database_len * 2;
530
531 let sspi_offset = offset;
536 var_data.put_slice(&self.sspi_data);
537 offset += sspi_len;
538
539 let attach_db_offset = offset;
541 write_utf16_string(&mut var_data, &self.attach_db_file);
542 offset += attach_db_len * 2;
543
544 let new_password_offset = offset;
546 if !self.new_password.is_empty() {
547 Self::write_obfuscated_password(&mut var_data, &self.new_password);
548 }
549 #[allow(unused_assignments)]
550 {
551 offset += new_password_len * 2;
552 }
553
554 if self.option_flags3.extension {
556 for feature in &self.features {
557 var_data.put_u8(feature.feature_id as u8);
558 var_data.put_u32_le(feature.data.len() as u32);
559 var_data.put_slice(&feature.data);
560 }
561 var_data.put_u8(FeatureId::Terminator as u8);
562 }
563
564 let total_length = LOGIN7_HEADER_SIZE + var_data.len();
566
567 buf.put_u32_le(total_length as u32); buf.put_u32_le(self.tds_version.raw()); buf.put_u32_le(self.packet_size); buf.put_u32_le(self.client_prog_version); buf.put_u32_le(self.client_pid); buf.put_u32_le(self.connection_id); buf.put_u8(self.option_flags1.to_byte());
577 buf.put_u8(self.option_flags2.to_byte());
578 buf.put_u8(self.type_flags.to_byte());
579 buf.put_u8(self.option_flags3.to_byte());
580
581 buf.put_i32_le(self.client_timezone); buf.put_u32_le(self.client_lcid); buf.put_u16_le(hostname_offset);
586 buf.put_u16_le(hostname_len);
587 buf.put_u16_le(username_offset);
588 buf.put_u16_le(username_len);
589 buf.put_u16_le(password_offset);
590 buf.put_u16_le(password_len);
591 buf.put_u16_le(app_name_offset);
592 buf.put_u16_le(app_name_len);
593 buf.put_u16_le(server_name_offset);
594 buf.put_u16_le(server_name_len);
595
596 if self.option_flags3.extension {
598 buf.put_u16_le(extension_offset as u16);
599 buf.put_u16_le(4); } else {
601 buf.put_u16_le(extension_offset as u16);
602 buf.put_u16_le(unused_len);
603 }
604
605 buf.put_u16_le(library_name_offset);
606 buf.put_u16_le(library_name_len);
607 buf.put_u16_le(language_offset);
608 buf.put_u16_le(language_len);
609 buf.put_u16_le(database_offset);
610 buf.put_u16_le(database_len);
611
612 buf.put_slice(&self.client_id);
614
615 buf.put_u16_le(sspi_offset);
616 buf.put_u16_le(sspi_len);
617 buf.put_u16_le(attach_db_offset);
618 buf.put_u16_le(attach_db_len);
619 buf.put_u16_le(new_password_offset);
620 buf.put_u16_le(new_password_len);
621
622 buf.put_u32_le(0);
624
625 buf.put_slice(&var_data);
627
628 buf.freeze()
629 }
630
631 fn write_obfuscated_password(dst: &mut impl BufMut, password: &str) {
636 for c in password.encode_utf16() {
637 let low = (c & 0xFF) as u8;
638 let high = ((c >> 8) & 0xFF) as u8;
639
640 let low_enc = low.rotate_right(4) ^ 0xA5;
643 let high_enc = high.rotate_right(4) ^ 0xA5;
644
645 dst.put_u8(low_enc);
646 dst.put_u8(high_enc);
647 }
648 }
649}
650
651#[cfg(test)]
652#[allow(clippy::unwrap_used)]
653mod tests {
654 use super::*;
655
656 #[test]
657 fn test_login7_default() {
658 let login = Login7::new();
659 assert_eq!(login.tds_version, TdsVersion::V7_4);
660 assert_eq!(login.packet_size, 4096);
661 assert!(login.option_flags2.odbc);
662 }
663
664 #[test]
665 fn test_login7_encode() {
666 let login = Login7::new()
667 .with_hostname("TESTHOST")
668 .with_sql_auth("testuser", "testpass")
669 .with_database("testdb")
670 .with_app_name("TestApp");
671
672 let encoded = login.encode();
673
674 assert!(encoded.len() >= LOGIN7_HEADER_SIZE);
676
677 let tds_version = u32::from_le_bytes([encoded[4], encoded[5], encoded[6], encoded[7]]);
679 assert_eq!(tds_version, TdsVersion::V7_4.raw());
680 }
681
682 #[test]
683 fn test_password_obfuscation() {
684 let mut buf = BytesMut::new();
686 Login7::write_obfuscated_password(&mut buf, "a");
687
688 assert_eq!(buf.len(), 2);
693 assert_eq!(buf[0], 0xB3);
694 assert_eq!(buf[1], 0xA5);
695 }
696
697 #[test]
711 fn test_login7_feature_extension_pointer_indirection() {
712 let login = Login7::new()
713 .with_hostname("HOST")
714 .with_sql_auth("u", "p")
715 .with_database("db")
716 .with_app_name("app")
717 .with_feature(FeatureExtension {
718 feature_id: FeatureId::ColumnEncryption,
719 data: Bytes::from_static(&[0x01]),
720 });
721
722 let encoded = login.encode();
723 assert!(encoded.len() >= LOGIN7_HEADER_SIZE);
724
725 assert_eq!(
740 encoded[27] & 0x10,
741 0x10,
742 "option_flags3.extension bit must be set"
743 );
744
745 const OFFSET_TABLE_START: usize = 36;
754 const EXTENSION_SLOT: usize = OFFSET_TABLE_START + 5 * 4; let ib_extension =
756 u16::from_le_bytes([encoded[EXTENSION_SLOT], encoded[EXTENSION_SLOT + 1]]) as usize;
757 let cb_extension =
758 u16::from_le_bytes([encoded[EXTENSION_SLOT + 2], encoded[EXTENSION_SLOT + 3]]);
759 assert_eq!(cb_extension, 4, "cbExtension must be 4 per MS-TDS §2.2.6.4");
760
761 assert!(
763 ib_extension + 4 <= encoded.len(),
764 "ibExtension out of bounds"
765 );
766
767 let feature_ext_offset = u32::from_le_bytes([
771 encoded[ib_extension],
772 encoded[ib_extension + 1],
773 encoded[ib_extension + 2],
774 encoded[ib_extension + 3],
775 ]) as usize;
776 assert!(
777 feature_ext_offset + 6 <= encoded.len(), "FeatureExt offset {feature_ext_offset} out of bounds (packet is {} bytes)",
779 encoded.len()
780 );
781 assert_eq!(
782 encoded[feature_ext_offset], 0x04,
783 "first byte of FeatureExt block should be FeatureId::ColumnEncryption (0x04)"
784 );
785 let data_len = u32::from_le_bytes([
786 encoded[feature_ext_offset + 1],
787 encoded[feature_ext_offset + 2],
788 encoded[feature_ext_offset + 3],
789 encoded[feature_ext_offset + 4],
790 ]);
791 assert_eq!(data_len, 1, "ColumnEncryption version payload is 1 byte");
792 assert_eq!(
793 encoded[feature_ext_offset + 5],
794 0x01,
795 "ColumnEncryption payload is version byte 0x01"
796 );
797 assert_eq!(
798 encoded[feature_ext_offset + 6],
799 0xFF,
800 "FeatureExt stream terminator 0xFF must follow"
801 );
802
803 assert!(
808 ib_extension < feature_ext_offset,
809 "ibExtension ({ib_extension}) must point at the u32 pointer, \
810 which lives before FeatureExt data ({feature_ext_offset})"
811 );
812 }
813
814 #[test]
815 fn test_option_flags() {
816 let flags1 = OptionFlags1::default();
817 assert_eq!(flags1.to_byte(), 0x00);
818
819 let flags2 = OptionFlags2 {
820 odbc: true,
821 integrated_security: true,
822 ..Default::default()
823 };
824 assert_eq!(flags2.to_byte(), 0x82);
825
826 let flags3 = OptionFlags3 {
827 extension: true,
828 ..Default::default()
829 };
830 assert_eq!(flags3.to_byte(), 0x10);
831 }
832}