1use bytes::{Buf, BufMut, Bytes};
33
34use crate::codec::{read_b_varchar, read_us_varchar};
35use crate::error::ProtocolError;
36use crate::prelude::*;
37use crate::types::TypeId;
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
41#[repr(u8)]
42#[non_exhaustive]
43pub enum TokenType {
44 ColMetaData = 0x81,
46 Error = 0xAA,
48 Info = 0xAB,
50 LoginAck = 0xAD,
52 Row = 0xD1,
54 NbcRow = 0xD2,
56 EnvChange = 0xE3,
58 Sspi = 0xED,
60 Done = 0xFD,
62 DoneInProc = 0xFF,
64 DoneProc = 0xFE,
66 ReturnStatus = 0x79,
68 ReturnValue = 0xAC,
70 Order = 0xA9,
72 FeatureExtAck = 0xAE,
74 SessionState = 0xE4,
76 FedAuthInfo = 0xEE,
78 ColInfo = 0xA5,
80 TabName = 0xA4,
82 Offset = 0x78,
84}
85
86impl TokenType {
87 pub fn from_u8(value: u8) -> Option<Self> {
89 match value {
90 0x81 => Some(Self::ColMetaData),
91 0xAA => Some(Self::Error),
92 0xAB => Some(Self::Info),
93 0xAD => Some(Self::LoginAck),
94 0xD1 => Some(Self::Row),
95 0xD2 => Some(Self::NbcRow),
96 0xE3 => Some(Self::EnvChange),
97 0xED => Some(Self::Sspi),
98 0xFD => Some(Self::Done),
99 0xFF => Some(Self::DoneInProc),
100 0xFE => Some(Self::DoneProc),
101 0x79 => Some(Self::ReturnStatus),
102 0xAC => Some(Self::ReturnValue),
103 0xA9 => Some(Self::Order),
104 0xAE => Some(Self::FeatureExtAck),
105 0xE4 => Some(Self::SessionState),
106 0xEE => Some(Self::FedAuthInfo),
107 0xA5 => Some(Self::ColInfo),
108 0xA4 => Some(Self::TabName),
109 0x78 => Some(Self::Offset),
110 _ => None,
111 }
112 }
113}
114
115#[derive(Debug, Clone)]
120#[non_exhaustive]
121pub enum Token {
122 ColMetaData(ColMetaData),
124 Row(RawRow),
126 NbcRow(NbcRow),
128 Done(Done),
130 DoneProc(DoneProc),
132 DoneInProc(DoneInProc),
134 ReturnStatus(i32),
136 ReturnValue(ReturnValue),
138 Error(ServerError),
140 Info(ServerInfo),
142 LoginAck(LoginAck),
144 EnvChange(EnvChange),
146 Order(Order),
148 FeatureExtAck(FeatureExtAck),
150 Sspi(SspiToken),
152 SessionState(SessionState),
154 FedAuthInfo(FedAuthInfo),
156}
157
158#[derive(Debug, Clone, Default)]
160pub struct ColMetaData {
161 pub columns: Vec<ColumnData>,
163 pub cek_table: Option<crate::crypto::CekTable>,
166}
167
168#[derive(Debug, Clone)]
170pub struct ColumnData {
171 pub name: String,
173 pub type_id: TypeId,
175 pub col_type: u8,
177 pub flags: u16,
179 pub user_type: u32,
181 pub type_info: TypeInfo,
183 pub crypto_metadata: Option<crate::crypto::CryptoMetadata>,
186}
187
188#[derive(Debug, Clone, Default)]
190pub struct TypeInfo {
191 pub max_length: Option<u32>,
193 pub precision: Option<u8>,
195 pub scale: Option<u8>,
197 pub collation: Option<Collation>,
199}
200
201#[derive(Debug, Clone, Copy, Default)]
224pub struct Collation {
225 pub lcid: u32,
233 pub sort_id: u8,
237}
238
239impl Collation {
240 pub fn from_bytes(bytes: &[u8; 5]) -> Self {
244 Self {
245 lcid: u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]),
246 sort_id: bytes[4],
247 }
248 }
249
250 pub fn to_bytes(&self) -> [u8; 5] {
254 let b = self.lcid.to_le_bytes();
255 [b[0], b[1], b[2], b[3], self.sort_id]
256 }
257
258 #[cfg(feature = "encoding")]
284 pub fn encoding(&self) -> Option<&'static encoding_rs::Encoding> {
285 if self.sort_id != 0 {
289 return crate::collation::encoding_for_sort_id(self.sort_id);
290 }
291 crate::collation::encoding_for_lcid(self.lcid)
292 }
293
294 #[cfg(feature = "encoding")]
299 pub fn is_utf8(&self) -> bool {
300 crate::collation::is_utf8_collation(self.lcid)
301 }
302
303 #[cfg(feature = "encoding")]
311 pub fn code_page(&self) -> Option<u16> {
312 if self.sort_id != 0 {
315 return crate::collation::code_page_for_sort_id(self.sort_id);
316 }
317 crate::collation::code_page_for_lcid(self.lcid)
318 }
319
320 #[cfg(feature = "encoding")]
324 pub fn encoding_name(&self) -> &'static str {
325 if self.sort_id != 0 {
326 return match crate::collation::encoding_for_sort_id(self.sort_id) {
327 Some(enc) => enc.name(),
328 None => "unsupported",
329 };
330 }
331 crate::collation::encoding_name_for_lcid(self.lcid)
332 }
333}
334
335#[derive(Debug, Clone)]
337pub struct RawRow {
338 pub data: bytes::Bytes,
340}
341
342#[derive(Debug, Clone)]
344pub struct NbcRow {
345 pub null_bitmap: Vec<u8>,
347 pub data: bytes::Bytes,
349}
350
351#[derive(Debug, Clone, Copy)]
353pub struct Done {
354 pub status: DoneStatus,
356 pub cur_cmd: u16,
358 pub row_count: u64,
360}
361
362#[derive(Debug, Clone, Copy, Default)]
364#[non_exhaustive]
365pub struct DoneStatus {
366 pub more: bool,
368 pub error: bool,
370 pub in_xact: bool,
372 pub count: bool,
374 pub attn: bool,
376 pub srverror: bool,
378}
379
380#[derive(Debug, Clone, Copy)]
382pub struct DoneInProc {
383 pub status: DoneStatus,
385 pub cur_cmd: u16,
387 pub row_count: u64,
389}
390
391#[derive(Debug, Clone, Copy)]
393pub struct DoneProc {
394 pub status: DoneStatus,
396 pub cur_cmd: u16,
398 pub row_count: u64,
400}
401
402#[derive(Debug, Clone)]
404#[non_exhaustive]
405pub struct ReturnValue {
406 pub param_ordinal: u16,
408 pub param_name: String,
410 pub status: u8,
412 pub user_type: u32,
414 pub flags: u16,
416 pub col_type: u8,
418 pub type_info: TypeInfo,
420 pub value: bytes::Bytes,
422}
423
424#[derive(Debug, Clone)]
426pub struct ServerError {
427 pub number: i32,
429 pub state: u8,
431 pub class: u8,
433 pub message: String,
435 pub server: String,
437 pub procedure: String,
439 pub line: i32,
441}
442
443#[derive(Debug, Clone)]
445pub struct ServerInfo {
446 pub number: i32,
448 pub state: u8,
450 pub class: u8,
452 pub message: String,
454 pub server: String,
456 pub procedure: String,
458 pub line: i32,
460}
461
462#[derive(Debug, Clone)]
464pub struct LoginAck {
465 pub interface: u8,
467 pub tds_version: u32,
469 pub prog_name: String,
471 pub prog_version: u32,
473}
474
475#[derive(Debug, Clone)]
477pub struct EnvChange {
478 pub env_type: EnvChangeType,
480 pub new_value: EnvChangeValue,
482 pub old_value: EnvChangeValue,
484}
485
486#[derive(Debug, Clone, Copy, PartialEq, Eq)]
488#[repr(u8)]
489#[non_exhaustive]
490pub enum EnvChangeType {
491 Database = 1,
493 Language = 2,
495 CharacterSet = 3,
497 PacketSize = 4,
499 UnicodeSortingLocalId = 5,
501 UnicodeComparisonFlags = 6,
503 SqlCollation = 7,
505 BeginTransaction = 8,
507 CommitTransaction = 9,
509 RollbackTransaction = 10,
511 EnlistDtcTransaction = 11,
513 DefectTransaction = 12,
515 RealTimeLogShipping = 13,
517 PromoteTransaction = 15,
519 TransactionManagerAddress = 16,
521 TransactionEnded = 17,
523 ResetConnectionCompletionAck = 18,
525 UserInstanceStarted = 19,
527 Routing = 20,
529}
530
531#[derive(Debug, Clone)]
533#[non_exhaustive]
534pub enum EnvChangeValue {
535 String(String),
537 Binary(bytes::Bytes),
539 Routing {
541 host: String,
543 port: u16,
545 },
546}
547
548#[derive(Debug, Clone)]
550pub struct Order {
551 pub columns: Vec<u16>,
553}
554
555#[derive(Debug, Clone)]
557pub struct FeatureExtAck {
558 pub features: Vec<FeatureAck>,
560}
561
562#[derive(Debug, Clone)]
564pub struct FeatureAck {
565 pub feature_id: u8,
567 pub data: bytes::Bytes,
569}
570
571#[derive(Debug, Clone)]
573pub struct SspiToken {
574 pub data: bytes::Bytes,
576}
577
578#[derive(Debug, Clone)]
580pub struct SessionState {
581 pub data: bytes::Bytes,
583}
584
585#[derive(Debug, Clone)]
587pub struct FedAuthInfo {
588 pub sts_url: String,
590 pub spn: String,
592}
593
594pub(crate) fn decode_collation(src: &mut impl Buf) -> Result<Collation, ProtocolError> {
602 if src.remaining() < 5 {
603 return Err(ProtocolError::UnexpectedEof);
604 }
605 let lcid = src.get_u32_le();
607 let sort_id = src.get_u8();
608 Ok(Collation { lcid, sort_id })
609}
610
611pub(crate) fn decode_type_info(
615 src: &mut impl Buf,
616 type_id: TypeId,
617 col_type: u8,
618) -> Result<TypeInfo, ProtocolError> {
619 match type_id {
620 TypeId::Null => Ok(TypeInfo::default()),
622 TypeId::Int1 | TypeId::Bit => Ok(TypeInfo::default()),
623 TypeId::Int2 => Ok(TypeInfo::default()),
624 TypeId::Int4 => Ok(TypeInfo::default()),
625 TypeId::Int8 => Ok(TypeInfo::default()),
626 TypeId::Float4 => Ok(TypeInfo::default()),
627 TypeId::Float8 => Ok(TypeInfo::default()),
628 TypeId::Money => Ok(TypeInfo::default()),
629 TypeId::Money4 => Ok(TypeInfo::default()),
630 TypeId::DateTime => Ok(TypeInfo::default()),
631 TypeId::DateTime4 => Ok(TypeInfo::default()),
632
633 TypeId::IntN | TypeId::BitN | TypeId::FloatN | TypeId::MoneyN | TypeId::DateTimeN => {
635 if src.remaining() < 1 {
636 return Err(ProtocolError::UnexpectedEof);
637 }
638 let max_length = src.get_u8() as u32;
639 Ok(TypeInfo {
640 max_length: Some(max_length),
641 ..Default::default()
642 })
643 }
644
645 TypeId::Guid => {
647 if src.remaining() < 1 {
648 return Err(ProtocolError::UnexpectedEof);
649 }
650 let max_length = src.get_u8() as u32;
651 Ok(TypeInfo {
652 max_length: Some(max_length),
653 ..Default::default()
654 })
655 }
656
657 TypeId::Decimal | TypeId::Numeric | TypeId::DecimalN | TypeId::NumericN => {
659 if src.remaining() < 3 {
660 return Err(ProtocolError::UnexpectedEof);
661 }
662 let max_length = src.get_u8() as u32;
663 let precision = src.get_u8();
664 let scale = src.get_u8();
665 Ok(TypeInfo {
666 max_length: Some(max_length),
667 precision: Some(precision),
668 scale: Some(scale),
669 ..Default::default()
670 })
671 }
672
673 TypeId::Char | TypeId::VarChar | TypeId::Binary | TypeId::VarBinary => {
675 if src.remaining() < 1 {
676 return Err(ProtocolError::UnexpectedEof);
677 }
678 let max_length = src.get_u8() as u32;
679 Ok(TypeInfo {
680 max_length: Some(max_length),
681 ..Default::default()
682 })
683 }
684
685 TypeId::BigVarChar | TypeId::BigChar => {
687 if src.remaining() < 7 {
688 return Err(ProtocolError::UnexpectedEof);
690 }
691 let max_length = src.get_u16_le() as u32;
692 let collation = decode_collation(src)?;
693 Ok(TypeInfo {
694 max_length: Some(max_length),
695 collation: Some(collation),
696 ..Default::default()
697 })
698 }
699
700 TypeId::BigVarBinary | TypeId::BigBinary => {
702 if src.remaining() < 2 {
703 return Err(ProtocolError::UnexpectedEof);
704 }
705 let max_length = src.get_u16_le() as u32;
706 Ok(TypeInfo {
707 max_length: Some(max_length),
708 ..Default::default()
709 })
710 }
711
712 TypeId::NChar | TypeId::NVarChar => {
714 if src.remaining() < 7 {
715 return Err(ProtocolError::UnexpectedEof);
717 }
718 let max_length = src.get_u16_le() as u32;
719 let collation = decode_collation(src)?;
720 Ok(TypeInfo {
721 max_length: Some(max_length),
722 collation: Some(collation),
723 ..Default::default()
724 })
725 }
726
727 TypeId::Date => Ok(TypeInfo::default()),
729
730 TypeId::Time | TypeId::DateTime2 | TypeId::DateTimeOffset => {
732 if src.remaining() < 1 {
733 return Err(ProtocolError::UnexpectedEof);
734 }
735 let scale = src.get_u8();
736 Ok(TypeInfo {
737 scale: Some(scale),
738 ..Default::default()
739 })
740 }
741
742 TypeId::Text | TypeId::NText | TypeId::Image => {
744 if src.remaining() < 4 {
746 return Err(ProtocolError::UnexpectedEof);
747 }
748 let max_length = src.get_u32_le();
749
750 let collation = if type_id == TypeId::Text || type_id == TypeId::NText {
752 if src.remaining() < 5 {
753 return Err(ProtocolError::UnexpectedEof);
754 }
755 Some(decode_collation(src)?)
756 } else {
757 None
758 };
759
760 if src.remaining() < 1 {
763 return Err(ProtocolError::UnexpectedEof);
764 }
765 let num_parts = src.get_u8();
766 for _ in 0..num_parts {
767 let _ = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
769 }
770
771 Ok(TypeInfo {
772 max_length: Some(max_length),
773 collation,
774 ..Default::default()
775 })
776 }
777
778 TypeId::Xml => {
780 if src.remaining() < 1 {
781 return Err(ProtocolError::UnexpectedEof);
782 }
783 let schema_present = src.get_u8();
784
785 if schema_present != 0 {
786 let _ = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; let _ = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; let _ = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; }
793
794 Ok(TypeInfo::default())
795 }
796
797 TypeId::Udt => {
799 if src.remaining() < 2 {
801 return Err(ProtocolError::UnexpectedEof);
802 }
803 let max_length = src.get_u16_le() as u32;
804
805 let _ = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; let _ = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; let _ = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; let _ = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; Ok(TypeInfo {
814 max_length: Some(max_length),
815 ..Default::default()
816 })
817 }
818
819 TypeId::Tvp => {
821 Err(ProtocolError::InvalidTokenType(col_type))
824 }
825
826 TypeId::Variant => {
828 if src.remaining() < 4 {
829 return Err(ProtocolError::UnexpectedEof);
830 }
831 let max_length = src.get_u32_le();
832 Ok(TypeInfo {
833 max_length: Some(max_length),
834 ..Default::default()
835 })
836 }
837 }
838}
839
840impl ColMetaData {
841 pub const NO_METADATA: u16 = 0xFFFF;
843
844 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
846 if src.remaining() < 2 {
847 return Err(ProtocolError::UnexpectedEof);
848 }
849
850 let column_count = src.get_u16_le();
851
852 if column_count == Self::NO_METADATA {
854 return Ok(Self {
855 columns: Vec::new(),
856 cek_table: None,
857 });
858 }
859
860 let mut columns = Vec::with_capacity(column_count as usize);
861
862 for _ in 0..column_count {
863 let column = Self::decode_column(src)?;
864 columns.push(column);
865 }
866
867 Ok(Self {
868 columns,
869 cek_table: None,
870 })
871 }
872
873 fn decode_column(src: &mut impl Buf) -> Result<ColumnData, ProtocolError> {
875 if src.remaining() < 7 {
877 return Err(ProtocolError::UnexpectedEof);
878 }
879
880 let user_type = src.get_u32_le();
881 let flags = src.get_u16_le();
882 let col_type = src.get_u8();
883
884 let type_id = TypeId::from_u8(col_type).ok_or(ProtocolError::InvalidDataType(col_type))?;
888
889 let type_info = decode_type_info(src, type_id, col_type)?;
891
892 let name = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
894
895 Ok(ColumnData {
896 name,
897 type_id,
898 col_type,
899 flags,
900 user_type,
901 type_info,
902 crypto_metadata: None,
903 })
904 }
905
906 pub fn decode_encrypted(src: &mut impl Buf) -> Result<Self, ProtocolError> {
919 if src.remaining() < 2 {
920 return Err(ProtocolError::UnexpectedEof);
921 }
922
923 let column_count = src.get_u16_le();
924
925 if column_count == Self::NO_METADATA {
926 return Ok(Self {
927 columns: Vec::new(),
928 cek_table: None,
929 });
930 }
931
932 let cek_table = crate::crypto::CekTable::decode(src)?;
934
935 let mut columns = Vec::with_capacity(column_count as usize);
936
937 for _ in 0..column_count {
938 let column = Self::decode_column_encrypted(src)?;
939 columns.push(column);
940 }
941
942 Ok(Self {
943 columns,
944 cek_table: Some(cek_table),
945 })
946 }
947
948 fn decode_column_encrypted(src: &mut impl Buf) -> Result<ColumnData, ProtocolError> {
952 if src.remaining() < 7 {
953 return Err(ProtocolError::UnexpectedEof);
954 }
955
956 let user_type = src.get_u32_le();
957 let flags = src.get_u16_le();
958 let col_type = src.get_u8();
959
960 let type_id = TypeId::from_u8(col_type).ok_or(ProtocolError::InvalidDataType(col_type))?;
961
962 let type_info = decode_type_info(src, type_id, col_type)?;
964
965 let crypto_metadata = if crate::crypto::is_column_encrypted(flags) {
967 Some(crate::crypto::CryptoMetadata::decode(src)?)
968 } else {
969 None
970 };
971
972 let name = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
974
975 Ok(ColumnData {
976 name,
977 type_id,
978 col_type,
979 flags,
980 user_type,
981 type_info,
982 crypto_metadata,
983 })
984 }
985
986 #[must_use]
988 pub fn column_count(&self) -> usize {
989 self.columns.len()
990 }
991
992 #[must_use]
994 pub fn is_empty(&self) -> bool {
995 self.columns.is_empty()
996 }
997}
998
999impl ColumnData {
1000 #[must_use]
1002 pub fn is_nullable(&self) -> bool {
1003 (self.flags & 0x0001) != 0
1004 }
1005
1006 #[must_use]
1010 pub fn fixed_size(&self) -> Option<usize> {
1011 match self.type_id {
1012 TypeId::Null => Some(0),
1013 TypeId::Int1 | TypeId::Bit => Some(1),
1014 TypeId::Int2 => Some(2),
1015 TypeId::Int4 => Some(4),
1016 TypeId::Int8 => Some(8),
1017 TypeId::Float4 => Some(4),
1018 TypeId::Float8 => Some(8),
1019 TypeId::Money => Some(8),
1020 TypeId::Money4 => Some(4),
1021 TypeId::DateTime => Some(8),
1022 TypeId::DateTime4 => Some(4),
1023 TypeId::Date => Some(3),
1024 _ => None,
1025 }
1026 }
1027}
1028
1029impl RawRow {
1034 pub fn decode(src: &mut impl Buf, metadata: &ColMetaData) -> Result<Self, ProtocolError> {
1039 let mut data = bytes::BytesMut::new();
1040
1041 for col in &metadata.columns {
1042 Self::decode_column_value(src, col, &mut data)?;
1043 }
1044
1045 Ok(Self {
1046 data: data.freeze(),
1047 })
1048 }
1049
1050 fn decode_column_value(
1052 src: &mut impl Buf,
1053 col: &ColumnData,
1054 dst: &mut bytes::BytesMut,
1055 ) -> Result<(), ProtocolError> {
1056 match col.type_id {
1057 TypeId::Null => {
1059 }
1061 TypeId::Int1 | TypeId::Bit => {
1062 if src.remaining() < 1 {
1063 return Err(ProtocolError::UnexpectedEof);
1064 }
1065 dst.extend_from_slice(&[src.get_u8()]);
1066 }
1067 TypeId::Int2 => {
1068 if src.remaining() < 2 {
1069 return Err(ProtocolError::UnexpectedEof);
1070 }
1071 dst.extend_from_slice(&src.get_u16_le().to_le_bytes());
1072 }
1073 TypeId::Int4 => {
1074 if src.remaining() < 4 {
1075 return Err(ProtocolError::UnexpectedEof);
1076 }
1077 dst.extend_from_slice(&src.get_u32_le().to_le_bytes());
1078 }
1079 TypeId::Int8 => {
1080 if src.remaining() < 8 {
1081 return Err(ProtocolError::UnexpectedEof);
1082 }
1083 dst.extend_from_slice(&src.get_u64_le().to_le_bytes());
1084 }
1085 TypeId::Float4 => {
1086 if src.remaining() < 4 {
1087 return Err(ProtocolError::UnexpectedEof);
1088 }
1089 dst.extend_from_slice(&src.get_u32_le().to_le_bytes());
1090 }
1091 TypeId::Float8 => {
1092 if src.remaining() < 8 {
1093 return Err(ProtocolError::UnexpectedEof);
1094 }
1095 dst.extend_from_slice(&src.get_u64_le().to_le_bytes());
1096 }
1097 TypeId::Money => {
1098 if src.remaining() < 8 {
1099 return Err(ProtocolError::UnexpectedEof);
1100 }
1101 let hi = src.get_u32_le();
1102 let lo = src.get_u32_le();
1103 dst.extend_from_slice(&hi.to_le_bytes());
1104 dst.extend_from_slice(&lo.to_le_bytes());
1105 }
1106 TypeId::Money4 => {
1107 if src.remaining() < 4 {
1108 return Err(ProtocolError::UnexpectedEof);
1109 }
1110 dst.extend_from_slice(&src.get_u32_le().to_le_bytes());
1111 }
1112 TypeId::DateTime => {
1113 if src.remaining() < 8 {
1114 return Err(ProtocolError::UnexpectedEof);
1115 }
1116 let days = src.get_u32_le();
1117 let time = src.get_u32_le();
1118 dst.extend_from_slice(&days.to_le_bytes());
1119 dst.extend_from_slice(&time.to_le_bytes());
1120 }
1121 TypeId::DateTime4 => {
1122 if src.remaining() < 4 {
1123 return Err(ProtocolError::UnexpectedEof);
1124 }
1125 dst.extend_from_slice(&src.get_u32_le().to_le_bytes());
1126 }
1127 TypeId::Date => {
1129 Self::decode_bytelen_type(src, dst)?;
1130 }
1131
1132 TypeId::IntN | TypeId::BitN | TypeId::FloatN | TypeId::MoneyN | TypeId::DateTimeN => {
1134 Self::decode_bytelen_type(src, dst)?;
1135 }
1136
1137 TypeId::Guid => {
1138 Self::decode_bytelen_type(src, dst)?;
1139 }
1140
1141 TypeId::Decimal | TypeId::Numeric | TypeId::DecimalN | TypeId::NumericN => {
1142 Self::decode_bytelen_type(src, dst)?;
1143 }
1144
1145 TypeId::Char | TypeId::VarChar | TypeId::Binary | TypeId::VarBinary => {
1147 Self::decode_bytelen_type(src, dst)?;
1148 }
1149
1150 TypeId::BigVarChar | TypeId::BigVarBinary => {
1152 if col.type_info.max_length == Some(0xFFFF) {
1154 Self::decode_plp_type(src, dst)?;
1155 } else {
1156 Self::decode_ushortlen_type(src, dst)?;
1157 }
1158 }
1159
1160 TypeId::BigChar | TypeId::BigBinary => {
1162 Self::decode_ushortlen_type(src, dst)?;
1163 }
1164
1165 TypeId::NVarChar => {
1167 if col.type_info.max_length == Some(0xFFFF) {
1169 Self::decode_plp_type(src, dst)?;
1170 } else {
1171 Self::decode_ushortlen_type(src, dst)?;
1172 }
1173 }
1174
1175 TypeId::NChar => {
1177 Self::decode_ushortlen_type(src, dst)?;
1178 }
1179
1180 TypeId::Time | TypeId::DateTime2 | TypeId::DateTimeOffset => {
1182 Self::decode_bytelen_type(src, dst)?;
1183 }
1184
1185 TypeId::Text | TypeId::NText | TypeId::Image => {
1187 Self::decode_textptr_type(src, dst)?;
1188 }
1189
1190 TypeId::Xml => {
1192 Self::decode_plp_type(src, dst)?;
1193 }
1194
1195 TypeId::Variant => {
1197 Self::decode_intlen_type(src, dst)?;
1198 }
1199
1200 TypeId::Udt => {
1201 Self::decode_plp_type(src, dst)?;
1203 }
1204
1205 TypeId::Tvp => {
1206 return Err(ProtocolError::InvalidTokenType(col.col_type));
1208 }
1209 }
1210
1211 Ok(())
1212 }
1213
1214 fn decode_bytelen_type(
1216 src: &mut impl Buf,
1217 dst: &mut bytes::BytesMut,
1218 ) -> Result<(), ProtocolError> {
1219 if src.remaining() < 1 {
1220 return Err(ProtocolError::UnexpectedEof);
1221 }
1222 let len = src.get_u8() as usize;
1223 if len == 0xFF {
1224 dst.extend_from_slice(&[0xFF]);
1226 } else if len == 0 {
1227 dst.extend_from_slice(&[0x00]);
1229 } else {
1230 if src.remaining() < len {
1231 return Err(ProtocolError::UnexpectedEof);
1232 }
1233 dst.extend_from_slice(&[len as u8]);
1234 for _ in 0..len {
1235 dst.extend_from_slice(&[src.get_u8()]);
1236 }
1237 }
1238 Ok(())
1239 }
1240
1241 fn decode_ushortlen_type(
1243 src: &mut impl Buf,
1244 dst: &mut bytes::BytesMut,
1245 ) -> Result<(), ProtocolError> {
1246 if src.remaining() < 2 {
1247 return Err(ProtocolError::UnexpectedEof);
1248 }
1249 let len = src.get_u16_le() as usize;
1250 if len == 0xFFFF {
1251 dst.extend_from_slice(&0xFFFFu16.to_le_bytes());
1253 } else if len == 0 {
1254 dst.extend_from_slice(&0u16.to_le_bytes());
1256 } else {
1257 if src.remaining() < len {
1258 return Err(ProtocolError::UnexpectedEof);
1259 }
1260 dst.extend_from_slice(&(len as u16).to_le_bytes());
1261 for _ in 0..len {
1262 dst.extend_from_slice(&[src.get_u8()]);
1263 }
1264 }
1265 Ok(())
1266 }
1267
1268 fn decode_intlen_type(
1270 src: &mut impl Buf,
1271 dst: &mut bytes::BytesMut,
1272 ) -> Result<(), ProtocolError> {
1273 if src.remaining() < 4 {
1274 return Err(ProtocolError::UnexpectedEof);
1275 }
1276 let len = src.get_u32_le() as usize;
1277 if len == 0xFFFFFFFF {
1278 dst.extend_from_slice(&0xFFFFFFFFu32.to_le_bytes());
1280 } else if len == 0 {
1281 dst.extend_from_slice(&0u32.to_le_bytes());
1283 } else {
1284 if src.remaining() < len {
1285 return Err(ProtocolError::UnexpectedEof);
1286 }
1287 dst.extend_from_slice(&(len as u32).to_le_bytes());
1288 for _ in 0..len {
1289 dst.extend_from_slice(&[src.get_u8()]);
1290 }
1291 }
1292 Ok(())
1293 }
1294
1295 fn decode_textptr_type(
1310 src: &mut impl Buf,
1311 dst: &mut bytes::BytesMut,
1312 ) -> Result<(), ProtocolError> {
1313 if src.remaining() < 1 {
1314 return Err(ProtocolError::UnexpectedEof);
1315 }
1316
1317 let textptr_len = src.get_u8() as usize;
1318
1319 if textptr_len == 0 {
1320 dst.extend_from_slice(&0xFFFFFFFFFFFFFFFFu64.to_le_bytes());
1322 return Ok(());
1323 }
1324
1325 if src.remaining() < textptr_len {
1327 return Err(ProtocolError::UnexpectedEof);
1328 }
1329 src.advance(textptr_len);
1330
1331 if src.remaining() < 8 {
1333 return Err(ProtocolError::UnexpectedEof);
1334 }
1335 src.advance(8);
1336
1337 if src.remaining() < 4 {
1339 return Err(ProtocolError::UnexpectedEof);
1340 }
1341 let data_len = src.get_u32_le() as usize;
1342
1343 if src.remaining() < data_len {
1344 return Err(ProtocolError::UnexpectedEof);
1345 }
1346
1347 dst.extend_from_slice(&(data_len as u64).to_le_bytes());
1353 dst.extend_from_slice(&(data_len as u32).to_le_bytes());
1354 for _ in 0..data_len {
1355 dst.extend_from_slice(&[src.get_u8()]);
1356 }
1357 dst.extend_from_slice(&0u32.to_le_bytes()); Ok(())
1360 }
1361
1362 fn decode_plp_type(src: &mut impl Buf, dst: &mut bytes::BytesMut) -> Result<(), ProtocolError> {
1368 if src.remaining() < 8 {
1369 return Err(ProtocolError::UnexpectedEof);
1370 }
1371
1372 let total_len = src.get_u64_le();
1373
1374 dst.extend_from_slice(&total_len.to_le_bytes());
1376
1377 if total_len == 0xFFFFFFFFFFFFFFFF {
1378 return Ok(());
1380 }
1381
1382 loop {
1384 if src.remaining() < 4 {
1385 return Err(ProtocolError::UnexpectedEof);
1386 }
1387 let chunk_len = src.get_u32_le() as usize;
1388 dst.extend_from_slice(&(chunk_len as u32).to_le_bytes());
1389
1390 if chunk_len == 0 {
1391 break;
1393 }
1394
1395 if src.remaining() < chunk_len {
1396 return Err(ProtocolError::UnexpectedEof);
1397 }
1398
1399 for _ in 0..chunk_len {
1400 dst.extend_from_slice(&[src.get_u8()]);
1401 }
1402 }
1403
1404 Ok(())
1405 }
1406}
1407
1408impl NbcRow {
1413 pub fn decode(src: &mut impl Buf, metadata: &ColMetaData) -> Result<Self, ProtocolError> {
1418 let col_count = metadata.columns.len();
1419 let bitmap_len = col_count.div_ceil(8);
1420
1421 if src.remaining() < bitmap_len {
1422 return Err(ProtocolError::UnexpectedEof);
1423 }
1424
1425 let mut null_bitmap = vec![0u8; bitmap_len];
1427 for byte in &mut null_bitmap {
1428 *byte = src.get_u8();
1429 }
1430
1431 let mut data = bytes::BytesMut::new();
1433
1434 for (i, col) in metadata.columns.iter().enumerate() {
1435 let byte_idx = i / 8;
1436 let bit_idx = i % 8;
1437 let is_null = (null_bitmap[byte_idx] & (1 << bit_idx)) != 0;
1438
1439 if !is_null {
1440 RawRow::decode_column_value(src, col, &mut data)?;
1443 }
1444 }
1445
1446 Ok(Self {
1447 null_bitmap,
1448 data: data.freeze(),
1449 })
1450 }
1451
1452 #[must_use]
1454 pub fn is_null(&self, column_index: usize) -> bool {
1455 let byte_idx = column_index / 8;
1456 let bit_idx = column_index % 8;
1457 if byte_idx < self.null_bitmap.len() {
1458 (self.null_bitmap[byte_idx] & (1 << bit_idx)) != 0
1459 } else {
1460 true }
1462 }
1463}
1464
1465impl ReturnValue {
1470 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
1472 if src.remaining() < 2 {
1479 return Err(ProtocolError::UnexpectedEof);
1480 }
1481 let param_ordinal = src.get_u16_le();
1482
1483 let param_name = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
1485
1486 if src.remaining() < 1 {
1488 return Err(ProtocolError::UnexpectedEof);
1489 }
1490 let status = src.get_u8();
1491
1492 if src.remaining() < 7 {
1494 return Err(ProtocolError::UnexpectedEof);
1495 }
1496 let user_type = src.get_u32_le();
1497 let flags = src.get_u16_le();
1498 let col_type = src.get_u8();
1499
1500 let type_id = TypeId::from_u8(col_type).ok_or(ProtocolError::InvalidDataType(col_type))?;
1501
1502 let type_info = decode_type_info(src, type_id, col_type)?;
1504
1505 let mut value_buf = bytes::BytesMut::new();
1507
1508 let temp_col = ColumnData {
1510 name: String::new(),
1511 type_id,
1512 col_type,
1513 flags,
1514 user_type,
1515 type_info: type_info.clone(),
1516 crypto_metadata: None,
1517 };
1518
1519 RawRow::decode_column_value(src, &temp_col, &mut value_buf)?;
1520
1521 Ok(Self {
1522 param_ordinal,
1523 param_name,
1524 status,
1525 user_type,
1526 flags,
1527 col_type,
1528 type_info,
1529 value: value_buf.freeze(),
1530 })
1531 }
1532}
1533
1534impl SessionState {
1539 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
1541 if src.remaining() < 4 {
1542 return Err(ProtocolError::UnexpectedEof);
1543 }
1544
1545 let length = src.get_u32_le() as usize;
1546
1547 if src.remaining() < length {
1548 return Err(ProtocolError::IncompletePacket {
1549 expected: length,
1550 actual: src.remaining(),
1551 });
1552 }
1553
1554 let data = src.copy_to_bytes(length);
1555
1556 Ok(Self { data })
1557 }
1558}
1559
1560mod done_status_bits {
1566 pub const DONE_MORE: u16 = 0x0001;
1567 pub const DONE_ERROR: u16 = 0x0002;
1568 pub const DONE_INXACT: u16 = 0x0004;
1569 pub const DONE_COUNT: u16 = 0x0010;
1570 pub const DONE_ATTN: u16 = 0x0020;
1571 pub const DONE_SRVERROR: u16 = 0x0100;
1572}
1573
1574impl DoneStatus {
1575 #[must_use]
1577 pub fn from_bits(bits: u16) -> Self {
1578 use done_status_bits::*;
1579 Self {
1580 more: (bits & DONE_MORE) != 0,
1581 error: (bits & DONE_ERROR) != 0,
1582 in_xact: (bits & DONE_INXACT) != 0,
1583 count: (bits & DONE_COUNT) != 0,
1584 attn: (bits & DONE_ATTN) != 0,
1585 srverror: (bits & DONE_SRVERROR) != 0,
1586 }
1587 }
1588
1589 #[must_use]
1591 pub fn to_bits(&self) -> u16 {
1592 use done_status_bits::*;
1593 let mut bits = 0u16;
1594 if self.more {
1595 bits |= DONE_MORE;
1596 }
1597 if self.error {
1598 bits |= DONE_ERROR;
1599 }
1600 if self.in_xact {
1601 bits |= DONE_INXACT;
1602 }
1603 if self.count {
1604 bits |= DONE_COUNT;
1605 }
1606 if self.attn {
1607 bits |= DONE_ATTN;
1608 }
1609 if self.srverror {
1610 bits |= DONE_SRVERROR;
1611 }
1612 bits
1613 }
1614}
1615
1616impl Done {
1617 pub const SIZE: usize = 12; pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
1622 if src.remaining() < Self::SIZE {
1623 return Err(ProtocolError::IncompletePacket {
1624 expected: Self::SIZE,
1625 actual: src.remaining(),
1626 });
1627 }
1628
1629 let status = DoneStatus::from_bits(src.get_u16_le());
1630 let cur_cmd = src.get_u16_le();
1631 let row_count = src.get_u64_le();
1632
1633 Ok(Self {
1634 status,
1635 cur_cmd,
1636 row_count,
1637 })
1638 }
1639
1640 pub fn encode(&self, dst: &mut impl BufMut) {
1642 dst.put_u8(TokenType::Done as u8);
1643 dst.put_u16_le(self.status.to_bits());
1644 dst.put_u16_le(self.cur_cmd);
1645 dst.put_u64_le(self.row_count);
1646 }
1647
1648 #[must_use]
1650 pub const fn has_more(&self) -> bool {
1651 self.status.more
1652 }
1653
1654 #[must_use]
1656 pub const fn has_error(&self) -> bool {
1657 self.status.error
1658 }
1659
1660 #[must_use]
1662 pub const fn has_count(&self) -> bool {
1663 self.status.count
1664 }
1665}
1666
1667impl DoneProc {
1668 pub const SIZE: usize = 12;
1670
1671 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
1673 if src.remaining() < Self::SIZE {
1674 return Err(ProtocolError::IncompletePacket {
1675 expected: Self::SIZE,
1676 actual: src.remaining(),
1677 });
1678 }
1679
1680 let status = DoneStatus::from_bits(src.get_u16_le());
1681 let cur_cmd = src.get_u16_le();
1682 let row_count = src.get_u64_le();
1683
1684 Ok(Self {
1685 status,
1686 cur_cmd,
1687 row_count,
1688 })
1689 }
1690
1691 pub fn encode(&self, dst: &mut impl BufMut) {
1693 dst.put_u8(TokenType::DoneProc as u8);
1694 dst.put_u16_le(self.status.to_bits());
1695 dst.put_u16_le(self.cur_cmd);
1696 dst.put_u64_le(self.row_count);
1697 }
1698}
1699
1700impl DoneInProc {
1701 pub const SIZE: usize = 12;
1703
1704 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
1706 if src.remaining() < Self::SIZE {
1707 return Err(ProtocolError::IncompletePacket {
1708 expected: Self::SIZE,
1709 actual: src.remaining(),
1710 });
1711 }
1712
1713 let status = DoneStatus::from_bits(src.get_u16_le());
1714 let cur_cmd = src.get_u16_le();
1715 let row_count = src.get_u64_le();
1716
1717 Ok(Self {
1718 status,
1719 cur_cmd,
1720 row_count,
1721 })
1722 }
1723
1724 pub fn encode(&self, dst: &mut impl BufMut) {
1726 dst.put_u8(TokenType::DoneInProc as u8);
1727 dst.put_u16_le(self.status.to_bits());
1728 dst.put_u16_le(self.cur_cmd);
1729 dst.put_u64_le(self.row_count);
1730 }
1731}
1732
1733impl ServerError {
1734 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
1736 if src.remaining() < 2 {
1739 return Err(ProtocolError::UnexpectedEof);
1740 }
1741
1742 let _length = src.get_u16_le();
1743
1744 if src.remaining() < 6 {
1745 return Err(ProtocolError::UnexpectedEof);
1746 }
1747
1748 let number = src.get_i32_le();
1749 let state = src.get_u8();
1750 let class = src.get_u8();
1751
1752 let message = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
1753 let server = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
1754 let procedure = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
1755
1756 if src.remaining() < 4 {
1757 return Err(ProtocolError::UnexpectedEof);
1758 }
1759 let line = src.get_i32_le();
1760
1761 Ok(Self {
1762 number,
1763 state,
1764 class,
1765 message,
1766 server,
1767 procedure,
1768 line,
1769 })
1770 }
1771
1772 #[must_use]
1774 pub const fn is_fatal(&self) -> bool {
1775 self.class >= 20
1776 }
1777
1778 #[must_use]
1780 pub const fn is_batch_abort(&self) -> bool {
1781 self.class >= 16
1782 }
1783}
1784
1785impl ServerInfo {
1786 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
1790 if src.remaining() < 2 {
1791 return Err(ProtocolError::UnexpectedEof);
1792 }
1793
1794 let _length = src.get_u16_le();
1795
1796 if src.remaining() < 6 {
1797 return Err(ProtocolError::UnexpectedEof);
1798 }
1799
1800 let number = src.get_i32_le();
1801 let state = src.get_u8();
1802 let class = src.get_u8();
1803
1804 let message = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
1805 let server = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
1806 let procedure = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
1807
1808 if src.remaining() < 4 {
1809 return Err(ProtocolError::UnexpectedEof);
1810 }
1811 let line = src.get_i32_le();
1812
1813 Ok(Self {
1814 number,
1815 state,
1816 class,
1817 message,
1818 server,
1819 procedure,
1820 line,
1821 })
1822 }
1823}
1824
1825impl LoginAck {
1826 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
1828 if src.remaining() < 2 {
1830 return Err(ProtocolError::UnexpectedEof);
1831 }
1832
1833 let _length = src.get_u16_le();
1834
1835 if src.remaining() < 5 {
1836 return Err(ProtocolError::UnexpectedEof);
1837 }
1838
1839 let interface = src.get_u8();
1840 let tds_version = src.get_u32_le();
1841 let prog_name = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
1842
1843 if src.remaining() < 4 {
1844 return Err(ProtocolError::UnexpectedEof);
1845 }
1846 let prog_version = src.get_u32_le();
1847
1848 Ok(Self {
1849 interface,
1850 tds_version,
1851 prog_name,
1852 prog_version,
1853 })
1854 }
1855
1856 #[must_use]
1858 pub fn tds_version(&self) -> crate::version::TdsVersion {
1859 crate::version::TdsVersion::new(self.tds_version)
1860 }
1861}
1862
1863impl EnvChangeType {
1864 pub fn from_u8(value: u8) -> Option<Self> {
1866 match value {
1867 1 => Some(Self::Database),
1868 2 => Some(Self::Language),
1869 3 => Some(Self::CharacterSet),
1870 4 => Some(Self::PacketSize),
1871 5 => Some(Self::UnicodeSortingLocalId),
1872 6 => Some(Self::UnicodeComparisonFlags),
1873 7 => Some(Self::SqlCollation),
1874 8 => Some(Self::BeginTransaction),
1875 9 => Some(Self::CommitTransaction),
1876 10 => Some(Self::RollbackTransaction),
1877 11 => Some(Self::EnlistDtcTransaction),
1878 12 => Some(Self::DefectTransaction),
1879 13 => Some(Self::RealTimeLogShipping),
1880 15 => Some(Self::PromoteTransaction),
1881 16 => Some(Self::TransactionManagerAddress),
1882 17 => Some(Self::TransactionEnded),
1883 18 => Some(Self::ResetConnectionCompletionAck),
1884 19 => Some(Self::UserInstanceStarted),
1885 20 => Some(Self::Routing),
1886 _ => None,
1887 }
1888 }
1889}
1890
1891impl EnvChange {
1892 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
1894 if src.remaining() < 3 {
1895 return Err(ProtocolError::UnexpectedEof);
1896 }
1897
1898 let length = src.get_u16_le() as usize;
1899 if length == 0 {
1900 return Err(ProtocolError::UnexpectedEof);
1903 }
1904 if src.remaining() < length {
1905 return Err(ProtocolError::IncompletePacket {
1906 expected: length,
1907 actual: src.remaining(),
1908 });
1909 }
1910
1911 let mut frame = src.copy_to_bytes(length);
1920 let src = &mut frame;
1921
1922 let env_type_byte = src.get_u8();
1923 let env_type = EnvChangeType::from_u8(env_type_byte)
1924 .ok_or(ProtocolError::InvalidTokenType(env_type_byte))?;
1925
1926 let (new_value, old_value) = match env_type {
1927 EnvChangeType::Routing => {
1928 let new_value = Self::decode_routing_value(src)?;
1930 let old_value = EnvChangeValue::Binary(Bytes::new());
1931 (new_value, old_value)
1932 }
1933 EnvChangeType::BeginTransaction
1934 | EnvChangeType::CommitTransaction
1935 | EnvChangeType::RollbackTransaction
1936 | EnvChangeType::EnlistDtcTransaction
1937 | EnvChangeType::SqlCollation => {
1938 let new_len = if src.has_remaining() {
1947 src.get_u8() as usize
1948 } else {
1949 0
1950 };
1951 let new_value = if new_len > 0 && src.remaining() >= new_len {
1952 EnvChangeValue::Binary(src.copy_to_bytes(new_len))
1953 } else {
1954 EnvChangeValue::Binary(Bytes::new())
1955 };
1956
1957 let old_len = if src.has_remaining() {
1958 src.get_u8() as usize
1959 } else {
1960 0
1961 };
1962 let old_value = if old_len > 0 && src.remaining() >= old_len {
1963 EnvChangeValue::Binary(src.copy_to_bytes(old_len))
1964 } else {
1965 EnvChangeValue::Binary(Bytes::new())
1966 };
1967
1968 (new_value, old_value)
1969 }
1970 _ => {
1971 let new_value = read_b_varchar(src)
1973 .map(EnvChangeValue::String)
1974 .unwrap_or(EnvChangeValue::String(String::new()));
1975
1976 let old_value = read_b_varchar(src)
1977 .map(EnvChangeValue::String)
1978 .unwrap_or(EnvChangeValue::String(String::new()));
1979
1980 (new_value, old_value)
1981 }
1982 };
1983
1984 Ok(Self {
1990 env_type,
1991 new_value,
1992 old_value,
1993 })
1994 }
1995
1996 fn decode_routing_value(src: &mut impl Buf) -> Result<EnvChangeValue, ProtocolError> {
1997 if src.remaining() < 2 {
1999 return Err(ProtocolError::UnexpectedEof);
2000 }
2001
2002 let _routing_len = src.get_u16_le();
2003
2004 if src.remaining() < 5 {
2005 return Err(ProtocolError::UnexpectedEof);
2006 }
2007
2008 let _protocol = src.get_u8();
2009 let port = src.get_u16_le();
2010 let server_len = src.get_u16_le() as usize;
2011
2012 if src.remaining() < server_len * 2 {
2014 return Err(ProtocolError::UnexpectedEof);
2015 }
2016
2017 let mut chars = Vec::with_capacity(server_len);
2018 for _ in 0..server_len {
2019 chars.push(src.get_u16_le());
2020 }
2021
2022 let host = String::from_utf16(&chars).map_err(|_| {
2023 ProtocolError::StringEncoding(
2024 #[cfg(feature = "std")]
2025 "invalid UTF-16 in routing hostname".to_string(),
2026 #[cfg(not(feature = "std"))]
2027 "invalid UTF-16 in routing hostname",
2028 )
2029 })?;
2030
2031 Ok(EnvChangeValue::Routing { host, port })
2032 }
2033
2034 #[must_use]
2036 pub fn is_routing(&self) -> bool {
2037 self.env_type == EnvChangeType::Routing
2038 }
2039
2040 #[must_use]
2042 pub fn routing_info(&self) -> Option<(&str, u16)> {
2043 if let EnvChangeValue::Routing { host, port } = &self.new_value {
2044 Some((host, *port))
2045 } else {
2046 None
2047 }
2048 }
2049
2050 #[must_use]
2052 pub fn new_database(&self) -> Option<&str> {
2053 if self.env_type == EnvChangeType::Database {
2054 if let EnvChangeValue::String(s) = &self.new_value {
2055 return Some(s);
2056 }
2057 }
2058 None
2059 }
2060}
2061
2062impl Order {
2063 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
2065 if src.remaining() < 2 {
2066 return Err(ProtocolError::UnexpectedEof);
2067 }
2068
2069 let length = src.get_u16_le() as usize;
2070 let column_count = length / 2;
2071
2072 if src.remaining() < length {
2073 return Err(ProtocolError::IncompletePacket {
2074 expected: length,
2075 actual: src.remaining(),
2076 });
2077 }
2078
2079 let mut columns = Vec::with_capacity(column_count);
2080 for _ in 0..column_count {
2081 columns.push(src.get_u16_le());
2082 }
2083
2084 Ok(Self { columns })
2085 }
2086}
2087
2088impl FeatureExtAck {
2089 pub const TERMINATOR: u8 = 0xFF;
2091
2092 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
2094 let mut features = Vec::new();
2095
2096 loop {
2097 if !src.has_remaining() {
2098 return Err(ProtocolError::UnexpectedEof);
2099 }
2100
2101 let feature_id = src.get_u8();
2102 if feature_id == Self::TERMINATOR {
2103 break;
2104 }
2105
2106 if src.remaining() < 4 {
2107 return Err(ProtocolError::UnexpectedEof);
2108 }
2109
2110 let data_len = src.get_u32_le() as usize;
2111
2112 if src.remaining() < data_len {
2113 return Err(ProtocolError::IncompletePacket {
2114 expected: data_len,
2115 actual: src.remaining(),
2116 });
2117 }
2118
2119 let data = src.copy_to_bytes(data_len);
2120 features.push(FeatureAck { feature_id, data });
2121 }
2122
2123 Ok(Self { features })
2124 }
2125}
2126
2127impl SspiToken {
2128 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
2130 if src.remaining() < 2 {
2131 return Err(ProtocolError::UnexpectedEof);
2132 }
2133
2134 let length = src.get_u16_le() as usize;
2135
2136 if src.remaining() < length {
2137 return Err(ProtocolError::IncompletePacket {
2138 expected: length,
2139 actual: src.remaining(),
2140 });
2141 }
2142
2143 let data = src.copy_to_bytes(length);
2144 Ok(Self { data })
2145 }
2146}
2147
2148impl FedAuthInfo {
2149 const ID_STSURL: u8 = 0x01;
2151 const ID_SPN: u8 = 0x02;
2154 const OPT_HEADER_LEN: usize = 9;
2156
2157 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
2169 if src.remaining() < 4 {
2170 return Err(ProtocolError::UnexpectedEof);
2171 }
2172 let token_len = src.get_u32_le() as usize;
2173 if src.remaining() < token_len {
2174 return Err(ProtocolError::UnexpectedEof);
2175 }
2176
2177 let region = src.copy_to_bytes(token_len);
2180 if region.len() < 4 {
2181 return Err(ProtocolError::UnexpectedEof);
2182 }
2183 let count = u32::from_le_bytes([region[0], region[1], region[2], region[3]]) as usize;
2184
2185 let headers_end = count
2189 .checked_mul(Self::OPT_HEADER_LEN)
2190 .and_then(|n| n.checked_add(4))
2191 .ok_or(ProtocolError::UnexpectedEof)?;
2192 if headers_end > region.len() {
2193 return Err(ProtocolError::UnexpectedEof);
2194 }
2195
2196 let mut sts_url = String::new();
2197 let mut spn = String::new();
2198
2199 for i in 0..count {
2200 let h = 4 + i * Self::OPT_HEADER_LEN;
2201 let info_id = region[h];
2202 let data_len =
2203 u32::from_le_bytes([region[h + 1], region[h + 2], region[h + 3], region[h + 4]])
2204 as usize;
2205 let data_off =
2206 u32::from_le_bytes([region[h + 5], region[h + 6], region[h + 7], region[h + 8]])
2207 as usize;
2208
2209 if info_id != Self::ID_SPN && info_id != Self::ID_STSURL {
2212 continue;
2213 }
2214
2215 let data_end = data_off
2216 .checked_add(data_len)
2217 .ok_or(ProtocolError::UnexpectedEof)?;
2218 if data_end > region.len() {
2219 return Err(ProtocolError::UnexpectedEof);
2220 }
2221 if data_len % 2 != 0 {
2222 return Err(ProtocolError::StringEncoding(
2223 #[cfg(feature = "std")]
2224 "FEDAUTHINFO option data has odd length, not UTF-16".to_string(),
2225 #[cfg(not(feature = "std"))]
2226 "FEDAUTHINFO option data has odd length, not UTF-16",
2227 ));
2228 }
2229
2230 let chars: Vec<u16> = region[data_off..data_end]
2231 .chunks_exact(2)
2232 .map(|b| u16::from_le_bytes([b[0], b[1]]))
2233 .collect();
2234 let value = String::from_utf16(&chars).map_err(|_| {
2235 ProtocolError::StringEncoding(
2236 #[cfg(feature = "std")]
2237 "invalid UTF-16 in FEDAUTHINFO option".to_string(),
2238 #[cfg(not(feature = "std"))]
2239 "invalid UTF-16 in FEDAUTHINFO option",
2240 )
2241 })?;
2242
2243 if info_id == Self::ID_SPN {
2244 spn = value;
2245 } else {
2246 sts_url = value;
2247 }
2248 }
2249
2250 Ok(Self { sts_url, spn })
2251 }
2252}
2253
2254pub struct TokenParser {
2295 data: Bytes,
2296 position: usize,
2297 encryption_enabled: bool,
2300}
2301
2302impl TokenParser {
2303 #[must_use]
2305 pub fn new(data: Bytes) -> Self {
2306 Self {
2307 data,
2308 position: 0,
2309 encryption_enabled: false,
2310 }
2311 }
2312
2313 #[must_use]
2318 pub fn with_encryption(mut self, enabled: bool) -> Self {
2319 self.encryption_enabled = enabled;
2320 self
2321 }
2322
2323 #[must_use]
2325 pub fn remaining(&self) -> usize {
2326 self.data.len().saturating_sub(self.position)
2327 }
2328
2329 #[must_use]
2331 pub fn has_remaining(&self) -> bool {
2332 self.position < self.data.len()
2333 }
2334
2335 #[must_use]
2337 pub fn peek_token_type(&self) -> Option<TokenType> {
2338 if self.position < self.data.len() {
2339 TokenType::from_u8(self.data[self.position])
2340 } else {
2341 None
2342 }
2343 }
2344
2345 pub fn next_token(&mut self) -> Result<Option<Token>, ProtocolError> {
2353 self.next_token_with_metadata(None)
2354 }
2355
2356 pub fn next_token_with_metadata(
2363 &mut self,
2364 metadata: Option<&ColMetaData>,
2365 ) -> Result<Option<Token>, ProtocolError> {
2366 if !self.has_remaining() {
2367 return Ok(None);
2368 }
2369
2370 let mut buf = &self.data[self.position..];
2371 let start_pos = self.position;
2372
2373 let token_type_byte = buf.get_u8();
2374 let token_type = TokenType::from_u8(token_type_byte);
2375
2376 let token = match token_type {
2377 Some(TokenType::Done) => {
2378 let done = Done::decode(&mut buf)?;
2379 Token::Done(done)
2380 }
2381 Some(TokenType::DoneProc) => {
2382 let done = DoneProc::decode(&mut buf)?;
2383 Token::DoneProc(done)
2384 }
2385 Some(TokenType::DoneInProc) => {
2386 let done = DoneInProc::decode(&mut buf)?;
2387 Token::DoneInProc(done)
2388 }
2389 Some(TokenType::Error) => {
2390 let error = ServerError::decode(&mut buf)?;
2391 Token::Error(error)
2392 }
2393 Some(TokenType::Info) => {
2394 let info = ServerInfo::decode(&mut buf)?;
2395 Token::Info(info)
2396 }
2397 Some(TokenType::LoginAck) => {
2398 let login_ack = LoginAck::decode(&mut buf)?;
2399 Token::LoginAck(login_ack)
2400 }
2401 Some(TokenType::EnvChange) => {
2402 let env_change = EnvChange::decode(&mut buf)?;
2403 Token::EnvChange(env_change)
2404 }
2405 Some(TokenType::Order) => {
2406 let order = Order::decode(&mut buf)?;
2407 Token::Order(order)
2408 }
2409 Some(TokenType::FeatureExtAck) => {
2410 let ack = FeatureExtAck::decode(&mut buf)?;
2411 Token::FeatureExtAck(ack)
2412 }
2413 Some(TokenType::Sspi) => {
2414 let sspi = SspiToken::decode(&mut buf)?;
2415 Token::Sspi(sspi)
2416 }
2417 Some(TokenType::FedAuthInfo) => {
2418 let info = FedAuthInfo::decode(&mut buf)?;
2419 Token::FedAuthInfo(info)
2420 }
2421 Some(TokenType::ReturnStatus) => {
2422 if buf.remaining() < 4 {
2423 return Err(ProtocolError::UnexpectedEof);
2424 }
2425 let status = buf.get_i32_le();
2426 Token::ReturnStatus(status)
2427 }
2428 Some(TokenType::ColMetaData) => {
2429 let col_meta = if self.encryption_enabled {
2430 ColMetaData::decode_encrypted(&mut buf)?
2431 } else {
2432 ColMetaData::decode(&mut buf)?
2433 };
2434 Token::ColMetaData(col_meta)
2435 }
2436 Some(TokenType::Row) => {
2437 let meta = metadata.ok_or_else(|| {
2438 ProtocolError::StringEncoding(
2439 #[cfg(feature = "std")]
2440 "Row token requires column metadata".to_string(),
2441 #[cfg(not(feature = "std"))]
2442 "Row token requires column metadata",
2443 )
2444 })?;
2445 let row = RawRow::decode(&mut buf, meta)?;
2446 Token::Row(row)
2447 }
2448 Some(TokenType::NbcRow) => {
2449 let meta = metadata.ok_or_else(|| {
2450 ProtocolError::StringEncoding(
2451 #[cfg(feature = "std")]
2452 "NbcRow token requires column metadata".to_string(),
2453 #[cfg(not(feature = "std"))]
2454 "NbcRow token requires column metadata",
2455 )
2456 })?;
2457 let row = NbcRow::decode(&mut buf, meta)?;
2458 Token::NbcRow(row)
2459 }
2460 Some(TokenType::ReturnValue) => {
2461 let ret_val = ReturnValue::decode(&mut buf)?;
2462 Token::ReturnValue(ret_val)
2463 }
2464 Some(TokenType::SessionState) => {
2465 let session = SessionState::decode(&mut buf)?;
2466 Token::SessionState(session)
2467 }
2468 Some(TokenType::ColInfo) | Some(TokenType::TabName) | Some(TokenType::Offset) => {
2469 if buf.remaining() < 2 {
2472 return Err(ProtocolError::UnexpectedEof);
2473 }
2474 let length = buf.get_u16_le() as usize;
2475 if buf.remaining() < length {
2476 return Err(ProtocolError::IncompletePacket {
2477 expected: length,
2478 actual: buf.remaining(),
2479 });
2480 }
2481 buf.advance(length);
2483 self.position = start_pos + (self.data.len() - start_pos - buf.remaining());
2485 return self.next_token_with_metadata(metadata);
2486 }
2487 None => {
2488 return Err(ProtocolError::InvalidTokenType(token_type_byte));
2489 }
2490 };
2491
2492 let consumed = self.data.len() - start_pos - buf.remaining();
2494 self.position = start_pos + consumed;
2495
2496 Ok(Some(token))
2497 }
2498
2499 pub fn skip_token(&mut self) -> Result<(), ProtocolError> {
2503 if !self.has_remaining() {
2504 return Ok(());
2505 }
2506
2507 let token_type_byte = self.data[self.position];
2508 let token_type = TokenType::from_u8(token_type_byte);
2509
2510 let skip_amount = match token_type {
2512 Some(TokenType::Done) | Some(TokenType::DoneProc) | Some(TokenType::DoneInProc) => {
2514 1 + Done::SIZE }
2516 Some(TokenType::ReturnStatus) => {
2517 1 + 4 }
2519 Some(TokenType::Error)
2521 | Some(TokenType::Info)
2522 | Some(TokenType::LoginAck)
2523 | Some(TokenType::EnvChange)
2524 | Some(TokenType::Order)
2525 | Some(TokenType::Sspi)
2526 | Some(TokenType::ColInfo)
2527 | Some(TokenType::TabName)
2528 | Some(TokenType::Offset)
2529 | Some(TokenType::ReturnValue) => {
2530 if self.remaining() < 3 {
2531 return Err(ProtocolError::UnexpectedEof);
2532 }
2533 let length = u16::from_le_bytes([
2534 self.data[self.position + 1],
2535 self.data[self.position + 2],
2536 ]) as usize;
2537 1 + 2 + length }
2539 Some(TokenType::SessionState) | Some(TokenType::FedAuthInfo) => {
2541 if self.remaining() < 5 {
2542 return Err(ProtocolError::UnexpectedEof);
2543 }
2544 let length = u32::from_le_bytes([
2545 self.data[self.position + 1],
2546 self.data[self.position + 2],
2547 self.data[self.position + 3],
2548 self.data[self.position + 4],
2549 ]) as usize;
2550 1 + 4 + length
2551 }
2552 Some(TokenType::FeatureExtAck) => {
2554 let mut buf = &self.data[self.position + 1..];
2556 let _ = FeatureExtAck::decode(&mut buf)?;
2557 self.data.len() - self.position - buf.remaining()
2558 }
2559 Some(TokenType::ColMetaData) | Some(TokenType::Row) | Some(TokenType::NbcRow) => {
2561 return Err(ProtocolError::InvalidTokenType(token_type_byte));
2562 }
2563 None => {
2564 return Err(ProtocolError::InvalidTokenType(token_type_byte));
2565 }
2566 };
2567
2568 if self.remaining() < skip_amount {
2569 return Err(ProtocolError::UnexpectedEof);
2570 }
2571
2572 self.position += skip_amount;
2573 Ok(())
2574 }
2575
2576 #[must_use]
2578 pub fn position(&self) -> usize {
2579 self.position
2580 }
2581
2582 pub fn reset(&mut self) {
2584 self.position = 0;
2585 }
2586}
2587
2588#[cfg(test)]
2593#[allow(clippy::unwrap_used, clippy::panic)]
2594mod tests {
2595 use super::*;
2596 use bytes::BytesMut;
2597
2598 #[test]
2599 fn test_done_roundtrip() {
2600 let done = Done {
2601 status: DoneStatus {
2602 more: false,
2603 error: false,
2604 in_xact: false,
2605 count: true,
2606 attn: false,
2607 srverror: false,
2608 },
2609 cur_cmd: 193, row_count: 42,
2611 };
2612
2613 let mut buf = BytesMut::new();
2614 done.encode(&mut buf);
2615
2616 let mut cursor = &buf[1..];
2618 let decoded = Done::decode(&mut cursor).unwrap();
2619
2620 assert_eq!(decoded.status.count, done.status.count);
2621 assert_eq!(decoded.cur_cmd, done.cur_cmd);
2622 assert_eq!(decoded.row_count, done.row_count);
2623 }
2624
2625 #[test]
2626 fn test_done_status_bits() {
2627 let status = DoneStatus {
2628 more: true,
2629 error: true,
2630 in_xact: true,
2631 count: true,
2632 attn: false,
2633 srverror: false,
2634 };
2635
2636 let bits = status.to_bits();
2637 let restored = DoneStatus::from_bits(bits);
2638
2639 assert_eq!(status.more, restored.more);
2640 assert_eq!(status.error, restored.error);
2641 assert_eq!(status.in_xact, restored.in_xact);
2642 assert_eq!(status.count, restored.count);
2643 }
2644
2645 #[test]
2646 fn test_token_parser_done() {
2647 let data = Bytes::from_static(&[
2649 0xFD, 0x10, 0x00, 0xC1, 0x00, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ]);
2654
2655 let mut parser = TokenParser::new(data);
2656 let token = parser.next_token().unwrap().unwrap();
2657
2658 match token {
2659 Token::Done(done) => {
2660 assert!(done.status.count);
2661 assert!(!done.status.more);
2662 assert_eq!(done.cur_cmd, 193);
2663 assert_eq!(done.row_count, 5);
2664 }
2665 _ => panic!("Expected Done token"),
2666 }
2667
2668 assert!(parser.next_token().unwrap().is_none());
2670 }
2671
2672 #[test]
2673 fn test_env_change_type_from_u8() {
2674 assert_eq!(EnvChangeType::from_u8(1), Some(EnvChangeType::Database));
2675 assert_eq!(EnvChangeType::from_u8(20), Some(EnvChangeType::Routing));
2676 assert_eq!(EnvChangeType::from_u8(100), None);
2677 }
2678
2679 #[test]
2686 fn test_env_change_routing_consumes_declared_length() {
2687 let host = "redirect.example";
2688 let host_utf16: Vec<u16> = host.encode_utf16().collect();
2689
2690 let mut data = BytesMut::new();
2691 let routing_len = 1 + 2 + 2 + host_utf16.len() * 2;
2693 let env_len = 1 + 2 + routing_len + 2;
2696 data.put_u16_le(env_len as u16);
2697 data.put_u8(20); data.put_u16_le(routing_len as u16);
2699 data.put_u8(0); data.put_u16_le(11000); data.put_u16_le(host_utf16.len() as u16);
2702 for c in &host_utf16 {
2703 data.put_u16_le(*c);
2704 }
2705 data.put_u16_le(0); data.put_u8(0xFD);
2708
2709 let mut buf: &[u8] = &data;
2710 let env = EnvChange::decode(&mut buf).unwrap();
2711 assert_eq!(env.routing_info(), Some((host, 11000)));
2712 assert_eq!(
2713 buf,
2714 &[0xFD],
2715 "decode must consume exactly the declared ENVCHANGE frame"
2716 );
2717 }
2718
2719 fn put_b_varchar(buf: &mut BytesMut, s: &str) {
2720 let utf16: Vec<u16> = s.encode_utf16().collect();
2721 buf.put_u8(utf16.len() as u8);
2722 for c in utf16 {
2723 buf.put_u16_le(c);
2724 }
2725 }
2726
2727 fn put_us_varchar(buf: &mut BytesMut, s: &str) {
2728 let utf16: Vec<u16> = s.encode_utf16().collect();
2729 buf.put_u16_le(utf16.len() as u16);
2730 for c in utf16 {
2731 buf.put_u16_le(c);
2732 }
2733 }
2734
2735 #[test]
2742 fn test_udt_info_metadata_uses_b_varchar_names() {
2743 let mut data = BytesMut::new();
2744 data.put_u16_le(0xFFFF); put_b_varchar(&mut data, "master");
2746 put_b_varchar(&mut data, "dbo");
2747 put_b_varchar(&mut data, "hierarchyid");
2748 put_us_varchar(
2749 &mut data,
2750 "Microsoft.SqlServer.Types.SqlHierarchyId, Microsoft.SqlServer.Types",
2751 );
2752 data.put_u8(0xFD);
2754
2755 let mut buf: &[u8] = &data;
2756 let info = decode_type_info(&mut buf, TypeId::Udt, TypeId::Udt as u8).unwrap();
2757 assert_eq!(info.max_length, Some(0xFFFF));
2758 assert_eq!(
2759 buf,
2760 &[0xFD],
2761 "decode must consume exactly the UDT_INFO frame"
2762 );
2763 }
2764
2765 #[test]
2769 fn test_xml_info_schema_bound_uses_b_varchar_names() {
2770 let mut data = BytesMut::new();
2771 data.put_u8(1); put_b_varchar(&mut data, "master");
2773 put_b_varchar(&mut data, "dbo");
2774 put_us_varchar(&mut data, "MyXmlSchemaCollection");
2775 data.put_u8(0xFD);
2776
2777 let mut buf: &[u8] = &data;
2778 decode_type_info(&mut buf, TypeId::Xml, TypeId::Xml as u8).unwrap();
2779 assert_eq!(
2780 buf,
2781 &[0xFD],
2782 "decode must consume exactly the XML_INFO frame"
2783 );
2784 }
2785
2786 #[test]
2787 fn hostile_env_change_binary_truncated_is_not_panic() {
2788 let data = [0x01, 0x00, 0x08];
2793 let mut buf: &[u8] = &data;
2794 let env = EnvChange::decode(&mut buf).unwrap();
2795 assert_eq!(env.env_type, EnvChangeType::BeginTransaction);
2796 }
2797
2798 #[test]
2801 fn hostile_env_change_under_declared_cannot_steal_following_bytes() {
2802 let mut data = BytesMut::new();
2807 data.put_u16_le(1); data.put_u8(0x08); let following: &[u8] = &[0x08, 1, 2, 3, 4, 5, 6, 7, 8, 0x00];
2810 data.extend_from_slice(following);
2811
2812 let mut buf: &[u8] = &data;
2813 let env = EnvChange::decode(&mut buf).unwrap();
2814 assert_eq!(env.env_type, EnvChangeType::BeginTransaction);
2815 match &env.new_value {
2816 EnvChangeValue::Binary(b) => {
2817 assert!(
2818 b.is_empty(),
2819 "under-declared frame yields the lenient empty value"
2820 );
2821 }
2822 other => panic!("expected empty Binary value, got {other:?}"),
2823 }
2824 assert_eq!(
2825 buf, following,
2826 "bytes beyond the declared frame belong to the next token"
2827 );
2828 }
2829
2830 #[test]
2833 fn hostile_env_change_zero_length_frame_errors() {
2834 let data = [0x00, 0x00, 0xFD];
2835 let mut buf: &[u8] = &data;
2836 assert!(EnvChange::decode(&mut buf).is_err());
2837 }
2838
2839 #[test]
2840 fn test_colmetadata_no_columns() {
2841 let data = Bytes::from_static(&[0xFF, 0xFF]);
2843 let mut cursor: &[u8] = &data;
2844 let meta = ColMetaData::decode(&mut cursor).unwrap();
2845 assert!(meta.is_empty());
2846 assert_eq!(meta.column_count(), 0);
2847 }
2848
2849 #[test]
2850 fn test_colmetadata_single_int_column() {
2851 let mut data = BytesMut::new();
2854 data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[0x00, 0x00, 0x00, 0x00]); data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[0x38]); data.extend_from_slice(&[0x02]); data.extend_from_slice(&[b'i', 0x00, b'd', 0x00]); let mut cursor: &[u8] = &data;
2863 let meta = ColMetaData::decode(&mut cursor).unwrap();
2864
2865 assert_eq!(meta.column_count(), 1);
2866 assert_eq!(meta.columns[0].name, "id");
2867 assert_eq!(meta.columns[0].type_id, TypeId::Int4);
2868 assert!(meta.columns[0].is_nullable());
2869 }
2870
2871 #[test]
2872 fn test_colmetadata_nvarchar_column() {
2873 let mut data = BytesMut::new();
2875 data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[0x00, 0x00, 0x00, 0x00]); data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[0xE7]); data.extend_from_slice(&[0x64, 0x00]); data.extend_from_slice(&[0x09, 0x04, 0xD0, 0x00, 0x34]); data.extend_from_slice(&[0x04]); data.extend_from_slice(&[b'n', 0x00, b'a', 0x00, b'm', 0x00, b'e', 0x00]);
2885
2886 let mut cursor: &[u8] = &data;
2887 let meta = ColMetaData::decode(&mut cursor).unwrap();
2888
2889 assert_eq!(meta.column_count(), 1);
2890 assert_eq!(meta.columns[0].name, "name");
2891 assert_eq!(meta.columns[0].type_id, TypeId::NVarChar);
2892 assert_eq!(meta.columns[0].type_info.max_length, Some(100));
2893 assert!(meta.columns[0].type_info.collation.is_some());
2894 }
2895
2896 #[test]
2897 fn test_raw_row_decode_int() {
2898 let metadata = ColMetaData {
2900 cek_table: None,
2901 columns: vec![ColumnData {
2902 name: "id".to_string(),
2903 type_id: TypeId::Int4,
2904 col_type: 0x38,
2905 flags: 0,
2906 user_type: 0,
2907 type_info: TypeInfo::default(),
2908 crypto_metadata: None,
2909 }],
2910 };
2911
2912 let data = Bytes::from_static(&[0x2A, 0x00, 0x00, 0x00]); let mut cursor: &[u8] = &data;
2915 let row = RawRow::decode(&mut cursor, &metadata).unwrap();
2916
2917 assert_eq!(row.data.len(), 4);
2919 assert_eq!(&row.data[..], &[0x2A, 0x00, 0x00, 0x00]);
2920 }
2921
2922 #[test]
2923 fn test_raw_row_decode_nullable_int() {
2924 let metadata = ColMetaData {
2926 cek_table: None,
2927 columns: vec![ColumnData {
2928 name: "id".to_string(),
2929 type_id: TypeId::IntN,
2930 col_type: 0x26,
2931 flags: 0x01, user_type: 0,
2933 type_info: TypeInfo {
2934 max_length: Some(4),
2935 ..Default::default()
2936 },
2937 crypto_metadata: None,
2938 }],
2939 };
2940
2941 let data = Bytes::from_static(&[0x04, 0x2A, 0x00, 0x00, 0x00]); let mut cursor: &[u8] = &data;
2944 let row = RawRow::decode(&mut cursor, &metadata).unwrap();
2945
2946 assert_eq!(row.data.len(), 5);
2947 assert_eq!(row.data[0], 4); assert_eq!(&row.data[1..], &[0x2A, 0x00, 0x00, 0x00]);
2949 }
2950
2951 #[test]
2952 fn test_raw_row_decode_null_value() {
2953 let metadata = ColMetaData {
2955 cek_table: None,
2956 columns: vec![ColumnData {
2957 name: "id".to_string(),
2958 type_id: TypeId::IntN,
2959 col_type: 0x26,
2960 flags: 0x01, user_type: 0,
2962 type_info: TypeInfo {
2963 max_length: Some(4),
2964 ..Default::default()
2965 },
2966 crypto_metadata: None,
2967 }],
2968 };
2969
2970 let data = Bytes::from_static(&[0xFF]);
2972 let mut cursor: &[u8] = &data;
2973 let row = RawRow::decode(&mut cursor, &metadata).unwrap();
2974
2975 assert_eq!(row.data.len(), 1);
2976 assert_eq!(row.data[0], 0xFF); }
2978
2979 #[test]
2980 fn test_nbcrow_null_bitmap() {
2981 let row = NbcRow {
2982 null_bitmap: vec![0b00000101], data: Bytes::new(),
2984 };
2985
2986 assert!(row.is_null(0));
2987 assert!(!row.is_null(1));
2988 assert!(row.is_null(2));
2989 assert!(!row.is_null(3));
2990 }
2991
2992 #[test]
2993 fn test_token_parser_colmetadata() {
2994 let mut data = BytesMut::new();
2996 data.extend_from_slice(&[0x81]); data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[0x00, 0x00, 0x00, 0x00]); data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[0x38]); data.extend_from_slice(&[0x02]); data.extend_from_slice(&[b'i', 0x00, b'd', 0x00]); let mut parser = TokenParser::new(data.freeze());
3005 let token = parser.next_token().unwrap().unwrap();
3006
3007 match token {
3008 Token::ColMetaData(meta) => {
3009 assert_eq!(meta.column_count(), 1);
3010 assert_eq!(meta.columns[0].name, "id");
3011 assert_eq!(meta.columns[0].type_id, TypeId::Int4);
3012 }
3013 _ => panic!("Expected ColMetaData token"),
3014 }
3015 }
3016
3017 #[test]
3018 fn test_token_parser_row_with_metadata() {
3019 let metadata = ColMetaData {
3021 cek_table: None,
3022 columns: vec![ColumnData {
3023 name: "id".to_string(),
3024 type_id: TypeId::Int4,
3025 col_type: 0x38,
3026 flags: 0,
3027 user_type: 0,
3028 type_info: TypeInfo::default(),
3029 crypto_metadata: None,
3030 }],
3031 };
3032
3033 let mut data = BytesMut::new();
3035 data.extend_from_slice(&[0xD1]); data.extend_from_slice(&[0x2A, 0x00, 0x00, 0x00]); let mut parser = TokenParser::new(data.freeze());
3039 let token = parser
3040 .next_token_with_metadata(Some(&metadata))
3041 .unwrap()
3042 .unwrap();
3043
3044 match token {
3045 Token::Row(row) => {
3046 assert_eq!(row.data.len(), 4);
3047 }
3048 _ => panic!("Expected Row token"),
3049 }
3050 }
3051
3052 #[test]
3053 fn test_token_parser_row_without_metadata_fails() {
3054 let mut data = BytesMut::new();
3056 data.extend_from_slice(&[0xD1]); data.extend_from_slice(&[0x2A, 0x00, 0x00, 0x00]); let mut parser = TokenParser::new(data.freeze());
3060 let result = parser.next_token(); assert!(result.is_err());
3063 }
3064
3065 #[test]
3066 fn test_token_parser_peek() {
3067 let data = Bytes::from_static(&[
3068 0xFD, 0x10, 0x00, 0xC1, 0x00, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ]);
3073
3074 let parser = TokenParser::new(data);
3075 assert_eq!(parser.peek_token_type(), Some(TokenType::Done));
3076 }
3077
3078 #[test]
3079 fn test_column_data_fixed_size() {
3080 let col = ColumnData {
3081 name: String::new(),
3082 type_id: TypeId::Int4,
3083 col_type: 0x38,
3084 flags: 0,
3085 user_type: 0,
3086 type_info: TypeInfo::default(),
3087 crypto_metadata: None,
3088 };
3089 assert_eq!(col.fixed_size(), Some(4));
3090
3091 let col2 = ColumnData {
3092 name: String::new(),
3093 type_id: TypeId::NVarChar,
3094 col_type: 0xE7,
3095 flags: 0,
3096 user_type: 0,
3097 type_info: TypeInfo::default(),
3098 crypto_metadata: None,
3099 };
3100 assert_eq!(col2.fixed_size(), None);
3101 }
3102
3103 #[test]
3111 fn test_decode_nvarchar_then_intn_roundtrip() {
3112 let mut wire_data = BytesMut::new();
3117
3118 let word = "World";
3121 let utf16: Vec<u16> = word.encode_utf16().collect();
3122 wire_data.put_u16_le((utf16.len() * 2) as u16); for code_unit in &utf16 {
3124 wire_data.put_u16_le(*code_unit);
3125 }
3126
3127 wire_data.put_u8(4); wire_data.put_i32_le(42);
3130
3131 let metadata = ColMetaData {
3133 cek_table: None,
3134 columns: vec![
3135 ColumnData {
3136 name: "greeting".to_string(),
3137 type_id: TypeId::NVarChar,
3138 col_type: 0xE7,
3139 flags: 0x01,
3140 user_type: 0,
3141 type_info: TypeInfo {
3142 max_length: Some(10), precision: None,
3144 scale: None,
3145 collation: None,
3146 },
3147 crypto_metadata: None,
3148 },
3149 ColumnData {
3150 name: "number".to_string(),
3151 type_id: TypeId::IntN,
3152 col_type: 0x26,
3153 flags: 0x01,
3154 user_type: 0,
3155 type_info: TypeInfo {
3156 max_length: Some(4),
3157 precision: None,
3158 scale: None,
3159 collation: None,
3160 },
3161 crypto_metadata: None,
3162 },
3163 ],
3164 };
3165
3166 let mut wire_cursor = wire_data.freeze();
3168 let raw_row = RawRow::decode(&mut wire_cursor, &metadata).unwrap();
3169
3170 assert_eq!(
3172 wire_cursor.remaining(),
3173 0,
3174 "wire data should be fully consumed"
3175 );
3176
3177 let mut stored_cursor: &[u8] = &raw_row.data;
3179
3180 assert!(
3183 stored_cursor.remaining() >= 2,
3184 "need at least 2 bytes for length"
3185 );
3186 let len0 = stored_cursor.get_u16_le() as usize;
3187 assert_eq!(len0, 10, "NVarChar length should be 10 bytes");
3188 assert!(
3189 stored_cursor.remaining() >= len0,
3190 "need {len0} bytes for data"
3191 );
3192
3193 let mut utf16_read = Vec::new();
3195 for _ in 0..(len0 / 2) {
3196 utf16_read.push(stored_cursor.get_u16_le());
3197 }
3198 let string0 = String::from_utf16(&utf16_read).unwrap();
3199 assert_eq!(string0, "World", "column 0 should be 'World'");
3200
3201 assert!(
3204 stored_cursor.remaining() >= 1,
3205 "need at least 1 byte for length"
3206 );
3207 let len1 = stored_cursor.get_u8();
3208 assert_eq!(len1, 4, "IntN length should be 4");
3209 assert!(stored_cursor.remaining() >= 4, "need 4 bytes for INT data");
3210 let int1 = stored_cursor.get_i32_le();
3211 assert_eq!(int1, 42, "column 1 should be 42");
3212
3213 assert_eq!(
3215 stored_cursor.remaining(),
3216 0,
3217 "stored data should be fully consumed"
3218 );
3219 }
3220
3221 #[test]
3222 fn test_decode_nvarchar_max_then_intn_roundtrip() {
3223 let mut wire_data = BytesMut::new();
3227
3228 let word = "Hello";
3231 let utf16: Vec<u16> = word.encode_utf16().collect();
3232 let byte_len = (utf16.len() * 2) as u64;
3233
3234 wire_data.put_u64_le(byte_len); wire_data.put_u32_le(byte_len as u32); for code_unit in &utf16 {
3237 wire_data.put_u16_le(*code_unit);
3238 }
3239 wire_data.put_u32_le(0); wire_data.put_u8(4);
3243 wire_data.put_i32_le(99);
3244
3245 let metadata = ColMetaData {
3247 cek_table: None,
3248 columns: vec![
3249 ColumnData {
3250 name: "text".to_string(),
3251 type_id: TypeId::NVarChar,
3252 col_type: 0xE7,
3253 flags: 0x01,
3254 user_type: 0,
3255 type_info: TypeInfo {
3256 max_length: Some(0xFFFF), precision: None,
3258 scale: None,
3259 collation: None,
3260 },
3261 crypto_metadata: None,
3262 },
3263 ColumnData {
3264 name: "num".to_string(),
3265 type_id: TypeId::IntN,
3266 col_type: 0x26,
3267 flags: 0x01,
3268 user_type: 0,
3269 type_info: TypeInfo {
3270 max_length: Some(4),
3271 precision: None,
3272 scale: None,
3273 collation: None,
3274 },
3275 crypto_metadata: None,
3276 },
3277 ],
3278 };
3279
3280 let mut wire_cursor = wire_data.freeze();
3282 let raw_row = RawRow::decode(&mut wire_cursor, &metadata).unwrap();
3283
3284 assert_eq!(
3286 wire_cursor.remaining(),
3287 0,
3288 "wire data should be fully consumed"
3289 );
3290
3291 let mut stored_cursor: &[u8] = &raw_row.data;
3293
3294 let total_len = stored_cursor.get_u64_le();
3296 assert_eq!(total_len, 10, "PLP total length should be 10");
3297
3298 let chunk_len = stored_cursor.get_u32_le();
3299 assert_eq!(chunk_len, 10, "PLP chunk length should be 10");
3300
3301 let mut utf16_read = Vec::new();
3302 for _ in 0..(chunk_len / 2) {
3303 utf16_read.push(stored_cursor.get_u16_le());
3304 }
3305 let string0 = String::from_utf16(&utf16_read).unwrap();
3306 assert_eq!(string0, "Hello", "column 0 should be 'Hello'");
3307
3308 let terminator = stored_cursor.get_u32_le();
3309 assert_eq!(terminator, 0, "PLP should end with 0");
3310
3311 let len1 = stored_cursor.get_u8();
3313 assert_eq!(len1, 4);
3314 let int1 = stored_cursor.get_i32_le();
3315 assert_eq!(int1, 99, "column 1 should be 99");
3316
3317 assert_eq!(
3319 stored_cursor.remaining(),
3320 0,
3321 "stored data should be fully consumed"
3322 );
3323 }
3324
3325 #[test]
3330 fn test_return_status_via_parser() {
3331 let data = Bytes::from_static(&[
3333 0x79, 0x00, 0x00, 0x00, 0x00, ]);
3336
3337 let mut parser = TokenParser::new(data);
3338 let token = parser.next_token().unwrap().unwrap();
3339
3340 match token {
3341 Token::ReturnStatus(status) => {
3342 assert_eq!(status, 0);
3343 }
3344 _ => panic!("Expected ReturnStatus token, got {token:?}"),
3345 }
3346
3347 assert!(parser.next_token().unwrap().is_none());
3348 }
3349
3350 #[test]
3351 fn test_return_status_nonzero() {
3352 let mut buf = BytesMut::new();
3354 buf.put_u8(0x79); buf.put_i32_le(-6);
3356
3357 let mut parser = TokenParser::new(buf.freeze());
3358 let token = parser.next_token().unwrap().unwrap();
3359
3360 match token {
3361 Token::ReturnStatus(status) => {
3362 assert_eq!(status, -6);
3363 }
3364 _ => panic!("Expected ReturnStatus token"),
3365 }
3366 }
3367
3368 #[test]
3373 fn test_done_proc_roundtrip() {
3374 let done = DoneProc {
3375 status: DoneStatus {
3376 more: false,
3377 error: false,
3378 in_xact: false,
3379 count: true,
3380 attn: false,
3381 srverror: false,
3382 },
3383 cur_cmd: 0x00C6, row_count: 100,
3385 };
3386
3387 let mut buf = BytesMut::new();
3388 done.encode(&mut buf);
3389
3390 assert_eq!(buf[0], 0xFE);
3392
3393 let mut cursor = &buf[1..];
3395 let decoded = DoneProc::decode(&mut cursor).unwrap();
3396
3397 assert!(decoded.status.count);
3398 assert!(!decoded.status.more);
3399 assert!(!decoded.status.error);
3400 assert_eq!(decoded.cur_cmd, 0x00C6);
3401 assert_eq!(decoded.row_count, 100);
3402 }
3403
3404 #[test]
3405 fn test_done_proc_via_parser() {
3406 let data = Bytes::from_static(&[
3407 0xFE, 0x00, 0x00, 0xC6, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ]);
3412
3413 let mut parser = TokenParser::new(data);
3414 let token = parser.next_token().unwrap().unwrap();
3415
3416 match token {
3417 Token::DoneProc(done) => {
3418 assert!(!done.status.count);
3419 assert!(!done.status.more);
3420 assert_eq!(done.cur_cmd, 198);
3421 assert_eq!(done.row_count, 0);
3422 }
3423 _ => panic!("Expected DoneProc token"),
3424 }
3425 }
3426
3427 #[test]
3428 fn test_done_proc_with_error_flag() {
3429 let mut buf = BytesMut::new();
3430 buf.put_u8(0xFE); buf.put_u16_le(0x0002); buf.put_u16_le(0x00C6); buf.put_u64_le(0); let mut parser = TokenParser::new(buf.freeze());
3436 let token = parser.next_token().unwrap().unwrap();
3437
3438 match token {
3439 Token::DoneProc(done) => {
3440 assert!(done.status.error);
3441 assert!(!done.status.count);
3442 assert!(!done.status.more);
3443 }
3444 _ => panic!("Expected DoneProc token"),
3445 }
3446 }
3447
3448 #[test]
3453 fn test_done_in_proc_roundtrip() {
3454 let done = DoneInProc {
3455 status: DoneStatus {
3456 more: true,
3457 error: false,
3458 in_xact: false,
3459 count: true,
3460 attn: false,
3461 srverror: false,
3462 },
3463 cur_cmd: 193, row_count: 7,
3465 };
3466
3467 let mut buf = BytesMut::new();
3468 done.encode(&mut buf);
3469
3470 assert_eq!(buf[0], 0xFF);
3471
3472 let mut cursor = &buf[1..];
3473 let decoded = DoneInProc::decode(&mut cursor).unwrap();
3474
3475 assert!(decoded.status.more);
3476 assert!(decoded.status.count);
3477 assert!(!decoded.status.error);
3478 assert_eq!(decoded.cur_cmd, 193);
3479 assert_eq!(decoded.row_count, 7);
3480 }
3481
3482 #[test]
3483 fn test_done_in_proc_via_parser() {
3484 let data = Bytes::from_static(&[
3485 0xFF, 0x11, 0x00, 0xC1, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ]);
3490
3491 let mut parser = TokenParser::new(data);
3492 let token = parser.next_token().unwrap().unwrap();
3493
3494 match token {
3495 Token::DoneInProc(done) => {
3496 assert!(done.status.more);
3497 assert!(done.status.count);
3498 assert_eq!(done.cur_cmd, 193);
3499 assert_eq!(done.row_count, 3);
3500 }
3501 _ => panic!("Expected DoneInProc token"),
3502 }
3503 }
3504
3505 #[test]
3510 fn test_server_error_decode() {
3511 let mut buf = BytesMut::new();
3514
3515 let msg_utf16: Vec<u16> = "Invalid column name 'foo'.".encode_utf16().collect();
3517 let srv_utf16: Vec<u16> = "SQLDB01".encode_utf16().collect();
3518 let proc_utf16: Vec<u16> = "".encode_utf16().collect();
3519
3520 let length: u16 = (4
3526 + 1
3527 + 1
3528 + 2
3529 + (msg_utf16.len() * 2)
3530 + 1
3531 + (srv_utf16.len() * 2)
3532 + 1
3533 + (proc_utf16.len() * 2)
3534 + 4) as u16;
3535
3536 buf.put_u16_le(length);
3537 buf.put_i32_le(207); buf.put_u8(1); buf.put_u8(16); buf.put_u16_le(msg_utf16.len() as u16);
3543 for &c in &msg_utf16 {
3544 buf.put_u16_le(c);
3545 }
3546
3547 buf.put_u8(srv_utf16.len() as u8);
3549 for &c in &srv_utf16 {
3550 buf.put_u16_le(c);
3551 }
3552
3553 buf.put_u8(proc_utf16.len() as u8);
3555
3556 buf.put_i32_le(42);
3558
3559 let mut cursor = buf.freeze();
3560 let error = ServerError::decode(&mut cursor).unwrap();
3561
3562 assert_eq!(error.number, 207);
3563 assert_eq!(error.state, 1);
3564 assert_eq!(error.class, 16);
3565 assert_eq!(error.message, "Invalid column name 'foo'.");
3566 assert_eq!(error.server, "SQLDB01");
3567 assert_eq!(error.procedure, "");
3568 assert_eq!(error.line, 42);
3569 }
3570
3571 #[test]
3572 fn test_server_error_severity_helpers() {
3573 let fatal = ServerError {
3574 number: 4014,
3575 state: 1,
3576 class: 20,
3577 message: "Fatal error".to_string(),
3578 server: String::new(),
3579 procedure: String::new(),
3580 line: 0,
3581 };
3582 assert!(fatal.is_fatal());
3583 assert!(fatal.is_batch_abort());
3584
3585 let batch_abort = ServerError {
3586 number: 547,
3587 state: 0,
3588 class: 16,
3589 message: "Constraint violation".to_string(),
3590 server: String::new(),
3591 procedure: String::new(),
3592 line: 1,
3593 };
3594 assert!(!batch_abort.is_fatal());
3595 assert!(batch_abort.is_batch_abort());
3596
3597 let informational = ServerError {
3598 number: 5701,
3599 state: 2,
3600 class: 10,
3601 message: "Changed db context".to_string(),
3602 server: String::new(),
3603 procedure: String::new(),
3604 line: 0,
3605 };
3606 assert!(!informational.is_fatal());
3607 assert!(!informational.is_batch_abort());
3608 }
3609
3610 #[test]
3611 fn test_server_error_via_parser() {
3612 let mut buf = BytesMut::new();
3614 buf.put_u8(0xAA); let msg_utf16: Vec<u16> = "Syntax error".encode_utf16().collect();
3617 let srv_utf16: Vec<u16> = "SRV".encode_utf16().collect();
3618 let proc_utf16: Vec<u16> = "sp_test".encode_utf16().collect();
3619
3620 let length: u16 = (4
3621 + 1
3622 + 1
3623 + 2
3624 + (msg_utf16.len() * 2)
3625 + 1
3626 + (srv_utf16.len() * 2)
3627 + 1
3628 + (proc_utf16.len() * 2)
3629 + 4) as u16;
3630
3631 buf.put_u16_le(length);
3632 buf.put_i32_le(102); buf.put_u8(1);
3634 buf.put_u8(15);
3635
3636 buf.put_u16_le(msg_utf16.len() as u16);
3637 for &c in &msg_utf16 {
3638 buf.put_u16_le(c);
3639 }
3640 buf.put_u8(srv_utf16.len() as u8);
3641 for &c in &srv_utf16 {
3642 buf.put_u16_le(c);
3643 }
3644 buf.put_u8(proc_utf16.len() as u8);
3645 for &c in &proc_utf16 {
3646 buf.put_u16_le(c);
3647 }
3648 buf.put_i32_le(5);
3649
3650 let mut parser = TokenParser::new(buf.freeze());
3651 let token = parser.next_token().unwrap().unwrap();
3652
3653 match token {
3654 Token::Error(err) => {
3655 assert_eq!(err.number, 102);
3656 assert_eq!(err.class, 15);
3657 assert_eq!(err.message, "Syntax error");
3658 assert_eq!(err.server, "SRV");
3659 assert_eq!(err.procedure, "sp_test");
3660 assert_eq!(err.line, 5);
3661 }
3662 _ => panic!("Expected Error token"),
3663 }
3664 }
3665
3666 fn build_return_value_intn(
3673 ordinal: u16,
3674 name: &str,
3675 status: u8,
3676 value: Option<i32>,
3677 ) -> BytesMut {
3678 let mut inner = BytesMut::new();
3679
3680 inner.put_u16_le(ordinal);
3682
3683 let name_utf16: Vec<u16> = name.encode_utf16().collect();
3685 inner.put_u8(name_utf16.len() as u8);
3686 for &c in &name_utf16 {
3687 inner.put_u16_le(c);
3688 }
3689
3690 inner.put_u8(status);
3692
3693 inner.put_u32_le(0);
3695
3696 inner.put_u16_le(0x0001); inner.put_u8(0x26);
3701
3702 inner.put_u8(4);
3704
3705 match value {
3707 Some(v) => {
3708 inner.put_u8(4); inner.put_i32_le(v);
3710 }
3711 None => {
3712 inner.put_u8(0); }
3714 }
3715
3716 inner
3719 }
3720
3721 #[test]
3722 fn test_return_value_int_output() {
3723 let buf = build_return_value_intn(1, "@result", 0x01, Some(42));
3724 let mut cursor = buf.freeze();
3725 let rv = ReturnValue::decode(&mut cursor).unwrap();
3726
3727 assert_eq!(rv.param_ordinal, 1);
3728 assert_eq!(rv.param_name, "@result");
3729 assert_eq!(rv.status, 0x01); assert_eq!(rv.col_type, 0x26); assert_eq!(rv.type_info.max_length, Some(4));
3732 assert_eq!(rv.value.len(), 5);
3734 assert_eq!(rv.value[0], 4);
3735 assert_eq!(
3736 i32::from_le_bytes([rv.value[1], rv.value[2], rv.value[3], rv.value[4]]),
3737 42
3738 );
3739 }
3740
3741 #[test]
3742 fn test_return_value_null_output() {
3743 let buf = build_return_value_intn(2, "@count", 0x01, None);
3744 let mut cursor = buf.freeze();
3745 let rv = ReturnValue::decode(&mut cursor).unwrap();
3746
3747 assert_eq!(rv.param_ordinal, 2);
3748 assert_eq!(rv.param_name, "@count");
3749 assert_eq!(rv.status, 0x01);
3750 assert_eq!(rv.col_type, 0x26);
3751 assert_eq!(rv.value.len(), 1);
3753 assert_eq!(rv.value[0], 0);
3754 }
3755
3756 #[test]
3757 fn test_return_value_udf_status() {
3758 let buf = build_return_value_intn(0, "@RETURN_VALUE", 0x02, Some(-1));
3760 let mut cursor = buf.freeze();
3761 let rv = ReturnValue::decode(&mut cursor).unwrap();
3762
3763 assert_eq!(rv.param_ordinal, 0);
3764 assert_eq!(rv.param_name, "@RETURN_VALUE");
3765 assert_eq!(rv.status, 0x02); assert_eq!(rv.value[0], 4);
3767 assert_eq!(
3768 i32::from_le_bytes([rv.value[1], rv.value[2], rv.value[3], rv.value[4]]),
3769 -1
3770 );
3771 }
3772
3773 #[test]
3774 fn test_return_value_nvarchar_output() {
3775 let mut inner = BytesMut::new();
3777
3778 inner.put_u16_le(1);
3780
3781 let name_utf16: Vec<u16> = "@name".encode_utf16().collect();
3783 inner.put_u8(name_utf16.len() as u8);
3784 for &c in &name_utf16 {
3785 inner.put_u16_le(c);
3786 }
3787
3788 inner.put_u8(0x01);
3790 inner.put_u32_le(0);
3792 inner.put_u16_le(0x0001);
3794 inner.put_u8(0xE7);
3796 inner.put_u16_le(200); inner.put_u32_le(0x0904D000); inner.put_u8(0x34); let val_utf16: Vec<u16> = "Hello".encode_utf16().collect();
3803 let byte_len = (val_utf16.len() * 2) as u16;
3804 inner.put_u16_le(byte_len);
3805 for &c in &val_utf16 {
3806 inner.put_u16_le(c);
3807 }
3808
3809 let mut cursor = inner.freeze();
3810 let rv = ReturnValue::decode(&mut cursor).unwrap();
3811
3812 assert_eq!(rv.param_ordinal, 1);
3813 assert_eq!(rv.param_name, "@name");
3814 assert_eq!(rv.status, 0x01);
3815 assert_eq!(rv.col_type, 0xE7); assert_eq!(rv.type_info.max_length, Some(200));
3817 assert!(rv.type_info.collation.is_some());
3818
3819 assert_eq!(rv.value.len(), 12); let val_len = u16::from_le_bytes([rv.value[0], rv.value[1]]);
3822 assert_eq!(val_len, 10);
3823 }
3824
3825 #[test]
3826 fn test_return_value_via_parser() {
3827 let mut data = BytesMut::new();
3829 data.put_u8(0xAC); data.extend_from_slice(&build_return_value_intn(0, "@out", 0x01, Some(99)));
3831
3832 let mut parser = TokenParser::new(data.freeze());
3833 let token = parser.next_token().unwrap().unwrap();
3834
3835 match token {
3836 Token::ReturnValue(rv) => {
3837 assert_eq!(rv.param_name, "@out");
3838 assert_eq!(rv.param_ordinal, 0);
3839 assert_eq!(rv.status, 0x01);
3840 assert_eq!(rv.col_type, 0x26);
3841 }
3842 _ => panic!("Expected ReturnValue token"),
3843 }
3844 }
3845
3846 #[test]
3851 fn test_multi_token_stored_proc_response() {
3852 let mut data = BytesMut::new();
3855
3856 data.put_u8(0xFF); data.put_u16_le(0x0010); data.put_u16_le(0x00C1); data.put_u64_le(3); data.put_u8(0x79); data.put_i32_le(0);
3865
3866 data.put_u8(0xFE); data.put_u16_le(0x0000); data.put_u16_le(0x00C6); data.put_u64_le(0);
3871
3872 let mut parser = TokenParser::new(data.freeze());
3873
3874 let t1 = parser.next_token().unwrap().unwrap();
3876 match t1 {
3877 Token::DoneInProc(done) => {
3878 assert!(done.status.count);
3879 assert_eq!(done.row_count, 3);
3880 assert_eq!(done.cur_cmd, 193);
3881 }
3882 _ => panic!("Expected DoneInProc, got {t1:?}"),
3883 }
3884
3885 let t2 = parser.next_token().unwrap().unwrap();
3887 match t2 {
3888 Token::ReturnStatus(status) => {
3889 assert_eq!(status, 0);
3890 }
3891 _ => panic!("Expected ReturnStatus, got {t2:?}"),
3892 }
3893
3894 let t3 = parser.next_token().unwrap().unwrap();
3896 match t3 {
3897 Token::DoneProc(done) => {
3898 assert!(!done.status.count);
3899 assert!(!done.status.more);
3900 assert_eq!(done.cur_cmd, 198);
3901 }
3902 _ => panic!("Expected DoneProc, got {t3:?}"),
3903 }
3904
3905 assert!(parser.next_token().unwrap().is_none());
3907 }
3908
3909 #[test]
3910 fn test_multi_token_error_in_stream() {
3911 let mut data = BytesMut::new();
3913
3914 data.put_u8(0xAA);
3916
3917 let msg_utf16: Vec<u16> = "Deadlock".encode_utf16().collect();
3918 let srv_utf16: Vec<u16> = "DB1".encode_utf16().collect();
3919
3920 let length: u16 = (4 + 1 + 1
3921 + 2 + (msg_utf16.len() * 2)
3922 + 1 + (srv_utf16.len() * 2)
3923 + 1 + 4) as u16;
3925
3926 data.put_u16_le(length);
3927 data.put_i32_le(1205); data.put_u8(51); data.put_u8(13); data.put_u16_le(msg_utf16.len() as u16);
3932 for &c in &msg_utf16 {
3933 data.put_u16_le(c);
3934 }
3935 data.put_u8(srv_utf16.len() as u8);
3936 for &c in &srv_utf16 {
3937 data.put_u16_le(c);
3938 }
3939 data.put_u8(0); data.put_i32_le(0);
3941
3942 data.put_u8(0xFD);
3944 data.put_u16_le(0x0002); data.put_u16_le(0x00C1); data.put_u64_le(0);
3947
3948 let mut parser = TokenParser::new(data.freeze());
3949
3950 let t1 = parser.next_token().unwrap().unwrap();
3952 match t1 {
3953 Token::Error(err) => {
3954 assert_eq!(err.number, 1205);
3955 assert_eq!(err.class, 13);
3956 assert_eq!(err.message, "Deadlock");
3957 assert_eq!(err.server, "DB1");
3958 }
3959 _ => panic!("Expected Error token, got {t1:?}"),
3960 }
3961
3962 let t2 = parser.next_token().unwrap().unwrap();
3964 match t2 {
3965 Token::Done(done) => {
3966 assert!(done.status.error);
3967 assert!(!done.status.count);
3968 }
3969 _ => panic!("Expected Done token, got {t2:?}"),
3970 }
3971
3972 assert!(parser.next_token().unwrap().is_none());
3973 }
3974
3975 #[test]
3976 fn test_multi_token_proc_with_return_value() {
3977 let mut data = BytesMut::new();
3979
3980 data.put_u8(0xAC);
3982 data.extend_from_slice(&build_return_value_intn(1, "@result", 0x01, Some(42)));
3983
3984 data.put_u8(0x79);
3986 data.put_i32_le(0);
3987
3988 data.put_u8(0xFE);
3990 data.put_u16_le(0x0000);
3991 data.put_u16_le(0x00C6);
3992 data.put_u64_le(0);
3993
3994 let mut parser = TokenParser::new(data.freeze());
3995
3996 let t1 = parser.next_token().unwrap().unwrap();
3997 match t1 {
3998 Token::ReturnValue(rv) => {
3999 assert_eq!(rv.param_name, "@result");
4000 assert_eq!(rv.param_ordinal, 1);
4001 }
4002 _ => panic!("Expected ReturnValue, got {t1:?}"),
4003 }
4004
4005 let t2 = parser.next_token().unwrap().unwrap();
4006 assert!(matches!(t2, Token::ReturnStatus(0)));
4007
4008 let t3 = parser.next_token().unwrap().unwrap();
4009 assert!(matches!(t3, Token::DoneProc(_)));
4010
4011 assert!(parser.next_token().unwrap().is_none());
4012 }
4013
4014 #[test]
4019 fn test_return_status_truncated() {
4020 let data = Bytes::from_static(&[0x79, 0x01, 0x02, 0x03]);
4022 let mut parser = TokenParser::new(data);
4023 assert!(parser.next_token().is_err());
4024 }
4025
4026 #[test]
4027 fn test_done_proc_truncated() {
4028 let data = Bytes::from_static(&[0xFE, 0x00, 0x00, 0xC1, 0x00, 0x01, 0x00, 0x00, 0x00]);
4030 let mut parser = TokenParser::new(data);
4031 assert!(parser.next_token().is_err());
4032 }
4033
4034 #[test]
4035 fn test_server_error_truncated() {
4036 let data = Bytes::from_static(&[0xAA, 0x20, 0x00]);
4038 let mut parser = TokenParser::new(data);
4039 assert!(parser.next_token().is_err());
4040 }
4041
4042 fn build_fed_auth_info_token(options: &[(u8, &str)]) -> Vec<u8> {
4051 let headers_end = 4 + options.len() * 9;
4052 let mut data_block = Vec::new();
4053 let mut headers = Vec::new();
4054 for (id, value) in options {
4055 let encoded: Vec<u8> = value.encode_utf16().flat_map(u16::to_le_bytes).collect();
4056 let offset = headers_end + data_block.len();
4057 headers.push(*id);
4058 headers.extend_from_slice(&u32::try_from(encoded.len()).unwrap().to_le_bytes());
4059 headers.extend_from_slice(&u32::try_from(offset).unwrap().to_le_bytes());
4060 data_block.extend_from_slice(&encoded);
4061 }
4062
4063 let token_len = 4 + headers.len() + data_block.len();
4064 let mut out = vec![0xEE];
4065 out.extend_from_slice(&u32::try_from(token_len).unwrap().to_le_bytes());
4066 out.extend_from_slice(&u32::try_from(options.len()).unwrap().to_le_bytes());
4067 out.extend_from_slice(&headers);
4068 out.extend_from_slice(&data_block);
4069 out
4070 }
4071
4072 #[test]
4073 fn test_fed_auth_info_decodes_spec_layout() {
4074 const STS: &str = "https://login.microsoftonline.com/common";
4075 const SPN: &str = "https://database.windows.net/";
4076 let token = build_fed_auth_info_token(&[(0x01, STS), (0x02, SPN)]);
4080
4081 let mut parser = TokenParser::new(Bytes::from(token));
4082 let parsed = parser.next_token().unwrap().unwrap();
4083 let Token::FedAuthInfo(info) = parsed else {
4084 panic!("expected FedAuthInfo, got {parsed:?}");
4085 };
4086 assert_eq!(info.sts_url, STS);
4087 assert_eq!(info.spn, SPN);
4088 assert!(parser.next_token().unwrap().is_none(), "exact consumption");
4089 }
4090
4091 #[test]
4092 fn test_fed_auth_info_preserves_following_tokens() {
4093 let mut stream = build_fed_auth_info_token(&[
4097 (0x01, "https://sts.example/"),
4098 (0x02, "https://db.example/"),
4099 ]);
4100 stream.push(0xFD); stream.extend_from_slice(&0u16.to_le_bytes()); stream.extend_from_slice(&0u16.to_le_bytes()); stream.extend_from_slice(&0u64.to_le_bytes()); let mut parser = TokenParser::new(Bytes::from(stream));
4106 assert!(matches!(
4107 parser.next_token().unwrap(),
4108 Some(Token::FedAuthInfo(_))
4109 ));
4110 assert!(
4111 matches!(parser.next_token().unwrap(), Some(Token::Done(_))),
4112 "DONE after FEDAUTHINFO must not be swallowed"
4113 );
4114 assert!(parser.next_token().unwrap().is_none());
4115 }
4116
4117 #[test]
4118 fn test_fed_auth_info_unknown_ids_ignored() {
4119 let token =
4121 build_fed_auth_info_token(&[(0x7F, "ignore-me"), (0x01, "https://sts.example/")]);
4122 let mut parser = TokenParser::new(Bytes::from(token));
4123 let Some(Token::FedAuthInfo(info)) = parser.next_token().unwrap() else {
4124 panic!("expected FedAuthInfo");
4125 };
4126 assert_eq!(info.sts_url, "https://sts.example/");
4127 assert_eq!(info.spn, "");
4128 }
4129
4130 #[test]
4131 fn test_fed_auth_info_hostile_inputs_error() {
4132 let mut truncated = build_fed_auth_info_token(&[(0x02, "https://sts.example/")]);
4134 truncated.truncate(truncated.len() - 4);
4135 assert!(
4136 TokenParser::new(Bytes::from(truncated))
4137 .next_token()
4138 .is_err()
4139 );
4140
4141 let mut bad_count = build_fed_auth_info_token(&[(0x02, "https://sts.example/")]);
4144 bad_count[5..9].copy_from_slice(&u32::MAX.to_le_bytes());
4145 assert!(
4146 TokenParser::new(Bytes::from(bad_count))
4147 .next_token()
4148 .is_err()
4149 );
4150
4151 let mut bad_offset = build_fed_auth_info_token(&[(0x02, "https://sts.example/")]);
4153 bad_offset[14..18].copy_from_slice(&u32::MAX.to_le_bytes());
4154 assert!(
4155 TokenParser::new(Bytes::from(bad_offset))
4156 .next_token()
4157 .is_err()
4158 );
4159
4160 let mut odd_len = build_fed_auth_info_token(&[(0x02, "https://sts.example/")]);
4162 odd_len[10..14].copy_from_slice(&3u32.to_le_bytes());
4163 assert!(TokenParser::new(Bytes::from(odd_len)).next_token().is_err());
4164 }
4165
4166 #[test]
4167 fn test_fed_auth_info_parse_and_skip_agree() {
4168 let token = build_fed_auth_info_token(&[(0x02, "https://sts.example/")]);
4171 let total = token.len();
4172
4173 let mut parser = TokenParser::new(Bytes::from(token.clone()));
4174 parser.next_token().unwrap();
4175 assert_eq!(parser.position(), total, "decode consumption");
4176
4177 let mut skipper = TokenParser::new(Bytes::from(token));
4178 skipper.skip_token().unwrap();
4179 assert_eq!(skipper.position(), total, "skip consumption");
4180 }
4181
4182 #[test]
4193 fn test_fed_auth_info_captured_from_azure() {
4194 const CAPTURED: &[u8] = &[
4195 0xEE, 0xCC, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x02, 0x3A, 0x00, 0x00, 0x00,
4196 0x16, 0x00, 0x00, 0x00, 0x01, 0x7C, 0x00, 0x00, 0x00, 0x50, 0x00, 0x00, 0x00, 0x68,
4197 0x00, 0x74, 0x00, 0x74, 0x00, 0x70, 0x00, 0x73, 0x00, 0x3A, 0x00, 0x2F, 0x00, 0x2F,
4198 0x00, 0x64, 0x00, 0x61, 0x00, 0x74, 0x00, 0x61, 0x00, 0x62, 0x00, 0x61, 0x00, 0x73,
4199 0x00, 0x65, 0x00, 0x2E, 0x00, 0x77, 0x00, 0x69, 0x00, 0x6E, 0x00, 0x64, 0x00, 0x6F,
4200 0x00, 0x77, 0x00, 0x73, 0x00, 0x2E, 0x00, 0x6E, 0x00, 0x65, 0x00, 0x74, 0x00, 0x2F,
4201 0x00, 0x68, 0x00, 0x74, 0x00, 0x74, 0x00, 0x70, 0x00, 0x73, 0x00, 0x3A, 0x00, 0x2F,
4202 0x00, 0x2F, 0x00, 0x6C, 0x00, 0x6F, 0x00, 0x67, 0x00, 0x69, 0x00, 0x6E, 0x00, 0x2E,
4203 0x00, 0x77, 0x00, 0x69, 0x00, 0x6E, 0x00, 0x64, 0x00, 0x6F, 0x00, 0x77, 0x00, 0x73,
4204 0x00, 0x2E, 0x00, 0x6E, 0x00, 0x65, 0x00, 0x74, 0x00, 0x2F, 0x00, 0x30, 0x00, 0x30,
4205 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x2D,
4206 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x2D, 0x00, 0x30, 0x00, 0x30,
4207 0x00, 0x30, 0x00, 0x30, 0x00, 0x2D, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30,
4208 0x00, 0x2D, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30,
4209 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00,
4210 ];
4211
4212 let mut parser = TokenParser::new(Bytes::from_static(CAPTURED));
4213 let Some(Token::FedAuthInfo(info)) = parser.next_token().unwrap() else {
4214 panic!("expected FedAuthInfo");
4215 };
4216 assert_eq!(
4217 info.sts_url,
4218 "https://login.windows.net/00000000-0000-0000-0000-000000000000"
4219 );
4220 assert_eq!(info.spn, "https://database.windows.net/");
4221 assert!(
4222 parser.next_token().unwrap().is_none(),
4223 "the captured token must be consumed exactly"
4224 );
4225 }
4226}