1use bytes::{BufMut, Bytes, BytesMut};
27
28use crate::codec::write_utf16_string;
29use crate::crypto::EncryptionTypeWire;
30use crate::prelude::*;
31use crate::token::Collation;
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38#[repr(u16)]
39#[non_exhaustive]
40pub enum ProcId {
41 Cursor = 0x0001,
43 CursorOpen = 0x0002,
45 CursorPrepare = 0x0003,
47 CursorExecute = 0x0004,
49 CursorPrepExec = 0x0005,
51 CursorUnprepare = 0x0006,
53 CursorFetch = 0x0007,
55 CursorOption = 0x0008,
57 CursorClose = 0x0009,
59 ExecuteSql = 0x000A,
61 Prepare = 0x000B,
63 Execute = 0x000C,
65 PrepExec = 0x000D,
67 PrepExecRpc = 0x000E,
69 Unprepare = 0x000F,
71}
72
73#[derive(Debug, Clone, Copy, Default)]
75pub struct RpcOptionFlags {
76 pub with_recompile: bool,
78 pub no_metadata: bool,
80 pub reuse_metadata: bool,
82}
83
84impl RpcOptionFlags {
85 pub fn new() -> Self {
87 Self::default()
88 }
89
90 #[must_use]
92 pub fn with_recompile(mut self, value: bool) -> Self {
93 self.with_recompile = value;
94 self
95 }
96
97 pub fn encode(&self) -> u16 {
99 let mut flags = 0u16;
100 if self.with_recompile {
101 flags |= 0x0001;
102 }
103 if self.no_metadata {
104 flags |= 0x0002;
105 }
106 if self.reuse_metadata {
107 flags |= 0x0004;
108 }
109 flags
110 }
111}
112
113#[derive(Debug, Clone, Copy, Default)]
115pub struct ParamFlags {
116 pub by_ref: bool,
118 pub default: bool,
120 pub encrypted: bool,
122}
123
124impl ParamFlags {
125 pub fn new() -> Self {
127 Self::default()
128 }
129
130 #[must_use]
132 pub fn output(mut self) -> Self {
133 self.by_ref = true;
134 self
135 }
136
137 pub fn encode(&self) -> u8 {
139 let mut flags = 0u8;
140 if self.by_ref {
141 flags |= 0x01;
142 }
143 if self.default {
144 flags |= 0x02;
145 }
146 if self.encrypted {
147 flags |= 0x08;
148 }
149 flags
150 }
151}
152
153#[derive(Debug, Clone)]
155pub struct TypeInfo {
156 pub type_id: u8,
158 pub max_length: Option<u16>,
160 pub precision: Option<u8>,
162 pub scale: Option<u8>,
164 pub collation: Option<[u8; 5]>,
166 pub tvp_type_name: Option<String>,
168}
169
170impl TypeInfo {
171 pub fn int() -> Self {
173 Self {
174 type_id: 0x26, max_length: Some(4),
176 precision: None,
177 scale: None,
178 collation: None,
179 tvp_type_name: None,
180 }
181 }
182
183 pub fn bigint() -> Self {
185 Self {
186 type_id: 0x26, max_length: Some(8),
188 precision: None,
189 scale: None,
190 collation: None,
191 tvp_type_name: None,
192 }
193 }
194
195 pub fn smallint() -> Self {
197 Self {
198 type_id: 0x26, max_length: Some(2),
200 precision: None,
201 scale: None,
202 collation: None,
203 tvp_type_name: None,
204 }
205 }
206
207 pub fn tinyint() -> Self {
209 Self {
210 type_id: 0x26, max_length: Some(1),
212 precision: None,
213 scale: None,
214 collation: None,
215 tvp_type_name: None,
216 }
217 }
218
219 pub fn bit() -> Self {
221 Self {
222 type_id: 0x68, max_length: Some(1),
224 precision: None,
225 scale: None,
226 collation: None,
227 tvp_type_name: None,
228 }
229 }
230
231 pub fn float() -> Self {
233 Self {
234 type_id: 0x6D, max_length: Some(8),
236 precision: None,
237 scale: None,
238 collation: None,
239 tvp_type_name: None,
240 }
241 }
242
243 pub fn real() -> Self {
245 Self {
246 type_id: 0x6D, max_length: Some(4),
248 precision: None,
249 scale: None,
250 collation: None,
251 tvp_type_name: None,
252 }
253 }
254
255 pub fn nvarchar(max_len: u16) -> Self {
257 Self {
258 type_id: 0xE7, max_length: Some(max_len * 2), precision: None,
261 scale: None,
262 collation: Some([0x09, 0x04, 0xD0, 0x00, 0x34]),
264 tvp_type_name: None,
265 }
266 }
267
268 pub fn nvarchar_max() -> Self {
270 Self {
271 type_id: 0xE7, max_length: Some(0xFFFF), precision: None,
274 scale: None,
275 collation: Some([0x09, 0x04, 0xD0, 0x00, 0x34]),
276 tvp_type_name: None,
277 }
278 }
279
280 const DEFAULT_COLLATION: [u8; 5] = [0x09, 0x04, 0xD0, 0x00, 0x34];
282
283 pub fn varchar(max_len: u16) -> Self {
285 Self::varchar_with_collation(max_len, Self::DEFAULT_COLLATION)
286 }
287
288 pub fn varchar_with_collation(max_len: u16, collation: [u8; 5]) -> Self {
290 Self {
291 type_id: 0xA7, max_length: Some(max_len),
293 precision: None,
294 scale: None,
295 collation: Some(collation),
296 tvp_type_name: None,
297 }
298 }
299
300 pub fn varchar_max() -> Self {
302 Self::varchar_max_with_collation(Self::DEFAULT_COLLATION)
303 }
304
305 pub fn varchar_max_with_collation(collation: [u8; 5]) -> Self {
307 Self {
308 type_id: 0xA7, max_length: Some(0xFFFF), precision: None,
311 scale: None,
312 collation: Some(collation),
313 tvp_type_name: None,
314 }
315 }
316
317 pub fn varbinary(max_len: u16) -> Self {
319 Self {
320 type_id: 0xA5, max_length: Some(max_len),
322 precision: None,
323 scale: None,
324 collation: None,
325 tvp_type_name: None,
326 }
327 }
328
329 pub fn varbinary_max() -> Self {
331 Self {
332 type_id: 0xA5, max_length: Some(0xFFFF), precision: None,
335 scale: None,
336 collation: None,
337 tvp_type_name: None,
338 }
339 }
340
341 pub fn uniqueidentifier() -> Self {
343 Self {
344 type_id: 0x24, max_length: Some(16),
346 precision: None,
347 scale: None,
348 collation: None,
349 tvp_type_name: None,
350 }
351 }
352
353 pub fn uuid() -> Self {
355 Self {
356 type_id: 0x24, max_length: Some(16),
358 precision: None,
359 scale: None,
360 collation: None,
361 tvp_type_name: None,
362 }
363 }
364
365 pub fn date() -> Self {
367 Self {
368 type_id: 0x28, max_length: None,
370 precision: None,
371 scale: None,
372 collation: None,
373 tvp_type_name: None,
374 }
375 }
376
377 pub fn time(scale: u8) -> Self {
379 Self {
380 type_id: 0x29, max_length: None,
382 precision: None,
383 scale: Some(scale),
384 collation: None,
385 tvp_type_name: None,
386 }
387 }
388
389 pub fn datetime2(scale: u8) -> Self {
391 Self {
392 type_id: 0x2A, max_length: None,
394 precision: None,
395 scale: Some(scale),
396 collation: None,
397 tvp_type_name: None,
398 }
399 }
400
401 pub fn datetimeoffset(scale: u8) -> Self {
403 Self {
404 type_id: 0x2B, max_length: None,
406 precision: None,
407 scale: Some(scale),
408 collation: None,
409 tvp_type_name: None,
410 }
411 }
412
413 pub fn decimal(precision: u8, scale: u8) -> Self {
415 Self {
416 type_id: 0x6C, max_length: Some(17), precision: Some(precision),
419 scale: Some(scale),
420 collation: None,
421 tvp_type_name: None,
422 }
423 }
424
425 pub fn money() -> Self {
427 Self {
428 type_id: 0x6E, max_length: Some(8),
430 precision: None,
431 scale: None,
432 collation: None,
433 tvp_type_name: None,
434 }
435 }
436
437 pub fn smallmoney() -> Self {
439 Self {
440 type_id: 0x6E, max_length: Some(4),
442 precision: None,
443 scale: None,
444 collation: None,
445 tvp_type_name: None,
446 }
447 }
448
449 pub fn smalldatetime() -> Self {
451 Self {
452 type_id: 0x6F, max_length: Some(4),
454 precision: None,
455 scale: None,
456 collation: None,
457 tvp_type_name: None,
458 }
459 }
460
461 pub fn tvp(type_name: impl Into<String>) -> Self {
466 Self {
467 type_id: 0xF3, max_length: None,
469 precision: None,
470 scale: None,
471 collation: None,
472 tvp_type_name: Some(type_name.into()),
473 }
474 }
475
476 pub fn encode(&self, buf: &mut BytesMut) {
478 if self.type_id != 0xF3 {
481 buf.put_u8(self.type_id);
482 }
483
484 match self.type_id {
486 0x26 | 0x68 | 0x6D | 0x6E | 0x6F => {
487 if let Some(len) = self.max_length {
489 buf.put_u8(len as u8);
490 }
491 }
492 0xE7 | 0xA7 | 0xA5 | 0xEF => {
493 if let Some(len) = self.max_length {
495 buf.put_u16_le(len);
496 }
497 if let Some(collation) = self.collation {
499 buf.put_slice(&collation);
500 }
501 }
502 0x24 => {
503 if let Some(len) = self.max_length {
505 buf.put_u8(len as u8);
506 }
507 }
508 0x29..=0x2B => {
509 if let Some(scale) = self.scale {
511 buf.put_u8(scale);
512 }
513 }
514 0x6C | 0x6A => {
515 if let Some(len) = self.max_length {
517 buf.put_u8(len as u8);
518 }
519 if let Some(precision) = self.precision {
520 buf.put_u8(precision);
521 }
522 if let Some(scale) = self.scale {
523 buf.put_u8(scale);
524 }
525 }
526 _ => {}
527 }
528 }
529}
530
531#[derive(Debug, Clone)]
538pub struct EncryptedParamMetadata {
539 pub base_type_info: TypeInfo,
541 pub algorithm_id: u8,
543 pub encryption_type: EncryptionTypeWire,
545 pub database_id: u32,
547 pub cek_id: u32,
549 pub cek_version: u32,
551 pub cek_md_version: u64,
553 pub normalization_rule_version: u8,
555}
556
557impl EncryptedParamMetadata {
558 pub fn encode(&self, buf: &mut BytesMut) {
561 self.base_type_info.encode(buf);
562 buf.put_u8(self.algorithm_id);
563 buf.put_u8(self.encryption_type.to_u8());
564 buf.put_u32_le(self.database_id);
565 buf.put_u32_le(self.cek_id);
566 buf.put_u32_le(self.cek_version);
567 buf.put_u64_le(self.cek_md_version);
568 buf.put_u8(self.normalization_rule_version);
569 }
570}
571
572#[derive(Debug, Clone)]
574pub struct RpcParam {
575 pub name: String,
577 pub flags: ParamFlags,
579 pub type_info: TypeInfo,
581 pub value: Option<Bytes>,
583 pub crypto_metadata: Option<EncryptedParamMetadata>,
586}
587
588impl RpcParam {
589 pub fn new(name: impl Into<String>, type_info: TypeInfo, value: Bytes) -> Self {
591 Self {
592 name: name.into(),
593 flags: ParamFlags::default(),
594 type_info,
595 value: Some(value),
596 crypto_metadata: None,
597 }
598 }
599
600 pub fn null(name: impl Into<String>, type_info: TypeInfo) -> Self {
602 Self {
603 name: name.into(),
604 flags: ParamFlags::default(),
605 type_info,
606 value: None,
607 crypto_metadata: None,
608 }
609 }
610
611 pub fn encrypted(
618 name: impl Into<String>,
619 ciphertext: Bytes,
620 metadata: EncryptedParamMetadata,
621 ) -> Self {
622 Self {
623 name: name.into(),
624 flags: ParamFlags {
625 encrypted: true,
626 ..ParamFlags::default()
627 },
628 type_info: TypeInfo::varbinary_max(),
629 value: Some(ciphertext),
630 crypto_metadata: Some(metadata),
631 }
632 }
633
634 pub fn encrypted_null(name: impl Into<String>, metadata: EncryptedParamMetadata) -> Self {
640 Self {
641 name: name.into(),
642 flags: ParamFlags {
643 encrypted: true,
644 ..ParamFlags::default()
645 },
646 type_info: TypeInfo::varbinary_max(),
647 value: None,
648 crypto_metadata: Some(metadata),
649 }
650 }
651
652 pub fn int(name: impl Into<String>, value: i32) -> Self {
654 let mut buf = BytesMut::with_capacity(4);
655 buf.put_i32_le(value);
656 Self::new(name, TypeInfo::int(), buf.freeze())
657 }
658
659 pub fn bigint(name: impl Into<String>, value: i64) -> Self {
661 let mut buf = BytesMut::with_capacity(8);
662 buf.put_i64_le(value);
663 Self::new(name, TypeInfo::bigint(), buf.freeze())
664 }
665
666 pub fn nvarchar(name: impl Into<String>, value: &str) -> Self {
668 let mut buf = BytesMut::new();
669 let mut code_units: usize = 0;
670 for code_unit in value.encode_utf16() {
671 buf.put_u16_le(code_unit);
672 code_units += 1;
673 }
674 let type_info = if code_units > 4000 {
680 TypeInfo::nvarchar_max()
681 } else {
682 TypeInfo::nvarchar(code_units.max(1) as u16)
683 };
684 Self::new(name, type_info, buf.freeze())
685 }
686
687 pub fn varchar(name: impl Into<String>, value: &str) -> Self {
697 let encoded = Self::encode_varchar_bytes(value);
698 let byte_len = encoded.len();
699 let type_info = if byte_len > 8000 {
700 TypeInfo::varchar_max()
701 } else {
702 TypeInfo::varchar(byte_len.max(1) as u16)
703 };
704 Self::new(name, type_info, Bytes::from(encoded))
705 }
706
707 fn encode_varchar_bytes(value: &str) -> Vec<u8> {
710 crate::collation::encode_str_for_collation(value, None)
711 }
712
713 pub fn varchar_with_collation(
718 name: impl Into<String>,
719 value: &str,
720 collation: &Collation,
721 ) -> Self {
722 let collation_bytes = collation.to_bytes();
723 let encoded = Self::encode_varchar_bytes_for_collation(value, collation);
724 let byte_len = encoded.len();
725 let type_info = if byte_len > 8000 {
726 TypeInfo::varchar_max_with_collation(collation_bytes)
727 } else {
728 TypeInfo::varchar_with_collation(byte_len.max(1) as u16, collation_bytes)
729 };
730 Self::new(name, type_info, Bytes::from(encoded))
731 }
732
733 fn encode_varchar_bytes_for_collation(value: &str, collation: &Collation) -> Vec<u8> {
735 crate::collation::encode_str_for_collation(value, Some(collation))
736 }
737
738 #[must_use]
740 pub fn as_output(mut self) -> Self {
741 self.flags = self.flags.output();
742 self
743 }
744
745 pub fn encode(&self, buf: &mut BytesMut) {
747 let name_len = self.name.encode_utf16().count() as u8;
749 buf.put_u8(name_len);
750 if name_len > 0 {
751 for code_unit in self.name.encode_utf16() {
752 buf.put_u16_le(code_unit);
753 }
754 }
755
756 buf.put_u8(self.flags.encode());
758
759 self.type_info.encode(buf);
761
762 if let Some(ref value) = self.value {
764 match self.type_info.type_id {
766 0x26 => {
767 buf.put_u8(value.len() as u8);
769 buf.put_slice(value);
770 }
771 0x68 | 0x6D | 0x6E | 0x6F => {
772 buf.put_u8(value.len() as u8);
774 buf.put_slice(value);
775 }
776 0xE7 | 0xA7 | 0xA5 => {
777 if self.type_info.max_length == Some(0xFFFF) {
779 let total_len = value.len() as u64;
782 buf.put_u64_le(total_len);
783 buf.put_u32_le(value.len() as u32);
784 buf.put_slice(value);
785 buf.put_u32_le(0); } else {
787 buf.put_u16_le(value.len() as u16);
788 buf.put_slice(value);
789 }
790 }
791 0x24 => {
792 buf.put_u8(value.len() as u8);
794 buf.put_slice(value);
795 }
796 0x28..=0x2B => {
797 buf.put_u8(value.len() as u8);
799 buf.put_slice(value);
800 }
801 0x6C => {
802 buf.put_u8(value.len() as u8);
804 buf.put_slice(value);
805 }
806 0xF3 => {
807 buf.put_slice(value);
811 }
812 _ => {
813 buf.put_u8(value.len() as u8);
815 buf.put_slice(value);
816 }
817 }
818 } else {
819 match self.type_info.type_id {
821 0xE7 | 0xA7 | 0xA5 => {
822 if self.type_info.max_length == Some(0xFFFF) {
824 buf.put_u64_le(0xFFFFFFFFFFFFFFFF); } else {
826 buf.put_u16_le(0xFFFF);
827 }
828 }
829 _ => {
830 buf.put_u8(0); }
832 }
833 }
834
835 if let Some(ref metadata) = self.crypto_metadata {
837 metadata.encode(buf);
838 }
839 }
840}
841
842#[derive(Debug, Clone)]
844pub struct RpcRequest {
845 proc_name: Option<String>,
847 proc_id: Option<ProcId>,
849 options: RpcOptionFlags,
851 params: Vec<RpcParam>,
853}
854
855impl RpcRequest {
856 pub fn named(proc_name: impl Into<String>) -> Self {
858 Self {
859 proc_name: Some(proc_name.into()),
860 proc_id: None,
861 options: RpcOptionFlags::default(),
862 params: Vec::new(),
863 }
864 }
865
866 pub fn by_id(proc_id: ProcId) -> Self {
868 Self {
869 proc_name: None,
870 proc_id: Some(proc_id),
871 options: RpcOptionFlags::default(),
872 params: Vec::new(),
873 }
874 }
875
876 pub fn execute_sql(sql: &str, params: Vec<RpcParam>) -> Self {
894 let mut request = Self::by_id(ProcId::ExecuteSql);
895
896 request.params.push(RpcParam::nvarchar("", sql));
898
899 if !params.is_empty() {
901 let declarations = Self::build_param_declarations(¶ms);
902 request.params.push(RpcParam::nvarchar("", &declarations));
903 }
904
905 request.params.extend(params);
907
908 request
909 }
910
911 pub fn build_param_declarations(params: &[RpcParam]) -> String {
919 params
920 .iter()
921 .map(|p| {
922 let name = if p.name.starts_with('@') {
923 p.name.clone()
924 } else if p.name.is_empty() {
925 format!(
927 "@p{}",
928 params.iter().position(|x| x.name == p.name).unwrap_or(0) + 1
929 )
930 } else {
931 format!("@{}", p.name)
932 };
933
934 let ti = p
937 .crypto_metadata
938 .as_ref()
939 .map(|m| &m.base_type_info)
940 .unwrap_or(&p.type_info);
941
942 let type_name: String = match ti.type_id {
943 0x26 => match ti.max_length {
944 Some(1) => "tinyint".to_string(),
945 Some(2) => "smallint".to_string(),
946 Some(4) => "int".to_string(),
947 Some(8) => "bigint".to_string(),
948 _ => "int".to_string(),
949 },
950 0x68 => "bit".to_string(),
951 0x6D => match ti.max_length {
952 Some(4) => "real".to_string(),
953 _ => "float".to_string(),
954 },
955 0xE7 => {
956 if ti.max_length == Some(0xFFFF) {
957 "nvarchar(max)".to_string()
958 } else {
959 let len = ti.max_length.unwrap_or(4000) / 2;
960 format!("nvarchar({len})")
961 }
962 }
963 0xA7 => {
964 if ti.max_length == Some(0xFFFF) {
965 "varchar(max)".to_string()
966 } else {
967 let len = ti.max_length.unwrap_or(8000);
968 format!("varchar({len})")
969 }
970 }
971 0xA5 => {
972 if ti.max_length == Some(0xFFFF) {
973 "varbinary(max)".to_string()
974 } else {
975 let len = ti.max_length.unwrap_or(8000);
976 format!("varbinary({len})")
977 }
978 }
979 0x24 => "uniqueidentifier".to_string(),
980 0x28 => "date".to_string(),
981 0x29 => {
982 let scale = ti.scale.unwrap_or(7);
983 format!("time({scale})")
984 }
985 0x2A => {
986 let scale = ti.scale.unwrap_or(7);
987 format!("datetime2({scale})")
988 }
989 0x2B => {
990 let scale = ti.scale.unwrap_or(7);
991 format!("datetimeoffset({scale})")
992 }
993 0x6C => {
994 let precision = ti.precision.unwrap_or(18);
995 let scale = ti.scale.unwrap_or(0);
996 format!("decimal({precision}, {scale})")
997 }
998 0x6E => match ti.max_length {
999 Some(4) => "smallmoney".to_string(),
1000 _ => "money".to_string(),
1001 },
1002 0x6F => match ti.max_length {
1003 Some(4) => "smalldatetime".to_string(),
1004 _ => "datetime".to_string(),
1005 },
1006 0xF3 => {
1007 if let Some(ref tvp_name) = ti.tvp_type_name {
1010 format!("{tvp_name} READONLY")
1011 } else {
1012 "sql_variant".to_string()
1014 }
1015 }
1016 _ => "sql_variant".to_string(),
1017 };
1018
1019 format!("{name} {type_name}")
1020 })
1021 .collect::<Vec<_>>()
1022 .join(", ")
1023 }
1024
1025 pub fn prepare(sql: &str, params: &[RpcParam]) -> Self {
1027 let mut request = Self::by_id(ProcId::Prepare);
1028
1029 request
1031 .params
1032 .push(RpcParam::null("@handle", TypeInfo::int()).as_output());
1033
1034 let declarations = Self::build_param_declarations(params);
1036 request
1037 .params
1038 .push(RpcParam::nvarchar("@params", &declarations));
1039
1040 request.params.push(RpcParam::nvarchar("@stmt", sql));
1042
1043 request.params.push(RpcParam::int("@options", 1));
1045
1046 request
1047 }
1048
1049 pub fn execute(handle: i32, params: Vec<RpcParam>) -> Self {
1051 let mut request = Self::by_id(ProcId::Execute);
1052
1053 request.params.push(RpcParam::int("@handle", handle));
1055
1056 request.params.extend(params);
1058
1059 request
1060 }
1061
1062 pub fn unprepare(handle: i32) -> Self {
1064 let mut request = Self::by_id(ProcId::Unprepare);
1065 request.params.push(RpcParam::int("@handle", handle));
1066 request
1067 }
1068
1069 #[must_use]
1071 pub fn with_options(mut self, options: RpcOptionFlags) -> Self {
1072 self.options = options;
1073 self
1074 }
1075
1076 #[must_use]
1078 pub fn param(mut self, param: RpcParam) -> Self {
1079 self.params.push(param);
1080 self
1081 }
1082
1083 #[must_use]
1087 pub fn encode(&self) -> Bytes {
1088 self.encode_with_transaction(0)
1089 }
1090
1091 #[must_use]
1103 pub fn encode_with_transaction(&self, transaction_descriptor: u64) -> Bytes {
1104 let mut buf = BytesMut::with_capacity(256);
1105
1106 let all_headers_start = buf.len();
1109 buf.put_u32_le(0); buf.put_u32_le(18); buf.put_u16_le(0x0002); buf.put_u64_le(transaction_descriptor); buf.put_u32_le(1); let all_headers_len = buf.len() - all_headers_start;
1120 let len_bytes = (all_headers_len as u32).to_le_bytes();
1121 buf[all_headers_start..all_headers_start + 4].copy_from_slice(&len_bytes);
1122
1123 if let Some(proc_id) = self.proc_id {
1125 buf.put_u16_le(0xFFFF); buf.put_u16_le(proc_id as u16);
1128 } else if let Some(ref proc_name) = self.proc_name {
1129 let name_len = proc_name.encode_utf16().count() as u16;
1131 buf.put_u16_le(name_len);
1132 write_utf16_string(&mut buf, proc_name);
1133 }
1134
1135 buf.put_u16_le(self.options.encode());
1137
1138 for param in &self.params {
1140 param.encode(&mut buf);
1141 }
1142
1143 buf.freeze()
1144 }
1145}
1146
1147#[cfg(test)]
1148#[allow(clippy::unwrap_used)]
1149mod tests {
1150 use super::*;
1151
1152 #[test]
1161 fn encrypted_param_encode_matches_captured_dotnet_wire() {
1162 fn unhex(s: &str) -> Vec<u8> {
1163 (0..s.len())
1164 .step_by(2)
1165 .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
1166 .collect()
1167 }
1168
1169 let mut nvarchar_50 = TypeInfo::nvarchar(50);
1171 nvarchar_50.collation = Some([0x09, 0x04, 0xd0, 0x00, 0x34]);
1172
1173 let cases = [
1174 (
1175 "@i",
1176 TypeInfo::int(),
1177 "024000690008a5ffff41000000000000004100000001ed7b8c6030870d92358f12acb0d0c69c00bc3aa3ba578ecb0ea5f514b5045912a2b1ae52ed834f6bac49520956e4a574c30d573590fb3785556c8fe42f87c5b4000000002604020106000000020000000100000023722f0069b4000001",
1178 ),
1179 (
1180 "@s",
1181 nvarchar_50,
1182 "024000730008a5ffff4100000000000000410000000150c0a7dec4d4241c7a4a617007d32d97e7131f8c57a5ad212487891170f12ecb9957fce16389f4728d1c3c65813beeea085ae3fd516d29f84298df3e97f0d05d00000000e764000904d00034020106000000020000000100000023722f0069b4000001",
1183 ),
1184 (
1185 "@b",
1186 TypeInfo::varbinary(50),
1187 "024000620008a5ffff41000000000000004100000001d17165aa6df0155be6b78c6712d3b03870ea394cfed10956cf07fbfa204c4b82cddfa5e2f4fc03335f579e2767657e3067cd9da7d62a07427106b91f747b97da00000000a53200020106000000020000000100000023722f0069b4000001",
1188 ),
1189 ];
1190
1191 for (name, base_type_info, golden_hex) in cases {
1192 let golden = unhex(golden_hex);
1193 let cipher_off = 1 + name.encode_utf16().count() * 2 + 1 + 3 + 8 + 4;
1196 let cipher = Bytes::copy_from_slice(&golden[cipher_off..cipher_off + 65]);
1197
1198 let param = RpcParam::encrypted(
1199 name,
1200 cipher,
1201 EncryptedParamMetadata {
1202 base_type_info,
1203 algorithm_id: 2,
1204 encryption_type: EncryptionTypeWire::Deterministic,
1205 database_id: 6,
1206 cek_id: 2,
1207 cek_version: 1,
1208 cek_md_version: 0x0000_b469_002f_7223,
1209 normalization_rule_version: 1,
1210 },
1211 );
1212
1213 let mut buf = BytesMut::new();
1214 param.encode(&mut buf);
1215 assert_eq!(
1216 buf.to_vec(),
1217 golden,
1218 "encrypted {name} param must match the captured Microsoft.Data.SqlClient bytes"
1219 );
1220 }
1221 }
1222
1223 #[test]
1224 fn test_proc_id_values() {
1225 assert_eq!(ProcId::ExecuteSql as u16, 0x000A);
1226 assert_eq!(ProcId::Prepare as u16, 0x000B);
1227 assert_eq!(ProcId::Execute as u16, 0x000C);
1228 assert_eq!(ProcId::Unprepare as u16, 0x000F);
1229 }
1230
1231 #[test]
1232 fn test_option_flags_encode() {
1233 let flags = RpcOptionFlags::new().with_recompile(true);
1234 assert_eq!(flags.encode(), 0x0001);
1235 }
1236
1237 #[test]
1238 fn test_param_flags_encode() {
1239 let flags = ParamFlags::new().output();
1240 assert_eq!(flags.encode(), 0x01);
1241 }
1242
1243 #[test]
1244 fn test_int_param() {
1245 let param = RpcParam::int("@p1", 42);
1246 assert_eq!(param.name, "@p1");
1247 assert_eq!(param.type_info.type_id, 0x26);
1248 assert!(param.value.is_some());
1249 }
1250
1251 #[test]
1252 fn test_nvarchar_param() {
1253 let param = RpcParam::nvarchar("@name", "Alice");
1254 assert_eq!(param.name, "@name");
1255 assert_eq!(param.type_info.type_id, 0xE7);
1256 assert_eq!(param.value.as_ref().unwrap().len(), 10);
1258 }
1259
1260 #[test]
1261 fn test_nvarchar_param_surrogate_pair_length() {
1262 let param = RpcParam::nvarchar("@p", "π");
1266 assert_eq!(param.value.as_ref().unwrap().len(), 4);
1267 assert_eq!(param.type_info.max_length, Some(4));
1269
1270 let param = RpcParam::nvarchar("@p", "Hello δΈη π");
1271 assert_eq!(param.value.as_ref().unwrap().len(), 22);
1273 assert_eq!(param.type_info.max_length, Some(22));
1274 }
1275
1276 #[test]
1277 fn test_execute_sql_request() {
1278 let rpc = RpcRequest::execute_sql(
1279 "SELECT * FROM users WHERE id = @p1",
1280 vec![RpcParam::int("@p1", 42)],
1281 );
1282
1283 assert_eq!(rpc.proc_id, Some(ProcId::ExecuteSql));
1284 assert_eq!(rpc.params.len(), 3);
1286 }
1287
1288 #[test]
1289 fn test_param_declarations() {
1290 let params = vec![
1291 RpcParam::int("@p1", 42),
1292 RpcParam::nvarchar("@name", "Alice"),
1293 ];
1294
1295 let decls = RpcRequest::build_param_declarations(¶ms);
1296 assert!(decls.contains("@p1 int"));
1297 assert!(decls.contains("@name nvarchar"));
1298 }
1299
1300 #[test]
1301 fn test_rpc_encode_not_empty() {
1302 let rpc = RpcRequest::execute_sql("SELECT 1", vec![]);
1303 let encoded = rpc.encode();
1304 assert!(!encoded.is_empty());
1305 }
1306
1307 #[test]
1308 fn test_prepare_request() {
1309 let rpc = RpcRequest::prepare(
1310 "SELECT * FROM users WHERE id = @p1",
1311 &[RpcParam::int("@p1", 0)],
1312 );
1313
1314 assert_eq!(rpc.proc_id, Some(ProcId::Prepare));
1315 assert_eq!(rpc.params.len(), 4);
1317 assert!(rpc.params[0].flags.by_ref); }
1319
1320 #[test]
1321 fn test_execute_request() {
1322 let rpc = RpcRequest::execute(123, vec![RpcParam::int("@p1", 42)]);
1323
1324 assert_eq!(rpc.proc_id, Some(ProcId::Execute));
1325 assert_eq!(rpc.params.len(), 2); }
1327
1328 #[test]
1329 fn test_unprepare_request() {
1330 let rpc = RpcRequest::unprepare(123);
1331
1332 assert_eq!(rpc.proc_id, Some(ProcId::Unprepare));
1333 assert_eq!(rpc.params.len(), 1); }
1335
1336 #[test]
1337 fn test_varchar_param() {
1338 let param = RpcParam::varchar("@name", "Alice");
1339 assert_eq!(param.name, "@name");
1340 assert_eq!(param.type_info.type_id, 0xA7);
1341 assert_eq!(param.value.as_ref().unwrap().len(), 5);
1343 assert_eq!(¶m.value.as_ref().unwrap()[..], b"Alice");
1344 }
1345
1346 #[test]
1347 fn test_varchar_param_max() {
1348 let long_str = "a".repeat(9000);
1350 let param = RpcParam::varchar("@big", &long_str);
1351 assert_eq!(param.type_info.type_id, 0xA7);
1352 assert_eq!(param.type_info.max_length, Some(0xFFFF));
1353 assert_eq!(param.value.as_ref().unwrap().len(), 9000);
1354 }
1355
1356 #[test]
1357 fn test_varchar_param_declarations() {
1358 let params = vec![
1359 RpcParam::int("@p1", 42),
1360 RpcParam::varchar("@name", "Alice"),
1361 ];
1362
1363 let decls = RpcRequest::build_param_declarations(¶ms);
1364 assert!(decls.contains("@p1 int"));
1365 assert!(decls.contains("@name varchar(5)"));
1366 }
1367
1368 #[test]
1369 fn test_varchar_type_info_has_collation() {
1370 let ti = TypeInfo::varchar(100);
1371 assert_eq!(ti.type_id, 0xA7);
1372 assert_eq!(ti.max_length, Some(100));
1373 assert!(ti.collation.is_some());
1374 }
1375
1376 #[test]
1377 fn test_varchar_encode_round_trip() {
1378 let param = RpcParam::varchar("@val", "test value");
1380 let mut buf = bytes::BytesMut::new();
1381 param.encode(&mut buf);
1382 assert!(!buf.is_empty());
1383 }
1384
1385 #[test]
1386 fn test_collation_round_trip() {
1387 let collation = Collation {
1388 lcid: 0x00D0_0409,
1389 sort_id: 0x34,
1390 };
1391 let bytes = collation.to_bytes();
1392 assert_eq!(bytes, [0x09, 0x04, 0xD0, 0x00, 0x34]);
1393
1394 let restored = Collation::from_bytes(&bytes);
1395 assert_eq!(restored.lcid, collation.lcid);
1396 assert_eq!(restored.sort_id, collation.sort_id);
1397 }
1398
1399 #[test]
1400 fn test_varchar_with_collation_uses_custom_collation_bytes() {
1401 let collation = Collation {
1403 lcid: 0x0804,
1404 sort_id: 0,
1405 };
1406 let param = RpcParam::varchar_with_collation("@val", "test", &collation);
1407 assert_eq!(param.type_info.type_id, 0xA7);
1408 assert_eq!(param.type_info.collation, Some(collation.to_bytes()));
1410 }
1411
1412 #[test]
1413 fn test_money_type_info() {
1414 let ti = TypeInfo::money();
1415 assert_eq!(ti.type_id, 0x6E);
1416 assert_eq!(ti.max_length, Some(8));
1417 }
1418
1419 #[test]
1420 fn test_smallmoney_type_info() {
1421 let ti = TypeInfo::smallmoney();
1422 assert_eq!(ti.type_id, 0x6E);
1423 assert_eq!(ti.max_length, Some(4));
1424 }
1425
1426 #[test]
1427 fn test_smalldatetime_type_info() {
1428 let ti = TypeInfo::smalldatetime();
1429 assert_eq!(ti.type_id, 0x6F);
1430 assert_eq!(ti.max_length, Some(4));
1431 }
1432
1433 #[test]
1434 fn test_money_param_declarations() {
1435 let decls = RpcRequest::build_param_declarations(&[
1436 RpcParam::new("@m", TypeInfo::money(), Bytes::from_static(&[0u8; 8])),
1437 RpcParam::new("@sm", TypeInfo::smallmoney(), Bytes::from_static(&[0u8; 4])),
1438 RpcParam::new(
1439 "@sdt",
1440 TypeInfo::smalldatetime(),
1441 Bytes::from_static(&[0u8; 4]),
1442 ),
1443 ]);
1444 assert!(decls.contains("@m money"), "got: {decls}");
1445 assert!(decls.contains("@sm smallmoney"), "got: {decls}");
1446 assert!(decls.contains("@sdt smalldatetime"), "got: {decls}");
1447 }
1448
1449 #[test]
1450 fn test_money_typeinfo_encodes_max_length_byte() {
1451 let mut buf = bytes::BytesMut::new();
1452 TypeInfo::money().encode(&mut buf);
1453 assert_eq!(&buf[..], &[0x6E, 0x08]);
1455
1456 let mut buf = bytes::BytesMut::new();
1457 TypeInfo::smallmoney().encode(&mut buf);
1458 assert_eq!(&buf[..], &[0x6E, 0x04]);
1459
1460 let mut buf = bytes::BytesMut::new();
1461 TypeInfo::smalldatetime().encode(&mut buf);
1462 assert_eq!(&buf[..], &[0x6F, 0x04]);
1463 }
1464
1465 #[test]
1466 fn test_varchar_with_collation_default_vs_custom_differ() {
1467 let default_param = RpcParam::varchar("@val", "test");
1468 let custom_collation = Collation {
1469 lcid: 0x0419, sort_id: 0,
1471 };
1472 let custom_param = RpcParam::varchar_with_collation("@val", "test", &custom_collation);
1473 assert_ne!(
1475 default_param.type_info.collation,
1476 custom_param.type_info.collation
1477 );
1478 }
1479}