1use bytes::{Buf, BufMut, Bytes};
33
34use crate::codec::{read_b_varchar, read_us_varchar};
35use crate::error::ProtocolError;
36use crate::prelude::*;
37use crate::types::TypeId;
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
41#[repr(u8)]
42#[non_exhaustive]
43pub enum TokenType {
44 ColMetaData = 0x81,
46 Error = 0xAA,
48 Info = 0xAB,
50 LoginAck = 0xAD,
52 Row = 0xD1,
54 NbcRow = 0xD2,
56 EnvChange = 0xE3,
58 Sspi = 0xED,
60 Done = 0xFD,
62 DoneInProc = 0xFF,
64 DoneProc = 0xFE,
66 ReturnStatus = 0x79,
68 ReturnValue = 0xAC,
70 Order = 0xA9,
72 FeatureExtAck = 0xAE,
74 SessionState = 0xE4,
76 FedAuthInfo = 0xEE,
78 ColInfo = 0xA5,
80 TabName = 0xA4,
82 Offset = 0x78,
84}
85
86impl TokenType {
87 pub fn from_u8(value: u8) -> Option<Self> {
89 match value {
90 0x81 => Some(Self::ColMetaData),
91 0xAA => Some(Self::Error),
92 0xAB => Some(Self::Info),
93 0xAD => Some(Self::LoginAck),
94 0xD1 => Some(Self::Row),
95 0xD2 => Some(Self::NbcRow),
96 0xE3 => Some(Self::EnvChange),
97 0xED => Some(Self::Sspi),
98 0xFD => Some(Self::Done),
99 0xFF => Some(Self::DoneInProc),
100 0xFE => Some(Self::DoneProc),
101 0x79 => Some(Self::ReturnStatus),
102 0xAC => Some(Self::ReturnValue),
103 0xA9 => Some(Self::Order),
104 0xAE => Some(Self::FeatureExtAck),
105 0xE4 => Some(Self::SessionState),
106 0xEE => Some(Self::FedAuthInfo),
107 0xA5 => Some(Self::ColInfo),
108 0xA4 => Some(Self::TabName),
109 0x78 => Some(Self::Offset),
110 _ => None,
111 }
112 }
113}
114
115#[derive(Debug, Clone)]
120#[non_exhaustive]
121pub enum Token {
122 ColMetaData(ColMetaData),
124 Row(RawRow),
126 NbcRow(NbcRow),
128 Done(Done),
130 DoneProc(DoneProc),
132 DoneInProc(DoneInProc),
134 ReturnStatus(i32),
136 ReturnValue(ReturnValue),
138 Error(ServerError),
140 Info(ServerInfo),
142 LoginAck(LoginAck),
144 EnvChange(EnvChange),
146 Order(Order),
148 FeatureExtAck(FeatureExtAck),
150 Sspi(SspiToken),
152 SessionState(SessionState),
154 FedAuthInfo(FedAuthInfo),
156}
157
158#[derive(Debug, Clone, Default)]
160pub struct ColMetaData {
161 pub columns: Vec<ColumnData>,
163 pub cek_table: Option<crate::crypto::CekTable>,
166}
167
168#[derive(Debug, Clone)]
170pub struct ColumnData {
171 pub name: String,
173 pub type_id: TypeId,
175 pub col_type: u8,
177 pub flags: u16,
179 pub user_type: u32,
181 pub type_info: TypeInfo,
183 pub crypto_metadata: Option<crate::crypto::CryptoMetadata>,
186}
187
188#[derive(Debug, Clone, Default)]
190pub struct TypeInfo {
191 pub max_length: Option<u32>,
193 pub precision: Option<u8>,
195 pub scale: Option<u8>,
197 pub collation: Option<Collation>,
199}
200
201#[derive(Debug, Clone, Copy, Default)]
224pub struct Collation {
225 pub lcid: u32,
233 pub sort_id: u8,
237}
238
239impl Collation {
240 pub fn from_bytes(bytes: &[u8; 5]) -> Self {
244 Self {
245 lcid: u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]),
246 sort_id: bytes[4],
247 }
248 }
249
250 pub fn to_bytes(&self) -> [u8; 5] {
254 let b = self.lcid.to_le_bytes();
255 [b[0], b[1], b[2], b[3], self.sort_id]
256 }
257
258 #[cfg(feature = "encoding")]
284 pub fn encoding(&self) -> Option<&'static encoding_rs::Encoding> {
285 if self.sort_id != 0 {
289 return crate::collation::encoding_for_sort_id(self.sort_id);
290 }
291 crate::collation::encoding_for_lcid(self.lcid)
292 }
293
294 #[cfg(feature = "encoding")]
299 pub fn is_utf8(&self) -> bool {
300 crate::collation::is_utf8_collation(self.lcid)
301 }
302
303 #[cfg(feature = "encoding")]
311 pub fn code_page(&self) -> Option<u16> {
312 if self.sort_id != 0 {
315 return crate::collation::code_page_for_sort_id(self.sort_id);
316 }
317 crate::collation::code_page_for_lcid(self.lcid)
318 }
319
320 #[cfg(feature = "encoding")]
324 pub fn encoding_name(&self) -> &'static str {
325 if self.sort_id != 0 {
326 return match crate::collation::encoding_for_sort_id(self.sort_id) {
327 Some(enc) => enc.name(),
328 None => "unsupported",
329 };
330 }
331 crate::collation::encoding_name_for_lcid(self.lcid)
332 }
333}
334
335#[derive(Debug, Clone)]
337pub struct RawRow {
338 pub data: bytes::Bytes,
340}
341
342#[derive(Debug, Clone)]
344pub struct NbcRow {
345 pub null_bitmap: Vec<u8>,
347 pub data: bytes::Bytes,
349}
350
351#[derive(Debug, Clone, Copy)]
353pub struct Done {
354 pub status: DoneStatus,
356 pub cur_cmd: u16,
358 pub row_count: u64,
360}
361
362#[derive(Debug, Clone, Copy, Default)]
364#[non_exhaustive]
365pub struct DoneStatus {
366 pub more: bool,
368 pub error: bool,
370 pub in_xact: bool,
372 pub count: bool,
374 pub attn: bool,
376 pub srverror: bool,
378}
379
380#[derive(Debug, Clone, Copy)]
382pub struct DoneInProc {
383 pub status: DoneStatus,
385 pub cur_cmd: u16,
387 pub row_count: u64,
389}
390
391#[derive(Debug, Clone, Copy)]
393pub struct DoneProc {
394 pub status: DoneStatus,
396 pub cur_cmd: u16,
398 pub row_count: u64,
400}
401
402#[derive(Debug, Clone)]
404#[non_exhaustive]
405pub struct ReturnValue {
406 pub param_ordinal: u16,
408 pub param_name: String,
410 pub status: u8,
412 pub user_type: u32,
414 pub flags: u16,
416 pub col_type: u8,
418 pub type_info: TypeInfo,
420 pub value: bytes::Bytes,
422}
423
424#[derive(Debug, Clone)]
426pub struct ServerError {
427 pub number: i32,
429 pub state: u8,
431 pub class: u8,
433 pub message: String,
435 pub server: String,
437 pub procedure: String,
439 pub line: i32,
441}
442
443#[derive(Debug, Clone)]
445pub struct ServerInfo {
446 pub number: i32,
448 pub state: u8,
450 pub class: u8,
452 pub message: String,
454 pub server: String,
456 pub procedure: String,
458 pub line: i32,
460}
461
462#[derive(Debug, Clone)]
464pub struct LoginAck {
465 pub interface: u8,
467 pub tds_version: u32,
469 pub prog_name: String,
471 pub prog_version: u32,
473}
474
475#[derive(Debug, Clone)]
477pub struct EnvChange {
478 pub env_type: EnvChangeType,
480 pub new_value: EnvChangeValue,
482 pub old_value: EnvChangeValue,
484}
485
486#[derive(Debug, Clone, Copy, PartialEq, Eq)]
488#[repr(u8)]
489#[non_exhaustive]
490pub enum EnvChangeType {
491 Database = 1,
493 Language = 2,
495 CharacterSet = 3,
497 PacketSize = 4,
499 UnicodeSortingLocalId = 5,
501 UnicodeComparisonFlags = 6,
503 SqlCollation = 7,
505 BeginTransaction = 8,
507 CommitTransaction = 9,
509 RollbackTransaction = 10,
511 EnlistDtcTransaction = 11,
513 DefectTransaction = 12,
515 RealTimeLogShipping = 13,
517 PromoteTransaction = 15,
519 TransactionManagerAddress = 16,
521 TransactionEnded = 17,
523 ResetConnectionCompletionAck = 18,
525 UserInstanceStarted = 19,
527 Routing = 20,
529}
530
531#[derive(Debug, Clone)]
533#[non_exhaustive]
534pub enum EnvChangeValue {
535 String(String),
537 Binary(bytes::Bytes),
539 Routing {
541 host: String,
543 port: u16,
545 },
546}
547
548#[derive(Debug, Clone)]
550pub struct Order {
551 pub columns: Vec<u16>,
553}
554
555#[derive(Debug, Clone)]
557pub struct FeatureExtAck {
558 pub features: Vec<FeatureAck>,
560}
561
562#[derive(Debug, Clone)]
564pub struct FeatureAck {
565 pub feature_id: u8,
567 pub data: bytes::Bytes,
569}
570
571#[derive(Debug, Clone)]
573pub struct SspiToken {
574 pub data: bytes::Bytes,
576}
577
578#[derive(Debug, Clone)]
580pub struct SessionState {
581 pub data: bytes::Bytes,
583}
584
585#[derive(Debug, Clone)]
587pub struct FedAuthInfo {
588 pub sts_url: String,
590 pub spn: String,
592}
593
594pub(crate) fn decode_collation(src: &mut impl Buf) -> Result<Collation, ProtocolError> {
602 if src.remaining() < 5 {
603 return Err(ProtocolError::UnexpectedEof);
604 }
605 let lcid = src.get_u32_le();
607 let sort_id = src.get_u8();
608 Ok(Collation { lcid, sort_id })
609}
610
611pub(crate) fn decode_type_info(
615 src: &mut impl Buf,
616 type_id: TypeId,
617 col_type: u8,
618) -> Result<TypeInfo, ProtocolError> {
619 match type_id {
620 TypeId::Null => Ok(TypeInfo::default()),
622 TypeId::Int1 | TypeId::Bit => Ok(TypeInfo::default()),
623 TypeId::Int2 => Ok(TypeInfo::default()),
624 TypeId::Int4 => Ok(TypeInfo::default()),
625 TypeId::Int8 => Ok(TypeInfo::default()),
626 TypeId::Float4 => Ok(TypeInfo::default()),
627 TypeId::Float8 => Ok(TypeInfo::default()),
628 TypeId::Money => Ok(TypeInfo::default()),
629 TypeId::Money4 => Ok(TypeInfo::default()),
630 TypeId::DateTime => Ok(TypeInfo::default()),
631 TypeId::DateTime4 => Ok(TypeInfo::default()),
632
633 TypeId::IntN | TypeId::BitN | TypeId::FloatN | TypeId::MoneyN | TypeId::DateTimeN => {
635 if src.remaining() < 1 {
636 return Err(ProtocolError::UnexpectedEof);
637 }
638 let max_length = src.get_u8() as u32;
639 Ok(TypeInfo {
640 max_length: Some(max_length),
641 ..Default::default()
642 })
643 }
644
645 TypeId::Guid => {
647 if src.remaining() < 1 {
648 return Err(ProtocolError::UnexpectedEof);
649 }
650 let max_length = src.get_u8() as u32;
651 Ok(TypeInfo {
652 max_length: Some(max_length),
653 ..Default::default()
654 })
655 }
656
657 TypeId::Decimal | TypeId::Numeric | TypeId::DecimalN | TypeId::NumericN => {
659 if src.remaining() < 3 {
660 return Err(ProtocolError::UnexpectedEof);
661 }
662 let max_length = src.get_u8() as u32;
663 let precision = src.get_u8();
664 let scale = src.get_u8();
665 Ok(TypeInfo {
666 max_length: Some(max_length),
667 precision: Some(precision),
668 scale: Some(scale),
669 ..Default::default()
670 })
671 }
672
673 TypeId::Char | TypeId::VarChar | TypeId::Binary | TypeId::VarBinary => {
675 if src.remaining() < 1 {
676 return Err(ProtocolError::UnexpectedEof);
677 }
678 let max_length = src.get_u8() as u32;
679 Ok(TypeInfo {
680 max_length: Some(max_length),
681 ..Default::default()
682 })
683 }
684
685 TypeId::BigVarChar | TypeId::BigChar => {
687 if src.remaining() < 7 {
688 return Err(ProtocolError::UnexpectedEof);
690 }
691 let max_length = src.get_u16_le() as u32;
692 let collation = decode_collation(src)?;
693 Ok(TypeInfo {
694 max_length: Some(max_length),
695 collation: Some(collation),
696 ..Default::default()
697 })
698 }
699
700 TypeId::BigVarBinary | TypeId::BigBinary => {
702 if src.remaining() < 2 {
703 return Err(ProtocolError::UnexpectedEof);
704 }
705 let max_length = src.get_u16_le() as u32;
706 Ok(TypeInfo {
707 max_length: Some(max_length),
708 ..Default::default()
709 })
710 }
711
712 TypeId::NChar | TypeId::NVarChar => {
714 if src.remaining() < 7 {
715 return Err(ProtocolError::UnexpectedEof);
717 }
718 let max_length = src.get_u16_le() as u32;
719 let collation = decode_collation(src)?;
720 Ok(TypeInfo {
721 max_length: Some(max_length),
722 collation: Some(collation),
723 ..Default::default()
724 })
725 }
726
727 TypeId::Date => Ok(TypeInfo::default()),
729
730 TypeId::Time | TypeId::DateTime2 | TypeId::DateTimeOffset => {
732 if src.remaining() < 1 {
733 return Err(ProtocolError::UnexpectedEof);
734 }
735 let scale = src.get_u8();
736 Ok(TypeInfo {
737 scale: Some(scale),
738 ..Default::default()
739 })
740 }
741
742 TypeId::Text | TypeId::NText | TypeId::Image => {
744 if src.remaining() < 4 {
746 return Err(ProtocolError::UnexpectedEof);
747 }
748 let max_length = src.get_u32_le();
749
750 let collation = if type_id == TypeId::Text || type_id == TypeId::NText {
752 if src.remaining() < 5 {
753 return Err(ProtocolError::UnexpectedEof);
754 }
755 Some(decode_collation(src)?)
756 } else {
757 None
758 };
759
760 if src.remaining() < 1 {
763 return Err(ProtocolError::UnexpectedEof);
764 }
765 let num_parts = src.get_u8();
766 for _ in 0..num_parts {
767 let _ = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
769 }
770
771 Ok(TypeInfo {
772 max_length: Some(max_length),
773 collation,
774 ..Default::default()
775 })
776 }
777
778 TypeId::Xml => {
780 if src.remaining() < 1 {
781 return Err(ProtocolError::UnexpectedEof);
782 }
783 let schema_present = src.get_u8();
784
785 if schema_present != 0 {
786 let _ = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; let _ = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; let _ = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; }
793
794 Ok(TypeInfo::default())
795 }
796
797 TypeId::Udt => {
799 if src.remaining() < 2 {
801 return Err(ProtocolError::UnexpectedEof);
802 }
803 let max_length = src.get_u16_le() as u32;
804
805 let _ = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; let _ = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; let _ = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; let _ = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; Ok(TypeInfo {
814 max_length: Some(max_length),
815 ..Default::default()
816 })
817 }
818
819 TypeId::Tvp => {
821 Err(ProtocolError::InvalidTokenType(col_type))
824 }
825
826 TypeId::Variant => {
828 if src.remaining() < 4 {
829 return Err(ProtocolError::UnexpectedEof);
830 }
831 let max_length = src.get_u32_le();
832 Ok(TypeInfo {
833 max_length: Some(max_length),
834 ..Default::default()
835 })
836 }
837 }
838}
839
840impl ColMetaData {
841 pub const NO_METADATA: u16 = 0xFFFF;
843
844 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
846 if src.remaining() < 2 {
847 return Err(ProtocolError::UnexpectedEof);
848 }
849
850 let column_count = src.get_u16_le();
851
852 if column_count == Self::NO_METADATA {
854 return Ok(Self {
855 columns: Vec::new(),
856 cek_table: None,
857 });
858 }
859
860 let mut columns = Vec::with_capacity(column_count as usize);
861
862 for _ in 0..column_count {
863 let column = Self::decode_column(src)?;
864 columns.push(column);
865 }
866
867 Ok(Self {
868 columns,
869 cek_table: None,
870 })
871 }
872
873 fn decode_column(src: &mut impl Buf) -> Result<ColumnData, ProtocolError> {
875 if src.remaining() < 7 {
877 return Err(ProtocolError::UnexpectedEof);
878 }
879
880 let user_type = src.get_u32_le();
881 let flags = src.get_u16_le();
882 let col_type = src.get_u8();
883
884 let type_id = TypeId::from_u8(col_type).ok_or(ProtocolError::InvalidDataType(col_type))?;
888
889 let type_info = decode_type_info(src, type_id, col_type)?;
891
892 let name = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
894
895 Ok(ColumnData {
896 name,
897 type_id,
898 col_type,
899 flags,
900 user_type,
901 type_info,
902 crypto_metadata: None,
903 })
904 }
905
906 pub fn decode_encrypted(src: &mut impl Buf) -> Result<Self, ProtocolError> {
919 if src.remaining() < 2 {
920 return Err(ProtocolError::UnexpectedEof);
921 }
922
923 let column_count = src.get_u16_le();
924
925 if column_count == Self::NO_METADATA {
926 return Ok(Self {
927 columns: Vec::new(),
928 cek_table: None,
929 });
930 }
931
932 let cek_table = crate::crypto::CekTable::decode(src)?;
934
935 let mut columns = Vec::with_capacity(column_count as usize);
936
937 for _ in 0..column_count {
938 let column = Self::decode_column_encrypted(src)?;
939 columns.push(column);
940 }
941
942 Ok(Self {
943 columns,
944 cek_table: Some(cek_table),
945 })
946 }
947
948 fn decode_column_encrypted(src: &mut impl Buf) -> Result<ColumnData, ProtocolError> {
952 if src.remaining() < 7 {
953 return Err(ProtocolError::UnexpectedEof);
954 }
955
956 let user_type = src.get_u32_le();
957 let flags = src.get_u16_le();
958 let col_type = src.get_u8();
959
960 let type_id = TypeId::from_u8(col_type).ok_or(ProtocolError::InvalidDataType(col_type))?;
961
962 let type_info = decode_type_info(src, type_id, col_type)?;
964
965 let crypto_metadata = if crate::crypto::is_column_encrypted(flags) {
967 Some(crate::crypto::CryptoMetadata::decode(src)?)
968 } else {
969 None
970 };
971
972 let name = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
974
975 Ok(ColumnData {
976 name,
977 type_id,
978 col_type,
979 flags,
980 user_type,
981 type_info,
982 crypto_metadata,
983 })
984 }
985
986 #[must_use]
988 pub fn column_count(&self) -> usize {
989 self.columns.len()
990 }
991
992 #[must_use]
994 pub fn is_empty(&self) -> bool {
995 self.columns.is_empty()
996 }
997}
998
999impl ColumnData {
1000 #[must_use]
1002 pub fn is_nullable(&self) -> bool {
1003 (self.flags & 0x0001) != 0
1004 }
1005
1006 #[must_use]
1010 pub fn fixed_size(&self) -> Option<usize> {
1011 match self.type_id {
1012 TypeId::Null => Some(0),
1013 TypeId::Int1 | TypeId::Bit => Some(1),
1014 TypeId::Int2 => Some(2),
1015 TypeId::Int4 => Some(4),
1016 TypeId::Int8 => Some(8),
1017 TypeId::Float4 => Some(4),
1018 TypeId::Float8 => Some(8),
1019 TypeId::Money => Some(8),
1020 TypeId::Money4 => Some(4),
1021 TypeId::DateTime => Some(8),
1022 TypeId::DateTime4 => Some(4),
1023 TypeId::Date => Some(3),
1024 _ => None,
1025 }
1026 }
1027}
1028
1029impl RawRow {
1034 pub fn decode(src: &mut impl Buf, metadata: &ColMetaData) -> Result<Self, ProtocolError> {
1039 let mut data = bytes::BytesMut::new();
1040
1041 for col in &metadata.columns {
1042 Self::decode_column_value(src, col, &mut data)?;
1043 }
1044
1045 Ok(Self {
1046 data: data.freeze(),
1047 })
1048 }
1049
1050 pub fn decode_prefix(
1057 src: &mut impl Buf,
1058 metadata: &ColMetaData,
1059 prefix_len: usize,
1060 ) -> Result<Self, ProtocolError> {
1061 let mut data = bytes::BytesMut::new();
1062 for col in metadata.columns.iter().take(prefix_len) {
1063 Self::decode_column_value(src, col, &mut data)?;
1064 }
1065 Ok(Self {
1066 data: data.freeze(),
1067 })
1068 }
1069
1070 fn decode_column_value(
1072 src: &mut impl Buf,
1073 col: &ColumnData,
1074 dst: &mut bytes::BytesMut,
1075 ) -> Result<(), ProtocolError> {
1076 match col.type_id {
1077 TypeId::Null => {
1079 }
1081 TypeId::Int1 | TypeId::Bit => {
1082 if src.remaining() < 1 {
1083 return Err(ProtocolError::UnexpectedEof);
1084 }
1085 dst.extend_from_slice(&[src.get_u8()]);
1086 }
1087 TypeId::Int2 => {
1088 if src.remaining() < 2 {
1089 return Err(ProtocolError::UnexpectedEof);
1090 }
1091 dst.extend_from_slice(&src.get_u16_le().to_le_bytes());
1092 }
1093 TypeId::Int4 => {
1094 if src.remaining() < 4 {
1095 return Err(ProtocolError::UnexpectedEof);
1096 }
1097 dst.extend_from_slice(&src.get_u32_le().to_le_bytes());
1098 }
1099 TypeId::Int8 => {
1100 if src.remaining() < 8 {
1101 return Err(ProtocolError::UnexpectedEof);
1102 }
1103 dst.extend_from_slice(&src.get_u64_le().to_le_bytes());
1104 }
1105 TypeId::Float4 => {
1106 if src.remaining() < 4 {
1107 return Err(ProtocolError::UnexpectedEof);
1108 }
1109 dst.extend_from_slice(&src.get_u32_le().to_le_bytes());
1110 }
1111 TypeId::Float8 => {
1112 if src.remaining() < 8 {
1113 return Err(ProtocolError::UnexpectedEof);
1114 }
1115 dst.extend_from_slice(&src.get_u64_le().to_le_bytes());
1116 }
1117 TypeId::Money => {
1118 if src.remaining() < 8 {
1119 return Err(ProtocolError::UnexpectedEof);
1120 }
1121 let hi = src.get_u32_le();
1122 let lo = src.get_u32_le();
1123 dst.extend_from_slice(&hi.to_le_bytes());
1124 dst.extend_from_slice(&lo.to_le_bytes());
1125 }
1126 TypeId::Money4 => {
1127 if src.remaining() < 4 {
1128 return Err(ProtocolError::UnexpectedEof);
1129 }
1130 dst.extend_from_slice(&src.get_u32_le().to_le_bytes());
1131 }
1132 TypeId::DateTime => {
1133 if src.remaining() < 8 {
1134 return Err(ProtocolError::UnexpectedEof);
1135 }
1136 let days = src.get_u32_le();
1137 let time = src.get_u32_le();
1138 dst.extend_from_slice(&days.to_le_bytes());
1139 dst.extend_from_slice(&time.to_le_bytes());
1140 }
1141 TypeId::DateTime4 => {
1142 if src.remaining() < 4 {
1143 return Err(ProtocolError::UnexpectedEof);
1144 }
1145 dst.extend_from_slice(&src.get_u32_le().to_le_bytes());
1146 }
1147 TypeId::Date => {
1149 Self::decode_bytelen_type(src, dst)?;
1150 }
1151
1152 TypeId::IntN | TypeId::BitN | TypeId::FloatN | TypeId::MoneyN | TypeId::DateTimeN => {
1154 Self::decode_bytelen_type(src, dst)?;
1155 }
1156
1157 TypeId::Guid => {
1158 Self::decode_bytelen_type(src, dst)?;
1159 }
1160
1161 TypeId::Decimal | TypeId::Numeric | TypeId::DecimalN | TypeId::NumericN => {
1162 Self::decode_bytelen_type(src, dst)?;
1163 }
1164
1165 TypeId::Char | TypeId::VarChar | TypeId::Binary | TypeId::VarBinary => {
1167 Self::decode_bytelen_type(src, dst)?;
1168 }
1169
1170 TypeId::BigVarChar | TypeId::BigVarBinary => {
1172 if col.type_info.max_length == Some(0xFFFF) {
1174 Self::decode_plp_type(src, dst)?;
1175 } else {
1176 Self::decode_ushortlen_type(src, dst)?;
1177 }
1178 }
1179
1180 TypeId::BigChar | TypeId::BigBinary => {
1182 Self::decode_ushortlen_type(src, dst)?;
1183 }
1184
1185 TypeId::NVarChar => {
1187 if col.type_info.max_length == Some(0xFFFF) {
1189 Self::decode_plp_type(src, dst)?;
1190 } else {
1191 Self::decode_ushortlen_type(src, dst)?;
1192 }
1193 }
1194
1195 TypeId::NChar => {
1197 Self::decode_ushortlen_type(src, dst)?;
1198 }
1199
1200 TypeId::Time | TypeId::DateTime2 | TypeId::DateTimeOffset => {
1202 Self::decode_bytelen_type(src, dst)?;
1203 }
1204
1205 TypeId::Text | TypeId::NText | TypeId::Image => {
1207 Self::decode_textptr_type(src, dst)?;
1208 }
1209
1210 TypeId::Xml => {
1212 Self::decode_plp_type(src, dst)?;
1213 }
1214
1215 TypeId::Variant => {
1217 Self::decode_intlen_type(src, dst)?;
1218 }
1219
1220 TypeId::Udt => {
1221 Self::decode_plp_type(src, dst)?;
1223 }
1224
1225 TypeId::Tvp => {
1226 return Err(ProtocolError::InvalidTokenType(col.col_type));
1228 }
1229 }
1230
1231 Ok(())
1232 }
1233
1234 fn decode_bytelen_type(
1236 src: &mut impl Buf,
1237 dst: &mut bytes::BytesMut,
1238 ) -> Result<(), ProtocolError> {
1239 if src.remaining() < 1 {
1240 return Err(ProtocolError::UnexpectedEof);
1241 }
1242 let len = src.get_u8() as usize;
1243 if len == 0xFF {
1244 dst.extend_from_slice(&[0xFF]);
1246 } else if len == 0 {
1247 dst.extend_from_slice(&[0x00]);
1249 } else {
1250 if src.remaining() < len {
1251 return Err(ProtocolError::UnexpectedEof);
1252 }
1253 dst.extend_from_slice(&[len as u8]);
1254 for _ in 0..len {
1255 dst.extend_from_slice(&[src.get_u8()]);
1256 }
1257 }
1258 Ok(())
1259 }
1260
1261 fn decode_ushortlen_type(
1263 src: &mut impl Buf,
1264 dst: &mut bytes::BytesMut,
1265 ) -> Result<(), ProtocolError> {
1266 if src.remaining() < 2 {
1267 return Err(ProtocolError::UnexpectedEof);
1268 }
1269 let len = src.get_u16_le() as usize;
1270 if len == 0xFFFF {
1271 dst.extend_from_slice(&0xFFFFu16.to_le_bytes());
1273 } else if len == 0 {
1274 dst.extend_from_slice(&0u16.to_le_bytes());
1276 } else {
1277 if src.remaining() < len {
1278 return Err(ProtocolError::UnexpectedEof);
1279 }
1280 dst.extend_from_slice(&(len as u16).to_le_bytes());
1281 for _ in 0..len {
1282 dst.extend_from_slice(&[src.get_u8()]);
1283 }
1284 }
1285 Ok(())
1286 }
1287
1288 fn decode_intlen_type(
1290 src: &mut impl Buf,
1291 dst: &mut bytes::BytesMut,
1292 ) -> Result<(), ProtocolError> {
1293 if src.remaining() < 4 {
1294 return Err(ProtocolError::UnexpectedEof);
1295 }
1296 let len = src.get_u32_le() as usize;
1297 if len == 0xFFFFFFFF {
1298 dst.extend_from_slice(&0xFFFFFFFFu32.to_le_bytes());
1300 } else if len == 0 {
1301 dst.extend_from_slice(&0u32.to_le_bytes());
1303 } else {
1304 if src.remaining() < len {
1305 return Err(ProtocolError::UnexpectedEof);
1306 }
1307 dst.extend_from_slice(&(len as u32).to_le_bytes());
1308 for _ in 0..len {
1309 dst.extend_from_slice(&[src.get_u8()]);
1310 }
1311 }
1312 Ok(())
1313 }
1314
1315 fn decode_textptr_type(
1330 src: &mut impl Buf,
1331 dst: &mut bytes::BytesMut,
1332 ) -> Result<(), ProtocolError> {
1333 if src.remaining() < 1 {
1334 return Err(ProtocolError::UnexpectedEof);
1335 }
1336
1337 let textptr_len = src.get_u8() as usize;
1338
1339 if textptr_len == 0 {
1340 dst.extend_from_slice(&0xFFFFFFFFFFFFFFFFu64.to_le_bytes());
1342 return Ok(());
1343 }
1344
1345 if src.remaining() < textptr_len {
1347 return Err(ProtocolError::UnexpectedEof);
1348 }
1349 src.advance(textptr_len);
1350
1351 if src.remaining() < 8 {
1353 return Err(ProtocolError::UnexpectedEof);
1354 }
1355 src.advance(8);
1356
1357 if src.remaining() < 4 {
1359 return Err(ProtocolError::UnexpectedEof);
1360 }
1361 let data_len = src.get_u32_le() as usize;
1362
1363 if src.remaining() < data_len {
1364 return Err(ProtocolError::UnexpectedEof);
1365 }
1366
1367 dst.extend_from_slice(&(data_len as u64).to_le_bytes());
1373 dst.extend_from_slice(&(data_len as u32).to_le_bytes());
1374 for _ in 0..data_len {
1375 dst.extend_from_slice(&[src.get_u8()]);
1376 }
1377 dst.extend_from_slice(&0u32.to_le_bytes()); Ok(())
1380 }
1381
1382 fn decode_plp_type(src: &mut impl Buf, dst: &mut bytes::BytesMut) -> Result<(), ProtocolError> {
1388 if src.remaining() < 8 {
1389 return Err(ProtocolError::UnexpectedEof);
1390 }
1391
1392 let total_len = src.get_u64_le();
1393
1394 dst.extend_from_slice(&total_len.to_le_bytes());
1396
1397 if total_len == 0xFFFFFFFFFFFFFFFF {
1398 return Ok(());
1400 }
1401
1402 loop {
1404 if src.remaining() < 4 {
1405 return Err(ProtocolError::UnexpectedEof);
1406 }
1407 let chunk_len = src.get_u32_le() as usize;
1408 dst.extend_from_slice(&(chunk_len as u32).to_le_bytes());
1409
1410 if chunk_len == 0 {
1411 break;
1413 }
1414
1415 if src.remaining() < chunk_len {
1416 return Err(ProtocolError::UnexpectedEof);
1417 }
1418
1419 for _ in 0..chunk_len {
1420 dst.extend_from_slice(&[src.get_u8()]);
1421 }
1422 }
1423
1424 Ok(())
1425 }
1426}
1427
1428impl NbcRow {
1433 pub fn decode(src: &mut impl Buf, metadata: &ColMetaData) -> Result<Self, ProtocolError> {
1438 let col_count = metadata.columns.len();
1439 let bitmap_len = col_count.div_ceil(8);
1440
1441 if src.remaining() < bitmap_len {
1442 return Err(ProtocolError::UnexpectedEof);
1443 }
1444
1445 let mut null_bitmap = vec![0u8; bitmap_len];
1447 for byte in &mut null_bitmap {
1448 *byte = src.get_u8();
1449 }
1450
1451 let mut data = bytes::BytesMut::new();
1453
1454 for (i, col) in metadata.columns.iter().enumerate() {
1455 let byte_idx = i / 8;
1456 let bit_idx = i % 8;
1457 let is_null = (null_bitmap[byte_idx] & (1 << bit_idx)) != 0;
1458
1459 if !is_null {
1460 RawRow::decode_column_value(src, col, &mut data)?;
1463 }
1464 }
1465
1466 Ok(Self {
1467 null_bitmap,
1468 data: data.freeze(),
1469 })
1470 }
1471
1472 pub fn decode_prefix(
1478 src: &mut impl Buf,
1479 metadata: &ColMetaData,
1480 prefix_len: usize,
1481 ) -> Result<Self, ProtocolError> {
1482 let col_count = metadata.columns.len();
1483 let bitmap_len = col_count.div_ceil(8);
1484
1485 if src.remaining() < bitmap_len {
1486 return Err(ProtocolError::UnexpectedEof);
1487 }
1488
1489 let mut null_bitmap = vec![0u8; bitmap_len];
1490 for byte in &mut null_bitmap {
1491 *byte = src.get_u8();
1492 }
1493
1494 let mut data = bytes::BytesMut::new();
1495 for (i, col) in metadata.columns.iter().enumerate().take(prefix_len) {
1496 let is_null = (null_bitmap[i / 8] & (1 << (i % 8))) != 0;
1497 if !is_null {
1498 RawRow::decode_column_value(src, col, &mut data)?;
1499 }
1500 }
1501
1502 Ok(Self {
1503 null_bitmap,
1504 data: data.freeze(),
1505 })
1506 }
1507
1508 #[must_use]
1510 pub fn is_null(&self, column_index: usize) -> bool {
1511 let byte_idx = column_index / 8;
1512 let bit_idx = column_index % 8;
1513 if byte_idx < self.null_bitmap.len() {
1514 (self.null_bitmap[byte_idx] & (1 << bit_idx)) != 0
1515 } else {
1516 true }
1518 }
1519}
1520
1521impl ReturnValue {
1526 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
1528 if src.remaining() < 2 {
1535 return Err(ProtocolError::UnexpectedEof);
1536 }
1537 let param_ordinal = src.get_u16_le();
1538
1539 let param_name = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
1541
1542 if src.remaining() < 1 {
1544 return Err(ProtocolError::UnexpectedEof);
1545 }
1546 let status = src.get_u8();
1547
1548 if src.remaining() < 7 {
1550 return Err(ProtocolError::UnexpectedEof);
1551 }
1552 let user_type = src.get_u32_le();
1553 let flags = src.get_u16_le();
1554 let col_type = src.get_u8();
1555
1556 let type_id = TypeId::from_u8(col_type).ok_or(ProtocolError::InvalidDataType(col_type))?;
1557
1558 let type_info = decode_type_info(src, type_id, col_type)?;
1560
1561 let mut value_buf = bytes::BytesMut::new();
1563
1564 let temp_col = ColumnData {
1566 name: String::new(),
1567 type_id,
1568 col_type,
1569 flags,
1570 user_type,
1571 type_info: type_info.clone(),
1572 crypto_metadata: None,
1573 };
1574
1575 RawRow::decode_column_value(src, &temp_col, &mut value_buf)?;
1576
1577 Ok(Self {
1578 param_ordinal,
1579 param_name,
1580 status,
1581 user_type,
1582 flags,
1583 col_type,
1584 type_info,
1585 value: value_buf.freeze(),
1586 })
1587 }
1588}
1589
1590impl SessionState {
1595 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
1597 if src.remaining() < 4 {
1598 return Err(ProtocolError::UnexpectedEof);
1599 }
1600
1601 let length = src.get_u32_le() as usize;
1602
1603 if src.remaining() < length {
1604 return Err(ProtocolError::IncompletePacket {
1605 expected: length,
1606 actual: src.remaining(),
1607 });
1608 }
1609
1610 let data = src.copy_to_bytes(length);
1611
1612 Ok(Self { data })
1613 }
1614}
1615
1616mod done_status_bits {
1622 pub const DONE_MORE: u16 = 0x0001;
1623 pub const DONE_ERROR: u16 = 0x0002;
1624 pub const DONE_INXACT: u16 = 0x0004;
1625 pub const DONE_COUNT: u16 = 0x0010;
1626 pub const DONE_ATTN: u16 = 0x0020;
1627 pub const DONE_SRVERROR: u16 = 0x0100;
1628}
1629
1630impl DoneStatus {
1631 #[must_use]
1633 pub fn from_bits(bits: u16) -> Self {
1634 use done_status_bits::*;
1635 Self {
1636 more: (bits & DONE_MORE) != 0,
1637 error: (bits & DONE_ERROR) != 0,
1638 in_xact: (bits & DONE_INXACT) != 0,
1639 count: (bits & DONE_COUNT) != 0,
1640 attn: (bits & DONE_ATTN) != 0,
1641 srverror: (bits & DONE_SRVERROR) != 0,
1642 }
1643 }
1644
1645 #[must_use]
1647 pub fn to_bits(&self) -> u16 {
1648 use done_status_bits::*;
1649 let mut bits = 0u16;
1650 if self.more {
1651 bits |= DONE_MORE;
1652 }
1653 if self.error {
1654 bits |= DONE_ERROR;
1655 }
1656 if self.in_xact {
1657 bits |= DONE_INXACT;
1658 }
1659 if self.count {
1660 bits |= DONE_COUNT;
1661 }
1662 if self.attn {
1663 bits |= DONE_ATTN;
1664 }
1665 if self.srverror {
1666 bits |= DONE_SRVERROR;
1667 }
1668 bits
1669 }
1670}
1671
1672impl Done {
1673 pub const SIZE: usize = 12; pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
1678 if src.remaining() < Self::SIZE {
1679 return Err(ProtocolError::IncompletePacket {
1680 expected: Self::SIZE,
1681 actual: src.remaining(),
1682 });
1683 }
1684
1685 let status = DoneStatus::from_bits(src.get_u16_le());
1686 let cur_cmd = src.get_u16_le();
1687 let row_count = src.get_u64_le();
1688
1689 Ok(Self {
1690 status,
1691 cur_cmd,
1692 row_count,
1693 })
1694 }
1695
1696 pub fn encode(&self, dst: &mut impl BufMut) {
1698 dst.put_u8(TokenType::Done as u8);
1699 dst.put_u16_le(self.status.to_bits());
1700 dst.put_u16_le(self.cur_cmd);
1701 dst.put_u64_le(self.row_count);
1702 }
1703
1704 #[must_use]
1706 pub const fn has_more(&self) -> bool {
1707 self.status.more
1708 }
1709
1710 #[must_use]
1712 pub const fn has_error(&self) -> bool {
1713 self.status.error
1714 }
1715
1716 #[must_use]
1718 pub const fn has_count(&self) -> bool {
1719 self.status.count
1720 }
1721}
1722
1723impl DoneProc {
1724 pub const SIZE: usize = 12;
1726
1727 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
1729 if src.remaining() < Self::SIZE {
1730 return Err(ProtocolError::IncompletePacket {
1731 expected: Self::SIZE,
1732 actual: src.remaining(),
1733 });
1734 }
1735
1736 let status = DoneStatus::from_bits(src.get_u16_le());
1737 let cur_cmd = src.get_u16_le();
1738 let row_count = src.get_u64_le();
1739
1740 Ok(Self {
1741 status,
1742 cur_cmd,
1743 row_count,
1744 })
1745 }
1746
1747 pub fn encode(&self, dst: &mut impl BufMut) {
1749 dst.put_u8(TokenType::DoneProc as u8);
1750 dst.put_u16_le(self.status.to_bits());
1751 dst.put_u16_le(self.cur_cmd);
1752 dst.put_u64_le(self.row_count);
1753 }
1754}
1755
1756impl DoneInProc {
1757 pub const SIZE: usize = 12;
1759
1760 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
1762 if src.remaining() < Self::SIZE {
1763 return Err(ProtocolError::IncompletePacket {
1764 expected: Self::SIZE,
1765 actual: src.remaining(),
1766 });
1767 }
1768
1769 let status = DoneStatus::from_bits(src.get_u16_le());
1770 let cur_cmd = src.get_u16_le();
1771 let row_count = src.get_u64_le();
1772
1773 Ok(Self {
1774 status,
1775 cur_cmd,
1776 row_count,
1777 })
1778 }
1779
1780 pub fn encode(&self, dst: &mut impl BufMut) {
1782 dst.put_u8(TokenType::DoneInProc as u8);
1783 dst.put_u16_le(self.status.to_bits());
1784 dst.put_u16_le(self.cur_cmd);
1785 dst.put_u64_le(self.row_count);
1786 }
1787}
1788
1789impl ServerError {
1790 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
1792 if src.remaining() < 2 {
1795 return Err(ProtocolError::UnexpectedEof);
1796 }
1797
1798 let _length = src.get_u16_le();
1799
1800 if src.remaining() < 6 {
1801 return Err(ProtocolError::UnexpectedEof);
1802 }
1803
1804 let number = src.get_i32_le();
1805 let state = src.get_u8();
1806 let class = src.get_u8();
1807
1808 let message = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
1809 let server = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
1810 let procedure = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
1811
1812 if src.remaining() < 4 {
1813 return Err(ProtocolError::UnexpectedEof);
1814 }
1815 let line = src.get_i32_le();
1816
1817 Ok(Self {
1818 number,
1819 state,
1820 class,
1821 message,
1822 server,
1823 procedure,
1824 line,
1825 })
1826 }
1827
1828 #[must_use]
1830 pub const fn is_fatal(&self) -> bool {
1831 self.class >= 20
1832 }
1833
1834 #[must_use]
1836 pub const fn is_batch_abort(&self) -> bool {
1837 self.class >= 16
1838 }
1839}
1840
1841impl ServerInfo {
1842 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
1846 if src.remaining() < 2 {
1847 return Err(ProtocolError::UnexpectedEof);
1848 }
1849
1850 let _length = src.get_u16_le();
1851
1852 if src.remaining() < 6 {
1853 return Err(ProtocolError::UnexpectedEof);
1854 }
1855
1856 let number = src.get_i32_le();
1857 let state = src.get_u8();
1858 let class = src.get_u8();
1859
1860 let message = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
1861 let server = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
1862 let procedure = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
1863
1864 if src.remaining() < 4 {
1865 return Err(ProtocolError::UnexpectedEof);
1866 }
1867 let line = src.get_i32_le();
1868
1869 Ok(Self {
1870 number,
1871 state,
1872 class,
1873 message,
1874 server,
1875 procedure,
1876 line,
1877 })
1878 }
1879}
1880
1881impl LoginAck {
1882 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
1884 if src.remaining() < 2 {
1886 return Err(ProtocolError::UnexpectedEof);
1887 }
1888
1889 let _length = src.get_u16_le();
1890
1891 if src.remaining() < 5 {
1892 return Err(ProtocolError::UnexpectedEof);
1893 }
1894
1895 let interface = src.get_u8();
1896 let tds_version = src.get_u32_le();
1897 let prog_name = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
1898
1899 if src.remaining() < 4 {
1900 return Err(ProtocolError::UnexpectedEof);
1901 }
1902 let prog_version = src.get_u32_le();
1903
1904 Ok(Self {
1905 interface,
1906 tds_version,
1907 prog_name,
1908 prog_version,
1909 })
1910 }
1911
1912 #[must_use]
1914 pub fn tds_version(&self) -> crate::version::TdsVersion {
1915 crate::version::TdsVersion::new(self.tds_version)
1916 }
1917}
1918
1919impl EnvChangeType {
1920 pub fn from_u8(value: u8) -> Option<Self> {
1922 match value {
1923 1 => Some(Self::Database),
1924 2 => Some(Self::Language),
1925 3 => Some(Self::CharacterSet),
1926 4 => Some(Self::PacketSize),
1927 5 => Some(Self::UnicodeSortingLocalId),
1928 6 => Some(Self::UnicodeComparisonFlags),
1929 7 => Some(Self::SqlCollation),
1930 8 => Some(Self::BeginTransaction),
1931 9 => Some(Self::CommitTransaction),
1932 10 => Some(Self::RollbackTransaction),
1933 11 => Some(Self::EnlistDtcTransaction),
1934 12 => Some(Self::DefectTransaction),
1935 13 => Some(Self::RealTimeLogShipping),
1936 15 => Some(Self::PromoteTransaction),
1937 16 => Some(Self::TransactionManagerAddress),
1938 17 => Some(Self::TransactionEnded),
1939 18 => Some(Self::ResetConnectionCompletionAck),
1940 19 => Some(Self::UserInstanceStarted),
1941 20 => Some(Self::Routing),
1942 _ => None,
1943 }
1944 }
1945}
1946
1947impl EnvChange {
1948 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
1950 if src.remaining() < 3 {
1951 return Err(ProtocolError::UnexpectedEof);
1952 }
1953
1954 let length = src.get_u16_le() as usize;
1955 if length == 0 {
1956 return Err(ProtocolError::UnexpectedEof);
1959 }
1960 if src.remaining() < length {
1961 return Err(ProtocolError::IncompletePacket {
1962 expected: length,
1963 actual: src.remaining(),
1964 });
1965 }
1966
1967 let mut frame = src.copy_to_bytes(length);
1976 let src = &mut frame;
1977
1978 let env_type_byte = src.get_u8();
1979 let env_type = EnvChangeType::from_u8(env_type_byte)
1980 .ok_or(ProtocolError::InvalidTokenType(env_type_byte))?;
1981
1982 let (new_value, old_value) = match env_type {
1983 EnvChangeType::Routing => {
1984 let new_value = Self::decode_routing_value(src)?;
1986 let old_value = EnvChangeValue::Binary(Bytes::new());
1987 (new_value, old_value)
1988 }
1989 EnvChangeType::BeginTransaction
1990 | EnvChangeType::CommitTransaction
1991 | EnvChangeType::RollbackTransaction
1992 | EnvChangeType::EnlistDtcTransaction
1993 | EnvChangeType::SqlCollation => {
1994 let new_len = if src.has_remaining() {
2003 src.get_u8() as usize
2004 } else {
2005 0
2006 };
2007 let new_value = if new_len > 0 && src.remaining() >= new_len {
2008 EnvChangeValue::Binary(src.copy_to_bytes(new_len))
2009 } else {
2010 EnvChangeValue::Binary(Bytes::new())
2011 };
2012
2013 let old_len = if src.has_remaining() {
2014 src.get_u8() as usize
2015 } else {
2016 0
2017 };
2018 let old_value = if old_len > 0 && src.remaining() >= old_len {
2019 EnvChangeValue::Binary(src.copy_to_bytes(old_len))
2020 } else {
2021 EnvChangeValue::Binary(Bytes::new())
2022 };
2023
2024 (new_value, old_value)
2025 }
2026 _ => {
2027 let new_value = read_b_varchar(src)
2029 .map(EnvChangeValue::String)
2030 .unwrap_or(EnvChangeValue::String(String::new()));
2031
2032 let old_value = read_b_varchar(src)
2033 .map(EnvChangeValue::String)
2034 .unwrap_or(EnvChangeValue::String(String::new()));
2035
2036 (new_value, old_value)
2037 }
2038 };
2039
2040 Ok(Self {
2046 env_type,
2047 new_value,
2048 old_value,
2049 })
2050 }
2051
2052 fn decode_routing_value(src: &mut impl Buf) -> Result<EnvChangeValue, ProtocolError> {
2053 if src.remaining() < 2 {
2055 return Err(ProtocolError::UnexpectedEof);
2056 }
2057
2058 let _routing_len = src.get_u16_le();
2059
2060 if src.remaining() < 5 {
2061 return Err(ProtocolError::UnexpectedEof);
2062 }
2063
2064 let _protocol = src.get_u8();
2065 let port = src.get_u16_le();
2066 let server_len = src.get_u16_le() as usize;
2067
2068 if src.remaining() < server_len * 2 {
2070 return Err(ProtocolError::UnexpectedEof);
2071 }
2072
2073 let mut chars = Vec::with_capacity(server_len);
2074 for _ in 0..server_len {
2075 chars.push(src.get_u16_le());
2076 }
2077
2078 let host = String::from_utf16(&chars).map_err(|_| {
2079 ProtocolError::StringEncoding(
2080 #[cfg(feature = "std")]
2081 "invalid UTF-16 in routing hostname".to_string(),
2082 #[cfg(not(feature = "std"))]
2083 "invalid UTF-16 in routing hostname",
2084 )
2085 })?;
2086
2087 Ok(EnvChangeValue::Routing { host, port })
2088 }
2089
2090 #[must_use]
2092 pub fn is_routing(&self) -> bool {
2093 self.env_type == EnvChangeType::Routing
2094 }
2095
2096 #[must_use]
2098 pub fn routing_info(&self) -> Option<(&str, u16)> {
2099 if let EnvChangeValue::Routing { host, port } = &self.new_value {
2100 Some((host, *port))
2101 } else {
2102 None
2103 }
2104 }
2105
2106 #[must_use]
2108 pub fn new_database(&self) -> Option<&str> {
2109 if self.env_type == EnvChangeType::Database {
2110 if let EnvChangeValue::String(s) = &self.new_value {
2111 return Some(s);
2112 }
2113 }
2114 None
2115 }
2116}
2117
2118impl Order {
2119 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
2121 if src.remaining() < 2 {
2122 return Err(ProtocolError::UnexpectedEof);
2123 }
2124
2125 let length = src.get_u16_le() as usize;
2126 let column_count = length / 2;
2127
2128 if src.remaining() < length {
2129 return Err(ProtocolError::IncompletePacket {
2130 expected: length,
2131 actual: src.remaining(),
2132 });
2133 }
2134
2135 let mut columns = Vec::with_capacity(column_count);
2136 for _ in 0..column_count {
2137 columns.push(src.get_u16_le());
2138 }
2139
2140 Ok(Self { columns })
2141 }
2142}
2143
2144impl FeatureExtAck {
2145 pub const TERMINATOR: u8 = 0xFF;
2147
2148 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
2150 let mut features = Vec::new();
2151
2152 loop {
2153 if !src.has_remaining() {
2154 return Err(ProtocolError::UnexpectedEof);
2155 }
2156
2157 let feature_id = src.get_u8();
2158 if feature_id == Self::TERMINATOR {
2159 break;
2160 }
2161
2162 if src.remaining() < 4 {
2163 return Err(ProtocolError::UnexpectedEof);
2164 }
2165
2166 let data_len = src.get_u32_le() as usize;
2167
2168 if src.remaining() < data_len {
2169 return Err(ProtocolError::IncompletePacket {
2170 expected: data_len,
2171 actual: src.remaining(),
2172 });
2173 }
2174
2175 let data = src.copy_to_bytes(data_len);
2176 features.push(FeatureAck { feature_id, data });
2177 }
2178
2179 Ok(Self { features })
2180 }
2181}
2182
2183impl SspiToken {
2184 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
2186 if src.remaining() < 2 {
2187 return Err(ProtocolError::UnexpectedEof);
2188 }
2189
2190 let length = src.get_u16_le() as usize;
2191
2192 if src.remaining() < length {
2193 return Err(ProtocolError::IncompletePacket {
2194 expected: length,
2195 actual: src.remaining(),
2196 });
2197 }
2198
2199 let data = src.copy_to_bytes(length);
2200 Ok(Self { data })
2201 }
2202}
2203
2204impl FedAuthInfo {
2205 const ID_STSURL: u8 = 0x01;
2207 const ID_SPN: u8 = 0x02;
2210 const OPT_HEADER_LEN: usize = 9;
2212
2213 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
2225 if src.remaining() < 4 {
2226 return Err(ProtocolError::UnexpectedEof);
2227 }
2228 let token_len = src.get_u32_le() as usize;
2229 if src.remaining() < token_len {
2230 return Err(ProtocolError::UnexpectedEof);
2231 }
2232
2233 let region = src.copy_to_bytes(token_len);
2236 if region.len() < 4 {
2237 return Err(ProtocolError::UnexpectedEof);
2238 }
2239 let count = u32::from_le_bytes([region[0], region[1], region[2], region[3]]) as usize;
2240
2241 let headers_end = count
2245 .checked_mul(Self::OPT_HEADER_LEN)
2246 .and_then(|n| n.checked_add(4))
2247 .ok_or(ProtocolError::UnexpectedEof)?;
2248 if headers_end > region.len() {
2249 return Err(ProtocolError::UnexpectedEof);
2250 }
2251
2252 let mut sts_url = String::new();
2253 let mut spn = String::new();
2254
2255 for i in 0..count {
2256 let h = 4 + i * Self::OPT_HEADER_LEN;
2257 let info_id = region[h];
2258 let data_len =
2259 u32::from_le_bytes([region[h + 1], region[h + 2], region[h + 3], region[h + 4]])
2260 as usize;
2261 let data_off =
2262 u32::from_le_bytes([region[h + 5], region[h + 6], region[h + 7], region[h + 8]])
2263 as usize;
2264
2265 if info_id != Self::ID_SPN && info_id != Self::ID_STSURL {
2268 continue;
2269 }
2270
2271 let data_end = data_off
2272 .checked_add(data_len)
2273 .ok_or(ProtocolError::UnexpectedEof)?;
2274 if data_end > region.len() {
2275 return Err(ProtocolError::UnexpectedEof);
2276 }
2277 if data_len % 2 != 0 {
2278 return Err(ProtocolError::StringEncoding(
2279 #[cfg(feature = "std")]
2280 "FEDAUTHINFO option data has odd length, not UTF-16".to_string(),
2281 #[cfg(not(feature = "std"))]
2282 "FEDAUTHINFO option data has odd length, not UTF-16",
2283 ));
2284 }
2285
2286 let chars: Vec<u16> = region[data_off..data_end]
2287 .chunks_exact(2)
2288 .map(|b| u16::from_le_bytes([b[0], b[1]]))
2289 .collect();
2290 let value = String::from_utf16(&chars).map_err(|_| {
2291 ProtocolError::StringEncoding(
2292 #[cfg(feature = "std")]
2293 "invalid UTF-16 in FEDAUTHINFO option".to_string(),
2294 #[cfg(not(feature = "std"))]
2295 "invalid UTF-16 in FEDAUTHINFO option",
2296 )
2297 })?;
2298
2299 if info_id == Self::ID_SPN {
2300 spn = value;
2301 } else {
2302 sts_url = value;
2303 }
2304 }
2305
2306 Ok(Self { sts_url, spn })
2307 }
2308}
2309
2310pub struct TokenParser {
2351 data: Bytes,
2352 position: usize,
2353 encryption_enabled: bool,
2356}
2357
2358impl TokenParser {
2359 #[must_use]
2361 pub fn new(data: Bytes) -> Self {
2362 Self {
2363 data,
2364 position: 0,
2365 encryption_enabled: false,
2366 }
2367 }
2368
2369 #[must_use]
2374 pub fn with_encryption(mut self, enabled: bool) -> Self {
2375 self.encryption_enabled = enabled;
2376 self
2377 }
2378
2379 #[must_use]
2381 pub fn remaining(&self) -> usize {
2382 self.data.len().saturating_sub(self.position)
2383 }
2384
2385 #[must_use]
2387 pub fn has_remaining(&self) -> bool {
2388 self.position < self.data.len()
2389 }
2390
2391 #[must_use]
2393 pub fn peek_token_type(&self) -> Option<TokenType> {
2394 if self.position < self.data.len() {
2395 TokenType::from_u8(self.data[self.position])
2396 } else {
2397 None
2398 }
2399 }
2400
2401 pub fn next_token(&mut self) -> Result<Option<Token>, ProtocolError> {
2409 self.next_token_with_metadata(None)
2410 }
2411
2412 pub fn next_token_with_metadata(
2419 &mut self,
2420 metadata: Option<&ColMetaData>,
2421 ) -> Result<Option<Token>, ProtocolError> {
2422 loop {
2423 if !self.has_remaining() {
2424 return Ok(None);
2425 }
2426
2427 let mut buf = &self.data[self.position..];
2428 let start_pos = self.position;
2429
2430 let token_type_byte = buf.get_u8();
2431 let token_type = TokenType::from_u8(token_type_byte);
2432
2433 let token = match token_type {
2434 Some(TokenType::Done) => {
2435 let done = Done::decode(&mut buf)?;
2436 Token::Done(done)
2437 }
2438 Some(TokenType::DoneProc) => {
2439 let done = DoneProc::decode(&mut buf)?;
2440 Token::DoneProc(done)
2441 }
2442 Some(TokenType::DoneInProc) => {
2443 let done = DoneInProc::decode(&mut buf)?;
2444 Token::DoneInProc(done)
2445 }
2446 Some(TokenType::Error) => {
2447 let error = ServerError::decode(&mut buf)?;
2448 Token::Error(error)
2449 }
2450 Some(TokenType::Info) => {
2451 let info = ServerInfo::decode(&mut buf)?;
2452 Token::Info(info)
2453 }
2454 Some(TokenType::LoginAck) => {
2455 let login_ack = LoginAck::decode(&mut buf)?;
2456 Token::LoginAck(login_ack)
2457 }
2458 Some(TokenType::EnvChange) => {
2459 let env_change = EnvChange::decode(&mut buf)?;
2460 Token::EnvChange(env_change)
2461 }
2462 Some(TokenType::Order) => {
2463 let order = Order::decode(&mut buf)?;
2464 Token::Order(order)
2465 }
2466 Some(TokenType::FeatureExtAck) => {
2467 let ack = FeatureExtAck::decode(&mut buf)?;
2468 Token::FeatureExtAck(ack)
2469 }
2470 Some(TokenType::Sspi) => {
2471 let sspi = SspiToken::decode(&mut buf)?;
2472 Token::Sspi(sspi)
2473 }
2474 Some(TokenType::FedAuthInfo) => {
2475 let info = FedAuthInfo::decode(&mut buf)?;
2476 Token::FedAuthInfo(info)
2477 }
2478 Some(TokenType::ReturnStatus) => {
2479 if buf.remaining() < 4 {
2480 return Err(ProtocolError::UnexpectedEof);
2481 }
2482 let status = buf.get_i32_le();
2483 Token::ReturnStatus(status)
2484 }
2485 Some(TokenType::ColMetaData) => {
2486 let col_meta = if self.encryption_enabled {
2487 ColMetaData::decode_encrypted(&mut buf)?
2488 } else {
2489 ColMetaData::decode(&mut buf)?
2490 };
2491 Token::ColMetaData(col_meta)
2492 }
2493 Some(TokenType::Row) => {
2494 let meta = metadata.ok_or_else(|| {
2495 ProtocolError::StringEncoding(
2496 #[cfg(feature = "std")]
2497 "Row token requires column metadata".to_string(),
2498 #[cfg(not(feature = "std"))]
2499 "Row token requires column metadata",
2500 )
2501 })?;
2502 let row = RawRow::decode(&mut buf, meta)?;
2503 Token::Row(row)
2504 }
2505 Some(TokenType::NbcRow) => {
2506 let meta = metadata.ok_or_else(|| {
2507 ProtocolError::StringEncoding(
2508 #[cfg(feature = "std")]
2509 "NbcRow token requires column metadata".to_string(),
2510 #[cfg(not(feature = "std"))]
2511 "NbcRow token requires column metadata",
2512 )
2513 })?;
2514 let row = NbcRow::decode(&mut buf, meta)?;
2515 Token::NbcRow(row)
2516 }
2517 Some(TokenType::ReturnValue) => {
2518 let ret_val = ReturnValue::decode(&mut buf)?;
2519 Token::ReturnValue(ret_val)
2520 }
2521 Some(TokenType::SessionState) => {
2522 let session = SessionState::decode(&mut buf)?;
2523 Token::SessionState(session)
2524 }
2525 Some(TokenType::ColInfo) | Some(TokenType::TabName) | Some(TokenType::Offset) => {
2526 if buf.remaining() < 2 {
2529 return Err(ProtocolError::UnexpectedEof);
2530 }
2531 let length = buf.get_u16_le() as usize;
2532 if buf.remaining() < length {
2533 return Err(ProtocolError::IncompletePacket {
2534 expected: length,
2535 actual: buf.remaining(),
2536 });
2537 }
2538 buf.advance(length);
2540 self.position = start_pos + (self.data.len() - start_pos - buf.remaining());
2545 continue;
2546 }
2547 None => {
2548 return Err(ProtocolError::InvalidTokenType(token_type_byte));
2549 }
2550 };
2551
2552 let consumed = self.data.len() - start_pos - buf.remaining();
2554 self.position = start_pos + consumed;
2555
2556 return Ok(Some(token));
2557 }
2558 }
2559
2560 pub fn skip_token(&mut self) -> Result<(), ProtocolError> {
2564 if !self.has_remaining() {
2565 return Ok(());
2566 }
2567
2568 let token_type_byte = self.data[self.position];
2569 let token_type = TokenType::from_u8(token_type_byte);
2570
2571 let skip_amount = match token_type {
2573 Some(TokenType::Done) | Some(TokenType::DoneProc) | Some(TokenType::DoneInProc) => {
2575 1 + Done::SIZE }
2577 Some(TokenType::ReturnStatus) => {
2578 1 + 4 }
2580 Some(TokenType::Error)
2582 | Some(TokenType::Info)
2583 | Some(TokenType::LoginAck)
2584 | Some(TokenType::EnvChange)
2585 | Some(TokenType::Order)
2586 | Some(TokenType::Sspi)
2587 | Some(TokenType::ColInfo)
2588 | Some(TokenType::TabName)
2589 | Some(TokenType::Offset)
2590 | Some(TokenType::ReturnValue) => {
2591 if self.remaining() < 3 {
2592 return Err(ProtocolError::UnexpectedEof);
2593 }
2594 let length = u16::from_le_bytes([
2595 self.data[self.position + 1],
2596 self.data[self.position + 2],
2597 ]) as usize;
2598 1 + 2 + length }
2600 Some(TokenType::SessionState) | Some(TokenType::FedAuthInfo) => {
2602 if self.remaining() < 5 {
2603 return Err(ProtocolError::UnexpectedEof);
2604 }
2605 let length = u32::from_le_bytes([
2606 self.data[self.position + 1],
2607 self.data[self.position + 2],
2608 self.data[self.position + 3],
2609 self.data[self.position + 4],
2610 ]) as usize;
2611 1 + 4 + length
2612 }
2613 Some(TokenType::FeatureExtAck) => {
2615 let mut buf = &self.data[self.position + 1..];
2617 let _ = FeatureExtAck::decode(&mut buf)?;
2618 self.data.len() - self.position - buf.remaining()
2619 }
2620 Some(TokenType::ColMetaData) | Some(TokenType::Row) | Some(TokenType::NbcRow) => {
2622 return Err(ProtocolError::InvalidTokenType(token_type_byte));
2623 }
2624 None => {
2625 return Err(ProtocolError::InvalidTokenType(token_type_byte));
2626 }
2627 };
2628
2629 if self.remaining() < skip_amount {
2630 return Err(ProtocolError::UnexpectedEof);
2631 }
2632
2633 self.position += skip_amount;
2634 Ok(())
2635 }
2636
2637 #[must_use]
2639 pub fn position(&self) -> usize {
2640 self.position
2641 }
2642
2643 pub fn reset(&mut self) {
2645 self.position = 0;
2646 }
2647}
2648
2649#[cfg(test)]
2654#[allow(clippy::unwrap_used, clippy::panic)]
2655mod tests {
2656 use super::*;
2657 use bytes::BytesMut;
2658
2659 #[test]
2660 fn test_done_roundtrip() {
2661 let done = Done {
2662 status: DoneStatus {
2663 more: false,
2664 error: false,
2665 in_xact: false,
2666 count: true,
2667 attn: false,
2668 srverror: false,
2669 },
2670 cur_cmd: 193, row_count: 42,
2672 };
2673
2674 let mut buf = BytesMut::new();
2675 done.encode(&mut buf);
2676
2677 let mut cursor = &buf[1..];
2679 let decoded = Done::decode(&mut cursor).unwrap();
2680
2681 assert_eq!(decoded.status.count, done.status.count);
2682 assert_eq!(decoded.cur_cmd, done.cur_cmd);
2683 assert_eq!(decoded.row_count, done.row_count);
2684 }
2685
2686 #[test]
2687 fn test_done_status_bits() {
2688 let status = DoneStatus {
2689 more: true,
2690 error: true,
2691 in_xact: true,
2692 count: true,
2693 attn: false,
2694 srverror: false,
2695 };
2696
2697 let bits = status.to_bits();
2698 let restored = DoneStatus::from_bits(bits);
2699
2700 assert_eq!(status.more, restored.more);
2701 assert_eq!(status.error, restored.error);
2702 assert_eq!(status.in_xact, restored.in_xact);
2703 assert_eq!(status.count, restored.count);
2704 }
2705
2706 #[test]
2707 fn test_token_parser_done() {
2708 let data = Bytes::from_static(&[
2710 0xFD, 0x10, 0x00, 0xC1, 0x00, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ]);
2715
2716 let mut parser = TokenParser::new(data);
2717 let token = parser.next_token().unwrap().unwrap();
2718
2719 match token {
2720 Token::Done(done) => {
2721 assert!(done.status.count);
2722 assert!(!done.status.more);
2723 assert_eq!(done.cur_cmd, 193);
2724 assert_eq!(done.row_count, 5);
2725 }
2726 _ => panic!("Expected Done token"),
2727 }
2728
2729 assert!(parser.next_token().unwrap().is_none());
2731 }
2732
2733 #[test]
2734 fn test_env_change_type_from_u8() {
2735 assert_eq!(EnvChangeType::from_u8(1), Some(EnvChangeType::Database));
2736 assert_eq!(EnvChangeType::from_u8(20), Some(EnvChangeType::Routing));
2737 assert_eq!(EnvChangeType::from_u8(100), None);
2738 }
2739
2740 #[test]
2747 fn test_env_change_routing_consumes_declared_length() {
2748 let host = "redirect.example";
2749 let host_utf16: Vec<u16> = host.encode_utf16().collect();
2750
2751 let mut data = BytesMut::new();
2752 let routing_len = 1 + 2 + 2 + host_utf16.len() * 2;
2754 let env_len = 1 + 2 + routing_len + 2;
2757 data.put_u16_le(env_len as u16);
2758 data.put_u8(20); data.put_u16_le(routing_len as u16);
2760 data.put_u8(0); data.put_u16_le(11000); data.put_u16_le(host_utf16.len() as u16);
2763 for c in &host_utf16 {
2764 data.put_u16_le(*c);
2765 }
2766 data.put_u16_le(0); data.put_u8(0xFD);
2769
2770 let mut buf: &[u8] = &data;
2771 let env = EnvChange::decode(&mut buf).unwrap();
2772 assert_eq!(env.routing_info(), Some((host, 11000)));
2773 assert_eq!(
2774 buf,
2775 &[0xFD],
2776 "decode must consume exactly the declared ENVCHANGE frame"
2777 );
2778 }
2779
2780 fn put_b_varchar(buf: &mut BytesMut, s: &str) {
2781 let utf16: Vec<u16> = s.encode_utf16().collect();
2782 buf.put_u8(utf16.len() as u8);
2783 for c in utf16 {
2784 buf.put_u16_le(c);
2785 }
2786 }
2787
2788 fn put_us_varchar(buf: &mut BytesMut, s: &str) {
2789 let utf16: Vec<u16> = s.encode_utf16().collect();
2790 buf.put_u16_le(utf16.len() as u16);
2791 for c in utf16 {
2792 buf.put_u16_le(c);
2793 }
2794 }
2795
2796 #[test]
2803 fn test_udt_info_metadata_uses_b_varchar_names() {
2804 let mut data = BytesMut::new();
2805 data.put_u16_le(0xFFFF); put_b_varchar(&mut data, "master");
2807 put_b_varchar(&mut data, "dbo");
2808 put_b_varchar(&mut data, "hierarchyid");
2809 put_us_varchar(
2810 &mut data,
2811 "Microsoft.SqlServer.Types.SqlHierarchyId, Microsoft.SqlServer.Types",
2812 );
2813 data.put_u8(0xFD);
2815
2816 let mut buf: &[u8] = &data;
2817 let info = decode_type_info(&mut buf, TypeId::Udt, TypeId::Udt as u8).unwrap();
2818 assert_eq!(info.max_length, Some(0xFFFF));
2819 assert_eq!(
2820 buf,
2821 &[0xFD],
2822 "decode must consume exactly the UDT_INFO frame"
2823 );
2824 }
2825
2826 #[test]
2830 fn test_xml_info_schema_bound_uses_b_varchar_names() {
2831 let mut data = BytesMut::new();
2832 data.put_u8(1); put_b_varchar(&mut data, "master");
2834 put_b_varchar(&mut data, "dbo");
2835 put_us_varchar(&mut data, "MyXmlSchemaCollection");
2836 data.put_u8(0xFD);
2837
2838 let mut buf: &[u8] = &data;
2839 decode_type_info(&mut buf, TypeId::Xml, TypeId::Xml as u8).unwrap();
2840 assert_eq!(
2841 buf,
2842 &[0xFD],
2843 "decode must consume exactly the XML_INFO frame"
2844 );
2845 }
2846
2847 #[test]
2848 fn hostile_env_change_binary_truncated_is_not_panic() {
2849 let data = [0x01, 0x00, 0x08];
2854 let mut buf: &[u8] = &data;
2855 let env = EnvChange::decode(&mut buf).unwrap();
2856 assert_eq!(env.env_type, EnvChangeType::BeginTransaction);
2857 }
2858
2859 #[test]
2862 fn hostile_env_change_under_declared_cannot_steal_following_bytes() {
2863 let mut data = BytesMut::new();
2868 data.put_u16_le(1); data.put_u8(0x08); let following: &[u8] = &[0x08, 1, 2, 3, 4, 5, 6, 7, 8, 0x00];
2871 data.extend_from_slice(following);
2872
2873 let mut buf: &[u8] = &data;
2874 let env = EnvChange::decode(&mut buf).unwrap();
2875 assert_eq!(env.env_type, EnvChangeType::BeginTransaction);
2876 match &env.new_value {
2877 EnvChangeValue::Binary(b) => {
2878 assert!(
2879 b.is_empty(),
2880 "under-declared frame yields the lenient empty value"
2881 );
2882 }
2883 other => panic!("expected empty Binary value, got {other:?}"),
2884 }
2885 assert_eq!(
2886 buf, following,
2887 "bytes beyond the declared frame belong to the next token"
2888 );
2889 }
2890
2891 #[test]
2894 fn hostile_env_change_zero_length_frame_errors() {
2895 let data = [0x00, 0x00, 0xFD];
2896 let mut buf: &[u8] = &data;
2897 assert!(EnvChange::decode(&mut buf).is_err());
2898 }
2899
2900 #[test]
2901 fn test_colmetadata_no_columns() {
2902 let data = Bytes::from_static(&[0xFF, 0xFF]);
2904 let mut cursor: &[u8] = &data;
2905 let meta = ColMetaData::decode(&mut cursor).unwrap();
2906 assert!(meta.is_empty());
2907 assert_eq!(meta.column_count(), 0);
2908 }
2909
2910 #[test]
2911 fn test_colmetadata_single_int_column() {
2912 let mut data = BytesMut::new();
2915 data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[0x00, 0x00, 0x00, 0x00]); data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[0x38]); data.extend_from_slice(&[0x02]); data.extend_from_slice(&[b'i', 0x00, b'd', 0x00]); let mut cursor: &[u8] = &data;
2924 let meta = ColMetaData::decode(&mut cursor).unwrap();
2925
2926 assert_eq!(meta.column_count(), 1);
2927 assert_eq!(meta.columns[0].name, "id");
2928 assert_eq!(meta.columns[0].type_id, TypeId::Int4);
2929 assert!(meta.columns[0].is_nullable());
2930 }
2931
2932 #[test]
2933 fn test_colmetadata_nvarchar_column() {
2934 let mut data = BytesMut::new();
2936 data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[0x00, 0x00, 0x00, 0x00]); data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[0xE7]); data.extend_from_slice(&[0x64, 0x00]); data.extend_from_slice(&[0x09, 0x04, 0xD0, 0x00, 0x34]); data.extend_from_slice(&[0x04]); data.extend_from_slice(&[b'n', 0x00, b'a', 0x00, b'm', 0x00, b'e', 0x00]);
2946
2947 let mut cursor: &[u8] = &data;
2948 let meta = ColMetaData::decode(&mut cursor).unwrap();
2949
2950 assert_eq!(meta.column_count(), 1);
2951 assert_eq!(meta.columns[0].name, "name");
2952 assert_eq!(meta.columns[0].type_id, TypeId::NVarChar);
2953 assert_eq!(meta.columns[0].type_info.max_length, Some(100));
2954 assert!(meta.columns[0].type_info.collation.is_some());
2955 }
2956
2957 #[test]
2958 fn test_raw_row_decode_int() {
2959 let metadata = ColMetaData {
2961 cek_table: None,
2962 columns: vec![ColumnData {
2963 name: "id".to_string(),
2964 type_id: TypeId::Int4,
2965 col_type: 0x38,
2966 flags: 0,
2967 user_type: 0,
2968 type_info: TypeInfo::default(),
2969 crypto_metadata: None,
2970 }],
2971 };
2972
2973 let data = Bytes::from_static(&[0x2A, 0x00, 0x00, 0x00]); let mut cursor: &[u8] = &data;
2976 let row = RawRow::decode(&mut cursor, &metadata).unwrap();
2977
2978 assert_eq!(row.data.len(), 4);
2980 assert_eq!(&row.data[..], &[0x2A, 0x00, 0x00, 0x00]);
2981 }
2982
2983 #[test]
2984 fn test_raw_row_decode_nullable_int() {
2985 let metadata = ColMetaData {
2987 cek_table: None,
2988 columns: vec![ColumnData {
2989 name: "id".to_string(),
2990 type_id: TypeId::IntN,
2991 col_type: 0x26,
2992 flags: 0x01, user_type: 0,
2994 type_info: TypeInfo {
2995 max_length: Some(4),
2996 ..Default::default()
2997 },
2998 crypto_metadata: None,
2999 }],
3000 };
3001
3002 let data = Bytes::from_static(&[0x04, 0x2A, 0x00, 0x00, 0x00]); let mut cursor: &[u8] = &data;
3005 let row = RawRow::decode(&mut cursor, &metadata).unwrap();
3006
3007 assert_eq!(row.data.len(), 5);
3008 assert_eq!(row.data[0], 4); assert_eq!(&row.data[1..], &[0x2A, 0x00, 0x00, 0x00]);
3010 }
3011
3012 #[test]
3013 fn test_raw_row_decode_null_value() {
3014 let metadata = ColMetaData {
3016 cek_table: None,
3017 columns: vec![ColumnData {
3018 name: "id".to_string(),
3019 type_id: TypeId::IntN,
3020 col_type: 0x26,
3021 flags: 0x01, user_type: 0,
3023 type_info: TypeInfo {
3024 max_length: Some(4),
3025 ..Default::default()
3026 },
3027 crypto_metadata: None,
3028 }],
3029 };
3030
3031 let data = Bytes::from_static(&[0xFF]);
3033 let mut cursor: &[u8] = &data;
3034 let row = RawRow::decode(&mut cursor, &metadata).unwrap();
3035
3036 assert_eq!(row.data.len(), 1);
3037 assert_eq!(row.data[0], 0xFF); }
3039
3040 #[test]
3041 fn test_nbcrow_null_bitmap() {
3042 let row = NbcRow {
3043 null_bitmap: vec![0b00000101], data: Bytes::new(),
3045 };
3046
3047 assert!(row.is_null(0));
3048 assert!(!row.is_null(1));
3049 assert!(row.is_null(2));
3050 assert!(!row.is_null(3));
3051 }
3052
3053 #[test]
3054 fn test_token_parser_colmetadata() {
3055 let mut data = BytesMut::new();
3057 data.extend_from_slice(&[0x81]); data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[0x00, 0x00, 0x00, 0x00]); data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[0x38]); data.extend_from_slice(&[0x02]); data.extend_from_slice(&[b'i', 0x00, b'd', 0x00]); let mut parser = TokenParser::new(data.freeze());
3066 let token = parser.next_token().unwrap().unwrap();
3067
3068 match token {
3069 Token::ColMetaData(meta) => {
3070 assert_eq!(meta.column_count(), 1);
3071 assert_eq!(meta.columns[0].name, "id");
3072 assert_eq!(meta.columns[0].type_id, TypeId::Int4);
3073 }
3074 _ => panic!("Expected ColMetaData token"),
3075 }
3076 }
3077
3078 #[test]
3079 fn test_token_parser_row_with_metadata() {
3080 let metadata = ColMetaData {
3082 cek_table: None,
3083 columns: vec![ColumnData {
3084 name: "id".to_string(),
3085 type_id: TypeId::Int4,
3086 col_type: 0x38,
3087 flags: 0,
3088 user_type: 0,
3089 type_info: TypeInfo::default(),
3090 crypto_metadata: None,
3091 }],
3092 };
3093
3094 let mut data = BytesMut::new();
3096 data.extend_from_slice(&[0xD1]); data.extend_from_slice(&[0x2A, 0x00, 0x00, 0x00]); let mut parser = TokenParser::new(data.freeze());
3100 let token = parser
3101 .next_token_with_metadata(Some(&metadata))
3102 .unwrap()
3103 .unwrap();
3104
3105 match token {
3106 Token::Row(row) => {
3107 assert_eq!(row.data.len(), 4);
3108 }
3109 _ => panic!("Expected Row token"),
3110 }
3111 }
3112
3113 #[test]
3114 fn test_token_parser_row_without_metadata_fails() {
3115 let mut data = BytesMut::new();
3117 data.extend_from_slice(&[0xD1]); data.extend_from_slice(&[0x2A, 0x00, 0x00, 0x00]); let mut parser = TokenParser::new(data.freeze());
3121 let result = parser.next_token(); assert!(result.is_err());
3124 }
3125
3126 #[test]
3127 fn test_token_parser_peek() {
3128 let data = Bytes::from_static(&[
3129 0xFD, 0x10, 0x00, 0xC1, 0x00, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ]);
3134
3135 let parser = TokenParser::new(data);
3136 assert_eq!(parser.peek_token_type(), Some(TokenType::Done));
3137 }
3138
3139 #[test]
3140 fn test_column_data_fixed_size() {
3141 let col = ColumnData {
3142 name: String::new(),
3143 type_id: TypeId::Int4,
3144 col_type: 0x38,
3145 flags: 0,
3146 user_type: 0,
3147 type_info: TypeInfo::default(),
3148 crypto_metadata: None,
3149 };
3150 assert_eq!(col.fixed_size(), Some(4));
3151
3152 let col2 = ColumnData {
3153 name: String::new(),
3154 type_id: TypeId::NVarChar,
3155 col_type: 0xE7,
3156 flags: 0,
3157 user_type: 0,
3158 type_info: TypeInfo::default(),
3159 crypto_metadata: None,
3160 };
3161 assert_eq!(col2.fixed_size(), None);
3162 }
3163
3164 #[test]
3172 fn test_decode_nvarchar_then_intn_roundtrip() {
3173 let mut wire_data = BytesMut::new();
3178
3179 let word = "World";
3182 let utf16: Vec<u16> = word.encode_utf16().collect();
3183 wire_data.put_u16_le((utf16.len() * 2) as u16); for code_unit in &utf16 {
3185 wire_data.put_u16_le(*code_unit);
3186 }
3187
3188 wire_data.put_u8(4); wire_data.put_i32_le(42);
3191
3192 let metadata = ColMetaData {
3194 cek_table: None,
3195 columns: vec![
3196 ColumnData {
3197 name: "greeting".to_string(),
3198 type_id: TypeId::NVarChar,
3199 col_type: 0xE7,
3200 flags: 0x01,
3201 user_type: 0,
3202 type_info: TypeInfo {
3203 max_length: Some(10), precision: None,
3205 scale: None,
3206 collation: None,
3207 },
3208 crypto_metadata: None,
3209 },
3210 ColumnData {
3211 name: "number".to_string(),
3212 type_id: TypeId::IntN,
3213 col_type: 0x26,
3214 flags: 0x01,
3215 user_type: 0,
3216 type_info: TypeInfo {
3217 max_length: Some(4),
3218 precision: None,
3219 scale: None,
3220 collation: None,
3221 },
3222 crypto_metadata: None,
3223 },
3224 ],
3225 };
3226
3227 let mut wire_cursor = wire_data.freeze();
3229 let raw_row = RawRow::decode(&mut wire_cursor, &metadata).unwrap();
3230
3231 assert_eq!(
3233 wire_cursor.remaining(),
3234 0,
3235 "wire data should be fully consumed"
3236 );
3237
3238 let mut stored_cursor: &[u8] = &raw_row.data;
3240
3241 assert!(
3244 stored_cursor.remaining() >= 2,
3245 "need at least 2 bytes for length"
3246 );
3247 let len0 = stored_cursor.get_u16_le() as usize;
3248 assert_eq!(len0, 10, "NVarChar length should be 10 bytes");
3249 assert!(
3250 stored_cursor.remaining() >= len0,
3251 "need {len0} bytes for data"
3252 );
3253
3254 let mut utf16_read = Vec::new();
3256 for _ in 0..(len0 / 2) {
3257 utf16_read.push(stored_cursor.get_u16_le());
3258 }
3259 let string0 = String::from_utf16(&utf16_read).unwrap();
3260 assert_eq!(string0, "World", "column 0 should be 'World'");
3261
3262 assert!(
3265 stored_cursor.remaining() >= 1,
3266 "need at least 1 byte for length"
3267 );
3268 let len1 = stored_cursor.get_u8();
3269 assert_eq!(len1, 4, "IntN length should be 4");
3270 assert!(stored_cursor.remaining() >= 4, "need 4 bytes for INT data");
3271 let int1 = stored_cursor.get_i32_le();
3272 assert_eq!(int1, 42, "column 1 should be 42");
3273
3274 assert_eq!(
3276 stored_cursor.remaining(),
3277 0,
3278 "stored data should be fully consumed"
3279 );
3280 }
3281
3282 #[test]
3283 fn test_decode_nvarchar_max_then_intn_roundtrip() {
3284 let mut wire_data = BytesMut::new();
3288
3289 let word = "Hello";
3292 let utf16: Vec<u16> = word.encode_utf16().collect();
3293 let byte_len = (utf16.len() * 2) as u64;
3294
3295 wire_data.put_u64_le(byte_len); wire_data.put_u32_le(byte_len as u32); for code_unit in &utf16 {
3298 wire_data.put_u16_le(*code_unit);
3299 }
3300 wire_data.put_u32_le(0); wire_data.put_u8(4);
3304 wire_data.put_i32_le(99);
3305
3306 let metadata = ColMetaData {
3308 cek_table: None,
3309 columns: vec![
3310 ColumnData {
3311 name: "text".to_string(),
3312 type_id: TypeId::NVarChar,
3313 col_type: 0xE7,
3314 flags: 0x01,
3315 user_type: 0,
3316 type_info: TypeInfo {
3317 max_length: Some(0xFFFF), precision: None,
3319 scale: None,
3320 collation: None,
3321 },
3322 crypto_metadata: None,
3323 },
3324 ColumnData {
3325 name: "num".to_string(),
3326 type_id: TypeId::IntN,
3327 col_type: 0x26,
3328 flags: 0x01,
3329 user_type: 0,
3330 type_info: TypeInfo {
3331 max_length: Some(4),
3332 precision: None,
3333 scale: None,
3334 collation: None,
3335 },
3336 crypto_metadata: None,
3337 },
3338 ],
3339 };
3340
3341 let mut wire_cursor = wire_data.freeze();
3343 let raw_row = RawRow::decode(&mut wire_cursor, &metadata).unwrap();
3344
3345 assert_eq!(
3347 wire_cursor.remaining(),
3348 0,
3349 "wire data should be fully consumed"
3350 );
3351
3352 let mut stored_cursor: &[u8] = &raw_row.data;
3354
3355 let total_len = stored_cursor.get_u64_le();
3357 assert_eq!(total_len, 10, "PLP total length should be 10");
3358
3359 let chunk_len = stored_cursor.get_u32_le();
3360 assert_eq!(chunk_len, 10, "PLP chunk length should be 10");
3361
3362 let mut utf16_read = Vec::new();
3363 for _ in 0..(chunk_len / 2) {
3364 utf16_read.push(stored_cursor.get_u16_le());
3365 }
3366 let string0 = String::from_utf16(&utf16_read).unwrap();
3367 assert_eq!(string0, "Hello", "column 0 should be 'Hello'");
3368
3369 let terminator = stored_cursor.get_u32_le();
3370 assert_eq!(terminator, 0, "PLP should end with 0");
3371
3372 let len1 = stored_cursor.get_u8();
3374 assert_eq!(len1, 4);
3375 let int1 = stored_cursor.get_i32_le();
3376 assert_eq!(int1, 99, "column 1 should be 99");
3377
3378 assert_eq!(
3380 stored_cursor.remaining(),
3381 0,
3382 "stored data should be fully consumed"
3383 );
3384 }
3385
3386 #[test]
3391 fn test_return_status_via_parser() {
3392 let data = Bytes::from_static(&[
3394 0x79, 0x00, 0x00, 0x00, 0x00, ]);
3397
3398 let mut parser = TokenParser::new(data);
3399 let token = parser.next_token().unwrap().unwrap();
3400
3401 match token {
3402 Token::ReturnStatus(status) => {
3403 assert_eq!(status, 0);
3404 }
3405 _ => panic!("Expected ReturnStatus token, got {token:?}"),
3406 }
3407
3408 assert!(parser.next_token().unwrap().is_none());
3409 }
3410
3411 #[test]
3412 fn test_return_status_nonzero() {
3413 let mut buf = BytesMut::new();
3415 buf.put_u8(0x79); buf.put_i32_le(-6);
3417
3418 let mut parser = TokenParser::new(buf.freeze());
3419 let token = parser.next_token().unwrap().unwrap();
3420
3421 match token {
3422 Token::ReturnStatus(status) => {
3423 assert_eq!(status, -6);
3424 }
3425 _ => panic!("Expected ReturnStatus token"),
3426 }
3427 }
3428
3429 #[test]
3434 fn test_done_proc_roundtrip() {
3435 let done = DoneProc {
3436 status: DoneStatus {
3437 more: false,
3438 error: false,
3439 in_xact: false,
3440 count: true,
3441 attn: false,
3442 srverror: false,
3443 },
3444 cur_cmd: 0x00C6, row_count: 100,
3446 };
3447
3448 let mut buf = BytesMut::new();
3449 done.encode(&mut buf);
3450
3451 assert_eq!(buf[0], 0xFE);
3453
3454 let mut cursor = &buf[1..];
3456 let decoded = DoneProc::decode(&mut cursor).unwrap();
3457
3458 assert!(decoded.status.count);
3459 assert!(!decoded.status.more);
3460 assert!(!decoded.status.error);
3461 assert_eq!(decoded.cur_cmd, 0x00C6);
3462 assert_eq!(decoded.row_count, 100);
3463 }
3464
3465 #[test]
3466 fn test_done_proc_via_parser() {
3467 let data = Bytes::from_static(&[
3468 0xFE, 0x00, 0x00, 0xC6, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ]);
3473
3474 let mut parser = TokenParser::new(data);
3475 let token = parser.next_token().unwrap().unwrap();
3476
3477 match token {
3478 Token::DoneProc(done) => {
3479 assert!(!done.status.count);
3480 assert!(!done.status.more);
3481 assert_eq!(done.cur_cmd, 198);
3482 assert_eq!(done.row_count, 0);
3483 }
3484 _ => panic!("Expected DoneProc token"),
3485 }
3486 }
3487
3488 #[test]
3489 fn test_done_proc_with_error_flag() {
3490 let mut buf = BytesMut::new();
3491 buf.put_u8(0xFE); buf.put_u16_le(0x0002); buf.put_u16_le(0x00C6); buf.put_u64_le(0); let mut parser = TokenParser::new(buf.freeze());
3497 let token = parser.next_token().unwrap().unwrap();
3498
3499 match token {
3500 Token::DoneProc(done) => {
3501 assert!(done.status.error);
3502 assert!(!done.status.count);
3503 assert!(!done.status.more);
3504 }
3505 _ => panic!("Expected DoneProc token"),
3506 }
3507 }
3508
3509 #[test]
3514 fn test_done_in_proc_roundtrip() {
3515 let done = DoneInProc {
3516 status: DoneStatus {
3517 more: true,
3518 error: false,
3519 in_xact: false,
3520 count: true,
3521 attn: false,
3522 srverror: false,
3523 },
3524 cur_cmd: 193, row_count: 7,
3526 };
3527
3528 let mut buf = BytesMut::new();
3529 done.encode(&mut buf);
3530
3531 assert_eq!(buf[0], 0xFF);
3532
3533 let mut cursor = &buf[1..];
3534 let decoded = DoneInProc::decode(&mut cursor).unwrap();
3535
3536 assert!(decoded.status.more);
3537 assert!(decoded.status.count);
3538 assert!(!decoded.status.error);
3539 assert_eq!(decoded.cur_cmd, 193);
3540 assert_eq!(decoded.row_count, 7);
3541 }
3542
3543 #[test]
3544 fn test_done_in_proc_via_parser() {
3545 let data = Bytes::from_static(&[
3546 0xFF, 0x11, 0x00, 0xC1, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ]);
3551
3552 let mut parser = TokenParser::new(data);
3553 let token = parser.next_token().unwrap().unwrap();
3554
3555 match token {
3556 Token::DoneInProc(done) => {
3557 assert!(done.status.more);
3558 assert!(done.status.count);
3559 assert_eq!(done.cur_cmd, 193);
3560 assert_eq!(done.row_count, 3);
3561 }
3562 _ => panic!("Expected DoneInProc token"),
3563 }
3564 }
3565
3566 #[test]
3571 fn test_server_error_decode() {
3572 let mut buf = BytesMut::new();
3575
3576 let msg_utf16: Vec<u16> = "Invalid column name 'foo'.".encode_utf16().collect();
3578 let srv_utf16: Vec<u16> = "SQLDB01".encode_utf16().collect();
3579 let proc_utf16: Vec<u16> = "".encode_utf16().collect();
3580
3581 let length: u16 = (4
3587 + 1
3588 + 1
3589 + 2
3590 + (msg_utf16.len() * 2)
3591 + 1
3592 + (srv_utf16.len() * 2)
3593 + 1
3594 + (proc_utf16.len() * 2)
3595 + 4) as u16;
3596
3597 buf.put_u16_le(length);
3598 buf.put_i32_le(207); buf.put_u8(1); buf.put_u8(16); buf.put_u16_le(msg_utf16.len() as u16);
3604 for &c in &msg_utf16 {
3605 buf.put_u16_le(c);
3606 }
3607
3608 buf.put_u8(srv_utf16.len() as u8);
3610 for &c in &srv_utf16 {
3611 buf.put_u16_le(c);
3612 }
3613
3614 buf.put_u8(proc_utf16.len() as u8);
3616
3617 buf.put_i32_le(42);
3619
3620 let mut cursor = buf.freeze();
3621 let error = ServerError::decode(&mut cursor).unwrap();
3622
3623 assert_eq!(error.number, 207);
3624 assert_eq!(error.state, 1);
3625 assert_eq!(error.class, 16);
3626 assert_eq!(error.message, "Invalid column name 'foo'.");
3627 assert_eq!(error.server, "SQLDB01");
3628 assert_eq!(error.procedure, "");
3629 assert_eq!(error.line, 42);
3630 }
3631
3632 #[test]
3633 fn test_server_error_severity_helpers() {
3634 let fatal = ServerError {
3635 number: 4014,
3636 state: 1,
3637 class: 20,
3638 message: "Fatal error".to_string(),
3639 server: String::new(),
3640 procedure: String::new(),
3641 line: 0,
3642 };
3643 assert!(fatal.is_fatal());
3644 assert!(fatal.is_batch_abort());
3645
3646 let batch_abort = ServerError {
3647 number: 547,
3648 state: 0,
3649 class: 16,
3650 message: "Constraint violation".to_string(),
3651 server: String::new(),
3652 procedure: String::new(),
3653 line: 1,
3654 };
3655 assert!(!batch_abort.is_fatal());
3656 assert!(batch_abort.is_batch_abort());
3657
3658 let informational = ServerError {
3659 number: 5701,
3660 state: 2,
3661 class: 10,
3662 message: "Changed db context".to_string(),
3663 server: String::new(),
3664 procedure: String::new(),
3665 line: 0,
3666 };
3667 assert!(!informational.is_fatal());
3668 assert!(!informational.is_batch_abort());
3669 }
3670
3671 #[test]
3672 fn test_server_error_via_parser() {
3673 let mut buf = BytesMut::new();
3675 buf.put_u8(0xAA); let msg_utf16: Vec<u16> = "Syntax error".encode_utf16().collect();
3678 let srv_utf16: Vec<u16> = "SRV".encode_utf16().collect();
3679 let proc_utf16: Vec<u16> = "sp_test".encode_utf16().collect();
3680
3681 let length: u16 = (4
3682 + 1
3683 + 1
3684 + 2
3685 + (msg_utf16.len() * 2)
3686 + 1
3687 + (srv_utf16.len() * 2)
3688 + 1
3689 + (proc_utf16.len() * 2)
3690 + 4) as u16;
3691
3692 buf.put_u16_le(length);
3693 buf.put_i32_le(102); buf.put_u8(1);
3695 buf.put_u8(15);
3696
3697 buf.put_u16_le(msg_utf16.len() as u16);
3698 for &c in &msg_utf16 {
3699 buf.put_u16_le(c);
3700 }
3701 buf.put_u8(srv_utf16.len() as u8);
3702 for &c in &srv_utf16 {
3703 buf.put_u16_le(c);
3704 }
3705 buf.put_u8(proc_utf16.len() as u8);
3706 for &c in &proc_utf16 {
3707 buf.put_u16_le(c);
3708 }
3709 buf.put_i32_le(5);
3710
3711 let mut parser = TokenParser::new(buf.freeze());
3712 let token = parser.next_token().unwrap().unwrap();
3713
3714 match token {
3715 Token::Error(err) => {
3716 assert_eq!(err.number, 102);
3717 assert_eq!(err.class, 15);
3718 assert_eq!(err.message, "Syntax error");
3719 assert_eq!(err.server, "SRV");
3720 assert_eq!(err.procedure, "sp_test");
3721 assert_eq!(err.line, 5);
3722 }
3723 _ => panic!("Expected Error token"),
3724 }
3725 }
3726
3727 fn build_return_value_intn(
3734 ordinal: u16,
3735 name: &str,
3736 status: u8,
3737 value: Option<i32>,
3738 ) -> BytesMut {
3739 let mut inner = BytesMut::new();
3740
3741 inner.put_u16_le(ordinal);
3743
3744 let name_utf16: Vec<u16> = name.encode_utf16().collect();
3746 inner.put_u8(name_utf16.len() as u8);
3747 for &c in &name_utf16 {
3748 inner.put_u16_le(c);
3749 }
3750
3751 inner.put_u8(status);
3753
3754 inner.put_u32_le(0);
3756
3757 inner.put_u16_le(0x0001); inner.put_u8(0x26);
3762
3763 inner.put_u8(4);
3765
3766 match value {
3768 Some(v) => {
3769 inner.put_u8(4); inner.put_i32_le(v);
3771 }
3772 None => {
3773 inner.put_u8(0); }
3775 }
3776
3777 inner
3780 }
3781
3782 #[test]
3783 fn test_return_value_int_output() {
3784 let buf = build_return_value_intn(1, "@result", 0x01, Some(42));
3785 let mut cursor = buf.freeze();
3786 let rv = ReturnValue::decode(&mut cursor).unwrap();
3787
3788 assert_eq!(rv.param_ordinal, 1);
3789 assert_eq!(rv.param_name, "@result");
3790 assert_eq!(rv.status, 0x01); assert_eq!(rv.col_type, 0x26); assert_eq!(rv.type_info.max_length, Some(4));
3793 assert_eq!(rv.value.len(), 5);
3795 assert_eq!(rv.value[0], 4);
3796 assert_eq!(
3797 i32::from_le_bytes([rv.value[1], rv.value[2], rv.value[3], rv.value[4]]),
3798 42
3799 );
3800 }
3801
3802 #[test]
3803 fn test_return_value_null_output() {
3804 let buf = build_return_value_intn(2, "@count", 0x01, None);
3805 let mut cursor = buf.freeze();
3806 let rv = ReturnValue::decode(&mut cursor).unwrap();
3807
3808 assert_eq!(rv.param_ordinal, 2);
3809 assert_eq!(rv.param_name, "@count");
3810 assert_eq!(rv.status, 0x01);
3811 assert_eq!(rv.col_type, 0x26);
3812 assert_eq!(rv.value.len(), 1);
3814 assert_eq!(rv.value[0], 0);
3815 }
3816
3817 #[test]
3818 fn test_return_value_udf_status() {
3819 let buf = build_return_value_intn(0, "@RETURN_VALUE", 0x02, Some(-1));
3821 let mut cursor = buf.freeze();
3822 let rv = ReturnValue::decode(&mut cursor).unwrap();
3823
3824 assert_eq!(rv.param_ordinal, 0);
3825 assert_eq!(rv.param_name, "@RETURN_VALUE");
3826 assert_eq!(rv.status, 0x02); assert_eq!(rv.value[0], 4);
3828 assert_eq!(
3829 i32::from_le_bytes([rv.value[1], rv.value[2], rv.value[3], rv.value[4]]),
3830 -1
3831 );
3832 }
3833
3834 #[test]
3835 fn test_return_value_nvarchar_output() {
3836 let mut inner = BytesMut::new();
3838
3839 inner.put_u16_le(1);
3841
3842 let name_utf16: Vec<u16> = "@name".encode_utf16().collect();
3844 inner.put_u8(name_utf16.len() as u8);
3845 for &c in &name_utf16 {
3846 inner.put_u16_le(c);
3847 }
3848
3849 inner.put_u8(0x01);
3851 inner.put_u32_le(0);
3853 inner.put_u16_le(0x0001);
3855 inner.put_u8(0xE7);
3857 inner.put_u16_le(200); inner.put_u32_le(0x0904D000); inner.put_u8(0x34); let val_utf16: Vec<u16> = "Hello".encode_utf16().collect();
3864 let byte_len = (val_utf16.len() * 2) as u16;
3865 inner.put_u16_le(byte_len);
3866 for &c in &val_utf16 {
3867 inner.put_u16_le(c);
3868 }
3869
3870 let mut cursor = inner.freeze();
3871 let rv = ReturnValue::decode(&mut cursor).unwrap();
3872
3873 assert_eq!(rv.param_ordinal, 1);
3874 assert_eq!(rv.param_name, "@name");
3875 assert_eq!(rv.status, 0x01);
3876 assert_eq!(rv.col_type, 0xE7); assert_eq!(rv.type_info.max_length, Some(200));
3878 assert!(rv.type_info.collation.is_some());
3879
3880 assert_eq!(rv.value.len(), 12); let val_len = u16::from_le_bytes([rv.value[0], rv.value[1]]);
3883 assert_eq!(val_len, 10);
3884 }
3885
3886 #[test]
3887 fn test_return_value_via_parser() {
3888 let mut data = BytesMut::new();
3890 data.put_u8(0xAC); data.extend_from_slice(&build_return_value_intn(0, "@out", 0x01, Some(99)));
3892
3893 let mut parser = TokenParser::new(data.freeze());
3894 let token = parser.next_token().unwrap().unwrap();
3895
3896 match token {
3897 Token::ReturnValue(rv) => {
3898 assert_eq!(rv.param_name, "@out");
3899 assert_eq!(rv.param_ordinal, 0);
3900 assert_eq!(rv.status, 0x01);
3901 assert_eq!(rv.col_type, 0x26);
3902 }
3903 _ => panic!("Expected ReturnValue token"),
3904 }
3905 }
3906
3907 #[test]
3912 fn test_multi_token_stored_proc_response() {
3913 let mut data = BytesMut::new();
3916
3917 data.put_u8(0xFF); data.put_u16_le(0x0010); data.put_u16_le(0x00C1); data.put_u64_le(3); data.put_u8(0x79); data.put_i32_le(0);
3926
3927 data.put_u8(0xFE); data.put_u16_le(0x0000); data.put_u16_le(0x00C6); data.put_u64_le(0);
3932
3933 let mut parser = TokenParser::new(data.freeze());
3934
3935 let t1 = parser.next_token().unwrap().unwrap();
3937 match t1 {
3938 Token::DoneInProc(done) => {
3939 assert!(done.status.count);
3940 assert_eq!(done.row_count, 3);
3941 assert_eq!(done.cur_cmd, 193);
3942 }
3943 _ => panic!("Expected DoneInProc, got {t1:?}"),
3944 }
3945
3946 let t2 = parser.next_token().unwrap().unwrap();
3948 match t2 {
3949 Token::ReturnStatus(status) => {
3950 assert_eq!(status, 0);
3951 }
3952 _ => panic!("Expected ReturnStatus, got {t2:?}"),
3953 }
3954
3955 let t3 = parser.next_token().unwrap().unwrap();
3957 match t3 {
3958 Token::DoneProc(done) => {
3959 assert!(!done.status.count);
3960 assert!(!done.status.more);
3961 assert_eq!(done.cur_cmd, 198);
3962 }
3963 _ => panic!("Expected DoneProc, got {t3:?}"),
3964 }
3965
3966 assert!(parser.next_token().unwrap().is_none());
3968 }
3969
3970 #[test]
3971 fn test_multi_token_error_in_stream() {
3972 let mut data = BytesMut::new();
3974
3975 data.put_u8(0xAA);
3977
3978 let msg_utf16: Vec<u16> = "Deadlock".encode_utf16().collect();
3979 let srv_utf16: Vec<u16> = "DB1".encode_utf16().collect();
3980
3981 let length: u16 = (4 + 1 + 1
3982 + 2 + (msg_utf16.len() * 2)
3983 + 1 + (srv_utf16.len() * 2)
3984 + 1 + 4) as u16;
3986
3987 data.put_u16_le(length);
3988 data.put_i32_le(1205); data.put_u8(51); data.put_u8(13); data.put_u16_le(msg_utf16.len() as u16);
3993 for &c in &msg_utf16 {
3994 data.put_u16_le(c);
3995 }
3996 data.put_u8(srv_utf16.len() as u8);
3997 for &c in &srv_utf16 {
3998 data.put_u16_le(c);
3999 }
4000 data.put_u8(0); data.put_i32_le(0);
4002
4003 data.put_u8(0xFD);
4005 data.put_u16_le(0x0002); data.put_u16_le(0x00C1); data.put_u64_le(0);
4008
4009 let mut parser = TokenParser::new(data.freeze());
4010
4011 let t1 = parser.next_token().unwrap().unwrap();
4013 match t1 {
4014 Token::Error(err) => {
4015 assert_eq!(err.number, 1205);
4016 assert_eq!(err.class, 13);
4017 assert_eq!(err.message, "Deadlock");
4018 assert_eq!(err.server, "DB1");
4019 }
4020 _ => panic!("Expected Error token, got {t1:?}"),
4021 }
4022
4023 let t2 = parser.next_token().unwrap().unwrap();
4025 match t2 {
4026 Token::Done(done) => {
4027 assert!(done.status.error);
4028 assert!(!done.status.count);
4029 }
4030 _ => panic!("Expected Done token, got {t2:?}"),
4031 }
4032
4033 assert!(parser.next_token().unwrap().is_none());
4034 }
4035
4036 #[test]
4037 fn test_multi_token_proc_with_return_value() {
4038 let mut data = BytesMut::new();
4040
4041 data.put_u8(0xAC);
4043 data.extend_from_slice(&build_return_value_intn(1, "@result", 0x01, Some(42)));
4044
4045 data.put_u8(0x79);
4047 data.put_i32_le(0);
4048
4049 data.put_u8(0xFE);
4051 data.put_u16_le(0x0000);
4052 data.put_u16_le(0x00C6);
4053 data.put_u64_le(0);
4054
4055 let mut parser = TokenParser::new(data.freeze());
4056
4057 let t1 = parser.next_token().unwrap().unwrap();
4058 match t1 {
4059 Token::ReturnValue(rv) => {
4060 assert_eq!(rv.param_name, "@result");
4061 assert_eq!(rv.param_ordinal, 1);
4062 }
4063 _ => panic!("Expected ReturnValue, got {t1:?}"),
4064 }
4065
4066 let t2 = parser.next_token().unwrap().unwrap();
4067 assert!(matches!(t2, Token::ReturnStatus(0)));
4068
4069 let t3 = parser.next_token().unwrap().unwrap();
4070 assert!(matches!(t3, Token::DoneProc(_)));
4071
4072 assert!(parser.next_token().unwrap().is_none());
4073 }
4074
4075 #[test]
4080 fn test_return_status_truncated() {
4081 let data = Bytes::from_static(&[0x79, 0x01, 0x02, 0x03]);
4083 let mut parser = TokenParser::new(data);
4084 assert!(parser.next_token().is_err());
4085 }
4086
4087 #[test]
4088 fn test_done_proc_truncated() {
4089 let data = Bytes::from_static(&[0xFE, 0x00, 0x00, 0xC1, 0x00, 0x01, 0x00, 0x00, 0x00]);
4091 let mut parser = TokenParser::new(data);
4092 assert!(parser.next_token().is_err());
4093 }
4094
4095 #[test]
4096 fn test_server_error_truncated() {
4097 let data = Bytes::from_static(&[0xAA, 0x20, 0x00]);
4099 let mut parser = TokenParser::new(data);
4100 assert!(parser.next_token().is_err());
4101 }
4102
4103 fn build_fed_auth_info_token(options: &[(u8, &str)]) -> Vec<u8> {
4112 let headers_end = 4 + options.len() * 9;
4113 let mut data_block = Vec::new();
4114 let mut headers = Vec::new();
4115 for (id, value) in options {
4116 let encoded: Vec<u8> = value.encode_utf16().flat_map(u16::to_le_bytes).collect();
4117 let offset = headers_end + data_block.len();
4118 headers.push(*id);
4119 headers.extend_from_slice(&u32::try_from(encoded.len()).unwrap().to_le_bytes());
4120 headers.extend_from_slice(&u32::try_from(offset).unwrap().to_le_bytes());
4121 data_block.extend_from_slice(&encoded);
4122 }
4123
4124 let token_len = 4 + headers.len() + data_block.len();
4125 let mut out = vec![0xEE];
4126 out.extend_from_slice(&u32::try_from(token_len).unwrap().to_le_bytes());
4127 out.extend_from_slice(&u32::try_from(options.len()).unwrap().to_le_bytes());
4128 out.extend_from_slice(&headers);
4129 out.extend_from_slice(&data_block);
4130 out
4131 }
4132
4133 #[test]
4134 fn test_fed_auth_info_decodes_spec_layout() {
4135 const STS: &str = "https://login.microsoftonline.com/common";
4136 const SPN: &str = "https://database.windows.net/";
4137 let token = build_fed_auth_info_token(&[(0x01, STS), (0x02, SPN)]);
4141
4142 let mut parser = TokenParser::new(Bytes::from(token));
4143 let parsed = parser.next_token().unwrap().unwrap();
4144 let Token::FedAuthInfo(info) = parsed else {
4145 panic!("expected FedAuthInfo, got {parsed:?}");
4146 };
4147 assert_eq!(info.sts_url, STS);
4148 assert_eq!(info.spn, SPN);
4149 assert!(parser.next_token().unwrap().is_none(), "exact consumption");
4150 }
4151
4152 #[test]
4153 fn test_fed_auth_info_preserves_following_tokens() {
4154 let mut stream = build_fed_auth_info_token(&[
4158 (0x01, "https://sts.example/"),
4159 (0x02, "https://db.example/"),
4160 ]);
4161 stream.push(0xFD); stream.extend_from_slice(&0u16.to_le_bytes()); stream.extend_from_slice(&0u16.to_le_bytes()); stream.extend_from_slice(&0u64.to_le_bytes()); let mut parser = TokenParser::new(Bytes::from(stream));
4167 assert!(matches!(
4168 parser.next_token().unwrap(),
4169 Some(Token::FedAuthInfo(_))
4170 ));
4171 assert!(
4172 matches!(parser.next_token().unwrap(), Some(Token::Done(_))),
4173 "DONE after FEDAUTHINFO must not be swallowed"
4174 );
4175 assert!(parser.next_token().unwrap().is_none());
4176 }
4177
4178 #[test]
4179 fn test_fed_auth_info_unknown_ids_ignored() {
4180 let token =
4182 build_fed_auth_info_token(&[(0x7F, "ignore-me"), (0x01, "https://sts.example/")]);
4183 let mut parser = TokenParser::new(Bytes::from(token));
4184 let Some(Token::FedAuthInfo(info)) = parser.next_token().unwrap() else {
4185 panic!("expected FedAuthInfo");
4186 };
4187 assert_eq!(info.sts_url, "https://sts.example/");
4188 assert_eq!(info.spn, "");
4189 }
4190
4191 #[test]
4192 fn test_fed_auth_info_hostile_inputs_error() {
4193 let mut truncated = build_fed_auth_info_token(&[(0x02, "https://sts.example/")]);
4195 truncated.truncate(truncated.len() - 4);
4196 assert!(
4197 TokenParser::new(Bytes::from(truncated))
4198 .next_token()
4199 .is_err()
4200 );
4201
4202 let mut bad_count = build_fed_auth_info_token(&[(0x02, "https://sts.example/")]);
4205 bad_count[5..9].copy_from_slice(&u32::MAX.to_le_bytes());
4206 assert!(
4207 TokenParser::new(Bytes::from(bad_count))
4208 .next_token()
4209 .is_err()
4210 );
4211
4212 let mut bad_offset = build_fed_auth_info_token(&[(0x02, "https://sts.example/")]);
4214 bad_offset[14..18].copy_from_slice(&u32::MAX.to_le_bytes());
4215 assert!(
4216 TokenParser::new(Bytes::from(bad_offset))
4217 .next_token()
4218 .is_err()
4219 );
4220
4221 let mut odd_len = build_fed_auth_info_token(&[(0x02, "https://sts.example/")]);
4223 odd_len[10..14].copy_from_slice(&3u32.to_le_bytes());
4224 assert!(TokenParser::new(Bytes::from(odd_len)).next_token().is_err());
4225 }
4226
4227 #[test]
4228 fn test_fed_auth_info_parse_and_skip_agree() {
4229 let token = build_fed_auth_info_token(&[(0x02, "https://sts.example/")]);
4232 let total = token.len();
4233
4234 let mut parser = TokenParser::new(Bytes::from(token.clone()));
4235 parser.next_token().unwrap();
4236 assert_eq!(parser.position(), total, "decode consumption");
4237
4238 let mut skipper = TokenParser::new(Bytes::from(token));
4239 skipper.skip_token().unwrap();
4240 assert_eq!(skipper.position(), total, "skip consumption");
4241 }
4242
4243 #[test]
4254 fn test_fed_auth_info_captured_from_azure() {
4255 const CAPTURED: &[u8] = &[
4256 0xEE, 0xCC, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x02, 0x3A, 0x00, 0x00, 0x00,
4257 0x16, 0x00, 0x00, 0x00, 0x01, 0x7C, 0x00, 0x00, 0x00, 0x50, 0x00, 0x00, 0x00, 0x68,
4258 0x00, 0x74, 0x00, 0x74, 0x00, 0x70, 0x00, 0x73, 0x00, 0x3A, 0x00, 0x2F, 0x00, 0x2F,
4259 0x00, 0x64, 0x00, 0x61, 0x00, 0x74, 0x00, 0x61, 0x00, 0x62, 0x00, 0x61, 0x00, 0x73,
4260 0x00, 0x65, 0x00, 0x2E, 0x00, 0x77, 0x00, 0x69, 0x00, 0x6E, 0x00, 0x64, 0x00, 0x6F,
4261 0x00, 0x77, 0x00, 0x73, 0x00, 0x2E, 0x00, 0x6E, 0x00, 0x65, 0x00, 0x74, 0x00, 0x2F,
4262 0x00, 0x68, 0x00, 0x74, 0x00, 0x74, 0x00, 0x70, 0x00, 0x73, 0x00, 0x3A, 0x00, 0x2F,
4263 0x00, 0x2F, 0x00, 0x6C, 0x00, 0x6F, 0x00, 0x67, 0x00, 0x69, 0x00, 0x6E, 0x00, 0x2E,
4264 0x00, 0x77, 0x00, 0x69, 0x00, 0x6E, 0x00, 0x64, 0x00, 0x6F, 0x00, 0x77, 0x00, 0x73,
4265 0x00, 0x2E, 0x00, 0x6E, 0x00, 0x65, 0x00, 0x74, 0x00, 0x2F, 0x00, 0x30, 0x00, 0x30,
4266 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x2D,
4267 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x2D, 0x00, 0x30, 0x00, 0x30,
4268 0x00, 0x30, 0x00, 0x30, 0x00, 0x2D, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30,
4269 0x00, 0x2D, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30,
4270 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00,
4271 ];
4272
4273 let mut parser = TokenParser::new(Bytes::from_static(CAPTURED));
4274 let Some(Token::FedAuthInfo(info)) = parser.next_token().unwrap() else {
4275 panic!("expected FedAuthInfo");
4276 };
4277 assert_eq!(
4278 info.sts_url,
4279 "https://login.windows.net/00000000-0000-0000-0000-000000000000"
4280 );
4281 assert_eq!(info.spn, "https://database.windows.net/");
4282 assert!(
4283 parser.next_token().unwrap().is_none(),
4284 "the captured token must be consumed exactly"
4285 );
4286 }
4287
4288 #[test]
4304 fn skip_tokens_iterate_not_recurse_273() {
4305 const SKIP_COUNT: usize = 200_000;
4306 let mut buf = BytesMut::with_capacity(SKIP_COUNT * 3 + 13);
4307 for _ in 0..SKIP_COUNT {
4308 buf.put_u8(TokenType::ColInfo as u8);
4309 buf.put_u16_le(0); }
4311 let done = Done {
4312 status: DoneStatus {
4313 more: false,
4314 error: false,
4315 in_xact: false,
4316 count: true,
4317 attn: false,
4318 srverror: false,
4319 },
4320 cur_cmd: 0xABCD,
4321 row_count: 99,
4322 };
4323 done.encode(&mut buf);
4324 let total_len = buf.len();
4325
4326 let mut parser = TokenParser::new(buf.freeze());
4327
4328 let Some(Token::Done(decoded)) = parser.next_token().unwrap() else {
4331 panic!("expected the DONE token after the skip run");
4332 };
4333 assert_eq!(decoded.cur_cmd, 0xABCD);
4334 assert_eq!(decoded.row_count, 99);
4335
4336 assert_eq!(parser.position(), total_len);
4338 assert!(parser.next_token().unwrap().is_none());
4339 }
4340}