1use bytes::{BufMut, Bytes, BytesMut};
27
28use crate::codec::write_utf16_string;
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35#[repr(u16)]
36pub enum ProcId {
37 Cursor = 0x0001,
39 CursorOpen = 0x0002,
41 CursorPrepare = 0x0003,
43 CursorExecute = 0x0004,
45 CursorPrepExec = 0x0005,
47 CursorUnprepare = 0x0006,
49 CursorFetch = 0x0007,
51 CursorOption = 0x0008,
53 CursorClose = 0x0009,
55 ExecuteSql = 0x000A,
57 Prepare = 0x000B,
59 Execute = 0x000C,
61 PrepExec = 0x000D,
63 PrepExecRpc = 0x000E,
65 Unprepare = 0x000F,
67}
68
69#[derive(Debug, Clone, Copy, Default)]
71pub struct RpcOptionFlags {
72 pub with_recompile: bool,
74 pub no_metadata: bool,
76 pub reuse_metadata: bool,
78}
79
80impl RpcOptionFlags {
81 pub fn new() -> Self {
83 Self::default()
84 }
85
86 #[must_use]
88 pub fn with_recompile(mut self, value: bool) -> Self {
89 self.with_recompile = value;
90 self
91 }
92
93 pub fn encode(&self) -> u16 {
95 let mut flags = 0u16;
96 if self.with_recompile {
97 flags |= 0x0001;
98 }
99 if self.no_metadata {
100 flags |= 0x0002;
101 }
102 if self.reuse_metadata {
103 flags |= 0x0004;
104 }
105 flags
106 }
107}
108
109#[derive(Debug, Clone, Copy, Default)]
111pub struct ParamFlags {
112 pub by_ref: bool,
114 pub default: bool,
116 pub encrypted: bool,
118}
119
120impl ParamFlags {
121 pub fn new() -> Self {
123 Self::default()
124 }
125
126 #[must_use]
128 pub fn output(mut self) -> Self {
129 self.by_ref = true;
130 self
131 }
132
133 pub fn encode(&self) -> u8 {
135 let mut flags = 0u8;
136 if self.by_ref {
137 flags |= 0x01;
138 }
139 if self.default {
140 flags |= 0x02;
141 }
142 if self.encrypted {
143 flags |= 0x08;
144 }
145 flags
146 }
147}
148
149#[derive(Debug, Clone)]
151pub struct TypeInfo {
152 pub type_id: u8,
154 pub max_length: Option<u16>,
156 pub precision: Option<u8>,
158 pub scale: Option<u8>,
160 pub collation: Option<[u8; 5]>,
162}
163
164impl TypeInfo {
165 pub fn int() -> Self {
167 Self {
168 type_id: 0x26, max_length: Some(4),
170 precision: None,
171 scale: None,
172 collation: None,
173 }
174 }
175
176 pub fn bigint() -> Self {
178 Self {
179 type_id: 0x26, max_length: Some(8),
181 precision: None,
182 scale: None,
183 collation: None,
184 }
185 }
186
187 pub fn smallint() -> Self {
189 Self {
190 type_id: 0x26, max_length: Some(2),
192 precision: None,
193 scale: None,
194 collation: None,
195 }
196 }
197
198 pub fn tinyint() -> Self {
200 Self {
201 type_id: 0x26, max_length: Some(1),
203 precision: None,
204 scale: None,
205 collation: None,
206 }
207 }
208
209 pub fn bit() -> Self {
211 Self {
212 type_id: 0x68, max_length: Some(1),
214 precision: None,
215 scale: None,
216 collation: None,
217 }
218 }
219
220 pub fn float() -> Self {
222 Self {
223 type_id: 0x6D, max_length: Some(8),
225 precision: None,
226 scale: None,
227 collation: None,
228 }
229 }
230
231 pub fn real() -> Self {
233 Self {
234 type_id: 0x6D, max_length: Some(4),
236 precision: None,
237 scale: None,
238 collation: None,
239 }
240 }
241
242 pub fn nvarchar(max_len: u16) -> Self {
244 Self {
245 type_id: 0xE7, max_length: Some(max_len * 2), precision: None,
248 scale: None,
249 collation: Some([0x09, 0x04, 0xD0, 0x00, 0x34]),
251 }
252 }
253
254 pub fn nvarchar_max() -> Self {
256 Self {
257 type_id: 0xE7, max_length: Some(0xFFFF), precision: None,
260 scale: None,
261 collation: Some([0x09, 0x04, 0xD0, 0x00, 0x34]),
262 }
263 }
264
265 pub fn varbinary(max_len: u16) -> Self {
267 Self {
268 type_id: 0xA5, max_length: Some(max_len),
270 precision: None,
271 scale: None,
272 collation: None,
273 }
274 }
275
276 pub fn uniqueidentifier() -> Self {
278 Self {
279 type_id: 0x24, max_length: Some(16),
281 precision: None,
282 scale: None,
283 collation: None,
284 }
285 }
286
287 pub fn date() -> Self {
289 Self {
290 type_id: 0x28, max_length: None,
292 precision: None,
293 scale: None,
294 collation: None,
295 }
296 }
297
298 pub fn datetime2(scale: u8) -> Self {
300 Self {
301 type_id: 0x2A, max_length: None,
303 precision: None,
304 scale: Some(scale),
305 collation: None,
306 }
307 }
308
309 pub fn decimal(precision: u8, scale: u8) -> Self {
311 Self {
312 type_id: 0x6C, max_length: Some(17), precision: Some(precision),
315 scale: Some(scale),
316 collation: None,
317 }
318 }
319
320 pub fn encode(&self, buf: &mut BytesMut) {
322 buf.put_u8(self.type_id);
323
324 match self.type_id {
326 0x26 | 0x68 | 0x6D => {
327 if let Some(len) = self.max_length {
329 buf.put_u8(len as u8);
330 }
331 }
332 0xE7 | 0xA5 | 0xEF => {
333 if let Some(len) = self.max_length {
335 buf.put_u16_le(len);
336 }
337 if let Some(collation) = self.collation {
339 buf.put_slice(&collation);
340 }
341 }
342 0x24 => {
343 if let Some(len) = self.max_length {
345 buf.put_u8(len as u8);
346 }
347 }
348 0x29..=0x2B => {
349 if let Some(scale) = self.scale {
351 buf.put_u8(scale);
352 }
353 }
354 0x6C | 0x6A => {
355 if let Some(len) = self.max_length {
357 buf.put_u8(len as u8);
358 }
359 if let Some(precision) = self.precision {
360 buf.put_u8(precision);
361 }
362 if let Some(scale) = self.scale {
363 buf.put_u8(scale);
364 }
365 }
366 _ => {}
367 }
368 }
369}
370
371#[derive(Debug, Clone)]
373pub struct RpcParam {
374 pub name: String,
376 pub flags: ParamFlags,
378 pub type_info: TypeInfo,
380 pub value: Option<Bytes>,
382}
383
384impl RpcParam {
385 pub fn new(name: impl Into<String>, type_info: TypeInfo, value: Bytes) -> Self {
387 Self {
388 name: name.into(),
389 flags: ParamFlags::default(),
390 type_info,
391 value: Some(value),
392 }
393 }
394
395 pub fn null(name: impl Into<String>, type_info: TypeInfo) -> Self {
397 Self {
398 name: name.into(),
399 flags: ParamFlags::default(),
400 type_info,
401 value: None,
402 }
403 }
404
405 pub fn int(name: impl Into<String>, value: i32) -> Self {
407 let mut buf = BytesMut::with_capacity(4);
408 buf.put_i32_le(value);
409 Self::new(name, TypeInfo::int(), buf.freeze())
410 }
411
412 pub fn bigint(name: impl Into<String>, value: i64) -> Self {
414 let mut buf = BytesMut::with_capacity(8);
415 buf.put_i64_le(value);
416 Self::new(name, TypeInfo::bigint(), buf.freeze())
417 }
418
419 pub fn nvarchar(name: impl Into<String>, value: &str) -> Self {
421 let mut buf = BytesMut::new();
422 for code_unit in value.encode_utf16() {
424 buf.put_u16_le(code_unit);
425 }
426 let char_len = value.chars().count();
427 let type_info = if char_len > 4000 {
428 TypeInfo::nvarchar_max()
429 } else {
430 TypeInfo::nvarchar(char_len.max(1) as u16)
431 };
432 Self::new(name, type_info, buf.freeze())
433 }
434
435 #[must_use]
437 pub fn as_output(mut self) -> Self {
438 self.flags = self.flags.output();
439 self
440 }
441
442 pub fn encode(&self, buf: &mut BytesMut) {
444 let name_len = self.name.encode_utf16().count() as u8;
446 buf.put_u8(name_len);
447 if name_len > 0 {
448 for code_unit in self.name.encode_utf16() {
449 buf.put_u16_le(code_unit);
450 }
451 }
452
453 buf.put_u8(self.flags.encode());
455
456 self.type_info.encode(buf);
458
459 if let Some(ref value) = self.value {
461 match self.type_info.type_id {
463 0x26 => {
464 buf.put_u8(value.len() as u8);
466 buf.put_slice(value);
467 }
468 0x68 | 0x6D => {
469 buf.put_u8(value.len() as u8);
471 buf.put_slice(value);
472 }
473 0xE7 | 0xA5 => {
474 if self.type_info.max_length == Some(0xFFFF) {
476 let total_len = value.len() as u64;
479 buf.put_u64_le(total_len);
480 buf.put_u32_le(value.len() as u32);
481 buf.put_slice(value);
482 buf.put_u32_le(0); } else {
484 buf.put_u16_le(value.len() as u16);
485 buf.put_slice(value);
486 }
487 }
488 0x24 => {
489 buf.put_u8(value.len() as u8);
491 buf.put_slice(value);
492 }
493 0x28 => {
494 buf.put_slice(value);
496 }
497 0x2A => {
498 buf.put_u8(value.len() as u8);
500 buf.put_slice(value);
501 }
502 0x6C => {
503 buf.put_u8(value.len() as u8);
505 buf.put_slice(value);
506 }
507 _ => {
508 buf.put_u8(value.len() as u8);
510 buf.put_slice(value);
511 }
512 }
513 } else {
514 match self.type_info.type_id {
516 0xE7 | 0xA5 => {
517 if self.type_info.max_length == Some(0xFFFF) {
519 buf.put_u64_le(0xFFFFFFFFFFFFFFFF); } else {
521 buf.put_u16_le(0xFFFF);
522 }
523 }
524 _ => {
525 buf.put_u8(0); }
527 }
528 }
529 }
530}
531
532#[derive(Debug, Clone)]
534pub struct RpcRequest {
535 proc_name: Option<String>,
537 proc_id: Option<ProcId>,
539 options: RpcOptionFlags,
541 params: Vec<RpcParam>,
543}
544
545impl RpcRequest {
546 pub fn named(proc_name: impl Into<String>) -> Self {
548 Self {
549 proc_name: Some(proc_name.into()),
550 proc_id: None,
551 options: RpcOptionFlags::default(),
552 params: Vec::new(),
553 }
554 }
555
556 pub fn by_id(proc_id: ProcId) -> Self {
558 Self {
559 proc_name: None,
560 proc_id: Some(proc_id),
561 options: RpcOptionFlags::default(),
562 params: Vec::new(),
563 }
564 }
565
566 pub fn execute_sql(sql: &str, params: Vec<RpcParam>) -> Self {
584 let mut request = Self::by_id(ProcId::ExecuteSql);
585
586 request.params.push(RpcParam::nvarchar("", sql));
588
589 if !params.is_empty() {
591 let declarations = Self::build_param_declarations(¶ms);
592 request.params.push(RpcParam::nvarchar("", &declarations));
593 }
594
595 request.params.extend(params);
597
598 request
599 }
600
601 fn build_param_declarations(params: &[RpcParam]) -> String {
603 params
604 .iter()
605 .map(|p| {
606 let name = if p.name.starts_with('@') {
607 p.name.clone()
608 } else if p.name.is_empty() {
609 format!(
611 "@p{}",
612 params.iter().position(|x| x.name == p.name).unwrap_or(0) + 1
613 )
614 } else {
615 format!("@{}", p.name)
616 };
617
618 let type_name: String = match p.type_info.type_id {
619 0x26 => match p.type_info.max_length {
620 Some(1) => "tinyint".to_string(),
621 Some(2) => "smallint".to_string(),
622 Some(4) => "int".to_string(),
623 Some(8) => "bigint".to_string(),
624 _ => "int".to_string(),
625 },
626 0x68 => "bit".to_string(),
627 0x6D => match p.type_info.max_length {
628 Some(4) => "real".to_string(),
629 _ => "float".to_string(),
630 },
631 0xE7 => {
632 if p.type_info.max_length == Some(0xFFFF) {
633 "nvarchar(max)".to_string()
634 } else {
635 let len = p.type_info.max_length.unwrap_or(4000) / 2;
636 format!("nvarchar({})", len)
637 }
638 }
639 0xA5 => {
640 if p.type_info.max_length == Some(0xFFFF) {
641 "varbinary(max)".to_string()
642 } else {
643 let len = p.type_info.max_length.unwrap_or(8000);
644 format!("varbinary({})", len)
645 }
646 }
647 0x24 => "uniqueidentifier".to_string(),
648 0x28 => "date".to_string(),
649 0x2A => {
650 let scale = p.type_info.scale.unwrap_or(7);
651 format!("datetime2({})", scale)
652 }
653 0x6C => {
654 let precision = p.type_info.precision.unwrap_or(18);
655 let scale = p.type_info.scale.unwrap_or(0);
656 format!("decimal({}, {})", precision, scale)
657 }
658 _ => "sql_variant".to_string(),
659 };
660
661 format!("{} {}", name, type_name)
662 })
663 .collect::<Vec<_>>()
664 .join(", ")
665 }
666
667 pub fn prepare(sql: &str, params: &[RpcParam]) -> Self {
669 let mut request = Self::by_id(ProcId::Prepare);
670
671 request
673 .params
674 .push(RpcParam::null("@handle", TypeInfo::int()).as_output());
675
676 let declarations = Self::build_param_declarations(params);
678 request
679 .params
680 .push(RpcParam::nvarchar("@params", &declarations));
681
682 request.params.push(RpcParam::nvarchar("@stmt", sql));
684
685 request.params.push(RpcParam::int("@options", 1));
687
688 request
689 }
690
691 pub fn execute(handle: i32, params: Vec<RpcParam>) -> Self {
693 let mut request = Self::by_id(ProcId::Execute);
694
695 request.params.push(RpcParam::int("@handle", handle));
697
698 request.params.extend(params);
700
701 request
702 }
703
704 pub fn unprepare(handle: i32) -> Self {
706 let mut request = Self::by_id(ProcId::Unprepare);
707 request.params.push(RpcParam::int("@handle", handle));
708 request
709 }
710
711 #[must_use]
713 pub fn with_options(mut self, options: RpcOptionFlags) -> Self {
714 self.options = options;
715 self
716 }
717
718 #[must_use]
720 pub fn param(mut self, param: RpcParam) -> Self {
721 self.params.push(param);
722 self
723 }
724
725 #[must_use]
729 pub fn encode(&self) -> Bytes {
730 self.encode_with_transaction(0)
731 }
732
733 #[must_use]
745 pub fn encode_with_transaction(&self, transaction_descriptor: u64) -> Bytes {
746 let mut buf = BytesMut::with_capacity(256);
747
748 let all_headers_start = buf.len();
751 buf.put_u32_le(0); buf.put_u32_le(18); buf.put_u16_le(0x0002); buf.put_u64_le(transaction_descriptor); buf.put_u32_le(1); let all_headers_len = buf.len() - all_headers_start;
762 let len_bytes = (all_headers_len as u32).to_le_bytes();
763 buf[all_headers_start..all_headers_start + 4].copy_from_slice(&len_bytes);
764
765 if let Some(proc_id) = self.proc_id {
767 buf.put_u16_le(0xFFFF); buf.put_u16_le(proc_id as u16);
770 } else if let Some(ref proc_name) = self.proc_name {
771 let name_len = proc_name.encode_utf16().count() as u16;
773 buf.put_u16_le(name_len);
774 write_utf16_string(&mut buf, proc_name);
775 }
776
777 buf.put_u16_le(self.options.encode());
779
780 for param in &self.params {
782 param.encode(&mut buf);
783 }
784
785 buf.freeze()
786 }
787}
788
789#[cfg(test)]
790#[allow(clippy::unwrap_used)]
791mod tests {
792 use super::*;
793
794 #[test]
795 fn test_proc_id_values() {
796 assert_eq!(ProcId::ExecuteSql as u16, 0x000A);
797 assert_eq!(ProcId::Prepare as u16, 0x000B);
798 assert_eq!(ProcId::Execute as u16, 0x000C);
799 assert_eq!(ProcId::Unprepare as u16, 0x000F);
800 }
801
802 #[test]
803 fn test_option_flags_encode() {
804 let flags = RpcOptionFlags::new().with_recompile(true);
805 assert_eq!(flags.encode(), 0x0001);
806 }
807
808 #[test]
809 fn test_param_flags_encode() {
810 let flags = ParamFlags::new().output();
811 assert_eq!(flags.encode(), 0x01);
812 }
813
814 #[test]
815 fn test_int_param() {
816 let param = RpcParam::int("@p1", 42);
817 assert_eq!(param.name, "@p1");
818 assert_eq!(param.type_info.type_id, 0x26);
819 assert!(param.value.is_some());
820 }
821
822 #[test]
823 fn test_nvarchar_param() {
824 let param = RpcParam::nvarchar("@name", "Alice");
825 assert_eq!(param.name, "@name");
826 assert_eq!(param.type_info.type_id, 0xE7);
827 assert_eq!(param.value.as_ref().unwrap().len(), 10);
829 }
830
831 #[test]
832 fn test_execute_sql_request() {
833 let rpc = RpcRequest::execute_sql(
834 "SELECT * FROM users WHERE id = @p1",
835 vec![RpcParam::int("@p1", 42)],
836 );
837
838 assert_eq!(rpc.proc_id, Some(ProcId::ExecuteSql));
839 assert_eq!(rpc.params.len(), 3);
841 }
842
843 #[test]
844 fn test_param_declarations() {
845 let params = vec![
846 RpcParam::int("@p1", 42),
847 RpcParam::nvarchar("@name", "Alice"),
848 ];
849
850 let decls = RpcRequest::build_param_declarations(¶ms);
851 assert!(decls.contains("@p1 int"));
852 assert!(decls.contains("@name nvarchar"));
853 }
854
855 #[test]
856 fn test_rpc_encode_not_empty() {
857 let rpc = RpcRequest::execute_sql("SELECT 1", vec![]);
858 let encoded = rpc.encode();
859 assert!(!encoded.is_empty());
860 }
861
862 #[test]
863 fn test_prepare_request() {
864 let rpc = RpcRequest::prepare(
865 "SELECT * FROM users WHERE id = @p1",
866 &[RpcParam::int("@p1", 0)],
867 );
868
869 assert_eq!(rpc.proc_id, Some(ProcId::Prepare));
870 assert_eq!(rpc.params.len(), 4);
872 assert!(rpc.params[0].flags.by_ref); }
874
875 #[test]
876 fn test_execute_request() {
877 let rpc = RpcRequest::execute(123, vec![RpcParam::int("@p1", 42)]);
878
879 assert_eq!(rpc.proc_id, Some(ProcId::Execute));
880 assert_eq!(rpc.params.len(), 2); }
882
883 #[test]
884 fn test_unprepare_request() {
885 let rpc = RpcRequest::unprepare(123);
886
887 assert_eq!(rpc.proc_id, Some(ProcId::Unprepare));
888 assert_eq!(rpc.params.len(), 1); }
890}