1use bytes::{BufMut, Bytes, BytesMut};
27
28use crate::codec::write_utf16_string;
29use crate::crypto::EncryptionTypeWire;
30use crate::prelude::*;
31use crate::token::Collation;
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38#[repr(u16)]
39#[non_exhaustive]
40pub enum ProcId {
41 Cursor = 0x0001,
43 CursorOpen = 0x0002,
45 CursorPrepare = 0x0003,
47 CursorExecute = 0x0004,
49 CursorPrepExec = 0x0005,
51 CursorUnprepare = 0x0006,
53 CursorFetch = 0x0007,
55 CursorOption = 0x0008,
57 CursorClose = 0x0009,
59 ExecuteSql = 0x000A,
61 Prepare = 0x000B,
63 Execute = 0x000C,
65 PrepExec = 0x000D,
67 PrepExecRpc = 0x000E,
69 Unprepare = 0x000F,
71}
72
73#[derive(Debug, Clone, Copy, Default)]
75pub struct RpcOptionFlags {
76 pub with_recompile: bool,
78 pub no_metadata: bool,
80 pub reuse_metadata: bool,
82}
83
84impl RpcOptionFlags {
85 pub fn new() -> Self {
87 Self::default()
88 }
89
90 #[must_use]
92 pub fn with_recompile(mut self, value: bool) -> Self {
93 self.with_recompile = value;
94 self
95 }
96
97 pub fn encode(&self) -> u16 {
99 let mut flags = 0u16;
100 if self.with_recompile {
101 flags |= 0x0001;
102 }
103 if self.no_metadata {
104 flags |= 0x0002;
105 }
106 if self.reuse_metadata {
107 flags |= 0x0004;
108 }
109 flags
110 }
111}
112
113#[derive(Debug, Clone, Copy, Default)]
115pub struct ParamFlags {
116 pub by_ref: bool,
118 pub default: bool,
120 pub encrypted: bool,
122}
123
124impl ParamFlags {
125 pub fn new() -> Self {
127 Self::default()
128 }
129
130 #[must_use]
132 pub fn output(mut self) -> Self {
133 self.by_ref = true;
134 self
135 }
136
137 pub fn encode(&self) -> u8 {
139 let mut flags = 0u8;
140 if self.by_ref {
141 flags |= 0x01;
142 }
143 if self.default {
144 flags |= 0x02;
145 }
146 if self.encrypted {
147 flags |= 0x08;
148 }
149 flags
150 }
151}
152
153#[derive(Debug, Clone)]
155pub struct TypeInfo {
156 pub type_id: u8,
158 pub max_length: Option<u16>,
160 pub precision: Option<u8>,
162 pub scale: Option<u8>,
164 pub collation: Option<[u8; 5]>,
166 pub tvp_type_name: Option<String>,
168}
169
170impl TypeInfo {
171 pub fn int() -> Self {
173 Self {
174 type_id: 0x26, max_length: Some(4),
176 precision: None,
177 scale: None,
178 collation: None,
179 tvp_type_name: None,
180 }
181 }
182
183 pub fn bigint() -> Self {
185 Self {
186 type_id: 0x26, max_length: Some(8),
188 precision: None,
189 scale: None,
190 collation: None,
191 tvp_type_name: None,
192 }
193 }
194
195 pub fn smallint() -> Self {
197 Self {
198 type_id: 0x26, max_length: Some(2),
200 precision: None,
201 scale: None,
202 collation: None,
203 tvp_type_name: None,
204 }
205 }
206
207 pub fn tinyint() -> Self {
209 Self {
210 type_id: 0x26, max_length: Some(1),
212 precision: None,
213 scale: None,
214 collation: None,
215 tvp_type_name: None,
216 }
217 }
218
219 pub fn bit() -> Self {
221 Self {
222 type_id: 0x68, max_length: Some(1),
224 precision: None,
225 scale: None,
226 collation: None,
227 tvp_type_name: None,
228 }
229 }
230
231 pub fn float() -> Self {
233 Self {
234 type_id: 0x6D, max_length: Some(8),
236 precision: None,
237 scale: None,
238 collation: None,
239 tvp_type_name: None,
240 }
241 }
242
243 pub fn real() -> Self {
245 Self {
246 type_id: 0x6D, max_length: Some(4),
248 precision: None,
249 scale: None,
250 collation: None,
251 tvp_type_name: None,
252 }
253 }
254
255 pub fn nvarchar(max_len: u16) -> Self {
257 Self {
258 type_id: 0xE7, max_length: Some(max_len * 2), precision: None,
261 scale: None,
262 collation: Some([0x09, 0x04, 0xD0, 0x00, 0x34]),
264 tvp_type_name: None,
265 }
266 }
267
268 pub fn nvarchar_max() -> Self {
270 Self {
271 type_id: 0xE7, max_length: Some(0xFFFF), precision: None,
274 scale: None,
275 collation: Some([0x09, 0x04, 0xD0, 0x00, 0x34]),
276 tvp_type_name: None,
277 }
278 }
279
280 const DEFAULT_COLLATION: [u8; 5] = [0x09, 0x04, 0xD0, 0x00, 0x34];
282
283 pub fn varchar(max_len: u16) -> Self {
285 Self::varchar_with_collation(max_len, Self::DEFAULT_COLLATION)
286 }
287
288 pub fn varchar_with_collation(max_len: u16, collation: [u8; 5]) -> Self {
290 Self {
291 type_id: 0xA7, max_length: Some(max_len),
293 precision: None,
294 scale: None,
295 collation: Some(collation),
296 tvp_type_name: None,
297 }
298 }
299
300 pub fn varchar_max() -> Self {
302 Self::varchar_max_with_collation(Self::DEFAULT_COLLATION)
303 }
304
305 pub fn varchar_max_with_collation(collation: [u8; 5]) -> Self {
307 Self {
308 type_id: 0xA7, max_length: Some(0xFFFF), precision: None,
311 scale: None,
312 collation: Some(collation),
313 tvp_type_name: None,
314 }
315 }
316
317 pub fn varbinary(max_len: u16) -> Self {
319 Self {
320 type_id: 0xA5, max_length: Some(max_len),
322 precision: None,
323 scale: None,
324 collation: None,
325 tvp_type_name: None,
326 }
327 }
328
329 pub fn varbinary_max() -> Self {
331 Self {
332 type_id: 0xA5, max_length: Some(0xFFFF), precision: None,
335 scale: None,
336 collation: None,
337 tvp_type_name: None,
338 }
339 }
340
341 pub fn char(len: u16) -> Self {
343 Self {
344 type_id: 0xAF, max_length: Some(len),
346 precision: None,
347 scale: None,
348 collation: Some(Self::DEFAULT_COLLATION),
349 tvp_type_name: None,
350 }
351 }
352
353 pub fn nchar(len: u16) -> Self {
355 Self {
356 type_id: 0xEF, max_length: Some(len * 2), precision: None,
359 scale: None,
360 collation: Some(Self::DEFAULT_COLLATION),
361 tvp_type_name: None,
362 }
363 }
364
365 pub fn binary(len: u16) -> Self {
367 Self {
368 type_id: 0xAD, max_length: Some(len),
370 precision: None,
371 scale: None,
372 collation: None,
373 tvp_type_name: None,
374 }
375 }
376
377 pub fn uniqueidentifier() -> Self {
379 Self {
380 type_id: 0x24, max_length: Some(16),
382 precision: None,
383 scale: None,
384 collation: None,
385 tvp_type_name: None,
386 }
387 }
388
389 pub fn uuid() -> Self {
391 Self {
392 type_id: 0x24, max_length: Some(16),
394 precision: None,
395 scale: None,
396 collation: None,
397 tvp_type_name: None,
398 }
399 }
400
401 pub fn date() -> Self {
403 Self {
404 type_id: 0x28, max_length: None,
406 precision: None,
407 scale: None,
408 collation: None,
409 tvp_type_name: None,
410 }
411 }
412
413 pub fn time(scale: u8) -> Self {
415 Self {
416 type_id: 0x29, max_length: None,
418 precision: None,
419 scale: Some(scale),
420 collation: None,
421 tvp_type_name: None,
422 }
423 }
424
425 pub fn datetime2(scale: u8) -> Self {
427 Self {
428 type_id: 0x2A, max_length: None,
430 precision: None,
431 scale: Some(scale),
432 collation: None,
433 tvp_type_name: None,
434 }
435 }
436
437 pub fn datetimeoffset(scale: u8) -> Self {
439 Self {
440 type_id: 0x2B, max_length: None,
442 precision: None,
443 scale: Some(scale),
444 collation: None,
445 tvp_type_name: None,
446 }
447 }
448
449 pub fn decimal(precision: u8, scale: u8) -> Self {
451 Self {
452 type_id: 0x6C, max_length: Some(17), precision: Some(precision),
455 scale: Some(scale),
456 collation: None,
457 tvp_type_name: None,
458 }
459 }
460
461 pub fn money() -> Self {
463 Self {
464 type_id: 0x6E, max_length: Some(8),
466 precision: None,
467 scale: None,
468 collation: None,
469 tvp_type_name: None,
470 }
471 }
472
473 pub fn smallmoney() -> Self {
475 Self {
476 type_id: 0x6E, max_length: Some(4),
478 precision: None,
479 scale: None,
480 collation: None,
481 tvp_type_name: None,
482 }
483 }
484
485 pub fn smalldatetime() -> Self {
487 Self {
488 type_id: 0x6F, max_length: Some(4),
490 precision: None,
491 scale: None,
492 collation: None,
493 tvp_type_name: None,
494 }
495 }
496
497 pub fn datetime() -> Self {
499 Self {
500 type_id: 0x6F, max_length: Some(8),
502 precision: None,
503 scale: None,
504 collation: None,
505 tvp_type_name: None,
506 }
507 }
508
509 pub fn tvp(type_name: impl Into<String>) -> Self {
514 Self {
515 type_id: 0xF3, max_length: None,
517 precision: None,
518 scale: None,
519 collation: None,
520 tvp_type_name: Some(type_name.into()),
521 }
522 }
523
524 pub fn encode(&self, buf: &mut BytesMut) {
526 if self.type_id != 0xF3 {
529 buf.put_u8(self.type_id);
530 }
531
532 match self.type_id {
534 0x26 | 0x68 | 0x6D | 0x6E | 0x6F => {
535 if let Some(len) = self.max_length {
537 buf.put_u8(len as u8);
538 }
539 }
540 0xE7 | 0xA7 | 0xA5 | 0xEF | 0xAF | 0xAD => {
541 if let Some(len) = self.max_length {
544 buf.put_u16_le(len);
545 }
546 if let Some(collation) = self.collation {
548 buf.put_slice(&collation);
549 }
550 }
551 0x24 => {
552 if let Some(len) = self.max_length {
554 buf.put_u8(len as u8);
555 }
556 }
557 0x29..=0x2B => {
558 if let Some(scale) = self.scale {
560 buf.put_u8(scale);
561 }
562 }
563 0x6C | 0x6A => {
564 if let Some(len) = self.max_length {
566 buf.put_u8(len as u8);
567 }
568 if let Some(precision) = self.precision {
569 buf.put_u8(precision);
570 }
571 if let Some(scale) = self.scale {
572 buf.put_u8(scale);
573 }
574 }
575 _ => {}
576 }
577 }
578}
579
580#[derive(Debug, Clone)]
587pub struct EncryptedParamMetadata {
588 pub base_type_info: TypeInfo,
590 pub algorithm_id: u8,
592 pub encryption_type: EncryptionTypeWire,
594 pub database_id: u32,
596 pub cek_id: u32,
598 pub cek_version: u32,
600 pub cek_md_version: u64,
602 pub normalization_rule_version: u8,
604}
605
606impl EncryptedParamMetadata {
607 pub fn encode(&self, buf: &mut BytesMut) {
610 self.base_type_info.encode(buf);
611 buf.put_u8(self.algorithm_id);
612 buf.put_u8(self.encryption_type.to_u8());
613 buf.put_u32_le(self.database_id);
614 buf.put_u32_le(self.cek_id);
615 buf.put_u32_le(self.cek_version);
616 buf.put_u64_le(self.cek_md_version);
617 buf.put_u8(self.normalization_rule_version);
618 }
619}
620
621#[derive(Debug, Clone)]
623pub struct RpcParam {
624 pub name: String,
626 pub flags: ParamFlags,
628 pub type_info: TypeInfo,
630 pub value: Option<Bytes>,
632 pub crypto_metadata: Option<EncryptedParamMetadata>,
635}
636
637impl RpcParam {
638 pub fn new(name: impl Into<String>, type_info: TypeInfo, value: Bytes) -> Self {
640 Self {
641 name: name.into(),
642 flags: ParamFlags::default(),
643 type_info,
644 value: Some(value),
645 crypto_metadata: None,
646 }
647 }
648
649 pub fn null(name: impl Into<String>, type_info: TypeInfo) -> Self {
651 Self {
652 name: name.into(),
653 flags: ParamFlags::default(),
654 type_info,
655 value: None,
656 crypto_metadata: None,
657 }
658 }
659
660 pub fn encrypted(
667 name: impl Into<String>,
668 ciphertext: Bytes,
669 metadata: EncryptedParamMetadata,
670 ) -> Self {
671 Self {
672 name: name.into(),
673 flags: ParamFlags {
674 encrypted: true,
675 ..ParamFlags::default()
676 },
677 type_info: TypeInfo::varbinary_max(),
678 value: Some(ciphertext),
679 crypto_metadata: Some(metadata),
680 }
681 }
682
683 pub fn encrypted_null(name: impl Into<String>, metadata: EncryptedParamMetadata) -> Self {
689 Self {
690 name: name.into(),
691 flags: ParamFlags {
692 encrypted: true,
693 ..ParamFlags::default()
694 },
695 type_info: TypeInfo::varbinary_max(),
696 value: None,
697 crypto_metadata: Some(metadata),
698 }
699 }
700
701 pub fn int(name: impl Into<String>, value: i32) -> Self {
703 let mut buf = BytesMut::with_capacity(4);
704 buf.put_i32_le(value);
705 Self::new(name, TypeInfo::int(), buf.freeze())
706 }
707
708 pub fn bigint(name: impl Into<String>, value: i64) -> Self {
710 let mut buf = BytesMut::with_capacity(8);
711 buf.put_i64_le(value);
712 Self::new(name, TypeInfo::bigint(), buf.freeze())
713 }
714
715 pub fn nvarchar(name: impl Into<String>, value: &str) -> Self {
717 let mut buf = BytesMut::new();
718 let mut code_units: usize = 0;
719 for code_unit in value.encode_utf16() {
720 buf.put_u16_le(code_unit);
721 code_units += 1;
722 }
723 let type_info = if code_units > 4000 {
729 TypeInfo::nvarchar_max()
730 } else {
731 TypeInfo::nvarchar(code_units.max(1) as u16)
732 };
733 Self::new(name, type_info, buf.freeze())
734 }
735
736 pub fn varchar(name: impl Into<String>, value: &str) -> Self {
746 let encoded = Self::encode_varchar_bytes(value);
747 let byte_len = encoded.len();
748 let type_info = if byte_len > 8000 {
749 TypeInfo::varchar_max()
750 } else {
751 TypeInfo::varchar(byte_len.max(1) as u16)
752 };
753 Self::new(name, type_info, Bytes::from(encoded))
754 }
755
756 fn encode_varchar_bytes(value: &str) -> Vec<u8> {
759 crate::collation::encode_str_for_collation(value, None)
760 }
761
762 pub fn varchar_with_collation(
767 name: impl Into<String>,
768 value: &str,
769 collation: &Collation,
770 ) -> Self {
771 let collation_bytes = collation.to_bytes();
772 let encoded = Self::encode_varchar_bytes_for_collation(value, collation);
773 let byte_len = encoded.len();
774 let type_info = if byte_len > 8000 {
775 TypeInfo::varchar_max_with_collation(collation_bytes)
776 } else {
777 TypeInfo::varchar_with_collation(byte_len.max(1) as u16, collation_bytes)
778 };
779 Self::new(name, type_info, Bytes::from(encoded))
780 }
781
782 fn encode_varchar_bytes_for_collation(value: &str, collation: &Collation) -> Vec<u8> {
784 crate::collation::encode_str_for_collation(value, Some(collation))
785 }
786
787 #[must_use]
789 pub fn as_output(mut self) -> Self {
790 self.flags = self.flags.output();
791 self
792 }
793
794 pub fn encode(&self, buf: &mut BytesMut) {
796 let name_len = self.name.encode_utf16().count() as u8;
798 buf.put_u8(name_len);
799 if name_len > 0 {
800 for code_unit in self.name.encode_utf16() {
801 buf.put_u16_le(code_unit);
802 }
803 }
804
805 buf.put_u8(self.flags.encode());
807
808 self.type_info.encode(buf);
810
811 if let Some(ref value) = self.value {
813 match self.type_info.type_id {
815 0x26 => {
816 buf.put_u8(value.len() as u8);
818 buf.put_slice(value);
819 }
820 0x68 | 0x6D | 0x6E | 0x6F => {
821 buf.put_u8(value.len() as u8);
823 buf.put_slice(value);
824 }
825 0xE7 | 0xA7 | 0xA5 => {
826 if self.type_info.max_length == Some(0xFFFF) {
828 let total_len = value.len() as u64;
831 buf.put_u64_le(total_len);
832 buf.put_u32_le(value.len() as u32);
833 buf.put_slice(value);
834 buf.put_u32_le(0); } else {
836 buf.put_u16_le(value.len() as u16);
837 buf.put_slice(value);
838 }
839 }
840 0x24 => {
841 buf.put_u8(value.len() as u8);
843 buf.put_slice(value);
844 }
845 0x28..=0x2B => {
846 buf.put_u8(value.len() as u8);
848 buf.put_slice(value);
849 }
850 0x6C => {
851 buf.put_u8(value.len() as u8);
853 buf.put_slice(value);
854 }
855 0xF3 => {
856 buf.put_slice(value);
860 }
861 _ => {
862 buf.put_u8(value.len() as u8);
864 buf.put_slice(value);
865 }
866 }
867 } else {
868 match self.type_info.type_id {
870 0xE7 | 0xA7 | 0xA5 => {
871 if self.type_info.max_length == Some(0xFFFF) {
873 buf.put_u64_le(0xFFFFFFFFFFFFFFFF); } else {
875 buf.put_u16_le(0xFFFF);
876 }
877 }
878 _ => {
879 buf.put_u8(0); }
881 }
882 }
883
884 if let Some(ref metadata) = self.crypto_metadata {
886 metadata.encode(buf);
887 }
888 }
889}
890
891#[derive(Debug, Clone)]
893pub struct RpcRequest {
894 proc_name: Option<String>,
896 proc_id: Option<ProcId>,
898 options: RpcOptionFlags,
900 params: Vec<RpcParam>,
902}
903
904impl RpcRequest {
905 pub fn named(proc_name: impl Into<String>) -> Self {
907 Self {
908 proc_name: Some(proc_name.into()),
909 proc_id: None,
910 options: RpcOptionFlags::default(),
911 params: Vec::new(),
912 }
913 }
914
915 pub fn by_id(proc_id: ProcId) -> Self {
917 Self {
918 proc_name: None,
919 proc_id: Some(proc_id),
920 options: RpcOptionFlags::default(),
921 params: Vec::new(),
922 }
923 }
924
925 pub fn execute_sql(sql: &str, params: Vec<RpcParam>) -> Self {
943 let mut request = Self::by_id(ProcId::ExecuteSql);
944
945 request.params.push(RpcParam::nvarchar("", sql));
947
948 if !params.is_empty() {
950 let declarations = Self::build_param_declarations(¶ms);
951 request.params.push(RpcParam::nvarchar("", &declarations));
952 }
953
954 request.params.extend(params);
956
957 request
958 }
959
960 pub fn build_param_declarations(params: &[RpcParam]) -> String {
968 params
969 .iter()
970 .enumerate()
971 .map(|(idx, p)| {
972 let name = if p.name.starts_with('@') {
973 p.name.clone()
974 } else if p.name.is_empty() {
975 format!("@p{}", idx + 1)
979 } else {
980 format!("@{}", p.name)
981 };
982
983 let ti = p
986 .crypto_metadata
987 .as_ref()
988 .map(|m| &m.base_type_info)
989 .unwrap_or(&p.type_info);
990
991 let type_name: String = match ti.type_id {
992 0x26 => match ti.max_length {
993 Some(1) => "tinyint".to_string(),
994 Some(2) => "smallint".to_string(),
995 Some(4) => "int".to_string(),
996 Some(8) => "bigint".to_string(),
997 _ => "int".to_string(),
998 },
999 0x68 => "bit".to_string(),
1000 0x6D => match ti.max_length {
1001 Some(4) => "real".to_string(),
1002 _ => "float".to_string(),
1003 },
1004 0xE7 => {
1005 if ti.max_length == Some(0xFFFF) {
1006 "nvarchar(max)".to_string()
1007 } else {
1008 let len = ti.max_length.unwrap_or(4000) / 2;
1009 format!("nvarchar({len})")
1010 }
1011 }
1012 0xA7 => {
1013 if ti.max_length == Some(0xFFFF) {
1014 "varchar(max)".to_string()
1015 } else {
1016 let len = ti.max_length.unwrap_or(8000);
1017 format!("varchar({len})")
1018 }
1019 }
1020 0xA5 => {
1021 if ti.max_length == Some(0xFFFF) {
1022 "varbinary(max)".to_string()
1023 } else {
1024 let len = ti.max_length.unwrap_or(8000);
1025 format!("varbinary({len})")
1026 }
1027 }
1028 0xAF => format!("char({})", ti.max_length.unwrap_or(1)),
1029 0xEF => format!("nchar({})", ti.max_length.unwrap_or(2) / 2),
1030 0xAD => format!("binary({})", ti.max_length.unwrap_or(1)),
1031 0x24 => "uniqueidentifier".to_string(),
1032 0x28 => "date".to_string(),
1033 0x29 => {
1034 let scale = ti.scale.unwrap_or(7);
1035 format!("time({scale})")
1036 }
1037 0x2A => {
1038 let scale = ti.scale.unwrap_or(7);
1039 format!("datetime2({scale})")
1040 }
1041 0x2B => {
1042 let scale = ti.scale.unwrap_or(7);
1043 format!("datetimeoffset({scale})")
1044 }
1045 0x6C => {
1046 let precision = ti.precision.unwrap_or(18);
1047 let scale = ti.scale.unwrap_or(0);
1048 format!("decimal({precision}, {scale})")
1049 }
1050 0x6E => match ti.max_length {
1051 Some(4) => "smallmoney".to_string(),
1052 _ => "money".to_string(),
1053 },
1054 0x6F => match ti.max_length {
1055 Some(4) => "smalldatetime".to_string(),
1056 _ => "datetime".to_string(),
1057 },
1058 0xF3 => {
1059 if let Some(ref tvp_name) = ti.tvp_type_name {
1062 format!("{tvp_name} READONLY")
1063 } else {
1064 "sql_variant".to_string()
1066 }
1067 }
1068 _ => "sql_variant".to_string(),
1069 };
1070
1071 format!("{name} {type_name}")
1072 })
1073 .collect::<Vec<_>>()
1074 .join(", ")
1075 }
1076
1077 pub fn prepare(sql: &str, params: &[RpcParam]) -> Self {
1079 let mut request = Self::by_id(ProcId::Prepare);
1080
1081 request
1083 .params
1084 .push(RpcParam::null("@handle", TypeInfo::int()).as_output());
1085
1086 let declarations = Self::build_param_declarations(params);
1088 request
1089 .params
1090 .push(RpcParam::nvarchar("@params", &declarations));
1091
1092 request.params.push(RpcParam::nvarchar("@stmt", sql));
1094
1095 request.params.push(RpcParam::int("@options", 1));
1097
1098 request
1099 }
1100
1101 pub fn execute(handle: i32, params: Vec<RpcParam>) -> Self {
1108 let mut request = Self::by_id(ProcId::Execute);
1109
1110 request.params.push(RpcParam::int("", handle));
1112
1113 for mut param in params {
1115 param.name.clear();
1116 request.params.push(param);
1117 }
1118
1119 request
1120 }
1121
1122 pub fn prepexec(sql: &str, params: Vec<RpcParam>) -> Self {
1130 let mut request = Self::by_id(ProcId::PrepExec);
1131
1132 request
1134 .params
1135 .push(RpcParam::null("@handle", TypeInfo::int()).as_output());
1136
1137 let declarations = Self::build_param_declarations(¶ms);
1139 request
1140 .params
1141 .push(RpcParam::nvarchar("@params", &declarations));
1142
1143 request.params.push(RpcParam::nvarchar("@stmt", sql));
1145
1146 for mut param in params {
1148 param.name.clear();
1149 request.params.push(param);
1150 }
1151
1152 request
1153 }
1154
1155 pub fn unprepare(handle: i32) -> Self {
1157 let mut request = Self::by_id(ProcId::Unprepare);
1158 request.params.push(RpcParam::int("@handle", handle));
1159 request
1160 }
1161
1162 #[must_use]
1164 pub fn with_options(mut self, options: RpcOptionFlags) -> Self {
1165 self.options = options;
1166 self
1167 }
1168
1169 #[must_use]
1171 pub fn param(mut self, param: RpcParam) -> Self {
1172 self.params.push(param);
1173 self
1174 }
1175
1176 #[must_use]
1180 pub fn encode(&self) -> Bytes {
1181 self.encode_with_transaction(0)
1182 }
1183
1184 #[must_use]
1196 pub fn encode_with_transaction(&self, transaction_descriptor: u64) -> Bytes {
1197 let mut buf = BytesMut::with_capacity(256);
1198
1199 let all_headers_start = buf.len();
1202 buf.put_u32_le(0); buf.put_u32_le(18); buf.put_u16_le(0x0002); buf.put_u64_le(transaction_descriptor); buf.put_u32_le(1); let all_headers_len = buf.len() - all_headers_start;
1213 let len_bytes = (all_headers_len as u32).to_le_bytes();
1214 buf[all_headers_start..all_headers_start + 4].copy_from_slice(&len_bytes);
1215
1216 if let Some(proc_id) = self.proc_id {
1218 buf.put_u16_le(0xFFFF); buf.put_u16_le(proc_id as u16);
1221 } else if let Some(ref proc_name) = self.proc_name {
1222 let name_len = proc_name.encode_utf16().count() as u16;
1224 buf.put_u16_le(name_len);
1225 write_utf16_string(&mut buf, proc_name);
1226 }
1227
1228 buf.put_u16_le(self.options.encode());
1230
1231 for param in &self.params {
1233 param.encode(&mut buf);
1234 }
1235
1236 buf.freeze()
1237 }
1238}
1239
1240#[cfg(test)]
1241#[allow(clippy::unwrap_used)]
1242mod tests {
1243 use super::*;
1244
1245 #[test]
1254 fn encrypted_param_encode_matches_captured_dotnet_wire() {
1255 fn unhex(s: &str) -> Vec<u8> {
1256 (0..s.len())
1257 .step_by(2)
1258 .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
1259 .collect()
1260 }
1261
1262 let mut nvarchar_50 = TypeInfo::nvarchar(50);
1264 nvarchar_50.collation = Some([0x09, 0x04, 0xd0, 0x00, 0x34]);
1265
1266 let cases = [
1267 (
1268 "@i",
1269 TypeInfo::int(),
1270 "024000690008a5ffff41000000000000004100000001ed7b8c6030870d92358f12acb0d0c69c00bc3aa3ba578ecb0ea5f514b5045912a2b1ae52ed834f6bac49520956e4a574c30d573590fb3785556c8fe42f87c5b4000000002604020106000000020000000100000023722f0069b4000001",
1271 ),
1272 (
1273 "@s",
1274 nvarchar_50,
1275 "024000730008a5ffff4100000000000000410000000150c0a7dec4d4241c7a4a617007d32d97e7131f8c57a5ad212487891170f12ecb9957fce16389f4728d1c3c65813beeea085ae3fd516d29f84298df3e97f0d05d00000000e764000904d00034020106000000020000000100000023722f0069b4000001",
1276 ),
1277 (
1278 "@b",
1279 TypeInfo::varbinary(50),
1280 "024000620008a5ffff41000000000000004100000001d17165aa6df0155be6b78c6712d3b03870ea394cfed10956cf07fbfa204c4b82cddfa5e2f4fc03335f579e2767657e3067cd9da7d62a07427106b91f747b97da00000000a53200020106000000020000000100000023722f0069b4000001",
1281 ),
1282 ];
1283
1284 for (name, base_type_info, golden_hex) in cases {
1285 let golden = unhex(golden_hex);
1286 let cipher_off = 1 + name.encode_utf16().count() * 2 + 1 + 3 + 8 + 4;
1289 let cipher = Bytes::copy_from_slice(&golden[cipher_off..cipher_off + 65]);
1290
1291 let param = RpcParam::encrypted(
1292 name,
1293 cipher,
1294 EncryptedParamMetadata {
1295 base_type_info,
1296 algorithm_id: 2,
1297 encryption_type: EncryptionTypeWire::Deterministic,
1298 database_id: 6,
1299 cek_id: 2,
1300 cek_version: 1,
1301 cek_md_version: 0x0000_b469_002f_7223,
1302 normalization_rule_version: 1,
1303 },
1304 );
1305
1306 let mut buf = BytesMut::new();
1307 param.encode(&mut buf);
1308 assert_eq!(
1309 buf.to_vec(),
1310 golden,
1311 "encrypted {name} param must match the captured Microsoft.Data.SqlClient bytes"
1312 );
1313 }
1314 }
1315
1316 #[test]
1317 fn test_proc_id_values() {
1318 assert_eq!(ProcId::ExecuteSql as u16, 0x000A);
1319 assert_eq!(ProcId::Prepare as u16, 0x000B);
1320 assert_eq!(ProcId::Execute as u16, 0x000C);
1321 assert_eq!(ProcId::Unprepare as u16, 0x000F);
1322 }
1323
1324 #[test]
1325 fn test_option_flags_encode() {
1326 let flags = RpcOptionFlags::new().with_recompile(true);
1327 assert_eq!(flags.encode(), 0x0001);
1328 }
1329
1330 #[test]
1331 fn test_param_flags_encode() {
1332 let flags = ParamFlags::new().output();
1333 assert_eq!(flags.encode(), 0x01);
1334 }
1335
1336 #[test]
1337 fn test_int_param() {
1338 let param = RpcParam::int("@p1", 42);
1339 assert_eq!(param.name, "@p1");
1340 assert_eq!(param.type_info.type_id, 0x26);
1341 assert!(param.value.is_some());
1342 }
1343
1344 #[test]
1345 fn test_nvarchar_param() {
1346 let param = RpcParam::nvarchar("@name", "Alice");
1347 assert_eq!(param.name, "@name");
1348 assert_eq!(param.type_info.type_id, 0xE7);
1349 assert_eq!(param.value.as_ref().unwrap().len(), 10);
1351 }
1352
1353 #[test]
1354 fn test_nvarchar_param_surrogate_pair_length() {
1355 let param = RpcParam::nvarchar("@p", "π");
1359 assert_eq!(param.value.as_ref().unwrap().len(), 4);
1360 assert_eq!(param.type_info.max_length, Some(4));
1362
1363 let param = RpcParam::nvarchar("@p", "Hello δΈη π");
1364 assert_eq!(param.value.as_ref().unwrap().len(), 22);
1366 assert_eq!(param.type_info.max_length, Some(22));
1367 }
1368
1369 #[test]
1370 fn test_execute_sql_request() {
1371 let rpc = RpcRequest::execute_sql(
1372 "SELECT * FROM users WHERE id = @p1",
1373 vec![RpcParam::int("@p1", 42)],
1374 );
1375
1376 assert_eq!(rpc.proc_id, Some(ProcId::ExecuteSql));
1377 assert_eq!(rpc.params.len(), 3);
1379 }
1380
1381 #[test]
1382 fn test_param_declarations() {
1383 let params = vec![
1384 RpcParam::int("@p1", 42),
1385 RpcParam::nvarchar("@name", "Alice"),
1386 ];
1387
1388 let decls = RpcRequest::build_param_declarations(¶ms);
1389 assert!(decls.contains("@p1 int"));
1390 assert!(decls.contains("@name nvarchar"));
1391 }
1392
1393 #[test]
1394 fn test_param_declarations_unnamed_are_positional() {
1395 let params = vec![
1399 RpcParam::int("", 1),
1400 RpcParam::int("", 2),
1401 RpcParam::int("", 3),
1402 ];
1403
1404 let decls = RpcRequest::build_param_declarations(¶ms);
1405 assert_eq!(decls, "@p1 int, @p2 int, @p3 int");
1406 }
1407
1408 #[test]
1409 fn test_rpc_encode_not_empty() {
1410 let rpc = RpcRequest::execute_sql("SELECT 1", vec![]);
1411 let encoded = rpc.encode();
1412 assert!(!encoded.is_empty());
1413 }
1414
1415 #[test]
1416 fn test_prepare_request() {
1417 let rpc = RpcRequest::prepare(
1418 "SELECT * FROM users WHERE id = @p1",
1419 &[RpcParam::int("@p1", 0)],
1420 );
1421
1422 assert_eq!(rpc.proc_id, Some(ProcId::Prepare));
1423 assert_eq!(rpc.params.len(), 4);
1425 assert!(rpc.params[0].flags.by_ref); }
1427
1428 #[test]
1429 fn test_execute_request() {
1430 let rpc = RpcRequest::execute(123, vec![RpcParam::int("@p1", 42)]);
1431
1432 assert_eq!(rpc.proc_id, Some(ProcId::Execute));
1433 assert_eq!(rpc.params.len(), 2); }
1435
1436 #[test]
1437 fn test_unprepare_request() {
1438 let rpc = RpcRequest::unprepare(123);
1439
1440 assert_eq!(rpc.proc_id, Some(ProcId::Unprepare));
1441 assert_eq!(rpc.params.len(), 1); }
1443
1444 #[test]
1445 fn test_prepexec_request() {
1446 let rpc = RpcRequest::prepexec(
1447 "SELECT * FROM users WHERE id = @p1",
1448 vec![RpcParam::int("@p1", 42)],
1449 );
1450
1451 assert_eq!(rpc.proc_id, Some(ProcId::PrepExec));
1452 assert_eq!(rpc.params.len(), 4);
1454 assert!(rpc.params[0].flags.by_ref); assert_eq!(rpc.params[1].name, "@params");
1456 assert_eq!(rpc.params[2].name, "@stmt");
1457 assert!(rpc.params[3].name.is_empty());
1460 }
1462
1463 #[test]
1464 fn test_varchar_param() {
1465 let param = RpcParam::varchar("@name", "Alice");
1466 assert_eq!(param.name, "@name");
1467 assert_eq!(param.type_info.type_id, 0xA7);
1468 assert_eq!(param.value.as_ref().unwrap().len(), 5);
1470 assert_eq!(¶m.value.as_ref().unwrap()[..], b"Alice");
1471 }
1472
1473 #[test]
1474 fn test_varchar_param_max() {
1475 let long_str = "a".repeat(9000);
1477 let param = RpcParam::varchar("@big", &long_str);
1478 assert_eq!(param.type_info.type_id, 0xA7);
1479 assert_eq!(param.type_info.max_length, Some(0xFFFF));
1480 assert_eq!(param.value.as_ref().unwrap().len(), 9000);
1481 }
1482
1483 #[test]
1484 fn test_varchar_param_declarations() {
1485 let params = vec![
1486 RpcParam::int("@p1", 42),
1487 RpcParam::varchar("@name", "Alice"),
1488 ];
1489
1490 let decls = RpcRequest::build_param_declarations(¶ms);
1491 assert!(decls.contains("@p1 int"));
1492 assert!(decls.contains("@name varchar(5)"));
1493 }
1494
1495 #[test]
1496 fn test_varchar_type_info_has_collation() {
1497 let ti = TypeInfo::varchar(100);
1498 assert_eq!(ti.type_id, 0xA7);
1499 assert_eq!(ti.max_length, Some(100));
1500 assert!(ti.collation.is_some());
1501 }
1502
1503 #[test]
1504 fn test_varchar_encode_round_trip() {
1505 let param = RpcParam::varchar("@val", "test value");
1507 let mut buf = bytes::BytesMut::new();
1508 param.encode(&mut buf);
1509 assert!(!buf.is_empty());
1510 }
1511
1512 #[test]
1513 fn test_collation_round_trip() {
1514 let collation = Collation {
1515 lcid: 0x00D0_0409,
1516 sort_id: 0x34,
1517 };
1518 let bytes = collation.to_bytes();
1519 assert_eq!(bytes, [0x09, 0x04, 0xD0, 0x00, 0x34]);
1520
1521 let restored = Collation::from_bytes(&bytes);
1522 assert_eq!(restored.lcid, collation.lcid);
1523 assert_eq!(restored.sort_id, collation.sort_id);
1524 }
1525
1526 #[test]
1527 fn test_varchar_with_collation_uses_custom_collation_bytes() {
1528 let collation = Collation {
1530 lcid: 0x0804,
1531 sort_id: 0,
1532 };
1533 let param = RpcParam::varchar_with_collation("@val", "test", &collation);
1534 assert_eq!(param.type_info.type_id, 0xA7);
1535 assert_eq!(param.type_info.collation, Some(collation.to_bytes()));
1537 }
1538
1539 #[test]
1540 fn test_money_type_info() {
1541 let ti = TypeInfo::money();
1542 assert_eq!(ti.type_id, 0x6E);
1543 assert_eq!(ti.max_length, Some(8));
1544 }
1545
1546 #[test]
1547 fn test_smallmoney_type_info() {
1548 let ti = TypeInfo::smallmoney();
1549 assert_eq!(ti.type_id, 0x6E);
1550 assert_eq!(ti.max_length, Some(4));
1551 }
1552
1553 #[test]
1554 fn test_smalldatetime_type_info() {
1555 let ti = TypeInfo::smalldatetime();
1556 assert_eq!(ti.type_id, 0x6F);
1557 assert_eq!(ti.max_length, Some(4));
1558 }
1559
1560 #[test]
1561 fn test_money_param_declarations() {
1562 let decls = RpcRequest::build_param_declarations(&[
1563 RpcParam::new("@m", TypeInfo::money(), Bytes::from_static(&[0u8; 8])),
1564 RpcParam::new("@sm", TypeInfo::smallmoney(), Bytes::from_static(&[0u8; 4])),
1565 RpcParam::new(
1566 "@sdt",
1567 TypeInfo::smalldatetime(),
1568 Bytes::from_static(&[0u8; 4]),
1569 ),
1570 ]);
1571 assert!(decls.contains("@m money"), "got: {decls}");
1572 assert!(decls.contains("@sm smallmoney"), "got: {decls}");
1573 assert!(decls.contains("@sdt smalldatetime"), "got: {decls}");
1574 }
1575
1576 #[test]
1577 fn test_money_typeinfo_encodes_max_length_byte() {
1578 let mut buf = bytes::BytesMut::new();
1579 TypeInfo::money().encode(&mut buf);
1580 assert_eq!(&buf[..], &[0x6E, 0x08]);
1582
1583 let mut buf = bytes::BytesMut::new();
1584 TypeInfo::smallmoney().encode(&mut buf);
1585 assert_eq!(&buf[..], &[0x6E, 0x04]);
1586
1587 let mut buf = bytes::BytesMut::new();
1588 TypeInfo::smalldatetime().encode(&mut buf);
1589 assert_eq!(&buf[..], &[0x6F, 0x04]);
1590 }
1591
1592 #[test]
1593 fn test_varchar_with_collation_default_vs_custom_differ() {
1594 let default_param = RpcParam::varchar("@val", "test");
1595 let custom_collation = Collation {
1596 lcid: 0x0419, sort_id: 0,
1598 };
1599 let custom_param = RpcParam::varchar_with_collation("@val", "test", &custom_collation);
1600 assert_ne!(
1602 default_param.type_info.collation,
1603 custom_param.type_info.collation
1604 );
1605 }
1606}