1use bytes::{BufMut, Bytes, BytesMut};
27
28use crate::codec::write_utf16_string;
29use crate::prelude::*;
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36#[repr(u16)]
37pub enum ProcId {
38 Cursor = 0x0001,
40 CursorOpen = 0x0002,
42 CursorPrepare = 0x0003,
44 CursorExecute = 0x0004,
46 CursorPrepExec = 0x0005,
48 CursorUnprepare = 0x0006,
50 CursorFetch = 0x0007,
52 CursorOption = 0x0008,
54 CursorClose = 0x0009,
56 ExecuteSql = 0x000A,
58 Prepare = 0x000B,
60 Execute = 0x000C,
62 PrepExec = 0x000D,
64 PrepExecRpc = 0x000E,
66 Unprepare = 0x000F,
68}
69
70#[derive(Debug, Clone, Copy, Default)]
72pub struct RpcOptionFlags {
73 pub with_recompile: bool,
75 pub no_metadata: bool,
77 pub reuse_metadata: bool,
79}
80
81impl RpcOptionFlags {
82 pub fn new() -> Self {
84 Self::default()
85 }
86
87 #[must_use]
89 pub fn with_recompile(mut self, value: bool) -> Self {
90 self.with_recompile = value;
91 self
92 }
93
94 pub fn encode(&self) -> u16 {
96 let mut flags = 0u16;
97 if self.with_recompile {
98 flags |= 0x0001;
99 }
100 if self.no_metadata {
101 flags |= 0x0002;
102 }
103 if self.reuse_metadata {
104 flags |= 0x0004;
105 }
106 flags
107 }
108}
109
110#[derive(Debug, Clone, Copy, Default)]
112pub struct ParamFlags {
113 pub by_ref: bool,
115 pub default: bool,
117 pub encrypted: bool,
119}
120
121impl ParamFlags {
122 pub fn new() -> Self {
124 Self::default()
125 }
126
127 #[must_use]
129 pub fn output(mut self) -> Self {
130 self.by_ref = true;
131 self
132 }
133
134 pub fn encode(&self) -> u8 {
136 let mut flags = 0u8;
137 if self.by_ref {
138 flags |= 0x01;
139 }
140 if self.default {
141 flags |= 0x02;
142 }
143 if self.encrypted {
144 flags |= 0x08;
145 }
146 flags
147 }
148}
149
150#[derive(Debug, Clone)]
152pub struct TypeInfo {
153 pub type_id: u8,
155 pub max_length: Option<u16>,
157 pub precision: Option<u8>,
159 pub scale: Option<u8>,
161 pub collation: Option<[u8; 5]>,
163 pub tvp_type_name: Option<String>,
165}
166
167impl TypeInfo {
168 pub fn int() -> Self {
170 Self {
171 type_id: 0x26, max_length: Some(4),
173 precision: None,
174 scale: None,
175 collation: None,
176 tvp_type_name: None,
177 }
178 }
179
180 pub fn bigint() -> Self {
182 Self {
183 type_id: 0x26, max_length: Some(8),
185 precision: None,
186 scale: None,
187 collation: None,
188 tvp_type_name: None,
189 }
190 }
191
192 pub fn smallint() -> Self {
194 Self {
195 type_id: 0x26, max_length: Some(2),
197 precision: None,
198 scale: None,
199 collation: None,
200 tvp_type_name: None,
201 }
202 }
203
204 pub fn tinyint() -> Self {
206 Self {
207 type_id: 0x26, max_length: Some(1),
209 precision: None,
210 scale: None,
211 collation: None,
212 tvp_type_name: None,
213 }
214 }
215
216 pub fn bit() -> Self {
218 Self {
219 type_id: 0x68, max_length: Some(1),
221 precision: None,
222 scale: None,
223 collation: None,
224 tvp_type_name: None,
225 }
226 }
227
228 pub fn float() -> Self {
230 Self {
231 type_id: 0x6D, max_length: Some(8),
233 precision: None,
234 scale: None,
235 collation: None,
236 tvp_type_name: None,
237 }
238 }
239
240 pub fn real() -> Self {
242 Self {
243 type_id: 0x6D, max_length: Some(4),
245 precision: None,
246 scale: None,
247 collation: None,
248 tvp_type_name: None,
249 }
250 }
251
252 pub fn nvarchar(max_len: u16) -> Self {
254 Self {
255 type_id: 0xE7, max_length: Some(max_len * 2), precision: None,
258 scale: None,
259 collation: Some([0x09, 0x04, 0xD0, 0x00, 0x34]),
261 tvp_type_name: None,
262 }
263 }
264
265 pub fn nvarchar_max() -> Self {
267 Self {
268 type_id: 0xE7, max_length: Some(0xFFFF), precision: None,
271 scale: None,
272 collation: Some([0x09, 0x04, 0xD0, 0x00, 0x34]),
273 tvp_type_name: None,
274 }
275 }
276
277 pub fn varbinary(max_len: u16) -> Self {
279 Self {
280 type_id: 0xA5, max_length: Some(max_len),
282 precision: None,
283 scale: None,
284 collation: None,
285 tvp_type_name: None,
286 }
287 }
288
289 pub fn uniqueidentifier() -> Self {
291 Self {
292 type_id: 0x24, max_length: Some(16),
294 precision: None,
295 scale: None,
296 collation: None,
297 tvp_type_name: None,
298 }
299 }
300
301 pub fn date() -> Self {
303 Self {
304 type_id: 0x28, max_length: None,
306 precision: None,
307 scale: None,
308 collation: None,
309 tvp_type_name: None,
310 }
311 }
312
313 pub fn datetime2(scale: u8) -> Self {
315 Self {
316 type_id: 0x2A, max_length: None,
318 precision: None,
319 scale: Some(scale),
320 collation: None,
321 tvp_type_name: None,
322 }
323 }
324
325 pub fn decimal(precision: u8, scale: u8) -> Self {
327 Self {
328 type_id: 0x6C, max_length: Some(17), precision: Some(precision),
331 scale: Some(scale),
332 collation: None,
333 tvp_type_name: None,
334 }
335 }
336
337 pub fn tvp(type_name: impl Into<String>) -> Self {
342 Self {
343 type_id: 0xF3, max_length: None,
345 precision: None,
346 scale: None,
347 collation: None,
348 tvp_type_name: Some(type_name.into()),
349 }
350 }
351
352 pub fn encode(&self, buf: &mut BytesMut) {
354 if self.type_id != 0xF3 {
357 buf.put_u8(self.type_id);
358 }
359
360 match self.type_id {
362 0x26 | 0x68 | 0x6D => {
363 if let Some(len) = self.max_length {
365 buf.put_u8(len as u8);
366 }
367 }
368 0xE7 | 0xA5 | 0xEF => {
369 if let Some(len) = self.max_length {
371 buf.put_u16_le(len);
372 }
373 if let Some(collation) = self.collation {
375 buf.put_slice(&collation);
376 }
377 }
378 0x24 => {
379 if let Some(len) = self.max_length {
381 buf.put_u8(len as u8);
382 }
383 }
384 0x29..=0x2B => {
385 if let Some(scale) = self.scale {
387 buf.put_u8(scale);
388 }
389 }
390 0x6C | 0x6A => {
391 if let Some(len) = self.max_length {
393 buf.put_u8(len as u8);
394 }
395 if let Some(precision) = self.precision {
396 buf.put_u8(precision);
397 }
398 if let Some(scale) = self.scale {
399 buf.put_u8(scale);
400 }
401 }
402 _ => {}
403 }
404 }
405}
406
407#[derive(Debug, Clone)]
409pub struct RpcParam {
410 pub name: String,
412 pub flags: ParamFlags,
414 pub type_info: TypeInfo,
416 pub value: Option<Bytes>,
418}
419
420impl RpcParam {
421 pub fn new(name: impl Into<String>, type_info: TypeInfo, value: Bytes) -> Self {
423 Self {
424 name: name.into(),
425 flags: ParamFlags::default(),
426 type_info,
427 value: Some(value),
428 }
429 }
430
431 pub fn null(name: impl Into<String>, type_info: TypeInfo) -> Self {
433 Self {
434 name: name.into(),
435 flags: ParamFlags::default(),
436 type_info,
437 value: None,
438 }
439 }
440
441 pub fn int(name: impl Into<String>, value: i32) -> Self {
443 let mut buf = BytesMut::with_capacity(4);
444 buf.put_i32_le(value);
445 Self::new(name, TypeInfo::int(), buf.freeze())
446 }
447
448 pub fn bigint(name: impl Into<String>, value: i64) -> Self {
450 let mut buf = BytesMut::with_capacity(8);
451 buf.put_i64_le(value);
452 Self::new(name, TypeInfo::bigint(), buf.freeze())
453 }
454
455 pub fn nvarchar(name: impl Into<String>, value: &str) -> Self {
457 let mut buf = BytesMut::new();
458 for code_unit in value.encode_utf16() {
460 buf.put_u16_le(code_unit);
461 }
462 let char_len = value.chars().count();
463 let type_info = if char_len > 4000 {
464 TypeInfo::nvarchar_max()
465 } else {
466 TypeInfo::nvarchar(char_len.max(1) as u16)
467 };
468 Self::new(name, type_info, buf.freeze())
469 }
470
471 #[must_use]
473 pub fn as_output(mut self) -> Self {
474 self.flags = self.flags.output();
475 self
476 }
477
478 pub fn encode(&self, buf: &mut BytesMut) {
480 let name_len = self.name.encode_utf16().count() as u8;
482 buf.put_u8(name_len);
483 if name_len > 0 {
484 for code_unit in self.name.encode_utf16() {
485 buf.put_u16_le(code_unit);
486 }
487 }
488
489 buf.put_u8(self.flags.encode());
491
492 self.type_info.encode(buf);
494
495 if let Some(ref value) = self.value {
497 match self.type_info.type_id {
499 0x26 => {
500 buf.put_u8(value.len() as u8);
502 buf.put_slice(value);
503 }
504 0x68 | 0x6D => {
505 buf.put_u8(value.len() as u8);
507 buf.put_slice(value);
508 }
509 0xE7 | 0xA5 => {
510 if self.type_info.max_length == Some(0xFFFF) {
512 let total_len = value.len() as u64;
515 buf.put_u64_le(total_len);
516 buf.put_u32_le(value.len() as u32);
517 buf.put_slice(value);
518 buf.put_u32_le(0); } else {
520 buf.put_u16_le(value.len() as u16);
521 buf.put_slice(value);
522 }
523 }
524 0x24 => {
525 buf.put_u8(value.len() as u8);
527 buf.put_slice(value);
528 }
529 0x28 => {
530 buf.put_slice(value);
532 }
533 0x2A => {
534 buf.put_u8(value.len() as u8);
536 buf.put_slice(value);
537 }
538 0x6C => {
539 buf.put_u8(value.len() as u8);
541 buf.put_slice(value);
542 }
543 0xF3 => {
544 buf.put_slice(value);
548 }
549 _ => {
550 buf.put_u8(value.len() as u8);
552 buf.put_slice(value);
553 }
554 }
555 } else {
556 match self.type_info.type_id {
558 0xE7 | 0xA5 => {
559 if self.type_info.max_length == Some(0xFFFF) {
561 buf.put_u64_le(0xFFFFFFFFFFFFFFFF); } else {
563 buf.put_u16_le(0xFFFF);
564 }
565 }
566 _ => {
567 buf.put_u8(0); }
569 }
570 }
571 }
572}
573
574#[derive(Debug, Clone)]
576pub struct RpcRequest {
577 proc_name: Option<String>,
579 proc_id: Option<ProcId>,
581 options: RpcOptionFlags,
583 params: Vec<RpcParam>,
585}
586
587impl RpcRequest {
588 pub fn named(proc_name: impl Into<String>) -> Self {
590 Self {
591 proc_name: Some(proc_name.into()),
592 proc_id: None,
593 options: RpcOptionFlags::default(),
594 params: Vec::new(),
595 }
596 }
597
598 pub fn by_id(proc_id: ProcId) -> Self {
600 Self {
601 proc_name: None,
602 proc_id: Some(proc_id),
603 options: RpcOptionFlags::default(),
604 params: Vec::new(),
605 }
606 }
607
608 pub fn execute_sql(sql: &str, params: Vec<RpcParam>) -> Self {
626 let mut request = Self::by_id(ProcId::ExecuteSql);
627
628 request.params.push(RpcParam::nvarchar("", sql));
630
631 if !params.is_empty() {
633 let declarations = Self::build_param_declarations(¶ms);
634 request.params.push(RpcParam::nvarchar("", &declarations));
635 }
636
637 request.params.extend(params);
639
640 request
641 }
642
643 fn build_param_declarations(params: &[RpcParam]) -> String {
645 params
646 .iter()
647 .map(|p| {
648 let name = if p.name.starts_with('@') {
649 p.name.clone()
650 } else if p.name.is_empty() {
651 format!(
653 "@p{}",
654 params.iter().position(|x| x.name == p.name).unwrap_or(0) + 1
655 )
656 } else {
657 format!("@{}", p.name)
658 };
659
660 let type_name: String = match p.type_info.type_id {
661 0x26 => match p.type_info.max_length {
662 Some(1) => "tinyint".to_string(),
663 Some(2) => "smallint".to_string(),
664 Some(4) => "int".to_string(),
665 Some(8) => "bigint".to_string(),
666 _ => "int".to_string(),
667 },
668 0x68 => "bit".to_string(),
669 0x6D => match p.type_info.max_length {
670 Some(4) => "real".to_string(),
671 _ => "float".to_string(),
672 },
673 0xE7 => {
674 if p.type_info.max_length == Some(0xFFFF) {
675 "nvarchar(max)".to_string()
676 } else {
677 let len = p.type_info.max_length.unwrap_or(4000) / 2;
678 format!("nvarchar({})", len)
679 }
680 }
681 0xA5 => {
682 if p.type_info.max_length == Some(0xFFFF) {
683 "varbinary(max)".to_string()
684 } else {
685 let len = p.type_info.max_length.unwrap_or(8000);
686 format!("varbinary({})", len)
687 }
688 }
689 0x24 => "uniqueidentifier".to_string(),
690 0x28 => "date".to_string(),
691 0x2A => {
692 let scale = p.type_info.scale.unwrap_or(7);
693 format!("datetime2({})", scale)
694 }
695 0x6C => {
696 let precision = p.type_info.precision.unwrap_or(18);
697 let scale = p.type_info.scale.unwrap_or(0);
698 format!("decimal({}, {})", precision, scale)
699 }
700 0xF3 => {
701 if let Some(ref tvp_name) = p.type_info.tvp_type_name {
704 format!("{} READONLY", tvp_name)
705 } else {
706 "sql_variant".to_string()
708 }
709 }
710 _ => "sql_variant".to_string(),
711 };
712
713 format!("{} {}", name, type_name)
714 })
715 .collect::<Vec<_>>()
716 .join(", ")
717 }
718
719 pub fn prepare(sql: &str, params: &[RpcParam]) -> Self {
721 let mut request = Self::by_id(ProcId::Prepare);
722
723 request
725 .params
726 .push(RpcParam::null("@handle", TypeInfo::int()).as_output());
727
728 let declarations = Self::build_param_declarations(params);
730 request
731 .params
732 .push(RpcParam::nvarchar("@params", &declarations));
733
734 request.params.push(RpcParam::nvarchar("@stmt", sql));
736
737 request.params.push(RpcParam::int("@options", 1));
739
740 request
741 }
742
743 pub fn execute(handle: i32, params: Vec<RpcParam>) -> Self {
745 let mut request = Self::by_id(ProcId::Execute);
746
747 request.params.push(RpcParam::int("@handle", handle));
749
750 request.params.extend(params);
752
753 request
754 }
755
756 pub fn unprepare(handle: i32) -> Self {
758 let mut request = Self::by_id(ProcId::Unprepare);
759 request.params.push(RpcParam::int("@handle", handle));
760 request
761 }
762
763 #[must_use]
765 pub fn with_options(mut self, options: RpcOptionFlags) -> Self {
766 self.options = options;
767 self
768 }
769
770 #[must_use]
772 pub fn param(mut self, param: RpcParam) -> Self {
773 self.params.push(param);
774 self
775 }
776
777 #[must_use]
781 pub fn encode(&self) -> Bytes {
782 self.encode_with_transaction(0)
783 }
784
785 #[must_use]
797 pub fn encode_with_transaction(&self, transaction_descriptor: u64) -> Bytes {
798 let mut buf = BytesMut::with_capacity(256);
799
800 let all_headers_start = buf.len();
803 buf.put_u32_le(0); buf.put_u32_le(18); buf.put_u16_le(0x0002); buf.put_u64_le(transaction_descriptor); buf.put_u32_le(1); let all_headers_len = buf.len() - all_headers_start;
814 let len_bytes = (all_headers_len as u32).to_le_bytes();
815 buf[all_headers_start..all_headers_start + 4].copy_from_slice(&len_bytes);
816
817 if let Some(proc_id) = self.proc_id {
819 buf.put_u16_le(0xFFFF); buf.put_u16_le(proc_id as u16);
822 } else if let Some(ref proc_name) = self.proc_name {
823 let name_len = proc_name.encode_utf16().count() as u16;
825 buf.put_u16_le(name_len);
826 write_utf16_string(&mut buf, proc_name);
827 }
828
829 buf.put_u16_le(self.options.encode());
831
832 for param in &self.params {
834 param.encode(&mut buf);
835 }
836
837 buf.freeze()
838 }
839}
840
841#[cfg(test)]
842#[allow(clippy::unwrap_used)]
843mod tests {
844 use super::*;
845
846 #[test]
847 fn test_proc_id_values() {
848 assert_eq!(ProcId::ExecuteSql as u16, 0x000A);
849 assert_eq!(ProcId::Prepare as u16, 0x000B);
850 assert_eq!(ProcId::Execute as u16, 0x000C);
851 assert_eq!(ProcId::Unprepare as u16, 0x000F);
852 }
853
854 #[test]
855 fn test_option_flags_encode() {
856 let flags = RpcOptionFlags::new().with_recompile(true);
857 assert_eq!(flags.encode(), 0x0001);
858 }
859
860 #[test]
861 fn test_param_flags_encode() {
862 let flags = ParamFlags::new().output();
863 assert_eq!(flags.encode(), 0x01);
864 }
865
866 #[test]
867 fn test_int_param() {
868 let param = RpcParam::int("@p1", 42);
869 assert_eq!(param.name, "@p1");
870 assert_eq!(param.type_info.type_id, 0x26);
871 assert!(param.value.is_some());
872 }
873
874 #[test]
875 fn test_nvarchar_param() {
876 let param = RpcParam::nvarchar("@name", "Alice");
877 assert_eq!(param.name, "@name");
878 assert_eq!(param.type_info.type_id, 0xE7);
879 assert_eq!(param.value.as_ref().unwrap().len(), 10);
881 }
882
883 #[test]
884 fn test_execute_sql_request() {
885 let rpc = RpcRequest::execute_sql(
886 "SELECT * FROM users WHERE id = @p1",
887 vec![RpcParam::int("@p1", 42)],
888 );
889
890 assert_eq!(rpc.proc_id, Some(ProcId::ExecuteSql));
891 assert_eq!(rpc.params.len(), 3);
893 }
894
895 #[test]
896 fn test_param_declarations() {
897 let params = vec![
898 RpcParam::int("@p1", 42),
899 RpcParam::nvarchar("@name", "Alice"),
900 ];
901
902 let decls = RpcRequest::build_param_declarations(¶ms);
903 assert!(decls.contains("@p1 int"));
904 assert!(decls.contains("@name nvarchar"));
905 }
906
907 #[test]
908 fn test_rpc_encode_not_empty() {
909 let rpc = RpcRequest::execute_sql("SELECT 1", vec![]);
910 let encoded = rpc.encode();
911 assert!(!encoded.is_empty());
912 }
913
914 #[test]
915 fn test_prepare_request() {
916 let rpc = RpcRequest::prepare(
917 "SELECT * FROM users WHERE id = @p1",
918 &[RpcParam::int("@p1", 0)],
919 );
920
921 assert_eq!(rpc.proc_id, Some(ProcId::Prepare));
922 assert_eq!(rpc.params.len(), 4);
924 assert!(rpc.params[0].flags.by_ref); }
926
927 #[test]
928 fn test_execute_request() {
929 let rpc = RpcRequest::execute(123, vec![RpcParam::int("@p1", 42)]);
930
931 assert_eq!(rpc.proc_id, Some(ProcId::Execute));
932 assert_eq!(rpc.params.len(), 2); }
934
935 #[test]
936 fn test_unprepare_request() {
937 let rpc = RpcRequest::unprepare(123);
938
939 assert_eq!(rpc.proc_id, Some(ProcId::Unprepare));
940 assert_eq!(rpc.params.len(), 1); }
942}