1use bytes::{Buf, BufMut, Bytes};
33
34use crate::codec::{read_b_varchar, read_us_varchar};
35use crate::error::ProtocolError;
36use crate::prelude::*;
37use crate::types::TypeId;
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
41#[repr(u8)]
42#[non_exhaustive]
43pub enum TokenType {
44 ColMetaData = 0x81,
46 Error = 0xAA,
48 Info = 0xAB,
50 LoginAck = 0xAD,
52 Row = 0xD1,
54 NbcRow = 0xD2,
56 EnvChange = 0xE3,
58 Sspi = 0xED,
60 Done = 0xFD,
62 DoneInProc = 0xFF,
64 DoneProc = 0xFE,
66 ReturnStatus = 0x79,
68 ReturnValue = 0xAC,
70 Order = 0xA9,
72 FeatureExtAck = 0xAE,
74 SessionState = 0xE4,
76 FedAuthInfo = 0xEE,
78 ColInfo = 0xA5,
80 TabName = 0xA4,
82 Offset = 0x78,
84}
85
86impl TokenType {
87 pub fn from_u8(value: u8) -> Option<Self> {
89 match value {
90 0x81 => Some(Self::ColMetaData),
91 0xAA => Some(Self::Error),
92 0xAB => Some(Self::Info),
93 0xAD => Some(Self::LoginAck),
94 0xD1 => Some(Self::Row),
95 0xD2 => Some(Self::NbcRow),
96 0xE3 => Some(Self::EnvChange),
97 0xED => Some(Self::Sspi),
98 0xFD => Some(Self::Done),
99 0xFF => Some(Self::DoneInProc),
100 0xFE => Some(Self::DoneProc),
101 0x79 => Some(Self::ReturnStatus),
102 0xAC => Some(Self::ReturnValue),
103 0xA9 => Some(Self::Order),
104 0xAE => Some(Self::FeatureExtAck),
105 0xE4 => Some(Self::SessionState),
106 0xEE => Some(Self::FedAuthInfo),
107 0xA5 => Some(Self::ColInfo),
108 0xA4 => Some(Self::TabName),
109 0x78 => Some(Self::Offset),
110 _ => None,
111 }
112 }
113}
114
115#[derive(Debug, Clone)]
120#[non_exhaustive]
121pub enum Token {
122 ColMetaData(ColMetaData),
124 Row(RawRow),
126 NbcRow(NbcRow),
128 Done(Done),
130 DoneProc(DoneProc),
132 DoneInProc(DoneInProc),
134 ReturnStatus(i32),
136 ReturnValue(ReturnValue),
138 Error(ServerError),
140 Info(ServerInfo),
142 LoginAck(LoginAck),
144 EnvChange(EnvChange),
146 Order(Order),
148 FeatureExtAck(FeatureExtAck),
150 Sspi(SspiToken),
152 SessionState(SessionState),
154 FedAuthInfo(FedAuthInfo),
156}
157
158#[derive(Debug, Clone, Default)]
160pub struct ColMetaData {
161 pub columns: Vec<ColumnData>,
163 pub cek_table: Option<crate::crypto::CekTable>,
166}
167
168#[derive(Debug, Clone)]
170pub struct ColumnData {
171 pub name: String,
173 pub type_id: TypeId,
175 pub col_type: u8,
177 pub flags: u16,
179 pub user_type: u32,
181 pub type_info: TypeInfo,
183 pub crypto_metadata: Option<crate::crypto::CryptoMetadata>,
186}
187
188#[derive(Debug, Clone, Default)]
190pub struct TypeInfo {
191 pub max_length: Option<u32>,
193 pub precision: Option<u8>,
195 pub scale: Option<u8>,
197 pub collation: Option<Collation>,
199}
200
201#[derive(Debug, Clone, Copy, Default)]
224pub struct Collation {
225 pub lcid: u32,
233 pub sort_id: u8,
237}
238
239impl Collation {
240 pub fn from_bytes(bytes: &[u8; 5]) -> Self {
244 Self {
245 lcid: u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]),
246 sort_id: bytes[4],
247 }
248 }
249
250 pub fn to_bytes(&self) -> [u8; 5] {
254 let b = self.lcid.to_le_bytes();
255 [b[0], b[1], b[2], b[3], self.sort_id]
256 }
257
258 #[cfg(feature = "encoding")]
284 pub fn encoding(&self) -> Option<&'static encoding_rs::Encoding> {
285 if self.sort_id != 0 {
289 return crate::collation::encoding_for_sort_id(self.sort_id);
290 }
291 crate::collation::encoding_for_lcid(self.lcid)
292 }
293
294 #[cfg(feature = "encoding")]
299 pub fn is_utf8(&self) -> bool {
300 crate::collation::is_utf8_collation(self.lcid)
301 }
302
303 #[cfg(feature = "encoding")]
311 pub fn code_page(&self) -> Option<u16> {
312 if self.sort_id != 0 {
315 return crate::collation::code_page_for_sort_id(self.sort_id);
316 }
317 crate::collation::code_page_for_lcid(self.lcid)
318 }
319
320 #[cfg(feature = "encoding")]
324 pub fn encoding_name(&self) -> &'static str {
325 if self.sort_id != 0 {
326 return match crate::collation::encoding_for_sort_id(self.sort_id) {
327 Some(enc) => enc.name(),
328 None => "unsupported",
329 };
330 }
331 crate::collation::encoding_name_for_lcid(self.lcid)
332 }
333}
334
335#[derive(Debug, Clone)]
337pub struct RawRow {
338 pub data: bytes::Bytes,
340}
341
342#[derive(Debug, Clone)]
344pub struct NbcRow {
345 pub null_bitmap: Vec<u8>,
347 pub data: bytes::Bytes,
349}
350
351#[derive(Debug, Clone, Copy)]
353pub struct Done {
354 pub status: DoneStatus,
356 pub cur_cmd: u16,
358 pub row_count: u64,
360}
361
362#[derive(Debug, Clone, Copy, Default)]
364#[non_exhaustive]
365pub struct DoneStatus {
366 pub more: bool,
368 pub error: bool,
370 pub in_xact: bool,
372 pub count: bool,
374 pub attn: bool,
376 pub srverror: bool,
378}
379
380#[derive(Debug, Clone, Copy)]
382pub struct DoneInProc {
383 pub status: DoneStatus,
385 pub cur_cmd: u16,
387 pub row_count: u64,
389}
390
391#[derive(Debug, Clone, Copy)]
393pub struct DoneProc {
394 pub status: DoneStatus,
396 pub cur_cmd: u16,
398 pub row_count: u64,
400}
401
402#[derive(Debug, Clone)]
404#[non_exhaustive]
405pub struct ReturnValue {
406 pub param_ordinal: u16,
408 pub param_name: String,
410 pub status: u8,
412 pub user_type: u32,
414 pub flags: u16,
416 pub col_type: u8,
418 pub type_info: TypeInfo,
420 pub value: bytes::Bytes,
422}
423
424#[derive(Debug, Clone)]
426pub struct ServerError {
427 pub number: i32,
429 pub state: u8,
431 pub class: u8,
433 pub message: String,
435 pub server: String,
437 pub procedure: String,
439 pub line: i32,
441}
442
443#[derive(Debug, Clone)]
445pub struct ServerInfo {
446 pub number: i32,
448 pub state: u8,
450 pub class: u8,
452 pub message: String,
454 pub server: String,
456 pub procedure: String,
458 pub line: i32,
460}
461
462#[derive(Debug, Clone)]
464pub struct LoginAck {
465 pub interface: u8,
467 pub tds_version: u32,
469 pub prog_name: String,
471 pub prog_version: u32,
473}
474
475#[derive(Debug, Clone)]
477pub struct EnvChange {
478 pub env_type: EnvChangeType,
480 pub new_value: EnvChangeValue,
482 pub old_value: EnvChangeValue,
484}
485
486#[derive(Debug, Clone, Copy, PartialEq, Eq)]
488#[repr(u8)]
489#[non_exhaustive]
490pub enum EnvChangeType {
491 Database = 1,
493 Language = 2,
495 CharacterSet = 3,
497 PacketSize = 4,
499 UnicodeSortingLocalId = 5,
501 UnicodeComparisonFlags = 6,
503 SqlCollation = 7,
505 BeginTransaction = 8,
507 CommitTransaction = 9,
509 RollbackTransaction = 10,
511 EnlistDtcTransaction = 11,
513 DefectTransaction = 12,
515 RealTimeLogShipping = 13,
517 PromoteTransaction = 15,
519 TransactionManagerAddress = 16,
521 TransactionEnded = 17,
523 ResetConnectionCompletionAck = 18,
525 UserInstanceStarted = 19,
527 Routing = 20,
529}
530
531#[derive(Debug, Clone)]
533#[non_exhaustive]
534pub enum EnvChangeValue {
535 String(String),
537 Binary(bytes::Bytes),
539 Routing {
541 host: String,
543 port: u16,
545 },
546}
547
548#[derive(Debug, Clone)]
550pub struct Order {
551 pub columns: Vec<u16>,
553}
554
555#[derive(Debug, Clone)]
557pub struct FeatureExtAck {
558 pub features: Vec<FeatureAck>,
560}
561
562#[derive(Debug, Clone)]
564pub struct FeatureAck {
565 pub feature_id: u8,
567 pub data: bytes::Bytes,
569}
570
571#[derive(Debug, Clone)]
573pub struct SspiToken {
574 pub data: bytes::Bytes,
576}
577
578#[derive(Debug, Clone)]
580pub struct SessionState {
581 pub data: bytes::Bytes,
583}
584
585#[derive(Debug, Clone)]
587pub struct FedAuthInfo {
588 pub sts_url: String,
590 pub spn: String,
592}
593
594pub(crate) fn decode_collation(src: &mut impl Buf) -> Result<Collation, ProtocolError> {
602 if src.remaining() < 5 {
603 return Err(ProtocolError::UnexpectedEof);
604 }
605 let lcid = src.get_u32_le();
607 let sort_id = src.get_u8();
608 Ok(Collation { lcid, sort_id })
609}
610
611pub(crate) fn decode_type_info(
615 src: &mut impl Buf,
616 type_id: TypeId,
617 col_type: u8,
618) -> Result<TypeInfo, ProtocolError> {
619 match type_id {
620 TypeId::Null => Ok(TypeInfo::default()),
622 TypeId::Int1 | TypeId::Bit => Ok(TypeInfo::default()),
623 TypeId::Int2 => Ok(TypeInfo::default()),
624 TypeId::Int4 => Ok(TypeInfo::default()),
625 TypeId::Int8 => Ok(TypeInfo::default()),
626 TypeId::Float4 => Ok(TypeInfo::default()),
627 TypeId::Float8 => Ok(TypeInfo::default()),
628 TypeId::Money => Ok(TypeInfo::default()),
629 TypeId::Money4 => Ok(TypeInfo::default()),
630 TypeId::DateTime => Ok(TypeInfo::default()),
631 TypeId::DateTime4 => Ok(TypeInfo::default()),
632
633 TypeId::IntN | TypeId::BitN | TypeId::FloatN | TypeId::MoneyN | TypeId::DateTimeN => {
635 if src.remaining() < 1 {
636 return Err(ProtocolError::UnexpectedEof);
637 }
638 let max_length = src.get_u8() as u32;
639 Ok(TypeInfo {
640 max_length: Some(max_length),
641 ..Default::default()
642 })
643 }
644
645 TypeId::Guid => {
647 if src.remaining() < 1 {
648 return Err(ProtocolError::UnexpectedEof);
649 }
650 let max_length = src.get_u8() as u32;
651 Ok(TypeInfo {
652 max_length: Some(max_length),
653 ..Default::default()
654 })
655 }
656
657 TypeId::Decimal | TypeId::Numeric | TypeId::DecimalN | TypeId::NumericN => {
659 if src.remaining() < 3 {
660 return Err(ProtocolError::UnexpectedEof);
661 }
662 let max_length = src.get_u8() as u32;
663 let precision = src.get_u8();
664 let scale = src.get_u8();
665 Ok(TypeInfo {
666 max_length: Some(max_length),
667 precision: Some(precision),
668 scale: Some(scale),
669 ..Default::default()
670 })
671 }
672
673 TypeId::Char | TypeId::VarChar | TypeId::Binary | TypeId::VarBinary => {
675 if src.remaining() < 1 {
676 return Err(ProtocolError::UnexpectedEof);
677 }
678 let max_length = src.get_u8() as u32;
679 Ok(TypeInfo {
680 max_length: Some(max_length),
681 ..Default::default()
682 })
683 }
684
685 TypeId::BigVarChar | TypeId::BigChar => {
687 if src.remaining() < 7 {
688 return Err(ProtocolError::UnexpectedEof);
690 }
691 let max_length = src.get_u16_le() as u32;
692 let collation = decode_collation(src)?;
693 Ok(TypeInfo {
694 max_length: Some(max_length),
695 collation: Some(collation),
696 ..Default::default()
697 })
698 }
699
700 TypeId::BigVarBinary | TypeId::BigBinary => {
702 if src.remaining() < 2 {
703 return Err(ProtocolError::UnexpectedEof);
704 }
705 let max_length = src.get_u16_le() as u32;
706 Ok(TypeInfo {
707 max_length: Some(max_length),
708 ..Default::default()
709 })
710 }
711
712 TypeId::NChar | TypeId::NVarChar => {
714 if src.remaining() < 7 {
715 return Err(ProtocolError::UnexpectedEof);
717 }
718 let max_length = src.get_u16_le() as u32;
719 let collation = decode_collation(src)?;
720 Ok(TypeInfo {
721 max_length: Some(max_length),
722 collation: Some(collation),
723 ..Default::default()
724 })
725 }
726
727 TypeId::Date => Ok(TypeInfo::default()),
729
730 TypeId::Time | TypeId::DateTime2 | TypeId::DateTimeOffset => {
732 if src.remaining() < 1 {
733 return Err(ProtocolError::UnexpectedEof);
734 }
735 let scale = src.get_u8();
736 Ok(TypeInfo {
737 scale: Some(scale),
738 ..Default::default()
739 })
740 }
741
742 TypeId::Text | TypeId::NText | TypeId::Image => {
744 if src.remaining() < 4 {
746 return Err(ProtocolError::UnexpectedEof);
747 }
748 let max_length = src.get_u32_le();
749
750 let collation = if type_id == TypeId::Text || type_id == TypeId::NText {
752 if src.remaining() < 5 {
753 return Err(ProtocolError::UnexpectedEof);
754 }
755 Some(decode_collation(src)?)
756 } else {
757 None
758 };
759
760 if src.remaining() < 1 {
763 return Err(ProtocolError::UnexpectedEof);
764 }
765 let num_parts = src.get_u8();
766 for _ in 0..num_parts {
767 let _ = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
769 }
770
771 Ok(TypeInfo {
772 max_length: Some(max_length),
773 collation,
774 ..Default::default()
775 })
776 }
777
778 TypeId::Xml => {
780 if src.remaining() < 1 {
781 return Err(ProtocolError::UnexpectedEof);
782 }
783 let schema_present = src.get_u8();
784
785 if schema_present != 0 {
786 let _ = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; let _ = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; let _ = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; }
793
794 Ok(TypeInfo::default())
795 }
796
797 TypeId::Udt => {
799 if src.remaining() < 2 {
801 return Err(ProtocolError::UnexpectedEof);
802 }
803 let max_length = src.get_u16_le() as u32;
804
805 let _ = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; let _ = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; let _ = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; let _ = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; Ok(TypeInfo {
814 max_length: Some(max_length),
815 ..Default::default()
816 })
817 }
818
819 TypeId::Tvp => {
821 Err(ProtocolError::InvalidTokenType(col_type))
824 }
825
826 TypeId::Variant => {
828 if src.remaining() < 4 {
829 return Err(ProtocolError::UnexpectedEof);
830 }
831 let max_length = src.get_u32_le();
832 Ok(TypeInfo {
833 max_length: Some(max_length),
834 ..Default::default()
835 })
836 }
837 }
838}
839
840impl ColMetaData {
841 pub const NO_METADATA: u16 = 0xFFFF;
843
844 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
846 if src.remaining() < 2 {
847 return Err(ProtocolError::UnexpectedEof);
848 }
849
850 let column_count = src.get_u16_le();
851
852 if column_count == Self::NO_METADATA {
854 return Ok(Self {
855 columns: Vec::new(),
856 cek_table: None,
857 });
858 }
859
860 let mut columns = Vec::with_capacity(column_count as usize);
861
862 for _ in 0..column_count {
863 let column = Self::decode_column(src)?;
864 columns.push(column);
865 }
866
867 Ok(Self {
868 columns,
869 cek_table: None,
870 })
871 }
872
873 fn decode_column(src: &mut impl Buf) -> Result<ColumnData, ProtocolError> {
875 if src.remaining() < 7 {
877 return Err(ProtocolError::UnexpectedEof);
878 }
879
880 let user_type = src.get_u32_le();
881 let flags = src.get_u16_le();
882 let col_type = src.get_u8();
883
884 let type_id = TypeId::from_u8(col_type).ok_or(ProtocolError::InvalidDataType(col_type))?;
888
889 let type_info = decode_type_info(src, type_id, col_type)?;
891
892 let name = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
894
895 Ok(ColumnData {
896 name,
897 type_id,
898 col_type,
899 flags,
900 user_type,
901 type_info,
902 crypto_metadata: None,
903 })
904 }
905
906 pub fn decode_encrypted(src: &mut impl Buf) -> Result<Self, ProtocolError> {
919 if src.remaining() < 2 {
920 return Err(ProtocolError::UnexpectedEof);
921 }
922
923 let column_count = src.get_u16_le();
924
925 if column_count == Self::NO_METADATA {
926 return Ok(Self {
927 columns: Vec::new(),
928 cek_table: None,
929 });
930 }
931
932 let cek_table = crate::crypto::CekTable::decode(src)?;
934
935 let mut columns = Vec::with_capacity(column_count as usize);
936
937 for _ in 0..column_count {
938 let column = Self::decode_column_encrypted(src)?;
939 columns.push(column);
940 }
941
942 Ok(Self {
943 columns,
944 cek_table: Some(cek_table),
945 })
946 }
947
948 fn decode_column_encrypted(src: &mut impl Buf) -> Result<ColumnData, ProtocolError> {
952 if src.remaining() < 7 {
953 return Err(ProtocolError::UnexpectedEof);
954 }
955
956 let user_type = src.get_u32_le();
957 let flags = src.get_u16_le();
958 let col_type = src.get_u8();
959
960 let type_id = TypeId::from_u8(col_type).ok_or(ProtocolError::InvalidDataType(col_type))?;
961
962 let type_info = decode_type_info(src, type_id, col_type)?;
964
965 let crypto_metadata = if crate::crypto::is_column_encrypted(flags) {
967 Some(crate::crypto::CryptoMetadata::decode(src)?)
968 } else {
969 None
970 };
971
972 let name = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
974
975 Ok(ColumnData {
976 name,
977 type_id,
978 col_type,
979 flags,
980 user_type,
981 type_info,
982 crypto_metadata,
983 })
984 }
985
986 #[must_use]
988 pub fn column_count(&self) -> usize {
989 self.columns.len()
990 }
991
992 #[must_use]
994 pub fn is_empty(&self) -> bool {
995 self.columns.is_empty()
996 }
997}
998
999impl ColumnData {
1000 #[must_use]
1002 pub fn is_nullable(&self) -> bool {
1003 (self.flags & 0x0001) != 0
1004 }
1005
1006 #[must_use]
1010 pub fn fixed_size(&self) -> Option<usize> {
1011 match self.type_id {
1012 TypeId::Null => Some(0),
1013 TypeId::Int1 | TypeId::Bit => Some(1),
1014 TypeId::Int2 => Some(2),
1015 TypeId::Int4 => Some(4),
1016 TypeId::Int8 => Some(8),
1017 TypeId::Float4 => Some(4),
1018 TypeId::Float8 => Some(8),
1019 TypeId::Money => Some(8),
1020 TypeId::Money4 => Some(4),
1021 TypeId::DateTime => Some(8),
1022 TypeId::DateTime4 => Some(4),
1023 TypeId::Date => Some(3),
1024 _ => None,
1025 }
1026 }
1027}
1028
1029impl RawRow {
1034 pub fn decode(src: &mut impl Buf, metadata: &ColMetaData) -> Result<Self, ProtocolError> {
1039 let mut data = bytes::BytesMut::new();
1040
1041 for col in &metadata.columns {
1042 Self::decode_column_value(src, col, &mut data)?;
1043 }
1044
1045 Ok(Self {
1046 data: data.freeze(),
1047 })
1048 }
1049
1050 fn decode_column_value(
1052 src: &mut impl Buf,
1053 col: &ColumnData,
1054 dst: &mut bytes::BytesMut,
1055 ) -> Result<(), ProtocolError> {
1056 match col.type_id {
1057 TypeId::Null => {
1059 }
1061 TypeId::Int1 | TypeId::Bit => {
1062 if src.remaining() < 1 {
1063 return Err(ProtocolError::UnexpectedEof);
1064 }
1065 dst.extend_from_slice(&[src.get_u8()]);
1066 }
1067 TypeId::Int2 => {
1068 if src.remaining() < 2 {
1069 return Err(ProtocolError::UnexpectedEof);
1070 }
1071 dst.extend_from_slice(&src.get_u16_le().to_le_bytes());
1072 }
1073 TypeId::Int4 => {
1074 if src.remaining() < 4 {
1075 return Err(ProtocolError::UnexpectedEof);
1076 }
1077 dst.extend_from_slice(&src.get_u32_le().to_le_bytes());
1078 }
1079 TypeId::Int8 => {
1080 if src.remaining() < 8 {
1081 return Err(ProtocolError::UnexpectedEof);
1082 }
1083 dst.extend_from_slice(&src.get_u64_le().to_le_bytes());
1084 }
1085 TypeId::Float4 => {
1086 if src.remaining() < 4 {
1087 return Err(ProtocolError::UnexpectedEof);
1088 }
1089 dst.extend_from_slice(&src.get_u32_le().to_le_bytes());
1090 }
1091 TypeId::Float8 => {
1092 if src.remaining() < 8 {
1093 return Err(ProtocolError::UnexpectedEof);
1094 }
1095 dst.extend_from_slice(&src.get_u64_le().to_le_bytes());
1096 }
1097 TypeId::Money => {
1098 if src.remaining() < 8 {
1099 return Err(ProtocolError::UnexpectedEof);
1100 }
1101 let hi = src.get_u32_le();
1102 let lo = src.get_u32_le();
1103 dst.extend_from_slice(&hi.to_le_bytes());
1104 dst.extend_from_slice(&lo.to_le_bytes());
1105 }
1106 TypeId::Money4 => {
1107 if src.remaining() < 4 {
1108 return Err(ProtocolError::UnexpectedEof);
1109 }
1110 dst.extend_from_slice(&src.get_u32_le().to_le_bytes());
1111 }
1112 TypeId::DateTime => {
1113 if src.remaining() < 8 {
1114 return Err(ProtocolError::UnexpectedEof);
1115 }
1116 let days = src.get_u32_le();
1117 let time = src.get_u32_le();
1118 dst.extend_from_slice(&days.to_le_bytes());
1119 dst.extend_from_slice(&time.to_le_bytes());
1120 }
1121 TypeId::DateTime4 => {
1122 if src.remaining() < 4 {
1123 return Err(ProtocolError::UnexpectedEof);
1124 }
1125 dst.extend_from_slice(&src.get_u32_le().to_le_bytes());
1126 }
1127 TypeId::Date => {
1129 Self::decode_bytelen_type(src, dst)?;
1130 }
1131
1132 TypeId::IntN | TypeId::BitN | TypeId::FloatN | TypeId::MoneyN | TypeId::DateTimeN => {
1134 Self::decode_bytelen_type(src, dst)?;
1135 }
1136
1137 TypeId::Guid => {
1138 Self::decode_bytelen_type(src, dst)?;
1139 }
1140
1141 TypeId::Decimal | TypeId::Numeric | TypeId::DecimalN | TypeId::NumericN => {
1142 Self::decode_bytelen_type(src, dst)?;
1143 }
1144
1145 TypeId::Char | TypeId::VarChar | TypeId::Binary | TypeId::VarBinary => {
1147 Self::decode_bytelen_type(src, dst)?;
1148 }
1149
1150 TypeId::BigVarChar | TypeId::BigVarBinary => {
1152 if col.type_info.max_length == Some(0xFFFF) {
1154 Self::decode_plp_type(src, dst)?;
1155 } else {
1156 Self::decode_ushortlen_type(src, dst)?;
1157 }
1158 }
1159
1160 TypeId::BigChar | TypeId::BigBinary => {
1162 Self::decode_ushortlen_type(src, dst)?;
1163 }
1164
1165 TypeId::NVarChar => {
1167 if col.type_info.max_length == Some(0xFFFF) {
1169 Self::decode_plp_type(src, dst)?;
1170 } else {
1171 Self::decode_ushortlen_type(src, dst)?;
1172 }
1173 }
1174
1175 TypeId::NChar => {
1177 Self::decode_ushortlen_type(src, dst)?;
1178 }
1179
1180 TypeId::Time | TypeId::DateTime2 | TypeId::DateTimeOffset => {
1182 Self::decode_bytelen_type(src, dst)?;
1183 }
1184
1185 TypeId::Text | TypeId::NText | TypeId::Image => {
1187 Self::decode_textptr_type(src, dst)?;
1188 }
1189
1190 TypeId::Xml => {
1192 Self::decode_plp_type(src, dst)?;
1193 }
1194
1195 TypeId::Variant => {
1197 Self::decode_intlen_type(src, dst)?;
1198 }
1199
1200 TypeId::Udt => {
1201 Self::decode_plp_type(src, dst)?;
1203 }
1204
1205 TypeId::Tvp => {
1206 return Err(ProtocolError::InvalidTokenType(col.col_type));
1208 }
1209 }
1210
1211 Ok(())
1212 }
1213
1214 fn decode_bytelen_type(
1216 src: &mut impl Buf,
1217 dst: &mut bytes::BytesMut,
1218 ) -> Result<(), ProtocolError> {
1219 if src.remaining() < 1 {
1220 return Err(ProtocolError::UnexpectedEof);
1221 }
1222 let len = src.get_u8() as usize;
1223 if len == 0xFF {
1224 dst.extend_from_slice(&[0xFF]);
1226 } else if len == 0 {
1227 dst.extend_from_slice(&[0x00]);
1229 } else {
1230 if src.remaining() < len {
1231 return Err(ProtocolError::UnexpectedEof);
1232 }
1233 dst.extend_from_slice(&[len as u8]);
1234 for _ in 0..len {
1235 dst.extend_from_slice(&[src.get_u8()]);
1236 }
1237 }
1238 Ok(())
1239 }
1240
1241 fn decode_ushortlen_type(
1243 src: &mut impl Buf,
1244 dst: &mut bytes::BytesMut,
1245 ) -> Result<(), ProtocolError> {
1246 if src.remaining() < 2 {
1247 return Err(ProtocolError::UnexpectedEof);
1248 }
1249 let len = src.get_u16_le() as usize;
1250 if len == 0xFFFF {
1251 dst.extend_from_slice(&0xFFFFu16.to_le_bytes());
1253 } else if len == 0 {
1254 dst.extend_from_slice(&0u16.to_le_bytes());
1256 } else {
1257 if src.remaining() < len {
1258 return Err(ProtocolError::UnexpectedEof);
1259 }
1260 dst.extend_from_slice(&(len as u16).to_le_bytes());
1261 for _ in 0..len {
1262 dst.extend_from_slice(&[src.get_u8()]);
1263 }
1264 }
1265 Ok(())
1266 }
1267
1268 fn decode_intlen_type(
1270 src: &mut impl Buf,
1271 dst: &mut bytes::BytesMut,
1272 ) -> Result<(), ProtocolError> {
1273 if src.remaining() < 4 {
1274 return Err(ProtocolError::UnexpectedEof);
1275 }
1276 let len = src.get_u32_le() as usize;
1277 if len == 0xFFFFFFFF {
1278 dst.extend_from_slice(&0xFFFFFFFFu32.to_le_bytes());
1280 } else if len == 0 {
1281 dst.extend_from_slice(&0u32.to_le_bytes());
1283 } else {
1284 if src.remaining() < len {
1285 return Err(ProtocolError::UnexpectedEof);
1286 }
1287 dst.extend_from_slice(&(len as u32).to_le_bytes());
1288 for _ in 0..len {
1289 dst.extend_from_slice(&[src.get_u8()]);
1290 }
1291 }
1292 Ok(())
1293 }
1294
1295 fn decode_textptr_type(
1310 src: &mut impl Buf,
1311 dst: &mut bytes::BytesMut,
1312 ) -> Result<(), ProtocolError> {
1313 if src.remaining() < 1 {
1314 return Err(ProtocolError::UnexpectedEof);
1315 }
1316
1317 let textptr_len = src.get_u8() as usize;
1318
1319 if textptr_len == 0 {
1320 dst.extend_from_slice(&0xFFFFFFFFFFFFFFFFu64.to_le_bytes());
1322 return Ok(());
1323 }
1324
1325 if src.remaining() < textptr_len {
1327 return Err(ProtocolError::UnexpectedEof);
1328 }
1329 src.advance(textptr_len);
1330
1331 if src.remaining() < 8 {
1333 return Err(ProtocolError::UnexpectedEof);
1334 }
1335 src.advance(8);
1336
1337 if src.remaining() < 4 {
1339 return Err(ProtocolError::UnexpectedEof);
1340 }
1341 let data_len = src.get_u32_le() as usize;
1342
1343 if src.remaining() < data_len {
1344 return Err(ProtocolError::UnexpectedEof);
1345 }
1346
1347 dst.extend_from_slice(&(data_len as u64).to_le_bytes());
1353 dst.extend_from_slice(&(data_len as u32).to_le_bytes());
1354 for _ in 0..data_len {
1355 dst.extend_from_slice(&[src.get_u8()]);
1356 }
1357 dst.extend_from_slice(&0u32.to_le_bytes()); Ok(())
1360 }
1361
1362 fn decode_plp_type(src: &mut impl Buf, dst: &mut bytes::BytesMut) -> Result<(), ProtocolError> {
1368 if src.remaining() < 8 {
1369 return Err(ProtocolError::UnexpectedEof);
1370 }
1371
1372 let total_len = src.get_u64_le();
1373
1374 dst.extend_from_slice(&total_len.to_le_bytes());
1376
1377 if total_len == 0xFFFFFFFFFFFFFFFF {
1378 return Ok(());
1380 }
1381
1382 loop {
1384 if src.remaining() < 4 {
1385 return Err(ProtocolError::UnexpectedEof);
1386 }
1387 let chunk_len = src.get_u32_le() as usize;
1388 dst.extend_from_slice(&(chunk_len as u32).to_le_bytes());
1389
1390 if chunk_len == 0 {
1391 break;
1393 }
1394
1395 if src.remaining() < chunk_len {
1396 return Err(ProtocolError::UnexpectedEof);
1397 }
1398
1399 for _ in 0..chunk_len {
1400 dst.extend_from_slice(&[src.get_u8()]);
1401 }
1402 }
1403
1404 Ok(())
1405 }
1406}
1407
1408impl NbcRow {
1413 pub fn decode(src: &mut impl Buf, metadata: &ColMetaData) -> Result<Self, ProtocolError> {
1418 let col_count = metadata.columns.len();
1419 let bitmap_len = col_count.div_ceil(8);
1420
1421 if src.remaining() < bitmap_len {
1422 return Err(ProtocolError::UnexpectedEof);
1423 }
1424
1425 let mut null_bitmap = vec![0u8; bitmap_len];
1427 for byte in &mut null_bitmap {
1428 *byte = src.get_u8();
1429 }
1430
1431 let mut data = bytes::BytesMut::new();
1433
1434 for (i, col) in metadata.columns.iter().enumerate() {
1435 let byte_idx = i / 8;
1436 let bit_idx = i % 8;
1437 let is_null = (null_bitmap[byte_idx] & (1 << bit_idx)) != 0;
1438
1439 if !is_null {
1440 RawRow::decode_column_value(src, col, &mut data)?;
1443 }
1444 }
1445
1446 Ok(Self {
1447 null_bitmap,
1448 data: data.freeze(),
1449 })
1450 }
1451
1452 #[must_use]
1454 pub fn is_null(&self, column_index: usize) -> bool {
1455 let byte_idx = column_index / 8;
1456 let bit_idx = column_index % 8;
1457 if byte_idx < self.null_bitmap.len() {
1458 (self.null_bitmap[byte_idx] & (1 << bit_idx)) != 0
1459 } else {
1460 true }
1462 }
1463}
1464
1465impl ReturnValue {
1470 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
1472 if src.remaining() < 2 {
1479 return Err(ProtocolError::UnexpectedEof);
1480 }
1481 let param_ordinal = src.get_u16_le();
1482
1483 let param_name = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
1485
1486 if src.remaining() < 1 {
1488 return Err(ProtocolError::UnexpectedEof);
1489 }
1490 let status = src.get_u8();
1491
1492 if src.remaining() < 7 {
1494 return Err(ProtocolError::UnexpectedEof);
1495 }
1496 let user_type = src.get_u32_le();
1497 let flags = src.get_u16_le();
1498 let col_type = src.get_u8();
1499
1500 let type_id = TypeId::from_u8(col_type).ok_or(ProtocolError::InvalidDataType(col_type))?;
1501
1502 let type_info = decode_type_info(src, type_id, col_type)?;
1504
1505 let mut value_buf = bytes::BytesMut::new();
1507
1508 let temp_col = ColumnData {
1510 name: String::new(),
1511 type_id,
1512 col_type,
1513 flags,
1514 user_type,
1515 type_info: type_info.clone(),
1516 crypto_metadata: None,
1517 };
1518
1519 RawRow::decode_column_value(src, &temp_col, &mut value_buf)?;
1520
1521 Ok(Self {
1522 param_ordinal,
1523 param_name,
1524 status,
1525 user_type,
1526 flags,
1527 col_type,
1528 type_info,
1529 value: value_buf.freeze(),
1530 })
1531 }
1532}
1533
1534impl SessionState {
1539 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
1541 if src.remaining() < 4 {
1542 return Err(ProtocolError::UnexpectedEof);
1543 }
1544
1545 let length = src.get_u32_le() as usize;
1546
1547 if src.remaining() < length {
1548 return Err(ProtocolError::IncompletePacket {
1549 expected: length,
1550 actual: src.remaining(),
1551 });
1552 }
1553
1554 let data = src.copy_to_bytes(length);
1555
1556 Ok(Self { data })
1557 }
1558}
1559
1560mod done_status_bits {
1566 pub const DONE_MORE: u16 = 0x0001;
1567 pub const DONE_ERROR: u16 = 0x0002;
1568 pub const DONE_INXACT: u16 = 0x0004;
1569 pub const DONE_COUNT: u16 = 0x0010;
1570 pub const DONE_ATTN: u16 = 0x0020;
1571 pub const DONE_SRVERROR: u16 = 0x0100;
1572}
1573
1574impl DoneStatus {
1575 #[must_use]
1577 pub fn from_bits(bits: u16) -> Self {
1578 use done_status_bits::*;
1579 Self {
1580 more: (bits & DONE_MORE) != 0,
1581 error: (bits & DONE_ERROR) != 0,
1582 in_xact: (bits & DONE_INXACT) != 0,
1583 count: (bits & DONE_COUNT) != 0,
1584 attn: (bits & DONE_ATTN) != 0,
1585 srverror: (bits & DONE_SRVERROR) != 0,
1586 }
1587 }
1588
1589 #[must_use]
1591 pub fn to_bits(&self) -> u16 {
1592 use done_status_bits::*;
1593 let mut bits = 0u16;
1594 if self.more {
1595 bits |= DONE_MORE;
1596 }
1597 if self.error {
1598 bits |= DONE_ERROR;
1599 }
1600 if self.in_xact {
1601 bits |= DONE_INXACT;
1602 }
1603 if self.count {
1604 bits |= DONE_COUNT;
1605 }
1606 if self.attn {
1607 bits |= DONE_ATTN;
1608 }
1609 if self.srverror {
1610 bits |= DONE_SRVERROR;
1611 }
1612 bits
1613 }
1614}
1615
1616impl Done {
1617 pub const SIZE: usize = 12; pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
1622 if src.remaining() < Self::SIZE {
1623 return Err(ProtocolError::IncompletePacket {
1624 expected: Self::SIZE,
1625 actual: src.remaining(),
1626 });
1627 }
1628
1629 let status = DoneStatus::from_bits(src.get_u16_le());
1630 let cur_cmd = src.get_u16_le();
1631 let row_count = src.get_u64_le();
1632
1633 Ok(Self {
1634 status,
1635 cur_cmd,
1636 row_count,
1637 })
1638 }
1639
1640 pub fn encode(&self, dst: &mut impl BufMut) {
1642 dst.put_u8(TokenType::Done as u8);
1643 dst.put_u16_le(self.status.to_bits());
1644 dst.put_u16_le(self.cur_cmd);
1645 dst.put_u64_le(self.row_count);
1646 }
1647
1648 #[must_use]
1650 pub const fn has_more(&self) -> bool {
1651 self.status.more
1652 }
1653
1654 #[must_use]
1656 pub const fn has_error(&self) -> bool {
1657 self.status.error
1658 }
1659
1660 #[must_use]
1662 pub const fn has_count(&self) -> bool {
1663 self.status.count
1664 }
1665}
1666
1667impl DoneProc {
1668 pub const SIZE: usize = 12;
1670
1671 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
1673 if src.remaining() < Self::SIZE {
1674 return Err(ProtocolError::IncompletePacket {
1675 expected: Self::SIZE,
1676 actual: src.remaining(),
1677 });
1678 }
1679
1680 let status = DoneStatus::from_bits(src.get_u16_le());
1681 let cur_cmd = src.get_u16_le();
1682 let row_count = src.get_u64_le();
1683
1684 Ok(Self {
1685 status,
1686 cur_cmd,
1687 row_count,
1688 })
1689 }
1690
1691 pub fn encode(&self, dst: &mut impl BufMut) {
1693 dst.put_u8(TokenType::DoneProc as u8);
1694 dst.put_u16_le(self.status.to_bits());
1695 dst.put_u16_le(self.cur_cmd);
1696 dst.put_u64_le(self.row_count);
1697 }
1698}
1699
1700impl DoneInProc {
1701 pub const SIZE: usize = 12;
1703
1704 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
1706 if src.remaining() < Self::SIZE {
1707 return Err(ProtocolError::IncompletePacket {
1708 expected: Self::SIZE,
1709 actual: src.remaining(),
1710 });
1711 }
1712
1713 let status = DoneStatus::from_bits(src.get_u16_le());
1714 let cur_cmd = src.get_u16_le();
1715 let row_count = src.get_u64_le();
1716
1717 Ok(Self {
1718 status,
1719 cur_cmd,
1720 row_count,
1721 })
1722 }
1723
1724 pub fn encode(&self, dst: &mut impl BufMut) {
1726 dst.put_u8(TokenType::DoneInProc as u8);
1727 dst.put_u16_le(self.status.to_bits());
1728 dst.put_u16_le(self.cur_cmd);
1729 dst.put_u64_le(self.row_count);
1730 }
1731}
1732
1733impl ServerError {
1734 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
1736 if src.remaining() < 2 {
1739 return Err(ProtocolError::UnexpectedEof);
1740 }
1741
1742 let _length = src.get_u16_le();
1743
1744 if src.remaining() < 6 {
1745 return Err(ProtocolError::UnexpectedEof);
1746 }
1747
1748 let number = src.get_i32_le();
1749 let state = src.get_u8();
1750 let class = src.get_u8();
1751
1752 let message = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
1753 let server = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
1754 let procedure = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
1755
1756 if src.remaining() < 4 {
1757 return Err(ProtocolError::UnexpectedEof);
1758 }
1759 let line = src.get_i32_le();
1760
1761 Ok(Self {
1762 number,
1763 state,
1764 class,
1765 message,
1766 server,
1767 procedure,
1768 line,
1769 })
1770 }
1771
1772 #[must_use]
1774 pub const fn is_fatal(&self) -> bool {
1775 self.class >= 20
1776 }
1777
1778 #[must_use]
1780 pub const fn is_batch_abort(&self) -> bool {
1781 self.class >= 16
1782 }
1783}
1784
1785impl ServerInfo {
1786 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
1790 if src.remaining() < 2 {
1791 return Err(ProtocolError::UnexpectedEof);
1792 }
1793
1794 let _length = src.get_u16_le();
1795
1796 if src.remaining() < 6 {
1797 return Err(ProtocolError::UnexpectedEof);
1798 }
1799
1800 let number = src.get_i32_le();
1801 let state = src.get_u8();
1802 let class = src.get_u8();
1803
1804 let message = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
1805 let server = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
1806 let procedure = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
1807
1808 if src.remaining() < 4 {
1809 return Err(ProtocolError::UnexpectedEof);
1810 }
1811 let line = src.get_i32_le();
1812
1813 Ok(Self {
1814 number,
1815 state,
1816 class,
1817 message,
1818 server,
1819 procedure,
1820 line,
1821 })
1822 }
1823}
1824
1825impl LoginAck {
1826 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
1828 if src.remaining() < 2 {
1830 return Err(ProtocolError::UnexpectedEof);
1831 }
1832
1833 let _length = src.get_u16_le();
1834
1835 if src.remaining() < 5 {
1836 return Err(ProtocolError::UnexpectedEof);
1837 }
1838
1839 let interface = src.get_u8();
1840 let tds_version = src.get_u32_le();
1841 let prog_name = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
1842
1843 if src.remaining() < 4 {
1844 return Err(ProtocolError::UnexpectedEof);
1845 }
1846 let prog_version = src.get_u32_le();
1847
1848 Ok(Self {
1849 interface,
1850 tds_version,
1851 prog_name,
1852 prog_version,
1853 })
1854 }
1855
1856 #[must_use]
1858 pub fn tds_version(&self) -> crate::version::TdsVersion {
1859 crate::version::TdsVersion::new(self.tds_version)
1860 }
1861}
1862
1863impl EnvChangeType {
1864 pub fn from_u8(value: u8) -> Option<Self> {
1866 match value {
1867 1 => Some(Self::Database),
1868 2 => Some(Self::Language),
1869 3 => Some(Self::CharacterSet),
1870 4 => Some(Self::PacketSize),
1871 5 => Some(Self::UnicodeSortingLocalId),
1872 6 => Some(Self::UnicodeComparisonFlags),
1873 7 => Some(Self::SqlCollation),
1874 8 => Some(Self::BeginTransaction),
1875 9 => Some(Self::CommitTransaction),
1876 10 => Some(Self::RollbackTransaction),
1877 11 => Some(Self::EnlistDtcTransaction),
1878 12 => Some(Self::DefectTransaction),
1879 13 => Some(Self::RealTimeLogShipping),
1880 15 => Some(Self::PromoteTransaction),
1881 16 => Some(Self::TransactionManagerAddress),
1882 17 => Some(Self::TransactionEnded),
1883 18 => Some(Self::ResetConnectionCompletionAck),
1884 19 => Some(Self::UserInstanceStarted),
1885 20 => Some(Self::Routing),
1886 _ => None,
1887 }
1888 }
1889}
1890
1891impl EnvChange {
1892 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
1894 if src.remaining() < 3 {
1895 return Err(ProtocolError::UnexpectedEof);
1896 }
1897
1898 let length = src.get_u16_le() as usize;
1899 if length == 0 {
1900 return Err(ProtocolError::UnexpectedEof);
1903 }
1904 if src.remaining() < length {
1905 return Err(ProtocolError::IncompletePacket {
1906 expected: length,
1907 actual: src.remaining(),
1908 });
1909 }
1910
1911 let mut frame = src.copy_to_bytes(length);
1920 let src = &mut frame;
1921
1922 let env_type_byte = src.get_u8();
1923 let env_type = EnvChangeType::from_u8(env_type_byte)
1924 .ok_or(ProtocolError::InvalidTokenType(env_type_byte))?;
1925
1926 let (new_value, old_value) = match env_type {
1927 EnvChangeType::Routing => {
1928 let new_value = Self::decode_routing_value(src)?;
1930 let old_value = EnvChangeValue::Binary(Bytes::new());
1931 (new_value, old_value)
1932 }
1933 EnvChangeType::BeginTransaction
1934 | EnvChangeType::CommitTransaction
1935 | EnvChangeType::RollbackTransaction
1936 | EnvChangeType::EnlistDtcTransaction
1937 | EnvChangeType::SqlCollation => {
1938 let new_len = if src.has_remaining() {
1947 src.get_u8() as usize
1948 } else {
1949 0
1950 };
1951 let new_value = if new_len > 0 && src.remaining() >= new_len {
1952 EnvChangeValue::Binary(src.copy_to_bytes(new_len))
1953 } else {
1954 EnvChangeValue::Binary(Bytes::new())
1955 };
1956
1957 let old_len = if src.has_remaining() {
1958 src.get_u8() as usize
1959 } else {
1960 0
1961 };
1962 let old_value = if old_len > 0 && src.remaining() >= old_len {
1963 EnvChangeValue::Binary(src.copy_to_bytes(old_len))
1964 } else {
1965 EnvChangeValue::Binary(Bytes::new())
1966 };
1967
1968 (new_value, old_value)
1969 }
1970 _ => {
1971 let new_value = read_b_varchar(src)
1973 .map(EnvChangeValue::String)
1974 .unwrap_or(EnvChangeValue::String(String::new()));
1975
1976 let old_value = read_b_varchar(src)
1977 .map(EnvChangeValue::String)
1978 .unwrap_or(EnvChangeValue::String(String::new()));
1979
1980 (new_value, old_value)
1981 }
1982 };
1983
1984 Ok(Self {
1990 env_type,
1991 new_value,
1992 old_value,
1993 })
1994 }
1995
1996 fn decode_routing_value(src: &mut impl Buf) -> Result<EnvChangeValue, ProtocolError> {
1997 if src.remaining() < 2 {
1999 return Err(ProtocolError::UnexpectedEof);
2000 }
2001
2002 let _routing_len = src.get_u16_le();
2003
2004 if src.remaining() < 5 {
2005 return Err(ProtocolError::UnexpectedEof);
2006 }
2007
2008 let _protocol = src.get_u8();
2009 let port = src.get_u16_le();
2010 let server_len = src.get_u16_le() as usize;
2011
2012 if src.remaining() < server_len * 2 {
2014 return Err(ProtocolError::UnexpectedEof);
2015 }
2016
2017 let mut chars = Vec::with_capacity(server_len);
2018 for _ in 0..server_len {
2019 chars.push(src.get_u16_le());
2020 }
2021
2022 let host = String::from_utf16(&chars).map_err(|_| {
2023 ProtocolError::StringEncoding(
2024 #[cfg(feature = "std")]
2025 "invalid UTF-16 in routing hostname".to_string(),
2026 #[cfg(not(feature = "std"))]
2027 "invalid UTF-16 in routing hostname",
2028 )
2029 })?;
2030
2031 Ok(EnvChangeValue::Routing { host, port })
2032 }
2033
2034 #[must_use]
2036 pub fn is_routing(&self) -> bool {
2037 self.env_type == EnvChangeType::Routing
2038 }
2039
2040 #[must_use]
2042 pub fn routing_info(&self) -> Option<(&str, u16)> {
2043 if let EnvChangeValue::Routing { host, port } = &self.new_value {
2044 Some((host, *port))
2045 } else {
2046 None
2047 }
2048 }
2049
2050 #[must_use]
2052 pub fn new_database(&self) -> Option<&str> {
2053 if self.env_type == EnvChangeType::Database {
2054 if let EnvChangeValue::String(s) = &self.new_value {
2055 return Some(s);
2056 }
2057 }
2058 None
2059 }
2060}
2061
2062impl Order {
2063 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
2065 if src.remaining() < 2 {
2066 return Err(ProtocolError::UnexpectedEof);
2067 }
2068
2069 let length = src.get_u16_le() as usize;
2070 let column_count = length / 2;
2071
2072 if src.remaining() < length {
2073 return Err(ProtocolError::IncompletePacket {
2074 expected: length,
2075 actual: src.remaining(),
2076 });
2077 }
2078
2079 let mut columns = Vec::with_capacity(column_count);
2080 for _ in 0..column_count {
2081 columns.push(src.get_u16_le());
2082 }
2083
2084 Ok(Self { columns })
2085 }
2086}
2087
2088impl FeatureExtAck {
2089 pub const TERMINATOR: u8 = 0xFF;
2091
2092 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
2094 let mut features = Vec::new();
2095
2096 loop {
2097 if !src.has_remaining() {
2098 return Err(ProtocolError::UnexpectedEof);
2099 }
2100
2101 let feature_id = src.get_u8();
2102 if feature_id == Self::TERMINATOR {
2103 break;
2104 }
2105
2106 if src.remaining() < 4 {
2107 return Err(ProtocolError::UnexpectedEof);
2108 }
2109
2110 let data_len = src.get_u32_le() as usize;
2111
2112 if src.remaining() < data_len {
2113 return Err(ProtocolError::IncompletePacket {
2114 expected: data_len,
2115 actual: src.remaining(),
2116 });
2117 }
2118
2119 let data = src.copy_to_bytes(data_len);
2120 features.push(FeatureAck { feature_id, data });
2121 }
2122
2123 Ok(Self { features })
2124 }
2125}
2126
2127impl SspiToken {
2128 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
2130 if src.remaining() < 2 {
2131 return Err(ProtocolError::UnexpectedEof);
2132 }
2133
2134 let length = src.get_u16_le() as usize;
2135
2136 if src.remaining() < length {
2137 return Err(ProtocolError::IncompletePacket {
2138 expected: length,
2139 actual: src.remaining(),
2140 });
2141 }
2142
2143 let data = src.copy_to_bytes(length);
2144 Ok(Self { data })
2145 }
2146}
2147
2148impl FedAuthInfo {
2149 const ID_SPN: u8 = 0x01;
2151 const ID_STSURL: u8 = 0x02;
2153 const OPT_HEADER_LEN: usize = 9;
2155
2156 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
2168 if src.remaining() < 4 {
2169 return Err(ProtocolError::UnexpectedEof);
2170 }
2171 let token_len = src.get_u32_le() as usize;
2172 if src.remaining() < token_len {
2173 return Err(ProtocolError::UnexpectedEof);
2174 }
2175
2176 let region = src.copy_to_bytes(token_len);
2179 if region.len() < 4 {
2180 return Err(ProtocolError::UnexpectedEof);
2181 }
2182 let count = u32::from_le_bytes([region[0], region[1], region[2], region[3]]) as usize;
2183
2184 let headers_end = count
2188 .checked_mul(Self::OPT_HEADER_LEN)
2189 .and_then(|n| n.checked_add(4))
2190 .ok_or(ProtocolError::UnexpectedEof)?;
2191 if headers_end > region.len() {
2192 return Err(ProtocolError::UnexpectedEof);
2193 }
2194
2195 let mut sts_url = String::new();
2196 let mut spn = String::new();
2197
2198 for i in 0..count {
2199 let h = 4 + i * Self::OPT_HEADER_LEN;
2200 let info_id = region[h];
2201 let data_len =
2202 u32::from_le_bytes([region[h + 1], region[h + 2], region[h + 3], region[h + 4]])
2203 as usize;
2204 let data_off =
2205 u32::from_le_bytes([region[h + 5], region[h + 6], region[h + 7], region[h + 8]])
2206 as usize;
2207
2208 if info_id != Self::ID_SPN && info_id != Self::ID_STSURL {
2211 continue;
2212 }
2213
2214 let data_end = data_off
2215 .checked_add(data_len)
2216 .ok_or(ProtocolError::UnexpectedEof)?;
2217 if data_end > region.len() {
2218 return Err(ProtocolError::UnexpectedEof);
2219 }
2220 if data_len % 2 != 0 {
2221 return Err(ProtocolError::StringEncoding(
2222 #[cfg(feature = "std")]
2223 "FEDAUTHINFO option data has odd length, not UTF-16".to_string(),
2224 #[cfg(not(feature = "std"))]
2225 "FEDAUTHINFO option data has odd length, not UTF-16",
2226 ));
2227 }
2228
2229 let chars: Vec<u16> = region[data_off..data_end]
2230 .chunks_exact(2)
2231 .map(|b| u16::from_le_bytes([b[0], b[1]]))
2232 .collect();
2233 let value = String::from_utf16(&chars).map_err(|_| {
2234 ProtocolError::StringEncoding(
2235 #[cfg(feature = "std")]
2236 "invalid UTF-16 in FEDAUTHINFO option".to_string(),
2237 #[cfg(not(feature = "std"))]
2238 "invalid UTF-16 in FEDAUTHINFO option",
2239 )
2240 })?;
2241
2242 if info_id == Self::ID_SPN {
2243 spn = value;
2244 } else {
2245 sts_url = value;
2246 }
2247 }
2248
2249 Ok(Self { sts_url, spn })
2250 }
2251}
2252
2253pub struct TokenParser {
2294 data: Bytes,
2295 position: usize,
2296 encryption_enabled: bool,
2299}
2300
2301impl TokenParser {
2302 #[must_use]
2304 pub fn new(data: Bytes) -> Self {
2305 Self {
2306 data,
2307 position: 0,
2308 encryption_enabled: false,
2309 }
2310 }
2311
2312 #[must_use]
2317 pub fn with_encryption(mut self, enabled: bool) -> Self {
2318 self.encryption_enabled = enabled;
2319 self
2320 }
2321
2322 #[must_use]
2324 pub fn remaining(&self) -> usize {
2325 self.data.len().saturating_sub(self.position)
2326 }
2327
2328 #[must_use]
2330 pub fn has_remaining(&self) -> bool {
2331 self.position < self.data.len()
2332 }
2333
2334 #[must_use]
2336 pub fn peek_token_type(&self) -> Option<TokenType> {
2337 if self.position < self.data.len() {
2338 TokenType::from_u8(self.data[self.position])
2339 } else {
2340 None
2341 }
2342 }
2343
2344 pub fn next_token(&mut self) -> Result<Option<Token>, ProtocolError> {
2352 self.next_token_with_metadata(None)
2353 }
2354
2355 pub fn next_token_with_metadata(
2362 &mut self,
2363 metadata: Option<&ColMetaData>,
2364 ) -> Result<Option<Token>, ProtocolError> {
2365 if !self.has_remaining() {
2366 return Ok(None);
2367 }
2368
2369 let mut buf = &self.data[self.position..];
2370 let start_pos = self.position;
2371
2372 let token_type_byte = buf.get_u8();
2373 let token_type = TokenType::from_u8(token_type_byte);
2374
2375 let token = match token_type {
2376 Some(TokenType::Done) => {
2377 let done = Done::decode(&mut buf)?;
2378 Token::Done(done)
2379 }
2380 Some(TokenType::DoneProc) => {
2381 let done = DoneProc::decode(&mut buf)?;
2382 Token::DoneProc(done)
2383 }
2384 Some(TokenType::DoneInProc) => {
2385 let done = DoneInProc::decode(&mut buf)?;
2386 Token::DoneInProc(done)
2387 }
2388 Some(TokenType::Error) => {
2389 let error = ServerError::decode(&mut buf)?;
2390 Token::Error(error)
2391 }
2392 Some(TokenType::Info) => {
2393 let info = ServerInfo::decode(&mut buf)?;
2394 Token::Info(info)
2395 }
2396 Some(TokenType::LoginAck) => {
2397 let login_ack = LoginAck::decode(&mut buf)?;
2398 Token::LoginAck(login_ack)
2399 }
2400 Some(TokenType::EnvChange) => {
2401 let env_change = EnvChange::decode(&mut buf)?;
2402 Token::EnvChange(env_change)
2403 }
2404 Some(TokenType::Order) => {
2405 let order = Order::decode(&mut buf)?;
2406 Token::Order(order)
2407 }
2408 Some(TokenType::FeatureExtAck) => {
2409 let ack = FeatureExtAck::decode(&mut buf)?;
2410 Token::FeatureExtAck(ack)
2411 }
2412 Some(TokenType::Sspi) => {
2413 let sspi = SspiToken::decode(&mut buf)?;
2414 Token::Sspi(sspi)
2415 }
2416 Some(TokenType::FedAuthInfo) => {
2417 let info = FedAuthInfo::decode(&mut buf)?;
2418 Token::FedAuthInfo(info)
2419 }
2420 Some(TokenType::ReturnStatus) => {
2421 if buf.remaining() < 4 {
2422 return Err(ProtocolError::UnexpectedEof);
2423 }
2424 let status = buf.get_i32_le();
2425 Token::ReturnStatus(status)
2426 }
2427 Some(TokenType::ColMetaData) => {
2428 let col_meta = if self.encryption_enabled {
2429 ColMetaData::decode_encrypted(&mut buf)?
2430 } else {
2431 ColMetaData::decode(&mut buf)?
2432 };
2433 Token::ColMetaData(col_meta)
2434 }
2435 Some(TokenType::Row) => {
2436 let meta = metadata.ok_or_else(|| {
2437 ProtocolError::StringEncoding(
2438 #[cfg(feature = "std")]
2439 "Row token requires column metadata".to_string(),
2440 #[cfg(not(feature = "std"))]
2441 "Row token requires column metadata",
2442 )
2443 })?;
2444 let row = RawRow::decode(&mut buf, meta)?;
2445 Token::Row(row)
2446 }
2447 Some(TokenType::NbcRow) => {
2448 let meta = metadata.ok_or_else(|| {
2449 ProtocolError::StringEncoding(
2450 #[cfg(feature = "std")]
2451 "NbcRow token requires column metadata".to_string(),
2452 #[cfg(not(feature = "std"))]
2453 "NbcRow token requires column metadata",
2454 )
2455 })?;
2456 let row = NbcRow::decode(&mut buf, meta)?;
2457 Token::NbcRow(row)
2458 }
2459 Some(TokenType::ReturnValue) => {
2460 let ret_val = ReturnValue::decode(&mut buf)?;
2461 Token::ReturnValue(ret_val)
2462 }
2463 Some(TokenType::SessionState) => {
2464 let session = SessionState::decode(&mut buf)?;
2465 Token::SessionState(session)
2466 }
2467 Some(TokenType::ColInfo) | Some(TokenType::TabName) | Some(TokenType::Offset) => {
2468 if buf.remaining() < 2 {
2471 return Err(ProtocolError::UnexpectedEof);
2472 }
2473 let length = buf.get_u16_le() as usize;
2474 if buf.remaining() < length {
2475 return Err(ProtocolError::IncompletePacket {
2476 expected: length,
2477 actual: buf.remaining(),
2478 });
2479 }
2480 buf.advance(length);
2482 self.position = start_pos + (self.data.len() - start_pos - buf.remaining());
2484 return self.next_token_with_metadata(metadata);
2485 }
2486 None => {
2487 return Err(ProtocolError::InvalidTokenType(token_type_byte));
2488 }
2489 };
2490
2491 let consumed = self.data.len() - start_pos - buf.remaining();
2493 self.position = start_pos + consumed;
2494
2495 Ok(Some(token))
2496 }
2497
2498 pub fn skip_token(&mut self) -> Result<(), ProtocolError> {
2502 if !self.has_remaining() {
2503 return Ok(());
2504 }
2505
2506 let token_type_byte = self.data[self.position];
2507 let token_type = TokenType::from_u8(token_type_byte);
2508
2509 let skip_amount = match token_type {
2511 Some(TokenType::Done) | Some(TokenType::DoneProc) | Some(TokenType::DoneInProc) => {
2513 1 + Done::SIZE }
2515 Some(TokenType::ReturnStatus) => {
2516 1 + 4 }
2518 Some(TokenType::Error)
2520 | Some(TokenType::Info)
2521 | Some(TokenType::LoginAck)
2522 | Some(TokenType::EnvChange)
2523 | Some(TokenType::Order)
2524 | Some(TokenType::Sspi)
2525 | Some(TokenType::ColInfo)
2526 | Some(TokenType::TabName)
2527 | Some(TokenType::Offset)
2528 | Some(TokenType::ReturnValue) => {
2529 if self.remaining() < 3 {
2530 return Err(ProtocolError::UnexpectedEof);
2531 }
2532 let length = u16::from_le_bytes([
2533 self.data[self.position + 1],
2534 self.data[self.position + 2],
2535 ]) as usize;
2536 1 + 2 + length }
2538 Some(TokenType::SessionState) | Some(TokenType::FedAuthInfo) => {
2540 if self.remaining() < 5 {
2541 return Err(ProtocolError::UnexpectedEof);
2542 }
2543 let length = u32::from_le_bytes([
2544 self.data[self.position + 1],
2545 self.data[self.position + 2],
2546 self.data[self.position + 3],
2547 self.data[self.position + 4],
2548 ]) as usize;
2549 1 + 4 + length
2550 }
2551 Some(TokenType::FeatureExtAck) => {
2553 let mut buf = &self.data[self.position + 1..];
2555 let _ = FeatureExtAck::decode(&mut buf)?;
2556 self.data.len() - self.position - buf.remaining()
2557 }
2558 Some(TokenType::ColMetaData) | Some(TokenType::Row) | Some(TokenType::NbcRow) => {
2560 return Err(ProtocolError::InvalidTokenType(token_type_byte));
2561 }
2562 None => {
2563 return Err(ProtocolError::InvalidTokenType(token_type_byte));
2564 }
2565 };
2566
2567 if self.remaining() < skip_amount {
2568 return Err(ProtocolError::UnexpectedEof);
2569 }
2570
2571 self.position += skip_amount;
2572 Ok(())
2573 }
2574
2575 #[must_use]
2577 pub fn position(&self) -> usize {
2578 self.position
2579 }
2580
2581 pub fn reset(&mut self) {
2583 self.position = 0;
2584 }
2585}
2586
2587#[cfg(test)]
2592#[allow(clippy::unwrap_used, clippy::panic)]
2593mod tests {
2594 use super::*;
2595 use bytes::BytesMut;
2596
2597 #[test]
2598 fn test_done_roundtrip() {
2599 let done = Done {
2600 status: DoneStatus {
2601 more: false,
2602 error: false,
2603 in_xact: false,
2604 count: true,
2605 attn: false,
2606 srverror: false,
2607 },
2608 cur_cmd: 193, row_count: 42,
2610 };
2611
2612 let mut buf = BytesMut::new();
2613 done.encode(&mut buf);
2614
2615 let mut cursor = &buf[1..];
2617 let decoded = Done::decode(&mut cursor).unwrap();
2618
2619 assert_eq!(decoded.status.count, done.status.count);
2620 assert_eq!(decoded.cur_cmd, done.cur_cmd);
2621 assert_eq!(decoded.row_count, done.row_count);
2622 }
2623
2624 #[test]
2625 fn test_done_status_bits() {
2626 let status = DoneStatus {
2627 more: true,
2628 error: true,
2629 in_xact: true,
2630 count: true,
2631 attn: false,
2632 srverror: false,
2633 };
2634
2635 let bits = status.to_bits();
2636 let restored = DoneStatus::from_bits(bits);
2637
2638 assert_eq!(status.more, restored.more);
2639 assert_eq!(status.error, restored.error);
2640 assert_eq!(status.in_xact, restored.in_xact);
2641 assert_eq!(status.count, restored.count);
2642 }
2643
2644 #[test]
2645 fn test_token_parser_done() {
2646 let data = Bytes::from_static(&[
2648 0xFD, 0x10, 0x00, 0xC1, 0x00, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ]);
2653
2654 let mut parser = TokenParser::new(data);
2655 let token = parser.next_token().unwrap().unwrap();
2656
2657 match token {
2658 Token::Done(done) => {
2659 assert!(done.status.count);
2660 assert!(!done.status.more);
2661 assert_eq!(done.cur_cmd, 193);
2662 assert_eq!(done.row_count, 5);
2663 }
2664 _ => panic!("Expected Done token"),
2665 }
2666
2667 assert!(parser.next_token().unwrap().is_none());
2669 }
2670
2671 #[test]
2672 fn test_env_change_type_from_u8() {
2673 assert_eq!(EnvChangeType::from_u8(1), Some(EnvChangeType::Database));
2674 assert_eq!(EnvChangeType::from_u8(20), Some(EnvChangeType::Routing));
2675 assert_eq!(EnvChangeType::from_u8(100), None);
2676 }
2677
2678 #[test]
2685 fn test_env_change_routing_consumes_declared_length() {
2686 let host = "redirect.example";
2687 let host_utf16: Vec<u16> = host.encode_utf16().collect();
2688
2689 let mut data = BytesMut::new();
2690 let routing_len = 1 + 2 + 2 + host_utf16.len() * 2;
2692 let env_len = 1 + 2 + routing_len + 2;
2695 data.put_u16_le(env_len as u16);
2696 data.put_u8(20); data.put_u16_le(routing_len as u16);
2698 data.put_u8(0); data.put_u16_le(11000); data.put_u16_le(host_utf16.len() as u16);
2701 for c in &host_utf16 {
2702 data.put_u16_le(*c);
2703 }
2704 data.put_u16_le(0); data.put_u8(0xFD);
2707
2708 let mut buf: &[u8] = &data;
2709 let env = EnvChange::decode(&mut buf).unwrap();
2710 assert_eq!(env.routing_info(), Some((host, 11000)));
2711 assert_eq!(
2712 buf,
2713 &[0xFD],
2714 "decode must consume exactly the declared ENVCHANGE frame"
2715 );
2716 }
2717
2718 fn put_b_varchar(buf: &mut BytesMut, s: &str) {
2719 let utf16: Vec<u16> = s.encode_utf16().collect();
2720 buf.put_u8(utf16.len() as u8);
2721 for c in utf16 {
2722 buf.put_u16_le(c);
2723 }
2724 }
2725
2726 fn put_us_varchar(buf: &mut BytesMut, s: &str) {
2727 let utf16: Vec<u16> = s.encode_utf16().collect();
2728 buf.put_u16_le(utf16.len() as u16);
2729 for c in utf16 {
2730 buf.put_u16_le(c);
2731 }
2732 }
2733
2734 #[test]
2741 fn test_udt_info_metadata_uses_b_varchar_names() {
2742 let mut data = BytesMut::new();
2743 data.put_u16_le(0xFFFF); put_b_varchar(&mut data, "master");
2745 put_b_varchar(&mut data, "dbo");
2746 put_b_varchar(&mut data, "hierarchyid");
2747 put_us_varchar(
2748 &mut data,
2749 "Microsoft.SqlServer.Types.SqlHierarchyId, Microsoft.SqlServer.Types",
2750 );
2751 data.put_u8(0xFD);
2753
2754 let mut buf: &[u8] = &data;
2755 let info = decode_type_info(&mut buf, TypeId::Udt, TypeId::Udt as u8).unwrap();
2756 assert_eq!(info.max_length, Some(0xFFFF));
2757 assert_eq!(
2758 buf,
2759 &[0xFD],
2760 "decode must consume exactly the UDT_INFO frame"
2761 );
2762 }
2763
2764 #[test]
2768 fn test_xml_info_schema_bound_uses_b_varchar_names() {
2769 let mut data = BytesMut::new();
2770 data.put_u8(1); put_b_varchar(&mut data, "master");
2772 put_b_varchar(&mut data, "dbo");
2773 put_us_varchar(&mut data, "MyXmlSchemaCollection");
2774 data.put_u8(0xFD);
2775
2776 let mut buf: &[u8] = &data;
2777 decode_type_info(&mut buf, TypeId::Xml, TypeId::Xml as u8).unwrap();
2778 assert_eq!(
2779 buf,
2780 &[0xFD],
2781 "decode must consume exactly the XML_INFO frame"
2782 );
2783 }
2784
2785 #[test]
2786 fn hostile_env_change_binary_truncated_is_not_panic() {
2787 let data = [0x01, 0x00, 0x08];
2792 let mut buf: &[u8] = &data;
2793 let env = EnvChange::decode(&mut buf).unwrap();
2794 assert_eq!(env.env_type, EnvChangeType::BeginTransaction);
2795 }
2796
2797 #[test]
2800 fn hostile_env_change_under_declared_cannot_steal_following_bytes() {
2801 let mut data = BytesMut::new();
2806 data.put_u16_le(1); data.put_u8(0x08); let following: &[u8] = &[0x08, 1, 2, 3, 4, 5, 6, 7, 8, 0x00];
2809 data.extend_from_slice(following);
2810
2811 let mut buf: &[u8] = &data;
2812 let env = EnvChange::decode(&mut buf).unwrap();
2813 assert_eq!(env.env_type, EnvChangeType::BeginTransaction);
2814 match &env.new_value {
2815 EnvChangeValue::Binary(b) => {
2816 assert!(
2817 b.is_empty(),
2818 "under-declared frame yields the lenient empty value"
2819 );
2820 }
2821 other => panic!("expected empty Binary value, got {other:?}"),
2822 }
2823 assert_eq!(
2824 buf, following,
2825 "bytes beyond the declared frame belong to the next token"
2826 );
2827 }
2828
2829 #[test]
2832 fn hostile_env_change_zero_length_frame_errors() {
2833 let data = [0x00, 0x00, 0xFD];
2834 let mut buf: &[u8] = &data;
2835 assert!(EnvChange::decode(&mut buf).is_err());
2836 }
2837
2838 #[test]
2839 fn test_colmetadata_no_columns() {
2840 let data = Bytes::from_static(&[0xFF, 0xFF]);
2842 let mut cursor: &[u8] = &data;
2843 let meta = ColMetaData::decode(&mut cursor).unwrap();
2844 assert!(meta.is_empty());
2845 assert_eq!(meta.column_count(), 0);
2846 }
2847
2848 #[test]
2849 fn test_colmetadata_single_int_column() {
2850 let mut data = BytesMut::new();
2853 data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[0x00, 0x00, 0x00, 0x00]); data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[0x38]); data.extend_from_slice(&[0x02]); data.extend_from_slice(&[b'i', 0x00, b'd', 0x00]); let mut cursor: &[u8] = &data;
2862 let meta = ColMetaData::decode(&mut cursor).unwrap();
2863
2864 assert_eq!(meta.column_count(), 1);
2865 assert_eq!(meta.columns[0].name, "id");
2866 assert_eq!(meta.columns[0].type_id, TypeId::Int4);
2867 assert!(meta.columns[0].is_nullable());
2868 }
2869
2870 #[test]
2871 fn test_colmetadata_nvarchar_column() {
2872 let mut data = BytesMut::new();
2874 data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[0x00, 0x00, 0x00, 0x00]); data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[0xE7]); data.extend_from_slice(&[0x64, 0x00]); data.extend_from_slice(&[0x09, 0x04, 0xD0, 0x00, 0x34]); data.extend_from_slice(&[0x04]); data.extend_from_slice(&[b'n', 0x00, b'a', 0x00, b'm', 0x00, b'e', 0x00]);
2884
2885 let mut cursor: &[u8] = &data;
2886 let meta = ColMetaData::decode(&mut cursor).unwrap();
2887
2888 assert_eq!(meta.column_count(), 1);
2889 assert_eq!(meta.columns[0].name, "name");
2890 assert_eq!(meta.columns[0].type_id, TypeId::NVarChar);
2891 assert_eq!(meta.columns[0].type_info.max_length, Some(100));
2892 assert!(meta.columns[0].type_info.collation.is_some());
2893 }
2894
2895 #[test]
2896 fn test_raw_row_decode_int() {
2897 let metadata = ColMetaData {
2899 cek_table: None,
2900 columns: vec![ColumnData {
2901 name: "id".to_string(),
2902 type_id: TypeId::Int4,
2903 col_type: 0x38,
2904 flags: 0,
2905 user_type: 0,
2906 type_info: TypeInfo::default(),
2907 crypto_metadata: None,
2908 }],
2909 };
2910
2911 let data = Bytes::from_static(&[0x2A, 0x00, 0x00, 0x00]); let mut cursor: &[u8] = &data;
2914 let row = RawRow::decode(&mut cursor, &metadata).unwrap();
2915
2916 assert_eq!(row.data.len(), 4);
2918 assert_eq!(&row.data[..], &[0x2A, 0x00, 0x00, 0x00]);
2919 }
2920
2921 #[test]
2922 fn test_raw_row_decode_nullable_int() {
2923 let metadata = ColMetaData {
2925 cek_table: None,
2926 columns: vec![ColumnData {
2927 name: "id".to_string(),
2928 type_id: TypeId::IntN,
2929 col_type: 0x26,
2930 flags: 0x01, user_type: 0,
2932 type_info: TypeInfo {
2933 max_length: Some(4),
2934 ..Default::default()
2935 },
2936 crypto_metadata: None,
2937 }],
2938 };
2939
2940 let data = Bytes::from_static(&[0x04, 0x2A, 0x00, 0x00, 0x00]); let mut cursor: &[u8] = &data;
2943 let row = RawRow::decode(&mut cursor, &metadata).unwrap();
2944
2945 assert_eq!(row.data.len(), 5);
2946 assert_eq!(row.data[0], 4); assert_eq!(&row.data[1..], &[0x2A, 0x00, 0x00, 0x00]);
2948 }
2949
2950 #[test]
2951 fn test_raw_row_decode_null_value() {
2952 let metadata = ColMetaData {
2954 cek_table: None,
2955 columns: vec![ColumnData {
2956 name: "id".to_string(),
2957 type_id: TypeId::IntN,
2958 col_type: 0x26,
2959 flags: 0x01, user_type: 0,
2961 type_info: TypeInfo {
2962 max_length: Some(4),
2963 ..Default::default()
2964 },
2965 crypto_metadata: None,
2966 }],
2967 };
2968
2969 let data = Bytes::from_static(&[0xFF]);
2971 let mut cursor: &[u8] = &data;
2972 let row = RawRow::decode(&mut cursor, &metadata).unwrap();
2973
2974 assert_eq!(row.data.len(), 1);
2975 assert_eq!(row.data[0], 0xFF); }
2977
2978 #[test]
2979 fn test_nbcrow_null_bitmap() {
2980 let row = NbcRow {
2981 null_bitmap: vec![0b00000101], data: Bytes::new(),
2983 };
2984
2985 assert!(row.is_null(0));
2986 assert!(!row.is_null(1));
2987 assert!(row.is_null(2));
2988 assert!(!row.is_null(3));
2989 }
2990
2991 #[test]
2992 fn test_token_parser_colmetadata() {
2993 let mut data = BytesMut::new();
2995 data.extend_from_slice(&[0x81]); data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[0x00, 0x00, 0x00, 0x00]); data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[0x38]); data.extend_from_slice(&[0x02]); data.extend_from_slice(&[b'i', 0x00, b'd', 0x00]); let mut parser = TokenParser::new(data.freeze());
3004 let token = parser.next_token().unwrap().unwrap();
3005
3006 match token {
3007 Token::ColMetaData(meta) => {
3008 assert_eq!(meta.column_count(), 1);
3009 assert_eq!(meta.columns[0].name, "id");
3010 assert_eq!(meta.columns[0].type_id, TypeId::Int4);
3011 }
3012 _ => panic!("Expected ColMetaData token"),
3013 }
3014 }
3015
3016 #[test]
3017 fn test_token_parser_row_with_metadata() {
3018 let metadata = ColMetaData {
3020 cek_table: None,
3021 columns: vec![ColumnData {
3022 name: "id".to_string(),
3023 type_id: TypeId::Int4,
3024 col_type: 0x38,
3025 flags: 0,
3026 user_type: 0,
3027 type_info: TypeInfo::default(),
3028 crypto_metadata: None,
3029 }],
3030 };
3031
3032 let mut data = BytesMut::new();
3034 data.extend_from_slice(&[0xD1]); data.extend_from_slice(&[0x2A, 0x00, 0x00, 0x00]); let mut parser = TokenParser::new(data.freeze());
3038 let token = parser
3039 .next_token_with_metadata(Some(&metadata))
3040 .unwrap()
3041 .unwrap();
3042
3043 match token {
3044 Token::Row(row) => {
3045 assert_eq!(row.data.len(), 4);
3046 }
3047 _ => panic!("Expected Row token"),
3048 }
3049 }
3050
3051 #[test]
3052 fn test_token_parser_row_without_metadata_fails() {
3053 let mut data = BytesMut::new();
3055 data.extend_from_slice(&[0xD1]); data.extend_from_slice(&[0x2A, 0x00, 0x00, 0x00]); let mut parser = TokenParser::new(data.freeze());
3059 let result = parser.next_token(); assert!(result.is_err());
3062 }
3063
3064 #[test]
3065 fn test_token_parser_peek() {
3066 let data = Bytes::from_static(&[
3067 0xFD, 0x10, 0x00, 0xC1, 0x00, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ]);
3072
3073 let parser = TokenParser::new(data);
3074 assert_eq!(parser.peek_token_type(), Some(TokenType::Done));
3075 }
3076
3077 #[test]
3078 fn test_column_data_fixed_size() {
3079 let col = ColumnData {
3080 name: String::new(),
3081 type_id: TypeId::Int4,
3082 col_type: 0x38,
3083 flags: 0,
3084 user_type: 0,
3085 type_info: TypeInfo::default(),
3086 crypto_metadata: None,
3087 };
3088 assert_eq!(col.fixed_size(), Some(4));
3089
3090 let col2 = ColumnData {
3091 name: String::new(),
3092 type_id: TypeId::NVarChar,
3093 col_type: 0xE7,
3094 flags: 0,
3095 user_type: 0,
3096 type_info: TypeInfo::default(),
3097 crypto_metadata: None,
3098 };
3099 assert_eq!(col2.fixed_size(), None);
3100 }
3101
3102 #[test]
3110 fn test_decode_nvarchar_then_intn_roundtrip() {
3111 let mut wire_data = BytesMut::new();
3116
3117 let word = "World";
3120 let utf16: Vec<u16> = word.encode_utf16().collect();
3121 wire_data.put_u16_le((utf16.len() * 2) as u16); for code_unit in &utf16 {
3123 wire_data.put_u16_le(*code_unit);
3124 }
3125
3126 wire_data.put_u8(4); wire_data.put_i32_le(42);
3129
3130 let metadata = ColMetaData {
3132 cek_table: None,
3133 columns: vec![
3134 ColumnData {
3135 name: "greeting".to_string(),
3136 type_id: TypeId::NVarChar,
3137 col_type: 0xE7,
3138 flags: 0x01,
3139 user_type: 0,
3140 type_info: TypeInfo {
3141 max_length: Some(10), precision: None,
3143 scale: None,
3144 collation: None,
3145 },
3146 crypto_metadata: None,
3147 },
3148 ColumnData {
3149 name: "number".to_string(),
3150 type_id: TypeId::IntN,
3151 col_type: 0x26,
3152 flags: 0x01,
3153 user_type: 0,
3154 type_info: TypeInfo {
3155 max_length: Some(4),
3156 precision: None,
3157 scale: None,
3158 collation: None,
3159 },
3160 crypto_metadata: None,
3161 },
3162 ],
3163 };
3164
3165 let mut wire_cursor = wire_data.freeze();
3167 let raw_row = RawRow::decode(&mut wire_cursor, &metadata).unwrap();
3168
3169 assert_eq!(
3171 wire_cursor.remaining(),
3172 0,
3173 "wire data should be fully consumed"
3174 );
3175
3176 let mut stored_cursor: &[u8] = &raw_row.data;
3178
3179 assert!(
3182 stored_cursor.remaining() >= 2,
3183 "need at least 2 bytes for length"
3184 );
3185 let len0 = stored_cursor.get_u16_le() as usize;
3186 assert_eq!(len0, 10, "NVarChar length should be 10 bytes");
3187 assert!(
3188 stored_cursor.remaining() >= len0,
3189 "need {len0} bytes for data"
3190 );
3191
3192 let mut utf16_read = Vec::new();
3194 for _ in 0..(len0 / 2) {
3195 utf16_read.push(stored_cursor.get_u16_le());
3196 }
3197 let string0 = String::from_utf16(&utf16_read).unwrap();
3198 assert_eq!(string0, "World", "column 0 should be 'World'");
3199
3200 assert!(
3203 stored_cursor.remaining() >= 1,
3204 "need at least 1 byte for length"
3205 );
3206 let len1 = stored_cursor.get_u8();
3207 assert_eq!(len1, 4, "IntN length should be 4");
3208 assert!(stored_cursor.remaining() >= 4, "need 4 bytes for INT data");
3209 let int1 = stored_cursor.get_i32_le();
3210 assert_eq!(int1, 42, "column 1 should be 42");
3211
3212 assert_eq!(
3214 stored_cursor.remaining(),
3215 0,
3216 "stored data should be fully consumed"
3217 );
3218 }
3219
3220 #[test]
3221 fn test_decode_nvarchar_max_then_intn_roundtrip() {
3222 let mut wire_data = BytesMut::new();
3226
3227 let word = "Hello";
3230 let utf16: Vec<u16> = word.encode_utf16().collect();
3231 let byte_len = (utf16.len() * 2) as u64;
3232
3233 wire_data.put_u64_le(byte_len); wire_data.put_u32_le(byte_len as u32); for code_unit in &utf16 {
3236 wire_data.put_u16_le(*code_unit);
3237 }
3238 wire_data.put_u32_le(0); wire_data.put_u8(4);
3242 wire_data.put_i32_le(99);
3243
3244 let metadata = ColMetaData {
3246 cek_table: None,
3247 columns: vec![
3248 ColumnData {
3249 name: "text".to_string(),
3250 type_id: TypeId::NVarChar,
3251 col_type: 0xE7,
3252 flags: 0x01,
3253 user_type: 0,
3254 type_info: TypeInfo {
3255 max_length: Some(0xFFFF), precision: None,
3257 scale: None,
3258 collation: None,
3259 },
3260 crypto_metadata: None,
3261 },
3262 ColumnData {
3263 name: "num".to_string(),
3264 type_id: TypeId::IntN,
3265 col_type: 0x26,
3266 flags: 0x01,
3267 user_type: 0,
3268 type_info: TypeInfo {
3269 max_length: Some(4),
3270 precision: None,
3271 scale: None,
3272 collation: None,
3273 },
3274 crypto_metadata: None,
3275 },
3276 ],
3277 };
3278
3279 let mut wire_cursor = wire_data.freeze();
3281 let raw_row = RawRow::decode(&mut wire_cursor, &metadata).unwrap();
3282
3283 assert_eq!(
3285 wire_cursor.remaining(),
3286 0,
3287 "wire data should be fully consumed"
3288 );
3289
3290 let mut stored_cursor: &[u8] = &raw_row.data;
3292
3293 let total_len = stored_cursor.get_u64_le();
3295 assert_eq!(total_len, 10, "PLP total length should be 10");
3296
3297 let chunk_len = stored_cursor.get_u32_le();
3298 assert_eq!(chunk_len, 10, "PLP chunk length should be 10");
3299
3300 let mut utf16_read = Vec::new();
3301 for _ in 0..(chunk_len / 2) {
3302 utf16_read.push(stored_cursor.get_u16_le());
3303 }
3304 let string0 = String::from_utf16(&utf16_read).unwrap();
3305 assert_eq!(string0, "Hello", "column 0 should be 'Hello'");
3306
3307 let terminator = stored_cursor.get_u32_le();
3308 assert_eq!(terminator, 0, "PLP should end with 0");
3309
3310 let len1 = stored_cursor.get_u8();
3312 assert_eq!(len1, 4);
3313 let int1 = stored_cursor.get_i32_le();
3314 assert_eq!(int1, 99, "column 1 should be 99");
3315
3316 assert_eq!(
3318 stored_cursor.remaining(),
3319 0,
3320 "stored data should be fully consumed"
3321 );
3322 }
3323
3324 #[test]
3329 fn test_return_status_via_parser() {
3330 let data = Bytes::from_static(&[
3332 0x79, 0x00, 0x00, 0x00, 0x00, ]);
3335
3336 let mut parser = TokenParser::new(data);
3337 let token = parser.next_token().unwrap().unwrap();
3338
3339 match token {
3340 Token::ReturnStatus(status) => {
3341 assert_eq!(status, 0);
3342 }
3343 _ => panic!("Expected ReturnStatus token, got {token:?}"),
3344 }
3345
3346 assert!(parser.next_token().unwrap().is_none());
3347 }
3348
3349 #[test]
3350 fn test_return_status_nonzero() {
3351 let mut buf = BytesMut::new();
3353 buf.put_u8(0x79); buf.put_i32_le(-6);
3355
3356 let mut parser = TokenParser::new(buf.freeze());
3357 let token = parser.next_token().unwrap().unwrap();
3358
3359 match token {
3360 Token::ReturnStatus(status) => {
3361 assert_eq!(status, -6);
3362 }
3363 _ => panic!("Expected ReturnStatus token"),
3364 }
3365 }
3366
3367 #[test]
3372 fn test_done_proc_roundtrip() {
3373 let done = DoneProc {
3374 status: DoneStatus {
3375 more: false,
3376 error: false,
3377 in_xact: false,
3378 count: true,
3379 attn: false,
3380 srverror: false,
3381 },
3382 cur_cmd: 0x00C6, row_count: 100,
3384 };
3385
3386 let mut buf = BytesMut::new();
3387 done.encode(&mut buf);
3388
3389 assert_eq!(buf[0], 0xFE);
3391
3392 let mut cursor = &buf[1..];
3394 let decoded = DoneProc::decode(&mut cursor).unwrap();
3395
3396 assert!(decoded.status.count);
3397 assert!(!decoded.status.more);
3398 assert!(!decoded.status.error);
3399 assert_eq!(decoded.cur_cmd, 0x00C6);
3400 assert_eq!(decoded.row_count, 100);
3401 }
3402
3403 #[test]
3404 fn test_done_proc_via_parser() {
3405 let data = Bytes::from_static(&[
3406 0xFE, 0x00, 0x00, 0xC6, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ]);
3411
3412 let mut parser = TokenParser::new(data);
3413 let token = parser.next_token().unwrap().unwrap();
3414
3415 match token {
3416 Token::DoneProc(done) => {
3417 assert!(!done.status.count);
3418 assert!(!done.status.more);
3419 assert_eq!(done.cur_cmd, 198);
3420 assert_eq!(done.row_count, 0);
3421 }
3422 _ => panic!("Expected DoneProc token"),
3423 }
3424 }
3425
3426 #[test]
3427 fn test_done_proc_with_error_flag() {
3428 let mut buf = BytesMut::new();
3429 buf.put_u8(0xFE); buf.put_u16_le(0x0002); buf.put_u16_le(0x00C6); buf.put_u64_le(0); let mut parser = TokenParser::new(buf.freeze());
3435 let token = parser.next_token().unwrap().unwrap();
3436
3437 match token {
3438 Token::DoneProc(done) => {
3439 assert!(done.status.error);
3440 assert!(!done.status.count);
3441 assert!(!done.status.more);
3442 }
3443 _ => panic!("Expected DoneProc token"),
3444 }
3445 }
3446
3447 #[test]
3452 fn test_done_in_proc_roundtrip() {
3453 let done = DoneInProc {
3454 status: DoneStatus {
3455 more: true,
3456 error: false,
3457 in_xact: false,
3458 count: true,
3459 attn: false,
3460 srverror: false,
3461 },
3462 cur_cmd: 193, row_count: 7,
3464 };
3465
3466 let mut buf = BytesMut::new();
3467 done.encode(&mut buf);
3468
3469 assert_eq!(buf[0], 0xFF);
3470
3471 let mut cursor = &buf[1..];
3472 let decoded = DoneInProc::decode(&mut cursor).unwrap();
3473
3474 assert!(decoded.status.more);
3475 assert!(decoded.status.count);
3476 assert!(!decoded.status.error);
3477 assert_eq!(decoded.cur_cmd, 193);
3478 assert_eq!(decoded.row_count, 7);
3479 }
3480
3481 #[test]
3482 fn test_done_in_proc_via_parser() {
3483 let data = Bytes::from_static(&[
3484 0xFF, 0x11, 0x00, 0xC1, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ]);
3489
3490 let mut parser = TokenParser::new(data);
3491 let token = parser.next_token().unwrap().unwrap();
3492
3493 match token {
3494 Token::DoneInProc(done) => {
3495 assert!(done.status.more);
3496 assert!(done.status.count);
3497 assert_eq!(done.cur_cmd, 193);
3498 assert_eq!(done.row_count, 3);
3499 }
3500 _ => panic!("Expected DoneInProc token"),
3501 }
3502 }
3503
3504 #[test]
3509 fn test_server_error_decode() {
3510 let mut buf = BytesMut::new();
3513
3514 let msg_utf16: Vec<u16> = "Invalid column name 'foo'.".encode_utf16().collect();
3516 let srv_utf16: Vec<u16> = "SQLDB01".encode_utf16().collect();
3517 let proc_utf16: Vec<u16> = "".encode_utf16().collect();
3518
3519 let length: u16 = (4
3525 + 1
3526 + 1
3527 + 2
3528 + (msg_utf16.len() * 2)
3529 + 1
3530 + (srv_utf16.len() * 2)
3531 + 1
3532 + (proc_utf16.len() * 2)
3533 + 4) as u16;
3534
3535 buf.put_u16_le(length);
3536 buf.put_i32_le(207); buf.put_u8(1); buf.put_u8(16); buf.put_u16_le(msg_utf16.len() as u16);
3542 for &c in &msg_utf16 {
3543 buf.put_u16_le(c);
3544 }
3545
3546 buf.put_u8(srv_utf16.len() as u8);
3548 for &c in &srv_utf16 {
3549 buf.put_u16_le(c);
3550 }
3551
3552 buf.put_u8(proc_utf16.len() as u8);
3554
3555 buf.put_i32_le(42);
3557
3558 let mut cursor = buf.freeze();
3559 let error = ServerError::decode(&mut cursor).unwrap();
3560
3561 assert_eq!(error.number, 207);
3562 assert_eq!(error.state, 1);
3563 assert_eq!(error.class, 16);
3564 assert_eq!(error.message, "Invalid column name 'foo'.");
3565 assert_eq!(error.server, "SQLDB01");
3566 assert_eq!(error.procedure, "");
3567 assert_eq!(error.line, 42);
3568 }
3569
3570 #[test]
3571 fn test_server_error_severity_helpers() {
3572 let fatal = ServerError {
3573 number: 4014,
3574 state: 1,
3575 class: 20,
3576 message: "Fatal error".to_string(),
3577 server: String::new(),
3578 procedure: String::new(),
3579 line: 0,
3580 };
3581 assert!(fatal.is_fatal());
3582 assert!(fatal.is_batch_abort());
3583
3584 let batch_abort = ServerError {
3585 number: 547,
3586 state: 0,
3587 class: 16,
3588 message: "Constraint violation".to_string(),
3589 server: String::new(),
3590 procedure: String::new(),
3591 line: 1,
3592 };
3593 assert!(!batch_abort.is_fatal());
3594 assert!(batch_abort.is_batch_abort());
3595
3596 let informational = ServerError {
3597 number: 5701,
3598 state: 2,
3599 class: 10,
3600 message: "Changed db context".to_string(),
3601 server: String::new(),
3602 procedure: String::new(),
3603 line: 0,
3604 };
3605 assert!(!informational.is_fatal());
3606 assert!(!informational.is_batch_abort());
3607 }
3608
3609 #[test]
3610 fn test_server_error_via_parser() {
3611 let mut buf = BytesMut::new();
3613 buf.put_u8(0xAA); let msg_utf16: Vec<u16> = "Syntax error".encode_utf16().collect();
3616 let srv_utf16: Vec<u16> = "SRV".encode_utf16().collect();
3617 let proc_utf16: Vec<u16> = "sp_test".encode_utf16().collect();
3618
3619 let length: u16 = (4
3620 + 1
3621 + 1
3622 + 2
3623 + (msg_utf16.len() * 2)
3624 + 1
3625 + (srv_utf16.len() * 2)
3626 + 1
3627 + (proc_utf16.len() * 2)
3628 + 4) as u16;
3629
3630 buf.put_u16_le(length);
3631 buf.put_i32_le(102); buf.put_u8(1);
3633 buf.put_u8(15);
3634
3635 buf.put_u16_le(msg_utf16.len() as u16);
3636 for &c in &msg_utf16 {
3637 buf.put_u16_le(c);
3638 }
3639 buf.put_u8(srv_utf16.len() as u8);
3640 for &c in &srv_utf16 {
3641 buf.put_u16_le(c);
3642 }
3643 buf.put_u8(proc_utf16.len() as u8);
3644 for &c in &proc_utf16 {
3645 buf.put_u16_le(c);
3646 }
3647 buf.put_i32_le(5);
3648
3649 let mut parser = TokenParser::new(buf.freeze());
3650 let token = parser.next_token().unwrap().unwrap();
3651
3652 match token {
3653 Token::Error(err) => {
3654 assert_eq!(err.number, 102);
3655 assert_eq!(err.class, 15);
3656 assert_eq!(err.message, "Syntax error");
3657 assert_eq!(err.server, "SRV");
3658 assert_eq!(err.procedure, "sp_test");
3659 assert_eq!(err.line, 5);
3660 }
3661 _ => panic!("Expected Error token"),
3662 }
3663 }
3664
3665 fn build_return_value_intn(
3672 ordinal: u16,
3673 name: &str,
3674 status: u8,
3675 value: Option<i32>,
3676 ) -> BytesMut {
3677 let mut inner = BytesMut::new();
3678
3679 inner.put_u16_le(ordinal);
3681
3682 let name_utf16: Vec<u16> = name.encode_utf16().collect();
3684 inner.put_u8(name_utf16.len() as u8);
3685 for &c in &name_utf16 {
3686 inner.put_u16_le(c);
3687 }
3688
3689 inner.put_u8(status);
3691
3692 inner.put_u32_le(0);
3694
3695 inner.put_u16_le(0x0001); inner.put_u8(0x26);
3700
3701 inner.put_u8(4);
3703
3704 match value {
3706 Some(v) => {
3707 inner.put_u8(4); inner.put_i32_le(v);
3709 }
3710 None => {
3711 inner.put_u8(0); }
3713 }
3714
3715 inner
3718 }
3719
3720 #[test]
3721 fn test_return_value_int_output() {
3722 let buf = build_return_value_intn(1, "@result", 0x01, Some(42));
3723 let mut cursor = buf.freeze();
3724 let rv = ReturnValue::decode(&mut cursor).unwrap();
3725
3726 assert_eq!(rv.param_ordinal, 1);
3727 assert_eq!(rv.param_name, "@result");
3728 assert_eq!(rv.status, 0x01); assert_eq!(rv.col_type, 0x26); assert_eq!(rv.type_info.max_length, Some(4));
3731 assert_eq!(rv.value.len(), 5);
3733 assert_eq!(rv.value[0], 4);
3734 assert_eq!(
3735 i32::from_le_bytes([rv.value[1], rv.value[2], rv.value[3], rv.value[4]]),
3736 42
3737 );
3738 }
3739
3740 #[test]
3741 fn test_return_value_null_output() {
3742 let buf = build_return_value_intn(2, "@count", 0x01, None);
3743 let mut cursor = buf.freeze();
3744 let rv = ReturnValue::decode(&mut cursor).unwrap();
3745
3746 assert_eq!(rv.param_ordinal, 2);
3747 assert_eq!(rv.param_name, "@count");
3748 assert_eq!(rv.status, 0x01);
3749 assert_eq!(rv.col_type, 0x26);
3750 assert_eq!(rv.value.len(), 1);
3752 assert_eq!(rv.value[0], 0);
3753 }
3754
3755 #[test]
3756 fn test_return_value_udf_status() {
3757 let buf = build_return_value_intn(0, "@RETURN_VALUE", 0x02, Some(-1));
3759 let mut cursor = buf.freeze();
3760 let rv = ReturnValue::decode(&mut cursor).unwrap();
3761
3762 assert_eq!(rv.param_ordinal, 0);
3763 assert_eq!(rv.param_name, "@RETURN_VALUE");
3764 assert_eq!(rv.status, 0x02); assert_eq!(rv.value[0], 4);
3766 assert_eq!(
3767 i32::from_le_bytes([rv.value[1], rv.value[2], rv.value[3], rv.value[4]]),
3768 -1
3769 );
3770 }
3771
3772 #[test]
3773 fn test_return_value_nvarchar_output() {
3774 let mut inner = BytesMut::new();
3776
3777 inner.put_u16_le(1);
3779
3780 let name_utf16: Vec<u16> = "@name".encode_utf16().collect();
3782 inner.put_u8(name_utf16.len() as u8);
3783 for &c in &name_utf16 {
3784 inner.put_u16_le(c);
3785 }
3786
3787 inner.put_u8(0x01);
3789 inner.put_u32_le(0);
3791 inner.put_u16_le(0x0001);
3793 inner.put_u8(0xE7);
3795 inner.put_u16_le(200); inner.put_u32_le(0x0904D000); inner.put_u8(0x34); let val_utf16: Vec<u16> = "Hello".encode_utf16().collect();
3802 let byte_len = (val_utf16.len() * 2) as u16;
3803 inner.put_u16_le(byte_len);
3804 for &c in &val_utf16 {
3805 inner.put_u16_le(c);
3806 }
3807
3808 let mut cursor = inner.freeze();
3809 let rv = ReturnValue::decode(&mut cursor).unwrap();
3810
3811 assert_eq!(rv.param_ordinal, 1);
3812 assert_eq!(rv.param_name, "@name");
3813 assert_eq!(rv.status, 0x01);
3814 assert_eq!(rv.col_type, 0xE7); assert_eq!(rv.type_info.max_length, Some(200));
3816 assert!(rv.type_info.collation.is_some());
3817
3818 assert_eq!(rv.value.len(), 12); let val_len = u16::from_le_bytes([rv.value[0], rv.value[1]]);
3821 assert_eq!(val_len, 10);
3822 }
3823
3824 #[test]
3825 fn test_return_value_via_parser() {
3826 let mut data = BytesMut::new();
3828 data.put_u8(0xAC); data.extend_from_slice(&build_return_value_intn(0, "@out", 0x01, Some(99)));
3830
3831 let mut parser = TokenParser::new(data.freeze());
3832 let token = parser.next_token().unwrap().unwrap();
3833
3834 match token {
3835 Token::ReturnValue(rv) => {
3836 assert_eq!(rv.param_name, "@out");
3837 assert_eq!(rv.param_ordinal, 0);
3838 assert_eq!(rv.status, 0x01);
3839 assert_eq!(rv.col_type, 0x26);
3840 }
3841 _ => panic!("Expected ReturnValue token"),
3842 }
3843 }
3844
3845 #[test]
3850 fn test_multi_token_stored_proc_response() {
3851 let mut data = BytesMut::new();
3854
3855 data.put_u8(0xFF); data.put_u16_le(0x0010); data.put_u16_le(0x00C1); data.put_u64_le(3); data.put_u8(0x79); data.put_i32_le(0);
3864
3865 data.put_u8(0xFE); data.put_u16_le(0x0000); data.put_u16_le(0x00C6); data.put_u64_le(0);
3870
3871 let mut parser = TokenParser::new(data.freeze());
3872
3873 let t1 = parser.next_token().unwrap().unwrap();
3875 match t1 {
3876 Token::DoneInProc(done) => {
3877 assert!(done.status.count);
3878 assert_eq!(done.row_count, 3);
3879 assert_eq!(done.cur_cmd, 193);
3880 }
3881 _ => panic!("Expected DoneInProc, got {t1:?}"),
3882 }
3883
3884 let t2 = parser.next_token().unwrap().unwrap();
3886 match t2 {
3887 Token::ReturnStatus(status) => {
3888 assert_eq!(status, 0);
3889 }
3890 _ => panic!("Expected ReturnStatus, got {t2:?}"),
3891 }
3892
3893 let t3 = parser.next_token().unwrap().unwrap();
3895 match t3 {
3896 Token::DoneProc(done) => {
3897 assert!(!done.status.count);
3898 assert!(!done.status.more);
3899 assert_eq!(done.cur_cmd, 198);
3900 }
3901 _ => panic!("Expected DoneProc, got {t3:?}"),
3902 }
3903
3904 assert!(parser.next_token().unwrap().is_none());
3906 }
3907
3908 #[test]
3909 fn test_multi_token_error_in_stream() {
3910 let mut data = BytesMut::new();
3912
3913 data.put_u8(0xAA);
3915
3916 let msg_utf16: Vec<u16> = "Deadlock".encode_utf16().collect();
3917 let srv_utf16: Vec<u16> = "DB1".encode_utf16().collect();
3918
3919 let length: u16 = (4 + 1 + 1
3920 + 2 + (msg_utf16.len() * 2)
3921 + 1 + (srv_utf16.len() * 2)
3922 + 1 + 4) as u16;
3924
3925 data.put_u16_le(length);
3926 data.put_i32_le(1205); data.put_u8(51); data.put_u8(13); data.put_u16_le(msg_utf16.len() as u16);
3931 for &c in &msg_utf16 {
3932 data.put_u16_le(c);
3933 }
3934 data.put_u8(srv_utf16.len() as u8);
3935 for &c in &srv_utf16 {
3936 data.put_u16_le(c);
3937 }
3938 data.put_u8(0); data.put_i32_le(0);
3940
3941 data.put_u8(0xFD);
3943 data.put_u16_le(0x0002); data.put_u16_le(0x00C1); data.put_u64_le(0);
3946
3947 let mut parser = TokenParser::new(data.freeze());
3948
3949 let t1 = parser.next_token().unwrap().unwrap();
3951 match t1 {
3952 Token::Error(err) => {
3953 assert_eq!(err.number, 1205);
3954 assert_eq!(err.class, 13);
3955 assert_eq!(err.message, "Deadlock");
3956 assert_eq!(err.server, "DB1");
3957 }
3958 _ => panic!("Expected Error token, got {t1:?}"),
3959 }
3960
3961 let t2 = parser.next_token().unwrap().unwrap();
3963 match t2 {
3964 Token::Done(done) => {
3965 assert!(done.status.error);
3966 assert!(!done.status.count);
3967 }
3968 _ => panic!("Expected Done token, got {t2:?}"),
3969 }
3970
3971 assert!(parser.next_token().unwrap().is_none());
3972 }
3973
3974 #[test]
3975 fn test_multi_token_proc_with_return_value() {
3976 let mut data = BytesMut::new();
3978
3979 data.put_u8(0xAC);
3981 data.extend_from_slice(&build_return_value_intn(1, "@result", 0x01, Some(42)));
3982
3983 data.put_u8(0x79);
3985 data.put_i32_le(0);
3986
3987 data.put_u8(0xFE);
3989 data.put_u16_le(0x0000);
3990 data.put_u16_le(0x00C6);
3991 data.put_u64_le(0);
3992
3993 let mut parser = TokenParser::new(data.freeze());
3994
3995 let t1 = parser.next_token().unwrap().unwrap();
3996 match t1 {
3997 Token::ReturnValue(rv) => {
3998 assert_eq!(rv.param_name, "@result");
3999 assert_eq!(rv.param_ordinal, 1);
4000 }
4001 _ => panic!("Expected ReturnValue, got {t1:?}"),
4002 }
4003
4004 let t2 = parser.next_token().unwrap().unwrap();
4005 assert!(matches!(t2, Token::ReturnStatus(0)));
4006
4007 let t3 = parser.next_token().unwrap().unwrap();
4008 assert!(matches!(t3, Token::DoneProc(_)));
4009
4010 assert!(parser.next_token().unwrap().is_none());
4011 }
4012
4013 #[test]
4018 fn test_return_status_truncated() {
4019 let data = Bytes::from_static(&[0x79, 0x01, 0x02, 0x03]);
4021 let mut parser = TokenParser::new(data);
4022 assert!(parser.next_token().is_err());
4023 }
4024
4025 #[test]
4026 fn test_done_proc_truncated() {
4027 let data = Bytes::from_static(&[0xFE, 0x00, 0x00, 0xC1, 0x00, 0x01, 0x00, 0x00, 0x00]);
4029 let mut parser = TokenParser::new(data);
4030 assert!(parser.next_token().is_err());
4031 }
4032
4033 #[test]
4034 fn test_server_error_truncated() {
4035 let data = Bytes::from_static(&[0xAA, 0x20, 0x00]);
4037 let mut parser = TokenParser::new(data);
4038 assert!(parser.next_token().is_err());
4039 }
4040
4041 fn build_fed_auth_info_token(options: &[(u8, &str)]) -> Vec<u8> {
4050 let headers_end = 4 + options.len() * 9;
4051 let mut data_block = Vec::new();
4052 let mut headers = Vec::new();
4053 for (id, value) in options {
4054 let encoded: Vec<u8> = value.encode_utf16().flat_map(u16::to_le_bytes).collect();
4055 let offset = headers_end + data_block.len();
4056 headers.push(*id);
4057 headers.extend_from_slice(&u32::try_from(encoded.len()).unwrap().to_le_bytes());
4058 headers.extend_from_slice(&u32::try_from(offset).unwrap().to_le_bytes());
4059 data_block.extend_from_slice(&encoded);
4060 }
4061
4062 let token_len = 4 + headers.len() + data_block.len();
4063 let mut out = vec![0xEE];
4064 out.extend_from_slice(&u32::try_from(token_len).unwrap().to_le_bytes());
4065 out.extend_from_slice(&u32::try_from(options.len()).unwrap().to_le_bytes());
4066 out.extend_from_slice(&headers);
4067 out.extend_from_slice(&data_block);
4068 out
4069 }
4070
4071 #[test]
4072 fn test_fed_auth_info_decodes_spec_layout() {
4073 const STS: &str = "https://login.microsoftonline.com/common";
4074 const SPN: &str = "https://database.windows.net/";
4075 let token = build_fed_auth_info_token(&[(0x02, STS), (0x01, SPN)]);
4077
4078 let mut parser = TokenParser::new(Bytes::from(token));
4079 let parsed = parser.next_token().unwrap().unwrap();
4080 let Token::FedAuthInfo(info) = parsed else {
4081 panic!("expected FedAuthInfo, got {parsed:?}");
4082 };
4083 assert_eq!(info.sts_url, STS);
4084 assert_eq!(info.spn, SPN);
4085 assert!(parser.next_token().unwrap().is_none(), "exact consumption");
4086 }
4087
4088 #[test]
4089 fn test_fed_auth_info_preserves_following_tokens() {
4090 let mut stream = build_fed_auth_info_token(&[
4094 (0x02, "https://sts.example/"),
4095 (0x01, "https://db.example/"),
4096 ]);
4097 stream.push(0xFD); stream.extend_from_slice(&0u16.to_le_bytes()); stream.extend_from_slice(&0u16.to_le_bytes()); stream.extend_from_slice(&0u64.to_le_bytes()); let mut parser = TokenParser::new(Bytes::from(stream));
4103 assert!(matches!(
4104 parser.next_token().unwrap(),
4105 Some(Token::FedAuthInfo(_))
4106 ));
4107 assert!(
4108 matches!(parser.next_token().unwrap(), Some(Token::Done(_))),
4109 "DONE after FEDAUTHINFO must not be swallowed"
4110 );
4111 assert!(parser.next_token().unwrap().is_none());
4112 }
4113
4114 #[test]
4115 fn test_fed_auth_info_unknown_ids_ignored() {
4116 let token =
4118 build_fed_auth_info_token(&[(0x7F, "ignore-me"), (0x02, "https://sts.example/")]);
4119 let mut parser = TokenParser::new(Bytes::from(token));
4120 let Some(Token::FedAuthInfo(info)) = parser.next_token().unwrap() else {
4121 panic!("expected FedAuthInfo");
4122 };
4123 assert_eq!(info.sts_url, "https://sts.example/");
4124 assert_eq!(info.spn, "");
4125 }
4126
4127 #[test]
4128 fn test_fed_auth_info_hostile_inputs_error() {
4129 let mut truncated = build_fed_auth_info_token(&[(0x02, "https://sts.example/")]);
4131 truncated.truncate(truncated.len() - 4);
4132 assert!(
4133 TokenParser::new(Bytes::from(truncated))
4134 .next_token()
4135 .is_err()
4136 );
4137
4138 let mut bad_count = build_fed_auth_info_token(&[(0x02, "https://sts.example/")]);
4141 bad_count[5..9].copy_from_slice(&u32::MAX.to_le_bytes());
4142 assert!(
4143 TokenParser::new(Bytes::from(bad_count))
4144 .next_token()
4145 .is_err()
4146 );
4147
4148 let mut bad_offset = build_fed_auth_info_token(&[(0x02, "https://sts.example/")]);
4150 bad_offset[14..18].copy_from_slice(&u32::MAX.to_le_bytes());
4151 assert!(
4152 TokenParser::new(Bytes::from(bad_offset))
4153 .next_token()
4154 .is_err()
4155 );
4156
4157 let mut odd_len = build_fed_auth_info_token(&[(0x02, "https://sts.example/")]);
4159 odd_len[10..14].copy_from_slice(&3u32.to_le_bytes());
4160 assert!(TokenParser::new(Bytes::from(odd_len)).next_token().is_err());
4161 }
4162
4163 #[test]
4164 fn test_fed_auth_info_parse_and_skip_agree() {
4165 let token = build_fed_auth_info_token(&[(0x02, "https://sts.example/")]);
4168 let total = token.len();
4169
4170 let mut parser = TokenParser::new(Bytes::from(token.clone()));
4171 parser.next_token().unwrap();
4172 assert_eq!(parser.position(), total, "decode consumption");
4173
4174 let mut skipper = TokenParser::new(Bytes::from(token));
4175 skipper.skip_token().unwrap();
4176 assert_eq!(skipper.position(), total, "skip consumption");
4177 }
4178}