1use bytes::{BufMut, Bytes, BytesMut};
27
28use crate::codec::write_utf16_string;
29use crate::crypto::EncryptionTypeWire;
30use crate::prelude::*;
31use crate::token::Collation;
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38#[repr(u16)]
39#[non_exhaustive]
40pub enum ProcId {
41 Cursor = 0x0001,
43 CursorOpen = 0x0002,
45 CursorPrepare = 0x0003,
47 CursorExecute = 0x0004,
49 CursorPrepExec = 0x0005,
51 CursorUnprepare = 0x0006,
53 CursorFetch = 0x0007,
55 CursorOption = 0x0008,
57 CursorClose = 0x0009,
59 ExecuteSql = 0x000A,
61 Prepare = 0x000B,
63 Execute = 0x000C,
65 PrepExec = 0x000D,
67 PrepExecRpc = 0x000E,
69 Unprepare = 0x000F,
71}
72
73#[derive(Debug, Clone, Copy, Default)]
75pub struct RpcOptionFlags {
76 pub with_recompile: bool,
78 pub no_metadata: bool,
80 pub reuse_metadata: bool,
82}
83
84impl RpcOptionFlags {
85 pub fn new() -> Self {
87 Self::default()
88 }
89
90 #[must_use]
92 pub fn with_recompile(mut self, value: bool) -> Self {
93 self.with_recompile = value;
94 self
95 }
96
97 pub fn encode(&self) -> u16 {
99 let mut flags = 0u16;
100 if self.with_recompile {
101 flags |= 0x0001;
102 }
103 if self.no_metadata {
104 flags |= 0x0002;
105 }
106 if self.reuse_metadata {
107 flags |= 0x0004;
108 }
109 flags
110 }
111}
112
113#[derive(Debug, Clone, Copy, Default)]
115pub struct ParamFlags {
116 pub by_ref: bool,
118 pub default: bool,
120 pub encrypted: bool,
122}
123
124impl ParamFlags {
125 pub fn new() -> Self {
127 Self::default()
128 }
129
130 #[must_use]
132 pub fn output(mut self) -> Self {
133 self.by_ref = true;
134 self
135 }
136
137 pub fn encode(&self) -> u8 {
139 let mut flags = 0u8;
140 if self.by_ref {
141 flags |= 0x01;
142 }
143 if self.default {
144 flags |= 0x02;
145 }
146 if self.encrypted {
147 flags |= 0x08;
148 }
149 flags
150 }
151}
152
153#[derive(Debug, Clone)]
155pub struct TypeInfo {
156 pub type_id: u8,
158 pub max_length: Option<u16>,
160 pub precision: Option<u8>,
162 pub scale: Option<u8>,
164 pub collation: Option<[u8; 5]>,
166 pub tvp_type_name: Option<String>,
168}
169
170impl TypeInfo {
171 pub fn int() -> Self {
173 Self {
174 type_id: 0x26, max_length: Some(4),
176 precision: None,
177 scale: None,
178 collation: None,
179 tvp_type_name: None,
180 }
181 }
182
183 pub fn bigint() -> Self {
185 Self {
186 type_id: 0x26, max_length: Some(8),
188 precision: None,
189 scale: None,
190 collation: None,
191 tvp_type_name: None,
192 }
193 }
194
195 pub fn smallint() -> Self {
197 Self {
198 type_id: 0x26, max_length: Some(2),
200 precision: None,
201 scale: None,
202 collation: None,
203 tvp_type_name: None,
204 }
205 }
206
207 pub fn tinyint() -> Self {
209 Self {
210 type_id: 0x26, max_length: Some(1),
212 precision: None,
213 scale: None,
214 collation: None,
215 tvp_type_name: None,
216 }
217 }
218
219 pub fn bit() -> Self {
221 Self {
222 type_id: 0x68, max_length: Some(1),
224 precision: None,
225 scale: None,
226 collation: None,
227 tvp_type_name: None,
228 }
229 }
230
231 pub fn float() -> Self {
233 Self {
234 type_id: 0x6D, max_length: Some(8),
236 precision: None,
237 scale: None,
238 collation: None,
239 tvp_type_name: None,
240 }
241 }
242
243 pub fn real() -> Self {
245 Self {
246 type_id: 0x6D, max_length: Some(4),
248 precision: None,
249 scale: None,
250 collation: None,
251 tvp_type_name: None,
252 }
253 }
254
255 pub fn nvarchar(max_len: u16) -> Self {
257 Self {
258 type_id: 0xE7, max_length: Some(max_len * 2), precision: None,
261 scale: None,
262 collation: Some([0x09, 0x04, 0xD0, 0x00, 0x34]),
264 tvp_type_name: None,
265 }
266 }
267
268 pub fn nvarchar_max() -> Self {
270 Self {
271 type_id: 0xE7, max_length: Some(0xFFFF), precision: None,
274 scale: None,
275 collation: Some([0x09, 0x04, 0xD0, 0x00, 0x34]),
276 tvp_type_name: None,
277 }
278 }
279
280 const DEFAULT_COLLATION: [u8; 5] = [0x09, 0x04, 0xD0, 0x00, 0x34];
282
283 pub fn varchar(max_len: u16) -> Self {
285 Self::varchar_with_collation(max_len, Self::DEFAULT_COLLATION)
286 }
287
288 pub fn varchar_with_collation(max_len: u16, collation: [u8; 5]) -> Self {
290 Self {
291 type_id: 0xA7, max_length: Some(max_len),
293 precision: None,
294 scale: None,
295 collation: Some(collation),
296 tvp_type_name: None,
297 }
298 }
299
300 pub fn varchar_max() -> Self {
302 Self::varchar_max_with_collation(Self::DEFAULT_COLLATION)
303 }
304
305 pub fn varchar_max_with_collation(collation: [u8; 5]) -> Self {
307 Self {
308 type_id: 0xA7, max_length: Some(0xFFFF), precision: None,
311 scale: None,
312 collation: Some(collation),
313 tvp_type_name: None,
314 }
315 }
316
317 pub fn varbinary(max_len: u16) -> Self {
319 Self {
320 type_id: 0xA5, max_length: Some(max_len),
322 precision: None,
323 scale: None,
324 collation: None,
325 tvp_type_name: None,
326 }
327 }
328
329 pub fn varbinary_max() -> Self {
331 Self {
332 type_id: 0xA5, max_length: Some(0xFFFF), precision: None,
335 scale: None,
336 collation: None,
337 tvp_type_name: None,
338 }
339 }
340
341 pub fn char(len: u16) -> Self {
343 Self {
344 type_id: 0xAF, max_length: Some(len),
346 precision: None,
347 scale: None,
348 collation: Some(Self::DEFAULT_COLLATION),
349 tvp_type_name: None,
350 }
351 }
352
353 pub fn nchar(len: u16) -> Self {
355 Self {
356 type_id: 0xEF, max_length: Some(len * 2), precision: None,
359 scale: None,
360 collation: Some(Self::DEFAULT_COLLATION),
361 tvp_type_name: None,
362 }
363 }
364
365 pub fn binary(len: u16) -> Self {
367 Self {
368 type_id: 0xAD, max_length: Some(len),
370 precision: None,
371 scale: None,
372 collation: None,
373 tvp_type_name: None,
374 }
375 }
376
377 pub fn uniqueidentifier() -> Self {
379 Self {
380 type_id: 0x24, max_length: Some(16),
382 precision: None,
383 scale: None,
384 collation: None,
385 tvp_type_name: None,
386 }
387 }
388
389 pub fn uuid() -> Self {
391 Self {
392 type_id: 0x24, max_length: Some(16),
394 precision: None,
395 scale: None,
396 collation: None,
397 tvp_type_name: None,
398 }
399 }
400
401 pub fn date() -> Self {
403 Self {
404 type_id: 0x28, max_length: None,
406 precision: None,
407 scale: None,
408 collation: None,
409 tvp_type_name: None,
410 }
411 }
412
413 pub fn time(scale: u8) -> Self {
415 Self {
416 type_id: 0x29, max_length: None,
418 precision: None,
419 scale: Some(scale),
420 collation: None,
421 tvp_type_name: None,
422 }
423 }
424
425 pub fn datetime2(scale: u8) -> Self {
427 Self {
428 type_id: 0x2A, max_length: None,
430 precision: None,
431 scale: Some(scale),
432 collation: None,
433 tvp_type_name: None,
434 }
435 }
436
437 pub fn datetimeoffset(scale: u8) -> Self {
439 Self {
440 type_id: 0x2B, max_length: None,
442 precision: None,
443 scale: Some(scale),
444 collation: None,
445 tvp_type_name: None,
446 }
447 }
448
449 pub fn decimal(precision: u8, scale: u8) -> Self {
451 Self {
452 type_id: 0x6C, max_length: Some(17), precision: Some(precision),
455 scale: Some(scale),
456 collation: None,
457 tvp_type_name: None,
458 }
459 }
460
461 pub fn money() -> Self {
463 Self {
464 type_id: 0x6E, max_length: Some(8),
466 precision: None,
467 scale: None,
468 collation: None,
469 tvp_type_name: None,
470 }
471 }
472
473 pub fn smallmoney() -> Self {
475 Self {
476 type_id: 0x6E, max_length: Some(4),
478 precision: None,
479 scale: None,
480 collation: None,
481 tvp_type_name: None,
482 }
483 }
484
485 pub fn smalldatetime() -> Self {
487 Self {
488 type_id: 0x6F, max_length: Some(4),
490 precision: None,
491 scale: None,
492 collation: None,
493 tvp_type_name: None,
494 }
495 }
496
497 pub fn datetime() -> Self {
499 Self {
500 type_id: 0x6F, max_length: Some(8),
502 precision: None,
503 scale: None,
504 collation: None,
505 tvp_type_name: None,
506 }
507 }
508
509 pub fn tvp(type_name: impl Into<String>) -> Self {
514 Self {
515 type_id: 0xF3, max_length: None,
517 precision: None,
518 scale: None,
519 collation: None,
520 tvp_type_name: Some(type_name.into()),
521 }
522 }
523
524 pub fn encode(&self, buf: &mut BytesMut) {
526 if self.type_id != 0xF3 {
529 buf.put_u8(self.type_id);
530 }
531
532 match self.type_id {
534 0x26 | 0x68 | 0x6D | 0x6E | 0x6F => {
535 if let Some(len) = self.max_length {
537 buf.put_u8(len as u8);
538 }
539 }
540 0xE7 | 0xA7 | 0xA5 | 0xEF | 0xAF | 0xAD => {
541 if let Some(len) = self.max_length {
544 buf.put_u16_le(len);
545 }
546 if let Some(collation) = self.collation {
548 buf.put_slice(&collation);
549 }
550 }
551 0x24 => {
552 if let Some(len) = self.max_length {
554 buf.put_u8(len as u8);
555 }
556 }
557 0x29..=0x2B => {
558 if let Some(scale) = self.scale {
560 buf.put_u8(scale);
561 }
562 }
563 0x6C | 0x6A => {
564 if let Some(len) = self.max_length {
566 buf.put_u8(len as u8);
567 }
568 if let Some(precision) = self.precision {
569 buf.put_u8(precision);
570 }
571 if let Some(scale) = self.scale {
572 buf.put_u8(scale);
573 }
574 }
575 _ => {}
576 }
577 }
578}
579
580#[derive(Debug, Clone)]
587pub struct EncryptedParamMetadata {
588 pub base_type_info: TypeInfo,
590 pub algorithm_id: u8,
592 pub encryption_type: EncryptionTypeWire,
594 pub database_id: u32,
596 pub cek_id: u32,
598 pub cek_version: u32,
600 pub cek_md_version: u64,
602 pub normalization_rule_version: u8,
604}
605
606impl EncryptedParamMetadata {
607 pub fn encode(&self, buf: &mut BytesMut) {
610 self.base_type_info.encode(buf);
611 buf.put_u8(self.algorithm_id);
612 buf.put_u8(self.encryption_type.to_u8());
613 buf.put_u32_le(self.database_id);
614 buf.put_u32_le(self.cek_id);
615 buf.put_u32_le(self.cek_version);
616 buf.put_u64_le(self.cek_md_version);
617 buf.put_u8(self.normalization_rule_version);
618 }
619}
620
621#[derive(Debug, Clone)]
623pub struct RpcParam {
624 pub name: String,
626 pub flags: ParamFlags,
628 pub type_info: TypeInfo,
630 pub value: Option<Bytes>,
632 pub crypto_metadata: Option<EncryptedParamMetadata>,
635}
636
637impl RpcParam {
638 pub fn new(name: impl Into<String>, type_info: TypeInfo, value: Bytes) -> Self {
640 Self {
641 name: name.into(),
642 flags: ParamFlags::default(),
643 type_info,
644 value: Some(value),
645 crypto_metadata: None,
646 }
647 }
648
649 pub fn null(name: impl Into<String>, type_info: TypeInfo) -> Self {
651 Self {
652 name: name.into(),
653 flags: ParamFlags::default(),
654 type_info,
655 value: None,
656 crypto_metadata: None,
657 }
658 }
659
660 pub fn encrypted(
667 name: impl Into<String>,
668 ciphertext: Bytes,
669 metadata: EncryptedParamMetadata,
670 ) -> Self {
671 Self {
672 name: name.into(),
673 flags: ParamFlags {
674 encrypted: true,
675 ..ParamFlags::default()
676 },
677 type_info: TypeInfo::varbinary_max(),
678 value: Some(ciphertext),
679 crypto_metadata: Some(metadata),
680 }
681 }
682
683 pub fn encrypted_null(name: impl Into<String>, metadata: EncryptedParamMetadata) -> Self {
689 Self {
690 name: name.into(),
691 flags: ParamFlags {
692 encrypted: true,
693 ..ParamFlags::default()
694 },
695 type_info: TypeInfo::varbinary_max(),
696 value: None,
697 crypto_metadata: Some(metadata),
698 }
699 }
700
701 pub fn int(name: impl Into<String>, value: i32) -> Self {
703 let mut buf = BytesMut::with_capacity(4);
704 buf.put_i32_le(value);
705 Self::new(name, TypeInfo::int(), buf.freeze())
706 }
707
708 pub fn bigint(name: impl Into<String>, value: i64) -> Self {
710 let mut buf = BytesMut::with_capacity(8);
711 buf.put_i64_le(value);
712 Self::new(name, TypeInfo::bigint(), buf.freeze())
713 }
714
715 pub fn nvarchar(name: impl Into<String>, value: &str) -> Self {
717 let mut buf = BytesMut::new();
718 let mut code_units: usize = 0;
719 for code_unit in value.encode_utf16() {
720 buf.put_u16_le(code_unit);
721 code_units += 1;
722 }
723 let type_info = if code_units > 4000 {
729 TypeInfo::nvarchar_max()
730 } else {
731 TypeInfo::nvarchar(code_units.max(1) as u16)
732 };
733 Self::new(name, type_info, buf.freeze())
734 }
735
736 pub fn varchar(name: impl Into<String>, value: &str) -> Self {
746 let encoded = Self::encode_varchar_bytes(value);
747 let byte_len = encoded.len();
748 let type_info = if byte_len > 8000 {
749 TypeInfo::varchar_max()
750 } else {
751 TypeInfo::varchar(byte_len.max(1) as u16)
752 };
753 Self::new(name, type_info, Bytes::from(encoded))
754 }
755
756 fn encode_varchar_bytes(value: &str) -> Vec<u8> {
759 crate::collation::encode_str_for_collation(value, None)
760 }
761
762 pub fn varchar_with_collation(
767 name: impl Into<String>,
768 value: &str,
769 collation: &Collation,
770 ) -> Self {
771 let collation_bytes = collation.to_bytes();
772 let encoded = Self::encode_varchar_bytes_for_collation(value, collation);
773 let byte_len = encoded.len();
774 let type_info = if byte_len > 8000 {
775 TypeInfo::varchar_max_with_collation(collation_bytes)
776 } else {
777 TypeInfo::varchar_with_collation(byte_len.max(1) as u16, collation_bytes)
778 };
779 Self::new(name, type_info, Bytes::from(encoded))
780 }
781
782 fn encode_varchar_bytes_for_collation(value: &str, collation: &Collation) -> Vec<u8> {
784 crate::collation::encode_str_for_collation(value, Some(collation))
785 }
786
787 #[must_use]
789 pub fn as_output(mut self) -> Self {
790 self.flags = self.flags.output();
791 self
792 }
793
794 pub fn encode(&self, buf: &mut BytesMut) {
796 let name_len = self.name.encode_utf16().count() as u8;
798 buf.put_u8(name_len);
799 if name_len > 0 {
800 for code_unit in self.name.encode_utf16() {
801 buf.put_u16_le(code_unit);
802 }
803 }
804
805 buf.put_u8(self.flags.encode());
807
808 self.type_info.encode(buf);
810
811 if let Some(ref value) = self.value {
813 match self.type_info.type_id {
815 0x26 => {
816 buf.put_u8(value.len() as u8);
818 buf.put_slice(value);
819 }
820 0x68 | 0x6D | 0x6E | 0x6F => {
821 buf.put_u8(value.len() as u8);
823 buf.put_slice(value);
824 }
825 0xE7 | 0xA7 | 0xA5 => {
826 if self.type_info.max_length == Some(0xFFFF) {
828 let total_len = value.len() as u64;
831 buf.put_u64_le(total_len);
832 buf.put_u32_le(value.len() as u32);
833 buf.put_slice(value);
834 buf.put_u32_le(0); } else {
836 buf.put_u16_le(value.len() as u16);
837 buf.put_slice(value);
838 }
839 }
840 0x24 => {
841 buf.put_u8(value.len() as u8);
843 buf.put_slice(value);
844 }
845 0x28..=0x2B => {
846 buf.put_u8(value.len() as u8);
848 buf.put_slice(value);
849 }
850 0x6C => {
851 buf.put_u8(value.len() as u8);
853 buf.put_slice(value);
854 }
855 0xF3 => {
856 buf.put_slice(value);
860 }
861 _ => {
862 buf.put_u8(value.len() as u8);
864 buf.put_slice(value);
865 }
866 }
867 } else {
868 match self.type_info.type_id {
870 0xE7 | 0xA7 | 0xA5 => {
871 if self.type_info.max_length == Some(0xFFFF) {
873 buf.put_u64_le(0xFFFFFFFFFFFFFFFF); } else {
875 buf.put_u16_le(0xFFFF);
876 }
877 }
878 _ => {
879 buf.put_u8(0); }
881 }
882 }
883
884 if let Some(ref metadata) = self.crypto_metadata {
886 metadata.encode(buf);
887 }
888 }
889}
890
891#[derive(Debug, Clone)]
893pub struct RpcRequest {
894 proc_name: Option<String>,
896 proc_id: Option<ProcId>,
898 options: RpcOptionFlags,
900 params: Vec<RpcParam>,
902}
903
904impl RpcRequest {
905 pub fn named(proc_name: impl Into<String>) -> Self {
907 Self {
908 proc_name: Some(proc_name.into()),
909 proc_id: None,
910 options: RpcOptionFlags::default(),
911 params: Vec::new(),
912 }
913 }
914
915 pub fn by_id(proc_id: ProcId) -> Self {
917 Self {
918 proc_name: None,
919 proc_id: Some(proc_id),
920 options: RpcOptionFlags::default(),
921 params: Vec::new(),
922 }
923 }
924
925 pub fn execute_sql(sql: &str, params: Vec<RpcParam>) -> Self {
943 let mut request = Self::by_id(ProcId::ExecuteSql);
944
945 request.params.push(RpcParam::nvarchar("", sql));
947
948 if !params.is_empty() {
950 let declarations = Self::build_param_declarations(¶ms);
951 request.params.push(RpcParam::nvarchar("", &declarations));
952 }
953
954 request.params.extend(params);
956
957 request
958 }
959
960 pub fn build_param_declarations(params: &[RpcParam]) -> String {
968 params
969 .iter()
970 .map(|p| {
971 let name = if p.name.starts_with('@') {
972 p.name.clone()
973 } else if p.name.is_empty() {
974 format!(
976 "@p{}",
977 params.iter().position(|x| x.name == p.name).unwrap_or(0) + 1
978 )
979 } else {
980 format!("@{}", p.name)
981 };
982
983 let ti = p
986 .crypto_metadata
987 .as_ref()
988 .map(|m| &m.base_type_info)
989 .unwrap_or(&p.type_info);
990
991 let type_name: String = match ti.type_id {
992 0x26 => match ti.max_length {
993 Some(1) => "tinyint".to_string(),
994 Some(2) => "smallint".to_string(),
995 Some(4) => "int".to_string(),
996 Some(8) => "bigint".to_string(),
997 _ => "int".to_string(),
998 },
999 0x68 => "bit".to_string(),
1000 0x6D => match ti.max_length {
1001 Some(4) => "real".to_string(),
1002 _ => "float".to_string(),
1003 },
1004 0xE7 => {
1005 if ti.max_length == Some(0xFFFF) {
1006 "nvarchar(max)".to_string()
1007 } else {
1008 let len = ti.max_length.unwrap_or(4000) / 2;
1009 format!("nvarchar({len})")
1010 }
1011 }
1012 0xA7 => {
1013 if ti.max_length == Some(0xFFFF) {
1014 "varchar(max)".to_string()
1015 } else {
1016 let len = ti.max_length.unwrap_or(8000);
1017 format!("varchar({len})")
1018 }
1019 }
1020 0xA5 => {
1021 if ti.max_length == Some(0xFFFF) {
1022 "varbinary(max)".to_string()
1023 } else {
1024 let len = ti.max_length.unwrap_or(8000);
1025 format!("varbinary({len})")
1026 }
1027 }
1028 0xAF => format!("char({})", ti.max_length.unwrap_or(1)),
1029 0xEF => format!("nchar({})", ti.max_length.unwrap_or(2) / 2),
1030 0xAD => format!("binary({})", ti.max_length.unwrap_or(1)),
1031 0x24 => "uniqueidentifier".to_string(),
1032 0x28 => "date".to_string(),
1033 0x29 => {
1034 let scale = ti.scale.unwrap_or(7);
1035 format!("time({scale})")
1036 }
1037 0x2A => {
1038 let scale = ti.scale.unwrap_or(7);
1039 format!("datetime2({scale})")
1040 }
1041 0x2B => {
1042 let scale = ti.scale.unwrap_or(7);
1043 format!("datetimeoffset({scale})")
1044 }
1045 0x6C => {
1046 let precision = ti.precision.unwrap_or(18);
1047 let scale = ti.scale.unwrap_or(0);
1048 format!("decimal({precision}, {scale})")
1049 }
1050 0x6E => match ti.max_length {
1051 Some(4) => "smallmoney".to_string(),
1052 _ => "money".to_string(),
1053 },
1054 0x6F => match ti.max_length {
1055 Some(4) => "smalldatetime".to_string(),
1056 _ => "datetime".to_string(),
1057 },
1058 0xF3 => {
1059 if let Some(ref tvp_name) = ti.tvp_type_name {
1062 format!("{tvp_name} READONLY")
1063 } else {
1064 "sql_variant".to_string()
1066 }
1067 }
1068 _ => "sql_variant".to_string(),
1069 };
1070
1071 format!("{name} {type_name}")
1072 })
1073 .collect::<Vec<_>>()
1074 .join(", ")
1075 }
1076
1077 pub fn prepare(sql: &str, params: &[RpcParam]) -> Self {
1079 let mut request = Self::by_id(ProcId::Prepare);
1080
1081 request
1083 .params
1084 .push(RpcParam::null("@handle", TypeInfo::int()).as_output());
1085
1086 let declarations = Self::build_param_declarations(params);
1088 request
1089 .params
1090 .push(RpcParam::nvarchar("@params", &declarations));
1091
1092 request.params.push(RpcParam::nvarchar("@stmt", sql));
1094
1095 request.params.push(RpcParam::int("@options", 1));
1097
1098 request
1099 }
1100
1101 pub fn execute(handle: i32, params: Vec<RpcParam>) -> Self {
1103 let mut request = Self::by_id(ProcId::Execute);
1104
1105 request.params.push(RpcParam::int("@handle", handle));
1107
1108 request.params.extend(params);
1110
1111 request
1112 }
1113
1114 pub fn unprepare(handle: i32) -> Self {
1116 let mut request = Self::by_id(ProcId::Unprepare);
1117 request.params.push(RpcParam::int("@handle", handle));
1118 request
1119 }
1120
1121 #[must_use]
1123 pub fn with_options(mut self, options: RpcOptionFlags) -> Self {
1124 self.options = options;
1125 self
1126 }
1127
1128 #[must_use]
1130 pub fn param(mut self, param: RpcParam) -> Self {
1131 self.params.push(param);
1132 self
1133 }
1134
1135 #[must_use]
1139 pub fn encode(&self) -> Bytes {
1140 self.encode_with_transaction(0)
1141 }
1142
1143 #[must_use]
1155 pub fn encode_with_transaction(&self, transaction_descriptor: u64) -> Bytes {
1156 let mut buf = BytesMut::with_capacity(256);
1157
1158 let all_headers_start = buf.len();
1161 buf.put_u32_le(0); buf.put_u32_le(18); buf.put_u16_le(0x0002); buf.put_u64_le(transaction_descriptor); buf.put_u32_le(1); let all_headers_len = buf.len() - all_headers_start;
1172 let len_bytes = (all_headers_len as u32).to_le_bytes();
1173 buf[all_headers_start..all_headers_start + 4].copy_from_slice(&len_bytes);
1174
1175 if let Some(proc_id) = self.proc_id {
1177 buf.put_u16_le(0xFFFF); buf.put_u16_le(proc_id as u16);
1180 } else if let Some(ref proc_name) = self.proc_name {
1181 let name_len = proc_name.encode_utf16().count() as u16;
1183 buf.put_u16_le(name_len);
1184 write_utf16_string(&mut buf, proc_name);
1185 }
1186
1187 buf.put_u16_le(self.options.encode());
1189
1190 for param in &self.params {
1192 param.encode(&mut buf);
1193 }
1194
1195 buf.freeze()
1196 }
1197}
1198
1199#[cfg(test)]
1200#[allow(clippy::unwrap_used)]
1201mod tests {
1202 use super::*;
1203
1204 #[test]
1213 fn encrypted_param_encode_matches_captured_dotnet_wire() {
1214 fn unhex(s: &str) -> Vec<u8> {
1215 (0..s.len())
1216 .step_by(2)
1217 .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap())
1218 .collect()
1219 }
1220
1221 let mut nvarchar_50 = TypeInfo::nvarchar(50);
1223 nvarchar_50.collation = Some([0x09, 0x04, 0xd0, 0x00, 0x34]);
1224
1225 let cases = [
1226 (
1227 "@i",
1228 TypeInfo::int(),
1229 "024000690008a5ffff41000000000000004100000001ed7b8c6030870d92358f12acb0d0c69c00bc3aa3ba578ecb0ea5f514b5045912a2b1ae52ed834f6bac49520956e4a574c30d573590fb3785556c8fe42f87c5b4000000002604020106000000020000000100000023722f0069b4000001",
1230 ),
1231 (
1232 "@s",
1233 nvarchar_50,
1234 "024000730008a5ffff4100000000000000410000000150c0a7dec4d4241c7a4a617007d32d97e7131f8c57a5ad212487891170f12ecb9957fce16389f4728d1c3c65813beeea085ae3fd516d29f84298df3e97f0d05d00000000e764000904d00034020106000000020000000100000023722f0069b4000001",
1235 ),
1236 (
1237 "@b",
1238 TypeInfo::varbinary(50),
1239 "024000620008a5ffff41000000000000004100000001d17165aa6df0155be6b78c6712d3b03870ea394cfed10956cf07fbfa204c4b82cddfa5e2f4fc03335f579e2767657e3067cd9da7d62a07427106b91f747b97da00000000a53200020106000000020000000100000023722f0069b4000001",
1240 ),
1241 ];
1242
1243 for (name, base_type_info, golden_hex) in cases {
1244 let golden = unhex(golden_hex);
1245 let cipher_off = 1 + name.encode_utf16().count() * 2 + 1 + 3 + 8 + 4;
1248 let cipher = Bytes::copy_from_slice(&golden[cipher_off..cipher_off + 65]);
1249
1250 let param = RpcParam::encrypted(
1251 name,
1252 cipher,
1253 EncryptedParamMetadata {
1254 base_type_info,
1255 algorithm_id: 2,
1256 encryption_type: EncryptionTypeWire::Deterministic,
1257 database_id: 6,
1258 cek_id: 2,
1259 cek_version: 1,
1260 cek_md_version: 0x0000_b469_002f_7223,
1261 normalization_rule_version: 1,
1262 },
1263 );
1264
1265 let mut buf = BytesMut::new();
1266 param.encode(&mut buf);
1267 assert_eq!(
1268 buf.to_vec(),
1269 golden,
1270 "encrypted {name} param must match the captured Microsoft.Data.SqlClient bytes"
1271 );
1272 }
1273 }
1274
1275 #[test]
1276 fn test_proc_id_values() {
1277 assert_eq!(ProcId::ExecuteSql as u16, 0x000A);
1278 assert_eq!(ProcId::Prepare as u16, 0x000B);
1279 assert_eq!(ProcId::Execute as u16, 0x000C);
1280 assert_eq!(ProcId::Unprepare as u16, 0x000F);
1281 }
1282
1283 #[test]
1284 fn test_option_flags_encode() {
1285 let flags = RpcOptionFlags::new().with_recompile(true);
1286 assert_eq!(flags.encode(), 0x0001);
1287 }
1288
1289 #[test]
1290 fn test_param_flags_encode() {
1291 let flags = ParamFlags::new().output();
1292 assert_eq!(flags.encode(), 0x01);
1293 }
1294
1295 #[test]
1296 fn test_int_param() {
1297 let param = RpcParam::int("@p1", 42);
1298 assert_eq!(param.name, "@p1");
1299 assert_eq!(param.type_info.type_id, 0x26);
1300 assert!(param.value.is_some());
1301 }
1302
1303 #[test]
1304 fn test_nvarchar_param() {
1305 let param = RpcParam::nvarchar("@name", "Alice");
1306 assert_eq!(param.name, "@name");
1307 assert_eq!(param.type_info.type_id, 0xE7);
1308 assert_eq!(param.value.as_ref().unwrap().len(), 10);
1310 }
1311
1312 #[test]
1313 fn test_nvarchar_param_surrogate_pair_length() {
1314 let param = RpcParam::nvarchar("@p", "π");
1318 assert_eq!(param.value.as_ref().unwrap().len(), 4);
1319 assert_eq!(param.type_info.max_length, Some(4));
1321
1322 let param = RpcParam::nvarchar("@p", "Hello δΈη π");
1323 assert_eq!(param.value.as_ref().unwrap().len(), 22);
1325 assert_eq!(param.type_info.max_length, Some(22));
1326 }
1327
1328 #[test]
1329 fn test_execute_sql_request() {
1330 let rpc = RpcRequest::execute_sql(
1331 "SELECT * FROM users WHERE id = @p1",
1332 vec![RpcParam::int("@p1", 42)],
1333 );
1334
1335 assert_eq!(rpc.proc_id, Some(ProcId::ExecuteSql));
1336 assert_eq!(rpc.params.len(), 3);
1338 }
1339
1340 #[test]
1341 fn test_param_declarations() {
1342 let params = vec![
1343 RpcParam::int("@p1", 42),
1344 RpcParam::nvarchar("@name", "Alice"),
1345 ];
1346
1347 let decls = RpcRequest::build_param_declarations(¶ms);
1348 assert!(decls.contains("@p1 int"));
1349 assert!(decls.contains("@name nvarchar"));
1350 }
1351
1352 #[test]
1353 fn test_rpc_encode_not_empty() {
1354 let rpc = RpcRequest::execute_sql("SELECT 1", vec![]);
1355 let encoded = rpc.encode();
1356 assert!(!encoded.is_empty());
1357 }
1358
1359 #[test]
1360 fn test_prepare_request() {
1361 let rpc = RpcRequest::prepare(
1362 "SELECT * FROM users WHERE id = @p1",
1363 &[RpcParam::int("@p1", 0)],
1364 );
1365
1366 assert_eq!(rpc.proc_id, Some(ProcId::Prepare));
1367 assert_eq!(rpc.params.len(), 4);
1369 assert!(rpc.params[0].flags.by_ref); }
1371
1372 #[test]
1373 fn test_execute_request() {
1374 let rpc = RpcRequest::execute(123, vec![RpcParam::int("@p1", 42)]);
1375
1376 assert_eq!(rpc.proc_id, Some(ProcId::Execute));
1377 assert_eq!(rpc.params.len(), 2); }
1379
1380 #[test]
1381 fn test_unprepare_request() {
1382 let rpc = RpcRequest::unprepare(123);
1383
1384 assert_eq!(rpc.proc_id, Some(ProcId::Unprepare));
1385 assert_eq!(rpc.params.len(), 1); }
1387
1388 #[test]
1389 fn test_varchar_param() {
1390 let param = RpcParam::varchar("@name", "Alice");
1391 assert_eq!(param.name, "@name");
1392 assert_eq!(param.type_info.type_id, 0xA7);
1393 assert_eq!(param.value.as_ref().unwrap().len(), 5);
1395 assert_eq!(¶m.value.as_ref().unwrap()[..], b"Alice");
1396 }
1397
1398 #[test]
1399 fn test_varchar_param_max() {
1400 let long_str = "a".repeat(9000);
1402 let param = RpcParam::varchar("@big", &long_str);
1403 assert_eq!(param.type_info.type_id, 0xA7);
1404 assert_eq!(param.type_info.max_length, Some(0xFFFF));
1405 assert_eq!(param.value.as_ref().unwrap().len(), 9000);
1406 }
1407
1408 #[test]
1409 fn test_varchar_param_declarations() {
1410 let params = vec![
1411 RpcParam::int("@p1", 42),
1412 RpcParam::varchar("@name", "Alice"),
1413 ];
1414
1415 let decls = RpcRequest::build_param_declarations(¶ms);
1416 assert!(decls.contains("@p1 int"));
1417 assert!(decls.contains("@name varchar(5)"));
1418 }
1419
1420 #[test]
1421 fn test_varchar_type_info_has_collation() {
1422 let ti = TypeInfo::varchar(100);
1423 assert_eq!(ti.type_id, 0xA7);
1424 assert_eq!(ti.max_length, Some(100));
1425 assert!(ti.collation.is_some());
1426 }
1427
1428 #[test]
1429 fn test_varchar_encode_round_trip() {
1430 let param = RpcParam::varchar("@val", "test value");
1432 let mut buf = bytes::BytesMut::new();
1433 param.encode(&mut buf);
1434 assert!(!buf.is_empty());
1435 }
1436
1437 #[test]
1438 fn test_collation_round_trip() {
1439 let collation = Collation {
1440 lcid: 0x00D0_0409,
1441 sort_id: 0x34,
1442 };
1443 let bytes = collation.to_bytes();
1444 assert_eq!(bytes, [0x09, 0x04, 0xD0, 0x00, 0x34]);
1445
1446 let restored = Collation::from_bytes(&bytes);
1447 assert_eq!(restored.lcid, collation.lcid);
1448 assert_eq!(restored.sort_id, collation.sort_id);
1449 }
1450
1451 #[test]
1452 fn test_varchar_with_collation_uses_custom_collation_bytes() {
1453 let collation = Collation {
1455 lcid: 0x0804,
1456 sort_id: 0,
1457 };
1458 let param = RpcParam::varchar_with_collation("@val", "test", &collation);
1459 assert_eq!(param.type_info.type_id, 0xA7);
1460 assert_eq!(param.type_info.collation, Some(collation.to_bytes()));
1462 }
1463
1464 #[test]
1465 fn test_money_type_info() {
1466 let ti = TypeInfo::money();
1467 assert_eq!(ti.type_id, 0x6E);
1468 assert_eq!(ti.max_length, Some(8));
1469 }
1470
1471 #[test]
1472 fn test_smallmoney_type_info() {
1473 let ti = TypeInfo::smallmoney();
1474 assert_eq!(ti.type_id, 0x6E);
1475 assert_eq!(ti.max_length, Some(4));
1476 }
1477
1478 #[test]
1479 fn test_smalldatetime_type_info() {
1480 let ti = TypeInfo::smalldatetime();
1481 assert_eq!(ti.type_id, 0x6F);
1482 assert_eq!(ti.max_length, Some(4));
1483 }
1484
1485 #[test]
1486 fn test_money_param_declarations() {
1487 let decls = RpcRequest::build_param_declarations(&[
1488 RpcParam::new("@m", TypeInfo::money(), Bytes::from_static(&[0u8; 8])),
1489 RpcParam::new("@sm", TypeInfo::smallmoney(), Bytes::from_static(&[0u8; 4])),
1490 RpcParam::new(
1491 "@sdt",
1492 TypeInfo::smalldatetime(),
1493 Bytes::from_static(&[0u8; 4]),
1494 ),
1495 ]);
1496 assert!(decls.contains("@m money"), "got: {decls}");
1497 assert!(decls.contains("@sm smallmoney"), "got: {decls}");
1498 assert!(decls.contains("@sdt smalldatetime"), "got: {decls}");
1499 }
1500
1501 #[test]
1502 fn test_money_typeinfo_encodes_max_length_byte() {
1503 let mut buf = bytes::BytesMut::new();
1504 TypeInfo::money().encode(&mut buf);
1505 assert_eq!(&buf[..], &[0x6E, 0x08]);
1507
1508 let mut buf = bytes::BytesMut::new();
1509 TypeInfo::smallmoney().encode(&mut buf);
1510 assert_eq!(&buf[..], &[0x6E, 0x04]);
1511
1512 let mut buf = bytes::BytesMut::new();
1513 TypeInfo::smalldatetime().encode(&mut buf);
1514 assert_eq!(&buf[..], &[0x6F, 0x04]);
1515 }
1516
1517 #[test]
1518 fn test_varchar_with_collation_default_vs_custom_differ() {
1519 let default_param = RpcParam::varchar("@val", "test");
1520 let custom_collation = Collation {
1521 lcid: 0x0419, sort_id: 0,
1523 };
1524 let custom_param = RpcParam::varchar_with_collation("@val", "test", &custom_collation);
1525 assert_ne!(
1527 default_param.type_info.collation,
1528 custom_param.type_info.collation
1529 );
1530 }
1531}