1use bytes::{Buf, BufMut, Bytes};
33
34use crate::codec::{read_b_varchar, read_us_varchar};
35use crate::error::ProtocolError;
36use crate::prelude::*;
37use crate::types::TypeId;
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
41#[repr(u8)]
42#[non_exhaustive]
43pub enum TokenType {
44 ColMetaData = 0x81,
46 Error = 0xAA,
48 Info = 0xAB,
50 LoginAck = 0xAD,
52 Row = 0xD1,
54 NbcRow = 0xD2,
56 EnvChange = 0xE3,
58 Sspi = 0xED,
60 Done = 0xFD,
62 DoneInProc = 0xFF,
64 DoneProc = 0xFE,
66 ReturnStatus = 0x79,
68 ReturnValue = 0xAC,
70 Order = 0xA9,
72 FeatureExtAck = 0xAE,
74 SessionState = 0xE4,
76 FedAuthInfo = 0xEE,
78 ColInfo = 0xA5,
80 TabName = 0xA4,
82 Offset = 0x78,
84}
85
86impl TokenType {
87 pub fn from_u8(value: u8) -> Option<Self> {
89 match value {
90 0x81 => Some(Self::ColMetaData),
91 0xAA => Some(Self::Error),
92 0xAB => Some(Self::Info),
93 0xAD => Some(Self::LoginAck),
94 0xD1 => Some(Self::Row),
95 0xD2 => Some(Self::NbcRow),
96 0xE3 => Some(Self::EnvChange),
97 0xED => Some(Self::Sspi),
98 0xFD => Some(Self::Done),
99 0xFF => Some(Self::DoneInProc),
100 0xFE => Some(Self::DoneProc),
101 0x79 => Some(Self::ReturnStatus),
102 0xAC => Some(Self::ReturnValue),
103 0xA9 => Some(Self::Order),
104 0xAE => Some(Self::FeatureExtAck),
105 0xE4 => Some(Self::SessionState),
106 0xEE => Some(Self::FedAuthInfo),
107 0xA5 => Some(Self::ColInfo),
108 0xA4 => Some(Self::TabName),
109 0x78 => Some(Self::Offset),
110 _ => None,
111 }
112 }
113}
114
115#[derive(Debug, Clone)]
120#[non_exhaustive]
121pub enum Token {
122 ColMetaData(ColMetaData),
124 Row(RawRow),
126 NbcRow(NbcRow),
128 Done(Done),
130 DoneProc(DoneProc),
132 DoneInProc(DoneInProc),
134 ReturnStatus(i32),
136 ReturnValue(ReturnValue),
138 Error(ServerError),
140 Info(ServerInfo),
142 LoginAck(LoginAck),
144 EnvChange(EnvChange),
146 Order(Order),
148 FeatureExtAck(FeatureExtAck),
150 Sspi(SspiToken),
152 SessionState(SessionState),
154 FedAuthInfo(FedAuthInfo),
156}
157
158#[derive(Debug, Clone, Default)]
160pub struct ColMetaData {
161 pub columns: Vec<ColumnData>,
163 pub cek_table: Option<crate::crypto::CekTable>,
166}
167
168#[derive(Debug, Clone)]
170pub struct ColumnData {
171 pub name: String,
173 pub type_id: TypeId,
175 pub col_type: u8,
177 pub flags: u16,
179 pub user_type: u32,
181 pub type_info: TypeInfo,
183 pub crypto_metadata: Option<crate::crypto::CryptoMetadata>,
186}
187
188#[derive(Debug, Clone, Default)]
190pub struct TypeInfo {
191 pub max_length: Option<u32>,
193 pub precision: Option<u8>,
195 pub scale: Option<u8>,
197 pub collation: Option<Collation>,
199}
200
201#[derive(Debug, Clone, Copy, Default)]
224pub struct Collation {
225 pub lcid: u32,
233 pub sort_id: u8,
237}
238
239impl Collation {
240 pub fn from_bytes(bytes: &[u8; 5]) -> Self {
244 Self {
245 lcid: u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]),
246 sort_id: bytes[4],
247 }
248 }
249
250 pub fn to_bytes(&self) -> [u8; 5] {
254 let b = self.lcid.to_le_bytes();
255 [b[0], b[1], b[2], b[3], self.sort_id]
256 }
257
258 #[cfg(feature = "encoding")]
284 pub fn encoding(&self) -> Option<&'static encoding_rs::Encoding> {
285 if self.sort_id != 0 {
289 return crate::collation::encoding_for_sort_id(self.sort_id);
290 }
291 crate::collation::encoding_for_lcid(self.lcid)
292 }
293
294 #[cfg(feature = "encoding")]
299 pub fn is_utf8(&self) -> bool {
300 crate::collation::is_utf8_collation(self.lcid)
301 }
302
303 #[cfg(feature = "encoding")]
311 pub fn code_page(&self) -> Option<u16> {
312 if self.sort_id != 0 {
315 return crate::collation::code_page_for_sort_id(self.sort_id);
316 }
317 crate::collation::code_page_for_lcid(self.lcid)
318 }
319
320 #[cfg(feature = "encoding")]
324 pub fn encoding_name(&self) -> &'static str {
325 if self.sort_id != 0 {
326 return match crate::collation::encoding_for_sort_id(self.sort_id) {
327 Some(enc) => enc.name(),
328 None => "unsupported",
329 };
330 }
331 crate::collation::encoding_name_for_lcid(self.lcid)
332 }
333}
334
335#[derive(Debug, Clone)]
337pub struct RawRow {
338 pub data: bytes::Bytes,
340}
341
342#[derive(Debug, Clone)]
344pub struct NbcRow {
345 pub null_bitmap: Vec<u8>,
347 pub data: bytes::Bytes,
349}
350
351#[derive(Debug, Clone, Copy)]
353pub struct Done {
354 pub status: DoneStatus,
356 pub cur_cmd: u16,
358 pub row_count: u64,
360}
361
362#[derive(Debug, Clone, Copy, Default)]
364#[non_exhaustive]
365pub struct DoneStatus {
366 pub more: bool,
368 pub error: bool,
370 pub in_xact: bool,
372 pub count: bool,
374 pub attn: bool,
376 pub srverror: bool,
378}
379
380#[derive(Debug, Clone, Copy)]
382pub struct DoneInProc {
383 pub status: DoneStatus,
385 pub cur_cmd: u16,
387 pub row_count: u64,
389}
390
391#[derive(Debug, Clone, Copy)]
393pub struct DoneProc {
394 pub status: DoneStatus,
396 pub cur_cmd: u16,
398 pub row_count: u64,
400}
401
402#[derive(Debug, Clone)]
404#[non_exhaustive]
405pub struct ReturnValue {
406 pub param_ordinal: u16,
408 pub param_name: String,
410 pub status: u8,
412 pub user_type: u32,
414 pub flags: u16,
416 pub col_type: u8,
418 pub type_info: TypeInfo,
420 pub value: bytes::Bytes,
422}
423
424#[derive(Debug, Clone)]
426pub struct ServerError {
427 pub number: i32,
429 pub state: u8,
431 pub class: u8,
433 pub message: String,
435 pub server: String,
437 pub procedure: String,
439 pub line: i32,
441}
442
443#[derive(Debug, Clone)]
445pub struct ServerInfo {
446 pub number: i32,
448 pub state: u8,
450 pub class: u8,
452 pub message: String,
454 pub server: String,
456 pub procedure: String,
458 pub line: i32,
460}
461
462#[derive(Debug, Clone)]
464pub struct LoginAck {
465 pub interface: u8,
467 pub tds_version: u32,
469 pub prog_name: String,
471 pub prog_version: u32,
473}
474
475#[derive(Debug, Clone)]
477pub struct EnvChange {
478 pub env_type: EnvChangeType,
480 pub new_value: EnvChangeValue,
482 pub old_value: EnvChangeValue,
484}
485
486#[derive(Debug, Clone, Copy, PartialEq, Eq)]
488#[repr(u8)]
489#[non_exhaustive]
490pub enum EnvChangeType {
491 Database = 1,
493 Language = 2,
495 CharacterSet = 3,
497 PacketSize = 4,
499 UnicodeSortingLocalId = 5,
501 UnicodeComparisonFlags = 6,
503 SqlCollation = 7,
505 BeginTransaction = 8,
507 CommitTransaction = 9,
509 RollbackTransaction = 10,
511 EnlistDtcTransaction = 11,
513 DefectTransaction = 12,
515 RealTimeLogShipping = 13,
517 PromoteTransaction = 15,
519 TransactionManagerAddress = 16,
521 TransactionEnded = 17,
523 ResetConnectionCompletionAck = 18,
525 UserInstanceStarted = 19,
527 Routing = 20,
529}
530
531#[derive(Debug, Clone)]
533#[non_exhaustive]
534pub enum EnvChangeValue {
535 String(String),
537 Binary(bytes::Bytes),
539 Routing {
541 host: String,
543 port: u16,
545 },
546}
547
548#[derive(Debug, Clone)]
550pub struct Order {
551 pub columns: Vec<u16>,
553}
554
555#[derive(Debug, Clone)]
557pub struct FeatureExtAck {
558 pub features: Vec<FeatureAck>,
560}
561
562#[derive(Debug, Clone)]
564pub struct FeatureAck {
565 pub feature_id: u8,
567 pub data: bytes::Bytes,
569}
570
571#[derive(Debug, Clone)]
573pub struct SspiToken {
574 pub data: bytes::Bytes,
576}
577
578#[derive(Debug, Clone)]
580pub struct SessionState {
581 pub data: bytes::Bytes,
583}
584
585#[derive(Debug, Clone)]
587pub struct FedAuthInfo {
588 pub sts_url: String,
590 pub spn: String,
592}
593
594pub(crate) fn decode_collation(src: &mut impl Buf) -> Result<Collation, ProtocolError> {
602 if src.remaining() < 5 {
603 return Err(ProtocolError::UnexpectedEof);
604 }
605 let lcid = src.get_u32_le();
607 let sort_id = src.get_u8();
608 Ok(Collation { lcid, sort_id })
609}
610
611pub(crate) fn decode_type_info(
615 src: &mut impl Buf,
616 type_id: TypeId,
617 col_type: u8,
618) -> Result<TypeInfo, ProtocolError> {
619 match type_id {
620 TypeId::Null => Ok(TypeInfo::default()),
622 TypeId::Int1 | TypeId::Bit => Ok(TypeInfo::default()),
623 TypeId::Int2 => Ok(TypeInfo::default()),
624 TypeId::Int4 => Ok(TypeInfo::default()),
625 TypeId::Int8 => Ok(TypeInfo::default()),
626 TypeId::Float4 => Ok(TypeInfo::default()),
627 TypeId::Float8 => Ok(TypeInfo::default()),
628 TypeId::Money => Ok(TypeInfo::default()),
629 TypeId::Money4 => Ok(TypeInfo::default()),
630 TypeId::DateTime => Ok(TypeInfo::default()),
631 TypeId::DateTime4 => Ok(TypeInfo::default()),
632
633 TypeId::IntN | TypeId::BitN | TypeId::FloatN | TypeId::MoneyN | TypeId::DateTimeN => {
635 if src.remaining() < 1 {
636 return Err(ProtocolError::UnexpectedEof);
637 }
638 let max_length = src.get_u8() as u32;
639 Ok(TypeInfo {
640 max_length: Some(max_length),
641 ..Default::default()
642 })
643 }
644
645 TypeId::Guid => {
647 if src.remaining() < 1 {
648 return Err(ProtocolError::UnexpectedEof);
649 }
650 let max_length = src.get_u8() as u32;
651 Ok(TypeInfo {
652 max_length: Some(max_length),
653 ..Default::default()
654 })
655 }
656
657 TypeId::Decimal | TypeId::Numeric | TypeId::DecimalN | TypeId::NumericN => {
659 if src.remaining() < 3 {
660 return Err(ProtocolError::UnexpectedEof);
661 }
662 let max_length = src.get_u8() as u32;
663 let precision = src.get_u8();
664 let scale = src.get_u8();
665 Ok(TypeInfo {
666 max_length: Some(max_length),
667 precision: Some(precision),
668 scale: Some(scale),
669 ..Default::default()
670 })
671 }
672
673 TypeId::Char | TypeId::VarChar | TypeId::Binary | TypeId::VarBinary => {
675 if src.remaining() < 1 {
676 return Err(ProtocolError::UnexpectedEof);
677 }
678 let max_length = src.get_u8() as u32;
679 Ok(TypeInfo {
680 max_length: Some(max_length),
681 ..Default::default()
682 })
683 }
684
685 TypeId::BigVarChar | TypeId::BigChar => {
687 if src.remaining() < 7 {
688 return Err(ProtocolError::UnexpectedEof);
690 }
691 let max_length = src.get_u16_le() as u32;
692 let collation = decode_collation(src)?;
693 Ok(TypeInfo {
694 max_length: Some(max_length),
695 collation: Some(collation),
696 ..Default::default()
697 })
698 }
699
700 TypeId::BigVarBinary | TypeId::BigBinary => {
702 if src.remaining() < 2 {
703 return Err(ProtocolError::UnexpectedEof);
704 }
705 let max_length = src.get_u16_le() as u32;
706 Ok(TypeInfo {
707 max_length: Some(max_length),
708 ..Default::default()
709 })
710 }
711
712 TypeId::NChar | TypeId::NVarChar => {
714 if src.remaining() < 7 {
715 return Err(ProtocolError::UnexpectedEof);
717 }
718 let max_length = src.get_u16_le() as u32;
719 let collation = decode_collation(src)?;
720 Ok(TypeInfo {
721 max_length: Some(max_length),
722 collation: Some(collation),
723 ..Default::default()
724 })
725 }
726
727 TypeId::Date => Ok(TypeInfo::default()),
729
730 TypeId::Time | TypeId::DateTime2 | TypeId::DateTimeOffset => {
732 if src.remaining() < 1 {
733 return Err(ProtocolError::UnexpectedEof);
734 }
735 let scale = src.get_u8();
736 Ok(TypeInfo {
737 scale: Some(scale),
738 ..Default::default()
739 })
740 }
741
742 TypeId::Text | TypeId::NText | TypeId::Image => {
744 if src.remaining() < 4 {
746 return Err(ProtocolError::UnexpectedEof);
747 }
748 let max_length = src.get_u32_le();
749
750 let collation = if type_id == TypeId::Text || type_id == TypeId::NText {
752 if src.remaining() < 5 {
753 return Err(ProtocolError::UnexpectedEof);
754 }
755 Some(decode_collation(src)?)
756 } else {
757 None
758 };
759
760 if src.remaining() < 1 {
763 return Err(ProtocolError::UnexpectedEof);
764 }
765 let num_parts = src.get_u8();
766 for _ in 0..num_parts {
767 let _ = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
769 }
770
771 Ok(TypeInfo {
772 max_length: Some(max_length),
773 collation,
774 ..Default::default()
775 })
776 }
777
778 TypeId::Xml => {
780 if src.remaining() < 1 {
781 return Err(ProtocolError::UnexpectedEof);
782 }
783 let schema_present = src.get_u8();
784
785 if schema_present != 0 {
786 let _ = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; let _ = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; let _ = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; }
793
794 Ok(TypeInfo::default())
795 }
796
797 TypeId::Udt => {
799 if src.remaining() < 2 {
801 return Err(ProtocolError::UnexpectedEof);
802 }
803 let max_length = src.get_u16_le() as u32;
804
805 let _ = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; let _ = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; let _ = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; let _ = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?; Ok(TypeInfo {
814 max_length: Some(max_length),
815 ..Default::default()
816 })
817 }
818
819 TypeId::Tvp => {
821 Err(ProtocolError::InvalidTokenType(col_type))
824 }
825
826 TypeId::Variant => {
828 if src.remaining() < 4 {
829 return Err(ProtocolError::UnexpectedEof);
830 }
831 let max_length = src.get_u32_le();
832 Ok(TypeInfo {
833 max_length: Some(max_length),
834 ..Default::default()
835 })
836 }
837 }
838}
839
840impl ColMetaData {
841 pub const NO_METADATA: u16 = 0xFFFF;
843
844 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
846 if src.remaining() < 2 {
847 return Err(ProtocolError::UnexpectedEof);
848 }
849
850 let column_count = src.get_u16_le();
851
852 if column_count == Self::NO_METADATA {
854 return Ok(Self {
855 columns: Vec::new(),
856 cek_table: None,
857 });
858 }
859
860 let mut columns = Vec::with_capacity(column_count as usize);
861
862 for _ in 0..column_count {
863 let column = Self::decode_column(src)?;
864 columns.push(column);
865 }
866
867 Ok(Self {
868 columns,
869 cek_table: None,
870 })
871 }
872
873 fn decode_column(src: &mut impl Buf) -> Result<ColumnData, ProtocolError> {
875 if src.remaining() < 7 {
877 return Err(ProtocolError::UnexpectedEof);
878 }
879
880 let user_type = src.get_u32_le();
881 let flags = src.get_u16_le();
882 let col_type = src.get_u8();
883
884 let type_id = TypeId::from_u8(col_type).ok_or(ProtocolError::InvalidDataType(col_type))?;
888
889 let type_info = decode_type_info(src, type_id, col_type)?;
891
892 let name = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
894
895 Ok(ColumnData {
896 name,
897 type_id,
898 col_type,
899 flags,
900 user_type,
901 type_info,
902 crypto_metadata: None,
903 })
904 }
905
906 pub fn decode_encrypted(src: &mut impl Buf) -> Result<Self, ProtocolError> {
919 if src.remaining() < 2 {
920 return Err(ProtocolError::UnexpectedEof);
921 }
922
923 let column_count = src.get_u16_le();
924
925 if column_count == Self::NO_METADATA {
926 return Ok(Self {
927 columns: Vec::new(),
928 cek_table: None,
929 });
930 }
931
932 let cek_table = crate::crypto::CekTable::decode(src)?;
934
935 let mut columns = Vec::with_capacity(column_count as usize);
936
937 for _ in 0..column_count {
938 let column = Self::decode_column_encrypted(src)?;
939 columns.push(column);
940 }
941
942 Ok(Self {
943 columns,
944 cek_table: Some(cek_table),
945 })
946 }
947
948 fn decode_column_encrypted(src: &mut impl Buf) -> Result<ColumnData, ProtocolError> {
952 if src.remaining() < 7 {
953 return Err(ProtocolError::UnexpectedEof);
954 }
955
956 let user_type = src.get_u32_le();
957 let flags = src.get_u16_le();
958 let col_type = src.get_u8();
959
960 let type_id = TypeId::from_u8(col_type).ok_or(ProtocolError::InvalidDataType(col_type))?;
961
962 let type_info = decode_type_info(src, type_id, col_type)?;
964
965 let crypto_metadata = if crate::crypto::is_column_encrypted(flags) {
967 Some(crate::crypto::CryptoMetadata::decode(src)?)
968 } else {
969 None
970 };
971
972 let name = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
974
975 Ok(ColumnData {
976 name,
977 type_id,
978 col_type,
979 flags,
980 user_type,
981 type_info,
982 crypto_metadata,
983 })
984 }
985
986 #[must_use]
988 pub fn column_count(&self) -> usize {
989 self.columns.len()
990 }
991
992 #[must_use]
994 pub fn is_empty(&self) -> bool {
995 self.columns.is_empty()
996 }
997}
998
999impl ColumnData {
1000 #[must_use]
1002 pub fn is_nullable(&self) -> bool {
1003 (self.flags & 0x0001) != 0
1004 }
1005
1006 #[must_use]
1010 pub fn fixed_size(&self) -> Option<usize> {
1011 match self.type_id {
1012 TypeId::Null => Some(0),
1013 TypeId::Int1 | TypeId::Bit => Some(1),
1014 TypeId::Int2 => Some(2),
1015 TypeId::Int4 => Some(4),
1016 TypeId::Int8 => Some(8),
1017 TypeId::Float4 => Some(4),
1018 TypeId::Float8 => Some(8),
1019 TypeId::Money => Some(8),
1020 TypeId::Money4 => Some(4),
1021 TypeId::DateTime => Some(8),
1022 TypeId::DateTime4 => Some(4),
1023 TypeId::Date => Some(3),
1024 _ => None,
1025 }
1026 }
1027}
1028
1029impl RawRow {
1034 pub fn decode(src: &mut impl Buf, metadata: &ColMetaData) -> Result<Self, ProtocolError> {
1039 let mut data = bytes::BytesMut::new();
1040
1041 for col in &metadata.columns {
1042 Self::decode_column_value(src, col, &mut data)?;
1043 }
1044
1045 Ok(Self {
1046 data: data.freeze(),
1047 })
1048 }
1049
1050 pub fn decode_prefix(
1057 src: &mut impl Buf,
1058 metadata: &ColMetaData,
1059 prefix_len: usize,
1060 ) -> Result<Self, ProtocolError> {
1061 let mut data = bytes::BytesMut::new();
1062 for col in metadata.columns.iter().take(prefix_len) {
1063 Self::decode_column_value(src, col, &mut data)?;
1064 }
1065 Ok(Self {
1066 data: data.freeze(),
1067 })
1068 }
1069
1070 fn decode_column_value(
1072 src: &mut impl Buf,
1073 col: &ColumnData,
1074 dst: &mut bytes::BytesMut,
1075 ) -> Result<(), ProtocolError> {
1076 match col.type_id {
1077 TypeId::Null => {
1079 }
1081 TypeId::Int1 | TypeId::Bit => {
1082 if src.remaining() < 1 {
1083 return Err(ProtocolError::UnexpectedEof);
1084 }
1085 dst.extend_from_slice(&[src.get_u8()]);
1086 }
1087 TypeId::Int2 => {
1088 if src.remaining() < 2 {
1089 return Err(ProtocolError::UnexpectedEof);
1090 }
1091 dst.extend_from_slice(&src.get_u16_le().to_le_bytes());
1092 }
1093 TypeId::Int4 => {
1094 if src.remaining() < 4 {
1095 return Err(ProtocolError::UnexpectedEof);
1096 }
1097 dst.extend_from_slice(&src.get_u32_le().to_le_bytes());
1098 }
1099 TypeId::Int8 => {
1100 if src.remaining() < 8 {
1101 return Err(ProtocolError::UnexpectedEof);
1102 }
1103 dst.extend_from_slice(&src.get_u64_le().to_le_bytes());
1104 }
1105 TypeId::Float4 => {
1106 if src.remaining() < 4 {
1107 return Err(ProtocolError::UnexpectedEof);
1108 }
1109 dst.extend_from_slice(&src.get_u32_le().to_le_bytes());
1110 }
1111 TypeId::Float8 => {
1112 if src.remaining() < 8 {
1113 return Err(ProtocolError::UnexpectedEof);
1114 }
1115 dst.extend_from_slice(&src.get_u64_le().to_le_bytes());
1116 }
1117 TypeId::Money => {
1118 if src.remaining() < 8 {
1119 return Err(ProtocolError::UnexpectedEof);
1120 }
1121 let hi = src.get_u32_le();
1122 let lo = src.get_u32_le();
1123 dst.extend_from_slice(&hi.to_le_bytes());
1124 dst.extend_from_slice(&lo.to_le_bytes());
1125 }
1126 TypeId::Money4 => {
1127 if src.remaining() < 4 {
1128 return Err(ProtocolError::UnexpectedEof);
1129 }
1130 dst.extend_from_slice(&src.get_u32_le().to_le_bytes());
1131 }
1132 TypeId::DateTime => {
1133 if src.remaining() < 8 {
1134 return Err(ProtocolError::UnexpectedEof);
1135 }
1136 let days = src.get_u32_le();
1137 let time = src.get_u32_le();
1138 dst.extend_from_slice(&days.to_le_bytes());
1139 dst.extend_from_slice(&time.to_le_bytes());
1140 }
1141 TypeId::DateTime4 => {
1142 if src.remaining() < 4 {
1143 return Err(ProtocolError::UnexpectedEof);
1144 }
1145 dst.extend_from_slice(&src.get_u32_le().to_le_bytes());
1146 }
1147 TypeId::Date => {
1149 Self::decode_bytelen_type(src, dst)?;
1150 }
1151
1152 TypeId::IntN | TypeId::BitN | TypeId::FloatN | TypeId::MoneyN | TypeId::DateTimeN => {
1154 Self::decode_bytelen_type(src, dst)?;
1155 }
1156
1157 TypeId::Guid => {
1158 Self::decode_bytelen_type(src, dst)?;
1159 }
1160
1161 TypeId::Decimal | TypeId::Numeric | TypeId::DecimalN | TypeId::NumericN => {
1162 Self::decode_bytelen_type(src, dst)?;
1163 }
1164
1165 TypeId::Char | TypeId::VarChar | TypeId::Binary | TypeId::VarBinary => {
1167 Self::decode_bytelen_type(src, dst)?;
1168 }
1169
1170 TypeId::BigVarChar | TypeId::BigVarBinary => {
1172 if col.type_info.max_length == Some(0xFFFF) {
1174 Self::decode_plp_type(src, dst)?;
1175 } else {
1176 Self::decode_ushortlen_type(src, dst)?;
1177 }
1178 }
1179
1180 TypeId::BigChar | TypeId::BigBinary => {
1182 Self::decode_ushortlen_type(src, dst)?;
1183 }
1184
1185 TypeId::NVarChar => {
1187 if col.type_info.max_length == Some(0xFFFF) {
1189 Self::decode_plp_type(src, dst)?;
1190 } else {
1191 Self::decode_ushortlen_type(src, dst)?;
1192 }
1193 }
1194
1195 TypeId::NChar => {
1197 Self::decode_ushortlen_type(src, dst)?;
1198 }
1199
1200 TypeId::Time | TypeId::DateTime2 | TypeId::DateTimeOffset => {
1202 Self::decode_bytelen_type(src, dst)?;
1203 }
1204
1205 TypeId::Text | TypeId::NText | TypeId::Image => {
1207 Self::decode_textptr_type(src, dst)?;
1208 }
1209
1210 TypeId::Xml => {
1212 Self::decode_plp_type(src, dst)?;
1213 }
1214
1215 TypeId::Variant => {
1217 Self::decode_intlen_type(src, dst)?;
1218 }
1219
1220 TypeId::Udt => {
1221 Self::decode_plp_type(src, dst)?;
1223 }
1224
1225 TypeId::Tvp => {
1226 return Err(ProtocolError::InvalidTokenType(col.col_type));
1228 }
1229 }
1230
1231 Ok(())
1232 }
1233
1234 fn decode_bytelen_type(
1236 src: &mut impl Buf,
1237 dst: &mut bytes::BytesMut,
1238 ) -> Result<(), ProtocolError> {
1239 if src.remaining() < 1 {
1240 return Err(ProtocolError::UnexpectedEof);
1241 }
1242 let len = src.get_u8() as usize;
1243 if len == 0xFF {
1244 dst.extend_from_slice(&[0xFF]);
1246 } else if len == 0 {
1247 dst.extend_from_slice(&[0x00]);
1249 } else {
1250 if src.remaining() < len {
1251 return Err(ProtocolError::UnexpectedEof);
1252 }
1253 dst.extend_from_slice(&[len as u8]);
1254 for _ in 0..len {
1255 dst.extend_from_slice(&[src.get_u8()]);
1256 }
1257 }
1258 Ok(())
1259 }
1260
1261 fn decode_ushortlen_type(
1263 src: &mut impl Buf,
1264 dst: &mut bytes::BytesMut,
1265 ) -> Result<(), ProtocolError> {
1266 if src.remaining() < 2 {
1267 return Err(ProtocolError::UnexpectedEof);
1268 }
1269 let len = src.get_u16_le() as usize;
1270 if len == 0xFFFF {
1271 dst.extend_from_slice(&0xFFFFu16.to_le_bytes());
1273 } else if len == 0 {
1274 dst.extend_from_slice(&0u16.to_le_bytes());
1276 } else {
1277 if src.remaining() < len {
1278 return Err(ProtocolError::UnexpectedEof);
1279 }
1280 dst.extend_from_slice(&(len as u16).to_le_bytes());
1281 for _ in 0..len {
1282 dst.extend_from_slice(&[src.get_u8()]);
1283 }
1284 }
1285 Ok(())
1286 }
1287
1288 fn decode_intlen_type(
1290 src: &mut impl Buf,
1291 dst: &mut bytes::BytesMut,
1292 ) -> Result<(), ProtocolError> {
1293 if src.remaining() < 4 {
1294 return Err(ProtocolError::UnexpectedEof);
1295 }
1296 let len = src.get_u32_le() as usize;
1297 if len == 0xFFFFFFFF {
1298 dst.extend_from_slice(&0xFFFFFFFFu32.to_le_bytes());
1300 } else if len == 0 {
1301 dst.extend_from_slice(&0u32.to_le_bytes());
1303 } else {
1304 if src.remaining() < len {
1305 return Err(ProtocolError::UnexpectedEof);
1306 }
1307 dst.extend_from_slice(&(len as u32).to_le_bytes());
1308 for _ in 0..len {
1309 dst.extend_from_slice(&[src.get_u8()]);
1310 }
1311 }
1312 Ok(())
1313 }
1314
1315 fn decode_textptr_type(
1330 src: &mut impl Buf,
1331 dst: &mut bytes::BytesMut,
1332 ) -> Result<(), ProtocolError> {
1333 if src.remaining() < 1 {
1334 return Err(ProtocolError::UnexpectedEof);
1335 }
1336
1337 let textptr_len = src.get_u8() as usize;
1338
1339 if textptr_len == 0 {
1340 dst.extend_from_slice(&0xFFFFFFFFFFFFFFFFu64.to_le_bytes());
1342 return Ok(());
1343 }
1344
1345 if src.remaining() < textptr_len {
1347 return Err(ProtocolError::UnexpectedEof);
1348 }
1349 src.advance(textptr_len);
1350
1351 if src.remaining() < 8 {
1353 return Err(ProtocolError::UnexpectedEof);
1354 }
1355 src.advance(8);
1356
1357 if src.remaining() < 4 {
1359 return Err(ProtocolError::UnexpectedEof);
1360 }
1361 let data_len = src.get_u32_le() as usize;
1362
1363 if src.remaining() < data_len {
1364 return Err(ProtocolError::UnexpectedEof);
1365 }
1366
1367 dst.extend_from_slice(&(data_len as u64).to_le_bytes());
1373 dst.extend_from_slice(&(data_len as u32).to_le_bytes());
1374 for _ in 0..data_len {
1375 dst.extend_from_slice(&[src.get_u8()]);
1376 }
1377 dst.extend_from_slice(&0u32.to_le_bytes()); Ok(())
1380 }
1381
1382 fn decode_plp_type(src: &mut impl Buf, dst: &mut bytes::BytesMut) -> Result<(), ProtocolError> {
1388 if src.remaining() < 8 {
1389 return Err(ProtocolError::UnexpectedEof);
1390 }
1391
1392 let total_len = src.get_u64_le();
1393
1394 dst.extend_from_slice(&total_len.to_le_bytes());
1396
1397 if total_len == 0xFFFFFFFFFFFFFFFF {
1398 return Ok(());
1400 }
1401
1402 loop {
1404 if src.remaining() < 4 {
1405 return Err(ProtocolError::UnexpectedEof);
1406 }
1407 let chunk_len = src.get_u32_le() as usize;
1408 dst.extend_from_slice(&(chunk_len as u32).to_le_bytes());
1409
1410 if chunk_len == 0 {
1411 break;
1413 }
1414
1415 if src.remaining() < chunk_len {
1416 return Err(ProtocolError::UnexpectedEof);
1417 }
1418
1419 for _ in 0..chunk_len {
1420 dst.extend_from_slice(&[src.get_u8()]);
1421 }
1422 }
1423
1424 Ok(())
1425 }
1426}
1427
1428impl NbcRow {
1433 pub fn decode(src: &mut impl Buf, metadata: &ColMetaData) -> Result<Self, ProtocolError> {
1438 let col_count = metadata.columns.len();
1439 let bitmap_len = col_count.div_ceil(8);
1440
1441 if src.remaining() < bitmap_len {
1442 return Err(ProtocolError::UnexpectedEof);
1443 }
1444
1445 let mut null_bitmap = vec![0u8; bitmap_len];
1447 for byte in &mut null_bitmap {
1448 *byte = src.get_u8();
1449 }
1450
1451 let mut data = bytes::BytesMut::new();
1453
1454 for (i, col) in metadata.columns.iter().enumerate() {
1455 let byte_idx = i / 8;
1456 let bit_idx = i % 8;
1457 let is_null = (null_bitmap[byte_idx] & (1 << bit_idx)) != 0;
1458
1459 if !is_null {
1460 RawRow::decode_column_value(src, col, &mut data)?;
1463 }
1464 }
1465
1466 Ok(Self {
1467 null_bitmap,
1468 data: data.freeze(),
1469 })
1470 }
1471
1472 pub fn decode_prefix(
1478 src: &mut impl Buf,
1479 metadata: &ColMetaData,
1480 prefix_len: usize,
1481 ) -> Result<Self, ProtocolError> {
1482 let col_count = metadata.columns.len();
1483 let bitmap_len = col_count.div_ceil(8);
1484
1485 if src.remaining() < bitmap_len {
1486 return Err(ProtocolError::UnexpectedEof);
1487 }
1488
1489 let mut null_bitmap = vec![0u8; bitmap_len];
1490 for byte in &mut null_bitmap {
1491 *byte = src.get_u8();
1492 }
1493
1494 let mut data = bytes::BytesMut::new();
1495 for (i, col) in metadata.columns.iter().enumerate().take(prefix_len) {
1496 let is_null = (null_bitmap[i / 8] & (1 << (i % 8))) != 0;
1497 if !is_null {
1498 RawRow::decode_column_value(src, col, &mut data)?;
1499 }
1500 }
1501
1502 Ok(Self {
1503 null_bitmap,
1504 data: data.freeze(),
1505 })
1506 }
1507
1508 #[must_use]
1510 pub fn is_null(&self, column_index: usize) -> bool {
1511 let byte_idx = column_index / 8;
1512 let bit_idx = column_index % 8;
1513 if byte_idx < self.null_bitmap.len() {
1514 (self.null_bitmap[byte_idx] & (1 << bit_idx)) != 0
1515 } else {
1516 true }
1518 }
1519}
1520
1521impl ReturnValue {
1526 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
1528 if src.remaining() < 2 {
1535 return Err(ProtocolError::UnexpectedEof);
1536 }
1537 let param_ordinal = src.get_u16_le();
1538
1539 let param_name = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
1541
1542 if src.remaining() < 1 {
1544 return Err(ProtocolError::UnexpectedEof);
1545 }
1546 let status = src.get_u8();
1547
1548 if src.remaining() < 7 {
1550 return Err(ProtocolError::UnexpectedEof);
1551 }
1552 let user_type = src.get_u32_le();
1553 let flags = src.get_u16_le();
1554 let col_type = src.get_u8();
1555
1556 let type_id = TypeId::from_u8(col_type).ok_or(ProtocolError::InvalidDataType(col_type))?;
1557
1558 let type_info = decode_type_info(src, type_id, col_type)?;
1560
1561 let mut value_buf = bytes::BytesMut::new();
1563
1564 let temp_col = ColumnData {
1566 name: String::new(),
1567 type_id,
1568 col_type,
1569 flags,
1570 user_type,
1571 type_info: type_info.clone(),
1572 crypto_metadata: None,
1573 };
1574
1575 RawRow::decode_column_value(src, &temp_col, &mut value_buf)?;
1576
1577 Ok(Self {
1578 param_ordinal,
1579 param_name,
1580 status,
1581 user_type,
1582 flags,
1583 col_type,
1584 type_info,
1585 value: value_buf.freeze(),
1586 })
1587 }
1588}
1589
1590impl SessionState {
1595 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
1597 if src.remaining() < 4 {
1598 return Err(ProtocolError::UnexpectedEof);
1599 }
1600
1601 let length = src.get_u32_le() as usize;
1602
1603 if src.remaining() < length {
1604 return Err(ProtocolError::IncompletePacket {
1605 expected: length,
1606 actual: src.remaining(),
1607 });
1608 }
1609
1610 let data = src.copy_to_bytes(length);
1611
1612 Ok(Self { data })
1613 }
1614}
1615
1616mod done_status_bits {
1622 pub const DONE_MORE: u16 = 0x0001;
1623 pub const DONE_ERROR: u16 = 0x0002;
1624 pub const DONE_INXACT: u16 = 0x0004;
1625 pub const DONE_COUNT: u16 = 0x0010;
1626 pub const DONE_ATTN: u16 = 0x0020;
1627 pub const DONE_SRVERROR: u16 = 0x0100;
1628}
1629
1630impl DoneStatus {
1631 #[must_use]
1633 pub fn from_bits(bits: u16) -> Self {
1634 use done_status_bits::*;
1635 Self {
1636 more: (bits & DONE_MORE) != 0,
1637 error: (bits & DONE_ERROR) != 0,
1638 in_xact: (bits & DONE_INXACT) != 0,
1639 count: (bits & DONE_COUNT) != 0,
1640 attn: (bits & DONE_ATTN) != 0,
1641 srverror: (bits & DONE_SRVERROR) != 0,
1642 }
1643 }
1644
1645 #[must_use]
1647 pub fn to_bits(&self) -> u16 {
1648 use done_status_bits::*;
1649 let mut bits = 0u16;
1650 if self.more {
1651 bits |= DONE_MORE;
1652 }
1653 if self.error {
1654 bits |= DONE_ERROR;
1655 }
1656 if self.in_xact {
1657 bits |= DONE_INXACT;
1658 }
1659 if self.count {
1660 bits |= DONE_COUNT;
1661 }
1662 if self.attn {
1663 bits |= DONE_ATTN;
1664 }
1665 if self.srverror {
1666 bits |= DONE_SRVERROR;
1667 }
1668 bits
1669 }
1670}
1671
1672impl Done {
1673 pub const SIZE: usize = 12; pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
1678 if src.remaining() < Self::SIZE {
1679 return Err(ProtocolError::IncompletePacket {
1680 expected: Self::SIZE,
1681 actual: src.remaining(),
1682 });
1683 }
1684
1685 let status = DoneStatus::from_bits(src.get_u16_le());
1686 let cur_cmd = src.get_u16_le();
1687 let row_count = src.get_u64_le();
1688
1689 Ok(Self {
1690 status,
1691 cur_cmd,
1692 row_count,
1693 })
1694 }
1695
1696 pub fn encode(&self, dst: &mut impl BufMut) {
1698 dst.put_u8(TokenType::Done as u8);
1699 dst.put_u16_le(self.status.to_bits());
1700 dst.put_u16_le(self.cur_cmd);
1701 dst.put_u64_le(self.row_count);
1702 }
1703
1704 #[must_use]
1706 pub const fn has_more(&self) -> bool {
1707 self.status.more
1708 }
1709
1710 #[must_use]
1712 pub const fn has_error(&self) -> bool {
1713 self.status.error
1714 }
1715
1716 #[must_use]
1718 pub const fn has_count(&self) -> bool {
1719 self.status.count
1720 }
1721}
1722
1723impl DoneProc {
1724 pub const SIZE: usize = 12;
1726
1727 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
1729 if src.remaining() < Self::SIZE {
1730 return Err(ProtocolError::IncompletePacket {
1731 expected: Self::SIZE,
1732 actual: src.remaining(),
1733 });
1734 }
1735
1736 let status = DoneStatus::from_bits(src.get_u16_le());
1737 let cur_cmd = src.get_u16_le();
1738 let row_count = src.get_u64_le();
1739
1740 Ok(Self {
1741 status,
1742 cur_cmd,
1743 row_count,
1744 })
1745 }
1746
1747 pub fn encode(&self, dst: &mut impl BufMut) {
1749 dst.put_u8(TokenType::DoneProc as u8);
1750 dst.put_u16_le(self.status.to_bits());
1751 dst.put_u16_le(self.cur_cmd);
1752 dst.put_u64_le(self.row_count);
1753 }
1754}
1755
1756impl DoneInProc {
1757 pub const SIZE: usize = 12;
1759
1760 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
1762 if src.remaining() < Self::SIZE {
1763 return Err(ProtocolError::IncompletePacket {
1764 expected: Self::SIZE,
1765 actual: src.remaining(),
1766 });
1767 }
1768
1769 let status = DoneStatus::from_bits(src.get_u16_le());
1770 let cur_cmd = src.get_u16_le();
1771 let row_count = src.get_u64_le();
1772
1773 Ok(Self {
1774 status,
1775 cur_cmd,
1776 row_count,
1777 })
1778 }
1779
1780 pub fn encode(&self, dst: &mut impl BufMut) {
1782 dst.put_u8(TokenType::DoneInProc as u8);
1783 dst.put_u16_le(self.status.to_bits());
1784 dst.put_u16_le(self.cur_cmd);
1785 dst.put_u64_le(self.row_count);
1786 }
1787}
1788
1789impl ServerError {
1790 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
1792 if src.remaining() < 2 {
1795 return Err(ProtocolError::UnexpectedEof);
1796 }
1797
1798 let _length = src.get_u16_le();
1799
1800 if src.remaining() < 6 {
1801 return Err(ProtocolError::UnexpectedEof);
1802 }
1803
1804 let number = src.get_i32_le();
1805 let state = src.get_u8();
1806 let class = src.get_u8();
1807
1808 let message = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
1809 let server = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
1810 let procedure = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
1811
1812 if src.remaining() < 4 {
1813 return Err(ProtocolError::UnexpectedEof);
1814 }
1815 let line = src.get_i32_le();
1816
1817 Ok(Self {
1818 number,
1819 state,
1820 class,
1821 message,
1822 server,
1823 procedure,
1824 line,
1825 })
1826 }
1827
1828 #[must_use]
1830 pub const fn is_fatal(&self) -> bool {
1831 self.class >= 20
1832 }
1833
1834 #[must_use]
1836 pub const fn is_batch_abort(&self) -> bool {
1837 self.class >= 16
1838 }
1839}
1840
1841impl ServerInfo {
1842 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
1846 if src.remaining() < 2 {
1847 return Err(ProtocolError::UnexpectedEof);
1848 }
1849
1850 let _length = src.get_u16_le();
1851
1852 if src.remaining() < 6 {
1853 return Err(ProtocolError::UnexpectedEof);
1854 }
1855
1856 let number = src.get_i32_le();
1857 let state = src.get_u8();
1858 let class = src.get_u8();
1859
1860 let message = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
1861 let server = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
1862 let procedure = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
1863
1864 if src.remaining() < 4 {
1865 return Err(ProtocolError::UnexpectedEof);
1866 }
1867 let line = src.get_i32_le();
1868
1869 Ok(Self {
1870 number,
1871 state,
1872 class,
1873 message,
1874 server,
1875 procedure,
1876 line,
1877 })
1878 }
1879}
1880
1881impl LoginAck {
1882 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
1884 if src.remaining() < 2 {
1886 return Err(ProtocolError::UnexpectedEof);
1887 }
1888
1889 let _length = src.get_u16_le();
1890
1891 if src.remaining() < 5 {
1892 return Err(ProtocolError::UnexpectedEof);
1893 }
1894
1895 let interface = src.get_u8();
1896 let tds_version = src.get_u32_le();
1897 let prog_name = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
1898
1899 if src.remaining() < 4 {
1900 return Err(ProtocolError::UnexpectedEof);
1901 }
1902 let prog_version = src.get_u32_le();
1903
1904 Ok(Self {
1905 interface,
1906 tds_version,
1907 prog_name,
1908 prog_version,
1909 })
1910 }
1911
1912 #[must_use]
1914 pub fn tds_version(&self) -> crate::version::TdsVersion {
1915 crate::version::TdsVersion::new(self.tds_version)
1916 }
1917}
1918
1919impl EnvChangeType {
1920 pub fn from_u8(value: u8) -> Option<Self> {
1922 match value {
1923 1 => Some(Self::Database),
1924 2 => Some(Self::Language),
1925 3 => Some(Self::CharacterSet),
1926 4 => Some(Self::PacketSize),
1927 5 => Some(Self::UnicodeSortingLocalId),
1928 6 => Some(Self::UnicodeComparisonFlags),
1929 7 => Some(Self::SqlCollation),
1930 8 => Some(Self::BeginTransaction),
1931 9 => Some(Self::CommitTransaction),
1932 10 => Some(Self::RollbackTransaction),
1933 11 => Some(Self::EnlistDtcTransaction),
1934 12 => Some(Self::DefectTransaction),
1935 13 => Some(Self::RealTimeLogShipping),
1936 15 => Some(Self::PromoteTransaction),
1937 16 => Some(Self::TransactionManagerAddress),
1938 17 => Some(Self::TransactionEnded),
1939 18 => Some(Self::ResetConnectionCompletionAck),
1940 19 => Some(Self::UserInstanceStarted),
1941 20 => Some(Self::Routing),
1942 _ => None,
1943 }
1944 }
1945}
1946
1947impl EnvChange {
1948 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
1950 if src.remaining() < 3 {
1951 return Err(ProtocolError::UnexpectedEof);
1952 }
1953
1954 let length = src.get_u16_le() as usize;
1955 if length == 0 {
1956 return Err(ProtocolError::UnexpectedEof);
1959 }
1960 if src.remaining() < length {
1961 return Err(ProtocolError::IncompletePacket {
1962 expected: length,
1963 actual: src.remaining(),
1964 });
1965 }
1966
1967 let mut frame = src.copy_to_bytes(length);
1976 let src = &mut frame;
1977
1978 let env_type_byte = src.get_u8();
1979 let env_type = EnvChangeType::from_u8(env_type_byte)
1980 .ok_or(ProtocolError::InvalidTokenType(env_type_byte))?;
1981
1982 let (new_value, old_value) = match env_type {
1983 EnvChangeType::Routing => {
1984 let new_value = Self::decode_routing_value(src)?;
1986 let old_value = EnvChangeValue::Binary(Bytes::new());
1987 (new_value, old_value)
1988 }
1989 EnvChangeType::BeginTransaction
1990 | EnvChangeType::CommitTransaction
1991 | EnvChangeType::RollbackTransaction
1992 | EnvChangeType::EnlistDtcTransaction
1993 | EnvChangeType::SqlCollation => {
1994 let new_len = if src.has_remaining() {
2003 src.get_u8() as usize
2004 } else {
2005 0
2006 };
2007 let new_value = if new_len > 0 && src.remaining() >= new_len {
2008 EnvChangeValue::Binary(src.copy_to_bytes(new_len))
2009 } else {
2010 EnvChangeValue::Binary(Bytes::new())
2011 };
2012
2013 let old_len = if src.has_remaining() {
2014 src.get_u8() as usize
2015 } else {
2016 0
2017 };
2018 let old_value = if old_len > 0 && src.remaining() >= old_len {
2019 EnvChangeValue::Binary(src.copy_to_bytes(old_len))
2020 } else {
2021 EnvChangeValue::Binary(Bytes::new())
2022 };
2023
2024 (new_value, old_value)
2025 }
2026 _ => {
2027 let new_value = read_b_varchar(src)
2029 .map(EnvChangeValue::String)
2030 .unwrap_or(EnvChangeValue::String(String::new()));
2031
2032 let old_value = read_b_varchar(src)
2033 .map(EnvChangeValue::String)
2034 .unwrap_or(EnvChangeValue::String(String::new()));
2035
2036 (new_value, old_value)
2037 }
2038 };
2039
2040 Ok(Self {
2046 env_type,
2047 new_value,
2048 old_value,
2049 })
2050 }
2051
2052 fn decode_routing_value(src: &mut impl Buf) -> Result<EnvChangeValue, ProtocolError> {
2053 if src.remaining() < 2 {
2055 return Err(ProtocolError::UnexpectedEof);
2056 }
2057
2058 let _routing_len = src.get_u16_le();
2059
2060 if src.remaining() < 5 {
2061 return Err(ProtocolError::UnexpectedEof);
2062 }
2063
2064 let _protocol = src.get_u8();
2065 let port = src.get_u16_le();
2066 let server_len = src.get_u16_le() as usize;
2067
2068 if src.remaining() < server_len * 2 {
2070 return Err(ProtocolError::UnexpectedEof);
2071 }
2072
2073 let mut chars = Vec::with_capacity(server_len);
2074 for _ in 0..server_len {
2075 chars.push(src.get_u16_le());
2076 }
2077
2078 let host = String::from_utf16(&chars).map_err(|_| {
2079 ProtocolError::StringEncoding(
2080 #[cfg(feature = "std")]
2081 "invalid UTF-16 in routing hostname".to_string(),
2082 #[cfg(not(feature = "std"))]
2083 "invalid UTF-16 in routing hostname",
2084 )
2085 })?;
2086
2087 Ok(EnvChangeValue::Routing { host, port })
2088 }
2089
2090 #[must_use]
2092 pub fn is_routing(&self) -> bool {
2093 self.env_type == EnvChangeType::Routing
2094 }
2095
2096 #[must_use]
2098 pub fn routing_info(&self) -> Option<(&str, u16)> {
2099 if let EnvChangeValue::Routing { host, port } = &self.new_value {
2100 Some((host, *port))
2101 } else {
2102 None
2103 }
2104 }
2105
2106 #[must_use]
2108 pub fn new_database(&self) -> Option<&str> {
2109 if self.env_type == EnvChangeType::Database {
2110 if let EnvChangeValue::String(s) = &self.new_value {
2111 return Some(s);
2112 }
2113 }
2114 None
2115 }
2116}
2117
2118impl Order {
2119 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
2121 if src.remaining() < 2 {
2122 return Err(ProtocolError::UnexpectedEof);
2123 }
2124
2125 let length = src.get_u16_le() as usize;
2126 let column_count = length / 2;
2127
2128 if src.remaining() < length {
2129 return Err(ProtocolError::IncompletePacket {
2130 expected: length,
2131 actual: src.remaining(),
2132 });
2133 }
2134
2135 let mut columns = Vec::with_capacity(column_count);
2136 for _ in 0..column_count {
2137 columns.push(src.get_u16_le());
2138 }
2139
2140 Ok(Self { columns })
2141 }
2142}
2143
2144impl FeatureExtAck {
2145 pub const TERMINATOR: u8 = 0xFF;
2147
2148 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
2150 let mut features = Vec::new();
2151
2152 loop {
2153 if !src.has_remaining() {
2154 return Err(ProtocolError::UnexpectedEof);
2155 }
2156
2157 let feature_id = src.get_u8();
2158 if feature_id == Self::TERMINATOR {
2159 break;
2160 }
2161
2162 if src.remaining() < 4 {
2163 return Err(ProtocolError::UnexpectedEof);
2164 }
2165
2166 let data_len = src.get_u32_le() as usize;
2167
2168 if src.remaining() < data_len {
2169 return Err(ProtocolError::IncompletePacket {
2170 expected: data_len,
2171 actual: src.remaining(),
2172 });
2173 }
2174
2175 let data = src.copy_to_bytes(data_len);
2176 features.push(FeatureAck { feature_id, data });
2177 }
2178
2179 Ok(Self { features })
2180 }
2181}
2182
2183impl SspiToken {
2184 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
2186 if src.remaining() < 2 {
2187 return Err(ProtocolError::UnexpectedEof);
2188 }
2189
2190 let length = src.get_u16_le() as usize;
2191
2192 if src.remaining() < length {
2193 return Err(ProtocolError::IncompletePacket {
2194 expected: length,
2195 actual: src.remaining(),
2196 });
2197 }
2198
2199 let data = src.copy_to_bytes(length);
2200 Ok(Self { data })
2201 }
2202}
2203
2204impl FedAuthInfo {
2205 const ID_STSURL: u8 = 0x01;
2207 const ID_SPN: u8 = 0x02;
2210 const OPT_HEADER_LEN: usize = 9;
2212
2213 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
2225 if src.remaining() < 4 {
2226 return Err(ProtocolError::UnexpectedEof);
2227 }
2228 let token_len = src.get_u32_le() as usize;
2229 if src.remaining() < token_len {
2230 return Err(ProtocolError::UnexpectedEof);
2231 }
2232
2233 let region = src.copy_to_bytes(token_len);
2236 if region.len() < 4 {
2237 return Err(ProtocolError::UnexpectedEof);
2238 }
2239 let count = u32::from_le_bytes([region[0], region[1], region[2], region[3]]) as usize;
2240
2241 let headers_end = count
2245 .checked_mul(Self::OPT_HEADER_LEN)
2246 .and_then(|n| n.checked_add(4))
2247 .ok_or(ProtocolError::UnexpectedEof)?;
2248 if headers_end > region.len() {
2249 return Err(ProtocolError::UnexpectedEof);
2250 }
2251
2252 let mut sts_url = String::new();
2253 let mut spn = String::new();
2254
2255 for i in 0..count {
2256 let h = 4 + i * Self::OPT_HEADER_LEN;
2257 let info_id = region[h];
2258 let data_len =
2259 u32::from_le_bytes([region[h + 1], region[h + 2], region[h + 3], region[h + 4]])
2260 as usize;
2261 let data_off =
2262 u32::from_le_bytes([region[h + 5], region[h + 6], region[h + 7], region[h + 8]])
2263 as usize;
2264
2265 if info_id != Self::ID_SPN && info_id != Self::ID_STSURL {
2268 continue;
2269 }
2270
2271 let data_end = data_off
2272 .checked_add(data_len)
2273 .ok_or(ProtocolError::UnexpectedEof)?;
2274 if data_end > region.len() {
2275 return Err(ProtocolError::UnexpectedEof);
2276 }
2277 if data_len % 2 != 0 {
2278 return Err(ProtocolError::StringEncoding(
2279 #[cfg(feature = "std")]
2280 "FEDAUTHINFO option data has odd length, not UTF-16".to_string(),
2281 #[cfg(not(feature = "std"))]
2282 "FEDAUTHINFO option data has odd length, not UTF-16",
2283 ));
2284 }
2285
2286 let chars: Vec<u16> = region[data_off..data_end]
2287 .chunks_exact(2)
2288 .map(|b| u16::from_le_bytes([b[0], b[1]]))
2289 .collect();
2290 let value = String::from_utf16(&chars).map_err(|_| {
2291 ProtocolError::StringEncoding(
2292 #[cfg(feature = "std")]
2293 "invalid UTF-16 in FEDAUTHINFO option".to_string(),
2294 #[cfg(not(feature = "std"))]
2295 "invalid UTF-16 in FEDAUTHINFO option",
2296 )
2297 })?;
2298
2299 if info_id == Self::ID_SPN {
2300 spn = value;
2301 } else {
2302 sts_url = value;
2303 }
2304 }
2305
2306 Ok(Self { sts_url, spn })
2307 }
2308}
2309
2310pub struct TokenParser {
2351 data: Bytes,
2352 position: usize,
2353 encryption_enabled: bool,
2356}
2357
2358impl TokenParser {
2359 #[must_use]
2361 pub fn new(data: Bytes) -> Self {
2362 Self {
2363 data,
2364 position: 0,
2365 encryption_enabled: false,
2366 }
2367 }
2368
2369 #[must_use]
2374 pub fn with_encryption(mut self, enabled: bool) -> Self {
2375 self.encryption_enabled = enabled;
2376 self
2377 }
2378
2379 #[must_use]
2381 pub fn remaining(&self) -> usize {
2382 self.data.len().saturating_sub(self.position)
2383 }
2384
2385 #[must_use]
2387 pub fn has_remaining(&self) -> bool {
2388 self.position < self.data.len()
2389 }
2390
2391 #[must_use]
2393 pub fn peek_token_type(&self) -> Option<TokenType> {
2394 if self.position < self.data.len() {
2395 TokenType::from_u8(self.data[self.position])
2396 } else {
2397 None
2398 }
2399 }
2400
2401 pub fn next_token(&mut self) -> Result<Option<Token>, ProtocolError> {
2409 self.next_token_with_metadata(None)
2410 }
2411
2412 pub fn next_token_with_metadata(
2419 &mut self,
2420 metadata: Option<&ColMetaData>,
2421 ) -> Result<Option<Token>, ProtocolError> {
2422 if !self.has_remaining() {
2423 return Ok(None);
2424 }
2425
2426 let mut buf = &self.data[self.position..];
2427 let start_pos = self.position;
2428
2429 let token_type_byte = buf.get_u8();
2430 let token_type = TokenType::from_u8(token_type_byte);
2431
2432 let token = match token_type {
2433 Some(TokenType::Done) => {
2434 let done = Done::decode(&mut buf)?;
2435 Token::Done(done)
2436 }
2437 Some(TokenType::DoneProc) => {
2438 let done = DoneProc::decode(&mut buf)?;
2439 Token::DoneProc(done)
2440 }
2441 Some(TokenType::DoneInProc) => {
2442 let done = DoneInProc::decode(&mut buf)?;
2443 Token::DoneInProc(done)
2444 }
2445 Some(TokenType::Error) => {
2446 let error = ServerError::decode(&mut buf)?;
2447 Token::Error(error)
2448 }
2449 Some(TokenType::Info) => {
2450 let info = ServerInfo::decode(&mut buf)?;
2451 Token::Info(info)
2452 }
2453 Some(TokenType::LoginAck) => {
2454 let login_ack = LoginAck::decode(&mut buf)?;
2455 Token::LoginAck(login_ack)
2456 }
2457 Some(TokenType::EnvChange) => {
2458 let env_change = EnvChange::decode(&mut buf)?;
2459 Token::EnvChange(env_change)
2460 }
2461 Some(TokenType::Order) => {
2462 let order = Order::decode(&mut buf)?;
2463 Token::Order(order)
2464 }
2465 Some(TokenType::FeatureExtAck) => {
2466 let ack = FeatureExtAck::decode(&mut buf)?;
2467 Token::FeatureExtAck(ack)
2468 }
2469 Some(TokenType::Sspi) => {
2470 let sspi = SspiToken::decode(&mut buf)?;
2471 Token::Sspi(sspi)
2472 }
2473 Some(TokenType::FedAuthInfo) => {
2474 let info = FedAuthInfo::decode(&mut buf)?;
2475 Token::FedAuthInfo(info)
2476 }
2477 Some(TokenType::ReturnStatus) => {
2478 if buf.remaining() < 4 {
2479 return Err(ProtocolError::UnexpectedEof);
2480 }
2481 let status = buf.get_i32_le();
2482 Token::ReturnStatus(status)
2483 }
2484 Some(TokenType::ColMetaData) => {
2485 let col_meta = if self.encryption_enabled {
2486 ColMetaData::decode_encrypted(&mut buf)?
2487 } else {
2488 ColMetaData::decode(&mut buf)?
2489 };
2490 Token::ColMetaData(col_meta)
2491 }
2492 Some(TokenType::Row) => {
2493 let meta = metadata.ok_or_else(|| {
2494 ProtocolError::StringEncoding(
2495 #[cfg(feature = "std")]
2496 "Row token requires column metadata".to_string(),
2497 #[cfg(not(feature = "std"))]
2498 "Row token requires column metadata",
2499 )
2500 })?;
2501 let row = RawRow::decode(&mut buf, meta)?;
2502 Token::Row(row)
2503 }
2504 Some(TokenType::NbcRow) => {
2505 let meta = metadata.ok_or_else(|| {
2506 ProtocolError::StringEncoding(
2507 #[cfg(feature = "std")]
2508 "NbcRow token requires column metadata".to_string(),
2509 #[cfg(not(feature = "std"))]
2510 "NbcRow token requires column metadata",
2511 )
2512 })?;
2513 let row = NbcRow::decode(&mut buf, meta)?;
2514 Token::NbcRow(row)
2515 }
2516 Some(TokenType::ReturnValue) => {
2517 let ret_val = ReturnValue::decode(&mut buf)?;
2518 Token::ReturnValue(ret_val)
2519 }
2520 Some(TokenType::SessionState) => {
2521 let session = SessionState::decode(&mut buf)?;
2522 Token::SessionState(session)
2523 }
2524 Some(TokenType::ColInfo) | Some(TokenType::TabName) | Some(TokenType::Offset) => {
2525 if buf.remaining() < 2 {
2528 return Err(ProtocolError::UnexpectedEof);
2529 }
2530 let length = buf.get_u16_le() as usize;
2531 if buf.remaining() < length {
2532 return Err(ProtocolError::IncompletePacket {
2533 expected: length,
2534 actual: buf.remaining(),
2535 });
2536 }
2537 buf.advance(length);
2539 self.position = start_pos + (self.data.len() - start_pos - buf.remaining());
2541 return self.next_token_with_metadata(metadata);
2542 }
2543 None => {
2544 return Err(ProtocolError::InvalidTokenType(token_type_byte));
2545 }
2546 };
2547
2548 let consumed = self.data.len() - start_pos - buf.remaining();
2550 self.position = start_pos + consumed;
2551
2552 Ok(Some(token))
2553 }
2554
2555 pub fn skip_token(&mut self) -> Result<(), ProtocolError> {
2559 if !self.has_remaining() {
2560 return Ok(());
2561 }
2562
2563 let token_type_byte = self.data[self.position];
2564 let token_type = TokenType::from_u8(token_type_byte);
2565
2566 let skip_amount = match token_type {
2568 Some(TokenType::Done) | Some(TokenType::DoneProc) | Some(TokenType::DoneInProc) => {
2570 1 + Done::SIZE }
2572 Some(TokenType::ReturnStatus) => {
2573 1 + 4 }
2575 Some(TokenType::Error)
2577 | Some(TokenType::Info)
2578 | Some(TokenType::LoginAck)
2579 | Some(TokenType::EnvChange)
2580 | Some(TokenType::Order)
2581 | Some(TokenType::Sspi)
2582 | Some(TokenType::ColInfo)
2583 | Some(TokenType::TabName)
2584 | Some(TokenType::Offset)
2585 | Some(TokenType::ReturnValue) => {
2586 if self.remaining() < 3 {
2587 return Err(ProtocolError::UnexpectedEof);
2588 }
2589 let length = u16::from_le_bytes([
2590 self.data[self.position + 1],
2591 self.data[self.position + 2],
2592 ]) as usize;
2593 1 + 2 + length }
2595 Some(TokenType::SessionState) | Some(TokenType::FedAuthInfo) => {
2597 if self.remaining() < 5 {
2598 return Err(ProtocolError::UnexpectedEof);
2599 }
2600 let length = u32::from_le_bytes([
2601 self.data[self.position + 1],
2602 self.data[self.position + 2],
2603 self.data[self.position + 3],
2604 self.data[self.position + 4],
2605 ]) as usize;
2606 1 + 4 + length
2607 }
2608 Some(TokenType::FeatureExtAck) => {
2610 let mut buf = &self.data[self.position + 1..];
2612 let _ = FeatureExtAck::decode(&mut buf)?;
2613 self.data.len() - self.position - buf.remaining()
2614 }
2615 Some(TokenType::ColMetaData) | Some(TokenType::Row) | Some(TokenType::NbcRow) => {
2617 return Err(ProtocolError::InvalidTokenType(token_type_byte));
2618 }
2619 None => {
2620 return Err(ProtocolError::InvalidTokenType(token_type_byte));
2621 }
2622 };
2623
2624 if self.remaining() < skip_amount {
2625 return Err(ProtocolError::UnexpectedEof);
2626 }
2627
2628 self.position += skip_amount;
2629 Ok(())
2630 }
2631
2632 #[must_use]
2634 pub fn position(&self) -> usize {
2635 self.position
2636 }
2637
2638 pub fn reset(&mut self) {
2640 self.position = 0;
2641 }
2642}
2643
2644#[cfg(test)]
2649#[allow(clippy::unwrap_used, clippy::panic)]
2650mod tests {
2651 use super::*;
2652 use bytes::BytesMut;
2653
2654 #[test]
2655 fn test_done_roundtrip() {
2656 let done = Done {
2657 status: DoneStatus {
2658 more: false,
2659 error: false,
2660 in_xact: false,
2661 count: true,
2662 attn: false,
2663 srverror: false,
2664 },
2665 cur_cmd: 193, row_count: 42,
2667 };
2668
2669 let mut buf = BytesMut::new();
2670 done.encode(&mut buf);
2671
2672 let mut cursor = &buf[1..];
2674 let decoded = Done::decode(&mut cursor).unwrap();
2675
2676 assert_eq!(decoded.status.count, done.status.count);
2677 assert_eq!(decoded.cur_cmd, done.cur_cmd);
2678 assert_eq!(decoded.row_count, done.row_count);
2679 }
2680
2681 #[test]
2682 fn test_done_status_bits() {
2683 let status = DoneStatus {
2684 more: true,
2685 error: true,
2686 in_xact: true,
2687 count: true,
2688 attn: false,
2689 srverror: false,
2690 };
2691
2692 let bits = status.to_bits();
2693 let restored = DoneStatus::from_bits(bits);
2694
2695 assert_eq!(status.more, restored.more);
2696 assert_eq!(status.error, restored.error);
2697 assert_eq!(status.in_xact, restored.in_xact);
2698 assert_eq!(status.count, restored.count);
2699 }
2700
2701 #[test]
2702 fn test_token_parser_done() {
2703 let data = Bytes::from_static(&[
2705 0xFD, 0x10, 0x00, 0xC1, 0x00, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ]);
2710
2711 let mut parser = TokenParser::new(data);
2712 let token = parser.next_token().unwrap().unwrap();
2713
2714 match token {
2715 Token::Done(done) => {
2716 assert!(done.status.count);
2717 assert!(!done.status.more);
2718 assert_eq!(done.cur_cmd, 193);
2719 assert_eq!(done.row_count, 5);
2720 }
2721 _ => panic!("Expected Done token"),
2722 }
2723
2724 assert!(parser.next_token().unwrap().is_none());
2726 }
2727
2728 #[test]
2729 fn test_env_change_type_from_u8() {
2730 assert_eq!(EnvChangeType::from_u8(1), Some(EnvChangeType::Database));
2731 assert_eq!(EnvChangeType::from_u8(20), Some(EnvChangeType::Routing));
2732 assert_eq!(EnvChangeType::from_u8(100), None);
2733 }
2734
2735 #[test]
2742 fn test_env_change_routing_consumes_declared_length() {
2743 let host = "redirect.example";
2744 let host_utf16: Vec<u16> = host.encode_utf16().collect();
2745
2746 let mut data = BytesMut::new();
2747 let routing_len = 1 + 2 + 2 + host_utf16.len() * 2;
2749 let env_len = 1 + 2 + routing_len + 2;
2752 data.put_u16_le(env_len as u16);
2753 data.put_u8(20); data.put_u16_le(routing_len as u16);
2755 data.put_u8(0); data.put_u16_le(11000); data.put_u16_le(host_utf16.len() as u16);
2758 for c in &host_utf16 {
2759 data.put_u16_le(*c);
2760 }
2761 data.put_u16_le(0); data.put_u8(0xFD);
2764
2765 let mut buf: &[u8] = &data;
2766 let env = EnvChange::decode(&mut buf).unwrap();
2767 assert_eq!(env.routing_info(), Some((host, 11000)));
2768 assert_eq!(
2769 buf,
2770 &[0xFD],
2771 "decode must consume exactly the declared ENVCHANGE frame"
2772 );
2773 }
2774
2775 fn put_b_varchar(buf: &mut BytesMut, s: &str) {
2776 let utf16: Vec<u16> = s.encode_utf16().collect();
2777 buf.put_u8(utf16.len() as u8);
2778 for c in utf16 {
2779 buf.put_u16_le(c);
2780 }
2781 }
2782
2783 fn put_us_varchar(buf: &mut BytesMut, s: &str) {
2784 let utf16: Vec<u16> = s.encode_utf16().collect();
2785 buf.put_u16_le(utf16.len() as u16);
2786 for c in utf16 {
2787 buf.put_u16_le(c);
2788 }
2789 }
2790
2791 #[test]
2798 fn test_udt_info_metadata_uses_b_varchar_names() {
2799 let mut data = BytesMut::new();
2800 data.put_u16_le(0xFFFF); put_b_varchar(&mut data, "master");
2802 put_b_varchar(&mut data, "dbo");
2803 put_b_varchar(&mut data, "hierarchyid");
2804 put_us_varchar(
2805 &mut data,
2806 "Microsoft.SqlServer.Types.SqlHierarchyId, Microsoft.SqlServer.Types",
2807 );
2808 data.put_u8(0xFD);
2810
2811 let mut buf: &[u8] = &data;
2812 let info = decode_type_info(&mut buf, TypeId::Udt, TypeId::Udt as u8).unwrap();
2813 assert_eq!(info.max_length, Some(0xFFFF));
2814 assert_eq!(
2815 buf,
2816 &[0xFD],
2817 "decode must consume exactly the UDT_INFO frame"
2818 );
2819 }
2820
2821 #[test]
2825 fn test_xml_info_schema_bound_uses_b_varchar_names() {
2826 let mut data = BytesMut::new();
2827 data.put_u8(1); put_b_varchar(&mut data, "master");
2829 put_b_varchar(&mut data, "dbo");
2830 put_us_varchar(&mut data, "MyXmlSchemaCollection");
2831 data.put_u8(0xFD);
2832
2833 let mut buf: &[u8] = &data;
2834 decode_type_info(&mut buf, TypeId::Xml, TypeId::Xml as u8).unwrap();
2835 assert_eq!(
2836 buf,
2837 &[0xFD],
2838 "decode must consume exactly the XML_INFO frame"
2839 );
2840 }
2841
2842 #[test]
2843 fn hostile_env_change_binary_truncated_is_not_panic() {
2844 let data = [0x01, 0x00, 0x08];
2849 let mut buf: &[u8] = &data;
2850 let env = EnvChange::decode(&mut buf).unwrap();
2851 assert_eq!(env.env_type, EnvChangeType::BeginTransaction);
2852 }
2853
2854 #[test]
2857 fn hostile_env_change_under_declared_cannot_steal_following_bytes() {
2858 let mut data = BytesMut::new();
2863 data.put_u16_le(1); data.put_u8(0x08); let following: &[u8] = &[0x08, 1, 2, 3, 4, 5, 6, 7, 8, 0x00];
2866 data.extend_from_slice(following);
2867
2868 let mut buf: &[u8] = &data;
2869 let env = EnvChange::decode(&mut buf).unwrap();
2870 assert_eq!(env.env_type, EnvChangeType::BeginTransaction);
2871 match &env.new_value {
2872 EnvChangeValue::Binary(b) => {
2873 assert!(
2874 b.is_empty(),
2875 "under-declared frame yields the lenient empty value"
2876 );
2877 }
2878 other => panic!("expected empty Binary value, got {other:?}"),
2879 }
2880 assert_eq!(
2881 buf, following,
2882 "bytes beyond the declared frame belong to the next token"
2883 );
2884 }
2885
2886 #[test]
2889 fn hostile_env_change_zero_length_frame_errors() {
2890 let data = [0x00, 0x00, 0xFD];
2891 let mut buf: &[u8] = &data;
2892 assert!(EnvChange::decode(&mut buf).is_err());
2893 }
2894
2895 #[test]
2896 fn test_colmetadata_no_columns() {
2897 let data = Bytes::from_static(&[0xFF, 0xFF]);
2899 let mut cursor: &[u8] = &data;
2900 let meta = ColMetaData::decode(&mut cursor).unwrap();
2901 assert!(meta.is_empty());
2902 assert_eq!(meta.column_count(), 0);
2903 }
2904
2905 #[test]
2906 fn test_colmetadata_single_int_column() {
2907 let mut data = BytesMut::new();
2910 data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[0x00, 0x00, 0x00, 0x00]); data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[0x38]); data.extend_from_slice(&[0x02]); data.extend_from_slice(&[b'i', 0x00, b'd', 0x00]); let mut cursor: &[u8] = &data;
2919 let meta = ColMetaData::decode(&mut cursor).unwrap();
2920
2921 assert_eq!(meta.column_count(), 1);
2922 assert_eq!(meta.columns[0].name, "id");
2923 assert_eq!(meta.columns[0].type_id, TypeId::Int4);
2924 assert!(meta.columns[0].is_nullable());
2925 }
2926
2927 #[test]
2928 fn test_colmetadata_nvarchar_column() {
2929 let mut data = BytesMut::new();
2931 data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[0x00, 0x00, 0x00, 0x00]); data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[0xE7]); data.extend_from_slice(&[0x64, 0x00]); data.extend_from_slice(&[0x09, 0x04, 0xD0, 0x00, 0x34]); data.extend_from_slice(&[0x04]); data.extend_from_slice(&[b'n', 0x00, b'a', 0x00, b'm', 0x00, b'e', 0x00]);
2941
2942 let mut cursor: &[u8] = &data;
2943 let meta = ColMetaData::decode(&mut cursor).unwrap();
2944
2945 assert_eq!(meta.column_count(), 1);
2946 assert_eq!(meta.columns[0].name, "name");
2947 assert_eq!(meta.columns[0].type_id, TypeId::NVarChar);
2948 assert_eq!(meta.columns[0].type_info.max_length, Some(100));
2949 assert!(meta.columns[0].type_info.collation.is_some());
2950 }
2951
2952 #[test]
2953 fn test_raw_row_decode_int() {
2954 let metadata = ColMetaData {
2956 cek_table: None,
2957 columns: vec![ColumnData {
2958 name: "id".to_string(),
2959 type_id: TypeId::Int4,
2960 col_type: 0x38,
2961 flags: 0,
2962 user_type: 0,
2963 type_info: TypeInfo::default(),
2964 crypto_metadata: None,
2965 }],
2966 };
2967
2968 let data = Bytes::from_static(&[0x2A, 0x00, 0x00, 0x00]); let mut cursor: &[u8] = &data;
2971 let row = RawRow::decode(&mut cursor, &metadata).unwrap();
2972
2973 assert_eq!(row.data.len(), 4);
2975 assert_eq!(&row.data[..], &[0x2A, 0x00, 0x00, 0x00]);
2976 }
2977
2978 #[test]
2979 fn test_raw_row_decode_nullable_int() {
2980 let metadata = ColMetaData {
2982 cek_table: None,
2983 columns: vec![ColumnData {
2984 name: "id".to_string(),
2985 type_id: TypeId::IntN,
2986 col_type: 0x26,
2987 flags: 0x01, user_type: 0,
2989 type_info: TypeInfo {
2990 max_length: Some(4),
2991 ..Default::default()
2992 },
2993 crypto_metadata: None,
2994 }],
2995 };
2996
2997 let data = Bytes::from_static(&[0x04, 0x2A, 0x00, 0x00, 0x00]); let mut cursor: &[u8] = &data;
3000 let row = RawRow::decode(&mut cursor, &metadata).unwrap();
3001
3002 assert_eq!(row.data.len(), 5);
3003 assert_eq!(row.data[0], 4); assert_eq!(&row.data[1..], &[0x2A, 0x00, 0x00, 0x00]);
3005 }
3006
3007 #[test]
3008 fn test_raw_row_decode_null_value() {
3009 let metadata = ColMetaData {
3011 cek_table: None,
3012 columns: vec![ColumnData {
3013 name: "id".to_string(),
3014 type_id: TypeId::IntN,
3015 col_type: 0x26,
3016 flags: 0x01, user_type: 0,
3018 type_info: TypeInfo {
3019 max_length: Some(4),
3020 ..Default::default()
3021 },
3022 crypto_metadata: None,
3023 }],
3024 };
3025
3026 let data = Bytes::from_static(&[0xFF]);
3028 let mut cursor: &[u8] = &data;
3029 let row = RawRow::decode(&mut cursor, &metadata).unwrap();
3030
3031 assert_eq!(row.data.len(), 1);
3032 assert_eq!(row.data[0], 0xFF); }
3034
3035 #[test]
3036 fn test_nbcrow_null_bitmap() {
3037 let row = NbcRow {
3038 null_bitmap: vec![0b00000101], data: Bytes::new(),
3040 };
3041
3042 assert!(row.is_null(0));
3043 assert!(!row.is_null(1));
3044 assert!(row.is_null(2));
3045 assert!(!row.is_null(3));
3046 }
3047
3048 #[test]
3049 fn test_token_parser_colmetadata() {
3050 let mut data = BytesMut::new();
3052 data.extend_from_slice(&[0x81]); data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[0x00, 0x00, 0x00, 0x00]); data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[0x38]); data.extend_from_slice(&[0x02]); data.extend_from_slice(&[b'i', 0x00, b'd', 0x00]); let mut parser = TokenParser::new(data.freeze());
3061 let token = parser.next_token().unwrap().unwrap();
3062
3063 match token {
3064 Token::ColMetaData(meta) => {
3065 assert_eq!(meta.column_count(), 1);
3066 assert_eq!(meta.columns[0].name, "id");
3067 assert_eq!(meta.columns[0].type_id, TypeId::Int4);
3068 }
3069 _ => panic!("Expected ColMetaData token"),
3070 }
3071 }
3072
3073 #[test]
3074 fn test_token_parser_row_with_metadata() {
3075 let metadata = ColMetaData {
3077 cek_table: None,
3078 columns: vec![ColumnData {
3079 name: "id".to_string(),
3080 type_id: TypeId::Int4,
3081 col_type: 0x38,
3082 flags: 0,
3083 user_type: 0,
3084 type_info: TypeInfo::default(),
3085 crypto_metadata: None,
3086 }],
3087 };
3088
3089 let mut data = BytesMut::new();
3091 data.extend_from_slice(&[0xD1]); data.extend_from_slice(&[0x2A, 0x00, 0x00, 0x00]); let mut parser = TokenParser::new(data.freeze());
3095 let token = parser
3096 .next_token_with_metadata(Some(&metadata))
3097 .unwrap()
3098 .unwrap();
3099
3100 match token {
3101 Token::Row(row) => {
3102 assert_eq!(row.data.len(), 4);
3103 }
3104 _ => panic!("Expected Row token"),
3105 }
3106 }
3107
3108 #[test]
3109 fn test_token_parser_row_without_metadata_fails() {
3110 let mut data = BytesMut::new();
3112 data.extend_from_slice(&[0xD1]); data.extend_from_slice(&[0x2A, 0x00, 0x00, 0x00]); let mut parser = TokenParser::new(data.freeze());
3116 let result = parser.next_token(); assert!(result.is_err());
3119 }
3120
3121 #[test]
3122 fn test_token_parser_peek() {
3123 let data = Bytes::from_static(&[
3124 0xFD, 0x10, 0x00, 0xC1, 0x00, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ]);
3129
3130 let parser = TokenParser::new(data);
3131 assert_eq!(parser.peek_token_type(), Some(TokenType::Done));
3132 }
3133
3134 #[test]
3135 fn test_column_data_fixed_size() {
3136 let col = ColumnData {
3137 name: String::new(),
3138 type_id: TypeId::Int4,
3139 col_type: 0x38,
3140 flags: 0,
3141 user_type: 0,
3142 type_info: TypeInfo::default(),
3143 crypto_metadata: None,
3144 };
3145 assert_eq!(col.fixed_size(), Some(4));
3146
3147 let col2 = ColumnData {
3148 name: String::new(),
3149 type_id: TypeId::NVarChar,
3150 col_type: 0xE7,
3151 flags: 0,
3152 user_type: 0,
3153 type_info: TypeInfo::default(),
3154 crypto_metadata: None,
3155 };
3156 assert_eq!(col2.fixed_size(), None);
3157 }
3158
3159 #[test]
3167 fn test_decode_nvarchar_then_intn_roundtrip() {
3168 let mut wire_data = BytesMut::new();
3173
3174 let word = "World";
3177 let utf16: Vec<u16> = word.encode_utf16().collect();
3178 wire_data.put_u16_le((utf16.len() * 2) as u16); for code_unit in &utf16 {
3180 wire_data.put_u16_le(*code_unit);
3181 }
3182
3183 wire_data.put_u8(4); wire_data.put_i32_le(42);
3186
3187 let metadata = ColMetaData {
3189 cek_table: None,
3190 columns: vec![
3191 ColumnData {
3192 name: "greeting".to_string(),
3193 type_id: TypeId::NVarChar,
3194 col_type: 0xE7,
3195 flags: 0x01,
3196 user_type: 0,
3197 type_info: TypeInfo {
3198 max_length: Some(10), precision: None,
3200 scale: None,
3201 collation: None,
3202 },
3203 crypto_metadata: None,
3204 },
3205 ColumnData {
3206 name: "number".to_string(),
3207 type_id: TypeId::IntN,
3208 col_type: 0x26,
3209 flags: 0x01,
3210 user_type: 0,
3211 type_info: TypeInfo {
3212 max_length: Some(4),
3213 precision: None,
3214 scale: None,
3215 collation: None,
3216 },
3217 crypto_metadata: None,
3218 },
3219 ],
3220 };
3221
3222 let mut wire_cursor = wire_data.freeze();
3224 let raw_row = RawRow::decode(&mut wire_cursor, &metadata).unwrap();
3225
3226 assert_eq!(
3228 wire_cursor.remaining(),
3229 0,
3230 "wire data should be fully consumed"
3231 );
3232
3233 let mut stored_cursor: &[u8] = &raw_row.data;
3235
3236 assert!(
3239 stored_cursor.remaining() >= 2,
3240 "need at least 2 bytes for length"
3241 );
3242 let len0 = stored_cursor.get_u16_le() as usize;
3243 assert_eq!(len0, 10, "NVarChar length should be 10 bytes");
3244 assert!(
3245 stored_cursor.remaining() >= len0,
3246 "need {len0} bytes for data"
3247 );
3248
3249 let mut utf16_read = Vec::new();
3251 for _ in 0..(len0 / 2) {
3252 utf16_read.push(stored_cursor.get_u16_le());
3253 }
3254 let string0 = String::from_utf16(&utf16_read).unwrap();
3255 assert_eq!(string0, "World", "column 0 should be 'World'");
3256
3257 assert!(
3260 stored_cursor.remaining() >= 1,
3261 "need at least 1 byte for length"
3262 );
3263 let len1 = stored_cursor.get_u8();
3264 assert_eq!(len1, 4, "IntN length should be 4");
3265 assert!(stored_cursor.remaining() >= 4, "need 4 bytes for INT data");
3266 let int1 = stored_cursor.get_i32_le();
3267 assert_eq!(int1, 42, "column 1 should be 42");
3268
3269 assert_eq!(
3271 stored_cursor.remaining(),
3272 0,
3273 "stored data should be fully consumed"
3274 );
3275 }
3276
3277 #[test]
3278 fn test_decode_nvarchar_max_then_intn_roundtrip() {
3279 let mut wire_data = BytesMut::new();
3283
3284 let word = "Hello";
3287 let utf16: Vec<u16> = word.encode_utf16().collect();
3288 let byte_len = (utf16.len() * 2) as u64;
3289
3290 wire_data.put_u64_le(byte_len); wire_data.put_u32_le(byte_len as u32); for code_unit in &utf16 {
3293 wire_data.put_u16_le(*code_unit);
3294 }
3295 wire_data.put_u32_le(0); wire_data.put_u8(4);
3299 wire_data.put_i32_le(99);
3300
3301 let metadata = ColMetaData {
3303 cek_table: None,
3304 columns: vec![
3305 ColumnData {
3306 name: "text".to_string(),
3307 type_id: TypeId::NVarChar,
3308 col_type: 0xE7,
3309 flags: 0x01,
3310 user_type: 0,
3311 type_info: TypeInfo {
3312 max_length: Some(0xFFFF), precision: None,
3314 scale: None,
3315 collation: None,
3316 },
3317 crypto_metadata: None,
3318 },
3319 ColumnData {
3320 name: "num".to_string(),
3321 type_id: TypeId::IntN,
3322 col_type: 0x26,
3323 flags: 0x01,
3324 user_type: 0,
3325 type_info: TypeInfo {
3326 max_length: Some(4),
3327 precision: None,
3328 scale: None,
3329 collation: None,
3330 },
3331 crypto_metadata: None,
3332 },
3333 ],
3334 };
3335
3336 let mut wire_cursor = wire_data.freeze();
3338 let raw_row = RawRow::decode(&mut wire_cursor, &metadata).unwrap();
3339
3340 assert_eq!(
3342 wire_cursor.remaining(),
3343 0,
3344 "wire data should be fully consumed"
3345 );
3346
3347 let mut stored_cursor: &[u8] = &raw_row.data;
3349
3350 let total_len = stored_cursor.get_u64_le();
3352 assert_eq!(total_len, 10, "PLP total length should be 10");
3353
3354 let chunk_len = stored_cursor.get_u32_le();
3355 assert_eq!(chunk_len, 10, "PLP chunk length should be 10");
3356
3357 let mut utf16_read = Vec::new();
3358 for _ in 0..(chunk_len / 2) {
3359 utf16_read.push(stored_cursor.get_u16_le());
3360 }
3361 let string0 = String::from_utf16(&utf16_read).unwrap();
3362 assert_eq!(string0, "Hello", "column 0 should be 'Hello'");
3363
3364 let terminator = stored_cursor.get_u32_le();
3365 assert_eq!(terminator, 0, "PLP should end with 0");
3366
3367 let len1 = stored_cursor.get_u8();
3369 assert_eq!(len1, 4);
3370 let int1 = stored_cursor.get_i32_le();
3371 assert_eq!(int1, 99, "column 1 should be 99");
3372
3373 assert_eq!(
3375 stored_cursor.remaining(),
3376 0,
3377 "stored data should be fully consumed"
3378 );
3379 }
3380
3381 #[test]
3386 fn test_return_status_via_parser() {
3387 let data = Bytes::from_static(&[
3389 0x79, 0x00, 0x00, 0x00, 0x00, ]);
3392
3393 let mut parser = TokenParser::new(data);
3394 let token = parser.next_token().unwrap().unwrap();
3395
3396 match token {
3397 Token::ReturnStatus(status) => {
3398 assert_eq!(status, 0);
3399 }
3400 _ => panic!("Expected ReturnStatus token, got {token:?}"),
3401 }
3402
3403 assert!(parser.next_token().unwrap().is_none());
3404 }
3405
3406 #[test]
3407 fn test_return_status_nonzero() {
3408 let mut buf = BytesMut::new();
3410 buf.put_u8(0x79); buf.put_i32_le(-6);
3412
3413 let mut parser = TokenParser::new(buf.freeze());
3414 let token = parser.next_token().unwrap().unwrap();
3415
3416 match token {
3417 Token::ReturnStatus(status) => {
3418 assert_eq!(status, -6);
3419 }
3420 _ => panic!("Expected ReturnStatus token"),
3421 }
3422 }
3423
3424 #[test]
3429 fn test_done_proc_roundtrip() {
3430 let done = DoneProc {
3431 status: DoneStatus {
3432 more: false,
3433 error: false,
3434 in_xact: false,
3435 count: true,
3436 attn: false,
3437 srverror: false,
3438 },
3439 cur_cmd: 0x00C6, row_count: 100,
3441 };
3442
3443 let mut buf = BytesMut::new();
3444 done.encode(&mut buf);
3445
3446 assert_eq!(buf[0], 0xFE);
3448
3449 let mut cursor = &buf[1..];
3451 let decoded = DoneProc::decode(&mut cursor).unwrap();
3452
3453 assert!(decoded.status.count);
3454 assert!(!decoded.status.more);
3455 assert!(!decoded.status.error);
3456 assert_eq!(decoded.cur_cmd, 0x00C6);
3457 assert_eq!(decoded.row_count, 100);
3458 }
3459
3460 #[test]
3461 fn test_done_proc_via_parser() {
3462 let data = Bytes::from_static(&[
3463 0xFE, 0x00, 0x00, 0xC6, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ]);
3468
3469 let mut parser = TokenParser::new(data);
3470 let token = parser.next_token().unwrap().unwrap();
3471
3472 match token {
3473 Token::DoneProc(done) => {
3474 assert!(!done.status.count);
3475 assert!(!done.status.more);
3476 assert_eq!(done.cur_cmd, 198);
3477 assert_eq!(done.row_count, 0);
3478 }
3479 _ => panic!("Expected DoneProc token"),
3480 }
3481 }
3482
3483 #[test]
3484 fn test_done_proc_with_error_flag() {
3485 let mut buf = BytesMut::new();
3486 buf.put_u8(0xFE); buf.put_u16_le(0x0002); buf.put_u16_le(0x00C6); buf.put_u64_le(0); let mut parser = TokenParser::new(buf.freeze());
3492 let token = parser.next_token().unwrap().unwrap();
3493
3494 match token {
3495 Token::DoneProc(done) => {
3496 assert!(done.status.error);
3497 assert!(!done.status.count);
3498 assert!(!done.status.more);
3499 }
3500 _ => panic!("Expected DoneProc token"),
3501 }
3502 }
3503
3504 #[test]
3509 fn test_done_in_proc_roundtrip() {
3510 let done = DoneInProc {
3511 status: DoneStatus {
3512 more: true,
3513 error: false,
3514 in_xact: false,
3515 count: true,
3516 attn: false,
3517 srverror: false,
3518 },
3519 cur_cmd: 193, row_count: 7,
3521 };
3522
3523 let mut buf = BytesMut::new();
3524 done.encode(&mut buf);
3525
3526 assert_eq!(buf[0], 0xFF);
3527
3528 let mut cursor = &buf[1..];
3529 let decoded = DoneInProc::decode(&mut cursor).unwrap();
3530
3531 assert!(decoded.status.more);
3532 assert!(decoded.status.count);
3533 assert!(!decoded.status.error);
3534 assert_eq!(decoded.cur_cmd, 193);
3535 assert_eq!(decoded.row_count, 7);
3536 }
3537
3538 #[test]
3539 fn test_done_in_proc_via_parser() {
3540 let data = Bytes::from_static(&[
3541 0xFF, 0x11, 0x00, 0xC1, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ]);
3546
3547 let mut parser = TokenParser::new(data);
3548 let token = parser.next_token().unwrap().unwrap();
3549
3550 match token {
3551 Token::DoneInProc(done) => {
3552 assert!(done.status.more);
3553 assert!(done.status.count);
3554 assert_eq!(done.cur_cmd, 193);
3555 assert_eq!(done.row_count, 3);
3556 }
3557 _ => panic!("Expected DoneInProc token"),
3558 }
3559 }
3560
3561 #[test]
3566 fn test_server_error_decode() {
3567 let mut buf = BytesMut::new();
3570
3571 let msg_utf16: Vec<u16> = "Invalid column name 'foo'.".encode_utf16().collect();
3573 let srv_utf16: Vec<u16> = "SQLDB01".encode_utf16().collect();
3574 let proc_utf16: Vec<u16> = "".encode_utf16().collect();
3575
3576 let length: u16 = (4
3582 + 1
3583 + 1
3584 + 2
3585 + (msg_utf16.len() * 2)
3586 + 1
3587 + (srv_utf16.len() * 2)
3588 + 1
3589 + (proc_utf16.len() * 2)
3590 + 4) as u16;
3591
3592 buf.put_u16_le(length);
3593 buf.put_i32_le(207); buf.put_u8(1); buf.put_u8(16); buf.put_u16_le(msg_utf16.len() as u16);
3599 for &c in &msg_utf16 {
3600 buf.put_u16_le(c);
3601 }
3602
3603 buf.put_u8(srv_utf16.len() as u8);
3605 for &c in &srv_utf16 {
3606 buf.put_u16_le(c);
3607 }
3608
3609 buf.put_u8(proc_utf16.len() as u8);
3611
3612 buf.put_i32_le(42);
3614
3615 let mut cursor = buf.freeze();
3616 let error = ServerError::decode(&mut cursor).unwrap();
3617
3618 assert_eq!(error.number, 207);
3619 assert_eq!(error.state, 1);
3620 assert_eq!(error.class, 16);
3621 assert_eq!(error.message, "Invalid column name 'foo'.");
3622 assert_eq!(error.server, "SQLDB01");
3623 assert_eq!(error.procedure, "");
3624 assert_eq!(error.line, 42);
3625 }
3626
3627 #[test]
3628 fn test_server_error_severity_helpers() {
3629 let fatal = ServerError {
3630 number: 4014,
3631 state: 1,
3632 class: 20,
3633 message: "Fatal error".to_string(),
3634 server: String::new(),
3635 procedure: String::new(),
3636 line: 0,
3637 };
3638 assert!(fatal.is_fatal());
3639 assert!(fatal.is_batch_abort());
3640
3641 let batch_abort = ServerError {
3642 number: 547,
3643 state: 0,
3644 class: 16,
3645 message: "Constraint violation".to_string(),
3646 server: String::new(),
3647 procedure: String::new(),
3648 line: 1,
3649 };
3650 assert!(!batch_abort.is_fatal());
3651 assert!(batch_abort.is_batch_abort());
3652
3653 let informational = ServerError {
3654 number: 5701,
3655 state: 2,
3656 class: 10,
3657 message: "Changed db context".to_string(),
3658 server: String::new(),
3659 procedure: String::new(),
3660 line: 0,
3661 };
3662 assert!(!informational.is_fatal());
3663 assert!(!informational.is_batch_abort());
3664 }
3665
3666 #[test]
3667 fn test_server_error_via_parser() {
3668 let mut buf = BytesMut::new();
3670 buf.put_u8(0xAA); let msg_utf16: Vec<u16> = "Syntax error".encode_utf16().collect();
3673 let srv_utf16: Vec<u16> = "SRV".encode_utf16().collect();
3674 let proc_utf16: Vec<u16> = "sp_test".encode_utf16().collect();
3675
3676 let length: u16 = (4
3677 + 1
3678 + 1
3679 + 2
3680 + (msg_utf16.len() * 2)
3681 + 1
3682 + (srv_utf16.len() * 2)
3683 + 1
3684 + (proc_utf16.len() * 2)
3685 + 4) as u16;
3686
3687 buf.put_u16_le(length);
3688 buf.put_i32_le(102); buf.put_u8(1);
3690 buf.put_u8(15);
3691
3692 buf.put_u16_le(msg_utf16.len() as u16);
3693 for &c in &msg_utf16 {
3694 buf.put_u16_le(c);
3695 }
3696 buf.put_u8(srv_utf16.len() as u8);
3697 for &c in &srv_utf16 {
3698 buf.put_u16_le(c);
3699 }
3700 buf.put_u8(proc_utf16.len() as u8);
3701 for &c in &proc_utf16 {
3702 buf.put_u16_le(c);
3703 }
3704 buf.put_i32_le(5);
3705
3706 let mut parser = TokenParser::new(buf.freeze());
3707 let token = parser.next_token().unwrap().unwrap();
3708
3709 match token {
3710 Token::Error(err) => {
3711 assert_eq!(err.number, 102);
3712 assert_eq!(err.class, 15);
3713 assert_eq!(err.message, "Syntax error");
3714 assert_eq!(err.server, "SRV");
3715 assert_eq!(err.procedure, "sp_test");
3716 assert_eq!(err.line, 5);
3717 }
3718 _ => panic!("Expected Error token"),
3719 }
3720 }
3721
3722 fn build_return_value_intn(
3729 ordinal: u16,
3730 name: &str,
3731 status: u8,
3732 value: Option<i32>,
3733 ) -> BytesMut {
3734 let mut inner = BytesMut::new();
3735
3736 inner.put_u16_le(ordinal);
3738
3739 let name_utf16: Vec<u16> = name.encode_utf16().collect();
3741 inner.put_u8(name_utf16.len() as u8);
3742 for &c in &name_utf16 {
3743 inner.put_u16_le(c);
3744 }
3745
3746 inner.put_u8(status);
3748
3749 inner.put_u32_le(0);
3751
3752 inner.put_u16_le(0x0001); inner.put_u8(0x26);
3757
3758 inner.put_u8(4);
3760
3761 match value {
3763 Some(v) => {
3764 inner.put_u8(4); inner.put_i32_le(v);
3766 }
3767 None => {
3768 inner.put_u8(0); }
3770 }
3771
3772 inner
3775 }
3776
3777 #[test]
3778 fn test_return_value_int_output() {
3779 let buf = build_return_value_intn(1, "@result", 0x01, Some(42));
3780 let mut cursor = buf.freeze();
3781 let rv = ReturnValue::decode(&mut cursor).unwrap();
3782
3783 assert_eq!(rv.param_ordinal, 1);
3784 assert_eq!(rv.param_name, "@result");
3785 assert_eq!(rv.status, 0x01); assert_eq!(rv.col_type, 0x26); assert_eq!(rv.type_info.max_length, Some(4));
3788 assert_eq!(rv.value.len(), 5);
3790 assert_eq!(rv.value[0], 4);
3791 assert_eq!(
3792 i32::from_le_bytes([rv.value[1], rv.value[2], rv.value[3], rv.value[4]]),
3793 42
3794 );
3795 }
3796
3797 #[test]
3798 fn test_return_value_null_output() {
3799 let buf = build_return_value_intn(2, "@count", 0x01, None);
3800 let mut cursor = buf.freeze();
3801 let rv = ReturnValue::decode(&mut cursor).unwrap();
3802
3803 assert_eq!(rv.param_ordinal, 2);
3804 assert_eq!(rv.param_name, "@count");
3805 assert_eq!(rv.status, 0x01);
3806 assert_eq!(rv.col_type, 0x26);
3807 assert_eq!(rv.value.len(), 1);
3809 assert_eq!(rv.value[0], 0);
3810 }
3811
3812 #[test]
3813 fn test_return_value_udf_status() {
3814 let buf = build_return_value_intn(0, "@RETURN_VALUE", 0x02, Some(-1));
3816 let mut cursor = buf.freeze();
3817 let rv = ReturnValue::decode(&mut cursor).unwrap();
3818
3819 assert_eq!(rv.param_ordinal, 0);
3820 assert_eq!(rv.param_name, "@RETURN_VALUE");
3821 assert_eq!(rv.status, 0x02); assert_eq!(rv.value[0], 4);
3823 assert_eq!(
3824 i32::from_le_bytes([rv.value[1], rv.value[2], rv.value[3], rv.value[4]]),
3825 -1
3826 );
3827 }
3828
3829 #[test]
3830 fn test_return_value_nvarchar_output() {
3831 let mut inner = BytesMut::new();
3833
3834 inner.put_u16_le(1);
3836
3837 let name_utf16: Vec<u16> = "@name".encode_utf16().collect();
3839 inner.put_u8(name_utf16.len() as u8);
3840 for &c in &name_utf16 {
3841 inner.put_u16_le(c);
3842 }
3843
3844 inner.put_u8(0x01);
3846 inner.put_u32_le(0);
3848 inner.put_u16_le(0x0001);
3850 inner.put_u8(0xE7);
3852 inner.put_u16_le(200); inner.put_u32_le(0x0904D000); inner.put_u8(0x34); let val_utf16: Vec<u16> = "Hello".encode_utf16().collect();
3859 let byte_len = (val_utf16.len() * 2) as u16;
3860 inner.put_u16_le(byte_len);
3861 for &c in &val_utf16 {
3862 inner.put_u16_le(c);
3863 }
3864
3865 let mut cursor = inner.freeze();
3866 let rv = ReturnValue::decode(&mut cursor).unwrap();
3867
3868 assert_eq!(rv.param_ordinal, 1);
3869 assert_eq!(rv.param_name, "@name");
3870 assert_eq!(rv.status, 0x01);
3871 assert_eq!(rv.col_type, 0xE7); assert_eq!(rv.type_info.max_length, Some(200));
3873 assert!(rv.type_info.collation.is_some());
3874
3875 assert_eq!(rv.value.len(), 12); let val_len = u16::from_le_bytes([rv.value[0], rv.value[1]]);
3878 assert_eq!(val_len, 10);
3879 }
3880
3881 #[test]
3882 fn test_return_value_via_parser() {
3883 let mut data = BytesMut::new();
3885 data.put_u8(0xAC); data.extend_from_slice(&build_return_value_intn(0, "@out", 0x01, Some(99)));
3887
3888 let mut parser = TokenParser::new(data.freeze());
3889 let token = parser.next_token().unwrap().unwrap();
3890
3891 match token {
3892 Token::ReturnValue(rv) => {
3893 assert_eq!(rv.param_name, "@out");
3894 assert_eq!(rv.param_ordinal, 0);
3895 assert_eq!(rv.status, 0x01);
3896 assert_eq!(rv.col_type, 0x26);
3897 }
3898 _ => panic!("Expected ReturnValue token"),
3899 }
3900 }
3901
3902 #[test]
3907 fn test_multi_token_stored_proc_response() {
3908 let mut data = BytesMut::new();
3911
3912 data.put_u8(0xFF); data.put_u16_le(0x0010); data.put_u16_le(0x00C1); data.put_u64_le(3); data.put_u8(0x79); data.put_i32_le(0);
3921
3922 data.put_u8(0xFE); data.put_u16_le(0x0000); data.put_u16_le(0x00C6); data.put_u64_le(0);
3927
3928 let mut parser = TokenParser::new(data.freeze());
3929
3930 let t1 = parser.next_token().unwrap().unwrap();
3932 match t1 {
3933 Token::DoneInProc(done) => {
3934 assert!(done.status.count);
3935 assert_eq!(done.row_count, 3);
3936 assert_eq!(done.cur_cmd, 193);
3937 }
3938 _ => panic!("Expected DoneInProc, got {t1:?}"),
3939 }
3940
3941 let t2 = parser.next_token().unwrap().unwrap();
3943 match t2 {
3944 Token::ReturnStatus(status) => {
3945 assert_eq!(status, 0);
3946 }
3947 _ => panic!("Expected ReturnStatus, got {t2:?}"),
3948 }
3949
3950 let t3 = parser.next_token().unwrap().unwrap();
3952 match t3 {
3953 Token::DoneProc(done) => {
3954 assert!(!done.status.count);
3955 assert!(!done.status.more);
3956 assert_eq!(done.cur_cmd, 198);
3957 }
3958 _ => panic!("Expected DoneProc, got {t3:?}"),
3959 }
3960
3961 assert!(parser.next_token().unwrap().is_none());
3963 }
3964
3965 #[test]
3966 fn test_multi_token_error_in_stream() {
3967 let mut data = BytesMut::new();
3969
3970 data.put_u8(0xAA);
3972
3973 let msg_utf16: Vec<u16> = "Deadlock".encode_utf16().collect();
3974 let srv_utf16: Vec<u16> = "DB1".encode_utf16().collect();
3975
3976 let length: u16 = (4 + 1 + 1
3977 + 2 + (msg_utf16.len() * 2)
3978 + 1 + (srv_utf16.len() * 2)
3979 + 1 + 4) as u16;
3981
3982 data.put_u16_le(length);
3983 data.put_i32_le(1205); data.put_u8(51); data.put_u8(13); data.put_u16_le(msg_utf16.len() as u16);
3988 for &c in &msg_utf16 {
3989 data.put_u16_le(c);
3990 }
3991 data.put_u8(srv_utf16.len() as u8);
3992 for &c in &srv_utf16 {
3993 data.put_u16_le(c);
3994 }
3995 data.put_u8(0); data.put_i32_le(0);
3997
3998 data.put_u8(0xFD);
4000 data.put_u16_le(0x0002); data.put_u16_le(0x00C1); data.put_u64_le(0);
4003
4004 let mut parser = TokenParser::new(data.freeze());
4005
4006 let t1 = parser.next_token().unwrap().unwrap();
4008 match t1 {
4009 Token::Error(err) => {
4010 assert_eq!(err.number, 1205);
4011 assert_eq!(err.class, 13);
4012 assert_eq!(err.message, "Deadlock");
4013 assert_eq!(err.server, "DB1");
4014 }
4015 _ => panic!("Expected Error token, got {t1:?}"),
4016 }
4017
4018 let t2 = parser.next_token().unwrap().unwrap();
4020 match t2 {
4021 Token::Done(done) => {
4022 assert!(done.status.error);
4023 assert!(!done.status.count);
4024 }
4025 _ => panic!("Expected Done token, got {t2:?}"),
4026 }
4027
4028 assert!(parser.next_token().unwrap().is_none());
4029 }
4030
4031 #[test]
4032 fn test_multi_token_proc_with_return_value() {
4033 let mut data = BytesMut::new();
4035
4036 data.put_u8(0xAC);
4038 data.extend_from_slice(&build_return_value_intn(1, "@result", 0x01, Some(42)));
4039
4040 data.put_u8(0x79);
4042 data.put_i32_le(0);
4043
4044 data.put_u8(0xFE);
4046 data.put_u16_le(0x0000);
4047 data.put_u16_le(0x00C6);
4048 data.put_u64_le(0);
4049
4050 let mut parser = TokenParser::new(data.freeze());
4051
4052 let t1 = parser.next_token().unwrap().unwrap();
4053 match t1 {
4054 Token::ReturnValue(rv) => {
4055 assert_eq!(rv.param_name, "@result");
4056 assert_eq!(rv.param_ordinal, 1);
4057 }
4058 _ => panic!("Expected ReturnValue, got {t1:?}"),
4059 }
4060
4061 let t2 = parser.next_token().unwrap().unwrap();
4062 assert!(matches!(t2, Token::ReturnStatus(0)));
4063
4064 let t3 = parser.next_token().unwrap().unwrap();
4065 assert!(matches!(t3, Token::DoneProc(_)));
4066
4067 assert!(parser.next_token().unwrap().is_none());
4068 }
4069
4070 #[test]
4075 fn test_return_status_truncated() {
4076 let data = Bytes::from_static(&[0x79, 0x01, 0x02, 0x03]);
4078 let mut parser = TokenParser::new(data);
4079 assert!(parser.next_token().is_err());
4080 }
4081
4082 #[test]
4083 fn test_done_proc_truncated() {
4084 let data = Bytes::from_static(&[0xFE, 0x00, 0x00, 0xC1, 0x00, 0x01, 0x00, 0x00, 0x00]);
4086 let mut parser = TokenParser::new(data);
4087 assert!(parser.next_token().is_err());
4088 }
4089
4090 #[test]
4091 fn test_server_error_truncated() {
4092 let data = Bytes::from_static(&[0xAA, 0x20, 0x00]);
4094 let mut parser = TokenParser::new(data);
4095 assert!(parser.next_token().is_err());
4096 }
4097
4098 fn build_fed_auth_info_token(options: &[(u8, &str)]) -> Vec<u8> {
4107 let headers_end = 4 + options.len() * 9;
4108 let mut data_block = Vec::new();
4109 let mut headers = Vec::new();
4110 for (id, value) in options {
4111 let encoded: Vec<u8> = value.encode_utf16().flat_map(u16::to_le_bytes).collect();
4112 let offset = headers_end + data_block.len();
4113 headers.push(*id);
4114 headers.extend_from_slice(&u32::try_from(encoded.len()).unwrap().to_le_bytes());
4115 headers.extend_from_slice(&u32::try_from(offset).unwrap().to_le_bytes());
4116 data_block.extend_from_slice(&encoded);
4117 }
4118
4119 let token_len = 4 + headers.len() + data_block.len();
4120 let mut out = vec![0xEE];
4121 out.extend_from_slice(&u32::try_from(token_len).unwrap().to_le_bytes());
4122 out.extend_from_slice(&u32::try_from(options.len()).unwrap().to_le_bytes());
4123 out.extend_from_slice(&headers);
4124 out.extend_from_slice(&data_block);
4125 out
4126 }
4127
4128 #[test]
4129 fn test_fed_auth_info_decodes_spec_layout() {
4130 const STS: &str = "https://login.microsoftonline.com/common";
4131 const SPN: &str = "https://database.windows.net/";
4132 let token = build_fed_auth_info_token(&[(0x01, STS), (0x02, SPN)]);
4136
4137 let mut parser = TokenParser::new(Bytes::from(token));
4138 let parsed = parser.next_token().unwrap().unwrap();
4139 let Token::FedAuthInfo(info) = parsed else {
4140 panic!("expected FedAuthInfo, got {parsed:?}");
4141 };
4142 assert_eq!(info.sts_url, STS);
4143 assert_eq!(info.spn, SPN);
4144 assert!(parser.next_token().unwrap().is_none(), "exact consumption");
4145 }
4146
4147 #[test]
4148 fn test_fed_auth_info_preserves_following_tokens() {
4149 let mut stream = build_fed_auth_info_token(&[
4153 (0x01, "https://sts.example/"),
4154 (0x02, "https://db.example/"),
4155 ]);
4156 stream.push(0xFD); stream.extend_from_slice(&0u16.to_le_bytes()); stream.extend_from_slice(&0u16.to_le_bytes()); stream.extend_from_slice(&0u64.to_le_bytes()); let mut parser = TokenParser::new(Bytes::from(stream));
4162 assert!(matches!(
4163 parser.next_token().unwrap(),
4164 Some(Token::FedAuthInfo(_))
4165 ));
4166 assert!(
4167 matches!(parser.next_token().unwrap(), Some(Token::Done(_))),
4168 "DONE after FEDAUTHINFO must not be swallowed"
4169 );
4170 assert!(parser.next_token().unwrap().is_none());
4171 }
4172
4173 #[test]
4174 fn test_fed_auth_info_unknown_ids_ignored() {
4175 let token =
4177 build_fed_auth_info_token(&[(0x7F, "ignore-me"), (0x01, "https://sts.example/")]);
4178 let mut parser = TokenParser::new(Bytes::from(token));
4179 let Some(Token::FedAuthInfo(info)) = parser.next_token().unwrap() else {
4180 panic!("expected FedAuthInfo");
4181 };
4182 assert_eq!(info.sts_url, "https://sts.example/");
4183 assert_eq!(info.spn, "");
4184 }
4185
4186 #[test]
4187 fn test_fed_auth_info_hostile_inputs_error() {
4188 let mut truncated = build_fed_auth_info_token(&[(0x02, "https://sts.example/")]);
4190 truncated.truncate(truncated.len() - 4);
4191 assert!(
4192 TokenParser::new(Bytes::from(truncated))
4193 .next_token()
4194 .is_err()
4195 );
4196
4197 let mut bad_count = build_fed_auth_info_token(&[(0x02, "https://sts.example/")]);
4200 bad_count[5..9].copy_from_slice(&u32::MAX.to_le_bytes());
4201 assert!(
4202 TokenParser::new(Bytes::from(bad_count))
4203 .next_token()
4204 .is_err()
4205 );
4206
4207 let mut bad_offset = build_fed_auth_info_token(&[(0x02, "https://sts.example/")]);
4209 bad_offset[14..18].copy_from_slice(&u32::MAX.to_le_bytes());
4210 assert!(
4211 TokenParser::new(Bytes::from(bad_offset))
4212 .next_token()
4213 .is_err()
4214 );
4215
4216 let mut odd_len = build_fed_auth_info_token(&[(0x02, "https://sts.example/")]);
4218 odd_len[10..14].copy_from_slice(&3u32.to_le_bytes());
4219 assert!(TokenParser::new(Bytes::from(odd_len)).next_token().is_err());
4220 }
4221
4222 #[test]
4223 fn test_fed_auth_info_parse_and_skip_agree() {
4224 let token = build_fed_auth_info_token(&[(0x02, "https://sts.example/")]);
4227 let total = token.len();
4228
4229 let mut parser = TokenParser::new(Bytes::from(token.clone()));
4230 parser.next_token().unwrap();
4231 assert_eq!(parser.position(), total, "decode consumption");
4232
4233 let mut skipper = TokenParser::new(Bytes::from(token));
4234 skipper.skip_token().unwrap();
4235 assert_eq!(skipper.position(), total, "skip consumption");
4236 }
4237
4238 #[test]
4249 fn test_fed_auth_info_captured_from_azure() {
4250 const CAPTURED: &[u8] = &[
4251 0xEE, 0xCC, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x02, 0x3A, 0x00, 0x00, 0x00,
4252 0x16, 0x00, 0x00, 0x00, 0x01, 0x7C, 0x00, 0x00, 0x00, 0x50, 0x00, 0x00, 0x00, 0x68,
4253 0x00, 0x74, 0x00, 0x74, 0x00, 0x70, 0x00, 0x73, 0x00, 0x3A, 0x00, 0x2F, 0x00, 0x2F,
4254 0x00, 0x64, 0x00, 0x61, 0x00, 0x74, 0x00, 0x61, 0x00, 0x62, 0x00, 0x61, 0x00, 0x73,
4255 0x00, 0x65, 0x00, 0x2E, 0x00, 0x77, 0x00, 0x69, 0x00, 0x6E, 0x00, 0x64, 0x00, 0x6F,
4256 0x00, 0x77, 0x00, 0x73, 0x00, 0x2E, 0x00, 0x6E, 0x00, 0x65, 0x00, 0x74, 0x00, 0x2F,
4257 0x00, 0x68, 0x00, 0x74, 0x00, 0x74, 0x00, 0x70, 0x00, 0x73, 0x00, 0x3A, 0x00, 0x2F,
4258 0x00, 0x2F, 0x00, 0x6C, 0x00, 0x6F, 0x00, 0x67, 0x00, 0x69, 0x00, 0x6E, 0x00, 0x2E,
4259 0x00, 0x77, 0x00, 0x69, 0x00, 0x6E, 0x00, 0x64, 0x00, 0x6F, 0x00, 0x77, 0x00, 0x73,
4260 0x00, 0x2E, 0x00, 0x6E, 0x00, 0x65, 0x00, 0x74, 0x00, 0x2F, 0x00, 0x30, 0x00, 0x30,
4261 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x2D,
4262 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x2D, 0x00, 0x30, 0x00, 0x30,
4263 0x00, 0x30, 0x00, 0x30, 0x00, 0x2D, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30,
4264 0x00, 0x2D, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30,
4265 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00, 0x30, 0x00,
4266 ];
4267
4268 let mut parser = TokenParser::new(Bytes::from_static(CAPTURED));
4269 let Some(Token::FedAuthInfo(info)) = parser.next_token().unwrap() else {
4270 panic!("expected FedAuthInfo");
4271 };
4272 assert_eq!(
4273 info.sts_url,
4274 "https://login.windows.net/00000000-0000-0000-0000-000000000000"
4275 );
4276 assert_eq!(info.spn, "https://database.windows.net/");
4277 assert!(
4278 parser.next_token().unwrap().is_none(),
4279 "the captured token must be consumed exactly"
4280 );
4281 }
4282}